Coverage for intelligence_toolkit/tests/unit/AI/test_client.py: 100%
126 statements
« prev ^ index » next coverage.py v7.10.7, created at 2025-10-16 13:41 -0300
« prev ^ index » next coverage.py v7.10.7, created at 2025-10-16 13:41 -0300
1# Copyright (c) 2024 Microsoft Corporation. All rights reserved.
2# Licensed under the MIT license. See LICENSE file in the project.
3#
4from unittest.mock import AsyncMock, MagicMock, patch
6import pytest
8from intelligence_toolkit.AI.classes import LLMCallback
9from intelligence_toolkit.AI.client import OpenAIClient
10from intelligence_toolkit.AI.defaults import DEFAULT_EMBEDDING_MODEL
11from intelligence_toolkit.AI.openai_configuration import OpenAIConfiguration
14@pytest.fixture
15def openai_config():
16 return OpenAIConfiguration({
17 "api_key": "test_key",
18 "model": "gpt-4",
19 "api_type": "OpenAI",
20 })
23@pytest.fixture
24def azure_openai_config():
25 return OpenAIConfiguration({
26 "api_key": "test_key",
27 "model": "gpt-4",
28 "api_type": "Azure OpenAI",
29 "api_base": "https://test.openai.azure.com",
30 "api_version": "2024-02-01",
31 "az_auth_type": "Azure Key",
32 })
35def test_openai_client_initialization_openai(openai_config):
36 with patch("intelligence_toolkit.AI.client.OpenAI") as mock_openai, \
37 patch("intelligence_toolkit.AI.client.AsyncOpenAI") as mock_async_openai:
39 client = OpenAIClient(openai_config)
41 mock_openai.assert_called_once_with(api_key="test_key")
42 mock_async_openai.assert_called_once_with(api_key="test_key")
43 assert client.configuration == openai_config
46def test_openai_client_initialization_azure_with_key(azure_openai_config):
47 with patch("intelligence_toolkit.AI.client.AzureOpenAI") as mock_azure, \
48 patch("intelligence_toolkit.AI.client.AsyncAzureOpenAI") as mock_async_azure:
50 client = OpenAIClient(azure_openai_config)
52 mock_azure.assert_called_once_with(
53 api_version="2024-02-01",
54 azure_endpoint="https://test.openai.azure.com",
55 api_key="test_key",
56 )
57 mock_async_azure.assert_called_once_with(
58 api_version="2024-02-01",
59 azure_endpoint="https://test.openai.azure.com",
60 api_key="test_key",
61 )
64def test_openai_client_initialization_azure_without_api_base():
65 config = OpenAIConfiguration({
66 "api_key": "test_key",
67 "model": "gpt-4",
68 "api_type": "Azure OpenAI",
69 "api_base": None,
70 })
72 with pytest.raises(ValueError, match="api_base is required for Azure OpenAI client"):
73 OpenAIClient(config)
76def test_openai_client_initialization_azure_managed_identity():
77 config = OpenAIConfiguration({
78 "api_key": "test_key",
79 "model": "gpt-4",
80 "api_type": "Azure OpenAI",
81 "api_base": "https://test.openai.azure.com",
82 "api_version": "2024-02-01",
83 "az_auth_type": "Managed Identity",
84 })
86 with patch("intelligence_toolkit.AI.client.DefaultAzureCredential") as mock_cred, \
87 patch("intelligence_toolkit.AI.client.get_bearer_token_provider") as mock_token, \
88 patch("intelligence_toolkit.AI.client.AzureOpenAI") as mock_azure, \
89 patch("intelligence_toolkit.AI.client.AsyncAzureOpenAI") as mock_async_azure:
91 mock_token.return_value = "mock_token_provider"
93 client = OpenAIClient(config)
95 mock_cred.assert_called_once()
96 mock_token.assert_called_once_with(
97 mock_cred.return_value,
98 "https://cognitiveservices.azure.com/.default",
99 )
101 mock_azure.assert_called_once_with(
102 api_version="2024-02-01",
103 azure_ad_token_provider="mock_token_provider",
104 azure_endpoint="https://test.openai.azure.com",
105 )
108def test_openai_client_generate_chat_non_streaming(openai_config):
109 with patch("intelligence_toolkit.AI.client.OpenAI") as mock_openai_class, \
110 patch("intelligence_toolkit.AI.client.AsyncOpenAI"):
112 mock_client = MagicMock()
113 mock_openai_class.return_value = mock_client
115 mock_response = MagicMock()
116 mock_response.choices = [MagicMock()]
117 mock_response.choices[0].message.content = "Test response"
118 mock_client.chat.completions.create.return_value = mock_response
120 client = OpenAIClient(openai_config)
121 messages = [{"role": "user", "content": "Hello"}]
122 result = client.generate_chat(messages, stream=False)
124 assert result == "Test response"
125 mock_client.chat.completions.create.assert_called_once()
128def test_openai_client_generate_chat_streaming_with_callbacks(openai_config):
129 with patch("intelligence_toolkit.AI.client.OpenAI") as mock_openai_class, \
130 patch("intelligence_toolkit.AI.client.AsyncOpenAI"):
132 mock_client = MagicMock()
133 mock_openai_class.return_value = mock_client
135 # Mock streaming response
136 chunk1 = MagicMock()
137 chunk1.choices = [MagicMock()]
138 chunk1.choices[0].delta.content = "Hello"
140 chunk2 = MagicMock()
141 chunk2.choices = [MagicMock()]
142 chunk2.choices[0].delta.content = " World"
144 mock_client.chat.completions.create.return_value = [chunk1, chunk2]
146 callback = LLMCallback()
147 client = OpenAIClient(openai_config)
148 messages = [{"role": "user", "content": "Hello"}]
149 result = client.generate_chat(messages, stream=True, callbacks=[callback])
151 assert result == "Hello World"
152 assert len(callback.response) > 0
155def test_openai_client_generate_chat_exception_handling(openai_config):
156 with patch("intelligence_toolkit.AI.client.OpenAI") as mock_openai_class, \
157 patch("intelligence_toolkit.AI.client.AsyncOpenAI"):
159 mock_client = MagicMock()
160 mock_openai_class.return_value = mock_client
161 mock_client.chat.completions.create.side_effect = Exception("API Error")
163 client = OpenAIClient(openai_config)
164 messages = [{"role": "user", "content": "Hello"}]
166 with pytest.raises(Exception, match="Problem in OpenAI response"):
167 client.generate_chat(messages, stream=False)
170def test_openai_client_generate_embedding(openai_config):
171 with patch("intelligence_toolkit.AI.client.OpenAI") as mock_openai_class, \
172 patch("intelligence_toolkit.AI.client.AsyncOpenAI"):
174 mock_client = MagicMock()
175 mock_openai_class.return_value = mock_client
177 mock_embedding_response = MagicMock()
178 mock_embedding_response.data = [MagicMock()]
179 mock_embedding_response.data[0].embedding = [0.1, 0.2, 0.3]
180 mock_client.embeddings.create.return_value = mock_embedding_response
182 client = OpenAIClient(openai_config)
183 result = client.generate_embedding("test text")
185 assert result == [0.1, 0.2, 0.3]
186 mock_client.embeddings.create.assert_called_once_with(
187 input="test text",
188 model=DEFAULT_EMBEDDING_MODEL
189 )
192def test_openai_client_generate_chat_custom_params(openai_config):
193 with patch("intelligence_toolkit.AI.client.OpenAI") as mock_openai_class, \
194 patch("intelligence_toolkit.AI.client.AsyncOpenAI"):
196 mock_client = MagicMock()
197 mock_openai_class.return_value = mock_client
199 mock_response = MagicMock()
200 mock_response.choices = [MagicMock()]
201 mock_response.choices[0].message.content = "Test response"
202 mock_client.chat.completions.create.return_value = mock_response
204 client = OpenAIClient(openai_config)
205 messages = [{"role": "user", "content": "Hello"}]
206 result = client.generate_chat(
207 messages,
208 stream=False,
209 max_tokens=1000,
210 temperature=0.5
211 )
213 assert result == "Test response"
214 call_kwargs = mock_client.chat.completions.create.call_args[1]
215 assert call_kwargs["max_tokens"] == 1000
216 assert call_kwargs["temperature"] == 0.5
219@pytest.mark.asyncio
220async def test_openai_client_generate_chat_async(openai_config):
221 with patch("intelligence_toolkit.AI.client.OpenAI"), \
222 patch("intelligence_toolkit.AI.client.AsyncOpenAI") as mock_async_class:
224 mock_async_client = MagicMock()
225 mock_async_class.return_value = mock_async_client
227 mock_response = MagicMock()
228 mock_response.choices = [MagicMock()]
229 mock_response.choices[0].message.content = "Async response"
230 mock_async_client.chat.completions.create = AsyncMock(return_value=mock_response)
232 client = OpenAIClient(openai_config)
233 messages = [{"role": "user", "content": "Hello"}]
234 result = await client.generate_chat_async(messages, stream=False)
236 assert result == "Async response"
239@pytest.mark.asyncio
240async def test_openai_client_generate_embedding_async(openai_config):
241 with patch("intelligence_toolkit.AI.client.OpenAI"), \
242 patch("intelligence_toolkit.AI.client.AsyncOpenAI") as mock_async_class:
244 mock_async_client = MagicMock()
245 mock_async_class.return_value = mock_async_client
247 mock_embedding_response = MagicMock()
248 mock_embedding_response.data = [MagicMock()]
249 mock_embedding_response.data[0].embedding = [0.4, 0.5, 0.6]
250 mock_async_client.embeddings.create = AsyncMock(return_value=mock_embedding_response)
252 client = OpenAIClient(openai_config)
253 result = await client.generate_embedding_async("test text")
255 assert result == [0.4, 0.5, 0.6]