LICENSE
MANIFEST.in
README.md
pyproject.toml
setup.py
kfac_jax/__init__.py
kfac_jax/py.typed
kfac_jax.egg-info/PKG-INFO
kfac_jax.egg-info/SOURCES.txt
kfac_jax.egg-info/dependency_links.txt
kfac_jax.egg-info/requires.txt
kfac_jax.egg-info/top_level.txt
kfac_jax/_src/__init__.py
kfac_jax/_src/layers_and_loss_tags.py
kfac_jax/_src/loss_functions.py
kfac_jax/_src/optimizer.py
kfac_jax/_src/patches_second_moment.py
kfac_jax/_src/tag_graph_matcher.py
kfac_jax/_src/tracer.py
kfac_jax/_src/curvature_blocks/__init__.py
kfac_jax/_src/curvature_blocks/curvature_block.py
kfac_jax/_src/curvature_blocks/diagonal.py
kfac_jax/_src/curvature_blocks/full.py
kfac_jax/_src/curvature_blocks/kronecker_factored.py
kfac_jax/_src/curvature_blocks/tnt.py
kfac_jax/_src/curvature_blocks/utils.py
kfac_jax/_src/curvature_estimator/__init__.py
kfac_jax/_src/curvature_estimator/block_diagonal.py
kfac_jax/_src/curvature_estimator/curvature_estimator.py
kfac_jax/_src/curvature_estimator/explicit_exact.py
kfac_jax/_src/curvature_estimator/implicit_exact.py
kfac_jax/_src/curvature_estimator/optax_interface.py
kfac_jax/_src/utils/__init__.py
kfac_jax/_src/utils/accumulators.py
kfac_jax/_src/utils/math.py
kfac_jax/_src/utils/misc.py
kfac_jax/_src/utils/parallel.py
kfac_jax/_src/utils/staging.py
kfac_jax/_src/utils/types.py
tests/test_estimator.py
tests/test_graph_matcher.py
tests/test_optax_interface.py
tests/test_patches_second_moment.py
tests/test_tracer.py
tests/test_utils.py