Coverage for physioblocks / simulation / solvers.py: 99%

85 statements  

« prev     ^ index     » next       coverage.py v7.13.1, created at 2026-01-09 16:40 +0100

1# SPDX-FileCopyrightText: Copyright INRIA 

2# 

3# SPDX-License-Identifier: LGPL-3.0-only 

4# 

5# Copyright INRIA 

6# 

7# This file is part of PhysioBlocks, a library mostly developed by the 

8# [Ananke project-team](https://team.inria.fr/ananke) at INRIA. 

9# 

10# Authors: 

11# - Colin Drieu 

12# - Dominique Chapelle 

13# - François Kimmig 

14# - Philippe Moireau 

15# 

16# PhysioBlocks is free software: you can redistribute it and/or modify it under the 

17# terms of the GNU Lesser General Public License as published by the Free Software 

18# Foundation, version 3 of the License. 

19# 

20# PhysioBlocks is distributed in the hope that it will be useful, but WITHOUT ANY 

21# WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A 

22# PARTICULAR PURPOSE. See the GNU Lesser General Public License for more details. 

23# 

24# You should have received a copy of the GNU Lesser General Public License along with 

25# PhysioBlocks. If not, see <https://www.gnu.org/licenses/>. 

26 

27"""Declare a generic **Solver** class and solver implementations""" 

28 

29from __future__ import annotations 

30 

31import logging 

32from abc import ABC, abstractmethod 

33from dataclasses import dataclass 

34 

35import numpy as np 

36from numpy.typing import NDArray 

37 

38from physioblocks.computing.assembling import EqSystem 

39from physioblocks.registers.type_register import register_type 

40from physioblocks.simulation.state import State 

41from physioblocks.utils.exceptions_utils import log_exception 

42 

43_logger = logging.getLogger(__name__) 

44 

45 

46@dataclass(frozen=True) 

47class Solution: 

48 """ 

49 Represent the solution return by a solver. 

50 """ 

51 

52 x: NDArray[np.float64] 

53 """the actual solution""" 

54 

55 converged: bool 

56 """get if the solver converged.""" 

57 

58 

59class ConvergenceError(Exception): 

60 """ 

61 Error raised when the solver did not converged. 

62 """ 

63 

64 

65class AbstractSolver(ABC): 

66 """ 

67 Base class for solvers. 

68 """ 

69 

70 iteration_max: int 

71 """the solver maximum allowed number of iterations""" 

72 

73 tolerance: float 

74 """the solver tolerance""" 

75 

76 def __init__( 

77 self, 

78 tolerance: float = 1e-9, 

79 iteration_max: int = 10, 

80 ) -> None: 

81 self.tolerance = tolerance 

82 self.iteration_max = iteration_max 

83 

84 def _get_state_magnitude( 

85 self, state: State, magnitudes: dict[str, float] | None = None 

86 ) -> NDArray[np.float64]: 

87 if magnitudes is None: 

88 return np.ones( 

89 state.size, 

90 ) 

91 

92 mag_dict = {} 

93 for var_mag_key, var_mag_value in magnitudes.items(): 

94 var_index = state.get_variable_index(var_mag_key) 

95 mag_dict[var_index] = var_mag_value 

96 sorted_mag = sorted(mag_dict.items()) 

97 state_mag_list = [x[1] for x in sorted_mag] 

98 return np.array( 

99 state_mag_list, 

100 ) 

101 

102 @abstractmethod 

103 def solve( 

104 self, 

105 state: State, 

106 system: EqSystem, 

107 magnitudes: dict[str, float] | None = None, 

108 ) -> Solution: 

109 """ 

110 Child classes have to override this method 

111 

112 :return: the solution of the solver 

113 :rtype: _Array 

114 """ 

115 

116 

117# Type id for the Newton Solver 

118NEWTON_SOLVER_TYPE_ID = "newton_solver" 

119 

120 

121@register_type(NEWTON_SOLVER_TYPE_ID) 

122class NewtonSolver(AbstractSolver): 

123 """ 

124 Implementation of the :class:`~.AbstractSolver` class using a **Newton method**. 

125 """ 

126 

