Coverage for contextualized/analysis/effects.py: 31%

208 statements  

« prev     ^ index     » next       coverage.py v7.4.4, created at 2024-04-21 13:49 -0400

1""" 

2Utilities for plotting effects learned by Contextualized models. 

3""" 

4 

5from typing import * 

6 

7import numpy as np 

8import matplotlib.pyplot as plt 

9 

10from contextualized.easy.wrappers import SKLearnWrapper 

11 

12 

13def simple_plot( 

14 x_vals: List[Union[float, int]], 

15 y_vals: List[Union[float, int]], 

16 **kwargs, 

17) -> None: 

18 """ 

19 Simple plotting of y vs x with kwargs passed to matplotlib helpers. 

20 

21 Args: 

22 x_vals: x values to plot 

23 y_vals: y values to plot 

24 **kwargs: kwargs passed to matplotlib helpers (fill_alpha, fill_color, y_lowers, y_uppers, x_label, y_label, x_ticks, x_ticklabels, y_ticks, y_ticklabels) 

25 

26 Returns: 

27 None 

28 """ 

29 plt.figure(figsize=kwargs.get("figsize", (8, 8))) 

30 if "y_lowers" in kwargs and "y_uppers" in kwargs: 

31 plt.fill_between( 

32 x_vals, 

33 np.squeeze(kwargs["y_lowers"]), 

34 np.squeeze(kwargs["y_uppers"]), 

35 alpha=kwargs.get("fill_alpha", 0.2), 

36 color=kwargs.get("fill_color", "blue"), 

37 ) 

38 plt.plot(x_vals, y_vals) 

39 plt.xlabel(kwargs.get("x_label", "X")) 

40 plt.ylabel(kwargs.get("y_label", "Y")) 

41 if ( 

42 kwargs.get("x_ticks", None) is not None 

43 and kwargs.get("x_ticklabels", None) is not None 

44 ): 

45 plt.xticks(kwargs["x_ticks"], kwargs["x_ticklabels"]) 

46 if ( 

47 kwargs.get("y_ticks", None) is not None 

48 and kwargs.get("y_ticklabels", None) is not None 

49 ): 

50 plt.yticks(kwargs["y_ticks"], kwargs["y_ticklabels"]) 

51 plt.show() 

52 

53 

54def plot_effect(x_vals, y_means, y_lowers=None, y_uppers=None, **kwargs): 

55 """Plots a single effect.""" 

56 min_val = np.min(y_means) 

57 y_means -= min_val 

58 if y_lowers is not None and y_uppers is not None: 

59 y_lowers -= min_val 

60 y_uppers -= min_val 

61 if kwargs.get("should_exponentiate", False): 

62 y_means = np.exp(y_means) 

63 if y_lowers is not None and y_uppers is not None: 

64 y_lowers = np.exp(y_lowers) 

65 y_uppers = np.exp(y_uppers) 

66 try: 

67 if "x_encoder" in kwargs and kwargs["x_encoder"] is not None: 

68 x_classes = kwargs["x_encoder"].classes_ 

69 # Line up class values with centered values. 

70 x_ticks = np.array(list(range(len(x_classes)))) 

71 if ( 

72 kwargs.get("x_means", None) is not None 

73 and kwargs.get("x_stds", None) is not None 

74 ): 

75 x_ticks = (x_ticks - kwargs["x_means"]) / kwargs["x_stds"] 

76 else: 

77 x_ticks = None 

78 x_classes = None 

79 except: 

80 x_classes = None 

81 x_ticks = None 

82 

83 if np.max(y_means) > kwargs.get("min_effect_size", 0.0): 

84 simple_plot( 

85 x_vals, 

86 y_means, 

87 x_label=kwargs.get("xlabel", "X"), 

88 y_label=kwargs.get("ylabel", "Odds Ratio of Outcome"), 

89 y_lowers=y_lowers, 

90 y_uppers=y_uppers, 

91 x_ticks=x_ticks, 

92 x_ticklabels=x_classes, 

93 ) 

94 

95 

