Coverage for suppy\projections\_projection_methods.py: 67%

199 statements  

« prev     ^ index     » next       coverage.py v7.6.4, created at 2026-05-08 13:56 +0200

1""" 

2General implementation for sequential, simultaneous, block iterative and 

3string averaged projection methods. 

4""" 

5from abc import ABC 

6from typing import List, Callable 

7import numpy as np 

8import numpy.typing as npt 

9 

10try: 

11 import cupy as cp 

12 

13 NO_GPU = False 

14except ImportError: 

15 cp = np 

16 NO_GPU = True 

17 

18from suppy.projections._projections import Projection, BasicProjection 

19from suppy.utils import ensure_float_array 

20 

21 

22class ProjectionMethod(Projection, ABC): 

23 """ 

24 A class used to represent methods for projecting a point onto multiple 

25 sets. 

26 

27 Parameters 

28 ---------- 

29 projections : List[Projection] 

30 A list of Projection objects to be used in the projection method. 

31 relaxation : float, optional 

32 A relaxation parameter for the projection method (default is 1). 

33 proximity_flag : bool 

34 Flag to indicate whether to take this object into account when calculating proximity, by default True. 

35 

36 Attributes 

37 ---------- 

38 projections : List[Projection] 

39 The list of Projection objects used in the projection method. 

40 all_x : array-like or None 

41 Storage for all x values if storage is enabled during solve. 

42 proximities : list 

43 A list to store proximity values during the solve process. 

44 relaxation : float 

45 Relaxation parameter for the projection. 

46 proximity_flag : bool 

47 Flag to indicate whether to take this object into account when calculating proximity. 

48 """ 

49 

50 def __init__( 

51 self, projections: List[Projection], relaxation: float = 1, proximity_flag: bool = True 

52 ): 

53 super().__init__(relaxation, proximity_flag) 

54 self.projections = projections 

55 self.all_x = None 

56 self.proximities = [] 

57 

58 def visualize(self, ax): 

59 """ 

60 Visualizes all projection objects (if applicable) on the given 

61 matplotlib axis. 

62 

63 Parameters 

64 ---------- 

65 ax : matplotlib.axes.Axes 

66 The matplotlib axis on which to visualize the projections. 

67 """ 

68 for proj in self.projections: 

69 proj.visualize(ax) 

70 

71 @ensure_float_array 

72 def solve( 

73 self, 

74 x: npt.NDArray, 

75 max_iter: int = 500, 

76 prox_tol: float = 1e-6, 

77 del_prox_tol: float = 1e-8, 

78 del_prox_n: int = 5, 

79 proximity_measures: List | None = None, 

80 storage: bool = False, 

81 storage_iters: List[int] | int | None = None, 

82 alternative_stopping_criterion: Callable | None = None, 

83 alternative_stopping_criterion_initial_call: Callable | None = None, 

84 ) -> np.ndarray: 

85 """ 

86 Solves the optimization problem using an iterative approach. 

87 

88 Parameters 

89 ---------- 

90 x : npt.NDArray 

91 Starting point for the algorithm. 

92 max_iter : int, optional 

93 Maximum number of iterations to perform, by default 500. 

94 prox_tol : float, optional 

95 The tolerance for the proximity on the constraints, by default 1e-6. 

96 del_prox_tol : float, optional 

97 The tolerance for the change in proximity over the last del_prox_n iterations, by default 1e-8. 

98 del_prox_n : int, optional 

99 The number of iterations that del_prox_tol needs to be met in a row, by default 5. 

100 proximity_measures : List, optional 

101 The proximity measures to calculate, by default a l2 norm measure is used. Right now only the first in the list is used to check the feasibility. 

102 storage : bool, optional 

103 Flag indicating whether to store intermediate solutions, by default False. 

104 storage_iters : List[int] or int, optional 

105 Controls which iterations are stored (when storage=True). If None, all iterations are stored. 

106 If a list of ints, only those iteration indices are stored (0 = initial point). 

107 If an int, storage occurs every that many iterations. 

108 alternative_stopping_criterion : callable, optional 

109 Alternative stopping criterion 

110 alternative_stopping_criterion_initial_call : callable, optional 

111 Initial call for an alternative stopping criterion 

112 

113 Returns 

114 ------- 

115 npt.NDArray 

116 The solution after the iterative process. 

117 """ 

