使用大型数据集
关于如何在无法完全载入内存的数据集上训练 neuralforecast 模型的教程
NeuralForecast 使用的标准 DataLoader 类期望数据集由一个单一的 DataFrame 表示,该 DataFrame 在拟合模型时会完全加载到内存中。然而,当数据集过大无法满足此要求时,我们可以转而使用自定义的大规模 DataLoader。这种自定义加载器假定每个时间序列都分散存储在一系列 Parquet 文件中,并确保在给定时间点只有一批数据被加载到内存中。
在本笔记本中,我们将演示这些文件的预期格式、如何训练模型以及如何使用这种大规模 DataLoader 执行推理。
加载库
数据
每个时间序列应存储在一个名为 unique_id=timeseries_id 的目录中。在此目录内,时间序列可以完全包含在一个 Parquet 文件中,或者分散在多个 Parquet 文件中。无论格式如何,时间序列都必须按时间排序。
例如,以下代码将 AirPassengers DataFrame(其中每个时间序列已按时间排序)拆分成以下格式
> data
> unique_id=Airline1
- a59945617fdb40d1bc6caa4aadad881c-0.parquet
> unique_id=Airline2
- a59945617fdb40d1bc6caa4aadad881c-0.parquet
然后,我们只需输入这些目录路径的列表。
unique_id | ds | y | 趋势 | y_[lag12] | |
---|---|---|---|---|---|
0 | Airline1 | 1949-01-31 | 112.0 | 0 | 112.0 |
1 | Airline1 | 1949-02-28 | 118.0 | 1 | 118.0 |
2 | Airline1 | 1949-03-31 | 132.0 | 2 | 132.0 |
3 | Airline1 | 1949-04-30 | 129.0 | 3 | 129.0 |
4 | Airline1 | 1949-05-31 | 121.0 | 4 | 121.0 |
… | … | … | … | … | … |
283 | Airline2 | 1960-08-31 | 906.0 | 283 | 859.0 |
284 | Airline2 | 1960-09-30 | 808.0 | 284 | 763.0 |
285 | Airline2 | 1960-10-31 | 761.0 | 285 | 707.0 |
286 | Airline2 | 1960-11-30 | 690.0 | 286 | 662.0 |
287 | Airline2 | 1960-12-31 | 732.0 | 287 | 705.0 |
您也可以使用 spark dataframe 创建此目录结构,如下所示:
DataLoader 类仍然期望静态数据作为单个 DataFrame 传入,其中每行对应一个时间序列。
id_col | airline1 | airline2 | |
---|---|---|---|
0 | Airline1 | 0 | 1 |
1 | Airline2 | 1 | 0 |
模型训练
现在我们在上述数据集上训练一个 NHITS 模型。值得注意的是,NeuralForecast 目前在使用此 DataLoader 时不支持缩放。如果您想缩放时间序列,应在将其传递给 fit
方法之前完成。
预测
处理大型数据集时,我们需要提供一个单一的 DataFrame,其中包含所有我们希望为其生成预测的时间序列的输入时间步。如果存在未来的外生特征,我们也应将这些特征的未来值包含在单独的 futr_df
DataFrame 中。
对于下面的预测,我们假设只想预测 Airline2 的接下来的 12 个时间步。
id_col | ds | NHITS | |
---|---|---|---|
0 | Airline2 | 1960-01-31 | 713.441406 |
1 | Airline2 | 1960-02-29 | 688.176880 |
2 | Airline2 | 1960-03-31 | 763.382935 |
3 | Airline2 | 1960-04-30 | 745.478027 |
4 | Airline2 | 1960-05-31 | 758.036438 |
5 | Airline2 | 1960-06-30 | 806.288574 |
6 | Airline2 | 1960-07-31 | 869.563782 |
7 | Airline2 | 1960-08-31 | 858.105896 |
8 | Airline2 | 1960-09-30 | 803.531555 |
9 | Airline2 | 1960-10-31 | 751.093079 |
10 | Airline2 | 1960-11-30 | 700.435852 |
11 | Airline2 | 1960-12-31 | 746.640259 |
评估
指标 | NHITS | |
---|---|---|
0 | mae | 20.728617 |
1 | rmse | 26.980698 |
2 | smape | 0.012879 |