Coverage for gemlib/deterministic/ode_model.py: 86%
43 statements
« prev ^ index » next coverage.py v7.10.3, created at 2025-12-02 22:09 +0000
« prev ^ index » next coverage.py v7.10.3, created at 2025-12-02 22:09 +0000
1"""ODE solver with gemlib state transition model interface"""
3from __future__ import annotations
5from collections.abc import Callable
6from typing import NamedTuple
8import jax.numpy as jnp
9import numpy as np
10import tensorflow_probability.substrates.jax as tfp
11from jax import Array
12from jax.typing import ArrayLike
14from gemlib.func_util import maybe_combine_fn
15from gemlib.tensor_util import broadcast_fn_to
16from gemlib.util import batch_gather, transition_coords
18__all__ = ["ode_model"]
21def _total_flux(
22 transition_rates: ArrayLike, state: ArrayLike, incidence_matrix: ArrayLike
23) -> Array:
24 """Multiplies `transition_rates` by source `state`s to return
25 the total flux along transitions given `state`.
27 Args
28 ----
29 transition_rates: a `[R,N]` tensor of per-unit transition rates
30 for `R` transitions and `N` aggregation units.
31 state: a `[N, S]` tensor of `N` aggregation units and `S` states.
32 incidence_matrix: a `[S, R]` matrix describing the change in `S` for
33 each transition `R`.
35 Returns
36 -------
37 A [R,N] tensor of total flux along each transition, taking into account the
38 availability of units in the source state.
39 """
41 source_state_idx = transition_coords(np.array(incidence_matrix))[:, 0]
42 source_states = batch_gather(
43 state, indices=source_state_idx[:, jnp.newaxis]
44 )
45 transition_rates = jnp.stack(transition_rates, axis=-1)
47 return jnp.einsum("...nr,...nr->...rn", transition_rates, source_states)
50class ODEResults(NamedTuple):
51 times: Array
52 states: Array
55def ode_model(
56 transition_rate_fn: list[Callable[[float, ArrayLike], Array]]
57 | Callable[[float, ArrayLike], tuple[Array]],
58 incidence_matrix: ArrayLike,
59 initial_state: ArrayLike,
60 num_steps: int | None = None,
61 initial_time: float = 0.0,
62 time_delta: float = 1.0,
63 times: ArrayLike | None = None,
64 solver: str = "DormandPrince",
65 solver_kwargs: dict | None = None,
66) -> ODEResults:
67 """Solve a system of differential equations
69 Args:
70 transition_rate_fn: Either a list of callables of the form
71 :code:`fn(t: float, state: Tensor) -> Tensor` or a Python callable
72 of the form :code:`fn(t: float, state: Tensor) -> tuple(Tensor,...)`
73 . In the first
74 (preferred) form, each callable in the list corresponds to the
75 respective transition in :code:`incidence_matrix`. In the second
76 form, the callable should return a :code:`tuple` of transition rate
77 tensors corresponding to transitions in :code:`incidence_matrix`.
78 **Note**: the second form will be deprecated in future releases of
79 :code:`gemlib`.
80 incidence_matrix: a :code:`[S, R]` matrix describing the change in
81 :code:`S` resulting from transitions :code:`R`.
82 initial_state: a :code:`[...,N, S]` (batched) tensor with the state
83 values for :code:`N` units and :code:`S` states.
84 num_steps: python integer representing the size of the time step to be
85 used.
86 initial_time: an offset giving the time of the first time step in the
87 model.
88 time_delta: the size of the time step to be used.
89 times: a 1-D tensor of times for which the ODE solutions are required.
90 solver: a string giving the ODE solver method to use. Can be "rk45"
91 (default) or "BDF". See the `TensorFlow Probability
92 documentation`_ for details.
93 solver_kwargs: a dictionary of keyword argument to supply to the
94 solver. See the solver documentation for details.
95 validate_args: check that the values of the parameters supplied to the
96 constructor are all within the domain of the ODE function
97 name: the name of this distribution.
99 .. _TensorFlow Probability documentation:
100 https://www.tensorflow.org/probability/api_docs/python/tfp/math/ode
101 """
103 if (num_steps is not None) and (times is not None):
104 raise ValueError("Must specify exactly one of `num_steps` or `times`")
106 if num_steps is not None:
107 times = jnp.arange(initial_time, time_delta * num_steps, time_delta)
108 elif times is not None:
109 times = jnp.asarray(times)
110 else:
111 raise ValueError("Must specify either `num_steps` or `times`")
113 transition_rate_fn = maybe_combine_fn(transition_rate_fn)
115 if solver_kwargs is None:
116 solver_kwargs = {}
118 if solver == "DormandPrince":
119 solver_fn = tfp.math.ode.DormandPrince(**solver_kwargs)
120 elif solver == "BDF":
121 solver_fn = tfp.math.ode.BDF(**solver_kwargs)
122 else:
123 raise ValueError("`solver` must be one of 'DormandPrince' or 'BDF'")
125 def derivs(t, state):
126 rates = broadcast_fn_to(transition_rate_fn, initial_state.shape[:-1])(
127 t, state
128 )
129 flux = _total_flux(rates, state, incidence_matrix)
130 derivs = jnp.linalg.matmul(incidence_matrix, flux)
131 return derivs.T
133 solver_results = solver_fn.solve(
134 ode_fn=derivs,
135 initial_time=initial_time,
136 initial_state=initial_state,
137 solution_times=times,
138 )
140 return ODEResults(
141 times=solver_results.times,
142 states=solver_results.states,
143 )