Coverage for suppy\projections\_basic_projections.py: 65%

455 statements  

« prev     ^ index     » next       coverage.py v7.6.4, created at 2026-05-08 15:16 +0200

1"""Simple projection objects.""" 

2from abc import ABC, abstractmethod 

3import math 

4from typing import List 

5import numpy as np 

6import numpy.typing as npt 

7import matplotlib.pyplot as plt 

8from matplotlib import patches 

9 

10from suppy.projections._projections import BasicProjection 

11 

12try: 

13 import cupy as cp 

14 

15 NO_GPU = False 

16except ImportError: 

17 NO_GPU = True 

18 cp = np 

19 

20# from suppy.utils.decorators import ensure_float_array 

21 

22 

23# Class for basic projections 

24 

25 

26class BoxProjection(BasicProjection): 

27 """ 

28 BoxProjection class for projecting points onto a box defined by lower 

29 and upper bounds. 

30 

31 Parameters 

32 ---------- 

33 lb : npt.NDArray 

34 Lower bounds of the box. 

35 ub : npt.NDArray 

36 Upper bounds of the box. 

37 idx : npt.NDArray or None 

38 Subset of the input vector to apply the projection on. 

39 relaxation : float, optional 

40 Relaxation parameter for the projection, by default 1. 

41 proximity_flag : bool 

42 Flag to indicate whether to take this object into account when calculating proximity, 

43 by default True. 

44 

45 Attributes 

46 ---------- 

47 lb : npt.NDArray 

48 Lower bounds of the box. 

49 ub : npt.NDArray 

50 Upper bounds of the box. 

51 relaxation : float 

52 Relaxation parameter for the projection. 

53 proximity_flag : bool 

54 Flag to indicate whether to take this object into account when calculating proximity. 

55 idx : npt.NDArray 

56 Subset of the input vector to apply the projection on. 

57 """ 

58 

59 def __init__( 

60 self, 

61 lb: npt.NDArray, 

62 ub: npt.NDArray, 

63 relaxation: float = 1, 

64 idx: npt.NDArray | None = None, 

65 proximity_flag=True, 

66 use_gpu=False, 

67 ): 

68 

69 super().__init__(relaxation, idx, proximity_flag, use_gpu) 

70 self.lb = lb 

71 self.ub = ub 

72 

73 def _project(self, x: npt.NDArray) -> np.ndarray: 

74 """ 

75 Projects the input array `x` onto the bounds defined by `self.lb` 

76 and `self.ub`. 

77 

78 Parameters 

79 ---------- 

80 x : npt.NDArray 

81 Input array to be projected. Can be a NumPy array or a CuPy array. 

82 

83 Returns 

84 ------- 

85 npt.NDArray 

86 The projected array with values clipped to the specified bounds. 

87 

88 Notes 

89 ----- 

90 This method modifies the input array `x` in place. 

91 """ 

92 xp = cp if isinstance(x, cp.ndarray) else np 

93 x[self.idx] = xp.maximum(self.lb, xp.minimum(self.ub, x[self.idx])) 

94 return x 

95 

96 def _proximity(self, x: npt.NDArray, proximity_measures: List) -> float: 

97 res = abs(x[self.idx] - self._project(x.copy())[self.idx]) 

98 measures = [] 

99 for measure in proximity_measures: 

100 if isinstance(measure, tuple): 

101 if measure[0] == "p_norm": 

102 measures.append(1 / len(res) * (res ** measure[1]).sum()) 

103 else: 

104 raise ValueError("Invalid proximity measure") 

105 elif isinstance(measure, str) and measure == "max_norm": 

106 measures.append(res.max()) 

107 else: 

108 raise ValueError("Invalid proximity measure") 

109 return measures 

110 

111 def visualize(self, ax: plt.Axes | None = None, color=None): 

112 """ 

113 Visualize the box if it is 2D on a given matplotlib Axes. 

114 

115 Parameters 

116 ---------- 

117 ax : plt.Axes, optional 

118 The matplotlib Axes to plot on. If None, a new figure and axes are created. 

119 color : str or None, optional 

120 The color to fill the box with. If None, the box will be filled with the default color. 

121 

122 Raises 

123 ------ 

124 ValueError 

125 If the box is not 2-dimensional. 

126 """ 

127 if len(self.lb) != 2: 

128 raise ValueError("Visualization only possible for 2D boxes") 

129 

130 if ax is None: 

131 _, ax = plt.subplots() 

132 box = patches.Rectangle( 

133 (self.lb[0], self.lb[1]), 

134 self.ub[0] - self.lb[0], 

135 self.ub[1] - self.lb[1], 

136 linewidth=1, 

137 edgecolor="black", 

138 facecolor=color, 

139 alpha=0.5, 

140 ) 

141 ax.add_patch(box) 

142 

143 def get_xy(self): 

144 """ 

145 Generate the coordinates for the edges of a box if it is 2D. 

146 

147 This method creates four edges of a 2D box defined by the lower bounds (lb) and upper bounds (ub). 

148 The edges are generated using 100 points each. 

149 

150 Returns 

151 ------- 

152 npt.NDArray 

153 A 2D array of shape (2, 400) containing the concatenated coordinates of the four edges. 

154 

155 Raises 

156 ------ 

157 ValueError 

158 If the box is not 2-dimensional. 

159 """ 

160 if len(self.lb) != 2: 

161 raise ValueError("Visualization only possible for 2D boxes") 

162 edge_1 = np.array([np.linspace(self.lb[0], self.ub[0], 100), np.ones(100) * self.lb[1]]) 

163 edge_2 = np.array([np.ones(100) * self.ub[0], np.linspace(self.lb[1], self.ub[1], 100)]) 

164 edge_3 = np.array([np.linspace(self.lb[0], self.ub[0], 100), np.ones(100) * self.ub[1]]) 

165 edge_4 = np.array([np.ones(100) * self.lb[0], np.linspace(self.lb[1], self.ub[1], 100)]) 

166 return np.concatenate((edge_1, edge_2, edge_3[:, ::-1], edge_4[:, ::-1]), axis=1) 

167 

168 

169class WeightedBoxProjection(BasicProjection): 

