Coverage for physioblocks / simulation / setup.py: 95%

195 statements  

« prev     ^ index     » next       coverage.py v7.13.1, created at 2026-01-09 16:40 +0100

1# SPDX-FileCopyrightText: Copyright INRIA 

2# 

3# SPDX-License-Identifier: LGPL-3.0-only 

4# 

5# Copyright INRIA 

6# 

7# This file is part of PhysioBlocks, a library mostly developed by the 

8# [Ananke project-team](https://team.inria.fr/ananke) at INRIA. 

9# 

10# Authors: 

11# - Colin Drieu 

12# - Dominique Chapelle 

13# - François Kimmig 

14# - Philippe Moireau 

15# 

16# PhysioBlocks is free software: you can redistribute it and/or modify it under the 

17# terms of the GNU Lesser General Public License as published by the Free Software 

18# Foundation, version 3 of the License. 

19# 

20# PhysioBlocks is distributed in the hope that it will be useful, but WITHOUT ANY 

21# WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A 

22# PARTICULAR PURPOSE. See the GNU Lesser General Public License for more details. 

23# 

24# You should have received a copy of the GNU Lesser General Public License along with 

25# PhysioBlocks. If not, see <https://www.gnu.org/licenses/>. 

26 

27""" 

28Defines functions to setup the simulation 

29""" 

30 

31from __future__ import annotations 

32 

33import logging 

34from collections.abc import Mapping 

35from dataclasses import dataclass 

36from os import linesep 

37from typing import Any, TypeAlias 

38 

39from physioblocks.computing.assembling import EqSystem 

40from physioblocks.computing.models import ( 

41 Expression, 

42 ModelComponent, 

43 SystemFunction, 

44) 

45from physioblocks.computing.quantities import ( 

46 Quantity, 

47 mid_point, 

48) 

49from physioblocks.description.blocks import BlockDescription, ModelComponentDescription 

50from physioblocks.description.flux import ( 

51 get_flux_dof_register, 

52) 

53from physioblocks.description.nets import BoundaryCondition, Net 

54from physioblocks.simulation.runtime import AbstractSimulation, Parameters 

55from physioblocks.simulation.saved_quantities import SavedQuantities 

56from physioblocks.simulation.solvers import AbstractSolver, NewtonSolver 

57from physioblocks.simulation.state import State 

58from physioblocks.simulation.time_manager import TIME_QUANTITY_ID, TimeManager 

59 

60_logger = logging.getLogger(__name__) 

61 

62SystemExpressions: TypeAlias = list[tuple[int, Expression, Any]] 

63""" 

64Type Alias matching a set of :class:`~physioblocks.computing.models.Expression` objects 

65with their model instance and their line in the residual. 

66""" 

67 

68__ID_SEPARATOR = "." 

69 

70_flux_dof_register = get_flux_dof_register() 

71 

72 

73@dataclass 

74class _BoundaryConditionsQuantities: 

75 flux: Quantity[Any] 

76 

77 def boundary_condition_func(self) -> Any: 

78 return mid_point(self.flux) 

79 

80 def boundary_condition_grad_func(self) -> Any: 

81 return 0.5 

82 

83 

84def create_models( 

85 model_id: str, 

86 description: ModelComponentDescription, 

87 parameters: dict[str, Quantity[Any]], 

88) -> dict[str, ModelComponent]: 

89 """ 

90 Create a model component instance and its submodels from the given parameters. 

91 

92 :param parameters: the available quantities 

93 :type parameters: dict[str, Quantity] 

94 

95 :return: a dict containing the created model and all its submodels recursively. 

96 :rtype: dict[str, ModelComponent]. 

97 """ 

98 submodels = {} 

99 

100 for submodel_id, submodel_desc in description.submodels.items(): 

101 unique_id = __get_submodel_unique_id(model_id, submodel_id) 

102 submodels.update(create_models(unique_id, submodel_desc, parameters)) 

103 

104 model_params: dict[str, Quantity[Any]] 

105 model_params = {} 

106 

107 for term_id, global_id in description.global_ids.items(): 

