12  HPT: PyTorch With cifar10 Data

In this tutorial, we will show how spotPython can be integrated into the PyTorch training workflow.

This document refers to the following software versions:

pip list | grep  "spot[RiverPython]"
spotPython                 0.2.34
spotRiver                  0.0.93
Note: you may need to restart the kernel to use updated packages.

spotPython can be installed via pip. Alternatively, the source code can be downloaded from gitHub: https://github.com/sequential-parameter-optimization/spotPython.

!pip install spotPython
# import sys
# !{sys.executable} -m pip install --upgrade build
# !{sys.executable} -m pip install --upgrade --force-reinstall spotPython

12.1 Setup

Before we consider the detailed experimental setup, we select the parameters that affect run time, initial design size and the device that is used.

MAX_TIME = 1
INIT_SIZE = 5
DEVICE = None # "cpu" # "cuda:0"
from spotPython.utils.device import getDevice
DEVICE = getDevice(DEVICE)
print(DEVICE)
mps
import os
import copy
import socket
from datetime import datetime
from dateutil.tz import tzlocal
start_time = datetime.now(tzlocal())
HOSTNAME = socket.gethostname().split(".")[0]
experiment_name = '12-torch' + "_" + HOSTNAME + "_" + str(MAX_TIME) + "min_" + str(INIT_SIZE) + "init_" + str(start_time).split(".", 1)[0].replace(' ', '_')
experiment_name = experiment_name.replace(':', '-')
print(experiment_name)
if not os.path.exists('./figures'):
    os.makedirs('./figures')
12-torch_p040025_1min_5init_2023-06-17_11-30-13

12.2 Initialization of the fun_control Dictionary

spotPython uses a Python dictionary for storing the information required for the hyperparameter tuning process, which was described in Section 14.2.

from spotPython.utils.init import fun_control_init
fun_control = fun_control_init(task="classification",
    tensorboard_path="runs/12_spot_hpt_torch_cifar10",
    device=DEVICE)

12.3 PyTorch Data Loading

12.4 1. Load Data Cifar10 Data

from torchvision import datasets, transforms
import torchvision
def load_data(data_dir="./data"):
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    ])

    trainset = torchvision.datasets.CIFAR10(
        root=data_dir, train=True, download=True, transform=transform)

    testset = torchvision.datasets.CIFAR10(
        root=data_dir, train=False, download=True, transform=transform)

    return trainset, testset
train, test = load_data()
Files already downloaded and verified
Files already downloaded and verified
  • Since this works fine, we can add the data loading to the fun_control dictionary:
n_samples = len(train)
# add the dataset to the fun_control
fun_control.update({"data": None, # dataset,
               "train": train,
               "test": test,
               "n_samples": n_samples,
               "target_column": None})

12.5 The Model (Algorithm) to be Tuned

12.6 Specification of the Preprocessing Model

After the training and test data are specified and added to the fun_control dictionary, spotPython allows the specification of a data preprocessing pipeline, e.g., for the scaling of the data or for the one-hot encoding of categorical variables, see Section 14.4.1. This feature is not used here, so we do not change the default value (which is None).

12.7 Step 4: Select algorithm and core_model_hyper_dict

12.7.1 Implementing a Configurable Neural Network With spotPython

spotPython includes the Net_CIFAR10 class which is implemented in the file netcifar10.py. The class is imported here.

This class inherits from the class Net_Core which is implemented in the file netcore.py, see Section 14.4.3.

from spotPython.torch.netcifar10 import Net_CIFAR10
from spotPython.data.torch_hyper_dict import TorchHyperDict
from spotPython.hyperparameters.values import add_core_model_to_fun_control
fun_control = add_core_model_to_fun_control(core_model=Net_CIFAR10,
                              fun_control=fun_control,
                              hyper_dict=TorchHyperDict,
                              filename=None)

12.8 The Search Space

12.8.1 Configuring the Search Space With spotPython

12.8.1.1 The hyper_dict Hyperparameters for the Selected Algorithm

spotPython uses JSON files for the specification of the hyperparameters, which were described in Section 14.5.2.

