Oze challenge example

Imports

[1]:
import os
from pathlib import Path
import numpy as np
import skorch
import torch
from skorch.callbacks import EarlyStopping
from matplotlib import pyplot as plt

from time_series_predictor import TimeSeriesPredictor
from time_series_models import BenchmarkLSTM
from oze_dataset import npz_check, OzeNPZDataset

Config

[2]:
plot_config = {}
plot_config['training progress'] = True
plot_config['prediction on training data'] = True
plot_config['forecast'] = True

forecast_config = {}
forecast_config['include history'] = True
forecast_config['steps ahead'] = 500

predictor_config = {}
predictor_config['epochs'] = 300
predictor_config['learning rate'] = 1e-2
predictor_config['hidden dim'] = 200
predictor_config['layers num'] = 3
predictor_config['patience'] = 40
predictor_config['dropout'] = 0.2
predictor_config['train shuffle'] = True
predictor_config['weight decay'] = 0 # 1E-5
predictor_config['bidirectional'] = False # True
predictor_config['train split'] = 20

config = {}
config['plot'] = plot_config
config['forecast'] = forecast_config
config['predictor'] = predictor_config
config['predict on training data enabled'] = True
config['forecast enabled'] = True

Time Series Predictor instantiation

[3]:
tsp = TimeSeriesPredictor(
    BenchmarkLSTM(
        hidden_dim=config['predictor']['hidden dim'],
        num_layers=config['predictor']['layers num'],
        dropout = config['predictor']['dropout'],
        bidirectional=config['predictor']['bidirectional']
    ),
    # Shuffle training data on each epoch
    iterator_train__shuffle=config['predictor']['train shuffle'],
    optimizer__weight_decay=config['predictor']['weight decay'],
    early_stopping=EarlyStopping(patience=config['predictor']['patience']),
    lr=config['predictor']['learning rate'],
    max_epochs=config['predictor']['epochs'],
    train_split=skorch.dataset.CVSplit(config['predictor']['train split']),
    optimizer=torch.optim.Adam
)

Training process

[4]:
credentials = {'user_name': os.environ.get('CHALLENGE_USER_NAME'), 'user_password': os.environ.get('CHALLENGE_USER_PASSWORD')}
ds = OzeNPZDataset(
    dataset_path=npz_check(
        Path(os.path.abspath(''), 'datasets'),
        'dataset',
        credentials=credentials
    )
)
tsp.fit(ds)

Using device cuda
Re-initializing module because the following parameters were re-set: input_dim, output_dim.
Re-initializing optimizer.
  epoch    train_loss    valid_loss    cp      dur
