总而言之,Temporal Fusion Transformer (TFT) 结合了门控层、LSTM 循环编码器和多头注意力层,用于多步预测策略解码器。
TFT 的输入包括静态外生变量 x(s)\mathbf{x}^{(s)}、历史外生变量 x[:t](h)\mathbf{x}^{(h)}_{[:t]}、预测时可用的外生变量 x[:t+H](f)\mathbf{x}^{(f)}_{[:t+H]} 和自回归特征 y[:t]\mathbf{y}_{[:t]},这些输入中的每一个都被进一步分解为类别型和连续型。该网络使用多分位数回归来建模以下条件概率:P(y[t+1:t+H]  y[:t],  x[:t](h),  x[:t+H](f),  x(s))\mathbb{P}(\mathbf{y}_{[t+1:t+H]}|\;\mathbf{y}_{[:t]},\; \mathbf{x}^{(h)}_{[:t]},\; \mathbf{x}^{(f)}_{[:t+H]},\; \mathbf{x}^{(s)})

参考文献
- Jan Golda, Krzysztof Kudrynski. “NVIDIA, 深度学习预测示例”
- Bryan Lim, Sercan O. Arik, Nicolas Loeff, Tomas Pfister,“Temporal Fusion Transformers for interpretable multi-horizon time series forecasting”

1. 辅助函数

1.1 门控机制

门控残差网络 (GRN) 提供自适应深度和网络复杂性,能够适应不同大小的数据集。残差连接允许网络跳过输入 a\mathbf{a} 和上下文 c\mathbf{c} 的非线性变换。

门控线性单元 (GLU) 提供了抑制 GRN 中不必要部分的灵活性。考虑 GRN 的输出 γ\gamma,则 GLU 变换定义为

GLU(γ)=σ(W4γ+b4)(W5γ+b5)\mathrm{GLU}(\gamma) = \sigma(\mathbf{W}_{4}\gamma +b_{4}) \odot (\mathbf{W}_{5}\gamma +b_{5})

1.2 变量选择网络

TFT 通过其变量选择网络 (VSN) 组件包含自动变量选择功能。VSN 接收原始输入 {x(s),x[:t](h),x[:t](f)}\{\mathbf{x}^{(s)}, \mathbf{x}^{(h)}_{[:t]}, \mathbf{x}^{(f)}_{[:t]}\} 并通过嵌入或线性变换将其转换为高维空间 {E(s),E[:t](h),E[:t+H](f)}\{\mathbf{E}^{(s)}, \mathbf{E}^{(h)}_{[:t]}, \mathbf{E}^{(f)}_{[:t+H]}\}

对于观测到的历史数据,时刻 tt 的嵌入矩阵 Et(h)\mathbf{E}^{(h)}_{t}jj 个变量 et,j(h)e^{(h)}_{t,j} 嵌入的拼接

变量选择权重由下式给出: st(h)=SoftMax(GRN(Et(h),E(s)))s^{(h)}_{t}=\mathrm{SoftMax}(\mathrm{GRN}(\mathbf{E}^{(h)}_{t},\mathbf{E}^{(s)}))

然后,VSN 处理后的特征为: E~t(h)=jsj(h)e~t,j(h)\tilde{\mathbf{E}}^{(h)}_{t}= \sum_{j} s^{(h)}_{j} \tilde{e}^{(h)}_{t,j}

1.3. 多头注意力

为了避免经典 Seq2Seq 架构造成的信息瓶颈,TFT 集成了继承自 Transformer 架构的解码器-编码器注意力机制 (Li et. al 2019, Vaswani et. al 2017)。它转换 LSTM 编码的时间特征的输出,并帮助解码器更好地捕获长期关系。

每个组件 HmH_{m} 的原始多头注意力的查询、键和值表示记为 Qm,Km,VmQ_{m}, K_{m}, V_{m},其变换由下式给出

TFT 修改了原始多头注意力以提高其可解释性。为此,它在所有头之间使用共享值 V~\tilde{V} 并采用加法聚合,InterpretableMultiHead(Q,K,V)=H~WM\mathrm{InterpretableMultiHead}(Q,K,V) = \tilde{H} W_{M}。该机制与单个注意力层非常相似,但它允许 MM 个多重注意力权重,因此可以被解释为 MM 个单个注意力层的平均集成。

2. TFT 架构

