Flights passengers example
Imports
[1]:
import torch
import numpy as np
import calendar
from datetime import timedelta, datetime
from matplotlib import pyplot as plt
from time_series_predictor import TimeSeriesPredictor
from time_series_models import BenchmarkLSTM
from flights_time_series_dataset import FlightsDataset, convert_year_month_array_to_datetime
/home/docs/checkouts/readthedocs.org/user_builds/time-series-predictor/envs/latest/lib/python3.7/site-packages/tqdm/auto.py:22: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html
from .autonotebook import tqdm as notebook_tqdm
Config
[2]:
plot_config = {}
plot_config['training progress'] = True
plot_config['prediction on training data'] = True
plot_config['prediction on validation data'] = True
plot_config['forecast'] = True
forecast_config = {}
forecast_config['include history'] = True
forecast_config['months ahead'] = 24
predictor_config = {}
predictor_config['epochs'] = 1000
predictor_config['learning rate'] = 1e-2
predictor_config['hidden dim'] = 100
predictor_config['layers num'] = 3
config = {}
config['plot'] = plot_config
config['forecast'] = forecast_config
config['predictor'] = predictor_config
config['predict on training data enabled'] = True
config['forecast enabled'] = True
Time Series Predictor instantiation
[3]:
tsp = TimeSeriesPredictor(
BenchmarkLSTM(
hidden_dim=config['predictor']['hidden dim'],
num_layers=config['predictor']['layers num']
),
lr=config['predictor']['learning rate'],
max_epochs=predictor_config['epochs'],
train_split=None,
optimizer=torch.optim.Adam
)
Training process
[4]:
ds = FlightsDataset()
tsp.fit(ds)
Using device cpu
Re-initializing module because the following parameters were re-set: module__input_dim, module__output_dim.
Re-initializing criterion.
Re-initializing optimizer.
epoch train_loss cp dur
------- ------------ ---- ------
1 0.2249 + 0.1329
2 0.0488 + 0.1030
3 2.3266 0.1009
4 0.1018 0.0978
5 0.1441 0.0986
6 0.1413 0.1018
7 0.1307 0.1051
8 0.1174 0.1047
9 0.1005 0.1028
10 0.0778 0.0970
11 0.0520 0.1031
12 0.0604 0.1017
13 0.0643 0.0947
14 0.0490 0.0985
15 0.0463 + 0.0981
16 0.0500 0.0967
17 0.0528 0.0953
18 0.0528 0.0958
19 0.0501 0.0992
20 0.0458 + 0.0963
21 0.0415 + 0.0953
22 0.0388 + 0.0954
23 0.0393 0.0924
24 0.0397 0.0912
25 0.0290 + 0.0930
26 0.0246 + 0.0937
27 0.0211 + 0.0944
28 0.0139 + 0.0925
29 0.0140 0.0931
30 0.0133 + 0.0918
31 0.0136 0.0935
32 0.0162 0.0921
33 0.0147 0.0945
34 0.0148 0.0928
35 0.0171 0.0934
36 0.0159 0.0933
37 0.0155 0.0946
38 0.0167 0.0921
39 0.0160 0.0918
40 0.0145 0.0930
41 0.0148 0.0928
42 0.0145 0.0921
43 0.0134 0.0938
44 0.0130 + 0.0924
45 0.0132 0.0936
46 0.0122 + 0.0923
47 0.0115 + 0.0927
48 0.0116 0.0953
49 0.0112 + 0.0937
50 0.0107 + 0.0926
51 0.0110 0.0911
52 0.0110 0.0922
53 0.0105 + 0.0933
54 0.0104 + 0.0920
55 0.0104 + 0.0927
56 0.0098 + 0.0937
57 0.0098 + 0.0933
58 0.0093 + 0.0929
59 0.0089 + 0.0921
60 0.0086 + 0.0928
61 0.0082 + 0.0925
62 0.0081 + 0.0920
63 0.0078 + 0.0943
64 0.0078 0.0919
65 0.0075 + 0.0921
66 0.0075 + 0.0928
67 0.0073 + 0.0929
68 0.0074 0.0925
69 0.0072 + 0.0931
70 0.0072 0.0925
71 0.0071 + 0.0937
72 0.0069 + 0.0936
73 0.0069 + 0.0943
74 0.0066 + 0.0930
75 0.0065 + 0.0932
76 0.0063 + 0.0915
77 0.0062 + 0.0924
78 0.0060 + 0.0928
79 0.0058 + 0.0917
80 0.0056 + 0.0933
81 0.0053 + 0.0934
82 0.0050 + 0.0933
83 0.0045 + 0.0923
84 0.0040 + 0.0927
85 0.0035 + 0.0933
86 0.0035 + 0.0939
87 0.0041 0.0946
88 0.0039 0.0937
89 0.0035 + 0.0924
90 0.0038 0.0924
91 0.0033 + 0.0927
92 0.0033 + 0.0929
93 0.0033 0.0929
94 0.0030 + 0.0913
95 0.0031 0.0943
96 0.0028 + 0.0938
97 0.0029 0.0934
98 0.0028 + 0.0925
99 0.0029 0.0928
100 0.0028 0.0929
101 0.0028 0.0935
102 0.0027 + 0.0946
103 0.0026 + 0.0925
104 0.0026 + 0.0931
105 0.0025 + 0.0935
106 0.0026 0.0947
107 0.0025 + 0.0954
108 0.0025 0.0939
109 0.0024 + 0.0925
110 0.0024 + 0.0935
111 0.0024 + 0.0929
112 0.0023 + 0.0940
113 0.0023 + 0.0921
114 0.0023 + 0.0927
115 0.0023 0.0934
116 0.0022 + 0.0934
117 0.0022 + 0.0927
118 0.0022 + 0.0929
119 0.0021 + 0.0929
120 0.0021 + 0.0935
121 0.0021 + 0.0938
122 0.0021 + 0.0938
123 0.0020 + 0.0931
124 0.0020 + 0.0932
125 0.0020 + 0.0933
126 0.0020 + 0.0946
127 0.0019 + 0.0927
128 0.0019 + 0.0925
129 0.0019 + 0.0925
130 0.0018 + 0.0943
131 0.0018 + 0.0944
132 0.0018 + 0.0931
133 0.0018 + 0.0920
134 0.0017 + 0.0935
135 0.0017 + 0.0924
136 0.0017 + 0.0925
137 0.0017 + 0.0911
138 0.0016 + 0.0913
139 0.0016 + 0.0930
140 0.0016 + 0.0932
141 0.0015 + 0.0928
142 0.0015 + 0.0925
143 0.0015 + 0.0927
144 0.0015 + 0.0933
145 0.0014 + 0.0947
146 0.0014 + 0.0933
147 0.0014 + 0.0931
148 0.0014 0.0926
149 0.0016 0.0913
150 0.0020 0.0925
151 0.0035 0.0921
152 0.0023 0.0926
153 0.0015 0.0943
154 0.0014 0.0924
155 0.0018 0.0921
156 0.0017 0.0928
157 0.0012 + 0.0924
158 0.0018 0.0933
159 0.0020 0.0929
160 0.0012 + 0.0940
161 0.0020 0.0914
162 0.0019 0.0916
163 0.0013 0.0938
164 0.0022 0.0927
165 0.0014 0.0948
166 0.0016 0.0920
167 0.0016 0.0927
168 0.0011 + 0.0919
169 0.0015 0.0928
170 0.0011 + 0.0915
171 0.0015 0.0960
172 0.0013 0.0914
173 0.0012 0.0926
174 0.0013 0.0935
175 0.0011 + 0.0930
176 0.0013 0.0922
177 0.0011 + 0.0911
178 0.0012 0.0913
179 0.0011 0.0916
180 0.0011 0.0924
181 0.0011 0.0935
182 0.0010 + 0.0917
183 0.0011 0.0935
184 0.0010 + 0.0918
185 0.0011 0.0945
186 0.0010 + 0.0939
187 0.0010 0.0940
188 0.0010 0.0918
189 0.0010 + 0.0927
190 0.0010 0.0930
191 0.0010 + 0.0935
192 0.0010 0.0925
193 0.0010 + 0.0918
194 0.0010 0.0924
195 0.0010 0.0938
196 0.0009 + 0.0918
197 0.0010 0.0920
198 0.0009 + 0.0915
199 0.0009 0.0932
200 0.0009 0.0920
201 0.0009 + 0.0938
202 0.0009 0.0921
203 0.0009 + 0.0924
204 0.0009 0.0925
205 0.0009 + 0.0948
206 0.0009 + 0.0949
207 0.0009 0.0953
208 0.0009 + 0.0932
209 0.0009 + 0.0925
210 0.0009 0.0928
211 0.0009 + 0.0937
212 0.0009 + 0.0933
213 0.0009 + 0.0931
214 0.0008 + 0.0929
215 0.0008 + 0.0916
216 0.0008 + 0.0921
217 0.0008 + 0.0930
218 0.0008 + 0.0944
219 0.0008 + 0.0929
220 0.0008 + 0.0943
221 0.0008 + 0.0940
222 0.0008 + 0.0939
223 0.0008 + 0.0929
224 0.0008 + 0.0950
225 0.0008 + 0.0929
226 0.0008 + 0.0929
227 0.0008 + 0.0920
228 0.0008 + 0.0943
229 0.0008 + 0.0938
230 0.0008 + 0.0934
231 0.0008 + 0.0935
232 0.0008 + 0.0939
233 0.0008 + 0.0937
234 0.0008 + 0.0924
235 0.0008 + 0.0924
236 0.0008 + 0.0943
237 0.0008 + 0.0927
238 0.0008 + 0.0942
239 0.0008 + 0.0938
240 0.0008 + 0.0937
241 0.0008 + 0.0934
242 0.0007 + 0.0952
243 0.0007 + 0.0943
244 0.0007 + 0.0945
245 0.0007 + 0.0925
246 0.0007 + 0.0932
247 0.0007 + 0.0924
248 0.0007 + 0.0938
249 0.0007 + 0.0937
250 0.0007 0.0950
251 0.0007 0.0980
252 0.0008 0.0919
253 0.0009 0.0923
254 0.0015 0.0914
255 0.0020 0.0926
256 0.0029 0.0928
257 0.0009 0.0928
258 0.0020 0.0932
259 0.0028 0.0932
260 0.0010 0.0941
261 0.0034 0.0935
262 0.0020 0.0946
263 0.0020 0.0933
264 0.0020 0.0944
265 0.0011 0.0939
266 0.0018 0.0928
267 0.0011 0.0923
268 0.0017 0.0943
269 0.0009 0.0930
270 0.0014 0.0934
271 0.0009 0.0933
272 0.0013 0.0936
273 0.0009 0.0929
274 0.0012 0.0934
275 0.0009 0.0923
276 0.0010 0.0942
277 0.0009 0.0921
278 0.0010 0.0944
279 0.0009 0.0945
280 0.0009 0.0938
281 0.0009 0.0931
282 0.0008 0.0929
283 0.0009 0.0943
284 0.0008 0.0936
285 0.0009 0.0940
286 0.0008 0.0928
287 0.0008 0.0938
288 0.0007 0.0925
289 0.0008 0.0940
290 0.0007 + 0.0939
291 0.0008 0.0944
292 0.0007 + 0.0939
293 0.0008 0.0938
294 0.0007 + 0.0923
295 0.0007 0.0933
296 0.0007 + 0.0927
297 0.0007 0.0931
298 0.0007 + 0.0920
299 0.0007 0.0943
300 0.0007 + 0.0927
301 0.0007 0.0945
302 0.0007 + 0.0930
303 0.0007 0.0939
304 0.0007 + 0.0919
305 0.0007 0.0940
306 0.0007 + 0.0933
307 0.0007 0.0924
308 0.0006 + 0.0930
309 0.0007 0.0932
310 0.0006 + 0.0921
311 0.0006 0.0930
312 0.0006 + 0.0937
313 0.0006 0.0926
314 0.0006 + 0.0922
315 0.0006 0.0925
316 0.0006 + 0.0932
317 0.0006 + 0.0946
318 0.0006 + 0.0951
319 0.0006 + 0.0930
320 0.0006 + 0.0944
321 0.0006 + 0.0934
322 0.0006 + 0.0947
323 0.0006 + 0.0927
324 0.0006 0.0943
325 0.0006 + 0.0916
326 0.0006 + 0.0981
327 0.0006 + 0.0925
328 0.0006 + 0.0928
329 0.0006 + 0.0923
330 0.0006 + 0.0939
331 0.0006 + 0.0926
332 0.0006 + 0.0920
333 0.0006 + 0.0920
334 0.0006 + 0.0909
335 0.0006 + 0.0909
336 0.0006 + 0.0916
337 0.0006 + 0.0922
338 0.0006 + 0.0928
339 0.0006 + 0.0914
340 0.0006 + 0.0932
341 0.0005 + 0.0917
342 0.0005 + 0.0933
343 0.0005 + 0.0911
344 0.0005 + 0.0908
345 0.0005 + 0.0911
346 0.0005 0.0936
347 0.0005 0.0928
348 0.0005 0.0944
349 0.0005 + 0.0931
350 0.0005 0.0914
351 0.0005 0.0913
352 0.0006 0.0912
353 0.0005 0.0925
354 0.0005 0.0920
355 0.0005 0.0924
356 0.0005 0.0925
357 0.0005 + 0.0914
358 0.0005 + 0.0926
359 0.0005 + 0.0935
360 0.0005 + 0.0952
361 0.0005 + 0.0939
362 0.0005 + 0.0944
363 0.0005 + 0.0925
364 0.0005 + 0.0926
365 0.0005 + 0.0926
366 0.0005 + 0.0918
367 0.0005 + 0.0929
368 0.0005 + 0.0922
369 0.0005 + 0.0918
370 0.0005 + 0.0928
371 0.0005 + 0.0929
372 0.0005 + 0.0926
373 0.0005 + 0.0933
374 0.0005 + 0.0944
375 0.0005 0.0934
376 0.0005 0.0922
377 0.0006 0.0918
378 0.0009 0.0931
379 0.0013 0.0936
380 0.0011 0.0931
381 0.0007 0.0924
382 0.0008 0.0929
383 0.0008 0.0908
384 0.0007 0.0949
385 0.0009 0.0919
386 0.0006 0.0920
387 0.0008 0.0921
388 0.0007 0.0911
389 0.0006 0.0912
390 0.0007 0.0908
391 0.0006 0.0909
392 0.0007 0.0915
393 0.0006 0.0911
394 0.0006 0.0917
395 0.0006 0.0936
396 0.0006 0.0918
397 0.0006 0.0931
398 0.0006 0.0925
399 0.0006 0.0942
400 0.0006 0.0939
401 0.0006 0.0930
402 0.0006 0.0925
403 0.0006 0.0920
404 0.0006 0.0930
405 0.0005 0.0906
406 0.0005 0.0935
407 0.0005 0.0919
408 0.0005 0.0923
409 0.0005 0.0939
410 0.0005 0.0935
411 0.0005 0.0917
412 0.0005 0.0921
413 0.0005 0.0923
414 0.0005 0.0923
415 0.0005 0.0929
416 0.0005 0.0943
417 0.0005 0.0919
418 0.0005 0.0941
419 0.0005 0.0924
420 0.0005 + 0.0944
421 0.0005 + 0.0945
422 0.0005 + 0.0944
423 0.0005 + 0.0917
424 0.0005 + 0.0925
425 0.0004 + 0.0917
426 0.0004 + 0.0931
427 0.0004 + 0.0924
428 0.0004 + 0.0932
429 0.0004 + 0.0934
430 0.0004 + 0.0922
431 0.0004 + 0.0917
432 0.0004 + 0.0920
433 0.0004 + 0.0920
434 0.0004 + 0.0942
435 0.0004 + 0.0926
436 0.0004 + 0.0948
437 0.0004 + 0.0934
438 0.0004 + 0.0947
439 0.0004 + 0.0939
440 0.0004 + 0.0938
441 0.0004 + 0.0933
442 0.0004 + 0.0928
443 0.0004 + 0.0948
444 0.0004 + 0.0932
445 0.0004 + 0.0927
446 0.0004 + 0.0917
447 0.0004 + 0.0923
448 0.0004 + 0.0920
449 0.0004 + 0.0918
450 0.0004 + 0.0925
451 0.0004 + 0.0923
452 0.0004 + 0.0948
453 0.0004 + 0.0931
454 0.0004 + 0.0936
455 0.0004 + 0.0933
456 0.0004 + 0.0941
457 0.0004 + 0.0937
458 0.0004 + 0.0936
459 0.0004 + 0.0920
460 0.0004 + 0.0930
461 0.0004 + 0.0942
462 0.0004 + 0.0927
463 0.0004 + 0.0947
464 0.0004 + 0.0939
465 0.0004 + 0.0929
466 0.0003 + 0.0922
467 0.0003 + 0.0924
468 0.0003 + 0.0921
469 0.0003 + 0.1064
470 0.0003 + 0.0947
471 0.0003 + 0.0981
472 0.0003 + 0.0936
473 0.0003 + 0.0952
474 0.0003 + 0.0927
475 0.0003 + 0.0931
476 0.0003 + 0.0925
477 0.0003 + 0.0920
478 0.0003 + 0.0938
479 0.0003 + 0.0926
480 0.0003 + 0.0990
481 0.0003 + 0.0941
482 0.0003 + 0.0933
483 0.0003 + 0.0929
484 0.0003 + 0.0926
485 0.0003 + 0.0915
486 0.0003 + 0.0923
487 0.0003 + 0.0918
488 0.0003 + 0.0931
489 0.0003 + 0.0934
490 0.0003 + 0.0941
491 0.0003 + 0.0915
492 0.0003 + 0.0928
493 0.0003 + 0.0931
494 0.0003 + 0.0913
495 0.0003 + 0.0939
496 0.0003 + 0.0936
497 0.0003 + 0.0929
498 0.0003 + 0.0938
499 0.0003 + 0.0920
500 0.0003 + 0.0927
501 0.0003 + 0.0928
502 0.0003 0.0937
503 0.0003 0.0927
504 0.0003 0.0948
505 0.0003 0.0923
506 0.0005 0.0934
507 0.0007 0.0924
508 0.0015 0.0938
509 0.0008 0.0944
510 0.0003 0.0935
511 0.0007 0.0940
512 0.0010 0.0922
513 0.0011 0.0916
514 0.0004 0.0921
515 0.0008 0.0943
516 0.0010 0.0915
517 0.0003 0.0932
518 0.0009 0.0926
519 0.0018 0.0941
520 0.0004 0.0925
521 0.0016 0.0929
522 0.0017 0.0915
523 0.0007 0.0918
524 0.0023 0.0912
525 0.0011 0.0932
526 0.0012 0.0910
527 0.0013 0.0932
528 0.0005 0.0933
529 0.0012 0.0923
530 0.0004 0.0937
531 0.0012 0.0922
532 0.0004 0.0922
533 0.0009 0.0912
534 0.0004 0.0920
535 0.0007 0.0898
536 0.0005 0.0932
537 0.0006 0.0911
538 0.0005 0.0923
539 0.0004 0.0943
540 0.0005 0.0922
541 0.0003 0.0918
542 0.0006 0.0914
543 0.0003 0.0912
544 0.0005 0.0919
545 0.0003 0.0913
546 0.0004 0.0931
547 0.0004 0.0922
548 0.0004 0.0921
549 0.0004 0.0937
550 0.0003 0.0942
551 0.0004 0.0930
552 0.0003 0.0940
553 0.0004 0.0929
554 0.0003 0.0929
555 0.0003 0.0931
556 0.0003 0.0931
557 0.0003 0.0929
558 0.0003 0.0918
559 0.0003 0.0926
560 0.0003 0.0925
561 0.0003 + 0.0938
562 0.0003 0.0919
563 0.0003 + 0.0915
564 0.0003 0.0919
565 0.0003 0.0936
566 0.0003 0.0930
567 0.0003 0.0950
568 0.0003 + 0.0932
569 0.0003 0.0933
570 0.0002 + 0.0924
571 0.0003 0.0924
572 0.0002 0.0931
573 0.0002 + 0.0928
574 0.0003 0.0925
575 0.0002 + 0.0908
576 0.0002 0.0924
577 0.0002 + 0.0932
578 0.0002 0.0929
579 0.0002 0.0925
580 0.0002 + 0.0927
581 0.0002 0.0924
582 0.0002 + 0.0926
583 0.0002 0.0916
584 0.0002 + 0.0925
585 0.0002 + 0.0924
586 0.0002 0.0931
587 0.0002 + 0.0907
588 0.0002 0.0939
589 0.0002 + 0.0912
590 0.0002 + 0.0924
591 0.0002 0.0924
592 0.0002 + 0.0934
593 0.0002 + 0.0935
594 0.0002 0.0923
595 0.0002 + 0.0928
596 0.0002 + 0.0938
597 0.0002 + 0.0927
598 0.0002 + 0.0932
599 0.0002 + 0.0930
600 0.0002 + 0.0926
601 0.0002 + 0.0926
602 0.0002 + 0.0929
603 0.0002 + 0.0929
604 0.0002 + 0.0942
605 0.0002 + 0.0931
606 0.0002 + 0.0935
607 0.0002 + 0.0939
608 0.0002 + 0.0949
609 0.0002 + 0.0935
610 0.0002 + 0.0922
611 0.0002 + 0.0927
612 0.0002 + 0.0926
613 0.0002 + 0.0927
614 0.0002 + 0.0928
615 0.0002 + 0.0922
616 0.0002 + 0.0935
617 0.0002 + 0.0923
618 0.0002 + 0.0915
619 0.0002 + 0.0917
620 0.0002 + 0.0917
621 0.0002 + 0.0931
622 0.0002 + 0.0915
623 0.0002 + 0.0920
624 0.0002 + 0.0933
625 0.0002 + 0.0932
626 0.0002 + 0.0931
627 0.0002 + 0.0929
628 0.0002 + 0.0918
629 0.0002 + 0.0917
630 0.0002 + 0.0919
631 0.0002 + 0.0920
632 0.0002 + 0.0921
633 0.0002 + 0.0931
634 0.0002 + 0.0928
635 0.0002 + 0.0924
636 0.0002 + 0.0922
637 0.0002 + 0.0921
638 0.0002 + 0.0932
639 0.0002 + 0.0918
640 0.0002 + 0.0930
641 0.0002 + 0.0940
642 0.0002 + 0.0930
643 0.0002 + 0.0936
644 0.0002 + 0.0933
645 0.0002 + 0.0930
646 0.0002 + 0.0931
647 0.0002 + 0.0920
648 0.0002 + 0.0925
649 0.0002 + 0.0919
650 0.0002 + 0.0918
651 0.0002 + 0.0928
652 0.0002 + 0.0915
653 0.0002 + 0.0916
654 0.0002 + 0.0932
655 0.0002 + 0.0933
656 0.0002 + 0.0930
657 0.0002 + 0.0922
658 0.0002 + 0.0931
659 0.0002 + 0.0927
660 0.0002 + 0.0932
661 0.0002 + 0.0919
662 0.0002 + 0.0909
663 0.0002 + 0.0931
664 0.0002 + 0.0925
665 0.0002 + 0.0938
666 0.0002 + 0.0941
667 0.0002 + 0.0924
668 0.0002 + 0.0920
669 0.0002 + 0.0936
670 0.0002 + 0.0920
671 0.0002 + 0.0937
672 0.0002 + 0.0922
673 0.0002 + 0.0933
674 0.0002 + 0.0934
675 0.0002 + 0.0950
676 0.0002 + 0.0935
677 0.0002 + 0.0930
678 0.0002 + 0.0929
679 0.0002 + 0.0928
680 0.0002 + 0.0939
681 0.0002 + 0.0935
682 0.0002 + 0.0924
683 0.0002 + 0.0927
684 0.0002 + 0.0918
685 0.0002 + 0.0925
686 0.0002 + 0.0924
687 0.0002 + 0.0914
688 0.0002 + 0.0935
689 0.0002 + 0.0915
690 0.0002 + 0.0936
691 0.0002 + 0.0935
692 0.0002 + 0.0936
693 0.0002 + 0.0935
694 0.0002 + 0.0927
695 0.0002 + 0.0931
696 0.0002 + 0.0916
697 0.0002 + 0.0922
698 0.0002 + 0.0919
699 0.0002 + 0.0936
700 0.0002 + 0.0935
701 0.0002 + 0.0926
702 0.0002 + 0.0930
703 0.0002 + 0.0935
704 0.0002 + 0.0936
705 0.0002 + 0.0929
706 0.0002 + 0.0936
707 0.0002 + 0.0917
708 0.0002 + 0.0935
709 0.0002 + 0.0933
710 0.0002 + 0.0931
711 0.0002 + 0.0925
712 0.0002 + 0.0922
713 0.0002 + 0.0931
714 0.0002 + 0.0925
715 0.0002 + 0.0935
716 0.0002 + 0.0935
717 0.0002 + 0.0911
718 0.0002 + 0.0926
719 0.0002 + 0.0921
720 0.0002 + 0.0919
721 0.0002 + 0.0925
722 0.0002 + 0.0928
723 0.0002 + 0.0930
724 0.0002 + 0.0926
725 0.0002 + 0.0946
726 0.0002 + 0.0917
727 0.0002 + 0.0937
728 0.0002 + 0.0918
729 0.0002 + 0.0925
730 0.0002 + 0.0928
731 0.0002 + 0.0923
732 0.0002 + 0.0936
733 0.0002 + 0.0931
734 0.0002 + 0.0927
735 0.0002 + 0.0926
736 0.0002 + 0.0927
737 0.0002 + 0.0928
738 0.0002 + 0.0935
739 0.0002 + 0.0921
740 0.0002 + 0.0936
741 0.0002 + 0.0932
742 0.0002 + 0.0939
743 0.0002 + 0.0938
744 0.0002 + 0.0926
745 0.0002 + 0.0919
746 0.0002 + 0.0943
747 0.0002 + 0.0932
748 0.0002 + 0.0936
749 0.0002 + 0.0938
750 0.0002 + 0.0927
751 0.0002 + 0.0928
752 0.0002 + 0.0928
753 0.0002 + 0.0922
754 0.0002 + 0.0942
755 0.0002 + 0.0928
756 0.0002 + 0.0945
757 0.0001 + 0.0936
758 0.0001 + 0.0950
759 0.0001 + 0.0938
760 0.0001 + 0.0929
761 0.0001 + 0.0930
762 0.0001 + 0.0930
763 0.0001 + 0.0930
764 0.0001 + 0.0927
765 0.0001 + 0.0929
766 0.0001 + 0.0926
767 0.0001 + 0.0924
768 0.0001 + 0.0925
769 0.0001 + 0.0944
770 0.0001 + 0.0935
771 0.0001 + 0.0939
772 0.0001 + 0.0928
773 0.0001 + 0.0942
774 0.0001 + 0.0931
775 0.0001 + 0.0928
776 0.0001 + 0.0927
777 0.0001 + 0.0934
778 0.0001 + 0.0926
779 0.0001 + 0.0944
780 0.0001 + 0.0925
781 0.0001 + 0.0942
782 0.0001 + 0.0937
783 0.0001 + 0.0937
784 0.0001 + 0.0926
785 0.0001 + 0.0930
786 0.0001 + 0.0936
787 0.0001 + 0.0952
788 0.0001 + 0.0932
789 0.0001 + 0.0933
790 0.0001 + 0.0929
791 0.0001 + 0.0946
792 0.0001 + 0.0940
793 0.0001 + 0.0935
794 0.0001 + 0.0915
795 0.0001 + 0.0933
796 0.0001 + 0.0922
797 0.0001 + 0.0944
798 0.0001 + 0.0929
799 0.0001 + 0.0914
800 0.0001 + 0.0914
801 0.0001 + 0.0940
802 0.0001 + 0.0923
803 0.0001 + 0.0941
804 0.0001 + 0.0931
805 0.0001 + 0.0932
806 0.0001 + 0.0932
807 0.0001 + 0.0940
808 0.0001 + 0.0924
809 0.0001 + 0.0933
810 0.0001 + 0.0933
811 0.0001 + 0.0927
812 0.0001 + 0.0919
813 0.0001 + 0.0935
814 0.0001 + 0.0928
815 0.0001 + 0.0916
816 0.0001 + 0.0925
817 0.0001 + 0.0924
818 0.0001 + 0.0919
819 0.0001 + 0.0937
820 0.0001 + 0.0933
821 0.0001 + 0.0942
822 0.0001 + 0.0943
823 0.0001 + 0.0944
824 0.0001 + 0.0938
825 0.0001 + 0.0936
826 0.0001 + 0.0920
827 0.0001 + 0.0924
828 0.0001 + 0.0937
829 0.0001 + 0.0927
830 0.0001 + 0.0933
831 0.0001 + 0.0919
832 0.0001 + 0.0920
833 0.0001 + 0.0932
834 0.0001 0.0926
835 0.0001 0.0927
836 0.0001 0.0949
837 0.0001 0.0930
838 0.0002 0.0938
839 0.0002 0.0923
840 0.0003 0.0934
841 0.0003 0.0937
842 0.0006 0.0924
843 0.0006 0.0936
844 0.0006 0.0941
845 0.0002 0.0918
846 0.0002 0.0921
847 0.0004 0.0948
848 0.0004 0.0930
849 0.0004 0.0935
850 0.0002 0.0924
851 0.0002 0.0927
852 0.0004 0.0923
853 0.0004 0.0915
854 0.0003 0.0928
855 0.0002 0.0927
856 0.0003 0.0923
857 0.0004 0.0942
858 0.0003 0.0933
859 0.0001 0.0942
860 0.0002 0.0912
861 0.0002 0.0931
862 0.0002 0.0935
863 0.0001 0.0936
864 0.0002 0.0928
865 0.0002 0.0920
866 0.0001 0.0920
867 0.0001 0.0931
868 0.0002 0.0948
869 0.0002 0.0940
870 0.0001 0.0935
871 0.0001 0.0928
872 0.0002 0.0939
873 0.0002 0.0923
874 0.0001 + 0.0912
875 0.0001 0.0939
876 0.0001 0.0927
877 0.0001 0.0914
878 0.0001 + 0.0929
879 0.0001 0.0927
880 0.0001 0.0938
881 0.0001 0.0935
882 0.0001 + 0.0935
883 0.0001 0.0926
884 0.0001 0.0926
885 0.0001 0.0930
886 0.0001 + 0.0919
887 0.0001 0.0928
888 0.0001 0.0941
889 0.0001 + 0.0929
890 0.0001 0.0962
891 0.0001 0.0940
892 0.0001 0.0918
893 0.0001 + 0.0913
894 0.0001 0.0926
895 0.0001 0.1027
896 0.0001 0.0925
897 0.0001 + 0.0934
898 0.0001 0.0924
899 0.0001 0.0934
900 0.0001 0.0934
901 0.0001 + 0.0932
902 0.0001 0.0948
903 0.0001 0.0929
904 0.0001 + 0.0934
905 0.0001 + 0.0933
906 0.0001 + 0.0927
907 0.0001 0.0939
908 0.0001 + 0.0921
909 0.0001 + 0.0937
910 0.0001 + 0.0941
911 0.0001 + 0.0929
912 0.0001 + 0.0920
913 0.0001 + 0.0930
914 0.0001 + 0.0941
915 0.0001 + 0.0933
916 0.0001 + 0.0944
917 0.0001 + 0.0933
918 0.0001 + 0.0943
919 0.0001 + 0.0938
920 0.0001 + 0.0932
921 0.0001 + 0.0941
922 0.0001 + 0.0939
923 0.0001 + 0.0932
924 0.0001 + 0.0927
925 0.0001 + 0.0939
926 0.0001 + 0.0922
927 0.0001 + 0.0944
928 0.0001 + 0.0932
929 0.0001 + 0.0922
930 0.0001 + 0.0931
931 0.0001 + 0.0929
932 0.0001 + 0.0941
933 0.0001 + 0.0942
934 0.0001 + 0.0952
935 0.0001 + 0.0943
936 0.0001 + 0.0943
937 0.0001 + 0.0932
938 0.0001 + 0.0935
939 0.0001 + 0.0929
940 0.0001 + 0.0933
941 0.0001 + 0.0937
942 0.0001 + 0.0924
943 0.0001 + 0.0924
944 0.0001 + 0.0935
945 0.0001 + 0.0925
946 0.0001 + 0.0951
947 0.0001 + 0.0929
948 0.0001 + 0.0939
949 0.0001 + 0.0946
950 0.0001 + 0.0943
951 0.0001 + 0.0942
952 0.0001 + 0.0920
953 0.0001 + 0.0924
954 0.0001 + 0.0941
955 0.0001 + 0.0926
956 0.0001 + 0.0938
957 0.0001 + 0.0941
958 0.0001 + 0.0932
959 0.0001 0.0929
960 0.0001 0.0944
961 0.0001 0.0926
962 0.0001 0.0936
963 0.0001 0.0946
964 0.0001 0.0929
965 0.0002 0.0938
966 0.0003 0.0933
967 0.0006 0.0938
968 0.0006 0.0933
969 0.0006 0.0936
970 0.0003 0.0928
971 0.0002 0.0919
972 0.0003 0.0928
973 0.0003 0.0935
974 0.0002 0.0918
975 0.0002 0.0928
976 0.0002 0.0934
977 0.0002 0.0933
978 0.0002 0.0929
979 0.0002 0.0925
980 0.0002 0.0936
981 0.0002 0.0931
982 0.0002 0.0929
983 0.0001 0.0942
984 0.0002 0.0938
985 0.0002 0.0926
986 0.0001 0.0935
987 0.0001 0.0929
988 0.0001 0.0943
989 0.0002 0.0939
990 0.0001 0.0935
991 0.0001 0.0938
992 0.0001 0.0922
993 0.0001 0.0921
994 0.0001 0.0935
995 0.0001 0.0925
996 0.0001 0.0937
997 0.0001 0.0947
998 0.0001 0.0929
999 0.0001 0.0927
1000 0.0001 0.0933
Loading the best network from the last checkpoint.
[4]:
<skorch.callbacks.training.Checkpoint at 0x7f002b99b910>
Plot training evolution
[5]:
if config['plot']['training progress']:
history_length = len(tsp.ttr.regressor_['regressor'].history)
train_loss = np.zeros((history_length, 1))
for epoch in tsp.ttr.regressor_['regressor'].history:
epoch_number = epoch['epoch']-1
train_loss[epoch_number] = epoch['train_loss']
_, axes_one = plt.subplots(figsize=(20, 20))
axes_one.plot(train_loss, 'o-', label='training')
axes_one.set_xlabel('Epoch')
axes_one.set_ylabel('MSE')
plt.legend()
Prediction on training data
[6]:
if config['predict on training data enabled']:
dataloader = tsp.ttr.regressor['regressor'].get_iterator(tsp.dataset)
x, y = dataloader.dataset[:]
netout = tsp.predict(x)
d_output = netout.shape[-1]
if config['plot']['prediction on training data']:
fig, axs = plt.subplots(d_output, 1, figsize=(20,20))
axs = [axs]
idx_range = [np.random.randint(0, len(tsp.dataset))]
for idx in idx_range:
if config['plot']['prediction on training data']:
x_absciss = convert_year_month_array_to_datetime(x[idx, :, :])
for idx_output_var in range(d_output):
# Select real passengers data
y_true = y[idx, :, idx_output_var]
y_pred = netout[idx, :, idx_output_var]
if config['plot']['prediction on training data']:
ax = axs[idx_output_var]
ax.plot(x_absciss, y_true, label="Truth", color='tab:blue')
ax.plot(x_absciss, y_pred, label="Prediction", color='tab:orange')
if idx == idx_range[0]:
ax.set_title(tsp.dataset.labels['y'][idx_output_var] + ' over time')
ax.set_xlabel('date')
ax.set_ylabel(tsp.dataset.labels['y'][idx_output_var])
ax.legend()
if config['plot']['prediction on training data']:
plt.show()
Future forecast
[7]:
# def make_x_pred(x, n_months_ahead, include_history=True):
# def raw_add_months(sourcedate, n_months):
# month = sourcedate.month - 1 + n_months
# year = sourcedate.year + month // 12
# month = month % 12 + 1
# day = min(sourcedate.day, calendar.monthrange(year,month)[1])
# return datetime(year, month, day)
# def add_months(months, n_months):
# return [raw_add_months(month, n_months) for month in months]
# def add_months_2(month, n_months):
# return [raw_add_months(month, n_month) for n_month in n_months]
# x_dates = convert_year_month_array_to_datetime(x.squeeze())
# last_month = x_dates[-1:][0]
# next_n_months = add_months_2(last_month, np.arange(n_months_ahead)+1)
# if include_history:
# n_months = x_dates + next_n_months
# else:
# n_months = next_n_months
# x_pred = np.array([[dt.month, dt.year] for dt in n_months]).reshape(1, -1, 2).astype(np.float32)
# return x_pred
if config['forecast enabled']:
dataloader = tsp.ttr.regressor['regressor'].get_iterator(tsp.dataset)
x, y = dataloader.dataset[:]
# x_pred = make_x_pred(x, config['forecast']['months ahead'])
# netout = tsp.predict(x_pred)
netout, x_pred = tsp.forecast(config['forecast']['months ahead'], include_history = config['forecast']['include history'])
d_output = netout.shape[-1]
if config['plot']['forecast']:
fig, axs = plt.subplots(d_output, 1, figsize=(20,20))
axs = [axs]
idx_range = [np.random.randint(0, len(tsp.dataset))]
for idx in idx_range:
if config['plot']['forecast']:
x_absciss = convert_year_month_array_to_datetime(x[idx, :, :])
x_pred_absciss = convert_year_month_array_to_datetime(x_pred[idx, :, :])
for idx_output_var in range(d_output):
# Select real passengers data
y_true = y[idx, :, idx_output_var]
y_pred = netout[idx, :, idx_output_var]
if config['plot']['prediction on validation data']:
ax = axs[idx_output_var]
ax.plot(x_pred_absciss, y_pred, label="Prediction", color='tab:orange')
ax.plot(x_absciss, y_true, label="Truth", color='tab:blue')
if idx == idx_range[0]:
ax.set_title(tsp.dataset.labels['y'][idx_output_var] + ' over time')
ax.set_xlabel('date')
ax.set_ylabel(tsp.dataset.labels['y'][idx_output_var])
ax.legend()
if config['plot']['prediction on validation data']:
plt.show()
[ ]: