# Stats
- 35 files
- 5412 (5.4K) lines
- 176331 (176K) chars
- 16813 (17K) `whitespace-split` tokens

# File Tree

```
ZANJ                                                 
├── .github                                          
│   └── workflows                                    
│    └── checks.yml                                  [118L     3,026C   334T]
├── tests                                            
│   ├── input_data                                   
│   │   ├── brain_networks.csv                       [924L 1,075,910C   924T]
│   │   └── iris.csv                                 [151L     3,857C   151T]
│   ├── unit                                         
│   │   ├── no_torch                                 
│   │   │   ├── test_bool_array.py                   [ 60L     1,169C   113T]
│   │   │   ├── test_dataframe_serialization.py      [156L     4,949C   499T]
│   │   │   ├── test_isolate_zanj_handler_store.py   [ 78L     1,784C   158T]
│   │   │   ├── test_load_item_recursive.py          [231L     7,452C   714T]
│   │   │   ├── test_loading_edge_cases.py           [192L     6,807C   538T]
│   │   │   ├── test_serializing_edge_cases.py       [179L     5,816C   459T]
│   │   │   ├── test_shared_prefix_keys.py           [ 40L     1,139C   112T]
│   │   │   ├── test_zanj_basic.py                   [184L     5,402C   458T]
│   │   │   ├── test_zanj_edge_cases.py              [290L     8,643C   907T]
│   │   │   ├── test_zanj_populate_nested.py         [ 57L     1,386C   103T]
│   │   │   └── test_zanj_serializable_dataclass.py  [168L     4,351C   404T]
│   │   └── with_torch                               
│   │    ├── test_bool_array_torch.py                [ 35L       793C    72T]
│   │    ├── test_get_module_device.py               [ 61L     1,870C   203T]
│   │    ├── test_sdc_torch.py                       [105L     2,555C   219T]
│   │    ├── test_torch_edge_cases.py                [229L     7,182C   642T]
│   │    ├── test_torchutil_edge_cases.py            [225L     7,310C   596T]
│   │    ├── test_zanj_sdc_modelcfg.py               [169L     4,750C   375T]
│   │    ├── test_zanj_torch.py                      [169L     5,042C   350T]
│   │    └── test_zanj_torch_cfgmismatch.py          [161L     3,983C   334T]
│   └── assert_no_torch.py                           [  8L       167C    12T]
├── zanj                                             
│   ├── __init__.py                                  [ 19L       299C    30T]
│   ├── consts.py                                    [ 29L       877C    93T]
│   ├── externals.py                                 [ 53L     1,515C   166T]
│   ├── loading.py                                   [488L    18,735C 1,762T]
│   ├── py.typed                                     [  0L         0C     0T]
│   ├── serializing.py                               [293L    10,703C   996T]
│   ├── torchutil.py                                 [295L    10,161C   915T]
│   └── zanj.py                                      [251L     8,708C   802T]
├── README.md                                        [235L    10,944C 1,365T]
├── demo.ipynb                                       [317L     9,678C   976T]
├── makefile                                         [728L    28,521C 3,688T]
├── pyproject.toml                                   [176L     5,085C   529T]
```

# File Contents

``````{ path=".github/workflows/checks.yml"  }
name: Checks

on:
  pull_request:
    branches:
      - main
  push:
    branches:
      - main
  
  workflow_dispatch:

jobs:
  lint:
    name: Formatting
    runs-on: ubuntu-latest
    steps:
      - name: Checkout code
        uses: actions/checkout@v4
        with: 
          fetch-depth: 0

      - name: Install linters
        run: pip install -r .meta/requirements/requirements-lint.txt

      - name: Run Format Checks
        run: make format-check RUN_GLOBAL=1

  test:
    name: Test
    runs-on: ubuntu-latest
    strategy:
      matrix:
        python: ["3.8", "3.9", "3.10", "3.11", "3.12", "3.13"]
        pkg:
          - torch: "1.13.1"
            numpy: "1.24.4"
            pandas: "2.0.3"
            group: "legacy"
          - torch: ""
            numpy: ""
            pandas: ""
            group: "latest"
          - torch: "None"
            numpy: ""
            pandas: ""
            group: "notorch"
        exclude:
          - python: "3.12"
            pkg:
              group: "legacy"
          - python: "3.13"
            pkg:
              group: "legacy"
    steps:
      - name: Checkout code
        uses: actions/checkout@v4
        with: 
          fetch-depth: 1

      - name: Set up python
        uses: actions/setup-python@v5
        with:
          python-version: ${{ matrix.python }}


      - name: set up uv
        run: curl -LsSf https://astral.sh/uv/install.sh | sh
  
      - name: check dependencies
        run: make dep-check

      - name: install dependencies and package
        run: make dep

      - name: Install different numpy version
        if: ${{ matrix.pkg.numpy != '' }}
        run: uv pip install numpy==${{ matrix.pkg.numpy }}

      - name: Install different pandas version
        if: ${{ matrix.pkg.pandas != '' }}
        run: uv pip install pandas==${{ matrix.pkg.pandas }}

      - name: Install different pytorch version
        if: ${{ matrix.pkg.torch != '' && matrix.pkg.torch != 'None' }}
        run: |
          uv pip install torch==${{ matrix.pkg.torch }}+cpu --extra-index-url https://download.pytorch.org/whl/cpu

      - name: remove torch if testing torchless
        if: ${{ matrix.pkg.torch == 'None' }}
        run: uv pip uninstall torch

      - name: make info
        run: make info-long UV_NOSYNC=1

      - name: torch dep info
        run: make dep-check-torch UV_NOSYNC=1
        continue-on-error: true
      
      - name: format check
        run: make format-check UV_NOSYNC=1

      - name: tests
        if: ${{ matrix.pkg.torch != 'None' }}
        run: make test UV_NOSYNC=1

      - name: tests without torch
        if: ${{ matrix.pkg.torch == 'None' }}
        run: make test-notorch UV_NOSYNC=1

      # - name: tests in strict mode
      #   # TODO: until zanj ported to 3.8 and 3.9
      #   if: ${{ matrix.python != '3.8' && matrix.python != '3.9' }}
      #   run: make test WARN_STRICT=1 RUN_GLOBAL=1

      - name: check typing
        if: ${{ matrix.python != '3.8' }}
        run: make typing

``````{ end_of_file=".github/workflows/checks.yml" }

``````{ path="tests/input_data/brain_networks.csv" processed_with="csv_preview_5_lines" }
network,1,1,2,2,3,3,4,4,5,5,6,6,6,6,7,7,7,7,7,7,8,8,8,8,8,8,9,9,10,10,11,11,12,12,12,12,12,13,13,13,13,13,13,14,14,15,15,16,16,16,16,16,16,16,16,17,17,17,17,17,17,17
node,1,1,1,1,1,1,1,1,1,1,1,1,2,2,1,1,2,2,3,3,1,1,2,2,3,3,1,1,1,1,1,1,1,1,2,2,3,1,1,2,2,3,4,1,1,1,1,1,1,2,2,3,3,4,4,1,1,2,2,3,3,4
hemi,lh,rh,lh,rh,lh,rh,lh,rh,lh,rh,lh,rh,lh,rh,lh,rh,lh,rh,lh,rh,lh,rh,lh,rh,lh,rh,lh,rh,lh,rh,lh,rh,lh,rh,lh,rh,lh,lh,rh,lh,rh,rh,rh,lh,rh,lh,rh,lh,rh,lh,rh,lh,rh,lh,rh,lh,rh,lh,rh,lh,rh,lh
,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,
0,56.05574417114258,92.03103637695312,3.391575574874878,38.65968322753906,26.203819274902344,-49.71556854248047,47.4610366821289,26.746612548828125,-35.898860931396484,-1.8891807794570925,5.898688316345215,-43.69232177734375,-47.66426467895508,12.2841215133667,1.5665380954742432,-13.042585372924805,-1.8552596569061282,-39.80590057373047,-30.831512451171875,-61.13700866699219,-25.82785606384277,39.02416229248047,-29.97164535522461,-6.1323723793029785,-56.75698852539063,0.
... (truncated)
``````{ end_of_file="tests/input_data/brain_networks.csv" }

``````{ path="tests/input_data/iris.csv" processed_with="csv_preview_5_lines" }
sepal_length,sepal_width,petal_length,petal_width,species
5.1,3.5,1.4,0.2,setosa
4.9,3.0,1.4,0.2,setosa
4.7,3.2,1.3,0.2,setosa
4.6,3.1,1.5,0.2,setosa
... (truncated)
``````{ end_of_file="tests/input_data/iris.csv" }

``````{ path="tests/unit/no_torch/test_bool_array.py"  }
from pathlib import Path

import numpy as np

from muutils.json_serialize import SerializableDataclass, serializable_dataclass

from zanj import ZANJ

TEST_DATA_PATH: Path = Path("tests/junk_data")


@serializable_dataclass
class MyClass_list(SerializableDataclass):
    name: str
    arr_1: list
    arr_2: list


def test_list_bool_array():
    fname: Path = TEST_DATA_PATH / "test_list_bool_array.zanj"
    c: MyClass_list = MyClass_list(
        name="test",
        arr_1=[True, False, True],
        arr_2=[True, False, True],
    )

    z = ZANJ()

    z.save(c, fname)

    c2: MyClass_list = z.read(fname)

    assert c == c2


@serializable_dataclass
class MyClass_np(SerializableDataclass):
    name: str
    arr_1: np.ndarray
    arr_2: np.ndarray


def test_np_bool_array():
    fname: Path = TEST_DATA_PATH / "test_np_bool_array.zanj"
    c: MyClass_np = MyClass_np(
        name="test",
        arr_1=np.array([True, False, True]),
        arr_2=np.array([True, False, True]),
    )

    z = ZANJ()

    z.save(c, fname)

    c2: MyClass_np = z.read(fname)

    assert c2.arr_1.dtype == np.bool_
    assert c2.arr_2.dtype == np.bool_

    assert c == c2

``````{ end_of_file="tests/unit/no_torch/test_bool_array.py" }

``````{ path="tests/unit/no_torch/test_dataframe_serialization.py"  }
"""Tests for pandas DataFrame serialization edge cases and regression prevention."""

from __future__ import annotations

from pathlib import Path

import numpy as np
import pandas as pd

from zanj import ZANJ

TEST_DATA_PATH: Path = Path("tests/junk_data")


def test_dataframe_detection_logic():
    """Verify the module + class name detection works for pandas DataFrames.

    This test would have caught the pandas 3.0 regression where the MRO string
    changed from 'pandas.core.frame.DataFrame' to 'pandas.DataFrame'.
    """
    df = pd.DataFrame({"a": [1, 2, 3], "b": [4, 5, 6]})

    # These are the exact checks used in serializing.py
    assert "pandas" in df.__class__.__module__, (
        f"Expected 'pandas' in module, got {df.__class__.__module__}"
    )
    assert df.__class__.__name__ == "DataFrame", (
        f"Expected class name 'DataFrame', got {df.__class__.__name__}"
    )


def test_small_dataframe_roundtrip():
    """Test DataFrame with fewer rows than external_list_threshold (256)."""
    df = pd.DataFrame(
        {
            "int_col": list(range(10)),
            "float_col": [x * 0.1 for x in range(10)],
            "str_col": [f"row_{x}" for x in range(10)],
        }
    )

    z = ZANJ()
    path = TEST_DATA_PATH / "test_small_dataframe.zanj"
    z.save({"df": df}, path)
    recovered = z.read(path)

    assert isinstance(recovered["df"], pd.DataFrame), (
        f"Expected DataFrame, got {type(recovered['df'])}"
    )
    assert df.equals(recovered["df"]), "DataFrames should be equal"


def test_single_row_dataframe():
    """Test DataFrame with a single row (minimal case)."""
    df = pd.DataFrame({"a": [1], "b": [2]})

    z = ZANJ()
    path = TEST_DATA_PATH / "test_single_row_dataframe.zanj"
    z.save({"df": df}, path)
    recovered = z.read(path)

    assert isinstance(recovered["df"], pd.DataFrame), (
        f"Expected DataFrame, got {type(recovered['df'])}"
    )
    assert len(recovered["df"]) == 1, "DataFrame should have 1 row"
    assert list(recovered["df"].columns) == ["a", "b"], "Columns should be preserved"


def test_empty_dataframe():
    """Test DataFrame with zero rows."""
    df = pd.DataFrame({"a": [], "b": []})

    z = ZANJ()
    path = TEST_DATA_PATH / "test_empty_dataframe.zanj"
    z.save({"df": df}, path)
    recovered = z.read(path)

    assert isinstance(recovered["df"], pd.DataFrame), (
        f"Expected DataFrame, got {type(recovered['df'])}"
    )
    assert len(recovered["df"]) == 0, "DataFrame should be empty"
    assert list(recovered["df"].columns) == ["a", "b"], "Columns should be preserved"


def test_dataframe_dtype_preservation():
    """Verify that dtypes survive the round-trip."""
    df = pd.DataFrame(
        {
            "int_col": pd.array([1, 2, 3], dtype="int64"),
            "float_col": pd.array([1.1, 2.2, 3.3], dtype="float64"),
            "str_col": pd.array(["a", "b", "c"], dtype="object"),
            "bool_col": pd.array([True, False, True], dtype="bool"),
        }
    )

    z = ZANJ()
    path = TEST_DATA_PATH / "test_dataframe_dtypes.zanj"
    z.save({"df": df}, path)
    recovered = z.read(path)

    assert isinstance(recovered["df"], pd.DataFrame)

    # Check values are preserved (dtypes may change due to JSON serialization)
    for col in df.columns:
        original_vals = df[col].tolist()
        recovered_vals = recovered["df"][col].tolist()
        assert original_vals == recovered_vals, (
            f"Column {col} values don't match: {original_vals} != {recovered_vals}"
        )


def test_dataframe_with_nan_values():
    """Test DataFrame containing NaN and None values."""
    df = pd.DataFrame(
        {
            "with_nan": [1.0, np.nan, 3.0],
            "with_none": [1, None, 3],
            "normal": [1, 2, 3],
        }
    )

    z = ZANJ()
    path = TEST_DATA_PATH / "test_dataframe_nan.zanj"
    z.save({"df": df}, path)
    recovered = z.read(path)

    assert isinstance(recovered["df"], pd.DataFrame)

    # Check NaN is preserved (use isna() for comparison)
    assert pd.isna(recovered["df"]["with_nan"].iloc[1]), "NaN should be preserved"
    assert recovered["df"]["with_nan"].iloc[0] == 1.0
    assert recovered["df"]["with_nan"].iloc[2] == 3.0


def test_dataframe_special_column_names():
    """Test DataFrame with unusual column names."""
    df = pd.DataFrame(
        {
            "normal_name": [1, 2],
            "with spaces": [3, 4],
            "with-dashes": [5, 6],
            "123_numeric_start": [7, 8],
            "special!@#chars": [9, 10],
        }
    )

    z = ZANJ()
    path = TEST_DATA_PATH / "test_dataframe_special_cols.zanj"
    z.save({"df": df}, path)
    recovered = z.read(path)

    assert isinstance(recovered["df"], pd.DataFrame)
    assert list(recovered["df"].columns) == list(df.columns), (
        "Special column names should be preserved"
    )
    assert df.equals(recovered["df"]), "DataFrames should be equal"

``````{ end_of_file="tests/unit/no_torch/test_dataframe_serialization.py" }

``````{ path="tests/unit/no_torch/test_isolate_zanj_handler_store.py"  }
from __future__ import annotations

import json
import typing
import zipfile
from pathlib import Path

import numpy as np
from muutils.json_serialize import (
    SerializableDataclass,
    serializable_dataclass,
    serializable_field,
)

from zanj import ZANJ
from zanj.loading import LOADER_MAP

np.random.seed(0)

# pylint: disable=missing-function-docstring,missing-class-docstring

TEST_DATA_PATH: Path = Path("tests/junk_data")


@serializable_dataclass
class Basic(SerializableDataclass):
    a: str
    q: int = 42
    c: typing.List[int] = serializable_field(default_factory=list)


def test_Basic():
    instance = Basic("hello", 42, [1, 2, 3])

    z = ZANJ()
    path = TEST_DATA_PATH / "test_Basic.zanj"
    z.save(instance, path)
    recovered = z.read(path)
    assert instance == recovered


print(list(LOADER_MAP.keys()))


@serializable_dataclass
class ModelCfg(SerializableDataclass):
    name: str
    num_layers: int
    hidden_size: int
    dropout: float


print(list(LOADER_MAP.keys()))


def test_isolate_handlers():
    instance = ModelCfg("lstm", 3, 128, 0.1)

    print(list(LOADER_MAP.keys()))

    z = ZANJ()
    path = TEST_DATA_PATH / "00-test_isolate_handlers.zanj"
    z.save(instance, path)
    recovered = z.read(path)
    assert instance == recovered

    assert "Basic(SerializableDataclass)" in LOADER_MAP
    assert "ModelCfg(SerializableDataclass)" in LOADER_MAP

    # check they are in the zanj file
    with zipfile.ZipFile(path, "r") as zfile:
        zmeta = json.load(zfile.open("__zanj_meta__.json", "r"))
        assert "Basic(SerializableDataclass)" in zmeta["zanj_cfg"]["load_handlers"]
        assert "ModelCfg(SerializableDataclass)" in zmeta["zanj_cfg"]["load_handlers"]


if __name__ == "__main__":
    test_isolate_handlers()

``````{ end_of_file="tests/unit/no_torch/test_isolate_zanj_handler_store.py" }

``````{ path="tests/unit/no_torch/test_load_item_recursive.py"  }
from __future__ import annotations

import typing
from pathlib import Path

import numpy as np
import pytest
from muutils.errormode import ErrorMode
from muutils.json_serialize import (
    SerializableDataclass,
    serializable_dataclass,
    serializable_field,
)
from zanj import ZANJ
from zanj.consts import _FORMAT_KEY
from zanj.loading import LoadedZANJ, load_item_recursive

TEST_DATA_PATH: Path = Path("tests/junk_data")


def test_load_item_recursive_basic():
    """Test basic functionality of load_item_recursive"""
    # Simple JSON data
    json_data = {
        "name": "test",
        "value": 42,
        "list": [1, 2, 3],
        "nested": {"a": 1, "b": 2},
    }

    # Load with default parameters
    result = load_item_recursive(json_data, tuple(), None)

    # Check the result
    assert result == json_data
    assert result["name"] == "test"
    assert result["value"] == 42
    assert result["list"] == [1, 2, 3]
    assert result["nested"] == {"a": 1, "b": 2}


def test_load_item_recursive_numpy_array():
    """Test loading a numpy array"""
    # Create a JSON representation of a numpy array properly formatted
    array_data = np.random.rand(5, 5)
    json_data = {
        _FORMAT_KEY: "numpy.ndarray:array_list_meta",  # Use the correct format suffix
        "dtype": str(array_data.dtype),
        "shape": list(array_data.shape),
        "data": array_data.tolist(),
    }

    # Load with default parameters
    result = load_item_recursive(json_data, tuple(), None)

    # Check the result
    assert isinstance(result, np.ndarray)
    assert result.shape == tuple(json_data["shape"])
    assert result.dtype == np.dtype(json_data["dtype"])
    assert np.allclose(result, array_data)


def test_load_item_recursive_serializable_dataclass():
    """Test loading a SerializableDataclass"""

    @serializable_dataclass
    class TestClass(SerializableDataclass):
        name: str
        value: int
        data: typing.List[int] = serializable_field(default_factory=list)

    # Create an instance and serialize it
    instance = TestClass("test", 42, [1, 2, 3])
    serialized = instance.serialize()

    # Load with default parameters
    result = load_item_recursive(serialized, tuple(), None)

    # Check the result
    assert isinstance(result, TestClass)
    assert result.name == "test"
    assert result.value == 42
    assert result.data == [1, 2, 3]


def test_load_item_recursive_nested_container():
    """Test loading with nested containers"""
    # Create a complex nested structure with properly formatted arrays
    json_data = {
        "name": "test",
        "arrays": [
            {
                _FORMAT_KEY: "numpy.ndarray:array_list_meta",  # Use correct format suffix
                "dtype": "float64",
                "shape": [3, 3],
                "data": np.random.rand(3, 3).tolist(),
            },
            {
                _FORMAT_KEY: "numpy.ndarray:array_list_meta",  # Use correct format suffix
                "dtype": "float64",
                "shape": [2, 2],
                "data": np.random.rand(2, 2).tolist(),
            },
        ],
        "nested": {
            "dict_with_array": {
                _FORMAT_KEY: "numpy.ndarray:array_list_meta",  # Use correct format suffix
                "dtype": "float64",
                "shape": [4, 4],
                "data": np.random.rand(4, 4).tolist(),
            }
        },
    }

    # Load with default parameters
    result = load_item_recursive(json_data, tuple(), None)

    # Check the result
    assert result["name"] == "test"
    assert len(result["arrays"]) == 2
    assert isinstance(result["arrays"][0], np.ndarray)
    assert isinstance(result["arrays"][1], np.ndarray)
    assert result["arrays"][0].shape == (3, 3)
    assert result["arrays"][1].shape == (2, 2)
    assert isinstance(result["nested"]["dict_with_array"], np.ndarray)
    assert result["nested"]["dict_with_array"].shape == (4, 4)