TFT 的第一步是将原始输入 {x(s),x(h),x(f)}\{\mathbf{x}^{(s)}, \mathbf{x}^{(h)}, \mathbf{x}^{(f)}\} 嵌入到高维空间 {E(s),E(h),E(f)}\{\mathbf{E}^{(s)}, \mathbf{E}^{(h)}, \mathbf{E}^{(f)}\} 中,然后每个嵌入都由变量选择网络 (VSN) 进行门控。静态嵌入 E(s)\mathbf{E}^{(s)} 用作变量选择的上下文以及 LSTM 的初始条件。最后,编码的变量被馈送到多头注意力解码器。

2.1 静态协变量编码器

静态嵌入 E(s)\mathbf{E}^{(s)} 通过 StaticCovariateEncoder 转换为上下文 cs,ce,ch,ccc_{s}, c_{e}, c_{h}, c_{c}。其中 csc_{s} 是时间变量选择上下文,cec_{e} 是 TemporalFusionDecoder 增强上下文,而 ch,ccc_{h}, c_{c} 是 TemporalCovariateEncoder 的 LSTM 隐藏状态/上下文。

2.2 时间协变量编码器

TemporalCovariateEncoder 使用 LSTM 对嵌入 E(h),E(f)\mathbf{E}^{(h)}, \mathbf{E}^{(f)} 和上下文 (ch,cc)(c_{h}, c_{c}) 进行编码。

对未来数据重复类似的过程,主要区别在于 E(f)\mathbf{E}^{(f)} 包含未来可用的信息。

2.3 时间融合解码器

TemporalFusionDecoder 使用 cec_{e} 丰富了 LSTM 的输出,然后使用注意力层和多步适配器。


来源

TFT

 TFT (h, input_size, tgt_size:int=1, stat_exog_list=None,
      hist_exog_list=None, futr_exog_list=None, hidden_size:int=128,
      n_head:int=4, attn_dropout:float=0.0, grn_activation:str='ELU',
      n_rnn_layers:int=1, rnn_type:str='lstm',
      one_rnn_initial_state:bool=False, dropout:float=0.1, loss=MAE(),
      valid_loss=None, max_steps:int=1000, learning_rate:float=0.001,
      num_lr_decays:int=-1, early_stop_patience_steps:int=-1,
      val_check_steps:int=100, batch_size:int=32,
      valid_batch_size:Optional[int]=None, windows_batch_size:int=1024,
      inference_windows_batch_size:int=1024, start_padding_enabled=False,
      step_size:int=1, scaler_type:str='robust', random_seed:int=1,
      drop_last_loader=False, alias:Optional[str]=None, optimizer=None,
      optimizer_kwargs=None, lr_scheduler=None, lr_scheduler_kwargs=None,
      dataloader_kwargs=None, **trainer_kwargs)

*TFT

Temporal Fusion Transformer (TFT) 架构是一种 Sequence-to-Sequence 模型,它结合静态、历史和未来可用数据来预测单变量目标。该方法结合了门控层、LSTM 循环编码器、可解释的多头注意力层以及多步预测策略解码器。

