Coverage for intelligence_toolkit/tests/unit/AI/test_base_chat.py: 100%
74 statements
« prev ^ index » next coverage.py v7.10.7, created at 2025-10-16 13:41 -0300
« prev ^ index » next coverage.py v7.10.7, created at 2025-10-16 13:41 -0300
1# Copyright (c) 2024 Microsoft Corporation. All rights reserved.
2# Licensed under the MIT license. See LICENSE file in the project.
3#
4from unittest.mock import AsyncMock, MagicMock, patch
6import pytest
8from intelligence_toolkit.AI.base_chat import BaseChat
9from intelligence_toolkit.AI.openai_configuration import OpenAIConfiguration
12@pytest.fixture
13def base_chat_config():
14 return OpenAIConfiguration({
15 "api_key": "test_key",
16 "model": "gpt-4",
17 "api_type": "OpenAI",
18 })
21@pytest.fixture
22def base_chat(base_chat_config):
23 with patch("intelligence_toolkit.AI.client.OpenAI"), \
24 patch("intelligence_toolkit.AI.client.AsyncOpenAI"):
25 return BaseChat(base_chat_config)
28def test_base_chat_initialization(base_chat_config):
29 with patch("intelligence_toolkit.AI.client.OpenAI"), \
30 patch("intelligence_toolkit.AI.client.AsyncOpenAI"):
32 chat = BaseChat(base_chat_config, concurrent_coroutines=10)
34 assert chat.configuration is not None
35 assert chat.semaphore._value == 10
38def test_base_chat_initialization_default_coroutines(base_chat_config):
39 with patch("intelligence_toolkit.AI.client.OpenAI"), \
40 patch("intelligence_toolkit.AI.client.AsyncOpenAI"):
42 from intelligence_toolkit.AI.defaults import DEFAULT_CONCURRENT_COROUTINES
44 chat = BaseChat(base_chat_config)
45 assert chat.semaphore._value == DEFAULT_CONCURRENT_COROUTINES
48@pytest.mark.asyncio
49async def test_generate_text_async_success(base_chat):
50 messages = [{"role": "user", "content": "Hello"}]
52 with patch.object(base_chat, 'generate_chat_async', new_callable=AsyncMock) as mock_chat:
53 mock_chat.return_value = "Test response"
55 result = await base_chat.generate_text_async(messages, None, False)
57 assert result == "Test response"
58 mock_chat.assert_called_once_with(messages=messages, stream=False)
61@pytest.mark.asyncio
62async def test_generate_text_async_with_callbacks(base_chat):
63 messages = [{"role": "user", "content": "Hello"}]
64 callback = MagicMock()
66 with patch.object(base_chat, 'generate_chat_async', new_callable=AsyncMock) as mock_chat, \
67 patch.object(base_chat, 'progress_callback') as mock_progress:
69 mock_chat.return_value = "Test response"
71 result = await base_chat.generate_text_async(messages, [callback], False)
73 assert result == "Test response"
74 mock_progress.assert_called_once()
77@pytest.mark.asyncio
78async def test_generate_text_async_exception(base_chat):
79 messages = [{"role": "user", "content": "Hello"}]
81 with patch.object(base_chat, 'generate_chat_async', new_callable=AsyncMock) as mock_chat:
82 mock_chat.side_effect = Exception("API Error")
84 with pytest.raises(Exception, match="Problem in OpenAI response"):
85 await base_chat.generate_text_async(messages, None, False)
88@pytest.mark.asyncio
89async def test_generate_texts_async_multiple_messages(base_chat):
90 messages_list = [
91 [{"role": "user", "content": "Hello"}],
92 [{"role": "user", "content": "World"}],
93 [{"role": "user", "content": "Test"}],
94 ]
96 with patch.object(base_chat, 'generate_text_async', new_callable=AsyncMock) as mock_generate:
97 mock_generate.side_effect = ["Response 1", "Response 2", "Response 3"]
99 results = await base_chat.generate_texts_async(messages_list)
101 assert len(results) == 3
102 assert results[0] == "Response 1"
103 assert results[1] == "Response 2"
104 assert results[2] == "Response 3"
105 assert base_chat.total_tasks == 3
108@pytest.mark.asyncio
109async def test_generate_texts_async_with_callbacks(base_chat):
110 messages_list = [
111 [{"role": "user", "content": "Hello"}],
112 [{"role": "user", "content": "World"}],
113 ]
115 callback = MagicMock()
117 with patch.object(base_chat, 'generate_text_async', new_callable=AsyncMock) as mock_generate, \
118 patch.object(base_chat, 'track_progress', new_callable=AsyncMock) as mock_track:
120 mock_generate.side_effect = ["Response 1", "Response 2"]
122 results = await base_chat.generate_texts_async(messages_list, callbacks=[callback])
124 assert len(results) == 2
125 mock_track.assert_called_once()
128@pytest.mark.asyncio
129async def test_generate_texts_async_with_kwargs(base_chat):
130 messages_list = [
131 [{"role": "user", "content": "Hello"}],
132 ]
134 with patch.object(base_chat, 'generate_text_async', new_callable=AsyncMock) as mock_generate:
135 mock_generate.return_value = "Response"
137 await base_chat.generate_texts_async(
138 messages_list,
139 temperature=0.5,
140 max_tokens=100
141 )
143 # Verify kwargs were passed
144 call_kwargs = mock_generate.call_args[1]
145 assert call_kwargs['temperature'] == 0.5
146 assert call_kwargs['max_tokens'] == 100