Coverage for intelligence_toolkit/anonymize_case_data/queries.py: 100%

132 statements  

« prev     ^ index     » next       coverage.py v7.10.7, created at 2025-10-16 13:41 -0300

1# Copyright (c) 2024 Microsoft Corporation. All rights reserved. 

2# Licensed under the MIT license. See LICENSE file in the project. 

3 

4from collections import defaultdict 

5from typing import Any 

6 

7import pandas as pd 

8 

9 

10def get_data_schema(sdf) -> dict[list[str]]: 

11 data_schema = defaultdict(list) 

12 if sdf is not None: 

13 for att in sdf.columns.values: 

14 vals = [str(x) for x in sdf[att].unique() if len(str(x)) > 0] 

15 for val in vals: 

16 data_schema[att].append(val) 

17 data_schema[att].sort() 

18 return data_schema 

19 

20def compute_aggregate_graph( 

21 adf, 

22 filters, 

23 source_attribute, 

24 target_attribute, 

25 highlight_attribute, 

26 att_separator=";", 

27 val_separator=":", 

28) -> pd.DataFrame: 

29 edge_atts = {source_attribute, target_attribute} 

30 edges = [] 

31 edge_counts = {} 

32 edge_highlights = {} 

33 for i, row in adf.iterrows(): 

34 selections = row["selections"].split(att_separator) 

35 selection_atts = set([x.split(val_separator)[0] for x in selections]) 

36 att_intersection = selection_atts.intersection(edge_atts) 

37 if att_intersection == edge_atts: 

38 # the record has both edge attributes; now we need to filter on the filters 

39 val_intersection = set(selections).intersection(filters) 

40 if val_intersection == set(filters): 

41 # the record has both edge attributes and the filters 

42 # check what else is in the selections 

43 remaining = [] 

44 source_val = None 

45 target_val = None 

46 for selection in selections: 

47 att = selection.split(val_separator)[0] 

48 if att == source_attribute: 

49 source_val = selection.split(val_separator)[1] 

50 if att == target_attribute: 

51 target_val = selection.split(val_separator)[1] 

52 if ( 

53 att not in edge_atts 

54 and selection not in filters 

55 and selection not in remaining 

56 ): 

57 remaining.append(selection) 

58 if len(remaining) == 0: 

59 edge_counts[(source_val, target_val)] = row["protected_count"] 

60 elif ( 

61 len(remaining) == 1 

62 and highlight_attribute is not None 

63 and remaining[0] == highlight_attribute 

64 ): 

65 edge_highlights[(source_val, target_val)] = row["protected_count"] 

66 

67 for edge, count in edge_counts.items(): 

68 if count > 0: 

69 highlight = edge_highlights[edge] if edge in edge_highlights else 0 

70 edges.append( 

71 [ 

72 edge[0], 

73 edge[1], 

74 count, 

75 highlight, 

76 highlight / count, 

77 "Aggregate", 

78 ] 

79 ) 

80 

81 edges_df = pd.DataFrame( 

82 edges, 

83 columns=["Source", "Target", "Count", "Highlight", "Proportion", "Dataset"], 

84 ) 

85 return edges_df 

86 

87 

88def compute_synthetic_graph( 

89 sdf, 

90 filters, 

91 source_attribute, 

92 target_attribute, 

93 highlight_attribute, 

94 att_separator=";", 

95 val_separator=":", 

96) -> pd.DataFrame: 

97 edges = [] 

98 att_groups = {} 

99 for f in filters: 

100 att, val = f.split(val_separator) 

101 att_groups[att] = att_groups.get(att, []) + [val] 

102 

103 # compute all pairs of source and target 

104 for source in sdf[source_attribute].unique(): 

105 for target in sdf[target_attribute].unique(): 

106 if len(str(source)) == 0 or len(str(target)) == 0: 

107 continue 

108 df = sdf.copy() 

109 df = df[df[source_attribute] == source] 

110 df = df[df[target_attribute] == target] 

111 for att, vals in att_groups.items(): 

112 df = df[df[att].isin(vals)] 

113 count = len(df) 

114 if count > 0: 

115 highlight = 0 

116 if highlight_attribute != "": 

117 hatt, hval = highlight_attribute.split(val_separator) 

118 df = df[df[hatt] == hval] 

119 highlight = len(df) 

120 edges.append( 

121 [ 

122 source, 

123 target, 

124 count, 

125 highlight, 

126 highlight / count, 

127 "Synthetic", 

128 ] 

129 ) 

130 

131 edges_df = pd.DataFrame( 

132 edges, 

133 columns=["Source", "Target", "Count", "Highlight", "Proportion", "Dataset"], 

134 ) 

135 return edges_df 

136 

137def compute_top_attributes_query( 

138 query, sdf, adf, show_attributes, num_values, att_separator=";", val_separator=":" 

139) -> pd.DataFrame | Any: 

140 data_schema = get_data_schema(sdf) 

141 df = sdf.copy(deep=True) 

142 selection = [] 

143 has_unions = False 

144 for att, vals in data_schema.items(): 

145 filter_vals = [ 

146 v 

147 for v in vals 

148 if {"attribute": att, "value": v} in query and len(str(v)) > 0 

149 ] 

