Coverage for tests / unit / with_torch / test_bool_array_torch.py: 100%

19 statements  

« prev     ^ index     » next       coverage.py v7.13.4, created at 2026-02-21 22:18 -0700

1from pathlib import Path 

2 

3import torch # type: ignore[import-not-found] 

4from muutils.json_serialize import SerializableDataclass, serializable_dataclass 

5 

6from zanj import ZANJ 

7 

8TEST_DATA_PATH: Path = Path("tests/junk_data") 

9 

10 

11@serializable_dataclass 

12class MyClass_torch(SerializableDataclass): 

13 name: str 

14 arr_1: torch.Tensor 

15 arr_2: torch.Tensor 

16 

17 

18def test_torch_bool_array(): 

19 fname: Path = TEST_DATA_PATH / "test_torch_bool_array.zanj" 

20 c: MyClass_torch = MyClass_torch( 

21 name="test", 

22 arr_1=torch.tensor([True, False, True]), 

23 arr_2=torch.tensor([True, False, True]), 

24 ) 

25 

26 z = ZANJ() 

27 

28 z.save(c, fname) 

29 

30 c2: MyClass_torch = z.read(fname) 

31 

32 assert c2.arr_1.dtype == torch.bool 

33 assert c2.arr_2.dtype == torch.bool 

34 

35 assert c == c2