118 xp = cp if isinstance(x, cp.ndarray) else np 

119 

120 if proximity_measures is None: 

121 proximity_measures = [("p_norm", 2)] 

122 else: 

123 # TODO: Check if the proximity measures are valid 

124 _ = None 

125 

126 self.proximities = [self.proximity(x, proximity_measures)] 

127 i = 0 

128 

129 def _should_store(idx): 

130 if storage_iters is None: 

131 return True 

132 if isinstance(storage_iters, int): 

133 return idx % storage_iters == 0 

134 return idx in storage_iters 

135 

136 if storage is True: 

137 self.all_x = [] 

138 if _should_store(0): 

139 if isinstance(x, np.ndarray): 

140 self.all_x.append(np.array(x.copy())) 

141 else: 

142 self.all_x.append((x.get())) 

143 

144 if alternative_stopping_criterion_initial_call is not None: 

145 stop = alternative_stopping_criterion_initial_call(x, self) 

146 else: 

147 stop = False # criterion for stopping the algorithm 

148 

149 self._n_tol = 0 

150 

151 while i < max_iter and not stop: 

152 x = self.project(x) 

153 if storage is True and _should_store(i + 1): 

154 if isinstance(x, np.ndarray): # convert to np array if cp 

155 self.all_x.append(np.array(x.copy())) 

156 else: 

157 self.all_x.append((x.get())) 

158 

159 self.proximities.append(self.proximity(x, proximity_measures)) 

160 

161 # TODO: If proximity changes x some potential issues! 

162 if alternative_stopping_criterion is not None: 

163 stop = alternative_stopping_criterion(x, self) 

164 else: 

165 stop = self._stopping_criterion(prox_tol, del_prox_tol, del_prox_n) 

166 

167 i += 1 

168 

169 if self.all_x is not None: 

170 self.all_x = np.array(self.all_x) 

171 

172 self.proximities = xp.array(self.proximities) 

173 return x 

174 

175 def _stopping_criterion(self, prox_tol: float, del_prox_tol: float, del_prox_n: int) -> bool: 

176 """Returns True when convergence is detected, False otherwise.""" 

177 if self.proximities[-1][0] < prox_tol: 

178 return True 

179 else: # check that last n proximity changes are below a threshold 

180 if self.proximities[-2][0] - self.proximities[-1][0] < del_prox_tol: 

181 self._n_tol += 1 

182 else: 

183 self._n_tol = 0 

184 if self._n_tol >= del_prox_n: 

185 return True 

186 return False 

187 

188 def _proximity(self, x: npt.NDArray, proximity_measures: List) -> List[float]: 

189 xp = cp if isinstance(x, cp.ndarray) else np 

190 proxs = xp.array( 

191 [xp.array(proj.proximity(x, proximity_measures)) for proj in self.projections] 

192 ) 

193 measures = [] 

194 for i, measure in enumerate(proximity_measures): 

195 if isinstance(measure, tuple): 

196 if measure[0] == "p_norm": 

197 measures.append((proxs[:, i]).mean()) 

198 else: 

199 raise ValueError("Invalid proximity measure") 

200 elif isinstance(measure, str) and measure == "max_norm": 

201 measures.append(proxs[:, i].max()) 

202 else: 

203 raise ValueError("Invalid proximity measure") 

204 return measures 

205 

206 

207class SequentialProjection(ProjectionMethod): 

208 """ 

209 Class to represent a sequential projection. 

210 

211 Parameters 

212 ---------- 

213 projections : List[Projection] 

214 A list of projection methods to be applied sequentially. 

215 relaxation : float, optional 

216 A relaxation parameter for the projection methods, by default 1. 

217 control_seq : None, numpy.typing.ArrayLike, or List[int], optional 

218 An optional sequence that determines the order in which the projections are applied. 

219 If None, the projections are applied in the order they are provided, by default None. 

220 proximity_flag : bool 

221 Flag to indicate whether to take this object into account when calculating proximity, by default True. 

222 

223 Attributes 

224 ---------- 

225 projections : List[Projection] 

226 The list of Projection objects used in the projection method. 

227 all_x : array-like or None 

228 Storage for all x values if storage is enabled during solve. 

229 relaxation : float 

230 Relaxation parameter for the projection. 

231 proximity_flag : bool 

232 Flag to indicate whether to take this object into account when calculating proximity. 

233 control_seq : npt.NDArray or List[int] 

234 The sequence in which the projections are applied. 

235 """ 