150 if len(filter_vals) > 0: 

151 df = df[df[att].isin(filter_vals)] 

152 if len(filter_vals) == 1: 

153 selection.append(f"{att}{val_separator}{filter_vals[0]}") 

154 else: 

155 has_unions = True 

156 # Add ID based on row number 

157 df["Id"] = [i for i in range(len(df))] 

158 sdf_filtered = df.melt(id_vars=["Id"], var_name="Attribute", value_name="Value") 

159 sdf_filtered["AttributeValue"] = ( 

160 sdf_filtered["Attribute"] + val_separator + sdf_filtered["Value"] 

161 ) 

162 sdf_filtered = sdf_filtered[sdf_filtered["Value"] != ""] 

163 syn_counts = ( 

164 sdf_filtered["AttributeValue"] 

165 .value_counts() 

166 .rename_axis("Attribute Value") 

167 .to_frame("Count") 

168 ) 

169 syn_counts["Attribute"] = syn_counts.index.str.split(val_separator).str[0] 

170 syn_counts.reset_index(level=0, inplace=True) 

171 syn_counts["Dataset"] = "Synthetic" 

172 

173 result_df = syn_counts[["Attribute", "Attribute Value", "Count", "Dataset"]] 

174 

175 if not has_unions: 

176 agg_rows = [] 

177 for att, vals in data_schema.items(): 

178 for val in vals: 

179 filter = f"{att}{val_separator}{val}" 

180 extended_selection = ( 

181 sorted(selection + [filter]) 

182 if filter not in selection 

183 else sorted(selection) 

184 ) 

185 extended_selection_key = att_separator.join(extended_selection) 

186 extended_filtered_aggs = adf[ 

187 adf["selections"] == extended_selection_key 

188 ] 

189 extended_agg_count = ( 

190 extended_filtered_aggs["protected_count"].values[0] 

191 if len(extended_filtered_aggs) > 0 

192 else 0 

193 ) 

194 if extended_agg_count > 0: 

195 agg_rows.append([att, filter, extended_agg_count, "Aggregate"]) 

196 agg_df = pd.DataFrame( 

197 agg_rows, columns=["Attribute", "Attribute Value", "Count", "Dataset"] 

198 ) 

199 result_df = pd.concat([syn_counts, agg_df], axis=0, ignore_index=True) 

200 # remove rows of result_df where Dataset = Synthetic if there is another row with Dataset = Aggregate for the same Attribute Value 

201 result_df = result_df[ 

202 ~( 

203 (result_df["Dataset"] == "Synthetic") 

204 & ( 

205 result_df["Attribute Value"].isin( 

206 result_df[result_df["Dataset"] == "Aggregate"][ 

207 "Attribute Value" 

208 ].values 

209 ) 

210 ) 

211 ) 

212 ] 

213 if len(show_attributes) > 0: 

214 result_df = result_df[result_df["Attribute"].isin(show_attributes)] 

215 

216 result_df = result_df.sort_values(by=["Count"], ascending=False) 

217 if num_values > 0: 

218 result_df = result_df[:num_values] 

219 return result_df[["Attribute", "Attribute Value", "Count", "Dataset"]] 

220 

221def compute_time_series_query( 

222 query, sdf, adf, time_attribute, time_series, att_separator=";", val_separator=":" 

223) -> pd.DataFrame: 

224 tdfs = [] 

225 times = [t for t in sorted(sdf[time_attribute].unique()) if len(str(t)) > 0] 

226 for time in times: 

227 time_query = query + [{"attribute": time_attribute, "value": time}] 

228 tdf = compute_top_attributes_query( 

229 query=time_query, 

230 sdf=sdf, 

231 adf=adf, 

232 show_attributes=time_series, 

233 num_values=0, 

234 att_separator=att_separator, 

235 val_separator=val_separator, 

236 ) 

237 tdf[time_attribute] = time 

238 tdfs.append(tdf) 

239 final_tdf = pd.concat(tdfs, axis=0, ignore_index=True) 

240 missing = [] 

241 for _, row in final_tdf.iterrows(): 

242 att = row["Attribute"] 

243 val = row["Attribute Value"] 

244 for time in times: 

245 match = final_tdf[ 

246 (final_tdf["Attribute"] == att) 

247 & (final_tdf["Attribute Value"] == val) 

248 & (final_tdf[time_attribute] == time) 

249 ] 

250 if len(match) == 0: 

251 missing.append([time, att, val, 0, "Aggregate"]) 

252 

253 if len(missing) > 0: 

254 missing_df = pd.DataFrame( 

255 missing, 

256 columns=[ 

257 time_attribute, 

258 "Attribute", 

259 "Attribute Value", 

260 "Count", 

261 "Dataset", 

262 ], 

263 ) 

264 final_tdf = pd.concat([final_tdf, missing_df], axis=0, ignore_index=True) 

265 return ( 

266 final_tdf[[time_attribute, "Attribute", "Attribute Value", "Count", "Dataset"]] 

267 .drop_duplicates() 

268 .sort_values(by=[time_attribute, "Attribute Value"]) 

269 )