Coverage for physioblocks / simulation / solvers.py: 99%
85 statements
« prev ^ index » next coverage.py v7.13.1, created at 2026-01-09 16:40 +0100
« prev ^ index » next coverage.py v7.13.1, created at 2026-01-09 16:40 +0100
1# SPDX-FileCopyrightText: Copyright INRIA
2#
3# SPDX-License-Identifier: LGPL-3.0-only
4#
5# Copyright INRIA
6#
7# This file is part of PhysioBlocks, a library mostly developed by the
8# [Ananke project-team](https://team.inria.fr/ananke) at INRIA.
9#
10# Authors:
11# - Colin Drieu
12# - Dominique Chapelle
13# - François Kimmig
14# - Philippe Moireau
15#
16# PhysioBlocks is free software: you can redistribute it and/or modify it under the
17# terms of the GNU Lesser General Public License as published by the Free Software
18# Foundation, version 3 of the License.
19#
20# PhysioBlocks is distributed in the hope that it will be useful, but WITHOUT ANY
21# WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A
22# PARTICULAR PURPOSE. See the GNU Lesser General Public License for more details.
23#
24# You should have received a copy of the GNU Lesser General Public License along with
25# PhysioBlocks. If not, see <https://www.gnu.org/licenses/>.
27"""Declare a generic **Solver** class and solver implementations"""
29from __future__ import annotations
31import logging
32from abc import ABC, abstractmethod
33from dataclasses import dataclass
35import numpy as np
36from numpy.typing import NDArray
38from physioblocks.computing.assembling import EqSystem
39from physioblocks.registers.type_register import register_type
40from physioblocks.simulation.state import State
41from physioblocks.utils.exceptions_utils import log_exception
43_logger = logging.getLogger(__name__)
46@dataclass(frozen=True)
47class Solution:
48 """
49 Represent the solution return by a solver.
50 """
52 x: NDArray[np.float64]
53 """the actual solution"""
55 converged: bool
56 """get if the solver converged."""
59class ConvergenceError(Exception):
60 """
61 Error raised when the solver did not converged.
62 """
65class AbstractSolver(ABC):
66 """
67 Base class for solvers.
68 """
70 iteration_max: int
71 """the solver maximum allowed number of iterations"""
73 tolerance: float
74 """the solver tolerance"""
76 def __init__(
77 self,
78 tolerance: float = 1e-9,
79 iteration_max: int = 10,
80 ) -> None:
81 self.tolerance = tolerance
82 self.iteration_max = iteration_max
84 def _get_state_magnitude(
85 self, state: State, magnitudes: dict[str, float] | None = None
86 ) -> NDArray[np.float64]:
87 if magnitudes is None:
88 return np.ones(
89 state.size,
90 )
92 mag_dict = {}
93 for var_mag_key, var_mag_value in magnitudes.items():
94 var_index = state.get_variable_index(var_mag_key)
95 mag_dict[var_index] = var_mag_value
96 sorted_mag = sorted(mag_dict.items())
97 state_mag_list = [x[1] for x in sorted_mag]
98 return np.array(
99 state_mag_list,
100 )
102 @abstractmethod
103 def solve(
104 self,
105 state: State,
106 system: EqSystem,
107 magnitudes: dict[str, float] | None = None,
108 ) -> Solution:
109 """
110 Child classes have to override this method
112 :return: the solution of the solver
113 :rtype: _Array
114 """
117# Type id for the Newton Solver
118NEWTON_SOLVER_TYPE_ID = "newton_solver"
121@register_type(NEWTON_SOLVER_TYPE_ID)
122class NewtonSolver(AbstractSolver):
123 """
124 Implementation of the :class:`~.AbstractSolver` class using a **Newton method**.
125 """
127 def _compute_residual_and_gradient(
128 self, system: EqSystem
129 ) -> tuple[NDArray[np.float64], NDArray[np.float64]]:
130 res = system.compute_residual()
131 grad = system.compute_gradient()
132 return res, grad
134 def _compute_new_state(
135 self,
136 state: State,
137 res: NDArray[np.float64],
138 grad: NDArray[np.float64],
139 state_mag: NDArray[np.float64],
140 ) -> NDArray[np.float64]:
141 res_grad_sol = np.linalg.solve(grad, res)
142 x = state.state_vector - res_grad_sol * state_mag
143 return x
145 def _compute_res_grad_mag(
146 self, gradient: NDArray[np.float64], state_mag: NDArray[np.float64]
147 ) -> tuple[NDArray[np.float64], NDArray[np.float64]]:
148 state_mag_line = np.atleast_2d(state_mag)
149 state_mag_col = state_mag_line.T
151 res_mag = gradient @ state_mag_col
152 abs_res_mag = np.abs(res_mag)
153 res_mag_inv = 1.0 / abs_res_mag
155 grad_mag_inv = res_mag_inv @ state_mag_line
156 return res_mag_inv.flatten(), grad_mag_inv
158 def _rescale_res_grad(
159 self,
160 residual: NDArray[np.float64],
161 res_mag_inv: NDArray[np.float64],
162 gradient: NDArray[np.float64],
163 grad_mag_inv: NDArray[np.float64],
164 ) -> tuple[NDArray[np.float64], NDArray[np.float64]]:
165 res_rescaled = residual * res_mag_inv
166 grad_rescaled = gradient * grad_mag_inv
167 return res_rescaled, grad_rescaled
169 def solve(
170 self,
171 state: State,
172 system: EqSystem,
173 magnitudes: dict[str, float] | None = None,
174 ) -> Solution:
175 """
176 Solve the equation system using the Newton method.
178 :return: the solution
179 :rtype: Solution
180 """
182 with np.errstate(all="raise"):
183 try:
184 i = 0
185 # initialize residual and magnitude
186 state_mag = self._get_state_magnitude(state, magnitudes)
187 res = np.ones(state.state_vector.shape)
189 # step 0 outside ou the loop to compute the residual and gradient
190 # magnitude
191 res, grad = self._compute_residual_and_gradient(system)
192 res_mag_inv, grad_mag_inv = self._compute_res_grad_mag(grad, state_mag)
193 res, grad = self._rescale_res_grad(res, res_mag_inv, grad, grad_mag_inv)
194 x = self._compute_new_state(state, res, grad, state_mag)
195 state.update_state_vector(x)
197 # Begin loop at iteration 1 (0 already done)
198 i = 1
199 while (
200 np.linalg.norm(res, ord=np.inf) > self.tolerance
201 and i < self.iteration_max
202 ):
203 res, grad = self._compute_residual_and_gradient(system)
204 res, grad = self._rescale_res_grad(
205 res, res_mag_inv, grad, grad_mag_inv
206 )
207 x = self._compute_new_state(state, res, grad, state_mag)
208 state.update_state_vector(x)
209 i += 1
211 sol = Solution(
212 state.state_vector,
213 (
214 bool(np.linalg.norm(res) <= self.tolerance)
215 and (True in np.isnan(x) or True in np.isinf(x)) is False
216 ),
217 )
218 except FloatingPointError as exception:
219 _logger.debug(
220 str.format(
221 "Solver did not converge at step {0} due to floating "
222 "point error. The solved property is set to False.",
223 i,
224 )
225 )
226 log_exception(
227 _logger,
228 FloatingPointError,
229 exception,
230 exception.__traceback__,
231 logging.DEBUG,
232 )
233 return Solution(np.empty(state.size), False)
235 return sol