保存和加载训练好的深度学习模型有多种重要用途。训练这些模型通常成本很高;存储预训练的模型可以帮助降低成本,因为它可以被加载并重复用于多次预测。此外,它还支持迁移学习能力,即在大型数据集上预训练一个灵活的模型,然后在其他数据上使用,只需少量或无需训练。这是机器学习领域最杰出的 🚀 成就之一 🧠,并有许多实际应用。

在本 Notebook 中,我们将演示如何保存和加载 NeuralForecast 模型。

要考虑的两种方法是
1. NeuralForecast.save:将模型保存到磁盘,允许保存数据集和配置。
2. NeuralForecast.load:从给定路径加载模型。

重要提示

本指南假定您对 NeuralForecast 库有基本了解。有关最简单的示例,请参阅入门指南。

您可以使用 Google Colab 通过 GPU 运行这些实验。

1. 安装 NeuralForecast

!pip install neuralforecast

2. 加载 AirPassengers 数据

在此示例中,我们将使用经典的 AirPassenger 数据集。从 utils 导入预处理过的 AirPassenger 数据。

from neuralforecast.utils import AirPassengersDF
Y_df = AirPassengersDF
Y_df.head()
unique_iddsy
01.01949-01-31112.0
11.01949-02-28118.0
21.01949-03-31132.0
31.01949-04-30129.0
41.01949-05-31121.0

3. 模型训练

接下来,我们实例化并训练三个模型:NBEATSNHITSAutoMLP。模型及其超参数在 models 列表中定义。

import logging

from ray import tune

from neuralforecast.core import NeuralForecast
from neuralforecast.auto import AutoMLP
from neuralforecast.models import NBEATS, NHITS
logging.getLogger('pytorch_lightning').setLevel(logging.ERROR)
horizon = 12
models = [NBEATS(input_size=2 * horizon, h=horizon, max_steps=50),
          NHITS(input_size=2 * horizon, h=horizon, max_steps=50),
          AutoMLP(# Ray tune explore config
                  config=dict(max_steps=100, # Operates with steps not epochs
                              input_size=tune.choice([3*horizon]),
                              learning_rate=tune.choice([1e-3])),
                  h=horizon,
                  num_samples=1, cpus=1)]
Seed set to 1
Seed set to 1
nf = NeuralForecast(models=models, freq='ME')
nf.fit(df=Y_df)

使用 predict 方法生成预测结果。

Y_hat_df = nf.predict()
Y_hat_df.head()
Predicting: |                                                                                                 …
Predicting: |                                                                                                 …
Predicting: |                                                                                                 …
unique_iddsNBEATSNHITSAutoMLP
01.01961-01-31446.882172447.219238454.914154
11.01961-02-28465.145813464.558014430.188446
21.01961-03-31469.978424474.637238458.478577
31.01961-04-30493.650665502.670349477.244507
41.01961-05-31537.569275559.405212522.252991

我们绘制每个模型的预测结果。

from utilsforecast.plotting import plot_series
plot_series(Y_df, Y_hat_df)

4. 保存模型

要保存所有训练好的模型,请使用 save 方法。此方法将保存超参数和可学习权重(参数)。

save 方法有以下输入参数

  • path:模型将保存到的目录。
  • model_index:可选列表,用于指定要保存哪些模型。例如,若只保存 NHITS 模型,请使用 model_index=[2]
  • overwrite:布尔值,用于覆盖 path 中已存在的文件。当为 True 时,该方法只会覆盖名称冲突的模型。
  • save_dataset:布尔值,用于保存包含数据集的 Dataset 对象。
nf.save(path='./checkpoints/test_run/',
        model_index=None, 
        overwrite=True,
        save_dataset=True)

对于每个模型,会创建并存储两个文件

  • [model_name]_[suffix].ckpt:包含模型参数和超参数的 Pytorch Lightning 检查点文件。
  • [model_name]_[suffix].pkl:包含配置属性的字典。

其中 model_name 对应模型的名称小写(例如 nhits)。我们使用数字后缀来区分同一类的多个模型。在此示例中,名称将是 automlp_0nbeats_0nhits_0

重要提示

Auto 模型将以其基础模型存储。例如,上面训练的 AutoMLP 将存储为一个 MLP 模型,包含在调优过程中找到的最佳超参数。

5. 加载模型

使用 load 方法加载已保存的模型,指定 path,并使用新的 nf2 对象生成预测结果。

nf2 = NeuralForecast.load(path='./checkpoints/test_run/')
Y_hat_df2 = nf2.predict()
Y_hat_df2.head()
Seed set to 1
Seed set to 1
Seed set to 1
Predicting: |                                                                                                 …
Predicting: |                                                                                                 …
Predicting: |                                                                                                 …
unique_iddsNHITSNBEATSAutoMLP
01.01961-01-31447.219238446.882172454.914154
11.01961-02-28464.558014465.145813430.188446
21.01961-03-31474.637238469.978424458.478577
31.01961-04-30502.670349493.650665477.244507
41.01961-05-31559.405212537.569275522.252991

最后,绘制预测结果以确认它们与原始预测结果相同。

plot_series(Y_df, Y_hat_df2)

参考

https://pytorch-lightning.readthedocs.io/en/stable/common/checkpointing_basic.html

Oreshkin, B. N., Carpov, D., Chapados, N., & Bengio, Y. (2019). N-BEATS: Neural basis expansion analysis for interpretable time series forecasting. ICLR 2020

Cristian Challu, Kin G. Olivares, Boris N. Oreshkin, Federico Garza, Max Mergenthaler-Canseco, Artur Dubrawski (2021). N-HiTS: Neural Hierarchical Interpolation for Time Series Forecasting. Accepted at AAAI 2023.