Coverage for contextualized/easy/wrappers/SKLearnWrapper.py: 92%

218 statements  

« prev     ^ index     » next       coverage.py v7.4.4, created at 2024-04-21 13:49 -0400

1""" 

2An sklearn-like wrapper for Contextualized models. 

3""" 

4 

5import copy 

6import os 

7from typing import * 

8 

9import numpy as np 

10from pytorch_lightning.callbacks.early_stopping import EarlyStopping 

11from pytorch_lightning.callbacks import ModelCheckpoint 

12from sklearn.model_selection import train_test_split 

13import torch 

14 

15from contextualized.functions import LINK_FUNCTIONS 

16from contextualized.regression import REGULARIZERS, LOSSES 

17 

18DEFAULT_LEARNING_RATE = 1e-3 

19DEFAULT_N_BOOTSTRAPS = 1 

20DEFAULT_ES_PATIENCE = 1 

21DEFAULT_VAL_BATCH_SIZE = 16 

22DEFAULT_TRAIN_BATCH_SIZE = 1 

23DEFAULT_TEST_BATCH_SIZE = 16 

24DEFAULT_VAL_SPLIT = 0.2 

25DEFAULT_ENCODER_TYPE = "mlp" 

26DEFAULT_ENCODER_WIDTH = 25 

27DEFAULT_ENCODER_LAYERS = 3 

28DEFAULT_ENCODER_LINK_FN = LINK_FUNCTIONS["identity"] 

29 

30 

31class SKLearnWrapper: 

32 """ 

33 An sklearn-like wrapper for Contextualized models. 

34 

35 Args: 

36 base_constructor (class): The base class to construct the model. 

37 extra_model_kwargs (dict): Extra kwargs to pass to the model constructor. 

38 extra_data_kwargs (dict): Extra kwargs to pass to the dataloader constructor. 

39 trainer_constructor (class): The trainer class to use. 

40 n_bootstraps (int, optional): Number of bootstraps to use. Defaults to 1. 

41 encoder_type (str, optional): Type of encoder to use ("mlp", "ngam", "linear"). Defaults to "mlp". 

42 loss_fn (torch.nn.Module, optional): Loss function. Defaults to LOSSES["mse"]. 

43 link_fn (torch.nn.Module, optional): Link function. Defaults to LINK_FUNCTIONS["identity"]. 

44 alpha (float, optional): Regularization strength. Defaults to 0.0. 

45 mu_ratio (float, optional): Float in range (0.0, 1.0), governs how much the regularization applies to context-specific parameters or context-specific offsets. 

46 l1_ratio (float, optional): Float in range (0.0, 1.0), governs how much the regularization penalizes l1 vs l2 parameter norms. 

47 """ 

48 

49 def _set_defaults(self): 

50 self.default_learning_rate = DEFAULT_LEARNING_RATE 

51 self.default_n_bootstraps = DEFAULT_N_BOOTSTRAPS 

52 self.default_es_patience = DEFAULT_ES_PATIENCE 

53 self.default_train_batch_size = DEFAULT_TRAIN_BATCH_SIZE 

54 self.default_test_batch_size = DEFAULT_TEST_BATCH_SIZE 

55 self.default_val_batch_size = DEFAULT_VAL_BATCH_SIZE 

56 self.default_val_split = DEFAULT_VAL_SPLIT 

57 self.default_encoder_width = DEFAULT_ENCODER_WIDTH 

58 self.default_encoder_layers = DEFAULT_ENCODER_LAYERS 

59 self.default_encoder_link_fn = DEFAULT_ENCODER_LINK_FN 

60 self.default_encoder_type = DEFAULT_ENCODER_TYPE 

61 

62 def __init__( 

63 self, 

64 base_constructor, 

65 extra_model_kwargs, 

66 extra_data_kwargs, 

67 trainer_constructor, 

68 **kwargs, 

69 ): 

70 self._set_defaults() 

71 self.base_constructor = base_constructor 

72 self.n_bootstraps = 1 

73 self.models = None 

74 self.trainers = None 

75 self.dataloaders = None 

76 self.context_dim = None 

77 self.x_dim = None 

78 self.y_dim = None 

79 self.trainer_constructor = trainer_constructor 

80 self.accelerator = "gpu" if torch.cuda.is_available() else "cpu" 

