import jax
import matplotlib.pyplot as plt
alias pt jax.tree_util.tree_map(lambda x: x.shape, %1)

# Print instance variables (usage "pi classInst")
alias pi for k in %1.__dict__.keys(): print("%1.",k,"=",%1.__dict__[k])
# Print instance variables in self
alias ps pi self
alias sh %1.shape

alias pp plt.imshow(%1, cmap="gray")
alias ss plt.show()
