End-to-end example using the bundled buffalo GPS dataset and two static raster covariates (NDVI and slope). The notebook calls the package throughout.
Sections 1. Import data 2. Prepare data 3. Train model 4. Validate against withheld data 5. Simulate a trajectory
import math
from pathlib import Path
import matplotlib.pyplot as plt
import numpy as np
import torch
import rasterio
import deepssf
from deepssf import (
ConvJointModel,
EarlyStopping,
ModelParams,
create_gif,
fit,
get_device,
load_environmental_layers,
make_dataloaders,
make_optimisers,
negativeLogLikeLoss,
filter_steps_by_window,
prepare_movement_df,
simulate_trajectory,
validate_next_step_probs,
)
import pandas as pd
from datetime import datetime # Date/time utilities
from IPython.display import Image # For plotting GIFs
from rasterio.plot import show # For plotting raster layers
print (f"deepssf { deepssf. __version__} " )
print (f"torch { torch. __version__} " )
deepssf 0.1.0
torch 2.12.0
# Reproducibility
torch.manual_seed(42 )
np.random.seed(42 )
DEVICE = get_device()
print (f"Device: { DEVICE} " )
1 Import data
DATA_DIR = Path("../src/deepssf/datasets/data" )
CSV_PATH = DATA_DIR / "buffalo_djelk_id2005.csv"
LAYER_PATHS = {
"ndvi" : str (DATA_DIR / "ndvi_2005.tif" ),
"slope" : str (DATA_DIR / "slope_2005.tif" ),
}
OUTPUT_DIR = Path("outputs" )
SNAPSHOT_DIR = OUTPUT_DIR / "training_snapshots"
OUTPUT_DIR.mkdir(exist_ok= True )
raw_df = pd.read_csv(CSV_PATH)
print (f"Raw GPS fixes : { len (raw_df):,} " )
print (f"Individuals : { raw_df['id' ]. nunique()} " )
# Sort by the individual ID and timestamp to ensure correct temporal order
raw_df = raw_df.sort_values(by= ["id" , "time" ]).reset_index(drop= True )
# # Convert the 'time' column to the correct timezone (here Australia/Darwin)
# raw_df["time"] = (
# pd.to_datetime(raw_df["time"], utc=True)
# .dt.tz_convert("Australia/Darwin")
# )
raw_df.head()
Raw GPS fixes : 10,297
Individuals : 1
0
2005
2018-07-25T00:04:02Z
41941.331695
-1.435875e+06
1
2005
2018-07-25T01:04:23Z
41969.310875
-1.435671e+06
2
2005
2018-07-25T02:04:39Z
41921.521939
-1.435654e+06
3
2005
2018-07-25T03:04:17Z
41779.439594
-1.435601e+06
4
2005
2018-07-25T04:04:39Z
41841.203272
-1.435635e+06
env_layers, raster_transform = load_environmental_layers(LAYER_PATHS)
print ("Layers loaded:" , list (env_layers.keys()))
for name, arr in env_layers.items():
print (f" { name} : shape= { arr. shape} , min= { arr. min ():.3f} , max= { arr. max ():.3f} " )
Layer 'ndvi': min=-0.06614641100168228, max=0.731696367263794 → scaled to [0, 1]
Layer 'slope': min=0.00013558653881773353, max=12.298083305358887 → scaled to [0, 1]
Layers loaded: ['ndvi', 'slope']
ndvi: shape=(758, 1175), min=0.000, max=1.090
slope: shape=(758, 1175), min=0.000, max=1.000
2 Prepare data
step_df = prepare_movement_df(raw_df)
print (f"Movement steps : { len (step_df):,} ( { len (raw_df) - len (step_df)} dropped — last fix per individual)" )
step_df.head()
Movement steps : 10,296 (1 dropped — last fix per individual)
0
2005
2018-07-25T00:04:02Z
41941.331695
-1.435875e+06
2018-07-25T01:04:23Z
41969.310875
-1.435671e+06
27.979180
204.064135
1.434536
0.000000
1.005833
0.066667
206.0
0.017452
0.999848
-0.391358
-0.920239
1
2005
2018-07-25T01:04:23Z
41969.310875
-1.435671e+06
2018-07-25T02:04:39Z
41921.521939
-1.435654e+06
-47.788936
16.857110
2.802478
1.434536
1.004444
1.066667
206.0
0.275637
0.961262
-0.391358
-0.920239
2
2005
2018-07-25T02:04:39Z
41921.521939
-1.435654e+06
2018-07-25T03:04:17Z
41779.439594
-1.435601e+06
-142.082345
53.568427
2.781049
2.802478
0.993889
2.066667
206.0
0.515038
0.857167
-0.391358
-0.920239
3
2005
2018-07-25T03:04:17Z
41779.439594
-1.435601e+06
2018-07-25T04:04:39Z
41841.203272
-1.435635e+06
61.763677
-34.322938
-0.507220
2.781049
1.006111
3.066667
206.0
0.719340
0.694658
-0.391358
-0.920239
4
2005
2018-07-25T04:04:39Z
41841.203272
-1.435635e+06
2018-07-25T05:04:27Z
41655.463332
-1.435604e+06
-185.739939
31.003534
2.976198
-0.507220
0.996667
4.066667
206.0
0.874620
0.484810
-0.391358
-0.920239
# Model configuration
# 4 scalar covariates broadcast into spatial maps (no dt_hour so simulate() matches)
SCALAR_COLS = ["hour_t1_sin1" , "hour_t1_cos1" , "yday_t1_sin1" , "yday_t1_cos1" , "dt_hour" ]
WINDOW_SIZE = 101 # small for speed; 101 is typical for production
PIXEL_SIZE = 25 # raster resolution in metres (CRS units)
BATCH_SIZE = 32
# Calculate the flat size of the movement CNN output after 3× conv+maxpool
OUTPUT_CHANNELS = 4
dim = WINDOW_SIZE
for _ in range (3 ):
dim = math.floor((dim + 2 * 1 - 3 ) / 1 + 1 ) # conv: pad=1 keeps dim
dim = math.floor((dim - 2 ) / 2 + 1 ) # maxpool k=2, s=2
DENSE_DIM = OUTPUT_CHANNELS * dim * dim
print (f"After 3× conv+maxpool: dim= { dim} → dense_dim_in_all= { DENSE_DIM} " )
# Drop steps whose displacement exceeds the window; would cause an index error in the loss
step_df = filter_steps_by_window(step_df, window_size= WINDOW_SIZE, pixel_size= PIXEL_SIZE)
print (f"Steps after window filter: { len (step_df):,} " )
After 3× conv+maxpool: dim=12 → dense_dim_in_all=576
Steps after window filter: 10,101
dl_train, dl_val, dl_test = make_dataloaders(
layer_paths= LAYER_PATHS,
window_size= WINDOW_SIZE,
batch_size= BATCH_SIZE,
train_split= 0.8 ,
val_split= 0.1 ,
scalar_cols= SCALAR_COLS,
df= step_df,
)
print (f"Train batches : { len (dl_train)} " )
print (f"Val batches : { len (dl_val)} " )
print (f"Test batches : { len (dl_test)} " )
Layer 'ndvi': min=-0.06614641100168228, max=0.731696367263794 → scaled to [0, 1]
Layer 'slope': min=0.00013558653881773353, max=12.298083305358887 → scaled to [0, 1]
Train batches : 253
Val batches : 32
Test batches : 32
/Users/scottforrest/Library/Mobile Documents/com~apple~CloudDocs/github_repos/deepSSF_package/src/deepssf/data.py:358: UserWarning: The given NumPy array is not writable, and PyTorch does not support non-writable tensors. This means writing to this tensor will result in undefined behavior. You may want to copy the array to protect its data or make it writable before converting it to a tensor. This type of warning will be suppressed for the rest of this program. (Triggered internally at /Users/runner/work/pytorch/pytorch/torch/csrc/utils/tensor_numpy.cpp:219.)
self.scalar_to_grid_data = torch.from_numpy(
# Inspect one batch
x1, x2, x3, (px2, py2), _ = next (iter (dl_train))
print (f"Spatial x1 : { x1. shape} (batch × channels × H × W)" )
print (f"Scalars x2 : { x2. shape} " )
print (f"Bearing x3 : { x3. shape} " )
print (f"Next-step px2: { px2[:4 ]. tolist()} py2: { py2[:4 ]. tolist()} " )
Spatial x1 : torch.Size([32, 2, 101, 101]) (batch × channels × H × W)
Scalars x2 : torch.Size([32, 5])
Bearing x3 : torch.Size([32, 1])
Next-step px2: [54, 51, 50, 50] py2: [51, 53, 50, 50]
3 Train model
params = ModelParams({
"batch_size" : BATCH_SIZE,
"image_dim" : WINDOW_SIZE,
"pixel_size" : PIXEL_SIZE,
"dim_in_nonspatial_to_grid" : len (SCALAR_COLS),
"dense_dim_in_nonspatial" : len (SCALAR_COLS),
"dense_dim_hidden" : 64 ,
"dense_dim_in_all" : DENSE_DIM,
"input_channels" : 2 + len (SCALAR_COLS), # 2 spatial + 4 scalar-to-grid
"output_channels" : OUTPUT_CHANNELS,
"kernel_size" : 3 ,
"stride" : 1 ,
"kernel_size_mp" : 2 ,
"stride_mp" : 2 ,
"padding" : 1 ,
"num_movement_params" : 12 ,
"dropout" : 0.0 ,
"device" : DEVICE,
})
model = ConvJointModel(params).to(DEVICE)
n_params = sum (p.numel() for p in model.parameters())
print (f"Model parameters : { n_params:,} " )
Model parameters : 43,009
loss_fn = negativeLogLikeLoss(reduction= "mean" )
optimisers, schedulers = make_optimisers(
model, lr_habitat= 1e-4 , lr_movement= 1e-5 , scheduler_patience= 2
)
early_stop = EarlyStopping(
patience= 5 , verbose= True , path= str (OUTPUT_DIR / "best_model.pt" )
)
history = fit(
model= model,
n_conv_layers= 4 , # how many convolutional layers in the habitat selection process (to filter edge artifacts that affect colour scaling)
window_size= WINDOW_SIZE,
dl_train= dl_train,
dl_val= dl_val,
loss_fn= loss_fn,
optimisers= optimisers,
schedulers= schedulers,
n_epochs= 50 ,
early_stopping= early_stop,
snapshot_dir= str (SNAPSHOT_DIR),
snapshot_item= 500 ,
)
Epoch 1/50
loss: 6.785439 [ 32/ 8081]
loss: 6.097230 [ 352/ 8081]
loss: 5.551458 [ 672/ 8081]
loss: 5.073307 [ 992/ 8081]
loss: 6.436917 [ 1312/ 8081]
loss: 6.255669 [ 1632/ 8081]
loss: 5.437774 [ 1952/ 8081]
loss: 5.842318 [ 2272/ 8081]
loss: 6.056246 [ 2592/ 8081]
loss: 5.573969 [ 2912/ 8081]
loss: 5.683355 [ 3232/ 8081]
loss: 5.886904 [ 3552/ 8081]
loss: 5.906250 [ 3872/ 8081]
loss: 4.869259 [ 4192/ 8081]
loss: 6.376064 [ 4512/ 8081]
loss: 5.547744 [ 4832/ 8081]
loss: 6.104920 [ 5152/ 8081]
loss: 5.175964 [ 5472/ 8081]
loss: 6.325367 [ 5792/ 8081]
loss: 5.473641 [ 6112/ 8081]
loss: 4.946332 [ 6432/ 8081]
loss: 5.967487 [ 6752/ 8081]
loss: 5.914863 [ 7072/ 8081]
loss: 5.850795 [ 7392/ 8081]
loss: 5.686168 [ 7712/ 8081]
loss: 5.253773 [ 8032/ 8081]
Avg training loss: 5.788497
Val loss: 5.687180 (hab: 9.166112, mov: 5.704777)
Validation loss decreased (5.708087 → 5.687180). Saving model…
Epoch 2/50
loss: 5.694406 [ 32/ 8081]
loss: 6.153904 [ 352/ 8081]
loss: 6.037060 [ 672/ 8081]
loss: 6.598245 [ 992/ 8081]
loss: 4.828341 [ 1312/ 8081]
loss: 5.693692 [ 1632/ 8081]
loss: 4.467628 [ 1952/ 8081]
loss: 6.476704 [ 2272/ 8081]
loss: 5.603591 [ 2592/ 8081]
loss: 5.886862 [ 2912/ 8081]
loss: 5.528289 [ 3232/ 8081]
loss: 5.574023 [ 3552/ 8081]
loss: 6.055640 [ 3872/ 8081]
loss: 6.516983 [ 4192/ 8081]
loss: 6.326391 [ 4512/ 8081]
loss: 4.843784 [ 4832/ 8081]
loss: 6.195689 [ 5152/ 8081]
loss: 5.668516 [ 5472/ 8081]
loss: 5.597651 [ 5792/ 8081]
loss: 4.522491 [ 6112/ 8081]
loss: 6.149788 [ 6432/ 8081]
loss: 5.482637 [ 6752/ 8081]
loss: 6.251376 [ 7072/ 8081]
loss: 4.979764 [ 7392/ 8081]
loss: 5.590664 [ 7712/ 8081]
loss: 6.732779 [ 8032/ 8081]
Avg training loss: 5.780206
Val loss: 5.680278 (hab: 9.167198, mov: 5.698563)
Validation loss decreased (5.687180 → 5.680278). Saving model…
Epoch 3/50
loss: 5.771159 [ 32/ 8081]
loss: 4.998755 [ 352/ 8081]
loss: 6.109458 [ 672/ 8081]
loss: 6.072502 [ 992/ 8081]
loss: 5.710176 [ 1312/ 8081]
loss: 5.001541 [ 1632/ 8081]
loss: 4.916896 [ 1952/ 8081]
loss: 5.842347 [ 2272/ 8081]
loss: 6.470130 [ 2592/ 8081]
loss: 5.437813 [ 2912/ 8081]
loss: 5.558701 [ 3232/ 8081]
loss: 5.935304 [ 3552/ 8081]
loss: 5.704793 [ 3872/ 8081]
loss: 7.646010 [ 4192/ 8081]
loss: 5.135283 [ 4512/ 8081]
loss: 5.725735 [ 4832/ 8081]
loss: 5.999853 [ 5152/ 8081]
loss: 5.170563 [ 5472/ 8081]
loss: 6.050817 [ 5792/ 8081]
loss: 5.566458 [ 6112/ 8081]
loss: 6.288377 [ 6432/ 8081]
loss: 5.568368 [ 6752/ 8081]
loss: 6.243468 [ 7072/ 8081]
loss: 5.513037 [ 7392/ 8081]
loss: 5.912124 [ 7712/ 8081]
loss: 5.924772 [ 8032/ 8081]
Avg training loss: 5.773384
Val loss: 5.673004 (hab: 9.161399, mov: 5.693920)
Validation loss decreased (5.680278 → 5.673004). Saving model…
Epoch 4/50
loss: 5.164478 [ 32/ 8081]
loss: 5.978908 [ 352/ 8081]
loss: 6.380840 [ 672/ 8081]
loss: 5.990369 [ 992/ 8081]
loss: 4.989443 [ 1312/ 8081]
loss: 6.138777 [ 1632/ 8081]
loss: 5.724806 [ 1952/ 8081]
loss: 5.964715 [ 2272/ 8081]
loss: 5.613798 [ 2592/ 8081]
loss: 5.038851 [ 2912/ 8081]
loss: 5.382350 [ 3232/ 8081]
loss: 6.172091 [ 3552/ 8081]
loss: 6.201478 [ 3872/ 8081]
loss: 6.695495 [ 4192/ 8081]
loss: 6.798589 [ 4512/ 8081]
loss: 6.260759 [ 4832/ 8081]
loss: 5.514187 [ 5152/ 8081]
loss: 5.574938 [ 5472/ 8081]
loss: 5.461888 [ 5792/ 8081]
loss: 5.341826 [ 6112/ 8081]
loss: 5.825900 [ 6432/ 8081]
loss: 5.031033 [ 6752/ 8081]
loss: 5.472840 [ 7072/ 8081]
loss: 6.054863 [ 7392/ 8081]
loss: 5.984641 [ 7712/ 8081]
loss: 5.449839 [ 8032/ 8081]
Avg training loss: 5.771757
Val loss: 5.670019 (hab: 9.161964, mov: 5.690092)
Validation loss decreased (5.673004 → 5.670019). Saving model…
Epoch 5/50
loss: 5.468553 [ 32/ 8081]
loss: 5.154895 [ 352/ 8081]
loss: 5.961631 [ 672/ 8081]
loss: 5.888359 [ 992/ 8081]
loss: 6.175714 [ 1312/ 8081]
loss: 6.281458 [ 1632/ 8081]
loss: 6.169101 [ 1952/ 8081]
loss: 6.215828 [ 2272/ 8081]
loss: 5.834133 [ 2592/ 8081]
loss: 5.974036 [ 2912/ 8081]
loss: 5.318532 [ 3232/ 8081]
loss: 5.839583 [ 3552/ 8081]
loss: 5.984657 [ 3872/ 8081]
loss: 6.032445 [ 4192/ 8081]
loss: 5.217687 [ 4512/ 8081]
loss: 5.736081 [ 4832/ 8081]
loss: 6.185085 [ 5152/ 8081]
loss: 6.298941 [ 5472/ 8081]
loss: 5.976015 [ 5792/ 8081]
loss: 6.169477 [ 6112/ 8081]
loss: 6.504289 [ 6432/ 8081]
loss: 5.326622 [ 6752/ 8081]
loss: 6.042302 [ 7072/ 8081]
loss: 5.940510 [ 7392/ 8081]
loss: 5.981143 [ 7712/ 8081]
loss: 6.376686 [ 8032/ 8081]
Avg training loss: 5.765343
Val loss: 5.667132 (hab: 9.160795, mov: 5.687070)
Validation loss decreased (5.670019 → 5.667132). Saving model…
Epoch 6/50
loss: 5.989162 [ 32/ 8081]
loss: 5.185757 [ 352/ 8081]
loss: 6.094879 [ 672/ 8081]
loss: 6.327469 [ 992/ 8081]
loss: 6.278698 [ 1312/ 8081]
loss: 7.185611 [ 1632/ 8081]
loss: 6.149991 [ 1952/ 8081]
loss: 6.799704 [ 2272/ 8081]
loss: 5.566303 [ 2592/ 8081]
loss: 5.896247 [ 2912/ 8081]
loss: 6.005388 [ 3232/ 8081]
loss: 5.651654 [ 3552/ 8081]
loss: 6.483175 [ 3872/ 8081]
loss: 5.532136 [ 4192/ 8081]
loss: 6.001493 [ 4512/ 8081]
loss: 6.396995 [ 4832/ 8081]
loss: 6.277179 [ 5152/ 8081]
loss: 5.335245 [ 5472/ 8081]
loss: 5.850138 [ 5792/ 8081]
loss: 5.985559 [ 6112/ 8081]
loss: 5.363245 [ 6432/ 8081]
loss: 5.474362 [ 6752/ 8081]
loss: 6.610009 [ 7072/ 8081]
loss: 4.994096 [ 7392/ 8081]
loss: 6.311173 [ 7712/ 8081]
loss: 5.433157 [ 8032/ 8081]
Avg training loss: 5.762437
Val loss: 5.664561 (hab: 9.163505, mov: 5.684692)
Validation loss decreased (5.667132 → 5.664561). Saving model…
Epoch 7/50
loss: 5.466299 [ 32/ 8081]
loss: 5.562140 [ 352/ 8081]
loss: 5.990706 [ 672/ 8081]
loss: 6.745186 [ 992/ 8081]
loss: 5.897082 [ 1312/ 8081]
loss: 6.342498 [ 1632/ 8081]
loss: 6.487095 [ 1952/ 8081]
loss: 5.308664 [ 2272/ 8081]
loss: 6.215004 [ 2592/ 8081]
loss: 5.250403 [ 2912/ 8081]
loss: 5.854526 [ 3232/ 8081]
loss: 5.654216 [ 3552/ 8081]
loss: 5.539589 [ 3872/ 8081]
loss: 5.537636 [ 4192/ 8081]
loss: 5.496979 [ 4512/ 8081]
loss: 5.635530 [ 4832/ 8081]
loss: 4.809999 [ 5152/ 8081]
loss: 5.620271 [ 5472/ 8081]
loss: 6.103968 [ 5792/ 8081]
loss: 6.340637 [ 6112/ 8081]
loss: 6.328131 [ 6432/ 8081]
loss: 5.185500 [ 6752/ 8081]
loss: 5.044396 [ 7072/ 8081]
loss: 6.326553 [ 7392/ 8081]
loss: 5.725074 [ 7712/ 8081]
loss: 5.338499 [ 8032/ 8081]
Avg training loss: 5.757804
Val loss: 5.663774 (hab: 9.160690, mov: 5.682512)
Validation loss decreased (5.664561 → 5.663774). Saving model…
Epoch 8/50
loss: 5.153199 [ 32/ 8081]
loss: 6.033429 [ 352/ 8081]
loss: 6.188162 [ 672/ 8081]
loss: 5.682156 [ 992/ 8081]
loss: 5.402650 [ 1312/ 8081]
loss: 6.339832 [ 1632/ 8081]
loss: 5.566161 [ 1952/ 8081]
loss: 5.892251 [ 2272/ 8081]
loss: 4.953355 [ 2592/ 8081]
loss: 5.078116 [ 2912/ 8081]
loss: 5.694687 [ 3232/ 8081]
loss: 6.260427 [ 3552/ 8081]
loss: 5.378967 [ 3872/ 8081]
loss: 5.724102 [ 4192/ 8081]
loss: 6.579778 [ 4512/ 8081]
loss: 6.327852 [ 4832/ 8081]
loss: 5.288838 [ 5152/ 8081]
loss: 5.571650 [ 5472/ 8081]
loss: 7.053420 [ 5792/ 8081]
loss: 6.272779 [ 6112/ 8081]
loss: 5.462452 [ 6432/ 8081]
loss: 5.919233 [ 6752/ 8081]
loss: 6.486927 [ 7072/ 8081]
loss: 5.169839 [ 7392/ 8081]
loss: 5.864375 [ 7712/ 8081]
loss: 6.144361 [ 8032/ 8081]
Avg training loss: 5.756941
Val loss: 5.659359 (hab: 9.160813, mov: 5.680856)
Validation loss decreased (5.663774 → 5.659359). Saving model…
Epoch 9/50
loss: 5.015366 [ 32/ 8081]
loss: 5.890463 [ 352/ 8081]
loss: 5.667079 [ 672/ 8081]
loss: 6.308356 [ 992/ 8081]
loss: 6.460155 [ 1312/ 8081]
loss: 5.740475 [ 1632/ 8081]
loss: 6.436217 [ 1952/ 8081]
loss: 7.255536 [ 2272/ 8081]
loss: 5.891561 [ 2592/ 8081]
loss: 5.757444 [ 2912/ 8081]
loss: 6.140157 [ 3232/ 8081]
loss: 5.553885 [ 3552/ 8081]
loss: 5.180971 [ 3872/ 8081]
loss: 5.593443 [ 4192/ 8081]
loss: 5.231184 [ 4512/ 8081]
loss: 5.650258 [ 4832/ 8081]
loss: 5.901844 [ 5152/ 8081]
loss: 5.418355 [ 5472/ 8081]
loss: 6.061893 [ 5792/ 8081]
loss: 5.725675 [ 6112/ 8081]
loss: 5.522688 [ 6432/ 8081]
loss: 5.422935 [ 6752/ 8081]
loss: 5.901719 [ 7072/ 8081]
loss: 6.634830 [ 7392/ 8081]
loss: 6.179843 [ 7712/ 8081]
loss: 6.570165 [ 8032/ 8081]
Avg training loss: 5.753735
Val loss: 5.657118 (hab: 9.155861, mov: 5.679439)
Validation loss decreased (5.659359 → 5.657118). Saving model…
Epoch 10/50
loss: 5.979346 [ 32/ 8081]
loss: 6.341550 [ 352/ 8081]
loss: 6.134659 [ 672/ 8081]
loss: 5.003821 [ 992/ 8081]
loss: 5.825746 [ 1312/ 8081]
loss: 6.957564 [ 1632/ 8081]
loss: 5.609220 [ 1952/ 8081]
loss: 5.420880 [ 2272/ 8081]
loss: 5.825545 [ 2592/ 8081]
loss: 6.414145 [ 2912/ 8081]
loss: 4.970078 [ 3232/ 8081]
loss: 4.309336 [ 3552/ 8081]
loss: 6.106659 [ 3872/ 8081]
loss: 6.817013 [ 4192/ 8081]
loss: 5.473575 [ 4512/ 8081]
loss: 5.892172 [ 4832/ 8081]
loss: 6.330991 [ 5152/ 8081]
loss: 4.686843 [ 5472/ 8081]
loss: 6.179124 [ 5792/ 8081]
loss: 5.744292 [ 6112/ 8081]
loss: 5.377276 [ 6432/ 8081]
loss: 5.348831 [ 6752/ 8081]
loss: 5.073267 [ 7072/ 8081]
loss: 4.481338 [ 7392/ 8081]
loss: 5.835503 [ 7712/ 8081]
loss: 4.866236 [ 8032/ 8081]
Avg training loss: 5.750079
Val loss: 5.655317 (hab: 9.154174, mov: 5.678362)
Validation loss decreased (5.657118 → 5.655317). Saving model…
Epoch 11/50
loss: 5.954041 [ 32/ 8081]
loss: 6.094968 [ 352/ 8081]
loss: 4.979434 [ 672/ 8081]
loss: 5.557578 [ 992/ 8081]
loss: 6.135086 [ 1312/ 8081]
loss: 6.096596 [ 1632/ 8081]
loss: 5.810320 [ 1952/ 8081]
loss: 5.773321 [ 2272/ 8081]
loss: 5.708895 [ 2592/ 8081]
loss: 5.832764 [ 2912/ 8081]
loss: 5.584544 [ 3232/ 8081]
loss: 5.994350 [ 3552/ 8081]
loss: 6.183756 [ 3872/ 8081]
loss: 7.011960 [ 4192/ 8081]
loss: 5.242159 [ 4512/ 8081]
loss: 4.903460 [ 4832/ 8081]
loss: 5.612700 [ 5152/ 8081]
loss: 6.321691 [ 5472/ 8081]
loss: 6.183264 [ 5792/ 8081]
loss: 6.033978 [ 6112/ 8081]
loss: 5.172969 [ 6432/ 8081]
loss: 6.431845 [ 6752/ 8081]
loss: 5.793088 [ 7072/ 8081]
loss: 4.720260 [ 7392/ 8081]
loss: 4.866962 [ 7712/ 8081]
loss: 6.244982 [ 8032/ 8081]
Avg training loss: 5.750686
Val loss: 5.653751 (hab: 9.154543, mov: 5.677163)
Validation loss decreased (5.655317 → 5.653751). Saving model…
Epoch 12/50
loss: 6.128786 [ 32/ 8081]
loss: 6.231214 [ 352/ 8081]
loss: 5.599900 [ 672/ 8081]
loss: 5.946187 [ 992/ 8081]
loss: 6.217439 [ 1312/ 8081]
loss: 5.561559 [ 1632/ 8081]
loss: 6.017690 [ 1952/ 8081]
loss: 6.422792 [ 2272/ 8081]
loss: 5.169445 [ 2592/ 8081]
loss: 5.734517 [ 2912/ 8081]
loss: 6.159955 [ 3232/ 8081]
loss: 6.048131 [ 3552/ 8081]
loss: 5.095168 [ 3872/ 8081]
loss: 5.633101 [ 4192/ 8081]
loss: 5.936021 [ 4512/ 8081]
loss: 4.978602 [ 4832/ 8081]
loss: 6.078108 [ 5152/ 8081]
loss: 6.416469 [ 5472/ 8081]
loss: 6.084074 [ 5792/ 8081]
loss: 7.058030 [ 6112/ 8081]
loss: 6.218597 [ 6432/ 8081]
loss: 5.706355 [ 6752/ 8081]
loss: 6.205107 [ 7072/ 8081]
loss: 6.251582 [ 7392/ 8081]
loss: 6.318515 [ 7712/ 8081]
loss: 5.446843 [ 8032/ 8081]
Avg training loss: 5.747251
Val loss: 5.651318 (hab: 9.156074, mov: 5.676356)
Validation loss decreased (5.653751 → 5.651318). Saving model…
Epoch 13/50
loss: 5.735549 [ 32/ 8081]
loss: 5.167077 [ 352/ 8081]
loss: 5.022010 [ 672/ 8081]
loss: 6.689678 [ 992/ 8081]
loss: 5.273829 [ 1312/ 8081]
loss: 5.824816 [ 1632/ 8081]
loss: 5.750481 [ 1952/ 8081]
loss: 5.760807 [ 2272/ 8081]
loss: 4.784536 [ 2592/ 8081]
loss: 7.326108 [ 2912/ 8081]
loss: 6.325606 [ 3232/ 8081]
loss: 5.315575 [ 3552/ 8081]
loss: 5.708946 [ 3872/ 8081]
loss: 6.236599 [ 4192/ 8081]
loss: 6.176717 [ 4512/ 8081]
loss: 6.667757 [ 4832/ 8081]
loss: 4.588668 [ 5152/ 8081]
loss: 5.248083 [ 5472/ 8081]
loss: 6.057298 [ 5792/ 8081]
loss: 5.856156 [ 6112/ 8081]
loss: 6.274911 [ 6432/ 8081]
loss: 6.290979 [ 6752/ 8081]
loss: 5.460048 [ 7072/ 8081]
loss: 4.758325 [ 7392/ 8081]
loss: 7.212662 [ 7712/ 8081]
loss: 6.885795 [ 8032/ 8081]
Avg training loss: 5.746758
Val loss: 5.648002 (hab: 9.150894, mov: 5.675554)
Validation loss decreased (5.651318 → 5.648002). Saving model…
Epoch 14/50
loss: 5.946347 [ 32/ 8081]
loss: 5.041623 [ 352/ 8081]
loss: 5.873816 [ 672/ 8081]
loss: 5.291799 [ 992/ 8081]
loss: 5.365646 [ 1312/ 8081]
loss: 5.800356 [ 1632/ 8081]
loss: 5.706845 [ 1952/ 8081]
loss: 6.113334 [ 2272/ 8081]
loss: 5.965096 [ 2592/ 8081]
loss: 6.214636 [ 2912/ 8081]
loss: 6.481359 [ 3232/ 8081]
loss: 4.853850 [ 3552/ 8081]
loss: 5.281315 [ 3872/ 8081]
loss: 5.762989 [ 4192/ 8081]
loss: 6.134844 [ 4512/ 8081]
loss: 6.813313 [ 4832/ 8081]
loss: 6.060571 [ 5152/ 8081]
loss: 7.026179 [ 5472/ 8081]
loss: 5.946495 [ 5792/ 8081]
loss: 5.180030 [ 6112/ 8081]
loss: 6.355926 [ 6432/ 8081]
loss: 5.490451 [ 6752/ 8081]
loss: 5.836087 [ 7072/ 8081]
loss: 6.027002 [ 7392/ 8081]
loss: 5.614779 [ 7712/ 8081]
loss: 6.035224 [ 8032/ 8081]
Avg training loss: 5.742599
Val loss: 5.648140 (hab: 9.153320, mov: 5.674904)
EarlyStopping counter: 1 out of 5
Epoch 15/50
loss: 5.597585 [ 32/ 8081]
loss: 4.928980 [ 352/ 8081]
loss: 5.134027 [ 672/ 8081]
loss: 5.157074 [ 992/ 8081]
loss: 5.909005 [ 1312/ 8081]
loss: 5.560033 [ 1632/ 8081]
loss: 6.205202 [ 1952/ 8081]
loss: 6.106179 [ 2272/ 8081]
loss: 7.436422 [ 2592/ 8081]
loss: 5.639592 [ 2912/ 8081]
loss: 4.822022 [ 3232/ 8081]
loss: 4.106799 [ 3552/ 8081]
loss: 6.259495 [ 3872/ 8081]
loss: 7.228454 [ 4192/ 8081]
loss: 5.105317 [ 4512/ 8081]
loss: 5.871928 [ 4832/ 8081]
loss: 5.165004 [ 5152/ 8081]
loss: 6.482831 [ 5472/ 8081]
loss: 4.902627 [ 5792/ 8081]
loss: 6.007688 [ 6112/ 8081]
loss: 4.781419 [ 6432/ 8081]
loss: 5.254600 [ 6752/ 8081]
loss: 6.425246 [ 7072/ 8081]
loss: 6.281509 [ 7392/ 8081]
loss: 5.598694 [ 7712/ 8081]
loss: 6.053635 [ 8032/ 8081]
Avg training loss: 5.743073
Val loss: 5.645316 (hab: 9.148739, mov: 5.674267)
Validation loss decreased (5.648002 → 5.645316). Saving model…
Epoch 16/50
loss: 6.885323 [ 32/ 8081]
loss: 5.594497 [ 352/ 8081]
loss: 5.678989 [ 672/ 8081]
loss: 6.227365 [ 992/ 8081]
loss: 5.790132 [ 1312/ 8081]
loss: 5.360517 [ 1632/ 8081]
loss: 5.787529 [ 1952/ 8081]
loss: 6.439590 [ 2272/ 8081]
loss: 5.170119 [ 2592/ 8081]
loss: 5.598507 [ 2912/ 8081]
loss: 5.916703 [ 3232/ 8081]
loss: 5.602601 [ 3552/ 8081]
loss: 4.940242 [ 3872/ 8081]
loss: 5.859461 [ 4192/ 8081]
loss: 5.760822 [ 4512/ 8081]
loss: 5.442616 [ 4832/ 8081]
loss: 5.790349 [ 5152/ 8081]
loss: 5.874994 [ 5472/ 8081]
loss: 5.619014 [ 5792/ 8081]
loss: 5.000046 [ 6112/ 8081]
loss: 5.243309 [ 6432/ 8081]
loss: 6.288419 [ 6752/ 8081]
loss: 6.087790 [ 7072/ 8081]
loss: 6.147535 [ 7392/ 8081]
loss: 6.012441 [ 7712/ 8081]
loss: 6.748470 [ 8032/ 8081]
Avg training loss: 5.742579
Val loss: 5.644631 (hab: 9.147775, mov: 5.673745)
Validation loss decreased (5.645316 → 5.644631). Saving model…
Epoch 17/50
loss: 6.283003 [ 32/ 8081]
loss: 5.595797 [ 352/ 8081]
loss: 5.715901 [ 672/ 8081]
loss: 5.501903 [ 992/ 8081]
loss: 5.439520 [ 1312/ 8081]
loss: 4.297513 [ 1632/ 8081]
loss: 5.481026 [ 1952/ 8081]
loss: 5.943962 [ 2272/ 8081]
loss: 5.792973 [ 2592/ 8081]
loss: 4.885669 [ 2912/ 8081]
loss: 5.561083 [ 3232/ 8081]
loss: 5.883833 [ 3552/ 8081]
loss: 6.213400 [ 3872/ 8081]
loss: 5.019048 [ 4192/ 8081]
loss: 5.765052 [ 4512/ 8081]
loss: 4.955160 [ 4832/ 8081]
loss: 5.329431 [ 5152/ 8081]
loss: 5.798806 [ 5472/ 8081]
loss: 5.645044 [ 5792/ 8081]
loss: 6.268931 [ 6112/ 8081]
loss: 5.830894 [ 6432/ 8081]
loss: 6.479350 [ 6752/ 8081]
loss: 5.524815 [ 7072/ 8081]
loss: 4.875284 [ 7392/ 8081]
loss: 6.527176 [ 7712/ 8081]
loss: 6.552918 [ 8032/ 8081]
Avg training loss: 5.741385
Val loss: 5.643943 (hab: 9.149630, mov: 5.673379)
Validation loss decreased (5.644631 → 5.643943). Saving model…
Epoch 18/50
loss: 5.646639 [ 32/ 8081]
loss: 6.221776 [ 352/ 8081]
loss: 5.486329 [ 672/ 8081]
loss: 6.544412 [ 992/ 8081]
loss: 5.341964 [ 1312/ 8081]
loss: 5.519305 [ 1632/ 8081]
loss: 5.982452 [ 1952/ 8081]
loss: 6.139508 [ 2272/ 8081]
loss: 5.064729 [ 2592/ 8081]
loss: 6.249910 [ 2912/ 8081]
loss: 5.974421 [ 3232/ 8081]
loss: 5.339847 [ 3552/ 8081]
loss: 6.463853 [ 3872/ 8081]
loss: 6.058188 [ 4192/ 8081]
loss: 5.772356 [ 4512/ 8081]
loss: 6.047544 [ 4832/ 8081]
loss: 6.791408 [ 5152/ 8081]
loss: 4.642981 [ 5472/ 8081]
loss: 5.114915 [ 5792/ 8081]
loss: 6.150812 [ 6112/ 8081]
loss: 6.051000 [ 6432/ 8081]
loss: 6.648256 [ 6752/ 8081]
loss: 6.250250 [ 7072/ 8081]
loss: 5.238799 [ 7392/ 8081]
loss: 5.869842 [ 7712/ 8081]
loss: 6.773789 [ 8032/ 8081]
Avg training loss: 5.740270
Val loss: 5.644076 (hab: 9.149081, mov: 5.672949)
EarlyStopping counter: 1 out of 5
Epoch 19/50
loss: 5.175504 [ 32/ 8081]
loss: 6.612113 [ 352/ 8081]
loss: 6.259410 [ 672/ 8081]
loss: 6.931705 [ 992/ 8081]
loss: 6.386226 [ 1312/ 8081]
loss: 5.127598 [ 1632/ 8081]
loss: 5.744873 [ 1952/ 8081]
loss: 5.399240 [ 2272/ 8081]
loss: 6.128780 [ 2592/ 8081]
loss: 4.790784 [ 2912/ 8081]
loss: 6.567341 [ 3232/ 8081]
loss: 4.790431 [ 3552/ 8081]
loss: 5.470789 [ 3872/ 8081]
loss: 5.554800 [ 4192/ 8081]
loss: 5.679316 [ 4512/ 8081]
loss: 5.368668 [ 4832/ 8081]
loss: 5.470472 [ 5152/ 8081]
loss: 5.754320 [ 5472/ 8081]
loss: 5.034759 [ 5792/ 8081]
loss: 6.242113 [ 6112/ 8081]
loss: 5.507895 [ 6432/ 8081]
loss: 5.377802 [ 6752/ 8081]
loss: 5.432044 [ 7072/ 8081]
loss: 7.036860 [ 7392/ 8081]
loss: 5.353534 [ 7712/ 8081]
loss: 5.177495 [ 8032/ 8081]
Avg training loss: 5.737458
Val loss: 5.644072 (hab: 9.149584, mov: 5.672635)
EarlyStopping counter: 2 out of 5
Epoch 20/50
loss: 5.072841 [ 32/ 8081]
loss: 5.977276 [ 352/ 8081]
loss: 6.188596 [ 672/ 8081]
loss: 5.472806 [ 992/ 8081]
loss: 5.054672 [ 1312/ 8081]
loss: 6.003091 [ 1632/ 8081]
loss: 5.515766 [ 1952/ 8081]
loss: 7.272111 [ 2272/ 8081]
loss: 5.377561 [ 2592/ 8081]
loss: 5.953638 [ 2912/ 8081]
loss: 5.927915 [ 3232/ 8081]
loss: 6.455359 [ 3552/ 8081]
loss: 6.164033 [ 3872/ 8081]
loss: 6.237619 [ 4192/ 8081]
loss: 6.076998 [ 4512/ 8081]
loss: 6.516166 [ 4832/ 8081]
loss: 5.407486 [ 5152/ 8081]
loss: 4.519997 [ 5472/ 8081]
loss: 6.256764 [ 5792/ 8081]
loss: 4.865379 [ 6112/ 8081]
loss: 5.583294 [ 6432/ 8081]
loss: 6.422597 [ 6752/ 8081]
loss: 5.656250 [ 7072/ 8081]
loss: 6.124544 [ 7392/ 8081]
loss: 6.916223 [ 7712/ 8081]
loss: 5.714869 [ 8032/ 8081]
Avg training loss: 5.738087
Val loss: 5.641866 (hab: 9.147915, mov: 5.672215)
Validation loss decreased (5.643943 → 5.641866). Saving model…
Epoch 21/50
loss: 6.010686 [ 32/ 8081]
loss: 6.227398 [ 352/ 8081]
loss: 5.143725 [ 672/ 8081]
loss: 5.645541 [ 992/ 8081]
loss: 5.240898 [ 1312/ 8081]
loss: 4.938787 [ 1632/ 8081]
loss: 5.559191 [ 1952/ 8081]
loss: 6.605025 [ 2272/ 8081]
loss: 5.396750 [ 2592/ 8081]
loss: 5.149271 [ 2912/ 8081]
loss: 5.966692 [ 3232/ 8081]
loss: 5.675520 [ 3552/ 8081]
loss: 5.866745 [ 3872/ 8081]
loss: 4.591894 [ 4192/ 8081]
loss: 6.107647 [ 4512/ 8081]
loss: 5.626565 [ 4832/ 8081]
loss: 6.301752 [ 5152/ 8081]
loss: 5.883070 [ 5472/ 8081]
loss: 6.466449 [ 5792/ 8081]
loss: 5.921516 [ 6112/ 8081]
loss: 4.517435 [ 6432/ 8081]
loss: 6.049242 [ 6752/ 8081]
loss: 7.047176 [ 7072/ 8081]
loss: 5.933296 [ 7392/ 8081]
loss: 6.263047 [ 7712/ 8081]
loss: 5.244785 [ 8032/ 8081]
Avg training loss: 5.737740
Val loss: 5.640463 (hab: 9.149177, mov: 5.671844)
Validation loss decreased (5.641866 → 5.640463). Saving model…
Epoch 22/50
loss: 5.056785 [ 32/ 8081]
loss: 5.107917 [ 352/ 8081]
loss: 6.367042 [ 672/ 8081]
loss: 5.297422 [ 992/ 8081]
loss: 6.158465 [ 1312/ 8081]
loss: 7.193504 [ 1632/ 8081]
loss: 5.873117 [ 1952/ 8081]
loss: 5.830358 [ 2272/ 8081]
loss: 5.546278 [ 2592/ 8081]
loss: 5.019753 [ 2912/ 8081]
loss: 6.140012 [ 3232/ 8081]
loss: 5.437500 [ 3552/ 8081]
loss: 5.131791 [ 3872/ 8081]
loss: 5.591730 [ 4192/ 8081]
loss: 5.345333 [ 4512/ 8081]
loss: 5.434434 [ 4832/ 8081]
loss: 5.019072 [ 5152/ 8081]
loss: 5.485474 [ 5472/ 8081]
loss: 5.990477 [ 5792/ 8081]
loss: 6.565684 [ 6112/ 8081]
loss: 5.539105 [ 6432/ 8081]
loss: 6.096859 [ 6752/ 8081]
loss: 5.047228 [ 7072/ 8081]
loss: 5.438942 [ 7392/ 8081]
loss: 4.245893 [ 7712/ 8081]
loss: 5.230383 [ 8032/ 8081]
Avg training loss: 5.737169
Val loss: 5.641476 (hab: 9.145061, mov: 5.671644)
EarlyStopping counter: 1 out of 5
Epoch 23/50
loss: 6.373130 [ 32/ 8081]
loss: 5.521699 [ 352/ 8081]
loss: 5.814208 [ 672/ 8081]
loss: 5.395493 [ 992/ 8081]
loss: 5.824369 [ 1312/ 8081]
loss: 6.002573 [ 1632/ 8081]
loss: 5.864894 [ 1952/ 8081]
loss: 6.124455 [ 2272/ 8081]
loss: 6.624372 [ 2592/ 8081]
loss: 5.793117 [ 2912/ 8081]
loss: 6.165062 [ 3232/ 8081]
loss: 6.720785 [ 3552/ 8081]
loss: 6.139231 [ 3872/ 8081]
loss: 5.599601 [ 4192/ 8081]
loss: 5.711744 [ 4512/ 8081]
loss: 5.504859 [ 4832/ 8081]
loss: 5.548381 [ 5152/ 8081]
loss: 6.631241 [ 5472/ 8081]
loss: 6.187800 [ 5792/ 8081]
loss: 5.437247 [ 6112/ 8081]
loss: 5.561949 [ 6432/ 8081]
loss: 5.568308 [ 6752/ 8081]
loss: 4.743848 [ 7072/ 8081]
loss: 4.785015 [ 7392/ 8081]
loss: 5.890841 [ 7712/ 8081]
loss: 5.016112 [ 8032/ 8081]
Avg training loss: 5.738002
Val loss: 5.641952 (hab: 9.144587, mov: 5.671266)
EarlyStopping counter: 2 out of 5
Epoch 24/50
loss: 5.746219 [ 32/ 8081]
loss: 5.143204 [ 352/ 8081]
loss: 6.581674 [ 672/ 8081]
loss: 5.834509 [ 992/ 8081]
loss: 5.494139 [ 1312/ 8081]
loss: 4.902992 [ 1632/ 8081]
loss: 5.277520 [ 1952/ 8081]
loss: 6.514735 [ 2272/ 8081]
loss: 5.920425 [ 2592/ 8081]
loss: 5.850049 [ 2912/ 8081]
loss: 5.195207 [ 3232/ 8081]
loss: 6.105504 [ 3552/ 8081]
loss: 5.616876 [ 3872/ 8081]
loss: 6.008060 [ 4192/ 8081]
loss: 5.422901 [ 4512/ 8081]
loss: 5.177522 [ 4832/ 8081]
loss: 4.751125 [ 5152/ 8081]
loss: 5.577045 [ 5472/ 8081]
loss: 6.131419 [ 5792/ 8081]
loss: 6.501829 [ 6112/ 8081]
loss: 5.214920 [ 6432/ 8081]
loss: 5.598885 [ 6752/ 8081]
loss: 5.734118 [ 7072/ 8081]
loss: 5.896174 [ 7392/ 8081]
loss: 5.807665 [ 7712/ 8081]
loss: 5.157558 [ 8032/ 8081]
Avg training loss: 5.735473
Val loss: 5.640158 (hab: 9.143460, mov: 5.671050)
Validation loss decreased (5.640463 → 5.640158). Saving model…
Epoch 25/50
loss: 6.161477 [ 32/ 8081]
loss: 5.648900 [ 352/ 8081]
loss: 5.558351 [ 672/ 8081]
loss: 5.186936 [ 992/ 8081]
loss: 5.539708 [ 1312/ 8081]
loss: 5.807642 [ 1632/ 8081]
loss: 6.611417 [ 1952/ 8081]
loss: 5.423437 [ 2272/ 8081]
loss: 5.694284 [ 2592/ 8081]
loss: 5.506063 [ 2912/ 8081]
loss: 4.919890 [ 3232/ 8081]
loss: 6.296585 [ 3552/ 8081]
loss: 5.463380 [ 3872/ 8081]
loss: 5.580927 [ 4192/ 8081]
loss: 6.874486 [ 4512/ 8081]
loss: 4.920260 [ 4832/ 8081]
loss: 6.043089 [ 5152/ 8081]
loss: 5.447857 [ 5472/ 8081]
loss: 6.368416 [ 5792/ 8081]
loss: 4.843377 [ 6112/ 8081]
loss: 6.345890 [ 6432/ 8081]
loss: 5.116454 [ 6752/ 8081]
loss: 6.050868 [ 7072/ 8081]
loss: 6.241920 [ 7392/ 8081]
loss: 6.305410 [ 7712/ 8081]
loss: 5.232047 [ 8032/ 8081]
Avg training loss: 5.734663
Val loss: 5.639725 (hab: 9.144361, mov: 5.670998)
Validation loss decreased (5.640158 → 5.639725). Saving model…
Epoch 26/50
loss: 7.151555 [ 32/ 8081]
loss: 5.414118 [ 352/ 8081]
loss: 6.189126 [ 672/ 8081]
loss: 5.227577 [ 992/ 8081]
loss: 6.818292 [ 1312/ 8081]
loss: 5.771779 [ 1632/ 8081]
loss: 5.577074 [ 1952/ 8081]
loss: 5.630198 [ 2272/ 8081]
loss: 5.653317 [ 2592/ 8081]
loss: 5.436688 [ 2912/ 8081]
loss: 5.958178 [ 3232/ 8081]
loss: 5.440780 [ 3552/ 8081]
loss: 5.431932 [ 3872/ 8081]
loss: 5.609800 [ 4192/ 8081]
loss: 5.665380 [ 4512/ 8081]
loss: 5.957418 [ 4832/ 8081]
loss: 4.937189 [ 5152/ 8081]
loss: 6.138237 [ 5472/ 8081]
loss: 6.117325 [ 5792/ 8081]
loss: 5.992046 [ 6112/ 8081]
loss: 5.035103 [ 6432/ 8081]
loss: 6.369143 [ 6752/ 8081]
loss: 5.848080 [ 7072/ 8081]
loss: 6.836145 [ 7392/ 8081]
loss: 5.299628 [ 7712/ 8081]
loss: 5.602996 [ 8032/ 8081]
Avg training loss: 5.735829
Val loss: 5.639682 (hab: 9.145324, mov: 5.670968)
Validation loss decreased (5.639725 → 5.639682). Saving model…
Epoch 27/50
loss: 4.927288 [ 32/ 8081]
loss: 6.503007 [ 352/ 8081]
loss: 5.772881 [ 672/ 8081]
loss: 5.601285 [ 992/ 8081]
loss: 5.897579 [ 1312/ 8081]
loss: 5.781460 [ 1632/ 8081]
loss: 5.490117 [ 1952/ 8081]
loss: 5.170106 [ 2272/ 8081]
loss: 6.280233 [ 2592/ 8081]
loss: 5.659313 [ 2912/ 8081]
loss: 5.048002 [ 3232/ 8081]
loss: 5.602965 [ 3552/ 8081]
loss: 6.199113 [ 3872/ 8081]
loss: 5.848446 [ 4192/ 8081]
loss: 5.756188 [ 4512/ 8081]
loss: 5.665813 [ 4832/ 8081]
loss: 4.977673 [ 5152/ 8081]
loss: 5.520265 [ 5472/ 8081]
loss: 5.220546 [ 5792/ 8081]
loss: 5.660740 [ 6112/ 8081]
loss: 5.151398 [ 6432/ 8081]
loss: 6.287862 [ 6752/ 8081]
loss: 5.646356 [ 7072/ 8081]
loss: 6.649393 [ 7392/ 8081]
loss: 5.878810 [ 7712/ 8081]
loss: 5.759536 [ 8032/ 8081]
Avg training loss: 5.733462
Val loss: 5.639660 (hab: 9.145306, mov: 5.670936)
Validation loss decreased (5.639682 → 5.639660). Saving model…
Epoch 28/50
loss: 6.236925 [ 32/ 8081]
loss: 6.146458 [ 352/ 8081]
loss: 5.195530 [ 672/ 8081]
loss: 5.583611 [ 992/ 8081]
loss: 5.273785 [ 1312/ 8081]
loss: 5.463507 [ 1632/ 8081]
loss: 5.881322 [ 1952/ 8081]
loss: 6.306096 [ 2272/ 8081]
loss: 5.130684 [ 2592/ 8081]
loss: 5.816171 [ 2912/ 8081]
loss: 6.420640 [ 3232/ 8081]
loss: 5.765622 [ 3552/ 8081]
loss: 5.842854 [ 3872/ 8081]
loss: 5.821796 [ 4192/ 8081]
loss: 5.756525 [ 4512/ 8081]
loss: 5.565989 [ 4832/ 8081]
loss: 5.573711 [ 5152/ 8081]
loss: 5.448873 [ 5472/ 8081]
loss: 6.393813 [ 5792/ 8081]
loss: 5.563116 [ 6112/ 8081]
loss: 7.175228 [ 6432/ 8081]
loss: 4.962247 [ 6752/ 8081]
loss: 5.873940 [ 7072/ 8081]
loss: 5.759718 [ 7392/ 8081]
loss: 5.624741 [ 7712/ 8081]
loss: 5.907386 [ 8032/ 8081]
Avg training loss: 5.733195
Val loss: 5.639641 (hab: 9.145159, mov: 5.670897)
Validation loss decreased (5.639660 → 5.639641). Saving model…
Epoch 29/50
loss: 6.709939 [ 32/ 8081]
loss: 6.194642 [ 352/ 8081]
loss: 5.739034 [ 672/ 8081]
loss: 5.670167 [ 992/ 8081]
loss: 4.941658 [ 1312/ 8081]
loss: 5.502157 [ 1632/ 8081]
loss: 6.519639 [ 1952/ 8081]
loss: 5.792060 [ 2272/ 8081]
loss: 5.513575 [ 2592/ 8081]
loss: 5.989444 [ 2912/ 8081]
loss: 6.193710 [ 3232/ 8081]
loss: 5.135377 [ 3552/ 8081]
loss: 5.945956 [ 3872/ 8081]
loss: 4.705295 [ 4192/ 8081]
loss: 5.501117 [ 4512/ 8081]
loss: 5.852671 [ 4832/ 8081]
loss: 6.271144 [ 5152/ 8081]
loss: 5.952385 [ 5472/ 8081]
loss: 5.365061 [ 5792/ 8081]
loss: 6.104717 [ 6112/ 8081]
loss: 5.881651 [ 6432/ 8081]
loss: 5.851075 [ 6752/ 8081]
loss: 5.142447 [ 7072/ 8081]
loss: 5.647768 [ 7392/ 8081]
loss: 5.986320 [ 7712/ 8081]
loss: 5.402556 [ 8032/ 8081]
Avg training loss: 5.734142
Val loss: 5.639626 (hab: 9.145162, mov: 5.670893)
Validation loss decreased (5.639641 → 5.639626). Saving model…
Epoch 30/50
loss: 5.612227 [ 32/ 8081]
loss: 5.494819 [ 352/ 8081]
loss: 4.629119 [ 672/ 8081]
loss: 6.452594 [ 992/ 8081]
loss: 5.972569 [ 1312/ 8081]
loss: 5.160095 [ 1632/ 8081]
loss: 5.795186 [ 1952/ 8081]
loss: 5.309612 [ 2272/ 8081]
loss: 5.891468 [ 2592/ 8081]
loss: 5.696829 [ 2912/ 8081]
loss: 5.713593 [ 3232/ 8081]
loss: 6.069798 [ 3552/ 8081]
loss: 4.935905 [ 3872/ 8081]
loss: 5.428632 [ 4192/ 8081]
loss: 5.866337 [ 4512/ 8081]
loss: 5.686249 [ 4832/ 8081]
loss: 5.791387 [ 5152/ 8081]
loss: 6.364533 [ 5472/ 8081]
loss: 5.466154 [ 5792/ 8081]
loss: 5.667849 [ 6112/ 8081]
loss: 5.313776 [ 6432/ 8081]
loss: 6.761433 [ 6752/ 8081]
loss: 6.629786 [ 7072/ 8081]
loss: 5.532158 [ 7392/ 8081]
loss: 6.032206 [ 7712/ 8081]
loss: 5.922668 [ 8032/ 8081]
Avg training loss: 5.734589
Val loss: 5.639633 (hab: 9.145213, mov: 5.670889)
EarlyStopping counter: 1 out of 5
Epoch 31/50
loss: 6.855468 [ 32/ 8081]
loss: 5.549836 [ 352/ 8081]
loss: 5.501370 [ 672/ 8081]
loss: 6.163689 [ 992/ 8081]
loss: 6.146986 [ 1312/ 8081]
loss: 5.668081 [ 1632/ 8081]
loss: 5.154352 [ 1952/ 8081]
loss: 6.032871 [ 2272/ 8081]
loss: 5.482692 [ 2592/ 8081]
loss: 6.838123 [ 2912/ 8081]
loss: 6.182957 [ 3232/ 8081]
loss: 5.312598 [ 3552/ 8081]
loss: 5.461801 [ 3872/ 8081]
loss: 5.386752 [ 4192/ 8081]
loss: 6.306053 [ 4512/ 8081]
loss: 5.683695 [ 4832/ 8081]
loss: 5.795979 [ 5152/ 8081]
loss: 5.956809 [ 5472/ 8081]
loss: 5.825184 [ 5792/ 8081]
loss: 5.104815 [ 6112/ 8081]
loss: 5.338410 [ 6432/ 8081]
loss: 4.392776 [ 6752/ 8081]
loss: 5.910175 [ 7072/ 8081]
loss: 4.851960 [ 7392/ 8081]
loss: 5.446443 [ 7712/ 8081]
loss: 5.372225 [ 8032/ 8081]
Avg training loss: 5.733252
Val loss: 5.639620 (hab: 9.145216, mov: 5.670887)
Validation loss decreased (5.639626 → 5.639620). Saving model…
Epoch 32/50
loss: 5.426272 [ 32/ 8081]
loss: 5.630202 [ 352/ 8081]
loss: 5.632508 [ 672/ 8081]
loss: 5.438181 [ 992/ 8081]
loss: 5.693160 [ 1312/ 8081]
loss: 5.824766 [ 1632/ 8081]
loss: 4.332841 [ 1952/ 8081]
loss: 5.520078 [ 2272/ 8081]
loss: 5.456306 [ 2592/ 8081]
loss: 5.956411 [ 2912/ 8081]
loss: 5.646355 [ 3232/ 8081]
loss: 5.780666 [ 3552/ 8081]
loss: 5.656393 [ 3872/ 8081]
loss: 6.113107 [ 4192/ 8081]
loss: 5.247138 [ 4512/ 8081]
loss: 6.456845 [ 4832/ 8081]
loss: 5.673826 [ 5152/ 8081]
loss: 6.329019 [ 5472/ 8081]
loss: 6.144670 [ 5792/ 8081]
loss: 6.344707 [ 6112/ 8081]
loss: 6.364787 [ 6432/ 8081]
loss: 5.852273 [ 6752/ 8081]
loss: 5.987990 [ 7072/ 8081]
loss: 5.569731 [ 7392/ 8081]
loss: 6.279569 [ 7712/ 8081]
loss: 5.453013 [ 8032/ 8081]
Avg training loss: 5.732120
Val loss: 5.639620 (hab: 9.145214, mov: 5.670886)
Validation loss decreased (5.639620 → 5.639620). Saving model…
Epoch 33/50
loss: 5.268016 [ 32/ 8081]
loss: 5.237939 [ 352/ 8081]
loss: 5.922724 [ 672/ 8081]
loss: 5.021144 [ 992/ 8081]
loss: 6.242224 [ 1312/ 8081]
loss: 5.304910 [ 1632/ 8081]
loss: 6.274157 [ 1952/ 8081]
loss: 5.443479 [ 2272/ 8081]
loss: 5.825380 [ 2592/ 8081]
loss: 4.930949 [ 2912/ 8081]
loss: 5.148467 [ 3232/ 8081]
loss: 5.614328 [ 3552/ 8081]
loss: 5.556447 [ 3872/ 8081]
loss: 6.509217 [ 4192/ 8081]
loss: 5.418288 [ 4512/ 8081]
loss: 5.584445 [ 4832/ 8081]
loss: 6.188684 [ 5152/ 8081]
loss: 5.922045 [ 5472/ 8081]
loss: 5.749983 [ 5792/ 8081]
loss: 5.129826 [ 6112/ 8081]
loss: 5.641623 [ 6432/ 8081]
loss: 7.205474 [ 6752/ 8081]
loss: 5.569323 [ 7072/ 8081]
loss: 5.049457 [ 7392/ 8081]
loss: 5.763159 [ 7712/ 8081]
loss: 5.636391 [ 8032/ 8081]
Avg training loss: 5.731416
Val loss: 5.639621 (hab: 9.145215, mov: 5.670886)
EarlyStopping counter: 1 out of 5
Epoch 34/50
loss: 5.573732 [ 32/ 8081]
loss: 4.562861 [ 352/ 8081]
loss: 6.424152 [ 672/ 8081]
loss: 5.473699 [ 992/ 8081]
loss: 5.866921 [ 1312/ 8081]
loss: 5.599736 [ 1632/ 8081]
loss: 5.116332 [ 1952/ 8081]
loss: 5.631752 [ 2272/ 8081]
loss: 5.095023 [ 2592/ 8081]
loss: 6.444837 [ 2912/ 8081]
loss: 5.993291 [ 3232/ 8081]
loss: 6.616193 [ 3552/ 8081]
loss: 5.721396 [ 3872/ 8081]
loss: 5.316614 [ 4192/ 8081]
loss: 4.425422 [ 4512/ 8081]
loss: 5.275918 [ 4832/ 8081]
loss: 5.434562 [ 5152/ 8081]
loss: 5.448647 [ 5472/ 8081]
loss: 5.664916 [ 5792/ 8081]
loss: 5.117067 [ 6112/ 8081]
loss: 5.216536 [ 6432/ 8081]
loss: 6.023912 [ 6752/ 8081]
loss: 6.081383 [ 7072/ 8081]
loss: 6.285754 [ 7392/ 8081]
loss: 4.963202 [ 7712/ 8081]
loss: 5.865150 [ 8032/ 8081]
Avg training loss: 5.731098
Val loss: 5.639621 (hab: 9.145216, mov: 5.670886)
EarlyStopping counter: 2 out of 5
Epoch 35/50
loss: 6.021615 [ 32/ 8081]
loss: 4.532571 [ 352/ 8081]
loss: 5.175756 [ 672/ 8081]
loss: 5.827329 [ 992/ 8081]
loss: 5.984054 [ 1312/ 8081]
loss: 5.343431 [ 1632/ 8081]
loss: 6.666986 [ 1952/ 8081]
loss: 6.816831 [ 2272/ 8081]
loss: 5.876522 [ 2592/ 8081]
loss: 5.746874 [ 2912/ 8081]
loss: 5.963414 [ 3232/ 8081]
loss: 4.622033 [ 3552/ 8081]
loss: 6.336753 [ 3872/ 8081]
loss: 5.816704 [ 4192/ 8081]
loss: 5.761320 [ 4512/ 8081]
loss: 5.750998 [ 4832/ 8081]
loss: 4.660877 [ 5152/ 8081]
loss: 5.898358 [ 5472/ 8081]
loss: 5.581824 [ 5792/ 8081]
loss: 4.712275 [ 6112/ 8081]
loss: 5.580365 [ 6432/ 8081]
loss: 5.149635 [ 6752/ 8081]
loss: 4.571750 [ 7072/ 8081]
loss: 5.565580 [ 7392/ 8081]
loss: 6.358205 [ 7712/ 8081]
loss: 5.638477 [ 8032/ 8081]
Avg training loss: 5.731328
Val loss: 5.639621 (hab: 9.145216, mov: 5.670886)
EarlyStopping counter: 3 out of 5
Epoch 36/50
loss: 5.219586 [ 32/ 8081]
loss: 5.866108 [ 352/ 8081]
loss: 5.157986 [ 672/ 8081]
loss: 5.456277 [ 992/ 8081]
loss: 6.346135 [ 1312/ 8081]
loss: 6.265892 [ 1632/ 8081]
loss: 5.365193 [ 1952/ 8081]
loss: 5.935129 [ 2272/ 8081]
loss: 5.507366 [ 2592/ 8081]
loss: 5.146176 [ 2912/ 8081]
loss: 6.031135 [ 3232/ 8081]
loss: 5.098520 [ 3552/ 8081]
loss: 6.378906 [ 3872/ 8081]
loss: 5.300342 [ 4192/ 8081]
loss: 6.126292 [ 4512/ 8081]
loss: 5.154186 [ 4832/ 8081]
loss: 5.361709 [ 5152/ 8081]
loss: 5.572066 [ 5472/ 8081]
loss: 5.279382 [ 5792/ 8081]
loss: 4.960338 [ 6112/ 8081]
loss: 6.301264 [ 6432/ 8081]
loss: 5.728517 [ 6752/ 8081]
loss: 6.473097 [ 7072/ 8081]
loss: 6.075172 [ 7392/ 8081]
loss: 5.903062 [ 7712/ 8081]
loss: 6.081424 [ 8032/ 8081]
Avg training loss: 5.735796
Val loss: 5.639621 (hab: 9.145216, mov: 5.670885)
EarlyStopping counter: 4 out of 5
Epoch 37/50
loss: 6.125269 [ 32/ 8081]
loss: 5.612249 [ 352/ 8081]
loss: 5.744735 [ 672/ 8081]
loss: 5.696525 [ 992/ 8081]
loss: 6.524700 [ 1312/ 8081]
loss: 5.233836 [ 1632/ 8081]
loss: 6.115005 [ 1952/ 8081]
loss: 5.521945 [ 2272/ 8081]
loss: 5.175142 [ 2592/ 8081]
loss: 6.153426 [ 2912/ 8081]
loss: 6.875118 [ 3232/ 8081]
loss: 5.568776 [ 3552/ 8081]
loss: 6.070964 [ 3872/ 8081]
loss: 4.977057 [ 4192/ 8081]
loss: 6.165647 [ 4512/ 8081]
loss: 6.339566 [ 4832/ 8081]
loss: 5.571197 [ 5152/ 8081]
loss: 5.252230 [ 5472/ 8081]
loss: 6.161972 [ 5792/ 8081]
loss: 5.660664 [ 6112/ 8081]
loss: 6.967310 [ 6432/ 8081]
loss: 4.854573 [ 6752/ 8081]
loss: 5.381294 [ 7072/ 8081]
loss: 5.590603 [ 7392/ 8081]
loss: 5.565978 [ 7712/ 8081]
loss: 5.514409 [ 8032/ 8081]
Avg training loss: 5.734960
Val loss: 5.639621 (hab: 9.145216, mov: 5.670885)
EarlyStopping counter: 5 out of 5
Early stopping triggered.
gif_path = str (OUTPUT_DIR / "training_progress.gif" )
create_gif(str (SNAPSHOT_DIR), gif_path, fps= 5 )
print (f"Training GIF → { gif_path} " )
# Play the gif
Image(gif_path)
Animation saved: outputs/training_progress.gif
Training GIF → outputs/training_progress.gif
<IPython.core.display.Image object>
fig, ax = plt.subplots(figsize= (7 , 3 ))
ax.plot(history["train_losses" ], label= "train" )
ax.plot(history["val_losses" ], label= "val" )
ax.set_xlabel("Epoch" )
ax.set_ylabel("NLL loss" )
ax.legend()
ax.set_title("Training history" )
plt.tight_layout()
plt.savefig(str (OUTPUT_DIR / "loss_history.png" ), dpi= 100 )
plt.show()
4 Validate against withheld test data
# Load the best checkpoint saved during training
model.load_state_dict(
torch.load(str (OUTPUT_DIR / "best_model.pt" ), map_location= DEVICE)
)
model.eval ()
print ("Best checkpoint loaded." )
# Wrap the static rasters as a landscape-loader callable.
# Month index is ignored here because we have static (non-S2) layers.
ndvi_t = torch.from_numpy(env_layers["ndvi" ].astype(np.float32))
slope_t = torch.from_numpy(env_layers["slope" ].astype(np.float32))
def get_landscape(_month_index):
return [ndvi_t, slope_t]
# Use the last 10 % of steps as the test set (matches the DataLoader split)
n_test = int (len (step_df) * 0.1 )
test_sample = step_df.iloc[- n_test:].reset_index(drop= True )
print (f"Validating on { len (test_sample):,} steps" )
val_results = validate_next_step_probs(
model,
test_sample,
get_landscape= get_landscape,
transform= raster_transform,
window_size= WINDOW_SIZE,
scalar_cols= tuple (SCALAR_COLS),
yday_col= "yday_t1" ,
bearing_col= "bearing_tm1" ,
month_index_fn= lambda _yday: 0 , # static layers — month doesn't matter
)
print (val_results[["habitat_prob" , "move_prob" , "next_step_prob" ]].describe())
Validating on 1,010 steps
habitat_prob move_prob next_step_prob
count 1010.000000 1010.000000 1010.000000
mean 0.000126 0.027330 0.030069
std 0.000050 0.054359 0.060509
min 0.000000 0.000000 0.000000
25% 0.000093 0.000224 0.000228
50% 0.000112 0.006624 0.007082
75% 0.000155 0.030587 0.034205
max 0.000430 0.255475 0.334267
fig, axes = plt.subplots(1 , 3 , figsize= (12 , 3 ))
for ax, col, title in zip (
axes,
["habitat_prob" , "move_prob" , "next_step_prob" ],
["Habitat" , "Movement" , "Joint next-step" ],
):
data = val_results[col].dropna()
ax.hist(data, bins= 40 , edgecolor= "white" )
ax.set_title(title)
ax.set_xlabel("Probability" )
axes[0 ].set_ylabel("Count" )
plt.suptitle("Next-step probability distributions (test set)" , y= 1.02 )
plt.tight_layout()
plt.savefig(str (OUTPUT_DIR / "validation_probs.png" ), dpi= 100 )
plt.show()
5 Simulate a trajectory
start_x = float (step_df["x1_" ].median())
start_y = float (step_df["y1_" ].median())
print (f"Starting location: ( { start_x:.0f} , { start_y:.0f} )" )
sim_df = simulate_trajectory(
model,
get_landscape= get_landscape,
transform= raster_transform,
start_x= start_x,
start_y= start_y,
n_steps= 1000 ,
starting_yday= test_sample["yday_t1" ].iloc[0 ],
starting_hour= test_sample["hour_t1" ].iloc[0 ],
time_between_steps= 1.0 ,
window_size= WINDOW_SIZE,
month_index_fn= lambda _yday: 0 , # static layers
)
print (f"Simulated { len (sim_df)} steps" )
sim_df.head()
Starting location: (36111, -1436905)
Simulated 1000 steps
0
36108.269587
-1.436912e+06
18.916667
255.000000
0
[[-9.412999, -9.240336, -9.038588, -8.982068, ...
[[-13.963823, -13.919329, -13.875055, -13.8310...
[[-23.376823, -23.159664, -22.913643, -22.8130...
1
35898.710528
-1.436614e+06
19.916667
255.041667
0
[[-9.405268, -9.207009, -8.999849, -8.948765, ...
[[-13.95573, -13.91391, -13.872377, -13.831140...
[[-23.360998, -23.120918, -22.872227, -22.7799...
2
36464.367638
-1.436123e+06
20.916667
255.083333
0
[[-9.160653, -9.023066, -8.815113, -8.739051, ...
[[-13.513008, -13.469606, -13.426506, -13.3837...
[[-22.67366, -22.492672, -22.24162, -22.122772...
3
36435.540267
-1.436041e+06
21.916667
255.125000
0
[[-8.481404, -8.342018, -8.182885, -8.192506, ...
[[-13.816451, -13.771689, -13.727157, -13.6828...
[[-22.297855, -22.113708, -21.910042, -21.8753...
4
36439.641807
-1.436043e+06
22.916667
255.166667
0
[[-8.254223, -8.155582, -7.9846506, -7.9550314...
[[-13.654331, -13.611277, -13.568495, -13.5259...
[[-21.908554, -21.76686, -21.553146, -21.48103...
fig, ax = plt.subplots(figsize= (7 , 6 ))
rasterio.plot.show(env_layers["ndvi" ], transform= raster_transform, ax= ax, cmap= 'viridis' )
ax.plot(
step_df["x1_" ], step_df["y1_" ],
"w-o" , alpha= 0.5 , markersize= 1 , linewidth= 0.8 , label= "observed" ,
)
ax.plot(
sim_df["x" ], sim_df["y" ],
"r-o" , markersize= 1 , linewidth= 0.8 , label= "simulated" ,
)
ax.set_xlabel("Easting (m, projected)" )
ax.set_ylabel("Northing (m, projected)" )
ax.legend()
ax.set_title("Simulated vs observed buffalo locations" )
plt.tight_layout()
plt.savefig(str (OUTPUT_DIR / "simulated_trajectory.png" ), dpi= 100 )
plt.show()