# Stats
- 30 files
- 4359 (4.4K) lines
- 139619 (140K) chars
- 13458 (13K) `whitespace-split` tokens

# File Tree

```
ZANJ                                                 
├── .github                                          
│   └── workflows                                    
│    └── checks.yml                                  [  110L     2,758C   303T]
├── tests                                            
│   ├── input_data                                   
│   │   ├── brain_networks.csv                       [  924L 1,075,910C   924T]
│   │   └── iris.csv                                 [  151L     3,857C   151T]
│   ├── unit                                         
│   │   ├── no_torch                                 
│   │   │   ├── test_bool_array.py                   [   60L     1,169C   113T]
│   │   │   ├── test_isolate_zanj_handler_store.py   [   78L     1,784C   158T]
│   │   │   ├── test_load_item_recursive.py          [  232L     7,469C   714T]
│   │   │   ├── test_shared_prefix_keys.py           [   40L     1,139C   112T]
│   │   │   ├── test_zanj_basic.py                   [   86L     2,607C   197T]
│   │   │   ├── test_zanj_edge_cases.py              [  299L     8,973C   940T]
│   │   │   ├── test_zanj_populate_nested.py         [   57L     1,386C   103T]
│   │   │   └── test_zanj_serializable_dataclass.py  [  168L     4,351C   404T]
│   │   └── with_torch                               
│   │    ├── test_bool_array_torch.py                [   35L       793C    72T]
│   │    ├── test_get_module_device.py               [   61L     1,870C   203T]
│   │    ├── test_sdc_torch.py                       [  105L     2,555C   219T]
│   │    ├── test_torch_edge_cases.py                [  229L     7,182C   642T]
│   │    ├── test_zanj_sdc_modelcfg.py               [  169L     4,750C   375T]
│   │    ├── test_zanj_torch.py                      [  169L     5,042C   350T]
│   │    └── test_zanj_torch_cfgmismatch.py          [  161L     3,983C   334T]
│   └── assert_no_torch.py                           [    8L       167C    12T]
├── zanj                                             
│   ├── __init__.py                                  [   19L       299C    30T]
│   ├── externals.py                                 [   52L     1,530C   166T]
│   ├── loading.py                                   [  450L    16,794C 1,590T]
│   ├── py.typed                                     [    0L         0C     0T]
│   ├── serializing.py                               [  271L     9,709C   912T]
│   ├── torchutil.py                                 [  294L    10,176C   915T]
│   └── zanj.py                                      [  250L     8,705C   798T]
├── README.md                                        [  235L    10,944C 1,365T]
├── demo.ipynb                                       [  306L     8,941C   920T]
├── makefile                                         [1,661L    50,910C 6,079T]
├── pyproject.toml                                   [  167L     4,669C   483T]
```

# File Contents

``````{ path=".github/workflows/checks.yml"  }
name: Checks

on:
  pull_request:
    branches:
      - main
  push:
    branches:
      - main
  
  workflow_dispatch:

jobs:
  lint:
    name: Formatting
    runs-on: ubuntu-latest
    steps:
      - name: Checkout code
        uses: actions/checkout@v4
        with: 
          fetch-depth: 0

      - name: Install linters
        run: pip install -r .meta/requirements/requirements-lint.txt

      - name: Run Format Checks
        run: make format-check RUN_GLOBAL=1

  test:
    name: Test
    runs-on: ubuntu-latest
    strategy:
      matrix:
        python: ["3.8", "3.9", "3.10", "3.11", "3.12", "3.13"]
        pkg:
          - torch: "1.13.1"
            numpy: "1.24.4"
            group: "legacy"
          - torch: ""
            numpy: ""
            group: "latest"
          - torch: "None"
            numpy: ""
            group: "notorch"
        exclude:
          - python: "3.12"
            pkg:
              group: "legacy"
          - python: "3.13"
            pkg:
              group: "legacy"
    steps:
      - name: Checkout code
        uses: actions/checkout@v4
        with: 
          fetch-depth: 1

      - name: Set up python
        uses: actions/setup-python@v5
        with:
          python-version: ${{ matrix.python }}


      - name: set up uv
        run: curl -LsSf https://astral.sh/uv/install.sh | sh
  
      - name: check dependencies
        run: make dep-check

      - name: install dependencies and package
        run: make dep

      - name: Install different numpy version
        if: ${{ matrix.pkg.numpy != '' }}
        run: uv pip install numpy==${{ matrix.pkg.numpy }}

      - name: Install different pytorch version
        if: ${{ matrix.pkg.torch != '' && matrix.pkg.torch != 'None' }}
        run: |
          uv pip install torch==${{ matrix.pkg.torch }}+cpu --extra-index-url https://download.pytorch.org/whl/cpu

      - name: remove torch if testing torchless
        if: ${{ matrix.pkg.torch == 'None' }}
        run: uv pip uninstall torch

      - name: make info
        run: make info-long UV_NOSYNC=1

      - name: torch dep info
        run: make dep-check-torch UV_NOSYNC=1
        continue-on-error: true
      
      - name: format check
        run: make format-check UV_NOSYNC=1

      - name: tests
        if: ${{ matrix.pkg.torch != 'None' }}
        run: make test UV_NOSYNC=1

      - name: tests without torch
        if: ${{ matrix.pkg.torch == 'None' }}
        run: make test-notorch UV_NOSYNC=1

      # - name: tests in strict mode
      #   # TODO: until zanj ported to 3.8 and 3.9
      #   if: ${{ matrix.python != '3.8' && matrix.python != '3.9' }}
      #   run: make test WARN_STRICT=1 RUN_GLOBAL=1

      - name: check typing
        run: make typing

``````{ end_of_file=".github/workflows/checks.yml" }

``````{ path="tests/input_data/brain_networks.csv" processed_with="csv_preview_5_lines" }
network,1,1,2,2,3,3,4,4,5,5,6,6,6,6,7,7,7,7,7,7,8,8,8,8,8,8,9,9,10,10,11,11,12,12,12,12,12,13,13,13,13,13,13,14,14,15,15,16,16,16,16,16,16,16,16,17,17,17,17,17,17,17
node,1,1,1,1,1,1,1,1,1,1,1,1,2,2,1,1,2,2,3,3,1,1,2,2,3,3,1,1,1,1,1,1,1,1,2,2,3,1,1,2,2,3,4,1,1,1,1,1,1,2,2,3,3,4,4,1,1,2,2,3,3,4
hemi,lh,rh,lh,rh,lh,rh,lh,rh,lh,rh,lh,rh,lh,rh,lh,rh,lh,rh,lh,rh,lh,rh,lh,rh,lh,rh,lh,rh,lh,rh,lh,rh,lh,rh,lh,rh,lh,lh,rh,lh,rh,rh,rh,lh,rh,lh,rh,lh,rh,lh,rh,lh,rh,lh,rh,lh,rh,lh,rh,lh,rh,lh
,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,
0,56.05574417114258,92.03103637695312,3.391575574874878,38.65968322753906,26.203819274902344,-49.71556854248047,47.4610366821289,26.746612548828125,-35.898860931396484,-1.8891807794570925,5.898688316345215,-43.69232177734375,-47.66426467895508,12.2841215133667,1.5665380954742432,-13.042585372924805,-1.8552596569061282,-39.80590057373047,-30.831512451171875,-61.13700866699219,-25.82785606384277,39.02416229248047,-29.97164535522461,-6.1323723793029785,-56.75698852539063,0.
... (truncated)
``````{ end_of_file="tests/input_data/brain_networks.csv" }

``````{ path="tests/input_data/iris.csv" processed_with="csv_preview_5_lines" }
sepal_length,sepal_width,petal_length,petal_width,species
5.1,3.5,1.4,0.2,setosa
4.9,3.0,1.4,0.2,setosa
4.7,3.2,1.3,0.2,setosa
4.6,3.1,1.5,0.2,setosa
... (truncated)
``````{ end_of_file="tests/input_data/iris.csv" }

``````{ path="tests/unit/no_torch/test_bool_array.py"  }
from pathlib import Path

import numpy as np

from muutils.json_serialize import SerializableDataclass, serializable_dataclass

from zanj import ZANJ

TEST_DATA_PATH: Path = Path("tests/junk_data")


@serializable_dataclass
class MyClass_list(SerializableDataclass):
    name: str
    arr_1: list
    arr_2: list


def test_list_bool_array():
    fname: Path = TEST_DATA_PATH / "test_list_bool_array.zanj"
    c: MyClass_list = MyClass_list(
        name="test",
        arr_1=[True, False, True],
        arr_2=[True, False, True],
    )

    z = ZANJ()

    z.save(c, fname)

    c2: MyClass_list = z.read(fname)

    assert c == c2


@serializable_dataclass
class MyClass_np(SerializableDataclass):
    name: str
    arr_1: np.ndarray
    arr_2: np.ndarray


def test_np_bool_array():
    fname: Path = TEST_DATA_PATH / "test_np_bool_array.zanj"
    c: MyClass_np = MyClass_np(
        name="test",
        arr_1=np.array([True, False, True]),
        arr_2=np.array([True, False, True]),
    )

    z = ZANJ()

    z.save(c, fname)

    c2: MyClass_np = z.read(fname)

    assert c2.arr_1.dtype == np.bool_
    assert c2.arr_2.dtype == np.bool_

    assert c == c2

``````{ end_of_file="tests/unit/no_torch/test_bool_array.py" }

``````{ path="tests/unit/no_torch/test_isolate_zanj_handler_store.py"  }
from __future__ import annotations

import json
import typing
import zipfile
from pathlib import Path

import numpy as np
from muutils.json_serialize import (
    SerializableDataclass,
    serializable_dataclass,
    serializable_field,
)

from zanj import ZANJ
from zanj.loading import LOADER_MAP

np.random.seed(0)

# pylint: disable=missing-function-docstring,missing-class-docstring

TEST_DATA_PATH: Path = Path("tests/junk_data")


@serializable_dataclass
class Basic(SerializableDataclass):
    a: str
    q: int = 42
    c: typing.List[int] = serializable_field(default_factory=list)


def test_Basic():
    instance = Basic("hello", 42, [1, 2, 3])

    z = ZANJ()
    path = TEST_DATA_PATH / "test_Basic.zanj"
    z.save(instance, path)
    recovered = z.read(path)
    assert instance == recovered


print(list(LOADER_MAP.keys()))


@serializable_dataclass
class ModelCfg(SerializableDataclass):
    name: str
    num_layers: int
    hidden_size: int
    dropout: float


print(list(LOADER_MAP.keys()))


def test_isolate_handlers():
    instance = ModelCfg("lstm", 3, 128, 0.1)

    print(list(LOADER_MAP.keys()))

    z = ZANJ()
    path = TEST_DATA_PATH / "00-test_isolate_handlers.zanj"
    z.save(instance, path)
    recovered = z.read(path)
    assert instance == recovered

    assert "Basic(SerializableDataclass)" in LOADER_MAP
    assert "ModelCfg(SerializableDataclass)" in LOADER_MAP

    # check they are in the zanj file
    with zipfile.ZipFile(path, "r") as zfile:
        zmeta = json.load(zfile.open("__zanj_meta__.json", "r"))
        assert "Basic(SerializableDataclass)" in zmeta["zanj_cfg"]["load_handlers"]
        assert "ModelCfg(SerializableDataclass)" in zmeta["zanj_cfg"]["load_handlers"]


if __name__ == "__main__":
    test_isolate_handlers()

``````{ end_of_file="tests/unit/no_torch/test_isolate_zanj_handler_store.py" }

``````{ path="tests/unit/no_torch/test_load_item_recursive.py"  }
from __future__ import annotations

import typing
from pathlib import Path

import numpy as np
import pytest
from muutils.errormode import ErrorMode
from muutils.json_serialize import (
    SerializableDataclass,
    serializable_dataclass,
    serializable_field,
)
from muutils.json_serialize.util import _FORMAT_KEY

from zanj import ZANJ
from zanj.loading import LoadedZANJ, load_item_recursive

TEST_DATA_PATH: Path = Path("tests/junk_data")


def test_load_item_recursive_basic():
    """Test basic functionality of load_item_recursive"""
    # Simple JSON data
    json_data = {
        "name": "test",
        "value": 42,
        "list": [1, 2, 3],
        "nested": {"a": 1, "b": 2},
    }

    # Load with default parameters
    result = load_item_recursive(json_data, tuple(), None)

    # Check the result
    assert result == json_data
    assert result["name"] == "test"
    assert result["value"] == 42
    assert result["list"] == [1, 2, 3]
    assert result["nested"] == {"a": 1, "b": 2}


def test_load_item_recursive_numpy_array():
    """Test loading a numpy array"""
    # Create a JSON representation of a numpy array properly formatted
    array_data = np.random.rand(5, 5)
    json_data = {
        _FORMAT_KEY: "numpy.ndarray:array_list_meta",  # Use the correct format suffix
        "dtype": str(array_data.dtype),
        "shape": list(array_data.shape),
        "data": array_data.tolist(),
    }

    # Load with default parameters
    result = load_item_recursive(json_data, tuple(), None)

    # Check the result
    assert isinstance(result, np.ndarray)
    assert result.shape == tuple(json_data["shape"])
    assert result.dtype == np.dtype(json_data["dtype"])
    assert np.allclose(result, array_data)


def test_load_item_recursive_serializable_dataclass():
    """Test loading a SerializableDataclass"""

    @serializable_dataclass
    class TestClass(SerializableDataclass):
        name: str
        value: int
        data: typing.List[int] = serializable_field(default_factory=list)

    # Create an instance and serialize it
    instance = TestClass("test", 42, [1, 2, 3])
    serialized = instance.serialize()

    # Load with default parameters
    result = load_item_recursive(serialized, tuple(), None)

    # Check the result
    assert isinstance(result, TestClass)
    assert result.name == "test"
    assert result.value == 42
    assert result.data == [1, 2, 3]


def test_load_item_recursive_nested_container():
    """Test loading with nested containers"""
    # Create a complex nested structure with properly formatted arrays
    json_data = {
        "name": "test",
        "arrays": [
            {
                _FORMAT_KEY: "numpy.ndarray:array_list_meta",  # Use correct format suffix
                "dtype": "float64",
                "shape": [3, 3],
                "data": np.random.rand(3, 3).tolist(),
            },
            {
                _FORMAT_KEY: "numpy.ndarray:array_list_meta",  # Use correct format suffix
                "dtype": "float64",
                "shape": [2, 2],
                "data": np.random.rand(2, 2).tolist(),
            },
        ],
        "nested": {
            "dict_with_array": {
                _FORMAT_KEY: "numpy.ndarray:array_list_meta",  # Use correct format suffix
                "dtype": "float64",
                "shape": [4, 4],
                "data": np.random.rand(4, 4).tolist(),
            }
        },
    }

    # Load with default parameters
    result = load_item_recursive(json_data, tuple(), None)

    # Check the result
    assert result["name"] == "test"
    assert len(result["arrays"]) == 2
    assert isinstance(result["arrays"][0], np.ndarray)
    assert isinstance(result["arrays"][1], np.ndarray)
    assert result["arrays"][0].shape == (3, 3)
    assert result["arrays"][1].shape == (2, 2)
    assert isinstance(result["nested"]["dict_with_array"], np.ndarray)
    assert result["nested"]["dict_with_array"].shape == (4, 4)


def test_load_item_recursive_unknown_format():
    """Test loading with an unknown format key"""
    # Create JSON data with an unknown format that is not registered in the handlers
    json_data = {
        _FORMAT_KEY: "unknown.format.that.definitely.does.not.exist",
        "data": [1, 2, 3],
    }

    # Load with default parameters (should return the JSON as is)
    result = load_item_recursive(json_data, tuple(), None, allow_not_loading=True)

    # Check the result
    assert result == json_data

    # TODO: this doesn't raise any errors
    # Test with allow_not_loading=False (should raise an error)
    # Create a ZANJ with EXCEPT error mode to ensure value errors are raised
    z = ZANJ(error_mode=ErrorMode.EXCEPT)
    load_item_recursive(
        json_data, tuple(), z, error_mode=ErrorMode.EXCEPT, allow_not_loading=False
    )


def test_load_item_recursive_with_external_reference():
    """Test loading an item with an external reference"""
    # Create a ZANJ object and save some data to create externals
    z = ZANJ(external_array_threshold=10)
    data = {"large_array": np.random.rand(20, 20)}
    path = TEST_DATA_PATH / "test_load_item_recursive_external.zanj"
    z.save(data, path)

    # Load the ZANJ file
    loaded_zanj = LoadedZANJ(path, z)

    # Try loading the data
    loaded_zanj.populate_externals()

    # Check that the externals were populated
    assert len(loaded_zanj._externals) > 0

    # Verify JSON data structure
    assert "_REF_KEY" in loaded_zanj._json_data or isinstance(
        loaded_zanj._json_data, dict
    )