127 def _compute_residual_and_gradient( 

128 self, system: EqSystem 

129 ) -> tuple[NDArray[np.float64], NDArray[np.float64]]: 

130 res = system.compute_residual() 

131 grad = system.compute_gradient() 

132 return res, grad 

133 

134 def _compute_new_state( 

135 self, 

136 state: State, 

137 res: NDArray[np.float64], 

138 grad: NDArray[np.float64], 

139 state_mag: NDArray[np.float64], 

140 ) -> NDArray[np.float64]: 

141 res_grad_sol = np.linalg.solve(grad, res) 

142 x = state.state_vector - res_grad_sol * state_mag 

143 return x 

144 

145 def _compute_res_grad_mag( 

146 self, gradient: NDArray[np.float64], state_mag: NDArray[np.float64] 

147 ) -> tuple[NDArray[np.float64], NDArray[np.float64]]: 

148 state_mag_line = np.atleast_2d(state_mag) 

149 state_mag_col = state_mag_line.T 

150 

151 res_mag = gradient @ state_mag_col 

152 abs_res_mag = np.abs(res_mag) 

153 res_mag_inv = 1.0 / abs_res_mag 

154 

155 grad_mag_inv = res_mag_inv @ state_mag_line 

156 return res_mag_inv.flatten(), grad_mag_inv 

157 

158 def _rescale_res_grad( 

159 self, 

160 residual: NDArray[np.float64], 

161 res_mag_inv: NDArray[np.float64], 

162 gradient: NDArray[np.float64], 

163 grad_mag_inv: NDArray[np.float64], 

164 ) -> tuple[NDArray[np.float64], NDArray[np.float64]]: 

165 res_rescaled = residual * res_mag_inv 

166 grad_rescaled = gradient * grad_mag_inv 

167 return res_rescaled, grad_rescaled 

168 

169 def solve( 

170 self, 

171 state: State, 

172 system: EqSystem, 

173 magnitudes: dict[str, float] | None = None, 

174 ) -> Solution: 

175 """ 

176 Solve the equation system using the Newton method. 

177 

178 :return: the solution 

179 :rtype: Solution 

180 """ 

181 

182 with np.errstate(all="raise"): 

183 try: 

184 i = 0 

185 # initialize residual and magnitude 

186 state_mag = self._get_state_magnitude(state, magnitudes) 

187 res = np.ones(state.state_vector.shape) 

188 

189 # step 0 outside ou the loop to compute the residual and gradient 

190 # magnitude 

191 res, grad = self._compute_residual_and_gradient(system) 

192 res_mag_inv, grad_mag_inv = self._compute_res_grad_mag(grad, state_mag) 

193 res, grad = self._rescale_res_grad(res, res_mag_inv, grad, grad_mag_inv) 

194 x = self._compute_new_state(state, res, grad, state_mag) 

195 state.update_state_vector(x) 

196 

197 # Begin loop at iteration 1 (0 already done) 

198 i = 1 

199 while ( 

200 np.linalg.norm(res, ord=np.inf) > self.tolerance 

201 and i < self.iteration_max 

202 ): 

203 res, grad = self._compute_residual_and_gradient(system) 

204 res, grad = self._rescale_res_grad( 

205 res, res_mag_inv, grad, grad_mag_inv 

206 ) 

207 x = self._compute_new_state(state, res, grad, state_mag) 

208 state.update_state_vector(x) 

209 i += 1 

210 

211 sol = Solution( 

212 state.state_vector, 

213 ( 

214 bool(np.linalg.norm(res) <= self.tolerance) 

215 and (True in np.isnan(x) or True in np.isinf(x)) is False 

216 ), 

217 ) 

218 except FloatingPointError as exception: 

219 _logger.debug( 

220 str.format( 

221 "Solver did not converge at step {0} due to floating " 

222 "point error. The solved property is set to False.", 

223 i, 

224 ) 

225 ) 

226 log_exception( 

227 _logger, 

228 FloatingPointError, 

229 exception, 

230 exception.__traceback__, 

231 logging.DEBUG, 

232 ) 

233 return Solution(np.empty(state.size), False) 

234 

235 return sol