Coverage for intelligence_toolkit/tests/unit/match_entity_records/test_api.py: 100%

154 statements  

« prev     ^ index     » next       coverage.py v7.10.7, created at 2025-10-16 13:41 -0300

1# Copyright (c) 2024 Microsoft Corporation. All rights reserved. 

2# Licensed under the MIT license. See LICENSE file in the project. 

3# 

4 

5import io 

6from pathlib import Path 

7from unittest.mock import AsyncMock, MagicMock, patch 

8 

9import numpy as np 

10import polars as pl 

11import pytest 

12 

13from intelligence_toolkit.AI.classes import LLMCallback 

14from intelligence_toolkit.match_entity_records.api import MatchEntityRecords 

15from intelligence_toolkit.match_entity_records.classes import ( 

16 AttributeToMatch, 

17 RecordsModel, 

18) 

19 

20 

21class TestMatchEntityRecords: 

22 @pytest.fixture() 

23 def api_instance(self) -> MatchEntityRecords: 

24 """Create a MatchEntityRecords instance for testing.""" 

25 return MatchEntityRecords() 

26 

27 @pytest.fixture() 

28 def sample_dataframe(self) -> pl.DataFrame: 

29 """Create a sample dataframe for testing.""" 

30 return pl.DataFrame( 

31 { 

32 "id": [1, 2, 3], 

33 "name": ["Entity A", "Entity B", "Entity C"], 

34 "attribute1": ["value1", "value2", "value3"], 

35 "attribute2": ["val1", "val2", "val3"], 

36 } 

37 ) 

38 

39 @pytest.fixture() 

40 def sample_model(self, sample_dataframe) -> RecordsModel: 

41 """Create a sample RecordsModel for testing.""" 

42 return RecordsModel( 

43 dataframe=sample_dataframe, 

44 dataframe_name="test_dataset", 

45 id_column="id", 

46 name_column="name", 

47 columns=["attribute1", "attribute2"], 

48 ) 

49 

50 @pytest.fixture() 

51 def populated_api(self, api_instance, sample_model) -> MatchEntityRecords: 

52 """Create a populated API instance with data.""" 

53 api_instance.add_df_to_model(sample_model) 

54 return api_instance 

55 

56 

57class TestTotalRecords(TestMatchEntityRecords): 

58 def test_total_records_empty(self, api_instance) -> None: 

59 """Test total_records with no data.""" 

60 assert api_instance.total_records == 0 

61 

62 def test_total_records_single_dataset(self, populated_api) -> None: 

63 """Test total_records with one dataset.""" 

64 assert populated_api.total_records == 3 

65 

66 def test_total_records_multiple_datasets(self, api_instance, sample_dataframe) -> None: 

67 """Test total_records with multiple datasets.""" 

68 # Clear existing data first 

69 api_instance.clear_model_dfs() 

70 

71 model1 = RecordsModel( 

72 dataframe=sample_dataframe, 

73 dataframe_name="dataset1", 

74 id_column="id", 

75 name_column="name", 

76 columns=["attribute1"], 

77 ) 

78 model2 = RecordsModel( 

79 dataframe=sample_dataframe.head(2), 

80 dataframe_name="dataset2", 

81 id_column="id", 

82 name_column="name", 

83 columns=["attribute1"], 

84 ) 

85 api_instance.add_df_to_model(model1) 

86 api_instance.add_df_to_model(model2) 

87 assert api_instance.total_records == 5 # 3 + 2 

88 

89 

90class TestAttributeOptions(TestMatchEntityRecords): 

91 def test_attribute_options_empty(self, api_instance) -> None: 

92 """Test attribute_options with no data.""" 

93 # Clear any existing data 

94 api_instance.clear_model_dfs() 

95 options = api_instance.attribute_options 

96 assert isinstance(options, list) 

97 assert len(options) == 0 

98 

99 def test_attribute_options_populated(self, populated_api) -> None: 

100 """Test attribute_options with data.""" 

101 options = populated_api.attribute_options 

102 assert isinstance(options, list) 

103 assert len(options) > 0 

104 # Options should be in format "column::dataset" 

105 for option in options: 

106 assert "::" in option 

107 

108 

109class TestIntegratedResults(TestMatchEntityRecords): 

110 def test_integrated_results_empty(self, api_instance) -> None: 

111 """Test integrated_results with empty data creates empty dataframe.""" 

112 # Initialize with empty dataframes to avoid panic 

113 api_instance.evaluations_df = pl.DataFrame( 

114 schema={"Group ID": pl.Int64, "Relatedness": pl.Float64, "Explanation": pl.Utf8} 

115 ) 

