Coverage for agentos/guardrails/policy.py: 40%
67 statements
« prev ^ index » next coverage.py v7.14.3, created at 2026-07-02 09:59 +0800
« prev ^ index » next coverage.py v7.14.3, created at 2026-07-02 09:59 +0800
1"""
2Guardrail policy enforcement — cumulative violation tracking, rate limiting,
3and session-scoped policy decisions.
4"""
6from dataclasses import dataclass, field
7from enum import Enum, auto
8from typing import Any, Callable, Dict, List, Optional
10from agentos.guardrails.engine import GuardrailAction, GuardrailCategory, GuardrailResult
13class PolicyViolation(str, Enum):
14 """Policy-level violation reasons."""
16 SESSION_BLOCKED = "session_blocked"
17 RATE_LIMITED = "rate_limited"
18 CUMULATIVE_VIOLATIONS = "cumulative_violations"
19 CATEGORY_BANNED = "category_banned"
22@dataclass
23class GuardrailPolicy:
24 """Session-scoped policy configuration."""
26 max_total_violations: int = 5
27 max_violations_per_category: dict[str, int] = field(default_factory=dict)
28 window_seconds: int = 300
29 auto_block_on: set[GuardrailCategory] = field(default_factory=set)
30 on_session_block: str = "reject" # reject / warn
31 monitoring_callback: Callable[[str, Dict[str, Any]], None] | None = None
33 def __post_init__(self):
34 if not self.max_violations_per_category:
35 self.max_violations_per_category = {
36 GuardrailCategory.INJECTION.value: 2,
37 GuardrailCategory.TOXICITY.value: 3,
38 GuardrailCategory.KEYWORD.value: 3,
39 }
42class PolicyEnforcer:
43 """Tracks violations per session and enforces cumulative policy."""
45 def __init__(self, policy: Optional[GuardrailPolicy] = None):
46 self.policy = policy or GuardrailPolicy()
47 self._violation_count: int = 0
48 self._category_counts: Dict[str, int] = {}
49 self._session_blocked: bool = False
50 self._violation_log: list[tuple[float, str, str]] = []
52 def evaluate(self, result: GuardrailResult, category: str = "") -> PolicyViolation | None:
53 """Evaluate a guardrail result against the current policy.
55 Returns None if no policy violation, or the reason for violation.
56 """
57 if self._session_blocked:
58 return PolicyViolation.SESSION_BLOCKED
60 if result.action == GuardrailAction.PASS:
61 return None
63 import time
64 now = time.time()
66 self._violation_count += 1
67 if category:
68 self._category_counts[category] = self._category_counts.get(category, 0) + 1
69 self._violation_log.append((now, category, result.action.value))
71 # Clean old entries outside window
72 cutoff = now - self.policy.window_seconds
73 self._violation_log = [(t, c, a) for t, c, a in self._violation_log if t > cutoff]
75 # Check cumulative violations
76 if self._violation_count >= self.policy.max_total_violations:
77 self._session_blocked = True
78 self._emit("session_blocked", {"total_violations": self._violation_count})
79 return PolicyViolation.CUMULATIVE_VIOLATIONS
81 # Check per-category limits
82 if category and category in self.policy.max_violations_per_category:
83 limit = self.policy.max_violations_per_category[category]
84 if self._category_counts[category] >= limit:
85 self._emit("category_blocked", {"category": category, "count": self._category_counts[category]})
86 return PolicyViolation.CATEGORY_BANNED
88 return None
90 def reset(self) -> None:
91 """Reset all violation counters for a new session."""
92 self._violation_count = 0
93 self._category_counts.clear()
94 self._session_blocked = False
95 self._violation_log.clear()
97 @property
98 def is_blocked(self) -> bool:
99 return self._session_blocked
101 @property
102 def total_violations(self) -> int:
103 return self._violation_count
105 def _emit(self, event: str, data: Dict[str, Any]) -> None:
106 if self.policy.monitoring_callback:
107 try:
108 self.policy.monitoring_callback(event, data)
109 except Exception:
110 pass