Coverage for tests / unit / with_torch / test_bool_array_torch.py: 100%
19 statements
« prev ^ index » next coverage.py v7.13.4, created at 2026-02-21 22:18 -0700
« prev ^ index » next coverage.py v7.13.4, created at 2026-02-21 22:18 -0700
1from pathlib import Path
3import torch # type: ignore[import-not-found]
4from muutils.json_serialize import SerializableDataclass, serializable_dataclass
6from zanj import ZANJ
8TEST_DATA_PATH: Path = Path("tests/junk_data")
11@serializable_dataclass
12class MyClass_torch(SerializableDataclass):
13 name: str
14 arr_1: torch.Tensor
15 arr_2: torch.Tensor
18def test_torch_bool_array():
19 fname: Path = TEST_DATA_PATH / "test_torch_bool_array.zanj"
20 c: MyClass_torch = MyClass_torch(
21 name="test",
22 arr_1=torch.tensor([True, False, True]),
23 arr_2=torch.tensor([True, False, True]),
24 )
26 z = ZANJ()
28 z.save(c, fname)
30 c2: MyClass_torch = z.read(fname)
32 assert c2.arr_1.dtype == torch.bool
33 assert c2.arr_2.dtype == torch.bool
35 assert c == c2