Coverage for intelligence_toolkit/match_entity_records/api.py: 100%

63 statements  

« prev     ^ index     » next       coverage.py v7.10.7, created at 2025-10-16 13:41 -0300

1# Copyright (c) 2024 Microsoft Corporation. All rights reserved. 

2# Licensed under the MIT license. See LICENSE file in the project. 

3# 

4 

5 

6import io 

7from typing import ClassVar 

8 

9import numpy as np 

10import polars as pl 

11 

12import intelligence_toolkit.AI.utils as utils 

13from intelligence_toolkit.AI.classes import LLMCallback 

14from intelligence_toolkit.AI.client import OpenAIClient 

15from intelligence_toolkit.helpers.classes import IntelligenceWorkflow 

16from intelligence_toolkit.match_entity_records import prompts 

17from intelligence_toolkit.match_entity_records.classes import ( 

18 AttributeToMatch, 

19 RecordsModel, 

20) 

21from intelligence_toolkit.match_entity_records.detect import ( 

22 build_attributes_dataframe, 

23 build_matches, 

24 build_matches_dataset, 

25 build_near_map, 

26 build_nearest_neighbors, 

27 build_sentence_pair_scores, 

28 convert_to_sentences, 

29) 

30from intelligence_toolkit.match_entity_records.prepare_model import ( 

31 build_attribute_options, 

32 build_attributes_list, 

33 format_model_df, 

34) 

35 

36 

37class MatchEntityRecords(IntelligenceWorkflow): 

38 model_dfs: ClassVar[dict] = {} 

39 max_rows_to_process = 0 

40 evaluations_df = pl.DataFrame() 

41 matches_df = pl.DataFrame() 

42 

43 @property 

44 def total_records(self) -> int: 

45 return sum(df.shape[0] for df in self.model_dfs.values()) 

46 

47 @property 

48 def attribute_options(self) -> str: 

49 return build_attribute_options(self.model_dfs) 

50 

51 @property 

52 def integrated_results(self) -> pl.DataFrame: 

53 value = self.evaluations_df.drop_nulls() 

54 return self.matches_df.join(value, on="Group ID", how="inner") 

55 

56 def add_df_to_model(self, model: RecordsModel) -> pl.DataFrame: 

57 if not model.dataframe_name: 

58 model.dataframe_name = "dataset_" + str(len(self.model_dfs) + 1) 

59 

60 self.model_dfs[model.dataframe_name] = format_model_df( 

61 model, 

62 self.max_rows_to_process, 

63 ) 

64 return self.model_dfs[model.dataframe_name] 

65 

66 def build_model_df(self, attributes_list: list[AttributeToMatch]) -> pl.DataFrame: 

67 attributes = build_attributes_list(attributes_list) 

68 self.model_df = build_attributes_dataframe(self.model_dfs, attributes) 

69 self.model_df = self.model_df.with_columns( 

70 (pl.col("Entity ID").cast(pl.Utf8)) 

71 + "::" 

72 + pl.col("Dataset").alias("Unique ID") 

73 ) 

74 

75 self.sentences_vector_data = convert_to_sentences(self.model_df) 

76 return self.model_df 

77 

78 async def embed_sentences(self) -> None: 

79 sentences_data = await self.embedder.embed_store_many( 

80 self.sentences_vector_data, cache_data=self.cache_embeddings 

81 ) 

82 self.all_sentences = [x["text"] for x in self.sentences_vector_data] 

83 self.embeddings = [ 

84 np.array(next(d["vector"] for d in sentences_data if d["text"] == f)) 

85 for f in self.all_sentences 

86 ] 

87 

88 def detect_record_groups( 

89 self, pair_embedding_threshold: int, pair_jaccard_threshold: int 

90 ) -> pl.DataFrame: 

91 distances, indices = build_nearest_neighbors(self.embeddings) 

92 near_map = build_near_map( 

93 distances, 

94 indices, 

95 self.all_sentences, 

96 pair_embedding_threshold, 

97 ) 

98 

99 pair_scores = build_sentence_pair_scores(near_map, self.model_df) 

100 

101 entity_to_group, matches, pair_to_match = build_matches( 

102 pair_scores, 

103 self.model_df, 

104 pair_jaccard_threshold, 

105 ) 

106 

107 matches_df = pl.DataFrame( 

108 list(matches), 

109 schema=["Group ID", *self.model_df.columns], 

110 ).sort(by=["Group ID", "Entity name", "Dataset"], descending=False) 

111 

112 self.matches_df = build_matches_dataset( 

113 matches_df, pair_to_match, entity_to_group 

114 ) 

115 return self.matches_df 

116 

117 async def evaluate_groups( 

118 self, 

119 ai_instructions=prompts.list_prompts, 

120 callbacks: list[LLMCallback] | None = None, 

121 ) -> None: 

122 data = self.model_df.drop( 

123 [ 

124 "Entity ID", 

125 "Dataset", 

126 "Name similarity", 

127 ] 

128 ).to_pandas() 

129 data = data.head(500) 

130 

131 batch_messages = utils.generate_batch_messages( 

132 ai_instructions, batch_name="data", batch_value=data 

133 ) 

134 prefix = "```\nGroup ID,Relatedness,Explanation\n" 

135 

136 for messages in batch_messages: 

137 response = await OpenAIClient(self.ai_configuration).generate_chat_async( 

138 messages, stream=False 

139 ) 

140 prefix = prefix + response + "\n" 

141 result = prefix.replace("```\n", "").strip() 

142 self.evaluations_df = pl.read_csv(io.StringIO(result)) 

143 return result 

144 

145 def clear_model_dfs(self) -> None: 

146 self.model_dfs = {}