12  HPT: PyTorch With cifar10 Data

In this tutorial, we will show how spotPython can be integrated into the PyTorch training workflow.

This document refers to the following software versions:

pip list | grep  "spot[RiverPython]"
spotPython                 0.2.33
spotRiver                  0.0.93
Note: you may need to restart the kernel to use updated packages.

spotPython can be installed via pip. Alternatively, the source code can be downloaded from gitHub: https://github.com/sequential-parameter-optimization/spotPython.

!pip install spotPython
# import sys
# !{sys.executable} -m pip install --upgrade build
# !{sys.executable} -m pip install --upgrade --force-reinstall spotPython

12.1 Setup

Before we consider the detailed experimental setup, we select the parameters that affect run time, initial design size and the device that is used.

MAX_TIME = 1
INIT_SIZE = 5
DEVICE = None # "cpu" # "cuda:0"
from spotPython.utils.device import getDevice
DEVICE = getDevice(DEVICE)
print(DEVICE)
mps
import os
import copy
import socket
from datetime import datetime
from dateutil.tz import tzlocal
start_time = datetime.now(tzlocal())
HOSTNAME = socket.gethostname().split(".")[0]
experiment_name = '12-torch' + "_" + HOSTNAME + "_" + str(MAX_TIME) + "min_" + str(INIT_SIZE) + "init_" + str(start_time).split(".", 1)[0].replace(' ', '_')
experiment_name = experiment_name.replace(':', '-')
print(experiment_name)
if not os.path.exists('./figures'):
    os.makedirs('./figures')
12-torch_p040025_1min_5init_2023-06-16_20-31-13

12.2 Initialization of the fun_control Dictionary

spotPython uses a Python dictionary for storing the information required for the hyperparameter tuning process, which was described in Section 13.2.

from spotPython.utils.init import fun_control_init
fun_control = fun_control_init(task="classification",
    tensorboard_path="runs/12_spot_hpt_torch_cifar10",
    device=DEVICE)

12.3 PyTorch Data Loading

12.4 1. Load Data Cifar10 Data

from torchvision import datasets, transforms
import torchvision
def load_data(data_dir="./data"):
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    ])

    trainset = torchvision.datasets.CIFAR10(
        root=data_dir, train=True, download=True, transform=transform)

    testset = torchvision.datasets.CIFAR10(
        root=data_dir, train=False, download=True, transform=transform)

    return trainset, testset
train, test = load_data()
Files already downloaded and verified
Files already downloaded and verified
  • Since this works fine, we can add the data loading to the fun_control dictionary:
n_samples = len(train)
# add the dataset to the fun_control
fun_control.update({"data": None, # dataset,
               "train": train,
               "test": test,
               "n_samples": n_samples,
               "target_column": None})

12.5 The Model (Algorithm) to be Tuned

12.6 Specification of the Preprocessing Model

After the training and test data are specified and added to the fun_control dictionary, spotPython allows the specification of a data preprocessing pipeline, e.g., for the scaling of the data or for the one-hot encoding of categorical variables, see Section 13.4.1. This feature is not used here, so we do not change the default value (which is None).

12.7 Step 4: Select algorithm and core_model_hyper_dict

12.7.1 Implementing a Configurable Neural Network With spotPython

spotPython includes the Net_CIFAR10 class which is implemented in the file netcifar10.py. The class is imported here.

This class inherits from the class Net_Core which is implemented in the file netcore.py, see ?sec-the-net-core-class-24.

from spotPython.torch.netcifar10 import Net_CIFAR10
from spotPython.data.torch_hyper_dict import TorchHyperDict
from spotPython.hyperparameters.values import add_core_model_to_fun_control
fun_control = add_core_model_to_fun_control(core_model=Net_CIFAR10,
                              fun_control=fun_control,
                              hyper_dict=TorchHyperDict,
                              filename=None)

12.8 The Search Space

12.8.1 Configuring the Search Space With spotPython

12.8.1.1 The hyper_dict Hyperparameters for the Selected Algorithm

spotPython uses JSON files for the specification of the hyperparameters, which were described in Section 13.5.2.

