Flights passengers example

Imports

[1]:
import torch
import numpy as np
import calendar
from datetime import timedelta, datetime
from matplotlib import pyplot as plt

from time_series_predictor import TimeSeriesPredictor
from time_series_models import BenchmarkLSTM
from flights_time_series_dataset import FlightsDataset, convert_year_month_array_to_datetime
/home/docs/checkouts/readthedocs.org/user_builds/time-series-predictor/envs/stable/lib/python3.7/site-packages/_distutils_hack/__init__.py:30: UserWarning: Setuptools is replacing distutils.
  warnings.warn("Setuptools is replacing distutils.")

Config

[2]:
plot_config = {}
plot_config['training progress'] = True
plot_config['prediction on training data'] = True
plot_config['prediction on validation data'] = True
plot_config['forecast'] = True

forecast_config = {}
forecast_config['include history'] = True
forecast_config['months ahead'] = 24

predictor_config = {}
predictor_config['epochs'] = 1000
predictor_config['learning rate'] = 1e-2
predictor_config['hidden dim'] = 100
predictor_config['layers num'] = 3

config = {}
config['plot'] = plot_config
config['forecast'] = forecast_config
config['predictor'] = predictor_config
config['predict on training data enabled'] = True
config['forecast enabled'] = True

Time Series Predictor instantiation

[3]:
tsp = TimeSeriesPredictor(
    BenchmarkLSTM(
        hidden_dim=config['predictor']['hidden dim'],
        num_layers=config['predictor']['layers num']
    ),
    lr=config['predictor']['learning rate'],
    max_epochs=predictor_config['epochs'],
    train_split=None,
    optimizer=torch.optim.Adam
)

Training process

[4]:
ds = FlightsDataset()
tsp.fit(ds)
Using device cpu
Re-initializing module because the following parameters were re-set: module__input_dim, module__output_dim.
Re-initializing criterion.
Re-initializing optimizer.
  epoch    train_loss    cp     dur
