deepSSF Training - S2

Author
Affiliation

Queensland University of Technology, CSIRO

Published

July 10, 2025

Abstract

In this script, we will train a deepSSF model on Sentinel-2 data directly (instead of derived covariates such as NDVI). The training data was generated using the deepSSF_data_prep_S2_id.qmd script, which crops out local images for each step of the observed telemetry data.

There aren’t as many comments or information to help understand the model, so please refer to the deepSSF_train script for more detail.

Detect computing environment

If using Google Colab, mount the drive and set the base directory to the working folder. If using local, set the base directory.

Code
import os       # Operating system utilities
import sys

# Detect environment
def is_colab():
    """Returns True if running in Google Colab, False otherwise."""
    try:
        import google.colab
        return True
    except ImportError:
        return False

# Set up environment-specific configurations
if is_colab():

    # Colab-specific setup
    !pip install rasterio
    from google.colab import drive
    drive.mount('/content/drive')
    sys.path.append('/content/drive/MyDrive/GitHub/deepSSF/Python')

    # for saving plots etc
    base_path = '/content/drive/MyDrive/GitHub/deepSSF'
    print("Running in Google Colab environment")

else:

    # Local environment setup
    base_path = '..'
    print("Running in local environment")

# Now you can use base_path regardless of environment
print(f"Using base path: {base_path}")
Running in local environment
Using base path: ..

Import packages

Code
# If using Google Colab, uncomment the following line
# !pip install rasterio

print(sys.version)  # Print Python version in use

import numpy as np                                      # Array operations
import matplotlib.pyplot as plt                         # Plotting library
import mpmath as mp                                     # Math library
import torch                                            # Main PyTorch library
import torch.optim as optim                             # Optimization algorithms
import torch.nn as nn                                   # Neural network modules
import os                                               # Operating system utilities
import glob                                             # Pattern matching
import imageio.v2 as imageio                            # Image manipulation - for creating GIFs
from IPython.display import Image, display              # For plotting GIFs
import pandas as pd                                     # Data manipulation
import rasterio                                         # Geospatial raster data

from torch.utils.data import Dataset, DataLoader        # Dataset and batch data loading
from datetime import datetime, timedelta                # Date/time utilities
from rasterio.plot import show                           # Plot raster data

import deepSSF_utils                                    # Import the .py file containing utility functions

# Get today's date
today_date = datetime.today().strftime('%Y-%m-%d')
print(today_date)

# Set random seed for reproducibility
# seed = 42
3.12.9 | packaged by Anaconda, Inc. | (main, Feb  6 2025, 12:55:12) [Clang 14.0.6 ]
2025-06-05

Set the device (accelerator - cuda for NVIDIA GPU or mps for Mac)

Code
# Set the device to be used (GPU or CPU)
device = torch.accelerator.current_accelerator().type if torch.accelerator.is_available() else "cpu"
print(f"Using {device} device")

if torch.backends.mps.is_available():
    # Set default tensor type for PyTorch
    torch.set_default_dtype(torch.float32)
    print('Set default tensor type to float32')
Using mps device
Set default tensor type to float32

Select individual and create directory to save model weights and outputs

Code
# select the id to train the model on
buffalo_id = 2005
# in our case the actual dataset will be slightly smaller due to steps being removed that were outside the extent
n_samples = 10297

Create a directory to save the outputs

If we have already run this code today, we will add update index to create a new folder

Code
# Count existing directories with similar pattern
pattern = f'{base_path}/Python/outputs/model_training_S2/id{buffalo_id}_scalar_movement_*_{today_date}'
existing_dirs = glob.glob(pattern)
dir_index = len(existing_dirs) + 1

# Create directory with index
output_dir = f'{base_path}/Python/outputs/model_training_S2/id{buffalo_id}_scalar_movement_{dir_index}_{today_date}'
os.makedirs(output_dir, exist_ok=True)

print(f"Created directory: {output_dir}")

# To use an existing directory for loading trained model
# output_dir = f'{base_path}/Python/outputs/model_training/id2005_2025-04-01'
Created directory: ../Python/outputs/model_training_S2/id2005_scalar_movement_2_2025-06-05

Import data

Set paths to data

Code
buffalo_id = 2005
n_samples = 10297

# Specify the path to CSV file
# csv_file_path = f'{base_path}/buffalo_local_data_id/buffalo_{buffalo_id}_data_df_lag_1hr_n{n_samples}.csv'
csv_file_path = f'{base_path}/buffalo_local_data_id/buffalo_temporal_cont_{buffalo_id}_data_df_lag_1hr_n{n_samples}.csv'

# Path to your TIF file (slope)
slope_path = f'{base_path}/buffalo_local_layers_id/buffalo_{buffalo_id}_slope_cent101x101_lag_1hr_n{n_samples}.tif'

# Paths to the Sentinel-2 bands
b1_path = f'{base_path}/buffalo_local_layers_id/buffalo_{buffalo_id}_s2_b1_cent101x101_lag_1hr_n{n_samples}.tif'
b2_path = f'{base_path}/buffalo_local_layers_id/buffalo_{buffalo_id}_s2_b2_cent101x101_lag_1hr_n{n_samples}.tif'
b3_path = f'{base_path}/buffalo_local_layers_id/buffalo_{buffalo_id}_s2_b3_cent101x101_lag_1hr_n{n_samples}.tif'
b4_path = f'{base_path}/buffalo_local_layers_id/buffalo_{buffalo_id}_s2_b4_cent101x101_lag_1hr_n{n_samples}.tif'
b5_path = f'{base_path}/buffalo_local_layers_id/buffalo_{buffalo_id}_s2_b5_cent101x101_lag_1hr_n{n_samples}.tif'
b6_path = f'{base_path}/buffalo_local_layers_id/buffalo_{buffalo_id}_s2_b6_cent101x101_lag_1hr_n{n_samples}.tif'
b7_path = f'{base_path}/buffalo_local_layers_id/buffalo_{buffalo_id}_s2_b7_cent101x101_lag_1hr_n{n_samples}.tif'
b8_path = f'{base_path}/buffalo_local_layers_id/buffalo_{buffalo_id}_s2_b8_cent101x101_lag_1hr_n{n_samples}.tif'
b8a_path = f'{base_path}/buffalo_local_layers_id/buffalo_{buffalo_id}_s2_b8a_cent101x101_lag_1hr_n{n_samples}.tif'
b9_path = f'{base_path}/buffalo_local_layers_id/buffalo_{buffalo_id}_s2_b9_cent101x101_lag_1hr_n{n_samples}.tif'
b11_path = f'{base_path}/buffalo_local_layers_id/buffalo_{buffalo_id}_s2_b11_cent101x101_lag_1hr_n{n_samples}.tif'
b12_path = f'{base_path}/buffalo_local_layers_id/buffalo_{buffalo_id}_s2_b12_cent101x101_lag_1hr_n{n_samples}.tif'

# Path to your TIF file (target variable)
# pres_path = f'{base_path}/buffalo_local_layers_id/buffalo_{buffalo_id}_pres_cent101x101_lag_1hr_n{n_samples}.tif'
pres_path = f'{base_path}/buffalo_local_layers_id/fixed_buffalo_{buffalo_id}_pres_cent101x101_lag_1hr_n{n_samples}.tif'

Read buffalo data

Code
# Read the CSV file into a DataFrame
buffalo_df = pd.read_csv(csv_file_path)
print(buffalo_df.shape)

# Lag the values in column 'A' by one index to get the bearing of the previous step
buffalo_df['bearing_tm1'] = buffalo_df['bearing'].shift(1)
# Pad the missing value with a specified value, e.g., 0
buffalo_df['bearing_tm1'] = buffalo_df['bearing_tm1'].fillna(0)

# Display the first few rows of the DataFrame
print(buffalo_df.head())
(10103, 43)
             x_            y_                    t_    id           x1_  \
0  41969.310875 -1.435671e+06  2018-07-25T01:04:23Z  2005  41969.310875   
1  41921.521939 -1.435654e+06  2018-07-25T02:04:39Z  2005  41921.521939   
2  41779.439594 -1.435601e+06  2018-07-25T03:04:17Z  2005  41779.439594   
3  41841.203272 -1.435635e+06  2018-07-25T04:04:39Z  2005  41841.203272   
4  41655.463332 -1.435604e+06  2018-07-25T05:04:27Z  2005  41655.463332   

            y1_           x2_           y2_     x2_cent    y2_cent  ...  \
0 -1.435671e+06  41921.521939 -1.435654e+06  -47.788936  16.857110  ...   
1 -1.435654e+06  41779.439594 -1.435601e+06 -142.082345  53.568427  ...   
2 -1.435601e+06  41841.203272 -1.435635e+06   61.763677 -34.322938  ...   
3 -1.435635e+06  41655.463332 -1.435604e+06 -185.739939  31.003534  ...   
4 -1.435604e+06  41618.651923 -1.435608e+06  -36.811409  -4.438037  ...   

    bearing  bearing_sin  bearing_cos        ta    cos_ta         x_min  \
0  2.802478     0.332652    -0.943050  1.367942  0.201466  40706.810875   
1  2.781049     0.352783    -0.935705 -0.021429  0.999770  40659.021939   
2 -0.507220    -0.485749     0.874098  2.994917 -0.989262  40516.939594   
3  2.976198     0.164641    -0.986354 -2.799767 -0.942144  40578.703272   
4 -3.021610    -0.119695    -0.992811  0.285377  0.959556  40392.963332   

          x_max         y_min         y_max  bearing_tm1  
0  43231.810875 -1.436934e+06 -1.434409e+06     0.000000  
1  43184.021939 -1.436917e+06 -1.434392e+06     2.802478  
2  43041.939594 -1.436863e+06 -1.434338e+06     2.781049  
3  43103.703272 -1.436898e+06 -1.434373e+06    -0.507220  
4  42917.963332 -1.436867e+06 -1.434342e+06     2.976198  

[5 rows x 44 columns]

Importing spatial data

Slope

Code
# Using rasterio
with rasterio.open(slope_path) as slope:
    # Read all layers/channels into a single numpy array
    # rasterio indexes channels starting from 1, hence the range is 1 to src.count + 1
    slope_stack = slope.read([i for i in range(1, slope.count + 1)])

print(slope_stack.shape)
(10103, 101, 101)
Code
# Convert the numpy array to a PyTorch tensor, which is the format required for training the model
slope_tens = torch.from_numpy(slope_stack)
print(slope_tens.shape)

# Print the mean, max, and min values of the slope tensor
print("Mean = ", torch.mean(slope_tens))
slope_max = torch.max(slope_tens)
slope_min = torch.min(slope_tens)
print("Max = ", slope_max)
print("Min = ", slope_min)

# Normalizing the data
slope_tens = (slope_tens - slope_min) / (slope_max - slope_min)
print("Mean = ", torch.mean(slope_tens))
print("Max = ", torch.max(slope_tens))
print("Min = ", torch.min(slope_tens))
torch.Size([10103, 101, 101])
Mean =  tensor(0.7779)
Max =  tensor(12.2981)
Min =  tensor(0.0006)
Mean =  tensor(0.0632)
Max =  tensor(1.)
Min =  tensor(0.)
Code
for i in range(0, 1):
    plt.imshow(slope_tens[i])
    plt.colorbar()
    plt.show()

Sentinel-2 bands

During the data preparation (in the deepSSF_data_prep_id_S2 script) for the Sentinel-2 bands, we scaled them by 10,000, so we do not need to scale them again here (as we did for the other covariates).

Band 1

Code
# Using rasterio
with rasterio.open(b1_path) as b1:
    # Read all layers/channels into a single numpy array
    # rasterio indexes channels starting from 1, hence the range is 1 to src.count + 1
    b1_stack = b1.read([i for i in range(1, b1.count + 1)])
Code
# Print the shape of the original b1_stack array
print(b1_stack.shape)

# Replace NaNs with -1
b1_stack = np.nan_to_num(b1_stack, nan=-1.0)

# Convert the numpy array to a PyTorch tensor
b1_tens = torch.from_numpy(b1_stack)
print(b1_tens.shape)

# Display the mean, max, and min values of the b1 tensor
print(f'Min =  {torch.min(b1_tens)}')
print(f'Mean = {torch.mean(b1_tens)}')
print(f'Max =  {torch.max(b1_tens)}')
(10103, 101, 101)
torch.Size([10103, 101, 101])
Min =  9.999999747378752e-05
Mean = 0.04444881156086922
Max =  0.1517084836959839
Code
for i in range(0, 1):
    plt.imshow(b1_tens[i].numpy())
    plt.colorbar()
    plt.show()

Band 2

Code
# Using rasterio
with rasterio.open(b2_path) as b2:
    # Read all layers/channels into a single numpy array
    # rasterio indexes channels starting from 1, hence the range is 1 to src.count + 1
    b2_stack = b2.read([i for i in range(1, b2.count + 1)])
Code
# Print the shape of the original b2_stack array
print(b2_stack.shape)

# Replace NaNs with -1
b2_stack = np.nan_to_num(b2_stack, nan=-1.0)

# Convert the numpy array to a PyTorch tensor
b2_tens = torch.from_numpy(b2_stack)
print(b2_tens.shape)

# Display the mean, max, and min values of the b2 tensor
print(f'Min =  {torch.min(b2_tens)}')
print(f'Mean = {torch.mean(b2_tens)}')
print(f'Max =  {torch.max(b2_tens)}')
(10103, 101, 101)
torch.Size([10103, 101, 101])
Min =  0.002810720121487975
Mean = 0.05629923567175865
Max =  0.1931755244731903
Code
for i in range(0, 1):
    plt.imshow(b2_tens[i].numpy())
    plt.colorbar()
    plt.show()

Band 3

Code
# Using rasterio
with rasterio.open(b3_path) as b3:
    # Read all layers/channels into a single numpy array
    # rasterio indexes channels starting from 1, hence the range is 1 to src.count + 1
    b3_stack = b3.read([i for i in range(1, b3.count + 1)])
Code
# Print the shape of the original b3_stack array
print(b3_stack.shape)

# Replace NaNs with -1
b3_stack = np.nan_to_num(b3_stack, nan=-1.0)

# Convert the numpy array to a PyTorch tensor
b3_tens = torch.from_numpy(b3_stack)
print(b3_tens.shape)

# Display the mean, max, and min values of the b3 tensor
print(f'Min =  {torch.min(b3_tens)}')
print(f'Mean = {torch.mean(b3_tens)}')
print(f'Max =  {torch.max(b3_tens)}')
(10103, 101, 101)
torch.Size([10103, 101, 101])
Min =  0.02109863981604576
Mean = 0.08027872443199158
Max =  0.2795756757259369
Code
for i in range(0, 1):
    plt.imshow(b3_tens[i].numpy())
    plt.colorbar()
    plt.show()

Band 4

Code
# Using rasterio
with rasterio.open(b4_path) as b4:
    # Read all layers/channels into a single numpy array
    # rasterio indexes channels starting from 1, hence the range is 1 to src.count + 1
    b4_stack = b4.read([i for i in range(1, b4.count + 1)])
Code
# Print the shape of the original b4_stack array
print(b4_stack.shape)

# Replace NaNs with -1
b4_stack = np.nan_to_num(b4_stack, nan=-1.0)

# Convert the numpy array to a PyTorch tensor
b4_tens = torch.from_numpy(b4_stack)
print(b4_tens.shape)

# Display the mean, max, and min values of the b4 tensor
print(f'Min =  {torch.min(b4_tens)}')
print(f'Mean = {torch.mean(b4_tens)}')
print(f'Max =  {torch.max(b4_tens)}')
(10103, 101, 101)
torch.Size([10103, 101, 101])
Min =  0.006578320171684027
Mean = 0.09937984496355057
Max =  0.43867969512939453
Code
for i in range(0, 1):
    plt.imshow(b4_tens[i].numpy())
    plt.colorbar()
    plt.show()

Band 5

Code
# Using rasterio
with rasterio.open(b5_path) as b5:
    # Read all layers/channels into a single numpy array
    # rasterio indexes channels starting from 1, hence the range is 1 to src.count + 1
    b5_stack = b5.read([i for i in range(1, b5.count + 1)])
Code
# Print the shape of the original b5_stack array
print(b5_stack.shape)

# Replace NaNs with -1
b5_stack = np.nan_to_num(b5_stack, nan=-1.0)

# Convert the numpy array to a PyTorch tensor
b5_tens = torch.from_numpy(b5_stack)
print(b5_tens.shape)

# Display the mean, max, and min values of the b5 tensor
print(f'Min =  {torch.min(b5_tens)}')
print(f'Mean = {torch.mean(b5_tens)}')
print(f'Max =  {torch.max(b5_tens)}')
(10103, 101, 101)
torch.Size([10103, 101, 101])
Min =  0.03587600216269493
Mean = 0.1369013786315918
Max =  0.4592735767364502
Code
for i in range(0, 1):
    plt.imshow(b5_tens[i].numpy())
    plt.colorbar()
    plt.show()

Band 6

Code
# Using rasterio
with rasterio.open(b6_path) as b6:
    # Read all layers/channels into a single numpy array
    # rasterio indexes channels starting from 1, hence the range is 1 to src.count + 1
    b6_stack = b6.read([i for i in range(1, b6.count + 1)])
Code
# Print the shape of the original b6_stack array
print(b6_stack.shape)

# Replace NaNs with -1
b6_stack = np.nan_to_num(b6_stack, nan=-1.0)

# Convert the numpy array to a PyTorch tensor
b6_tens = torch.from_numpy(b6_stack)
print(b6_tens.shape)

# Display the mean, max, and min values of the b6 tensor
print(f'Min =  {torch.min(b6_tens)}')
print(f'Mean = {torch.mean(b6_tens)}')
print(f'Max =  {torch.max(b6_tens)}')
(10103, 101, 101)
torch.Size([10103, 101, 101])
Min =  0.038534000515937805
Mean = 0.17727354168891907
Max =  0.5120914578437805
Code
for i in range(0, 1):
    plt.imshow(b6_tens[i].numpy())
    plt.colorbar()
    plt.show()

Band 7

Code
# Using rasterio
with rasterio.open(b7_path) as b7:
    # Read all layers/channels into a single numpy array
    # rasterio indexes channels starting from 1, hence the range is 1 to src.count + 1
    b7_stack = b7.read([i for i in range(1, b7.count + 1)])
Code
# Print the shape of the original b7_stack array
print(b7_stack.shape)

# Replace NaNs with -1
b7_stack = np.nan_to_num(b7_stack, nan=-1.0)

# Convert the numpy array to a PyTorch tensor
b7_tens = torch.from_numpy(b7_stack)
print(b7_tens.shape)

# Display the mean, max, and min values of the b7 tensor
print(f'Min =  {torch.min(b7_tens)}')
print(f'Mean = {torch.mean(b7_tens)}')
print(f'Max =  {torch.max(b7_tens)}')
(10103, 101, 101)
torch.Size([10103, 101, 101])
Min =  0.04165744036436081
Mean = 0.19983430206775665
Max =  0.6045699119567871
Code
for i in range(0, 1):
    plt.imshow(b7_tens[i].numpy())
    plt.colorbar()
    plt.show()

Band 8

Code
# Using rasterio
with rasterio.open(b8_path) as b8:
    # Read all layers/channels into a single numpy array
    # rasterio indexes channels starting from 1, hence the range is 1 to src.count + 1
    b8_stack = b8.read([i for i in range(1, b8.count + 1)])
Code
# Print the shape of the original b8_stack array
print(b8_stack.shape)

# Replace NaNs with -1
b8_stack = np.nan_to_num(b8_stack, nan=-1.0)

# Convert the numpy array to a PyTorch tensor
b8_tens = torch.from_numpy(b8_stack)
print(b8_tens.shape)

# Display the mean, max, and min values of the b8 tensor
print(f'Min =  {torch.min(b8_tens)}')
print(f'Mean = {torch.mean(b8_tens)}')
print(f'Max =  {torch.max(b8_tens)}')
(10103, 101, 101)
torch.Size([10103, 101, 101])
Min =  0.03680320084095001
Mean = 0.2095790058374405
Max =  0.6004582643508911
Code
for i in range(0, 1):
    plt.imshow(b8_tens[i].numpy())
    plt.colorbar()
    plt.show()

Band 8a

Code
# Using rasterio
with rasterio.open(b8a_path) as b8a:
    # Read all layers/channels into a single numpy array
    # rasterio indexes channels starting from 1, hence the range is 1 to src.count + 1
    b8a_stack = b8a.read([i for i in range(1, b8a.count + 1)])
Code
# Print the shape of the original b8a_stack array
print(b8a_stack.shape)

# Replace NaNs with -1
b8a_stack = np.nan_to_num(b8a_stack, nan=-1.0)

# Convert the numpy array to a PyTorch tensor
b8a_tens = torch.from_numpy(b8a_stack)
print(b8a_tens.shape)

# Display the mean, max, and min values of the b8a tensor
print(f'Min =  {torch.min(b8a_tens)}')
print(f'Mean = {torch.mean(b8a_tens)}')
print(f'Max =  {torch.max(b8a_tens)}')
(10103, 101, 101)
torch.Size([10103, 101, 101])
Min =  0.03570704162120819
Mean = 0.22782424092292786
Max =  0.6218413710594177
Code
for i in range(0, 1):
    plt.imshow(b8a_tens[i].numpy())
    plt.colorbar()
    plt.show()

Band 9

Code
# Using rasterio
with rasterio.open(b9_path) as b9:
    # Read all layers/channels into a single numpy array
    # rasterio indexes channels starting from 1, hence the range is 1 to src.count + 1
    b9_stack = b9.read([i for i in range(1, b9.count + 1)])
Code
# Print the shape of the original b9_stack array
print(b9_stack.shape)

# Replace NaNs with -1
b9_stack = np.nan_to_num(b9_stack, nan=-1.0)

# Convert the numpy array to a PyTorch tensor
b9_tens = torch.from_numpy(b9_stack)
print(b9_tens.shape)

# Display the mean, max, and min values of the b9 tensor
print(f'Min =  {torch.min(b9_tens)}')
print(f'Mean = {torch.mean(b9_tens)}')
print(f'Max =  {torch.max(b9_tens)}')
(10103, 101, 101)
torch.Size([10103, 101, 101])
Min =  0.012299999594688416
Mean = 0.22701695561408997
Max =  0.5680500268936157
Code
for i in range(0, 1):
    plt.imshow(b9_tens[i].numpy())
    plt.colorbar()
    plt.show()

Band 11

Code
# Using rasterio
with rasterio.open(b11_path) as b11:
    # Read all layers/channels into a single numpy array
    # rasterio indexes channels starting from 1, hence the range is 1 to src.count + 1
    b11_stack = b11.read([i for i in range(1, b11.count + 1)])
Code
# Print the shape of the original b11_stack array
print(b11_stack.shape)

# Replace NaNs with -1
b11_stack = np.nan_to_num(b11_stack, nan=-1.0)

# Convert the numpy array to a PyTorch tensor
b11_tens = torch.from_numpy(b11_stack)
print(b11_tens.shape)

# Display the mean, max, and min values of the b11 tensor
print(f'Min =  {torch.min(b11_tens)}')
print(f'Mean = {torch.mean(b11_tens)}')
print(f'Max =  {torch.max(b11_tens)}')
(10103, 101, 101)
torch.Size([10103, 101, 101])
Min =  0.01741199940443039
Mean = 0.27866700291633606
Max =  0.657039225101471
Code
for i in range(0, 1):
    plt.imshow(b11_tens[i].numpy())
    plt.colorbar()
    plt.show()

Band 12

Code
# Using rasterio
with rasterio.open(b12_path) as b12:
    # Read all layers/channels into a single numpy array
    # rasterio indexes channels starting from 1, hence the range is 1 to src.count + 1
    b12_stack = b12.read([i for i in range(1, b1.count + 1)])
Code
# Print the shape of the original b12_stack array
print(b12_stack.shape)

# Replace NaNs with -1
b12_stack = np.nan_to_num(b12_stack, nan=-1.0)

# Convert the numpy array to a PyTorch tensor
b12_tens = torch.from_numpy(b12_stack)
print(b12_tens.shape)

# Display the mean, max, and min values of the b12 tensor
print(f'Min =  {torch.min(b12_tens)}')
print(f'Mean = {torch.mean(b12_tens)}')
print(f'Max =  {torch.max(b12_tens)}')
(10103, 101, 101)
torch.Size([10103, 101, 101])
Min =  0.012337599880993366
Mean = 0.19245103001594543
Max =  0.5119996666908264
Code
for i in range(0, 1):
    plt.imshow(b12_tens[i].numpy())
    plt.colorbar()
    plt.show()

View as RGB

Given the Red (B4), Green (B3), and Blue (B2) bands, we can create an RGB image.

Code
# Assuming b4_tens, b3_tens, and b2_tens are your tensors
rgb_image = torch.stack([b4_tens, b3_tens, b2_tens], dim=-1)

# Convert to NumPy
rgb_image_np = rgb_image[0].cpu().numpy()

# Normalize to the range [0, 1] for display
rgb_image_np = (rgb_image_np - rgb_image_np.min()) / (rgb_image_np.max() - rgb_image_np.min())

# Display the image
plt.imshow(rgb_image_np)
plt.title('Sentinel-2 RGB Image')
plt.show()

Presence records - target of model

The target is what we are trying to predict with the deepSSF model, with is the location of the observed next step.

Code
# Using rasterio
with rasterio.open(pres_path) as pres:
    # Read all layers/channels into a single numpy array
    # rasterio indexes channels starting from 1, hence the range is 1 to src.count + 1
    pres_stack = pres.read([i for i in range(1, pres.count + 1)])

print(pres_stack.shape)
print(type(pres_stack))
(10103, 101, 101)
<class 'numpy.ndarray'>
Code
for i in range(0, 1):
    plt.imshow(pres_stack[i])
    plt.show()

Combine the spatial layers into channels

Code
# use sentinel-2 bands
combined_stack = torch.stack([b1_tens,
                              b2_tens,
                              b3_tens,
                              b4_tens,
                              b5_tens,
                              b6_tens,
                              b7_tens,
                              b8_tens,
                              b8a_tens,
                              b9_tens,
                              b11_tens,
                              b12_tens,
                              slope_tens],
                              dim=1)

print(combined_stack.shape)
torch.Size([10103, 13, 101, 101])

Defining data sets and data loaders

Creating a dataset class

This custom PyTorch Dataset organizes all your input (spatial data, scalar covariates, bearing, and target) in a single object, allowing you to neatly manage how samples are accessed. The __init__ method prepares and stores all the data, __len__ returns the total number of samples, and __getitem__ retrieves a single sample by index—enabling straightforward batching and iteration when used with a DataLoader.

Code
class buffalo_data(Dataset):

    def __init__(self):
        # data loading
        self.spatial_data_x = combined_stack

        # the scalar data that will be converted to grid data and added to the spatial covariates for CNN components
        self.scalar_to_grid_data = torch.from_numpy(buffalo_df[['hour_t1_sin1',
                                                                'hour_t1_cos1',
                                                                'hour_t1_sin2',
                                                                'hour_t1_cos2',
                                                                'yday_t1_sin1',
                                                                'yday_t1_cos1',
                                                                'yday_t1_sin2',
                                                                'yday_t1_cos2']].values).float()

       # the bearing data that will be added as a channel to the spatial covariates
        self.bearing_x = torch.from_numpy(buffalo_df[['bearing_tm1']].values).float()

        # the target data
        self.target = torch.tensor(pres_stack)

        # number of samples
        self.n_samples = self.spatial_data_x.shape[0]

    def __len__(self):
        # allows for the use of len() function
        return self.n_samples

    def __getitem__(self, index):
        # allows for indexing of the dataset
        return self.spatial_data_x[index], self.scalar_to_grid_data[index], self.bearing_x[index], self.target[index]

Now we can create an instance of the dataset class and check that is working as expected.

Code
# Create an instance of our custom buffalo_data Dataset:
dataset = buffalo_data()

# Print the total number of samples loaded (determined by n_samples in the dataset):
print(dataset.n_samples)

# Retrieve *all* samples (using the slice dataset[:] invokes __getitem__ on all indices).
# This returns a tuple of (spatial data, scalar-to-grid data, bearing data, target labels).
features1, features2, features3, labels = dataset[:]

# Examine the dimensions of each returned tensor for verification:

# Spatial data
print(features1.shape)

# Scalar-to-grid data
print(features2.shape)

# Bearing data
print(features3.shape)

# Target labels
print(labels.shape)
10103
torch.Size([10103, 13, 101, 101])
torch.Size([10103, 8])
torch.Size([10103, 1])
torch.Size([10103, 101, 101])

Split into training, validation and test sets

Code
training_split = 0.8 # 80% of the data will be used for training
validation_split = 0.1 # 10% of the data will be used for validation (deciding when to stop training)
test_split = 0.1 # 10% of the data will be used for testing (model evaluation)

# To split the data randomly
# dataset_train, dataset_val, dataset_test = torch.utils.data.random_split(dataset,
#                                                                          [training_split,
#                                                                           validation_split,
#                                                                           test_split])

To split the data sequentially, we can use the Subset function from PyTorch, which allows us to select a subset of the data based on the indices we provide.

Code
# For sequential split (need integers)
n_samples = len(dataset)
n_train = int(training_split * n_samples)
n_val = int(validation_split * n_samples)
n_test = n_samples - n_train - n_val  # Ensure they sum to total number of samples

# Get the start and end indices for each split
train_end = int(training_split * n_samples)
val_end = int((training_split + validation_split) * n_samples)

train_indices = list(range(0, train_end))
val_indices = list(range(train_end, val_end))
test_indices = list(range(val_end, len(dataset)))

# to split the samples sequentially
dataset_train = torch.utils.data.Subset(dataset, train_indices)
dataset_val = torch.utils.data.Subset(dataset, val_indices)
dataset_test = torch.utils.data.Subset(dataset, test_indices)

print("Number of training samples: ",   len(dataset_train))
print("Number of validation samples: ", len(dataset_val))
print("Number of testing samples: ",    len(dataset_test))
Number of training samples:  8082
Number of validation samples:  1010
Number of testing samples:  1011

Create dataloaders

The DataLoader in PyTorch wraps an iterable around the Dataset to enable easy access to the samples.

Code
# Define the batch size for how many samples to process at once in each step:
batch_size = 32

# Create a DataLoader for the training dataset with a batch size of batch_size, and shuffle samples
# so that the model doesn't see data in the same order each epoch.
dataloader_train = DataLoader(dataset=dataset_train,
                              batch_size=batch_size,
                              shuffle=True)

# Create a DataLoader for the validation dataset, also with a batch size of batch_size and shuffling.
# Even though it's not always mandatory to shuffle validation data, some users keep the same setting.
dataloader_val = DataLoader(dataset=dataset_val,
                            batch_size=batch_size,
                            shuffle=True)

# Create a DataLoader for the test dataset, likewise with a batch size of batch_size and shuffling.
# As we want to index the testing data for plotting, we will not shuffle the test data.
dataloader_test = DataLoader(dataset=dataset_test,
                             batch_size=batch_size,
                             shuffle=False)

Check that the data loader is working as expected.

