Coverage for gemlib/deterministic/ode_model.py: 86%

43 statements  

« prev     ^ index     » next       coverage.py v7.10.3, created at 2025-12-02 22:09 +0000

1"""ODE solver with gemlib state transition model interface""" 

2 

3from __future__ import annotations 

4 

5from collections.abc import Callable 

6from typing import NamedTuple 

7 

8import jax.numpy as jnp 

9import numpy as np 

10import tensorflow_probability.substrates.jax as tfp 

11from jax import Array 

12from jax.typing import ArrayLike 

13 

14from gemlib.func_util import maybe_combine_fn 

15from gemlib.tensor_util import broadcast_fn_to 

16from gemlib.util import batch_gather, transition_coords 

17 

18__all__ = ["ode_model"] 

19 

20 

21def _total_flux( 

22 transition_rates: ArrayLike, state: ArrayLike, incidence_matrix: ArrayLike 

23) -> Array: 

24 """Multiplies `transition_rates` by source `state`s to return 

25 the total flux along transitions given `state`. 

26 

27 Args 

28 ---- 

29 transition_rates: a `[R,N]` tensor of per-unit transition rates 

30 for `R` transitions and `N` aggregation units. 

31 state: a `[N, S]` tensor of `N` aggregation units and `S` states. 

32 incidence_matrix: a `[S, R]` matrix describing the change in `S` for 

33 each transition `R`. 

34 

35 Returns 

36 ------- 

37 A [R,N] tensor of total flux along each transition, taking into account the 

38 availability of units in the source state. 

39 """ 

40 

41 source_state_idx = transition_coords(np.array(incidence_matrix))[:, 0] 

42 source_states = batch_gather( 

43 state, indices=source_state_idx[:, jnp.newaxis] 

44 ) 

45 transition_rates = jnp.stack(transition_rates, axis=-1) 

46 

47 return jnp.einsum("...nr,...nr->...rn", transition_rates, source_states) 

48 

49 

50class ODEResults(NamedTuple): 

51 times: Array 

52 states: Array 

53 

54 

55def ode_model( 

56 transition_rate_fn: list[Callable[[float, ArrayLike], Array]] 

57 | Callable[[float, ArrayLike], tuple[Array]], 

58 incidence_matrix: ArrayLike, 

59 initial_state: ArrayLike, 

60 num_steps: int | None = None, 

61 initial_time: float = 0.0, 

62 time_delta: float = 1.0, 

63 times: ArrayLike | None = None, 

64 solver: str = "DormandPrince", 

65 solver_kwargs: dict | None = None, 

66) -> ODEResults: 

67 """Solve a system of differential equations 

68 

69 Args: 

70 transition_rate_fn: Either a list of callables of the form 

71 :code:`fn(t: float, state: Tensor) -> Tensor` or a Python callable 

72 of the form :code:`fn(t: float, state: Tensor) -> tuple(Tensor,...)` 

73 . In the first 

74 (preferred) form, each callable in the list corresponds to the 

75 respective transition in :code:`incidence_matrix`. In the second 

76 form, the callable should return a :code:`tuple` of transition rate 

77 tensors corresponding to transitions in :code:`incidence_matrix`. 

78 **Note**: the second form will be deprecated in future releases of 

79 :code:`gemlib`. 

80 incidence_matrix: a :code:`[S, R]` matrix describing the change in 

81 :code:`S` resulting from transitions :code:`R`. 

82 initial_state: a :code:`[...,N, S]` (batched) tensor with the state 

83 values for :code:`N` units and :code:`S` states. 

84 num_steps: python integer representing the size of the time step to be 

85 used. 

86 initial_time: an offset giving the time of the first time step in the 

87 model. 

88 time_delta: the size of the time step to be used. 

89 times: a 1-D tensor of times for which the ODE solutions are required. 

90 solver: a string giving the ODE solver method to use. Can be "rk45" 

91 (default) or "BDF". See the `TensorFlow Probability 

92 documentation`_ for details. 

93 solver_kwargs: a dictionary of keyword argument to supply to the 

94 solver. See the solver documentation for details. 

95 validate_args: check that the values of the parameters supplied to the 

96 constructor are all within the domain of the ODE function 

97 name: the name of this distribution. 

98 

99 .. _TensorFlow Probability documentation: 

100 https://www.tensorflow.org/probability/api_docs/python/tfp/math/ode 

101 """ 

102 

103 if (num_steps is not None) and (times is not None): 

104 raise ValueError("Must specify exactly one of `num_steps` or `times`") 

105 

106 if num_steps is not None: 

107 times = jnp.arange(initial_time, time_delta * num_steps, time_delta) 

108 elif times is not None: 

109 times = jnp.asarray(times) 

110 else: 

111 raise ValueError("Must specify either `num_steps` or `times`") 

112 

113 transition_rate_fn = maybe_combine_fn(transition_rate_fn) 

114 

115 if solver_kwargs is None: 

116 solver_kwargs = {} 

117 

118 if solver == "DormandPrince": 

119 solver_fn = tfp.math.ode.DormandPrince(**solver_kwargs) 

120 elif solver == "BDF": 

121 solver_fn = tfp.math.ode.BDF(**solver_kwargs) 

122 else: 

123 raise ValueError("`solver` must be one of 'DormandPrince' or 'BDF'") 

124 

125 def derivs(t, state): 

126 rates = broadcast_fn_to(transition_rate_fn, initial_state.shape[:-1])( 

127 t, state 

128 ) 

129 flux = _total_flux(rates, state, incidence_matrix) 

130 derivs = jnp.linalg.matmul(incidence_matrix, flux) 

131 return derivs.T 

132 

133 solver_results = solver_fn.solve( 

134 ode_fn=derivs, 

135 initial_time=initial_time, 

136 initial_state=initial_state, 

137 solution_times=times, 

138 ) 

139 

140 return ODEResults( 

141 times=solver_results.times, 

142 states=solver_results.states, 

143 )