The corresponding entries for the Net_CIFAR10 class are shown below.

    "Net_CIFAR10":
    {
        "l1": {
            "type": "int",
            "default": 5,
            "transform": "transform_power_2_int",
            "lower": 2,
            "upper": 9},
        "l2": {
            "type": "int",
            "default": 5,
            "transform": "transform_power_2_int",
            "lower": 2,
            "upper": 9},
        "lr_mult": {
            "type": "float",
            "default": 1.0,
            "transform": "None",
            "lower": 0.1,
            "upper": 10.0},
        "batch_size": {
            "type": "int",
            "default": 4,
            "transform": "transform_power_2_int",
            "lower": 1,
            "upper": 4},
        "epochs": {
            "type": "int",
            "default": 3,
            "transform": "transform_power_2_int",
            "lower": 3,
            "upper": 4},
        "k_folds": {
            "type": "int",
            "default": 1,
            "transform": "None",
            "lower": 1,
            "upper": 1},
        "patience": {
            "type": "int",
            "default": 5,
            "transform": "None",
            "lower": 2,
            "upper": 10
        },
        "optimizer": {
            "levels": ["Adadelta",
                       "Adagrad",
                       "Adam",
                       "AdamW",
                       "SparseAdam",
                       "Adamax",
                       "ASGD",
                       "NAdam",
                       "RAdam",
                       "RMSprop",
                       "Rprop",
                       "SGD"],
            "type": "factor",
            "default": "SGD",
            "transform": "None",
            "class_name": "torch.optim",
            "core_model_parameter_type": "str",
            "lower": 0,
            "upper": 12},
        "sgd_momentum": {
            "type": "float",
            "default": 0.0,
            "transform": "None",
            "lower": 0.0,
            "upper": 1.0}
    },

12.9 Modifying the Hyperparameters

spotPython provides functions for modifying the hyperparameters, their bounds and factors as well as for activating and de-activating hyperparameters without re-compilation of the Python source code. These functions were described in Section 14.5.3.

12.9.1 Step 5: Modify hyper_dict Hyperparameters for the Selected Algorithm aka core_model

12.9.1.1 Modify Hyperparameters of Type numeric and integer (boolean)

The hyperparameter k_folds is not used, it is de-activated here by setting the lower and upper bound to the same value.

from spotPython.hyperparameters.values import modify_hyper_parameter_bounds
# fun_control = modify_hyper_parameter_bounds(fun_control, "delta", bounds=[1e-10, 1e-6])
# fun_control = modify_hyper_parameter_bounds(fun_control, "min_samples_split", bounds=[3, 20])
#fun_control = modify_hyper_parameter_bounds(fun_control, "merit_preprune", bounds=[0, 0])
# fun_control["core_model_hyper_dict"]
fun_control = modify_hyper_parameter_bounds(fun_control, "k_folds", bounds=[2, 2])

12.9.2 Modify hyperparameter of type factor

from spotPython.hyperparameters.values import modify_hyper_parameter_levels
fun_control = modify_hyper_parameter_levels(fun_control, "optimizer", ["Adam"])
# fun_control = modify_hyper_parameter_levels(fun_control, "leaf_model", ["LinearRegression"])
# fun_control["core_model_hyper_dict"]

12.9.3 Optimizers

Optimizers can be selected as described in Section 19.8.2.

Optimizers are described in Section 14.6.

fun_control = modify_hyper_parameter_bounds(fun_control,
    "lr_mult", bounds=[1e-3, 1e-3])
fun_control = modify_hyper_parameter_bounds(fun_control,
    "sgd_momentum", bounds=[0.9, 0.9])

12.10 Evaluation

The evaluation procedure requires the specification of two elements:

  1. the way how the data is split into a train and a test set and
  2. the loss function (and a metric).

These are described in Section 19.9.

The key "loss_function" specifies the loss function which is used during the optimization, see Section 14.8.

We will use CrossEntropy loss for the multiclass-classification task.

from torch.nn import CrossEntropyLoss
loss_function = CrossEntropyLoss()
fun_control.update({
        "loss_function": loss_function,
        "shuffle": True,
        "eval":  "train_hold_out"
        })

12.10.1 Metric

import torchmetrics
metric_torch = torchmetrics.Accuracy(task="multiclass",
     num_classes=10).to(fun_control["device"])
fun_control.update({"metric_torch": metric_torch})

12.11 Preparing the SPOT Call

The following code passes the information about the parameter ranges and bounds to spot.

# extract the variable types, names, and bounds
from spotPython.hyperparameters.values import (get_bound_values,
    get_var_name,
    get_var_type,)
var_type = get_var_type(fun_control)
var_name = get_var_name(fun_control)
fun_control.update({"var_type": var_type,
                    "var_name": var_name})
lower = get_bound_values(fun_control, "lower")
upper = get_bound_values(fun_control, "upper")
from spotPython.utils.eda import gen_design_table
print(gen_design_table(fun_control))
| name         | type   | default   |   lower |   upper | transform             |
|--------------|--------|-----------|---------|---------|-----------------------|
| l1           | int    | 5         |   2     |   9     | transform_power_2_int |
| l2           | int    | 5         |   2     |   9     | transform_power_2_int |
| lr_mult      | float  | 1.0       |   0.001 |   0.001 | None                  |
| batch_size   | int    | 4         |   1     |   4     | transform_power_2_int |
| epochs       | int    | 3         |   3     |   4     | transform_power_2_int |
| k_folds      | int    | 1         |   2     |   2     | None                  |
| patience     | int    | 5         |   2     |  10     | None                  |
| optimizer    | factor | SGD       |   0     |   0     | None                  |
| sgd_momentum | float  | 0.0       |   0.9   |   0.9   | None                  |

