Coverage for muutils\interval.py: 98%

278 statements  

« prev     ^ index     » next       coverage.py v7.6.1, created at 2024-10-09 01:48 -0600

1"represents a mathematical `Interval` over the real numbers" 

2 

3from __future__ import annotations 

4 

5import math 

6import typing 

7from typing import Optional, Iterable, Sequence, Union, Any 

8 

9from muutils.misc import str_to_numeric 

10 

11_EPSILON: float = 1e-10 

12 

13Number = Union[float, int] 

14 

15_EMPTY_INTERVAL_ARGS: tuple[Number, Number, bool, bool, set[Number]] = ( 

16 math.nan, 

17 math.nan, 

18 False, 

19 False, 

20 set(), 

21) 

22 

23 

24class Interval: 

25 """ 

26 Represents a mathematical interval, open by default. 

27 

28 The Interval class can represent both open and closed intervals, as well as half-open intervals. 

29 It supports various initialization methods and provides containment checks. 

30 

31 Examples: 

32 

33 >>> i1 = Interval(1, 5) # Default open interval (1, 5) 

34 >>> 3 in i1 

35 True 

36 >>> 1 in i1 

37 False 

38 >>> i2 = Interval([1, 5]) # Closed interval [1, 5] 

39 >>> 1 in i2 

40 True 

41 >>> i3 = Interval(1, 5, closed_L=True) # Half-open interval [1, 5) 

42 >>> str(i3) 

43 '[1, 5)' 

44 >>> i4 = ClosedInterval(1, 5) # Closed interval [1, 5] 

45 >>> i5 = OpenInterval(1, 5) # Open interval (1, 5) 

46 

47 """ 

48 

49 def __init__( 

50 self, 

51 *args: Union[Sequence[Number], Number], 

52 is_closed: Optional[bool] = None, 

53 closed_L: Optional[bool] = None, 

54 closed_R: Optional[bool] = None, 

55 ): 

56 self.lower: Number 

57 self.upper: Number 

58 self.closed_L: bool 

59 self.closed_R: bool 

60 self.singleton_set: Optional[set[Number]] = None 

61 try: 

62 if len(args) == 0: 

63 ( 

64 self.lower, 

65 self.upper, 

66 self.closed_L, 

67 self.closed_R, 

68 self.singleton_set, 

69 ) = _EMPTY_INTERVAL_ARGS 

70 return 

71 # Handle different types of input arguments 

72 if len(args) == 1 and isinstance( 

73 args[0], (list, tuple, Sequence, Iterable) 

74 ): 

75 assert ( 

76 len(args[0]) == 2 

77 ), "if arg is a list or tuple, it must have length 2" 

78 self.lower = args[0][0] 

79 self.upper = args[0][1] 

80 # Determine closure type based on the container type 

81 default_closed = isinstance(args[0], list) 

82 elif len(args) == 1 and isinstance( 

83 args[0], (int, float, typing.SupportsFloat, typing.SupportsInt) 

84 ): 

85 # a singleton, but this will be handled later 

86 self.lower = args[0] 

87 self.upper = args[0] 

88 default_closed = False 

89 elif len(args) == 2: 

90 self.lower, self.upper = args # type: ignore[assignment] 

91 default_closed = False # Default to open interval if two args 

92 else: 

93 raise ValueError(f"Invalid input arguments: {args}") 

94 

95 # if both of the bounds are NaN or None, return an empty interval 

96 if any(x is None for x in (self.lower, self.upper)) or any( 

97 math.isnan(x) for x in (self.lower, self.upper) 

98 ): 

99 if (self.lower is None and self.upper is None) or ( 

100 math.isnan(self.lower) and math.isnan(self.upper) 

101 ): 

102 ( 

103 self.lower, 

104 self.upper, 

105 self.closed_L, 

106 self.closed_R, 

107 self.singleton_set, 

108 ) = _EMPTY_INTERVAL_ARGS 

109 return 

110 else: 

111 raise ValueError( 

112 "Both bounds must be NaN or None to create an empty interval. Also, just use `Interval.get_empty()` instead." 

113 ) 

114 

115 # Ensure lower bound is less than upper bound 

116 if self.lower > self.upper: 

117 raise ValueError("Lower bound must be less than upper bound") 

118 

119 if math.isnan(self.lower) or math.isnan(self.upper): 

120 raise ValueError("NaN is not allowed as an interval bound") 

121 

122 # Determine closure properties 

123 if is_closed is not None: 

124 # can't specify both is_closed and closed_L/R 

125 if (closed_L is not None) or (closed_R is not None): 

126 raise ValueError("Cannot specify both is_closed and closed_L/R") 

