tensorflow>=2.18.0
tensorflow_probability>=0.25.0
tf_keras>=2.18.0
numpy
pyDOE
matplotlib
jax>=0.7.0
flax>=0.11.1
optax>=0.2.2
