Coverage for sentimatrix / core / config.py: 98%

219 statements  

« prev     ^ index     » next       coverage.py v7.13.2, created at 2026-01-28 09:30 +0000

1""" 

2Sentimatrix Configuration Module 

3 

4Provides centralized configuration management using Pydantic V2. 

5Supports YAML/JSON file loading, environment variables, and runtime overrides. 

6 

7Example: 

8 >>> config = SentimatrixConfig.from_file("config.yaml") 

9 >>> config = SentimatrixConfig.from_env() 

10 >>> config.llm.provider 

11 'openai' 

12""" 

13 

14from __future__ import annotations 

15 

16import os 

17from enum import Enum 

18from pathlib import Path 

19from typing import Any, Dict, List, Literal, Optional, Union 

20 

21import yaml 

22from pydantic import BaseModel, ConfigDict, Field, field_validator, model_validator 

23from pydantic_settings import BaseSettings, SettingsConfigDict 

24 

25 

26class LogLevel(str, Enum): 

27 """Supported log levels.""" 

28 

29 DEBUG = "DEBUG" 

30 INFO = "INFO" 

31 WARNING = "WARNING" 

32 ERROR = "ERROR" 

33 CRITICAL = "CRITICAL" 

34 

35 

36class CacheBackend(str, Enum): 

37 """Supported cache backends.""" 

38 

39 MEMORY = "memory" 

40 REDIS = "redis" 

41 SQLITE = "sqlite" 

42 

43 

44class LLMProvider(str, Enum): 

45 """Supported LLM providers.""" 

46 

47 OPENAI = "openai" 

48 ANTHROPIC = "anthropic" 

49 GROQ = "groq" 

50 GEMINI = "gemini" 

51 MISTRAL = "mistral" 

52 COHERE = "cohere" 

53 TOGETHER = "together" 

54 FIREWORKS = "fireworks" 

55 CEREBRAS = "cerebras" 

56 DEEPSEEK = "deepseek" 

57 OLLAMA = "ollama" 

58 VLLM = "vllm" 

59 HUGGINGFACE = "huggingface" 

60 

61 

62class ScraperProvider(str, Enum): 

63 """Supported scraper providers.""" 

64 

65 PLAYWRIGHT = "playwright" 

66 SELENIUM = "selenium" 

67 HTTPX = "httpx" 

68 REQUESTS = "requests" 

69 SCRAPERAPI = "scraperapi" 

70 BRIGHTDATA = "brightdata" 

71 OXYLABS = "oxylabs" 

72 APIFY = "apify" 

73 ZYTE = "zyte" 

74 FIRECRAWL = "firecrawl" 

75 

76 

77class RetryConfig(BaseModel): 

78 """Configuration for retry behavior.""" 

79 

80 model_config = ConfigDict(frozen=True) 

81 

82 max_retries: int = Field(default=3, ge=0, le=10, description="Maximum number of retry attempts") 

83 initial_delay: float = Field( 

84 default=1.0, ge=0.1, le=60.0, description="Initial delay between retries in seconds" 

85 ) 

86 max_delay: float = Field( 

87 default=60.0, ge=1.0, le=300.0, description="Maximum delay between retries in seconds" 

88 ) 

89 exponential_base: float = Field( 

90 default=2.0, ge=1.0, le=5.0, description="Base for exponential backoff" 

91 ) 

92 jitter: bool = Field(default=True, description="Add random jitter to delays") 

93 

94 

95class RateLimitConfig(BaseModel): 

96 """Configuration for rate limiting.""" 

97 

98 model_config = ConfigDict(frozen=True) 

99 

100 requests_per_second: float = Field( 

101 default=1.0, ge=0.1, le=100.0, description="Maximum requests per second" 

102 ) 

103 requests_per_minute: int = Field( 

104 default=60, ge=1, le=6000, description="Maximum requests per minute" 

105 ) 

106 concurrent_requests: int = Field( 

107 default=5, ge=1, le=100, description="Maximum concurrent requests" 

108 ) 

