Coverage for tests/unit/with_torch/test_bool_array_torch.py: 100%
19 statements
« prev ^ index » next coverage.py v7.6.12, created at 2025-03-20 11:17 -0600
« prev ^ index » next coverage.py v7.6.12, created at 2025-03-20 11:17 -0600
1from pathlib import Path
3import torch # type: ignore[import-not-found]
4from muutils.json_serialize import SerializableDataclass, serializable_dataclass
6from zanj import ZANJ
8TEST_DATA_PATH: Path = Path("tests/junk_data")
11@serializable_dataclass
12class MyClass_torch(SerializableDataclass):
13 name: str
14 arr_1: torch.Tensor
15 arr_2: torch.Tensor
18def test_torch_bool_array():
19 fname: Path = TEST_DATA_PATH / "test_torch_bool_array.zanj"
20 c: MyClass_torch = MyClass_torch(
21 name="test",
22 arr_1=torch.tensor([True, False, True]),
23 arr_2=torch.tensor([True, False, True]),
24 )
26 z = ZANJ()
28 z.save(c, fname)
30 c2: MyClass_torch = z.read(fname)
32 assert c2.arr_1.dtype == torch.bool
33 assert c2.arr_2.dtype == torch.bool
35 assert c == c2