def test_load_item_recursive_unknown_format():
    """Test loading with an unknown format key"""
    # Create JSON data with an unknown format that is not registered in the handlers
    json_data = {
        _FORMAT_KEY: "unknown.format.that.definitely.does.not.exist",
        "data": [1, 2, 3],
    }

    # Load with default parameters (should return the JSON as is)
    result = load_item_recursive(json_data, tuple(), None, allow_not_loading=True)

    # Check the result
    assert result == json_data

    # TODO: this doesn't raise any errors
    # Test with allow_not_loading=False (should raise an error)
    # Create a ZANJ with EXCEPT error mode to ensure value errors are raised
    z = ZANJ(error_mode=ErrorMode.EXCEPT)
    load_item_recursive(
        json_data, tuple(), z, error_mode=ErrorMode.EXCEPT, allow_not_loading=False
    )


def test_load_item_recursive_with_external_reference():
    """Test loading an item with an external reference"""
    # Create a ZANJ object and save some data to create externals
    z = ZANJ(external_array_threshold=10)
    data = {"large_array": np.random.rand(20, 20)}
    path = TEST_DATA_PATH / "test_load_item_recursive_external.zanj"
    z.save(data, path)

    # Load the ZANJ file
    loaded_zanj = LoadedZANJ(path, z)

    # Try loading the data
    loaded_zanj.populate_externals()

    # Check that the externals were populated
    assert len(loaded_zanj._externals) > 0

    # Verify JSON data structure
    assert "_REF_KEY" in loaded_zanj._json_data or isinstance(
        loaded_zanj._json_data, dict
    )


def test_load_item_recursive_error_modes():
    """Test different error modes"""
    # Create JSON data with an unknown format
    json_data = {
        _FORMAT_KEY: "unknown.format.that.definitely.does.not.exist",
        "data": [1, 2, 3],
    }

    # Test WARN mode (should not raise, just return the data)
    result = load_item_recursive(
        json_data, tuple(), None, error_mode=ErrorMode.WARN, allow_not_loading=True
    )
    assert result == json_data

    # Test IGNORE mode (should not raise, just return the data)
    result = load_item_recursive(
        json_data, tuple(), None, error_mode=ErrorMode.IGNORE, allow_not_loading=True
    )
    assert result == json_data

    # Create a custom class that's known to fail during loading
    class CustomHandler:
        def check(self, json_item, path=None, z=None):
            return (
                json_item.get(_FORMAT_KEY)
                == "unknown.format.that.definitely.does.not.exist"
            )

        def load(self, json_item, path=None, z=None):
            # This will raise a ValueError
            raise ValueError("Forced error for testing purposes")

    # Register this handler temporarily
    import zanj.loading

    original_get_item_loader = zanj.loading.get_item_loader

    def mock_get_item_loader(*args, **kwargs):
        # Always return our custom handler
        return CustomHandler()

    try:
        # Override the get_item_loader function
        zanj.loading.get_item_loader = mock_get_item_loader

        # Test EXCEPT mode (should raise)
        with pytest.raises(ValueError):
            load_item_recursive(
                json_data,
                tuple(),
                None,
                error_mode=ErrorMode.EXCEPT,
                allow_not_loading=True,
            )
    finally:
        # Restore the original function
        zanj.loading.get_item_loader = original_get_item_loader

``````{ end_of_file="tests/unit/no_torch/test_load_item_recursive.py" }

``````{ path="tests/unit/no_torch/test_loading_edge_cases.py"  }
"""Edge case tests for zanj/loading.py to improve coverage."""

from __future__ import annotations

import pytest

from zanj.consts import _FORMAT_KEY
from zanj.loading import (
    LoaderHandler,
    _populate_externals_error_checking,
    get_item_loader,
    load_item_recursive,
)


class TestPopulateExternalsErrorChecking:
    """Tests for _populate_externals_error_checking function."""

    def test_external_item_missing_data_field(self):
        """Line 68: External item with format key but no data field should raise KeyError."""
        malformed_item = {
            _FORMAT_KEY: "list:external",
            "some_field": "value",
            # missing "data" field
        }
        with pytest.raises(KeyError, match="expected an external item"):
            _populate_externals_error_checking("key", malformed_item)

    def test_sequence_with_non_int_key(self):
        """Line 75-76: Accessing a sequence with a non-int key should raise TypeError."""
        sequence = [1, 2, 3]
        with pytest.raises(TypeError, match="expected int"):
            _populate_externals_error_checking("string_key", sequence)

    def test_sequence_with_out_of_range_key(self):
        """Line 77-78: Accessing a sequence with an out-of-range index should raise IndexError."""
        sequence = [1, 2, 3]
        with pytest.raises(IndexError, match="index out of range"):
            _populate_externals_error_checking(100, sequence)

    def test_mapping_with_non_str_key(self):
        """Line 82-83: Accessing a mapping with a non-str key should raise TypeError."""
        mapping = {"a": 1, "b": 2}
        with pytest.raises(TypeError, match="expected str"):
            _populate_externals_error_checking(123, mapping)

    def test_mapping_with_missing_key(self):
        """Line 84-85: Accessing a mapping with a missing key should raise KeyError."""
        mapping = {"a": 1, "b": 2}
        with pytest.raises(KeyError, match="key not in dict"):
            _populate_externals_error_checking("missing_key", mapping)

    def test_invalid_item_type(self):
        """Line 88-89: Passing an invalid item type should raise TypeError."""
        invalid_item = 42  # int is neither sequence nor mapping
        with pytest.raises(TypeError, match="expected dict or list"):
            _populate_externals_error_checking("key", invalid_item)

    def test_valid_sequence_access(self):
        """Valid sequence access should not raise and return False."""
        sequence = [1, 2, 3]
        result = _populate_externals_error_checking(1, sequence)
        assert result is False

    def test_valid_mapping_access(self):
        """Valid mapping access should not raise and return False."""
        mapping = {"a": 1, "b": 2}
        result = _populate_externals_error_checking("a", mapping)
        assert result is False

    def test_external_item_with_data_returns_true(self):
        """External item with data field should return True."""
        external_item = {
            _FORMAT_KEY: "list:external",
            "data": [1, 2, 3],
        }
        result = _populate_externals_error_checking("key", external_item)
        assert result is True


class TestLoaderHandlerFromFormattedClass:
    """Tests for LoaderHandler.from_formattedclass method."""

    def test_missing_serialize_method(self):
        """Line 137: Class missing serialize method should raise AssertionError."""

        class MissingSerialize:
            __muutils_format__ = "test"

            @classmethod
            def load(cls, json_item, path=None, z=None):
                pass

        with pytest.raises(AssertionError):
            LoaderHandler.from_formattedclass(MissingSerialize)

    def test_missing_load_method(self):
        """Line 139: Class missing load method should raise AssertionError."""

        class MissingLoad:
            __muutils_format__ = "test"

            def serialize(self):
                pass

        with pytest.raises(AssertionError):
            LoaderHandler.from_formattedclass(MissingLoad)

    def test_missing_format_key(self):
        """Line 141: Class missing __muutils_format__ should raise AssertionError."""

        class MissingFormat:
            def serialize(self):
                pass

            @classmethod
            def load(cls, json_item, path=None, z=None):
                pass

        with pytest.raises(AssertionError):
            LoaderHandler.from_formattedclass(MissingFormat)

    def test_non_string_format_key(self):
        """Line 142: Class with non-string __muutils_format__ should raise AssertionError."""

        class NonStringFormat:
            __muutils_format__ = 12345  # should be string

            def serialize(self):
                pass

            @classmethod
            def load(cls, json_item, path=None, z=None):
                pass

        with pytest.raises(AssertionError):
            LoaderHandler.from_formattedclass(NonStringFormat)

    def test_valid_formatted_class(self):
        """Valid class should create a LoaderHandler successfully."""

        class ValidClass:
            __muutils_format__ = "test.ValidClass"

            def serialize(self):
                return {}

            @classmethod
            def load(cls, json_item, path=None, z=None):
                return cls()

        handler = LoaderHandler.from_formattedclass(ValidClass)
        assert handler.uid == "test.ValidClass"
        assert "ValidClass" in handler.desc


class TestGetItemLoader:
    """Tests for get_item_loader function."""

    def test_non_string_format_key(self):
        """Line 306-309: Item with non-string __muutils_format__ should raise TypeError."""
        malformed_item = {
            _FORMAT_KEY: 12345,  # should be str
            "data": [1, 2, 3],
        }
        with pytest.raises(TypeError, match="invalid __muutils_format__ type"):
            get_item_loader(malformed_item, ("test",))


class TestLoadItemRecursive:
    """Tests for load_item_recursive function."""

    def test_unloadable_type_strict_mode(self):
        """Line 395-398: Unloadable type with allow_not_loading=False should raise ValueError."""

        # Create a custom object that isn't handled by any loader
        class CustomUnloadable:
            pass

        item = CustomUnloadable()
        with pytest.raises(ValueError, match="unknown type"):
            load_item_recursive(item, ("test",), None, allow_not_loading=False)

    def test_unloadable_type_permissive_mode(self):
        """Line 395-396: Unloadable type with allow_not_loading=True should return as-is."""

        class CustomUnloadable:
            pass

        item = CustomUnloadable()
        result = load_item_recursive(item, ("test",), None, allow_not_loading=True)
        assert result is item

``````{ end_of_file="tests/unit/no_torch/test_loading_edge_cases.py" }

``````{ path="tests/unit/no_torch/test_serializing_edge_cases.py"  }
"""Edge case tests for zanj/serializing.py to improve coverage."""

from __future__ import annotations

import numpy as np
import pytest

from zanj import ZANJ
from zanj.serializing import zanj_external_serialize


class TestZanjExternalSerialize:
    """Tests for zanj_external_serialize function edge cases."""

    def test_duplicate_external_path(self):
        """Lines 124-127: Duplicate external path should raise ValueError."""
        z = ZANJ()
        array1 = np.random.rand(10, 10)
        array2 = np.random.rand(5, 5)

        # First call with path ('data',) should succeed
        zanj_external_serialize(
            z, array1, path=("data",), item_type="npy", _format="numpy.ndarray:external"
        )

        # Second call with same path should raise ValueError
        with pytest.raises(ValueError, match="already exists"):
            zanj_external_serialize(
                z,
                array2,
                path=("data",),
                item_type="npy",
                _format="numpy.ndarray:external",
            )

    def test_path_prefix_conflict_child_first(self):
        """Lines 137-142: Path prefix conflict where new path is parent of existing."""
        z = ZANJ()
        array1 = np.random.rand(10, 10)
        array2 = np.random.rand(5, 5)

        # First: add child path 'layer/1/weight'
        zanj_external_serialize(
            z,
            array1,
            path=("layer", "1", "weight"),
            item_type="npy",
            _format="numpy.ndarray:external",
        )

        # Second: try to add parent path 'layer/1' - should conflict
        with pytest.raises(ValueError, match="is a prefix of another path"):
            zanj_external_serialize(
                z,
                array2,
                path=("layer", "1"),
                item_type="npy",
                _format="numpy.ndarray:external",
            )

    def test_path_prefix_conflict_parent_first(self):
        """Lines 134-136: Path prefix conflict where new path is child of existing."""
        z = ZANJ()
        array1 = np.random.rand(10, 10)
        array2 = np.random.rand(5, 5)

        # First: add parent path 'layer/1'
        zanj_external_serialize(
            z,
            array1,
            path=("layer", "1"),
            item_type="npy",
            _format="numpy.ndarray:external",
        )

        # Second: try to add child path 'layer/1/weight' - should conflict
        with pytest.raises(ValueError, match="is a prefix of another path"):
            zanj_external_serialize(
                z,
                array2,
                path=("layer", "1", "weight"),
                item_type="npy",
                _format="numpy.ndarray:external",
            )

    def test_invalid_npy_data_type(self):
        """Line 160: Invalid data type for NPY serialization should raise TypeError."""
        z = ZANJ()

        # Pass a string instead of array - should fail
        with pytest.raises(TypeError, match="expected numpy.ndarray"):
            zanj_external_serialize(
                z,
                data="not an array",
                path=("test",),
                item_type="npy",
                _format="numpy.ndarray:external",
            )

    def test_invalid_jsonl_data_type(self):
        """Line 184: Invalid data type for JSONL serialization should raise TypeError."""
        z = ZANJ()

        # Create a custom class that is not iterable/sequence/dataframe
        class NotSerializableAsJsonl:
            pass

        obj = NotSerializableAsJsonl()

        with pytest.raises(TypeError, match="expected list or pandas.DataFrame"):
            zanj_external_serialize(
                z,
                data=obj,
                path=("test",),
                item_type="jsonl",
                _format="list:external",
            )

    def test_valid_npy_serialization(self):
        """Verify valid NPY serialization works correctly."""
        z = ZANJ()
        array = np.random.rand(10, 10)

        result = zanj_external_serialize(
            z,
            array,
            path=("valid_array",),
            item_type="npy",
            _format="numpy.ndarray:external",
        )

        assert "__muutils_format__" in result
        assert result["__muutils_format__"] == "numpy.ndarray:external"
        assert "valid_array.npy" in z._externals

    def test_valid_jsonl_list_serialization(self):
        """Verify valid JSONL list serialization works correctly."""
        z = ZANJ()
        data = [{"a": 1}, {"b": 2}, {"c": 3}]

        result = zanj_external_serialize(
            z,
            data,
            path=("valid_list",),
            item_type="jsonl",
            _format="list:external",
        )

        assert "__muutils_format__" in result
        assert result["__muutils_format__"] == "list:external"
        assert "valid_list.jsonl" in z._externals

    def test_non_overlapping_paths_allowed(self):
        """Verify that non-overlapping paths with similar prefixes are allowed."""
        z = ZANJ()
        array1 = np.random.rand(10, 10)
        array2 = np.random.rand(5, 5)

        # These should NOT conflict: 'layer.1' and 'layer.1.weight' as strings
        # but 'layer/1' and 'layer/10' should be fine
        zanj_external_serialize(
            z,
            array1,
            path=("layer", "1"),
            item_type="npy",
            _format="numpy.ndarray:external",
        )

        # 'layer/10' is NOT a prefix of 'layer/1' or vice versa
        zanj_external_serialize(
            z,
            array2,
            path=("layer", "10"),
            item_type="npy",
            _format="numpy.ndarray:external",
        )

        assert "layer/1.npy" in z._externals
        assert "layer/10.npy" in z._externals

``````{ end_of_file="tests/unit/no_torch/test_serializing_edge_cases.py" }

``````{ path="tests/unit/no_torch/test_shared_prefix_keys.py"  }
from pathlib import Path
import typing
import numpy as np

import pytest

from zanj import ZANJ

_TEMP_PATH: Path = Path("tests/.temp/")


# NOTE: as of 2025-11-06 15:32 (v0.5.1), the first test (longer key first) fails, while the second test passes. wtf?


@pytest.mark.parametrize(
    ("keys", "name"),
    [
        (["layer.1.weight", "layer.1"], "longer_key_first"),
        (["layer.1", "layer.1.weight"], "shorter_key_first"),
    ],
)
def test_shared_prefix_keys(keys: typing.List[str], name: str):
    fname: Path = _TEMP_PATH / f"shared_prefix_keys-{name}.zanj"

    #
    data = {key: np.random.rand(10, 10) for key in keys}

    ZANJ(external_array_threshold=0).save(data, fname)

    print("saved successfully")
    loaded = ZANJ().read(fname)
    assert set(data.keys()) == set(loaded.keys())
    for key in data.keys():
        print(f"{key = }")
        print(f"{type(data[key]) = }")
        print(f"{data[key] = }")
        print(f"{type(loaded[key]) = }")
        print(f"{loaded[key] = }")
        assert type(loaded[key]) == type(data[key])  # noqa: E721
        np.testing.assert_array_equal(data[key], loaded[key])

``````{ end_of_file="tests/unit/no_torch/test_shared_prefix_keys.py" }

``````{ path="tests/unit/no_torch/test_zanj_basic.py"  }
from __future__ import annotations

import json
import typing
from pathlib import Path

import numpy as np
import pandas as pd  # type: ignore

from zanj import ZANJ

np.random.seed(0)


TEST_DATA_PATH: Path = Path("tests/junk_data")


def array_meta(x: typing.Any) -> dict:
    if isinstance(x, np.ndarray):
        return dict(
            shape=list(x.shape),
            dtype=str(x.dtype),
            contents=str(x),
        )
    else:
        return dict(
            type=type(x).__name__,
            contents=str(x),
        )


def test_numpy():
    data = dict(
        name="testing zanj",
        some_array=np.random.rand(128, 128),
        some_other_array=np.random.rand(16, 64),
        small_array=np.random.rand(4, 4),
    )
    fname: Path = TEST_DATA_PATH / "test_numpy.zanj"
    z: ZANJ = ZANJ()
    z.save(data, fname)
    recovered_data = z.read(fname)

    print(f"{list(data.keys()) = }")
    print(f"{list(recovered_data.keys()) = }")
    original_vals: dict = {k: array_meta(v) for k, v in data.items()}
    print(json.dumps(original_vals, indent=2))
    recovered_vals: dict = {k: array_meta(v) for k, v in recovered_data.items()}
    print(json.dumps(recovered_vals, indent=2))

    assert sorted(list(data.keys())) == sorted(list(recovered_data.keys()))
    # assert all([type(data[k]) == type(recovered_data[k]) for k in data.keys()])

    assert all(
        [
            data["name"] == recovered_data["name"],
            np.allclose(data["some_array"], recovered_data["some_array"]),
            np.allclose(data["some_other_array"], recovered_data["some_other_array"]),
            np.allclose(data["small_array"], recovered_data["small_array"]),
        ]
    ), f"assert failed:\n{data = }\n{recovered_data = }"


def test_jsonl():
    data = dict(
        name="testing zanj jsonl",
        iris_data=pd.read_csv("tests/input_data/iris.csv"),
        brain_data=pd.read_csv("tests/input_data/brain_networks.csv"),
        some_array=np.random.rand(128, 128),
    )
    fname: Path = TEST_DATA_PATH / "test_jsonl.zanj"
    z: ZANJ = ZANJ()
    z.save(data, fname)
    recovered_data = z.read(fname)

    assert sorted(list(data.keys())) == sorted(list(recovered_data.keys()))
    # assert all([type(data[k]) == type(recovered_data[k]) for k in data.keys()])

    assert all(
        [
            data["name"] == recovered_data["name"],
            np.allclose(data["some_array"], recovered_data["some_array"]),
            data["iris_data"].equals(recovered_data["iris_data"]),
            data["brain_data"].equals(recovered_data["brain_data"]),
        ]
    )


def test_polars_dataframe():
    import polars as pl

    # basic dataframe with various types
    data = dict(
        name="testing zanj polars",
        df=pl.DataFrame(
            {
                "a": [1, 2, 3],
                "b": ["x", "y", "z"],
                "c": [1.1, 2.2, 3.3],
            }
        ),
        some_array=np.random.rand(128, 128),
    )
    fname: Path = TEST_DATA_PATH / "test_polars.zanj"
    z: ZANJ = ZANJ()
    z.save(data, fname)
    recovered_data = z.read(fname)

    assert sorted(list(data.keys())) == sorted(list(recovered_data.keys()))

    assert all(
        [
            data["name"] == recovered_data["name"],
            np.allclose(data["some_array"], recovered_data["some_array"]),
            data["df"].equals(recovered_data["df"]),
        ]
    )


def test_polars_dataframe_empty():
    """Test empty polars DataFrame serialization"""
    import polars as pl

    data = dict(
        name="testing empty polars df",
        empty_df=pl.DataFrame({"a": [], "b": [], "c": []}),
    )
    fname: Path = TEST_DATA_PATH / "test_polars_empty.zanj"
    z: ZANJ = ZANJ()
    z.save(data, fname)
    recovered_data = z.read(fname)

    assert data["name"] == recovered_data["name"]
    assert recovered_data["empty_df"].shape == (0, 3)
    assert recovered_data["empty_df"].columns == ["a", "b", "c"]


def test_polars_dataframe_large():
    """Test larger polars DataFrame to ensure external storage works"""
    import polars as pl

    # create a larger dataframe
    n_rows = 1000
    data = dict(
        name="testing large polars df",
        large_df=pl.DataFrame(
            {
                "int_col": list(range(n_rows)),
                "float_col": [float(i) * 0.1 for i in range(n_rows)],
                "str_col": [f"row_{i}" for i in range(n_rows)],
                "bool_col": [i % 2 == 0 for i in range(n_rows)],
            }
        ),
    )
    fname: Path = TEST_DATA_PATH / "test_polars_large.zanj"
    z: ZANJ = ZANJ()
    z.save(data, fname)
    recovered_data = z.read(fname)

    assert data["name"] == recovered_data["name"]
    assert data["large_df"].equals(recovered_data["large_df"])


def test_polars_with_nulls():
    """Test polars DataFrame with null values"""
    import polars as pl

    data = dict(
        name="testing polars with nulls",
        df_with_nulls=pl.DataFrame(
            {
                "a": [1, None, 3],
                "b": ["x", "y", None],
                "c": [1.1, None, 3.3],
            }
        ),
    )
    fname: Path = TEST_DATA_PATH / "test_polars_nulls.zanj"
    z: ZANJ = ZANJ()
    z.save(data, fname)
    recovered_data = z.read(fname)

    assert data["name"] == recovered_data["name"]
    assert data["df_with_nulls"].equals(recovered_data["df_with_nulls"])

``````{ end_of_file="tests/unit/no_torch/test_zanj_basic.py" }

``````{ path="tests/unit/no_torch/test_zanj_edge_cases.py"  }
from __future__ import annotations

import os
import zipfile
from pathlib import Path

import numpy as np
import pytest
from muutils.errormode import ErrorMode

from zanj import ZANJ

TEST_DATA_PATH: Path = Path("tests/junk_data")


