组织成不同聚合级别结构的大量时间序列,通常要求它们的预测结果遵循聚合约束并为非负值,这带来了创建能够生成一致预测的新算法的挑战。

`HierarchicalForecast` 软件包提供了广泛的 Python 实现的分层预测算法,这些算法遵循非负分层预测调节。

在本 notebook 中,我们将展示如何使用 `HierarchicalForecast` 软件包在 `Wiki2` 数据集上执行非负预测调节。

您可以使用 Google Colab 在 CPU 或 GPU 上运行这些实验。

!pip install hierarchicalforecast statsforecast datasetsforecast

1. 加载数据

在此示例中,我们将使用 `Wiki2` 数据集。以下单元格获取层次结构中不同级别的时间序列,用于从底层层次结构恢复完整数据集的求和数据框 `S_df`,以及用 `tags` 表示的每个层次结构的索引。

import numpy as np
import pandas as pd

from datasetsforecast.hierarchical import HierarchicalData
Y_df, S_df, tags = HierarchicalData.load('./data', 'Wiki2')
Y_df['ds'] = pd.to_datetime(Y_df['ds'])
S_df = S_df.reset_index(names="unique_id")
Y_df.head()
unique_iddsy
0Total2016-01-01156508
1Total2016-01-02129902
2Total2016-01-03138203
3Total2016-01-04115017
4Total2016-01-05126042
S_df.iloc[:5, :5]
unique_idde_AAC_AAG_001de_AAC_AAG_010de_AAC_AAG_014de_AAC_AAG_045
0Total1111
1de1111
2en0000
3fr0000
4ja0000
tags
{'Views': array(['Total'], dtype=object),
 'Views/Country': array(['de', 'en', 'fr', 'ja', 'ru', 'zh'], dtype=object),
 'Views/Country/Access': array(['de_AAC', 'de_DES', 'de_MOB', 'en_AAC', 'en_DES', 'en_MOB',
        'fr_AAC', 'fr_DES', 'fr_MOB', 'ja_AAC', 'ja_DES', 'ja_MOB',
        'ru_AAC', 'ru_DES', 'ru_MOB', 'zh_AAC', 'zh_DES', 'zh_MOB'],
       dtype=object),
 'Views/Country/Access/Agent': array(['de_AAC_AAG', 'de_AAC_SPD', 'de_DES_AAG', 'de_MOB_AAG',
        'en_AAC_AAG', 'en_AAC_SPD', 'en_DES_AAG', 'en_MOB_AAG',
        'fr_AAC_AAG', 'fr_AAC_SPD', 'fr_DES_AAG', 'fr_MOB_AAG',
        'ja_AAC_AAG', 'ja_AAC_SPD', 'ja_DES_AAG', 'ja_MOB_AAG',
        'ru_AAC_AAG', 'ru_AAC_SPD', 'ru_DES_AAG', 'ru_MOB_AAG',
        'zh_AAC_AAG', 'zh_AAC_SPD', 'zh_DES_AAG', 'zh_MOB_AAG'],
       dtype=object),
 'Views/Country/Access/Agent/Topic': array(['de_AAC_AAG_001', 'de_AAC_AAG_010', 'de_AAC_AAG_014',
        'de_AAC_AAG_045', 'de_AAC_AAG_063', 'de_AAC_AAG_100',
        'de_AAC_AAG_110', 'de_AAC_AAG_123', 'de_AAC_AAG_143',
        'de_AAC_SPD_012', 'de_AAC_SPD_074', 'de_AAC_SPD_080',
        'de_AAC_SPD_105', 'de_AAC_SPD_115', 'de_AAC_SPD_133',
        'de_DES_AAG_064', 'de_DES_AAG_116', 'de_DES_AAG_131',
        'de_MOB_AAG_015', 'de_MOB_AAG_020', 'de_MOB_AAG_032',
        'de_MOB_AAG_059', 'de_MOB_AAG_062', 'de_MOB_AAG_088',
        'de_MOB_AAG_095', 'de_MOB_AAG_109', 'de_MOB_AAG_122',
        'de_MOB_AAG_149', 'en_AAC_AAG_044', 'en_AAC_AAG_049',
        'en_AAC_AAG_075', 'en_AAC_AAG_114', 'en_AAC_AAG_119',
        'en_AAC_AAG_141', 'en_AAC_SPD_004', 'en_AAC_SPD_011',
        'en_AAC_SPD_026', 'en_AAC_SPD_048', 'en_AAC_SPD_067',
        'en_AAC_SPD_126', 'en_AAC_SPD_140', 'en_DES_AAG_016',
        'en_DES_AAG_024', 'en_DES_AAG_042', 'en_DES_AAG_069',
        'en_DES_AAG_082', 'en_DES_AAG_102', 'en_MOB_AAG_018',
        'en_MOB_AAG_022', 'en_MOB_AAG_101', 'en_MOB_AAG_124',
        'fr_AAC_AAG_029', 'fr_AAC_AAG_046', 'fr_AAC_AAG_070',
        'fr_AAC_AAG_087', 'fr_AAC_AAG_098', 'fr_AAC_AAG_104',
        'fr_AAC_AAG_111', 'fr_AAC_AAG_112', 'fr_AAC_AAG_142',
        'fr_AAC_SPD_025', 'fr_AAC_SPD_027', 'fr_AAC_SPD_035',
        'fr_AAC_SPD_077', 'fr_AAC_SPD_084', 'fr_AAC_SPD_097',
        'fr_AAC_SPD_130', 'fr_DES_AAG_023', 'fr_DES_AAG_043',
        'fr_DES_AAG_051', 'fr_DES_AAG_058', 'fr_DES_AAG_061',
        'fr_DES_AAG_091', 'fr_DES_AAG_093', 'fr_DES_AAG_094',
        'fr_DES_AAG_136', 'fr_MOB_AAG_006', 'fr_MOB_AAG_030',
        'fr_MOB_AAG_066', 'fr_MOB_AAG_117', 'fr_MOB_AAG_120',
        'fr_MOB_AAG_121', 'fr_MOB_AAG_135', 'fr_MOB_AAG_147',
        'ja_AAC_AAG_038', 'ja_AAC_AAG_047', 'ja_AAC_AAG_055',
        'ja_AAC_AAG_076', 'ja_AAC_AAG_099', 'ja_AAC_AAG_128',
        'ja_AAC_AAG_132', 'ja_AAC_AAG_134', 'ja_AAC_AAG_137',
        'ja_AAC_SPD_013', 'ja_AAC_SPD_034', 'ja_AAC_SPD_050',
        'ja_AAC_SPD_060', 'ja_AAC_SPD_078', 'ja_AAC_SPD_106',
        'ja_DES_AAG_079', 'ja_DES_AAG_081', 'ja_DES_AAG_113',
        'ja_MOB_AAG_065', 'ja_MOB_AAG_073', 'ja_MOB_AAG_092',
        'ja_MOB_AAG_127', 'ja_MOB_AAG_129', 'ja_MOB_AAG_144',
        'ru_AAC_AAG_008', 'ru_AAC_AAG_145', 'ru_AAC_AAG_146',
        'ru_AAC_SPD_000', 'ru_AAC_SPD_090', 'ru_AAC_SPD_148',
        'ru_DES_AAG_003', 'ru_DES_AAG_007', 'ru_DES_AAG_017',
        'ru_DES_AAG_041', 'ru_DES_AAG_071', 'ru_DES_AAG_072',
        'ru_MOB_AAG_002', 'ru_MOB_AAG_040', 'ru_MOB_AAG_083',
        'ru_MOB_AAG_086', 'ru_MOB_AAG_103', 'ru_MOB_AAG_107',
        'ru_MOB_AAG_118', 'ru_MOB_AAG_125', 'zh_AAC_AAG_021',
        'zh_AAC_AAG_033', 'zh_AAC_AAG_037', 'zh_AAC_AAG_052',
        'zh_AAC_AAG_057', 'zh_AAC_AAG_085', 'zh_AAC_AAG_108',
        'zh_AAC_SPD_039', 'zh_AAC_SPD_096', 'zh_DES_AAG_009',
        'zh_DES_AAG_019', 'zh_DES_AAG_053', 'zh_DES_AAG_054',
        'zh_DES_AAG_056', 'zh_DES_AAG_068', 'zh_DES_AAG_089',
        'zh_DES_AAG_139', 'zh_MOB_AAG_005', 'zh_MOB_AAG_028',
        'zh_MOB_AAG_031', 'zh_MOB_AAG_036', 'zh_MOB_AAG_138'], dtype=object)}

