Coverage for tests / tests_simulation / test_setup.py: 99%

98 statements  

« prev     ^ index     » next       coverage.py v7.13.1, created at 2026-01-09 16:40 +0100

1# SPDX-FileCopyrightText: Copyright INRIA 

2# 

3# SPDX-License-Identifier: LGPL-3.0-only 

4# 

5# Copyright INRIA 

6# 

7# This file is part of PhysioBlocks, a library mostly developed by the 

8# [Ananke project-team](https://team.inria.fr/ananke) at INRIA. 

9# 

10# Authors: 

11# - Colin Drieu 

12# - Dominique Chapelle 

13# - François Kimmig 

14# - Philippe Moireau 

15# 

16# PhysioBlocks is free software: you can redistribute it and/or modify it under the 

17# terms of the GNU Lesser General Public License as published by the Free Software 

18# Foundation, version 3 of the License. 

19# 

20# PhysioBlocks is distributed in the hope that it will be useful, but WITHOUT ANY 

21# WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A 

22# PARTICULAR PURPOSE. See the GNU Lesser General Public License for more details. 

23# 

24# You should have received a copy of the GNU Lesser General Public License along with 

25# PhysioBlocks. If not, see <https://www.gnu.org/licenses/>. 

26 

27from dataclasses import dataclass 

28from typing import Any 

29from unittest.mock import patch 

30 

31import pytest 

32 

33import physioblocks.simulation.setup as setup 

34from physioblocks.computing.models import Block, Expression 

35from physioblocks.computing.quantities import Quantity 

36from physioblocks.description.blocks import ID_SEPARATOR, BlockDescription 

37from physioblocks.description.nets import ( 

38 Net, 

39) 

40from physioblocks.simulation.runtime import AbstractSimulation 

41from physioblocks.simulation.solvers import AbstractSolver 

42from physioblocks.simulation.state import State 

43from physioblocks.simulation.time_manager import TIME_QUANTITY_ID 

44 

45SUBBLOCK_ID = "subblock" 

46INTERNAL_VAR_ID = "var_id" 

47FLUX_TYPE_ID = "flux_type" 

48DOF_P1_ID = "p1" 

49DOF_P2_ID = "p2" 

50DOF_TYPE_ID = "dof_type" 

51NODE_0_ID = "node_0" 

52NODE_1_ID = "node_1" 

53BLOCK_ID = "block" 

54FLUX_ID = "flux" 

55DOF_0_ID = ID_SEPARATOR.join([NODE_0_ID, DOF_TYPE_ID]) 

56DOF_1_ID = ID_SEPARATOR.join([NODE_1_ID, DOF_TYPE_ID]) 

57INTERNAL_VARIABLE_LOCAL_ID = "internal_variable" 

58 

59 

60@dataclass 

61class BlockTest(Block): 

62 p1: Quantity[Any] 

63 p2: Quantity[Any] 

64 internal_variable: Quantity[Any] 

65 

66 

67def empty_func(model: BlockTest): 

68 pass 

69 

70 

71BlockTest.declares_flux_expression(0, DOF_P1_ID, Expression(1, empty_func)) 

72BlockTest.declares_flux_expression(1, DOF_P2_ID, Expression(1, empty_func)) 

73BlockTest.declares_internal_expression( 

74 INTERNAL_VARIABLE_LOCAL_ID, Expression(1, empty_func) 

75) 

76 

77 

78@pytest.fixture 

79@patch.multiple( 

80 "physioblocks.description.nets._flux_type_register", 

81 create=True, 

82 _fluxes_types={FLUX_TYPE_ID: DOF_TYPE_ID}, 

83 _dof_types={DOF_TYPE_ID: FLUX_TYPE_ID}, 

84) 

85def net(): 

86 net = Net() 

87 node_0_id = NODE_0_ID 

88 node_1_id = NODE_1_ID 

89 net.add_node(node_0_id) 

90 net.add_node(node_1_id) 

91 net.add_block( 

92 BLOCK_ID, 

93 BlockDescription( 

94 BLOCK_ID, 

95 BlockTest, 

96 FLUX_TYPE_ID, 

97 global_ids={ 

98 DOF_P1_ID: DOF_0_ID, 

99 DOF_P2_ID: DOF_1_ID, 

100 INTERNAL_VARIABLE_LOCAL_ID: INTERNAL_VAR_ID, 

101 }, 

102 ), 

103 {0: node_0_id, 1: node_1_id}, 

104 ) 

