Coverage for src / documint_mcp / utils / validators.py: 0%

111 statements  

« prev     ^ index     » next       coverage.py v7.13.5, created at 2026-03-30 22:30 -0400

1""" 

2Input validation utilities to prevent security vulnerabilities. 

3 

4Bug Fix #3: Input Validation and Sanitization 

5- Validates file paths to prevent directory traversal attacks 

6- Sanitizes user input to prevent injection attacks 

7- Validates file types and content 

8""" 

9 

10import re 

11from pathlib import Path 

12 

13from pydantic import BaseModel, Field, field_validator 

14 

15 

16class FilePathValidator: 

17 """ 

18 Secure file path validation to prevent directory traversal attacks. 

19 

20 Bug Fix #3: Prevents directory traversal and path injection vulnerabilities. 

21 """ 

22 

23 ALLOWED_EXTENSIONS = {".txt", ".md", ".rst", ".pdf", ".docx", ".html", ".json"} 

24 BLOCKED_PATTERNS = [ 

25 r"\.\.[\\/]", # Directory traversal 

26 r"^[\\/]", # Absolute paths 

27 r"[<>:\"|?*]", # Invalid filename characters 

28 r"^(CON|PRN|AUX|NUL|COM[1-9]|LPT[1-9])$", # Windows reserved names 

29 ] 

30 

31 @classmethod 

32 def validate_path(cls, file_path: str | Path) -> Path: 

33 """ 

34 Validate file path for security issues. 

35 

36 Args: 

37 file_path: Path to validate 

38 

39 Returns: 

40 Validated Path object 

41 

42 Raises: 

43 ValueError: If path is invalid or insecure 

44 """ 

45 if not file_path: 

46 raise ValueError("File path cannot be empty") 

47 

48 path_str = str(file_path) 

49 

50 # Check for blocked patterns 

51 for pattern in cls.BLOCKED_PATTERNS: 

52 if re.search(pattern, path_str, re.IGNORECASE): 

53 raise ValueError("Invalid file path: contains blocked pattern") 

54 

55 # Convert to Path and resolve 

56 path = Path(path_str) 

57 

58 # Check file extension 

59 if path.suffix.lower() not in cls.ALLOWED_EXTENSIONS: 

60 raise ValueError(f"File extension {path.suffix} not allowed") 

61 

62 # Ensure path is within allowed directory 

63 try: 

64 resolved_path = path.resolve() 

65 # This prevents symlink attacks 

66 if not resolved_path.is_file() and not resolved_path.parent.exists(): 

67 resolved_path.parent.mkdir(parents=True, exist_ok=True) 

68 except (OSError, ValueError) as e: 

69 raise ValueError(f"Invalid file path: {e}") from e 

70 

71 return resolved_path 

72 

73 @classmethod 

74 def validate_filename(cls, filename: str) -> str: 

75 """ 

76 Validate filename for security issues. 

77 

78 Args: 

79 filename: Filename to validate 

80 

81 Returns: 

82 Validated filename 

83 

84 Raises: 

85 ValueError: If filename is invalid 

86 """ 

87 if not filename or not filename.strip(): 

88 raise ValueError("Filename cannot be empty") 

89 

90 filename = filename.strip() 

91 

92 # Check length 

93 if len(filename) > 255: 

94 raise ValueError("Filename too long") 

95 

96 # Check for invalid characters 

97 for pattern in cls.BLOCKED_PATTERNS[2:]: # Skip path-specific patterns 

98 if re.search(pattern, filename, re.IGNORECASE): 

99 raise ValueError("Invalid filename: contains blocked characters") 

100 

101 return filename 

102 

103 

104class ContentValidator: 

105 """ 

106 Content validation to prevent injection attacks and ensure data integrity. 

107 

108 Bug Fix #4: Prevents content-based security vulnerabilities. 

109 """ 

110 

111 MAX_CONTENT_LENGTH = 10 * 1024 * 1024 # 10MB 

112 ALLOWED_MIME_TYPES = { 

113 "text/plain", 

114 "text/markdown", 

115 "text/html", 

116 "application/json", 

117 "application/pdf", 

118 } 

119 

120 @classmethod 

121 def validate_content(cls, content: str, content_type: str = "text/plain") -> str: 

122 """ 

123 Validate content for security issues. 

124 

125 Args: 

126 content: Content to validate 

127 content_type: MIME type of content 

128 

129 Returns: 

130 Validated content 

131 

132 Raises: 

133 ValueError: If content is invalid 

134 """ 

135 if not isinstance(content, str): 

136 raise ValueError("Content must be a string") 

137 

138 if len(content) > cls.MAX_CONTENT_LENGTH: 

139 raise ValueError(f"Content too large: {len(content)} bytes") 

140 

141 if content_type not in cls.ALLOWED_MIME_TYPES: 