96def get_homogeneous_context_effects( 

97 model: SKLearnWrapper, C: np.ndarray, **kwargs 

98) -> Tuple[np.ndarray, np.ndarray]: 

99 """ 

100 Get the homogeneous (context-invariant) effects of context. 

101 

102 Args: 

103 model (SKLearnWrapper): a fitted ``contextualized.easy`` model 

104 C: the context values to use to estimate the effects 

105 verbose (bool, optional): print progess. Default True. 

106 individual_preds (bool, optional): whether to use plot each bootstrap. Default True. 

107 C_vis (np.ndarray, optional): Context bins used to visualize context (n_vis, n_contexts). Default None to construct anew. 

108 n_vis (int, optional): Number of bins to use to visualize context. Default 1000. 

109 

110 Returns: 

111 Tuple[np.ndarray, np.ndarray]: 

112 c_vis: the context values that were used to estimate the effects 

113 effects: array of effects, one for each context. Each homogeneous effect is a matrix of shape: 

114 (n_bootstraps, n_context_vals, n_outcomes). 

115 """ 

116 if kwargs.get("verbose", True): 

117 print("Estimating Homogeneous Contextual Effects.") 

118 c_vis = maybe_make_c_vis(C, **kwargs) 

119 

120 effects = [] 

121 for j in range(C.shape[1]): 

122 c_j = np.zeros_like(c_vis) 

123 c_j[:, j] = c_vis[:, j] 

124 try: 

125 (_, mus) = model.predict_params( 

126 c_j, individual_preds=kwargs.get("individual_preds", True) 

127 ) 

128 except ValueError: 

129 (_, mus) = model.predict_params(c_j) 

130 effects.append(mus) 

131 return c_vis, np.array(effects) 

132 

133 

134def get_homogeneous_predictor_effects(model, C, **kwargs): 

135 """ 

136 Get the homogeneous (context-invariant) effects of predictors. 

137 :param model: 

138 :param C: 

139 

140 returns: 

141 c_vis: the context values that were used to estimate the effects 

142 effects: np array of effects, one for each predictor. Each homogeneous effect is a matrix of shape: 

143 (n_bootstraps, n_outcomes). 

144 

145 """ 

146 if kwargs.get("verbose", True): 

147 print("Estimating Homogeneous Predictor Effects.") 

148 c_vis = maybe_make_c_vis(C, **kwargs) 

149 c_idx = 0 

150 try: 

151 (betas, _) = model.predict_params( 

152 c_vis, individual_preds=kwargs.get("individual_preds", True) 

153 ) 

154 # bootstraps x C_vis x outcomes x predictors 

155 if len(betas.shape) == 4: 

156 c_idx = 1 

157 except ValueError: 

158 (betas, _) = model.predict_params(c_vis) 

159 betas = np.mean( 

160 betas, axis=c_idx 

161 ) # homogeneous predictor effect is context-invariant 

162 return c_vis, np.transpose(betas, (2, 0, 1)) 

163 

164 

165def get_heterogeneous_predictor_effects(model, C, **kwargs): 

166 """ 

167 Get the heterogeneous (context-variant) effects of predictors. 

168 :param model: 

169 :param C: 

170 

171 returns: 

172 c_vis: the context values that were used to estimate the effects 

173 effects: np array of effects, one for each context x predictor pair. 

174 Each heterogeneous effect is a matrix of shape: 

175 (n_predictors, n_bootstraps, n_context_vals, n_outcomes). 

176 """ 

177 if kwargs.get("verbose", True): 

178 print("Estimating Heterogeneous Predictor Effects.") 

179 c_vis = maybe_make_c_vis(C, **kwargs) 

180 

181 effects = [] 

182 for j in range(C.shape[1]): 

183 c_j = np.zeros_like(c_vis) 

184 c_j[:, j] = c_vis[:, j] 

185 c_idx = 0 

186 try: 

187 (betas, _) = model.predict_params( 

188 c_j, individual_preds=kwargs.get("individual_preds", True) 

189 ) 

190 # bootstraps x C_vis x outcomes x predictors 

