Coverage for intelligence_toolkit/match_entity_records/prepare_model.py: 100%

36 statements  

« prev     ^ index     » next       coverage.py v7.10.7, created at 2025-10-16 13:41 -0300

1# Copyright (c) 2024 Microsoft Corporation. All rights reserved. 

2# Licensed under the MIT license. See LICENSE file in the project. 

3# 

4from collections import defaultdict 

5 

6import polars as pl 

7 

8from intelligence_toolkit.match_entity_records.classes import ( 

9 AttributeToMatch, 

10 RecordsModel, 

11) 

12 

13 

14def format_model_df( 

15 model: RecordsModel, 

16 max_rows: int = 0, 

17) -> pl.DataFrame: 

18 """ 

19 Format the dataset for training the model 

20 :param model: The model to format 

21 :param max_rows: The maximum number of rows to return 

22 :return: The formatted dataset 

23 """ 

24 

25 if model.dataframe.is_empty(): 

26 return pl.DataFrame() 

27 

28 if not model.id_column: 

29 selected_df = model.dataframe.with_row_index(name="Entity ID") 

30 else: 

31 selected_df = model.dataframe.rename({model.id_column: "Entity ID"}) 

32 

33 selected_df = selected_df.rename({model.name_column: "Entity name"}) 

34 selected_df = selected_df.with_columns([pl.col("Entity ID").cast(pl.Utf8)]) 

35 

36 selected_df = selected_df.select( 

37 ["Entity ID", "Entity name", *sorted(model.columns)] 

38 ) 

39 if max_rows > 0: 

40 selected_df = selected_df.head(max_rows) 

41 return selected_df 

42 

43 

44def build_attribute_options(matching_dfs: dict[str, pl.DataFrame]) -> list[str]: 

45 attr_options = [] 

46 skip_columns = ["Entity ID", "Entity name"] 

47 for dataset, merged_df in matching_dfs.items(): 

48 attr_options.extend( 

49 [f"{c}::{dataset}" for c in merged_df.columns if c not in skip_columns] 

50 ) 

51 return sorted(attr_options) 

52 

53 

54def build_attributes_list(attr_list: list[AttributeToMatch]) -> dict: 

55 df_renamed = defaultdict(dict) 

56 for attr in attr_list: 

57 att_name = attr.get("label") 

58 columns = attr.get("columns") 

59 if not columns: 

60 continue 

61 if not att_name: 

62 att_name = sorted(columns)[0].split("::")[0] 

63 for val in columns: 

64 col, dataset = val.split("::") 

65 if dataset not in df_renamed: 

66 df_renamed[dataset] = {} 

67 df_renamed[dataset][col] = att_name 

68 

69 return df_renamed