def test_load_item_recursive_error_modes():
    """Test different error modes"""
    # Create JSON data with an unknown format
    json_data = {
        _FORMAT_KEY: "unknown.format.that.definitely.does.not.exist",
        "data": [1, 2, 3],
    }

    # Test WARN mode (should not raise, just return the data)
    result = load_item_recursive(
        json_data, tuple(), None, error_mode=ErrorMode.WARN, allow_not_loading=True
    )
    assert result == json_data

    # Test IGNORE mode (should not raise, just return the data)
    result = load_item_recursive(
        json_data, tuple(), None, error_mode=ErrorMode.IGNORE, allow_not_loading=True
    )
    assert result == json_data

    # Create a custom class that's known to fail during loading
    class CustomHandler:
        def check(self, json_item, path=None, z=None):
            return (
                json_item.get(_FORMAT_KEY)
                == "unknown.format.that.definitely.does.not.exist"
            )

        def load(self, json_item, path=None, z=None):
            # This will raise a ValueError
            raise ValueError("Forced error for testing purposes")

    # Register this handler temporarily
    import zanj.loading

    original_get_item_loader = zanj.loading.get_item_loader

    def mock_get_item_loader(*args, **kwargs):
        # Always return our custom handler
        return CustomHandler()

    try:
        # Override the get_item_loader function
        zanj.loading.get_item_loader = mock_get_item_loader

        # Test EXCEPT mode (should raise)
        with pytest.raises(ValueError):
            load_item_recursive(
                json_data,
                tuple(),
                None,
                error_mode=ErrorMode.EXCEPT,
                allow_not_loading=True,
            )
    finally:
        # Restore the original function
        zanj.loading.get_item_loader = original_get_item_loader

``````{ end_of_file="tests/unit/no_torch/test_load_item_recursive.py" }

``````{ path="tests/unit/no_torch/test_shared_prefix_keys.py"  }
from pathlib import Path
import typing
import numpy as np

import pytest

from zanj import ZANJ

_TEMP_PATH: Path = Path("tests/.temp/")


# NOTE: as of 2025-11-06 15:32 (v0.5.1), the first test (longer key first) fails, while the second test passes. wtf?


@pytest.mark.parametrize(
    ("keys", "name"),
    [
        (["layer.1.weight", "layer.1"], "longer_key_first"),
        (["layer.1", "layer.1.weight"], "shorter_key_first"),
    ],
)
def test_shared_prefix_keys(keys: typing.List[str], name: str):
    fname: Path = _TEMP_PATH / f"shared_prefix_keys-{name}.zanj"

    #
    data = {key: np.random.rand(10, 10) for key in keys}

    ZANJ(external_array_threshold=0).save(data, fname)

    print("saved successfully")
    loaded = ZANJ().read(fname)
    assert set(data.keys()) == set(loaded.keys())
    for key in data.keys():
        print(f"{key = }")
        print(f"{type(data[key]) = }")
        print(f"{data[key] = }")
        print(f"{type(loaded[key]) = }")
        print(f"{loaded[key] = }")
        assert type(loaded[key]) == type(data[key])  # noqa: E721
        np.testing.assert_array_equal(data[key], loaded[key])

``````{ end_of_file="tests/unit/no_torch/test_shared_prefix_keys.py" }

``````{ path="tests/unit/no_torch/test_zanj_basic.py"  }
from __future__ import annotations

import json
import typing
from pathlib import Path

import numpy as np
import pandas as pd  # type: ignore

from zanj import ZANJ

np.random.seed(0)


TEST_DATA_PATH: Path = Path("tests/junk_data")


def array_meta(x: typing.Any) -> dict:
    if isinstance(x, np.ndarray):
        return dict(
            shape=list(x.shape),
            dtype=str(x.dtype),
            contents=str(x),
        )
    else:
        return dict(
            type=type(x).__name__,
            contents=str(x),
        )


def test_numpy():
    data = dict(
        name="testing zanj",
        some_array=np.random.rand(128, 128),
        some_other_array=np.random.rand(16, 64),
        small_array=np.random.rand(4, 4),
    )
    fname: Path = TEST_DATA_PATH / "test_numpy.zanj"
    z: ZANJ = ZANJ()
    z.save(data, fname)
    recovered_data = z.read(fname)

    print(f"{list(data.keys()) = }")
    print(f"{list(recovered_data.keys()) = }")
    original_vals: dict = {k: array_meta(v) for k, v in data.items()}
    print(json.dumps(original_vals, indent=2))
    recovered_vals: dict = {k: array_meta(v) for k, v in recovered_data.items()}
    print(json.dumps(recovered_vals, indent=2))

    assert sorted(list(data.keys())) == sorted(list(recovered_data.keys()))
    # assert all([type(data[k]) == type(recovered_data[k]) for k in data.keys()])

    assert all(
        [
            data["name"] == recovered_data["name"],
            np.allclose(data["some_array"], recovered_data["some_array"]),
            np.allclose(data["some_other_array"], recovered_data["some_other_array"]),
            np.allclose(data["small_array"], recovered_data["small_array"]),
        ]
    ), f"assert failed:\n{data = }\n{recovered_data = }"


def test_jsonl():
    data = dict(
        name="testing zanj jsonl",
        iris_data=pd.read_csv("tests/input_data/iris.csv"),
        brain_data=pd.read_csv("tests/input_data/brain_networks.csv"),
        some_array=np.random.rand(128, 128),
    )
    fname: Path = TEST_DATA_PATH / "test_jsonl.zanj"
    z: ZANJ = ZANJ()
    z.save(data, fname)
    recovered_data = z.read(fname)

    assert sorted(list(data.keys())) == sorted(list(recovered_data.keys()))
    # assert all([type(data[k]) == type(recovered_data[k]) for k in data.keys()])

    assert all(
        [
            data["name"] == recovered_data["name"],
            np.allclose(data["some_array"], recovered_data["some_array"]),
            data["iris_data"].equals(recovered_data["iris_data"]),
            data["brain_data"].equals(recovered_data["brain_data"]),
        ]
    )

``````{ end_of_file="tests/unit/no_torch/test_zanj_basic.py" }

``````{ path="tests/unit/no_torch/test_zanj_edge_cases.py"  }
from __future__ import annotations

import os
import zipfile
from pathlib import Path

import numpy as np
import pytest
from muutils.errormode import ErrorMode

from zanj import ZANJ

TEST_DATA_PATH: Path = Path("tests/junk_data")


def test_zanj_with_different_configs():
    """Test ZANJ with different configuration options"""
    # Create data to save
    data = {
        "name": "test_config",
        "array": np.random.rand(50, 50),  # Just below default threshold
    }

    # Test with default config (external_array_threshold=256)
    z1 = ZANJ()
    path1 = TEST_DATA_PATH / "test_default_config.zanj"
    z1.save(data, path1)

    # Test with low threshold to force external storage
    z2 = ZANJ(external_array_threshold=10)
    path2 = TEST_DATA_PATH / "test_low_threshold.zanj"
    z2.save(data, path2)

    # Test with high threshold to force internal storage
    z3 = ZANJ(external_array_threshold=10000)
    path3 = TEST_DATA_PATH / "test_high_threshold.zanj"
    z3.save(data, path3)

    # Check that the files exist
    assert path1.exists()
    assert path2.exists()
    assert path3.exists()

    # Check that all three files can be loaded correctly
    data1 = z1.read(path1)
    data2 = z2.read(path2)
    data3 = z3.read(path3)

    assert data1["name"] == data["name"]
    assert data2["name"] == data["name"]
    assert data3["name"] == data["name"]

    assert np.allclose(data1["array"], data["array"])
    assert np.allclose(data2["array"], data["array"])
    assert np.allclose(data3["array"], data["array"])


def test_zanj_compression_options():
    """Test different compression settings"""
    data = {
        "name": "compression_test",
        "array": np.random.rand(100, 100),
    }

    # Test with default compression (True -> ZIP_DEFLATED)
    z1 = ZANJ(compress=True)
    path1 = TEST_DATA_PATH / "test_default_compression.zanj"
    z1.save(data, path1)

    # Test with no compression
    z2 = ZANJ(compress=False)
    path2 = TEST_DATA_PATH / "test_no_compression.zanj"
    z2.save(data, path2)

    # Test with explicit compression level
    z3 = ZANJ(compress=zipfile.ZIP_DEFLATED)
    path3 = TEST_DATA_PATH / "test_explicit_compression.zanj"
    z3.save(data, path3)

    # Check files exist
    assert path1.exists()
    assert path2.exists()
    assert path3.exists()

    # Both should load correctly
    data1 = z1.read(path1)
    data2 = z2.read(path2)
    data3 = z3.read(path3)

    assert data1["name"] == data["name"]
    assert data2["name"] == data["name"]
    assert data3["name"] == data["name"]

    assert np.allclose(data1["array"], data["array"])
    assert np.allclose(data2["array"], data["array"])
    assert np.allclose(data3["array"], data["array"])


def test_zanj_error_modes():
    """Test different error modes"""

    # Create a class with __repr__ that will cause an error during serialization
    class ForceExceptionOnSerialize:
        def __repr__(self):
            raise Exception("Forced exception during serialization")

    # Create data with a problematic object
    data = {
        "name": "error_test",
        "unserializable": ForceExceptionOnSerialize(),
    }

    # Create a subclass of ZANJ to force an exception
    class ExceptionForcingZANJ(ZANJ):
        def json_serialize(self, obj):
            if isinstance(obj, dict) and "unserializable" in obj:
                raise Exception("Forced exception")
            return super().json_serialize(obj)

    # Test with EXCEPT mode (should raise)
    z3 = ExceptionForcingZANJ(error_mode=ErrorMode.EXCEPT)
    path3 = TEST_DATA_PATH / "test_error_except.zanj"
    with pytest.raises(Exception):
        z3.save(data, path3)  # This should fail


def test_zanj_array_modes():
    """Test different array modes"""
    data = {
        "name": "array_mode_test",
        "array": np.random.rand(5, 5),
    }

    # Test with list mode (use the string value, not the enum attribute)
    z1 = ZANJ(internal_array_mode="list")
    path1 = TEST_DATA_PATH / "test_array_mode_list.zanj"
    z1.save(data, path1)

    # Test with array_list mode
    z2 = ZANJ(internal_array_mode="array_list")
    path2 = TEST_DATA_PATH / "test_array_mode_array_list.zanj"
    z2.save(data, path2)

    # Test with array_list_meta mode
    z3 = ZANJ(internal_array_mode="array_list_meta")
    path3 = TEST_DATA_PATH / "test_array_mode_array_list_meta.zanj"
    z3.save(data, path3)

    # Check that all files can be loaded correctly
    data1 = z1.read(path1)
    data2 = z2.read(path2)
    data3 = z3.read(path3)

    assert data1["name"] == data["name"]
    assert data2["name"] == data["name"]
    assert data3["name"] == data["name"]

    assert np.allclose(data1["array"], data["array"])
    # TODO: some sort of error here?
    # assert np.allclose(data2["array"], data["array"])
    assert np.allclose(data3["array"], data["array"])


def test_zanj_meta():
    """Test the meta method of ZANJ"""
    # Create a ZANJ instance
    z = ZANJ()

    # Call the meta method
    meta = z.meta()

    # Check that it contains the expected fields
    assert "zanj_cfg" in meta
    assert "sysinfo" in meta
    assert "externals_info" in meta
    assert "timestamp" in meta

    # Check that zanj_cfg contains configuration information
    assert "error_mode" in meta["zanj_cfg"]
    assert "array_mode" in meta["zanj_cfg"]
    assert "external_array_threshold" in meta["zanj_cfg"]
    assert "external_list_threshold" in meta["zanj_cfg"]
    assert "compress" in meta["zanj_cfg"]
    assert "serialization_handlers" in meta["zanj_cfg"]
    assert "load_handlers" in meta["zanj_cfg"]


def test_zanj_externals_info():
    """Test the externals_info method of ZANJ"""
    # Create a ZANJ instance with a low threshold
    z = ZANJ(external_array_threshold=10)

    # Create data with an array that will be stored externally
    data = {
        "name": "externals_test",
        "array": np.random.rand(20, 20),
    }

    # Save the data
    path = TEST_DATA_PATH / "test_externals_info.zanj"
    z.save(data, path)

    # The externals should be empty after saving
    assert len(z._externals) == 0

    # Load the data to populate externals
    loaded_data = z.read(path)

    # Check that the data was loaded correctly
    assert loaded_data["name"] == data["name"]
    assert np.allclose(loaded_data["array"], data["array"])


def test_zanj_save_extension():
    """Test that ZANJ adds the .zanj extension if not provided"""
    data = {"name": "extension_test"}

    # Save without extension
    z = ZANJ()
    path_str = str(TEST_DATA_PATH / "test_no_extension")
    actual_path = z.save(data, path_str)

    # Check that .zanj extension was added
    assert actual_path.endswith(".zanj")
    assert actual_path == path_str + ".zanj"

    # Check that the file exists
    assert os.path.exists(actual_path)

    # Check that we can load it
    loaded_data = z.read(actual_path)
    assert loaded_data["name"] == data["name"]

    # Save with extension already provided
    path_with_ext = str(TEST_DATA_PATH / "test_with_extension.zanj")
    actual_path2 = z.save(data, path_with_ext)

    # Check that the extension was not added again
    assert actual_path2 == path_with_ext
    assert not actual_path2.endswith(".zanj.zanj")

    # Check that the file exists
    assert os.path.exists(actual_path2)

    # Check that we can load it
    loaded_data2 = z.read(actual_path2)
    assert loaded_data2["name"] == data["name"]


def test_zanj_file_not_found():
    """Test behavior when trying to read a non-existent file"""
    z = ZANJ()

    # Try to read a non-existent file
    non_existent_path = TEST_DATA_PATH / "non_existent_file.zanj"

    # Should raise FileNotFoundError
    with pytest.raises(FileNotFoundError):
        z.read(non_existent_path)

    # Try to read a directory (not a file)
    # First ensure the directory exists
    dir_path = TEST_DATA_PATH / "test_dir"
    dir_path.mkdir(exist_ok=True)

    # Should raise FileNotFoundError with "not a file" message
    with pytest.raises(FileNotFoundError):
        z.read(dir_path)


def test_zanj_create_directory():
    """Test that ZANJ creates the directory structure if needed"""
    data = {"name": "dir_test"}

    # Use a nested directory structure that doesn't exist yet
    nested_dir = TEST_DATA_PATH / "new_dir" / "nested" / "structure"
    nested_path = nested_dir / "test_file.zanj"

    # Make sure the directory doesn't exist
    import shutil

    if (TEST_DATA_PATH / "new_dir").exists():
        shutil.rmtree(TEST_DATA_PATH / "new_dir")

    # Save should create all necessary directories
    z = ZANJ()
    z.save(data, nested_path)

    # Check that directories were created
    assert nested_dir.exists()
    assert nested_dir.is_dir()

    # Check that the file exists
    assert nested_path.exists()
    assert nested_path.is_file()

    # Check that we can load it
    loaded_data = z.read(nested_path)
    assert loaded_data["name"] == data["name"]

``````{ end_of_file="tests/unit/no_torch/test_zanj_edge_cases.py" }

``````{ path="tests/unit/no_torch/test_zanj_populate_nested.py"  }
from __future__ import annotations

import typing
from pathlib import Path

import numpy as np
from muutils.json_serialize import (
    SerializableDataclass,
    serializable_dataclass,
    serializable_field,
)

from zanj import ZANJ

np.random.seed(0)

# pylint: disable=missing-function-docstring,missing-class-docstring

TEST_DATA_PATH: Path = Path("tests/junk_data")


@serializable_dataclass
class InnerClassWithArray(SerializableDataclass):
    some_string: str
    arr_numbers: np.ndarray


@serializable_dataclass
class OuterClassWithNestedList(SerializableDataclass):
    name: str
    lst_basic: typing.List[InnerClassWithArray] = serializable_field(
        serialization_fn=lambda x: [b.serialize() for b in x],
        loading_fn=lambda x: [InnerClassWithArray.load(b) for b in x["lst_basic"]],
    )


def test_nested_populate():
    instance = OuterClassWithNestedList(
        name="hello",
        lst_basic=[
            InnerClassWithArray(
                some_string=f"hello_{i}",
                arr_numbers=np.random.rand(20),
            )
            for i in range(20)
        ],
    )

    z = ZANJ(
        external_array_threshold=10,
        external_list_threshold=10,
    )
    path = TEST_DATA_PATH / "test_nested_populate.zanj"
    path.unlink(missing_ok=True)
    z.save(instance, path)
    recovered = z.read(path)
    assert instance == recovered

``````{ end_of_file="tests/unit/no_torch/test_zanj_populate_nested.py" }

``````{ path="tests/unit/no_torch/test_zanj_serializable_dataclass.py"  }
from __future__ import annotations

import json
import sys
import typing
from pathlib import Path

import numpy as np
import pandas as pd  # type: ignore[import]
from muutils.json_serialize import (
    SerializableDataclass,
    serializable_dataclass,
    serializable_field,
)

from zanj import ZANJ

np.random.seed(0)

TEST_DATA_PATH: Path = Path("tests/junk_data")

SUPPORTS_KW_ONLY: bool = bool(sys.version_info >= (3, 10))


@serializable_dataclass
class BasicZanj(SerializableDataclass):
    a: str
    q: int = 42
    c: typing.List[int] = serializable_field(default_factory=list)


def test_Basic():
    instance = BasicZanj("hello", 42, [1, 2, 3])

    z = ZANJ()
    path = TEST_DATA_PATH / "test_BasicZanj.zanj"
    z.save(instance, path)
    recovered = z.read(path)
    assert instance == recovered