12.12 The Objective Function fun_torch

The objective function fun_torch is selected next. It implements an interface from PyTorch’s training, validation, and testing methods to spotPython.

from spotPython.fun.hypertorch import HyperTorch
fun = HyperTorch().fun_torch

12.13 Starting the Hyperparameter Tuning

import numpy as np
from spotPython.spot import spot
from math import inf
spot_tuner = spot.Spot(fun=fun,
                   lower = lower,
                   upper = upper,
                   fun_evals = inf,
                   fun_repeats = 1,
                   max_time = MAX_TIME,
                   noise = False,
                   tolerance_x = np.sqrt(np.spacing(1)),
                   var_type = var_type,
                   var_name = var_name,
                   infill_criterion = "y",
                   n_points = 1,
                   seed=123,
                   log_level = 50,
                   show_models= False,
                   show_progress= True,
                   fun_control = fun_control,
                   design_control={"init_size": INIT_SIZE,
                                   "repeats": 1},
                   surrogate_control={"noise": True,
                                      "cod_type": "norm",
                                      "min_theta": -4,
                                      "max_theta": 3,
                                      "n_theta": len(var_name),
                                      "model_fun_evals": 10_000,
                                      "log_level": 50
                                      })
spot_tuner.run(X_start=X_start)

config: {'l1': 128, 'l2': 8, 'lr_mult': 0.001, 'batch_size': 16, 'epochs': 16, 'k_folds': 2, 'patience': 5, 'optimizer': 'Adam', 'sgd_momentum': 0.9}
Epoch: 1
Loss on hold-out set: 2.317682361984253
Accuracy on hold-out set: 0.10185
MulticlassAccuracy value on hold-out data: 0.1018500030040741
Epoch: 2
Loss on hold-out set: 2.310224766921997
Accuracy on hold-out set: 0.10185
MulticlassAccuracy value on hold-out data: 0.1018500030040741
Epoch: 3
Loss on hold-out set: 2.3018915784835814
Accuracy on hold-out set: 0.10185
MulticlassAccuracy value on hold-out data: 0.1018500030040741
Epoch: 4
Loss on hold-out set: 2.2944185848236085
Accuracy on hold-out set: 0.10185
MulticlassAccuracy value on hold-out data: 0.1018500030040741
Epoch: 5
Loss on hold-out set: 2.288027394866943
Accuracy on hold-out set: 0.10265
MulticlassAccuracy value on hold-out data: 0.1026500016450882
Epoch: 6
Loss on hold-out set: 2.282048793411255
Accuracy on hold-out set: 0.1233
MulticlassAccuracy value on hold-out data: 0.12330000102519989
Epoch: 7
Loss on hold-out set: 2.2761273420333863
Accuracy on hold-out set: 0.152
MulticlassAccuracy value on hold-out data: 0.15199999511241913
Epoch: 8
Loss on hold-out set: 2.2703126028060914
Accuracy on hold-out set: 0.1676
MulticlassAccuracy value on hold-out data: 0.16760000586509705
Epoch: 9
Loss on hold-out set: 2.264477591133118
Accuracy on hold-out set: 0.1727
MulticlassAccuracy value on hold-out data: 0.17270000278949738
Epoch: 10
Loss on hold-out set: 2.2586488855361937
Accuracy on hold-out set: 0.17725
MulticlassAccuracy value on hold-out data: 0.1772499978542328
Epoch: 11
Loss on hold-out set: 2.252915013885498
Accuracy on hold-out set: 0.17725
MulticlassAccuracy value on hold-out data: 0.1772499978542328
Epoch: 12
Loss on hold-out set: 2.247335859107971
Accuracy on hold-out set: 0.17835
MulticlassAccuracy value on hold-out data: 0.17835000157356262
Epoch: 13
Loss on hold-out set: 2.242020149040222
Accuracy on hold-out set: 0.17635
MulticlassAccuracy value on hold-out data: 0.17634999752044678
Epoch: 14
Loss on hold-out set: 2.2369145687103273
Accuracy on hold-out set: 0.1792
MulticlassAccuracy value on hold-out data: 0.17919999361038208
Epoch: 15
Loss on hold-out set: 2.2320204914093016
Accuracy on hold-out set: 0.17865
MulticlassAccuracy value on hold-out data: 0.17865000665187836
Epoch: 16
Loss on hold-out set: 2.2273454607963563
Accuracy on hold-out set: 0.1782
MulticlassAccuracy value on hold-out data: 0.17820000648498535
Returned to Spot: Validation loss: 2.2273454607963563
----------------------------------------------

config: {'l1': 16, 'l2': 16, 'lr_mult': 0.001, 'batch_size': 8, 'epochs': 8, 'k_folds': 2, 'patience': 7, 'optimizer': 'Adam', 'sgd_momentum': 0.9}
Epoch: 1
Loss on hold-out set: 2.3160439434051514
Accuracy on hold-out set: 0.1007
MulticlassAccuracy value on hold-out data: 0.1006999984383583
Epoch: 2
Loss on hold-out set: 2.3151568338394166
Accuracy on hold-out set: 0.1007
MulticlassAccuracy value on hold-out data: 0.1006999984383583
Epoch: 3
Loss on hold-out set: 2.314100040245056
Accuracy on hold-out set: 0.1007
MulticlassAccuracy value on hold-out data: 0.1006999984383583
Epoch: 4
Loss on hold-out set: 2.3128211247444153
Accuracy on hold-out set: 0.1007
MulticlassAccuracy value on hold-out data: 0.1006999984383583
Epoch: 5
Loss on hold-out set: 2.311337257385254
Accuracy on hold-out set: 0.1007
MulticlassAccuracy value on hold-out data: 0.1006999984383583
Epoch: 6
Loss on hold-out set: 2.3099444057464598
Accuracy on hold-out set: 0.1009
MulticlassAccuracy value on hold-out data: 0.10090000182390213
Epoch: 7
Loss on hold-out set: 2.3086226751327517
Accuracy on hold-out set: 0.10185
MulticlassAccuracy value on hold-out data: 0.1018500030040741
Epoch: 8
Loss on hold-out set: 2.3066345465660096
Accuracy on hold-out set: 0.1116
MulticlassAccuracy value on hold-out data: 0.11159999668598175
Returned to Spot: Validation loss: 2.3066345465660096
----------------------------------------------

config: {'l1': 256, 'l2': 128, 'lr_mult': 0.001, 'batch_size': 2, 'epochs': 16, 'k_folds': 2, 'patience': 9, 'optimizer': 'Adam', 'sgd_momentum': 0.9}
Epoch: 1
Loss on hold-out set: 2.295369023823738
Accuracy on hold-out set: 0.12705
MulticlassAccuracy value on hold-out data: 0.1270499974489212
Epoch: 2
Loss on hold-out set: 2.2650628816366196
Accuracy on hold-out set: 0.16465
MulticlassAccuracy value on hold-out data: 0.16464999318122864
Epoch: 3
Loss on hold-out set: 2.200050670218468
Accuracy on hold-out set: 0.21355
MulticlassAccuracy value on hold-out data: 0.21355000138282776
Epoch: 4
Loss on hold-out set: 2.1263058030486106
Accuracy on hold-out set: 0.24285
MulticlassAccuracy value on hold-out data: 0.24285000562667847
Epoch: 5
Loss on hold-out set: 2.075596194398403
Accuracy on hold-out set: 0.2633
MulticlassAccuracy value on hold-out data: 0.26330000162124634
Epoch: 6
Loss on hold-out set: 2.042118002486229
Accuracy on hold-out set: 0.2803
MulticlassAccuracy value on hold-out data: 0.28029999136924744
Epoch: 7
Loss on hold-out set: 2.0144165697336196
Accuracy on hold-out set: 0.2889
MulticlassAccuracy value on hold-out data: 0.2888999879360199
Epoch: 8
Loss on hold-out set: 1.9925568327248095
Accuracy on hold-out set: 0.2921
MulticlassAccuracy value on hold-out data: 0.2921000123023987
Epoch: 9
Loss on hold-out set: 1.9744776573359966
Accuracy on hold-out set: 0.2967
MulticlassAccuracy value on hold-out data: 0.29670000076293945
Epoch: 10
Loss on hold-out set: 1.9580025708377362
Accuracy on hold-out set: 0.2998
MulticlassAccuracy value on hold-out data: 0.29980000853538513
Epoch: 11
Loss on hold-out set: 1.94395217654109
Accuracy on hold-out set: 0.30575
MulticlassAccuracy value on hold-out data: 0.3057500123977661
Epoch: 12
Loss on hold-out set: 1.9320823233783244
Accuracy on hold-out set: 0.30865
MulticlassAccuracy value on hold-out data: 0.3086499869823456
Epoch: 13
Loss on hold-out set: 1.9210882222354413
Accuracy on hold-out set: 0.31335
MulticlassAccuracy value on hold-out data: 0.31334999203681946
Epoch: 14
Loss on hold-out set: 1.9106255203962326
Accuracy on hold-out set: 0.3153
MulticlassAccuracy value on hold-out data: 0.31529998779296875
Epoch: 15
Loss on hold-out set: 1.9017579320669173
Accuracy on hold-out set: 0.3179
MulticlassAccuracy value on hold-out data: 0.31790000200271606
Epoch: 16
Loss on hold-out set: 1.8938326651364565
Accuracy on hold-out set: 0.32315
MulticlassAccuracy value on hold-out data: 0.32315000891685486
Returned to Spot: Validation loss: 1.8938326651364565
----------------------------------------------