-------  ------------  ------------  ----  -------
      1        0.0598        0.0172     +  16.5771
      2        0.0158        0.0141     +  16.2039
      3        0.0125        0.0102     +  16.3423
      4        0.0092        0.0079     +  17.2209
      5        0.0073        0.0068     +  17.0013
      6        0.0065        0.0056     +  17.0640
      7        0.0056        0.0109        17.0670
      8        0.0057        0.0042     +  17.1090
      9        0.0040        0.0037     +  17.1370
     10        0.0035        0.0031     +  17.9320
     11        0.0033        0.0029     +  18.1180
     12        0.0029        0.0028     +  17.7663
     13        0.0028        0.0028     +  17.9782
     14        0.0026        0.0027     +  17.7470
     15        0.0027        0.0026     +  17.6875
     16        0.0025        0.0028        17.6994
     17        0.0025        0.0023     +  17.9310
     18        0.0023        0.0026        17.9590
     19        0.0024        0.0023     +  18.3705
     20        0.0023        0.0022     +  18.1116
     21        0.0022        0.0023        17.7300
     22        0.0023        0.0023        17.9180
     23        0.0026        0.0022     +  17.6660
     24        0.0023        0.0022     +  17.5430
     25        0.0021        0.0022        17.8630
     26        0.0029        0.0024        18.1196
     27        0.0022        0.0021     +  18.4433
     28        0.0021        0.0021     +  17.6281
     29        0.0022        0.0021        20.1274
     30        0.0021        0.0021        18.8374
     31        0.0020        0.0022        18.8767
     32        0.0021        0.0023        18.6730
     33        0.0021        0.0021        18.1687
     34        0.0020        0.0022        17.9550
     35        0.0023        0.0029        18.0541
     36        0.0024        0.0021        17.9760
     37        0.0021        0.0023        18.1625
     38        0.0021        0.0021     +  18.2970
     39        0.0020        0.0021        18.3930
     40        0.0020        0.0022        18.7891
     41        0.0021        0.0021     +  18.2757
     42        0.0020        0.0020     +  18.0570
     43        0.0020        0.0020     +  20.3404
     44        0.0020        0.0020     +  18.3443
     45        0.0020        0.0021        18.2001
     46        0.0036        0.0023        18.2080
     47        0.0021        0.0024        18.3500
     48        0.0021        0.0020        18.3178
     49        0.0020        0.0020     +  18.2160
     50        0.0020        0.0020     +  18.3598
     51        0.0020        0.0020     +  18.1183
     52        0.0019        0.0019     +  18.4071
     53        0.0020        0.0019        18.3092
     54        0.0019        0.0020        18.6570
     55        0.0019        0.0019     +  18.6850
     56        0.0020        0.0019        18.3450
     57        0.0019        0.0020        18.5480
     58        0.0023        0.0020        18.5968
     59        0.0020        0.0020        18.2156
     60        0.0019        0.0020        18.4245
     61        0.0019        0.0020        17.5182
     62        0.0019        0.0019     +  18.0560
     63        0.0019        0.0019        17.7231
     64        0.0019        0.0019     +  17.7588
     65        0.0019        0.0019        17.9545
     66        0.0019        0.0019        18.0695
     67        0.0019        0.0019        18.2617
     68        0.0019        0.0019        18.1225
     69        0.0018        0.0019        17.5935
     70        0.0019        0.0020        17.8453
     71        0.0018        0.0019        17.8726
     72        0.0019        0.0019        18.0304
     73        0.0018        0.0019     +  18.2397
     74        0.0019        0.0019     +  17.6092
     75        0.0019        0.0019        19.4706
     76        0.0018        0.0019        19.3744
     77        0.0019        0.0019        18.6761
     78        0.0019        0.0019        18.5124
     79        0.0018        0.0019     +  18.2067
     80        0.0023        0.0099        18.0633
     81        0.0034        0.0020        17.7588
     82        0.0019        0.0019        17.9262
     83        0.0019        0.0019        18.1253
     84        0.0019        0.0019        18.3857
     85        0.0018        0.0018     +  18.3150
     86        0.0018        0.0019        17.6030
     87        0.0018        0.0019        19.3958
     88        0.0018        0.0021        19.3822
     89        0.0018        0.0018     +  18.2132
     90        0.0018        0.0019        18.2757
     91        0.0018        0.0018        18.8198
     92        0.0018        0.0018        18.6442
     93        0.0018        0.0018        18.2281
     94        0.0018        0.0018     +  18.2434
     95        0.0018        0.0018     +  18.3947
     96        0.0017        0.0018     +  18.5891
     97        0.0018        0.0018        18.0081
     98        0.0018        0.0019        18.9160
     99        0.0018        0.0018        18.1352
    100        0.0018        0.0018        17.8232
    101        0.0018        0.0018        17.5794
    102        0.0021        0.0024        17.9071
    103        0.0019        0.0018        18.1371
    104        0.0017        0.0018        18.3236
    105        0.0017        0.0017     +  19.3759
    106        0.0017        0.0018        18.7773
    107        0.0017        0.0017     +  18.7611
    108        0.0017        0.0018        19.0639
    109        0.0017        0.0017     +  18.3907
    110        0.0017        0.0018        18.5630
    111        0.0017        0.0017     +  18.8217
    112        0.0017        0.0017        18.8646
    113        0.0017        0.0017        18.4320
    114        0.0017        0.0017        18.3737
    115        0.0017        0.0017        18.2317
    116        0.0017        0.0018        18.0524
    117        0.0017        0.0017        18.2034
    118        0.0017        0.0017        17.9165
    119        0.0017        0.0017     +  17.5662
    120        0.0016        0.0017        17.4541
    121        0.0016        0.0017        18.0072
    122        0.0017        0.0017        18.3076
    123        0.0017        0.0017        18.6552
    124        0.0016        0.0017        18.7788
    125        0.0016        0.0017        18.2889
    126        0.0017        0.0017        17.6763
    127        0.0016        0.0017        19.4493
    128        0.0016        0.0017        18.8597
    129        0.0016        0.0017     +  18.3305
    130        0.0016        0.0017        17.8498
    131        0.0016        0.0017        17.7710
    132        0.0017        0.0017        18.1878
    133        0.0016        0.0017        17.6042
    134        0.0017        0.0017        18.2504
    135        0.0016        0.0018        18.7013
    136        0.0017        0.0017        17.3087
    137        0.0016        0.0017        18.0370
    138        0.0016        0.0017        18.1528
    139        0.0016        0.0017        18.3397
    140        0.0016        0.0017        18.2780
    141        0.0016        0.0017        18.0155
    142        0.0016        0.0017        18.2716
    143        0.0016        0.0017        18.2486
    144        0.0016        0.0017        18.2717
    145        0.0017        0.0019        18.2916
    146        0.0016        0.0016     +  18.0735
    147        0.0016        0.0017        17.8215
    148        0.0015        0.0016     +  21.0219
    149        0.0015        0.0016        19.0031
    150        0.0016        0.0016        18.4742
    151        0.0015        0.0016     +  18.3874
    152        0.0015        0.0016        18.6737
    153        0.0015        0.0016     +  17.8465
    154        0.0015        0.0016        17.7676
    155        0.0016        0.0017        17.9031
    156        0.0017        0.0017        18.0056
    157        0.0016        0.0017        17.9833
    158        0.0015        0.0016     +  18.4605
    159        0.0015        0.0016        18.1056
    160        0.0015        0.0016        18.4370
    161        0.0015        0.0018        18.0088
    162        0.0015        0.0017        18.1680
    163        0.0015        0.0017        17.5063
    164        0.0029        0.0027        19.5434
    165        0.0020        0.0018        18.5370
    166        0.0017        0.0017        17.4618
    167        0.0016        0.0016        17.7114
    168        0.0017        0.0018        17.7470
    169        0.0017        0.0017        17.8873
    170        0.0016        0.0016        18.5114
    171        0.0019        0.0070        18.1091
    172        0.0166        0.0119        18.5273
    173        0.0089        0.0064        19.8884
    174        0.0056        0.0049        17.9124
    175        0.0045        0.0044        17.6294
    176        0.0040        0.0043        18.5569
    177        0.0037        0.0036        18.4536
    178        0.0034        0.0033        18.5129
    179        0.0031        0.0030        18.4247
    180        0.0029        0.0030        18.0297
    181        0.0028        0.0027        18.2043
    182        0.0028        0.0026        17.6103
    183        0.0030        0.0028        17.7594
    184        0.0027        0.0027        17.8397
    185        0.0026        0.0025        17.7770
    186        0.0027        0.0027        18.0442
    187        0.0025        0.0025        17.8620
    188        0.0025        0.0027        17.8000
    189        0.0027        0.0025        17.8609
    190        0.0024        0.0026        17.9180
    191        0.0026        0.0025        18.1246
    192        0.0024        0.0024        18.1221
    193        0.0023        0.0023        18.1935
    194        0.0025        0.0026        18.0658
    195        0.0024        0.0024        18.0395
    196        0.0023        0.0023        20.7955
    197        0.0023        0.0022        19.5824
Stopping since valid_loss has not improved in the last 40 epochs.
Loading the best network from the last checkpoint.

Plot training evolution

[5]:
if config['plot']['training progress']:
    history_length = len(tsp.ttr.regressor_['regressor'].history)
    train_loss = np.zeros((history_length, 1))
    valid_loss = np.zeros((history_length, 1))
    for epoch in tsp.ttr.regressor_['regressor'].history:
        epoch_number = epoch['epoch']-1
        train_loss[epoch_number] = epoch['train_loss']
        valid_loss[epoch_number] = epoch['valid_loss']
    _, axes_one = plt.subplots(figsize=(20, 20))
    plt.plot(train_loss, 'o-', label='training')
    plt.plot(valid_loss, 'o-', label='validation')
    axes_one.set_xlabel('Epoch')
    axes_one.set_ylabel('MSE')
    plt.legend()
../_images/notebooks_example_oze_challenge_9_0.svg

Prediction on training data

[6]:
if config['predict on training data enabled']:
    # Select training example
    idx = np.random.randint(0, len(tsp.dataset))
    dataloader = tsp.ttr.regressor['regressor'].get_iterator(tsp.dataset)
    x, y = dataloader.dataset[idx]

    # Run predictions
    netout = tsp.sample_predict(x)

    d_output = netout.shape[1]
    if config['plot']['prediction on training data']:
        fig, axs = plt.subplots(d_output, 1, figsize=(20,20))
    for idx_output_var in range(d_output):
        ax = axs[idx_output_var]
        # Select real passengers data
        y_true = y[:, idx_output_var]

        y_pred = netout[:, idx_output_var]

        if config['plot']['prediction on training data']:
            ax.plot(y_true, label="Truth")
            ax.plot(y_pred, label="Prediction")
            ax.set_title(tsp.dataset.labels['X'][idx_output_var])
            ax.legend()
../_images/notebooks_example_oze_challenge_11_0.svg

Future forecast

[7]:
# Run forecast
if config['forecast enabled']:
    netout, _ = tsp.forecast(config['forecast']['steps ahead'],
                          include_history=config['forecast']['include history'])

    d_output = netout.shape[-1]
    # Select any training example just for comparison
    idx = np.random.randint(0, len(ds))
    dataloader = tsp.ttr.regressor['regressor'].get_iterator(tsp.dataset)
    x, y = dataloader.dataset[idx]
    if config['plot']['forecast']:
        fig, axs = plt.subplots(d_output, 1, figsize=(20,20))
    for idx_output_var in range(d_output):
        # Select real passengers data
        y_true = y[:, idx_output_var]

        y_pred = netout[idx, :, idx_output_var]

        if config['plot']['forecast']:
            ax = axs[idx_output_var]
            if config['forecast']['include history']:
                plot_args = [y_pred]
            else:
                y_pred_index = [i+tsp.dataset.get_x_shape()[1]+1 for i in range(len(y_pred))]
                plot_args = [y_pred_index, y_pred]
            ax.plot(y_true, label="Truth")
            ax.plot(*plot_args, label="Forecast")
            ax.set_title(tsp.dataset.labels['X'][idx_output_var])
            ax.legend()
../_images/notebooks_example_oze_challenge_13_0.svg
[ ]: