list | grep "spot[RiverPython]" pip
spotPython 0.2.34
spotRiver 0.0.93
Note: you may need to restart the kernel to use updated packages.
In this tutorial, we will show how spotPython
can be integrated into the PyTorch
training workflow.
This document refers to the following software versions:
python
: 3.10.10torch
: 2.0.1torchvision
: 0.15.0list | grep "spot[RiverPython]" pip
spotPython 0.2.34
spotRiver 0.0.93
Note: you may need to restart the kernel to use updated packages.
spotPython
can be installed via pip. Alternatively, the source code can be downloaded from gitHub: https://github.com/sequential-parameter-optimization/spotPython.
!pip install spotPython
spotPython
from gitHub.# import sys
# !{sys.executable} -m pip install --upgrade build
# !{sys.executable} -m pip install --upgrade --force-reinstall spotPython
Before we consider the detailed experimental setup, we select the parameters that affect run time, initial design size and the device that is used.
= 1
MAX_TIME = 5
INIT_SIZE = "cpu" # "cuda:0" DEVICE
from spotPython.utils.device import getDevice
= getDevice(DEVICE)
DEVICE print(DEVICE)
cpu
import os
import copy
import socket
from datetime import datetime
from dateutil.tz import tzlocal
= datetime.now(tzlocal())
start_time = socket.gethostname().split(".")[0]
HOSTNAME = '11-torch' + "_" + HOSTNAME + "_" + str(MAX_TIME) + "min_" + str(INIT_SIZE) + "init_" + str(start_time).split(".", 1)[0].replace(' ', '_')
experiment_name = experiment_name.replace(':', '-')
experiment_name print(experiment_name)
if not os.path.exists('./figures'):
'./figures') os.makedirs(
11-torch_p040025_1min_5init_2023-06-17_11-21-10
fun_control
DictionaryspotPython
uses a Python dictionary for storing the information required for the hyperparameter tuning process, which was described in Section 14.2.
from spotPython.utils.init import fun_control_init
= fun_control_init(task="classification",
fun_control ="runs/11_spot_hpt_torch_fashion_mnist",
tensorboard_path=DEVICE) device
from torchvision import datasets, transforms
from torchvision.transforms import ToTensor
def load_data(data_dir="./data"):
# Download training data from open datasets.
= datasets.FashionMNIST(
training_data =data_dir,
root=True,
train=True,
download=ToTensor(),
transform
)# Download test data from open datasets.
= datasets.FashionMNIST(
test_data =data_dir,
root=False,
train=True,
download=ToTensor(),
transform
)return training_data, test_data
= load_data()
train, test train.data.shape, test.data.shape
(torch.Size([60000, 28, 28]), torch.Size([10000, 28, 28]))
= len(train)
n_samples # add the dataset to the fun_control
"data": None,
fun_control.update({"train": train,
"test": test,
"n_samples": n_samples,
"target_column": None})
After the training and test data are specified and added to the fun_control
dictionary, spotPython
allows the specification of a data preprocessing pipeline, e.g., for the scaling of the data or for the one-hot encoding of categorical variables, see Section 14.4.1. This feature is not used here, so we do not change the default value (which is None
).
algorithm
and core_model_hyper_dict
spotPython
implements a class which is similar to the class described in the PyTorch tutorial. The class is called Net_fashionMNIST
and is implemented in the file netfashionMNIST.py
. The class is imported here.
from torch import nn
import spotPython.torch.netcore as netcore
class Net_fashionMNIST(netcore.Net_Core):
def __init__(self, l1, l2, lr_mult, batch_size, epochs, k_folds, patience, optimizer, sgd_momentum):
super(Net_fashionMNIST, self).__init__(
=lr_mult,
lr_mult=batch_size,
batch_size=epochs,
epochs=k_folds,
k_folds=patience,
patience=optimizer,
optimizer=sgd_momentum,
sgd_momentum
)self.flatten = nn.Flatten()
self.linear_relu_stack = nn.Sequential(
28 * 28, l1),
nn.Linear(
nn.ReLU(),
nn.Linear(l1, l2),
nn.ReLU(),10)
nn.Linear(l2,
)
def forward(self, x):
= self.flatten(x)
x = self.linear_relu_stack(x)
logits return logits
This class inherits from the class Net_Core
which is implemented in the file netcore.py
, see Section 14.4.3.
from spotPython.data.torch_hyper_dict import TorchHyperDict
from spotPython.torch.netfashionMNIST import Net_fashionMNIST
from spotPython.hyperparameters.values import add_core_model_to_fun_control
= add_core_model_to_fun_control(core_model=Net_fashionMNIST,
fun_control =fun_control,
fun_control=TorchHyperDict,
hyper_dict=None) filename
hyper_dict
Hyperparameters for the Selected AlgorithmspotPython
uses JSON
files for the specification of the hyperparameters, which were described in Section 14.5.2.
The corresponding entries for the Net_fashionMNIST
class are shown below.
"Net_fashionMNIST":
{
"l1": {
"type": "int",
"default": 5,
"transform": "transform_power_2_int",
"lower": 2,
"upper": 9},
"l2": {
"type": "int",
"default": 5,
"transform": "transform_power_2_int",
"lower": 2,
"upper": 9},
"lr_mult": {
"type": "float",
"default": 1.0,
"transform": "None",
"lower": 0.1,
"upper": 10.0},
"batch_size": {
"type": "int",
"default": 4,
"transform": "transform_power_2_int",
"lower": 1,
"upper": 4},
"epochs": {
"type": "int",
"default": 3,
"transform": "transform_power_2_int",
"lower": 3,
"upper": 4},
"k_folds": {
"type": "int",
"default": 1,
"transform": "None",
"lower": 1,
"upper": 1},
"patience": {
"type": "int",
"default": 5,
"transform": "None",
"lower": 2,
"upper": 10
},
"optimizer": {
"levels": ["Adadelta",
"Adagrad",
"Adam",
"AdamW",
"SparseAdam",
"Adamax",
"ASGD",
"NAdam",
"RAdam",
"RMSprop",
"Rprop",
"SGD"],
"type": "factor",
"default": "SGD",
"transform": "None",
"core_model_parameter_type": "str",
"lower": 0,
"upper": 12},
"sgd_momentum": {
"type": "float",
"default": 0.0,
"transform": "None",
"lower": 0.0,
"upper": 1.0}
},
hyper_dict
Hyperparameters for the Selected Algorithm aka core_model
spotPython
provides functions for modifying the hyperparameters, their bounds and factors as well as for activating and de-activating hyperparameters without re-compilation of the Python source code. These functions were described in Section 14.5.3.
from spotPython.hyperparameters.values import modify_hyper_parameter_bounds
# fun_control = modify_hyper_parameter_bounds(fun_control, "delta", bounds=[1e-10, 1e-6])
# fun_control = modify_hyper_parameter_bounds(fun_control, "min_samples_split", bounds=[3, 20])
#fun_control = modify_hyper_parameter_bounds(fun_control, "merit_preprune", bounds=[0, 0])
# fun_control["core_model_hyper_dict"]
= modify_hyper_parameter_bounds(fun_control, "k_folds", bounds=[0, 0])
fun_control = modify_hyper_parameter_bounds(fun_control, "patience", bounds=[2, 2])
fun_control = modify_hyper_parameter_bounds(fun_control, "epochs", bounds=[2, 3]) fun_control
from spotPython.hyperparameters.values import modify_hyper_parameter_levels
# fun_control = modify_hyper_parameter_levels(fun_control, "leaf_model", ["LinearRegression"])
# fun_control["core_model_hyper_dict"]
Optimizers are described in Section 14.6.
The evaluation procedure requires the specification of two elements:
These are described in Section 19.9.
The key "loss_function"
specifies the loss function which is used during the optimization, see Section 14.8.
We will use CrossEntropy loss for the multiclass-classification task.
from torch.nn import CrossEntropyLoss
= CrossEntropyLoss()
loss_function
fun_control.update({"loss_function": loss_function,
"shuffle": True,
"eval": "train_hold_out"
})
from torchmetrics import Accuracy
= Accuracy(task="multiclass", num_classes=10).to(fun_control["device"])
metric_torch "metric_torch": metric_torch}) fun_control.update({
The following code passes the information about the parameter ranges and bounds to spot
.
# extract the variable types, names, and bounds
from spotPython.hyperparameters.values import (get_bound_values,
get_var_name,
get_var_type,)= get_var_type(fun_control)
var_type = get_var_name(fun_control)
var_name "var_type": var_type,
fun_control.update({"var_name": var_name})
= get_bound_values(fun_control, "lower")
lower = get_bound_values(fun_control, "upper") upper
from spotPython.utils.eda import gen_design_table
print(gen_design_table(fun_control))
| name | type | default | lower | upper | transform |
|--------------|--------|-----------|---------|---------|-----------------------|
| l1 | int | 5 | 2 | 9 | transform_power_2_int |
| l2 | int | 5 | 2 | 9 | transform_power_2_int |
| lr_mult | float | 1.0 | 0.1 | 10 | None |
| batch_size | int | 4 | 1 | 4 | transform_power_2_int |
| epochs | int | 3 | 2 | 3 | transform_power_2_int |
| k_folds | int | 1 | 0 | 0 | None |
| patience | int | 5 | 2 | 2 | None |
| optimizer | factor | SGD | 0 | 12 | None |
| sgd_momentum | float | 0.0 | 0 | 1 | None |
fun_torch
The objective function fun_torch
is selected next. It implements an interface from PyTorch
’s training, validation, and testing methods to spotPython
.
from spotPython.fun.hypertorch import HyperTorch
= HyperTorch().fun_torch fun
import numpy as np
from spotPython.spot import spot
from math import inf
= spot.Spot(fun=fun,
spot_tuner = lower,
lower = upper,
upper = inf,
fun_evals = 1,
fun_repeats = MAX_TIME,
max_time = False,
noise = np.sqrt(np.spacing(1)),
tolerance_x = var_type,
var_type = var_name,
var_name = "y",
infill_criterion = 1,
n_points =123,
seed= 50,
log_level = False,
show_models= True,
show_progress= fun_control,
fun_control ={"init_size": INIT_SIZE,
design_control"repeats": 1},
={"noise": True,
surrogate_control"cod_type": "norm",
"min_theta": -4,
"max_theta": 3,
"n_theta": len(var_name),
"model_fun_evals": 10_000,
"log_level": 50
})=X_start) spot_tuner.run(X_start
config: {'l1': 16, 'l2': 32, 'lr_mult': 9.563687451910228, 'batch_size': 8, 'epochs': 8, 'k_folds': 0, 'patience': 2, 'optimizer': 'AdamW', 'sgd_momentum': 0.41533100039458876}
Epoch: 1
Loss on hold-out set: 0.6194007385865941
Accuracy on hold-out set: 0.80425
MulticlassAccuracy value on hold-out data: 0.8042500019073486
Epoch: 2
Loss on hold-out set: 0.5304514363297882
Accuracy on hold-out set: 0.8284166666666667
MulticlassAccuracy value on hold-out data: 0.828416645526886
Epoch: 3
Loss on hold-out set: 0.5181355149464604
Accuracy on hold-out set: 0.8395
MulticlassAccuracy value on hold-out data: 0.8395000100135803
Epoch: 4
Loss on hold-out set: 0.5617739174207672
Accuracy on hold-out set: 0.8196666666666667
MulticlassAccuracy value on hold-out data: 0.8196666836738586
Epoch: 5
Loss on hold-out set: 0.5308127590441145
Accuracy on hold-out set: 0.832875
MulticlassAccuracy value on hold-out data: 0.8328750133514404
Early stopping at epoch 4
Returned to Spot: Validation loss: 0.5308127590441145
----------------------------------------------
config: {'l1': 128, 'l2': 32, 'lr_mult': 6.258012467639852, 'batch_size': 2, 'epochs': 4, 'k_folds': 0, 'patience': 2, 'optimizer': 'RAdam', 'sgd_momentum': 0.9572474073249809}
Epoch: 1
Loss on hold-out set: 0.6282176408300402
Accuracy on hold-out set: 0.8405
MulticlassAccuracy value on hold-out data: 0.840499997138977
Epoch: 2
Loss on hold-out set: 0.5457000766642944
Accuracy on hold-out set: 0.8679166666666667
MulticlassAccuracy value on hold-out data: 0.8679166436195374
Epoch: 3
Loss on hold-out set: 0.5141142338671999
Accuracy on hold-out set: 0.880625
MulticlassAccuracy value on hold-out data: 0.8806250095367432
Epoch: 4
Loss on hold-out set: 0.5849477228425366
Accuracy on hold-out set: 0.8694583333333333
MulticlassAccuracy value on hold-out data: 0.8694583177566528
Returned to Spot: Validation loss: 0.5849477228425366
----------------------------------------------
config: {'l1': 256, 'l2': 256, 'lr_mult': 0.2437336281201693, 'batch_size': 16, 'epochs': 8, 'k_folds': 0, 'patience': 2, 'optimizer': 'Adagrad', 'sgd_momentum': 0.15368887503658651}
Epoch: 1
Loss on hold-out set: 0.4944321710815032
Accuracy on hold-out set: 0.8320833333333333
MulticlassAccuracy value on hold-out data: 0.8320833444595337
Epoch: 2
Loss on hold-out set: 0.4628898579701781
Accuracy on hold-out set: 0.8368333333333333
MulticlassAccuracy value on hold-out data: 0.8368333578109741
Epoch: 3
Loss on hold-out set: 0.4363566631687184
Accuracy on hold-out set: 0.8484583333333333
MulticlassAccuracy value on hold-out data: 0.8484583497047424
Epoch: 4
Loss on hold-out set: 0.4216897410278519
Accuracy on hold-out set: 0.8548333333333333
MulticlassAccuracy value on hold-out data: 0.8548333048820496
Epoch: 5
Loss on hold-out set: 0.41668745046854017
Accuracy on hold-out set: 0.857
MulticlassAccuracy value on hold-out data: 0.8569999933242798
Epoch: 6
Loss on hold-out set: 0.40203342540934683
Accuracy on hold-out set: 0.86075
MulticlassAccuracy value on hold-out data: 0.8607500195503235
Epoch: 7
Loss on hold-out set: 0.4060324877500534
Accuracy on hold-out set: 0.8573333333333333
MulticlassAccuracy value on hold-out data: 0.8573333621025085
Epoch: 8
Loss on hold-out set: 0.39424820882454514
Accuracy on hold-out set: 0.8646666666666667
MulticlassAccuracy value on hold-out data: 0.8646666407585144
Returned to Spot: Validation loss: 0.39424820882454514
----------------------------------------------
config: {'l1': 64, 'l2': 8, 'lr_mult': 2.906205211581667, 'batch_size': 8, 'epochs': 4, 'k_folds': 0, 'patience': 2, 'optimizer': 'SGD', 'sgd_momentum': 0.25435133436334767}
Epoch: 1
Loss on hold-out set: 1.0938524534106255
Accuracy on hold-out set: 0.621
MulticlassAccuracy value on hold-out data: 0.6209999918937683
Epoch: 2
Loss on hold-out set: 0.848190563261509
Accuracy on hold-out set: 0.685375
MulticlassAccuracy value on hold-out data: 0.6853749752044678
Epoch: 3
Loss on hold-out set: 0.7481328926657637
Accuracy on hold-out set: 0.7252916666666667
MulticlassAccuracy value on hold-out data: 0.7252916693687439
Epoch: 4
Loss on hold-out set: 0.7023135198329886
Accuracy on hold-out set: 0.7464166666666666
MulticlassAccuracy value on hold-out data: 0.7464166879653931
Returned to Spot: Validation loss: 0.7023135198329886
----------------------------------------------
config: {'l1': 4, 'l2': 128, 'lr_mult': 4.224097306355747, 'batch_size': 4, 'epochs': 8, 'k_folds': 0, 'patience': 2, 'optimizer': 'Adamax', 'sgd_momentum': 0.6538496127257492}
Epoch: 1
Loss on hold-out set: 1.4610868829339743
Accuracy on hold-out set: 0.361
MulticlassAccuracy value on hold-out data: 0.3610000014305115
Epoch: 2
Loss on hold-out set: 1.3173598170628151
Accuracy on hold-out set: 0.45125
MulticlassAccuracy value on hold-out data: 0.45124998688697815
Epoch: 3
Loss on hold-out set: 1.3123198301047088
Accuracy on hold-out set: 0.457875
MulticlassAccuracy value on hold-out data: 0.45787501335144043
Epoch: 4
Loss on hold-out set: 1.2965779286225636
Accuracy on hold-out set: 0.49191666666666667
MulticlassAccuracy value on hold-out data: 0.4919166564941406
Epoch: 5
Loss on hold-out set: 1.2610745450109244
Accuracy on hold-out set: 0.5101666666666667
MulticlassAccuracy value on hold-out data: 0.5101666450500488
Epoch: 6
Loss on hold-out set: 1.2758716597606738
Accuracy on hold-out set: 0.5029166666666667
MulticlassAccuracy value on hold-out data: 0.502916693687439
Epoch: 7
Loss on hold-out set: 1.2511051948765914
Accuracy on hold-out set: 0.486875
MulticlassAccuracy value on hold-out data: 0.4868749976158142
Epoch: 8
Loss on hold-out set: 1.2649959445620576
Accuracy on hold-out set: 0.49633333333333335
MulticlassAccuracy value on hold-out data: 0.4963333308696747
Returned to Spot: Validation loss: 1.2649959445620576
----------------------------------------------
config: {'l1': 256, 'l2': 256, 'lr_mult': 0.21396102096383096, 'batch_size': 16, 'epochs': 8, 'k_folds': 0, 'patience': 2, 'optimizer': 'Adagrad', 'sgd_momentum': 0.1471458256326419}
Epoch: 1
Loss on hold-out set: 0.4979291863093773
Accuracy on hold-out set: 0.8295
MulticlassAccuracy value on hold-out data: 0.8295000195503235
Epoch: 2
Loss on hold-out set: 0.4613260097379486
Accuracy on hold-out set: 0.8374166666666667
MulticlassAccuracy value on hold-out data: 0.8374166488647461
Epoch: 3
Loss on hold-out set: 0.43455552994211516
Accuracy on hold-out set: 0.84825
MulticlassAccuracy value on hold-out data: 0.8482499718666077
Epoch: 4
Loss on hold-out set: 0.4192953429793318
Accuracy on hold-out set: 0.8529583333333334
MulticlassAccuracy value on hold-out data: 0.8529583215713501
Epoch: 5
Loss on hold-out set: 0.4106680783244471
Accuracy on hold-out set: 0.85625
MulticlassAccuracy value on hold-out data: 0.856249988079071
Epoch: 6
Loss on hold-out set: 0.40636651575441163
Accuracy on hold-out set: 0.8585833333333334
MulticlassAccuracy value on hold-out data: 0.8585833311080933
Epoch: 7
Loss on hold-out set: 0.4053738015300284
Accuracy on hold-out set: 0.8561666666666666
MulticlassAccuracy value on hold-out data: 0.856166660785675
Epoch: 8
Loss on hold-out set: 0.39709795179590585
Accuracy on hold-out set: 0.8615
MulticlassAccuracy value on hold-out data: 0.8615000247955322
Returned to Spot: Validation loss: 0.39709795179590585
----------------------------------------------
spotPython tuning: 0.39424820882454514 [#####-----] 50.96%
config: {'l1': 256, 'l2': 256, 'lr_mult': 2.961232098479208, 'batch_size': 16, 'epochs': 8, 'k_folds': 0, 'patience': 2, 'optimizer': 'Adagrad', 'sgd_momentum': 0.7773902885423232}
Epoch: 1
Loss on hold-out set: 0.3948214785158634
Accuracy on hold-out set: 0.8554583333333333
MulticlassAccuracy value on hold-out data: 0.8554583191871643
Epoch: 2
Loss on hold-out set: 0.35802501618117094
Accuracy on hold-out set: 0.8705416666666667
MulticlassAccuracy value on hold-out data: 0.8705416917800903
Epoch: 3
Loss on hold-out set: 0.3484339735303074
Accuracy on hold-out set: 0.8772083333333334
MulticlassAccuracy value on hold-out data: 0.8772083520889282
Epoch: 4
Loss on hold-out set: 0.32964511326141654
Accuracy on hold-out set: 0.8821666666666667
MulticlassAccuracy value on hold-out data: 0.8821666836738586
Epoch: 5
Loss on hold-out set: 0.3429898156930382
Accuracy on hold-out set: 0.8795416666666667
MulticlassAccuracy value on hold-out data: 0.8795416951179504
Epoch: 6
Loss on hold-out set: 0.33033685816281166
Accuracy on hold-out set: 0.885
MulticlassAccuracy value on hold-out data: 0.8849999904632568
Early stopping at epoch 5
Returned to Spot: Validation loss: 0.33033685816281166
----------------------------------------------
spotPython tuning: 0.33033685816281166 [#########-] 90.18%
config: {'l1': 64, 'l2': 256, 'lr_mult': 10.0, 'batch_size': 16, 'epochs': 8, 'k_folds': 0, 'patience': 2, 'optimizer': 'Adagrad', 'sgd_momentum': 1.0}
Epoch: 1
Loss on hold-out set: 0.6928169537732999
Accuracy on hold-out set: 0.757375
MulticlassAccuracy value on hold-out data: 0.7573750019073486
Epoch: 2
Loss on hold-out set: 0.6009475722064574
Accuracy on hold-out set: 0.78375
MulticlassAccuracy value on hold-out data: 0.7837499976158142
Epoch: 3
Loss on hold-out set: 0.5835815471870204
Accuracy on hold-out set: 0.7854166666666667
MulticlassAccuracy value on hold-out data: 0.7854166626930237
Epoch: 4
Loss on hold-out set: 0.5971286185433468
Accuracy on hold-out set: 0.7808333333333334
MulticlassAccuracy value on hold-out data: 0.7808333039283752
Epoch: 5
Loss on hold-out set: 0.524177671401451
Accuracy on hold-out set: 0.8114583333333333
MulticlassAccuracy value on hold-out data: 0.8114583492279053
Epoch: 6
Loss on hold-out set: 0.513421056561172
Accuracy on hold-out set: 0.81175
MulticlassAccuracy value on hold-out data: 0.8117499947547913
Epoch: 7
Loss on hold-out set: 0.498869803994894
Accuracy on hold-out set: 0.81175
MulticlassAccuracy value on hold-out data: 0.8117499947547913
Epoch: 8
Loss on hold-out set: 0.4996524799987674
Accuracy on hold-out set: 0.8166666666666667
MulticlassAccuracy value on hold-out data: 0.8166666626930237
Returned to Spot: Validation loss: 0.4996524799987674
----------------------------------------------
spotPython tuning: 0.33033685816281166 [##########] 100.00% Done...
<spotPython.spot.spot.Spot at 0x29deb2bf0>
The textual output shown in the console (or code cell) can be visualized with Tensorboard as described in Section 14.13.
After the hyperparameter tuning run is finished, the results can be analyzed as described in Section 14.14.
= False
SAVE = False
LOAD
if SAVE:
= "res_" + experiment_name + ".pkl"
result_file_name with open(result_file_name, 'wb') as f:
pickle.dump(spot_tuner, f)
if LOAD:
= "ADD THE NAME here, e.g.: res_ch10-friedman-hpt-0_maans03_60min_20init_1K_2023-04-14_10-11-19.pkl"
result_file_name with open(result_file_name, 'rb') as f:
= pickle.load(f) spot_tuner
After the hyperparameter tuning run is finished, the progress of the hyperparameter tuning can be visualized. The following code generates the progress plot from ?fig-progress.
=False,
spot_tuner.plot_progress(log_y="./figures/" + experiment_name+"_progress.png") filename
print(gen_design_table(fun_control=fun_control,
=spot_tuner)) spot
| name | type | default | lower | upper | tuned | transform | importance | stars |
|--------------|--------|-----------|---------|---------|--------------------|-----------------------|--------------|---------|
| l1 | int | 5 | 2.0 | 9.0 | 8.0 | transform_power_2_int | 0.22 | . |
| l2 | int | 5 | 2.0 | 9.0 | 8.0 | transform_power_2_int | 100.00 | *** |
| lr_mult | float | 1.0 | 0.1 | 10.0 | 2.961232098479208 | None | 0.00 | |
| batch_size | int | 4 | 1.0 | 4.0 | 4.0 | transform_power_2_int | 0.00 | |
| epochs | int | 3 | 2.0 | 3.0 | 3.0 | transform_power_2_int | 0.00 | |
| k_folds | int | 1 | 0.0 | 0.0 | 0.0 | None | 0.00 | |
| patience | int | 5 | 2.0 | 2.0 | 2.0 | None | 0.00 | |
| optimizer | factor | SGD | 0.0 | 12.0 | 1.0 | None | 0.01 | |
| sgd_momentum | float | 0.0 | 0.0 | 1.0 | 0.7773902885423232 | None | 0.01 | |
=0.025, filename="./figures/" + experiment_name+"_importance.png") spot_tuner.plot_importance(threshold
The architecture of the spotPython
model can be obtained by the following code:
from spotPython.hyperparameters.values import get_one_core_model_from_X
= spot_tuner.to_all_dim(spot_tuner.min_X.reshape(1,-1))
X = get_one_core_model_from_X(X, fun_control)
model_spot model_spot
Net_fashionMNIST(
(flatten): Flatten(start_dim=1, end_dim=-1)
(linear_relu_stack): Sequential(
(0): Linear(in_features=784, out_features=256, bias=True)
(1): ReLU()
(2): Linear(in_features=256, out_features=256, bias=True)
(3): ReLU()
(4): Linear(in_features=256, out_features=10, bias=True)
)
)
= fun_control
fc "core_model_hyper_dict":
fc.update({"core_model"].__name__]})
hyper_dict[fun_control[= get_one_core_model_from_X(X_start, fun_control=fc)
model_default model_default
Net_fashionMNIST(
(flatten): Flatten(start_dim=1, end_dim=-1)
(linear_relu_stack): Sequential(
(0): Linear(in_features=784, out_features=32, bias=True)
(1): ReLU()
(2): Linear(in_features=32, out_features=32, bias=True)
(3): ReLU()
(4): Linear(in_features=32, out_features=10, bias=True)
)
)
The method train_tuned
takes a model architecture without trained weights and trains this model with the train data. The train data is split into train and validation data. The validation data is used for early stopping. The trained model weights are saved as a dictionary.
from spotPython.torch.traintest import train_tuned
=model_default, train_dataset=train, shuffle=True,
train_tuned(net=fun_control["loss_function"],
loss_function=fun_control["metric_torch"],
metric= fun_control["device"],
device =1_000_000,
show_batch_interval=None,
path=fun_control["task"]) task
Epoch: 1
Loss on hold-out set: 2.0563841097354887
Accuracy on hold-out set: 0.41233333333333333
MulticlassAccuracy value on hold-out data: 0.41233333945274353
Epoch: 2
Loss on hold-out set: 1.4778506670792897
Accuracy on hold-out set: 0.5262083333333333
MulticlassAccuracy value on hold-out data: 0.5262083411216736
Epoch: 3
Loss on hold-out set: 1.2049028609593708
Accuracy on hold-out set: 0.5736666666666667
MulticlassAccuracy value on hold-out data: 0.5736666917800903
Epoch: 4
Loss on hold-out set: 1.0716215518712997
Accuracy on hold-out set: 0.6002083333333333
MulticlassAccuracy value on hold-out data: 0.6002083420753479
Epoch: 5
Loss on hold-out set: 0.9921316761374473
Accuracy on hold-out set: 0.6262083333333334
MulticlassAccuracy value on hold-out data: 0.6262083053588867
Epoch: 6
Loss on hold-out set: 0.9373754310806592
Accuracy on hold-out set: 0.6455
MulticlassAccuracy value on hold-out data: 0.6455000042915344
Epoch: 7
Loss on hold-out set: 0.8958202048738797
Accuracy on hold-out set: 0.6605
MulticlassAccuracy value on hold-out data: 0.6604999899864197
Epoch: 8
Loss on hold-out set: 0.8619507992466291
Accuracy on hold-out set: 0.6740833333333334
MulticlassAccuracy value on hold-out data: 0.6740833520889282
Returned to Spot: Validation loss: 0.8619507992466291
----------------------------------------------
from spotPython.torch.traintest import test_tuned
=model_default, test_dataset=test,
test_tuned(net=fun_control["loss_function"],
loss_function=fun_control["metric_torch"],
metric=False,
shuffle= fun_control["device"],
device =fun_control["task"]) task
Loss on hold-out set: 0.8744107367515564
Accuracy on hold-out set: 0.6637
MulticlassAccuracy value on hold-out data: 0.6636999845504761
Final evaluation: Validation loss: 0.8744107367515564
Final evaluation: Validation metric: 0.6636999845504761
----------------------------------------------
(0.8744107367515564, nan, tensor(0.6637))
The following code trains the model model_spot
. If path
is set to a filename, e.g., path = "model_spot_trained.pt"
, the weights of the trained model will be saved to this file.
=model_spot, train_dataset=train,
train_tuned(net=fun_control["loss_function"],
loss_function=fun_control["metric_torch"],
metric=True,
shuffle= fun_control["device"],
device =None,
path=fun_control["task"]) task
Epoch: 1
Loss on hold-out set: 0.4294524548873305
Accuracy on hold-out set: 0.8432083333333333
MulticlassAccuracy value on hold-out data: 0.8432083129882812
Epoch: 2
Loss on hold-out set: 0.3704918979307016
Accuracy on hold-out set: 0.866375
MulticlassAccuracy value on hold-out data: 0.8663750290870667
Epoch: 3
Loss on hold-out set: 0.3632254023378094
Accuracy on hold-out set: 0.8699166666666667
MulticlassAccuracy value on hold-out data: 0.8699166774749756
Epoch: 4
Loss on hold-out set: 0.34430248334367447
Accuracy on hold-out set: 0.878
MulticlassAccuracy value on hold-out data: 0.878000020980835
Epoch: 5
Loss on hold-out set: 0.3381388572102102
Accuracy on hold-out set: 0.8815833333333334
MulticlassAccuracy value on hold-out data: 0.8815833330154419
Epoch: 6
Loss on hold-out set: 0.3301646394111837
Accuracy on hold-out set: 0.884
MulticlassAccuracy value on hold-out data: 0.8840000033378601
Epoch: 7
Loss on hold-out set: 0.33321633622717733
Accuracy on hold-out set: 0.8840416666666666
MulticlassAccuracy value on hold-out data: 0.8840416669845581
Epoch: 8
Loss on hold-out set: 0.3401115870640303
Accuracy on hold-out set: 0.8805
MulticlassAccuracy value on hold-out data: 0.8805000185966492
Early stopping at epoch 7
Returned to Spot: Validation loss: 0.3401115870640303
----------------------------------------------
=model_spot, test_dataset=test,
test_tuned(net=False,
shuffle=fun_control["loss_function"],
loss_function=fun_control["metric_torch"],
metric= fun_control["device"],
device =fun_control["task"]) task
Loss on hold-out set: 0.36662325962334874
Accuracy on hold-out set: 0.8741
MulticlassAccuracy value on hold-out data: 0.8741000294685364
Final evaluation: Validation loss: 0.36662325962334874
Final evaluation: Validation metric: 0.8741000294685364
----------------------------------------------
(0.36662325962334874, nan, tensor(0.8741))
= "./figures/" + experiment_name
filename =filename) spot_tuner.plot_important_hyperparameter_contour(filename
l1: 0.21622658465044553
l2: 100.0
spot_tuner.parallel_plot()
Parallel coordinates plots
= False
PLOT_ALL if PLOT_ALL:
= spot_tuner.k
n for i in range(n-1):
for j in range(i+1, n):
=i, j=j, min_z=min_z, max_z = max_z) spot_tuner.plot_contour(i