Coverage for physioblocks / simulation / runtime.py: 98%

153 statements  

« prev     ^ index     » next       coverage.py v7.13.1, created at 2026-01-09 16:40 +0100

1# SPDX-FileCopyrightText: Copyright INRIA 

2# 

3# SPDX-License-Identifier: LGPL-3.0-only 

4# 

5# Copyright INRIA 

6# 

7# This file is part of PhysioBlocks, a library mostly developed by the 

8# [Ananke project-team](https://team.inria.fr/ananke) at INRIA. 

9# 

10# Authors: 

11# - Colin Drieu 

12# - Dominique Chapelle 

13# - François Kimmig 

14# - Philippe Moireau 

15# 

16# PhysioBlocks is free software: you can redistribute it and/or modify it under the 

17# terms of the GNU Lesser General Public License as published by the Free Software 

18# Foundation, version 3 of the License. 

19# 

20# PhysioBlocks is distributed in the hope that it will be useful, but WITHOUT ANY 

21# WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A 

22# PARTICULAR PURPOSE. See the GNU Lesser General Public License for more details. 

23# 

24# You should have received a copy of the GNU Lesser General Public License along with 

25# PhysioBlocks. If not, see <https://www.gnu.org/licenses/>. 

26 

27""" 

28Defines the **Simulation** classes that define how the simulations runs 

29""" 

30 

31from __future__ import annotations 

32 

33import logging 

34from abc import ABC, abstractmethod 

35from collections.abc import Iterable 

36from typing import Any, TypeAlias 

37 

38import numpy as np 

39from numpy.typing import NDArray 

40 

41from physioblocks.computing.assembling import EqSystem 

42from physioblocks.computing.models import ModelComponent 

43from physioblocks.computing.quantities import Quantity 

44from physioblocks.registers.type_register import register_type 

45from physioblocks.simulation.functions import ( 

46 AbstractFunction, 

47 is_state_function, 

48 is_time_function, 

49) 

50from physioblocks.simulation.saved_quantities import SavedQuantities 

51from physioblocks.simulation.solvers import AbstractSolver, ConvergenceError 

52from physioblocks.simulation.state import STATE_NAME_ID, State 

53from physioblocks.simulation.time_manager import TIME_QUANTITY_ID, TimeManager 

54from physioblocks.utils.exceptions_utils import log_exception 

55 

56Parameters: TypeAlias = dict[str, Quantity[Any]] 

57"""Type alias for quantities collection""" 

58 

59Result: TypeAlias = dict[str, np.float64 | NDArray[np.float64]] 

60"""Type alias for a single result line""" 

61 

62Results: TypeAlias = list[Result] 

63"""Type alias for all the results of the simulation""" 

64 

65 

66_logger = logging.getLogger(__name__) 

67 

68 

69class AbstractSimulation(ABC): 

70 """ 

71 Base class for **Simulations** 

72 

73 .. note:: Use a :class:`~physioblocks.simulation.setup.SimulationFactory` instance 

74 to instanciate simulations. 

75 

76 :param factory: the factory that created the simulation instance. 

77 :type factory: SimulationFactory 

78 

79 :param time_manager: the simulation time manager 

80 :type time_manager: TimeManager 

81 

82 :param solver: the solver to use for simulation steps 

83 :type solver: AbstractSolver 

84 

85 :param state: the simulation state 

86 :type state: State 

87 

88 :param parameters: the simulations quantities for parameters. 

89 :type parameters: Parameters 

90 

91 :param saved_quantities: the **Saved Quantities** register 

92 :type saved_quantities: SavedQuantities 

93 

94 :param models: the mapping of used models with their names 

95 :type models: ModelComponent 

96 

97 :param eq_system: the equation system to solve at each time step 

98 :type eq_system: EqSystem 

99 

100 :param magnitudes: magnitude of the state variables 

101 :type magnitudes: dict[str, float] 

102 """ 

103 

