Coverage for tests / unit / no_torch / test_zanj_serializable_dataclass.py: 100%
91 statements
« prev ^ index » next coverage.py v7.13.4, created at 2026-02-21 22:18 -0700
« prev ^ index » next coverage.py v7.13.4, created at 2026-02-21 22:18 -0700
1from __future__ import annotations
3import json
4import sys
5import typing
6from pathlib import Path
8import numpy as np
9import pandas as pd # type: ignore[import]
10from muutils.json_serialize import (
11 SerializableDataclass,
12 serializable_dataclass,
13 serializable_field,
14)
16from zanj import ZANJ
18np.random.seed(0)
20TEST_DATA_PATH: Path = Path("tests/junk_data")
22SUPPORTS_KW_ONLY: bool = bool(sys.version_info >= (3, 10))
25@serializable_dataclass
26class BasicZanj(SerializableDataclass):
27 a: str
28 q: int = 42
29 c: typing.List[int] = serializable_field(default_factory=list)
32def test_Basic():
33 instance = BasicZanj("hello", 42, [1, 2, 3])
35 z = ZANJ()
36 path = TEST_DATA_PATH / "test_BasicZanj.zanj"
37 z.save(instance, path)
38 recovered = z.read(path)
39 assert instance == recovered
42@serializable_dataclass
43class Nested(SerializableDataclass):
44 name: str
45 basic: BasicZanj
46 val: float
49def test_Nested():
50 instance = Nested("hello", BasicZanj("hello", 42, [1, 2, 3]), 3.14)
52 z = ZANJ()
53 path = TEST_DATA_PATH / "test_Nested.zanj"
54 z.save(instance, path)
55 recovered = z.read(path)
56 assert instance == recovered
59@serializable_dataclass
60class Nested_with_container(SerializableDataclass):
61 name: str
62 basic: BasicZanj
63 val: float
64 container: typing.List[Nested] = serializable_field(default_factory=list)
67def test_Nested_with_container():
68 instance = Nested_with_container(
69 "hello",
70 basic=BasicZanj("hello", 42, [1, 2, 3]),
71 val=3.14,
72 container=[
73 Nested("n1", BasicZanj("n1_b", 123, [4, 5, 7]), 2.71),
74 Nested("n2", BasicZanj("n2_b", 456, [7, 8, 9]), 6.28),
75 ],
76 )
78 z = ZANJ()
79 path = TEST_DATA_PATH / "test_Nested_with_container.zanj"
80 z.save(instance, path)
81 recovered = z.read(path)
82 assert instance == recovered
85@serializable_dataclass
86class sdc_with_np_array(SerializableDataclass):
87 name: str
88 arr1: np.ndarray
89 arr2: np.ndarray
92def test_sdc_with_np_array_small():
93 instance = sdc_with_np_array("small arrays", np.random.rand(10), np.random.rand(20))
95 z = ZANJ()
96 path = TEST_DATA_PATH / "test_sdc_with_np_array.zanj"
97 z.save(instance, path)
98 recovered = z.read(path)
99 assert instance == recovered
102def test_sdc_with_np_array():
103 instance = sdc_with_np_array(
104 "bigger arrays", np.random.rand(128, 128), np.random.rand(256, 256)
105 )
107 z = ZANJ()
108 path = TEST_DATA_PATH / "test_sdc_with_np_array.zanj"
109 z.save(instance, path)
110 recovered = z.read(path)
111 assert instance == recovered
114@serializable_dataclass
115class sdc_with_df(SerializableDataclass):
116 name: str
117 iris_data: pd.DataFrame
118 brain_data: pd.DataFrame
121def test_sdc_with_df():
122 instance = sdc_with_df(
123 "downloaded_data",
124 iris_data=pd.read_csv("tests/input_data/iris.csv"),
125 brain_data=pd.read_csv("tests/input_data/brain_networks.csv"),
126 )
128 z = ZANJ()
129 path = TEST_DATA_PATH / "test_sdc_with_df.zanj"
130 z.save(instance, path)
131 recovered = z.read(path)
132 assert instance == recovered
135@serializable_dataclass
136class sdc_container_explicit(SerializableDataclass):
137 name: str
138 container: typing.List[Nested] = serializable_field(
139 default_factory=list,
140 # as jsonl string, for whatever reason
141 serialization_fn=lambda c: "\n".join([json.dumps(n.serialize()) for n in c]),
142 loading_fn=lambda data: [
143 Nested.load(json.loads(n)) for n in data["container"].split("\n")
144 ],
145 # TODO: explicitly specifying the following does not work, since it gets automatically converted before we call load in `loading_fn`:
146 # serialization_fn=lambda c: [n.serialize() for n in c],
147 # loading_fn=lambda data: [Nested.load(n) for n in data["container"]],
148 )
151def test_sdc_container_explicit():
152 instance = sdc_container_explicit(
153 "container explicit",
154 container=[
155 Nested(
156 f"n-{n}",
157 BasicZanj(f"n-{n}_b", n * 10 + 1, [n + 1, n + 2, n + 10]),
158 n * np.pi,
159 )
160 for n in range(10)
161 ],
162 )
164 z = ZANJ()
165 path = TEST_DATA_PATH / "test_sdc_container_explicit.zanj"
166 z.save(instance, path)
167 recovered = z.read(path)
168 assert instance == recovered