config: {'l1': 8, 'l2': 32, 'lr_mult': 0.001, 'batch_size': 4, 'epochs': 8, 'k_folds': 2, 'patience': 6, 'optimizer': 'Adam', 'sgd_momentum': 0.9}
Epoch: 1
Loss on hold-out set: 2.3082054862976076
Accuracy on hold-out set: 0.1273
MulticlassAccuracy value on hold-out data: 0.12729999423027039
Epoch: 2
Loss on hold-out set: 2.3043037899017333
Accuracy on hold-out set: 0.14435
MulticlassAccuracy value on hold-out data: 0.14435000717639923
Epoch: 3
Loss on hold-out set: 2.298952097606659
Accuracy on hold-out set: 0.13915
MulticlassAccuracy value on hold-out data: 0.1391499936580658
Epoch: 4
Loss on hold-out set: 2.2942004368782043
Accuracy on hold-out set: 0.152
MulticlassAccuracy value on hold-out data: 0.15199999511241913
Epoch: 5
Loss on hold-out set: 2.2894358127593994
Accuracy on hold-out set: 0.1655
MulticlassAccuracy value on hold-out data: 0.1655000001192093
Epoch: 6
Loss on hold-out set: 2.284435478925705
Accuracy on hold-out set: 0.17245
MulticlassAccuracy value on hold-out data: 0.1724500060081482
Epoch: 7
Loss on hold-out set: 2.279789707994461
Accuracy on hold-out set: 0.17505
MulticlassAccuracy value on hold-out data: 0.17505000531673431
Epoch: 8
Loss on hold-out set: 2.275451728010178
Accuracy on hold-out set: 0.1722
MulticlassAccuracy value on hold-out data: 0.17219999432563782
Returned to Spot: Validation loss: 2.275451728010178
----------------------------------------------

config: {'l1': 64, 'l2': 512, 'lr_mult': 0.001, 'batch_size': 8, 'epochs': 16, 'k_folds': 2, 'patience': 3, 'optimizer': 'Adam', 'sgd_momentum': 0.9}
Epoch: 1
Loss on hold-out set: 2.3014001943588256
Accuracy on hold-out set: 0.1083
MulticlassAccuracy value on hold-out data: 0.10830000042915344
Epoch: 2
Loss on hold-out set: 2.297946923542023
Accuracy on hold-out set: 0.1507
MulticlassAccuracy value on hold-out data: 0.15070000290870667
Epoch: 3
Loss on hold-out set: 2.292231347179413
Accuracy on hold-out set: 0.184
MulticlassAccuracy value on hold-out data: 0.18400000035762787
Epoch: 4
Loss on hold-out set: 2.282592613220215
Accuracy on hold-out set: 0.2019
MulticlassAccuracy value on hold-out data: 0.20190000534057617
Epoch: 5
Loss on hold-out set: 2.2672391525268556
Accuracy on hold-out set: 0.21315
MulticlassAccuracy value on hold-out data: 0.2131499946117401
Epoch: 6
Loss on hold-out set: 2.2460231662750245
Accuracy on hold-out set: 0.21915
MulticlassAccuracy value on hold-out data: 0.21915000677108765
Epoch: 7
Loss on hold-out set: 2.220496086406708
Accuracy on hold-out set: 0.218
MulticlassAccuracy value on hold-out data: 0.21799999475479126
Epoch: 8
Loss on hold-out set: 2.192908446598053
Accuracy on hold-out set: 0.21825
MulticlassAccuracy value on hold-out data: 0.21825000643730164
Epoch: 9
Loss on hold-out set: 2.1664687589168548
Accuracy on hold-out set: 0.2319
MulticlassAccuracy value on hold-out data: 0.23190000653266907
Epoch: 10
Loss on hold-out set: 2.143453182077408
Accuracy on hold-out set: 0.24535
MulticlassAccuracy value on hold-out data: 0.24535000324249268
Epoch: 11
Loss on hold-out set: 2.124103342294693
Accuracy on hold-out set: 0.24895
MulticlassAccuracy value on hold-out data: 0.24895000457763672
Epoch: 12
Loss on hold-out set: 2.107735202217102
Accuracy on hold-out set: 0.24895
MulticlassAccuracy value on hold-out data: 0.24895000457763672
Epoch: 13
Loss on hold-out set: 2.093512494468689
Accuracy on hold-out set: 0.25095
MulticlassAccuracy value on hold-out data: 0.25095000863075256
Epoch: 14
Loss on hold-out set: 2.081147567367554
Accuracy on hold-out set: 0.25575
MulticlassAccuracy value on hold-out data: 0.25575000047683716
Epoch: 15
Loss on hold-out set: 2.0699062724113464
Accuracy on hold-out set: 0.26005
MulticlassAccuracy value on hold-out data: 0.2600499987602234
Epoch: 16
Loss on hold-out set: 2.0595499613285067
Accuracy on hold-out set: 0.26205
MulticlassAccuracy value on hold-out data: 0.26205000281333923
Returned to Spot: Validation loss: 2.0595499613285067
----------------------------------------------

