Neural/MLForecast
本示例 notebook 演示了 HierarchicalForecast 的协调方法与流行的机器学习库的兼容性,特别是 NeuralForecast 和 MLForecast。
该 notebook 使用 NBEATS 和 XGBRegressor 模型为 TourismLarge 分层数据集创建基础预测。之后,我们使用 HierarchicalForecast 来协调基础预测。
参考文献
- Boris N. Oreshkin, Dmitri Carpov, Nicolas Chapados, Yoshua Bengio (2019). “N-BEATS: Neural basis expansion analysis for interpretable time series forecasting”. url: https://arxiv.org/abs/1905.10437
- Tianqi Chen and Carlos Guestrin. “XGBoost: A Scalable Tree Boosting System”. In: Proceedings of the 22nd ACM SIGKDD International Conference on Knowledge Discovery and Data Mining. KDD ’16. San Francisco, California, USA: Association for Computing Machinery, 2016, pp. 785–794. isbn: 9781450342322. doi: 10.1145/2939672.2939785. url: https://doi.org/10.1145/2939672.2939785 (cit. on p. 26).
您可以使用 Google Colab 通过 CPU 或 GPU 运行这些实验。
1. 安装包
2. 加载分层数据集
这个详细的澳大利亚旅游数据集来自澳大利亚旅游研究局管理的国家游客调查,由 1998 年至 2016 年的 555 个月度序列组成,按地理位置和旅行目的进行组织。自然地理层级包括七个州,进一步划分为 27 个区域和 76 个地区。旅行目的类别包括度假、探亲访友 (VFR)、商务及其他。MinT (Wickramasuriya et al., 2019) 以及其他分层预测研究过去也使用过该数据集。该数据集可在 MinT 协调网页获取,尽管也有其他来源。
地理划分 | 各划分的序列数 | 各目的的序列数 | 总计 |
---|---|---|---|
澳大利亚 | 1 | 4 | 5 |
州 | 7 | 28 | 35 |
区域 | 27 | 108 | 135 |
地区 | 76 | 304 | 380 |
总计 | 111 | 444 | 555 |
unique_id | ds | y | |
---|---|---|---|
0 | TotalAll | 1998-01-01 | 45151.071280 |
1 | TotalAll | 1998-02-01 | 17294.699551 |
2 | TotalAll | 1998-03-01 | 20725.114184 |
3 | TotalAll | 1998-04-01 | 25388.612353 |
4 | TotalAll | 1998-05-01 | 20330.035211 |
可视化聚合矩阵。
将 dataframe 拆分为训练集/测试集。
3. 拟合和预测模型
HierarchicalForecast 与许多不同的 ML 模型兼容。这里,我们展示两个例子
1. NBEATS,一种基于 MLP 的深度神经网络架构。
2. XGBRegressor,一种基于树的架构。
unique_id | ds | NBEATS | NBEATS-lo-98.0 | NBEATS-lo-96.0 | NBEATS-lo-94.0 | NBEATS-lo-92.0 | NBEATS-lo-90.0 | NBEATS-lo-88.0 | NBEATS-lo-86.0 | … | NBEATS-hi-80.0 | NBEATS-hi-82.0 | NBEATS-hi-84.0 | NBEATS-hi-86.0 | NBEATS-hi-88.0 | NBEATS-hi-90.0 | NBEATS-hi-92.0 | NBEATS-hi-94.0 | NBEATS-hi-96.0 | NBEATS-hi-98.0 | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
0 | AAAAll | 2016-01-01 | 2843.298584 | 1764.249023 | 1806.885132 | 1864.019043 | 1906.171021 | 1945.994629 | 1965.081421 | 1998.606812 | … | 3497.682373 | 3520.107666 | 3561.643799 | 3600.121094 | 3646.954346 | 3703.382324 | 3774.084473 | 3813.719238 | 3902.713867 | 3991.594238 |
1 | AAAAll | 2016-02-01 | 1753.340698 | 1394.245850 | 1414.474976 | 1439.167480 | 1458.228394 | 1474.655640 | 1480.433472 | 1489.651245 | … | 2024.560791 | 2049.965576 | 2066.480957 | 2090.285156 | 2120.172852 | 2145.964844 | 2201.716064 | 2253.415039 | 2364.905029 | 2441.167480 |
2 | AAAAll | 2016-03-01 | 1878.675171 | 1446.630371 | 1491.637817 | 1513.890137 | 1524.787842 | 1532.539917 | 1547.460205 | 1559.098389 | … | 2172.270996 | 2189.489990 | 2216.255859 | 2236.661377 | 2286.617676 | 2370.431152 | 2411.910156 | 2477.557373 | 2579.611084 | 2722.415283 |
3 | AAAAll | 2016-04-01 | 2140.948486 | 1661.737793 | 1706.259399 | 1724.914551 | 1736.446045 | 1754.887695 | 1765.482056 | 1772.123901 | … | 2470.206543 | 2483.571045 | 2493.527588 | 2517.062744 | 2547.355713 | 2577.867676 | 2610.180908 | 2637.010498 | 2700.801758 | 2864.596924 |
4 | AAAAll | 2016-05-01 | 1834.694946 | 1466.314209 | 1485.427002 | 1500.715210 | 1518.462036 | 1535.386475 | 1543.525635 | 1554.429810 | … | 2093.700684 | 2120.782471 | 2137.882812 | 2154.052002 | 2164.069824 | 2189.309326 | 2234.271973 | 2311.157715 | 2436.267090 | 2659.653809 |
… | … | … | … | … | … | … | … | … | … | … | … | … | … | … | … | … | … | … | … | … | … |
6655 | TotalVis | 2016-08-01 | 7362.455078 | 5799.121582 | 5960.676270 | 6073.553223 | 6230.090820 | 6294.191406 | 6365.950684 | 6400.492676 | … | 8120.279785 | 8144.139648 | 8185.699219 | 8212.809570 | 8255.871094 | 8291.191406 | 8374.907227 | 8435.806641 | 8568.060547 | 8770.566406 |
6656 | TotalVis | 2016-09-01 | 7803.098145 | 6455.050293 | 6612.847168 | 6690.960938 | 6804.897461 | 6848.432617 | 6873.607422 | 6904.770020 | … | 8562.215820 | 8594.000000 | 8642.083984 | 8715.201172 | 8795.628906 | 8924.573242 | 9053.747070 | 9250.514648 | 9410.338867 | 9818.623047 |
6657 | TotalVis | 2016-10-01 | 8478.570312 | 6592.350098 | 6818.883789 | 7075.323730 | 7223.682129 | 7300.230957 | 7336.740723 | 7391.779785 | … | 9558.611328 | 9586.333984 | 9658.816406 | 9761.448242 | 9802.087891 | 9870.294922 | 9956.144531 | 10070.672852 | 10195.408203 | 10342.619141 |
6658 | TotalVis | 2016-11-01 | 8251.816406 | 6471.753906 | 6551.861328 | 6621.647461 | 6694.992188 | 6740.827148 | 6798.824707 | 6825.794434 | … | 9519.825195 | 9557.507812 | 9624.822266 | 9720.269531 | 9811.011719 | 9907.259766 | 10132.628906 | 10362.583984 | 10896.478516 | 11394.652344 |
6659 | TotalVis | 2016-12-01 | 9023.334961 | 6798.515625 | 6978.411621 | 7165.805176 | 7250.106934 | 7333.168457 | 7395.183594 | 7457.470215 | … | 10221.937500 | 10290.527344 | 10334.883789 | 10399.726562 | 10553.360352 | 10645.852539 | 10806.295898 | 10992.416016 | 11328.151367 | 11933.357422 |
unique_id | ds | XGBRegressor | XGBRegressor-lo-98 | XGBRegressor-lo-96 | XGBRegressor-lo-94 | XGBRegressor-lo-92 | XGBRegressor-lo-90 | XGBRegressor-lo-88 | XGBRegressor-lo-86 | … | XGBRegressor-hi-80 | XGBRegressor-hi-82 | XGBRegressor-hi-84 | XGBRegressor-hi-86 | XGBRegressor-hi-88 | XGBRegressor-hi-90 | XGBRegressor-hi-92 | XGBRegressor-hi-94 | XGBRegressor-hi-96 | XGBRegressor-hi-98 | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
0 | AAAAll | 2016-01-01 | 3240.743164 | 2566.404620 | 2638.984995 | 2711.565370 | 2784.145745 | 2856.726120 | 2876.514198 | 2877.447884 | … | 3601.237386 | 3602.171072 | 3603.104758 | 3604.038444 | 3604.972130 | 3624.760208 | 3697.340583 | 3769.920958 | 3842.501333 | 3915.081708 |
1 | AAAAll | 2016-02-01 | 1583.065063 | 1247.414469 | 1248.895343 | 1250.376217 | 1251.857091 | 1253.337965 | 1263.627340 | 1277.062610 | … | 1848.761709 | 1862.196978 | 1875.632248 | 1889.067517 | 1902.502787 | 1912.792162 | 1914.273036 | 1915.753910 | 1917.234784 | 1918.715658 |
2 | AAAAll | 2016-03-01 | 2030.168213 | 1345.896497 | 1386.655046 | 1427.413595 | 1468.172144 | 1508.930693 | 1546.207337 | 1582.240444 | … | 2369.996660 | 2406.029767 | 2442.062874 | 2478.095981 | 2514.129089 | 2551.405733 | 2592.164282 | 2632.922831 | 2673.681380 | 2714.439928 |
3 | AAAAll | 2016-04-01 | 2152.282227 | 1767.276611 | 1772.956049 | 1778.635487 | 1784.314926 | 1789.994364 | 1798.503584 | 1808.023439 | … | 2467.981448 | 2477.501303 | 2487.021159 | 2496.541014 | 2506.060870 | 2514.570089 | 2520.249527 | 2525.928966 | 2531.608404 | 2537.287842 |
4 | AAAAll | 2016-05-01 | 1970.894775 | 1476.761973 | 1510.667430 | 1544.572887 | 1578.478344 | 1612.383801 | 1625.448072 | 1631.069062 | … | 2293.857519 | 2299.478509 | 2305.099499 | 2310.720489 | 2316.341479 | 2329.405750 | 2363.311207 | 2397.216664 | 2431.122121 | 2465.027578 |
… | … | … | … | … | … | … | … | … | … | … | … | … | … | … | … | … | … | … | … | … | … |
6655 | TotalVis | 2016-08-01 | 7810.465820 | 6251.079674 | 6268.924727 | 6286.769780 | 6304.614833 | 6322.459886 | 6375.977772 | 6442.235956 | … | 8979.921135 | 9046.179318 | 9112.437501 | 9178.695685 | 9244.953868 | 9298.471754 | 9316.316807 | 9334.161860 | 9352.006913 | 9369.851967 |
6656 | TotalVis | 2016-09-01 | 6887.893555 | 5346.477959 | 5397.795065 | 5449.112170 | 5500.429275 | 5551.746380 | 5604.124112 | 5656.880638 | … | 7960.636893 | 8013.393419 | 8066.149945 | 8118.906472 | 8171.662998 | 8224.040729 | 8275.357834 | 8326.674940 | 8377.992045 | 8429.309150 |
6657 | TotalVis | 2016-10-01 | 7763.275879 | 6138.534738 | 6267.740281 | 6396.945824 | 6526.151367 | 6655.356910 | 6706.009194 | 6728.606744 | … | 8730.152366 | 8752.749916 | 8775.347465 | 8797.945014 | 8820.542563 | 8871.194848 | 9000.400391 | 9129.605934 | 9258.811477 | 9388.017020 |
6658 | TotalVis | 2016-11-01 | 7432.722168 | 5703.395148 | 5726.926242 | 5750.457336 | 5773.988430 | 5797.519524 | 5929.164698 | 6099.422043 | … | 8255.250258 | 8425.507603 | 8595.764948 | 8766.022293 | 8936.279638 | 9067.924811 | 9091.455905 | 9114.986999 | 9138.518093 | 9162.049187 |
6659 | TotalVis | 2016-12-01 | 9624.172852 | 8115.705498 | 8217.381077 | 8319.056655 | 8420.732234 | 8522.407812 | 8566.581883 | 8590.219701 | … | 10587.212548 | 10610.850366 | 10634.488184 | 10658.126002 | 10681.763820 | 10725.937891 | 10827.613470 | 10929.289048 | 11030.964626 | 11132.640205 |
4. 协调预测
只需少量解析,我们就可以使用不同的 HierarchicalForecast 协调方法来协调原始输出预测。
5. 评估
为了进行评估,我们使用 CRPS 的缩放变体,正如 Rangapuram (2021) 所提出的,以衡量预测分位数 y_hat
相对于观测值 y
的准确性。
我们发现,XGB 与 MinTrace(mint_shrink) 协调方法在测试集上产生了最低的 CRPS 分数,从而为我们提供了最佳的概率预测。
层级 | 指标 | NBEATS/BottomUp | NBEATS/MinTrace_method-mint_shrink | NBEATS/ERM_method-closed_lambda_reg-0.01 | |
---|---|---|---|---|---|
8 | 总体 | scaled_crps | 2.523212 | 2.43205 | 2.645045 |
层级 | 指标 | XGBRegressor/BottomUp | XGBRegressor/MinTrace_method-mint_shrink | XGBRegressor/ERM_method-closed_lambda_reg-0.01 | |
---|---|---|---|---|---|
8 | 总体 | scaled_crps | 1.98255 | 1.44981 | 1.910014 |