Coverage for suppy\projections\_projection_methods.py: 67%
199 statements
« prev ^ index » next coverage.py v7.6.4, created at 2026-05-08 13:56 +0200
« prev ^ index » next coverage.py v7.6.4, created at 2026-05-08 13:56 +0200
1"""
2General implementation for sequential, simultaneous, block iterative and
3string averaged projection methods.
4"""
5from abc import ABC
6from typing import List, Callable
7import numpy as np
8import numpy.typing as npt
10try:
11 import cupy as cp
13 NO_GPU = False
14except ImportError:
15 cp = np
16 NO_GPU = True
18from suppy.projections._projections import Projection, BasicProjection
19from suppy.utils import ensure_float_array
22class ProjectionMethod(Projection, ABC):
23 """
24 A class used to represent methods for projecting a point onto multiple
25 sets.
27 Parameters
28 ----------
29 projections : List[Projection]
30 A list of Projection objects to be used in the projection method.
31 relaxation : float, optional
32 A relaxation parameter for the projection method (default is 1).
33 proximity_flag : bool
34 Flag to indicate whether to take this object into account when calculating proximity, by default True.
36 Attributes
37 ----------
38 projections : List[Projection]
39 The list of Projection objects used in the projection method.
40 all_x : array-like or None
41 Storage for all x values if storage is enabled during solve.
42 proximities : list
43 A list to store proximity values during the solve process.
44 relaxation : float
45 Relaxation parameter for the projection.
46 proximity_flag : bool
47 Flag to indicate whether to take this object into account when calculating proximity.
48 """
50 def __init__(
51 self, projections: List[Projection], relaxation: float = 1, proximity_flag: bool = True
52 ):
53 super().__init__(relaxation, proximity_flag)
54 self.projections = projections
55 self.all_x = None
56 self.proximities = []
58 def visualize(self, ax):
59 """
60 Visualizes all projection objects (if applicable) on the given
61 matplotlib axis.
63 Parameters
64 ----------
65 ax : matplotlib.axes.Axes
66 The matplotlib axis on which to visualize the projections.
67 """
68 for proj in self.projections:
69 proj.visualize(ax)
71 @ensure_float_array
72 def solve(
73 self,
74 x: npt.NDArray,
75 max_iter: int = 500,
76 prox_tol: float = 1e-6,
77 del_prox_tol: float = 1e-8,
78 del_prox_n: int = 5,
79 proximity_measures: List | None = None,
80 storage: bool = False,
81 storage_iters: List[int] | int | None = None,
82 alternative_stopping_criterion: Callable | None = None,
83 alternative_stopping_criterion_initial_call: Callable | None = None,
84 ) -> np.ndarray:
85 """
86 Solves the optimization problem using an iterative approach.
88 Parameters
89 ----------
90 x : npt.NDArray
91 Starting point for the algorithm.
92 max_iter : int, optional
93 Maximum number of iterations to perform, by default 500.
94 prox_tol : float, optional
95 The tolerance for the proximity on the constraints, by default 1e-6.
96 del_prox_tol : float, optional
97 The tolerance for the change in proximity over the last del_prox_n iterations, by default 1e-8.
98 del_prox_n : int, optional
99 The number of iterations that del_prox_tol needs to be met in a row, by default 5.
100 proximity_measures : List, optional
101 The proximity measures to calculate, by default a l2 norm measure is used. Right now only the first in the list is used to check the feasibility.
102 storage : bool, optional
103 Flag indicating whether to store intermediate solutions, by default False.
104 storage_iters : List[int] or int, optional
105 Controls which iterations are stored (when storage=True). If None, all iterations are stored.
106 If a list of ints, only those iteration indices are stored (0 = initial point).
107 If an int, storage occurs every that many iterations.
108 alternative_stopping_criterion : callable, optional
109 Alternative stopping criterion
110 alternative_stopping_criterion_initial_call : callable, optional
111 Initial call for an alternative stopping criterion
113 Returns
114 -------
115 npt.NDArray
116 The solution after the iterative process.
117 """
118 xp = cp if isinstance(x, cp.ndarray) else np
120 if proximity_measures is None:
121 proximity_measures = [("p_norm", 2)]
122 else:
123 # TODO: Check if the proximity measures are valid
124 _ = None
126 self.proximities = [self.proximity(x, proximity_measures)]
127 i = 0
129 def _should_store(idx):
130 if storage_iters is None:
131 return True
132 if isinstance(storage_iters, int):
133 return idx % storage_iters == 0
134 return idx in storage_iters
136 if storage is True:
137 self.all_x = []
138 if _should_store(0):
139 if isinstance(x, np.ndarray):
140 self.all_x.append(np.array(x.copy()))
141 else:
142 self.all_x.append((x.get()))
144 if alternative_stopping_criterion_initial_call is not None:
145 stop = alternative_stopping_criterion_initial_call(x, self)
146 else:
147 stop = False # criterion for stopping the algorithm
149 self._n_tol = 0
151 while i < max_iter and not stop:
152 x = self.project(x)
153 if storage is True and _should_store(i + 1):
154 if isinstance(x, np.ndarray): # convert to np array if cp
155 self.all_x.append(np.array(x.copy()))
156 else:
157 self.all_x.append((x.get()))
159 self.proximities.append(self.proximity(x, proximity_measures))
161 # TODO: If proximity changes x some potential issues!
162 if alternative_stopping_criterion is not None:
163 stop = alternative_stopping_criterion(x, self)
164 else:
165 stop = self._stopping_criterion(prox_tol, del_prox_tol, del_prox_n)
167 i += 1
169 if self.all_x is not None:
170 self.all_x = np.array(self.all_x)
172 self.proximities = xp.array(self.proximities)
173 return x
175 def _stopping_criterion(self, prox_tol: float, del_prox_tol: float, del_prox_n: int) -> bool:
176 """Returns True when convergence is detected, False otherwise."""
177 if self.proximities[-1][0] < prox_tol:
178 return True
179 else: # check that last n proximity changes are below a threshold
180 if self.proximities[-2][0] - self.proximities[-1][0] < del_prox_tol:
181 self._n_tol += 1
182 else:
183 self._n_tol = 0
184 if self._n_tol >= del_prox_n:
185 return True
186 return False
188 def _proximity(self, x: npt.NDArray, proximity_measures: List) -> List[float]:
189 xp = cp if isinstance(x, cp.ndarray) else np
190 proxs = xp.array(
191 [xp.array(proj.proximity(x, proximity_measures)) for proj in self.projections]
192 )
193 measures = []
194 for i, measure in enumerate(proximity_measures):
195 if isinstance(measure, tuple):
196 if measure[0] == "p_norm":
197 measures.append((proxs[:, i]).mean())
198 else:
199 raise ValueError("Invalid proximity measure")
200 elif isinstance(measure, str) and measure == "max_norm":
201 measures.append(proxs[:, i].max())
202 else:
203 raise ValueError("Invalid proximity measure")
204 return measures
207class SequentialProjection(ProjectionMethod):
208 """
209 Class to represent a sequential projection.
211 Parameters
212 ----------
213 projections : List[Projection]
214 A list of projection methods to be applied sequentially.
215 relaxation : float, optional
216 A relaxation parameter for the projection methods, by default 1.
217 control_seq : None, numpy.typing.ArrayLike, or List[int], optional
218 An optional sequence that determines the order in which the projections are applied.
219 If None, the projections are applied in the order they are provided, by default None.
220 proximity_flag : bool
221 Flag to indicate whether to take this object into account when calculating proximity, by default True.
223 Attributes
224 ----------
225 projections : List[Projection]
226 The list of Projection objects used in the projection method.
227 all_x : array-like or None
228 Storage for all x values if storage is enabled during solve.
229 relaxation : float
230 Relaxation parameter for the projection.
231 proximity_flag : bool
232 Flag to indicate whether to take this object into account when calculating proximity.
233 control_seq : npt.NDArray or List[int]
234 The sequence in which the projections are applied.
235 """
237 def __init__(
238 self,
239 projections: List[Projection],
240 relaxation: float = 1,
241 control_seq: None | npt.NDArray | List[int] = None,
242 proximity_flag: bool = True,
243 ):
244 super().__init__(projections, relaxation, proximity_flag)
245 if control_seq is None:
246 self.control_seq = np.arange(len(projections))
247 else:
248 self.control_seq = control_seq
250 def _project(self, x: npt.NDArray) -> np.ndarray:
251 """
252 Sequentially projects the input array `x` using the control
253 sequence.
255 Parameters
256 ----------
257 x : npt.NDArray
258 The input array to be projected.
260 Returns
261 -------
262 npt.NDArray
263 The projected array after applying all projection methods in the control sequence.
264 """
266 for i in self.control_seq:
267 x = self.projections[i].project(x)
268 return x
271class SimultaneousProjection(ProjectionMethod):
272 """
273 Class to represent a simultaneous projection.
275 Parameters
276 ----------
277 projections : List[Projection]
278 A list of projection methods to be applied.
279 weights : npt.NDArray or None, optional
280 An array of weights for each projection method. If None, equal weights
281 are assigned to each projection. Weights are normalized to sum up to 1. Default is None.
282 relaxation : float, optional
283 A relaxation parameter for the projection methods. Default is 1.
284 proximity_flag : bool, optional
285 A flag indicating whether to use proximity in the projection methods.
286 Default is True.
288 Attributes
289 ----------
290 projections : List[Projection]
291 The list of Projection objects used in the projection method.
292 all_x : array-like or None
293 Storage for all x values if storage is enabled during solve.
294 relaxation : float
295 Relaxation parameter for the projection.
296 proximity_flag : bool
297 Flag to indicate whether to take this object into account when calculating proximity.
298 weights : npt.NDArray
299 The weights assigned to each projection method.
301 Notes
302 -----
303 While the simultaneous projection is performed simultaneously mathematically, the actual computation right now is sequential.
304 """
306 def __init__(
307 self,
308 projections: List[Projection],
309 weights: npt.NDArray | None = None,
310 relaxation: float = 1,
311 proximity_flag: bool = True,
312 ):
313 super().__init__(projections, relaxation, proximity_flag)
314 if weights is None:
315 weights = np.ones(len(projections)) / len(projections)
316 self.weights = weights / weights.sum()
318 def _project(self, x: npt.NDArray) -> np.ndarray:
319 """
320 Simultaneously projects the input array `x`.
322 Parameters
323 ----------
324 x : npt.NDArray
325 The input array to be projected.
327 Returns
328 -------
329 npt.NDArray
330 The projected array.
331 """
332 x_new = 0
333 for proj, weight in zip(self.projections, self.weights):
334 x_new = x_new + weight * proj.project(x.copy())
335 return x_new
337 def _proximity(self, x: npt.NDArray, proximity_measures: List) -> List[float]:
338 xp = cp if isinstance(x, cp.ndarray) else np
339 proxs = xp.array(
340 [xp.array(proj.proximity(x, proximity_measures)) for proj in self.projections]
341 )
342 measures = []
343 for i, measure in enumerate(proximity_measures):
344 if isinstance(measure, tuple):
345 if measure[0] == "p_norm":
346 measures.append(self.weights @ (proxs[:, i]))
347 else:
348 raise ValueError("Invalid proximity measure")
349 elif isinstance(measure, str) and measure == "max_norm":
350 measures.append(proxs[:, i].max())
351 else:
352 raise ValueError("Invalid proximity measure")
353 return measures
356class StringAveragedProjection(ProjectionMethod):
357 """
358 Class to represent a string averaged projection.
360 Parameters
361 ----------
362 projections : List[Projection]
363 A list of projection methods to be applied.
364 strings : List[List]
365 A list of strings, where each string is a list of indices of the projection methods to be applied.
366 weights : npt.NDArray or None, optional
367 An array of weights for each strings. If None, equal weights
368 are assigned to each string. Weights are normalized to sum up to 1. Default is None.
369 relaxation : float, optional
370 A relaxation parameter for the projection methods. Default is 1.
371 proximity_flag : bool, optional
372 A flag indicating whether to use proximity in the projection methods.
373 Default is True.
375 Attributes
376 ----------
377 projections : List[Projection]
378 The list of Projection objects used in the projection method.
379 all_x : array-like or None
380 Storage for all x values if storage is enabled during solve.
381 relaxation : float
382 Relaxation parameter for the projection.
383 proximity_flag : bool
384 Flag to indicate whether to take this object into account when calculating proximity.
385 strings : List[List]
386 A list of strings, where each string is a list of indices of the projection methods to be applied.
387 weights : npt.NDArray
388 The weights assigned to each projection method.
390 Notes
391 -----
392 While the string projections are performed simultaneously mathematically, the actual computation right now is sequential.
393 """
395 def __init__(
396 self,
397 projections: List[Projection],
398 strings: List[List],
399 weights: npt.NDArray | None = None,
400 relaxation: float = 1,
401 proximity_flag: bool = True,
402 ):
403 super().__init__(projections, relaxation, proximity_flag)
404 if weights is None:
405 self.weights = np.ones(len(strings)) / len(strings)
406 else:
407 self.weights = weights / weights.sum()
408 self.strings = strings
410 def _project(self, x: npt.NDArray) -> np.ndarray:
411 """
412 String averaged projection of the input array `x`.
414 Parameters
415 ----------
416 x : npt.NDArray
417 The input array to be projected.
419 Returns
420 -------
421 npt.NDArray
422 The projected array after applying all projection methods in the control sequence.
423 """
424 x_new = 0
425 # TODO: Can this be parallelized?
426 for weight, string in zip(self.weights, self.strings):
427 # run over all individual strings
428 x_s = x.copy() # create a copy for
429 for el in string: # run over all elements in the string sequentially
430 x_s = self.projections[el].project(x_s)
431 x_new += weight * x_s
432 return x_new
435class BlockIterativeProjection(ProjectionMethod):
436 """
437 Class to represent a block iterative projection.
439 Parameters
440 ----------
441 projections : List[Projection]
442 A list of projection methods to be applied.
443 weights : List[List[float]] | List[npt.NDArray]
444 A List of weights for each block of projection methods.
445 relaxation : float, optional
446 A relaxation parameter for the projection methods. Default is 1.
447 proximity_flag : bool, optional
448 A flag indicating whether to use proximity in the projection methods.
449 Default is True.
451 Attributes
452 ----------
453 projections : List[Projection]
454 The list of Projection objects used in the projection method.
455 all_x : array-like or None
456 Storage for all x values if storage is enabled during solve.
457 relaxation : float
458 Relaxation parameter for the projection.
459 proximity_flag : bool
460 Flag to indicate whether to take this object into account when calculating proximity.
461 weights : List[npt.NDArray]
462 The weights assigned to each block of projection methods.
464 Notes
465 -----
466 While the individual block projections are performed simultaneously mathematically, the actual computation right now is sequential.
467 """
469 def __init__(
470 self,
471 projections: List[Projection],
472 weights: List[List[float]] | List[npt.NDArray],
473 relaxation: float = 1,
474 proximity_flag: bool = True,
475 ):
476 super().__init__(projections, relaxation, proximity_flag)
477 xp = cp if self._use_gpu else np
478 # check if weights has the correct format
479 for el in weights:
480 if len(el) != len(projections):
481 raise ValueError("Weights do not match the number of projections!")
483 if abs((el.sum() - 1)) > 1e-10:
484 raise ValueError("Weights do not add up to 1!")
486 self.weights = []
487 self.block_idxs = [
488 xp.where(xp.array(el) > 0)[0] for el in weights
489 ] # get idxs that meet requirements
491 # assemble a list of general weights
492 self.total_weights = xp.zeros_like(weights[0])
493 for el in weights:
494 el = xp.asarray(el)
495 self.weights.append(el[xp.array(el) > 0]) # remove non zero weights
496 self.total_weights += el / len(weights)
498 def _project(self, x: npt.NDArray) -> np.ndarray:
499 # TODO: Can this be parallelized?
500 for weight, block_idx in zip(self.weights, self.block_idxs):
501 x_new = 0
502 for i, el in enumerate(block_idx):
503 x_new += weight[i] * self.projections[el].project(x.copy())
504 x = x_new
505 return x
507 def _proximity(self, x: npt.NDArray, proximity_measures: List) -> List[float]:
508 xp = cp if isinstance(x, cp.ndarray) else np
509 proxs = xp.array(
510 [xp.array(proj.proximity(x, proximity_measures)) for proj in self.projections]
511 )
512 measures = []
513 for i, measure in enumerate(proximity_measures):
514 if isinstance(measure, tuple):
515 if measure[0] == "p_norm":
516 measures.append(self.total_weights @ (proxs[:, i]))
517 else:
518 raise ValueError("Invalid proximity measure")
519 elif isinstance(measure, str) and measure == "max_norm":
520 measures.append(proxs[:, i].max())
521 else:
522 raise ValueError("Invalid proximity measure")
523 return measures
526class MultiBallProjection(BasicProjection, ABC):
527 """Projection onto multiple balls."""
529 def __init__(
530 self,
531 centers: npt.NDArray,
532 radii: npt.NDArray,
533 relaxation: float = 1,
534 idx: npt.NDArray | None = None,
535 proximity_flag=True,
536 ):
537 try:
538 if isinstance(centers, cp.ndarray) and isinstance(radii, cp.ndarray):
539 _use_gpu = True
540 elif (isinstance(centers, cp.ndarray)) != (isinstance(radii, cp.ndarray)):
541 raise ValueError("Mismatch between input types of centers and radii")
542 else:
543 _use_gpu = False
544 except ModuleNotFoundError:
545 _use_gpu = False
547 super().__init__(relaxation, idx, proximity_flag, _use_gpu)
548 self.centers = centers
549 self.radii = radii
552class SequentialMultiBallProjection(MultiBallProjection):
553 """Sequential projection onto multiple balls."""
555 def _project(self, x: npt.NDArray) -> np.ndarray:
556 xp = cp if self._use_gpu else np
557 for i in range(len(self.centers)):
558 diff = x[self.idx] - self.centers[i]
559 dist = xp.linalg.norm(diff)
560 if dist > self.radii[i]:
561 x[self.idx] = self.centers[i] + self.radii[i] * diff / dist
562 return x
565class SimultaneousMultiBallProjection(MultiBallProjection):
566 """Simultaneous projection onto multiple balls."""
568 def __init__(
569 self,
570 centers: npt.NDArray,
571 radii: npt.NDArray,
572 weights: npt.NDArray,
573 relaxation: float = 1,
574 idx: npt.NDArray | None = None,
575 proximity_flag=True,
576 ):
578 super().__init__(centers, radii, relaxation, idx, proximity_flag)
579 self.weights = weights
581 def _project(self, x: npt.NDArray) -> np.ndarray:
582 xp = cp if self._use_gpu else np
583 dists = xp.linalg.norm(x[self.idx] - self.centers, axis=1)
584 idx = (dists - self.radii) > 0
585 x[self.idx] = x[self.idx] - (self.weights[idx] * (1 - self.radii[idx] / dists[idx])) @ (
586 x[self.idx] - self.centers[idx]
587 )
588 return x