本示例 notebook 演示了 HierarchicalForecast 的协调方法与流行的机器学习库的兼容性,特别是 NeuralForecastMLForecast

该 notebook 使用 NBEATS 和 XGBRegressor 模型为 TourismLarge 分层数据集创建基础预测。之后,我们使用 HierarchicalForecast 来协调基础预测。

参考文献
- Boris N. Oreshkin, Dmitri Carpov, Nicolas Chapados, Yoshua Bengio (2019). “N-BEATS: Neural basis expansion analysis for interpretable time series forecasting”. url: https://arxiv.org/abs/1905.10437
- Tianqi Chen and Carlos Guestrin. “XGBoost: A Scalable Tree Boosting System”. In: Proceedings of the 22nd ACM SIGKDD International Conference on Knowledge Discovery and Data Mining. KDD ’16. San Francisco, California, USA: Association for Computing Machinery, 2016, pp. 785–794. isbn: 9781450342322. doi: 10.1145/2939672.2939785. url: https://doi.org/10.1145/2939672.2939785 (cit. on p. 26).

您可以使用 Google Colab 通过 CPU 或 GPU 运行这些实验。

1. 安装包

!pip install datasetsforecast hierarchicalforecast mlforecast neuralforecast
import numpy as np
import pandas as pd

from datasetsforecast.hierarchical import HierarchicalData

from neuralforecast import NeuralForecast
from neuralforecast.models import NBEATS
from neuralforecast.losses.pytorch import GMM

from mlforecast import MLForecast
from mlforecast.utils import PredictionIntervals
import xgboost as xgb

#obtain hierarchical reconciliation methods and evaluation
from hierarchicalforecast.methods import BottomUp, ERM, MinTrace
from hierarchicalforecast.utils import HierarchicalPlot
from hierarchicalforecast.core import HierarchicalReconciliation
from hierarchicalforecast.evaluation import evaluate

2. 加载分层数据集

这个详细的澳大利亚旅游数据集来自澳大利亚旅游研究局管理的国家游客调查,由 1998 年至 2016 年的 555 个月度序列组成,按地理位置和旅行目的进行组织。自然地理层级包括七个州,进一步划分为 27 个区域和 76 个地区。旅行目的类别包括度假、探亲访友 (VFR)、商务及其他。MinT (Wickramasuriya et al., 2019) 以及其他分层预测研究过去也使用过该数据集。该数据集可在 MinT 协调网页获取,尽管也有其他来源。

地理划分各划分的序列数各目的的序列数总计
澳大利亚145
72835
区域27108135
地区76304380
总计111444555
Y_df, S_df, tags = HierarchicalData.load('./data', 'TourismLarge')
Y_df['ds'] = pd.to_datetime(Y_df['ds'])
S_df = S_df.reset_index(names="unique_id")
Y_df.head()
unique_iddsy
0TotalAll1998-01-0145151.071280
1TotalAll1998-02-0117294.699551
2TotalAll1998-03-0120725.114184
3TotalAll1998-04-0125388.612353
4TotalAll1998-05-0120330.035211

可视化聚合矩阵。

hplot = HierarchicalPlot(S=S_df, tags=tags)
hplot.plot_summing_matrix()

将 dataframe 拆分为训练集/测试集。

horizon = 12
Y_test_df = Y_df.groupby('unique_id', as_index=False).tail(horizon)
Y_train_df = Y_df.drop(Y_test_df.index)

3. 拟合和预测模型

HierarchicalForecast 与许多不同的 ML 模型兼容。这里,我们展示两个例子
1. NBEATS,一种基于 MLP 的深度神经网络架构。
2. XGBRegressor,一种基于树的架构。

level = np.arange(0, 100, 2)
qs = [[50-lv/2, 50+lv/2] for lv in level]
quantiles = np.sort(np.concatenate(qs)[1:]/100)

#fit/predict NBEATS from NeuralForecast
nbeats = NBEATS(h=horizon,
              input_size=2*horizon,
              loss=GMM(n_components=10, quantiles=quantiles),
              scaler_type='robust',
              max_steps=2000)
nf = NeuralForecast(models=[nbeats], freq='MS')
nf.fit(df=Y_train_df)
Y_hat_nf = nf.predict()
insample_nf = nf.predict_insample(step_size=horizon)

#fit/predict XGBRegressor from MLForecast
mf = MLForecast(models=[xgb.XGBRegressor()], 
                freq='MS',
                lags=[1,2,12,24],
                date_features=['month'],
                )