def test_zanj_with_different_configs():
    """Test ZANJ with different configuration options"""
    # Create data to save
    data = {
        "name": "test_config",
        "array": np.random.rand(50, 50),  # Just below default threshold
    }

    # Test with default config (external_array_threshold=256)
    z1 = ZANJ()
    path1 = TEST_DATA_PATH / "test_default_config.zanj"
    z1.save(data, path1)

    # Test with low threshold to force external storage
    z2 = ZANJ(external_array_threshold=10)
    path2 = TEST_DATA_PATH / "test_low_threshold.zanj"
    z2.save(data, path2)

    # Test with high threshold to force internal storage
    z3 = ZANJ(external_array_threshold=10000)
    path3 = TEST_DATA_PATH / "test_high_threshold.zanj"
    z3.save(data, path3)

    # Check that the files exist
    assert path1.exists()
    assert path2.exists()
    assert path3.exists()

    # Check that all three files can be loaded correctly
    data1 = z1.read(path1)
    data2 = z2.read(path2)
    data3 = z3.read(path3)

    assert data1["name"] == data["name"]
    assert data2["name"] == data["name"]
    assert data3["name"] == data["name"]

    assert np.allclose(data1["array"], data["array"])
    assert np.allclose(data2["array"], data["array"])
    assert np.allclose(data3["array"], data["array"])


def test_zanj_compression_options():
    """Test different compression settings"""
    data = {
        "name": "compression_test",
        "array": np.random.rand(100, 100),
    }

    # Test with default compression (True -> ZIP_DEFLATED)
    z1 = ZANJ(compress=True)
    path1 = TEST_DATA_PATH / "test_default_compression.zanj"
    z1.save(data, path1)

    # Test with no compression
    z2 = ZANJ(compress=False)
    path2 = TEST_DATA_PATH / "test_no_compression.zanj"
    z2.save(data, path2)

    # Test with explicit compression level
    z3 = ZANJ(compress=zipfile.ZIP_DEFLATED)
    path3 = TEST_DATA_PATH / "test_explicit_compression.zanj"
    z3.save(data, path3)

    # Check files exist
    assert path1.exists()
    assert path2.exists()
    assert path3.exists()

    # Both should load correctly
    data1 = z1.read(path1)
    data2 = z2.read(path2)
    data3 = z3.read(path3)

    assert data1["name"] == data["name"]
    assert data2["name"] == data["name"]
    assert data3["name"] == data["name"]

    assert np.allclose(data1["array"], data["array"])
    assert np.allclose(data2["array"], data["array"])
    assert np.allclose(data3["array"], data["array"])


def test_zanj_error_modes():
    """Test different error modes"""

    # Create a class with __repr__ that will cause an error during serialization
    class ForceExceptionOnSerialize:
        def __repr__(self):
            raise Exception("Forced exception during serialization")

    # Create data with a problematic object
    data = {
        "name": "error_test",
        "unserializable": ForceExceptionOnSerialize(),
    }

    # Create a subclass of ZANJ to force an exception
    class ExceptionForcingZANJ(ZANJ):
        def json_serialize(self, obj):
            if isinstance(obj, dict) and "unserializable" in obj:
                raise Exception("Forced exception")
            return super().json_serialize(obj)

    # Test with EXCEPT mode (should raise)
    z3 = ExceptionForcingZANJ(error_mode=ErrorMode.EXCEPT)
    path3 = TEST_DATA_PATH / "test_error_except.zanj"
    with pytest.raises(Exception):
        z3.save(data, path3)  # This should fail


def test_zanj_array_modes():
    """Test different array modes"""
    data = {
        "name": "array_mode_test",
        "array": np.random.rand(5, 5),
    }

    # Test with list mode (use the string value, not the enum attribute)
    z1 = ZANJ(internal_array_mode="list")
    path1 = TEST_DATA_PATH / "test_array_mode_list.zanj"
    z1.save(data, path1)

    # Test with array_list_meta mode
    z3 = ZANJ(internal_array_mode="array_list_meta")
    path3 = TEST_DATA_PATH / "test_array_mode_array_list_meta.zanj"
    z3.save(data, path3)

    # Check that all files can be loaded correctly
    data1 = z1.read(path1)
    data3 = z3.read(path3)

    assert data1["name"] == data["name"]
    assert data3["name"] == data["name"]

    assert np.allclose(data1["array"], data["array"])
    assert np.allclose(data3["array"], data["array"])


def test_zanj_meta():
    """Test the meta method of ZANJ"""
    # Create a ZANJ instance
    z = ZANJ()

    # Call the meta method
    meta = z.meta()

    # Check that it contains the expected fields
    assert "zanj_cfg" in meta
    assert "sysinfo" in meta
    assert "externals_info" in meta
    assert "timestamp" in meta

    # Check that zanj_cfg contains configuration information
    assert "error_mode" in meta["zanj_cfg"]
    assert "array_mode" in meta["zanj_cfg"]
    assert "external_array_threshold" in meta["zanj_cfg"]
    assert "external_list_threshold" in meta["zanj_cfg"]
    assert "compress" in meta["zanj_cfg"]
    assert "serialization_handlers" in meta["zanj_cfg"]
    assert "load_handlers" in meta["zanj_cfg"]


def test_zanj_externals_info():
    """Test the externals_info method of ZANJ"""
    # Create a ZANJ instance with a low threshold
    z = ZANJ(external_array_threshold=10)

    # Create data with an array that will be stored externally
    data = {
        "name": "externals_test",
        "array": np.random.rand(20, 20),
    }

    # Save the data
    path = TEST_DATA_PATH / "test_externals_info.zanj"
    z.save(data, path)

    # The externals should be empty after saving
    assert len(z._externals) == 0

    # Load the data to populate externals
    loaded_data = z.read(path)

    # Check that the data was loaded correctly
    assert loaded_data["name"] == data["name"]
    assert np.allclose(loaded_data["array"], data["array"])


def test_zanj_save_extension():
    """Test that ZANJ adds the .zanj extension if not provided"""
    data = {"name": "extension_test"}

    # Save without extension
    z = ZANJ()
    path_str = str(TEST_DATA_PATH / "test_no_extension")
    actual_path = z.save(data, path_str)

    # Check that .zanj extension was added
    assert actual_path.endswith(".zanj")
    assert actual_path == path_str + ".zanj"

    # Check that the file exists
    assert os.path.exists(actual_path)

    # Check that we can load it
    loaded_data = z.read(actual_path)
    assert loaded_data["name"] == data["name"]

    # Save with extension already provided
    path_with_ext = str(TEST_DATA_PATH / "test_with_extension.zanj")
    actual_path2 = z.save(data, path_with_ext)

    # Check that the extension was not added again
    assert actual_path2 == path_with_ext
    assert not actual_path2.endswith(".zanj.zanj")

    # Check that the file exists
    assert os.path.exists(actual_path2)

    # Check that we can load it
    loaded_data2 = z.read(actual_path2)
    assert loaded_data2["name"] == data["name"]


def test_zanj_file_not_found():
    """Test behavior when trying to read a non-existent file"""
    z = ZANJ()

    # Try to read a non-existent file
    non_existent_path = TEST_DATA_PATH / "non_existent_file.zanj"

    # Should raise FileNotFoundError
    with pytest.raises(FileNotFoundError):
        z.read(non_existent_path)

    # Try to read a directory (not a file)
    # First ensure the directory exists
    dir_path = TEST_DATA_PATH / "test_dir"
    dir_path.mkdir(exist_ok=True)

    # Should raise FileNotFoundError with "not a file" message
    with pytest.raises(FileNotFoundError):
        z.read(dir_path)


def test_zanj_create_directory():
    """Test that ZANJ creates the directory structure if needed"""
    data = {"name": "dir_test"}

    # Use a nested directory structure that doesn't exist yet
    nested_dir = TEST_DATA_PATH / "new_dir" / "nested" / "structure"
    nested_path = nested_dir / "test_file.zanj"

    # Make sure the directory doesn't exist
    import shutil

    if (TEST_DATA_PATH / "new_dir").exists():
        shutil.rmtree(TEST_DATA_PATH / "new_dir")

    # Save should create all necessary directories
    z = ZANJ()
    z.save(data, nested_path)

    # Check that directories were created
    assert nested_dir.exists()
    assert nested_dir.is_dir()

    # Check that the file exists
    assert nested_path.exists()
    assert nested_path.is_file()

    # Check that we can load it
    loaded_data = z.read(nested_path)
    assert loaded_data["name"] == data["name"]

``````{ end_of_file="tests/unit/no_torch/test_zanj_edge_cases.py" }

``````{ path="tests/unit/no_torch/test_zanj_populate_nested.py"  }
from __future__ import annotations

import typing
from pathlib import Path

import numpy as np
from muutils.json_serialize import (
    SerializableDataclass,
    serializable_dataclass,
    serializable_field,
)

from zanj import ZANJ

np.random.seed(0)

# pylint: disable=missing-function-docstring,missing-class-docstring

TEST_DATA_PATH: Path = Path("tests/junk_data")


@serializable_dataclass
class InnerClassWithArray(SerializableDataclass):
    some_string: str
    arr_numbers: np.ndarray


@serializable_dataclass
class OuterClassWithNestedList(SerializableDataclass):
    name: str
    lst_basic: typing.List[InnerClassWithArray] = serializable_field(
        serialization_fn=lambda x: [b.serialize() for b in x],
        loading_fn=lambda x: [InnerClassWithArray.load(b) for b in x["lst_basic"]],
    )


def test_nested_populate():
    instance = OuterClassWithNestedList(
        name="hello",
        lst_basic=[
            InnerClassWithArray(
                some_string=f"hello_{i}",
                arr_numbers=np.random.rand(20),
            )
            for i in range(20)
        ],
    )

    z = ZANJ(
        external_array_threshold=10,
        external_list_threshold=10,
    )
    path = TEST_DATA_PATH / "test_nested_populate.zanj"
    path.unlink(missing_ok=True)
    z.save(instance, path)
    recovered = z.read(path)
    assert instance == recovered

``````{ end_of_file="tests/unit/no_torch/test_zanj_populate_nested.py" }

``````{ path="tests/unit/no_torch/test_zanj_serializable_dataclass.py"  }
from __future__ import annotations

import json
import sys
import typing
from pathlib import Path

import numpy as np
import pandas as pd  # type: ignore[import]
from muutils.json_serialize import (
    SerializableDataclass,
    serializable_dataclass,
    serializable_field,
)

from zanj import ZANJ

np.random.seed(0)

TEST_DATA_PATH: Path = Path("tests/junk_data")

SUPPORTS_KW_ONLY: bool = bool(sys.version_info >= (3, 10))


@serializable_dataclass
class BasicZanj(SerializableDataclass):
    a: str
    q: int = 42
    c: typing.List[int] = serializable_field(default_factory=list)


def test_Basic():
    instance = BasicZanj("hello", 42, [1, 2, 3])

    z = ZANJ()
    path = TEST_DATA_PATH / "test_BasicZanj.zanj"
    z.save(instance, path)
    recovered = z.read(path)
    assert instance == recovered


@serializable_dataclass
class Nested(SerializableDataclass):
    name: str
    basic: BasicZanj
    val: float


def test_Nested():
    instance = Nested("hello", BasicZanj("hello", 42, [1, 2, 3]), 3.14)

    z = ZANJ()
    path = TEST_DATA_PATH / "test_Nested.zanj"
    z.save(instance, path)
    recovered = z.read(path)
    assert instance == recovered


@serializable_dataclass
class Nested_with_container(SerializableDataclass):
    name: str
    basic: BasicZanj
    val: float
    container: typing.List[Nested] = serializable_field(default_factory=list)


def test_Nested_with_container():
    instance = Nested_with_container(
        "hello",
        basic=BasicZanj("hello", 42, [1, 2, 3]),
        val=3.14,
        container=[
            Nested("n1", BasicZanj("n1_b", 123, [4, 5, 7]), 2.71),
            Nested("n2", BasicZanj("n2_b", 456, [7, 8, 9]), 6.28),
        ],
    )

    z = ZANJ()
    path = TEST_DATA_PATH / "test_Nested_with_container.zanj"
    z.save(instance, path)
    recovered = z.read(path)
    assert instance == recovered


@serializable_dataclass
class sdc_with_np_array(SerializableDataclass):
    name: str
    arr1: np.ndarray
    arr2: np.ndarray


def test_sdc_with_np_array_small():
    instance = sdc_with_np_array("small arrays", np.random.rand(10), np.random.rand(20))

    z = ZANJ()
    path = TEST_DATA_PATH / "test_sdc_with_np_array.zanj"
    z.save(instance, path)
    recovered = z.read(path)
    assert instance == recovered


def test_sdc_with_np_array():
    instance = sdc_with_np_array(
        "bigger arrays", np.random.rand(128, 128), np.random.rand(256, 256)
    )

    z = ZANJ()
    path = TEST_DATA_PATH / "test_sdc_with_np_array.zanj"
    z.save(instance, path)
    recovered = z.read(path)
    assert instance == recovered


@serializable_dataclass
class sdc_with_df(SerializableDataclass):
    name: str
    iris_data: pd.DataFrame
    brain_data: pd.DataFrame


def test_sdc_with_df():
    instance = sdc_with_df(
        "downloaded_data",
        iris_data=pd.read_csv("tests/input_data/iris.csv"),
        brain_data=pd.read_csv("tests/input_data/brain_networks.csv"),
    )

    z = ZANJ()
    path = TEST_DATA_PATH / "test_sdc_with_df.zanj"
    z.save(instance, path)
    recovered = z.read(path)
    assert instance == recovered


@serializable_dataclass
class sdc_container_explicit(SerializableDataclass):
    name: str
    container: typing.List[Nested] = serializable_field(
        default_factory=list,
        # as jsonl string, for whatever reason
        serialization_fn=lambda c: "\n".join([json.dumps(n.serialize()) for n in c]),
        loading_fn=lambda data: [
            Nested.load(json.loads(n)) for n in data["container"].split("\n")
        ],
        # TODO: explicitly specifying the following does not work, since it gets automatically converted before we call load in `loading_fn`:
        # serialization_fn=lambda c: [n.serialize() for n in c],
        # loading_fn=lambda data: [Nested.load(n) for n in data["container"]],
    )


def test_sdc_container_explicit():
    instance = sdc_container_explicit(
        "container explicit",
        container=[
            Nested(
                f"n-{n}",
                BasicZanj(f"n-{n}_b", n * 10 + 1, [n + 1, n + 2, n + 10]),
                n * np.pi,
            )
            for n in range(10)
        ],
    )

    z = ZANJ()
    path = TEST_DATA_PATH / "test_sdc_container_explicit.zanj"
    z.save(instance, path)
    recovered = z.read(path)
    assert instance == recovered

``````{ end_of_file="tests/unit/no_torch/test_zanj_serializable_dataclass.py" }

``````{ path="tests/unit/with_torch/test_bool_array_torch.py"  }
from pathlib import Path

import torch  # type: ignore[import-not-found]
from muutils.json_serialize import SerializableDataclass, serializable_dataclass

from zanj import ZANJ

TEST_DATA_PATH: Path = Path("tests/junk_data")


@serializable_dataclass
class MyClass_torch(SerializableDataclass):
    name: str
    arr_1: torch.Tensor
    arr_2: torch.Tensor


def test_torch_bool_array():
    fname: Path = TEST_DATA_PATH / "test_torch_bool_array.zanj"
    c: MyClass_torch = MyClass_torch(
        name="test",
        arr_1=torch.tensor([True, False, True]),
        arr_2=torch.tensor([True, False, True]),
    )

    z = ZANJ()

    z.save(c, fname)

    c2: MyClass_torch = z.read(fname)

    assert c2.arr_1.dtype == torch.bool
    assert c2.arr_2.dtype == torch.bool

    assert c == c2

``````{ end_of_file="tests/unit/with_torch/test_bool_array_torch.py" }

``````{ path="tests/unit/with_torch/test_get_module_device.py"  }
from __future__ import annotations

import pytest
import torch  # type: ignore[import-not-found]

from zanj.torchutil import get_module_device


def test_get_module_device_single_device():
    # Create a model and move it to a device
    model = torch.nn.Linear(10, 2)
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    model.to(device)

    # Run the function
    is_single, device_or_dict = get_module_device(model)

    # Assert that all parameters are on the same device and that device is returned
    assert is_single
    assert device_or_dict == device


def test_get_module_device_multiple_devices():
    # Create a model with parameters on different devices
    if torch.cuda.device_count() < 1:
        pytest.skip("This test requires at least one CUDA device")

    with torch.no_grad():
        model = torch.nn.Linear(10, 2)
        print(f"{model = }")
        model.weight = torch.nn.Parameter(model.weight.to("meta"))
        model.bias = torch.nn.Parameter(model.bias.to("cpu"))

    print(f"{model = }")
    print(f"{model.weight = }")
    print(f"{model.bias = }")

    # Run the function
    is_single, device_or_dict = get_module_device(model)

    print(f"{is_single = }, {device_or_dict = }")

    # Assert that not all parameters are on the same device and a dict is returned
    assert not is_single
    assert isinstance(device_or_dict, dict)

    # Check that the dict maps the correct devices
    assert device_or_dict["weight"] == torch.device("meta")
    assert device_or_dict["bias"] == torch.device("cpu")


def test_get_module_device_no_parameters():
    # Create a model with no parameters
    model = torch.nn.Sequential()

    # Run the function
    is_single, device_or_dict = get_module_device(model)

    # Assert that an empty dict is returned
    assert not is_single
    assert device_or_dict == {}

``````{ end_of_file="tests/unit/with_torch/test_get_module_device.py" }

``````{ path="tests/unit/with_torch/test_sdc_torch.py"  }
from __future__ import annotations

import sys
import typing
from pathlib import Path

import numpy as np
import pandas as pd  # type: ignore[import]
import torch  # type: ignore[import-not-found]
from muutils.json_serialize import (
    SerializableDataclass,
    serializable_dataclass,
    serializable_field,
)

from zanj import ZANJ

np.random.seed(0)

TEST_DATA_PATH: Path = Path("tests/junk_data")

SUPPORTS_KW_ONLY: bool = bool(sys.version_info >= (3, 10))


@serializable_dataclass
class BasicZanjTorch(SerializableDataclass):
    a: str
    q: int = 42
    c: typing.List[int] = serializable_field(default_factory=list)


@serializable_dataclass
class NestedTorch(SerializableDataclass):
    name: str
    basic: BasicZanjTorch
    val: float


@serializable_dataclass
class sdc_with_torch_tensor(SerializableDataclass):
    name: str
    tensor1: torch.Tensor
    tensor2: torch.Tensor


def test_sdc_tensor_small():
    instance = sdc_with_torch_tensor("small tensors", torch.rand(8), torch.rand(16))

    z = ZANJ()
    path = TEST_DATA_PATH / "test_sdc_tensor_small.zanj"
    z.save(instance, path)
    recovered = z.read(path)
    assert instance == recovered


def test_sdc_tensor():
    instance = sdc_with_torch_tensor(
        "bigger tensors", torch.rand(128, 128), torch.rand(256, 256)
    )

    z = ZANJ()
    path = TEST_DATA_PATH / "test_sdc_tensor.zanj"
    z.save(instance, path)
    recovered = z.read(path)
    assert instance == recovered


@serializable_dataclass(kw_only=SUPPORTS_KW_ONLY)
class sdc_complicated(SerializableDataclass):
    name: str
    arr1: np.ndarray
    arr2: np.ndarray
    iris_data: pd.DataFrame
    brain_data: pd.DataFrame
    container: typing.List[NestedTorch]

    tensor: torch.Tensor

    def __eq__(self, value):
        return super().__eq__(value)


def test_sdc_complicated():
    instance = sdc_complicated(
        name="complicated data",
        arr1=np.random.rand(128, 128),
        arr2=np.random.rand(256, 256),
        iris_data=pd.read_csv("tests/input_data/iris.csv"),
        brain_data=pd.read_csv("tests/input_data/brain_networks.csv"),
        container=[
            NestedTorch(
                f"n-{n}",
                BasicZanjTorch(f"n-{n}_b", n * 10 + 1, [n + 1, n + 2, n + 10]),
                n * np.pi,
            )
            for n in range(10)
        ],
        tensor=torch.rand(512, 512),
    )

    z = ZANJ()
    path = TEST_DATA_PATH / "test_sdc_complicated.zanj"
    z.save(instance, path)
    recovered = z.read(path)
    assert instance == recovered

``````{ end_of_file="tests/unit/with_torch/test_sdc_torch.py" }

``````{ path="tests/unit/with_torch/test_torch_edge_cases.py"  }
from __future__ import annotations

from pathlib import Path

import pytest
import torch  # type: ignore[import-not-found]
from muutils.json_serialize import (
    SerializableDataclass,
    serializable_dataclass,
)

from zanj import ZANJ
from zanj.torchutil import (
    ConfiguredModel,
    assert_model_exact_equality,
    get_module_device,
    num_params,
    set_config_class,
)

TEST_DATA_PATH: Path = Path("tests/junk_data")


def test_num_params():
    """Test the num_params function with various models"""
    # Simple model with trainable parameters
    model1 = torch.nn.Sequential(
        torch.nn.Linear(10, 20), torch.nn.ReLU(), torch.nn.Linear(20, 1)
    )

    # Expected number of trainable parameters:
    # Linear1: 10*20 + 20 = 220
    # Linear2: 20*1 + 1 = 21
    # Total: 241
    assert num_params(model1) == 241

    # Model with some non-trainable parameters
    model2 = torch.nn.Sequential(
        torch.nn.Linear(10, 20),
        torch.nn.BatchNorm1d(20, track_running_stats=True),
        torch.nn.Linear(20, 1),
    )

    # Freeze the batch norm layer
    for param in model2[1].parameters():
        param.requires_grad = False

    # Count only trainable parameters
    trainable_params = num_params(model2, only_trainable=True)

    # Count all parameters
    all_params = num_params(model2, only_trainable=False)

    # Batch norm has 2*20 = 40 parameters (weight and bias)
    # So difference should be 40
    assert all_params - trainable_params == 40


