使用 Polars 的端到端演练
多时间序列的模型训练、评估和选择
Polars 简介:一个高性能 DataFrame 库
本文档旨在重点介绍 Polars 的最新集成,Polars 是一个使用 Rust 开发的强大且高速的 DataFrame 库,现已集成到 StatsForecast 的功能中。Polars 凭借其灵活强大的功能,已在数据科学社区内迅速建立起良好声誉,进一步巩固了其作为管理和操作大型数据集的可靠工具的地位。
Polars 支持 Rust、Python、Node.js 和 R 等语言,它在处理大型数据集时表现出卓越的效率和速度,超越了许多其他 DataFrame 库(例如 Pandas)。Polars 的开源性质吸引了持续的改进和贡献,增强了其在数据科学领域的吸引力。
Polars 促使其快速普及的最重要特性是
-
性能效率:Polars 使用 Rust 构建,在管理大型数据集时展现出卓越的能力,速度惊人且内存使用量极低。
-
惰性计算:Polars 基于“惰性计算”原则运行,为高效执行创建优化的逻辑操作计划,这一特性类似于 Apache Spark 的功能。
-
并行执行:Polars 能够利用多核 CPU,促进操作的并行执行,显著加速数据处理任务。
前提条件
本指南假定您对 StatsForecast 有基本了解。有关最小示例,请访问快速入门
请按照本文的逐步指南构建适用于多个时间序列的生产级预测管道。
在本指南中,您将熟悉核心的 StatsForecast
类以及一些相关方法,例如 StatsForecast.plot
、StatsForecast.forecast
和 StatsForecast.cross_validation.
我们将使用来自 M4 竞赛的经典基准数据集。该数据集包含来自不同领域的时间序列,如金融、经济和销售。在本例中,我们将使用小时数据集的一个子集。
我们将对每个时间序列进行单独建模。这种级别的预测也称为局部预测。因此,您将为每个独特序列训练一系列模型,然后选择最佳模型。StatsForecast 专注于速度、简洁性和可扩展性,这使得它非常适合这项任务。
大纲
- 安装软件包。
- 读取数据。
- 探索数据。
- 为每个独特的时间序列组合训练多个模型。
- 使用交叉验证评估模型的性能。
- 为每个独特的时间序列选择最佳模型。
本指南未涵盖的内容
- 使用云上的集群进行大规模预测。
- 使用 Ray 集群在 5 分钟内预测 M5 数据集。
- 使用 Spark 集群在 5 分钟内预测 M5 数据集。
- 了解如何在不到 30 分钟内预测 100 万个序列。
- 训练多重季节性模型。
- 在此电力负荷预测教程中学习如何使用多重季节性。
- 使用外部回归量或外生变量
- 按照此教程包含外生变量,例如天气或节假日,或类别或族群等静态变量。
- 将 StatsForecast 与其他常用库进行比较。
- 您可以在此处重现我们的基准测试。
安装库
我们假定您已安装 StatsForecast。有关如何安装 StatsForecast 的说明,请查阅本指南。
读取数据
我们将使用 Polars 读取存储在 Parquet 文件中的 M4 小时数据集以提高效率。您可以使用常规的 Polars 操作读取 .csv
等其他格式的数据。
StatsForecast 的输入始终是一个长格式的 DataFrame,包含三列:unique_id
, ds
和 y
-
unique_id
(字符串、整数或类别)表示序列的标识符。 -
ds
(日期戳或整数)列应为表示时间的整数索引,或理想情况下为日期戳,如日期的 YYYY-MM-DD 或时间戳的 YYYY-MM-DD HH:MM:SS。 -
y
(数值)表示我们希望预测的测量值。
该数据集已满足要求。
根据您的互联网连接,此步骤大约需要 10 秒。
unique_id | ds | y |
---|---|---|
str | i64 | f64 |
“H1” | 1 | 605.0 |
“H1” | 2 | 586.0 |
“H1” | 3 | 586.0 |
“H1” | 4 | 559.0 |
“H1” | 5 | 511.0 |
该数据集包含 414 个独特序列,平均有 900 个观测值。出于此示例和可重现性的目的,我们将仅选择 10 个独特 ID 并仅保留最后一周的数据。根据您的处理基础设施,您可以随意选择更多或更少的序列。
注意
处理时间取决于可用的计算资源。在 AWS 的 c5d.24xlarge(96 核)实例上运行此示例并使用完整数据集大约需要 10 分钟。
使用 plot 方法探索数据
使用 StatsForecast
类中的 plot
方法绘制一些序列。此方法打印数据集中的 8 个随机序列,对于基本 EDA 非常有用。
注意
StatsForecast.plot
方法默认使用 matplotlib 作为引擎。您可以通过设置engine="plotly"
更改为 plotly。
为多个序列训练多个模型
StatsForecast 可以高效地在多个时间序列上训练多个模型。
首先导入并实例化所需的模型。StatsForecast 提供了多种模型,按以下类别分组:
-
自动预测:自动预测工具搜索最佳参数并为一系列时间序列选择最佳模型。这些工具对于大量的单变量时间序列非常有用。包括以下模型的自动版本:Arima、ETS、Theta、CES。
-
指数平滑:使用所有过去观测值的加权平均,权重随着时间向过去呈指数衰减。适用于没有明显趋势或季节性的数据。示例:SES、Holt’s Winters、SSO。
-
基准模型:用于建立基线经典模型。示例:Mean、Naive、Random Walk
-
间歇性或稀疏模型:适用于非零观测值非常少的序列。示例:CROSTON、ADIDA、IMAPA
-
多重季节性:适用于具有一个以上明显季节性的信号。对于电力和日志等低频数据非常有用。示例:MSTL。
-
Theta 模型:将两个 theta 线拟合到去季节化的时间序列,使用不同的技术获取并组合这两条 theta 线以生成最终预测。示例:Theta、DynamicTheta
您可以在此处查看完整的模型列表。
在此示例中,我们将使用
-
HoltWinters
:三重指数平滑,Holt-Winters 方法是指数平滑的扩展,适用于同时包含趋势和季节性的序列。参考:HoltWinters
-
SeasonalNaive
:内存高效的季节性朴素预测。参考:SeasonalNaive
-
HistoricAverage
:算术平均值。参考:HistoricAverage
。 -
DynamicOptimizedTheta
:Theta 模型家族在各种数据集(如 M3)中表现良好。对去季节化的时间序列进行建模。参考:DynamicOptimizedTheta
。
导入并实例化模型。设置 season_length
参数有时很棘手。大师 Rob Hyndmann 关于季节周期的文章可能会有所帮助。
我们通过实例化一个新的 StatsForecast
对象来拟合模型,该对象具有以下参数
-
models
:模型列表。从模型中选择您想要的模型并导入它们。 -
freq
:表示数据频率的字符串。(参见 panda 的可用频率。)Polars 也支持此功能。 -
n_jobs
:n_jobs:整数,并行处理中使用的作业数,使用 -1 表示所有核心。 -
fallback_model
:如果模型失败时使用的模型。
任何设置都会传递给构造函数。然后调用其 fit 方法并传入历史数据帧。
注意
StatsForecast 通过 Numba 进行 JIT 编译,从而实现惊人的速度。首次调用 statsforecast 类时,fit 方法大约需要 5 秒。第二次(Numba 编译完您的设置后)应该不到 0.2 秒。
forecast
方法接受两个参数:预测未来 h
(视野)步和 level
。
-
h
(整数):表示预测未来 h 步。在本例中,是未来 12 个月。 -
level
(浮点数列表):此可选参数用于概率预测。设置预测区间(或置信百分比)的level
。例如,level=[90]
表示模型期望真实值有 90% 的时间位于该区间内。
此处的 forecast 对象是一个新的数据帧,包含模型名称和 y hat 值的列,以及不确定性区间的列。根据您的计算机,此步骤大约需要 1 分钟。(如果您想将速度提高到几秒,请移除 AutoModels,例如 ARIMA 和 Theta)
注意
forecast
方法与分布式集群兼容,因此不存储任何模型参数。如果您想为每个模型存储参数,可以使用fit
和predict
方法。但是,这些方法未针对 Spark、Ray 或 Dask 等分布式引擎定义。
unique_id | ds | HoltWinters | HoltWinters-lo-90 | HoltWinters-hi-90 | CrostonClassic | CrostonClassic-lo-90 | CrostonClassic-hi-90 | SeasonalNaive | SeasonalNaive-lo-90 | SeasonalNaive-hi-90 | HistoricAverage | HistoricAverage-lo-90 | HistoricAverage-hi-90 | DynamicOptimizedTheta | DynamicOptimizedTheta-lo-90 | DynamicOptimizedTheta-hi-90 |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
str | i64 | f64 | f64 | f64 | f64 | f64 | f64 | f64 | f64 | f64 | f64 | f64 | f64 | f64 | f64 | f64 |
“H1” | 749 | 829.0 | 422.549268 | 1235.450732 | 829.0 | 422.549268 | 1235.450732 | 635.0 | 566.036734 | 703.963266 | 660.982143 | 398.037761 | 923.926524 | 592.701851 | 577.67728 | 611.652639 |
“H1” | 750 | 807.0 | 400.549268 | 1213.450732 | 807.0 | 400.549268 | 1213.450732 | 572.0 | 503.036734 | 640.963266 | 660.982143 | 398.037761 | 923.926524 | 525.589116 | 505.449755 | 546.621805 |
“H1” | 751 | 785.0 | 378.549268 | 1191.450732 | 785.0 | 378.549268 | 1191.450732 | 532.0 | 463.036734 | 600.963266 | 660.982143 | 398.037761 | 923.926524 | 489.251814 | 462.072871 | 512.424116 |
“H1” | 752 | 756.0 | 349.549268 | 1162.450732 | 756.0 | 349.549268 | 1162.450732 | 493.0 | 424.036734 | 561.963266 | 660.982143 | 398.037761 | 923.926524 | 456.195032 | 430.554302 | 478.260963 |
“H1” | 753 | 719.0 | 312.549268 | 1125.450732 | 719.0 | 312.549268 | 1125.450732 | 477.0 | 408.036734 | 545.963266 | 660.982143 | 398.037761 | 923.926524 | 436.290514 | 411.051232 | 461.815932 |
使用 StatsForecast.plot
方法绘制 8 个随机序列的结果。
StatsForecast.plot
允许进一步定制。例如,绘制不同模型和独特 ID 的结果。
评估模型的性能
在前面的步骤中,我们使用历史数据来预测未来。然而,为了评估其准确性,我们还想知道模型在过去会如何表现。为了评估模型在您的数据上的准确性和鲁棒性,请执行交叉验证。
对于时间序列数据,交叉验证是通过在历史数据上定义一个滑动窗口并预测其后续周期来完成的。这种形式的交叉验证使我们能够在更广泛的时间实例中更好地估计模型的预测能力,同时保持训练集中的数据连续,这是模型所要求的。
下图描绘了这样的交叉验证策略
时间序列模型的交叉验证被认为是最佳实践,但大多数实现速度非常慢。statsforecast 库将交叉验证实现为分布式操作,从而减少了执行时间。如果您的数据集很大,您也可以使用 Ray、Dask 或 Spark 在分布式集群中执行交叉验证。
在本例中,我们希望评估每个模型在最后 2 天(n_windows=2)的性能,每隔一天(step_size=48)进行预测。根据您的计算机,此步骤大约需要 1 分钟。
提示
设置
n_windows=1
模拟了传统的训练-测试集划分,其中历史数据作为训练集,最后 48 小时作为测试集。
StatsForecast
类中的 cross_validation
方法接受以下参数。
-
df
:训练数据帧 -
h
(整数):表示正在预测的未来 h 步。在本例中,是未来 24 小时。 -
step_size
(整数):每个窗口之间的步长。换句话说:您希望多久运行一次预测过程。 -
n_windows
(整数):用于交叉验证的窗口数。换句话说:您希望评估过去多少个预测过程。
cv_df
对象是一个新的数据帧,包含以下列
-
unique_id
:序列标识符 -
ds
:日期戳或时间索引 -
cutoff
:n_windows
的最后一个日期戳或时间索引。如果n_windows=1
,则只有一个独特的截止值;如果n_windows=2
,则有两个独特的截止值。 -
y
:真实值 -
"model"
:包含模型名称和拟合值的列。
unique_id | ds | cutoff | y | HoltWinters | CrostonClassic | SeasonalNaive | HistoricAverage | DynamicOptimizedTheta |
---|---|---|---|---|---|---|---|---|
str | i64 | i64 | f64 | f64 | f64 | f64 | f64 | f64 |
“H1” | 701 | 700 | 619.0 | 847.0 | 742.668748 | 691.0 | 661.675 | 612.767504 |
“H1” | 702 | 700 | 565.0 | 820.0 | 742.668748 | 618.0 | 661.675 | 536.846278 |
“H1” | 703 | 700 | 532.0 | 790.0 | 742.668748 | 563.0 | 661.675 | 497.824286 |
“H1” | 704 | 700 | 495.0 | 784.0 | 742.668748 | 529.0 | 661.675 | 464.723219 |
“H1” | 705 | 700 | 481.0 | 752.0 | 742.668748 | 504.0 | 661.675 | 440.972336 |
接下来,我们将使用常用的误差度量(如平均绝对误差 (MAE) 或均方误差 (MSE))评估每个模型在每个序列上的性能。定义一个实用函数来评估交叉验证数据帧的不同误差度量。
首先从 utilsforecast.losses
导入所需的误差度量。然后定义一个实用函数,该函数将交叉验证数据帧作为度量,并返回一个评估数据帧,其中包含每个独特 ID、拟合模型和所有截止点的误差度量的平均值。
警告
您也可以使用平均绝对百分比误差 (MAPE),但对于精细预测,MAPE 值极难判断,对评估预测质量没有用。
使用均方误差度量创建包含交叉验证数据帧评估结果的数据帧。
unique_id | HoltWinters | CrostonClassic | SeasonalNaive | HistoricAverage | DynamicOptimizedTheta | best_model |
---|---|---|---|---|---|---|
str | f64 | f64 | f64 | f64 | f64 | str |
“H1” | 44888.020833 | 28038.733985 | 1422.666667 | 20927.664488 | 1296.333977 | “DynamicOptimizedTheta” |
“H10” | 2812.916667 | 1483.483839 | 96.895833 | 1980.367543 | 379.621134 | “SeasonalNaive” |
“H100” | 121625.375 | 91945.139237 | 12019.0 | 78491.191439 | 21699.649325 | “SeasonalNaive” |
“H101” | 28453.395833 | 16183.63434 | 10944.458333 | 18208.4098 | 63698.077266 | “SeasonalNaive” |
“H102” | 232924.854167 | 132655.309136 | 12699.895833 | 309110.475212 | 31393.535274 | “SeasonalNaive” |
创建一个摘要表,其中包含模型列以及该模型表现最佳的序列数量。在这种情况下,Arima 和 Seasonal Naive 是 10 个序列的最佳模型,Theta 模型应用于两个序列。
best_model | count |
---|---|
str | u32 |
“DynamicOptimizedTheta” | 4 |
“SeasonalNaive” | 6 |
您可以通过绘制特定模型获胜的 unique_ids 来进一步探索您的结果。
为每个独特序列选择最佳模型
定义一个实用函数,该函数接受包含预测结果的 forecast 数据帧和 evaluation 数据帧,并返回一个包含每个 unique_id 的最佳预测结果的数据帧。
创建您的生产级数据帧,其中包含每个 unique_id 的最佳预测结果。
unique_id | ds | best_model | best_model-lo-90 | best_model-hi-90 |
---|---|---|---|---|
str | i64 | f64 | f64 | f64 |
“H1” | 749 | 592.701851 | 577.67728 | 611.652639 |
“H1” | 750 | 525.589116 | 505.449755 | 546.621805 |
“H1” | 751 | 489.251814 | 462.072871 | 512.424116 |
“H1” | 752 | 456.195032 | 430.554302 | 478.260963 |
“H1” | 753 | 436.290514 | 411.051232 | 461.815932 |
绘制结果。