Coverage for tests\unit\test_dictmagic.py: 100%

130 statements  

« prev     ^ index     » next       coverage.py v7.6.1, created at 2024-10-15 20:56 -0600

1from __future__ import annotations 

2 

3import pytest 

4 

5from muutils.dictmagic import ( 

6 condense_nested_dicts, 

7 condense_nested_dicts_matching_values, 

8 condense_tensor_dict, 

9 dotlist_to_nested_dict, 

10 is_numeric_consecutive, 

11 kwargs_to_nested_dict, 

12 nested_dict_to_dotlist, 

13 tuple_dims_replace, 

14 update_with_nested_dict, 

15) 

16from muutils.json_serialize import SerializableDataclass, serializable_dataclass 

17 

18 

19def test_dotlist_to_nested_dict(): 

20 # Positive case 

21 assert dotlist_to_nested_dict({"a.b.c": 1, "a.b.d": 2, "a.e": 3}) == { 

22 "a": {"b": {"c": 1, "d": 2}, "e": 3} 

23 } 

24 

25 # Negative case 

26 with pytest.raises(TypeError): 

27 dotlist_to_nested_dict({1: 1}) 

28 

29 # Test with different separator 

30 assert dotlist_to_nested_dict({"a/b/c": 1, "a/b/d": 2, "a/e": 3}, sep="/") == { 

31 "a": {"b": {"c": 1, "d": 2}, "e": 3} 

32 } 

33 

34 

35def test_update_with_nested_dict(): 

36 # Positive case 

37 assert update_with_nested_dict({"a": {"b": 1}, "c": -1}, {"a": {"b": 2}}) == { 

38 "a": {"b": 2}, 

39 "c": -1, 

40 } 

41 

42 # Case where the key is not present in original dict 

43 assert update_with_nested_dict({"a": {"b": 1}, "c": -1}, {"d": 3}) == { 

44 "a": {"b": 1}, 

45 "c": -1, 

46 "d": 3, 

47 } 

48 

49 # Case where a nested value is overridden 

50 assert update_with_nested_dict( 

51 {"a": {"b": 1, "d": 3}, "c": -1}, {"a": {"b": 2}} 

52 ) == {"a": {"b": 2, "d": 3}, "c": -1} 

53 

54 # Case where the dict we are trying to update does not exist 

55 assert update_with_nested_dict({"a": 1}, {"b": {"c": 2}}) == {"a": 1, "b": {"c": 2}} 

56 

57 

58def test_kwargs_to_nested_dict(): 

59 # Positive case 

60 assert kwargs_to_nested_dict({"a.b.c": 1, "a.b.d": 2, "a.e": 3}) == { 

61 "a": {"b": {"c": 1, "d": 2}, "e": 3} 

62 } 

63 

64 # Case where strip_prefix is not None 

65 assert kwargs_to_nested_dict( 

66 {"prefix.a.b.c": 1, "prefix.a.b.d": 2, "prefix.a.e": 3}, strip_prefix="prefix." 

67 ) == {"a": {"b": {"c": 1, "d": 2}, "e": 3}} 

68 

69 # Negative case 

70 with pytest.raises(ValueError): 

71 kwargs_to_nested_dict( 

72 {"a.b.c": 1, "a.b.d": 2, "a.e": 3}, 

73 strip_prefix="prefix.", 

74 when_unknown_prefix="raise", 

75 ) 

76 

77 # Case where -- and - prefix 

78 assert kwargs_to_nested_dict( 

79 {"--a.b.c": 1, "--a.b.d": 2, "a.e": 3}, 

80 strip_prefix="--", 

81 when_unknown_prefix="ignore", 

82 ) == {"a": {"b": {"c": 1, "d": 2}, "e": 3}} 

83 

84 # Case where -- and - prefix with warning 

85 with pytest.warns(UserWarning): 

86 kwargs_to_nested_dict( 

87 {"--a.b.c": 1, "-a.b.d": 2, "a.e": 3}, 

88 strip_prefix="-", 

89 when_unknown_prefix="warn", 

90 ) 

91 

92 

93def test_kwargs_to_nested_dict_transform_key(): 

94 # Case where transform_key is not None, changing dashes to underscores 

95 assert kwargs_to_nested_dict( 

96 {"a-b-c": 1, "a-b-d": 2, "a-e": 3}, transform_key=lambda x: x.replace("-", "_") 

97 ) == {"a_b_c": 1, "a_b_d": 2, "a_e": 3} 

98 

99 # Case where strip_prefix and transform_key are both used 

100 assert kwargs_to_nested_dict( 

101 {"prefix.a-b-c": 1, "prefix.a-b-d": 2, "prefix.a-e": 3}, 

102 strip_prefix="prefix.", 

103 transform_key=lambda x: x.replace("-", "_"), 

104 ) == {"a_b_c": 1, "a_b_d": 2, "a_e": 3} 

105 

106 # Case where strip_prefix, transform_key and when_unknown_prefix='raise' are all used 

107 with pytest.raises(ValueError): 

108 kwargs_to_nested_dict( 

109 {"a-b-c": 1, "prefix.a-b-d": 2, "prefix.a-e": 3}, 

110 strip_prefix="prefix.", 

111 transform_key=lambda x: x.replace("-", "_"), 

112 when_unknown_prefix="raise", 

113 ) 

114 

115 # Case where strip_prefix, transform_key and when_unknown_prefix='warn' are all used 

116 with pytest.warns(UserWarning): 

117 assert kwargs_to_nested_dict( 

118 {"a-b-c": 1, "prefix.a-b-d": 2, "prefix.a-e": 3}, 

119 strip_prefix="prefix.", 

120 transform_key=lambda x: x.replace("-", "_"), 

121 when_unknown_prefix="warn", 

122 ) == {"a_b_c": 1, "a_b_d": 2, "a_e": 3} 

123 

124 

125@serializable_dataclass 

126class ChildData(SerializableDataclass): 

127 x: int 

128 y: int 

129 

130 

131@serializable_dataclass 

132class ParentData(SerializableDataclass): 

133 a: int 

134 b: ChildData 

135 

136 

137def test_update_from_nested_dict(): 

138 parent = ParentData(a=1, b=ChildData(x=2, y=3)) 

139 update_data = {"a": 5, "b": {"x": 6}} 

140 parent.update_from_nested_dict(update_data) 

141 

142 assert parent.a == 5 

143 assert parent.b.x == 6 

144 assert parent.b.y == 3 

145 

146 update_data2 = {"b": {"y": 7}} 

147 parent.update_from_nested_dict(update_data2) 

148 

149 assert parent.a == 5 

150 assert parent.b.x == 6 

151 assert parent.b.y == 7 

152 

153 

154def test_update_from_dotlists(): 

155 parent = ParentData(a=1, b=ChildData(x=2, y=3)) 

156 update_data = {"a": 5, "b.x": 6} 

157 parent.update_from_nested_dict(dotlist_to_nested_dict(update_data)) 

158 

159 assert parent.a == 5 

160 assert parent.b.x == 6 

161 assert parent.b.y == 3 

162 

163 update_data2 = {"b.y": 7} 

164 parent.update_from_nested_dict(dotlist_to_nested_dict(update_data2)) 

165 

166 assert parent.a == 5 

167 assert parent.b.x == 6 

168 assert parent.b.y == 7 

169 

170 

171# Tests for is_numeric_consecutive 

172@pytest.mark.parametrize( 

173 "test_input,expected", 

174 [ 

175 (["1", "2", "3"], True), 

176 (["1", "3", "2"], True), 

177 (["1", "4", "2"], False), 

178 ([], False), 

179 (["a", "2", "3"], False), 

180 ], 

181) 

182def test_is_numeric_consecutive(test_input, expected): 

183 assert is_numeric_consecutive(test_input) == expected 

184 

185 

186# Tests for condense_nested_dicts 

187def test_condense_nested_dicts_single_level(): 

188 data = {"1": "a", "2": "a", "3": "b"} 

