Coverage for suppy\superiorization\_split_sup.py: 86%

173 statements  

« prev     ^ index     » next       coverage.py v7.6.4, created at 2026-05-08 13:56 +0200

1from typing import List, Callable 

2 

3import numpy as np 

4import numpy.typing as npt 

5from suppy.feasibility._bands._ams_extrapolations import AdaptiveStepLandweberHyperslab 

6from suppy.utils import ensure_float_array 

7from suppy.perturbations import Perturbation, DummyPerturbation 

8from ._sup import FeasibilityPerturbation 

9 

10try: 

11 import cupy as cp 

12 

13 NO_GPU = False 

14except ImportError: 

15 cp = np 

16 NO_GPU = True 

17 

18 

19class SplitSuperiorization(FeasibilityPerturbation): 

20 """ 

21 A class used to perform split superiorization on a given feasibility 

22 problem. 

23 

24 Parameters 

25 ---------- 

26 basic : object 

27 An instance of a split problem. 

28 input_perturbation_scheme : Perturbation or None, optional 

29 Perturbation scheme for the input, by default None. 

30 target_perturbation_scheme : Perturbation or None, optional 

31 Perturbation scheme for the target, by default None. 

32 input_objective_tol : float, optional 

33 Tolerance for the input objective function, by default 1e-4. 

34 target_objective_tol : float, optional 

35 Tolerance for the target objective function, by default 1e-4. 

36 prox_tol : float, optional 

37 Tolerance for the constraint, by default 1e-6. 

38 

39 Attributes 

40 ---------- 

41 input_perturbation_scheme : Perturbation or None 

42 Perturbation scheme for the input. 

43 target_perturbation_scheme : Perturbation or None 

44 Perturbation scheme for the target. 

45 input_objective_tol : float 

46 Tolerance for the input objective function. 

47 target_objective_tol : float 

48 Tolerance for the target objective function. 

49 prox_tol : float 

50 Tolerance for the constraint. 

51 input_f_k : float 

52 The current objective function value for the input. 

53 target_f_k : float 

54 The current objective function value for the target. 

55 p_k : float 

56 The current proximity function value. 

57 _k : int 

58 The current iteration number. 

59 all_x_values : list 

60 Array storing all points achieved via the superiorization algorithm. 

61 all_function_values : list 

62 Array storing all objective function values achieved via the superiorization algorithm. 

63 all_x_values_function_reduction : list 

64 Array storing all points achieved via the function reduction step. 

65 all_function_values_function_reduction : list 

66 Array storing all objective function values achieved via the function reduction step. 

67 """ 

68 

69 def __init__( 

70 self, 

71 basic, # needs to be a split problem 

72 input_perturbation_scheme: Perturbation | None = None, 

73 target_perturbation_scheme: Perturbation | None = None, 

74 ): 

75 super().__init__(basic) 

76 if input_perturbation_scheme is None and target_perturbation_scheme is None: 

77 raise ValueError( 

78 "At least one perturbation scheme must be provided for SplitSuperiorization." 

79 ) 

80 

81 self.input_perturbation_scheme = ( 

82 input_perturbation_scheme 

83 if input_perturbation_scheme is not None 

84 else DummyPerturbation() 

85 ) 

86 self.target_perturbation_scheme = ( 

87 target_perturbation_scheme 

88 if target_perturbation_scheme is not None 

89 else DummyPerturbation() 

90 ) 

91 

92 # initialize some variables for the algorithms 

93 self.input_f_k = None 

94 self.target_f_k = None 

95 self.p_k = None 

96 self._k = 0 

97 

98 self.all_x = [] 

99 self.all_function_values = [] 

100 self.proximities = [] 

101 

102 self.all_x_function_reduction = [] 

103 self.all_function_values_function_reduction = [] 

104 self.proximities_function_reduction = [] 

105 

106 self.all_x_basic = [] 

107 self.all_function_values_basic = [] 

108 self.proximities_basic = [] 

109 

110 @ensure_float_array 

