Coverage for sentimatrix / providers / base.py: 88%

252 statements  

« prev     ^ index     » next       coverage.py v7.13.2, created at 2026-01-28 09:30 +0000

1""" 

2Sentimatrix Base Provider Interfaces 

3 

4Defines abstract base classes and protocols for all provider types: 

5- BaseLLMProvider: For LLM API providers 

6- BaseScraperProvider: For web scraping providers 

7- BaseModelProvider: For ML model providers 

8 

9Also includes provider registry for dynamic provider discovery. 

10""" 

11 

12from __future__ import annotations 

13 

14from abc import ABC, abstractmethod 

15from dataclasses import dataclass, field 

16from datetime import datetime 

17from enum import Enum 

18from typing import ( 

19 Any, 

20 AsyncIterator, 

21 Callable, 

22 Dict, 

23 Generic, 

24 List, 

25 Optional, 

26 Type, 

27 TypeVar, 

28 Union, 

29) 

30 

31from sentimatrix.core.config import LLMConfig, ModelConfig, ScraperConfig 

32from sentimatrix.core.exceptions import ( 

33 ProviderInitializationError, 

34 ProviderNotFoundError, 

35) 

36 

37 

38class ProviderType(str, Enum): 

39 """Types of providers.""" 

40 

41 LLM = "llm" 

42 SCRAPER = "scraper" 

43 MODEL = "model" 

44 

45 

46@dataclass 

47class ProviderCapabilities: 

48 """Describes what a provider can do.""" 

49 

50 # LLM capabilities 

51 streaming: bool = False 

52 function_calling: bool = False 

53 vision: bool = False 

54 json_mode: bool = False 

55 embeddings: bool = False 

56 

57 # Context limits 

58 max_context_tokens: int = 4096 

59 max_output_tokens: int = 4096 

60 

61 # Scraper capabilities 

62 javascript_rendering: bool = False 

63 screenshots: bool = False 

64 pdf_generation: bool = False 

65 proxy_support: bool = False 

66 

67 # Model capabilities 

68 batch_processing: bool = False 

69 gpu_support: bool = False 

70 quantization: bool = False 

71 

72 def to_dict(self) -> Dict[str, Any]: 

73 """Convert to dictionary.""" 

74 return { 

75 "streaming": self.streaming, 

76 "function_calling": self.function_calling, 

77 "vision": self.vision, 

78 "json_mode": self.json_mode, 

79 "embeddings": self.embeddings, 

80 "max_context_tokens": self.max_context_tokens, 

81 "max_output_tokens": self.max_output_tokens, 

82 "javascript_rendering": self.javascript_rendering, 

83 "screenshots": self.screenshots, 

84 "pdf_generation": self.pdf_generation, 

85 "proxy_support": self.proxy_support, 

86 "batch_processing": self.batch_processing, 

87 "gpu_support": self.gpu_support, 

88 "quantization": self.quantization, 

89 } 

90 

91 

92@dataclass 

93class ProviderInfo: 

94 """Metadata about a provider.""" 

95 

96 name: str 

97 provider_type: ProviderType 

98 version: str = "1.0.0" 

99 description: str = "" 

100 capabilities: ProviderCapabilities = field(default_factory=ProviderCapabilities) 

101 supported_models: List[str] = field(default_factory=list) 

102 website: Optional[str] = None 

103 documentation: Optional[str] = None 

104 

105 def to_dict(self) -> Dict[str, Any]: 

106 """Convert to dictionary.""" 

107 return { 

108 "name": self.name, 

109 "provider_type": self.provider_type.value, 

110 "version": self.version, 

111 "description": self.description, 

112 "capabilities": self.capabilities.to_dict(), 

113 "supported_models": self.supported_models, 

114 "website": self.website, 

115 "documentation": self.documentation, 

116 } 

117 

118 

119@dataclass 

120class TokenUsage: 

121 """Token usage statistics.""" 

122 

123 prompt_tokens: int = 0 