108 if term_id not in [ 

109 saved_quantity.term_id 

110 for saved_quantity in description.described_type.saved_quantities 

111 ]: 

112 model_params[term_id] = parameters[global_id] 

113 

114 model = description.described_type(**model_params) 

115 models = {model_id: model} 

116 models.update(submodels) 

117 return models 

118 

119 

120def __get_submodel_unique_id(model_id: str, submodel_id: str) -> str: 

121 return __ID_SEPARATOR.join([model_id, submodel_id]) 

122 

123 

124def build_state(net: Net) -> State: 

125 """ 

126 Build the state of the simulation from the net description. 

127 

128 :param net: the net description 

129 :type net: Net 

130 

131 :return: the initial state 

132 :rtype: State 

133 """ 

134 

135 state = State() 

136 

137 for block in net.blocks.values(): 

138 # Add the internal variables of the blocks 

139 for var_id, var_size in block.internal_variables: 

140 state.add_variable(var_id, var_size * [0.0]) 

141 

142 for node in net.nodes.values(): 

143 if node.is_boundary is True: 

144 # For boundaries, only add dof as variable if the boundary condition is on 

145 # the flux. 

146 # Otherwise, the dof is given, it is not a variable, but a parameter. 

147 for bd in node.boundary_conditions: 

148 if node.has_flux_type(bd.condition_type) is True: 

149 dof = node.get_flux_dof(bd.condition_type) 

150 state.add_variable(dof.dof_id, 0.0) 

151 else: 

152 # add the dofs at the nodes as external variables in the state. 

153 for dof in node.dofs: 

154 state.add_variable(dof.dof_id, 0.0) 

155 

156 return state 

157 

158 

159def build_parameters(net: Net, state: State) -> Parameters: 

160 """ 

161 Build the initial parameter register from the net description and the initial state. 

162 

163 :param net: the net description 

164 :type net: Net 

165 

166 :param net: the state 

167 :type net: State 

168 

169 :return: the initial parameter register 

170 :rtype: Parameters 

171 """ 

172 parameters = {} 

173 

174 for block in net.blocks.values(): 

175 for qty_id in _get_block_qty_ids(block): 

176 if ( 

177 qty_id != TIME_QUANTITY_ID 

178 and qty_id not in state 

179 and qty_id 

180 not in [saved_quantity[0] for saved_quantity in block.saved_quantities] 

181 ): 

182 parameters[qty_id] = Quantity(0) 

183 

184 # add flux boundary conditions 

185 for node in net.nodes.values(): 

186 for bc in node.boundary_conditions: 

187 if node.has_flux_type(bc.condition_type): 

188 parameters[bc.condition_id] = Quantity(0) 

189 

190 return parameters 

191 

192 

193def build_eq_system(expressions: SystemExpressions, state: State) -> EqSystem: 

194 """build_eq_system(expressions: SystemExpressions, state: State) -> EqSystem 

195 

196 Build an :class:`~physioblocks.computing.assembling.EqSystem` instance from set of 

197 :class:`~physioblocks.computing.models.Expression` objects. 

198 

199 :param expressions: The expressions representing the system 

200 :type expressions: Expressions 

201 

202 :param state: the state for the system 

203 :tupe size: State 

204 

205 :return: an equation system initialized with the given expressions 

206 :rtype: EqSystem 

207 """ 

208 eq_system = EqSystem(state.size) 

209 for line_index, expression, parameters in expressions: 

210 expr_grad = _build_gradient(expression, state) 

211 eq_system.add_system_part( 

212 line_index, expression.size, expression.expr_func, expr_grad, parameters 

213 ) 

214 return eq_system 

215 

216 

217def _build_quantities( 

218 parameters: Parameters, state: State, time_manager: TimeManager 

219) -> dict[str, Quantity[Any]]: 

