Coverage for session_buddy / llm / providers / openai_provider.py: 25.00%

56 statements  

« prev     ^ index     » next       coverage.py v7.13.1, created at 2026-01-04 00:43 -0800

1"""OpenAI API provider implementation. 

2 

3This module provides the OpenAI provider implementation using the official 

4OpenAI Python SDK for chat completions and streaming. 

5""" 

6 

7from __future__ import annotations 

8 

9from datetime import datetime 

10from typing import TYPE_CHECKING, Any 

11 

12from session_buddy.llm.base import LLMProvider 

13from session_buddy.llm.models import LLMMessage, LLMResponse 

14 

15if TYPE_CHECKING: 

16 from collections.abc import AsyncGenerator 

17 

18 

19class OpenAIProvider(LLMProvider): 

20 """OpenAI API provider.""" 

21 

22 def __init__(self, config: dict[str, Any]) -> None: 

23 super().__init__(config) 

24 self.api_key = config.get("api_key") 

25 self.base_url = config.get("base_url", "https://api.openai.com/v1") 

26 self.default_model = config.get("default_model", "gpt-4") 

27 self._client: Any = None 

28 

29 async def _get_client(self) -> Any: 

30 """Get or create OpenAI client.""" 

31 if self._client is None: 

32 try: 

33 import openai 

34 

35 self._client = openai.AsyncOpenAI( 

36 api_key=self.api_key, 

37 base_url=self.base_url, 

38 ) 

39 except ImportError: 

40 msg = "OpenAI package not installed. Install with: pip install openai" 

41 raise ImportError( 

42 msg, 

43 ) 

44 return self._client 

45 

46 def _convert_messages(self, messages: list[LLMMessage]) -> list[dict[str, str]]: 

47 """Convert LLMMessage objects to OpenAI format.""" 

48 return [{"role": msg.role, "content": msg.content} for msg in messages] 

49 

50 async def generate( 

51 self, 

52 messages: list[LLMMessage], 

53 model: str | None = None, 

54 temperature: float = 0.7, 

55 max_tokens: int | None = None, 

56 **kwargs: Any, 

57 ) -> LLMResponse: 

58 """Generate response using OpenAI API.""" 

59 if not await self.is_available(): 

60 msg = "OpenAI provider not available" 

61 raise RuntimeError(msg) 

62 

63 client = await self._get_client() 

64 model_name = model or self.default_model 

65 

66 try: 

67 response = await client.chat.completions.create( 

68 model=model_name, 

69 messages=self._convert_messages(messages), 

70 temperature=temperature, 

71 max_tokens=max_tokens, 

72 **kwargs, 

73 ) 

74 

75 return LLMResponse( 

76 content=response.choices[0].message.content, 

77 model=model_name, 

78 provider="openai", 

79 usage={ 

80 "prompt_tokens": response.usage.prompt_tokens 

81 if response.usage 

82 else 0, 

83 "completion_tokens": response.usage.completion_tokens 

84 if response.usage 

85 else 0, 

86 "total_tokens": response.usage.total_tokens 

87 if response.usage 

88 else 0, 

89 }, 

90 finish_reason=response.choices[0].finish_reason, 

91 timestamp=datetime.now().isoformat(), 

92 metadata={"response_id": response.id}, 

93 ) 

94 

95 except Exception as e: 

96 self.logger.exception(f"OpenAI generation failed: {e}") 

97 raise 

98 

99 async def stream_generate( # type: ignore[override] 

100 self, 

101 messages: list[LLMMessage], 

102 model: str | None = None, 

103 temperature: float = 0.7, 

104 max_tokens: int | None = None, 

105 **kwargs: Any, 

106 ) -> AsyncGenerator[str]: 

107 """Stream response using OpenAI API.""" 

108 if not await self.is_available(): 

109 msg = "OpenAI provider not available" 

110 raise RuntimeError(msg) 

111 

112 client = await self._get_client() 

113 model_name = model or self.default_model 

114 

115 try: 

116 response = await client.chat.completions.create( 

117 model=model_name, 

118 messages=self._convert_messages(messages), 

119 temperature=temperature, 

120 max_tokens=max_tokens, 

121 stream=True, 

122 **kwargs, 

123 ) 

124 

125 async for chunk in response: 

126 if chunk.choices[0].delta.content: 

127 yield chunk.choices[0].delta.content 

128 

129 except Exception as e: 

130 self.logger.exception(f"OpenAI streaming failed: {e}") 

131 raise 

132 

133 async def is_available(self) -> bool: 

134 """Check if OpenAI API is available.""" 

135 if not self.api_key: 

136 return False 

137 

138 try: 

139 client = await self._get_client() 

140 # Test with a simple request 

141 await client.models.list() 

142 return True 

143 except Exception: 

144 return False 

145 

146 def get_models(self) -> list[str]: 

147 """Get available OpenAI models.""" 

148 return [ 

149 "gpt-4", 

150 "gpt-4-turbo", 

151 "gpt-4o", 

152 "gpt-4o-mini", 

153 "gpt-3.5-turbo", 

154 "gpt-3.5-turbo-16k", 

155 ]