Coverage for src / tracekit / visualization / digital.py: 69%

102 statements  

« prev     ^ index     » next       coverage.py v7.13.1, created at 2026-01-11 23:04 +0000

1"""Digital timing diagram visualization. 

2 

3This module provides timing diagrams for digital signals with 

4protocol decode overlay support. 

5 

6 

7Example: 

8 >>> from tracekit.visualization.digital import plot_timing 

9 >>> fig = plot_timing([clk, data, cs], names=["CLK", "DATA", "CS"]) 

10 >>> plt.show() 

11 

12References: 

13 matplotlib best practices for digital waveform visualization 

14""" 

15 

16from __future__ import annotations 

17 

18from typing import TYPE_CHECKING 

19 

20import numpy as np 

21 

22try: 

23 import matplotlib.pyplot as plt 

24 from matplotlib.patches import Rectangle 

25 

26 HAS_MATPLOTLIB = True 

27except ImportError: 

28 HAS_MATPLOTLIB = False 

29 

30from tracekit.core.types import DigitalTrace, WaveformTrace 

31 

32if TYPE_CHECKING: 

33 from collections.abc import Sequence 

34 

35 from matplotlib.axes import Axes 

36 from matplotlib.figure import Figure 

37 

38 from tracekit.analyzers.protocols.base import Annotation 

39 

40 

41def plot_timing( 

42 traces: Sequence[WaveformTrace | DigitalTrace], 

43 *, 

44 names: list[str] | None = None, 

45 annotations: list[list[Annotation]] | None = None, 

46 time_unit: str = "auto", 

47 show_grid: bool = True, 

48 figsize: tuple[float, float] | None = None, 

49 title: str | None = None, 

50 time_range: tuple[float, float] | None = None, 

51 threshold: float | str = "auto", 

52) -> Figure: 

53 """Plot digital timing diagram with protocol decode overlay. 

54 

55 Creates a stacked timing diagram showing digital waveforms with 

56 timing information and optional protocol decode annotations. 

57 

58 Args: 

59 traces: List of traces to plot (analog or digital). 

60 names: Channel names for labels. If None, uses CH1, CH2, etc. 

61 annotations: List of protocol annotations per channel (optional). 

62 time_unit: Time unit ("s", "ms", "us", "ns", "auto"). 

63 show_grid: Show vertical grid lines at time intervals. 

64 figsize: Figure size (width, height) in inches. 

65 title: Overall figure title. 

66 time_range: Optional (start, end) time range to display in seconds. 

67 threshold: Threshold for analog-to-digital conversion ("auto" or float). 

68 

69 Returns: 

70 Matplotlib Figure object. 

71 

72 Raises: 

73 ImportError: If matplotlib is not available. 

74 ValueError: If traces list is empty. 

75 

76 Example: 

77 >>> fig = plot_timing( 

78 ... [clk_trace, data_trace, cs_trace], 

79 ... names=["CLK", "DATA", "CS"], 

80 ... annotations=[[], uart_annotations, []] 

81 ... ) 

82 >>> plt.savefig("timing.png") 

83 

84 References: 

85 IEEE 181-2011: Standard for Transitional Waveform Definitions 

86 """ 

87 if not HAS_MATPLOTLIB: 87 ↛ 88line 87 didn't jump to line 88 because the condition on line 87 was never true

88 raise ImportError("matplotlib is required for visualization") 

89 

90 if len(traces) == 0: 

91 raise ValueError("traces list cannot be empty") 

92 

93 n_channels = len(traces) 

94 

95 if names is None: 

96 names = [f"CH{i + 1}" for i in range(n_channels)] 

97 

98 if len(names) != n_channels: 

99 raise ValueError(f"names length ({len(names)}) must match traces ({n_channels})") 

100 

101 if figsize is None: 101 ↛ 105line 101 didn't jump to line 105 because the condition on line 101 was always true

102 figsize = (12, 1.5 * n_channels) 

103 

104 # Convert analog traces to digital 

105 from tracekit.analyzers.digital.extraction import to_digital 

106 

107 digital_traces: list[DigitalTrace] = [] 

108 for trace in traces: 