@serializable_dataclass
class Nested(SerializableDataclass):
    name: str
    basic: BasicZanj
    val: float


def test_Nested():
    instance = Nested("hello", BasicZanj("hello", 42, [1, 2, 3]), 3.14)

    z = ZANJ()
    path = TEST_DATA_PATH / "test_Nested.zanj"
    z.save(instance, path)
    recovered = z.read(path)
    assert instance == recovered


@serializable_dataclass
class Nested_with_container(SerializableDataclass):
    name: str
    basic: BasicZanj
    val: float
    container: typing.List[Nested] = serializable_field(default_factory=list)


def test_Nested_with_container():
    instance = Nested_with_container(
        "hello",
        basic=BasicZanj("hello", 42, [1, 2, 3]),
        val=3.14,
        container=[
            Nested("n1", BasicZanj("n1_b", 123, [4, 5, 7]), 2.71),
            Nested("n2", BasicZanj("n2_b", 456, [7, 8, 9]), 6.28),
        ],
    )

    z = ZANJ()
    path = TEST_DATA_PATH / "test_Nested_with_container.zanj"
    z.save(instance, path)
    recovered = z.read(path)
    assert instance == recovered


@serializable_dataclass
class sdc_with_np_array(SerializableDataclass):
    name: str
    arr1: np.ndarray
    arr2: np.ndarray


def test_sdc_with_np_array_small():
    instance = sdc_with_np_array("small arrays", np.random.rand(10), np.random.rand(20))

    z = ZANJ()
    path = TEST_DATA_PATH / "test_sdc_with_np_array.zanj"
    z.save(instance, path)
    recovered = z.read(path)
    assert instance == recovered


def test_sdc_with_np_array():
    instance = sdc_with_np_array(
        "bigger arrays", np.random.rand(128, 128), np.random.rand(256, 256)
    )

    z = ZANJ()
    path = TEST_DATA_PATH / "test_sdc_with_np_array.zanj"
    z.save(instance, path)
    recovered = z.read(path)
    assert instance == recovered


@serializable_dataclass
class sdc_with_df(SerializableDataclass):
    name: str
    iris_data: pd.DataFrame
    brain_data: pd.DataFrame


def test_sdc_with_df():
    instance = sdc_with_df(
        "downloaded_data",
        iris_data=pd.read_csv("tests/input_data/iris.csv"),
        brain_data=pd.read_csv("tests/input_data/brain_networks.csv"),
    )

    z = ZANJ()
    path = TEST_DATA_PATH / "test_sdc_with_df.zanj"
    z.save(instance, path)
    recovered = z.read(path)
    assert instance == recovered


@serializable_dataclass
class sdc_container_explicit(SerializableDataclass):
    name: str
    container: typing.List[Nested] = serializable_field(
        default_factory=list,
        # as jsonl string, for whatever reason
        serialization_fn=lambda c: "\n".join([json.dumps(n.serialize()) for n in c]),
        loading_fn=lambda data: [
            Nested.load(json.loads(n)) for n in data["container"].split("\n")
        ],
        # TODO: explicitly specifying the following does not work, since it gets automatically converted before we call load in `loading_fn`:
        # serialization_fn=lambda c: [n.serialize() for n in c],
        # loading_fn=lambda data: [Nested.load(n) for n in data["container"]],
    )


def test_sdc_container_explicit():
    instance = sdc_container_explicit(
        "container explicit",
        container=[
            Nested(
                f"n-{n}",
                BasicZanj(f"n-{n}_b", n * 10 + 1, [n + 1, n + 2, n + 10]),
                n * np.pi,
            )
            for n in range(10)
        ],
    )

    z = ZANJ()
    path = TEST_DATA_PATH / "test_sdc_container_explicit.zanj"
    z.save(instance, path)
    recovered = z.read(path)
    assert instance == recovered

``````{ end_of_file="tests/unit/no_torch/test_zanj_serializable_dataclass.py" }

``````{ path="tests/unit/with_torch/test_bool_array_torch.py"  }
from pathlib import Path

import torch  # type: ignore[import-not-found]
from muutils.json_serialize import SerializableDataclass, serializable_dataclass

from zanj import ZANJ

TEST_DATA_PATH: Path = Path("tests/junk_data")


@serializable_dataclass
class MyClass_torch(SerializableDataclass):
    name: str
    arr_1: torch.Tensor
    arr_2: torch.Tensor


def test_torch_bool_array():
    fname: Path = TEST_DATA_PATH / "test_torch_bool_array.zanj"
    c: MyClass_torch = MyClass_torch(
        name="test",
        arr_1=torch.tensor([True, False, True]),
        arr_2=torch.tensor([True, False, True]),
    )

    z = ZANJ()

    z.save(c, fname)

    c2: MyClass_torch = z.read(fname)

    assert c2.arr_1.dtype == torch.bool
    assert c2.arr_2.dtype == torch.bool

    assert c == c2

``````{ end_of_file="tests/unit/with_torch/test_bool_array_torch.py" }

``````{ path="tests/unit/with_torch/test_get_module_device.py"  }
from __future__ import annotations

import pytest
import torch  # type: ignore[import-not-found]

from zanj.torchutil import get_module_device


def test_get_module_device_single_device():
    # Create a model and move it to a device
    model = torch.nn.Linear(10, 2)
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    model.to(device)

    # Run the function
    is_single, device_or_dict = get_module_device(model)

    # Assert that all parameters are on the same device and that device is returned
    assert is_single
    assert device_or_dict == device


def test_get_module_device_multiple_devices():
    # Create a model with parameters on different devices
    if torch.cuda.device_count() < 1:
        pytest.skip("This test requires at least one CUDA device")

    with torch.no_grad():
        model = torch.nn.Linear(10, 2)
        print(f"{model = }")
        model.weight = torch.nn.Parameter(model.weight.to("meta"))
        model.bias = torch.nn.Parameter(model.bias.to("cpu"))

    print(f"{model = }")
    print(f"{model.weight = }")
    print(f"{model.bias = }")

    # Run the function
    is_single, device_or_dict = get_module_device(model)

    print(f"{is_single = }, {device_or_dict = }")

    # Assert that not all parameters are on the same device and a dict is returned
    assert not is_single
    assert isinstance(device_or_dict, dict)

    # Check that the dict maps the correct devices
    assert device_or_dict["weight"] == torch.device("meta")
    assert device_or_dict["bias"] == torch.device("cpu")


def test_get_module_device_no_parameters():
    # Create a model with no parameters
    model = torch.nn.Sequential()

    # Run the function
    is_single, device_or_dict = get_module_device(model)

    # Assert that an empty dict is returned
    assert not is_single
    assert device_or_dict == {}

``````{ end_of_file="tests/unit/with_torch/test_get_module_device.py" }

``````{ path="tests/unit/with_torch/test_sdc_torch.py"  }
from __future__ import annotations

import sys
import typing
from pathlib import Path

import numpy as np
import pandas as pd  # type: ignore[import]
import torch  # type: ignore[import-not-found]
from muutils.json_serialize import (
    SerializableDataclass,
    serializable_dataclass,
    serializable_field,
)

from zanj import ZANJ

np.random.seed(0)

TEST_DATA_PATH: Path = Path("tests/junk_data")

SUPPORTS_KW_ONLY: bool = bool(sys.version_info >= (3, 10))


@serializable_dataclass
class BasicZanjTorch(SerializableDataclass):
    a: str
    q: int = 42
    c: typing.List[int] = serializable_field(default_factory=list)


@serializable_dataclass
class NestedTorch(SerializableDataclass):
    name: str
    basic: BasicZanjTorch
    val: float


@serializable_dataclass
class sdc_with_torch_tensor(SerializableDataclass):
    name: str
    tensor1: torch.Tensor
    tensor2: torch.Tensor


def test_sdc_tensor_small():
    instance = sdc_with_torch_tensor("small tensors", torch.rand(8), torch.rand(16))

    z = ZANJ()
    path = TEST_DATA_PATH / "test_sdc_tensor_small.zanj"
    z.save(instance, path)
    recovered = z.read(path)
    assert instance == recovered


def test_sdc_tensor():
    instance = sdc_with_torch_tensor(
        "bigger tensors", torch.rand(128, 128), torch.rand(256, 256)
    )

    z = ZANJ()
    path = TEST_DATA_PATH / "test_sdc_tensor.zanj"
    z.save(instance, path)
    recovered = z.read(path)
    assert instance == recovered


@serializable_dataclass(kw_only=SUPPORTS_KW_ONLY)
class sdc_complicated(SerializableDataclass):
    name: str
    arr1: np.ndarray
    arr2: np.ndarray
    iris_data: pd.DataFrame
    brain_data: pd.DataFrame
    container: typing.List[NestedTorch]

    tensor: torch.Tensor

    def __eq__(self, value):
        return super().__eq__(value)


def test_sdc_complicated():
    instance = sdc_complicated(
        name="complicated data",
        arr1=np.random.rand(128, 128),
        arr2=np.random.rand(256, 256),
        iris_data=pd.read_csv("tests/input_data/iris.csv"),
        brain_data=pd.read_csv("tests/input_data/brain_networks.csv"),
        container=[
            NestedTorch(
                f"n-{n}",
                BasicZanjTorch(f"n-{n}_b", n * 10 + 1, [n + 1, n + 2, n + 10]),
                n * np.pi,
            )
            for n in range(10)
        ],
        tensor=torch.rand(512, 512),
    )

    z = ZANJ()
    path = TEST_DATA_PATH / "test_sdc_complicated.zanj"
    z.save(instance, path)
    recovered = z.read(path)
    assert instance == recovered

``````{ end_of_file="tests/unit/with_torch/test_sdc_torch.py" }

``````{ path="tests/unit/with_torch/test_torch_edge_cases.py"  }
from __future__ import annotations

from pathlib import Path

import pytest
import torch  # type: ignore[import-not-found]
from muutils.json_serialize import (
    SerializableDataclass,
    serializable_dataclass,
)

from zanj import ZANJ
from zanj.torchutil import (
    ConfiguredModel,
    assert_model_exact_equality,
    get_module_device,
    num_params,
    set_config_class,
)

TEST_DATA_PATH: Path = Path("tests/junk_data")


def test_num_params():
    """Test the num_params function with various models"""
    # Simple model with trainable parameters
    model1 = torch.nn.Sequential(
        torch.nn.Linear(10, 20), torch.nn.ReLU(), torch.nn.Linear(20, 1)
    )

    # Expected number of trainable parameters:
    # Linear1: 10*20 + 20 = 220
    # Linear2: 20*1 + 1 = 21
    # Total: 241
    assert num_params(model1) == 241

    # Model with some non-trainable parameters
    model2 = torch.nn.Sequential(
        torch.nn.Linear(10, 20),
        torch.nn.BatchNorm1d(20, track_running_stats=True),
        torch.nn.Linear(20, 1),
    )

    # Freeze the batch norm layer
    for param in model2[1].parameters():
        param.requires_grad = False

    # Count only trainable parameters
    trainable_params = num_params(model2, only_trainable=True)

    # Count all parameters
    all_params = num_params(model2, only_trainable=False)

    # Batch norm has 2*20 = 40 parameters (weight and bias)
    # So difference should be 40
    assert all_params - trainable_params == 40


def test_get_module_device_empty():
    """Test get_module_device with a module that has no parameters"""

    # Create a module with no parameters
    class EmptyModule(torch.nn.Module):
        def __init__(self):
            super().__init__()

        def forward(self, x):
            return x

    empty_module = EmptyModule()

    # Test the function
    is_single, device_dict = get_module_device(empty_module)

    # Should return False and an empty dict
    assert is_single is False
    assert device_dict == {}


def test_load_state_dict_wrapper():
    """Test the _load_state_dict_wrapper method of ConfiguredModel"""

    @serializable_dataclass
    class SimpleConfig(SerializableDataclass):
        size: int

    @set_config_class(SimpleConfig)
    class SimpleModel(ConfiguredModel[SimpleConfig]):
        def __init__(self, config: SimpleConfig):
            super().__init__(config)
            self.linear = torch.nn.Linear(config.size, config.size)

        def forward(self, x):
            return self.linear(x)

        # Override _load_state_dict_wrapper to test custom behavior
        def _load_state_dict_wrapper(self, state_dict, **kwargs):
            # This should be called instead of the standard load_state_dict
            self.custom_wrapper_called = True
            # Still need to actually load the state dict
            return super()._load_state_dict_wrapper(state_dict)

    # Create a model, save it, and load it
    model = SimpleModel(SimpleConfig(10))
    path = TEST_DATA_PATH / "test_load_state_dict_wrapper.zanj"
    ZANJ().save(model, path)

    # Create a new model instance
    model2 = SimpleModel(SimpleConfig(10))
    model2.custom_wrapper_called = False

    # Use read to load the model, which should call _load_state_dict_wrapper
    loaded_model = model2.read(path)

    # Check that our custom wrapper was called
    assert hasattr(loaded_model, "custom_wrapper_called")
    assert loaded_model.custom_wrapper_called is True


def test_deprecated_load_file():
    """Test that the deprecated load_file method works and issues a warning"""

    @serializable_dataclass
    class SimpleConfig(SerializableDataclass):
        size: int

    @set_config_class(SimpleConfig)
    class SimpleModel(ConfiguredModel[SimpleConfig]):
        def __init__(self, config: SimpleConfig):
            super().__init__(config)
            self.linear = torch.nn.Linear(config.size, config.size)

        def forward(self, x):
            return self.linear(x)

    # Create a model and save it
    model = SimpleModel(SimpleConfig(10))
    path = TEST_DATA_PATH / "test_deprecated_load_file.zanj"
    ZANJ().save(model, path)

    # Use the deprecated method with a warning check
    with pytest.warns(DeprecationWarning):
        loaded_model = SimpleModel.load_file(path)

    # Check that the model was loaded correctly
    assert_model_exact_equality(model, loaded_model)


def test_configmodel_training_records():
    """Test that training_records are properly saved and loaded"""

    @serializable_dataclass
    class SimpleConfig(SerializableDataclass):
        size: int

    @set_config_class(SimpleConfig)
    class SimpleModel(ConfiguredModel[SimpleConfig]):
        def __init__(self, config: SimpleConfig):
            super().__init__(config)
            self.linear = torch.nn.Linear(config.size, config.size)
            self.training_records = None  # Initialize to None

        def forward(self, x):
            return self.linear(x)

    # Create a model and set some training records
    model = SimpleModel(SimpleConfig(10))
    model.training_records = {
        "loss": [1.0, 0.9, 0.8, 0.7],
        "accuracy": [0.6, 0.7, 0.8, 0.9],
        "learning_rate": [0.01, 0.005, 0.001, 0.0005],
    }

    # Save the model
    path = TEST_DATA_PATH / "test_training_records.zanj"
    ZANJ().save(model, path)

    # Load the model
    loaded_model = SimpleModel.read(path)

    # Check that training records were loaded correctly
    assert loaded_model.training_records == model.training_records
    assert loaded_model.training_records["loss"] == [1.0, 0.9, 0.8, 0.7]
    assert loaded_model.training_records["accuracy"] == [0.6, 0.7, 0.8, 0.9]
    assert loaded_model.training_records["learning_rate"] == [
        0.01,
        0.005,
        0.001,
        0.0005,
    ]


def test_configmodel_with_custom_settings():
    """Test ConfiguredModel with custom settings for _load_state_dict_wrapper"""

    @serializable_dataclass
    class SimpleConfig(SerializableDataclass):
        size: int

    @set_config_class(SimpleConfig)
    class CustomStateModel(ConfiguredModel[SimpleConfig]):
        def __init__(self, config: SimpleConfig):
            super().__init__(config)
            self.linear = torch.nn.Linear(config.size, config.size)
            self.custom_param_received = None

        def forward(self, x):
            return self.linear(x)

        def _load_state_dict_wrapper(self, state_dict, **kwargs):
            # Store the custom param if provided
            self.custom_param_received = kwargs.get("custom_param", None)
            return super(CustomStateModel, self)._load_state_dict_wrapper(state_dict)

    # Create a model
    model = CustomStateModel(SimpleConfig(10))
    path = TEST_DATA_PATH / "test_custom_settings.zanj"
    ZANJ().save(model, path)

    # Create a ZANJ with custom settings
    z = ZANJ(
        custom_settings={"_load_state_dict_wrapper": {"custom_param": "test_value"}}
    )

    # Load the model with the custom ZANJ
    loaded_model = z.read(path)

    # Check that the custom param was received
    assert loaded_model.custom_param_received == "test_value"

``````{ end_of_file="tests/unit/with_torch/test_torch_edge_cases.py" }

``````{ path="tests/unit/with_torch/test_zanj_sdc_modelcfg.py"  }
from __future__ import annotations

import json
import sys
import typing
from pathlib import Path

import numpy as np
import torch  # type: ignore[import-not-found]
from muutils.json_serialize import (
    SerializableDataclass,
    serializable_dataclass,
    serializable_field,
)

from zanj import ZANJ

np.random.seed(0)

# pylint: disable=missing-function-docstring,missing-class-docstring

TEST_DATA_PATH: Path = Path("tests/junk_data")


SUPPORTS_KW_ONLY: bool = bool(sys.version_info >= (3, 10))


