Coverage for src / tracekit / filtering / introspection.py: 100%

212 statements  

« prev     ^ index     » next       coverage.py v7.13.1, created at 2026-01-11 23:04 +0000

1"""Filter introspection and visualization for TraceKit. 

2 

3Provides filter analysis tools including Bode plots, impulse response, 

4step response, and pole-zero diagrams. 

5 

6 

7Example: 

8 >>> from tracekit.filtering import LowPassFilter, plot_bode 

9 >>> filt = LowPassFilter(cutoff=1e6, sample_rate=10e6, order=4) 

10 >>> fig = plot_bode(filt) 

11 >>> plt.show() 

12""" 

13 

14from __future__ import annotations 

15 

16from typing import TYPE_CHECKING 

17 

18import numpy as np 

19 

20if TYPE_CHECKING: 

21 from matplotlib.figure import Figure 

22 from numpy.typing import NDArray 

23 

24from tracekit.filtering.base import Filter, IIRFilter 

25 

26 

27class FilterIntrospection: 

28 """Mixin class providing filter introspection methods. 

29 

30 Provides methods for analyzing filter characteristics including 

31 frequency response, impulse response, step response, and stability. 

32 """ 

33 

34 def __init__(self, filter_obj: Filter) -> None: 

35 """Initialize with a filter object. 

36 

37 Args: 

38 filter_obj: Filter to introspect. 

39 """ 

40 self._filter = filter_obj 

41 

42 @property 

43 def filter(self) -> Filter: 

44 """The wrapped filter object. 

45 

46 Returns: 

47 The filter being introspected. 

48 """ 

49 return self._filter 

50 

51 def magnitude_response( 

52 self, 

53 freqs: NDArray[np.float64] | None = None, 

54 db: bool = True, 

55 ) -> tuple[NDArray[np.float64], NDArray[np.float64]]: 

56 """Get magnitude response. 

57 

58 Args: 

59 freqs: Frequencies in Hz. If None, auto-generate. 

60 db: If True, return magnitude in dB. 

61 

62 Returns: 

63 Tuple of (frequencies, magnitude). 

64 

65 Raises: 

66 ValueError: If freqs is None and filter has no sample_rate. 

67 """ 

68 if freqs is None: 

69 if self._filter.sample_rate is None: 

70 raise ValueError( 

71 "Either freqs must be provided or filter must have sample_rate set" 

72 ) 

73 freqs = np.linspace(0, self._filter.sample_rate / 2, 512) 

74 

75 h = self._filter.get_transfer_function(freqs) 

76 mag = np.abs(h) 

77 

78 if db: 

79 mag = 20 * np.log10(np.maximum(mag, 1e-12)) 

80 

81 return freqs, mag 

82 

83 def phase_response( 

84 self, 

85 freqs: NDArray[np.float64] | None = None, 

86 unwrap: bool = True, 

87 degrees: bool = True, 

88 ) -> tuple[NDArray[np.float64], NDArray[np.float64]]: 

89 """Get phase response. 

90 

91 Args: 

92 freqs: Frequencies in Hz. If None, auto-generate. 

93 unwrap: If True, unwrap phase to remove discontinuities. 

94 degrees: If True, return phase in degrees. 

95 

96 Returns: 

97 Tuple of (frequencies, phase). 

98 

99 Raises: 

100 ValueError: If freqs is None and filter has no sample_rate. 

101 """ 

102 if freqs is None: 

103 if self._filter.sample_rate is None: 

104 raise ValueError( 

105 "Either freqs must be provided or filter must have sample_rate set" 

106 ) 

107 freqs = np.linspace(0, self._filter.sample_rate / 2, 512) 

108 

109 h = self._filter.get_transfer_function(freqs) 

110 phase = np.angle(h) 

111 

112 if unwrap: 

113 phase = np.unwrap(phase) 

114 

115 if degrees: 

116 phase = np.degrees(phase) 

117 

118 return freqs, phase 

119 

120 def group_delay_hz( 

121 self, 

122 freqs: NDArray[np.float64] | None = None, 

123 ) -> tuple[NDArray[np.float64], NDArray[np.float64]]: 

124 """Get group delay in seconds. 

125 

126 Args: 

127 freqs: Frequencies in Hz. If None, auto-generate. 

128 

129 Returns: 

130 Tuple of (frequencies in Hz, group delay in seconds). 

131 """ 

132 w, gd_samples = self._filter.get_group_delay() 

133 

134 if self._filter.sample_rate is not None: 

135 freqs_out = w * self._filter.sample_rate / (2 * np.pi) 

136 gd_seconds = gd_samples / self._filter.sample_rate 

137 else: 

138 freqs_out = w 

139 gd_seconds = gd_samples 

140 

141 return freqs_out, gd_seconds 

142 

143 def passband_ripple( 

144 self, 

145 passband_edge: float, 

146 ) -> float: 

147 """Calculate passband ripple in dB. 

148 

149 Args: 

150 passband_edge: Passband edge frequency in Hz. 

151 

152 Returns: 

153 Peak-to-peak ripple in dB within passband. 

154 

155 Raises: 

156 ValueError: If filter sample_rate is not set. 

157 """ 

158 if self._filter.sample_rate is None: 

159 raise ValueError("Sample rate must be set") 

160 

161 freqs = np.linspace(0, passband_edge, 256) 

162 _, mag_db = self.magnitude_response(freqs, db=True) 

163 

164 return float(np.max(mag_db) - np.min(mag_db)) 

165 

166 def stopband_attenuation( 

167 self, 

168 stopband_edge: float, 

169 ) -> float: 

170 """Calculate minimum stopband attenuation in dB. 

171 

172 Args: 

173 stopband_edge: Stopband edge frequency in Hz. 

174 

175 Returns: 

176 Minimum attenuation in stopband in dB (positive value). 

177 

178 Raises: 

179 ValueError: If filter sample_rate is not set. 

180 """ 

181 if self._filter.sample_rate is None: 

182 raise ValueError("Sample rate must be set") 

183 

184 freqs = np.linspace(stopband_edge, self._filter.sample_rate / 2, 256) 

185 _, mag_db = self.magnitude_response(freqs, db=True) 

186 

187 return float(-np.max(mag_db)) 

188 

189 def cutoff_frequency( 

190 self, 

191 threshold_db: float = -3.0, 

192 ) -> float: 

193 """Find -3dB cutoff frequency. 

194 

195 Args: 

196 threshold_db: Threshold in dB (default -3dB). 

197 

198 Returns: 

199 Cutoff frequency in Hz. 

200 

201 Raises: 

202 ValueError: If filter sample_rate is not set. 

203 """ 

204 if self._filter.sample_rate is None: 

205 raise ValueError("Sample rate must be set") 

206 

207 freqs = np.linspace(0, self._filter.sample_rate / 2, 1000) 

208 _, mag_db = self.magnitude_response(freqs, db=True) 

209 

210 # Normalize to 0dB at DC 

211 mag_db = mag_db - mag_db[0] 

212 

213 # Find first crossing of threshold 

214 crossings = np.where(mag_db < threshold_db)[0] 

215 if len(crossings) == 0: 

216 return float(freqs[-1]) 

217 

218 return float(freqs[crossings[0]]) 

219 

220 

221def plot_bode( 

222 filt: Filter, 

223 *, 

224 figsize: tuple[float, float] = (10, 8), 

225 freq_range: tuple[float, float] | None = None, 

226 n_points: int = 512, 

227 title: str | None = None, 

228) -> Figure: 

229 """Plot Bode diagram (magnitude and phase response). 

230 

231 Args: 

232 filt: Filter to plot. 

233 figsize: Figure size in inches. 

234 freq_range: Frequency range (min, max) in Hz. None for auto. 

235 n_points: Number of frequency points. 

236 title: Plot title. 

237 

238 Returns: 

239 Matplotlib Figure object. 

240 

241 Raises: 

242 ValueError: If filter sample_rate is not set. 

243 

244 Example: 

245 >>> fig = plot_bode(filt) 

246 >>> plt.show() 

247 """ 

248 import matplotlib.pyplot as plt 

249 

250 if filt.sample_rate is None: 

251 raise ValueError("Filter sample rate must be set for plotting") 

252 

253 if freq_range is None: 

254 freq_range = (1, filt.sample_rate / 2) 

255 

256 freqs = np.geomspace(freq_range[0], freq_range[1], n_points) 

257 

258 introspect = FilterIntrospection(filt) 

259 _, mag_db = introspect.magnitude_response(freqs, db=True) 

260 _, phase_deg = introspect.phase_response(freqs, degrees=True) 

261 

262 fig, (ax1, ax2) = plt.subplots(2, 1, figsize=figsize, sharex=True) 

263 

264 # Magnitude plot 

265 ax1.semilogx(freqs, mag_db) 

266 ax1.set_ylabel("Magnitude (dB)") 

267 ax1.grid(True, which="both", alpha=0.3) 

268 ax1.axhline(-3, color="r", linestyle="--", alpha=0.5, label="-3 dB") 

269 ax1.legend() 

270 

271 # Phase plot 

272 ax2.semilogx(freqs, phase_deg) 

273 ax2.set_xlabel("Frequency (Hz)") 

274 ax2.set_ylabel("Phase (degrees)") 

275 ax2.grid(True, which="both", alpha=0.3) 

276 

277 if title: 

278 fig.suptitle(title) 

279 else: 

280 fig.suptitle(f"Bode Plot - Order {filt.order} Filter") 

281 

282 plt.tight_layout() 

283 return fig 

284 

285 

286def plot_impulse( 

287 filt: Filter, 

288 *, 

289 n_samples: int = 256, 

290 figsize: tuple[float, float] = (10, 4), 

291 title: str | None = None, 

292) -> Figure: 

293 """Plot impulse response. 

294 

295 Args: 

296 filt: Filter to plot. 

297 n_samples: Number of samples in response. 

298 figsize: Figure size in inches. 

299 title: Plot title. 

300 

301 Returns: 

302 Matplotlib Figure object. 

303 """ 

304 import matplotlib.pyplot as plt 

305 

306 impulse = filt.get_impulse_response(n_samples) 

307 

308 fig, ax = plt.subplots(figsize=figsize) 

309 

310 if filt.sample_rate is not None: 

311 t = np.arange(n_samples) / filt.sample_rate * 1e6 # microseconds 

312 ax.plot(t, impulse) 

313 ax.set_xlabel("Time (us)") 

314 else: 

315 ax.plot(impulse) 

316 ax.set_xlabel("Samples") 

317 

318 ax.set_ylabel("Amplitude") 

319 ax.grid(True, alpha=0.3) 

320 ax.axhline(0, color="k", linewidth=0.5) 

321 

322 if title: 

323 ax.set_title(title) 

324 else: 

325 ax.set_title("Impulse Response") 

326 

327 plt.tight_layout() 

328 return fig 

329 

330 

331def plot_step( 

332 filt: Filter, 

333 *, 

334 n_samples: int = 256, 

335 figsize: tuple[float, float] = (10, 4), 

336 title: str | None = None, 

337) -> Figure: 

338 """Plot step response. 

339 

340 Args: 

341 filt: Filter to plot. 

342 n_samples: Number of samples in response. 

343 figsize: Figure size in inches. 

344 title: Plot title. 

345 

346 Returns: 

347 Matplotlib Figure object. 

348 """ 

349 import matplotlib.pyplot as plt 

350 

351 step = filt.get_step_response(n_samples) 

352 

353 fig, ax = plt.subplots(figsize=figsize) 

354 

355 if filt.sample_rate is not None: 

356 t = np.arange(n_samples) / filt.sample_rate * 1e6 # microseconds 

357 ax.plot(t, step) 

358 ax.set_xlabel("Time (us)") 

359 else: 

360 ax.plot(step) 

361 ax.set_xlabel("Samples") 

362 

363 ax.set_ylabel("Amplitude") 

364 ax.grid(True, alpha=0.3) 

365 ax.axhline(1, color="r", linestyle="--", alpha=0.5, label="Final value") 

366 ax.legend() 

367 

368 if title: 

369 ax.set_title(title) 

370 else: 

371 ax.set_title("Step Response") 

372 

373 plt.tight_layout() 

374 return fig 

375 

376 

377def plot_poles_zeros( 

378 filt: Filter, 

379 *, 

380 figsize: tuple[float, float] = (8, 8), 

381 title: str | None = None, 

382) -> Figure: 

383 """Plot pole-zero diagram for IIR filter. 

384 

385 Args: 

386 filt: IIR filter to plot. 

387 figsize: Figure size in inches. 

388 title: Plot title. 

389 

390 Returns: 

391 Matplotlib Figure object. 

392 

393 Raises: 

394 ValueError: If filter is not an IIRFilter. 

395 """ 

396 import matplotlib.pyplot as plt 

397 

398 if not isinstance(filt, IIRFilter): 

399 raise ValueError("Pole-zero plot only available for IIR filters") 

400 

401 poles = filt.poles 

402 zeros = filt.zeros 

403 

404 fig, ax = plt.subplots(figsize=figsize) 

405 

406 # Draw unit circle 

407 theta = np.linspace(0, 2 * np.pi, 100) 

408 ax.plot(np.cos(theta), np.sin(theta), "k--", alpha=0.3, label="Unit circle") 

409 

410 # Plot poles and zeros 

411 ax.scatter( 

412 np.real(zeros), 

413 np.imag(zeros), 

414 marker="o", 

415 s=100, 

416 facecolors="none", 

417 edgecolors="b", 

418 linewidths=2, 

419 label="Zeros", 

420 ) 

421 ax.scatter( 

422 np.real(poles), 

423 np.imag(poles), 

424 marker="x", 

425 s=100, 

426 c="r", 

427 linewidths=2, 

428 label="Poles", 

429 ) 

430 

431 ax.set_xlabel("Real") 

432 ax.set_ylabel("Imaginary") 

433 ax.set_aspect("equal") 

434 ax.grid(True, alpha=0.3) 

435 ax.legend() 

436 

437 # Stability indicator 

438 is_stable = np.all(np.abs(poles) < 1.0) 

439 stability_text = "STABLE" if is_stable else "UNSTABLE" 

440 stability_color = "green" if is_stable else "red" 

441 ax.text( 

442 0.95, 

443 0.95, 

444 stability_text, 

445 transform=ax.transAxes, 

446 fontsize=12, 

447 fontweight="bold", 

448 color=stability_color, 

449 ha="right", 

450 va="top", 

451 ) 

452 

453 if title: 

454 ax.set_title(title) 

455 else: 

456 ax.set_title(f"Pole-Zero Plot (Order {filt.order})") 

457 

458 plt.tight_layout() 

459 return fig 

460 

461 

462def plot_group_delay( 

463 filt: Filter, 

464 *, 

465 figsize: tuple[float, float] = (10, 4), 

466 freq_range: tuple[float, float] | None = None, 

467 n_points: int = 512, 

468 title: str | None = None, 

469) -> Figure: 

470 """Plot group delay. 

471 

472 Args: 

473 filt: Filter to plot. 

474 figsize: Figure size in inches. 

475 freq_range: Frequency range (min, max) in Hz. 

476 n_points: Number of frequency points. 

477 title: Plot title. 

478 

479 Returns: 

480 Matplotlib Figure object. 

481 

482 Raises: 

483 ValueError: If filter sample_rate is not set. 

484 """ 

485 import matplotlib.pyplot as plt 

486 

487 if filt.sample_rate is None: 

488 raise ValueError("Filter sample rate must be set for plotting") 

489 

490 introspect = FilterIntrospection(filt) 

491 freqs, gd = introspect.group_delay_hz() 

492 

493 fig, ax = plt.subplots(figsize=figsize) 

494 

495 ax.semilogx(freqs, gd * 1e6) # Convert to microseconds 

496 ax.set_xlabel("Frequency (Hz)") 

497 ax.set_ylabel("Group Delay (us)") 

498 ax.grid(True, which="both", alpha=0.3) 

499 

500 if title: 

501 ax.set_title(title) 

502 else: 

503 ax.set_title("Group Delay") 

504 

505 plt.tight_layout() 

506 return fig 

507 

508 

509def compare_filters( 

510 filters: list[Filter], 

511 labels: list[str] | None = None, 

512 *, 

513 figsize: tuple[float, float] = (12, 10), 

514 freq_range: tuple[float, float] | None = None, 

515 n_points: int = 512, 

516) -> Figure: 

517 """Compare multiple filters on the same plots. 

518 

519 Args: 

520 filters: List of filters to compare. 

521 labels: Labels for each filter. If None, uses "Filter 1", etc. 

522 figsize: Figure size in inches. 

523 freq_range: Frequency range (min, max) in Hz. 

524 n_points: Number of frequency points. 

525 

526 Returns: 

527 Matplotlib Figure object with comparison plots. 

528 

529 Raises: 

530 ValueError: If number of labels doesn't match number of filters or if filter sample_rate is not set. 

531 """ 

532 import matplotlib.pyplot as plt 

533 

534 if labels is None: 

535 labels = [f"Filter {i + 1}" for i in range(len(filters))] 

536 

537 if len(labels) != len(filters): 

538 raise ValueError("Number of labels must match number of filters") 

539 

540 # Use first filter's sample rate for frequency axis 

541 sample_rate = filters[0].sample_rate 

542 if sample_rate is None: 

543 raise ValueError("Filter sample rate must be set for plotting") 

544 

545 if freq_range is None: 

546 freq_range = (1, sample_rate / 2) 

547 

548 freqs = np.geomspace(freq_range[0], freq_range[1], n_points) 

549 

550 fig, axes = plt.subplots(2, 2, figsize=figsize) 

551 

552 for filt, label in zip(filters, labels, strict=False): 

553 introspect = FilterIntrospection(filt) 

554 _, mag_db = introspect.magnitude_response(freqs, db=True) 

555 _, phase_deg = introspect.phase_response(freqs, degrees=True) 

556 impulse = filt.get_impulse_response(256) 

557 step = filt.get_step_response(256) 

558 

559 # Magnitude 

560 axes[0, 0].semilogx(freqs, mag_db, label=label) 

561 # Phase 

562 axes[0, 1].semilogx(freqs, phase_deg, label=label) 

563 # Impulse 

564 axes[1, 0].plot(impulse, label=label) 

565 # Step 

566 axes[1, 1].plot(step, label=label) 

567 

568 axes[0, 0].set_ylabel("Magnitude (dB)") 

569 axes[0, 0].set_title("Magnitude Response") 

570 axes[0, 0].grid(True, which="both", alpha=0.3) 

571 axes[0, 0].axhline(-3, color="k", linestyle="--", alpha=0.3) 

572 axes[0, 0].legend() 

573 

574 axes[0, 1].set_ylabel("Phase (degrees)") 

575 axes[0, 1].set_title("Phase Response") 

576 axes[0, 1].grid(True, which="both", alpha=0.3) 

577 axes[0, 1].legend() 

578 

579 axes[1, 0].set_xlabel("Samples") 

580 axes[1, 0].set_ylabel("Amplitude") 

581 axes[1, 0].set_title("Impulse Response") 

582 axes[1, 0].grid(True, alpha=0.3) 

583 axes[1, 0].legend() 

584 

585 axes[1, 1].set_xlabel("Samples") 

586 axes[1, 1].set_ylabel("Amplitude") 

587 axes[1, 1].set_title("Step Response") 

588 axes[1, 1].grid(True, alpha=0.3) 

589 axes[1, 1].axhline(1, color="k", linestyle="--", alpha=0.3) 

590 axes[1, 1].legend() 

591 

592 fig.suptitle("Filter Comparison") 

593 plt.tight_layout() 

594 return fig 

595 

596 

597__all__ = [ 

598 "FilterIntrospection", 

599 "compare_filters", 

600 "plot_bode", 

601 "plot_group_delay", 

602 "plot_impulse", 

603 "plot_poles_zeros", 

604 "plot_step", 

605]