Coverage for intelligence_toolkit/tests/unit/AI/test_openai_embedder.py: 100%

55 statements  

« prev     ^ index     » next       coverage.py v7.10.7, created at 2025-10-16 13:41 -0300

1# Copyright (c) 2024 Microsoft Corporation. All rights reserved. 

2# Licensed under the MIT license. See LICENSE file in the project. 

3# 

4import tempfile 

5from unittest.mock import AsyncMock, MagicMock, patch 

6 

7import pytest 

8 

9from intelligence_toolkit.AI.openai_configuration import OpenAIConfiguration 

10from intelligence_toolkit.AI.openai_embedder import OpenAIEmbedder 

11 

12 

13@pytest.fixture 

14def openai_config(): 

15 return OpenAIConfiguration({ 

16 "api_key": "test_key", 

17 "model": "gpt-4", 

18 "api_type": "OpenAI", 

19 "embedding_model": "text-embedding-3-small", 

20 }) 

21 

22 

23@pytest.fixture 

24def temp_db_path(): 

25 with tempfile.TemporaryDirectory() as tmpdir: 

26 yield tmpdir 

27 

28 

29def test_openai_embedder_initialization(openai_config, temp_db_path): 

30 with patch("intelligence_toolkit.AI.client.OpenAI"), \ 

31 patch("intelligence_toolkit.AI.client.AsyncOpenAI"): 

32 

33 embedder = OpenAIEmbedder( 

34 openai_config, 

35 db_name="test_embeddings", 

36 db_path=temp_db_path 

37 ) 

38 

39 assert embedder.configuration == openai_config 

40 assert embedder.openai_client is not None 

41 assert embedder.vector_store is not None 

42 

43 

44def test_openai_embedder_generate_embedding(openai_config, temp_db_path): 

45 with patch("intelligence_toolkit.AI.client.OpenAI") as mock_openai, \ 

46 patch("intelligence_toolkit.AI.client.AsyncOpenAI"): 

47 

48 mock_client = MagicMock() 

49 mock_openai.return_value = mock_client 

50 

51 mock_embedding_response = MagicMock() 

52 mock_embedding_response.data = [MagicMock()] 

53 mock_embedding_response.data[0].embedding = [0.1, 0.2, 0.3] 

54 mock_client.embeddings.create.return_value = mock_embedding_response 

55 

56 embedder = OpenAIEmbedder( 

57 openai_config, 

58 db_name="test_embeddings", 

59 db_path=temp_db_path 

60 ) 

61 

62 result = embedder._generate_embedding("test text") 

63 

64 assert result == [0.1, 0.2, 0.3] 

65 mock_client.embeddings.create.assert_called_once() 

66 

67 

68@pytest.mark.asyncio 

69async def test_openai_embedder_generate_embedding_async(openai_config, temp_db_path): 

70 with patch("intelligence_toolkit.AI.client.OpenAI"), \ 

71 patch("intelligence_toolkit.AI.client.AsyncOpenAI") as mock_async_openai: 

72 

73 mock_async_client = MagicMock() 

74 mock_async_openai.return_value = mock_async_client 

75 

76 mock_embedding_response = MagicMock() 

77 mock_embedding_response.data = [MagicMock()] 

78 mock_embedding_response.data[0].embedding = [0.4, 0.5, 0.6] 

79 mock_async_client.embeddings.create = AsyncMock(return_value=mock_embedding_response) 

80 

81 embedder = OpenAIEmbedder( 

82 openai_config, 

83 db_name="test_embeddings", 

84 db_path=temp_db_path 

85 ) 

86 

87 result = await embedder._generate_embedding_async("test text") 

88 

89 assert result == [0.4, 0.5, 0.6] 

90 mock_async_client.embeddings.create.assert_called_once() 

91 

92 

93def test_openai_embedder_uses_configured_model(openai_config, temp_db_path): 

94 with patch("intelligence_toolkit.AI.client.OpenAI") as mock_openai, \ 

95 patch("intelligence_toolkit.AI.client.AsyncOpenAI"): 

96 

97 mock_client = MagicMock() 

98 mock_openai.return_value = mock_client 

99 

100 mock_embedding_response = MagicMock() 

101 mock_embedding_response.data = [MagicMock()] 

102 mock_embedding_response.data[0].embedding = [0.1, 0.2, 0.3] 

103 mock_client.embeddings.create.return_value = mock_embedding_response 

104 

105 embedder = OpenAIEmbedder( 

106 openai_config, 

107 db_name="test_embeddings", 

108 db_path=temp_db_path 

109 ) 

110 

111 embedder._generate_embedding("test text") 

112 

113 call_kwargs = mock_client.embeddings.create.call_args[1] 

114 assert call_kwargs["model"] == "text-embedding-3-small"