Code
# Display image and label.
# next(iter(dataloader_train)) returns the next batch of the training data
features1, features2, features3, labels = next(iter(dataloader_train))
print(f"Feature 1 batch shape: {features1.size()}")
print(f"Feature 2 batch shape: {features2.size()}")
print(f"Feature 3 batch shape: {features3.size()}")
print(f"Labels batch shape: {labels.size()}")
Feature 1 batch shape: torch.Size([32, 13, 101, 101])
Feature 2 batch shape: torch.Size([32, 8])
Feature 3 batch shape: torch.Size([32, 1])
Labels batch shape: torch.Size([32, 101, 101])

Define the model

Deep learning can be considered as a sequence of blocks, each of which perform some (typically nonlinear) transformation on input data to produce some output. Providing each block has the appropriate inputs, they can be combined to build a larger network that is capable of achieving complex and abstract transformations and can be used to represent complex processes.

A block is modular component of a neural network, in our case defined as a Python class (type of object with certain functionality described by its definition) inheriting from torch.nn.Module in PyTorch. A block encapsulates a sequence of operations, including layers (such as fully connected layers or convolutional layers) and activation functions, to process input data. Each block has a forward method (i.e. instructions) that defines the data flow through the network during inference or training.

Convolutional block for the habitat selection subnetwork

This block is a convolutional layer that takes in the spatial covariates (including the layers created from the scalar values such as time), goes through a series of convolution operations and ReLU activation functions and outputs a feature map, which is the habitat selection probability surface.

Code
class Conv2d_block_spatial(nn.Module):
    def __init__(self, params):
        super(Conv2d_block_spatial, self).__init__()

        # define the parameters
        self.batch_size = params.batch_size
        self.input_channels = params.input_channels
        self.output_channels = params.output_channels
        self.kernel_size = params.kernel_size
        self.stride = params.stride
        self.padding = params.padding
        self.image_dim = params.image_dim
        self.device = params.device

        # define the layers - nn.Sequential allows for the definition of layers in a sequential manner
        self.conv2d = nn.Sequential(

        # convolutional layer 1
        nn.Conv2d(in_channels=self.input_channels,
                  out_channels=self.output_channels,
                  kernel_size=self.kernel_size,
                  stride=self.stride,
                  padding=self.padding),
        # ReLU activation function
        nn.ReLU(),

        # convolutional layer 2
        nn.Conv2d(in_channels=self.output_channels,
                  out_channels=self.output_channels,
                  kernel_size=self.kernel_size,
                  stride=self.stride,
                  padding=self.padding),
        # ReLU activation function
        nn.ReLU(),

        # convolutional layer 2
        nn.Conv2d(in_channels=self.output_channels,
                  out_channels=self.output_channels,
                  kernel_size=self.kernel_size,
                  stride=self.stride,
                  padding=self.padding),
        # ReLU activation function
        nn.ReLU(),

        # convolutional layer 3, which outputs a single layer, which is the habitat selection map
        nn.Conv2d(in_channels=self.output_channels,
                  out_channels=1,
                  kernel_size=self.kernel_size,
                  stride=self.stride,
                  padding=self.padding)
        )

    # define the forward pass of the model, i.e. how the data flows through the model
    def forward(self, x):

        # self.conv2d(x) passes the input through the convolutional layers, and the squeeze function removes the channel dimension, resulting in a 2D tensor (habitat selection map)
        # print("Shape before squeeze:", self.conv2d(x).shape) # Debugging print
        conv2d_spatial = self.conv2d(x).squeeze(dim = 1)

        # normalise to sum to 1
        # print("Shape before logsumexp:", conv2d_spatial.shape) # Debugging print
        conv2d_spatial = conv2d_spatial - torch.logsumexp(conv2d_spatial, dim = (1, 2), keepdim = True)

        # output the habitat selection map
        return conv2d_spatial

Convolutional block for the movement subnetwork

This block is also convolutional layer, with the same inputs, but this block also has max pooling layers to reduce the spatial resolution of the feature maps whilst preserving the most prominent features in the feature maps, and outputs a ‘flattened’ feature map. A flattened feature map is a 1D tensor (a vector) that can be used as input to a fully connected layer.

Code
class Conv2d_block_toFC(nn.Module):
    def __init__(self, params):
        super(Conv2d_block_toFC, self).__init__()

        # define the parameters
        self.batch_size = params.batch_size
        self.input_channels = params.input_channels
        self.output_channels_movement = params.output_channels_movement
        self.kernel_size = params.kernel_size
        self.stride = params.stride
        self.kernel_size_mp = params.kernel_size_mp
        self.stride_mp = params.stride_mp
        self.padding = params.padding
        self.image_dim = params.image_dim
        self.device = params.device

        # define the layers - nn.Sequential allows for the definition of layers in a sequential manner
        self.conv2d = nn.Sequential(

        # convolutional layer 1
        nn.Conv2d(in_channels=self.input_channels,
                  out_channels=self.output_channels_movement,
                  kernel_size=self.kernel_size,
                  stride=self.stride,
                  padding=self.padding),
        # ReLU activation function
        nn.ReLU(),

        # max pooling layer 1 (reduces the spatial dimensions of the data whilst retaining the most important features)
        nn.MaxPool2d(kernel_size=self.kernel_size_mp,
                     stride=self.stride_mp),

        # convolutional layer 2
        nn.Conv2d(in_channels=self.output_channels_movement,
                  out_channels=self.output_channels_movement,
                  kernel_size=self.kernel_size,
                  stride=self.stride,
                  padding=self.padding),
        # ReLU activation function
        nn.ReLU(),

        # max pooling layer 2
        nn.MaxPool2d(kernel_size=self.kernel_size_mp,
                     stride=self.stride_mp),

        # # convolutional layer 3
        # nn.Conv2d(in_channels=self.output_channels_movement,
        #           out_channels=self.output_channels_movement,
        #           kernel_size=self.kernel_size,
        #           stride=self.stride,
        #           padding=self.padding),
        # # ReLU activation function
        # nn.ReLU(),

        # # max pooling layer 3
        # nn.MaxPool2d(kernel_size=self.kernel_size_mp,
        #              stride=self.stride_mp),

        # flatten the data to pass through the fully connected layer
        nn.Flatten())

    def forward(self, x):

        # self.conv2d(x) passes the input through the convolutional layers, and outputs a 1D tensor
        return self.conv2d(x)

Fully connected block for the movement subnetwork

This block takes in the flattened feature map from the previous block, passes through several fully connected layers, which extracts information from the spatial covariates that is relevant for movement, and outputs the parameters that define the movement kernel.

Code
class FCN_block_all_movement(nn.Module):
    def __init__(self, params):
        super(FCN_block_all_movement, self).__init__()

        # define the parameters
        self.batch_size = params.batch_size
        self.dense_dim_in_all = params.dense_dim_in_all
        self.dense_dim_hidden = params.dense_dim_hidden
        self.image_dim = params.image_dim
        self.device = params.device
        self.num_movement_params = params.num_movement_params
        self.dropout = params.dropout

        # define the layers - nn.Sequential allows for the definition of layers in a sequential manner
        self.ffn = nn.Sequential(

            # fully connected layer 1 (the dense_dim_in_all is the number of input features,
            # and should match the output of the Conv2d_block_toFC block).
            # the dense_dim_hidden is the number of neurons in the hidden layer, and doesn't need to be the same as the input features
            nn.Linear(self.dense_dim_in_all, self.dense_dim_hidden),
            # dropout layer (helps to reduce overfitting)
            nn.Dropout(self.dropout),
            # ReLU activation function
            nn.ReLU(),

            # fully connected layer 2
            # the number of input neurons should match the output from the previous layer
            nn.Linear(self.dense_dim_hidden, self.dense_dim_hidden),
            # dropout layer
            nn.Dropout(self.dropout),
            # ReLU activation function
            nn.ReLU(),

            # fully connected layer 3
            # the number of input neurons should match the output from the previous layer
            nn.Linear(self.dense_dim_hidden, self.dense_dim_hidden),
            # dropout layer
            nn.Dropout(self.dropout),
            # ReLU activation function
            nn.ReLU(),

            # fully connected layer final
            # the number of input neurons should match the output from the previous layer,
            # and the number of output neurons should match the number of movement parameters
            nn.Linear(self.dense_dim_hidden, self.num_movement_params)

        )

    def forward(self, x):

        # self.ffn(x) passes the input through the fully connected layers, and outputs a 1D tensor (vector of movement parameters)
        return self.ffn(x)

Block to convert the movement parameters to a probability distribution

What the block does

This block is a bit longer and more involved, but there are no parameters in here that need to be learned (estimated). It is just a series of operations that are applied to the movement parameters to convert them to a probability distribution.

This block takes in the movement parameters and converts them to a probability distribution. This essentially just applies the appropriate density functions using the parameter values predicted by the movement blocks, which in our case is a finite mixture of Gamma distributions and a finite mixture of von Mises distributions.

The formulation of predicting parameters and converting them to a movement kernel ensures that the movement kernel is very flexible, and can be any combination of distributions, which need not all be the same (e.g., a step length distribution may be combination of a Gamma and a log-normal distribution).

Constraints

One constraint to ensure that we can perform backpropagation is that the entire forward pass, including the block below that produces the density functions, must be differentiable with respect to the parameters of the model. PyTorch’s torch.distributions module and its special functions (e.g., torch.special) provide differentiable implementations for many common distributions. Examples are the

  • Gamma function for the (log) Gamma distribution, torch.lgamma()
  • The modified Bessel function of the first kind of order 0 for the von Mises distribution, torch.special.i0()

Some of the movement parameters, such as the shape and scale of the Gamma distribution, must be positive. We therefore exponentiate them in this block to ensure that they are positive. This means that the model is actually learning the log of the shape and scale parameters. For the von Mises mu parameters however, they can be any value, so we do not need to exponentiate them. We could constrain them to be between -pi and pi, but this is not necessary as the von Mises distribution is periodic, so any value will be equivalent to another value that is within the range -pi to pi.

Notes

To help with identifiability, it is possible to fix certain parameter values, such as the mu parameters in the mixture of von Mises distributions to pi and -pi for instance (one would then reduce the number of predicted parameters by the previous block, as these no longer need to be predicted).

We can also transform certain parameters such that they are being estimated in a similar range (analagous to standardising variables in linear regression). In our case we know that the scale parameter of one of the Gamma distributions is around 500. What we can then do after exponentiating is multiply the scale parameter by 500, so the model is learning the log of the scale parameter divided by 500. This will ensure that this parameter is in a similar range to the other parameters, and can help with convergence. To do this we:

Pull out the relevant parameters from the input tensor (output of previous block) - gamma_scale2 = torch.exp(x[:, 4]).unsqueeze(0).unsqueeze(0)

Multiply the scale parameter by 500, so the model is learning the log of the scale parameter divided by 500 - gamma_scale2 = gamma_scale2 * 500

Consideration for the centre cell

As the centre cell has a distance of exactly 0 (using the distance layer we created), this can cause numerical issues for the gamma distribution, as when the shape parameter is less than one the mode approaches infinity. To avoid this, we can add a small value to the central cell of the distance layer, so that the distance is never exactly 0.

To get the value to use in the central cell, we will calculate the average distance from the very centre to any point in the cell (assuming that the distance within the cell is continuous). This comes out to be:

\(\int_{-0.5}^{0.5} \int_{-0.5}^{0.5} \sqrt{x^2 + y^2} \, dx \, dy\)

We calculate a constant numerically below:

Code
def integrand(x, y):
    return mp.sqrt(x**2 + y**2)

val = mp.quad(lambda Y:
              mp.quad(lambda X: integrand(X, Y),
                      [-0.5, 0.5]),
                      [-0.5, 0.5])

print(val)
0.382597668656132
Code
class Params_to_Grid_Block(nn.Module):
    def __init__(self, params):
        super(Params_to_Grid_Block, self).__init__()

        # define the parameters
        self.batch_size = params.batch_size
        self.image_dim = params.image_dim
        self.pixel_size = params.pixel_size

        # create distance and bearing layers
        # determine the distance of each pixel from the centre of the image
        self.center = self.image_dim // 2
        y, x = np.indices((self.image_dim, self.image_dim))
        self.distance_layer = torch.from_numpy(np.sqrt((self.pixel_size*(x - self.center))**2 +
                                                       (self.pixel_size*(y - self.center))**2)).float()
        # change the centre cell to the average distance from the centre to the edge of the pixel

        # average distance from the centre to the perimeter of the pixel (accounting for longer distances at the corners)
        # self.distance_layer[self.center, self.center] = 0.56*self.pixel_size

        # average distance from the centre to any point within the pixel
        # calculated as a double integral of sqrt(x^2 + y^2) dx dy over the area of the pixel
        self.distance_layer[self.center, self.center] = 0.3826*self.pixel_size

        # determine the bearing of each pixel from the centre of the image
        self.bearing_layer = torch.from_numpy(np.arctan2(self.center - y,
                                                         x - self.center)).float()
        self.device = params.device


    # Gamma densities (on the log-scale) for the mixture distribution
    def gamma_density(self, x, shape, scale):
        # Ensure all tensors are on the same device as x
        shape = shape.to(x.device)
        scale = scale.to(x.device)
        return -1*torch.lgamma(shape) -shape*torch.log(scale) + (shape - 1)*torch.log(x) - x/scale

        # # to account for change of variables
        # return (-1*torch.lgamma(shape) -shape*torch.log(scale) + (shape - 1)*torch.log(x) - x/scale) - torch.log(x)

    # log von Mises densities (on the log-scale) for the mixture distribution
    def vonmises_density(self, x, kappa, vm_mu):
        # Ensure all tensors are on the same device as x
        kappa = kappa.to(x.device)
        vm_mu = vm_mu.to(x.device)
        return kappa*torch.cos(x - vm_mu) - 1*(np.log(2*torch.pi) + torch.log(torch.special.i0(kappa)))


    def forward(self, x, bearing):

        # parameters of the first mixture distribution
        # x are the outputs from the fully connected layers (vector of movement parameters)
        # we therefore need to extract the appropriate parameters
        # the locations are not specific to any specific parameters, as long as any aren't extracted more than once

        # Gamma distributions

        # pull out the parameters of the first gamma distribution and exponentiate them to ensure they are positive
        # the unsqueeze function adds a new dimension to the tensor
        # we do this twice to match the dimensions of the distance_layer,
        # and then repeat the parameter value across a grid, such that the density can be calculated at every cell/pixel
        gamma_shape1 = torch.exp(x[:, 0]).unsqueeze(0).unsqueeze(0)
        gamma_shape1 = gamma_shape1.repeat(self.image_dim, self.image_dim, 1)
        # this just changes the order of the dimensions to match the distance_layer
        gamma_shape1 = gamma_shape1.permute(2, 0, 1)

        gamma_scale1 = torch.exp(x[:, 1]).unsqueeze(0).unsqueeze(0)
        gamma_scale1 = gamma_scale1.repeat(self.image_dim, self.image_dim, 1)
        gamma_scale1 = gamma_scale1.permute(2, 0, 1)

        # gamma_weight1 = torch.exp(x[:, 2]).unsqueeze(0).unsqueeze(0)
        gamma_weight1 = x[:, 2].unsqueeze(0).unsqueeze(0)
        gamma_weight1 = gamma_weight1.repeat(self.image_dim, self.image_dim, 1)
        gamma_weight1 = gamma_weight1.permute(2, 0, 1)

        # parameters of the second mixture distribution
        gamma_shape2 = torch.exp(x[:, 3]).unsqueeze(0).unsqueeze(0)
        gamma_shape2 = gamma_shape2.repeat(self.image_dim, self.image_dim, 1)
        gamma_shape2 = gamma_shape2.permute(2, 0, 1)

        gamma_scale2 = torch.exp(x[:, 4]).unsqueeze(0).unsqueeze(0)
        gamma_scale2 = gamma_scale2 * 500 ### transform the scale parameter so it can be estimated near the same range as the other parameters
        gamma_scale2 = gamma_scale2.repeat(self.image_dim, self.image_dim, 1)
        gamma_scale2 = gamma_scale2.permute(2, 0, 1)

        # gamma_weight1 = torch.exp(x[:, 5]).unsqueeze(0).unsqueeze(0)
        gamma_weight2 = x[:, 5].unsqueeze(0).unsqueeze(0)
        gamma_weight2 = gamma_weight2.repeat(self.image_dim, self.image_dim, 1)
        gamma_weight2 = gamma_weight2.permute(2, 0, 1)

        # Apply softmax to the mixture weights to ensure they sum to 1
        gamma_weights = torch.stack([gamma_weight1, gamma_weight2], dim=0)
        gamma_weights = torch.nn.functional.softmax(gamma_weights, dim=0)
        gamma_weight1 = gamma_weights[0]
        gamma_weight2 = gamma_weights[1]

        # calculation of Gamma densities
        gamma_density_layer1 = self.gamma_density(self.distance_layer,
                                                  gamma_shape1,
                                                  gamma_scale1).to(device)

        gamma_density_layer2 = self.gamma_density(self.distance_layer,
                                                  gamma_shape2,
                                                  gamma_scale2).to(device)

        # combining both densities to create a mixture distribution using logsumexp
        logsumexp_gamma_corr = torch.max(gamma_density_layer1, gamma_density_layer2)
        gamma_density_layer = logsumexp_gamma_corr + torch.log(gamma_weight1 * torch.exp(gamma_density_layer1 - logsumexp_gamma_corr) +
                                                               gamma_weight2 * torch.exp(gamma_density_layer2 - logsumexp_gamma_corr))
        # print(torch.sum(gamma_density_layer))

        # Normalise the gamma density layer to sum to 1
        # gamma_density_layer = gamma_density_layer - torch.logsumexp(gamma_density_layer, dim = (1, 2), keepdim = True)
        # print(f'Gamma density sum: {torch.sum(torch.exp(gamma_density_layer))}')


        ## Von Mises Distributions

        # calculate the new bearing from the turning angle
        # takes in the bearing from the previous step and adds the turning angle, which is estimated by the model
        # we do not exponentiate the von Mises mu parameters as we want to allow them to be negative
        bearing_new1 = x[:, 6] + bearing[:, 0]

        # the new bearing becomes the mean of the von Mises distribution
        vonmises_mu1 = bearing_new1.unsqueeze(0).unsqueeze(0)
        vonmises_mu1 = vonmises_mu1.repeat(self.image_dim, self.image_dim, 1)
        vonmises_mu1 = vonmises_mu1.permute(2, 0, 1)

        # parameters of the first von Mises distribution
        vonmises_kappa1 = torch.exp(x[:, 7]).unsqueeze(0).unsqueeze(0)
        vonmises_kappa1 = vonmises_kappa1.repeat(self.image_dim, self.image_dim, 1)
        vonmises_kappa1 = vonmises_kappa1.permute(2, 0, 1)

        # vonmises_weight1 = torch.exp(x[:, 8]).unsqueeze(0).unsqueeze(0)
        vonmises_weight1 = x[:, 8].unsqueeze(0).unsqueeze(0)
        vonmises_weight1 = vonmises_weight1.repeat(self.image_dim, self.image_dim, 1)
        vonmises_weight1 = vonmises_weight1.permute(2, 0, 1)

        # vm_mu and weight for the second von Mises distribution
        bearing_new2 = x[:, 9] + bearing[:, 0]

        vonmises_mu2 = bearing_new2.unsqueeze(0).unsqueeze(0)
        vonmises_mu2 = vonmises_mu2.repeat(self.image_dim, self.image_dim, 1)
        vonmises_mu2 = vonmises_mu2.permute(2, 0, 1)

        # parameters of the second von Mises distribution
        vonmises_kappa2 = torch.exp(x[:, 10]).unsqueeze(0).unsqueeze(0)
        vonmises_kappa2 = vonmises_kappa2.repeat(self.image_dim, self.image_dim, 1)
        vonmises_kappa2 = vonmises_kappa2.permute(2, 0, 1)

        # vonmises_weight2 = torch.exp(x[:, 11]).unsqueeze(0).unsqueeze(0)
        vonmises_weight2 = x[:, 11].unsqueeze(0).unsqueeze(0)
        vonmises_weight2 = vonmises_weight2.repeat(self.image_dim, self.image_dim, 1)
        vonmises_weight2 = vonmises_weight2.permute(2, 0, 1)

        # Apply softmax to the weights
        vonmises_weights = torch.stack([vonmises_weight1, vonmises_weight2], dim=0)
        vonmises_weights = torch.nn.functional.softmax(vonmises_weights, dim=0)
        vonmises_weight1 = vonmises_weights[0]
        vonmises_weight2 = vonmises_weights[1]

        # calculation of von Mises densities
        vonmises_density_layer1 = self.vonmises_density(self.bearing_layer,
                                                        vonmises_kappa1,
                                                        vonmises_mu1).to(device)

        vonmises_density_layer2 = self.vonmises_density(self.bearing_layer,
                                                        vonmises_kappa2,
                                                        vonmises_mu2).to(device)

        # combining both densities to create a mixture distribution using the logsumexp trick
        logsumexp_vm_corr = torch.max(vonmises_density_layer1, vonmises_density_layer2)
        vonmises_density_layer = logsumexp_vm_corr + torch.log(vonmises_weight1 * torch.exp(vonmises_density_layer1 - logsumexp_vm_corr) +
                                                               vonmises_weight2 * torch.exp(vonmises_density_layer2 - logsumexp_vm_corr))
        # print(torch.sum(vonmises_density_layer))

        # Normalise the von Mises density layer to sum to 1
        vonmises_density_layer = vonmises_density_layer - torch.logsumexp(vonmises_density_layer, dim = (1, 2), keepdim = True)
        # print(f'von Mises density sum: {torch.sum(torch.exp(vonmises_density_layer))}')

        # combining the two distributions
        movement_grid = gamma_density_layer + vonmises_density_layer # Gamma and von Mises densities are on the log-scale

        # normalise (on the log-scale using the log-sum-exp trick) before combining with the habitat predictions
        # print('Movement grid unnorm ', torch.sum(movement_grid))
        # print(f'Movement density sum: {torch.sum(torch.exp(movement_grid))}')

        # Normalise the movement grid to sum to 1
        # movement_grid = movement_grid - torch.logsumexp(movement_grid, dim = (1, 2), keepdim = True)
        # print(f'Movement density normalised sum: {torch.sum(torch.exp(movement_grid))}')

        return movement_grid

Duplicate the block but account for the change in movement variables

Code
class Params_to_Grid_Block_ChV(nn.Module):
    def __init__(self, params):
        super(Params_to_Grid_Block_ChV, self).__init__()

        # define the parameters
        self.batch_size = params.batch_size
        self.image_dim = params.image_dim
        self.pixel_size = params.pixel_size

        # create distance and bearing layers
        # determine the distance of each pixel from the centre of the image
        self.center = self.image_dim // 2
        y, x = np.indices((self.image_dim, self.image_dim))
        self.distance_layer = torch.from_numpy(np.sqrt((self.pixel_size*(x - self.center))**2 +
                                                       (self.pixel_size*(y - self.center))**2)).float()
        # change the centre cell to the average distance from the centre to the edge of the pixel

        # average distance from the centre to the perimeter of the pixel (accounting for longer distances at the corners)
        # self.distance_layer[self.center, self.center] = 0.56*self.pixel_size

        # average distance from the centre to any point within the pixel
        # calculated as a double integral of sqrt(x^2 + y^2) dx dy over the area of the pixel
        self.distance_layer[self.center, self.center] = 0.3826*self.pixel_size

        # determine the bearing of each pixel from the centre of the image
        self.bearing_layer = torch.from_numpy(np.arctan2(self.center - y,
                                                         x - self.center)).float()
        self.device = params.device


    # Gamma densities (on the log-scale) for the mixture distribution
    def gamma_density(self, x, shape, scale):
        # Ensure all tensors are on the same device as x
        shape = shape.to(x.device)
        scale = scale.to(x.device)
        # return -1*torch.lgamma(shape) -shape*torch.log(scale) + (shape - 1)*torch.log(x) - x/scale

        # to account for change of variables
        return (-1*torch.lgamma(shape) -shape*torch.log(scale) + (shape - 1)*torch.log(x) - x/scale) - torch.log(x)

    # log von Mises densities (on the log-scale) for the mixture distribution
    def vonmises_density(self, x, kappa, vm_mu):
        # Ensure all tensors are on the same device as x
        kappa = kappa.to(x.device)
        vm_mu = vm_mu.to(x.device)
        return kappa*torch.cos(x - vm_mu) - 1*(np.log(2*torch.pi) + torch.log(torch.special.i0(kappa)))


    def forward(self, x, bearing):

        # parameters of the first mixture distribution
        # x are the outputs from the fully connected layers (vector of movement parameters)
        # we therefore need to extract the appropriate parameters
        # the locations are not specific to any specific parameters, as long as any aren't extracted more than once

        # Gamma distributions

        # pull out the parameters of the first gamma distribution and exponentiate them to ensure they are positive
        # the unsqueeze function adds a new dimension to the tensor
        # we do this twice to match the dimensions of the distance_layer,
        # and then repeat the parameter value across a grid, such that the density can be calculated at every cell/pixel
        gamma_shape1 = torch.exp(x[:, 0]).unsqueeze(0).unsqueeze(0)
        gamma_shape1 = gamma_shape1.repeat(self.image_dim, self.image_dim, 1)
        # this just changes the order of the dimensions to match the distance_layer
        gamma_shape1 = gamma_shape1.permute(2, 0, 1)

        gamma_scale1 = torch.exp(x[:, 1]).unsqueeze(0).unsqueeze(0)
        gamma_scale1 = gamma_scale1.repeat(self.image_dim, self.image_dim, 1)
        gamma_scale1 = gamma_scale1.permute(2, 0, 1)

        # gamma_weight1 = torch.exp(x[:, 2]).unsqueeze(0).unsqueeze(0)
        gamma_weight1 = x[:, 2].unsqueeze(0).unsqueeze(0)
        gamma_weight1 = gamma_weight1.repeat(self.image_dim, self.image_dim, 1)
        gamma_weight1 = gamma_weight1.permute(2, 0, 1)

        # parameters of the second mixture distribution
        gamma_shape2 = torch.exp(x[:, 3]).unsqueeze(0).unsqueeze(0)
        gamma_shape2 = gamma_shape2.repeat(self.image_dim, self.image_dim, 1)
        gamma_shape2 = gamma_shape2.permute(2, 0, 1)

        gamma_scale2 = torch.exp(x[:, 4]).unsqueeze(0).unsqueeze(0)
        gamma_scale2 = gamma_scale2 * 500 ### transform the scale parameter so it can be estimated near the same range as the other parameters
        gamma_scale2 = gamma_scale2.repeat(self.image_dim, self.image_dim, 1)
        gamma_scale2 = gamma_scale2.permute(2, 0, 1)

        # gamma_weight2 = torch.exp(x[:, 5]).unsqueeze(0).unsqueeze(0)
        gamma_weight2 = x[:, 5].unsqueeze(0).unsqueeze(0)
        gamma_weight2 = gamma_weight2.repeat(self.image_dim, self.image_dim, 1)
        gamma_weight2 = gamma_weight2.permute(2, 0, 1)

        # Apply softmax to the mixture weights to ensure they sum to 1
        gamma_weights = torch.stack([gamma_weight1, gamma_weight2], dim=0)
        gamma_weights = torch.nn.functional.softmax(gamma_weights, dim=0)
        gamma_weight1 = gamma_weights[0]
        gamma_weight2 = gamma_weights[1]

        # calculation of Gamma densities
        gamma_density_layer1 = self.gamma_density(self.distance_layer,
                                                  gamma_shape1,
                                                  gamma_scale1).to(device)

        gamma_density_layer2 = self.gamma_density(self.distance_layer,
                                                  gamma_shape2,
                                                  gamma_scale2).to(device)

        # combining both densities to create a mixture distribution using logsumexp
        logsumexp_gamma_corr = torch.max(gamma_density_layer1, gamma_density_layer2)
        gamma_density_layer = logsumexp_gamma_corr + torch.log(gamma_weight1 * torch.exp(gamma_density_layer1 - logsumexp_gamma_corr) +
                                                               gamma_weight2 * torch.exp(gamma_density_layer2 - logsumexp_gamma_corr))
        # print(torch.sum(gamma_density_layer))
        # print(torch.sum(torch.exp(gamma_density_layer)))


        ## Von Mises Distributions

        # calculate the new bearing from the turning angle
        # takes in the bearing from the previous step and adds the turning angle, which is estimated by the model
        # we do not exponentiate the von Mises mu parameters as we want to allow them to be negative
        bearing_new1 = x[:, 6] + bearing[:, 0]

        # the new bearing becomes the mean of the von Mises distribution
        vonmises_mu1 = bearing_new1.unsqueeze(0).unsqueeze(0)
        vonmises_mu1 = vonmises_mu1.repeat(self.image_dim, self.image_dim, 1)
        vonmises_mu1 = vonmises_mu1.permute(2, 0, 1)

        # parameters of the first von Mises distribution
        vonmises_kappa1 = torch.exp(x[:, 7]).unsqueeze(0).unsqueeze(0)
        vonmises_kappa1 = vonmises_kappa1.repeat(self.image_dim, self.image_dim, 1)
        vonmises_kappa1 = vonmises_kappa1.permute(2, 0, 1)

        # vonmises_weight1 = torch.exp(x[:, 8]).unsqueeze(0).unsqueeze(0)
        vonmises_weight1 = x[:, 8].unsqueeze(0).unsqueeze(0)
        vonmises_weight1 = vonmises_weight1.repeat(self.image_dim, self.image_dim, 1)
        vonmises_weight1 = vonmises_weight1.permute(2, 0, 1)

        # vm_mu and weight for the second von Mises distribution
        bearing_new2 = x[:, 9] + bearing[:, 0]

        vonmises_mu2 = bearing_new2.unsqueeze(0).unsqueeze(0)
        vonmises_mu2 = vonmises_mu2.repeat(self.image_dim, self.image_dim, 1)
        vonmises_mu2 = vonmises_mu2.permute(2, 0, 1)

        # parameters of the second von Mises distribution
        vonmises_kappa2 = torch.exp(x[:, 10]).unsqueeze(0).unsqueeze(0)
        vonmises_kappa2 = vonmises_kappa2.repeat(self.image_dim, self.image_dim, 1)
        vonmises_kappa2 = vonmises_kappa2.permute(2, 0, 1)

        # vonmises_weight2 = torch.exp(x[:, 11]).unsqueeze(0).unsqueeze(0)
        vonmises_weight2 = x[:, 11].unsqueeze(0).unsqueeze(0)
        vonmises_weight2 = vonmises_weight2.repeat(self.image_dim, self.image_dim, 1)
        vonmises_weight2 = vonmises_weight2.permute(2, 0, 1)

        # Apply softmax to the weights
        vonmises_weights = torch.stack([vonmises_weight1, vonmises_weight2], dim=0)
        vonmises_weights = torch.nn.functional.softmax(vonmises_weights, dim=0)
        vonmises_weight1 = vonmises_weights[0]
        vonmises_weight2 = vonmises_weights[1]

        # calculation of von Mises densities
        vonmises_density_layer1 = self.vonmises_density(self.bearing_layer,
                                                        vonmises_kappa1,
                                                        vonmises_mu1).to(device)

        vonmises_density_layer2 = self.vonmises_density(self.bearing_layer,
                                                        vonmises_kappa2,
                                                        vonmises_mu2).to(device)

        # combining both densities to create a mixture distribution using the logsumexp trick
        logsumexp_vm_corr = torch.max(vonmises_density_layer1, vonmises_density_layer2)
        vonmises_density_layer = logsumexp_vm_corr + torch.log(vonmises_weight1 * torch.exp(vonmises_density_layer1 - logsumexp_vm_corr) +
                                                               vonmises_weight2 * torch.exp(vonmises_density_layer2 - logsumexp_vm_corr))
        # print(torch.sum(vonmises_density_layer))
        # print(torch.sum(torch.exp(vonmises_density_layer)))

        # combining the two distributions
        movement_grid = gamma_density_layer + vonmises_density_layer # Gamma and von Mises densities are on the log-scale

        # normalise (on the log-scale using the log-sum-exp trick) before combining with the habitat predictions
        # movement_grid = movement_grid - torch.logsumexp(movement_grid, dim = (1, 2), keepdim = True)
        # print('Movement grid norm ', torch.sum(movement_grid))
        # print(torch.sum(torch.exp(movement_grid)))

        return movement_grid

Scalar to grid block

This block takes any scalar value (e.g., time of day, day of year) and converts it to a 2D image, with the same values for all pixels.

This is so that the scalar values can be used as input to the convolutional layers.

Code
class Scalar_to_Grid_Block(nn.Module):
    def __init__(self, params):
        super(Scalar_to_Grid_Block, self).__init__()

        # define the parameters
        self.batch_size = params.batch_size
        self.image_dim = params.image_dim
        self.device = params.device

    def forward(self, x):

        # how many scalar values are being passed in
        num_scalars = x.shape[1]
        # expand the scalar values to the spatial dimensions of the image
        scalar_map = x.view(x.shape[0], num_scalars, 1, 1).expand(x.shape[0],
                                                                  num_scalars,
                                                                  self.image_dim,
                                                                  self.image_dim)

        # return the scalar maps
        return scalar_map

Combine the blocks into the deepSSF model

Here is where we combine the blocks into a model. Similarly to the previous blocks, the model is a Python class that inherits from torch.nn.Module, which combines other torch.nn.Module modules.

For example, we can instantiate the habitat selection convolution block using self.conv_habitat = Conv2d_block_spatial(params) in the __init__ method (the ‘constructor’ for a class). We can now access that block using self.conv_habitat in the forward method.

In the forward method, we pass the input data through the habitat selection convolution block using output_habitat = self.conv_habitat(all_spatial), where all_spatial is the input data, which is a combination of the spatial covariates and the scalar values converted to 2D images.

First we instantiate the blocks, and then define the forward method, which defines the data flow through the network during inference or training.

Code
class ConvJointModel(nn.Module):
    def __init__(self, params):
        """
        ConvJointModel:
        - Initializes blocks for scalar-to-grid transformation,
          habitat convolution, movement convolution + movement fully connected, and final parameter-to-grid transformation.
        - Accepts parameters from the params object, which we will define later.
        """
        super(ConvJointModel, self).__init__()

        # Block to convert scalar features into grid-like (spatial) features
        self.scalar_grid_output = Scalar_to_Grid_Block(params)

        # Convolutional block for habitat selection
        self.conv_habitat = Conv2d_block_spatial(params)

        # # Convolutional block for movement extraction (output fed into fully connected layers)
        # self.conv_movement = Conv2d_block_toFC(params)

        # Fully connected block for movement
        self.fcn_movement_all = FCN_block_all_movement(params)

        # Converts movement distribution parameters into a grid (the 2D movement kernel)
        self.movement_grid_output = Params_to_Grid_Block_ChV(params)

        # Device information from params (e.g., CPU or GPU)
        self.device = params.device

    def forward(self, x):
        """
        Forward pass:
        1. Extract scalar data and convert to grid features.
        2. Concatenate the newly created scalar-based grids with spatial data.
        3. Pass this combined input through separate sub-networks for habitat and movement.
        4. Convert movement parameters to a grid, then stack the habitat and movement outputs.
        """
        # x contains:
        # - spatial_data_x (image-like layers)
        # - scalar_inputs (scalar features needing conversion)
        # - bearing_x (the bearing from the previous time step, the turning angle is estimated as the deviation from this)
        spatial_data_x = x[0]
        scalar_inputs = x[1]
        bearing_x = x[2]

        # Convert scalar data to spatial (grid) form
        scalar_grids = self.scalar_grid_output(scalar_inputs)

        # Combine the original spatial data with the newly generated scalar grids
        all_spatial = torch.cat([spatial_data_x, scalar_grids], dim=1)

        # HABITAT SUBNETWORK
        # Convolutional feature extraction for habitat selection
        output_habitat = self.conv_habitat(all_spatial)

        # MOVEMENT SUBNETWORK
        # # Convolutional feature extraction (different architecture for movement)
        # conv_movement = self.conv_movement(all_spatial)

        # Fully connected layers for movement (processing both spatial features and any extras)
        # output_movement = self.fcn_movement_all(conv_movement)
        output_movement = self.fcn_movement_all(scalar_inputs)

        # Transform the movement parameters into a grid, using bearing information
        output_movement = self.movement_grid_output(output_movement, bearing_x)

        # Combine (stack) habitat and movement outputs without merging them
        output = torch.stack((output_habitat, output_movement), dim=-1)

        return output

Set the parameters for the model which will be specified in a dictionary

This Python class serves as a simple parameter container for a model that involves both spatial (e.g., convolutional layers) and non-spatial inputs. It captures all relevant hyperparameters and settings—such as image dimensions, kernel sizes, and fully connected layer dimensions—along with the target device (CPU or GPU). This structure allows easy configuration of the model without scattering parameters throughout the code.

Code
class ModelParams():
    def __init__(self, dict_params):
        self.batch_size = dict_params["batch_size"]
        self.image_dim = dict_params["image_dim"]
        self.pixel_size = dict_params["pixel_size"]
        self.dim_in_nonspatial_to_grid = dict_params["dim_in_nonspatial_to_grid"]
        self.dense_dim_in_nonspatial = dict_params["dense_dim_in_nonspatial"]
        self.dense_dim_hidden = dict_params["dense_dim_hidden"]
        self.dense_dim_in_all = dict_params["dense_dim_in_all"]
        self.input_channels = dict_params["input_channels"]
        self.output_channels = dict_params["output_channels"]
        self.output_channels_movement = dict_params["output_channels_movement"]
        self.kernel_size = dict_params["kernel_size"]
        self.stride = dict_params["stride"]
        self.kernel_size_mp = dict_params["kernel_size_mp"]
        self.stride_mp = dict_params["stride_mp"]
        self.padding = dict_params["padding"]
        self.image_dim = dict_params["image_dim"]
        self.num_movement_params = dict_params["num_movement_params"]
        self.dropout = dict_params["dropout"]
        self.device = dict_params["device"]

Define the parameters for the model

Here we enter the specific parameter values and hyperparameters for the model. These are the values that will be used to instantiate the model.

Code
n_max_pool_layers = 2 # used to determine the number of inputs entering the fully connected block - needs to be manually changed if the number of max pooling layers is changed
n_spatial_inputs = features1.shape[1] # number of spatial inputs layers 
n_scalar_inputs = features2.shape[1] # number of scalar inputs 

params_dict = {"batch_size": batch_size, #number of samples in each batch
               "image_dim": 101, #number of pixels along the edge of each local patch/image
               "pixel_size": 25, #number of metres along the edge of a pixel
               "input_channels": n_spatial_inputs + n_scalar_inputs, #number of spatial layers in each image + number of scalar layers that are converted to a grid
               "dim_in_nonspatial_to_grid": n_scalar_inputs, #the number of scalar predictors that are converted to a grid and appended to the spatial features
               "dense_dim_in_nonspatial": n_scalar_inputs, #change this to however many other scalar predictors you have (bearing, velocity etc)
               "kernel_size": 3, #the size of the 2D moving windows / kernels that are being learned
               "stride": 1, #the stride used when applying the kernel.  This reduces the dimension of the output if set to greater than 1
               "kernel_size_mp": 2, #the size of the kernel that is used in max pooling operations
               "stride_mp": 2, #the stride that is used in max pooling operations
               "padding": 1, #the amount of padding to apply to images prior to applying the 2D convolution
               "num_movement_params": 12, #number of parameters used to parameterise the movement kernel
               "dropout": 0.1, #the proportion of nodes that are dropped out in the dropout layers

               # hyperparameters that change the model architecture
               "output_channels": 4, #number of convolution filters to learn
               "output_channels_movement": 4, #number of convolution filters to learn for the movement kernel
               "dense_dim_hidden": 256, #number of nodes in the hidden layers

               # this will be updated below
               "dense_dim_in_all": n_scalar_inputs, #number of inputs entering the fully connected block once the nonspatial features have been concatenated to the spatial features
               "device": device
               }

# Now update the dictionary with calculated values
# params_dict["dense_dim_in_all"] = int(((params_dict["image_dim"] - (params_dict["image_dim"] % 2))**2) * (params_dict["output_channels_movement"] / (4**n_max_pool_layers)))

Note about the model

In future scripts (such as when simulating from the deepSSF model or training the model on Sentinel-2 data) we want to load the same model, so to prevent copying and pasting, we will save the model definition to a Python file and import it into future scripts.

We do this by copying the model definition above to a Python file named deepSSF_model.py, which can be imported into future scripts using import deepSSF_model.

Ideally we would just use that file to define the model in this script, but we include it here as it’s helpful to test components of the model and see how it all works.

Just remember that if you make changes to the model in this script, you will have to copy them across the deepSSF_model.py file, or just change the model definition in the deepSSF_model.py file and call that directly from here to train it.

To call it you would uncomment the lines in the next cell.

Code
# # Import the functions in the deepSSF_model.py file
# import deepSSF_model

# # Create an instance of the ModelParams class using the params_dict
# params = deepSSF_model.ModelParams(deepSSF_model.params_dict)

# # Create an instance of the ConvJointModel class using the params
# model = deepSSF_model.ConvJointModel(params).to(device)

# # Print the model architecture to check that it worked
# print(model)

Instantiate the model

Here we instantiate the model using the parameters defined above.

Code
# Initialize the parameter container using the parameters defined in 'params_dict'
params = ModelParams(params_dict)

# Create an instance of the ConvJointModel using the parameters,
# and move the model to the specified device (e.g., CPU or GPU)
model = ConvJointModel(params).to(device)

# Print the model architecture
print(model)
ConvJointModel(
  (scalar_grid_output): Scalar_to_Grid_Block()
  (conv_habitat): Conv2d_block_spatial(
    (conv2d): Sequential(
      (0): Conv2d(21, 4, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (1): ReLU()
      (2): Conv2d(4, 4, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (3): ReLU()
      (4): Conv2d(4, 4, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (5): ReLU()
      (6): Conv2d(4, 1, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    )
  )
  (fcn_movement_all): FCN_block_all_movement(
    (ffn): Sequential(
      (0): Linear(in_features=8, out_features=256, bias=True)
      (1): Dropout(p=0.1, inplace=False)
      (2): ReLU()
      (3): Linear(in_features=256, out_features=256, bias=True)
      (4): Dropout(p=0.1, inplace=False)
      (5): ReLU()
      (6): Linear(in_features=256, out_features=256, bias=True)
      (7): Dropout(p=0.1, inplace=False)
      (8): ReLU()
      (9): Linear(in_features=256, out_features=12, bias=True)
    )
  )
  (movement_grid_output): Params_to_Grid_Block_ChV()
)

Pull out some testing data

To test the other blocks, and the full model, we will need some data. We can pull that out from the training set.

Code
# Number of samples in the train dataset
print("Number of samples in the train dataset: ", len(dataloader_train.dataset))
print('\n')

# Select an index from the test dataset to retrieve a sample, between 0 and number of samples
# We picked this fairly arbitrarily, but with some interesting environmental features to illustrate the model's predictions
iteration_index = 2700

# 2. Retrieve a single sample (features and label) from the test dataset at the specified index

# sample_spatial_covs is a sample of the spatial covariates for a single step
# sample_temporal_covs is a sample of the temporal covariates for a single step
# sample_prev_bearing is a sample bearing of the previous step
# sample_next_step is the target label (what we are trying to predict) for the next step

# We set these here and will also use them later in the script to check how the model's predictions look,
# and when we extract feature maps from the convolutional layers
sample_spatial_covs, sample_temporal_covs, sample_prev_bearing, sample_next_step = dataloader_train.dataset[iteration_index]

# 3. Reshape data tensors to add a batch dimension (since the model expects batches)
sample_spatial_covs = sample_spatial_covs.unsqueeze(0).to(device)
sample_temporal_covs = sample_temporal_covs.unsqueeze(0).to(device)
sample_prev_bearing = sample_prev_bearing.unsqueeze(0).to(device)
sample_next_step = sample_next_step.unsqueeze(0).to(device)

print(f'Shape of the sample spatial covariates:  {sample_spatial_covs.shape}')
print(f'Shape of the sample temporal covariates: {sample_temporal_covs.shape}')
print(f'Shape of the sample previous bearing:    {sample_prev_bearing.shape}')
print(f'Shape of the sample next step:           {sample_next_step.shape}')
Number of samples in the train dataset:  8082


Shape of the sample spatial covariates:  torch.Size([1, 13, 101, 101])
Shape of the sample temporal covariates: torch.Size([1, 8])
Shape of the sample previous bearing:    torch.Size([1, 1])
Shape of the sample next step:           torch.Size([1, 101, 101])

For visualisation, we can return the scale of the covariates to their original values.

Code
# Slope
slope_norm = sample_spatial_covs.detach().cpu()[0, 12, :, :]
slope_natural = (slope_norm * (slope_max - slope_min)) + slope_min

Pull out the scalar values

Code
# Convert the PyTorch tensor x2 to a NumPy array:
#   1) Detach from the computation graph so no gradients are tracked.
#   2) Move to CPU memory.
#   3) Convert to NumPy.
# Then extract the first sample (index 0) and its respective channel for each variable:
hour_t2_sin = sample_temporal_covs.detach().cpu().numpy()[0, 0]
hour_t2_cos = sample_temporal_covs.detach().cpu().numpy()[0, 1]
yday_t2_sin = sample_temporal_covs.detach().cpu().numpy()[0, 2]
yday_t2_cos = sample_temporal_covs.detach().cpu().numpy()[0, 3]

# Convert x3 similarly and extract the bearing from the first sample and channel:
bearing = sample_prev_bearing.detach().cpu().numpy()[0, 0]

Helper functions

To return the hour and day of the year to their original values, we can use the following functions.

Code
def recover_hour(sin_term, cos_term):
    # Calculate the angle theta
    theta = np.arctan2(sin_term, cos_term)
    # Calculate hour_t2
    hour = (12 * theta) / np.pi % 24
    return hour

def recover_yday(sin_term, cos_term):
    # Calculate the angle theta
    theta = np.arctan2(sin_term, cos_term)
    # Calculate hour_t2
    yday = (365 * theta) / (2 * np.pi)  % 365
    return yday

Calculate the hour, day of year and previous bearing of the test sample

Code
hour_t2 = recover_hour(hour_t2_sin, hour_t2_cos)
hour_t2_integer = int(hour_t2)  # Convert to integer
print(f'Hour:               {hour_t2_integer}')

yday_t2 = recover_yday(yday_t2_sin, yday_t2_cos)
yday_t2_integer = int(yday_t2)  # Convert to integer
print(f'Day of the year:    {yday_t2_integer}')

bearing_degrees = np.degrees(bearing) % 360
bearing_degrees = round(bearing_degrees, 1)  # Round to 2 decimal places
bearing_degrees = int(bearing_degrees)  # Convert to integer
print(f'Bearing (radians):  {bearing}')
print(f'Bearing (degrees):  {bearing_degrees}')
Hour:               16
Day of the year:    122
Bearing (radians):  2.7272613048553467
Bearing (degrees):  156

Grab the row and column of the observed next step (label or target)

Code
# Find the coordinates of the element that is 1
target = sample_next_step.detach().cpu().numpy()[0,:,:]
coordinates = np.where(target == 1)
# Extract the coordinates
row, column = coordinates[0][0], coordinates[1][0]
print(f"The location of the next step is (row, column): ({row}, {column})")
The location of the next step is (row, column): (46, 44)

Plot the sample covariates

Code
# Plot the covariates
fig, axs = plt.subplots(1, 2, figsize=(9, 5))

red = sample_spatial_covs.detach().cpu()[0, 3, :, :]
green = sample_spatial_covs.detach().cpu()[0, 2, :, :]
blue = sample_spatial_covs.detach().cpu()[0, 1, :, :]

# Assuming b4_tens, b3_tens, and b2_tens are your tensors
rgb_image = torch.stack([red, green, blue], dim=-1)
print(rgb_image.shape)
# Convert to NumPy
rgb_image_np = rgb_image.numpy()

# Normalize to the range [0, 1] for display
rgb_image_np = (rgb_image_np - rgb_image_np.min()) / (rgb_image_np.max() - rgb_image_np.min())

# Plot RGB
im1 = axs[0].imshow(rgb_image_np)
axs[0].set_title('Sentinel-2 RGB Image')

# Plot Slope
im2 = axs[1].imshow(slope_natural.numpy())
axs[1].set_title('Slope')
# fig.colorbar(im2, ax=axs[1])

filename_covs = f'{output_dir}/covs_id{buffalo_id}_yday{yday_t2_integer}_hour{hour_t2_integer}_bearing{bearing_degrees}_next_r{row}_c{column}.png'
plt.tight_layout()
plt.savefig(filename_covs, dpi=300, bbox_inches='tight') # if we want to save the figure
plt.show()
plt.close()  # Close the figure to free memory
torch.Size([101, 101, 3])

Plot the target (observed location of the next step)

The model is trying to maximise the probability at the location of the next step, which is the target.

Code
filename_target = f'{output_dir}/target_id{buffalo_id}_yday{yday_t2_integer}_hour{hour_t2_integer}_bearing{bearing_degrees}_next_r{row}_c{column}.png'

plt.imshow(target)
plt.tight_layout()
plt.savefig(filename_target, dpi=300, bbox_inches='tight') # if we want to save the figure
plt.colorbar()
plt.show()

Prepare for training

Loss function

We use a custom negative log likelihood loss function. Essentially what this does is extracts the next-step log-probability at the location of the observed next step, and then takes the negative of this value. This is the loss that we want to minimise, as we want to maximise the probability of the observed next step.

We will also save this loss function in a script called deepSSF_loss.py so that we can import it into future scripts.

Code
class negativeLogLikeLoss(nn.Module):
    """
    Custom negative log-likelihood loss that operates on a 4D prediction tensor
    (batch, height, width, channels). The forward pass:
    1. Sums across channel 3 (two log-densities, habitat selection and movement predictions) to obtain a combined log-density.
    2. Multiplies this log-density by the target, which is 0 everywhere except for at the location of the next step, effectively extracting that value,
    then multiplies by -1 such that the function can be minimised (and the probabilities maximised).
    3. Applies the user-specified reduction (mean, sum, or none).
    """

    def __init__(self, reduction='mean'):
        """
        Args:
            reduction (str): Specifies the reduction to apply to the output:
                             'mean', 'sum', or 'none'.
        """
        super(negativeLogLikeLoss, self).__init__()
        assert reduction in ['mean', 'sum', 'none'], \
            "reduction should be 'mean', 'sum', or 'none'"
        self.reduction = reduction

    def forward(self, predict, target):
        """
        Forward pass of the negative log-likelihood loss.

        Args:
            predict (Tensor): A tensor of shape (B, H, W, 2) with log-densities
                              across two channels to be summed.
            target  (Tensor): A tensor of the same spatial dimensions (B, H, W)
                              indicating where the log-densities should be evaluated.

        Returns:
            Tensor: The computed negative log-likelihood loss. Shape depends on
                    the reduction method.
        """

        habitat_probability_surface = predict[:, :, :, 0] * 1.0
        movement_probability_surface = predict[:, :, :, 1] * 1.0

        # Sum the log-densities from the two channels
        predict_prod = habitat_probability_surface + movement_probability_surface

        # Normalise the movement_probability_surface log-densities using the log-sum-exp trick
        # predict_prod = predict_prod - torch.logsumexp(predict_prod, dim = (1, 2), keepdim = True)

        # Check for NaNs in the combined predictions
        if torch.isnan(predict_prod).any():
            print("NaNs detected in predict_prod")
            print("predict_prod:", predict_prod)
            raise ValueError("NaNs detected in predict_prod")

        # Normalise the next-step log-densities using the log-sum-exp trick
        # predict_prod = predict_prod - torch.logsumexp(predict_prod, dim = (1, 2), keepdim = True)

        # Compute negative log-likelihood by multiplying log-densities with target
        # and then flipping the sign
        negLogLike = -1 * (predict_prod * target)

        # Check for NaNs after computing negative log-likelihood
        if torch.isnan(negLogLike).any():
            print("NaNs detected in negLogLike")
            print("negLogLike:", negLogLike)
            raise ValueError("NaNs detected in negLogLike")

        # Just extract the value at the next step
        negLogLike = negLogLike.sum(dim=(1, 2))

        # Calculate the loss on the habitat selection surface
        habitat_loss = -1 * (habitat_probability_surface * target)
        habitat_loss = habitat_loss.sum(dim=(1, 2))

        # Calculate the loss on the movement surface
        movement_loss = -1 * (movement_probability_surface * target)
        movement_loss = movement_loss.sum(dim=(1, 2))

        # Apply the specified reduction
        if self.reduction == 'mean':
            return torch.mean(negLogLike), torch.mean(habitat_loss), torch.mean(movement_loss)
        elif self.reduction == 'sum':
            return torch.sum(negLogLike), torch.sum(habitat_loss), torch.sum(movement_loss)
        elif self.reduction == 'none':
            return negLogLike, habitat_loss, movement_loss

        # Default return (though it should never reach here without hitting an if)
        return negLogLike, habitat_loss, movement_loss

Test the loss function

Code
# Define the negative log-likelihood loss function with mean reduction
loss_fn = negativeLogLikeLoss(reduction='mean')

# Calculate the loss using the model outputs and the targets
total_loss, habitat_loss, movement_loss = loss_fn(model((sample_spatial_covs, sample_temporal_covs, sample_prev_bearing)), sample_next_step)
print(f'Total loss:     {total_loss}')
print(f'Habitat loss:   {habitat_loss}')
print(f'Movement loss:  {movement_loss}')
Total loss:     22.783802032470703
Habitat loss:   9.230388641357422
Movement loss:  13.553414344787598

Early stopping code

This code will be used to stop training if the validation loss does not improve after a certain number of epochs.

When the loss of the validation data (which is held out from the training data) decreases (i.e. the model improves), the model weights are saved. Each time the validation loss does not decrease, a counter is incremented. If the counter reaches the patience value, the training loop will break and the model will stop training. The ‘final’ model is then the model that had the lowest validation loss.

We have saved this code in a script called deepSSF_early_stopping.py so that we can import it into future scripts.

Code
class EarlyStopping:
    def __init__(self, patience=5, verbose=False, delta=0, path='checkpoint.pt', trace_func=print):
        """
        Args:
            patience (int): How long to wait after last time validation loss improved.
                            Default: 5
            verbose (bool): If True, prints a message for each validation loss improvement.
                            Default: False
            delta (float): Minimum change in the monitored quantity to qualify as an improvement.
                            Default: 0
            path (str): Path for the checkpoint to be saved to.
                            Default: 'checkpoint.pt'
            trace_func (function): trace print function.
                            Default: print
        """

        self.patience = patience
        self.verbose = verbose
        self.counter = 0
        self.best_score = None
        self.early_stop = False
        self.val_loss_min = float('inf')
        self.delta = delta
        self.path = path
        self.trace_func = trace_func

    def __call__(self, val_loss, model):

        # takes the validation loss and the model as inputs
        score = -val_loss

        # save the model's weights if the validation loss decreases
        if self.best_score is None:
            self.best_score = score
            self.save_checkpoint(val_loss, model)

        # if the validation loss does not decrease, increment the counter
        elif score < self.best_score + self.delta:
            self.counter += 1
            self.trace_func(f'EarlyStopping counter: {self.counter} out of {self.patience}')
            if self.counter >= self.patience:
                self.early_stop = True
        else:
            self.best_score = score
            self.save_checkpoint(val_loss, model)
            self.counter = 0

    def save_checkpoint(self, val_loss, model):
        '''Saves model when validation loss decrease.'''
        if self.verbose:
            self.trace_func(f'Validation loss decreased ({self.val_loss_min:.6f} --> {val_loss:.6f}).  Saving model ...')
        torch.save(model.state_dict(), self.path)
        self.val_loss_min = val_loss

Training loop

This code defines the main training loop for a single epoch. It iterates over batches from the training dataloader, moves the data to the correct device (e.g., CPU or GPU), calculates the loss, and performs backpropagation to update the model parameters. It also prints periodic updates of the current loss.

Code
def train_loop(dataloader_train, 
               model, 
               loss_fn, 
               optimisers, 
               skip_epoch0_training=False):
    """
    Runs the training process for one epoch using the given dataloader, model,
    loss function, and optimizer. Prints progress updates every few batches.
    """

    # Unpack optimisers
    optimiser_movement, optimiser_habitat = optimisers

    # 1. Total number of training examples
    num_train_batches = len(dataloader_train)
    size = len(dataloader_train.dataset)

    # 2. Put model in training mode (affects layers like dropout, batchnorm)
    model.train()

    # 3. Variable to accumulate the total loss over the epoch
    epoch_loss = 0.0

    # 4. Loop over batches in the training dataloader
    for batch, (x1, x2, x3, y) in enumerate(dataloader_train):

        # Move the batch of data to the specified device (CPU/GPU)
        x1 = x1.to(device)
        x2 = x2.to(device)
        x3 = x3.to(device)
        y = y.to(device)

        # Forward pass: compute the model output and loss
        with torch.set_grad_enabled(not skip_epoch0_training):
            outputs = model((x1, x2, x3))
            total_loss, habitat_loss, movement_loss = loss_fn(outputs, y,)

        epoch_loss += total_loss.detach()  # Use detach to prevent memory leaks

        # Only perform optimization if not skipping training
        if not skip_epoch0_training:
            # Backpropagation: compute gradients and update parameters
            # Reset gradients before the next iteration

            # Zero all gradients
            optimiser_movement.zero_grad()
            optimiser_habitat.zero_grad()

            # Backward pass 
            # habitat_loss.backward(retain_graph=True)
            # movement_loss.backward()
            total_loss.backward()

            # For movement optimizer: save habitat gradients, then zero them out
            habitat_grads = []
            for param in model.conv_habitat.parameters():
                # Save the gradient
                if param.grad is not None:
                    habitat_grads.append(param.grad.clone())
                else:
                    habitat_grads.append(None)
                # Zero out habitat gradient for movement update
                param.grad = None

            # Update movement parameters
            optimiser_movement.step()

            # For habitat optimizer: restore habitat gradients and zero movement gradients
            # for param in model.conv_movement.parameters():
            #     param.grad = None
            for param in model.fcn_movement_all.parameters():
                param.grad = None

            # Restore habitat gradients
            for i, param in enumerate(model.conv_habitat.parameters()):
                param.grad = habitat_grads[i]

            # Update habitat parameters
            optimiser_habitat.step()

        # Print an update every 5 batches to keep track of training progress
        if batch % 20 == 0:
            loss_val = total_loss.item()
            current = batch * batch_size + len(x1)
            if skip_epoch0_training:
                print(f"[Observation only] loss: {loss_val:>15f}  [{current:>5d}/{size:>5d}]")
            else:
                print(f"loss: {loss_val:>15f}  [{current:>5d}/{size:>5d}]")

        torch.cuda.empty_cache()

    # Compute the average training loss and print it
    epoch_loss /= num_train_batches
    if skip_epoch0_training:
        print(f"\nAvg training loss (observation only): {epoch_loss:>15f}")
    else:
        print(f"\nAvg training loss: {epoch_loss:>15f}")
    train_losses.append(epoch_loss.item())

Test loop

The test loop is similar to the training loop, but it does not perform backpropagation. It calculates the loss on the test set and returns the average loss.

Code
def test_loop(dataloader_test, model, loss_fn):
    """
    Evaluates the model on the provided test dataset by computing
    the average loss over all batches.
    No gradients are computed during this process (torch.no_grad()).
    """

    # 1. Set the model to evaluation mode (affects layers like dropout, batchnorm).
    model.eval()

    size = len(dataloader_test.dataset)
    num_batches = len(dataloader_test)

    test_loss = 0

    # 2. Disable gradient computation to speed up evaluation and reduce memory usage
    with torch.no_grad():
        # 3. Loop through each batch in the test dataloader
        for x1, x2, x3, y in dataloader_test:

            # Move the batch of data to the appropriate device (CPU/GPU)
            # x1, x2, x3 are the spatial covariates, temporal covariates, and bearing, respectively
            # y is the label (observed location of the next step)
            x1 = x1.to(device)
            x2 = x2.to(device)
            x3 = x3.to(device)
            y = y.to(device)

            # Compute the loss on the test set (no backward pass needed)
            total_loss, habitat_loss, movement_loss = loss_fn(model((x1, x2, x3)), y)
            test_loss += total_loss.detach()

    # 4. Compute average test loss over all batches
    test_loss /= num_batches

    torch.cuda.empty_cache()

    # Print the average test loss
    print(f"Avg test loss:    {test_loss:>15f} \n")

Train the model

Code
path_save_weights = f'{output_dir}/checkpoint_deepSSF_buffalo{buffalo_id}.pt'
print(f'Saving weights at: \n {path_save_weights}')

print(f'Output directory: {output_dir}')

epochs = 100
train_losses = []  # Track training losses across epochs
val_losses = []   # Track validation losses across epochs
val_habitat_losses = []  # Track validation habitat losses across epochs
val_movement_losses = []  # Track validation movement losses across epochs
# Difference in loss between epochs
train_diff = []
val_diff = []
val_habitat_diff = []
val_movement_diff = []

# Initialize the parameter container using the parameters defined in 'params_dict'
params = ModelParams(params_dict)
# Create an instance of the ConvJointModel using the parameters,
# and move the model to the specified device (e.g., CPU or GPU)
model = ConvJointModel(params).to(device)
# Print the model architecture
print(model)

# Define the negative log-likelihood loss function with mean reduction
loss_fn = negativeLogLikeLoss(reduction='mean') #, alpha=0.5

# Set the initial learning rates for each process
initial_learning_rate_movement = 1e-5
initial_learning_rate_habitat = 1e-4

# Create a combined optimiser for all movement-related parameters
# movement_params = list(model.conv_movement.parameters()) + list(model.fcn_movement_all.parameters())
movement_params = model.fcn_movement_all.parameters()

# Define separate optimizers for each component
optimiser_movement = optim.Adam(movement_params, lr=initial_learning_rate_movement)
optimiser_habitat = optim.Adam(model.conv_habitat.parameters(), lr=initial_learning_rate_habitat)

# Put optimisers into a tuple to call in the training loop
optimisers = (optimiser_movement, optimiser_habitat)

# Create separate schedulers for each optimizer
scheduler_movement = torch.optim.lr_scheduler.ReduceLROnPlateau(
    optimiser_movement, 'min', factor=0.1, patience=5)
scheduler_habitat = torch.optim.lr_scheduler.ReduceLROnPlateau(
    optimiser_habitat, 'min', factor=0.1, patience=5)

# Initialise early stopping 
early_stopping = EarlyStopping(patience=15, verbose=True, path=path_save_weights)

# Create directory for saving training images
os.makedirs(f'{output_dir}/training_images', exist_ok=True)
os.makedirs(f'{output_dir}/loss_images', exist_ok=True)

for t in range(epochs):

    # Initialise variables to store during training
    train_loss = 0.0
    num_train_batches = len(dataloader_train)

    val_loss = 0.0
    val_loss_habitat = 0.0
    val_loss_movement = 0.0
    num_batches = len(dataloader_val)

    print(f"Epoch {t+1}\n-------------------------------")

    # Skip training in the first epoch, but still calculate losses
    skip_training = (t == 0)

    # 1. Run the training loop for one epoch using the training dataloader
    train_loop(dataloader_train, 
               model, 
               loss_fn, 
               optimisers, 
               skip_epoch0_training=skip_training)

    # 2. Evaluate model performance on the validation dataset
    model.eval()  # Switch to evaluation mode for proper layer behavior
    with torch.no_grad():

        # Loop through each batch in the validation dataloader
        for x1, x2, x3, y in dataloader_val:
            # Move data to the chosen device (CPU/GPU)
            x1 = x1.to(device)
            x2 = x2.to(device)
            x3 = x3.to(device)
            y = y.to(device)

            # Accumulate validation loss
            total_loss, habitat_loss, movement_loss = loss_fn(model((x1, x2, x3)), y)
            val_loss += total_loss.detach()
            val_loss_habitat += habitat_loss.detach()
            val_loss_movement += movement_loss.detach()

    # # 3. Step the scheduler based on the validation loss (adjusts learning rate if needed)
    # scheduler.step(val_loss)
    scheduler_movement.step(val_loss_movement)
    scheduler_habitat.step(val_loss_habitat)

    # 4. Compute the average validation loss and print it, along with the current learning rate
    val_loss /= num_batches
    val_loss_habitat /= num_batches
    val_loss_movement /= num_batches

    print(f"Avg validation loss:            {val_loss:>15f}")
    print(f"Avg validation habitat loss:    {val_loss_habitat:>15f}")
    print(f"Avg validation movement loss:   {val_loss_movement:>15f}")
    print(f"Movement learning rate:         {scheduler_movement.get_last_lr()}")
    print(f"Habitat learning rate:          {scheduler_habitat.get_last_lr()}")

    # 5. Track the validation loss for plotting or monitoring
    val_losses.append(val_loss.item())
    val_habitat_losses.append(val_loss_habitat.item())
    val_movement_losses.append(val_loss_movement.item())

    # Memory management - add after validation but before early stopping check
    # torch.cuda.empty_cache()
    # gc.collect()

    # 6. Early stopping: if no improvement in validation loss for a set patience, stop training
    early_stopping(val_loss, model)
    if early_stopping.early_stop:
        print("Early stopping")
        # Restore the best model weights saved by EarlyStopping
        model.load_state_dict(torch.load(path_save_weights, weights_only=True, map_location=device))
        test_loop(dataloader_test, model, loss_fn)  # Evaluate on test set once training stops
        break
    else:
        model.eval()
        print("\n")

    torch.cuda.empty_cache()


    # ----------------------------------------------------
    # The following code demonstrates how
    # to optionally visualize or save intermediate results
    # (e.g., habitat probability surface, movement probability,
    # and next-step probability surfaces).

    # uncomment the code all in one go to run it (it should be inside the training loop)
    # ----------------------------------------------------

    # Extract training and validation losses for plotting

    # Convert the list of tensors to a single tensor
    train_losses_np = torch.tensor(train_losses).detach().cpu().numpy()
    val_losses_np = torch.tensor(val_losses).detach().cpu().numpy()
    val_habitat_losses_np = torch.tensor(val_habitat_losses).detach().cpu().numpy()
    val_movement_losses_np = torch.tensor(val_movement_losses).detach().cpu().numpy()

    # Get the difference in losses between epochs
    train_diff.append(train_losses_np[t] - train_losses_np[t-1])
    val_diff.append(val_losses_np[t] - val_losses_np[t-1])
    val_habitat_diff.append(val_habitat_losses_np[t] - val_habitat_losses_np[t-1])
    val_movement_diff.append(val_movement_losses_np[t] - val_movement_losses_np[t-1])

    # Number of epochs
    n_epochs = len(val_losses)

    # -----------------------------------------------------------
    # 1. Retrieve a single test example (covariates and labels)
    #    at the specified 'iteration_index' from the test dataset
    # -----------------------------------------------------------
    x1, x2, x3, labels = dataloader_train.dataset[iteration_index]

    # -----------------------------------------------------------
    # 2. Add a batch dimension and move tensors to the device
    #    for model inference
    # -----------------------------------------------------------
    x1 = x1.unsqueeze(0).to(device)
    x2 = x2.unsqueeze(0).to(device)
    x3 = x3.unsqueeze(0).to(device)

    # -----------------------------------------------------------
    # 3. Run the model on the single test example
    # -----------------------------------------------------------
    test = model((x1, x2, x3))

    # -----------------------------------------------------------
    # 4. Extract habitat and movement outputs;
    #    convert them to NumPy arrays for visualization
    # -----------------------------------------------------------
    hab_density = test.detach().cpu().numpy()[0, :, :, 0]
    movement_density = test.detach().cpu().numpy()[0, :, :, 1]

    # -----------------------------------------------------------
    # 5. Generate masks to exclude certain border cells for
    #    color scale reasons (setting them to -inf).
    # -----------------------------------------------------------
    x_mask = np.ones_like(hab_density)
    y_mask = np.ones_like(hab_density)

    # Mask out a few columns (0-2 and 98-end) and rows (0-2 and 98-end)
    x_mask[:, :3] = -np.inf
    x_mask[:, 98:] = -np.inf
    y_mask[:3, :] = -np.inf
    y_mask[98:, :] = -np.inf

    # Apply the masks to habitat density
    hab_density_mask = hab_density * x_mask * y_mask

    # Combine habitat and movement densities to represent
    # next-step probability
    step_density = hab_density + movement_density
    step_density_mask = step_density * x_mask * y_mask

    # Plot the covariates
    fig, axs = plt.subplots(2, 2, figsize=(9, 7.5))

    # # Plot NDVI
    # im1 = axs[0, 0].imshow(ndvi_natural.numpy(), cmap='viridis')
    # axs[0, 0].set_title('NDVI')
    # fig.colorbar(im1, ax=axs[0, 0])

    # Plot Training and Validation Loss
    axs[0, 0].plot(range(n_epochs), train_losses_np, label='Training Loss', color='blue')
    axs[0, 0].plot(range(n_epochs), val_losses_np, label='Validation Loss', color='red')
    # axs[0, 0].plot(range(n_epochs), val_habitat_losses_np, label='Validation Habitat Loss', color='green')
    # axs[0, 0].plot(range(n_epochs), val_movement_losses_np, label='Validation Movement Loss', color='orange')
    axs[0, 0].set_xlim(0, 120)
    axs[0, 0].set_title('Training and validation loss')
    axs[0, 0].set_xlabel('Epoch')
    axs[0, 0].set_ylabel('Loss')
    axs[0, 0].legend()

    # Plot habitat selection log-probability
    im2 = axs[0, 1].imshow(hab_density_mask, cmap='viridis')
    axs[0, 1].set_title('Habitat log-probability')
    fig.colorbar(im2, ax=axs[0, 1])

    # Plot movement log-probability
    im3 = axs[1, 0].imshow(movement_density, cmap='viridis')
    axs[1, 0].set_title('Movement log-probability')
    fig.colorbar(im3, ax=axs[1, 0])

    # Plot next-step log-probability
    im4 = axs[1, 1].imshow(step_density_mask, cmap='viridis')
    axs[1, 1].set_title('Next-step log-probability')
    fig.colorbar(im4, ax=axs[1, 1])

    filename_covs = f'{output_dir}/training_images/training_epoch_index{t}_id{buffalo_id}_yday{yday_t2_integer}_hour{hour_t2_integer}_bearing{bearing_degrees}_next_r{row}_c{column}.png'
    plt.tight_layout()
    plt.savefig(filename_covs, dpi=150) # creates inconsistent image sizes >>> , bbox_inches='tight'
    # plt.show()
    plt.close()  # Close the figure to free memory

    # Plot the difference in the loss of each component between epochs
    filename_diff = f'{output_dir}/loss_images/training_diff_epoch_index{t}_id{buffalo_id}_yday{yday_t2_integer}_hour{hour_t2_integer}_bearing{bearing_degrees}_next_r{row}_c{column}.png'
    plt.axhline(y=0, color='black', linestyle='--', label='Null Probability')  # null probs
    # plt.plot(range(n_epochs), train_diff, label='Training Loss Difference', color='blue')
    # plt.plot(range(n_epochs), val_diff, label='Validation Loss Difference', color='red')
    plt.plot(range(n_epochs), val_habitat_diff, label='Validation Habitat Loss Difference', color='green')
    plt.plot(range(n_epochs), val_movement_diff, label='Validation Movement Loss Difference', color='orange')
    plt.xlim(0, 120)
    plt.title('Habitat and movement loss difference')
    plt.xlabel('Epoch')
    plt.ylabel('Loss difference')
    plt.legend()
    # plt.tight_layout()
    plt.savefig(filename_diff, dpi=150) # creates inconsistent image sizes >>> , bbox_inches='tight'
    # plt.show()
    plt.close()  # Close the figure to free memory

print("Done!")
Saving weights at: 
 ../Python/outputs/model_training_S2/id2005_scalar_movement_2_2025-06-05/checkpoint_deepSSF_buffalo2005.pt
Output directory: ../Python/outputs/model_training_S2/id2005_scalar_movement_2_2025-06-05
ConvJointModel(
  (scalar_grid_output): Scalar_to_Grid_Block()
  (conv_habitat): Conv2d_block_spatial(
    (conv2d): Sequential(
      (0): Conv2d(21, 4, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (1): ReLU()
      (2): Conv2d(4, 4, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (3): ReLU()
      (4): Conv2d(4, 4, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (5): ReLU()
      (6): Conv2d(4, 1, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    )
  )
  (fcn_movement_all): FCN_block_all_movement(
    (ffn): Sequential(
      (0): Linear(in_features=8, out_features=256, bias=True)
      (1): Dropout(p=0.1, inplace=False)
      (2): ReLU()
      (3): Linear(in_features=256, out_features=256, bias=True)
      (4): Dropout(p=0.1, inplace=False)
      (5): ReLU()
      (6): Linear(in_features=256, out_features=256, bias=True)
      (7): Dropout(p=0.1, inplace=False)
      (8): ReLU()
      (9): Linear(in_features=256, out_features=12, bias=True)
    )
  )
  (movement_grid_output): Params_to_Grid_Block_ChV()
)
Epoch 1
-------------------------------
[Observation only] loss:       23.145914  [   32/ 8082]
[Observation only] loss:       22.848213  [  672/ 8082]
[Observation only] loss:       22.685690  [ 1312/ 8082]
[Observation only] loss:       22.802162  [ 1952/ 8082]
[Observation only] loss:       22.507214  [ 2592/ 8082]
[Observation only] loss:       22.503035  [ 3232/ 8082]
[Observation only] loss:       22.014181  [ 3872/ 8082]
[Observation only] loss:       23.221727  [ 4512/ 8082]
[Observation only] loss:       22.761583  [ 5152/ 8082]
[Observation only] loss:       22.621414  [ 5792/ 8082]
[Observation only] loss:       23.055664  [ 6432/ 8082]
[Observation only] loss:       23.195740  [ 7072/ 8082]
[Observation only] loss:       22.849449  [ 7712/ 8082]

Avg training loss (observation only):       22.776478
Avg validation loss:                  22.567806
Avg validation habitat loss:           9.229526
Avg validation movement loss:         13.338280
Movement learning rate:         [1e-05]
Habitat learning rate:          [0.0001]
Validation loss decreased (inf --> 22.567806).  Saving model ...

Epoch 2
-------------------------------
loss:       22.491375  [   32/ 8082]
loss:       22.600468  [  672/ 8082]
loss:       22.646358  [ 1312/ 8082]
loss:       22.734280  [ 1952/ 8082]
loss:       22.936218  [ 2592/ 8082]
loss:       22.677103  [ 3232/ 8082]
loss:       22.834663  [ 3872/ 8082]
loss:       22.540688  [ 4512/ 8082]
loss:       22.024836  [ 5152/ 8082]
loss:       22.735624  [ 5792/ 8082]
loss:       22.714054  [ 6432/ 8082]
loss:       22.350060  [ 7072/ 8082]
loss:       22.097446  [ 7712/ 8082]

Avg training loss:       22.557842
Avg validation loss:                  22.097454
Avg validation habitat loss:           9.210783
Avg validation movement loss:         12.886669
Movement learning rate:         [1e-05]
Habitat learning rate:          [0.0001]
Validation loss decreased (22.567806 --> 22.097454).  Saving model ...

Epoch 3
-------------------------------
loss:       21.527706  [   32/ 8082]
loss:       22.656925  [  672/ 8082]
loss:       23.031595  [ 1312/ 8082]
loss:       22.638725  [ 1952/ 8082]
loss:       22.121407  [ 2592/ 8082]
loss:       21.547007  [ 3232/ 8082]
loss:       22.066875  [ 3872/ 8082]
loss:       22.300606  [ 4512/ 8082]
loss:       22.719770  [ 5152/ 8082]
loss:       22.616438  [ 5792/ 8082]
loss:       22.053728  [ 6432/ 8082]
loss:       22.056820  [ 7072/ 8082]
loss:       22.040588  [ 7712/ 8082]

Avg training loss:       22.084496
Avg validation loss:                  21.604164
Avg validation habitat loss:           9.156961
Avg validation movement loss:         12.447203
Movement learning rate:         [1e-05]
Habitat learning rate:          [0.0001]
Validation loss decreased (22.097454 --> 21.604164).  Saving model ...

Epoch 4
-------------------------------
loss:       21.828339  [   32/ 8082]
loss:       21.696220  [  672/ 8082]
loss:       22.224220  [ 1312/ 8082]
loss:       21.740173  [ 1952/ 8082]
loss:       21.926682  [ 2592/ 8082]
loss:       21.545471  [ 3232/ 8082]
loss:       22.046978  [ 3872/ 8082]
loss:       21.794256  [ 4512/ 8082]
loss:       21.563906  [ 5152/ 8082]
loss:       21.904961  [ 5792/ 8082]
loss:       21.438297  [ 6432/ 8082]
loss:       20.840139  [ 7072/ 8082]
loss:       21.188772  [ 7712/ 8082]

Avg training loss:       21.699593
Avg validation loss:                  21.228951
Avg validation habitat loss:           9.102291
Avg validation movement loss:         12.126659
Movement learning rate:         [1e-05]
Habitat learning rate:          [0.0001]
Validation loss decreased (21.604164 --> 21.228951).  Saving model ...

Epoch 5
-------------------------------
loss:       22.049572  [   32/ 8082]
loss:       20.473804  [  672/ 8082]
loss:       21.991619  [ 1312/ 8082]
loss:       21.252106  [ 1952/ 8082]
loss:       21.590950  [ 2592/ 8082]
loss:       21.615730  [ 3232/ 8082]
loss:       21.985062  [ 3872/ 8082]
loss:       21.055092  [ 4512/ 8082]
loss:       21.311924  [ 5152/ 8082]
loss:       22.055412  [ 5792/ 8082]
loss:       21.721199  [ 6432/ 8082]
loss:       21.341251  [ 7072/ 8082]
loss:       22.237782  [ 7712/ 8082]

Avg training loss:       21.458605
Avg validation loss:                  21.085865
Avg validation habitat loss:           9.055049
Avg validation movement loss:         12.030817
Movement learning rate:         [1e-05]
Habitat learning rate:          [0.0001]
Validation loss decreased (21.228951 --> 21.085865).  Saving model ...

Epoch 6
-------------------------------
loss:       21.424976  [   32/ 8082]
loss:       21.490717  [  672/ 8082]
loss:       22.324043  [ 1312/ 8082]
loss:       20.973207  [ 1952/ 8082]
loss:       20.831932  [ 2592/ 8082]
loss:       22.177689  [ 3232/ 8082]
loss:       21.719204  [ 3872/ 8082]
loss:       21.357609  [ 4512/ 8082]
loss:       20.709713  [ 5152/ 8082]
loss:       21.533581  [ 5792/ 8082]
loss:       21.478928  [ 6432/ 8082]
loss:       20.919470  [ 7072/ 8082]
loss:       20.734505  [ 7712/ 8082]

Avg training loss:       21.368002
Avg validation loss:                  21.006781
Avg validation habitat loss:           9.037350
Avg validation movement loss:         11.969433
Movement learning rate:         [1e-05]
Habitat learning rate:          [0.0001]
Validation loss decreased (21.085865 --> 21.006781).  Saving model ...

Epoch 7
-------------------------------
loss:       20.711094  [   32/ 8082]
loss:       20.795650  [  672/ 8082]
loss:       22.060495  [ 1312/ 8082]
loss:       21.569458  [ 1952/ 8082]
loss:       21.310856  [ 2592/ 8082]
loss:       21.400623  [ 3232/ 8082]
loss:       20.481079  [ 3872/ 8082]
loss:       21.412466  [ 4512/ 8082]
loss:       20.609871  [ 5152/ 8082]
loss:       21.441900  [ 5792/ 8082]
loss:       21.346174  [ 6432/ 8082]
loss:       20.842838  [ 7072/ 8082]
loss:       21.067366  [ 7712/ 8082]

Avg training loss:       21.317116
Avg validation loss:                  20.973610
Avg validation habitat loss:           9.043344
Avg validation movement loss:         11.930265
Movement learning rate:         [1e-05]
Habitat learning rate:          [0.0001]
Validation loss decreased (21.006781 --> 20.973610).  Saving model ...

Epoch 8
-------------------------------
loss:       21.454086  [   32/ 8082]
loss:       21.713417  [  672/ 8082]
loss:       22.085426  [ 1312/ 8082]
loss:       20.907196  [ 1952/ 8082]
loss:       21.070141  [ 2592/ 8082]
loss:       20.884434  [ 3232/ 8082]
loss:       21.188454  [ 3872/ 8082]
loss:       21.116657  [ 4512/ 8082]
loss:       21.897556  [ 5152/ 8082]
loss:       20.972239  [ 5792/ 8082]
loss:       21.931610  [ 6432/ 8082]
loss:       20.399981  [ 7072/ 8082]
loss:       20.078285  [ 7712/ 8082]

Avg training loss:       21.286648
Avg validation loss:                  20.924185
Avg validation habitat loss:           9.046171
Avg validation movement loss:         11.878015
Movement learning rate:         [1e-05]
Habitat learning rate:          [0.0001]
Validation loss decreased (20.973610 --> 20.924185).  Saving model ...

Epoch 9
-------------------------------
loss:       21.400911  [   32/ 8082]
loss:       21.863380  [  672/ 8082]
loss:       21.227234  [ 1312/ 8082]
loss:       20.729401  [ 1952/ 8082]
loss:       21.836586  [ 2592/ 8082]
loss:       20.710028  [ 3232/ 8082]
loss:       20.347727  [ 3872/ 8082]
loss:       21.351017  [ 4512/ 8082]
loss:       20.610641  [ 5152/ 8082]
loss:       21.459038  [ 5792/ 8082]
loss:       20.033573  [ 6432/ 8082]
loss:       21.909445  [ 7072/ 8082]
loss:       20.222166  [ 7712/ 8082]

Avg training loss:       21.259096
Avg validation loss:                  20.915016
Avg validation habitat loss:           9.046423
Avg validation movement loss:         11.868596
Movement learning rate:         [1e-05]
Habitat learning rate:          [0.0001]
Validation loss decreased (20.924185 --> 20.915016).  Saving model ...

Epoch 10
-------------------------------
loss:       20.958054  [   32/ 8082]
loss:       20.429680  [  672/ 8082]
loss:       20.913551  [ 1312/ 8082]
loss:       20.481876  [ 1952/ 8082]
loss:       21.408901  [ 2592/ 8082]
loss:       21.134224  [ 3232/ 8082]
loss:       20.760900  [ 3872/ 8082]
loss:       21.440830  [ 4512/ 8082]
loss:       21.488529  [ 5152/ 8082]
loss:       21.755604  [ 5792/ 8082]
loss:       20.790857  [ 6432/ 8082]
loss:       21.009857  [ 7072/ 8082]
loss:       21.048809  [ 7712/ 8082]

Avg training loss:       21.239704
Avg validation loss:                  20.902134
Avg validation habitat loss:           9.044733
Avg validation movement loss:         11.857397
Movement learning rate:         [1e-05]
Habitat learning rate:          [0.0001]
Validation loss decreased (20.915016 --> 20.902134).  Saving model ...

Epoch 11
-------------------------------
loss:       21.096672  [   32/ 8082]
loss:       20.399212  [  672/ 8082]
loss:       21.214495  [ 1312/ 8082]
loss:       21.236334  [ 1952/ 8082]
loss:       20.733944  [ 2592/ 8082]
loss:       21.345058  [ 3232/ 8082]
loss:       21.785999  [ 3872/ 8082]
loss:       21.050419  [ 4512/ 8082]
loss:       21.359404  [ 5152/ 8082]
loss:       20.748739  [ 5792/ 8082]
loss:       21.613403  [ 6432/ 8082]
loss:       21.788340  [ 7072/ 8082]
loss:       20.938786  [ 7712/ 8082]

Avg training loss:       21.227646
Avg validation loss:                  20.908264
Avg validation habitat loss:           9.045177
Avg validation movement loss:         11.863087
Movement learning rate:         [1e-05]
Habitat learning rate:          [0.0001]
EarlyStopping counter: 1 out of 15

Epoch 12
-------------------------------
loss:       21.036507  [   32/ 8082]
loss:       21.327248  [  672/ 8082]
loss:       20.278337  [ 1312/ 8082]
loss:       21.455257  [ 1952/ 8082]
loss:       21.334621  [ 2592/ 8082]
loss:       21.642715  [ 3232/ 8082]
loss:       21.558609  [ 3872/ 8082]
loss:       21.044666  [ 4512/ 8082]
loss:       21.344259  [ 5152/ 8082]
loss:       22.116482  [ 5792/ 8082]
loss:       21.511631  [ 6432/ 8082]
loss:       20.931982  [ 7072/ 8082]
loss:       21.954626  [ 7712/ 8082]

Avg training loss:       21.210945
Avg validation loss:                  20.867819
Avg validation habitat loss:           9.038441
Avg validation movement loss:         11.829380
Movement learning rate:         [1e-05]
Habitat learning rate:          [1e-05]
Validation loss decreased (20.902134 --> 20.867819).  Saving model ...

Epoch 13
-------------------------------
loss:       21.788040  [   32/ 8082]
loss:       20.729189  [  672/ 8082]
loss:       22.053547  [ 1312/ 8082]
loss:       22.079924  [ 1952/ 8082]
loss:       20.847200  [ 2592/ 8082]
loss:       22.619478  [ 3232/ 8082]
loss:       21.114822  [ 3872/ 8082]
loss:       21.663542  [ 4512/ 8082]
loss:       21.553707  [ 5152/ 8082]
loss:       21.946461  [ 5792/ 8082]
loss:       21.476524  [ 6432/ 8082]
loss:       21.598766  [ 7072/ 8082]
loss:       21.414360  [ 7712/ 8082]

Avg training loss:       21.193920
Avg validation loss:                  20.848740
Avg validation habitat loss:           9.037801
Avg validation movement loss:         11.810939
Movement learning rate:         [1e-05]
Habitat learning rate:          [1e-05]
Validation loss decreased (20.867819 --> 20.848740).  Saving model ...

Epoch 14
-------------------------------
loss:       20.528849  [   32/ 8082]
loss:       20.863527  [  672/ 8082]
loss:       20.738653  [ 1312/ 8082]
loss:       21.296501  [ 1952/ 8082]
loss:       20.686237  [ 2592/ 8082]
loss:       21.790766  [ 3232/ 8082]
loss:       21.660587  [ 3872/ 8082]
loss:       20.720257  [ 4512/ 8082]
loss:       20.333691  [ 5152/ 8082]
loss:       21.582300  [ 5792/ 8082]
loss:       21.665226  [ 6432/ 8082]
loss:       21.233349  [ 7072/ 8082]
loss:       20.756908  [ 7712/ 8082]

Avg training loss:       21.196007
Avg validation loss:                  20.861876
Avg validation habitat loss:           9.036444
Avg validation movement loss:         11.825429
Movement learning rate:         [1e-05]
Habitat learning rate:          [1e-05]
EarlyStopping counter: 1 out of 15

Epoch 15
-------------------------------
loss:       22.023518  [   32/ 8082]
loss:       21.034973  [  672/ 8082]
loss:       21.153305  [ 1312/ 8082]
loss:       21.196768  [ 1952/ 8082]
loss:       21.342583  [ 2592/ 8082]
loss:       21.717873  [ 3232/ 8082]
loss:       20.382698  [ 3872/ 8082]
loss:       21.448109  [ 4512/ 8082]
loss:       20.241938  [ 5152/ 8082]
loss:       20.688532  [ 5792/ 8082]
loss:       21.202389  [ 6432/ 8082]
loss:       22.816853  [ 7072/ 8082]
loss:       20.808300  [ 7712/ 8082]

Avg training loss:       21.186493
Avg validation loss:                  20.829815
Avg validation habitat loss:           9.033096
Avg validation movement loss:         11.796720
Movement learning rate:         [1e-05]
Habitat learning rate:          [1e-05]
Validation loss decreased (20.848740 --> 20.829815).  Saving model ...

Epoch 16
-------------------------------
loss:       21.150539  [   32/ 8082]
loss:       20.694864  [  672/ 8082]
loss:       20.638567  [ 1312/ 8082]
loss:       20.713860  [ 1952/ 8082]
loss:       20.920170  [ 2592/ 8082]
loss:       21.498886  [ 3232/ 8082]
loss:       20.461460  [ 3872/ 8082]
loss:       21.459799  [ 4512/ 8082]
loss:       21.617128  [ 5152/ 8082]
loss:       21.100677  [ 5792/ 8082]
loss:       21.362625  [ 6432/ 8082]
loss:       21.115896  [ 7072/ 8082]
loss:       21.681713  [ 7712/ 8082]

Avg training loss:       21.179407
Avg validation loss:                  20.816412
Avg validation habitat loss:           9.031187
Avg validation movement loss:         11.785226
Movement learning rate:         [1e-05]
Habitat learning rate:          [1e-05]
Validation loss decreased (20.829815 --> 20.816412).  Saving model ...

Epoch 17
-------------------------------
loss:       21.203239  [   32/ 8082]
loss:       21.881290  [  672/ 8082]
loss:       20.080597  [ 1312/ 8082]
loss:       20.878534  [ 1952/ 8082]
loss:       21.071144  [ 2592/ 8082]
loss:       22.074572  [ 3232/ 8082]
loss:       20.642521  [ 3872/ 8082]
loss:       20.214436  [ 4512/ 8082]
loss:       21.128044  [ 5152/ 8082]
loss:       21.824587  [ 5792/ 8082]
loss:       21.733837  [ 6432/ 8082]
loss:       21.076471  [ 7072/ 8082]
loss:       21.004824  [ 7712/ 8082]

Avg training loss:       21.171009
Avg validation loss:                  20.818497
Avg validation habitat loss:           9.029499
Avg validation movement loss:         11.788999
Movement learning rate:         [1e-05]
Habitat learning rate:          [1e-05]
EarlyStopping counter: 1 out of 15

Epoch 18
-------------------------------
loss:       20.757704  [   32/ 8082]
loss:       19.870464  [  672/ 8082]
loss:       20.571629  [ 1312/ 8082]
loss:       20.682262  [ 1952/ 8082]
loss:       21.916908  [ 2592/ 8082]
loss:       21.891731  [ 3232/ 8082]
loss:       21.072895  [ 3872/ 8082]
loss:       20.267975  [ 4512/ 8082]
loss:       21.196260  [ 5152/ 8082]
loss:       20.240715  [ 5792/ 8082]
loss:       22.083355  [ 6432/ 8082]
loss:       21.363201  [ 7072/ 8082]
loss:       20.723520  [ 7712/ 8082]

Avg training loss:       21.170866
Avg validation loss:                  20.824989
Avg validation habitat loss:           9.030242
Avg validation movement loss:         11.794747
Movement learning rate:         [1e-05]
Habitat learning rate:          [1e-05]
EarlyStopping counter: 2 out of 15

Epoch 19
-------------------------------
loss:       20.547075  [   32/ 8082]
loss:       20.789583  [  672/ 8082]
loss:       20.778284  [ 1312/ 8082]
loss:       20.744684  [ 1952/ 8082]
loss:       21.462093  [ 2592/ 8082]
loss:       22.259712  [ 3232/ 8082]
loss:       20.948105  [ 3872/ 8082]
loss:       21.170586  [ 4512/ 8082]
loss:       20.006477  [ 5152/ 8082]
loss:       21.592457  [ 5792/ 8082]
loss:       21.636990  [ 6432/ 8082]
loss:       20.897472  [ 7072/ 8082]
loss:       20.539326  [ 7712/ 8082]

Avg training loss:       21.159136
Avg validation loss:                  20.816593
Avg validation habitat loss:           9.032276
Avg validation movement loss:         11.784320
Movement learning rate:         [1e-05]
Habitat learning rate:          [1e-05]
EarlyStopping counter: 3 out of 15

Epoch 20
-------------------------------
loss:       21.492580  [   32/ 8082]
loss:       20.724102  [  672/ 8082]
loss:       21.460249  [ 1312/ 8082]
loss:       21.157827  [ 1952/ 8082]
loss:       20.799656  [ 2592/ 8082]
loss:       21.152050  [ 3232/ 8082]
loss:       21.262957  [ 3872/ 8082]
loss:       20.597225  [ 4512/ 8082]
loss:       20.945023  [ 5152/ 8082]
loss:       21.101837  [ 5792/ 8082]
loss:       21.255554  [ 6432/ 8082]
loss:       21.472183  [ 7072/ 8082]
loss:       21.319832  [ 7712/ 8082]

Avg training loss:       21.157312
Avg validation loss:                  20.797131
Avg validation habitat loss:           9.027948
Avg validation movement loss:         11.769186
Movement learning rate:         [1e-05]
Habitat learning rate:          [1e-05]
Validation loss decreased (20.816412 --> 20.797131).  Saving model ...

Epoch 21
-------------------------------
loss:       21.497169  [   32/ 8082]
loss:       20.091866  [  672/ 8082]
loss:       20.546898  [ 1312/ 8082]
loss:       20.548870  [ 1952/ 8082]
loss:       21.073519  [ 2592/ 8082]
loss:       21.380587  [ 3232/ 8082]
loss:       20.156319  [ 3872/ 8082]
loss:       21.297672  [ 4512/ 8082]
loss:       20.537281  [ 5152/ 8082]
loss:       19.831913  [ 5792/ 8082]
loss:       19.995611  [ 6432/ 8082]
loss:       21.524014  [ 7072/ 8082]
loss:       20.353510  [ 7712/ 8082]

Avg training loss:       21.152542
Avg validation loss:                  20.815090
Avg validation habitat loss:           9.027156
Avg validation movement loss:         11.787934
Movement learning rate:         [1e-05]
Habitat learning rate:          [1e-05]
EarlyStopping counter: 1 out of 15

Epoch 22
-------------------------------
loss:       21.215414  [   32/ 8082]
loss:       21.545546  [  672/ 8082]
loss:       21.544437  [ 1312/ 8082]
loss:       20.946671  [ 1952/ 8082]
loss:       21.506306  [ 2592/ 8082]
loss:       20.425323  [ 3232/ 8082]
loss:       21.513355  [ 3872/ 8082]
loss:       20.467644  [ 4512/ 8082]
loss:       20.775734  [ 5152/ 8082]
loss:       22.148407  [ 5792/ 8082]
loss:       22.264874  [ 6432/ 8082]
loss:       21.231478  [ 7072/ 8082]
loss:       21.023224  [ 7712/ 8082]

Avg training loss:       21.146730
Avg validation loss:                  20.789963
Avg validation habitat loss:           9.027478
Avg validation movement loss:         11.762482
Movement learning rate:         [1e-05]
Habitat learning rate:          [1e-05]
Validation loss decreased (20.797131 --> 20.789963).  Saving model ...

Epoch 23
-------------------------------
loss:       21.881798  [   32/ 8082]
loss:       20.770687  [  672/ 8082]
loss:       21.535299  [ 1312/ 8082]
loss:       20.559221  [ 1952/ 8082]
loss:       20.637714  [ 2592/ 8082]
loss:       20.155331  [ 3232/ 8082]
loss:       21.661261  [ 3872/ 8082]
loss:       21.339867  [ 4512/ 8082]
loss:       21.098503  [ 5152/ 8082]
loss:       21.415815  [ 5792/ 8082]
loss:       21.304445  [ 6432/ 8082]
loss:       22.567953  [ 7072/ 8082]
loss:       22.135626  [ 7712/ 8082]

Avg training loss:       21.146635
Avg validation loss:                  20.805067
Avg validation habitat loss:           9.028492
Avg validation movement loss:         11.776574
Movement learning rate:         [1e-05]
Habitat learning rate:          [1e-05]
EarlyStopping counter: 1 out of 15

Epoch 24
-------------------------------
loss:       20.872395  [   32/ 8082]
loss:       20.697403  [  672/ 8082]
loss:       21.186888  [ 1312/ 8082]
loss:       20.876577  [ 1952/ 8082]
loss:       21.625328  [ 2592/ 8082]
loss:       20.267830  [ 3232/ 8082]
loss:       20.733440  [ 3872/ 8082]
loss:       21.001404  [ 4512/ 8082]
loss:       20.038193  [ 5152/ 8082]
loss:       21.685505  [ 5792/ 8082]
loss:       20.601254  [ 6432/ 8082]
loss:       21.231255  [ 7072/ 8082]
loss:       22.374275  [ 7712/ 8082]

Avg training loss:       21.140148
Avg validation loss:                  20.795481
Avg validation habitat loss:           9.026320
Avg validation movement loss:         11.769160
Movement learning rate:         [1e-05]
Habitat learning rate:          [1e-05]
EarlyStopping counter: 2 out of 15

Epoch 25
-------------------------------
loss:       21.400166  [   32/ 8082]
loss:       20.420086  [  672/ 8082]
loss:       21.357122  [ 1312/ 8082]
loss:       20.841751  [ 1952/ 8082]
loss:       21.458828  [ 2592/ 8082]
loss:       20.734526  [ 3232/ 8082]
loss:       20.903400  [ 3872/ 8082]
loss:       21.872757  [ 4512/ 8082]
loss:       20.753246  [ 5152/ 8082]
loss:       20.245499  [ 5792/ 8082]
loss:       22.019331  [ 6432/ 8082]
loss:       21.230900  [ 7072/ 8082]
loss:       20.654531  [ 7712/ 8082]

Avg training loss:       21.138622
Avg validation loss:                  20.788698
Avg validation habitat loss:           9.027100
Avg validation movement loss:         11.761598
Movement learning rate:         [1e-05]
Habitat learning rate:          [1e-05]
Validation loss decreased (20.789963 --> 20.788698).  Saving model ...

Epoch 26
-------------------------------
loss:       21.430428  [   32/ 8082]
loss:       20.878675  [  672/ 8082]
loss:       21.062071  [ 1312/ 8082]
loss:       20.910215  [ 1952/ 8082]
loss:       21.569189  [ 2592/ 8082]
loss:       20.719362  [ 3232/ 8082]
loss:       20.836317  [ 3872/ 8082]
loss:       21.302094  [ 4512/ 8082]
loss:       20.843605  [ 5152/ 8082]
loss:       21.386097  [ 5792/ 8082]
loss:       21.322710  [ 6432/ 8082]
loss:       20.824400  [ 7072/ 8082]
loss:       21.304245  [ 7712/ 8082]

Avg training loss:       21.135345
Avg validation loss:                  20.770597
Avg validation habitat loss:           9.029612
Avg validation movement loss:         11.740987
Movement learning rate:         [1e-05]
Habitat learning rate:          [1e-05]
Validation loss decreased (20.788698 --> 20.770597).  Saving model ...

Epoch 27
-------------------------------
loss:       20.302856  [   32/ 8082]
loss:       21.906572  [  672/ 8082]
loss:       22.442581  [ 1312/ 8082]
loss:       22.169353  [ 1952/ 8082]
loss:       20.855629  [ 2592/ 8082]
loss:       20.862471  [ 3232/ 8082]
loss:       21.079117  [ 3872/ 8082]
loss:       21.846558  [ 4512/ 8082]
loss:       20.597416  [ 5152/ 8082]
loss:       21.325996  [ 5792/ 8082]
loss:       21.242905  [ 6432/ 8082]
loss:       22.024933  [ 7072/ 8082]
loss:       21.318340  [ 7712/ 8082]

Avg training loss:       21.135565
Avg validation loss:                  20.760530
Avg validation habitat loss:           9.028321
Avg validation movement loss:         11.732207
Movement learning rate:         [1e-05]
Habitat learning rate:          [1e-05]
Validation loss decreased (20.770597 --> 20.760530).  Saving model ...

Epoch 28
-------------------------------
loss:       20.626581  [   32/ 8082]
loss:       22.088329  [  672/ 8082]
loss:       21.844406  [ 1312/ 8082]
loss:       21.127743  [ 1952/ 8082]
loss:       20.958706  [ 2592/ 8082]
loss:       21.240135  [ 3232/ 8082]
loss:       22.820835  [ 3872/ 8082]
loss:       21.115845  [ 4512/ 8082]
loss:       20.969172  [ 5152/ 8082]
loss:       20.997906  [ 5792/ 8082]
loss:       21.205708  [ 6432/ 8082]
loss:       21.276062  [ 7072/ 8082]
loss:       21.129511  [ 7712/ 8082]

Avg training loss:       21.131155
Avg validation loss:                  20.781105
Avg validation habitat loss:           9.028436
Avg validation movement loss:         11.752671
Movement learning rate:         [1e-05]
Habitat learning rate:          [1e-05]
EarlyStopping counter: 1 out of 15

Epoch 29
-------------------------------
loss:       21.324337  [   32/ 8082]
loss:       21.249756  [  672/ 8082]
loss:       21.615227  [ 1312/ 8082]
loss:       20.040539  [ 1952/ 8082]
loss:       21.820065  [ 2592/ 8082]
loss:       20.980541  [ 3232/ 8082]
loss:       20.271732  [ 3872/ 8082]
loss:       21.416475  [ 4512/ 8082]
loss:       21.544765  [ 5152/ 8082]
loss:       21.252304  [ 5792/ 8082]
loss:       20.818727  [ 6432/ 8082]
loss:       20.915058  [ 7072/ 8082]
loss:       21.240482  [ 7712/ 8082]

Avg training loss:       21.117781
Avg validation loss:                  20.793734
Avg validation habitat loss:           9.028555
Avg validation movement loss:         11.765176
Movement learning rate:         [1e-05]
Habitat learning rate:          [1e-05]
EarlyStopping counter: 2 out of 15

Epoch 30
-------------------------------
loss:       21.820110  [   32/ 8082]
loss:       20.382082  [  672/ 8082]
loss:       20.078415  [ 1312/ 8082]
loss:       20.171513  [ 1952/ 8082]
loss:       21.684881  [ 2592/ 8082]
loss:       20.520363  [ 3232/ 8082]
loss:       20.990616  [ 3872/ 8082]
loss:       21.473301  [ 4512/ 8082]
loss:       22.159891  [ 5152/ 8082]
loss:       20.516306  [ 5792/ 8082]
loss:       20.767601  [ 6432/ 8082]
loss:       22.186560  [ 7072/ 8082]
loss:       20.912476  [ 7712/ 8082]

Avg training loss:       21.121313
Avg validation loss:                  20.767521
Avg validation habitat loss:           9.029516
Avg validation movement loss:         11.738002
Movement learning rate:         [1e-05]
Habitat learning rate:          [1.0000000000000002e-06]
EarlyStopping counter: 3 out of 15

Epoch 31
-------------------------------
loss:       21.437902  [   32/ 8082]
loss:       21.117199  [  672/ 8082]
loss:       20.584345  [ 1312/ 8082]
loss:       21.065020  [ 1952/ 8082]
loss:       20.349684  [ 2592/ 8082]
loss:       20.933506  [ 3232/ 8082]
loss:       20.635319  [ 3872/ 8082]
loss:       21.275661  [ 4512/ 8082]
loss:       21.781719  [ 5152/ 8082]
loss:       20.462719  [ 5792/ 8082]
loss:       21.778208  [ 6432/ 8082]
loss:       20.609955  [ 7072/ 8082]
loss:       22.015806  [ 7712/ 8082]

Avg training loss:       21.118505
Avg validation loss:                  20.784393
Avg validation habitat loss:           9.028644
Avg validation movement loss:         11.755750
Movement learning rate:         [1e-05]
Habitat learning rate:          [1.0000000000000002e-06]
EarlyStopping counter: 4 out of 15

Epoch 32
-------------------------------
loss:       22.343058  [   32/ 8082]
loss:       21.422207  [  672/ 8082]
loss:       21.509037  [ 1312/ 8082]
loss:       20.849970  [ 1952/ 8082]
loss:       21.604668  [ 2592/ 8082]
loss:       20.841194  [ 3232/ 8082]
loss:       21.065008  [ 3872/ 8082]
loss:       20.930733  [ 4512/ 8082]
loss:       21.069557  [ 5152/ 8082]
loss:       20.454048  [ 5792/ 8082]
loss:       20.782249  [ 6432/ 8082]
loss:       20.741388  [ 7072/ 8082]
loss:       20.857525  [ 7712/ 8082]

Avg training loss:       21.114100
Avg validation loss:                  20.786816
Avg validation habitat loss:           9.026937
Avg validation movement loss:         11.759875
Movement learning rate:         [1e-05]
Habitat learning rate:          [1.0000000000000002e-06]
EarlyStopping counter: 5 out of 15

Epoch 33
-------------------------------
loss:       20.861525  [   32/ 8082]
loss:       21.250650  [  672/ 8082]
loss:       21.768620  [ 1312/ 8082]
loss:       22.034649  [ 1952/ 8082]
loss:       21.176085  [ 2592/ 8082]
loss:       21.546896  [ 3232/ 8082]
loss:       21.842052  [ 3872/ 8082]
loss:       20.757820  [ 4512/ 8082]
loss:       20.842663  [ 5152/ 8082]
loss:       22.104130  [ 5792/ 8082]
loss:       20.857515  [ 6432/ 8082]
loss:       21.619083  [ 7072/ 8082]
loss:       20.846825  [ 7712/ 8082]

Avg training loss:       21.115108
Avg validation loss:                  20.764582
Avg validation habitat loss:           9.029498
Avg validation movement loss:         11.735083
Movement learning rate:         [1.0000000000000002e-06]
Habitat learning rate:          [1.0000000000000002e-06]
EarlyStopping counter: 6 out of 15

Epoch 34
-------------------------------
loss:       20.288479  [   32/ 8082]
loss:       21.287521  [  672/ 8082]
loss:       21.729200  [ 1312/ 8082]
loss:       21.560881  [ 1952/ 8082]
loss:       20.850574  [ 2592/ 8082]
loss:       21.010719  [ 3232/ 8082]
loss:       21.940323  [ 3872/ 8082]
loss:       21.398573  [ 4512/ 8082]
loss:       20.784222  [ 5152/ 8082]
loss:       21.050602  [ 5792/ 8082]
loss:       21.449968  [ 6432/ 8082]
loss:       20.466709  [ 7072/ 8082]
loss:       21.046553  [ 7712/ 8082]

Avg training loss:       21.110161
Avg validation loss:                  20.765018
Avg validation habitat loss:           9.027593
Avg validation movement loss:         11.737428
Movement learning rate:         [1.0000000000000002e-06]
Habitat learning rate:          [1.0000000000000002e-06]
EarlyStopping counter: 7 out of 15

Epoch 35
-------------------------------
loss:       20.279297  [   32/ 8082]
loss:       20.959492  [  672/ 8082]
loss:       20.999603  [ 1312/ 8082]
loss:       20.784821  [ 1952/ 8082]
loss:       21.112902  [ 2592/ 8082]
loss:       19.672127  [ 3232/ 8082]
loss:       21.199701  [ 3872/ 8082]
loss:       20.662323  [ 4512/ 8082]
loss:       22.284149  [ 5152/ 8082]
loss:       21.463650  [ 5792/ 8082]
loss:       21.206123  [ 6432/ 8082]
loss:       22.159784  [ 7072/ 8082]
loss:       20.282604  [ 7712/ 8082]

Avg training loss:       21.109623
Avg validation loss:                  20.783741
Avg validation habitat loss:           9.026810
Avg validation movement loss:         11.756932
Movement learning rate:         [1.0000000000000002e-06]
Habitat learning rate:          [1.0000000000000002e-06]
EarlyStopping counter: 8 out of 15

Epoch 36
-------------------------------
loss:       21.121698  [   32/ 8082]
loss:       21.805082  [  672/ 8082]
loss:       20.060991  [ 1312/ 8082]
loss:       20.904667  [ 1952/ 8082]
loss:       20.950287  [ 2592/ 8082]
loss:       21.621799  [ 3232/ 8082]
loss:       21.658867  [ 3872/ 8082]
loss:       21.363186  [ 4512/ 8082]
loss:       20.292938  [ 5152/ 8082]
loss:       22.284636  [ 5792/ 8082]
loss:       21.797125  [ 6432/ 8082]
loss:       21.356319  [ 7072/ 8082]
loss:       20.655933  [ 7712/ 8082]

Avg training loss:       21.110859
Avg validation loss:                  20.761776
Avg validation habitat loss:           9.026420
Avg validation movement loss:         11.735357
Movement learning rate:         [1.0000000000000002e-06]
Habitat learning rate:          [1.0000000000000002e-07]
EarlyStopping counter: 9 out of 15

Epoch 37
-------------------------------
loss:       20.587643  [   32/ 8082]
loss:       21.220551  [  672/ 8082]
loss:       21.822920  [ 1312/ 8082]
loss:       21.134798  [ 1952/ 8082]
loss:       20.682344  [ 2592/ 8082]
loss:       21.379015  [ 3232/ 8082]
loss:       20.234079  [ 3872/ 8082]
loss:       21.346920  [ 4512/ 8082]
loss:       20.513264  [ 5152/ 8082]
loss:       21.536575  [ 5792/ 8082]
loss:       20.888409  [ 6432/ 8082]
loss:       20.758640  [ 7072/ 8082]
loss:       20.939220  [ 7712/ 8082]

Avg training loss:       21.111229
Avg validation loss:                  20.776789
Avg validation habitat loss:           9.028404
Avg validation movement loss:         11.748384
Movement learning rate:         [1.0000000000000002e-06]
Habitat learning rate:          [1.0000000000000002e-07]
EarlyStopping counter: 10 out of 15

Epoch 38
-------------------------------
loss:       21.311371  [   32/ 8082]
loss:       20.956961  [  672/ 8082]
loss:       20.587088  [ 1312/ 8082]
loss:       20.941013  [ 1952/ 8082]
loss:       19.801998  [ 2592/ 8082]
loss:       21.290096  [ 3232/ 8082]
loss:       21.000519  [ 3872/ 8082]
loss:       22.737528  [ 4512/ 8082]
loss:       21.201664  [ 5152/ 8082]
loss:       21.008396  [ 5792/ 8082]
loss:       21.519379  [ 6432/ 8082]
loss:       22.196375  [ 7072/ 8082]
loss:       20.736612  [ 7712/ 8082]

Avg training loss:       21.113522
Avg validation loss:                  20.775139
Avg validation habitat loss:           9.029838
Avg validation movement loss:         11.745304
Movement learning rate:         [1.0000000000000002e-06]
Habitat learning rate:          [1.0000000000000002e-07]
EarlyStopping counter: 11 out of 15

Epoch 39
-------------------------------
loss:       21.876328  [   32/ 8082]
loss:       20.419033  [  672/ 8082]
loss:       22.091740  [ 1312/ 8082]
loss:       20.317142  [ 1952/ 8082]
loss:       20.854733  [ 2592/ 8082]
loss:       20.268436  [ 3232/ 8082]
loss:       20.823586  [ 3872/ 8082]
loss:       21.214470  [ 4512/ 8082]
loss:       22.452995  [ 5152/ 8082]
loss:       21.880571  [ 5792/ 8082]
loss:       20.642220  [ 6432/ 8082]
loss:       20.872694  [ 7072/ 8082]
loss:       21.650322  [ 7712/ 8082]

Avg training loss:       21.104427
Avg validation loss:                  20.759424
Avg validation habitat loss:           9.026233
Avg validation movement loss:         11.733191
Movement learning rate:         [1.0000000000000002e-07]
Habitat learning rate:          [1.0000000000000002e-07]
Validation loss decreased (20.760530 --> 20.759424).  Saving model ...

Epoch 40
-------------------------------
loss:       21.736853  [   32/ 8082]
loss:       20.840431  [  672/ 8082]
loss:       20.314442  [ 1312/ 8082]
loss:       21.004602  [ 1952/ 8082]
loss:       20.445990  [ 2592/ 8082]
loss:       21.628838  [ 3232/ 8082]
loss:       21.115950  [ 3872/ 8082]
loss:       21.102341  [ 4512/ 8082]
loss:       21.647709  [ 5152/ 8082]
loss:       20.956989  [ 5792/ 8082]
loss:       22.170771  [ 6432/ 8082]
loss:       22.143444  [ 7072/ 8082]
loss:       20.541891  [ 7712/ 8082]

Avg training loss:       21.106998
Avg validation loss:                  20.773697
Avg validation habitat loss:           9.029427
Avg validation movement loss:         11.744270
Movement learning rate:         [1.0000000000000002e-07]
Habitat learning rate:          [1.0000000000000002e-07]
EarlyStopping counter: 1 out of 15

Epoch 41
-------------------------------
loss:       21.053188  [   32/ 8082]
loss:       20.294916  [  672/ 8082]
loss:       21.261162  [ 1312/ 8082]
loss:       20.793102  [ 1952/ 8082]
loss:       20.550575  [ 2592/ 8082]
loss:       20.478981  [ 3232/ 8082]
loss:       21.028191  [ 3872/ 8082]
loss:       21.505920  [ 4512/ 8082]
loss:       20.963558  [ 5152/ 8082]
loss:       21.153084  [ 5792/ 8082]
loss:       21.394444  [ 6432/ 8082]
loss:       20.944214  [ 7072/ 8082]
loss:       21.471575  [ 7712/ 8082]

Avg training loss:       21.112028
Avg validation loss:                  20.774817
Avg validation habitat loss:           9.028699
Avg validation movement loss:         11.746118
Movement learning rate:         [1.0000000000000002e-07]
Habitat learning rate:          [1.0000000000000002e-07]
EarlyStopping counter: 2 out of 15

Epoch 42
-------------------------------
loss:       21.156864  [   32/ 8082]
loss:       20.018555  [  672/ 8082]
loss:       21.987658  [ 1312/ 8082]
loss:       21.277771  [ 1952/ 8082]
loss:       21.080284  [ 2592/ 8082]
loss:       21.003010  [ 3232/ 8082]
loss:       20.559223  [ 3872/ 8082]
loss:       21.242886  [ 4512/ 8082]
loss:       20.538317  [ 5152/ 8082]
loss:       21.429745  [ 5792/ 8082]
loss:       21.864704  [ 6432/ 8082]
loss:       20.961948  [ 7072/ 8082]
loss:       21.007004  [ 7712/ 8082]

Avg training loss:       21.108147
Avg validation loss:                  20.778454
Avg validation habitat loss:           9.028488
Avg validation movement loss:         11.749967
Movement learning rate:         [1.0000000000000002e-07]
Habitat learning rate:          [1.0000000000000004e-08]
EarlyStopping counter: 3 out of 15

Epoch 43
-------------------------------
loss:       20.276619  [   32/ 8082]
loss:       20.783964  [  672/ 8082]
loss:       21.657917  [ 1312/ 8082]
loss:       21.723740  [ 1952/ 8082]
loss:       20.786205  [ 2592/ 8082]
loss:       21.333647  [ 3232/ 8082]
loss:       21.763710  [ 3872/ 8082]
loss:       20.915974  [ 4512/ 8082]
loss:       20.391705  [ 5152/ 8082]
loss:       21.182034  [ 5792/ 8082]
loss:       21.093012  [ 6432/ 8082]
loss:       21.890554  [ 7072/ 8082]
loss:       21.857841  [ 7712/ 8082]

Avg training loss:       21.115002
Avg validation loss:                  20.782640
Avg validation habitat loss:           9.028945
Avg validation movement loss:         11.753695
Movement learning rate:         [1.0000000000000002e-07]
Habitat learning rate:          [1.0000000000000004e-08]
EarlyStopping counter: 4 out of 15

Epoch 44
-------------------------------
loss:       20.126511  [   32/ 8082]
loss:       21.423092  [  672/ 8082]
loss:       20.504999  [ 1312/ 8082]
loss:       21.392654  [ 1952/ 8082]
loss:       21.174973  [ 2592/ 8082]
loss:       20.999424  [ 3232/ 8082]
loss:       21.461527  [ 3872/ 8082]
loss:       20.943792  [ 4512/ 8082]
loss:       21.129875  [ 5152/ 8082]
loss:       21.391418  [ 5792/ 8082]
loss:       20.764061  [ 6432/ 8082]
loss:       19.966536  [ 7072/ 8082]
loss:       20.792652  [ 7712/ 8082]

Avg training loss:       21.111006
Avg validation loss:                  20.783051
Avg validation habitat loss:           9.027270
Avg validation movement loss:         11.755780
Movement learning rate:         [1.0000000000000002e-07]
Habitat learning rate:          [1.0000000000000004e-08]
EarlyStopping counter: 5 out of 15

Epoch 45
-------------------------------
loss:       20.755146  [   32/ 8082]
loss:       21.586885  [  672/ 8082]
loss:       20.459595  [ 1312/ 8082]
loss:       21.361931  [ 1952/ 8082]
loss:       20.014366  [ 2592/ 8082]
loss:       21.905636  [ 3232/ 8082]
loss:       21.608587  [ 3872/ 8082]
loss:       21.900686  [ 4512/ 8082]
loss:       21.148649  [ 5152/ 8082]
loss:       22.457621  [ 5792/ 8082]
loss:       21.671501  [ 6432/ 8082]
loss:       21.282991  [ 7072/ 8082]
loss:       22.441212  [ 7712/ 8082]

Avg training loss:       21.109976
Avg validation loss:                  20.778486
Avg validation habitat loss:           9.025890
Avg validation movement loss:         11.752597
Movement learning rate:         [1.0000000000000004e-08]
Habitat learning rate:          [1.0000000000000004e-08]
EarlyStopping counter: 6 out of 15

Epoch 46
-------------------------------
loss:       20.833817  [   32/ 8082]
loss:       20.552261  [  672/ 8082]
loss:       21.360552  [ 1312/ 8082]
loss:       21.143639  [ 1952/ 8082]
loss:       20.412758  [ 2592/ 8082]
loss:       20.493855  [ 3232/ 8082]
loss:       20.962029  [ 3872/ 8082]
loss:       20.957211  [ 4512/ 8082]
loss:       21.096090  [ 5152/ 8082]
loss:       20.204882  [ 5792/ 8082]
loss:       20.266146  [ 6432/ 8082]
loss:       20.215599  [ 7072/ 8082]
loss:       20.816164  [ 7712/ 8082]

Avg training loss:       21.109566
Avg validation loss:                  20.764332
Avg validation habitat loss:           9.030427
Avg validation movement loss:         11.733905
Movement learning rate:         [1.0000000000000004e-08]
Habitat learning rate:          [1.0000000000000004e-08]
EarlyStopping counter: 7 out of 15

Epoch 47
-------------------------------
loss:       20.373940  [   32/ 8082]
loss:       21.585772  [  672/ 8082]
loss:       20.756474  [ 1312/ 8082]
loss:       20.826685  [ 1952/ 8082]
loss:       21.001717  [ 2592/ 8082]
loss:       20.707537  [ 3232/ 8082]
loss:       21.680599  [ 3872/ 8082]
loss:       20.417879  [ 4512/ 8082]
loss:       21.324600  [ 5152/ 8082]
loss:       21.920219  [ 5792/ 8082]
loss:       21.690548  [ 6432/ 8082]
loss:       20.678190  [ 7072/ 8082]
loss:       21.010994  [ 7712/ 8082]

Avg training loss:       21.113253
Avg validation loss:                  20.763006
Avg validation habitat loss:           9.029927
Avg validation movement loss:         11.733077
Movement learning rate:         [1.0000000000000004e-08]
Habitat learning rate:          [1.0000000000000004e-08]
EarlyStopping counter: 8 out of 15

Epoch 48
-------------------------------
loss:       20.036011  [   32/ 8082]
loss:       20.569271  [  672/ 8082]
loss:       21.323219  [ 1312/ 8082]
loss:       20.030180  [ 1952/ 8082]
loss:       21.612413  [ 2592/ 8082]
loss:       20.307692  [ 3232/ 8082]
loss:       21.452917  [ 3872/ 8082]
loss:       21.396126  [ 4512/ 8082]
loss:       21.374596  [ 5152/ 8082]
loss:       21.264673  [ 5792/ 8082]
loss:       20.684439  [ 6432/ 8082]
loss:       21.023285  [ 7072/ 8082]
loss:       21.543339  [ 7712/ 8082]

Avg training loss:       21.114367
Avg validation loss:                  20.768066
Avg validation habitat loss:           9.028212
Avg validation movement loss:         11.739855
Movement learning rate:         [1.0000000000000004e-08]
Habitat learning rate:          [1.0000000000000004e-08]
EarlyStopping counter: 9 out of 15

Epoch 49
-------------------------------
loss:       20.997339  [   32/ 8082]
loss:       21.919893  [  672/ 8082]
loss:       20.857544  [ 1312/ 8082]
loss:       20.710470  [ 1952/ 8082]
loss:       20.649572  [ 2592/ 8082]
loss:       21.327763  [ 3232/ 8082]
loss:       20.981363  [ 3872/ 8082]
loss:       21.350046  [ 4512/ 8082]
loss:       21.919949  [ 5152/ 8082]
loss:       21.537434  [ 5792/ 8082]
loss:       20.947874  [ 6432/ 8082]
loss:       20.415073  [ 7072/ 8082]
loss:       21.466240  [ 7712/ 8082]

Avg training loss:       21.108202
Avg validation loss:                  20.790010
Avg validation habitat loss:           9.029243
Avg validation movement loss:         11.760770
Movement learning rate:         [1.0000000000000004e-08]
Habitat learning rate:          [1.0000000000000004e-08]
EarlyStopping counter: 10 out of 15

Epoch 50
-------------------------------
loss:       21.044529  [   32/ 8082]
loss:       20.704817  [  672/ 8082]
loss:       22.084221  [ 1312/ 8082]
loss:       22.019882  [ 1952/ 8082]
loss:       21.799543  [ 2592/ 8082]
loss:       21.718569  [ 3232/ 8082]
loss:       21.682419  [ 3872/ 8082]
loss:       20.722166  [ 4512/ 8082]
loss:       21.392447  [ 5152/ 8082]
loss:       20.025097  [ 5792/ 8082]
loss:       20.765682  [ 6432/ 8082]
loss:       20.803421  [ 7072/ 8082]
loss:       20.800808  [ 7712/ 8082]

Avg training loss:       21.109673
Avg validation loss:                  20.783222
Avg validation habitat loss:           9.032985
Avg validation movement loss:         11.750238
Movement learning rate:         [1.0000000000000004e-08]
Habitat learning rate:          [1.0000000000000004e-08]
EarlyStopping counter: 11 out of 15

Epoch 51
-------------------------------
loss:       20.257008  [   32/ 8082]
loss:       20.532137  [  672/ 8082]
loss:       21.723961  [ 1312/ 8082]
loss:       20.733177  [ 1952/ 8082]
loss:       20.488813  [ 2592/ 8082]
loss:       21.372883  [ 3232/ 8082]
loss:       21.451153  [ 3872/ 8082]
loss:       21.253405  [ 4512/ 8082]
loss:       21.505219  [ 5152/ 8082]
loss:       20.761768  [ 5792/ 8082]
loss:       20.405777  [ 6432/ 8082]
loss:       22.343805  [ 7072/ 8082]
loss:       20.622820  [ 7712/ 8082]

Avg training loss:       21.109854
Avg validation loss:                  20.760973
Avg validation habitat loss:           9.028950
Avg validation movement loss:         11.732022
Movement learning rate:         [1.0000000000000004e-08]
Habitat learning rate:          [1.0000000000000004e-08]
EarlyStopping counter: 12 out of 15

Epoch 52
-------------------------------
loss:       20.756983  [   32/ 8082]
loss:       20.882374  [  672/ 8082]
loss:       20.980854  [ 1312/ 8082]
loss:       21.705879  [ 1952/ 8082]
loss:       20.807159  [ 2592/ 8082]
loss:       21.933815  [ 3232/ 8082]
loss:       20.952028  [ 3872/ 8082]
loss:       22.038109  [ 4512/ 8082]
loss:       22.093548  [ 5152/ 8082]
loss:       21.470623  [ 5792/ 8082]
loss:       21.565010  [ 6432/ 8082]
loss:       20.713081  [ 7072/ 8082]
loss:       21.439686  [ 7712/ 8082]

Avg training loss:       21.109364
Avg validation loss:                  20.768993
Avg validation habitat loss:           9.028571
Avg validation movement loss:         11.740425
Movement learning rate:         [1.0000000000000004e-08]
Habitat learning rate:          [1.0000000000000004e-08]
EarlyStopping counter: 13 out of 15

Epoch 53
-------------------------------
loss:       21.685266  [   32/ 8082]
loss:       22.062162  [  672/ 8082]
loss:       21.183498  [ 1312/ 8082]
loss:       21.420483  [ 1952/ 8082]
loss:       20.503111  [ 2592/ 8082]
loss:       20.824589  [ 3232/ 8082]
loss:       21.013004  [ 3872/ 8082]
loss:       21.812647  [ 4512/ 8082]
loss:       21.492619  [ 5152/ 8082]
loss:       20.997688  [ 5792/ 8082]
loss:       21.130159  [ 6432/ 8082]
loss:       21.188469  [ 7072/ 8082]
loss:       21.332197  [ 7712/ 8082]

Avg training loss:       21.111416
Avg validation loss:                  20.757061
Avg validation habitat loss:           9.026418
Avg validation movement loss:         11.730642
Movement learning rate:         [1.0000000000000004e-08]
Habitat learning rate:          [1.0000000000000004e-08]
Validation loss decreased (20.759424 --> 20.757061).  Saving model ...

Epoch 54
-------------------------------
loss:       21.341465  [   32/ 8082]
loss:       21.986927  [  672/ 8082]
loss:       21.042246  [ 1312/ 8082]
loss:       21.393723  [ 1952/ 8082]
loss:       20.944571  [ 2592/ 8082]
loss:       20.909498  [ 3232/ 8082]
loss:       20.689037  [ 3872/ 8082]
loss:       21.450266  [ 4512/ 8082]
loss:       20.910793  [ 5152/ 8082]
loss:       21.407951  [ 5792/ 8082]
loss:       20.452225  [ 6432/ 8082]
loss:       20.545444  [ 7072/ 8082]
loss:       21.031971  [ 7712/ 8082]

Avg training loss:       21.107767
Avg validation loss:                  20.775972
Avg validation habitat loss:           9.030288
Avg validation movement loss:         11.745685
Movement learning rate:         [1.0000000000000004e-08]
Habitat learning rate:          [1.0000000000000004e-08]
EarlyStopping counter: 1 out of 15

Epoch 55
-------------------------------
loss:       20.788361  [   32/ 8082]
loss:       20.858458  [  672/ 8082]
loss:       22.218189  [ 1312/ 8082]
loss:       22.717186  [ 1952/ 8082]
loss:       20.757273  [ 2592/ 8082]
loss:       21.272762  [ 3232/ 8082]
loss:       20.271229  [ 3872/ 8082]
loss:       20.457176  [ 4512/ 8082]
loss:       20.745560  [ 5152/ 8082]
loss:       21.188873  [ 5792/ 8082]
loss:       20.985039  [ 6432/ 8082]
loss:       20.731689  [ 7072/ 8082]
loss:       20.740452  [ 7712/ 8082]

Avg training loss:       21.109037
Avg validation loss:                  20.759003
Avg validation habitat loss:           9.027705
Avg validation movement loss:         11.731300
Movement learning rate:         [1.0000000000000004e-08]
Habitat learning rate:          [1.0000000000000004e-08]
EarlyStopping counter: 2 out of 15

Epoch 56
-------------------------------
loss:       21.623131  [   32/ 8082]
loss:       20.550379  [  672/ 8082]
loss:       20.471752  [ 1312/ 8082]
loss:       21.106680  [ 1952/ 8082]
loss:       21.489708  [ 2592/ 8082]
loss:       21.512596  [ 3232/ 8082]
loss:       20.833183  [ 3872/ 8082]
loss:       20.972534  [ 4512/ 8082]
loss:       21.607494  [ 5152/ 8082]
loss:       19.893278  [ 5792/ 8082]
loss:       20.395542  [ 6432/ 8082]
loss:       21.356148  [ 7072/ 8082]
loss:       22.237072  [ 7712/ 8082]

Avg training loss:       21.110117
Avg validation loss:                  20.770212
Avg validation habitat loss:           9.030463
Avg validation movement loss:         11.739747
Movement learning rate:         [1.0000000000000004e-08]
Habitat learning rate:          [1.0000000000000004e-08]
EarlyStopping counter: 3 out of 15

Epoch 57
-------------------------------
loss:       21.471134  [   32/ 8082]
loss:       20.403673  [  672/ 8082]
loss:       21.805651  [ 1312/ 8082]
loss:       21.556252  [ 1952/ 8082]
loss:       21.516830  [ 2592/ 8082]
loss:       21.764790  [ 3232/ 8082]
loss:       20.962631  [ 3872/ 8082]
loss:       22.426086  [ 4512/ 8082]
loss:       21.106247  [ 5152/ 8082]
loss:       22.032450  [ 5792/ 8082]
loss:       20.856766  [ 6432/ 8082]
loss:       21.661938  [ 7072/ 8082]
loss:       20.319122  [ 7712/ 8082]

Avg training loss:       21.106272
Avg validation loss:                  20.787678
Avg validation habitat loss:           9.027486
Avg validation movement loss:         11.760192
Movement learning rate:         [1.0000000000000004e-08]
Habitat learning rate:          [1.0000000000000004e-08]
EarlyStopping counter: 4 out of 15

Epoch 58
-------------------------------
loss:       21.594254  [   32/ 8082]
loss:       20.480169  [  672/ 8082]
loss:       21.706699  [ 1312/ 8082]
loss:       20.566048  [ 1952/ 8082]
loss:       21.498877  [ 2592/ 8082]
loss:       21.729258  [ 3232/ 8082]
loss:       20.862799  [ 3872/ 8082]
loss:       20.733498  [ 4512/ 8082]
loss:       20.754040  [ 5152/ 8082]
loss:       21.192673  [ 5792/ 8082]
loss:       20.812071  [ 6432/ 8082]
loss:       21.474667  [ 7072/ 8082]
loss:       21.143600  [ 7712/ 8082]

Avg training loss:       21.111210
Avg validation loss:                  20.787897
Avg validation habitat loss:           9.029156
Avg validation movement loss:         11.758740
Movement learning rate:         [1.0000000000000004e-08]
Habitat learning rate:          [1.0000000000000004e-08]
EarlyStopping counter: 5 out of 15

Epoch 59
-------------------------------
loss:       21.229256  [   32/ 8082]
loss:       22.690159  [  672/ 8082]
loss:       21.201874  [ 1312/ 8082]
loss:       20.482384  [ 1952/ 8082]
loss:       20.652485  [ 2592/ 8082]
loss:       21.737740  [ 3232/ 8082]
loss:       21.341522  [ 3872/ 8082]
loss:       21.518597  [ 4512/ 8082]
loss:       20.663443  [ 5152/ 8082]
loss:       22.043026  [ 5792/ 8082]
loss:       20.103086  [ 6432/ 8082]
loss:       20.547050  [ 7072/ 8082]
loss:       21.169407  [ 7712/ 8082]

Avg training loss:       21.111944
Avg validation loss:                  20.758141
Avg validation habitat loss:           9.027023
Avg validation movement loss:         11.731115
Movement learning rate:         [1.0000000000000004e-08]
Habitat learning rate:          [1.0000000000000004e-08]
EarlyStopping counter: 6 out of 15

Epoch 60
-------------------------------
loss:       20.931267  [   32/ 8082]
loss:       20.691107  [  672/ 8082]
loss:       20.470964  [ 1312/ 8082]
loss:       21.566435  [ 1952/ 8082]
loss:       20.457657  [ 2592/ 8082]
loss:       21.211708  [ 3232/ 8082]
loss:       21.773178  [ 3872/ 8082]
loss:       20.698442  [ 4512/ 8082]
loss:       20.305294  [ 5152/ 8082]
loss:       20.531590  [ 5792/ 8082]
loss:       21.057816  [ 6432/ 8082]
loss:       20.357700  [ 7072/ 8082]
loss:       22.122181  [ 7712/ 8082]

Avg training loss:       21.110146
Avg validation loss:                  20.776733
Avg validation habitat loss:           9.028346
Avg validation movement loss:         11.748386
Movement learning rate:         [1.0000000000000004e-08]
Habitat learning rate:          [1.0000000000000004e-08]
EarlyStopping counter: 7 out of 15

Epoch 61
-------------------------------
loss:       21.237446  [   32/ 8082]
loss:       20.303053  [  672/ 8082]
loss:       20.483995  [ 1312/ 8082]
loss:       21.131941  [ 1952/ 8082]
loss:       21.023201  [ 2592/ 8082]
loss:       21.809341  [ 3232/ 8082]
loss:       21.812557  [ 3872/ 8082]
loss:       20.312874  [ 4512/ 8082]
loss:       21.387962  [ 5152/ 8082]
loss:       22.553179  [ 5792/ 8082]
loss:       21.530895  [ 6432/ 8082]
loss:       20.653793  [ 7072/ 8082]
loss:       20.768784  [ 7712/ 8082]

Avg training loss:       21.114136
Avg validation loss:                  20.780287
Avg validation habitat loss:           9.029342
Avg validation movement loss:         11.750943
Movement learning rate:         [1.0000000000000004e-08]
Habitat learning rate:          [1.0000000000000004e-08]
EarlyStopping counter: 8 out of 15

Epoch 62
-------------------------------
loss:       20.616972  [   32/ 8082]
loss:       21.211975  [  672/ 8082]
loss:       20.891670  [ 1312/ 8082]
loss:       21.081400  [ 1952/ 8082]
loss:       21.075647  [ 2592/ 8082]
loss:       21.460037  [ 3232/ 8082]
loss:       21.184305  [ 3872/ 8082]
loss:       20.263901  [ 4512/ 8082]
loss:       20.499617  [ 5152/ 8082]
loss:       20.833942  [ 5792/ 8082]
loss:       21.430456  [ 6432/ 8082]
loss:       21.733582  [ 7072/ 8082]
loss:       21.636635  [ 7712/ 8082]

Avg training loss:       21.111383
Avg validation loss:                  20.758785
Avg validation habitat loss:           9.028771
Avg validation movement loss:         11.730015
Movement learning rate:         [1.0000000000000004e-08]
Habitat learning rate:          [1.0000000000000004e-08]
EarlyStopping counter: 9 out of 15

Epoch 63
-------------------------------
loss:       20.826603  [   32/ 8082]
loss:       21.042696  [  672/ 8082]
loss:       22.261703  [ 1312/ 8082]
loss:       20.870951  [ 1952/ 8082]
loss:       19.816666  [ 2592/ 8082]
---------------------------------------------------------------------------
KeyboardInterrupt                         Traceback (most recent call last)
Cell In[215], line 73
     70 skip_training = (t == 0)
     72 # 1. Run the training loop for one epoch using the training dataloader
---> 73 train_loop(dataloader_train, 
     74            model, 
     75            loss_fn, 
     76            optimisers, 
     77            skip_epoch0_training=skip_training)
     79 # 2. Evaluate model performance on the validation dataset
     80 model.eval()  # Switch to evaluation mode for proper layer behavior

Cell In[213], line 35, in train_loop(dataloader_train, model, loss_fn, optimisers, skip_epoch0_training)
     33 # Forward pass: compute the model output and loss
     34 with torch.set_grad_enabled(not skip_epoch0_training):
---> 35     outputs = model((x1, x2, x3))
     36     total_loss, habitat_loss, movement_loss = loss_fn(outputs, y,)
     38 epoch_loss += total_loss.detach()  # Use detach to prevent memory leaks

File /opt/miniconda3/envs/deepSSF/lib/python3.12/site-packages/torch/nn/modules/module.py:1751, in Module._wrapped_call_impl(self, *args, **kwargs)
   1749     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1750 else:
-> 1751     return self._call_impl(*args, **kwargs)

File /opt/miniconda3/envs/deepSSF/lib/python3.12/site-packages/torch/nn/modules/module.py:1762, in Module._call_impl(self, *args, **kwargs)
   1757 # If we don't have any hooks, we want to skip the rest of the logic in
   1758 # this function, and just call forward.
   1759 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1760         or _global_backward_pre_hooks or _global_backward_hooks
   1761         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1762     return forward_call(*args, **kwargs)
   1764 result = None
   1765 called_always_called_hooks = set()

Cell In[197], line 64, in ConvJointModel.forward(self, x)
     61 output_movement = self.fcn_movement_all(scalar_inputs)
     63 # Transform the movement parameters into a grid, using bearing information
---> 64 output_movement = self.movement_grid_output(output_movement, bearing_x)
     66 # Combine (stack) habitat and movement outputs without merging them
     67 output = torch.stack((output_habitat, output_movement), dim=-1)

File /opt/miniconda3/envs/deepSSF/lib/python3.12/site-packages/torch/nn/modules/module.py:1751, in Module._wrapped_call_impl(self, *args, **kwargs)
   1749     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1750 else:
-> 1751     return self._call_impl(*args, **kwargs)

File /opt/miniconda3/envs/deepSSF/lib/python3.12/site-packages/torch/nn/modules/module.py:1762, in Module._call_impl(self, *args, **kwargs)
   1757 # If we don't have any hooks, we want to skip the rest of the logic in
   1758 # this function, and just call forward.
   1759 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1760         or _global_backward_pre_hooks or _global_backward_hooks
   1761         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1762     return forward_call(*args, **kwargs)
   1764 result = None
   1765 called_always_called_hooks = set()

Cell In[195], line 98, in Params_to_Grid_Block_ChV.forward(self, x, bearing)
     95 gamma_weight2 = gamma_weights[1]
     97 # calculation of Gamma densities
---> 98 gamma_density_layer1 = self.gamma_density(self.distance_layer,
     99                                           gamma_shape1,
    100                                           gamma_scale1).to(device)
    102 gamma_density_layer2 = self.gamma_density(self.distance_layer,
    103                                           gamma_shape2,
    104                                           gamma_scale2).to(device)
    106 # combining both densities to create a mixture distribution using logsumexp

Cell In[195], line 39, in Params_to_Grid_Block_ChV.gamma_density(self, x, shape, scale)
     35 scale = scale.to(x.device)
     36 # return -1*torch.lgamma(shape) -shape*torch.log(scale) + (shape - 1)*torch.log(x) - x/scale
     37 
     38 # to account for change of variables
---> 39 return (-1*torch.lgamma(shape) -shape*torch.log(scale) + (shape - 1)*torch.log(x) - x/scale) - torch.log(x)

KeyboardInterrupt: 

Make a GIF of the training images

First, here’s a function to call to make a gif from a given directory.

Code
# Example sorting by the epoch number
def extract_index(filename):
    # Extract the epoch number from the filename
    # Adjust the extraction based on your naming pattern
    import re
    match = re.search(r'index(\d+)_', filename)
    if match:
        return int(match.group(1))
    return 0

def create_gif(image_folder, output_filename, fps=10):
    """
    Creates a GIF from a sequence of images in a folder.

    Parameters:
    - image_folder: Path to the folder containing images
    - output_filename: Name of the output GIF file
    - duration: Duration of each frame in seconds
    """
    # Get all png files in the specified folder, sorted by name
    images = sorted(glob.glob(os.path.join(image_folder, '*.png')), key=extract_index)

    # Check if any images were found
    if not images:
        print(f"No images found in {image_folder}")
        return

    # Read all images
    frames = [imageio.imread(image) for image in images]

    # Save as GIF
    imageio.mimsave(output_filename, frames, fps=fps, loop=0)

    display(Image(filename=output_filename))

    print(f"GIF created successfully: {output_filename}")

Create training GIF

Code
# Path to your images
image_folder =  f'{output_dir}/training_images'
# Output GIF filename
output_filename = f'{output_dir}/training_gif_id{buffalo_id}_yday{yday_t2_integer}_hour{hour_t2_integer}_bearing{bearing_degrees}_next_r{row}_c{column}.gif'
# Create the GIF
create_gif(image_folder, output_filename, fps=10)
<IPython.core.display.Image object>
GIF created successfully: ../Python/outputs/model_training_S2/id2005_scalar_movement_2_2025-06-05/training_gif_id2005_yday122_hour16_bearing156_next_r46_c44.gif

Create loss GIF

Code
# Path to your images
image_folder =  f'{output_dir}/loss_images'
# Output GIF filename
output_filename = f'{output_dir}/loss_gif.gif'
# Create the GIF
create_gif(image_folder, output_filename, fps=10)
<IPython.core.display.Image object>
GIF created successfully: ../Python/outputs/model_training_S2/id2005_scalar_movement_2_2025-06-05/loss_gif.gif

Loading in previous models

As we’ve trained the model, the model parameters are already stored in the model object. But as we were training the model, we were saving it to file, and that, and other trained models can be loaded.

The model parameters that are being loaded must match the model object that has been defined above. If the model object has changed, the model parameters will not be able to be loaded.

Code
path_save_weights
'../Python/outputs/model_training_S2/id2005_scalar_movement_2_2025-06-05/checkpoint_deepSSF_buffalo2005.pt'

If loading a previously trained model

Code
# to load previously saved weights
# path_save_weights = f'{output_dir}/checkpoint_deepSSF_buffalo2005_2025-04-01.pt'

model.load_state_dict(torch.load(path_save_weights,
                                 weights_only=True,
                                 map_location=torch.device('cpu')))
<All keys matched successfully>

View model outputs

Create a directory to save model outputs

Save the validation loss as a dataframe

Code
# Directory for saving the loss dataframe
filename_loss_csv = f'{output_dir}/deepSSF_val_loss_buffalo{buffalo_id}.csv'

# Check if val_losses is defined (which means a model has been trained in this session)
try:

    # Convert the list of tensors to a single tensor
    val_losses_tensor = torch.tensor(val_losses)

    print("val_losses has been defined - storing as csv\n")

    # Number of epochs
    n_epochs = len(val_losses)
    print(f'Number of epochs: {n_epochs}')

    val_losses_df = pd.DataFrame({
        "epoch": range(1, n_epochs + 1),
        "val_losses": val_losses_tensor.detach().cpu().numpy()
    })

    print(val_losses_df.head())

    # Save the validation losses to a CSV file
    val_losses_df.to_csv(filename_loss_csv, index=False)

# if val_losses hasn't been defined (for if you are loading model weights from a saved object)
except NameError:

    # This code runs if val_losses is not defined
    print("val_losses has not been defined - loading from saved csv\n")
    # Initialize it with a default value

    # Read the val_losses csv file
    val_losses_df = pd.read_csv(filename_loss_csv)
    print(val_losses_df.head())

    # Number of epochs
    n_epochs = len(val_losses_df)
    print(f'\nNumber of epochs: {n_epochs}')
val_losses has been defined - storing as csv

Number of epochs: 62
   epoch  val_losses
0      1   22.567806
1      2   22.097454
2      3   21.604164
3      4   21.228951
4      5   21.085865
Code
# Directory for saving the loss dataframe
filename_train_loss_csv = f'{output_dir}/deepSSF_train_loss_buffalo{buffalo_id}.csv'

# Check if train_losses is defined (which means a model has been trained in this session)
try:

    # Convert the list of tensors to a single tensor
    train_losses_tensor = torch.tensor(train_losses)

    print("train_losses has been defined - storing as csv\n")

    train_losses_df = pd.DataFrame({
        "epoch": np.linspace(1, n_epochs, len(train_losses)),
        "train_losses": train_losses_tensor.detach().cpu().numpy()
    })

    # print(train_losses_df.head)

    # Save the train losses to a CSV file
    train_losses_df.to_csv(filename_train_loss_csv, index=False)

# if train_losses hasn't been defined (for if you are loading model weights from a saved object)
except NameError:

    # This code runs if train_losses is not defined
    print("train_losses has not been defined - loading from saved csv\n")
    # Initialize it with a default value

    # Read the train_losses csv file
    train_losses_df = pd.read_csv(filename_train_loss_csv)
    # print(train_losses_df.head())
train_losses has been defined - storing as csv

Plot the validation loss

Code
# Directory for saving the loss plots
filename_loss = f'{output_dir}/val_loss_buffalo{buffalo_id}.png'

# Plot the validation losses
plt.plot(train_losses_df['epoch'], train_losses_df['train_losses'], label='Training Loss', color='blue')  # Plot training loss in blue
plt.plot(val_losses_df['epoch'], val_losses_df['val_losses'], label='Validation Loss', color='red')  # Plot validation loss in red
plt.title('Validation Losses')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend()  # Show legend to distinguish lines
plt.savefig(filename_loss, dpi=300, bbox_inches='tight')
plt.show()

Test model

Take some random samples from the test dataset and generate predictions for them. We loop through the samples (which are shuffled randomly), make predictions, and plot the results.

Code
# 1. Set the model in evaluation mode
model.eval()

# Loop over samples in the validation dataset
for i in range(0, 5):

  sample_number = np.random.randint(0, len(dataloader_test.dataset))
  print(f'Sample number: {sample_number}')

  # Display image and label
  x1, x2, x3, labels = dataloader_test.dataset[sample_number]

  # Add a batch dimension
  x1 = x1.unsqueeze(0).cpu()
  x2 = x2.unsqueeze(0).cpu()
  x3 = x3.unsqueeze(0).cpu()
  labels = labels.unsqueeze(0).cpu()

  # Pull out the scalars
  hour_t2_sin1 = x2.detach().numpy()[0,0]
  hour_t2_cos1 = x2.detach().numpy()[0,1]
  hour_t2_sin2 = x2.detach().numpy()[0,2]
  hour_t2_cos2 = x2.detach().numpy()[0,3]
  yday_t2_sin1 = x2.detach().numpy()[0,4]
  yday_t2_cos1 = x2.detach().numpy()[0,5]
  yday_t2_sin2 = x2.detach().numpy()[0,6]
  yday_t2_cos2 = x2.detach().numpy()[0,7]
  bearing = x3.detach().numpy()[0,0]

  # Recover the hour
  hour_t2 = recover_hour(hour_t2_sin1, hour_t2_cos1)
  hour_t2_integer = int(hour_t2)  # Convert to integer
  print(f'Hour:                        {hour_t2_integer}')

  # Recover the day of the year
  yday_t2 = recover_yday(yday_t2_sin1, yday_t2_cos1)
  yday_t2_integer = int(yday_t2)  # Convert to integer
  print(f'Day of the year:             {yday_t2_integer}')

  # Recover the bearing
  bearing_degrees = np.degrees(bearing) % 360
  bearing_degrees = round(bearing_degrees, 1)  # Round to 2 decimal places
  bearing_degrees = int(bearing_degrees)  # Convert to integer
  print(f'Bearing (radians):           {bearing}')
  print(f'Bearing (degrees):           {bearing_degrees}')

  # Pull out the RGB layers for plotting
  blue_layer = x1.detach().cpu().numpy()[0,1,:,:]
  green_layer = x1.detach().cpu().numpy()[0,2,:,:]
  red_layer = x1.detach().cpu().numpy()[0,3,:,:]

  # Stack the RGB layers
  rgb_image_np = np.stack([red_layer, green_layer, blue_layer], axis=-1)

  # Normalize to the range [0, 1] for display
  rgb_image_np = (rgb_image_np - rgb_image_np.min()) / (rgb_image_np.max() - rgb_image_np.min())

  # Find the coordinates of the element that is 1
  target = labels.detach().cpu().numpy()[0,:,:]
  coordinates = np.where(target == 1)

  # Extract the coordinates
  row, column = coordinates[0][0], coordinates[1][0]
  print(f"Next step is (row, column):  ({row}, {column})")


  # -------------------------------------------------------------------------
  # Run the model on the input data
  # -------------------------------------------------------------------------

  # Move input tensors to the GPU if available
  x1 = x1.to(device)
  x2 = x2.to(device)
  x3 = x3.to(device)

  test = model((x1, x2, x3))
  # print(test.shape)

  # Extract and exponentiate the habitat density channel
  hab_density = test.detach().cpu().numpy()[0, :, :, 0]
  hab_density_exp = np.exp(hab_density)
  # print(np.sum(hab_density_exp))  # Debug: check the sum of exponentiated values

  # Create masks to remove unwanted edge cells from visualization
  #    (setting them to -∞ affects the color scale in plots)
  x_mask = np.ones_like(hab_density)
  y_mask = np.ones_like(hab_density)

  # mask out cells on the edges that affect the colour scale
  x_mask[:, :3] = -np.inf
  x_mask[:, 98:] = -np.inf
  y_mask[:3, :] = -np.inf
  y_mask[98:, :] = -np.inf

  # Apply the masks to the habitat density (log scale) and exponentiated version
  hab_density_mask = hab_density * x_mask * y_mask
  hab_density_exp_mask = hab_density_exp * x_mask * y_mask

  # Extract and exponentiate the movement density channel
  move_density = test.detach().cpu().numpy()[0,:,:,1]
  move_density_exp = np.exp(move_density)

  # Apply the same masking strategy to movement densities
  move_density_mask = move_density * x_mask * y_mask
  move_density_exp_mask = move_density_exp * x_mask * y_mask

  # Compute the next-step density by adding habitat + movement (log-space)
  step_density = test[0, :, :, 0] + test[0, :, :, 1]
  step_density = step_density.detach().cpu().numpy()
  step_density_exp = np.exp(step_density)

  # Apply masks to the step densities (log and exponentiated)
  step_density_mask = step_density * x_mask * y_mask
  step_density_exp_mask = step_density_exp * x_mask * y_mask

  # -------------------------------------------------------------------------
  # Plot the RGB image, slope, habitat selection, and movement density
  #   Change the panels to visualize different layers
  # -------------------------------------------------------------------------
  fig, axs = plt.subplots(2, 2, figsize=(10, 10))

  # Plot RGB
  im1 = axs[0, 0].imshow(rgb_image_np)
  axs[0, 0].set_title('Sentinel-2 RGB')

  # Plot slope
  im2 = axs[0, 1].imshow(x1.detach().cpu().numpy()[0,12,:,:], cmap='viridis')
  axs[0, 1].set_title('Slope')
  fig.colorbar(im2, ax=axs[0, 1], shrink=0.7)

  # Plot habitat selection
  im3 = axs[1, 0].imshow(hab_density_mask, cmap='viridis')
  axs[1, 0].set_title('Habitat selection log-probability')
  fig.colorbar(im3, ax=axs[1, 0], shrink=0.7)

  # # Movement density (change the axis and uncomment one of the other panels)
  # im3 = axs[1, 0].imshow(move_density_mask, cmap='viridis')
  # axs[1, 0].set_title('Movement log-probability')
  # fig.colorbar(im3, ax=axs[0, 1], shrink=0.7)

  # Next-step probability
  im4 = axs[1, 1].imshow(step_density_mask, cmap='viridis')
  axs[1, 1].set_title('Next-step log-probability')
  fig.colorbar(im4, ax=axs[1, 1], shrink=0.7)

  # Save the figure
  filename_covs = f'{output_dir}/deepSSF_S2_slope_id{buffalo_id}_yday{yday_t2_integer}_hour{hour_t2_integer}_bearing{bearing_degrees}_next_r{row}_c{column}.png'
  plt.tight_layout()
  plt.savefig(filename_covs, dpi=600, bbox_inches='tight')
  plt.show()
  plt.close()  # Close the figure to free memory
Sample number: 133
Hour:                        5
Day of the year:             262
Bearing (radians):           1.7486798763275146
Bearing (degrees):           100
Next step is (row, column):  (50, 50)

Sample number: 135
Hour:                        7
Day of the year:             262
Bearing (radians):           -0.8036912679672241
Bearing (degrees):           314
Next step is (row, column):  (56, 65)

Sample number: 171
Hour:                        0
Day of the year:             263
Bearing (radians):           0.37439003586769104
Bearing (degrees):           21
Next step is (row, column):  (50, 50)

Sample number: 994
Hour:                        4
Day of the year:             303
Bearing (radians):           1.3337944746017456
Bearing (degrees):           76
Next step is (row, column):  (50, 50)

Sample number: 308
Hour:                        18
Day of the year:             270
Bearing (radians):           2.5887136459350586
Bearing (degrees):           148
Next step is (row, column):  (23, 33)

Extracting convolution layer outputs

In the convolutional blocks, each convolutional layer learns a set of filters (kernels) that extract different features from the input data. In the habitat selection subnetwork, the convolution filters (and their associated bias parameters - not shown below) are the only parameters that are trained, and it is the filters that transform the set of input covariates into the habitat selection probabilities. They do this by maximising features of the inputs that correlate with observed next-steps.

For each convolutional layer, there are typically a number of filters. For the habitat selection subnetwork, we used 4 filters in the first two layers, and a single filter in the last layer. Each of these filters has a number of channels which correspond one-to-one with the input layers. The outputs of the filter channels are then combined to produce a feature map, with a single feature map produced for each filter. In successive layers, the feature maps become the input layers, and the filters operate on these layers. Because there are multiple filters in ech layer, they can ‘specialise’ in extracting different features from the input layers.

By visualizing and inspecting these filters, and the corresponding feature maps, we can:

  • Gain interpretability: Understand what kind of features the network is detecting—e.g., edges, shapes, or textures.
  • Debug: Check if the filters have meaningful patterns or if something went wrong (e.g., all zeros or random noise).
  • Compare layers: See how early layers often learn low-level patterns while deeper layers learn more abstract features.

We will first set up some activation hooks for storing the feature maps. Activation hooks are placed at certain points within the model’s forward pass and store intermediate results. We will also extract the convolution filters (which are weights of the model and as such don’t require hooks - we can access them directly).

We will then run the sample covariates through the model and extract the feature maps from the habitat selection convolutional block, and plot them along with the covariates and convolution filters.

Note that there are also ReLU activation functions in the convolutional blocks, which are not shown below. These are applied to the feature maps, and set all negative values to zero. They are not learned parameters, but are part of the forward pass of the model.

Create scalar grids for plotting

Using the Scalar_to_Grid_Block class from the deepSSF_model script, we can convert the scalar covariates into grids for plotting.

Code
# Create an instance of the scalar-to-grid block using model parameters
scalar_to_grid_block = Scalar_to_Grid_Block(params)

# Convert scalars into spatial grid representation
scalar_maps = scalar_to_grid_block(x2)
print(scalar_maps.shape)  # Check the shape of the generated spatial maps
torch.Size([1, 8, 101, 101])

Convolutional layer 1

Activation hook

Code
# -----------------------------------------------------------
# Create a dictionary to store activation outputs
# -----------------------------------------------------------
activation = {}

def get_activation(name):
    """
    Returns a hook function that can be registered on a layer
    to capture its output (i.e., feature maps) after the forward pass.

    Args:
        name (str): The key under which the activation is stored in the 'activation' dict.
    """
    def hook(model, input, output):
        # Detach and save the layer's output in the dictionary
        activation[name] = output.detach()
    return hook

# -----------------------------------------------------------
# Register a forward hook on the first convolution layer
#    in the model's 'conv_habitat' block
# -----------------------------------------------------------
model.conv_habitat.conv2d[0].register_forward_hook(get_activation("hab_conv1"))

# -----------------------------------------------------------
# Perform a forward pass through the model with the desired input
#    The feature maps from the hooked layer will be stored in 'activation'
# -----------------------------------------------------------
out = model((x1, x2, x3))  # e.g., model((spatial_data_x, scalars_to_grid, bearing_x))

# -----------------------------------------------------------
# Retrieve the captured feature maps from the dictionary
#    and move them to the CPU for inspection
# -----------------------------------------------------------
feat_maps1 = activation["hab_conv1"].cpu()
print("Feature map shape:", feat_maps1.shape)
# Typically shape: (batch_size, out_channels, height, width)

# -----------------------------------------------------------
# Visualize the feature maps for the first sample in the batch
# -----------------------------------------------------------
feat_maps1_sample = feat_maps1[0]  # Shape: (out_channels, H, W)
num_maps1 = feat_maps1_sample.shape[0]
print("Number of feature maps:", num_maps1)
Feature map shape: torch.Size([1, 4, 101, 101])
Number of feature maps: 4

Stack spatial and scalar (as grid) covariates

For plotting. Also create a vector of names to index over.

Code
covariate_stack = torch.cat([x1, scalar_maps], dim=1)
print(covariate_stack.shape)

covariate_names = ['S2 B1',
                   'S2 B2',
                   'S2 B3',
                   'S2 B4',
                   'S2 B5',
                   'S2 B6',
                   'S2 B7',
                   'S2 B8',
                   'S2 B8a',
                   'S2 B9',
                   'S2 B11',
                   'S2 B12',
                   'Slope',
                   'Hour sin1',
                   'Hour cos1',
                   'Hour sin2',
                   'Hour cos2',
                   'yday sin1',
                   'yday cos1',
                   'yday sin2',
                   'yday cos2',]
torch.Size([1, 21, 101, 101])

Extract filters and plot

Code
# -------------------------------------------------------------------------
# Check or print the convolution layer in conv_habitat (for debugging)
# -------------------------------------------------------------------------
print(model.conv_habitat.conv2d)

# -------------------------------------------------------------------------
# Set the model to evaluation mode (disables dropout, etc.)
# -------------------------------------------------------------------------
model.eval()

# -------------------------------------------------------------------------
# Extract the weights (filters) from the first convolution layer in conv_habitat
# -------------------------------------------------------------------------
filters_c1 = model.conv_habitat.conv2d[0].weight.data.clone().cpu()
print("Filters shape:", filters_c1.shape)
# Typically (out_channels, in_channels, kernel_height, kernel_width)

# -------------------------------------------------------------------------
# Visualize each filter’s first channel in a grid of subplots
# -------------------------------------------------------------------------
num_filters_c1 = filters_c1.shape[1]
print(num_filters_c1)

for z in range(num_maps1):

    fig, axes = plt.subplots(2, num_filters_c1, figsize=(2*num_filters_c1, 4))
    for i in range(num_filters_c1):

        # Add the covariates as the first row of subplots
        axes[0,i].imshow(covariate_stack[0, i].detach().cpu().numpy(), cmap='viridis')
        axes[0,i].axis('off')
        axes[0,i].set_title(f'{covariate_names[i]}')
        if i > x1.shape[1] - 1:
            im1 = axes[0,i].imshow(covariate_stack[0, i].detach().cpu().numpy(), cmap='viridis')
            im1.set_clim(-1, 1)
            axes[0,i].text(scalar_maps.shape[2] // 2, scalar_maps.shape[3] // 2,
                f'Value: {round(x2[0, i-x1.shape[1]].item(), 2)}',
                ha='center', va='center', color='white', fontsize=12)

        kernel = filters_c1[z, i, :, :]  # Show the first input channel
        im = axes[1,i].imshow(kernel, cmap='viridis')
        axes[1,i].axis('off')
        axes[1,i].set_title(f'Layer 1, Filter {z+1}')
        # Annotate each cell with the numeric value
        for (j, k), val in np.ndenumerate(kernel):
            axes[1,i].text(k, j, f'{val:.2f}', ha='center', va='center', color='white')

    plt.tight_layout()
    plt.savefig(f'{output_dir}/id{buffalo_id}_conv_layer1_filters{z}_{today_date}.png', dpi=600, bbox_inches='tight')
    plt.show()


    # -----------------------------------------------------------
    # Loop over each feature map channel and save them as images.
    #    Multiply by x_mask * y_mask if you need to mask out edges.
    # -----------------------------------------------------------

    plt.figure()
    plt.imshow(feat_maps1_sample[z].numpy() * x_mask * y_mask, cmap='viridis')
    plt.title(f"Layer 1, Feature Map {z+1}")
    # Hide axis if you prefer: plt.axis('off')
    plt.savefig(f'{output_dir}/id{buffalo_id}_conv_layer1_feature_map{z}_{today_date}.png', dpi=600, bbox_inches='tight')
    plt.show()
Sequential(
  (0): Conv2d(21, 4, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (1): ReLU()
  (2): Conv2d(4, 4, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (3): ReLU()
  (4): Conv2d(4, 4, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (5): ReLU()
  (6): Conv2d(4, 1, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
)
Filters shape: torch.Size([4, 21, 3, 3])
21

Convolutional layer 2

Activation hook

Code
# -----------------------------------------------------------
# Register a forward hook on the second convolution layer
#    in the model's 'conv_habitat' block
# -----------------------------------------------------------
model.conv_habitat.conv2d[2].register_forward_hook(get_activation("hab_conv2"))

# -----------------------------------------------------------
# Perform a forward pass through the model with the desired input
#    The feature maps from the hooked layer will be stored in 'activation'
# -----------------------------------------------------------
out = model((x1, x2, x3))  # e.g., model((spatial_data_x, scalars_to_grid, bearing_x))

# -----------------------------------------------------------
# Retrieve the captured feature maps from the dictionary
#    and move them to the CPU for inspection
# -----------------------------------------------------------
feat_maps2 = activation["hab_conv2"].cpu()
print("Feature map shape:", feat_maps2.shape)
# Typically shape: (batch_size, out_channels, height, width)

# -----------------------------------------------------------
# Visualize the feature maps for the first sample in the batch
# -----------------------------------------------------------
feat_maps2_sample = feat_maps2[0]  # Shape: (out_channels, H, W)
num_maps2 = feat_maps2_sample.shape[0]
print("Number of feature maps:", num_maps2)
Feature map shape: torch.Size([1, 4, 101, 101])
Number of feature maps: 4

Extract filters and plot

Code
# -------------------------------------------------------------------------
# Extract the weights (filters) from the second convolution layer in conv_habitat
# -------------------------------------------------------------------------
filters_c2 = model.conv_habitat.conv2d[2].weight.data.clone().cpu()
print("Filters shape:", filters_c2.shape)
# Typically (out_channels, in_channels, kernel_height, kernel_width)

# -------------------------------------------------------------------------
# Visualize each filter’s first channel in a grid of subplots
# -------------------------------------------------------------------------
num_filters_c2 = filters_c2.shape[1]
print(num_filters_c2)

for z in range(num_maps2):

    fig, axes = plt.subplots(2, num_filters_c2, figsize=(2*num_filters_c2, 4))
    for i in range(num_filters_c2):

        # Add the covariates as the first row of subplots
        axes[0,i].imshow(feat_maps1_sample[i].numpy() * x_mask * y_mask, cmap='viridis')
        axes[0,i].axis('off')
        axes[0,i].set_title(f"Layer 1, Map {z+1}")

        # if i > 3:
        #     im1 = axes[0,i].imshow(covariate_stack[0, i].detach().cpu().numpy(), cmap='viridis')
        #     im1.set_clim(-1, 1)
        #     axes[0,i].text(scalar_maps.shape[2] // 2, scalar_maps.shape[3] // 2,
        #         f'Value: {round(x2[0, i-4].item(), 2)}',
        #         ha='center', va='center', color='white', fontsize=12)

        kernel = filters_c2[z, i, :, :]  # Show the first input channel
        im = axes[1,i].imshow(kernel, cmap='viridis')
        axes[1,i].axis('off')
        axes[1,i].set_title(f'Layer 2, Filter {z+1}')
        # Annotate each cell with the numeric value
        for (j, k), val in np.ndenumerate(kernel):
            axes[1,i].text(k, j, f'{val:.2f}', ha='center', va='center', color='white')

    plt.tight_layout()
    plt.savefig(f'{output_dir}/id{buffalo_id}_conv_layer2_filters{z}_{today_date}.png', dpi=600, bbox_inches='tight')
    plt.show()


    # -----------------------------------------------------------
    # 6. Loop over each feature map channel and save them as images.
    #    Multiply by x_mask * y_mask if you need to mask out edges.
    # -----------------------------------------------------------

    plt.figure()
    plt.imshow(feat_maps2_sample[z].numpy() * x_mask * y_mask, cmap='viridis')
    plt.title(f"Layer 2, Feature Map {z+1}")
    # Hide axis if you prefer: plt.axis('off')
    plt.savefig(f'{output_dir}/id{buffalo_id}_conv_layer2_feature_map{z}_{today_date}.png', dpi=600, bbox_inches='tight')
    plt.show()
Filters shape: torch.Size([4, 4, 3, 3])
4

Convolutional layer 3

Activation hook

Code
# -----------------------------------------------------------
# Register a forward hook on the third convolution layer
#    in the model's 'conv_habitat' block
# -----------------------------------------------------------
model.conv_habitat.conv2d[4].register_forward_hook(get_activation("hab_conv3"))

# -----------------------------------------------------------
# Perform a forward pass through the model with the desired input
#    The feature maps from the hooked layer will be stored in 'activation'
# -----------------------------------------------------------
out = model((x1, x2, x3))  # e.g., model((spatial_data_x, scalars_to_grid, bearing_x))

# -----------------------------------------------------------
# Retrieve the captured feature maps from the dictionary
#    and move them to the CPU for inspection
# -----------------------------------------------------------
feat_maps3 = activation["hab_conv3"].cpu()
print("Feature map shape:", feat_maps3.shape)
# Typically shape: (batch_size, out_channels, height, width)

# -----------------------------------------------------------
# Visualize the feature maps for the first sample in the batch
# -----------------------------------------------------------
feat_maps3_sample = feat_maps3[0]  # Shape: (out_channels, H, W)
num_maps3 = feat_maps3_sample.shape[0]
print("Number of feature maps:", num_maps3)
Feature map shape: torch.Size([1, 4, 101, 101])
Number of feature maps: 4

Extract filters and plot

Code
# -------------------------------------------------------------------------
# Extract the weights (filters) from the second convolution layer in conv_habitat
# -------------------------------------------------------------------------
filters_c3 = model.conv_habitat.conv2d[4].weight.data.clone().cpu()
print("Filters shape:", filters_c3.shape)
# Typically (out_channels, in_channels, kernel_height, kernel_width)

# -------------------------------------------------------------------------
# Visualize each filter’s first channel in a grid of subplots
# -------------------------------------------------------------------------
num_filters_c3 = filters_c3.shape[1]
print(num_filters_c3)

for z in range(num_maps3):

    fig, axes = plt.subplots(2, num_filters_c3, figsize=(2*num_filters_c3, 4))
    for i in range(num_filters_c3):

        # Add the covariates as the first row of subplots
        axes[0,i].imshow(feat_maps2_sample[i].numpy() * x_mask * y_mask, cmap='viridis')
        axes[0,i].axis('off')
        axes[0,i].set_title(f"Layer 2, Map {z+1}")


        kernel = filters_c3[z, i, :, :]  # Show the first input channel
        im = axes[1,i].imshow(kernel, cmap='viridis')
        axes[1,i].axis('off')
        axes[1,i].set_title(f'Layer 3, Filter {z+1}')
        # Annotate each cell with the numeric value
        for (j, k), val in np.ndenumerate(kernel):
            axes[1,i].text(k, j, f'{val:.2f}', ha='center', va='center', color='white')

    plt.tight_layout()
    plt.savefig(f'{output_dir}/id{buffalo_id}_conv_layer3_filters{z}_{today_date}.png', dpi=600, bbox_inches='tight')
    plt.show()


    # -----------------------------------------------------------
    # 6. Loop over each feature map channel and save them as images.
    #    Multiply by x_mask * y_mask if you need to mask out edges.
    # -----------------------------------------------------------

    plt.figure()
    plt.imshow(feat_maps3_sample[z].numpy() * x_mask * y_mask, cmap='viridis')
    plt.title(f"Habitat selection log probability")
    # Hide axis if you prefer: plt.axis('off')
    plt.savefig(f'{output_dir}/id{buffalo_id}_conv_layer3_feature_map{z}_{today_date}.png', dpi=600, bbox_inches='tight')
    plt.show()
Filters shape: torch.Size([4, 4, 3, 3])
4

Checking estimated movement parameters

Similarly to the convolutional layers, we can set hooks to extract the predicted movement parameters from the model, and assess how variable that is across samples.

Code
# -------------------------------------------------------------------------
# Create a list to store the intermediate output from the fully connected
#    movement sub-network (fcn_movement_all)
# -------------------------------------------------------------------------
intermediate_output = []

def hook(module, input, output):
    """
    Hook function that captures the output of the specified layer
    (fcn_movement_all) during the forward pass.
    """
    intermediate_output.append(output)

# -------------------------------------------------------------------------
# Register the forward hook on 'fcn_movement_all', so its outputs
#    are recorded every time the model does a forward pass.
# -------------------------------------------------------------------------
hook_handle = model.fcn_movement_all.register_forward_hook(hook)

# -------------------------------------------------------------------------
# Perform a forward pass with the model in evaluation mode,
#    disabling gradient computation.
# -------------------------------------------------------------------------
model.eval()
with torch.no_grad():
    final_output = model((x1, x2, x3))

# -------------------------------------------------------------------------
# Inspect the captured intermediate output
#    'intermediate_output[0]' corresponds to the first (and only) forward pass.
# -------------------------------------------------------------------------
print("Intermediate output shape:", intermediate_output[0].shape)
print("Intermediate output values:", intermediate_output[0][0])

# -------------------------------------------------------------------------
# Remove the hook to avoid repeated capturing in subsequent passes
# -------------------------------------------------------------------------
hook_handle.remove()

# -------------------------------------------------------------------------
# Unpack the parameters from the FCN output (assumes a specific ordering)
# -------------------------------------------------------------------------
gamma_shape1, gamma_scale1, gamma_weight1, \
gamma_shape2, gamma_scale2, gamma_weight2, \
vonmises_mu1, vonmises_kappa1, vonmises_weight1, \
vonmises_mu2, vonmises_kappa2, vonmises_weight2 = intermediate_output[0][0]

# -------------------------------------------------------------------------
# Convert parameters from log-space (if applicable) and print them
#    Gamma and von Mises parameters
# -------------------------------------------------------------------------
# --- Gamma #1 ---
print("Gamma shape 1:", torch.exp(gamma_shape1))
print("Gamma scale 1:", torch.exp(gamma_scale1))
print("Gamma weight 1:",
      torch.exp(gamma_weight1) / (torch.exp(gamma_weight1) + torch.exp(gamma_weight2)))

# --- Gamma #2 ---
print("Gamma shape 2:", torch.exp(gamma_shape2))
print("Gamma scale 2:", torch.exp(gamma_scale2) * 500)  # scale factor 500
print("Gamma weight 2:",
      torch.exp(gamma_weight2) / (torch.exp(gamma_weight1) + torch.exp(gamma_weight2)))

# --- von Mises #1 ---
# % (2*np.pi) ensures the mu (angle) is wrapped within [0, 2π)
print("Von Mises mu 1:", vonmises_mu1 % (2*np.pi))
print("Von Mises kappa 1:", torch.exp(vonmises_kappa1))
print("Von Mises weight 1:",
      torch.exp(vonmises_weight1) / (torch.exp(vonmises_weight1) + torch.exp(vonmises_weight2)))

# --- von Mises #2 ---
print("Von Mises mu 2:", vonmises_mu2 % (2*np.pi))
print("Von Mises kappa 2:", torch.exp(vonmises_kappa2))
print("Von Mises weight 2:",
      torch.exp(vonmises_weight2) / (torch.exp(vonmises_weight1) + torch.exp(vonmises_weight2)))
Intermediate output shape: torch.Size([1, 12])
Intermediate output values: tensor([ 1.1870,  1.6020, -0.9378,  0.6392, -0.9164,  0.9960,  0.0760, -0.6234,
         0.5451,  0.1789, -0.3929, -1.2263], device='mps:0')
Gamma shape 1: tensor(3.2773, device='mps:0')
Gamma scale 1: tensor(4.9631, device='mps:0')
Gamma weight 1: tensor(0.1263, device='mps:0')
Gamma shape 2: tensor(1.8950, device='mps:0')
Gamma scale 2: tensor(199.9785, device='mps:0')
Gamma weight 2: tensor(0.8737, device='mps:0')
Von Mises mu 1: tensor(0.0760, device='mps:0')
Von Mises kappa 1: tensor(0.5361, device='mps:0')
Von Mises weight 1: tensor(0.8546, device='mps:0')
Von Mises mu 2: tensor(0.1789, device='mps:0')
Von Mises kappa 2: tensor(0.6751, device='mps:0')
Von Mises weight 2: tensor(0.1454, device='mps:0')

Plot the movement distributions

We can use the movement parameters to plot the step length and turning angle distributions for the sample covariates.

Code
# -------------------------------------------------------------------------
# Define helper functions for calculating Gamma and von Mises log-densities
# -------------------------------------------------------------------------
def gamma_density(x, shape, scale):
    """
    Computes the log of the Gamma density for each value in x.

    Args:
      x (Tensor): Input values for which to compute the density.
      shape (float): Gamma shape parameter
      scale (float): Gamma scale parameter

    Returns:
      Tensor: The log of the Gamma probability density at each x.
    """
    return -1*torch.lgamma(shape) - shape*torch.log(scale) \
           + (shape - 1)*torch.log(x) - x/scale

def vonmises_density(x, kappa, vm_mu):
    """
    Computes the log of the von Mises density for each value in x.

    Args:
      x (Tensor): Input angles in radians.
      kappa (float): Concentration parameter (kappa)
      vm_mu (float): Mean direction parameter (mu)

    Returns:
      Tensor: The log of the von Mises probability density at each x.
    """
    return kappa*torch.cos(x - vm_mu) - 1*(np.log(2*torch.pi) + torch.log(torch.special.i0(kappa)))


# -------------------------------------------------------------------------
# Round and display the mixture weights for the Gamma distributions
# -------------------------------------------------------------------------
gamma_weight1_recovered = torch.exp(gamma_weight1)/(torch.exp(gamma_weight1) + torch.exp(gamma_weight2))
rounded_gamma_weight1 = round(gamma_weight1_recovered.item(), 2)

gamma_weight2_recovered = torch.exp(gamma_weight2)/(torch.exp(gamma_weight1) + torch.exp(gamma_weight2))
rounded_gamma_weight2 = round(gamma_weight2_recovered.item(), 2)

# -------------------------------------------------------------------------
# Round and display the mixture weights for the von Mises distributions
# -------------------------------------------------------------------------
vonmises_weight1_recovered = torch.exp(vonmises_weight1)/(torch.exp(vonmises_weight1) + torch.exp(vonmises_weight2))
rounded_vm_weight1 = round(vonmises_weight1_recovered.item(), 2)

vonmises_weight2_recovered = torch.exp(vonmises_weight2)/(torch.exp(vonmises_weight1) + torch.exp(vonmises_weight2))
rounded_vm_weight2 = round(vonmises_weight2_recovered.item(), 2)


# -------------------------------------------------------------------------
# 1. Plotting the Gamma mixture distribution
#    a) Generate x values
#    b) Compute individual Gamma log densities
#    c) Exponentiate and combine using recovered weights
# -------------------------------------------------------------------------
x_values = torch.linspace(1, 101, 1000).to(device)
gamma1_density = gamma_density(x_values, torch.exp(gamma_shape1), torch.exp(gamma_scale1))
gamma2_density = gamma_density(x_values, torch.exp(gamma_shape2), torch.exp(gamma_scale2)*500)
gamma_mixture_density = gamma_weight1_recovered*torch.exp(gamma1_density) \
                        + gamma_weight2_recovered*torch.exp(gamma2_density)

# Move results to CPU and convert to NumPy for plotting
x_values_np = x_values.cpu().numpy()
gamma1_density_np = np.exp(gamma1_density.cpu().numpy())
gamma2_density_np = np.exp(gamma2_density.cpu().numpy())
gamma_mixture_density_np = gamma_mixture_density.cpu().numpy()

# -------------------------------------------------------------------------
# 2. Plot the Gamma distributions and their mixture
# -------------------------------------------------------------------------
plt.plot(x_values_np, gamma1_density_np, label=f'Gamma 1 Density: weight = {rounded_gamma_weight1}')
plt.plot(x_values_np, gamma2_density_np, label=f'Gamma 2 Density: weight = {rounded_gamma_weight2}')
plt.plot(x_values_np, gamma_mixture_density_np, label='Gamma Mixture Density')
plt.xlabel('x')
plt.ylabel('Density')
plt.title('Gamma Density Function')
plt.legend()
plt.show()


# -------------------------------------------------------------------------
# 3. Plotting the von Mises mixture distribution
#    a) Generate x values from -π to π
#    b) Compute individual von Mises log densities
#    c) Exponentiate and combine using recovered weights
# -------------------------------------------------------------------------
x_values = torch.linspace(-np.pi, np.pi, 1000).to(device)
vonmises1_density = vonmises_density(x_values, torch.exp(vonmises_kappa1), vonmises_mu1)
vonmises2_density = vonmises_density(x_values, torch.exp(vonmises_kappa2), vonmises_mu2)
vonmises_mixture_density = vonmises_weight1_recovered*torch.exp(vonmises1_density) \
                           + vonmises_weight2_recovered*torch.exp(vonmises2_density)

# Move results to CPU and convert to NumPy for plotting
x_values_np = x_values.cpu().numpy()
vonmises1_density_np = np.exp(vonmises1_density.cpu().numpy())
vonmises2_density_np = np.exp(vonmises2_density.cpu().numpy())
vonmises_mixture_density_np = vonmises_mixture_density.cpu().numpy()

# -------------------------------------------------------------------------
# 4. Plot the von Mises distributions and their mixture
# -------------------------------------------------------------------------
plt.plot(x_values_np, vonmises1_density_np, label=f'Von Mises 1 Density: weight = {rounded_vm_weight1}')
plt.plot(x_values_np, vonmises2_density_np, label=f'Von Mises 2 Density: weight = {rounded_vm_weight2}')
plt.plot(x_values_np, vonmises_mixture_density_np, label='Von Mises Mixture Density')
plt.xlabel('x (radians)')
plt.ylabel('Density')
plt.title('Von Mises Density Function')
plt.ylim(0, 0.4)  # Set a limit for the y-axis
plt.legend()
plt.show()

Generate a distribution of movement parameters

To see how variable the movement parameters are across samples, we can generate a distribution of movement parameters from a batch of samples.

We take the code from above that we used to create the DataLoader for the test data and increase the batch size (to get more samples to create the distribution from).

As we’re not using the test dataset any more, we’ll just put all of the samples in the same batch, and generate movement parameters for all of them.

Code
print(f'There are {len(dataset_test)} samples in the test dataset')
bs = len(dataset_test) # batch size
dataloader_test = DataLoader(dataset=dataset_test, batch_size=bs, shuffle=True)
There are 1011 samples in the test dataset

Take all of the samples from the test dataset and put them in a single batch.

Code
# -----------------------------------------------------------
# Fetch a batch of data from the training dataloader
# -----------------------------------------------------------
x1_batch, x2_batch, x3_batch, labels = next(iter(dataloader_test))

x1_batch = x1_batch.to(device)
x2_batch = x2_batch.to(device)
x3_batch = x3_batch.to(device)
labels = labels.to(device)

# -----------------------------------------------------------
# Register a forward hook to capture the outputs
#    from 'fcn_movement_all' during the forward pass
# -----------------------------------------------------------
hook_handle = model.fcn_movement_all.register_forward_hook(hook)

# -----------------------------------------------------------
# Perform a forward pass in evaluation mode to generate
#    and capture the sub-network's outputs in 'intermediate_output'
# -----------------------------------------------------------
model.eval()  # Disables certain layers like dropout

# Pass the batch through the model
final_output = model((x1_batch, x2_batch, x3_batch))

# -----------------------------------------------------------
# Prepare lists to store the distribution parameters
#    for each sample in the batch
# -----------------------------------------------------------
gamma_shape1_list = []
gamma_scale1_list = []
gamma_weight1_list = []
gamma_shape2_list = []
gamma_scale2_list = []
gamma_weight2_list = []
vonmises_mu1_list = []
vonmises_kappa1_list = []
vonmises_weight1_list = []
vonmises_mu2_list = []
vonmises_kappa2_list = []
vonmises_weight2_list = []

# -----------------------------------------------------------
# Extract parameters from 'intermediate_output'
#    for every sample in the batch
# -----------------------------------------------------------
for batch_output in intermediate_output:
    # Each 'batch_output' corresponds to one forward pass;
    # it might contain multiple samples if the batch size > 1
    for sample_output in batch_output:
        # Unpack the 12 parameters of the Gamma and von Mises mixtures
        gamma_shape1, gamma_scale1, gamma_weight1, \
        gamma_shape2, gamma_scale2, gamma_weight2, \
        vonmises_mu1, vonmises_kappa1, vonmises_weight1, \
        vonmises_mu2, vonmises_kappa2, vonmises_weight2 = sample_output

        # Convert log-space parameters to real space, then store
        gamma_shape1_list.append(torch.exp(gamma_shape1).item())
        gamma_scale1_list.append(torch.exp(gamma_scale1).item())
        gamma_weight1_list.append(
            (torch.exp(gamma_weight1)/(torch.exp(gamma_weight1) + torch.exp(gamma_weight2))).item()
        )
        gamma_shape2_list.append(torch.exp(gamma_shape2).item())
        gamma_scale2_list.append((torch.exp(gamma_scale2)*500).item())  # scale factor 500
        gamma_weight2_list.append(
            (torch.exp(gamma_weight2)/(torch.exp(gamma_weight1) + torch.exp(gamma_weight2))).item()
        )
        vonmises_mu1_list.append((vonmises_mu1 % (2*np.pi)).item())
        vonmises_kappa1_list.append(torch.exp(vonmises_kappa1).item())
        vonmises_weight1_list.append(
            (torch.exp(vonmises_weight1)/(torch.exp(vonmises_weight1) + torch.exp(vonmises_weight2))).item()
        )
        vonmises_mu2_list.append((vonmises_mu2 % (2*np.pi)).item())
        vonmises_kappa2_list.append(torch.exp(vonmises_kappa2).item())
        vonmises_weight2_list.append(
            (torch.exp(vonmises_weight2)/(torch.exp(vonmises_weight1) + torch.exp(vonmises_weight2))).item()
        )

Plot the distribution of movement parameters

Code
# -----------------------------------------------------------
# Define a helper function to plot histograms
#    for the collected parameters
# -----------------------------------------------------------
def plot_histogram(data, title, xlabel):
    """
    Plots a histogram of the provided data.

    Args:
        data (list): Data points to plot in a histogram.
        title (str): Title of the histogram plot.
        xlabel (str): X-axis label.
    """
    plt.figure()
    plt.hist(data, bins=30, alpha=0.75)
    plt.title(title)
    plt.xlabel(xlabel)
    plt.ylabel('Frequency')
    plt.show()

# -----------------------------------------------------------
# Plot histograms for each parameter distribution
# -----------------------------------------------------------
plot_histogram(gamma_shape1_list, 'Gamma Shape 1 Distribution', 'Shape 1')
plot_histogram(gamma_scale1_list, 'Gamma Scale 1 Distribution', 'Scale 1')
plot_histogram(gamma_weight1_list, 'Gamma Weight 1 Distribution', 'Weight 1')
plot_histogram(gamma_shape2_list, 'Gamma Shape 2 Distribution', 'Shape 2')
plot_histogram(gamma_scale2_list, 'Gamma Scale 2 Distribution', 'Scale 2')
plot_histogram(gamma_weight2_list, 'Gamma Weight 2 Distribution', 'Weight 2')
plot_histogram(vonmises_mu1_list, 'Von Mises Mu 1 Distribution', 'Mu 1')
plot_histogram(vonmises_kappa1_list, 'Von Mises Kappa 1 Distribution', 'Kappa 1')
plot_histogram(vonmises_weight1_list, 'Von Mises Weight 1 Distribution', 'Weight 1')
plot_histogram(vonmises_mu2_list, 'Von Mises Mu 2 Distribution', 'Mu 2')
plot_histogram(vonmises_kappa2_list, 'Von Mises Kappa 2 Distribution', 'Kappa 2')
plot_histogram(vonmises_weight2_list, 'Von Mises Weight 2 Distribution', 'Weight 2')

# -----------------------------------------------------------
# Remove the hook to stop capturing outputs
#    in subsequent forward passes
# -----------------------------------------------------------
hook_handle.remove()

Importing spatial data

Instead of importing the stacks of local layers (one for each step), here we want to import the spatial covariates for the extent we want to simulate over. We use an extent that covers all of the observed locations, which refer to as the ‘landscape’.

Sentinel-2 bands

Each stack represents a month of median values of cloud-free pixels, and each layer in the stack are the bands.

During the data preparation all of these layers were scaled by 10,000, and don’t need to be scaled any further.

Code
# Specify the directory containing your TIFF files
data_dir = f'{base_path}/mapping/cropped rasters/sentinel2/25m'  # Replace with the actual path to your TIFF files

# Use glob to get a list of all TIFF files matching the pattern
tif_files = glob.glob(os.path.join(data_dir, 'S2_SR_masked_scaled_25m_*.tif'))
print(f'Found {len(tif_files)} TIFF files')
print('\n'.join(tif_files))
Found 12 TIFF files
../mapping/cropped rasters/sentinel2/25m/S2_SR_masked_scaled_25m_2019_08.tif
../mapping/cropped rasters/sentinel2/25m/S2_SR_masked_scaled_25m_2019_09.tif
../mapping/cropped rasters/sentinel2/25m/S2_SR_masked_scaled_25m_2019_01.tif
../mapping/cropped rasters/sentinel2/25m/S2_SR_masked_scaled_25m_2019_02.tif
../mapping/cropped rasters/sentinel2/25m/S2_SR_masked_scaled_25m_2019_03.tif
../mapping/cropped rasters/sentinel2/25m/S2_SR_masked_scaled_25m_2019_07.tif
../mapping/cropped rasters/sentinel2/25m/S2_SR_masked_scaled_25m_2019_06.tif
../mapping/cropped rasters/sentinel2/25m/S2_SR_masked_scaled_25m_2019_12.tif
../mapping/cropped rasters/sentinel2/25m/S2_SR_masked_scaled_25m_2019_04.tif
../mapping/cropped rasters/sentinel2/25m/S2_SR_masked_scaled_25m_2019_10.tif
../mapping/cropped rasters/sentinel2/25m/S2_SR_masked_scaled_25m_2019_11.tif
../mapping/cropped rasters/sentinel2/25m/S2_SR_masked_scaled_25m_2019_05.tif
Code
# Initialise a dictionary to store data with date as the key
data_dict = {}

# Loop over each TIFF file to read and process the data
for tif_file in tif_files:
    # Extract the filename from the path
    filename = os.path.basename(tif_file)

    # Extract the date from the filename
    # Assuming filenames are in the format 'S2_SR_masked_YYYY_MM.tif'
    date_str = filename.replace('S2_SR_masked_scaled_25m_', '').replace('.tif', '')
    # date_str will be something like '2019_01'

    # Read the TIFF file using rasterio
    with rasterio.open(tif_file) as src:
        # Read all bands of the TIFF file
        data = src.read()
        # 'data' is a NumPy array with shape (bands, height, width)

        # Count the number of cells that are NaN
        n_nan = np.isnan(data).sum()

        print(f"Date: {date_str}")
        print(f"Number of NaN values in {date_str}: {n_nan}")
        print(f'Proportion of NaN values: {n_nan / data.size:.4%}\n')

        # Replace NaN values with zeros
        data = np.nan_to_num(data, nan=0)

        # Add the data to the dictionary with date as the key
        data_dict[date_str] = data
Date: 2019_08
Number of NaN values in 2019_08: 144
Proportion of NaN values: 0.0002%

Date: 2019_09
Number of NaN values in 2019_09: 36
Proportion of NaN values: 0.0001%

Date: 2019_01
Number of NaN values in 2019_01: 2460
Proportion of NaN values: 0.0037%

Date: 2019_02
Number of NaN values in 2019_02: 420
Proportion of NaN values: 0.0006%

Date: 2019_03
Number of NaN values in 2019_03: 478731
Proportion of NaN values: 0.7291%

Date: 2019_07
Number of NaN values in 2019_07: 96
Proportion of NaN values: 0.0001%

Date: 2019_06
Number of NaN values in 2019_06: 144
Proportion of NaN values: 0.0002%

Date: 2019_12
Number of NaN values in 2019_12: 0
Proportion of NaN values: 0.0000%

Date: 2019_04
Number of NaN values in 2019_04: 13296
Proportion of NaN values: 0.0202%

Date: 2019_10
Number of NaN values in 2019_10: 0
Proportion of NaN values: 0.0000%

Date: 2019_11
Number of NaN values in 2019_11: 48
Proportion of NaN values: 0.0001%

Date: 2019_05
Number of NaN values in 2019_05: 144
Proportion of NaN values: 0.0002%
Code
# Select some bands from the processed data stored in 'data_dict' for plotting
layers_to_plot = []

# Specify the date and band numbers you want to plot
dates_to_plot = ['2019_01', '2019_05']  # This grabs all available dates. You can select specific ones if needed.
bands_to_plot = [1, 2, 3]  # Band indices for bands 2, 3, and 4, which are B, G, and R

# Loop through the selected dates and bands to prepare them for plotting
for date_str in dates_to_plot:
    data = data_dict[date_str]  # Get the normalized data for this date

    for band_idx in bands_to_plot:
        # Collect the specific band for plotting
        layers_to_plot.append((data[band_idx], band_idx + 1, date_str))

# Plot the stored layers
for band, band_number, date_str in layers_to_plot:
    plt.figure(figsize=(8, 6))
    plt.imshow(band, cmap='viridis')
    plt.title(f'Band {band_number} - {date_str}')
    plt.colorbar() #label='Normalized Value'
    plt.show()

Plot as RGB

We can also visualise the Sentinel-2 bands as an RGB image, using the Red, Green and Blue bands.

The plotting was a bit dark so we will adjust the brightness of the image using a gamma correction.

Code
# Specify the date for the RGB layers
date_str = '2019_08'

# pull out the RGB bands
r_band = data_dict[date_str][3]
g_band = data_dict[date_str][2]
b_band = data_dict[date_str][1]

# Stack the bands along a new axis
rgb_image = np.stack([r_band, g_band, b_band], axis=-1)
# Normalize to the range [0, 1] for display
rgb_image = (rgb_image - rgb_image.min()) / (rgb_image.max() - rgb_image.min())

# Apply gamma correction to the image
gamma = 1.75
rgb_image = rgb_image ** (1/gamma)

plt.figure()  # Create a new figure
plt.imshow(rgb_image)
plt.title('Sentinel 2 RGB')
plt.show()
plt.close()  # Close the figure to free memory

Slope

Code
# Path to the slope raster file
file_path = f'{base_path}/mapping/cropped rasters/slope.tif'

# read the raster file
with rasterio.open(file_path) as src:
    # Read the raster band as separate variable
    slope_landscape = src.read(1)
    # Get the metadata of the raster
    slope_meta = src.meta
    raster_transform = src.transform # same as the raster transform in the NDVI raster read
Code
# Check the slope metadata:
print("Slope metadata:")
print(slope_meta)
print("\n")

# Check the shape (rows, columns) of the slope landscape raster:
print("Shape of slope landscape raster:")
print(slope_landscape.shape)
print("\n")

# Check for NA values in the slope raster:
print("Number of NA values in the slope raster:")
print(np.isnan(slope_landscape).sum())

# Replace NaNs in the slope array with 0.0 (representing water):
slope_landscape = np.nan_to_num(slope_landscape, nan=0.0)

# Define the maximum and minimum slope values from the stack of local layers:
slope_max = 12.2981
slope_min = 0.0006

# Convert the slope landscape data from a NumPy array to a PyTorch tensor:
slope_landscape_tens = torch.from_numpy(slope_landscape)

# Normalize the slope landscape data:
slope_landscape_norm = (slope_landscape_tens - slope_min) / (slope_max - slope_min)

# Visualize the slope landscape (note: displaying the original tensor, not the normalised data):
plt.imshow(slope_landscape_tens.numpy())
plt.colorbar()
plt.show()
Slope metadata:
{'driver': 'GTiff', 'dtype': 'float32', 'nodata': nan, 'width': 2400, 'height': 2280, 'count': 1, 'crs': CRS.from_wkt('PROJCS["GDA94 / Geoscience Australia Lambert",GEOGCS["GDA94",DATUM["Geocentric_Datum_of_Australia_1994",SPHEROID["GRS 1980",6378137,298.257222101,AUTHORITY["EPSG","7019"]],AUTHORITY["EPSG","6283"]],PRIMEM["Greenwich",0,AUTHORITY["EPSG","8901"]],UNIT["degree",0.0174532925199433,AUTHORITY["EPSG","9122"]],AUTHORITY["EPSG","4283"]],PROJECTION["Lambert_Conformal_Conic_2SP"],PARAMETER["latitude_of_origin",0],PARAMETER["central_meridian",134],PARAMETER["standard_parallel_1",-18],PARAMETER["standard_parallel_2",-36],PARAMETER["false_easting",0],PARAMETER["false_northing",0],UNIT["metre",1,AUTHORITY["EPSG","9001"]],AXIS["Easting",EAST],AXIS["Northing",NORTH],AUTHORITY["EPSG","3112"]]'), 'transform': Affine(25.0, 0.0, 0.0,
       0.0, -25.0, -1406000.0)}


Shape of slope landscape raster:
(2280, 2400)


Number of NA values in the slope raster:
9356

Subset function

As we described the subset function in the deepSSF_simulations.ipynb notebook, and stored it in the deepSSF_functions.py script, we will just import it here.

Code
subset_function = deepSSF_utils.subset_raster_with_padding_torch

Testing the subset function

We want to ensure that the function pads the raster when it is outside the landscape extent.

Code
# Pick a location (x, y) from the buffalo DataFrame
x = buffalo_df['x1_'].iloc[0]
y = buffalo_df['y1_'].iloc[0]

# Define the window size for the subset
window_size = 101

# Get the subset of the slope landscape
slope_subset, origin_x, origin_y = subset_function(slope_landscape_norm, x, y, window_size, raster_transform)

# For sentinel 2 data
selected_month = '2019_01'
# Get the data for the selected month
s2_data = data_dict[selected_month]

# Convert the NumPy array to a PyTorch tensor
s2_tensor = torch.from_numpy(s2_data)
s2_tensor = s2_tensor.float()  # Ensure the tensor is of type float
print(s2_tensor.shape) # [bands, height, width]

# Get the subset of the Sentinel-2 bands
s2_b1_subset, origin_x, origin_y = subset_function(s2_tensor[0,:,:], x, y, window_size, raster_transform)
s2_b2_subset, origin_x, origin_y = subset_function(s2_tensor[1,:,:], x, y, window_size, raster_transform)
s2_b3_subset, origin_x, origin_y = subset_function(s2_tensor[2,:,:], x, y, window_size, raster_transform)
s2_b4_subset, origin_x, origin_y = subset_function(s2_tensor[3,:,:], x, y, window_size, raster_transform)
s2_b5_subset, origin_x, origin_y = subset_function(s2_tensor[4,:,:], x, y, window_size, raster_transform)
s2_b6_subset, origin_x, origin_y = subset_function(s2_tensor[5,:,:], x, y, window_size, raster_transform)
s2_b7_subset, origin_x, origin_y = subset_function(s2_tensor[6,:,:], x, y, window_size, raster_transform)
s2_b8_subset, origin_x, origin_y = subset_function(s2_tensor[7,:,:], x, y, window_size, raster_transform)
s2_b8a_subset, origin_x, origin_y = subset_function(s2_tensor[8,:,:], x, y, window_size, raster_transform)
s2_b9_subset, origin_x, origin_y = subset_function(s2_tensor[9,:,:], x, y, window_size, raster_transform)
s2_b11_subset, origin_x, origin_y = subset_function(s2_tensor[10,:,:], x, y, window_size, raster_transform)
s2_b12_subset, origin_x, origin_y = subset_function(s2_tensor[11,:,:], x, y, window_size, raster_transform)

# Plot the subset
fig, axs = plt.subplots(2, 2, figsize=(10, 10))

axs[0, 0].imshow(s2_b2_subset.detach().numpy(), cmap='viridis')
axs[0, 0].set_title('Band 2 (blue) Subset')

axs[0, 1].imshow(s2_b3_subset.detach().numpy(), cmap='viridis')
axs[0, 1].set_title('Band 3 (green) Subset')

axs[1, 0].imshow(s2_b4_subset.detach().numpy(), cmap='viridis')
axs[1, 0].set_title('Band 4 (red) Subset')

axs[1, 1].imshow(slope_subset.detach().numpy(), cmap='viridis')
axs[1, 1].set_title('Slope Subset')
torch.Size([12, 2280, 2400])
Text(0.5, 1.0, 'Slope Subset')

Create a mask for edge cells

Due to the padding at the edges of the covariates, convolutional layers create artifacts that can affect the colour scale of the predictions when plotting. To avoid this, we will create a mask that we can apply to the predictions to remove the edge cells.

Code
# Create a mask to remove the edge values for plotting
# (as it affects the colour scale)
x_mask = np.ones_like(slope_subset)
y_mask = np.ones_like(slope_subset)

# Mask out bordering cells
x_mask[:, :3] = -np.inf
x_mask[:, 98:] = -np.inf
y_mask[:3, :] = -np.inf
y_mask[98:, :] = -np.inf

Setup validation

To get the validation running we need a few extra functions.

Firstly, we need to index the Sentinel-2 layers correctly, based on the time of the simulated location. We’ll do this by creating a function that takes day of the year of the simulated location and returns the correct index for the Sentinel-2 layers.

This indexing is slightly different from the indexing we used for the deepSSF_simulations.ipynb notebook, which was indexing NDVI layers. In that case we were indexing the layers directly, and therefore the first entry was at 0 (i.e., March was in month_index = 2). Here, we are creating a string that corresponds to the layer name, and therefore the first entry is at 1. (i.e., March will be at month_index = 3)

Code
# Create a mapping from day of the year to month index
def day_to_month_index(day_of_year):
    # Calculate the year and the day within that year
    base_date = datetime(2019, 1, 1)
    date = base_date + timedelta(days=int(day_of_year) - 1)
    year_diff = date.year - base_date.year
    month_index = (date.month) + (year_diff * 12)  # month index (1-based)
    if month_index == 0:
        month_index += 1
    return month_index

yday = 35
month_index = day_to_month_index(yday)
print(month_index)
2

Check the Sentinel-2 layer indexing

Subset the raster layers at the first observed location of the training data.

Code
# Set the window size for the local layers
# (should be the same as the one used during training)
window_size = 101

# Step index for the buffalo data
step_index = 1503

# starting location of buffalo 2005
x = buffalo_df['x1_'].iloc[step_index]
y = buffalo_df['y1_'].iloc[step_index]
print(f'Starting x and y coordinates: {x}, {y}')

yday = buffalo_df['yday_t2'].iloc[step_index]
print(f'Starting day of the year:     {yday}')

# Get the month index from the day of the year
month_index = day_to_month_index(yday)

# for sentinel 2 data
selected_month = f'2019_{month_index:02d}'
# Get the normalized data for the selected month
s2_data = data_dict[selected_month]

# Convert the NumPy array to a PyTorch tensor
s2_tensor = torch.from_numpy(s2_data)
s2_tensor = s2_tensor.float()  # Ensure the tensor is of type float
print(s2_tensor.shape)

# Get the subset of the Sentinel-2 bands
s2_b1_subset, origin_x, origin_y = subset_function(s2_tensor[0,:,:], x, y, window_size, raster_transform)
s2_b2_subset, origin_x, origin_y = subset_function(s2_tensor[1,:,:], x, y, window_size, raster_transform)
s2_b3_subset, origin_x, origin_y = subset_function(s2_tensor[2,:,:], x, y, window_size, raster_transform)
s2_b4_subset, origin_x, origin_y = subset_function(s2_tensor[3,:,:], x, y, window_size, raster_transform)
s2_b5_subset, origin_x, origin_y = subset_function(s2_tensor[4,:,:], x, y, window_size, raster_transform)
s2_b6_subset, origin_x, origin_y = subset_function(s2_tensor[5,:,:], x, y, window_size, raster_transform)
s2_b7_subset, origin_x, origin_y = subset_function(s2_tensor[6,:,:], x, y, window_size, raster_transform)
s2_b8_subset, origin_x, origin_y = subset_function(s2_tensor[7,:,:], x, y, window_size, raster_transform)
s2_b8a_subset, origin_x, origin_y = subset_function(s2_tensor[8,:,:], x, y, window_size, raster_transform)
s2_b9_subset, origin_x, origin_y = subset_function(s2_tensor[9,:,:], x, y, window_size, raster_transform)
s2_b11_subset, origin_x, origin_y = subset_function(s2_tensor[10,:,:], x, y, window_size, raster_transform)
s2_b12_subset, origin_x, origin_y = subset_function(s2_tensor[11,:,:], x, y, window_size, raster_transform)

# Get the subset of the slope landscape
slope_subset, origin_x, origin_y = subset_function(slope_landscape_norm, x, y, window_size, raster_transform)

# Plot the subset
fig, axs = plt.subplots(2, 2, figsize=(10, 10))

axs[0, 0].imshow(s2_b2_subset.numpy(), cmap='viridis')
axs[0, 0].set_title('Band 2 (blue) Subset')

axs[0, 1].imshow(s2_b3_subset.numpy(), cmap='viridis')
axs[0, 1].set_title('Band 3 (green) Subset')

axs[1, 0].imshow(s2_b4_subset.numpy(), cmap='viridis')
axs[1, 0].set_title('Band 4 (red) Subset')

axs[1, 1].imshow(slope_subset.numpy(), cmap='viridis')
axs[1, 1].set_title('Slope Subset')
Starting x and y coordinates: 43036.55199338104, -1437242.6615283354
Starting day of the year:     271.71
torch.Size([12, 2280, 2400])
Text(0.5, 1.0, 'Slope Subset')

Plot as RGB

Code
# pull out the RGB bands
r_band = s2_b4_subset.detach().numpy()
g_band = s2_b3_subset.detach().numpy()
b_band = s2_b2_subset.detach().numpy()

# Stack the bands along a new axis
rgb_image = np.stack([r_band, g_band, b_band], axis=-1)
# Normalize to the range [0, 1] for display
rgb_image = (rgb_image - rgb_image.min()) / (rgb_image.max() - rgb_image.min())

plt.figure()  # Create a new figure
plt.imshow(rgb_image)
plt.title('Sentinel 2 RGB')
plt.show()
plt.close()  # Close the figure to free memory

Next-step probability values

We can now calculate the next-step probabilities for each observed step. As we generate habitat selection, movement and next-step probability surfaces, we can get the predicted probability values for each one, which can be compared to the respective process in the SSF.

The process for generating the next-step probabilities is as follows:

  1. Get the current location of the individual
  2. Crop out the local layers for the current location
  3. Run the model of the local layers to get the habitat selection, movement and next-step probability surfaces
  4. Get the predicted probability values at the location of the next step
  5. Store the predicted probability values and export them as a csv for comparison with the SSF

First, select the data to generate prediction values for. For testing the function we can select a subset.

Code
# To select a subset of samples to test the function
# test_data = buffalo_df.iloc[0:10]

# To select all of the data
test_data = buffalo_df

# Get the number of samples in the test data
n_samples = len(test_data)
print(f'Number of samples: {n_samples}')

# Create empty vectors to store the predicted probabilities
habitat_probs = np.repeat(0., n_samples)
move_probs = np.repeat(0., n_samples)
next_step_probs = np.repeat(0., n_samples)
Number of samples: 10103

Loop over each step

Code
# Create directory for saving prediction images
os.makedirs(f'{output_dir}/prediction_images', exist_ok=True)

# Start at 1 so the bearing at t - 1 is available
for i in range(1, n_samples):

  sample = test_data.iloc[i]

  # Current location (x1, y1)
  x = sample['x1_']
  y = sample['y1_']

  # Convert geographic coordinates to pixel coordinates
  px, py = ~raster_transform * (x, y)

  # Next step location (x2, y2)
  x2 = sample['x2_']
  y2 = sample['y2_']

  # Convert geographic coordinates to pixel coordinates
  px2, py2 = ~raster_transform * (x2, y2)

  # The difference in x and y coordinates
  d_x = x2 - x
  d_y = y2 - y
  # print('d_x and d_y are ', d_x, d_y) # Debugging

  # Temporal covariates for t1
  hour_t1_sin1 = sample['hour_t1_sin1']
  hour_t1_cos1 = sample['hour_t1_cos1']
  hour_t1_sin2 = sample['hour_t1_sin2']
  hour_t1_cos2 = sample['hour_t1_cos2']
  yday_t1_sin1 = sample['yday_t1_sin1']
  yday_t1_cos1 = sample['yday_t1_cos1']
  yday_t1_sin2 = sample['yday_t1_sin2']
  yday_t1_cos2 = sample['yday_t1_cos2']

  # Bearing of previous step (t - 1)
  bearing = sample['bearing_tm1']

  # Hour of the day (for saving the plot)
  hour_t2 = sample['hour_t2']

  # Day of the year
  yday = sample['yday_t2']

  # Convert day of the year to month index
  month_index = day_to_month_index(yday)
  # print(month_index)

  # For sentinel 2 data
  selected_month = f'2019_{month_index:02d}'
  # Get the Sentinel-2 layers for the selected month
  s2_data = data_dict[selected_month]

  # Convert the Sentinel-2 data from a NumPy array to a PyTorch tensor
  s2_tensor = torch.from_numpy(s2_data)
  s2_tensor = s2_tensor.float()  # Ensure the tensor is of type float
  # print(s2_tensor.shape)

  # Crop out the Sentinel-2 subsets at the location of x1, y1
  s2_b1_subset, origin_x, origin_y = subset_function(s2_tensor[0,:,:], x, y, window_size, raster_transform)
  s2_b2_subset, origin_x, origin_y = subset_function(s2_tensor[1,:,:], x, y, window_size, raster_transform)
  s2_b3_subset, origin_x, origin_y = subset_function(s2_tensor[2,:,:], x, y, window_size, raster_transform)
  s2_b4_subset, origin_x, origin_y = subset_function(s2_tensor[3,:,:], x, y, window_size, raster_transform)
  s2_b5_subset, origin_x, origin_y = subset_function(s2_tensor[4,:,:], x, y, window_size, raster_transform)
  s2_b6_subset, origin_x, origin_y = subset_function(s2_tensor[5,:,:], x, y, window_size, raster_transform)
  s2_b7_subset, origin_x, origin_y = subset_function(s2_tensor[6,:,:], x, y, window_size, raster_transform)
  s2_b8_subset, origin_x, origin_y = subset_function(s2_tensor[7,:,:], x, y, window_size, raster_transform)
  s2_b8a_subset, origin_x, origin_y = subset_function(s2_tensor[8,:,:], x, y, window_size, raster_transform)
  s2_b9_subset, origin_x, origin_y = subset_function(s2_tensor[9,:,:], x, y, window_size, raster_transform)
  s2_b11_subset, origin_x, origin_y = subset_function(s2_tensor[10,:,:], x, y, window_size, raster_transform)
  s2_b12_subset, origin_x, origin_y = subset_function(s2_tensor[11,:,:], x, y, window_size, raster_transform)

  # Crop out the slope subset at the location of x1, y1
  slope_subset, origin_x, origin_y = subset_function(slope_landscape_norm, x, y, window_size, raster_transform)

  # Location of the next step in local pixel coordinates
  px2_subset = px2 - origin_x
  py2_subset = py2 - origin_y
  # print('px2_subset and py2_subset are ', px2_subset, py2_subset) # Debugging

  # Stack the channels along a new axis
  x1 = torch.stack([s2_b1_subset,
                    s2_b2_subset,
                    s2_b3_subset,
                    s2_b4_subset,
                    s2_b5_subset,
                    s2_b6_subset,
                    s2_b7_subset,
                    s2_b8_subset,
                    s2_b8a_subset,
                    s2_b9_subset,
                    s2_b11_subset,
                    s2_b12_subset,
                    slope_subset], dim=0)

  # Add a batch dimension (required to be the correct dimension for the model)
  x1 = x1.unsqueeze(0).to(device)
  # print(x1.shape)

  # Temporal covariates for t1
  hour_t1_sin1_tensor = torch.tensor(hour_t1_sin1).float()
  hour_t1_cos1_tensor = torch.tensor(hour_t1_cos1).float()
  hour_t1_sin2_tensor = torch.tensor(hour_t1_sin2).float()
  hour_t1_cos2_tensor = torch.tensor(hour_t1_cos2).float()
  yday_t1_sin1_tensor = torch.tensor(yday_t1_sin1).float()
  yday_t1_cos1_tensor = torch.tensor(yday_t1_cos1).float()
  yday_t1_sin2_tensor = torch.tensor(yday_t1_sin2).float()
  yday_t1_cos2_tensor = torch.tensor(yday_t1_cos2).float()

  # Stack tensors
  x2 = torch.stack((hour_t1_sin1_tensor.unsqueeze(0),
                    hour_t1_cos1_tensor.unsqueeze(0),
                    hour_t1_sin2_tensor.unsqueeze(0),
                    hour_t1_cos2_tensor.unsqueeze(0),
                    yday_t1_sin1_tensor.unsqueeze(0),
                    yday_t1_cos1_tensor.unsqueeze(0),
                    yday_t1_sin2_tensor.unsqueeze(0),
                    yday_t1_cos2_tensor.unsqueeze(0)),
                    dim=1).to(device)
  # print(x2)
  # print(x2.shape)

  # put bearing in the correct dimension (batch_size, 1)
  bearing = torch.tensor(bearing).float().unsqueeze(0).unsqueeze(0).to(device)
  # print(bearing)
  # print(bearing.shape)

  # -------------------------------------------------------------------------
  # Run the model
  # -------------------------------------------------------------------------
  model_output = model((x1, x2, bearing))


  # -------------------------------------------------------------------------
  # Habitat selection probability
  # -------------------------------------------------------------------------
  hab_density = model_output.detach().cpu().numpy()[0,:,:,0]
  hab_density_exp = np.exp(hab_density)

  # Normalise the probability surface to sum to 1
  hab_density_exp_norm = hab_density_exp / np.sum(hab_density_exp)
  # print(np.sum(hab_density_exp_norm))  # Should be 1

  # Store the probability of habitat selection at the location of x2, y2
  # These probabilities are normalised in the model function
  habitat_probs[i] = hab_density_exp_norm[(int(py2_subset), int(px2_subset))]
  # print('Habitat probability value = ', habitat_probs[i])


  # -------------------------------------------------------------------------
  # Movement probability
  # -------------------------------------------------------------------------
  move_density = model_output.detach().cpu().numpy()[0,:,:,1]
  move_density_exp = np.exp(move_density)

  # Normalise the probability surface to sum to 1
  move_density_exp_norm = move_density_exp / np.sum(move_density_exp)
  # print(np.sum(move_density_exp_norm))  # Should be 1

  # Store the movement probability at the location of x2, y2
  # These probabilities are normalised in the model function
  move_probs[i] = move_density_exp_norm[(int(py2_subset), int(px2_subset))]
  # print('Movement probability value = ', move_probs[i])


  # -------------------------------------------------------------------------
  # Next step probability
  # -------------------------------------------------------------------------
  step_density = hab_density + move_density
  step_density_exp = np.exp(step_density)
  # print('Sum of step density exp = ', np.sum(step_density_exp)) # Won't be 1

  step_density_exp_norm = step_density_exp / np.sum(step_density_exp)
  # print('Sum of step density exp norm = ', np.sum(step_density_exp_norm)) # Should be 1

  # Extract the value of the covariates at the location of x2, y2
  next_step_probs[i] = step_density_exp_norm[(int(py2_subset), int(px2_subset))]
  # print('Next-step probability value = ', next_step_probs[i])


  # -------------------------------------------------------------------------
  # Plot the next-step predictions
  # -------------------------------------------------------------------------

  # Plot the first few probability surfaces - change the condition to i < n_steps to plot all
  if i < 51:

    # Mask out bordering cells
    hab_density_mask = hab_density * x_mask * y_mask
    move_density_mask = move_density * x_mask * y_mask
    step_density_mask = step_density * x_mask * y_mask

    # Create a mask for the next step
    next_step_mask = np.ones_like(hab_density)
    next_step_mask[int(py2_subset), int(px2_subset)] = -np.inf

    # Plot the outputs
    fig_out, axs_out = plt.subplots(2, 2, figsize=(10, 8))

    # RGB for plotting
    # pull out the RGB bands
    r_band = s2_b4_subset.detach().numpy()
    g_band = s2_b3_subset.detach().numpy()
    b_band = s2_b2_subset.detach().numpy()

    # Stack the bands along a new axis
    rgb_image = np.stack([r_band, g_band, b_band], axis=-1)
    # Normalize to the range [0, 1] for display
    rgb_image = (rgb_image - rgb_image.min()) / (rgb_image.max() - rgb_image.min())

    # Plot s2_b2
    im1 = axs_out[0, 0].imshow(rgb_image)
    axs_out[0, 0].set_title('Sentinel 2 RGB')

    # Plot habitat selection log-probability
    im2 = axs_out[0, 1].imshow(hab_density_mask * next_step_mask, cmap='viridis')
    axs_out[0, 1].set_title('Habitat selection log-probability')
    fig_out.colorbar(im2, ax=axs_out[0, 1], shrink=0.7)

    # Movement density log-probability
    im3 = axs_out[1, 0].imshow(move_density_mask * next_step_mask, cmap='viridis')
    axs_out[1, 0].set_title('Movement log-probability')
    fig_out.colorbar(im3, ax=axs_out[1, 0], shrink=0.7)

    # Next-step probability
    im4 = axs_out[1, 1].imshow(step_density_mask * next_step_mask, cmap='viridis')
    axs_out[1, 1].set_title('Next-step log-probability')
    fig_out.colorbar(im4, ax=axs_out[1, 1], shrink=0.7)

    filename_covs = f'{output_dir}/prediction_images/id{buffalo_id}_step_index{i+1}_yday{yday}_hour{hour_t2}.png'
    plt.tight_layout()
    plt.savefig(filename_covs, dpi=150) #, bbox_inches='tight'
    # plt.show()
    plt.close()  # Close the figure to free memory
Code
print(next_step_probs)
[0.00000000e+00 2.15532933e-03 4.45809448e-03 ... 4.90309775e-01
 2.74423131e-04 3.73210089e-04]

Make a GIF of the prediction images

Code
# Path to your images
image_folder =  f'{output_dir}/prediction_images'
# Output GIF filename
output_filename = f'{output_dir}/prediction_gif_id{buffalo_id}_yday{yday_t2_integer}_hour{hour_t2_integer}_bearing{bearing_degrees}_next_r{row}_c{column}.gif'
# Create the GIF
create_gif(image_folder, output_filename, fps=5)
<IPython.core.display.Image object>
GIF created successfully: ../Python/outputs/model_training_S2/id2005_scalar_movement_2_2025-06-05/prediction_gif_id2005_yday270_hour18_bearing148_next_r23_c33.gif

Calculate the null probabilities

As each cell has a probability values, we can calculate what the probability would be if the model provided no information at all, and each cell was equally likely to be the next step. This is just 1 divided by the total number of cells.

Code
null_prob = 1 / (window_size ** 2)
print(f'Null probability: {null_prob:.3e}')
Null probability: 9.803e-05

Compute the rolling average of the probabilities

Code
rolling_window_size = 100 # Rolling window size

# Convert to pandas Series and compute rolling mean
rolling_mean_habitat = pd.Series(habitat_probs).rolling(window=window_size, center=True).mean()
rolling_mean_movement = pd.Series(move_probs).rolling(window=window_size, center=True).mean()
rolling_mean_next_step = pd.Series(next_step_probs).rolling(window=window_size, center=True).mean()

Plot the probabilities

We can get an idea of how variable the probabilities are for the habitat selection and movement surfaces, and for the next-step probabilities, by plotting them across the trajectory

Code
# Plot the habitat probs through time as a line graph
plt.plot(habitat_probs[range(100)], color='blue', label='Habitat Probabilities - S2')
plt.plot(rolling_mean_habitat[range(100)], color='red', label='Rolling Mean')
plt.axhline(y=null_prob, color='black', linestyle='--', label='Null Probability')  # null probs
plt.xlabel('Index')
plt.ylabel('Probability')
plt.title('Habitat Probability')
plt.legend()  # Add legend to differentiate lines
plt.show()
plt.savefig(f'{output_dir}/id{buffalo_id}_habitat_probs_100_steps.png', dpi=300, bbox_inches='tight')

# Plot the habitat probs through time as a line graph
plt.plot(habitat_probs[habitat_probs > 0], color='blue', label='Habitat Probabilities - S2')
plt.plot(rolling_mean_habitat[rolling_mean_habitat > 0], color='red', label='Rolling Mean')
plt.axhline(y=null_prob, color='black', linestyle='--', label='Null Probability')  # null probs
plt.xlabel('Index')
plt.ylabel('Probability')
plt.ylim(0, 5e-4)  # Set a limit for the y-axis
plt.title('Habitat Probability')
plt.legend()  # Add legend to differentiate lines
plt.show()
plt.savefig(f'{output_dir}/id{buffalo_id}_habitat_probs.png', dpi=300, bbox_inches='tight')

# Plot the movement probs through time as a line graph
plt.plot(move_probs[move_probs > 0], color='blue', label='Movement Probabilities - S2')
plt.plot(rolling_mean_movement[rolling_mean_movement > 0], color='red', label='Rolling Mean')
plt.axhline(y=null_prob, color='black', linestyle='--', label='Null Probability')  # null probs
plt.xlabel('Index')
plt.ylabel('Probability')
plt.title('Movement Probability')
plt.legend()  # Add legend to differentiate lines
plt.show()
plt.savefig(f'{output_dir}/id{buffalo_id}_move_probs.png', dpi=300, bbox_inches='tight')

# Plot the next step probs through time as a line graph
plt.plot(next_step_probs[next_step_probs > 0], color='blue', label='Next Step Probabilities - S2')
plt.plot(rolling_mean_next_step[rolling_mean_next_step > 0], color='red', label='Rolling Mean')
plt.axhline(y=null_prob, color='black', linestyle='--', label='Null Probability')  # null probs
plt.xlabel('Index')
plt.ylabel('Probability')
plt.title('Next Step Probability')
plt.legend()  # Add legend to differentiate lines
plt.show()
plt.savefig(f'{output_dir}/id{buffalo_id}_next_step_probs.png', dpi=300, bbox_inches='tight')

<Figure size 640x480 with 0 Axes>

Save the probabilities

We can save the probabilities to a csv file to compare with the SSF probabilities.

Code
# Append the probabilities to the dataframe
buffalo_df['habitat_probs'] = habitat_probs
buffalo_df['move_probs'] = move_probs
buffalo_df['next_step_probs'] = next_step_probs

csv_filename = f'{output_dir}/deepSSF_validation_id{buffalo_id}_n{len(test_data)}.csv'
print(csv_filename)
buffalo_df.to_csv(csv_filename, index=True)
../Python/outputs/model_training_S2/id2005_scalar_movement_2_2025-06-05/deepSSF_validation_id2005_n10103.csv