@serializable_dataclass
class MyModelCfg(SerializableDataclass):
    name: str
    num_layers: int
    hidden_size: int
    dropout: float


@serializable_dataclass(kw_only=SUPPORTS_KW_ONLY)
class TrainCfg(SerializableDataclass):
    name: str
    weight_decay: float
    optimizer: typing.Type[torch.optim.Optimizer] = serializable_field(
        default_factory=lambda: torch.optim.Adam,
        serialization_fn=lambda x: x.__name__,
        loading_fn=lambda data: getattr(torch.optim, data["optimizer"]),
    )
    optimizer_kwargs: typing.Dict[str, typing.Any] = serializable_field(  # type: ignore
        default_factory=lambda: dict(lr=0.000001)
    )


class CustomCfg:
    def __init__(self, x: int, y: str):
        self.x = x
        self.y = y

    def __eq__(self, other):
        return self.x == other.x and self.y == other.y

    def serialize(self):
        return {"x": self.x, "y": self.y}

    @classmethod
    def load(cls, data):
        return cls(
            **{
                "x": data["x"],
                "y": data["y"],
            }
        )


@serializable_dataclass(kw_only=SUPPORTS_KW_ONLY)
class BasicCfgHolder(SerializableDataclass):
    model: MyModelCfg
    optimizer: TrainCfg
    custom: typing.Optional[CustomCfg] = serializable_field(
        default=None,
        serialization_fn=lambda x: x.serialize(),
        loading_fn=lambda data: CustomCfg.load(data["custom"]),
    )


instance_basic: BasicCfgHolder = BasicCfgHolder(  # type: ignore
    model=MyModelCfg("lstm", 3, 128, 0.1),  # type: ignore
    optimizer=TrainCfg(  # type: ignore
        name="adamw",
        weight_decay=0.2,
        optimizer=torch.optim.AdamW,
        optimizer_kwargs=dict(lr=0.0001),
    ),
    custom=CustomCfg(42, "forty-two"),
)


def test_config_holder():
    instance_stored = instance_basic.serialize()
    with open(TEST_DATA_PATH / "test_config_holder.json", "w") as f:
        json.dump(instance_stored, f, indent="\t")
    with open(TEST_DATA_PATH / "test_config_holder.json", "r") as f:
        instance_stored_read = json.load(f)
    recovered = BasicCfgHolder.load(instance_stored_read)
    assert isinstance(recovered.model, MyModelCfg)
    assert isinstance(recovered.optimizer, TrainCfg)
    assert isinstance(recovered.custom, CustomCfg)
    assert recovered.custom.x == 42
    assert instance_basic == recovered


def test_config_holder_zanj():
    z = ZANJ()
    path = TEST_DATA_PATH / "test_config_holder.zanj"
    z.save(instance_basic, path)
    recovered = z.read(path)
    assert isinstance(recovered.model, MyModelCfg)
    assert isinstance(recovered.optimizer, TrainCfg)
    assert isinstance(recovered.custom, CustomCfg)
    assert recovered.custom.x == 42
    assert instance_basic == recovered


@serializable_dataclass(kw_only=SUPPORTS_KW_ONLY)
class BaseGPTConfig(SerializableDataclass):
    name: str
    act_fn: str
    d_model: int
    d_head: int
    n_layers: int


@serializable_dataclass(kw_only=SUPPORTS_KW_ONLY)
class AdvCfgHolder(SerializableDataclass):
    model_cfg: BaseGPTConfig
    name: str = serializable_field(default="default")
    tokenizer: typing.Optional[CustomCfg] = serializable_field(
        default=None,
        serialization_fn=lambda x: repr(x) if x is not None else None,
        loading_fn=lambda data: (
            None if data["tokenizer"] is None else NotImplementedError
        ),
    )


instance_adv: AdvCfgHolder = AdvCfgHolder(  # type: ignore
    model_cfg=BaseGPTConfig(  # type: ignore
        name="gpt2",
        act_fn="gelu",
        d_model=128,
        d_head=64,
        n_layers=3,
    ),
    tokenizer=None,
)


def test_adv_config_holder():
    instance_stored = instance_adv.serialize()
    with open(TEST_DATA_PATH / "test_adv_config_holder.json", "w") as f:
        json.dump(instance_stored, f, indent="\t")
    recovered = AdvCfgHolder.load(instance_stored)
    assert isinstance(recovered.model_cfg, BaseGPTConfig)
    assert instance_adv == recovered


def test_adv_config_holder_zanj():
    z = ZANJ()
    path = TEST_DATA_PATH / "test_adv_config_holder.zanj"
    z.save(instance_adv, path)
    recovered = z.read(path)
    assert isinstance(recovered.model_cfg, BaseGPTConfig)
    assert instance_adv == recovered

``````{ end_of_file="tests/unit/with_torch/test_zanj_sdc_modelcfg.py" }

``````{ path="tests/unit/with_torch/test_zanj_torch.py"  }
from __future__ import annotations

from pathlib import Path

import numpy as np
import torch  # type: ignore[import-not-found]
from muutils.json_serialize import (
    SerializableDataclass,
    serializable_dataclass,
    serializable_field,
)
from muutils.tensor_utils import compare_state_dicts

from zanj import ZANJ
from zanj.torchutil import (
    ConfiguredModel,
    assert_model_exact_equality,
    set_config_class,
)

np.random.seed(0)

TEST_DATA_PATH: Path = Path("tests/junk_data")


def test_torch_configmodel_minimal():
    @serializable_dataclass
    class MyNNConfig(SerializableDataclass):
        n_layers: int

    @set_config_class(MyNNConfig)
    class MyNN(ConfiguredModel[MyNNConfig]):
        def __init__(self, config: MyNNConfig):
            super().__init__(config)

            self.layer = torch.nn.Linear(config.n_layers, 1)

        def forward(self, x):
            return self.layer(x)

    config: MyNNConfig = MyNNConfig(
        n_layers=2,
    )

    model: MyNN = MyNN(config)

    fname: Path = TEST_DATA_PATH / "test_torch_configmodel.zanj"
    ZANJ().save(model, fname)

    print(f"saved model to {fname}")
    print(f"{model.zanj_model_config = }")

    # try to load the model
    model2: MyNN = MyNN.read(fname)
    print(f"loaded model from {fname}")
    print(f"{model2.zanj_model_config = }")

    assert model.zanj_model_config == model2.zanj_model_config
    assert model.training_records == model2.training_records

    compare_state_dicts(model.state_dict(), model2.state_dict())
    assert_model_exact_equality(model, model2)

    model3: MyNN = ZANJ().read(fname)
    print(f"loaded model from {fname}")
    print(f"{model3.zanj_model_config = }")

    assert model.zanj_model_config == model3.zanj_model_config
    assert model.training_records == model3.training_records

    compare_state_dicts(model.state_dict(), model3.state_dict())
    assert_model_exact_equality(model, model3)


def test_torch_configmodel():
    import torch  # type: ignore[import-not-found]

    from zanj.torchutil import ConfiguredModel, set_config_class

    @serializable_dataclass
    class MyGPTConfig(SerializableDataclass):
        """basic test GPT config"""

        n_layers: int
        n_heads: int
        embedding_size: int
        n_positions: int
        n_vocab: int

        loss_factory: torch.nn.modules.loss._Loss = serializable_field(
            default_factory=lambda: torch.nn.CrossEntropyLoss,
            serialization_fn=lambda x: x.__name__,
            loading_fn=lambda x: getattr(torch.nn, x["loss_factory"]),
        )

        loss_kwargs: dict = serializable_field(default_factory=dict)

        @property
        def loss(self):
            return self.loss_factory(**self.loss_kwargs)

        optim_factory: torch.optim.Optimizer = serializable_field(
            default_factory=lambda: torch.optim.Adam,
            serialization_fn=lambda x: x.__name__,
            loading_fn=lambda x: getattr(torch.optim, x["optim_factory"]),
        )

        optim_kwargs: dict = serializable_field(default_factory=dict)

        def optim(self, model):
            return self.optim_factory(model.parameters(), **self.optim_kwargs)  # type: ignore

    @set_config_class(MyGPTConfig)
    class MyGPT(ConfiguredModel[MyGPTConfig]):
        """basic GPT model"""

        def __init__(self, config: MyGPTConfig):
            super().__init__(config)

            # implementation of a GPT style model with decoders only

            self.transformer = torch.nn.Transformer(
                d_model=config.embedding_size,
                nhead=config.n_heads,
                num_encoder_layers=0,
                num_decoder_layers=config.n_layers,
            )

        def forward(self, x):
            return self.transformer(x)

    config: MyGPTConfig = MyGPTConfig(
        n_layers=2,
        n_heads=2,
        embedding_size=16,
        n_positions=16,
        n_vocab=128,
        loss_factory=torch.nn.CrossEntropyLoss,
    )

    model: MyGPT = MyGPT(config)
    model.training_records = dict(loss=[3, 2, 1], accuracy=[0.1, 0.2, 0.3])

    fname: Path = TEST_DATA_PATH / "test_torch_configmodel.zanj"
    ZANJ().save(model, fname)

    print(f"saved model to {fname}")
    print(f"{model.zanj_model_config = }")

    # try to load the model
    model2: MyGPT = MyGPT.read(fname)
    print(f"loaded model from {fname}")
    print(f"{model2.zanj_model_config = }")

    assert model.zanj_model_config == model2.zanj_model_config
    assert model.training_records == model2.training_records

    compare_state_dicts(model.state_dict(), model2.state_dict())
    assert_model_exact_equality(model, model2)

    model3: MyGPT = ZANJ().read(fname)
    print(f"loaded model from {fname}")
    print(f"{model3.zanj_model_config = }")

    assert model.zanj_model_config == model3.zanj_model_config
    assert model.training_records == model3.training_records

    compare_state_dicts(model.state_dict(), model3.state_dict())
    assert_model_exact_equality(model, model3)

``````{ end_of_file="tests/unit/with_torch/test_zanj_torch.py" }

``````{ path="tests/unit/with_torch/test_zanj_torch_cfgmismatch.py"  }
from __future__ import annotations

from typing import Any

import torch  # type: ignore[import-not-found]
from muutils.json_serialize import (
    SerializableDataclass,
    serializable_dataclass,
    serializable_field,
)

from zanj.torchutil import (
    ConfigMismatchException,
    ConfiguredModel,
    assert_model_cfg_equality,
    set_config_class,
)

# Assuming required imports and classes are present (including ConfiguredModel, MyGPTConfig, and MyGPT)


@serializable_dataclass
class MyGPTConfig(SerializableDataclass):
    """basic test GPT config"""

    n_layers: int
    n_heads: int
    embedding_size: int
    n_positions: int
    n_vocab: int
    junk_data: Any = serializable_field(default_factory=dict)


@set_config_class(MyGPTConfig)
class MyGPT(ConfiguredModel[MyGPTConfig]):
    """basic "GPT" model"""

    def __init__(self, config: MyGPTConfig):
        super().__init__(config)
        self.transformer = torch.nn.Linear(config.embedding_size, config.n_vocab)

    def forward(self, x):
        return self.transformer(x)


def test_config_mismatch_exception_direct():
    msg = "Configs don't match"
    diff = {"model_cfg": {"are_weights_processed": {"self": False, "other": True}}}

    exc = ConfigMismatchException(msg, diff)
    assert exc.diff == diff
    assert (
        str(exc)
        == r"Configs don't match: {'model_cfg': {'are_weights_processed': {'self': False, 'other': True}}}"
    )


def test_equal_configs():
    config = MyGPTConfig(
        n_layers=2,
        n_heads=2,
        embedding_size=16,
        n_positions=16,
        n_vocab=128,
        junk_data={"a": 1, "b": 2},
    )

    model_a = MyGPT(config)
    model_b = MyGPT(config)

    assert_model_cfg_equality(model_a, model_b)


def test_unequal_configs():
    config_a = MyGPTConfig(
        n_layers=2,
        n_heads=2,
        embedding_size=16,
        n_positions=16,
        n_vocab=128,
        junk_data={"a": 1, "b": 2},
    )
    # a different config
    config_b = MyGPTConfig(
        n_layers=3,
        n_heads=2,
        embedding_size=16,
        n_positions=16,
        n_vocab=128,
        junk_data={"a": 7, "something": "or other"},
    )

    model_a = MyGPT(config_a)
    model_b = MyGPT(config_b)

    try:
        assert_model_cfg_equality(model_a, model_b)
    except ConfigMismatchException as exc:
        assert exc.diff == {
            "n_layers": {"self": 2, "other": 3},
            "junk_data": {
                "self": {"a": 1, "b": 2},
                "other": {"a": 7, "something": "or other"},
            },
        }
    else:
        raise AssertionError("Expected a ConfigMismatchException!")


def test_unequal_configs_2():
    config_a = MyGPTConfig(
        n_layers=2,
        n_heads=2,
        embedding_size=16,
        n_positions=16,
        n_vocab=128,
        junk_data={"a": 1, "b": 2},
    )
    # a different config
    config_b = MyGPTConfig(
        n_layers=3,
        n_heads=2,
        embedding_size=16,
        n_positions=16,
        n_vocab=128,
        junk_data="this isnt even a dict lol",
    )

    model_a = MyGPT(config_a)
    model_b = MyGPT(config_b)

    try:
        assert_model_cfg_equality(model_a, model_b)
    except ConfigMismatchException as exc:
        assert exc.diff == {
            "n_layers": {"self": 2, "other": 3},
            "junk_data": {
                "self": {"a": 1, "b": 2},
                "other": "this isnt even a dict lol",
            },
        }
    else:
        raise AssertionError("Expected a ConfigMismatchException!")


def test_incorrect_instance():
    config = MyGPTConfig(
        n_layers=2,
        n_heads=2,
        embedding_size=16,
        n_positions=16,
        n_vocab=128,
    )

    model_a = MyGPT(config)
    model_b = "Not a ConfiguredModel instance"

    try:
        assert_model_cfg_equality(model_a, model_b)  # type: ignore
    except AssertionError as exc:
        assert str(exc) == "model_b must be a ConfiguredModel"

``````{ end_of_file="tests/unit/with_torch/test_zanj_torch_cfgmismatch.py" }

``````{ path="tests/assert_no_torch.py"  }
import pytest


def test_assert_no_torch():
    with pytest.raises(ImportError):
        import torch  # type: ignore[import-not-found]

        print(torch.rand(10))

``````{ end_of_file="tests/assert_no_torch.py" }

``````{ path="zanj/__init__.py"  }
"""
.. include:: ../README.md
"""

from __future__ import annotations

from zanj.loading import register_loader_handler
from zanj.zanj import ZANJ

__all__ = [
    "register_loader_handler",
    "ZANJ",
    # modules
    "externals",
    "loading",
    "serializing",
    "torchutil",
    "zanj",
]

``````{ end_of_file="zanj/__init__.py" }

``````{ path="zanj/externals.py"  }
"""for storing/retrieving an item externally in a ZANJ archive"""

from __future__ import annotations

import json
from typing import IO, Any, Callable, Literal, NamedTuple, get_args

import numpy as np
from muutils.json_serialize.json_serialize import ObjectPath
from muutils.json_serialize.util import JSONitem

# this is to make type checking work -- it will later be overridden
_ZANJ_pre = Any

ZANJ_MAIN: str = "__zanj__.json"
ZANJ_META: str = "__zanj_meta__.json"

ExternalItemType = Literal["jsonl", "npy"]

ExternalItemType_vals = get_args(ExternalItemType)

ExternalItem = NamedTuple(
    "ExternalItem",
    [
        ("item_type", ExternalItemType),
        ("data", Any),
        ("path", ObjectPath),
    ],
)


def load_jsonl(zanj: "LoadedZANJ", fp: IO[bytes]) -> list[JSONitem]:  # type: ignore[name-defined] # noqa: F821
    return [json.loads(line) for line in fp]


def load_npy(zanj: "LoadedZANJ", fp: IO[bytes]) -> np.ndarray:  # type: ignore[name-defined] # noqa: F821
    return np.load(fp)


EXTERNAL_LOAD_FUNCS: dict[ExternalItemType, Callable[[_ZANJ_pre, IO[bytes]], Any]] = {
    "jsonl": load_jsonl,
    "npy": load_npy,
}


def GET_EXTERNAL_LOAD_FUNC(item_type: str) -> Callable[[_ZANJ_pre, IO[bytes]], Any]:
    if item_type not in EXTERNAL_LOAD_FUNCS:
        raise ValueError(
            f"unknown external item type: {item_type}, needs to be one of {EXTERNAL_LOAD_FUNCS.keys()}"
        )
    # safe to ignore since we just checked
    return EXTERNAL_LOAD_FUNCS[item_type]  # type: ignore[index]

``````{ end_of_file="zanj/externals.py" }

``````{ path="zanj/loading.py"  }
from __future__ import annotations

import json
import threading
import typing
import zipfile
from dataclasses import dataclass
from pathlib import Path
from typing import Any, Callable

import numpy as np

try:
    import pandas as pd  # type: ignore[import]

    pandas_DataFrame = pd.DataFrame  # type: ignore[no-redef]
except ImportError:

    class pandas_DataFrame:  # type: ignore[no-redef]
        def __init__(self, *args, **kwargs):
            raise ImportError("cannot load pandas DataFrame, pandas is not installed")


