Coverage for session_buddy / llm / providers / openai_provider.py: 25.00%
56 statements
« prev ^ index » next coverage.py v7.13.1, created at 2026-01-04 00:43 -0800
« prev ^ index » next coverage.py v7.13.1, created at 2026-01-04 00:43 -0800
1"""OpenAI API provider implementation.
3This module provides the OpenAI provider implementation using the official
4OpenAI Python SDK for chat completions and streaming.
5"""
7from __future__ import annotations
9from datetime import datetime
10from typing import TYPE_CHECKING, Any
12from session_buddy.llm.base import LLMProvider
13from session_buddy.llm.models import LLMMessage, LLMResponse
15if TYPE_CHECKING:
16 from collections.abc import AsyncGenerator
19class OpenAIProvider(LLMProvider):
20 """OpenAI API provider."""
22 def __init__(self, config: dict[str, Any]) -> None:
23 super().__init__(config)
24 self.api_key = config.get("api_key")
25 self.base_url = config.get("base_url", "https://api.openai.com/v1")
26 self.default_model = config.get("default_model", "gpt-4")
27 self._client: Any = None
29 async def _get_client(self) -> Any:
30 """Get or create OpenAI client."""
31 if self._client is None:
32 try:
33 import openai
35 self._client = openai.AsyncOpenAI(
36 api_key=self.api_key,
37 base_url=self.base_url,
38 )
39 except ImportError:
40 msg = "OpenAI package not installed. Install with: pip install openai"
41 raise ImportError(
42 msg,
43 )
44 return self._client
46 def _convert_messages(self, messages: list[LLMMessage]) -> list[dict[str, str]]:
47 """Convert LLMMessage objects to OpenAI format."""
48 return [{"role": msg.role, "content": msg.content} for msg in messages]
50 async def generate(
51 self,
52 messages: list[LLMMessage],
53 model: str | None = None,
54 temperature: float = 0.7,
55 max_tokens: int | None = None,
56 **kwargs: Any,
57 ) -> LLMResponse:
58 """Generate response using OpenAI API."""
59 if not await self.is_available():
60 msg = "OpenAI provider not available"
61 raise RuntimeError(msg)
63 client = await self._get_client()
64 model_name = model or self.default_model
66 try:
67 response = await client.chat.completions.create(
68 model=model_name,
69 messages=self._convert_messages(messages),
70 temperature=temperature,
71 max_tokens=max_tokens,
72 **kwargs,
73 )
75 return LLMResponse(
76 content=response.choices[0].message.content,
77 model=model_name,
78 provider="openai",
79 usage={
80 "prompt_tokens": response.usage.prompt_tokens
81 if response.usage
82 else 0,
83 "completion_tokens": response.usage.completion_tokens
84 if response.usage
85 else 0,
86 "total_tokens": response.usage.total_tokens
87 if response.usage
88 else 0,
89 },
90 finish_reason=response.choices[0].finish_reason,
91 timestamp=datetime.now().isoformat(),
92 metadata={"response_id": response.id},
93 )
95 except Exception as e:
96 self.logger.exception(f"OpenAI generation failed: {e}")
97 raise
99 async def stream_generate( # type: ignore[override]
100 self,
101 messages: list[LLMMessage],
102 model: str | None = None,
103 temperature: float = 0.7,
104 max_tokens: int | None = None,
105 **kwargs: Any,
106 ) -> AsyncGenerator[str]:
107 """Stream response using OpenAI API."""
108 if not await self.is_available():
109 msg = "OpenAI provider not available"
110 raise RuntimeError(msg)
112 client = await self._get_client()
113 model_name = model or self.default_model
115 try:
116 response = await client.chat.completions.create(
117 model=model_name,
118 messages=self._convert_messages(messages),
119 temperature=temperature,
120 max_tokens=max_tokens,
121 stream=True,
122 **kwargs,
123 )
125 async for chunk in response:
126 if chunk.choices[0].delta.content:
127 yield chunk.choices[0].delta.content
129 except Exception as e:
130 self.logger.exception(f"OpenAI streaming failed: {e}")
131 raise
133 async def is_available(self) -> bool:
134 """Check if OpenAI API is available."""
135 if not self.api_key:
136 return False
138 try:
139 client = await self._get_client()
140 # Test with a simple request
141 await client.models.list()
142 return True
143 except Exception:
144 return False
146 def get_models(self) -> list[str]:
147 """Get available OpenAI models."""
148 return [
149 "gpt-4",
150 "gpt-4-turbo",
151 "gpt-4o",
152 "gpt-4o-mini",
153 "gpt-3.5-turbo",
154 "gpt-3.5-turbo-16k",
155 ]