Coverage for tests / tests_simulation / test_runtime.py: 99%

125 statements  

« prev     ^ index     » next       coverage.py v7.13.1, created at 2026-01-09 16:40 +0100

1# SPDX-FileCopyrightText: Copyright INRIA 

2# 

3# SPDX-License-Identifier: LGPL-3.0-only 

4# 

5# Copyright INRIA 

6# 

7# This file is part of PhysioBlocks, a library mostly developed by the 

8# [Ananke project-team](https://team.inria.fr/ananke) at INRIA. 

9# 

10# Authors: 

11# - Colin Drieu 

12# - Dominique Chapelle 

13# - François Kimmig 

14# - Philippe Moireau 

15# 

16# PhysioBlocks is free software: you can redistribute it and/or modify it under the 

17# terms of the GNU Lesser General Public License as published by the Free Software 

18# Foundation, version 3 of the License. 

19# 

20# PhysioBlocks is distributed in the hope that it will be useful, but WITHOUT ANY 

21# WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A 

22# PARTICULAR PURPOSE. See the GNU Lesser General Public License for more details. 

23# 

24# You should have received a copy of the GNU Lesser General Public License along with 

25# PhysioBlocks. If not, see <https://www.gnu.org/licenses/>. 

26 

27from unittest.mock import Mock, patch 

28 

29import numpy as np 

30import pytest 

31 

32from physioblocks.computing.models import ( 

33 ModelComponent, 

34) 

35from physioblocks.computing.quantities import Quantity 

36from physioblocks.simulation.functions import AbstractFunction 

37from physioblocks.simulation.runtime import ( 

38 AbstractSimulation, 

39 ForwardSimulation, 

40 SimulationError, 

41) 

42from physioblocks.simulation.setup import SimulationFactory 

43from physioblocks.simulation.solvers import AbstractSolver, Solution 

44from physioblocks.simulation.state import State 

45 

46 

47def get_solution(converged: bool) -> Solution: 

48 return Solution( 

49 np.array( 

50 [0.1, 0.2, 0.3], 

51 ), 

52 converged, 

53 ) 

54 

55 

56def block_qty_update_func(model: ModelComponent): 

57 return 0.0 

58 

59 

60def time_func(self, time: np.float64) -> np.float64: 

61 return time 

62 

63 

64def no_param_func(self) -> np.float64: 

65 return 0.0 

66 

67 

68def state_func(self, state: State) -> np.float64: 

69 return state.state_vector 

70 

71 

72@pytest.fixture 

73@patch.multiple(AbstractSolver, __abstractmethods__=set()) 

74@patch.multiple(AbstractSimulation, __abstractmethods__=set()) 

75def simulation() -> AbstractSimulation: 

76 sim_factory = SimulationFactory(AbstractSimulation, AbstractSolver()) 

77 return sim_factory.create_simulation() 

78 

79 

80class TestSimulation: 

81 def test_register_time_update_exceptions(self, simulation: AbstractSimulation): 

82 with patch.multiple( 

83 AbstractFunction, __abstractmethods__=set(), eval=time_func 

84 ): 

85 wrong_param_id = "no_param" 

86 err_message = str.format("{0} not found in parameters", wrong_param_id) 

87 with pytest.raises(KeyError, match=err_message): 

88 simulation.register_timed_parameter_update( 

89 wrong_param_id, AbstractFunction() 

90 ) 

91 

92 with patch.multiple( 

93 AbstractFunction, __abstractmethods__=set(), eval=no_param_func 

94 ): 

95 time_triggered_qty_id = "time_triggered_qty" 

96 time_triggered_qty = Quantity(0.0) 

97 simulation.parameters[time_triggered_qty_id] = time_triggered_qty 

98 

99 err_message = str.format( 

100 "{0} is not a time function", type(AbstractFunction()).__name__ 

101 ) 

102 with pytest.raises(TypeError, match=err_message): 

103 simulation.register_timed_parameter_update( 

104 time_triggered_qty_id, AbstractFunction() 

105 ) 

106 

107 @patch.multiple(AbstractFunction, __abstractmethods__=set(), eval=time_func) 

108 def test_update_time(self, simulation: AbstractSimulation): 

109 simulation.time_manager.step_size = 0.1 

110 simulation.time_manager.start = 0.0 

111 simulation.time_manager.duration = 0.2 

112 

113 time_triggered_qty_id = "time_triggered_qty" 

114 time_triggered_qty = Quantity(0.0) 

115 simulation.parameters[time_triggered_qty_id] = time_triggered_qty 

116 

117 simulation.register_timed_parameter_update( 

118 time_triggered_qty_id, AbstractFunction() 

119 ) 

120 assert simulation.parameters[time_triggered_qty_id].current == pytest.approx( 

121 0.0 

122 ) 

123 assert simulation.parameters[time_triggered_qty_id].new == pytest.approx(0.0) 

124 simulation.time_manager.update_time() 

125 simulation._update_time() # noqa SLF001 

126 assert simulation.parameters[time_triggered_qty_id].current == pytest.approx( 

127 0.0 

128 ) 

129 assert simulation.parameters[time_triggered_qty_id].new == pytest.approx(0.1) 