81 self.acceptable_kwargs = { 

82 "data": [ 

83 "train_batch_size", 

84 "val_batch_size", 

85 "test_batch_size", 

86 "C_val", 

87 "X_val", 

88 "val_split", 

89 ], 

90 "model": [ 

91 "loss_fn", 

92 "link_fn", 

93 "univariate", 

94 "encoder_type", 

95 "encoder_kwargs", 

96 "model_regularizer", 

97 "num_archetypes", 

98 "learning_rate", 

99 "context_dim", 

100 "x_dim", 

101 ], 

102 "trainer": [ 

103 "max_epochs", 

104 "check_val_every_n_epoch", 

105 "val_check_interval", 

106 "callbacks", 

107 "callback_constructors", 

108 "accelerator", 

109 ], 

110 "fit": [], 

111 "wrapper": [ 

112 "n_bootstraps", 

113 "es_patience", 

114 "es_monitor", 

115 "es_mode", 

116 "es_min_delta", 

117 "es_verbose", 

118 ], 

119 } 

120 self._update_acceptable_kwargs("model", extra_model_kwargs) 

121 self._update_acceptable_kwargs("data", extra_data_kwargs) 

122 self._update_acceptable_kwargs( 

123 "model", kwargs.pop("remove_model_kwargs", []), acceptable=False 

124 ) 

125 self._update_acceptable_kwargs( 

126 "data", kwargs.pop("remove_data_kwargs", []), acceptable=False 

127 ) 

128 self.convenience_kwargs = [ 

129 "alpha", 

130 "l1_ratio", 

131 "mu_ratio", 

132 "subtype_probabilities", 

133 "width", 

134 "layers", 

135 "encoder_link_fn", 

136 ] 

137 self.constructor_kwargs = self._organize_constructor_kwargs(**kwargs) 

138 self.constructor_kwargs["encoder_kwargs"]["width"] = kwargs.pop( 

139 "width", self.constructor_kwargs["encoder_kwargs"]["width"] 

140 ) 

141 self.constructor_kwargs["encoder_kwargs"]["layers"] = kwargs.pop( 

142 "layers", self.constructor_kwargs["encoder_kwargs"]["layers"] 

143 ) 

144 self.constructor_kwargs["encoder_kwargs"]["link_fn"] = kwargs.pop( 

145 "encoder_link_fn", 

146 self.constructor_kwargs["encoder_kwargs"].get( 

147 "link_fn", self.default_encoder_link_fn 

148 ), 

149 ) 

150 self.not_constructor_kwargs = { 

151 k: v 

152 for k, v in kwargs.items() 

153 if k not in self.constructor_kwargs and k not in self.convenience_kwargs 

154 } 

155 # Some args will not be ignored by wrapper because sub-class will handle them. 

156 # self.private_kwargs = kwargs.pop("private_kwargs", []) 

157 # self.private_kwargs.append("private_kwargs") 

158 # Add Predictor-Specific kwargs for parsing. 

159 self._init_kwargs, unrecognized_general_kwargs = self._organize_kwargs( 

160 **self.not_constructor_kwargs 

161 ) 

162 for key, value in self.constructor_kwargs.items(): 

163 self._init_kwargs["model"][key] = value 

164 recognized_private_init_kwargs = self._parse_private_init_kwargs(**kwargs) 

165 for kwarg in set(unrecognized_general_kwargs) - set( 

166 recognized_private_init_kwargs 

167 ): 

168 print(f"Received unknown keyword argument {kwarg}, probably ignoring.") 

169 

170 def _organize_and_expand_fit_kwargs(self, **kwargs): 

171 """ 

172 Private function to organize kwargs passed to constructor or 

173 fit function. 

174 """ 

175 organized_kwargs, unrecognized_general_kwargs = self._organize_kwargs(**kwargs) 

176 recognized_private_kwargs = self._parse_private_fit_kwargs(**kwargs) 

177 for kwarg in set(unrecognized_general_kwargs) - set(recognized_private_kwargs): 

178 print(f"Received unknown keyword argument {kwarg}, probably ignoring.") 

179 # Add kwargs from __init__ to organized_kwargs, keeping more recent kwargs. 

180 for category, category_kwargs in self._init_kwargs.items(): 

181 for key, value in category_kwargs.items(): 

182 if key not in organized_kwargs[category]: 

