Coverage for tests / tests_computing / test_models.py: 99%

158 statements  

« prev     ^ index     » next       coverage.py v7.13.1, created at 2026-01-09 16:40 +0100

1# SPDX-FileCopyrightText: Copyright INRIA 

2# 

3# SPDX-License-Identifier: LGPL-3.0-only 

4# 

5# Copyright INRIA 

6# 

7# This file is part of PhysioBlocks, a library mostly developed by the 

8# [Ananke project-team](https://team.inria.fr/ananke) at INRIA. 

9# 

10# Authors: 

11# - Colin Drieu 

12# - Dominique Chapelle 

13# - François Kimmig 

14# - Philippe Moireau 

15# 

16# PhysioBlocks is free software: you can redistribute it and/or modify it under the 

17# terms of the GNU Lesser General Public License as published by the Free Software 

18# Foundation, version 3 of the License. 

19# 

20# PhysioBlocks is distributed in the hope that it will be useful, but WITHOUT ANY 

21# WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A 

22# PARTICULAR PURPOSE. See the GNU Lesser General Public License for more details. 

23# 

24# You should have received a copy of the GNU Lesser General Public License along with 

25# PhysioBlocks. If not, see <https://www.gnu.org/licenses/>. 

26 

27from typing import Any 

28from unittest.mock import patch 

29 

30import pytest 

31 

32from physioblocks.computing.models import ( 

33 BlockMetaClass, 

34 Expression, 

35 ExpressionDefinition, 

36 ModelComponentMetaClass, 

37 TermDefinition, 

38) 

39from physioblocks.computing.quantities import Quantity 

40 

41TERM_A_ID = "a" 

42TERM_B_ID = "b" 

43TERM_X_ID = "x" 

44TERM_Y_ID = "y" 

45TERM_Z_ID = "z" 

46UNDEFINED_TERM_ID = "undefined" 

47FLUX_TYPE = "flux" 

48DOF_ID = "dof" 

49 

50 

51def func(): 

52 return 0 

53 

54 

55@pytest.fixture 

56def grads(): 

57 return {"var": func} 

58 

59 

60@pytest.fixture 

61def term_definition(): 

62 return TermDefinition(DOF_ID, 2) 

63 

64 

65@pytest.fixture 

66def expression(): 

67 return Expression(2, func) 

68 

69 

70@pytest.fixture 

71def other_expression(): 

72 return Expression(2, func) 

73 

74 

75@pytest.fixture 

76def expression_def(expression: Expression, term_definition: TermDefinition): 

77 return ExpressionDefinition(expression, [term_definition]) 

78 

79 

80class TestExpression: 

81 def test_constructor(self, grads): 

82 expr = Expression(1, func, grads) 

83 assert expr.size == 1 

84 assert expr.expr_func == func 

85 assert expr.expr_gradients == grads 

86 

87 def test_set(self, grads): 

88 expr = Expression(1, func, grads) 

89 

90 with pytest.raises(AttributeError): 

91 expr.size = 3 

92 

93 with pytest.raises(AttributeError): 

94 expr.expr_func = None 

95 

96 with pytest.raises(AttributeError): 

97 expr.expr_gradients = {} 

98 

99 expr.expr_gradients["var"] = None 

100 assert expr.expr_gradients == grads 

101 

102 def test_eq(self, grads): 

103 expr_1 = Expression(1, func, grads) 

104 expr_2 = Expression(1, func, grads) 

105 expr_3 = Expression(1, func) 

106 expr_4 = Expression(2, func, grads) 

107 

108 assert expr_1 == expr_1 

109 assert expr_1 == expr_2 

110 assert expr_1 != expr_3 

111 assert expr_1 != expr_4 

112 

113 

114class TestTermDefinition: 

115 def test_eq(self): 

116 term_a = TermDefinition(TERM_A_ID, 1) 

117 term_b = TermDefinition(TERM_B_ID, 1) 

118 term_c = TermDefinition(TERM_A_ID, 3) 

119 

120 assert term_a != term_b 

121 assert term_a == term_c 

122 

123 

124class TestExpressionDefinition: 

125 def test_valid( 

126 self, 

127 expression_def: ExpressionDefinition, 

128 term_definition: TermDefinition, 

129 expression: Expression, 

130 ): 

131 # valid expression with one term 

132 assert expression_def.valid is True 

133 

134 # valid expression with two terms 

135 valid_expr = ExpressionDefinition( 

136 expression, [TermDefinition("a", 1, 0), TermDefinition("b", 1, 1)] 

137 ) 

138 assert valid_expr.valid is True 

139 

140 # valid expression with unsorted terms 