104 def __init__( 

105 self, 

106 factory: Any, 

107 time_manager: TimeManager, 

108 state: State, 

109 parameters: Parameters, 

110 saved_quantities: SavedQuantities, 

111 models: dict[str, ModelComponent], 

112 solver: AbstractSolver, 

113 eq_system: EqSystem, 

114 magnitudes: dict[str, float] | None = None, 

115 ): 

116 self.factory = factory 

117 self.state = state 

118 self.parameters = parameters 

119 self.saved_quantities = saved_quantities 

120 self.models = models 

121 self.time_manager = time_manager 

122 self.solver = solver 

123 self.eq_system = eq_system 

124 if magnitudes is None: 

125 magnitudes = {} 

126 self.magnitudes = self._check_magnitudes(magnitudes, state) 

127 self._timed_updates: dict[str, AbstractFunction] = {} 

128 self._output_functions_updates: dict[str, AbstractFunction] = {} 

129 

130 @property 

131 def update_functions(self) -> dict[str, AbstractFunction]: 

132 """ 

133 Get all functions to update at each time step with their matching quantity 

134 global name. 

135 

136 :return: the update functions 

137 :rtype: dict[str, AbstractFunction] 

138 """ 

139 return self._timed_updates.copy() 

140 

141 @property 

142 def outputs_functions(self) -> dict[str, AbstractFunction]: 

143 """ 

144 Get all functions that compute the additional output after a time step 

145 with their matching output global names. 

146 

147 :return: the output functions 

148 :rtype: dict[str, AbstractFunction] 

149 """ 

150 return self._output_functions_updates.copy() 

151 

152 @property 

153 def quantities(self) -> dict[str, Quantity[Any]]: 

154 """ 

155 Get all the quantities in the simulation from the parameters, the state 

156 and the time manager. 

157 

158 :return: a dictionary containing all the simulation quantities 

159 :rtype: dict[str, Quantity] 

160 """ 

161 quantities: dict[str, Quantity[Any]] = { 

162 TIME_QUANTITY_ID: self.time_manager.time 

163 } 

164 quantities.update(self.parameters) 

165 quantities.update(self.state.variables) 

166 

167 return quantities 

168 

169 def register_timed_parameter_update( 

170 self, parameter_id: str, update_function: AbstractFunction 

171 ) -> None: 

172 """ 

173 Register a simulation function to update the parameters with the given global 

174 name at each time step. 

175 

176 :param parameter_id: the global name of the parameter to update 

177 :type parameter_id: str 

178 

179 :param update_function: the function to call to evaluate the parameter value 

180 :type update_function: AbstractFunction 

181 """ 

182 

183 if parameter_id not in self.parameters: 

184 raise KeyError(str.format("{0} not found in parameters", parameter_id)) 

185 

186 if ( 

187 isinstance(update_function, AbstractFunction) is False 

188 or is_time_function(update_function) is False 

189 ): 

190 raise TypeError( 

191 str.format( 

192 "{0} is not a time function", 

193 type(update_function).__name__, 

194 ) 

195 ) 

196 

197 self._timed_updates[parameter_id] = update_function 

198 

199 def unregister_timed_parameter_update(self, parameter_id: str) -> None: 

200 """ 

201 Unegister a simulation function from the timed updates. 

202 

203 :param parameter_id: the global name of the parameter to unregister. 

204 :type parameter_id: str 

205 """ 

206 self._timed_updates.pop(parameter_id) 

207 

208 def register_output_function( 

209 self, output_id: str, update_function: AbstractFunction 

210 ) -> None: 

211 """ 

212 Register a function that is called to compute an additional output. 

213 

214 :param output_id: the global name of the output in the results 

215 :type output_id: str 

216 

217 :param update_function: the function to compute the output 

218 :type output_id: AbstractFunction 

219 

220 :raise ValueError: Raises a value error when the output id is already defined 

221 in the results 

222 """ 

223 if ( 

224 output_id in self._output_functions_updates 

225 or output_id in self.saved_quantities 

226 or output_id in self.state 

227 ): 

228 raise KeyError(str.format("Output {0} is already defined.", output_id)) 

229 

230 if isinstance(update_function, AbstractFunction) is False: 