109 if isinstance(trace, WaveformTrace): 

110 digital_traces.append(to_digital(trace, threshold=threshold)) # type: ignore[arg-type] 

111 else: 

112 digital_traces.append(trace) 

113 

114 # Auto-select time unit from first trace 

115 if time_unit == "auto" and len(digital_traces) > 0: 

116 ref_trace = digital_traces[0] 

117 duration = len(ref_trace.data) * ref_trace.metadata.time_base 

118 if duration < 1e-6: 118 ↛ 119line 118 didn't jump to line 119 because the condition on line 118 was never true

119 time_unit = "ns" 

120 elif duration < 1e-3: 

121 time_unit = "us" 

122 elif duration < 1: 122 ↛ 125line 122 didn't jump to line 125 because the condition on line 122 was always true

123 time_unit = "ms" 

124 else: 

125 time_unit = "s" 

126 

127 time_multipliers = {"s": 1.0, "ms": 1e3, "us": 1e6, "ns": 1e9} 

128 multiplier = time_multipliers.get(time_unit, 1.0) 

129 

130 # Create figure 

131 fig, axes = plt.subplots( 

132 n_channels, 

133 1, 

134 figsize=figsize, 

135 sharex=True, 

136 ) 

137 

138 if n_channels == 1: 

139 axes = [axes] 

140 

141 # Determine time range 

142 if time_range is not None: 

143 start_time, end_time = time_range 

144 else: 

145 start_time = 0.0 

146 end_time = max(trace.duration for trace in digital_traces if len(trace.data) > 0) 

147 

148 for i, (trace, name, ax) in enumerate(zip(digital_traces, names, axes, strict=False)): 

149 time = trace.time_vector * multiplier 

150 

151 # Filter to time range 

152 if time_range is not None: 

153 start_idx = int(np.searchsorted(trace.time_vector, start_time)) 

154 end_idx = int(np.searchsorted(trace.time_vector, end_time)) 

155 time = time[start_idx:end_idx] 

156 data_slice = trace.data[start_idx:end_idx] 

157 else: 

158 data_slice = trace.data 

159 

160 # Plot digital waveform as step function 

161 ax.step( 

162 time, 

163 data_slice.astype(int), 

164 where="post", 

165 color=f"C{i}", 

166 linewidth=1.5, 

167 ) 

168 

169 # Set up digital signal display 

170 ax.set_ylim(-0.2, 1.2) 

171 ax.set_yticks([0, 1]) 

172 ax.set_yticklabels(["0", "1"]) 

173 ax.set_ylabel(name, rotation=0, ha="right", va="center", fontweight="bold") 

174 

175 if show_grid: 

176 ax.grid(True, alpha=0.2, axis="x") 

177 

178 # Add protocol annotations if provided 

179 if annotations is not None and i < len(annotations) and annotations[i]: 179 ↛ 180line 179 didn't jump to line 180 because the condition on line 179 was never true

180 _add_protocol_annotations(ax, annotations[i], multiplier, time_unit) 

181 

182 # Remove x-axis labels except for bottom plot 

183 if i < n_channels - 1: 

184 ax.set_xticklabels([]) 

185 

186 # Set x-label on bottom plot 

187 axes[-1].set_xlabel(f"Time ({time_unit})") 

188 

189 if title: 

190 fig.suptitle(title, fontsize=14, fontweight="bold") 

191 

192 fig.tight_layout() 

193 return fig 

194 

195 

196def _add_protocol_annotations( 

197 ax: Axes, 

198 annotations: list[Annotation], 

199 multiplier: float, 

200 time_unit: str, 

201) -> None: 

202 """Add protocol decode annotations to timing diagram. 

203 

204 Args: 

205 ax: Matplotlib axes to annotate. 

206 annotations: List of protocol annotations. 

207 multiplier: Time unit multiplier for display. 

208 time_unit: Time unit string. 

209 """ 

210 for ann in annotations: 

211 # Get annotation time range 

212 start_time = ann.start_sample * multiplier if hasattr(ann, "start_sample") else 0 

213 end_time = ann.end_sample * multiplier if hasattr(ann, "end_sample") else start_time 

214 

215 # Get annotation text and level 

216 if hasattr(ann, "data"): 

217 text = str(ann.data) 

218 elif hasattr(ann, "value"): 

219 text = str(ann.value) 

220 else: 

221 text = str(ann) 

222 

223 # Determine annotation color based on type/level 

224 color = "lightblue" 

225 if hasattr(ann, "level"): 

226 level_str = str(ann.level).lower() 

227 if "error" in level_str or "warn" in level_str: 

228 color = "lightcoral" 

229 elif "data" in level_str or "byte" in level_str: 

230 color = "lightgreen" 

231 elif "start" in level_str or "stop" in level_str: 

232 color = "lightyellow" 

233 

234 # Draw annotation box 

235 width = end_time - start_time if end_time > start_time else multiplier * 10 

236 rect = Rectangle( 

237 (start_time, 1.05), 

238 width, 

239 0.15, 

240 facecolor=color, 

241 edgecolor="black", 

242 linewidth=0.5, 

243 alpha=0.7, 

244 ) 

245 ax.add_patch(rect) 

246 

247 # Add text label 

248 mid_time = start_time + width / 2 

249 ax.text( 

250 mid_time, 

251 1.125, 

252 text, 

253 ha="center", 

254 va="center", 

255 fontsize=7, 

256 fontfamily="monospace", 

257 ) 

258 

259 

260def plot_logic_analyzer( 

261 traces: Sequence[DigitalTrace], 

262 *, 

263 names: list[str] | None = None, 

264 bus_groups: dict[str, list[int]] | None = None, 

265 time_unit: str = "auto", 

266 show_grid: bool = True, 

267 figsize: tuple[float, float] | None = None, 

268 title: str | None = None, 

269) -> Figure: 

270 """Plot logic analyzer style multi-channel display with bus grouping. 

271 

272 Creates a timing diagram optimized for logic analyzer visualization 

273 with support for bus grouping (showing multi-bit buses as hex values). 

274 

275 Args: 

276 traces: List of digital traces. 

277 names: Channel names. 

278 bus_groups: Dictionary mapping bus names to channel indices. 

279 Example: {"DATA": [0, 1, 2, 3], "ADDR": [4, 5, 6, 7]} 

280 time_unit: Time unit for display. 

281 show_grid: Show vertical grid lines. 

282 figsize: Figure size. 

283 title: Plot title. 

284 

285 Returns: 

286 Matplotlib Figure object. 

287 

288 Raises: 

289 ImportError: If matplotlib is not available. 

290 ValueError: If traces list is empty. 

291 

292 Example: 

293 >>> fig = plot_logic_analyzer( 

294 ... traces, 

295 ... names=[f"D{i}" for i in range(8)], 

296 ... bus_groups={"DATA": [0, 1, 2, 3, 4, 5, 6, 7]} 

297 ... ) 

298 

299 References: 

300 Logic analyzer display conventions 

301 """ 

302 if not HAS_MATPLOTLIB: 302 ↛ 303line 302 didn't jump to line 303 because the condition on line 302 was never true

303 raise ImportError("matplotlib is required for visualization") 

304 

305 if len(traces) == 0: 

306 raise ValueError("traces list cannot be empty") 

307 

308 # Convert to list for plot_timing 

309 traces_list: list[WaveformTrace | DigitalTrace] = list(traces) 

310 

311 # If no bus groups, just use regular timing diagram 

312 if bus_groups is None: 312 ↛ 324line 312 didn't jump to line 324 because the condition on line 312 was always true

313 return plot_timing( 

314 traces_list, 

315 names=names, 

316 time_unit=time_unit, 

317 show_grid=show_grid, 

318 figsize=figsize, 

319 title=title, 

320 ) 

321 

322 # Implementation for bus grouping would go here 

323 # For MVP, delegate to plot_timing 

324 return plot_timing( 

325 traces_list, 

326 names=names, 

327 time_unit=time_unit, 

328 show_grid=show_grid, 

329 figsize=figsize, 

330 title=title, 

331 ) 

332 

333 

334__all__ = [ 

335 "plot_logic_analyzer", 

336 "plot_timing", 

337]