Coverage for intelligence_toolkit/tests/unit/AI/test_local_embedder.py: 100%

75 statements  

« prev     ^ index     » next       coverage.py v7.10.7, created at 2025-10-16 13:41 -0300

1# Copyright (c) 2024 Microsoft Corporation. All rights reserved. 

2# Licensed under the MIT license. See LICENSE file in the project. 

3# 

4import tempfile 

5from unittest.mock import MagicMock, patch 

6 

7import pytest 

8 

9from intelligence_toolkit.AI.local_embedder import LocalEmbedder 

10 

11 

12@pytest.fixture 

13def temp_db_path(): 

14 with tempfile.TemporaryDirectory() as tmpdir: 

15 yield tmpdir 

16 

17 

18def test_local_embedder_initialization(temp_db_path): 

19 with patch("intelligence_toolkit.AI.local_embedder.SentenceTransformer") as mock_st: 

20 mock_model = MagicMock() 

21 mock_st.return_value = mock_model 

22 

23 embedder = LocalEmbedder( 

24 db_name="test_embeddings", 

25 db_path=temp_db_path, 

26 model="all-distilroberta-v1" 

27 ) 

28 

29 assert embedder.local_client is not None 

30 mock_st.assert_called_once_with("all-distilroberta-v1") 

31 

32 

33def test_local_embedder_initialization_default_model(temp_db_path): 

34 with patch("intelligence_toolkit.AI.local_embedder.SentenceTransformer") as mock_st: 

35 from intelligence_toolkit.AI.defaults import DEFAULT_LOCAL_EMBEDDING_MODEL 

36 

37 mock_model = MagicMock() 

38 mock_st.return_value = mock_model 

39 

40 embedder = LocalEmbedder( 

41 db_name="test_embeddings", 

42 db_path=temp_db_path, 

43 ) 

44 

45 mock_st.assert_called_once_with(DEFAULT_LOCAL_EMBEDDING_MODEL) 

46 

47 

48def test_local_embedder_initialization_none_model(temp_db_path): 

49 with patch("intelligence_toolkit.AI.local_embedder.SentenceTransformer") as mock_st: 

50 from intelligence_toolkit.AI.defaults import DEFAULT_LOCAL_EMBEDDING_MODEL 

51 

52 mock_model = MagicMock() 

53 mock_st.return_value = mock_model 

54 

55 embedder = LocalEmbedder( 

56 db_name="test_embeddings", 

57 db_path=temp_db_path, 

58 model=None 

59 ) 

60 

61 mock_st.assert_called_once_with(DEFAULT_LOCAL_EMBEDDING_MODEL) 

62 

63 

64def test_local_embedder_initialization_model_error(temp_db_path): 

65 with patch("intelligence_toolkit.AI.local_embedder.SentenceTransformer") as mock_st: 

66 mock_st.side_effect = Exception("Model not found") 

67 

68 with pytest.raises(Exception, match="Failed to load local embedding model"): 

69 LocalEmbedder( 

70 db_name="test_embeddings", 

71 db_path=temp_db_path, 

72 model="invalid-model" 

73 ) 

74 

75 

76def test_local_embedder_generate_embedding(temp_db_path): 

77 with patch("intelligence_toolkit.AI.local_embedder.SentenceTransformer") as mock_st: 

78 mock_model = MagicMock() 

79 mock_st.return_value = mock_model 

80 

81 # Mock the encode method 

82 import numpy as np 

83 mock_array = np.array([0.1, 0.2, 0.3]) 

84 mock_model.encode.return_value = mock_array 

85 

86 embedder = LocalEmbedder( 

87 db_name="test_embeddings", 

88 db_path=temp_db_path, 

89 ) 

90 

91 result = embedder._generate_embedding("test text") 

92 

93 assert result == [0.1, 0.2, 0.3] 

94 mock_model.encode.assert_called_once_with("test text") 

95 

96 

97@pytest.mark.asyncio 

98async def test_local_embedder_generate_embedding_async(temp_db_path): 

99 with patch("intelligence_toolkit.AI.local_embedder.SentenceTransformer") as mock_st: 

100 mock_model = MagicMock() 

101 mock_st.return_value = mock_model 

102 

103 import numpy as np 

104 mock_array = np.array([0.4, 0.5, 0.6]) 

105 mock_model.encode.return_value = mock_array 

106 

107 embedder = LocalEmbedder( 

108 db_name="test_embeddings", 

109 db_path=temp_db_path, 

110 ) 

111 

112 result = await embedder._generate_embedding_async("test text") 

113 

114 assert result == [0.4, 0.5, 0.6] 

115 mock_model.encode.assert_called_once_with("test text") 

116 

117 

118def test_local_embedder_generate_embedding_list(temp_db_path): 

119 with patch("intelligence_toolkit.AI.local_embedder.SentenceTransformer") as mock_st: 

120 mock_model = MagicMock() 

121 mock_st.return_value = mock_model 

122 

123 import numpy as np 

124 mock_array = np.array([[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]]) 

125 mock_model.encode.return_value = mock_array 

126 

127 embedder = LocalEmbedder( 

128 db_name="test_embeddings", 

129 db_path=temp_db_path, 

130 ) 

131 

132 result = embedder._generate_embedding(["text1", "text2"]) 

133 

134 assert len(result) == 2 

135 assert result[0] == [0.1, 0.2, 0.3] 

136 assert result[1] == [0.4, 0.5, 0.6] 

137 

138 

139def test_local_embedder_check_token_count_disabled(temp_db_path): 

140 with patch("intelligence_toolkit.AI.local_embedder.SentenceTransformer") as mock_st: 

141 mock_model = MagicMock() 

142 mock_st.return_value = mock_model 

143 

144 embedder = LocalEmbedder( 

145 db_name="test_embeddings", 

146 db_path=temp_db_path, 

147 ) 

148 

149 # Verify check_token_count is disabled for local embedder 

150 assert embedder.check_token_count == False