def test_get_module_device_empty():
    """Test get_module_device with a module that has no parameters"""

    # Create a module with no parameters
    class EmptyModule(torch.nn.Module):
        def __init__(self):
            super().__init__()

        def forward(self, x):
            return x

    empty_module = EmptyModule()

    # Test the function
    is_single, device_dict = get_module_device(empty_module)

    # Should return False and an empty dict
    assert is_single is False
    assert device_dict == {}


def test_load_state_dict_wrapper():
    """Test the _load_state_dict_wrapper method of ConfiguredModel"""

    @serializable_dataclass
    class SimpleConfig(SerializableDataclass):
        size: int

    @set_config_class(SimpleConfig)
    class SimpleModel(ConfiguredModel[SimpleConfig]):
        def __init__(self, config: SimpleConfig):
            super().__init__(config)
            self.linear = torch.nn.Linear(config.size, config.size)

        def forward(self, x):
            return self.linear(x)

        # Override _load_state_dict_wrapper to test custom behavior
        def _load_state_dict_wrapper(self, state_dict, **kwargs):
            # This should be called instead of the standard load_state_dict
            self.custom_wrapper_called = True
            # Still need to actually load the state dict
            return super()._load_state_dict_wrapper(state_dict)

    # Create a model, save it, and load it
    model = SimpleModel(SimpleConfig(10))
    path = TEST_DATA_PATH / "test_load_state_dict_wrapper.zanj"
    ZANJ().save(model, path)

    # Create a new model instance
    model2 = SimpleModel(SimpleConfig(10))
    model2.custom_wrapper_called = False

    # Use read to load the model, which should call _load_state_dict_wrapper
    loaded_model = model2.read(path)

    # Check that our custom wrapper was called
    assert hasattr(loaded_model, "custom_wrapper_called")
    assert loaded_model.custom_wrapper_called is True


def test_deprecated_load_file():
    """Test that the deprecated load_file method works and issues a warning"""

    @serializable_dataclass
    class SimpleConfig(SerializableDataclass):
        size: int

    @set_config_class(SimpleConfig)
    class SimpleModel(ConfiguredModel[SimpleConfig]):
        def __init__(self, config: SimpleConfig):
            super().__init__(config)
            self.linear = torch.nn.Linear(config.size, config.size)

        def forward(self, x):
            return self.linear(x)

    # Create a model and save it
    model = SimpleModel(SimpleConfig(10))
    path = TEST_DATA_PATH / "test_deprecated_load_file.zanj"
    ZANJ().save(model, path)

    # Use the deprecated method with a warning check
    with pytest.warns(DeprecationWarning):
        loaded_model = SimpleModel.load_file(path)

    # Check that the model was loaded correctly
    assert_model_exact_equality(model, loaded_model)


def test_configmodel_training_records():
    """Test that training_records are properly saved and loaded"""

    @serializable_dataclass
    class SimpleConfig(SerializableDataclass):
        size: int

    @set_config_class(SimpleConfig)
    class SimpleModel(ConfiguredModel[SimpleConfig]):
        def __init__(self, config: SimpleConfig):
            super().__init__(config)
            self.linear = torch.nn.Linear(config.size, config.size)
            self.training_records = None  # Initialize to None

        def forward(self, x):
            return self.linear(x)

    # Create a model and set some training records
    model = SimpleModel(SimpleConfig(10))
    model.training_records = {
        "loss": [1.0, 0.9, 0.8, 0.7],
        "accuracy": [0.6, 0.7, 0.8, 0.9],
        "learning_rate": [0.01, 0.005, 0.001, 0.0005],
    }

    # Save the model
    path = TEST_DATA_PATH / "test_training_records.zanj"
    ZANJ().save(model, path)

    # Load the model
    loaded_model = SimpleModel.read(path)

    # Check that training records were loaded correctly
    assert loaded_model.training_records == model.training_records
    assert loaded_model.training_records["loss"] == [1.0, 0.9, 0.8, 0.7]
    assert loaded_model.training_records["accuracy"] == [0.6, 0.7, 0.8, 0.9]
    assert loaded_model.training_records["learning_rate"] == [
        0.01,
        0.005,
        0.001,
        0.0005,
    ]


def test_configmodel_with_custom_settings():
    """Test ConfiguredModel with custom settings for _load_state_dict_wrapper"""

    @serializable_dataclass
    class SimpleConfig(SerializableDataclass):
        size: int

    @set_config_class(SimpleConfig)
    class CustomStateModel(ConfiguredModel[SimpleConfig]):
        def __init__(self, config: SimpleConfig):
            super().__init__(config)
            self.linear = torch.nn.Linear(config.size, config.size)
            self.custom_param_received = None

        def forward(self, x):
            return self.linear(x)

        def _load_state_dict_wrapper(self, state_dict, **kwargs):
            # Store the custom param if provided
            self.custom_param_received = kwargs.get("custom_param", None)
            return super(CustomStateModel, self)._load_state_dict_wrapper(state_dict)

    # Create a model
    model = CustomStateModel(SimpleConfig(10))
    path = TEST_DATA_PATH / "test_custom_settings.zanj"
    ZANJ().save(model, path)

    # Create a ZANJ with custom settings
    z = ZANJ(
        custom_settings={"_load_state_dict_wrapper": {"custom_param": "test_value"}}
    )

    # Load the model with the custom ZANJ
    loaded_model = z.read(path)

    # Check that the custom param was received
    assert loaded_model.custom_param_received == "test_value"

``````{ end_of_file="tests/unit/with_torch/test_torch_edge_cases.py" }

``````{ path="tests/unit/with_torch/test_torchutil_edge_cases.py"  }
"""Edge case tests for zanj/torchutil.py to improve coverage."""

from __future__ import annotations

from pathlib import Path

import pytest
import torch

from muutils.json_serialize import SerializableDataclass, serializable_dataclass

from zanj import ZANJ
from zanj.torchutil import (
    ConfiguredModel,
    assert_model_exact_equality,
    num_params,
    set_config_class,
)


TEST_DATA_PATH: Path = Path("tests/junk_data")


@serializable_dataclass
class EdgeCaseTestConfig(SerializableDataclass):
    """Simple config for testing."""

    hidden_size: int
    num_layers: int


@set_config_class(EdgeCaseTestConfig)
class EdgeCaseTestModel(ConfiguredModel[EdgeCaseTestConfig]):
    """Simple model for testing."""

    def __init__(self, cfg: EdgeCaseTestConfig):
        super().__init__(cfg)
        self.linear = torch.nn.Linear(cfg.hidden_size, cfg.hidden_size)

    def forward(self, x):
        return self.linear(x)


class TestConfiguredModelValidation:
    """Tests for ConfiguredModel validation edge cases."""

    def test_missing_config_class_decorator(self):
        """Line 102: Model without @set_config_class decorator should raise NotImplementedError."""

        class UnconfiguredModel(ConfiguredModel):
            def __init__(self, cfg):
                super().__init__(cfg)

        with pytest.raises(NotImplementedError, match="need to set"):
            UnconfiguredModel(EdgeCaseTestConfig(32, 2))

    def test_wrong_config_type(self):
        """Line 104: Passing wrong config type should raise TypeError."""

        class WrongConfig(SerializableDataclass):
            other_field: str = "test"

        with pytest.raises(TypeError, match="must be an instance of"):
            EdgeCaseTestModel(WrongConfig())

    def test_config_not_serializable_dataclass(self):
        """Using config that isn't a dict should raise TypeError."""
        with pytest.raises(TypeError):
            EdgeCaseTestModel({"hidden_size": 32})  # type: ignore


class TestSetConfigClass:
    """Tests for set_config_class decorator."""

    def test_invalid_config_class_type(self):
        """Line 227: Passing non-SerializableDataclass should raise TypeError."""

        class NotSerializable:
            pass

        with pytest.raises(
            TypeError, match="must be a subclass of SerializableDataclass"
        ):

            @set_config_class(NotSerializable)  # type: ignore
            class BadModel(ConfiguredModel):
                pass


class TestConfiguredModelSaveLoad:
    """Tests for save/load with default ZANJ."""

    def test_save_with_default_zanj(self, tmp_path):
        """Lines 134-136: Save without explicit ZANJ should create default."""
        cfg = EdgeCaseTestConfig(16, 1)
        model = EdgeCaseTestModel(cfg)

        file_path = str(tmp_path / "model_default_zanj.zanj")
        model.save(file_path)  # No zanj argument - uses default

        assert Path(file_path).exists()

        # Verify we can load it back
        loaded = EdgeCaseTestModel.read(file_path)
        assert loaded.zanj_model_config.hidden_size == 16

    def test_load_with_default_zanj(self, tmp_path):
        """Line 154: Load without explicit ZANJ should create default."""
        cfg = EdgeCaseTestConfig(16, 1)
        model = EdgeCaseTestModel(cfg)

        # Save first
        file_path = str(tmp_path / "model_for_load_test.zanj")
        z = ZANJ()
        z.save(model.serialize(), file_path)

        # Load with default ZANJ (via read method)
        loaded = EdgeCaseTestModel.read(file_path)  # No zanj argument - uses default
        assert loaded.zanj_model_config.hidden_size == 16

    def test_serialize_with_default_zanj(self):
        """Line 114-115: Serialize without explicit ZANJ should create default."""
        cfg = EdgeCaseTestConfig(16, 1)
        model = EdgeCaseTestModel(cfg)

        # Call serialize without zanj argument
        serialized = model.serialize()  # No zanj argument - uses default

        assert "zanj_model_config" in serialized
        assert "state_dict" in serialized
        assert "__muutils_format__" in serialized


class TestNumParamsWrapper:
    """Tests for num_params instance method."""

    def test_num_params_instance_method(self):
        """Line 220: Test the instance method wrapper for num_params."""
        cfg = EdgeCaseTestConfig(16, 1)
        model = EdgeCaseTestModel(cfg)

        # Call instance method
        instance_result = model.num_params()

        # Call module-level function
        function_result = num_params(model)

        assert instance_result == function_result
        assert instance_result > 0


class TestAssertModelExactEquality:
    """Tests for assert_model_exact_equality function."""

    def test_state_dict_mismatch(self):
        """Lines 289-290: State dict value mismatch should fail assertion."""
        cfg = EdgeCaseTestConfig(16, 1)
        model_a = EdgeCaseTestModel(cfg)
        model_b = EdgeCaseTestModel(cfg)

        # Modify model_b's weights to not match model_a
        with torch.no_grad():
            for param in model_b.parameters():
                param.add_(1.0)  # Add 1 to all parameters

        with pytest.raises(AssertionError, match="state dict elements don't match"):
            assert_model_exact_equality(model_a, model_b)

    def test_equal_models_pass(self):
        """Equal models should pass the assertion."""
        cfg = EdgeCaseTestConfig(16, 1)
        model_a = EdgeCaseTestModel(cfg)
        model_b = EdgeCaseTestModel(cfg)

        # Copy state dict from a to b to make them equal
        model_b.load_state_dict(model_a.state_dict())

        # Should not raise
        assert_model_exact_equality(model_a, model_b)

    def test_state_dict_keys_mismatch(self):
        """Different state dict keys should fail assertion."""

        class DifferentModel(ConfiguredModel[EdgeCaseTestConfig]):
            _config_class = EdgeCaseTestConfig

            def __init__(self, cfg: EdgeCaseTestConfig):
                super().__init__(cfg)
                self.different_linear = torch.nn.Linear(
                    cfg.hidden_size, cfg.hidden_size
                )

            def forward(self, x):
                return self.different_linear(x)

        cfg = EdgeCaseTestConfig(16, 1)
        model_a = EdgeCaseTestModel(cfg)
        model_b = DifferentModel(cfg)

        with pytest.raises(AssertionError, match="state dict keys don't match"):
            assert_model_exact_equality(model_a, model_b)


class TestConfiguredModelReadRoundTrip:
    """Tests for ConfiguredModel read/write round trip."""

    def test_full_round_trip(self, tmp_path):
        """Test complete save and read cycle."""
        cfg = EdgeCaseTestConfig(32, 2)
        model = EdgeCaseTestModel(cfg)

        # Set some training records
        model.training_records = {"epochs": 10, "loss": 0.5}

        file_path = str(tmp_path / "round_trip_test.zanj")
        model.save(file_path)

        loaded = EdgeCaseTestModel.read(file_path)

        assert loaded.zanj_model_config.hidden_size == 32
        assert loaded.zanj_model_config.num_layers == 2
        assert loaded.training_records == {"epochs": 10, "loss": 0.5}

        # Check weights are equal
        assert_model_exact_equality(model, loaded)

``````{ end_of_file="tests/unit/with_torch/test_torchutil_edge_cases.py" }

``````{ path="tests/unit/with_torch/test_zanj_sdc_modelcfg.py"  }
from __future__ import annotations

import json
import sys
import typing
from pathlib import Path

import numpy as np
import torch  # type: ignore[import-not-found]
from muutils.json_serialize import (
    SerializableDataclass,
    serializable_dataclass,
    serializable_field,
)

from zanj import ZANJ

np.random.seed(0)

# pylint: disable=missing-function-docstring,missing-class-docstring

TEST_DATA_PATH: Path = Path("tests/junk_data")


SUPPORTS_KW_ONLY: bool = bool(sys.version_info >= (3, 10))


@serializable_dataclass
class MyModelCfg(SerializableDataclass):
    name: str
    num_layers: int
    hidden_size: int
    dropout: float


@serializable_dataclass(kw_only=SUPPORTS_KW_ONLY)
class TrainCfg(SerializableDataclass):
    name: str
    weight_decay: float
    optimizer: typing.Type[torch.optim.Optimizer] = serializable_field(
        default_factory=lambda: torch.optim.Adam,
        serialization_fn=lambda x: x.__name__,
        loading_fn=lambda data: getattr(torch.optim, data["optimizer"]),
    )
    optimizer_kwargs: typing.Dict[str, typing.Any] = serializable_field(  # type: ignore
        default_factory=lambda: dict(lr=0.000001)
    )


class CustomCfg:
    def __init__(self, x: int, y: str):
        self.x = x
        self.y = y

    def __eq__(self, other):
        return self.x == other.x and self.y == other.y

    def serialize(self):
        return {"x": self.x, "y": self.y}

    @classmethod
    def load(cls, data):
        return cls(
            **{
                "x": data["x"],
                "y": data["y"],
            }
        )


@serializable_dataclass(kw_only=SUPPORTS_KW_ONLY)
class BasicCfgHolder(SerializableDataclass):
    model: MyModelCfg
    optimizer: TrainCfg
    custom: typing.Optional[CustomCfg] = serializable_field(
        default=None,
        serialization_fn=lambda x: x.serialize(),
        loading_fn=lambda data: CustomCfg.load(data["custom"]),
    )


instance_basic: BasicCfgHolder = BasicCfgHolder(  # type: ignore
    model=MyModelCfg("lstm", 3, 128, 0.1),  # type: ignore
    optimizer=TrainCfg(  # type: ignore
        name="adamw",
        weight_decay=0.2,
        optimizer=torch.optim.AdamW,
        optimizer_kwargs=dict(lr=0.0001),
    ),
    custom=CustomCfg(42, "forty-two"),
)


def test_config_holder():
    instance_stored = instance_basic.serialize()
    with open(TEST_DATA_PATH / "test_config_holder.json", "w") as f:
        json.dump(instance_stored, f, indent="\t")
    with open(TEST_DATA_PATH / "test_config_holder.json", "r") as f:
        instance_stored_read = json.load(f)
    recovered = BasicCfgHolder.load(instance_stored_read)
    assert isinstance(recovered.model, MyModelCfg)
    assert isinstance(recovered.optimizer, TrainCfg)
    assert isinstance(recovered.custom, CustomCfg)
    assert recovered.custom.x == 42
    assert instance_basic == recovered


def test_config_holder_zanj():
    z = ZANJ()
    path = TEST_DATA_PATH / "test_config_holder.zanj"
    z.save(instance_basic, path)
    recovered = z.read(path)
    assert isinstance(recovered.model, MyModelCfg)
    assert isinstance(recovered.optimizer, TrainCfg)
    assert isinstance(recovered.custom, CustomCfg)
    assert recovered.custom.x == 42
    assert instance_basic == recovered


@serializable_dataclass(kw_only=SUPPORTS_KW_ONLY)
class BaseGPTConfig(SerializableDataclass):
    name: str
    act_fn: str
    d_model: int
    d_head: int
    n_layers: int


@serializable_dataclass(kw_only=SUPPORTS_KW_ONLY)
class AdvCfgHolder(SerializableDataclass):
    model_cfg: BaseGPTConfig
    name: str = serializable_field(default="default")
    tokenizer: typing.Optional[CustomCfg] = serializable_field(
        default=None,
        serialization_fn=lambda x: repr(x) if x is not None else None,
        loading_fn=lambda data: (
            None if data["tokenizer"] is None else NotImplementedError
        ),
    )


instance_adv: AdvCfgHolder = AdvCfgHolder(  # type: ignore
    model_cfg=BaseGPTConfig(  # type: ignore
        name="gpt2",
        act_fn="gelu",
        d_model=128,
        d_head=64,
        n_layers=3,
    ),
    tokenizer=None,
)


def test_adv_config_holder():
    instance_stored = instance_adv.serialize()
    with open(TEST_DATA_PATH / "test_adv_config_holder.json", "w") as f:
        json.dump(instance_stored, f, indent="\t")
    recovered = AdvCfgHolder.load(instance_stored)
    assert isinstance(recovered.model_cfg, BaseGPTConfig)
    assert instance_adv == recovered


def test_adv_config_holder_zanj():
    z = ZANJ()
    path = TEST_DATA_PATH / "test_adv_config_holder.zanj"
    z.save(instance_adv, path)
    recovered = z.read(path)
    assert isinstance(recovered.model_cfg, BaseGPTConfig)
    assert instance_adv == recovered

``````{ end_of_file="tests/unit/with_torch/test_zanj_sdc_modelcfg.py" }

``````{ path="tests/unit/with_torch/test_zanj_torch.py"  }
from __future__ import annotations

from pathlib import Path

import numpy as np
import torch  # type: ignore[import-not-found]
from muutils.json_serialize import (
    SerializableDataclass,
    serializable_dataclass,
    serializable_field,
)
from muutils.tensor_utils import compare_state_dicts

from zanj import ZANJ
from zanj.torchutil import (
    ConfiguredModel,
    assert_model_exact_equality,
    set_config_class,
)

np.random.seed(0)

TEST_DATA_PATH: Path = Path("tests/junk_data")


def test_torch_configmodel_minimal():
    @serializable_dataclass
    class MyNNConfig(SerializableDataclass):
        n_layers: int

    @set_config_class(MyNNConfig)
    class MyNN(ConfiguredModel[MyNNConfig]):
        def __init__(self, config: MyNNConfig):
            super().__init__(config)

            self.layer = torch.nn.Linear(config.n_layers, 1)

        def forward(self, x):
            return self.layer(x)

    config: MyNNConfig = MyNNConfig(
        n_layers=2,
    )

    model: MyNN = MyNN(config)

    fname: Path = TEST_DATA_PATH / "test_torch_configmodel.zanj"
    ZANJ().save(model, fname)

    print(f"saved model to {fname}")
    print(f"{model.zanj_model_config = }")

    # try to load the model
    model2: MyNN = MyNN.read(fname)
    print(f"loaded model from {fname}")
    print(f"{model2.zanj_model_config = }")

    assert model.zanj_model_config == model2.zanj_model_config
    assert model.training_records == model2.training_records

    compare_state_dicts(model.state_dict(), model2.state_dict())
    assert_model_exact_equality(model, model2)

    model3: MyNN = ZANJ().read(fname)
    print(f"loaded model from {fname}")
    print(f"{model3.zanj_model_config = }")

    assert model.zanj_model_config == model3.zanj_model_config
    assert model.training_records == model3.training_records

    compare_state_dicts(model.state_dict(), model3.state_dict())
    assert_model_exact_equality(model, model3)