mf.fit(Y_train_df, fitted=True, prediction_intervals=PredictionIntervals(n_windows=10, h=horizon)) 
Y_hat_mf = mf.predict(horizon, level=level)
insample_mf = mf.forecast_fitted_values()
Y_hat_nf
unique_iddsNBEATSNBEATS-lo-98.0NBEATS-lo-96.0NBEATS-lo-94.0NBEATS-lo-92.0NBEATS-lo-90.0NBEATS-lo-88.0NBEATS-lo-86.0NBEATS-hi-80.0NBEATS-hi-82.0NBEATS-hi-84.0NBEATS-hi-86.0NBEATS-hi-88.0NBEATS-hi-90.0NBEATS-hi-92.0NBEATS-hi-94.0NBEATS-hi-96.0NBEATS-hi-98.0
0AAAAll2016-01-012843.2985841764.2490231806.8851321864.0190431906.1710211945.9946291965.0814211998.6068123497.6823733520.1076663561.6437993600.1210943646.9543463703.3823243774.0844733813.7192383902.7138673991.594238
1AAAAll2016-02-011753.3406981394.2458501414.4749761439.1674801458.2283941474.6556401480.4334721489.6512452024.5607912049.9655762066.4809572090.2851562120.1728522145.9648442201.7160642253.4150392364.9050292441.167480
2AAAAll2016-03-011878.6751711446.6303711491.6378171513.8901371524.7878421532.5399171547.4602051559.0983892172.2709962189.4899902216.2558592236.6613772286.6176762370.4311522411.9101562477.5573732579.6110842722.415283
3AAAAll2016-04-012140.9484861661.7377931706.2593991724.9145511736.4460451754.8876951765.4820561772.1239012470.2065432483.5710452493.5275882517.0627442547.3557132577.8676762610.1809082637.0104982700.8017582864.596924
4AAAAll2016-05-011834.6949461466.3142091485.4270021500.7152101518.4620361535.3864751543.5256351554.4298102093.7006842120.7824712137.8828122154.0520022164.0698242189.3093262234.2719732311.1577152436.2670902659.653809
6655TotalVis2016-08-017362.4550785799.1215825960.6762706073.5532236230.0908206294.1914066365.9506846400.4926768120.2797858144.1396488185.6992198212.8095708255.8710948291.1914068374.9072278435.8066418568.0605478770.566406
6656TotalVis2016-09-017803.0981456455.0502936612.8471686690.9609386804.8974616848.4326176873.6074226904.7700208562.2158208594.0000008642.0839848715.2011728795.6289068924.5732429053.7470709250.5146489410.3388679818.623047
6657TotalVis2016-10-018478.5703126592.3500986818.8837897075.3237307223.6821297300.2309577336.7407237391.7797859558.6113289586.3339849658.8164069761.4482429802.0878919870.2949229956.14453110070.67285210195.40820310342.619141
6658TotalVis2016-11-018251.8164066471.7539066551.8613286621.6474616694.9921886740.8271486798.8247076825.7944349519.8251959557.5078129624.8222669720.2695319811.0117199907.25976610132.62890610362.58398410896.47851611394.652344
6659TotalVis2016-12-019023.3349616798.5156256978.4116217165.8051767250.1069347333.1684577395.1835947457.47021510221.93750010290.52734410334.88378910399.72656210553.36035210645.85253910806.29589810992.41601611328.15136711933.357422
Y_hat_mf
unique_iddsXGBRegressorXGBRegressor-lo-98XGBRegressor-lo-96XGBRegressor-lo-94XGBRegressor-lo-92XGBRegressor-lo-90XGBRegressor-lo-88XGBRegressor-lo-86XGBRegressor-hi-80XGBRegressor-hi-82XGBRegressor-hi-84XGBRegressor-hi-86XGBRegressor-hi-88XGBRegressor-hi-90XGBRegressor-hi-92XGBRegressor-hi-94XGBRegressor-hi-96XGBRegressor-hi-98
0AAAAll2016-01-013240.7431642566.4046202638.9849952711.5653702784.1457452856.7261202876.5141982877.4478843601.2373863602.1710723603.1047583604.0384443604.9721303624.7602083697.3405833769.9209583842.5013333915.081708
1AAAAll2016-02-011583.0650631247.4144691248.8953431250.3762171251.8570911253.3379651263.6273401277.0626101848.7617091862.1969781875.6322481889.0675171902.5027871912.7921621914.2730361915.7539101917.2347841918.715658
2AAAAll2016-03-012030.1682131345.8964971386.6550461427.4135951468.1721441508.9306931546.2073371582.2404442369.9966602406.0297672442.0628742478.0959812514.1290892551.4057332592.1642822632.9228312673.6813802714.439928
3AAAAll2016-04-012152.2822271767.2766111772.9560491778.6354871784.3149261789.9943641798.5035841808.0234392467.9814482477.5013032487.0211592496.5410142506.0608702514.5700892520.2495272525.9289662531.6084042537.287842
4AAAAll2016-05-011970.8947751476.7619731510.6674301544.5728871578.4783441612.3838011625.4480721631.0690622293.8575192299.4785092305.0994992310.7204892316.3414792329.4057502363.3112072397.2166642431.1221212465.027578
6655TotalVis2016-08-017810.4658206251.0796746268.9247276286.7697806304.6148336322.4598866375.9777726442.2359568979.9211359046.1793189112.4375019178.6956859244.9538689298.4717549316.3168079334.1618609352.0069139369.851967
6656TotalVis2016-09-016887.8935555346.4779595397.7950655449.1121705500.4292755551.7463805604.1241125656.8806387960.6368938013.3934198066.1499458118.9064728171.6629988224.0407298275.3578348326.6749408377.9920458429.309150
6657TotalVis2016-10-017763.2758796138.5347386267.7402816396.9458246526.1513676655.3569106706.0091946728.6067448730.1523668752.7499168775.3474658797.9450148820.5425638871.1948489000.4003919129.6059349258.8114779388.017020
6658TotalVis2016-11-017432.7221685703.3951485726.9262425750.4573365773.9884305797.5195245929.1646986099.4220438255.2502588425.5076038595.7649488766.0222938936.2796389067.9248119091.4559059114.9869999138.5180939162.049187
6659TotalVis2016-12-019624.1728528115.7054988217.3810778319.0566558420.7322348522.4078128566.5818838590.21970110587.21254810610.85036610634.48818410658.12600210681.76382010725.93789110827.61347010929.28904811030.96462611132.640205

4. 协调预测

只需少量解析,我们就可以使用不同的 HierarchicalForecast 协调方法来协调原始输出预测。

reconcilers = [
    ERM(method='closed'),
    BottomUp(),
    MinTrace('mint_shrink'),
]
hrec = HierarchicalReconciliation(reconcilers=reconcilers)

Y_rec_nf = hrec.reconcile(Y_hat_df=Y_hat_nf, Y_df=insample_nf, S=S_df, tags=tags, level=level)
Y_rec_mf = hrec.reconcile(Y_hat_df=Y_hat_mf, Y_df=insample_mf, S=S_df, tags=tags, level=level)

5. 评估

为了进行评估,我们使用 CRPS 的缩放变体,正如 Rangapuram (2021) 所提出的,以衡量预测分位数 y_hat 相对于观测值 y 的准确性。

sCRPS(F^τ,yτ)=2Ni01QL(F^i,τ,yi,τ)qiyi,τdq \mathrm{sCRPS}(\hat{F}_{\tau}, \mathbf{y}_{\tau}) = \frac{2}{N} \sum_{i} \int^{1}_{0} \frac{\mathrm{QL}(\hat{F}_{i,\tau}, y_{i,\tau})_{q}}{\sum_{i} | y_{i,\tau} |} dq

我们发现,XGB 与 MinTrace(mint_shrink) 协调方法在测试集上产生了最低的 CRPS 分数,从而为我们提供了最佳的概率预测。

from utilsforecast.losses import scaled_crps
rec_model_names_nf = ['NBEATS/BottomUp', 'NBEATS/MinTrace_method-mint_shrink', 'NBEATS/ERM_method-closed_lambda_reg-0.01']

evaluation_nf = evaluate(df = Y_rec_nf.merge(Y_test_df, on=['unique_id', 'ds']),
                      tags = tags,
                      metrics = [scaled_crps],
                      models= rec_model_names_nf,
                      level = list(range(0, 100, 2)),
                      )

rec_model_names_mf = ['XGBRegressor/BottomUp', 'XGBRegressor/MinTrace_method-mint_shrink', 'XGBRegressor/ERM_method-closed_lambda_reg-0.01']

evaluation_mf = evaluate(df = Y_rec_mf.merge(Y_test_df, on=['unique_id', 'ds']),
                      tags = tags,
                      metrics = [scaled_crps],
                      models= rec_model_names_mf,
                      level = list(range(0, 100, 2)),
                      )
name = 'NBEATS/BottomUp'
quantile_columns = [col for col in Y_rec_mf.columns if (name+'-lo') in col or (name+'-hi') in col]
evaluation_nf.query("level == 'Overall'")
层级指标NBEATS/BottomUpNBEATS/MinTrace_method-mint_shrinkNBEATS/ERM_method-closed_lambda_reg-0.01
8总体scaled_crps2.5232122.432052.645045
evaluation_mf.query("level == 'Overall'")
层级指标XGBRegressor/BottomUpXGBRegressor/MinTrace_method-mint_shrinkXGBRegressor/ERM_method-closed_lambda_reg-0.01
8总体scaled_crps1.982551.449811.910014

6. 可视化

plot_nf = Y_df.merge(Y_rec_nf, on=['unique_id', 'ds'], how="outer")

plot_mf = Y_df.merge(Y_rec_mf, on=['unique_id', 'ds'], how="outer")
hplot.plot_series(
    series='TotalVis',
    Y_df=plot_nf, 
    models=['y', 'NBEATS', 'NBEATS/BottomUp', 'NBEATS/MinTrace_method-mint_shrink', 'NBEATS/ERM_method-closed_lambda_reg-0.01'],
    level=[80]
)

hplot.plot_series(
    series='TotalVis',
    Y_df=plot_mf, 
    models=['y', 'XGBRegressor', 'XGBRegressor/BottomUp', 'XGBRegressor/MinTrace_method-mint_shrink', 'XGBRegressor/ERM_method-closed_lambda_reg-0.01'],
    level=[80]
)