111 def solve( 

112 self, 

113 x: npt.NDArray, 

114 max_iter: int = 10, 

115 prox_tol: float = 1e-6, 

116 del_prox_tol: float = 1e-8, 

117 del_prox_n: int = 5, 

118 proximity_measures: List | None = None, 

119 del_input_objective_tol: float = 1e-6, 

120 del_input_objective_n: int = 5, 

121 del_target_objective_tol: float = 1e-6, 

122 del_target_objective_n: int = 5, 

123 storage: bool = False, 

124 storage_iters: List[int] | int | None = None, 

125 alternative_stopping_criterion: Callable | None = None, 

126 alternative_stopping_criterion_initial_call: Callable | None = None, 

127 ) -> np.ndarray: 

128 """ 

129 Solves the optimization problem using the superiorization method. 

130 

131 Parameters 

132 ---------- 

133 x : npt.NDArray 

134 Starting point for the algorithm. 

135 max_iter : int, optional 

136 Maximum number of iterations to perform, by default 500. 

137 prox_tol : float, optional 

138 The tolerance for the proximity on the constraints, by default 1e-6. 

139 del_prox_tol : float, optional 

140 The tolerance for the change in proximity over the last del_prox_n iterations, by default 1e-8. 

141 del_prox_n : int, optional 

142 The number of iterations that del_prox_tol needs to be met in a row, by default 5. 

143 proximity_measures : List, optional 

144 The proximity measures to calculate, by default a l2 norm measure is used. Right now only the first in the list is used to check the feasibility. 

145 del_input_objective_tol 

146 The tolerance for change in the objective function over the last del_input_objective_n iterations, by default 1e-8. 

147 del_input_objective_n 

148 The number of iterations that del_input_objective_tol needs to be met in a row, by default 5. 

149 del_target_objective_tol 

150 The tolerance for change in the objective function over the last del_target_objective_n iterations, by default 1e-8. 

151 del_target_objective_n 

152 The number of iterations that del_target_objective_tol needs to be met in a row, by default 5. 

153 storage : bool, optional 

154 Flag indicating whether to store intermediate solutions, by default False. 

155 storage_iters : List[int] or int, optional 

156 Controls which iterations are stored (when storage=True). If None, all iterations are stored. 

157 If a list of ints, only those iteration indices are stored (0 = initial point). 

158 If an int, storage occurs every that many iterations. 

159 alternative_stopping_criterion : callable, optional 

160 Alternative stopping criterion 

161 alternative_stopping_criterion_initial_call : callable, optional 

162 Initial call for an alternative stopping criterion 

163 

164 Returns 

165 ------- 

166 npt.NDArray 

167 The superiorized solution after performing the superiorization method. 

168 """ 

169 

170 def _should_store(idx): 

171 if storage_iters is None: 

172 return True 

173 if isinstance(storage_iters, int): 

174 return idx % storage_iters == 0 

175 return idx in storage_iters 

176 

177 # initialization of variables 

178 if proximity_measures is None: 

179 proximity_measures = [("p_norm", 2)] 

180 else: 

181 # TODO: check that proximity measures are valid 

182 _ = None 

183 

184 self.input_perturbation_scheme.reset() # reset the input perturbation scheme 

185 self.target_perturbation_scheme.reset() # reset the target perturbation scheme 

186 

187 self._n_tol_input_objective = ( 

188 0 # number of iterations with objective function changes below threshold 

189 ) 

190 self._n_tol_target_objective = ( 

191 0 # number of iterations with objective function changes below threshold 

192 ) 

193 

194 self._n_tol_prox = 0 # number of iterations with proximity changes below threshold 

195 

196 self.t = [0] # array storing the time for each iteration 

197 self.l = [] 

198 

199 x_0 = x.copy() 

200 

201 self._initial_storage( 

202 x, 

203 storage and _should_store(0), 

204 [ 

205 self.input_perturbation_scheme.func(x_0), 

206 self.target_perturbation_scheme.func(self.basic.map(x_0)), 

207 ], 

208 self.basic.proximity(x_0, proximity_measures), 

209 ) 

