Coverage for intelligence_toolkit/anonymize_case_data/api.py: 100%

95 statements  

« prev     ^ index     » next       coverage.py v7.10.7, created at 2025-10-16 13:41 -0300

1# Copyright (c) 2024 Microsoft Corporation. All rights reserved. 

2# Licensed under the MIT license. See LICENSE file in the project. 

3 

4import math 

5import pandas as pd 

6import plotly.graph_objects as go 

7from enum import Enum 

8from pacsynth import ( 

9 AccuracyMode, 

10 Dataset, 

11 DpAggregateSeededParametersBuilder, 

12 DpAggregateSeededSynthesizer, 

13 FabricationMode, 

14) 

15 

16import app.util.df_functions as df_functions 

17import intelligence_toolkit.anonymize_case_data.queries as queries 

18import intelligence_toolkit.anonymize_case_data.visuals as visuals 

19from intelligence_toolkit.anonymize_case_data.error_report import ErrorReport 

20from intelligence_toolkit.anonymize_case_data.synthesizability_statistics import ( 

21 SynthesizabilityStatistics, 

22) 

23from intelligence_toolkit.helpers.classes import IntelligenceWorkflow 

24 

25 

26 

27class AnonymizeCaseData(IntelligenceWorkflow): 

28 class FabricationStrategy(Enum): 

29 BALANCED = FabricationMode.balanced() 

30 PROGRESSIVE = FabricationMode.progressive() 

31 MINIMIZED = FabricationMode.minimize() 

32 UNCONTROLLED = FabricationMode.uncontrolled() 

33 

34 def __init__(self) -> None: 

35 self.protected_number_of_records = 0 

36 self.delta = 0 

37 self.sensitive_df = pd.DataFrame() 

38 self.aggregate_df = pd.DataFrame() 

39 self.synthetic_aggregate_df = pd.DataFrame() 

40 self.synthetic_df = pd.DataFrame() 

41 self.aggregate_error_report = pd.DataFrame() 

42 self.synthetic_error_report = pd.DataFrame() 

43 

44 def analyze_synthesizability(self, df: pd.DataFrame) -> SynthesizabilityStatistics: 

45 distinct_counts = [] 

46 att_cols = list(df.columns) 

47 num_cols = len(att_cols) 

48 for col in df.columns.to_numpy(): 

49 distinct_values = [ 

50 x for x in df[col].astype(str).unique() if x not in ["", "nan"] 

51 ] 

52 num = len(distinct_values) 

53 if num > 0: 

54 distinct_counts.append(num) 

55 distinct_counts.sort() 

56 overall_att_count = sum(distinct_counts) 

57 possible_combinations = math.prod(distinct_counts) 

58 possible_combinations_per_row = round(possible_combinations / len(df), 1) 

59 mean_vals_per_record = ( 

60 sum( 

61 [ 

62 len([y for y in x if str(y) not in ["nan", ""]]) 

63 for x in df.to_numpy() 

64 ] 

65 ) 

66 / df.shape[0] 

67 ) 

68 max_combinations_per_record = 2**mean_vals_per_record 

69 excess_combinations_ratio = ( 

70 possible_combinations_per_row / max_combinations_per_record 

71 ) 

72 return SynthesizabilityStatistics( 

73 num_cols, 

74 overall_att_count, 

75 possible_combinations, 

76 possible_combinations_per_row, 

77 mean_vals_per_record, 

78 max_combinations_per_record, 

79 excess_combinations_ratio, 

80 ) 

81 

82 def anonymize_case_data( 

83 self, 

84 df: pd.DataFrame, 

85 epsilon: float, 

86 reporting_length: int = 4, 

87 percentile_percentage: float = 99, 

88 percentile_epsilon_proportion: float = 0.01, 

89 number_of_records_epsilon_proportion: float = 0.005, 

90 weight_selection_percentile: float = 95, 

91 accuracy_mode: AccuracyMode = AccuracyMode.prioritize_long_combinations(), 

92 fabrication_mode: FabricationStrategy = FabricationStrategy.BALANCED, 

93 empty_value: str = "", 

94 use_synthetic_counts: bool = True, 

95 aggregate_counts_scale_factor: float = 1.0, 

96 ) -> None: 

