Coverage for intelligence_toolkit/AI/client.py: 81%

90 statements  

« prev     ^ index     » next       coverage.py v7.10.7, created at 2025-10-16 13:41 -0300

1# Copyright (c) 2024 Microsoft Corporation. All rights reserved. 

2# Licensed under the MIT license. See LICENSE file in the project. 

3# 

4import logging 

5 

6from azure.identity import DefaultAzureCredential, get_bearer_token_provider 

7from openai import AsyncAzureOpenAI, AsyncOpenAI, AzureOpenAI, OpenAI 

8 

9from intelligence_toolkit.AI.classes import LLMCallback 

10 

11from .defaults import API_BASE_REQUIRED_FOR_AZURE, DEFAULT_EMBEDDING_MODEL 

12from .openai_configuration import OpenAIConfiguration 

13 

14log = logging.getLogger(__name__) 

15 

16 

17class OpenAIClient: 

18 """OpenAI Client class definition.""" 

19 

20 _client = None 

21 _async_client = None 

22 

23 def __init__(self, configuration: OpenAIConfiguration | None = None) -> None: 

24 self.configuration = configuration or OpenAIConfiguration() 

25 self._create_openai_client() 

26 

27 def _create_openai_client(self) -> None: 

28 """Create a new OpenAI client instance.""" 

29 if self.configuration.api_type == "Azure OpenAI": 

30 api_base = self.configuration.api_base 

31 if api_base is None: 

32 raise ValueError(API_BASE_REQUIRED_FOR_AZURE) 

33 log.info( 

34 "Creating Azure OpenAI client api_base=%s", 

35 api_base, 

36 ) 

37 

38 if self.configuration.az_auth_type == "Managed Identity": 

39 token_provider = get_bearer_token_provider( 

40 DefaultAzureCredential(), 

41 "https://cognitiveservices.azure.com/.default", 

42 ) 

43 

44 self._client = AzureOpenAI( 

45 api_version=self.configuration.api_version, 

46 # Azure-Specifics 

47 azure_ad_token_provider=token_provider, 

48 azure_endpoint=api_base, 

49 ) 

50 self._async_client = AsyncAzureOpenAI( 

51 api_version=self.configuration.api_version, 

52 # Azure-Specifics 

53 azure_ad_token_provider=token_provider, 

54 azure_endpoint=api_base, 

55 ) 

56 else: 

57 self._client = AzureOpenAI( 

58 api_version=self.configuration.api_version, 

59 # Azure-Specifics 

60 azure_endpoint=api_base, 

61 api_key=self.configuration.api_key, 

62 ) 

63 self._async_client = AsyncAzureOpenAI( 

64 api_version=self.configuration.api_version, 

65 # Azure-Specifics 

66 azure_endpoint=api_base, 

67 api_key=self.configuration.api_key, 

68 ) 

69 else: 

70 log.info("Creating OpenAI client") 

71 self._client = OpenAI( 

72 api_key=self.configuration.api_key, 

73 ) 

74 self._async_client = AsyncOpenAI( 

75 api_key=self.configuration.api_key, 

76 ) 

77 

78 def generate_chat( 

79 self, 

80 messages: list[str], 

81 stream: bool = True, 

82 callbacks: list[LLMCallback] | None = None, 

83 **kwargs, 

84 ): 

85 try: 

86 if "max_tokens" in kwargs.keys(): 

87 max_tokens = kwargs["max_tokens"] 

88 kwargs.pop("max_tokens") 

89 else: 

90 max_tokens = self.configuration.max_tokens 

91 if "temperature" in kwargs.keys(): 

92 temperature = kwargs["temperature"] 

93 kwargs.pop("temperature") 

94 else: 

95 temperature = self.configuration.temperature 

96 response = self._client.chat.completions.create( 

97 model=self.configuration.model, 

98 temperature=temperature, 

99 max_tokens=max_tokens, 

100 messages=messages, 

101 stream=stream, 

102 **kwargs, 

103 ) 

104 if stream and callbacks is not None: 

105 full_response = "" 

106 for chunk in response: 

107 if len(chunk.choices) > 0: 

108 delta = chunk.choices[0].delta.content or "" # type: ignore 

109 if delta is not None: 

110 full_response += delta 

111 if callbacks: 

112 show = full_response 

113 if len(delta) > 0: 

114 show += "▌" 

115 for callback in callbacks: 

116 callback.on_llm_new_token(show) 

117 return full_response 

118 

119 return ( 

120 response.choices[0].message.content if len(response.choices) > 0 else "" 

121 ) # type: ignore 

122 except Exception as e: 

123 print(f"Error validating report: {e}") 

124 msg = f"Problem in OpenAI response. {e}" 

125 raise Exception(msg) 

126 

127 async def generate_chat_async( 

128 self, 

129 messages: list[str], 

130 stream: bool = True, 

131 callbacks: list[LLMCallback] | None = None, 

132 **kwargs, 

133 ): 

134 if "max_tokens" in kwargs.keys(): 

135 max_tokens = kwargs["max_tokens"] 

136 kwargs.pop("max_tokens") 

137 else: 

138 max_tokens = self.configuration.max_tokens 

139 if "temperature" in kwargs.keys(): 

140 temperature = kwargs["temperature"] 

141 kwargs.pop("temperature") 

142 else: 

143 temperature = self.configuration.temperature 

144 response = await self._async_client.chat.completions.create( 

145 model=self.configuration.model, 

146 temperature=temperature, 

147 max_tokens=max_tokens, 

148 messages=messages, 

149 stream=stream, 

150 **kwargs, 

151 ) 

152 if stream and callbacks is not None: 

153 full_response = "" 

154 async for chunk in response: 

155 delta = chunk.choices[0].delta.content or "" # type: ignore 

156 if delta is not None: 

157 full_response += delta 

158 if callbacks: 

159 show = full_response 

160 if len(delta) > 0: 

161 show += "▌" 

162 for callback in callbacks: 

163 callback.on_llm_new_token(show) 

164 return full_response 

165 

166 return response.choices[0].message.content or "" # type: ignore 

167 

168 def generate_embedding( 

169 self, text: str, model: str = DEFAULT_EMBEDDING_MODEL 

170 ) -> list[float]: 

171 embedding = self._client.embeddings.create(input=text, model=model) 

172 return embedding.data[0].embedding 

173 

174 def generate_embeddings( 

175 self, text: list[str], model: str = DEFAULT_EMBEDDING_MODEL 

176 ) -> list[float]: 

177 return self._client.embeddings.create(input=text, model=model) 

178 

179 async def generate_embedding_async( 

180 self, text: list[str], model: str = DEFAULT_EMBEDDING_MODEL 

181 ) -> list[float]: 

182 embedding = await self._async_client.embeddings.create(input=text, model=model) 

183 return embedding.data[0].embedding