Flights passengers example
Imports
[1]:
import torch
import numpy as np
import calendar
from datetime import timedelta, datetime
from matplotlib import pyplot as plt
from time_series_predictor import TimeSeriesPredictor
from time_series_models import BenchmarkLSTM
from flights_time_series_dataset import FlightsDataset, convert_year_month_array_to_datetime
/home/docs/checkouts/readthedocs.org/user_builds/time-series-predictor/envs/stable/lib/python3.7/site-packages/_distutils_hack/__init__.py:30: UserWarning: Setuptools is replacing distutils.
warnings.warn("Setuptools is replacing distutils.")
Config
[2]:
plot_config = {}
plot_config['training progress'] = True
plot_config['prediction on training data'] = True
plot_config['prediction on validation data'] = True
plot_config['forecast'] = True
forecast_config = {}
forecast_config['include history'] = True
forecast_config['months ahead'] = 24
predictor_config = {}
predictor_config['epochs'] = 1000
predictor_config['learning rate'] = 1e-2
predictor_config['hidden dim'] = 100
predictor_config['layers num'] = 3
config = {}
config['plot'] = plot_config
config['forecast'] = forecast_config
config['predictor'] = predictor_config
config['predict on training data enabled'] = True
config['forecast enabled'] = True
Time Series Predictor instantiation
[3]:
tsp = TimeSeriesPredictor(
BenchmarkLSTM(
hidden_dim=config['predictor']['hidden dim'],
num_layers=config['predictor']['layers num']
),
lr=config['predictor']['learning rate'],
max_epochs=predictor_config['epochs'],
train_split=None,
optimizer=torch.optim.Adam
)
Training process
[4]:
ds = FlightsDataset()
tsp.fit(ds)
Using device cpu
Re-initializing module because the following parameters were re-set: module__input_dim, module__output_dim.
Re-initializing criterion.
Re-initializing optimizer.
epoch train_loss cp dur
------- ------------ ---- ------
1 0.1743 + 0.0950
2 0.0485 + 0.0745
3 0.0675 0.0738
4 0.1192 0.0750
5 0.1004 0.0759
6 0.0797 0.0759
7 0.0621 0.0734
8 0.0527 0.0725
9 0.0676 0.0746
10 0.0502 0.0734
11 0.0525 0.0738
12 0.0553 0.0729
13 0.0548 0.0736
14 0.0525 0.0747
15 0.0505 0.0728
16 0.0503 0.0741
17 0.0502 0.0746
18 0.0468 + 0.0773
19 0.0402 + 0.0746
20 0.0304 + 0.0733
21 0.0133 + 0.0725
22 0.0350 0.0737
23 0.0675 0.0746
24 0.0440 0.0737
25 0.0348 0.0740
26 0.0302 0.0741
27 0.0300 0.0732
28 0.0359 0.0715
29 0.0403 0.0742
30 0.0365 0.0742
31 0.0307 0.0774
32 0.0281 0.0764
33 0.0271 0.0731
34 0.0243 0.0748
35 0.0178 0.0719
36 0.0109 + 0.0742
37 0.0120 0.0747
38 0.0125 0.0752
39 0.0103 + 0.0755
40 0.0149 0.0755
41 0.0146 0.0736
42 0.0104 0.0741
43 0.0097 + 0.0722
44 0.0100 0.0747
45 0.0083 + 0.0747
46 0.0080 + 0.0758
47 0.0091 0.0741
48 0.0096 0.0738
49 0.0089 0.0733
50 0.0082 0.0724
51 0.0084 0.0734
52 0.0084 0.0745
53 0.0074 + 0.0735
54 0.0068 + 0.0720
55 0.0071 0.0732
56 0.0069 0.0731
57 0.0063 + 0.0756
58 0.0065 0.0761
59 0.0067 0.0736
60 0.0061 + 0.0745
61 0.0061 + 0.0742
62 0.0062 0.0740
63 0.0056 + 0.0728
64 0.0055 + 0.0740
65 0.0055 0.0745
66 0.0052 + 0.0743
67 0.0051 + 0.0733
68 0.0052 0.0752
69 0.0049 + 0.0747
70 0.0048 + 0.0751
71 0.0047 + 0.0742
72 0.0044 + 0.0737
73 0.0043 + 0.0740
74 0.0041 + 0.0745
75 0.0039 + 0.0739
76 0.0038 + 0.0735
77 0.0035 + 0.0750
78 0.0035 + 0.0738
79 0.0032 + 0.0732
80 0.0033 0.0737
81 0.0031 + 0.0766
82 0.0032 0.0769
83 0.0032 0.0768
84 0.0031 + 0.0769
85 0.0031 + 0.0750
86 0.0028 + 0.0743
87 0.0028 + 0.0736
88 0.0026 + 0.0742
89 0.0026 + 0.0741
90 0.0025 + 0.0742
91 0.0025 + 0.0738
92 0.0024 + 0.0745
93 0.0024 + 0.0736
94 0.0023 + 0.0754
95 0.0022 + 0.0745
96 0.0021 + 0.0734
97 0.0020 + 0.0730
98 0.0019 + 0.0742
99 0.0018 + 0.0731
100 0.0017 + 0.0735
101 0.0016 + 0.0763
102 0.0015 + 0.0768
103 0.0014 + 0.0740
104 0.0014 + 0.0736
105 0.0013 + 0.0746
106 0.0013 0.0748
107 0.0014 0.0762
108 0.0014 0.0729
109 0.0014 0.0753
110 0.0014 0.0723
111 0.0013 + 0.0732
112 0.0013 + 0.0750
113 0.0012 + 0.0757
114 0.0012 + 0.0749
115 0.0011 + 0.0767
116 0.0011 + 0.0737
117 0.0011 + 0.0741
118 0.0011 0.0741
119 0.0011 + 0.0745
120 0.0011 + 0.0736
121 0.0011 + 0.0756
122 0.0010 + 0.0746
123 0.0010 + 0.0751
124 0.0010 + 0.0762
125 0.0010 + 0.0757
126 0.0010 + 0.0760
127 0.0010 + 0.0777
128 0.0010 + 0.0733
129 0.0010 + 0.0739
130 0.0009 + 0.0739
131 0.0009 + 0.0743
132 0.0009 + 0.0764
133 0.0009 + 0.0766
134 0.0009 + 0.0734
135 0.0009 + 0.0742
136 0.0009 + 0.0732
137 0.0009 + 0.0729
138 0.0009 + 0.0742
139 0.0009 + 0.0760
140 0.0009 + 0.0747
141 0.0008 + 0.0770
142 0.0008 + 0.0774
143 0.0008 + 0.0752
144 0.0008 + 0.0733
145 0.0008 + 0.0751
146 0.0008 + 0.0738
147 0.0008 + 0.0742
148 0.0008 + 0.0743
149 0.0008 + 0.0739
150 0.0008 + 0.0752
151 0.0008 + 0.0738
152 0.0008 + 0.0740
153 0.0008 + 0.0750
154 0.0008 + 0.0759
155 0.0007 + 0.0744
156 0.0007 + 0.0757
157 0.0007 + 0.0771
158 0.0007 + 0.0749
159 0.0007 + 0.0742
160 0.0007 + 0.0741
161 0.0007 + 0.0741
162 0.0007 + 0.0739
163 0.0007 + 0.0766
164 0.0007 + 0.0755
165 0.0007 + 0.0730
166 0.0007 + 0.0731
167 0.0007 + 0.0747
168 0.0007 + 0.0743
169 0.0007 + 0.0753
170 0.0007 + 0.0735
171 0.0007 + 0.0737
172 0.0007 + 0.0729
173 0.0007 + 0.0729
174 0.0007 + 0.0750
175 0.0007 + 0.0734
176 0.0007 + 0.0727
177 0.0007 + 0.0740
178 0.0007 + 0.0745
179 0.0007 0.0747
180 0.0007 0.0749
181 0.0008 0.0767
182 0.0016 0.0760
183 0.0048 0.0742
184 0.0007 0.0763
185 0.0041 0.0759
186 0.0052 0.0760
187 0.0046 0.0756
188 0.0016 0.0768
189 0.0048 0.0773
190 0.0016 0.0766
191 0.0045 0.0758
192 0.0015 0.0759
193 0.0026 0.0753
194 0.0024 0.0755
195 0.0011 0.0754
196 0.0027 0.0778
197 0.0013 0.0757
198 0.0014 0.0763
199 0.0021 0.0784
200 0.0008 0.0758
201 0.0017 0.0757
202 0.0014 0.0758
203 0.0008 0.0756
204 0.0016 0.0764
205 0.0009 0.0773
206 0.0010 0.0759
207 0.0013 0.0774
208 0.0007 0.0743
209 0.0010 0.0741
210 0.0010 0.0768
211 0.0007 0.0761
212 0.0010 0.0759
213 0.0008 0.0771
214 0.0007 0.0761
215 0.0009 0.0778
216 0.0007 0.0769
217 0.0008 0.0752
218 0.0008 0.0766
219 0.0007 0.0744
220 0.0008 0.0750
221 0.0007 0.0752
222 0.0007 0.0733
223 0.0007 0.0755
224 0.0007 0.0752
225 0.0007 0.0772
226 0.0007 0.0750
227 0.0006 + 0.0752
228 0.0007 0.0758
229 0.0007 0.0764
230 0.0006 + 0.0749
231 0.0007 0.0769
232 0.0006 0.0762
233 0.0006 + 0.0761
234 0.0006 0.0750
235 0.0006 + 0.0932
236 0.0006 0.0809
237 0.0006 0.0743
238 0.0006 + 0.0748
239 0.0006 0.0757
240 0.0006 0.0775
241 0.0006 + 0.0755
242 0.0006 0.0750
243 0.0006 + 0.0759
244 0.0006 0.0759
245 0.0006 0.0755
246 0.0006 + 0.0761
247 0.0006 0.0760
248 0.0006 + 0.0756
249 0.0006 + 0.0765
250 0.0006 0.0765
251 0.0006 + 0.0746
252 0.0006 0.0777
253 0.0006 + 0.0745
254 0.0006 + 0.0750
255 0.0006 0.0769
256 0.0006 + 0.0758
257 0.0006 + 0.0749
258 0.0006 + 0.0759
259 0.0005 + 0.0760
260 0.0005 0.0772
261 0.0005 + 0.0762
262 0.0005 + 0.0765
263 0.0005 + 0.0759
264 0.0005 + 0.0783
265 0.0005 + 0.0778
266 0.0005 + 0.0765
267 0.0005 + 0.0756
268 0.0005 + 0.0748
269 0.0005 + 0.0756
270 0.0005 + 0.0769
271 0.0005 + 0.0760
272 0.0005 + 0.0758
273 0.0005 + 0.0752
274 0.0005 + 0.0744
275 0.0005 + 0.0764
276 0.0005 + 0.0758
277 0.0005 + 0.0755
278 0.0005 + 0.0770
279 0.0005 + 0.0758
280 0.0005 + 0.0733
281 0.0005 + 0.0743
282 0.0005 + 0.0737
283 0.0005 + 0.0747
284 0.0005 + 0.0752
285 0.0005 + 0.0737
286 0.0005 + 0.0742
287 0.0005 + 0.0763
288 0.0005 + 0.0752
289 0.0005 + 0.0752
290 0.0005 + 0.0749
291 0.0005 + 0.0747
292 0.0005 + 0.0741
293 0.0005 + 0.0737
294 0.0005 + 0.0738
295 0.0005 + 0.0741
296 0.0005 + 0.0738
297 0.0005 + 0.0730
298 0.0005 + 0.0750
299 0.0005 + 0.0745
300 0.0005 + 0.0734
301 0.0005 + 0.0742
302 0.0005 + 0.0735
303 0.0005 + 0.0739
304 0.0005 + 0.0737
305 0.0005 + 0.0777
306 0.0005 + 0.0825
307 0.0005 + 0.0744
308 0.0005 + 0.0756
309 0.0005 + 0.0753
310 0.0005 + 0.0778
311 0.0005 + 0.0766
312 0.0005 + 0.0760
313 0.0005 + 0.0748
314 0.0005 + 0.0729
315 0.0005 + 0.0736
316 0.0005 + 0.0728
317 0.0005 + 0.0740
318 0.0005 + 0.0750
319 0.0005 + 0.0733
320 0.0005 + 0.0725
321 0.0005 + 0.0733
322 0.0005 + 0.0752
323 0.0005 + 0.0752
324 0.0005 + 0.0752
325 0.0005 + 0.0748
326 0.0005 + 0.0726
327 0.0005 + 0.0742
328 0.0005 + 0.0741
329 0.0005 + 0.0751
330 0.0005 + 0.0735
331 0.0005 + 0.0740
332 0.0005 + 0.0740
333 0.0005 + 0.0758
334 0.0005 + 0.0767
335 0.0005 + 0.0771
336 0.0005 + 0.0750
337 0.0005 + 0.0749
338 0.0005 + 0.0751
339 0.0005 + 0.0728
340 0.0004 + 0.0741
341 0.0004 + 0.0884
342 0.0004 + 0.0737
343 0.0004 + 0.0728
344 0.0004 + 0.0742
345 0.0004 + 0.0744
346 0.0004 + 0.0730
347 0.0004 + 0.0738
348 0.0004 + 0.0745
349 0.0004 + 0.0749
350 0.0004 + 0.0730
351 0.0004 + 0.0730
352 0.0004 + 0.0748
353 0.0004 + 0.0738
354 0.0004 + 0.0728
355 0.0004 + 0.0796
356 0.0004 + 0.0765
357 0.0004 + 0.0752
358 0.0004 + 0.0748
359 0.0004 + 0.0746
360 0.0004 + 0.0730
361 0.0004 + 0.0752
362 0.0004 + 0.0744
363 0.0004 + 0.0743
364 0.0004 + 0.0739
365 0.0004 + 0.0741
366 0.0004 + 0.0733
367 0.0004 + 0.0745
368 0.0004 + 0.0726
369 0.0004 + 0.0758
370 0.0004 + 0.0764
371 0.0004 + 0.0764
372 0.0004 + 0.0776
373 0.0004 + 0.0782
374 0.0004 + 0.0784
375 0.0004 + 0.0774
376 0.0004 + 0.0792
377 0.0004 + 0.0780
378 0.0004 + 0.0796
379 0.0004 + 0.0780
380 0.0004 + 0.0788
381 0.0004 + 0.0773
382 0.0004 + 0.0762
383 0.0004 + 0.0771
384 0.0004 + 0.0780
385 0.0004 + 0.0768
386 0.0004 + 0.0764
387 0.0004 + 0.0744
388 0.0004 + 0.0751
389 0.0004 + 0.0753
390 0.0004 + 0.0743
391 0.0004 + 0.0748
392 0.0004 + 0.0744
393 0.0004 + 0.0751
394 0.0004 + 0.0764
395 0.0004 + 0.0778
396 0.0004 + 0.0784
397 0.0004 + 0.0751
398 0.0004 + 0.0761
399 0.0004 + 0.0742
400 0.0004 + 0.0754
401 0.0004 + 0.0754
402 0.0004 + 0.0763
403 0.0004 + 0.0752
404 0.0004 + 0.0730
405 0.0004 + 0.0737
406 0.0004 + 0.0737
407 0.0004 + 0.0735
408 0.0004 + 0.0756
409 0.0004 + 0.0732
410 0.0004 + 0.0731
411 0.0004 + 0.0746
412 0.0004 + 0.0748
413 0.0004 + 0.0751
414 0.0004 + 0.0745
415 0.0004 + 0.0755
416 0.0004 + 0.0737
417 0.0004 + 0.0749
418 0.0004 + 0.0739
419 0.0004 + 0.0736
420 0.0004 + 0.0757
421 0.0004 + 0.0754
422 0.0004 + 0.0770
423 0.0004 + 0.0756
424 0.0004 + 0.0753
425 0.0004 + 0.0749
426 0.0004 + 0.0742
427 0.0004 + 0.0726
428 0.0004 + 0.0731
429 0.0004 + 0.0738
430 0.0004 + 0.0745
431 0.0004 + 0.0741
432 0.0004 + 0.0757
433 0.0004 + 0.0750
434 0.0004 + 0.0760
435 0.0004 + 0.0774
436 0.0004 + 0.0770
437 0.0004 + 0.0759
438 0.0004 + 0.0745
439 0.0004 + 0.0765
440 0.0004 + 0.0758
441 0.0004 + 0.0748
442 0.0004 + 0.0744
443 0.0004 + 0.0751
444 0.0004 + 0.0775
445 0.0004 + 0.0767
446 0.0004 + 0.0767
447 0.0004 + 0.0763
448 0.0004 + 0.0763
449 0.0004 + 0.0740
450 0.0004 + 0.0767
451 0.0004 + 0.0760
452 0.0003 + 0.0783
453 0.0003 + 0.0775
454 0.0003 + 0.0783
455 0.0003 + 0.0765
456 0.0003 + 0.0765
457 0.0003 + 0.0765
458 0.0003 + 0.0760
459 0.0003 + 0.0771
460 0.0003 + 0.0739
461 0.0003 + 0.0744
462 0.0003 + 0.0844
463 0.0003 + 0.0751
464 0.0003 + 0.0747
465 0.0003 + 0.0741
466 0.0003 + 0.0756
467 0.0003 + 0.0758
468 0.0003 + 0.0753
469 0.0003 + 0.0737
470 0.0003 + 0.0765
471 0.0003 + 0.0750
472 0.0003 + 0.0756
473 0.0003 + 0.0738
474 0.0003 + 0.0766
475 0.0003 + 0.0733
476 0.0003 + 0.0738
477 0.0003 + 0.0734
478 0.0003 + 0.0733
479 0.0003 + 0.0753
480 0.0003 + 0.0753
481 0.0003 + 0.0749
482 0.0003 + 0.0744
483 0.0003 + 0.0739
484 0.0003 + 0.0759
485 0.0003 + 0.0743
486 0.0003 + 0.0737
487 0.0003 + 0.0738
488 0.0003 + 0.0767
489 0.0003 0.0767
490 0.0004 0.0783
491 0.0006 0.0740
492 0.0017 0.0746
493 0.0023 0.0737
494 0.0011 0.0728
495 0.0020 0.0743
496 0.0005 0.0735
497 0.0017 0.0742
498 0.0009 0.0748
499 0.0013 0.0723
500 0.0006 0.0725
501 0.0011 0.0743
502 0.0007 0.0743
503 0.0009 0.0730
504 0.0006 0.0743
505 0.0008 0.0725
506 0.0007 0.0749
507 0.0007 0.0733
508 0.0005 0.0726
509 0.0007 0.0751
510 0.0006 0.0753
511 0.0006 0.0736
512 0.0005 0.0744
513 0.0005 0.0761
514 0.0005 0.0758
515 0.0005 0.0746
516 0.0005 0.0738
517 0.0005 0.0750
518 0.0005 0.0757
519 0.0005 0.0741
520 0.0005 0.0724
521 0.0004 0.0726
522 0.0005 0.0722
523 0.0004 0.0730
524 0.0005 0.0740
525 0.0004 0.0744
526 0.0004 0.0745
527 0.0004 0.0729
528 0.0004 0.0746
529 0.0004 0.0734
530 0.0004 0.0738
531 0.0004 0.0738
532 0.0004 0.0742
533 0.0004 0.0738
534 0.0004 0.0729
535 0.0004 0.0740
536 0.0004 0.0744
537 0.0004 0.0754
538 0.0004 0.0747
539 0.0004 0.0737
540 0.0004 0.0744
541 0.0004 0.0754
542 0.0004 0.0746
543 0.0004 0.0734
544 0.0004 0.0753
545 0.0004 0.0727
546 0.0004 0.0737
547 0.0004 0.0728
548 0.0004 0.0719
549 0.0004 0.0721
550 0.0004 0.0727
551 0.0004 0.0742
552 0.0004 0.0722
553 0.0004 0.0733
554 0.0004 0.0737
555 0.0004 0.0737
556 0.0004 0.0746
557 0.0004 0.0751
558 0.0004 0.0744
559 0.0004 0.0741
560 0.0004 0.0728
561 0.0004 0.0719
562 0.0003 0.0719
563 0.0003 0.0716
564 0.0003 0.0743
565 0.0003 0.0735
566 0.0003 0.0737
567 0.0003 0.0750
568 0.0003 0.0772
569 0.0003 0.0742
570 0.0003 0.0744
571 0.0003 0.0739
572 0.0003 0.0727
573 0.0003 0.0740
574 0.0003 0.0731
575 0.0003 0.0719
576 0.0003 0.0725
577 0.0003 0.0745
578 0.0003 0.0755
579 0.0003 0.0734
580 0.0003 0.0732
581 0.0003 0.0734
582 0.0003 0.0741
583 0.0003 0.0761
584 0.0003 0.0764
585 0.0003 0.0757
586 0.0003 0.0738
587 0.0003 0.0726
588 0.0003 0.0735
589 0.0003 0.0773
590 0.0003 0.0758
591 0.0003 0.0766
592 0.0003 0.0767
593 0.0003 0.0748
594 0.0003 0.0742
595 0.0003 0.0743
596 0.0003 0.0726
597 0.0003 0.0738
598 0.0003 0.0729
599 0.0003 0.0730
600 0.0003 0.0783
601 0.0003 0.0743
602 0.0003 0.0732
603 0.0003 0.0731
604 0.0003 0.0752
605 0.0003 0.0725
606 0.0003 0.0744
607 0.0003 0.0740
608 0.0003 0.0736
609 0.0003 0.0726
610 0.0003 0.0723
611 0.0003 0.0743
612 0.0003 0.0746
613 0.0003 0.0716
614 0.0003 0.0724
615 0.0003 0.0730
616 0.0003 0.0753
617 0.0003 0.0751
618 0.0003 0.0742
619 0.0003 0.0731
620 0.0003 0.0749
621 0.0003 0.0763
622 0.0003 0.0754
623 0.0003 0.0748
624 0.0003 0.0733
625 0.0003 0.0730
626 0.0003 0.0728
627 0.0003 0.0738
628 0.0003 0.0739
629 0.0003 0.0733
630 0.0003 0.0747
631 0.0003 0.0743
632 0.0003 0.0719
633 0.0003 + 0.0731
634 0.0003 + 0.0766
635 0.0003 + 0.0748
636 0.0003 + 0.0741
637 0.0003 + 0.0742
638 0.0003 + 0.0729
639 0.0003 + 0.0739
640 0.0003 + 0.0718
641 0.0003 + 0.0732
642 0.0003 + 0.0759
643 0.0003 + 0.0733
644 0.0003 + 0.0774
645 0.0003 + 0.0768
646 0.0003 + 0.0759
647 0.0003 + 0.0768
648 0.0003 + 0.0759
649 0.0003 + 0.0759
650 0.0003 + 0.0748
651 0.0003 + 0.0737
652 0.0003 + 0.0753
653 0.0003 + 0.0744
654 0.0003 + 0.0735
655 0.0003 + 0.0761
656 0.0003 + 0.0742
657 0.0003 + 0.0738
658 0.0003 + 0.0740
659 0.0003 + 0.0751
660 0.0003 + 0.0741
661 0.0003 + 0.0735
662 0.0003 + 0.0726
663 0.0003 + 0.0732
664 0.0003 + 0.0749
665 0.0003 + 0.0739
666 0.0003 + 0.0757
667 0.0003 + 0.0752
668 0.0003 + 0.0743
669 0.0003 + 0.0733
670 0.0003 + 0.0728
671 0.0003 + 0.0745
672 0.0003 + 0.0744
673 0.0003 + 0.0731
674 0.0003 + 0.0746
675 0.0003 + 0.0742
676 0.0003 + 0.0751
677 0.0003 + 0.0748
678 0.0003 + 0.0746
679 0.0003 + 0.0739
680 0.0003 + 0.0752
681 0.0003 + 0.0740
682 0.0003 + 0.0750
683 0.0003 + 0.0757
684 0.0003 + 0.0734
685 0.0003 + 0.0747
686 0.0003 + 0.0740
687 0.0003 + 0.0765
688 0.0003 + 0.0750
689 0.0003 + 0.0736
690 0.0003 + 0.0724
691 0.0003 + 0.0740
692 0.0003 + 0.0741
693 0.0003 + 0.0737
694 0.0003 + 0.0736
695 0.0003 + 0.0767
696 0.0003 + 0.0733
697 0.0003 + 0.0755
698 0.0003 + 0.0747
699 0.0003 + 0.0752
700 0.0003 + 0.0751
701 0.0003 + 0.0757
702 0.0003 + 0.0733
703 0.0003 + 0.0739
704 0.0003 + 0.0744
705 0.0003 + 0.0742
706 0.0003 + 0.0771
707 0.0003 + 0.0749
708 0.0003 + 0.0777
709 0.0003 + 0.0765
710 0.0003 + 0.0738
711 0.0003 + 0.0738
712 0.0003 + 0.0750
713 0.0003 + 0.0767
714 0.0003 + 0.0761
715 0.0003 + 0.0746
716 0.0003 + 0.0763
717 0.0003 + 0.0752
718 0.0003 + 0.0764
719 0.0003 + 0.0780
720 0.0003 + 0.0763
721 0.0003 + 0.0755
722 0.0003 + 0.0758
723 0.0003 + 0.0748
724 0.0003 + 0.0745
725 0.0003 + 0.0746
726 0.0003 + 0.0748
727 0.0003 + 0.0766
728 0.0003 + 0.0742
729 0.0003 + 0.0769
730 0.0003 + 0.0774
731 0.0003 + 0.0764
732 0.0003 + 0.0740
733 0.0003 + 0.0748
734 0.0003 + 0.0762
735 0.0003 + 0.0760
736 0.0003 + 0.0742
737 0.0003 + 0.0748
738 0.0003 + 0.0738
739 0.0003 + 0.0760
740 0.0003 + 0.0769
741 0.0003 + 0.0751
742 0.0003 + 0.0743
743 0.0003 + 0.0760
744 0.0003 + 0.0740
745 0.0003 + 0.0751
746 0.0003 + 0.0750
747 0.0003 + 0.0748
748 0.0003 + 0.0760
749 0.0003 + 0.0743
750 0.0003 + 0.0741
751 0.0003 + 0.0738
752 0.0003 + 0.0750
753 0.0003 + 0.0745
754 0.0003 + 0.0741
755 0.0003 + 0.0756
756 0.0003 + 0.0745
757 0.0003 + 0.0752
758 0.0002 + 0.0742
759 0.0002 + 0.0736
760 0.0003 0.0746
761 0.0003 0.0755
762 0.0007 0.0753
763 0.0007 0.0738
764 0.0006 0.0729
765 0.0004 0.0752
766 0.0005 0.0742
767 0.0005 0.0734
768 0.0004 0.0728
769 0.0005 0.0722
770 0.0004 0.0721
771 0.0003 0.0747
772 0.0004 0.0749
773 0.0003 0.0770
774 0.0004 0.0755
775 0.0004 0.0745
776 0.0003 0.0740
777 0.0004 0.0730
778 0.0003 0.0722
779 0.0003 0.0740
780 0.0003 0.0725
781 0.0003 0.0741
782 0.0003 0.0730
783 0.0003 0.0739
784 0.0003 0.0735
785 0.0003 0.0734
786 0.0003 0.0744
787 0.0003 0.0748
788 0.0003 0.0755
789 0.0003 0.0744
790 0.0003 0.0736
791 0.0003 0.0745
792 0.0003 0.0735
793 0.0003 0.0737
794 0.0003 0.0728
795 0.0003 0.0730
796 0.0003 0.0731
797 0.0002 0.0745
798 0.0002 + 0.0731
799 0.0003 0.0740
800 0.0003 0.0773
801 0.0002 0.0751
802 0.0002 + 0.0739
803 0.0002 0.0727
804 0.0002 0.0727
805 0.0002 0.0740
806 0.0002 0.0740
807 0.0002 + 0.0757
808 0.0002 0.0750
809 0.0002 0.0771
810 0.0002 0.0744
811 0.0002 0.0738
812 0.0002 0.0738
813 0.0002 0.0742
814 0.0003 0.0739
815 0.0002 0.0734
816 0.0002 0.0720
817 0.0002 0.0744
818 0.0002 + 0.0749
819 0.0002 + 0.0749
820 0.0002 + 0.0751
821 0.0002 + 0.0733
822 0.0002 0.0748
823 0.0002 0.0733
824 0.0002 0.0761
825 0.0002 0.0776
826 0.0002 0.0764
827 0.0003 0.0750
828 0.0003 0.0748
829 0.0003 0.0763
830 0.0002 0.0768
831 0.0002 + 0.0747
832 0.0002 + 0.0760
833 0.0002 0.0774
834 0.0002 0.0727
835 0.0002 0.0740
836 0.0002 0.0737
837 0.0002 + 0.0738
838 0.0002 0.0758
839 0.0002 0.0740
840 0.0002 0.0727
841 0.0002 0.0739
842 0.0002 0.0721
843 0.0002 + 0.0736
844 0.0002 + 0.0745
845 0.0002 0.0783
846 0.0002 0.0726
847 0.0002 0.0755
848 0.0002 0.0741
849 0.0002 + 0.0783
850 0.0002 + 0.0780
851 0.0002 + 0.0762
852 0.0002 + 0.0738
853 0.0002 0.0735
854 0.0002 0.0747
855 0.0002 0.0734
856 0.0002 + 0.0741
857 0.0002 + 0.0754
858 0.0002 + 0.0729
859 0.0002 + 0.0738
860 0.0002 + 0.0755
861 0.0002 + 0.0758
862 0.0002 + 0.0750
863 0.0002 + 0.0733
864 0.0002 + 0.0732
865 0.0002 + 0.0740
866 0.0002 + 0.0735
867 0.0002 + 0.0747
868 0.0002 + 0.0777
869 0.0002 + 0.0762
870 0.0002 + 0.0774
871 0.0002 + 0.0750
872 0.0002 + 0.0776
873 0.0002 + 0.0744
874 0.0002 + 0.0732
875 0.0001 + 0.0758
876 0.0001 + 0.0757
877 0.0001 + 0.0746
878 0.0001 + 0.0750
879 0.0001 + 0.0738
880 0.0001 + 0.0740
881 0.0001 0.0741
882 0.0002 0.0742
883 0.0002 0.0748
884 0.0002 0.0729
885 0.0002 0.0727
886 0.0003 0.0729
887 0.0004 0.0747
888 0.0003 0.0749
889 0.0003 0.0734
890 0.0002 0.0739
891 0.0002 0.0749
892 0.0002 0.0738
893 0.0003 0.0746
894 0.0003 0.0748
895 0.0003 0.0761
896 0.0002 0.0751
897 0.0003 0.0729
898 0.0003 0.0717
899 0.0002 0.0728
900 0.0002 0.0733
901 0.0002 0.0724
902 0.0002 0.0728
903 0.0002 0.0752
904 0.0002 0.0720
905 0.0002 0.0732
906 0.0002 0.0737
907 0.0002 0.0740
908 0.0002 0.0745
909 0.0002 0.0740
910 0.0002 0.0740
911 0.0002 0.0751
912 0.0002 0.0733
913 0.0002 0.0739
914 0.0002 0.0736
915 0.0002 0.0730
916 0.0002 0.0726
917 0.0002 0.0730
918 0.0002 0.0748
919 0.0001 0.0754
920 0.0001 0.0783
921 0.0002 0.0750
922 0.0001 0.0760
923 0.0001 + 0.0750
924 0.0001 + 0.0755
925 0.0001 + 0.0767
926 0.0001 0.0768
927 0.0001 0.0753
928 0.0001 0.0764
929 0.0001 0.0761
930 0.0001 0.0756
931 0.0001 0.0762
932 0.0001 0.0761
933 0.0002 0.0758
934 0.0002 0.0748
935 0.0002 0.0747
936 0.0002 0.0750
937 0.0002 0.0746
938 0.0002 0.0745
939 0.0002 0.0753
940 0.0001 0.0741
941 0.0001 + 0.0732
942 0.0001 + 0.0766
943 0.0001 0.0757
944 0.0001 0.0762
945 0.0002 0.0762
946 0.0002 0.0767
947 0.0002 0.0954
948 0.0002 0.0738
949 0.0002 0.0758
950 0.0001 0.0777
951 0.0001 0.0757
952 0.0001 + 0.0770
953 0.0001 0.0791
954 0.0001 0.0751
955 0.0001 0.0759
956 0.0001 0.0773
957 0.0002 0.0766
958 0.0001 0.0775
959 0.0001 0.0765
960 0.0001 0.0738
961 0.0001 + 0.0763
962 0.0001 0.0769
963 0.0001 0.0755
964 0.0001 0.0775
965 0.0001 0.0735
966 0.0001 0.0753
967 0.0001 0.0757
968 0.0002 0.0754
969 0.0001 0.0781
970 0.0001 0.0763
971 0.0001 0.0771
972 0.0001 0.0764
973 0.0001 0.0744
974 0.0001 + 0.0747
975 0.0001 + 0.0757
976 0.0001 + 0.0763
977 0.0001 0.0764
978 0.0001 0.0746
979 0.0001 0.0767
980 0.0001 0.0780
981 0.0001 0.0787
982 0.0001 0.0764
983 0.0001 0.0766
984 0.0001 0.0778
985 0.0002 0.0777
986 0.0002 0.0754
987 0.0002 0.0765
988 0.0002 0.0767
989 0.0001 0.0747
990 0.0001 0.0740
991 0.0001 + 0.0753
992 0.0001 + 0.0771
993 0.0001 0.0799
994 0.0001 0.0774
995 0.0001 0.0766
996 0.0001 0.0783
997 0.0001 0.0747
998 0.0001 0.0744
999 0.0001 + 0.0769
1000 0.0001 0.0785
Loading the best network from the last checkpoint.
[4]:
<skorch.callbacks.training.Checkpoint at 0x7fbfdb73c290>
Plot training evolution
[5]:
if config['plot']['training progress']:
history_length = len(tsp.ttr.regressor_['regressor'].history)
train_loss = np.zeros((history_length, 1))
for epoch in tsp.ttr.regressor_['regressor'].history:
epoch_number = epoch['epoch']-1
train_loss[epoch_number] = epoch['train_loss']
_, axes_one = plt.subplots(figsize=(20, 20))
axes_one.plot(train_loss, 'o-', label='training')
axes_one.set_xlabel('Epoch')
axes_one.set_ylabel('MSE')
plt.legend()
Prediction on training data
[6]:
if config['predict on training data enabled']:
dataloader = tsp.ttr.regressor['regressor'].get_iterator(tsp.dataset)
x, y = dataloader.dataset[:]
netout = tsp.predict(x)
d_output = netout.shape[-1]
if config['plot']['prediction on training data']:
fig, axs = plt.subplots(d_output, 1, figsize=(20,20))
axs = [axs]
idx_range = [np.random.randint(0, len(tsp.dataset))]
for idx in idx_range:
if config['plot']['prediction on training data']:
x_absciss = convert_year_month_array_to_datetime(x[idx, :, :])
for idx_output_var in range(d_output):
# Select real passengers data
y_true = y[idx, :, idx_output_var]
y_pred = netout[idx, :, idx_output_var]
if config['plot']['prediction on training data']:
ax = axs[idx_output_var]
ax.plot(x_absciss, y_true, label="Truth", color='tab:blue')
ax.plot(x_absciss, y_pred, label="Prediction", color='tab:orange')
if idx == idx_range[0]:
ax.set_title(tsp.dataset.labels['y'][idx_output_var] + ' over time')
ax.set_xlabel('date')
ax.set_ylabel(tsp.dataset.labels['y'][idx_output_var])
ax.legend()
if config['plot']['prediction on training data']:
plt.show()
Future forecast
[7]:
# def make_x_pred(x, n_months_ahead, include_history=True):
# def raw_add_months(sourcedate, n_months):
# month = sourcedate.month - 1 + n_months
# year = sourcedate.year + month // 12
# month = month % 12 + 1
# day = min(sourcedate.day, calendar.monthrange(year,month)[1])
# return datetime(year, month, day)
# def add_months(months, n_months):
# return [raw_add_months(month, n_months) for month in months]
# def add_months_2(month, n_months):
# return [raw_add_months(month, n_month) for n_month in n_months]
# x_dates = convert_year_month_array_to_datetime(x.squeeze())
# last_month = x_dates[-1:][0]
# next_n_months = add_months_2(last_month, np.arange(n_months_ahead)+1)
# if include_history:
# n_months = x_dates + next_n_months
# else:
# n_months = next_n_months
# x_pred = np.array([[dt.month, dt.year] for dt in n_months]).reshape(1, -1, 2).astype(np.float32)
# return x_pred
if config['forecast enabled']:
dataloader = tsp.ttr.regressor['regressor'].get_iterator(tsp.dataset)
x, y = dataloader.dataset[:]
# x_pred = make_x_pred(x, config['forecast']['months ahead'])
# netout = tsp.predict(x_pred)
netout, x_pred = tsp.forecast(config['forecast']['months ahead'], include_history = config['forecast']['include history'])
d_output = netout.shape[-1]
if config['plot']['forecast']:
fig, axs = plt.subplots(d_output, 1, figsize=(20,20))
axs = [axs]
idx_range = [np.random.randint(0, len(tsp.dataset))]
for idx in idx_range:
if config['plot']['forecast']:
x_absciss = convert_year_month_array_to_datetime(x[idx, :, :])
x_pred_absciss = convert_year_month_array_to_datetime(x_pred[idx, :, :])
for idx_output_var in range(d_output):
# Select real passengers data
y_true = y[idx, :, idx_output_var]
y_pred = netout[idx, :, idx_output_var]
if config['plot']['prediction on validation data']:
ax = axs[idx_output_var]
ax.plot(x_pred_absciss, y_pred, label="Prediction", color='tab:orange')
ax.plot(x_absciss, y_true, label="Truth", color='tab:blue')
if idx == idx_range[0]:
ax.set_title(tsp.dataset.labels['y'][idx_output_var] + ' over time')
ax.set_xlabel('date')
ax.set_ylabel(tsp.dataset.labels['y'][idx_output_var])
ax.legend()
if config['plot']['prediction on validation data']:
plt.show()
[ ]: