Coverage for src / tracekit / session / session.py: 68%

130 statements  

« prev     ^ index     » next       coverage.py v7.13.1, created at 2026-01-11 23:04 +0000

1"""Analysis session management. 

2 

3This module provides session save/restore functionality for TraceKit. 

4 

5 

6Example: 

7 >>> session = Session() 

8 >>> session.load_trace('capture.wfm') 

9 >>> session.save('debug_session.tks') 

10 >>> 

11 >>> # Later... 

12 >>> session = load_session('debug_session.tks') 

13""" 

14 

15from __future__ import annotations 

16 

17import gzip 

18import pickle 

19from dataclasses import dataclass, field 

20from datetime import datetime 

21from pathlib import Path 

22from typing import Any 

23 

24import numpy as np 

25 

26from tracekit.session.annotations import AnnotationLayer 

27from tracekit.session.history import OperationHistory 

28 

29 

30@dataclass 

31class Session: 

32 """Analysis session container. 

33 

34 Manages traces, annotations, measurements, and history for a complete 

35 analysis session. Sessions can be saved and restored. 

36 

37 Attributes: 

38 name: Session name 

39 traces: Dictionary of loaded traces (name -> trace) 

40 annotation_layers: Annotation layers 

41 measurements: Recorded measurements 

42 history: Operation history 

43 metadata: Session metadata 

44 created_at: Creation timestamp 

45 modified_at: Last modification timestamp 

46 """ 

47 

48 name: str = "Untitled Session" 

49 traces: dict[str, Any] = field(default_factory=dict) 

50 annotation_layers: dict[str, AnnotationLayer] = field(default_factory=dict) 

51 measurements: dict[str, Any] = field(default_factory=dict) 

52 history: OperationHistory = field(default_factory=OperationHistory) 

53 metadata: dict[str, Any] = field(default_factory=dict) 

54 created_at: datetime = field(default_factory=datetime.now) 

55 modified_at: datetime = field(default_factory=datetime.now) 

56 _file_path: Path | None = None 

57 

58 def __post_init__(self) -> None: 

59 """Initialize default annotation layer.""" 

60 if "default" not in self.annotation_layers: 60 ↛ exitline 60 didn't return from function '__post_init__' because the condition on line 60 was always true

61 self.annotation_layers["default"] = AnnotationLayer("Default") 

62 

63 def load_trace( 

64 self, 

65 path: str | Path, 

66 name: str | None = None, 

67 **load_kwargs: Any, 

68 ) -> Any: 

69 """Load a trace into the session. 

70 

71 Args: 

72 path: Path to trace file. 

73 name: Name for trace in session (default: filename). 

74 **load_kwargs: Additional arguments for load(). 

75 

76 Returns: 

77 Loaded trace. 

78 """ 

79 from tracekit.loaders import load 

80 

81 path = Path(path) 

82 trace = load(str(path), **load_kwargs) 

83 

84 if name is None: 

85 name = path.stem 

86 

87 self.traces[name] = trace 

88 self._mark_modified() 

89 

90 self.history.record( 

91 "load_trace", 

92 {"path": str(path), "name": name}, 

93 result=f"Loaded {name}", 

94 ) 

95 

96 return trace 

97 

98 def add_trace( 

99 self, 

100 name: str, 

101 trace: Any, 

102 ) -> None: 

103 """Add an in-memory trace to the session. 

104 

105 This method allows adding traces that were created programmatically 

106 or loaded separately, rather than loading from a file. 

107 

108 Args: 

109 name: Name for the trace in the session. 

110 trace: Trace object (WaveformTrace, DigitalTrace, etc.). 

111 

112 Raises: 

113 ValueError: If name is empty or already exists. 

114 TypeError: If trace doesn't have expected attributes. 

115 

116 Example: 

117 >>> session = Session() 

118 >>> data = np.sin(np.linspace(0, 2*np.pi, 1000)) 

119 >>> trace = tk.WaveformTrace(data=data, metadata=tk.TraceMetadata(sample_rate=1e6)) 

120 >>> session.add_trace("my_trace", trace) 

121 """ 

122 if not name: 

123 raise ValueError("Trace name cannot be empty") 

124 

125 if not hasattr(trace, "data"): 

126 raise TypeError("Trace must have a 'data' attribute") 

127 

128 self.traces[name] = trace 

129 self._mark_modified() 

130 

131 self.history.record( 

132 "add_trace", 

133 {"name": name, "type": type(trace).__name__}, 

134 result=f"Added {name}", 

135 ) 

136 

137 def remove_trace(self, name: str) -> None: 

138 """Remove a trace from the session. 

139 

140 Args: 

141 name: Name of the trace to remove. 

142 

143 Raises: 

144 KeyError: If trace not found. 

145 """ 

146 if name not in self.traces: 

147 raise KeyError(f"Trace '{name}' not found in session") 

148 

149 del self.traces[name] 

150 self._mark_modified() 

151 

152 self.history.record( 

153 "remove_trace", 

154 {"name": name}, 

155 result=f"Removed {name}", 

156 ) 

157 

158 def get_trace(self, name: str) -> Any: 

159 """Get trace by name. 

160 

161 Args: 

162 name: Trace name. 

163 

164 Returns: 

165 Trace object. 

166 """ 

167 return self.traces[name] 

168 

169 def list_traces(self) -> list[str]: 

170 """List all trace names.""" 

171 return list(self.traces.keys()) 

172 

173 def annotate( 

174 self, 

175 text: str, 

176 *, 

177 time: float | None = None, 

178 time_range: tuple[float, float] | None = None, 

179 layer: str = "default", 

180 **kwargs: Any, 

181 ) -> None: 

182 """Add annotation to session. 

183 

184 Args: 

185 text: Annotation text. 

186 time: Time point for annotation. 

187 time_range: Time range for annotation. 

188 layer: Annotation layer name. 

189 **kwargs: Additional annotation parameters. 

190 """ 

191 if layer not in self.annotation_layers: 

192 self.annotation_layers[layer] = AnnotationLayer(layer) 

193 

194 self.annotation_layers[layer].add( 

195 text=text, 

196 time=time, 

197 time_range=time_range, 

198 **kwargs, 

199 ) 

200 self._mark_modified() 

201 

202 self.history.record( 

203 "annotate", 

204 {"text": text, "time": time, "layer": layer}, 

205 ) 

206 

207 def get_annotations( 

208 self, 

209 layer: str | None = None, 

210 time_range: tuple[float, float] | None = None, 

211 ) -> list[Any]: 

212 """Get annotations. 

213 

214 Args: 

215 layer: Filter by layer name (None for all layers). 

216 time_range: Filter by time range. 

217 

218 Returns: 

219 List of annotations. 

220 """ 

221 annotations = [] 

222 

223 layers = [self.annotation_layers[layer]] if layer else self.annotation_layers.values() 

224 

225 for ann_layer in layers: 

226 if time_range: 226 ↛ 227line 226 didn't jump to line 227 because the condition on line 226 was never true

227 annotations.extend(ann_layer.find_in_range(time_range[0], time_range[1])) 

228 else: 

229 annotations.extend(ann_layer.annotations) 

230 

231 return annotations 

232 

233 def record_measurement( 

234 self, 

235 name: str, 

236 value: Any, 

237 unit: str = "", 

238 trace_name: str | None = None, 

239 **metadata: Any, 

240 ) -> None: 

241 """Record a measurement result. 

242 

243 Args: 

244 name: Measurement name (e.g., 'rise_time'). 

245 value: Measurement value. 

246 unit: Unit of measurement. 

247 trace_name: Associated trace name. 

248 **metadata: Additional metadata. 

249 """ 

250 self.measurements[name] = { 

251 "value": value, 

252 "unit": unit, 

253 "trace": trace_name, 

254 "timestamp": datetime.now().isoformat(), 

255 **metadata, 

256 } 

257 self._mark_modified() 

258 

259 self.history.record( 

260 f"measure_{name}", 

261 {"trace": trace_name}, 

262 result=f"{value} {unit}".strip(), 

263 ) 

264 

265 def get_measurements(self) -> dict[str, Any]: 

266 """Get all recorded measurements.""" 

267 return self.measurements.copy() 

268 

269 def save( 

270 self, 

271 path: str | Path | None = None, 

272 *, 

273 include_traces: bool = True, 

274 compress: bool = True, 

275 ) -> Path: 

276 """Save session to file. 

277 

278 Args: 

279 path: Output path (default: use existing or generate). 

280 include_traces: Include trace data in session file. 

281 compress: Compress session file with gzip. 

282 

283 Returns: 

284 Path to saved file. 

285 

286 Example: 

287 >>> session.save('analysis.tks') 

288 

289 Security Note: 

290 Session files use pickle serialization for flexibility. Share 

291 session files only with trusted parties. For secure data exchange 

292 with untrusted parties, use JSON or HDF5 export formats instead. 

293 """ 

294 if path is None: 294 ↛ 295line 294 didn't jump to line 295 because the condition on line 294 was never true

295 path = self._file_path or Path(f"{self.name.replace(' ', '_')}.tks") 

296 else: 

297 path = Path(path) 

298 

299 self._file_path = path 

300 self._mark_modified() 

301 

302 # Build session data 

303 data = self._to_dict(include_traces=include_traces) 

304 

305 # Serialize 

306 if compress: 306 ↛ 310line 306 didn't jump to line 310 because the condition on line 306 was always true

307 with gzip.open(path, "wb") as f: 

308 pickle.dump(data, f) 

309 else: 

310 with open(path, "wb") as f: 

311 pickle.dump(data, f) 

312 

313 self.history.record("save", {"path": str(path)}) 

314 

315 return path 

316 

317 def _to_dict(self, include_traces: bool = True) -> dict[str, Any]: 

318 """Convert session to dictionary.""" 

319 data: dict[str, Any] = { 

320 "version": "1.0", 

321 "name": self.name, 

322 "created_at": self.created_at.isoformat(), 

323 "modified_at": self.modified_at.isoformat(), 

324 "annotation_layers": { 

325 name: layer.to_dict() for name, layer in self.annotation_layers.items() 

326 }, 

327 "measurements": self.measurements, 

328 "history": self.history.to_dict(), 

329 "metadata": self.metadata, 

330 } 

331 

332 if include_traces: 

333 # Store traces with their data 

334 data["traces"] = {} 

335 for name, trace in self.traces.items(): 335 ↛ 336line 335 didn't jump to line 336 because the loop on line 335 never started

336 trace_data = { 

337 "type": type(trace).__name__, 

338 "data": trace.data.tolist() if hasattr(trace, "data") else None, 

339 "sample_rate": ( 

340 trace.metadata.sample_rate if hasattr(trace, "metadata") else None 

341 ), 

342 } 

343 data["traces"][name] = trace_data 

344 else: 

345 data["traces"] = {} 

346 

347 return data 

348 

349 @classmethod 

350 def _from_dict(cls, data: dict[str, Any]) -> Session: 

351 """Create session from dictionary.""" 

352 session = cls( 

353 name=data.get("name", "Untitled Session"), 

354 metadata=data.get("metadata", {}), 

355 ) 

356 

357 if "created_at" in data: 357 ↛ 359line 357 didn't jump to line 359 because the condition on line 357 was always true

358 session.created_at = datetime.fromisoformat(data["created_at"]) 

359 if "modified_at" in data: 359 ↛ 363line 359 didn't jump to line 363 because the condition on line 359 was always true

360 session.modified_at = datetime.fromisoformat(data["modified_at"]) 

361 

362 # Restore annotation layers 

363 for name, layer_data in data.get("annotation_layers", {}).items(): 

364 session.annotation_layers[name] = AnnotationLayer.from_dict(layer_data) 

365 

366 # Restore measurements 

367 session.measurements = data.get("measurements", {}) 

368 

369 # Restore history 

370 if "history" in data: 370 ↛ 374line 370 didn't jump to line 374 because the condition on line 370 was always true

371 session.history = OperationHistory.from_dict(data["history"]) 

372 

373 # Restore traces (if included) 

374 if "traces" in data: 374 ↛ 384line 374 didn't jump to line 384 because the condition on line 374 was always true

375 from tracekit.core.types import WaveformTrace 

376 

377 for name, trace_data in data["traces"].items(): 377 ↛ 378line 377 didn't jump to line 378 because the loop on line 377 never started

378 if trace_data.get("data") is not None: 

379 session.traces[name] = WaveformTrace( # type: ignore[call-arg] 

380 data=np.array(trace_data["data"]), 

381 sample_rate=trace_data.get("sample_rate", 1.0), 

382 ) 

383 

384 return session 

385 

386 def _mark_modified(self) -> None: 

387 """Update modification timestamp.""" 

388 self.modified_at = datetime.now() 

389 

390 def summary(self) -> str: 

391 """Get session summary.""" 

392 lines = [ 

393 f"Session: {self.name}", 

394 f"Created: {self.created_at.strftime('%Y-%m-%d %H:%M')}", 

395 f"Modified: {self.modified_at.strftime('%Y-%m-%d %H:%M')}", 

396 f"Traces: {len(self.traces)}", 

397 f"Annotations: {sum(len(l.annotations) for l in self.annotation_layers.values())}", # noqa: E741 

398 f"Measurements: {len(self.measurements)}", 

399 f"History entries: {len(self.history.entries)}", 

400 ] 

401 return "\n".join(lines) 

402 

403 

404def load_session(path: str | Path) -> Session: 

405 """Load session from file. 

406 

407 Args: 

408 path: Path to session file (.tks). 

409 

410 Returns: 

411 Loaded Session object. 

412 

413 Example: 

414 >>> session = load_session('debug_session.tks') 

415 >>> print(session.list_traces()) 

416 

417 Security Warning: 

418 Session files use pickle serialization. Only load session files from 

419 trusted sources. Loading a malicious .tks file could execute arbitrary 

420 code. Never load session files from untrusted or unknown sources. 

421 

422 For secure data exchange, consider exporting to JSON or HDF5 formats 

423 instead of using pickle-based session files. 

424 """ 

425 path = Path(path) 

426 

427 try: 

428 # Try gzip compressed first 

429 with gzip.open(path, "rb") as f: 

430 data = pickle.load(f) 

431 except gzip.BadGzipFile: 

432 # Fall back to uncompressed 

433 with open(path, "rb") as f: 

434 data = pickle.load(f) 

435 

436 session = Session._from_dict(data) 

437 session._file_path = path 

438 

439 return session 

440 

441 

442__all__ = [ 

443 "Session", 

444 "load_session", 

445]