Coverage for agentos/tools/circuit_breaker.py: 0%
138 statements
« prev ^ index » next coverage.py v7.14.3, created at 2026-07-03 07:15 +0800
« prev ^ index » next coverage.py v7.14.3, created at 2026-07-03 07:15 +0800
1"""
2Circuit Breaker for AgentOS.
4Protects against cascading failures with three states:
5- CLOSED: normal operation, track failures
6- OPEN: circuit tripped, fast-fail all calls
7- HALF_OPEN: probe with limited calls to test recovery
9Supports failure/success thresholds, recovery timeout, and callbacks.
10"""
12import threading
13import time
14from dataclasses import dataclass, field
15from enum import Enum, auto
16from typing import Any, Callable, Dict, Optional, TypeVar
18T = TypeVar("T")
21# ============================================================================
22# Enums & Types
23# ============================================================================
25class CircuitState(Enum):
26 CLOSED = auto() # Normal operation
27 OPEN = auto() # Fast-fail, no calls allowed
28 HALF_OPEN = auto() # Probe mode, limited calls allowed
31CircuitCallback = Callable[["CircuitBreaker", CircuitState, CircuitState], None]
34# ============================================================================
35# CircuitBreaker
36# ============================================================================
38class CircuitBreaker:
39 """Thread-safe circuit breaker.
41 Parameters:
42 failure_threshold: consecutive/max failures before tripping
43 recovery_timeout: seconds before transitioning OPEN → HALF_OPEN
44 half_open_max_calls: max probe calls in HALF_OPEN before deciding
45 success_threshold: successes needed in HALF_OPEN to close circuit
46 """
48 def __init__(
49 self,
50 name: str = "default",
51 failure_threshold: int = 5,
52 recovery_timeout: float = 30.0,
53 half_open_max_calls: int = 3,
54 success_threshold: int = 2,
55 on_state_change: Optional[CircuitCallback] = None,
56 ):
57 self.name = name
58 self.failure_threshold = failure_threshold
59 self.recovery_timeout = recovery_timeout
60 self.half_open_max_calls = half_open_max_calls
61 self.success_threshold = success_threshold
62 self.on_state_change = on_state_change
64 self._lock = threading.RLock()
65 self._state: CircuitState = CircuitState.CLOSED
66 self._failure_count: int = 0
67 self._success_count: int = 0
68 self._half_open_calls: int = 0
69 self._last_failure_time: float = 0.0
70 self._last_success_time: float = 0.0
71 self._total_calls: int = 0
72 self._total_failures: int = 0
73 self._total_successes: int = 0
74 self._opened_at: float = 0.0
76 # ---------- state management ----------
78 def _transition(self, new_state: CircuitState) -> None:
79 old = self._state
80 if old == new_state:
81 return
82 self._state = new_state
83 if new_state == CircuitState.OPEN:
84 self._opened_at = time.time()
85 elif new_state == CircuitState.HALF_OPEN:
86 self._success_count = 0
87 self._half_open_calls = 0
88 elif new_state == CircuitState.CLOSED:
89 self._failure_count = 0
90 if self.on_state_change:
91 try:
92 self.on_state_change(self, old, new_state)
93 except Exception:
94 pass
96 @property
97 def state(self) -> CircuitState:
98 with self._lock:
99 return self._state
101 # ---------- call execution ----------
103 def call(self, fn: Callable[..., T], *args, **kwargs) -> T:
104 """Execute fn through the circuit breaker. Raises CircuitOpenError if open."""
105 self._check_state()
106 self._total_calls += 1
107 try:
108 result = fn(*args, **kwargs)
109 self._on_success()
110 return result
111 except Exception as e:
112 self._on_failure()
113 raise
115 def _check_state(self) -> None:
116 with self._lock:
117 if self._state == CircuitState.CLOSED:
118 return
119 if self._state == CircuitState.OPEN:
120 elapsed = time.time() - self._opened_at
121 if elapsed >= self.recovery_timeout:
122 self._transition(CircuitState.HALF_OPEN)
123 self._half_open_calls += 1 # count this probe
124 return
125 raise CircuitOpenError(
126 f"Circuit '{self.name}' is OPEN "
127 f"(recovery in {self.recovery_timeout - elapsed:.1f}s)"
128 )
129 if self._state == CircuitState.HALF_OPEN:
130 if self._half_open_calls >= self.half_open_max_calls:
131 raise CircuitOpenError(
132 f"Circuit '{self.name}' HALF_OPEN limit reached "
133 f"({self._half_open_calls}/{self.half_open_max_calls})"
134 )
135 self._half_open_calls += 1
137 def _on_success(self) -> None:
138 with self._lock:
139 self._total_successes += 1
140 self._last_success_time = time.time()
141 if self._state == CircuitState.HALF_OPEN:
142 self._success_count += 1
143 if self._success_count >= self.success_threshold:
144 self._transition(CircuitState.CLOSED)
145 elif self._state == CircuitState.CLOSED:
146 self._failure_count = 0
148 def _on_failure(self) -> None:
149 with self._lock:
150 self._total_failures += 1
151 self._last_failure_time = time.time()
152 self._failure_count += 1
153 if self._state == CircuitState.HALF_OPEN:
154 self._transition(CircuitState.OPEN)
155 elif self._state == CircuitState.CLOSED and self._failure_count >= self.failure_threshold:
156 self._transition(CircuitState.OPEN)
158 # ---------- manual control ----------
160 def reset(self) -> None:
161 """Force circuit back to CLOSED."""
162 with self._lock:
163 self._failure_count = 0
164 self._success_count = 0
165 self._half_open_calls = 0
166 self._transition(CircuitState.CLOSED)
168 def trip(self) -> None:
169 """Force circuit OPEN."""
170 with self._lock:
171 self._failure_count = self.failure_threshold
172 self._transition(CircuitState.OPEN)
174 # ---------- stats ----------
176 @property
177 def stats(self) -> Dict[str, Any]:
178 with self._lock:
179 return {
180 "name": self.name,
181 "state": self._state.name,
182 "failure_count": self._failure_count,
183 "half_open_calls": self._half_open_calls,
184 "total_calls": self._total_calls,
185 "total_successes": self._total_successes,
186 "total_failures": self._total_failures,
187 "last_failure": self._last_failure_time,
188 "last_success": self._last_success_time,
189 "opened_at": self._opened_at,
190 }
193# ============================================================================
194# Errors
195# ============================================================================
197class CircuitOpenError(Exception):
198 """Raised when a call is attempted on an OPEN circuit."""
199 pass
202# ============================================================================
203# CircuitRegistry — manage multiple breakers by name
204# ============================================================================
206class CircuitRegistry:
207 """Global registry for named circuit breakers."""
209 def __init__(self):
210 self._breakers: Dict[str, CircuitBreaker] = {}
211 self._lock = threading.Lock()
213 def get(self, name: str, **kwargs) -> CircuitBreaker:
214 with self._lock:
215 if name not in self._breakers:
216 self._breakers[name] = CircuitBreaker(name=name, **kwargs)
217 return self._breakers[name]
219 def remove(self, name: str) -> bool:
220 with self._lock:
221 return self._breakers.pop(name, None) is not None
223 def list_breakers(self) -> Dict[str, str]:
224 with self._lock:
225 return {n: b.state.name for n, b in self._breakers.items()}
227 def reset_all(self) -> None:
228 with self._lock:
229 for b in self._breakers.values():
230 b.reset()
233_default_registry: Optional[CircuitRegistry] = None
234_registry_lock = threading.Lock()
237def get_circuit_registry() -> CircuitRegistry:
238 global _default_registry
239 if _default_registry is None:
240 with _registry_lock:
241 if _default_registry is None:
242 _default_registry = CircuitRegistry()
243 return _default_registry