Coverage for src / tracekit / reporting / plots.py: 8%
306 statements
« prev ^ index » next coverage.py v7.13.1, created at 2026-01-11 23:04 +0000
« prev ^ index » next coverage.py v7.13.1, created at 2026-01-11 23:04 +0000
1"""Plot generation for comprehensive analysis reports.
3This module provides intelligent plot generation for different analysis domains,
4using the existing visualization library and returning figures for the OutputManager
5to save.
6"""
8from __future__ import annotations
10import logging
11from collections.abc import Callable
12from pathlib import Path
13from typing import TYPE_CHECKING, Any
15import numpy as np
17logger = logging.getLogger(__name__)
19try:
20 import matplotlib
21 import matplotlib.pyplot as plt
23 # Use non-interactive backend for automated plot generation
24 matplotlib.use("Agg")
25 HAS_MATPLOTLIB = True
26except ImportError:
27 HAS_MATPLOTLIB = False
29if TYPE_CHECKING:
30 from matplotlib.figure import Figure
32 from tracekit.reporting.config import AnalysisConfig, AnalysisDomain
33 from tracekit.reporting.output import OutputManager
36class PlotGenerator:
37 """Generates visualization plots from analysis results.
39 Intelligently creates appropriate plots for each analysis domain based on
40 available data. Uses the existing tracekit.visualization module and returns
41 matplotlib Figure objects for the OutputManager to save.
43 Attributes:
44 config: Analysis configuration (optional, for plot settings).
46 Requirements:
48 Example:
49 >>> config = AnalysisConfig(plot_format="png", plot_dpi=150)
50 >>> generator = PlotGenerator(config)
51 >>> paths = generator.generate_plots(
52 ... AnalysisDomain.SPECTRAL,
53 ... {"fft": {"frequencies": freq, "magnitude_db": mag}},
54 ... output_manager
55 ... )
56 """
58 def __init__(self, config: AnalysisConfig | None = None) -> None:
59 """Initialize plot generator.
61 Args:
62 config: Analysis configuration for plot settings (format, DPI, etc.).
63 If None, uses defaults.
65 Raises:
66 ImportError: If matplotlib is not installed.
67 """
68 if not HAS_MATPLOTLIB: 68 ↛ 69line 68 didn't jump to line 69 because the condition on line 68 was never true
69 raise ImportError("matplotlib is required for plot generation")
71 self.config = config
73 def generate_plots(
74 self,
75 domain: AnalysisDomain,
76 results: dict[str, Any],
77 output_manager: OutputManager,
78 ) -> list[Path]:
79 """Generate all appropriate plots for an analysis domain.
81 Inspects the results dictionary and generates appropriate visualization
82 plots based on the domain and available data. Returns list of saved
83 plot paths.
85 Args:
86 domain: Analysis domain (e.g., SPECTRAL, WAVEFORM, DIGITAL).
87 results: Dictionary of analysis results for this domain.
88 output_manager: OutputManager instance for saving plots.
90 Returns:
91 List of paths to saved plot files.
93 Example:
94 >>> results = {
95 ... "fft": {"frequencies": freq_array, "magnitude_db": mag_array},
96 ... "psd": {"frequencies": freq_array, "psd": psd_array}
97 ... }
98 >>> paths = generator.generate_plots(
99 ... AnalysisDomain.SPECTRAL,
100 ... results,
101 ... output_manager
102 ... )
103 """
104 from tracekit.reporting.config import AnalysisDomain
106 # Get plot format and DPI from config
107 plot_format = self.config.plot_format if self.config else "png"
108 plot_dpi = self.config.plot_dpi if self.config else 150
110 saved_paths: list[Path] = []
112 for analysis_name, result_data in results.items():
113 # Skip non-dict results
114 if not isinstance(result_data, dict):
115 continue
117 # Check if we have a registered plot function
118 key = (domain, analysis_name)
119 if key in PLOT_REGISTRY:
120 plot_func = PLOT_REGISTRY[key]
121 try:
122 fig = plot_func(result_data)
123 if fig is not None:
124 path = output_manager.save_plot(
125 domain,
126 analysis_name,
127 fig,
128 format=plot_format,
129 dpi=plot_dpi,
130 )
131 saved_paths.append(path)
132 plt.close(fig) # Prevent memory leaks
133 except Exception as e:
134 # Log error but continue with other plots
135 logger.warning("Failed to generate %s plot: %s", analysis_name, e)
136 continue
138 # Also try generic domain-level plots
139 try:
140 if domain == AnalysisDomain.SPECTRAL:
141 saved_paths.extend(
142 self._generate_spectral_plots(
143 results, domain, output_manager, plot_format, plot_dpi
144 )
145 )
146 elif domain == AnalysisDomain.WAVEFORM:
147 saved_paths.extend(
148 self._generate_waveform_plots(
149 results, domain, output_manager, plot_format, plot_dpi
150 )
151 )
152 elif domain == AnalysisDomain.DIGITAL:
153 saved_paths.extend(
154 self._generate_digital_plots(
155 results, domain, output_manager, plot_format, plot_dpi
156 )
157 )
158 elif domain == AnalysisDomain.STATISTICS:
159 saved_paths.extend(
160 self._generate_statistics_plots(
161 results, domain, output_manager, plot_format, plot_dpi
162 )
163 )
164 elif domain == AnalysisDomain.JITTER:
165 saved_paths.extend(
166 self._generate_jitter_plots(
167 results, domain, output_manager, plot_format, plot_dpi
168 )
169 )
170 elif domain == AnalysisDomain.EYE:
171 saved_paths.extend(
172 self._generate_eye_plots(results, domain, output_manager, plot_format, plot_dpi)
173 )
174 elif domain == AnalysisDomain.PATTERNS:
175 saved_paths.extend(
176 self._generate_pattern_plots(
177 results, domain, output_manager, plot_format, plot_dpi
178 )
179 )
180 elif domain == AnalysisDomain.POWER:
181 saved_paths.extend(
182 self._generate_power_plots(
183 results, domain, output_manager, plot_format, plot_dpi
184 )
185 )
186 except Exception as e:
187 logger.warning("Error in domain-level plot generation for %s: %s", domain.value, e)
189 return saved_paths
191 def _generate_spectral_plots(
192 self,
193 results: dict[str, Any],
194 domain: AnalysisDomain,
195 output_manager: OutputManager,
196 plot_format: str,
197 plot_dpi: int,
198 ) -> list[Path]:
199 """Generate spectral analysis plots (FFT, PSD, spectrogram)."""
200 paths: list[Path] = []
202 # FFT plot
203 if "fft" in results and isinstance(results["fft"], dict):
204 fft_data = results["fft"]
205 if "frequencies" in fft_data and "magnitude_db" in fft_data:
206 try:
207 fig = self._plot_spectrum(fft_data, title="FFT Magnitude Spectrum")
208 path = output_manager.save_plot(
209 domain, "fft_spectrum", fig, format=plot_format, dpi=plot_dpi
210 )
211 paths.append(path)
212 plt.close(fig)
213 except Exception:
214 pass
216 # PSD plot
217 if "psd" in results and isinstance(results["psd"], dict):
218 psd_data = results["psd"]
219 if "frequencies" in psd_data and "psd" in psd_data:
220 try:
221 fig = self._plot_spectrum(
222 psd_data, title="Power Spectral Density", ylabel="PSD (dB/Hz)"
223 )
224 path = output_manager.save_plot(
225 domain, "psd_spectrum", fig, format=plot_format, dpi=plot_dpi
226 )
227 paths.append(path)
228 plt.close(fig)
229 except Exception:
230 pass
232 # Spectrogram
233 if "spectrogram" in results and isinstance(results["spectrogram"], dict):
234 spec_data = results["spectrogram"]
235 if "times" in spec_data and "frequencies" in spec_data and "Sxx_db" in spec_data:
236 try:
237 fig = self._plot_spectrogram(spec_data)
238 path = output_manager.save_plot(
239 domain, "spectrogram", fig, format=plot_format, dpi=plot_dpi
240 )
241 paths.append(path)
242 plt.close(fig)
243 except Exception:
244 pass
246 return paths
248 def _generate_waveform_plots(
249 self,
250 results: dict[str, Any],
251 domain: AnalysisDomain,
252 output_manager: OutputManager,
253 plot_format: str,
254 plot_dpi: int,
255 ) -> list[Path]:
256 """Generate waveform analysis plots (time series, histograms)."""
257 paths: list[Path] = []
259 # Look for time-series data
260 for key in ["amplitude", "voltage", "signal", "data"]:
261 if key in results and isinstance(results[key], np.ndarray | list):
262 try:
263 fig = self._plot_time_series(
264 {"data": results[key]}, title=f"{key.title()} vs Time"
265 )
266 path = output_manager.save_plot(
267 domain, f"{key}_timeseries", fig, format=plot_format, dpi=plot_dpi
268 )
269 paths.append(path)
270 plt.close(fig)
271 break
272 except Exception:
273 pass
275 # Histogram of amplitudes
276 for key in ["amplitude", "voltage", "data"]:
277 if key in results and isinstance(results[key], np.ndarray | list):
278 try:
279 fig = self._plot_histogram(
280 {"data": results[key]}, title=f"{key.title()} Distribution"
281 )
282 path = output_manager.save_plot(
283 domain, f"{key}_histogram", fig, format=plot_format, dpi=plot_dpi
284 )
285 paths.append(path)
286 plt.close(fig)
287 break
288 except Exception:
289 pass
291 return paths
293 def _generate_digital_plots(
294 self,
295 results: dict[str, Any],
296 domain: AnalysisDomain,
297 output_manager: OutputManager,
298 plot_format: str,
299 plot_dpi: int,
300 ) -> list[Path]:
301 """Generate digital signal analysis plots (edges, timing)."""
302 paths: list[Path] = []
304 # Edge histogram (rise/fall time distribution)
305 if "edges" in results and isinstance(results["edges"], dict):
306 edges_data = results["edges"]
307 if "rise_times" in edges_data:
308 try:
309 fig = self._plot_histogram(
310 {"data": edges_data["rise_times"]}, title="Rise Time Distribution"
311 )
312 path = output_manager.save_plot(
313 domain, "rise_time_hist", fig, format=plot_format, dpi=plot_dpi
314 )
315 paths.append(path)
316 plt.close(fig)
317 except Exception:
318 pass
320 return paths
322 def _generate_statistics_plots(
323 self,
324 results: dict[str, Any],
325 domain: AnalysisDomain,
326 output_manager: OutputManager,
327 plot_format: str,
328 plot_dpi: int,
329 ) -> list[Path]:
330 """Generate statistical analysis plots (distributions, box plots)."""
331 paths: list[Path] = []
333 # Histogram of distribution
334 if "distribution" in results and isinstance(results["distribution"], dict):
335 dist_data = results["distribution"]
336 if "data" in dist_data:
337 try:
338 fig = self._plot_histogram(dist_data, title="Statistical Distribution")
339 path = output_manager.save_plot(
340 domain, "distribution", fig, format=plot_format, dpi=plot_dpi
341 )
342 paths.append(path)
343 plt.close(fig)
344 except Exception:
345 pass
347 return paths
349 def _generate_jitter_plots(
350 self,
351 results: dict[str, Any],
352 domain: AnalysisDomain,
353 output_manager: OutputManager,
354 plot_format: str,
355 plot_dpi: int,
356 ) -> list[Path]:
357 """Generate jitter analysis plots (TIE histogram, bathtub curve)."""
358 paths: list[Path] = []
360 # TIE (Time Interval Error) histogram
361 if "tie" in results and isinstance(results["tie"], np.ndarray | list):
362 try:
363 fig = self._plot_histogram(
364 {"data": results["tie"]}, title="Time Interval Error (TIE)"
365 )
366 path = output_manager.save_plot(
367 domain, "tie_histogram", fig, format=plot_format, dpi=plot_dpi
368 )
369 paths.append(path)
370 plt.close(fig)
371 except Exception:
372 pass
374 return paths
376 def _generate_eye_plots(
377 self,
378 results: dict[str, Any],
379 domain: AnalysisDomain,
380 output_manager: OutputManager,
381 plot_format: str,
382 plot_dpi: int,
383 ) -> list[Path]:
384 """Generate eye diagram plots."""
385 # Eye diagrams are typically generated by the analyzer itself
386 # This is a placeholder for future enhancements
387 return []
389 def _generate_pattern_plots(
390 self,
391 results: dict[str, Any],
392 domain: AnalysisDomain,
393 output_manager: OutputManager,
394 plot_format: str,
395 plot_dpi: int,
396 ) -> list[Path]:
397 """Generate pattern analysis plots (motifs, sequences)."""
398 paths: list[Path] = []
400 # Pattern occurrence histogram
401 if "patterns" in results and isinstance(results["patterns"], dict):
402 pattern_data = results["patterns"]
403 if "occurrences" in pattern_data:
404 try:
405 fig = self._plot_histogram(
406 {"data": pattern_data["occurrences"]}, title="Pattern Occurrences"
407 )
408 path = output_manager.save_plot(
409 domain, "pattern_occurrences", fig, format=plot_format, dpi=plot_dpi
410 )
411 paths.append(path)
412 plt.close(fig)
413 except Exception:
414 pass
416 return paths
418 def _generate_power_plots(
419 self,
420 results: dict[str, Any],
421 domain: AnalysisDomain,
422 output_manager: OutputManager,
423 plot_format: str,
424 plot_dpi: int,
425 ) -> list[Path]:
426 """Generate power analysis plots (power vs time, efficiency)."""
427 paths: list[Path] = []
429 # Power time series
430 if "power" in results and isinstance(results["power"], np.ndarray | list):
431 try:
432 fig = self._plot_time_series(
433 {"data": results["power"]}, title="Power vs Time", ylabel="Power (W)"
434 )
435 path = output_manager.save_plot(
436 domain, "power_timeseries", fig, format=plot_format, dpi=plot_dpi
437 )
438 paths.append(path)
439 plt.close(fig)
440 except Exception:
441 pass
443 return paths
445 # ============================================================================
446 # Individual plot methods
447 # ============================================================================
449 def _plot_spectrum(
450 self,
451 data: dict[str, Any],
452 title: str = "Spectrum",
453 ylabel: str = "Magnitude (dB)",
454 ) -> Figure:
455 """Plot frequency spectrum (FFT, PSD, etc.).
457 Args:
458 data: Dictionary with 'frequencies' and magnitude data.
459 title: Plot title.
460 ylabel: Y-axis label.
462 Returns:
463 Matplotlib Figure object.
465 Raises:
466 ValueError: If frequency/magnitude data is missing or empty.
467 """
468 fig, ax = plt.subplots(figsize=(10, 6))
470 frequencies = np.asarray(data.get("frequencies", []))
471 # Try multiple possible keys for magnitude data
472 magnitude = None
473 for key in ["magnitude_db", "psd", "magnitude", "power_db"]:
474 if key in data:
475 magnitude = np.asarray(data[key])
476 break
478 if magnitude is None or len(frequencies) == 0 or len(magnitude) == 0:
479 plt.close(fig)
480 raise ValueError("Missing or empty frequency/magnitude data")
482 # Auto-select frequency unit
483 max_freq = frequencies[-1] if len(frequencies) > 0 else 1.0
484 if max_freq >= 1e9:
485 freq_unit = "GHz"
486 freq_scale = 1e9
487 elif max_freq >= 1e6:
488 freq_unit = "MHz"
489 freq_scale = 1e6
490 elif max_freq >= 1e3:
491 freq_unit = "kHz"
492 freq_scale = 1e3
493 else:
494 freq_unit = "Hz"
495 freq_scale = 1.0
497 ax.plot(frequencies / freq_scale, magnitude, linewidth=0.8)
498 ax.set_xlabel(f"Frequency ({freq_unit})")
499 ax.set_ylabel(ylabel)
500 ax.set_title(title)
501 ax.grid(True, alpha=0.3, which="both")
502 ax.set_xscale("log")
504 fig.tight_layout()
505 return fig
507 def _plot_histogram(
508 self,
509 data: dict[str, Any],
510 title: str = "Histogram",
511 xlabel: str = "Value",
512 ) -> Figure:
513 """Plot histogram of data distribution.
515 Args:
516 data: Dictionary with 'data' array.
517 title: Plot title.
518 xlabel: X-axis label.
520 Returns:
521 Matplotlib Figure object.
523 Raises:
524 ValueError: If data array is empty or contains no finite values.
525 """
526 fig, ax = plt.subplots(figsize=(8, 6))
528 values = np.asarray(data.get("data", []))
529 if len(values) == 0:
530 plt.close(fig)
531 raise ValueError("Empty data array for histogram")
533 # Remove NaN/Inf values
534 values = values[np.isfinite(values)]
535 if len(values) == 0:
536 plt.close(fig)
537 raise ValueError("No finite values for histogram")
539 # Auto-select number of bins (Sturges' rule with limits)
540 n_bins = min(50, max(10, int(np.ceil(np.log2(len(values)) + 1))))
542 ax.hist(values, bins=n_bins, alpha=0.7, edgecolor="black")
543 ax.set_xlabel(xlabel)
544 ax.set_ylabel("Count")
545 ax.set_title(title)
546 ax.grid(True, alpha=0.3, axis="y")
548 fig.tight_layout()
549 return fig
551 def _plot_time_series(
552 self,
553 data: dict[str, Any],
554 title: str = "Time Series",
555 ylabel: str = "Amplitude",
556 ) -> Figure:
557 """Plot time-domain data.
559 Args:
560 data: Dictionary with 'data' and optionally 'time' arrays.
561 title: Plot title.
562 ylabel: Y-axis label.
564 Returns:
565 Matplotlib Figure object.
567 Raises:
568 ValueError: If data array is empty.
569 """
570 fig, ax = plt.subplots(figsize=(10, 6))
572 values = np.asarray(data.get("data", []))
573 if len(values) == 0:
574 plt.close(fig)
575 raise ValueError("Empty data array for time series")
577 time = data.get("time", np.arange(len(values)))
578 time = np.asarray(time)
580 # Auto-select time unit
581 max_time = time[-1] if len(time) > 0 else 1.0
582 if max_time < 1e-6:
583 time_unit = "ns"
584 time_scale = 1e9
585 elif max_time < 1e-3:
586 time_unit = "us"
587 time_scale = 1e6
588 elif max_time < 1:
589 time_unit = "ms"
590 time_scale = 1e3
591 else:
592 time_unit = "s"
593 time_scale = 1.0
595 ax.plot(time * time_scale, values, linewidth=0.8)
596 ax.set_xlabel(f"Time ({time_unit})")
597 ax.set_ylabel(ylabel)
598 ax.set_title(title)
599 ax.grid(True, alpha=0.3)
601 fig.tight_layout()
602 return fig
604 def _plot_spectrogram(self, data: dict[str, Any]) -> Figure:
605 """Plot spectrogram (time-frequency heatmap).
607 Args:
608 data: Dictionary with 'times', 'frequencies', and 'Sxx_db' arrays.
610 Returns:
611 Matplotlib Figure object.
613 Raises:
614 ValueError: If spectrogram data is missing or empty.
615 """
616 fig, ax = plt.subplots(figsize=(10, 6))
618 times = np.asarray(data.get("times", []))
619 frequencies = np.asarray(data.get("frequencies", []))
620 Sxx_db = np.asarray(data.get("Sxx_db", []))
622 if len(times) == 0 or len(frequencies) == 0 or Sxx_db.size == 0:
623 plt.close(fig)
624 raise ValueError("Missing spectrogram data")
626 # Auto-select units
627 max_time = times[-1] if len(times) > 0 else 1.0
628 if max_time < 1e-6:
629 time_unit = "ns"
630 time_scale = 1e9
631 elif max_time < 1e-3:
632 time_unit = "us"
633 time_scale = 1e6
634 elif max_time < 1:
635 time_unit = "ms"
636 time_scale = 1e3
637 else:
638 time_unit = "s"
639 time_scale = 1.0
641 max_freq = frequencies[-1] if len(frequencies) > 0 else 1.0
642 if max_freq >= 1e9:
643 freq_unit = "GHz"
644 freq_scale = 1e9
645 elif max_freq >= 1e6:
646 freq_unit = "MHz"
647 freq_scale = 1e6
648 elif max_freq >= 1e3:
649 freq_unit = "kHz"
650 freq_scale = 1e3
651 else:
652 freq_unit = "Hz"
653 freq_scale = 1.0
655 # Auto color limits
656 valid_db = Sxx_db[np.isfinite(Sxx_db)]
657 if len(valid_db) > 0:
658 vmax = np.max(valid_db)
659 vmin = max(np.min(valid_db), vmax - 80)
660 else:
661 vmin, vmax = None, None
663 pcm = ax.pcolormesh(
664 times * time_scale,
665 frequencies / freq_scale,
666 Sxx_db,
667 shading="auto",
668 cmap="viridis",
669 vmin=vmin,
670 vmax=vmax,
671 )
673 ax.set_xlabel(f"Time ({time_unit})")
674 ax.set_ylabel(f"Frequency ({freq_unit})")
675 ax.set_title("Spectrogram")
677 cbar = fig.colorbar(pcm, ax=ax)
678 cbar.set_label("Magnitude (dB)")
680 fig.tight_layout()
681 return fig
684# ============================================================================
685# Plot Registry
686# ============================================================================
688# Maps (domain, analysis_name) tuples to plot generation functions
689# This allows custom plot functions to be registered for specific analyses
690PLOT_REGISTRY: dict[
691 tuple[AnalysisDomain, str] | AnalysisDomain, Callable[[dict[str, Any]], Figure]
692] = {}
695def register_plot(
696 domain: AnalysisDomain,
697 analysis_name: str | None = None,
698) -> Callable[[Callable[[dict[str, Any]], Figure]], Callable[[dict[str, Any]], Figure]]:
699 """Decorator to register a custom plot function.
701 Args:
702 domain: Analysis domain.
703 analysis_name: Specific analysis name (optional). If None, registers
704 for entire domain.
706 Returns:
707 Decorator function.
709 Example:
710 >>> @register_plot(AnalysisDomain.SPECTRAL, "custom_fft")
711 ... def plot_custom_fft(data: dict[str, Any]) -> Figure:
712 ... fig, ax = plt.subplots()
713 ... # Custom plotting code
714 ... return fig
715 """
717 def decorator(func: Callable[[dict[str, Any]], Figure]) -> Callable[[dict[str, Any]], Figure]:
718 if analysis_name:
719 PLOT_REGISTRY[(domain, analysis_name)] = func
720 else:
721 PLOT_REGISTRY[domain] = func
722 return func
724 return decorator
727__all__ = [
728 "PLOT_REGISTRY",
729 "PlotGenerator",
730 "register_plot",
731]