191 if len(betas.shape) == 4: 

192 c_idx = 1 

193 except ValueError: 

194 (betas, _) = model.predict_params(c_j) 

195 # Heterogeneous Effects are mean-centered wrt C 

196 effect = np.transpose( 

197 np.transpose(betas, (0, 2, 3, 1)) 

198 - np.tile( 

199 np.expand_dims(np.mean(betas, axis=c_idx), -1), 

200 (1, 1, 1, betas.shape[1]), 

201 ), 

202 (0, 3, 1, 2), 

203 ) 

204 effects.append(effect) 

205 effects = np.array(effects) 

206 if len(effects.shape) == 5: 

207 effects = np.transpose( 

208 effects, (0, 4, 1, 2, 3) 

209 ) # (n_contexts, n_predictors, n_bootstraps, n_context_vals, n_outcomes) 

210 else: 

211 effects = np.transpose( 

212 effects, (0, 3, 1, 2) 

213 ) # (n_contexts, n_predictors, n_context_vals, n_outcomes) 

214 return c_vis, effects 

215 

216 

217def plot_boolean_vars(names, y_mean, y_err, **kwargs): 

218 """ 

219 Plots Boolean variables. 

220 """ 

221 plt.figure(figsize=kwargs.get("figsize", (12, 8))) 

222 sorted_i = np.argsort(y_mean) 

223 if kwargs.get("classification", True): 

224 y_mean = np.exp(y_mean) 

225 y_err = np.exp(y_err) 

226 for counter, i in enumerate(sorted_i): 

227 plt.bar( 

228 counter, 

229 y_mean[i], 

230 width=0.5, 

231 color=kwargs.get("fill_color", "blue"), 

232 edgecolor=kwargs.get("edge_color", "black"), 

233 yerr=y_err, 

234 ) 

235 plt.xticks( 

236 range(len(names)), 

237 np.array(names)[sorted_i], 

238 rotation=60, 

239 fontsize=kwargs.get("boolean_x_ticksize", 18), 

240 ha="right", 

241 ) 

242 plt.ylabel( 

243 kwargs.get("ylabel", "Odds Ratio of Outcome"), 

244 fontsize=kwargs.get("ylabel_fontsize", 32), 

245 ) 

246 plt.yticks(fontsize=kwargs.get("ytick_fontsize", 18)) 

247 if kwargs.get("bool_figname", None) is not None: 

248 plt.savefig(kwargs.get("bool_figname"), dpi=300, bbox_inches="tight") 

249 else: 

250 plt.show() 

251 

252 

253def plot_homogeneous_context_effects( 

254 model: SKLearnWrapper, 

255 C: np.ndarray, 

256 **kwargs, 

257) -> None: 

258 """ 

259 Plot the direct effect of context on outcomes, disregarding other features. 

260 This context effect is homogeneous in that it is a static function of context (context-invariant). 

261 

262 Args: 

263 model (SKLearnWrapper): a fitted ``contextualized.easy`` model 

264 C: the context values to use to estimate the effects 

265 verbose (bool, optional): print progess. Default True. 

266 individual_preds (bool, optional): whether to use plot each bootstrap. Default True. 

267 C_vis (np.ndarray, optional): Context bins used to visualize context (n_vis, n_contexts). Default None to construct anew. 

268 n_vis (int, optional): Number of bins to use to visualize context. Default 1000. 

269 lower_pct (int, optional): Lower percentile for bootstraps. Default 2.5. 

270 upper_pct (int, optional): Upper percentile for bootstraps. Default 97.5. 

271 classification (bool, optional): Whether to exponentiate the effects. Default True. 

272 C_encoders (List[sklearn.preprocessing.LabelEncoder], optional): encoders for each context. Default None. 

273 C_means (np.ndarray, optional): means for each context. Default None. 

274 C_stds (np.ndarray, optional): standard deviations for each context. Default None. 

275 xlabel_prefix (str, optional): prefix for x label. Default "". 

276 figname (str, optional): name of figure to save. Default None. 

277 

278 Returns: 

279 None 

280 """ 