124 completion_tokens: int = 0 

125 total_tokens: int = 0 

126 

127 def __add__(self, other: "TokenUsage") -> "TokenUsage": 

128 """Add two token usages together.""" 

129 return TokenUsage( 

130 prompt_tokens=self.prompt_tokens + other.prompt_tokens, 

131 completion_tokens=self.completion_tokens + other.completion_tokens, 

132 total_tokens=self.total_tokens + other.total_tokens, 

133 ) 

134 

135 

136@dataclass 

137class LLMResponse: 

138 """Response from an LLM provider.""" 

139 

140 content: str 

141 model: str 

142 provider: str 

143 usage: TokenUsage 

144 finish_reason: str = "stop" 

145 response_time_ms: float = 0.0 

146 raw_response: Optional[Dict[str, Any]] = None 

147 

148 # Function calling 

149 tool_calls: Optional[List[Dict[str, Any]]] = None 

150 

151 # Metadata 

152 created_at: datetime = field(default_factory=datetime.now) 

153 

154 def to_dict(self) -> Dict[str, Any]: 

155 """Convert to dictionary.""" 

156 return { 

157 "content": self.content, 

158 "model": self.model, 

159 "provider": self.provider, 

160 "usage": { 

161 "prompt_tokens": self.usage.prompt_tokens, 

162 "completion_tokens": self.usage.completion_tokens, 

163 "total_tokens": self.usage.total_tokens, 

164 }, 

165 "finish_reason": self.finish_reason, 

166 "response_time_ms": self.response_time_ms, 

167 "tool_calls": self.tool_calls, 

168 "created_at": self.created_at.isoformat(), 

169 } 

170 

171 

172@dataclass 

173class ScrapedContent: 

174 """Content scraped from a URL.""" 

175 

176 url: str 

177 title: Optional[str] = None 

178 content: str = "" 

179 html: Optional[str] = None 

180 status_code: int = 200 

181 response_time_ms: float = 0.0 

182 

183 # Metadata 

184 headers: Dict[str, str] = field(default_factory=dict) 

185 cookies: Dict[str, str] = field(default_factory=dict) 

186 scraped_at: datetime = field(default_factory=datetime.now) 

187 

188 # Provider info 

189 provider: str = "" 

190 proxy_used: Optional[str] = None 

191 user_agent: Optional[str] = None 

192 

193 def to_dict(self) -> Dict[str, Any]: 

194 """Convert to dictionary.""" 

195 return { 

196 "url": self.url, 

197 "title": self.title, 

198 "content": self.content, 

199 "html": self.html, 

200 "status_code": self.status_code, 

201 "response_time_ms": self.response_time_ms, 

202 "provider": self.provider, 

203 "scraped_at": self.scraped_at.isoformat(), 

204 } 

205 

206 

207@dataclass 

208class Review: 

209 """Represents a scraped review.""" 

210 

211 id: str 

212 text: str 

213 source: str 

214 platform: str 

215 author: Optional[str] = None 

216 rating: Optional[float] = None 

217 timestamp: Optional[datetime] = None 

218 metadata: Dict[str, Any] = field(default_factory=dict) 

219 

220 def to_dict(self) -> Dict[str, Any]: 

221 """Convert to dictionary.""" 

222 return { 

223 "id": self.id, 

224 "text": self.text, 

225 "source": self.source, 

226 "platform": self.platform, 

227 "author": self.author, 

228 "rating": self.rating, 

229 "timestamp": self.timestamp.isoformat() if self.timestamp else None, 

230 "metadata": self.metadata, 

231 } 

232 

233 

234@dataclass 

235class PredictionResult: 

236 """Result from a model prediction.""" 

237 

238 label: str 

239 score: float 

240 confidence: float = 0.0 

241 all_scores: Dict[str, float] = field(default_factory=dict) 

242 model_name: str = "" 

243 processing_time_ms: float = 0.0 

244 

245 def to_dict(self) -> Dict[str, Any]: 

246 """Convert to dictionary.""" 

247 return { 

248 "label": self.label, 

249 "score": self.score, 

250 "confidence": self.confidence, 

251 "all_scores": self.all_scores, 

252 "model_name": self.model_name, 

253 "processing_time_ms": self.processing_time_ms, 

254 } 

255 

256 

257class BaseProvider(ABC): 

258 """Base class for all providers.""" 

259 

260 def __init__(self, config: Any = None) -> None: 

261 """ 

262 Initialize provider. 

263 

264 Args: 

265 config: Provider-specific configuration 

266 """ 

267 self._config = config 

268 self._initialized = False 

269 

270 @property 

271 @abstractmethod 

272 def info(self) -> ProviderInfo: 

273 """Get provider information.""" 

274 pass 

275 

276 @property 

277 def name(self) -> str: 

278 """Get provider name.""" 

279 return self.info.name 

280 

281 @property 

282 def is_initialized(self) -> bool: 

283 """Check if provider is initialized.""" 

284 return self._initialized 

285 

286 @abstractmethod 

287 async def initialize(self) -> None: 

288 """Initialize the provider (load resources, verify credentials, etc.).""" 

289 pass 

290 

291 @abstractmethod 

292 async def close(self) -> None: 

293 """Cleanup provider resources.""" 

294 pass 

295 

296 async def __aenter__(self) -> "BaseProvider": 

297 """Async context manager entry.""" 

298 await self.initialize() 

299 return self 

300 

301 async def __aexit__(self, *args: Any) -> None: 

302 """Async context manager exit.""" 

303 await self.close() 

304 

305 def _ensure_initialized(self) -> None: 

306 """Ensure provider is initialized before operations.""" 

307 if not self._initialized: 307 ↛ exitline 307 didn't return from function '_ensure_initialized' because the condition on line 307 was always true

308 raise ProviderInitializationError( 

309 self.name, 

310 "Provider not initialized. Call initialize() first.", 

311 ) 

312 

313 

314class BaseLLMProvider(BaseProvider): 

315 """ 

316 Abstract base class for LLM providers. 

317 

318 All LLM providers must implement these methods: 

319 - generate: Generate a completion from a prompt 

320 - generate_stream: Stream a completion from a prompt 

321 

322 Optional methods: 

323 - embed: Generate embeddings for text 

324 - count_tokens: Count tokens in text 

325 """ 

326 

327 def __init__(self, config: Optional[LLMConfig] = None) -> None: 

328 """ 

329 Initialize LLM provider. 

330 

331 Args: 

332 config: LLM configuration 

333 """ 

334 super().__init__(config) 

335 self._config: LLMConfig = config or LLMConfig() 

336 

337 @abstractmethod 

338 async def generate( 

339 self, 

340 prompt: str, 

341 system_prompt: Optional[str] = None, 

342 temperature: Optional[float] = None, 

343 max_tokens: Optional[int] = None, 

344 stop: Optional[List[str]] = None, 

345 **kwargs: Any, 

346 ) -> LLMResponse: 

347 """ 

348 Generate a completion from a prompt. 

349 

350 Args: 

351 prompt: User prompt 

352 system_prompt: Optional system prompt 

353 temperature: Sampling temperature (overrides config) 

354 max_tokens: Maximum tokens to generate (overrides config) 

355 stop: Stop sequences 

356 **kwargs: Additional provider-specific parameters 

357 

358 Returns: 

359 LLMResponse with generated content 

360 """ 

361 pass 

362 

363 @abstractmethod 

364 async def generate_stream( 

365 self, 

366 prompt: str, 

367 system_prompt: Optional[str] = None, 

368 temperature: Optional[float] = None, 

369 max_tokens: Optional[int] = None, 

370 stop: Optional[List[str]] = None, 

371 **kwargs: Any, 

372 ) -> AsyncIterator[str]: 