183 organized_kwargs[category][key] = value 

184 

185 # Add necessary kwargs. 

186 def maybe_add_kwarg(category, kwarg, default_val): 

187 if kwarg in self.acceptable_kwargs[category]: 

188 organized_kwargs[category][kwarg] = organized_kwargs[category].get( 

189 kwarg, default_val 

190 ) 

191 

192 # Model 

193 maybe_add_kwarg("model", "learning_rate", self.default_learning_rate) 

194 maybe_add_kwarg("model", "context_dim", self.context_dim) 

195 maybe_add_kwarg("model", "x_dim", self.x_dim) 

196 maybe_add_kwarg("model", "y_dim", self.y_dim) 

197 if ( 

198 "num_archetypes" in organized_kwargs["model"] 

199 and organized_kwargs["model"]["num_archetypes"] == 0 

200 ): 

201 del organized_kwargs["model"]["num_archetypes"] 

202 

203 # Data 

204 maybe_add_kwarg("data", "train_batch_size", self.default_train_batch_size) 

205 maybe_add_kwarg("data", "val_batch_size", self.default_val_batch_size) 

206 maybe_add_kwarg("data", "test_batch_size", self.default_test_batch_size) 

207 

208 # Wrapper 

209 maybe_add_kwarg("wrapper", "n_bootstraps", self.default_n_bootstraps) 

210 

211 # Trainer 

212 maybe_add_kwarg( 

213 "trainer", 

214 "callback_constructors", 

215 [ 

216 lambda i: EarlyStopping( 

217 monitor=kwargs.get("es_monitor", "val_loss"), 

218 mode=kwargs.get("es_mode", "min"), 

219 patience=kwargs.get("es_patience", self.default_es_patience), 

220 verbose=kwargs.get("es_verbose", False), 

221 min_delta=kwargs.get("es_min_delta", 0.00), 

222 ) 

223 ], 

224 ) 

225 organized_kwargs["trainer"]["callback_constructors"].append( 

226 lambda i: ModelCheckpoint( 

227 monitor=kwargs.get("es_monitor", "val_loss"), 

228 dirpath=f"{kwargs.get('checkpoint_path', './lightning_logs')}/boot_{i}_checkpoints", 

229 filename="{epoch}-{val_loss:.2f}", 

230 ) 

231 ) 

232 maybe_add_kwarg("trainer", "accelerator", self.accelerator) 

233 return organized_kwargs 

234 

235 def _parse_private_fit_kwargs(self, **kwargs): 

236 """ 

237 Parse private (model-specific) kwargs passed to fit function. 

238 Return the list of parsed kwargs. 

239 """ 

240 return [] 

241 

242 def _parse_private_init_kwargs(self, **kwargs): 

243 """ 

244 Parse private (model-specific) kwargs passed to constructor. 

245 Return the list of parsed kwargs. 

246 """ 

247 return [] 

248 

249 def _update_acceptable_kwargs(self, category, new_kwargs, acceptable=True): 

250 """ 

251 Helper function to update the acceptable kwargs. 

252 If acceptable=True, the new kwargs will be added to the list of acceptable kwargs. 

253 If acceptable=False, the new kwargs will be removed from the list of acceptable kwargs. 

254 """ 

255 if acceptable: 

256 self.acceptable_kwargs[category] = list( 

257 set(self.acceptable_kwargs[category]).union(set(new_kwargs)) 

258 ) 

259 else: 

260 self.acceptable_kwargs[category] = list( 

261 set(self.acceptable_kwargs[category]) - set(new_kwargs) 

262 ) 

263 

264 def _organize_kwargs(self, **kwargs): 

265 """ 

266 Private helper function to organize kwargs passed to constructor or 

267 fit function. 

268 Organizes kwargs into data, model, trainer, fit, and wrapper categories. 

269 """ 

270 

271 # Combine default allowed keywords with subclass-specfic 

272 organized_kwargs = {category: {} for category in self.acceptable_kwargs} 

273 unrecognized_kwargs = [] 

274 for kwarg, value in kwargs.items(): 

275 # if kwarg in self.private_kwargs: 

276 # continue 

277 not_found = True 

278 for category, category_kwargs in self.acceptable_kwargs.items(): 

279 if kwarg in category_kwargs: 

280 organized_kwargs[category][kwarg] = value 