参数
h: int,预测范围。
input_size: int,自回归输入大小,y=[1,2,3,4] input_size=2 -> y_[t-2:t]=[1,2]。
tgt_size: int=1,目标大小。
stat_exog_list: str list,静态连续列。
hist_exog_list: str list,历史连续列。
futr_exog_list: str list,未来连续列。
hidden_size: int,嵌入和编码器的单元数。
n_head: int=4,时间融合解码器中的注意力头数。
attn_dropout: float (0, 1),融合解码器注意力层的 dropout。
grn_activation: str,GRN 模块的激活函数,可选值为 [‘ReLU’, ‘Softplus’, ‘Tanh’, ‘SELU’, ‘LeakyReLU’, ‘Sigmoid’, ‘ELU’, ‘GLU’]。
n_rnn_layers: int=1,RNN 层数。
rnn_type: str=“lstm”,循环神经网络 (RNN) 层类型,可选值为 [“lstm”,“gru”]。
one_rnn_initial_state: str=False,使用由静态协变量计算出的相同初始状态初始化所有 RNN 层。
dropout: float (0, 1),输入 VSN 的 dropout。
loss: PyTorch 模块,从 损失函数集合 中实例化的训练损失类。
valid_loss: PyTorch 模块=loss,从 损失函数集合 中实例化的验证损失类。
max_steps: int=1000,最大训练步数。
learning_rate: float=1e-3,学习率,范围 (0, 1)。
num_lr_decays: int=-1,学习率衰减次数,在最大训练步数中均匀分布。
early_stop_patience_steps: int=-1,早停前等待的验证迭代次数。
val_check_steps: int=100,每次验证损失检查之间的训练步数。
batch_size: int,每个批次中不同时间序列的数量。
valid_batch_size: int=None,每个验证和测试批次中不同时间序列的数量。
windows_batch_size: int=None,从滚动数据中采样的窗口数,默认为全部。
inference_windows_batch_size: int=-1,每次推理批次中采样的窗口数,-1 为全部。
start_padding_enabled: bool=False,如果为 True,模型将在时间序列开头按输入大小填充零。
step_size: int=1,时间数据每个窗口之间的步长。
scaler_type: str=‘robust’,用于时间输入归一化的缩放器类型,参见 时间缩放器
random_seed: int,用于重复性实验的随机种子初始化。
drop_last_loader: bool=False,如果为 True,TimeSeriesDataLoader 将丢弃最后一个非完整批次。
alias: str,可选,模型的自定义名称。
optimizer: ‘torch.optim.Optimizer’ 的子类,可选,用户指定的优化器,而不是默认选项 (Adam)。
optimizer_kwargs: dict,可选,用户指定 optimizer 使用的参数列表。
lr_scheduler: ‘torch.optim.lr_scheduler.LRScheduler’ 的子类,可选,用户指定的 lr_scheduler,而不是默认选项 (StepLR)。
lr_scheduler_kwargs: dict,可选,用户指定 lr_scheduler 使用的参数列表。
dataloader_kwargs: dict,可选,由 TimeSeriesDataLoader 传递给 PyTorch Lightning 数据加载器的参数列表。
**trainer_kwargs: int,继承自 PyTorch Lightning trainer 的关键字 trainer 参数。

参考文献
- Bryan Lim, Sercan O. Arik, Nicolas Loeff, Tomas Pfister,“Temporal Fusion Transformers for interpretable multi-horizon time series forecasting”*

3. TFT 方法


TFT.fit

 TFT.fit (dataset, val_size=0, test_size=0, random_seed=None,
          distributed_config=None)

*拟合。

fit 方法使用初始化参数(learning_ratewindows_batch_size 等)和初始化时定义的 loss 函数来优化神经网络的权重。在 fit 方法内部,我们使用继承了初始化参数 self.trainer_kwargs 的 PyTorch Lightning Trainer 来自定义其输入,参见 PL 的 trainer 参数

该方法设计为与 SKLearn-like 类兼容,特别是与 StatsForecast 库兼容。

默认情况下,model 不保存训练检查点以保护磁盘内存,如需保存,请在 __init__ 中将 enable_checkpointing 设置为 True

参数
dataset: NeuralForecast 的 TimeSeriesDataset,参见文档
val_size: int,时间交叉验证的验证集大小。
random_seed: int=None,用于 pytorch 初始化器和 numpy 生成器的随机种子,会覆盖 model.__init__ 中的设置。
test_size: int,时间交叉验证的测试集大小。
*


TFT.predict

 TFT.predict (dataset, test_size=None, step_size=1, random_seed=None,
              quantiles=None, **data_module_kwargs)

*预测。

使用 PL 的 Trainer 执行 predict_step 进行神经网络预测。

参数
dataset: NeuralForecast 的 TimeSeriesDataset,参见文档
test_size: int=None,时间交叉验证的测试集大小。
step_size: int=1,每个窗口之间的步长。
random_seed: int=None,用于 pytorch 初始化器和 numpy 生成器的随机种子,会覆盖 model.__init__ 中的设置。
quantiles: list of floats,可选 (default=None),要预测的目标分位数。
**data_module_kwargs: PL 的 TimeSeriesDataModule 参数,参见 文档。*


来源

TFT.feature_importances,

 TFT.feature_importances, ()

*计算历史、未来和静态特征的重要性。

返回:dict:包含每种特征类型重要性的字典。键为 ‘hist_vsn’、‘future_vsn’ 和 ‘static_vsn’,对应的值为包含相应特征重要性的 pandas DataFrame。*


来源

TFT.attention_weights

 TFT.attention_weights ()

*批次平均注意力权重

返回:np.ndarray:包含每个时间步注意力权重的 1D 数组。*


来源

TFT.attention_weights

 TFT.attention_weights ()

*批次平均注意力权重

返回:np.ndarray:包含每个时间步注意力权重的 1D 数组。*


来源

TFT.feature_importance_correlations

 TFT.feature_importance_correlations ()

*计算过去和未来特征重要性与平均注意力权重之间的相关性。