220 """ 

221 Build a dictionary joining all the quantities from 

222 the parameters, the state and the time manager. 

223 

224 :param parameters: the parameters register 

225 :type parameters: Parameters 

226 

227 :param state: the state 

228 :type state: State 

229 

230 :param time_manager: the time manager 

231 :type time_manager: TimeManager 

232 

233 :return: a dictionary containing all the simulation quantities 

234 :rtype: dict[str, Quantity] 

235 """ 

236 quantities = {} 

237 

238 quantities.update(parameters) 

239 quantities.update(state.variables) 

240 

241 quantities[TIME_QUANTITY_ID] = time_manager.time 

242 

243 return quantities 

244 

245 

246def _build_gradient(eq: Expression, state: State) -> dict[int, SystemFunction]: 

247 gradients = {} 

248 for var_id in eq.expr_gradients: 

249 if var_id in state.variables: 

250 gradients[state.get_variable_index(var_id)] = eq.expr_gradients[var_id] 

251 

252 return gradients 

253 

254 

255def _build_boundary_condition_expression( 

256 boundary_condition: BoundaryCondition, quantities: dict[str, Quantity[Any]] 

257) -> tuple[Expression, _BoundaryConditionsQuantities]: 

258 flux_id = boundary_condition.condition_id 

259 flux = quantities[flux_id] 

260 bc_parameters = _BoundaryConditionsQuantities(flux) 

261 flux_expr = Expression( 

262 flux.size, 

263 _BoundaryConditionsQuantities.boundary_condition_func, 

264 {flux_id: _BoundaryConditionsQuantities.boundary_condition_grad_func}, 

265 ) 

266 return flux_expr, bc_parameters 

267 

268 

269def _get_block_qty_ids(block: ModelComponentDescription) -> list[str]: 

270 ids = list(block.global_ids.values()) 

271 

272 for sub_model in block.submodels.values(): 

273 child_ids = _get_block_qty_ids(sub_model) 

274 ids.extend(child_ids) 

275 

276 return ids 

277 

278 

279def __get_model_desc(net: Net, model_id: str) -> ModelComponentDescription: 

280 splitted_id = model_id.split(__ID_SEPARATOR) 

281 submodels: Mapping[str, ModelComponentDescription] = net.blocks 

282 

283 for id_part in splitted_id: 

284 model_desc = submodels[id_part] 

285 submodels = model_desc.submodels 

286 

287 if model_desc is not None: 

288 return model_desc 

289 

290 raise ValueError(str.format("No model named {0} defined in the net.", model_id)) 

291 

292 

293def build_blocks( 

294 net: Net, quantities: dict[str, Quantity[Any]] 

295) -> dict[str, ModelComponent]: 

296 """ 

297 Build all the blocks and their submodels holding the quantities from the net. 

298 

299 :param net: the net 

300 :type net: Net 

301 

302 :param quantities: the simulation quantities 

303 :type quantities: dict[str, Quantity] 

304 """ 

305 

306 block_models = {} 

307 for block_id, block_desc in net.blocks.items(): 

308 block_models.update(create_models(block_id, block_desc, quantities)) 

309 return block_models 

310 

311 

312def _get_internal_expressions( 

313 model_id: str, 

314 model_desc: ModelComponentDescription, 

315 state: State, 

316 models: dict[str, ModelComponent], 

317) -> SystemExpressions: 

318 expressions = [] 

319 for expr_def in model_desc.internal_expressions: 

320 first_term = expr_def.get_term(0) 

321 var_index = state.get_variable_index(first_term.term_id) 

322 expressions.append((var_index, expr_def.expression, models[model_id])) 

323 

324 return expressions 

325 

326 

327def _get_fluxes_expressions( 

328 net: Net, 

329 block_id: str, 

330 block_desc: BlockDescription, 

331 state: State, 

332 model: ModelComponent, 

333) -> SystemExpressions: 

334 fluxes_expr: SystemExpressions = [] 

335 

336 for local_node_index in block_desc.described_type.nodes: 

337 global_node_id = net.local_to_global_node_id(block_id, local_node_index) 

338 node = net.nodes[global_node_id] 

339 

340 # Add the fluxes 

341 flux_expr = block_desc.fluxes[local_node_index] 

