Flights passengers example

Imports

[1]:
import torch
import numpy as np
import calendar
from datetime import timedelta, datetime
from matplotlib import pyplot as plt

from time_series_predictor import TimeSeriesPredictor
from time_series_models import BenchmarkLSTM
from flights_time_series_dataset import FlightsDataset, convert_year_month_array_to_datetime
/home/docs/checkouts/readthedocs.org/user_builds/time-series-predictor/envs/latest/lib/python3.7/site-packages/tqdm/auto.py:22: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html
  from .autonotebook import tqdm as notebook_tqdm

Config

[2]:
plot_config = {}
plot_config['training progress'] = True
plot_config['prediction on training data'] = True
plot_config['prediction on validation data'] = True
plot_config['forecast'] = True

forecast_config = {}
forecast_config['include history'] = True
forecast_config['months ahead'] = 24

predictor_config = {}
predictor_config['epochs'] = 1000
predictor_config['learning rate'] = 1e-2
predictor_config['hidden dim'] = 100
predictor_config['layers num'] = 3

config = {}
config['plot'] = plot_config
config['forecast'] = forecast_config
config['predictor'] = predictor_config
config['predict on training data enabled'] = True
config['forecast enabled'] = True

Time Series Predictor instantiation

[3]:
tsp = TimeSeriesPredictor(
    BenchmarkLSTM(
        hidden_dim=config['predictor']['hidden dim'],
        num_layers=config['predictor']['layers num']
    ),
    lr=config['predictor']['learning rate'],
    max_epochs=predictor_config['epochs'],
    train_split=None,
    optimizer=torch.optim.Adam
)

Training process

[4]:
ds = FlightsDataset()
tsp.fit(ds)
Using device cpu
Re-initializing module because the following parameters were re-set: module__input_dim, module__output_dim.
Re-initializing criterion.
Re-initializing optimizer.
  epoch    train_loss    cp     dur
