Coverage for tests\unit\test_dictmagic.py: 100%
130 statements
« prev ^ index » next coverage.py v7.6.1, created at 2024-10-09 01:48 -0600
« prev ^ index » next coverage.py v7.6.1, created at 2024-10-09 01:48 -0600
1from __future__ import annotations
3import pytest
5from muutils.dictmagic import (
6 condense_nested_dicts,
7 condense_nested_dicts_matching_values,
8 condense_tensor_dict,
9 dotlist_to_nested_dict,
10 is_numeric_consecutive,
11 kwargs_to_nested_dict,
12 nested_dict_to_dotlist,
13 tuple_dims_replace,
14 update_with_nested_dict,
15)
16from muutils.json_serialize import SerializableDataclass, serializable_dataclass
19def test_dotlist_to_nested_dict():
20 # Positive case
21 assert dotlist_to_nested_dict({"a.b.c": 1, "a.b.d": 2, "a.e": 3}) == {
22 "a": {"b": {"c": 1, "d": 2}, "e": 3}
23 }
25 # Negative case
26 with pytest.raises(TypeError):
27 dotlist_to_nested_dict({1: 1})
29 # Test with different separator
30 assert dotlist_to_nested_dict({"a/b/c": 1, "a/b/d": 2, "a/e": 3}, sep="/") == {
31 "a": {"b": {"c": 1, "d": 2}, "e": 3}
32 }
35def test_update_with_nested_dict():
36 # Positive case
37 assert update_with_nested_dict({"a": {"b": 1}, "c": -1}, {"a": {"b": 2}}) == {
38 "a": {"b": 2},
39 "c": -1,
40 }
42 # Case where the key is not present in original dict
43 assert update_with_nested_dict({"a": {"b": 1}, "c": -1}, {"d": 3}) == {
44 "a": {"b": 1},
45 "c": -1,
46 "d": 3,
47 }
49 # Case where a nested value is overridden
50 assert update_with_nested_dict(
51 {"a": {"b": 1, "d": 3}, "c": -1}, {"a": {"b": 2}}
52 ) == {"a": {"b": 2, "d": 3}, "c": -1}
54 # Case where the dict we are trying to update does not exist
55 assert update_with_nested_dict({"a": 1}, {"b": {"c": 2}}) == {"a": 1, "b": {"c": 2}}
58def test_kwargs_to_nested_dict():
59 # Positive case
60 assert kwargs_to_nested_dict({"a.b.c": 1, "a.b.d": 2, "a.e": 3}) == {
61 "a": {"b": {"c": 1, "d": 2}, "e": 3}
62 }
64 # Case where strip_prefix is not None
65 assert kwargs_to_nested_dict(
66 {"prefix.a.b.c": 1, "prefix.a.b.d": 2, "prefix.a.e": 3}, strip_prefix="prefix."
67 ) == {"a": {"b": {"c": 1, "d": 2}, "e": 3}}
69 # Negative case
70 with pytest.raises(ValueError):
71 kwargs_to_nested_dict(
72 {"a.b.c": 1, "a.b.d": 2, "a.e": 3},
73 strip_prefix="prefix.",
74 when_unknown_prefix="raise",
75 )
77 # Case where -- and - prefix
78 assert kwargs_to_nested_dict(
79 {"--a.b.c": 1, "--a.b.d": 2, "a.e": 3},
80 strip_prefix="--",
81 when_unknown_prefix="ignore",
82 ) == {"a": {"b": {"c": 1, "d": 2}, "e": 3}}
84 # Case where -- and - prefix with warning
85 with pytest.warns(UserWarning):
86 kwargs_to_nested_dict(
87 {"--a.b.c": 1, "-a.b.d": 2, "a.e": 3},
88 strip_prefix="-",
89 when_unknown_prefix="warn",
90 )
93def test_kwargs_to_nested_dict_transform_key():
94 # Case where transform_key is not None, changing dashes to underscores
95 assert kwargs_to_nested_dict(
96 {"a-b-c": 1, "a-b-d": 2, "a-e": 3}, transform_key=lambda x: x.replace("-", "_")
97 ) == {"a_b_c": 1, "a_b_d": 2, "a_e": 3}
99 # Case where strip_prefix and transform_key are both used
100 assert kwargs_to_nested_dict(
101 {"prefix.a-b-c": 1, "prefix.a-b-d": 2, "prefix.a-e": 3},
102 strip_prefix="prefix.",
103 transform_key=lambda x: x.replace("-", "_"),
104 ) == {"a_b_c": 1, "a_b_d": 2, "a_e": 3}
106 # Case where strip_prefix, transform_key and when_unknown_prefix='raise' are all used
107 with pytest.raises(ValueError):
108 kwargs_to_nested_dict(
109 {"a-b-c": 1, "prefix.a-b-d": 2, "prefix.a-e": 3},
110 strip_prefix="prefix.",
111 transform_key=lambda x: x.replace("-", "_"),
112 when_unknown_prefix="raise",
113 )
115 # Case where strip_prefix, transform_key and when_unknown_prefix='warn' are all used
116 with pytest.warns(UserWarning):
117 assert kwargs_to_nested_dict(
118 {"a-b-c": 1, "prefix.a-b-d": 2, "prefix.a-e": 3},
119 strip_prefix="prefix.",
120 transform_key=lambda x: x.replace("-", "_"),
121 when_unknown_prefix="warn",
122 ) == {"a_b_c": 1, "a_b_d": 2, "a_e": 3}
125@serializable_dataclass
126class ChildData(SerializableDataclass):
127 x: int
128 y: int
131@serializable_dataclass
132class ParentData(SerializableDataclass):
133 a: int
134 b: ChildData
137def test_update_from_nested_dict():
138 parent = ParentData(a=1, b=ChildData(x=2, y=3))
139 update_data = {"a": 5, "b": {"x": 6}}
140 parent.update_from_nested_dict(update_data)
142 assert parent.a == 5
143 assert parent.b.x == 6
144 assert parent.b.y == 3
146 update_data2 = {"b": {"y": 7}}
147 parent.update_from_nested_dict(update_data2)
149 assert parent.a == 5
150 assert parent.b.x == 6
151 assert parent.b.y == 7
154def test_update_from_dotlists():
155 parent = ParentData(a=1, b=ChildData(x=2, y=3))
156 update_data = {"a": 5, "b.x": 6}
157 parent.update_from_nested_dict(dotlist_to_nested_dict(update_data))
159 assert parent.a == 5
160 assert parent.b.x == 6
161 assert parent.b.y == 3
163 update_data2 = {"b.y": 7}
164 parent.update_from_nested_dict(dotlist_to_nested_dict(update_data2))
166 assert parent.a == 5
167 assert parent.b.x == 6
168 assert parent.b.y == 7
171# Tests for is_numeric_consecutive
172@pytest.mark.parametrize(
173 "test_input,expected",
174 [
175 (["1", "2", "3"], True),
176 (["1", "3", "2"], True),
177 (["1", "4", "2"], False),
178 ([], False),
179 (["a", "2", "3"], False),
180 ],
181)
182def test_is_numeric_consecutive(test_input, expected):
183 assert is_numeric_consecutive(test_input) == expected
186# Tests for condense_nested_dicts
187def test_condense_nested_dicts_single_level():
188 data = {"1": "a", "2": "a", "3": "b"}
189 expected = {"[1-2]": "a", "3": "b"}
190 assert condense_nested_dicts(data) == expected
193def test_condense_nested_dicts_nested():
194 data = {"1": {"1": "a", "2": "a"}, "2": "b"}
195 expected = {"1": {"[1-2]": "a"}, "2": "b"}
196 assert condense_nested_dicts(data) == expected
199def test_condense_nested_dicts_non_numeric():
200 data = {"a": "a", "b": "a", "c": "b"}
201 assert condense_nested_dicts(data, condense_matching_values=False) == data
202 assert condense_nested_dicts(data, condense_matching_values=True) == {
203 "[a, b]": "a",
204 "c": "b",
205 }
208def test_condense_nested_dicts_mixed_keys():
209 data = {"1": "a", "2": "a", "a": "b"}
210 assert condense_nested_dicts(data) == {"[1, 2]": "a", "a": "b"}
213# Mocking a Tensor-like object for use in tests
214class MockTensor:
215 def __init__(self, shape):
216 self.shape = shape
219# Test cases for `tuple_dims_replace`
220@pytest.mark.parametrize(
221 "input_tuple,dims_names_map,expected",
222 [
223 ((1, 2, 3), {1: "A", 2: "B"}, ("A", "B", 3)),
224 ((4, 5, 6), {}, (4, 5, 6)),
225 ((7, 8), None, (7, 8)),
226 ((1, 2, 3), {3: "C"}, (1, 2, "C")),
227 ],
228)
229def test_tuple_dims_replace(input_tuple, dims_names_map, expected):
230 assert tuple_dims_replace(input_tuple, dims_names_map) == expected
233@pytest.fixture
234def tensor_data():
235 # Mock tensor data simulating different shapes
236 return {
237 "tensor1": MockTensor((10, 256, 256)),
238 "tensor2": MockTensor((10, 256, 256)),
239 "tensor3": MockTensor((10, 512, 256)),
240 }
243def test_condense_tensor_dict_basic(tensor_data):
244 assert condense_tensor_dict(
245 tensor_data,
246 drop_batch_dims=1,
247 condense_matching_values=False,
248 ) == {
249 "tensor1": "(256, 256)",
250 "tensor2": "(256, 256)",
251 "tensor3": "(512, 256)",
252 }
254 assert condense_tensor_dict(
255 tensor_data,
256 drop_batch_dims=1,
257 condense_matching_values=True,
258 ) == {
259 "[tensor1, tensor2]": "(256, 256)",
260 "tensor3": "(512, 256)",
261 }
264def test_condense_tensor_dict_shapes_convert(tensor_data):
265 # Returning the actual shape tuple
266 shapes_convert = lambda x: x # noqa: E731
267 assert condense_tensor_dict(
268 tensor_data,
269 shapes_convert=shapes_convert,
270 drop_batch_dims=1,
271 condense_matching_values=False,
272 ) == {
273 "tensor1": (256, 256),
274 "tensor2": (256, 256),
275 "tensor3": (512, 256),
276 }
278 assert condense_tensor_dict(
279 tensor_data,
280 shapes_convert=shapes_convert,
281 drop_batch_dims=1,
282 condense_matching_values=True,
283 ) == {
284 "[tensor1, tensor2]": (256, 256),
285 "tensor3": (512, 256),
286 }
289def test_condense_tensor_dict_named_dims(tensor_data):
290 assert condense_tensor_dict(
291 tensor_data,
292 dims_names_map={10: "B", 256: "A", 512: "C"},
293 condense_matching_values=False,
294 ) == {
295 "tensor1": "(B, A, A)",
296 "tensor2": "(B, A, A)",
297 "tensor3": "(B, C, A)",
298 }
300 assert condense_tensor_dict(
301 tensor_data,
302 dims_names_map={10: "B", 256: "A", 512: "C"},
303 condense_matching_values=True,
304 ) == {"[tensor1, tensor2]": "(B, A, A)", "tensor3": "(B, C, A)"}
307@pytest.mark.parametrize(
308 "input_data,expected,fallback_mapping",
309 [
310 # Test 1: Simple dictionary with no identical values
311 ({"a": 1, "b": 2}, {"a": 1, "b": 2}, None),
312 # Test 2: Dictionary with identical values
313 ({"a": 1, "b": 1, "c": 2}, {"[a, b]": 1, "c": 2}, None),
314 # Test 3: Nested dictionary with identical values
315 ({"a": {"x": 1, "y": 1}, "b": 2}, {"a": {"[x, y]": 1}, "b": 2}, None),
316 # Test 4: Nested dictionaries with and without identical values
317 (
318 {"a": {"x": 1, "y": 2}, "b": {"x": 1, "z": 3}, "c": 1},
319 {"a": {"x": 1, "y": 2}, "b": {"x": 1, "z": 3}, "c": 1},
320 None,
321 ),
322 # Test 5: Dictionary with unhashable values and no fallback mapping
323 # This case is expected to fail without a fallback mapping, hence not included when using str as fallback
324 # Test 6: Dictionary with unhashable values and a fallback mapping as str
325 (
326 {"a": [1, 2], "b": [1, 2], "c": "test"},
327 {"[a, b]": "[1, 2]", "c": "test"},
328 str,
329 ),
330 ],
331)
332def test_condense_nested_dicts_matching_values(input_data, expected, fallback_mapping):
333 if fallback_mapping is not None:
334 result = condense_nested_dicts_matching_values(input_data, fallback_mapping)
335 else:
336 result = condense_nested_dicts_matching_values(input_data)
337 assert result == expected, f"Expected {expected}, got {result}"
340# "ndtd" = `nested_dict_to_dotlist`
341def test_nested_dict_to_dotlist_basic():
342 nested_dict = {"a": {"b": {"c": 1, "d": 2}, "e": 3}}
343 expected_dotlist = {"a.b.c": 1, "a.b.d": 2, "a.e": 3}
344 assert nested_dict_to_dotlist(nested_dict) == expected_dotlist
347def test_nested_dict_to_dotlist_empty():
348 nested_dict = {}
349 expected_dotlist = {}
350 result = nested_dict_to_dotlist(nested_dict)
351 assert result == expected_dotlist
354def test_nested_dict_to_dotlist_single_level():
355 nested_dict = {"a": 1, "b": 2, "c": 3}
356 expected_dotlist = {"a": 1, "b": 2, "c": 3}
357 assert nested_dict_to_dotlist(nested_dict) == expected_dotlist
360def test_nested_dict_to_dotlist_with_list():
361 nested_dict = {"a": [1, 2, {"b": 3}], "c": 4}
362 expected_dotlist = {"a.0": 1, "a.1": 2, "a.2.b": 3, "c": 4}
363 assert nested_dict_to_dotlist(nested_dict, allow_lists=True) == expected_dotlist
366def test_nested_dict_to_dotlist_nested_empty():
367 nested_dict = {"a": {"b": {}}}
368 expected_dotlist = {"a.b": {}}
369 assert nested_dict_to_dotlist(nested_dict) == expected_dotlist
372def test_round_trip_conversion():
373 original = {"a": {"b": {"c": 1, "d": 2}, "e": 3}}
374 dotlist = nested_dict_to_dotlist(original)
375 result = dotlist_to_nested_dict(dotlist)
376 assert result == original