342 dof = node.get_flux_dof(block_desc.flux_type) 

343 # if the dof is not in state, it has been fixed with a boundary condition, 

344 # don't add the flux 

345 if dof.dof_id in state: 

346 dof_state_index = state.get_variable_index(dof.dof_id) 

347 fluxes_expr.append((dof_state_index, flux_expr, model)) 

348 

349 return fluxes_expr 

350 

351 

352def _get_model_internal_expressions( 

353 model_id: str, 

354 model_desc: ModelComponentDescription, 

355 state: State, 

356 models: dict[str, ModelComponent], 

357) -> SystemExpressions: 

358 int_expressions: SystemExpressions = [] 

359 

360 int_expressions = _get_internal_expressions(model_id, model_desc, state, models) 

361 

362 for submodel_id, submodel_desc in model_desc.submodels.items(): 

363 submodel_net_id = __get_submodel_unique_id(model_id, submodel_id) 

364 submodel_expressions = _get_model_internal_expressions( 

365 submodel_net_id, submodel_desc, state, models 

366 ) 

367 

368 int_expressions.extend(submodel_expressions) 

369 

370 return int_expressions 

371 

372 

373def _get_block_expressions( 

374 net: Net, 

375 block_id: str, 

376 state: State, 

377 models: dict[str, ModelComponent], 

378) -> SystemExpressions: 

379 # get the expressions defined by the model part of the block (and its submodels) 

380 expressions = _get_model_internal_expressions( 

381 block_id, net.blocks[block_id], state, models 

382 ) 

383 

384 flux_expressions = _get_fluxes_expressions( 

385 net, 

386 block_id, 

387 net.blocks[block_id], 

388 state, 

389 models[block_id], 

390 ) 

391 expressions.extend(flux_expressions) 

392 

393 return expressions 

394 

395 

396def _build_net_expressions( 

397 net: Net, 

398 state: State, 

399 models: dict[str, ModelComponent], 

400 quantities: dict[str, Quantity[Any]], 

401) -> SystemExpressions: 

402 """ 

403 Get all expressions to build the system from the net. 

404 

405 :param net: the net 

406 :type net: Net 

407 

408 :param blocks: the blocks defining the expressions 

409 :type blocks: dict[str, Block] 

410 

411 :param quantities: all the quantities availables 

412 :type quantities: dict[str, Quantity] 

413 

414 :return: a set of expressions 

415 :rtype: SystemExpressions 

416 """ 

417 expressions: SystemExpressions = [] 

418 

419 # Get all blocks expressions 

420 for block_id in net.blocks: 

421 # Add block expressions: 

422 block_expressions = _get_block_expressions(net, block_id, state, models) 

423 expressions.extend(block_expressions) 

424 

425 # Add boundary conditions 

426 bc_expressions = _build_boundary_condition_expressions(net, state, quantities) 

427 expressions.extend(bc_expressions) 

428 

429 return expressions 

430 

431 

432def _build_boundary_condition_expressions( 

433 net: Net, state: State, quantities: dict[str, Quantity[Any]] 

434) -> SystemExpressions: 

435 bc_expressions = [] 

436 

437 for node in net.nodes.values(): 

438 for condition in node.boundary_conditions: 

439 # if the condition is on the flux, add the boundary condition expression 

440 if condition.condition_type in _flux_dof_register.flux_dof_couples: 

441 bc_expr, bc_param = _build_boundary_condition_expression( 

442 condition, quantities 

443 ) 

444 dof = node.get_flux_dof(condition.condition_type) 

445 

446 bc_index = state.get_variable_index(dof.dof_id) 

447 bc_expressions.append((bc_index, bc_expr, bc_param)) 

448 

449 return bc_expressions 

450 

451 

452def _get_model_saved_quantities_expressions( 

453 model_id: str, 

454 model_desc: ModelComponentDescription, 

455 models: dict[str, ModelComponent], 

456) -> list[tuple[str, Expression, ModelComponent, int, int]]: 

