Coverage for agentos/guardrails/engine.py: 51%

99 statements  

« prev     ^ index     » next       coverage.py v7.14.3, created at 2026-07-02 09:59 +0800

1""" 

2Guardrail engine — rule registry, evaluation, and result aggregation. 

3""" 

4 

5from dataclasses import dataclass, field 

6from enum import Enum, auto 

7from typing import Any, Callable, Dict, List, Optional, Sequence 

8 

9 

10class GuardrailAction(str, Enum): 

11 """Guardrail disposition for a single rule match.""" 

12 

13 BLOCK = "block" 

14 FLAG = "flag" 

15 SANITIZE = "sanitize" 

16 PASS = "pass" 

17 

18 

19class GuardrailCategory(str, Enum): 

20 """Semantic category of a guardrail rule.""" 

21 

22 PII = "pii" 

23 TOXICITY = "toxicity" 

24 INJECTION = "injection" 

25 KEYWORD = "keyword" 

26 LENGTH = "length" 

27 CUSTOM = "custom" 

28 

29 

30@dataclass 

31class GuardrailResult: 

32 """Aggregate result after all guardrails have been evaluated.""" 

33 

34 passed: bool 

35 action: GuardrailAction 

36 violations: list[str] = field(default_factory=list) 

37 sanitized_text: str | None = None 

38 metadata: dict[str, Any] = field(default_factory=dict) 

39 

40 @property 

41 def blocked(self) -> bool: 

42 return self.action == GuardrailAction.BLOCK 

43 

44 

45@dataclass 

46class GuardrailRule: 

47 """A single guardrail rule definition.""" 

48 

49 name: str 

50 category: GuardrailCategory 

51 action: GuardrailAction 

52 check: Callable[[str], bool] 

53 sanitize: Callable[[str], str] | None = None 

54 description: str = "" 

55 enabled: bool = True 

56 metadata: Dict[str, Any] = field(default_factory=dict) 

57 

58 

59class InputGuardrail: 

60 """Validates user prompts before they reach the LLM.""" 

61 

62 def __init__(self, rules: Optional[list[GuardrailRule]] = None): 

63 self._rules: dict[str, GuardrailRule] = {} 

64 if rules: 

65 for r in rules: 

66 self.add_rule(r) 

67 

68 def add_rule(self, rule: GuardrailRule) -> None: 

69 self._rules[rule.name] = rule 

70 

71 def remove_rule(self, name: str) -> None: 

72 self._rules.pop(name, None) 

73 

74 def evaluate(self, text: str) -> GuardrailResult: 

75 """Run all enabled input rules against the text.""" 

76 violations: list[str] = [] 

77 worst_action = GuardrailAction.PASS 

78 sanitized = text 

79 for rule in self._rules.values(): 

80 if not rule.enabled: 

81 continue 

82 if rule.check(sanitized): 

83 violations.append(f"{rule.name}: {rule.description or rule.category.value}") 

84 if rule.sanitize: 

85 sanitized = rule.sanitize(sanitized) 

86 if _action_priority(rule.action) > _action_priority(worst_action): 

87 worst_action = rule.action 

88 

89 passed = worst_action != GuardrailAction.BLOCK 

90 return GuardrailResult( 

91 passed=passed, 

92 action=worst_action, 

93 violations=violations, 

94 sanitized_text=sanitized if sanitized != text else None, 

95 ) 

96 

97 

98class OutputGuardrail: 

99 """Validates LLM outputs before they reach the user.""" 

100 

101 def __init__(self, rules: Optional[list[GuardrailRule]] = None): 

102 self._rules: dict[str, GuardrailRule] = {} 

103 if rules: 

104 for r in rules: 

105 self.add_rule(r) 

106 

107 def add_rule(self, rule: GuardrailRule) -> None: 

108 self._rules[rule.name] = rule 

109 

110 def remove_rule(self, name: str) -> None: 

111 self._rules.pop(name, None) 

112 

113 def evaluate(self, text: str) -> GuardrailResult: 

114 """Run all enabled output rules against the text.""" 

115 violations: list[str] = [] 

116 worst_action = GuardrailAction.PASS 

117 sanitized = text 

118 for rule in self._rules.values(): 

119 if not rule.enabled: 

120 continue 

121 if rule.check(sanitized): 

122 violations.append(f"{rule.name}: {rule.description or rule.category.value}") 

123 if rule.sanitize: 

124 sanitized = rule.sanitize(sanitized) 

125 if _action_priority(rule.action) > _action_priority(worst_action): 

126 worst_action = rule.action 

127 

128 passed = worst_action != GuardrailAction.BLOCK 

129 return GuardrailResult( 

130 passed=passed, 

131 action=worst_action, 

132 violations=violations, 

133 sanitized_text=sanitized if sanitized != text else None, 

134 ) 

135 

136 

137class GuardrailEngine: 

138 """Unified guardrail engine managing both input and output pipelines.""" 

139 

140 def __init__( 

141 self, 

142 input_rules: Optional[list[GuardrailRule]] = None, 

143 output_rules: Optional[list[GuardrailRule]] = None, 

144 ): 

145 self.input = InputGuardrail(input_rules) 

146 self.output = OutputGuardrail(output_rules) 

147 

148 def check_input(self, prompt: str) -> GuardrailResult: 

149 return self.input.evaluate(prompt) 

150 

151 def check_output(self, response: str) -> GuardrailResult: 

152 return self.output.evaluate(response) 

153 

154 def check(self, prompt: str, response: str = "") -> tuple[GuardrailResult, GuardrailResult]: 

155 """Evaluate input and output guardrails. Empty response skips output check.""" 

156 inp = self.input.evaluate(prompt) 

157 out = self.output.evaluate(response) if response else GuardrailResult( 

158 passed=True, action=GuardrailAction.PASS 

159 ) 

160 return inp, out 

161 

162 

163def _action_priority(action: GuardrailAction) -> int: 

164 return {GuardrailAction.PASS: 0, GuardrailAction.FLAG: 1, GuardrailAction.SANITIZE: 2, GuardrailAction.BLOCK: 3}[action]