231 raise TypeError( 

232 str.format( 

233 "{0} is not a valid output function", 

234 type(update_function).__name__, 

235 ) 

236 ) 

237 

238 self._output_functions_updates[output_id] = update_function 

239 

240 def unregister_output_function(self, output_id: str) -> None: 

241 """ 

242 Unregister a function from the outputs updates. 

243 

244 :param output_id: the global name of the output. 

245 :type output_id: str 

246 """ 

247 self._output_functions_updates.pop(output_id) 

248 

249 def _initialize(self) -> Results: 

250 """Initialize the simulation with current parameters. 

251 

252 This method should be called when overriding the run method. 

253 """ 

254 self._initial_state = self.state.state_vector 

255 _initialize_models(self.models.values()) 

256 

257 # save the initialization 

258 results = [self._get_current_result()] 

259 

260 self.time_manager.initialize() 

261 self.time_manager.update_time() 

262 

263 self.state.set_state_vector(self.state.state_vector) 

264 

265 return results 

266 

267 def _finalize(self) -> None: 

268 """Terminate the simulation reinitializing state and time to initial values. 

269 

270 This method should be called when overriding the run method. 

271 """ 

272 self.time_manager.time.initialize(self.time_manager.start) 

273 self.state.set_state_vector(self._initial_state) 

274 

275 def _check_magnitudes( 

276 self, magnitudes: dict[str, float], state: State 

277 ) -> dict[str, float]: 

278 checked_magnitudes = {} 

279 

280 for variable_id in state: 

281 if variable_id not in magnitudes: 

282 message = str.format( 

283 "No magnitude initialized for variable {0}. Magnitude set to 1.0", 

284 variable_id, 

285 ) 

286 _logger.warning(message) 

287 checked_magnitudes[variable_id] = 1.0 

288 

289 elif magnitudes[variable_id] == 0.0: 

290 message = str.format( 

291 "Magnitude for variable {0} is initialized to 0.0. " 

292 "Replacing with 1.0", 

293 variable_id, 

294 ) 

295 _logger.warning(message) 

296 checked_magnitudes[variable_id] = 1.0 

297 else: 

298 checked_magnitudes[variable_id] = magnitudes[variable_id] 

299 

300 return checked_magnitudes 

301 

302 @abstractmethod 

303 def run(self) -> Results: 

304 """ 

305 Run the simulation, this method should be implemented in child classes. 

306 

307 :return: the list of solution for each time step 

308 :rtype: list[NDArray[float64]] 

309 """ 

310 

311 def _update_time(self) -> None: 

312 """ 

313 Updates all the time triggered updatable parameters. 

314 """ 

315 for param_id, func in self._timed_updates.items(): 

316 self.parameters[param_id].initialize( 

317 func.eval(self.time_manager.time.current) 

318 ) 

319 self.parameters[param_id].update(func.eval(self.time_manager.time.new)) 

320 

321 def _get_current_result(self) -> Result: 

322 result: Result = {} 

323 

324 result[TIME_QUANTITY_ID] = self.time_manager.time.current 

325 result.update( 

326 {var_id: qty.current for var_id, qty in self.state.variables.items()} 

327 ) 

328 

329 self.saved_quantities.update() 

330 result.update( 

331 {qty_id: qty.current for qty_id, qty in self.saved_quantities.items()} 

332 ) 

333 

334 for output_id, update_function in self._output_functions_updates.items(): 

335 arguments: dict[str, Any] = {} 

336 if is_time_function(update_function): 

337 arguments[TIME_QUANTITY_ID] = self.time_manager.time.current 

338 if is_state_function(update_function): 

339 arguments[STATE_NAME_ID] = self.state 

340 

341 result[output_id] = update_function.eval(**arguments) 

342 

343 return result 

344 

345 

346def _initialize_models(models: Iterable[ModelComponent]) -> None: 

347 """ 

348 Initialize all provided models 

349 

350 :param blocks: the blocks to initialize 

351 :type blocks: Iterable[Block] 

352 """ 

353 for block in models: 

354 block.initialize() 

355 

356 

