Coverage for tests / tests_computing / test_models.py: 99%
158 statements
« prev ^ index » next coverage.py v7.13.1, created at 2026-01-09 16:40 +0100
« prev ^ index » next coverage.py v7.13.1, created at 2026-01-09 16:40 +0100
1# SPDX-FileCopyrightText: Copyright INRIA
2#
3# SPDX-License-Identifier: LGPL-3.0-only
4#
5# Copyright INRIA
6#
7# This file is part of PhysioBlocks, a library mostly developed by the
8# [Ananke project-team](https://team.inria.fr/ananke) at INRIA.
9#
10# Authors:
11# - Colin Drieu
12# - Dominique Chapelle
13# - François Kimmig
14# - Philippe Moireau
15#
16# PhysioBlocks is free software: you can redistribute it and/or modify it under the
17# terms of the GNU Lesser General Public License as published by the Free Software
18# Foundation, version 3 of the License.
19#
20# PhysioBlocks is distributed in the hope that it will be useful, but WITHOUT ANY
21# WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A
22# PARTICULAR PURPOSE. See the GNU Lesser General Public License for more details.
23#
24# You should have received a copy of the GNU Lesser General Public License along with
25# PhysioBlocks. If not, see <https://www.gnu.org/licenses/>.
27from typing import Any
28from unittest.mock import patch
30import pytest
32from physioblocks.computing.models import (
33 BlockMetaClass,
34 Expression,
35 ExpressionDefinition,
36 ModelComponentMetaClass,
37 TermDefinition,
38)
39from physioblocks.computing.quantities import Quantity
41TERM_A_ID = "a"
42TERM_B_ID = "b"
43TERM_X_ID = "x"
44TERM_Y_ID = "y"
45TERM_Z_ID = "z"
46UNDEFINED_TERM_ID = "undefined"
47FLUX_TYPE = "flux"
48DOF_ID = "dof"
51def func():
52 return 0
55@pytest.fixture
56def grads():
57 return {"var": func}
60@pytest.fixture
61def term_definition():
62 return TermDefinition(DOF_ID, 2)
65@pytest.fixture
66def expression():
67 return Expression(2, func)
70@pytest.fixture
71def other_expression():
72 return Expression(2, func)
75@pytest.fixture
76def expression_def(expression: Expression, term_definition: TermDefinition):
77 return ExpressionDefinition(expression, [term_definition])
80class TestExpression:
81 def test_constructor(self, grads):
82 expr = Expression(1, func, grads)
83 assert expr.size == 1
84 assert expr.expr_func == func
85 assert expr.expr_gradients == grads
87 def test_set(self, grads):
88 expr = Expression(1, func, grads)
90 with pytest.raises(AttributeError):
91 expr.size = 3
93 with pytest.raises(AttributeError):
94 expr.expr_func = None
96 with pytest.raises(AttributeError):
97 expr.expr_gradients = {}
99 expr.expr_gradients["var"] = None
100 assert expr.expr_gradients == grads
102 def test_eq(self, grads):
103 expr_1 = Expression(1, func, grads)
104 expr_2 = Expression(1, func, grads)
105 expr_3 = Expression(1, func)
106 expr_4 = Expression(2, func, grads)
108 assert expr_1 == expr_1
109 assert expr_1 == expr_2
110 assert expr_1 != expr_3
111 assert expr_1 != expr_4
114class TestTermDefinition:
115 def test_eq(self):
116 term_a = TermDefinition(TERM_A_ID, 1)
117 term_b = TermDefinition(TERM_B_ID, 1)
118 term_c = TermDefinition(TERM_A_ID, 3)
120 assert term_a != term_b
121 assert term_a == term_c
124class TestExpressionDefinition:
125 def test_valid(
126 self,
127 expression_def: ExpressionDefinition,
128 term_definition: TermDefinition,
129 expression: Expression,
130 ):
131 # valid expression with one term
132 assert expression_def.valid is True
134 # valid expression with two terms
135 valid_expr = ExpressionDefinition(
136 expression, [TermDefinition("a", 1, 0), TermDefinition("b", 1, 1)]
137 )
138 assert valid_expr.valid is True
140 # valid expression with unsorted terms
141 valid_expr = ExpressionDefinition(
142 expression, [TermDefinition("a", 1, 1), TermDefinition("b", 1, 0)]
143 )
144 assert valid_expr.valid is True
146 # expression with too many terms
147 invalid_expression = ExpressionDefinition(
148 expression,
149 [term_definition, term_definition],
150 )
151 assert invalid_expression.valid is False
153 # expression with no terms
154 invalid_expression = ExpressionDefinition(
155 expression,
156 [],
157 )
158 assert invalid_expression.valid is False
160 # expression with terms too small
161 invalid_expression = ExpressionDefinition(
162 expression,
163 [TermDefinition(DOF_ID, 1)],
164 )
165 assert invalid_expression.valid is False
167 # invalid expression with repeating indexes
168 valid_expr = ExpressionDefinition(
169 expression, [TermDefinition("a", 1), TermDefinition("b", 1)]
170 )
171 assert valid_expr.valid is False
173 # invalid expression with indexes not starting at zero
174 valid_expr = ExpressionDefinition(
175 expression, [TermDefinition("a", 1, 1), TermDefinition("b", 1, 2)]
176 )
177 assert valid_expr.valid is False
179 def test_get_term(
180 self,
181 expression_def: ExpressionDefinition,
182 term_definition: TermDefinition,
183 expression: Expression,
184 ):
185 # one term expression definition
186 assert expression_def.get_term(0) == term_definition
187 err_mess = str.format(
188 "No term starts at index {0} in expression", 1, expression
189 )
190 with pytest.raises(KeyError, match=err_mess):
191 expression_def.get_term(1)
193 # two terms expression definition
194 term_a = TermDefinition("a", 1, 0)
195 term_b = TermDefinition("b", 1, 1)
196 two_term_expr_definition = ExpressionDefinition(expression, [term_a, term_b])
197 assert two_term_expr_definition.get_term(0) == term_a
198 assert two_term_expr_definition.get_term(1) == term_b
200 err_mess = str.format(
201 "No term starts at index {0} in expression", 2, expression
202 )
203 with pytest.raises(KeyError, match=err_mess):
204 two_term_expr_definition.get_term(2)
207class ModelComponentTest(metaclass=ModelComponentMetaClass):
208 a: Quantity
209 x: Quantity
210 y: Quantity[Any]
211 z: Quantity
212 constant: float # not a local id
213 parameter: str # not a local id
216class TestModelComponentMetaClass:
217 def test_declarations(self, expression: Expression, other_expression: Expression):
218 ModelComponentTest.declares_saved_quantity_expression(
219 TERM_B_ID, expression, 1, 0
220 )
221 ModelComponentTest.declares_internal_expression(TERM_X_ID, expression, 1, 0)
222 ModelComponentTest.declares_internal_expression(TERM_Y_ID, expression, 1, 1)
223 ModelComponentTest.declares_internal_expression(
224 TERM_Z_ID, other_expression, index=0
225 )
227 assert ModelComponentTest.local_ids == [
228 TERM_A_ID,
229 TERM_X_ID,
230 TERM_Y_ID,
231 TERM_Z_ID,
232 TERM_B_ID,
233 ]
235 assert ModelComponentTest.internal_variables == [
236 TermDefinition(TERM_X_ID, 1, 0),
237 TermDefinition(TERM_Y_ID, 1, 1),
238 TermDefinition(TERM_Z_ID, 2, 0),
239 ]
240 assert ModelComponentTest.has_saved_quantity(TERM_B_ID) is True
241 assert ModelComponentTest.has_internal_variable(TERM_X_ID) is True
242 assert ModelComponentTest.has_internal_variable(TERM_Y_ID) is True
243 assert ModelComponentTest.has_internal_variable(TERM_Z_ID) is True
244 assert ModelComponentTest.has_internal_variable(UNDEFINED_TERM_ID) is False
245 assert ModelComponentTest.has_saved_quantity(UNDEFINED_TERM_ID) is False
247 assert ModelComponentTest.saved_quantities == [TermDefinition(TERM_B_ID, 1, 0)]
248 assert len(ModelComponentTest.internal_expressions) == 2
249 assert ModelComponentTest.internal_expressions[0].expression is expression
250 assert ModelComponentTest.internal_expressions[0].terms == [
251 TermDefinition(TERM_X_ID, 1, 0),
252 TermDefinition(TERM_Y_ID, 1, 1),
253 ]
254 assert ModelComponentTest.internal_expressions[1].expression is other_expression
255 assert ModelComponentTest.internal_expressions[1].terms == [
256 TermDefinition(TERM_Z_ID, 2, 0),
257 ]
259 assert len(ModelComponentTest.saved_quantities_expressions) == 1
260 assert (
261 ModelComponentTest.saved_quantities_expressions[0].expression == expression
262 )
263 assert ModelComponentTest.saved_quantities_expressions[0].terms == [
264 TermDefinition(TERM_B_ID, 0)
265 ]
267 assert ModelComponentTest.get_internal_variable_expression(TERM_X_ID) == (
268 expression,
269 1,
270 0,
271 )
272 assert ModelComponentTest.get_internal_variable_expression(TERM_Y_ID) == (
273 expression,
274 1,
275 1,
276 )
277 assert ModelComponentTest.get_internal_variable_expression(TERM_Z_ID) == (
278 other_expression,
279 2,
280 0,
281 )
282 assert ModelComponentTest.get_saved_quantity_expression(TERM_B_ID) == (
283 expression,
284 1,
285 0,
286 )
288 def test_exceptions(
289 self,
290 expression: Expression,
291 ):
292 error_msg = str.format("No expression defined for {0}.", TERM_A_ID)
293 with pytest.raises(KeyError, match=error_msg):
294 ModelComponentTest.get_internal_variable_expression(TERM_A_ID)
296 error_msg = str.format("An expression is already defined for {0}.", TERM_A_ID)
297 with pytest.raises(KeyError, match=error_msg):
298 ModelComponentTest.declares_internal_expression(TERM_A_ID, expression, 1, 0)
299 ModelComponentTest.declares_internal_expression(TERM_A_ID, expression, 1, 0)
301 error_msg = str.format(
302 "{0} definition of size {1} starting at index {2} exceed expression size "
303 "{3}",
304 TERM_X_ID,
305 3,
306 0,
307 expression.size,
308 )
309 with pytest.raises(ValueError, match=error_msg):
310 ModelComponentTest.declares_internal_expression(TERM_X_ID, expression, 3, 0)
312 error_msg = str.format(
313 "{0} definition of size {1} starting at index {2} exceed expression "
314 "size {3}",
315 TERM_X_ID,
316 1,
317 2,
318 expression.size,
319 )
320 with pytest.raises(ValueError, match=error_msg):
321 ModelComponentTest.declares_internal_expression(TERM_X_ID, expression, 1, 2)
323 error_msg = str.format(
324 "{0} definition of size {1} starting at index {2} exceed expression "
325 "size {3}",
326 TERM_X_ID,
327 3,
328 0,
329 expression.size,
330 )
331 with pytest.raises(ValueError, match=error_msg):
332 ModelComponentTest.declares_internal_expression(TERM_X_ID, expression, 3, 0)
335class BlockTest(metaclass=BlockMetaClass):
336 x: Quantity
337 a: Quantity
340class TestBlockMetaClass:
341 def test_declares_flux(
342 self, expression: ExpressionDefinition, expression_def: ExpressionDefinition
343 ):
344 BlockTest.declares_flux_expression(0, DOF_ID, expression)
346 assert BlockTest.nodes == [0]
347 assert BlockTest.local_ids == [TERM_X_ID, TERM_A_ID]
348 assert BlockTest.external_variables_ids == [DOF_ID]
349 assert BlockTest.fluxes_expressions == {0: expression_def}
350 assert BlockTest.fluxes_expressions[0] == expression_def
352 def test_exceptions(
353 self, expression: ExpressionDefinition, expression_def: ExpressionDefinition
354 ):
355 error_message = str.format(
356 "Flux {0} is already defined for the block node at index {1}.",
357 func.__name__,
358 0,
359 )
360 with (
361 pytest.raises(ValueError, match=error_message),
362 patch.object(BlockTest, attribute="_fluxes", new={0: expression_def}),
363 ):
364 BlockTest.declares_flux_expression(0, DOF_ID, expression)