109 backoff_factor: float = Field( 

110 default=2.0, ge=1.0, le=10.0, description="Backoff multiplier on rate limit hit" 

111 ) 

112 

113 

114class ProxyConfig(BaseModel): 

115 """Configuration for proxy settings.""" 

116 

117 model_config = ConfigDict(frozen=True) 

118 

119 enabled: bool = Field(default=False, description="Enable proxy usage") 

120 provider: Optional[str] = Field( 

121 default=None, description="Proxy provider (brightdata, oxylabs, custom)" 

122 ) 

123 url: Optional[str] = Field(default=None, description="Proxy URL") 

124 username: Optional[str] = Field(default=None, description="Proxy username") 

125 password: Optional[str] = Field(default=None, description="Proxy password") 

126 rotation: bool = Field(default=True, description="Enable proxy rotation") 

127 country: Optional[str] = Field(default=None, description="Target country code") 

128 

129 @model_validator(mode="after") 

130 def validate_proxy_config(self) -> "ProxyConfig": 

131 """Validate proxy configuration consistency.""" 

132 if self.enabled and not self.url and not self.provider: 

133 raise ValueError("Either 'url' or 'provider' must be specified when proxy is enabled") 

134 return self 

135 

136 

137class LLMConfig(BaseModel): 

138 """Configuration for LLM provider.""" 

139 

140 model_config = ConfigDict(frozen=True) 

141 

142 provider: LLMProvider = Field(default=LLMProvider.OPENAI, description="LLM provider to use") 

143 model: str = Field(default="gpt-4o-mini", description="Model name/identifier") 

144 api_key: Optional[str] = Field(default=None, description="API key (can use env var)") 

145 api_base: Optional[str] = Field(default=None, description="Custom API base URL") 

146 organization: Optional[str] = Field(default=None, description="Organization ID") 

147 timeout: int = Field(default=30, ge=5, le=300, description="Request timeout in seconds") 

148 max_tokens: int = Field(default=1024, ge=1, le=128000, description="Maximum tokens to generate") 

149 temperature: float = Field(default=0.7, ge=0.0, le=2.0, description="Sampling temperature") 

150 top_p: float = Field(default=1.0, ge=0.0, le=1.0, description="Top-p sampling parameter") 

151 retry: RetryConfig = Field(default_factory=RetryConfig, description="Retry configuration") 

152 rate_limit: RateLimitConfig = Field( 

153 default_factory=RateLimitConfig, description="Rate limit configuration" 

154 ) 

155 

156 @field_validator("api_key", mode="before") 

157 @classmethod 

158 def resolve_api_key_from_env(cls, v: Optional[str]) -> Optional[str]: 

159 """Resolve API key from environment variable if prefixed with 'env:'.""" 

160 if v and v.startswith("env:"): 

161 env_var = v[4:] 

162 return os.environ.get(env_var) 

163 return v 

164 

165 

166class ScraperConfig(BaseModel): 

167 """Configuration for web scraping.""" 

168 

169 model_config = ConfigDict(frozen=True) 

170 

171 provider: ScraperProvider = Field( 

172 default=ScraperProvider.PLAYWRIGHT, description="Scraper provider to use" 

173 ) 

174 headless: bool = Field(default=True, description="Run browser in headless mode") 

175 timeout: int = Field(default=30, ge=5, le=120, description="Page load timeout in seconds") 

176 wait_for_selector: Optional[str] = Field( 

177 default=None, description="CSS selector to wait for before scraping" 

178 ) 

179 user_agent: Optional[str] = Field(default=None, description="Custom user agent string") 

180 viewport_width: int = Field(default=1920, ge=320, le=3840, description="Browser viewport width") 

181 viewport_height: int = Field( 

182 default=1080, ge=240, le=2160, description="Browser viewport height" 

183 ) 

184 proxy: ProxyConfig = Field(default_factory=ProxyConfig, description="Proxy configuration") 

185 rate_limit: RateLimitConfig = Field( 

186 default_factory=RateLimitConfig, description="Rate limit configuration" 

187 ) 

188 retry: RetryConfig = Field(default_factory=RetryConfig, description="Retry configuration") 

189 screenshots: bool = Field(default=False, description="Capture screenshots during scraping") 

190 screenshot_dir: Optional[str] = Field( 

191 default=None, description="Directory for screenshots" 

192 ) 

193 

194 

195class ModelConfig(BaseModel): 

196 """Configuration for ML models.""" 

197 

198 model_config = ConfigDict(frozen=True) 

199 

200 sentiment_model: str = Field( 

201 default="cardiffnlp/twitter-roberta-base-sentiment-latest", 

202 description="HuggingFace model for sentiment analysis", 

203 ) 

204 emotion_model: str = Field( 

205 default="SamLowe/roberta-base-go_emotions", 

206 description="HuggingFace model for emotion detection", 

207 ) 

208 device: Literal["auto", "cpu", "cuda", "mps"] = Field( 

209 default="auto", description="Device for model inference" 

210 ) 

211 batch_size: int = Field(default=32, ge=1, le=512, description="Batch size for inference") 

212 max_length: int = Field( 

213 default=512, ge=32, le=4096, description="Maximum sequence length for tokenization" 

214 ) 

215 use_quantization: bool = Field(default=False, description="Enable model quantization") 

216 cache_models: bool = Field(default=True, description="Cache loaded models in memory") 

217 

218 

219class CacheConfig(BaseModel): 

220 """Configuration for caching.""" 

221 

222 model_config = ConfigDict(frozen=True) 

223 

224 enabled: bool = Field(default=True, description="Enable caching") 

225 backend: CacheBackend = Field(default=CacheBackend.MEMORY, description="Cache backend") 

226 ttl: int = Field(default=3600, ge=0, le=86400, description="Default TTL in seconds (0=no expiry)") 

227 max_size: int = Field(default=1000, ge=10, le=100000, description="Maximum cache entries") 

228 namespace: str = Field(default="sentimatrix", description="Cache key namespace") 

229 redis_url: Optional[str] = Field(default=None, description="Redis connection URL") 

230 sqlite_path: Optional[str] = Field(default=None, description="SQLite database path") 

231 compression: bool = Field(default=False, description="Enable cache value compression") 

232 

233 @model_validator(mode="after") 

234 def validate_backend_config(self) -> "CacheConfig": 

235 """Validate backend-specific configuration.""" 

236 if self.backend == CacheBackend.REDIS and not self.redis_url: 

237 raise ValueError("redis_url is required when using Redis backend") 

238 if self.backend == CacheBackend.SQLITE and not self.sqlite_path: 

239 raise ValueError("sqlite_path is required when using SQLite backend") 

240 return self 

241 

242 

243class LogConfig(BaseModel): 

244 """Configuration for logging.""" 

245 

246 model_config = ConfigDict(frozen=True) 

247 

248 level: LogLevel = Field(default=LogLevel.INFO, description="Log level") 

249 format: Literal["json", "text"] = Field(default="json", description="Log format") 

250 file_path: Optional[str] = Field(default=None, description="Log file path") 

251 max_file_size_mb: int = Field( 

252 default=10, ge=1, le=100, description="Maximum log file size in MB" 

253 ) 

254 backup_count: int = Field(default=5, ge=0, le=20, description="Number of backup log files") 

255 include_timestamp: bool = Field(default=True, description="Include timestamp in logs") 

256 include_caller: bool = Field(default=True, description="Include caller info in logs") 

257 console_output: bool = Field(default=True, description="Output logs to console") 

258 colorize: bool = Field(default=True, description="Colorize console output") 

259 

260 

261class OutputConfig(BaseModel): 

262 """Configuration for output handling.""" 

263 

264 model_config = ConfigDict(frozen=True) 

265 

