Coverage for intelligence_toolkit/tests/unit/detect_entity_networks/test_prepare_model.py: 100%
239 statements
« prev ^ index » next coverage.py v7.10.7, created at 2025-10-16 13:41 -0300
« prev ^ index » next coverage.py v7.10.7, created at 2025-10-16 13:41 -0300
1# Copyright (c) 2024 Microsoft Corporation. All rights reserved.
2# Licensed under the MIT license. See LICENSE file in the project.
3#
5from typing import Literal
6from unittest.mock import patch
8import polars as pl
9import pytest
11from intelligence_toolkit.detect_entity_networks.classes import FlagAggregatorType
12from intelligence_toolkit.detect_entity_networks.prepare_model import (
13 build_flag_links,
14 build_flags,
15 build_groups,
16 build_main_graph,
17 clean_text,
18 format_data_columns,
19 generate_attribute_links,
20 transform_entity,
21)
24class TestCleanText:
25 def test_remove_punctuation(self) -> None:
26 assert clean_text("Hello, world!") == "Hello world"
28 def test_remove_special_characters(self) -> None:
29 assert clean_text("Hello, world!") == "Hello world"
31 def test_reduce_multiple_spaces_to_single(self) -> None:
32 assert clean_text("Hello world") == "Hello world"
34 def test_allow_special_characters(self) -> None:
35 assert (
36 clean_text("Email me@home.com & bring snacks+")
37 == "Email me@homecom & bring snacks+"
38 )
40 def test_combined_scenarios(self) -> None:
41 assert (
42 clean_text("Hello, world! Email me@home.com & bring snacks+")
43 == "Hello world Email me@homecom & bring snacks+"
44 )
47class TestFormatDataColumns:
48 def test_multiple_columns(self) -> None:
49 initial_df = pl.DataFrame(
50 {
51 "entity_id": ["123 ", " 456"],
52 "name": ["John Doe", "Jane Doe"],
53 "email": ["john@doe.com", "jane@doe.com"],
54 }
55 )
56 expected_df = pl.DataFrame(
57 {
58 "entity_id": ["123", "456"],
59 "name": ["John Doe", "Jane Doe"],
60 "email": ["john@doecom", "jane@doecom"],
61 }
62 )
63 columns_to_link = ["name", "email"]
64 entity_id_column = "entity_id"
66 result_df = format_data_columns(initial_df, columns_to_link, entity_id_column)
68 assert result_df.equals(expected_df)
70 @patch("re.sub")
71 def test_empty_dataframe(self, mock_clean_text) -> None:
72 # Setup
73 mock_clean_text.side_effect = lambda x: x
74 initial_df = pl.DataFrame({"entity_id": [], "name": [], "email": []})
75 columns_to_link = ["name", "email"]
76 entity_id_column = "entity_id"
78 # Exercise
79 result_df = format_data_columns(initial_df, columns_to_link, entity_id_column)
81 assert mock_clean_text.call_count == 0
82 assert result_df.equals(initial_df)
84 @patch("re.sub")
85 def test_special_characters_in_entity_id(self, mock_clean_text) -> None:
86 # Setup
87 mock_clean_text.side_effect = lambda _x, _y, _z: "cleaned"
88 initial_df = pl.DataFrame(
89 {
90 "entity_id": ["@123!", "#456$"],
91 "name": ["John Doe", "Jane Doe"],
92 }
93 )
94 columns_to_link = ["name"]
95 entity_id_column = "entity_id"
97 result_df = format_data_columns(initial_df, columns_to_link, entity_id_column)
99 assert mock_clean_text.call_count == 8 # 4 for entity_id + 4 for name
100 for val in result_df[entity_id_column]:
101 assert val == "cleaned"
104class TestPrepareEntityAttribute:
105 @pytest.fixture()
106 def data(self) -> pl.DataFrame:
107 return pl.DataFrame(
108 {
109 "entity_id": [1, 2, 3],
110 "attribute1": ["A", "B", "A"],
111 "attribute2": ["X", "Y", "X"],
112 }
113 )
115 def test_column_name(self, data) -> None:
116 entity_id_column = "entity_id"
117 columns_to_link = ["attribute1", "attribute2"]
118 entity_links = generate_attribute_links(
119 data,
120 entity_id_column,
121 columns_to_link,
122 )
123 assert len(entity_links) == 2
126class TestBuildUndirectedGraph:
127 def test_graph_empty(self) -> None:
128 result = build_main_graph()
129 assert result.size() == 0
131 def test_attribute_links(self) -> None:
132 network_attribute_links = [
133 [("Entity1", "attribute", "Value1"), ("Entity2", "attribute", "Value2")],
134 [("Entity3", "relation", "Entity4"), ("Entity5", "relation", "Entity6")],
135 [("Entity7", "attribute", "Value3")],
136 ]
137 result = build_main_graph(network_attribute_links)
138 expected_nodes = [
139 "ENTITY==Entity1",
140 "attribute==Value1",
141 "ENTITY==Entity2",
142 "attribute==Value2",
143 "ENTITY==Entity3",
144 "relation==Entity4",
145 "ENTITY==Entity5",
146 "relation==Entity6",
147 "ENTITY==Entity7",
148 "attribute==Value3",
149 ]
151 expected_edges = [
152 ("ENTITY==Entity1", "attribute==Value1"),
153 ("ENTITY==Entity2", "attribute==Value2"),
154 ("ENTITY==Entity3", "relation==Entity4"),
155 ("ENTITY==Entity5", "relation==Entity6"),
156 ("ENTITY==Entity7", "attribute==Value3"),
157 ]
158 for node in expected_nodes:
159 assert result.has_node(node)
160 for edge in expected_edges:
161 assert result.has_edge(edge[0], edge[1])
164class TestBuildFlagLinks:
165 @pytest.fixture()
166 def df_flag(self):
167 return pl.DataFrame(
168 {
169 "Entity_N": ["A", "C", "D", "F", "Z"],
170 "flags_numb": [1, 2, 3, 0, 3],
171 }
172 )
174 def test_prepare_count(self, df_flag):
175 entity_col = "Entity_N"
176 flag_agg = FlagAggregatorType.Count
177 flag_columns = ["flags_numb"]
179 result = build_flag_links(df_flag, entity_col, flag_agg, flag_columns)
181 expected = [
182 ["A", "flags_numb", "flags_numb", 1],
183 ["C", "flags_numb", "flags_numb", 2],
184 ["D", "flags_numb", "flags_numb", 3],
185 ["F", "flags_numb", "flags_numb", 0],
186 ["Z", "flags_numb", "flags_numb", 3],
187 ]
189 assert sorted(result) == sorted(expected)
191 def test_prepare_value_column_doesnt_exist(self, df_flag):
192 entity_col = "Entity_N"
193 flag_agg = FlagAggregatorType.Count
194 flag_columns = ["flags_numb123"]
195 msg = "Column flags_numb123 not found in the DataFrame."
196 with pytest.raises(ValueError, match=msg):
197 build_flag_links(df_flag, entity_col, flag_agg, flag_columns)
199 def test_prepare_entity_column_doesnt_exist(self, df_flag):
200 entity_col = "Entity_N12"
201 flag_agg = FlagAggregatorType.Count
202 flag_columns = ["flags_numb"]
203 msg = "Column Entity_N12 not found in the DataFrame."
204 with pytest.raises(ValueError, match=msg):
205 build_flag_links(df_flag, entity_col, flag_agg, flag_columns)
207 def test_prepare_count_existing(self, df_flag):
208 entity_col = "Entity_N"
209 flag_agg = FlagAggregatorType.Count
210 flag_columns = ["flags_numb"]
212 existing_flags = [["E", "flags_numb1", "flags_numb1", 2]]
213 result = build_flag_links(
214 df_flag, entity_col, flag_agg, flag_columns, existing_flags
215 )
217 expected = [
218 ["A", "flags_numb", "flags_numb", 1],
219 ["C", "flags_numb", "flags_numb", 2],
220 ["D", "flags_numb", "flags_numb", 3],
221 ["F", "flags_numb", "flags_numb", 0],
222 ["Z", "flags_numb", "flags_numb", 3],
223 ["E", "flags_numb1", "flags_numb1", 2],
224 ]
226 assert sorted(result) == sorted(expected)
228 def test_prepare_instance(self, df_flag):
229 entity_col = "Entity_N"
230 flag_agg = FlagAggregatorType.Instance
231 flag_columns = ["flags_numb"]
233 result = build_flag_links(df_flag, entity_col, flag_agg, flag_columns)
235 expected = [
236 ["A", "flags_numb", 1, 1],
237 ["C", "flags_numb", 2, 1],
238 ["D", "flags_numb", 3, 1],
239 ["F", "flags_numb", 0, 1],
240 ["Z", "flags_numb", 3, 1],
241 ]
243 assert sorted(result) == sorted(expected)
245 def test_prepare_instance_agg(self, df_flag):
246 # add one row to the dataframe
247 df_flag = pl.concat(
248 [df_flag, pl.DataFrame({"Entity_N": ["A"], "flags_numb": [2]})]
249 )
251 entity_col = "Entity_N"
252 flag_agg = FlagAggregatorType.Instance
253 flag_columns = ["flags_numb"]
255 result = build_flag_links(df_flag, entity_col, flag_agg, flag_columns)
257 expected = [
258 ["A", "flags_numb", 3, 1],
259 ["C", "flags_numb", 2, 1],
260 ["D", "flags_numb", 3, 1],
261 ["F", "flags_numb", 0, 1],
262 ["Z", "flags_numb", 3, 1],
263 ]
265 assert sorted(result) == sorted(expected)
268class TestBuildFlags:
269 @pytest.fixture()
270 def link_list_integrated(self) -> list[list]:
271 return [
272 ["A", "flags_numb", 3, 1],
273 ["C", "flags_numb", 2, 1],
274 ["D", "flags_numb", 3, 1],
275 ["F", "flags_numb", 0, 1],
276 ["Z", "flags_numb", 3, 1],
277 ]
279 @pytest.fixture()
280 def link_list_count(self) -> list[list]:
281 return [
282 ["A", "flags_numb", "flags_numb", 3],
283 ["C", "flags_numb", "flags_numb", 2],
284 ["D", "flags_numb", "flags_numb", 3],
285 ["F", "flags_numb", "flags_numb", 0],
286 ["Z", "flags_numb", "flags_numb", 3],
287 ]
289 def test_flags_list_empty(self) -> None:
290 flags, max_entity_flags, mean_entity_flags = build_flags()
292 expected_flags = pl.DataFrame()
293 expected_max_entity_flags = 0
294 expected_mean_entity_flags = 0
296 assert flags.equals(expected_flags)
297 assert max_entity_flags == expected_max_entity_flags
298 assert mean_entity_flags == expected_mean_entity_flags
300 def test_flags_integrated(self, link_list_integrated) -> None:
301 flags, _, _ = build_flags(link_list_integrated)
303 expected = pl.DataFrame(
304 {
305 "entity": ["A", "C", "D", "F", "Z"],
306 "type": [
307 "flags_numb",
308 "flags_numb",
309 "flags_numb",
310 "flags_numb",
311 "flags_numb",
312 ],
313 "flag": [3, 2, 3, 0, 3],
314 "count": [1, 1, 1, 1, 1],
315 "qualified_entity": [
316 "ENTITY==A",
317 "ENTITY==C",
318 "ENTITY==D",
319 "ENTITY==F",
320 "ENTITY==Z",
321 ],
322 }
323 )
325 df1_sorted = flags.sort(by=["entity"])
326 df2_sorted = expected.sort(by=["entity"])
328 assert df1_sorted.equals(df2_sorted)
330 def test_max_entity_flags_integrated(self, link_list_integrated) -> None:
331 _, max_entity_flags, _ = build_flags(link_list_integrated)
333 expected = 1
335 assert max_entity_flags == expected
337 def test_mean_entity_flags_integrated(self, link_list_integrated) -> None:
338 _, _, mean_entity_flags = build_flags(link_list_integrated)
340 expected = 1
342 assert mean_entity_flags == expected
344 def test_flags_count(self, link_list_count):
345 flags, _, _ = build_flags(link_list_count)
347 expected = pl.DataFrame(
348 {
349 "entity": ["A", "C", "D", "F", "Z"],
350 "type": [
351 "flags_numb",
352 "flags_numb",
353 "flags_numb",
354 "flags_numb",
355 "flags_numb",
356 ],
357 "flag": [
358 "flags_numb",
359 "flags_numb",
360 "flags_numb",
361 "flags_numb",
362 "flags_numb",
363 ],
364 "count": [3, 2, 3, 0, 3],
365 "qualified_entity": [
366 "ENTITY==A",
367 "ENTITY==C",
368 "ENTITY==D",
369 "ENTITY==F",
370 "ENTITY==Z",
371 ],
372 }
373 )
375 flags_sorted = flags.sort(by=["entity"])
377 assert flags_sorted.equals(expected)
379 def test_flags_count_sum(self, link_list_count) -> None:
380 link_list_count.append(["A", "flags_numb", "flags_numb", 5])
381 flags, _, _ = build_flags(link_list_count)
383 expected = pl.DataFrame(
384 {
385 "entity": ["A", "C", "D", "F", "Z"],
386 "type": [
387 "flags_numb",
388 "flags_numb",
389 "flags_numb",
390 "flags_numb",
391 "flags_numb",
392 ],
393 "flag": [
394 "flags_numb",
395 "flags_numb",
396 "flags_numb",
397 "flags_numb",
398 "flags_numb",
399 ],
400 "count": [8, 2, 3, 0, 3],
401 "qualified_entity": [
402 "ENTITY==A",
403 "ENTITY==C",
404 "ENTITY==D",
405 "ENTITY==F",
406 "ENTITY==Z",
407 ],
408 }
409 )
411 flags_sorted = flags.sort(by=["entity"])
413 assert flags_sorted.equals(expected)
415 def test_max_entity_flags_count_sum(self, link_list_count) -> None:
416 link_list_count.append(["A", "flags_numb", "flags_numb", 5])
417 result = build_flags(link_list_count)
419 expected = 8
421 assert result[1] == expected
423 def test_mean_entity_flags_count_sum(self, link_list_count) -> None:
424 link_list_count.append(["A", "flags_numb", "flags_numb", 5])
425 result = build_flags(link_list_count)
427 expected = 4.0
429 assert result[2] == expected
431 def test_max_entity_flags_count(self, link_list_count) -> None:
432 result = build_flags(link_list_count)
434 expected = 3
436 assert result[1] == expected
438 def test_mean_entity_flags_count(self, link_list_count) -> None:
439 result = build_flags(link_list_count)
441 expected = 2.75
443 assert result[2] == expected
446class TestTransformEntity:
447 def test_transform_entity_basic(self) -> None:
448 entity = "12345"
449 expected = "ENTITY==12345"
450 assert transform_entity(entity) == expected
452 def test_transform_entity_empty_string(self) -> None:
453 entity = ""
454 expected = "ENTITY=="
455 assert transform_entity(entity) == expected
457 def test_transform_entity_special_characters(self) -> None:
458 entity = "@$%^&*()"
459 expected = "ENTITY==@$%^&*()"
460 assert transform_entity(entity) == expected
462 def test_transform_entity_numeric(self) -> None:
463 entity = "9876543210"
464 expected = "ENTITY==9876543210"
465 assert transform_entity(entity) == expected
467 def test_transform_entity_whitespace(self) -> None:
468 entity = " "
469 expected = "ENTITY== "
470 assert transform_entity(entity) == expected
472 def test_transform_entity_none(self) -> None:
473 entity = None
474 expected = "ENTITY==None"
475 assert transform_entity(entity) == expected
478class TestBuildGroups:
479 @pytest.fixture()
480 def df_groups(self):
481 return pl.DataFrame(
482 {
483 "entity_id": ["A", "B", "C", "D"],
484 "attribute1": ["X", "Y", "Z", "X"],
485 "attribute2": ["X", "Y", "Y", "X"],
486 }
487 )
489 @pytest.fixture()
490 def entity_col(self) -> Literal["entity_id"]:
491 return "entity_id"
493 def test_build_groups(self, df_groups, entity_col) -> None:
494 value_cols = ["attribute1"]
495 group_links = build_groups(value_cols, df_groups, entity_col)
497 expected_group_links = [
498 [
499 ["A", "attribute1", "X"],
500 ["B", "attribute1", "Y"],
501 ["C", "attribute1", "Z"],
502 ["D", "attribute1", "X"],
503 ]
504 ]
506 assert sorted(group_links) == expected_group_links
508 def test_build_groups_existing_groups(self, df_groups, entity_col) -> None:
509 value_cols = ["attribute1"]
510 existing_groups_links = [
511 ["Z", "attribute2", "X"],
512 ]
513 group_links = build_groups(
514 value_cols, df_groups, entity_col, existing_groups_links
515 )
517 expected_group_links = [
518 ["Z", "attribute2", "X"],
519 [
520 ["A", "attribute1", "X"],
521 ["B", "attribute1", "Y"],
522 ["C", "attribute1", "Z"],
523 ["D", "attribute1", "X"],
524 ],
525 ]
527 assert group_links == expected_group_links
529 def test_build_groups_two_columns(self, df_groups, entity_col) -> None:
530 value_cols = ["attribute1", "attribute2"]
531 group_links = build_groups(value_cols, df_groups, entity_col)
533 expected_group_links = [
534 [
535 ["A", "attribute1", "X"],
536 ["B", "attribute1", "Y"],
537 ["C", "attribute1", "Z"],
538 ["D", "attribute1", "X"],
539 ],
540 [
541 ["A", "attribute2", "X"],
542 ["B", "attribute2", "Y"],
543 ["C", "attribute2", "Y"],
544 ["D", "attribute2", "X"],
545 ],
546 ]
548 assert sorted(group_links[0]) == expected_group_links[0]
549 assert sorted(group_links[1]) == expected_group_links[1]
551 def test_build_groups_column_empty(self, df_groups, entity_col) -> None:
552 value_cols = []
553 group_links = build_groups(value_cols, df_groups, entity_col)
555 assert group_links == []
557 def test_build_groups_df_empty(self, entity_col) -> None:
558 value_cols = ["attribute1"]
559 df_groups = pl.DataFrame()
560 group_links = build_groups(value_cols, df_groups, entity_col)
562 assert group_links == []
564 def test_build_groups_column_doesnt_exists(self, entity_col, df_groups) -> None:
565 value_cols = ["attribute12"]
566 with pytest.raises(
567 ValueError,
568 match="Column attribute12 not found in the DataFrame.",
569 ):
570 build_groups(value_cols, df_groups, entity_col)
572 def test_build_groups_entity_doesnt_exist(self, df_groups) -> None:
573 entity_col = "entity123"
574 value_cols = ["attribute1"]
575 with pytest.raises(
576 ValueError,
577 match="Column entity123 not found in the DataFrame.",
578 ):
579 build_groups(value_cols, df_groups, entity_col)