189 expected = {"[1-2]": "a", "3": "b"} 

190 assert condense_nested_dicts(data) == expected 

191 

192 

193def test_condense_nested_dicts_nested(): 

194 data = {"1": {"1": "a", "2": "a"}, "2": "b"} 

195 expected = {"1": {"[1-2]": "a"}, "2": "b"} 

196 assert condense_nested_dicts(data) == expected 

197 

198 

199def test_condense_nested_dicts_non_numeric(): 

200 data = {"a": "a", "b": "a", "c": "b"} 

201 assert condense_nested_dicts(data, condense_matching_values=False) == data 

202 assert condense_nested_dicts(data, condense_matching_values=True) == { 

203 "[a, b]": "a", 

204 "c": "b", 

205 } 

206 

207 

208def test_condense_nested_dicts_mixed_keys(): 

209 data = {"1": "a", "2": "a", "a": "b"} 

210 assert condense_nested_dicts(data) == {"[1, 2]": "a", "a": "b"} 

211 

212 

213# Mocking a Tensor-like object for use in tests 

214class MockTensor: 

215 def __init__(self, shape): 

216 self.shape = shape 

217 

218 

219# Test cases for `tuple_dims_replace` 

220@pytest.mark.parametrize( 

221 "input_tuple,dims_names_map,expected", 

222 [ 

223 ((1, 2, 3), {1: "A", 2: "B"}, ("A", "B", 3)), 

224 ((4, 5, 6), {}, (4, 5, 6)), 

225 ((7, 8), None, (7, 8)), 

226 ((1, 2, 3), {3: "C"}, (1, 2, "C")), 

227 ], 

228) 

229def test_tuple_dims_replace(input_tuple, dims_names_map, expected): 

230 assert tuple_dims_replace(input_tuple, dims_names_map) == expected 

231 

232 

233@pytest.fixture 

234def tensor_data(): 

235 # Mock tensor data simulating different shapes 

236 return { 

237 "tensor1": MockTensor((10, 256, 256)), 

238 "tensor2": MockTensor((10, 256, 256)), 

239 "tensor3": MockTensor((10, 512, 256)), 

240 } 

241 

242 

243def test_condense_tensor_dict_basic(tensor_data): 

244 assert condense_tensor_dict( 

245 tensor_data, 

246 drop_batch_dims=1, 

247 condense_matching_values=False, 

248 ) == { 

249 "tensor1": "(256, 256)", 

250 "tensor2": "(256, 256)", 

251 "tensor3": "(512, 256)", 

252 } 

253 

254 assert condense_tensor_dict( 

255 tensor_data, 

256 drop_batch_dims=1, 

257 condense_matching_values=True, 

258 ) == { 

259 "[tensor1, tensor2]": "(256, 256)", 

260 "tensor3": "(512, 256)", 

261 } 

262 

263 

264def test_condense_tensor_dict_shapes_convert(tensor_data): 

265 # Returning the actual shape tuple 

266 shapes_convert = lambda x: x # noqa: E731 

267 assert condense_tensor_dict( 

268 tensor_data, 

269 shapes_convert=shapes_convert, 

270 drop_batch_dims=1, 

271 condense_matching_values=False, 

272 ) == { 

273 "tensor1": (256, 256), 

274 "tensor2": (256, 256), 

275 "tensor3": (512, 256), 

276 } 

277 

278 assert condense_tensor_dict( 

279 tensor_data, 

280 shapes_convert=shapes_convert, 

281 drop_batch_dims=1, 

282 condense_matching_values=True, 

283 ) == { 

284 "[tensor1, tensor2]": (256, 256), 

285 "tensor3": (512, 256), 

286 } 

287 

288 

289def test_condense_tensor_dict_named_dims(tensor_data): 

290 assert condense_tensor_dict( 

291 tensor_data, 

292 dims_names_map={10: "B", 256: "A", 512: "C"}, 

293 condense_matching_values=False, 

294 ) == { 

295 "tensor1": "(B, A, A)", 

296 "tensor2": "(B, A, A)", 

297 "tensor3": "(B, C, A)", 

298 } 