-------  ------------  ----  ------
      1        0.2249     +  0.1329
      2        0.0488     +  0.1030
      3        2.3266        0.1009
      4        0.1018        0.0978
      5        0.1441        0.0986
      6        0.1413        0.1018
      7        0.1307        0.1051
      8        0.1174        0.1047
      9        0.1005        0.1028
     10        0.0778        0.0970
     11        0.0520        0.1031
     12        0.0604        0.1017
     13        0.0643        0.0947
     14        0.0490        0.0985
     15        0.0463     +  0.0981
     16        0.0500        0.0967
     17        0.0528        0.0953
     18        0.0528        0.0958
     19        0.0501        0.0992
     20        0.0458     +  0.0963
     21        0.0415     +  0.0953
     22        0.0388     +  0.0954
     23        0.0393        0.0924
     24        0.0397        0.0912
     25        0.0290     +  0.0930
     26        0.0246     +  0.0937
     27        0.0211     +  0.0944
     28        0.0139     +  0.0925
     29        0.0140        0.0931
     30        0.0133     +  0.0918
     31        0.0136        0.0935
     32        0.0162        0.0921
     33        0.0147        0.0945
     34        0.0148        0.0928
     35        0.0171        0.0934
     36        0.0159        0.0933
     37        0.0155        0.0946
     38        0.0167        0.0921
     39        0.0160        0.0918
     40        0.0145        0.0930
     41        0.0148        0.0928
     42        0.0145        0.0921
     43        0.0134        0.0938
     44        0.0130     +  0.0924
     45        0.0132        0.0936
     46        0.0122     +  0.0923
     47        0.0115     +  0.0927
     48        0.0116        0.0953
     49        0.0112     +  0.0937
     50        0.0107     +  0.0926
     51        0.0110        0.0911
     52        0.0110        0.0922
     53        0.0105     +  0.0933
     54        0.0104     +  0.0920
     55        0.0104     +  0.0927
     56        0.0098     +  0.0937
     57        0.0098     +  0.0933
     58        0.0093     +  0.0929
     59        0.0089     +  0.0921
     60        0.0086     +  0.0928
     61        0.0082     +  0.0925
     62        0.0081     +  0.0920
     63        0.0078     +  0.0943
     64        0.0078        0.0919
     65        0.0075     +  0.0921
     66        0.0075     +  0.0928
     67        0.0073     +  0.0929
     68        0.0074        0.0925
     69        0.0072     +  0.0931
     70        0.0072        0.0925
     71        0.0071     +  0.0937
     72        0.0069     +  0.0936
     73        0.0069     +  0.0943
     74        0.0066     +  0.0930
     75        0.0065     +  0.0932
     76        0.0063     +  0.0915
     77        0.0062     +  0.0924
     78        0.0060     +  0.0928
     79        0.0058     +  0.0917
     80        0.0056     +  0.0933
     81        0.0053     +  0.0934
     82        0.0050     +  0.0933
     83        0.0045     +  0.0923
     84        0.0040     +  0.0927
     85        0.0035     +  0.0933
     86        0.0035     +  0.0939
     87        0.0041        0.0946
     88        0.0039        0.0937
     89        0.0035     +  0.0924
     90        0.0038        0.0924
     91        0.0033     +  0.0927
     92        0.0033     +  0.0929
     93        0.0033        0.0929
     94        0.0030     +  0.0913
     95        0.0031        0.0943
     96        0.0028     +  0.0938
     97        0.0029        0.0934
     98        0.0028     +  0.0925
     99        0.0029        0.0928
    100        0.0028        0.0929
    101        0.0028        0.0935
    102        0.0027     +  0.0946
    103        0.0026     +  0.0925
    104        0.0026     +  0.0931
    105        0.0025     +  0.0935
    106        0.0026        0.0947
    107        0.0025     +  0.0954
    108        0.0025        0.0939
    109        0.0024     +  0.0925
    110        0.0024     +  0.0935
    111        0.0024     +  0.0929
    112        0.0023     +  0.0940
    113        0.0023     +  0.0921
    114        0.0023     +  0.0927
    115        0.0023        0.0934
    116        0.0022     +  0.0934
    117        0.0022     +  0.0927
    118        0.0022     +  0.0929
    119        0.0021     +  0.0929
    120        0.0021     +  0.0935
    121        0.0021     +  0.0938
    122        0.0021     +  0.0938
    123        0.0020     +  0.0931
    124        0.0020     +  0.0932
    125        0.0020     +  0.0933
    126        0.0020     +  0.0946
    127        0.0019     +  0.0927
    128        0.0019     +  0.0925
    129        0.0019     +  0.0925
    130        0.0018     +  0.0943
    131        0.0018     +  0.0944
    132        0.0018     +  0.0931
    133        0.0018     +  0.0920
    134        0.0017     +  0.0935
    135        0.0017     +  0.0924
    136        0.0017     +  0.0925
    137        0.0017     +  0.0911
    138        0.0016     +  0.0913
    139        0.0016     +  0.0930
    140        0.0016     +  0.0932
    141        0.0015     +  0.0928
    142        0.0015     +  0.0925
    143        0.0015     +  0.0927
    144        0.0015     +  0.0933
    145        0.0014     +  0.0947
    146        0.0014     +  0.0933
    147        0.0014     +  0.0931
    148        0.0014        0.0926
    149        0.0016        0.0913
    150        0.0020        0.0925
    151        0.0035        0.0921
    152        0.0023        0.0926
    153        0.0015        0.0943
    154        0.0014        0.0924
    155        0.0018        0.0921
    156        0.0017        0.0928
    157        0.0012     +  0.0924
    158        0.0018        0.0933
    159        0.0020        0.0929
    160        0.0012     +  0.0940
    161        0.0020        0.0914
    162        0.0019        0.0916
    163        0.0013        0.0938
    164        0.0022        0.0927
    165        0.0014        0.0948
    166        0.0016        0.0920
    167        0.0016        0.0927
    168        0.0011     +  0.0919
    169        0.0015        0.0928
    170        0.0011     +  0.0915
    171        0.0015        0.0960
    172        0.0013        0.0914
    173        0.0012        0.0926
    174        0.0013        0.0935
    175        0.0011     +  0.0930
    176        0.0013        0.0922
    177        0.0011     +  0.0911
    178        0.0012        0.0913
    179        0.0011        0.0916
    180        0.0011        0.0924
    181        0.0011        0.0935
    182        0.0010     +  0.0917
    183        0.0011        0.0935
    184        0.0010     +  0.0918
    185        0.0011        0.0945
    186        0.0010     +  0.0939
    187        0.0010        0.0940
    188        0.0010        0.0918
    189        0.0010     +  0.0927
    190        0.0010        0.0930
    191        0.0010     +  0.0935
    192        0.0010        0.0925
    193        0.0010     +  0.0918
    194        0.0010        0.0924
    195        0.0010        0.0938
    196        0.0009     +  0.0918
    197        0.0010        0.0920
    198        0.0009     +  0.0915
    199        0.0009        0.0932
    200        0.0009        0.0920
    201        0.0009     +  0.0938
    202        0.0009        0.0921
    203        0.0009     +  0.0924
    204        0.0009        0.0925
    205        0.0009     +  0.0948
    206        0.0009     +  0.0949
    207        0.0009        0.0953
    208        0.0009     +  0.0932
    209        0.0009     +  0.0925
    210        0.0009        0.0928
    211        0.0009     +  0.0937
    212        0.0009     +  0.0933
    213        0.0009     +  0.0931
    214        0.0008     +  0.0929
    215        0.0008     +  0.0916
    216        0.0008     +  0.0921
    217        0.0008     +  0.0930
    218        0.0008     +  0.0944
    219        0.0008     +  0.0929
    220        0.0008     +  0.0943
    221        0.0008     +  0.0940
    222        0.0008     +  0.0939
    223        0.0008     +  0.0929
    224        0.0008     +  0.0950
    225        0.0008     +  0.0929
    226        0.0008     +  0.0929
    227        0.0008     +  0.0920
    228        0.0008     +  0.0943
    229        0.0008     +  0.0938
    230        0.0008     +  0.0934
    231        0.0008     +  0.0935
    232        0.0008     +  0.0939
    233        0.0008     +  0.0937
    234        0.0008     +  0.0924
    235        0.0008     +  0.0924
    236        0.0008     +  0.0943
    237        0.0008     +  0.0927
    238        0.0008     +  0.0942
    239        0.0008     +  0.0938
    240        0.0008     +  0.0937
    241        0.0008     +  0.0934
    242        0.0007     +  0.0952
    243        0.0007     +  0.0943
    244        0.0007     +  0.0945
    245        0.0007     +  0.0925
    246        0.0007     +  0.0932
    247        0.0007     +  0.0924
    248        0.0007     +  0.0938
    249        0.0007     +  0.0937
    250        0.0007        0.0950
    251        0.0007        0.0980
    252        0.0008        0.0919
    253        0.0009        0.0923
    254        0.0015        0.0914
    255        0.0020        0.0926
    256        0.0029        0.0928
    257        0.0009        0.0928
    258        0.0020        0.0932
    259        0.0028        0.0932
    260        0.0010        0.0941
    261        0.0034        0.0935
    262        0.0020        0.0946
    263        0.0020        0.0933
    264        0.0020        0.0944
    265        0.0011        0.0939
    266        0.0018        0.0928
    267        0.0011        0.0923
    268        0.0017        0.0943
    269        0.0009        0.0930
    270        0.0014        0.0934
    271        0.0009        0.0933
    272        0.0013        0.0936
    273        0.0009        0.0929
    274        0.0012        0.0934
    275        0.0009        0.0923
    276        0.0010        0.0942
    277        0.0009        0.0921
    278        0.0010        0.0944
    279        0.0009        0.0945
    280        0.0009        0.0938
    281        0.0009        0.0931
    282        0.0008        0.0929
    283        0.0009        0.0943
    284        0.0008        0.0936
    285        0.0009        0.0940
    286        0.0008        0.0928
    287        0.0008        0.0938
    288        0.0007        0.0925
    289        0.0008        0.0940
    290        0.0007     +  0.0939
    291        0.0008        0.0944
    292        0.0007     +  0.0939
    293        0.0008        0.0938
    294        0.0007     +  0.0923
    295        0.0007        0.0933
    296        0.0007     +  0.0927
    297        0.0007        0.0931
    298        0.0007     +  0.0920
    299        0.0007        0.0943
    300        0.0007     +  0.0927
    301        0.0007        0.0945
    302        0.0007     +  0.0930
    303        0.0007        0.0939
    304        0.0007     +  0.0919
    305        0.0007        0.0940
    306        0.0007     +  0.0933
    307        0.0007        0.0924
    308        0.0006     +  0.0930
    309        0.0007        0.0932
    310        0.0006     +  0.0921
    311        0.0006        0.0930
    312        0.0006     +  0.0937
    313        0.0006        0.0926
    314        0.0006     +  0.0922
    315        0.0006        0.0925
    316        0.0006     +  0.0932
    317        0.0006     +  0.0946
    318        0.0006     +  0.0951
    319        0.0006     +  0.0930
    320        0.0006     +  0.0944
    321        0.0006     +  0.0934
    322        0.0006     +  0.0947
    323        0.0006     +  0.0927
    324        0.0006        0.0943
    325        0.0006     +  0.0916
    326        0.0006     +  0.0981
    327        0.0006     +  0.0925
    328        0.0006     +  0.0928
    329        0.0006     +  0.0923
    330        0.0006     +  0.0939
    331        0.0006     +  0.0926
    332        0.0006     +  0.0920
    333        0.0006     +  0.0920
    334        0.0006     +  0.0909
    335        0.0006     +  0.0909
    336        0.0006     +  0.0916
    337        0.0006     +  0.0922
    338        0.0006     +  0.0928
    339        0.0006     +  0.0914
    340        0.0006     +  0.0932
    341        0.0005     +  0.0917
    342        0.0005     +  0.0933
    343        0.0005     +  0.0911
    344        0.0005     +  0.0908
    345        0.0005     +  0.0911
    346        0.0005        0.0936
    347        0.0005        0.0928
    348        0.0005        0.0944
    349        0.0005     +  0.0931
    350        0.0005        0.0914
    351        0.0005        0.0913
    352        0.0006        0.0912
    353        0.0005        0.0925
    354        0.0005        0.0920
    355        0.0005        0.0924
    356        0.0005        0.0925
    357        0.0005     +  0.0914
    358        0.0005     +  0.0926
    359        0.0005     +  0.0935
    360        0.0005     +  0.0952
    361        0.0005     +  0.0939
    362        0.0005     +  0.0944
    363        0.0005     +  0.0925
    364        0.0005     +  0.0926
    365        0.0005     +  0.0926
    366        0.0005     +  0.0918
    367        0.0005     +  0.0929
    368        0.0005     +  0.0922
    369        0.0005     +  0.0918
    370        0.0005     +  0.0928
    371        0.0005     +  0.0929
    372        0.0005     +  0.0926
    373        0.0005     +  0.0933
    374        0.0005     +  0.0944
    375        0.0005        0.0934
    376        0.0005        0.0922
    377        0.0006        0.0918
    378        0.0009        0.0931
    379        0.0013        0.0936
    380        0.0011        0.0931
    381        0.0007        0.0924
    382        0.0008        0.0929
    383        0.0008        0.0908
    384        0.0007        0.0949
    385        0.0009        0.0919
    386        0.0006        0.0920
    387        0.0008        0.0921
    388        0.0007        0.0911
    389        0.0006        0.0912
    390        0.0007        0.0908
    391        0.0006        0.0909
    392        0.0007        0.0915
    393        0.0006        0.0911
    394        0.0006        0.0917
    395        0.0006        0.0936
    396        0.0006        0.0918
    397        0.0006        0.0931
    398        0.0006        0.0925
    399        0.0006        0.0942
    400        0.0006        0.0939
    401        0.0006        0.0930
    402        0.0006        0.0925
    403        0.0006        0.0920
    404        0.0006        0.0930
    405        0.0005        0.0906
    406        0.0005        0.0935
    407        0.0005        0.0919
    408        0.0005        0.0923
    409        0.0005        0.0939
    410        0.0005        0.0935
    411        0.0005        0.0917
    412        0.0005        0.0921
    413        0.0005        0.0923
    414        0.0005        0.0923
    415        0.0005        0.0929
    416        0.0005        0.0943
    417        0.0005        0.0919
    418        0.0005        0.0941
    419        0.0005        0.0924
    420        0.0005     +  0.0944
    421        0.0005     +  0.0945
    422        0.0005     +  0.0944
    423        0.0005     +  0.0917
    424        0.0005     +  0.0925
    425        0.0004     +  0.0917
    426        0.0004     +  0.0931
    427        0.0004     +  0.0924
    428        0.0004     +  0.0932
    429        0.0004     +  0.0934
    430        0.0004     +  0.0922
    431        0.0004     +  0.0917
    432        0.0004     +  0.0920
    433        0.0004     +  0.0920
    434        0.0004     +  0.0942
    435        0.0004     +  0.0926
    436        0.0004     +  0.0948
    437        0.0004     +  0.0934
    438        0.0004     +  0.0947
    439        0.0004     +  0.0939
    440        0.0004     +  0.0938
    441        0.0004     +  0.0933
    442        0.0004     +  0.0928
    443        0.0004     +  0.0948
    444        0.0004     +  0.0932
    445        0.0004     +  0.0927
    446        0.0004     +  0.0917
    447        0.0004     +  0.0923
    448        0.0004     +  0.0920
    449        0.0004     +  0.0918
    450        0.0004     +  0.0925
    451        0.0004     +  0.0923
    452        0.0004     +  0.0948
    453        0.0004     +  0.0931
    454        0.0004     +  0.0936
    455        0.0004     +  0.0933
    456        0.0004     +  0.0941
    457        0.0004     +  0.0937
    458        0.0004     +  0.0936
    459        0.0004     +  0.0920
    460        0.0004     +  0.0930
    461        0.0004     +  0.0942
    462        0.0004     +  0.0927
    463        0.0004     +  0.0947
    464        0.0004     +  0.0939
    465        0.0004     +  0.0929
    466        0.0003     +  0.0922
    467        0.0003     +  0.0924
    468        0.0003     +  0.0921
    469        0.0003     +  0.1064
    470        0.0003     +  0.0947
    471        0.0003     +  0.0981
    472        0.0003     +  0.0936
    473        0.0003     +  0.0952
    474        0.0003     +  0.0927
    475        0.0003     +  0.0931
    476        0.0003     +  0.0925
    477        0.0003     +  0.0920
    478        0.0003     +  0.0938
    479        0.0003     +  0.0926
    480        0.0003     +  0.0990
    481        0.0003     +  0.0941
    482        0.0003     +  0.0933
    483        0.0003     +  0.0929
    484        0.0003     +  0.0926
    485        0.0003     +  0.0915
    486        0.0003     +  0.0923
    487        0.0003     +  0.0918
    488        0.0003     +  0.0931
    489        0.0003     +  0.0934
    490        0.0003     +  0.0941
    491        0.0003     +  0.0915
    492        0.0003     +  0.0928
    493        0.0003     +  0.0931
    494        0.0003     +  0.0913
    495        0.0003     +  0.0939
    496        0.0003     +  0.0936
    497        0.0003     +  0.0929
    498        0.0003     +  0.0938
    499        0.0003     +  0.0920
    500        0.0003     +  0.0927
    501        0.0003     +  0.0928
    502        0.0003        0.0937
    503        0.0003        0.0927
    504        0.0003        0.0948
    505        0.0003        0.0923
    506        0.0005        0.0934
    507        0.0007        0.0924
    508        0.0015        0.0938
    509        0.0008        0.0944
    510        0.0003        0.0935
    511        0.0007        0.0940
    512        0.0010        0.0922
    513        0.0011        0.0916
    514        0.0004        0.0921
    515        0.0008        0.0943
    516        0.0010        0.0915
    517        0.0003        0.0932
    518        0.0009        0.0926
    519        0.0018        0.0941
    520        0.0004        0.0925
    521        0.0016        0.0929
    522        0.0017        0.0915
    523        0.0007        0.0918
    524        0.0023        0.0912
    525        0.0011        0.0932
    526        0.0012        0.0910
    527        0.0013        0.0932
    528        0.0005        0.0933
    529        0.0012        0.0923
    530        0.0004        0.0937
    531        0.0012        0.0922
    532        0.0004        0.0922
    533        0.0009        0.0912
    534        0.0004        0.0920
    535        0.0007        0.0898
    536        0.0005        0.0932
    537        0.0006        0.0911
    538        0.0005        0.0923
    539        0.0004        0.0943
    540        0.0005        0.0922
    541        0.0003        0.0918
    542        0.0006        0.0914
    543        0.0003        0.0912
    544        0.0005        0.0919
    545        0.0003        0.0913
    546        0.0004        0.0931
    547        0.0004        0.0922
    548        0.0004        0.0921
    549        0.0004        0.0937
    550        0.0003        0.0942
    551        0.0004        0.0930
    552        0.0003        0.0940
    553        0.0004        0.0929
    554        0.0003        0.0929
    555        0.0003        0.0931
    556        0.0003        0.0931
    557        0.0003        0.0929
    558        0.0003        0.0918
    559        0.0003        0.0926
    560        0.0003        0.0925
    561        0.0003     +  0.0938
    562        0.0003        0.0919
    563        0.0003     +  0.0915
    564        0.0003        0.0919
    565        0.0003        0.0936
    566        0.0003        0.0930
    567        0.0003        0.0950
    568        0.0003     +  0.0932
    569        0.0003        0.0933
    570        0.0002     +  0.0924
    571        0.0003        0.0924
    572        0.0002        0.0931
    573        0.0002     +  0.0928
    574        0.0003        0.0925
    575        0.0002     +  0.0908
    576        0.0002        0.0924
    577        0.0002     +  0.0932
    578        0.0002        0.0929
    579        0.0002        0.0925
    580        0.0002     +  0.0927
    581        0.0002        0.0924
    582        0.0002     +  0.0926
    583        0.0002        0.0916
    584        0.0002     +  0.0925
    585        0.0002     +  0.0924
    586        0.0002        0.0931
    587        0.0002     +  0.0907
    588        0.0002        0.0939
    589        0.0002     +  0.0912
    590        0.0002     +  0.0924
    591        0.0002        0.0924
    592        0.0002     +  0.0934
    593        0.0002     +  0.0935
    594        0.0002        0.0923
    595        0.0002     +  0.0928
    596        0.0002     +  0.0938
    597        0.0002     +  0.0927
    598        0.0002     +  0.0932
    599        0.0002     +  0.0930
    600        0.0002     +  0.0926
    601        0.0002     +  0.0926
    602        0.0002     +  0.0929
    603        0.0002     +  0.0929
    604        0.0002     +  0.0942
    605        0.0002     +  0.0931
    606        0.0002     +  0.0935
    607        0.0002     +  0.0939
    608        0.0002     +  0.0949
    609        0.0002     +  0.0935
    610        0.0002     +  0.0922
    611        0.0002     +  0.0927
    612        0.0002     +  0.0926
    613        0.0002     +  0.0927
    614        0.0002     +  0.0928
    615        0.0002     +  0.0922
    616        0.0002     +  0.0935
    617        0.0002     +  0.0923
    618        0.0002     +  0.0915
    619        0.0002     +  0.0917
    620        0.0002     +  0.0917
    621        0.0002     +  0.0931
    622        0.0002     +  0.0915
    623        0.0002     +  0.0920
    624        0.0002     +  0.0933
    625        0.0002     +  0.0932
    626        0.0002     +  0.0931
    627        0.0002     +  0.0929
    628        0.0002     +  0.0918
    629        0.0002     +  0.0917
    630        0.0002     +  0.0919
    631        0.0002     +  0.0920
    632        0.0002     +  0.0921
    633        0.0002     +  0.0931
    634        0.0002     +  0.0928
    635        0.0002     +  0.0924
    636        0.0002     +  0.0922
    637        0.0002     +  0.0921
    638        0.0002     +  0.0932
    639        0.0002     +  0.0918
    640        0.0002     +  0.0930
    641        0.0002     +  0.0940
    642        0.0002     +  0.0930
    643        0.0002     +  0.0936
    644        0.0002     +  0.0933
    645        0.0002     +  0.0930
    646        0.0002     +  0.0931
    647        0.0002     +  0.0920
    648        0.0002     +  0.0925
    649        0.0002     +  0.0919
    650        0.0002     +  0.0918
    651        0.0002     +  0.0928
    652        0.0002     +  0.0915
    653        0.0002     +  0.0916
    654        0.0002     +  0.0932
    655        0.0002     +  0.0933
    656        0.0002     +  0.0930
    657        0.0002     +  0.0922
    658        0.0002     +  0.0931
    659        0.0002     +  0.0927
    660        0.0002     +  0.0932
    661        0.0002     +  0.0919
    662        0.0002     +  0.0909
    663        0.0002     +  0.0931
    664        0.0002     +  0.0925
    665        0.0002     +  0.0938
    666        0.0002     +  0.0941
    667        0.0002     +  0.0924
    668        0.0002     +  0.0920
    669        0.0002     +  0.0936
    670        0.0002     +  0.0920
    671        0.0002     +  0.0937
    672        0.0002     +  0.0922
    673        0.0002     +  0.0933
    674        0.0002     +  0.0934
    675        0.0002     +  0.0950
    676        0.0002     +  0.0935
    677        0.0002     +  0.0930
    678        0.0002     +  0.0929
    679        0.0002     +  0.0928
    680        0.0002     +  0.0939
    681        0.0002     +  0.0935
    682        0.0002     +  0.0924
    683        0.0002     +  0.0927
    684        0.0002     +  0.0918
    685        0.0002     +  0.0925
    686        0.0002     +  0.0924
    687        0.0002     +  0.0914
    688        0.0002     +  0.0935
    689        0.0002     +  0.0915
    690        0.0002     +  0.0936
    691        0.0002     +  0.0935
    692        0.0002     +  0.0936
    693        0.0002     +  0.0935
    694        0.0002     +  0.0927
    695        0.0002     +  0.0931
    696        0.0002     +  0.0916
    697        0.0002     +  0.0922
    698        0.0002     +  0.0919
    699        0.0002     +  0.0936
    700        0.0002     +  0.0935
    701        0.0002     +  0.0926
    702        0.0002     +  0.0930
    703        0.0002     +  0.0935
    704        0.0002     +  0.0936
    705        0.0002     +  0.0929
    706        0.0002     +  0.0936
    707        0.0002     +  0.0917
    708        0.0002     +  0.0935
    709        0.0002     +  0.0933
    710        0.0002     +  0.0931
    711        0.0002     +  0.0925
    712        0.0002     +  0.0922
    713        0.0002     +  0.0931
    714        0.0002     +  0.0925
    715        0.0002     +  0.0935
    716        0.0002     +  0.0935
    717        0.0002     +  0.0911
    718        0.0002     +  0.0926
    719        0.0002     +  0.0921
    720        0.0002     +  0.0919
    721        0.0002     +  0.0925
    722        0.0002     +  0.0928
    723        0.0002     +  0.0930
    724        0.0002     +  0.0926
    725        0.0002     +  0.0946
    726        0.0002     +  0.0917
    727        0.0002     +  0.0937
    728        0.0002     +  0.0918
    729        0.0002     +  0.0925
    730        0.0002     +  0.0928
    731        0.0002     +  0.0923
    732        0.0002     +  0.0936
    733        0.0002     +  0.0931
    734        0.0002     +  0.0927
    735        0.0002     +  0.0926
    736        0.0002     +  0.0927
    737        0.0002     +  0.0928
    738        0.0002     +  0.0935
    739        0.0002     +  0.0921
    740        0.0002     +  0.0936
    741        0.0002     +  0.0932
    742        0.0002     +  0.0939
    743        0.0002     +  0.0938
    744        0.0002     +  0.0926
    745        0.0002     +  0.0919
    746        0.0002     +  0.0943
    747        0.0002     +  0.0932
    748        0.0002     +  0.0936
    749        0.0002     +  0.0938
    750        0.0002     +  0.0927
    751        0.0002     +  0.0928
    752        0.0002     +  0.0928
    753        0.0002     +  0.0922
    754        0.0002     +  0.0942
    755        0.0002     +  0.0928
    756        0.0002     +  0.0945
    757        0.0001     +  0.0936
    758        0.0001     +  0.0950
    759        0.0001     +  0.0938
    760        0.0001     +  0.0929
    761        0.0001     +  0.0930
    762        0.0001     +  0.0930
    763        0.0001     +  0.0930
    764        0.0001     +  0.0927
    765        0.0001     +  0.0929
    766        0.0001     +  0.0926
    767        0.0001     +  0.0924
    768        0.0001     +  0.0925
    769        0.0001     +  0.0944
    770        0.0001     +  0.0935
    771        0.0001     +  0.0939
    772        0.0001     +  0.0928
    773        0.0001     +  0.0942
    774        0.0001     +  0.0931
    775        0.0001     +  0.0928
    776        0.0001     +  0.0927
    777        0.0001     +  0.0934
    778        0.0001     +  0.0926
    779        0.0001     +  0.0944
    780        0.0001     +  0.0925
    781        0.0001     +  0.0942
    782        0.0001     +  0.0937
    783        0.0001     +  0.0937
    784        0.0001     +  0.0926
    785        0.0001     +  0.0930
    786        0.0001     +  0.0936
    787        0.0001     +  0.0952
    788        0.0001     +  0.0932
    789        0.0001     +  0.0933
    790        0.0001     +  0.0929
    791        0.0001     +  0.0946
    792        0.0001     +  0.0940
    793        0.0001     +  0.0935
    794        0.0001     +  0.0915
    795        0.0001     +  0.0933
    796        0.0001     +  0.0922
    797        0.0001     +  0.0944
    798        0.0001     +  0.0929
    799        0.0001     +  0.0914
    800        0.0001     +  0.0914
    801        0.0001     +  0.0940
    802        0.0001     +  0.0923
    803        0.0001     +  0.0941
    804        0.0001     +  0.0931
    805        0.0001     +  0.0932
    806        0.0001     +  0.0932
    807        0.0001     +  0.0940
    808        0.0001     +  0.0924
    809        0.0001     +  0.0933
    810        0.0001     +  0.0933
    811        0.0001     +  0.0927
    812        0.0001     +  0.0919
    813        0.0001     +  0.0935
    814        0.0001     +  0.0928
    815        0.0001     +  0.0916
    816        0.0001     +  0.0925
    817        0.0001     +  0.0924
    818        0.0001     +  0.0919
    819        0.0001     +  0.0937
    820        0.0001     +  0.0933
    821        0.0001     +  0.0942
    822        0.0001     +  0.0943
    823        0.0001     +  0.0944
    824        0.0001     +  0.0938
    825        0.0001     +  0.0936
    826        0.0001     +  0.0920
    827        0.0001     +  0.0924
    828        0.0001     +  0.0937
    829        0.0001     +  0.0927
    830        0.0001     +  0.0933
    831        0.0001     +  0.0919
    832        0.0001     +  0.0920
    833        0.0001     +  0.0932
    834        0.0001        0.0926
    835        0.0001        0.0927
    836        0.0001        0.0949
    837        0.0001        0.0930
    838        0.0002        0.0938
    839        0.0002        0.0923
    840        0.0003        0.0934
    841        0.0003        0.0937
    842        0.0006        0.0924
    843        0.0006        0.0936
    844        0.0006        0.0941
    845        0.0002        0.0918
    846        0.0002        0.0921
    847        0.0004        0.0948
    848        0.0004        0.0930
    849        0.0004        0.0935
    850        0.0002        0.0924
    851        0.0002        0.0927
    852        0.0004        0.0923
    853        0.0004        0.0915
    854        0.0003        0.0928
    855        0.0002        0.0927
    856        0.0003        0.0923
    857        0.0004        0.0942
    858        0.0003        0.0933
    859        0.0001        0.0942
    860        0.0002        0.0912
    861        0.0002        0.0931
    862        0.0002        0.0935
    863        0.0001        0.0936
    864        0.0002        0.0928
    865        0.0002        0.0920
    866        0.0001        0.0920
    867        0.0001        0.0931
    868        0.0002        0.0948
    869        0.0002        0.0940
    870        0.0001        0.0935
    871        0.0001        0.0928
    872        0.0002        0.0939
    873        0.0002        0.0923
    874        0.0001     +  0.0912
    875        0.0001        0.0939
    876        0.0001        0.0927
    877        0.0001        0.0914
    878        0.0001     +  0.0929
    879        0.0001        0.0927
    880        0.0001        0.0938
    881        0.0001        0.0935
    882        0.0001     +  0.0935
    883        0.0001        0.0926
    884        0.0001        0.0926
    885        0.0001        0.0930
    886        0.0001     +  0.0919
    887        0.0001        0.0928
    888        0.0001        0.0941
    889        0.0001     +  0.0929
    890        0.0001        0.0962
    891        0.0001        0.0940
    892        0.0001        0.0918
    893        0.0001     +  0.0913
    894        0.0001        0.0926
    895        0.0001        0.1027
    896        0.0001        0.0925
    897        0.0001     +  0.0934
    898        0.0001        0.0924
    899        0.0001        0.0934
    900        0.0001        0.0934
    901        0.0001     +  0.0932
    902        0.0001        0.0948
    903        0.0001        0.0929
    904        0.0001     +  0.0934
    905        0.0001     +  0.0933
    906        0.0001     +  0.0927
    907        0.0001        0.0939
    908        0.0001     +  0.0921
    909        0.0001     +  0.0937
    910        0.0001     +  0.0941
    911        0.0001     +  0.0929
    912        0.0001     +  0.0920
    913        0.0001     +  0.0930
    914        0.0001     +  0.0941
    915        0.0001     +  0.0933
    916        0.0001     +  0.0944
    917        0.0001     +  0.0933
    918        0.0001     +  0.0943
    919        0.0001     +  0.0938
    920        0.0001     +  0.0932
    921        0.0001     +  0.0941
    922        0.0001     +  0.0939
    923        0.0001     +  0.0932
    924        0.0001     +  0.0927
    925        0.0001     +  0.0939
    926        0.0001     +  0.0922
    927        0.0001     +  0.0944
    928        0.0001     +  0.0932
    929        0.0001     +  0.0922
    930        0.0001     +  0.0931
    931        0.0001     +  0.0929
    932        0.0001     +  0.0941
    933        0.0001     +  0.0942
    934        0.0001     +  0.0952
    935        0.0001     +  0.0943
    936        0.0001     +  0.0943
    937        0.0001     +  0.0932
    938        0.0001     +  0.0935
    939        0.0001     +  0.0929
    940        0.0001     +  0.0933
    941        0.0001     +  0.0937
    942        0.0001     +  0.0924
    943        0.0001     +  0.0924
    944        0.0001     +  0.0935
    945        0.0001     +  0.0925
    946        0.0001     +  0.0951
    947        0.0001     +  0.0929
    948        0.0001     +  0.0939
    949        0.0001     +  0.0946
    950        0.0001     +  0.0943
    951        0.0001     +  0.0942
    952        0.0001     +  0.0920
    953        0.0001     +  0.0924
    954        0.0001     +  0.0941
    955        0.0001     +  0.0926
    956        0.0001     +  0.0938
    957        0.0001     +  0.0941
    958        0.0001     +  0.0932
    959        0.0001        0.0929
    960        0.0001        0.0944
    961        0.0001        0.0926
    962        0.0001        0.0936
    963        0.0001        0.0946
    964        0.0001        0.0929
    965        0.0002        0.0938
    966        0.0003        0.0933
    967        0.0006        0.0938
    968        0.0006        0.0933
    969        0.0006        0.0936
    970        0.0003        0.0928
    971        0.0002        0.0919
    972        0.0003        0.0928
    973        0.0003        0.0935
    974        0.0002        0.0918
    975        0.0002        0.0928
    976        0.0002        0.0934
    977        0.0002        0.0933
    978        0.0002        0.0929
    979        0.0002        0.0925
    980        0.0002        0.0936
    981        0.0002        0.0931
    982        0.0002        0.0929
    983        0.0001        0.0942
    984        0.0002        0.0938
    985        0.0002        0.0926
    986        0.0001        0.0935
    987        0.0001        0.0929
    988        0.0001        0.0943
    989        0.0002        0.0939
    990        0.0001        0.0935
    991        0.0001        0.0938
    992        0.0001        0.0922
    993        0.0001        0.0921
    994        0.0001        0.0935
    995        0.0001        0.0925
    996        0.0001        0.0937
    997        0.0001        0.0947
    998        0.0001        0.0929
    999        0.0001        0.0927
   1000        0.0001        0.0933
