Coverage for intelligence_toolkit/compare_case_groups/api.py: 0%

99 statements  

« prev     ^ index     » next       coverage.py v7.10.7, created at 2025-10-16 13:41 -0300

1# Copyright (c) 2024 Microsoft Corporation. All rights reserved. 

2# Licensed under the MIT license. See LICENSE file in the project. 

3# 

4 

5 

6import polars as pl 

7 

8import intelligence_toolkit.AI.utils as utils 

9from intelligence_toolkit.AI.client import OpenAIClient 

10from intelligence_toolkit.compare_case_groups import prompts 

11from intelligence_toolkit.compare_case_groups.build_dataframes import ( 

12 build_attribute_df, 

13 build_grouped_df, 

14 build_ranked_df, 

15 filter_df, 

16) 

17from intelligence_toolkit.compare_case_groups.temporal_process import ( 

18 build_temporal_data, 

19 create_window_df, 

20) 

21from intelligence_toolkit.helpers.classes import IntelligenceWorkflow 

22 

23 

24class CompareCaseGroups(IntelligenceWorkflow): 

25 model_df = pl.DataFrame() 

26 filtered_df = pl.DataFrame() 

27 prepared_df = pl.DataFrame() 

28 

29 def __init__(self): 

30 self.filters = [] 

31 self.groups = [] 

32 self.aggregates = [] 

33 self.temporal = "" 

34 

35 def get_dataset_proportion(self) -> int: 

36 initial_row_count = len(self.prepared_df) 

37 filtered_row_count = len(self.filtered_df) 

38 return round( 

39 100 * filtered_row_count / initial_row_count 

40 if initial_row_count > 0 

41 else 0, 

42 0, 

43 ) 

44 

45 def get_report_groups_filter_options(self): 

46 unique_groups_df = self.model_df.select(self.groups).unique() 

47 

48 unique_groups = unique_groups_df.to_dicts() 

49 

50 return sorted( 

51 unique_groups, key=lambda x: tuple(x[group] for group in self.groups) 

52 ) 

53 

54 def get_filter_options(self, input_df: pl.DataFrame) -> list[str]: 

55 sorted_atts = [] 

56 sorted_cols = sorted(input_df.columns) 

57 for col in sorted_cols: 

58 unique_sorted_values = ( 

59 input_df.with_columns(pl.col(col).cast(pl.Utf8)) # Cast to string 

60 .select(pl.col(col).unique()) # Get unique values 

61 .to_series() # Convert to Series 

62 .sort() # Sort the unique values 

63 ) 

64 vals = [ 

65 f"{col}:{x}" 

66 for x in unique_sorted_values 

67 if x 

68 not in [ 

69 "", 

70 "<NA>", 

71 "nan", 

72 "NaN", 

73 "None", 

74 "none", 

75 "NULL", 

76 "null", 

77 ] 

78 ] 

79 sorted_atts.extend(vals) 

80 return sorted_atts 

81 

82 def _select_columns_ranked_df(self, ranked_df: pl.DataFrame) -> None: 

83 columns = self.groups.copy() 

84 default_columns = [ 

85 "group_count", 

86 "group_rank", 

87 "attribute_value", 

88 "attribute_count", 

89 "attribute_rank", 

90 ] 

91 

92 columns.extend(default_columns) 

93 

94 if self.temporal: 

95 columns.extend( 

96 [ 

97 f"{self.temporal}_window", 

98 f"{self.temporal}_window_count", 

99 f"{self.temporal}_window_rank", 

100 f"{self.temporal}_window_delta", 

101 ] 

102 ) 

103 self.model_df = ranked_df.select(columns) 

104 

105 def create_data_summary( 

106 self, 

107 prepared_df: pl.DataFrame, 

108 filters: list[str], 

109 groups: list[str], 

110 aggregates: list[str], 

111 temporal: str = "", 

112 ): 

113 self.filters = filters 

114 self.groups = groups 

115 self.aggregates = aggregates 

116 self.temporal = temporal 

117 self.prepared_df = prepared_df 

118 

119 self.prepared_df = self.prepared_df.drop_nulls(subset=self.groups) 

120 

121 self.model_df = self.prepared_df.with_columns( 

122 [ 

123 pl.when(pl.col(col) == "").then(None).otherwise(pl.col(col)).alias(col) 

124 for col in self.prepared_df.columns 

125 ] 

126 ) 

127 

128 self.filtered_df = filter_df(self.model_df, filters) 

129 

130 grouped_df = build_grouped_df(self.filtered_df, groups) 

131 

132 attributes_df = build_attribute_df(self.filtered_df, groups, aggregates) 

133 

134 temporal_df = pl.DataFrame() 

135 if temporal: 

136 window_df = create_window_df(groups, temporal, aggregates, self.filtered_df) 

137 

138 temporal_atts = sorted( 

139 self.model_df[temporal].cast(pl.Utf8).unique().drop_nulls() 

140 ) 

141 

142 temporal_df = build_temporal_data( 

143 window_df, groups, temporal_atts, temporal 

144 ) 

145 else: 

146 for group in groups: 

147 attributes_df = attributes_df.filter(pl.col(group).is_not_null()) 

148 

149 ranked_df = build_ranked_df( 

150 temporal_df, 

151 grouped_df, 

152 attributes_df, 

153 temporal, 

154 groups, 

155 ) 

156 self._select_columns_ranked_df(ranked_df) 

157 

158 def _format_list(self, items, bold=True, escape_colon=False) -> str: 

159 formatted_items = [] 

160 for item in items: 

161 if escape_colon: 

162 item = item.replace(":", "\\:") 

163 if bold: 

164 item = f"**{item}**" 

165 formatted_items.append(item) 

166 return "[" + ", ".join(formatted_items) + "]" 

167 

168 def get_summary_description(self) -> str: 

169 groups_text = self._format_list(self.groups) 

170 filters_text = self._format_list(self.filters, escape_colon=True) 

171 

172 description_lines = ["This table shows:"] 

173 

174 if self.filters: 

175 description_lines.append( 

176 f"- A summary of **{len(self.filtered_df)}** data records matching {filters_text}, representing **{self.get_dataset_proportion()}%** of the overall dataset with values for all grouping attributes" 

177 ) 

178 else: 

179 description_lines.append( 

180 f"- A summary of all **{len(self.filtered_df)}** data records with values for all grouping attributes" 

181 ) 

182 

183 description_lines.extend( 

184 [ 

185 f"- The **group_count** of records for all {groups_text} groups, and corresponding **group_rank**", 

186 f"- The **attribute_count** of each **attribute_value** for all {groups_text} groups, and corresponding **attribute_rank**", 

187 ] 

188 ) 

189 

190 if self.temporal: 

191 description_lines.extend( 

192 [ 

193 f"- The **{self.temporal}_window_count** of each **attribute_value** for each **{self.temporal}_window** for all {groups_text} groups, and corresponding **{self.temporal}_window_rank**", 

194 f"- The **{self.temporal}_window_delta**, or change in the **attribute_value_count** for successive **{self.temporal}_window** values, within each {groups_text} group", 

195 ] 

196 ) 

197 

198 return "\n".join(description_lines) 

199 

200 def get_report_data( 

201 self, 

202 selected_groups=None, 

203 top_group_ranks=None, 

204 ) -> tuple[pl.DataFrame, str]: 

205 selected_df = self.model_df 

206 

207 filter_description = "" 

208 if selected_groups: 

209 filter_expr = pl.lit(False) 

210 for group in selected_groups: 

211 group_expr = pl.lit(True) 

212 for col, value in group.items(): 

213 group_expr &= pl.col(col) == value 

214 filter_expr |= group_expr 

215 

216 # Apply the filter to the DataFrame 

217 selected_df = self.model_df.filter(filter_expr) 

218 filter_description = f'Filtered to the following groups only: {", ".join([str(s) for s in selected_groups])}' 

219 elif top_group_ranks: 

220 selected_df = selected_df.filter(pl.col("group_rank") <= top_group_ranks) 

221 filter_description = ( 

222 f"Filtered to the top {top_group_ranks} groups by record count" 

223 ) 

224 return selected_df, filter_description 

225 

226 def generate_group_report( 

227 self, 

228 report_data: pl.DataFrame, 

229 filter_description=str, 

230 ai_instructions=prompts.user_prompt, 

231 callbacks=[], 

232 ): 

233 variables = { 

234 "description": self.get_summary_description(), 

235 "dataset": report_data.write_csv(), 

236 "filters": filter_description, 

237 } 

238 

239 messages = utils.generate_messages( 

240 ai_instructions, 

241 prompts.list_prompts["report_prompt"], 

242 variables, 

243 prompts.list_prompts["safety_prompt"], 

244 ) 

245 return OpenAIClient(self.ai_configuration).generate_chat( 

246 messages, callbacks=callbacks 

247 )