266 default_format: Literal["json", "csv", "xlsx"] = Field( 

267 default="json", description="Default export format" 

268 ) 

269 include_metadata: bool = Field(default=True, description="Include metadata in exports") 

270 include_raw_data: bool = Field(default=False, description="Include raw data in exports") 

271 pretty_print: bool = Field(default=True, description="Pretty print JSON output") 

272 datetime_format: str = Field( 

273 default="%Y-%m-%dT%H:%M:%SZ", description="Datetime format string" 

274 ) 

275 

276 

277class FallbackConfig(BaseModel): 

278 """Configuration for provider fallback chain.""" 

279 

280 model_config = ConfigDict(frozen=True) 

281 

282 enabled: bool = Field(default=True, description="Enable fallback chain") 

283 providers: List[LLMProvider] = Field( 

284 default_factory=lambda: [LLMProvider.OPENAI, LLMProvider.ANTHROPIC, LLMProvider.GROQ], 

285 description="Ordered list of fallback providers", 

286 ) 

287 max_attempts: int = Field(default=3, ge=1, le=10, description="Maximum fallback attempts") 

288 

289 

290class SentimatrixConfig(BaseSettings): 

291 """ 

292 Main configuration class for Sentimatrix. 

293 

294 Supports loading from: 

295 - YAML/JSON files 

296 - Environment variables (prefixed with SENTIMATRIX_) 

297 - Runtime overrides 

298 

299 Example: 

300 >>> config = SentimatrixConfig.from_file("config.yaml") 

301 >>> config = SentimatrixConfig() # Uses defaults + env vars 

302 """ 

303 

304 model_config = SettingsConfigDict( 

305 env_prefix="SENTIMATRIX_", 

306 env_nested_delimiter="__", 

307 case_sensitive=False, 

308 extra="ignore", 

309 ) 

310 

311 # Sub-configurations 

312 llm: LLMConfig = Field(default_factory=LLMConfig, description="LLM provider configuration") 

313 scrapers: ScraperConfig = Field( 

314 default_factory=ScraperConfig, description="Scraper configuration" 

315 ) 

316 models: ModelConfig = Field(default_factory=ModelConfig, description="ML model configuration") 

317 cache: CacheConfig = Field(default_factory=CacheConfig, description="Cache configuration") 

318 logging: LogConfig = Field(default_factory=LogConfig, description="Logging configuration") 

319 output: OutputConfig = Field(default_factory=OutputConfig, description="Output configuration") 

320 fallback: FallbackConfig = Field( 

321 default_factory=FallbackConfig, description="Fallback configuration" 

322 ) 

323 

324 # Global settings 

325 debug: bool = Field(default=False, description="Enable debug mode") 

326 dry_run: bool = Field(default=False, description="Enable dry run mode (no API calls)") 

327 

328 @classmethod 

329 def from_file(cls, path: Union[str, Path], **overrides: Any) -> "SentimatrixConfig": 

330 """ 

331 Load configuration from a YAML or JSON file. 

332 

333 Args: 

334 path: Path to configuration file 

335 **overrides: Additional configuration overrides 

336 

337 Returns: 

338 SentimatrixConfig instance 

339 

340 Raises: 

341 ConfigurationError: If file cannot be loaded or parsed 

342 """ 

343 path = Path(path) 

344 

345 if not path.exists(): 

346 from sentimatrix.core.exceptions import ConfigurationError 

347 

348 raise ConfigurationError(f"Configuration file not found: {path}") 

349 

350 try: 

351 with open(path, "r", encoding="utf-8") as f: 

352 if path.suffix in (".yaml", ".yml"): 

353 data = yaml.safe_load(f) or {} 

354 elif path.suffix == ".json": 

355 import json 

356 

357 data = json.load(f) 

358 else: 

359 from sentimatrix.core.exceptions import ConfigurationError 

360 

361 raise ConfigurationError( 

362 f"Unsupported config file format: {path.suffix}. " 

363 "Use .yaml, .yml, or .json" 

364 ) 

