Coverage for tests / unit / no_torch / test_shared_prefix_keys.py: 100%

22 statements  

« prev     ^ index     » next       coverage.py v7.13.4, created at 2026-02-21 22:18 -0700

1from pathlib import Path 

2import typing 

3import numpy as np 

4 

5import pytest 

6 

7from zanj import ZANJ 

8 

9_TEMP_PATH: Path = Path("tests/.temp/") 

10 

11 

12# NOTE: as of 2025-11-06 15:32 (v0.5.1), the first test (longer key first) fails, while the second test passes. wtf? 

13 

14 

15@pytest.mark.parametrize( 

16 ("keys", "name"), 

17 [ 

18 (["layer.1.weight", "layer.1"], "longer_key_first"), 

19 (["layer.1", "layer.1.weight"], "shorter_key_first"), 

20 ], 

21) 

22def test_shared_prefix_keys(keys: typing.List[str], name: str): 

23 fname: Path = _TEMP_PATH / f"shared_prefix_keys-{name}.zanj" 

24 

25 # 

26 data = {key: np.random.rand(10, 10) for key in keys} 

27 

28 ZANJ(external_array_threshold=0).save(data, fname) 

29 

30 print("saved successfully") 

31 loaded = ZANJ().read(fname) 

32 assert set(data.keys()) == set(loaded.keys()) 

33 for key in data.keys(): 

34 print(f"{key = }") 

35 print(f"{type(data[key]) = }") 

36 print(f"{data[key] = }") 

37 print(f"{type(loaded[key]) = }") 

38 print(f"{loaded[key] = }") 

39 assert type(loaded[key]) == type(data[key]) # noqa: E721 

40 np.testing.assert_array_equal(data[key], loaded[key])