236 

237 def __init__( 

238 self, 

239 projections: List[Projection], 

240 relaxation: float = 1, 

241 control_seq: None | npt.NDArray | List[int] = None, 

242 proximity_flag: bool = True, 

243 ): 

244 super().__init__(projections, relaxation, proximity_flag) 

245 if control_seq is None: 

246 self.control_seq = np.arange(len(projections)) 

247 else: 

248 self.control_seq = control_seq 

249 

250 def _project(self, x: npt.NDArray) -> np.ndarray: 

251 """ 

252 Sequentially projects the input array `x` using the control 

253 sequence. 

254 

255 Parameters 

256 ---------- 

257 x : npt.NDArray 

258 The input array to be projected. 

259 

260 Returns 

261 ------- 

262 npt.NDArray 

263 The projected array after applying all projection methods in the control sequence. 

264 """ 

265 

266 for i in self.control_seq: 

267 x = self.projections[i].project(x) 

268 return x 

269 

270 

271class SimultaneousProjection(ProjectionMethod): 

272 """ 

273 Class to represent a simultaneous projection. 

274 

275 Parameters 

276 ---------- 

277 projections : List[Projection] 

278 A list of projection methods to be applied. 

279 weights : npt.NDArray or None, optional 

280 An array of weights for each projection method. If None, equal weights 

281 are assigned to each projection. Weights are normalized to sum up to 1. Default is None. 

282 relaxation : float, optional 

283 A relaxation parameter for the projection methods. Default is 1. 

284 proximity_flag : bool, optional 

285 A flag indicating whether to use proximity in the projection methods. 

286 Default is True. 

287 

288 Attributes 

289 ---------- 

290 projections : List[Projection] 

291 The list of Projection objects used in the projection method. 

292 all_x : array-like or None 

293 Storage for all x values if storage is enabled during solve. 

294 relaxation : float 

295 Relaxation parameter for the projection. 

296 proximity_flag : bool 

297 Flag to indicate whether to take this object into account when calculating proximity. 

298 weights : npt.NDArray 

299 The weights assigned to each projection method. 

300 

301 Notes 

302 ----- 

303 While the simultaneous projection is performed simultaneously mathematically, the actual computation right now is sequential. 

304 """ 

305 

306 def __init__( 

307 self, 

308 projections: List[Projection], 

309 weights: npt.NDArray | None = None, 

310 relaxation: float = 1, 

311 proximity_flag: bool = True, 

312 ): 

313 super().__init__(projections, relaxation, proximity_flag) 

314 if weights is None: 

315 weights = np.ones(len(projections)) / len(projections) 

316 self.weights = weights / weights.sum() 

317 

318 def _project(self, x: npt.NDArray) -> np.ndarray: 

319 """ 

320 Simultaneously projects the input array `x`. 

321 

322 Parameters 

323 ---------- 

324 x : npt.NDArray 

325 The input array to be projected. 

326 

327 Returns 

328 ------- 

329 npt.NDArray 

330 The projected array. 

331 """ 

332 x_new = 0 

333 for proj, weight in zip(self.projections, self.weights): 

334 x_new = x_new + weight * proj.project(x.copy()) 

335 return x_new 

336 

337 def _proximity(self, x: npt.NDArray, proximity_measures: List) -> List[float]: 

338 xp = cp if isinstance(x, cp.ndarray) else np 

339 proxs = xp.array( 

340 [xp.array(proj.proximity(x, proximity_measures)) for proj in self.projections] 

341 ) 

342 measures = [] 

343 for i, measure in enumerate(proximity_measures): 

344 if isinstance(measure, tuple): 

345 if measure[0] == "p_norm": 

346 measures.append(self.weights @ (proxs[:, i])) 

347 else: 

348 raise ValueError("Invalid proximity measure") 

349 elif isinstance(measure, str) and measure == "max_norm": 

350 measures.append(proxs[:, i].max()) 

351 else: 

352 raise ValueError("Invalid proximity measure") 

353 return measures 

354 

355 

356class StringAveragedProjection(ProjectionMethod): 

