Coverage for physioblocks / simulation / state.py: 97%

96 statements  

« prev     ^ index     » next       coverage.py v7.13.1, created at 2026-01-09 16:40 +0100

1# SPDX-FileCopyrightText: Copyright INRIA 

2# 

3# SPDX-License-Identifier: LGPL-3.0-only 

4# 

5# Copyright INRIA 

6# 

7# This file is part of PhysioBlocks, a library mostly developed by the 

8# [Ananke project-team](https://team.inria.fr/ananke) at INRIA. 

9# 

10# Authors: 

11# - Colin Drieu 

12# - Dominique Chapelle 

13# - François Kimmig 

14# - Philippe Moireau 

15# 

16# PhysioBlocks is free software: you can redistribute it and/or modify it under the 

17# terms of the GNU Lesser General Public License as published by the Free Software 

18# Foundation, version 3 of the License. 

19# 

20# PhysioBlocks is distributed in the hope that it will be useful, but WITHOUT ANY 

21# WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A 

22# PARTICULAR PURPOSE. See the GNU Lesser General Public License for more details. 

23# 

24# You should have received a copy of the GNU Lesser General Public License along with 

25# PhysioBlocks. If not, see <https://www.gnu.org/licenses/>. 

26 

27""" 

28Define the **State** that holds simulation variables. 

29""" 

30 

31from collections.abc import Callable, Generator, Mapping 

32from pprint import pformat 

33from typing import Any 

34 

35import numpy as np 

36from numpy.typing import NDArray 

37 

38from physioblocks.computing.quantities import Quantity 

39 

40# Constant to identity the state in simulation 

41STATE_NAME_ID = "state" 

42 

43 

44class State: 

45 """ 

46 The **State** holds the variables names, quantities and indexes during the 

47 simulation. 

48 

49 Variables quantity values can be accessed individually with their names or index, 

50 or altogether throught the **State Vector**. 

51 """ 

52 

53 _variables: dict[str, Quantity[Any]] 

54 """The variables ids and quantities values""" 

55 

56 def __init__(self) -> None: 

57 self._variables = {} 

58 

59 @property 

60 def size(self) -> int: 

61 """ 

62 Get the total size of the state. 

63 

64 :return: the size of the state 

65 :rtype: int 

66 """ 

67 return sum([var_qty.size for var_qty in self._variables.values()]) 

68 

69 @property 

70 def variables(self) -> dict[str, Quantity[Any]]: 

71 """ 

72 Get a mapping of variables names and quantities. 

73 

74 :return: the variables names and quantities. 

75 :rtype: dict[str, Quantity] 

76 """ 

77 return self._variables.copy() 

78 

79 @property 

80 def state_vector(self) -> NDArray[np.float64]: 

81 """ 

82 Get the vector of the ``new`` values of the state variable quantities. 

83 

84 :return: the state vector 

85 :rtype: NDArray[np.float64] 

86 """ 

87 if len(self._variables) > 0: 

88 return np.concatenate( 

89 [var_qty.new for var_qty in self._variables.values()], axis=None 

90 ) 

91 else: 

92 return np.array([]) 

93 

94 def __array__(self) -> NDArray[Any]: 

95 return self.state_vector 

96 

97 def __getitem__(self, var_id: str) -> Quantity[Any]: 

98 """ 

99 Get the variable quantity. 

100 

101 :param var_id: the variable id. 

102 :type var_id: str 

103 

104 :return: the variable quantity 

105 :rtype: Quantity 

106 """ 

107 if var_id in self._variables: 

108 return self._variables[var_id] 

109 

110 raise KeyError(str.format("State has no variable variable named {0}.", var_id)) 

111 

112 def get(self, key: str) -> Quantity[Any] | None: 

113 """ 

114 Get the variable quantity with the given key, 

115 or ``None`` if it is not registered. 

116 

117 :param key: the variable key 

118 :type key: str 

119 

120 :return: the variable or None 

121 :rtype: Quantity | None 

122 """ 

123 return self._variables.get(key) 

124 

125 def update(self, mapping: Mapping[str, Any]) -> None: 

126 """ 

127 Update the state variable quantities with the values provided in the 

128 mapping. 

129 

130 .. note:: 

131 

132 New variables in the mapping are added to the state while existing 

133 variables quantities are initialised to the given value. 

134 

135 :param mapping: the values to update the state. 

136 :type mapping: str 

137 """ 

138 for key, val in mapping.items(): 

139 self.__setitem__(key, val) 

140 

141 def __setitem__(self, var_id: str, value: Any) -> None: 

142 """ 

143 Set the variable quantity value. 

144 

145 :param var_id: the variable name. 

146 :type var_id: str 

147 

148 :param value: the variable quant. 

149 :type var_id: Quantity 

150 

151 :raises ValueError: Raises a ValueError if the value is not a Quantity 

152 or the quantity size is incorrect. 

153 """ 

154 

155 if var_id not in self._variables: 

156 self.add_variable(var_id, value) 

157 

158 if var_id in self._variables: 

159 if np.asarray(value).size != self._variables[var_id].size: 

160 raise ValueError( 

161 str.format( 

162 "Expected size {0} for variable {1}, got {2}.", 

163 self._variables[var_id].size, 

164 var_id, 

165 np.asarray(value).size, 

166 ) 

167 ) 

168 else: 

169 self._variables[var_id].initialize(value) 