from muutils.errormode import ErrorMode
from muutils.json_serialize.array import load_array
from muutils.json_serialize.json_serialize import ObjectPath
from muutils.json_serialize.util import (
    _FORMAT_KEY,
    _REF_KEY,
    JSONdict,
    JSONitem,
    safe_getsource,
    string_as_lines,
)

from zanj.externals import (
    GET_EXTERNAL_LOAD_FUNC,
    ZANJ_MAIN,
    ZANJ_META,
    ExternalItem,
    _ZANJ_pre,
)

# pylint: disable=protected-access, dangerous-default-value


def _populate_externals_error_checking(key, item) -> bool:
    """checks that the key is valid for the item. returns "True" we need to augment the path by accessing the "data" element"""

    # special case for not fully loaded external item which we still need to populate
    if isinstance(item, typing.Mapping):
        if (_FORMAT_KEY in item) and item[_FORMAT_KEY].endswith(":external"):
            if "data" in item:
                return True
            else:
                raise KeyError(
                    f"expected an external item, but could not find data: {list(item.keys())}",
                    f"{item[_FORMAT_KEY]}, {len(item) = }, {item.get('data', '<EMPTY>') = }",
                )

    # if it's a list, make sure the key is an int and that it's in range
    if isinstance(item, typing.Sequence):
        if not isinstance(key, int):
            raise TypeError(f"improper type: '{type(key) = }', expected int")
        if key >= len(item):
            raise IndexError(f"index out of range: '{key = }', expected < {len(item)}")

    # if it's a dict, make sure that the key is a str and that it's in the dict
    elif isinstance(item, typing.Mapping):
        if not isinstance(key, str):
            raise TypeError(f"improper type: '{type(key) = }', expected str")
        if key not in item:
            raise KeyError(f"key not in dict: '{key = }', expected in {item.keys()}")

    # otherwise, raise an error
    else:
        raise TypeError(f"improper type: '{type(item) = }', expected dict or list")

    return False


@dataclass
class LoaderHandler:
    """handler for loading an object from a json file or a ZANJ archive"""

    # TODO: add a separate "asserts" function?
    # right now, any asserts must happen in `check` or `load` which is annoying with lambdas

    # (json_data, path) -> whether to use this handler
    check: Callable[[JSONitem, ObjectPath, _ZANJ_pre], bool]
    # function to load the object (json_data, path) -> loaded_obj
    load: Callable[[JSONitem, ObjectPath, _ZANJ_pre], Any]
    # unique identifier for the handler, saved in __muutils_format__ field
    uid: str
    # source package of the handler -- note that this might be overridden by ZANJ
    source_pckg: str
    # priority of the handler, defaults are all 0
    priority: int = 0
    # description of the handler
    desc: str = "(no description)"

    def serialize(self) -> JSONdict:
        """serialize the handler info"""
        return {
            # get the code and doc of the check function
            "check": {
                "code": safe_getsource(self.check),
                "doc": string_as_lines(self.check.__doc__),
            },
            # get the code and doc of the load function
            "load": {
                "code": safe_getsource(self.load),
                "doc": string_as_lines(self.load.__doc__),
            },
            # get the uid, source_pckg, priority, and desc
            "uid": str(self.uid),
            "source_pckg": str(self.source_pckg),
            "priority": int(self.priority),
            "desc": str(self.desc),
        }

    @classmethod
    def from_formattedclass(cls, fc: type, priority: int = 0):
        """create a loader from a class with `serialize`, `load` methods and `__muutils_format__` attribute"""
        assert hasattr(fc, "serialize")
        assert callable(fc.serialize)  # type: ignore
        assert hasattr(fc, "load")
        assert callable(fc.load)  # type: ignore
        assert hasattr(fc, _FORMAT_KEY)
        assert isinstance(fc.__muutils_format__, str)  # type: ignore

        return cls(
            check=lambda json_item, path=None, z=None: (  # type: ignore[misc]
                json_item[_FORMAT_KEY] == fc.__muutils_format__  # type: ignore[attr-defined]
            ),
            load=lambda json_item, path=None, z=None: fc.load(json_item, path, z),  # type: ignore[misc]
            uid=fc.__muutils_format__,  # type: ignore[attr-defined]
            source_pckg=str(fc.__module__),
            priority=priority,
            desc=f"formatted class loader for {fc.__name__}",
        )


# TODO: how can we type hint this without actually importing torch?
def _torch_loaderhandler_load(
    json_item: JSONitem,
    path: ObjectPath,
    z: _ZANJ_pre | None = None,
):
    """load a torch tensor from a json item"""
    try:
        import torch  # type: ignore[import-not-found]
        from muutils.tensor_utils import TORCH_DTYPE_MAP
    except ImportError as e:
        err_msg: str = f"could not import torch, which we need to load the object at {path = }: {json_item = }"
        raise ImportError(err_msg) from e

    return torch.tensor(
        load_array(json_item),
        dtype=TORCH_DTYPE_MAP[json_item["dtype"]],  # type: ignore[index, call-overload]
    )


# NOTE: there are type ignores on the loaders, since the type checking should be the responsibility of the check function

LOADER_MAP_LOCK = threading.Lock()

LOADER_MAP: dict[str, LoaderHandler] = {
    lh.uid: lh
    for lh in [
        # array external
        LoaderHandler(
            check=lambda json_item, path=None, z=None: (  # type: ignore[misc]
                isinstance(json_item, typing.Mapping)
                and _FORMAT_KEY in json_item
                and json_item[_FORMAT_KEY].startswith("numpy.ndarray")
                # and json_item["data"].dtype.name == json_item["dtype"]
                # and tuple(json_item["data"].shape) == tuple(json_item["shape"])
            ),
            load=lambda json_item, path=None, z=None: np.array(  # type: ignore[misc]
                load_array(json_item), dtype=np.dtype(json_item["dtype"])
            ),
            uid="numpy.ndarray",
            source_pckg="zanj",
            desc="numpy.ndarray loader",
        ),
        LoaderHandler(
            check=lambda json_item, path=None, z=None: (  # type: ignore[misc]
                isinstance(json_item, typing.Mapping)
                and _FORMAT_KEY in json_item
                and json_item[_FORMAT_KEY].startswith("torch.Tensor")
                # and json_item["data"].dtype.name == json_item["dtype"]
                # and tuple(json_item["data"].shape) == tuple(json_item["shape"])
            ),
            load=_torch_loaderhandler_load,
            uid="torch.Tensor",
            source_pckg="zanj",
            desc="torch.Tensor loader",
        ),
        # pandas
        LoaderHandler(
            check=lambda json_item, path=None, z=None: (  # type: ignore[misc]
                isinstance(json_item, typing.Mapping)
                and _FORMAT_KEY in json_item
                and json_item[_FORMAT_KEY].startswith("pandas.DataFrame")
                and "data" in json_item
                and isinstance(json_item["data"], typing.Sequence)
            ),
            load=lambda json_item, path=None, z=None: pandas_DataFrame(  # type: ignore[misc]
                json_item["data"]
            ),
            uid="pandas.DataFrame",
            source_pckg="zanj",
            desc="pandas.DataFrame loader",
        ),
        # list/tuple external
        LoaderHandler(
            check=lambda json_item, path=None, z=None: (  # type: ignore[misc]
                isinstance(json_item, typing.Mapping)
                and _FORMAT_KEY in json_item
                and json_item[_FORMAT_KEY].startswith("list")
                and "data" in json_item
                and isinstance(json_item["data"], typing.Sequence)
            ),
            load=lambda json_item, path=None, z=None: [  # type: ignore[misc]
                load_item_recursive(x, path, z) for x in json_item["data"]
            ],
            uid="list",
            source_pckg="zanj",
            desc="list loader, for externals",
        ),
        LoaderHandler(
            check=lambda json_item, path=None, z=None: (  # type: ignore[misc]
                isinstance(json_item, typing.Mapping)
                and _FORMAT_KEY in json_item
                and json_item[_FORMAT_KEY].startswith("tuple")
                and "data" in json_item
                and isinstance(json_item["data"], typing.Sequence)
            ),
            load=lambda json_item, path=None, z=None: tuple(  # type: ignore[misc]
                [load_item_recursive(x, path, z) for x in json_item["data"]]
            ),
            uid="tuple",
            source_pckg="zanj",
            desc="tuple loader, for externals",
        ),
    ]
}


def register_loader_handler(handler: LoaderHandler):
    """register a custom loader handler"""
    global LOADER_MAP, LOADER_MAP_LOCK
    with LOADER_MAP_LOCK:
        LOADER_MAP[handler.uid] = handler


def get_item_loader(
    json_item: JSONitem,
    path: ObjectPath,
    zanj: _ZANJ_pre | None = None,
    error_mode: ErrorMode = ErrorMode.WARN,
    # lh_map: dict[str, LoaderHandler] = LOADER_MAP,
) -> LoaderHandler | None:
    """get the loader for a json item"""
    global LOADER_MAP

    # check if we recognize the format
    if isinstance(json_item, typing.Mapping) and _FORMAT_KEY in json_item:
        if not isinstance(json_item[_FORMAT_KEY], str):
            raise TypeError(
                f"invalid __muutils_format__ type '{type(json_item[_FORMAT_KEY])}' in '{path=}': '{json_item[_FORMAT_KEY] = }'"
            )
        if json_item[_FORMAT_KEY] in LOADER_MAP:
            return LOADER_MAP[json_item[_FORMAT_KEY]]  # type: ignore[index]

    # if we dont recognize the format, try to find a loader that can handle it
    for key, lh in LOADER_MAP.items():
        if lh.check(json_item, path, zanj):
            return lh

    # if we still dont have a loader, return None
    return None


def load_item_recursive(
    json_item: JSONitem,
    path: ObjectPath,
    zanj: _ZANJ_pre | None = None,
    error_mode: ErrorMode = ErrorMode.WARN,
    allow_not_loading: bool = True,
) -> Any:
    lh: LoaderHandler | None = get_item_loader(
        json_item=json_item,
        path=path,
        zanj=zanj,
        error_mode=error_mode,
        # lh_map=lh_map,
    )

    if lh is not None:
        # special case for serializable dataclasses
        if (
            isinstance(json_item, typing.Mapping)
            and (_FORMAT_KEY in json_item)
            and ("SerializableDataclass" in json_item[_FORMAT_KEY])  # type: ignore[operator]
        ):
            # why this horribleness?
            # SerializableDataclass, if it has a field `x` which is also a SerializableDataclass, will automatically call `x.__class__.load()`
            # However, we need to load things in containers, as well as arrays
            processed_json_item: dict = {
                key: (
                    val
                    if (
                        isinstance(val, typing.Mapping)
                        and (_FORMAT_KEY in val)
                        and ("SerializableDataclass" in val[_FORMAT_KEY])
                    )
                    else load_item_recursive(
                        json_item=val,
                        path=tuple(path) + (key,),
                        zanj=zanj,
                        error_mode=error_mode,
                    )
                )
                for key, val in json_item.items()
            }

            return lh.load(processed_json_item, path, zanj)

        else:
            return lh.load(json_item, path, zanj)
    else:
        if isinstance(json_item, dict):
            return {
                key: load_item_recursive(
                    json_item=json_item[key],
                    path=tuple(path) + (key,),
                    zanj=zanj,
                    error_mode=error_mode,
                    # lh_map=lh_map,
                )
                for key in json_item
            }
        elif isinstance(json_item, list):
            return [
                load_item_recursive(
                    json_item=x,
                    path=tuple(path) + (i,),
                    zanj=zanj,
                    error_mode=error_mode,
                    # lh_map=lh_map,
                )
                for i, x in enumerate(json_item)
            ]
        elif isinstance(json_item, (str, int, float, bool, type(None))):
            return json_item
        else:
            if allow_not_loading:
                return json_item
            else:
                raise ValueError(
                    f"unknown type {type(json_item)} at {path}\n{json_item}"
                )


def _each_item_in_externals(
    externals: dict[str, ExternalItem],
    json_data: JSONitem,
) -> typing.Iterable[tuple[str, ExternalItem, Any, ObjectPath]]:
    """note that you MUST use the raw iterator, dont try to turn into a list or something"""

    sorted_externals: list[tuple[str, ExternalItem]] = sorted(
        externals.items(), key=lambda x: len(x[1].path)
    )

    for ext_path, ext_item in sorted_externals:
        # get the path to the item
        path: ObjectPath = tuple(ext_item.path)
        assert len(path) > 0
        assert all(isinstance(key, (str, int)) for key in path), (
            f"improper types in path {path=}"
        )
        # get the item
        item = json_data
        for i, key in enumerate(path):
            try:
                # ignores in this block are because we cannot know the type is indexable in static analysis
                # but, we check the types in the line below
                external_unloaded: bool = _populate_externals_error_checking(key, item)
                if external_unloaded:
                    item = item["data"]  # type: ignore
                item = item[key]  # type: ignore[index]

            except (KeyError, IndexError, TypeError) as e:
                raise KeyError(
                    f"could not find '{key = }' at path '{ext_path = }', specifically at index '{i = }'",
                    f"'{type(item) =}', '{len(item) = }', '{item.keys() if isinstance(item, dict) else None = }'",  # type: ignore
                    f"From error: {e = }",
                    f"\n\n{item=}\n\n{ext_item=}",
                ) from e

        yield (ext_path, ext_item, item, path)


class LoadedZANJ:
    """for loading a zanj file"""

    def __init__(
        self,
        path: str | Path,
        zanj: _ZANJ_pre,
    ) -> None:
        # path and zanj object
        self._path: str = str(path)
        self._zanj: _ZANJ_pre = zanj

        # load zip file
        _zipf: zipfile.ZipFile = zipfile.ZipFile(file=self._path, mode="r")

        # load data
        self._meta: JSONdict = json.load(_zipf.open(ZANJ_META, "r"))
        self._json_data: JSONitem = json.load(_zipf.open(ZANJ_MAIN, "r"))

        # read externals
        self._externals: dict[str, ExternalItem] = dict()
        for fname, ext_item in self._meta["externals_info"].items():  # type: ignore
            item_type: str = ext_item["item_type"]  # type: ignore
            with _zipf.open(fname, "r") as fp:
                self._externals[fname] = ExternalItem(
                    item_type=item_type,  # type: ignore[arg-type]
                    data=GET_EXTERNAL_LOAD_FUNC(item_type)(self, fp),
                    path=ext_item["path"],  # type: ignore
                )

        # close zip file
        _zipf.close()
        del _zipf

    def populate_externals(self) -> None:
        """put all external items into the main json data"""

        # loop over once, populating the externals only
        for ext_path, ext_item, item, path in _each_item_in_externals(
            self._externals, self._json_data
        ):
            # replace the item with the external item
            assert _REF_KEY in item  # type: ignore
            assert item[_REF_KEY] == ext_path  # type: ignore
            item["data"] = ext_item.data  # type: ignore

``````{ end_of_file="zanj/loading.py" }

``````{ path="zanj/py.typed"  }

``````{ end_of_file="zanj/py.typed" }

``````{ path="zanj/serializing.py"  }
from __future__ import annotations

import json
import sys
from dataclasses import dataclass
from typing import IO, Any, Callable, Iterable, Sequence
import warnings

import numpy as np
from muutils.json_serialize.array import arr_metadata
from muutils.json_serialize.json_serialize import (  # JsonSerializer,
    DEFAULT_HANDLERS,
    ObjectPath,
    SerializerHandler,
)
from muutils.json_serialize.util import (
    JSONdict,
    JSONitem,
    MonoTuple,
    _FORMAT_KEY,
    _REF_KEY,
)

from zanj.externals import ExternalItem, ExternalItemType, _ZANJ_pre

KW_ONLY_KWARGS: dict = dict()
if sys.version_info >= (3, 10):
    KW_ONLY_KWARGS["kw_only"] = True

# pylint: disable=unused-argument, protected-access, unexpected-keyword-arg
# for some reason pylint complains about kwargs to ZANJSerializerHandler


def jsonl_metadata(data: list[JSONdict]) -> dict:
    """metadata about a jsonl object"""
    all_cols: set[str] = set([col for item in data for col in item.keys()])
    return {
        "data[0]": data[0],
        "len(data)": len(data),
        "columns": {
            col: {
                "types": list(
                    set([type(item[col]).__name__ for item in data if col in item])
                ),
                "len": len([item[col] for item in data if col in item]),
            }
            for col in all_cols
            if col != _FORMAT_KEY
        },
    }


def store_npy(self: _ZANJ_pre, fp: IO[bytes], data: np.ndarray) -> None:
    """store numpy array to given file as .npy"""
    # TODO: Type `<module 'numpy.lib'>` has no attribute `format` --> zanj/serializing.py:54:5
    # info: rule `unresolved-attribute` is enabled by default
    np.lib.format.write_array(  # ty: ignore[unresolved-attribute]
        fp=fp,
        array=np.asanyarray(data),
        allow_pickle=False,
    )


def store_jsonl(self: _ZANJ_pre, fp: IO[bytes], data: Sequence[JSONitem]) -> None:
    """store sequence to given file as .jsonl"""

    for item in data:
        fp.write(json.dumps(item).encode("utf-8"))
        fp.write("\n".encode("utf-8"))


