Coverage for tests / tests_simulation / test_runtime.py: 99%
125 statements
« prev ^ index » next coverage.py v7.13.1, created at 2026-01-09 16:40 +0100
« prev ^ index » next coverage.py v7.13.1, created at 2026-01-09 16:40 +0100
1# SPDX-FileCopyrightText: Copyright INRIA
2#
3# SPDX-License-Identifier: LGPL-3.0-only
4#
5# Copyright INRIA
6#
7# This file is part of PhysioBlocks, a library mostly developed by the
8# [Ananke project-team](https://team.inria.fr/ananke) at INRIA.
9#
10# Authors:
11# - Colin Drieu
12# - Dominique Chapelle
13# - François Kimmig
14# - Philippe Moireau
15#
16# PhysioBlocks is free software: you can redistribute it and/or modify it under the
17# terms of the GNU Lesser General Public License as published by the Free Software
18# Foundation, version 3 of the License.
19#
20# PhysioBlocks is distributed in the hope that it will be useful, but WITHOUT ANY
21# WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A
22# PARTICULAR PURPOSE. See the GNU Lesser General Public License for more details.
23#
24# You should have received a copy of the GNU Lesser General Public License along with
25# PhysioBlocks. If not, see <https://www.gnu.org/licenses/>.
27from unittest.mock import Mock, patch
29import numpy as np
30import pytest
32from physioblocks.computing.models import (
33 ModelComponent,
34)
35from physioblocks.computing.quantities import Quantity
36from physioblocks.simulation.functions import AbstractFunction
37from physioblocks.simulation.runtime import (
38 AbstractSimulation,
39 ForwardSimulation,
40 SimulationError,
41)
42from physioblocks.simulation.setup import SimulationFactory
43from physioblocks.simulation.solvers import AbstractSolver, Solution
44from physioblocks.simulation.state import State
47def get_solution(converged: bool) -> Solution:
48 return Solution(
49 np.array(
50 [0.1, 0.2, 0.3],
51 ),
52 converged,
53 )
56def block_qty_update_func(model: ModelComponent):
57 return 0.0
60def time_func(self, time: np.float64) -> np.float64:
61 return time
64def no_param_func(self) -> np.float64:
65 return 0.0
68def state_func(self, state: State) -> np.float64:
69 return state.state_vector
72@pytest.fixture
73@patch.multiple(AbstractSolver, __abstractmethods__=set())
74@patch.multiple(AbstractSimulation, __abstractmethods__=set())
75def simulation() -> AbstractSimulation:
76 sim_factory = SimulationFactory(AbstractSimulation, AbstractSolver())
77 return sim_factory.create_simulation()
80class TestSimulation:
81 def test_register_time_update_exceptions(self, simulation: AbstractSimulation):
82 with patch.multiple(
83 AbstractFunction, __abstractmethods__=set(), eval=time_func
84 ):
85 wrong_param_id = "no_param"
86 err_message = str.format("{0} not found in parameters", wrong_param_id)
87 with pytest.raises(KeyError, match=err_message):
88 simulation.register_timed_parameter_update(
89 wrong_param_id, AbstractFunction()
90 )
92 with patch.multiple(
93 AbstractFunction, __abstractmethods__=set(), eval=no_param_func
94 ):
95 time_triggered_qty_id = "time_triggered_qty"
96 time_triggered_qty = Quantity(0.0)
97 simulation.parameters[time_triggered_qty_id] = time_triggered_qty
99 err_message = str.format(
100 "{0} is not a time function", type(AbstractFunction()).__name__
101 )
102 with pytest.raises(TypeError, match=err_message):
103 simulation.register_timed_parameter_update(
104 time_triggered_qty_id, AbstractFunction()
105 )
107 @patch.multiple(AbstractFunction, __abstractmethods__=set(), eval=time_func)
108 def test_update_time(self, simulation: AbstractSimulation):
109 simulation.time_manager.step_size = 0.1
110 simulation.time_manager.start = 0.0
111 simulation.time_manager.duration = 0.2
113 time_triggered_qty_id = "time_triggered_qty"
114 time_triggered_qty = Quantity(0.0)
115 simulation.parameters[time_triggered_qty_id] = time_triggered_qty
117 simulation.register_timed_parameter_update(
118 time_triggered_qty_id, AbstractFunction()
119 )
120 assert simulation.parameters[time_triggered_qty_id].current == pytest.approx(
121 0.0
122 )
123 assert simulation.parameters[time_triggered_qty_id].new == pytest.approx(0.0)
124 simulation.time_manager.update_time()
125 simulation._update_time() # noqa SLF001
126 assert simulation.parameters[time_triggered_qty_id].current == pytest.approx(
127 0.0
128 )
129 assert simulation.parameters[time_triggered_qty_id].new == pytest.approx(0.1)
130 simulation.unregister_timed_parameter_update(time_triggered_qty_id)
132 @patch.multiple(AbstractFunction, __abstractmethods__=set(), eval=no_param_func)
133 def test_register_simulation_outputs(self, simulation: AbstractSimulation):
134 # test functions with no parameters
135 no_param_func_id = "no_param_func"
136 output_func = AbstractFunction()
137 simulation.register_output_function(no_param_func_id, output_func)
138 simulation.outputs_functions.pop(no_param_func_id)
139 assert no_param_func_id in simulation.outputs_functions
141 simulation.unregister_output_function(no_param_func_id)
142 assert no_param_func_id not in simulation.outputs_functions
144 @patch.multiple(AbstractFunction, __abstractmethods__=set(), eval=no_param_func)
145 def test_register_simulation_exceptions(self, simulation: AbstractSimulation):
146 output_id = "output"
147 error_message = str.format("Output {0} is already defined.", output_id)
148 with (
149 patch.object(
150 simulation, attribute="_output_functions_updates", new={output_id: None}
151 ),
152 pytest.raises(KeyError, match=error_message),
153 ):
154 simulation.register_output_function(output_id, AbstractFunction())
156 with (
157 patch.object(
158 simulation.saved_quantities,
159 attribute="_saved_quantities",
160 new={output_id: None},
161 ),
162 pytest.raises(KeyError, match=error_message),
163 ):
164 simulation.register_output_function(output_id, AbstractFunction())
166 with (
167 patch.object(State, attribute="__contains__", return_value=True),
168 pytest.raises(KeyError, match=error_message),
169 ):
170 simulation.register_output_function(output_id, AbstractFunction())
172 error_message = str.format(
173 "{0} is not a valid output function", object.__name__
174 )
175 with pytest.raises(TypeError, match=error_message):
176 simulation.register_output_function(output_id, object())
178 @patch.multiple(AbstractFunction, __abstractmethods__=set(), eval=no_param_func)
179 def test_no_parameter_output_function(self, simulation: AbstractSimulation):
180 no_param_func_id = "no_param_func"
181 output_func = AbstractFunction()
182 simulation.register_output_function(no_param_func_id, output_func)
184 results = simulation._get_current_result() # noqa SLF001
185 assert results[no_param_func_id] == pytest.approx(0.0)
187 @patch.multiple(AbstractFunction, __abstractmethods__=set(), eval=time_func)
188 def test_time_parameter_output_function(self, simulation: AbstractSimulation):
189 time_func_id = "time_func"
190 output_func = AbstractFunction()
191 simulation.register_output_function(time_func_id, output_func)
193 simulation.time_manager.time.initialize(0.001)
194 results = simulation._get_current_result() # noqa SLF001
195 assert results[time_func_id] == pytest.approx(0.001)
197 @patch.multiple(AbstractFunction, __abstractmethods__=set(), eval=state_func)
198 def test_state_parameter_output_function(self, simulation: AbstractSimulation):
199 state_vector_value = np.array([0.1, 0.2])
200 with patch.object(
201 State, attribute="state_vector", create=True, new=state_vector_value
202 ):
203 state_func_id = "state_func"
204 output_func = AbstractFunction()
205 simulation.register_output_function(state_func_id, output_func)
207 results = simulation._get_current_result() # noqa SLF001
208 assert results[state_func_id] == pytest.approx(state_vector_value)
211@pytest.fixture
212@patch.multiple(AbstractSolver, __abstractmethods__=set())
213def forward_simulation() -> ForwardSimulation:
214 sim_factory = SimulationFactory(ForwardSimulation, solver=AbstractSolver())
215 sim = sim_factory.create_simulation()
216 sim.time_manager.start = 0.0
217 sim.time_manager.duration = 0.010
218 sim.time_manager.step_size = 0.001
219 sim.state.add_variable("x", [0.0, 0.0, 0.0])
220 return sim
223class TestForwardSimulation:
224 def test_run(self, forward_simulation: ForwardSimulation):
225 sol = get_solution(True)
226 forward_simulation.state["x"].initialize(sol.x)
227 with patch.multiple(
228 AbstractSolver,
229 __abstractmethods__=set(),
230 solve=Mock(return_value=sol),
231 ):
232 results = forward_simulation.run()
233 for result in results:
234 assert result["x"] == pytest.approx(sol.x)
236 def test_run_no_solution(self, forward_simulation: ForwardSimulation):
237 sol = get_solution(False)
238 forward_simulation.state["x"].initialize(sol.x)
239 with (
240 patch.multiple(
241 AbstractSolver, __abstractmethods__=set(), solve=Mock(return_value=sol)
242 ),
243 pytest.raises(SimulationError),
244 ):
245 forward_simulation.run()