def test_torch_configmodel():
    import torch  # type: ignore[import-not-found]

    from zanj.torchutil import ConfiguredModel, set_config_class

    @serializable_dataclass
    class MyGPTConfig(SerializableDataclass):
        """basic test GPT config"""

        n_layers: int
        n_heads: int
        embedding_size: int
        n_positions: int
        n_vocab: int

        loss_factory: torch.nn.modules.loss._Loss = serializable_field(
            default_factory=lambda: torch.nn.CrossEntropyLoss,
            serialization_fn=lambda x: x.__name__,
            loading_fn=lambda x: getattr(torch.nn, x["loss_factory"]),
        )

        loss_kwargs: dict = serializable_field(default_factory=dict)

        @property
        def loss(self):
            return self.loss_factory(**self.loss_kwargs)

        optim_factory: torch.optim.Optimizer = serializable_field(
            default_factory=lambda: torch.optim.Adam,
            serialization_fn=lambda x: x.__name__,
            loading_fn=lambda x: getattr(torch.optim, x["optim_factory"]),
        )

        optim_kwargs: dict = serializable_field(default_factory=dict)

        def optim(self, model):
            return self.optim_factory(model.parameters(), **self.optim_kwargs)  # type: ignore

    @set_config_class(MyGPTConfig)
    class MyGPT(ConfiguredModel[MyGPTConfig]):
        """basic GPT model"""

        def __init__(self, config: MyGPTConfig):
            super().__init__(config)

            # implementation of a GPT style model with decoders only

            self.transformer = torch.nn.Transformer(
                d_model=config.embedding_size,
                nhead=config.n_heads,
                num_encoder_layers=0,
                num_decoder_layers=config.n_layers,
            )

        def forward(self, x):
            return self.transformer(x)

    config: MyGPTConfig = MyGPTConfig(
        n_layers=2,
        n_heads=2,
        embedding_size=16,
        n_positions=16,
        n_vocab=128,
        loss_factory=torch.nn.CrossEntropyLoss,
    )

    model: MyGPT = MyGPT(config)
    model.training_records = dict(loss=[3, 2, 1], accuracy=[0.1, 0.2, 0.3])

    fname: Path = TEST_DATA_PATH / "test_torch_configmodel.zanj"
    ZANJ().save(model, fname)

    print(f"saved model to {fname}")
    print(f"{model.zanj_model_config = }")

    # try to load the model
    model2: MyGPT = MyGPT.read(fname)
    print(f"loaded model from {fname}")
    print(f"{model2.zanj_model_config = }")

    assert model.zanj_model_config == model2.zanj_model_config
    assert model.training_records == model2.training_records

    compare_state_dicts(model.state_dict(), model2.state_dict())
    assert_model_exact_equality(model, model2)

    model3: MyGPT = ZANJ().read(fname)
    print(f"loaded model from {fname}")
    print(f"{model3.zanj_model_config = }")

    assert model.zanj_model_config == model3.zanj_model_config
    assert model.training_records == model3.training_records

    compare_state_dicts(model.state_dict(), model3.state_dict())
    assert_model_exact_equality(model, model3)

``````{ end_of_file="tests/unit/with_torch/test_zanj_torch.py" }

``````{ path="tests/unit/with_torch/test_zanj_torch_cfgmismatch.py"  }
from __future__ import annotations

from typing import Any

import torch  # type: ignore[import-not-found]
from muutils.json_serialize import (
    SerializableDataclass,
    serializable_dataclass,
    serializable_field,
)

from zanj.torchutil import (
    ConfigMismatchException,
    ConfiguredModel,
    assert_model_cfg_equality,
    set_config_class,
)

# Assuming required imports and classes are present (including ConfiguredModel, MyGPTConfig, and MyGPT)


@serializable_dataclass
class MyGPTConfig(SerializableDataclass):
    """basic test GPT config"""

    n_layers: int
    n_heads: int
    embedding_size: int
    n_positions: int
    n_vocab: int
    junk_data: Any = serializable_field(default_factory=dict)


@set_config_class(MyGPTConfig)
class MyGPT(ConfiguredModel[MyGPTConfig]):
    """basic "GPT" model"""

    def __init__(self, config: MyGPTConfig):
        super().__init__(config)
        self.transformer = torch.nn.Linear(config.embedding_size, config.n_vocab)

    def forward(self, x):
        return self.transformer(x)


def test_config_mismatch_exception_direct():
    msg = "Configs don't match"
    diff = {"model_cfg": {"are_weights_processed": {"self": False, "other": True}}}

    exc = ConfigMismatchException(msg, diff)
    assert exc.diff == diff
    assert (
        str(exc)
        == r"Configs don't match: {'model_cfg': {'are_weights_processed': {'self': False, 'other': True}}}"
    )


def test_equal_configs():
    config = MyGPTConfig(
        n_layers=2,
        n_heads=2,
        embedding_size=16,
        n_positions=16,
        n_vocab=128,
        junk_data={"a": 1, "b": 2},
    )

    model_a = MyGPT(config)
    model_b = MyGPT(config)

    assert_model_cfg_equality(model_a, model_b)


def test_unequal_configs():
    config_a = MyGPTConfig(
        n_layers=2,
        n_heads=2,
        embedding_size=16,
        n_positions=16,
        n_vocab=128,
        junk_data={"a": 1, "b": 2},
    )
    # a different config
    config_b = MyGPTConfig(
        n_layers=3,
        n_heads=2,
        embedding_size=16,
        n_positions=16,
        n_vocab=128,
        junk_data={"a": 7, "something": "or other"},
    )

    model_a = MyGPT(config_a)
    model_b = MyGPT(config_b)

    try:
        assert_model_cfg_equality(model_a, model_b)
    except ConfigMismatchException as exc:
        assert exc.diff == {
            "n_layers": {"self": 2, "other": 3},
            "junk_data": {
                "self": {"a": 1, "b": 2},
                "other": {"a": 7, "something": "or other"},
            },
        }
    else:
        raise AssertionError("Expected a ConfigMismatchException!")


def test_unequal_configs_2():
    config_a = MyGPTConfig(
        n_layers=2,
        n_heads=2,
        embedding_size=16,
        n_positions=16,
        n_vocab=128,
        junk_data={"a": 1, "b": 2},
    )
    # a different config
    config_b = MyGPTConfig(
        n_layers=3,
        n_heads=2,
        embedding_size=16,
        n_positions=16,
        n_vocab=128,
        junk_data="this isnt even a dict lol",
    )

    model_a = MyGPT(config_a)
    model_b = MyGPT(config_b)

    try:
        assert_model_cfg_equality(model_a, model_b)
    except ConfigMismatchException as exc:
        assert exc.diff == {
            "n_layers": {"self": 2, "other": 3},
            "junk_data": {
                "self": {"a": 1, "b": 2},
                "other": "this isnt even a dict lol",
            },
        }
    else:
        raise AssertionError("Expected a ConfigMismatchException!")


def test_incorrect_instance():
    config = MyGPTConfig(
        n_layers=2,
        n_heads=2,
        embedding_size=16,
        n_positions=16,
        n_vocab=128,
    )

    model_a = MyGPT(config)
    model_b = "Not a ConfiguredModel instance"

    try:
        assert_model_cfg_equality(model_a, model_b)  # type: ignore
    except AssertionError as exc:
        assert str(exc) == "model_b must be a ConfiguredModel"

``````{ end_of_file="tests/unit/with_torch/test_zanj_torch_cfgmismatch.py" }

``````{ path="tests/assert_no_torch.py"  }
import pytest


def test_assert_no_torch():
    with pytest.raises(ImportError):
        import torch  # type: ignore[import-not-found]

        print(torch.rand(10))

``````{ end_of_file="tests/assert_no_torch.py" }

``````{ path="zanj/__init__.py"  }
"""
.. include:: ../README.md
"""

from __future__ import annotations

from zanj.loading import register_loader_handler
from zanj.zanj import ZANJ

__all__ = [
    "register_loader_handler",
    "ZANJ",
    # modules
    "externals",
    "loading",
    "serializing",
    "torchutil",
    "zanj",
]

``````{ end_of_file="zanj/__init__.py" }

``````{ path="zanj/consts.py"  }
"""Constants and re-exports from muutils with version compatibility."""

from __future__ import annotations

# Items that exist in muutils.json_serialize.util across all versions
from muutils.json_serialize.util import (
    JSONdict,
    JSONitem,
    MonoTuple,
    safe_getsource,
    string_as_lines,
)

# _FORMAT_KEY and _REF_KEY moved from .util to .types in muutils >= 0.9
try:
    from muutils.json_serialize.types import _FORMAT_KEY, _REF_KEY  # type: ignore[import-not-found]
except ImportError:
    # fallback for muutils < 0.9 where these lived in .util; mypy can't resolve this across try/except
    from muutils.json_serialize.util import _FORMAT_KEY, _REF_KEY  # type: ignore[import-not-found, attr-defined, no-redef]

__all__ = [
    "JSONdict",
    "JSONitem",
    "MonoTuple",
    "_FORMAT_KEY",
    "_REF_KEY",
    "safe_getsource",
    "string_as_lines",
]

``````{ end_of_file="zanj/consts.py" }

``````{ path="zanj/externals.py"  }
"""for storing/retrieving an item externally in a ZANJ archive"""

from __future__ import annotations

import json
from typing import IO, Any, Callable, Literal, NamedTuple, get_args

import numpy as np
from muutils.json_serialize.json_serialize import ObjectPath

from zanj.consts import JSONitem

# this is to make type checking work -- it will later be overridden
_ZANJ_pre = Any

ZANJ_MAIN: str = "__zanj__.json"
ZANJ_META: str = "__zanj_meta__.json"

ExternalItemType = Literal["jsonl", "npy"]

ExternalItemType_vals = get_args(ExternalItemType)

ExternalItem = NamedTuple(
    "ExternalItem",
    [
        ("item_type", ExternalItemType),
        ("data", Any),
        ("path", ObjectPath),
    ],
)


def load_jsonl(zanj: "LoadedZANJ", fp: IO[bytes]) -> list[JSONitem]:  # type: ignore[name-defined] # noqa: F821
    return [json.loads(line) for line in fp]


def load_npy(zanj: "LoadedZANJ", fp: IO[bytes]) -> np.ndarray:  # type: ignore[name-defined] # noqa: F821
    return np.load(fp)


EXTERNAL_LOAD_FUNCS: dict[ExternalItemType, Callable[[_ZANJ_pre, IO[bytes]], Any]] = {
    "jsonl": load_jsonl,
    "npy": load_npy,
}


def GET_EXTERNAL_LOAD_FUNC(item_type: str) -> Callable[[_ZANJ_pre, IO[bytes]], Any]:
    if item_type not in EXTERNAL_LOAD_FUNCS:
        raise ValueError(
            f"unknown external item type: {item_type}, needs to be one of {EXTERNAL_LOAD_FUNCS.keys()}"
        )
    # safe to ignore since we just checked
    return EXTERNAL_LOAD_FUNCS[item_type]  # type: ignore[index]

``````{ end_of_file="zanj/externals.py" }

``````{ path="zanj/loading.py"  }
from __future__ import annotations

import json
import threading
import typing
import zipfile
from dataclasses import dataclass
from pathlib import Path
from typing import Any, Callable

import numpy as np

try:
    import pandas as pd  # type: ignore[import]

    pandas_DataFrame = pd.DataFrame  # type: ignore[no-redef]
except ImportError:

    class pandas_DataFrame:  # type: ignore[no-redef]
        def __init__(self, *args, **kwargs):
            raise ImportError("cannot load pandas DataFrame, pandas is not installed")


try:
    import polars as pl  # type: ignore[import]

    polars_DataFrame = pl.DataFrame  # type: ignore[no-redef]
except ImportError:

    class polars_DataFrame:  # type: ignore[no-redef]
        def __init__(self, *args, **kwargs):
            raise ImportError("cannot load polars DataFrame, polars is not installed")


from muutils.errormode import ErrorMode
from muutils.json_serialize.array import load_array
from muutils.json_serialize.json_serialize import ObjectPath

from zanj.consts import (
    JSONdict,
    JSONitem,
    _FORMAT_KEY,
    _REF_KEY,
    safe_getsource,
    string_as_lines,
)

from zanj.externals import (
    GET_EXTERNAL_LOAD_FUNC,
    ZANJ_MAIN,
    ZANJ_META,
    ExternalItem,
    _ZANJ_pre,
)

# pylint: disable=protected-access, dangerous-default-value


def _populate_externals_error_checking(key, item) -> bool:
    """checks that the key is valid for the item. returns "True" we need to augment the path by accessing the "data" element"""

    # special case for not fully loaded external item which we still need to populate
    if isinstance(item, typing.Mapping):
        if (_FORMAT_KEY in item) and item[_FORMAT_KEY].endswith(":external"):
            if "data" in item:
                return True
            else:
                raise KeyError(
                    f"expected an external item, but could not find data: {list(item.keys())}",
                    f"{item[_FORMAT_KEY]}, {len(item) = }, {item.get('data', '<EMPTY>') = }",
                )

    # if it's a list, make sure the key is an int and that it's in range
    if isinstance(item, typing.Sequence):
        if not isinstance(key, int):
            raise TypeError(f"improper type: '{type(key) = }', expected int")
        if key >= len(item):
            raise IndexError(f"index out of range: '{key = }', expected < {len(item)}")

    # if it's a dict, make sure that the key is a str and that it's in the dict
    elif isinstance(item, typing.Mapping):
        if not isinstance(key, str):
            raise TypeError(f"improper type: '{type(key) = }', expected str")
        if key not in item:
            raise KeyError(f"key not in dict: '{key = }', expected in {item.keys()}")

    # otherwise, raise an error
    else:
        raise TypeError(f"improper type: '{type(item) = }', expected dict or list")

    return False


@dataclass
class LoaderHandler:
    """handler for loading an object from a json file or a ZANJ archive"""

    # TODO: add a separate "asserts" function?
    # right now, any asserts must happen in `check` or `load` which is annoying with lambdas

    # (json_data, path) -> whether to use this handler
    check: Callable[[JSONitem, ObjectPath, _ZANJ_pre], bool]
    # function to load the object (json_data, path) -> loaded_obj
    load: Callable[[JSONitem, ObjectPath, _ZANJ_pre], Any]
    # unique identifier for the handler, saved in __muutils_format__ field
    uid: str
    # source package of the handler -- note that this might be overridden by ZANJ
    source_pckg: str
    # priority of the handler, defaults are all 0
    priority: int = 0
    # description of the handler
    desc: str = "(no description)"

    def serialize(self) -> JSONdict:
        """serialize the handler info"""
        return {
            # get the code and doc of the check function
            "check": {
                "code": safe_getsource(self.check),
                "doc": string_as_lines(self.check.__doc__),
            },
            # get the code and doc of the load function
            "load": {
                "code": safe_getsource(self.load),
                "doc": string_as_lines(self.load.__doc__),
            },
            # get the uid, source_pckg, priority, and desc
            "uid": str(self.uid),
            "source_pckg": str(self.source_pckg),
            "priority": int(self.priority),
            "desc": str(self.desc),
        }

    @classmethod
    def from_formattedclass(cls, fc: type, priority: int = 0):
        """create a loader from a class with `serialize`, `load` methods and `__muutils_format__` attribute"""
        assert hasattr(fc, "serialize")
        assert callable(fc.serialize)  # type: ignore
        assert hasattr(fc, "load")
        assert callable(fc.load)  # type: ignore
        assert hasattr(fc, _FORMAT_KEY)
        assert isinstance(fc.__muutils_format__, str)  # type: ignore

        return cls(
            check=lambda json_item, path=None, z=None: (  # type: ignore[misc]
                json_item[_FORMAT_KEY] == fc.__muutils_format__  # type: ignore[attr-defined]
            ),
            load=lambda json_item, path=None, z=None: fc.load(json_item, path, z),  # type: ignore[misc]
            uid=fc.__muutils_format__,  # type: ignore[attr-defined]
            source_pckg=str(fc.__module__),
            priority=priority,
            desc=f"formatted class loader for {fc.__name__}",
        )


# TODO: how can we type hint this without actually importing torch?
def _torch_loaderhandler_load(
    json_item: JSONitem,
    path: ObjectPath,
    z: _ZANJ_pre | None = None,
):
    """load a torch tensor from a json item"""
    try:
        import torch  # type: ignore[import-not-found]
        from muutils.tensor_utils import TORCH_DTYPE_MAP
    except ImportError as e:
        err_msg: str = f"could not import torch, which we need to load the object at {path = }: {json_item = }"
        raise ImportError(err_msg) from e

    return torch.tensor(
        # json_item is JSONitem but load_array expects narrower types; runtime check is in LoaderHandler.check
        load_array(json_item),  # type: ignore[no-matching-overload, call-overload]
        dtype=TORCH_DTYPE_MAP[json_item["dtype"]],  # type: ignore[index, call-overload]
    )


# NOTE: there are type ignores on the loaders, since the type checking should be the responsibility of the check function

LOADER_MAP_LOCK = threading.Lock()

LOADER_MAP: dict[str, LoaderHandler] = {
    lh.uid: lh
    for lh in [
        # array external
        LoaderHandler(
            check=lambda json_item, path=None, z=None: (  # type: ignore[misc]
                isinstance(json_item, typing.Mapping)
                and _FORMAT_KEY in json_item
                and json_item[_FORMAT_KEY].startswith("numpy.ndarray")
                # and json_item["data"].dtype.name == json_item["dtype"]
                # and tuple(json_item["data"].shape) == tuple(json_item["shape"])
            ),
            load=lambda json_item, path=None, z=None: np.array(  # type: ignore[misc]
                load_array(json_item), dtype=np.dtype(json_item["dtype"])
            ),
            uid="numpy.ndarray",
            source_pckg="zanj",
            desc="numpy.ndarray loader",
        ),
        LoaderHandler(
            check=lambda json_item, path=None, z=None: (  # type: ignore[misc]
                isinstance(json_item, typing.Mapping)
                and _FORMAT_KEY in json_item
                and json_item[_FORMAT_KEY].startswith("torch.Tensor")
                # and json_item["data"].dtype.name == json_item["dtype"]
                # and tuple(json_item["data"].shape) == tuple(json_item["shape"])
            ),
            load=_torch_loaderhandler_load,
            uid="torch.Tensor",
            source_pckg="zanj",
            desc="torch.Tensor loader",
        ),
        # pandas
        LoaderHandler(
            check=lambda json_item, path=None, z=None: (  # type: ignore[misc]
                isinstance(json_item, typing.Mapping)
                and _FORMAT_KEY in json_item
                and json_item[_FORMAT_KEY].startswith("pandas.DataFrame")
                and "data" in json_item
                and isinstance(json_item["data"], typing.Sequence)
            ),
            load=lambda json_item, path=None, z=None: (  # type: ignore[misc]
                pandas_DataFrame(json_item["data"])
                # if there is no data, load just the columns (this is for empty dataframes)
                if json_item["data"]
                else pandas_DataFrame(columns=json_item.get("columns"))
            ),
            uid="pandas.DataFrame",
            source_pckg="zanj",
            desc="pandas.DataFrame loader",
        ),
        # polars
        LoaderHandler(
            check=lambda json_item, path=None, z=None: (  # type: ignore[misc]
                isinstance(json_item, typing.Mapping)
                and _FORMAT_KEY in json_item
                and json_item[_FORMAT_KEY].startswith("polars.DataFrame")
                and "data" in json_item
                and isinstance(json_item["data"], typing.Sequence)
            ),
            load=lambda json_item, path=None, z=None: (  # type: ignore[misc]
                polars_DataFrame(json_item["data"])
                if json_item["data"]
                else polars_DataFrame(
                    schema={col: str for col in json_item.get("columns", [])}
                )
            ),
            uid="polars.DataFrame",
            source_pckg="zanj",
            desc="polars.DataFrame loader",
        ),
        # list/tuple external
        LoaderHandler(
            check=lambda json_item, path=None, z=None: (  # type: ignore[misc]
                isinstance(json_item, typing.Mapping)
                and _FORMAT_KEY in json_item
                and json_item[_FORMAT_KEY].startswith("list")
                and "data" in json_item
                and isinstance(json_item["data"], typing.Sequence)
            ),
            load=lambda json_item, path=None, z=None: [  # type: ignore[misc, arg-type]
                load_item_recursive(x, path, z)  # type: ignore[arg-type]
                for x in json_item["data"]
            ],
            uid="list",
            source_pckg="zanj",
            desc="list loader, for externals",
        ),
        LoaderHandler(
            check=lambda json_item, path=None, z=None: (  # type: ignore[misc]
                isinstance(json_item, typing.Mapping)
                and _FORMAT_KEY in json_item
                and json_item[_FORMAT_KEY].startswith("tuple")
                and "data" in json_item
                and isinstance(json_item["data"], typing.Sequence)
            ),
            load=lambda json_item, path=None, z=None: tuple(  # type: ignore[misc, arg-type]
                [load_item_recursive(x, path, z) for x in json_item["data"]]  # type: ignore[arg-type]
            ),
            uid="tuple",
            source_pckg="zanj",
            desc="tuple loader, for externals",
        ),
    ]
}


def register_loader_handler(handler: LoaderHandler):
    """register a custom loader handler"""
    global LOADER_MAP, LOADER_MAP_LOCK
    with LOADER_MAP_LOCK:
        LOADER_MAP[handler.uid] = handler


def get_item_loader(
    json_item: JSONitem,
    path: ObjectPath,
    zanj: _ZANJ_pre | None = None,
    error_mode: ErrorMode = ErrorMode.WARN,
    # lh_map: dict[str, LoaderHandler] = LOADER_MAP,
) -> LoaderHandler | None:
    """get the loader for a json item"""
    global LOADER_MAP

    # check if we recognize the format
    if isinstance(json_item, typing.Mapping) and _FORMAT_KEY in json_item:
        if not isinstance(json_item[_FORMAT_KEY], str):  # type: ignore[index]
            raise TypeError(
                f"invalid __muutils_format__ type '{type(json_item[_FORMAT_KEY])}' in '{path=}': '{json_item[_FORMAT_KEY] = }'"  # type: ignore[index]
            )
        if json_item[_FORMAT_KEY] in LOADER_MAP:  # type: ignore[index]
            return LOADER_MAP[json_item[_FORMAT_KEY]]  # type: ignore[index]

    # if we dont recognize the format, try to find a loader that can handle it
    for key, lh in LOADER_MAP.items():
        if lh.check(json_item, path, zanj):
            return lh

    # if we still dont have a loader, return None
    return None