281 c_vis, effects = get_homogeneous_context_effects(model, C, **kwargs) 

282 # effects.shape is (n_context, n_bootstraps, n_context_vals, n_outcomes) 

283 for outcome in range(effects.shape[-1]): 

284 for j in range(effects.shape[0]): 

285 try: 

286 mus = effects[j, :, :, outcome] 

287 means = np.mean(mus, axis=0) 

288 lowers = np.percentile(mus, kwargs.get("lower_pct", 2.5), axis=0) 

289 uppers = np.percentile(mus, kwargs.get("upper_pct", 97.5), axis=0) 

290 except ValueError: 

291 mus = effects[j, :, outcome] 

292 means = mus # no bootstraps were provided. 

293 lowers, uppers = None, None 

294 

295 if "C_encoders" in kwargs: 

296 encoder = kwargs["C_encoders"][j] 

297 else: 

298 encoder = None 

299 if "C_means" in kwargs: 

300 c_means = kwargs["C_means"][j] 

301 else: 

302 c_means = None 

303 if "C_stds" in kwargs: 

304 c_stds = kwargs["C_stds"][j] 

305 else: 

306 c_stds = None 

307 plot_effect( 

308 c_vis[:, j], 

309 means, 

310 lowers, 

311 uppers, 

312 should_exponentiate=kwargs.get("classification", True), 

313 x_encoder=encoder, 

314 x_means=c_means, 

315 x_stds=c_stds, 

316 xlabel=C.columns.tolist()[j], 

317 **kwargs, 

318 ) 

319 

320 

321def plot_homogeneous_predictor_effects( 

322 model: SKLearnWrapper, 

323 C: np.ndarray, 

324 X: np.ndarray, 

325 **kwargs, 

326) -> None: 

327 """ 

328 Plot the effect of predictors on outcomes that do not change with context (homogeneous). 

329 

330 Args: 

331 model (SKLearnWrapper): a fitted ``contextualized.easy`` model 

332 C: the context values to use to estimate the effects 

333 X: the predictor values to use to estimate the effects 

334 max_classes_for_discrete (int, optional): maximum number of classes to treat as discrete. Default 10. 

335 min_effect_size (float, optional): minimum effect size to plot. Default 0.003. 

336 ylabel (str, optional): y label for plot. Default "Influence of ". 

337 xlabel_prefix (str, optional): prefix for x label. Default "". 

338 X_names (List[str], optional): names of predictors. Default None. 

339 X_encoders (List[sklearn.preprocessing.LabelEncoder], optional): encoders for each predictor. Default None. 

340 X_means (np.ndarray, optional): means for each predictor. Default None. 

341 X_stds (np.ndarray, optional): standard deviations for each predictor. Default None. 

342 verbose (bool, optional): print progess. Default True. 

343 lower_pct (int, optional): Lower percentile for bootstraps. Default 2.5. 

344 upper_pct (int, optional): Upper percentile for bootstraps. Default 97.5. 

345 classification (bool, optional): Whether to exponentiate the effects. Default True. 

346 figname (str, optional): name of figure to save. Default None. 

347 

348 Returns: 

349 None 

350 """ 

351 c_vis = np.zeros_like(C.values) 

352 x_vis = make_grid_mat(X.values, 1000) 

353 (betas, _) = model.predict_params( 

354 c_vis, individual_preds=True 

355 ) # bootstraps x C_vis x outcomes x predictors 

356 homogeneous_betas = np.mean(betas, axis=1) # bootstraps x outcomes x predictors 

357 for outcome in range(homogeneous_betas.shape[1]): 

358 betas = homogeneous_betas[:, outcome, :] # bootstraps x predictors 

359 my_avg_betas = np.mean(betas, axis=0) 

360 lowers = np.percentile(betas, kwargs.get("lower_pct", 2.5), axis=0) 

361 uppers = np.percentile(betas, kwargs.get("upper_pct", 97.5), axis=0) 

362 max_impacts = [] 

363 # Calculate the max impact of each effect. 