141 valid_expr = ExpressionDefinition( 

142 expression, [TermDefinition("a", 1, 1), TermDefinition("b", 1, 0)] 

143 ) 

144 assert valid_expr.valid is True 

145 

146 # expression with too many terms 

147 invalid_expression = ExpressionDefinition( 

148 expression, 

149 [term_definition, term_definition], 

150 ) 

151 assert invalid_expression.valid is False 

152 

153 # expression with no terms 

154 invalid_expression = ExpressionDefinition( 

155 expression, 

156 [], 

157 ) 

158 assert invalid_expression.valid is False 

159 

160 # expression with terms too small 

161 invalid_expression = ExpressionDefinition( 

162 expression, 

163 [TermDefinition(DOF_ID, 1)], 

164 ) 

165 assert invalid_expression.valid is False 

166 

167 # invalid expression with repeating indexes 

168 valid_expr = ExpressionDefinition( 

169 expression, [TermDefinition("a", 1), TermDefinition("b", 1)] 

170 ) 

171 assert valid_expr.valid is False 

172 

173 # invalid expression with indexes not starting at zero 

174 valid_expr = ExpressionDefinition( 

175 expression, [TermDefinition("a", 1, 1), TermDefinition("b", 1, 2)] 

176 ) 

177 assert valid_expr.valid is False 

178 

179 def test_get_term( 

180 self, 

181 expression_def: ExpressionDefinition, 

182 term_definition: TermDefinition, 

183 expression: Expression, 

184 ): 

185 # one term expression definition 

186 assert expression_def.get_term(0) == term_definition 

187 err_mess = str.format( 

188 "No term starts at index {0} in expression", 1, expression 

189 ) 

190 with pytest.raises(KeyError, match=err_mess): 

191 expression_def.get_term(1) 

192 

193 # two terms expression definition 

194 term_a = TermDefinition("a", 1, 0) 

195 term_b = TermDefinition("b", 1, 1) 

196 two_term_expr_definition = ExpressionDefinition(expression, [term_a, term_b]) 

197 assert two_term_expr_definition.get_term(0) == term_a 

198 assert two_term_expr_definition.get_term(1) == term_b 

199 

200 err_mess = str.format( 

201 "No term starts at index {0} in expression", 2, expression 

202 ) 

203 with pytest.raises(KeyError, match=err_mess): 

204 two_term_expr_definition.get_term(2) 

205 

206 

207class ModelComponentTest(metaclass=ModelComponentMetaClass): 

208 a: Quantity 

209 x: Quantity 

210 y: Quantity[Any] 

211 z: Quantity 

212 constant: float # not a local id 

213 parameter: str # not a local id 

214 

215 

216class TestModelComponentMetaClass: 

217 def test_declarations(self, expression: Expression, other_expression: Expression): 

218 ModelComponentTest.declares_saved_quantity_expression( 

219 TERM_B_ID, expression, 1, 0 

220 ) 

221 ModelComponentTest.declares_internal_expression(TERM_X_ID, expression, 1, 0) 

222 ModelComponentTest.declares_internal_expression(TERM_Y_ID, expression, 1, 1) 

223 ModelComponentTest.declares_internal_expression( 

224 TERM_Z_ID, other_expression, index=0 

225 ) 

226 

227 assert ModelComponentTest.local_ids == [ 

228 TERM_A_ID, 

229 TERM_X_ID, 

230 TERM_Y_ID, 

231 TERM_Z_ID, 

232 TERM_B_ID, 

233 ] 

234 

235 assert ModelComponentTest.internal_variables == [ 

236 TermDefinition(TERM_X_ID, 1, 0), 

237 TermDefinition(TERM_Y_ID, 1, 1), 

238 TermDefinition(TERM_Z_ID, 2, 0), 

239 ] 

240 assert ModelComponentTest.has_saved_quantity(TERM_B_ID) is True 

241 assert ModelComponentTest.has_internal_variable(TERM_X_ID) is True 

242 assert ModelComponentTest.has_internal_variable(TERM_Y_ID) is True 

243 assert ModelComponentTest.has_internal_variable(TERM_Z_ID) is True 

244 assert ModelComponentTest.has_internal_variable(UNDEFINED_TERM_ID) is False 

245 assert ModelComponentTest.has_saved_quantity(UNDEFINED_TERM_ID) is False 

246 

247 assert ModelComponentTest.saved_quantities == [TermDefinition(TERM_B_ID, 1, 0)] 

248 assert len(ModelComponentTest.internal_expressions) == 2 

249 assert ModelComponentTest.internal_expressions[0].expression is expression 