457 expressions = [ 

458 ( 

459 term_def.term_id, 

460 saved_qty_expr_def.expression, 

461 models[model_id], 

462 term_def.size, 

463 term_def.index, 

464 ) 

465 for saved_qty_expr_def in model_desc.saved_quantities_expressions 

466 for term_def in saved_qty_expr_def.terms 

467 ] 

468 for submodel_id, submodel_desc in model_desc.submodels.items(): 

469 expressions.extend( 

470 _get_model_saved_quantities_expressions(submodel_id, submodel_desc, models) 

471 ) 

472 

473 return expressions 

474 

475 

476def build_saved_quantities( 

477 net: Net, models: dict[str, ModelComponent] 

478) -> SavedQuantities: 

479 """ 

480 Create the saved quantities register for the simulation. 

481 

482 :param net: the simulation net 

483 :type net: Net 

484 :param models: the models in the simulations 

485 :type models: dict[str, ModelComponent] 

486 :return: the simulation saved quantities 

487 :rtype: SavedQuantities 

488 """ 

489 models_saved_quantities = SavedQuantities() 

490 for model_id, model_desc in net.blocks.items(): 

491 saved_quantities = _get_model_saved_quantities_expressions( 

492 model_id, model_desc, models 

493 ) 

494 

495 for quantity_id, expression, model, size, index in saved_quantities: 

496 models_saved_quantities.register( 

497 quantity_id, expression, model, size, index 

498 ) 

499 

500 return models_saved_quantities 

501 

502 

503class SimulationFactory: 

504 """ 

505 Factory for **Simulation** objects 

506 

507 :param simulation_type: The simulation type to create 

508 :type simulation_type: type[AbstractSimulation] 

509 

510 :param net: the net to initialize the simulation parameters 

511 :type net: Net 

512 

513 :param solver: the solver the simulation will use 

514 :type solver: AbstractSolver 

515 

516 :param simulation_options: additional simulation options depending on the 

517 simulation type 

518 :type simulation_options: dict[str, Any] 

519 """ 

520 

521 def __init__( 

522 self, 

523 simulation_type: type[AbstractSimulation], 

524 solver: AbstractSolver | None = None, 

525 net: Net | None = None, 

526 simulation_options: dict[str, Any] | None = None, 

527 ): 

528 self.simulation_type = simulation_type 

529 self.solver = solver if solver is not None else NewtonSolver() 

530 self.net = net if net is not None else Net() 

531 self.simulation_options = ( 

532 simulation_options if simulation_options is not None else {} 

533 ) 

534 

535 def create_simulation(self) -> AbstractSimulation: 

536 """ 

537 Create a **Simulation** instance. 

538 

539 :return: a simulation instance. 

540 :rtype: AbstractSimulation 

541 """ 

542 

543 if issubclass(self.simulation_type, AbstractSimulation) is False: 

544 raise TypeError( 

545 str.format( 

546 "{0} is not a {1} sub-class.", 

547 self.simulation_type.__name__, 

548 AbstractSimulation.__name__, 

549 ) 

550 ) 

551 

552 time_manager = TimeManager() 

553 state = build_state(self.net) 

554 parameters = build_parameters(self.net, state) 

555 all_quantities = _build_quantities(parameters, state, time_manager) 

556 models = build_blocks(self.net, all_quantities) 

557 saved_quantities = build_saved_quantities(self.net, models) 

558 expressions = _build_net_expressions(self.net, state, models, all_quantities) 

559 eq_system = build_eq_system(expressions, state) 

560 

561 # Log simulation informations 

562 _logger.info(str.format("Net:{1}{0}", self.net, linesep)) 

563 _logger.info(str.format("State:{1}{0}", state, linesep)) 

564 _logger.info(str.format("System:{1}{0}", eq_system, linesep)) 

565 

566 return self.simulation_type( 

567 factory=self, 

568 time_manager=time_manager, 

569 state=state, 

570 parameters=parameters, 

571 saved_quantities=saved_quantities, 

572 models=models, 

573 solver=self.solver, 

574 eq_system=eq_system, 

575 **self.simulation_options, 

576 )