deepSSF Training - S2

Author
Affiliation

Queensland University of Technology, CSIRO

Published

February 27, 2025

Abstract

In this script, we will train a deepSSF model on Sentinel-2 data directly (instead of derived covariates such as NDVI). The training data was generated using the deepSSF_data_prep_S2_id.qmd script, which crops out local images for each step of the observed telemetry data.

There aren’t as many comments or information to help understand the model, so please refer to the deepSSF_train script for more detail.

Import packages

Code
# If using Google Colab, uncomment the following line
# !pip install rasterio

import sys
print(sys.version)  # Print Python version in use

import numpy as np                                      # Array operations
import matplotlib.pyplot as plt                         # Plotting library
import torch                                            # Main PyTorch library
import torch.optim as optim                             # Optimization algorithms
from torch import nn                                    # Neural network modules
from torch.utils.data import Dataset, DataLoader        # Dataset and batch data loading
from datetime import datetime                           # Date/time utilities
import os                                               # Operating system utilities
import pandas as pd                                     # Data manipulation
import rasterio                                         # Geospatial raster data

import deepSSF_model                                    # Import the .py file containing the deepSSF model     
import deepSSF_loss                                     # Import the .py file containing the deepSSF loss function
import deepSSF_early_stopping                           # Import the .py file containing the early stopping function                                     

# Get today's date
today_date = datetime.today().strftime('%Y-%m-%d')

# Set random seed for reproducibility
seed = 42
3.12.5 | packaged by Anaconda, Inc. | (main, Sep 12 2024, 18:18:29) [MSC v.1929 64 bit (AMD64)]

If using Google Colab, uncomment the following lines

The file directories will also need to be changed to match the location of the files in your Google Drive.

Code
# from google.colab import drive
# drive.mount('/content/drive')

Import data

Set paths to data

Code
buffalo_id = 2005
n_samples = 10297

# Specify the path to CSV file
csv_file_path = f'../buffalo_local_data_id/buffalo_{buffalo_id}_data_df_lag_1hr_n{n_samples}.csv'

# Path to your TIF file (slope)
slope_path = f'../buffalo_local_layers_id/buffalo_{buffalo_id}_slope_cent101x101_lag_1hr_n{n_samples}.tif'

# Paths to the Sentinel-2 bands
b1_path = f'../buffalo_local_layers_id/buffalo_{buffalo_id}_s2_b1_cent101x101_lag_1hr_n{n_samples}.tif'
b2_path = f'../buffalo_local_layers_id/buffalo_{buffalo_id}_s2_b2_cent101x101_lag_1hr_n{n_samples}.tif'
b3_path = f'../buffalo_local_layers_id/buffalo_{buffalo_id}_s2_b3_cent101x101_lag_1hr_n{n_samples}.tif'
b4_path = f'../buffalo_local_layers_id/buffalo_{buffalo_id}_s2_b4_cent101x101_lag_1hr_n{n_samples}.tif'
b5_path = f'../buffalo_local_layers_id/buffalo_{buffalo_id}_s2_b5_cent101x101_lag_1hr_n{n_samples}.tif'
b6_path = f'../buffalo_local_layers_id/buffalo_{buffalo_id}_s2_b6_cent101x101_lag_1hr_n{n_samples}.tif'
b7_path = f'../buffalo_local_layers_id/buffalo_{buffalo_id}_s2_b7_cent101x101_lag_1hr_n{n_samples}.tif'
b8_path = f'../buffalo_local_layers_id/buffalo_{buffalo_id}_s2_b8_cent101x101_lag_1hr_n{n_samples}.tif'
b8a_path = f'../buffalo_local_layers_id/buffalo_{buffalo_id}_s2_b8a_cent101x101_lag_1hr_n{n_samples}.tif'
b9_path = f'../buffalo_local_layers_id/buffalo_{buffalo_id}_s2_b9_cent101x101_lag_1hr_n{n_samples}.tif'
b11_path = f'../buffalo_local_layers_id/buffalo_{buffalo_id}_s2_b11_cent101x101_lag_1hr_n{n_samples}.tif'
b12_path = f'../buffalo_local_layers_id/buffalo_{buffalo_id}_s2_b12_cent101x101_lag_1hr_n{n_samples}.tif'

# Path to your TIF file (target variable)
pres_path = f'../buffalo_local_layers_id/buffalo_{buffalo_id}_pres_cent101x101_lag_1hr_n{n_samples}.tif'

Read buffalo data

Code
# Read the CSV file into a DataFrame
buffalo_df = pd.read_csv(csv_file_path)
print(buffalo_df.shape)

# Lag the values in column 'A' by one index to get the bearing of the previous step
buffalo_df['bearing_tm1'] = buffalo_df['bearing'].shift(1)
# Pad the missing value with a specified value, e.g., 0
buffalo_df['bearing_tm1'] = buffalo_df['bearing_tm1'].fillna(0)

# Display the first few rows of the DataFrame
print(buffalo_df.head())
(10103, 35)
             x_            y_                    t_    id           x1_  \
0  41969.310875 -1.435671e+06  2018-07-25T01:04:23Z  2005  41969.310875   
1  41921.521939 -1.435654e+06  2018-07-25T02:04:39Z  2005  41921.521939   
2  41779.439594 -1.435601e+06  2018-07-25T03:04:17Z  2005  41779.439594   
3  41841.203272 -1.435635e+06  2018-07-25T04:04:39Z  2005  41841.203272   
4  41655.463332 -1.435604e+06  2018-07-25T05:04:27Z  2005  41655.463332   

            y1_           x2_           y2_     x2_cent    y2_cent  ...  \
0 -1.435671e+06  41921.521939 -1.435654e+06  -47.788936  16.857110  ...   
1 -1.435654e+06  41779.439594 -1.435601e+06 -142.082345  53.568427  ...   
2 -1.435601e+06  41841.203272 -1.435635e+06   61.763677 -34.322938  ...   
3 -1.435635e+06  41655.463332 -1.435604e+06 -185.739939  31.003534  ...   
4 -1.435604e+06  41618.651923 -1.435608e+06  -36.811409  -4.438037  ...   

     cos_ta         x_min         x_max         y_min         y_max  s2_index  \
0  0.201466  40706.810875  43231.810875 -1.436934e+06 -1.434409e+06         7   
1  0.999770  40659.021939  43184.021939 -1.436917e+06 -1.434392e+06         7   
2 -0.989262  40516.939594  43041.939594 -1.436863e+06 -1.434338e+06         7   
3 -0.942144  40578.703272  43103.703272 -1.436898e+06 -1.434373e+06         7   
4  0.959556  40392.963332  42917.963332 -1.436867e+06 -1.434342e+06         7   

   points_vect_cent  year_t2  yday_t2_2018_base  bearing_tm1  
0               NaN     2018                206     0.000000  
1               NaN     2018                206     2.802478  
2               NaN     2018                206     2.781049  
3               NaN     2018                206    -0.507220  
4               NaN     2018                206     2.976198  

[5 rows x 36 columns]

Importing spatial data

Slope

Code
# Using rasterio
with rasterio.open(slope_path) as slope:
    # Read all layers/channels into a single numpy array
    # rasterio indexes channels starting from 1, hence the range is 1 to src.count + 1
    slope_stack = slope.read([i for i in range(1, slope.count + 1)])

print(slope_stack.shape)
(10103, 101, 101)
Code
# Convert the numpy array to a PyTorch tensor, which is the format required for training the model
slope_tens = torch.from_numpy(slope_stack)
print(slope_tens.shape)

# Print the mean, max, and min values of the slope tensor
print("Mean = ", torch.mean(slope_tens))
slope_max = torch.max(slope_tens)
slope_min = torch.min(slope_tens)
print("Max = ", slope_max)
print("Min = ", slope_min)

# Normalizing the data
slope_tens = (slope_tens - slope_min) / (slope_max - slope_min)
print("Mean = ", torch.mean(slope_tens))
print("Max = ", torch.max(slope_tens))
print("Min = ", torch.min(slope_tens))
torch.Size([10103, 101, 101])
Mean =  tensor(0.7779)
Max =  tensor(12.2981)
Min =  tensor(0.0006)
Mean =  tensor(0.0632)
Max =  tensor(1.)
Min =  tensor(0.)
Code
for i in range(0, 1):
    plt.imshow(slope_tens[i])
    plt.colorbar()
    plt.show()

Sentinel-2 bands

During the data preparation (in the deepSSF_data_prep_id_S2 script) for the Sentinel-2 bands, we scaled them by 10,000, so we do not need to scale them again here (as we did for the other covariates).

Band 1

Code
# Using rasterio
with rasterio.open(b1_path) as b1:
    # Read all layers/channels into a single numpy array
    # rasterio indexes channels starting from 1, hence the range is 1 to src.count + 1
    b1_stack = b1.read([i for i in range(1, b1.count + 1)])
Code
# Print the shape of the original b1_stack array
print(b1_stack.shape)

# Replace NaNs with -1
b1_stack = np.nan_to_num(b1_stack, nan=-1.0)

# Convert the numpy array to a PyTorch tensor
b1_tens = torch.from_numpy(b1_stack)
print(b1_tens.shape)

# Display the mean, max, and min values of the b1 tensor
print(f'Min =  {torch.min(b1_tens)}')
print(f'Mean = {torch.mean(b1_tens)}')
print(f'Max =  {torch.max(b1_tens)}')
(10103, 101, 101)
torch.Size([10103, 101, 101])
Min =  9.999999747378752e-05
Mean = 0.04444880783557892
Max =  0.1517084836959839
Code
for i in range(0, 1):
    plt.imshow(b1_tens[i].numpy())
    plt.colorbar()
    plt.show()

Band 2

Code
# Using rasterio
with rasterio.open(b2_path) as b2:
    # Read all layers/channels into a single numpy array
    # rasterio indexes channels starting from 1, hence the range is 1 to src.count + 1
    b2_stack = b2.read([i for i in range(1, b2.count + 1)])
Code
# Print the shape of the original b2_stack array
print(b2_stack.shape)

# Replace NaNs with -1
b2_stack = np.nan_to_num(b2_stack, nan=-1.0)

# Convert the numpy array to a PyTorch tensor
b2_tens = torch.from_numpy(b2_stack)
print(b2_tens.shape)

# Display the mean, max, and min values of the b2 tensor
print(f'Min =  {torch.min(b2_tens)}')
print(f'Mean = {torch.mean(b2_tens)}')
print(f'Max =  {torch.max(b2_tens)}')
(10103, 101, 101)
torch.Size([10103, 101, 101])
Min =  0.002810720121487975
Mean = 0.05629923567175865
Max =  0.1931755244731903
Code
for i in range(0, 1):
    plt.imshow(b2_tens[i].numpy())
    plt.colorbar()
    plt.show()

Band 3

Code
# Using rasterio
with rasterio.open(b3_path) as b3:
    # Read all layers/channels into a single numpy array
    # rasterio indexes channels starting from 1, hence the range is 1 to src.count + 1
    b3_stack = b3.read([i for i in range(1, b3.count + 1)])
Code
# Print the shape of the original b3_stack array
print(b3_stack.shape)

# Replace NaNs with -1
b3_stack = np.nan_to_num(b3_stack, nan=-1.0)

# Convert the numpy array to a PyTorch tensor
b3_tens = torch.from_numpy(b3_stack)
print(b3_tens.shape)

# Display the mean, max, and min values of the b3 tensor
print(f'Min =  {torch.min(b3_tens)}')
print(f'Mean = {torch.mean(b3_tens)}')
print(f'Max =  {torch.max(b3_tens)}')
(10103, 101, 101)
torch.Size([10103, 101, 101])
Min =  0.02109863981604576
Mean = 0.08027872443199158
Max =  0.2795756757259369
Code
for i in range(0, 1):
    plt.imshow(b3_tens[i].numpy())
    plt.colorbar()
    plt.show()

Band 4

Code
# Using rasterio
with rasterio.open(b4_path) as b4:
    # Read all layers/channels into a single numpy array
    # rasterio indexes channels starting from 1, hence the range is 1 to src.count + 1
    b4_stack = b4.read([i for i in range(1, b4.count + 1)])
Code
# Print the shape of the original b4_stack array
print(b4_stack.shape)

# Replace NaNs with -1
b4_stack = np.nan_to_num(b4_stack, nan=-1.0)

# Convert the numpy array to a PyTorch tensor
b4_tens = torch.from_numpy(b4_stack)
print(b4_tens.shape)

# Display the mean, max, and min values of the b4 tensor
print(f'Min =  {torch.min(b4_tens)}')
print(f'Mean = {torch.mean(b4_tens)}')
print(f'Max =  {torch.max(b4_tens)}')
(10103, 101, 101)
torch.Size([10103, 101, 101])
Min =  0.006578320171684027
Mean = 0.09937984496355057
Max =  0.43867969512939453
Code
for i in range(0, 1):
    plt.imshow(b4_tens[i].numpy())
    plt.colorbar()
    plt.show()

Band 5

Code
# Using rasterio
with rasterio.open(b5_path) as b5:
    # Read all layers/channels into a single numpy array
    # rasterio indexes channels starting from 1, hence the range is 1 to src.count + 1
    b5_stack = b5.read([i for i in range(1, b5.count + 1)])
Code
# Print the shape of the original b5_stack array
print(b5_stack.shape)

# Replace NaNs with -1
b5_stack = np.nan_to_num(b5_stack, nan=-1.0)

# Convert the numpy array to a PyTorch tensor
b5_tens = torch.from_numpy(b5_stack)
print(b5_tens.shape)

# Display the mean, max, and min values of the b5 tensor
print(f'Min =  {torch.min(b5_tens)}')
print(f'Mean = {torch.mean(b5_tens)}')
print(f'Max =  {torch.max(b5_tens)}')
(10103, 101, 101)
torch.Size([10103, 101, 101])
Min =  0.03587600216269493
Mean = 0.136901393532753
Max =  0.4592735767364502
Code
for i in range(0, 1):
    plt.imshow(b5_tens[i].numpy())
    plt.colorbar()
    plt.show()

Band 6

Code
# Using rasterio
with rasterio.open(b6_path) as b6:
    # Read all layers/channels into a single numpy array
    # rasterio indexes channels starting from 1, hence the range is 1 to src.count + 1
    b6_stack = b6.read([i for i in range(1, b6.count + 1)])
Code
# Print the shape of the original b6_stack array
print(b6_stack.shape)

# Replace NaNs with -1
b6_stack = np.nan_to_num(b6_stack, nan=-1.0)

# Convert the numpy array to a PyTorch tensor
b6_tens = torch.from_numpy(b6_stack)
print(b6_tens.shape)

# Display the mean, max, and min values of the b6 tensor
print(f'Min =  {torch.min(b6_tens)}')
print(f'Mean = {torch.mean(b6_tens)}')
print(f'Max =  {torch.max(b6_tens)}')
(10103, 101, 101)
torch.Size([10103, 101, 101])
Min =  0.038534000515937805
Mean = 0.17727355659008026
Max =  0.5120914578437805
Code
for i in range(0, 1):
    plt.imshow(b6_tens[i].numpy())
    plt.colorbar()
    plt.show()

Band 7

Code
# Using rasterio
with rasterio.open(b7_path) as b7:
    # Read all layers/channels into a single numpy array
    # rasterio indexes channels starting from 1, hence the range is 1 to src.count + 1
    b7_stack = b7.read([i for i in range(1, b7.count + 1)])
Code
# Print the shape of the original b7_stack array
print(b7_stack.shape)

# Replace NaNs with -1
b7_stack = np.nan_to_num(b7_stack, nan=-1.0)

# Convert the numpy array to a PyTorch tensor
b7_tens = torch.from_numpy(b7_stack)
print(b7_tens.shape)

# Display the mean, max, and min values of the b7 tensor
print(f'Min =  {torch.min(b7_tens)}')
print(f'Mean = {torch.mean(b7_tens)}')
print(f'Max =  {torch.max(b7_tens)}')
(10103, 101, 101)
torch.Size([10103, 101, 101])
Min =  0.04165744036436081
Mean = 0.19983430206775665
Max =  0.6045699119567871
Code
for i in range(0, 1):
    plt.imshow(b7_tens[i].numpy())
    plt.colorbar()
    plt.show()

Band 8

Code
# Using rasterio
with rasterio.open(b8_path) as b8:
    # Read all layers/channels into a single numpy array
    # rasterio indexes channels starting from 1, hence the range is 1 to src.count + 1
    b8_stack = b8.read([i for i in range(1, b8.count + 1)])
Code
# Print the shape of the original b8_stack array
print(b8_stack.shape)

# Replace NaNs with -1
b8_stack = np.nan_to_num(b8_stack, nan=-1.0)

# Convert the numpy array to a PyTorch tensor
b8_tens = torch.from_numpy(b8_stack)
print(b8_tens.shape)

# Display the mean, max, and min values of the b8 tensor
print(f'Min =  {torch.min(b8_tens)}')
print(f'Mean = {torch.mean(b8_tens)}')
print(f'Max =  {torch.max(b8_tens)}')
(10103, 101, 101)
torch.Size([10103, 101, 101])
Min =  0.03680320084095001
Mean = 0.2095790058374405
Max =  0.6004582643508911
Code
for i in range(0, 1):
    plt.imshow(b8_tens[i].numpy())
    plt.colorbar()
    plt.show()

Band 8a

Code
# Using rasterio
with rasterio.open(b8a_path) as b8a:
    # Read all layers/channels into a single numpy array
    # rasterio indexes channels starting from 1, hence the range is 1 to src.count + 1
    b8a_stack = b8a.read([i for i in range(1, b8a.count + 1)])
Code
# Print the shape of the original b8a_stack array
print(b8a_stack.shape)

# Replace NaNs with -1
b8a_stack = np.nan_to_num(b8a_stack, nan=-1.0)

# Convert the numpy array to a PyTorch tensor
b8a_tens = torch.from_numpy(b8a_stack)
print(b8a_tens.shape)

# Display the mean, max, and min values of the b8a tensor
print(f'Min =  {torch.min(b8a_tens)}')
print(f'Mean = {torch.mean(b8a_tens)}')
print(f'Max =  {torch.max(b8a_tens)}')
(10103, 101, 101)
torch.Size([10103, 101, 101])
Min =  0.03570704162120819
Mean = 0.22782425582408905
Max =  0.6218413710594177
Code
for i in range(0, 1):
    plt.imshow(b8a_tens[i].numpy())
    plt.colorbar()
    plt.show()

Band 9

Code
# Using rasterio
with rasterio.open(b9_path) as b9:
    # Read all layers/channels into a single numpy array
    # rasterio indexes channels starting from 1, hence the range is 1 to src.count + 1
    b9_stack = b9.read([i for i in range(1, b9.count + 1)])
Code
# Print the shape of the original b9_stack array
print(b9_stack.shape)

# Replace NaNs with -1
b9_stack = np.nan_to_num(b9_stack, nan=-1.0)

# Convert the numpy array to a PyTorch tensor
b9_tens = torch.from_numpy(b9_stack)
print(b9_tens.shape)

# Display the mean, max, and min values of the b9 tensor
print(f'Min =  {torch.min(b9_tens)}')
print(f'Mean = {torch.mean(b9_tens)}')
print(f'Max =  {torch.max(b9_tens)}')
(10103, 101, 101)
torch.Size([10103, 101, 101])
Min =  0.012299999594688416
Mean = 0.22701694071292877
Max =  0.5680500268936157
Code
for i in range(0, 1):
    plt.imshow(b9_tens[i].numpy())
    plt.colorbar()
    plt.show()

Band 11

Code
# Using rasterio
with rasterio.open(b11_path) as b11:
    # Read all layers/channels into a single numpy array
    # rasterio indexes channels starting from 1, hence the range is 1 to src.count + 1
    b11_stack = b11.read([i for i in range(1, b11.count + 1)])
Code
# Print the shape of the original b11_stack array
print(b11_stack.shape)

# Replace NaNs with -1
b11_stack = np.nan_to_num(b11_stack, nan=-1.0)

# Convert the numpy array to a PyTorch tensor
b11_tens = torch.from_numpy(b11_stack)
print(b11_tens.shape)

# Display the mean, max, and min values of the b11 tensor
print(f'Min =  {torch.min(b11_tens)}')
print(f'Mean = {torch.mean(b11_tens)}')
print(f'Max =  {torch.max(b11_tens)}')
(10103, 101, 101)
torch.Size([10103, 101, 101])
Min =  0.01741199940443039
Mean = 0.27866700291633606
Max =  0.657039225101471
Code
for i in range(0, 1):
    plt.imshow(b11_tens[i].numpy())
    plt.colorbar()
    plt.show()

Band 12

Code
# Using rasterio
with rasterio.open(b12_path) as b12:
    # Read all layers/channels into a single numpy array
    # rasterio indexes channels starting from 1, hence the range is 1 to src.count + 1
    b12_stack = b12.read([i for i in range(1, b1.count + 1)])
Code
# Print the shape of the original b12_stack array
print(b12_stack.shape)

# Replace NaNs with -1
b12_stack = np.nan_to_num(b12_stack, nan=-1.0)

# Convert the numpy array to a PyTorch tensor
b12_tens = torch.from_numpy(b12_stack)
print(b12_tens.shape)

# Display the mean, max, and min values of the b12 tensor
print(f'Min =  {torch.min(b12_tens)}')
print(f'Mean = {torch.mean(b12_tens)}')
print(f'Max =  {torch.max(b12_tens)}')
(10103, 101, 101)
torch.Size([10103, 101, 101])
Min =  0.012337599880993366
Mean = 0.19245100021362305
Max =  0.5119996666908264
Code
for i in range(0, 1):
    plt.imshow(b12_tens[i].numpy())
    plt.colorbar()
    plt.show()

View as RGB

Given the Red (B4), Green (B3), and Blue (B2) bands, we can create an RGB image.

Code
# Assuming b4_tens, b3_tens, and b2_tens are your tensors
rgb_image = torch.stack([b4_tens, b3_tens, b2_tens], dim=-1)

# Convert to NumPy
rgb_image_np = rgb_image[0].cpu().numpy()

# Normalize to the range [0, 1] for display
rgb_image_np = (rgb_image_np - rgb_image_np.min()) / (rgb_image_np.max() - rgb_image_np.min())

# Display the image
plt.imshow(rgb_image_np)
plt.title('Sentinel-2 RGB Image')
plt.show()

Presence records - target of model

The target is what we are trying to predict with the deepSSF model, with is the location of the observed next step.

Code
# Using rasterio
with rasterio.open(pres_path) as pres:
    # Read all layers/channels into a single numpy array
    # rasterio indexes channels starting from 1, hence the range is 1 to src.count + 1
    pres_stack = pres.read([i for i in range(1, pres.count + 1)])

print(pres_stack.shape)
print(type(pres_stack))
(10103, 101, 101)
<class 'numpy.ndarray'>
Code
for i in range(0, 1):
    plt.imshow(pres_stack[i])
    plt.show()

Code
# train on the GPU or on the CPU, if a GPU is not available
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
print(f"Using {device} device")
Using cpu device

Combine the spatial layers into channels

Code
# use sentinel-2 bands
combined_stack = torch.stack([b1_tens, 
                              b2_tens, 
                              b3_tens, 
                              b4_tens,
                              b5_tens, 
                              b6_tens, 
                              b7_tens, 
                              b8_tens,
                              b8a_tens, 
                              b9_tens, 
                              b11_tens, 
                              b12_tens,
                              slope_tens], 
                              dim=1)

print(combined_stack.shape)
torch.Size([10103, 13, 101, 101])

Defining data sets and data loaders

Creating a dataset class

This custom PyTorch Dataset organizes all your input (spatial data, scalar covariates, bearing, and target) in a single object, allowing you to neatly manage how samples are accessed. The __init__ method prepares and stores all the data, __len__ returns the total number of samples, and __getitem__ retrieves a single sample by index—enabling straightforward batching and iteration when used with a DataLoader.

Code
class buffalo_data(Dataset):

    def __init__(self):
        # data loading
        self.spatial_data_x = combined_stack

        # the scalar data that will be converted to grid data and added to the spatial covariates for CNN components
        self.scalar_to_grid_data = torch.from_numpy(buffalo_df[['hour_t2_sin', 
                                                                'hour_t2_cos', 
                                                                'yday_t2_sin', 
                                                                'yday_t2_cos']].values).float()

       # the bearing data that will be added as a channel to the spatial covariates
        self.bearing_x = torch.from_numpy(buffalo_df[['bearing_tm1']].values).float()

        # the target data
        self.target = torch.tensor(pres_stack)

        # number of samples
        self.n_samples = self.spatial_data_x.shape[0]

    def __len__(self):
        # allows for the use of len() function
        return self.n_samples

    def __getitem__(self, index):
        # allows for indexing of the dataset
        return self.spatial_data_x[index], self.scalar_to_grid_data[index], self.bearing_x[index], self.target[index]

Now we can create an instance of the dataset class and check that is working as expected.

Code
# Create an instance of our custom buffalo_data Dataset:
dataset = buffalo_data()

# Print the total number of samples loaded (determined by n_samples in the dataset):
print(dataset.n_samples)

# Retrieve *all* samples (using the slice dataset[:] invokes __getitem__ on all indices).
# This returns a tuple of (spatial data, scalar-to-grid data, bearing data, target labels).
features1, features2, features3, labels = dataset[:]

# Examine the dimensions of each returned tensor for verification:

# Spatial data
print(features1.shape)

# Scalar-to-grid data
print(features2.shape)

# Bearing data
print(features3.shape)

# Target labels
print(labels.shape)
10103
torch.Size([10103, 13, 101, 101])
torch.Size([10103, 4])
torch.Size([10103, 1])
torch.Size([10103, 101, 101])

Split into training, validation and test sets

Code
training_split = 0.8 # 80% of the data will be used for training
validation_split = 0.1 # 10% of the data will be used for validation (deciding when to stop training)
test_split = 0.1 # 10% of the data will be used for testing (model evaluation)

dataset_train, dataset_val, dataset_test = torch.utils.data.random_split(dataset, 
                                                                         [training_split, 
                                                                          validation_split, 
                                                                          test_split])
print("Number of training samples: ", len(dataset_train))
print("Number of validation samples: ", len(dataset_val))
print("Number of testing samples: ", len(dataset_test))
Number of training samples:  8083
Number of validation samples:  1010
Number of testing samples:  1010

Create dataloaders

The DataLoader in PyTorch wraps an iterable around the Dataset to enable easy access to the samples.

Code
# Define the batch size for how many samples to process at once in each step:
bs = 32

# Create a DataLoader for the training dataset with a batch size of bs, and shuffle samples 
# so that the model doesn't see data in the same order each epoch.
dataloader_train = DataLoader(dataset=dataset_train, 
                              batch_size=bs, 
                              shuffle=True)

# Create a DataLoader for the validation dataset, also with a batch size of bs and shuffling.
# Even though it's not always mandatory to shuffle validation data, some users keep the same setting.
dataloader_val = DataLoader(dataset=dataset_val, 
                            batch_size=bs, 
                            shuffle=True)

# Create a DataLoader for the test dataset, likewise with a batch size of bs and shuffling.
# As we want to index the testing data for plotting, we will not shuffle the test data.
dataloader_test = DataLoader(dataset=dataset_test, 
                             batch_size=bs, 
                             shuffle=True)

Check that the data loader is working as expected.

Code
# Display image and label.
# next(iter(dataloader_train)) returns the next batch of the training data
features1, features2, features3, labels = next(iter(dataloader_train))
print(f"Feature 1 batch shape: {features1.size()}")
print(f"Feature 2 batch shape: {features2.size()}")
print(f"Feature 3 batch shape: {features3.size()}")
print(f"Labels batch shape: {labels.size()}")
Feature 1 batch shape: torch.Size([32, 13, 101, 101])
Feature 2 batch shape: torch.Size([32, 4])
Feature 3 batch shape: torch.Size([32, 1])
Labels batch shape: torch.Size([32, 101, 101])

Load the model

As we have already described the model in detail in the deepSSF_model script, we can simply import the model here.

We will use the same model architecture as in the previous script, except that we will need to use a slightly edited dictionary to account for the additional input channels.

Code
# run on the GPU or on the CPU, if a GPU is not available
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
print(f"Using {device} device")
Using cpu device

Define the parameters for the model

Here we enter the specific parameter values and hyperparameters for the model. These are the values that will be used to instantiate the model.

Code
# In our case the 12 Sentinel-2 layers + slope
num_spatial_covs = 13 

params_dict = {"batch_size": 32,
               "image_dim": 101, #number of pixels along the edge of each local patch/image
               "pixel_size": 25, #number of metres along the edge of a pixel
               "dim_in_nonspatial_to_grid": 4, #the number of scalar predictors that are converted to a grid and appended to the spatial features
               "dense_dim_in_nonspatial": 4, #change this to however many other scalar predictors you have (bearing, velocity etc)
               "dense_dim_hidden": 128, #number of nodes in the hidden layers
               "dense_dim_out": 128, #number of nodes in the output of the fully connected block (FCN)
               "dense_dim_in_all": 2500,# + 128, #number of inputs entering the fully connected block once the nonspatial features have been concatenated to the spatial features
               "input_channels": num_spatial_covs + 4, #number of spatial layers in each image + number of scalar layers that are converted to a grid
               "output_channels": 4, #number of filters to learn
               "kernel_size": 3, #the size of the 2D moving windows / kernels that are being learned
               "stride": 1, #the stride used when applying the kernel.  This reduces the dimension of the output if set to greater than 1
               "kernel_size_mp": 2, #the size of the kernel that is used in max pooling operations
               "stride_mp": 2, #the stride that is used in max pooling operations
               "padding": 1, #the amount of padding to apply to images prior to applying the 2D convolution
               "num_movement_params": 12, #number of parameters used to parameterise the movement kernel
               "dropout": 0.1,
               "device": device
               }

Instantiate the model

As described in the deepSSF_train.ipynb script, we saved the model definition into a file named deepSSF_model.py. We can instantiate the model by importing the file (which was done when importing other packages) and calling the classes parameter dictionary from that script.