373 """ 

374 Stream a completion from a prompt. 

375 

376 Args: 

377 prompt: User prompt 

378 system_prompt: Optional system prompt 

379 temperature: Sampling temperature (overrides config) 

380 max_tokens: Maximum tokens to generate (overrides config) 

381 stop: Stop sequences 

382 **kwargs: Additional provider-specific parameters 

383 

384 Yields: 

385 Text chunks as they're generated 

386 """ 

387 pass 

388 

389 async def generate_with_functions( 

390 self, 

391 prompt: str, 

392 functions: List[Dict[str, Any]], 

393 system_prompt: Optional[str] = None, 

394 function_call: Union[str, Dict[str, str]] = "auto", 

395 **kwargs: Any, 

396 ) -> LLMResponse: 

397 """ 

398 Generate a completion with function calling. 

399 

400 Args: 

401 prompt: User prompt 

402 functions: List of function definitions 

403 system_prompt: Optional system prompt 

404 function_call: How to handle function calls ("auto", "none", or specific function) 

405 **kwargs: Additional parameters 

406 

407 Returns: 

408 LLMResponse, potentially with tool_calls 

409 

410 Raises: 

411 NotImplementedError: If provider doesn't support function calling 

412 """ 

413 if not self.info.capabilities.function_calling: 

414 raise NotImplementedError( 

415 f"{self.name} does not support function calling" 

416 ) 

417 raise NotImplementedError("Subclass must implement generate_with_functions") 

418 

419 async def embed(self, text: Union[str, List[str]]) -> Union[List[float], List[List[float]]]: 

420 """ 

421 Generate embeddings for text. 

422 

423 Args: 

424 text: Single text or list of texts 

425 

426 Returns: 

427 Embedding vector(s) 

428 

429 Raises: 

430 NotImplementedError: If provider doesn't support embeddings 

431 """ 

432 if not self.info.capabilities.embeddings: 

433 raise NotImplementedError(f"{self.name} does not support embeddings") 

434 raise NotImplementedError("Subclass must implement embed") 

435 

436 def count_tokens(self, text: str) -> int: 

437 """ 

438 Count tokens in text. 

439 

440 Args: 

441 text: Text to count tokens in 

442 

443 Returns: 

444 Number of tokens 

445 

446 Note: 

447 Default implementation provides rough estimate. 

448 Override for accurate counting. 

449 """ 

450 # Rough estimate: ~4 characters per token for English 

451 return len(text) // 4 

452 

453 @property 

454 def model(self) -> str: 

455 """Get configured model name.""" 

456 return self._config.model 

457 

458 @property 

459 def supports_streaming(self) -> bool: 

460 """Check if provider supports streaming.""" 

461 return self.info.capabilities.streaming 

462 

463 @property 

464 def supports_vision(self) -> bool: 

465 """Check if provider supports vision/image inputs.""" 

466 return self.info.capabilities.vision 

467 

468 @property 

469 def supports_function_calling(self) -> bool: 

470 """Check if provider supports function calling.""" 

471 return self.info.capabilities.function_calling 

472 

473 

474class BaseScraperProvider(BaseProvider): 

475 """ 

476 Abstract base class for scraper providers. 

477 

478 All scraper providers must implement: 

479 - scrape: Scrape content from a URL 

480 - scrape_reviews: Extract reviews from a URL 

481 """ 

482 

483 def __init__(self, config: Optional[ScraperConfig] = None) -> None: 

484 """ 

485 Initialize scraper provider. 

486 

487 Args: 

488 config: Scraper configuration 

489 """ 

490 super().__init__(config) 

491 self._config: ScraperConfig = config or ScraperConfig() 

492 

493 @abstractmethod 

494 async def scrape( 

495 self, 

496 url: str, 

497 wait_for: Optional[str] = None, 

498 timeout: Optional[int] = None, 

499 **kwargs: Any, 

500 ) -> ScrapedContent: 

