Coverage for agentos/tools/snapshot_manager.py: 0%

78 statements  

« prev     ^ index     » next       coverage.py v7.14.3, created at 2026-07-03 07:53 +0800

1""" 

2SnapshotManager — generic state snapshot + rollback. 

3 

4Supports: 

5 - Register objects that implement get_state() / restore_state(state) 

6 - Take named snapshots of all registered objects 

7 - Restore to any snapshot 

8 - List / delete snapshots 

9 - Maximum snapshot count with automatic eviction 

10 - Thread-safe 

11""" 

12 

13from __future__ import annotations 

14 

15import threading 

16import time 

17from typing import Any, Callable, Dict, List, Optional, Protocol, runtime_checkable 

18 

19 

20# ============================================================================ 

21# Snapshottable protocol 

22# ============================================================================ 

23 

24@runtime_checkable 

25class Snapshottable(Protocol): 

26 """Objects that can be snapshotted must implement these two methods.""" 

27 

28 def get_state(self) -> Any: 

29 ... 

30 

31 def restore_state(self, state: Any) -> None: 

32 ... 

33 

34 

35# ============================================================================ 

36# Snapshot 

37# ============================================================================ 

38 

39class Snapshot: 

40 """A named point-in-time state capture.""" 

41 

42 __slots__ = ("name", "states", "timestamp") 

43 

44 def __init__(self, name: str, states: Dict[str, Any]): 

45 self.name = name 

46 self.states = states # {object_id: state} 

47 self.timestamp = time.time() 

48 

49 

50# ============================================================================ 

51# SnapshotManager 

52# ============================================================================ 

53 

54class SnapshotManager: 

55 """Manages snapshots of Snapshottable objects. 

56 

57 Usage: 

58 class ConfigStore: 

59 def get_state(self): return {"version": self.version} 

60 def restore_state(self, state): self.version = state["version"] 

61 

62 manager = SnapshotManager(max_snapshots=10) 

63 manager.register("config", ConfigStore()) 

64 

65 # Take snapshot 

66 manager.snapshot("before_update") 

67 

68 # ... make changes ... 

69 

70 # Roll back 

71 manager.rollback("before_update") 

72 """ 

73 

74 def __init__(self, max_snapshots: int = 20): 

75 if max_snapshots < 1: 

76 raise ValueError("max_snapshots must be >= 1") 

77 self._max_snapshots = max_snapshots 

78 self._objects: Dict[str, Snapshottable] = {} 

79 self._snapshots: List[Snapshot] = [] 

80 self._lock = threading.RLock() 

81 

82 # ---------- Registration ---------- 

83 

84 def register(self, name: str, obj: Any) -> None: 

85 if not isinstance(obj, Snapshottable): 

86 raise TypeError( 

87 f"Object '{name}' does not implement Snapshottable " 

88 f"(needs get_state() / restore_state(state))" 

89 ) 

90 with self._lock: 

91 self._objects[name] = obj 

92 

93 def unregister(self, name: str) -> bool: 

94 with self._lock: 

95 return self._objects.pop(name, None) is not None 

96 

97 @property 

98 def registered(self) -> List[str]: 

99 with self._lock: 

100 return sorted(self._objects.keys()) 

101 

102 # ---------- Snapshots ---------- 

103 

104 def snapshot(self, name: str) -> Snapshot: 

105 """Capture current state of all registered objects.""" 

106 with self._lock: 

107 states = {} 

108 for obj_name, obj in self._objects.items(): 

109 states[obj_name] = obj.get_state() 

110 

111 snap = Snapshot(name=name, states=states) 

112 self._snapshots.append(snap) 

113 

114 # Evict oldest if over limit 

115 excess = len(self._snapshots) - self._max_snapshots 

116 if excess > 0: 

117 self._snapshots = self._snapshots[excess:] 

118 

119 return snap 

120 

121 def rollback(self, name: str, raise_on_missing: bool = True) -> bool: 

122 """Restore all registered objects to a named snapshot.""" 

123 snap = self._find_snapshot(name) 

124 if snap is None: 

125 if raise_on_missing: 

126 raise KeyError(f"Snapshot '{name}' not found") 

127 return False 

128 

129 with self._lock: 

130 restored = 0 

131 for obj_name, state in snap.states.items(): 

132 obj = self._objects.get(obj_name) 

133 if obj is not None: 

134 obj.restore_state(state) 

135 restored += 1 

136 return restored > 0 

137 

138 def _find_snapshot(self, name: str) -> Optional[Snapshot]: 

139 with self._lock: 

140 for snap in reversed(self._snapshots): 

141 if snap.name == name: 

142 return snap 

143 return None 

144 

145 # ---------- Management ---------- 

146 

147 def list_snapshots(self) -> List[str]: 

148 with self._lock: 

149 return [s.name for s in self._snapshots] 

150 

151 def delete_snapshot(self, name: str) -> bool: 

152 with self._lock: 

153 before = len(self._snapshots) 

154 self._snapshots = [s for s in self._snapshots if s.name != name] 

155 return len(self._snapshots) < before 

156 

157 def clear(self) -> None: 

158 with self._lock: 

159 self._snapshots.clear() 

160 

161 @property 

162 def snapshot_count(self) -> int: 

163 with self._lock: 

164 return len(self._snapshots)