Coverage for agentos/tools/snapshot_manager.py: 0%
78 statements
« prev ^ index » next coverage.py v7.14.3, created at 2026-07-03 07:53 +0800
« prev ^ index » next coverage.py v7.14.3, created at 2026-07-03 07:53 +0800
1"""
2SnapshotManager — generic state snapshot + rollback.
4Supports:
5 - Register objects that implement get_state() / restore_state(state)
6 - Take named snapshots of all registered objects
7 - Restore to any snapshot
8 - List / delete snapshots
9 - Maximum snapshot count with automatic eviction
10 - Thread-safe
11"""
13from __future__ import annotations
15import threading
16import time
17from typing import Any, Callable, Dict, List, Optional, Protocol, runtime_checkable
20# ============================================================================
21# Snapshottable protocol
22# ============================================================================
24@runtime_checkable
25class Snapshottable(Protocol):
26 """Objects that can be snapshotted must implement these two methods."""
28 def get_state(self) -> Any:
29 ...
31 def restore_state(self, state: Any) -> None:
32 ...
35# ============================================================================
36# Snapshot
37# ============================================================================
39class Snapshot:
40 """A named point-in-time state capture."""
42 __slots__ = ("name", "states", "timestamp")
44 def __init__(self, name: str, states: Dict[str, Any]):
45 self.name = name
46 self.states = states # {object_id: state}
47 self.timestamp = time.time()
50# ============================================================================
51# SnapshotManager
52# ============================================================================
54class SnapshotManager:
55 """Manages snapshots of Snapshottable objects.
57 Usage:
58 class ConfigStore:
59 def get_state(self): return {"version": self.version}
60 def restore_state(self, state): self.version = state["version"]
62 manager = SnapshotManager(max_snapshots=10)
63 manager.register("config", ConfigStore())
65 # Take snapshot
66 manager.snapshot("before_update")
68 # ... make changes ...
70 # Roll back
71 manager.rollback("before_update")
72 """
74 def __init__(self, max_snapshots: int = 20):
75 if max_snapshots < 1:
76 raise ValueError("max_snapshots must be >= 1")
77 self._max_snapshots = max_snapshots
78 self._objects: Dict[str, Snapshottable] = {}
79 self._snapshots: List[Snapshot] = []
80 self._lock = threading.RLock()
82 # ---------- Registration ----------
84 def register(self, name: str, obj: Any) -> None:
85 if not isinstance(obj, Snapshottable):
86 raise TypeError(
87 f"Object '{name}' does not implement Snapshottable "
88 f"(needs get_state() / restore_state(state))"
89 )
90 with self._lock:
91 self._objects[name] = obj
93 def unregister(self, name: str) -> bool:
94 with self._lock:
95 return self._objects.pop(name, None) is not None
97 @property
98 def registered(self) -> List[str]:
99 with self._lock:
100 return sorted(self._objects.keys())
102 # ---------- Snapshots ----------
104 def snapshot(self, name: str) -> Snapshot:
105 """Capture current state of all registered objects."""
106 with self._lock:
107 states = {}
108 for obj_name, obj in self._objects.items():
109 states[obj_name] = obj.get_state()
111 snap = Snapshot(name=name, states=states)
112 self._snapshots.append(snap)
114 # Evict oldest if over limit
115 excess = len(self._snapshots) - self._max_snapshots
116 if excess > 0:
117 self._snapshots = self._snapshots[excess:]
119 return snap
121 def rollback(self, name: str, raise_on_missing: bool = True) -> bool:
122 """Restore all registered objects to a named snapshot."""
123 snap = self._find_snapshot(name)
124 if snap is None:
125 if raise_on_missing:
126 raise KeyError(f"Snapshot '{name}' not found")
127 return False
129 with self._lock:
130 restored = 0
131 for obj_name, state in snap.states.items():
132 obj = self._objects.get(obj_name)
133 if obj is not None:
134 obj.restore_state(state)
135 restored += 1
136 return restored > 0
138 def _find_snapshot(self, name: str) -> Optional[Snapshot]:
139 with self._lock:
140 for snap in reversed(self._snapshots):
141 if snap.name == name:
142 return snap
143 return None
145 # ---------- Management ----------
147 def list_snapshots(self) -> List[str]:
148 with self._lock:
149 return [s.name for s in self._snapshots]
151 def delete_snapshot(self, name: str) -> bool:
152 with self._lock:
153 before = len(self._snapshots)
154 self._snapshots = [s for s in self._snapshots if s.name != name]
155 return len(self._snapshots) < before
157 def clear(self) -> None:
158 with self._lock:
159 self._snapshots.clear()
161 @property
162 def snapshot_count(self) -> int:
163 with self._lock:
164 return len(self._snapshots)