Coverage for intelligence_toolkit/match_entity_records/detect.py: 98%
151 statements
« prev ^ index » next coverage.py v7.10.7, created at 2025-10-16 13:41 -0300
« prev ^ index » next coverage.py v7.10.7, created at 2025-10-16 13:41 -0300
1# Copyright (c) 2024 Microsoft Corporation. All rights reserved.
2# Licensed under the MIT license. See LICENSE file in the project.
3#
5import re
6from collections import defaultdict
7from typing import Any
9import numpy as np
10import polars as pl
11from sklearn.neighbors import NearestNeighbors
13from intelligence_toolkit.AI.classes import VectorData
14from intelligence_toolkit.AI.utils import hash_text
15from intelligence_toolkit.match_entity_records.config import (
16 DEFAULT_COLUMNS_DONT_CONVERT,
17 DEFAULT_MAX_RECORD_DISTANCE,
18 DEFAULT_SENTENCE_PAIR_JACCARD_THRESHOLD,
19)
22def convert_to_sentences(
23 merged_dataframe: pl.DataFrame,
24 skip_columns: list[str] | None = DEFAULT_COLUMNS_DONT_CONVERT,
25) -> list[VectorData]:
26 sentences: list[VectorData] = []
27 skip_columns = skip_columns or []
28 cols = merged_dataframe.columns
29 for row in merged_dataframe.iter_rows(named=True):
30 sentence = "ENTITY NAME: " + row["Entity name"].upper() + "; "
31 for field in cols:
32 if field not in skip_columns:
33 val = str(row[field]).upper()
34 if val == "NAN":
35 val = ""
36 sentence += field.upper() + ": " + val + "; "
37 sentence = sentence.strip()
38 text_hashed = hash_text(sentence)
39 sentences.append({"text": sentence, "hash": text_hashed})
41 return sentences
44def build_nearest_neighbors(
45 embeddings: np.array,
46 n_neighbors: int = 50,
47 leaf_size: int = 20,
48 metric: str = "cosine",
49) -> tuple[np.array, np.array]:
50 if len(embeddings) < n_neighbors:
51 msg = f"Number of neighbors ({n_neighbors}) is greater than number of embeddings ({len(embeddings)})"
52 raise ValueError(msg)
54 nbrs = NearestNeighbors(
55 n_neighbors=n_neighbors,
56 n_jobs=1,
57 algorithm="auto",
58 leaf_size=leaf_size,
59 metric=metric,
60 ).fit(embeddings)
62 distances, indices = nbrs.kneighbors(embeddings)
63 return distances, indices
66def build_near_map(
67 distances: np.array,
68 indices: np.array,
69 all_sentences: list[str],
70 max_record_distance: int | None = DEFAULT_MAX_RECORD_DISTANCE,
71) -> defaultdict[Any, list]:
72 near_map = defaultdict(list)
74 for ix in range(len(all_sentences)):
75 n = all_sentences[ix]
76 near_is = indices[ix][1:]
77 near_ds = distances[ix][1:]
78 near_ns = all_sentences[ix]
79 nearest = zip(near_is, near_ds, near_ns, strict=False)
80 for near_i, near_d, near_n in nearest:
81 if near_d <= max_record_distance:
82 near_map[ix].append(near_i)
84 return near_map
87def build_sentence_pair_scores(
88 near_map: defaultdict[Any, list], merged_df: pl.DataFrame
89) -> list:
90 sentence_pair_scores = []
91 for ix, nx_list in near_map.items():
92 ixrec = merged_df.row(ix, named=True)
93 for nx in nx_list:
94 nxrec = merged_df.row(nx, named=True)
95 ixn = ixrec["Entity name"].upper()
96 nxn = nxrec["Entity name"].upper()
97 if ixn == nxn:
98 score = 1
99 else:
100 ixn_c = re.sub(r"[^\w\s]", "", ixn)
101 nxn_c = re.sub(r"[^\w\s]", "", nxn)
102 N = 3
103 igrams = {ixn_c[i : i + N] for i in range(len(ixn_c) - N + 1)}
104 ngrams = {nxn_c[i : i + N] for i in range(len(nxn_c) - N + 1)}
105 inter = len(igrams.intersection(ngrams))
106 union = len(igrams.union(ngrams))
107 score = inter / union if union > 0 else 0
109 sentence_pair_scores.append(
110 (
111 ix,
112 nx,
113 score,
114 )
115 )
116 return sentence_pair_scores
119def build_matches(
120 sentence_pair_scores,
121 merged_df: pl.DataFrame,
122 sentence_pair_jaccard_threshold: float = DEFAULT_SENTENCE_PAIR_JACCARD_THRESHOLD,
123) -> tuple[dict, set, dict]:
124 entity_to_group = {}
125 group_id = 0
126 matches = set()
127 pair_to_match = {}
129 for ix, nx, score in sorted(
130 sentence_pair_scores,
131 key=lambda x: x[2],
132 reverse=True,
133 ):
134 if score < sentence_pair_jaccard_threshold:
135 continue
137 ixrec = merged_df.row(ix, named=True)
138 nxrec = merged_df.row(nx, named=True)
139 ixn = ixrec["Entity name"]
140 nxn = nxrec["Entity name"]
141 ixp = ixrec["Dataset"]
142 nxp = nxrec["Dataset"]
144 ix_id = f"{ixn}::{ixp}"
145 nx_id = f"{nxn}::{nxp}"
147 if ix_id in entity_to_group and nx_id in entity_to_group:
148 ig = entity_to_group[ix_id]
149 ng = entity_to_group[nx_id]
150 if ig != ng:
151 for k, v in list(entity_to_group.items()):
152 if v == ig:
153 entity_to_group[k] = ng
154 elif ix_id in entity_to_group:
155 entity_to_group[nx_id] = entity_to_group[ix_id]
156 elif nx_id in entity_to_group:
157 entity_to_group[ix_id] = entity_to_group[nx_id]
158 else:
159 entity_to_group[ix_id] = group_id
160 entity_to_group[nx_id] = group_id
161 group_id += 1
163 matches.add((entity_to_group[ix_id], *list(merged_df.row(ix))))
164 matches.add((entity_to_group[nx_id], *list(merged_df.row(nx))))
166 pair_to_match[tuple(sorted([ix_id, nx_id]))] = score
167 return entity_to_group, matches, pair_to_match
170def _calculate_mean_score(pair_to_match: dict, entity_to_group: dict) -> dict:
171 group_to_scores = defaultdict(list)
173 for (ix_id, nx_id), score in pair_to_match.items():
174 if (
175 ix_id in entity_to_group
176 and nx_id in entity_to_group
177 and entity_to_group[ix_id] == entity_to_group[nx_id]
178 ):
179 group_to_scores[entity_to_group[ix_id]].append(score)
181 group_to_mean_similarity = {}
182 for group, scores in group_to_scores.items():
183 group_to_mean_similarity[group] = (
184 sum(scores) / len(scores)
185 if len(scores) > 0
186 else 1 # Must be the same value
187 )
188 return group_to_mean_similarity
191def build_matches_dataset(
192 matches_df: pl.DataFrame, pair_to_match: dict, entity_to_group: dict
193) -> pl.DataFrame:
194 if matches_df.is_empty():
195 return matches_df
197 group_to_size = (
198 matches_df.group_by("Group ID")
199 .agg(pl.count("Entity ID").alias("Size"))
200 .to_dict()
201 )
202 group_to_size = dict(
203 zip(
204 group_to_size["Group ID"],
205 group_to_size["Size"],
206 strict=False,
207 )
208 )
209 matches_df = matches_df.with_columns(
210 matches_df["Group ID"].replace(group_to_size).alias("Group size")
211 )
213 order_first_columns = [
214 "Group ID",
215 "Group size",
216 "Entity name",
217 "Dataset",
218 "Entity ID",
219 ]
220 remaining_columns = [c for c in matches_df.columns if c not in order_first_columns]
221 new_column_order = order_first_columns + remaining_columns
222 matches_df = matches_df.with_columns([matches_df[c] for c in new_column_order])
224 # keep only groups larger than 1
225 matches_df = matches_df.with_columns(
226 matches_df["Entity ID"]
227 .map_elements(lambda x: x.split("::")[0])
228 .alias("Entity ID")
229 ).filter(pl.col("Group size") > 1)
231 # iterate over groups, calculating mean score
232 group_to_mean_similarity = _calculate_mean_score(pair_to_match, entity_to_group)
234 if matches_df.is_empty():
235 return matches_df
237 matches_df = matches_df.with_columns(
238 matches_df["Group ID"]
239 .map_elements(lambda x: group_to_mean_similarity.get(x))
240 .alias("Name similarity")
241 )
243 matches_df = matches_df.filter(pl.col("Name similarity").is_not_null())
244 return matches_df.sort(by=["Name similarity", "Group ID"])
247def build_attributes_dataframe(
248 matching_dfs: dict[pl.DataFrame], atts_to_datasets: defaultdict[dict]
249) -> pl.DataFrame:
250 if not matching_dfs:
251 return pl.DataFrame()
253 aligned_dfs = []
254 for dataset, merged_df in matching_dfs.items():
255 if dataset not in atts_to_datasets:
256 continue
257 rdf = merged_df.clone()
258 rdf = rdf.rename(atts_to_datasets[dataset])
259 # drop columns that are not in atts_to_datasets
260 for col in matching_dfs[dataset].columns:
261 if col not in rdf.columns:
262 continue
263 if col not in atts_to_datasets[dataset] and col not in [
264 "Entity ID",
265 "Entity name",
266 ]:
267 rdf = rdf.drop(col)
268 continue
270 for dataset1 in atts_to_datasets:
271 if dataset1 not in atts_to_datasets and col not in [
272 "Entity ID",
273 "Entity name",
274 ]:
275 rdf = rdf.drop(col)
277 rdf = rdf.with_columns(pl.lit(dataset).alias("Dataset"))
278 rdf = rdf.select(sorted(rdf.columns))
279 aligned_dfs.append(rdf)
281 string_dfs = []
282 for merged_df in aligned_dfs:
283 for col in merged_df.columns:
284 merged_df = merged_df.with_columns(pl.col(col).cast(pl.Utf8))
285 string_dfs.append(merged_df)
287 return pl.concat(string_dfs).filter(pl.col("Entity name") != "")