Coverage for src / kemi / sanitize.py: 100%

62 statements  

« prev     ^ index     » next       coverage.py v7.13.5, created at 2026-06-05 15:47 +0000

1"""Prompt injection detection and sanitization with audit logging. 

2 

3Security layer for protecting AI agents from prompt injection attacks. 

4Provides detection, sanitization, and comprehensive audit logging. 

5""" 

6 

7import hashlib 

8import logging 

9import re 

10from re import Pattern 

11from typing import Any 

12 

13logger = logging.getLogger(__name__) 

14 

15# Detectable injection pattern categories 

16_SUSPICIOUS_PATTERNS: list[tuple[Pattern[str], str]] = [ 

17 # Instruction override attempts 

18 (re.compile(r"(?i)\bignore\s+(all\s+)?previous\s+instructions\b"), "instruction_override"), 

19 (re.compile(r"(?i)\byou\s+are\s+now\b"), "role_override"), 

20 (re.compile(r"(?i)\bignore\s+all\b"), "ignore_all"), 

21 # Role playing / jailbreak attempts 

22 (re.compile(r"(?i)^\s*system\s*:", re.MULTILINE), "system_prefix"), 

23 (re.compile(r"(?i)^\s*assistant\s*:", re.MULTILINE), "assistant_prefix"), 

24 (re.compile(r"(?i)\[INST\]"), "inst_token"), 

25 (re.compile(r"(?i)^\s*###\s*instruction", re.MULTILINE), "markdown_instruction"), 

26] 

27 

28_ROLE_PREFIXES: list[tuple[Pattern[str], str]] = [ 

29 (re.compile(r"(?i)^\s*user\s*:\s*", re.MULTILINE), "user_role"), 

30 (re.compile(r"(?i)^\s*assistant\s*:\s*", re.MULTILINE), "assistant_role"), 

31 (re.compile(r"(?i)^\s*system\s*:\s*", re.MULTILINE), "system_role"), 

32 (re.compile(r"(?i)^\s*bot\s*:\s*", re.MULTILINE), "bot_role"), 

33] 

34 

35 

36def _log_detection( 

37 content_length: int, 

38 content_hash: str, 

39 pattern_name: str, 

40 action: str, 

41 details: dict[str, Any] | None = None, 

42) -> None: 

43 """Log prompt injection detection event for audit purposes. 

44 

45 Args: 

46 content_length: Length of the content (avoids logging actual content) 

47 content_hash: SHA256 hash of content for identification without exposure 

48 pattern_name: Name of the pattern that matched 

49 action: What action was taken (detected, sanitized, rejected) 

50 details: Additional context for the log entry 

51 """ 

52 log_data: dict[str, Any] = { 

53 "event": "prompt_injection_detection", 

54 "pattern": pattern_name, 

55 "action": action, 

56 "content_length": content_length, 

57 "content_hash": content_hash[:16], # Only first 16 chars of hash 

58 } 

59 if details: 

60 log_data.update(details) 

61 

62 # Use appropriate log level based on severity 

63 if action == "rejected": 

64 logger.error("Prompt injection attempt detected and rejected: %s", log_data) 

65 else: 

66 logger.warning("Prompt injection pattern detected: %s", log_data) 

67 

68 

69def _get_content_hash(content: str) -> str: 

70 """Get SHA256 hash of content for audit logging without exposing content.""" 

71 return hashlib.sha256(content.encode()).hexdigest() 

72 

73 

74def is_suspicious(content: str) -> bool: 

75 """Check if content contains potential prompt injection patterns. 

76 

77 Does not modify the content. Returns True if any suspicious pattern found. 

78 Logs detection event for audit purposes (only metadata, not content). 

79 """ 

80 if len(content) < 8: 

81 return False 

82 

83 for pattern, pattern_name in _SUSPICIOUS_PATTERNS: 

84 if pattern.search(content): 

85 _log_detection( 

86 len(content), 

87 _get_content_hash(content), 

88 pattern_name, 

89 "detected", 

90 ) 

91 return True 

92 

93 return False 

94 

95 

96def sanitize(content: str, strict: bool = False) -> str: 

97 """Remove or neutralize potential prompt injection patterns. 

98 

99 Default strict=False: removes suspicious patterns only. 

100 strict=True: additionally removes any line starting with role prefix. 

101 

102 Protects legitimate short statements (< 8 words) that contain no instruction pattern. 

103 

104 Logs all sanitization events for audit purposes (only metadata, not content). 

105 """ 

106 word_count = len(content.split()) 

107 

108 if word_count < 8 and not is_suspicious(content): 

109 return content 

110 

111 result = content 

112 detected_patterns: list[str] = [] 

113 

114 for pattern, pattern_name in _SUSPICIOUS_PATTERNS: 

115 if pattern.search(result): 

116 detected_patterns.append(pattern_name) 

117 result = pattern.sub("[SANITIZED]", result) 

118 

119 if strict: 

120 for pattern, pattern_name in _ROLE_PREFIXES: 

121 if pattern.search(result): 

122 detected_patterns.append(pattern_name) 

123 result = pattern.sub("[ROLE]", result) 

124 

125 # Log the sanitization event 

126 if detected_patterns: 

127 _log_detection( 

128 len(content), 

129 _get_content_hash(content), 

130 ", ".join(detected_patterns), 

131 "sanitized", 

132 {"strict_mode": strict, "patterns_found": len(detected_patterns)}, 

133 ) 

134 

135 return result 

136 

137 

138def sanitize_with_rejection(content: str, strict: bool = False) -> tuple[str, bool]: 

139 """Sanitize content and indicate whether it was suspicious. 

140 

141 Args: 

142 content: The content to sanitize 

143 strict: Whether to use strict mode (also remove role prefixes) 

144 

145 Returns: 

146 Tuple of (sanitized_content, was_suspicious) 

147 """ 

148 word_count = len(content.split()) 

149 

150 # Fast path: short non-suspicious content 

151 if word_count < 8: 

152 # Check if suspicious without logging (to avoid double logging) 

153 is_susp = False 

154 for pattern, _ in _SUSPICIOUS_PATTERNS: 

155 if pattern.search(content): 

156 is_susp = True 

157 break 

158 if not is_susp: 

159 return content, False 

160 

161 # Track if content is suspicious (without double logging) 

162 was_suspicious = False 

163 content_hash = _get_content_hash(content) 

164 

165 # Check for suspicious patterns, only log once 

166 for pattern, pattern_name in _SUSPICIOUS_PATTERNS: 

167 if pattern.search(content): 

168 was_suspicious = True 

169 _log_detection( 

170 len(content), 

171 content_hash, 

172 pattern_name, 

173 "sanitized_with_rejection", 

174 {"strict_mode": strict}, 

175 ) 

176 break # Only log once, not per pattern 

177 

178 # Sanitize (may log additional role prefix patterns in strict mode) 

179 sanitized = sanitize(content, strict) 

180 

181 return sanitized, was_suspicious