501 """ 

502 Scrape content from a URL. 

503 

504 Args: 

505 url: URL to scrape 

506 wait_for: CSS selector to wait for before scraping 

507 timeout: Request timeout (overrides config) 

508 **kwargs: Additional provider-specific parameters 

509 

510 Returns: 

511 ScrapedContent with page content 

512 """ 

513 pass 

514 

515 @abstractmethod 

516 async def scrape_reviews( 

517 self, 

518 url: str, 

519 limit: int = 100, 

520 sort_by: Optional[str] = None, 

521 **kwargs: Any, 

522 ) -> List[Review]: 

523 """ 

524 Extract reviews from a URL. 

525 

526 Args: 

527 url: URL to scrape reviews from 

528 limit: Maximum number of reviews to extract 

529 sort_by: Sort order (implementation-specific) 

530 **kwargs: Additional parameters 

531 

532 Returns: 

533 List of Review objects 

534 """ 

535 pass 

536 

537 async def screenshot( 

538 self, 

539 url: str, 

540 path: str, 

541 full_page: bool = True, 

542 **kwargs: Any, 

543 ) -> str: 

544 """ 

545 Take a screenshot of a page. 

546 

547 Args: 

548 url: URL to screenshot 

549 path: Output file path 

550 full_page: Capture full page or viewport only 

551 **kwargs: Additional parameters 

552 

553 Returns: 

554 Path to saved screenshot 

555 

556 Raises: 

557 NotImplementedError: If provider doesn't support screenshots 

558 """ 

559 if not self.info.capabilities.screenshots: 

560 raise NotImplementedError(f"{self.name} does not support screenshots") 

561 raise NotImplementedError("Subclass must implement screenshot") 

562 

563 def get_supported_platforms(self) -> List[str]: 

564 """ 

565 Get list of supported platforms/domains. 

566 

567 Returns: 

568 List of platform names this scraper supports 

569 """ 

570 return [] 

571 

572 def supports_platform(self, platform: str) -> bool: 

573 """ 

574 Check if scraper supports a specific platform. 

575 

576 Args: 

577 platform: Platform name to check 

578 

579 Returns: 

580 True if platform is supported 

581 """ 

582 supported = self.get_supported_platforms() 

583 if not supported: # Empty means all platforms 583 ↛ 584line 583 didn't jump to line 584 because the condition on line 583 was never true

584 return True 

585 return platform.lower() in [p.lower() for p in supported] 

586 

587 

588class BaseModelProvider(BaseProvider): 

589 """ 

590 Abstract base class for ML model providers (sentiment, emotion, etc.). 

591 

592 All model providers must implement: 

593 - predict: Make a prediction on input 

594 - predict_batch: Make predictions on multiple inputs 

595 """ 

596 

597 def __init__(self, config: Optional[ModelConfig] = None) -> None: 

598 """ 

599 Initialize model provider. 

600 

601 Args: 

602 config: Model configuration 

603 """ 

604 super().__init__(config) 

605 self._config: ModelConfig = config or ModelConfig() 

606 self._model: Any = None 

607 

608 @abstractmethod 

609 async def predict(self, text: str, **kwargs: Any) -> PredictionResult: 

610 """ 

611 Make a prediction on input text. 

612 

613 Args: 

614 text: Input text 

615 **kwargs: Additional parameters 

616 

617 Returns: 

618 PredictionResult with label and scores 

619 """ 

620 pass 

621 

622 @abstractmethod 

623 async def predict_batch( 

624 self, texts: List[str], **kwargs: Any 

625 ) -> List[PredictionResult]: 

626 """ 

627 Make predictions on multiple texts. 

628 

629 Args: 

630 texts: List of input texts 

631 **kwargs: Additional parameters 

632 

633 Returns: 

634 List of PredictionResult objects 

635 """ 

636 pass 

637 

638 @abstractmethod 

639 def get_model_info(self) -> Dict[str, Any]: 

640 """ 

641 Get information about the loaded model. 

642 

643 Returns: 

644 Dictionary with model metadata 

645 """ 

