Coverage for intelligence_toolkit/tests/unit/AI/test_client.py: 100%

126 statements  

« prev     ^ index     » next       coverage.py v7.10.7, created at 2025-10-16 13:41 -0300

1# Copyright (c) 2024 Microsoft Corporation. All rights reserved. 

2# Licensed under the MIT license. See LICENSE file in the project. 

3# 

4from unittest.mock import AsyncMock, MagicMock, patch 

5 

6import pytest 

7 

8from intelligence_toolkit.AI.classes import LLMCallback 

9from intelligence_toolkit.AI.client import OpenAIClient 

10from intelligence_toolkit.AI.defaults import DEFAULT_EMBEDDING_MODEL 

11from intelligence_toolkit.AI.openai_configuration import OpenAIConfiguration 

12 

13 

14@pytest.fixture 

15def openai_config(): 

16 return OpenAIConfiguration({ 

17 "api_key": "test_key", 

18 "model": "gpt-4", 

19 "api_type": "OpenAI", 

20 }) 

21 

22 

23@pytest.fixture 

24def azure_openai_config(): 

25 return OpenAIConfiguration({ 

26 "api_key": "test_key", 

27 "model": "gpt-4", 

28 "api_type": "Azure OpenAI", 

29 "api_base": "https://test.openai.azure.com", 

30 "api_version": "2024-02-01", 

31 "az_auth_type": "Azure Key", 

32 }) 

33 

34 

35def test_openai_client_initialization_openai(openai_config): 

36 with patch("intelligence_toolkit.AI.client.OpenAI") as mock_openai, \ 

37 patch("intelligence_toolkit.AI.client.AsyncOpenAI") as mock_async_openai: 

38 

39 client = OpenAIClient(openai_config) 

40 

41 mock_openai.assert_called_once_with(api_key="test_key") 

42 mock_async_openai.assert_called_once_with(api_key="test_key") 

43 assert client.configuration == openai_config 

44 

45 

46def test_openai_client_initialization_azure_with_key(azure_openai_config): 

47 with patch("intelligence_toolkit.AI.client.AzureOpenAI") as mock_azure, \ 

48 patch("intelligence_toolkit.AI.client.AsyncAzureOpenAI") as mock_async_azure: 

49 

50 client = OpenAIClient(azure_openai_config) 

51 

52 mock_azure.assert_called_once_with( 

53 api_version="2024-02-01", 

54 azure_endpoint="https://test.openai.azure.com", 

55 api_key="test_key", 

56 ) 

57 mock_async_azure.assert_called_once_with( 

58 api_version="2024-02-01", 

59 azure_endpoint="https://test.openai.azure.com", 

60 api_key="test_key", 

61 ) 

62 

63 

64def test_openai_client_initialization_azure_without_api_base(): 

65 config = OpenAIConfiguration({ 

66 "api_key": "test_key", 

67 "model": "gpt-4", 

68 "api_type": "Azure OpenAI", 

69 "api_base": None, 

70 }) 

71 

72 with pytest.raises(ValueError, match="api_base is required for Azure OpenAI client"): 

73 OpenAIClient(config) 

74 

75 

76def test_openai_client_initialization_azure_managed_identity(): 

77 config = OpenAIConfiguration({ 

78 "api_key": "test_key", 

79 "model": "gpt-4", 

80 "api_type": "Azure OpenAI", 

81 "api_base": "https://test.openai.azure.com", 

82 "api_version": "2024-02-01", 

83 "az_auth_type": "Managed Identity", 

84 }) 

85 

86 with patch("intelligence_toolkit.AI.client.DefaultAzureCredential") as mock_cred, \ 

87 patch("intelligence_toolkit.AI.client.get_bearer_token_provider") as mock_token, \ 

88 patch("intelligence_toolkit.AI.client.AzureOpenAI") as mock_azure, \ 

89 patch("intelligence_toolkit.AI.client.AsyncAzureOpenAI") as mock_async_azure: 

90 

91 mock_token.return_value = "mock_token_provider" 

92 

93 client = OpenAIClient(config) 

94 

95 mock_cred.assert_called_once() 

96 mock_token.assert_called_once_with( 

97 mock_cred.return_value, 

98 "https://cognitiveservices.azure.com/.default", 

99 ) 

100 

101 mock_azure.assert_called_once_with( 

102 api_version="2024-02-01", 

103 azure_ad_token_provider="mock_token_provider", 

104 azure_endpoint="https://test.openai.azure.com", 

105 ) 

106 

107 

108def test_openai_client_generate_chat_non_streaming(openai_config): 

109 with patch("intelligence_toolkit.AI.client.OpenAI") as mock_openai_class, \ 

110 patch("intelligence_toolkit.AI.client.AsyncOpenAI"): 

111 

112 mock_client = MagicMock() 

113 mock_openai_class.return_value = mock_client 

114 

115 mock_response = MagicMock() 

116 mock_response.choices = [MagicMock()] 

117 mock_response.choices[0].message.content = "Test response" 

118 mock_client.chat.completions.create.return_value = mock_response 

119 

120 client = OpenAIClient(openai_config) 

121 messages = [{"role": "user", "content": "Hello"}] 

122 result = client.generate_chat(messages, stream=False) 

123 

124 assert result == "Test response" 

125 mock_client.chat.completions.create.assert_called_once() 

126 

127 

128def test_openai_client_generate_chat_streaming_with_callbacks(openai_config): 

