Coverage for physioblocks / simulation / setup.py: 95%
195 statements
« prev ^ index » next coverage.py v7.13.1, created at 2026-01-09 16:40 +0100
« prev ^ index » next coverage.py v7.13.1, created at 2026-01-09 16:40 +0100
1# SPDX-FileCopyrightText: Copyright INRIA
2#
3# SPDX-License-Identifier: LGPL-3.0-only
4#
5# Copyright INRIA
6#
7# This file is part of PhysioBlocks, a library mostly developed by the
8# [Ananke project-team](https://team.inria.fr/ananke) at INRIA.
9#
10# Authors:
11# - Colin Drieu
12# - Dominique Chapelle
13# - François Kimmig
14# - Philippe Moireau
15#
16# PhysioBlocks is free software: you can redistribute it and/or modify it under the
17# terms of the GNU Lesser General Public License as published by the Free Software
18# Foundation, version 3 of the License.
19#
20# PhysioBlocks is distributed in the hope that it will be useful, but WITHOUT ANY
21# WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A
22# PARTICULAR PURPOSE. See the GNU Lesser General Public License for more details.
23#
24# You should have received a copy of the GNU Lesser General Public License along with
25# PhysioBlocks. If not, see <https://www.gnu.org/licenses/>.
27"""
28Defines functions to setup the simulation
29"""
31from __future__ import annotations
33import logging
34from collections.abc import Mapping
35from dataclasses import dataclass
36from os import linesep
37from typing import Any, TypeAlias
39from physioblocks.computing.assembling import EqSystem
40from physioblocks.computing.models import (
41 Expression,
42 ModelComponent,
43 SystemFunction,
44)
45from physioblocks.computing.quantities import (
46 Quantity,
47 mid_point,
48)
49from physioblocks.description.blocks import BlockDescription, ModelComponentDescription
50from physioblocks.description.flux import (
51 get_flux_dof_register,
52)
53from physioblocks.description.nets import BoundaryCondition, Net
54from physioblocks.simulation.runtime import AbstractSimulation, Parameters
55from physioblocks.simulation.saved_quantities import SavedQuantities
56from physioblocks.simulation.solvers import AbstractSolver, NewtonSolver
57from physioblocks.simulation.state import State
58from physioblocks.simulation.time_manager import TIME_QUANTITY_ID, TimeManager
60_logger = logging.getLogger(__name__)
62SystemExpressions: TypeAlias = list[tuple[int, Expression, Any]]
63"""
64Type Alias matching a set of :class:`~physioblocks.computing.models.Expression` objects
65with their model instance and their line in the residual.
66"""
68__ID_SEPARATOR = "."
70_flux_dof_register = get_flux_dof_register()
73@dataclass
74class _BoundaryConditionsQuantities:
75 flux: Quantity[Any]
77 def boundary_condition_func(self) -> Any:
78 return mid_point(self.flux)
80 def boundary_condition_grad_func(self) -> Any:
81 return 0.5
84def create_models(
85 model_id: str,
86 description: ModelComponentDescription,
87 parameters: dict[str, Quantity[Any]],
88) -> dict[str, ModelComponent]:
89 """
90 Create a model component instance and its submodels from the given parameters.
92 :param parameters: the available quantities
93 :type parameters: dict[str, Quantity]
95 :return: a dict containing the created model and all its submodels recursively.
96 :rtype: dict[str, ModelComponent].
97 """
98 submodels = {}
100 for submodel_id, submodel_desc in description.submodels.items():
101 unique_id = __get_submodel_unique_id(model_id, submodel_id)
102 submodels.update(create_models(unique_id, submodel_desc, parameters))
104 model_params: dict[str, Quantity[Any]]
105 model_params = {}
107 for term_id, global_id in description.global_ids.items():
108 if term_id not in [
109 saved_quantity.term_id
110 for saved_quantity in description.described_type.saved_quantities
111 ]:
112 model_params[term_id] = parameters[global_id]
114 model = description.described_type(**model_params)
115 models = {model_id: model}
116 models.update(submodels)
117 return models
120def __get_submodel_unique_id(model_id: str, submodel_id: str) -> str:
121 return __ID_SEPARATOR.join([model_id, submodel_id])
124def build_state(net: Net) -> State:
125 """
126 Build the state of the simulation from the net description.
128 :param net: the net description
129 :type net: Net
131 :return: the initial state
132 :rtype: State
133 """
135 state = State()
137 for block in net.blocks.values():
138 # Add the internal variables of the blocks
139 for var_id, var_size in block.internal_variables:
140 state.add_variable(var_id, var_size * [0.0])
142 for node in net.nodes.values():
143 if node.is_boundary is True:
144 # For boundaries, only add dof as variable if the boundary condition is on
145 # the flux.
146 # Otherwise, the dof is given, it is not a variable, but a parameter.
147 for bd in node.boundary_conditions:
148 if node.has_flux_type(bd.condition_type) is True:
149 dof = node.get_flux_dof(bd.condition_type)
150 state.add_variable(dof.dof_id, 0.0)
151 else:
152 # add the dofs at the nodes as external variables in the state.
153 for dof in node.dofs:
154 state.add_variable(dof.dof_id, 0.0)
156 return state
159def build_parameters(net: Net, state: State) -> Parameters:
160 """
161 Build the initial parameter register from the net description and the initial state.
163 :param net: the net description
164 :type net: Net
166 :param net: the state
167 :type net: State
169 :return: the initial parameter register
170 :rtype: Parameters
171 """
172 parameters = {}
174 for block in net.blocks.values():
175 for qty_id in _get_block_qty_ids(block):
176 if (
177 qty_id != TIME_QUANTITY_ID
178 and qty_id not in state
179 and qty_id
180 not in [saved_quantity[0] for saved_quantity in block.saved_quantities]
181 ):
182 parameters[qty_id] = Quantity(0)
184 # add flux boundary conditions
185 for node in net.nodes.values():
186 for bc in node.boundary_conditions:
187 if node.has_flux_type(bc.condition_type):
188 parameters[bc.condition_id] = Quantity(0)
190 return parameters
193def build_eq_system(expressions: SystemExpressions, state: State) -> EqSystem:
194 """build_eq_system(expressions: SystemExpressions, state: State) -> EqSystem
196 Build an :class:`~physioblocks.computing.assembling.EqSystem` instance from set of
197 :class:`~physioblocks.computing.models.Expression` objects.
199 :param expressions: The expressions representing the system
200 :type expressions: Expressions
202 :param state: the state for the system
203 :tupe size: State
205 :return: an equation system initialized with the given expressions
206 :rtype: EqSystem
207 """
208 eq_system = EqSystem(state.size)
209 for line_index, expression, parameters in expressions:
210 expr_grad = _build_gradient(expression, state)
211 eq_system.add_system_part(
212 line_index, expression.size, expression.expr_func, expr_grad, parameters
213 )
214 return eq_system
217def _build_quantities(
218 parameters: Parameters, state: State, time_manager: TimeManager
219) -> dict[str, Quantity[Any]]:
220 """
221 Build a dictionary joining all the quantities from
222 the parameters, the state and the time manager.
224 :param parameters: the parameters register
225 :type parameters: Parameters
227 :param state: the state
228 :type state: State
230 :param time_manager: the time manager
231 :type time_manager: TimeManager
233 :return: a dictionary containing all the simulation quantities
234 :rtype: dict[str, Quantity]
235 """
236 quantities = {}
238 quantities.update(parameters)
239 quantities.update(state.variables)
241 quantities[TIME_QUANTITY_ID] = time_manager.time
243 return quantities
246def _build_gradient(eq: Expression, state: State) -> dict[int, SystemFunction]:
247 gradients = {}
248 for var_id in eq.expr_gradients:
249 if var_id in state.variables:
250 gradients[state.get_variable_index(var_id)] = eq.expr_gradients[var_id]
252 return gradients
255def _build_boundary_condition_expression(
256 boundary_condition: BoundaryCondition, quantities: dict[str, Quantity[Any]]
257) -> tuple[Expression, _BoundaryConditionsQuantities]:
258 flux_id = boundary_condition.condition_id
259 flux = quantities[flux_id]
260 bc_parameters = _BoundaryConditionsQuantities(flux)
261 flux_expr = Expression(
262 flux.size,
263 _BoundaryConditionsQuantities.boundary_condition_func,
264 {flux_id: _BoundaryConditionsQuantities.boundary_condition_grad_func},
265 )
266 return flux_expr, bc_parameters
269def _get_block_qty_ids(block: ModelComponentDescription) -> list[str]:
270 ids = list(block.global_ids.values())
272 for sub_model in block.submodels.values():
273 child_ids = _get_block_qty_ids(sub_model)
274 ids.extend(child_ids)
276 return ids
279def __get_model_desc(net: Net, model_id: str) -> ModelComponentDescription:
280 splitted_id = model_id.split(__ID_SEPARATOR)
281 submodels: Mapping[str, ModelComponentDescription] = net.blocks
283 for id_part in splitted_id:
284 model_desc = submodels[id_part]
285 submodels = model_desc.submodels
287 if model_desc is not None:
288 return model_desc
290 raise ValueError(str.format("No model named {0} defined in the net.", model_id))
293def build_blocks(
294 net: Net, quantities: dict[str, Quantity[Any]]
295) -> dict[str, ModelComponent]:
296 """
297 Build all the blocks and their submodels holding the quantities from the net.
299 :param net: the net
300 :type net: Net
302 :param quantities: the simulation quantities
303 :type quantities: dict[str, Quantity]
304 """
306 block_models = {}
307 for block_id, block_desc in net.blocks.items():
308 block_models.update(create_models(block_id, block_desc, quantities))
309 return block_models
312def _get_internal_expressions(
313 model_id: str,
314 model_desc: ModelComponentDescription,
315 state: State,
316 models: dict[str, ModelComponent],
317) -> SystemExpressions:
318 expressions = []
319 for expr_def in model_desc.internal_expressions:
320 first_term = expr_def.get_term(0)
321 var_index = state.get_variable_index(first_term.term_id)
322 expressions.append((var_index, expr_def.expression, models[model_id]))
324 return expressions
327def _get_fluxes_expressions(
328 net: Net,
329 block_id: str,
330 block_desc: BlockDescription,
331 state: State,
332 model: ModelComponent,
333) -> SystemExpressions:
334 fluxes_expr: SystemExpressions = []
336 for local_node_index in block_desc.described_type.nodes:
337 global_node_id = net.local_to_global_node_id(block_id, local_node_index)
338 node = net.nodes[global_node_id]
340 # Add the fluxes
341 flux_expr = block_desc.fluxes[local_node_index]
342 dof = node.get_flux_dof(block_desc.flux_type)
343 # if the dof is not in state, it has been fixed with a boundary condition,
344 # don't add the flux
345 if dof.dof_id in state:
346 dof_state_index = state.get_variable_index(dof.dof_id)
347 fluxes_expr.append((dof_state_index, flux_expr, model))
349 return fluxes_expr
352def _get_model_internal_expressions(
353 model_id: str,
354 model_desc: ModelComponentDescription,
355 state: State,
356 models: dict[str, ModelComponent],
357) -> SystemExpressions:
358 int_expressions: SystemExpressions = []
360 int_expressions = _get_internal_expressions(model_id, model_desc, state, models)
362 for submodel_id, submodel_desc in model_desc.submodels.items():
363 submodel_net_id = __get_submodel_unique_id(model_id, submodel_id)
364 submodel_expressions = _get_model_internal_expressions(
365 submodel_net_id, submodel_desc, state, models
366 )
368 int_expressions.extend(submodel_expressions)
370 return int_expressions
373def _get_block_expressions(
374 net: Net,
375 block_id: str,
376 state: State,
377 models: dict[str, ModelComponent],
378) -> SystemExpressions:
379 # get the expressions defined by the model part of the block (and its submodels)
380 expressions = _get_model_internal_expressions(
381 block_id, net.blocks[block_id], state, models
382 )
384 flux_expressions = _get_fluxes_expressions(
385 net,
386 block_id,
387 net.blocks[block_id],
388 state,
389 models[block_id],
390 )
391 expressions.extend(flux_expressions)
393 return expressions
396def _build_net_expressions(
397 net: Net,
398 state: State,
399 models: dict[str, ModelComponent],
400 quantities: dict[str, Quantity[Any]],
401) -> SystemExpressions:
402 """
403 Get all expressions to build the system from the net.
405 :param net: the net
406 :type net: Net
408 :param blocks: the blocks defining the expressions
409 :type blocks: dict[str, Block]
411 :param quantities: all the quantities availables
412 :type quantities: dict[str, Quantity]
414 :return: a set of expressions
415 :rtype: SystemExpressions
416 """
417 expressions: SystemExpressions = []
419 # Get all blocks expressions
420 for block_id in net.blocks:
421 # Add block expressions:
422 block_expressions = _get_block_expressions(net, block_id, state, models)
423 expressions.extend(block_expressions)
425 # Add boundary conditions
426 bc_expressions = _build_boundary_condition_expressions(net, state, quantities)
427 expressions.extend(bc_expressions)
429 return expressions
432def _build_boundary_condition_expressions(
433 net: Net, state: State, quantities: dict[str, Quantity[Any]]
434) -> SystemExpressions:
435 bc_expressions = []
437 for node in net.nodes.values():
438 for condition in node.boundary_conditions:
439 # if the condition is on the flux, add the boundary condition expression
440 if condition.condition_type in _flux_dof_register.flux_dof_couples:
441 bc_expr, bc_param = _build_boundary_condition_expression(
442 condition, quantities
443 )
444 dof = node.get_flux_dof(condition.condition_type)
446 bc_index = state.get_variable_index(dof.dof_id)
447 bc_expressions.append((bc_index, bc_expr, bc_param))
449 return bc_expressions
452def _get_model_saved_quantities_expressions(
453 model_id: str,
454 model_desc: ModelComponentDescription,
455 models: dict[str, ModelComponent],
456) -> list[tuple[str, Expression, ModelComponent, int, int]]:
457 expressions = [
458 (
459 term_def.term_id,
460 saved_qty_expr_def.expression,
461 models[model_id],
462 term_def.size,
463 term_def.index,
464 )
465 for saved_qty_expr_def in model_desc.saved_quantities_expressions
466 for term_def in saved_qty_expr_def.terms
467 ]
468 for submodel_id, submodel_desc in model_desc.submodels.items():
469 expressions.extend(
470 _get_model_saved_quantities_expressions(submodel_id, submodel_desc, models)
471 )
473 return expressions
476def build_saved_quantities(
477 net: Net, models: dict[str, ModelComponent]
478) -> SavedQuantities:
479 """
480 Create the saved quantities register for the simulation.
482 :param net: the simulation net
483 :type net: Net
484 :param models: the models in the simulations
485 :type models: dict[str, ModelComponent]
486 :return: the simulation saved quantities
487 :rtype: SavedQuantities
488 """
489 models_saved_quantities = SavedQuantities()
490 for model_id, model_desc in net.blocks.items():
491 saved_quantities = _get_model_saved_quantities_expressions(
492 model_id, model_desc, models
493 )
495 for quantity_id, expression, model, size, index in saved_quantities:
496 models_saved_quantities.register(
497 quantity_id, expression, model, size, index
498 )
500 return models_saved_quantities
503class SimulationFactory:
504 """
505 Factory for **Simulation** objects
507 :param simulation_type: The simulation type to create
508 :type simulation_type: type[AbstractSimulation]
510 :param net: the net to initialize the simulation parameters
511 :type net: Net
513 :param solver: the solver the simulation will use
514 :type solver: AbstractSolver
516 :param simulation_options: additional simulation options depending on the
517 simulation type
518 :type simulation_options: dict[str, Any]
519 """
521 def __init__(
522 self,
523 simulation_type: type[AbstractSimulation],
524 solver: AbstractSolver | None = None,
525 net: Net | None = None,
526 simulation_options: dict[str, Any] | None = None,
527 ):
528 self.simulation_type = simulation_type
529 self.solver = solver if solver is not None else NewtonSolver()
530 self.net = net if net is not None else Net()
531 self.simulation_options = (
532 simulation_options if simulation_options is not None else {}
533 )
535 def create_simulation(self) -> AbstractSimulation:
536 """
537 Create a **Simulation** instance.
539 :return: a simulation instance.
540 :rtype: AbstractSimulation
541 """
543 if issubclass(self.simulation_type, AbstractSimulation) is False:
544 raise TypeError(
545 str.format(
546 "{0} is not a {1} sub-class.",
547 self.simulation_type.__name__,
548 AbstractSimulation.__name__,
549 )
550 )
552 time_manager = TimeManager()
553 state = build_state(self.net)
554 parameters = build_parameters(self.net, state)
555 all_quantities = _build_quantities(parameters, state, time_manager)
556 models = build_blocks(self.net, all_quantities)
557 saved_quantities = build_saved_quantities(self.net, models)
558 expressions = _build_net_expressions(self.net, state, models, all_quantities)
559 eq_system = build_eq_system(expressions, state)
561 # Log simulation informations
562 _logger.info(str.format("Net:{1}{0}", self.net, linesep))
563 _logger.info(str.format("State:{1}{0}", state, linesep))
564 _logger.info(str.format("System:{1}{0}", eq_system, linesep))
566 return self.simulation_type(
567 factory=self,
568 time_manager=time_manager,
569 state=state,
570 parameters=parameters,
571 saved_quantities=saved_quantities,
572 models=models,
573 solver=self.solver,
574 eq_system=eq_system,
575 **self.simulation_options,
576 )