281 not_found = False 

282 break 

283 if not_found: 

284 unrecognized_kwargs.append(kwarg) 

285 

286 return organized_kwargs, unrecognized_kwargs 

287 

288 def _organize_constructor_kwargs(self, **kwargs): 

289 """ 

290 Helper function to set all the default constructor or changes allowed. 

291 """ 

292 constructor_kwargs = {} 

293 

294 def maybe_add_constructor_kwarg(kwarg, default_val): 

295 if kwarg in self.acceptable_kwargs["model"]: 

296 constructor_kwargs[kwarg] = kwargs.get(kwarg, default_val) 

297 

298 maybe_add_constructor_kwarg("link_fn", LINK_FUNCTIONS["identity"]) 

299 maybe_add_constructor_kwarg("univariate", False) 

300 maybe_add_constructor_kwarg("encoder_type", self.default_encoder_type) 

301 maybe_add_constructor_kwarg("loss_fn", LOSSES["mse"]) 

302 maybe_add_constructor_kwarg( 

303 "encoder_kwargs", 

304 { 

305 "width": kwargs.get("encoder_width", self.default_encoder_width), 

306 "layers": kwargs.get("encoder_layers", self.default_encoder_layers), 

307 "link_fn": kwargs.get("encoder_link_fn", self.default_encoder_link_fn), 

308 }, 

309 ) 

310 if kwargs.get("subtype_probabilities", False): 

311 constructor_kwargs["encoder_kwargs"]["link_fn"] = LINK_FUNCTIONS["softmax"] 

312 

313 # Make regularizer 

314 if "model_regularizer" in self.acceptable_kwargs["model"]: 

315 if "alpha" in kwargs and kwargs["alpha"] > 0: 

316 constructor_kwargs["model_regularizer"] = REGULARIZERS["l1_l2"]( 

317 kwargs["alpha"], 

318 kwargs.get("l1_ratio", 1.0), 

319 kwargs.get("mu_ratio", 0.5), 

320 ) 

321 else: 

322 constructor_kwargs["model_regularizer"] = kwargs.get( 

323 "model_regularizer", REGULARIZERS["none"] 

324 ) 

325 return constructor_kwargs 

326 

327 def _split_train_data(self, C, X, Y=None, Y_required=False, **kwargs): 

328 if "C_val" in kwargs: 

329 if "X_val" in kwargs: 

330 if Y_required and "Y_val" in kwargs: 

331 train_data = [C, X, Y] 

332 val_data = [kwargs["C_val"], X, kwargs["X_val"], Y, kwargs["Y_val"]] 

333 return train_data, val_data 

334 print("Y_val not provided, not using the provided C_val or X_val.") 

335 else: 

336 print("X_val not provided, not using the provided C_val.") 

337 if "val_split" in kwargs: 

338 if 0 < kwargs["val_split"] < 1: 

339 val_split = kwargs["val_split"] 

340 else: 

341 print( 

342 """val_split={kwargs['val_split']} provided but should be between 0 

343 and 1 to indicate proportion of data to use as validation.""" 

344 ) 

345 raise ValueError 

346 else: 

347 val_split = self.default_val_split 

348 if Y is None: 

349 C_train, C_val, X_train, X_val = train_test_split( 

350 C, X, test_size=val_split, shuffle=True 

351 ) 

352 train_data = [C_train, X_train] 

353 val_data = [C_val, X_val] 

354 else: 

355 C_train, C_val, X_train, X_val, Y_train, Y_val = train_test_split( 

356 C, X, Y, test_size=val_split, shuffle=True 

357 ) 

358 train_data = [C_train, X_train, Y_train] 

359 val_data = [C_val, X_val, Y_val] 

360 return train_data, val_data 

361 

362 def _build_dataloader(self, model, batch_size, *data): 

363 """ 

364 Helper function to build a single dataloder. 

365 Expects *args to contain whatever data (C,X,Y) is necessary for this model. 

366 """ 

367 return model.dataloader(*data, batch_size=batch_size) 

368 

369 def _build_dataloaders(self, model, train_data, val_data, **kwargs): 

370 """ 

371 :param model: 

372 :param **kwargs: 

373 """ 

374 train_dataloader = self._build_dataloader( 

375 model, 

376 kwargs.get("train_batch_size", self.default_train_batch_size), 

377 *train_data, 

378 ) 

