deepSSF Package Walkthrough

Author
Affiliation

Queensland University of Technology, CSIRO

Published

May 31, 2026

Abstract

Here we show an example of how to use the functions in the deepSSF package, showing data loading, model training and testing and generating simulations.

To install and use the package, visit the GitHub repository here: https://github.com/swforrest/deepSSF_package, which has instructions in the README. This script will not download with the package, but it can be accessed from the deepSSF_package repo, under the ‘examples’ folder (as a Jupyter notebook or a html file).

We use data that comes with the package in the datasets folder, which is the GPS tracking data of a single water buffalo (~10,000 locations), and two spatial layers, NDVI (the average from 2018-2019, which covers the temporal extent of the data), and slope.

Code
import math
from pathlib import Path

import matplotlib.pyplot as plt
import numpy as np
import torch
import rasterio

# Use this to load the functions used in this notebook
# You can also just use import deepssf, and then call the functions with deepssf.function_name
import deepssf
from deepssf import (
    ConvJointModel,
    EarlyStopping,
    ModelParams,
    create_gif,
    fit,
    get_device,
    load_environmental_layers,
    make_dataloaders,
    make_optimisers,
    negativeLogLikeLoss,
    filter_steps_by_window,
    prepare_movement_df,
    simulate_trajectory,
    validate_next_step_probs,
)

import pandas as pd
from datetime import datetime                # Date/time utilities
from IPython.display import HTML, display    # For plotting GIFs
from rasterio.plot import show               # For plotting raster layers

print(f"deepssf {deepssf.__version__}")
print(f"torch {torch.__version__}")
deepssf 0.1.2
torch 2.12.0

Setup the script

Set random seed and the device for training the model

If you have a GPU, this is where it will be specified (automatically).

Code
# Reproducibility — fix random seeds so results are the same on every run.
torch.manual_seed(42)
np.random.seed(42)

# DEVICE is capitalised because it is a module-level constant (set once, never reassigned).
# get_device() returns 'cuda' if a GPU is available, 'mps' on Apple Silicon,
# or 'cpu' otherwise.  All tensors and the model are moved to this device.
DEVICE = get_device()
print(f"Device: {DEVICE}")
Device: mps

Set script paths

