保存和加载模型
保存和加载训练好的深度学习模型有多种重要用途。训练这些模型通常成本很高;存储预训练的模型可以帮助降低成本,因为它可以被加载并重复用于多次预测。此外,它还支持迁移学习能力,即在大型数据集上预训练一个灵活的模型,然后在其他数据上使用,只需少量或无需训练。这是机器学习领域最杰出的 🚀 成就之一 🧠,并有许多实际应用。
在本 Notebook 中,我们将演示如何保存和加载 NeuralForecast
模型。
要考虑的两种方法是
1. NeuralForecast.save
:将模型保存到磁盘,允许保存数据集和配置。
2. NeuralForecast.load
:从给定路径加载模型。
重要提示
本指南假定您对 NeuralForecast 库有基本了解。有关最简单的示例,请参阅入门指南。
您可以使用 Google Colab 通过 GPU 运行这些实验。
1. 安装 NeuralForecast
2. 加载 AirPassengers 数据
在此示例中,我们将使用经典的 AirPassenger 数据集。从 utils
导入预处理过的 AirPassenger 数据。
unique_id | ds | y | |
---|---|---|---|
0 | 1.0 | 1949-01-31 | 112.0 |
1 | 1.0 | 1949-02-28 | 118.0 |
2 | 1.0 | 1949-03-31 | 132.0 |
3 | 1.0 | 1949-04-30 | 129.0 |
4 | 1.0 | 1949-05-31 | 121.0 |
3. 模型训练
接下来,我们实例化并训练三个模型:NBEATS
、NHITS
和 AutoMLP
。模型及其超参数在 models
列表中定义。
使用 predict
方法生成预测结果。
unique_id | ds | NBEATS | NHITS | AutoMLP | |
---|---|---|---|---|---|
0 | 1.0 | 1961-01-31 | 446.882172 | 447.219238 | 454.914154 |
1 | 1.0 | 1961-02-28 | 465.145813 | 464.558014 | 430.188446 |
2 | 1.0 | 1961-03-31 | 469.978424 | 474.637238 | 458.478577 |
3 | 1.0 | 1961-04-30 | 493.650665 | 502.670349 | 477.244507 |
4 | 1.0 | 1961-05-31 | 537.569275 | 559.405212 | 522.252991 |
我们绘制每个模型的预测结果。
4. 保存模型
要保存所有训练好的模型,请使用 save
方法。此方法将保存超参数和可学习权重(参数)。
save
方法有以下输入参数
path
:模型将保存到的目录。model_index
:可选列表,用于指定要保存哪些模型。例如,若只保存NHITS
模型,请使用model_index=[2]
。overwrite
:布尔值,用于覆盖path
中已存在的文件。当为 True 时,该方法只会覆盖名称冲突的模型。save_dataset
:布尔值,用于保存包含数据集的Dataset
对象。
对于每个模型,会创建并存储两个文件
[model_name]_[suffix].ckpt
:包含模型参数和超参数的 Pytorch Lightning 检查点文件。[model_name]_[suffix].pkl
:包含配置属性的字典。
其中 model_name
对应模型的名称小写(例如 nhits
)。我们使用数字后缀来区分同一类的多个模型。在此示例中,名称将是 automlp_0
、nbeats_0
和 nhits_0
。
重要提示
Auto 模型将以其基础模型存储。例如,上面训练的
AutoMLP
将存储为一个MLP
模型,包含在调优过程中找到的最佳超参数。
5. 加载模型
使用 load
方法加载已保存的模型,指定 path
,并使用新的 nf2
对象生成预测结果。
unique_id | ds | NHITS | NBEATS | AutoMLP | |
---|---|---|---|---|---|
0 | 1.0 | 1961-01-31 | 447.219238 | 446.882172 | 454.914154 |
1 | 1.0 | 1961-02-28 | 464.558014 | 465.145813 | 430.188446 |
2 | 1.0 | 1961-03-31 | 474.637238 | 469.978424 | 458.478577 |
3 | 1.0 | 1961-04-30 | 502.670349 | 493.650665 | 477.244507 |
4 | 1.0 | 1961-05-31 | 559.405212 | 537.569275 | 522.252991 |
最后,绘制预测结果以确认它们与原始预测结果相同。
参考
https://pytorch-lightning.readthedocs.io/en/stable/common/checkpointing_basic.html