Coverage for src / tracekit / dsl / commands.py: 100%

98 statements  

« prev     ^ index     » next       coverage.py v7.13.1, created at 2026-01-11 23:04 +0000

1"""TraceKit DSL Commands. 

2 

3Built-in command implementations for DSL. 

4""" 

5 

6import sys 

7from pathlib import Path 

8from typing import Any 

9 

10from tracekit.core.exceptions import TraceKitError 

11 

12 

13def cmd_load(filename: str) -> Any: 

14 """Load a trace file. 

15 

16 Args: 

17 filename: Path to trace file 

18 

19 Returns: 

20 Loaded trace object 

21 

22 Raises: 

23 TraceKitError: If file cannot be loaded 

24 """ 

25 path = Path(filename) 

26 

27 if not path.exists(): 

28 raise TraceKitError(f"File not found: {filename}") 

29 

30 # Determine loader based on extension 

31 ext = path.suffix.lower() 

32 

33 try: 

34 if ext == ".csv": 

35 from tracekit.loaders.csv import ( # type: ignore[import-not-found] 

36 load_csv, # type: ignore[import-not-found] 

37 ) 

38 

39 return load_csv(str(path)) 

40 elif ext == ".bin": 

41 from tracekit.loaders.binary import ( # type: ignore[import-not-found] 

42 load_binary, # type: ignore[import-not-found] 

43 ) 

44 

45 return load_binary(str(path)) 

46 elif ext in (".h5", ".hdf5"): 

47 from tracekit.loaders.hdf5 import ( # type: ignore[import-not-found] 

48 load_hdf5, # type: ignore[import-not-found] 

49 ) 

50 

51 return load_hdf5(str(path)) 

52 else: 

53 raise TraceKitError(f"Unsupported file format: {ext}") 

54 

55 except ImportError as e: 

56 raise TraceKitError(f"Loader not available for {ext}: {e}") # noqa: B904 

57 

58 

59def cmd_filter(trace: Any, filter_type: str, *args: Any, **kwargs: Any) -> Any: 

60 """Apply filter to trace. 

61 

62 Args: 

63 trace: Input trace 

64 filter_type: Filter type (lowpass, highpass, bandpass, bandstop) 

65 *args: Filter parameters (cutoff frequency, etc.) 

66 **kwargs: Additional filter options 

67 

68 Returns: 

69 Filtered trace 

70 

71 Raises: 

72 TraceKitError: If filter cannot be applied 

73 """ 

74 try: 

75 from tracekit.filtering import filters # type: ignore[attr-defined] 

76 

77 if filter_type.lower() == "lowpass": 

78 if len(args) < 1: 

79 raise TraceKitError("lowpass filter requires cutoff frequency") 

80 return filters.low_pass(trace, cutoff=args[0], **kwargs) 

81 

82 elif filter_type.lower() == "highpass": 

83 if len(args) < 1: 

84 raise TraceKitError("highpass filter requires cutoff frequency") 

85 return filters.high_pass(trace, cutoff=args[0], **kwargs) 

86 

87 elif filter_type.lower() == "bandpass": 

88 if len(args) < 2: 

89 raise TraceKitError("bandpass filter requires low and high cutoff frequencies") 

90 return filters.band_pass(trace, low=args[0], high=args[1], **kwargs) 

91 

92 elif filter_type.lower() == "bandstop": 

93 if len(args) < 2: 

94 raise TraceKitError("bandstop filter requires low and high cutoff frequencies") 

95 return filters.band_stop(trace, low=args[0], high=args[1], **kwargs) 

96 

97 else: 

98 raise TraceKitError(f"Unknown filter type: {filter_type}") 

99 

100 except ImportError: 

101 raise TraceKitError("Filtering module not available") # noqa: B904 

102 

103 

104def cmd_measure(trace: Any, *measurements: str) -> Any: 

105 """Measure properties of trace. 

106 

107 Args: 

108 trace: Input trace 

109 *measurements: Measurement names (rise_time, fall_time, etc.) 

110 

111 Returns: 

112 Measurement results (single value or dict) 

113 

114 Raises: 

115 TraceKitError: If measurement cannot be performed 

116 """ 

117 try: 

118 from tracekit.analyzers import ( # type: ignore[attr-defined] 

119 measurements as meas, # type: ignore[attr-defined] 

120 ) 

121 

122 if len(measurements) == 0: 

123 raise TraceKitError("measure command requires at least one measurement name") 

124 

125 results = {} 

126 

127 for measurement in measurements: 

128 meas_name = measurement.lower() 

129 

130 if meas_name == "rise_time": 

131 results["rise_time"] = meas.rise_time(trace) 

132 elif meas_name == "fall_time": 

133 results["fall_time"] = meas.fall_time(trace) 

134 elif meas_name == "period": 

135 results["period"] = meas.period(trace) 

136 elif meas_name == "frequency": 

137 results["frequency"] = meas.frequency(trace) 

138 elif meas_name == "amplitude": 

139 results["amplitude"] = meas.amplitude(trace) 

140 elif meas_name == "mean": 

141 results["mean"] = meas.mean(trace) 

142 elif meas_name == "rms": 

143 results["rms"] = meas.rms(trace) 

144 elif meas_name == "all": 

145 # Measure all available measurements 

146 results = meas.measure_all(trace) 

147 break 

148 else: 

149 raise TraceKitError(f"Unknown measurement: {measurement}") 

150 

151 # Return single value if only one measurement 

152 if len(results) == 1: 

153 return next(iter(results.values())) 

154 

155 return results 

156 

157 except ImportError: 

158 raise TraceKitError("Measurements module not available") # noqa: B904 

159 

160 

161def cmd_plot(trace: Any, **options: Any) -> None: 

162 """Plot trace. 

163 

164 Args: 

165 trace: Input trace 

166 **options: Plot options (title, annotate, etc.) 

167 

168 Raises: 

169 TraceKitError: If plotting fails 

170 """ 

171 try: 

172 from tracekit.visualization import ( # type: ignore[attr-defined] 

173 plot as plot_module, # type: ignore[attr-defined] 

174 ) 

175 

176 title = options.get("title", "Trace Plot") 

177 annotate = options.get("annotate") 

178 

179 plot_module.plot_trace(trace, title=title) 

180 

181 if annotate: 

182 plot_module.add_annotation(annotate) 

183 

184 plot_module.show() 

185 

186 except ImportError: 

187 raise TraceKitError("Visualization module not available") # noqa: B904 

188 

189 

190def cmd_export(data: Any, format_type: str, filename: str | None = None) -> None: 

191 """Export data to file. 

192 

193 Args: 

194 data: Data to export (trace, measurements, etc.) 

195 format_type: Export format (json, csv, hdf5) 

196 filename: Output filename (optional, auto-generated if None) 

197 

198 Raises: 

199 TraceKitError: If export fails 

200 """ 

201 try: 

202 from tracekit.exporters import exporters # type: ignore[attr-defined] 

203 

204 if filename is None: 

205 filename = f"export.{format_type}" 

206 

207 fmt = format_type.lower() 

208 

209 if fmt == "json": 

210 exporters.json(data, filename) 

211 elif fmt == "csv": 

212 exporters.csv(data, filename) 

213 elif fmt in ("h5", "hdf5"): 

214 exporters.hdf5(data, filename) 

215 else: 

216 raise TraceKitError(f"Unknown export format: {format_type}") 

217 

218 print(f"Exported to {filename}", file=sys.stderr) 

219 

220 except ImportError: 

221 raise TraceKitError("Export module not available") # noqa: B904 

222 

223 

224def cmd_glob(pattern: str) -> list[str]: 

225 """Glob files matching pattern. 

226 

227 Args: 

228 pattern: Glob pattern (*.csv, etc.) 

229 

230 Returns: 

231 List of matching filenames 

232 """ 

233 from glob import glob as glob_func 

234 

235 return list(glob_func(pattern)) # noqa: PTH207 

236 

237 

238# Command registry 

239BUILTIN_COMMANDS = { 

240 "load": cmd_load, 

241 "filter": cmd_filter, 

242 "measure": cmd_measure, 

243 "plot": cmd_plot, 

244 "export": cmd_export, 

245 "glob": cmd_glob, 

246}