The corresponding entries for the Net_CIFAR10 class are shown below.

    "Net_CIFAR10":
    {
        "l1": {
            "type": "int",
            "default": 5,
            "transform": "transform_power_2_int",
            "lower": 2,
            "upper": 9},
        "l2": {
            "type": "int",
            "default": 5,
            "transform": "transform_power_2_int",
            "lower": 2,
            "upper": 9},
        "lr_mult": {
            "type": "float",
            "default": 1.0,
            "transform": "None",
            "lower": 0.1,
            "upper": 10.0},
        "batch_size": {
            "type": "int",
            "default": 4,
            "transform": "transform_power_2_int",
            "lower": 1,
            "upper": 4},
        "epochs": {
            "type": "int",
            "default": 3,
            "transform": "transform_power_2_int",
            "lower": 3,
            "upper": 4},
        "k_folds": {
            "type": "int",
            "default": 1,
            "transform": "None",
            "lower": 1,
            "upper": 1},
        "patience": {
            "type": "int",
            "default": 5,
            "transform": "None",
            "lower": 2,
            "upper": 10
        },
        "optimizer": {
            "levels": ["Adadelta",
                       "Adagrad",
                       "Adam",
                       "AdamW",
                       "SparseAdam",
                       "Adamax",
                       "ASGD",
                       "NAdam",
                       "RAdam",
                       "RMSprop",
                       "Rprop",
                       "SGD"],
            "type": "factor",
            "default": "SGD",
            "transform": "None",
            "class_name": "torch.optim",
            "core_model_parameter_type": "str",
            "lower": 0,
            "upper": 12},
        "sgd_momentum": {
            "type": "float",
            "default": 0.0,
            "transform": "None",
            "lower": 0.0,
            "upper": 1.0}
    },

12.9 Modifying the Hyperparameters

spotPython provides functions for modifying the hyperparameters, their bounds and factors as well as for activating and de-activating hyperparameters without re-compilation of the Python source code. These functions were described in Section 13.5.3.

12.9.1 Step 5: Modify hyper_dict Hyperparameters for the Selected Algorithm aka core_model

12.9.1.1 Modify Hyperparameters of Type numeric and integer (boolean)

The hyperparameter k_folds is not used, it is de-activated here by setting the lower and upper bound to the same value.

from spotPython.hyperparameters.values import modify_hyper_parameter_bounds
# fun_control = modify_hyper_parameter_bounds(fun_control, "delta", bounds=[1e-10, 1e-6])
# fun_control = modify_hyper_parameter_bounds(fun_control, "min_samples_split", bounds=[3, 20])
#fun_control = modify_hyper_parameter_bounds(fun_control, "merit_preprune", bounds=[0, 0])
# fun_control["core_model_hyper_dict"]
fun_control = modify_hyper_parameter_bounds(fun_control, "k_folds", bounds=[2, 2])

12.9.2 Modify hyperparameter of type factor

from spotPython.hyperparameters.values import modify_hyper_parameter_levels
fun_control = modify_hyper_parameter_levels(fun_control, "optimizer", ["Adam"])
# fun_control = modify_hyper_parameter_levels(fun_control, "leaf_model", ["LinearRegression"])
# fun_control["core_model_hyper_dict"]

12.9.3 Optimizers

Optimizers can be selected as described in Section 18.8.2.

Optimizers are described in Section 13.6.

fun_control = modify_hyper_parameter_bounds(fun_control,
    "lr_mult", bounds=[1e-3, 1e-3])
fun_control = modify_hyper_parameter_bounds(fun_control,
    "sgd_momentum", bounds=[0.9, 0.9])

12.10 Evaluation

The evaluation procedure requires the specification of two elements:

  1. the way how the data is split into a train and a test set and
  2. the loss function (and a metric).

These are described in Section 18.9.

The key "loss_function" specifies the loss function which is used during the optimization, see Section 13.8.

We will use CrossEntropy loss for the multiclass-classification task.

from torch.nn import CrossEntropyLoss
loss_function = CrossEntropyLoss()
fun_control.update({
        "loss_function": loss_function,
        "shuffle": True,
        "eval":  "train_hold_out"
        })

12.10.1 Metric

import torchmetrics
metric_torch = torchmetrics.Accuracy(task="multiclass",
     num_classes=10).to(fun_control["device"])
fun_control.update({"metric_torch": metric_torch})

12.11 Preparing the SPOT Call

The following code passes the information about the parameter ranges and bounds to spot.

# extract the variable types, names, and bounds
from spotPython.hyperparameters.values import (get_bound_values,
    get_var_name,
    get_var_type,)
var_type = get_var_type(fun_control)
var_name = get_var_name(fun_control)
fun_control.update({"var_type": var_type,
                    "var_name": var_name})
lower = get_bound_values(fun_control, "lower")
upper = get_bound_values(fun_control, "upper")
from spotPython.utils.eda import gen_design_table
print(gen_design_table(fun_control))
| name         | type   | default   |   lower |   upper | transform             |
|--------------|--------|-----------|---------|---------|-----------------------|
| l1           | int    | 5         |   2     |   9     | transform_power_2_int |
| l2           | int    | 5         |   2     |   9     | transform_power_2_int |
| lr_mult      | float  | 1.0       |   0.001 |   0.001 | None                  |
| batch_size   | int    | 4         |   1     |   4     | transform_power_2_int |
| epochs       | int    | 3         |   3     |   4     | transform_power_2_int |
| k_folds      | int    | 1         |   2     |   2     | None                  |
| patience     | int    | 5         |   2     |  10     | None                  |
| optimizer    | factor | SGD       |   0     |   0     | None                  |
| sgd_momentum | float  | 0.0       |   0.9   |   0.9   | None                  |