646 pass 

647 

648 @property 

649 def model_name(self) -> str: 

650 """Get loaded model name.""" 

651 return getattr(self._config, "sentiment_model", "unknown") 

652 

653 @property 

654 def device(self) -> str: 

655 """Get device model is running on.""" 

656 return self._config.device 

657 

658 

659# Type variable for generic provider 

660T = TypeVar("T", bound=BaseProvider) 

661 

662 

663class ProviderRegistry: 

664 """ 

665 Registry for provider discovery and instantiation. 

666 

667 Supports registration of provider classes and factory functions. 

668 """ 

669 

670 _instance: Optional["ProviderRegistry"] = None 

671 _providers: Dict[str, Dict[str, Type[BaseProvider]]] = {} 

672 _factories: Dict[str, Dict[str, Callable[..., BaseProvider]]] = {} 

673 

674 def __new__(cls) -> "ProviderRegistry": 

675 """Singleton pattern.""" 

676 if cls._instance is None: 

677 cls._instance = super().__new__(cls) 

678 cls._providers = { 

679 "llm": {}, 

680 "scraper": {}, 

681 "model": {}, 

682 } 

683 cls._factories = { 

684 "llm": {}, 

685 "scraper": {}, 

686 "model": {}, 

687 } 

688 return cls._instance 

689 

690 def register( 

691 self, 

692 name: str, 

693 provider_type: Union[str, ProviderType], 

694 provider_class: Type[BaseProvider], 

695 ) -> None: 

696 """ 

697 Register a provider class. 

698 

699 Args: 

700 name: Provider name (e.g., "openai", "playwright") 

701 provider_type: Type of provider ("llm", "scraper", "model") 

702 provider_class: Provider class to register 

703 """ 

704 if isinstance(provider_type, ProviderType): 704 ↛ 707line 704 didn't jump to line 707 because the condition on line 704 was always true

705 provider_type = provider_type.value 

706 

707 if provider_type not in self._providers: 707 ↛ 708line 707 didn't jump to line 708 because the condition on line 707 was never true

708 self._providers[provider_type] = {} 

709 

710 self._providers[provider_type][name.lower()] = provider_class 

711 

712 def register_factory( 

713 self, 

714 name: str, 

715 provider_type: Union[str, ProviderType], 

716 factory: Callable[..., BaseProvider], 

717 ) -> None: 

718 """ 

719 Register a factory function for creating providers. 

720 

721 Args: 

722 name: Provider name 

723 provider_type: Type of provider 

724 factory: Factory function that creates provider instances 

725 """ 

726 if isinstance(provider_type, ProviderType): 726 ↛ 729line 726 didn't jump to line 729 because the condition on line 726 was always true

727 provider_type = provider_type.value 

728 

729 if provider_type not in self._factories: 729 ↛ 730line 729 didn't jump to line 730 because the condition on line 729 was never true

730 self._factories[provider_type] = {} 

731 

732 self._factories[provider_type][name.lower()] = factory 

733 

734 def get( 

735 self, 

736 name: str, 

737 provider_type: Union[str, ProviderType], 

738 config: Any = None, 

739 **kwargs: Any, 

740 ) -> BaseProvider: 

741 """ 

742 Get a provider instance by name. 

743 

744 Args: 

745 name: Provider name 

746 provider_type: Type of provider 

747 config: Provider configuration 

748 **kwargs: Additional arguments for provider constructor 

749 

750 Returns: 

751 Provider instance 

752 

753 Raises: 

754 ProviderNotFoundError: If provider is not registered 

755 """ 

756 if isinstance(provider_type, ProviderType): 

757 provider_type = provider_type.value 

758 

759 name_lower = name.lower() 

760 

761 # Check factories first 

762 if provider_type in self._factories and name_lower in self._factories[provider_type]: 762 ↛ 763line 762 didn't jump to line 763 because the condition on line 762 was never true