-------  ------------  ----  ------
      1        0.1743     +  0.0950
      2        0.0485     +  0.0745
      3        0.0675        0.0738
      4        0.1192        0.0750
      5        0.1004        0.0759
      6        0.0797        0.0759
      7        0.0621        0.0734
      8        0.0527        0.0725
      9        0.0676        0.0746
     10        0.0502        0.0734
     11        0.0525        0.0738
     12        0.0553        0.0729
     13        0.0548        0.0736
     14        0.0525        0.0747
     15        0.0505        0.0728
     16        0.0503        0.0741
     17        0.0502        0.0746
     18        0.0468     +  0.0773
     19        0.0402     +  0.0746
     20        0.0304     +  0.0733
     21        0.0133     +  0.0725
     22        0.0350        0.0737
     23        0.0675        0.0746
     24        0.0440        0.0737
     25        0.0348        0.0740
     26        0.0302        0.0741
     27        0.0300        0.0732
     28        0.0359        0.0715
     29        0.0403        0.0742
     30        0.0365        0.0742
     31        0.0307        0.0774
     32        0.0281        0.0764
     33        0.0271        0.0731
     34        0.0243        0.0748
     35        0.0178        0.0719
     36        0.0109     +  0.0742
     37        0.0120        0.0747
     38        0.0125        0.0752
     39        0.0103     +  0.0755
     40        0.0149        0.0755
     41        0.0146        0.0736
     42        0.0104        0.0741
     43        0.0097     +  0.0722
     44        0.0100        0.0747
     45        0.0083     +  0.0747
     46        0.0080     +  0.0758
     47        0.0091        0.0741
     48        0.0096        0.0738
     49        0.0089        0.0733
     50        0.0082        0.0724
     51        0.0084        0.0734
     52        0.0084        0.0745
     53        0.0074     +  0.0735
     54        0.0068     +  0.0720
     55        0.0071        0.0732
     56        0.0069        0.0731
     57        0.0063     +  0.0756
     58        0.0065        0.0761
     59        0.0067        0.0736
     60        0.0061     +  0.0745
     61        0.0061     +  0.0742
     62        0.0062        0.0740
     63        0.0056     +  0.0728
     64        0.0055     +  0.0740
     65        0.0055        0.0745
     66        0.0052     +  0.0743
     67        0.0051     +  0.0733
     68        0.0052        0.0752
     69        0.0049     +  0.0747
     70        0.0048     +  0.0751
     71        0.0047     +  0.0742
     72        0.0044     +  0.0737
     73        0.0043     +  0.0740
     74        0.0041     +  0.0745
     75        0.0039     +  0.0739
     76        0.0038     +  0.0735
     77        0.0035     +  0.0750
     78        0.0035     +  0.0738
     79        0.0032     +  0.0732
     80        0.0033        0.0737
     81        0.0031     +  0.0766
     82        0.0032        0.0769
     83        0.0032        0.0768
     84        0.0031     +  0.0769
     85        0.0031     +  0.0750
     86        0.0028     +  0.0743
     87        0.0028     +  0.0736
     88        0.0026     +  0.0742
     89        0.0026     +  0.0741
     90        0.0025     +  0.0742
     91        0.0025     +  0.0738
     92        0.0024     +  0.0745
     93        0.0024     +  0.0736
     94        0.0023     +  0.0754
     95        0.0022     +  0.0745
     96        0.0021     +  0.0734
     97        0.0020     +  0.0730
     98        0.0019     +  0.0742
     99        0.0018     +  0.0731
    100        0.0017     +  0.0735
    101        0.0016     +  0.0763
    102        0.0015     +  0.0768
    103        0.0014     +  0.0740
    104        0.0014     +  0.0736
    105        0.0013     +  0.0746
    106        0.0013        0.0748
    107        0.0014        0.0762
    108        0.0014        0.0729
    109        0.0014        0.0753
    110        0.0014        0.0723
    111        0.0013     +  0.0732
    112        0.0013     +  0.0750
    113        0.0012     +  0.0757
    114        0.0012     +  0.0749
    115        0.0011     +  0.0767
    116        0.0011     +  0.0737
    117        0.0011     +  0.0741
    118        0.0011        0.0741
    119        0.0011     +  0.0745
    120        0.0011     +  0.0736
    121        0.0011     +  0.0756
    122        0.0010     +  0.0746
    123        0.0010     +  0.0751
    124        0.0010     +  0.0762
    125        0.0010     +  0.0757
    126        0.0010     +  0.0760
    127        0.0010     +  0.0777
    128        0.0010     +  0.0733
    129        0.0010     +  0.0739
    130        0.0009     +  0.0739
    131        0.0009     +  0.0743
    132        0.0009     +  0.0764
    133        0.0009     +  0.0766
    134        0.0009     +  0.0734
    135        0.0009     +  0.0742
    136        0.0009     +  0.0732
    137        0.0009     +  0.0729
    138        0.0009     +  0.0742
    139        0.0009     +  0.0760
    140        0.0009     +  0.0747
    141        0.0008     +  0.0770
    142        0.0008     +  0.0774
    143        0.0008     +  0.0752
    144        0.0008     +  0.0733
    145        0.0008     +  0.0751
    146        0.0008     +  0.0738
    147        0.0008     +  0.0742
    148        0.0008     +  0.0743
    149        0.0008     +  0.0739
    150        0.0008     +  0.0752
    151        0.0008     +  0.0738
    152        0.0008     +  0.0740
    153        0.0008     +  0.0750
    154        0.0008     +  0.0759
    155        0.0007     +  0.0744
    156        0.0007     +  0.0757
    157        0.0007     +  0.0771
    158        0.0007     +  0.0749
    159        0.0007     +  0.0742
    160        0.0007     +  0.0741
    161        0.0007     +  0.0741
    162        0.0007     +  0.0739
    163        0.0007     +  0.0766
    164        0.0007     +  0.0755
    165        0.0007     +  0.0730
    166        0.0007     +  0.0731
    167        0.0007     +  0.0747
    168        0.0007     +  0.0743
    169        0.0007     +  0.0753
    170        0.0007     +  0.0735
    171        0.0007     +  0.0737
    172        0.0007     +  0.0729
    173        0.0007     +  0.0729
    174        0.0007     +  0.0750
    175        0.0007     +  0.0734
    176        0.0007     +  0.0727
    177        0.0007     +  0.0740
    178        0.0007     +  0.0745
    179        0.0007        0.0747
    180        0.0007        0.0749
    181        0.0008        0.0767
    182        0.0016        0.0760
    183        0.0048        0.0742
    184        0.0007        0.0763
    185        0.0041        0.0759
    186        0.0052        0.0760
    187        0.0046        0.0756
    188        0.0016        0.0768
    189        0.0048        0.0773
    190        0.0016        0.0766
    191        0.0045        0.0758
    192        0.0015        0.0759
    193        0.0026        0.0753
    194        0.0024        0.0755
    195        0.0011        0.0754
    196        0.0027        0.0778
    197        0.0013        0.0757
    198        0.0014        0.0763
    199        0.0021        0.0784
    200        0.0008        0.0758
    201        0.0017        0.0757
    202        0.0014        0.0758
    203        0.0008        0.0756
    204        0.0016        0.0764
    205        0.0009        0.0773
    206        0.0010        0.0759
    207        0.0013        0.0774
    208        0.0007        0.0743
    209        0.0010        0.0741
    210        0.0010        0.0768
    211        0.0007        0.0761
    212        0.0010        0.0759
    213        0.0008        0.0771
    214        0.0007        0.0761
    215        0.0009        0.0778
    216        0.0007        0.0769
    217        0.0008        0.0752
    218        0.0008        0.0766
    219        0.0007        0.0744
    220        0.0008        0.0750
    221        0.0007        0.0752
    222        0.0007        0.0733
    223        0.0007        0.0755
    224        0.0007        0.0752
    225        0.0007        0.0772
    226        0.0007        0.0750
    227        0.0006     +  0.0752
    228        0.0007        0.0758
    229        0.0007        0.0764
    230        0.0006     +  0.0749
    231        0.0007        0.0769
    232        0.0006        0.0762
    233        0.0006     +  0.0761
    234        0.0006        0.0750
    235        0.0006     +  0.0932
    236        0.0006        0.0809
    237        0.0006        0.0743
    238        0.0006     +  0.0748
    239        0.0006        0.0757
    240        0.0006        0.0775
    241        0.0006     +  0.0755
    242        0.0006        0.0750
    243        0.0006     +  0.0759
    244        0.0006        0.0759
    245        0.0006        0.0755
    246        0.0006     +  0.0761
    247        0.0006        0.0760
    248        0.0006     +  0.0756
    249        0.0006     +  0.0765
    250        0.0006        0.0765
    251        0.0006     +  0.0746
    252        0.0006        0.0777
    253        0.0006     +  0.0745
    254        0.0006     +  0.0750
    255        0.0006        0.0769
    256        0.0006     +  0.0758
    257        0.0006     +  0.0749
    258        0.0006     +  0.0759
    259        0.0005     +  0.0760
    260        0.0005        0.0772
    261        0.0005     +  0.0762
    262        0.0005     +  0.0765
    263        0.0005     +  0.0759
    264        0.0005     +  0.0783
    265        0.0005     +  0.0778
    266        0.0005     +  0.0765
    267        0.0005     +  0.0756
    268        0.0005     +  0.0748
    269        0.0005     +  0.0756
    270        0.0005     +  0.0769
    271        0.0005     +  0.0760
    272        0.0005     +  0.0758
    273        0.0005     +  0.0752
    274        0.0005     +  0.0744
    275        0.0005     +  0.0764
    276        0.0005     +  0.0758
    277        0.0005     +  0.0755
    278        0.0005     +  0.0770
    279        0.0005     +  0.0758
    280        0.0005     +  0.0733
    281        0.0005     +  0.0743
    282        0.0005     +  0.0737
    283        0.0005     +  0.0747
    284        0.0005     +  0.0752
    285        0.0005     +  0.0737
    286        0.0005     +  0.0742
    287        0.0005     +  0.0763
    288        0.0005     +  0.0752
    289        0.0005     +  0.0752
    290        0.0005     +  0.0749
    291        0.0005     +  0.0747
    292        0.0005     +  0.0741
    293        0.0005     +  0.0737
    294        0.0005     +  0.0738
    295        0.0005     +  0.0741
    296        0.0005     +  0.0738
    297        0.0005     +  0.0730
    298        0.0005     +  0.0750
    299        0.0005     +  0.0745
    300        0.0005     +  0.0734
    301        0.0005     +  0.0742
    302        0.0005     +  0.0735
    303        0.0005     +  0.0739
    304        0.0005     +  0.0737
    305        0.0005     +  0.0777
    306        0.0005     +  0.0825
    307        0.0005     +  0.0744
    308        0.0005     +  0.0756
    309        0.0005     +  0.0753
    310        0.0005     +  0.0778
    311        0.0005     +  0.0766
    312        0.0005     +  0.0760
    313        0.0005     +  0.0748
    314        0.0005     +  0.0729
    315        0.0005     +  0.0736
    316        0.0005     +  0.0728
    317        0.0005     +  0.0740
    318        0.0005     +  0.0750
    319        0.0005     +  0.0733
    320        0.0005     +  0.0725
    321        0.0005     +  0.0733
    322        0.0005     +  0.0752
    323        0.0005     +  0.0752
    324        0.0005     +  0.0752
    325        0.0005     +  0.0748
    326        0.0005     +  0.0726
    327        0.0005     +  0.0742
    328        0.0005     +  0.0741
    329        0.0005     +  0.0751
    330        0.0005     +  0.0735
    331        0.0005     +  0.0740
    332        0.0005     +  0.0740
    333        0.0005     +  0.0758
    334        0.0005     +  0.0767
    335        0.0005     +  0.0771
    336        0.0005     +  0.0750
    337        0.0005     +  0.0749
    338        0.0005     +  0.0751
    339        0.0005     +  0.0728
    340        0.0004     +  0.0741
    341        0.0004     +  0.0884
    342        0.0004     +  0.0737
    343        0.0004     +  0.0728
    344        0.0004     +  0.0742
    345        0.0004     +  0.0744
    346        0.0004     +  0.0730
    347        0.0004     +  0.0738
    348        0.0004     +  0.0745
    349        0.0004     +  0.0749
    350        0.0004     +  0.0730
    351        0.0004     +  0.0730
    352        0.0004     +  0.0748
    353        0.0004     +  0.0738
    354        0.0004     +  0.0728
    355        0.0004     +  0.0796
    356        0.0004     +  0.0765
    357        0.0004     +  0.0752
    358        0.0004     +  0.0748
    359        0.0004     +  0.0746
    360        0.0004     +  0.0730
    361        0.0004     +  0.0752
    362        0.0004     +  0.0744
    363        0.0004     +  0.0743
    364        0.0004     +  0.0739
    365        0.0004     +  0.0741
    366        0.0004     +  0.0733
    367        0.0004     +  0.0745
    368        0.0004     +  0.0726
    369        0.0004     +  0.0758
    370        0.0004     +  0.0764
    371        0.0004     +  0.0764
    372        0.0004     +  0.0776
    373        0.0004     +  0.0782
    374        0.0004     +  0.0784
    375        0.0004     +  0.0774
    376        0.0004     +  0.0792
    377        0.0004     +  0.0780
    378        0.0004     +  0.0796
    379        0.0004     +  0.0780
    380        0.0004     +  0.0788
    381        0.0004     +  0.0773
    382        0.0004     +  0.0762
    383        0.0004     +  0.0771
    384        0.0004     +  0.0780
    385        0.0004     +  0.0768
    386        0.0004     +  0.0764
    387        0.0004     +  0.0744
    388        0.0004     +  0.0751
    389        0.0004     +  0.0753
    390        0.0004     +  0.0743
    391        0.0004     +  0.0748
    392        0.0004     +  0.0744
    393        0.0004     +  0.0751
    394        0.0004     +  0.0764
    395        0.0004     +  0.0778
    396        0.0004     +  0.0784
    397        0.0004     +  0.0751
    398        0.0004     +  0.0761
    399        0.0004     +  0.0742
    400        0.0004     +  0.0754
    401        0.0004     +  0.0754
    402        0.0004     +  0.0763
    403        0.0004     +  0.0752
    404        0.0004     +  0.0730
    405        0.0004     +  0.0737
    406        0.0004     +  0.0737
    407        0.0004     +  0.0735
    408        0.0004     +  0.0756
    409        0.0004     +  0.0732
    410        0.0004     +  0.0731
    411        0.0004     +  0.0746
    412        0.0004     +  0.0748
    413        0.0004     +  0.0751
    414        0.0004     +  0.0745
    415        0.0004     +  0.0755
    416        0.0004     +  0.0737
    417        0.0004     +  0.0749
    418        0.0004     +  0.0739
    419        0.0004     +  0.0736
    420        0.0004     +  0.0757
    421        0.0004     +  0.0754
    422        0.0004     +  0.0770
    423        0.0004     +  0.0756
    424        0.0004     +  0.0753
    425        0.0004     +  0.0749
    426        0.0004     +  0.0742
    427        0.0004     +  0.0726
    428        0.0004     +  0.0731
    429        0.0004     +  0.0738
    430        0.0004     +  0.0745
    431        0.0004     +  0.0741
    432        0.0004     +  0.0757
    433        0.0004     +  0.0750
    434        0.0004     +  0.0760
    435        0.0004     +  0.0774
    436        0.0004     +  0.0770
    437        0.0004     +  0.0759
    438        0.0004     +  0.0745
    439        0.0004     +  0.0765
    440        0.0004     +  0.0758
    441        0.0004     +  0.0748
    442        0.0004     +  0.0744
    443        0.0004     +  0.0751
    444        0.0004     +  0.0775
    445        0.0004     +  0.0767
    446        0.0004     +  0.0767
    447        0.0004     +  0.0763
    448        0.0004     +  0.0763
    449        0.0004     +  0.0740
    450        0.0004     +  0.0767
    451        0.0004     +  0.0760
    452        0.0003     +  0.0783
    453        0.0003     +  0.0775
    454        0.0003     +  0.0783
    455        0.0003     +  0.0765
    456        0.0003     +  0.0765
    457        0.0003     +  0.0765
    458        0.0003     +  0.0760
    459        0.0003     +  0.0771
    460        0.0003     +  0.0739
    461        0.0003     +  0.0744
    462        0.0003     +  0.0844
    463        0.0003     +  0.0751
    464        0.0003     +  0.0747
    465        0.0003     +  0.0741
    466        0.0003     +  0.0756
    467        0.0003     +  0.0758
    468        0.0003     +  0.0753
    469        0.0003     +  0.0737
    470        0.0003     +  0.0765
    471        0.0003     +  0.0750
    472        0.0003     +  0.0756
    473        0.0003     +  0.0738
    474        0.0003     +  0.0766
    475        0.0003     +  0.0733
    476        0.0003     +  0.0738
    477        0.0003     +  0.0734
    478        0.0003     +  0.0733
    479        0.0003     +  0.0753
    480        0.0003     +  0.0753
    481        0.0003     +  0.0749
    482        0.0003     +  0.0744
    483        0.0003     +  0.0739
    484        0.0003     +  0.0759
    485        0.0003     +  0.0743
    486        0.0003     +  0.0737
    487        0.0003     +  0.0738
    488        0.0003     +  0.0767
    489        0.0003        0.0767
    490        0.0004        0.0783
    491        0.0006        0.0740
    492        0.0017        0.0746
    493        0.0023        0.0737
    494        0.0011        0.0728
    495        0.0020        0.0743
    496        0.0005        0.0735
    497        0.0017        0.0742
    498        0.0009        0.0748
    499        0.0013        0.0723
    500        0.0006        0.0725
    501        0.0011        0.0743
    502        0.0007        0.0743
    503        0.0009        0.0730
    504        0.0006        0.0743
    505        0.0008        0.0725
    506        0.0007        0.0749
    507        0.0007        0.0733
    508        0.0005        0.0726
    509        0.0007        0.0751
    510        0.0006        0.0753
    511        0.0006        0.0736
    512        0.0005        0.0744
    513        0.0005        0.0761
    514        0.0005        0.0758
    515        0.0005        0.0746
    516        0.0005        0.0738
    517        0.0005        0.0750
    518        0.0005        0.0757
    519        0.0005        0.0741
    520        0.0005        0.0724
    521        0.0004        0.0726
    522        0.0005        0.0722
    523        0.0004        0.0730
    524        0.0005        0.0740
    525        0.0004        0.0744
    526        0.0004        0.0745
    527        0.0004        0.0729
    528        0.0004        0.0746
    529        0.0004        0.0734
    530        0.0004        0.0738
    531        0.0004        0.0738
    532        0.0004        0.0742
    533        0.0004        0.0738
    534        0.0004        0.0729
    535        0.0004        0.0740
    536        0.0004        0.0744
    537        0.0004        0.0754
    538        0.0004        0.0747
    539        0.0004        0.0737
    540        0.0004        0.0744
    541        0.0004        0.0754
    542        0.0004        0.0746
    543        0.0004        0.0734
    544        0.0004        0.0753
    545        0.0004        0.0727
    546        0.0004        0.0737
    547        0.0004        0.0728
    548        0.0004        0.0719
    549        0.0004        0.0721
    550        0.0004        0.0727
    551        0.0004        0.0742
    552        0.0004        0.0722
    553        0.0004        0.0733
    554        0.0004        0.0737
    555        0.0004        0.0737
    556        0.0004        0.0746
    557        0.0004        0.0751
    558        0.0004        0.0744
    559        0.0004        0.0741
    560        0.0004        0.0728
    561        0.0004        0.0719
    562        0.0003        0.0719
    563        0.0003        0.0716
    564        0.0003        0.0743
    565        0.0003        0.0735
    566        0.0003        0.0737
    567        0.0003        0.0750
    568        0.0003        0.0772
    569        0.0003        0.0742
    570        0.0003        0.0744
    571        0.0003        0.0739
    572        0.0003        0.0727
    573        0.0003        0.0740
    574        0.0003        0.0731
    575        0.0003        0.0719
    576        0.0003        0.0725
    577        0.0003        0.0745
    578        0.0003        0.0755
    579        0.0003        0.0734
    580        0.0003        0.0732
    581        0.0003        0.0734
    582        0.0003        0.0741
    583        0.0003        0.0761
    584        0.0003        0.0764
    585        0.0003        0.0757
    586        0.0003        0.0738
    587        0.0003        0.0726
    588        0.0003        0.0735
    589        0.0003        0.0773
    590        0.0003        0.0758
    591        0.0003        0.0766
    592        0.0003        0.0767
    593        0.0003        0.0748
    594        0.0003        0.0742
    595        0.0003        0.0743
    596        0.0003        0.0726
    597        0.0003        0.0738
    598        0.0003        0.0729
    599        0.0003        0.0730
    600        0.0003        0.0783
    601        0.0003        0.0743
    602        0.0003        0.0732
    603        0.0003        0.0731
    604        0.0003        0.0752
    605        0.0003        0.0725
    606        0.0003        0.0744
    607        0.0003        0.0740
    608        0.0003        0.0736
    609        0.0003        0.0726
    610        0.0003        0.0723
    611        0.0003        0.0743
    612        0.0003        0.0746
    613        0.0003        0.0716
    614        0.0003        0.0724
    615        0.0003        0.0730
    616        0.0003        0.0753
    617        0.0003        0.0751
    618        0.0003        0.0742
    619        0.0003        0.0731
    620        0.0003        0.0749
    621        0.0003        0.0763
    622        0.0003        0.0754
    623        0.0003        0.0748
    624        0.0003        0.0733
    625        0.0003        0.0730
    626        0.0003        0.0728
    627        0.0003        0.0738
    628        0.0003        0.0739
    629        0.0003        0.0733
    630        0.0003        0.0747
    631        0.0003        0.0743
    632        0.0003        0.0719
    633        0.0003     +  0.0731
    634        0.0003     +  0.0766
    635        0.0003     +  0.0748
    636        0.0003     +  0.0741
    637        0.0003     +  0.0742
    638        0.0003     +  0.0729
    639        0.0003     +  0.0739
    640        0.0003     +  0.0718
    641        0.0003     +  0.0732
    642        0.0003     +  0.0759
    643        0.0003     +  0.0733
    644        0.0003     +  0.0774
    645        0.0003     +  0.0768
    646        0.0003     +  0.0759
    647        0.0003     +  0.0768
    648        0.0003     +  0.0759
    649        0.0003     +  0.0759
    650        0.0003     +  0.0748
    651        0.0003     +  0.0737
    652        0.0003     +  0.0753
    653        0.0003     +  0.0744
    654        0.0003     +  0.0735
    655        0.0003     +  0.0761
    656        0.0003     +  0.0742
    657        0.0003     +  0.0738
    658        0.0003     +  0.0740
    659        0.0003     +  0.0751
    660        0.0003     +  0.0741
    661        0.0003     +  0.0735
    662        0.0003     +  0.0726
    663        0.0003     +  0.0732
    664        0.0003     +  0.0749
    665        0.0003     +  0.0739
    666        0.0003     +  0.0757
    667        0.0003     +  0.0752
    668        0.0003     +  0.0743
    669        0.0003     +  0.0733
    670        0.0003     +  0.0728
    671        0.0003     +  0.0745
    672        0.0003     +  0.0744
    673        0.0003     +  0.0731
    674        0.0003     +  0.0746
    675        0.0003     +  0.0742
    676        0.0003     +  0.0751
    677        0.0003     +  0.0748
    678        0.0003     +  0.0746
    679        0.0003     +  0.0739
    680        0.0003     +  0.0752
    681        0.0003     +  0.0740
    682        0.0003     +  0.0750
    683        0.0003     +  0.0757
    684        0.0003     +  0.0734
    685        0.0003     +  0.0747
    686        0.0003     +  0.0740
    687        0.0003     +  0.0765
    688        0.0003     +  0.0750
    689        0.0003     +  0.0736
    690        0.0003     +  0.0724
    691        0.0003     +  0.0740
    692        0.0003     +  0.0741
    693        0.0003     +  0.0737
    694        0.0003     +  0.0736
    695        0.0003     +  0.0767
    696        0.0003     +  0.0733
    697        0.0003     +  0.0755
    698        0.0003     +  0.0747
    699        0.0003     +  0.0752
    700        0.0003     +  0.0751
    701        0.0003     +  0.0757
    702        0.0003     +  0.0733
    703        0.0003     +  0.0739
    704        0.0003     +  0.0744
    705        0.0003     +  0.0742
    706        0.0003     +  0.0771
    707        0.0003     +  0.0749
    708        0.0003     +  0.0777
    709        0.0003     +  0.0765
    710        0.0003     +  0.0738
    711        0.0003     +  0.0738
    712        0.0003     +  0.0750
    713        0.0003     +  0.0767
    714        0.0003     +  0.0761
    715        0.0003     +  0.0746
    716        0.0003     +  0.0763
    717        0.0003     +  0.0752
    718        0.0003     +  0.0764
    719        0.0003     +  0.0780
    720        0.0003     +  0.0763
    721        0.0003     +  0.0755
    722        0.0003     +  0.0758
    723        0.0003     +  0.0748
    724        0.0003     +  0.0745
    725        0.0003     +  0.0746
    726        0.0003     +  0.0748
    727        0.0003     +  0.0766
    728        0.0003     +  0.0742
    729        0.0003     +  0.0769
    730        0.0003     +  0.0774
    731        0.0003     +  0.0764
    732        0.0003     +  0.0740
    733        0.0003     +  0.0748
    734        0.0003     +  0.0762
    735        0.0003     +  0.0760
    736        0.0003     +  0.0742
    737        0.0003     +  0.0748
    738        0.0003     +  0.0738
    739        0.0003     +  0.0760
    740        0.0003     +  0.0769
    741        0.0003     +  0.0751
    742        0.0003     +  0.0743
    743        0.0003     +  0.0760
    744        0.0003     +  0.0740
    745        0.0003     +  0.0751
    746        0.0003     +  0.0750
    747        0.0003     +  0.0748
    748        0.0003     +  0.0760
    749        0.0003     +  0.0743
    750        0.0003     +  0.0741
    751        0.0003     +  0.0738
    752        0.0003     +  0.0750
    753        0.0003     +  0.0745
    754        0.0003     +  0.0741
    755        0.0003     +  0.0756
    756        0.0003     +  0.0745
    757        0.0003     +  0.0752
    758        0.0002     +  0.0742
    759        0.0002     +  0.0736
    760        0.0003        0.0746
    761        0.0003        0.0755
    762        0.0007        0.0753
    763        0.0007        0.0738
    764        0.0006        0.0729
    765        0.0004        0.0752
    766        0.0005        0.0742
    767        0.0005        0.0734
    768        0.0004        0.0728
    769        0.0005        0.0722
    770        0.0004        0.0721
    771        0.0003        0.0747
    772        0.0004        0.0749
    773        0.0003        0.0770
    774        0.0004        0.0755
    775        0.0004        0.0745
    776        0.0003        0.0740
    777        0.0004        0.0730
    778        0.0003        0.0722
    779        0.0003        0.0740
    780        0.0003        0.0725
    781        0.0003        0.0741
    782        0.0003        0.0730
    783        0.0003        0.0739
    784        0.0003        0.0735
    785        0.0003        0.0734
    786        0.0003        0.0744
    787        0.0003        0.0748
    788        0.0003        0.0755
    789        0.0003        0.0744
    790        0.0003        0.0736
    791        0.0003        0.0745
    792        0.0003        0.0735
    793        0.0003        0.0737
    794        0.0003        0.0728
    795        0.0003        0.0730
    796        0.0003        0.0731
    797        0.0002        0.0745
    798        0.0002     +  0.0731
    799        0.0003        0.0740
    800        0.0003        0.0773
    801        0.0002        0.0751
    802        0.0002     +  0.0739
    803        0.0002        0.0727
    804        0.0002        0.0727
    805        0.0002        0.0740
    806        0.0002        0.0740
    807        0.0002     +  0.0757
    808        0.0002        0.0750
    809        0.0002        0.0771
    810        0.0002        0.0744
    811        0.0002        0.0738
    812        0.0002        0.0738
    813        0.0002        0.0742
    814        0.0003        0.0739
    815        0.0002        0.0734
    816        0.0002        0.0720
    817        0.0002        0.0744
    818        0.0002     +  0.0749
    819        0.0002     +  0.0749
    820        0.0002     +  0.0751
    821        0.0002     +  0.0733
    822        0.0002        0.0748
    823        0.0002        0.0733
    824        0.0002        0.0761
    825        0.0002        0.0776
    826        0.0002        0.0764
    827        0.0003        0.0750
    828        0.0003        0.0748
    829        0.0003        0.0763
    830        0.0002        0.0768
    831        0.0002     +  0.0747
    832        0.0002     +  0.0760
    833        0.0002        0.0774
    834        0.0002        0.0727
    835        0.0002        0.0740
    836        0.0002        0.0737
    837        0.0002     +  0.0738
    838        0.0002        0.0758
    839        0.0002        0.0740
    840        0.0002        0.0727
    841        0.0002        0.0739
    842        0.0002        0.0721
    843        0.0002     +  0.0736
    844        0.0002     +  0.0745
    845        0.0002        0.0783
    846        0.0002        0.0726
    847        0.0002        0.0755
    848        0.0002        0.0741
    849        0.0002     +  0.0783
    850        0.0002     +  0.0780
    851        0.0002     +  0.0762
    852        0.0002     +  0.0738
    853        0.0002        0.0735
    854        0.0002        0.0747
    855        0.0002        0.0734
    856        0.0002     +  0.0741
    857        0.0002     +  0.0754
    858        0.0002     +  0.0729
    859        0.0002     +  0.0738
    860        0.0002     +  0.0755
    861        0.0002     +  0.0758
    862        0.0002     +  0.0750
    863        0.0002     +  0.0733
    864        0.0002     +  0.0732
    865        0.0002     +  0.0740
    866        0.0002     +  0.0735
    867        0.0002     +  0.0747
    868        0.0002     +  0.0777
    869        0.0002     +  0.0762
    870        0.0002     +  0.0774
    871        0.0002     +  0.0750
    872        0.0002     +  0.0776
    873        0.0002     +  0.0744
    874        0.0002     +  0.0732
    875        0.0001     +  0.0758
    876        0.0001     +  0.0757
    877        0.0001     +  0.0746
    878        0.0001     +  0.0750
    879        0.0001     +  0.0738
    880        0.0001     +  0.0740
    881        0.0001        0.0741
    882        0.0002        0.0742
    883        0.0002        0.0748
    884        0.0002        0.0729
    885        0.0002        0.0727
    886        0.0003        0.0729
    887        0.0004        0.0747
    888        0.0003        0.0749
    889        0.0003        0.0734
    890        0.0002        0.0739
    891        0.0002        0.0749
    892        0.0002        0.0738
    893        0.0003        0.0746
    894        0.0003        0.0748
    895        0.0003        0.0761
    896        0.0002        0.0751
    897        0.0003        0.0729
    898        0.0003        0.0717
    899        0.0002        0.0728
    900        0.0002        0.0733
    901        0.0002        0.0724
    902        0.0002        0.0728
    903        0.0002        0.0752
    904        0.0002        0.0720
    905        0.0002        0.0732
    906        0.0002        0.0737
    907        0.0002        0.0740
    908        0.0002        0.0745
    909        0.0002        0.0740
    910        0.0002        0.0740
    911        0.0002        0.0751
    912        0.0002        0.0733
    913        0.0002        0.0739
    914        0.0002        0.0736
    915        0.0002        0.0730
    916        0.0002        0.0726
    917        0.0002        0.0730
    918        0.0002        0.0748
    919        0.0001        0.0754
    920        0.0001        0.0783
    921        0.0002        0.0750
    922        0.0001        0.0760
    923        0.0001     +  0.0750
    924        0.0001     +  0.0755
    925        0.0001     +  0.0767
    926        0.0001        0.0768
    927        0.0001        0.0753
    928        0.0001        0.0764
    929        0.0001        0.0761
    930        0.0001        0.0756
    931        0.0001        0.0762
    932        0.0001        0.0761
    933        0.0002        0.0758
    934        0.0002        0.0748
    935        0.0002        0.0747
    936        0.0002        0.0750
    937        0.0002        0.0746
    938        0.0002        0.0745
    939        0.0002        0.0753
    940        0.0001        0.0741
    941        0.0001     +  0.0732
    942        0.0001     +  0.0766
    943        0.0001        0.0757
    944        0.0001        0.0762
    945        0.0002        0.0762
    946        0.0002        0.0767
    947        0.0002        0.0954
    948        0.0002        0.0738
    949        0.0002        0.0758
    950        0.0001        0.0777
    951        0.0001        0.0757
    952        0.0001     +  0.0770
    953        0.0001        0.0791
    954        0.0001        0.0751
    955        0.0001        0.0759
    956        0.0001        0.0773
    957        0.0002        0.0766
    958        0.0001        0.0775
    959        0.0001        0.0765
    960        0.0001        0.0738
    961        0.0001     +  0.0763
    962        0.0001        0.0769
    963        0.0001        0.0755
    964        0.0001        0.0775
    965        0.0001        0.0735
    966        0.0001        0.0753
    967        0.0001        0.0757
    968        0.0002        0.0754
    969        0.0001        0.0781
    970        0.0001        0.0763
    971        0.0001        0.0771
    972        0.0001        0.0764
    973        0.0001        0.0744
    974        0.0001     +  0.0747
    975        0.0001     +  0.0757
    976        0.0001     +  0.0763
    977        0.0001        0.0764
    978        0.0001        0.0746
    979        0.0001        0.0767
    980        0.0001        0.0780
    981        0.0001        0.0787
    982        0.0001        0.0764
    983        0.0001        0.0766
    984        0.0001        0.0778
    985        0.0002        0.0777
    986        0.0002        0.0754
    987        0.0002        0.0765
    988        0.0002        0.0767
    989        0.0001        0.0747
    990        0.0001        0.0740
    991        0.0001     +  0.0753
    992        0.0001     +  0.0771
    993        0.0001        0.0799
    994        0.0001        0.0774
    995        0.0001        0.0766
    996        0.0001        0.0783
    997        0.0001        0.0747
    998        0.0001        0.0744
    999        0.0001     +  0.0769
   1000        0.0001        0.0785