def load_item_recursive(
    json_item: JSONitem,
    path: ObjectPath,
    zanj: _ZANJ_pre | None = None,
    error_mode: ErrorMode = ErrorMode.WARN,
    allow_not_loading: bool = True,
) -> Any:
    lh: LoaderHandler | None = get_item_loader(
        json_item=json_item,
        path=path,
        zanj=zanj,
        error_mode=error_mode,
        # lh_map=lh_map,
    )

    if lh is not None:
        # special case for serializable dataclasses
        if (
            isinstance(json_item, typing.Mapping)
            and (_FORMAT_KEY in json_item)
            and ("SerializableDataclass" in json_item[_FORMAT_KEY])  # type: ignore[operator]
        ):
            # why this horribleness?
            # SerializableDataclass, if it has a field `x` which is also a SerializableDataclass, will automatically call `x.__class__.load()`
            # However, we need to load things in containers, as well as arrays
            processed_json_item: dict = {
                key: (
                    val
                    if (
                        isinstance(val, typing.Mapping)
                        and (_FORMAT_KEY in val)
                        and ("SerializableDataclass" in val[_FORMAT_KEY])  # type: ignore[operator, index]
                    )
                    else load_item_recursive(
                        json_item=val,  # type: ignore[arg-type]
                        path=tuple(path) + (key,),  # type: ignore[arg-type]
                        zanj=zanj,
                        error_mode=error_mode,
                    )
                )
                for key, val in json_item.items()
            }

            return lh.load(processed_json_item, path, zanj)

        else:
            return lh.load(json_item, path, zanj)
    else:
        if isinstance(json_item, dict):
            return {
                key: load_item_recursive(
                    # ty doesn't narrow JSONitem to dict after isinstance check; string key indexing is safe here
                    json_item=json_item[key],  # type: ignore[invalid-argument-type, call-overload]
                    path=tuple(path) + (key,),
                    zanj=zanj,
                    error_mode=error_mode,
                    # lh_map=lh_map,
                )
                for key in json_item
            }
        elif isinstance(json_item, list):
            return [
                load_item_recursive(
                    json_item=x,
                    path=tuple(path) + (i,),
                    zanj=zanj,
                    error_mode=error_mode,
                    # lh_map=lh_map,
                )
                for i, x in enumerate(json_item)
            ]
        elif isinstance(json_item, (str, int, float, bool, type(None))):
            return json_item
        else:
            if allow_not_loading:
                return json_item
            else:
                raise ValueError(
                    f"unknown type {type(json_item)} at {path}\n{json_item}"
                )


def _each_item_in_externals(
    externals: dict[str, ExternalItem],
    json_data: JSONitem,
) -> typing.Iterable[tuple[str, ExternalItem, Any, ObjectPath]]:
    """note that you MUST use the raw iterator, dont try to turn into a list or something"""

    sorted_externals: list[tuple[str, ExternalItem]] = sorted(
        externals.items(), key=lambda x: len(x[1].path)
    )

    for ext_path, ext_item in sorted_externals:
        # get the path to the item
        path: ObjectPath = tuple(ext_item.path)
        assert len(path) > 0
        assert all(isinstance(key, (str, int)) for key in path), (
            f"improper types in path {path=}"
        )
        # get the item
        item = json_data
        for i, key in enumerate(path):
            try:
                # ignores in this block are because we cannot know the type is indexable in static analysis
                # but, we check the types in the line below
                external_unloaded: bool = _populate_externals_error_checking(key, item)
                if external_unloaded:
                    item = item["data"]  # type: ignore
                item = item[key]  # type: ignore[index]

            except (KeyError, IndexError, TypeError) as e:
                raise KeyError(
                    f"could not find '{key = }' at path '{ext_path = }', specifically at index '{i = }'",
                    f"'{type(item) =}', '{len(item) = }', '{item.keys() if isinstance(item, dict) else None = }'",  # type: ignore
                    f"From error: {e = }",
                    f"\n\n{item=}\n\n{ext_item=}",
                ) from e

        yield (ext_path, ext_item, item, path)


class LoadedZANJ:
    """for loading a zanj file"""

    def __init__(
        self,
        path: str | Path,
        zanj: _ZANJ_pre,
    ) -> None:
        # path and zanj object
        self._path: str = str(path)
        self._zanj: _ZANJ_pre = zanj

        # load zip file
        _zipf: zipfile.ZipFile = zipfile.ZipFile(file=self._path, mode="r")

        # load data
        self._meta: JSONdict = json.load(_zipf.open(ZANJ_META, "r"))
        self._json_data: JSONitem = json.load(_zipf.open(ZANJ_MAIN, "r"))

        # read externals
        self._externals: dict[str, ExternalItem] = dict()
        for fname, ext_item in self._meta["externals_info"].items():  # type: ignore
            item_type: str = ext_item["item_type"]  # type: ignore
            with _zipf.open(fname, "r") as fp:
                self._externals[fname] = ExternalItem(
                    item_type=item_type,  # type: ignore[arg-type]
                    data=GET_EXTERNAL_LOAD_FUNC(item_type)(self, fp),
                    path=ext_item["path"],  # type: ignore
                )

        # close zip file
        _zipf.close()
        del _zipf

    def populate_externals(self) -> None:
        """put all external items into the main json data"""

        # loop over once, populating the externals only
        for ext_path, ext_item, item, path in _each_item_in_externals(
            self._externals, self._json_data
        ):
            # replace the item with the external item
            assert _REF_KEY in item  # type: ignore
            assert item[_REF_KEY] == ext_path  # type: ignore
            item["data"] = ext_item.data  # type: ignore

``````{ end_of_file="zanj/loading.py" }

``````{ path="zanj/py.typed"  }

``````{ end_of_file="zanj/py.typed" }

``````{ path="zanj/serializing.py"  }
from __future__ import annotations

import json
import sys
from dataclasses import dataclass
from typing import IO, Any, Callable, Iterable, Sequence
import warnings

import numpy as np
from muutils.json_serialize.array import arr_metadata
from muutils.json_serialize.json_serialize import (  # JsonSerializer,
    DEFAULT_HANDLERS,
    ObjectPath,
    SerializerHandler,
)

from zanj.consts import JSONdict, JSONitem, MonoTuple, _FORMAT_KEY, _REF_KEY

from zanj.externals import ExternalItem, ExternalItemType, _ZANJ_pre

KW_ONLY_KWARGS: dict = dict()
if sys.version_info >= (3, 10):
    KW_ONLY_KWARGS["kw_only"] = True

# pylint: disable=unused-argument, protected-access, unexpected-keyword-arg
# for some reason pylint complains about kwargs to ZANJSerializerHandler


def jsonl_metadata(data: list[JSONdict]) -> dict[str, Any]:
    """metadata about a jsonl object"""
    all_cols: set[str] = set([col for item in data for col in item.keys()])
    output: dict[str, Any] = {
        "len(data)": len(data),
        "columns": {
            col: {
                "types": list(
                    set([type(item[col]).__name__ for item in data if col in item])
                ),
                "len": len([item[col] for item in data if col in item]),
            }
            for col in all_cols
            if col != _FORMAT_KEY
        },
    }
    if len(data) > 0:
        output["data[0]"] = data[0]
    return output


def store_npy(self: _ZANJ_pre, fp: IO[bytes], data: np.ndarray) -> None:
    """store numpy array to given file as .npy"""
    # TODO: Type `<module 'numpy.lib'>` has no attribute `format` --> zanj/serializing.py:54:5
    # info: rule `unresolved-attribute` is enabled by default
    np.lib.format.write_array(  # ty: ignore[unresolved-attribute]
        fp=fp,
        array=np.asanyarray(data),
        allow_pickle=False,
    )


def store_jsonl(self: _ZANJ_pre, fp: IO[bytes], data: Sequence[JSONitem]) -> None:
    """store sequence to given file as .jsonl"""

    for item in data:
        fp.write(json.dumps(item).encode("utf-8"))
        fp.write("\n".encode("utf-8"))


EXTERNAL_STORE_FUNCS: dict[
    ExternalItemType, Callable[[_ZANJ_pre, IO[bytes], Any], None]
] = {
    "npy": store_npy,
    "jsonl": store_jsonl,
}


@dataclass(**KW_ONLY_KWARGS)
class ZANJSerializerHandler(SerializerHandler):
    """a handler for ZANJ serialization"""

    # unique identifier for the handler, saved in _FORMAT_KEY field
    # uid: str
    # source package of the handler -- note that this might be overridden by ZANJ
    source_pckg: str
    # (self_config, object) -> whether to use this handler
    check: Callable[[_ZANJ_pre, Any, ObjectPath], bool]
    # (self_config, object, path) -> serialized object
    serialize_func: Callable[[_ZANJ_pre, Any, ObjectPath], JSONitem]
    # optional description of how this serializer works
    # desc: str = "(no description)"


def zanj_external_serialize(
    jser: _ZANJ_pre,
    data: Any,
    path: ObjectPath,
    item_type: ExternalItemType,
    _format: str,
) -> JSONitem:
    """stores a numpy array or jsonl externally in a ZANJ object

    # Parameters:
     - `jser: ZANJ`
     - `data: Any`
     - `path: ObjectPath`
     - `item_type: ExternalItemType`

    # Returns:
     - `JSONitem`
       json data with reference

    # Modifies:
     - modifies `jser._externals`
    """
    # get the path, make sure its unique
    assert isinstance(path, tuple), (
        f"path must be a tuple, got {type(path) = } {path = }"
    )
    joined_path: str = "/".join([str(p) for p in path])
    archive_path: str = f"{joined_path}.{item_type}"

    # TODO: somehow need to control whether a failure here causes a fallback to other handlers, or whether the except should propagate
    # this will probably require changes to the upstream muutils.json_serialize code
    if archive_path in jser._externals:
        err_msg = f"external path {archive_path} already exists!"
        warnings.warn(err_msg)
        raise ValueError(err_msg)
    # Check for true path prefix conflicts (not just string prefix)
    # Only flag when one path is a directory ancestor of another (contains "/" separator)
    for p in jser._externals.keys():
        # Remove the file extension to get the joined_path
        existing_joined_path = p.rsplit(".", 1)[0]
        # Check if one is a true path prefix with "/" separator
        if existing_joined_path.startswith(joined_path + "/") or joined_path.startswith(
            existing_joined_path + "/"
        ):
            err_msg = (
                f"external path {joined_path} is a prefix of another path {p}!\n"
                + f"{jser._externals.keys() = }\n{joined_path = }\n{path = }\n{p = }\n{existing_joined_path = }\n{archive_path = }\n{_format = }"
            )
            warnings.warn(err_msg)
            raise ValueError(err_msg)

    # process the data if needed, assemble metadata
    data_new: Any = data
    output: dict = {
        _FORMAT_KEY: _format,
        _REF_KEY: archive_path,
    }
    if item_type == "npy":
        # check type
        data_type_str: str = str(type(data))
        if data_type_str == "<class 'torch.Tensor'>":
            # detach and convert
            data_new = data.detach().cpu().numpy()
        elif data_type_str == "<class 'numpy.ndarray'>":
            pass
        else:
            # if not a numpy array, except
            raise TypeError(f"expected numpy.ndarray, got {data_type_str}")
        # get metadata
        output.update(arr_metadata(data))
    elif item_type.startswith("jsonl"):
        # check via module and class name to avoid importing pandas (works with pandas 3.0+)
        dataframe_columns = None
        if (
            "pandas" in data.__class__.__module__
            and data.__class__.__name__ == "DataFrame"
        ):
            dataframe_columns = data.columns.tolist()
            data_new = data.to_dict(orient="records")
        elif (
            "polars" in data.__class__.__module__
            and data.__class__.__name__ == "DataFrame"
        ):
            dataframe_columns = data.columns
            data_new = data.to_dicts()
        elif isinstance(data, (list, tuple, Iterable, Sequence)):
            data_new = [
                jser.json_serialize(item, tuple(path) + (i,))
                for i, item in enumerate(data)
            ]
        else:
            raise TypeError(
                f"expected list or pandas.DataFrame for jsonl, got {type(data)}"
            )

        if all([isinstance(item, dict) for item in data_new]):
            output.update(jsonl_metadata(data_new))

        # set DataFrame columns after jsonl_metadata to avoid being overwritten
        if dataframe_columns is not None:
            output["columns"] = dataframe_columns

    # store the item for external serialization
    jser._externals[archive_path] = ExternalItem(
        item_type=item_type,
        data=data_new,
        path=path,
    )

    return output


DEFAULT_SERIALIZER_HANDLERS_ZANJ: MonoTuple[ZANJSerializerHandler] = tuple(
    [
        ZANJSerializerHandler(
            check=lambda self, obj, path: (
                isinstance(obj, np.ndarray)
                and obj.size >= self.external_array_threshold
            ),
            serialize_func=lambda self, obj, path: zanj_external_serialize(
                self, obj, path, item_type="npy", _format="numpy.ndarray:external"
            ),
            uid="numpy.ndarray:external",
            source_pckg="zanj",
            desc="external numpy array",
        ),
        ZANJSerializerHandler(
            check=lambda self, obj, path: (
                str(type(obj)) == "<class 'torch.Tensor'>"
                and int(obj.nelement()) >= self.external_array_threshold
            ),
            serialize_func=lambda self, obj, path: zanj_external_serialize(
                self, obj, path, item_type="npy", _format="torch.Tensor:external"
            ),
            uid="torch.Tensor:external",
            source_pckg="zanj",
            desc="external torch tensor",
        ),
        ZANJSerializerHandler(
            check=lambda self, obj, path: (
                isinstance(obj, list) and len(obj) >= self.external_list_threshold
            ),
            serialize_func=lambda self, obj, path: zanj_external_serialize(
                self, obj, path, item_type="jsonl", _format="list:external"
            ),
            uid="list:external",
            source_pckg="zanj",
            desc="external list",
        ),
        ZANJSerializerHandler(
            check=lambda self, obj, path: (
                isinstance(obj, tuple) and len(obj) >= self.external_list_threshold
            ),
            serialize_func=lambda self, obj, path: zanj_external_serialize(
                self, obj, path, item_type="jsonl", _format="tuple:external"
            ),
            uid="tuple:external",
            source_pckg="zanj",
            desc="external tuple",
        ),
        ZANJSerializerHandler(
            check=lambda self, obj, path: (
                "pandas" in obj.__class__.__module__
                and obj.__class__.__name__ == "DataFrame"
            ),
            serialize_func=lambda self, obj, path: zanj_external_serialize(
                self, obj, path, item_type="jsonl", _format="pandas.DataFrame:external"
            ),
            uid="pandas.DataFrame:external",
            source_pckg="zanj",
            desc="external pandas DataFrame",
        ),
        ZANJSerializerHandler(
            check=lambda self, obj, path: (
                "polars" in obj.__class__.__module__
                and obj.__class__.__name__ == "DataFrame"
            ),
            serialize_func=lambda self, obj, path: zanj_external_serialize(
                self, obj, path, item_type="jsonl", _format="polars.DataFrame:external"
            ),
            uid="polars.DataFrame:external",
            source_pckg="zanj",
            desc="external polars DataFrame",
        ),
        # ZANJSerializerHandler(
        #     check=lambda self, obj, path: "<class 'torch.nn.modules.module.Module'>"
        #     in [str(t) for t in obj.__class__.__mro__],
        #     serialize_func=lambda self, obj, path: zanj_serialize_torchmodule(
        #         self, obj, path,
        #     ),
        #     uid="torch.nn.Module",
        #     source_pckg="zanj",
        #     desc="fallback torch serialization",
        # ),
    ]
) + tuple(
    DEFAULT_HANDLERS  # type: ignore[arg-type]
)

# the complaint above is:
# error: Argument 1 to "tuple" has incompatible type "Sequence[SerializerHandler]"; expected "Iterable[ZANJSerializerHandler]"  [arg-type]

``````{ end_of_file="zanj/serializing.py" }

``````{ path="zanj/torchutil.py"  }
"""torch utilities for zanj -- in particular the `ConfiguredModel` base class

note that this requires torch
"""

from __future__ import annotations

import abc
import typing
import warnings
from typing import Any, Type, TypeVar

try:
    import torch  # type: ignore[import-not-found]
except ImportError as e:
    raise ImportError(
        "torch is required for zanj.torchutil, please install it with `pip install torch` or `pip install zanj[torch]`"
    ) from e

from muutils.json_serialize import SerializableDataclass
from muutils.json_serialize.json_serialize import ObjectPath

from zanj.consts import _FORMAT_KEY, safe_getsource, string_as_lines

from zanj import ZANJ, register_loader_handler
from zanj.loading import LoaderHandler, load_item_recursive

# pylint: disable=protected-access

KWArgs = Any


def num_params(m: torch.nn.Module, only_trainable: bool = True):
    """return total number of parameters in a model

    - only counting shared parameters once
    - if `only_trainable` is False, will include parameters with `requires_grad = False`

    https://stackoverflow.com/questions/49201236/check-the-total-number-of-parameters-in-a-pytorch-model
    """
    parameters: list[torch.nn.Parameter] = list(m.parameters())
    if only_trainable:
        parameters = [p for p in parameters if p.requires_grad]

    unique: list[torch.nn.Parameter] = list(
        {p.data_ptr(): p for p in parameters}.values()
    )

    return sum(p.numel() for p in unique)


def get_module_device(
    m: torch.nn.Module,
) -> tuple[bool, torch.device | dict[str, torch.device]]:
    """get the current devices"""

    devs: dict[str, torch.device] = {name: p.device for name, p in m.named_parameters()}

    if len(devs) == 0:
        return False, devs

    # check if all devices are the same by getting one device
    dev_uni: torch.device = next(iter(devs.values()))

    if all(dev == dev_uni for dev in devs.values()):
        return True, dev_uni
    else:
        return False, devs


T_config = TypeVar("T_config", bound=SerializableDataclass)


class ConfiguredModel(
    torch.nn.Module,
    typing.Generic[T_config],
    metaclass=abc.ABCMeta,
):
    """a model that has a configuration, for saving with ZANJ

    ```python
    @set_config_class(YourConfig)
    class YourModule(ConfiguredModel[YourConfig]):
        def __init__(self, cfg: YourConfig):
            super().__init__(cfg)
    ```

    `__init__()` must initialize the model from a config object only, and call
    `super().__init__(zanj_model_config)`

    If you are inheriting from another class + ConfiguredModel,
    ConfiguredModel must be the first class in the inheritance list
    """

    # dont set this directly, use `set_config_class()` decorator
    _config_class: type | None = None
    zanj_config_class = property(lambda self: type(self)._config_class)

    def __init__(self, zanj_model_config: T_config, **kwargs):
        super().__init__(**kwargs)
        if self.zanj_config_class is None:
            raise NotImplementedError("you need to set `config_class` for your model")
        if not isinstance(zanj_model_config, self.zanj_config_class):  # type: ignore
            raise TypeError(
                f"config must be an instance of {self.zanj_config_class = }, got {type(zanj_model_config) = }"
            )

        self.zanj_model_config: T_config = zanj_model_config
        self.training_records: dict | None = None

    def serialize(
        self, path: ObjectPath = tuple(), zanj: ZANJ | None = None
    ) -> dict[str, Any]:
        if zanj is None:
            zanj = ZANJ()
        obj = dict(
            zanj_model_config=self.zanj_model_config.serialize(),
            meta=dict(
                class_name=self.__class__.__name__,
                class_doc=string_as_lines(self.__class__.__doc__),
                class_source=safe_getsource(self.__class__),
                module_name=self.__class__.__module__,
                module_mro=[str(x) for x in self.__class__.__mro__],
                num_params=num_params(self),
                as_str=string_as_lines(str(self)),
            ),
            training_records=self.training_records,
            state_dict=self.state_dict(),
            __muutils_format__=self.__class__.__name__,
        )
        return obj

    def save(self, file_path: str, zanj: ZANJ | None = None):
        if zanj is None:
            zanj = ZANJ()
        zanj.save(self.serialize(), file_path)

    def _load_state_dict_wrapper(
        self,
        state_dict: dict[str, torch.Tensor],
        **kwargs,
    ):
        """wrapper for `load_state_dict()` in case you need to override it"""
        assert len(kwargs) == 0, f"got unexpected kwargs: {kwargs}"
        return self.load_state_dict(state_dict)

    @classmethod
    def load(
        cls, obj: dict[str, Any], path: ObjectPath, zanj: ZANJ | None = None
    ) -> "ConfiguredModel":
        """load a model from a serialized object"""

        if zanj is None:
            zanj = ZANJ()

        # get the config
        zanj_model_config: T_config = cls._config_class.load(obj["zanj_model_config"])  # type: ignore

        # get the training records
        training_records: typing.Any = load_item_recursive(
            obj.get("training_records", None),
            tuple(path) + ("training_records",),
            zanj,
        )

        # initialize the model
        model: "ConfiguredModel" = cls(zanj_model_config)

        # load the state dict
        tensored_state_dict: dict[str, torch.Tensor] = load_item_recursive(
            obj["state_dict"],
            tuple(path) + ("state_dict",),
            zanj,
        )

        model._load_state_dict_wrapper(
            tensored_state_dict,
            **zanj.custom_settings.get("_load_state_dict_wrapper", dict()),
        )

        # set the training records
        model.training_records = training_records

        return model

    @classmethod
    def read(cls, file_path: str, zanj: ZANJ | None = None) -> "ConfiguredModel":
        """read a model from a file"""
        if zanj is None:
            zanj = ZANJ()

        mdl: ConfiguredModel = zanj.read(file_path)
        assert isinstance(mdl, cls), f"loaded object must be a {cls}, got {type(mdl)}"
        return mdl

    @classmethod
    def load_file(cls, file_path: str, zanj: ZANJ | None = None) -> "ConfiguredModel":
        """read a model from a file"""
        warnings.warn(
            "load_file() is deprecated, use read() instead", DeprecationWarning
        )
        return cls.read(file_path, zanj)

    @classmethod
    def get_handler(cls) -> LoaderHandler:
        cls_name: str = str(cls.__name__)
        return LoaderHandler(
            check=lambda json_item, path=None, z=None: (  # type: ignore
                isinstance(json_item, dict)
                and _FORMAT_KEY in json_item
                and json_item[_FORMAT_KEY].startswith(cls_name)
            ),
            load=lambda json_item, path=None, z=None: cls.load(json_item, path, z),  # type: ignore
            uid=cls_name,
            source_pckg=cls.__module__,
            desc=f"{cls.__module__} {cls_name} loader via zanj.torchutil.ConfiguredModel",
        )

    def num_params(self) -> int:
        return num_params(self)


