tensorflow>=2.18.0
tensorflow_probability>=0.25.0
tf_keras>=2.18.0
numpy
pyDOE
matplotlib
jax>=0.4.28
flax>=0.8.4
optax>=0.2.2