Loading the best network from the last checkpoint.
[4]:
<skorch.callbacks.training.Checkpoint at 0x7fbfdb73c290>

Plot training evolution

[5]:
if config['plot']['training progress']:
    history_length = len(tsp.ttr.regressor_['regressor'].history)
    train_loss = np.zeros((history_length, 1))
    for epoch in tsp.ttr.regressor_['regressor'].history:
        epoch_number = epoch['epoch']-1
        train_loss[epoch_number] = epoch['train_loss']
    _, axes_one = plt.subplots(figsize=(20, 20))
    axes_one.plot(train_loss, 'o-', label='training')
    axes_one.set_xlabel('Epoch')
    axes_one.set_ylabel('MSE')
    plt.legend()
../_images/notebooks_example_flights_dataset_9_0.png

Prediction on training data

[6]:
if config['predict on training data enabled']:
    dataloader = tsp.ttr.regressor['regressor'].get_iterator(tsp.dataset)
    x, y = dataloader.dataset[:]
    netout = tsp.predict(x)
    d_output = netout.shape[-1]
    if config['plot']['prediction on training data']:
        fig, axs = plt.subplots(d_output, 1, figsize=(20,20))
        axs = [axs]
    idx_range = [np.random.randint(0, len(tsp.dataset))]
    for idx in idx_range:
        if config['plot']['prediction on training data']:
            x_absciss = convert_year_month_array_to_datetime(x[idx, :, :])
        for idx_output_var in range(d_output):
            # Select real passengers data
            y_true = y[idx, :, idx_output_var]

            y_pred = netout[idx, :, idx_output_var]
            if config['plot']['prediction on training data']:
                ax = axs[idx_output_var]
                ax.plot(x_absciss, y_true, label="Truth", color='tab:blue')
                ax.plot(x_absciss, y_pred, label="Prediction", color='tab:orange')
                if idx == idx_range[0]:
                    ax.set_title(tsp.dataset.labels['y'][idx_output_var] + ' over time')
                    ax.set_xlabel('date')
                    ax.set_ylabel(tsp.dataset.labels['y'][idx_output_var])
                    ax.legend()
    if config['plot']['prediction on training data']:
        plt.show()