EXTERNAL_STORE_FUNCS: dict[
    ExternalItemType, Callable[[_ZANJ_pre, IO[bytes], Any], None]
] = {
    "npy": store_npy,
    "jsonl": store_jsonl,
}


@dataclass(**KW_ONLY_KWARGS)
class ZANJSerializerHandler(SerializerHandler):
    """a handler for ZANJ serialization"""

    # unique identifier for the handler, saved in _FORMAT_KEY field
    # uid: str
    # source package of the handler -- note that this might be overridden by ZANJ
    source_pckg: str
    # (self_config, object) -> whether to use this handler
    check: Callable[[_ZANJ_pre, Any, ObjectPath], bool]
    # (self_config, object, path) -> serialized object
    serialize_func: Callable[[_ZANJ_pre, Any, ObjectPath], JSONitem]
    # optional description of how this serializer works
    # desc: str = "(no description)"


def zanj_external_serialize(
    jser: _ZANJ_pre,
    data: Any,
    path: ObjectPath,
    item_type: ExternalItemType,
    _format: str,
) -> JSONitem:
    """stores a numpy array or jsonl externally in a ZANJ object

    # Parameters:
     - `jser: ZANJ`
     - `data: Any`
     - `path: ObjectPath`
     - `item_type: ExternalItemType`

    # Returns:
     - `JSONitem`
       json data with reference

    # Modifies:
     - modifies `jser._externals`
    """
    # get the path, make sure its unique
    assert isinstance(path, tuple), (
        f"path must be a tuple, got {type(path) = } {path = }"
    )
    joined_path: str = "/".join([str(p) for p in path])
    archive_path: str = f"{joined_path}.{item_type}"

    # TODO: somehow need to control whether a failure here causes a fallback to other handlers, or whether the except should propagate
    # this will probably require changes to the upstream muutils.json_serialize code
    if archive_path in jser._externals:
        err_msg = f"external path {archive_path} already exists!"
        warnings.warn(err_msg)
        raise ValueError(err_msg)
    # Check for true path prefix conflicts (not just string prefix)
    # Only flag when one path is a directory ancestor of another (contains "/" separator)
    for p in jser._externals.keys():
        # Remove the file extension to get the joined_path
        existing_joined_path = p.rsplit(".", 1)[0]
        # Check if one is a true path prefix with "/" separator
        if existing_joined_path.startswith(joined_path + "/") or joined_path.startswith(
            existing_joined_path + "/"
        ):
            err_msg = (
                f"external path {joined_path} is a prefix of another path {p}!\n"
                + f"{jser._externals.keys() = }\n{joined_path = }\n{path = }\n{p = }\n{existing_joined_path = }\n{archive_path = }\n{_format = }"
            )
            warnings.warn(err_msg)
            raise ValueError(err_msg)

    # process the data if needed, assemble metadata
    data_new: Any = data
    output: dict = {
        _FORMAT_KEY: _format,
        _REF_KEY: archive_path,
    }
    if item_type == "npy":
        # check type
        data_type_str: str = str(type(data))
        if data_type_str == "<class 'torch.Tensor'>":
            # detach and convert
            data_new = data.detach().cpu().numpy()
        elif data_type_str == "<class 'numpy.ndarray'>":
            pass
        else:
            # if not a numpy array, except
            raise TypeError(f"expected numpy.ndarray, got {data_type_str}")
        # get metadata
        output.update(arr_metadata(data))
    elif item_type.startswith("jsonl"):
        # check via mro to avoid importing pandas
        if any("pandas.core.frame.DataFrame" in str(t) for t in data.__class__.__mro__):
            output["columns"] = data.columns.tolist()
            data_new = data.to_dict(orient="records")
        elif isinstance(data, (list, tuple, Iterable, Sequence)):
            data_new = [
                jser.json_serialize(item, tuple(path) + (i,))
                for i, item in enumerate(data)
            ]
        else:
            raise TypeError(
                f"expected list or pandas.DataFrame for jsonl, got {type(data)}"
            )

        if all([isinstance(item, dict) for item in data_new]):
            output.update(jsonl_metadata(data_new))

    # store the item for external serialization
    jser._externals[archive_path] = ExternalItem(
        item_type=item_type,
        data=data_new,
        path=path,
    )

    return output


DEFAULT_SERIALIZER_HANDLERS_ZANJ: MonoTuple[ZANJSerializerHandler] = tuple(
    [
        ZANJSerializerHandler(
            check=lambda self, obj, path: (
                isinstance(obj, np.ndarray)
                and obj.size >= self.external_array_threshold
            ),
            serialize_func=lambda self, obj, path: zanj_external_serialize(
                self, obj, path, item_type="npy", _format="numpy.ndarray:external"
            ),
            uid="numpy.ndarray:external",
            source_pckg="zanj",
            desc="external numpy array",
        ),
        ZANJSerializerHandler(
            check=lambda self, obj, path: (
                str(type(obj)) == "<class 'torch.Tensor'>"
                and int(obj.nelement()) >= self.external_array_threshold
            ),
            serialize_func=lambda self, obj, path: zanj_external_serialize(
                self, obj, path, item_type="npy", _format="torch.Tensor:external"
            ),
            uid="torch.Tensor:external",
            source_pckg="zanj",
            desc="external torch tensor",
        ),
        ZANJSerializerHandler(
            check=lambda self, obj, path: isinstance(obj, list)
            and len(obj) >= self.external_list_threshold,
            serialize_func=lambda self, obj, path: zanj_external_serialize(
                self, obj, path, item_type="jsonl", _format="list:external"
            ),
            uid="list:external",
            source_pckg="zanj",
            desc="external list",
        ),
        ZANJSerializerHandler(
            check=lambda self, obj, path: isinstance(obj, tuple)
            and len(obj) >= self.external_list_threshold,
            serialize_func=lambda self, obj, path: zanj_external_serialize(
                self, obj, path, item_type="jsonl", _format="tuple:external"
            ),
            uid="tuple:external",
            source_pckg="zanj",
            desc="external tuple",
        ),
        ZANJSerializerHandler(
            check=lambda self, obj, path: (
                any(
                    "pandas.core.frame.DataFrame" in str(t)
                    for t in obj.__class__.__mro__
                )
                and len(obj) >= self.external_list_threshold
            ),
            serialize_func=lambda self, obj, path: zanj_external_serialize(
                self, obj, path, item_type="jsonl", _format="pandas.DataFrame:external"
            ),
            uid="pandas.DataFrame:external",
            source_pckg="zanj",
            desc="external pandas DataFrame",
        ),
        # ZANJSerializerHandler(
        #     check=lambda self, obj, path: "<class 'torch.nn.modules.module.Module'>"
        #     in [str(t) for t in obj.__class__.__mro__],
        #     serialize_func=lambda self, obj, path: zanj_serialize_torchmodule(
        #         self, obj, path,
        #     ),
        #     uid="torch.nn.Module",
        #     source_pckg="zanj",
        #     desc="fallback torch serialization",
        # ),
    ]
) + tuple(
    DEFAULT_HANDLERS  # type: ignore[arg-type]
)

# the complaint above is:
# error: Argument 1 to "tuple" has incompatible type "Sequence[SerializerHandler]"; expected "Iterable[ZANJSerializerHandler]"  [arg-type]

``````{ end_of_file="zanj/serializing.py" }

``````{ path="zanj/torchutil.py"  }
"""torch utilities for zanj -- in particular the `ConfiguredModel` base class

note that this requires torch
"""

from __future__ import annotations

import abc
import typing
import warnings
from typing import Any, Type, TypeVar

try:
    import torch  # type: ignore[import-not-found]
except ImportError as e:
    raise ImportError(
        "torch is required for zanj.torchutil, please install it with `pip install torch` or `pip install zanj[torch]`"
    ) from e

from muutils.json_serialize import SerializableDataclass
from muutils.json_serialize.json_serialize import ObjectPath
from muutils.json_serialize.util import safe_getsource, string_as_lines, _FORMAT_KEY

from zanj import ZANJ, register_loader_handler
from zanj.loading import LoaderHandler, load_item_recursive

# pylint: disable=protected-access

KWArgs = Any


def num_params(m: torch.nn.Module, only_trainable: bool = True):
    """return total number of parameters in a model

    - only counting shared parameters once
    - if `only_trainable` is False, will include parameters with `requires_grad = False`

    https://stackoverflow.com/questions/49201236/check-the-total-number-of-parameters-in-a-pytorch-model
    """
    parameters: list[torch.nn.Parameter] = list(m.parameters())
    if only_trainable:
        parameters = [p for p in parameters if p.requires_grad]

    unique: list[torch.nn.Parameter] = list(
        {p.data_ptr(): p for p in parameters}.values()
    )

    return sum(p.numel() for p in unique)


def get_module_device(
    m: torch.nn.Module,
) -> tuple[bool, torch.device | dict[str, torch.device]]:
    """get the current devices"""

    devs: dict[str, torch.device] = {name: p.device for name, p in m.named_parameters()}

    if len(devs) == 0:
        return False, devs

    # check if all devices are the same by getting one device
    dev_uni: torch.device = next(iter(devs.values()))

    if all(dev == dev_uni for dev in devs.values()):
        return True, dev_uni
    else:
        return False, devs


T_config = TypeVar("T_config", bound=SerializableDataclass)


class ConfiguredModel(
    torch.nn.Module,
    typing.Generic[T_config],
    metaclass=abc.ABCMeta,
):
    """a model that has a configuration, for saving with ZANJ

    ```python
    @set_config_class(YourConfig)
    class YourModule(ConfiguredModel[YourConfig]):
        def __init__(self, cfg: YourConfig):
            super().__init__(cfg)
    ```

    `__init__()` must initialize the model from a config object only, and call
    `super().__init__(zanj_model_config)`

    If you are inheriting from another class + ConfiguredModel,
    ConfiguredModel must be the first class in the inheritance list
    """

    # dont set this directly, use `set_config_class()` decorator
    _config_class: type | None = None
    zanj_config_class = property(lambda self: type(self)._config_class)

    def __init__(self, zanj_model_config: T_config, **kwargs):
        super().__init__(**kwargs)
        if self.zanj_config_class is None:
            raise NotImplementedError("you need to set `config_class` for your model")
        if not isinstance(zanj_model_config, self.zanj_config_class):  # type: ignore
            raise TypeError(
                f"config must be an instance of {self.zanj_config_class = }, got {type(zanj_model_config) = }"
            )

        self.zanj_model_config: T_config = zanj_model_config
        self.training_records: dict | None = None

    def serialize(
        self, path: ObjectPath = tuple(), zanj: ZANJ | None = None
    ) -> dict[str, Any]:
        if zanj is None:
            zanj = ZANJ()
        obj = dict(
            zanj_model_config=self.zanj_model_config.serialize(),
            meta=dict(
                class_name=self.__class__.__name__,
                class_doc=string_as_lines(self.__class__.__doc__),
                class_source=safe_getsource(self.__class__),
                module_name=self.__class__.__module__,
                module_mro=[str(x) for x in self.__class__.__mro__],
                num_params=num_params(self),
                as_str=string_as_lines(str(self)),
            ),
            training_records=self.training_records,
            state_dict=self.state_dict(),
            __muutils_format__=self.__class__.__name__,
        )
        return obj

    def save(self, file_path: str, zanj: ZANJ | None = None):
        if zanj is None:
            zanj = ZANJ()
        zanj.save(self.serialize(), file_path)

    def _load_state_dict_wrapper(
        self,
        state_dict: dict[str, torch.Tensor],
        **kwargs,
    ):
        """wrapper for `load_state_dict()` in case you need to override it"""
        assert len(kwargs) == 0, f"got unexpected kwargs: {kwargs}"
        return self.load_state_dict(state_dict)

    @classmethod
    def load(
        cls, obj: dict[str, Any], path: ObjectPath, zanj: ZANJ | None = None
    ) -> "ConfiguredModel":
        """load a model from a serialized object"""

        if zanj is None:
            zanj = ZANJ()

        # get the config
        zanj_model_config: T_config = cls._config_class.load(obj["zanj_model_config"])  # type: ignore

        # get the training records
        training_records: typing.Any = load_item_recursive(
            obj.get("training_records", None),
            tuple(path) + ("training_records",),
            zanj,
        )

        # initialize the model
        model: "ConfiguredModel" = cls(zanj_model_config)

        # load the state dict
        tensored_state_dict: dict[str, torch.Tensor] = load_item_recursive(
            obj["state_dict"],
            tuple(path) + ("state_dict",),
            zanj,
        )

        model._load_state_dict_wrapper(
            tensored_state_dict,
            **zanj.custom_settings.get("_load_state_dict_wrapper", dict()),
        )

        # set the training records
        model.training_records = training_records

        return model

    @classmethod
    def read(cls, file_path: str, zanj: ZANJ | None = None) -> "ConfiguredModel":
        """read a model from a file"""
        if zanj is None:
            zanj = ZANJ()

        mdl: ConfiguredModel = zanj.read(file_path)
        assert isinstance(mdl, cls), f"loaded object must be a {cls}, got {type(mdl)}"
        return mdl

    @classmethod
    def load_file(cls, file_path: str, zanj: ZANJ | None = None) -> "ConfiguredModel":
        """read a model from a file"""
        warnings.warn(
            "load_file() is deprecated, use read() instead", DeprecationWarning
        )
        return cls.read(file_path, zanj)

    @classmethod
    def get_handler(cls) -> LoaderHandler:
        cls_name: str = str(cls.__name__)
        return LoaderHandler(
            check=lambda json_item, path=None, z=None: (  # type: ignore
                isinstance(json_item, dict)
                and _FORMAT_KEY in json_item
                and json_item[_FORMAT_KEY].startswith(cls_name)
            ),
            load=lambda json_item, path=None, z=None: cls.load(json_item, path, z),  # type: ignore
            uid=cls_name,
            source_pckg=cls.__module__,
            desc=f"{cls.__module__} {cls_name} loader via zanj.torchutil.ConfiguredModel",
        )

    def num_params(self) -> int:
        return num_params(self)


def set_config_class(
    config_class: Type[SerializableDataclass],
) -> typing.Callable[[Type[ConfiguredModel]], Type[ConfiguredModel]]:
    if not issubclass(config_class, SerializableDataclass):
        raise TypeError(f"{config_class} must be a subclass of SerializableDataclass")

    def wrapper(cls: Type[ConfiguredModel]) -> Type[ConfiguredModel]:
        # set the config class
        cls._config_class = config_class

        # register the handlers
        register_loader_handler(cls.get_handler())

        # return the new class
        return cls

    return wrapper


class ConfigMismatchException(ValueError):
    def __init__(self, msg: str, diff):
        super().__init__(msg)
        self.diff = diff

    def __str__(self):
        return f"{super().__str__()}: {self.diff}"


def assert_model_cfg_equality(model_a: ConfiguredModel, model_b: ConfiguredModel):
    """check both models are correct instances and have the same config

    Raises:
        ConfigMismatchException: if the configs don't match, e.diff will contain the diff
    """
    assert isinstance(model_a, ConfiguredModel), "model_a must be a ConfiguredModel"
    assert isinstance(model_a.zanj_model_config, SerializableDataclass), (
        "model_a must have a zanj_model_config"
    )
    assert isinstance(model_b, ConfiguredModel), "model_b must be a ConfiguredModel"
    assert isinstance(model_b.zanj_model_config, SerializableDataclass), (
        "model_b must have a zanj_model_config"
    )

    cls_type: type = type(model_a.zanj_model_config)

    if not (model_a.zanj_model_config == model_b.zanj_model_config):
        raise ConfigMismatchException(
            f"configs of type {type(model_a.zanj_model_config)}, {type(model_b.zanj_model_config)} don't match",
            diff=cls_type.diff(model_a.zanj_model_config, model_b.zanj_model_config),  # type: ignore[attr-defined]
        )


def assert_model_exact_equality(model_a: ConfiguredModel, model_b: ConfiguredModel):
    """check the models are exactly equal, including state dict contents"""
    assert_model_cfg_equality(model_a, model_b)

    model_a_sd_keys: set[str] = set(model_a.state_dict().keys())
    model_b_sd_keys: set[str] = set(model_b.state_dict().keys())
    assert model_a_sd_keys == model_b_sd_keys, (
        f"state dict keys don't match: {model_a_sd_keys - model_b_sd_keys} / {model_b_sd_keys - model_a_sd_keys}"
    )
    keys_failed: list[str] = list()
    for k, v_a in model_a.state_dict().items():
        v_b = model_b.state_dict()[k]
        if not (v_a == v_b).all():
            # if not torch.allclose(v, v_load):
            keys_failed.append(k)
            print(f"failed {k}")
        else:
            print(f"passed {k}")
    assert len(keys_failed) == 0, (
        f"{len(keys_failed)} / {len(model_a_sd_keys)} state dict elements don't match: {keys_failed}"
    )

``````{ end_of_file="zanj/torchutil.py" }

``````{ path="zanj/zanj.py"  }
"""
an HDF5/exdir file alternative, which uses json for attributes, allows serialization of arbitrary data

for large arrays, the output is a .tar.gz file with most data in a json file, but with sufficiently large arrays stored in binary .npy files


"ZANJ" is an acronym that the AI tool [Elicit](https://elicit.org) came up with for me. not to be confused with:

- https://en.wikipedia.org/wiki/Zanj
- https://www.plutojournals.com/zanj/

"""

