Coverage for intelligence_toolkit/tests/unit/AI/test_local_embedder.py: 100%
75 statements
« prev ^ index » next coverage.py v7.10.7, created at 2025-10-16 13:41 -0300
« prev ^ index » next coverage.py v7.10.7, created at 2025-10-16 13:41 -0300
1# Copyright (c) 2024 Microsoft Corporation. All rights reserved.
2# Licensed under the MIT license. See LICENSE file in the project.
3#
4import tempfile
5from unittest.mock import MagicMock, patch
7import pytest
9from intelligence_toolkit.AI.local_embedder import LocalEmbedder
12@pytest.fixture
13def temp_db_path():
14 with tempfile.TemporaryDirectory() as tmpdir:
15 yield tmpdir
18def test_local_embedder_initialization(temp_db_path):
19 with patch("intelligence_toolkit.AI.local_embedder.SentenceTransformer") as mock_st:
20 mock_model = MagicMock()
21 mock_st.return_value = mock_model
23 embedder = LocalEmbedder(
24 db_name="test_embeddings",
25 db_path=temp_db_path,
26 model="all-distilroberta-v1"
27 )
29 assert embedder.local_client is not None
30 mock_st.assert_called_once_with("all-distilroberta-v1")
33def test_local_embedder_initialization_default_model(temp_db_path):
34 with patch("intelligence_toolkit.AI.local_embedder.SentenceTransformer") as mock_st:
35 from intelligence_toolkit.AI.defaults import DEFAULT_LOCAL_EMBEDDING_MODEL
37 mock_model = MagicMock()
38 mock_st.return_value = mock_model
40 embedder = LocalEmbedder(
41 db_name="test_embeddings",
42 db_path=temp_db_path,
43 )
45 mock_st.assert_called_once_with(DEFAULT_LOCAL_EMBEDDING_MODEL)
48def test_local_embedder_initialization_none_model(temp_db_path):
49 with patch("intelligence_toolkit.AI.local_embedder.SentenceTransformer") as mock_st:
50 from intelligence_toolkit.AI.defaults import DEFAULT_LOCAL_EMBEDDING_MODEL
52 mock_model = MagicMock()
53 mock_st.return_value = mock_model
55 embedder = LocalEmbedder(
56 db_name="test_embeddings",
57 db_path=temp_db_path,
58 model=None
59 )
61 mock_st.assert_called_once_with(DEFAULT_LOCAL_EMBEDDING_MODEL)
64def test_local_embedder_initialization_model_error(temp_db_path):
65 with patch("intelligence_toolkit.AI.local_embedder.SentenceTransformer") as mock_st:
66 mock_st.side_effect = Exception("Model not found")
68 with pytest.raises(Exception, match="Failed to load local embedding model"):
69 LocalEmbedder(
70 db_name="test_embeddings",
71 db_path=temp_db_path,
72 model="invalid-model"
73 )
76def test_local_embedder_generate_embedding(temp_db_path):
77 with patch("intelligence_toolkit.AI.local_embedder.SentenceTransformer") as mock_st:
78 mock_model = MagicMock()
79 mock_st.return_value = mock_model
81 # Mock the encode method
82 import numpy as np
83 mock_array = np.array([0.1, 0.2, 0.3])
84 mock_model.encode.return_value = mock_array
86 embedder = LocalEmbedder(
87 db_name="test_embeddings",
88 db_path=temp_db_path,
89 )
91 result = embedder._generate_embedding("test text")
93 assert result == [0.1, 0.2, 0.3]
94 mock_model.encode.assert_called_once_with("test text")
97@pytest.mark.asyncio
98async def test_local_embedder_generate_embedding_async(temp_db_path):
99 with patch("intelligence_toolkit.AI.local_embedder.SentenceTransformer") as mock_st:
100 mock_model = MagicMock()
101 mock_st.return_value = mock_model
103 import numpy as np
104 mock_array = np.array([0.4, 0.5, 0.6])
105 mock_model.encode.return_value = mock_array
107 embedder = LocalEmbedder(
108 db_name="test_embeddings",
109 db_path=temp_db_path,
110 )
112 result = await embedder._generate_embedding_async("test text")
114 assert result == [0.4, 0.5, 0.6]
115 mock_model.encode.assert_called_once_with("test text")
118def test_local_embedder_generate_embedding_list(temp_db_path):
119 with patch("intelligence_toolkit.AI.local_embedder.SentenceTransformer") as mock_st:
120 mock_model = MagicMock()
121 mock_st.return_value = mock_model
123 import numpy as np
124 mock_array = np.array([[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]])
125 mock_model.encode.return_value = mock_array
127 embedder = LocalEmbedder(
128 db_name="test_embeddings",
129 db_path=temp_db_path,
130 )
132 result = embedder._generate_embedding(["text1", "text2"])
134 assert len(result) == 2
135 assert result[0] == [0.1, 0.2, 0.3]
136 assert result[1] == [0.4, 0.5, 0.6]
139def test_local_embedder_check_token_count_disabled(temp_db_path):
140 with patch("intelligence_toolkit.AI.local_embedder.SentenceTransformer") as mock_st:
141 mock_model = MagicMock()
142 mock_st.return_value = mock_model
144 embedder = LocalEmbedder(
145 db_name="test_embeddings",
146 db_path=temp_db_path,
147 )
149 # Verify check_token_count is disabled for local embedder
150 assert embedder.check_token_count == False