def set_config_class(
    config_class: Type[SerializableDataclass],
) -> typing.Callable[[Type[ConfiguredModel]], Type[ConfiguredModel]]:
    if not issubclass(config_class, SerializableDataclass):
        raise TypeError(f"{config_class} must be a subclass of SerializableDataclass")

    def wrapper(cls: Type[ConfiguredModel]) -> Type[ConfiguredModel]:
        # set the config class
        cls._config_class = config_class

        # register the handlers
        register_loader_handler(cls.get_handler())

        # return the new class
        return cls

    return wrapper


class ConfigMismatchException(ValueError):
    def __init__(self, msg: str, diff):
        super().__init__(msg)
        self.diff = diff

    def __str__(self):
        return f"{super().__str__()}: {self.diff}"


def assert_model_cfg_equality(model_a: ConfiguredModel, model_b: ConfiguredModel):
    """check both models are correct instances and have the same config

    Raises:
        ConfigMismatchException: if the configs don't match, e.diff will contain the diff
    """
    assert isinstance(model_a, ConfiguredModel), "model_a must be a ConfiguredModel"
    assert isinstance(model_a.zanj_model_config, SerializableDataclass), (
        "model_a must have a zanj_model_config"
    )
    assert isinstance(model_b, ConfiguredModel), "model_b must be a ConfiguredModel"
    assert isinstance(model_b.zanj_model_config, SerializableDataclass), (
        "model_b must have a zanj_model_config"
    )

    cls_type: type = type(model_a.zanj_model_config)

    if not (model_a.zanj_model_config == model_b.zanj_model_config):
        raise ConfigMismatchException(
            f"configs of type {type(model_a.zanj_model_config)}, {type(model_b.zanj_model_config)} don't match",
            diff=cls_type.diff(model_a.zanj_model_config, model_b.zanj_model_config),  # type: ignore[attr-defined]
        )


def assert_model_exact_equality(model_a: ConfiguredModel, model_b: ConfiguredModel):
    """check the models are exactly equal, including state dict contents"""
    assert_model_cfg_equality(model_a, model_b)

    model_a_sd_keys: set[str] = set(model_a.state_dict().keys())
    model_b_sd_keys: set[str] = set(model_b.state_dict().keys())
    assert model_a_sd_keys == model_b_sd_keys, (
        f"state dict keys don't match: {model_a_sd_keys - model_b_sd_keys} / {model_b_sd_keys - model_a_sd_keys}"
    )
    keys_failed: list[str] = list()
    for k, v_a in model_a.state_dict().items():
        v_b = model_b.state_dict()[k]
        if not (v_a == v_b).all():
            # if not torch.allclose(v, v_load):
            keys_failed.append(k)
            print(f"failed {k}")
        else:
            print(f"passed {k}")
    assert len(keys_failed) == 0, (
        f"{len(keys_failed)} / {len(model_a_sd_keys)} state dict elements don't match: {keys_failed}"
    )

``````{ end_of_file="zanj/torchutil.py" }

``````{ path="zanj/zanj.py"  }
"""
an HDF5/exdir file alternative, which uses json for attributes, allows serialization of arbitrary data

for large arrays, the output is a .tar.gz file with most data in a json file, but with sufficiently large arrays stored in binary .npy files


"ZANJ" is an acronym that the AI tool [Elicit](https://elicit.org) came up with for me. not to be confused with:

- https://en.wikipedia.org/wiki/Zanj
- https://www.plutojournals.com/zanj/

"""

from __future__ import annotations

import json
import os
import time
import zipfile
from dataclasses import dataclass
from pathlib import Path
from typing import Any, Union

import numpy as np
from muutils.errormode import ErrorMode
from muutils.json_serialize.array import ArrayMode, arr_metadata
from muutils.json_serialize.json_serialize import (
    JsonSerializer,
    SerializerHandler,
    json_serialize,
)
from muutils.sysinfo import SysInfo

from zanj.consts import JSONitem, MonoTuple

from zanj.externals import ZANJ_MAIN, ZANJ_META, ExternalItem
import zanj.externals
from zanj.loading import LOADER_MAP, LoadedZANJ, load_item_recursive
from zanj.serializing import (
    DEFAULT_SERIALIZER_HANDLERS_ZANJ,
    EXTERNAL_STORE_FUNCS,
    KW_ONLY_KWARGS,
)

# pylint: disable=protected-access, unused-import, dangerous-default-value, line-too-long

ZANJitem = Union[
    JSONitem,
    np.ndarray,
    "pd.DataFrame",  # type: ignore # noqa: F821
]


@dataclass(**KW_ONLY_KWARGS)
class _ZANJ_GLOBAL_DEFAULTS_CLASS:
    error_mode: ErrorMode = ErrorMode.EXCEPT
    internal_array_mode: ArrayMode = "array_list_meta"
    external_array_threshold: int = 256
    external_list_threshold: int = 256
    compress: bool | int = True
    custom_settings: dict[str, Any] | None = None


ZANJ_GLOBAL_DEFAULTS: _ZANJ_GLOBAL_DEFAULTS_CLASS = _ZANJ_GLOBAL_DEFAULTS_CLASS()


class ZANJ(JsonSerializer):
    """Zip up: Arrays in Numpy, JSON for everything else

    given an arbitrary object, throw into a zip file, with arrays stored in .npy files, and everything else stored in a json file

    (basically npz file with json)

    - numpy (or pytorch) arrays are stored in paths according to their name and structure in the object
    - everything else about the object is stored in a json file `zanj.json` in the root of the archive, via `muutils.json_serialize.JsonSerializer`
    - metadata about ZANJ configuration, and optionally packages and versions, is stored in a `__zanj_meta__.json` file in the root of the archive

    create a ZANJ-class via `z_cls = ZANJ().create(obj)`, and save/read instances of the object via `z_cls.save(obj, path)`, `z_cls.load(path)`. be sure to pass an **instance** of the object, to make sure that the attributes of the class can be correctly recognized

    """

    def __init__(
        self,
        error_mode: ErrorMode = ZANJ_GLOBAL_DEFAULTS.error_mode,
        internal_array_mode: ArrayMode = ZANJ_GLOBAL_DEFAULTS.internal_array_mode,
        external_array_threshold: int = ZANJ_GLOBAL_DEFAULTS.external_array_threshold,
        external_list_threshold: int = ZANJ_GLOBAL_DEFAULTS.external_list_threshold,
        compress: bool | int = ZANJ_GLOBAL_DEFAULTS.compress,
        custom_settings: dict[str, Any] | None = ZANJ_GLOBAL_DEFAULTS.custom_settings,
        handlers_pre: MonoTuple[SerializerHandler] = tuple(),
        handlers_default: MonoTuple[
            SerializerHandler
        ] = DEFAULT_SERIALIZER_HANDLERS_ZANJ,
    ) -> None:
        super().__init__(
            array_mode=internal_array_mode,
            error_mode=error_mode,
            handlers_pre=handlers_pre,
            handlers_default=handlers_default,
        )

        self.external_array_threshold: int = external_array_threshold
        self.external_list_threshold: int = external_list_threshold
        self.custom_settings: dict = (
            custom_settings if custom_settings is not None else dict()
        )

        # process compression to int if bool given
        self.compress = compress
        if isinstance(compress, bool):
            if compress:
                self.compress = zipfile.ZIP_DEFLATED
            else:
                self.compress = zipfile.ZIP_STORED

        # create the externals, leave it empty
        self._externals: dict[str, ExternalItem] = dict()

    def externals_info(self) -> dict[str, dict[str, str | int | list[int]]]:
        """return information about the current externals"""
        output: dict[str, dict] = dict()

        key: str
        item: ExternalItem
        for key, item in self._externals.items():
            data = item.data
            output[key] = {
                "item_type": item.item_type,
                "path": item.path,
                "type(data)": str(type(data)),
                "len(data)": len(data),
            }

            if item.item_type == "ndarray":
                output[key].update(arr_metadata(data))
            elif item.item_type.startswith("jsonl") and len(data) > 0:
                output[key]["data[0]"] = data[0]

        return {
            key: val
            for key, val in sorted(output.items(), key=lambda x: len(x[1]["path"]))
        }

    def meta(self) -> JSONitem:
        """return the metadata of the ZANJ archive"""

        serialization_handlers = {h.uid: h.serialize() for h in self.handlers}
        load_handlers = {h.uid: h.serialize() for h in LOADER_MAP.values()}

        return dict(
            # configuration of this ZANJ instance
            zanj_cfg=dict(
                error_mode=str(self.error_mode),
                array_mode=str(self.array_mode),
                external_array_threshold=self.external_array_threshold,
                external_list_threshold=self.external_list_threshold,
                compress=self.compress,
                serialization_handlers=serialization_handlers,
                load_handlers=load_handlers,
            ),
            # system info (python, pip packages, torch & cuda, platform info, git info)
            sysinfo=json_serialize(SysInfo.get_all(include=("python", "pytorch"))),
            externals_info=self.externals_info(),
            timestamp=time.time(),
        )

    def save(self, obj: Any, file_path: str | Path) -> str:
        """save the object to a ZANJ archive. returns the path to the archive"""

        # adjust extension
        file_path = str(file_path)
        if not file_path.endswith(".zanj"):
            file_path += ".zanj"

        # make directory
        dir_path: str = os.path.dirname(file_path)
        if dir_path != "":
            if not os.path.exists(dir_path):
                os.makedirs(dir_path, exist_ok=False)

        # clear the externals!
        self._externals = dict()

        # serialize the object -- this will populate self._externals
        # TODO: calling self.json_serialize again here might be slow
        json_data: JSONitem = self.json_serialize(self.json_serialize(obj))

        # open the zip file
        zipf: zipfile.ZipFile = zipfile.ZipFile(
            file=file_path, mode="w", compression=self.compress
        )

        # store base json data and metadata
        zipf.writestr(
            ZANJ_META,
            json.dumps(
                self.json_serialize(self.meta()),
                indent="\t",
            ),
        )
        zipf.writestr(
            ZANJ_MAIN,
            json.dumps(
                json_data,
                indent="\t",
            ),
        )

        # store externals
        for key, (ext_type, ext_data, ext_path) in self._externals.items():
            # why force zip64? numpy.savez does it
            with zipf.open(key, "w", force_zip64=True) as fp:
                EXTERNAL_STORE_FUNCS[ext_type](self, fp, ext_data)

        zipf.close()

        # clear the externals, again
        self._externals = dict()

        return file_path

    def read(
        self,
        file_path: Union[str, Path],
    ) -> Any:
        """load the object from a ZANJ archive
        # TODO: load only some part of the zanj file by passing an ObjectPath
        """
        file_path = Path(file_path)
        if not file_path.exists():
            raise FileNotFoundError(f"file not found: {file_path}")
        if not file_path.is_file():
            raise FileNotFoundError(f"not a file: {file_path}")

        loaded_zanj: LoadedZANJ = LoadedZANJ(
            path=file_path,
            zanj=self,
        )

        loaded_zanj.populate_externals()

        return load_item_recursive(
            loaded_zanj._json_data,
            path=tuple(),
            zanj=self,
            error_mode=self.error_mode,
            # lh_map=loader_handlers,
        )


zanj.externals._ZANJ_pre = ZANJ  # type: ignore

``````{ end_of_file="zanj/zanj.py" }

