Coverage for intelligence_toolkit/match_entity_records/api.py: 100%
63 statements
« prev ^ index » next coverage.py v7.10.7, created at 2025-10-16 13:41 -0300
« prev ^ index » next coverage.py v7.10.7, created at 2025-10-16 13:41 -0300
1# Copyright (c) 2024 Microsoft Corporation. All rights reserved.
2# Licensed under the MIT license. See LICENSE file in the project.
3#
6import io
7from typing import ClassVar
9import numpy as np
10import polars as pl
12import intelligence_toolkit.AI.utils as utils
13from intelligence_toolkit.AI.classes import LLMCallback
14from intelligence_toolkit.AI.client import OpenAIClient
15from intelligence_toolkit.helpers.classes import IntelligenceWorkflow
16from intelligence_toolkit.match_entity_records import prompts
17from intelligence_toolkit.match_entity_records.classes import (
18 AttributeToMatch,
19 RecordsModel,
20)
21from intelligence_toolkit.match_entity_records.detect import (
22 build_attributes_dataframe,
23 build_matches,
24 build_matches_dataset,
25 build_near_map,
26 build_nearest_neighbors,
27 build_sentence_pair_scores,
28 convert_to_sentences,
29)
30from intelligence_toolkit.match_entity_records.prepare_model import (
31 build_attribute_options,
32 build_attributes_list,
33 format_model_df,
34)
37class MatchEntityRecords(IntelligenceWorkflow):
38 model_dfs: ClassVar[dict] = {}
39 max_rows_to_process = 0
40 evaluations_df = pl.DataFrame()
41 matches_df = pl.DataFrame()
43 @property
44 def total_records(self) -> int:
45 return sum(df.shape[0] for df in self.model_dfs.values())
47 @property
48 def attribute_options(self) -> str:
49 return build_attribute_options(self.model_dfs)
51 @property
52 def integrated_results(self) -> pl.DataFrame:
53 value = self.evaluations_df.drop_nulls()
54 return self.matches_df.join(value, on="Group ID", how="inner")
56 def add_df_to_model(self, model: RecordsModel) -> pl.DataFrame:
57 if not model.dataframe_name:
58 model.dataframe_name = "dataset_" + str(len(self.model_dfs) + 1)
60 self.model_dfs[model.dataframe_name] = format_model_df(
61 model,
62 self.max_rows_to_process,
63 )
64 return self.model_dfs[model.dataframe_name]
66 def build_model_df(self, attributes_list: list[AttributeToMatch]) -> pl.DataFrame:
67 attributes = build_attributes_list(attributes_list)
68 self.model_df = build_attributes_dataframe(self.model_dfs, attributes)
69 self.model_df = self.model_df.with_columns(
70 (pl.col("Entity ID").cast(pl.Utf8))
71 + "::"
72 + pl.col("Dataset").alias("Unique ID")
73 )
75 self.sentences_vector_data = convert_to_sentences(self.model_df)
76 return self.model_df
78 async def embed_sentences(self) -> None:
79 sentences_data = await self.embedder.embed_store_many(
80 self.sentences_vector_data, cache_data=self.cache_embeddings
81 )
82 self.all_sentences = [x["text"] for x in self.sentences_vector_data]
83 self.embeddings = [
84 np.array(next(d["vector"] for d in sentences_data if d["text"] == f))
85 for f in self.all_sentences
86 ]
88 def detect_record_groups(
89 self, pair_embedding_threshold: int, pair_jaccard_threshold: int
90 ) -> pl.DataFrame:
91 distances, indices = build_nearest_neighbors(self.embeddings)
92 near_map = build_near_map(
93 distances,
94 indices,
95 self.all_sentences,
96 pair_embedding_threshold,
97 )
99 pair_scores = build_sentence_pair_scores(near_map, self.model_df)
101 entity_to_group, matches, pair_to_match = build_matches(
102 pair_scores,
103 self.model_df,
104 pair_jaccard_threshold,
105 )
107 matches_df = pl.DataFrame(
108 list(matches),
109 schema=["Group ID", *self.model_df.columns],
110 ).sort(by=["Group ID", "Entity name", "Dataset"], descending=False)
112 self.matches_df = build_matches_dataset(
113 matches_df, pair_to_match, entity_to_group
114 )
115 return self.matches_df
117 async def evaluate_groups(
118 self,
119 ai_instructions=prompts.list_prompts,
120 callbacks: list[LLMCallback] | None = None,
121 ) -> None:
122 data = self.model_df.drop(
123 [
124 "Entity ID",
125 "Dataset",
126 "Name similarity",
127 ]
128 ).to_pandas()
129 data = data.head(500)
131 batch_messages = utils.generate_batch_messages(
132 ai_instructions, batch_name="data", batch_value=data
133 )
134 prefix = "```\nGroup ID,Relatedness,Explanation\n"
136 for messages in batch_messages:
137 response = await OpenAIClient(self.ai_configuration).generate_chat_async(
138 messages, stream=False
139 )
140 prefix = prefix + response + "\n"
141 result = prefix.replace("```\n", "").strip()
142 self.evaluations_df = pl.read_csv(io.StringIO(result))
143 return result
145 def clear_model_dfs(self) -> None:
146 self.model_dfs = {}