Coverage for tests / unit / with_torch / test_sdc_torch.py: 98%
59 statements
« prev ^ index » next coverage.py v7.13.4, created at 2026-02-21 22:18 -0700
« prev ^ index » next coverage.py v7.13.4, created at 2026-02-21 22:18 -0700
1from __future__ import annotations
3import sys
4import typing
5from pathlib import Path
7import numpy as np
8import pandas as pd # type: ignore[import]
9import torch # type: ignore[import-not-found]
10from muutils.json_serialize import (
11 SerializableDataclass,
12 serializable_dataclass,
13 serializable_field,
14)
16from zanj import ZANJ
18np.random.seed(0)
20TEST_DATA_PATH: Path = Path("tests/junk_data")
22SUPPORTS_KW_ONLY: bool = bool(sys.version_info >= (3, 10))
25@serializable_dataclass
26class BasicZanjTorch(SerializableDataclass):
27 a: str
28 q: int = 42
29 c: typing.List[int] = serializable_field(default_factory=list)
32@serializable_dataclass
33class NestedTorch(SerializableDataclass):
34 name: str
35 basic: BasicZanjTorch
36 val: float
39@serializable_dataclass
40class sdc_with_torch_tensor(SerializableDataclass):
41 name: str
42 tensor1: torch.Tensor
43 tensor2: torch.Tensor
46def test_sdc_tensor_small():
47 instance = sdc_with_torch_tensor("small tensors", torch.rand(8), torch.rand(16))
49 z = ZANJ()
50 path = TEST_DATA_PATH / "test_sdc_tensor_small.zanj"
51 z.save(instance, path)
52 recovered = z.read(path)
53 assert instance == recovered
56def test_sdc_tensor():
57 instance = sdc_with_torch_tensor(
58 "bigger tensors", torch.rand(128, 128), torch.rand(256, 256)
59 )
61 z = ZANJ()
62 path = TEST_DATA_PATH / "test_sdc_tensor.zanj"
63 z.save(instance, path)
64 recovered = z.read(path)
65 assert instance == recovered
68@serializable_dataclass(kw_only=SUPPORTS_KW_ONLY)
69class sdc_complicated(SerializableDataclass):
70 name: str
71 arr1: np.ndarray
72 arr2: np.ndarray
73 iris_data: pd.DataFrame
74 brain_data: pd.DataFrame
75 container: typing.List[NestedTorch]
77 tensor: torch.Tensor
79 def __eq__(self, value):
80 return super().__eq__(value)
83def test_sdc_complicated():
84 instance = sdc_complicated(
85 name="complicated data",
86 arr1=np.random.rand(128, 128),
87 arr2=np.random.rand(256, 256),
88 iris_data=pd.read_csv("tests/input_data/iris.csv"),
89 brain_data=pd.read_csv("tests/input_data/brain_networks.csv"),
90 container=[
91 NestedTorch(
92 f"n-{n}",
93 BasicZanjTorch(f"n-{n}_b", n * 10 + 1, [n + 1, n + 2, n + 10]),
94 n * np.pi,
95 )
96 for n in range(10)
97 ],
98 tensor=torch.rand(512, 512),
99 )
101 z = ZANJ()
102 path = TEST_DATA_PATH / "test_sdc_complicated.zanj"
103 z.save(instance, path)
104 recovered = z.read(path)
105 assert instance == recovered