127 self.closed_L = is_closed 

128 self.closed_R = is_closed 

129 else: 

130 self.closed_L = closed_L if closed_L is not None else default_closed 

131 self.closed_R = closed_R if closed_R is not None else default_closed 

132 

133 # handle singleton/empty case 

134 if self.lower == self.upper and not (self.closed_L or self.closed_R): 

135 ( 

136 self.lower, 

137 self.upper, 

138 self.closed_L, 

139 self.closed_R, 

140 self.singleton_set, 

141 ) = _EMPTY_INTERVAL_ARGS 

142 return 

143 

144 elif self.lower == self.upper and (self.closed_L or self.closed_R): 

145 self.singleton_set = {self.lower} # Singleton interval 

146 self.closed_L = True 

147 self.closed_R = True 

148 return 

149 # otherwise `singleton_set` is `None` 

150 

151 except (AssertionError, ValueError) as e: 

152 raise ValueError( 

153 f"Invalid input arguments to Interval: {args = }, {is_closed = }, {closed_L = }, {closed_R = }\n{e}\nUsage:\n{self.__doc__}" 

154 ) from e 

155 

156 @property 

157 def is_closed(self) -> bool: 

158 if self.is_empty: 

159 return True 

160 if self.is_singleton: 

161 return True 

162 return self.closed_L and self.closed_R 

163 

164 @property 

165 def is_open(self) -> bool: 

166 if self.is_empty: 

167 return True 

168 if self.is_singleton: 

169 return False 

170 return not self.closed_L and not self.closed_R 

171 

172 @property 

173 def is_half_open(self) -> bool: 

174 return (self.closed_L and not self.closed_R) or ( 

175 not self.closed_L and self.closed_R 

176 ) 

177 

178 @property 

179 def is_singleton(self) -> bool: 

180 return self.singleton_set is not None and len(self.singleton_set) == 1 

181 

182 @property 

183 def is_empty(self) -> bool: 

184 return self.singleton_set is not None and len(self.singleton_set) == 0 

185 

186 @property 

187 def is_finite(self) -> bool: 

188 return not math.isinf(self.lower) and not math.isinf(self.upper) 

189 

190 @property 

191 def singleton(self) -> Number: 

192 if not self.is_singleton: 

193 raise ValueError("Interval is not a singleton") 

194 return next(iter(self.singleton_set)) # type: ignore[arg-type] 

195 

196 @staticmethod 

197 def get_empty() -> Interval: 

198 return Interval(math.nan, math.nan, closed_L=None, closed_R=None) 

199 

200 @staticmethod 

201 def get_singleton(value: Number) -> Interval: 

202 if math.isnan(value) or value is None: 

203 return Interval.get_empty() 

204 return Interval(value, value, closed_L=True, closed_R=True) 

205 

206 def numerical_contained(self, item: Number) -> bool: 

207 if self.is_empty: 

208 return False 

209 if math.isnan(item): 

210 raise ValueError("NaN cannot be checked for containment in an interval") 

211 if self.is_singleton: 

212 return item in self.singleton_set # type: ignore[operator] 

213 return ((self.closed_L and item >= self.lower) or item > self.lower) and ( 

214 (self.closed_R and item <= self.upper) or item < self.upper 

215 ) 

216 

217 def interval_contained(self, item: Interval) -> bool: 

218 if item.is_empty: 

219 return True 

220 if self.is_empty: 

221 return False 

222 if item.is_singleton: 

223 return self.numerical_contained(item.singleton) 

224 if self.is_singleton: 

225 if not item.is_singleton: 

226 return False 

227 return self.singleton == item.singleton 

228 

229 lower_contained: bool = ( 

230 # either strictly wider bound 

231 self.lower < item.lower 

232 # if same, then self must be closed if item is open 

233 or (self.lower == item.lower and self.closed_L >= item.closed_L) 

234 ) 

235 

236 upper_contained: bool = ( 

237 # either strictly wider bound 

238 self.upper > item.upper 

239 # if same, then self must be closed if item is open 

240 or (self.upper == item.upper and self.closed_R >= item.closed_R) 

241 ) 

242 

243 return lower_contained and upper_contained 

244 

245 def __contains__(self, item: Any) -> bool: 

246 if isinstance(item, Interval): 

247 return self.interval_contained(item) 

248 else: 

249 return self.numerical_contained(item) 

250 

251 def __repr__(self) -> str: 

252 if self.is_empty: 

253 return r"∅" 

254 if self.is_singleton: 

255 return "{" + str(self.singleton) + "}" 

256 left: str = "[" if self.closed_L else "(" 

257 right: str = "]" if self.closed_R else ")" 

258 return f"{left}{self.lower}, {self.upper}{right}" 

259 

260 def __str__(self) -> str: 

261 return repr(self) 

262 

263 @classmethod 

264 def from_str(cls, input_str: str) -> Interval: 

265 input_str = input_str.strip() 

266 # empty and singleton 

267 if input_str.count(",") == 0: 

268 # empty set 

269 if input_str == "∅": 

270 return cls.get_empty() 

271 assert input_str.startswith("{") and input_str.endswith( 

272 "}" 

273 ), "Invalid input string" 

274 input_str_set_interior: str = input_str.strip("{}").strip() 

275 if len(input_str_set_interior) == 0: 

276 return cls.get_empty() 

277 # singleton set 

278 return cls.get_singleton(str_to_numeric(input_str_set_interior)) 

279 

280 # expect commas 

281 if not input_str.count(",") == 1: 

282 raise ValueError("Invalid input string") 

283 

284 # get bounds 

285 lower: str 

286 upper: str 

287 lower, upper = input_str.strip("[]()").split(",") 

288 lower = lower.strip() 

289 upper = upper.strip() 

290 

291 lower_num: Number = str_to_numeric(lower) 

292 upper_num: Number = str_to_numeric(upper) 

293 

294 # figure out closure 

295 closed_L: bool 

296 closed_R: bool 

297 if input_str[0] == "[": 

298 closed_L = True 

299 elif input_str[0] == "(": 

300 closed_L = False 

301 else: 

302 raise ValueError("Invalid input string") 

303 

304 if input_str[-1] == "]": 

305 closed_R = True 

306 elif input_str[-1] == ")": 

307 closed_R = False 

308 else: 

309 raise ValueError("Invalid input string") 

310 

311 return cls(lower_num, upper_num, closed_L=closed_L, closed_R=closed_R) 

312 

313 def __eq__(self, other: object) -> bool: 

314 if not isinstance(other, Interval): 

315 return False 

316 if self.is_empty and other.is_empty: 

317 return True 

318 if self.is_singleton and other.is_singleton: 

319 return self.singleton == other.singleton 

320 return (self.lower, self.upper, self.closed_L, self.closed_R) == ( 

321 other.lower, 

322 other.upper, 

323 other.closed_L, 

324 other.closed_R, 

325 ) 

326 

327 def __iter__(self): 

328 if self.is_empty: 

329 return 

330 elif self.is_singleton: 

331 yield self.singleton 

332 return 

333 else: 

334 yield self.lower 

335 yield self.upper 

336 

337 def __getitem__(self, index: int) -> float: 

338 if self.is_empty: 

339 raise IndexError("Empty interval has no bounds") 

340 if self.is_singleton: 

341 if index == 0: 

342 return self.singleton 

343 else: 

344 raise IndexError("Singleton interval has only one bound") 

345 if index == 0: 

346 return self.lower 

347 elif index == 1: 

348 return self.upper 

349 else: 

350 raise IndexError("Interval index out of range") 

351 

352 def __len__(self) -> int: 

353 return 0 if self.is_empty else 1 if self.is_singleton else 2 

354 

355 def copy(self) -> Interval: 

356 if self.is_empty: 

357 return Interval.get_empty() 

358 if self.is_singleton: 

359 return Interval.get_singleton(self.singleton) 

360 return Interval( 

361 self.lower, self.upper, closed_L=self.closed_L, closed_R=self.closed_R 

362 ) 

363 

364 def size(self) -> float: 

365 """ 

366 Returns the size of the interval. 

367 

368 # Returns: 

369 

370 - `float` 

371 the size of the interval 

372 """ 

373 if self.is_empty or self.is_singleton: 

374 return 0 

375 else: 

376 return self.upper - self.lower 

377 

378 def clamp(self, value: Union[int, float], epsilon: float = _EPSILON) -> float: 

379 """ 

380 Clamp the given value to the interval bounds. 

381 

382 For open bounds, the clamped value will be slightly inside the interval (by epsilon). 

383 

384 # Parameters: 

385 

386 - `value : Union[int, float]` 

387 the value to clamp. 

388 - `epsilon : float` 

389 margin for open bounds 

390 (defaults to `_EPSILON`) 

391 

392 # Returns: 

393 

394 - `float` 

395 the clamped value 

396 

397 # Raises: 

398 

399 - `ValueError` : If the input value is NaN. 

400 """ 

401 

402 if math.isnan(value): 

403 raise ValueError("Cannot clamp NaN value") 

404 

405 if math.isnan(epsilon): 

406 raise ValueError("Epsilon cannot be NaN") 

407 

408 if epsilon < 0: 

409 raise ValueError(f"Epsilon must be non-negative: {epsilon = }") 

410 

411 if self.is_empty: 

412 raise ValueError("Cannot clamp to an empty interval") 

413 

414 if self.is_singleton: 

415 return self.singleton 

416 

417 if epsilon > self.size(): 

418 raise ValueError( 

419 f"epsilon is greater than the size of the interval: {epsilon = }, {self.size() = }, {self = }" 

420 ) 

421 

422 # make type work with decimals and stuff 

423 if not isinstance(value, (int, float)): 

424 epsilon = value.__class__(epsilon) 

425 

426 clamped_min: Number 

427 if self.closed_L: 

428 clamped_min = self.lower 

429 else: 

430 clamped_min = self.lower + epsilon 

431 

432 clamped_max: Number 

433 if self.closed_R: 

434 clamped_max = self.upper 

435 else: 

436 clamped_max = self.upper - epsilon 

437 

438 return max(clamped_min, min(value, clamped_max)) 

439 

440 def intersection(self, other: Interval) -> Optional[Interval]: 

441 if not isinstance(other, Interval): 

442 raise TypeError("Can only intersect with another Interval") 

443 

444 if self.is_empty or other.is_empty: 

445 return Interval.get_empty() 

446 

447 if self.is_singleton: 

448 if other.numerical_contained(self.singleton): 

449 return self.copy() 

450 else: 

451 return Interval.get_empty() 

452 

453 if other.is_singleton: 

454 if self.numerical_contained(other.singleton): 

455 return other.copy() 

456 else: 

457 return Interval.get_empty() 

458 

459 if self.upper < other.lower or other.upper < self.lower: 

460 return Interval.get_empty() 

461 

462 lower: Number = max(self.lower, other.lower) 

463 upper: Number = min(self.upper, other.upper) 

464 closed_L: bool = self.closed_L if self.lower > other.lower else other.closed_L 

465 closed_R: bool = self.closed_R if self.upper < other.upper else other.closed_R 

466 

467 return Interval(lower, upper, closed_L=closed_L, closed_R=closed_R) 

468 

469 def union(self, other: Interval) -> Interval: 

470 if not isinstance(other, Interval): 

471 raise TypeError("Can only union with another Interval") 

472 

473 # empty set case 

474 if self.is_empty: 

475 return other.copy() 

476 if other.is_empty: 

477 return self.copy() 

478 

479 # special case where the intersection is empty but the intervals are contiguous 

480 if self.upper == other.lower: 

481 if self.closed_R or other.closed_L: 

482 return Interval( 

483 self.lower, 

484 other.upper, 

485 closed_L=self.closed_L, 

486 closed_R=other.closed_R, 

487 ) 

488 elif other.upper == self.lower: 

489 if other.closed_R or self.closed_L: 

490 return Interval( 

491 other.lower, 

492 self.upper, 

493 closed_L=other.closed_L, 

494 closed_R=self.closed_R, 

495 ) 

496 

497 # non-intersecting nonempty and non-contiguous intervals 

498 if self.intersection(other) == Interval.get_empty(): 

499 raise NotImplementedError( 

500 "Union of non-intersecting nonempty non-contiguous intervals is not implemented " 

501 + f"{self = }, {other = }, {self.intersection(other) = }" 

502 ) 

503 

504 # singleton case 

505 if self.is_singleton: 

506 return other.copy() 

507 if other.is_singleton: 

508 return self.copy() 

509 

510 # regular case 

511 lower: Number = min(self.lower, other.lower) 

512 upper: Number = max(self.upper, other.upper) 

513 closed_L: bool = self.closed_L if self.lower < other.lower else other.closed_L 

514 closed_R: bool = self.closed_R if self.upper > other.upper else other.closed_R 

515 

516 return Interval(lower, upper, closed_L=closed_L, closed_R=closed_R) 

517 

518 

519class ClosedInterval(Interval): 

520 def __init__(self, *args: Union[Sequence[float], float], **kwargs: Any): 

521 if any(key in kwargs for key in ("is_closed", "closed_L", "closed_R")): 

522 raise ValueError("Cannot specify closure properties for ClosedInterval") 

523 super().__init__(*args, is_closed=True) 

524 

525 

526class OpenInterval(Interval): 

527 def __init__(self, *args: Union[Sequence[float], float], **kwargs: Any): 

528 if any(key in kwargs for key in ("is_closed", "closed_L", "closed_R")): 

529 raise ValueError("Cannot specify closure properties for OpenInterval") 

530 super().__init__(*args, is_closed=False)