Coverage for suppy\projections\_projections.py: 87%
46 statements
« prev ^ index » next coverage.py v7.6.4, created at 2026-05-08 13:56 +0200
« prev ^ index » next coverage.py v7.6.4, created at 2026-05-08 13:56 +0200
1"""Base classes for all projection objects."""
2from abc import ABC, abstractmethod
3import numpy as np
4import numpy.typing as npt
6try:
7 import cupy as cp
9 NO_GPU = False
10except ImportError:
11 NO_GPU = True
12 cp = np
15class Projection(ABC):
16 """
17 Abstract base class for projections used in feasibility algorithms.
19 Parameters
20 ----------
21 relaxation : float, optional
22 Relaxation parameter for the projection, by default 1.
23 proximity_flag : bool
24 Flag to indicate whether to take this object into account when calculating proximity, by default True.
26 Attributes
27 ----------
28 relaxation : float
29 Relaxation parameter for the projection.
30 proximity_flag : bool
31 Flag to indicate whether to take this object into account when calculating proximity.
32 """
34 def __init__(self, relaxation=1, proximity_flag=True, _use_gpu=False):
35 self.relaxation = relaxation
36 self.proximity_flag = proximity_flag
37 self._use_gpu = _use_gpu
39 # @ensure_float_array
40 # removed decorator since it leads to unwanted behavior
42 def step(self, x: npt.NDArray) -> np.ndarray:
43 """
44 Perform the (possibly relaxed) projection of input array 'x' onto
45 the constraint.
47 Parameters
48 ----------
49 x : npt.NDArray
50 The input array to be projected.
52 Returns
53 -------
54 npt.NDArray
55 The (possibly relaxed) projection of 'x' onto the constraint.
56 """
57 return self.project(x)
59 def project(self, x: npt.NDArray) -> np.ndarray:
60 """
61 Perform the (possibly relaxed) projection of input array 'x' onto
62 the constraint.
64 Parameters
65 ----------
66 x : npt.NDArray
67 The input array to be projected.
69 Returns
70 -------
71 npt.NDArray
72 The (possibly relaxed) projection of 'x' onto the constraint.
73 """
74 if self.relaxation == 1:
75 return self._project(x)
77 return x.copy() * (1 - self.relaxation) + self.relaxation * (self._project(x))
79 @abstractmethod
80 def _project(self, x: npt.NDArray) -> np.ndarray:
81 """Internal method to project the point x onto the set."""
83 def proximity(self, x: npt.NDArray, proximity_measures: list) -> float:
84 """
85 Calculate proximity measures of point `x` to the set.
87 Parameters
88 ----------
89 x : npt.NDArray
90 Input array for which the proximity measure is to be calculated.
92 Returns
93 -------
94 list[float]
95 The proximity measures of the input array `x`.
96 """
97 xp = cp if isinstance(x, cp.ndarray) else np
98 if self.proximity_flag:
99 return xp.array(self._proximity(x, proximity_measures))
101 return xp.zeros(len(proximity_measures))
103 @abstractmethod
104 def _proximity(self, x: npt.NDArray, proximity_measures: list) -> list[float]:
105 """
106 Calculate proximity measures of point `x` to set.
108 Parameters
109 ----------
110 x : npt.NDArray
111 Input array for which the proximity measures are to be calculated.
112 proximity_measures : list
113 list of proximity measures to calculate.
115 Returns
116 -------
117 list[float]
118 The proximity measures of the input array `x`.
119 """
122class BasicProjection(Projection, ABC):
123 """
124 BasicProjection is an abstract base class that extends the Projection
125 class.
126 It allows for projecting onto a subset of the input array based on provided
127 indices.
129 Parameters
130 ----------
131 relaxation : float, optional
132 Relaxation parameter for the projection, by default 1.
133 idx : npt.NDArray or None, optional
134 Indices to apply the projection, by default None.
135 proximity_flag : bool
136 Flag to indicate whether to take this object into account when calculating proximity, by default True.
138 Attributes
139 ----------
140 relaxation : float
141 Relaxation parameter for the projection.
142 proximity_flag : bool
143 Flag to indicate whether to take this object into account when calculating proximity.
144 idx : npt.NDArray
145 Subset of the input vector to apply the projection on.
146 """
148 def __init__(
149 self, relaxation=1, idx: npt.NDArray | None = None, proximity_flag=True, _use_gpu=False
150 ):
151 super().__init__(relaxation, proximity_flag, _use_gpu)
152 self.idx = idx if idx is not None else np.s_[:]
154 # NOTE: This method should not be required since the base class implementation is sufficient
155 # def project(self, x: npt.NDArray) -> np.ndarray:
156 # """
157 # Perform the (possibly relaxed) projection of input array 'x' onto the constraint.
159 # Parameters
160 # ----------
161 # x : npt.NDArray
162 # The input array to be projected.
164 # Returns
165 # -------
166 # npt.NDArray
167 # The (possibly relaxed) projection of 'x' onto the constraint.
168 # """
170 # if self.relaxation == 1:
171 # return self._project(x)
172 # else:
173 # x[self.idx] = x[self.idx] * (1 - self.relaxation) + self.relaxation * (
174 # self._project(x)[self.idx]
175 # )
176 # return x
178 def _proximity(self, x: npt.NDArray, proximity_measures: list) -> list[float]:
179 # probably should have some option to choose the distance
180 res = x[self.idx] - self._project(x.copy())[self.idx]
181 dist = (res**2).sum() ** (1 / 2)
182 measures = []
183 for measure in proximity_measures:
184 if isinstance(measure, tuple):
185 if measure[0] == "p_norm":
186 measures.append(dist ** measure[1])
187 else:
188 raise ValueError("Invalid proximity measure")
189 elif isinstance(measure, str) and measure == "max_norm":
190 measures.append(dist)
191 else:
192 raise ValueError("Invalid proximity measure")
193 return measures