Coverage for tests / tests_simulation / test_solvers.py: 100%

60 statements  

« prev     ^ index     » next       coverage.py v7.13.1, created at 2026-01-09 16:40 +0100

1# SPDX-FileCopyrightText: Copyright INRIA 

2# 

3# SPDX-License-Identifier: LGPL-3.0-only 

4# 

5# Copyright INRIA 

6# 

7# This file is part of PhysioBlocks, a library mostly developed by the 

8# [Ananke project-team](https://team.inria.fr/ananke) at INRIA. 

9# 

10# Authors: 

11# - Colin Drieu 

12# - Dominique Chapelle 

13# - François Kimmig 

14# - Philippe Moireau 

15# 

16# PhysioBlocks is free software: you can redistribute it and/or modify it under the 

17# terms of the GNU Lesser General Public License as published by the Free Software 

18# Foundation, version 3 of the License. 

19# 

20# PhysioBlocks is distributed in the hope that it will be useful, but WITHOUT ANY 

21# WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A 

22# PARTICULAR PURPOSE. See the GNU Lesser General Public License for more details. 

23# 

24# You should have received a copy of the GNU Lesser General Public License along with 

25# PhysioBlocks. If not, see <https://www.gnu.org/licenses/>. 

26 

27from unittest.mock import Mock, patch 

28 

29import numpy as np 

30import pytest 

31 

32from physioblocks.computing.assembling import EqSystem 

33from physioblocks.simulation.solvers import AbstractSolver, NewtonSolver 

34from physioblocks.simulation.state import State 

35 

36 

37def mock_residual_converge(): 

38 return np.zeros( 

39 shape=2, 

40 ) 

41 

42 

43def mock_residual_dont_converge(): 

44 return np.ones( 

45 shape=2, 

46 ) 

47 

48 

49def mock_residual_nan(): 

50 vec = np.empty( 

51 shape=2, 

52 ) 

53 vec[:] = np.nan 

54 return vec 

55 

56 

57def mock_residual_inf(): 

58 vec = np.empty( 

59 shape=2, 

60 ) 

61 vec[:] = np.inf 

62 return vec 

63 

64 

65def mock_gradient(): 

66 return np.array( 

67 [[1, 0], [0, 1]], 

68 ) 

69 

70 

71@pytest.fixture 

72def system() -> EqSystem: 

73 return EqSystem(2) 

74 

75 

76@pytest.fixture 

77def state() -> State: 

78 state = State() 

79 state.add_variable("x0", 0.0) 

80 state.add_variable("x1", 0.0) 

81 return state 

82 

83 

84@pytest.fixture 

85def magnitudes() -> dict[str, float]: 

86 return {"x0": 1.0, "x1": 2.0} 

87 

88 

89class TestAbstractSolver: 

90 @patch.multiple(AbstractSolver, __abstractmethods__=set()) 

91 def test_constructor(self): 

92 solver_mag = AbstractSolver(1e-12, 2) 

93 assert solver_mag.iteration_max == 2 

94 assert solver_mag.tolerance == pytest.approx(1e-12) 

95 

96 

97class TestNewtonSolver: 

98 @patch.multiple( 

99 EqSystem, 

100 compute_residual=Mock(return_value=mock_residual_converge()), 

101 compute_gradient=Mock(return_value=mock_gradient()), 

102 ) 

103 def test_converge(self, state, system, magnitudes): 

104 solver = NewtonSolver(1e-9, 10) 

105 sol = solver.solve(state, system, magnitudes) 

106 assert sol.converged is True 

107 assert sol.x == pytest.approx([0, 0]) 

108 

109 @patch.multiple( 

110 EqSystem, 

111 compute_residual=Mock(return_value=mock_residual_dont_converge()), 

112 compute_gradient=Mock(return_value=mock_gradient()), 

113 ) 

114 def test_dont_converge_sol(self, state, system, magnitudes): 

115 solver = NewtonSolver(1e-9, 10) 

116 sol = solver.solve(state, system, magnitudes) 

117 assert sol.converged is False 

118 

119 @patch.multiple( 

120 EqSystem, 

121 compute_residual=Mock(return_value=mock_residual_nan()), 

122 compute_gradient=Mock(return_value=mock_gradient()), 

123 ) 

124 def test_dont_converge_nan(self, state, system, magnitudes): 

125 solver = NewtonSolver(1e-9, 10) 

126 sol = solver.solve(state, system, magnitudes) 

127 assert sol.converged is False 

128 

129 @patch.multiple( 

130 EqSystem, 

131 compute_residual=Mock(return_value=mock_residual_inf()), 

132 compute_gradient=Mock(return_value=mock_gradient()), 

133 ) 

134 def test_dont_converge_inf(self, state, system, magnitudes): 

135 solver = NewtonSolver(1e-9, 10) 

136 sol = solver.solve(state, system, magnitudes) 

137 assert sol.converged is False