130 simulation.unregister_timed_parameter_update(time_triggered_qty_id) 

131 

132 @patch.multiple(AbstractFunction, __abstractmethods__=set(), eval=no_param_func) 

133 def test_register_simulation_outputs(self, simulation: AbstractSimulation): 

134 # test functions with no parameters 

135 no_param_func_id = "no_param_func" 

136 output_func = AbstractFunction() 

137 simulation.register_output_function(no_param_func_id, output_func) 

138 simulation.outputs_functions.pop(no_param_func_id) 

139 assert no_param_func_id in simulation.outputs_functions 

140 

141 simulation.unregister_output_function(no_param_func_id) 

142 assert no_param_func_id not in simulation.outputs_functions 

143 

144 @patch.multiple(AbstractFunction, __abstractmethods__=set(), eval=no_param_func) 

145 def test_register_simulation_exceptions(self, simulation: AbstractSimulation): 

146 output_id = "output" 

147 error_message = str.format("Output {0} is already defined.", output_id) 

148 with ( 

149 patch.object( 

150 simulation, attribute="_output_functions_updates", new={output_id: None} 

151 ), 

152 pytest.raises(KeyError, match=error_message), 

153 ): 

154 simulation.register_output_function(output_id, AbstractFunction()) 

155 

156 with ( 

157 patch.object( 

158 simulation.saved_quantities, 

159 attribute="_saved_quantities", 

160 new={output_id: None}, 

161 ), 

162 pytest.raises(KeyError, match=error_message), 

163 ): 

164 simulation.register_output_function(output_id, AbstractFunction()) 

165 

166 with ( 

167 patch.object(State, attribute="__contains__", return_value=True), 

168 pytest.raises(KeyError, match=error_message), 

169 ): 

170 simulation.register_output_function(output_id, AbstractFunction()) 

171 

172 error_message = str.format( 

173 "{0} is not a valid output function", object.__name__ 

174 ) 

175 with pytest.raises(TypeError, match=error_message): 

176 simulation.register_output_function(output_id, object()) 

177 

178 @patch.multiple(AbstractFunction, __abstractmethods__=set(), eval=no_param_func) 

179 def test_no_parameter_output_function(self, simulation: AbstractSimulation): 

180 no_param_func_id = "no_param_func" 

181 output_func = AbstractFunction() 

182 simulation.register_output_function(no_param_func_id, output_func) 

183 

184 results = simulation._get_current_result() # noqa SLF001 

185 assert results[no_param_func_id] == pytest.approx(0.0) 

186 

187 @patch.multiple(AbstractFunction, __abstractmethods__=set(), eval=time_func) 

188 def test_time_parameter_output_function(self, simulation: AbstractSimulation): 

189 time_func_id = "time_func" 

190 output_func = AbstractFunction() 

191 simulation.register_output_function(time_func_id, output_func) 

192 

193 simulation.time_manager.time.initialize(0.001) 

194 results = simulation._get_current_result() # noqa SLF001 

195 assert results[time_func_id] == pytest.approx(0.001) 

196 

197 @patch.multiple(AbstractFunction, __abstractmethods__=set(), eval=state_func) 

198 def test_state_parameter_output_function(self, simulation: AbstractSimulation): 

199 state_vector_value = np.array([0.1, 0.2]) 

200 with patch.object( 

201 State, attribute="state_vector", create=True, new=state_vector_value 

202 ): 

203 state_func_id = "state_func" 

204 output_func = AbstractFunction() 

205 simulation.register_output_function(state_func_id, output_func) 

206 

207 results = simulation._get_current_result() # noqa SLF001 

208 assert results[state_func_id] == pytest.approx(state_vector_value) 

209 

210 

211@pytest.fixture 

212@patch.multiple(AbstractSolver, __abstractmethods__=set()) 

213def forward_simulation() -> ForwardSimulation: 

214 sim_factory = SimulationFactory(ForwardSimulation, solver=AbstractSolver()) 

215 sim = sim_factory.create_simulation() 

216 sim.time_manager.start = 0.0 

217 sim.time_manager.duration = 0.010 

218 sim.time_manager.step_size = 0.001 

219 sim.state.add_variable("x", [0.0, 0.0, 0.0]) 

220 return sim 

221 

222 

223class TestForwardSimulation: 

224 def test_run(self, forward_simulation: ForwardSimulation): 

225 sol = get_solution(True) 

226 forward_simulation.state["x"].initialize(sol.x) 

227 with patch.multiple( 

228 AbstractSolver, 

229 __abstractmethods__=set(), 

230 solve=Mock(return_value=sol), 

231 ): 

232 results = forward_simulation.run() 

233 for result in results: 

234 assert result["x"] == pytest.approx(sol.x) 

235 

236 def test_run_no_solution(self, forward_simulation: ForwardSimulation): 

237 sol = get_solution(False) 

238 forward_simulation.state["x"].initialize(sol.x) 

239 with ( 

240 patch.multiple( 

241 AbstractSolver, __abstractmethods__=set(), solve=Mock(return_value=sol) 

242 ), 

243 pytest.raises(SimulationError), 

244 ): 

245 forward_simulation.run()