Welcome to Torch2Jax
JAX
{% for field in pytree_fields %}
{{ field.path}} {{field.shape }}
{% endfor %}
PyTorch
{% for field in torch_fields %}
{{ field.path }} {{ field.shape }}
{% endfor %}
Convert!
Visualize with Penzai!