Coverage for physioblocks / simulation / runtime.py: 98%
153 statements
« prev ^ index » next coverage.py v7.13.1, created at 2026-01-09 16:40 +0100
« prev ^ index » next coverage.py v7.13.1, created at 2026-01-09 16:40 +0100
1# SPDX-FileCopyrightText: Copyright INRIA
2#
3# SPDX-License-Identifier: LGPL-3.0-only
4#
5# Copyright INRIA
6#
7# This file is part of PhysioBlocks, a library mostly developed by the
8# [Ananke project-team](https://team.inria.fr/ananke) at INRIA.
9#
10# Authors:
11# - Colin Drieu
12# - Dominique Chapelle
13# - François Kimmig
14# - Philippe Moireau
15#
16# PhysioBlocks is free software: you can redistribute it and/or modify it under the
17# terms of the GNU Lesser General Public License as published by the Free Software
18# Foundation, version 3 of the License.
19#
20# PhysioBlocks is distributed in the hope that it will be useful, but WITHOUT ANY
21# WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A
22# PARTICULAR PURPOSE. See the GNU Lesser General Public License for more details.
23#
24# You should have received a copy of the GNU Lesser General Public License along with
25# PhysioBlocks. If not, see <https://www.gnu.org/licenses/>.
27"""
28Defines the **Simulation** classes that define how the simulations runs
29"""
31from __future__ import annotations
33import logging
34from abc import ABC, abstractmethod
35from collections.abc import Iterable
36from typing import Any, TypeAlias
38import numpy as np
39from numpy.typing import NDArray
41from physioblocks.computing.assembling import EqSystem
42from physioblocks.computing.models import ModelComponent
43from physioblocks.computing.quantities import Quantity
44from physioblocks.registers.type_register import register_type
45from physioblocks.simulation.functions import (
46 AbstractFunction,
47 is_state_function,
48 is_time_function,
49)
50from physioblocks.simulation.saved_quantities import SavedQuantities
51from physioblocks.simulation.solvers import AbstractSolver, ConvergenceError
52from physioblocks.simulation.state import STATE_NAME_ID, State
53from physioblocks.simulation.time_manager import TIME_QUANTITY_ID, TimeManager
54from physioblocks.utils.exceptions_utils import log_exception
56Parameters: TypeAlias = dict[str, Quantity[Any]]
57"""Type alias for quantities collection"""
59Result: TypeAlias = dict[str, np.float64 | NDArray[np.float64]]
60"""Type alias for a single result line"""
62Results: TypeAlias = list[Result]
63"""Type alias for all the results of the simulation"""
66_logger = logging.getLogger(__name__)
69class AbstractSimulation(ABC):
70 """
71 Base class for **Simulations**
73 .. note:: Use a :class:`~physioblocks.simulation.setup.SimulationFactory` instance
74 to instanciate simulations.
76 :param factory: the factory that created the simulation instance.
77 :type factory: SimulationFactory
79 :param time_manager: the simulation time manager
80 :type time_manager: TimeManager
82 :param solver: the solver to use for simulation steps
83 :type solver: AbstractSolver
85 :param state: the simulation state
86 :type state: State
88 :param parameters: the simulations quantities for parameters.
89 :type parameters: Parameters
91 :param saved_quantities: the **Saved Quantities** register
92 :type saved_quantities: SavedQuantities
94 :param models: the mapping of used models with their names
95 :type models: ModelComponent
97 :param eq_system: the equation system to solve at each time step
98 :type eq_system: EqSystem
100 :param magnitudes: magnitude of the state variables
101 :type magnitudes: dict[str, float]
102 """
104 def __init__(
105 self,
106 factory: Any,
107 time_manager: TimeManager,
108 state: State,
109 parameters: Parameters,
110 saved_quantities: SavedQuantities,
111 models: dict[str, ModelComponent],
112 solver: AbstractSolver,
113 eq_system: EqSystem,
114 magnitudes: dict[str, float] | None = None,
115 ):
116 self.factory = factory
117 self.state = state
118 self.parameters = parameters
119 self.saved_quantities = saved_quantities
120 self.models = models
121 self.time_manager = time_manager
122 self.solver = solver
123 self.eq_system = eq_system
124 if magnitudes is None:
125 magnitudes = {}
126 self.magnitudes = self._check_magnitudes(magnitudes, state)
127 self._timed_updates: dict[str, AbstractFunction] = {}
128 self._output_functions_updates: dict[str, AbstractFunction] = {}
130 @property
131 def update_functions(self) -> dict[str, AbstractFunction]:
132 """
133 Get all functions to update at each time step with their matching quantity
134 global name.
136 :return: the update functions
137 :rtype: dict[str, AbstractFunction]
138 """
139 return self._timed_updates.copy()
141 @property
142 def outputs_functions(self) -> dict[str, AbstractFunction]:
143 """
144 Get all functions that compute the additional output after a time step
145 with their matching output global names.
147 :return: the output functions
148 :rtype: dict[str, AbstractFunction]
149 """
150 return self._output_functions_updates.copy()
152 @property
153 def quantities(self) -> dict[str, Quantity[Any]]:
154 """
155 Get all the quantities in the simulation from the parameters, the state
156 and the time manager.
158 :return: a dictionary containing all the simulation quantities
159 :rtype: dict[str, Quantity]
160 """
161 quantities: dict[str, Quantity[Any]] = {
162 TIME_QUANTITY_ID: self.time_manager.time
163 }
164 quantities.update(self.parameters)
165 quantities.update(self.state.variables)
167 return quantities
169 def register_timed_parameter_update(
170 self, parameter_id: str, update_function: AbstractFunction
171 ) -> None:
172 """
173 Register a simulation function to update the parameters with the given global
174 name at each time step.
176 :param parameter_id: the global name of the parameter to update
177 :type parameter_id: str
179 :param update_function: the function to call to evaluate the parameter value
180 :type update_function: AbstractFunction
181 """
183 if parameter_id not in self.parameters:
184 raise KeyError(str.format("{0} not found in parameters", parameter_id))
186 if (
187 isinstance(update_function, AbstractFunction) is False
188 or is_time_function(update_function) is False
189 ):
190 raise TypeError(
191 str.format(
192 "{0} is not a time function",
193 type(update_function).__name__,
194 )
195 )
197 self._timed_updates[parameter_id] = update_function
199 def unregister_timed_parameter_update(self, parameter_id: str) -> None:
200 """
201 Unegister a simulation function from the timed updates.
203 :param parameter_id: the global name of the parameter to unregister.
204 :type parameter_id: str
205 """
206 self._timed_updates.pop(parameter_id)
208 def register_output_function(
209 self, output_id: str, update_function: AbstractFunction
210 ) -> None:
211 """
212 Register a function that is called to compute an additional output.
214 :param output_id: the global name of the output in the results
215 :type output_id: str
217 :param update_function: the function to compute the output
218 :type output_id: AbstractFunction
220 :raise ValueError: Raises a value error when the output id is already defined
221 in the results
222 """
223 if (
224 output_id in self._output_functions_updates
225 or output_id in self.saved_quantities
226 or output_id in self.state
227 ):
228 raise KeyError(str.format("Output {0} is already defined.", output_id))
230 if isinstance(update_function, AbstractFunction) is False:
231 raise TypeError(
232 str.format(
233 "{0} is not a valid output function",
234 type(update_function).__name__,
235 )
236 )
238 self._output_functions_updates[output_id] = update_function
240 def unregister_output_function(self, output_id: str) -> None:
241 """
242 Unregister a function from the outputs updates.
244 :param output_id: the global name of the output.
245 :type output_id: str
246 """
247 self._output_functions_updates.pop(output_id)
249 def _initialize(self) -> Results:
250 """Initialize the simulation with current parameters.
252 This method should be called when overriding the run method.
253 """
254 self._initial_state = self.state.state_vector
255 _initialize_models(self.models.values())
257 # save the initialization
258 results = [self._get_current_result()]
260 self.time_manager.initialize()
261 self.time_manager.update_time()
263 self.state.set_state_vector(self.state.state_vector)
265 return results
267 def _finalize(self) -> None:
268 """Terminate the simulation reinitializing state and time to initial values.
270 This method should be called when overriding the run method.
271 """
272 self.time_manager.time.initialize(self.time_manager.start)
273 self.state.set_state_vector(self._initial_state)
275 def _check_magnitudes(
276 self, magnitudes: dict[str, float], state: State
277 ) -> dict[str, float]:
278 checked_magnitudes = {}
280 for variable_id in state:
281 if variable_id not in magnitudes:
282 message = str.format(
283 "No magnitude initialized for variable {0}. Magnitude set to 1.0",
284 variable_id,
285 )
286 _logger.warning(message)
287 checked_magnitudes[variable_id] = 1.0
289 elif magnitudes[variable_id] == 0.0:
290 message = str.format(
291 "Magnitude for variable {0} is initialized to 0.0. "
292 "Replacing with 1.0",
293 variable_id,
294 )
295 _logger.warning(message)
296 checked_magnitudes[variable_id] = 1.0
297 else:
298 checked_magnitudes[variable_id] = magnitudes[variable_id]
300 return checked_magnitudes
302 @abstractmethod
303 def run(self) -> Results:
304 """
305 Run the simulation, this method should be implemented in child classes.
307 :return: the list of solution for each time step
308 :rtype: list[NDArray[float64]]
309 """
311 def _update_time(self) -> None:
312 """
313 Updates all the time triggered updatable parameters.
314 """
315 for param_id, func in self._timed_updates.items():
316 self.parameters[param_id].initialize(
317 func.eval(self.time_manager.time.current)
318 )
319 self.parameters[param_id].update(func.eval(self.time_manager.time.new))
321 def _get_current_result(self) -> Result:
322 result: Result = {}
324 result[TIME_QUANTITY_ID] = self.time_manager.time.current
325 result.update(
326 {var_id: qty.current for var_id, qty in self.state.variables.items()}
327 )
329 self.saved_quantities.update()
330 result.update(
331 {qty_id: qty.current for qty_id, qty in self.saved_quantities.items()}
332 )
334 for output_id, update_function in self._output_functions_updates.items():
335 arguments: dict[str, Any] = {}
336 if is_time_function(update_function):
337 arguments[TIME_QUANTITY_ID] = self.time_manager.time.current
338 if is_state_function(update_function):
339 arguments[STATE_NAME_ID] = self.state
341 result[output_id] = update_function.eval(**arguments)
343 return result
346def _initialize_models(models: Iterable[ModelComponent]) -> None:
347 """
348 Initialize all provided models
350 :param blocks: the blocks to initialize
351 :type blocks: Iterable[Block]
352 """
353 for block in models:
354 block.initialize()
357# Forward simulation type id
358FORWARD_SIM_ID = "forward_simulation"
361@register_type(FORWARD_SIM_ID)
362class ForwardSimulation(AbstractSimulation):
363 """
364 Extend :class:`~.AbstractSimulation` class to define a **Forward Simulation**.
366 The forward simulation solve the **Equation System** at each time step using
367 the simulation **Solver**.
369 If the solver did not converge at a given time step, it breaks the current time
370 step into smaller steps and try again.
371 If it still do not converge, it recursivly breaks the current time steps again and
372 stops if the time step is under the minimum time step allowed by the time manager.
374 When finding a solution for a reduced time step, the simulation
375 then tries to solve for the remaining time interval in the current time step.
377 .. note::
379 When breaking a simulation step, the forward simulation still only provide a
380 result for the time step interval given to the time manager.
382 """
384 def run(self) -> Results:
385 """
386 Solve the system for each time steps.
388 :return: the list of solution for each time step
389 :rtype: list[NDArray[float64]]
391 :raise SimulationError: raise a Simulation Error holding the current results
392 if the simulation stops before reaching the end time.
393 """
394 # initialize the simulation and save the initial results
395 results = self._initialize()
397 try:
398 while self.time_manager.ended is False:
399 next_step = self.time_manager.time.new
401 self._update_time()
403 while (
404 np.abs(next_step - self.time_manager.time.current)
405 > self.time_manager.min_step
406 ):
407 self.state.reset_state_vector()
409 sol = self.solver.solve(self.state, self.eq_system, self.magnitudes)
411 if sol.converged is False:
412 inter_time = 0.5 * self.time_manager.current_step_size
413 if inter_time < self.time_manager.min_step:
414 raise ConvergenceError(
415 str.format(
416 "The solver did not converge at {0}s for minimal"
417 "time step {1}",
418 self.time_manager.time.current,
419 self.time_manager.min_step,
420 )
421 )
423 self.time_manager.current_step_size = inter_time
424 self.time_manager.time.update(
425 self.time_manager.time.current
426 + self.time_manager.current_step_size
427 )
428 else:
429 self.state.set_state_vector(sol.x)
431 self.time_manager.update_time()
432 if (
433 np.abs(next_step - self.time_manager.time.current)
434 >= self.time_manager.min_step
435 ):
436 self.time_manager.current_step_size = (
437 next_step - self.time_manager.time.current
438 )
439 self.time_manager.time.update(next_step)
440 else:
441 self.time_manager.time.initialize(next_step)
442 self.time_manager.current_step_size = (
443 self.time_manager.step_size
444 )
445 self.time_manager.time.update(
446 self.time_manager.time.current
447 + self.time_manager.current_step_size
448 )
450 self.state.set_state_vector(sol.x)
451 results.append(self._get_current_result())
452 except Exception as exception:
453 log_exception(
454 _logger,
455 type(exception),
456 exception,
457 exception.__traceback__,
458 logging.DEBUG,
459 )
460 raise SimulationError(
461 str.format(
462 "An error caused the simulation to stop prematurely",
463 intermediate_results=results,
464 ),
465 results,
466 ) from exception
468 self._finalize()
469 return results
472class SimulationError(Exception):
473 """
474 Error raised when the simulation encounter a problem.
475 """
477 intermediate_results: Results
478 """Results obtained before the simulation error occured"""
480 def __init__(
481 self, message: str, intermediate_results: Results, *args: Any, **kwargs: Any
482 ) -> None:
483 super().__init__(message, *args, **kwargs)
484 self.intermediate_results = intermediate_results