Coverage for intelligence_toolkit/compare_case_groups/build_dataframes.py: 100%
35 statements
« prev ^ index » next coverage.py v7.10.7, created at 2025-10-16 13:41 -0300
« prev ^ index » next coverage.py v7.10.7, created at 2025-10-16 13:41 -0300
1# Copyright (c) 2024 Microsoft Corporation. All rights reserved.
2# Licensed under the MIT license. See LICENSE file in the project.
3#
5import polars as pl
8def build_ranked_df(
9 temporal_df: pl.DataFrame,
10 group_df: pl.DataFrame,
11 attribute_df: pl.DataFrame,
12 temporal: str,
13 groups: list[str],
14) -> pl.DataFrame:
15 if temporal != "":
16 odf = temporal_df.join(group_df, on=groups, how="left", suffix="_r")
17 else:
18 odf = attribute_df.join(group_df, on=groups, how="left", suffix="_r")
20 odf = odf.join(
21 attribute_df, on=[*groups, "attribute_value"], how="left", suffix="_r"
22 )
24 odf = odf.sort(by=groups)
26 if temporal != "":
27 odf = odf.with_columns(
28 [
29 pl.col(temporal).alias(f"{temporal}_window"),
30 pl.col(f"{temporal}_window_rank").cast(pl.Int32),
31 pl.col(f"{temporal}_window_delta").cast(pl.Int32),
32 ]
33 )
35 return odf.with_columns(
36 [pl.col("attribute_rank").cast(pl.Int32), pl.col("group_rank").cast(pl.Int32)]
37 )
40def build_grouped_df(main_dataset: pl.DataFrame, groups: list[str]) -> pl.DataFrame:
41 """
42 This function takes a main dataset and a list of grouping columns, then processes
43 and returns a DataFrame with the counts of each group and their ranks.
45 Parameters:
46 main_dataset (pl.DataFrame): The main dataset to process.
47 groups (list of str): The list of column names to group by.
49 Returns:
50 pl.DataFrame: A DataFrame with group counts and ranks.
51 """
52 # Ensure groups is a list of strings
53 if not all(isinstance(group, str) for group in groups):
54 error_text = "All elements in groups must be strings"
55 raise ValueError(error_text)
57 main_dataset = main_dataset.with_columns(
58 pl.arange(0, main_dataset.height).cast(pl.Utf8).alias("record_id")
59 )
61 gdf = main_dataset.melt(
62 id_vars=groups,
63 value_vars=["record_id"],
64 variable_name="Attribute",
65 value_name="Value",
66 )
68 gdf = gdf.with_columns(
69 (pl.col("Attribute") + ":" + pl.col("Value")).alias("attribute_value")
70 )
72 gdf = gdf.group_by(groups).agg(pl.len().alias("group_count"))
74 gdf = gdf.with_columns(
75 pl.col("group_count").rank(method="max", descending=True).alias("group_rank")
76 )
78 return gdf.sort(by=groups)
80def build_attribute_df(
81 filtered_df: pl.DataFrame, groups: list[str], aggregates: str = ""
82) -> pl.DataFrame:
84 ndf = filtered_df.melt(
85 id_vars=groups,
86 value_vars=aggregates,
87 variable_name="Attribute",
88 value_name="Value",
89 )
90 # Drop rows with NaN values in the "Value" column
91 ndf = ndf.drop_nulls(subset=["Value"])
92 # Create "attribute_value" column
93 ndf = ndf.with_columns(
94 (pl.col("Attribute") + ":" + pl.col("Value").cast(str)).alias("attribute_value")
95 )
97 # Group by and count the occurrences
98 attributes_df = ndf.group_by([*groups, "attribute_value"]).agg(
99 pl.len().alias("attribute_count")
100 )
101 # Ensure all groups have entries for all attribute_values
102 all_attribute_values = attributes_df["attribute_value"].unique().to_list()
103 groups_df = filtered_df.select(groups).unique()
104 all_combinations = pl.DataFrame({col: groups_df[col] for col in groups}).join(
105 pl.DataFrame({"attribute_value": all_attribute_values}), how="cross"
106 )
107 attributes_df = all_combinations.join(
108 attributes_df, on=[*groups, "attribute_value"], how="left"
109 ).fill_null(0)
110 # Calculate the rank
111 return attributes_df.with_columns(
112 [
113 pl.col("attribute_count")
114 .rank("max", descending=True)
115 .over("attribute_value")
116 .alias("attribute_rank")
117 ]
118 )
121def filter_df(main_df: pl.DataFrame, filters: list[str]) -> pl.DataFrame:
122 for f in filters:
123 col, val = f.split(":")
124 main_df = main_df.filter(pl.col(col) == val)
126 return main_df