Coverage for intelligence_toolkit/tests/unit/match_entity_records/test_api.py: 100%
154 statements
« prev ^ index » next coverage.py v7.10.7, created at 2025-10-16 13:41 -0300
« prev ^ index » next coverage.py v7.10.7, created at 2025-10-16 13:41 -0300
1# Copyright (c) 2024 Microsoft Corporation. All rights reserved.
2# Licensed under the MIT license. See LICENSE file in the project.
3#
5import io
6from pathlib import Path
7from unittest.mock import AsyncMock, MagicMock, patch
9import numpy as np
10import polars as pl
11import pytest
13from intelligence_toolkit.AI.classes import LLMCallback
14from intelligence_toolkit.match_entity_records.api import MatchEntityRecords
15from intelligence_toolkit.match_entity_records.classes import (
16 AttributeToMatch,
17 RecordsModel,
18)
21class TestMatchEntityRecords:
22 @pytest.fixture()
23 def api_instance(self) -> MatchEntityRecords:
24 """Create a MatchEntityRecords instance for testing."""
25 return MatchEntityRecords()
27 @pytest.fixture()
28 def sample_dataframe(self) -> pl.DataFrame:
29 """Create a sample dataframe for testing."""
30 return pl.DataFrame(
31 {
32 "id": [1, 2, 3],
33 "name": ["Entity A", "Entity B", "Entity C"],
34 "attribute1": ["value1", "value2", "value3"],
35 "attribute2": ["val1", "val2", "val3"],
36 }
37 )
39 @pytest.fixture()
40 def sample_model(self, sample_dataframe) -> RecordsModel:
41 """Create a sample RecordsModel for testing."""
42 return RecordsModel(
43 dataframe=sample_dataframe,
44 dataframe_name="test_dataset",
45 id_column="id",
46 name_column="name",
47 columns=["attribute1", "attribute2"],
48 )
50 @pytest.fixture()
51 def populated_api(self, api_instance, sample_model) -> MatchEntityRecords:
52 """Create a populated API instance with data."""
53 api_instance.add_df_to_model(sample_model)
54 return api_instance
57class TestTotalRecords(TestMatchEntityRecords):
58 def test_total_records_empty(self, api_instance) -> None:
59 """Test total_records with no data."""
60 assert api_instance.total_records == 0
62 def test_total_records_single_dataset(self, populated_api) -> None:
63 """Test total_records with one dataset."""
64 assert populated_api.total_records == 3
66 def test_total_records_multiple_datasets(self, api_instance, sample_dataframe) -> None:
67 """Test total_records with multiple datasets."""
68 # Clear existing data first
69 api_instance.clear_model_dfs()
71 model1 = RecordsModel(
72 dataframe=sample_dataframe,
73 dataframe_name="dataset1",
74 id_column="id",
75 name_column="name",
76 columns=["attribute1"],
77 )
78 model2 = RecordsModel(
79 dataframe=sample_dataframe.head(2),
80 dataframe_name="dataset2",
81 id_column="id",
82 name_column="name",
83 columns=["attribute1"],
84 )
85 api_instance.add_df_to_model(model1)
86 api_instance.add_df_to_model(model2)
87 assert api_instance.total_records == 5 # 3 + 2
90class TestAttributeOptions(TestMatchEntityRecords):
91 def test_attribute_options_empty(self, api_instance) -> None:
92 """Test attribute_options with no data."""
93 # Clear any existing data
94 api_instance.clear_model_dfs()
95 options = api_instance.attribute_options
96 assert isinstance(options, list)
97 assert len(options) == 0
99 def test_attribute_options_populated(self, populated_api) -> None:
100 """Test attribute_options with data."""
101 options = populated_api.attribute_options
102 assert isinstance(options, list)
103 assert len(options) > 0
104 # Options should be in format "column::dataset"
105 for option in options:
106 assert "::" in option
109class TestIntegratedResults(TestMatchEntityRecords):
110 def test_integrated_results_empty(self, api_instance) -> None:
111 """Test integrated_results with empty data creates empty dataframe."""
112 # Initialize with empty dataframes to avoid panic
113 api_instance.evaluations_df = pl.DataFrame(
114 schema={"Group ID": pl.Int64, "Relatedness": pl.Float64, "Explanation": pl.Utf8}
115 )
116 api_instance.matches_df = pl.DataFrame(
117 schema={"Group ID": pl.Int64, "Entity name": pl.Utf8, "Dataset": pl.Utf8}
118 )
119 result = api_instance.integrated_results
120 assert isinstance(result, pl.DataFrame)
121 assert result.is_empty()
123 def test_integrated_results_with_data(self, api_instance) -> None:
124 """Test integrated_results with matches and evaluations."""
125 # Create sample matches_df
126 api_instance.matches_df = pl.DataFrame(
127 {
128 "Group ID": [1, 2],
129 "Entity name": ["A", "B"],
130 "Dataset": ["test", "test"],
131 }
132 )
133 # Create sample evaluations_df
134 api_instance.evaluations_df = pl.DataFrame(
135 {
136 "Group ID": [1, 2],
137 "Relatedness": [8, 6],
138 "Explanation": ["Similar", "Somewhat similar"],
139 }
140 )
141 result = api_instance.integrated_results
142 assert isinstance(result, pl.DataFrame)
143 assert not result.is_empty()
144 assert "Group ID" in result.columns
145 assert "Relatedness" in result.columns
148class TestAddDfToModel(TestMatchEntityRecords):
149 def test_add_df_to_model_basic(self, api_instance, sample_model) -> None:
150 """Test adding a dataframe to the model."""
151 result = api_instance.add_df_to_model(sample_model)
152 assert isinstance(result, pl.DataFrame)
153 assert not result.is_empty()
154 assert "test_dataset" in api_instance.model_dfs
156 def test_add_df_to_model_auto_name(self, api_instance, sample_dataframe) -> None:
157 """Test adding a dataframe without a name."""
158 # Clear existing data first
159 api_instance.clear_model_dfs()
161 model = RecordsModel(
162 dataframe=sample_dataframe,
163 dataframe_name="",
164 id_column="id",
165 name_column="name",
166 columns=["attribute1"],
167 )
168 result = api_instance.add_df_to_model(model)
169 assert isinstance(result, pl.DataFrame)
170 assert len(api_instance.model_dfs) == 1
172 def test_add_df_to_model_with_max_rows(self, api_instance, sample_model) -> None:
173 """Test adding a dataframe with max_rows limit."""
174 api_instance.max_rows_to_process = 2
175 result = api_instance.add_df_to_model(sample_model)
176 assert len(result) <= 2
179class TestBuildModelDf(TestMatchEntityRecords):
180 @pytest.fixture()
181 def attributes_list(self) -> list[AttributeToMatch]:
182 """Create sample attributes list."""
183 return [
184 {"label": "Attribute 1", "columns": ["attribute1::test_dataset"]},
185 {"label": "Attribute 2", "columns": ["attribute2::test_dataset"]},
186 ]
188 def test_build_model_df_basic(self, populated_api, attributes_list) -> None:
189 """Test building model dataframe."""
190 result = populated_api.build_model_df(attributes_list)
191 assert isinstance(result, pl.DataFrame)
192 assert not result.is_empty()
193 # Check that the model_df was set correctly
194 assert hasattr(populated_api, "model_df")
195 assert "Dataset" in result.columns
197 def test_build_model_df_creates_sentences(
198 self, populated_api, attributes_list
199 ) -> None:
200 """Test that building model df creates sentence vector data."""
201 populated_api.build_model_df(attributes_list)
202 assert hasattr(populated_api, "sentences_vector_data")
203 assert isinstance(populated_api.sentences_vector_data, list)
206class TestEmbedSentences(TestMatchEntityRecords):
207 @pytest.mark.asyncio()
208 async def test_embed_sentences(self, populated_api) -> None:
209 """Test embedding sentences."""
210 # Setup mock embedder
211 mock_embedder = AsyncMock()
212 mock_embedder.embed_store_many = AsyncMock(
213 return_value=[
214 {"text": "sentence1", "vector": [0.1, 0.2, 0.3]},
215 {"text": "sentence2", "vector": [0.4, 0.5, 0.6]},
216 ]
217 )
218 populated_api.embedder = mock_embedder
219 populated_api.cache_embeddings = True
221 # Setup sentences_vector_data
222 populated_api.sentences_vector_data = [
223 {"text": "sentence1"},
224 {"text": "sentence2"},
225 ]
227 await populated_api.embed_sentences()
229 assert hasattr(populated_api, "all_sentences")
230 assert hasattr(populated_api, "embeddings")
231 assert len(populated_api.all_sentences) == 2
232 assert len(populated_api.embeddings) == 2
233 assert isinstance(populated_api.embeddings[0], np.ndarray)
234 mock_embedder.embed_store_many.assert_called_once()
237class TestDetectRecordGroups(TestMatchEntityRecords):
238 def test_detect_record_groups(self, populated_api) -> None:
239 """Test detecting record groups."""
240 # Setup required data with enough embeddings for default n_neighbors (50)
241 # Use 60 embeddings to be safe
242 embeddings_list = [np.random.rand(3) for _ in range(60)]
243 populated_api.embeddings = embeddings_list
244 populated_api.all_sentences = [f"sent{i}" for i in range(60)]
246 # Create model_df with matching number of rows
247 entity_ids = [str(i) for i in range(60)]
248 entity_names = [f"Entity{i}" for i in range(60)]
249 datasets = ["test"] * 60
250 unique_ids = [f"{i}::test" for i in range(60)]
252 populated_api.model_df = pl.DataFrame(
253 {
254 "Entity ID": entity_ids,
255 "Entity name": entity_names,
256 "Dataset": datasets,
257 "Unique ID": unique_ids,
258 }
259 )
261 result = populated_api.detect_record_groups(
262 pair_embedding_threshold=80, pair_jaccard_threshold=50
263 )
265 assert isinstance(result, pl.DataFrame)
266 assert hasattr(populated_api, "matches_df")
269class TestEvaluateGroups(TestMatchEntityRecords):
270 @pytest.mark.asyncio()
271 async def test_evaluate_groups_basic(self, populated_api) -> None:
272 """Test evaluating groups."""
273 # Setup model_df
274 populated_api.model_df = pl.DataFrame(
275 {
276 "Entity ID": ["1", "2"],
277 "Entity name": ["A", "B"],
278 "Dataset": ["test", "test"],
279 "Name similarity": [0.9, 0.8],
280 "Attribute 1": ["val1", "val2"],
281 }
282 )
284 # Mock AI configuration and client
285 populated_api.ai_configuration = MagicMock()
286 mock_response = "1,9,Very similar entities\n2,7,Somewhat related"
288 with patch(
289 "intelligence_toolkit.match_entity_records.api.OpenAIClient"
290 ) as mock_client:
291 mock_instance = MagicMock()
292 mock_instance.generate_chat_async = AsyncMock(return_value=mock_response)
293 mock_client.return_value = mock_instance
295 result = await populated_api.evaluate_groups()
297 assert isinstance(result, str)
298 assert hasattr(populated_api, "evaluations_df")
299 assert isinstance(populated_api.evaluations_df, pl.DataFrame)
301 @pytest.mark.asyncio()
302 async def test_evaluate_groups_with_callbacks(self, populated_api) -> None:
303 """Test evaluating groups with callbacks."""
304 populated_api.model_df = pl.DataFrame(
305 {
306 "Entity ID": ["1"],
307 "Entity name": ["A"],
308 "Dataset": ["test"],
309 "Name similarity": [0.9],
310 }
311 )
313 populated_api.ai_configuration = MagicMock()
314 mock_response = "1,9,Similar"
315 callbacks = [MagicMock(spec=LLMCallback)]
317 with patch(
318 "intelligence_toolkit.match_entity_records.api.OpenAIClient"
319 ) as mock_client:
320 mock_instance = MagicMock()
321 mock_instance.generate_chat_async = AsyncMock(return_value=mock_response)
322 mock_client.return_value = mock_instance
324 result = await populated_api.evaluate_groups(callbacks=callbacks)
325 assert isinstance(result, str)
328class TestClearModelDfs(TestMatchEntityRecords):
329 def test_clear_model_dfs_empty(self, api_instance) -> None:
330 """Test clearing model dfs when empty."""
331 api_instance.clear_model_dfs()
332 assert len(api_instance.model_dfs) == 0
334 def test_clear_model_dfs_populated(self, populated_api) -> None:
335 """Test clearing model dfs with data."""
336 assert len(populated_api.model_dfs) > 0
337 populated_api.clear_model_dfs()
338 assert len(populated_api.model_dfs) == 0