Coverage for intelligence_toolkit/match_entity_records/prepare_model.py: 100%
36 statements
« prev ^ index » next coverage.py v7.10.7, created at 2025-10-16 13:41 -0300
« prev ^ index » next coverage.py v7.10.7, created at 2025-10-16 13:41 -0300
1# Copyright (c) 2024 Microsoft Corporation. All rights reserved.
2# Licensed under the MIT license. See LICENSE file in the project.
3#
4from collections import defaultdict
6import polars as pl
8from intelligence_toolkit.match_entity_records.classes import (
9 AttributeToMatch,
10 RecordsModel,
11)
14def format_model_df(
15 model: RecordsModel,
16 max_rows: int = 0,
17) -> pl.DataFrame:
18 """
19 Format the dataset for training the model
20 :param model: The model to format
21 :param max_rows: The maximum number of rows to return
22 :return: The formatted dataset
23 """
25 if model.dataframe.is_empty():
26 return pl.DataFrame()
28 if not model.id_column:
29 selected_df = model.dataframe.with_row_index(name="Entity ID")
30 else:
31 selected_df = model.dataframe.rename({model.id_column: "Entity ID"})
33 selected_df = selected_df.rename({model.name_column: "Entity name"})
34 selected_df = selected_df.with_columns([pl.col("Entity ID").cast(pl.Utf8)])
36 selected_df = selected_df.select(
37 ["Entity ID", "Entity name", *sorted(model.columns)]
38 )
39 if max_rows > 0:
40 selected_df = selected_df.head(max_rows)
41 return selected_df
44def build_attribute_options(matching_dfs: dict[str, pl.DataFrame]) -> list[str]:
45 attr_options = []
46 skip_columns = ["Entity ID", "Entity name"]
47 for dataset, merged_df in matching_dfs.items():
48 attr_options.extend(
49 [f"{c}::{dataset}" for c in merged_df.columns if c not in skip_columns]
50 )
51 return sorted(attr_options)
54def build_attributes_list(attr_list: list[AttributeToMatch]) -> dict:
55 df_renamed = defaultdict(dict)
56 for attr in attr_list:
57 att_name = attr.get("label")
58 columns = attr.get("columns")
59 if not columns:
60 continue
61 if not att_name:
62 att_name = sorted(columns)[0].split("::")[0]
63 for val in columns:
64 col, dataset = val.split("::")
65 if dataset not in df_renamed:
66 df_renamed[dataset] = {}
67 df_renamed[dataset][col] = att_name
69 return df_renamed