379 if val_data is None: 

380 val_dataloader = None 

381 else: 

382 val_dataloader = self._build_dataloader( 

383 model, 

384 kwargs.get("val_batch_size", self.default_val_batch_size), 

385 *val_data, 

386 ) 

387 

388 return train_dataloader, val_dataloader 

389 

390 def predict( 

391 self, C: np.ndarray, X: np.ndarray, individual_preds: bool = False, **kwargs 

392 ) -> Union[np.ndarray, List[np.ndarray]]: 

393 """Predict outcomes from context C and predictors X. 

394 

395 Args: 

396 C (np.ndarray): Context array of shape (n_samples, n_context_features) 

397 X (np.ndarray): Predictor array of shape (N, n_features) 

398 individual_preds (bool, optional): Whether to return individual predictions for each model. Defaults to False. 

399 

400 Returns: 

401 Union[np.ndarray, List[np.ndarray]]: The outcomes predicted by the context-specific models (n_samples, y_dim). Returned as lists of individual bootstraps if individual_preds is True. 

402 """ 

403 if not hasattr(self, "models") or self.models is None: 

404 raise ValueError( 

405 "Trying to predict with a model that hasn't been trained yet." 

406 ) 

407 predictions = np.array( 

408 [ 

409 self.trainers[i].predict_y( 

410 self.models[i], 

411 self.models[i].dataloader(C, X, np.zeros((len(C), self.y_dim))), 

412 **kwargs, 

413 ) 

414 for i in range(len(self.models)) 

415 ] 

416 ) 

417 if individual_preds: 

418 return predictions 

419 return np.mean(predictions, axis=0) 

420 

421 def predict_params( 

422 self, 

423 C: np.ndarray, 

424 individual_preds: bool = False, 

425 model_includes_mus: bool = True, 

426 **kwargs, 

427 ) -> Union[ 

428 np.ndarray, 

429 List[np.ndarray], 

430 Tuple[np.ndarray, np.ndarray], 

431 Tuple[List[np.ndarray], List[np.ndarray]], 

432 ]: 

433 """ 

434 Predict context-specific model parameters from context C. 

435 

436 Args: 

437 C (np.ndarray): Context array of shape (n_samples, n_context_features) 

438 individual_preds (bool, optional): Whether to return individual model predictions for each bootstrap. Defaults to False, averaging across bootstraps. 

439 model_includes_mus (bool, optional): Whether the model includes context-specific offsets (mu). Defaults to True. 

440 

441 Returns: 

442 Union[np.ndarray, List[np.ndarray], Tuple[np.ndarray, np.ndarray], Tuple[List[np.ndarray], List[np.ndarray]]: The parameters of the predicted context-specific models. 

443 Returned as lists of individual bootstraps if individual_preds is True, otherwise averages the bootstraps for a better estimate. 

444 If model_includes_mus is True, returns both coefficients and offsets as a tuple of (betas, mus). Otherwise, returns coefficients (betas) only. 

445 For model_includes_mus=True, ([betas], [mus]) if individual_preds is True, otherwise (betas, mus). 

446 For model_includes_mus=False, [betas] if individual_preds is True, otherwise betas. 

447 betas is shape (n_samples, x_dim, y_dim) or (n_samples, x_dim) if y_dim = 1. 

448 mus is shape (n_samples, y_dim) or (n_samples,) if y_dim = 1. 

449 """ 

450 # Returns betas, mus 

451 if kwargs.pop("uses_y", True): 

452 get_dataloader = lambda i: self.models[i].dataloader( 

453 C, np.zeros((len(C), self.x_dim)), np.zeros((len(C), self.y_dim)) 

454 ) 

455 else: 

456 get_dataloader = lambda i: self.models[i].dataloader( 

457 C, np.zeros((len(C), self.x_dim)) 

458 ) 

459 predictions = [ 

460 self.trainers[i].predict_params(self.models[i], get_dataloader(i), **kwargs) 

461 for i in range(len(self.models)) 

462 ] 

463 if model_includes_mus: 

464 betas = np.array([p[0] for p in predictions]) 

465 mus = np.array([p[1] for p in predictions]) 

466 if individual_preds: 

467 return betas, mus 

468 else: 

469 return np.mean(betas, axis=0), np.mean(mus, axis=0) 