from __future__ import annotations

import json
import os
import time
import zipfile
from dataclasses import dataclass
from pathlib import Path
from typing import Any, Union

import numpy as np
from muutils.errormode import ErrorMode
from muutils.json_serialize.array import ArrayMode, arr_metadata
from muutils.json_serialize.json_serialize import (
    JsonSerializer,
    SerializerHandler,
    json_serialize,
)
from muutils.json_serialize.util import JSONitem, MonoTuple
from muutils.sysinfo import SysInfo

from zanj.externals import ZANJ_MAIN, ZANJ_META, ExternalItem
import zanj.externals
from zanj.loading import LOADER_MAP, LoadedZANJ, load_item_recursive
from zanj.serializing import (
    DEFAULT_SERIALIZER_HANDLERS_ZANJ,
    EXTERNAL_STORE_FUNCS,
    KW_ONLY_KWARGS,
)

# pylint: disable=protected-access, unused-import, dangerous-default-value, line-too-long

ZANJitem = Union[
    JSONitem,
    np.ndarray,
    "pd.DataFrame",  # type: ignore # noqa: F821
]


@dataclass(**KW_ONLY_KWARGS)
class _ZANJ_GLOBAL_DEFAULTS_CLASS:
    error_mode: ErrorMode = ErrorMode.EXCEPT
    internal_array_mode: ArrayMode = "array_list_meta"
    external_array_threshold: int = 256
    external_list_threshold: int = 256
    compress: bool | int = True
    custom_settings: dict[str, Any] | None = None


ZANJ_GLOBAL_DEFAULTS: _ZANJ_GLOBAL_DEFAULTS_CLASS = _ZANJ_GLOBAL_DEFAULTS_CLASS()


class ZANJ(JsonSerializer):
    """Zip up: Arrays in Numpy, JSON for everything else

    given an arbitrary object, throw into a zip file, with arrays stored in .npy files, and everything else stored in a json file

    (basically npz file with json)

    - numpy (or pytorch) arrays are stored in paths according to their name and structure in the object
    - everything else about the object is stored in a json file `zanj.json` in the root of the archive, via `muutils.json_serialize.JsonSerializer`
    - metadata about ZANJ configuration, and optionally packages and versions, is stored in a `__zanj_meta__.json` file in the root of the archive

    create a ZANJ-class via `z_cls = ZANJ().create(obj)`, and save/read instances of the object via `z_cls.save(obj, path)`, `z_cls.load(path)`. be sure to pass an **instance** of the object, to make sure that the attributes of the class can be correctly recognized

    """

    def __init__(
        self,
        error_mode: ErrorMode = ZANJ_GLOBAL_DEFAULTS.error_mode,
        internal_array_mode: ArrayMode = ZANJ_GLOBAL_DEFAULTS.internal_array_mode,
        external_array_threshold: int = ZANJ_GLOBAL_DEFAULTS.external_array_threshold,
        external_list_threshold: int = ZANJ_GLOBAL_DEFAULTS.external_list_threshold,
        compress: bool | int = ZANJ_GLOBAL_DEFAULTS.compress,
        custom_settings: dict[str, Any] | None = ZANJ_GLOBAL_DEFAULTS.custom_settings,
        handlers_pre: MonoTuple[SerializerHandler] = tuple(),
        handlers_default: MonoTuple[
            SerializerHandler
        ] = DEFAULT_SERIALIZER_HANDLERS_ZANJ,
    ) -> None:
        super().__init__(
            array_mode=internal_array_mode,
            error_mode=error_mode,
            handlers_pre=handlers_pre,
            handlers_default=handlers_default,
        )

        self.external_array_threshold: int = external_array_threshold
        self.external_list_threshold: int = external_list_threshold
        self.custom_settings: dict = (
            custom_settings if custom_settings is not None else dict()
        )

        # process compression to int if bool given
        self.compress = compress
        if isinstance(compress, bool):
            if compress:
                self.compress = zipfile.ZIP_DEFLATED
            else:
                self.compress = zipfile.ZIP_STORED

        # create the externals, leave it empty
        self._externals: dict[str, ExternalItem] = dict()

    def externals_info(self) -> dict[str, dict[str, str | int | list[int]]]:
        """return information about the current externals"""
        output: dict[str, dict] = dict()

        key: str
        item: ExternalItem
        for key, item in self._externals.items():
            data = item.data
            output[key] = {
                "item_type": item.item_type,
                "path": item.path,
                "type(data)": str(type(data)),
                "len(data)": len(data),
            }

            if item.item_type == "ndarray":
                output[key].update(arr_metadata(data))
            elif item.item_type.startswith("jsonl"):
                output[key]["data[0]"] = data[0]

        return {
            key: val
            for key, val in sorted(output.items(), key=lambda x: len(x[1]["path"]))
        }

    def meta(self) -> JSONitem:
        """return the metadata of the ZANJ archive"""

        serialization_handlers = {h.uid: h.serialize() for h in self.handlers}
        load_handlers = {h.uid: h.serialize() for h in LOADER_MAP.values()}

        return dict(
            # configuration of this ZANJ instance
            zanj_cfg=dict(
                error_mode=str(self.error_mode),
                array_mode=str(self.array_mode),
                external_array_threshold=self.external_array_threshold,
                external_list_threshold=self.external_list_threshold,
                compress=self.compress,
                serialization_handlers=serialization_handlers,
                load_handlers=load_handlers,
            ),
            # system info (python, pip packages, torch & cuda, platform info, git info)
            sysinfo=json_serialize(SysInfo.get_all(include=("python", "pytorch"))),
            externals_info=self.externals_info(),
            timestamp=time.time(),
        )

    def save(self, obj: Any, file_path: str | Path) -> str:
        """save the object to a ZANJ archive. returns the path to the archive"""

        # adjust extension
        file_path = str(file_path)
        if not file_path.endswith(".zanj"):
            file_path += ".zanj"

        # make directory
        dir_path: str = os.path.dirname(file_path)
        if dir_path != "":
            if not os.path.exists(dir_path):
                os.makedirs(dir_path, exist_ok=False)

        # clear the externals!
        self._externals = dict()

        # serialize the object -- this will populate self._externals
        # TODO: calling self.json_serialize again here might be slow
        json_data: JSONitem = self.json_serialize(self.json_serialize(obj))

        # open the zip file
        zipf: zipfile.ZipFile = zipfile.ZipFile(
            file=file_path, mode="w", compression=self.compress
        )

        # store base json data and metadata
        zipf.writestr(
            ZANJ_META,
            json.dumps(
                self.json_serialize(self.meta()),
                indent="\t",
            ),
        )
        zipf.writestr(
            ZANJ_MAIN,
            json.dumps(
                json_data,
                indent="\t",
            ),
        )

        # store externals
        for key, (ext_type, ext_data, ext_path) in self._externals.items():
            # why force zip64? numpy.savez does it
            with zipf.open(key, "w", force_zip64=True) as fp:
                EXTERNAL_STORE_FUNCS[ext_type](self, fp, ext_data)

        zipf.close()

        # clear the externals, again
        self._externals = dict()

        return file_path

    def read(
        self,
        file_path: Union[str, Path],
    ) -> Any:
        """load the object from a ZANJ archive
        # TODO: load only some part of the zanj file by passing an ObjectPath
        """
        file_path = Path(file_path)
        if not file_path.exists():
            raise FileNotFoundError(f"file not found: {file_path}")
        if not file_path.is_file():
            raise FileNotFoundError(f"not a file: {file_path}")

        loaded_zanj: LoadedZANJ = LoadedZANJ(
            path=file_path,
            zanj=self,
        )

        loaded_zanj.populate_externals()

        return load_item_recursive(
            loaded_zanj._json_data,
            path=tuple(),
            zanj=self,
            error_mode=self.error_mode,
            # lh_map=loader_handlers,
        )


zanj.externals._ZANJ_pre = ZANJ  # type: ignore

``````{ end_of_file="zanj/zanj.py" }