``````{ path="README.md"  }
[![PyPI](https://img.shields.io/pypi/v/zanj)](https://pypi.org/project/zanj/)
[![Checks](https://github.com/mivanit/zanj/actions/workflows/checks.yml/badge.svg)](https://github.com/mivanit/zanj/actions/workflows/checks.yml)
[![Coverage](docs/coverage/coverage.svg)](docs/coverage/coverage.txt)
![code size, bytes](https://img.shields.io/github/languages/code-size/mivanit/zanj)
![PyPI - Downloads](https://img.shields.io/pypi/dm/zanj)
[![DOI](https://zenodo.org/badge/618623453.svg)](https://doi.org/10.5281/zenodo.15540392)


<!-- ![GitHub commit activity](https://img.shields.io/github/commit-activity/t/mivanit/zanj)
![GitHub closed pull requests](https://img.shields.io/github/issues-pr-closed/mivanit/zanj) -->
<!-- ![Lines of code](https://img.shields.io/tokei/lines/github.com/mivanit/zanj) -->

# ZANJ

# Overview

The `ZANJ` format is meant to be a way of saving arbitrary objects to disk, in a way that is flexible, allows keeping configuration and data together, and is human readable. It is very loosely inspired by HDF5 and the derived `exdir` format, and the implementation is inspired by `npz` files.

- You can take any `SerializableDataclass` from the [muutils](https://github.com/mivanit/muutils) library and save it to disk -- any large arrays or lists will be stored efficiently as external files in the zip archive, while the basic structure and metadata will be stored in readable JSON files. 
- You can also specify a special `ConfiguredModel`, which inherits from a `torch.nn.Module` which will let you save not just your model weights, but all required configuration information, plus any other metadata (like training logs) in a single file.

This library was originally a module in [muutils](https://github.com/mivanit/muutils/)


# Installation
Available on PyPI as [`zanj`](https://pypi.org/project/zanj/)

```
pip install zanj
```

# Usage

You can find a runnable example of this in [`demo.ipynb`](demo.ipynb)

## Saving a basic object

Any `SerializableDataclass` of basic types can be saved as zanj:

```python
import numpy as np
import pandas as pd
from muutils.json_serialize import SerializableDataclass, serializable_dataclass, serializable_field
from zanj import ZANJ

@serializable_dataclass
class BasicZanj(SerializableDataclass):
    a: str
    q: int = 42
    c: list[int] = serializable_field(default_factory=list)

# initialize a zanj reader/writer
zj = ZANJ()

# create an instance
instance: BasicZanj = BasicZanj("hello", 42, [1, 2, 3])
path: str = "tests/junk_data/path_to_save_instance.zanj"
zj.save(instance, path)
recovered: BasicZanj = zj.read(path)
```

ZANJ will intelligently handle nested serializable dataclasses, numpy arrays, pytorch tensors, and pandas dataframes: 

```python
import torch
import pandas as pd

@serializable_dataclass
class Complicated(SerializableDataclass):
    name: str
    arr1: np.ndarray
    arr2: np.ndarray
    iris_data: pd.DataFrame
    brain_data: pd.DataFrame
    container: list[BasicZanj]
    torch_tensor: torch.Tensor
```

For custom classes, you can specify a `serialization_fn` and `loading_fn` to handle the logic of converting to and from a json-serializable format:

```python
@serializable_dataclass
class Complicated(SerializableDataclass):
    name: str
    device: torch.device = serializable_field(
        serialization_fn=lambda self: str(self.device),
        loading_fn=lambda data: torch.device(data["device"]),
    )
```

Note that `loading_fn` takes the dictionary of the whole class -- this is in case you've stored data in multiple fields of the dict which are needed to reconstruct the object.

## Saving Models

First, define a configuration class for your model. This class will hold the parameters for your model and any associated objects (like losses and optimizers). The configuration class should be a subclass of `SerializableDataclass` and use the `serializable_field` function to define fields that need special serialization.

Here's an example that defines a GPT-like model configuration:

```python
from zanj.torchutil import ConfiguredModel, set_config_class

@serializable_dataclass
class MyNNConfig(SerializableDataclass):
    input_dim: int
    hidden_dim: int
    output_dim: int

    # store the activation function by name, reconstruct it by looking it up in torch.nn
    act_fn: torch.nn.Module = serializable_field(
        serialization_fn=lambda x: x.__name__,
        loading_fn=lambda x: getattr(torch.nn, x["act_fn"]),
    )

    # same for the loss function
    loss_kwargs: dict = serializable_field(default_factory=dict)
    loss_factory: torch.nn.modules.loss._Loss = serializable_field(
        default_factory=lambda: torch.nn.CrossEntropyLoss,
        serialization_fn=lambda x: x.__name__,
        loading_fn=lambda x: getattr(torch.nn, x["loss_factory"]),
    )
    loss = property(lambda self: self.loss_factory(**self.loss_kwargs))
```

Then, define your model class. It should be a subclass of `ConfiguredModel`, and use the `set_config_class` decorator to associate it with your configuration class. The `__init__` method should take a single argument, which is an instance of your configuration class. You must also call the superclass `__init__` method with the configuration instance.

```python
@set_config_class(MyNNConfig)
class MyNN(ConfiguredModel[MyNNConfig]):
    def __init__(self, config: MyNNConfig):
		# call the superclass init!
		# this will store the model in the zanj_model_config field
        super().__init__(config)

		# whatever you want here
        self.net = torch.nn.Sequential(
            torch.nn.Linear(config.input_dim, config.hidden_dim),
            config.act_fn(),
            torch.nn.Linear(config.hidden_dim, config.output_dim),
        )

    def forward(self, x):
        return self.net(x)
```

You can now create instances of your model, save them to disk, and load them back into memory:

```python
config = MyNNConfig(
    input_dim=10,
    hidden_dim=20,
    output_dim=2,
    act_fn=torch.nn.ReLU,
    loss_kwargs=dict(reduction="mean"),
)

# create your model from the config, and save
model = MyNN(config)
fname = "tests/junk_data/path_to_save_model.zanj"
ZANJ().save(model, fname)
# load by calling the class method `read()`
loaded_model = MyNN.read(fname)
# zanj will actually infer the type of the object in the file 
# -- and will warn you if you don't have the correct package installed
loaded_another_way = ZANJ().read(fname)
```

## Configuration

When initializing a `ZANJ` object, you can specify some configuration info about saving, such as:

- thresholds for how big an array/table has to be before moving to external file
- compression settings
- error modes
- additional handlers for serialization

```python
# how big an array or list (including pandas DataFrame) can be before moving it from the core JSON file
external_array_threshold: int = ZANJ_GLOBAL_DEFAULTS.external_array_threshold
external_list_threshold: int = ZANJ_GLOBAL_DEFAULTS.external_list_threshold
# compression settings passed to `zipfile` package
compress: bool | int = ZANJ_GLOBAL_DEFAULTS.compress
# for doing very cursed things in your own custom loading or serialization functions
custom_settings: dict[str, Any] | None = ZANJ_GLOBAL_DEFAULTS.custom_settings
# specify additional serialization handlers
handlers_pre: MonoTuple[SerializerHandler] = tuple()
handlers_default: MonoTuple[SerializerHandler] = DEFAULT_SERIALIZER_HANDLERS_ZANJ,
```

# Implementation

The on-disk format is a file `<filename>.zanj` is a zip file containing:

- `__zanj_meta__.json`: a file containing zanj-specific metadata including:
	- system information
	- installed packages
	- information about external files
- `__zanj__.json`: a file containing user-specified data
	- when an element is too big, it can be moved to an external file
		- `.npy` for numpy arrays or torch tensors
		- `.jsonl` for pandas dataframes or large sequences
	- list of external files stored in `__zanj_meta__.json`
	- "$ref" key, specified in `_REF_KEY` in muutils, will have value pointing to external file
	- `_FORMAT_KEY` key will detail an external format type


# Comparison to other formats



| Format                  | Safe | Zero-copy | Lazy loading | No file size limit | Layout control | Flexibility | Bfloat16 |
| ----------------------- | ---- | --------- | ------------ | ------------------ | -------------- | ----------- | -------- |
| pickle (PyTorch)        | ❌   | ❌        | ❌           | ✅                 | ❌             | ✅          | ✅       |
| H5 (Tensorflow)         | ✅   | ❌        | ✅           | ✅                 | ~              | ~           | ❌       |
| HDF5                    | ✅   | ?         | ✅           | ✅                 | ~              | ✅          | ❌       |
| SavedModel (Tensorflow) | ✅   | ❌        | ❌           | ✅                 | ✅             | ❌          | ✅       |
| MsgPack (flax)          | ✅   | ✅        | ❌           | ✅                 | ❌             | ❌          | ✅       |
| Protobuf (ONNX)         | ✅   | ❌        | ❌           | ❌                 | ❌             | ❌          | ✅       |
| Cap'n'Proto             | ✅   | ✅        | ~            | ✅                 | ✅             | ~           | ❌       |
| Numpy (npy,npz)         | ✅   | ?         | ?            | ❌                 | ✅             | ❌          | ❌       |
| SafeTensors             | ✅   | ✅        | ✅           | ✅                 | ✅             | ❌          | ✅       |
| exdir                   | ✅   | ?         | ?            | ?                  | ?              | ✅          | ❌       |
| ZANJ                    | ✅   | ❌        | ❌*          | ✅                 | ✅             | ✅          | ❌*      |


- Safe: Can I use a file randomly downloaded and expect not to run arbitrary code ?
- Zero-copy: Does reading the file require more memory than the original file ?
- Lazy loading: Can I inspect the file without loading everything ? And loading only some tensors in it without scanning the whole file (distributed setting) ?
- Layout control: Lazy loading, is not necessarily enough since if the information about tensors is spread out in your file, then even if the information is lazily accessible you might have to access most of your file to read the available tensors (incurring many DISK -> RAM copies). Controlling the layout to keep fast access to single tensors is important.
- No file size limit: Is there a limit to the file size ?
- Flexibility: Can I save custom code in the format and be able to use it later with zero extra code ? (~ means we can store more than pure tensors, but no custom code)
- Bfloat16: Does the format support native bfloat16 (meaning no weird workarounds are necessary)? This is becoming increasingly important in the ML world.

`*` denotes this feature may be coming at a future date :)

(This table was stolen from [safetensors](https://github.com/huggingface/safetensors/blob/main/README.md))

``````{ end_of_file="README.md" }

``````{ path="demo.ipynb" processed_with="ipynb_to_md" }
# Installation
Available on PyPI as [`zanj`](https://pypi.org/project/zanj/)

```
pip install zanj
```

```python
import os

import numpy as np
import pandas as pd
import torch

from muutils.json_serialize import (
    SerializableDataclass,
    serializable_dataclass,
    serializable_field,
)
from zanj import ZANJ
```


# Usage

## Saving a basic object

Any `SerializableDataclass` of basic types can be saved as zanj:

```python
@serializable_dataclass
class BasicZanj(SerializableDataclass):
    a: str
    q: int = 42
    c: list[int] = serializable_field(default_factory=list)


# initialize a zanj reader/writer
zj = ZANJ()

# create an instance
instance: BasicZanj = BasicZanj("hello", 42, [1, 2, 3])
path: str = "tests/junk_data/path_to_save_instance.zanj"
zj.save(instance, path)
recovered: BasicZanj = zj.read(path)
```

```python
print(f"{type(recovered) = }")  # BasicZanj
print(f"{os.path.getsize(path) = }")
```

ZANJ will intelligently handle nested serializable dataclasses, numpy arrays, pytorch tensors, and pandas dataframes: 

```python
@serializable_dataclass
class Complicated(SerializableDataclass):
    name: str
    arr1: np.ndarray
    arr2: np.ndarray
    iris_data: pd.DataFrame
    brain_data: pd.DataFrame
    container: list[BasicZanj]
    torch_tensor: torch.Tensor
```

For custom classes, you can specify a `serialization_fn` and `loading_fn` to handle the logic of converting to and from a json-serializable format:

```python
@serializable_dataclass
class Complicated2(SerializableDataclass):
    name: str
    device: torch.device = serializable_field(
        serialization_fn=lambda self: str(self.device),
        loading_fn=lambda data: torch.device(data["device"]),
    )
```

Note that `loading_fn` takes the dictionary of the whole class -- this is in case you've stored data in multiple fields of the dict which are needed to reconstruct the object.

## Saving Models

First, define a configuration class for your model. This class will hold the parameters for your model and any associated objects (like losses and optimizers). The configuration class should be a subclass of `SerializableDataclass` and use the `serializable_field` function to define fields that need special serialization.

Here's an example that defines a configuration for a simple neural network:

```python
from zanj.torchutil import ConfiguredModel, set_config_class


@serializable_dataclass
class MyNNConfig(SerializableDataclass):
    input_dim: int
    hidden_dim: int
    output_dim: int

    # store the activation function by name, reconstruct it by looking it up in torch.nn
    act_fn: torch.nn.Module = serializable_field(
        serialization_fn=lambda x: x.__name__,
        loading_fn=lambda x: getattr(torch.nn, x["act_fn"]),
    )

    # same for the loss function
    loss_kwargs: dict = serializable_field(default_factory=dict)
    loss_factory: torch.nn.modules.loss._Loss = serializable_field(
        default_factory=lambda: torch.nn.CrossEntropyLoss,
        serialization_fn=lambda x: x.__name__,
        loading_fn=lambda x: getattr(torch.nn, x["loss_factory"]),
    )
    loss = property(lambda self: self.loss_factory(**self.loss_kwargs))
```

Then, define your model class. It should be a subclass of `ConfiguredModel`, and use the `set_config_class` decorator to associate it with your configuration class. The `__init__` method should take a single argument, which is an instance of your configuration class. You must also call the superclass `__init__` method with the configuration instance.

```python
@set_config_class(MyNNConfig)
class MyNN(ConfiguredModel[MyNNConfig]):
    def __init__(self, config: MyNNConfig):
        # call the superclass init!
        # this will store the model in the zanj_model_config field
        super().__init__(config)

        # whatever you want here
        self.net = torch.nn.Sequential(
            torch.nn.Linear(config.input_dim, config.hidden_dim),
            config.act_fn(),
            torch.nn.Linear(config.hidden_dim, config.output_dim),
        )

    def forward(self, x):
        return self.net(x)
```

You can now create instances of your model, save them to disk, and load them back into memory:

```python
config = MyNNConfig(
    input_dim=10,
    hidden_dim=20,
    output_dim=2,
    act_fn=torch.nn.ReLU,
    loss_kwargs=dict(reduction="mean"),
)

# create your model from the config, and save
model = MyNN(config)
fname = "tests/junk_data/path_to_save_model.zanj"
ZANJ().save(model, fname)
# load by calling the class method `read()`
loaded_model = MyNN.read(fname)
# zanj will actually infer the type of the object in the file
# -- and will warn you if you don't have the correct package installed
loaded_another_way = ZANJ().read(fname)
```

```python
print(f"{type(loaded_model) = }")
x = torch.randn(config.input_dim)
print(f"{x.shape = }")
out_1 = model(x)
out_2 = loaded_model(x)
out_3 = loaded_another_way(x)

print(f"{out_1 = }, {out_2 = }, {out_3 = }")
assert torch.allclose(out_1, out_2)
assert torch.allclose(out_1, out_3)
```


``````{ end_of_file="demo.ipynb" }

``````{ path="makefile" processed_with="makefile_recipes" }
# first/default target is help
.PHONY: default
default: help
	...

# download makefile helper scripts from GitHub
# uses curl to fetch scripts from the template repository
# override version: make self-setup-scripts SCRIPTS_VERSION=v0.5.0
.PHONY: self-setup-scripts
self-setup-scripts:
	@echo "downloading makefile scripts (version: $(SCRIPTS_VERSION))"
	...

# this recipe is weird. we need it because:
# - a one liner for getting the version with toml is unwieldy, and using regex is fragile
# - using $$SCRIPT_GET_VERSION within $(shell ...) doesn't work because of escaping issues
# - trying to write to the file inside the `gen-version-info` recipe doesn't work, 
#   shell eval happens before our `python ...` gets run and `cat` doesn't see the new file
.PHONY: write-proj-version
write-proj-version:
	...

# gets version info from $(PYPROJECT), last version from $(LAST_VERSION_FILE), and python version
# uses just `python` for everything except getting the python version. no echo here, because this is "private"
.PHONY: gen-version-info
gen-version-info: write-proj-version
	...

# getting commit log since the tag specified in $(LAST_VERSION_FILE)
# will write to $(COMMIT_LOG_FILE)
# when publishing, the contents of $(COMMIT_LOG_FILE) will be used as the tag description (but can be edited during the process)
# no echo here, because this is "private"
.PHONY: gen-commit-log
gen-commit-log: gen-version-info
	...

# force the version info to be read, printing it out
# also force the commit log to be generated, and cat it out
.PHONY: version
version: gen-commit-log
	@echo "Current version is $(PROJ_VERSION), last auto-uploaded version is $(LAST_VERSION)"
	...

.PHONY: setup
setup: self-setup-scripts dep-check
	@echo "download scripts and sync dependencies"
	...

.PHONY: dep-check-torch
dep-check-torch:
	@echo "see if torch is installed, and which CUDA version and devices it sees"
	...

# sync dependencies and export to requirements.txt files
# - syncs all extras and groups with uv (including dev dependencies)
# - compiles bytecode for faster imports
# - exports to requirements.txt files per tool.uv-exports.exports config
# configure via pyproject.toml:[tool.uv-exports]:
#   [tool.uv-exports]
#   exports = [
#     { name = "base", extras = [], groups = [] },  # base package deps only
#     { name = "dev", extras = [], groups = ["dev"] },  # dev dependencies
#     { name = "all", extras = ["all"], groups = ["dev"] }  # everything
#   ]
.PHONY: dep
dep:
	@echo "syncing and exporting dependencies as per $(PYPROJECT) section 'tool.uv-exports.exports'"
	...

.PHONY: dep-compile
dep-compile:
	@echo "syncing dependencies with bytecode compilation"
	...

# verify that requirements.txt files match current dependencies
# - exports deps to temp directory
# - diffs temp against existing requirements files
# - FAILS if any differences found (means you need to run `make dep`)
# useful in CI to catch when pyproject.toml changed but requirements weren't regenerated
.PHONY: dep-check
dep-check:
	@echo "Checking that exported requirements are up to date"
	...

.PHONY: dep-clean
dep-clean:
	@echo "clean up lock files, .venv, and requirements files"
	...

# format code AND auto-fix linting issues
# performs TWO operations: reformats code, then auto-fixes safe linting issues
# configure in pyproject.toml:[tool.ruff]
.PHONY: format
format:
	@echo "format the source code"
	...

# runs ruff to check if the code is formatted correctly
.PHONY: format-check
format-check:
	@echo "check if the source code is formatted correctly"
	...

# runs type checks with configured checkers
# set TYPE_CHECKERS to customize which checkers run (e.g., TYPE_CHECKERS=mypy,basedpyright)
# set TYPING_OUTPUT_DIR to save outputs to files (used by typing-summary)
# returns exit code 1 if any checker fails
.PHONY: typing
typing:
	@echo "running type checks"
	...

# save type check outputs and generate detailed breakdown
# outputs are saved to $(TYPE_ERRORS_DIR)/*.txt
# summary is generated to $(TYPING_SUMMARY_FILE)
.PHONY: typing-summary
typing-summary:
	@echo "running type checks and saving to $(TYPE_ERRORS_DIR)/"
	...

.PHONY: test
test: clean
	@echo "running tests"
	...

.PHONY: test-notorch
test-notorch: clean
	@echo "running only tests without torch"
	...

.PHONY: check
check: format-check test typing
	@echo "run format checks, tests, and typing checks"
	...

# generates a whole tree of documentation in html format.
# see `$(MAKE_DOCS_SCRIPT_PATH)` and the templates in `$(DOCS_RESOURCES_DIR)/templates/html/` for more info
.PHONY: docs-html
docs-html:
	@echo "generate html docs"
	...

# instead of a whole website, generates a single markdown file with all docs using the templates in `$(DOCS_RESOURCES_DIR)/templates/markdown/`.
# this is useful if you want to have a copy that you can grep/search, but those docs are much messier.
.PHONY: docs-md
docs-md:
	@echo "generate combined (single-file) docs in markdown"
	...

# generate coverage reports from test results
# WARNING: if .coverage file not found, will automatically run `make test` first
# - generates text report: $(COVERAGE_REPORTS_DIR)/coverage.txt
# - generates SVG badge: $(COVERAGE_REPORTS_DIR)/coverage.svg
# - generates HTML report: $(COVERAGE_REPORTS_DIR)/html/
# - removes .gitignore from html dir (we publish coverage with docs)
.PHONY: cov
cov:
	@echo "generate coverage reports"
	...

# runs the coverage report, then the docs, then the combined docs
.PHONY: docs
docs: cov docs-html docs-md todo lmcat
	@echo "generate all documentation and coverage reports"
	...

# remove generated documentation files, but preserve resources
# - removes all docs except those in DOCS_RESOURCES_DIR
# - preserves files/patterns specified in pyproject.toml config
# - distinct from `make clean` (which removes temp build files, not docs)
# configure via pyproject.toml:[tool.makefile.docs]:
#   [tool.makefile.docs]
#   output_dir = "docs"  # must match DOCS_DIR in makefile
#   no_clean = [  # files/patterns to preserve when cleaning
#     "resources/**",
#     "*.svg",
#     "*.css"
#   ]
.PHONY: docs-clean
docs-clean:
	@echo "remove generated docs except resources"
	...

# get all TODO's from the code
# configure via pyproject.toml:[tool.makefile.inline-todo]:
#   [tool.makefile.inline-todo]
#   search_dir = "."  # directory to search for TODOs
#   out_file_base = "docs/other/todo-inline"  # output file path (without extension)
#   context_lines = 2  # lines of context around each TODO
#   extensions = ["py", "md"]  # file extensions to search
#   tags = ["CRIT", "TODO", "FIXME", "HACK", "BUG", "DOC"]  # tags to look for
#   exclude = ["docs/**", ".venv/**", "scripts/get_todos.py"]  # patterns to exclude
#   branch = "main"  # git branch for URLs
#   # repo_url = "..."  # repository URL (defaults to [project.urls.{repository,github}])
#   # template_md = "..."  # custom jinja2 template for markdown output
#   # template_issue = "..."  # custom format string for issues
#   # template_html_source = "..."  # custom html template path
#   tag_label_map = { "BUG" = "bug", "TODO" = "enhancement", "DOC" = "documentation" } # mapping of tags to GitHub issue labels
.PHONY: todo
todo:
	@echo "get all TODO's from the code"
	...

.PHONY: lmcat-tree
lmcat-tree:
	@echo "show in console the lmcat tree view"
	...

.PHONY: lmcat
lmcat:
	@echo "write the lmcat full output to pyproject.toml:[tool.lmcat.output]"
	...

# verify git is ready for publishing
# REQUIRES:
# - current branch must be $(PUBLISH_BRANCH)
# - no uncommitted changes (git status --porcelain must be empty)
# EXITS with error if either condition fails
.PHONY: verify-git
verify-git:
	@echo "checking git status"
	...

# build package distribution files
# creates wheel (.whl) and source distribution (.tar.gz) in dist/
.PHONY: build
build:
	@echo "build the package"
	...

# publish package to PyPI and create git tag
# PREREQUISITES:
# - must be on $(PUBLISH_BRANCH) branch with clean git status (verified by verify-git)
# - must have $(PYPI_TOKEN_FILE) with your PyPI token
# - version in pyproject.toml must be different from $(LAST_VERSION_FILE)
# PROCESS:
# 1. runs checks, validates version, builds package, verifies git clean
# 2. prompts for version confirmation (you can edit $(COMMIT_LOG_FILE) at this point)
# 3. creates git commit updating $(LAST_VERSION_FILE)
# 4. creates annotated git tag with commit log as description
# 5. pushes tag to origin
# 6. uploads to PyPI via twine
.PHONY: publish
publish: check version build verify-git
	@echo "Ready to publish $(PROJ_VERSION) to PyPI"
	...

# cleans up temporary files:
# - caches: .mypy_cache, .ruff_cache, .pytest_cache, .coverage
# - build artifacts: dist/, build/, *.egg-info
# - test temp files: $(TESTS_TEMP_DIR)
# - __pycache__ directories and *.pyc/*.pyo files in $(PACKAGE_NAME), $(TESTS_DIR), $(DOCS_DIR)
# uses `-` prefix on find commands to continue even if directories don't exist
# distinct from `make docs-clean`, which removes generated documentation
.PHONY: clean
clean:
	@echo "clean up temporary files"
	...

# remove all generated/build files including .venv
# runs: clean + docs-clean + dep-clean
# removes .venv, uv.lock, requirements.txt files, generated docs, build artifacts
# run `make dep` after this to reinstall dependencies
.PHONY: clean-all
clean-all: clean docs-clean dep-clean
	@echo "clean up all temporary files, dep files, venv, and generated docs"
	...

.PHONY: info
info: gen-version-info
	@echo "# makefile variables"
	...

.PHONY: info-long
info-long: info
	@echo "# other variables"
	...

# Smart help command: shows general help, or detailed info about specific targets
# Usage:
#   make help              - shows general help (list of targets + makefile variables)
#   make help="test"       - shows detailed info about the 'test' recipe
#   make HELP="test clean" - shows detailed info about multiple recipes
#   make h=*               - shows detailed info about all recipes (wildcard expansion)
#   make H="test"          - same as HELP (case variations supported)
#
# All variations work: help/HELP/h/H with values like "foo", "foo bar", "*", "--all"
.PHONY: help
help:
	...

``````{ end_of_file="makefile" }

``````{ path="pyproject.toml"  }
[project]
    name = "zanj"
    version = "0.6.0"
    description = "save and load complex objects to disk without pickling"
    license = "GPL-3.0-only"
    authors = [
        { name = "Michael Ivanitskiy", email = "mivanits@umich.edu" }
    ]
    readme = "README.md"
    requires-python = ">=3.8"
    classifiers=[
        "Programming Language :: Python :: 3.8",
        "Programming Language :: Python :: 3.9",
        "Programming Language :: Python :: 3.10",
        "Programming Language :: Python :: 3.11",
        "Programming Language :: Python :: 3.12",
        "Development Status :: 4 - Beta",
        "License :: OSI Approved :: GNU General Public License v3 (GPLv3)",
        "Operating System :: OS Independent",
    ]

    dependencies = [
        "muutils>=0.8.0",
        "numpy>=1.24.4",
        "jaxtyping>=0.2.12",
    ]


[project.optional-dependencies]
pandas = ["pandas>=1.5.3"]
polars = ["polars>=0.20.0"]
torch = [
    # we still need torch for < 3.8 but users will have to figure out a working version themselves
    "torch>=1.13.1; python_version >= '3.9'",
    "torch==1.13.1; python_version < '3.9'",
    "h11>=0.16.0", # see https://github.com/mivanit/ZANJ/security/dependabot/13
]

[project.urls]
    Homepage = "https://miv.name/zanj"
    Repository = "https://github.com/mivanit/zanj"
    Documentation = "https://miv.name/zanj/"
	Issues = "https://github.com/mivanit/ZANJ/issues"



[dependency-groups]
    dev = [
        # typing
        "mypy>=1.0.1",
        "ty>=0.0.18",
        "basedpyright",
        # tests & coverage
        "pytest>=8.2.2",
        "pytest-cov>=4.1.0",
        # for testing plotting and notebooks
        "ipykernel>=6.23.2",
        # tornado depended on by notebook stuff, see
        # https://github.com/mivanit/ZANJ/security/dependabot/22
        "tornado>=6.5; python_version >= '3.9'",
        "jupyter",
        "matplotlib>=3.0.0",
        "plotly>=5.0.0",
        # see https://github.com/mivanit/ZANJ/security/dependabot/18
        "setuptools>=78.1.1; python_version >= '3.9'",
        # generating docs
        "pdoc>=14.6.0",
        # https://github.com/mivanit/ZANJ/security/dependabot/4
        "jinja2>=3.1.6",
        # lmcat -- a custom library. not exactly docs, but lets an LLM see all the code
        "lmcat>=0.2.0; python_version >= '3.11'",
        # tomli since no tomlib in python < 3.11
        "tomli>=2.1.0; python_version < '3.11'",
        # for uploading
        "twine",
        # for testing dataframe support
        "pandas>=1.5.3",
        "polars>=0.20.0",
    ]
    lint = [
        # lint
        "ruff>=0.4.8",
    ]

[build-system]
    requires = ["hatchling"]
    build-backend = "hatchling.build"


[tool.ruff]
    exclude = ["tests/input_data", "tests/junk_data", ".meta/scripts/"]

[tool.mypy]
    packages = ["zanj"]
    exclude = ["tests/input_data", "tests/junk_data"]
    show_error_codes = true

[tool.ty.src]
    exclude = ["tests/", "demo.ipynb"]

[tool.ty.rules]
    unused-type-ignore-comment = "ignore"
    too-many-positional-arguments = "ignore"
    unused-ignore-comment = "ignore"

[tool.lmcat]
    output = "docs/other/lmcat.txt" # changing this might mean it wont be accessible from the docs
    ignore_patterns = [
		"docs/**",
		".venv/**",
		".git/**",
		".meta/**",
        ".ruff_cache/**",
		"uv.lock",
		"LICENSE",
	]
    [tool.lmcat.glob_process]
        "[mM]akefile" = "makefile_recipes"
        "*.ipynb" = "ipynb_to_md"
        "*.csv" = "csv_preview_5_lines"

# for configuring this tool (makefile, make_docs.py)
# ============================================================
[tool.makefile]
[tool.makefile.docs]
    output_dir = "docs"
    no_clean = [
        ".nojekyll",  # For GitHub Pages
        "temp",
    ]
    markdown_headings_increment = 2
    warnings_ignore = []

    [tool.makefile.docs.notebooks]
        enabled = true
        source_path = "."
        output_path_relative = "notebooks"
        [tool.makefile.docs.notebooks.descriptions]
            "demo" = "Example notebook showing basic usage"


[tool.makefile.uv-exports]
	args = [
		"--no-hashes"
	]
	exports = [
		# no groups, no extras, just the base dependencies
		{ name = "base", groups = false, extras = false },
		# all groups
		{ name = "groups", groups = true, extras = false },
		# only the lint group -- custom options for this
		{ name = "lint", options = ["--only-group", "lint"] },
		# # all groups and extras
		{ name = "all", filename="requirements.txt", groups = true, extras=true },
		# # all groups and extras, a different way
		{ name = "all", groups = true, options = ["--all-extras"] },
	]

# configures `make todo`
[tool.makefile.inline-todo]
	search_dir = "."
	out_file_base = "docs/other/todo-inline.md"
	context_lines = 2
	extensions = ["py", "md"]
	tags = ["CRIT", "TODO", "FIXME", "HACK", "BUG", "DOC"]
	exclude = [
		"docs/**",
		".venv/**",
	]
    [tool.makefile.inline-todo.tag_label_map]
        "BUG" = "bug"
        "TODO" = "enhancement"
		"DOC" = "documentation"

# ============================================================


``````{ end_of_file="pyproject.toml" }