Coverage for sentimatrix / core / config.py: 98%
219 statements
« prev ^ index » next coverage.py v7.13.2, created at 2026-01-28 09:30 +0000
« prev ^ index » next coverage.py v7.13.2, created at 2026-01-28 09:30 +0000
1"""
2Sentimatrix Configuration Module
4Provides centralized configuration management using Pydantic V2.
5Supports YAML/JSON file loading, environment variables, and runtime overrides.
7Example:
8 >>> config = SentimatrixConfig.from_file("config.yaml")
9 >>> config = SentimatrixConfig.from_env()
10 >>> config.llm.provider
11 'openai'
12"""
14from __future__ import annotations
16import os
17from enum import Enum
18from pathlib import Path
19from typing import Any, Dict, List, Literal, Optional, Union
21import yaml
22from pydantic import BaseModel, ConfigDict, Field, field_validator, model_validator
23from pydantic_settings import BaseSettings, SettingsConfigDict
26class LogLevel(str, Enum):
27 """Supported log levels."""
29 DEBUG = "DEBUG"
30 INFO = "INFO"
31 WARNING = "WARNING"
32 ERROR = "ERROR"
33 CRITICAL = "CRITICAL"
36class CacheBackend(str, Enum):
37 """Supported cache backends."""
39 MEMORY = "memory"
40 REDIS = "redis"
41 SQLITE = "sqlite"
44class LLMProvider(str, Enum):
45 """Supported LLM providers."""
47 OPENAI = "openai"
48 ANTHROPIC = "anthropic"
49 GROQ = "groq"
50 GEMINI = "gemini"
51 MISTRAL = "mistral"
52 COHERE = "cohere"
53 TOGETHER = "together"
54 FIREWORKS = "fireworks"
55 CEREBRAS = "cerebras"
56 DEEPSEEK = "deepseek"
57 OLLAMA = "ollama"
58 VLLM = "vllm"
59 HUGGINGFACE = "huggingface"
62class ScraperProvider(str, Enum):
63 """Supported scraper providers."""
65 PLAYWRIGHT = "playwright"
66 SELENIUM = "selenium"
67 HTTPX = "httpx"
68 REQUESTS = "requests"
69 SCRAPERAPI = "scraperapi"
70 BRIGHTDATA = "brightdata"
71 OXYLABS = "oxylabs"
72 APIFY = "apify"
73 ZYTE = "zyte"
74 FIRECRAWL = "firecrawl"
77class RetryConfig(BaseModel):
78 """Configuration for retry behavior."""
80 model_config = ConfigDict(frozen=True)
82 max_retries: int = Field(default=3, ge=0, le=10, description="Maximum number of retry attempts")
83 initial_delay: float = Field(
84 default=1.0, ge=0.1, le=60.0, description="Initial delay between retries in seconds"
85 )
86 max_delay: float = Field(
87 default=60.0, ge=1.0, le=300.0, description="Maximum delay between retries in seconds"
88 )
89 exponential_base: float = Field(
90 default=2.0, ge=1.0, le=5.0, description="Base for exponential backoff"
91 )
92 jitter: bool = Field(default=True, description="Add random jitter to delays")
95class RateLimitConfig(BaseModel):
96 """Configuration for rate limiting."""
98 model_config = ConfigDict(frozen=True)
100 requests_per_second: float = Field(
101 default=1.0, ge=0.1, le=100.0, description="Maximum requests per second"
102 )
103 requests_per_minute: int = Field(
104 default=60, ge=1, le=6000, description="Maximum requests per minute"
105 )
106 concurrent_requests: int = Field(
107 default=5, ge=1, le=100, description="Maximum concurrent requests"
108 )
109 backoff_factor: float = Field(
110 default=2.0, ge=1.0, le=10.0, description="Backoff multiplier on rate limit hit"
111 )
114class ProxyConfig(BaseModel):
115 """Configuration for proxy settings."""
117 model_config = ConfigDict(frozen=True)
119 enabled: bool = Field(default=False, description="Enable proxy usage")
120 provider: Optional[str] = Field(
121 default=None, description="Proxy provider (brightdata, oxylabs, custom)"
122 )
123 url: Optional[str] = Field(default=None, description="Proxy URL")
124 username: Optional[str] = Field(default=None, description="Proxy username")
125 password: Optional[str] = Field(default=None, description="Proxy password")
126 rotation: bool = Field(default=True, description="Enable proxy rotation")
127 country: Optional[str] = Field(default=None, description="Target country code")
129 @model_validator(mode="after")
130 def validate_proxy_config(self) -> "ProxyConfig":
131 """Validate proxy configuration consistency."""
132 if self.enabled and not self.url and not self.provider:
133 raise ValueError("Either 'url' or 'provider' must be specified when proxy is enabled")
134 return self
137class LLMConfig(BaseModel):
138 """Configuration for LLM provider."""
140 model_config = ConfigDict(frozen=True)
142 provider: LLMProvider = Field(default=LLMProvider.OPENAI, description="LLM provider to use")
143 model: str = Field(default="gpt-4o-mini", description="Model name/identifier")
144 api_key: Optional[str] = Field(default=None, description="API key (can use env var)")
145 api_base: Optional[str] = Field(default=None, description="Custom API base URL")
146 organization: Optional[str] = Field(default=None, description="Organization ID")
147 timeout: int = Field(default=30, ge=5, le=300, description="Request timeout in seconds")
148 max_tokens: int = Field(default=1024, ge=1, le=128000, description="Maximum tokens to generate")
149 temperature: float = Field(default=0.7, ge=0.0, le=2.0, description="Sampling temperature")
150 top_p: float = Field(default=1.0, ge=0.0, le=1.0, description="Top-p sampling parameter")
151 retry: RetryConfig = Field(default_factory=RetryConfig, description="Retry configuration")
152 rate_limit: RateLimitConfig = Field(
153 default_factory=RateLimitConfig, description="Rate limit configuration"
154 )
156 @field_validator("api_key", mode="before")
157 @classmethod
158 def resolve_api_key_from_env(cls, v: Optional[str]) -> Optional[str]:
159 """Resolve API key from environment variable if prefixed with 'env:'."""
160 if v and v.startswith("env:"):
161 env_var = v[4:]
162 return os.environ.get(env_var)
163 return v
166class ScraperConfig(BaseModel):
167 """Configuration for web scraping."""
169 model_config = ConfigDict(frozen=True)
171 provider: ScraperProvider = Field(
172 default=ScraperProvider.PLAYWRIGHT, description="Scraper provider to use"
173 )
174 headless: bool = Field(default=True, description="Run browser in headless mode")
175 timeout: int = Field(default=30, ge=5, le=120, description="Page load timeout in seconds")
176 wait_for_selector: Optional[str] = Field(
177 default=None, description="CSS selector to wait for before scraping"
178 )
179 user_agent: Optional[str] = Field(default=None, description="Custom user agent string")
180 viewport_width: int = Field(default=1920, ge=320, le=3840, description="Browser viewport width")
181 viewport_height: int = Field(
182 default=1080, ge=240, le=2160, description="Browser viewport height"
183 )
184 proxy: ProxyConfig = Field(default_factory=ProxyConfig, description="Proxy configuration")
185 rate_limit: RateLimitConfig = Field(
186 default_factory=RateLimitConfig, description="Rate limit configuration"
187 )
188 retry: RetryConfig = Field(default_factory=RetryConfig, description="Retry configuration")
189 screenshots: bool = Field(default=False, description="Capture screenshots during scraping")
190 screenshot_dir: Optional[str] = Field(
191 default=None, description="Directory for screenshots"
192 )
195class ModelConfig(BaseModel):
196 """Configuration for ML models."""
198 model_config = ConfigDict(frozen=True)
200 sentiment_model: str = Field(
201 default="cardiffnlp/twitter-roberta-base-sentiment-latest",
202 description="HuggingFace model for sentiment analysis",
203 )
204 emotion_model: str = Field(
205 default="SamLowe/roberta-base-go_emotions",
206 description="HuggingFace model for emotion detection",
207 )
208 device: Literal["auto", "cpu", "cuda", "mps"] = Field(
209 default="auto", description="Device for model inference"
210 )
211 batch_size: int = Field(default=32, ge=1, le=512, description="Batch size for inference")
212 max_length: int = Field(
213 default=512, ge=32, le=4096, description="Maximum sequence length for tokenization"
214 )
215 use_quantization: bool = Field(default=False, description="Enable model quantization")
216 cache_models: bool = Field(default=True, description="Cache loaded models in memory")
219class CacheConfig(BaseModel):
220 """Configuration for caching."""
222 model_config = ConfigDict(frozen=True)
224 enabled: bool = Field(default=True, description="Enable caching")
225 backend: CacheBackend = Field(default=CacheBackend.MEMORY, description="Cache backend")
226 ttl: int = Field(default=3600, ge=0, le=86400, description="Default TTL in seconds (0=no expiry)")
227 max_size: int = Field(default=1000, ge=10, le=100000, description="Maximum cache entries")
228 namespace: str = Field(default="sentimatrix", description="Cache key namespace")
229 redis_url: Optional[str] = Field(default=None, description="Redis connection URL")
230 sqlite_path: Optional[str] = Field(default=None, description="SQLite database path")
231 compression: bool = Field(default=False, description="Enable cache value compression")
233 @model_validator(mode="after")
234 def validate_backend_config(self) -> "CacheConfig":
235 """Validate backend-specific configuration."""
236 if self.backend == CacheBackend.REDIS and not self.redis_url:
237 raise ValueError("redis_url is required when using Redis backend")
238 if self.backend == CacheBackend.SQLITE and not self.sqlite_path:
239 raise ValueError("sqlite_path is required when using SQLite backend")
240 return self
243class LogConfig(BaseModel):
244 """Configuration for logging."""
246 model_config = ConfigDict(frozen=True)
248 level: LogLevel = Field(default=LogLevel.INFO, description="Log level")
249 format: Literal["json", "text"] = Field(default="json", description="Log format")
250 file_path: Optional[str] = Field(default=None, description="Log file path")
251 max_file_size_mb: int = Field(
252 default=10, ge=1, le=100, description="Maximum log file size in MB"
253 )
254 backup_count: int = Field(default=5, ge=0, le=20, description="Number of backup log files")
255 include_timestamp: bool = Field(default=True, description="Include timestamp in logs")
256 include_caller: bool = Field(default=True, description="Include caller info in logs")
257 console_output: bool = Field(default=True, description="Output logs to console")
258 colorize: bool = Field(default=True, description="Colorize console output")
261class OutputConfig(BaseModel):
262 """Configuration for output handling."""
264 model_config = ConfigDict(frozen=True)
266 default_format: Literal["json", "csv", "xlsx"] = Field(
267 default="json", description="Default export format"
268 )
269 include_metadata: bool = Field(default=True, description="Include metadata in exports")
270 include_raw_data: bool = Field(default=False, description="Include raw data in exports")
271 pretty_print: bool = Field(default=True, description="Pretty print JSON output")
272 datetime_format: str = Field(
273 default="%Y-%m-%dT%H:%M:%SZ", description="Datetime format string"
274 )
277class FallbackConfig(BaseModel):
278 """Configuration for provider fallback chain."""
280 model_config = ConfigDict(frozen=True)
282 enabled: bool = Field(default=True, description="Enable fallback chain")
283 providers: List[LLMProvider] = Field(
284 default_factory=lambda: [LLMProvider.OPENAI, LLMProvider.ANTHROPIC, LLMProvider.GROQ],
285 description="Ordered list of fallback providers",
286 )
287 max_attempts: int = Field(default=3, ge=1, le=10, description="Maximum fallback attempts")
290class SentimatrixConfig(BaseSettings):
291 """
292 Main configuration class for Sentimatrix.
294 Supports loading from:
295 - YAML/JSON files
296 - Environment variables (prefixed with SENTIMATRIX_)
297 - Runtime overrides
299 Example:
300 >>> config = SentimatrixConfig.from_file("config.yaml")
301 >>> config = SentimatrixConfig() # Uses defaults + env vars
302 """
304 model_config = SettingsConfigDict(
305 env_prefix="SENTIMATRIX_",
306 env_nested_delimiter="__",
307 case_sensitive=False,
308 extra="ignore",
309 )
311 # Sub-configurations
312 llm: LLMConfig = Field(default_factory=LLMConfig, description="LLM provider configuration")
313 scrapers: ScraperConfig = Field(
314 default_factory=ScraperConfig, description="Scraper configuration"
315 )
316 models: ModelConfig = Field(default_factory=ModelConfig, description="ML model configuration")
317 cache: CacheConfig = Field(default_factory=CacheConfig, description="Cache configuration")
318 logging: LogConfig = Field(default_factory=LogConfig, description="Logging configuration")
319 output: OutputConfig = Field(default_factory=OutputConfig, description="Output configuration")
320 fallback: FallbackConfig = Field(
321 default_factory=FallbackConfig, description="Fallback configuration"
322 )
324 # Global settings
325 debug: bool = Field(default=False, description="Enable debug mode")
326 dry_run: bool = Field(default=False, description="Enable dry run mode (no API calls)")
328 @classmethod
329 def from_file(cls, path: Union[str, Path], **overrides: Any) -> "SentimatrixConfig":
330 """
331 Load configuration from a YAML or JSON file.
333 Args:
334 path: Path to configuration file
335 **overrides: Additional configuration overrides
337 Returns:
338 SentimatrixConfig instance
340 Raises:
341 ConfigurationError: If file cannot be loaded or parsed
342 """
343 path = Path(path)
345 if not path.exists():
346 from sentimatrix.core.exceptions import ConfigurationError
348 raise ConfigurationError(f"Configuration file not found: {path}")
350 try:
351 with open(path, "r", encoding="utf-8") as f:
352 if path.suffix in (".yaml", ".yml"):
353 data = yaml.safe_load(f) or {}
354 elif path.suffix == ".json":
355 import json
357 data = json.load(f)
358 else:
359 from sentimatrix.core.exceptions import ConfigurationError
361 raise ConfigurationError(
362 f"Unsupported config file format: {path.suffix}. "
363 "Use .yaml, .yml, or .json"
364 )
365 except yaml.YAMLError as e:
366 from sentimatrix.core.exceptions import ConfigurationError
368 raise ConfigurationError(f"Failed to parse YAML configuration: {e}") from e
369 except Exception as e:
370 from sentimatrix.core.exceptions import ConfigurationError
372 raise ConfigurationError(f"Failed to load configuration file: {e}") from e
374 # Merge overrides
375 data.update(overrides)
377 return cls(**data)
379 @classmethod
380 def from_env(cls, **overrides: Any) -> "SentimatrixConfig":
381 """
382 Load configuration from environment variables.
384 Environment variables should be prefixed with SENTIMATRIX_.
385 Nested values use double underscore as delimiter.
387 Example:
388 SENTIMATRIX_LLM__PROVIDER=openai
389 SENTIMATRIX_LLM__MODEL=gpt-4
390 SENTIMATRIX_DEBUG=true
392 Args:
393 **overrides: Additional configuration overrides
395 Returns:
396 SentimatrixConfig instance
397 """
398 return cls(**overrides)
400 def to_dict(self) -> Dict[str, Any]:
401 """Convert configuration to dictionary."""
402 return self.model_dump()
404 def to_yaml(self) -> str:
405 """Convert configuration to YAML string."""
406 return yaml.dump(self.to_dict(), default_flow_style=False, sort_keys=False)
408 def save(self, path: Union[str, Path]) -> None:
409 """
410 Save configuration to a file.
412 Args:
413 path: Output file path (format determined by extension)
414 """
415 path = Path(path)
416 path.parent.mkdir(parents=True, exist_ok=True)
418 with open(path, "w", encoding="utf-8") as f:
419 if path.suffix in (".yaml", ".yml"):
420 yaml.dump(self.to_dict(), f, default_flow_style=False, sort_keys=False)
421 elif path.suffix == ".json": 421 ↛ 426line 421 didn't jump to line 426 because the condition on line 421 was always true
422 import json
424 json.dump(self.to_dict(), f, indent=2)
425 else:
426 from sentimatrix.core.exceptions import ConfigurationError
428 raise ConfigurationError(
429 f"Unsupported output format: {path.suffix}. Use .yaml, .yml, or .json"
430 )
432 def with_overrides(self, **overrides: Any) -> "SentimatrixConfig":
433 """
434 Create a new configuration with the specified overrides.
436 Args:
437 **overrides: Configuration overrides
439 Returns:
440 New SentimatrixConfig instance with overrides applied
441 """
442 data = self.to_dict()
443 self._deep_merge(data, overrides)
444 return SentimatrixConfig(**data)
446 @staticmethod
447 def _deep_merge(base: Dict[str, Any], updates: Dict[str, Any]) -> None:
448 """Deep merge updates into base dictionary."""
449 for key, value in updates.items():
450 if key in base and isinstance(base[key], dict) and isinstance(value, dict):
451 SentimatrixConfig._deep_merge(base[key], value)
452 else:
453 base[key] = value
456# Convenience function for quick access
457def get_config(
458 config_path: Optional[Union[str, Path]] = None, **overrides: Any
459) -> SentimatrixConfig:
460 """
461 Get configuration instance.
463 Loads from file if path provided, otherwise from environment variables.
465 Args:
466 config_path: Optional path to configuration file
467 **overrides: Configuration overrides
469 Returns:
470 SentimatrixConfig instance
471 """
472 if config_path:
473 return SentimatrixConfig.from_file(config_path, **overrides)
474 return SentimatrixConfig.from_env(**overrides)