Loading the best network from the last checkpoint.
[4]:
<skorch.callbacks.training.Checkpoint at 0x7f002b99b910>

Plot training evolution

[5]:
if config['plot']['training progress']:
    history_length = len(tsp.ttr.regressor_['regressor'].history)
    train_loss = np.zeros((history_length, 1))
    for epoch in tsp.ttr.regressor_['regressor'].history:
        epoch_number = epoch['epoch']-1
        train_loss[epoch_number] = epoch['train_loss']
    _, axes_one = plt.subplots(figsize=(20, 20))
    axes_one.plot(train_loss, 'o-', label='training')
    axes_one.set_xlabel('Epoch')
    axes_one.set_ylabel('MSE')
    plt.legend()
../_images/notebooks_example_flights_dataset_9_0.png

Prediction on training data

[6]:
if config['predict on training data enabled']:
    dataloader = tsp.ttr.regressor['regressor'].get_iterator(tsp.dataset)
    x, y = dataloader.dataset[:]
    netout = tsp.predict(x)
    d_output = netout.shape[-1]
    if config['plot']['prediction on training data']:
        fig, axs = plt.subplots(d_output, 1, figsize=(20,20))
        axs = [axs]
    idx_range = [np.random.randint(0, len(tsp.dataset))]
    for idx in idx_range:
        if config['plot']['prediction on training data']:
            x_absciss = convert_year_month_array_to_datetime(x[idx, :, :])
        for idx_output_var in range(d_output):
            # Select real passengers data
            y_true = y[idx, :, idx_output_var]

            y_pred = netout[idx, :, idx_output_var]
            if config['plot']['prediction on training data']:
                ax = axs[idx_output_var]
                ax.plot(x_absciss, y_true, label="Truth", color='tab:blue')
                ax.plot(x_absciss, y_pred, label="Prediction", color='tab:orange')
                if idx == idx_range[0]:
                    ax.set_title(tsp.dataset.labels['y'][idx_output_var] + ' over time')
                    ax.set_xlabel('date')
                    ax.set_ylabel(tsp.dataset.labels['y'][idx_output_var])
                    ax.legend()
    if config['plot']['prediction on training data']:
        plt.show()
