Coverage for src / tracekit / session / session.py: 68%
130 statements
« prev ^ index » next coverage.py v7.13.1, created at 2026-01-11 23:04 +0000
« prev ^ index » next coverage.py v7.13.1, created at 2026-01-11 23:04 +0000
1"""Analysis session management.
3This module provides session save/restore functionality for TraceKit.
6Example:
7 >>> session = Session()
8 >>> session.load_trace('capture.wfm')
9 >>> session.save('debug_session.tks')
10 >>>
11 >>> # Later...
12 >>> session = load_session('debug_session.tks')
13"""
15from __future__ import annotations
17import gzip
18import pickle
19from dataclasses import dataclass, field
20from datetime import datetime
21from pathlib import Path
22from typing import Any
24import numpy as np
26from tracekit.session.annotations import AnnotationLayer
27from tracekit.session.history import OperationHistory
30@dataclass
31class Session:
32 """Analysis session container.
34 Manages traces, annotations, measurements, and history for a complete
35 analysis session. Sessions can be saved and restored.
37 Attributes:
38 name: Session name
39 traces: Dictionary of loaded traces (name -> trace)
40 annotation_layers: Annotation layers
41 measurements: Recorded measurements
42 history: Operation history
43 metadata: Session metadata
44 created_at: Creation timestamp
45 modified_at: Last modification timestamp
46 """
48 name: str = "Untitled Session"
49 traces: dict[str, Any] = field(default_factory=dict)
50 annotation_layers: dict[str, AnnotationLayer] = field(default_factory=dict)
51 measurements: dict[str, Any] = field(default_factory=dict)
52 history: OperationHistory = field(default_factory=OperationHistory)
53 metadata: dict[str, Any] = field(default_factory=dict)
54 created_at: datetime = field(default_factory=datetime.now)
55 modified_at: datetime = field(default_factory=datetime.now)
56 _file_path: Path | None = None
58 def __post_init__(self) -> None:
59 """Initialize default annotation layer."""
60 if "default" not in self.annotation_layers: 60 ↛ exitline 60 didn't return from function '__post_init__' because the condition on line 60 was always true
61 self.annotation_layers["default"] = AnnotationLayer("Default")
63 def load_trace(
64 self,
65 path: str | Path,
66 name: str | None = None,
67 **load_kwargs: Any,
68 ) -> Any:
69 """Load a trace into the session.
71 Args:
72 path: Path to trace file.
73 name: Name for trace in session (default: filename).
74 **load_kwargs: Additional arguments for load().
76 Returns:
77 Loaded trace.
78 """
79 from tracekit.loaders import load
81 path = Path(path)
82 trace = load(str(path), **load_kwargs)
84 if name is None:
85 name = path.stem
87 self.traces[name] = trace
88 self._mark_modified()
90 self.history.record(
91 "load_trace",
92 {"path": str(path), "name": name},
93 result=f"Loaded {name}",
94 )
96 return trace
98 def add_trace(
99 self,
100 name: str,
101 trace: Any,
102 ) -> None:
103 """Add an in-memory trace to the session.
105 This method allows adding traces that were created programmatically
106 or loaded separately, rather than loading from a file.
108 Args:
109 name: Name for the trace in the session.
110 trace: Trace object (WaveformTrace, DigitalTrace, etc.).
112 Raises:
113 ValueError: If name is empty or already exists.
114 TypeError: If trace doesn't have expected attributes.
116 Example:
117 >>> session = Session()
118 >>> data = np.sin(np.linspace(0, 2*np.pi, 1000))
119 >>> trace = tk.WaveformTrace(data=data, metadata=tk.TraceMetadata(sample_rate=1e6))
120 >>> session.add_trace("my_trace", trace)
121 """
122 if not name:
123 raise ValueError("Trace name cannot be empty")
125 if not hasattr(trace, "data"):
126 raise TypeError("Trace must have a 'data' attribute")
128 self.traces[name] = trace
129 self._mark_modified()
131 self.history.record(
132 "add_trace",
133 {"name": name, "type": type(trace).__name__},
134 result=f"Added {name}",
135 )
137 def remove_trace(self, name: str) -> None:
138 """Remove a trace from the session.
140 Args:
141 name: Name of the trace to remove.
143 Raises:
144 KeyError: If trace not found.
145 """
146 if name not in self.traces:
147 raise KeyError(f"Trace '{name}' not found in session")
149 del self.traces[name]
150 self._mark_modified()
152 self.history.record(
153 "remove_trace",
154 {"name": name},
155 result=f"Removed {name}",
156 )
158 def get_trace(self, name: str) -> Any:
159 """Get trace by name.
161 Args:
162 name: Trace name.
164 Returns:
165 Trace object.
166 """
167 return self.traces[name]
169 def list_traces(self) -> list[str]:
170 """List all trace names."""
171 return list(self.traces.keys())
173 def annotate(
174 self,
175 text: str,
176 *,
177 time: float | None = None,
178 time_range: tuple[float, float] | None = None,
179 layer: str = "default",
180 **kwargs: Any,
181 ) -> None:
182 """Add annotation to session.
184 Args:
185 text: Annotation text.
186 time: Time point for annotation.
187 time_range: Time range for annotation.
188 layer: Annotation layer name.
189 **kwargs: Additional annotation parameters.
190 """
191 if layer not in self.annotation_layers:
192 self.annotation_layers[layer] = AnnotationLayer(layer)
194 self.annotation_layers[layer].add(
195 text=text,
196 time=time,
197 time_range=time_range,
198 **kwargs,
199 )
200 self._mark_modified()
202 self.history.record(
203 "annotate",
204 {"text": text, "time": time, "layer": layer},
205 )
207 def get_annotations(
208 self,
209 layer: str | None = None,
210 time_range: tuple[float, float] | None = None,
211 ) -> list[Any]:
212 """Get annotations.
214 Args:
215 layer: Filter by layer name (None for all layers).
216 time_range: Filter by time range.
218 Returns:
219 List of annotations.
220 """
221 annotations = []
223 layers = [self.annotation_layers[layer]] if layer else self.annotation_layers.values()
225 for ann_layer in layers:
226 if time_range: 226 ↛ 227line 226 didn't jump to line 227 because the condition on line 226 was never true
227 annotations.extend(ann_layer.find_in_range(time_range[0], time_range[1]))
228 else:
229 annotations.extend(ann_layer.annotations)
231 return annotations
233 def record_measurement(
234 self,
235 name: str,
236 value: Any,
237 unit: str = "",
238 trace_name: str | None = None,
239 **metadata: Any,
240 ) -> None:
241 """Record a measurement result.
243 Args:
244 name: Measurement name (e.g., 'rise_time').
245 value: Measurement value.
246 unit: Unit of measurement.
247 trace_name: Associated trace name.
248 **metadata: Additional metadata.
249 """
250 self.measurements[name] = {
251 "value": value,
252 "unit": unit,
253 "trace": trace_name,
254 "timestamp": datetime.now().isoformat(),
255 **metadata,
256 }
257 self._mark_modified()
259 self.history.record(
260 f"measure_{name}",
261 {"trace": trace_name},
262 result=f"{value} {unit}".strip(),
263 )
265 def get_measurements(self) -> dict[str, Any]:
266 """Get all recorded measurements."""
267 return self.measurements.copy()
269 def save(
270 self,
271 path: str | Path | None = None,
272 *,
273 include_traces: bool = True,
274 compress: bool = True,
275 ) -> Path:
276 """Save session to file.
278 Args:
279 path: Output path (default: use existing or generate).
280 include_traces: Include trace data in session file.
281 compress: Compress session file with gzip.
283 Returns:
284 Path to saved file.
286 Example:
287 >>> session.save('analysis.tks')
289 Security Note:
290 Session files use pickle serialization for flexibility. Share
291 session files only with trusted parties. For secure data exchange
292 with untrusted parties, use JSON or HDF5 export formats instead.
293 """
294 if path is None: 294 ↛ 295line 294 didn't jump to line 295 because the condition on line 294 was never true
295 path = self._file_path or Path(f"{self.name.replace(' ', '_')}.tks")
296 else:
297 path = Path(path)
299 self._file_path = path
300 self._mark_modified()
302 # Build session data
303 data = self._to_dict(include_traces=include_traces)
305 # Serialize
306 if compress: 306 ↛ 310line 306 didn't jump to line 310 because the condition on line 306 was always true
307 with gzip.open(path, "wb") as f:
308 pickle.dump(data, f)
309 else:
310 with open(path, "wb") as f:
311 pickle.dump(data, f)
313 self.history.record("save", {"path": str(path)})
315 return path
317 def _to_dict(self, include_traces: bool = True) -> dict[str, Any]:
318 """Convert session to dictionary."""
319 data: dict[str, Any] = {
320 "version": "1.0",
321 "name": self.name,
322 "created_at": self.created_at.isoformat(),
323 "modified_at": self.modified_at.isoformat(),
324 "annotation_layers": {
325 name: layer.to_dict() for name, layer in self.annotation_layers.items()
326 },
327 "measurements": self.measurements,
328 "history": self.history.to_dict(),
329 "metadata": self.metadata,
330 }
332 if include_traces:
333 # Store traces with their data
334 data["traces"] = {}
335 for name, trace in self.traces.items(): 335 ↛ 336line 335 didn't jump to line 336 because the loop on line 335 never started
336 trace_data = {
337 "type": type(trace).__name__,
338 "data": trace.data.tolist() if hasattr(trace, "data") else None,
339 "sample_rate": (
340 trace.metadata.sample_rate if hasattr(trace, "metadata") else None
341 ),
342 }
343 data["traces"][name] = trace_data
344 else:
345 data["traces"] = {}
347 return data
349 @classmethod
350 def _from_dict(cls, data: dict[str, Any]) -> Session:
351 """Create session from dictionary."""
352 session = cls(
353 name=data.get("name", "Untitled Session"),
354 metadata=data.get("metadata", {}),
355 )
357 if "created_at" in data: 357 ↛ 359line 357 didn't jump to line 359 because the condition on line 357 was always true
358 session.created_at = datetime.fromisoformat(data["created_at"])
359 if "modified_at" in data: 359 ↛ 363line 359 didn't jump to line 363 because the condition on line 359 was always true
360 session.modified_at = datetime.fromisoformat(data["modified_at"])
362 # Restore annotation layers
363 for name, layer_data in data.get("annotation_layers", {}).items():
364 session.annotation_layers[name] = AnnotationLayer.from_dict(layer_data)
366 # Restore measurements
367 session.measurements = data.get("measurements", {})
369 # Restore history
370 if "history" in data: 370 ↛ 374line 370 didn't jump to line 374 because the condition on line 370 was always true
371 session.history = OperationHistory.from_dict(data["history"])
373 # Restore traces (if included)
374 if "traces" in data: 374 ↛ 384line 374 didn't jump to line 384 because the condition on line 374 was always true
375 from tracekit.core.types import WaveformTrace
377 for name, trace_data in data["traces"].items(): 377 ↛ 378line 377 didn't jump to line 378 because the loop on line 377 never started
378 if trace_data.get("data") is not None:
379 session.traces[name] = WaveformTrace( # type: ignore[call-arg]
380 data=np.array(trace_data["data"]),
381 sample_rate=trace_data.get("sample_rate", 1.0),
382 )
384 return session
386 def _mark_modified(self) -> None:
387 """Update modification timestamp."""
388 self.modified_at = datetime.now()
390 def summary(self) -> str:
391 """Get session summary."""
392 lines = [
393 f"Session: {self.name}",
394 f"Created: {self.created_at.strftime('%Y-%m-%d %H:%M')}",
395 f"Modified: {self.modified_at.strftime('%Y-%m-%d %H:%M')}",
396 f"Traces: {len(self.traces)}",
397 f"Annotations: {sum(len(l.annotations) for l in self.annotation_layers.values())}", # noqa: E741
398 f"Measurements: {len(self.measurements)}",
399 f"History entries: {len(self.history.entries)}",
400 ]
401 return "\n".join(lines)
404def load_session(path: str | Path) -> Session:
405 """Load session from file.
407 Args:
408 path: Path to session file (.tks).
410 Returns:
411 Loaded Session object.
413 Example:
414 >>> session = load_session('debug_session.tks')
415 >>> print(session.list_traces())
417 Security Warning:
418 Session files use pickle serialization. Only load session files from
419 trusted sources. Loading a malicious .tks file could execute arbitrary
420 code. Never load session files from untrusted or unknown sources.
422 For secure data exchange, consider exporting to JSON or HDF5 formats
423 instead of using pickle-based session files.
424 """
425 path = Path(path)
427 try:
428 # Try gzip compressed first
429 with gzip.open(path, "rb") as f:
430 data = pickle.load(f)
431 except gzip.BadGzipFile:
432 # Fall back to uncompressed
433 with open(path, "rb") as f:
434 data = pickle.load(f)
436 session = Session._from_dict(data)
437 session._file_path = path
439 return session
442__all__ = [
443 "Session",
444 "load_session",
445]