129 with patch("intelligence_toolkit.AI.client.OpenAI") as mock_openai_class, \ 

130 patch("intelligence_toolkit.AI.client.AsyncOpenAI"): 

131 

132 mock_client = MagicMock() 

133 mock_openai_class.return_value = mock_client 

134 

135 # Mock streaming response 

136 chunk1 = MagicMock() 

137 chunk1.choices = [MagicMock()] 

138 chunk1.choices[0].delta.content = "Hello" 

139 

140 chunk2 = MagicMock() 

141 chunk2.choices = [MagicMock()] 

142 chunk2.choices[0].delta.content = " World" 

143 

144 mock_client.chat.completions.create.return_value = [chunk1, chunk2] 

145 

146 callback = LLMCallback() 

147 client = OpenAIClient(openai_config) 

148 messages = [{"role": "user", "content": "Hello"}] 

149 result = client.generate_chat(messages, stream=True, callbacks=[callback]) 

150 

151 assert result == "Hello World" 

152 assert len(callback.response) > 0 

153 

154 

155def test_openai_client_generate_chat_exception_handling(openai_config): 

156 with patch("intelligence_toolkit.AI.client.OpenAI") as mock_openai_class, \ 

157 patch("intelligence_toolkit.AI.client.AsyncOpenAI"): 

158 

159 mock_client = MagicMock() 

160 mock_openai_class.return_value = mock_client 

161 mock_client.chat.completions.create.side_effect = Exception("API Error") 

162 

163 client = OpenAIClient(openai_config) 

164 messages = [{"role": "user", "content": "Hello"}] 

165 

166 with pytest.raises(Exception, match="Problem in OpenAI response"): 

167 client.generate_chat(messages, stream=False) 

168 

169 

170def test_openai_client_generate_embedding(openai_config): 

171 with patch("intelligence_toolkit.AI.client.OpenAI") as mock_openai_class, \ 

172 patch("intelligence_toolkit.AI.client.AsyncOpenAI"): 

173 

174 mock_client = MagicMock() 

175 mock_openai_class.return_value = mock_client 

176 

177 mock_embedding_response = MagicMock() 

178 mock_embedding_response.data = [MagicMock()] 

179 mock_embedding_response.data[0].embedding = [0.1, 0.2, 0.3] 

180 mock_client.embeddings.create.return_value = mock_embedding_response 

181 

182 client = OpenAIClient(openai_config) 

183 result = client.generate_embedding("test text") 

184 

185 assert result == [0.1, 0.2, 0.3] 

186 mock_client.embeddings.create.assert_called_once_with( 

187 input="test text", 

188 model=DEFAULT_EMBEDDING_MODEL 

189 ) 

190 

191 

192def test_openai_client_generate_chat_custom_params(openai_config): 

193 with patch("intelligence_toolkit.AI.client.OpenAI") as mock_openai_class, \ 

194 patch("intelligence_toolkit.AI.client.AsyncOpenAI"): 

195 

196 mock_client = MagicMock() 

197 mock_openai_class.return_value = mock_client 

198 

199 mock_response = MagicMock() 

200 mock_response.choices = [MagicMock()] 

201 mock_response.choices[0].message.content = "Test response" 

202 mock_client.chat.completions.create.return_value = mock_response 

203 

204 client = OpenAIClient(openai_config) 

205 messages = [{"role": "user", "content": "Hello"}] 

206 result = client.generate_chat( 

207 messages, 

208 stream=False, 

209 max_tokens=1000, 

210 temperature=0.5 

211 ) 

212 

213 assert result == "Test response" 

214 call_kwargs = mock_client.chat.completions.create.call_args[1] 

215 assert call_kwargs["max_tokens"] == 1000 

216 assert call_kwargs["temperature"] == 0.5 

217 

218 

219@pytest.mark.asyncio 

220async def test_openai_client_generate_chat_async(openai_config): 

221 with patch("intelligence_toolkit.AI.client.OpenAI"), \ 

222 patch("intelligence_toolkit.AI.client.AsyncOpenAI") as mock_async_class: 

223 

224 mock_async_client = MagicMock() 

225 mock_async_class.return_value = mock_async_client 

226 

227 mock_response = MagicMock() 

228 mock_response.choices = [MagicMock()] 

229 mock_response.choices[0].message.content = "Async response" 

230 mock_async_client.chat.completions.create = AsyncMock(return_value=mock_response) 

231 

232 client = OpenAIClient(openai_config) 

233 messages = [{"role": "user", "content": "Hello"}] 

234 result = await client.generate_chat_async(messages, stream=False) 

235 

236 assert result == "Async response" 

237 

238 

239@pytest.mark.asyncio 

240async def test_openai_client_generate_embedding_async(openai_config): 

241 with patch("intelligence_toolkit.AI.client.OpenAI"), \ 

242 patch("intelligence_toolkit.AI.client.AsyncOpenAI") as mock_async_class: 

243 

244 mock_async_client = MagicMock() 

245 mock_async_class.return_value = mock_async_client 

246 

247 mock_embedding_response = MagicMock() 

248 mock_embedding_response.data = [MagicMock()] 

249 mock_embedding_response.data[0].embedding = [0.4, 0.5, 0.6] 

250 mock_async_client.embeddings.create = AsyncMock(return_value=mock_embedding_response) 

251 

252 client = OpenAIClient(openai_config) 

253 result = await client.generate_embedding_async("test text") 

254 

255 assert result == [0.4, 0.5, 0.6]