Coverage for agentos/tools/circuit_breaker.py: 25%

138 statements  

« prev     ^ index     » next       coverage.py v7.14.3, created at 2026-07-03 13:55 +0800

1""" 

2Circuit Breaker for AgentOS. 

3 

4Protects against cascading failures with three states: 

5- CLOSED: normal operation, track failures 

6- OPEN: circuit tripped, fast-fail all calls 

7- HALF_OPEN: probe with limited calls to test recovery 

8 

9Supports failure/success thresholds, recovery timeout, and callbacks. 

10""" 

11 

12import threading 

13import time 

14from dataclasses import dataclass, field 

15from enum import Enum, auto 

16from typing import Any, Callable, Dict, Optional, TypeVar 

17 

18T = TypeVar("T") 

19 

20 

21# ============================================================================ 

22# Enums & Types 

23# ============================================================================ 

24 

25class CircuitState(Enum): 

26 CLOSED = auto() # Normal operation 

27 OPEN = auto() # Fast-fail, no calls allowed 

28 HALF_OPEN = auto() # Probe mode, limited calls allowed 

29 

30 

31CircuitCallback = Callable[["CircuitBreaker", CircuitState, CircuitState], None] 

32 

33 

34# ============================================================================ 

35# CircuitBreaker 

36# ============================================================================ 

37 

38class CircuitBreaker: 

39 """Thread-safe circuit breaker. 

40 

41 Parameters: 

42 failure_threshold: consecutive/max failures before tripping 

43 recovery_timeout: seconds before transitioning OPEN → HALF_OPEN 

44 half_open_max_calls: max probe calls in HALF_OPEN before deciding 

45 success_threshold: successes needed in HALF_OPEN to close circuit 

46 """ 

47 

48 def __init__( 

49 self, 

50 name: str = "default", 

51 failure_threshold: int = 5, 

52 recovery_timeout: float = 30.0, 

53 half_open_max_calls: int = 3, 

54 success_threshold: int = 2, 

55 on_state_change: Optional[CircuitCallback] = None, 

56 ): 

57 self.name = name 

58 self.failure_threshold = failure_threshold 

59 self.recovery_timeout = recovery_timeout 

60 self.half_open_max_calls = half_open_max_calls 

61 self.success_threshold = success_threshold 

62 self.on_state_change = on_state_change 

63 

64 self._lock = threading.RLock() 

65 self._state: CircuitState = CircuitState.CLOSED 

66 self._failure_count: int = 0 

67 self._success_count: int = 0 

68 self._half_open_calls: int = 0 

69 self._last_failure_time: float = 0.0 

70 self._last_success_time: float = 0.0 

71 self._total_calls: int = 0 

72 self._total_failures: int = 0 

73 self._total_successes: int = 0 

74 self._opened_at: float = 0.0 

75 

76 # ---------- state management ---------- 

77 

78 def _transition(self, new_state: CircuitState) -> None: 

79 old = self._state 

80 if old == new_state: 

81 return 

82 self._state = new_state 

83 if new_state == CircuitState.OPEN: 

84 self._opened_at = time.time() 

85 elif new_state == CircuitState.HALF_OPEN: 

86 self._success_count = 0 

87 self._half_open_calls = 0 

88 elif new_state == CircuitState.CLOSED: 

89 self._failure_count = 0 

90 if self.on_state_change: 

91 try: 

92 self.on_state_change(self, old, new_state) 

93 except Exception: 

94 pass 

95 

96 @property 

97 def state(self) -> CircuitState: 

98 with self._lock: 

99 return self._state 

100 

101 # ---------- call execution ---------- 

102 

103 def call(self, fn: Callable[..., T], *args, **kwargs) -> T: 

104 """Execute fn through the circuit breaker. Raises CircuitOpenError if open.""" 

105 self._check_state() 

106 self._total_calls += 1 

107 try: 

108 result = fn(*args, **kwargs) 

109 self._on_success() 

110 return result 

111 except Exception as e: 

112 self._on_failure() 

113 raise 

114 

115 def _check_state(self) -> None: 

116 with self._lock: 

117 if self._state == CircuitState.CLOSED: 

118 return 

119 if self._state == CircuitState.OPEN: 

120 elapsed = time.time() - self._opened_at 

121 if elapsed >= self.recovery_timeout: 