357 """ 

358 Class to represent a string averaged projection. 

359 

360 Parameters 

361 ---------- 

362 projections : List[Projection] 

363 A list of projection methods to be applied. 

364 strings : List[List] 

365 A list of strings, where each string is a list of indices of the projection methods to be applied. 

366 weights : npt.NDArray or None, optional 

367 An array of weights for each strings. If None, equal weights 

368 are assigned to each string. Weights are normalized to sum up to 1. Default is None. 

369 relaxation : float, optional 

370 A relaxation parameter for the projection methods. Default is 1. 

371 proximity_flag : bool, optional 

372 A flag indicating whether to use proximity in the projection methods. 

373 Default is True. 

374 

375 Attributes 

376 ---------- 

377 projections : List[Projection] 

378 The list of Projection objects used in the projection method. 

379 all_x : array-like or None 

380 Storage for all x values if storage is enabled during solve. 

381 relaxation : float 

382 Relaxation parameter for the projection. 

383 proximity_flag : bool 

384 Flag to indicate whether to take this object into account when calculating proximity. 

385 strings : List[List] 

386 A list of strings, where each string is a list of indices of the projection methods to be applied. 

387 weights : npt.NDArray 

388 The weights assigned to each projection method. 

389 

390 Notes 

391 ----- 

392 While the string projections are performed simultaneously mathematically, the actual computation right now is sequential. 

393 """ 

394 

395 def __init__( 

396 self, 

397 projections: List[Projection], 

398 strings: List[List], 

399 weights: npt.NDArray | None = None, 

400 relaxation: float = 1, 

401 proximity_flag: bool = True, 

402 ): 

403 super().__init__(projections, relaxation, proximity_flag) 

404 if weights is None: 

405 self.weights = np.ones(len(strings)) / len(strings) 

406 else: 

407 self.weights = weights / weights.sum() 

408 self.strings = strings 

409 

410 def _project(self, x: npt.NDArray) -> np.ndarray: 

411 """ 

412 String averaged projection of the input array `x`. 

413 

414 Parameters 

415 ---------- 

416 x : npt.NDArray 

417 The input array to be projected. 

418 

419 Returns 

420 ------- 

421 npt.NDArray 

422 The projected array after applying all projection methods in the control sequence. 

423 """ 

424 x_new = 0 

425 # TODO: Can this be parallelized? 

426 for weight, string in zip(self.weights, self.strings): 

427 # run over all individual strings 

428 x_s = x.copy() # create a copy for 

429 for el in string: # run over all elements in the string sequentially 

430 x_s = self.projections[el].project(x_s) 

431 x_new += weight * x_s 

432 return x_new 

433 

434 

435class BlockIterativeProjection(ProjectionMethod): 

436 """ 

437 Class to represent a block iterative projection. 

438 

439 Parameters 

440 ---------- 

441 projections : List[Projection] 

442 A list of projection methods to be applied. 

443 weights : List[List[float]] | List[npt.NDArray] 

444 A List of weights for each block of projection methods. 

445 relaxation : float, optional 

446 A relaxation parameter for the projection methods. Default is 1. 

447 proximity_flag : bool, optional 

448 A flag indicating whether to use proximity in the projection methods. 

449 Default is True. 

450 

451 Attributes 

452 ---------- 

453 projections : List[Projection] 

454 The list of Projection objects used in the projection method. 

455 all_x : array-like or None 

456 Storage for all x values if storage is enabled during solve. 

457 relaxation : float 

458 Relaxation parameter for the projection. 

459 proximity_flag : bool 

460 Flag to indicate whether to take this object into account when calculating proximity. 

461 weights : List[npt.NDArray] 

462 The weights assigned to each block of projection methods. 

463 

464 Notes 

465 ----- 

466 While the individual block projections are performed simultaneously mathematically, the actual computation right now is sequential. 

467 """ 

468 

469 def __init__( 

470 self, 

471 projections: List[Projection], 

472 weights: List[List[float]] | List[npt.NDArray], 

473 relaxation: float = 1, 

474 proximity_flag: bool = True, 

475 ): 

476 super().__init__(projections, relaxation, proximity_flag) 

477 xp = cp if self._use_gpu else np 

478 # check if weights has the correct format 

479 for el in weights: 

480 if len(el) != len(projections): 

481 raise ValueError("Weights do not match the number of projections!") 

482 

