Coverage for tests / tests_simulation / test_setup.py: 99%
98 statements
« prev ^ index » next coverage.py v7.13.1, created at 2026-01-09 16:40 +0100
« prev ^ index » next coverage.py v7.13.1, created at 2026-01-09 16:40 +0100
1# SPDX-FileCopyrightText: Copyright INRIA
2#
3# SPDX-License-Identifier: LGPL-3.0-only
4#
5# Copyright INRIA
6#
7# This file is part of PhysioBlocks, a library mostly developed by the
8# [Ananke project-team](https://team.inria.fr/ananke) at INRIA.
9#
10# Authors:
11# - Colin Drieu
12# - Dominique Chapelle
13# - François Kimmig
14# - Philippe Moireau
15#
16# PhysioBlocks is free software: you can redistribute it and/or modify it under the
17# terms of the GNU Lesser General Public License as published by the Free Software
18# Foundation, version 3 of the License.
19#
20# PhysioBlocks is distributed in the hope that it will be useful, but WITHOUT ANY
21# WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A
22# PARTICULAR PURPOSE. See the GNU Lesser General Public License for more details.
23#
24# You should have received a copy of the GNU Lesser General Public License along with
25# PhysioBlocks. If not, see <https://www.gnu.org/licenses/>.
27from dataclasses import dataclass
28from typing import Any
29from unittest.mock import patch
31import pytest
33import physioblocks.simulation.setup as setup
34from physioblocks.computing.models import Block, Expression
35from physioblocks.computing.quantities import Quantity
36from physioblocks.description.blocks import ID_SEPARATOR, BlockDescription
37from physioblocks.description.nets import (
38 Net,
39)
40from physioblocks.simulation.runtime import AbstractSimulation
41from physioblocks.simulation.solvers import AbstractSolver
42from physioblocks.simulation.state import State
43from physioblocks.simulation.time_manager import TIME_QUANTITY_ID
45SUBBLOCK_ID = "subblock"
46INTERNAL_VAR_ID = "var_id"
47FLUX_TYPE_ID = "flux_type"
48DOF_P1_ID = "p1"
49DOF_P2_ID = "p2"
50DOF_TYPE_ID = "dof_type"
51NODE_0_ID = "node_0"
52NODE_1_ID = "node_1"
53BLOCK_ID = "block"
54FLUX_ID = "flux"
55DOF_0_ID = ID_SEPARATOR.join([NODE_0_ID, DOF_TYPE_ID])
56DOF_1_ID = ID_SEPARATOR.join([NODE_1_ID, DOF_TYPE_ID])
57INTERNAL_VARIABLE_LOCAL_ID = "internal_variable"
60@dataclass
61class BlockTest(Block):
62 p1: Quantity[Any]
63 p2: Quantity[Any]
64 internal_variable: Quantity[Any]
67def empty_func(model: BlockTest):
68 pass
71BlockTest.declares_flux_expression(0, DOF_P1_ID, Expression(1, empty_func))
72BlockTest.declares_flux_expression(1, DOF_P2_ID, Expression(1, empty_func))
73BlockTest.declares_internal_expression(
74 INTERNAL_VARIABLE_LOCAL_ID, Expression(1, empty_func)
75)
78@pytest.fixture
79@patch.multiple(
80 "physioblocks.description.nets._flux_type_register",
81 create=True,
82 _fluxes_types={FLUX_TYPE_ID: DOF_TYPE_ID},
83 _dof_types={DOF_TYPE_ID: FLUX_TYPE_ID},
84)
85def net():
86 net = Net()
87 node_0_id = NODE_0_ID
88 node_1_id = NODE_1_ID
89 net.add_node(node_0_id)
90 net.add_node(node_1_id)
91 net.add_block(
92 BLOCK_ID,
93 BlockDescription(
94 BLOCK_ID,
95 BlockTest,
96 FLUX_TYPE_ID,
97 global_ids={
98 DOF_P1_ID: DOF_0_ID,
99 DOF_P2_ID: DOF_1_ID,
100 INTERNAL_VARIABLE_LOCAL_ID: INTERNAL_VAR_ID,
101 },
102 ),
103 {0: node_0_id, 1: node_1_id},
104 )
106 return net
109@patch.multiple(
110 "physioblocks.simulation.setup._flux_dof_register",
111 create=True,
112 _fluxes_types={FLUX_TYPE_ID: DOF_TYPE_ID},
113 _dof_types={DOF_TYPE_ID: FLUX_TYPE_ID},
114)
115class TestSetupMethods:
116 def test_build_state(self, net: Net):
117 state_0 = setup.build_state(net)
118 assert INTERNAL_VAR_ID in state_0
119 assert DOF_0_ID in state_0
120 assert DOF_1_ID in state_0
122 net.set_boundary(NODE_1_ID, DOF_TYPE_ID, DOF_1_ID)
123 net.set_boundary(NODE_0_ID, FLUX_TYPE_ID, FLUX_ID)
124 state_1 = setup.build_state(net)
125 assert INTERNAL_VAR_ID in state_1
126 assert DOF_0_ID in state_1
127 assert DOF_1_ID not in state_1
129 def test_build_parameters(self, net: Net):
130 net.set_boundary(NODE_1_ID, DOF_TYPE_ID, DOF_1_ID)
131 net.set_boundary(NODE_0_ID, FLUX_TYPE_ID, FLUX_ID)
133 state = setup.build_state(net)
134 register = setup.build_parameters(net, state)
136 assert DOF_1_ID in register
137 assert FLUX_ID in register
138 assert DOF_0_ID not in register
139 assert TIME_QUANTITY_ID not in register
140 assert INTERNAL_VAR_ID not in register
142 def test_build_eq_system(self):
143 expression = Expression(1, empty_func, {INTERNAL_VAR_ID: empty_func})
144 expressions = [
145 (0, expression, None),
146 ]
147 state = State()
148 state.add_variable(INTERNAL_VAR_ID, 0.0)
149 eq_system = setup.build_eq_system(expressions, state)
150 assert eq_system.system_size == 1
153class TestSimulationFactory:
154 @patch.multiple(
155 "physioblocks.simulation.setup._flux_dof_register",
156 create=True,
157 _fluxes_types={FLUX_TYPE_ID: DOF_TYPE_ID},
158 _dof_types={DOF_TYPE_ID: FLUX_TYPE_ID},
159 )
160 @patch.multiple(AbstractSolver, __abstractmethods__=set())
161 @patch.multiple(AbstractSimulation, __abstractmethods__=set())
162 def test_create_simulation(self):
163 net = Net()
164 node_0_id = NODE_0_ID
165 node_1_id = NODE_1_ID
166 net.add_node(node_0_id)
167 net.add_node(node_1_id)
169 net.add_block(
170 BLOCK_ID,
171 BlockDescription(
172 BLOCK_ID,
173 BlockTest,
174 FLUX_TYPE_ID,
175 global_ids={
176 DOF_P1_ID: DOF_0_ID,
177 DOF_P2_ID: DOF_1_ID,
178 INTERNAL_VARIABLE_LOCAL_ID: INTERNAL_VAR_ID,
179 },
180 ),
181 {0: node_0_id, 1: node_1_id},
182 )
184 net.set_boundary(node_1_id, DOF_TYPE_ID, DOF_1_ID)
185 net.set_boundary(node_0_id, FLUX_TYPE_ID, FLUX_ID)
187 sim_factory = setup.SimulationFactory(
188 AbstractSimulation,
189 AbstractSolver(),
190 net,
191 )
192 sim = sim_factory.create_simulation()
193 assert sim.state.size == 2
195 @patch.multiple(AbstractSolver, __abstractmethods__=set())
196 def test_wrong_simulation_type(self):
197 err_message = str.format(
198 "{0} is not a {1} sub-class.", object.__name__, AbstractSimulation.__name__
199 )
200 with pytest.raises(TypeError, match=err_message):
201 sim_factory = setup.SimulationFactory(object, AbstractSolver())
202 sim_factory.create_simulation()