../_images/notebooks_example_flights_dataset_11_0.png

Future forecast

[7]:
# def make_x_pred(x, n_months_ahead, include_history=True):
#     def raw_add_months(sourcedate, n_months):
#         month = sourcedate.month - 1 + n_months
#         year = sourcedate.year + month // 12
#         month = month % 12 + 1
#         day = min(sourcedate.day, calendar.monthrange(year,month)[1])
#         return datetime(year, month, day)
#     def add_months(months, n_months):
#         return [raw_add_months(month, n_months) for month in months]
#     def add_months_2(month, n_months):
#         return [raw_add_months(month, n_month) for n_month in n_months]
#     x_dates = convert_year_month_array_to_datetime(x.squeeze())
#     last_month = x_dates[-1:][0]
#     next_n_months = add_months_2(last_month, np.arange(n_months_ahead)+1)
#     if include_history:
#         n_months = x_dates + next_n_months
#     else:
#         n_months = next_n_months
#     x_pred = np.array([[dt.month, dt.year] for dt in n_months]).reshape(1, -1, 2).astype(np.float32)
#     return x_pred

if config['forecast enabled']:
    dataloader = tsp.ttr.regressor['regressor'].get_iterator(tsp.dataset)
    x, y = dataloader.dataset[:]
    # x_pred = make_x_pred(x, config['forecast']['months ahead'])
    # netout = tsp.predict(x_pred)
    netout, x_pred = tsp.forecast(config['forecast']['months ahead'], include_history = config['forecast']['include history'])
    d_output = netout.shape[-1]
    if config['plot']['forecast']:
        fig, axs = plt.subplots(d_output, 1, figsize=(20,20))
        axs = [axs]
    idx_range = [np.random.randint(0, len(tsp.dataset))]
    for idx in idx_range:
        if config['plot']['forecast']:
            x_absciss = convert_year_month_array_to_datetime(x[idx, :, :])
            x_pred_absciss = convert_year_month_array_to_datetime(x_pred[idx, :, :])
        for idx_output_var in range(d_output):
            # Select real passengers data
            y_true = y[idx, :, idx_output_var]

            y_pred = netout[idx, :, idx_output_var]
            if config['plot']['prediction on validation data']:
                ax = axs[idx_output_var]
                ax.plot(x_pred_absciss, y_pred, label="Prediction", color='tab:orange')
                ax.plot(x_absciss, y_true, label="Truth", color='tab:blue')
                if idx == idx_range[0]:
                    ax.set_title(tsp.dataset.labels['y'][idx_output_var] + ' over time')
                    ax.set_xlabel('date')
                    ax.set_ylabel(tsp.dataset.labels['y'][idx_output_var])
                    ax.legend()
    if config['plot']['prediction on validation data']:
        plt.show()
../_images/notebooks_example_flights_dataset_13_0.png
[ ]: