Welcome to Torch2Jax

JAX

{% for field in pytree_fields %}
{{ field.path }} {{ field.shape }}
{% endfor %}

PyTorch

{% for field in torch_fields %}
{{ field.path }} {{ field.shape }}
{% endfor %}