Coverage for contextualized/easy/wrappers/SKLearnWrapper.py: 92%
218 statements
« prev ^ index » next coverage.py v7.4.4, created at 2024-04-21 13:49 -0400
« prev ^ index » next coverage.py v7.4.4, created at 2024-04-21 13:49 -0400
1"""
2An sklearn-like wrapper for Contextualized models.
3"""
5import copy
6import os
7from typing import *
9import numpy as np
10from pytorch_lightning.callbacks.early_stopping import EarlyStopping
11from pytorch_lightning.callbacks import ModelCheckpoint
12from sklearn.model_selection import train_test_split
13import torch
15from contextualized.functions import LINK_FUNCTIONS
16from contextualized.regression import REGULARIZERS, LOSSES
18DEFAULT_LEARNING_RATE = 1e-3
19DEFAULT_N_BOOTSTRAPS = 1
20DEFAULT_ES_PATIENCE = 1
21DEFAULT_VAL_BATCH_SIZE = 16
22DEFAULT_TRAIN_BATCH_SIZE = 1
23DEFAULT_TEST_BATCH_SIZE = 16
24DEFAULT_VAL_SPLIT = 0.2
25DEFAULT_ENCODER_TYPE = "mlp"
26DEFAULT_ENCODER_WIDTH = 25
27DEFAULT_ENCODER_LAYERS = 3
28DEFAULT_ENCODER_LINK_FN = LINK_FUNCTIONS["identity"]
31class SKLearnWrapper:
32 """
33 An sklearn-like wrapper for Contextualized models.
35 Args:
36 base_constructor (class): The base class to construct the model.
37 extra_model_kwargs (dict): Extra kwargs to pass to the model constructor.
38 extra_data_kwargs (dict): Extra kwargs to pass to the dataloader constructor.
39 trainer_constructor (class): The trainer class to use.
40 n_bootstraps (int, optional): Number of bootstraps to use. Defaults to 1.
41 encoder_type (str, optional): Type of encoder to use ("mlp", "ngam", "linear"). Defaults to "mlp".
42 loss_fn (torch.nn.Module, optional): Loss function. Defaults to LOSSES["mse"].
43 link_fn (torch.nn.Module, optional): Link function. Defaults to LINK_FUNCTIONS["identity"].
44 alpha (float, optional): Regularization strength. Defaults to 0.0.
45 mu_ratio (float, optional): Float in range (0.0, 1.0), governs how much the regularization applies to context-specific parameters or context-specific offsets.
46 l1_ratio (float, optional): Float in range (0.0, 1.0), governs how much the regularization penalizes l1 vs l2 parameter norms.
47 """
49 def _set_defaults(self):
50 self.default_learning_rate = DEFAULT_LEARNING_RATE
51 self.default_n_bootstraps = DEFAULT_N_BOOTSTRAPS
52 self.default_es_patience = DEFAULT_ES_PATIENCE
53 self.default_train_batch_size = DEFAULT_TRAIN_BATCH_SIZE
54 self.default_test_batch_size = DEFAULT_TEST_BATCH_SIZE
55 self.default_val_batch_size = DEFAULT_VAL_BATCH_SIZE
56 self.default_val_split = DEFAULT_VAL_SPLIT
57 self.default_encoder_width = DEFAULT_ENCODER_WIDTH
58 self.default_encoder_layers = DEFAULT_ENCODER_LAYERS
59 self.default_encoder_link_fn = DEFAULT_ENCODER_LINK_FN
60 self.default_encoder_type = DEFAULT_ENCODER_TYPE
62 def __init__(
63 self,
64 base_constructor,
65 extra_model_kwargs,
66 extra_data_kwargs,
67 trainer_constructor,
68 **kwargs,
69 ):
70 self._set_defaults()
71 self.base_constructor = base_constructor
72 self.n_bootstraps = 1
73 self.models = None
74 self.trainers = None
75 self.dataloaders = None
76 self.context_dim = None
77 self.x_dim = None
78 self.y_dim = None
79 self.trainer_constructor = trainer_constructor
80 self.accelerator = "gpu" if torch.cuda.is_available() else "cpu"
81 self.acceptable_kwargs = {
82 "data": [
83 "train_batch_size",
84 "val_batch_size",
85 "test_batch_size",
86 "C_val",
87 "X_val",
88 "val_split",
89 ],
90 "model": [
91 "loss_fn",
92 "link_fn",
93 "univariate",
94 "encoder_type",
95 "encoder_kwargs",
96 "model_regularizer",
97 "num_archetypes",
98 "learning_rate",
99 "context_dim",
100 "x_dim",
101 ],
102 "trainer": [
103 "max_epochs",
104 "check_val_every_n_epoch",
105 "val_check_interval",
106 "callbacks",
107 "callback_constructors",
108 "accelerator",
109 ],
110 "fit": [],
111 "wrapper": [
112 "n_bootstraps",
113 "es_patience",
114 "es_monitor",
115 "es_mode",
116 "es_min_delta",
117 "es_verbose",
118 ],
119 }
120 self._update_acceptable_kwargs("model", extra_model_kwargs)
121 self._update_acceptable_kwargs("data", extra_data_kwargs)
122 self._update_acceptable_kwargs(
123 "model", kwargs.pop("remove_model_kwargs", []), acceptable=False
124 )
125 self._update_acceptable_kwargs(
126 "data", kwargs.pop("remove_data_kwargs", []), acceptable=False
127 )
128 self.convenience_kwargs = [
129 "alpha",
130 "l1_ratio",
131 "mu_ratio",
132 "subtype_probabilities",
133 "width",
134 "layers",
135 "encoder_link_fn",
136 ]
137 self.constructor_kwargs = self._organize_constructor_kwargs(**kwargs)
138 self.constructor_kwargs["encoder_kwargs"]["width"] = kwargs.pop(
139 "width", self.constructor_kwargs["encoder_kwargs"]["width"]
140 )
141 self.constructor_kwargs["encoder_kwargs"]["layers"] = kwargs.pop(
142 "layers", self.constructor_kwargs["encoder_kwargs"]["layers"]
143 )
144 self.constructor_kwargs["encoder_kwargs"]["link_fn"] = kwargs.pop(
145 "encoder_link_fn",
146 self.constructor_kwargs["encoder_kwargs"].get(
147 "link_fn", self.default_encoder_link_fn
148 ),
149 )
150 self.not_constructor_kwargs = {
151 k: v
152 for k, v in kwargs.items()
153 if k not in self.constructor_kwargs and k not in self.convenience_kwargs
154 }
155 # Some args will not be ignored by wrapper because sub-class will handle them.
156 # self.private_kwargs = kwargs.pop("private_kwargs", [])
157 # self.private_kwargs.append("private_kwargs")
158 # Add Predictor-Specific kwargs for parsing.
159 self._init_kwargs, unrecognized_general_kwargs = self._organize_kwargs(
160 **self.not_constructor_kwargs
161 )
162 for key, value in self.constructor_kwargs.items():
163 self._init_kwargs["model"][key] = value
164 recognized_private_init_kwargs = self._parse_private_init_kwargs(**kwargs)
165 for kwarg in set(unrecognized_general_kwargs) - set(
166 recognized_private_init_kwargs
167 ):
168 print(f"Received unknown keyword argument {kwarg}, probably ignoring.")
170 def _organize_and_expand_fit_kwargs(self, **kwargs):
171 """
172 Private function to organize kwargs passed to constructor or
173 fit function.
174 """
175 organized_kwargs, unrecognized_general_kwargs = self._organize_kwargs(**kwargs)
176 recognized_private_kwargs = self._parse_private_fit_kwargs(**kwargs)
177 for kwarg in set(unrecognized_general_kwargs) - set(recognized_private_kwargs):
178 print(f"Received unknown keyword argument {kwarg}, probably ignoring.")
179 # Add kwargs from __init__ to organized_kwargs, keeping more recent kwargs.
180 for category, category_kwargs in self._init_kwargs.items():
181 for key, value in category_kwargs.items():
182 if key not in organized_kwargs[category]:
183 organized_kwargs[category][key] = value
185 # Add necessary kwargs.
186 def maybe_add_kwarg(category, kwarg, default_val):
187 if kwarg in self.acceptable_kwargs[category]:
188 organized_kwargs[category][kwarg] = organized_kwargs[category].get(
189 kwarg, default_val
190 )
192 # Model
193 maybe_add_kwarg("model", "learning_rate", self.default_learning_rate)
194 maybe_add_kwarg("model", "context_dim", self.context_dim)
195 maybe_add_kwarg("model", "x_dim", self.x_dim)
196 maybe_add_kwarg("model", "y_dim", self.y_dim)
197 if (
198 "num_archetypes" in organized_kwargs["model"]
199 and organized_kwargs["model"]["num_archetypes"] == 0
200 ):
201 del organized_kwargs["model"]["num_archetypes"]
203 # Data
204 maybe_add_kwarg("data", "train_batch_size", self.default_train_batch_size)
205 maybe_add_kwarg("data", "val_batch_size", self.default_val_batch_size)
206 maybe_add_kwarg("data", "test_batch_size", self.default_test_batch_size)
208 # Wrapper
209 maybe_add_kwarg("wrapper", "n_bootstraps", self.default_n_bootstraps)
211 # Trainer
212 maybe_add_kwarg(
213 "trainer",
214 "callback_constructors",
215 [
216 lambda i: EarlyStopping(
217 monitor=kwargs.get("es_monitor", "val_loss"),
218 mode=kwargs.get("es_mode", "min"),
219 patience=kwargs.get("es_patience", self.default_es_patience),
220 verbose=kwargs.get("es_verbose", False),
221 min_delta=kwargs.get("es_min_delta", 0.00),
222 )
223 ],
224 )
225 organized_kwargs["trainer"]["callback_constructors"].append(
226 lambda i: ModelCheckpoint(
227 monitor=kwargs.get("es_monitor", "val_loss"),
228 dirpath=f"{kwargs.get('checkpoint_path', './lightning_logs')}/boot_{i}_checkpoints",
229 filename="{epoch}-{val_loss:.2f}",
230 )
231 )
232 maybe_add_kwarg("trainer", "accelerator", self.accelerator)
233 return organized_kwargs
235 def _parse_private_fit_kwargs(self, **kwargs):
236 """
237 Parse private (model-specific) kwargs passed to fit function.
238 Return the list of parsed kwargs.
239 """
240 return []
242 def _parse_private_init_kwargs(self, **kwargs):
243 """
244 Parse private (model-specific) kwargs passed to constructor.
245 Return the list of parsed kwargs.
246 """
247 return []
249 def _update_acceptable_kwargs(self, category, new_kwargs, acceptable=True):
250 """
251 Helper function to update the acceptable kwargs.
252 If acceptable=True, the new kwargs will be added to the list of acceptable kwargs.
253 If acceptable=False, the new kwargs will be removed from the list of acceptable kwargs.
254 """
255 if acceptable:
256 self.acceptable_kwargs[category] = list(
257 set(self.acceptable_kwargs[category]).union(set(new_kwargs))
258 )
259 else:
260 self.acceptable_kwargs[category] = list(
261 set(self.acceptable_kwargs[category]) - set(new_kwargs)
262 )
264 def _organize_kwargs(self, **kwargs):
265 """
266 Private helper function to organize kwargs passed to constructor or
267 fit function.
268 Organizes kwargs into data, model, trainer, fit, and wrapper categories.
269 """
271 # Combine default allowed keywords with subclass-specfic
272 organized_kwargs = {category: {} for category in self.acceptable_kwargs}
273 unrecognized_kwargs = []
274 for kwarg, value in kwargs.items():
275 # if kwarg in self.private_kwargs:
276 # continue
277 not_found = True
278 for category, category_kwargs in self.acceptable_kwargs.items():
279 if kwarg in category_kwargs:
280 organized_kwargs[category][kwarg] = value
281 not_found = False
282 break
283 if not_found:
284 unrecognized_kwargs.append(kwarg)
286 return organized_kwargs, unrecognized_kwargs
288 def _organize_constructor_kwargs(self, **kwargs):
289 """
290 Helper function to set all the default constructor or changes allowed.
291 """
292 constructor_kwargs = {}
294 def maybe_add_constructor_kwarg(kwarg, default_val):
295 if kwarg in self.acceptable_kwargs["model"]:
296 constructor_kwargs[kwarg] = kwargs.get(kwarg, default_val)
298 maybe_add_constructor_kwarg("link_fn", LINK_FUNCTIONS["identity"])
299 maybe_add_constructor_kwarg("univariate", False)
300 maybe_add_constructor_kwarg("encoder_type", self.default_encoder_type)
301 maybe_add_constructor_kwarg("loss_fn", LOSSES["mse"])
302 maybe_add_constructor_kwarg(
303 "encoder_kwargs",
304 {
305 "width": kwargs.get("encoder_width", self.default_encoder_width),
306 "layers": kwargs.get("encoder_layers", self.default_encoder_layers),
307 "link_fn": kwargs.get("encoder_link_fn", self.default_encoder_link_fn),
308 },
309 )
310 if kwargs.get("subtype_probabilities", False):
311 constructor_kwargs["encoder_kwargs"]["link_fn"] = LINK_FUNCTIONS["softmax"]
313 # Make regularizer
314 if "model_regularizer" in self.acceptable_kwargs["model"]:
315 if "alpha" in kwargs and kwargs["alpha"] > 0:
316 constructor_kwargs["model_regularizer"] = REGULARIZERS["l1_l2"](
317 kwargs["alpha"],
318 kwargs.get("l1_ratio", 1.0),
319 kwargs.get("mu_ratio", 0.5),
320 )
321 else:
322 constructor_kwargs["model_regularizer"] = kwargs.get(
323 "model_regularizer", REGULARIZERS["none"]
324 )
325 return constructor_kwargs
327 def _split_train_data(self, C, X, Y=None, Y_required=False, **kwargs):
328 if "C_val" in kwargs:
329 if "X_val" in kwargs:
330 if Y_required and "Y_val" in kwargs:
331 train_data = [C, X, Y]
332 val_data = [kwargs["C_val"], X, kwargs["X_val"], Y, kwargs["Y_val"]]
333 return train_data, val_data
334 print("Y_val not provided, not using the provided C_val or X_val.")
335 else:
336 print("X_val not provided, not using the provided C_val.")
337 if "val_split" in kwargs:
338 if 0 < kwargs["val_split"] < 1:
339 val_split = kwargs["val_split"]
340 else:
341 print(
342 """val_split={kwargs['val_split']} provided but should be between 0
343 and 1 to indicate proportion of data to use as validation."""
344 )
345 raise ValueError
346 else:
347 val_split = self.default_val_split
348 if Y is None:
349 C_train, C_val, X_train, X_val = train_test_split(
350 C, X, test_size=val_split, shuffle=True
351 )
352 train_data = [C_train, X_train]
353 val_data = [C_val, X_val]
354 else:
355 C_train, C_val, X_train, X_val, Y_train, Y_val = train_test_split(
356 C, X, Y, test_size=val_split, shuffle=True
357 )
358 train_data = [C_train, X_train, Y_train]
359 val_data = [C_val, X_val, Y_val]
360 return train_data, val_data
362 def _build_dataloader(self, model, batch_size, *data):
363 """
364 Helper function to build a single dataloder.
365 Expects *args to contain whatever data (C,X,Y) is necessary for this model.
366 """
367 return model.dataloader(*data, batch_size=batch_size)
369 def _build_dataloaders(self, model, train_data, val_data, **kwargs):
370 """
371 :param model:
372 :param **kwargs:
373 """
374 train_dataloader = self._build_dataloader(
375 model,
376 kwargs.get("train_batch_size", self.default_train_batch_size),
377 *train_data,
378 )
379 if val_data is None:
380 val_dataloader = None
381 else:
382 val_dataloader = self._build_dataloader(
383 model,
384 kwargs.get("val_batch_size", self.default_val_batch_size),
385 *val_data,
386 )
388 return train_dataloader, val_dataloader
390 def predict(
391 self, C: np.ndarray, X: np.ndarray, individual_preds: bool = False, **kwargs
392 ) -> Union[np.ndarray, List[np.ndarray]]:
393 """Predict outcomes from context C and predictors X.
395 Args:
396 C (np.ndarray): Context array of shape (n_samples, n_context_features)
397 X (np.ndarray): Predictor array of shape (N, n_features)
398 individual_preds (bool, optional): Whether to return individual predictions for each model. Defaults to False.
400 Returns:
401 Union[np.ndarray, List[np.ndarray]]: The outcomes predicted by the context-specific models (n_samples, y_dim). Returned as lists of individual bootstraps if individual_preds is True.
402 """
403 if not hasattr(self, "models") or self.models is None:
404 raise ValueError(
405 "Trying to predict with a model that hasn't been trained yet."
406 )
407 predictions = np.array(
408 [
409 self.trainers[i].predict_y(
410 self.models[i],
411 self.models[i].dataloader(C, X, np.zeros((len(C), self.y_dim))),
412 **kwargs,
413 )
414 for i in range(len(self.models))
415 ]
416 )
417 if individual_preds:
418 return predictions
419 return np.mean(predictions, axis=0)
421 def predict_params(
422 self,
423 C: np.ndarray,
424 individual_preds: bool = False,
425 model_includes_mus: bool = True,
426 **kwargs,
427 ) -> Union[
428 np.ndarray,
429 List[np.ndarray],
430 Tuple[np.ndarray, np.ndarray],
431 Tuple[List[np.ndarray], List[np.ndarray]],
432 ]:
433 """
434 Predict context-specific model parameters from context C.
436 Args:
437 C (np.ndarray): Context array of shape (n_samples, n_context_features)
438 individual_preds (bool, optional): Whether to return individual model predictions for each bootstrap. Defaults to False, averaging across bootstraps.
439 model_includes_mus (bool, optional): Whether the model includes context-specific offsets (mu). Defaults to True.
441 Returns:
442 Union[np.ndarray, List[np.ndarray], Tuple[np.ndarray, np.ndarray], Tuple[List[np.ndarray], List[np.ndarray]]: The parameters of the predicted context-specific models.
443 Returned as lists of individual bootstraps if individual_preds is True, otherwise averages the bootstraps for a better estimate.
444 If model_includes_mus is True, returns both coefficients and offsets as a tuple of (betas, mus). Otherwise, returns coefficients (betas) only.
445 For model_includes_mus=True, ([betas], [mus]) if individual_preds is True, otherwise (betas, mus).
446 For model_includes_mus=False, [betas] if individual_preds is True, otherwise betas.
447 betas is shape (n_samples, x_dim, y_dim) or (n_samples, x_dim) if y_dim = 1.
448 mus is shape (n_samples, y_dim) or (n_samples,) if y_dim = 1.
449 """
450 # Returns betas, mus
451 if kwargs.pop("uses_y", True):
452 get_dataloader = lambda i: self.models[i].dataloader(
453 C, np.zeros((len(C), self.x_dim)), np.zeros((len(C), self.y_dim))
454 )
455 else:
456 get_dataloader = lambda i: self.models[i].dataloader(
457 C, np.zeros((len(C), self.x_dim))
458 )
459 predictions = [
460 self.trainers[i].predict_params(self.models[i], get_dataloader(i), **kwargs)
461 for i in range(len(self.models))
462 ]
463 if model_includes_mus:
464 betas = np.array([p[0] for p in predictions])
465 mus = np.array([p[1] for p in predictions])
466 if individual_preds:
467 return betas, mus
468 else:
469 return np.mean(betas, axis=0), np.mean(mus, axis=0)
470 betas = np.array(predictions)
471 if not individual_preds:
472 return np.mean(betas, axis=0)
473 return betas
475 def fit(self, *args, **kwargs) -> None:
476 """
477 Fit contextualized model to data.
479 Args:
480 C (np.ndarray): Context array of shape (n_samples, n_context_features)
481 X (np.ndarray): Predictor array of shape (N, n_features)
482 Y (np.ndarray, optional): Target array of shape (N, n_targets). Defaults to None, where X will be used as targets such as in Contextualized Networks.
483 max_epochs (int, optional): Maximum number of epochs to train for. Defaults to 1.
484 learning_rate (float, optional): Learning rate for optimizer. Defaults to 1e-3.
485 val_split (float, optional): Proportion of data to use for validation and early stopping. Defaults to 0.2.
486 n_bootstraps (int, optional): Number of bootstraps to use. Defaults to 1.
487 train_batch_size (int, optional): Batch size for training. Defaults to 1.
488 val_batch_size (int, optional): Batch size for validation. Defaults to 16.
489 test_batch_size (int, optional): Batch size for testing. Defaults to 16.
490 es_patience (int, optional): Number of epochs to wait before early stopping. Defaults to 1.
491 es_monitor (str, optional): Metric to monitor for early stopping. Defaults to "val_loss".
492 es_mode (str, optional): Mode for early stopping. Defaults to "min".
493 es_verbose (bool, optional): Whether to print early stopping updates. Defaults to False.
494 """
495 self.models = []
496 self.trainers = []
497 self.dataloaders = {"train": [], "val": [], "test": []}
498 self.context_dim = args[0].shape[-1]
499 self.x_dim = args[1].shape[-1]
500 if len(args) == 3:
501 Y = args[2]
502 if kwargs.get("Y", None) is not None:
503 Y = kwargs.get("Y")
504 if len(Y.shape) == 1: # add feature dimension to Y if not given.
505 Y = np.expand_dims(Y, 1)
506 self.y_dim = Y.shape[-1]
507 args = (args[0], args[1], Y)
508 else:
509 self.y_dim = self.x_dim
510 organized_kwargs = self._organize_and_expand_fit_kwargs(**kwargs)
511 self.n_bootstraps = organized_kwargs["wrapper"].get(
512 "n_bootstraps", self.n_bootstraps
513 )
514 for bootstrap in range(self.n_bootstraps):
515 model = self.base_constructor(**organized_kwargs["model"])
516 train_data, val_data = self._split_train_data(
517 *args, **organized_kwargs["data"]
518 )
519 train_dataloader, val_dataloader = self._build_dataloaders(
520 model,
521 train_data,
522 val_data,
523 **organized_kwargs["data"],
524 )
525 # Makes a new trainer for each bootstrap fit - bad practice, but necessary here.
526 my_trainer_kwargs = copy.deepcopy(organized_kwargs["trainer"])
527 # Must reconstruct the callbacks because they save state from fitting trajectories.
528 my_trainer_kwargs["callbacks"] = [
529 f(bootstrap)
530 for f in organized_kwargs["trainer"]["callback_constructors"]
531 ]
532 del my_trainer_kwargs["callback_constructors"]
533 trainer = self.trainer_constructor(
534 **my_trainer_kwargs, enable_progress_bar=False
535 )
536 checkpoint_callback = my_trainer_kwargs["callbacks"][1]
537 os.makedirs(checkpoint_callback.dirpath, exist_ok=True)
538 try:
539 trainer.fit(
540 model, train_dataloader, val_dataloader, **organized_kwargs["fit"]
541 )
542 except:
543 trainer.fit(model, train_dataloader, **organized_kwargs["fit"])
544 if kwargs.get("max_epochs", 1) > 0:
545 best_checkpoint = torch.load(checkpoint_callback.best_model_path)
546 model.load_state_dict(best_checkpoint["state_dict"])
547 self.dataloaders["train"].append(train_dataloader)
548 self.dataloaders["val"].append(val_dataloader)
549 self.models.append(model)
550 self.trainers.append(trainer)