Coverage for physioblocks / simulation / state.py: 97%
96 statements
« prev ^ index » next coverage.py v7.13.1, created at 2026-01-09 16:40 +0100
« prev ^ index » next coverage.py v7.13.1, created at 2026-01-09 16:40 +0100
1# SPDX-FileCopyrightText: Copyright INRIA
2#
3# SPDX-License-Identifier: LGPL-3.0-only
4#
5# Copyright INRIA
6#
7# This file is part of PhysioBlocks, a library mostly developed by the
8# [Ananke project-team](https://team.inria.fr/ananke) at INRIA.
9#
10# Authors:
11# - Colin Drieu
12# - Dominique Chapelle
13# - François Kimmig
14# - Philippe Moireau
15#
16# PhysioBlocks is free software: you can redistribute it and/or modify it under the
17# terms of the GNU Lesser General Public License as published by the Free Software
18# Foundation, version 3 of the License.
19#
20# PhysioBlocks is distributed in the hope that it will be useful, but WITHOUT ANY
21# WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A
22# PARTICULAR PURPOSE. See the GNU Lesser General Public License for more details.
23#
24# You should have received a copy of the GNU Lesser General Public License along with
25# PhysioBlocks. If not, see <https://www.gnu.org/licenses/>.
27"""
28Define the **State** that holds simulation variables.
29"""
31from collections.abc import Callable, Generator, Mapping
32from pprint import pformat
33from typing import Any
35import numpy as np
36from numpy.typing import NDArray
38from physioblocks.computing.quantities import Quantity
40# Constant to identity the state in simulation
41STATE_NAME_ID = "state"
44class State:
45 """
46 The **State** holds the variables names, quantities and indexes during the
47 simulation.
49 Variables quantity values can be accessed individually with their names or index,
50 or altogether throught the **State Vector**.
51 """
53 _variables: dict[str, Quantity[Any]]
54 """The variables ids and quantities values"""
56 def __init__(self) -> None:
57 self._variables = {}
59 @property
60 def size(self) -> int:
61 """
62 Get the total size of the state.
64 :return: the size of the state
65 :rtype: int
66 """
67 return sum([var_qty.size for var_qty in self._variables.values()])
69 @property
70 def variables(self) -> dict[str, Quantity[Any]]:
71 """
72 Get a mapping of variables names and quantities.
74 :return: the variables names and quantities.
75 :rtype: dict[str, Quantity]
76 """
77 return self._variables.copy()
79 @property
80 def state_vector(self) -> NDArray[np.float64]:
81 """
82 Get the vector of the ``new`` values of the state variable quantities.
84 :return: the state vector
85 :rtype: NDArray[np.float64]
86 """
87 if len(self._variables) > 0:
88 return np.concatenate(
89 [var_qty.new for var_qty in self._variables.values()], axis=None
90 )
91 else:
92 return np.array([])
94 def __array__(self) -> NDArray[Any]:
95 return self.state_vector
97 def __getitem__(self, var_id: str) -> Quantity[Any]:
98 """
99 Get the variable quantity.
101 :param var_id: the variable id.
102 :type var_id: str
104 :return: the variable quantity
105 :rtype: Quantity
106 """
107 if var_id in self._variables:
108 return self._variables[var_id]
110 raise KeyError(str.format("State has no variable variable named {0}.", var_id))
112 def get(self, key: str) -> Quantity[Any] | None:
113 """
114 Get the variable quantity with the given key,
115 or ``None`` if it is not registered.
117 :param key: the variable key
118 :type key: str
120 :return: the variable or None
121 :rtype: Quantity | None
122 """
123 return self._variables.get(key)
125 def update(self, mapping: Mapping[str, Any]) -> None:
126 """
127 Update the state variable quantities with the values provided in the
128 mapping.
130 .. note::
132 New variables in the mapping are added to the state while existing
133 variables quantities are initialised to the given value.
135 :param mapping: the values to update the state.
136 :type mapping: str
137 """
138 for key, val in mapping.items():
139 self.__setitem__(key, val)
141 def __setitem__(self, var_id: str, value: Any) -> None:
142 """
143 Set the variable quantity value.
145 :param var_id: the variable name.
146 :type var_id: str
148 :param value: the variable quant.
149 :type var_id: Quantity
151 :raises ValueError: Raises a ValueError if the value is not a Quantity
152 or the quantity size is incorrect.
153 """
155 if var_id not in self._variables:
156 self.add_variable(var_id, value)
158 if var_id in self._variables:
159 if np.asarray(value).size != self._variables[var_id].size:
160 raise ValueError(
161 str.format(
162 "Expected size {0} for variable {1}, got {2}.",
163 self._variables[var_id].size,
164 var_id,
165 np.asarray(value).size,
166 )
167 )
168 else:
169 self._variables[var_id].initialize(value)
171 def __str__(self) -> str:
172 state_dict: dict[str, Any] = {}
173 state_dict["Variables"] = {
174 self.get_variable_index(var_id): (var_id, "size " + str(var_qty.size))
175 for var_id, var_qty in self._variables.items()
176 }
177 return pformat(state_dict, indent=2, compact=False)
179 @property
180 def indexes(self) -> dict[str, int]:
181 """
182 Get a mapping of the variables indexes with their names.
184 :return: the variables indexes ordered by variables ids
185 :rtype: dict[str, int]
186 """
187 return {var_id: self.get_variable_index(var_id) for var_id in self._variables}
189 def get_variable_index(self, variable_id: str) -> int:
190 """
191 Get the index of the variable with the given name
193 :param variable_id: the variable id
194 :rtype: str
196 :return: the variable index
197 :rtype: int
198 """
200 index = 0
201 for key, value in self._variables.items():
202 if variable_id == key:
203 return index
204 else:
205 index += value.size
207 raise KeyError(
208 str.format("State has no variable variable named {0}.", variable_id)
209 )
211 def get_variable_size(self, var_id: str) -> int:
212 """
213 Get the size of the variable with the given name.
215 :param var_id: the variable id
216 :rtype: str
218 :return: the size of the variable
219 :rtype: int
220 """
221 return self._variables[var_id].size
223 def get_variable_id(self, var_index: int) -> str:
224 """
225 Get the variable name with the given index.
227 :param var_index: the variable index
228 :rtype: int
230 :return: the variable id
231 :rtype: str
232 """
233 index = 0
234 var_id_iterator = (var_id for var_id in self._variables)
235 var_id = next(var_id_iterator, None)
237 while index != var_index and var_id is not None:
238 index += self._variables[var_id].size
239 var_id = next(var_id_iterator, None)
241 if var_id is not None:
242 return var_id
244 raise KeyError(str.format("No variable at index {0}", var_index))
246 def __iter__(self) -> Generator[str, None, None]:
247 """
248 Iterate on the variables names in the state.
250 :return: the variable ids
251 :rtype: str
252 """
253 yield from self._variables
255 def __contains__(self, key: str) -> bool:
256 """
257 Checks if the key is in the variables names.
259 :param key: The key to test
260 :rtype: Any
261 """
262 return key in self._variables
264 def update_state_vector(self, x: NDArray[np.float64]) -> None:
265 """
266 Update the ``new`` values of the state vector quantities,
267 with the given vector.
269 :param x: the vector to set.
270 :type x: NDArray[np.float64]
272 :raise ValueError:
273 Raise a ValueError when x and the state vector sizes don't match.
274 """
275 self.__change_state_vector(Quantity.update, x)
277 def reset_state_vector(self) -> None:
278 """
279 Set the new values to the current value of the state vector quantities.
280 """
281 for variable in self._variables.values():
282 variable.initialize(variable.current)
284 def set_state_vector(self, x: NDArray[np.float64]) -> None:
285 """
286 Set the ``new`` and ``current`` values of the state vector quantities
287 with the given vector.
289 :param x: the vector to set.
290 :type x: NDArray[np.float64]
292 :raise ValueError:
293 Raise a ValueError when x and the state vector sizes don't match.
294 """
295 self.__change_state_vector(Quantity.initialize, x)
297 def __change_state_vector(
298 self, func: Callable[[Quantity[Any], Any], None], x: NDArray[np.float64]
299 ) -> None:
300 # Checks x and state vector have the same size.
301 if x.size != self.size:
302 raise ValueError(str.format("State vector size does not match state size."))
303 indexes = self.indexes
304 for var_id, quantity in self._variables.items():
305 var_index = indexes[var_id]
306 if quantity.size == 1:
307 # assign scalar value
308 func(quantity, x[var_index])
309 else:
310 # assign vector value
311 quantity_part = x[var_index : var_index + quantity.size]
312 func(quantity, quantity_part)
314 def add_variable(self, var_id: str, var_value: Any) -> None:
315 """
316 Add a variable to the state.
318 :param var_id: the name of the variable
319 :type var_id: str
321 :param value: the initial value of the variable.
322 :type size: int
323 """
325 if var_id in self:
326 raise KeyError(str.format("{0} is already registered.", var_id))
328 quantity = var_value if isinstance(var_value, Quantity) else Quantity(var_value)
329 self._variables[var_id] = quantity
331 def remove_variable(self, var_id: str) -> None:
332 """
333 Remove a variable from the state
335 :param var_id: the name of the variable to remove.
336 :type var_id: str
337 """
338 # remove the variable
339 if var_id in self._variables:
340 self._variables.pop(var_id)