总而言之,Temporal Fusion Transformer (TFT) 结合了门控层、LSTM 循环编码器和多头注意力层,用于多步预测策略解码器。
TFT 的输入包括静态外生变量 x(s)、历史外生变量 x[:t](h)、预测时可用的外生变量 x[:t+H](f) 和自回归特征 y[:t],这些输入中的每一个都被进一步分解为类别型和连续型。该网络使用多分位数回归来建模以下条件概率:P(y[t+1:t+H]∣y[:t],x[:t](h),x[:t+H](f),x(s))
参考文献
- Jan Golda, Krzysztof Kudrynski. “NVIDIA, 深度学习预测示例”
- Bryan Lim, Sercan O. Arik, Nicolas Loeff, Tomas Pfister,“Temporal Fusion Transformers for interpretable multi-horizon time series forecasting”
1. 辅助函数
1.1 门控机制
门控残差网络 (GRN) 提供自适应深度和网络复杂性,能够适应不同大小的数据集。残差连接允许网络跳过输入 a 和上下文 c 的非线性变换。
门控线性单元 (GLU) 提供了抑制 GRN 中不必要部分的灵活性。考虑 GRN 的输出 γ,则 GLU 变换定义为
GLU(γ)=σ(W4γ+b4)⊙(W5γ+b5)
1.2 变量选择网络
TFT 通过其变量选择网络 (VSN) 组件包含自动变量选择功能。VSN 接收原始输入 {x(s),x[:t](h),x[:t](f)} 并通过嵌入或线性变换将其转换为高维空间 {E(s),E[:t](h),E[:t+H](f)}。
对于观测到的历史数据,时刻 t 的嵌入矩阵 Et(h) 是 j 个变量 et,j(h) 嵌入的拼接
变量选择权重由下式给出: st(h)=SoftMax(GRN(Et(h),E(s)))
然后,VSN 处理后的特征为: E~t(h)=∑jsj(h)e~t,j(h)
1.3. 多头注意力
为了避免经典 Seq2Seq 架构造成的信息瓶颈,TFT 集成了继承自 Transformer 架构的解码器-编码器注意力机制 (Li et. al 2019, Vaswani et. al 2017)。它转换 LSTM 编码的时间特征的输出,并帮助解码器更好地捕获长期关系。
每个组件 Hm 的原始多头注意力的查询、键和值表示记为 Qm,Km,Vm,其变换由下式给出
TFT 修改了原始多头注意力以提高其可解释性。为此,它在所有头之间使用共享值 V~ 并采用加法聚合,InterpretableMultiHead(Q,K,V)=H~WM。该机制与单个注意力层非常相似,但它允许 M 个多重注意力权重,因此可以被解释为 M 个单个注意力层的平均集成。
2. TFT 架构
TFT 的第一步是将原始输入 {x(s),x(h),x(f)} 嵌入到高维空间 {E(s),E(h),E(f)} 中,然后每个嵌入都由变量选择网络 (VSN) 进行门控。静态嵌入 E(s) 用作变量选择的上下文以及 LSTM 的初始条件。最后,编码的变量被馈送到多头注意力解码器。
2.1 静态协变量编码器
静态嵌入 E(s) 通过 StaticCovariateEncoder 转换为上下文 cs,ce,ch,cc。其中 cs 是时间变量选择上下文,ce 是 TemporalFusionDecoder 增强上下文,而 ch,cc 是 TemporalCovariateEncoder 的 LSTM 隐藏状态/上下文。
2.2 时间协变量编码器
TemporalCovariateEncoder 使用 LSTM 对嵌入 E(h),E(f) 和上下文 (ch,cc) 进行编码。
对未来数据重复类似的过程,主要区别在于 E(f) 包含未来可用的信息。
2.3 时间融合解码器
TemporalFusionDecoder 使用 ce 丰富了 LSTM 的输出,然后使用注意力层和多步适配器。
来源
TFT
*TFT
Temporal Fusion Transformer (TFT) 架构是一种 Sequence-to-Sequence 模型,它结合静态、历史和未来可用数据来预测单变量目标。该方法结合了门控层、LSTM 循环编码器、可解释的多头注意力层以及多步预测策略解码器。
参数
h
: int,预测范围。
input_size
: int,自回归输入大小,y=[1,2,3,4] input_size=2 -> y_[t-2:t]=[1,2]。
tgt_size
: int=1,目标大小。
stat_exog_list
: str list,静态连续列。
hist_exog_list
: str list,历史连续列。
futr_exog_list
: str list,未来连续列。
hidden_size
: int,嵌入和编码器的单元数。
n_head
: int=4,时间融合解码器中的注意力头数。
attn_dropout
: float (0, 1),融合解码器注意力层的 dropout。
grn_activation
: str,GRN 模块的激活函数,可选值为 [‘ReLU’, ‘Softplus’, ‘Tanh’, ‘SELU’, ‘LeakyReLU’, ‘Sigmoid’, ‘ELU’, ‘GLU’]。
n_rnn_layers
: int=1,RNN 层数。
rnn_type
: str=“lstm”,循环神经网络 (RNN) 层类型,可选值为 [“lstm”,“gru”]。
one_rnn_initial_state
: str=False,使用由静态协变量计算出的相同初始状态初始化所有 RNN 层。
dropout
: float (0, 1),输入 VSN 的 dropout。
loss
: PyTorch 模块,从 损失函数集合 中实例化的训练损失类。
valid_loss
: PyTorch 模块=loss
,从 损失函数集合 中实例化的验证损失类。
max_steps
: int=1000,最大训练步数。
learning_rate
: float=1e-3,学习率,范围 (0, 1)。
num_lr_decays
: int=-1,学习率衰减次数,在最大训练步数中均匀分布。
early_stop_patience_steps
: int=-1,早停前等待的验证迭代次数。
val_check_steps
: int=100,每次验证损失检查之间的训练步数。
batch_size
: int,每个批次中不同时间序列的数量。
valid_batch_size
: int=None,每个验证和测试批次中不同时间序列的数量。
windows_batch_size
: int=None,从滚动数据中采样的窗口数,默认为全部。
inference_windows_batch_size
: int=-1,每次推理批次中采样的窗口数,-1 为全部。
start_padding_enabled
: bool=False,如果为 True,模型将在时间序列开头按输入大小填充零。
step_size
: int=1,时间数据每个窗口之间的步长。
scaler_type
: str=‘robust’,用于时间输入归一化的缩放器类型,参见 时间缩放器。
random_seed
: int,用于重复性实验的随机种子初始化。
drop_last_loader
: bool=False,如果为 True,TimeSeriesDataLoader
将丢弃最后一个非完整批次。
alias
: str,可选,模型的自定义名称。
optimizer
: ‘torch.optim.Optimizer’ 的子类,可选,用户指定的优化器,而不是默认选项 (Adam)。
optimizer_kwargs
: dict,可选,用户指定 optimizer
使用的参数列表。
lr_scheduler
: ‘torch.optim.lr_scheduler.LRScheduler’ 的子类,可选,用户指定的 lr_scheduler,而不是默认选项 (StepLR)。
lr_scheduler_kwargs
: dict,可选,用户指定 lr_scheduler
使用的参数列表。
dataloader_kwargs
: dict,可选,由 TimeSeriesDataLoader
传递给 PyTorch Lightning 数据加载器的参数列表。
**trainer_kwargs
: int,继承自 PyTorch Lightning trainer 的关键字 trainer 参数。
参考文献
- Bryan Lim, Sercan O. Arik, Nicolas Loeff, Tomas Pfister,“Temporal Fusion Transformers for interpretable multi-horizon time series forecasting”*
3. TFT 方法
TFT.fit
*拟合。
fit
方法使用初始化参数(learning_rate
、windows_batch_size
等)和初始化时定义的 loss
函数来优化神经网络的权重。在 fit
方法内部,我们使用继承了初始化参数 self.trainer_kwargs
的 PyTorch Lightning Trainer
来自定义其输入,参见 PL 的 trainer 参数。
该方法设计为与 SKLearn-like 类兼容,特别是与 StatsForecast 库兼容。
默认情况下,model
不保存训练检查点以保护磁盘内存,如需保存,请在 __init__
中将 enable_checkpointing
设置为 True
。
参数
dataset
: NeuralForecast 的 TimeSeriesDataset
,参见文档。
val_size
: int,时间交叉验证的验证集大小。
random_seed
: int=None,用于 pytorch 初始化器和 numpy 生成器的随机种子,会覆盖 model.__init__ 中的设置。
test_size
: int,时间交叉验证的测试集大小。
*
TFT.predict
*预测。
使用 PL 的 Trainer
执行 predict_step
进行神经网络预测。
参数
dataset
: NeuralForecast 的 TimeSeriesDataset
,参见文档。
test_size
: int=None,时间交叉验证的测试集大小。
step_size
: int=1,每个窗口之间的步长。
random_seed
: int=None,用于 pytorch 初始化器和 numpy 生成器的随机种子,会覆盖 model.__init__ 中的设置。
quantiles
: list of floats,可选 (default=None),要预测的目标分位数。
**data_module_kwargs
: PL 的 TimeSeriesDataModule 参数,参见 文档。*
来源
TFT.feature_importances,
*计算历史、未来和静态特征的重要性。
返回:dict:包含每种特征类型重要性的字典。键为 ‘hist_vsn’、‘future_vsn’ 和 ‘static_vsn’,对应的值为包含相应特征重要性的 pandas DataFrame。*
来源
TFT.attention_weights
*批次平均注意力权重
返回:np.ndarray:包含每个时间步注意力权重的 1D 数组。*
来源
TFT.attention_weights
*批次平均注意力权重
返回:np.ndarray:包含每个时间步注意力权重的 1D 数组。*
来源
TFT.feature_importance_correlations
*计算过去和未来特征重要性与平均注意力权重之间的相关性。
返回:pd.DataFrame:包含过去特征重要性与平均注意力权重之间相关系数的 DataFrame。*
使用示例
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from neuralforecast import NeuralForecast
from neuralforecast.losses.pytorch import DistributionLoss
from neuralforecast.utils import AirPassengersPanel, AirPassengersStatic
AirPassengersPanel["month"] = AirPassengersPanel.ds.dt.month
Y_train_df = AirPassengersPanel[
AirPassengersPanel.ds < AirPassengersPanel["ds"].values[-12]
]
Y_test_df = AirPassengersPanel[
AirPassengersPanel.ds >= AirPassengersPanel["ds"].values[-12]
].reset_index(drop=True)
nf = NeuralForecast(
models=[
TFT(
h=12,
input_size=48,
hidden_size=20,
grn_activation="ELU",
rnn_type="lstm",
n_rnn_layers=1,
one_rnn_initial_state=False,
loss=DistributionLoss(distribution="StudentT", level=[80, 90]),
learning_rate=0.005,
stat_exog_list=["airline1"],
futr_exog_list=["y_[lag12]", "month"],
hist_exog_list=["trend"],
max_steps=300,
val_check_steps=10,
early_stop_patience_steps=10,
scaler_type="robust",
windows_batch_size=None,
enable_progress_bar=True,
),
],
freq="ME",
)
nf.fit(df=Y_train_df, static_df=AirPassengersStatic, val_size=12)
Y_hat_df = nf.predict(futr_df=Y_test_df)
Y_hat_df = Y_hat_df.reset_index(drop=False).drop(columns=["unique_id", "ds"])
plot_df = pd.concat([Y_test_df, Y_hat_df], axis=1)
plot_df = pd.concat([Y_train_df, plot_df])
plot_df = plot_df[plot_df.unique_id == "Airline1"].drop("unique_id", axis=1)
plt.plot(plot_df["ds"], plot_df["y"], c="black", label="True")
plt.plot(plot_df["ds"], plot_df["TFT"], c="purple", label="mean")
plt.plot(plot_df["ds"], plot_df["TFT-median"], c="blue", label="median")
plt.fill_between(
x=plot_df["ds"][-12:],
y1=plot_df["TFT-lo-90"][-12:].values,
y2=plot_df["TFT-hi-90"][-12:].values,
alpha=0.4,
label="level 90",
)
plt.legend()
plt.grid()
plt.plot()
可解释性
1. 注意力权重
def plot_attention(
self, plot: str = "time", output: str = "plot", width: int = 800, height: int = 400
):
"""
Plot the attention weights.
Args:
plot (str, optional): The type of plot to generate. Can be one of the following:
- 'time': Display the mean attention weights over time.
- 'all': Display the attention weights for each horizon.
- 'heatmap': Display the attention weights as a heatmap.
- An integer in the range [1, model.h) to display the attention weights for a specific horizon.
output (str, optional): The type of output to generate. Can be one of the following:
- 'plot': Display the plot directly.
- 'figure': Return the plot as a figure object.
width (int, optional): Width of the plot in pixels. Default is 800.
height (int, optional): Height of the plot in pixels. Default is 400.
Returns:
matplotlib.figure.Figure: If `output` is 'figure', the function returns the plot as a figure object.
"""
attention = (
self.mean_on_batch(self.interpretability_params["attn_wts"])
.mean(dim=0)
.cpu()
.numpy()
)
fig, ax = plt.subplots(figsize=(width / 100, height / 100))
if plot == "time":
attention = attention[self.input_size :, :].mean(axis=0)
ax.plot(np.arange(-self.input_size, self.h), attention)
ax.axvline(
x=0, color="black", linewidth=3, linestyle="--", label="prediction start"
)
ax.set_title("Mean Attention")
ax.set_xlabel("time")
ax.set_ylabel("Attention")
ax.legend()
elif plot == "all":
for i in range(self.input_size, attention.shape[0]):
ax.plot(
np.arange(-self.input_size, self.h),
attention[i, :],
label=f"horizon {i-self.input_size+1}",
)
ax.axvline(
x=0, color="black", linewidth=3, linestyle="--", label="prediction start"
)
ax.set_title("Attention per horizon")
ax.set_xlabel("time")
ax.set_ylabel("Attention")
ax.legend()
elif plot == "heatmap":
cax = ax.imshow(
attention,
aspect="auto",
cmap="viridis",
extent=[-self.input_size, self.h, -self.input_size, self.h],
)
fig.colorbar(cax)
ax.set_title("Attention Heatmap")
ax.set_xlabel("Attention (current time step)")
ax.set_ylabel("Attention (previous time step)")
elif isinstance(plot, int) and (plot in np.arange(1, self.h + 1)):
i = self.input_size + plot - 1
ax.plot(
np.arange(-self.input_size, self.h),
attention[i, :],
label=f"horizon {plot}",
)
ax.axvline(
x=0, color="black", linewidth=3, linestyle="--", label="prediction start"
)
ax.set_title(f"Attention weight for horizon {plot}")
ax.set_xlabel("time")
ax.set_ylabel("Attention")
ax.legend()
else:
raise ValueError(
'plot has to be in ["time","all","heatmap"] or integer in range(1,model.h)'
)
plt.tight_layout()
if output == "plot":
plt.show()
elif output == "figure":
return fig
else:
raise ValueError(f"Invalid output: {output}. Expected 'plot' or 'figure'.")
1.1 平均注意力
1.2 所有未来时间步的注意力
1.3 特定未来时间步的注意力
2. 特征重要性
2.1 全局特征重要性
静态变量重要性
过去变量重要性
未来变量重要性
2.2 随时间变化的变量重要性
未来变量随时间变化的重要性
每个未来时间步中每个未来协变量的重要性
2.3
过去变量随时间变化的重要性
由注意力加权的过去变量随时间变化的重要性
根据每个时间步中每个变量的重要性来分解每个时间步的重要性
3. 随时间变化的变量重要性相关性
在同一时刻获得和失去重要性的变量