Coverage for intelligence_toolkit/tests/unit/AI/test_base_chat.py: 100%

74 statements  

« prev     ^ index     » next       coverage.py v7.10.7, created at 2025-10-16 13:41 -0300

1# Copyright (c) 2024 Microsoft Corporation. All rights reserved. 

2# Licensed under the MIT license. See LICENSE file in the project. 

3# 

4from unittest.mock import AsyncMock, MagicMock, patch 

5 

6import pytest 

7 

8from intelligence_toolkit.AI.base_chat import BaseChat 

9from intelligence_toolkit.AI.openai_configuration import OpenAIConfiguration 

10 

11 

12@pytest.fixture 

13def base_chat_config(): 

14 return OpenAIConfiguration({ 

15 "api_key": "test_key", 

16 "model": "gpt-4", 

17 "api_type": "OpenAI", 

18 }) 

19 

20 

21@pytest.fixture 

22def base_chat(base_chat_config): 

23 with patch("intelligence_toolkit.AI.client.OpenAI"), \ 

24 patch("intelligence_toolkit.AI.client.AsyncOpenAI"): 

25 return BaseChat(base_chat_config) 

26 

27 

28def test_base_chat_initialization(base_chat_config): 

29 with patch("intelligence_toolkit.AI.client.OpenAI"), \ 

30 patch("intelligence_toolkit.AI.client.AsyncOpenAI"): 

31 

32 chat = BaseChat(base_chat_config, concurrent_coroutines=10) 

33 

34 assert chat.configuration is not None 

35 assert chat.semaphore._value == 10 

36 

37 

38def test_base_chat_initialization_default_coroutines(base_chat_config): 

39 with patch("intelligence_toolkit.AI.client.OpenAI"), \ 

40 patch("intelligence_toolkit.AI.client.AsyncOpenAI"): 

41 

42 from intelligence_toolkit.AI.defaults import DEFAULT_CONCURRENT_COROUTINES 

43 

44 chat = BaseChat(base_chat_config) 

45 assert chat.semaphore._value == DEFAULT_CONCURRENT_COROUTINES 

46 

47 

48@pytest.mark.asyncio 

49async def test_generate_text_async_success(base_chat): 

50 messages = [{"role": "user", "content": "Hello"}] 

51 

52 with patch.object(base_chat, 'generate_chat_async', new_callable=AsyncMock) as mock_chat: 

53 mock_chat.return_value = "Test response" 

54 

55 result = await base_chat.generate_text_async(messages, None, False) 

56 

57 assert result == "Test response" 

58 mock_chat.assert_called_once_with(messages=messages, stream=False) 

59 

60 

61@pytest.mark.asyncio 

62async def test_generate_text_async_with_callbacks(base_chat): 

63 messages = [{"role": "user", "content": "Hello"}] 

64 callback = MagicMock() 

65 

66 with patch.object(base_chat, 'generate_chat_async', new_callable=AsyncMock) as mock_chat, \ 

67 patch.object(base_chat, 'progress_callback') as mock_progress: 

68 

69 mock_chat.return_value = "Test response" 

70 

71 result = await base_chat.generate_text_async(messages, [callback], False) 

72 

73 assert result == "Test response" 

74 mock_progress.assert_called_once() 

75 

76 

77@pytest.mark.asyncio 

78async def test_generate_text_async_exception(base_chat): 

79 messages = [{"role": "user", "content": "Hello"}] 

80 

81 with patch.object(base_chat, 'generate_chat_async', new_callable=AsyncMock) as mock_chat: 

82 mock_chat.side_effect = Exception("API Error") 

83 

84 with pytest.raises(Exception, match="Problem in OpenAI response"): 

85 await base_chat.generate_text_async(messages, None, False) 

86 

87 

88@pytest.mark.asyncio 

89async def test_generate_texts_async_multiple_messages(base_chat): 

90 messages_list = [ 

91 [{"role": "user", "content": "Hello"}], 

92 [{"role": "user", "content": "World"}], 

93 [{"role": "user", "content": "Test"}], 

94 ] 

95 

96 with patch.object(base_chat, 'generate_text_async', new_callable=AsyncMock) as mock_generate: 

97 mock_generate.side_effect = ["Response 1", "Response 2", "Response 3"] 

98 

99 results = await base_chat.generate_texts_async(messages_list) 

100 

101 assert len(results) == 3 

102 assert results[0] == "Response 1" 

103 assert results[1] == "Response 2" 

104 assert results[2] == "Response 3" 

105 assert base_chat.total_tasks == 3 

106 

107 

108@pytest.mark.asyncio 

109async def test_generate_texts_async_with_callbacks(base_chat): 

110 messages_list = [ 

111 [{"role": "user", "content": "Hello"}], 

112 [{"role": "user", "content": "World"}], 

113 ] 

114 

115 callback = MagicMock() 

116 

117 with patch.object(base_chat, 'generate_text_async', new_callable=AsyncMock) as mock_generate, \ 

118 patch.object(base_chat, 'track_progress', new_callable=AsyncMock) as mock_track: 

119 

120 mock_generate.side_effect = ["Response 1", "Response 2"] 

121 

122 results = await base_chat.generate_texts_async(messages_list, callbacks=[callback]) 

123 

124 assert len(results) == 2 

125 mock_track.assert_called_once() 

126 

127 

128@pytest.mark.asyncio 

129async def test_generate_texts_async_with_kwargs(base_chat): 

130 messages_list = [ 

131 [{"role": "user", "content": "Hello"}], 

132 ] 

133 

134 with patch.object(base_chat, 'generate_text_async', new_callable=AsyncMock) as mock_generate: 

135 mock_generate.return_value = "Response" 

136 

137 await base_chat.generate_texts_async( 

138 messages_list, 

139 temperature=0.5, 

140 max_tokens=100 

141 ) 

142 

143 # Verify kwargs were passed 

144 call_kwargs = mock_generate.call_args[1] 

145 assert call_kwargs['temperature'] == 0.5 

146 assert call_kwargs['max_tokens'] == 100