210 

211 self._k = 0 # reset counter if necessary 

212 

213 if alternative_stopping_criterion_initial_call is not None: 

214 stop = alternative_stopping_criterion_initial_call(x, self) 

215 else: 

216 stop = False # criterion for stopping the algorithm 

217 

218 # initial function and proximity values 

219 # self.input_f_k = self.input_perturbation_scheme.func(x_0) 

220 # self.target_f_k = self.target_perturbation_scheme.func(y) 

221 

222 # self.p_k = self.basic.proximity(x_0, proximity_measures) 

223 

224 # # if storage: 

225 # # self._initial_storage(x_0,self.perturbation_scheme.func(x_0)) 

226 y = None 

227 

228 while self._k < max_iter and not stop: 

229 

230 # check if a restart should be performed 

231 

232 # perform the perturbation schemes update steps and pre steps 

233 self.input_perturbation_scheme.pre_step( 

234 x, 

235 last_proximity=self.proximities[-1][0], 

236 last_proximity_basic=self.proximities_basic[-1][0], 

237 last_proximity_function_reduction=self.proximities_function_reduction[-1][0], 

238 last_function_value=self.all_function_values[-1][0], 

239 last_function_value_basic=self.all_function_values_basic[-1][0], 

240 ) 

241 

242 x = self.input_perturbation_scheme.perturbation_step(x) 

243 

244 self.target_perturbation_scheme.pre_step( 

245 y, 

246 last_proximity=self.proximities[-1][1], 

247 last_proximity_basic=self.proximities_basic[-1][1], 

248 last_proximity_function_reduction=self.proximities_function_reduction[-1][1], 

249 last_function_value=self.all_function_values[-1][1], 

250 last_function_value_basic=self.all_function_values_basic[-1][1], 

251 ) 

252 

253 y = self.target_perturbation_scheme.perturbation_step(y) 

254 

255 # post steps 

256 self.input_perturbation_scheme.post_step( 

257 x, 

258 last_proximity=self.proximities[-1][0], 

259 last_proximity_basic=self.proximities_basic[-1][0], 

260 last_proximity_function_reduction=self.proximities_function_reduction[-1][0], 

261 last_function_value=self.all_function_values[-1][0], 

262 last_function_value_basic=self.all_function_values_basic[-1][0], 

263 ) 

264 

265 self.target_perturbation_scheme.post_step( 

266 y, 

267 last_proximity=self.proximities[-1][1], 

268 last_proximity_basic=self.proximities_basic[-1][1], 

269 last_proximity_function_reduction=self.proximities_function_reduction[-1][1], 

270 last_function_value=self.all_function_values[-1][1], 

271 last_function_value_basic=self.all_function_values_basic[-1][1], 

272 ) 

273 

274 self.storage( 

275 x, 

276 kind="function_reduction", 

277 storage=storage and _should_store(self._k + 1), 

278 f=[ 

279 self.input_perturbation_scheme.func(x), 

280 self.target_perturbation_scheme.func(y), 

281 ], 

282 p=self.basic.proximity(x, proximity_measures), 

283 ) 

284 

285 # perform basic step 

286 x, y = self.basic.step(x, y) 

287 

288 self.storage( 

289 x, 

290 kind="basic", 

291 storage=storage and _should_store(self._k + 1), 

292 f=[ 

293 self.input_perturbation_scheme.func(x), 

294 self.target_perturbation_scheme.func(y), 

295 ], 

296 p=self.basic.proximity(x, proximity_measures), 

297 ) 

298 

299 self._k += 1 

300 

301 # enable different stopping criteria for different superiorization algorithms 

302 if alternative_stopping_criterion is not None: 

303 stop = alternative_stopping_criterion(x, self) 

304 else: 

305 stop = self._stopping_criterion( 

306 del_input_objective_tol, 

307 del_input_objective_n, 

308 del_target_objective_tol, 

309 del_target_objective_n, 

310 prox_tol, 

311 del_prox_tol, 

312 del_prox_n, 

313 ) 

