长周期概率预测
长周期预测具有挑战性,因为预测结果具有波动性且存在计算复杂度。为了解决这个问题,我们创建了 NHITS 模型,并将代码发布到 NeuralForecast 库中。NHITS
通过分层插值和多速率输入处理,使其部分输出专注于时间序列的不同频率。我们使用 Student's t-分布 对目标时间序列进行建模。NHITS
将为每个时间戳输出分布参数。
在本 Notebook 中,我们将展示如何在 ETTm2 基准数据集上使用 NHITS
进行概率预测。该数据集包含来自 2 个站点的 2 个变压器的数据点,包括负荷、油温。
我们将向您展示如何加载数据、训练模型以及执行自动超参数调优,以实现 SoTA(最先进)性能,其性能甚至超越最新的 Transformer 架构,而计算成本仅为其一小部分(快 50 倍)。
您可以使用 Google Colab 的 GPU 运行这些实验。
1. 库
2. 加载 ETTm2 数据
LongHorizon
类将自动下载完整的 ETTm2 数据集并进行处理。
它返回三个 Dataframe:Y_df
包含目标变量的值,X_df
包含外生日历特征,S_df
包含每个时间序列的静态特征(ETTm2 没有)。在此示例中,我们将仅使用 Y_df
。
如果您想使用自己的数据,只需替换 Y_df
。请确保使用长格式,并且结构与我们的数据集相似。
unique_id | ds | y | |
---|---|---|---|
0 | HUFL | 2016-07-01 00:00:00 | -0.041413 |
1 | HUFL | 2016-07-01 00:15:00 | -0.185467 |
57600 | HULL | 2016-07-01 00:00:00 | 0.040104 |
57601 | HULL | 2016-07-01 00:15:00 | -0.214450 |
115200 | LUFL | 2016-07-01 00:00:00 | 0.695804 |
115201 | LUFL | 2016-07-01 00:15:00 | 0.434685 |
172800 | LULL | 2016-07-01 00:00:00 | 0.434430 |
172801 | LULL | 2016-07-01 00:15:00 | 0.428168 |
230400 | MUFL | 2016-07-01 00:00:00 | -0.599211 |
230401 | MUFL | 2016-07-01 00:15:00 | -0.658068 |
288000 | MULL | 2016-07-01 00:00:00 | -0.393536 |
288001 | MULL | 2016-07-01 00:15:00 | -0.659338 |
345600 | OT | 2016-07-01 00:00:00 | 1.018032 |
345601 | OT | 2016-07-01 00:15:00 | 0.980124 |
重要提示
DataFrames 必须包含所有
['unique_id', 'ds', 'y']
列。确保y
列没有缺失值或非数值。
接下来,绘制 HUFL
变量,并标记验证集和训练集的分割点。
3. 超参数选择和预测
AutoNHITS
类将使用 Tune 库自动执行超参数调优,探索用户定义或默认的搜索空间。模型根据在验证集上的误差进行选择,然后存储最佳模型并在推理时使用。
AutoNHITS.default_config
属性包含一个建议的超参数空间。在这里,我们根据论文中的超参数指定了不同的搜索空间。请注意,*1000 次随机梯度步长*足以实现 SoTA 性能。您可以随意调整此空间。
提示
有关不同空间选项(如列表和连续区间)的更多信息,请参阅 https://docs.rayai.org.cn/en/latest/tune/index.html。
要实例化 AutoNHITS
,您需要定义
h
:预测范围loss
:训练损失。使用DistributionLoss
生成概率预测。config
:超参数搜索空间。如果为None
,AutoNHITS
类将使用预定义的建议超参数空间。num_samples
:探索的配置数量。
通过实例化 NeuralForecast
对象并使用以下必需参数来拟合模型
-
models
:模型列表。 -
freq
:表示数据频率的字符串。(请参阅 panda 可用的频率。)
cross_validation
方法允许您模拟多个历史预测,通过用 fit
和 predict
方法替换 for 循环,极大地简化了管道。
对于时间序列数据,交叉验证是通过在历史数据上定义一个滑动窗口并预测其后续周期来完成的。这种形式的交叉验证使我们能够在更广泛的时间实例范围内更好地估计模型的预测能力,同时也按照模型的需要保持训练集中的数据是连续的。
cross_validation
方法将使用验证集进行超参数选择,然后生成测试集的预测结果。
4. 可视化
最后,我们将预测结果与 Y_df
数据集合并并绘制预测图。
unique_id | ds | cutoff | AutoNHITS | AutoNHITS-median | AutoNHITS-lo-90 | AutoNHITS-lo-80 | AutoNHITS-hi-80 | AutoNHITS-hi-90 | y | |
---|---|---|---|---|---|---|---|---|---|---|
0 | HUFL | 2018-02-11 00:00:00 | 2018-02-10 23:45:00 | -0.922304 | -0.914175 | -1.217987 | -1.138274 | -0.708157 | -0.617799 | -0.849571 |
1 | HUFL | 2018-02-11 00:15:00 | 2018-02-10 23:45:00 | -0.954299 | -0.957198 | -1.403932 | -1.263984 | -0.618467 | -0.442688 | -1.049700 |
2 | HUFL | 2018-02-11 00:30:00 | 2018-02-10 23:45:00 | -0.987538 | -0.972558 | -1.512509 | -1.310191 | -0.621673 | -0.444359 | -1.185730 |
3 | HUFL | 2018-02-11 00:45:00 | 2018-02-10 23:45:00 | -1.067760 | -1.063188 | -1.614276 | -1.475302 | -0.665729 | -0.521775 | -1.329785 |
4 | HUFL | 2018-02-11 01:00:00 | 2018-02-10 23:45:00 | -1.001276 | -1.001494 | -1.508795 | -1.390156 | -0.629212 | -0.470608 | -1.369715 |
… | … | … | … | … | … | … | … | … | … | … |
581275 | OT | 2018-02-20 22:45:00 | 2018-02-19 23:45:00 | -1.200041 | -1.200862 | -1.591271 | -1.490571 | -0.907190 | -0.779424 | -1.581325 |
581276 | OT | 2018-02-20 23:00:00 | 2018-02-19 23:45:00 | -1.237206 | -1.225333 | -1.618691 | -1.518204 | -0.960075 | -0.838512 | -1.581325 |
581277 | OT | 2018-02-20 23:15:00 | 2018-02-19 23:45:00 | -1.232434 | -1.229675 | -1.591164 | -1.481251 | -0.989993 | -0.870404 | -1.581325 |
581278 | OT | 2018-02-20 23:30:00 | 2018-02-19 23:45:00 | -1.259237 | -1.258848 | -1.659239 | -1.536979 | -0.985581 | -0.822370 | -1.562328 |
581279 | OT | 2018-02-20 23:45:00 | 2018-02-19 23:45:00 | -1.247161 | -1.251899 | -1.631909 | -1.520350 | -0.949529 | -0.832602 | -1.562328 |