Coverage for physioblocks / computing / models.py: 98%
152 statements
« prev ^ index » next coverage.py v7.13.1, created at 2026-01-09 16:40 +0100
« prev ^ index » next coverage.py v7.13.1, created at 2026-01-09 16:40 +0100
1# SPDX-FileCopyrightText: Copyright INRIA
2#
3# SPDX-License-Identifier: LGPL-3.0-only
4#
5# Copyright INRIA
6#
7# This file is part of PhysioBlocks, a library mostly developed by the
8# [Ananke project-team](https://team.inria.fr/ananke) at INRIA.
9#
10# Authors:
11# - Colin Drieu
12# - Dominique Chapelle
13# - François Kimmig
14# - Philippe Moireau
15#
16# PhysioBlocks is free software: you can redistribute it and/or modify it under the
17# terms of the GNU Lesser General Public License as published by the Free Software
18# Foundation, version 3 of the License.
19#
20# PhysioBlocks is distributed in the hope that it will be useful, but WITHOUT ANY
21# WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A
22# PARTICULAR PURPOSE. See the GNU Lesser General Public License for more details.
23#
24# You should have received a copy of the GNU Lesser General Public License along with
25# PhysioBlocks. If not, see <https://www.gnu.org/licenses/>.
27"""
28Declares the **ModelComponents** and the **Block** base classes along with
29objects to hold the **Flux**, **Internal Expressions** and **Saved Quantities**
30functions.
31"""
33from __future__ import annotations
35from collections.abc import Callable
36from dataclasses import dataclass, field
37from inspect import get_annotations
38from typing import Any, TypeAlias, get_origin
40import numpy as np
41from numpy.typing import NDArray
43from physioblocks.computing.quantities import Quantity
45SystemFunction: TypeAlias = Callable[..., np.float64 | NDArray[np.float64]]
46"""Type alias for functions composing the system"""
49@dataclass(frozen=True)
50class Expression:
51 """Expression(size:int, expr_func: SystemFunction, expr_gradients: Mapping[str, SystemFunction] = {})
52 Store function computing numerical values for terms in the models with the function
53 result size.
55 Optionally, it can define a set of function to compute the partial derivatives
56 for a set of variables.
58 Example
59 ^^^^^^^
61 .. code:: python
63 def f1(x1, x2):
64 return 0.5 * x1 + 0.8 * x2
66 def df1_dx1(x1, x2):
67 return 0.5
69 def df1_dx2(x1, x2):
70 return 0.8
72 expression_f1 = Expression(
73 1, # size
74 f1, # expression function
75 {
76 "x1": df1_dx1,
77 "x2": df1_dx2,
78 } # expressions partial derivatives
79 )
80 """ # noqa: E501
82 size: int
83 """Size of the result of the function"""
85 expr_func: SystemFunction
86 """Function to compute the numerical value"""
88 expr_gradients: dict[str, SystemFunction] = field(default_factory=dict)
89 """
90 Collection of functions to compute the derivatives of the expression for
91 variables
92 """
94 def __eq__(self, value: Any) -> bool:
95 return bool(
96 isinstance(value, Expression)
97 and (
98 self.size == value.size
99 and self.expr_func == value.expr_func
100 and self.expr_gradients == value.expr_gradients
101 )
102 )
105@dataclass(frozen=True)
106class TermDefinition:
107 """Describe Terms defined in an :class:`~physioblocks.computing.models.Expression`.
109 An :class:`~physioblocks.computing.models.Expression` object can define several
110 **Terms**.
112 Example
113 ^^^^^^^
115 .. code:: python
117 def vector_3d(x1, x2, x3):
118 return [x1, x2, x3]
120 expression_vector = Expression(
121 3, # size
122 vector_3d # expression function
123 )
125 x1_term = Term(
126 "x1", # id
127 1, # term size
128 0 # starting index in vector expression
129 )
131 x2_term = Term(
132 "x2", # id
133 1, # term size
134 1 # starting index in vector expression
135 )
137 x3_term = Term(
138 "x3", # id
139 1, # term size
140 2 # starting index in vector expression
141 )
142 """
144 term_id: str
145 """Term id"""
147 size: int
148 """Term size"""
150 index: int = 0
151 """Starting line of the term index in its expression"""
153 def __eq__(self, value: Any) -> bool:
154 return isinstance(value, TermDefinition) and self.term_id == value.term_id
157@dataclass(frozen=True)
158class ExpressionDefinition:
159 """ExpressionDefinition(expression: Expression, terms: list[TermDefinition] = [])
161 Holds an :class:`~physioblocks.computing.models.Expression` and
162 the :class:`~physioblocks.computing.models.TermDefinition` couple.
164 Example
165 ^^^^^^^
167 .. code:: python
169 # Expression Definition for the example above:
170 >>> definition = ExpressionDefinition(
171 expression_vector,
172 [
173 x1_term,
174 x2_term,
175 x3_term
176 ]
177 )
178 """
180 expression: Expression
181 """The expression"""
183 terms: list[TermDefinition] = field(default_factory=list)
184 """The expressed Terms"""
186 @property
187 def valid(self) -> bool:
188 """Check if the definition is complete, meaning a term is
189 associated with each line of the expression and terms do not
190 overlap.
192 :return: True if the definition is valid, False otherwise
193 :rtype: bool
196 .. code ::
198 # From example above
199 >>> definition = ExpressionDefinition(
200 expression_vector,
201 [
202 x1_term,
203 x2_term,
204 x3_term
205 ]
206 )
207 >>> definition.valid # True
209 >>> overlapping_definition = ExpressionDefinition(
210 expression_vector,
211 [
212 x1_term,
213 x2_term,
214 x3_term,
215 x1_term # overlapping term on index 0
216 ]
217 )
218 >>> overlapping_definition.valid # False
220 >>> incomplete_definition = ExpressionDefinition(
221 expression_vector,
222 [
223 x1_term,
224 # missing index 1 term
225 x3_term
226 ]
227 )
228 >>> incomplete_definition.valid # False
230 """
231 used_indexes = []
232 for term in self.terms:
233 for i in range(term.index, term.index + term.size):
234 if i in used_indexes:
235 return False
236 used_indexes.append(i)
238 return len(used_indexes) == self.expression.size and 0 in used_indexes
240 def get_term(self, index: int) -> TermDefinition:
241 """Get term starting in expression at the given index
243 :param index: the first index of the term in the expression.
244 :type index: int
246 :return: the term definition
247 :rtype: TermDefinition
249 .. code ::
251 # From example above
252 >>> definition = ExpressionDefinition(
253 expression_vector,
254 [
255 x1_term,
256 x2_term,
257 x3_term
258 ]
259 )
260 >>> definition.get_term(0) # x1_term
261 >>> definition.get_term(1) # x2_term
262 >>> definition.get_term(2) # x3_term
264 """
265 for term in self.terms:
266 if term.index == index:
267 return term
269 raise KeyError(
270 str.format("No term starts at index {0} in expression, {1}", index, self)
271 )
274ExpressionsCollection: TypeAlias = dict[str, list[ExpressionDefinition]]
275"""
276Type alias for a collection of expressions.
277Keys are the expression types as strings.
278Values are a tuple defining the actual expression and a list of Term Definitions it
279expresses.
280"""
283class ModelComponentMetaClass(type):
284 """Meta-class for :class:`~physioblocks.computing.models.ModelComponent`.
286 Defines the model **Internal Equations** and **Saved Quantities**
287 using :class:`~physioblocks.computing.models.Expression` objects.
289 * **Internal Equations** are expressing **Internal Variables** with a residual
290 equation.
291 * **Saved Quantities** are given with a direct relation
293 """
295 __INTERNAL_EXPRESSION_KEY = "internals"
296 __SAVED_QUANTITIES_EXPRESSION_KEY = "saved_quantities"
298 _expressions: ExpressionsCollection
300 def __init__(cls, *args: Any, **kwargs: Any) -> None:
301 super().__init__(*args, **kwargs)
302 cls._expressions = {
303 cls.__INTERNAL_EXPRESSION_KEY: [],
304 cls.__SAVED_QUANTITIES_EXPRESSION_KEY: [],
305 }
307 @staticmethod
308 def __is_quantity_type(type_to_test: Any) -> bool:
309 if isinstance(type_to_test, type) is False:
310 type_to_test = get_origin(type_to_test)
311 return issubclass(type_to_test, Quantity) is True
313 @property
314 def local_ids(cls) -> list[str]:
315 """
316 Get local parameters ids of the model.
318 Every member of the :class:`~physioblocks.computing.models.ModelComponent`
319 annotated with :class:`~physioblocks.computing.quantities.Quantity` type
320 has a local id.
322 :return: the local ids of the parameters
323 :rtype: list[str]
325 Example
326 ^^^^^^^
328 .. code:: python
330 @dataclass
331 class SimpleModel(metaclass=ModelComponentMetaClass):
333 x1: Quantity
334 x2: Quantity
336 SimpleModel.local_ids # ["x1", "x2"]
337 """
338 annotations = get_annotations(cls)
340 # get the quantities local ids
341 local_ids = [
342 key
343 for key, item in annotations.items()
344 if ModelComponentMetaClass.__is_quantity_type(item)
345 ]
347 # add the saved quantities local ids
348 local_ids.extend(
349 [
350 saved_quantity_expr.term_id
351 for saved_quantity_expr in cls.saved_quantities
352 ]
353 )
355 return local_ids
357 def _get_all_terms(cls, expr_type: str) -> list[TermDefinition]:
358 if expr_type not in cls._expressions:
359 return []
360 # get terms local id and size for all expressions of the given type
361 return [
362 term_def
363 for expression_def in cls._expressions[expr_type]
364 for term_def in expression_def.terms
365 ]
367 def _has_term_defined(cls, tested_id: str, expr_type: str) -> bool:
368 """Get if the given id is defined as the given espresstion type
370 :param variable_id: the id to test
371 :type variable_id: str
373 :param expr_type: the expression type to test
374 :type expr_type: str
376 :return: True if the id defines an expresseion of the given expression type,
377 false otherwise
378 :rtype: bool
379 """
380 return any(term.term_id == tested_id for term in cls._get_all_terms(expr_type))
382 def _get_all_expressions(cls, expr_type: str) -> list[ExpressionDefinition]:
383 # get all expressions of a type with the matching defined terms
384 if expr_type not in cls._expressions:
385 return []
387 return cls._expressions[expr_type].copy()
389 def _get_term_expression(
390 cls, term_id: str, expression_type: str
391 ) -> tuple[Expression, int, int]:
392 """Get the expression, the size and the line index in the expression
393 of the given given local term id.
395 :param term_id: the term id
396 :type term_id: str
398 :param expression_type: the type of expression of the term
399 :type term_id: str
401 :return: the expression, the size and line of the term in the expression.
402 :rtype: tuple[Expression, int, int]
403 """
404 if expression_type not in cls._expressions:
405 raise KeyError(str.format("No expressions of type {0}.", expression_type))
407 for expr_def in cls._expressions[expression_type]:
408 for term_def in expr_def.terms:
409 if term_def.term_id == term_id:
410 return (expr_def.expression, term_def.size, term_def.index)
412 raise KeyError(str.format("No expression defined for {0}.", term_id))
414 def _get_all_terms_ids(cls) -> list[str]:
415 # get all expressed terms local ids
416 return [
417 term_def.term_id
418 for expr_type in cls._expressions
419 for expr_def in cls._expressions[expr_type]
420 for term_def in expr_def.terms
421 ]
423 def _declares_term_expression(
424 cls,
425 term_id: str,
426 expr: Expression,
427 expr_type: str,
428 size: int | None = None,
429 index: int = 0,
430 ) -> None:
431 """
432 Add a term expression to the model definition.
434 :param term_id: the local id of the term.
435 :type term_id: str
437 :param expr: the associated expression
438 :type expr: Expression
440 :param expr_type: the expression type
441 :type expr_type: str
443 :param size: the term size
444 :type size: int
446 :param index: the starting line index of the term in the expression.
447 :type index: str
448 """
449 if size is None:
450 size = expr.size
452 if index + size > expr.size:
453 raise ValueError(
454 str.format(
455 "{0} definition of size {1} starting at index {2} exceed "
456 "expression size {3}",
457 term_id,
458 size,
459 index,
460 expr.size,
461 )
462 )
464 # check if a term with the same id is already used
465 if term_id in cls._get_all_terms_ids():
466 raise KeyError(
467 str.format("An expression is already defined for {0}.", term_id)
468 )
470 # get existing expression definition
471 expression_def = None
472 for expr_def in cls._expressions[expr_type]:
473 if expr_def.expression is expr:
474 expression_def = expr_def
475 break
477 # if not found, create a new one
478 if expression_def is None:
479 expression_def = ExpressionDefinition(expr, [])
480 cls._expressions[expr_type].append(expression_def)
482 # Add the term definition to the expression definition
483 expression_def.terms.append(TermDefinition(term_id, size, index))
485 def declares_internal_expression(
486 cls,
487 variable_id: str,
488 expr: Expression,
489 size: int | None = None,
490 index: int = 0,
491 ) -> None:
492 """
493 Declares a :class:`~physioblocks.computing.models.Expression` object
494 for an **Internal Equation** of the model.
496 :param term_id: the local id of the variable associated with the expression
497 :type term_id: str
499 :param expr: the associated expression
500 :type expr: Expression
502 :param size: the term size
503 :type size: int
505 :param index: the starting line index of the term in the expression.
506 :type index: str
508 Example
509 ^^^^^^^
511 .. code:: python
513 @dataclass
514 class SimpleModel(metaclass=ModelComponentMetaClass):
516 x1: Quantity
517 a: Quantity
518 b: Quantity
520 def x1_residual(self):
521 return self.a.current * self.x1.new - b.current
523 def dx1_residual_dx1(self):
524 return self.a
526 x1_expression = Expression(
527 1,
528 SimpleModel.x1_residual,
529 {
530 "x1": SimpleModel.dx1_residual_dx1
531 }
532 )
533 SimpleModel.declare_internal_expression(
534 "x1", # term id
535 x1_expression, # term expression
536 1, # term size
537 0 # Term index in the expression
538 )
540 """
541 cls._declares_term_expression(
542 variable_id, expr, cls.__INTERNAL_EXPRESSION_KEY, size, index
543 )
545 @property
546 def internal_variables(cls) -> list[TermDefinition]:
547 """Get the :class:`~physioblocks.computing.models.TermDefinition`
548 object describing **internal Variables**
550 :return: the internal variables term definitions
551 :rtype: list[TermDefinition]
552 """
553 return cls._get_all_terms(cls.__INTERNAL_EXPRESSION_KEY)
555 @property
556 def internal_expressions(cls) -> list[ExpressionDefinition]:
557 """Get the all :class:`~physioblocks.computing.models.Expression` object
558 describing **Internal Equations** of the model component.
560 :return: the internal equation expressions
561 :rtype: list[ExpressionDefinition]
562 """
563 return cls._get_all_expressions(cls.__INTERNAL_EXPRESSION_KEY)
565 def get_internal_variable_expression(
566 cls, term_id: str
567 ) -> tuple[Expression, int, int]:
568 """Get the :class:`~physioblocks.computing.models.Expression` for the given
569 **Internal Variable** local name.
571 :param term_id: the term id
572 :type term_id: str
574 :return: the expression, its size and the starting index of the
575 term in the expression.
576 :rtype: tuple[Expression, int, int]
577 """
578 return cls._get_term_expression(term_id, cls.__INTERNAL_EXPRESSION_KEY)
580 def has_internal_variable(cls, variable_id: str) -> bool:
581 """Get if the given name match an **Internal Variable** of the
582 model component
584 :param variable_id: the id to test
585 :type variable_id: str
587 :return: True if the id defines an **Internal Variable**, False otherwise
588 :rtype: bool
589 """
590 return cls._has_term_defined(variable_id, cls.__INTERNAL_EXPRESSION_KEY)
592 def declares_saved_quantity_expression(
593 cls, term_id: str, expr: Expression, size: int | None = None, index: int = 0
594 ) -> None:
595 """
596 Add a **Saved Quantity** :class:`~physioblocks.computing.models.Expression`
597 object to the model definition.
599 :param term_id: the local id of the term.
600 :type term_id: str
602 :param expr: the associated expression
603 :type expr: Expression
605 :param size: the term size
606 :type size: int
608 :param index: the starting line index of the term in the expression.
609 :type index: str
611 Example
612 ^^^^^^^
614 .. code:: python
616 @dataclass
617 class SimpleModel(metaclass=ModelComponentMetaClass):
619 x1: Quantity
621 def x1_squared(self):
622 return x1.current * x1.current
624 x1_squared_expression = Expression(
625 1,
626 SimpleModel.x1_squared
627 )
628 SimpleModel.declares_saved_quantity_expression(
629 "x1_squared", # term id
630 x1_squared_expression, # term expression
631 1, # term size
632 0 # Term index in the expression
633 )
634 """
635 cls._declares_term_expression(
636 term_id, expr, cls.__SAVED_QUANTITIES_EXPRESSION_KEY, size, index
637 )
639 @property
640 def saved_quantities(cls) -> list[TermDefinition]:
641 """Get the saved quantities expressed by the model
643 :return: the saved quantities local id and size.
644 :rtype: list[tuple[str, int]]
645 """
646 return cls._get_all_terms(cls.__SAVED_QUANTITIES_EXPRESSION_KEY)
648 def has_saved_quantity(cls, saved_quantity_id: str) -> bool:
649 """Get if the given id is a saved quantity
651 :param saved_quantity_id: the id to test
652 :type saved_quantity_id: str
654 :return: True if the id defines a saved quantity, false otherwise
655 :rtype: bool
656 """
657 return cls._has_term_defined(
658 saved_quantity_id, cls.__SAVED_QUANTITIES_EXPRESSION_KEY
659 )
661 @property
662 def saved_quantities_expressions(cls) -> list[ExpressionDefinition]:
663 """Get the all saved quantities expressions
665 :return: the saved quantities expressions
666 :rtype: list[ExpressionDefinition]
667 """
668 return cls._get_all_expressions(cls.__SAVED_QUANTITIES_EXPRESSION_KEY)
670 def get_saved_quantity_expression(cls, term_id: str) -> tuple[Expression, int, int]:
671 """Get the expression for the given saved quantity local id.
673 :param term_id: the term id
674 :type term_id: str
676 :return: the expression, the starting index of the term in the expression
677 and its size.
678 :rtype: tuple[Expression, int, int]
679 """
680 return cls._get_term_expression(term_id, cls.__SAVED_QUANTITIES_EXPRESSION_KEY)
683class ModelComponent(metaclass=ModelComponentMetaClass):
684 """
685 Holds parameters and define functions to compute
686 **Internal Equations** and **Saved Quantities**.
687 """
689 def initialize(self) -> None:
690 """Override this method to define specific for model initialization."""
693class BlockMetaClass(ModelComponentMetaClass):
694 """Meta-class for :class:`~physioblocks.computing.models.Block`.
696 Extends :class:`~physioblocks.computing.models.ModelComponentMetaClass` type adding
697 **Flux** :class:`~physioblocks.computing.models.Expression` to the model definition.
699 * Every **Flux** is expressed toward the outside of the **Block**.
700 * Every **Local Nodes** index of the **Block** defines one **Flux**.
702 .. note::
704 :class:`~physioblocks.computing.models.BlockMetaClass` can also define
705 **Internal Equations** and **Saved Quantities**
706 """
708 _fluxes: dict[int, ExpressionDefinition]
709 """Stores the flux expressions at each local nodes"""
711 def __init__(cls, *args: Any, **kwargs: Any) -> None:
712 super().__init__(*args, **kwargs)
713 cls._fluxes = {}
715 def declares_flux_expression(
716 cls, node_index: int, variable_id: str, expr: Expression
717 ) -> None:
718 """
719 Add a flux expression defining a block external relation.
721 :param node_index: the local node index where the flux is shared
722 :type node_index: int
724 :param variable_id: the local id of the variable associated to the node.
725 :type variable_id: str
727 :param expr: the associated expression
728 :type expr: Expression
730 Example
731 ^^^^^^^
733 .. code:: python
735 @dataclass
736 class SimpleBlock(metaclass=BlockMetaClass):
738 q0.new: Quantity
740 def flux_0(self):
741 return q0.new
744 def dflux_0_dq0(self):
745 return 1.0
747 flux_0_expression = Expression(
748 1,
749 SimpleBlock.flux_0,
750 {
751 "q0": SimpleBlock.dflux_0_dq0
752 }
753 )
755 SimpleBlock.declares_flux_expression(
756 0, # Local Node index,
757 "potential_0", # Associated DOF id
758 flux_0_expression, # flux expression
759 )
760 """
762 if node_index in cls.nodes:
763 raise ValueError(
764 str.format(
765 "Flux {0} is already defined for the block node at index {1}.",
766 cls._fluxes[node_index].expression.expr_func.__name__,
767 node_index,
768 )
769 )
771 cls._fluxes[node_index] = ExpressionDefinition(
772 expr, [TermDefinition(variable_id, expr.size)]
773 )
775 @property
776 def nodes(cls) -> list[int]:
777 """Get all the local nodes indexes.
779 :return: The list of indexes.
780 :rtype: list[int]
781 """
782 return [node_index for node_index in cls._fluxes]
784 @property
785 def fluxes_expressions(cls) -> dict[int, ExpressionDefinition]:
786 """Get all the fluxes expressions in the block with the local node where they
787 are shared.
789 :return: the fluxes exressions ordered by node index.
790 :rtype: dict[int, ExpressionDefinition]
791 """
792 return cls._fluxes.copy()
794 @property
795 def external_variables_ids(cls) -> list[str]:
796 """
797 Get local id of variables defined by the flux connecting to a node in the block.
799 :return: a list of all local external variables ids.
800 :rtype: list[str]
801 """
802 return [
803 term.term_id for expr_def in cls._fluxes.values() for term in expr_def.terms
804 ]
807class Block(ModelComponent, metaclass=BlockMetaClass):
808 """
809 Extends :class:`~physioblocks.computing.models.ModelComponent` and declare
810 functions to compute **Flux**.
811 """