Coverage for tests / unit / with_torch / test_sdc_torch.py: 98%

59 statements  

« prev     ^ index     » next       coverage.py v7.13.4, created at 2026-02-21 22:18 -0700

1from __future__ import annotations 

2 

3import sys 

4import typing 

5from pathlib import Path 

6 

7import numpy as np 

8import pandas as pd # type: ignore[import] 

9import torch # type: ignore[import-not-found] 

10from muutils.json_serialize import ( 

11 SerializableDataclass, 

12 serializable_dataclass, 

13 serializable_field, 

14) 

15 

16from zanj import ZANJ 

17 

18np.random.seed(0) 

19 

20TEST_DATA_PATH: Path = Path("tests/junk_data") 

21 

22SUPPORTS_KW_ONLY: bool = bool(sys.version_info >= (3, 10)) 

23 

24 

25@serializable_dataclass 

26class BasicZanjTorch(SerializableDataclass): 

27 a: str 

28 q: int = 42 

29 c: typing.List[int] = serializable_field(default_factory=list) 

30 

31 

32@serializable_dataclass 

33class NestedTorch(SerializableDataclass): 

34 name: str 

35 basic: BasicZanjTorch 

36 val: float 

37 

38 

39@serializable_dataclass 

40class sdc_with_torch_tensor(SerializableDataclass): 

41 name: str 

42 tensor1: torch.Tensor 

43 tensor2: torch.Tensor 

44 

45 

46def test_sdc_tensor_small(): 

47 instance = sdc_with_torch_tensor("small tensors", torch.rand(8), torch.rand(16)) 

48 

49 z = ZANJ() 

50 path = TEST_DATA_PATH / "test_sdc_tensor_small.zanj" 

51 z.save(instance, path) 

52 recovered = z.read(path) 

53 assert instance == recovered 

54 

55 

56def test_sdc_tensor(): 

57 instance = sdc_with_torch_tensor( 

58 "bigger tensors", torch.rand(128, 128), torch.rand(256, 256) 

59 ) 

60 

61 z = ZANJ() 

62 path = TEST_DATA_PATH / "test_sdc_tensor.zanj" 

63 z.save(instance, path) 

64 recovered = z.read(path) 

65 assert instance == recovered 

66 

67 

68@serializable_dataclass(kw_only=SUPPORTS_KW_ONLY) 

69class sdc_complicated(SerializableDataclass): 

70 name: str 

71 arr1: np.ndarray 

72 arr2: np.ndarray 

73 iris_data: pd.DataFrame 

74 brain_data: pd.DataFrame 

75 container: typing.List[NestedTorch] 

76 

77 tensor: torch.Tensor 

78 

79 def __eq__(self, value): 

80 return super().__eq__(value) 

81 

82 

83def test_sdc_complicated(): 

84 instance = sdc_complicated( 

85 name="complicated data", 

86 arr1=np.random.rand(128, 128), 

87 arr2=np.random.rand(256, 256), 

88 iris_data=pd.read_csv("tests/input_data/iris.csv"), 

89 brain_data=pd.read_csv("tests/input_data/brain_networks.csv"), 

90 container=[ 

91 NestedTorch( 

92 f"n-{n}", 

93 BasicZanjTorch(f"n-{n}_b", n * 10 + 1, [n + 1, n + 2, n + 10]), 

94 n * np.pi, 

95 ) 

96 for n in range(10) 

97 ], 

98 tensor=torch.rand(512, 512), 

99 ) 

100 

101 z = ZANJ() 

102 path = TEST_DATA_PATH / "test_sdc_complicated.zanj" 

103 z.save(instance, path) 

104 recovered = z.read(path) 

105 assert instance == recovered