../_images/notebooks_example_flights_dataset_11_0.png

Future forecast

[7]:
# def make_x_pred(x, n_months_ahead, include_history=True):
#     def raw_add_months(sourcedate, n_months):
#         month = sourcedate.month - 1 + n_months
#         year = sourcedate.year + month // 12
#         month = month % 12 + 1
#         day = min(sourcedate.day, calendar.monthrange(year,month)[1])
#         return datetime(year, month, day)
#     def add_months(months, n_months):
#         return [raw_add_months(month, n_months) for month in months]
#     def add_months_2(month, n_months):
#         return [raw_add_months(month, n_month) for n_month in n_months]
#     x_dates = convert_year_month_array_to_datetime(x.squeeze())
#     last_month = x_dates[-1:][0]
#     next_n_months = add_months_2(last_month, np.arange(n_months_ahead)+1)
#     if include_history:
#         n_months = x_dates + next_n_months
#     else:
#         n_months = next_n_months
#     x_pred = np.array([[dt.month, dt.year] for dt in n_months]).reshape(1, -1, 2).astype(np.float32)
#     return x_pred

if config['forecast enabled']:
    dataloader = tsp.ttr.regressor['regressor'].get_iterator(tsp.dataset)
    x, y = dataloader.dataset[:]
    # x_pred = make_x_pred(x, config['forecast']['months ahead'])
    # netout = tsp.predict(x_pred)
    netout, x_pred = tsp.forecast(config['forecast']['months ahead'], include_history = config['forecast']['include history'])
    d_output = netout.shape[-1]
    if config['plot']['forecast']:
        fig, axs = plt.subplots(d_output, 1, figsize=(20,20))
        axs = [axs]
    idx_range = [np.random.randint(0, len(tsp.dataset))]
    for idx in idx_range:
        if config['plot']['forecast']:
            x_absciss = convert_year_month_array_to_datetime(x[idx, :, :])
            x_pred_absciss = convert_year_month_array_to_datetime(x_pred[idx, :, :])
        for idx_output_var in range(d_output):
            # Select real passengers data
            y_true = y[idx, :, idx_output_var]

            y_pred = netout[idx, :, idx_output_var]
            if config['plot']['prediction on validation data']:
                ax = axs[idx_output_var]
                ax.plot(x_pred_absciss, y_pred, label="Prediction", color='tab:orange')
                ax.plot(x_absciss, y_true, label="Truth", color='tab:blue')
                if idx == idx_range[0]:
                    ax.set_title(tsp.dataset.labels['y'][idx_output_var] + ' over time')
                    ax.set_xlabel('date')
                    ax.set_ylabel(tsp.dataset.labels['y'][idx_output_var])
                    ax.legend()
    if config['plot']['prediction on validation data']:
        plt.show()
../_images/notebooks_example_flights_dataset_13_0.png
[ ]: