StatsForecast 遵循 sklearn 模型 API。对于这个最小示例,您将创建 StatsForecast 类的一个实例,然后调用它的 fitpredict 方法。如果您对速度要求不是特别高,并且想探索拟合值和参数,我们推荐此选项。

提示

如果您想预测多个序列,我们建议使用 forecast 方法。请查看这篇多时间序列入门指南。

StatsForecast 的输入始终是采用长格式的数据帧,包含三列:unique_iddsy

  • unique_id(字符串、整数或类别)表示序列的标识符。

  • ds(日期戳)列应采用 Pandas 期望的格式,日期格式最好为 YYYY-MM-DD,时间戳格式最好为 YYYY-MM-DD HH:MM:SS。

  • y(数值)表示我们希望预测的度量值。

作为一个例子,我们来看看 US Air Passengers 数据集。该时间序列包含 1949 年至 1960 年美国航空公司乘客的月总数。CSV 文件可从此处获取。

我们假设您已经安装了 StatsForecast。请查看这篇指南,了解如何安装 StatsForecast的说明。

首先,我们将导入数据

# uncomment the following line to install the library
# %pip install statsforecast
import pandas as pd
df = pd.read_csv('https://datasets-nixtla.s3.amazonaws.com/air-passengers.csv', parse_dates=['ds'])
df.head()
唯一ID日期戳值 y
0AirPassengers1949-01-01112
1AirPassengers1949-02-01118
2AirPassengers1949-03-01132
3AirPassengers1949-04-01129
4AirPassengers1949-05-01121

我们通过实例化一个新的StatsForecast 对象来拟合模型,该对象需要两个参数:https://nixtla.github.io/statsforecast/src/core/models.html * models:模型列表。从模型中选择您想要使用的模型并导入它们。本例中,我们将使用AutoARIMA 模型。我们将 season_length 设置为 12,因为我们预期每 12 个月出现一次季节性效应。(参见:季节周期

任何设置都会传递给构造函数。然后您调用其 fit 方法并传入历史数据帧。

注意

StatsForecast 通过 Numba 使用 JIT 编译实现了惊人的速度。首次调用 statsforecast 类时,fit 方法可能需要大约 5 秒。第二次调用(Numba 编译设置后)应该不到 0.2 秒。

from statsforecast import StatsForecast
from statsforecast.models import AutoARIMA
sf = StatsForecast(
    models=[AutoARIMA(season_length = 12)],
    freq='MS',
)
sf.fit(df)
StatsForecast(models=[AutoARIMA])

predict 方法接受两个参数:预测未来 h 步(horizon)和 level

  • h (int):表示预测未来 h 步。在本例中,即提前 12 个月。

  • level(浮点数列表):此可选参数用于概率预测。设置预测区间的 level(或置信百分位)。例如,level=[90] 意味着模型预期真实值有 90% 的时间落在此区间内。

这里的 forecast 对象是一个新的数据帧,包含一列模型名称和 y hat 值,以及不确定性区间的列。

forecast_df = sf.predict(h=12, level=[90])
forecast_df.tail()
唯一ID日期戳AutoARIMAAutoARIMA-lo-90AutoARIMA-hi-90
7AirPassengers1961-08-01633.236389590.009033676.463745
8AirPassengers1961-09-01535.236389489.558899580.913940
9AirPassengers1961-10-01488.236389440.233795536.239014
10AirPassengers1961-11-01417.236389367.016205467.456604
11AirPassengers1961-12-01459.236389406.892456511.580322

您可以通过调用 StatsForecast.plot 方法并传入您的 forecast 数据帧来绘制预测结果。

sf.plot(df, forecast_df, level=[90])

后续步骤