Coverage for intelligence_toolkit/compare_case_groups/build_dataframes.py: 100%

35 statements  

« prev     ^ index     » next       coverage.py v7.10.7, created at 2025-10-16 13:41 -0300

1# Copyright (c) 2024 Microsoft Corporation. All rights reserved. 

2# Licensed under the MIT license. See LICENSE file in the project. 

3# 

4 

5import polars as pl 

6 

7 

8def build_ranked_df( 

9 temporal_df: pl.DataFrame, 

10 group_df: pl.DataFrame, 

11 attribute_df: pl.DataFrame, 

12 temporal: str, 

13 groups: list[str], 

14) -> pl.DataFrame: 

15 if temporal != "": 

16 odf = temporal_df.join(group_df, on=groups, how="left", suffix="_r") 

17 else: 

18 odf = attribute_df.join(group_df, on=groups, how="left", suffix="_r") 

19 

20 odf = odf.join( 

21 attribute_df, on=[*groups, "attribute_value"], how="left", suffix="_r" 

22 ) 

23 

24 odf = odf.sort(by=groups) 

25 

26 if temporal != "": 

27 odf = odf.with_columns( 

28 [ 

29 pl.col(temporal).alias(f"{temporal}_window"), 

30 pl.col(f"{temporal}_window_rank").cast(pl.Int32), 

31 pl.col(f"{temporal}_window_delta").cast(pl.Int32), 

32 ] 

33 ) 

34 

35 return odf.with_columns( 

36 [pl.col("attribute_rank").cast(pl.Int32), pl.col("group_rank").cast(pl.Int32)] 

37 ) 

38 

39 

40def build_grouped_df(main_dataset: pl.DataFrame, groups: list[str]) -> pl.DataFrame: 

41 """ 

42 This function takes a main dataset and a list of grouping columns, then processes 

43 and returns a DataFrame with the counts of each group and their ranks. 

44 

45 Parameters: 

46 main_dataset (pl.DataFrame): The main dataset to process. 

47 groups (list of str): The list of column names to group by. 

48 

49 Returns: 

50 pl.DataFrame: A DataFrame with group counts and ranks. 

51 """ 

52 # Ensure groups is a list of strings 

53 if not all(isinstance(group, str) for group in groups): 

54 error_text = "All elements in groups must be strings" 

55 raise ValueError(error_text) 

56 

57 main_dataset = main_dataset.with_columns( 

58 pl.arange(0, main_dataset.height).cast(pl.Utf8).alias("record_id") 

59 ) 

60 

61 gdf = main_dataset.melt( 

62 id_vars=groups, 

63 value_vars=["record_id"], 

64 variable_name="Attribute", 

65 value_name="Value", 

66 ) 

67 

68 gdf = gdf.with_columns( 

69 (pl.col("Attribute") + ":" + pl.col("Value")).alias("attribute_value") 

70 ) 

71 

72 gdf = gdf.group_by(groups).agg(pl.len().alias("group_count")) 

73 

74 gdf = gdf.with_columns( 

75 pl.col("group_count").rank(method="max", descending=True).alias("group_rank") 

76 ) 

77 

78 return gdf.sort(by=groups) 

79 

80def build_attribute_df( 

81 filtered_df: pl.DataFrame, groups: list[str], aggregates: str = "" 

82) -> pl.DataFrame: 

83 

84 ndf = filtered_df.melt( 

85 id_vars=groups, 

86 value_vars=aggregates, 

87 variable_name="Attribute", 

88 value_name="Value", 

89 ) 

90 # Drop rows with NaN values in the "Value" column 

91 ndf = ndf.drop_nulls(subset=["Value"]) 

92 # Create "attribute_value" column 

93 ndf = ndf.with_columns( 

94 (pl.col("Attribute") + ":" + pl.col("Value").cast(str)).alias("attribute_value") 

95 ) 

96 

97 # Group by and count the occurrences 

98 attributes_df = ndf.group_by([*groups, "attribute_value"]).agg( 

99 pl.len().alias("attribute_count") 

100 ) 

101 # Ensure all groups have entries for all attribute_values 

102 all_attribute_values = attributes_df["attribute_value"].unique().to_list() 

103 groups_df = filtered_df.select(groups).unique() 

104 all_combinations = pl.DataFrame({col: groups_df[col] for col in groups}).join( 

105 pl.DataFrame({"attribute_value": all_attribute_values}), how="cross" 

106 ) 

107 attributes_df = all_combinations.join( 

108 attributes_df, on=[*groups, "attribute_value"], how="left" 

109 ).fill_null(0) 

110 # Calculate the rank 

111 return attributes_df.with_columns( 

112 [ 

113 pl.col("attribute_count") 

114 .rank("max", descending=True) 

115 .over("attribute_value") 

116 .alias("attribute_rank") 

117 ] 

118 ) 

119 

120 

121def filter_df(main_df: pl.DataFrame, filters: list[str]) -> pl.DataFrame: 

122 for f in filters: 

123 col, val = f.split(":") 

124 main_df = main_df.filter(pl.col(col) == val) 

125 

126 return main_df