470 betas = np.array(predictions) 

471 if not individual_preds: 

472 return np.mean(betas, axis=0) 

473 return betas 

474 

475 def fit(self, *args, **kwargs) -> None: 

476 """ 

477 Fit contextualized model to data. 

478 

479 Args: 

480 C (np.ndarray): Context array of shape (n_samples, n_context_features) 

481 X (np.ndarray): Predictor array of shape (N, n_features) 

482 Y (np.ndarray, optional): Target array of shape (N, n_targets). Defaults to None, where X will be used as targets such as in Contextualized Networks. 

483 max_epochs (int, optional): Maximum number of epochs to train for. Defaults to 1. 

484 learning_rate (float, optional): Learning rate for optimizer. Defaults to 1e-3. 

485 val_split (float, optional): Proportion of data to use for validation and early stopping. Defaults to 0.2. 

486 n_bootstraps (int, optional): Number of bootstraps to use. Defaults to 1. 

487 train_batch_size (int, optional): Batch size for training. Defaults to 1. 

488 val_batch_size (int, optional): Batch size for validation. Defaults to 16. 

489 test_batch_size (int, optional): Batch size for testing. Defaults to 16. 

490 es_patience (int, optional): Number of epochs to wait before early stopping. Defaults to 1. 

491 es_monitor (str, optional): Metric to monitor for early stopping. Defaults to "val_loss". 

492 es_mode (str, optional): Mode for early stopping. Defaults to "min". 

493 es_verbose (bool, optional): Whether to print early stopping updates. Defaults to False. 

494 """ 

495 self.models = [] 

496 self.trainers = [] 

497 self.dataloaders = {"train": [], "val": [], "test": []} 

498 self.context_dim = args[0].shape[-1] 

499 self.x_dim = args[1].shape[-1] 

500 if len(args) == 3: 

501 Y = args[2] 

502 if kwargs.get("Y", None) is not None: 

503 Y = kwargs.get("Y") 

504 if len(Y.shape) == 1: # add feature dimension to Y if not given. 

505 Y = np.expand_dims(Y, 1) 

506 self.y_dim = Y.shape[-1] 

507 args = (args[0], args[1], Y) 

508 else: 

509 self.y_dim = self.x_dim 

510 organized_kwargs = self._organize_and_expand_fit_kwargs(**kwargs) 

511 self.n_bootstraps = organized_kwargs["wrapper"].get( 

512 "n_bootstraps", self.n_bootstraps 

513 ) 

514 for bootstrap in range(self.n_bootstraps): 

515 model = self.base_constructor(**organized_kwargs["model"]) 

516 train_data, val_data = self._split_train_data( 

517 *args, **organized_kwargs["data"] 

518 ) 

519 train_dataloader, val_dataloader = self._build_dataloaders( 

520 model, 

521 train_data, 

522 val_data, 

523 **organized_kwargs["data"], 

524 ) 

525 # Makes a new trainer for each bootstrap fit - bad practice, but necessary here. 

526 my_trainer_kwargs = copy.deepcopy(organized_kwargs["trainer"]) 

527 # Must reconstruct the callbacks because they save state from fitting trajectories. 

528 my_trainer_kwargs["callbacks"] = [ 

529 f(bootstrap) 

530 for f in organized_kwargs["trainer"]["callback_constructors"] 

531 ] 

532 del my_trainer_kwargs["callback_constructors"] 

533 trainer = self.trainer_constructor( 

534 **my_trainer_kwargs, enable_progress_bar=False 

535 ) 

536 checkpoint_callback = my_trainer_kwargs["callbacks"][1] 

537 os.makedirs(checkpoint_callback.dirpath, exist_ok=True) 

538 try: 

539 trainer.fit( 

540 model, train_dataloader, val_dataloader, **organized_kwargs["fit"] 

541 ) 

542 except: 

543 trainer.fit(model, train_dataloader, **organized_kwargs["fit"]) 

544 if kwargs.get("max_epochs", 1) > 0: 

545 best_checkpoint = torch.load(checkpoint_callback.best_model_path) 

546 model.load_state_dict(best_checkpoint["state_dict"]) 

547 self.dataloaders["train"].append(train_dataloader) 

548 self.dataloaders["val"].append(val_dataloader) 

549 self.models.append(model) 

550 self.trainers.append(trainer)