Coverage for intelligence_toolkit/detect_case_patterns/api.py: 0%
55 statements
« prev ^ index » next coverage.py v7.10.7, created at 2025-10-16 13:41 -0300
« prev ^ index » next coverage.py v7.10.7, created at 2025-10-16 13:41 -0300
1# Copyright (c) 2024 Microsoft Corporation. All rights reserved.
2# Licensed under the MIT license. See LICENSE file in the project.
4import altair as alt
5import pandas as pd
7import intelligence_toolkit.AI.utils as utils
8import intelligence_toolkit.detect_case_patterns.model as model
9import intelligence_toolkit.detect_case_patterns.prompts as prompts
10import intelligence_toolkit.graph.graph_fusion_encoder_embedding as gfee
11from intelligence_toolkit.AI.client import OpenAIClient
12from intelligence_toolkit.helpers.classes import IntelligenceWorkflow
15class DetectCasePatterns(IntelligenceWorkflow):
16 def __init__(self):
17 self.dynamic_graph_df = pd.DataFrame()
18 self.detect_patterns_df = pd.DataFrame()
19 self.patterns_df = pd.DataFrame()
20 self.period_to_graph = {}
21 self.node_to_label = {}
22 self.period_col = ""
23 self.type_val_sep = ":"
25 def generate_graph_model(
26 self,
27 df,
28 period_col,
29 type_val_sep=":",
30 min_edge_weight=0.001,
31 missing_edge_prop=0.1,
32 ):
33 self.input_df = df
34 self.period_col = period_col
35 self.type_val_sep = type_val_sep
36 self.dynamic_graph_df = model.generate_graph_model(df, period_col, type_val_sep)
37 self._prepare_graph(
38 min_edge_weight,
39 missing_edge_prop,
40 )
42 def _prepare_graph(
43 self,
44 min_edge_weight,
45 missing_edge_prop,
46 ):
47 self.detect_patterns_df, self.period_to_graph = model.prepare_graph(
48 self.dynamic_graph_df, min_edge_weight, missing_edge_prop
49 )
51 def generate_embedding_model(self):
52 node_to_label_str = dict(
53 self.dynamic_graph_df[["Full Attribute", "Attribute Type"]].values
54 )
55 # convert string labels to int labels
56 sorted_labels = sorted(set(node_to_label_str.values()))
57 label_to_code = {v: i for i, v in enumerate(sorted_labels)}
58 self.node_to_label = {
59 k: {0: label_to_code[v]} for k, v in node_to_label_str.items()
60 }
61 self.node_to_period_to_pos, self.node_to_period_to_shift = (
62 gfee.generate_graph_fusion_encoder_embedding(
63 self.period_to_graph,
64 self.node_to_label,
65 correlation=True,
66 diaga=True,
67 laplacian=True,
68 max_level=0,
69 )
70 )
72 def detect_patterns(self, min_pattern_count, max_pattern_length):
73 self.min_pattern_count = min_pattern_count
74 self.max_pattern_length = max_pattern_length
75 (self.patterns_df, self.close_pairs, self.all_pairs) = model.detect_patterns(
76 self.node_to_period_to_pos,
77 self.dynamic_graph_df,
78 self.type_val_sep,
79 self.min_pattern_count,
80 self.max_pattern_length,
81 )
83 def create_time_series_df(self):
84 self.time_series_df = model.create_time_series_df(
85 self.dynamic_graph_df, self.patterns_df
86 )
88 def compute_attribute_counts(self, selected_pattern, selected_pattern_period):
89 return model.compute_attribute_counts(
90 df=self.input_df,
91 pattern=selected_pattern,
92 period_col=self.period_col,
93 period=selected_pattern_period,
94 type_val_sep=self.type_val_sep,
95 )
97 def create_time_series_chart(
98 self,
99 selected_pattern,
100 selected_pattern_period,
101 resize_title=False,
102 ):
103 selected_pattern_df = self.time_series_df[
104 (self.time_series_df["pattern"] == selected_pattern)
105 ]
106 title = "Pattern: " + selected_pattern + " (" + selected_pattern_period + ")"
107 if resize_title and len(title) > 100:
108 # Find the last occurrence of '&' within the first half of the title
109 split_index = title.rfind("&", 0, len(title) // 2)
111 # If '&' is found, break the title there; otherwise, split by length
112 if split_index != -1:
113 title = [
114 title[: split_index + 1].strip(),
115 title[split_index + 1 :].strip(),
116 ]
117 else:
118 title = [
119 title[: len(title) // 2].strip(),
120 title[len(title) // 2 :].strip(),
121 ]
123 count_ct = (
124 alt.Chart(selected_pattern_df)
125 .mark_line()
126 .encode(x="period:O", y="count:Q", color=alt.ColorValue("blue"))
127 .properties(title=title, height=220, width=600)
128 )
129 return count_ct
131 def explain_pattern(
132 self,
133 selected_pattern,
134 selected_pattern_period,
135 attribute_counts=None,
136 ai_instructions=prompts.user_prompt,
137 callbacks=[],
138 ):
139 if attribute_counts is None:
140 attribute_counts = self.compute_attribute_counts(
141 selected_pattern, selected_pattern_period
142 )
143 variables = {
144 "pattern": selected_pattern,
145 "period": selected_pattern_period,
146 "time_series": self.time_series_df[
147 self.time_series_df["pattern"] == selected_pattern
148 ].to_csv(index=False),
149 "attribute_counts": attribute_counts.to_csv(index=False),
150 }
151 messages = utils.generate_messages(
152 ai_instructions,
153 prompts.list_prompts["report_prompt"],
154 variables,
155 prompts.list_prompts["safety_prompt"],
156 )
157 return OpenAIClient(self.ai_configuration).generate_chat(
158 messages, callbacks=callbacks
159 )