我们将数据框拆分为训练集/测试集。

Y_test_df = Y_df.groupby('unique_id', as_index=False).tail(7)
Y_train_df = Y_df.drop(Y_test_df.index)

2. 基础预测

以下单元格使用 `AutoETS` 模型计算每个时间序列的*基础预测*。请注意,`Y_hat_df` 包含预测结果,但它们并不一致。

from statsforecast.models import AutoETS, Naive
from statsforecast.core import StatsForecast
fcst = StatsForecast(
    models=[AutoETS(season_length=7, model='ZAA'), Naive()], 
    freq='D', 
    n_jobs=-1
)
Y_hat_df = fcst.forecast(df=Y_train_df, h=7)

请注意,`AutoETS` 模型计算出某些时间序列的负值预测。

Y_hat_df.query('AutoETS < 0')
unique_iddsAutoETSNaive
28de_AAC_AAG_0012016-12-25-523.766907340.0
29de_AAC_AAG_0012016-12-26-245.337433340.0
30de_AAC_AAG_0012016-12-27-194.253815340.0
33de_AAC_AAG_0012016-12-30-315.425659340.0
34de_AAC_AAG_0012016-12-31-806.920105340.0
1217zh_AAC_AAG_0332016-12-31-86.46678937.0
1345zh_MOB2016-12-26-199.5348821036.0
1346zh_MOB2016-12-27-69.5272601036.0
1352zh_MOB_AAG2016-12-26-199.5348821036.0
1353zh_MOB_AAG2016-12-27-69.5272601036.0

3. 非负预测调节

以下单元格使用 HierarchicalReconciliation 类使先前的预测结果一致且非负。

from hierarchicalforecast.methods import MinTrace
from hierarchicalforecast.core import HierarchicalReconciliation
reconcilers = [
    MinTrace(method='ols'),
    MinTrace(method='ols', nonnegative=True)
]
hrec = HierarchicalReconciliation(reconcilers=reconcilers)
Y_rec_df = hrec.reconcile(Y_hat_df=Y_hat_df, Y_df=Y_train_df,
                          S=S_df, tags=tags)

请注意,非负预测调节方法得到了非负预测结果。

Y_rec_df
unique_iddsAutoETSNaiveAutoETS/最小跟踪方法-olsNaive/最小跟踪方法-olsAutoETS/最小跟踪方法-ols_非负-TrueNaive/最小跟踪方法-ols_非负-True
0Total2016-12-2594523.16406295743.095852.00042195743.09.664245e+0495743.0
1Total2016-12-2687734.36718895743.089525.23827695743.09.028857e+0495743.0
2Total2016-12-2787751.12500095743.089638.11918495743.09.056593e+0495743.0
3Total2016-12-28133237.96875095743.0131051.83905795743.01.314028e+0595743.0
4Total2016-12-29126501.79687595743.0121214.04860495743.01.218000e+0595743.0
1388zh_MOB_AAG_1382016-12-2762.04974465.0-147.39976065.00.000000e+0065.0
1389zh_MOB_AAG_1382016-12-2854.93403265.07.56168265.04.397229e-1565.0
1390zh_MOB_AAG_1382016-12-2960.45261865.0114.25348965.09.321380e+0165.0
1391zh_MOB_AAG_1382016-12-3050.35669365.096.44675465.07.565171e+0165.0
1392zh_MOB_AAG_1382016-12-3166.73562665.0208.18464865.01.851130e+0265.0
Y_rec_df.query('`AutoETS/MinTrace_method-ols_nonnegative-True` < 0')
unique_iddsAutoETSNaiveAutoETS/最小跟踪方法-olsNaive/最小跟踪方法-olsAutoETS/最小跟踪方法-ols_非负-TrueNaive/最小跟踪方法-ols_非负-True

自由预测调节方法得到负值预测结果。

Y_rec_df.query('`AutoETS/MinTrace_method-ols` < 0')
unique_iddsAutoETSNaiveAutoETS/最小跟踪方法-olsNaive/最小跟踪方法-olsAutoETS/最小跟踪方法-ols_非负-TrueNaive/最小跟踪方法-ols_非负-True
56de_DES2016-12-25-2553.932861495.0-3818.990043495.00.000000e+00495.0
57de_DES2016-12-26-2155.228271495.0-3309.806933495.01.909922e-30495.0
58de_DES2016-12-27-2720.993896495.0-3965.351121495.01.140223e-13495.0
60de_DES2016-12-29-3429.432617495.0-3042.502484495.03.049601e+02495.0
61de_DES2016-12-30-3963.202637495.0-3476.273292495.02.877829e+02495.0
1380zh_MOB_AAG_0362016-12-2675.298317115.0-166.245228115.00.000000e+00115.0
1381zh_MOB_AAG_0362016-12-2772.895554115.0-136.553950115.01.699002e-14115.0
1386zh_MOB_AAG_1382016-12-2594.79662365.0-49.41017465.00.000000e+0065.0
1387zh_MOB_AAG_1382016-12-2671.29398365.0-170.24956265.00.000000e+0065.0
1388zh_MOB_AAG_1382016-12-2762.04974465.0-147.39976065.00.000000e+0065.0

4. 评估

`HierarchicalForecast` 软件包包含 evaluate 函数用于评估不同的层次结构。我们使用 `utilsforecast` 计算平均绝对误差。

from hierarchicalforecast.evaluation import evaluate
from utilsforecast.losses import mse
evaluation = evaluate(df = Y_rec_df.merge(Y_test_df, on=['unique_id', 'ds']),
                      tags = tags,
                      train_df = Y_train_df,
                      metrics = [mse],
                      benchmark="Naive")

evaluation.set_index(["level", "metric"]).filter(like='ETS')
AutoETSAutoETS/最小跟踪方法-olsAutoETS/最小跟踪方法-ols_非负-True
levelmetric
Viewsmse-scaled0.7358000.6973710.675672
Views/国家mse-scaled1.1903541.0536310.994758
Views/国家/访问mse-scaled1.0861021.1335071.172270
Views/国家/访问/代理mse-scaled1.0673941.1002151.127960
Views/国家/访问/代理/主题mse-scaled1.4351051.3819901.163428
总体mse-scaled1.0108010.9776670.939286

请注意,非负预测调节方法的性能(较低的误差)优于其无约束对应方法。

参考文献