142 raise ValueError(f"Content type {content_type} not allowed") 

143 

144 # Basic XSS prevention for HTML content 

145 if content_type == "text/html": 

146 content = cls._sanitize_html(content) 

147 

148 return content 

149 

150 @classmethod 

151 def _sanitize_html(cls, html_content: str) -> str: 

152 """ 

153 Basic HTML sanitization to prevent XSS. 

154 

155 Args: 

156 html_content: HTML content to sanitize 

157 

158 Returns: 

159 Sanitized HTML content 

160 """ 

161 # Remove script tags and their content 

162 html_content = re.sub( 

163 r"<script[^>]*>.*?</script>", 

164 "", 

165 html_content, 

166 flags=re.DOTALL | re.IGNORECASE, 

167 ) 

168 

169 # Remove dangerous attributes 

170 dangerous_attrs = ["onclick", "onload", "onerror", "onmouseover", "onfocus"] 

171 for attr in dangerous_attrs: 

172 html_content = re.sub( 

173 rf'{attr}="[^"]*"', "", html_content, flags=re.IGNORECASE 

174 ) 

175 html_content = re.sub( 

176 rf"{attr}='[^']*'", "", html_content, flags=re.IGNORECASE 

177 ) 

178 

179 return html_content 

180 

181 

182class SearchQueryValidator: 

183 """ 

184 Search query validation to prevent injection attacks. 

185 

186 Bug Fix #5: Prevents search injection vulnerabilities. 

187 """ 

188 

189 MAX_QUERY_LENGTH = 1000 

190 BLOCKED_CHARS = ["<", ">", '"', "'", "&", ";", "|", "`", "$"] 

191 

192 @classmethod 

193 def validate_query(cls, query: str) -> str: 

194 """ 

195 Validate search query for security issues. 

196 

197 Args: 

198 query: Search query to validate 

199 

200 Returns: 

201 Validated query 

202 

203 Raises: 

204 ValueError: If query is invalid 

205 """ 

206 if not query or not query.strip(): 

207 raise ValueError("Query cannot be empty") 

208 

209 query = query.strip() 

210 

211 if len(query) > cls.MAX_QUERY_LENGTH: 

212 raise ValueError(f"Query too long: {len(query)} characters") 

213 

214 # Check for blocked characters 

215 for char in cls.BLOCKED_CHARS: 

216 if char in query: 

217 raise ValueError(f"Query contains blocked character: {char}") 

218 

219 # Basic SQL injection prevention 

220 sql_keywords = [ 

221 "DROP", 

222 "DELETE", 

223 "UPDATE", 

224 "INSERT", 

225 "CREATE", 

226 "ALTER", 

227 "EXEC", 

228 "UNION", 

229 "SELECT", 

230 ] 

231 query_upper = query.upper() 

232 for keyword in sql_keywords: 

233 if keyword in query_upper: 

234 raise ValueError(f"Query contains blocked keyword: {keyword}") 

235 

236 return query 

237 

238 

239class DocumentMetadata(BaseModel): 

240 """ 

241 Document metadata model with validation. 

242 

243 Performance Optimization: Uses Pydantic for fast validation and serialization. 

244 """ 

245 

246 title: str = Field(..., max_length=200) 

247 description: str | None = Field(None, max_length=1000) 

248 tags: list[str] = Field(default_factory=list, max_length=10) 

249 author: str | None = Field(None, max_length=100) 

250 file_path: str = Field(..., min_length=1) 

251 content_type: str = Field(default="text/plain") 

252 file_size: int = Field(..., ge=0) 

253 

254 @field_validator("title") 

255 @classmethod 

256 def validate_title(cls, v: str) -> str: 

257 """Validate document title.""" 

258 if not v or not v.strip(): 

259 raise ValueError("Title cannot be empty") 

260 return v.strip() 

261 

262 @field_validator("tags") 

263 @classmethod 

264 def validate_tags(cls, v: list[str]) -> list[str]: 

265 """Validate document tags.""" 

266 validated_tags = [] 

267 for tag in v: 

268 if tag and tag.strip(): 

269 validated_tag = tag.strip().lower() 

270 if len(validated_tag) <= 50 and validated_tag not in validated_tags: 

271 validated_tags.append(validated_tag) 

272 return validated_tags 

273 

274 @field_validator("file_path") 

275 @classmethod 

276 def validate_file_path(cls, v: str) -> str: 

277 """Validate file path.""" 

278 FilePathValidator.validate_path(v) 

279 return v 

280 

281 @field_validator("content_type") 

282 @classmethod 

283 def validate_content_type(cls, v: str) -> str: 

284 """Validate content type.""" 

285 if v not in ContentValidator.ALLOWED_MIME_TYPES: 

286 raise ValueError(f"Content type {v} not allowed") 

287 return v