105 

106 return net 

107 

108 

109@patch.multiple( 

110 "physioblocks.simulation.setup._flux_dof_register", 

111 create=True, 

112 _fluxes_types={FLUX_TYPE_ID: DOF_TYPE_ID}, 

113 _dof_types={DOF_TYPE_ID: FLUX_TYPE_ID}, 

114) 

115class TestSetupMethods: 

116 def test_build_state(self, net: Net): 

117 state_0 = setup.build_state(net) 

118 assert INTERNAL_VAR_ID in state_0 

119 assert DOF_0_ID in state_0 

120 assert DOF_1_ID in state_0 

121 

122 net.set_boundary(NODE_1_ID, DOF_TYPE_ID, DOF_1_ID) 

123 net.set_boundary(NODE_0_ID, FLUX_TYPE_ID, FLUX_ID) 

124 state_1 = setup.build_state(net) 

125 assert INTERNAL_VAR_ID in state_1 

126 assert DOF_0_ID in state_1 

127 assert DOF_1_ID not in state_1 

128 

129 def test_build_parameters(self, net: Net): 

130 net.set_boundary(NODE_1_ID, DOF_TYPE_ID, DOF_1_ID) 

131 net.set_boundary(NODE_0_ID, FLUX_TYPE_ID, FLUX_ID) 

132 

133 state = setup.build_state(net) 

134 register = setup.build_parameters(net, state) 

135 

136 assert DOF_1_ID in register 

137 assert FLUX_ID in register 

138 assert DOF_0_ID not in register 

139 assert TIME_QUANTITY_ID not in register 

140 assert INTERNAL_VAR_ID not in register 

141 

142 def test_build_eq_system(self): 

143 expression = Expression(1, empty_func, {INTERNAL_VAR_ID: empty_func}) 

144 expressions = [ 

145 (0, expression, None), 

146 ] 

147 state = State() 

148 state.add_variable(INTERNAL_VAR_ID, 0.0) 

149 eq_system = setup.build_eq_system(expressions, state) 

150 assert eq_system.system_size == 1 

151 

152 

153class TestSimulationFactory: 

154 @patch.multiple( 

155 "physioblocks.simulation.setup._flux_dof_register", 

156 create=True, 

157 _fluxes_types={FLUX_TYPE_ID: DOF_TYPE_ID}, 

158 _dof_types={DOF_TYPE_ID: FLUX_TYPE_ID}, 

159 ) 

160 @patch.multiple(AbstractSolver, __abstractmethods__=set()) 

161 @patch.multiple(AbstractSimulation, __abstractmethods__=set()) 

162 def test_create_simulation(self): 

163 net = Net() 

164 node_0_id = NODE_0_ID 

165 node_1_id = NODE_1_ID 

166 net.add_node(node_0_id) 

167 net.add_node(node_1_id) 

168 

169 net.add_block( 

170 BLOCK_ID, 

171 BlockDescription( 

172 BLOCK_ID, 

173 BlockTest, 

174 FLUX_TYPE_ID, 

175 global_ids={ 

176 DOF_P1_ID: DOF_0_ID, 

177 DOF_P2_ID: DOF_1_ID, 

178 INTERNAL_VARIABLE_LOCAL_ID: INTERNAL_VAR_ID, 

179 }, 

180 ), 

181 {0: node_0_id, 1: node_1_id}, 

182 ) 

183 

184 net.set_boundary(node_1_id, DOF_TYPE_ID, DOF_1_ID) 

185 net.set_boundary(node_0_id, FLUX_TYPE_ID, FLUX_ID) 

186 

187 sim_factory = setup.SimulationFactory( 

188 AbstractSimulation, 

189 AbstractSolver(), 

190 net, 

191 ) 

192 sim = sim_factory.create_simulation() 

193 assert sim.state.size == 2 

194 

195 @patch.multiple(AbstractSolver, __abstractmethods__=set()) 

196 def test_wrong_simulation_type(self): 

197 err_message = str.format( 

198 "{0} is not a {1} sub-class.", object.__name__, AbstractSimulation.__name__ 

199 ) 

200 with pytest.raises(TypeError, match=err_message): 

201 sim_factory = setup.SimulationFactory(object, AbstractSolver()) 

202 sim_factory.create_simulation()