170 """ 

171 WeightedBoxProjection applies a weighted projection on a box defined by 

172 lower and upper bounds. 

173 The idea is a "simultaneous" variant to the "sequential" BoxProjection. 

174 

175 Parameters 

176 ---------- 

177 lb : npt.NDArray 

178 Lower bounds of the box. 

179 ub : npt.NDArray 

180 Upper bounds of the box. 

181 weights : npt.NDArray 

182 Weights for the projection. 

183 relaxation : float, optional 

184 Relaxation parameter, by default 1. 

185 idx : npt.NDArray or None 

186 Subset of the input vector to apply the projection on. 

187 proximity_flag : bool, optional 

188 Flag to indicate if proximity should be calculated, by default True. 

189 use_gpu : bool, optional 

190 Flag to indicate if GPU should be used, by default False. 

191 

192 Attributes 

193 ---------- 

194 lb : npt.NDArray 

195 Lower bounds of the box. 

196 ub : npt.NDArray 

197 Upper bounds of the box. 

198 relaxation : float 

199 Relaxation parameter for the projection. 

200 proximity_flag : bool 

201 Flag to indicate whether to take this object into account when calculating proximity. 

202 idx : npt.NDArray 

203 Subset of the input vector to apply the projection on. 

204 """ 

205 

206 def __init__( 

207 self, 

208 lb: npt.NDArray, 

209 ub: npt.NDArray, 

210 weights: npt.NDArray, 

211 relaxation: float = 1, 

212 idx: npt.NDArray | None = None, 

213 proximity_flag=True, 

214 use_gpu=False, 

215 ): 

216 

217 super().__init__(relaxation, idx, proximity_flag, use_gpu) 

218 self.lb = lb 

219 self.ub = ub 

220 self.weights = weights / weights.sum() 

221 

222 def _project(self, x: npt.NDArray) -> np.ndarray: 

223 """ 

224 Projects the input array `x`. 

225 

226 Parameters 

227 ---------- 

228 x : npt.NDArray 

229 The input array to be projected. 

230 

231 Returns 

232 ------- 

233 npt.NDArray 

234 The projected array. 

235 

236 Notes 

237 ----- 

238 This method modifies the input array `x` in place. 

239 """ 

240 xp = cp if isinstance(x, cp.ndarray) else np 

241 x[self.idx] += self.weights * ( 

242 xp.maximum(self.lb, xp.minimum(self.ub, x[self.idx])) - x[self.idx] 

243 ) 

244 return x 

245 

246 def _full_project(self, x: npt.NDArray) -> np.ndarray: 

247 """ 

248 Projects the elements of the input array `x` within the specified 

249 bounds. 

250 

251 Parameters 

252 ---------- 

253 x : npt.NDArray 

254 Input array to be projected. 

255 

256 Returns 

257 ------- 

258 npt.NDArray 

259 The projected array with elements constrained within the bounds. 

260 """ 

261 xp = cp if isinstance(x, cp.ndarray) else np 

262 x[self.idx] = xp.maximum(self.lb, xp.minimum(self.ub, x[self.idx])) 

263 

264 return x 

265 

266 def _proximity(self, x: npt.NDArray, proximity_measures: List) -> float: 

267 res = abs(x[self.idx] - self._project(x.copy())[self.idx]) 

268 measures = [] 

269 for measure in proximity_measures: 

270 if isinstance(measure, tuple): 

271 if measure[0] == "p_norm": 

272 measures.append(self.weights @ (res ** measure[1])) 

273 else: 

274 raise ValueError("Invalid proximity measure") 

275 elif isinstance(measure, str) and measure == "max_norm": 

276 measures.append(res.max()) 

277 else: 

278 raise ValueError("Invalid proximity measure") 

279 return measures 

280 

281 def visualize(self, ax: plt.Axes | None = None, color=None): 

282 """ 

283 Visualize the box if it is 2D on a given matplotlib Axes. 

284 

285 Parameters 

286 ---------- 

287 ax : plt.Axes, optional 

288 The matplotlib Axes to plot on. If None, a new figure and axes are created. 

289 color : str or None, optional 

290 The color to fill the box with. If None, the box will be filled with the default color. 

291 

292 Raises 

293 ------ 

294 ValueError 

295 If the box is not 2-dimensional. 

296 """ 

297 if len(self.lb) != 2: 

298 raise ValueError("Visualization only possible for 2D boxes") 

299 

300 if ax is None: 

301 _, ax = plt.subplots() 

302 box = patches.Rectangle( 

303 (self.lb[0], self.lb[1]), 

304 self.ub[0] - self.lb[0], 

305 self.ub[1] - self.lb[1], 

306 linewidth=1, 

307 edgecolor="black", 

308 facecolor=color, 

309 alpha=0.5, 

310 ) 

311 ax.add_patch(box) 

312 

313 def get_xy(self): 

314 """ 

315 Generate the coordinates for the edges of a box if it is 2D. 

316 

317 This method creates four edges of a 2D box defined by the lower bounds (lb) and upper bounds (ub). 

318 The edges are generated using 100 points each. 

319 

320 Returns 

321 ------- 

322 np.ndarray 

323 A 2D array of shape (2, 400) containing the concatenated coordinates of the four edges. 

324 

325 Raises 

326 ------ 

327 ValueError 

328 If the box is not 2-dimensional. 

329 """ 

330 if len(self.lb) != 2: 

331 raise ValueError("Visualization only possible for 2D boxes") 

332 edge_1 = np.array([np.linspace(self.lb[0], self.ub[0], 100), np.ones(100) * self.lb[1]]) 

333 edge_2 = np.array([np.ones(100) * self.ub[0], np.linspace(self.lb[1], self.ub[1], 100)]) 

334 edge_3 = np.array([np.linspace(self.lb[0], self.ub[0], 100), np.ones(100) * self.ub[1]]) 

335 edge_4 = np.array([np.ones(100) * self.lb[0], np.linspace(self.lb[1], self.ub[1], 100)]) 

336 return np.concatenate((edge_1, edge_2, edge_3[:, ::-1], edge_4[:, ::-1]), axis=1) 

337 

338 

339# Projection onto a single halfspace 

340class HalfspaceProjection(BasicProjection): 

341 """ 

342 A class used to represent a projection onto a halfspace. 

343 

344 Parameters 

345 ---------- 

346 a : npt.NDArray 

347 The normal vector defining the halfspace. 

348 b : float 

349 The offset value defining the halfspace. 

350 relaxation : float, optional 

351 The relaxation parameter, by default 1. 

352 idx : npt.NDArray or None 

353 Subset of the input vector to apply the projection on. 

354 proximity_flag : bool, optional 

355 Flag to indicate whether to take this object into account when calculating proximity, by default True. 

356 use_gpu : bool, optional 

357 Flag to indicate if GPU should be used, by default False. 

358 

359 Attributes 

360 ---------- 

361 a : npt.NDArray 

362 The normal vector defining the halfspace. 

363 a_norm : npt.NDArray 

364 The normalized normal vector. 

365 b : float 

366 The offset value defining the halfspace. 

367 relaxation : float 

368 The relaxation parameter for the projection. 

369 proximity_flag : bool 

370 Flag to indicate whether to take this object into account when calculating proximity. 

371 idx : npt.NDArray 

372 Subset of the input vector to apply the projection on. 

373 """ 

374 