12.12 The Objective Function fun_torch

The objective function fun_torch is selected next. It implements an interface from PyTorch’s training, validation, and testing methods to spotPython.

from spotPython.fun.hypertorch import HyperTorch
fun = HyperTorch().fun_torch

12.13 Starting the Hyperparameter Tuning

import numpy as np
from spotPython.spot import spot
from math import inf
spot_tuner = spot.Spot(fun=fun,
                   lower = lower,
                   upper = upper,
                   fun_evals = inf,
                   fun_repeats = 1,
                   max_time = MAX_TIME,
                   noise = False,
                   tolerance_x = np.sqrt(np.spacing(1)),
                   var_type = var_type,
                   var_name = var_name,
                   infill_criterion = "y",
                   n_points = 1,
                   seed=123,
                   log_level = 50,
                   show_models= False,
                   show_progress= True,
                   fun_control = fun_control,
                   design_control={"init_size": INIT_SIZE,
                                   "repeats": 1},
                   surrogate_control={"noise": True,
                                      "cod_type": "norm",
                                      "min_theta": -4,
                                      "max_theta": 3,
                                      "n_theta": len(var_name),
                                      "model_fun_evals": 10_000,
                                      "log_level": 50
                                      })
spot_tuner.run(X_start=X_start)

config: {'l1': 128, 'l2': 8, 'lr_mult': 0.001, 'batch_size': 16, 'epochs': 16, 'k_folds': 2, 'patience': 5, 'optimizer': 'Adam', 'sgd_momentum': 0.9}
Epoch: 1
Loss on hold-out set: 2.315646593475342
Accuracy on hold-out set: 0.10035
MulticlassAccuracy value on hold-out data: 0.10034999996423721
Epoch: 2
Loss on hold-out set: 2.314196505355835
Accuracy on hold-out set: 0.10035
MulticlassAccuracy value on hold-out data: 0.10034999996423721
Epoch: 3
Loss on hold-out set: 2.3125065156936646
Accuracy on hold-out set: 0.10035
MulticlassAccuracy value on hold-out data: 0.10034999996423721
Epoch: 4
Loss on hold-out set: 2.3105575498580935
Accuracy on hold-out set: 0.10035
MulticlassAccuracy value on hold-out data: 0.10034999996423721
Epoch: 5
Loss on hold-out set: 2.3082722026824953
Accuracy on hold-out set: 0.10035
MulticlassAccuracy value on hold-out data: 0.10034999996423721
Epoch: 6
Loss on hold-out set: 2.3054690366744994
Accuracy on hold-out set: 0.10335
MulticlassAccuracy value on hold-out data: 0.10334999859333038
Epoch: 7
Loss on hold-out set: 2.3020389120101927
Accuracy on hold-out set: 0.11445
MulticlassAccuracy value on hold-out data: 0.11445000022649765
Epoch: 8
Loss on hold-out set: 2.298271685028076
Accuracy on hold-out set: 0.12865
MulticlassAccuracy value on hold-out data: 0.1286499947309494
Epoch: 9
Loss on hold-out set: 2.2943480314254763
Accuracy on hold-out set: 0.14205
MulticlassAccuracy value on hold-out data: 0.14204999804496765
Epoch: 10
Loss on hold-out set: 2.2902308708190917
Accuracy on hold-out set: 0.1489
MulticlassAccuracy value on hold-out data: 0.14890000224113464
Epoch: 11
Loss on hold-out set: 2.2861568740844724
Accuracy on hold-out set: 0.1536
MulticlassAccuracy value on hold-out data: 0.15360000729560852
Epoch: 12
Loss on hold-out set: 2.282076177787781
Accuracy on hold-out set: 0.15475
MulticlassAccuracy value on hold-out data: 0.1547500044107437
Epoch: 13
Loss on hold-out set: 2.2780181921005247
Accuracy on hold-out set: 0.15465
MulticlassAccuracy value on hold-out data: 0.1546500027179718
Epoch: 14
Loss on hold-out set: 2.2738588747024537
Accuracy on hold-out set: 0.1606
MulticlassAccuracy value on hold-out data: 0.16060000658035278
Epoch: 15
Loss on hold-out set: 2.2696618770599364
Accuracy on hold-out set: 0.16835
MulticlassAccuracy value on hold-out data: 0.1683499962091446
Epoch: 16
Loss on hold-out set: 2.2653541986465453
Accuracy on hold-out set: 0.17415
MulticlassAccuracy value on hold-out data: 0.1741500049829483
Returned to Spot: Validation loss: 2.2653541986465453
----------------------------------------------