122 self._transition(CircuitState.HALF_OPEN) 

123 self._half_open_calls += 1 # count this probe 

124 return 

125 raise CircuitOpenError( 

126 f"Circuit '{self.name}' is OPEN " 

127 f"(recovery in {self.recovery_timeout - elapsed:.1f}s)" 

128 ) 

129 if self._state == CircuitState.HALF_OPEN: 

130 if self._half_open_calls >= self.half_open_max_calls: 

131 raise CircuitOpenError( 

132 f"Circuit '{self.name}' HALF_OPEN limit reached " 

133 f"({self._half_open_calls}/{self.half_open_max_calls})" 

134 ) 

135 self._half_open_calls += 1 

136 

137 def _on_success(self) -> None: 

138 with self._lock: 

139 self._total_successes += 1 

140 self._last_success_time = time.time() 

141 if self._state == CircuitState.HALF_OPEN: 

142 self._success_count += 1 

143 if self._success_count >= self.success_threshold: 

144 self._transition(CircuitState.CLOSED) 

145 elif self._state == CircuitState.CLOSED: 

146 self._failure_count = 0 

147 

148 def _on_failure(self) -> None: 

149 with self._lock: 

150 self._total_failures += 1 

151 self._last_failure_time = time.time() 

152 self._failure_count += 1 

153 if self._state == CircuitState.HALF_OPEN: 

154 self._transition(CircuitState.OPEN) 

155 elif self._state == CircuitState.CLOSED and self._failure_count >= self.failure_threshold: 

156 self._transition(CircuitState.OPEN) 

157 

158 # ---------- manual control ---------- 

159 

160 def reset(self) -> None: 

161 """Force circuit back to CLOSED.""" 

162 with self._lock: 

163 self._failure_count = 0 

164 self._success_count = 0 

165 self._half_open_calls = 0 

166 self._transition(CircuitState.CLOSED) 

167 

168 def trip(self) -> None: 

169 """Force circuit OPEN.""" 

170 with self._lock: 

171 self._failure_count = self.failure_threshold 

172 self._transition(CircuitState.OPEN) 

173 

174 # ---------- stats ---------- 

175 

176 @property 

177 def stats(self) -> Dict[str, Any]: 

178 with self._lock: 

179 return { 

180 "name": self.name, 

181 "state": self._state.name, 

182 "failure_count": self._failure_count, 

183 "half_open_calls": self._half_open_calls, 

184 "total_calls": self._total_calls, 

185 "total_successes": self._total_successes, 

186 "total_failures": self._total_failures, 

187 "last_failure": self._last_failure_time, 

188 "last_success": self._last_success_time, 

189 "opened_at": self._opened_at, 

190 } 

191 

192 

193# ============================================================================ 

194# Errors 

195# ============================================================================ 

196 

197class CircuitOpenError(Exception): 

198 """Raised when a call is attempted on an OPEN circuit.""" 

199 pass 

200 

201 

202# ============================================================================ 

203# CircuitRegistry — manage multiple breakers by name 

204# ============================================================================ 

205 

206class CircuitRegistry: 

207 """Global registry for named circuit breakers.""" 

208 

209 def __init__(self): 

210 self._breakers: Dict[str, CircuitBreaker] = {} 

211 self._lock = threading.Lock() 

212 

213 def get(self, name: str, **kwargs) -> CircuitBreaker: 

214 with self._lock: 

215 if name not in self._breakers: 

216 self._breakers[name] = CircuitBreaker(name=name, **kwargs) 

217 return self._breakers[name] 

218 

219 def remove(self, name: str) -> bool: 

220 with self._lock: 

221 return self._breakers.pop(name, None) is not None 

222 

223 def list_breakers(self) -> Dict[str, str]: 

224 with self._lock: 

225 return {n: b.state.name for n, b in self._breakers.items()} 

226 

227 def reset_all(self) -> None: 

228 with self._lock: 

229 for b in self._breakers.values(): 

230 b.reset() 

231 

232 

233_default_registry: Optional[CircuitRegistry] = None 

234_registry_lock = threading.Lock() 

235 

236 

237def get_circuit_registry() -> CircuitRegistry: 

238 global _default_registry 

239 if _default_registry is None: 

240 with _registry_lock: 

241 if _default_registry is None: 

242 _default_registry = CircuitRegistry() 

243 return _default_registry