Coverage for intelligence_toolkit/tests/unit/detect_entity_networks/test_prepare_model.py: 100%

239 statements  

« prev     ^ index     » next       coverage.py v7.10.7, created at 2025-10-16 13:41 -0300

1# Copyright (c) 2024 Microsoft Corporation. All rights reserved. 

2# Licensed under the MIT license. See LICENSE file in the project. 

3# 

4 

5from typing import Literal 

6from unittest.mock import patch 

7 

8import polars as pl 

9import pytest 

10 

11from intelligence_toolkit.detect_entity_networks.classes import FlagAggregatorType 

12from intelligence_toolkit.detect_entity_networks.prepare_model import ( 

13 build_flag_links, 

14 build_flags, 

15 build_groups, 

16 build_main_graph, 

17 clean_text, 

18 format_data_columns, 

19 generate_attribute_links, 

20 transform_entity, 

21) 

22 

23 

24class TestCleanText: 

25 def test_remove_punctuation(self) -> None: 

26 assert clean_text("Hello, world!") == "Hello world" 

27 

28 def test_remove_special_characters(self) -> None: 

29 assert clean_text("Hello, world!") == "Hello world" 

30 

31 def test_reduce_multiple_spaces_to_single(self) -> None: 

32 assert clean_text("Hello world") == "Hello world" 

33 

34 def test_allow_special_characters(self) -> None: 

35 assert ( 

36 clean_text("Email me@home.com & bring snacks+") 

37 == "Email me@homecom & bring snacks+" 

38 ) 

39 

40 def test_combined_scenarios(self) -> None: 

41 assert ( 

42 clean_text("Hello, world! Email me@home.com & bring snacks+") 

43 == "Hello world Email me@homecom & bring snacks+" 

44 ) 

45 

46 

47class TestFormatDataColumns: 

48 def test_multiple_columns(self) -> None: 

49 initial_df = pl.DataFrame( 

50 { 

51 "entity_id": ["123 ", " 456"], 

52 "name": ["John Doe", "Jane Doe"], 

53 "email": ["john@doe.com", "jane@doe.com"], 

54 } 

55 ) 

56 expected_df = pl.DataFrame( 

57 { 

58 "entity_id": ["123", "456"], 

59 "name": ["John Doe", "Jane Doe"], 

60 "email": ["john@doecom", "jane@doecom"], 

61 } 

62 ) 

63 columns_to_link = ["name", "email"] 

64 entity_id_column = "entity_id" 

65 

66 result_df = format_data_columns(initial_df, columns_to_link, entity_id_column) 

67 

68 assert result_df.equals(expected_df) 

69 

70 @patch("re.sub") 

71 def test_empty_dataframe(self, mock_clean_text) -> None: 

72 # Setup 

73 mock_clean_text.side_effect = lambda x: x 

74 initial_df = pl.DataFrame({"entity_id": [], "name": [], "email": []}) 

75 columns_to_link = ["name", "email"] 

76 entity_id_column = "entity_id" 

77 

78 # Exercise 

79 result_df = format_data_columns(initial_df, columns_to_link, entity_id_column) 

80 

81 assert mock_clean_text.call_count == 0 

82 assert result_df.equals(initial_df) 

83 

84 @patch("re.sub") 

85 def test_special_characters_in_entity_id(self, mock_clean_text) -> None: 

86 # Setup 

87 mock_clean_text.side_effect = lambda _x, _y, _z: "cleaned" 

88 initial_df = pl.DataFrame( 

89 { 

90 "entity_id": ["@123!", "#456$"], 

91 "name": ["John Doe", "Jane Doe"], 

92 } 

93 ) 

94 columns_to_link = ["name"] 

95 entity_id_column = "entity_id" 

96 

97 result_df = format_data_columns(initial_df, columns_to_link, entity_id_column) 

98 

99 assert mock_clean_text.call_count == 8 # 4 for entity_id + 4 for name 

100 for val in result_df[entity_id_column]: 

101 assert val == "cleaned" 

102 

103 

104class TestPrepareEntityAttribute: 

105 @pytest.fixture() 

106 def data(self) -> pl.DataFrame: 

107 return pl.DataFrame( 

108 { 

109 "entity_id": [1, 2, 3], 

110 "attribute1": ["A", "B", "A"], 

111 "attribute2": ["X", "Y", "X"], 

112 } 

113 ) 

114 

115 def test_column_name(self, data) -> None: 

116 entity_id_column = "entity_id" 

117 columns_to_link = ["attribute1", "attribute2"] 

118 entity_links = generate_attribute_links( 

119 data, 

120 entity_id_column, 

121 columns_to_link, 

122 ) 

123 assert len(entity_links) == 2 

124 

125 

126class TestBuildUndirectedGraph: 

127 def test_graph_empty(self) -> None: 

128 result = build_main_graph() 

129 assert result.size() == 0 

130 

131 def test_attribute_links(self) -> None: 

132 network_attribute_links = [ 

133 [("Entity1", "attribute", "Value1"), ("Entity2", "attribute", "Value2")], 

134 [("Entity3", "relation", "Entity4"), ("Entity5", "relation", "Entity6")], 

135 [("Entity7", "attribute", "Value3")], 

136 ] 

137 result = build_main_graph(network_attribute_links) 

138 expected_nodes = [ 

139 "ENTITY==Entity1", 

140 "attribute==Value1", 

141 "ENTITY==Entity2", 

142 "attribute==Value2", 

143 "ENTITY==Entity3", 

144 "relation==Entity4", 

145 "ENTITY==Entity5", 

146 "relation==Entity6", 

147 "ENTITY==Entity7", 

148 "attribute==Value3", 

149 ] 

150 

151 expected_edges = [ 

152 ("ENTITY==Entity1", "attribute==Value1"), 

153 ("ENTITY==Entity2", "attribute==Value2"), 

154 ("ENTITY==Entity3", "relation==Entity4"), 

155 ("ENTITY==Entity5", "relation==Entity6"), 

156 ("ENTITY==Entity7", "attribute==Value3"), 

157 ] 

158 for node in expected_nodes: 

159 assert result.has_node(node) 

160 for edge in expected_edges: 

161 assert result.has_edge(edge[0], edge[1]) 

162 

163 

164class TestBuildFlagLinks: 

165 @pytest.fixture() 

166 def df_flag(self): 

167 return pl.DataFrame( 

168 { 

169 "Entity_N": ["A", "C", "D", "F", "Z"], 

170 "flags_numb": [1, 2, 3, 0, 3], 

171 } 

172 ) 

173 

174 def test_prepare_count(self, df_flag): 

175 entity_col = "Entity_N" 

176 flag_agg = FlagAggregatorType.Count 

177 flag_columns = ["flags_numb"] 

178 

179 result = build_flag_links(df_flag, entity_col, flag_agg, flag_columns) 

180 

181 expected = [ 

182 ["A", "flags_numb", "flags_numb", 1], 

183 ["C", "flags_numb", "flags_numb", 2], 

184 ["D", "flags_numb", "flags_numb", 3], 

185 ["F", "flags_numb", "flags_numb", 0], 

186 ["Z", "flags_numb", "flags_numb", 3], 

187 ] 

188 

189 assert sorted(result) == sorted(expected) 

190 

191 def test_prepare_value_column_doesnt_exist(self, df_flag): 

192 entity_col = "Entity_N" 

193 flag_agg = FlagAggregatorType.Count 

194 flag_columns = ["flags_numb123"] 

195 msg = "Column flags_numb123 not found in the DataFrame." 

196 with pytest.raises(ValueError, match=msg): 

197 build_flag_links(df_flag, entity_col, flag_agg, flag_columns) 

198 

199 def test_prepare_entity_column_doesnt_exist(self, df_flag): 

200 entity_col = "Entity_N12" 

201 flag_agg = FlagAggregatorType.Count 

202 flag_columns = ["flags_numb"] 

203 msg = "Column Entity_N12 not found in the DataFrame." 

204 with pytest.raises(ValueError, match=msg): 

205 build_flag_links(df_flag, entity_col, flag_agg, flag_columns) 

206 

207 def test_prepare_count_existing(self, df_flag): 

208 entity_col = "Entity_N" 

209 flag_agg = FlagAggregatorType.Count 

210 flag_columns = ["flags_numb"] 

211 

212 existing_flags = [["E", "flags_numb1", "flags_numb1", 2]] 

213 result = build_flag_links( 

214 df_flag, entity_col, flag_agg, flag_columns, existing_flags 

215 ) 

216 

217 expected = [ 

218 ["A", "flags_numb", "flags_numb", 1], 

219 ["C", "flags_numb", "flags_numb", 2], 

220 ["D", "flags_numb", "flags_numb", 3], 

221 ["F", "flags_numb", "flags_numb", 0], 

222 ["Z", "flags_numb", "flags_numb", 3], 

223 ["E", "flags_numb1", "flags_numb1", 2], 

224 ] 

225 

226 assert sorted(result) == sorted(expected) 

227 

228 def test_prepare_instance(self, df_flag): 

229 entity_col = "Entity_N" 

230 flag_agg = FlagAggregatorType.Instance 

231 flag_columns = ["flags_numb"] 

232 

233 result = build_flag_links(df_flag, entity_col, flag_agg, flag_columns) 

234 

235 expected = [ 

236 ["A", "flags_numb", 1, 1], 

237 ["C", "flags_numb", 2, 1], 

238 ["D", "flags_numb", 3, 1], 

239 ["F", "flags_numb", 0, 1], 

240 ["Z", "flags_numb", 3, 1], 

241 ] 

242 

243 assert sorted(result) == sorted(expected) 

244 

245 def test_prepare_instance_agg(self, df_flag): 

246 # add one row to the dataframe 

247 df_flag = pl.concat( 

248 [df_flag, pl.DataFrame({"Entity_N": ["A"], "flags_numb": [2]})] 

249 ) 

250 

251 entity_col = "Entity_N" 

252 flag_agg = FlagAggregatorType.Instance 

253 flag_columns = ["flags_numb"] 

254 

255 result = build_flag_links(df_flag, entity_col, flag_agg, flag_columns) 

256 

257 expected = [ 

258 ["A", "flags_numb", 3, 1], 

259 ["C", "flags_numb", 2, 1], 

260 ["D", "flags_numb", 3, 1], 

261 ["F", "flags_numb", 0, 1], 

262 ["Z", "flags_numb", 3, 1], 

263 ] 

264 

265 assert sorted(result) == sorted(expected) 

266 

267 

268class TestBuildFlags: 

269 @pytest.fixture() 

270 def link_list_integrated(self) -> list[list]: 

271 return [ 

272 ["A", "flags_numb", 3, 1], 

273 ["C", "flags_numb", 2, 1], 

274 ["D", "flags_numb", 3, 1], 

275 ["F", "flags_numb", 0, 1], 

276 ["Z", "flags_numb", 3, 1], 

277 ] 

278 

279 @pytest.fixture() 

280 def link_list_count(self) -> list[list]: 

281 return [ 

282 ["A", "flags_numb", "flags_numb", 3], 

283 ["C", "flags_numb", "flags_numb", 2], 

284 ["D", "flags_numb", "flags_numb", 3], 

285 ["F", "flags_numb", "flags_numb", 0], 

286 ["Z", "flags_numb", "flags_numb", 3], 

287 ] 

288 

289 def test_flags_list_empty(self) -> None: 

290 flags, max_entity_flags, mean_entity_flags = build_flags() 

291 

292 expected_flags = pl.DataFrame() 

293 expected_max_entity_flags = 0 

294 expected_mean_entity_flags = 0 

295 

296 assert flags.equals(expected_flags) 

297 assert max_entity_flags == expected_max_entity_flags 

298 assert mean_entity_flags == expected_mean_entity_flags 

299 

300 def test_flags_integrated(self, link_list_integrated) -> None: 

301 flags, _, _ = build_flags(link_list_integrated) 

302 

303 expected = pl.DataFrame( 

304 { 

305 "entity": ["A", "C", "D", "F", "Z"], 

306 "type": [ 

307 "flags_numb", 

308 "flags_numb", 

309 "flags_numb", 

310 "flags_numb", 

311 "flags_numb", 

312 ], 

313 "flag": [3, 2, 3, 0, 3], 

314 "count": [1, 1, 1, 1, 1], 

315 "qualified_entity": [ 

316 "ENTITY==A", 

317 "ENTITY==C", 

318 "ENTITY==D", 

319 "ENTITY==F", 

320 "ENTITY==Z", 

321 ], 

322 } 

323 ) 

324 

325 df1_sorted = flags.sort(by=["entity"]) 

326 df2_sorted = expected.sort(by=["entity"]) 

327 

328 assert df1_sorted.equals(df2_sorted) 

329 

330 def test_max_entity_flags_integrated(self, link_list_integrated) -> None: 

331 _, max_entity_flags, _ = build_flags(link_list_integrated) 

332 

333 expected = 1 

334 

335 assert max_entity_flags == expected 

336 

337 def test_mean_entity_flags_integrated(self, link_list_integrated) -> None: 

338 _, _, mean_entity_flags = build_flags(link_list_integrated) 

339 

340 expected = 1 

341 

342 assert mean_entity_flags == expected 

343 

344 def test_flags_count(self, link_list_count): 

345 flags, _, _ = build_flags(link_list_count) 

346 

347 expected = pl.DataFrame( 

348 { 

349 "entity": ["A", "C", "D", "F", "Z"], 

350 "type": [ 

351 "flags_numb", 

352 "flags_numb", 

353 "flags_numb", 

354 "flags_numb", 

355 "flags_numb", 

356 ], 

357 "flag": [ 

358 "flags_numb", 

359 "flags_numb", 

360 "flags_numb", 

361 "flags_numb", 

362 "flags_numb", 

363 ], 

364 "count": [3, 2, 3, 0, 3], 

365 "qualified_entity": [ 

366 "ENTITY==A", 

367 "ENTITY==C", 

368 "ENTITY==D", 

369 "ENTITY==F", 

370 "ENTITY==Z", 

371 ], 

372 } 

373 ) 

374 

375 flags_sorted = flags.sort(by=["entity"]) 

376 

377 assert flags_sorted.equals(expected) 

378 

379 def test_flags_count_sum(self, link_list_count) -> None: 

380 link_list_count.append(["A", "flags_numb", "flags_numb", 5]) 

381 flags, _, _ = build_flags(link_list_count) 

382 

383 expected = pl.DataFrame( 

384 { 

385 "entity": ["A", "C", "D", "F", "Z"], 

386 "type": [ 

387 "flags_numb", 

388 "flags_numb", 

389 "flags_numb", 

390 "flags_numb", 

391 "flags_numb", 

392 ], 

393 "flag": [ 

394 "flags_numb", 

395 "flags_numb", 

396 "flags_numb", 

397 "flags_numb", 

398 "flags_numb", 

399 ], 

400 "count": [8, 2, 3, 0, 3], 

401 "qualified_entity": [ 

402 "ENTITY==A", 

403 "ENTITY==C", 

404 "ENTITY==D", 

405 "ENTITY==F", 

406 "ENTITY==Z", 

407 ], 

408 } 

409 ) 

410 

411 flags_sorted = flags.sort(by=["entity"]) 

412 

413 assert flags_sorted.equals(expected) 

414 

415 def test_max_entity_flags_count_sum(self, link_list_count) -> None: 

416 link_list_count.append(["A", "flags_numb", "flags_numb", 5]) 

417 result = build_flags(link_list_count) 

418 

419 expected = 8 

420 

421 assert result[1] == expected 

422 

423 def test_mean_entity_flags_count_sum(self, link_list_count) -> None: 

424 link_list_count.append(["A", "flags_numb", "flags_numb", 5]) 

425 result = build_flags(link_list_count) 

426 

427 expected = 4.0 

428 

429 assert result[2] == expected 

430 

431 def test_max_entity_flags_count(self, link_list_count) -> None: 

432 result = build_flags(link_list_count) 

433 

434 expected = 3 

435 

436 assert result[1] == expected 

437 

438 def test_mean_entity_flags_count(self, link_list_count) -> None: 

439 result = build_flags(link_list_count) 

440 

441 expected = 2.75 

442 

443 assert result[2] == expected 

444 

445 

446class TestTransformEntity: 

447 def test_transform_entity_basic(self) -> None: 

448 entity = "12345" 

449 expected = "ENTITY==12345" 

450 assert transform_entity(entity) == expected 

451 

452 def test_transform_entity_empty_string(self) -> None: 

453 entity = "" 

454 expected = "ENTITY==" 

455 assert transform_entity(entity) == expected 

456 

457 def test_transform_entity_special_characters(self) -> None: 

458 entity = "@$%^&*()" 

459 expected = "ENTITY==@$%^&*()" 

460 assert transform_entity(entity) == expected 

461 

462 def test_transform_entity_numeric(self) -> None: 

463 entity = "9876543210" 

464 expected = "ENTITY==9876543210" 

465 assert transform_entity(entity) == expected 

466 

467 def test_transform_entity_whitespace(self) -> None: 

468 entity = " " 

469 expected = "ENTITY== " 

470 assert transform_entity(entity) == expected 

471 

472 def test_transform_entity_none(self) -> None: 

473 entity = None 

474 expected = "ENTITY==None" 

475 assert transform_entity(entity) == expected 

476 

477 

478class TestBuildGroups: 

479 @pytest.fixture() 

480 def df_groups(self): 

481 return pl.DataFrame( 

482 { 

483 "entity_id": ["A", "B", "C", "D"], 

484 "attribute1": ["X", "Y", "Z", "X"], 

485 "attribute2": ["X", "Y", "Y", "X"], 

486 } 

487 ) 

488 

489 @pytest.fixture() 

490 def entity_col(self) -> Literal["entity_id"]: 

491 return "entity_id" 

492 

493 def test_build_groups(self, df_groups, entity_col) -> None: 

494 value_cols = ["attribute1"] 

495 group_links = build_groups(value_cols, df_groups, entity_col) 

496 

497 expected_group_links = [ 

498 [ 

499 ["A", "attribute1", "X"], 

500 ["B", "attribute1", "Y"], 

501 ["C", "attribute1", "Z"], 

502 ["D", "attribute1", "X"], 

503 ] 

504 ] 

505 

506 assert sorted(group_links) == expected_group_links 

507 

508 def test_build_groups_existing_groups(self, df_groups, entity_col) -> None: 

509 value_cols = ["attribute1"] 

510 existing_groups_links = [ 

511 ["Z", "attribute2", "X"], 

512 ] 

513 group_links = build_groups( 

514 value_cols, df_groups, entity_col, existing_groups_links 

515 ) 

516 

517 expected_group_links = [ 

518 ["Z", "attribute2", "X"], 

519 [ 

520 ["A", "attribute1", "X"], 

521 ["B", "attribute1", "Y"], 

522 ["C", "attribute1", "Z"], 

523 ["D", "attribute1", "X"], 

524 ], 

525 ] 

526 

527 assert group_links == expected_group_links 

528 

529 def test_build_groups_two_columns(self, df_groups, entity_col) -> None: 

530 value_cols = ["attribute1", "attribute2"] 

531 group_links = build_groups(value_cols, df_groups, entity_col) 

532 

533 expected_group_links = [ 

534 [ 

535 ["A", "attribute1", "X"], 

536 ["B", "attribute1", "Y"], 

537 ["C", "attribute1", "Z"], 

538 ["D", "attribute1", "X"], 

539 ], 

540 [ 

541 ["A", "attribute2", "X"], 

542 ["B", "attribute2", "Y"], 

543 ["C", "attribute2", "Y"], 

544 ["D", "attribute2", "X"], 

545 ], 

546 ] 

547 

548 assert sorted(group_links[0]) == expected_group_links[0] 

549 assert sorted(group_links[1]) == expected_group_links[1] 

550 

551 def test_build_groups_column_empty(self, df_groups, entity_col) -> None: 

552 value_cols = [] 

553 group_links = build_groups(value_cols, df_groups, entity_col) 

554 

555 assert group_links == [] 

556 

557 def test_build_groups_df_empty(self, entity_col) -> None: 

558 value_cols = ["attribute1"] 

559 df_groups = pl.DataFrame() 

560 group_links = build_groups(value_cols, df_groups, entity_col) 

561 

562 assert group_links == [] 

563 

564 def test_build_groups_column_doesnt_exists(self, entity_col, df_groups) -> None: 

565 value_cols = ["attribute12"] 

566 with pytest.raises( 

567 ValueError, 

568 match="Column attribute12 not found in the DataFrame.", 

569 ): 

570 build_groups(value_cols, df_groups, entity_col) 

571 

572 def test_build_groups_entity_doesnt_exist(self, df_groups) -> None: 

573 entity_col = "entity123" 

574 value_cols = ["attribute1"] 

575 with pytest.raises( 

576 ValueError, 

577 match="Column entity123 not found in the DataFrame.", 

578 ): 

579 build_groups(value_cols, df_groups, entity_col)