Coverage for sentimatrix / providers / base.py: 88%
252 statements
« prev ^ index » next coverage.py v7.13.2, created at 2026-01-28 09:30 +0000
« prev ^ index » next coverage.py v7.13.2, created at 2026-01-28 09:30 +0000
1"""
2Sentimatrix Base Provider Interfaces
4Defines abstract base classes and protocols for all provider types:
5- BaseLLMProvider: For LLM API providers
6- BaseScraperProvider: For web scraping providers
7- BaseModelProvider: For ML model providers
9Also includes provider registry for dynamic provider discovery.
10"""
12from __future__ import annotations
14from abc import ABC, abstractmethod
15from dataclasses import dataclass, field
16from datetime import datetime
17from enum import Enum
18from typing import (
19 Any,
20 AsyncIterator,
21 Callable,
22 Dict,
23 Generic,
24 List,
25 Optional,
26 Type,
27 TypeVar,
28 Union,
29)
31from sentimatrix.core.config import LLMConfig, ModelConfig, ScraperConfig
32from sentimatrix.core.exceptions import (
33 ProviderInitializationError,
34 ProviderNotFoundError,
35)
38class ProviderType(str, Enum):
39 """Types of providers."""
41 LLM = "llm"
42 SCRAPER = "scraper"
43 MODEL = "model"
46@dataclass
47class ProviderCapabilities:
48 """Describes what a provider can do."""
50 # LLM capabilities
51 streaming: bool = False
52 function_calling: bool = False
53 vision: bool = False
54 json_mode: bool = False
55 embeddings: bool = False
57 # Context limits
58 max_context_tokens: int = 4096
59 max_output_tokens: int = 4096
61 # Scraper capabilities
62 javascript_rendering: bool = False
63 screenshots: bool = False
64 pdf_generation: bool = False
65 proxy_support: bool = False
67 # Model capabilities
68 batch_processing: bool = False
69 gpu_support: bool = False
70 quantization: bool = False
72 def to_dict(self) -> Dict[str, Any]:
73 """Convert to dictionary."""
74 return {
75 "streaming": self.streaming,
76 "function_calling": self.function_calling,
77 "vision": self.vision,
78 "json_mode": self.json_mode,
79 "embeddings": self.embeddings,
80 "max_context_tokens": self.max_context_tokens,
81 "max_output_tokens": self.max_output_tokens,
82 "javascript_rendering": self.javascript_rendering,
83 "screenshots": self.screenshots,
84 "pdf_generation": self.pdf_generation,
85 "proxy_support": self.proxy_support,
86 "batch_processing": self.batch_processing,
87 "gpu_support": self.gpu_support,
88 "quantization": self.quantization,
89 }
92@dataclass
93class ProviderInfo:
94 """Metadata about a provider."""
96 name: str
97 provider_type: ProviderType
98 version: str = "1.0.0"
99 description: str = ""
100 capabilities: ProviderCapabilities = field(default_factory=ProviderCapabilities)
101 supported_models: List[str] = field(default_factory=list)
102 website: Optional[str] = None
103 documentation: Optional[str] = None
105 def to_dict(self) -> Dict[str, Any]:
106 """Convert to dictionary."""
107 return {
108 "name": self.name,
109 "provider_type": self.provider_type.value,
110 "version": self.version,
111 "description": self.description,
112 "capabilities": self.capabilities.to_dict(),
113 "supported_models": self.supported_models,
114 "website": self.website,
115 "documentation": self.documentation,
116 }
119@dataclass
120class TokenUsage:
121 """Token usage statistics."""
123 prompt_tokens: int = 0
124 completion_tokens: int = 0
125 total_tokens: int = 0
127 def __add__(self, other: "TokenUsage") -> "TokenUsage":
128 """Add two token usages together."""
129 return TokenUsage(
130 prompt_tokens=self.prompt_tokens + other.prompt_tokens,
131 completion_tokens=self.completion_tokens + other.completion_tokens,
132 total_tokens=self.total_tokens + other.total_tokens,
133 )
136@dataclass
137class LLMResponse:
138 """Response from an LLM provider."""
140 content: str
141 model: str
142 provider: str
143 usage: TokenUsage
144 finish_reason: str = "stop"
145 response_time_ms: float = 0.0
146 raw_response: Optional[Dict[str, Any]] = None
148 # Function calling
149 tool_calls: Optional[List[Dict[str, Any]]] = None
151 # Metadata
152 created_at: datetime = field(default_factory=datetime.now)
154 def to_dict(self) -> Dict[str, Any]:
155 """Convert to dictionary."""
156 return {
157 "content": self.content,
158 "model": self.model,
159 "provider": self.provider,
160 "usage": {
161 "prompt_tokens": self.usage.prompt_tokens,
162 "completion_tokens": self.usage.completion_tokens,
163 "total_tokens": self.usage.total_tokens,
164 },
165 "finish_reason": self.finish_reason,
166 "response_time_ms": self.response_time_ms,
167 "tool_calls": self.tool_calls,
168 "created_at": self.created_at.isoformat(),
169 }
172@dataclass
173class ScrapedContent:
174 """Content scraped from a URL."""
176 url: str
177 title: Optional[str] = None
178 content: str = ""
179 html: Optional[str] = None
180 status_code: int = 200
181 response_time_ms: float = 0.0
183 # Metadata
184 headers: Dict[str, str] = field(default_factory=dict)
185 cookies: Dict[str, str] = field(default_factory=dict)
186 scraped_at: datetime = field(default_factory=datetime.now)
188 # Provider info
189 provider: str = ""
190 proxy_used: Optional[str] = None
191 user_agent: Optional[str] = None
193 def to_dict(self) -> Dict[str, Any]:
194 """Convert to dictionary."""
195 return {
196 "url": self.url,
197 "title": self.title,
198 "content": self.content,
199 "html": self.html,
200 "status_code": self.status_code,
201 "response_time_ms": self.response_time_ms,
202 "provider": self.provider,
203 "scraped_at": self.scraped_at.isoformat(),
204 }
207@dataclass
208class Review:
209 """Represents a scraped review."""
211 id: str
212 text: str
213 source: str
214 platform: str
215 author: Optional[str] = None
216 rating: Optional[float] = None
217 timestamp: Optional[datetime] = None
218 metadata: Dict[str, Any] = field(default_factory=dict)
220 def to_dict(self) -> Dict[str, Any]:
221 """Convert to dictionary."""
222 return {
223 "id": self.id,
224 "text": self.text,
225 "source": self.source,
226 "platform": self.platform,
227 "author": self.author,
228 "rating": self.rating,
229 "timestamp": self.timestamp.isoformat() if self.timestamp else None,
230 "metadata": self.metadata,
231 }
234@dataclass
235class PredictionResult:
236 """Result from a model prediction."""
238 label: str
239 score: float
240 confidence: float = 0.0
241 all_scores: Dict[str, float] = field(default_factory=dict)
242 model_name: str = ""
243 processing_time_ms: float = 0.0
245 def to_dict(self) -> Dict[str, Any]:
246 """Convert to dictionary."""
247 return {
248 "label": self.label,
249 "score": self.score,
250 "confidence": self.confidence,
251 "all_scores": self.all_scores,
252 "model_name": self.model_name,
253 "processing_time_ms": self.processing_time_ms,
254 }
257class BaseProvider(ABC):
258 """Base class for all providers."""
260 def __init__(self, config: Any = None) -> None:
261 """
262 Initialize provider.
264 Args:
265 config: Provider-specific configuration
266 """
267 self._config = config
268 self._initialized = False
270 @property
271 @abstractmethod
272 def info(self) -> ProviderInfo:
273 """Get provider information."""
274 pass
276 @property
277 def name(self) -> str:
278 """Get provider name."""
279 return self.info.name
281 @property
282 def is_initialized(self) -> bool:
283 """Check if provider is initialized."""
284 return self._initialized
286 @abstractmethod
287 async def initialize(self) -> None:
288 """Initialize the provider (load resources, verify credentials, etc.)."""
289 pass
291 @abstractmethod
292 async def close(self) -> None:
293 """Cleanup provider resources."""
294 pass
296 async def __aenter__(self) -> "BaseProvider":
297 """Async context manager entry."""
298 await self.initialize()
299 return self
301 async def __aexit__(self, *args: Any) -> None:
302 """Async context manager exit."""
303 await self.close()
305 def _ensure_initialized(self) -> None:
306 """Ensure provider is initialized before operations."""
307 if not self._initialized: 307 ↛ exitline 307 didn't return from function '_ensure_initialized' because the condition on line 307 was always true
308 raise ProviderInitializationError(
309 self.name,
310 "Provider not initialized. Call initialize() first.",
311 )
314class BaseLLMProvider(BaseProvider):
315 """
316 Abstract base class for LLM providers.
318 All LLM providers must implement these methods:
319 - generate: Generate a completion from a prompt
320 - generate_stream: Stream a completion from a prompt
322 Optional methods:
323 - embed: Generate embeddings for text
324 - count_tokens: Count tokens in text
325 """
327 def __init__(self, config: Optional[LLMConfig] = None) -> None:
328 """
329 Initialize LLM provider.
331 Args:
332 config: LLM configuration
333 """
334 super().__init__(config)
335 self._config: LLMConfig = config or LLMConfig()
337 @abstractmethod
338 async def generate(
339 self,
340 prompt: str,
341 system_prompt: Optional[str] = None,
342 temperature: Optional[float] = None,
343 max_tokens: Optional[int] = None,
344 stop: Optional[List[str]] = None,
345 **kwargs: Any,
346 ) -> LLMResponse:
347 """
348 Generate a completion from a prompt.
350 Args:
351 prompt: User prompt
352 system_prompt: Optional system prompt
353 temperature: Sampling temperature (overrides config)
354 max_tokens: Maximum tokens to generate (overrides config)
355 stop: Stop sequences
356 **kwargs: Additional provider-specific parameters
358 Returns:
359 LLMResponse with generated content
360 """
361 pass
363 @abstractmethod
364 async def generate_stream(
365 self,
366 prompt: str,
367 system_prompt: Optional[str] = None,
368 temperature: Optional[float] = None,
369 max_tokens: Optional[int] = None,
370 stop: Optional[List[str]] = None,
371 **kwargs: Any,
372 ) -> AsyncIterator[str]:
373 """
374 Stream a completion from a prompt.
376 Args:
377 prompt: User prompt
378 system_prompt: Optional system prompt
379 temperature: Sampling temperature (overrides config)
380 max_tokens: Maximum tokens to generate (overrides config)
381 stop: Stop sequences
382 **kwargs: Additional provider-specific parameters
384 Yields:
385 Text chunks as they're generated
386 """
387 pass
389 async def generate_with_functions(
390 self,
391 prompt: str,
392 functions: List[Dict[str, Any]],
393 system_prompt: Optional[str] = None,
394 function_call: Union[str, Dict[str, str]] = "auto",
395 **kwargs: Any,
396 ) -> LLMResponse:
397 """
398 Generate a completion with function calling.
400 Args:
401 prompt: User prompt
402 functions: List of function definitions
403 system_prompt: Optional system prompt
404 function_call: How to handle function calls ("auto", "none", or specific function)
405 **kwargs: Additional parameters
407 Returns:
408 LLMResponse, potentially with tool_calls
410 Raises:
411 NotImplementedError: If provider doesn't support function calling
412 """
413 if not self.info.capabilities.function_calling:
414 raise NotImplementedError(
415 f"{self.name} does not support function calling"
416 )
417 raise NotImplementedError("Subclass must implement generate_with_functions")
419 async def embed(self, text: Union[str, List[str]]) -> Union[List[float], List[List[float]]]:
420 """
421 Generate embeddings for text.
423 Args:
424 text: Single text or list of texts
426 Returns:
427 Embedding vector(s)
429 Raises:
430 NotImplementedError: If provider doesn't support embeddings
431 """
432 if not self.info.capabilities.embeddings:
433 raise NotImplementedError(f"{self.name} does not support embeddings")
434 raise NotImplementedError("Subclass must implement embed")
436 def count_tokens(self, text: str) -> int:
437 """
438 Count tokens in text.
440 Args:
441 text: Text to count tokens in
443 Returns:
444 Number of tokens
446 Note:
447 Default implementation provides rough estimate.
448 Override for accurate counting.
449 """
450 # Rough estimate: ~4 characters per token for English
451 return len(text) // 4
453 @property
454 def model(self) -> str:
455 """Get configured model name."""
456 return self._config.model
458 @property
459 def supports_streaming(self) -> bool:
460 """Check if provider supports streaming."""
461 return self.info.capabilities.streaming
463 @property
464 def supports_vision(self) -> bool:
465 """Check if provider supports vision/image inputs."""
466 return self.info.capabilities.vision
468 @property
469 def supports_function_calling(self) -> bool:
470 """Check if provider supports function calling."""
471 return self.info.capabilities.function_calling
474class BaseScraperProvider(BaseProvider):
475 """
476 Abstract base class for scraper providers.
478 All scraper providers must implement:
479 - scrape: Scrape content from a URL
480 - scrape_reviews: Extract reviews from a URL
481 """
483 def __init__(self, config: Optional[ScraperConfig] = None) -> None:
484 """
485 Initialize scraper provider.
487 Args:
488 config: Scraper configuration
489 """
490 super().__init__(config)
491 self._config: ScraperConfig = config or ScraperConfig()
493 @abstractmethod
494 async def scrape(
495 self,
496 url: str,
497 wait_for: Optional[str] = None,
498 timeout: Optional[int] = None,
499 **kwargs: Any,
500 ) -> ScrapedContent:
501 """
502 Scrape content from a URL.
504 Args:
505 url: URL to scrape
506 wait_for: CSS selector to wait for before scraping
507 timeout: Request timeout (overrides config)
508 **kwargs: Additional provider-specific parameters
510 Returns:
511 ScrapedContent with page content
512 """
513 pass
515 @abstractmethod
516 async def scrape_reviews(
517 self,
518 url: str,
519 limit: int = 100,
520 sort_by: Optional[str] = None,
521 **kwargs: Any,
522 ) -> List[Review]:
523 """
524 Extract reviews from a URL.
526 Args:
527 url: URL to scrape reviews from
528 limit: Maximum number of reviews to extract
529 sort_by: Sort order (implementation-specific)
530 **kwargs: Additional parameters
532 Returns:
533 List of Review objects
534 """
535 pass
537 async def screenshot(
538 self,
539 url: str,
540 path: str,
541 full_page: bool = True,
542 **kwargs: Any,
543 ) -> str:
544 """
545 Take a screenshot of a page.
547 Args:
548 url: URL to screenshot
549 path: Output file path
550 full_page: Capture full page or viewport only
551 **kwargs: Additional parameters
553 Returns:
554 Path to saved screenshot
556 Raises:
557 NotImplementedError: If provider doesn't support screenshots
558 """
559 if not self.info.capabilities.screenshots:
560 raise NotImplementedError(f"{self.name} does not support screenshots")
561 raise NotImplementedError("Subclass must implement screenshot")
563 def get_supported_platforms(self) -> List[str]:
564 """
565 Get list of supported platforms/domains.
567 Returns:
568 List of platform names this scraper supports
569 """
570 return []
572 def supports_platform(self, platform: str) -> bool:
573 """
574 Check if scraper supports a specific platform.
576 Args:
577 platform: Platform name to check
579 Returns:
580 True if platform is supported
581 """
582 supported = self.get_supported_platforms()
583 if not supported: # Empty means all platforms 583 ↛ 584line 583 didn't jump to line 584 because the condition on line 583 was never true
584 return True
585 return platform.lower() in [p.lower() for p in supported]
588class BaseModelProvider(BaseProvider):
589 """
590 Abstract base class for ML model providers (sentiment, emotion, etc.).
592 All model providers must implement:
593 - predict: Make a prediction on input
594 - predict_batch: Make predictions on multiple inputs
595 """
597 def __init__(self, config: Optional[ModelConfig] = None) -> None:
598 """
599 Initialize model provider.
601 Args:
602 config: Model configuration
603 """
604 super().__init__(config)
605 self._config: ModelConfig = config or ModelConfig()
606 self._model: Any = None
608 @abstractmethod
609 async def predict(self, text: str, **kwargs: Any) -> PredictionResult:
610 """
611 Make a prediction on input text.
613 Args:
614 text: Input text
615 **kwargs: Additional parameters
617 Returns:
618 PredictionResult with label and scores
619 """
620 pass
622 @abstractmethod
623 async def predict_batch(
624 self, texts: List[str], **kwargs: Any
625 ) -> List[PredictionResult]:
626 """
627 Make predictions on multiple texts.
629 Args:
630 texts: List of input texts
631 **kwargs: Additional parameters
633 Returns:
634 List of PredictionResult objects
635 """
636 pass
638 @abstractmethod
639 def get_model_info(self) -> Dict[str, Any]:
640 """
641 Get information about the loaded model.
643 Returns:
644 Dictionary with model metadata
645 """
646 pass
648 @property
649 def model_name(self) -> str:
650 """Get loaded model name."""
651 return getattr(self._config, "sentiment_model", "unknown")
653 @property
654 def device(self) -> str:
655 """Get device model is running on."""
656 return self._config.device
659# Type variable for generic provider
660T = TypeVar("T", bound=BaseProvider)
663class ProviderRegistry:
664 """
665 Registry for provider discovery and instantiation.
667 Supports registration of provider classes and factory functions.
668 """
670 _instance: Optional["ProviderRegistry"] = None
671 _providers: Dict[str, Dict[str, Type[BaseProvider]]] = {}
672 _factories: Dict[str, Dict[str, Callable[..., BaseProvider]]] = {}
674 def __new__(cls) -> "ProviderRegistry":
675 """Singleton pattern."""
676 if cls._instance is None:
677 cls._instance = super().__new__(cls)
678 cls._providers = {
679 "llm": {},
680 "scraper": {},
681 "model": {},
682 }
683 cls._factories = {
684 "llm": {},
685 "scraper": {},
686 "model": {},
687 }
688 return cls._instance
690 def register(
691 self,
692 name: str,
693 provider_type: Union[str, ProviderType],
694 provider_class: Type[BaseProvider],
695 ) -> None:
696 """
697 Register a provider class.
699 Args:
700 name: Provider name (e.g., "openai", "playwright")
701 provider_type: Type of provider ("llm", "scraper", "model")
702 provider_class: Provider class to register
703 """
704 if isinstance(provider_type, ProviderType): 704 ↛ 707line 704 didn't jump to line 707 because the condition on line 704 was always true
705 provider_type = provider_type.value
707 if provider_type not in self._providers: 707 ↛ 708line 707 didn't jump to line 708 because the condition on line 707 was never true
708 self._providers[provider_type] = {}
710 self._providers[provider_type][name.lower()] = provider_class
712 def register_factory(
713 self,
714 name: str,
715 provider_type: Union[str, ProviderType],
716 factory: Callable[..., BaseProvider],
717 ) -> None:
718 """
719 Register a factory function for creating providers.
721 Args:
722 name: Provider name
723 provider_type: Type of provider
724 factory: Factory function that creates provider instances
725 """
726 if isinstance(provider_type, ProviderType): 726 ↛ 729line 726 didn't jump to line 729 because the condition on line 726 was always true
727 provider_type = provider_type.value
729 if provider_type not in self._factories: 729 ↛ 730line 729 didn't jump to line 730 because the condition on line 729 was never true
730 self._factories[provider_type] = {}
732 self._factories[provider_type][name.lower()] = factory
734 def get(
735 self,
736 name: str,
737 provider_type: Union[str, ProviderType],
738 config: Any = None,
739 **kwargs: Any,
740 ) -> BaseProvider:
741 """
742 Get a provider instance by name.
744 Args:
745 name: Provider name
746 provider_type: Type of provider
747 config: Provider configuration
748 **kwargs: Additional arguments for provider constructor
750 Returns:
751 Provider instance
753 Raises:
754 ProviderNotFoundError: If provider is not registered
755 """
756 if isinstance(provider_type, ProviderType):
757 provider_type = provider_type.value
759 name_lower = name.lower()
761 # Check factories first
762 if provider_type in self._factories and name_lower in self._factories[provider_type]: 762 ↛ 763line 762 didn't jump to line 763 because the condition on line 762 was never true
763 factory = self._factories[provider_type][name_lower]
764 return factory(config, **kwargs)
766 # Then check registered classes
767 if provider_type in self._providers and name_lower in self._providers[provider_type]: 767 ↛ 768line 767 didn't jump to line 768 because the condition on line 767 was never true
768 provider_class = self._providers[provider_type][name_lower]
769 return provider_class(config, **kwargs)
771 raise ProviderNotFoundError(f"{provider_type}:{name}")
773 def list_providers(
774 self, provider_type: Optional[Union[str, ProviderType]] = None
775 ) -> Dict[str, List[str]]:
776 """
777 List registered providers.
779 Args:
780 provider_type: Optional type filter
782 Returns:
783 Dictionary of provider types to provider names
784 """
785 if provider_type is not None:
786 if isinstance(provider_type, ProviderType): 786 ↛ 788line 786 didn't jump to line 788 because the condition on line 786 was always true
787 provider_type = provider_type.value
788 return {
789 provider_type: list(self._providers.get(provider_type, {}).keys())
790 + list(self._factories.get(provider_type, {}).keys())
791 }
793 result = {}
794 for pt in ["llm", "scraper", "model"]:
795 result[pt] = list(self._providers.get(pt, {}).keys()) + list(
796 self._factories.get(pt, {}).keys()
797 )
798 return result
800 def is_registered(
801 self, name: str, provider_type: Union[str, ProviderType]
802 ) -> bool:
803 """
804 Check if a provider is registered.
806 Args:
807 name: Provider name
808 provider_type: Type of provider
810 Returns:
811 True if provider is registered
812 """
813 if isinstance(provider_type, ProviderType): 813 ↛ 816line 813 didn't jump to line 816 because the condition on line 813 was always true
814 provider_type = provider_type.value
816 name_lower = name.lower()
817 return (
818 name_lower in self._providers.get(provider_type, {})
819 or name_lower in self._factories.get(provider_type, {})
820 )
823# Module-level convenience functions
826def get_provider(
827 name: str,
828 provider_type: Union[str, ProviderType],
829 config: Any = None,
830 **kwargs: Any,
831) -> BaseProvider:
832 """
833 Get a provider instance by name.
835 Args:
836 name: Provider name (e.g., "openai", "playwright")
837 provider_type: Type of provider ("llm", "scraper", "model")
838 config: Provider configuration
839 **kwargs: Additional arguments
841 Returns:
842 Provider instance
844 Example:
845 >>> provider = get_provider("openai", "llm", config=llm_config)
846 >>> await provider.initialize()
847 """
848 registry = ProviderRegistry()
849 return registry.get(name, provider_type, config, **kwargs)
852def register_provider(
853 name: str,
854 provider_type: Union[str, ProviderType],
855 provider_class: Type[BaseProvider],
856) -> None:
857 """
858 Register a provider class.
860 Args:
861 name: Provider name
862 provider_type: Type of provider
863 provider_class: Provider class to register
865 Example:
866 >>> register_provider("custom", "llm", CustomLLMProvider)
867 """
868 registry = ProviderRegistry()
869 registry.register(name, provider_type, provider_class)
872def list_providers(
873 provider_type: Optional[Union[str, ProviderType]] = None
874) -> Dict[str, List[str]]:
875 """
876 List registered providers.
878 Args:
879 provider_type: Optional type filter
881 Returns:
882 Dictionary of provider types to provider names
883 """
884 registry = ProviderRegistry()
885 return registry.list_providers(provider_type)