357# Forward simulation type id 

358FORWARD_SIM_ID = "forward_simulation" 

359 

360 

361@register_type(FORWARD_SIM_ID) 

362class ForwardSimulation(AbstractSimulation): 

363 """ 

364 Extend :class:`~.AbstractSimulation` class to define a **Forward Simulation**. 

365 

366 The forward simulation solve the **Equation System** at each time step using 

367 the simulation **Solver**. 

368 

369 If the solver did not converge at a given time step, it breaks the current time 

370 step into smaller steps and try again. 

371 If it still do not converge, it recursivly breaks the current time steps again and 

372 stops if the time step is under the minimum time step allowed by the time manager. 

373 

374 When finding a solution for a reduced time step, the simulation 

375 then tries to solve for the remaining time interval in the current time step. 

376 

377 .. note:: 

378 

379 When breaking a simulation step, the forward simulation still only provide a 

380 result for the time step interval given to the time manager. 

381 

382 """ 

383 

384 def run(self) -> Results: 

385 """ 

386 Solve the system for each time steps. 

387 

388 :return: the list of solution for each time step 

389 :rtype: list[NDArray[float64]] 

390 

391 :raise SimulationError: raise a Simulation Error holding the current results 

392 if the simulation stops before reaching the end time. 

393 """ 

394 # initialize the simulation and save the initial results 

395 results = self._initialize() 

396 

397 try: 

398 while self.time_manager.ended is False: 

399 next_step = self.time_manager.time.new 

400 

401 self._update_time() 

402 

403 while ( 

404 np.abs(next_step - self.time_manager.time.current) 

405 > self.time_manager.min_step 

406 ): 

407 self.state.reset_state_vector() 

408 

409 sol = self.solver.solve(self.state, self.eq_system, self.magnitudes) 

410 

411 if sol.converged is False: 

412 inter_time = 0.5 * self.time_manager.current_step_size 

413 if inter_time < self.time_manager.min_step: 

414 raise ConvergenceError( 

415 str.format( 

416 "The solver did not converge at {0}s for minimal" 

417 "time step {1}", 

418 self.time_manager.time.current, 

419 self.time_manager.min_step, 

420 ) 

421 ) 

422 

423 self.time_manager.current_step_size = inter_time 

424 self.time_manager.time.update( 

425 self.time_manager.time.current 

426 + self.time_manager.current_step_size 

427 ) 

428 else: 

429 self.state.set_state_vector(sol.x) 

430 

431 self.time_manager.update_time() 

432 if ( 

433 np.abs(next_step - self.time_manager.time.current) 

434 >= self.time_manager.min_step 

435 ): 

436 self.time_manager.current_step_size = ( 

437 next_step - self.time_manager.time.current 

438 ) 

439 self.time_manager.time.update(next_step) 

440 else: 

441 self.time_manager.time.initialize(next_step) 

442 self.time_manager.current_step_size = ( 

443 self.time_manager.step_size 

444 ) 

445 self.time_manager.time.update( 

446 self.time_manager.time.current 

447 + self.time_manager.current_step_size 

448 ) 

449 

450 self.state.set_state_vector(sol.x) 

451 results.append(self._get_current_result()) 

452 except Exception as exception: 

453 log_exception( 

454 _logger, 

455 type(exception), 

456 exception, 

457 exception.__traceback__, 

458 logging.DEBUG, 

459 ) 

460 raise SimulationError( 

461 str.format( 

462 "An error caused the simulation to stop prematurely", 

463 intermediate_results=results, 

464 ), 

465 results, 

466 ) from exception 

467 

468 self._finalize() 

469 return results 

470 

471 

472class SimulationError(Exception): 

473 """ 

474 Error raised when the simulation encounter a problem. 

475 """ 

476 

477 intermediate_results: Results 

478 """Results obtained before the simulation error occured""" 

479 

480 def __init__( 

481 self, message: str, intermediate_results: Results, *args: Any, **kwargs: Any 

482 ) -> None: 

483 super().__init__(message, *args, **kwargs) 

484 self.intermediate_results = intermediate_results