Coverage for intelligence_toolkit/compare_case_groups/api.py: 0%
99 statements
« prev ^ index » next coverage.py v7.10.7, created at 2025-10-16 13:41 -0300
« prev ^ index » next coverage.py v7.10.7, created at 2025-10-16 13:41 -0300
1# Copyright (c) 2024 Microsoft Corporation. All rights reserved.
2# Licensed under the MIT license. See LICENSE file in the project.
3#
6import polars as pl
8import intelligence_toolkit.AI.utils as utils
9from intelligence_toolkit.AI.client import OpenAIClient
10from intelligence_toolkit.compare_case_groups import prompts
11from intelligence_toolkit.compare_case_groups.build_dataframes import (
12 build_attribute_df,
13 build_grouped_df,
14 build_ranked_df,
15 filter_df,
16)
17from intelligence_toolkit.compare_case_groups.temporal_process import (
18 build_temporal_data,
19 create_window_df,
20)
21from intelligence_toolkit.helpers.classes import IntelligenceWorkflow
24class CompareCaseGroups(IntelligenceWorkflow):
25 model_df = pl.DataFrame()
26 filtered_df = pl.DataFrame()
27 prepared_df = pl.DataFrame()
29 def __init__(self):
30 self.filters = []
31 self.groups = []
32 self.aggregates = []
33 self.temporal = ""
35 def get_dataset_proportion(self) -> int:
36 initial_row_count = len(self.prepared_df)
37 filtered_row_count = len(self.filtered_df)
38 return round(
39 100 * filtered_row_count / initial_row_count
40 if initial_row_count > 0
41 else 0,
42 0,
43 )
45 def get_report_groups_filter_options(self):
46 unique_groups_df = self.model_df.select(self.groups).unique()
48 unique_groups = unique_groups_df.to_dicts()
50 return sorted(
51 unique_groups, key=lambda x: tuple(x[group] for group in self.groups)
52 )
54 def get_filter_options(self, input_df: pl.DataFrame) -> list[str]:
55 sorted_atts = []
56 sorted_cols = sorted(input_df.columns)
57 for col in sorted_cols:
58 unique_sorted_values = (
59 input_df.with_columns(pl.col(col).cast(pl.Utf8)) # Cast to string
60 .select(pl.col(col).unique()) # Get unique values
61 .to_series() # Convert to Series
62 .sort() # Sort the unique values
63 )
64 vals = [
65 f"{col}:{x}"
66 for x in unique_sorted_values
67 if x
68 not in [
69 "",
70 "<NA>",
71 "nan",
72 "NaN",
73 "None",
74 "none",
75 "NULL",
76 "null",
77 ]
78 ]
79 sorted_atts.extend(vals)
80 return sorted_atts
82 def _select_columns_ranked_df(self, ranked_df: pl.DataFrame) -> None:
83 columns = self.groups.copy()
84 default_columns = [
85 "group_count",
86 "group_rank",
87 "attribute_value",
88 "attribute_count",
89 "attribute_rank",
90 ]
92 columns.extend(default_columns)
94 if self.temporal:
95 columns.extend(
96 [
97 f"{self.temporal}_window",
98 f"{self.temporal}_window_count",
99 f"{self.temporal}_window_rank",
100 f"{self.temporal}_window_delta",
101 ]
102 )
103 self.model_df = ranked_df.select(columns)
105 def create_data_summary(
106 self,
107 prepared_df: pl.DataFrame,
108 filters: list[str],
109 groups: list[str],
110 aggregates: list[str],
111 temporal: str = "",
112 ):
113 self.filters = filters
114 self.groups = groups
115 self.aggregates = aggregates
116 self.temporal = temporal
117 self.prepared_df = prepared_df
119 self.prepared_df = self.prepared_df.drop_nulls(subset=self.groups)
121 self.model_df = self.prepared_df.with_columns(
122 [
123 pl.when(pl.col(col) == "").then(None).otherwise(pl.col(col)).alias(col)
124 for col in self.prepared_df.columns
125 ]
126 )
128 self.filtered_df = filter_df(self.model_df, filters)
130 grouped_df = build_grouped_df(self.filtered_df, groups)
132 attributes_df = build_attribute_df(self.filtered_df, groups, aggregates)
134 temporal_df = pl.DataFrame()
135 if temporal:
136 window_df = create_window_df(groups, temporal, aggregates, self.filtered_df)
138 temporal_atts = sorted(
139 self.model_df[temporal].cast(pl.Utf8).unique().drop_nulls()
140 )
142 temporal_df = build_temporal_data(
143 window_df, groups, temporal_atts, temporal
144 )
145 else:
146 for group in groups:
147 attributes_df = attributes_df.filter(pl.col(group).is_not_null())
149 ranked_df = build_ranked_df(
150 temporal_df,
151 grouped_df,
152 attributes_df,
153 temporal,
154 groups,
155 )
156 self._select_columns_ranked_df(ranked_df)
158 def _format_list(self, items, bold=True, escape_colon=False) -> str:
159 formatted_items = []
160 for item in items:
161 if escape_colon:
162 item = item.replace(":", "\\:")
163 if bold:
164 item = f"**{item}**"
165 formatted_items.append(item)
166 return "[" + ", ".join(formatted_items) + "]"
168 def get_summary_description(self) -> str:
169 groups_text = self._format_list(self.groups)
170 filters_text = self._format_list(self.filters, escape_colon=True)
172 description_lines = ["This table shows:"]
174 if self.filters:
175 description_lines.append(
176 f"- A summary of **{len(self.filtered_df)}** data records matching {filters_text}, representing **{self.get_dataset_proportion()}%** of the overall dataset with values for all grouping attributes"
177 )
178 else:
179 description_lines.append(
180 f"- A summary of all **{len(self.filtered_df)}** data records with values for all grouping attributes"
181 )
183 description_lines.extend(
184 [
185 f"- The **group_count** of records for all {groups_text} groups, and corresponding **group_rank**",
186 f"- The **attribute_count** of each **attribute_value** for all {groups_text} groups, and corresponding **attribute_rank**",
187 ]
188 )
190 if self.temporal:
191 description_lines.extend(
192 [
193 f"- The **{self.temporal}_window_count** of each **attribute_value** for each **{self.temporal}_window** for all {groups_text} groups, and corresponding **{self.temporal}_window_rank**",
194 f"- The **{self.temporal}_window_delta**, or change in the **attribute_value_count** for successive **{self.temporal}_window** values, within each {groups_text} group",
195 ]
196 )
198 return "\n".join(description_lines)
200 def get_report_data(
201 self,
202 selected_groups=None,
203 top_group_ranks=None,
204 ) -> tuple[pl.DataFrame, str]:
205 selected_df = self.model_df
207 filter_description = ""
208 if selected_groups:
209 filter_expr = pl.lit(False)
210 for group in selected_groups:
211 group_expr = pl.lit(True)
212 for col, value in group.items():
213 group_expr &= pl.col(col) == value
214 filter_expr |= group_expr
216 # Apply the filter to the DataFrame
217 selected_df = self.model_df.filter(filter_expr)
218 filter_description = f'Filtered to the following groups only: {", ".join([str(s) for s in selected_groups])}'
219 elif top_group_ranks:
220 selected_df = selected_df.filter(pl.col("group_rank") <= top_group_ranks)
221 filter_description = (
222 f"Filtered to the top {top_group_ranks} groups by record count"
223 )
224 return selected_df, filter_description
226 def generate_group_report(
227 self,
228 report_data: pl.DataFrame,
229 filter_description=str,
230 ai_instructions=prompts.user_prompt,
231 callbacks=[],
232 ):
233 variables = {
234 "description": self.get_summary_description(),
235 "dataset": report_data.write_csv(),
236 "filters": filter_description,
237 }
239 messages = utils.generate_messages(
240 ai_instructions,
241 prompts.list_prompts["report_prompt"],
242 variables,
243 prompts.list_prompts["safety_prompt"],
244 )
245 return OpenAIClient(self.ai_configuration).generate_chat(
246 messages, callbacks=callbacks
247 )