config: {'l1': 16, 'l2': 16, 'lr_mult': 0.001, 'batch_size': 8, 'epochs': 8, 'k_folds': 2, 'patience': 7, 'optimizer': 'Adam', 'sgd_momentum': 0.9}
Epoch: 1
Loss on hold-out set: 2.3130117963790893
Accuracy on hold-out set: 0.10005
MulticlassAccuracy value on hold-out data: 0.10005000233650208
Epoch: 2
Loss on hold-out set: 2.311689888572693
Accuracy on hold-out set: 0.10005
MulticlassAccuracy value on hold-out data: 0.10005000233650208
Epoch: 3
Loss on hold-out set: 2.310342220592499
Accuracy on hold-out set: 0.10005
MulticlassAccuracy value on hold-out data: 0.10005000233650208
Epoch: 4
Loss on hold-out set: 2.3088318652153017
Accuracy on hold-out set: 0.10005
MulticlassAccuracy value on hold-out data: 0.10005000233650208
Epoch: 5
Loss on hold-out set: 2.3070800992965697
Accuracy on hold-out set: 0.10005
MulticlassAccuracy value on hold-out data: 0.10005000233650208
Epoch: 6
Loss on hold-out set: 2.305096861553192
Accuracy on hold-out set: 0.10025
MulticlassAccuracy value on hold-out data: 0.1002499982714653
Epoch: 7
Loss on hold-out set: 2.3029436743736267
Accuracy on hold-out set: 0.10575
MulticlassAccuracy value on hold-out data: 0.10575000196695328
Epoch: 8
Loss on hold-out set: 2.3007870406150817
Accuracy on hold-out set: 0.1246
MulticlassAccuracy value on hold-out data: 0.12460000067949295
Returned to Spot: Validation loss: 2.3007870406150817
----------------------------------------------

config: {'l1': 256, 'l2': 128, 'lr_mult': 0.001, 'batch_size': 2, 'epochs': 16, 'k_folds': 2, 'patience': 9, 'optimizer': 'Adam', 'sgd_momentum': 0.9}
Epoch: 1
Loss on hold-out set: 2.292080225634575
Accuracy on hold-out set: 0.12195
MulticlassAccuracy value on hold-out data: 0.12195000052452087
Epoch: 2
Loss on hold-out set: 2.2527622716903686
Accuracy on hold-out set: 0.17955
MulticlassAccuracy value on hold-out data: 0.17955000698566437
Epoch: 3
Loss on hold-out set: 2.1898319549918175
Accuracy on hold-out set: 0.23835
MulticlassAccuracy value on hold-out data: 0.2383500039577484
Epoch: 4
Loss on hold-out set: 2.1320833272099495
Accuracy on hold-out set: 0.2541
MulticlassAccuracy value on hold-out data: 0.2540999948978424
Epoch: 5
Loss on hold-out set: 2.0898145262241363
Accuracy on hold-out set: 0.2582
MulticlassAccuracy value on hold-out data: 0.2581999897956848
Epoch: 6
Loss on hold-out set: 2.056964628946781
Accuracy on hold-out set: 0.2649
MulticlassAccuracy value on hold-out data: 0.26489999890327454
Epoch: 7
Loss on hold-out set: 2.0288013229608537
Accuracy on hold-out set: 0.27085
MulticlassAccuracy value on hold-out data: 0.2708500027656555
Epoch: 8
Loss on hold-out set: 2.005462641263008
Accuracy on hold-out set: 0.27575
MulticlassAccuracy value on hold-out data: 0.2757500112056732
Epoch: 9
Loss on hold-out set: 1.986916621595621
Accuracy on hold-out set: 0.2807
MulticlassAccuracy value on hold-out data: 0.2806999981403351
Epoch: 10
Loss on hold-out set: 1.9713387882947921
Accuracy on hold-out set: 0.28475
MulticlassAccuracy value on hold-out data: 0.2847500145435333
Epoch: 11
Loss on hold-out set: 1.9560879898548127
Accuracy on hold-out set: 0.29075
MulticlassAccuracy value on hold-out data: 0.29074999690055847
Epoch: 12
Loss on hold-out set: 1.9449392019420861
Accuracy on hold-out set: 0.29755
MulticlassAccuracy value on hold-out data: 0.2975499927997589
Epoch: 13
Loss on hold-out set: 1.9348831713944674
Accuracy on hold-out set: 0.30305
MulticlassAccuracy value on hold-out data: 0.3030500113964081
Epoch: 14
Loss on hold-out set: 1.9249073069006204
Accuracy on hold-out set: 0.3071
MulticlassAccuracy value on hold-out data: 0.30709999799728394
Epoch: 15
Loss on hold-out set: 1.9163196316242217
Accuracy on hold-out set: 0.31175
MulticlassAccuracy value on hold-out data: 0.31174999475479126
Epoch: 16
Loss on hold-out set: 1.9064452330052852
Accuracy on hold-out set: 0.319
MulticlassAccuracy value on hold-out data: 0.3190000057220459
Returned to Spot: Validation loss: 1.9064452330052852
----------------------------------------------