170 

171 def __str__(self) -> str: 

172 state_dict: dict[str, Any] = {} 

173 state_dict["Variables"] = { 

174 self.get_variable_index(var_id): (var_id, "size " + str(var_qty.size)) 

175 for var_id, var_qty in self._variables.items() 

176 } 

177 return pformat(state_dict, indent=2, compact=False) 

178 

179 @property 

180 def indexes(self) -> dict[str, int]: 

181 """ 

182 Get a mapping of the variables indexes with their names. 

183 

184 :return: the variables indexes ordered by variables ids 

185 :rtype: dict[str, int] 

186 """ 

187 return {var_id: self.get_variable_index(var_id) for var_id in self._variables} 

188 

189 def get_variable_index(self, variable_id: str) -> int: 

190 """ 

191 Get the index of the variable with the given name 

192 

193 :param variable_id: the variable id 

194 :rtype: str 

195 

196 :return: the variable index 

197 :rtype: int 

198 """ 

199 

200 index = 0 

201 for key, value in self._variables.items(): 

202 if variable_id == key: 

203 return index 

204 else: 

205 index += value.size 

206 

207 raise KeyError( 

208 str.format("State has no variable variable named {0}.", variable_id) 

209 ) 

210 

211 def get_variable_size(self, var_id: str) -> int: 

212 """ 

213 Get the size of the variable with the given name. 

214 

215 :param var_id: the variable id 

216 :rtype: str 

217 

218 :return: the size of the variable 

219 :rtype: int 

220 """ 

221 return self._variables[var_id].size 

222 

223 def get_variable_id(self, var_index: int) -> str: 

224 """ 

225 Get the variable name with the given index. 

226 

227 :param var_index: the variable index 

228 :rtype: int 

229 

230 :return: the variable id 

231 :rtype: str 

232 """ 

233 index = 0 

234 var_id_iterator = (var_id for var_id in self._variables) 

235 var_id = next(var_id_iterator, None) 

236 

237 while index != var_index and var_id is not None: 

238 index += self._variables[var_id].size 

239 var_id = next(var_id_iterator, None) 

240 

241 if var_id is not None: 

242 return var_id 

243 

244 raise KeyError(str.format("No variable at index {0}", var_index)) 

245 

246 def __iter__(self) -> Generator[str, None, None]: 

247 """ 

248 Iterate on the variables names in the state. 

249 

250 :return: the variable ids 

251 :rtype: str 

252 """ 

253 yield from self._variables 

254 

255 def __contains__(self, key: str) -> bool: 

256 """ 

257 Checks if the key is in the variables names. 

258 

259 :param key: The key to test 

260 :rtype: Any 

261 """ 

262 return key in self._variables 

263 

264 def update_state_vector(self, x: NDArray[np.float64]) -> None: 

265 """ 

266 Update the ``new`` values of the state vector quantities, 

267 with the given vector. 

268 

269 :param x: the vector to set. 

270 :type x: NDArray[np.float64] 

271 

272 :raise ValueError: 

273 Raise a ValueError when x and the state vector sizes don't match. 

274 """ 

275 self.__change_state_vector(Quantity.update, x) 

276 

277 def reset_state_vector(self) -> None: 

278 """ 

279 Set the new values to the current value of the state vector quantities. 

280 """ 

281 for variable in self._variables.values(): 

282 variable.initialize(variable.current) 

283 

284 def set_state_vector(self, x: NDArray[np.float64]) -> None: 

285 """ 

286 Set the ``new`` and ``current`` values of the state vector quantities 

287 with the given vector. 

288 

289 :param x: the vector to set. 

290 :type x: NDArray[np.float64] 

291 

292 :raise ValueError: 

293 Raise a ValueError when x and the state vector sizes don't match. 

294 """ 

295 self.__change_state_vector(Quantity.initialize, x) 

296 

297 def __change_state_vector( 

298 self, func: Callable[[Quantity[Any], Any], None], x: NDArray[np.float64] 

299 ) -> None: 

300 # Checks x and state vector have the same size. 

301 if x.size != self.size: 

302 raise ValueError(str.format("State vector size does not match state size.")) 

303 indexes = self.indexes 

304 for var_id, quantity in self._variables.items(): 

305 var_index = indexes[var_id] 

306 if quantity.size == 1: 

307 # assign scalar value 

308 func(quantity, x[var_index]) 

309 else: 

310 # assign vector value 

311 quantity_part = x[var_index : var_index + quantity.size] 

312 func(quantity, quantity_part) 

313 

314 def add_variable(self, var_id: str, var_value: Any) -> None: 

315 """ 

316 Add a variable to the state. 

317 

318 :param var_id: the name of the variable 

319 :type var_id: str 

320 

321 :param value: the initial value of the variable. 

322 :type size: int 

323 """ 

324 

325 if var_id in self: 

326 raise KeyError(str.format("{0} is already registered.", var_id)) 

327 

328 quantity = var_value if isinstance(var_value, Quantity) else Quantity(var_value) 

329 self._variables[var_id] = quantity 

330 

331 def remove_variable(self, var_id: str) -> None: 

332 """ 

333 Remove a variable from the state 

334 

335 :param var_id: the name of the variable to remove. 

336 :type var_id: str 

337 """ 

338 # remove the variable 

339 if var_id in self._variables: 

340 self._variables.pop(var_id)