Coverage for intelligence_toolkit/AI/client.py: 81%
90 statements
« prev ^ index » next coverage.py v7.10.7, created at 2025-10-16 13:41 -0300
« prev ^ index » next coverage.py v7.10.7, created at 2025-10-16 13:41 -0300
1# Copyright (c) 2024 Microsoft Corporation. All rights reserved.
2# Licensed under the MIT license. See LICENSE file in the project.
3#
4import logging
6from azure.identity import DefaultAzureCredential, get_bearer_token_provider
7from openai import AsyncAzureOpenAI, AsyncOpenAI, AzureOpenAI, OpenAI
9from intelligence_toolkit.AI.classes import LLMCallback
11from .defaults import API_BASE_REQUIRED_FOR_AZURE, DEFAULT_EMBEDDING_MODEL
12from .openai_configuration import OpenAIConfiguration
14log = logging.getLogger(__name__)
17class OpenAIClient:
18 """OpenAI Client class definition."""
20 _client = None
21 _async_client = None
23 def __init__(self, configuration: OpenAIConfiguration | None = None) -> None:
24 self.configuration = configuration or OpenAIConfiguration()
25 self._create_openai_client()
27 def _create_openai_client(self) -> None:
28 """Create a new OpenAI client instance."""
29 if self.configuration.api_type == "Azure OpenAI":
30 api_base = self.configuration.api_base
31 if api_base is None:
32 raise ValueError(API_BASE_REQUIRED_FOR_AZURE)
33 log.info(
34 "Creating Azure OpenAI client api_base=%s",
35 api_base,
36 )
38 if self.configuration.az_auth_type == "Managed Identity":
39 token_provider = get_bearer_token_provider(
40 DefaultAzureCredential(),
41 "https://cognitiveservices.azure.com/.default",
42 )
44 self._client = AzureOpenAI(
45 api_version=self.configuration.api_version,
46 # Azure-Specifics
47 azure_ad_token_provider=token_provider,
48 azure_endpoint=api_base,
49 )
50 self._async_client = AsyncAzureOpenAI(
51 api_version=self.configuration.api_version,
52 # Azure-Specifics
53 azure_ad_token_provider=token_provider,
54 azure_endpoint=api_base,
55 )
56 else:
57 self._client = AzureOpenAI(
58 api_version=self.configuration.api_version,
59 # Azure-Specifics
60 azure_endpoint=api_base,
61 api_key=self.configuration.api_key,
62 )
63 self._async_client = AsyncAzureOpenAI(
64 api_version=self.configuration.api_version,
65 # Azure-Specifics
66 azure_endpoint=api_base,
67 api_key=self.configuration.api_key,
68 )
69 else:
70 log.info("Creating OpenAI client")
71 self._client = OpenAI(
72 api_key=self.configuration.api_key,
73 )
74 self._async_client = AsyncOpenAI(
75 api_key=self.configuration.api_key,
76 )
78 def generate_chat(
79 self,
80 messages: list[str],
81 stream: bool = True,
82 callbacks: list[LLMCallback] | None = None,
83 **kwargs,
84 ):
85 try:
86 if "max_tokens" in kwargs.keys():
87 max_tokens = kwargs["max_tokens"]
88 kwargs.pop("max_tokens")
89 else:
90 max_tokens = self.configuration.max_tokens
91 if "temperature" in kwargs.keys():
92 temperature = kwargs["temperature"]
93 kwargs.pop("temperature")
94 else:
95 temperature = self.configuration.temperature
96 response = self._client.chat.completions.create(
97 model=self.configuration.model,
98 temperature=temperature,
99 max_tokens=max_tokens,
100 messages=messages,
101 stream=stream,
102 **kwargs,
103 )
104 if stream and callbacks is not None:
105 full_response = ""
106 for chunk in response:
107 if len(chunk.choices) > 0:
108 delta = chunk.choices[0].delta.content or "" # type: ignore
109 if delta is not None:
110 full_response += delta
111 if callbacks:
112 show = full_response
113 if len(delta) > 0:
114 show += "▌"
115 for callback in callbacks:
116 callback.on_llm_new_token(show)
117 return full_response
119 return (
120 response.choices[0].message.content if len(response.choices) > 0 else ""
121 ) # type: ignore
122 except Exception as e:
123 print(f"Error validating report: {e}")
124 msg = f"Problem in OpenAI response. {e}"
125 raise Exception(msg)
127 async def generate_chat_async(
128 self,
129 messages: list[str],
130 stream: bool = True,
131 callbacks: list[LLMCallback] | None = None,
132 **kwargs,
133 ):
134 if "max_tokens" in kwargs.keys():
135 max_tokens = kwargs["max_tokens"]
136 kwargs.pop("max_tokens")
137 else:
138 max_tokens = self.configuration.max_tokens
139 if "temperature" in kwargs.keys():
140 temperature = kwargs["temperature"]
141 kwargs.pop("temperature")
142 else:
143 temperature = self.configuration.temperature
144 response = await self._async_client.chat.completions.create(
145 model=self.configuration.model,
146 temperature=temperature,
147 max_tokens=max_tokens,
148 messages=messages,
149 stream=stream,
150 **kwargs,
151 )
152 if stream and callbacks is not None:
153 full_response = ""
154 async for chunk in response:
155 delta = chunk.choices[0].delta.content or "" # type: ignore
156 if delta is not None:
157 full_response += delta
158 if callbacks:
159 show = full_response
160 if len(delta) > 0:
161 show += "▌"
162 for callback in callbacks:
163 callback.on_llm_new_token(show)
164 return full_response
166 return response.choices[0].message.content or "" # type: ignore
168 def generate_embedding(
169 self, text: str, model: str = DEFAULT_EMBEDDING_MODEL
170 ) -> list[float]:
171 embedding = self._client.embeddings.create(input=text, model=model)
172 return embedding.data[0].embedding
174 def generate_embeddings(
175 self, text: list[str], model: str = DEFAULT_EMBEDDING_MODEL
176 ) -> list[float]:
177 return self._client.embeddings.create(input=text, model=model)
179 async def generate_embedding_async(
180 self, text: list[str], model: str = DEFAULT_EMBEDDING_MODEL
181 ) -> list[float]:
182 embedding = await self._async_client.embeddings.create(input=text, model=model)
183 return embedding.data[0].embedding