config: {'l1': 8, 'l2': 32, 'lr_mult': 0.001, 'batch_size': 4, 'epochs': 8, 'k_folds': 2, 'patience': 6, 'optimizer': 'Adam', 'sgd_momentum': 0.9}
Epoch: 1
Loss on hold-out set: 2.3122916830062867
Accuracy on hold-out set: 0.1002
MulticlassAccuracy value on hold-out data: 0.10019999742507935
Epoch: 2
Loss on hold-out set: 2.3099040962219237
Accuracy on hold-out set: 0.1002
MulticlassAccuracy value on hold-out data: 0.10019999742507935
Epoch: 3
Loss on hold-out set: 2.3068060571670532
Accuracy on hold-out set: 0.10125
MulticlassAccuracy value on hold-out data: 0.10125000029802322
Epoch: 4
Loss on hold-out set: 2.3029414982795715
Accuracy on hold-out set: 0.12065
MulticlassAccuracy value on hold-out data: 0.12065000087022781
Epoch: 5
Loss on hold-out set: 2.2983582141399386
Accuracy on hold-out set: 0.13935
MulticlassAccuracy value on hold-out data: 0.13934999704360962
Epoch: 6
Loss on hold-out set: 2.292820638370514
Accuracy on hold-out set: 0.1461
MulticlassAccuracy value on hold-out data: 0.1460999995470047
Epoch: 7
Loss on hold-out set: 2.2858299319267275
Accuracy on hold-out set: 0.1527
MulticlassAccuracy value on hold-out data: 0.1527000069618225
Epoch: 8
Loss on hold-out set: 2.27768601269722
Accuracy on hold-out set: 0.15835
MulticlassAccuracy value on hold-out data: 0.15835000574588776
Returned to Spot: Validation loss: 2.27768601269722
----------------------------------------------

config: {'l1': 64, 'l2': 512, 'lr_mult': 0.001, 'batch_size': 8, 'epochs': 16, 'k_folds': 2, 'patience': 3, 'optimizer': 'Adam', 'sgd_momentum': 0.9}
Epoch: 1
Loss on hold-out set: 2.301998101902008
Accuracy on hold-out set: 0.1036
MulticlassAccuracy value on hold-out data: 0.10360000282526016
Epoch: 2
Loss on hold-out set: 2.299367659664154
Accuracy on hold-out set: 0.13285
MulticlassAccuracy value on hold-out data: 0.13285000622272491
Epoch: 3
Loss on hold-out set: 2.2937819888114928
Accuracy on hold-out set: 0.1356
MulticlassAccuracy value on hold-out data: 0.1356000006198883
Epoch: 4
Loss on hold-out set: 2.2835525563240053
Accuracy on hold-out set: 0.15275
MulticlassAccuracy value on hold-out data: 0.15275000035762787
Epoch: 5
Loss on hold-out set: 2.2678538510322572
Accuracy on hold-out set: 0.16965
MulticlassAccuracy value on hold-out data: 0.16965000331401825
Epoch: 6
Loss on hold-out set: 2.2476910405158996
Accuracy on hold-out set: 0.17415
MulticlassAccuracy value on hold-out data: 0.1741500049829483
Epoch: 7
Loss on hold-out set: 2.2244457854270934
Accuracy on hold-out set: 0.18045
MulticlassAccuracy value on hold-out data: 0.18045000731945038
Epoch: 8
Loss on hold-out set: 2.198971457386017
Accuracy on hold-out set: 0.18675
MulticlassAccuracy value on hold-out data: 0.18674999475479126
Epoch: 9
Loss on hold-out set: 2.1728923971652985
Accuracy on hold-out set: 0.19975
MulticlassAccuracy value on hold-out data: 0.19975000619888306
Epoch: 10
Loss on hold-out set: 2.14876118311882
Accuracy on hold-out set: 0.2335
MulticlassAccuracy value on hold-out data: 0.23350000381469727
Epoch: 11
Loss on hold-out set: 2.1276989792346956
Accuracy on hold-out set: 0.23685
MulticlassAccuracy value on hold-out data: 0.23684999346733093
Epoch: 12
Loss on hold-out set: 2.1099014014244077
Accuracy on hold-out set: 0.2401
MulticlassAccuracy value on hold-out data: 0.24009999632835388
Epoch: 13
Loss on hold-out set: 2.0947401561737062
Accuracy on hold-out set: 0.2412
MulticlassAccuracy value on hold-out data: 0.24120000004768372
Epoch: 14
Loss on hold-out set: 2.081419317626953
Accuracy on hold-out set: 0.2437
MulticlassAccuracy value on hold-out data: 0.24369999766349792
Epoch: 15
Loss on hold-out set: 2.0695546627521515
Accuracy on hold-out set: 0.24535
MulticlassAccuracy value on hold-out data: 0.24535000324249268
Epoch: 16
Loss on hold-out set: 2.059267439651489
Accuracy on hold-out set: 0.24815
MulticlassAccuracy value on hold-out data: 0.24815000593662262
Returned to Spot: Validation loss: 2.059267439651489
----------------------------------------------

