deepSSF — train, validate, simulate

End-to-end example using the bundled buffalo GPS dataset and two static raster covariates (NDVI and slope). The notebook calls the package throughout.

Sections 1. Import data 2. Prepare data 3. Train model 4. Validate against withheld data 5. Simulate a trajectory

import math
from pathlib import Path

import matplotlib.pyplot as plt
import numpy as np
import torch
import rasterio

import deepssf
from deepssf import (
    ConvJointModel,
    EarlyStopping,
    ModelParams,
    create_gif,
    fit,
    get_device,
    load_environmental_layers,
    make_dataloaders,
    make_optimisers,
    negativeLogLikeLoss,
    filter_steps_by_window,
    prepare_movement_df,
    simulate_trajectory,
    validate_next_step_probs,
)

import pandas as pd
from datetime import datetime                # Date/time utilities
from IPython.display import Image            # For plotting GIFs
from rasterio.plot import show               # For plotting raster layers

print(f"deepssf {deepssf.__version__}")
print(f"torch {torch.__version__}")
deepssf 0.1.0
torch 2.12.0
# Reproducibility
torch.manual_seed(42)
np.random.seed(42)

DEVICE = get_device()
print(f"Device: {DEVICE}")
Device: mps

1 Import data

DATA_DIR = Path("../src/deepssf/datasets/data")
CSV_PATH = DATA_DIR / "buffalo_djelk_id2005.csv"
LAYER_PATHS = {
    "ndvi":  str(DATA_DIR / "ndvi_2005.tif"),
    "slope": str(DATA_DIR / "slope_2005.tif"),
}
OUTPUT_DIR  = Path("outputs")
SNAPSHOT_DIR = OUTPUT_DIR / "training_snapshots"
OUTPUT_DIR.mkdir(exist_ok=True)
raw_df = pd.read_csv(CSV_PATH)
print(f"Raw GPS fixes : {len(raw_df):,}")
print(f"Individuals   : {raw_df['id'].nunique()}")

# Sort by the individual ID and timestamp to ensure correct temporal order
raw_df = raw_df.sort_values(by=["id", "time"]).reset_index(drop=True)

# # Convert the 'time' column to the correct timezone (here Australia/Darwin) 
# raw_df["time"] = (
#         pd.to_datetime(raw_df["time"], utc=True)
#         .dt.tz_convert("Australia/Darwin")
#     )

raw_df.head()
Raw GPS fixes : 10,297
Individuals   : 1
id time x y
0 2005 2018-07-25T00:04:02Z 41941.331695 -1.435875e+06
1 2005 2018-07-25T01:04:23Z 41969.310875 -1.435671e+06
2 2005 2018-07-25T02:04:39Z 41921.521939 -1.435654e+06
3 2005 2018-07-25T03:04:17Z 41779.439594 -1.435601e+06
4 2005 2018-07-25T04:04:39Z 41841.203272 -1.435635e+06
env_layers, raster_transform = load_environmental_layers(LAYER_PATHS)
print("Layers loaded:", list(env_layers.keys()))
for name, arr in env_layers.items():
    print(f"  {name}: shape={arr.shape}, min={arr.min():.3f}, max={arr.max():.3f}")
Layer 'ndvi': min=-0.06614641100168228, max=0.731696367263794 → scaled to [0, 1]
Layer 'slope': min=0.00013558653881773353, max=12.298083305358887 → scaled to [0, 1]
Layers loaded: ['ndvi', 'slope']
  ndvi: shape=(758, 1175), min=0.000, max=1.090
  slope: shape=(758, 1175), min=0.000, max=1.000

2 Prepare data

step_df = prepare_movement_df(raw_df)
print(f"Movement steps : {len(step_df):,}  ({len(raw_df) - len(step_df)} dropped — last fix per individual)")
step_df.head()
Movement steps : 10,296  (1 dropped — last fix per individual)
id t1_ x1_ y1_ t2_ x2_ y2_ dx dy bearing bearing_tm1 dt_hour hour_t1 yday_t1 hour_t1_sin1 hour_t1_cos1 yday_t1_sin1 yday_t1_cos1
0 2005 2018-07-25T00:04:02Z 41941.331695 -1.435875e+06 2018-07-25T01:04:23Z 41969.310875 -1.435671e+06 27.979180 204.064135 1.434536 0.000000 1.005833 0.066667 206.0 0.017452 0.999848 -0.391358 -0.920239
1 2005 2018-07-25T01:04:23Z 41969.310875 -1.435671e+06 2018-07-25T02:04:39Z 41921.521939 -1.435654e+06 -47.788936 16.857110 2.802478 1.434536 1.004444 1.066667 206.0 0.275637 0.961262 -0.391358 -0.920239
2 2005 2018-07-25T02:04:39Z 41921.521939 -1.435654e+06 2018-07-25T03:04:17Z 41779.439594 -1.435601e+06 -142.082345 53.568427 2.781049 2.802478 0.993889 2.066667 206.0 0.515038 0.857167 -0.391358 -0.920239
3 2005 2018-07-25T03:04:17Z 41779.439594 -1.435601e+06 2018-07-25T04:04:39Z 41841.203272 -1.435635e+06 61.763677 -34.322938 -0.507220 2.781049 1.006111 3.066667 206.0 0.719340 0.694658 -0.391358 -0.920239
4 2005 2018-07-25T04:04:39Z 41841.203272 -1.435635e+06 2018-07-25T05:04:27Z 41655.463332 -1.435604e+06 -185.739939 31.003534 2.976198 -0.507220 0.996667 4.066667 206.0 0.874620 0.484810 -0.391358 -0.920239
# Model configuration
# 4 scalar covariates broadcast into spatial maps (no dt_hour so simulate() matches)
SCALAR_COLS = ["hour_t1_sin1", "hour_t1_cos1", "yday_t1_sin1", "yday_t1_cos1", "dt_hour"]
WINDOW_SIZE = 101    # small for speed; 101 is typical for production
PIXEL_SIZE  = 25    # raster resolution in metres (CRS units)
BATCH_SIZE  = 32

# Calculate the flat size of the movement CNN output after 3× conv+maxpool
OUTPUT_CHANNELS = 4
dim = WINDOW_SIZE
for _ in range(3):
    dim = math.floor((dim + 2 * 1 - 3) / 1 + 1)  # conv: pad=1 keeps dim
    dim = math.floor((dim - 2) / 2 + 1)           # maxpool k=2, s=2
DENSE_DIM = OUTPUT_CHANNELS * dim * dim
print(f"After 3× conv+maxpool: dim={dim}  →  dense_dim_in_all={DENSE_DIM}")

