Coverage for intelligence_toolkit/detect_case_patterns/api.py: 0%

55 statements  

« prev     ^ index     » next       coverage.py v7.10.7, created at 2025-10-16 13:41 -0300

1# Copyright (c) 2024 Microsoft Corporation. All rights reserved. 

2# Licensed under the MIT license. See LICENSE file in the project. 

3 

4import altair as alt 

5import pandas as pd 

6 

7import intelligence_toolkit.AI.utils as utils 

8import intelligence_toolkit.detect_case_patterns.model as model 

9import intelligence_toolkit.detect_case_patterns.prompts as prompts 

10import intelligence_toolkit.graph.graph_fusion_encoder_embedding as gfee 

11from intelligence_toolkit.AI.client import OpenAIClient 

12from intelligence_toolkit.helpers.classes import IntelligenceWorkflow 

13 

14 

15class DetectCasePatterns(IntelligenceWorkflow): 

16 def __init__(self): 

17 self.dynamic_graph_df = pd.DataFrame() 

18 self.detect_patterns_df = pd.DataFrame() 

19 self.patterns_df = pd.DataFrame() 

20 self.period_to_graph = {} 

21 self.node_to_label = {} 

22 self.period_col = "" 

23 self.type_val_sep = ":" 

24 

25 def generate_graph_model( 

26 self, 

27 df, 

28 period_col, 

29 type_val_sep=":", 

30 min_edge_weight=0.001, 

31 missing_edge_prop=0.1, 

32 ): 

33 self.input_df = df 

34 self.period_col = period_col 

35 self.type_val_sep = type_val_sep 

36 self.dynamic_graph_df = model.generate_graph_model(df, period_col, type_val_sep) 

37 self._prepare_graph( 

38 min_edge_weight, 

39 missing_edge_prop, 

40 ) 

41 

42 def _prepare_graph( 

43 self, 

44 min_edge_weight, 

45 missing_edge_prop, 

46 ): 

47 self.detect_patterns_df, self.period_to_graph = model.prepare_graph( 

48 self.dynamic_graph_df, min_edge_weight, missing_edge_prop 

49 ) 

50 

51 def generate_embedding_model(self): 

52 node_to_label_str = dict( 

53 self.dynamic_graph_df[["Full Attribute", "Attribute Type"]].values 

54 ) 

55 # convert string labels to int labels 

56 sorted_labels = sorted(set(node_to_label_str.values())) 

57 label_to_code = {v: i for i, v in enumerate(sorted_labels)} 

58 self.node_to_label = { 

59 k: {0: label_to_code[v]} for k, v in node_to_label_str.items() 

60 } 

61 self.node_to_period_to_pos, self.node_to_period_to_shift = ( 

62 gfee.generate_graph_fusion_encoder_embedding( 

63 self.period_to_graph, 

64 self.node_to_label, 

65 correlation=True, 

66 diaga=True, 

67 laplacian=True, 

68 max_level=0, 

69 ) 

70 ) 

71 

72 def detect_patterns(self, min_pattern_count, max_pattern_length): 

73 self.min_pattern_count = min_pattern_count 

74 self.max_pattern_length = max_pattern_length 

75 (self.patterns_df, self.close_pairs, self.all_pairs) = model.detect_patterns( 

76 self.node_to_period_to_pos, 

77 self.dynamic_graph_df, 

78 self.type_val_sep, 

79 self.min_pattern_count, 

80 self.max_pattern_length, 

81 ) 

82 

83 def create_time_series_df(self): 

84 self.time_series_df = model.create_time_series_df( 

85 self.dynamic_graph_df, self.patterns_df 

86 ) 

87 

88 def compute_attribute_counts(self, selected_pattern, selected_pattern_period): 

89 return model.compute_attribute_counts( 

90 df=self.input_df, 

91 pattern=selected_pattern, 

92 period_col=self.period_col, 

93 period=selected_pattern_period, 

94 type_val_sep=self.type_val_sep, 

95 ) 

96 

97 def create_time_series_chart( 

98 self, 

99 selected_pattern, 

100 selected_pattern_period, 

101 resize_title=False, 

102 ): 

103 selected_pattern_df = self.time_series_df[ 

104 (self.time_series_df["pattern"] == selected_pattern) 

105 ] 

106 title = "Pattern: " + selected_pattern + " (" + selected_pattern_period + ")" 

107 if resize_title and len(title) > 100: 

108 # Find the last occurrence of '&' within the first half of the title 

109 split_index = title.rfind("&", 0, len(title) // 2) 

110 

111 # If '&' is found, break the title there; otherwise, split by length 

112 if split_index != -1: 

113 title = [ 

114 title[: split_index + 1].strip(), 

115 title[split_index + 1 :].strip(), 

116 ] 

117 else: 

118 title = [ 

119 title[: len(title) // 2].strip(), 

120 title[len(title) // 2 :].strip(), 

121 ] 

122 

123 count_ct = ( 

124 alt.Chart(selected_pattern_df) 

125 .mark_line() 

126 .encode(x="period:O", y="count:Q", color=alt.ColorValue("blue")) 

127 .properties(title=title, height=220, width=600) 

128 ) 

129 return count_ct 

130 

131 def explain_pattern( 

132 self, 

133 selected_pattern, 

134 selected_pattern_period, 

135 attribute_counts=None, 

136 ai_instructions=prompts.user_prompt, 

137 callbacks=[], 

138 ): 

139 if attribute_counts is None: 

140 attribute_counts = self.compute_attribute_counts( 

141 selected_pattern, selected_pattern_period 

142 ) 

143 variables = { 

144 "pattern": selected_pattern, 

145 "period": selected_pattern_period, 

146 "time_series": self.time_series_df[ 

147 self.time_series_df["pattern"] == selected_pattern 

148 ].to_csv(index=False), 

149 "attribute_counts": attribute_counts.to_csv(index=False), 

150 } 

151 messages = utils.generate_messages( 

152 ai_instructions, 

153 prompts.list_prompts["report_prompt"], 

154 variables, 

155 prompts.list_prompts["safety_prompt"], 

156 ) 

157 return OpenAIClient(self.ai_configuration).generate_chat( 

158 messages, callbacks=callbacks 

159 )