763 factory = self._factories[provider_type][name_lower] 

764 return factory(config, **kwargs) 

765 

766 # Then check registered classes 

767 if provider_type in self._providers and name_lower in self._providers[provider_type]: 767 ↛ 768line 767 didn't jump to line 768 because the condition on line 767 was never true

768 provider_class = self._providers[provider_type][name_lower] 

769 return provider_class(config, **kwargs) 

770 

771 raise ProviderNotFoundError(f"{provider_type}:{name}") 

772 

773 def list_providers( 

774 self, provider_type: Optional[Union[str, ProviderType]] = None 

775 ) -> Dict[str, List[str]]: 

776 """ 

777 List registered providers. 

778 

779 Args: 

780 provider_type: Optional type filter 

781 

782 Returns: 

783 Dictionary of provider types to provider names 

784 """ 

785 if provider_type is not None: 

786 if isinstance(provider_type, ProviderType): 786 ↛ 788line 786 didn't jump to line 788 because the condition on line 786 was always true

787 provider_type = provider_type.value 

788 return { 

789 provider_type: list(self._providers.get(provider_type, {}).keys()) 

790 + list(self._factories.get(provider_type, {}).keys()) 

791 } 

792 

793 result = {} 

794 for pt in ["llm", "scraper", "model"]: 

795 result[pt] = list(self._providers.get(pt, {}).keys()) + list( 

796 self._factories.get(pt, {}).keys() 

797 ) 

798 return result 

799 

800 def is_registered( 

801 self, name: str, provider_type: Union[str, ProviderType] 

802 ) -> bool: 

803 """ 

804 Check if a provider is registered. 

805 

806 Args: 

807 name: Provider name 

808 provider_type: Type of provider 

809 

810 Returns: 

811 True if provider is registered 

812 """ 

813 if isinstance(provider_type, ProviderType): 813 ↛ 816line 813 didn't jump to line 816 because the condition on line 813 was always true

814 provider_type = provider_type.value 

815 

816 name_lower = name.lower() 

817 return ( 

818 name_lower in self._providers.get(provider_type, {}) 

819 or name_lower in self._factories.get(provider_type, {}) 

820 ) 

821 

822 

823# Module-level convenience functions 

824 

825 

826def get_provider( 

827 name: str, 

828 provider_type: Union[str, ProviderType], 

829 config: Any = None, 

830 **kwargs: Any, 

831) -> BaseProvider: 

832 """ 

833 Get a provider instance by name. 

834 

835 Args: 

836 name: Provider name (e.g., "openai", "playwright") 

837 provider_type: Type of provider ("llm", "scraper", "model") 

838 config: Provider configuration 

839 **kwargs: Additional arguments 

840 

841 Returns: 

842 Provider instance 

843 

844 Example: 

845 >>> provider = get_provider("openai", "llm", config=llm_config) 

846 >>> await provider.initialize() 

847 """ 

848 registry = ProviderRegistry() 

849 return registry.get(name, provider_type, config, **kwargs) 

850 

851 

852def register_provider( 

853 name: str, 

854 provider_type: Union[str, ProviderType], 

855 provider_class: Type[BaseProvider], 

856) -> None: 

857 """ 

858 Register a provider class. 

859 

860 Args: 

861 name: Provider name 

862 provider_type: Type of provider 

863 provider_class: Provider class to register 

864 

865 Example: 

866 >>> register_provider("custom", "llm", CustomLLMProvider) 

867 """ 

868 registry = ProviderRegistry() 

869 registry.register(name, provider_type, provider_class) 

870 

871 

872def list_providers( 

873 provider_type: Optional[Union[str, ProviderType]] = None 

874) -> Dict[str, List[str]]: 

875 """ 

876 List registered providers. 

877 

878 Args: 

879 provider_type: Optional type filter 

880 

881 Returns: 

882 Dictionary of provider types to provider names 

883 """ 

884 registry = ProviderRegistry() 

885 return registry.list_providers(provider_type)