返回:pd.DataFrame:包含过去特征重要性与平均注意力权重之间相关系数的 DataFrame。*

使用示例

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd

from neuralforecast import NeuralForecast

# from neuralforecast.models import TFT
from neuralforecast.losses.pytorch import DistributionLoss
from neuralforecast.utils import AirPassengersPanel, AirPassengersStatic

AirPassengersPanel["month"] = AirPassengersPanel.ds.dt.month
Y_train_df = AirPassengersPanel[
    AirPassengersPanel.ds < AirPassengersPanel["ds"].values[-12]
]  # 132 train
Y_test_df = AirPassengersPanel[
    AirPassengersPanel.ds >= AirPassengersPanel["ds"].values[-12]
].reset_index(drop=True)  # 12 test

nf = NeuralForecast(
    models=[
        TFT(
            h=12,
            input_size=48,
            hidden_size=20,
            grn_activation="ELU",
            rnn_type="lstm",
            n_rnn_layers=1,
            one_rnn_initial_state=False,
            loss=DistributionLoss(distribution="StudentT", level=[80, 90]),
            learning_rate=0.005,
            stat_exog_list=["airline1"],
            futr_exog_list=["y_[lag12]", "month"],
            hist_exog_list=["trend"],
            max_steps=300,
            val_check_steps=10,
            early_stop_patience_steps=10,
            scaler_type="robust",
            windows_batch_size=None,
            enable_progress_bar=True,
        ),
    ],
    freq="ME",
)
nf.fit(df=Y_train_df, static_df=AirPassengersStatic, val_size=12)
Y_hat_df = nf.predict(futr_df=Y_test_df)

# Plot quantile predictions
Y_hat_df = Y_hat_df.reset_index(drop=False).drop(columns=["unique_id", "ds"])
plot_df = pd.concat([Y_test_df, Y_hat_df], axis=1)
plot_df = pd.concat([Y_train_df, plot_df])

plot_df = plot_df[plot_df.unique_id == "Airline1"].drop("unique_id", axis=1)
plt.plot(plot_df["ds"], plot_df["y"], c="black", label="True")
plt.plot(plot_df["ds"], plot_df["TFT"], c="purple", label="mean")
plt.plot(plot_df["ds"], plot_df["TFT-median"], c="blue", label="median")
plt.fill_between(
    x=plot_df["ds"][-12:],
    y1=plot_df["TFT-lo-90"][-12:].values,
    y2=plot_df["TFT-hi-90"][-12:].values,
    alpha=0.4,
    label="level 90",
)
plt.legend()
plt.grid()
plt.plot()

可解释性

1. 注意力权重

attention = nf.models[0].attention_weights()
def plot_attention(
    self, plot: str = "time", output: str = "plot", width: int = 800, height: int = 400
):
    """
    Plot the attention weights.

    Args:
        plot (str, optional): The type of plot to generate. Can be one of the following:
            - 'time': Display the mean attention weights over time.
            - 'all': Display the attention weights for each horizon.
            - 'heatmap': Display the attention weights as a heatmap.
            - An integer in the range [1, model.h) to display the attention weights for a specific horizon.
        output (str, optional): The type of output to generate. Can be one of the following:
            - 'plot': Display the plot directly.
            - 'figure': Return the plot as a figure object.
        width (int, optional): Width of the plot in pixels. Default is 800.
        height (int, optional): Height of the plot in pixels. Default is 400.

    Returns:
        matplotlib.figure.Figure: If `output` is 'figure', the function returns the plot as a figure object.
    """

    attention = (
        self.mean_on_batch(self.interpretability_params["attn_wts"])
        .mean(dim=0)
        .cpu()
        .numpy()
    )

    fig, ax = plt.subplots(figsize=(width / 100, height / 100))

    if plot == "time":
        attention = attention[self.input_size :, :].mean(axis=0)
        ax.plot(np.arange(-self.input_size, self.h), attention)
        ax.axvline(
            x=0, color="black", linewidth=3, linestyle="--", label="prediction start"
        )
        ax.set_title("Mean Attention")
        ax.set_xlabel("time")
        ax.set_ylabel("Attention")
        ax.legend()

    elif plot == "all":
        for i in range(self.input_size, attention.shape[0]):
            ax.plot(
                np.arange(-self.input_size, self.h),
                attention[i, :],
                label=f"horizon {i-self.input_size+1}",
            )
        ax.axvline(
            x=0, color="black", linewidth=3, linestyle="--", label="prediction start"
        )
        ax.set_title("Attention per horizon")
        ax.set_xlabel("time")
        ax.set_ylabel("Attention")
        ax.legend()

    elif plot == "heatmap":
        cax = ax.imshow(
            attention,
            aspect="auto",
            cmap="viridis",
            extent=[-self.input_size, self.h, -self.input_size, self.h],
        )
        fig.colorbar(cax)
        ax.set_title("Attention Heatmap")
        ax.set_xlabel("Attention (current time step)")
        ax.set_ylabel("Attention (previous time step)")

    elif isinstance(plot, int) and (plot in np.arange(1, self.h + 1)):
        i = self.input_size + plot - 1
        ax.plot(
            np.arange(-self.input_size, self.h),
            attention[i, :],
            label=f"horizon {plot}",
        )
        ax.axvline(
            x=0, color="black", linewidth=3, linestyle="--", label="prediction start"
        )
        ax.set_title(f"Attention weight for horizon {plot}")
        ax.set_xlabel("time")
        ax.set_ylabel("Attention")
        ax.legend()

    else:
        raise ValueError(
            'plot has to be in ["time","all","heatmap"] or integer in range(1,model.h)'
        )

    plt.tight_layout()

    if output == "plot":
        plt.show()
    elif output == "figure":
        return fig
    else:
        raise ValueError(f"Invalid output: {output}. Expected 'plot' or 'figure'.")