299 

300 assert condense_tensor_dict( 

301 tensor_data, 

302 dims_names_map={10: "B", 256: "A", 512: "C"}, 

303 condense_matching_values=True, 

304 ) == {"[tensor1, tensor2]": "(B, A, A)", "tensor3": "(B, C, A)"} 

305 

306 

307@pytest.mark.parametrize( 

308 "input_data,expected,fallback_mapping", 

309 [ 

310 # Test 1: Simple dictionary with no identical values 

311 ({"a": 1, "b": 2}, {"a": 1, "b": 2}, None), 

312 # Test 2: Dictionary with identical values 

313 ({"a": 1, "b": 1, "c": 2}, {"[a, b]": 1, "c": 2}, None), 

314 # Test 3: Nested dictionary with identical values 

315 ({"a": {"x": 1, "y": 1}, "b": 2}, {"a": {"[x, y]": 1}, "b": 2}, None), 

316 # Test 4: Nested dictionaries with and without identical values 

317 ( 

318 {"a": {"x": 1, "y": 2}, "b": {"x": 1, "z": 3}, "c": 1}, 

319 {"a": {"x": 1, "y": 2}, "b": {"x": 1, "z": 3}, "c": 1}, 

320 None, 

321 ), 

322 # Test 5: Dictionary with unhashable values and no fallback mapping 

323 # This case is expected to fail without a fallback mapping, hence not included when using str as fallback 

324 # Test 6: Dictionary with unhashable values and a fallback mapping as str 

325 ( 

326 {"a": [1, 2], "b": [1, 2], "c": "test"}, 

327 {"[a, b]": "[1, 2]", "c": "test"}, 

328 str, 

329 ), 

330 ], 

331) 

332def test_condense_nested_dicts_matching_values(input_data, expected, fallback_mapping): 

333 if fallback_mapping is not None: 

334 result = condense_nested_dicts_matching_values(input_data, fallback_mapping) 

335 else: 

336 result = condense_nested_dicts_matching_values(input_data) 

337 assert result == expected, f"Expected {expected}, got {result}" 

338 

339 

340# "ndtd" = `nested_dict_to_dotlist` 

341def test_nested_dict_to_dotlist_basic(): 

342 nested_dict = {"a": {"b": {"c": 1, "d": 2}, "e": 3}} 

343 expected_dotlist = {"a.b.c": 1, "a.b.d": 2, "a.e": 3} 

344 assert nested_dict_to_dotlist(nested_dict) == expected_dotlist 

345 

346 

347def test_nested_dict_to_dotlist_empty(): 

348 nested_dict = {} 

349 expected_dotlist = {} 

350 result = nested_dict_to_dotlist(nested_dict) 

351 assert result == expected_dotlist 

352 

353 

354def test_nested_dict_to_dotlist_single_level(): 

355 nested_dict = {"a": 1, "b": 2, "c": 3} 

356 expected_dotlist = {"a": 1, "b": 2, "c": 3} 

357 assert nested_dict_to_dotlist(nested_dict) == expected_dotlist 

358 

359 

360def test_nested_dict_to_dotlist_with_list(): 

361 nested_dict = {"a": [1, 2, {"b": 3}], "c": 4} 

362 expected_dotlist = {"a.0": 1, "a.1": 2, "a.2.b": 3, "c": 4} 

363 assert nested_dict_to_dotlist(nested_dict, allow_lists=True) == expected_dotlist 

364 

365 

366def test_nested_dict_to_dotlist_nested_empty(): 

367 nested_dict = {"a": {"b": {}}} 

368 expected_dotlist = {"a.b": {}} 

369 assert nested_dict_to_dotlist(nested_dict) == expected_dotlist 

370 

371 

372def test_round_trip_conversion(): 

373 original = {"a": {"b": {"c": 1, "d": 2}, "e": 3}} 

374 dotlist = nested_dict_to_dotlist(original) 

375 result = dotlist_to_nested_dict(dotlist) 

376 assert result == original