483 if abs((el.sum() - 1)) > 1e-10: 

484 raise ValueError("Weights do not add up to 1!") 

485 

486 self.weights = [] 

487 self.block_idxs = [ 

488 xp.where(xp.array(el) > 0)[0] for el in weights 

489 ] # get idxs that meet requirements 

490 

491 # assemble a list of general weights 

492 self.total_weights = xp.zeros_like(weights[0]) 

493 for el in weights: 

494 el = xp.asarray(el) 

495 self.weights.append(el[xp.array(el) > 0]) # remove non zero weights 

496 self.total_weights += el / len(weights) 

497 

498 def _project(self, x: npt.NDArray) -> np.ndarray: 

499 # TODO: Can this be parallelized? 

500 for weight, block_idx in zip(self.weights, self.block_idxs): 

501 x_new = 0 

502 for i, el in enumerate(block_idx): 

503 x_new += weight[i] * self.projections[el].project(x.copy()) 

504 x = x_new 

505 return x 

506 

507 def _proximity(self, x: npt.NDArray, proximity_measures: List) -> List[float]: 

508 xp = cp if isinstance(x, cp.ndarray) else np 

509 proxs = xp.array( 

510 [xp.array(proj.proximity(x, proximity_measures)) for proj in self.projections] 

511 ) 

512 measures = [] 

513 for i, measure in enumerate(proximity_measures): 

514 if isinstance(measure, tuple): 

515 if measure[0] == "p_norm": 

516 measures.append(self.total_weights @ (proxs[:, i])) 

517 else: 

518 raise ValueError("Invalid proximity measure") 

519 elif isinstance(measure, str) and measure == "max_norm": 

520 measures.append(proxs[:, i].max()) 

521 else: 

522 raise ValueError("Invalid proximity measure") 

523 return measures 

524 

525 

526class MultiBallProjection(BasicProjection, ABC): 

527 """Projection onto multiple balls.""" 

528 

529 def __init__( 

530 self, 

531 centers: npt.NDArray, 

532 radii: npt.NDArray, 

533 relaxation: float = 1, 

534 idx: npt.NDArray | None = None, 

535 proximity_flag=True, 

536 ): 

537 try: 

538 if isinstance(centers, cp.ndarray) and isinstance(radii, cp.ndarray): 

539 _use_gpu = True 

540 elif (isinstance(centers, cp.ndarray)) != (isinstance(radii, cp.ndarray)): 

541 raise ValueError("Mismatch between input types of centers and radii") 

542 else: 

543 _use_gpu = False 

544 except ModuleNotFoundError: 

545 _use_gpu = False 

546 

547 super().__init__(relaxation, idx, proximity_flag, _use_gpu) 

548 self.centers = centers 

549 self.radii = radii 

550 

551 

552class SequentialMultiBallProjection(MultiBallProjection): 

553 """Sequential projection onto multiple balls.""" 

554 

555 def _project(self, x: npt.NDArray) -> np.ndarray: 

556 xp = cp if self._use_gpu else np 

557 for i in range(len(self.centers)): 

558 diff = x[self.idx] - self.centers[i] 

559 dist = xp.linalg.norm(diff) 

560 if dist > self.radii[i]: 

561 x[self.idx] = self.centers[i] + self.radii[i] * diff / dist 

562 return x 

563 

564 

565class SimultaneousMultiBallProjection(MultiBallProjection): 

566 """Simultaneous projection onto multiple balls.""" 

567 

568 def __init__( 

569 self, 

570 centers: npt.NDArray, 

571 radii: npt.NDArray, 

572 weights: npt.NDArray, 

573 relaxation: float = 1, 

574 idx: npt.NDArray | None = None, 

575 proximity_flag=True, 

576 ): 

577 

578 super().__init__(centers, radii, relaxation, idx, proximity_flag) 

579 self.weights = weights 

580 

581 def _project(self, x: npt.NDArray) -> np.ndarray: 

582 xp = cp if self._use_gpu else np 

583 dists = xp.linalg.norm(x[self.idx] - self.centers, axis=1) 

584 idx = (dists - self.radii) > 0 

585 x[self.idx] = x[self.idx] - (self.weights[idx] * (1 - self.radii[idx] / dists[idx])) @ ( 

586 x[self.idx] - self.centers[idx] 

587 ) 

588 return x