Coverage for tests\unit\test_tensor_utils.py: 100%

51 statements  

« prev     ^ index     » next       coverage.py v7.6.1, created at 2024-10-15 20:56 -0600

1from __future__ import annotations 

2 

3import jaxtyping 

4import numpy as np 

5import pytest 

6import torch 

7 

8from muutils.tensor_utils import ( 

9 DTYPE_MAP, 

10 TORCH_DTYPE_MAP, 

11 StateDictKeysError, 

12 StateDictShapeError, 

13 compare_state_dicts, 

14 get_dict_shapes, 

15 jaxtype_factory, 

16 lpad_array, 

17 lpad_tensor, 

18 numpy_to_torch_dtype, 

19 pad_array, 

20 pad_tensor, 

21 rpad_array, 

22 rpad_tensor, 

23) 

24 

25 

26def test_jaxtype_factory(): 

27 ATensor = jaxtype_factory( 

28 "ATensor", torch.Tensor, jaxtyping.Float, legacy_mode="ignore" 

29 ) 

30 assert ATensor.__name__ == "ATensor" 

31 assert "default_jax_dtype = <class 'jaxtyping.Float'" in ATensor.__doc__ 

32 assert "array_type = <class 'torch.Tensor'>" in ATensor.__doc__ 

33 

34 x = ATensor[(1, 2, 3), np.float32] 

35 print(x) 

36 y = ATensor["dim1 dim2", np.float32] 

37 print(y) 

38 

39 

40def test_numpy_to_torch_dtype(): 

41 assert numpy_to_torch_dtype(np.float32) == torch.float32 

42 assert numpy_to_torch_dtype(np.int32) == torch.int32 

43 assert numpy_to_torch_dtype(torch.float32) == torch.float32 

44 

45 

46def test_dtype_maps(): 

47 assert len(DTYPE_MAP) == len(TORCH_DTYPE_MAP) 

48 for key in DTYPE_MAP: 

49 assert key in TORCH_DTYPE_MAP 

50 assert numpy_to_torch_dtype(DTYPE_MAP[key]) == TORCH_DTYPE_MAP[key] 

51 

52 

53def test_pad_tensor(): 

54 tensor = torch.tensor([1, 2, 3]) 

55 assert torch.all(pad_tensor(tensor, 5) == torch.tensor([0, 0, 1, 2, 3])) 

56 assert torch.all(lpad_tensor(tensor, 5) == torch.tensor([0, 0, 1, 2, 3])) 

57 assert torch.all(rpad_tensor(tensor, 5) == torch.tensor([1, 2, 3, 0, 0])) 

58 

59 

60def test_pad_array(): 

61 array = np.array([1, 2, 3]) 

62 assert np.array_equal(pad_array(array, 5), np.array([0, 0, 1, 2, 3])) 

63 assert np.array_equal(lpad_array(array, 5), np.array([0, 0, 1, 2, 3])) 

64 assert np.array_equal(rpad_array(array, 5), np.array([1, 2, 3, 0, 0])) 

65 

66 

67def test_compare_state_dicts(): 

68 d1 = {"a": torch.tensor([1, 2, 3]), "b": torch.tensor([4, 5, 6])} 

69 d2 = {"a": torch.tensor([1, 2, 3]), "b": torch.tensor([4, 5, 6])} 

70 compare_state_dicts(d1, d2) # This should not raise an exception 

71 

72 d2["a"] = torch.tensor([7, 8, 9]) 

73 with pytest.raises(AssertionError): 

74 compare_state_dicts(d1, d2) # This should raise an exception 

75 

76 d2["a"] = torch.tensor([7, 8, 9, 10]) 

77 with pytest.raises(StateDictShapeError): 

78 compare_state_dicts(d1, d2) # This should raise an exception 

79 

80 d2["c"] = torch.tensor([10, 11, 12]) 

81 with pytest.raises(StateDictKeysError): 

82 compare_state_dicts(d1, d2) # This should raise an exception 

83 

84 

85def test_get_dict_shapes(): 

86 x = {"a": torch.rand(2, 3), "b": torch.rand(1, 3, 5), "c": torch.rand(2)} 

87 x_shapes = get_dict_shapes(x) 

88 assert x_shapes == {"a": (2, 3), "b": (1, 3, 5), "c": (2,)}