250 assert ModelComponentTest.internal_expressions[0].terms == [ 

251 TermDefinition(TERM_X_ID, 1, 0), 

252 TermDefinition(TERM_Y_ID, 1, 1), 

253 ] 

254 assert ModelComponentTest.internal_expressions[1].expression is other_expression 

255 assert ModelComponentTest.internal_expressions[1].terms == [ 

256 TermDefinition(TERM_Z_ID, 2, 0), 

257 ] 

258 

259 assert len(ModelComponentTest.saved_quantities_expressions) == 1 

260 assert ( 

261 ModelComponentTest.saved_quantities_expressions[0].expression == expression 

262 ) 

263 assert ModelComponentTest.saved_quantities_expressions[0].terms == [ 

264 TermDefinition(TERM_B_ID, 0) 

265 ] 

266 

267 assert ModelComponentTest.get_internal_variable_expression(TERM_X_ID) == ( 

268 expression, 

269 1, 

270 0, 

271 ) 

272 assert ModelComponentTest.get_internal_variable_expression(TERM_Y_ID) == ( 

273 expression, 

274 1, 

275 1, 

276 ) 

277 assert ModelComponentTest.get_internal_variable_expression(TERM_Z_ID) == ( 

278 other_expression, 

279 2, 

280 0, 

281 ) 

282 assert ModelComponentTest.get_saved_quantity_expression(TERM_B_ID) == ( 

283 expression, 

284 1, 

285 0, 

286 ) 

287 

288 def test_exceptions( 

289 self, 

290 expression: Expression, 

291 ): 

292 error_msg = str.format("No expression defined for {0}.", TERM_A_ID) 

293 with pytest.raises(KeyError, match=error_msg): 

294 ModelComponentTest.get_internal_variable_expression(TERM_A_ID) 

295 

296 error_msg = str.format("An expression is already defined for {0}.", TERM_A_ID) 

297 with pytest.raises(KeyError, match=error_msg): 

298 ModelComponentTest.declares_internal_expression(TERM_A_ID, expression, 1, 0) 

299 ModelComponentTest.declares_internal_expression(TERM_A_ID, expression, 1, 0) 

300 

301 error_msg = str.format( 

302 "{0} definition of size {1} starting at index {2} exceed expression size " 

303 "{3}", 

304 TERM_X_ID, 

305 3, 

306 0, 

307 expression.size, 

308 ) 

309 with pytest.raises(ValueError, match=error_msg): 

310 ModelComponentTest.declares_internal_expression(TERM_X_ID, expression, 3, 0) 

311 

312 error_msg = str.format( 

313 "{0} definition of size {1} starting at index {2} exceed expression " 

314 "size {3}", 

315 TERM_X_ID, 

316 1, 

317 2, 

318 expression.size, 

319 ) 

320 with pytest.raises(ValueError, match=error_msg): 

321 ModelComponentTest.declares_internal_expression(TERM_X_ID, expression, 1, 2) 

322 

323 error_msg = str.format( 

324 "{0} definition of size {1} starting at index {2} exceed expression " 

325 "size {3}", 

326 TERM_X_ID, 

327 3, 

328 0, 

329 expression.size, 

330 ) 

331 with pytest.raises(ValueError, match=error_msg): 

332 ModelComponentTest.declares_internal_expression(TERM_X_ID, expression, 3, 0) 

333 

334 

335class BlockTest(metaclass=BlockMetaClass): 

336 x: Quantity 

337 a: Quantity 

338 

339 

340class TestBlockMetaClass: 

341 def test_declares_flux( 

342 self, expression: ExpressionDefinition, expression_def: ExpressionDefinition 

343 ): 

344 BlockTest.declares_flux_expression(0, DOF_ID, expression) 

345 

346 assert BlockTest.nodes == [0] 

347 assert BlockTest.local_ids == [TERM_X_ID, TERM_A_ID] 

348 assert BlockTest.external_variables_ids == [DOF_ID] 

349 assert BlockTest.fluxes_expressions == {0: expression_def} 

350 assert BlockTest.fluxes_expressions[0] == expression_def 

351 

352 def test_exceptions( 

353 self, expression: ExpressionDefinition, expression_def: ExpressionDefinition 

354 ): 

355 error_message = str.format( 

356 "Flux {0} is already defined for the block node at index {1}.", 

357 func.__name__, 

358 0, 

359 ) 

360 with ( 

361 pytest.raises(ValueError, match=error_message), 

362 patch.object(BlockTest, attribute="_fluxes", new={0: expression_def}), 

363 ): 

364 BlockTest.declares_flux_expression(0, DOF_ID, expression)