1.1 平均注意力

plot_attention(nf.models[0], plot="time")

1.2 所有未来时间步的注意力

plot_attention(nf.models[0], plot="all")

1.3 特定未来时间步的注意力

plot_attention(nf.models[0], plot=8)

2. 特征重要性

2.1 全局特征重要性

feature_importances = nf.models[0].feature_importances()
feature_importances.keys()

静态变量重要性

feature_importances["Static covariates"].sort_values(by="importance").plot(kind="barh")

过去变量重要性

feature_importances["Past variable importance over time"].mean().sort_values().plot(
    kind="barh"
)

未来变量重要性

feature_importances["Future variable importance over time"].mean().sort_values().plot(
    kind="barh"
)

2.2 随时间变化的变量重要性

未来变量随时间变化的重要性

每个未来时间步中每个未来协变量的重要性

df = feature_importances["Future variable importance over time"]


fig, ax = plt.subplots(figsize=(20, 10))
bottom = np.zeros(len(df.index))
for col in df.columns:
    p = ax.bar(np.arange(-len(df), 0), df[col].values, 0.6, label=col, bottom=bottom)
    bottom += df[col]
ax.set_title("Future variable importance over time ponderated by attention")
ax.set_ylabel("Importance")
ax.set_xlabel("Time")
ax.grid(True)
ax.legend()
plt.show()

2.3

过去变量随时间变化的重要性

df = feature_importances["Past variable importance over time"]

fig, ax = plt.subplots(figsize=(20, 10))
bottom = np.zeros(len(df.index))

for col in df.columns:
    p = ax.bar(np.arange(-len(df), 0), df[col].values, 0.6, label=col, bottom=bottom)
    bottom += df[col]
ax.set_title("Past variable importance over time")
ax.set_ylabel("Importance")
ax.set_xlabel("Time")
ax.legend()
ax.grid(True)

plt.show()

由注意力加权的过去变量随时间变化的重要性

根据每个时间步中每个变量的重要性来分解每个时间步的重要性

df = feature_importances["Past variable importance over time"]
mean_attention = (
    nf.models[0]
    .attention_weights()[nf.models[0].input_size :, :]
    .mean(axis=0)[: nf.models[0].input_size]
)
df = df.multiply(mean_attention, axis=0)

fig, ax = plt.subplots(figsize=(20, 10))
bottom = np.zeros(len(df.index))

for col in df.columns:
    p = ax.bar(np.arange(-len(df), 0), df[col].values, 0.6, label=col, bottom=bottom)
    bottom += df[col]
ax.set_title("Past variable importance over time ponderated by attention")
ax.set_ylabel("Importance")
ax.set_xlabel("Time")
ax.legend()
ax.grid(True)
plt.plot(
    np.arange(-len(df), 0),
    mean_attention,
    color="black",
    marker="o",
    linestyle="-",
    linewidth=2,
    label="mean_attention",
)
plt.legend()
plt.show()

3. 随时间变化的变量重要性相关性

在同一时刻获得和失去重要性的变量

nf.models[0].feature_importance_correlations()