375 def __init__( 

376 self, 

377 a: npt.NDArray, 

378 b: float, 

379 relaxation: float = 1, 

380 idx: npt.NDArray | None = None, 

381 proximity_flag=True, 

382 use_gpu=False, 

383 ): 

384 

385 super().__init__(relaxation, idx, proximity_flag, use_gpu) 

386 self.a = a 

387 self.a_norm = self.a / (self.a @ self.a) 

388 self.b = b 

389 

390 def _linear_map(self, x): 

391 return self.a @ x 

392 

393 def _project(self, x: npt.NDArray) -> np.ndarray: 

394 """ 

395 Projects the input array `x`. 

396 

397 Parameters 

398 ---------- 

399 x : npt.NDArray 

400 The input array to be projected. 

401 

402 Returns 

403 ------- 

404 npt.NDArray 

405 The projected array. 

406 

407 Notes 

408 ----- 

409 This method modifies the input array `x` in place. 

410 """ 

411 

412 # TODO: dtype check! 

413 y = self._linear_map(x[self.idx]) 

414 

415 if y > self.b: 

416 x[self.idx] -= (y - self.b) * self.a_norm 

417 

418 return x 

419 

420 def get_xy(self, x: npt.NDArray | None = None): 

421 """ 

422 Generate x and y coordinates for visualization of 2D halfspaces. 

423 

424 Parameters 

425 ---------- 

426 x : npt.NDArray or None, optional 

427 The x-coordinates for which to compute the corresponding y-coordinates. 

428 If None, a default range of x values from -10 to 10 is used. 

429 

430 Returns 

431 ------- 

432 np.ndarray 

433 A 2D array where the first row contains the x-coordinates and the second row contains the corresponding y-coordinates. 

434 

435 Raises 

436 ------ 

437 ValueError 

438 If the halfspace is not 2-dimensional. 

439 """ 

440 if len(self.a) != 2: 

441 raise ValueError("Visualization only possible for 2D halfspaces") 

442 

443 if x is None: 

444 x = np.linspace(-10, 10, 100) 

445 

446 if self.a[1] == 0: 

447 y = np.array([np.ones(100) * self.b, np.linspace(-10, 10, 100)]) 

448 else: 

449 y = (self.b - self.a[0] * x) / self.a[1] 

450 

451 return np.array([x, y]) 

452 

453 def visualize( 

454 self, 

455 ax: plt.Axes | None = None, 

456 x: npt.NDArray | None = None, 

457 y_fill: npt.NDArray | None = None, 

458 color=None, 

459 ): 

460 """ 

461 Visualize the halfspace if it is 2D on a given matplotlib Axes. 

462 

463 Parameters 

464 ---------- 

465 ax : plt.Axes, optional 

466 The matplotlib Axes to plot on. If None, a new figure and axes are created. 

467 color : str or None, optional 

468 The color to fill the box with. If None, the halfspace will be filled with the default color. 

469 

470 Raises 

471 ------ 

472 ValueError 

473 If the halfspace is not 2-dimensional. 

474 """ 

475 

476 if len(self.a) != 2: 

477 raise ValueError("Visualization only possible for 2D halfspaces") 

478 

479 if ax is None: 

480 _, ax = plt.subplots() 

481 

482 if x is None: 

483 x = np.linspace(-10, 10, 100) 

484 

485 if self.a[1] == 0: 

486 ax.axvline(x=self.b / self.a[0], label="Halfspace", color=color) 

487 if np.sign(self.a[0]) == 1: 

488 ax.fill_betweenx( 

489 x, 

490 ax.get_xlim()[0], 

491 self.b, 

492 color=color, 

493 label="Halfspace", 

494 alpha=0.5, 

495 ) 

496 else: 

497 ax.fill_betweenx( 

498 x, 

499 self.b, 

500 ax.get_xlim()[1], 

501 color=color, 

502 label="Halfspace", 

503 alpha=0.5, 

504 ) 

505 

506 else: 

507 y = (self.b - self.a[0] * x) / self.a[1] 

508 ax.plot(x, y, color="xkcd:black") 

509 if y_fill is None: 

510 y_fill = np.min(y) if self.a[1] > 0 else np.max(y) 

511 

512 ax.fill_between(x, y, y_fill, color=color, label="Halfspace", alpha=0.5) 

513 

514 

515class BandProjection(BasicProjection): 

516 """ 

517 A class used to represent a projection onto a band. 

518 

519 Parameters 

520 ---------- 

521 a : npt.NDArray 

522 The normal vector defining the halfspace. 

523 lb : float 

524 The lower bound of the band. 

525 ub : float 

526 The upper bound of the band. 

527 idx : npt.NDArray or None 

528 Subset of the input vector to apply the projection on. 

529 relaxation : float, optional 

530 The relaxation parameter, by default 1. 

531 idx : npt.NDArray or None 

532 Subset of the input vector to apply the projection on. 

533 

534 Attributes 

535 ---------- 

536 a : npt.NDArray 

537 The normal vector defining the halfspace. 

538 a_norm : npt.NDArray 

539 The normalized normal vector. 

540 lb : float 

541 The lower bound of the band. 

542 ub : float 

543 The upper bound of the band. 

544 relaxation : float 

545 The relaxation parameter for the projection. 

546 proximity_flag : bool 

547 Flag to indicate whether to take this object into account when calculating proximity. 

548 idx : npt.NDArray 

549 Subset of the input vector to apply the projection on. 

550 """ 

551 

552 def __init__( 

553 self, 

554 a: npt.NDArray, 

555 lb: float, 

556 ub: float, 

557 relaxation: float = 1, 

558 idx: npt.NDArray | None = None, 

559 proximity_flag=True, 

560 use_gpu=False, 

561 ): 

562 

563 super().__init__(relaxation, idx, proximity_flag, use_gpu) 

564 self.a = a 

565 self.a_norm = self.a / (self.a @ self.a) 

566 self.lb = lb 

567 self.ub = ub 

568 

569 def _project(self, x: npt.NDArray) -> np.ndarray: 

570 """ 

571 Projects the input array `x`. 

572 

573 Parameters 

574 ---------- 

575 x : npt.NDArray 

576 The input array to be projected. 

577 

578 Returns 

579 ------- 

580 npt.NDArray 

581 The projected array. 

582 

583 Notes 

584 ----- 

585 This method modifies the input array `x` in place. 

586 """ 

587 y = self.a @ x[self.idx] 

588 

589 if y > self.ub: 

590 x[self.idx] -= (y - self.ub) * self.a_norm 

591 elif y < self.lb: 

592 x[self.idx] -= (y - self.lb) * self.a_norm 

593 

594 return x 

595 

596 def get_xy(self, x: npt.NDArray | None = None): 

597 """ 

598 Calculate the x and y coordinates for the lower and upper bounds of 

599 a 2D band. 

600 

601 Parameters 

602 ---------- 

603 x : npt.NDArray or None, optional 

604 The x-coordinates at which to evaluate the bounds. If None, a default range 

605 from -10 to 10 with 100 points is used. 

606 

607 Returns 

608 ------- 

609 tuple of np.ndarray 

610 A tuple containing two numpy arrays: 

611 - The first array represents the x and y coordinates for the lower bound. 

612 - The second array represents the x and y coordinates for the upper bound. 

613 

614 Raises 

615 ------ 

616 ValueError 

617 If the band is not 2-dimensional. 

618 """ 

619 

620 if len(self.a) != 2: 

621 raise ValueError("Visualization only possible for 2D bands") 

622 

623 if x is None: 

624 x = np.linspace(-10, 10, 100) 

625 if self.a[1] == 0: 

626 y_lb = np.array([np.ones(100) * self.lb, np.linspace(-10, 10, 100)]) 

627 y_ub = np.array([np.ones(100) * self.ub, np.linspace(-10, 10, 100)]) 

628 else: 

629 y_lb = (self.lb - self.a[0] * x) / self.a[1] 

630 y_ub = (self.ub - self.a[0] * x) / self.a[1] 

631 return np.array([x, y_lb]), np.array([x, y_ub]) 

632 

633 def visualize(self, ax: plt.Axes | None = None, x: npt.NDArray | None = None, color=None): 

634 """ 

635 Visualize the band if it is 2D on a given matplotlib Axes. 

636 

637 Parameters 

638 ---------- 

639 ax : plt.Axes, optional 

640 The matplotlib Axes to plot on. If None, a new figure and axes are created. 

641 color : str or None, optional 

642 The color to fill the box with. If None, the band will be filled with the default color. 

643 

644 Raises 

645 ------ 

646 ValueError 

647 If the band is not 2-dimensional. 

648 """ 

649 

650 if len(self.a) != 2: 

651 raise ValueError("Visualization only possible for 2D bands") 

652 

653 if ax is None: 

654 _, ax = plt.subplots() 

655 

656 if x is None: 

657 x = np.linspace(-10, 10, 100) 

658 

659 if self.a[1] == 0: 

660 ax.plot(np.ones(100) * self.lb, x, color="xkcd:black") 

661 ax.plot(np.ones(100) * self.ub, x, color="xkcd:black") 

662 # ax.axvline(x = self.b/self.a[0],label='Halfspace',color = color) 

663 if np.sign(self.a[0]) == 1: 

664 ax.fill_betweenx(x, self.lb, self.ub, color=color, label="Band", alpha=0.5) 

665 else: 

666 ax.fill_betweenx(x, self.lb, self.ub, color=color, label="Band", alpha=0.5) 

667 else: 

668 y_lb = (self.lb - self.a[0] * x) / self.a[1] 

669 y_ub = (self.ub - self.a[0] * x) / self.a[1] 

670 ax.plot(x, y_lb, color="xkcd:black") 

671 ax.plot(x, y_ub, color="xkcd:black") 

672 ax.fill_between(x, y_lb, y_ub, color=color, label="Band", alpha=0.5) 

673 

674 

675class BallProjection(BasicProjection): 

676 """ 

677 A class used to represent a projection onto a ball. 

678 

679 Parameters 

680 ---------- 

681 center : npt.NDArray 

682 The center of the ball. 

683 radius : float 

684 The radius of the ball. 

685 relaxation : float, optional 

686 The relaxation parameter (default is 1). 

687 idx : npt.NDArray or None 

688 Subset of the input vector to apply the projection on. 

689 proximity_flag : bool, optional 

690 Flag to indicate whether to take this object into account when calculating proximity, by default True. 

691 use_gpu : bool, optional 

692 Flag to indicate if GPU should be used, by default False. 

693 

694 Attributes 

695 ---------- 

696 center : npt.NDArray 

697 The center of the ball. 

698 radius : float 

699 The radius of the ball. 

700 relaxation : float 

701 The relaxation parameter for the projection. 

702 proximity_flag : bool 

703 Flag to indicate whether to take this object into account when calculating proximity. 

704 idx : npt.NDArray 

705 Subset of the input vector to apply the projection on. 

706 """ 

707 

708 def __init__( 

709 self, 

710 center: npt.NDArray, 

711 radius: float, 

712 relaxation: float = 1, 

713 idx: npt.NDArray | None = None, 

714 proximity_flag=True, 

715 use_gpu=False, 

716 ): 

717 

718 super().__init__(relaxation, idx, proximity_flag, use_gpu) 

719 self.center = center 

720 self.radius = radius 

721 

722 def _project(self, x: npt.NDArray) -> np.ndarray: 

723 """ 

724 Projects the input array `x` onto the surface of the ball. 

725 

726 Parameters 

727 ---------- 

728 x : npt.NDArray 

729 The input array to be projected. 

730 

731 Returns 

732 ------- 

733 npt.NDArray 

734 The projected array. 

735 """ 

736 xp = cp if isinstance(x, cp.ndarray) else np 

737 if xp.linalg.norm(x[self.idx] - self.center) > self.radius: 

738 x[self.idx] -= (x[self.idx] - self.center) * ( 

739 1 - self.radius / xp.linalg.norm(x[self.idx] - self.center) 

740 ) 

741 

742 return x 

743 

744 def visualize(self, ax: plt.Axes | None = None, color=None, edgecolor=None): 

745 """ 

746 Visualize the halfspace if it is 2D on a given matplotlib Axes. 

747 

748 Parameters 

749 ---------- 

750 ax : plt.Axes, optional 

751 The matplotlib Axes to plot on. If None, a new figure and axes are created. 

752 color : str or None, optional 

753 The color to fill the box with. If None, the halfspace will be filled with the default color. 

754 

755 Raises 

756 ------ 

757 ValueError 

758 If the halfspace is not 2-dimensional. 

759 """ 

760 

761 if len(self.center) != 2: 

762 raise ValueError("Visualization only possible for 2D balls") 

763 

764 if ax is None: 

765 _, ax = plt.subplots() 

766 

767 circle = plt.Circle( 

768 (self.center[0], self.center[1]), 

769 self.radius, 

770 facecolor=color, 

771 alpha=0.5, 

772 edgecolor=edgecolor, 

773 ) 

774 ax.add_artist(circle) 

775 

776 def get_xy(self): 

777 """ 

778 Generate x and y coordinates for a 2D ball visualization. 

779 

780 Returns 

781 ------- 

782 np.ndarray 

783 A 2x50 array where the first row contains the x coordinates and the 

784 second row contains the y coordinates of the points on the circumference 

785 of the 2D ball. 

786 

787 Raises 

788 ------ 

789 ValueError 

790 If the center does not have exactly 2 dimensions. 

791 """ 

792 if len(self.center) != 2: 

793 raise ValueError("Visualization only possible for 2D balls") 

794 

795 theta = np.linspace(0, 2 * np.pi, 50) 

796 x = self.center[0] + self.radius * np.cos(theta) 

797 y = self.center[1] + self.radius * np.sin(theta) 

798 return np.array([x, y]) 

799 

800 

801class MaxDVHProjection(BasicProjection): 

802 """ 

803 Class for max dose-volume histogram projections. 

804 

805 Parameters 

806 ---------- 

807 d_max : float 

808 The maximum dose value. 

809 max_percentage : float 

810 The maximum percentage of elements allowed to exceed d_max. 

811 idx : npt.NDArray or None 

812 Subset of the input vector to apply the projection on. 

813 

814 Attributes 

815 ---------- 

816 d_max : float 

817 The maximum dose value. 

818 max_percentage : float 

819 The maximum percentage of elements allowed to exceed d_max. 

820 """ 

821 

822 def __init__( 

823 self, 

824 d_max: float, 

825 max_percentage: float, 

826 idx: npt.NDArray | None = None, 

827 relaxation: float = 1.0, 

828 proximity_flag=True, 

829 use_gpu=False, 

830 ): 

831 super().__init__( 

832 relaxation=relaxation, idx=idx, proximity_flag=proximity_flag, _use_gpu=use_gpu 

833 ) 

834 

835 # max percentage of elements that are allowed to exceed d_max 

836 self.max_percentage = max_percentage 

837 self.d_max = d_max 

838 

839 if isinstance(self.idx, slice): 

840 self._idx_indices = None 

841 elif self.idx.dtype == bool: 

842 raise ValueError("Boolean indexing is not supported for this projection.") 

843 else: 

844 if self._use_gpu: 

845 self._idx_indices = cp.asarray(self.idx, dtype=cp.int32) 

846 else: 

847 self._idx_indices = self.idx 

848 

849 def _project(self, x: npt.NDArray) -> np.ndarray: 

850 """ 

851 Projects the input array `x` onto the DVH constraint. 

852 

853 Parameters 

854 ---------- 

855 x : npt.NDArray 

856 The input array to be projected. 

857 

858 Returns 

859 ------- 

860 npt.NDArray 

861 The projected array. 

862 """ 

863 if isinstance(self.idx, slice): 

864 return self._project_all(x) 

865 

866 return self._project_subset(x) 

867 

868 def _project_all(self, x: npt.NDArray) -> np.ndarray: 

869 n = len(x) 

870 am = math.floor(self.max_percentage * n) 

871 

872 l = (x > self.d_max).sum() 

873 

874 z = l - am 

875 

876 if z > 0: 

877 x[x.argsort()[n - l : n - am]] = self.d_max 

878 return x 

879 

880 def _project_subset(self, x: npt.NDArray) -> np.ndarray: 

881 

882 n = self.idx.sum() if self.idx.dtype == bool else len(self.idx) 

883 

884 am = math.floor(self.max_percentage * n) 

885 

886 l = (x[self.idx] > self.d_max).sum() 

887 

888 z = l - am # number of elements that need to be reduced 

889 

890 if z > 0: 

891 x[self._idx_indices[x[self.idx].argsort()[n - l : n - am]]] = self.d_max 

892 

893 return x 

894 

895 # def _project(self, x: npt.NDArray) -> np.ndarray: 

896 # """ 

897 # Projects the input array `x` onto the DVH constraint. 

898 

899 # Parameters 

900 # ---------- 

901 # x : npt.NDArray 

902 # The input array to be projected. 

903 

904 # Returns 

905 # ------- 

906 # npt.NDArray 

907 # The projected array. 

908 

909 # Notes 

910 # ----- 

911 # - The method calculates the number of elements that should receive a dose lower than `d_max` based on `max_percentage`. 

912 # - It then determines how many elements in the input array exceed `d_max`. 

913 # - If the number of elements exceeding `d_max` is greater than the allowed maximum, it reduces the highest values to `d_max`. 

914 # """ 

915 # # percentage of elements that should receive a dose lower than d_max 

916 # n = len(x) if isinstance(self.idx, slice) else self.idx.sum() 

917 # am = math.floor(self.max_percentage * n) 

918 

919 # # number of elements in structure with dose greater than d_max 

920 # l = (x[self.idx] > self.d_max).sum() 

921 

922 # z = l - am # number of elements that need to be reduced 

923 

924 # if z > 0: 

925 # x[x[self.idx].argsort()[n - l : n - am]] = self.d_max 

926 

927 # return x 

928 

929 def _proximity(self, x: npt.NDArray, proximity_measures: List) -> List[float]: 

930 """ 

931 Calculate the proximity of the given array to a specified maximum 

932 percentage. 

933 

934 Parameters 

935 ---------- 

936 x : npt.NDArray 

937 Input array to be evaluated. 

938 

939 Returns 

940 ------- 

941 float 

942 The proximity value as a percentage. 

943 """ 

944 # TODO: Find appropriate proximity measure 

945 

946 if isinstance(self.idx, slice): 

947 return self._proximity_all(x, proximity_measures) 

948 

949 return self._proximity_subset(x, proximity_measures) 

950 

951 def _proximity_all(self, x: npt.NDArray, proximity_measures: List) -> List[float]: 

952 n = len(x) 

953 am = math.floor(self.max_percentage * n) 

954 

955 l = (x > self.d_max).sum() 

956 

957 z = l - am 

958 

959 if z > 0: 

960 x_over = x[x.argsort()[n - l : n - am]] - self.d_max 

961 else: 

962 return [0 for measure in proximity_measures] 

963 measures = [] 

964 for measure in proximity_measures: 

965 if isinstance(measure, tuple): 

966 if measure[0] == "p_norm": 

967 measures.append((x_over ** measure[1]).sum() / len(x)) 

968 else: 

969 raise ValueError("Invalid proximity measure") 

970 elif isinstance(measure, str) and measure == "max_norm": 

971 measures.append(x_over.max()) 

972 else: 

973 raise ValueError("Invalid proximity measure") 

974 return measures 

975 

976 def _proximity_subset(self, x: npt.NDArray, proximity_measures: List) -> List[float]: 

977 

978 n = self.idx.sum() if self.idx.dtype == bool else len(self.idx) 

979 

980 am = math.floor(self.max_percentage * n) 

981 

982 l = (x[self.idx] > self.d_max).sum() 

983 

984 z = l - am # number of elements that need to be reduced 

985 

986 if z > 0: 

987 x_over = x[self._idx_indices[x[self.idx].argsort()[n - l : n - am]]] - self.d_max 

988 else: 

989 return [0 for measure in proximity_measures] 

990 measures = [] 

991 for measure in proximity_measures: 

992 if isinstance(measure, tuple): 

993 if measure[0] == "p_norm": 

994 measures.append((x_over ** measure[1]).sum() / n) 

995 else: 

996 raise ValueError("Invalid proximity measure") 

997 elif isinstance(measure, str) and measure == "max_norm": 

998 measures.append(x_over.max()) 

999 else: 

1000 raise ValueError("Invalid proximity measure") 

1001 return measures 

1002 

1003 

1004class MinDVHProjection(BasicProjection): 

1005 """""" 

1006 

1007 def __init__( 

1008 self, 

1009 d_min: float, 

1010 min_percentage: float, 

1011 idx: npt.NDArray | None = None, 

1012 relaxation: float = 1.0, 

1013 proximity_flag=True, 

1014 use_gpu=False, 

1015 ): 

1016 super().__init__( 

1017 relaxation=relaxation, idx=idx, proximity_flag=proximity_flag, _use_gpu=use_gpu 

1018 ) 

1019 

1020 # percentage of elements that need to have at least d_min 

1021 self.min_percentage = min_percentage 

1022 self.d_min = d_min 

1023 if isinstance(self.idx, slice): 

1024 self._idx_indices = None 

1025 elif self.idx.dtype == bool: 

1026 raise ValueError("Boolean indexing is not supported for this projection.") 

1027 else: 

1028 if self._use_gpu: 

1029 self._idx_indices = cp.asarray(self.idx, dtype=cp.int32) 

1030 else: 

1031 self._idx_indices = self.idx 

1032 

1033 def _project(self, x: npt.NDArray) -> np.ndarray: 

1034 """ 

1035 Projects the input array `x` onto the DVH constraint. 

1036 

1037 Parameters 

1038 ---------- 

1039 x : npt.NDArray 

1040 The input array to be projected. 

1041 

1042 Returns 

1043 ------- 

1044 npt.NDArray 

1045 The projected array. 

1046 """ 

1047 if isinstance(self.idx, slice): 

1048 return self._project_all(x) 

1049 

1050 return self._project_subset(x) 

1051 

1052 def _project_all(self, x: npt.NDArray) -> np.ndarray: 

1053 n = len(x) 

1054 am = math.ceil(self.min_percentage * n) 

1055 

1056 l = (x < self.d_min).sum() 

1057 

1058 z = l - n + am 

1059 

1060 if z > 0: 

1061 x[x.argsort()[n - am : l]] = self.d_min 

1062 return x 

1063 

1064 def _project_subset(self, x: npt.NDArray) -> np.ndarray: 

1065 

1066 n = self.idx.sum() if self.idx.dtype == bool else len(self.idx) 

1067 

1068 am = math.ceil(self.min_percentage * n) 

1069 

1070 l = (x[self.idx] < self.d_min).sum() 

1071 

1072 z = l - n + am # number of elements that need to be reduced 

1073 

1074 if z > 0: 

1075 x[self._idx_indices[x[self.idx].argsort()[n - am : l]]] = self.d_min 

1076 

1077 return x 

1078 

1079 def _proximity(self, x: npt.NDArray, proximity_measures: List) -> List[float]: 

1080 """ 

1081 Calculate the proximity of the given array to a specified maximum 

1082 percentage. 

1083 

1084 Parameters 

1085 ---------- 

1086 x : npt.NDArray 

1087 Input array to be evaluated. 

1088 

1089 Returns 

1090 ------- 

1091 List[float] 

1092 List of proximity values. 

1093 """ 

1094 # TODO: Find appropriate proximity measure 

1095 if isinstance(self.idx, slice): 

1096 return self._proximity_all(x, proximity_measures) 

1097 

1098 return self._proximity_subset(x, proximity_measures) 

1099 

1100 def _proximity_all(self, x: npt.NDArray, proximity_measures: List) -> List[float]: 

1101 """ 

1102 Calculate the proximity of the given array to a specified maximum 

1103 percentage. 

1104 

1105 Parameters 

1106 ---------- 

1107 x : npt.NDArray 

1108 Input array to be evaluated. 

1109 

1110 Returns 

1111 ------- 

1112 float 

1113 The proximity value as a percentage. 

1114 """ 

1115 # TODO: Find appropriate proximity measure 

1116 n = len(x) 

1117 am = math.ceil(self.min_percentage * n) 

1118 

1119 l = (x < self.d_min).sum() 

1120 

1121 z = l - n + am 

1122 

1123 if z > 0: 

1124 x_under = self.d_min - x[x.argsort()[n - am : l]] 

1125 else: 

1126 return [0 for measure in proximity_measures] 

1127 

1128 measures = [] 

1129 for measure in proximity_measures: 

1130 if isinstance(measure, tuple): 

1131 if measure[0] == "p_norm": 

1132 measures.append((x_under ** measure[1]).sum() / n) 

1133 else: 

1134 raise ValueError("Invalid proximity measure") 

1135 elif isinstance(measure, str) and measure == "max_norm": 

1136 measures.append(x_under.max()) 

1137 else: 

1138 raise ValueError("Invalid proximity measure") 

1139 return measures 

1140 

1141 def _proximity_subset(self, x: npt.NDArray, proximity_measures: List) -> List[float]: 

1142 """ 

1143 Calculate the proximity of the given array to a specified maximum 

1144 percentage. 

1145 

1146 Parameters 

1147 ---------- 

1148 x : npt.NDArray 

1149 Input array to be evaluated. 

1150 

1151 Returns 

1152 ------- 

1153 float 

1154 The proximity value as a percentage. 

1155 """ 

1156 n = self.idx.sum() if self.idx.dtype == bool else len(self.idx) 

1157 

1158 am = math.ceil(self.min_percentage * n) 

1159 

1160 l = (x[self.idx] < self.d_min).sum() 

1161 

1162 z = l - n + am # number of elements that need to be reduced 

1163 

1164 if z > 0: 

1165 x_under = self.d_min - x[self._idx_indices[x[self.idx].argsort()[n - am : l]]] 

1166 

1167 else: 

1168 return [0 for measure in proximity_measures] 

1169 

1170 measures = [] 

1171 for measure in proximity_measures: 

1172 if isinstance(measure, tuple): 

1173 if measure[0] == "p_norm": 

1174 measures.append((x_under ** measure[1]).sum() / n) 

1175 else: 

1176 raise ValueError("Invalid proximity measure") 

1177 elif isinstance(measure, str) and measure == "max_norm": 

1178 measures.append(x_under.max()) 

1179 else: 

1180 raise ValueError("Invalid proximity measure") 

1181 return measures 

1182 

1183 

1184class CustomProjection(BasicProjection): 

1185 """ 

1186 CustomProjection allows users to set up custom projection objects. 

1187 

1188 Parameters 

1189 ---------- 

1190 projection_function : callable 

1191 User-defined function for projection. 

1192 proximity_function : callable, optional 

1193 User-defined function for proximity calculation, by default None. 

1194 If None, the proximity is calculated based on P(x)-x residuals. 

1195 """ 

1196 

1197 def __init__( 

1198 self, 

1199 projection_function, 

1200 proximity_function=None, 

1201 relaxation=1, 

1202 idx: npt.NDArray | None = None, 

1203 proximity_flag=True, 

1204 _use_gpu=False, 

1205 ): 

1206 super().__init__(relaxation, idx, proximity_flag, _use_gpu) 

1207 self.projection_function = projection_function 

1208 self.proximity_function = proximity_function 

1209 

1210 def _project(self, x: npt.NDArray) -> npt.NDArray: 

1211 return self.projection_function(x) 

1212 

1213 def _proximity(self, x: npt.NDArray, proximity_measures: List) -> List[float]: 

1214 if self.proximity_function is None: 

1215 return super()._proximity(x, proximity_measures) 

1216 else: 

1217 return self.proximity_function(x, proximity_measures) 

1218 

1219 

1220class MMUProjection(BasicProjection, ABC): 

1221 """ 

1222 Class for enforcing minimum monitoring units. They can only be 0 or 

1223 above a certain threshold (mmu). 

1224 """ 

1225 

1226 # Default idea: 

1227 def __init__( 

1228 self, 

1229 mmu: float | npt.NDArray, 

1230 relaxation: float = 1.0, 

1231 idx: npt.NDArray | None = None, 

1232 proximity_flag=True, 

1233 _use_gpu=False, 

1234 ): 

1235 super().__init__( 

1236 relaxation=relaxation, idx=idx, proximity_flag=proximity_flag, _use_gpu=_use_gpu 

1237 ) 

1238 if isinstance(self.idx, slice): 

1239 self._idx_indices = None 

1240 elif self.idx.dtype == bool: 

1241 raise ValueError("Boolean indexing is not supported for this projection.") 

1242 else: 

1243 if self._use_gpu: 

1244 self._idx_indices = cp.asarray(self.idx, dtype=cp.int32) 

1245 else: 

1246 self._idx_indices = self.idx 

1247 

1248 self.mmu = mmu 

1249 

1250 @abstractmethod 

1251 def _project(self, x: npt.NDArray) -> npt.NDArray: 

1252 pass 

1253 

1254 

1255class MMUProjection1(MMUProjection): 

1256 def _project(self, x: npt.NDArray) -> npt.NDArray: 

1257 """ 

1258 First idea: 

1259 If x[i] < mmu set x[i] to 0, else leave it unchanged. 

1260 """ 

1261 xp = cp if isinstance(x, cp.ndarray) else np 

1262 x[self.idx] = xp.where((x[self.idx] < self.mmu), 0, x[self.idx]) 

1263 return x 

1264 

1265 

1266class MMUProjection2(MMUProjection): 

1267 def _project(self, x: npt.NDArray) -> npt.NDArray: 

1268 """ 

1269 Second idea: 

1270 Project to closest point, e.g. if x[i] < mmu/2 set x[i] to 0, else set 

1271 x[i] to mmu. 

1272 """ 

1273 xp = cp if isinstance(x, cp.ndarray) else np 

1274 x[self.idx] = xp.where( 

1275 x[self.idx] <= self.mmu / 2, 

1276 0, 

1277 xp.where(x[self.idx] <= self.mmu, self.mmu, x[self.idx]), 

1278 ) 

1279 return x 

1280 

1281 

1282class MMUProjectionMinMMUPercentage(MMUProjection): 

1283 """ 

1284 A minimum percentage of the elements have to be at level of the mmu or 

1285 the rest has to be above. 

1286 If too many elements are below mmu, the ones closest below mmu are set to 

1287 mmu until the percentage is reached. 

1288 """ 

1289 

1290 def __init__( 

1291 self, 

1292 mmu: float | npt.NDArray, 

1293 min_percentage: float, 

1294 idx: npt.NDArray | None = None, 

1295 relaxation: float = 1.0, 

1296 proximity_flag=True, 

1297 use_gpu=False, 

1298 ): 

1299 super().__init__( 

1300 mmu, relaxation=relaxation, idx=idx, proximity_flag=proximity_flag, _use_gpu=use_gpu 

1301 ) 

1302 self.min_percentage = min_percentage 

1303 

1304 def _project(self, x: npt.NDArray) -> np.ndarray: 

1305 """ 

1306 Projects the input array `x` onto the DVH constraint. 

1307 

1308 Parameters 

1309 ---------- 

1310 x : npt.NDArray 

1311 The input array to be projected. 

1312 

1313 Returns 

1314 ------- 

1315 npt.NDArray 

1316 The projected array. 

1317 """ 

1318 if isinstance(self.idx, slice): 

1319 return self._project_all(x) 

1320 

1321 return self._project_subset(x) 

1322 

1323 def _project_all(self, x: npt.NDArray) -> np.ndarray: 

1324 xp = cp if isinstance(x, cp.ndarray) else np 

1325 n = len(x) 

1326 am = math.ceil( 

1327 self.min_percentage * n 

1328 ) # find number of elements that need to be at least mmu 

1329 

1330 l = (x < self.mmu).sum() # number of elements not fullfilling the mmu requirement 

1331 

1332 z = l - n + am 

1333 

1334 if z > 0: 

1335 x[x.argsort()[n - am : l]] = self.mmu 

1336 

1337 x[self.idx] = xp.where( 

1338 x[self.idx] <= self.mmu / 2, 

1339 0, 

1340 xp.where(x[self.idx] <= self.mmu, self.mmu, x[self.idx]), 

1341 ) 

1342 

1343 return x 

1344 

1345 def _project_subset(self, x: npt.NDArray) -> np.ndarray: 

1346 xp = cp if isinstance(x, cp.ndarray) else np 

1347 n = self.idx.sum() if self.idx.dtype == bool else len(self.idx) 

1348 

1349 am = math.ceil(self.min_percentage * n) 

1350 

1351 l = (x[self.idx] < self.mmu).sum() 

1352 

1353 z = l - n + am # number of elements that need to be reduced 

1354 

1355 if z > 0: 

1356 x[self._idx_indices[x[self.idx].argsort()[n - am : l]]] = self.mmu 

1357 

1358 x[self.idx] = xp.where( 

1359 x[self.idx] <= self.mmu / 2, 

1360 0, 

1361 xp.where(x[self.idx] <= self.mmu, self.mmu, x[self.idx]), 

1362 ) 

1363 

1364 return x 

1365 

1366 

1367class MMUProjectionMinZeroPercentage(MMUProjection): 

1368 """ 

1369 A minimum percentage of the elements must have a value of 0. 

1370 If too many elements are above, the ones closest are set to 0 until the 

1371 percentage is reached. 

1372 """ 

1373 

1374 def __init__( 

1375 self, 

1376 mmu: float | npt.NDArray, 

1377 min_percentage: float, 

1378 idx: npt.NDArray | None = None, 

1379 relaxation: float = 1.0, 

1380 proximity_flag=True, 

1381 use_gpu=False, 

1382 ): 

1383 super().__init__( 

1384 mmu, relaxation=relaxation, idx=idx, proximity_flag=proximity_flag, _use_gpu=use_gpu 

1385 ) 

1386 self.min_percentage = min_percentage 

1387 self.idxs = [] 

1388 

1389 def _project(self, x: npt.NDArray) -> np.ndarray: 

1390 """ 

1391 Projects the input array `x` onto the DVH constraint. 

1392 

1393 Parameters 

1394 ---------- 

1395 x : npt.NDArray 

1396 The input array to be projected. 

1397 

1398 Returns 

1399 ------- 

1400 npt.NDArray 

1401 The projected array. 

1402 """ 

1403 if isinstance(self.idx, slice): 

1404 return self._project_all(x) 

1405 

1406 return self._project_subset(x) 

1407 

1408 def _project_all(self, x: npt.NDArray) -> np.ndarray: 

1409 xp = cp if isinstance(x, cp.ndarray) else np 

1410 n = len(x) 

1411 am = math.ceil(self.min_percentage * n) # find number of elements that need to be 0 

1412 

1413 l = (x > 0).sum() # number of elements that are above the requirement 

1414 

1415 z = l - n + am 

1416 

1417 if z > 0: 

1418 x[x.argsort()[:am]] = 0 

1419 

1420 x[self.idx] = xp.where( 

1421 x[self.idx] <= self.mmu / 2, 

1422 0, 

1423 xp.where(x[self.idx] <= self.mmu, self.mmu, x[self.idx]), 

1424 ) 

1425 

1426 return x 

1427 

1428 def _project_subset(self, x: npt.NDArray) -> np.ndarray: 

1429 xp = cp if isinstance(x, cp.ndarray) else np 

1430 n = self.idx.sum() if self.idx.dtype == bool else len(self.idx) 

1431 

1432 am = math.ceil(self.min_percentage * n) 

1433 

1434 l = (x[self.idx] > 0).sum() # number of elements that are above the requirement 

1435 

1436 z = l - n + am # number of elements that need to be reduced 

1437 

1438 if z > 0: 

1439 x[self._idx_indices[x[self.idx].argsort()[:am]]] = 0 

1440 

1441 x[self.idx] = xp.where( 

1442 x[self.idx] <= self.mmu / 2, 

1443 0, 

1444 xp.where(x[self.idx] <= self.mmu, self.mmu, x[self.idx]), 

1445 ) 

1446 

1447 return x 

1448 

1449 

1450class VariableMMUProjectionMinMaxZeroPercentage(MMUProjection): 

1451 """ 

1452 A minimum percentage of the elements must have a value of 0, a minimum 

1453 percentage of the elements must be at least mmu and the rest can be above. 

1454 Each constraint can have a different MMU! 

1455 """ 

1456 

1457 def __init__( 

1458 self, 

1459 mmu: npt.NDArray, 

1460 min_percentage: float, 

1461 max_percentage: float, 

1462 idx: npt.NDArray | None = None, 

1463 relaxation: float = 1.0, 

1464 proximity_flag=True, 

1465 use_gpu=False, 

1466 ): 

1467 super().__init__( 

1468 mmu, relaxation=relaxation, idx=idx, proximity_flag=proximity_flag, _use_gpu=use_gpu 

1469 ) 

1470 self.min_percentage = min_percentage 

1471 self.max_percentage = ( 

1472 max_percentage # equal to the number of elements that need to be at MMU or above 

1473 ) 

1474 

1475 def _project(self, x: npt.NDArray) -> np.ndarray: 

1476 """ 

1477 Projects the input array `x` onto the DVH constraint. 

1478 

1479 Parameters 

1480 ---------- 

1481 x : npt.NDArray 

1482 The input array to be projected. 

1483 

1484 Returns 

1485 ------- 

1486 npt.NDArray 

1487 The projected array. 

1488 """ 

1489 if isinstance(self.idx, slice): 

1490 return self._project_all(x) 

1491 

1492 return self._project_subset(x) 

1493 

1494 def _project_all(self, x: npt.NDArray) -> np.ndarray: 

1495 xp = cp if isinstance(x, cp.ndarray) else np 

1496 n = len(x) 

1497 num_min = math.ceil(self.min_percentage * n) # find number of elements that need to be 0 

1498 num_max = math.ceil( 

1499 self.max_percentage * n 

1500 ) # find number of elements that need to be at MMU or above 

1501 

1502 num_above = (x > 0).sum() # number of elements that are above the requirement 

1503 

1504 # find cost of moving elements to MMU and to 0 

1505 cost = x - xp.max(self.mmu, x) 

1506 sort_idxs = cost.argsort() 

1507 

1508 x[sort_idxs[:num_min]] = 0 

1509 x[sort_idxs[num_max:]] = xp.where( 

1510 x[sort_idxs[num_max:]] < self.mmu[sort_idxs[num_max:]], 

1511 self.mmu[sort_idxs[num_max:]], 

1512 x[sort_idxs[num_max:]], 

1513 ) 

1514 

1515 mask = sort_idxs[num_min:num_max] 

1516 x[mask] = xp.where(x[mask] <= self.mmu[mask] / 2, 0, xp.maximum(self.mmu[mask], x[mask])) 

1517 

1518 return x 

1519 

1520 # def _project_subset(self, x: npt.NDArray) -> np.ndarray: 

1521 # xp = cp if isinstance(x, cp.ndarray) else np 

1522 # n = self.idx.sum() if self.idx.dtype == bool else len(self.idx) 

1523 

1524 # am = math.ceil(self.min_percentage * n) 

1525 

1526 # l = (x[self.idx] > 0).sum() #number of elements that are above the requirement 

1527 

1528 # z = l - n + am # number of elements that need to be reduced 

1529 

1530 # if z > 0: 

1531 # x[self._idx_indices[x[self.idx].argsort()[:am]]] = 0 

1532 

1533 # x[self.idx] = xp.where( 

1534 # x[self.idx] <= self.mmu / 2, 

1535 # 0, 

1536 # xp.where( 

1537 # x[self.idx] <= self.mmu, 

1538 # self.mmu, 

1539 # x[self.idx] 

1540 # ) 

1541 # ) 

1542 

1543 # return x