97 """ 

98 Anonymizes a given dataframe that has been preformatted as categorical microdata (one subject per row; one row per subject). 

99 

100 See [Synthetic Data Showcase](https://github.com/microsoft/synthetic-data-showcase) for more information. 

101 

102 Args: 

103 df (pd.DataFrame): The dataframe to be anonymized. 

104 epsilon (float): The epsilon value for differential privacy. 

105 reporting_length (int, optional): The maximum length of attribute value combination to compute. Defaults to 4. 

106 percentile_percentage (float, optional): The percentile to use for the epsilon budget. Defaults to 99. 

107 percentile_epsilon_proportion (float, optional): The proportion of the epsilon budget to use for percentile calculation. Defaults to 0.01. 

108 number_of_records_epsilon_proportion (float, optional): The proportion of the epsilon budget to use for the number of records. Defaults to 0.005. 

109 weight_selection_percentile (float, optional): The percentile to use for selecting weights. Defaults to 95. 

110 accuracy_mode (AccuracyMode, optional): The accuracy mode to use. Defaults to AccuracyMode.prioritize_long_combinations(). 

111 fabrication_mode (FabricationMode, optional): The fabrication mode to use. Defaults to FabricationMode.balanced(). 

112 empty_value (str, optional): The value to use for empty cells. Defaults to "". 

113 use_synthetic_counts (bool, optional): Whether to use synthetic counts in progress to guide sampling. Defaults to True. 

114 aggregate_counts_scale_factor (float, optional): The scale factor to use for aggregate counts. Defaults to 1.0. 

115 """ 

116 self.sensitive_df = df_functions.fix_null_ints(df) 

117 

118 sensitive_dataset = Dataset.from_data_frame(self.sensitive_df) 

119 

120 params = ( 

121 DpAggregateSeededParametersBuilder() 

122 .reporting_length(reporting_length) 

123 .epsilon(epsilon) 

124 .percentile_percentage(percentile_percentage) 

125 .percentile_epsilon_proportion(percentile_epsilon_proportion) 

126 .accuracy_mode(accuracy_mode) 

127 .number_of_records_epsilon_proportion(number_of_records_epsilon_proportion) 

128 .fabrication_mode(fabrication_mode.value) 

129 .empty_value(empty_value) 

130 .weight_selection_percentile(weight_selection_percentile) 

131 .use_synthetic_counts(use_synthetic_counts) 

132 .aggregate_counts_scale_factor(aggregate_counts_scale_factor) 

133 .build() 

134 ) 

135 

136 synth = DpAggregateSeededSynthesizer(params) 

137 

138 synth.fit(sensitive_dataset) 

139 self.protected_number_of_records = synth.get_dp_number_of_records() 

140 self.delta = 1.0 / ( 

141 math.log(self.protected_number_of_records) 

142 * self.protected_number_of_records 

143 ) 

144 synthetic_raw_data = synth.sample() 

145 synthetic_dataset = Dataset(synthetic_raw_data) 

146 self.synthetic_df = Dataset.raw_data_to_data_frame(synthetic_raw_data) 

147 

148 sensitive_aggregates = sensitive_dataset.get_aggregates(reporting_length, ";") 

149 

150 # export the differentially private aggregates (internal to the synthesizer) 

151 dp_aggregates = synth.get_dp_aggregates(";") 

152 

153 # generate aggregates from the synthetic data 

154 synthetic_aggregates = synthetic_dataset.get_aggregates(reporting_length, ";") 

155 

156 sensitive_aggregates_parsed = { 

157 tuple(agg.split(";")): count 

158 for (agg, count) in sensitive_aggregates.items() 

159 } 

160 dp_aggregates_parsed = { 

161 tuple(agg.split(";")): count for (agg, count) in dp_aggregates.items() 

162 } 

163 synthetic_aggregates_parsed = { 

164 tuple(agg.split(";")): count 

165 for (agg, count) in synthetic_aggregates.items() 

166 } 

167 

168 self.aggregate_df = pd.DataFrame( 

169 data=dp_aggregates.items(), 

170 columns=["selections", "protected_count"], 

171 ) 

172 self.aggregate_df.loc[len(self.aggregate_df)] = [ 

173 "record_count", 

174 self.protected_number_of_records, 

175 ] 

176 self.aggregate_df = self.aggregate_df.sort_values( 

177 by=["protected_count"], ascending=False 

178 ) 

179 

180 self.synthetic_aggregate_df = pd.DataFrame( 

181 data=synthetic_aggregates.items(), 

182 columns=["selections", "protected_count"], 

183 ) 

184 self.synthetic_aggregate_df.loc[len(self.synthetic_aggregate_df)] = [ 

185 "record_count", 

186 self.protected_number_of_records, 

187 ] 

188 self.synthetic_aggregate_df = self.synthetic_aggregate_df.sort_values( 

189 by=["protected_count"], ascending=False 

190 ) 

191 

192 self.aggregate_error_report = ErrorReport( 

193 sensitive_aggregates_parsed, dp_aggregates_parsed 

194 ).gen() 

195 self.synthetic_error_report = ErrorReport( 

196 sensitive_aggregates_parsed, synthetic_aggregates_parsed 

197 ).gen() 

198 

199 def get_data_schema(self) -> dict[list[str]]: 

200 return queries.get_data_schema(self.synthetic_df) 

201 

202 def compute_aggregate_graph_df( 

203 self, 

204 filters: list[str], 

205 source_attribute: str, 

206 target_attribute: str, 

207 highlight_attribute: str, 

208 ) -> pd.DataFrame: 