If you have cloned/forked/downloaded the GitHub repo (https://github.com/swforrest/deepSSF_package) you shouldn’t need to change these, but for your own analyses you will want to set these to point to where your data is and where you want outputs saved.

All-caps names below are module-level constants — set once here and referenced throughout the notebook without being reassigned. Python has no built-in constant keyword, so capitalisation is the community convention (PEP 8) to signal “do not change this value”.

Code
DATA_DIR = Path("../src/deepssf/datasets/data")  # location of the bundled sample data
CSV_PATH = DATA_DIR / "buffalo_djelk_id2005.csv"  # GPS tracking data for individual 2005
LAYER_PATHS = {
    "ndvi":  str(DATA_DIR / "ndvi_2005.tif"),   # Normalised Difference Vegetation Index
    "slope": str(DATA_DIR / "slope_2005.tif"),  # terrain slope derived from a DEM
}
OUTPUT_DIR   = Path("outputs")                    # figures and model checkpoints go here
SNAPSHOT_DIR = OUTPUT_DIR / "training_snapshots"  # per-batch prediction-map snapshots
OUTPUT_DIR.mkdir(exist_ok=True)

Import data

The csv file is imported with a timestamp and an x and y coordinate. In our case the data was already in a projected format that is specific to Australia (EPSG:3112), which is the same crs as the raster data which will be imported.

Code
raw_df = pd.read_csv(CSV_PATH)
print(f"Raw GPS fixes : {len(raw_df):,}")
print(f"Individuals   : {raw_df['id'].nunique()}")

# Sort by the individual ID and timestamp to ensure correct temporal order
raw_df = raw_df.sort_values(by=["id", "time"]).reset_index(drop=True)

# # Convert the 'time' column to the correct timezone (here Australia/Darwin) 
raw_df["time"] = (
        pd.to_datetime(raw_df["time"], utc=True)
        .dt.tz_convert("Australia/Darwin")
    )

raw_df.head()
Raw GPS fixes : 10,297
Individuals   : 1
id time x y
0 2005 2018-07-25 09:34:02+09:30 41941.331695 -1.435875e+06
1 2005 2018-07-25 10:34:23+09:30 41969.310875 -1.435671e+06
2 2005 2018-07-25 11:34:39+09:30 41921.521939 -1.435654e+06
3 2005 2018-07-25 12:34:17+09:30 41779.439594 -1.435601e+06
4 2005 2018-07-25 13:34:39+09:30 41841.203272 -1.435635e+06

load_environmental_layers()

Reads each GeoTIFF listed in LAYER_PATHS, min–max

  • Scales every band to [0, 1], and returns:
    • env_layers — dict mapping layer name → 2-D NumPy array (H × W)
    • raster_transform — Affine object that converts pixel (row, col) indices to the projected coordinates (easting, northing) of the raster(s), which should be in the same CRS as the GPS data. This transform is used whenever we need to translate between pixel space and real-world coordinates.
Code
env_layers, raster_transform = load_environmental_layers(LAYER_PATHS)
print("Layers loaded:", list(env_layers.keys()))
for name, arr in env_layers.items():
    print(f"  {name}: shape={arr.shape}, min={arr.min():.3f}, max={arr.max():.3f}")
Layer 'ndvi': min=-0.06614641100168228, max=0.731696367263794 → scaled to [0, 1]
Layer 'slope': min=0.00013558653881773353, max=12.298083305358887 → scaled to [0, 1]
Layers loaded: ['ndvi', 'slope']
  ndvi: shape=(758, 1175), min=0.000, max=1.090
  slope: shape=(758, 1175), min=0.000, max=1.000

Plot each spatial covariate with the observed GPS locations overlaid.

This confirms the raster extent covers the animal’s home range and gives a visual sense of the landscape the model will learn from.

Code
fig, axes = plt.subplots(1, 2, figsize=(13, 5))

layer_plot = [
    ("ndvi",  "NDVI (scaled 0–1)",  "YlGn"),
    ("slope", "Slope (scaled 0–1)", "viridis"),
]

for ax, (name, label, cmap) in zip(axes, layer_plot):
    rasterio.plot.show(env_layers[name], transform=raster_transform, ax=ax, cmap=cmap)
    ax.plot(
        raw_df["x"], raw_df["y"],
        "r-o", alpha=0.5, markersize=1, linewidth=0.8, label="GPS trajectory",
    )
    ax.set_xlabel("Easting (m, projected)")
    ax.set_ylabel("Northing (m, projected)")
    ax.set_title(label)
    ax.legend(loc="upper right", markerscale=6)

plt.suptitle("Spatial covariates with observed GPS locations", y=1.01)
plt.tight_layout()
plt.savefig(str(OUTPUT_DIR / "spatial_covariates.png"), dpi=100)
plt.show()

prepare_movement_df()

Converts a fixes dataframe (one row per GPS location) into a step dataframe (one row per movement between consecutive fixes).

  • For each step it computes:
    • dx, dy — displacement in projected CRS units (metres)
    • bearing — direction of travel (radians; 0 = north, clockwise positive)
    • bearing_tm1 — bearing of the previous step, used as a model input so the network can capture directional persistence by estimating a turning angle
    • dt_hour — time elapsed between consecutive fixes (hours)
    • hour_t1, yday_t1 — decimal hour-of-day and day-of-year at the step start location
    • *_sin1 / *_cos1 — sine/cosine encodings of hour and yday (see SCALAR_COLS below)

The trailing underscore on column names (x1_, y1_, t1_, etc.) flags them as “start of step” variables; x2_, y2_, t2_ are the corresponding “end of step” values, following the convention of the amt package (Signer et al. 2019). The last fix for each individual is dropped because no forward step can be formed.

Code
step_df = prepare_movement_df(raw_df)
print(f"Movement steps : {len(step_df):,}  ({len(raw_df) - len(step_df)} dropped — last fix per individual)")
step_df.head()
Movement steps : 10,296  (1 dropped — last fix per individual)
id t1_ x1_ y1_ t2_ x2_ y2_ dx dy bearing bearing_tm1 dt_hour hour_t1 yday_t1 hour_t1_sin1 hour_t1_cos1 yday_t1_sin1 yday_t1_cos1
0 2005 2018-07-25 00:04:02 41941.331695 -1.435875e+06 2018-07-25 01:04:23 41969.310875 -1.435671e+06 27.979180 204.064135 1.434536 0.000000 1.005833 0.066667 206.0 0.017452 0.999848 -0.391358 -0.920239
1 2005 2018-07-25 01:04:23 41969.310875 -1.435671e+06 2018-07-25 02:04:39 41921.521939 -1.435654e+06 -47.788936 16.857110 2.802478 1.434536 1.004444 1.066667 206.0 0.275637 0.961262 -0.391358 -0.920239
2 2005 2018-07-25 02:04:39 41921.521939 -1.435654e+06 2018-07-25 03:04:17 41779.439594 -1.435601e+06 -142.082345 53.568427 2.781049 2.802478 0.993889 2.066667 206.0 0.515038 0.857167 -0.391358 -0.920239
3 2005 2018-07-25 03:04:17 41779.439594 -1.435601e+06 2018-07-25 04:04:39 41841.203272 -1.435635e+06 61.763677 -34.322938 -0.507220 2.781049 1.006111 3.066667 206.0 0.719340 0.694658 -0.391358 -0.920239
4 2005 2018-07-25 04:04:39 41841.203272 -1.435635e+06 2018-07-25 05:04:27 41655.463332 -1.435604e+06 -185.739939 31.003534 2.976198 -0.507220 0.996667 4.066667 206.0 0.874620 0.484810 -0.391358 -0.920239

SCALAR_COLS

These are non-spatial inputs. For processing with convolutional layers, they are converted into spatial windows (rasters/arrays), and stacked with the truly spatial covariates.

The model receives a fixed-size spatial window centred on each step’s start location. To let the network express time-varying habitat preferences (e.g. different nocturnal vs diurnal use), scalar covariates are tiled into additional channels of the same H × W size, so every pixel “knows” the current time context when habitat suitability is estimated.

  • hour_t1_sin1 / hour_t1_cos1
    • Sine and cosine of (2π × hour / 24), encoding time-of-day as a circular variable. The circular encoding means 23:00 and 01:00 are treated as adjacent — a plain hour value would have an artificial discontinuity at midnight.
  • yday_t1_sin1 / yday_t1_cos1
    • Same circular encoding for day-of-year (1–365), capturing seasonal variation in habitat use (e.g. dry-season vs wet-season movements).
  • dt_hour
    • Elapsed time between consecutive fixes (hours). Because GPS collars do not always fix at a perfectly regular rate, including dt_hour lets the movement kernel rescale expected displacement proportionally to the actual time interval rather than assuming a fixed fix rate. Note: dt_hour must be supplied at simulation time as well, so that simulate_trajectory() produces inputs consistent with training.
Code
SCALAR_COLS = ["hour_t1_sin1",
               "hour_t1_cos1", 
               "yday_t1_sin1", 
               "yday_t1_cos1", 
               "dt_hour"]

WINDOW_SIZE

The size of the local availability window in cells/pixels; 101 pixels × 25 m = 2525 m from the centre location to the edge

Code
WINDOW_SIZE = 101  

PIXEL_SIZE

Raster resolution in metres — must match the loaded GeoTIFFs

Code
PIXEL_SIZE  = 25    

BATCH_SIZE

Number of training samples (movement steps) per training mini-batch

Code
BATCH_SIZE  = 32    

OUTPUT_CHANNELS

Number of convolutional filters (and therefore feature maps) for each convolution layer. More convolutional filters can provide more flexible functions, but will increase the number of parameters and slow down training times.

Code
OUTPUT_CHANNELS = 4

Pre-calculate the flattened CNN output size

This value will be the input to the fully-connected layers in the movement subnetwork, and is calculated here so that ModelParams can be set without hard-coding it.

  • The formula mirrors the layers inside ConvJointModel:
    • conv (padding=1) keeps the spatial dimension unchanged at each layer
    • maxpool k=2 - the number of max pooling layers (set in the model code and isn’t currently editable in this script)
    • s=2 halves the spatial dimension (floor division) at each layer
Code
dim = WINDOW_SIZE
for _ in range(3):
    dim = math.floor((dim + 2 * 1 - 3) / 1 + 1)  # conv (pad=1 preserves dim)
    dim = math.floor((dim - 2) / 2 + 1)           # maxpool k=2, s=2
DENSE_DIM = OUTPUT_CHANNELS * dim * dim
print(f"After 3× conv+maxpool: dim={dim}  →  dense_dim_in_all={DENSE_DIM}")
After 3× conv+maxpool: dim=12  →  dense_dim_in_all=576

filter_steps_by_window()

Removes steps whose straight-line displacement exceeds half the spatial window (WINDOW_SIZE × PIXEL_SIZE / 2 metres). Such steps would place the true next location outside the predicted probability map, causing an index error in the loss function.

Code
step_df = filter_steps_by_window(step_df, window_size=WINDOW_SIZE, pixel_size=PIXEL_SIZE)
print(f"Steps after window filter: {len(step_df):,}")
Steps after window filter: 10,101

make_dataloaders()

Splits step_df chronologically into train / val / test subsets (here 80 / 10 / 10 %) and returns a PyTorch DataLoader for each.

  • Each DataLoader yields tuples of:
    • x1
      — spatial window tensor (batch × n_spatial_channels × H × W)
      • One channel per entry in LAYER_PATHS, cropped around each step’s start pixel.
    • x2
      — scalar covariate tensor (batch × len(SCALAR_COLS))
      • Raw scalar values; broadcast into spatial maps inside the model.
    • x3
      — previous bearing tensor (batch × 1)
      • Conditions the movement kernel on the direction of the prior step.
    • (px2, py2) — pixel indices of the true next location within the window
      • Used by the loss function to look up the predicted probability at the observed destination.

The layer_paths argument causes the function to re-read the rasters internally, so the DataLoader is independent of the env_layers dict loaded above.

Code
dl_train, dl_val, dl_test = make_dataloaders(
    layer_paths=LAYER_PATHS,
    window_size=WINDOW_SIZE,
    batch_size=BATCH_SIZE,
    train_split=0.8,
    val_split=0.1,
    scalar_cols=SCALAR_COLS,
    df=step_df,
)
print(f"Train batches : {len(dl_train)}")
print(f"Val   batches : {len(dl_val)}")
print(f"Test  batches : {len(dl_test)}")
Layer 'ndvi': min=-0.06614641100168228, max=0.731696367263794 → scaled to [0, 1]
Layer 'slope': min=0.00013558653881773353, max=12.298083305358887 → scaled to [0, 1]
/Users/scottforrest/Library/Mobile Documents/com~apple~CloudDocs/github_repos/deepSSF_package/src/deepssf/data.py:366: UserWarning: The given NumPy array is not writable, and PyTorch does not support non-writable tensors. This means writing to this tensor will result in undefined behavior. You may want to copy the array to protect its data or make it writable before converting it to a tensor. This type of warning will be suppressed for the rest of this program. (Triggered internally at /Users/runner/work/pytorch/pytorch/torch/csrc/utils/tensor_numpy.cpp:219.)
  self.scalar_to_grid_data = torch.from_numpy(
Train batches : 253
Val   batches : 32
Test  batches : 32

Inspect one batch to verify tensor shapes match ModelParams expectations

x1 has one channel per spatial layer (ndvi + slope = 2). The scalar covariates in x2 are broadcast to full H × W maps inside ConvJointModel and concatenated with x1 there, giving the CNN 2 + len(SCALAR_COLS) = 7 input channels total.

The px2 and py2 are the target that the model is trying to predict, and are the pixel values of the next step (with the centre cell being (50, 50))

next(iter(dl_train)) will call the next batch, which are randomly shuffled, so calling the function again will provide different training samples.

Code
x1, x2, x3, (px2, py2), _ = next(iter(dl_train))
print(f"Spatial   x1 : {x1.shape}   (batch × channels × H × W)")
print(f"Scalars   x2 : {x2.shape}   (batch × len(SCALAR_COLS))")
print(f"Bearing   x3 : {x3.shape}   (batch × 1)")
print(f"Next-step px2: {px2[:4].tolist()}  py2: {py2[:4].tolist()}")
Spatial   x1 : torch.Size([32, 2, 101, 101])   (batch × channels × H × W)
Scalars   x2 : torch.Size([32, 5])   (batch × len(SCALAR_COLS))
Bearing   x3 : torch.Size([32, 1])   (batch × 1)
Next-step px2: [54, 51, 50, 50]  py2: [51, 53, 50, 50]

ModelParams

A thin dict-like container (it is a wrapper around something else rather than storing data) that groups all architecture hyper-parameters so they can be passed to ConvJointModel and logged together.

This is where the parameters that are passed to the model are set.

Code
params = ModelParams({
    "batch_size":                BATCH_SIZE,
    "image_dim":                 WINDOW_SIZE,           # spatial window side length (pixels)
    "pixel_size":                PIXEL_SIZE,            # metres per pixel
    "dim_in_nonspatial_to_grid": len(SCALAR_COLS),      # scalar channels tiled into the window
    "dense_dim_in_nonspatial":   len(SCALAR_COLS),      # scalar inputs to the dense MLP branch
    "dense_dim_hidden":          64,                    # hidden units in the scalar MLP
    "dense_dim_in_all":          DENSE_DIM,             # flattened CNN output size (computed above)
    "input_channels":            2 + len(SCALAR_COLS),  # 2 env layers + scalar-to-grid channels
    "output_channels":           OUTPUT_CHANNELS,       # feature maps out of each conv layer
    "kernel_size":               3,                     # conv kernel spatial size
    "stride":                    1,
    "kernel_size_mp":            2,                     # max-pool kernel size
    "stride_mp":                 2,
    "padding":                   1,                     # same-padding to preserve spatial dims through conv
    "num_movement_params":       12,                    # parameters of the movement kernel
    "dropout":                   0.0,
    "device":                    DEVICE,
})

ConvJointModel

The deepSSF model! A joint habitat–movement model with two heads: 1. Habitat head — a small CNN applied to the spatial window (env layers + tiled scalars) that outputs a per-pixel log-suitability map. 2. Movement head — a dense MLP (fully-connected layer) that takes the scalar covariates and the previous bearing, and produces parameters for a parametric movement kernel (a mixture of wrapped distributions centred on the current location).

The two log-probability surfaces are summed pixel-wise to give the joint predicted probability map for the animal’s next step.

Code
model = ConvJointModel(params).to(DEVICE)
n_params = sum(p.numel() for p in model.parameters())
print(f"Model parameters : {n_params:,}")
Model parameters : 43,009

negativeLogLikeLoss

Scores each predicted probability map against the true next-step pixel index; minimising it pushes probability mass toward observed locations.

Code
loss_fn = negativeLogLikeLoss(reduction="mean")

make_optimisers()

Creates two Adam optimisers with separate learning rates:

  • lr_habitat
    • controls the CNN (habitat) head; typically higher
  • lr_movement
    • controls the movement-kernel MLP; kept lower for stability

It also returns a ReduceLROnPlateau scheduler for each optimiser that reduces the learning rate by a factor of scheduler_factor (default = 0.1) if validation loss does not improve for scheduler_patience epochs (default = 5).

Code
optimisers, schedulers = make_optimisers(
    model, lr_habitat=1e-3, lr_movement=1e-5, scheduler_patience=3, scheduler_factor=0.1
)

EarlyStopping

Monitors validation loss and saves the best model weights to disk whenever a new minimum is reached. Training stops automatically if the loss does not improve for patience consecutive epochs, preventing overfitting.

Code
early_stop = EarlyStopping(
    patience=5, verbose=True, path=str(OUTPUT_DIR / "best_model.pt")
)

fit()

Runs the full training loop.

  • Inputs:
    • model
      • the name of the deepSSF model specified using ConvJointModel above
    • image_trim_pixels
      • how many outer pixel rings to trim from the output images before saving to prevent artifacts with the colour scaling due to the image padding being pulled into the image during the convolution layers.
    • dl_train
      • train dataloader (contains the training data)
    • dl_val
      • validation dataloader (contains the validation data used to assess the learning rate and early stopping)
    • loss_fn
      • loss function
    • optimisers
      • optimisers used to update the model weights
    • schedulers
      • learning-rate schedulers (used to update the learning rates)
    • early_stopping
      • stores the validation loss and stops the model training if it doesn’t improve for a given number of epochs (set using patience in the EarlyStopping definition)
    • snapshot_dir
      • where to save images of the training progress
    • snapshot_item
      • which sample (movement step) from the validation dataset to run the model on to generate the movement, habitat selection and next-step probability surfaces
  • Outputs:
    • history dict with keys train_losses and val_losses (one scalar value per completed epoch)

Every epoch, a plot of the training and validation loss, and the predicted probability maps (movement, habitat selection and next-step) are written to SNAPSHOT_DIR, which will be turned into a GIF in the next cell.

Code
history = fit(
    model=model,
    image_trim_pixels=4,
    window_size=WINDOW_SIZE,
    dl_train=dl_train,
    dl_val=dl_val,
    loss_fn=loss_fn,
    optimisers=optimisers,
    schedulers=schedulers,
    n_epochs=50,
    early_stopping=early_stop,
    snapshot_dir=str(SNAPSHOT_DIR),
    snapshot_item=250,
)

Epoch 1/50
loss:        6.513348  [   32/ 8081]
loss:        6.319979  [  352/ 8081]
loss:        6.238961  [  672/ 8081]
loss:        6.757659  [  992/ 8081]
loss:        6.764330  [ 1312/ 8081]
loss:        6.599177  [ 1632/ 8081]
loss:        6.631302  [ 1952/ 8081]
loss:        7.107268  [ 2272/ 8081]
loss:        6.204191  [ 2592/ 8081]
loss:        6.667402  [ 2912/ 8081]
loss:        6.679033  [ 3232/ 8081]
loss:        6.146152  [ 3552/ 8081]
loss:        6.724045  [ 3872/ 8081]
loss:        6.560881  [ 4192/ 8081]
loss:        6.125783  [ 4512/ 8081]
loss:        6.735407  [ 4832/ 8081]
loss:        6.529395  [ 5152/ 8081]
loss:        6.726855  [ 5472/ 8081]
loss:        6.412127  [ 5792/ 8081]
loss:        6.422827  [ 6112/ 8081]
loss:        6.490239  [ 6432/ 8081]
loss:        6.592959  [ 6752/ 8081]
loss:        6.916759  [ 7072/ 8081]
loss:        6.458790  [ 7392/ 8081]
loss:        6.815647  [ 7712/ 8081]
loss:        6.535911  [ 8032/ 8081]

Avg training loss:        6.512734
Val loss: 6.399613  (hab: 9.149054, mov: 6.424873)
Validation loss decreased (inf → 6.399613). Saving model…

Epoch 2/50
loss:        5.898046  [   32/ 8081]
loss:        6.333018  [  352/ 8081]
loss:        6.913955  [  672/ 8081]
loss:        6.317987  [  992/ 8081]
loss:        6.061097  [ 1312/ 8081]
loss:        6.165645  [ 1632/ 8081]
loss:        6.639863  [ 1952/ 8081]
loss:        5.940703  [ 2272/ 8081]
loss:        6.259355  [ 2592/ 8081]
loss:        6.342093  [ 2912/ 8081]
loss:        6.987063  [ 3232/ 8081]
loss:        6.345895  [ 3552/ 8081]
loss:        6.048259  [ 3872/ 8081]
loss:        6.225343  [ 4192/ 8081]
loss:        6.254364  [ 4512/ 8081]
loss:        6.314987  [ 4832/ 8081]
loss:        6.773647  [ 5152/ 8081]
loss:        6.380954  [ 5472/ 8081]
loss:        7.027800  [ 5792/ 8081]
loss:        6.871723  [ 6112/ 8081]
loss:        6.887583  [ 6432/ 8081]
loss:        6.625319  [ 6752/ 8081]
loss:        6.337022  [ 7072/ 8081]
loss:        6.670538  [ 7392/ 8081]
loss:        6.062923  [ 7712/ 8081]
loss:        5.442103  [ 8032/ 8081]

Avg training loss:        6.400464
Val loss: 6.279017  (hab: 9.092908, mov: 6.338993)
Validation loss decreased (6.399613 → 6.279017). Saving model…

Epoch 3/50
loss:        5.881618  [   32/ 8081]
loss:        6.321713  [  352/ 8081]
loss:        6.744115  [  672/ 8081]
loss:        6.266128  [  992/ 8081]
loss:        5.953803  [ 1312/ 8081]
loss:        5.952302  [ 1632/ 8081]
loss:        6.486915  [ 1952/ 8081]
loss:        6.306614  [ 2272/ 8081]
loss:        6.361903  [ 2592/ 8081]
loss:        6.050750  [ 2912/ 8081]
loss:        5.978547  [ 3232/ 8081]
loss:        6.266454  [ 3552/ 8081]
loss:        6.571918  [ 3872/ 8081]
loss:        6.279250  [ 4192/ 8081]
loss:        6.508870  [ 4512/ 8081]
loss:        6.340617  [ 4832/ 8081]
loss:        6.716526  [ 5152/ 8081]
loss:        5.994553  [ 5472/ 8081]
loss:        6.505113  [ 5792/ 8081]
loss:        5.982091  [ 6112/ 8081]
loss:        5.923322  [ 6432/ 8081]
loss:        5.906955  [ 6752/ 8081]
loss:        6.889729  [ 7072/ 8081]
loss:        6.473563  [ 7392/ 8081]
loss:        6.630983  [ 7712/ 8081]
loss:        6.277025  [ 8032/ 8081]

Avg training loss:        6.299221
Val loss: 6.166172  (hab: 9.073013, mov: 6.243046)
Validation loss decreased (6.279017 → 6.166172). Saving model…

Epoch 4/50
loss:        6.215015  [   32/ 8081]
loss:        5.581359  [  352/ 8081]
loss:        6.195288  [  672/ 8081]
loss:        5.711836  [  992/ 8081]
loss:        6.514735  [ 1312/ 8081]
loss:        6.181086  [ 1632/ 8081]
loss:        6.923858  [ 1952/ 8081]
loss:        6.033204  [ 2272/ 8081]
loss:        6.238012  [ 2592/ 8081]
loss:        6.453608  [ 2912/ 8081]
loss:        7.261230  [ 3232/ 8081]
loss:        6.563379  [ 3552/ 8081]
loss:        6.737093  [ 3872/ 8081]
loss:        6.068009  [ 4192/ 8081]
loss:        6.160167  [ 4512/ 8081]
loss:        5.942147  [ 4832/ 8081]
loss:        5.894120  [ 5152/ 8081]
loss:        5.960709  [ 5472/ 8081]
loss:        5.260208  [ 5792/ 8081]
loss:        5.780227  [ 6112/ 8081]
loss:        6.615031  [ 6432/ 8081]
loss:        6.715469  [ 6752/ 8081]
loss:        5.114902  [ 7072/ 8081]
loss:        5.679732  [ 7392/ 8081]
loss:        6.047050  [ 7712/ 8081]
loss:        5.955266  [ 8032/ 8081]

Avg training loss:        6.195487
Val loss: 6.060345  (hab: 9.058475, mov: 6.141308)
Validation loss decreased (6.166172 → 6.060345). Saving model…

Epoch 5/50
loss:        6.658419  [   32/ 8081]
loss:        6.726235  [  352/ 8081]
loss:        6.370027  [  672/ 8081]
loss:        6.625559  [  992/ 8081]
loss:        6.660704  [ 1312/ 8081]
loss:        5.754659  [ 1632/ 8081]
loss:        6.008422  [ 1952/ 8081]
loss:        7.062066  [ 2272/ 8081]
loss:        5.216234  [ 2592/ 8081]
loss:        6.090228  [ 2912/ 8081]
loss:        5.673282  [ 3232/ 8081]
loss:        6.185088  [ 3552/ 8081]
loss:        5.630754  [ 3872/ 8081]
loss:        5.691380  [ 4192/ 8081]
loss:        6.619287  [ 4512/ 8081]
loss:        5.246547  [ 4832/ 8081]
loss:        6.479588  [ 5152/ 8081]
loss:        5.380687  [ 5472/ 8081]
loss:        5.967161  [ 5792/ 8081]
loss:        6.535676  [ 6112/ 8081]
loss:        6.005552  [ 6432/ 8081]
loss:        6.613456  [ 6752/ 8081]
loss:        6.336710  [ 7072/ 8081]
loss:        5.654268  [ 7392/ 8081]
loss:        5.885343  [ 7712/ 8081]
loss:        5.609878  [ 8032/ 8081]

Avg training loss:        6.098873
Val loss: 5.955629  (hab: 9.030585, mov: 6.045432)
Validation loss decreased (6.060345 → 5.955629). Saving model…

Epoch 6/50
loss:        5.822225  [   32/ 8081]
loss:        5.761471  [  352/ 8081]
loss:        6.183834  [  672/ 8081]
loss:        6.046104  [  992/ 8081]
loss:        6.002414  [ 1312/ 8081]
loss:        5.816042  [ 1632/ 8081]
loss:        5.728461  [ 1952/ 8081]
loss:        6.272363  [ 2272/ 8081]
loss:        6.305082  [ 2592/ 8081]
loss:        6.311210  [ 2912/ 8081]
loss:        5.821839  [ 3232/ 8081]
loss:        6.511950  [ 3552/ 8081]
loss:        5.849454  [ 3872/ 8081]
loss:        6.237171  [ 4192/ 8081]
loss:        5.767322  [ 4512/ 8081]
loss:        6.190836  [ 4832/ 8081]
loss:        6.446795  [ 5152/ 8081]
loss:        5.346597  [ 5472/ 8081]
loss:        6.349314  [ 5792/ 8081]
loss:        5.806986  [ 6112/ 8081]
loss:        6.208708  [ 6432/ 8081]
loss:        6.476559  [ 6752/ 8081]
loss:        5.439232  [ 7072/ 8081]
loss:        6.363841  [ 7392/ 8081]
loss:        5.783147  [ 7712/ 8081]
loss:        6.551175  [ 8032/ 8081]

Avg training loss:        6.016759
Val loss: 5.883277  (hab: 9.037213, mov: 5.961697)
Validation loss decreased (5.955629 → 5.883277). Saving model…

Epoch 7/50
loss:        5.472890  [   32/ 8081]
loss:        6.087283  [  352/ 8081]
loss:        5.947907  [  672/ 8081]
loss:        5.972356  [  992/ 8081]
loss:        6.231291  [ 1312/ 8081]
loss:        6.414335  [ 1632/ 8081]
loss:        6.680422  [ 1952/ 8081]
loss:        6.363143  [ 2272/ 8081]
loss:        5.393204  [ 2592/ 8081]
loss:        5.106235  [ 2912/ 8081]
loss:        5.696716  [ 3232/ 8081]
loss:        6.890513  [ 3552/ 8081]
loss:        6.610801  [ 3872/ 8081]
loss:        5.626327  [ 4192/ 8081]
loss:        5.699317  [ 4512/ 8081]
loss:        6.330873  [ 4832/ 8081]
loss:        5.627152  [ 5152/ 8081]
loss:        6.197600  [ 5472/ 8081]
loss:        6.725983  [ 5792/ 8081]
loss:        5.723792  [ 6112/ 8081]
loss:        6.072978  [ 6432/ 8081]
loss:        5.844558  [ 6752/ 8081]
loss:        5.957417  [ 7072/ 8081]
loss:        5.529443  [ 7392/ 8081]
loss:        5.941032  [ 7712/ 8081]
loss:        6.097619  [ 8032/ 8081]

Avg training loss:        5.942231
Val loss: 5.812197  (hab: 9.032493, mov: 5.893594)
Validation loss decreased (5.883277 → 5.812197). Saving model…

Epoch 8/50
loss:        5.744102  [   32/ 8081]
loss:        5.656890  [  352/ 8081]
loss:        6.292316  [  672/ 8081]
loss:        6.252225  [  992/ 8081]
loss:        5.939830  [ 1312/ 8081]
loss:        6.296592  [ 1632/ 8081]
loss:        5.747702  [ 1952/ 8081]
loss:        5.712788  [ 2272/ 8081]
loss:        5.741662  [ 2592/ 8081]
loss:        5.285065  [ 2912/ 8081]
loss:        5.173178  [ 3232/ 8081]
loss:        6.903015  [ 3552/ 8081]
loss:        5.695376  [ 3872/ 8081]
loss:        5.744624  [ 4192/ 8081]
loss:        5.954119  [ 4512/ 8081]
loss:        4.970155  [ 4832/ 8081]
loss:        6.148893  [ 5152/ 8081]
loss:        6.026309  [ 5472/ 8081]
loss:        5.662178  [ 5792/ 8081]
loss:        6.148443  [ 6112/ 8081]
loss:        5.601367  [ 6432/ 8081]
loss:        5.772060  [ 6752/ 8081]
loss:        4.635138  [ 7072/ 8081]
loss:        5.972570  [ 7392/ 8081]
loss:        5.541584  [ 7712/ 8081]
loss:        5.249794  [ 8032/ 8081]

Avg training loss:        5.888501
Val loss: 5.770297  (hab: 9.054404, mov: 5.838027)
Validation loss decreased (5.812197 → 5.770297). Saving model…

Epoch 9/50
loss:        5.494358  [   32/ 8081]
loss:        5.966914  [  352/ 8081]
loss:        5.924324  [  672/ 8081]
loss:        5.797906  [  992/ 8081]
loss:        6.357792  [ 1312/ 8081]
loss:        6.078364  [ 1632/ 8081]
loss:        6.116479  [ 1952/ 8081]
loss:        6.078429  [ 2272/ 8081]
loss:        5.033019  [ 2592/ 8081]
loss:        4.953336  [ 2912/ 8081]
loss:        6.628636  [ 3232/ 8081]
loss:        5.549214  [ 3552/ 8081]
loss:        6.284534  [ 3872/ 8081]
loss:        4.622386  [ 4192/ 8081]
loss:        6.369568  [ 4512/ 8081]
loss:        4.868660  [ 4832/ 8081]
loss:        5.488200  [ 5152/ 8081]
loss:        5.862567  [ 5472/ 8081]
loss:        6.039328  [ 5792/ 8081]
loss:        5.234589  [ 6112/ 8081]
loss:        5.907556  [ 6432/ 8081]
loss:        5.743940  [ 6752/ 8081]
loss:        5.441957  [ 7072/ 8081]
loss:        5.440845  [ 7392/ 8081]
loss:        5.158034  [ 7712/ 8081]
loss:        6.050865  [ 8032/ 8081]

Avg training loss:        5.838698
Val loss: 5.730033  (hab: 9.061876, mov: 5.795771)
Validation loss decreased (5.770297 → 5.730033). Saving model…

Epoch 10/50
loss:        6.179825  [   32/ 8081]
loss:        5.357401  [  352/ 8081]
loss:        6.809822  [  672/ 8081]
loss:        6.472744  [  992/ 8081]
loss:        6.635061  [ 1312/ 8081]
loss:        6.397947  [ 1632/ 8081]
loss:        5.919065  [ 1952/ 8081]
loss:        5.730613  [ 2272/ 8081]
loss:        5.453868  [ 2592/ 8081]
loss:        5.638618  [ 2912/ 8081]
loss:        6.182692  [ 3232/ 8081]
loss:        6.336976  [ 3552/ 8081]
loss:        6.059975  [ 3872/ 8081]
loss:        6.442751  [ 4192/ 8081]
loss:        5.999722  [ 4512/ 8081]
loss:        6.285412  [ 4832/ 8081]
loss:        5.633966  [ 5152/ 8081]
loss:        6.046505  [ 5472/ 8081]
loss:        6.264528  [ 5792/ 8081]
loss:        5.914484  [ 6112/ 8081]
loss:        5.977823  [ 6432/ 8081]
loss:        5.719940  [ 6752/ 8081]
loss:        6.066643  [ 7072/ 8081]
loss:        5.849580  [ 7392/ 8081]
loss:        5.411000  [ 7712/ 8081]
loss:        6.022013  [ 8032/ 8081]

Avg training loss:        5.809240
Val loss: 5.689379  (hab: 9.040149, mov: 5.764468)
Validation loss decreased (5.730033 → 5.689379). Saving model…

Epoch 11/50
loss:        4.907570  [   32/ 8081]
loss:        5.889924  [  352/ 8081]
loss:        5.956709  [  672/ 8081]
loss:        5.759821  [  992/ 8081]
loss:        6.424210  [ 1312/ 8081]
loss:        6.346479  [ 1632/ 8081]
loss:        5.922520  [ 1952/ 8081]
loss:        5.436238  [ 2272/ 8081]
loss:        6.612879  [ 2592/ 8081]
loss:        5.339804  [ 2912/ 8081]
loss:        6.211967  [ 3232/ 8081]
loss:        5.464496  [ 3552/ 8081]
loss:        6.156346  [ 3872/ 8081]
loss:        5.153495  [ 4192/ 8081]
loss:        5.414938  [ 4512/ 8081]
loss:        5.473416  [ 4832/ 8081]
loss:        5.691952  [ 5152/ 8081]
loss:        5.547223  [ 5472/ 8081]
loss:        5.611948  [ 5792/ 8081]
loss:        5.708721  [ 6112/ 8081]
loss:        5.696208  [ 6432/ 8081]
loss:        6.019831  [ 6752/ 8081]
loss:        4.553470  [ 7072/ 8081]
loss:        6.459009  [ 7392/ 8081]
loss:        5.727405  [ 7712/ 8081]
loss:        6.545286  [ 8032/ 8081]

Avg training loss:        5.786154
Val loss: 5.668725  (hab: 9.042551, mov: 5.741380)
Validation loss decreased (5.689379 → 5.668725). Saving model…

Epoch 12/50
loss:        5.616603  [   32/ 8081]
loss:        6.049252  [  352/ 8081]
loss:        4.835304  [  672/ 8081]
loss:        5.257712  [  992/ 8081]
loss:        5.597521  [ 1312/ 8081]
loss:        5.995442  [ 1632/ 8081]
loss:        5.782512  [ 1952/ 8081]
loss:        6.130647  [ 2272/ 8081]
loss:        5.720591  [ 2592/ 8081]
loss:        6.001844  [ 2912/ 8081]
loss:        4.933335  [ 3232/ 8081]
loss:        6.482705  [ 3552/ 8081]
loss:        5.820237  [ 3872/ 8081]
loss:        5.733268  [ 4192/ 8081]
loss:        5.900891  [ 4512/ 8081]
loss:        6.137032  [ 4832/ 8081]
loss:        5.593785  [ 5152/ 8081]
loss:        6.598097  [ 5472/ 8081]
loss:        4.435269  [ 5792/ 8081]
loss:        5.126099  [ 6112/ 8081]
loss:        5.691244  [ 6432/ 8081]
loss:        6.051795  [ 6752/ 8081]
loss:        6.269933  [ 7072/ 8081]
loss:        5.587019  [ 7392/ 8081]
loss:        5.059365  [ 7712/ 8081]
loss:        6.231275  [ 8032/ 8081]

Avg training loss:        5.766855
Val loss: 5.656562  (hab: 9.043391, mov: 5.725366)
Validation loss decreased (5.668725 → 5.656562). Saving model…

Epoch 13/50
loss:        5.606125  [   32/ 8081]
loss:        5.722439  [  352/ 8081]
loss:        5.606833  [  672/ 8081]
loss:        5.216550  [  992/ 8081]
loss:        4.678419  [ 1312/ 8081]
loss:        5.843431  [ 1632/ 8081]
loss:        5.164611  [ 1952/ 8081]
loss:        4.773099  [ 2272/ 8081]
loss:        7.645373  [ 2592/ 8081]
loss:        5.885402  [ 2912/ 8081]
loss:        5.255199  [ 3232/ 8081]
loss:        6.466753  [ 3552/ 8081]
loss:        5.693366  [ 3872/ 8081]
loss:        5.614701  [ 4192/ 8081]
loss:        5.959176  [ 4512/ 8081]
loss:        5.948330  [ 4832/ 8081]
loss:        5.129210  [ 5152/ 8081]
loss:        5.860931  [ 5472/ 8081]
loss:        5.058755  [ 5792/ 8081]
loss:        4.907567  [ 6112/ 8081]
loss:        5.229248  [ 6432/ 8081]
loss:        5.879795  [ 6752/ 8081]
loss:        5.287624  [ 7072/ 8081]
loss:        6.196038  [ 7392/ 8081]
loss:        5.572518  [ 7712/ 8081]
loss:        6.908350  [ 8032/ 8081]

Avg training loss:        5.755663
Val loss: 5.641256  (hab: 9.037884, mov: 5.714141)
Validation loss decreased (5.656562 → 5.641256). Saving model…

Epoch 14/50
loss:        4.964145  [   32/ 8081]
loss:        5.318380  [  352/ 8081]
loss:        5.127728  [  672/ 8081]
loss:        6.157777  [  992/ 8081]
loss:        6.613916  [ 1312/ 8081]
loss:        6.080698  [ 1632/ 8081]
loss:        5.652702  [ 1952/ 8081]
loss:        5.900042  [ 2272/ 8081]
loss:        6.858476  [ 2592/ 8081]
loss:        5.943289  [ 2912/ 8081]
loss:        5.797606  [ 3232/ 8081]
loss:        5.417280  [ 3552/ 8081]
loss:        5.564807  [ 3872/ 8081]
loss:        5.571766  [ 4192/ 8081]
loss:        6.216476  [ 4512/ 8081]
loss:        5.974131  [ 4832/ 8081]
loss:        5.863035  [ 5152/ 8081]
loss:        6.162068  [ 5472/ 8081]
loss:        6.523361  [ 5792/ 8081]
loss:        5.621315  [ 6112/ 8081]
loss:        6.325366  [ 6432/ 8081]
loss:        6.285115  [ 6752/ 8081]
loss:        5.690443  [ 7072/ 8081]
loss:        6.123982  [ 7392/ 8081]
loss:        4.844786  [ 7712/ 8081]
loss:        5.590330  [ 8032/ 8081]

Avg training loss:        5.744881
Val loss: 5.644368  (hab: 9.050741, mov: 5.706220)
EarlyStopping counter: 1 out of 5

Epoch 15/50
loss:        5.449314  [   32/ 8081]
loss:        5.857429  [  352/ 8081]
loss:        6.305021  [  672/ 8081]
loss:        6.457657  [  992/ 8081]
loss:        5.897919  [ 1312/ 8081]
loss:        5.853174  [ 1632/ 8081]
loss:        5.240265  [ 1952/ 8081]
loss:        5.410861  [ 2272/ 8081]
loss:        5.105857  [ 2592/ 8081]
loss:        5.933477  [ 2912/ 8081]
loss:        6.734917  [ 3232/ 8081]
loss:        5.744835  [ 3552/ 8081]
loss:        5.404105  [ 3872/ 8081]
loss:        6.486819  [ 4192/ 8081]
loss:        6.072401  [ 4512/ 8081]
loss:        5.728040  [ 4832/ 8081]
loss:        5.800982  [ 5152/ 8081]
loss:        6.108822  [ 5472/ 8081]
loss:        6.193382  [ 5792/ 8081]
loss:        5.731816  [ 6112/ 8081]
loss:        5.498885  [ 6432/ 8081]
loss:        6.113375  [ 6752/ 8081]
loss:        5.672741  [ 7072/ 8081]
loss:        5.861865  [ 7392/ 8081]
loss:        4.941085  [ 7712/ 8081]
loss:        5.365587  [ 8032/ 8081]

Avg training loss:        5.736278
Val loss: 5.636805  (hab: 9.041551, mov: 5.700096)
Validation loss decreased (5.641256 → 5.636805). Saving model…

Epoch 16/50
loss:        6.240680  [   32/ 8081]
loss:        5.571955  [  352/ 8081]
loss:        5.680108  [  672/ 8081]
loss:        6.780542  [  992/ 8081]
loss:        5.652176  [ 1312/ 8081]
loss:        5.889809  [ 1632/ 8081]
loss:        5.862473  [ 1952/ 8081]
loss:        5.809608  [ 2272/ 8081]
loss:        5.217363  [ 2592/ 8081]
loss:        6.361870  [ 2912/ 8081]
loss:        5.102130  [ 3232/ 8081]
loss:        5.680561  [ 3552/ 8081]
loss:        5.477432  [ 3872/ 8081]
loss:        4.823541  [ 4192/ 8081]
loss:        5.801976  [ 4512/ 8081]
loss:        5.822965  [ 4832/ 8081]
loss:        5.812321  [ 5152/ 8081]
loss:        5.775498  [ 5472/ 8081]
loss:        5.370294  [ 5792/ 8081]
loss:        4.826260  [ 6112/ 8081]
loss:        5.204805  [ 6432/ 8081]
loss:        5.874338  [ 6752/ 8081]
loss:        6.099152  [ 7072/ 8081]
loss:        6.395084  [ 7392/ 8081]
loss:        5.619109  [ 7712/ 8081]
loss:        6.167995  [ 8032/ 8081]

Avg training loss:        5.734382
Val loss: 5.637286  (hab: 9.058554, mov: 5.695152)
EarlyStopping counter: 1 out of 5

Epoch 17/50
loss:        5.759869  [   32/ 8081]
loss:        5.908279  [  352/ 8081]
loss:        4.975919  [  672/ 8081]
loss:        6.317106  [  992/ 8081]
loss:        4.888456  [ 1312/ 8081]
loss:        5.638083  [ 1632/ 8081]
loss:        5.614165  [ 1952/ 8081]
loss:        5.462237  [ 2272/ 8081]
loss:        5.872178  [ 2592/ 8081]
loss:        5.855830  [ 2912/ 8081]
loss:        6.183403  [ 3232/ 8081]
loss:        6.292447  [ 3552/ 8081]
loss:        5.343740  [ 3872/ 8081]
loss:        4.814789  [ 4192/ 8081]
loss:        6.485147  [ 4512/ 8081]
loss:        6.744562  [ 4832/ 8081]
loss:        4.991343  [ 5152/ 8081]
loss:        5.253975  [ 5472/ 8081]
loss:        5.706602  [ 5792/ 8081]
loss:        5.745562  [ 6112/ 8081]
loss:        5.470129  [ 6432/ 8081]
loss:        5.297190  [ 6752/ 8081]
loss:        4.941541  [ 7072/ 8081]
loss:        5.556078  [ 7392/ 8081]
loss:        5.601835  [ 7712/ 8081]
loss:        5.165209  [ 8032/ 8081]

Avg training loss:        5.728193
Val loss: 5.634346  (hab: 9.078211, mov: 5.691624)
Validation loss decreased (5.636805 → 5.634346). Saving model…

Epoch 18/50
loss:        5.580297  [   32/ 8081]
loss:        6.001015  [  352/ 8081]
loss:        5.586500  [  672/ 8081]
loss:        6.604277  [  992/ 8081]
loss:        6.078627  [ 1312/ 8081]
loss:        5.917189  [ 1632/ 8081]
loss:        6.071002  [ 1952/ 8081]
loss:        5.724484  [ 2272/ 8081]
loss:        5.730997  [ 2592/ 8081]
loss:        5.547022  [ 2912/ 8081]
loss:        5.203174  [ 3232/ 8081]
loss:        5.844120  [ 3552/ 8081]
loss:        5.323066  [ 3872/ 8081]
loss:        6.056065  [ 4192/ 8081]
loss:        5.626657  [ 4512/ 8081]
loss:        6.334327  [ 4832/ 8081]
loss:        5.639535  [ 5152/ 8081]
loss:        5.718531  [ 5472/ 8081]
loss:        5.517420  [ 5792/ 8081]
loss:        5.156143  [ 6112/ 8081]
loss:        5.721713  [ 6432/ 8081]
loss:        5.354311  [ 6752/ 8081]
loss:        5.269536  [ 7072/ 8081]
loss:        5.242304  [ 7392/ 8081]
loss:        6.042232  [ 7712/ 8081]
loss:        5.906689  [ 8032/ 8081]

Avg training loss:        5.728079
Val loss: 5.618409  (hab: 9.059954, mov: 5.688412)
Validation loss decreased (5.634346 → 5.618409). Saving model…

Epoch 19/50
loss:        6.044871  [   32/ 8081]
loss:        5.803884  [  352/ 8081]
loss:        5.660613  [  672/ 8081]
loss:        5.212093  [  992/ 8081]
loss:        5.937016  [ 1312/ 8081]
loss:        5.890323  [ 1632/ 8081]
loss:        6.329659  [ 1952/ 8081]
loss:        6.529563  [ 2272/ 8081]
loss:        6.605713  [ 2592/ 8081]
loss:        5.811274  [ 2912/ 8081]
loss:        5.425788  [ 3232/ 8081]
loss:        5.407681  [ 3552/ 8081]
loss:        5.044333  [ 3872/ 8081]
loss:        6.168036  [ 4192/ 8081]
loss:        5.366661  [ 4512/ 8081]
loss:        4.933005  [ 4832/ 8081]
loss:        6.039813  [ 5152/ 8081]
loss:        5.939240  [ 5472/ 8081]
loss:        5.947885  [ 5792/ 8081]
loss:        6.139528  [ 6112/ 8081]
loss:        4.153102  [ 6432/ 8081]
loss:        5.557661  [ 6752/ 8081]
loss:        5.384722  [ 7072/ 8081]
loss:        5.586881  [ 7392/ 8081]
loss:        5.676156  [ 7712/ 8081]
loss:        5.556557  [ 8032/ 8081]

Avg training loss:        5.715559
Val loss: 5.616465  (hab: 9.050395, mov: 5.686038)
Validation loss decreased (5.618409 → 5.616465). Saving model…

Epoch 20/50
loss:        4.815418  [   32/ 8081]
loss:        5.574852  [  352/ 8081]
loss:        6.128708  [  672/ 8081]
loss:        5.583793  [  992/ 8081]
loss:        5.282274  [ 1312/ 8081]
loss:        5.726599  [ 1632/ 8081]
loss:        5.874033  [ 1952/ 8081]
loss:        6.185098  [ 2272/ 8081]
loss:        4.892152  [ 2592/ 8081]
loss:        5.481307  [ 2912/ 8081]
loss:        5.608483  [ 3232/ 8081]
loss:        5.819208  [ 3552/ 8081]
loss:        6.036858  [ 3872/ 8081]
loss:        6.429809  [ 4192/ 8081]
loss:        5.757306  [ 4512/ 8081]
loss:        5.805862  [ 4832/ 8081]
loss:        5.315531  [ 5152/ 8081]
loss:        5.275296  [ 5472/ 8081]
loss:        6.169271  [ 5792/ 8081]
loss:        5.387419  [ 6112/ 8081]
loss:        4.782458  [ 6432/ 8081]
loss:        5.168671  [ 6752/ 8081]
loss:        4.899096  [ 7072/ 8081]
loss:        4.789249  [ 7392/ 8081]
loss:        5.733751  [ 7712/ 8081]
loss:        5.973266  [ 8032/ 8081]

Avg training loss:        5.718753
Val loss: 5.615386  (hab: 9.056146, mov: 5.683999)
Validation loss decreased (5.616465 → 5.615386). Saving model…

Epoch 21/50
loss:        6.183380  [   32/ 8081]
loss:        5.188821  [  352/ 8081]
loss:        6.013844  [  672/ 8081]
loss:        5.708962  [  992/ 8081]
loss:        4.720342  [ 1312/ 8081]
loss:        5.139887  [ 1632/ 8081]
loss:        6.798734  [ 1952/ 8081]
loss:        5.688621  [ 2272/ 8081]
loss:        5.764734  [ 2592/ 8081]
loss:        5.178732  [ 2912/ 8081]
loss:        5.693705  [ 3232/ 8081]
loss:        5.228932  [ 3552/ 8081]
loss:        5.548631  [ 3872/ 8081]
loss:        6.048266  [ 4192/ 8081]
loss:        5.535581  [ 4512/ 8081]
loss:        5.740133  [ 4832/ 8081]
loss:        5.634562  [ 5152/ 8081]
loss:        5.548189  [ 5472/ 8081]
loss:        6.467731  [ 5792/ 8081]
loss:        5.138684  [ 6112/ 8081]
loss:        5.851662  [ 6432/ 8081]
loss:        5.746037  [ 6752/ 8081]
loss:        6.040398  [ 7072/ 8081]
loss:        5.061504  [ 7392/ 8081]
loss:        6.137201  [ 7712/ 8081]
loss:        5.381677  [ 8032/ 8081]

Avg training loss:        5.711886
Val loss: 5.612343  (hab: 9.060971, mov: 5.682530)
Validation loss decreased (5.615386 → 5.612343). Saving model…

Epoch 22/50
loss:        6.297585  [   32/ 8081]
loss:        5.578657  [  352/ 8081]
loss:        6.305127  [  672/ 8081]
loss:        5.026219  [  992/ 8081]
loss:        6.537750  [ 1312/ 8081]
loss:        5.491141  [ 1632/ 8081]
loss:        5.449276  [ 1952/ 8081]
loss:        5.750866  [ 2272/ 8081]
loss:        6.007241  [ 2592/ 8081]
loss:        5.954107  [ 2912/ 8081]
loss:        6.736318  [ 3232/ 8081]
loss:        4.786770  [ 3552/ 8081]
loss:        5.596262  [ 3872/ 8081]
loss:        5.670405  [ 4192/ 8081]
loss:        5.491155  [ 4512/ 8081]
loss:        5.541801  [ 4832/ 8081]
loss:        7.086232  [ 5152/ 8081]
loss:        4.686992  [ 5472/ 8081]
loss:        5.861276  [ 5792/ 8081]
loss:        5.906724  [ 6112/ 8081]
loss:        6.650864  [ 6432/ 8081]
loss:        5.598725  [ 6752/ 8081]
loss:        5.054667  [ 7072/ 8081]
loss:        6.816591  [ 7392/ 8081]
loss:        5.817470  [ 7712/ 8081]
loss:        7.007591  [ 8032/ 8081]

Avg training loss:        5.712276
Val loss: 5.613923  (hab: 9.068626, mov: 5.681161)
EarlyStopping counter: 1 out of 5

Epoch 23/50
loss:        6.970927  [   32/ 8081]
loss:        6.394039  [  352/ 8081]
loss:        5.422188  [  672/ 8081]
loss:        5.950150  [  992/ 8081]
loss:        4.757160  [ 1312/ 8081]
loss:        5.645199  [ 1632/ 8081]
loss:        5.659248  [ 1952/ 8081]
loss:        5.165676  [ 2272/ 8081]
loss:        5.686069  [ 2592/ 8081]
loss:        5.184331  [ 2912/ 8081]
loss:        5.856824  [ 3232/ 8081]
loss:        5.231144  [ 3552/ 8081]
loss:        6.129216  [ 3872/ 8081]
loss:        5.775649  [ 4192/ 8081]
loss:        6.582394  [ 4512/ 8081]
loss:        5.029494  [ 4832/ 8081]
loss:        5.339326  [ 5152/ 8081]
loss:        6.114333  [ 5472/ 8081]
loss:        6.511226  [ 5792/ 8081]
loss:        6.880102  [ 6112/ 8081]
loss:        5.542309  [ 6432/ 8081]
loss:        6.349352  [ 6752/ 8081]
loss:        5.943939  [ 7072/ 8081]
loss:        5.804504  [ 7392/ 8081]
loss:        5.836282  [ 7712/ 8081]
loss:        6.245509  [ 8032/ 8081]

Avg training loss:        5.708841
Val loss: 5.613580  (hab: 9.071318, mov: 5.680023)
EarlyStopping counter: 2 out of 5

Epoch 24/50
loss:        6.344127  [   32/ 8081]
loss:        6.336230  [  352/ 8081]
loss:        6.083117  [  672/ 8081]
loss:        5.233209  [  992/ 8081]
loss:        5.865287  [ 1312/ 8081]
loss:        5.156081  [ 1632/ 8081]
loss:        5.540485  [ 1952/ 8081]
loss:        6.385826  [ 2272/ 8081]
loss:        5.476202  [ 2592/ 8081]
loss:        5.373429  [ 2912/ 8081]
loss:        5.734528  [ 3232/ 8081]
loss:        6.616238  [ 3552/ 8081]
loss:        4.775983  [ 3872/ 8081]
loss:        5.515054  [ 4192/ 8081]
loss:        6.089822  [ 4512/ 8081]
loss:        4.916617  [ 4832/ 8081]
loss:        5.707019  [ 5152/ 8081]
loss:        6.043135  [ 5472/ 8081]
loss:        5.688140  [ 5792/ 8081]
loss:        5.170226  [ 6112/ 8081]
loss:        5.233944  [ 6432/ 8081]
loss:        5.507638  [ 6752/ 8081]
loss:        4.514205  [ 7072/ 8081]
loss:        5.612385  [ 7392/ 8081]
loss:        5.405685  [ 7712/ 8081]
loss:        5.144617  [ 8032/ 8081]

Avg training loss:        5.708871
Val loss: 5.612180  (hab: 9.070977, mov: 5.678934)
Validation loss decreased (5.612343 → 5.612180). Saving model…

Epoch 25/50
loss:        5.483903  [   32/ 8081]
loss:        5.557379  [  352/ 8081]
loss:        6.426230  [  672/ 8081]
loss:        5.649759  [  992/ 8081]
loss:        6.312637  [ 1312/ 8081]
loss:        5.019237  [ 1632/ 8081]
loss:        5.405530  [ 1952/ 8081]
loss:        5.900116  [ 2272/ 8081]
loss:        5.370254  [ 2592/ 8081]
loss:        6.170698  [ 2912/ 8081]
loss:        6.044189  [ 3232/ 8081]
loss:        5.418559  [ 3552/ 8081]
loss:        5.996602  [ 3872/ 8081]
loss:        5.943696  [ 4192/ 8081]
loss:        5.810556  [ 4512/ 8081]
loss:        5.618696  [ 4832/ 8081]
loss:        6.685282  [ 5152/ 8081]
loss:        5.771829  [ 5472/ 8081]
loss:        5.538842  [ 5792/ 8081]
loss:        5.731853  [ 6112/ 8081]
loss:        4.953280  [ 6432/ 8081]
loss:        6.067793  [ 6752/ 8081]
loss:        5.686928  [ 7072/ 8081]
loss:        5.837258  [ 7392/ 8081]
loss:        5.619393  [ 7712/ 8081]
loss:        6.014652  [ 8032/ 8081]

Avg training loss:        5.705319
Val loss: 5.614037  (hab: 9.075295, mov: 5.677942)
EarlyStopping counter: 1 out of 5

Epoch 26/50
loss:        5.019064  [   32/ 8081]
loss:        4.745184  [  352/ 8081]
loss:        5.631293  [  672/ 8081]
loss:        6.071359  [  992/ 8081]
loss:        5.766923  [ 1312/ 8081]
loss:        6.177391  [ 1632/ 8081]
loss:        5.471320  [ 1952/ 8081]
loss:        5.027010  [ 2272/ 8081]
loss:        5.461694  [ 2592/ 8081]
loss:        5.205660  [ 2912/ 8081]
loss:        5.392303  [ 3232/ 8081]
loss:        5.890113  [ 3552/ 8081]
loss:        5.996963  [ 3872/ 8081]
loss:        5.888836  [ 4192/ 8081]
loss:        5.782978  [ 4512/ 8081]
loss:        5.287475  [ 4832/ 8081]
loss:        5.440210  [ 5152/ 8081]
loss:        6.513184  [ 5472/ 8081]
loss:        5.154773  [ 5792/ 8081]
loss:        4.872709  [ 6112/ 8081]
loss:        5.144480  [ 6432/ 8081]
loss:        6.164564  [ 6752/ 8081]
loss:        5.565736  [ 7072/ 8081]
loss:        6.032124  [ 7392/ 8081]
loss:        6.110894  [ 7712/ 8081]
loss:        5.697278  [ 8032/ 8081]

Avg training loss:        5.697607
Val loss: 5.610400  (hab: 9.067782, mov: 5.677879)
Validation loss decreased (5.612180 → 5.610400). Saving model…

Epoch 27/50
loss:        6.097190  [   32/ 8081]
loss:        5.636503  [  352/ 8081]
loss:        6.254807  [  672/ 8081]
loss:        5.250633  [  992/ 8081]
loss:        5.113654  [ 1312/ 8081]
loss:        5.435678  [ 1632/ 8081]
loss:        5.517510  [ 1952/ 8081]
loss:        5.493887  [ 2272/ 8081]
loss:        6.069567  [ 2592/ 8081]
loss:        4.473845  [ 2912/ 8081]
loss:        4.820729  [ 3232/ 8081]
loss:        6.595756  [ 3552/ 8081]
loss:        5.957569  [ 3872/ 8081]
loss:        5.818820  [ 4192/ 8081]
loss:        6.787045  [ 4512/ 8081]
loss:        6.415217  [ 4832/ 8081]
loss:        6.801254  [ 5152/ 8081]
loss:        6.635920  [ 5472/ 8081]
loss:        6.249227  [ 5792/ 8081]
loss:        5.184406  [ 6112/ 8081]
loss:        6.028659  [ 6432/ 8081]
loss:        6.063338  [ 6752/ 8081]
loss:        4.167001  [ 7072/ 8081]
loss:        6.044289  [ 7392/ 8081]
loss:        5.910362  [ 7712/ 8081]
loss:        5.351403  [ 8032/ 8081]

Avg training loss:        5.697752
Val loss: 5.609746  (hab: 9.066412, mov: 5.677803)
Validation loss decreased (5.610400 → 5.609746). Saving model…

Epoch 28/50
loss:        6.096348  [   32/ 8081]
loss:        6.305573  [  352/ 8081]
loss:        5.602438  [  672/ 8081]
loss:        6.977356  [  992/ 8081]
loss:        5.385228  [ 1312/ 8081]
loss:        5.474465  [ 1632/ 8081]
loss:        4.792825  [ 1952/ 8081]
loss:        5.840420  [ 2272/ 8081]
loss:        6.240949  [ 2592/ 8081]
loss:        5.857591  [ 2912/ 8081]
loss:        5.636131  [ 3232/ 8081]
loss:        5.702084  [ 3552/ 8081]
loss:        5.729275  [ 3872/ 8081]
loss:        6.326180  [ 4192/ 8081]
loss:        5.094948  [ 4512/ 8081]
loss:        6.315386  [ 4832/ 8081]
loss:        5.568670  [ 5152/ 8081]
loss:        5.235243  [ 5472/ 8081]
loss:        5.828605  [ 5792/ 8081]
loss:        5.036157  [ 6112/ 8081]
loss:        5.372856  [ 6432/ 8081]
loss:        5.544307  [ 6752/ 8081]
loss:        6.524855  [ 7072/ 8081]
loss:        5.273438  [ 7392/ 8081]
loss:        6.219529  [ 7712/ 8081]
loss:        6.204034  [ 8032/ 8081]

Avg training loss:        5.697381
Val loss: 5.609159  (hab: 9.065305, mov: 5.677736)
Validation loss decreased (5.609746 → 5.609159). Saving model…

Epoch 29/50
loss:        6.185834  [   32/ 8081]
loss:        6.085011  [  352/ 8081]
loss:        6.083863  [  672/ 8081]
loss:        5.526052  [  992/ 8081]
loss:        6.035280  [ 1312/ 8081]
loss:        5.940582  [ 1632/ 8081]
loss:        5.954806  [ 1952/ 8081]
loss:        5.452516  [ 2272/ 8081]
loss:        5.605563  [ 2592/ 8081]
loss:        6.698833  [ 2912/ 8081]
loss:        5.330328  [ 3232/ 8081]
loss:        5.943144  [ 3552/ 8081]
loss:        5.416422  [ 3872/ 8081]
loss:        5.966122  [ 4192/ 8081]
loss:        5.665062  [ 4512/ 8081]
loss:        5.516191  [ 4832/ 8081]
loss:        5.457212  [ 5152/ 8081]
loss:        5.502567  [ 5472/ 8081]
loss:        6.032925  [ 5792/ 8081]
loss:        5.336675  [ 6112/ 8081]
loss:        5.972738  [ 6432/ 8081]
loss:        5.189960  [ 6752/ 8081]
loss:        5.246009  [ 7072/ 8081]
loss:        5.992759  [ 7392/ 8081]
loss:        5.654251  [ 7712/ 8081]
loss:        5.759189  [ 8032/ 8081]

Avg training loss:        5.696073
Val loss: 5.606671  (hab: 9.061607, mov: 5.677664)
Validation loss decreased (5.609159 → 5.606671). Saving model…

Epoch 30/50
loss:        5.685652  [   32/ 8081]
loss:        6.224775  [  352/ 8081]
loss:        5.143509  [  672/ 8081]
loss:        6.062613  [  992/ 8081]
loss:        6.063200  [ 1312/ 8081]
loss:        4.796700  [ 1632/ 8081]
loss:        6.347664  [ 1952/ 8081]
loss:        4.513020  [ 2272/ 8081]
loss:        5.597942  [ 2592/ 8081]
loss:        6.043799  [ 2912/ 8081]
loss:        6.838120  [ 3232/ 8081]
loss:        5.480160  [ 3552/ 8081]
loss:        5.526507  [ 3872/ 8081]
loss:        6.455516  [ 4192/ 8081]
loss:        5.712759  [ 4512/ 8081]
loss:        6.032924  [ 4832/ 8081]
loss:        5.487291  [ 5152/ 8081]
loss:        6.136601  [ 5472/ 8081]
loss:        5.893800  [ 5792/ 8081]
loss:        7.423604  [ 6112/ 8081]
loss:        6.322469  [ 6432/ 8081]
loss:        6.444270  [ 6752/ 8081]
loss:        5.683311  [ 7072/ 8081]
loss:        5.644976  [ 7392/ 8081]
loss:        6.008801  [ 7712/ 8081]
loss:        5.804708  [ 8032/ 8081]

Avg training loss:        5.696024
Val loss: 5.607674  (hab: 9.063704, mov: 5.677591)
EarlyStopping counter: 1 out of 5

Epoch 31/50
loss:        5.649486  [   32/ 8081]
loss:        6.746201  [  352/ 8081]
loss:        6.346245  [  672/ 8081]
loss:        5.224216  [  992/ 8081]
loss:        5.047760  [ 1312/ 8081]
loss:        6.521373  [ 1632/ 8081]
loss:        6.068427  [ 1952/ 8081]
loss:        6.159099  [ 2272/ 8081]
loss:        5.539003  [ 2592/ 8081]
loss:        6.329982  [ 2912/ 8081]
loss:        5.931009  [ 3232/ 8081]
loss:        6.581205  [ 3552/ 8081]
loss:        5.904395  [ 3872/ 8081]
loss:        5.335628  [ 4192/ 8081]
loss:        5.552714  [ 4512/ 8081]
loss:        6.431862  [ 4832/ 8081]
loss:        6.130784  [ 5152/ 8081]
loss:        5.946055  [ 5472/ 8081]
loss:        5.180882  [ 5792/ 8081]
loss:        5.691638  [ 6112/ 8081]
loss:        5.541974  [ 6432/ 8081]
loss:        5.711466  [ 6752/ 8081]
loss:        5.295522  [ 7072/ 8081]
loss:        6.359267  [ 7392/ 8081]
loss:        5.490843  [ 7712/ 8081]
loss:        5.619943  [ 8032/ 8081]

Avg training loss:        5.693800
Val loss: 5.608880  (hab: 9.063725, mov: 5.677515)
EarlyStopping counter: 2 out of 5

Epoch 32/50
loss:        5.864106  [   32/ 8081]
loss:        5.331354  [  352/ 8081]
loss:        6.225701  [  672/ 8081]
loss:        5.892544  [  992/ 8081]
loss:        5.271852  [ 1312/ 8081]
loss:        5.296095  [ 1632/ 8081]
loss:        5.734218  [ 1952/ 8081]
loss:        6.311450  [ 2272/ 8081]
loss:        6.817872  [ 2592/ 8081]
loss:        6.007078  [ 2912/ 8081]
loss:        5.932118  [ 3232/ 8081]
loss:        5.492791  [ 3552/ 8081]
loss:        5.910633  [ 3872/ 8081]
loss:        5.670037  [ 4192/ 8081]
loss:        6.026743  [ 4512/ 8081]
loss:        5.508593  [ 4832/ 8081]
loss:        5.124711  [ 5152/ 8081]
loss:        5.578349  [ 5472/ 8081]
loss:        5.548217  [ 5792/ 8081]
loss:        5.888811  [ 6112/ 8081]
loss:        5.293932  [ 6432/ 8081]
loss:        5.474451  [ 6752/ 8081]
loss:        5.046737  [ 7072/ 8081]
loss:        5.431853  [ 7392/ 8081]
loss:        5.443534  [ 7712/ 8081]
loss:        4.504044  [ 8032/ 8081]

Avg training loss:        5.696063
Val loss: 5.609391  (hab: 9.063499, mov: 5.677441)
EarlyStopping counter: 3 out of 5

Epoch 33/50
loss:        6.485788  [   32/ 8081]
loss:        6.010224  [  352/ 8081]
loss:        5.579084  [  672/ 8081]
loss:        6.287263  [  992/ 8081]
loss:        5.225684  [ 1312/ 8081]
loss:        5.554420  [ 1632/ 8081]
loss:        6.013733  [ 1952/ 8081]
loss:        5.896962  [ 2272/ 8081]
loss:        5.516796  [ 2592/ 8081]
loss:        4.663539  [ 2912/ 8081]
loss:        5.070152  [ 3232/ 8081]
loss:        5.772351  [ 3552/ 8081]
loss:        5.085042  [ 3872/ 8081]
loss:        6.083080  [ 4192/ 8081]
loss:        4.705304  [ 4512/ 8081]
loss:        6.495463  [ 4832/ 8081]
loss:        5.870428  [ 5152/ 8081]
loss:        5.244241  [ 5472/ 8081]
loss:        6.315904  [ 5792/ 8081]
loss:        7.024615  [ 6112/ 8081]
loss:        5.152605  [ 6432/ 8081]
loss:        5.637001  [ 6752/ 8081]
loss:        5.996116  [ 7072/ 8081]
loss:        5.707726  [ 7392/ 8081]
loss:        5.026938  [ 7712/ 8081]
loss:        5.096850  [ 8032/ 8081]

Avg training loss:        5.691970
Val loss: 5.607607  (hab: 9.061089, mov: 5.677368)
EarlyStopping counter: 4 out of 5

Epoch 34/50
loss:        5.588144  [   32/ 8081]
loss:        5.542765  [  352/ 8081]
loss:        6.053326  [  672/ 8081]
loss:        5.045709  [  992/ 8081]
loss:        5.643415  [ 1312/ 8081]
loss:        5.707616  [ 1632/ 8081]
loss:        6.182617  [ 1952/ 8081]
loss:        5.912023  [ 2272/ 8081]
loss:        5.646701  [ 2592/ 8081]
loss:        6.887000  [ 2912/ 8081]
loss:        5.904776  [ 3232/ 8081]
loss:        7.279239  [ 3552/ 8081]
loss:        5.329747  [ 3872/ 8081]
loss:        6.507558  [ 4192/ 8081]
loss:        5.086779  [ 4512/ 8081]
loss:        5.902202  [ 4832/ 8081]
loss:        5.610799  [ 5152/ 8081]
loss:        6.010343  [ 5472/ 8081]
loss:        5.060927  [ 5792/ 8081]
loss:        7.418796  [ 6112/ 8081]
loss:        6.054123  [ 6432/ 8081]
loss:        5.982379  [ 6752/ 8081]
loss:        6.061192  [ 7072/ 8081]
loss:        4.860736  [ 7392/ 8081]
loss:        5.596017  [ 7712/ 8081]
loss:        5.766304  [ 8032/ 8081]

Avg training loss:        5.689735
Val loss: 5.607455  (hab: 9.060824, mov: 5.677361)
EarlyStopping counter: 5 out of 5
Early stopping triggered.

create_gif()

Assembles the snapshots in snapshot_dir into an animation so you can see how the predictions evolve as training progresses.

Code
gif_path = str(OUTPUT_DIR / "training_progress.gif")
create_gif(str(SNAPSHOT_DIR), gif_path, fps=5)
print(f"Training GIF → {gif_path}")

# Play the gif
display(HTML(f'<img src="{gif_path}" />'))
Animation saved: outputs/training_progress.gif
Training GIF → outputs/training_progress.gif
Code
fig, ax = plt.subplots(figsize=(7, 3))
ax.plot(history["train_losses"], label="train")
ax.plot(history["val_losses"],   label="val")
ax.set_xlabel("Epoch")
ax.set_ylabel("NLL loss")
ax.legend()
ax.set_title("Training history")
plt.tight_layout()
plt.savefig(str(OUTPUT_DIR / "loss_history.png"), dpi=100)
plt.show()

Test the model against withheld data

Convert the NumPy rasters loaded earlier into PyTorch tensors.

Both validate_next_step_probs() and simulate_trajectory() call get_landscape(month_index) at each step to fetch the current environmental layers as a list of 2-D tensors.

With static (non-seasonal) rasters the month_index argument is unused. For seasonal data (e.g. monthly NDVI composites) you would index into a stack of layers here based on the current month.

Code
ndvi_t  = torch.from_numpy(env_layers["ndvi"].astype(np.float32))
slope_t = torch.from_numpy(env_layers["slope"].astype(np.float32))

def get_landscape(_month_index):
    # Returns layers in the same order as LAYER_PATHS.
    return [ndvi_t, slope_t]

# Use the last 10 % of steps as the test set (matches the DataLoader split).
n_test      = int(len(step_df) * 0.1)
test_sample = step_df.iloc[-n_test:].reset_index(drop=True)
print(f"Testing on {len(test_sample):,} steps")
Testing on 1,010 steps

validate_next_step_probs()

Runs the model forward on each test step and records the probability mass assigned to the true next location.

  • Inputs:
    • model
    • test_sample
      • test dataframe that contains movement steps (from the data loaded above)
    • get_landscape
      • retrieves the landscape layers (spatial covariates) that were used to train the model
    • transform
      • raster Affine transform
    • window_size
      • window size of the local availability window (should be the same as for training)
    • scalar_cols
      • scalar column names (should be set from above)
    • yday_col
      • which column has the yday column to ensure the correct spatial layers (if dynamic through time) are indexed
    • month_index_fn
      • a function mapping day-of-year → month index (used by the get_landscape function). Here we are using static spatial layers so we can set it to 0.
  • Outputs:
    • test_sample dataframe with three new columns appended:
      • habitat_prob — probability at the true location from the habitat head alone
      • move_prob — probability at the true location from the movement head alone
      • next_step_prob — joint (habitat × movement) probability at the true location

Higher values mean the model performed better (there was higher predicted probability at the observed next-step).

Code
val_results = validate_next_step_probs(
    model,
    test_sample,
    get_landscape=get_landscape,
    transform=raster_transform,
    window_size=WINDOW_SIZE,
    scalar_cols=tuple(SCALAR_COLS),
    yday_col="yday_t1",
    bearing_col="bearing_tm1",
    month_index_fn=lambda _yday: 0,   # static layers
)
print(val_results[["habitat_prob", "move_prob", "next_step_prob"]].describe())
       habitat_prob    move_prob  next_step_prob
count   1010.000000  1010.000000     1010.000000
mean       0.000193     0.026397        0.035083
std        0.000179     0.051699        0.072018
min        0.000000     0.000000        0.000000
25%        0.000092     0.000240        0.000247
50%        0.000123     0.006742        0.007731
75%        0.000189     0.029893        0.036953
max        0.000675     0.254268        0.552189
Code
# Use the last 10 % of steps as the test set (matches the DataLoader split)
n_test      = int(len(step_df) * 0.1)
test_sample = step_df.iloc[-n_test:].reset_index(drop=True)
print(f"Validating on {len(test_sample):,} steps")

val_results = validate_next_step_probs(
    model,
    test_sample,
    get_landscape=get_landscape,
    transform=raster_transform,
    window_size=WINDOW_SIZE,
    scalar_cols=tuple(SCALAR_COLS),
    yday_col="yday_t1",
    bearing_col="bearing_tm1",
    month_index_fn=lambda _yday: 0,   # static layers — month doesn't matter
)
print(val_results[["habitat_prob", "move_prob", "next_step_prob"]].describe())
Validating on 1,010 steps
       habitat_prob    move_prob  next_step_prob
count   1010.000000  1010.000000     1010.000000
mean       0.000193     0.026397        0.035083
std        0.000179     0.051699        0.072018
min        0.000000     0.000000        0.000000
25%        0.000092     0.000240        0.000247
50%        0.000123     0.006742        0.007731
75%        0.000189     0.029893        0.036953
max        0.000675     0.254268        0.552189
Code
fig, axes = plt.subplots(1, 3, figsize=(12, 3))
for ax, col, title in zip(
    axes,
    ["habitat_prob", "move_prob", "next_step_prob"],
    ["Habitat", "Movement", "Next-step"],
):
    data = val_results[col].dropna()
    ax.hist(data, bins=40, edgecolor="white")
    ax.set_title(title)
    ax.set_xlabel("Probability")
axes[0].set_ylabel("Count")
plt.suptitle("Next-step probability distributions (test set)", y=1.02)
plt.tight_layout()
plt.savefig(str(OUTPUT_DIR / "validation_probs.png"), dpi=100)
plt.show()

Simulate new data

Select starting parameters for the model.

We’ve set the starting location to be the first location of the observed data.

Code
start_x = step_df["x1_"].iloc[0]
start_y = step_df["y1_"].iloc[0]
print(f"Starting location: ({start_x:.0f}, {start_y:.0f})")
Starting location: (41941, -1435875)

Simulate a trajectory

simulate_trajectory()

Generates a sequence of steps by iteratively: 1. Extracting a spatial window centred on the current location. 2. Running the model to obtain a joint log-probability surface over that window. 3. Sampling the next pixel from that surface (softmax-normalised probabilities). 4. Converting the sampled pixel back to projected coordinates and advancing time.

  • Inputs:
    • model
      • the trained model
    • get_landscape
      • retrieves the landscape layers (spatial covariates) that were used to train the model
    • transform
      • raster Affine transform
    • start_x
      • x coordinate of starting location
    • start_y
      • y coordinate of starting location
    • n_steps
      • number of steps
    • starting_yday
      • starting day-of-year
    • starting_hour
      • starting hour-of-day
    • time_between_steps
      • time between steps in hours
    • window_size
      • size in pixels of the local landscape
    • month_index_fn
      • a function mapping day-of-year → month index (used by the get_landscape function). Here we are using static spatial layers so we can set it to 0.
  • Outputs: DataFrame with one row per simulated step containing:
    • x, y
      — projected coordinates of the simulated location
    • hour, yday
      — time variables updated at each step
    • month_index
      • the month index used to index the spatial variables (will be 0 if using static spatial covariates)
Code
sim_df = simulate_trajectory(
    model,
    get_landscape=get_landscape,
    transform=raster_transform,
    start_x=start_x,
    start_y=start_y,
    n_steps=1000,
    starting_yday=test_sample["yday_t1"].iloc[0],
    starting_hour=test_sample["hour_t1"].iloc[0],
    time_between_steps=1.0,
    window_size=WINDOW_SIZE,
    month_index_fn=lambda _yday: 0,   # static layers
)
print(f"Simulated {len(sim_df)} steps")
sim_df.head()
Simulated 1000 steps
x y hour yday month_index
0 41905.272779 -1.435885e+06 18.916667 255.000000 0
1 41147.764883 -1.435758e+06 19.916667 255.041667 0
2 40708.724292 -1.435091e+06 20.916667 255.083333 0
3 40889.072755 -1.434925e+06 21.916667 255.125000 0
4 40892.806893 -1.434897e+06 22.916667 255.166667 0

Plot the simulated data

Code
fig, ax = plt.subplots(figsize=(7, 6))
rasterio.plot.show(env_layers["ndvi"], transform=raster_transform, ax=ax, cmap='viridis')
ax.plot(
    step_df["x1_"], step_df["y1_"],
    "w-o", alpha=0.5, markersize=1, linewidth=0.8, label="observed",
)
ax.plot(
    sim_df["x"], sim_df["y"],
    "r-o", markersize=1, linewidth=0.8, label="simulated",
)
ax.set_xlabel("Easting (m, projected)")
ax.set_ylabel("Northing (m, projected)")
ax.legend()
ax.set_title("Simulated vs observed buffalo locations")
plt.tight_layout()
plt.savefig(str(OUTPUT_DIR / "simulated_trajectory.png"), dpi=100)
plt.show()

References

Signer, Johannes, John Fieberg, and Tal Avgar. 2019. “Animal Movement Tools (Amt): R Package for Managing Tracking Data and Conducting Habitat Selection Analyses.” Ecology and Evolution 9 (2): 880–90. https://doi.org/10.1002/ece3.4823.