config: {'l1': 512, 'l2': 128, 'lr_mult': 0.001, 'batch_size': 2, 'epochs': 16, 'k_folds': 2, 'patience': 10, 'optimizer': 'Adam', 'sgd_momentum': 0.9}
Epoch: 1
Loss on hold-out set: 2.288093339014053
Accuracy on hold-out set: 0.14745
MulticlassAccuracy value on hold-out data: 0.14745000004768372
Epoch: 2
Loss on hold-out set: 2.2383644287228583
Accuracy on hold-out set: 0.24585
MulticlassAccuracy value on hold-out data: 0.24584999680519104
Epoch: 3
Loss on hold-out set: 2.1543905505657195
Accuracy on hold-out set: 0.25905
MulticlassAccuracy value on hold-out data: 0.25905001163482666
Epoch: 4
Loss on hold-out set: 2.087169475835562
Accuracy on hold-out set: 0.26725
MulticlassAccuracy value on hold-out data: 0.2672500014305115
Epoch: 5
Loss on hold-out set: 2.0454443979144097
Accuracy on hold-out set: 0.279
MulticlassAccuracy value on hold-out data: 0.27900001406669617
Epoch: 6
Loss on hold-out set: 2.0169653468310833
Accuracy on hold-out set: 0.2863
MulticlassAccuracy value on hold-out data: 0.28630000352859497
Epoch: 7
Loss on hold-out set: 1.993017890369892
Accuracy on hold-out set: 0.2943
MulticlassAccuracy value on hold-out data: 0.29429998993873596
Epoch: 8
Loss on hold-out set: 1.9708995192170142
Accuracy on hold-out set: 0.29925
MulticlassAccuracy value on hold-out data: 0.2992500066757202
Epoch: 9
Loss on hold-out set: 1.9503918004870415
Accuracy on hold-out set: 0.3068
MulticlassAccuracy value on hold-out data: 0.3068000078201294
Epoch: 10
Loss on hold-out set: 1.930957144320011
Accuracy on hold-out set: 0.31435
MulticlassAccuracy value on hold-out data: 0.3143500089645386
Epoch: 11
Loss on hold-out set: 1.9140663527846336
Accuracy on hold-out set: 0.31475
MulticlassAccuracy value on hold-out data: 0.31474998593330383
Epoch: 12
Loss on hold-out set: 1.8989707651793957
Accuracy on hold-out set: 0.3219
MulticlassAccuracy value on hold-out data: 0.32190001010894775
Epoch: 13
Loss on hold-out set: 1.8855368406295776
Accuracy on hold-out set: 0.3253
MulticlassAccuracy value on hold-out data: 0.325300008058548
Epoch: 14
Loss on hold-out set: 1.8729350258708
Accuracy on hold-out set: 0.33085
MulticlassAccuracy value on hold-out data: 0.3308500051498413
Epoch: 15
Loss on hold-out set: 1.862265262067318
Accuracy on hold-out set: 0.33375
MulticlassAccuracy value on hold-out data: 0.33375000953674316
Epoch: 16
Loss on hold-out set: 1.853241665238142
Accuracy on hold-out set: 0.33615
MulticlassAccuracy value on hold-out data: 0.33614999055862427
Returned to Spot: Validation loss: 1.853241665238142
----------------------------------------------
spotPython tuning: 1.853241665238142 [##########] 100.00% Done...
<spotPython.spot.spot.Spot at 0x29e412b90>

13 Tensorboard

The textual output shown in the console (or code cell) can be visualized with Tensorboard as described in Section 14.13.

13.0.1 Results

After the hyperparameter tuning run is finished, the results can be analyzed as described in Section 14.14.

SAVE = False
LOAD = False

if SAVE:
    result_file_name = "res_" + experiment_name + ".pkl"
    with open(result_file_name, 'wb') as f:
        pickle.dump(spot_tuner, f)

if LOAD:
    result_file_name = "ADD THE NAME here, e.g.: res_ch10-friedman-hpt-0_maans03_60min_20init_1K_2023-04-14_10-11-19.pkl"
    with open(result_file_name, 'rb') as f:
        spot_tuner =  pickle.load(f)

After the hyperparameter tuning run is finished, the progress of the hyperparameter tuning can be visualized. The following code generates the progress plot from ?fig-progress.

spot_tuner.plot_progress(log_y=False,
    filename="./figures/" + experiment_name+"_progress.png")

Progress plot. Black dots denote results from the initial design. Red dots illustrate the improvement found by the surrogate model based optimization.
  • Print the results
print(gen_design_table(fun_control=fun_control,
    spot=spot_tuner))
| name         | type   | default   |   lower |   upper |   tuned | transform             |   importance | stars   |
|--------------|--------|-----------|---------|---------|---------|-----------------------|--------------|---------|
| l1           | int    | 5         |     2.0 |     9.0 |     9.0 | transform_power_2_int |       100.00 | ***     |
| l2           | int    | 5         |     2.0 |     9.0 |     7.0 | transform_power_2_int |        58.05 | **      |
| lr_mult      | float  | 1.0       |   0.001 |   0.001 |   0.001 | None                  |         0.00 |         |
| batch_size   | int    | 4         |     1.0 |     4.0 |     1.0 | transform_power_2_int |        46.96 | *       |
| epochs       | int    | 3         |     3.0 |     4.0 |     4.0 | transform_power_2_int |         0.13 | .       |
| k_folds      | int    | 1         |     2.0 |     2.0 |     2.0 | None                  |         0.00 |         |
| patience     | int    | 5         |     2.0 |    10.0 |    10.0 | None                  |         0.12 | .       |
| optimizer    | factor | SGD       |     0.0 |     0.0 |     0.0 | None                  |         0.00 |         |
| sgd_momentum | float  | 0.0       |     0.9 |     0.9 |     0.9 | None                  |         0.00 |         |

13.1 Show variable importance

spot_tuner.plot_importance(threshold=0.025, filename="./figures/" + experiment_name+"_importance.png")

Variable importance plot, threshold 0.025.

13.2 Get the Tuned Architecture (SPOT Results)

The architecture of the spotPython model can be obtained by the following code:

from spotPython.hyperparameters.values import get_one_core_model_from_X
X = spot_tuner.to_all_dim(spot_tuner.min_X.reshape(1,-1))
model_spot = get_one_core_model_from_X(X, fun_control)
model_spot
Net_CIFAR10(
  (conv1): Conv2d(3, 6, kernel_size=(5, 5), stride=(1, 1))
  (pool): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (conv2): Conv2d(6, 16, kernel_size=(5, 5), stride=(1, 1))
  (fc1): Linear(in_features=400, out_features=512, bias=True)
  (fc2): Linear(in_features=512, out_features=128, bias=True)
  (fc3): Linear(in_features=128, out_features=10, bias=True)
)

13.3 Evaluation of the Tuned Architecture

from spotPython.torch.traintest import (
    train_tuned,
    test_tuned,
    )
train_tuned(net=model_spot, train_dataset=train,
        loss_function=fun_control["loss_function"],
        metric=fun_control["metric_torch"],
        shuffle=True,
        device = fun_control["device"],
        path=None,
        task=fun_control["task"],)
Epoch: 1
Batch: 10000. Batch Size: 2. Training Loss (running): 2.300
Loss on hold-out set: 2.2904104845523836
Accuracy on hold-out set: 0.1406
MulticlassAccuracy value on hold-out data: 0.14059999585151672
Epoch: 2
Batch: 10000. Batch Size: 2. Training Loss (running): 2.281
Loss on hold-out set: 2.2546629938364027
Accuracy on hold-out set: 0.2591
MulticlassAccuracy value on hold-out data: 0.2590999901294708
Epoch: 3
Batch: 10000. Batch Size: 2. Training Loss (running): 2.232
Loss on hold-out set: 2.1802040595412255
Accuracy on hold-out set: 0.27405
MulticlassAccuracy value on hold-out data: 0.2740499973297119
Epoch: 4
Batch: 10000. Batch Size: 2. Training Loss (running): 2.146
Loss on hold-out set: 2.0877344265937805
Accuracy on hold-out set: 0.2816
MulticlassAccuracy value on hold-out data: 0.2815999984741211
Epoch: 5
Batch: 10000. Batch Size: 2. Training Loss (running): 2.058
Loss on hold-out set: 2.0125756390690803
Accuracy on hold-out set: 0.28955
MulticlassAccuracy value on hold-out data: 0.2895500063896179
Epoch: 6
Batch: 10000. Batch Size: 2. Training Loss (running): 1.992
Loss on hold-out set: 1.9651088925540448
Accuracy on hold-out set: 0.3
MulticlassAccuracy value on hold-out data: 0.30000001192092896
Epoch: 7
Batch: 10000. Batch Size: 2. Training Loss (running): 1.945
Loss on hold-out set: 1.935444342005253
Accuracy on hold-out set: 0.30595
MulticlassAccuracy value on hold-out data: 0.30594998598098755
Epoch: 8
Batch: 10000. Batch Size: 2. Training Loss (running): 1.916
Loss on hold-out set: 1.9154097786486148
Accuracy on hold-out set: 0.313
MulticlassAccuracy value on hold-out data: 0.31299999356269836
Epoch: 9
Batch: 10000. Batch Size: 2. Training Loss (running): 1.894
Loss on hold-out set: 1.899637191927433
Accuracy on hold-out set: 0.31825
MulticlassAccuracy value on hold-out data: 0.31825000047683716
Epoch: 10
Batch: 10000. Batch Size: 2. Training Loss (running): 1.891
Loss on hold-out set: 1.8874678076148033
Accuracy on hold-out set: 0.3249
MulticlassAccuracy value on hold-out data: 0.3249000012874603
Epoch: 11
Batch: 10000. Batch Size: 2. Training Loss (running): 1.864
Loss on hold-out set: 1.8766552743077278
Accuracy on hold-out set: 0.32895
MulticlassAccuracy value on hold-out data: 0.3289499878883362
Epoch: 12
Batch: 10000. Batch Size: 2. Training Loss (running): 1.856
Loss on hold-out set: 1.8654536617636681
Accuracy on hold-out set: 0.3339
MulticlassAccuracy value on hold-out data: 0.33390000462532043
Epoch: 13
Batch: 10000. Batch Size: 2. Training Loss (running): 1.841
Loss on hold-out set: 1.8563501557201147
Accuracy on hold-out set: 0.3385
MulticlassAccuracy value on hold-out data: 0.3384999930858612
Epoch: 14
Batch: 10000. Batch Size: 2. Training Loss (running): 1.841
Loss on hold-out set: 1.8458486349493266
Accuracy on hold-out set: 0.34125
MulticlassAccuracy value on hold-out data: 0.3412500023841858
Epoch: 15
Batch: 10000. Batch Size: 2. Training Loss (running): 1.828
Loss on hold-out set: 1.836037072250247
Accuracy on hold-out set: 0.34615
MulticlassAccuracy value on hold-out data: 0.3461500108242035
Epoch: 16
Batch: 10000. Batch Size: 2. Training Loss (running): 1.814
Loss on hold-out set: 1.826788178345561
Accuracy on hold-out set: 0.3507
MulticlassAccuracy value on hold-out data: 0.3506999909877777
Returned to Spot: Validation loss: 1.826788178345561
----------------------------------------------

If path is set to a filename, e.g., path = "model_spot_trained.pt", the weights of the trained model will be loaded from this file.

test_tuned(net=model_spot, test_dataset=test,
            shuffle=False,
            loss_function=fun_control["loss_function"],
            metric=fun_control["metric_torch"],
            device = fun_control["device"],
            task=fun_control["task"],)
Loss on hold-out set: 1.805866453498602
Accuracy on hold-out set: 0.3538
MulticlassAccuracy value on hold-out data: 0.3537999987602234
Final evaluation: Validation loss: 1.805866453498602
Final evaluation: Validation metric: 0.3537999987602234
----------------------------------------------
(1.805866453498602, nan, tensor(0.3538, device='mps:0'))

13.4 Cross-validated Evaluations

from spotPython.torch.traintest import evaluate_cv
# modify k-kolds:
setattr(model_spot, "k_folds",  10)
df_eval, df_preds, df_metrics = evaluate_cv(net=model_spot,
            dataset=fun_control["data"],
            loss_function=fun_control["loss_function"],
            metric=fun_control["metric_torch"],
            task=fun_control["task"],
            writer=fun_control["writer"],
            writerId="model_spot_cv",
            device = fun_control["device"])
Error in Net_Core. Call to evaluate_cv() failed. err=TypeError("Expected sequence or array-like, got <class 'NoneType'>"), type(err)=<class 'TypeError'>
metric_name = type(fun_control["metric_torch"]).__name__
print(f"loss: {df_eval}, Cross-validated {metric_name}: {df_metrics}")
loss: nan, Cross-validated MulticlassAccuracy: nan

13.5 Detailed Hyperparameter Plots

filename = "./figures/" + experiment_name
spot_tuner.plot_important_hyperparameter_contour(filename=filename)
l1:  100.0
l2:  58.047350042903965
batch_size:  46.96088396373169
epochs:  0.12960315991572416
patience:  0.12189029934839649

Contour plots.

13.6 Parallel Coordinates Plot

spot_tuner.parallel_plot()

Parallel coordinates plots

13.7 Plot all Combinations of Hyperparameters

  • Warning: this may take a while.
PLOT_ALL = False
if PLOT_ALL:
    n = spot_tuner.k
    for i in range(n-1):
        for j in range(i+1, n):
            spot_tuner.plot_contour(i=i, j=j, min_z=min_z, max_z = max_z)