365 except yaml.YAMLError as e: 

366 from sentimatrix.core.exceptions import ConfigurationError 

367 

368 raise ConfigurationError(f"Failed to parse YAML configuration: {e}") from e 

369 except Exception as e: 

370 from sentimatrix.core.exceptions import ConfigurationError 

371 

372 raise ConfigurationError(f"Failed to load configuration file: {e}") from e 

373 

374 # Merge overrides 

375 data.update(overrides) 

376 

377 return cls(**data) 

378 

379 @classmethod 

380 def from_env(cls, **overrides: Any) -> "SentimatrixConfig": 

381 """ 

382 Load configuration from environment variables. 

383 

384 Environment variables should be prefixed with SENTIMATRIX_. 

385 Nested values use double underscore as delimiter. 

386 

387 Example: 

388 SENTIMATRIX_LLM__PROVIDER=openai 

389 SENTIMATRIX_LLM__MODEL=gpt-4 

390 SENTIMATRIX_DEBUG=true 

391 

392 Args: 

393 **overrides: Additional configuration overrides 

394 

395 Returns: 

396 SentimatrixConfig instance 

397 """ 

398 return cls(**overrides) 

399 

400 def to_dict(self) -> Dict[str, Any]: 

401 """Convert configuration to dictionary.""" 

402 return self.model_dump() 

403 

404 def to_yaml(self) -> str: 

405 """Convert configuration to YAML string.""" 

406 return yaml.dump(self.to_dict(), default_flow_style=False, sort_keys=False) 

407 

408 def save(self, path: Union[str, Path]) -> None: 

409 """ 

410 Save configuration to a file. 

411 

412 Args: 

413 path: Output file path (format determined by extension) 

414 """ 

415 path = Path(path) 

416 path.parent.mkdir(parents=True, exist_ok=True) 

417 

418 with open(path, "w", encoding="utf-8") as f: 

419 if path.suffix in (".yaml", ".yml"): 

420 yaml.dump(self.to_dict(), f, default_flow_style=False, sort_keys=False) 

421 elif path.suffix == ".json": 421 ↛ 426line 421 didn't jump to line 426 because the condition on line 421 was always true

422 import json 

423 

424 json.dump(self.to_dict(), f, indent=2) 

425 else: 

426 from sentimatrix.core.exceptions import ConfigurationError 

427 

428 raise ConfigurationError( 

429 f"Unsupported output format: {path.suffix}. Use .yaml, .yml, or .json" 

430 ) 

431 

432 def with_overrides(self, **overrides: Any) -> "SentimatrixConfig": 

433 """ 

434 Create a new configuration with the specified overrides. 

435 

436 Args: 

437 **overrides: Configuration overrides 

438 

439 Returns: 

440 New SentimatrixConfig instance with overrides applied 

441 """ 

442 data = self.to_dict() 

443 self._deep_merge(data, overrides) 

444 return SentimatrixConfig(**data) 

445 

446 @staticmethod 

447 def _deep_merge(base: Dict[str, Any], updates: Dict[str, Any]) -> None: 

448 """Deep merge updates into base dictionary.""" 

449 for key, value in updates.items(): 

450 if key in base and isinstance(base[key], dict) and isinstance(value, dict): 

451 SentimatrixConfig._deep_merge(base[key], value) 

452 else: 

453 base[key] = value 

454 

455 

456# Convenience function for quick access 

457def get_config( 

458 config_path: Optional[Union[str, Path]] = None, **overrides: Any 

459) -> SentimatrixConfig: 

460 """ 

461 Get configuration instance. 

462 

463 Loads from file if path provided, otherwise from environment variables. 

464 

465 Args: 

466 config_path: Optional path to configuration file 

467 **overrides: Configuration overrides 

468 

469 Returns: 

470 SentimatrixConfig instance 

471 """ 

472 if config_path: 

473 return SentimatrixConfig.from_file(config_path, **overrides) 

474 return SentimatrixConfig.from_env(**overrides)