# Drop steps whose displacement exceeds the window; would cause an index error in the loss
step_df = filter_steps_by_window(step_df, window_size=WINDOW_SIZE, pixel_size=PIXEL_SIZE)
print(f"Steps after window filter: {len(step_df):,}")
After 3× conv+maxpool: dim=12  →  dense_dim_in_all=576
Steps after window filter: 10,101
dl_train, dl_val, dl_test = make_dataloaders(
    layer_paths=LAYER_PATHS,
    window_size=WINDOW_SIZE,
    batch_size=BATCH_SIZE,
    train_split=0.8,
    val_split=0.1,
    scalar_cols=SCALAR_COLS,
    df=step_df,
)
print(f"Train batches : {len(dl_train)}")
print(f"Val   batches : {len(dl_val)}")
print(f"Test  batches : {len(dl_test)}")
Layer 'ndvi': min=-0.06614641100168228, max=0.731696367263794 → scaled to [0, 1]
Layer 'slope': min=0.00013558653881773353, max=12.298083305358887 → scaled to [0, 1]
Train batches : 253
Val   batches : 32
Test  batches : 32
/Users/scottforrest/Library/Mobile Documents/com~apple~CloudDocs/github_repos/deepSSF_package/src/deepssf/data.py:358: UserWarning: The given NumPy array is not writable, and PyTorch does not support non-writable tensors. This means writing to this tensor will result in undefined behavior. You may want to copy the array to protect its data or make it writable before converting it to a tensor. This type of warning will be suppressed for the rest of this program. (Triggered internally at /Users/runner/work/pytorch/pytorch/torch/csrc/utils/tensor_numpy.cpp:219.)
  self.scalar_to_grid_data = torch.from_numpy(
# Inspect one batch
x1, x2, x3, (px2, py2), _ = next(iter(dl_train))
print(f"Spatial   x1 : {x1.shape}   (batch × channels × H × W)")
print(f"Scalars   x2 : {x2.shape}")
print(f"Bearing   x3 : {x3.shape}")
print(f"Next-step px2: {px2[:4].tolist()}  py2: {py2[:4].tolist()}")
Spatial   x1 : torch.Size([32, 2, 101, 101])   (batch × channels × H × W)
Scalars   x2 : torch.Size([32, 5])
Bearing   x3 : torch.Size([32, 1])
Next-step px2: [54, 51, 50, 50]  py2: [51, 53, 50, 50]

3 Train model

params = ModelParams({
    "batch_size":                BATCH_SIZE,
    "image_dim":                 WINDOW_SIZE,
    "pixel_size":                PIXEL_SIZE,
    "dim_in_nonspatial_to_grid": len(SCALAR_COLS),
    "dense_dim_in_nonspatial":   len(SCALAR_COLS),
    "dense_dim_hidden":          64,
    "dense_dim_in_all":          DENSE_DIM,
    "input_channels":            2 + len(SCALAR_COLS),  # 2 spatial + 4 scalar-to-grid
    "output_channels":           OUTPUT_CHANNELS,
    "kernel_size":               3,
    "stride":                    1,
    "kernel_size_mp":            2,
    "stride_mp":                 2,
    "padding":                   1,
    "num_movement_params":       12,
    "dropout":                   0.0,
    "device":                    DEVICE,
})

model = ConvJointModel(params).to(DEVICE)
n_params = sum(p.numel() for p in model.parameters())
print(f"Model parameters : {n_params:,}")
Model parameters : 43,009
loss_fn     = negativeLogLikeLoss(reduction="mean")
optimisers, schedulers = make_optimisers(
    model, lr_habitat=1e-4, lr_movement=1e-5, scheduler_patience=2
)
early_stop  = EarlyStopping(
    patience=5, verbose=True, path=str(OUTPUT_DIR / "best_model.pt")
)
history = fit(
    model=model,
    n_conv_layers=4, # how many convolutional layers in the habitat selection process (to filter edge artifacts that affect colour scaling)
    window_size=WINDOW_SIZE,
    dl_train=dl_train,
    dl_val=dl_val,
    loss_fn=loss_fn,
    optimisers=optimisers,
    schedulers=schedulers,
    n_epochs=50,
    early_stopping=early_stop,
    snapshot_dir=str(SNAPSHOT_DIR),
    snapshot_item=500,
)

Epoch 1/50
loss:        6.785439  [   32/ 8081]
loss:        6.097230  [  352/ 8081]
loss:        5.551458  [  672/ 8081]
loss:        5.073307  [  992/ 8081]
loss:        6.436917  [ 1312/ 8081]
loss:        6.255669  [ 1632/ 8081]
loss:        5.437774  [ 1952/ 8081]
loss:        5.842318  [ 2272/ 8081]
loss:        6.056246  [ 2592/ 8081]
loss:        5.573969  [ 2912/ 8081]
loss:        5.683355  [ 3232/ 8081]
loss:        5.886904  [ 3552/ 8081]
loss:        5.906250  [ 3872/ 8081]
loss:        4.869259  [ 4192/ 8081]
loss:        6.376064  [ 4512/ 8081]
loss:        5.547744  [ 4832/ 8081]
loss:        6.104920  [ 5152/ 8081]
loss:        5.175964  [ 5472/ 8081]
loss:        6.325367  [ 5792/ 8081]
loss:        5.473641  [ 6112/ 8081]
loss:        4.946332  [ 6432/ 8081]
loss:        5.967487  [ 6752/ 8081]
loss:        5.914863  [ 7072/ 8081]
loss:        5.850795  [ 7392/ 8081]
loss:        5.686168  [ 7712/ 8081]
loss:        5.253773  [ 8032/ 8081]

Avg training loss:        5.788497
Val loss: 5.687180  (hab: 9.166112, mov: 5.704777)
Validation loss decreased (5.708087 → 5.687180). Saving model…

Epoch 2/50
loss:        5.694406  [   32/ 8081]
loss:        6.153904  [  352/ 8081]
loss:        6.037060  [  672/ 8081]
loss:        6.598245  [  992/ 8081]
loss:        4.828341  [ 1312/ 8081]
loss:        5.693692  [ 1632/ 8081]
loss:        4.467628  [ 1952/ 8081]
loss:        6.476704  [ 2272/ 8081]
loss:        5.603591  [ 2592/ 8081]
loss:        5.886862  [ 2912/ 8081]
loss:        5.528289  [ 3232/ 8081]
loss:        5.574023  [ 3552/ 8081]
loss:        6.055640  [ 3872/ 8081]
loss:        6.516983  [ 4192/ 8081]
loss:        6.326391  [ 4512/ 8081]
loss:        4.843784  [ 4832/ 8081]
loss:        6.195689  [ 5152/ 8081]
loss:        5.668516  [ 5472/ 8081]
loss:        5.597651  [ 5792/ 8081]
loss:        4.522491  [ 6112/ 8081]
loss:        6.149788  [ 6432/ 8081]
loss:        5.482637  [ 6752/ 8081]
loss:        6.251376  [ 7072/ 8081]
loss:        4.979764  [ 7392/ 8081]
loss:        5.590664  [ 7712/ 8081]
loss:        6.732779  [ 8032/ 8081]

Avg training loss:        5.780206
Val loss: 5.680278  (hab: 9.167198, mov: 5.698563)
Validation loss decreased (5.687180 → 5.680278). Saving model…

Epoch 3/50
loss:        5.771159  [   32/ 8081]
loss:        4.998755  [  352/ 8081]
loss:        6.109458  [  672/ 8081]
loss:        6.072502  [  992/ 8081]
loss:        5.710176  [ 1312/ 8081]
loss:        5.001541  [ 1632/ 8081]
loss:        4.916896  [ 1952/ 8081]
loss:        5.842347  [ 2272/ 8081]
loss:        6.470130  [ 2592/ 8081]
loss:        5.437813  [ 2912/ 8081]
loss:        5.558701  [ 3232/ 8081]
loss:        5.935304  [ 3552/ 8081]
loss:        5.704793  [ 3872/ 8081]
loss:        7.646010  [ 4192/ 8081]
loss:        5.135283  [ 4512/ 8081]
loss:        5.725735  [ 4832/ 8081]
loss:        5.999853  [ 5152/ 8081]
loss:        5.170563  [ 5472/ 8081]
loss:        6.050817  [ 5792/ 8081]
loss:        5.566458  [ 6112/ 8081]
loss:        6.288377  [ 6432/ 8081]
loss:        5.568368  [ 6752/ 8081]
loss:        6.243468  [ 7072/ 8081]
loss:        5.513037  [ 7392/ 8081]
loss:        5.912124  [ 7712/ 8081]
loss:        5.924772  [ 8032/ 8081]

Avg training loss:        5.773384
Val loss: 5.673004  (hab: 9.161399, mov: 5.693920)
Validation loss decreased (5.680278 → 5.673004). Saving model…

Epoch 4/50
loss:        5.164478  [   32/ 8081]
loss:        5.978908  [  352/ 8081]
loss:        6.380840  [  672/ 8081]
loss:        5.990369  [  992/ 8081]
loss:        4.989443  [ 1312/ 8081]
loss:        6.138777  [ 1632/ 8081]
loss:        5.724806  [ 1952/ 8081]
loss:        5.964715  [ 2272/ 8081]
loss:        5.613798  [ 2592/ 8081]
loss:        5.038851  [ 2912/ 8081]
loss:        5.382350  [ 3232/ 8081]
loss:        6.172091  [ 3552/ 8081]
loss:        6.201478  [ 3872/ 8081]
loss:        6.695495  [ 4192/ 8081]
loss:        6.798589  [ 4512/ 8081]
loss:        6.260759  [ 4832/ 8081]
loss:        5.514187  [ 5152/ 8081]
loss:        5.574938  [ 5472/ 8081]
loss:        5.461888  [ 5792/ 8081]
loss:        5.341826  [ 6112/ 8081]
loss:        5.825900  [ 6432/ 8081]
loss:        5.031033  [ 6752/ 8081]
loss:        5.472840  [ 7072/ 8081]
loss:        6.054863  [ 7392/ 8081]
loss:        5.984641  [ 7712/ 8081]
loss:        5.449839  [ 8032/ 8081]

Avg training loss:        5.771757
Val loss: 5.670019  (hab: 9.161964, mov: 5.690092)
Validation loss decreased (5.673004 → 5.670019). Saving model…

Epoch 5/50
loss:        5.468553  [   32/ 8081]
loss:        5.154895  [  352/ 8081]
loss:        5.961631  [  672/ 8081]
loss:        5.888359  [  992/ 8081]
loss:        6.175714  [ 1312/ 8081]
loss:        6.281458  [ 1632/ 8081]
loss:        6.169101  [ 1952/ 8081]
loss:        6.215828  [ 2272/ 8081]
loss:        5.834133  [ 2592/ 8081]
loss:        5.974036  [ 2912/ 8081]
loss:        5.318532  [ 3232/ 8081]
loss:        5.839583  [ 3552/ 8081]
loss:        5.984657  [ 3872/ 8081]
loss:        6.032445  [ 4192/ 8081]
loss:        5.217687  [ 4512/ 8081]
loss:        5.736081  [ 4832/ 8081]
loss:        6.185085  [ 5152/ 8081]
loss:        6.298941  [ 5472/ 8081]
loss:        5.976015  [ 5792/ 8081]
loss:        6.169477  [ 6112/ 8081]
loss:        6.504289  [ 6432/ 8081]
loss:        5.326622  [ 6752/ 8081]
loss:        6.042302  [ 7072/ 8081]
loss:        5.940510  [ 7392/ 8081]
loss:        5.981143  [ 7712/ 8081]
loss:        6.376686  [ 8032/ 8081]

Avg training loss:        5.765343
Val loss: 5.667132  (hab: 9.160795, mov: 5.687070)
Validation loss decreased (5.670019 → 5.667132). Saving model…

Epoch 6/50
loss:        5.989162  [   32/ 8081]
loss:        5.185757  [  352/ 8081]
loss:        6.094879  [  672/ 8081]
loss:        6.327469  [  992/ 8081]
loss:        6.278698  [ 1312/ 8081]
loss:        7.185611  [ 1632/ 8081]
loss:        6.149991  [ 1952/ 8081]
loss:        6.799704  [ 2272/ 8081]
loss:        5.566303  [ 2592/ 8081]
loss:        5.896247  [ 2912/ 8081]
loss:        6.005388  [ 3232/ 8081]
loss:        5.651654  [ 3552/ 8081]
loss:        6.483175  [ 3872/ 8081]
loss:        5.532136  [ 4192/ 8081]
loss:        6.001493  [ 4512/ 8081]
loss:        6.396995  [ 4832/ 8081]
loss:        6.277179  [ 5152/ 8081]
loss:        5.335245  [ 5472/ 8081]
loss:        5.850138  [ 5792/ 8081]
loss:        5.985559  [ 6112/ 8081]
loss:        5.363245  [ 6432/ 8081]
loss:        5.474362  [ 6752/ 8081]
loss:        6.610009  [ 7072/ 8081]
loss:        4.994096  [ 7392/ 8081]
loss:        6.311173  [ 7712/ 8081]
loss:        5.433157  [ 8032/ 8081]

Avg training loss:        5.762437
Val loss: 5.664561  (hab: 9.163505, mov: 5.684692)
Validation loss decreased (5.667132 → 5.664561). Saving model…

Epoch 7/50
loss:        5.466299  [   32/ 8081]
loss:        5.562140  [  352/ 8081]
loss:        5.990706  [  672/ 8081]
loss:        6.745186  [  992/ 8081]
loss:        5.897082  [ 1312/ 8081]
loss:        6.342498  [ 1632/ 8081]
loss:        6.487095  [ 1952/ 8081]
loss:        5.308664  [ 2272/ 8081]
loss:        6.215004  [ 2592/ 8081]
loss:        5.250403  [ 2912/ 8081]
loss:        5.854526  [ 3232/ 8081]
loss:        5.654216  [ 3552/ 8081]
loss:        5.539589  [ 3872/ 8081]
loss:        5.537636  [ 4192/ 8081]
loss:        5.496979  [ 4512/ 8081]
loss:        5.635530  [ 4832/ 8081]
loss:        4.809999  [ 5152/ 8081]
loss:        5.620271  [ 5472/ 8081]
loss:        6.103968  [ 5792/ 8081]
loss:        6.340637  [ 6112/ 8081]
loss:        6.328131  [ 6432/ 8081]
loss:        5.185500  [ 6752/ 8081]
loss:        5.044396  [ 7072/ 8081]
loss:        6.326553  [ 7392/ 8081]
loss:        5.725074  [ 7712/ 8081]
loss:        5.338499  [ 8032/ 8081]

Avg training loss:        5.757804
Val loss: 5.663774  (hab: 9.160690, mov: 5.682512)
Validation loss decreased (5.664561 → 5.663774). Saving model…

Epoch 8/50
loss:        5.153199  [   32/ 8081]
loss:        6.033429  [  352/ 8081]
loss:        6.188162  [  672/ 8081]
loss:        5.682156  [  992/ 8081]
loss:        5.402650  [ 1312/ 8081]
loss:        6.339832  [ 1632/ 8081]
loss:        5.566161  [ 1952/ 8081]
loss:        5.892251  [ 2272/ 8081]
loss:        4.953355  [ 2592/ 8081]
loss:        5.078116  [ 2912/ 8081]
loss:        5.694687  [ 3232/ 8081]
loss:        6.260427  [ 3552/ 8081]
loss:        5.378967  [ 3872/ 8081]
loss:        5.724102  [ 4192/ 8081]
loss:        6.579778  [ 4512/ 8081]
loss:        6.327852  [ 4832/ 8081]
loss:        5.288838  [ 5152/ 8081]
loss:        5.571650  [ 5472/ 8081]
loss:        7.053420  [ 5792/ 8081]
loss:        6.272779  [ 6112/ 8081]
loss:        5.462452  [ 6432/ 8081]
loss:        5.919233  [ 6752/ 8081]
loss:        6.486927  [ 7072/ 8081]
loss:        5.169839  [ 7392/ 8081]
loss:        5.864375  [ 7712/ 8081]
loss:        6.144361  [ 8032/ 8081]

Avg training loss:        5.756941
Val loss: 5.659359  (hab: 9.160813, mov: 5.680856)
Validation loss decreased (5.663774 → 5.659359). Saving model…

Epoch 9/50
loss:        5.015366  [   32/ 8081]
loss:        5.890463  [  352/ 8081]
loss:        5.667079  [  672/ 8081]
loss:        6.308356  [  992/ 8081]
loss:        6.460155  [ 1312/ 8081]
loss:        5.740475  [ 1632/ 8081]
loss:        6.436217  [ 1952/ 8081]
loss:        7.255536  [ 2272/ 8081]
loss:        5.891561  [ 2592/ 8081]
loss:        5.757444  [ 2912/ 8081]
loss:        6.140157  [ 3232/ 8081]
loss:        5.553885  [ 3552/ 8081]
loss:        5.180971  [ 3872/ 8081]
loss:        5.593443  [ 4192/ 8081]
loss:        5.231184  [ 4512/ 8081]
loss:        5.650258  [ 4832/ 8081]
loss:        5.901844  [ 5152/ 8081]
loss:        5.418355  [ 5472/ 8081]
loss:        6.061893  [ 5792/ 8081]
loss:        5.725675  [ 6112/ 8081]
loss:        5.522688  [ 6432/ 8081]
loss:        5.422935  [ 6752/ 8081]
loss:        5.901719  [ 7072/ 8081]
loss:        6.634830  [ 7392/ 8081]
loss:        6.179843  [ 7712/ 8081]
loss:        6.570165  [ 8032/ 8081]

Avg training loss:        5.753735
Val loss: 5.657118  (hab: 9.155861, mov: 5.679439)
Validation loss decreased (5.659359 → 5.657118). Saving model…

Epoch 10/50
loss:        5.979346  [   32/ 8081]
loss:        6.341550  [  352/ 8081]
loss:        6.134659  [  672/ 8081]
loss:        5.003821  [  992/ 8081]
loss:        5.825746  [ 1312/ 8081]
loss:        6.957564  [ 1632/ 8081]
loss:        5.609220  [ 1952/ 8081]
loss:        5.420880  [ 2272/ 8081]
loss:        5.825545  [ 2592/ 8081]
loss:        6.414145  [ 2912/ 8081]
loss:        4.970078  [ 3232/ 8081]
loss:        4.309336  [ 3552/ 8081]
loss:        6.106659  [ 3872/ 8081]
loss:        6.817013  [ 4192/ 8081]
loss:        5.473575  [ 4512/ 8081]
loss:        5.892172  [ 4832/ 8081]
loss:        6.330991  [ 5152/ 8081]
loss:        4.686843  [ 5472/ 8081]
loss:        6.179124  [ 5792/ 8081]
loss:        5.744292  [ 6112/ 8081]
loss:        5.377276  [ 6432/ 8081]
loss:        5.348831  [ 6752/ 8081]
loss:        5.073267  [ 7072/ 8081]
loss:        4.481338  [ 7392/ 8081]
loss:        5.835503  [ 7712/ 8081]
loss:        4.866236  [ 8032/ 8081]

Avg training loss:        5.750079
Val loss: 5.655317  (hab: 9.154174, mov: 5.678362)
Validation loss decreased (5.657118 → 5.655317). Saving model…

Epoch 11/50
loss:        5.954041  [   32/ 8081]
loss:        6.094968  [  352/ 8081]
loss:        4.979434  [  672/ 8081]
loss:        5.557578  [  992/ 8081]
loss:        6.135086  [ 1312/ 8081]
loss:        6.096596  [ 1632/ 8081]
loss:        5.810320  [ 1952/ 8081]
loss:        5.773321  [ 2272/ 8081]
loss:        5.708895  [ 2592/ 8081]
loss:        5.832764  [ 2912/ 8081]
loss:        5.584544  [ 3232/ 8081]
loss:        5.994350  [ 3552/ 8081]
loss:        6.183756  [ 3872/ 8081]
loss:        7.011960  [ 4192/ 8081]
loss:        5.242159  [ 4512/ 8081]
loss:        4.903460  [ 4832/ 8081]
loss:        5.612700  [ 5152/ 8081]
loss:        6.321691  [ 5472/ 8081]
loss:        6.183264  [ 5792/ 8081]
loss:        6.033978  [ 6112/ 8081]
loss:        5.172969  [ 6432/ 8081]
loss:        6.431845  [ 6752/ 8081]
loss:        5.793088  [ 7072/ 8081]
loss:        4.720260  [ 7392/ 8081]
loss:        4.866962  [ 7712/ 8081]
loss:        6.244982  [ 8032/ 8081]

Avg training loss:        5.750686
Val loss: 5.653751  (hab: 9.154543, mov: 5.677163)
Validation loss decreased (5.655317 → 5.653751). Saving model…

Epoch 12/50
loss:        6.128786  [   32/ 8081]
loss:        6.231214  [  352/ 8081]
loss:        5.599900  [  672/ 8081]
loss:        5.946187  [  992/ 8081]
loss:        6.217439  [ 1312/ 8081]
loss:        5.561559  [ 1632/ 8081]
loss:        6.017690  [ 1952/ 8081]
loss:        6.422792  [ 2272/ 8081]
loss:        5.169445  [ 2592/ 8081]
loss:        5.734517  [ 2912/ 8081]
loss:        6.159955  [ 3232/ 8081]
loss:        6.048131  [ 3552/ 8081]
loss:        5.095168  [ 3872/ 8081]
loss:        5.633101  [ 4192/ 8081]
loss:        5.936021  [ 4512/ 8081]
loss:        4.978602  [ 4832/ 8081]
loss:        6.078108  [ 5152/ 8081]
loss:        6.416469  [ 5472/ 8081]
loss:        6.084074  [ 5792/ 8081]
loss:        7.058030  [ 6112/ 8081]
loss:        6.218597  [ 6432/ 8081]
loss:        5.706355  [ 6752/ 8081]
loss:        6.205107  [ 7072/ 8081]
loss:        6.251582  [ 7392/ 8081]
loss:        6.318515  [ 7712/ 8081]
loss:        5.446843  [ 8032/ 8081]

Avg training loss:        5.747251
Val loss: 5.651318  (hab: 9.156074, mov: 5.676356)
Validation loss decreased (5.653751 → 5.651318). Saving model…

Epoch 13/50
loss:        5.735549  [   32/ 8081]
loss:        5.167077  [  352/ 8081]
loss:        5.022010  [  672/ 8081]
loss:        6.689678  [  992/ 8081]
loss:        5.273829  [ 1312/ 8081]
loss:        5.824816  [ 1632/ 8081]
loss:        5.750481  [ 1952/ 8081]
loss:        5.760807  [ 2272/ 8081]
loss:        4.784536  [ 2592/ 8081]
loss:        7.326108  [ 2912/ 8081]
loss:        6.325606  [ 3232/ 8081]
loss:        5.315575  [ 3552/ 8081]
loss:        5.708946  [ 3872/ 8081]
loss:        6.236599  [ 4192/ 8081]
loss:        6.176717  [ 4512/ 8081]
loss:        6.667757  [ 4832/ 8081]
loss:        4.588668  [ 5152/ 8081]
loss:        5.248083  [ 5472/ 8081]
loss:        6.057298  [ 5792/ 8081]
loss:        5.856156  [ 6112/ 8081]
loss:        6.274911  [ 6432/ 8081]
loss:        6.290979  [ 6752/ 8081]
loss:        5.460048  [ 7072/ 8081]
loss:        4.758325  [ 7392/ 8081]
loss:        7.212662  [ 7712/ 8081]
loss:        6.885795  [ 8032/ 8081]

Avg training loss:        5.746758
Val loss: 5.648002  (hab: 9.150894, mov: 5.675554)
Validation loss decreased (5.651318 → 5.648002). Saving model…

Epoch 14/50
loss:        5.946347  [   32/ 8081]
loss:        5.041623  [  352/ 8081]
loss:        5.873816  [  672/ 8081]
loss:        5.291799  [  992/ 8081]
loss:        5.365646  [ 1312/ 8081]
loss:        5.800356  [ 1632/ 8081]
loss:        5.706845  [ 1952/ 8081]
loss:        6.113334  [ 2272/ 8081]
loss:        5.965096  [ 2592/ 8081]
loss:        6.214636  [ 2912/ 8081]
loss:        6.481359  [ 3232/ 8081]
loss:        4.853850  [ 3552/ 8081]
loss:        5.281315  [ 3872/ 8081]
loss:        5.762989  [ 4192/ 8081]
loss:        6.134844  [ 4512/ 8081]
loss:        6.813313  [ 4832/ 8081]
loss:        6.060571  [ 5152/ 8081]
loss:        7.026179  [ 5472/ 8081]
loss:        5.946495  [ 5792/ 8081]
loss:        5.180030  [ 6112/ 8081]
loss:        6.355926  [ 6432/ 8081]
loss:        5.490451  [ 6752/ 8081]
loss:        5.836087  [ 7072/ 8081]
loss:        6.027002  [ 7392/ 8081]
loss:        5.614779  [ 7712/ 8081]
loss:        6.035224  [ 8032/ 8081]

Avg training loss:        5.742599
Val loss: 5.648140  (hab: 9.153320, mov: 5.674904)
EarlyStopping counter: 1 out of 5

Epoch 15/50
loss:        5.597585  [   32/ 8081]
loss:        4.928980  [  352/ 8081]
loss:        5.134027  [  672/ 8081]
loss:        5.157074  [  992/ 8081]
loss:        5.909005  [ 1312/ 8081]
loss:        5.560033  [ 1632/ 8081]
loss:        6.205202  [ 1952/ 8081]
loss:        6.106179  [ 2272/ 8081]
loss:        7.436422  [ 2592/ 8081]
loss:        5.639592  [ 2912/ 8081]
loss:        4.822022  [ 3232/ 8081]
loss:        4.106799  [ 3552/ 8081]
loss:        6.259495  [ 3872/ 8081]
loss:        7.228454  [ 4192/ 8081]
loss:        5.105317  [ 4512/ 8081]
loss:        5.871928  [ 4832/ 8081]
loss:        5.165004  [ 5152/ 8081]
loss:        6.482831  [ 5472/ 8081]
loss:        4.902627  [ 5792/ 8081]
loss:        6.007688  [ 6112/ 8081]
loss:        4.781419  [ 6432/ 8081]
loss:        5.254600  [ 6752/ 8081]
loss:        6.425246  [ 7072/ 8081]
loss:        6.281509  [ 7392/ 8081]
loss:        5.598694  [ 7712/ 8081]
loss:        6.053635  [ 8032/ 8081]

Avg training loss:        5.743073
Val loss: 5.645316  (hab: 9.148739, mov: 5.674267)
Validation loss decreased (5.648002 → 5.645316). Saving model…

Epoch 16/50
loss:        6.885323  [   32/ 8081]
loss:        5.594497  [  352/ 8081]
loss:        5.678989  [  672/ 8081]
loss:        6.227365  [  992/ 8081]
loss:        5.790132  [ 1312/ 8081]
loss:        5.360517  [ 1632/ 8081]
loss:        5.787529  [ 1952/ 8081]
loss:        6.439590  [ 2272/ 8081]
loss:        5.170119  [ 2592/ 8081]
loss:        5.598507  [ 2912/ 8081]
loss:        5.916703  [ 3232/ 8081]
loss:        5.602601  [ 3552/ 8081]
loss:        4.940242  [ 3872/ 8081]
loss:        5.859461  [ 4192/ 8081]
loss:        5.760822  [ 4512/ 8081]
loss:        5.442616  [ 4832/ 8081]
loss:        5.790349  [ 5152/ 8081]
loss:        5.874994  [ 5472/ 8081]
loss:        5.619014  [ 5792/ 8081]
loss:        5.000046  [ 6112/ 8081]
loss:        5.243309  [ 6432/ 8081]
loss:        6.288419  [ 6752/ 8081]
loss:        6.087790  [ 7072/ 8081]
loss:        6.147535  [ 7392/ 8081]
loss:        6.012441  [ 7712/ 8081]
loss:        6.748470  [ 8032/ 8081]

Avg training loss:        5.742579
Val loss: 5.644631  (hab: 9.147775, mov: 5.673745)
Validation loss decreased (5.645316 → 5.644631). Saving model…

Epoch 17/50
loss:        6.283003  [   32/ 8081]
loss:        5.595797  [  352/ 8081]
loss:        5.715901  [  672/ 8081]
loss:        5.501903  [  992/ 8081]
loss:        5.439520  [ 1312/ 8081]
loss:        4.297513  [ 1632/ 8081]
loss:        5.481026  [ 1952/ 8081]
loss:        5.943962  [ 2272/ 8081]
loss:        5.792973  [ 2592/ 8081]
loss:        4.885669  [ 2912/ 8081]
loss:        5.561083  [ 3232/ 8081]
loss:        5.883833  [ 3552/ 8081]
loss:        6.213400  [ 3872/ 8081]
loss:        5.019048  [ 4192/ 8081]
loss:        5.765052  [ 4512/ 8081]
loss:        4.955160  [ 4832/ 8081]
loss:        5.329431  [ 5152/ 8081]
loss:        5.798806  [ 5472/ 8081]
loss:        5.645044  [ 5792/ 8081]
loss:        6.268931  [ 6112/ 8081]
loss:        5.830894  [ 6432/ 8081]
loss:        6.479350  [ 6752/ 8081]
loss:        5.524815  [ 7072/ 8081]
loss:        4.875284  [ 7392/ 8081]
loss:        6.527176  [ 7712/ 8081]
loss:        6.552918  [ 8032/ 8081]

Avg training loss:        5.741385
Val loss: 5.643943  (hab: 9.149630, mov: 5.673379)
Validation loss decreased (5.644631 → 5.643943). Saving model…

Epoch 18/50
loss:        5.646639  [   32/ 8081]
loss:        6.221776  [  352/ 8081]
loss:        5.486329  [  672/ 8081]
loss:        6.544412  [  992/ 8081]
loss:        5.341964  [ 1312/ 8081]
loss:        5.519305  [ 1632/ 8081]
loss:        5.982452  [ 1952/ 8081]
loss:        6.139508  [ 2272/ 8081]
loss:        5.064729  [ 2592/ 8081]
loss:        6.249910  [ 2912/ 8081]
loss:        5.974421  [ 3232/ 8081]
loss:        5.339847  [ 3552/ 8081]
loss:        6.463853  [ 3872/ 8081]
loss:        6.058188  [ 4192/ 8081]
loss:        5.772356  [ 4512/ 8081]
loss:        6.047544  [ 4832/ 8081]
loss:        6.791408  [ 5152/ 8081]
loss:        4.642981  [ 5472/ 8081]
loss:        5.114915  [ 5792/ 8081]
loss:        6.150812  [ 6112/ 8081]
loss:        6.051000  [ 6432/ 8081]
loss:        6.648256  [ 6752/ 8081]
loss:        6.250250  [ 7072/ 8081]
loss:        5.238799  [ 7392/ 8081]
loss:        5.869842  [ 7712/ 8081]
loss:        6.773789  [ 8032/ 8081]

Avg training loss:        5.740270
Val loss: 5.644076  (hab: 9.149081, mov: 5.672949)
EarlyStopping counter: 1 out of 5

Epoch 19/50
loss:        5.175504  [   32/ 8081]
loss:        6.612113  [  352/ 8081]
loss:        6.259410  [  672/ 8081]
loss:        6.931705  [  992/ 8081]
loss:        6.386226  [ 1312/ 8081]
loss:        5.127598  [ 1632/ 8081]
loss:        5.744873  [ 1952/ 8081]
loss:        5.399240  [ 2272/ 8081]
loss:        6.128780  [ 2592/ 8081]
loss:        4.790784  [ 2912/ 8081]
loss:        6.567341  [ 3232/ 8081]
loss:        4.790431  [ 3552/ 8081]
loss:        5.470789  [ 3872/ 8081]
loss:        5.554800  [ 4192/ 8081]
loss:        5.679316  [ 4512/ 8081]
loss:        5.368668  [ 4832/ 8081]
loss:        5.470472  [ 5152/ 8081]
loss:        5.754320  [ 5472/ 8081]
loss:        5.034759  [ 5792/ 8081]
loss:        6.242113  [ 6112/ 8081]
loss:        5.507895  [ 6432/ 8081]
loss:        5.377802  [ 6752/ 8081]
loss:        5.432044  [ 7072/ 8081]
loss:        7.036860  [ 7392/ 8081]
loss:        5.353534  [ 7712/ 8081]
loss:        5.177495  [ 8032/ 8081]

Avg training loss:        5.737458
Val loss: 5.644072  (hab: 9.149584, mov: 5.672635)
EarlyStopping counter: 2 out of 5

Epoch 20/50
loss:        5.072841  [   32/ 8081]
loss:        5.977276  [  352/ 8081]
loss:        6.188596  [  672/ 8081]
loss:        5.472806  [  992/ 8081]
loss:        5.054672  [ 1312/ 8081]
loss:        6.003091  [ 1632/ 8081]
loss:        5.515766  [ 1952/ 8081]
loss:        7.272111  [ 2272/ 8081]
loss:        5.377561  [ 2592/ 8081]
loss:        5.953638  [ 2912/ 8081]
loss:        5.927915  [ 3232/ 8081]
loss:        6.455359  [ 3552/ 8081]
loss:        6.164033  [ 3872/ 8081]
loss:        6.237619  [ 4192/ 8081]
loss:        6.076998  [ 4512/ 8081]
loss:        6.516166  [ 4832/ 8081]
loss:        5.407486  [ 5152/ 8081]
loss:        4.519997  [ 5472/ 8081]
loss:        6.256764  [ 5792/ 8081]
loss:        4.865379  [ 6112/ 8081]
loss:        5.583294  [ 6432/ 8081]
loss:        6.422597  [ 6752/ 8081]
loss:        5.656250  [ 7072/ 8081]
loss:        6.124544  [ 7392/ 8081]
loss:        6.916223  [ 7712/ 8081]
loss:        5.714869  [ 8032/ 8081]

Avg training loss:        5.738087
Val loss: 5.641866  (hab: 9.147915, mov: 5.672215)
Validation loss decreased (5.643943 → 5.641866). Saving model…

Epoch 21/50
loss:        6.010686  [   32/ 8081]
loss:        6.227398  [  352/ 8081]
loss:        5.143725  [  672/ 8081]
loss:        5.645541  [  992/ 8081]
loss:        5.240898  [ 1312/ 8081]
loss:        4.938787  [ 1632/ 8081]
loss:        5.559191  [ 1952/ 8081]
loss:        6.605025  [ 2272/ 8081]
loss:        5.396750  [ 2592/ 8081]
loss:        5.149271  [ 2912/ 8081]
loss:        5.966692  [ 3232/ 8081]
loss:        5.675520  [ 3552/ 8081]
loss:        5.866745  [ 3872/ 8081]
loss:        4.591894  [ 4192/ 8081]
loss:        6.107647  [ 4512/ 8081]
loss:        5.626565  [ 4832/ 8081]
loss:        6.301752  [ 5152/ 8081]
loss:        5.883070  [ 5472/ 8081]
loss:        6.466449  [ 5792/ 8081]
loss:        5.921516  [ 6112/ 8081]
loss:        4.517435  [ 6432/ 8081]
loss:        6.049242  [ 6752/ 8081]
loss:        7.047176  [ 7072/ 8081]
loss:        5.933296  [ 7392/ 8081]
loss:        6.263047  [ 7712/ 8081]
loss:        5.244785  [ 8032/ 8081]

Avg training loss:        5.737740
Val loss: 5.640463  (hab: 9.149177, mov: 5.671844)
Validation loss decreased (5.641866 → 5.640463). Saving model…

Epoch 22/50
loss:        5.056785  [   32/ 8081]
loss:        5.107917  [  352/ 8081]
loss:        6.367042  [  672/ 8081]
loss:        5.297422  [  992/ 8081]
loss:        6.158465  [ 1312/ 8081]
loss:        7.193504  [ 1632/ 8081]
loss:        5.873117  [ 1952/ 8081]
loss:        5.830358  [ 2272/ 8081]
loss:        5.546278  [ 2592/ 8081]
loss:        5.019753  [ 2912/ 8081]
loss:        6.140012  [ 3232/ 8081]
loss:        5.437500  [ 3552/ 8081]
loss:        5.131791  [ 3872/ 8081]
loss:        5.591730  [ 4192/ 8081]
loss:        5.345333  [ 4512/ 8081]
loss:        5.434434  [ 4832/ 8081]
loss:        5.019072  [ 5152/ 8081]
loss:        5.485474  [ 5472/ 8081]
loss:        5.990477  [ 5792/ 8081]
loss:        6.565684  [ 6112/ 8081]
loss:        5.539105  [ 6432/ 8081]
loss:        6.096859  [ 6752/ 8081]
loss:        5.047228  [ 7072/ 8081]
loss:        5.438942  [ 7392/ 8081]
loss:        4.245893  [ 7712/ 8081]
loss:        5.230383  [ 8032/ 8081]

Avg training loss:        5.737169
Val loss: 5.641476  (hab: 9.145061, mov: 5.671644)
EarlyStopping counter: 1 out of 5

Epoch 23/50
loss:        6.373130  [   32/ 8081]
loss:        5.521699  [  352/ 8081]
loss:        5.814208  [  672/ 8081]
loss:        5.395493  [  992/ 8081]
loss:        5.824369  [ 1312/ 8081]
loss:        6.002573  [ 1632/ 8081]
loss:        5.864894  [ 1952/ 8081]
loss:        6.124455  [ 2272/ 8081]
loss:        6.624372  [ 2592/ 8081]
loss:        5.793117  [ 2912/ 8081]
loss:        6.165062  [ 3232/ 8081]
loss:        6.720785  [ 3552/ 8081]
loss:        6.139231  [ 3872/ 8081]
loss:        5.599601  [ 4192/ 8081]
loss:        5.711744  [ 4512/ 8081]
loss:        5.504859  [ 4832/ 8081]
loss:        5.548381  [ 5152/ 8081]
loss:        6.631241  [ 5472/ 8081]
loss:        6.187800  [ 5792/ 8081]
loss:        5.437247  [ 6112/ 8081]
loss:        5.561949  [ 6432/ 8081]
loss:        5.568308  [ 6752/ 8081]
loss:        4.743848  [ 7072/ 8081]
loss:        4.785015  [ 7392/ 8081]
loss:        5.890841  [ 7712/ 8081]
loss:        5.016112  [ 8032/ 8081]

Avg training loss:        5.738002
Val loss: 5.641952  (hab: 9.144587, mov: 5.671266)
EarlyStopping counter: 2 out of 5

Epoch 24/50
loss:        5.746219  [   32/ 8081]
loss:        5.143204  [  352/ 8081]
loss:        6.581674  [  672/ 8081]
loss:        5.834509  [  992/ 8081]
loss:        5.494139  [ 1312/ 8081]
loss:        4.902992  [ 1632/ 8081]
loss:        5.277520  [ 1952/ 8081]
loss:        6.514735  [ 2272/ 8081]
loss:        5.920425  [ 2592/ 8081]
loss:        5.850049  [ 2912/ 8081]
loss:        5.195207  [ 3232/ 8081]
loss:        6.105504  [ 3552/ 8081]
loss:        5.616876  [ 3872/ 8081]
loss:        6.008060  [ 4192/ 8081]
loss:        5.422901  [ 4512/ 8081]
loss:        5.177522  [ 4832/ 8081]
loss:        4.751125  [ 5152/ 8081]
loss:        5.577045  [ 5472/ 8081]
loss:        6.131419  [ 5792/ 8081]
loss:        6.501829  [ 6112/ 8081]
loss:        5.214920  [ 6432/ 8081]
loss:        5.598885  [ 6752/ 8081]
loss:        5.734118  [ 7072/ 8081]
loss:        5.896174  [ 7392/ 8081]
loss:        5.807665  [ 7712/ 8081]
loss:        5.157558  [ 8032/ 8081]

Avg training loss:        5.735473
Val loss: 5.640158  (hab: 9.143460, mov: 5.671050)
Validation loss decreased (5.640463 → 5.640158). Saving model…

Epoch 25/50
loss:        6.161477  [   32/ 8081]
loss:        5.648900  [  352/ 8081]
loss:        5.558351  [  672/ 8081]
loss:        5.186936  [  992/ 8081]
loss:        5.539708  [ 1312/ 8081]
loss:        5.807642  [ 1632/ 8081]
loss:        6.611417  [ 1952/ 8081]
loss:        5.423437  [ 2272/ 8081]
loss:        5.694284  [ 2592/ 8081]
loss:        5.506063  [ 2912/ 8081]
loss:        4.919890  [ 3232/ 8081]
loss:        6.296585  [ 3552/ 8081]
loss:        5.463380  [ 3872/ 8081]
loss:        5.580927  [ 4192/ 8081]
loss:        6.874486  [ 4512/ 8081]
loss:        4.920260  [ 4832/ 8081]
loss:        6.043089  [ 5152/ 8081]
loss:        5.447857  [ 5472/ 8081]
loss:        6.368416  [ 5792/ 8081]
loss:        4.843377  [ 6112/ 8081]
loss:        6.345890  [ 6432/ 8081]
loss:        5.116454  [ 6752/ 8081]
loss:        6.050868  [ 7072/ 8081]
loss:        6.241920  [ 7392/ 8081]
loss:        6.305410  [ 7712/ 8081]
loss:        5.232047  [ 8032/ 8081]

Avg training loss:        5.734663
Val loss: 5.639725  (hab: 9.144361, mov: 5.670998)
Validation loss decreased (5.640158 → 5.639725). Saving model…

Epoch 26/50
loss:        7.151555  [   32/ 8081]
loss:        5.414118  [  352/ 8081]
loss:        6.189126  [  672/ 8081]
loss:        5.227577  [  992/ 8081]
loss:        6.818292  [ 1312/ 8081]
loss:        5.771779  [ 1632/ 8081]
loss:        5.577074  [ 1952/ 8081]
loss:        5.630198  [ 2272/ 8081]
loss:        5.653317  [ 2592/ 8081]
loss:        5.436688  [ 2912/ 8081]
loss:        5.958178  [ 3232/ 8081]
loss:        5.440780  [ 3552/ 8081]
loss:        5.431932  [ 3872/ 8081]
loss:        5.609800  [ 4192/ 8081]
loss:        5.665380  [ 4512/ 8081]
loss:        5.957418  [ 4832/ 8081]
loss:        4.937189  [ 5152/ 8081]
loss:        6.138237  [ 5472/ 8081]
loss:        6.117325  [ 5792/ 8081]
loss:        5.992046  [ 6112/ 8081]
loss:        5.035103  [ 6432/ 8081]
loss:        6.369143  [ 6752/ 8081]
loss:        5.848080  [ 7072/ 8081]
loss:        6.836145  [ 7392/ 8081]
loss:        5.299628  [ 7712/ 8081]
loss:        5.602996  [ 8032/ 8081]

Avg training loss:        5.735829
Val loss: 5.639682  (hab: 9.145324, mov: 5.670968)
Validation loss decreased (5.639725 → 5.639682). Saving model…

Epoch 27/50
loss:        4.927288  [   32/ 8081]
loss:        6.503007  [  352/ 8081]
loss:        5.772881  [  672/ 8081]
loss:        5.601285  [  992/ 8081]
loss:        5.897579  [ 1312/ 8081]
loss:        5.781460  [ 1632/ 8081]
loss:        5.490117  [ 1952/ 8081]
loss:        5.170106  [ 2272/ 8081]
loss:        6.280233  [ 2592/ 8081]
loss:        5.659313  [ 2912/ 8081]
loss:        5.048002  [ 3232/ 8081]
loss:        5.602965  [ 3552/ 8081]
loss:        6.199113  [ 3872/ 8081]
loss:        5.848446  [ 4192/ 8081]
loss:        5.756188  [ 4512/ 8081]
loss:        5.665813  [ 4832/ 8081]
loss:        4.977673  [ 5152/ 8081]
loss:        5.520265  [ 5472/ 8081]
loss:        5.220546  [ 5792/ 8081]
loss:        5.660740  [ 6112/ 8081]
loss:        5.151398  [ 6432/ 8081]
loss:        6.287862  [ 6752/ 8081]
loss:        5.646356  [ 7072/ 8081]
loss:        6.649393  [ 7392/ 8081]
loss:        5.878810  [ 7712/ 8081]
loss:        5.759536  [ 8032/ 8081]

Avg training loss:        5.733462
Val loss: 5.639660  (hab: 9.145306, mov: 5.670936)
Validation loss decreased (5.639682 → 5.639660). Saving model…

Epoch 28/50
loss:        6.236925  [   32/ 8081]
loss:        6.146458  [  352/ 8081]
loss:        5.195530  [  672/ 8081]
loss:        5.583611  [  992/ 8081]
loss:        5.273785  [ 1312/ 8081]
loss:        5.463507  [ 1632/ 8081]
loss:        5.881322  [ 1952/ 8081]
loss:        6.306096  [ 2272/ 8081]
loss:        5.130684  [ 2592/ 8081]
loss:        5.816171  [ 2912/ 8081]
loss:        6.420640  [ 3232/ 8081]
loss:        5.765622  [ 3552/ 8081]
loss:        5.842854  [ 3872/ 8081]
loss:        5.821796  [ 4192/ 8081]
loss:        5.756525  [ 4512/ 8081]
loss:        5.565989  [ 4832/ 8081]
loss:        5.573711  [ 5152/ 8081]
loss:        5.448873  [ 5472/ 8081]
loss:        6.393813  [ 5792/ 8081]
loss:        5.563116  [ 6112/ 8081]
loss:        7.175228  [ 6432/ 8081]
loss:        4.962247  [ 6752/ 8081]
loss:        5.873940  [ 7072/ 8081]
loss:        5.759718  [ 7392/ 8081]
loss:        5.624741  [ 7712/ 8081]
loss:        5.907386  [ 8032/ 8081]

Avg training loss:        5.733195
Val loss: 5.639641  (hab: 9.145159, mov: 5.670897)
Validation loss decreased (5.639660 → 5.639641). Saving model…

Epoch 29/50
loss:        6.709939  [   32/ 8081]
loss:        6.194642  [  352/ 8081]
loss:        5.739034  [  672/ 8081]
loss:        5.670167  [  992/ 8081]
loss:        4.941658  [ 1312/ 8081]
loss:        5.502157  [ 1632/ 8081]
loss:        6.519639  [ 1952/ 8081]
loss:        5.792060  [ 2272/ 8081]
loss:        5.513575  [ 2592/ 8081]
loss:        5.989444  [ 2912/ 8081]
loss:        6.193710  [ 3232/ 8081]
loss:        5.135377  [ 3552/ 8081]
loss:        5.945956  [ 3872/ 8081]
loss:        4.705295  [ 4192/ 8081]
loss:        5.501117  [ 4512/ 8081]
loss:        5.852671  [ 4832/ 8081]
loss:        6.271144  [ 5152/ 8081]
loss:        5.952385  [ 5472/ 8081]
loss:        5.365061  [ 5792/ 8081]
loss:        6.104717  [ 6112/ 8081]
loss:        5.881651  [ 6432/ 8081]
loss:        5.851075  [ 6752/ 8081]
loss:        5.142447  [ 7072/ 8081]
loss:        5.647768  [ 7392/ 8081]
loss:        5.986320  [ 7712/ 8081]
loss:        5.402556  [ 8032/ 8081]

Avg training loss:        5.734142
Val loss: 5.639626  (hab: 9.145162, mov: 5.670893)
Validation loss decreased (5.639641 → 5.639626). Saving model…

Epoch 30/50
loss:        5.612227  [   32/ 8081]
loss:        5.494819  [  352/ 8081]
loss:        4.629119  [  672/ 8081]
loss:        6.452594  [  992/ 8081]
loss:        5.972569  [ 1312/ 8081]
loss:        5.160095  [ 1632/ 8081]
loss:        5.795186  [ 1952/ 8081]
loss:        5.309612  [ 2272/ 8081]
loss:        5.891468  [ 2592/ 8081]
loss:        5.696829  [ 2912/ 8081]
loss:        5.713593  [ 3232/ 8081]
loss:        6.069798  [ 3552/ 8081]
loss:        4.935905  [ 3872/ 8081]
loss:        5.428632  [ 4192/ 8081]
loss:        5.866337  [ 4512/ 8081]
loss:        5.686249  [ 4832/ 8081]
loss:        5.791387  [ 5152/ 8081]
loss:        6.364533  [ 5472/ 8081]
loss:        5.466154  [ 5792/ 8081]
loss:        5.667849  [ 6112/ 8081]
loss:        5.313776  [ 6432/ 8081]
loss:        6.761433  [ 6752/ 8081]
loss:        6.629786  [ 7072/ 8081]
loss:        5.532158  [ 7392/ 8081]
loss:        6.032206  [ 7712/ 8081]
loss:        5.922668  [ 8032/ 8081]

Avg training loss:        5.734589
Val loss: 5.639633  (hab: 9.145213, mov: 5.670889)
EarlyStopping counter: 1 out of 5

Epoch 31/50
loss:        6.855468  [   32/ 8081]
loss:        5.549836  [  352/ 8081]
loss:        5.501370  [  672/ 8081]
loss:        6.163689  [  992/ 8081]
loss:        6.146986  [ 1312/ 8081]
loss:        5.668081  [ 1632/ 8081]
loss:        5.154352  [ 1952/ 8081]
loss:        6.032871  [ 2272/ 8081]
loss:        5.482692  [ 2592/ 8081]
loss:        6.838123  [ 2912/ 8081]
loss:        6.182957  [ 3232/ 8081]
loss:        5.312598  [ 3552/ 8081]
loss:        5.461801  [ 3872/ 8081]
loss:        5.386752  [ 4192/ 8081]
loss:        6.306053  [ 4512/ 8081]
loss:        5.683695  [ 4832/ 8081]
loss:        5.795979  [ 5152/ 8081]
loss:        5.956809  [ 5472/ 8081]
loss:        5.825184  [ 5792/ 8081]
loss:        5.104815  [ 6112/ 8081]
loss:        5.338410  [ 6432/ 8081]
loss:        4.392776  [ 6752/ 8081]
loss:        5.910175  [ 7072/ 8081]
loss:        4.851960  [ 7392/ 8081]
loss:        5.446443  [ 7712/ 8081]
loss:        5.372225  [ 8032/ 8081]

Avg training loss:        5.733252
Val loss: 5.639620  (hab: 9.145216, mov: 5.670887)
Validation loss decreased (5.639626 → 5.639620). Saving model…

Epoch 32/50
loss:        5.426272  [   32/ 8081]
loss:        5.630202  [  352/ 8081]
loss:        5.632508  [  672/ 8081]
loss:        5.438181  [  992/ 8081]
loss:        5.693160  [ 1312/ 8081]
loss:        5.824766  [ 1632/ 8081]
loss:        4.332841  [ 1952/ 8081]
loss:        5.520078  [ 2272/ 8081]
loss:        5.456306  [ 2592/ 8081]
loss:        5.956411  [ 2912/ 8081]
loss:        5.646355  [ 3232/ 8081]
loss:        5.780666  [ 3552/ 8081]
loss:        5.656393  [ 3872/ 8081]
loss:        6.113107  [ 4192/ 8081]
loss:        5.247138  [ 4512/ 8081]
loss:        6.456845  [ 4832/ 8081]
loss:        5.673826  [ 5152/ 8081]
loss:        6.329019  [ 5472/ 8081]
loss:        6.144670  [ 5792/ 8081]
loss:        6.344707  [ 6112/ 8081]
loss:        6.364787  [ 6432/ 8081]
loss:        5.852273  [ 6752/ 8081]
loss:        5.987990  [ 7072/ 8081]
loss:        5.569731  [ 7392/ 8081]
loss:        6.279569  [ 7712/ 8081]
loss:        5.453013  [ 8032/ 8081]

Avg training loss:        5.732120
Val loss: 5.639620  (hab: 9.145214, mov: 5.670886)
Validation loss decreased (5.639620 → 5.639620). Saving model…

Epoch 33/50
loss:        5.268016  [   32/ 8081]
loss:        5.237939  [  352/ 8081]
loss:        5.922724  [  672/ 8081]
loss:        5.021144  [  992/ 8081]
loss:        6.242224  [ 1312/ 8081]
loss:        5.304910  [ 1632/ 8081]
loss:        6.274157  [ 1952/ 8081]
loss:        5.443479  [ 2272/ 8081]
loss:        5.825380  [ 2592/ 8081]
loss:        4.930949  [ 2912/ 8081]
loss:        5.148467  [ 3232/ 8081]
loss:        5.614328  [ 3552/ 8081]
loss:        5.556447  [ 3872/ 8081]
loss:        6.509217  [ 4192/ 8081]
loss:        5.418288  [ 4512/ 8081]
loss:        5.584445  [ 4832/ 8081]
loss:        6.188684  [ 5152/ 8081]
loss:        5.922045  [ 5472/ 8081]
loss:        5.749983  [ 5792/ 8081]
loss:        5.129826  [ 6112/ 8081]
loss:        5.641623  [ 6432/ 8081]
loss:        7.205474  [ 6752/ 8081]
loss:        5.569323  [ 7072/ 8081]
loss:        5.049457  [ 7392/ 8081]
loss:        5.763159  [ 7712/ 8081]
loss:        5.636391  [ 8032/ 8081]

Avg training loss:        5.731416
Val loss: 5.639621  (hab: 9.145215, mov: 5.670886)
EarlyStopping counter: 1 out of 5

Epoch 34/50
loss:        5.573732  [   32/ 8081]
loss:        4.562861  [  352/ 8081]
loss:        6.424152  [  672/ 8081]
loss:        5.473699  [  992/ 8081]
loss:        5.866921  [ 1312/ 8081]
loss:        5.599736  [ 1632/ 8081]
loss:        5.116332  [ 1952/ 8081]
loss:        5.631752  [ 2272/ 8081]
loss:        5.095023  [ 2592/ 8081]
loss:        6.444837  [ 2912/ 8081]
loss:        5.993291  [ 3232/ 8081]
loss:        6.616193  [ 3552/ 8081]
loss:        5.721396  [ 3872/ 8081]
loss:        5.316614  [ 4192/ 8081]
loss:        4.425422  [ 4512/ 8081]
loss:        5.275918  [ 4832/ 8081]
loss:        5.434562  [ 5152/ 8081]
loss:        5.448647  [ 5472/ 8081]
loss:        5.664916  [ 5792/ 8081]
loss:        5.117067  [ 6112/ 8081]
loss:        5.216536  [ 6432/ 8081]
loss:        6.023912  [ 6752/ 8081]
loss:        6.081383  [ 7072/ 8081]
loss:        6.285754  [ 7392/ 8081]
loss:        4.963202  [ 7712/ 8081]
loss:        5.865150  [ 8032/ 8081]

Avg training loss:        5.731098
Val loss: 5.639621  (hab: 9.145216, mov: 5.670886)
EarlyStopping counter: 2 out of 5

Epoch 35/50
loss:        6.021615  [   32/ 8081]
loss:        4.532571  [  352/ 8081]
loss:        5.175756  [  672/ 8081]
loss:        5.827329  [  992/ 8081]
loss:        5.984054  [ 1312/ 8081]
loss:        5.343431  [ 1632/ 8081]
loss:        6.666986  [ 1952/ 8081]
loss:        6.816831  [ 2272/ 8081]
loss:        5.876522  [ 2592/ 8081]
loss:        5.746874  [ 2912/ 8081]
loss:        5.963414  [ 3232/ 8081]
loss:        4.622033  [ 3552/ 8081]
loss:        6.336753  [ 3872/ 8081]
loss:        5.816704  [ 4192/ 8081]
loss:        5.761320  [ 4512/ 8081]
loss:        5.750998  [ 4832/ 8081]
loss:        4.660877  [ 5152/ 8081]
loss:        5.898358  [ 5472/ 8081]
loss:        5.581824  [ 5792/ 8081]
loss:        4.712275  [ 6112/ 8081]
loss:        5.580365  [ 6432/ 8081]
loss:        5.149635  [ 6752/ 8081]
loss:        4.571750  [ 7072/ 8081]
loss:        5.565580  [ 7392/ 8081]
loss:        6.358205  [ 7712/ 8081]
loss:        5.638477  [ 8032/ 8081]

Avg training loss:        5.731328
Val loss: 5.639621  (hab: 9.145216, mov: 5.670886)
EarlyStopping counter: 3 out of 5

Epoch 36/50
loss:        5.219586  [   32/ 8081]
loss:        5.866108  [  352/ 8081]
loss:        5.157986  [  672/ 8081]
loss:        5.456277  [  992/ 8081]
loss:        6.346135  [ 1312/ 8081]
loss:        6.265892  [ 1632/ 8081]
loss:        5.365193  [ 1952/ 8081]
loss:        5.935129  [ 2272/ 8081]
loss:        5.507366  [ 2592/ 8081]
loss:        5.146176  [ 2912/ 8081]
loss:        6.031135  [ 3232/ 8081]
loss:        5.098520  [ 3552/ 8081]
loss:        6.378906  [ 3872/ 8081]
loss:        5.300342  [ 4192/ 8081]
loss:        6.126292  [ 4512/ 8081]
loss:        5.154186  [ 4832/ 8081]
loss:        5.361709  [ 5152/ 8081]
loss:        5.572066  [ 5472/ 8081]
loss:        5.279382  [ 5792/ 8081]
loss:        4.960338  [ 6112/ 8081]
loss:        6.301264  [ 6432/ 8081]
loss:        5.728517  [ 6752/ 8081]
loss:        6.473097  [ 7072/ 8081]
loss:        6.075172  [ 7392/ 8081]
loss:        5.903062  [ 7712/ 8081]
loss:        6.081424  [ 8032/ 8081]

Avg training loss:        5.735796
Val loss: 5.639621  (hab: 9.145216, mov: 5.670885)
EarlyStopping counter: 4 out of 5

Epoch 37/50
loss:        6.125269  [   32/ 8081]
loss:        5.612249  [  352/ 8081]
loss:        5.744735  [  672/ 8081]
loss:        5.696525  [  992/ 8081]
loss:        6.524700  [ 1312/ 8081]
loss:        5.233836  [ 1632/ 8081]
loss:        6.115005  [ 1952/ 8081]
loss:        5.521945  [ 2272/ 8081]
loss:        5.175142  [ 2592/ 8081]
loss:        6.153426  [ 2912/ 8081]
loss:        6.875118  [ 3232/ 8081]
loss:        5.568776  [ 3552/ 8081]
loss:        6.070964  [ 3872/ 8081]
loss:        4.977057  [ 4192/ 8081]
loss:        6.165647  [ 4512/ 8081]
loss:        6.339566  [ 4832/ 8081]
loss:        5.571197  [ 5152/ 8081]
loss:        5.252230  [ 5472/ 8081]
loss:        6.161972  [ 5792/ 8081]
loss:        5.660664  [ 6112/ 8081]
loss:        6.967310  [ 6432/ 8081]
loss:        4.854573  [ 6752/ 8081]
loss:        5.381294  [ 7072/ 8081]
loss:        5.590603  [ 7392/ 8081]
loss:        5.565978  [ 7712/ 8081]
loss:        5.514409  [ 8032/ 8081]

Avg training loss:        5.734960
Val loss: 5.639621  (hab: 9.145216, mov: 5.670885)
EarlyStopping counter: 5 out of 5
Early stopping triggered.
gif_path = str(OUTPUT_DIR / "training_progress.gif")
create_gif(str(SNAPSHOT_DIR), gif_path, fps=5)
print(f"Training GIF → {gif_path}")

# Play the gif
Image(gif_path)
Animation saved: outputs/training_progress.gif
Training GIF → outputs/training_progress.gif
<IPython.core.display.Image object>
fig, ax = plt.subplots(figsize=(7, 3))
ax.plot(history["train_losses"], label="train")
ax.plot(history["val_losses"],   label="val")
ax.set_xlabel("Epoch")
ax.set_ylabel("NLL loss")
ax.legend()
ax.set_title("Training history")
plt.tight_layout()
plt.savefig(str(OUTPUT_DIR / "loss_history.png"), dpi=100)
plt.show()

4 Validate against withheld test data

# Load the best checkpoint saved during training
model.load_state_dict(
    torch.load(str(OUTPUT_DIR / "best_model.pt"), map_location=DEVICE)
)
model.eval()
print("Best checkpoint loaded.")
Best checkpoint loaded.
# Wrap the static rasters as a landscape-loader callable.
# Month index is ignored here because we have static (non-S2) layers.
ndvi_t  = torch.from_numpy(env_layers["ndvi"].astype(np.float32))
slope_t = torch.from_numpy(env_layers["slope"].astype(np.float32))

def get_landscape(_month_index):
    return [ndvi_t, slope_t]
# Use the last 10 % of steps as the test set (matches the DataLoader split)
n_test      = int(len(step_df) * 0.1)
test_sample = step_df.iloc[-n_test:].reset_index(drop=True)
print(f"Validating on {len(test_sample):,} steps")

val_results = validate_next_step_probs(
    model,
    test_sample,
    get_landscape=get_landscape,
    transform=raster_transform,
    window_size=WINDOW_SIZE,
    scalar_cols=tuple(SCALAR_COLS),
    yday_col="yday_t1",
    bearing_col="bearing_tm1",
    month_index_fn=lambda _yday: 0,   # static layers — month doesn't matter
)
print(val_results[["habitat_prob", "move_prob", "next_step_prob"]].describe())
Validating on 1,010 steps
       habitat_prob    move_prob  next_step_prob
count   1010.000000  1010.000000     1010.000000
mean       0.000126     0.027330        0.030069
std        0.000050     0.054359        0.060509
min        0.000000     0.000000        0.000000
25%        0.000093     0.000224        0.000228
50%        0.000112     0.006624        0.007082
75%        0.000155     0.030587        0.034205
max        0.000430     0.255475        0.334267
fig, axes = plt.subplots(1, 3, figsize=(12, 3))
for ax, col, title in zip(
    axes,
    ["habitat_prob", "move_prob", "next_step_prob"],
    ["Habitat", "Movement", "Joint next-step"],
):
    data = val_results[col].dropna()
    ax.hist(data, bins=40, edgecolor="white")
    ax.set_title(title)
    ax.set_xlabel("Probability")
axes[0].set_ylabel("Count")
plt.suptitle("Next-step probability distributions (test set)", y=1.02)
plt.tight_layout()
plt.savefig(str(OUTPUT_DIR / "validation_probs.png"), dpi=100)
plt.show()

5 Simulate a trajectory

start_x = float(step_df["x1_"].median())
start_y = float(step_df["y1_"].median())
print(f"Starting location: ({start_x:.0f}, {start_y:.0f})")

sim_df = simulate_trajectory(
    model,
    get_landscape=get_landscape,
    transform=raster_transform,
    start_x=start_x,
    start_y=start_y,
    n_steps=1000,
    starting_yday=test_sample["yday_t1"].iloc[0],
    starting_hour=test_sample["hour_t1"].iloc[0],
    time_between_steps=1.0,
    window_size=WINDOW_SIZE,
    month_index_fn=lambda _yday: 0,   # static layers
)
print(f"Simulated {len(sim_df)} steps")
sim_df.head()
Starting location: (36111, -1436905)
Simulated 1000 steps
x y hour yday month_index hab_log_prob move_log_prob step_log_prob
0 36108.269587 -1.436912e+06 18.916667 255.000000 0 [[-9.412999, -9.240336, -9.038588, -8.982068, ... [[-13.963823, -13.919329, -13.875055, -13.8310... [[-23.376823, -23.159664, -22.913643, -22.8130...
1 35898.710528 -1.436614e+06 19.916667 255.041667 0 [[-9.405268, -9.207009, -8.999849, -8.948765, ... [[-13.95573, -13.91391, -13.872377, -13.831140... [[-23.360998, -23.120918, -22.872227, -22.7799...
2 36464.367638 -1.436123e+06 20.916667 255.083333 0 [[-9.160653, -9.023066, -8.815113, -8.739051, ... [[-13.513008, -13.469606, -13.426506, -13.3837... [[-22.67366, -22.492672, -22.24162, -22.122772...
3 36435.540267 -1.436041e+06 21.916667 255.125000 0 [[-8.481404, -8.342018, -8.182885, -8.192506, ... [[-13.816451, -13.771689, -13.727157, -13.6828... [[-22.297855, -22.113708, -21.910042, -21.8753...
4 36439.641807 -1.436043e+06 22.916667 255.166667 0 [[-8.254223, -8.155582, -7.9846506, -7.9550314... [[-13.654331, -13.611277, -13.568495, -13.5259... [[-21.908554, -21.76686, -21.553146, -21.48103...
fig, ax = plt.subplots(figsize=(7, 6))
rasterio.plot.show(env_layers["ndvi"], transform=raster_transform, ax=ax, cmap='viridis')
ax.plot(
    step_df["x1_"], step_df["y1_"],
    "w-o", alpha=0.5, markersize=1, linewidth=0.8, label="observed",
)
ax.plot(
    sim_df["x"], sim_df["y"],
    "r-o", markersize=1, linewidth=0.8, label="simulated",
)
ax.set_xlabel("Easting (m, projected)")
ax.set_ylabel("Northing (m, projected)")
ax.legend()
ax.set_title("Simulated vs observed buffalo locations")
plt.tight_layout()
plt.savefig(str(OUTPUT_DIR / "simulated_trajectory.png"), dpi=100)
plt.show()