``````{ path="README.md"  }
[![PyPI](https://img.shields.io/pypi/v/zanj)](https://pypi.org/project/zanj/)
[![Checks](https://github.com/mivanit/zanj/actions/workflows/checks.yml/badge.svg)](https://github.com/mivanit/zanj/actions/workflows/checks.yml)
[![Coverage](docs/coverage/coverage.svg)](docs/coverage/coverage.txt)
![code size, bytes](https://img.shields.io/github/languages/code-size/mivanit/zanj)
![PyPI - Downloads](https://img.shields.io/pypi/dm/zanj)
[![DOI](https://zenodo.org/badge/618623453.svg)](https://doi.org/10.5281/zenodo.15540392)


<!-- ![GitHub commit activity](https://img.shields.io/github/commit-activity/t/mivanit/zanj)
![GitHub closed pull requests](https://img.shields.io/github/issues-pr-closed/mivanit/zanj) -->
<!-- ![Lines of code](https://img.shields.io/tokei/lines/github.com/mivanit/zanj) -->

# ZANJ

# Overview

The `ZANJ` format is meant to be a way of saving arbitrary objects to disk, in a way that is flexible, allows keeping configuration and data together, and is human readable. It is very loosely inspired by HDF5 and the derived `exdir` format, and the implementation is inspired by `npz` files.

- You can take any `SerializableDataclass` from the [muutils](https://github.com/mivanit/muutils) library and save it to disk -- any large arrays or lists will be stored efficiently as external files in the zip archive, while the basic structure and metadata will be stored in readable JSON files. 
- You can also specify a special `ConfiguredModel`, which inherits from a `torch.nn.Module` which will let you save not just your model weights, but all required configuration information, plus any other metadata (like training logs) in a single file.

This library was originally a module in [muutils](https://github.com/mivanit/muutils/)


# Installation
Available on PyPI as [`zanj`](https://pypi.org/project/zanj/)

```
pip install zanj
```

# Usage

You can find a runnable example of this in [`demo.ipynb`](demo.ipynb)

## Saving a basic object

Any `SerializableDataclass` of basic types can be saved as zanj:

```python
import numpy as np
import pandas as pd
from muutils.json_serialize import SerializableDataclass, serializable_dataclass, serializable_field
from zanj import ZANJ

@serializable_dataclass
class BasicZanj(SerializableDataclass):
    a: str
    q: int = 42
    c: list[int] = serializable_field(default_factory=list)

# initialize a zanj reader/writer
zj = ZANJ()

# create an instance
instance: BasicZanj = BasicZanj("hello", 42, [1, 2, 3])
path: str = "tests/junk_data/path_to_save_instance.zanj"
zj.save(instance, path)
recovered: BasicZanj = zj.read(path)
```

ZANJ will intelligently handle nested serializable dataclasses, numpy arrays, pytorch tensors, and pandas dataframes: 

```python
import torch
import pandas as pd

@serializable_dataclass
class Complicated(SerializableDataclass):
    name: str
    arr1: np.ndarray
    arr2: np.ndarray
    iris_data: pd.DataFrame
    brain_data: pd.DataFrame
    container: list[BasicZanj]
    torch_tensor: torch.Tensor
```

For custom classes, you can specify a `serialization_fn` and `loading_fn` to handle the logic of converting to and from a json-serializable format:

```python
@serializable_dataclass
class Complicated(SerializableDataclass):
    name: str
    device: torch.device = serializable_field(
        serialization_fn=lambda self: str(self.device),
        loading_fn=lambda data: torch.device(data["device"]),
    )
```

Note that `loading_fn` takes the dictionary of the whole class -- this is in case you've stored data in multiple fields of the dict which are needed to reconstruct the object.

## Saving Models

First, define a configuration class for your model. This class will hold the parameters for your model and any associated objects (like losses and optimizers). The configuration class should be a subclass of `SerializableDataclass` and use the `serializable_field` function to define fields that need special serialization.

Here's an example that defines a GPT-like model configuration:

```python
from zanj.torchutil import ConfiguredModel, set_config_class

@serializable_dataclass
class MyNNConfig(SerializableDataclass):
    input_dim: int
    hidden_dim: int
    output_dim: int

    # store the activation function by name, reconstruct it by looking it up in torch.nn
    act_fn: torch.nn.Module = serializable_field(
        serialization_fn=lambda x: x.__name__,
        loading_fn=lambda x: getattr(torch.nn, x["act_fn"]),
    )

    # same for the loss function
    loss_kwargs: dict = serializable_field(default_factory=dict)
    loss_factory: torch.nn.modules.loss._Loss = serializable_field(
        default_factory=lambda: torch.nn.CrossEntropyLoss,
        serialization_fn=lambda x: x.__name__,
        loading_fn=lambda x: getattr(torch.nn, x["loss_factory"]),
    )
    loss = property(lambda self: self.loss_factory(**self.loss_kwargs))
```

Then, define your model class. It should be a subclass of `ConfiguredModel`, and use the `set_config_class` decorator to associate it with your configuration class. The `__init__` method should take a single argument, which is an instance of your configuration class. You must also call the superclass `__init__` method with the configuration instance.

```python
@set_config_class(MyNNConfig)
class MyNN(ConfiguredModel[MyNNConfig]):
    def __init__(self, config: MyNNConfig):
		# call the superclass init!
		# this will store the model in the zanj_model_config field
        super().__init__(config)

		# whatever you want here
        self.net = torch.nn.Sequential(
            torch.nn.Linear(config.input_dim, config.hidden_dim),
            config.act_fn(),
            torch.nn.Linear(config.hidden_dim, config.output_dim),
        )

    def forward(self, x):
        return self.net(x)
```

You can now create instances of your model, save them to disk, and load them back into memory:

```python
config = MyNNConfig(
    input_dim=10,
    hidden_dim=20,
    output_dim=2,
    act_fn=torch.nn.ReLU,
    loss_kwargs=dict(reduction="mean"),
)

# create your model from the config, and save
model = MyNN(config)
fname = "tests/junk_data/path_to_save_model.zanj"
ZANJ().save(model, fname)
# load by calling the class method `read()`
loaded_model = MyNN.read(fname)
# zanj will actually infer the type of the object in the file 
# -- and will warn you if you don't have the correct package installed
loaded_another_way = ZANJ().read(fname)
```

## Configuration

When initializing a `ZANJ` object, you can specify some configuration info about saving, such as:

- thresholds for how big an array/table has to be before moving to external file
- compression settings
- error modes
- additional handlers for serialization

```python
# how big an array or list (including pandas DataFrame) can be before moving it from the core JSON file
external_array_threshold: int = ZANJ_GLOBAL_DEFAULTS.external_array_threshold
external_list_threshold: int = ZANJ_GLOBAL_DEFAULTS.external_list_threshold
# compression settings passed to `zipfile` package
compress: bool | int = ZANJ_GLOBAL_DEFAULTS.compress
# for doing very cursed things in your own custom loading or serialization functions
custom_settings: dict[str, Any] | None = ZANJ_GLOBAL_DEFAULTS.custom_settings
# specify additional serialization handlers
handlers_pre: MonoTuple[SerializerHandler] = tuple()
handlers_default: MonoTuple[SerializerHandler] = DEFAULT_SERIALIZER_HANDLERS_ZANJ,
```

# Implementation

The on-disk format is a file `<filename>.zanj` is a zip file containing:

- `__zanj_meta__.json`: a file containing zanj-specific metadata including:
	- system information
	- installed packages
	- information about external files
- `__zanj__.json`: a file containing user-specified data
	- when an element is too big, it can be moved to an external file
		- `.npy` for numpy arrays or torch tensors
		- `.jsonl` for pandas dataframes or large sequences
	- list of external files stored in `__zanj_meta__.json`
	- "$ref" key, specified in `_REF_KEY` in muutils, will have value pointing to external file
	- `_FORMAT_KEY` key will detail an external format type


# Comparison to other formats



| Format                  | Safe | Zero-copy | Lazy loading | No file size limit | Layout control | Flexibility | Bfloat16 |
| ----------------------- | ---- | --------- | ------------ | ------------------ | -------------- | ----------- | -------- |
| pickle (PyTorch)        | ❌   | ❌        | ❌           | ✅                 | ❌             | ✅          | ✅       |
| H5 (Tensorflow)         | ✅   | ❌        | ✅           | ✅                 | ~              | ~           | ❌       |
| HDF5                    | ✅   | ?         | ✅           | ✅                 | ~              | ✅          | ❌       |
| SavedModel (Tensorflow) | ✅   | ❌        | ❌           | ✅                 | ✅             | ❌          | ✅       |
| MsgPack (flax)          | ✅   | ✅        | ❌           | ✅                 | ❌             | ❌          | ✅       |
| Protobuf (ONNX)         | ✅   | ❌        | ❌           | ❌                 | ❌             | ❌          | ✅       |
| Cap'n'Proto             | ✅   | ✅        | ~            | ✅                 | ✅             | ~           | ❌       |
| Numpy (npy,npz)         | ✅   | ?         | ?            | ❌                 | ✅             | ❌          | ❌       |
| SafeTensors             | ✅   | ✅        | ✅           | ✅                 | ✅             | ❌          | ✅       |
| exdir                   | ✅   | ?         | ?            | ?                  | ?              | ✅          | ❌       |
| ZANJ                    | ✅   | ❌        | ❌*          | ✅                 | ✅             | ✅          | ❌*      |


- Safe: Can I use a file randomly downloaded and expect not to run arbitrary code ?
- Zero-copy: Does reading the file require more memory than the original file ?
- Lazy loading: Can I inspect the file without loading everything ? And loading only some tensors in it without scanning the whole file (distributed setting) ?
- Layout control: Lazy loading, is not necessarily enough since if the information about tensors is spread out in your file, then even if the information is lazily accessible you might have to access most of your file to read the available tensors (incurring many DISK -> RAM copies). Controlling the layout to keep fast access to single tensors is important.
- No file size limit: Is there a limit to the file size ?
- Flexibility: Can I save custom code in the format and be able to use it later with zero extra code ? (~ means we can store more than pure tensors, but no custom code)
- Bfloat16: Does the format support native bfloat16 (meaning no weird workarounds are necessary)? This is becoming increasingly important in the ML world.

`*` denotes this feature may be coming at a future date :)

(This table was stolen from [safetensors](https://github.com/huggingface/safetensors/blob/main/README.md))

``````{ end_of_file="README.md" }

``````{ path="demo.ipynb" processed_with="ipynb_to_md" }
# Installation
Available on PyPI as [`zanj`](https://pypi.org/project/zanj/)

```
pip install zanj
```

```python
import os

import numpy as np
import pandas as pd
import torch

from muutils.json_serialize import (
    SerializableDataclass,
    serializable_dataclass,
    serializable_field,
)
from zanj import ZANJ
```


# Usage

## Saving a basic object

Any `SerializableDataclass` of basic types can be saved as zanj:

```python
@serializable_dataclass
class BasicZanj(SerializableDataclass):
    a: str
    q: int = 42
    c: list[int] = serializable_field(default_factory=list)


# initialize a zanj reader/writer
zj = ZANJ()

# create an instance
instance: BasicZanj = BasicZanj("hello", 42, [1, 2, 3])
path: str = "tests/junk_data/path_to_save_instance.zanj"
zj.save(instance, path)
recovered: BasicZanj = zj.read(path)
```

```python
print(f"{type(recovered) = }")  # BasicZanj
print(f"{os.path.getsize(path) = }")
```

ZANJ will intelligently handle nested serializable dataclasses, numpy arrays, pytorch tensors, and pandas dataframes: 

```python
@serializable_dataclass
class Complicated(SerializableDataclass):
    name: str
    arr1: np.ndarray
    arr2: np.ndarray
    iris_data: pd.DataFrame
    brain_data: pd.DataFrame
    container: list[BasicZanj]
    torch_tensor: torch.Tensor
```

For custom classes, you can specify a `serialization_fn` and `loading_fn` to handle the logic of converting to and from a json-serializable format:

```python
@serializable_dataclass
class Complicated2(SerializableDataclass):
    name: str
    device: torch.device = serializable_field(
        serialization_fn=lambda self: str(self.device),
        loading_fn=lambda data: torch.device(data["device"]),
    )
```

Note that `loading_fn` takes the dictionary of the whole class -- this is in case you've stored data in multiple fields of the dict which are needed to reconstruct the object.

## Saving Models

First, define a configuration class for your model. This class will hold the parameters for your model and any associated objects (like losses and optimizers). The configuration class should be a subclass of `SerializableDataclass` and use the `serializable_field` function to define fields that need special serialization.

Here's an example that defines a configuration for a simple neural network:

```python
from zanj.torchutil import ConfiguredModel, set_config_class


@serializable_dataclass
class MyNNConfig(SerializableDataclass):
    input_dim: int
    hidden_dim: int
    output_dim: int

    # store the activation function by name, reconstruct it by looking it up in torch.nn
    act_fn: torch.nn.Module = serializable_field(
        serialization_fn=lambda x: x.__name__,
        loading_fn=lambda x: getattr(torch.nn, x["act_fn"]),
    )

    # same for the loss function
    loss_kwargs: dict = serializable_field(default_factory=dict)
    loss_factory: torch.nn.modules.loss._Loss = serializable_field(
        default_factory=lambda: torch.nn.CrossEntropyLoss,
        serialization_fn=lambda x: x.__name__,
        loading_fn=lambda x: getattr(torch.nn, x["loss_factory"]),
    )
    loss = property(lambda self: self.loss_factory(**self.loss_kwargs))
```

Then, define your model class. It should be a subclass of `ConfiguredModel`, and use the `set_config_class` decorator to associate it with your configuration class. The `__init__` method should take a single argument, which is an instance of your configuration class. You must also call the superclass `__init__` method with the configuration instance.

```python
@set_config_class(MyNNConfig)
class MyNN(ConfiguredModel[MyNNConfig]):
    def __init__(self, config: MyNNConfig):
        # call the superclass init!
        # this will store the model in the zanj_model_config field
        super().__init__(config)

        # whatever you want here
        self.net = torch.nn.Sequential(
            torch.nn.Linear(config.input_dim, config.hidden_dim),
            config.act_fn(),
            torch.nn.Linear(config.hidden_dim, config.output_dim),
        )

    def forward(self, x):
        return self.net(x)
```

You can now create instances of your model, save them to disk, and load them back into memory:

```python
config = MyNNConfig(
    input_dim=10,
    hidden_dim=20,
    output_dim=2,
    act_fn=torch.nn.ReLU,
    loss_kwargs=dict(reduction="mean"),
)

# create your model from the config, and save
model = MyNN(config)
fname = "tests/junk_data/path_to_save_model.zanj"
ZANJ().save(model, fname)
# load by calling the class method `read()`
loaded_model = MyNN.read(fname)
# zanj will actually infer the type of the object in the file
# -- and will warn you if you don't have the correct package installed
loaded_another_way = ZANJ().read(fname)
```

```python
print(f"{type(loaded_model) = }")
x = torch.randn(config.input_dim)
print(f"{x.shape = }")
out_1 = model(x)
out_2 = loaded_model(x)
out_3 = loaded_another_way(x)

print(f"{out_1 = }, {out_2 = }, {out_3 = }")
assert torch.allclose(out_1, out_2)
assert torch.allclose(out_1, out_3)
```


``````{ end_of_file="demo.ipynb" }

``````{ path="makefile" processed_with="makefile_recipes" }
# first/default target is help
.PHONY: default
default: help
	...

# this recipe is weird. we need it because:
# - a one liner for getting the version with toml is unwieldy, and using regex is fragile
# - using $$SCRIPT_GET_VERSION within $(shell ...) doesn't work because of escaping issues
# - trying to write to the file inside the `gen-version-info` recipe doesn't work, 
# 	shell eval happens before our `python -c ...` gets run and `cat` doesn't see the new file
.PHONY: write-proj-version
write-proj-version:
	...

# gets version info from $(PYPROJECT), last version from $(LAST_VERSION_FILE), and python version
# uses just `python` for everything except getting the python version. no echo here, because this is "private"
.PHONY: gen-version-info
gen-version-info: write-proj-version
	...

# getting commit log since the tag specified in $(LAST_VERSION_FILE)
# will write to $(COMMIT_LOG_FILE)
# when publishing, the contents of $(COMMIT_LOG_FILE) will be used as the tag description (but can be edited during the process)
# no echo here, because this is "private"
.PHONY: gen-commit-log
gen-commit-log: gen-version-info
	...

# force the version info to be read, printing it out
# also force the commit log to be generated, and cat it out
.PHONY: version
version: gen-commit-log
	@echo "Current version is $(PROJ_VERSION), last auto-uploaded version is $(LAST_VERSION)"
	...

.PHONY: setup
setup: dep-check
	@echo "install and update via uv"
	...

.PHONY: dep-check-torch
dep-check-torch:
	@echo "see if torch is installed, and which CUDA version and devices it sees"
	...

.PHONY: dep
dep:
	@echo "Exporting dependencies as per $(PYPROJECT) section 'tool.uv-exports.exports'"
	...

.PHONY: dep-check
dep-check:
	@echo "Checking that exported requirements are up to date"
	...

.PHONY: dep-clean
dep-clean:
	@echo "clean up lock files, .venv, and requirements files"
	...

# runs ruff and pycln to format the code
.PHONY: format
format:
	@echo "format the source code"
	...

# runs ruff and pycln to check if the code is formatted correctly
.PHONY: format-check
format-check:
	@echo "check if the source code is formatted correctly"
	...

# runs type checks with mypy
.PHONY: typing
typing: clean
	@echo "running type checks"
	...

# generates a report of the mypy output
.PHONY: typing-report
typing-report:
	@echo "generate a report of the type check output -- errors per file"
	...

.PHONY: test
test: clean
	@echo "running tests"
	...

.PHONY: test-notorch
test-notorch: clean
	@echo "running only tests without torch"
	...

.PHONY: check
check: clean format-check test typing
	@echo "run format checks, tests, and typing checks"
	...

# generates a whole tree of documentation in html format.
# see `$(MAKE_DOCS_SCRIPT_PATH)` and the templates in `$(DOCS_RESOURCES_DIR)/templates/html/` for more info
.PHONY: docs-html
docs-html:
	@echo "generate html docs"
	...

# instead of a whole website, generates a single markdown file with all docs using the templates in `$(DOCS_RESOURCES_DIR)/templates/markdown/`.
# this is useful if you want to have a copy that you can grep/search, but those docs are much messier.
# docs-combined will use pandoc to convert them to other formats.
.PHONY: docs-md
docs-md:
	@echo "generate combined (single-file) docs in markdown"
	...

# after running docs-md, this will convert the combined markdown file to other formats:
# gfm (github-flavored markdown), plain text, and html
# requires pandoc in path, pointed to by $(PANDOC)
# pdf output would be nice but requires other deps
.PHONY: docs-combined
docs-combined: docs-md
	@echo "generate combined (single-file) docs in markdown and convert to other formats"
	...

# generates coverage reports as html and text with `pytest-cov`, and a badge with `coverage-badge`
# if `.coverage` is not found, will run tests first
# also removes the `.gitignore` file that `coverage html` creates, since we count that as part of the docs
.PHONY: cov
cov:
	@echo "generate coverage reports"
	...

# runs the coverage report, then the docs, then the combined docs
.PHONY: docs
docs: cov docs-html docs-combined todo lmcat
	@echo "generate all documentation and coverage reports"
	...

# removed all generated documentation files, but leaves everything in `$DOCS_RESOURCES_DIR`
# and leaves things defined in `pyproject.toml:tool.makefile.docs.no_clean`
# (templates, svg, css, make_docs.py script)
# distinct from `make clean`
.PHONY: docs-clean
docs-clean:
	@echo "remove generated docs except resources"
	...

.PHONY: todo
todo:
	@echo "get all TODO's from the code"
	...

.PHONY: lmcat-tree
lmcat-tree:
	@echo "show in console the lmcat tree view"
	...

.PHONY: lmcat
lmcat:
	@echo "write the lmcat full output to pyproject.toml:[tool.lmcat.output]"
	...

# verifies that the current branch is $(PUBLISH_BRANCH) and that git is clean
# used before publishing
.PHONY: verify-git
verify-git: 
	@echo "checking git status"
	...

.PHONY: build
build: 
	@echo "build the package"
	...

# gets the commit log, checks everything, builds, and then publishes with twine
# will ask the user to confirm the new version number (and this allows for editing the tag info)
# will also print the contents of $(PYPI_TOKEN_FILE) to the console for the user to copy and paste in when prompted by twine
.PHONY: publish
publish: gen-commit-log check build verify-git version gen-version-info
	@echo "run all checks, build, and then publish"
	...

# cleans up temp files from formatter, type checking, tests, coverage
# removes all built files
# removes $(TESTS_TEMP_DIR) to remove temporary test files
# recursively removes all `__pycache__` directories and `*.pyc` or `*.pyo` files
# distinct from `make docs-clean`, which only removes generated documentation files
.PHONY: clean
clean:
	@echo "clean up temporary files"
	...

.PHONY: clean-all
clean-all: clean docs-clean dep-clean
	@echo "clean up all temporary files, dep files, venv, and generated docs"
	...

.PHONY: info
info: gen-version-info
	@echo "# makefile variables"
	...

.PHONY: info-long
info-long: info
	@echo "# other variables"
	...

# immediately print out the help targets, and then local variables (but those take a bit longer)
.PHONY: help
help: help-targets info
	@echo -n ""
	...

``````{ end_of_file="makefile" }

``````{ path="pyproject.toml"  }
[project]
    name = "zanj"
    version = "0.5.2"
    description = "save and load complex objects to disk without pickling"
    license = "GPL-3.0-only"
    authors = [
        { name = "Michael Ivanitskiy", email = "mivanits@umich.edu" }
    ]
    readme = "README.md"
    requires-python = ">=3.8"
    classifiers=[
        "Programming Language :: Python :: 3.8",
        "Programming Language :: Python :: 3.9",
        "Programming Language :: Python :: 3.10",
        "Programming Language :: Python :: 3.11",
        "Programming Language :: Python :: 3.12",
        "Development Status :: 4 - Beta",
        "License :: OSI Approved :: GNU General Public License v3 (GPLv3)",
        "Operating System :: OS Independent",
    ]

    dependencies = [
        "muutils>=0.8.0",
        "numpy>=1.24.4",
        "jaxtyping>=0.2.12",
    ]


[project.optional-dependencies]
pandas = ["pandas>=1.5.3"]
torch = [
    "torch>=1.13.1",
    "h11>=0.16.0", # see https://github.com/mivanit/ZANJ/security/dependabot/13
]

[project.urls]
    Homepage = "https://miv.name/zanj"
    Repository = "https://github.com/mivanit/zanj"
    Documentation = "https://miv.name/zanj/"
	Issues = "https://github.com/mivanit/ZANJ/issues"



[dependency-groups]
    dev = [
        # typing
        "mypy>=1.0.1",
        "ty",
        # tests & coverage
        "pytest>=8.2.2",
        "pytest-cov>=4.1.0",
        "coverage-badge>=1.1.1",
        # for testing plotting and notebooks
        "ipykernel>=6.23.2",
        # tornado depended on by notebook stuff, see
        # https://github.com/mivanit/ZANJ/security/dependabot/22
        "tornado>=6.5; python_version >= '3.9'",
        "jupyter",
        "matplotlib>=3.0.0",
        "plotly>=5.0.0",
        # see https://github.com/mivanit/ZANJ/security/dependabot/18
        "setuptools>=78.1.1; python_version >= '3.9'",
        # generating docs
        "pdoc>=14.6.0",
        # https://github.com/mivanit/ZANJ/security/dependabot/4
        "jinja2>=3.1.6",
        # lmcat -- a custom library. not exactly docs, but lets an LLM see all the code
        "lmcat>=0.2.0; python_version >= '3.11'",
        # tomli since no tomlib in python < 3.11
        "tomli>=2.1.0; python_version < '3.11'",
        # for uploading
        "twine",
    ]
    lint = [
        # lint
        "pycln>=2.1.3",
        "ruff>=0.4.8",
    ]

[build-system]
    requires = ["hatchling"]
    build-backend = "hatchling.build"


[tool.ruff]
    exclude = ["tests/input_data", "tests/junk_data"]

[tool.pycln]
    all = true
    exclude = ["tests/input_data", "tests/junk_data"]

[tool.mypy]
    exclude = ["tests/input_data", "tests/junk_data"]
    show_error_codes = true


[tool.lmcat]
    output = "docs/other/lmcat.txt" # changing this might mean it wont be accessible from the docs
    ignore_patterns = [
		"docs/**",
		".venv/**",
		".git/**",
		".meta/**",
        ".ruff_cache/**",
		"uv.lock",
		"LICENSE",
	]
    [tool.lmcat.glob_process]
        "[mM]akefile" = "makefile_recipes"
        "*.ipynb" = "ipynb_to_md"
        "*.csv" = "csv_preview_5_lines"

# for configuring this tool (makefile, make_docs.py)
# ============================================================
[tool.makefile]
[tool.makefile.docs]
    output_dir = "docs"
    no_clean = [
        ".nojekyll",  # For GitHub Pages
        "temp",
    ]
    markdown_headings_increment = 2
    warnings_ignore = []

    [tool.makefile.docs.notebooks]
        enabled = true
        source_path = "."
        output_path_relative = "notebooks"
        [tool.makefile.docs.notebooks.descriptions]
            "demo" = "Example notebook showing basic usage"


[tool.makefile.uv-exports]
	args = [
		"--no-hashes"
	]
	exports = [
		# no groups, no extras, just the base dependencies
		{ name = "base", groups = false, extras = false },
		# all groups
		{ name = "groups", groups = true, extras = false },
		# only the lint group -- custom options for this
		{ name = "lint", options = ["--only-group", "lint"] },
		# # all groups and extras
		{ name = "all", filename="requirements.txt", groups = true, extras=true },
		# # all groups and extras, a different way
		{ name = "all", groups = true, options = ["--all-extras"] },
	]

# configures `make todo`
[tool.makefile.inline-todo]
	search_dir = "."
	out_file_base = "docs/other/todo-inline.md"
	context_lines = 2
	extensions = ["py", "md"]
	tags = ["CRIT", "TODO", "FIXME", "HACK", "BUG", "DOC"]
	exclude = [
		"docs/**",
		".venv/**",
	]
    [tool.makefile.inline-todo.tag_label_map]
        "BUG" = "bug"
        "TODO" = "enhancement"
		"DOC" = "documentation"

# ============================================================


``````{ end_of_file="pyproject.toml" }