Code
params = deepSSF_model.ModelParams(params_dict)
model = deepSSF_model.ConvJointModel(params).to(device)
print(model)
ConvJointModel(
  (scalar_grid_output): Scalar_to_Grid_Block()
  (conv_habitat): Conv2d_block_spatial(
    (conv2d): Sequential(
      (0): Conv2d(17, 4, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (1): ReLU()
      (2): Conv2d(4, 4, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (3): ReLU()
      (4): Conv2d(4, 1, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    )
  )
  (conv_movement): Conv2d_block_toFC(
    (conv2d): Sequential(
      (0): Conv2d(17, 4, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (1): ReLU()
      (2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
      (3): Conv2d(4, 4, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (4): ReLU()
      (5): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
      (6): Flatten(start_dim=1, end_dim=-1)
    )
  )
  (fcn_movement_all): FCN_block_all_movement(
    (ffn): Sequential(
      (0): Linear(in_features=2500, out_features=128, bias=True)
      (1): Dropout(p=0.1, inplace=False)
      (2): ReLU()
      (3): Linear(in_features=128, out_features=128, bias=True)
      (4): Dropout(p=0.1, inplace=False)
      (5): ReLU()
      (6): Linear(in_features=128, out_features=12, bias=True)
    )
  )
  (movement_grid_output): Params_to_Grid_Block()
)

Set model hyperparameters

Set the learning rate, loss function, optimizer, scheduler and early stopping.

Code
learning_rate = 1e-3

# Define the negative log-likelihood loss function with mean reduction
loss_fn = deepSSF_loss.negativeLogLikeLoss(reduction='mean')

# path to save the model weights
path_save_weights = f'model_checkpoints/deepSSF_S2_slope_buffalo{buffalo_id}_{today_date}.pt'

# Set up the Adam optimizer for updating the model's parameters
optimiser = optim.Adam(model.parameters(), lr=learning_rate)

# Create a learning rate scheduler that reduces the LR by a factor of 0.1 
#    if validation loss has not improved for 'patience=5' epochs
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
    optimiser,  # The optimizer whose learning rate will be adjusted
    mode='min', # The metric to be minimized (e.g., validation loss)
    factor=0.1, # Factor by which the learning rate will be reduced
    patience=5  # Number of epochs with no improvement before learning rate reduces
)

# EarlyStopping stops training after 'patience=10' epochs with no improvement, 
#    optionally saving the best model weights
early_stopping = deepSSF_early_stopping.EarlyStopping(patience=20, verbose=True, path=path_save_weights)

Training loop

This code defines the main training loop for a single epoch. It iterates over batches from the training dataloader, moves the data to the correct device (e.g., CPU or GPU), calculates the loss, and performs backpropagation to update the model parameters. It also prints periodic updates of the current loss.

Code
def train_loop(dataloader_train, model, loss_fn, optimiser):
    """
    Runs the training process for one epoch using the given dataloader, model, 
    loss function, and optimizer. Prints progress updates every few batches.
    """

    # 1. Total number of training examples
    size = len(dataloader_train.dataset)

    # 2. Put model in training mode (affects layers like dropout, batchnorm)
    model.train()

    # 3. Variable to accumulate the total loss over the epoch
    epoch_loss = 0.0

    # 4. Loop over batches in the training dataloader
    for batch, (x1, x2, x3, y) in enumerate(dataloader_train):

        # Move the batch of data to the specified device (CPU/GPU)
        x1 = x1.to(device)
        x2 = x2.to(device)
        x3 = x3.to(device)
        y = y.to(device)

        # Forward pass: compute the model output and loss
        loss = loss_fn(model((x1, x2, x3)), y)
        epoch_loss += loss

        # Backpropagation: compute gradients and update parameters
        loss.backward()
        optimiser.step()

        # Reset gradients before the next iteration
        optimiser.zero_grad()

        # Print an update every 5 batches to keep track of training progress
        if batch % 5 == 0:
            loss_val = loss.item()
            current = batch * bs + len(x1)
            print(f"loss: {loss_val:>15f}  [{current:>5d}/{size:>5d}]")

Test loop

The test loop is similar to the training loop, but it does not perform backpropagation. It calculates the loss on the test set and returns the average loss.

Code
def test_loop(dataloader_test, model, loss_fn):
    """
    Evaluates the model on the provided test dataset by computing 
    the average loss over all batches. 
    No gradients are computed during this process (torch.no_grad()).
    """

    # 1. Set the model to evaluation mode (affects layers like dropout, batchnorm).
    model.eval()

    size = len(dataloader_test.dataset)
    num_batches = len(dataloader_test)
    
    test_loss = 0

    # 2. Disable gradient computation to speed up evaluation and reduce memory usage
    with torch.no_grad():
        # 3. Loop through each batch in the test dataloader
        for x1, x2, x3, y in dataloader_test:

            # Move the batch of data to the appropriate device (CPU/GPU)
            x1 = x1.to(device)
            x2 = x2.to(device)
            x3 = x3.to(device)
            y = y.to(device)

            # Compute the loss on the test set (no backward pass needed)
            test_loss += loss_fn(model((x1, x2, x3)), y)

    # 4. Compute average test loss over all batches
    test_loss /= num_batches

    # Print the average test loss
    print(f"Avg test loss: {test_loss:>15f} \n")

Train the model

Code
epochs = 100
val_losses = []   # Track validation losses across epochs

for t in range(epochs):
    val_loss = 0.0
    num_batches = len(dataloader_test)

    print(f"Epoch {t+1}\n-------------------------------")

    # 1. Run the training loop for one epoch using the training dataloader
    train_loop(dataloader_train, model, loss_fn, optimiser)

    # 2. Evaluate model performance on the validation dataset
    model.eval()  # Switch to evaluation mode for proper layer behavior
    with torch.no_grad():
        
        for x1, x2, x3, y in dataloader_val:
            # Move data to the chosen device (CPU/GPU)
            x1 = x1.to(device)
            x2 = x2.to(device)
            x3 = x3.to(device)
            y = y.to(device)

            # Accumulate validation loss
            val_loss += loss_fn(model((x1, x2, x3)), y)

    # 3. Step the scheduler based on the validation loss (adjusts learning rate if needed)
    scheduler.step(val_loss)

    # 4. Compute the average validation loss and print it, along with the current learning rate
    val_loss /= num_batches
    print(f"\nAvg validation loss: {val_loss:>15f}")
    print(f"Learning rate: {scheduler.get_last_lr()}")

    # 5. Track the validation loss for plotting or monitoring
    val_losses.append(val_loss)

    # 6. Early stopping: if no improvement in validation loss for a set patience, stop training
    early_stopping(val_loss, model)
    if early_stopping.early_stop:
        print("Early stopping")
        # Restore the best model weights saved by EarlyStopping
        model.load_state_dict(torch.load(path_save_weights, weights_only=True))
        test_loop(dataloader_test, model, loss_fn)  # Evaluate on test set once training stops
        break
    else:
        model.eval()
        print("\n")

print("Done!")
Epoch 1
-------------------------------
loss:        0.000820  [   32/ 8083]
loss:        0.000711  [  192/ 8083]
loss:        0.000606  [  352/ 8083]
loss:        0.000546  [  512/ 8083]
loss:        0.000617  [  672/ 8083]
loss:        0.000587  [  832/ 8083]
loss:        0.000603  [  992/ 8083]
loss:        0.000676  [ 1152/ 8083]
loss:        0.000518  [ 1312/ 8083]
loss:        0.000564  [ 1472/ 8083]
loss:        0.000654  [ 1632/ 8083]
loss:        0.000506  [ 1792/ 8083]
loss:        0.000603  [ 1952/ 8083]
loss:        0.000650  [ 2112/ 8083]
loss:        0.000649  [ 2272/ 8083]
loss:        0.000556  [ 2432/ 8083]
loss:        0.000539  [ 2592/ 8083]
loss:        0.000555  [ 2752/ 8083]
loss:        0.000558  [ 2912/ 8083]
loss:        0.000582  [ 3072/ 8083]
loss:        0.000570  [ 3232/ 8083]
loss:        0.000524  [ 3392/ 8083]
loss:        0.000517  [ 3552/ 8083]
loss:        0.000587  [ 3712/ 8083]
loss:        0.000572  [ 3872/ 8083]
loss:        0.000633  [ 4032/ 8083]
loss:        0.000585  [ 4192/ 8083]
loss:        0.000583  [ 4352/ 8083]
loss:        0.000514  [ 4512/ 8083]
loss:        0.000503  [ 4672/ 8083]
loss:        0.000611  [ 4832/ 8083]
loss:        0.000568  [ 4992/ 8083]
loss:        0.000574  [ 5152/ 8083]
loss:        0.000620  [ 5312/ 8083]
loss:        0.000561  [ 5472/ 8083]
loss:        0.000537  [ 5632/ 8083]
loss:        0.000562  [ 5792/ 8083]
loss:        0.000643  [ 5952/ 8083]
loss:        0.000585  [ 6112/ 8083]
loss:        0.000611  [ 6272/ 8083]
loss:        0.000564  [ 6432/ 8083]
loss:        0.000629  [ 6592/ 8083]
loss:        0.000493  [ 6752/ 8083]
loss:        0.000522  [ 6912/ 8083]
loss:        0.000711  [ 7072/ 8083]
loss:        0.000575  [ 7232/ 8083]
loss:        0.000603  [ 7392/ 8083]
loss:        0.000582  [ 7552/ 8083]
loss:        0.000562  [ 7712/ 8083]
loss:        0.000586  [ 7872/ 8083]
loss:        0.000644  [ 8032/ 8083]

Avg validation loss:        0.000532
Learning rate: [0.001]
Validation loss decreased (inf --> 0.000532).  Saving model ...


Epoch 2
-------------------------------
loss:        0.000521  [   32/ 8083]
loss:        0.000573  [  192/ 8083]
loss:        0.000459  [  352/ 8083]
loss:        0.000565  [  512/ 8083]
loss:        0.000557  [  672/ 8083]
loss:        0.000533  [  832/ 8083]
loss:        0.000599  [  992/ 8083]
loss:        0.000614  [ 1152/ 8083]
loss:        0.000453  [ 1312/ 8083]
loss:        0.000692  [ 1472/ 8083]
loss:        0.000511  [ 1632/ 8083]
loss:        0.000563  [ 1792/ 8083]
loss:        0.000493  [ 1952/ 8083]
loss:        0.000593  [ 2112/ 8083]
loss:        0.000544  [ 2272/ 8083]
loss:        0.000613  [ 2432/ 8083]
loss:        0.000584  [ 2592/ 8083]
loss:        0.000631  [ 2752/ 8083]
loss:        0.000551  [ 2912/ 8083]
loss:        0.000578  [ 3072/ 8083]
loss:        0.000587  [ 3232/ 8083]
loss:        0.000622  [ 3392/ 8083]
loss:        0.000550  [ 3552/ 8083]
loss:        0.000534  [ 3712/ 8083]
loss:        0.000548  [ 3872/ 8083]
loss:        0.000567  [ 4032/ 8083]
loss:        0.000624  [ 4192/ 8083]
loss:        0.000559  [ 4352/ 8083]
loss:        0.000455  [ 4512/ 8083]
loss:        0.000612  [ 4672/ 8083]
loss:        0.000512  [ 4832/ 8083]
loss:        0.000585  [ 4992/ 8083]
loss:        0.000578  [ 5152/ 8083]
loss:        0.000428  [ 5312/ 8083]
loss:        0.000475  [ 5472/ 8083]
loss:        0.000504  [ 5632/ 8083]
loss:        0.000505  [ 5792/ 8083]
loss:        0.000521  [ 5952/ 8083]
loss:        0.000558  [ 6112/ 8083]
loss:        0.000407  [ 6272/ 8083]
loss:        0.000507  [ 6432/ 8083]
loss:        0.000511  [ 6592/ 8083]
loss:        0.000455  [ 6752/ 8083]
loss:        0.000535  [ 6912/ 8083]
loss:        0.000610  [ 7072/ 8083]
loss:        0.000635  [ 7232/ 8083]
loss:        0.000546  [ 7392/ 8083]
loss:        0.000557  [ 7552/ 8083]
loss:        0.000608  [ 7712/ 8083]
loss:        0.000525  [ 7872/ 8083]
loss:        0.000421  [ 8032/ 8083]

Avg validation loss:        0.000518
Learning rate: [0.001]
Validation loss decreased (0.000532 --> 0.000518).  Saving model ...


Epoch 3
-------------------------------
loss:        0.000410  [   32/ 8083]
loss:        0.000544  [  192/ 8083]
loss:        0.000506  [  352/ 8083]
loss:        0.000443  [  512/ 8083]
loss:        0.000638  [  672/ 8083]
loss:        0.000502  [  832/ 8083]
loss:        0.000523  [  992/ 8083]
loss:        0.000594  [ 1152/ 8083]
loss:        0.000546  [ 1312/ 8083]
loss:        0.000494  [ 1472/ 8083]
loss:        0.000503  [ 1632/ 8083]
loss:        0.000637  [ 1792/ 8083]
loss:        0.000501  [ 1952/ 8083]
loss:        0.000593  [ 2112/ 8083]
loss:        0.000488  [ 2272/ 8083]
loss:        0.000481  [ 2432/ 8083]
loss:        0.000490  [ 2592/ 8083]
loss:        0.000448  [ 2752/ 8083]
loss:        0.000625  [ 2912/ 8083]
loss:        0.000573  [ 3072/ 8083]
loss:        0.000558  [ 3232/ 8083]
loss:        0.000521  [ 3392/ 8083]
loss:        0.000545  [ 3552/ 8083]
loss:        0.000495  [ 3712/ 8083]
loss:        0.000457  [ 3872/ 8083]
loss:        0.000540  [ 4032/ 8083]
loss:        0.000563  [ 4192/ 8083]
loss:        0.000546  [ 4352/ 8083]
loss:        0.000575  [ 4512/ 8083]
loss:        0.000636  [ 4672/ 8083]
loss:        0.000487  [ 4832/ 8083]
loss:        0.000565  [ 4992/ 8083]
loss:        0.000503  [ 5152/ 8083]
loss:        0.000584  [ 5312/ 8083]
loss:        0.000559  [ 5472/ 8083]
loss:        0.000600  [ 5632/ 8083]
loss:        0.000544  [ 5792/ 8083]
loss:        0.000440  [ 5952/ 8083]
loss:        0.000484  [ 6112/ 8083]
loss:        0.000510  [ 6272/ 8083]
loss:        0.000542  [ 6432/ 8083]
loss:        0.000605  [ 6592/ 8083]
loss:        0.000567  [ 6752/ 8083]
loss:        0.000632  [ 6912/ 8083]
loss:        0.000458  [ 7072/ 8083]
loss:        0.000479  [ 7232/ 8083]
loss:        0.000486  [ 7392/ 8083]
loss:        0.000505  [ 7552/ 8083]
loss:        0.000493  [ 7712/ 8083]
loss:        0.000622  [ 7872/ 8083]
loss:        0.000546  [ 8032/ 8083]

Avg validation loss:        0.000513
Learning rate: [0.001]
Validation loss decreased (0.000518 --> 0.000513).  Saving model ...


Epoch 4
-------------------------------
loss:        0.000541  [   32/ 8083]
loss:        0.000519  [  192/ 8083]
loss:        0.000566  [  352/ 8083]
loss:        0.000441  [  512/ 8083]
loss:        0.000581  [  672/ 8083]
loss:        0.000480  [  832/ 8083]
loss:        0.000427  [  992/ 8083]
loss:        0.000628  [ 1152/ 8083]
loss:        0.000560  [ 1312/ 8083]
loss:        0.000594  [ 1472/ 8083]
loss:        0.000682  [ 1632/ 8083]
loss:        0.000565  [ 1792/ 8083]
loss:        0.000538  [ 1952/ 8083]
loss:        0.000448  [ 2112/ 8083]
loss:        0.000546  [ 2272/ 8083]
loss:        0.000532  [ 2432/ 8083]
loss:        0.000628  [ 2592/ 8083]
loss:        0.000463  [ 2752/ 8083]
loss:        0.000512  [ 2912/ 8083]
loss:        0.000582  [ 3072/ 8083]
loss:        0.000668  [ 3232/ 8083]
loss:        0.000533  [ 3392/ 8083]
loss:        0.000573  [ 3552/ 8083]
loss:        0.000443  [ 3712/ 8083]
loss:        0.000578  [ 3872/ 8083]
loss:        0.000444  [ 4032/ 8083]
loss:        0.000583  [ 4192/ 8083]
loss:        0.000570  [ 4352/ 8083]
loss:        0.000586  [ 4512/ 8083]
loss:        0.000544  [ 4672/ 8083]
loss:        0.000491  [ 4832/ 8083]
loss:        0.000487  [ 4992/ 8083]
loss:        0.000468  [ 5152/ 8083]
loss:        0.000574  [ 5312/ 8083]
loss:        0.000534  [ 5472/ 8083]
loss:        0.000675  [ 5632/ 8083]
loss:        0.000531  [ 5792/ 8083]
loss:        0.000568  [ 5952/ 8083]
loss:        0.000504  [ 6112/ 8083]
loss:        0.000632  [ 6272/ 8083]
loss:        0.000543  [ 6432/ 8083]
loss:        0.000546  [ 6592/ 8083]
loss:        0.000558  [ 6752/ 8083]
loss:        0.000598  [ 6912/ 8083]
loss:        0.000499  [ 7072/ 8083]
loss:        0.000459  [ 7232/ 8083]
loss:        0.000577  [ 7392/ 8083]
loss:        0.000560  [ 7552/ 8083]
loss:        0.000508  [ 7712/ 8083]
loss:        0.000427  [ 7872/ 8083]
loss:        0.000488  [ 8032/ 8083]

Avg validation loss:        0.000513
Learning rate: [0.001]
EarlyStopping counter: 1 out of 20


Epoch 5
-------------------------------
loss:        0.000481  [   32/ 8083]
loss:        0.000549  [  192/ 8083]
loss:        0.000586  [  352/ 8083]
loss:        0.000581  [  512/ 8083]
loss:        0.000515  [  672/ 8083]
loss:        0.000596  [  832/ 8083]
loss:        0.000516  [  992/ 8083]
loss:        0.000563  [ 1152/ 8083]
loss:        0.000478  [ 1312/ 8083]
loss:        0.000605  [ 1472/ 8083]
loss:        0.000457  [ 1632/ 8083]
loss:        0.000446  [ 1792/ 8083]
loss:        0.000484  [ 1952/ 8083]
loss:        0.000551  [ 2112/ 8083]
loss:        0.000584  [ 2272/ 8083]
loss:        0.000656  [ 2432/ 8083]
loss:        0.000468  [ 2592/ 8083]
loss:        0.000508  [ 2752/ 8083]
loss:        0.000586  [ 2912/ 8083]
loss:        0.000547  [ 3072/ 8083]
loss:        0.000507  [ 3232/ 8083]
loss:        0.000485  [ 3392/ 8083]
loss:        0.000563  [ 3552/ 8083]
loss:        0.000600  [ 3712/ 8083]
loss:        0.000576  [ 3872/ 8083]
loss:        0.000523  [ 4032/ 8083]
loss:        0.000450  [ 4192/ 8083]
loss:        0.000614  [ 4352/ 8083]
loss:        0.000490  [ 4512/ 8083]
loss:        0.000519  [ 4672/ 8083]
loss:        0.000545  [ 4832/ 8083]
loss:        0.000634  [ 4992/ 8083]
loss:        0.000494  [ 5152/ 8083]
loss:        0.000606  [ 5312/ 8083]
loss:        0.000474  [ 5472/ 8083]
loss:        0.000438  [ 5632/ 8083]
loss:        0.000545  [ 5792/ 8083]
loss:        0.000537  [ 5952/ 8083]
loss:        0.000499  [ 6112/ 8083]
loss:        0.000483  [ 6272/ 8083]
loss:        0.000597  [ 6432/ 8083]
loss:        0.000519  [ 6592/ 8083]
loss:        0.000511  [ 6752/ 8083]
loss:        0.000542  [ 6912/ 8083]
loss:        0.000586  [ 7072/ 8083]
loss:        0.000462  [ 7232/ 8083]
loss:        0.000389  [ 7392/ 8083]
loss:        0.000559  [ 7552/ 8083]
loss:        0.000524  [ 7712/ 8083]
loss:        0.000610  [ 7872/ 8083]
loss:        0.000511  [ 8032/ 8083]

Avg validation loss:        0.000513
Learning rate: [0.001]
Validation loss decreased (0.000513 --> 0.000513).  Saving model ...


Epoch 6
-------------------------------
loss:        0.000497  [   32/ 8083]
loss:        0.000540  [  192/ 8083]
loss:        0.000621  [  352/ 8083]
loss:        0.000466  [  512/ 8083]
loss:        0.000472  [  672/ 8083]
loss:        0.000605  [  832/ 8083]
loss:        0.000526  [  992/ 8083]
loss:        0.000560  [ 1152/ 8083]
loss:        0.000524  [ 1312/ 8083]
loss:        0.000526  [ 1472/ 8083]
loss:        0.000532  [ 1632/ 8083]
loss:        0.000481  [ 1792/ 8083]
loss:        0.000587  [ 1952/ 8083]
loss:        0.000498  [ 2112/ 8083]
loss:        0.000565  [ 2272/ 8083]
loss:        0.000623  [ 2432/ 8083]
loss:        0.000633  [ 2592/ 8083]
loss:        0.000510  [ 2752/ 8083]
loss:        0.000431  [ 2912/ 8083]
loss:        0.000477  [ 3072/ 8083]
loss:        0.000480  [ 3232/ 8083]
loss:        0.000536  [ 3392/ 8083]
loss:        0.000466  [ 3552/ 8083]
loss:        0.000479  [ 3712/ 8083]
loss:        0.000618  [ 3872/ 8083]
loss:        0.000492  [ 4032/ 8083]
loss:        0.000547  [ 4192/ 8083]
loss:        0.000486  [ 4352/ 8083]
loss:        0.000506  [ 4512/ 8083]
loss:        0.000531  [ 4672/ 8083]
loss:        0.000588  [ 4832/ 8083]
loss:        0.000484  [ 4992/ 8083]
loss:        0.000494  [ 5152/ 8083]
loss:        0.000476  [ 5312/ 8083]
loss:        0.000587  [ 5472/ 8083]
loss:        0.000522  [ 5632/ 8083]
loss:        0.000487  [ 5792/ 8083]
loss:        0.000470  [ 5952/ 8083]
loss:        0.000558  [ 6112/ 8083]
loss:        0.000648  [ 6272/ 8083]
loss:        0.000480  [ 6432/ 8083]
loss:        0.000437  [ 6592/ 8083]
loss:        0.000498  [ 6752/ 8083]
loss:        0.000387  [ 6912/ 8083]
loss:        0.000553  [ 7072/ 8083]
loss:        0.000658  [ 7232/ 8083]
loss:        0.000516  [ 7392/ 8083]
loss:        0.000544  [ 7552/ 8083]
loss:        0.000566  [ 7712/ 8083]
loss:        0.000450  [ 7872/ 8083]
loss:        0.000638  [ 8032/ 8083]

Avg validation loss:        0.000516
Learning rate: [0.001]
EarlyStopping counter: 1 out of 20


Epoch 7
-------------------------------
loss:        0.000603  [   32/ 8083]
loss:        0.000587  [  192/ 8083]
loss:        0.000605  [  352/ 8083]
loss:        0.000608  [  512/ 8083]
loss:        0.000543  [  672/ 8083]
loss:        0.000513  [  832/ 8083]
loss:        0.000457  [  992/ 8083]
loss:        0.000534  [ 1152/ 8083]
loss:        0.000447  [ 1312/ 8083]
loss:        0.000555  [ 1472/ 8083]
loss:        0.000515  [ 1632/ 8083]
loss:        0.000543  [ 1792/ 8083]
loss:        0.000447  [ 1952/ 8083]
loss:        0.000639  [ 2112/ 8083]
loss:        0.000488  [ 2272/ 8083]
loss:        0.000585  [ 2432/ 8083]
loss:        0.000459  [ 2592/ 8083]
loss:        0.000611  [ 2752/ 8083]
loss:        0.000465  [ 2912/ 8083]
loss:        0.000480  [ 3072/ 8083]
loss:        0.000427  [ 3232/ 8083]
loss:        0.000563  [ 3392/ 8083]
loss:        0.000447  [ 3552/ 8083]
loss:        0.000591  [ 3712/ 8083]
loss:        0.000496  [ 3872/ 8083]
loss:        0.000528  [ 4032/ 8083]
loss:        0.000624  [ 4192/ 8083]
loss:        0.000562  [ 4352/ 8083]
loss:        0.000590  [ 4512/ 8083]
loss:        0.000409  [ 4672/ 8083]
loss:        0.000545  [ 4832/ 8083]
loss:        0.000394  [ 4992/ 8083]
loss:        0.000626  [ 5152/ 8083]
loss:        0.000420  [ 5312/ 8083]
loss:        0.000488  [ 5472/ 8083]
loss:        0.000456  [ 5632/ 8083]
loss:        0.000575  [ 5792/ 8083]
loss:        0.000553  [ 5952/ 8083]
loss:        0.000417  [ 6112/ 8083]
loss:        0.000498  [ 6272/ 8083]
loss:        0.000477  [ 6432/ 8083]
loss:        0.000541  [ 6592/ 8083]
loss:        0.000361  [ 6752/ 8083]
loss:        0.000537  [ 6912/ 8083]
loss:        0.000578  [ 7072/ 8083]
loss:        0.000566  [ 7232/ 8083]
loss:        0.000552  [ 7392/ 8083]
loss:        0.000520  [ 7552/ 8083]
loss:        0.000585  [ 7712/ 8083]
loss:        0.000537  [ 7872/ 8083]
loss:        0.000520  [ 8032/ 8083]

Avg validation loss:        0.000511
Learning rate: [0.001]
Validation loss decreased (0.000513 --> 0.000511).  Saving model ...


Epoch 8
-------------------------------
loss:        0.000494  [   32/ 8083]
loss:        0.000475  [  192/ 8083]
loss:        0.000557  [  352/ 8083]
loss:        0.000510  [  512/ 8083]
loss:        0.000542  [  672/ 8083]
loss:        0.000571  [  832/ 8083]
loss:        0.000579  [  992/ 8083]
loss:        0.000464  [ 1152/ 8083]
loss:        0.000581  [ 1312/ 8083]
loss:        0.000581  [ 1472/ 8083]
loss:        0.000533  [ 1632/ 8083]
loss:        0.000574  [ 1792/ 8083]
loss:        0.000607  [ 1952/ 8083]
loss:        0.000514  [ 2112/ 8083]
loss:        0.000561  [ 2272/ 8083]
loss:        0.000520  [ 2432/ 8083]
loss:        0.000538  [ 2592/ 8083]
loss:        0.000585  [ 2752/ 8083]
loss:        0.000621  [ 2912/ 8083]
loss:        0.000546  [ 3072/ 8083]
loss:        0.000530  [ 3232/ 8083]
loss:        0.000583  [ 3392/ 8083]
loss:        0.000590  [ 3552/ 8083]
loss:        0.000543  [ 3712/ 8083]
loss:        0.000638  [ 3872/ 8083]
loss:        0.000524  [ 4032/ 8083]
loss:        0.000443  [ 4192/ 8083]
loss:        0.000557  [ 4352/ 8083]
loss:        0.000617  [ 4512/ 8083]
loss:        0.000559  [ 4672/ 8083]
loss:        0.000445  [ 4832/ 8083]
loss:        0.000444  [ 4992/ 8083]
loss:        0.000514  [ 5152/ 8083]
loss:        0.000441  [ 5312/ 8083]
loss:        0.000433  [ 5472/ 8083]
loss:        0.000417  [ 5632/ 8083]
loss:        0.000612  [ 5792/ 8083]
loss:        0.000599  [ 5952/ 8083]
loss:        0.000502  [ 6112/ 8083]
loss:        0.000530  [ 6272/ 8083]
loss:        0.000624  [ 6432/ 8083]
loss:        0.000525  [ 6592/ 8083]
loss:        0.000522  [ 6752/ 8083]
loss:        0.000548  [ 6912/ 8083]
loss:        0.000589  [ 7072/ 8083]
loss:        0.000532  [ 7232/ 8083]
loss:        0.000524  [ 7392/ 8083]
loss:        0.000567  [ 7552/ 8083]
loss:        0.000536  [ 7712/ 8083]
loss:        0.000613  [ 7872/ 8083]
loss:        0.000595  [ 8032/ 8083]

Avg validation loss:        0.000511
Learning rate: [0.001]
EarlyStopping counter: 1 out of 20


Epoch 9
-------------------------------
loss:        0.000512  [   32/ 8083]
loss:        0.000407  [  192/ 8083]
loss:        0.000467  [  352/ 8083]
loss:        0.000537  [  512/ 8083]
loss:        0.000488  [  672/ 8083]
loss:        0.000467  [  832/ 8083]
loss:        0.000544  [  992/ 8083]
loss:        0.000509  [ 1152/ 8083]
loss:        0.000425  [ 1312/ 8083]
loss:        0.000544  [ 1472/ 8083]
loss:        0.000522  [ 1632/ 8083]
loss:        0.000516  [ 1792/ 8083]
loss:        0.000534  [ 1952/ 8083]
loss:        0.000489  [ 2112/ 8083]
loss:        0.000465  [ 2272/ 8083]
loss:        0.000518  [ 2432/ 8083]
loss:        0.000504  [ 2592/ 8083]
loss:        0.000551  [ 2752/ 8083]
loss:        0.000533  [ 2912/ 8083]
loss:        0.000479  [ 3072/ 8083]
loss:        0.000501  [ 3232/ 8083]
loss:        0.000612  [ 3392/ 8083]
loss:        0.000535  [ 3552/ 8083]
loss:        0.000447  [ 3712/ 8083]
loss:        0.000580  [ 3872/ 8083]
loss:        0.000469  [ 4032/ 8083]
loss:        0.000452  [ 4192/ 8083]
loss:        0.000493  [ 4352/ 8083]
loss:        0.000527  [ 4512/ 8083]
loss:        0.000541  [ 4672/ 8083]
loss:        0.000452  [ 4832/ 8083]
loss:        0.000513  [ 4992/ 8083]
loss:        0.000536  [ 5152/ 8083]
loss:        0.000546  [ 5312/ 8083]
loss:        0.000588  [ 5472/ 8083]
loss:        0.000553  [ 5632/ 8083]
loss:        0.000379  [ 5792/ 8083]
loss:        0.000518  [ 5952/ 8083]
loss:        0.000525  [ 6112/ 8083]
loss:        0.000568  [ 6272/ 8083]
loss:        0.000565  [ 6432/ 8083]
loss:        0.000653  [ 6592/ 8083]
loss:        0.000532  [ 6752/ 8083]
loss:        0.000476  [ 6912/ 8083]
loss:        0.000424  [ 7072/ 8083]
loss:        0.000546  [ 7232/ 8083]
loss:        0.000517  [ 7392/ 8083]
loss:        0.000508  [ 7552/ 8083]
loss:        0.000545  [ 7712/ 8083]
loss:        0.000438  [ 7872/ 8083]
loss:        0.000501  [ 8032/ 8083]

Avg validation loss:        0.000514
Learning rate: [0.001]
EarlyStopping counter: 2 out of 20


Epoch 10
-------------------------------
loss:        0.000574  [   32/ 8083]
loss:        0.000545  [  192/ 8083]
loss:        0.000439  [  352/ 8083]
loss:        0.000500  [  512/ 8083]
loss:        0.000551  [  672/ 8083]
loss:        0.000587  [  832/ 8083]
loss:        0.000484  [  992/ 8083]
loss:        0.000514  [ 1152/ 8083]
loss:        0.000511  [ 1312/ 8083]
loss:        0.000587  [ 1472/ 8083]
loss:        0.000654  [ 1632/ 8083]
loss:        0.000620  [ 1792/ 8083]
loss:        0.000495  [ 1952/ 8083]
loss:        0.000459  [ 2112/ 8083]
loss:        0.000582  [ 2272/ 8083]
loss:        0.000532  [ 2432/ 8083]
loss:        0.000384  [ 2592/ 8083]
loss:        0.000532  [ 2752/ 8083]
loss:        0.000508  [ 2912/ 8083]
loss:        0.000437  [ 3072/ 8083]
loss:        0.000556  [ 3232/ 8083]
loss:        0.000491  [ 3392/ 8083]
loss:        0.000585  [ 3552/ 8083]
loss:        0.000422  [ 3712/ 8083]
loss:        0.000475  [ 3872/ 8083]
loss:        0.000569  [ 4032/ 8083]
loss:        0.000499  [ 4192/ 8083]
loss:        0.000599  [ 4352/ 8083]
loss:        0.000556  [ 4512/ 8083]
loss:        0.000548  [ 4672/ 8083]
loss:        0.000606  [ 4832/ 8083]
loss:        0.000553  [ 4992/ 8083]
loss:        0.000555  [ 5152/ 8083]
loss:        0.000606  [ 5312/ 8083]
loss:        0.000492  [ 5472/ 8083]
loss:        0.000509  [ 5632/ 8083]
loss:        0.000529  [ 5792/ 8083]
loss:        0.000516  [ 5952/ 8083]
loss:        0.000504  [ 6112/ 8083]
loss:        0.000658  [ 6272/ 8083]
loss:        0.000539  [ 6432/ 8083]
loss:        0.000470  [ 6592/ 8083]
loss:        0.000551  [ 6752/ 8083]
loss:        0.000517  [ 6912/ 8083]
loss:        0.000472  [ 7072/ 8083]
loss:        0.000584  [ 7232/ 8083]
loss:        0.000522  [ 7392/ 8083]
loss:        0.000564  [ 7552/ 8083]
loss:        0.000506  [ 7712/ 8083]
loss:        0.000511  [ 7872/ 8083]
loss:        0.000527  [ 8032/ 8083]

Avg validation loss:        0.000509
Learning rate: [0.001]
Validation loss decreased (0.000511 --> 0.000509).  Saving model ...


Epoch 11
-------------------------------
loss:        0.000648  [   32/ 8083]
loss:        0.000522  [  192/ 8083]
loss:        0.000590  [  352/ 8083]
loss:        0.000542  [  512/ 8083]
loss:        0.000434  [  672/ 8083]
loss:        0.000528  [  832/ 8083]
loss:        0.000498  [  992/ 8083]
loss:        0.000487  [ 1152/ 8083]
loss:        0.000523  [ 1312/ 8083]
loss:        0.000510  [ 1472/ 8083]
loss:        0.000521  [ 1632/ 8083]
loss:        0.000497  [ 1792/ 8083]
loss:        0.000521  [ 1952/ 8083]
loss:        0.000439  [ 2112/ 8083]
loss:        0.000532  [ 2272/ 8083]
loss:        0.000461  [ 2432/ 8083]
loss:        0.000429  [ 2592/ 8083]
loss:        0.000550  [ 2752/ 8083]
loss:        0.000539  [ 2912/ 8083]
loss:        0.000449  [ 3072/ 8083]
loss:        0.000472  [ 3232/ 8083]
loss:        0.000699  [ 3392/ 8083]
loss:        0.000511  [ 3552/ 8083]
loss:        0.000438  [ 3712/ 8083]
loss:        0.000521  [ 3872/ 8083]
loss:        0.000493  [ 4032/ 8083]
loss:        0.000475  [ 4192/ 8083]
loss:        0.000600  [ 4352/ 8083]
loss:        0.000538  [ 4512/ 8083]
loss:        0.000475  [ 4672/ 8083]
loss:        0.000623  [ 4832/ 8083]
loss:        0.000588  [ 4992/ 8083]
loss:        0.000519  [ 5152/ 8083]
loss:        0.000487  [ 5312/ 8083]
loss:        0.000480  [ 5472/ 8083]
loss:        0.000515  [ 5632/ 8083]
loss:        0.000586  [ 5792/ 8083]
loss:        0.000569  [ 5952/ 8083]
loss:        0.000546  [ 6112/ 8083]
loss:        0.000532  [ 6272/ 8083]
loss:        0.000510  [ 6432/ 8083]
loss:        0.000602  [ 6592/ 8083]
loss:        0.000480  [ 6752/ 8083]
loss:        0.000463  [ 6912/ 8083]
loss:        0.000481  [ 7072/ 8083]
loss:        0.000447  [ 7232/ 8083]
loss:        0.000563  [ 7392/ 8083]
loss:        0.000533  [ 7552/ 8083]
loss:        0.000599  [ 7712/ 8083]
loss:        0.000517  [ 7872/ 8083]
loss:        0.000482  [ 8032/ 8083]

Avg validation loss:        0.000510
Learning rate: [0.001]
EarlyStopping counter: 1 out of 20


Epoch 12
-------------------------------
loss:        0.000494  [   32/ 8083]
loss:        0.000537  [  192/ 8083]
loss:        0.000512  [  352/ 8083]
loss:        0.000491  [  512/ 8083]
loss:        0.000516  [  672/ 8083]
loss:        0.000496  [  832/ 8083]
loss:        0.000544  [  992/ 8083]
loss:        0.000490  [ 1152/ 8083]
loss:        0.000669  [ 1312/ 8083]
loss:        0.000526  [ 1472/ 8083]
loss:        0.000511  [ 1632/ 8083]
loss:        0.000515  [ 1792/ 8083]
loss:        0.000494  [ 1952/ 8083]
loss:        0.000532  [ 2112/ 8083]
loss:        0.000472  [ 2272/ 8083]
loss:        0.000483  [ 2432/ 8083]
loss:        0.000521  [ 2592/ 8083]
loss:        0.000467  [ 2752/ 8083]
loss:        0.000485  [ 2912/ 8083]
loss:        0.000686  [ 3072/ 8083]
loss:        0.000510  [ 3232/ 8083]
loss:        0.000546  [ 3392/ 8083]
loss:        0.000605  [ 3552/ 8083]
loss:        0.000646  [ 3712/ 8083]
loss:        0.000414  [ 3872/ 8083]
loss:        0.000494  [ 4032/ 8083]
loss:        0.000463  [ 4192/ 8083]
loss:        0.000570  [ 4352/ 8083]
loss:        0.000511  [ 4512/ 8083]
loss:        0.000439  [ 4672/ 8083]
loss:        0.000476  [ 4832/ 8083]
loss:        0.000533  [ 4992/ 8083]
loss:        0.000489  [ 5152/ 8083]
loss:        0.000494  [ 5312/ 8083]
loss:        0.000486  [ 5472/ 8083]
loss:        0.000528  [ 5632/ 8083]
loss:        0.000490  [ 5792/ 8083]
loss:        0.000577  [ 5952/ 8083]
loss:        0.000509  [ 6112/ 8083]
loss:        0.000520  [ 6272/ 8083]
loss:        0.000476  [ 6432/ 8083]
loss:        0.000594  [ 6592/ 8083]
loss:        0.000518  [ 6752/ 8083]
loss:        0.000654  [ 6912/ 8083]
loss:        0.000574  [ 7072/ 8083]
loss:        0.000430  [ 7232/ 8083]
loss:        0.000490  [ 7392/ 8083]
loss:        0.000458  [ 7552/ 8083]
loss:        0.000541  [ 7712/ 8083]
loss:        0.000589  [ 7872/ 8083]
loss:        0.000630  [ 8032/ 8083]

Avg validation loss:        0.000514
Learning rate: [0.001]
EarlyStopping counter: 2 out of 20


Epoch 13
-------------------------------
loss:        0.000546  [   32/ 8083]
loss:        0.000496  [  192/ 8083]
loss:        0.000472  [  352/ 8083]
loss:        0.000683  [  512/ 8083]
loss:        0.000511  [  672/ 8083]
loss:        0.000470  [  832/ 8083]
loss:        0.000506  [  992/ 8083]
loss:        0.000540  [ 1152/ 8083]
loss:        0.000574  [ 1312/ 8083]
loss:        0.000365  [ 1472/ 8083]
loss:        0.000746  [ 1632/ 8083]
loss:        0.000537  [ 1792/ 8083]
loss:        0.000554  [ 1952/ 8083]
loss:        0.000547  [ 2112/ 8083]
loss:        0.000517  [ 2272/ 8083]
loss:        0.000546  [ 2432/ 8083]
loss:        0.000504  [ 2592/ 8083]
loss:        0.000570  [ 2752/ 8083]
loss:        0.000545  [ 2912/ 8083]
loss:        0.000503  [ 3072/ 8083]
loss:        0.000451  [ 3232/ 8083]
loss:        0.000615  [ 3392/ 8083]
loss:        0.000474  [ 3552/ 8083]
loss:        0.000531  [ 3712/ 8083]
loss:        0.000594  [ 3872/ 8083]
loss:        0.000532  [ 4032/ 8083]
loss:        0.000452  [ 4192/ 8083]
loss:        0.000475  [ 4352/ 8083]
loss:        0.000596  [ 4512/ 8083]
loss:        0.000474  [ 4672/ 8083]
loss:        0.000473  [ 4832/ 8083]
loss:        0.000610  [ 4992/ 8083]
loss:        0.000492  [ 5152/ 8083]
loss:        0.000508  [ 5312/ 8083]
loss:        0.000506  [ 5472/ 8083]
loss:        0.000640  [ 5632/ 8083]
loss:        0.000517  [ 5792/ 8083]
loss:        0.000532  [ 5952/ 8083]
loss:        0.000488  [ 6112/ 8083]
loss:        0.000500  [ 6272/ 8083]
loss:        0.000483  [ 6432/ 8083]
loss:        0.000554  [ 6592/ 8083]
loss:        0.000498  [ 6752/ 8083]
loss:        0.000474  [ 6912/ 8083]
loss:        0.000539  [ 7072/ 8083]
loss:        0.000576  [ 7232/ 8083]
loss:        0.000453  [ 7392/ 8083]
loss:        0.000446  [ 7552/ 8083]
loss:        0.000575  [ 7712/ 8083]
loss:        0.000562  [ 7872/ 8083]
loss:        0.000511  [ 8032/ 8083]

Avg validation loss:        0.000507
Learning rate: [0.001]
Validation loss decreased (0.000509 --> 0.000507).  Saving model ...


Epoch 14
-------------------------------
loss:        0.000599  [   32/ 8083]
loss:        0.000500  [  192/ 8083]
loss:        0.000543  [  352/ 8083]
loss:        0.000592  [  512/ 8083]
loss:        0.000488  [  672/ 8083]
loss:        0.000563  [  832/ 8083]
loss:        0.000545  [  992/ 8083]
loss:        0.000608  [ 1152/ 8083]
loss:        0.000519  [ 1312/ 8083]
loss:        0.000566  [ 1472/ 8083]
loss:        0.000593  [ 1632/ 8083]
loss:        0.000477  [ 1792/ 8083]
loss:        0.000649  [ 1952/ 8083]
loss:        0.000548  [ 2112/ 8083]
loss:        0.000544  [ 2272/ 8083]
loss:        0.000616  [ 2432/ 8083]
loss:        0.000492  [ 2592/ 8083]
loss:        0.000396  [ 2752/ 8083]
loss:        0.000499  [ 2912/ 8083]
loss:        0.000544  [ 3072/ 8083]
loss:        0.000512  [ 3232/ 8083]
loss:        0.000504  [ 3392/ 8083]
loss:        0.000435  [ 3552/ 8083]
loss:        0.000526  [ 3712/ 8083]
loss:        0.000483  [ 3872/ 8083]
loss:        0.000487  [ 4032/ 8083]
loss:        0.000485  [ 4192/ 8083]
loss:        0.000511  [ 4352/ 8083]
loss:        0.000547  [ 4512/ 8083]
loss:        0.000493  [ 4672/ 8083]
loss:        0.000534  [ 4832/ 8083]
loss:        0.000474  [ 4992/ 8083]
loss:        0.000593  [ 5152/ 8083]
loss:        0.000512  [ 5312/ 8083]
loss:        0.000451  [ 5472/ 8083]
loss:        0.000588  [ 5632/ 8083]
loss:        0.000524  [ 5792/ 8083]
loss:        0.000547  [ 5952/ 8083]
loss:        0.000517  [ 6112/ 8083]
loss:        0.000538  [ 6272/ 8083]
loss:        0.000568  [ 6432/ 8083]
loss:        0.000550  [ 6592/ 8083]
loss:        0.000403  [ 6752/ 8083]
loss:        0.000469  [ 6912/ 8083]
loss:        0.000473  [ 7072/ 8083]
loss:        0.000435  [ 7232/ 8083]
loss:        0.000528  [ 7392/ 8083]
loss:        0.000567  [ 7552/ 8083]
loss:        0.000538  [ 7712/ 8083]
loss:        0.000520  [ 7872/ 8083]
loss:        0.000504  [ 8032/ 8083]

Avg validation loss:        0.000511
Learning rate: [0.001]
EarlyStopping counter: 1 out of 20


Epoch 15
-------------------------------
loss:        0.000469  [   32/ 8083]
loss:        0.000538  [  192/ 8083]
loss:        0.000497  [  352/ 8083]
loss:        0.000440  [  512/ 8083]
loss:        0.000480  [  672/ 8083]
loss:        0.000505  [  832/ 8083]
loss:        0.000497  [  992/ 8083]
loss:        0.000581  [ 1152/ 8083]
loss:        0.000532  [ 1312/ 8083]
loss:        0.000509  [ 1472/ 8083]
loss:        0.000420  [ 1632/ 8083]
loss:        0.000519  [ 1792/ 8083]
loss:        0.000459  [ 1952/ 8083]
loss:        0.000415  [ 2112/ 8083]
loss:        0.000509  [ 2272/ 8083]
loss:        0.000494  [ 2432/ 8083]
loss:        0.000460  [ 2592/ 8083]
loss:        0.000511  [ 2752/ 8083]
loss:        0.000574  [ 2912/ 8083]
loss:        0.000602  [ 3072/ 8083]
loss:        0.000476  [ 3232/ 8083]
loss:        0.000456  [ 3392/ 8083]
loss:        0.000489  [ 3552/ 8083]
loss:        0.000559  [ 3712/ 8083]
loss:        0.000576  [ 3872/ 8083]
loss:        0.000551  [ 4032/ 8083]
loss:        0.000600  [ 4192/ 8083]
loss:        0.000545  [ 4352/ 8083]
loss:        0.000520  [ 4512/ 8083]
loss:        0.000459  [ 4672/ 8083]
loss:        0.000521  [ 4832/ 8083]
loss:        0.000589  [ 4992/ 8083]
loss:        0.000478  [ 5152/ 8083]
loss:        0.000583  [ 5312/ 8083]
loss:        0.000478  [ 5472/ 8083]
loss:        0.000589  [ 5632/ 8083]
loss:        0.000576  [ 5792/ 8083]
loss:        0.000543  [ 5952/ 8083]
loss:        0.000503  [ 6112/ 8083]
loss:        0.000525  [ 6272/ 8083]
loss:        0.000462  [ 6432/ 8083]
loss:        0.000501  [ 6592/ 8083]
loss:        0.000566  [ 6752/ 8083]
loss:        0.000551  [ 6912/ 8083]
loss:        0.000451  [ 7072/ 8083]
loss:        0.000628  [ 7232/ 8083]
loss:        0.000594  [ 7392/ 8083]
loss:        0.000588  [ 7552/ 8083]
loss:        0.000531  [ 7712/ 8083]
loss:        0.000567  [ 7872/ 8083]
loss:        0.000508  [ 8032/ 8083]

Avg validation loss:        0.000512
Learning rate: [0.001]
EarlyStopping counter: 2 out of 20


Epoch 16
-------------------------------
loss:        0.000602  [   32/ 8083]
loss:        0.000605  [  192/ 8083]
loss:        0.000420  [  352/ 8083]
loss:        0.000489  [  512/ 8083]
loss:        0.000578  [  672/ 8083]
loss:        0.000484  [  832/ 8083]
loss:        0.000564  [  992/ 8083]
loss:        0.000489  [ 1152/ 8083]
loss:        0.000612  [ 1312/ 8083]
loss:        0.000402  [ 1472/ 8083]
loss:        0.000464  [ 1632/ 8083]
loss:        0.000518  [ 1792/ 8083]
loss:        0.000516  [ 1952/ 8083]
loss:        0.000465  [ 2112/ 8083]
loss:        0.000591  [ 2272/ 8083]
loss:        0.000547  [ 2432/ 8083]
loss:        0.000494  [ 2592/ 8083]
loss:        0.000546  [ 2752/ 8083]
loss:        0.000540  [ 2912/ 8083]
loss:        0.000514  [ 3072/ 8083]
loss:        0.000554  [ 3232/ 8083]
loss:        0.000530  [ 3392/ 8083]
loss:        0.000596  [ 3552/ 8083]
loss:        0.000509  [ 3712/ 8083]
loss:        0.000534  [ 3872/ 8083]
loss:        0.000532  [ 4032/ 8083]
loss:        0.000627  [ 4192/ 8083]
loss:        0.000669  [ 4352/ 8083]
loss:        0.000464  [ 4512/ 8083]
loss:        0.000484  [ 4672/ 8083]
loss:        0.000623  [ 4832/ 8083]
loss:        0.000503  [ 4992/ 8083]
loss:        0.000450  [ 5152/ 8083]
loss:        0.000438  [ 5312/ 8083]
loss:        0.000470  [ 5472/ 8083]
loss:        0.000523  [ 5632/ 8083]
loss:        0.000623  [ 5792/ 8083]
loss:        0.000488  [ 5952/ 8083]
loss:        0.000439  [ 6112/ 8083]
loss:        0.000441  [ 6272/ 8083]
loss:        0.000547  [ 6432/ 8083]
loss:        0.000549  [ 6592/ 8083]
loss:        0.000546  [ 6752/ 8083]
loss:        0.000495  [ 6912/ 8083]
loss:        0.000520  [ 7072/ 8083]
loss:        0.000571  [ 7232/ 8083]
loss:        0.000510  [ 7392/ 8083]
loss:        0.000566  [ 7552/ 8083]
loss:        0.000487  [ 7712/ 8083]
loss:        0.000553  [ 7872/ 8083]
loss:        0.000437  [ 8032/ 8083]

Avg validation loss:        0.000505
Learning rate: [0.001]
Validation loss decreased (0.000507 --> 0.000505).  Saving model ...


Epoch 17
-------------------------------
loss:        0.000482  [   32/ 8083]
loss:        0.000443  [  192/ 8083]
loss:        0.000438  [  352/ 8083]
loss:        0.000439  [  512/ 8083]
loss:        0.000478  [  672/ 8083]
loss:        0.000484  [  832/ 8083]
loss:        0.000525  [  992/ 8083]
loss:        0.000458  [ 1152/ 8083]
loss:        0.000540  [ 1312/ 8083]
loss:        0.000511  [ 1472/ 8083]
loss:        0.000523  [ 1632/ 8083]
loss:        0.000497  [ 1792/ 8083]
loss:        0.000498  [ 1952/ 8083]
loss:        0.000564  [ 2112/ 8083]
loss:        0.000441  [ 2272/ 8083]
loss:        0.000519  [ 2432/ 8083]
loss:        0.000598  [ 2592/ 8083]
loss:        0.000492  [ 2752/ 8083]
loss:        0.000543  [ 2912/ 8083]
loss:        0.000574  [ 3072/ 8083]
loss:        0.000642  [ 3232/ 8083]
loss:        0.000598  [ 3392/ 8083]
loss:        0.000422  [ 3552/ 8083]
loss:        0.000417  [ 3712/ 8083]
loss:        0.000496  [ 3872/ 8083]
loss:        0.000545  [ 4032/ 8083]
loss:        0.000519  [ 4192/ 8083]
loss:        0.000509  [ 4352/ 8083]
loss:        0.000487  [ 4512/ 8083]
loss:        0.000559  [ 4672/ 8083]
loss:        0.000523  [ 4832/ 8083]
loss:        0.000526  [ 4992/ 8083]
loss:        0.000426  [ 5152/ 8083]
loss:        0.000441  [ 5312/ 8083]
loss:        0.000584  [ 5472/ 8083]
loss:        0.000533  [ 5632/ 8083]
loss:        0.000553  [ 5792/ 8083]
loss:        0.000441  [ 5952/ 8083]
loss:        0.000542  [ 6112/ 8083]
loss:        0.000434  [ 6272/ 8083]
loss:        0.000593  [ 6432/ 8083]
loss:        0.000556  [ 6592/ 8083]
loss:        0.000570  [ 6752/ 8083]
loss:        0.000526  [ 6912/ 8083]
loss:        0.000500  [ 7072/ 8083]
loss:        0.000547  [ 7232/ 8083]
loss:        0.000525  [ 7392/ 8083]
loss:        0.000601  [ 7552/ 8083]
loss:        0.000567  [ 7712/ 8083]
loss:        0.000525  [ 7872/ 8083]
loss:        0.000524  [ 8032/ 8083]

Avg validation loss:        0.000508
Learning rate: [0.001]
EarlyStopping counter: 1 out of 20


Epoch 18
-------------------------------
loss:        0.000484  [   32/ 8083]
loss:        0.000500  [  192/ 8083]
loss:        0.000450  [  352/ 8083]
loss:        0.000672  [  512/ 8083]
loss:        0.000542  [  672/ 8083]
loss:        0.000528  [  832/ 8083]
loss:        0.000618  [  992/ 8083]
loss:        0.000514  [ 1152/ 8083]
loss:        0.000533  [ 1312/ 8083]
loss:        0.000502  [ 1472/ 8083]
loss:        0.000472  [ 1632/ 8083]
loss:        0.000540  [ 1792/ 8083]
loss:        0.000618  [ 1952/ 8083]
loss:        0.000478  [ 2112/ 8083]
loss:        0.000552  [ 2272/ 8083]
loss:        0.000533  [ 2432/ 8083]
loss:        0.000554  [ 2592/ 8083]
loss:        0.000498  [ 2752/ 8083]
loss:        0.000488  [ 2912/ 8083]
loss:        0.000514  [ 3072/ 8083]
loss:        0.000684  [ 3232/ 8083]
loss:        0.000533  [ 3392/ 8083]
loss:        0.000496  [ 3552/ 8083]
loss:        0.000521  [ 3712/ 8083]
loss:        0.000528  [ 3872/ 8083]
loss:        0.000480  [ 4032/ 8083]
loss:        0.000550  [ 4192/ 8083]
loss:        0.000414  [ 4352/ 8083]
loss:        0.000625  [ 4512/ 8083]
loss:        0.000562  [ 4672/ 8083]
loss:        0.000539  [ 4832/ 8083]
loss:        0.000547  [ 4992/ 8083]
loss:        0.000557  [ 5152/ 8083]
loss:        0.000508  [ 5312/ 8083]
loss:        0.000586  [ 5472/ 8083]
loss:        0.000683  [ 5632/ 8083]
loss:        0.000481  [ 5792/ 8083]
loss:        0.000561  [ 5952/ 8083]
loss:        0.000506  [ 6112/ 8083]
loss:        0.000496  [ 6272/ 8083]
loss:        0.000499  [ 6432/ 8083]
loss:        0.000478  [ 6592/ 8083]
loss:        0.000547  [ 6752/ 8083]
loss:        0.000485  [ 6912/ 8083]
loss:        0.000570  [ 7072/ 8083]
loss:        0.000500  [ 7232/ 8083]
loss:        0.000551  [ 7392/ 8083]
loss:        0.000477  [ 7552/ 8083]
loss:        0.000475  [ 7712/ 8083]
loss:        0.000579  [ 7872/ 8083]
loss:        0.000499  [ 8032/ 8083]

Avg validation loss:        0.000503
Learning rate: [0.001]
Validation loss decreased (0.000505 --> 0.000503).  Saving model ...


Epoch 19
-------------------------------
loss:        0.000466  [   32/ 8083]
loss:        0.000475  [  192/ 8083]
loss:        0.000488  [  352/ 8083]
loss:        0.000459  [  512/ 8083]
loss:        0.000461  [  672/ 8083]
loss:        0.000470  [  832/ 8083]
loss:        0.000540  [  992/ 8083]
loss:        0.000567  [ 1152/ 8083]
loss:        0.000481  [ 1312/ 8083]
loss:        0.000474  [ 1472/ 8083]
loss:        0.000503  [ 1632/ 8083]
loss:        0.000540  [ 1792/ 8083]
loss:        0.000510  [ 1952/ 8083]
loss:        0.000513  [ 2112/ 8083]
loss:        0.000483  [ 2272/ 8083]
loss:        0.000508  [ 2432/ 8083]
loss:        0.000456  [ 2592/ 8083]
loss:        0.000618  [ 2752/ 8083]
loss:        0.000496  [ 2912/ 8083]
loss:        0.000515  [ 3072/ 8083]
loss:        0.000567  [ 3232/ 8083]
loss:        0.000444  [ 3392/ 8083]
loss:        0.000320  [ 3552/ 8083]
loss:        0.000583  [ 3712/ 8083]
loss:        0.000577  [ 3872/ 8083]
loss:        0.000492  [ 4032/ 8083]
loss:        0.000606  [ 4192/ 8083]
loss:        0.000474  [ 4352/ 8083]
loss:        0.000532  [ 4512/ 8083]
loss:        0.000513  [ 4672/ 8083]
loss:        0.000428  [ 4832/ 8083]
loss:        0.000496  [ 4992/ 8083]
loss:        0.000488  [ 5152/ 8083]
loss:        0.000549  [ 5312/ 8083]
loss:        0.000558  [ 5472/ 8083]
loss:        0.000581  [ 5632/ 8083]
loss:        0.000568  [ 5792/ 8083]
loss:        0.000567  [ 5952/ 8083]
loss:        0.000471  [ 6112/ 8083]
loss:        0.000572  [ 6272/ 8083]
loss:        0.000562  [ 6432/ 8083]
loss:        0.000480  [ 6592/ 8083]
loss:        0.000599  [ 6752/ 8083]
loss:        0.000480  [ 6912/ 8083]
loss:        0.000527  [ 7072/ 8083]
loss:        0.000472  [ 7232/ 8083]
loss:        0.000522  [ 7392/ 8083]
loss:        0.000602  [ 7552/ 8083]
loss:        0.000425  [ 7712/ 8083]
loss:        0.000492  [ 7872/ 8083]
loss:        0.000517  [ 8032/ 8083]

Avg validation loss:        0.000507
Learning rate: [0.001]
EarlyStopping counter: 1 out of 20


Epoch 20
-------------------------------
loss:        0.000468  [   32/ 8083]
loss:        0.000545  [  192/ 8083]
loss:        0.000559  [  352/ 8083]
loss:        0.000477  [  512/ 8083]
loss:        0.000522  [  672/ 8083]
loss:        0.000507  [  832/ 8083]
loss:        0.000535  [  992/ 8083]
loss:        0.000557  [ 1152/ 8083]
loss:        0.000500  [ 1312/ 8083]
loss:        0.000597  [ 1472/ 8083]
loss:        0.000538  [ 1632/ 8083]
loss:        0.000611  [ 1792/ 8083]
loss:        0.000491  [ 1952/ 8083]
loss:        0.000449  [ 2112/ 8083]
loss:        0.000498  [ 2272/ 8083]
loss:        0.000507  [ 2432/ 8083]
loss:        0.000558  [ 2592/ 8083]
loss:        0.000434  [ 2752/ 8083]
loss:        0.000499  [ 2912/ 8083]
loss:        0.000461  [ 3072/ 8083]
loss:        0.000539  [ 3232/ 8083]
loss:        0.000468  [ 3392/ 8083]
loss:        0.000481  [ 3552/ 8083]
loss:        0.000585  [ 3712/ 8083]
loss:        0.000500  [ 3872/ 8083]
loss:        0.000581  [ 4032/ 8083]
loss:        0.000434  [ 4192/ 8083]
loss:        0.000455  [ 4352/ 8083]
loss:        0.000537  [ 4512/ 8083]
loss:        0.000525  [ 4672/ 8083]
loss:        0.000542  [ 4832/ 8083]
loss:        0.000461  [ 4992/ 8083]
loss:        0.000468  [ 5152/ 8083]
loss:        0.000444  [ 5312/ 8083]
loss:        0.000526  [ 5472/ 8083]
loss:        0.000585  [ 5632/ 8083]
loss:        0.000586  [ 5792/ 8083]
loss:        0.000597  [ 5952/ 8083]
loss:        0.000351  [ 6112/ 8083]
loss:        0.000473  [ 6272/ 8083]
loss:        0.000465  [ 6432/ 8083]
loss:        0.000364  [ 6592/ 8083]
loss:        0.000478  [ 6752/ 8083]
loss:        0.000454  [ 6912/ 8083]
loss:        0.000522  [ 7072/ 8083]
loss:        0.000526  [ 7232/ 8083]
loss:        0.000501  [ 7392/ 8083]
loss:        0.000590  [ 7552/ 8083]
loss:        0.000498  [ 7712/ 8083]
loss:        0.000528  [ 7872/ 8083]
loss:        0.000442  [ 8032/ 8083]

Avg validation loss:        0.000512
Learning rate: [0.001]
EarlyStopping counter: 2 out of 20


Epoch 21
-------------------------------
loss:        0.000412  [   32/ 8083]
loss:        0.000465  [  192/ 8083]
loss:        0.000383  [  352/ 8083]
loss:        0.000594  [  512/ 8083]
loss:        0.000563  [  672/ 8083]
loss:        0.000445  [  832/ 8083]
loss:        0.000524  [  992/ 8083]
loss:        0.000521  [ 1152/ 8083]
loss:        0.000487  [ 1312/ 8083]
loss:        0.000502  [ 1472/ 8083]
loss:        0.000632  [ 1632/ 8083]
loss:        0.000468  [ 1792/ 8083]
loss:        0.000633  [ 1952/ 8083]
loss:        0.000550  [ 2112/ 8083]
loss:        0.000477  [ 2272/ 8083]
loss:        0.000563  [ 2432/ 8083]
loss:        0.000516  [ 2592/ 8083]
loss:        0.000519  [ 2752/ 8083]
loss:        0.000641  [ 2912/ 8083]
loss:        0.000441  [ 3072/ 8083]
loss:        0.000582  [ 3232/ 8083]
loss:        0.000544  [ 3392/ 8083]
loss:        0.000494  [ 3552/ 8083]
loss:        0.000537  [ 3712/ 8083]
loss:        0.000556  [ 3872/ 8083]
loss:        0.000517  [ 4032/ 8083]
loss:        0.000488  [ 4192/ 8083]
loss:        0.000447  [ 4352/ 8083]
loss:        0.000488  [ 4512/ 8083]
loss:        0.000515  [ 4672/ 8083]
loss:        0.000537  [ 4832/ 8083]
loss:        0.000518  [ 4992/ 8083]
loss:        0.000595  [ 5152/ 8083]
loss:        0.000570  [ 5312/ 8083]
loss:        0.000583  [ 5472/ 8083]
loss:        0.000534  [ 5632/ 8083]
loss:        0.000474  [ 5792/ 8083]
loss:        0.000505  [ 5952/ 8083]
loss:        0.000654  [ 6112/ 8083]
loss:        0.000432  [ 6272/ 8083]
loss:        0.000467  [ 6432/ 8083]
loss:        0.000556  [ 6592/ 8083]
loss:        0.000527  [ 6752/ 8083]
loss:        0.000551  [ 6912/ 8083]
loss:        0.000443  [ 7072/ 8083]
loss:        0.000508  [ 7232/ 8083]
loss:        0.000527  [ 7392/ 8083]
loss:        0.000405  [ 7552/ 8083]
loss:        0.000347  [ 7712/ 8083]
loss:        0.000501  [ 7872/ 8083]
loss:        0.000489  [ 8032/ 8083]

Avg validation loss:        0.000504
Learning rate: [0.001]
EarlyStopping counter: 3 out of 20


Epoch 22
-------------------------------
loss:        0.000654  [   32/ 8083]
loss:        0.000526  [  192/ 8083]
loss:        0.000575  [  352/ 8083]
loss:        0.000539  [  512/ 8083]
loss:        0.000496  [  672/ 8083]
loss:        0.000471  [  832/ 8083]
loss:        0.000563  [  992/ 8083]
loss:        0.000521  [ 1152/ 8083]
loss:        0.000683  [ 1312/ 8083]
loss:        0.000515  [ 1472/ 8083]
loss:        0.000589  [ 1632/ 8083]
loss:        0.000548  [ 1792/ 8083]
loss:        0.000450  [ 1952/ 8083]
loss:        0.000544  [ 2112/ 8083]
loss:        0.000453  [ 2272/ 8083]
loss:        0.000477  [ 2432/ 8083]
loss:        0.000492  [ 2592/ 8083]
loss:        0.000576  [ 2752/ 8083]
loss:        0.000512  [ 2912/ 8083]
loss:        0.000521  [ 3072/ 8083]
loss:        0.000432  [ 3232/ 8083]
loss:        0.000499  [ 3392/ 8083]
loss:        0.000507  [ 3552/ 8083]
loss:        0.000550  [ 3712/ 8083]
loss:        0.000606  [ 3872/ 8083]
loss:        0.000575  [ 4032/ 8083]
loss:        0.000427  [ 4192/ 8083]
loss:        0.000588  [ 4352/ 8083]
loss:        0.000642  [ 4512/ 8083]
loss:        0.000388  [ 4672/ 8083]
loss:        0.000472  [ 4832/ 8083]
loss:        0.000474  [ 4992/ 8083]
loss:        0.000533  [ 5152/ 8083]
loss:        0.000495  [ 5312/ 8083]
loss:        0.000466  [ 5472/ 8083]
loss:        0.000458  [ 5632/ 8083]
loss:        0.000404  [ 5792/ 8083]
loss:        0.000463  [ 5952/ 8083]
loss:        0.000465  [ 6112/ 8083]
loss:        0.000524  [ 6272/ 8083]
loss:        0.000535  [ 6432/ 8083]
loss:        0.000544  [ 6592/ 8083]
loss:        0.000509  [ 6752/ 8083]
loss:        0.000545  [ 6912/ 8083]
loss:        0.000542  [ 7072/ 8083]
loss:        0.000468  [ 7232/ 8083]
loss:        0.000556  [ 7392/ 8083]
loss:        0.000548  [ 7552/ 8083]
loss:        0.000531  [ 7712/ 8083]
loss:        0.000613  [ 7872/ 8083]
loss:        0.000535  [ 8032/ 8083]

Avg validation loss:        0.000513
Learning rate: [0.001]
EarlyStopping counter: 4 out of 20


Epoch 23
-------------------------------
loss:        0.000478  [   32/ 8083]
loss:        0.000464  [  192/ 8083]
loss:        0.000514  [  352/ 8083]
loss:        0.000449  [  512/ 8083]
loss:        0.000460  [  672/ 8083]
loss:        0.000490  [  832/ 8083]
loss:        0.000509  [  992/ 8083]
loss:        0.000500  [ 1152/ 8083]
loss:        0.000508  [ 1312/ 8083]
loss:        0.000504  [ 1472/ 8083]
loss:        0.000505  [ 1632/ 8083]
loss:        0.000466  [ 1792/ 8083]
loss:        0.000485  [ 1952/ 8083]
loss:        0.000481  [ 2112/ 8083]
loss:        0.000507  [ 2272/ 8083]
loss:        0.000492  [ 2432/ 8083]
loss:        0.000508  [ 2592/ 8083]
loss:        0.000505  [ 2752/ 8083]
loss:        0.000564  [ 2912/ 8083]
loss:        0.000393  [ 3072/ 8083]
loss:        0.000419  [ 3232/ 8083]
loss:        0.000560  [ 3392/ 8083]
loss:        0.000444  [ 3552/ 8083]
loss:        0.000546  [ 3712/ 8083]
loss:        0.000507  [ 3872/ 8083]
loss:        0.000542  [ 4032/ 8083]
loss:        0.000514  [ 4192/ 8083]
loss:        0.000438  [ 4352/ 8083]
loss:        0.000517  [ 4512/ 8083]
loss:        0.000462  [ 4672/ 8083]
loss:        0.000491  [ 4832/ 8083]
loss:        0.000529  [ 4992/ 8083]
loss:        0.000493  [ 5152/ 8083]
loss:        0.000534  [ 5312/ 8083]
loss:        0.000561  [ 5472/ 8083]
loss:        0.000568  [ 5632/ 8083]
loss:        0.000513  [ 5792/ 8083]
loss:        0.000491  [ 5952/ 8083]
loss:        0.000603  [ 6112/ 8083]
loss:        0.000398  [ 6272/ 8083]
loss:        0.000505  [ 6432/ 8083]
loss:        0.000494  [ 6592/ 8083]
loss:        0.000531  [ 6752/ 8083]
loss:        0.000560  [ 6912/ 8083]
loss:        0.000517  [ 7072/ 8083]
loss:        0.000489  [ 7232/ 8083]
loss:        0.000532  [ 7392/ 8083]
loss:        0.000635  [ 7552/ 8083]
loss:        0.000480  [ 7712/ 8083]
loss:        0.000426  [ 7872/ 8083]
loss:        0.000497  [ 8032/ 8083]

Avg validation loss:        0.000504
Learning rate: [0.001]
EarlyStopping counter: 5 out of 20


Epoch 24
-------------------------------
loss:        0.000485  [   32/ 8083]
loss:        0.000520  [  192/ 8083]
loss:        0.000500  [  352/ 8083]
loss:        0.000447  [  512/ 8083]
loss:        0.000507  [  672/ 8083]
loss:        0.000558  [  832/ 8083]
loss:        0.000566  [  992/ 8083]
loss:        0.000605  [ 1152/ 8083]
loss:        0.000507  [ 1312/ 8083]
loss:        0.000501  [ 1472/ 8083]
loss:        0.000570  [ 1632/ 8083]
loss:        0.000536  [ 1792/ 8083]
loss:        0.000529  [ 1952/ 8083]
loss:        0.000490  [ 2112/ 8083]
loss:        0.000602  [ 2272/ 8083]
loss:        0.000480  [ 2432/ 8083]
loss:        0.000463  [ 2592/ 8083]
loss:        0.000521  [ 2752/ 8083]
loss:        0.000488  [ 2912/ 8083]
loss:        0.000546  [ 3072/ 8083]
loss:        0.000618  [ 3232/ 8083]
loss:        0.000542  [ 3392/ 8083]
loss:        0.000410  [ 3552/ 8083]
loss:        0.000456  [ 3712/ 8083]
loss:        0.000538  [ 3872/ 8083]
loss:        0.000530  [ 4032/ 8083]
loss:        0.000587  [ 4192/ 8083]
loss:        0.000526  [ 4352/ 8083]
loss:        0.000404  [ 4512/ 8083]
loss:        0.000535  [ 4672/ 8083]
loss:        0.000560  [ 4832/ 8083]
loss:        0.000496  [ 4992/ 8083]
loss:        0.000588  [ 5152/ 8083]
loss:        0.000540  [ 5312/ 8083]
loss:        0.000510  [ 5472/ 8083]
loss:        0.000545  [ 5632/ 8083]
loss:        0.000505  [ 5792/ 8083]
loss:        0.000565  [ 5952/ 8083]
loss:        0.000608  [ 6112/ 8083]
loss:        0.000542  [ 6272/ 8083]
loss:        0.000545  [ 6432/ 8083]
loss:        0.000540  [ 6592/ 8083]
loss:        0.000513  [ 6752/ 8083]
loss:        0.000490  [ 6912/ 8083]
loss:        0.000440  [ 7072/ 8083]
loss:        0.000621  [ 7232/ 8083]
loss:        0.000550  [ 7392/ 8083]
loss:        0.000503  [ 7552/ 8083]
loss:        0.000490  [ 7712/ 8083]
loss:        0.000528  [ 7872/ 8083]
loss:        0.000472  [ 8032/ 8083]

Avg validation loss:        0.000504
Learning rate: [0.0001]
EarlyStopping counter: 6 out of 20


Epoch 25
-------------------------------
loss:        0.000389  [   32/ 8083]
loss:        0.000435  [  192/ 8083]
loss:        0.000592  [  352/ 8083]
loss:        0.000564  [  512/ 8083]
loss:        0.000381  [  672/ 8083]
loss:        0.000538  [  832/ 8083]
loss:        0.000587  [  992/ 8083]
loss:        0.000545  [ 1152/ 8083]
loss:        0.000491  [ 1312/ 8083]
loss:        0.000495  [ 1472/ 8083]
loss:        0.000508  [ 1632/ 8083]
loss:        0.000591  [ 1792/ 8083]
loss:        0.000555  [ 1952/ 8083]
loss:        0.000452  [ 2112/ 8083]
loss:        0.000462  [ 2272/ 8083]
loss:        0.000516  [ 2432/ 8083]
loss:        0.000408  [ 2592/ 8083]
loss:        0.000517  [ 2752/ 8083]
loss:        0.000465  [ 2912/ 8083]
loss:        0.000453  [ 3072/ 8083]
loss:        0.000504  [ 3232/ 8083]
loss:        0.000518  [ 3392/ 8083]
loss:        0.000595  [ 3552/ 8083]
loss:        0.000458  [ 3712/ 8083]
loss:        0.000475  [ 3872/ 8083]
loss:        0.000600  [ 4032/ 8083]
loss:        0.000547  [ 4192/ 8083]
loss:        0.000495  [ 4352/ 8083]
loss:        0.000542  [ 4512/ 8083]
loss:        0.000529  [ 4672/ 8083]
loss:        0.000404  [ 4832/ 8083]
loss:        0.000358  [ 4992/ 8083]
loss:        0.000377  [ 5152/ 8083]
loss:        0.000461  [ 5312/ 8083]
loss:        0.000514  [ 5472/ 8083]
loss:        0.000449  [ 5632/ 8083]
loss:        0.000527  [ 5792/ 8083]
loss:        0.000439  [ 5952/ 8083]
loss:        0.000588  [ 6112/ 8083]
loss:        0.000485  [ 6272/ 8083]
loss:        0.000478  [ 6432/ 8083]
loss:        0.000486  [ 6592/ 8083]
loss:        0.000570  [ 6752/ 8083]
loss:        0.000517  [ 6912/ 8083]
loss:        0.000530  [ 7072/ 8083]
loss:        0.000512  [ 7232/ 8083]
loss:        0.000412  [ 7392/ 8083]
loss:        0.000474  [ 7552/ 8083]
loss:        0.000545  [ 7712/ 8083]
loss:        0.000436  [ 7872/ 8083]
loss:        0.000399  [ 8032/ 8083]

Avg validation loss:        0.000501
Learning rate: [0.0001]
Validation loss decreased (0.000503 --> 0.000501).  Saving model ...


Epoch 26
-------------------------------
loss:        0.000563  [   32/ 8083]
loss:        0.000504  [  192/ 8083]
loss:        0.000499  [  352/ 8083]
loss:        0.000480  [  512/ 8083]
loss:        0.000427  [  672/ 8083]
loss:        0.000476  [  832/ 8083]
loss:        0.000364  [  992/ 8083]
loss:        0.000623  [ 1152/ 8083]
loss:        0.000505  [ 1312/ 8083]
loss:        0.000571  [ 1472/ 8083]
loss:        0.000587  [ 1632/ 8083]
loss:        0.000448  [ 1792/ 8083]
loss:        0.000495  [ 1952/ 8083]
loss:        0.000549  [ 2112/ 8083]
loss:        0.000470  [ 2272/ 8083]
loss:        0.000466  [ 2432/ 8083]
loss:        0.000477  [ 2592/ 8083]
loss:        0.000442  [ 2752/ 8083]
loss:        0.000550  [ 2912/ 8083]
loss:        0.000366  [ 3072/ 8083]
loss:        0.000469  [ 3232/ 8083]
loss:        0.000464  [ 3392/ 8083]
loss:        0.000403  [ 3552/ 8083]
loss:        0.000379  [ 3712/ 8083]
loss:        0.000542  [ 3872/ 8083]
loss:        0.000457  [ 4032/ 8083]
loss:        0.000443  [ 4192/ 8083]
loss:        0.000545  [ 4352/ 8083]
loss:        0.000444  [ 4512/ 8083]
loss:        0.000498  [ 4672/ 8083]
loss:        0.000372  [ 4832/ 8083]
loss:        0.000564  [ 4992/ 8083]
loss:        0.000546  [ 5152/ 8083]
loss:        0.000502  [ 5312/ 8083]
loss:        0.000457  [ 5472/ 8083]
loss:        0.000490  [ 5632/ 8083]
loss:        0.000548  [ 5792/ 8083]
loss:        0.000459  [ 5952/ 8083]
loss:        0.000531  [ 6112/ 8083]
loss:        0.000520  [ 6272/ 8083]
loss:        0.000463  [ 6432/ 8083]
loss:        0.000475  [ 6592/ 8083]
loss:        0.000517  [ 6752/ 8083]
loss:        0.000471  [ 6912/ 8083]
loss:        0.000532  [ 7072/ 8083]
loss:        0.000599  [ 7232/ 8083]
loss:        0.000569  [ 7392/ 8083]
loss:        0.000416  [ 7552/ 8083]
loss:        0.000499  [ 7712/ 8083]
loss:        0.000469  [ 7872/ 8083]
loss:        0.000443  [ 8032/ 8083]

Avg validation loss:        0.000506
Learning rate: [0.0001]
EarlyStopping counter: 1 out of 20


Epoch 27
-------------------------------
loss:        0.000596  [   32/ 8083]
loss:        0.000498  [  192/ 8083]
loss:        0.000490  [  352/ 8083]
loss:        0.000518  [  512/ 8083]
loss:        0.000527  [  672/ 8083]
loss:        0.000549  [  832/ 8083]
loss:        0.000459  [  992/ 8083]
loss:        0.000527  [ 1152/ 8083]
loss:        0.000567  [ 1312/ 8083]
loss:        0.000497  [ 1472/ 8083]
loss:        0.000507  [ 1632/ 8083]
loss:        0.000535  [ 1792/ 8083]
loss:        0.000534  [ 1952/ 8083]
loss:        0.000447  [ 2112/ 8083]
loss:        0.000494  [ 2272/ 8083]
loss:        0.000482  [ 2432/ 8083]
loss:        0.000427  [ 2592/ 8083]
loss:        0.000458  [ 2752/ 8083]
loss:        0.000482  [ 2912/ 8083]
loss:        0.000501  [ 3072/ 8083]
loss:        0.000442  [ 3232/ 8083]
loss:        0.000443  [ 3392/ 8083]
loss:        0.000423  [ 3552/ 8083]
loss:        0.000406  [ 3712/ 8083]
loss:        0.000383  [ 3872/ 8083]
loss:        0.000443  [ 4032/ 8083]
loss:        0.000453  [ 4192/ 8083]
loss:        0.000552  [ 4352/ 8083]
loss:        0.000438  [ 4512/ 8083]
loss:        0.000414  [ 4672/ 8083]
loss:        0.000529  [ 4832/ 8083]
loss:        0.000431  [ 4992/ 8083]
loss:        0.000449  [ 5152/ 8083]
loss:        0.000522  [ 5312/ 8083]
loss:        0.000526  [ 5472/ 8083]
loss:        0.000387  [ 5632/ 8083]
loss:        0.000546  [ 5792/ 8083]
loss:        0.000562  [ 5952/ 8083]
loss:        0.000425  [ 6112/ 8083]
loss:        0.000552  [ 6272/ 8083]
loss:        0.000506  [ 6432/ 8083]
loss:        0.000533  [ 6592/ 8083]
loss:        0.000555  [ 6752/ 8083]
loss:        0.000592  [ 6912/ 8083]
loss:        0.000422  [ 7072/ 8083]
loss:        0.000416  [ 7232/ 8083]
loss:        0.000487  [ 7392/ 8083]
loss:        0.000477  [ 7552/ 8083]
loss:        0.000485  [ 7712/ 8083]
loss:        0.000455  [ 7872/ 8083]
loss:        0.000623  [ 8032/ 8083]

Avg validation loss:        0.000501
Learning rate: [0.0001]
Validation loss decreased (0.000501 --> 0.000501).  Saving model ...


Epoch 28
-------------------------------
loss:        0.000419  [   32/ 8083]
loss:        0.000552  [  192/ 8083]
loss:        0.000510  [  352/ 8083]
loss:        0.000601  [  512/ 8083]
loss:        0.000547  [  672/ 8083]
loss:        0.000640  [  832/ 8083]
loss:        0.000538  [  992/ 8083]
loss:        0.000507  [ 1152/ 8083]
loss:        0.000513  [ 1312/ 8083]
loss:        0.000562  [ 1472/ 8083]
loss:        0.000605  [ 1632/ 8083]
loss:        0.000583  [ 1792/ 8083]
loss:        0.000408  [ 1952/ 8083]
loss:        0.000479  [ 2112/ 8083]
loss:        0.000387  [ 2272/ 8083]
loss:        0.000536  [ 2432/ 8083]
loss:        0.000600  [ 2592/ 8083]
loss:        0.000390  [ 2752/ 8083]
loss:        0.000522  [ 2912/ 8083]
loss:        0.000423  [ 3072/ 8083]
loss:        0.000463  [ 3232/ 8083]
loss:        0.000491  [ 3392/ 8083]
loss:        0.000574  [ 3552/ 8083]
loss:        0.000448  [ 3712/ 8083]
loss:        0.000436  [ 3872/ 8083]
loss:        0.000569  [ 4032/ 8083]
loss:        0.000469  [ 4192/ 8083]
loss:        0.000418  [ 4352/ 8083]
loss:        0.000492  [ 4512/ 8083]
loss:        0.000528  [ 4672/ 8083]
loss:        0.000503  [ 4832/ 8083]
loss:        0.000496  [ 4992/ 8083]
loss:        0.000588  [ 5152/ 8083]
loss:        0.000442  [ 5312/ 8083]
loss:        0.000501  [ 5472/ 8083]
loss:        0.000489  [ 5632/ 8083]
loss:        0.000480  [ 5792/ 8083]
loss:        0.000513  [ 5952/ 8083]
loss:        0.000481  [ 6112/ 8083]
loss:        0.000430  [ 6272/ 8083]
loss:        0.000539  [ 6432/ 8083]
loss:        0.000480  [ 6592/ 8083]
loss:        0.000545  [ 6752/ 8083]
loss:        0.000551  [ 6912/ 8083]
loss:        0.000443  [ 7072/ 8083]
loss:        0.000479  [ 7232/ 8083]
loss:        0.000440  [ 7392/ 8083]
loss:        0.000447  [ 7552/ 8083]
loss:        0.000533  [ 7712/ 8083]
loss:        0.000543  [ 7872/ 8083]
loss:        0.000443  [ 8032/ 8083]

Avg validation loss:        0.000502
Learning rate: [0.0001]
EarlyStopping counter: 1 out of 20


Epoch 29
-------------------------------
loss:        0.000421  [   32/ 8083]
loss:        0.000490  [  192/ 8083]
loss:        0.000516  [  352/ 8083]
loss:        0.000493  [  512/ 8083]
loss:        0.000518  [  672/ 8083]
loss:        0.000418  [  832/ 8083]
loss:        0.000470  [  992/ 8083]
loss:        0.000502  [ 1152/ 8083]
loss:        0.000458  [ 1312/ 8083]
loss:        0.000440  [ 1472/ 8083]
loss:        0.000547  [ 1632/ 8083]
loss:        0.000528  [ 1792/ 8083]
loss:        0.000555  [ 1952/ 8083]
loss:        0.000478  [ 2112/ 8083]
loss:        0.000518  [ 2272/ 8083]
loss:        0.000404  [ 2432/ 8083]
loss:        0.000540  [ 2592/ 8083]
loss:        0.000522  [ 2752/ 8083]
loss:        0.000542  [ 2912/ 8083]
loss:        0.000523  [ 3072/ 8083]
loss:        0.000542  [ 3232/ 8083]
loss:        0.000485  [ 3392/ 8083]
loss:        0.000508  [ 3552/ 8083]
loss:        0.000490  [ 3712/ 8083]
loss:        0.000677  [ 3872/ 8083]
loss:        0.000507  [ 4032/ 8083]
loss:        0.000512  [ 4192/ 8083]
loss:        0.000534  [ 4352/ 8083]
loss:        0.000598  [ 4512/ 8083]
loss:        0.000385  [ 4672/ 8083]
loss:        0.000492  [ 4832/ 8083]
loss:        0.000519  [ 4992/ 8083]
loss:        0.000434  [ 5152/ 8083]
loss:        0.000608  [ 5312/ 8083]
loss:        0.000448  [ 5472/ 8083]
loss:        0.000580  [ 5632/ 8083]
loss:        0.000499  [ 5792/ 8083]
loss:        0.000394  [ 5952/ 8083]
loss:        0.000473  [ 6112/ 8083]
loss:        0.000498  [ 6272/ 8083]
loss:        0.000494  [ 6432/ 8083]
loss:        0.000493  [ 6592/ 8083]
loss:        0.000551  [ 6752/ 8083]
loss:        0.000472  [ 6912/ 8083]
loss:        0.000462  [ 7072/ 8083]
loss:        0.000566  [ 7232/ 8083]
loss:        0.000469  [ 7392/ 8083]
loss:        0.000463  [ 7552/ 8083]
loss:        0.000431  [ 7712/ 8083]
loss:        0.000561  [ 7872/ 8083]
loss:        0.000459  [ 8032/ 8083]

Avg validation loss:        0.000501
Learning rate: [0.0001]
EarlyStopping counter: 2 out of 20


Epoch 30
-------------------------------
loss:        0.000487  [   32/ 8083]
loss:        0.000511  [  192/ 8083]
loss:        0.000483  [  352/ 8083]
loss:        0.000465  [  512/ 8083]
loss:        0.000471  [  672/ 8083]
loss:        0.000494  [  832/ 8083]
loss:        0.000454  [  992/ 8083]
loss:        0.000468  [ 1152/ 8083]
loss:        0.000626  [ 1312/ 8083]
loss:        0.000490  [ 1472/ 8083]
loss:        0.000517  [ 1632/ 8083]
loss:        0.000510  [ 1792/ 8083]
loss:        0.000473  [ 1952/ 8083]
loss:        0.000489  [ 2112/ 8083]
loss:        0.000438  [ 2272/ 8083]
loss:        0.000479  [ 2432/ 8083]
loss:        0.000394  [ 2592/ 8083]
loss:        0.000546  [ 2752/ 8083]
loss:        0.000390  [ 2912/ 8083]
loss:        0.000501  [ 3072/ 8083]
loss:        0.000452  [ 3232/ 8083]
loss:        0.000559  [ 3392/ 8083]
loss:        0.000426  [ 3552/ 8083]
loss:        0.000393  [ 3712/ 8083]
loss:        0.000483  [ 3872/ 8083]
loss:        0.000564  [ 4032/ 8083]
loss:        0.000525  [ 4192/ 8083]
loss:        0.000501  [ 4352/ 8083]
loss:        0.000507  [ 4512/ 8083]
loss:        0.000503  [ 4672/ 8083]
loss:        0.000515  [ 4832/ 8083]
loss:        0.000504  [ 4992/ 8083]
loss:        0.000553  [ 5152/ 8083]
loss:        0.000478  [ 5312/ 8083]
loss:        0.000497  [ 5472/ 8083]
loss:        0.000469  [ 5632/ 8083]
loss:        0.000495  [ 5792/ 8083]
loss:        0.000537  [ 5952/ 8083]
loss:        0.000459  [ 6112/ 8083]
loss:        0.000478  [ 6272/ 8083]
loss:        0.000398  [ 6432/ 8083]
loss:        0.000546  [ 6592/ 8083]
loss:        0.000466  [ 6752/ 8083]
loss:        0.000532  [ 6912/ 8083]
loss:        0.000512  [ 7072/ 8083]
loss:        0.000454  [ 7232/ 8083]
loss:        0.000435  [ 7392/ 8083]
loss:        0.000540  [ 7552/ 8083]
loss:        0.000495  [ 7712/ 8083]
loss:        0.000540  [ 7872/ 8083]
loss:        0.000515  [ 8032/ 8083]

Avg validation loss:        0.000500
Learning rate: [0.0001]
Validation loss decreased (0.000501 --> 0.000500).  Saving model ...


Epoch 31
-------------------------------
loss:        0.000470  [   32/ 8083]
loss:        0.000546  [  192/ 8083]
loss:        0.000557  [  352/ 8083]
loss:        0.000527  [  512/ 8083]
loss:        0.000589  [  672/ 8083]
loss:        0.000452  [  832/ 8083]
loss:        0.000511  [  992/ 8083]
loss:        0.000565  [ 1152/ 8083]
loss:        0.000472  [ 1312/ 8083]
loss:        0.000523  [ 1472/ 8083]
loss:        0.000463  [ 1632/ 8083]
loss:        0.000394  [ 1792/ 8083]
loss:        0.000469  [ 1952/ 8083]
loss:        0.000519  [ 2112/ 8083]
loss:        0.000386  [ 2272/ 8083]
loss:        0.000393  [ 2432/ 8083]
loss:        0.000464  [ 2592/ 8083]
loss:        0.000409  [ 2752/ 8083]
loss:        0.000445  [ 2912/ 8083]
loss:        0.000463  [ 3072/ 8083]
loss:        0.000503  [ 3232/ 8083]
loss:        0.000510  [ 3392/ 8083]
loss:        0.000527  [ 3552/ 8083]
loss:        0.000531  [ 3712/ 8083]
loss:        0.000449  [ 3872/ 8083]
loss:        0.000621  [ 4032/ 8083]
loss:        0.000460  [ 4192/ 8083]
loss:        0.000371  [ 4352/ 8083]
loss:        0.000573  [ 4512/ 8083]
loss:        0.000590  [ 4672/ 8083]
loss:        0.000449  [ 4832/ 8083]
loss:        0.000522  [ 4992/ 8083]
loss:        0.000584  [ 5152/ 8083]
loss:        0.000622  [ 5312/ 8083]
loss:        0.000504  [ 5472/ 8083]
loss:        0.000477  [ 5632/ 8083]
loss:        0.000484  [ 5792/ 8083]
loss:        0.000566  [ 5952/ 8083]
loss:        0.000488  [ 6112/ 8083]
loss:        0.000444  [ 6272/ 8083]
loss:        0.000560  [ 6432/ 8083]
loss:        0.000478  [ 6592/ 8083]
loss:        0.000422  [ 6752/ 8083]
loss:        0.000519  [ 6912/ 8083]
loss:        0.000472  [ 7072/ 8083]
loss:        0.000396  [ 7232/ 8083]
loss:        0.000553  [ 7392/ 8083]
loss:        0.000468  [ 7552/ 8083]
loss:        0.000421  [ 7712/ 8083]
loss:        0.000551  [ 7872/ 8083]
loss:        0.000445  [ 8032/ 8083]

Avg validation loss:        0.000502
Learning rate: [0.0001]
EarlyStopping counter: 1 out of 20


Epoch 32
-------------------------------
loss:        0.000510  [   32/ 8083]
loss:        0.000469  [  192/ 8083]
loss:        0.000485  [  352/ 8083]
loss:        0.000532  [  512/ 8083]
loss:        0.000480  [  672/ 8083]
loss:        0.000496  [  832/ 8083]
loss:        0.000527  [  992/ 8083]
loss:        0.000502  [ 1152/ 8083]
loss:        0.000551  [ 1312/ 8083]
loss:        0.000582  [ 1472/ 8083]
loss:        0.000563  [ 1632/ 8083]
loss:        0.000466  [ 1792/ 8083]
loss:        0.000466  [ 1952/ 8083]
loss:        0.000422  [ 2112/ 8083]
loss:        0.000564  [ 2272/ 8083]
loss:        0.000480  [ 2432/ 8083]
loss:        0.000563  [ 2592/ 8083]
loss:        0.000502  [ 2752/ 8083]
loss:        0.000537  [ 2912/ 8083]
loss:        0.000506  [ 3072/ 8083]
loss:        0.000486  [ 3232/ 8083]
loss:        0.000539  [ 3392/ 8083]
loss:        0.000507  [ 3552/ 8083]
loss:        0.000441  [ 3712/ 8083]
loss:        0.000504  [ 3872/ 8083]
loss:        0.000462  [ 4032/ 8083]
loss:        0.000522  [ 4192/ 8083]
loss:        0.000456  [ 4352/ 8083]
loss:        0.000425  [ 4512/ 8083]
loss:        0.000502  [ 4672/ 8083]
loss:        0.000452  [ 4832/ 8083]
loss:        0.000392  [ 4992/ 8083]
loss:        0.000401  [ 5152/ 8083]
loss:        0.000456  [ 5312/ 8083]
loss:        0.000480  [ 5472/ 8083]
loss:        0.000497  [ 5632/ 8083]
loss:        0.000468  [ 5792/ 8083]
loss:        0.000542  [ 5952/ 8083]
loss:        0.000434  [ 6112/ 8083]
loss:        0.000513  [ 6272/ 8083]
loss:        0.000537  [ 6432/ 8083]
loss:        0.000474  [ 6592/ 8083]
loss:        0.000450  [ 6752/ 8083]
loss:        0.000487  [ 6912/ 8083]
loss:        0.000522  [ 7072/ 8083]
loss:        0.000506  [ 7232/ 8083]
loss:        0.000486  [ 7392/ 8083]
loss:        0.000554  [ 7552/ 8083]
loss:        0.000521  [ 7712/ 8083]
loss:        0.000428  [ 7872/ 8083]
loss:        0.000514  [ 8032/ 8083]

Avg validation loss:        0.000502
Learning rate: [0.0001]
EarlyStopping counter: 2 out of 20


Epoch 33
-------------------------------
loss:        0.000541  [   32/ 8083]
loss:        0.000679  [  192/ 8083]
loss:        0.000508  [  352/ 8083]
loss:        0.000474  [  512/ 8083]
loss:        0.000498  [  672/ 8083]
loss:        0.000464  [  832/ 8083]
loss:        0.000549  [  992/ 8083]
loss:        0.000544  [ 1152/ 8083]
loss:        0.000452  [ 1312/ 8083]
loss:        0.000532  [ 1472/ 8083]
loss:        0.000430  [ 1632/ 8083]
loss:        0.000523  [ 1792/ 8083]
loss:        0.000487  [ 1952/ 8083]
loss:        0.000540  [ 2112/ 8083]
loss:        0.000437  [ 2272/ 8083]
loss:        0.000593  [ 2432/ 8083]
loss:        0.000483  [ 2592/ 8083]
loss:        0.000512  [ 2752/ 8083]
loss:        0.000442  [ 2912/ 8083]
loss:        0.000558  [ 3072/ 8083]
loss:        0.000447  [ 3232/ 8083]
loss:        0.000436  [ 3392/ 8083]
loss:        0.000433  [ 3552/ 8083]
loss:        0.000509  [ 3712/ 8083]
loss:        0.000511  [ 3872/ 8083]
loss:        0.000433  [ 4032/ 8083]
loss:        0.000543  [ 4192/ 8083]
loss:        0.000528  [ 4352/ 8083]
loss:        0.000557  [ 4512/ 8083]
loss:        0.000656  [ 4672/ 8083]
loss:        0.000546  [ 4832/ 8083]
loss:        0.000535  [ 4992/ 8083]
loss:        0.000504  [ 5152/ 8083]
loss:        0.000503  [ 5312/ 8083]
loss:        0.000514  [ 5472/ 8083]
loss:        0.000419  [ 5632/ 8083]
loss:        0.000458  [ 5792/ 8083]
loss:        0.000394  [ 5952/ 8083]
loss:        0.000459  [ 6112/ 8083]
loss:        0.000515  [ 6272/ 8083]
loss:        0.000497  [ 6432/ 8083]
loss:        0.000471  [ 6592/ 8083]
loss:        0.000486  [ 6752/ 8083]
loss:        0.000537  [ 6912/ 8083]
loss:        0.000373  [ 7072/ 8083]
loss:        0.000397  [ 7232/ 8083]
loss:        0.000442  [ 7392/ 8083]
loss:        0.000584  [ 7552/ 8083]
loss:        0.000599  [ 7712/ 8083]
loss:        0.000471  [ 7872/ 8083]
loss:        0.000582  [ 8032/ 8083]

Avg validation loss:        0.000502
Learning rate: [0.0001]
EarlyStopping counter: 3 out of 20


Epoch 34
-------------------------------
loss:        0.000439  [   32/ 8083]
loss:        0.000479  [  192/ 8083]
loss:        0.000538  [  352/ 8083]
loss:        0.000519  [  512/ 8083]
loss:        0.000468  [  672/ 8083]
loss:        0.000505  [  832/ 8083]
loss:        0.000497  [  992/ 8083]
loss:        0.000482  [ 1152/ 8083]
loss:        0.000589  [ 1312/ 8083]
loss:        0.000566  [ 1472/ 8083]
loss:        0.000444  [ 1632/ 8083]
loss:        0.000588  [ 1792/ 8083]
loss:        0.000519  [ 1952/ 8083]
loss:        0.000500  [ 2112/ 8083]
loss:        0.000460  [ 2272/ 8083]
loss:        0.000551  [ 2432/ 8083]
loss:        0.000468  [ 2592/ 8083]
loss:        0.000508  [ 2752/ 8083]
loss:        0.000424  [ 2912/ 8083]
loss:        0.000473  [ 3072/ 8083]
loss:        0.000512  [ 3232/ 8083]
loss:        0.000537  [ 3392/ 8083]
loss:        0.000443  [ 3552/ 8083]
loss:        0.000427  [ 3712/ 8083]
loss:        0.000627  [ 3872/ 8083]
loss:        0.000440  [ 4032/ 8083]
loss:        0.000464  [ 4192/ 8083]
loss:        0.000507  [ 4352/ 8083]
loss:        0.000378  [ 4512/ 8083]
loss:        0.000449  [ 4672/ 8083]
loss:        0.000551  [ 4832/ 8083]
loss:        0.000516  [ 4992/ 8083]
loss:        0.000514  [ 5152/ 8083]
loss:        0.000509  [ 5312/ 8083]
loss:        0.000359  [ 5472/ 8083]
loss:        0.000513  [ 5632/ 8083]
loss:        0.000443  [ 5792/ 8083]
loss:        0.000607  [ 5952/ 8083]
loss:        0.000393  [ 6112/ 8083]
loss:        0.000536  [ 6272/ 8083]
loss:        0.000447  [ 6432/ 8083]
loss:        0.000501  [ 6592/ 8083]
loss:        0.000526  [ 6752/ 8083]
loss:        0.000631  [ 6912/ 8083]
loss:        0.000607  [ 7072/ 8083]
loss:        0.000503  [ 7232/ 8083]
loss:        0.000428  [ 7392/ 8083]
loss:        0.000518  [ 7552/ 8083]
loss:        0.000453  [ 7712/ 8083]
loss:        0.000496  [ 7872/ 8083]
loss:        0.000475  [ 8032/ 8083]

Avg validation loss:        0.000502
Learning rate: [0.0001]
EarlyStopping counter: 4 out of 20


Epoch 35
-------------------------------
loss:        0.000385  [   32/ 8083]
loss:        0.000538  [  192/ 8083]
loss:        0.000432  [  352/ 8083]
loss:        0.000489  [  512/ 8083]
loss:        0.000589  [  672/ 8083]
loss:        0.000468  [  832/ 8083]
loss:        0.000489  [  992/ 8083]
loss:        0.000426  [ 1152/ 8083]
loss:        0.000485  [ 1312/ 8083]
loss:        0.000530  [ 1472/ 8083]
loss:        0.000493  [ 1632/ 8083]
loss:        0.000603  [ 1792/ 8083]
loss:        0.000446  [ 1952/ 8083]
loss:        0.000580  [ 2112/ 8083]
loss:        0.000484  [ 2272/ 8083]
loss:        0.000492  [ 2432/ 8083]
loss:        0.000597  [ 2592/ 8083]
loss:        0.000586  [ 2752/ 8083]
loss:        0.000542  [ 2912/ 8083]
loss:        0.000414  [ 3072/ 8083]
loss:        0.000544  [ 3232/ 8083]
loss:        0.000473  [ 3392/ 8083]
loss:        0.000491  [ 3552/ 8083]
loss:        0.000487  [ 3712/ 8083]
loss:        0.000428  [ 3872/ 8083]
loss:        0.000616  [ 4032/ 8083]
loss:        0.000533  [ 4192/ 8083]
loss:        0.000477  [ 4352/ 8083]
loss:        0.000476  [ 4512/ 8083]
loss:        0.000492  [ 4672/ 8083]
loss:        0.000512  [ 4832/ 8083]
loss:        0.000458  [ 4992/ 8083]
loss:        0.000531  [ 5152/ 8083]
loss:        0.000531  [ 5312/ 8083]
loss:        0.000456  [ 5472/ 8083]
loss:        0.000445  [ 5632/ 8083]
loss:        0.000466  [ 5792/ 8083]
loss:        0.000512  [ 5952/ 8083]
loss:        0.000552  [ 6112/ 8083]
loss:        0.000415  [ 6272/ 8083]
loss:        0.000467  [ 6432/ 8083]
loss:        0.000567  [ 6592/ 8083]
loss:        0.000523  [ 6752/ 8083]
loss:        0.000506  [ 6912/ 8083]
loss:        0.000447  [ 7072/ 8083]
loss:        0.000502  [ 7232/ 8083]
loss:        0.000534  [ 7392/ 8083]
loss:        0.000432  [ 7552/ 8083]
loss:        0.000578  [ 7712/ 8083]
loss:        0.000467  [ 7872/ 8083]
loss:        0.000628  [ 8032/ 8083]

Avg validation loss:        0.000500
Learning rate: [0.0001]
Validation loss decreased (0.000500 --> 0.000500).  Saving model ...


Epoch 36
-------------------------------
loss:        0.000504  [   32/ 8083]
loss:        0.000439  [  192/ 8083]
loss:        0.000546  [  352/ 8083]
loss:        0.000497  [  512/ 8083]
loss:        0.000376  [  672/ 8083]
loss:        0.000528  [  832/ 8083]
loss:        0.000463  [  992/ 8083]
loss:        0.000541  [ 1152/ 8083]
loss:        0.000487  [ 1312/ 8083]
loss:        0.000558  [ 1472/ 8083]
loss:        0.000510  [ 1632/ 8083]
loss:        0.000503  [ 1792/ 8083]
loss:        0.000482  [ 1952/ 8083]
loss:        0.000491  [ 2112/ 8083]
loss:        0.000479  [ 2272/ 8083]
loss:        0.000537  [ 2432/ 8083]
loss:        0.000453  [ 2592/ 8083]
loss:        0.000530  [ 2752/ 8083]
loss:        0.000518  [ 2912/ 8083]
loss:        0.000533  [ 3072/ 8083]
loss:        0.000525  [ 3232/ 8083]
loss:        0.000485  [ 3392/ 8083]
loss:        0.000496  [ 3552/ 8083]
loss:        0.000520  [ 3712/ 8083]
loss:        0.000582  [ 3872/ 8083]
loss:        0.000481  [ 4032/ 8083]
loss:        0.000538  [ 4192/ 8083]
loss:        0.000412  [ 4352/ 8083]
loss:        0.000451  [ 4512/ 8083]
loss:        0.000555  [ 4672/ 8083]
loss:        0.000527  [ 4832/ 8083]
loss:        0.000620  [ 4992/ 8083]
loss:        0.000403  [ 5152/ 8083]
loss:        0.000460  [ 5312/ 8083]
loss:        0.000484  [ 5472/ 8083]
loss:        0.000421  [ 5632/ 8083]
loss:        0.000466  [ 5792/ 8083]
loss:        0.000531  [ 5952/ 8083]
loss:        0.000485  [ 6112/ 8083]
loss:        0.000463  [ 6272/ 8083]
loss:        0.000446  [ 6432/ 8083]
loss:        0.000466  [ 6592/ 8083]
loss:        0.000519  [ 6752/ 8083]
loss:        0.000479  [ 6912/ 8083]
loss:        0.000508  [ 7072/ 8083]
loss:        0.000473  [ 7232/ 8083]
loss:        0.000490  [ 7392/ 8083]
loss:        0.000464  [ 7552/ 8083]
loss:        0.000454  [ 7712/ 8083]
loss:        0.000491  [ 7872/ 8083]
loss:        0.000384  [ 8032/ 8083]

Avg validation loss:        0.000503
Learning rate: [0.0001]
EarlyStopping counter: 1 out of 20


Epoch 37
-------------------------------
loss:        0.000484  [   32/ 8083]
loss:        0.000521  [  192/ 8083]
loss:        0.000466  [  352/ 8083]
loss:        0.000493  [  512/ 8083]
loss:        0.000495  [  672/ 8083]
loss:        0.000473  [  832/ 8083]
loss:        0.000592  [  992/ 8083]
loss:        0.000489  [ 1152/ 8083]
loss:        0.000549  [ 1312/ 8083]
loss:        0.000507  [ 1472/ 8083]
loss:        0.000519  [ 1632/ 8083]
loss:        0.000433  [ 1792/ 8083]
loss:        0.000498  [ 1952/ 8083]
loss:        0.000547  [ 2112/ 8083]
loss:        0.000553  [ 2272/ 8083]
loss:        0.000464  [ 2432/ 8083]
loss:        0.000480  [ 2592/ 8083]
loss:        0.000599  [ 2752/ 8083]
loss:        0.000495  [ 2912/ 8083]
loss:        0.000547  [ 3072/ 8083]
loss:        0.000496  [ 3232/ 8083]
loss:        0.000526  [ 3392/ 8083]
loss:        0.000464  [ 3552/ 8083]
loss:        0.000521  [ 3712/ 8083]
loss:        0.000495  [ 3872/ 8083]
loss:        0.000535  [ 4032/ 8083]
loss:        0.000484  [ 4192/ 8083]
loss:        0.000463  [ 4352/ 8083]
loss:        0.000499  [ 4512/ 8083]
loss:        0.000580  [ 4672/ 8083]
loss:        0.000452  [ 4832/ 8083]
loss:        0.000451  [ 4992/ 8083]
loss:        0.000680  [ 5152/ 8083]
loss:        0.000477  [ 5312/ 8083]
loss:        0.000568  [ 5472/ 8083]
loss:        0.000552  [ 5632/ 8083]
loss:        0.000444  [ 5792/ 8083]
loss:        0.000446  [ 5952/ 8083]
loss:        0.000478  [ 6112/ 8083]
loss:        0.000390  [ 6272/ 8083]
loss:        0.000475  [ 6432/ 8083]
loss:        0.000496  [ 6592/ 8083]
loss:        0.000561  [ 6752/ 8083]
loss:        0.000462  [ 6912/ 8083]
loss:        0.000621  [ 7072/ 8083]
loss:        0.000505  [ 7232/ 8083]
loss:        0.000483  [ 7392/ 8083]
loss:        0.000485  [ 7552/ 8083]
loss:        0.000465  [ 7712/ 8083]
loss:        0.000478  [ 7872/ 8083]
loss:        0.000327  [ 8032/ 8083]

Avg validation loss:        0.000503
Learning rate: [0.0001]
EarlyStopping counter: 2 out of 20


Epoch 38
-------------------------------
loss:        0.000465  [   32/ 8083]
loss:        0.000555  [  192/ 8083]
loss:        0.000520  [  352/ 8083]
loss:        0.000462  [  512/ 8083]
loss:        0.000445  [  672/ 8083]
loss:        0.000381  [  832/ 8083]
loss:        0.000432  [  992/ 8083]
loss:        0.000481  [ 1152/ 8083]
loss:        0.000578  [ 1312/ 8083]
loss:        0.000580  [ 1472/ 8083]
loss:        0.000440  [ 1632/ 8083]
loss:        0.000515  [ 1792/ 8083]
loss:        0.000479  [ 1952/ 8083]
loss:        0.000472  [ 2112/ 8083]
loss:        0.000483  [ 2272/ 8083]
loss:        0.000530  [ 2432/ 8083]
loss:        0.000419  [ 2592/ 8083]
loss:        0.000534  [ 2752/ 8083]
loss:        0.000485  [ 2912/ 8083]
loss:        0.000535  [ 3072/ 8083]
loss:        0.000454  [ 3232/ 8083]
loss:        0.000441  [ 3392/ 8083]
loss:        0.000575  [ 3552/ 8083]
loss:        0.000544  [ 3712/ 8083]
loss:        0.000535  [ 3872/ 8083]
loss:        0.000409  [ 4032/ 8083]
loss:        0.000483  [ 4192/ 8083]
loss:        0.000502  [ 4352/ 8083]
loss:        0.000497  [ 4512/ 8083]
loss:        0.000483  [ 4672/ 8083]
loss:        0.000439  [ 4832/ 8083]
loss:        0.000530  [ 4992/ 8083]
loss:        0.000623  [ 5152/ 8083]
loss:        0.000512  [ 5312/ 8083]
loss:        0.000554  [ 5472/ 8083]
loss:        0.000437  [ 5632/ 8083]
loss:        0.000463  [ 5792/ 8083]
loss:        0.000505  [ 5952/ 8083]
loss:        0.000455  [ 6112/ 8083]
loss:        0.000452  [ 6272/ 8083]
loss:        0.000599  [ 6432/ 8083]
loss:        0.000491  [ 6592/ 8083]
loss:        0.000492  [ 6752/ 8083]
loss:        0.000459  [ 6912/ 8083]
loss:        0.000472  [ 7072/ 8083]
loss:        0.000510  [ 7232/ 8083]
loss:        0.000443  [ 7392/ 8083]
loss:        0.000488  [ 7552/ 8083]
loss:        0.000525  [ 7712/ 8083]
loss:        0.000571  [ 7872/ 8083]
loss:        0.000495  [ 8032/ 8083]

Avg validation loss:        0.000502
Learning rate: [0.0001]
EarlyStopping counter: 3 out of 20


Epoch 39
-------------------------------
loss:        0.000488  [   32/ 8083]
loss:        0.000507  [  192/ 8083]
loss:        0.000375  [  352/ 8083]
loss:        0.000587  [  512/ 8083]
loss:        0.000538  [  672/ 8083]
loss:        0.000479  [  832/ 8083]
loss:        0.000590  [  992/ 8083]
loss:        0.000441  [ 1152/ 8083]
loss:        0.000481  [ 1312/ 8083]
loss:        0.000467  [ 1472/ 8083]
loss:        0.000476  [ 1632/ 8083]
loss:        0.000401  [ 1792/ 8083]
loss:        0.000481  [ 1952/ 8083]
loss:        0.000530  [ 2112/ 8083]
loss:        0.000464  [ 2272/ 8083]
loss:        0.000509  [ 2432/ 8083]
loss:        0.000505  [ 2592/ 8083]
loss:        0.000545  [ 2752/ 8083]
loss:        0.000501  [ 2912/ 8083]
loss:        0.000535  [ 3072/ 8083]
loss:        0.000500  [ 3232/ 8083]
loss:        0.000486  [ 3392/ 8083]
loss:        0.000390  [ 3552/ 8083]
loss:        0.000542  [ 3712/ 8083]
loss:        0.000568  [ 3872/ 8083]
loss:        0.000575  [ 4032/ 8083]
loss:        0.000457  [ 4192/ 8083]
loss:        0.000451  [ 4352/ 8083]
loss:        0.000532  [ 4512/ 8083]
loss:        0.000580  [ 4672/ 8083]
loss:        0.000563  [ 4832/ 8083]
loss:        0.000494  [ 4992/ 8083]
loss:        0.000512  [ 5152/ 8083]
loss:        0.000606  [ 5312/ 8083]
loss:        0.000514  [ 5472/ 8083]
loss:        0.000466  [ 5632/ 8083]
loss:        0.000437  [ 5792/ 8083]
loss:        0.000441  [ 5952/ 8083]
loss:        0.000637  [ 6112/ 8083]
loss:        0.000435  [ 6272/ 8083]
loss:        0.000364  [ 6432/ 8083]
loss:        0.000543  [ 6592/ 8083]
loss:        0.000619  [ 6752/ 8083]
loss:        0.000484  [ 6912/ 8083]
loss:        0.000513  [ 7072/ 8083]
loss:        0.000590  [ 7232/ 8083]
loss:        0.000468  [ 7392/ 8083]
loss:        0.000492  [ 7552/ 8083]
loss:        0.000551  [ 7712/ 8083]
loss:        0.000431  [ 7872/ 8083]
loss:        0.000502  [ 8032/ 8083]

Avg validation loss:        0.000501
Learning rate: [0.0001]
EarlyStopping counter: 4 out of 20


Epoch 40
-------------------------------
loss:        0.000517  [   32/ 8083]
loss:        0.000537  [  192/ 8083]
loss:        0.000438  [  352/ 8083]
loss:        0.000525  [  512/ 8083]
loss:        0.000415  [  672/ 8083]
loss:        0.000468  [  832/ 8083]
loss:        0.000355  [  992/ 8083]
loss:        0.000451  [ 1152/ 8083]
loss:        0.000423  [ 1312/ 8083]
loss:        0.000488  [ 1472/ 8083]
loss:        0.000477  [ 1632/ 8083]
loss:        0.000528  [ 1792/ 8083]
loss:        0.000443  [ 1952/ 8083]
loss:        0.000496  [ 2112/ 8083]
loss:        0.000420  [ 2272/ 8083]
loss:        0.000604  [ 2432/ 8083]
loss:        0.000479  [ 2592/ 8083]
loss:        0.000517  [ 2752/ 8083]
loss:        0.000465  [ 2912/ 8083]
loss:        0.000415  [ 3072/ 8083]
loss:        0.000509  [ 3232/ 8083]
loss:        0.000389  [ 3392/ 8083]
loss:        0.000494  [ 3552/ 8083]
loss:        0.000546  [ 3712/ 8083]
loss:        0.000528  [ 3872/ 8083]
loss:        0.000465  [ 4032/ 8083]
loss:        0.000528  [ 4192/ 8083]
loss:        0.000524  [ 4352/ 8083]
loss:        0.000540  [ 4512/ 8083]
loss:        0.000412  [ 4672/ 8083]
loss:        0.000453  [ 4832/ 8083]
loss:        0.000512  [ 4992/ 8083]
loss:        0.000510  [ 5152/ 8083]
loss:        0.000489  [ 5312/ 8083]
loss:        0.000547  [ 5472/ 8083]
loss:        0.000504  [ 5632/ 8083]
loss:        0.000474  [ 5792/ 8083]
loss:        0.000511  [ 5952/ 8083]
loss:        0.000430  [ 6112/ 8083]
loss:        0.000536  [ 6272/ 8083]
loss:        0.000434  [ 6432/ 8083]
loss:        0.000457  [ 6592/ 8083]
loss:        0.000548  [ 6752/ 8083]
loss:        0.000374  [ 6912/ 8083]
loss:        0.000408  [ 7072/ 8083]
loss:        0.000499  [ 7232/ 8083]
loss:        0.000438  [ 7392/ 8083]
loss:        0.000526  [ 7552/ 8083]
loss:        0.000483  [ 7712/ 8083]
loss:        0.000642  [ 7872/ 8083]
loss:        0.000510  [ 8032/ 8083]

Avg validation loss:        0.000503
Learning rate: [0.0001]
EarlyStopping counter: 5 out of 20


Epoch 41
-------------------------------
loss:        0.000549  [   32/ 8083]
loss:        0.000363  [  192/ 8083]
loss:        0.000471  [  352/ 8083]
loss:        0.000519  [  512/ 8083]
loss:        0.000539  [  672/ 8083]
loss:        0.000484  [  832/ 8083]
loss:        0.000507  [  992/ 8083]
loss:        0.000585  [ 1152/ 8083]
loss:        0.000363  [ 1312/ 8083]
loss:        0.000447  [ 1472/ 8083]
loss:        0.000447  [ 1632/ 8083]
loss:        0.000385  [ 1792/ 8083]
loss:        0.000399  [ 1952/ 8083]
loss:        0.000615  [ 2112/ 8083]
loss:        0.000524  [ 2272/ 8083]
loss:        0.000506  [ 2432/ 8083]
loss:        0.000539  [ 2592/ 8083]
loss:        0.000602  [ 2752/ 8083]
loss:        0.000481  [ 2912/ 8083]
loss:        0.000438  [ 3072/ 8083]
loss:        0.000473  [ 3232/ 8083]
loss:        0.000450  [ 3392/ 8083]
loss:        0.000573  [ 3552/ 8083]
loss:        0.000481  [ 3712/ 8083]
loss:        0.000509  [ 3872/ 8083]
loss:        0.000496  [ 4032/ 8083]
loss:        0.000506  [ 4192/ 8083]
loss:        0.000462  [ 4352/ 8083]
loss:        0.000546  [ 4512/ 8083]
loss:        0.000499  [ 4672/ 8083]
loss:        0.000585  [ 4832/ 8083]
loss:        0.000581  [ 4992/ 8083]
loss:        0.000473  [ 5152/ 8083]
loss:        0.000513  [ 5312/ 8083]
loss:        0.000495  [ 5472/ 8083]
loss:        0.000513  [ 5632/ 8083]
loss:        0.000500  [ 5792/ 8083]
loss:        0.000486  [ 5952/ 8083]
loss:        0.000584  [ 6112/ 8083]
loss:        0.000540  [ 6272/ 8083]
loss:        0.000489  [ 6432/ 8083]
loss:        0.000491  [ 6592/ 8083]
loss:        0.000506  [ 6752/ 8083]
loss:        0.000452  [ 6912/ 8083]
loss:        0.000468  [ 7072/ 8083]
loss:        0.000447  [ 7232/ 8083]
loss:        0.000562  [ 7392/ 8083]
loss:        0.000448  [ 7552/ 8083]
loss:        0.000606  [ 7712/ 8083]
loss:        0.000537  [ 7872/ 8083]
loss:        0.000472  [ 8032/ 8083]

Avg validation loss:        0.000501
Learning rate: [1e-05]
EarlyStopping counter: 6 out of 20


Epoch 42
-------------------------------
loss:        0.000510  [   32/ 8083]
loss:        0.000587  [  192/ 8083]
loss:        0.000484  [  352/ 8083]
loss:        0.000426  [  512/ 8083]
loss:        0.000469  [  672/ 8083]
loss:        0.000466  [  832/ 8083]
loss:        0.000535  [  992/ 8083]
loss:        0.000438  [ 1152/ 8083]
loss:        0.000492  [ 1312/ 8083]
loss:        0.000436  [ 1472/ 8083]
loss:        0.000585  [ 1632/ 8083]
loss:        0.000547  [ 1792/ 8083]
loss:        0.000497  [ 1952/ 8083]
loss:        0.000458  [ 2112/ 8083]
loss:        0.000526  [ 2272/ 8083]
loss:        0.000514  [ 2432/ 8083]
loss:        0.000521  [ 2592/ 8083]
loss:        0.000530  [ 2752/ 8083]
loss:        0.000415  [ 2912/ 8083]
loss:        0.000578  [ 3072/ 8083]
loss:        0.000567  [ 3232/ 8083]
loss:        0.000624  [ 3392/ 8083]
loss:        0.000453  [ 3552/ 8083]
loss:        0.000484  [ 3712/ 8083]
loss:        0.000475  [ 3872/ 8083]
loss:        0.000541  [ 4032/ 8083]
loss:        0.000524  [ 4192/ 8083]
loss:        0.000530  [ 4352/ 8083]
loss:        0.000437  [ 4512/ 8083]
loss:        0.000431  [ 4672/ 8083]
loss:        0.000609  [ 4832/ 8083]
loss:        0.000593  [ 4992/ 8083]
loss:        0.000526  [ 5152/ 8083]
loss:        0.000584  [ 5312/ 8083]
loss:        0.000474  [ 5472/ 8083]
loss:        0.000541  [ 5632/ 8083]
loss:        0.000560  [ 5792/ 8083]
loss:        0.000452  [ 5952/ 8083]
loss:        0.000506  [ 6112/ 8083]
loss:        0.000541  [ 6272/ 8083]
loss:        0.000493  [ 6432/ 8083]
loss:        0.000501  [ 6592/ 8083]
loss:        0.000401  [ 6752/ 8083]
loss:        0.000491  [ 6912/ 8083]
loss:        0.000420  [ 7072/ 8083]
loss:        0.000524  [ 7232/ 8083]
loss:        0.000532  [ 7392/ 8083]
loss:        0.000466  [ 7552/ 8083]
loss:        0.000499  [ 7712/ 8083]
loss:        0.000477  [ 7872/ 8083]
loss:        0.000506  [ 8032/ 8083]

Avg validation loss:        0.000501
Learning rate: [1e-05]
EarlyStopping counter: 7 out of 20


Epoch 43
-------------------------------
loss:        0.000408  [   32/ 8083]
loss:        0.000487  [  192/ 8083]
loss:        0.000474  [  352/ 8083]
loss:        0.000431  [  512/ 8083]
loss:        0.000592  [  672/ 8083]
loss:        0.000523  [  832/ 8083]
loss:        0.000595  [  992/ 8083]
loss:        0.000410  [ 1152/ 8083]
loss:        0.000478  [ 1312/ 8083]
loss:        0.000566  [ 1472/ 8083]
loss:        0.000531  [ 1632/ 8083]
loss:        0.000427  [ 1792/ 8083]
loss:        0.000530  [ 1952/ 8083]
loss:        0.000473  [ 2112/ 8083]
loss:        0.000530  [ 2272/ 8083]
loss:        0.000461  [ 2432/ 8083]
loss:        0.000503  [ 2592/ 8083]
loss:        0.000493  [ 2752/ 8083]
loss:        0.000505  [ 2912/ 8083]
loss:        0.000609  [ 3072/ 8083]
loss:        0.000519  [ 3232/ 8083]
loss:        0.000468  [ 3392/ 8083]
loss:        0.000399  [ 3552/ 8083]
loss:        0.000513  [ 3712/ 8083]
loss:        0.000567  [ 3872/ 8083]
loss:        0.000511  [ 4032/ 8083]
loss:        0.000519  [ 4192/ 8083]
loss:        0.000512  [ 4352/ 8083]
loss:        0.000477  [ 4512/ 8083]
loss:        0.000478  [ 4672/ 8083]
loss:        0.000480  [ 4832/ 8083]
loss:        0.000483  [ 4992/ 8083]
loss:        0.000279  [ 5152/ 8083]
loss:        0.000455  [ 5312/ 8083]
loss:        0.000501  [ 5472/ 8083]
loss:        0.000477  [ 5632/ 8083]
loss:        0.000528  [ 5792/ 8083]
loss:        0.000409  [ 5952/ 8083]
loss:        0.000511  [ 6112/ 8083]
loss:        0.000514  [ 6272/ 8083]
loss:        0.000470  [ 6432/ 8083]
loss:        0.000480  [ 6592/ 8083]
loss:        0.000545  [ 6752/ 8083]
loss:        0.000442  [ 6912/ 8083]
loss:        0.000452  [ 7072/ 8083]
loss:        0.000431  [ 7232/ 8083]
loss:        0.000563  [ 7392/ 8083]
loss:        0.000545  [ 7552/ 8083]
loss:        0.000524  [ 7712/ 8083]
loss:        0.000351  [ 7872/ 8083]
loss:        0.000492  [ 8032/ 8083]

Avg validation loss:        0.000502
Learning rate: [1e-05]
EarlyStopping counter: 8 out of 20


Epoch 44
-------------------------------
loss:        0.000386  [   32/ 8083]
loss:        0.000627  [  192/ 8083]
loss:        0.000525  [  352/ 8083]
loss:        0.000544  [  512/ 8083]
loss:        0.000515  [  672/ 8083]
loss:        0.000491  [  832/ 8083]
loss:        0.000537  [  992/ 8083]
loss:        0.000493  [ 1152/ 8083]
loss:        0.000429  [ 1312/ 8083]
loss:        0.000544  [ 1472/ 8083]
loss:        0.000447  [ 1632/ 8083]
loss:        0.000486  [ 1792/ 8083]
loss:        0.000492  [ 1952/ 8083]
loss:        0.000511  [ 2112/ 8083]
loss:        0.000522  [ 2272/ 8083]
loss:        0.000334  [ 2432/ 8083]
loss:        0.000443  [ 2592/ 8083]
loss:        0.000518  [ 2752/ 8083]
loss:        0.000500  [ 2912/ 8083]
loss:        0.000510  [ 3072/ 8083]
loss:        0.000391  [ 3232/ 8083]
loss:        0.000512  [ 3392/ 8083]
loss:        0.000527  [ 3552/ 8083]
loss:        0.000487  [ 3712/ 8083]
loss:        0.000515  [ 3872/ 8083]
loss:        0.000532  [ 4032/ 8083]
loss:        0.000475  [ 4192/ 8083]
loss:        0.000435  [ 4352/ 8083]
loss:        0.000485  [ 4512/ 8083]
loss:        0.000499  [ 4672/ 8083]
loss:        0.000495  [ 4832/ 8083]
loss:        0.000575  [ 4992/ 8083]
loss:        0.000533  [ 5152/ 8083]
loss:        0.000537  [ 5312/ 8083]
loss:        0.000423  [ 5472/ 8083]
loss:        0.000477  [ 5632/ 8083]
loss:        0.000480  [ 5792/ 8083]
loss:        0.000384  [ 5952/ 8083]
loss:        0.000528  [ 6112/ 8083]
loss:        0.000500  [ 6272/ 8083]
loss:        0.000421  [ 6432/ 8083]
loss:        0.000526  [ 6592/ 8083]
loss:        0.000353  [ 6752/ 8083]
loss:        0.000462  [ 6912/ 8083]
loss:        0.000623  [ 7072/ 8083]
loss:        0.000496  [ 7232/ 8083]
loss:        0.000436  [ 7392/ 8083]
loss:        0.000551  [ 7552/ 8083]
loss:        0.000518  [ 7712/ 8083]
loss:        0.000527  [ 7872/ 8083]
loss:        0.000498  [ 8032/ 8083]

Avg validation loss:        0.000502
Learning rate: [1e-05]
EarlyStopping counter: 9 out of 20


Epoch 45
-------------------------------
loss:        0.000386  [   32/ 8083]
loss:        0.000383  [  192/ 8083]
loss:        0.000529  [  352/ 8083]
loss:        0.000536  [  512/ 8083]
loss:        0.000547  [  672/ 8083]
loss:        0.000535  [  832/ 8083]
loss:        0.000390  [  992/ 8083]
loss:        0.000550  [ 1152/ 8083]
loss:        0.000519  [ 1312/ 8083]
loss:        0.000513  [ 1472/ 8083]
loss:        0.000431  [ 1632/ 8083]
loss:        0.000498  [ 1792/ 8083]
loss:        0.000450  [ 1952/ 8083]
loss:        0.000509  [ 2112/ 8083]
loss:        0.000554  [ 2272/ 8083]
loss:        0.000386  [ 2432/ 8083]
loss:        0.000427  [ 2592/ 8083]
loss:        0.000561  [ 2752/ 8083]
loss:        0.000488  [ 2912/ 8083]
loss:        0.000507  [ 3072/ 8083]
loss:        0.000491  [ 3232/ 8083]
loss:        0.000511  [ 3392/ 8083]
loss:        0.000547  [ 3552/ 8083]
loss:        0.000471  [ 3712/ 8083]
loss:        0.000419  [ 3872/ 8083]
loss:        0.000527  [ 4032/ 8083]
loss:        0.000550  [ 4192/ 8083]
loss:        0.000538  [ 4352/ 8083]
loss:        0.000468  [ 4512/ 8083]
loss:        0.000631  [ 4672/ 8083]
loss:        0.000456  [ 4832/ 8083]
loss:        0.000522  [ 4992/ 8083]
loss:        0.000544  [ 5152/ 8083]
loss:        0.000492  [ 5312/ 8083]
loss:        0.000440  [ 5472/ 8083]
loss:        0.000550  [ 5632/ 8083]
loss:        0.000512  [ 5792/ 8083]
loss:        0.000491  [ 5952/ 8083]
loss:        0.000638  [ 6112/ 8083]
loss:        0.000566  [ 6272/ 8083]
loss:        0.000546  [ 6432/ 8083]
loss:        0.000467  [ 6592/ 8083]
loss:        0.000528  [ 6752/ 8083]
loss:        0.000481  [ 6912/ 8083]
loss:        0.000518  [ 7072/ 8083]
loss:        0.000477  [ 7232/ 8083]
loss:        0.000436  [ 7392/ 8083]
loss:        0.000416  [ 7552/ 8083]
loss:        0.000537  [ 7712/ 8083]
loss:        0.000587  [ 7872/ 8083]
loss:        0.000444  [ 8032/ 8083]

Avg validation loss:        0.000501
Learning rate: [1e-05]
EarlyStopping counter: 10 out of 20


Epoch 46
-------------------------------
loss:        0.000349  [   32/ 8083]
loss:        0.000427  [  192/ 8083]
loss:        0.000635  [  352/ 8083]
loss:        0.000672  [  512/ 8083]
loss:        0.000564  [  672/ 8083]
loss:        0.000503  [  832/ 8083]
loss:        0.000489  [  992/ 8083]
loss:        0.000553  [ 1152/ 8083]
loss:        0.000483  [ 1312/ 8083]
loss:        0.000419  [ 1472/ 8083]
loss:        0.000481  [ 1632/ 8083]
loss:        0.000445  [ 1792/ 8083]
loss:        0.000613  [ 1952/ 8083]
loss:        0.000492  [ 2112/ 8083]
loss:        0.000471  [ 2272/ 8083]
loss:        0.000547  [ 2432/ 8083]
loss:        0.000311  [ 2592/ 8083]
loss:        0.000400  [ 2752/ 8083]
loss:        0.000380  [ 2912/ 8083]
loss:        0.000515  [ 3072/ 8083]
loss:        0.000449  [ 3232/ 8083]
loss:        0.000560  [ 3392/ 8083]
loss:        0.000426  [ 3552/ 8083]
loss:        0.000559  [ 3712/ 8083]
loss:        0.000515  [ 3872/ 8083]
loss:        0.000530  [ 4032/ 8083]
loss:        0.000568  [ 4192/ 8083]
loss:        0.000439  [ 4352/ 8083]
loss:        0.000443  [ 4512/ 8083]
loss:        0.000549  [ 4672/ 8083]
loss:        0.000505  [ 4832/ 8083]
loss:        0.000470  [ 4992/ 8083]
loss:        0.000451  [ 5152/ 8083]
loss:        0.000555  [ 5312/ 8083]
loss:        0.000403  [ 5472/ 8083]
loss:        0.000538  [ 5632/ 8083]
loss:        0.000523  [ 5792/ 8083]
loss:        0.000483  [ 5952/ 8083]
loss:        0.000528  [ 6112/ 8083]
loss:        0.000444  [ 6272/ 8083]
loss:        0.000586  [ 6432/ 8083]
loss:        0.000407  [ 6592/ 8083]
loss:        0.000533  [ 6752/ 8083]
loss:        0.000538  [ 6912/ 8083]
loss:        0.000446  [ 7072/ 8083]
loss:        0.000402  [ 7232/ 8083]
loss:        0.000553  [ 7392/ 8083]
loss:        0.000538  [ 7552/ 8083]
loss:        0.000605  [ 7712/ 8083]
loss:        0.000528  [ 7872/ 8083]
loss:        0.000545  [ 8032/ 8083]

Avg validation loss:        0.000502
Learning rate: [1e-05]
EarlyStopping counter: 11 out of 20


Epoch 47
-------------------------------
loss:        0.000524  [   32/ 8083]
loss:        0.000559  [  192/ 8083]
loss:        0.000440  [  352/ 8083]
loss:        0.000498  [  512/ 8083]
loss:        0.000485  [  672/ 8083]
loss:        0.000472  [  832/ 8083]
loss:        0.000531  [  992/ 8083]
loss:        0.000488  [ 1152/ 8083]
loss:        0.000553  [ 1312/ 8083]
loss:        0.000431  [ 1472/ 8083]
loss:        0.000537  [ 1632/ 8083]
loss:        0.000469  [ 1792/ 8083]
loss:        0.000535  [ 1952/ 8083]
loss:        0.000439  [ 2112/ 8083]
loss:        0.000469  [ 2272/ 8083]
loss:        0.000585  [ 2432/ 8083]
loss:        0.000447  [ 2592/ 8083]
loss:        0.000547  [ 2752/ 8083]
loss:        0.000465  [ 2912/ 8083]
loss:        0.000487  [ 3072/ 8083]
loss:        0.000527  [ 3232/ 8083]
loss:        0.000478  [ 3392/ 8083]
loss:        0.000516  [ 3552/ 8083]
loss:        0.000370  [ 3712/ 8083]
loss:        0.000509  [ 3872/ 8083]
loss:        0.000445  [ 4032/ 8083]
loss:        0.000533  [ 4192/ 8083]
loss:        0.000543  [ 4352/ 8083]
loss:        0.000521  [ 4512/ 8083]
loss:        0.000517  [ 4672/ 8083]
loss:        0.000507  [ 4832/ 8083]
loss:        0.000557  [ 4992/ 8083]
loss:        0.000406  [ 5152/ 8083]
loss:        0.000420  [ 5312/ 8083]
loss:        0.000454  [ 5472/ 8083]
loss:        0.000601  [ 5632/ 8083]
loss:        0.000435  [ 5792/ 8083]
loss:        0.000488  [ 5952/ 8083]
loss:        0.000509  [ 6112/ 8083]
loss:        0.000493  [ 6272/ 8083]
loss:        0.000532  [ 6432/ 8083]
loss:        0.000514  [ 6592/ 8083]
loss:        0.000574  [ 6752/ 8083]
loss:        0.000469  [ 6912/ 8083]
loss:        0.000505  [ 7072/ 8083]
loss:        0.000478  [ 7232/ 8083]
loss:        0.000391  [ 7392/ 8083]
loss:        0.000502  [ 7552/ 8083]
loss:        0.000456  [ 7712/ 8083]
loss:        0.000396  [ 7872/ 8083]
loss:        0.000568  [ 8032/ 8083]

Avg validation loss:        0.000503
Learning rate: [1.0000000000000002e-06]
EarlyStopping counter: 12 out of 20


Epoch 48
-------------------------------
loss:        0.000442  [   32/ 8083]
loss:        0.000585  [  192/ 8083]
loss:        0.000536  [  352/ 8083]
loss:        0.000395  [  512/ 8083]
loss:        0.000489  [  672/ 8083]
loss:        0.000448  [  832/ 8083]
loss:        0.000501  [  992/ 8083]
loss:        0.000516  [ 1152/ 8083]
loss:        0.000482  [ 1312/ 8083]
loss:        0.000537  [ 1472/ 8083]
loss:        0.000536  [ 1632/ 8083]
loss:        0.000409  [ 1792/ 8083]
loss:        0.000480  [ 1952/ 8083]
loss:        0.000604  [ 2112/ 8083]
loss:        0.000476  [ 2272/ 8083]
loss:        0.000457  [ 2432/ 8083]
loss:        0.000447  [ 2592/ 8083]
loss:        0.000462  [ 2752/ 8083]
loss:        0.000391  [ 2912/ 8083]
loss:        0.000457  [ 3072/ 8083]
loss:        0.000528  [ 3232/ 8083]
loss:        0.000305  [ 3392/ 8083]
loss:        0.000582  [ 3552/ 8083]
loss:        0.000590  [ 3712/ 8083]
loss:        0.000401  [ 3872/ 8083]
loss:        0.000493  [ 4032/ 8083]
loss:        0.000405  [ 4192/ 8083]
loss:        0.000559  [ 4352/ 8083]
loss:        0.000423  [ 4512/ 8083]
loss:        0.000486  [ 4672/ 8083]
loss:        0.000457  [ 4832/ 8083]
loss:        0.000502  [ 4992/ 8083]
loss:        0.000402  [ 5152/ 8083]
loss:        0.000506  [ 5312/ 8083]
loss:        0.000509  [ 5472/ 8083]
loss:        0.000493  [ 5632/ 8083]
loss:        0.000554  [ 5792/ 8083]
loss:        0.000475  [ 5952/ 8083]
loss:        0.000620  [ 6112/ 8083]
loss:        0.000500  [ 6272/ 8083]
loss:        0.000476  [ 6432/ 8083]
loss:        0.000534  [ 6592/ 8083]
loss:        0.000491  [ 6752/ 8083]
loss:        0.000404  [ 6912/ 8083]
loss:        0.000521  [ 7072/ 8083]
loss:        0.000454  [ 7232/ 8083]
loss:        0.000414  [ 7392/ 8083]
loss:        0.000569  [ 7552/ 8083]
loss:        0.000441  [ 7712/ 8083]
loss:        0.000526  [ 7872/ 8083]
loss:        0.000417  [ 8032/ 8083]

Avg validation loss:        0.000503
Learning rate: [1.0000000000000002e-06]
EarlyStopping counter: 13 out of 20


Epoch 49
-------------------------------
loss:        0.000427  [   32/ 8083]
loss:        0.000498  [  192/ 8083]
loss:        0.000472  [  352/ 8083]
loss:        0.000413  [  512/ 8083]
loss:        0.000460  [  672/ 8083]
loss:        0.000525  [  832/ 8083]
loss:        0.000542  [  992/ 8083]
loss:        0.000465  [ 1152/ 8083]
loss:        0.000533  [ 1312/ 8083]
loss:        0.000464  [ 1472/ 8083]
loss:        0.000488  [ 1632/ 8083]
loss:        0.000557  [ 1792/ 8083]
loss:        0.000449  [ 1952/ 8083]
loss:        0.000430  [ 2112/ 8083]
loss:        0.000461  [ 2272/ 8083]
loss:        0.000463  [ 2432/ 8083]
loss:        0.000632  [ 2592/ 8083]
loss:        0.000498  [ 2752/ 8083]
loss:        0.000425  [ 2912/ 8083]
loss:        0.000399  [ 3072/ 8083]
loss:        0.000477  [ 3232/ 8083]
loss:        0.000496  [ 3392/ 8083]
loss:        0.000479  [ 3552/ 8083]
loss:        0.000517  [ 3712/ 8083]
loss:        0.000427  [ 3872/ 8083]
loss:        0.000481  [ 4032/ 8083]
loss:        0.000399  [ 4192/ 8083]
loss:        0.000552  [ 4352/ 8083]
loss:        0.000535  [ 4512/ 8083]
loss:        0.000565  [ 4672/ 8083]
loss:        0.000479  [ 4832/ 8083]
loss:        0.000520  [ 4992/ 8083]
loss:        0.000522  [ 5152/ 8083]
loss:        0.000596  [ 5312/ 8083]
loss:        0.000530  [ 5472/ 8083]
loss:        0.000424  [ 5632/ 8083]
loss:        0.000357  [ 5792/ 8083]
loss:        0.000525  [ 5952/ 8083]
loss:        0.000526  [ 6112/ 8083]
loss:        0.000454  [ 6272/ 8083]
loss:        0.000512  [ 6432/ 8083]
loss:        0.000522  [ 6592/ 8083]
loss:        0.000508  [ 6752/ 8083]
loss:        0.000505  [ 6912/ 8083]
loss:        0.000553  [ 7072/ 8083]
loss:        0.000426  [ 7232/ 8083]
loss:        0.000440  [ 7392/ 8083]
loss:        0.000427  [ 7552/ 8083]
loss:        0.000471  [ 7712/ 8083]
loss:        0.000476  [ 7872/ 8083]
loss:        0.000583  [ 8032/ 8083]

Avg validation loss:        0.000502
Learning rate: [1.0000000000000002e-06]
EarlyStopping counter: 14 out of 20


Epoch 50
-------------------------------
loss:        0.000517  [   32/ 8083]
loss:        0.000440  [  192/ 8083]
loss:        0.000465  [  352/ 8083]
loss:        0.000507  [  512/ 8083]
loss:        0.000544  [  672/ 8083]
loss:        0.000518  [  832/ 8083]
loss:        0.000531  [  992/ 8083]
loss:        0.000483  [ 1152/ 8083]
loss:        0.000564  [ 1312/ 8083]
loss:        0.000667  [ 1472/ 8083]
loss:        0.000515  [ 1632/ 8083]
loss:        0.000374  [ 1792/ 8083]
loss:        0.000433  [ 1952/ 8083]
loss:        0.000532  [ 2112/ 8083]
loss:        0.000595  [ 2272/ 8083]
loss:        0.000576  [ 2432/ 8083]
loss:        0.000492  [ 2592/ 8083]
loss:        0.000397  [ 2752/ 8083]
loss:        0.000499  [ 2912/ 8083]
loss:        0.000474  [ 3072/ 8083]
loss:        0.000465  [ 3232/ 8083]
loss:        0.000537  [ 3392/ 8083]
loss:        0.000486  [ 3552/ 8083]
loss:        0.000514  [ 3712/ 8083]
loss:        0.000521  [ 3872/ 8083]
loss:        0.000459  [ 4032/ 8083]
loss:        0.000474  [ 4192/ 8083]
loss:        0.000455  [ 4352/ 8083]
loss:        0.000477  [ 4512/ 8083]
loss:        0.000455  [ 4672/ 8083]
loss:        0.000526  [ 4832/ 8083]
loss:        0.000473  [ 4992/ 8083]
loss:        0.000410  [ 5152/ 8083]
loss:        0.000497  [ 5312/ 8083]
loss:        0.000477  [ 5472/ 8083]
loss:        0.000514  [ 5632/ 8083]
loss:        0.000445  [ 5792/ 8083]
loss:        0.000485  [ 5952/ 8083]
loss:        0.000575  [ 6112/ 8083]
loss:        0.000441  [ 6272/ 8083]
loss:        0.000496  [ 6432/ 8083]
loss:        0.000526  [ 6592/ 8083]
loss:        0.000541  [ 6752/ 8083]
loss:        0.000407  [ 6912/ 8083]
loss:        0.000449  [ 7072/ 8083]
loss:        0.000522  [ 7232/ 8083]
loss:        0.000464  [ 7392/ 8083]
loss:        0.000425  [ 7552/ 8083]
loss:        0.000537  [ 7712/ 8083]
loss:        0.000554  [ 7872/ 8083]
loss:        0.000503  [ 8032/ 8083]

Avg validation loss:        0.000503
Learning rate: [1.0000000000000002e-06]
EarlyStopping counter: 15 out of 20


Epoch 51
-------------------------------
loss:        0.000481  [   32/ 8083]
loss:        0.000491  [  192/ 8083]
loss:        0.000507  [  352/ 8083]
loss:        0.000411  [  512/ 8083]
loss:        0.000501  [  672/ 8083]
loss:        0.000372  [  832/ 8083]
loss:        0.000654  [  992/ 8083]
loss:        0.000435  [ 1152/ 8083]
loss:        0.000517  [ 1312/ 8083]
loss:        0.000457  [ 1472/ 8083]
loss:        0.000408  [ 1632/ 8083]
loss:        0.000503  [ 1792/ 8083]
loss:        0.000536  [ 1952/ 8083]
loss:        0.000522  [ 2112/ 8083]
loss:        0.000413  [ 2272/ 8083]
loss:        0.000440  [ 2432/ 8083]
loss:        0.000494  [ 2592/ 8083]
loss:        0.000512  [ 2752/ 8083]
loss:        0.000438  [ 2912/ 8083]
loss:        0.000382  [ 3072/ 8083]
loss:        0.000587  [ 3232/ 8083]
loss:        0.000429  [ 3392/ 8083]
loss:        0.000527  [ 3552/ 8083]
loss:        0.000581  [ 3712/ 8083]
loss:        0.000402  [ 3872/ 8083]
loss:        0.000470  [ 4032/ 8083]
loss:        0.000518  [ 4192/ 8083]
loss:        0.000510  [ 4352/ 8083]
loss:        0.000496  [ 4512/ 8083]
loss:        0.000516  [ 4672/ 8083]
loss:        0.000546  [ 4832/ 8083]
loss:        0.000510  [ 4992/ 8083]
loss:        0.000514  [ 5152/ 8083]
loss:        0.000558  [ 5312/ 8083]
loss:        0.000526  [ 5472/ 8083]
loss:        0.000446  [ 5632/ 8083]
loss:        0.000497  [ 5792/ 8083]
loss:        0.000379  [ 5952/ 8083]
loss:        0.000501  [ 6112/ 8083]
loss:        0.000460  [ 6272/ 8083]
loss:        0.000572  [ 6432/ 8083]
loss:        0.000532  [ 6592/ 8083]
loss:        0.000482  [ 6752/ 8083]
loss:        0.000511  [ 6912/ 8083]
loss:        0.000577  [ 7072/ 8083]
loss:        0.000528  [ 7232/ 8083]
loss:        0.000463  [ 7392/ 8083]
loss:        0.000596  [ 7552/ 8083]
loss:        0.000549  [ 7712/ 8083]
loss:        0.000531  [ 7872/ 8083]
loss:        0.000438  [ 8032/ 8083]

Avg validation loss:        0.000501
Learning rate: [1.0000000000000002e-06]
EarlyStopping counter: 16 out of 20


Epoch 52
-------------------------------
loss:        0.000570  [   32/ 8083]
loss:        0.000522  [  192/ 8083]
loss:        0.000404  [  352/ 8083]
loss:        0.000474  [  512/ 8083]
loss:        0.000569  [  672/ 8083]
loss:        0.000449  [  832/ 8083]
loss:        0.000533  [  992/ 8083]
loss:        0.000481  [ 1152/ 8083]
loss:        0.000534  [ 1312/ 8083]
loss:        0.000543  [ 1472/ 8083]
loss:        0.000560  [ 1632/ 8083]
loss:        0.000537  [ 1792/ 8083]
loss:        0.000510  [ 1952/ 8083]
loss:        0.000398  [ 2112/ 8083]
loss:        0.000592  [ 2272/ 8083]
loss:        0.000545  [ 2432/ 8083]
loss:        0.000491  [ 2592/ 8083]
loss:        0.000519  [ 2752/ 8083]
loss:        0.000399  [ 2912/ 8083]
loss:        0.000562  [ 3072/ 8083]
loss:        0.000439  [ 3232/ 8083]
loss:        0.000515  [ 3392/ 8083]
loss:        0.000518  [ 3552/ 8083]
loss:        0.000582  [ 3712/ 8083]
loss:        0.000493  [ 3872/ 8083]
loss:        0.000504  [ 4032/ 8083]
loss:        0.000522  [ 4192/ 8083]
loss:        0.000425  [ 4352/ 8083]
loss:        0.000490  [ 4512/ 8083]
loss:        0.000491  [ 4672/ 8083]
loss:        0.000566  [ 4832/ 8083]
loss:        0.000455  [ 4992/ 8083]
loss:        0.000549  [ 5152/ 8083]
loss:        0.000413  [ 5312/ 8083]
loss:        0.000556  [ 5472/ 8083]
loss:        0.000517  [ 5632/ 8083]
loss:        0.000481  [ 5792/ 8083]
loss:        0.000532  [ 5952/ 8083]
loss:        0.000591  [ 6112/ 8083]
loss:        0.000501  [ 6272/ 8083]
loss:        0.000522  [ 6432/ 8083]
loss:        0.000573  [ 6592/ 8083]
loss:        0.000539  [ 6752/ 8083]
loss:        0.000423  [ 6912/ 8083]
loss:        0.000373  [ 7072/ 8083]
loss:        0.000390  [ 7232/ 8083]
loss:        0.000431  [ 7392/ 8083]
loss:        0.000538  [ 7552/ 8083]
loss:        0.000390  [ 7712/ 8083]
loss:        0.000525  [ 7872/ 8083]
loss:        0.000631  [ 8032/ 8083]

Avg validation loss:        0.000502
Learning rate: [1.0000000000000002e-06]
EarlyStopping counter: 17 out of 20


Epoch 53
-------------------------------
loss:        0.000498  [   32/ 8083]
loss:        0.000528  [  192/ 8083]
loss:        0.000456  [  352/ 8083]
loss:        0.000444  [  512/ 8083]
loss:        0.000569  [  672/ 8083]
loss:        0.000513  [  832/ 8083]
loss:        0.000422  [  992/ 8083]
loss:        0.000493  [ 1152/ 8083]
loss:        0.000407  [ 1312/ 8083]
loss:        0.000455  [ 1472/ 8083]
loss:        0.000642  [ 1632/ 8083]
loss:        0.000433  [ 1792/ 8083]
loss:        0.000475  [ 1952/ 8083]
loss:        0.000506  [ 2112/ 8083]
loss:        0.000394  [ 2272/ 8083]
loss:        0.000499  [ 2432/ 8083]
loss:        0.000433  [ 2592/ 8083]
loss:        0.000454  [ 2752/ 8083]
loss:        0.000460  [ 2912/ 8083]
loss:        0.000429  [ 3072/ 8083]
loss:        0.000438  [ 3232/ 8083]
loss:        0.000475  [ 3392/ 8083]
loss:        0.000465  [ 3552/ 8083]
loss:        0.000571  [ 3712/ 8083]
loss:        0.000416  [ 3872/ 8083]
loss:        0.000570  [ 4032/ 8083]
loss:        0.000518  [ 4192/ 8083]
loss:        0.000528  [ 4352/ 8083]
loss:        0.000392  [ 4512/ 8083]
loss:        0.000529  [ 4672/ 8083]
loss:        0.000552  [ 4832/ 8083]
loss:        0.000557  [ 4992/ 8083]
loss:        0.000553  [ 5152/ 8083]
loss:        0.000538  [ 5312/ 8083]
loss:        0.000600  [ 5472/ 8083]
loss:        0.000503  [ 5632/ 8083]
loss:        0.000480  [ 5792/ 8083]
loss:        0.000424  [ 5952/ 8083]
loss:        0.000398  [ 6112/ 8083]
loss:        0.000490  [ 6272/ 8083]
loss:        0.000442  [ 6432/ 8083]
loss:        0.000488  [ 6592/ 8083]
loss:        0.000400  [ 6752/ 8083]
loss:        0.000537  [ 6912/ 8083]
loss:        0.000400  [ 7072/ 8083]
loss:        0.000501  [ 7232/ 8083]
loss:        0.000421  [ 7392/ 8083]
loss:        0.000468  [ 7552/ 8083]
loss:        0.000422  [ 7712/ 8083]
loss:        0.000522  [ 7872/ 8083]
loss:        0.000523  [ 8032/ 8083]

Avg validation loss:        0.000505
Learning rate: [1.0000000000000002e-07]
EarlyStopping counter: 18 out of 20


Epoch 54
-------------------------------
loss:        0.000504  [   32/ 8083]
loss:        0.000479  [  192/ 8083]
loss:        0.000500  [  352/ 8083]
loss:        0.000512  [  512/ 8083]
loss:        0.000498  [  672/ 8083]
loss:        0.000518  [  832/ 8083]
loss:        0.000556  [  992/ 8083]
loss:        0.000425  [ 1152/ 8083]
loss:        0.000444  [ 1312/ 8083]
loss:        0.000416  [ 1472/ 8083]
loss:        0.000489  [ 1632/ 8083]
loss:        0.000477  [ 1792/ 8083]
loss:        0.000499  [ 1952/ 8083]
loss:        0.000482  [ 2112/ 8083]
loss:        0.000449  [ 2272/ 8083]
loss:        0.000629  [ 2432/ 8083]
loss:        0.000500  [ 2592/ 8083]
loss:        0.000440  [ 2752/ 8083]
loss:        0.000474  [ 2912/ 8083]
loss:        0.000445  [ 3072/ 8083]
loss:        0.000480  [ 3232/ 8083]
loss:        0.000374  [ 3392/ 8083]
loss:        0.000455  [ 3552/ 8083]
loss:        0.000595  [ 3712/ 8083]
loss:        0.000521  [ 3872/ 8083]
loss:        0.000418  [ 4032/ 8083]
loss:        0.000484  [ 4192/ 8083]
loss:        0.000417  [ 4352/ 8083]
loss:        0.000490  [ 4512/ 8083]
loss:        0.000524  [ 4672/ 8083]
loss:        0.000474  [ 4832/ 8083]
loss:        0.000562  [ 4992/ 8083]
loss:        0.000536  [ 5152/ 8083]
loss:        0.000496  [ 5312/ 8083]
loss:        0.000488  [ 5472/ 8083]
loss:        0.000522  [ 5632/ 8083]
loss:        0.000478  [ 5792/ 8083]
loss:        0.000536  [ 5952/ 8083]
loss:        0.000438  [ 6112/ 8083]
loss:        0.000481  [ 6272/ 8083]
loss:        0.000465  [ 6432/ 8083]
loss:        0.000513  [ 6592/ 8083]
loss:        0.000557  [ 6752/ 8083]
loss:        0.000542  [ 6912/ 8083]
loss:        0.000432  [ 7072/ 8083]
loss:        0.000451  [ 7232/ 8083]
loss:        0.000502  [ 7392/ 8083]
loss:        0.000355  [ 7552/ 8083]
loss:        0.000487  [ 7712/ 8083]
loss:        0.000436  [ 7872/ 8083]
loss:        0.000514  [ 8032/ 8083]

Avg validation loss:        0.000504
Learning rate: [1.0000000000000002e-07]
EarlyStopping counter: 19 out of 20


Epoch 55
-------------------------------
loss:        0.000418  [   32/ 8083]
loss:        0.000527  [  192/ 8083]
loss:        0.000599  [  352/ 8083]
loss:        0.000568  [  512/ 8083]
loss:        0.000490  [  672/ 8083]
loss:        0.000577  [  832/ 8083]
loss:        0.000434  [  992/ 8083]
loss:        0.000611  [ 1152/ 8083]
loss:        0.000456  [ 1312/ 8083]
loss:        0.000504  [ 1472/ 8083]
loss:        0.000480  [ 1632/ 8083]
loss:        0.000411  [ 1792/ 8083]
loss:        0.000479  [ 1952/ 8083]
loss:        0.000486  [ 2112/ 8083]
loss:        0.000431  [ 2272/ 8083]
loss:        0.000537  [ 2432/ 8083]
loss:        0.000573  [ 2592/ 8083]
loss:        0.000457  [ 2752/ 8083]
loss:        0.000529  [ 2912/ 8083]
loss:        0.000508  [ 3072/ 8083]
loss:        0.000566  [ 3232/ 8083]
loss:        0.000409  [ 3392/ 8083]
loss:        0.000539  [ 3552/ 8083]
loss:        0.000481  [ 3712/ 8083]
loss:        0.000405  [ 3872/ 8083]
loss:        0.000393  [ 4032/ 8083]
loss:        0.000494  [ 4192/ 8083]
loss:        0.000469  [ 4352/ 8083]
loss:        0.000513  [ 4512/ 8083]
loss:        0.000477  [ 4672/ 8083]
loss:        0.000462  [ 4832/ 8083]
loss:        0.000474  [ 4992/ 8083]
loss:        0.000563  [ 5152/ 8083]
loss:        0.000480  [ 5312/ 8083]
loss:        0.000520  [ 5472/ 8083]
loss:        0.000550  [ 5632/ 8083]
loss:        0.000480  [ 5792/ 8083]
loss:        0.000483  [ 5952/ 8083]
loss:        0.000511  [ 6112/ 8083]
loss:        0.000470  [ 6272/ 8083]
loss:        0.000570  [ 6432/ 8083]
loss:        0.000446  [ 6592/ 8083]
loss:        0.000549  [ 6752/ 8083]
loss:        0.000507  [ 6912/ 8083]
loss:        0.000487  [ 7072/ 8083]
loss:        0.000502  [ 7232/ 8083]
loss:        0.000464  [ 7392/ 8083]
loss:        0.000537  [ 7552/ 8083]
loss:        0.000403  [ 7712/ 8083]
loss:        0.000463  [ 7872/ 8083]
loss:        0.000586  [ 8032/ 8083]

Avg validation loss:        0.000503
Learning rate: [1.0000000000000002e-07]
EarlyStopping counter: 20 out of 20
Early stopping
Avg test loss:        0.000520 

Done!

Create an output directory to save the outputs and plots.

Code
# To save the outputs of the model, create a directory
output_dir = f'outputs/model_outputs/deepSSF_S2/id{buffalo_id}'
os.makedirs(output_dir, exist_ok=True)

Save the validation loss as a dataframe

Code
# Directory for saving the loss dataframe
filename_loss_csv = f'{output_dir}/deepSSF_S2_val_loss_buffalo{buffalo_id}_{today_date}.csv'

val_losses_df = pd.DataFrame({
    "epoch": range(1, len(val_losses) + 1),
    "val_losses": val_losses
})

# Save the validation losses to a CSV file
val_losses_df.to_csv(filename_loss_csv, index=False)

Plot the validation loss

Code
# Directory for saving the loss plots
filename_loss_png = f'{output_dir}/deepSSF_S2_val_loss_buffalo{buffalo_id}_{today_date}.png'

# Plot the validation losses
plt.plot(val_losses, label='Validation Loss', color='red')  # Plot validation loss in red
plt.title('Validation Losses')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend()  # Show legend to distinguish lines
plt.savefig(filename_loss_png, dpi=600, bbox_inches='tight')
plt.show()

Check model parameters

Code
# to look at the parameters (weights and biases) of the model
# print(model.state_dict())

Loading in previous models

As we’ve trained the model, the model parameters are already stored in the model object. But as we were training the model, we were saving it to file, and that, and other trained models can be loaded.

The model parameters that are being loaded must match the model object that has been defined above. If the model object has changed, the model parameters will not be able to be loaded.

Code
path_save_weights
'model_checkpoints/deepSSF_S2_slope_buffalo2005_2025-02-09.pt'

If loading a previously trained model

Code
# to load previously saved weights
# path_save_weights = f'model_checkpoints/deepSSF_S2_slope_buffalo2005_2025-02-09.pt'

model.load_state_dict(torch.load(path_save_weights, 
                                 weights_only=True,
                                 map_location=torch.device('cpu')))
<All keys matched successfully>

Test model

Take some random samples from the test dataset and generate predictions for them. We loop through the samples (which are shuffled randomly), make predictions, and plot the results.

Helper functions

To return the hour and day of the year to their original values, we can use the following functions.

Code
def recover_hour(sin_term, cos_term):
    # Calculate the angle theta
    theta = np.arctan2(sin_term, cos_term)
    # Calculate hour_t2
    hour = (12 * theta) / np.pi % 24
    return hour

def recover_yday(sin_term, cos_term):
    # Calculate the angle theta
    theta = np.arctan2(sin_term, cos_term)
    # Calculate hour_t2
    yday = (365 * theta) / (2 * np.pi)  % 365
    return yday
Code
# 1. Set the model in evaluation mode
model.eval()

# Loop over samples in the validation dataset
for i in range(0, 5):

  # Display image and label
  x1, x2, x3, labels = next(iter(dataloader_test))

  # Pull out the scalars
  hour_t2_sin = x2.detach().numpy()[0,0]
  hour_t2_cos = x2.detach().numpy()[0,1]
  yday_t2_sin = x2.detach().numpy()[0,2]
  yday_t2_cos = x2.detach().numpy()[0,3]
  bearing = x3.detach().numpy()[0,0]

  # Recover the hour
  hour_t2 = recover_hour(hour_t2_sin, hour_t2_cos)
  hour_t2_integer = int(hour_t2)  # Convert to integer
  print(f'Hour:                        {hour_t2_integer}')

  # Recover the day of the year
  yday_t2 = recover_yday(yday_t2_sin, yday_t2_cos)
  yday_t2_integer = int(yday_t2)  # Convert to integer
  print(f'Day of the year:             {yday_t2_integer}')

  # Recover the bearing
  bearing_degrees = np.degrees(bearing) % 360
  bearing_degrees = round(bearing_degrees, 1)  # Round to 2 decimal places
  bearing_degrees = int(bearing_degrees)  # Convert to integer
  print(f'Bearing (radians):           {bearing}')
  print(f'Bearing (degrees):           {bearing_degrees}')

  # Pull out the RGB layers for plotting
  blue_layer = x1.detach().cpu().numpy()[0,1,:,:]
  green_layer = x1.detach().cpu().numpy()[0,2,:,:]
  red_layer = x1.detach().cpu().numpy()[0,3,:,:]

  # Stack the RGB layers
  rgb_image_np = np.stack([red_layer, green_layer, blue_layer], axis=-1)

  # Normalize to the range [0, 1] for display
  rgb_image_np = (rgb_image_np - rgb_image_np.min()) / (rgb_image_np.max() - rgb_image_np.min())

  # Find the coordinates of the element that is 1
  target = labels.detach().cpu().numpy()[0,:,:]
  coordinates = np.where(target == 1)

  # Extract the coordinates
  row, column = coordinates[0][0], coordinates[1][0]
  print(f"Next step is (row, column):  ({row}, {column})")


  # -------------------------------------------------------------------------
  # Run the model on the input data
  # -------------------------------------------------------------------------

  # Move input tensors to the GPU if available
  x1 = x1.to(device)
  x2 = x2.to(device)
  x3 = x3.to(device)

  test = model((x1, x2, x3))
  # print(test.shape)

  # Extract and exponentiate the habitat density channel
  hab_density = test.detach().cpu().numpy()[0, :, :, 0]
  hab_density_exp = np.exp(hab_density)
  # print(np.sum(hab_density_exp))  # Debug: check the sum of exponentiated values

  # Create masks to remove unwanted edge cells from visualization
  #    (setting them to -∞ affects the color scale in plots)
  x_mask = np.ones_like(hab_density)
  y_mask = np.ones_like(hab_density)

  # mask out cells on the edges that affect the colour scale
  x_mask[:, :3] = -np.inf
  x_mask[:, 98:] = -np.inf
  y_mask[:3, :] = -np.inf
  y_mask[98:, :] = -np.inf

  # Apply the masks to the habitat density (log scale) and exponentiated version
  hab_density_mask = hab_density * x_mask * y_mask
  hab_density_exp_mask = hab_density_exp * x_mask * y_mask

  # Extract and exponentiate the movement density channel
  move_density = test.detach().cpu().numpy()[0,:,:,1]
  move_density_exp = np.exp(move_density)

  # Apply the same masking strategy to movement densities
  move_density_mask = move_density * x_mask * y_mask
  move_density_exp_mask = move_density_exp * x_mask * y_mask

  # Compute the next-step density by adding habitat + movement (log-space)
  step_density = test[0, :, :, 0] + test[0, :, :, 1]
  step_density = step_density.detach().cpu().numpy()
  step_density_exp = np.exp(step_density)

  # Apply masks to the step densities (log and exponentiated)
  step_density_mask = step_density * x_mask * y_mask
  step_density_exp_mask = step_density_exp * x_mask * y_mask

  # -------------------------------------------------------------------------
  # Plot the RGB image, slope, habitat selection, and movement density
  #   Change the panels to visualize different layers
  # -------------------------------------------------------------------------
  fig, axs = plt.subplots(2, 2, figsize=(10, 10))

  # Plot RGB
  im1 = axs[0, 0].imshow(rgb_image_np)
  axs[0, 0].set_title('Sentinel-2 RGB')

  # Plot slope
  im2 = axs[0, 1].imshow(x1.detach().cpu().numpy()[0,12,:,:], cmap='viridis')
  axs[0, 1].set_title('Slope')
  fig.colorbar(im2, ax=axs[0, 1], shrink=0.7)

  # Plot habitat selection
  im3 = axs[1, 0].imshow(hab_density_mask, cmap='viridis')
  axs[1, 0].set_title('Habitat selection log-probability')
  fig.colorbar(im3, ax=axs[1, 0], shrink=0.7)

  # # Movement density (change the axis and uncomment one of the other panels)
  # im3 = axs[1, 0].imshow(move_density_mask, cmap='viridis')
  # axs[1, 0].set_title('Movement log-probability')
  # fig.colorbar(im3, ax=axs[0, 1], shrink=0.7)

  # Next-step probability
  im4 = axs[1, 1].imshow(step_density_mask, cmap='viridis')
  axs[1, 1].set_title('Next-step log-probability')
  fig.colorbar(im4, ax=axs[1, 1], shrink=0.7)

  # Save the figure
  filename_covs = f'{output_dir}/deepSSF_S2_slope_id{buffalo_id}_yday{yday_t2_integer}_hour{hour_t2_integer}_bearing{bearing_degrees}_next_r{row}_c{column}.png'
  plt.tight_layout()
  plt.savefig(filename_covs, dpi=600, bbox_inches='tight')
  plt.show()
  plt.close()  # Close the figure to free memory
Hour:                        8
Day of the year:             175
Bearing (radians):           2.2515599727630615
Bearing (degrees):           129
Next step is (row, column):  (52, 45)

Hour:                        17
Day of the year:             217
Bearing (radians):           -0.1754153072834015
Bearing (degrees):           349
Next step is (row, column):  (48, 50)

Hour:                        3
Day of the year:             9
Bearing (radians):           -1.8308851718902588
Bearing (degrees):           255
Next step is (row, column):  (50, 50)

Hour:                        10
Day of the year:             224
Bearing (radians):           2.1292378902435303
Bearing (degrees):           122
Next step is (row, column):  (49, 62)

Hour:                        14
Day of the year:             7
Bearing (radians):           2.781324625015259
Bearing (degrees):           159
Next step is (row, column):  (49, 50)

Extracting convolution layer outputs

In the convolutional blocks, each convolutional layer learns a set of filters (kernels) that extract different features from the input data. In the habitat selection subnetwork, the convolution filters (and their associated bias parameters - not shown below) are the only parameters that are trained, and it is the filters that transform the set of input covariates into the habitat selection probabilities. They do this by maximising features of the inputs that correlate with observed next-steps.

For each convolutional layer, there are typically a number of filters. For the habitat selection subnetwork, we used 4 filters in the first two layers, and a single filter in the last layer. Each of these filters has a number of channels which correspond one-to-one with the input layers. The outputs of the filter channels are then combined to produce a feature map, with a single feature map produced for each filter. In successive layers, the feature maps become the input layers, and the filters operate on these layers. Because there are multiple filters in ech layer, they can ‘specialise’ in extracting different features from the input layers.

By visualizing and inspecting these filters, and the corresponding feature maps, we can:

  • Gain interpretability: Understand what kind of features the network is detecting—e.g., edges, shapes, or textures.
  • Debug: Check if the filters have meaningful patterns or if something went wrong (e.g., all zeros or random noise).
  • Compare layers: See how early layers often learn low-level patterns while deeper layers learn more abstract features.

We will first set up some activation hooks for storing the feature maps. Activation hooks are placed at certain points within the model’s forward pass and store intermediate results. We will also extract the convolution filters (which are weights of the model and as such don’t require hooks - we can access them directly).

We will then run the sample covariates through the model and extract the feature maps from the habitat selection convolutional block, and plot them along with the covariates and convolution filters.

Note that there are also ReLU activation functions in the convolutional blocks, which are not shown below. These are applied to the feature maps, and set all negative values to zero. They are not learned parameters, but are part of the forward pass of the model.

Create scalar grids for plotting

Using the Scalar_to_Grid_Block class from the deepSSF_model script, we can convert the scalar covariates into grids for plotting.

Code
# Create an instance of the scalar-to-grid block using model parameters
scalar_to_grid_block = deepSSF_model.Scalar_to_Grid_Block(params)

# Convert scalars into spatial grid representation
scalar_maps = scalar_to_grid_block(x2)
print(scalar_maps.shape)  # Check the shape of the generated spatial maps
torch.Size([32, 4, 101, 101])

Convolutional layer 1

Activation hook

Code
# -----------------------------------------------------------
# Create a dictionary to store activation outputs
# -----------------------------------------------------------
activation = {}

def get_activation(name):
    """
    Returns a hook function that can be registered on a layer 
    to capture its output (i.e., feature maps) after the forward pass.

    Args:
        name (str): The key under which the activation is stored in the 'activation' dict.
    """
    def hook(model, input, output):
        # Detach and save the layer's output in the dictionary
        activation[name] = output.detach()
    return hook

# -----------------------------------------------------------
# Register a forward hook on the first convolution layer 
#    in the model's 'conv_habitat' block
# -----------------------------------------------------------
model.conv_habitat.conv2d[0].register_forward_hook(get_activation("hab_conv1"))

# -----------------------------------------------------------
# Perform a forward pass through the model with the desired input
#    The feature maps from the hooked layer will be stored in 'activation'
# -----------------------------------------------------------
out = model((x1, x2, x3))  # e.g., model((spatial_data_x, scalars_to_grid, bearing_x))

# -----------------------------------------------------------
# Retrieve the captured feature maps from the dictionary
#    and move them to the CPU for inspection
# -----------------------------------------------------------
feat_maps1 = activation["hab_conv1"].cpu()
print("Feature map shape:", feat_maps1.shape)
# Typically shape: (batch_size, out_channels, height, width)

# -----------------------------------------------------------
# Visualize the feature maps for the first sample in the batch
# -----------------------------------------------------------
feat_maps1_sample = feat_maps1[0]  # Shape: (out_channels, H, W)
num_maps1 = feat_maps1_sample.shape[0]
print("Number of feature maps:", num_maps1)
Feature map shape: torch.Size([32, 4, 101, 101])
Number of feature maps: 4

Stack spatial and scalar (as grid) covariates

For plotting. Also create a vector of names to index over.

Code
covariate_stack = torch.cat([x1, scalar_maps], dim=1)
print(covariate_stack.shape)

covariate_names = ['S2 B1',
                   'S2 B2',
                   'S2 B3',
                   'S2 B4',
                   'S2 B5',
                   'S2 B6',
                   'S2 B7',
                   'S2 B8',
                   'S2 B8a',
                   'S2 B9',
                   'S2 B11',
                   'S2 B12',
                   'Slope', 
                   'Hour sin', 
                   'Hour cos', 
                   'yday sin', 
                   'yday cos']
torch.Size([32, 17, 101, 101])

Extract filters and plot

Code
# -------------------------------------------------------------------------
# Check or print the convolution layer in conv_habitat (for debugging)
# -------------------------------------------------------------------------
print(model.conv_habitat.conv2d)

# -------------------------------------------------------------------------
# Set the model to evaluation mode (disables dropout, etc.)
# -------------------------------------------------------------------------
model.eval()

# -------------------------------------------------------------------------
# Extract the weights (filters) from the first convolution layer in conv_habitat
# -------------------------------------------------------------------------
filters_c1 = model.conv_habitat.conv2d[0].weight.data.clone().cpu()
print("Filters shape:", filters_c1.shape)
# Typically (out_channels, in_channels, kernel_height, kernel_width)

# -------------------------------------------------------------------------
# Visualize each filter’s first channel in a grid of subplots
# -------------------------------------------------------------------------
num_filters_c1 = filters_c1.shape[1]
print(num_filters_c1)

for z in range(num_maps1):

    fig, axes = plt.subplots(2, num_filters_c1, figsize=(2*num_filters_c1, 4))
    for i in range(num_filters_c1):

        # Add the covariates as the first row of subplots
        axes[0,i].imshow(covariate_stack[0, i].detach().cpu().numpy(), cmap='viridis')
        axes[0,i].axis('off')
        axes[0,i].set_title(f'{covariate_names[i]}')
        if i > x1.shape[1] - 1:
            im1 = axes[0,i].imshow(covariate_stack[0, i].detach().cpu().numpy(), cmap='viridis')
            im1.set_clim(-1, 1)
            axes[0,i].text(scalar_maps.shape[2] // 2, scalar_maps.shape[3] // 2, 
                f'Value: {round(x2[0, i-x1.shape[1]].item(), 2)}', 
                ha='center', va='center', color='white', fontsize=12)

        kernel = filters_c1[z, i, :, :]  # Show the first input channel
        im = axes[1,i].imshow(kernel, cmap='viridis')
        axes[1,i].axis('off')
        axes[1,i].set_title(f'Layer 1, Filter {z+1}')
        # Annotate each cell with the numeric value
        for (j, k), val in np.ndenumerate(kernel):
            axes[1,i].text(k, j, f'{val:.2f}', ha='center', va='center', color='white')

    plt.tight_layout()
    plt.savefig(f'{output_dir}/id{buffalo_id}_conv_layer1_filters{z}_{today_date}.png', dpi=600, bbox_inches='tight')
    plt.show()

        
    # -----------------------------------------------------------
    # Loop over each feature map channel and save them as images.
    #    Multiply by x_mask * y_mask if you need to mask out edges.
    # -----------------------------------------------------------

    plt.figure()
    plt.imshow(feat_maps1_sample[z].numpy() * x_mask * y_mask, cmap='viridis')
    plt.title(f"Layer 1, Feature Map {z+1}")
    # Hide axis if you prefer: plt.axis('off')
    plt.savefig(f'{output_dir}/id{buffalo_id}_conv_layer1_feature_map{z}_{today_date}.png', dpi=600, bbox_inches='tight')
    plt.show()
Sequential(
  (0): Conv2d(17, 4, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (1): ReLU()
  (2): Conv2d(4, 4, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (3): ReLU()
  (4): Conv2d(4, 1, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
)
Filters shape: torch.Size([4, 17, 3, 3])
17

Convolutional layer 2

Activation hook

Code
# -----------------------------------------------------------
# Register a forward hook on the second convolution layer 
#    in the model's 'conv_habitat' block
# -----------------------------------------------------------
model.conv_habitat.conv2d[2].register_forward_hook(get_activation("hab_conv2"))

# -----------------------------------------------------------
# Perform a forward pass through the model with the desired input
#    The feature maps from the hooked layer will be stored in 'activation'
# -----------------------------------------------------------
out = model((x1, x2, x3))  # e.g., model((spatial_data_x, scalars_to_grid, bearing_x))

# -----------------------------------------------------------
# Retrieve the captured feature maps from the dictionary
#    and move them to the CPU for inspection
# -----------------------------------------------------------
feat_maps2 = activation["hab_conv2"].cpu()
print("Feature map shape:", feat_maps2.shape)
# Typically shape: (batch_size, out_channels, height, width)

# -----------------------------------------------------------
# Visualize the feature maps for the first sample in the batch
# -----------------------------------------------------------
feat_maps2_sample = feat_maps2[0]  # Shape: (out_channels, H, W)
num_maps2 = feat_maps2_sample.shape[0]
print("Number of feature maps:", num_maps2)
Feature map shape: torch.Size([32, 4, 101, 101])
Number of feature maps: 4

Extract filters and plot

Code
# -------------------------------------------------------------------------
# Extract the weights (filters) from the second convolution layer in conv_habitat
# -------------------------------------------------------------------------
filters_c2 = model.conv_habitat.conv2d[2].weight.data.clone().cpu()
print("Filters shape:", filters_c2.shape)
# Typically (out_channels, in_channels, kernel_height, kernel_width)

# -------------------------------------------------------------------------
# Visualize each filter’s first channel in a grid of subplots
# -------------------------------------------------------------------------
num_filters_c2 = filters_c2.shape[1]
print(num_filters_c2)

for z in range(num_maps2):

    fig, axes = plt.subplots(2, num_filters_c2, figsize=(2*num_filters_c2, 4))
    for i in range(num_filters_c2):

        # Add the covariates as the first row of subplots
        axes[0,i].imshow(feat_maps1_sample[i].numpy() * x_mask * y_mask, cmap='viridis')
        axes[0,i].axis('off')
        axes[0,i].set_title(f"Layer 1, Map {z+1}")

        # if i > 3:
        #     im1 = axes[0,i].imshow(covariate_stack[0, i].detach().cpu().numpy(), cmap='viridis')
        #     im1.set_clim(-1, 1)
        #     axes[0,i].text(scalar_maps.shape[2] // 2, scalar_maps.shape[3] // 2, 
        #         f'Value: {round(x2[0, i-4].item(), 2)}', 
        #         ha='center', va='center', color='white', fontsize=12)

        kernel = filters_c2[z, i, :, :]  # Show the first input channel
        im = axes[1,i].imshow(kernel, cmap='viridis')
        axes[1,i].axis('off')
        axes[1,i].set_title(f'Layer 2, Filter {z+1}')
        # Annotate each cell with the numeric value
        for (j, k), val in np.ndenumerate(kernel):
            axes[1,i].text(k, j, f'{val:.2f}', ha='center', va='center', color='white')

    plt.tight_layout()
    plt.savefig(f'{output_dir}/id{buffalo_id}_conv_layer2_filters{z}_{today_date}.png', dpi=600, bbox_inches='tight')
    plt.show()

        
    # -----------------------------------------------------------
    # 6. Loop over each feature map channel and save them as images.
    #    Multiply by x_mask * y_mask if you need to mask out edges.
    # -----------------------------------------------------------

    plt.figure()
    plt.imshow(feat_maps2_sample[z].numpy() * x_mask * y_mask, cmap='viridis')
    plt.title(f"Layer 2, Feature Map {z+1}")
    # Hide axis if you prefer: plt.axis('off')
    plt.savefig(f'{output_dir}/id{buffalo_id}_conv_layer2_feature_map{z}_{today_date}.png', dpi=600, bbox_inches='tight')
    plt.show()
Filters shape: torch.Size([4, 4, 3, 3])
4

Convolutional layer 3

Activation hook

Code
# -----------------------------------------------------------
# Register a forward hook on the third convolution layer 
#    in the model's 'conv_habitat' block
# -----------------------------------------------------------
model.conv_habitat.conv2d[4].register_forward_hook(get_activation("hab_conv3"))

# -----------------------------------------------------------
# Perform a forward pass through the model with the desired input
#    The feature maps from the hooked layer will be stored in 'activation'
# -----------------------------------------------------------
out = model((x1, x2, x3))  # e.g., model((spatial_data_x, scalars_to_grid, bearing_x))

# -----------------------------------------------------------
# Retrieve the captured feature maps from the dictionary
#    and move them to the CPU for inspection
# -----------------------------------------------------------
feat_maps3 = activation["hab_conv3"].cpu()
print("Feature map shape:", feat_maps3.shape)
# Typically shape: (batch_size, out_channels, height, width)

# -----------------------------------------------------------
# Visualize the feature maps for the first sample in the batch
# -----------------------------------------------------------
feat_maps3_sample = feat_maps3[0]  # Shape: (out_channels, H, W)
num_maps3 = feat_maps3_sample.shape[0]
print("Number of feature maps:", num_maps3)
Feature map shape: torch.Size([32, 1, 101, 101])
Number of feature maps: 1

Extract filters and plot

Code
# -------------------------------------------------------------------------
# Extract the weights (filters) from the second convolution layer in conv_habitat
# -------------------------------------------------------------------------
filters_c3 = model.conv_habitat.conv2d[4].weight.data.clone().cpu()
print("Filters shape:", filters_c3.shape)
# Typically (out_channels, in_channels, kernel_height, kernel_width)

# -------------------------------------------------------------------------
# Visualize each filter’s first channel in a grid of subplots
# -------------------------------------------------------------------------
num_filters_c3 = filters_c3.shape[1]
print(num_filters_c3)

for z in range(num_maps3):

    fig, axes = plt.subplots(2, num_filters_c3, figsize=(2*num_filters_c3, 4))
    for i in range(num_filters_c3):

        # Add the covariates as the first row of subplots
        axes[0,i].imshow(feat_maps2_sample[i].numpy() * x_mask * y_mask, cmap='viridis')
        axes[0,i].axis('off')
        axes[0,i].set_title(f"Layer 2, Map {z+1}")


        kernel = filters_c3[z, i, :, :]  # Show the first input channel
        im = axes[1,i].imshow(kernel, cmap='viridis')
        axes[1,i].axis('off')
        axes[1,i].set_title(f'Layer 3, Filter {z+1}')
        # Annotate each cell with the numeric value
        for (j, k), val in np.ndenumerate(kernel):
            axes[1,i].text(k, j, f'{val:.2f}', ha='center', va='center', color='white')

    plt.tight_layout()
    plt.savefig(f'{output_dir}/id{buffalo_id}_conv_layer3_filters{z}_{today_date}.png', dpi=600, bbox_inches='tight')
    plt.show()

        
    # -----------------------------------------------------------
    # 6. Loop over each feature map channel and save them as images.
    #    Multiply by x_mask * y_mask if you need to mask out edges.
    # -----------------------------------------------------------

    plt.figure()
    plt.imshow(feat_maps3_sample[z].numpy() * x_mask * y_mask, cmap='viridis')
    plt.title(f"Habitat selection log probability")
    # Hide axis if you prefer: plt.axis('off')
    plt.savefig(f'{output_dir}/id{buffalo_id}_conv_layer3_feature_map{z}_{today_date}.png', dpi=600, bbox_inches='tight')
    plt.show()
Filters shape: torch.Size([1, 4, 3, 3])
4

Checking estimated movement parameters

Similarly to the convolutional layers, we can set hooks to extract the predicted movement parameters from the model, and assess how variable that is across samples.

Code
# -------------------------------------------------------------------------
# Create a list to store the intermediate output from the fully connected
#    movement sub-network (fcn_movement_all)
# -------------------------------------------------------------------------
intermediate_output = []

def hook(module, input, output):
    """
    Hook function that captures the output of the specified layer
    (fcn_movement_all) during the forward pass.
    """
    intermediate_output.append(output)

# -------------------------------------------------------------------------
# Register the forward hook on 'fcn_movement_all', so its outputs
#    are recorded every time the model does a forward pass.
# -------------------------------------------------------------------------
hook_handle = model.fcn_movement_all.register_forward_hook(hook)

# -------------------------------------------------------------------------
# Perform a forward pass with the model in evaluation mode, 
#    disabling gradient computation.
# -------------------------------------------------------------------------
model.eval()
with torch.no_grad():
    final_output = model((x1, x2, x3))

# -------------------------------------------------------------------------
# Inspect the captured intermediate output
#    'intermediate_output[0]' corresponds to the first (and only) forward pass.
# -------------------------------------------------------------------------
print("Intermediate output shape:", intermediate_output[0].shape)
print("Intermediate output values:", intermediate_output[0][0])

# -------------------------------------------------------------------------
# Remove the hook to avoid repeated capturing in subsequent passes
# -------------------------------------------------------------------------
hook_handle.remove()

# -------------------------------------------------------------------------
# Unpack the parameters from the FCN output (assumes a specific ordering)
# -------------------------------------------------------------------------
gamma_shape1, gamma_scale1, gamma_weight1, \
gamma_shape2, gamma_scale2, gamma_weight2, \
vonmises_mu1, vonmises_kappa1, vonmises_weight1, \
vonmises_mu2, vonmises_kappa2, vonmises_weight2 = intermediate_output[0][0]

# -------------------------------------------------------------------------
# Convert parameters from log-space (if applicable) and print them
#    Gamma and von Mises parameters
# -------------------------------------------------------------------------
# --- Gamma #1 ---
print("Gamma shape 1:", torch.exp(gamma_shape1))
print("Gamma scale 1:", torch.exp(gamma_scale1))
print("Gamma weight 1:",
      torch.exp(gamma_weight1) / (torch.exp(gamma_weight1) + torch.exp(gamma_weight2)))

# --- Gamma #2 ---
print("Gamma shape 2:", torch.exp(gamma_shape2))
print("Gamma scale 2:", torch.exp(gamma_scale2) * 500)  # scale factor 500
print("Gamma weight 2:",
      torch.exp(gamma_weight2) / (torch.exp(gamma_weight1) + torch.exp(gamma_weight2)))

# --- von Mises #1 ---
# % (2*np.pi) ensures the mu (angle) is wrapped within [0, 2π)
print("Von Mises mu 1:", vonmises_mu1 % (2*np.pi))
print("Von Mises kappa 1:", torch.exp(vonmises_kappa1))
print("Von Mises weight 1:",
      torch.exp(vonmises_weight1) / (torch.exp(vonmises_weight1) + torch.exp(vonmises_weight2)))

# --- von Mises #2 ---
print("Von Mises mu 2:", vonmises_mu2 % (2*np.pi))
print("Von Mises kappa 2:", torch.exp(vonmises_kappa2))
print("Von Mises weight 2:",
      torch.exp(vonmises_weight2) / (torch.exp(vonmises_weight1) + torch.exp(vonmises_weight2)))
Intermediate output shape: torch.Size([32, 12])
Intermediate output values: tensor([ 0.6753,  0.7900,  0.5984, -0.4723, -0.6171, -3.2411,  0.0966, -1.1834,
         0.4872,  0.0490,  1.1332,  0.0900])
Gamma shape 1: tensor(1.9647)
Gamma scale 1: tensor(2.2034)
Gamma weight 1: tensor(0.9789)
Gamma shape 2: tensor(0.6236)
Gamma scale 2: tensor(269.7635)
Gamma weight 2: tensor(0.0211)
Von Mises mu 1: tensor(0.0966)
Von Mises kappa 1: tensor(0.3062)
Von Mises weight 1: tensor(0.5980)
Von Mises mu 2: tensor(0.0490)
Von Mises kappa 2: tensor(3.1056)
Von Mises weight 2: tensor(0.4020)

Plot the movement distributions

We can use the movement parameters to plot the step length and turning angle distributions for the sample covariates.

Code
# -------------------------------------------------------------------------
# Define helper functions for calculating Gamma and von Mises log-densities
# -------------------------------------------------------------------------
def gamma_density(x, shape, scale):
    """
    Computes the log of the Gamma density for each value in x.

    Args:
      x (Tensor): Input values for which to compute the density.
      shape (float): Gamma shape parameter
      scale (float): Gamma scale parameter

    Returns:
      Tensor: The log of the Gamma probability density at each x.
    """
    return -1*torch.lgamma(shape) - shape*torch.log(scale) \
           + (shape - 1)*torch.log(x) - x/scale

def vonmises_density(x, kappa, vm_mu):
    """
    Computes the log of the von Mises density for each value in x.

    Args:
      x (Tensor): Input angles in radians.
      kappa (float): Concentration parameter (kappa)
      vm_mu (float): Mean direction parameter (mu)

    Returns:
      Tensor: The log of the von Mises probability density at each x.
    """
    return kappa*torch.cos(x - vm_mu) - 1*(np.log(2*torch.pi) + torch.log(torch.special.i0(kappa)))


# -------------------------------------------------------------------------
# Round and display the mixture weights for the Gamma distributions
# -------------------------------------------------------------------------
gamma_weight1_recovered = torch.exp(gamma_weight1)/(torch.exp(gamma_weight1) + torch.exp(gamma_weight2))
rounded_gamma_weight1 = round(gamma_weight1_recovered.item(), 2)

gamma_weight2_recovered = torch.exp(gamma_weight2)/(torch.exp(gamma_weight1) + torch.exp(gamma_weight2))
rounded_gamma_weight2 = round(gamma_weight2_recovered.item(), 2)

# -------------------------------------------------------------------------
# Round and display the mixture weights for the von Mises distributions
# -------------------------------------------------------------------------
vonmises_weight1_recovered = torch.exp(vonmises_weight1)/(torch.exp(vonmises_weight1) + torch.exp(vonmises_weight2))
rounded_vm_weight1 = round(vonmises_weight1_recovered.item(), 2)

vonmises_weight2_recovered = torch.exp(vonmises_weight2)/(torch.exp(vonmises_weight1) + torch.exp(vonmises_weight2))
rounded_vm_weight2 = round(vonmises_weight2_recovered.item(), 2)


# -------------------------------------------------------------------------
# 1. Plotting the Gamma mixture distribution
#    a) Generate x values
#    b) Compute individual Gamma log densities
#    c) Exponentiate and combine using recovered weights
# -------------------------------------------------------------------------
x_values = torch.linspace(1, 101, 1000).to(device)
gamma1_density = gamma_density(x_values, torch.exp(gamma_shape1), torch.exp(gamma_scale1))
gamma2_density = gamma_density(x_values, torch.exp(gamma_shape2), torch.exp(gamma_scale2)*500)
gamma_mixture_density = gamma_weight1_recovered*torch.exp(gamma1_density) \
                        + gamma_weight2_recovered*torch.exp(gamma2_density)

# Move results to CPU and convert to NumPy for plotting
x_values_np = x_values.cpu().numpy()
gamma1_density_np = np.exp(gamma1_density.cpu().numpy())
gamma2_density_np = np.exp(gamma2_density.cpu().numpy())
gamma_mixture_density_np = gamma_mixture_density.cpu().numpy()

# -------------------------------------------------------------------------
# 2. Plot the Gamma distributions and their mixture
# -------------------------------------------------------------------------
plt.plot(x_values_np, gamma1_density_np, label=f'Gamma 1 Density: weight = {rounded_gamma_weight1}')
plt.plot(x_values_np, gamma2_density_np, label=f'Gamma 2 Density: weight = {rounded_gamma_weight2}')
plt.plot(x_values_np, gamma_mixture_density_np, label='Gamma Mixture Density')
plt.xlabel('x')
plt.ylabel('Density')
plt.title('Gamma Density Function')
plt.legend()
plt.show()


# -------------------------------------------------------------------------
# 3. Plotting the von Mises mixture distribution
#    a) Generate x values from -π to π
#    b) Compute individual von Mises log densities
#    c) Exponentiate and combine using recovered weights
# -------------------------------------------------------------------------
x_values = torch.linspace(-np.pi, np.pi, 1000).to(device)
vonmises1_density = vonmises_density(x_values, torch.exp(vonmises_kappa1), vonmises_mu1)
vonmises2_density = vonmises_density(x_values, torch.exp(vonmises_kappa2), vonmises_mu2)
vonmises_mixture_density = vonmises_weight1_recovered*torch.exp(vonmises1_density) \
                           + vonmises_weight2_recovered*torch.exp(vonmises2_density)

# Move results to CPU and convert to NumPy for plotting
x_values_np = x_values.cpu().numpy()
vonmises1_density_np = np.exp(vonmises1_density.cpu().numpy())
vonmises2_density_np = np.exp(vonmises2_density.cpu().numpy())
vonmises_mixture_density_np = vonmises_mixture_density.cpu().numpy()

# -------------------------------------------------------------------------
# 4. Plot the von Mises distributions and their mixture
# -------------------------------------------------------------------------
plt.plot(x_values_np, vonmises1_density_np, label=f'Von Mises 1 Density: weight = {rounded_vm_weight1}')
plt.plot(x_values_np, vonmises2_density_np, label=f'Von Mises 2 Density: weight = {rounded_vm_weight2}')
plt.plot(x_values_np, vonmises_mixture_density_np, label='Von Mises Mixture Density')
plt.xlabel('x (radians)')
plt.ylabel('Density')
plt.title('Von Mises Density Function')
plt.ylim(0, 1.2)  # Set a limit for the y-axis
plt.legend()
plt.show()

Generate a distribution of movement parameters

To see how variable the movement parameters are across samples, we can generate a distribution of movement parameters from a batch of samples.

We take the code from above that we used to create the DataLoader for the test data and increase the batch size (to get more samples to create the distribution from).

As we’re not using the test dataset any more, we’ll just put all of the samples in the same batch, and generate movement parameters for all of them.

Code
print(f'There are {len(dataset_test)} samples in the test dataset')
bs = len(dataset_test) # batch size
dataloader_test = DataLoader(dataset=dataset_test, batch_size=bs, shuffle=True)
There are 1010 samples in the test dataset

Take all of the samples from the test dataset and put them in a single batch.

Code
# -----------------------------------------------------------
# Fetch a batch of data from the training dataloader
# -----------------------------------------------------------
x1_batch, x2_batch, x3_batch, labels = next(iter(dataloader_test))

# -----------------------------------------------------------
# Register a forward hook to capture the outputs 
#    from 'fcn_movement_all' during the forward pass
# -----------------------------------------------------------
hook_handle = model.fcn_movement_all.register_forward_hook(hook)

# -----------------------------------------------------------
# Perform a forward pass in evaluation mode to generate 
#    and capture the sub-network's outputs in 'intermediate_output'
# -----------------------------------------------------------
model.eval()  # Disables certain layers like dropout

# Pass the batch through the model
final_output = model((x1_batch, x2_batch, x3_batch))

# -----------------------------------------------------------
# Prepare lists to store the distribution parameters 
#    for each sample in the batch
# -----------------------------------------------------------
gamma_shape1_list = []
gamma_scale1_list = []
gamma_weight1_list = []
gamma_shape2_list = []
gamma_scale2_list = []
gamma_weight2_list = []
vonmises_mu1_list = []
vonmises_kappa1_list = []
vonmises_weight1_list = []
vonmises_mu2_list = []
vonmises_kappa2_list = []
vonmises_weight2_list = []

# -----------------------------------------------------------
# Extract parameters from 'intermediate_output' 
#    for every sample in the batch
# -----------------------------------------------------------
for batch_output in intermediate_output:
    # Each 'batch_output' corresponds to one forward pass;
    # it might contain multiple samples if the batch size > 1
    for sample_output in batch_output:
        # Unpack the 12 parameters of the Gamma and von Mises mixtures
        gamma_shape1, gamma_scale1, gamma_weight1, \
        gamma_shape2, gamma_scale2, gamma_weight2, \
        vonmises_mu1, vonmises_kappa1, vonmises_weight1, \
        vonmises_mu2, vonmises_kappa2, vonmises_weight2 = sample_output

        # Convert log-space parameters to real space, then store
        gamma_shape1_list.append(torch.exp(gamma_shape1).item())
        gamma_scale1_list.append(torch.exp(gamma_scale1).item())
        gamma_weight1_list.append(
            (torch.exp(gamma_weight1)/(torch.exp(gamma_weight1) + torch.exp(gamma_weight2))).item()
        )
        gamma_shape2_list.append(torch.exp(gamma_shape2).item())
        gamma_scale2_list.append((torch.exp(gamma_scale2)*500).item())  # scale factor 500
        gamma_weight2_list.append(
            (torch.exp(gamma_weight2)/(torch.exp(gamma_weight1) + torch.exp(gamma_weight2))).item()
        )
        vonmises_mu1_list.append((vonmises_mu1 % (2*np.pi)).item())
        vonmises_kappa1_list.append(torch.exp(vonmises_kappa1).item())
        vonmises_weight1_list.append(
            (torch.exp(vonmises_weight1)/(torch.exp(vonmises_weight1) + torch.exp(vonmises_weight2))).item()
        )
        vonmises_mu2_list.append((vonmises_mu2 % (2*np.pi)).item())
        vonmises_kappa2_list.append(torch.exp(vonmises_kappa2).item())
        vonmises_weight2_list.append(
            (torch.exp(vonmises_weight2)/(torch.exp(vonmises_weight1) + torch.exp(vonmises_weight2))).item()
        )

Plot the distribution of movement parameters

Code
# -----------------------------------------------------------
# Define a helper function to plot histograms 
#    for the collected parameters
# -----------------------------------------------------------
def plot_histogram(data, title, xlabel):
    """
    Plots a histogram of the provided data.

    Args:
        data (list): Data points to plot in a histogram.
        title (str): Title of the histogram plot.
        xlabel (str): X-axis label.
    """
    plt.figure()
    plt.hist(data, bins=30, alpha=0.75)
    plt.title(title)
    plt.xlabel(xlabel)
    plt.ylabel('Frequency')
    plt.show()

# -----------------------------------------------------------
# Plot histograms for each parameter distribution
# -----------------------------------------------------------
plot_histogram(gamma_shape1_list, 'Gamma Shape 1 Distribution', 'Shape 1')
plot_histogram(gamma_scale1_list, 'Gamma Scale 1 Distribution', 'Scale 1')
plot_histogram(gamma_weight1_list, 'Gamma Weight 1 Distribution', 'Weight 1')
plot_histogram(gamma_shape2_list, 'Gamma Shape 2 Distribution', 'Shape 2')
plot_histogram(gamma_scale2_list, 'Gamma Scale 2 Distribution', 'Scale 2')
plot_histogram(gamma_weight2_list, 'Gamma Weight 2 Distribution', 'Weight 2')
plot_histogram(vonmises_mu1_list, 'Von Mises Mu 1 Distribution', 'Mu 1')
plot_histogram(vonmises_kappa1_list, 'Von Mises Kappa 1 Distribution', 'Kappa 1')
plot_histogram(vonmises_weight1_list, 'Von Mises Weight 1 Distribution', 'Weight 1')
plot_histogram(vonmises_mu2_list, 'Von Mises Mu 2 Distribution', 'Mu 2')
plot_histogram(vonmises_kappa2_list, 'Von Mises Kappa 2 Distribution', 'Kappa 2')
plot_histogram(vonmises_weight2_list, 'Von Mises Weight 2 Distribution', 'Weight 2')

# -----------------------------------------------------------
# Remove the hook to stop capturing outputs 
#    in subsequent forward passes
# -----------------------------------------------------------
hook_handle.remove()