Coverage for suppy\superiorization\_split_sup.py: 86%
173 statements
« prev ^ index » next coverage.py v7.6.4, created at 2026-05-08 13:56 +0200
« prev ^ index » next coverage.py v7.6.4, created at 2026-05-08 13:56 +0200
1from typing import List, Callable
3import numpy as np
4import numpy.typing as npt
5from suppy.feasibility._bands._ams_extrapolations import AdaptiveStepLandweberHyperslab
6from suppy.utils import ensure_float_array
7from suppy.perturbations import Perturbation, DummyPerturbation
8from ._sup import FeasibilityPerturbation
10try:
11 import cupy as cp
13 NO_GPU = False
14except ImportError:
15 cp = np
16 NO_GPU = True
19class SplitSuperiorization(FeasibilityPerturbation):
20 """
21 A class used to perform split superiorization on a given feasibility
22 problem.
24 Parameters
25 ----------
26 basic : object
27 An instance of a split problem.
28 input_perturbation_scheme : Perturbation or None, optional
29 Perturbation scheme for the input, by default None.
30 target_perturbation_scheme : Perturbation or None, optional
31 Perturbation scheme for the target, by default None.
32 input_objective_tol : float, optional
33 Tolerance for the input objective function, by default 1e-4.
34 target_objective_tol : float, optional
35 Tolerance for the target objective function, by default 1e-4.
36 prox_tol : float, optional
37 Tolerance for the constraint, by default 1e-6.
39 Attributes
40 ----------
41 input_perturbation_scheme : Perturbation or None
42 Perturbation scheme for the input.
43 target_perturbation_scheme : Perturbation or None
44 Perturbation scheme for the target.
45 input_objective_tol : float
46 Tolerance for the input objective function.
47 target_objective_tol : float
48 Tolerance for the target objective function.
49 prox_tol : float
50 Tolerance for the constraint.
51 input_f_k : float
52 The current objective function value for the input.
53 target_f_k : float
54 The current objective function value for the target.
55 p_k : float
56 The current proximity function value.
57 _k : int
58 The current iteration number.
59 all_x_values : list
60 Array storing all points achieved via the superiorization algorithm.
61 all_function_values : list
62 Array storing all objective function values achieved via the superiorization algorithm.
63 all_x_values_function_reduction : list
64 Array storing all points achieved via the function reduction step.
65 all_function_values_function_reduction : list
66 Array storing all objective function values achieved via the function reduction step.
67 """
69 def __init__(
70 self,
71 basic, # needs to be a split problem
72 input_perturbation_scheme: Perturbation | None = None,
73 target_perturbation_scheme: Perturbation | None = None,
74 ):
75 super().__init__(basic)
76 if input_perturbation_scheme is None and target_perturbation_scheme is None:
77 raise ValueError(
78 "At least one perturbation scheme must be provided for SplitSuperiorization."
79 )
81 self.input_perturbation_scheme = (
82 input_perturbation_scheme
83 if input_perturbation_scheme is not None
84 else DummyPerturbation()
85 )
86 self.target_perturbation_scheme = (
87 target_perturbation_scheme
88 if target_perturbation_scheme is not None
89 else DummyPerturbation()
90 )
92 # initialize some variables for the algorithms
93 self.input_f_k = None
94 self.target_f_k = None
95 self.p_k = None
96 self._k = 0
98 self.all_x = []
99 self.all_function_values = []
100 self.proximities = []
102 self.all_x_function_reduction = []
103 self.all_function_values_function_reduction = []
104 self.proximities_function_reduction = []
106 self.all_x_basic = []
107 self.all_function_values_basic = []
108 self.proximities_basic = []
110 @ensure_float_array
111 def solve(
112 self,
113 x: npt.NDArray,
114 max_iter: int = 10,
115 prox_tol: float = 1e-6,
116 del_prox_tol: float = 1e-8,
117 del_prox_n: int = 5,
118 proximity_measures: List | None = None,
119 del_input_objective_tol: float = 1e-6,
120 del_input_objective_n: int = 5,
121 del_target_objective_tol: float = 1e-6,
122 del_target_objective_n: int = 5,
123 storage: bool = False,
124 storage_iters: List[int] | int | None = None,
125 alternative_stopping_criterion: Callable | None = None,
126 alternative_stopping_criterion_initial_call: Callable | None = None,
127 ) -> np.ndarray:
128 """
129 Solves the optimization problem using the superiorization method.
131 Parameters
132 ----------
133 x : npt.NDArray
134 Starting point for the algorithm.
135 max_iter : int, optional
136 Maximum number of iterations to perform, by default 500.
137 prox_tol : float, optional
138 The tolerance for the proximity on the constraints, by default 1e-6.
139 del_prox_tol : float, optional
140 The tolerance for the change in proximity over the last del_prox_n iterations, by default 1e-8.
141 del_prox_n : int, optional
142 The number of iterations that del_prox_tol needs to be met in a row, by default 5.
143 proximity_measures : List, optional
144 The proximity measures to calculate, by default a l2 norm measure is used. Right now only the first in the list is used to check the feasibility.
145 del_input_objective_tol
146 The tolerance for change in the objective function over the last del_input_objective_n iterations, by default 1e-8.
147 del_input_objective_n
148 The number of iterations that del_input_objective_tol needs to be met in a row, by default 5.
149 del_target_objective_tol
150 The tolerance for change in the objective function over the last del_target_objective_n iterations, by default 1e-8.
151 del_target_objective_n
152 The number of iterations that del_target_objective_tol needs to be met in a row, by default 5.
153 storage : bool, optional
154 Flag indicating whether to store intermediate solutions, by default False.
155 storage_iters : List[int] or int, optional
156 Controls which iterations are stored (when storage=True). If None, all iterations are stored.
157 If a list of ints, only those iteration indices are stored (0 = initial point).
158 If an int, storage occurs every that many iterations.
159 alternative_stopping_criterion : callable, optional
160 Alternative stopping criterion
161 alternative_stopping_criterion_initial_call : callable, optional
162 Initial call for an alternative stopping criterion
164 Returns
165 -------
166 npt.NDArray
167 The superiorized solution after performing the superiorization method.
168 """
170 def _should_store(idx):
171 if storage_iters is None:
172 return True
173 if isinstance(storage_iters, int):
174 return idx % storage_iters == 0
175 return idx in storage_iters
177 # initialization of variables
178 if proximity_measures is None:
179 proximity_measures = [("p_norm", 2)]
180 else:
181 # TODO: check that proximity measures are valid
182 _ = None
184 self.input_perturbation_scheme.reset() # reset the input perturbation scheme
185 self.target_perturbation_scheme.reset() # reset the target perturbation scheme
187 self._n_tol_input_objective = (
188 0 # number of iterations with objective function changes below threshold
189 )
190 self._n_tol_target_objective = (
191 0 # number of iterations with objective function changes below threshold
192 )
194 self._n_tol_prox = 0 # number of iterations with proximity changes below threshold
196 self.t = [0] # array storing the time for each iteration
197 self.l = []
199 x_0 = x.copy()
201 self._initial_storage(
202 x,
203 storage and _should_store(0),
204 [
205 self.input_perturbation_scheme.func(x_0),
206 self.target_perturbation_scheme.func(self.basic.map(x_0)),
207 ],
208 self.basic.proximity(x_0, proximity_measures),
209 )
211 self._k = 0 # reset counter if necessary
213 if alternative_stopping_criterion_initial_call is not None:
214 stop = alternative_stopping_criterion_initial_call(x, self)
215 else:
216 stop = False # criterion for stopping the algorithm
218 # initial function and proximity values
219 # self.input_f_k = self.input_perturbation_scheme.func(x_0)
220 # self.target_f_k = self.target_perturbation_scheme.func(y)
222 # self.p_k = self.basic.proximity(x_0, proximity_measures)
224 # # if storage:
225 # # self._initial_storage(x_0,self.perturbation_scheme.func(x_0))
226 y = None
228 while self._k < max_iter and not stop:
230 # check if a restart should be performed
232 # perform the perturbation schemes update steps and pre steps
233 self.input_perturbation_scheme.pre_step(
234 x,
235 last_proximity=self.proximities[-1][0],
236 last_proximity_basic=self.proximities_basic[-1][0],
237 last_proximity_function_reduction=self.proximities_function_reduction[-1][0],
238 last_function_value=self.all_function_values[-1][0],
239 last_function_value_basic=self.all_function_values_basic[-1][0],
240 )
242 x = self.input_perturbation_scheme.perturbation_step(x)
244 self.target_perturbation_scheme.pre_step(
245 y,
246 last_proximity=self.proximities[-1][1],
247 last_proximity_basic=self.proximities_basic[-1][1],
248 last_proximity_function_reduction=self.proximities_function_reduction[-1][1],
249 last_function_value=self.all_function_values[-1][1],
250 last_function_value_basic=self.all_function_values_basic[-1][1],
251 )
253 y = self.target_perturbation_scheme.perturbation_step(y)
255 # post steps
256 self.input_perturbation_scheme.post_step(
257 x,
258 last_proximity=self.proximities[-1][0],
259 last_proximity_basic=self.proximities_basic[-1][0],
260 last_proximity_function_reduction=self.proximities_function_reduction[-1][0],
261 last_function_value=self.all_function_values[-1][0],
262 last_function_value_basic=self.all_function_values_basic[-1][0],
263 )
265 self.target_perturbation_scheme.post_step(
266 y,
267 last_proximity=self.proximities[-1][1],
268 last_proximity_basic=self.proximities_basic[-1][1],
269 last_proximity_function_reduction=self.proximities_function_reduction[-1][1],
270 last_function_value=self.all_function_values[-1][1],
271 last_function_value_basic=self.all_function_values_basic[-1][1],
272 )
274 self.storage(
275 x,
276 kind="function_reduction",
277 storage=storage and _should_store(self._k + 1),
278 f=[
279 self.input_perturbation_scheme.func(x),
280 self.target_perturbation_scheme.func(y),
281 ],
282 p=self.basic.proximity(x, proximity_measures),
283 )
285 # perform basic step
286 x, y = self.basic.step(x, y)
288 self.storage(
289 x,
290 kind="basic",
291 storage=storage and _should_store(self._k + 1),
292 f=[
293 self.input_perturbation_scheme.func(x),
294 self.target_perturbation_scheme.func(y),
295 ],
296 p=self.basic.proximity(x, proximity_measures),
297 )
299 self._k += 1
301 # enable different stopping criteria for different superiorization algorithms
302 if alternative_stopping_criterion is not None:
303 stop = alternative_stopping_criterion(x, self)
304 else:
305 stop = self._stopping_criterion(
306 del_input_objective_tol,
307 del_input_objective_n,
308 del_target_objective_tol,
309 del_target_objective_n,
310 prox_tol,
311 del_prox_tol,
312 del_prox_n,
313 )
315 self._additional_action(x, y)
317 self._post_step(x)
319 return x
321 def _stopping_criterion(
322 self,
323 del_input_objective_tol: float,
324 del_input_objective_n: int,
325 del_target_objective_tol: float,
326 del_target_objective_n: int,
327 prox_tol: float,
328 del_prox_tol: float,
329 del_prox_n: int,
330 ) -> bool:
331 """"""
332 stop_objective = False # variable to check if the objective function criteria is met
333 stop_objective_input = (
334 abs(self.all_function_values[-3][0] - self.all_function_values[-1][0])
335 / max(1, self.all_function_values[-3][0])
336 < del_input_objective_tol
337 )
338 stop_objective_target = (
339 abs(self.all_function_values[-3][1] - self.all_function_values[-1][1])
340 / max(1, self.all_function_values[-3][1])
341 < del_target_objective_tol
342 )
344 if stop_objective_input:
345 self._n_tol_input_objective += 1
346 else:
347 self._n_tol_input_objective = 0
349 if stop_objective_target:
350 self._n_tol_target_objective += 1
352 else:
353 self._n_tol_target_objective = 0
355 if (self._n_tol_input_objective >= del_input_objective_n) and (
356 self._n_tol_target_objective >= del_target_objective_n
357 ): # n objective function changes in input AND output space below threshold
358 stop_objective = True
360 stop_prox = False # variable to check if the proximity criteria is met
361 # check if proximity values are below the threshold
362 if self.proximities[-1][1][0] < prox_tol: # proximity below goal/tolerance
363 stop_prox = True
365 # check if the proximity changes are below tolerance level
366 if (
367 abs(self.proximities[-3][1][0] - self.proximities[-1][1][0])
368 / max(1, self.proximities[-3][1][0])
369 < del_prox_tol
370 ):
371 self._n_tol_prox += 1
372 else:
373 self._n_tol_prox = 0
374 if self._n_tol_prox >= del_prox_n: # n proximity changes below threshold
375 stop_prox = True
377 # check if both criteria are met
378 return stop_objective and stop_prox
380 def _additional_action(self, x: npt.NDArray, y: npt.NDArray):
381 """
382 Perform an additional action on the given inputs.
384 Parameters
385 ----------
386 x : np.ndarray
387 Description of parameter `x`.
388 y : np.ndarray
389 Description of parameter `y`.
391 Returns
392 -------
393 None
394 """
396 def _initial_storage(self, x: npt.NDArray, storage: bool, f: list, p: list):
397 """
398 Initializes storage for objective values and appends initial values.
400 Parameters
401 ----------
402 x : np.ndarray
403 Initial values of the variables.
404 f : np.ndarray
405 Initial values of the objective function.
406 p : np.ndarray
407 Proximity function value
408 """
409 # reset objective values
410 self.all_x = []
411 self.all_function_values = [] # array storing all objective function values
412 self.proximities = [] # array storing all proximity function values
414 self.all_x_function_reduction = []
415 self.all_function_values_function_reduction = []
416 self.proximities_function_reduction = []
418 self.all_x_basic = []
419 self.all_function_values_basic = []
420 self.proximities_basic = []
422 # append initial values
423 f_temp = []
424 for el in f:
425 if not NO_GPU and isinstance(el, cp.ndarray):
426 f_temp.append(el.get())
427 else:
428 f_temp.append(el)
430 # modify proximities
431 p_temp = []
432 for el in p:
433 if not NO_GPU and isinstance(el, cp.ndarray):
434 p_temp.append(el.get())
435 else:
436 p_temp.append(el)
438 self.all_function_values.append(f_temp)
439 self.all_function_values_basic.append(f_temp)
440 self.all_function_values_function_reduction.append(f_temp)
442 self.proximities.append(p_temp)
443 self.proximities_basic.append(p_temp)
444 self.proximities_function_reduction.append(p_temp)
446 if storage:
447 if isinstance(x, np.ndarray):
448 self.all_x.append(np.array(x.copy()))
449 self.all_x_basic.append(x.copy())
450 self.all_x_function_reduction.append(x.copy())
452 else:
453 self.all_x.append((x.get()))
454 self.all_x_basic.append((x.get()))
455 self.all_x_function_reduction.append((x.get()))
457 def storage(
458 self,
459 x: npt.NDArray,
460 kind: str,
461 storage: bool = True,
462 f: list | None = None,
463 p: list | None = None,
464 ):
465 """
466 Stores the given values of x and f into the corresponding lists.
468 Parameters
469 ----------
470 x : npt.NDArray
471 The current value of the variable x to be stored.
472 kind : str
473 The type of storage to be used, either "function_reduction" or "basic".
474 storage : bool, optional
475 If True, store the values of x
476 """
478 # always store all function and proximity values
479 f_temp = []
480 for el in f:
481 if not NO_GPU and isinstance(el, cp.ndarray):
482 f_temp.append(el.get())
483 else:
484 f_temp.append(el)
485 self.all_function_values.append(f_temp)
487 # modify proximities
488 p_temp = []
489 for el in p:
490 if not NO_GPU and isinstance(el, cp.ndarray):
491 p_temp.append(el.get())
492 else:
493 p_temp.append(el)
494 self.proximities.append(p_temp)
496 if storage and isinstance(x, np.ndarray):
497 self.all_x.append(x.copy())
498 elif storage:
499 self.all_x.append((x.get()))
501 if kind == "function_reduction":
502 self.all_function_values_function_reduction.append(f_temp)
503 self.proximities_function_reduction.append(p_temp)
505 if storage and isinstance(x, np.ndarray):
506 self.all_x_function_reduction.append(x.copy())
507 elif storage:
508 self.all_x_function_reduction.append((x.get()))
510 elif kind == "basic":
511 self.all_function_values_basic.append(f_temp)
512 self.proximities_basic.append(p_temp)
514 if storage and isinstance(x, np.ndarray):
515 self.all_x_basic.append(x.copy())
516 elif storage:
517 self.all_x_basic.append((x.get()))
519 else:
520 raise ValueError("Invalid storage type. Use 'function_reduction' or 'basic'.")
522 def _post_step(self, x: npt.NDArray):
523 """
524 Perform an action after the optimization process has finished.
526 Parameters
527 ----------
528 x : array-like
529 The current value of the variable x.
531 Returns
532 -------
533 None
534 """
536 self.all_x = np.array(self.all_x)
537 self.all_x_function_reduction = np.array(self.all_x_function_reduction)
538 self.all_x_basic = np.array(self.all_x_basic)
540 self.all_function_values = np.array(self.all_function_values)
541 self.all_function_values_function_reduction = np.array(
542 self.all_function_values_function_reduction
543 )
544 self.all_function_values_basic = np.array(self.all_function_values_basic)
545 self.proximities = np.array(self.proximities)
546 self.proximities_function_reduction = np.array(self.proximities_function_reduction)
547 self.proximities_basic = np.array(self.proximities_basic)