116 api_instance.matches_df = pl.DataFrame( 

117 schema={"Group ID": pl.Int64, "Entity name": pl.Utf8, "Dataset": pl.Utf8} 

118 ) 

119 result = api_instance.integrated_results 

120 assert isinstance(result, pl.DataFrame) 

121 assert result.is_empty() 

122 

123 def test_integrated_results_with_data(self, api_instance) -> None: 

124 """Test integrated_results with matches and evaluations.""" 

125 # Create sample matches_df 

126 api_instance.matches_df = pl.DataFrame( 

127 { 

128 "Group ID": [1, 2], 

129 "Entity name": ["A", "B"], 

130 "Dataset": ["test", "test"], 

131 } 

132 ) 

133 # Create sample evaluations_df 

134 api_instance.evaluations_df = pl.DataFrame( 

135 { 

136 "Group ID": [1, 2], 

137 "Relatedness": [8, 6], 

138 "Explanation": ["Similar", "Somewhat similar"], 

139 } 

140 ) 

141 result = api_instance.integrated_results 

142 assert isinstance(result, pl.DataFrame) 

143 assert not result.is_empty() 

144 assert "Group ID" in result.columns 

145 assert "Relatedness" in result.columns 

146 

147 

148class TestAddDfToModel(TestMatchEntityRecords): 

149 def test_add_df_to_model_basic(self, api_instance, sample_model) -> None: 

150 """Test adding a dataframe to the model.""" 

151 result = api_instance.add_df_to_model(sample_model) 

152 assert isinstance(result, pl.DataFrame) 

153 assert not result.is_empty() 

154 assert "test_dataset" in api_instance.model_dfs 

155 

156 def test_add_df_to_model_auto_name(self, api_instance, sample_dataframe) -> None: 

157 """Test adding a dataframe without a name.""" 

158 # Clear existing data first 

159 api_instance.clear_model_dfs() 

160 

161 model = RecordsModel( 

162 dataframe=sample_dataframe, 

163 dataframe_name="", 

164 id_column="id", 

165 name_column="name", 

166 columns=["attribute1"], 

167 ) 

168 result = api_instance.add_df_to_model(model) 

169 assert isinstance(result, pl.DataFrame) 

170 assert len(api_instance.model_dfs) == 1 

171 

172 def test_add_df_to_model_with_max_rows(self, api_instance, sample_model) -> None: 

173 """Test adding a dataframe with max_rows limit.""" 

174 api_instance.max_rows_to_process = 2 

175 result = api_instance.add_df_to_model(sample_model) 

176 assert len(result) <= 2 

177 

178 

179class TestBuildModelDf(TestMatchEntityRecords): 

180 @pytest.fixture() 

181 def attributes_list(self) -> list[AttributeToMatch]: 

182 """Create sample attributes list.""" 

183 return [ 

184 {"label": "Attribute 1", "columns": ["attribute1::test_dataset"]}, 

185 {"label": "Attribute 2", "columns": ["attribute2::test_dataset"]}, 

186 ] 

187 

188 def test_build_model_df_basic(self, populated_api, attributes_list) -> None: 

189 """Test building model dataframe.""" 

190 result = populated_api.build_model_df(attributes_list) 

191 assert isinstance(result, pl.DataFrame) 

192 assert not result.is_empty() 

193 # Check that the model_df was set correctly 

194 assert hasattr(populated_api, "model_df") 

195 assert "Dataset" in result.columns 

196 

197 def test_build_model_df_creates_sentences( 

198 self, populated_api, attributes_list 

199 ) -> None: 

200 """Test that building model df creates sentence vector data.""" 

201 populated_api.build_model_df(attributes_list) 

202 assert hasattr(populated_api, "sentences_vector_data") 

203 assert isinstance(populated_api.sentences_vector_data, list) 

204 

205 

206class TestEmbedSentences(TestMatchEntityRecords): 

207 @pytest.mark.asyncio() 

208 async def test_embed_sentences(self, populated_api) -> None: 

209 """Test embedding sentences.""" 

210 # Setup mock embedder 

211 mock_embedder = AsyncMock() 

212 mock_embedder.embed_store_many = AsyncMock( 

213 return_value=[ 

214 {"text": "sentence1", "vector": [0.1, 0.2, 0.3]}, 

215 {"text": "sentence2", "vector": [0.4, 0.5, 0.6]}, 

216 ] 

217 ) 

218 populated_api.embedder = mock_embedder 

219 populated_api.cache_embeddings = True 

220 

221 # Setup sentences_vector_data 

222 populated_api.sentences_vector_data = [ 

223 {"text": "sentence1"}, 

224 {"text": "sentence2"}, 

225 ] 

226 

227 await populated_api.embed_sentences() 

228 

229 assert hasattr(populated_api, "all_sentences") 

230 assert hasattr(populated_api, "embeddings") 

231 assert len(populated_api.all_sentences) == 2 

232 assert len(populated_api.embeddings) == 2 

233 assert isinstance(populated_api.embeddings[0], np.ndarray) 

234 mock_embedder.embed_store_many.assert_called_once() 

235 

236 

237class TestDetectRecordGroups(TestMatchEntityRecords): 

238 def test_detect_record_groups(self, populated_api) -> None: 

239 """Test detecting record groups.""" 

240 # Setup required data with enough embeddings for default n_neighbors (50) 

241 # Use 60 embeddings to be safe 

242 embeddings_list = [np.random.rand(3) for _ in range(60)] 

243 populated_api.embeddings = embeddings_list 

244 populated_api.all_sentences = [f"sent{i}" for i in range(60)] 

245 

246 # Create model_df with matching number of rows 

247 entity_ids = [str(i) for i in range(60)] 

248 entity_names = [f"Entity{i}" for i in range(60)] 

249 datasets = ["test"] * 60 

250 unique_ids = [f"{i}::test" for i in range(60)] 

251 

252 populated_api.model_df = pl.DataFrame( 

253 { 

254 "Entity ID": entity_ids, 

255 "Entity name": entity_names, 

256 "Dataset": datasets, 

257 "Unique ID": unique_ids, 

258 } 

259 ) 

260 

261 result = populated_api.detect_record_groups( 

262 pair_embedding_threshold=80, pair_jaccard_threshold=50 

263 ) 

264 

265 assert isinstance(result, pl.DataFrame) 

266 assert hasattr(populated_api, "matches_df") 

267 

268 

269class TestEvaluateGroups(TestMatchEntityRecords): 

270 @pytest.mark.asyncio() 

271 async def test_evaluate_groups_basic(self, populated_api) -> None: 

272 """Test evaluating groups.""" 

273 # Setup model_df 

274 populated_api.model_df = pl.DataFrame( 

275 { 

276 "Entity ID": ["1", "2"], 

277 "Entity name": ["A", "B"], 

278 "Dataset": ["test", "test"], 

279 "Name similarity": [0.9, 0.8], 

280 "Attribute 1": ["val1", "val2"], 

281 } 

282 ) 

283 

284 # Mock AI configuration and client 

285 populated_api.ai_configuration = MagicMock() 

286 mock_response = "1,9,Very similar entities\n2,7,Somewhat related" 

287 

288 with patch( 

289 "intelligence_toolkit.match_entity_records.api.OpenAIClient" 

290 ) as mock_client: 

291 mock_instance = MagicMock() 

292 mock_instance.generate_chat_async = AsyncMock(return_value=mock_response) 

293 mock_client.return_value = mock_instance 

294 

295 result = await populated_api.evaluate_groups() 

296 

297 assert isinstance(result, str) 

298 assert hasattr(populated_api, "evaluations_df") 

299 assert isinstance(populated_api.evaluations_df, pl.DataFrame) 

300 

301 @pytest.mark.asyncio() 

302 async def test_evaluate_groups_with_callbacks(self, populated_api) -> None: 

303 """Test evaluating groups with callbacks.""" 

304 populated_api.model_df = pl.DataFrame( 

305 { 

306 "Entity ID": ["1"], 

307 "Entity name": ["A"], 

308 "Dataset": ["test"], 

309 "Name similarity": [0.9], 

310 } 

311 ) 

312 

313 populated_api.ai_configuration = MagicMock() 

314 mock_response = "1,9,Similar" 

315 callbacks = [MagicMock(spec=LLMCallback)] 

316 

317 with patch( 

318 "intelligence_toolkit.match_entity_records.api.OpenAIClient" 

319 ) as mock_client: 

320 mock_instance = MagicMock() 

321 mock_instance.generate_chat_async = AsyncMock(return_value=mock_response) 

322 mock_client.return_value = mock_instance 

323 

324 result = await populated_api.evaluate_groups(callbacks=callbacks) 

325 assert isinstance(result, str) 

326 

327 

328class TestClearModelDfs(TestMatchEntityRecords): 

329 def test_clear_model_dfs_empty(self, api_instance) -> None: 

330 """Test clearing model dfs when empty.""" 

331 api_instance.clear_model_dfs() 

332 assert len(api_instance.model_dfs) == 0 

333 

334 def test_clear_model_dfs_populated(self, populated_api) -> None: 

335 """Test clearing model dfs with data.""" 

336 assert len(populated_api.model_dfs) > 0 

337 populated_api.clear_model_dfs() 

338 assert len(populated_api.model_dfs) == 0