314 

315 self._additional_action(x, y) 

316 

317 self._post_step(x) 

318 

319 return x 

320 

321 def _stopping_criterion( 

322 self, 

323 del_input_objective_tol: float, 

324 del_input_objective_n: int, 

325 del_target_objective_tol: float, 

326 del_target_objective_n: int, 

327 prox_tol: float, 

328 del_prox_tol: float, 

329 del_prox_n: int, 

330 ) -> bool: 

331 """""" 

332 stop_objective = False # variable to check if the objective function criteria is met 

333 stop_objective_input = ( 

334 abs(self.all_function_values[-3][0] - self.all_function_values[-1][0]) 

335 / max(1, self.all_function_values[-3][0]) 

336 < del_input_objective_tol 

337 ) 

338 stop_objective_target = ( 

339 abs(self.all_function_values[-3][1] - self.all_function_values[-1][1]) 

340 / max(1, self.all_function_values[-3][1]) 

341 < del_target_objective_tol 

342 ) 

343 

344 if stop_objective_input: 

345 self._n_tol_input_objective += 1 

346 else: 

347 self._n_tol_input_objective = 0 

348 

349 if stop_objective_target: 

350 self._n_tol_target_objective += 1 

351 

352 else: 

353 self._n_tol_target_objective = 0 

354 

355 if (self._n_tol_input_objective >= del_input_objective_n) and ( 

356 self._n_tol_target_objective >= del_target_objective_n 

357 ): # n objective function changes in input AND output space below threshold 

358 stop_objective = True 

359 

360 stop_prox = False # variable to check if the proximity criteria is met 

361 # check if proximity values are below the threshold 

362 if self.proximities[-1][1][0] < prox_tol: # proximity below goal/tolerance 

363 stop_prox = True 

364 

365 # check if the proximity changes are below tolerance level 

366 if ( 

367 abs(self.proximities[-3][1][0] - self.proximities[-1][1][0]) 

368 / max(1, self.proximities[-3][1][0]) 

369 < del_prox_tol 

370 ): 

371 self._n_tol_prox += 1 

372 else: 

373 self._n_tol_prox = 0 

374 if self._n_tol_prox >= del_prox_n: # n proximity changes below threshold 

375 stop_prox = True 

376 

377 # check if both criteria are met 

378 return stop_objective and stop_prox 

379 

380 def _additional_action(self, x: npt.NDArray, y: npt.NDArray): 

381 """ 

382 Perform an additional action on the given inputs. 

383 

384 Parameters 

385 ---------- 

386 x : np.ndarray 

387 Description of parameter `x`. 

388 y : np.ndarray 

389 Description of parameter `y`. 

390 

391 Returns 

392 ------- 

393 None 

394 """ 

395 

396 def _initial_storage(self, x: npt.NDArray, storage: bool, f: list, p: list): 

397 """ 

398 Initializes storage for objective values and appends initial values. 

399 

400 Parameters 

401 ---------- 

402 x : np.ndarray 

403 Initial values of the variables. 

404 f : np.ndarray 

405 Initial values of the objective function. 

406 p : np.ndarray 

407 Proximity function value 

408 """ 

409 # reset objective values 

410 self.all_x = [] 

411 self.all_function_values = [] # array storing all objective function values 

412 self.proximities = [] # array storing all proximity function values 

413 

414 self.all_x_function_reduction = [] 

415 self.all_function_values_function_reduction = [] 

416 self.proximities_function_reduction = [] 

417 

418 self.all_x_basic = [] 

419 self.all_function_values_basic = [] 

420 self.proximities_basic = [] 

421 

422 # append initial values 

423 f_temp = [] 

424 for el in f: 

425 if not NO_GPU and isinstance(el, cp.ndarray): 

426 f_temp.append(el.get()) 

427 else: 

428 f_temp.append(el) 

429 

430 # modify proximities 

431 p_temp = [] 

432 for el in p: 