config: {'l1': 512, 'l2': 128, 'lr_mult': 0.001, 'batch_size': 2, 'epochs': 16, 'k_folds': 2, 'patience': 9, 'optimizer': 'Adam', 'sgd_momentum': 0.9}
Epoch: 1
Loss on hold-out set: 2.289809232711792
Accuracy on hold-out set: 0.1529
MulticlassAccuracy value on hold-out data: 0.15289999544620514
Epoch: 2
Loss on hold-out set: 2.2399055677771567
Accuracy on hold-out set: 0.2568
MulticlassAccuracy value on hold-out data: 0.25679999589920044
Epoch: 3
Loss on hold-out set: 2.159338474869728
Accuracy on hold-out set: 0.25525
MulticlassAccuracy value on hold-out data: 0.2552500069141388
Epoch: 4
Loss on hold-out set: 2.0896094763815403
Accuracy on hold-out set: 0.2629
MulticlassAccuracy value on hold-out data: 0.2628999948501587
Epoch: 5
Loss on hold-out set: 2.0389798625826834
Accuracy on hold-out set: 0.2717
MulticlassAccuracy value on hold-out data: 0.271699994802475
Epoch: 6
Loss on hold-out set: 2.0001385292083027
Accuracy on hold-out set: 0.286
MulticlassAccuracy value on hold-out data: 0.28600001335144043
Epoch: 7
Loss on hold-out set: 1.9710163355231285
Accuracy on hold-out set: 0.29465
MulticlassAccuracy value on hold-out data: 0.29464998841285706
Epoch: 8
Loss on hold-out set: 1.948514760518074
Accuracy on hold-out set: 0.3043
MulticlassAccuracy value on hold-out data: 0.3043000102043152
Epoch: 9
Loss on hold-out set: 1.9313356006026268
Accuracy on hold-out set: 0.3089
MulticlassAccuracy value on hold-out data: 0.30889999866485596
Epoch: 10
Loss on hold-out set: 1.9165579177349805
Accuracy on hold-out set: 0.3148
MulticlassAccuracy value on hold-out data: 0.3147999942302704
Epoch: 11
Loss on hold-out set: 1.904406348219514
Accuracy on hold-out set: 0.31645
MulticlassAccuracy value on hold-out data: 0.31644999980926514
Epoch: 12
Loss on hold-out set: 1.892529323834181
Accuracy on hold-out set: 0.32125
MulticlassAccuracy value on hold-out data: 0.32124999165534973
Epoch: 13
Loss on hold-out set: 1.8818038147062064
Accuracy on hold-out set: 0.32585
MulticlassAccuracy value on hold-out data: 0.3258500099182129
Epoch: 14
Loss on hold-out set: 1.8708130443632602
Accuracy on hold-out set: 0.32925
MulticlassAccuracy value on hold-out data: 0.3292500078678131
Epoch: 15
Loss on hold-out set: 1.86188459828496
Accuracy on hold-out set: 0.3329
MulticlassAccuracy value on hold-out data: 0.3328999876976013
Epoch: 16
Loss on hold-out set: 1.851201780089736
Accuracy on hold-out set: 0.3354
MulticlassAccuracy value on hold-out data: 0.3353999853134155
Returned to Spot: Validation loss: 1.851201780089736
----------------------------------------------
spotPython tuning: 1.851201780089736 [##########] 100.00% Done...
<spotPython.spot.spot.Spot at 0x2a576faf0>

13 Tensorboard

The textual output shown in the console (or code cell) can be visualized with Tensorboard as described in Section 13.13.

13.0.1 Results

After the hyperparameter tuning run is finished, the results can be analyzed as described in Section 13.14.

SAVE = False
LOAD = False

if SAVE:
    result_file_name = "res_" + experiment_name + ".pkl"
    with open(result_file_name, 'wb') as f:
        pickle.dump(spot_tuner, f)

if LOAD:
    result_file_name = "ADD THE NAME here, e.g.: res_ch10-friedman-hpt-0_maans03_60min_20init_1K_2023-04-14_10-11-19.pkl"
    with open(result_file_name, 'rb') as f:
        spot_tuner =  pickle.load(f)

After the hyperparameter tuning run is finished, the progress of the hyperparameter tuning can be visualized. The following code generates the progress plot from ?fig-progress.

spot_tuner.plot_progress(log_y=False,
    filename="./figures/" + experiment_name+"_progress.png")

Progress plot. Black dots denote results from the initial design. Red dots illustrate the improvement found by the surrogate model based optimization.
  • Print the results
print(gen_design_table(fun_control=fun_control,
    spot=spot_tuner))
| name         | type   | default   |   lower |   upper |   tuned | transform             |   importance | stars   |
|--------------|--------|-----------|---------|---------|---------|-----------------------|--------------|---------|
| l1           | int    | 5         |     2.0 |     9.0 |     9.0 | transform_power_2_int |        75.62 | **      |
| l2           | int    | 5         |     2.0 |     9.0 |     7.0 | transform_power_2_int |       100.00 | ***     |
| lr_mult      | float  | 1.0       |   0.001 |   0.001 |   0.001 | None                  |         0.00 |         |
| batch_size   | int    | 4         |     1.0 |     4.0 |     1.0 | transform_power_2_int |        50.78 | **      |
| epochs       | int    | 3         |     3.0 |     4.0 |     4.0 | transform_power_2_int |         0.06 |         |
| k_folds      | int    | 1         |     2.0 |     2.0 |     2.0 | None                  |         0.00 |         |
| patience     | int    | 5         |     2.0 |    10.0 |     9.0 | None                  |         0.37 | .       |
| optimizer    | factor | SGD       |     0.0 |     0.0 |     0.0 | None                  |         0.00 |         |
| sgd_momentum | float  | 0.0       |     0.9 |     0.9 |     0.9 | None                  |         0.00 |         |

13.1 Show variable importance

spot_tuner.plot_importance(threshold=0.025, filename="./figures/" + experiment_name+"_importance.png")

Variable importance plot, threshold 0.025.

13.2 Get the Tuned Architecture (SPOT Results)

The architecture of the spotPython model can be obtained by the following code:

from spotPython.hyperparameters.values import get_one_core_model_from_X
X = spot_tuner.to_all_dim(spot_tuner.min_X.reshape(1,-1))
model_spot = get_one_core_model_from_X(X, fun_control)
model_spot
Net_CIFAR10(
  (conv1): Conv2d(3, 6, kernel_size=(5, 5), stride=(1, 1))
  (pool): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (conv2): Conv2d(6, 16, kernel_size=(5, 5), stride=(1, 1))
  (fc1): Linear(in_features=400, out_features=512, bias=True)
  (fc2): Linear(in_features=512, out_features=128, bias=True)
  (fc3): Linear(in_features=128, out_features=10, bias=True)
)

13.3 Evaluation of the Tuned Architecture

from spotPython.torch.traintest import (
    train_tuned,
    test_tuned,
    )
train_tuned(net=model_spot, train_dataset=train,
        loss_function=fun_control["loss_function"],
        metric=fun_control["metric_torch"],
        shuffle=True,
        device = fun_control["device"],
        path=None,
        task=fun_control["task"],)
Epoch: 1
Batch: 10000. Batch Size: 2. Training Loss (running): 2.299
Loss on hold-out set: 2.2851744856357574
Accuracy on hold-out set: 0.15865
MulticlassAccuracy value on hold-out data: 0.1586499959230423
Epoch: 2
Batch: 10000. Batch Size: 2. Training Loss (running): 2.275
Loss on hold-out set: 2.245528294920921
Accuracy on hold-out set: 0.2211
MulticlassAccuracy value on hold-out data: 0.22110000252723694
Epoch: 3
Batch: 10000. Batch Size: 2. Training Loss (running): 2.228
Loss on hold-out set: 2.187386551630497
Accuracy on hold-out set: 0.25045
MulticlassAccuracy value on hold-out data: 0.2504499852657318
Epoch: 4
Batch: 10000. Batch Size: 2. Training Loss (running): 2.165
Loss on hold-out set: 2.1225022766947745
Accuracy on hold-out set: 0.2609
MulticlassAccuracy value on hold-out data: 0.26089999079704285
Epoch: 5
Batch: 10000. Batch Size: 2. Training Loss (running): 2.103
Loss on hold-out set: 2.0686626995503903
Accuracy on hold-out set: 0.26805
MulticlassAccuracy value on hold-out data: 0.2680499851703644
Epoch: 6
Batch: 10000. Batch Size: 2. Training Loss (running): 2.055
Loss on hold-out set: 2.0292678597331046
Accuracy on hold-out set: 0.28115
MulticlassAccuracy value on hold-out data: 0.2811500132083893
Epoch: 7
Batch: 10000. Batch Size: 2. Training Loss (running): 2.020
Loss on hold-out set: 1.9972160290777683
Accuracy on hold-out set: 0.29175
MulticlassAccuracy value on hold-out data: 0.2917500138282776
Epoch: 8
Batch: 10000. Batch Size: 2. Training Loss (running): 1.991
Loss on hold-out set: 1.9707989422470331
Accuracy on hold-out set: 0.30125
MulticlassAccuracy value on hold-out data: 0.30125001072883606
Epoch: 9
Batch: 10000. Batch Size: 2. Training Loss (running): 1.960
Loss on hold-out set: 1.951032400649786
Accuracy on hold-out set: 0.3083
MulticlassAccuracy value on hold-out data: 0.3082999885082245
Epoch: 10
Batch: 10000. Batch Size: 2. Training Loss (running): 1.945
Loss on hold-out set: 1.9330769111812114
Accuracy on hold-out set: 0.3106
MulticlassAccuracy value on hold-out data: 0.31060001254081726
Epoch: 11
Batch: 10000. Batch Size: 2. Training Loss (running): 1.927
Loss on hold-out set: 1.917974703899026
Accuracy on hold-out set: 0.3155
MulticlassAccuracy value on hold-out data: 0.3154999911785126
Epoch: 12
Batch: 10000. Batch Size: 2. Training Loss (running): 1.913
Loss on hold-out set: 1.9047532189786434
Accuracy on hold-out set: 0.3215
MulticlassAccuracy value on hold-out data: 0.3215000033378601
Epoch: 13
Batch: 10000. Batch Size: 2. Training Loss (running): 1.897
Loss on hold-out set: 1.8918233195841312
Accuracy on hold-out set: 0.32295
MulticlassAccuracy value on hold-out data: 0.32295000553131104
Epoch: 14
Batch: 10000. Batch Size: 2. Training Loss (running): 1.892
Loss on hold-out set: 1.88024304997921
Accuracy on hold-out set: 0.3269
MulticlassAccuracy value on hold-out data: 0.32690000534057617
Epoch: 15
Batch: 10000. Batch Size: 2. Training Loss (running): 1.873
Loss on hold-out set: 1.869371847102046
Accuracy on hold-out set: 0.33075
MulticlassAccuracy value on hold-out data: 0.3307499885559082
Epoch: 16
Batch: 10000. Batch Size: 2. Training Loss (running): 1.867
Loss on hold-out set: 1.8582122519776225
Accuracy on hold-out set: 0.33655
MulticlassAccuracy value on hold-out data: 0.3365499973297119
Returned to Spot: Validation loss: 1.8582122519776225
----------------------------------------------

If path is set to a filename, e.g., path = "model_spot_trained.pt", the weights of the trained model will be loaded from this file.

test_tuned(net=model_spot, test_dataset=test,
            shuffle=False,
            loss_function=fun_control["loss_function"],
            metric=fun_control["metric_torch"],
            device = fun_control["device"],
            task=fun_control["task"],)
Loss on hold-out set: 1.8505705657571554
Accuracy on hold-out set: 0.3421
MulticlassAccuracy value on hold-out data: 0.34209999442100525
Final evaluation: Validation loss: 1.8505705657571554
Final evaluation: Validation metric: 0.34209999442100525
----------------------------------------------
(1.8505705657571554, nan, tensor(0.3421, device='mps:0'))

13.4 Cross-validated Evaluations

from spotPython.torch.traintest import evaluate_cv
# modify k-kolds:
setattr(model_spot, "k_folds",  10)
df_eval, df_preds, df_metrics = evaluate_cv(net=model_spot,
            dataset=fun_control["data"],
            loss_function=fun_control["loss_function"],
            metric=fun_control["metric_torch"],
            task=fun_control["task"],
            writer=fun_control["writer"],
            writerId="model_spot_cv",
            device = fun_control["device"])
Error in Net_Core. Call to evaluate_cv() failed. err=TypeError("Expected sequence or array-like, got <class 'NoneType'>"), type(err)=<class 'TypeError'>
metric_name = type(fun_control["metric_torch"]).__name__
print(f"loss: {df_eval}, Cross-validated {metric_name}: {df_metrics}")
loss: nan, Cross-validated MulticlassAccuracy: nan

13.5 Detailed Hyperparameter Plots

filename = "./figures/" + experiment_name
spot_tuner.plot_important_hyperparameter_contour(filename=filename)
l1:  75.61542448415453
l2:  100.0
batch_size:  50.778331725565074
epochs:  0.05529730402580659
patience:  0.36658479214640016

Contour plots.

13.6 Parallel Coordinates Plot

spot_tuner.parallel_plot()

Parallel coordinates plots

13.7 Plot all Combinations of Hyperparameters

  • Warning: this may take a while.
PLOT_ALL = False
if PLOT_ALL:
    n = spot_tuner.k
    for i in range(n-1):
        for j in range(i+1, n):
            spot_tuner.plot_contour(i=i, j=j, min_z=min_z, max_z = max_z)