364 for k in range(my_avg_betas.shape[0]): 

365 effect_range = my_avg_betas[k] * np.ptp(x_vis[:, k]) 

366 max_impacts.append(effect_range) 

367 effects_by_desc_impact = np.argsort(max_impacts)[::-1] 

368 

369 boolean_vars = [j for j in range(X.shape[-1]) if len(set(X.values[:, j])) == 2] 

370 if len(boolean_vars) > 0: 

371 plot_boolean_vars( 

372 [X.columns[j] for j in boolean_vars], 

373 [max_impacts[j] for j in boolean_vars], 

374 [np.max(uppers[j]) - max_impacts[j] for j in boolean_vars], 

375 **kwargs, 

376 ) 

377 for j in effects_by_desc_impact: 

378 if j in boolean_vars: 

379 continue 

380 means = my_avg_betas[j] * x_vis[:, j] 

381 my_lowers = lowers[j] * x_vis[:, j] 

382 my_uppers = uppers[j] * x_vis[:, j] 

383 if "X_encoders" in kwargs: 

384 encoder = kwargs["X_encoders"][j] 

385 else: 

386 encoder = None 

387 if "X_means" in kwargs: 

388 x_means = kwargs["X_means"][j] 

389 else: 

390 x_means = None 

391 if "X_stds" in kwargs: 

392 x_stds = kwargs["X_stds"][j] 

393 else: 

394 x_stds = None 

395 

396 plot_effect( 

397 x_vis[:, j], 

398 means, 

399 my_lowers, 

400 my_uppers, 

401 should_exponentiate=kwargs.get("classification", True), 

402 x_encoder=encoder, 

403 x_means=x_means, 

404 x_stds=x_stds, 

405 xlabel=f"{kwargs.get('xlabel_prefix', '')} {X.columns[j]}", 

406 **kwargs, 

407 ) 

408 

409 

410def plot_heterogeneous_predictor_effects(model, C, X, **kwargs): 

411 """ 

412 Plot how the effect of predictors on outcomes changes with context (heterogeneous). 

413 

414 Args: 

415 model (SKLearnWrapper): a fitted ``contextualized.easy`` model 

416 C: the context values to use to estimate the effects 

417 X: the predictor values to use to estimate the effects 

418 max_classes_for_discrete (int, optional): maximum number of classes to treat as discrete. Default 10. 

419 min_effect_size (float, optional): minimum effect size to plot. Default 0.003. 

420 y_prefix (str, optional): y prefix for plot. Default "Influence of ". 

421 X_names (List[str], optional): names of predictors. Default None. 

422 verbose (bool, optional): print progess. Default True. 

423 individual_preds (bool, optional): whether to use plot each bootstrap. Default True. 

424 C_vis (np.ndarray, optional): Context bins used to visualize context (n_vis, n_contexts). Default None to construct anew. 

425 n_vis (int, optional): Number of bins to use to visualize context. Default 1000. 

426 lower_pct (int, optional): Lower percentile for bootstraps. Default 2.5. 

427 upper_pct (int, optional): Upper percentile for bootstraps. Default 97.5. 

428 classification (bool, optional): Whether to exponentiate the effects. Default True. 

429 C_encoders (List[sklearn.preprocessing.LabelEncoder], optional): encoders for each context. Default None. 

430 C_means (np.ndarray, optional): means for each context. Default None. 

431 C_stds (np.ndarray, optional): standard deviations for each context. Default None. 

432 xlabel_prefix (str, optional): prefix for x label. Default "". 

433 figname (str, optional): name of figure to save. Default None. 

434 

435 Returns: 

436 None 

437 """ 

438 c_vis = maybe_make_c_vis(C, **kwargs) 

439 n_vis = c_vis.shape[0] 

440 # c_names = C.columns.tolist() 

441 for j in range(C.shape[1]): 

442 c_j = c_vis.copy() 

443 c_j[:, :j] = 0.0 

444 c_j[:, j + 1 :] = 0.0 