433 if not NO_GPU and isinstance(el, cp.ndarray): 

434 p_temp.append(el.get()) 

435 else: 

436 p_temp.append(el) 

437 

438 self.all_function_values.append(f_temp) 

439 self.all_function_values_basic.append(f_temp) 

440 self.all_function_values_function_reduction.append(f_temp) 

441 

442 self.proximities.append(p_temp) 

443 self.proximities_basic.append(p_temp) 

444 self.proximities_function_reduction.append(p_temp) 

445 

446 if storage: 

447 if isinstance(x, np.ndarray): 

448 self.all_x.append(np.array(x.copy())) 

449 self.all_x_basic.append(x.copy()) 

450 self.all_x_function_reduction.append(x.copy()) 

451 

452 else: 

453 self.all_x.append((x.get())) 

454 self.all_x_basic.append((x.get())) 

455 self.all_x_function_reduction.append((x.get())) 

456 

457 def storage( 

458 self, 

459 x: npt.NDArray, 

460 kind: str, 

461 storage: bool = True, 

462 f: list | None = None, 

463 p: list | None = None, 

464 ): 

465 """ 

466 Stores the given values of x and f into the corresponding lists. 

467 

468 Parameters 

469 ---------- 

470 x : npt.NDArray 

471 The current value of the variable x to be stored. 

472 kind : str 

473 The type of storage to be used, either "function_reduction" or "basic". 

474 storage : bool, optional 

475 If True, store the values of x 

476 """ 

477 

478 # always store all function and proximity values 

479 f_temp = [] 

480 for el in f: 

481 if not NO_GPU and isinstance(el, cp.ndarray): 

482 f_temp.append(el.get()) 

483 else: 

484 f_temp.append(el) 

485 self.all_function_values.append(f_temp) 

486 

487 # modify proximities 

488 p_temp = [] 

489 for el in p: 

490 if not NO_GPU and isinstance(el, cp.ndarray): 

491 p_temp.append(el.get()) 

492 else: 

493 p_temp.append(el) 

494 self.proximities.append(p_temp) 

495 

496 if storage and isinstance(x, np.ndarray): 

497 self.all_x.append(x.copy()) 

498 elif storage: 

499 self.all_x.append((x.get())) 

500 

501 if kind == "function_reduction": 

502 self.all_function_values_function_reduction.append(f_temp) 

503 self.proximities_function_reduction.append(p_temp) 

504 

505 if storage and isinstance(x, np.ndarray): 

506 self.all_x_function_reduction.append(x.copy()) 

507 elif storage: 

508 self.all_x_function_reduction.append((x.get())) 

509 

510 elif kind == "basic": 

511 self.all_function_values_basic.append(f_temp) 

512 self.proximities_basic.append(p_temp) 

513 

514 if storage and isinstance(x, np.ndarray): 

515 self.all_x_basic.append(x.copy()) 

516 elif storage: 

517 self.all_x_basic.append((x.get())) 

518 

519 else: 

520 raise ValueError("Invalid storage type. Use 'function_reduction' or 'basic'.") 

521 

522 def _post_step(self, x: npt.NDArray): 

523 """ 

524 Perform an action after the optimization process has finished. 

525 

526 Parameters 

527 ---------- 

528 x : array-like 

529 The current value of the variable x. 

530 

531 Returns 

532 ------- 

533 None 

534 """ 

535 

536 self.all_x = np.array(self.all_x) 

537 self.all_x_function_reduction = np.array(self.all_x_function_reduction) 

538 self.all_x_basic = np.array(self.all_x_basic) 

539 

540 self.all_function_values = np.array(self.all_function_values) 

541 self.all_function_values_function_reduction = np.array( 

542 self.all_function_values_function_reduction 

543 ) 

544 self.all_function_values_basic = np.array(self.all_function_values_basic) 

545 self.proximities = np.array(self.proximities) 

546 self.proximities_function_reduction = np.array(self.proximities_function_reduction) 

547 self.proximities_basic = np.array(self.proximities_basic)