使用 TFT 进行预测:时序融合 Transformer
Temporal Fusion Transformer (TFT) 模型由 Lim 等人 [1] 提出,是时间序列预测中最流行的基于 Transformer 的模型之一。总而言之,TFT 结合了门控层、LSTM 循环编码器以及多头注意力层,用于多步预测策略的解码器。有关 Nixtla 的 TFT 实现的更多详细信息,请访问此链接。
在本 Notebook 中,我们将展示如何在德州电力市场负荷数据 (ERCOT) 上训练 TFT 模型。准确预测电力市场具有重要意义,因为它对于规划电力分配和消费非常有用。
我们将向您展示如何加载数据、训练 TFT 并进行自动超参数调整以及生成预测。然后,我们将向您展示如何执行多次历史预测以进行交叉验证。
您可以使用 Google Colab 通过 GPU 运行这些实验。
1. 库
2. 加载 ERCOT 数据
NeuralForecast 的输入始终是一个采用长格式的数据框,包含三列:unique_id
、ds
和 y
-
unique_id
(字符串、整数或类别)表示序列的标识符。 -
ds
(日期戳或整数)列应为索引时间的整数,或理想情况下为 YYYY-MM-DD 格式的日期戳或 YYYY-MM-DD HH:MM:SS 格式的时间戳。 -
y
(数值)表示我们希望预测的测量值。我们将重命名
首先,读取 ERCOT 市场的 2022 年历史总需求数据。我们处理了原始数据(可在此处获取),通过添加因夏令时缺失的小时数据,将日期解析为 datetime 格式,并筛选出感兴趣的列。
unique_id | ds | y | |
---|---|---|---|
0 | ERCOT | 2021-01-01 00:00:00 | 43719.849616 |
1 | ERCOT | 2021-01-01 01:00:00 | 43321.050347 |
2 | ERCOT | 2021-01-01 02:00:00 | 43063.067063 |
3 | ERCOT | 2021-01-01 03:00:00 | 43090.059203 |
4 | ERCOT | 2021-01-01 04:00:00 | 43486.590073 |
3. 模型训练和预测
首先,实例化 AutoTFT
模型。 AutoTFT
类将使用 Tune 库自动执行超参数调优,探索用户定义或默认的搜索空间。模型根据在验证集上的误差进行选择,然后存储最佳模型并在推理时使用。
要实例化 AutoTFT
,您需要定义
h
: 预测范围loss
: 训练损失config
: 超参数搜索空间。如果为None
,则AutoTFT
类将使用预定义的建议超参数空间。num_samples
: 探索的配置数量。
提示
增加
num_samples
参数以探索所选模型的更广泛配置集。根据经验法则,选择大于15
的值。当
num_samples=3
时,此示例应在约 20 分钟内运行完成。
提示
我们所有的模型都可以用于点预测和概率预测。要生成概率输出,只需将损失函数修改为我们的
DistributionLoss
之一。完整的损失函数列表可在此链接中找到
重要提示
TFT 是一个非常大的模型,可能需要大量内存!如果您遇到 GPU 内存不足的问题,请尝试声明您的配置搜索空间并减小
hidden_size
、n_heads
和windows_batch_size
参数。这些是 config 的所有参数
NeuralForecast
类具有内置方法来简化预测流程,例如 fit
、predit
和 cross_validation
。使用以下必需参数实例化一个 NeuralForecast
对象
-
models
: 模型列表。 -
freq
: 指示数据频率的字符串。(参见 panda 的可用频率列表。)
然后,使用 fit
方法在 ERCOT 数据上训练 AutoTFT
模型。总训练时间将取决于硬件和探索的配置,应在 10 到 30 分钟之间。
最后,使用 predict
方法预测训练数据后未来 24 小时的数据并绘制预测结果。
unique_id | ds | AutoTFT | |
---|---|---|---|
0 | ERCOT | 2022-10-01 00:00:00 | 38600.757812 |
1 | ERCOT | 2022-10-01 01:00:00 | 36871.199219 |
2 | ERCOT | 2022-10-01 02:00:00 | 35505.500000 |
3 | ERCOT | 2022-10-01 03:00:00 | 34781.691406 |
4 | ERCOT | 2022-10-01 04:00:00 | 34647.484375 |
使用 matplot lib 绘制结果
4. 对多次历史预测进行交叉验证
cross_validation
方法允许您模拟多次历史预测,通过用 fit
和 predict
方法替换 for 循环来大大简化流程。请参阅此教程以查看如何定义窗口的动画演示。
对于时间序列数据,交叉验证是通过在历史数据上定义一个滑动窗口并预测其后续周期来完成的。这种形式的交叉验证使我们能够在更广泛的时间实例范围内更好地估计模型的预测能力,同时保持训练集中的数据是连续的,这是我们的模型所要求的。cross_validation
方法将使用验证集进行超参数选择,然后为测试集生成预测结果。
使用 cross_validation
方法生成九月份的所有每日预测。设置验证集和测试集的大小。要生成每日预测,请将窗口之间的步长设置为 24,以便每天只生成一个预测。
最后,我们将预测结果与 Y_df
数据集合并并绘制预测图。
后续步骤
在 Challu 等人 [2] 的工作中,我们证明 N-HiTS 模型以少 50 倍的计算量,性能优于最新的 Transformer 模型 20% 以上。
在此教程中了解如何使用 N-HiTS 和 NeuralForecast 库。