445 (models, _) = model.predict_params( 

446 c_j, individual_preds=True 

447 ) # n_bootstraps x n_vis x outcomes x predictors 

448 homogeneous_effects = np.mean( 

449 models, axis=1 

450 ) # n_bootstraps x outcomes x predictors 

451 heterogeneous_effects = models.copy() 

452 for i in range(n_vis): 

453 heterogeneous_effects[:, i] -= homogeneous_effects 

454 # n_bootstraps x n_vis x outcomes x predictors 

455 

456 for outcome in range(heterogeneous_effects.shape[2]): 

457 my_effects = heterogeneous_effects[ 

458 :, :, outcome, : 

459 ] # n_bootstraps x n_vis x predictors 

460 means = np.mean(my_effects, axis=0) # n_vis x predictors 

461 my_lowers = np.percentile(my_effects, kwargs.get("lower_pct", 2.5), axis=0) 

462 my_uppers = np.percentile(my_effects, kwargs.get("upper_pct", 97.5), axis=0) 

463 

464 x_ticks = None 

465 x_ticklabels = None 

466 try: 

467 x_classes = kwargs["encoders"][j].classes_ 

468 if len(x_classes) <= kwargs.get("max_classes_for_discrete", 10): 

469 x_ticks = np.array(list(range(len(x_classes)))) 

470 if "c_means" in kwargs: 

471 x_ticks -= kwargs["c_means"][j] 

472 if "c_stds" in kwargs: 

473 x_ticks /= kwargs["c_stds"][j] 

474 x_ticklabels = x_classes 

475 except KeyError: 

476 pass 

477 for k in range(my_effects.shape[-1]): 

478 if np.max(heterogeneous_effects[:, k]) > kwargs.get( 

479 "min_effect_size", 0.0 

480 ): 

481 simple_plot( 

482 c_vis[:, j], 

483 means[:, k], 

484 x_label=C.columns[j], 

485 y_label=f"{kwargs.get('y_prefix', 'Influence of')} {X.columns[k]}", 

486 y_lowers=my_lowers[:, k], 

487 y_uppers=my_uppers[:, k], 

488 x_ticks=x_ticks, 

489 x_ticklabels=x_ticklabels, 

490 **kwargs, 

491 ) 

492 

493 

494def make_grid_mat(observation_mat, n_vis): 

495 """ 

496 

497 :param observation_mat: defines the domain for each feature. 

498 :param n_vis: 

499 

500 returns a matrix of n_vis x n_features that can be used to visualize the effects of the features. 

501 

502 """ 

503 ar_vis = np.zeros((n_vis, observation_mat.shape[1])) 

504 for j in range(observation_mat.shape[1]): 

505 ar_vis[:, j] = np.linspace( 

506 np.min(observation_mat[:, j]), np.max(observation_mat[:, j]), n_vis 

507 ) 

508 return ar_vis 

509 

510 

511def make_c_vis(C, n_vis): 

512 """ 

513 

514 :param C: 

515 :param n_vis: 

516 

517 returns a matrix of n_vis x n_contexts that can be used to visualize the effects of the context variables. 

518 

519 """ 

520 if isinstance(C, np.ndarray): 

521 return make_grid_mat(C, n_vis) 

522 return make_grid_mat(C.values, n_vis) 

523 

524 

525def maybe_make_c_vis(C, **kwargs): 

526 """ 

527 

528 :param C: 

529 :param n_vis: 

530 

531 returns a matrix of n_vis x n_contexts that can be used to visualize the effects of the context variables. 

532 if C_vis is supplied, then we use that instead. 

533 

534 """ 

535 if kwargs.get("C_vis", None) is None: 

536 if kwargs.get("verbose", True): 

537 print( 

538 """Generating datapoints for visualization by assuming the encoder is 

539 an additive model and thus doesn't require sampling on a manifold. 

540 If the encoder has interactions, please supply C_vis so that we 

541 can visualize these effects on the correct data manifold.""" 

542 ) 

543 return make_c_vis(C, kwargs.get("n_vis", 1000)) 

544 return kwargs["C_vis"]