保存和加载模型
保存和加载训练好的深度学习模型有多种重要用途。训练这些模型通常成本很高;存储预训练的模型可以帮助降低成本,因为它可以被加载并重复用于多次预测。此外,它还支持迁移学习能力,即在大型数据集上预训练一个灵活的模型,然后在其他数据上使用,只需少量或无需训练。这是机器学习领域最杰出的 🚀 成就之一 🧠,并有许多实际应用。
在本 Notebook 中,我们将演示如何保存和加载 NeuralForecast 模型。
要考虑的两种方法是
1. NeuralForecast.save:将模型保存到磁盘,允许保存数据集和配置。
2. NeuralForecast.load:从给定路径加载模型。
重要提示
本指南假定您对 NeuralForecast 库有基本了解。有关最简单的示例,请参阅入门指南。
您可以使用 Google Colab 通过 GPU 运行这些实验。
1. 安装 NeuralForecast
2. 加载 AirPassengers 数据
在此示例中,我们将使用经典的 AirPassenger 数据集。从 utils 导入预处理过的 AirPassenger 数据。
| unique_id | ds | y | |
|---|---|---|---|
| 0 | 1.0 | 1949-01-31 | 112.0 | 
| 1 | 1.0 | 1949-02-28 | 118.0 | 
| 2 | 1.0 | 1949-03-31 | 132.0 | 
| 3 | 1.0 | 1949-04-30 | 129.0 | 
| 4 | 1.0 | 1949-05-31 | 121.0 | 
3. 模型训练
接下来,我们实例化并训练三个模型:NBEATS、NHITS 和 AutoMLP。模型及其超参数在 models 列表中定义。
使用 predict 方法生成预测结果。
| unique_id | ds | NBEATS | NHITS | AutoMLP | |
|---|---|---|---|---|---|
| 0 | 1.0 | 1961-01-31 | 446.882172 | 447.219238 | 454.914154 | 
| 1 | 1.0 | 1961-02-28 | 465.145813 | 464.558014 | 430.188446 | 
| 2 | 1.0 | 1961-03-31 | 469.978424 | 474.637238 | 458.478577 | 
| 3 | 1.0 | 1961-04-30 | 493.650665 | 502.670349 | 477.244507 | 
| 4 | 1.0 | 1961-05-31 | 537.569275 | 559.405212 | 522.252991 | 
我们绘制每个模型的预测结果。
4. 保存模型
要保存所有训练好的模型,请使用 save 方法。此方法将保存超参数和可学习权重(参数)。
save 方法有以下输入参数
- path:模型将保存到的目录。
- model_index:可选列表,用于指定要保存哪些模型。例如,若只保存- NHITS模型,请使用- model_index=[2]。
- overwrite:布尔值,用于覆盖- path中已存在的文件。当为 True 时,该方法只会覆盖名称冲突的模型。
- save_dataset:布尔值,用于保存包含数据集的- Dataset对象。
对于每个模型,会创建并存储两个文件
- [model_name]_[suffix].ckpt:包含模型参数和超参数的 Pytorch Lightning 检查点文件。
- [model_name]_[suffix].pkl:包含配置属性的字典。
其中 model_name 对应模型的名称小写(例如 nhits)。我们使用数字后缀来区分同一类的多个模型。在此示例中,名称将是 automlp_0、nbeats_0 和 nhits_0。
重要提示
Auto 模型将以其基础模型存储。例如,上面训练的
AutoMLP将存储为一个MLP模型,包含在调优过程中找到的最佳超参数。
5. 加载模型
使用 load 方法加载已保存的模型,指定 path,并使用新的 nf2 对象生成预测结果。
| unique_id | ds | NHITS | NBEATS | AutoMLP | |
|---|---|---|---|---|---|
| 0 | 1.0 | 1961-01-31 | 447.219238 | 446.882172 | 454.914154 | 
| 1 | 1.0 | 1961-02-28 | 464.558014 | 465.145813 | 430.188446 | 
| 2 | 1.0 | 1961-03-31 | 474.637238 | 469.978424 | 458.478577 | 
| 3 | 1.0 | 1961-04-30 | 502.670349 | 493.650665 | 477.244507 | 
| 4 | 1.0 | 1961-05-31 | 559.405212 | 537.569275 | 522.252991 | 
最后,绘制预测结果以确认它们与原始预测结果相同。
参考
https://pytorch-lightning.readthedocs.io/en/stable/common/checkpointing_basic.html