209 return queries.compute_aggregate_graph( 

210 self.aggregate_df, 

211 filters, 

212 source_attribute, 

213 target_attribute, 

214 highlight_attribute, 

215 ) 

216 

217 def compute_synthetic_graph_df( 

218 self, 

219 filters: list[str], 

220 source_attribute: str, 

221 target_attribute: str, 

222 highlight_attribute, 

223 ) -> pd.DataFrame: 

224 return queries.compute_synthetic_graph( 

225 self.synthetic_df, 

226 filters, 

227 source_attribute, 

228 target_attribute, 

229 highlight_attribute, 

230 ) 

231 

232 def compute_time_series_query_df( 

233 self, 

234 selection, 

235 time_attribute, 

236 series_attributes, 

237 att_separator=";", 

238 val_separator=":", 

239 ) -> pd.DataFrame: 

240 return queries.compute_time_series_query( 

241 query=selection, 

242 sdf=self.synthetic_df, 

243 adf=self.aggregate_df, 

244 time_attribute=time_attribute, 

245 time_series=series_attributes, 

246 att_separator=att_separator, 

247 val_separator=val_separator, 

248 ) 

249 

250 def compute_top_attributes_query_df( 

251 self, 

252 query: str, 

253 show_attributes: list[str], 

254 num_values: int, 

255 att_separator=";", 

256 val_separator=":", 

257 ) -> pd.DataFrame: 

258 return queries.compute_top_attributes_query( 

259 query, 

260 self.synthetic_df, 

261 self.aggregate_df, 

262 show_attributes, 

263 num_values, 

264 att_separator, 

265 val_separator, 

266 ) 

267 

268 def get_bar_chart_fig( 

269 self, 

270 selection: list[str], 

271 show_attributes: list[str], 

272 unit: str, 

273 width: int, 

274 height: int, 

275 scheme: list[str], 

276 num_values: int, 

277 att_separator=";", 

278 val_separator=":", 

279 ) -> tuple[go.Figure, pd.DataFrame]: 

280 chart_df = self.compute_top_attributes_query_df( 

281 query=selection, 

282 show_attributes=show_attributes, 

283 num_values=num_values, 

284 att_separator=att_separator, 

285 val_separator=val_separator, 

286 ) 

287 chart = visuals.get_bar_chart( 

288 selection=selection, 

289 show_attributes=show_attributes, 

290 unit=unit, 

291 chart_df=chart_df, 

292 width=width, 

293 height=height, 

294 scheme=scheme, 

295 ) 

296 return chart, chart_df 

297 

298 def get_line_chart_fig( 

299 self, 

300 selection: list[str], 

301 series_attributes: list[str], 

302 unit: str, 

303 time_attribute: str, 

304 width: int, 

305 height: int, 

306 scheme: list[str], 

307 att_separator: str = ";", 

308 val_separator: str = ":", 

309 ) -> tuple[go.Figure, pd.DataFrame]: 

310 chart_df = self.compute_time_series_query_df( 

311 selection=selection, 

312 time_attribute=time_attribute, 

313 series_attributes=series_attributes, 

314 att_separator=att_separator, 

315 val_separator=val_separator, 

316 ) 

317 chart = visuals.get_line_chart( 

318 selection=selection, 

319 series_attributes=series_attributes, 

320 unit=unit, 

321 chart_df=chart_df, 

322 time_attribute=time_attribute, 

323 width=width, 

324 height=height, 

325 scheme=scheme, 

326 ) 

327 return chart, chart_df 

328 

329 def get_flow_chart_fig( 

330 self, 

331 selection: list[str], 

332 source_attribute: str, 

333 target_attribute: str, 

334 highlight_attribute: str, 

335 width: int, 

336 height: int, 

337 unit: str, 

338 scheme: list[str], 

339 att_separator: str = ";", 

340 val_separator: str = ":", 

341 ): 

342 selection_keys = [x["attribute"] + val_separator + x["value"] for x in selection] 

343 att_count = 2 if highlight_attribute == "" else 3 

344 att_count += len(selection) 

345 if att_count <= 4: 

346 chart_df = queries.compute_aggregate_graph( 

347 self.aggregate_df, 

348 selection_keys, 

349 source_attribute, 

350 target_attribute, 

351 highlight_attribute, 

352 att_separator, 

353 val_separator, 

354 ) 

355 else: 

356 chart_df = queries.compute_synthetic_graph( 

357 self.synthetic_df, 

358 selection_keys, 

359 source_attribute, 

360 target_attribute, 

361 highlight_attribute, 

362 att_separator, 

363 val_separator, 

364 ) 

365 chart = visuals.get_flow_chart( 

366 chart_df, 

367 selection, 

368 source_attribute, 

369 target_attribute, 

370 highlight_attribute, 

371 width, 

372 height, 

373 unit, 

374 scheme, 

375 ) 

376 return chart, chart_df