目录

引言

StatsForecast 中的 AutoTheta 模型根据均方误差 (MSE) 自动选择最佳的 Theta 模型。在本节中,我们将讨论 AutoTheta 考虑的每个模型,然后解释它是如何选择最佳模型的。

1. 标准 Theta 模型 (STM)

标准 Theta 模型 是 Assimakopoulos 和 Nikolopoulos (2000) 引入的 Theta 模型的原始版本。它将时间序列分解为原始序列的两个修改版本,称为 theta 线。这些线是通过对原始序列的二阶差分应用线性变换创建的,由一个称为 theta θ\theta 的参数控制。一条 theta 线捕捉长期趋势,而另一条捕捉短期波动。然后将这两条 theta 线结合起来产生最终预测。STM 假定模型参数随时间保持不变。

2. 优化 Theta 模型 (OTM)

优化 Theta 模型 通过搜索最佳的 theta 参数而不是使用固定值来扩展 STM。此优化步骤允许模型更好地拟合具有更高变异性的序列。

3. 动态标准 Theta 模型 (DSTM)

动态标准 Theta 模型 允许 STM 随时间调整。它不是保持参数静态不变,而是在新数据可用时动态更新参数。当预测具有不断演变的趋势或季节性的序列时,这种动态行为会很有用。

4. 动态优化 Theta 模型 (DOTM)

动态优化 Theta 模型 结合了 OTM 和 DSTM 的特性。与 OTM 一样,它会优化 theta 参数。与 DSTM 一样,它会利用新数据动态更新模型。

AutoTheta 如何选择最佳模型

  1. AutoTheta 会将 Theta 模型的所有四种变体(STM、OTM、DSTM 和 DOTM)拟合到您的数据。
  2. 每个模型都使用交叉验证或留出验证策略进行评估,具体取决于配置。
  3. 选择均方误差 (MSE) 最低的模型。
  4. 然后使用选定的模型生成预测。

加载库和数据

提示

需要 Statsforecast。要安装,请参阅说明

接下来,我们导入绘图库并配置绘图样式。

import pandas as pd

import scipy.stats as stats
import matplotlib.pyplot as plt
import seaborn as sns
from statsmodels.graphics.tsaplots import plot_acf, plot_pacf

plt.style.use('fivethirtyeight')
plt.rcParams['lines.linewidth'] = 1.5
dark_style = {
    'figure.facecolor': '#212946',
    'axes.facecolor': '#212946',
    'savefig.facecolor':'#212946',
    'axes.grid': True,
    'axes.grid.which': 'both',
    'axes.spines.left': False,
    'axes.spines.right': False,
    'axes.spines.top': False,
    'axes.spines.bottom': False,
    'grid.color': '#2A3459',
    'grid.linewidth': '1',
    'text.color': '0.9',
    'axes.labelcolor': '0.9',
    'xtick.color': '0.9',
    'ytick.color': '0.9',
    'font.size': 12 }
plt.rcParams.update(dark_style)

from pylab import rcParams
rcParams['figure.figsize'] = (18,7)

读取数据

df = pd.read_csv("https://raw.githubusercontent.com/Naren8520/Serie-de-tiempo-con-Machine-Learning/main/Data/candy_production.csv")
df.head()
observation_dateIPG3113N
01972-01-0185.6945
11972-02-0171.8200
21972-03-0166.0229
31972-04-0164.5645
41972-05-0165.0100

StatsForecast 的输入始终是一个长格式的数据框,包含三列:unique_id、ds 和 y

  • unique_id(字符串、整数或类别)表示序列的标识符。

  • ds(日期戳)列应采用 Pandas 期望的格式,日期最好是 YYYY-MM-DD,时间戳最好是 YYYY-MM-DD HH:MM:SS。

  • y(数值)表示我们希望预测的度量值。

df["unique_id"]="1"
df.columns=["ds", "y", "unique_id"]
df.head()
dsyunique_id
01972-01-0185.69451
11972-02-0171.82001
21972-03-0166.02291
31972-04-0164.56451
41972-05-0165.01001
print(df.dtypes)
ds            object
y            float64
unique_id     object
dtype: object

我们可以看到,我们的时间变量 (ds) 处于对象格式,我们需要将其转换为日期格式。

df["ds"] = pd.to_datetime(df["ds"])

使用 plot 方法探索数据

使用 StatsForecast 类中的 plot 方法绘制一些序列。此方法打印数据集中的随机序列,对基本探索性数据分析 (EDA) 非常有用。

from statsforecast import StatsForecast

StatsForecast.plot(df)

自相关图

fig, axs = plt.subplots(nrows=1, ncols=2)

plot_acf(df["y"],  lags=60, ax=axs[0],color="fuchsia")
axs[0].set_title("Autocorrelation");

plot_pacf(df["y"],  lags=60, ax=axs[1],color="lime")
axs[1].set_title('Partial Autocorrelation')

plt.show();

划分训练集和测试集

让我们将数据划分为两部分:1. 用于训练 AutoTheta 模型的数据 2. 用于测试模型的数据。

对于测试数据,我们将使用最近 12 个月的数据来测试和评估我们模型的性能。

train = df[df.ds<='2016-08-01'] 
test = df[df.ds>'2016-08-01']
train.shape, test.shape
((536, 3), (12, 3))

现在,让我们绘制训练数据和测试数据。

sns.lineplot(train,x="ds", y="y", label="Train", linewidth=3, linestyle=":")
sns.lineplot(test, x="ds", y="y", label="Test")
plt.ylabel("Candy Production")
plt.xlabel("Month")
plt.show()

使用 StatsForecast 实现 AutoTheta

加载库

from statsforecast import StatsForecast
from statsforecast.models import AutoTheta

实例化模型

导入并实例化模型。设置参数有时很棘手。Rob Hyndmann 大师关于季节周期的这篇文章可能会有用。season_length。

使用 mse 自动选择最佳 Theta 模型(标准 Theta 模型 (‘STM’)、优化 Theta 模型 (‘OTM’)、动态标准 Theta 模型 (‘DSTM’)、动态优化 Theta 模型 (‘DOTM’)))。

season_length = 12 # Monthly data 
horizon = len(test) # number of predictions

# We call the model that we are going to use
models = [AutoTheta(season_length=season_length,
                     decomposition_type="additive",
                     model="STM")]

我们通过实例化一个新的 StatsForecast 对象来拟合模型,参数如下:

models: 模型列表。从 models 中选择您想要的模型并导入它们。

  • freq: 一个字符串,指示数据的频率。(请参阅 panda 可用频率。)

  • n_jobs: n_jobs: int,并行处理中使用的作业数,使用 -1 表示所有核心。

  • fallback_model: 如果模型失败时使用的模型。

任何设置都传递给构造函数。然后调用其 fit 方法并传入历史数据框。

sf = StatsForecast(models=models, freq='MS')

拟合模型

sf.fit(df=train)
StatsForecast(models=[AutoTheta])

让我们看看 Theta 模型的结果。我们可以通过以下指令进行观察:

result=sf.fitted_[0,0].model_
result
{'mse': 100.57831864069415,
 'amse': array([26.13585578, 38.60211513, 44.70605915]),
 'fit': results(x=array([258.45064973,   0.7664297 ]), fn=100.57831864069415, nit=32, simplex=array([[250.37338496,   0.76970741],
        [232.03915522,   0.76429422],
        [258.45064973,   0.7664297 ]])),
 'residuals': array([-2.14815337e+02, -6.20562800e+01, -2.13256707e+01, -1.25845480e+01,
        -1.19719350e+01, -9.40876632e+00, -8.60141525e+00, -9.00054652e+00,
        -1.98778836e+00,  3.14564857e+01,  1.98519673e+01,  2.04962370e+01,
         4.98120196e+00, -1.08735375e+01, -1.12328024e+01, -8.08115377e+00,
        -9.98197589e+00, -8.39937098e+00, -1.25789505e+01, -1.05952806e+01,
         8.47229127e-01,  2.25644616e+01,  2.54401546e+01,  1.73989716e+01,
         2.40287275e+00, -2.53475866e+00, -8.00591135e+00, -1.79241479e+01,
        -6.36590693e+00, -5.76986468e+00, -2.26766759e+01, -8.95260931e+00,
        -7.19719166e+00,  2.74032238e+01,  2.21368457e+01,  6.43171676e+00,
        -3.51755220e+00, -1.31441941e+01, -6.13166031e+00,  1.51512150e+00,
        -8.05777104e+00, -8.59603388e+00, -1.08617851e+01, -6.72940177e+00,
        -6.24861641e+00,  2.85996828e+01,  2.98048030e+01,  1.90032238e+01,
         5.07597842e+00, -9.59170058e+00, -1.64521034e+01, -7.52212744e+00,
        -5.16540394e+00, -1.27924628e+01, -9.68434625e+00, -8.76758703e+00,
        -8.27475947e-01,  3.08424002e+01,  2.47947352e+01,  2.35867208e+01,
         3.75664716e+00, -4.47305717e+00, -1.48000403e+01, -1.08431546e+01,
        -1.01249972e+01, -1.12379765e+01, -1.28624644e+01, -9.47780103e+00,
        -2.17960841e-01,  2.49398648e+01,  1.66027782e+01,  2.62581230e+01,
        -1.94879264e+00, -8.10877843e+00, -6.93183679e+00, -6.80707596e+00,
        -1.17809892e+01, -1.05320670e+01, -1.59715849e+01, -9.07599923e+00,
         6.11988125e-01,  2.24925163e+01,  2.57389503e+01,  2.38907614e+01,
         4.99776202e+00, -1.07054696e+01, -7.24194672e+00, -1.17412084e+01,
        -1.10031559e+01, -9.10138831e+00, -1.62277209e+01, -1.02585250e+01,
        -2.79431476e+00,  1.96746051e+01,  2.40620700e+01,  2.00041920e+01,
         8.38674843e-01, -3.01708830e-01, -1.10576372e+01, -1.76502404e+01,
        -4.79853028e+00, -7.74057206e+00, -1.55628746e+01, -6.19663664e+00,
        -4.85267830e+00,  2.17819325e+01,  2.48075790e+01,  2.16186207e+01,
         9.21215745e+00, -1.71191202e+00, -1.38314188e+01, -9.44161337e+00,
        -6.35863884e+00, -1.10470671e+01, -1.41408736e+01, -9.60039945e+00,
        -4.80959619e+00,  3.41173952e+01,  2.02685767e+01,  1.65177446e+01,
         1.45004431e+00, -6.65011083e-01, -1.11027939e+01, -1.82545876e+01,
        -1.08637878e+01, -9.67573606e+00, -1.22946714e+01, -1.02064815e+01,
        -2.94225894e+00,  3.21840497e+01,  2.21586046e+01,  2.09073990e+01,
        -2.49862821e-01, -6.05605889e+00, -1.16741825e+01, -1.31096470e+01,
        -1.07043825e+01, -1.25489037e+01, -9.16715807e+00, -7.70278723e+00,
        -2.55657034e+00,  2.69936351e+01,  1.62042780e+01,  1.67614452e+01,
         8.62186552e+00, -3.51518668e+00, -9.27421021e+00, -1.15442848e+01,
        -9.96136043e+00, -1.17898558e+01, -1.13147670e+01, -7.10440489e+00,
        -1.10170600e+00,  2.60646482e+01,  2.32687942e+01,  1.82272063e+01,
         3.98792378e+00, -7.64233782e+00, -1.07945901e+01, -1.16024004e+01,
        -1.10645345e+01, -1.33282245e+01, -1.15534843e+01, -6.76286215e+00,
         3.93786824e+00,  2.37018431e+01,  2.07922131e+01,  2.37645505e+01,
         7.00182907e-01, -1.59605643e+00, -1.62277584e+01, -1.51068271e+01,
        -1.01377645e+01, -1.13639586e+01, -1.38275901e+01, -5.87092572e+00,
         3.43469809e+00,  2.82932175e+01,  2.39510218e+01,  1.71053544e+01,
         6.00992500e-01, -7.61224365e-01, -1.18686664e+01, -1.51989727e+01,
        -1.23352870e+01, -1.09931345e+01, -1.34086766e+01, -4.52127997e+00,
         2.09363525e+00,  3.13825850e+01,  2.43980063e+01,  1.89899567e+01,
        -7.55702038e+00, -2.76893846e-01, -6.52574120e+00, -1.67167241e+01,
        -1.17498886e+01, -7.68050287e+00, -5.60844424e+00, -2.79087739e+00,
        -2.92094111e-01,  2.31896495e+01,  1.70158799e+01,  1.84177113e+01,
        -3.39879920e-01,  1.31241579e+00, -9.65552567e+00, -1.30840488e+01,
        -1.33540036e+01, -9.72077648e+00, -1.09022916e+01, -4.49636288e+00,
        -6.88544858e-01,  1.88878504e+01,  2.15227074e+01,  2.32009723e+01,
        -5.72605223e+00,  1.87746593e+00, -6.95944675e+00, -1.41944248e+01,
        -1.25398544e+01, -8.09461542e+00, -5.46316863e+00, -4.73324533e+00,
         1.12162644e+00,  1.61183526e+01,  2.63470350e+01,  2.28827919e+01,
        -6.75326971e+00,  4.34023844e+00, -6.61711624e+00, -1.64533666e+01,
        -1.44473761e+01, -4.85575583e+00, -1.14659672e+01, -1.83412077e+00,
        -3.17492418e+00,  1.22586060e+01,  2.19162129e+01,  1.62630835e+01,
        -1.99943697e+00,  2.59255529e-03, -8.89996147e+00, -1.10976714e+01,
        -1.43864448e+01, -9.48222409e+00, -1.06785728e+01, -7.24340882e+00,
         2.15092681e+00,  1.53607666e+01,  2.06126854e+01,  1.96076182e+01,
         3.03104699e+00, -8.52358190e-02, -8.52357557e+00, -1.33461589e+01,
        -1.37600247e+01, -6.08841095e+00, -8.32367886e+00, -3.02117555e+00,
         4.08615082e-01,  1.63346143e+01,  1.76259473e+01,  1.75724049e+01,
         1.52688162e+00, -2.23616417e+00, -3.82136854e+00, -1.61943630e+01,
        -1.55739806e+01, -6.10489716e+00, -6.56542955e+00, -3.79160074e+00,
         1.79366664e+00,  1.37690213e+01,  1.71704010e+01,  2.12969028e+01,
         2.55881370e+00, -5.89333549e+00, -5.43867513e+00, -9.34441775e+00,
        -1.23296368e+01, -7.43701484e+00, -9.59827267e+00, -6.98198280e+00,
        -7.94911839e-01,  1.30601062e+01,  2.03392195e+01,  2.52824447e+01,
        -3.95418211e+00,  2.43162216e+00, -3.09611231e+00, -1.49779647e+01,
        -1.07287660e+01, -8.40149898e+00, -1.18887475e+01, -1.74756969e+00,
         2.17909158e+00,  1.20038451e+01,  2.42508083e+01,  2.34572756e+01,
        -5.17568738e+00, -1.96585193e-01, -4.18458348e+00, -1.55118992e+01,
        -1.38833773e+01, -8.29522246e+00, -1.30003245e+01, -1.67001046e-01,
         9.35165464e-01,  1.47274009e+01,  2.29308500e+01,  2.17103726e+01,
         3.68218796e+00,  2.64751368e-01, -7.34442896e+00, -1.25122452e+01,
        -1.14503472e+01, -8.19533891e+00, -1.15456946e+01, -2.81694273e+00,
        -1.50158220e+00,  1.14252490e+01,  2.08253654e+01,  1.93274939e+01,
         7.94218283e-01, -5.10392562e-01, -8.74257956e+00, -9.01561168e+00,
        -1.00192375e+01, -1.10908742e+01, -1.09129057e+01, -6.64424202e+00,
        -1.50482563e+00,  1.46897914e+01,  1.73829656e+01,  2.23508516e+01,
         8.64908482e+00,  6.22670938e-01, -6.68012958e+00, -5.70808463e+00,
        -1.80391974e+01, -7.97569860e+00, -1.19962932e+01, -5.55858916e+00,
         2.35415063e+00,  1.17526337e+01,  1.54009327e+01,  2.21564076e+01,
         3.90926848e+00,  2.21699063e+00, -3.80724386e+00, -1.09345639e+01,
        -1.37938477e+01, -1.00726110e+01, -1.19963696e+01, -5.40000702e+00,
        -1.51910929e+00,  1.69895520e+00,  1.74367921e+01,  2.04883238e+01,
         7.55305367e+00,  7.29570618e-01, -5.09536099e+00, -1.29493298e+01,
        -1.53454372e+01, -2.46711622e+00, -1.01903520e+01, -4.03697494e+00,
        -3.08084548e+00,  3.86928001e+00,  1.92764155e+01,  1.55958052e+01,
         7.35560665e+00,  1.85905286e+00, -5.61647492e-01, -1.23394890e+01,
        -9.90369650e+00, -7.50968724e+00, -1.83651468e+01, -2.77916418e+00,
        -1.07805825e+00,  8.15877162e+00,  2.33477133e+01,  1.69720395e+01,
         6.19355409e+00,  4.92033190e+00, -1.36452236e+01, -1.10382237e+01,
        -4.45625959e+00, -1.37976278e+01, -1.12070229e+01, -1.28293907e+00,
         1.02615489e-01,  1.16373419e+01,  1.73964040e+01,  1.64050904e+01,
         1.32632316e+01,  4.44789857e+00, -1.66636700e+01, -1.04932431e+01,
        -7.27536831e+00, -1.52095878e+01, -8.33331485e+00, -6.12562623e+00,
        -6.19892381e-01,  1.73375856e+01,  1.71076116e+01,  2.30092371e+01,
        -1.39793588e+00,  1.20108534e+00, -1.01506292e+01, -9.35709025e+00,
        -1.72524967e+01, -1.33257487e+01, -1.11436060e+01, -1.07822676e+00,
         2.29723021e+00,  1.15489387e+01,  1.72661557e+01,  2.11762682e+01,
         9.51783705e+00, -1.02191435e+00, -5.14895585e+00, -2.05301479e+01,
        -1.56429911e+01, -1.60412160e+01, -1.50915585e+01, -2.94815119e+00,
         4.61947140e+00,  6.94204531e+00,  1.79378222e+01,  2.19333496e+01,
         8.01926876e+00, -3.09873539e+00, -6.33383956e+00, -1.29668016e+01,
        -1.54450181e+01, -1.27736754e+01, -1.46733580e+01, -8.76927199e+00,
         8.56843050e+00,  1.28259048e+01,  1.86473170e+01,  5.73666651e+00,
         4.33460471e+00,  2.08833654e+00, -3.96959363e+00, -1.29223840e+01,
        -1.19550435e+01, -1.27279210e+01, -8.02537118e+00, -3.92329973e+00,
         7.09140567e+00,  2.42153157e+01,  1.28924451e+01,  1.79711994e+01,
         2.89522816e+00,  1.30474094e+00, -7.77941829e+00, -1.04361458e+01,
        -1.14357321e+01, -1.23868252e+01, -3.73410135e+00,  6.47313429e-01,
         5.14176514e+00,  1.16621376e+01,  8.00349556e+00,  1.83900860e+01,
         3.46846764e+00,  2.29413265e+00, -4.06962578e+00, -8.55164849e+00,
        -1.76399695e+01, -1.50423508e+01, -1.13765532e+01, -9.17973632e+00,
        -4.22254178e+00,  2.19090137e+01,  1.90170614e+01,  1.80606278e+01,
         4.08981599e+00,  2.02346117e+00, -5.45474659e+00, -1.38725716e+01,
        -1.50622791e+01, -1.15367789e+01, -7.55445577e+00, -1.77510788e+00,
         9.46335947e+00,  4.88813367e+00,  1.61490895e+01,  1.93212548e+01,
         1.03075610e+01, -6.46758291e-01, -5.79530543e-01, -1.35917659e+01,
        -1.62148912e+01, -1.29823949e+01, -1.02149087e+01, -3.24211066e+00,
         3.05411201e-01,  1.19385090e+01,  2.08979477e+01,  2.19927470e+01,
         1.32364223e+00,  1.68626515e+00, -3.52030557e+00, -1.50337436e+01,
        -1.75865944e+01, -1.23980840e+01, -1.19670311e+01, -1.59575440e+00,
         4.32015112e+00,  1.39461330e+01,  2.63901690e+01,  2.11431667e+01,
         1.19960552e+00,  1.22769386e+00, -3.12851420e+00, -1.23388328e+01,
        -1.66429432e+01, -9.08277509e+00, -7.92637338e+00,  2.43702321e+00,
        -3.53211182e+00,  1.00606776e+01,  1.39608421e+01,  1.44689452e+01,
         6.50770562e+00,  3.13940836e+00, -4.89894478e-01, -1.05833296e+01,
        -1.34863098e+01, -1.20763793e+01, -1.00738904e+01, -9.39207297e+00]),
 'm': 12,
 'states': array([[1.24021769e+02, 8.30544193e+01, 8.40569047e+01, 6.14692129e-02,
         3.00509837e+02],
        [8.50161150e+01, 7.80917592e+01, 8.40569047e+01, 6.14692129e-02,
         1.33876280e+02],
        [7.65735210e+01, 7.67280499e+01, 8.40569047e+01, 6.14692129e-02,
         8.73485707e+01],
        ...,
        [1.12984989e+02, 1.00517846e+02, 8.40569047e+01, 6.14692129e-02,
         1.14480779e+02],
        [1.14049672e+02, 1.00543746e+02, 8.40569047e+01, 6.14692129e-02,
         1.13025090e+02],
        [1.10946036e+02, 1.00561388e+02, 8.40569047e+01, 6.14692129e-02,
         1.14089773e+02]]),
 'par': {'initial_smoothed': 258.45064973324986,
  'alpha': 0.7664297044277045,
  'theta': 2.0},
 'n': 536,
 'modeltype': 'STM',
 'mean_y': 100.56138830499272,
 'decompose': True,
 'decomposition_type': 'additive',
 'seas_forecast': {'mean': array([  0.08977811,  18.09442035,  20.24848682,  19.4306462 ,
           2.64008067,  -1.30909907,  -7.97773123, -12.32640613,
         -12.02777406, -10.1369666 , -11.42293515,  -5.30249992])},
 'fitted': array([300.50983667, 133.87628004,  87.34857069,  77.149048  ,
         76.98193501,  77.05546632,  77.64431525,  79.83754652,
         77.03398836,  75.47241431,  85.74423271,  85.47106297,
         86.31849804,  88.14353755,  80.8438024 ,  78.37975377,
         81.66417589,  83.26287098,  84.62535048,  83.7700806 ,
         79.74427087,  80.35553844,  83.81224541,  87.82202841,
         86.29562725,  86.14455866,  85.23591135,  85.24504787,
         80.98550693,  85.35566468,  88.73347592,  80.13900931,
         77.37219166,  71.81797618,  78.98325428,  80.46128324,
         70.5292522 ,  65.84059407,  56.80056031,  58.2461785 ,
         68.88547104,  71.95893388,  73.17068509,  73.63150177,
         72.56861641,  67.74141718,  75.823697  ,  83.17867619,
         82.88182158,  84.77950058,  78.46220336,  71.99792744,
         75.71080394,  81.00106285,  78.99654625,  80.35978703,
         77.73477595,  77.0624998 ,  86.86366484,  90.37877922,
         93.59485284,  94.48135717,  92.0871403 ,  86.88905458,
         88.05659723,  89.54567652,  88.73256441,  87.66000103,
         84.49066084,  84.28553517,  89.56282177,  86.79937702,
         92.06289264,  88.57657843,  83.39583679,  84.22817596,
         88.48908916,  88.70896704,  88.43688493,  84.98139923,
         82.12001187,  82.55098375,  85.9525497 ,  90.19133861,
         93.64043798,  95.47816961,  88.30724672,  88.90190843,
         89.38115594,  90.19718831,  91.02162087,  87.36982498,
         83.60211476,  81.42239492,  82.66423005,  85.61780804,
         86.08812516,  84.73820883,  85.54103724,  83.21124039,
         79.16163028,  84.73307206,  86.60047462,  83.45823664,
         82.8036783 ,  79.0463675 ,  81.90332096,  85.42827927,
         87.13594255,  92.20371202,  91.92571881,  87.47001337,
         89.71173884,  94.08746707,  93.42067364,  91.36829945,
         88.10499619,  84.3807048 ,  96.69192329,  96.73805538,
         94.53625569,  93.65491108,  94.17929385,  91.81488763,
         87.30208784,  88.22493606,  88.60917145,  87.97178146,
         84.24395894,  81.95085029,  92.78029537,  94.275001  ,
         95.43756282,  93.25335889,  89.64588248,  86.84354705,
         86.27398255,  87.31900372,  85.50115807,  87.26078723,
         85.45187034,  83.45436489,  90.30572204,  87.23685484,
         85.22183448,  89.83718668,  88.17711021,  87.21418481,
         87.84436043,  89.45885582,  88.22276703,  88.33640489,
         86.986106  ,  86.10365179,  92.24300578,  94.58859369,
         93.69697622,  94.76073782,  89.93749012,  87.8093004 ,
         88.3949345 ,  89.16392452,  86.74878426,  86.67946215,
         85.59093176,  88.57095695,  92.89938688,  93.34684947,
         96.69921709,  95.24315643,  95.05395839,  88.76162712,
         86.66136449,  88.14065857,  87.23099008,  85.41872572,
         85.01380191,  87.60818255,  95.45557821,  98.3240456 ,
         96.5726075 ,  95.04052437,  95.49116642,  92.53977272,
         90.36888696,  90.16393454,  89.5384766 ,  88.04727997,
         88.67676475,  90.24331499, 100.45849371, 103.6695433 ,
        103.36252038,  95.57789385,  96.3997412 ,  97.54332409,
         94.2091886 ,  94.45290287,  96.36634424, 100.85347739,
        102.80919411, 102.5472505 , 106.48312008, 104.03628873,
        103.29067992, 101.03748421, 103.07742567, 101.82224878,
        101.27230355, 100.28657648, 100.63629155, 101.06606288,
        101.71464486, 101.14884962, 101.78769257, 102.7950277 ,
        105.71545223,  99.33413407, 101.80714675, 102.61832483,
        101.21735442, 100.85561542, 102.45166863, 107.05014533,
        107.51717356, 108.33874738, 106.85496498, 111.55980808,
        114.23636971, 107.06776156, 111.42831624, 112.50186659,
        109.36957611, 107.54585583, 111.62426724, 111.62202077,
        114.31102418, 111.83959398, 107.39758714, 108.70651652,
        106.30953697, 102.78440744, 103.82046147, 103.14437142,
        104.11684481, 102.33982409, 102.8723728 , 103.47360882,
        102.01677319, 103.62723339, 101.56281457, 101.87268181,
        102.03905301, 102.36943582, 103.33817557, 102.95055886,
        102.19972468, 100.90281095, 104.03647886, 106.44257555,
        108.22178492, 108.49688565, 107.17885267, 105.19959511,
        103.80611838, 102.98366417, 102.30386854, 105.52016297,
        102.58638056,  99.89919716, 103.02022955, 106.77390074,
        107.96263336, 109.29927875, 106.01489901, 103.6864972 ,
        105.1475863 , 105.11603549, 101.63327513, 103.61001775,
        105.92623683, 105.72561484, 107.82567267, 109.2548828 ,
        107.99841184, 107.35109379, 103.5233805 , 103.62365533,
        108.13938211, 103.11607784, 106.01381231, 109.78596466,
        107.78446605, 108.81079898, 110.17164752, 109.84536969,
        112.60070842, 114.23275493, 109.59549173, 112.69372438,
        115.81058738, 109.85108519, 110.73448348, 113.67239919,
        111.26167729, 109.87022246, 111.31252448, 110.13430105,
        114.10103454, 114.77969912, 112.22984999, 114.31642742,
        116.09441204, 116.92384863, 118.16082896, 118.67694524,
        118.56524723, 119.03853891, 120.55739465, 120.49404273,
        122.4297822 , 121.24085099, 116.16013458, 116.63300608,
        116.58468172, 115.20069256, 115.84357956, 115.28811168,
        117.8563375 , 119.42647419, 118.72610568, 119.14774202,
        118.15012563, 116.95870856, 114.38003444, 112.21454843,
        114.48341518, 119.11962906, 120.63092958, 121.65618463,
        126.75939743, 122.1827986 , 123.8699932 , 123.46128916,
        123.29574937, 125.06196634, 120.23216725, 116.54759242,
        118.66743152, 119.67090937, 122.40414386, 125.63126387,
        126.72874773, 125.40591101, 125.48596965, 125.07720702,
        125.03320929, 123.8308448 , 111.2956079 , 109.17137615,
        110.01274633, 113.80892938, 115.40216099, 117.64202977,
        117.19533719, 114.68331622, 120.59245198, 121.56787494,
        122.56854548, 120.16921999, 109.29738449, 108.58309477,
        105.67469335, 109.31954714, 111.77844749, 117.49308896,
        117.5137965 , 119.17248724, 121.21684679, 115.92686418,
        117.89155825, 117.02722838, 109.44298667, 111.84906053,
        109.99544591, 112.7496681 , 117.55482364, 113.24182371,
        114.25985959, 120.09362779, 117.31872292, 117.51493907,
        120.62638451, 120.66695807, 115.74879597, 113.59360961,
        111.30546837, 119.47810143, 123.92117003, 117.29474313,
        118.73046831, 122.40358785, 118.54651485, 120.94522623,
        120.34509238, 119.83191444, 119.28258839, 116.90606293,
        119.67953588, 116.61541466, 118.57002916, 116.93539025,
        119.24189675, 115.26824869, 112.85500598, 113.09982676,
        116.36816979, 118.09076126, 113.10484433, 110.84983175,
        112.21846295, 117.52051435, 117.77135585, 119.97014793,
        113.71329113, 110.97321598, 106.47875848, 103.69775119,
        105.5329286 , 109.03535469, 100.5185778 ,  98.7783504 ,
        100.72723124, 104.88073539, 103.53983956, 104.83050157,
        104.37041809, 101.78207536,  99.79195805,  97.33147199,
         94.7051695 , 101.23419515,  97.22698298,  96.03053349,
         85.56579529,  86.89526346,  89.52989363,  92.63258395,
         92.20654345,  92.29302095,  90.33797118,  92.97269973,
         94.06049433,  99.45748428, 104.17945492,  98.57230063,
         97.48447184,  97.71075906,  99.74481829,  99.92754582,
        101.40703208, 101.89152524, 100.19790135, 106.12158657,
        110.71243486, 114.61516239, 109.71600444, 100.36181401,
         99.59503236, 100.26066735, 103.05302578, 106.07904849,
        109.00286948, 104.73225081, 101.00335324, 101.06963632,
         98.12874178,  94.85438633,  97.80873857,  96.89567218,
         95.87638401,  97.01823883,  99.60314659, 101.56757157,
        100.41327905,  98.11827889,  97.07615577, 100.07180788,
        102.80604053, 110.02096633,  99.93001054,  96.81884524,
         96.765739  , 102.67305829, 103.21143054, 108.91236591,
        107.9732912 , 104.79489485, 102.64480872, 103.60141066,
        105.2112888 , 105.40729101, 100.7199523 , 101.24845301,
        103.24285777, 102.26463485, 104.59110557, 108.03814361,
        105.99389435, 101.76418396, 100.06193105,  99.6756544 ,
        102.54734888, 105.82036702, 102.67173097, 107.40963325,
        108.75289448, 107.67960614, 109.6546142 , 113.40193278,
        113.42314323, 109.91667509, 110.75537338, 113.46597679,
        119.42851182, 116.6833224 , 110.55675793, 105.76845483,
        101.99639438, 104.99139164, 108.43159448, 114.20122959,
        115.56790983, 114.4807793 , 113.0250904 , 114.08977297])}

现在让我们可视化模型的残差。

如我们所见,上面获得的结果以字典形式输出,要从字典中提取每个元素,我们将使用 .get() 函数提取元素,然后将其保存到 pd.DataFrame() 中。

residual=pd.DataFrame(result.get("residuals"), columns=["residual Model"])
residual
残差模型
0-214.815337
1-62.056280
2-21.325671
533-12.076379
534-10.073890
535-9.392073
fig, axs = plt.subplots(nrows=2, ncols=2)

residual.plot(ax=axs[0,0])
axs[0,0].set_title("Residuals");

sns.distplot(residual, ax=axs[0,1]);
axs[0,1].set_title("Density plot - Residual");

stats.probplot(residual["residual Model"], dist="norm", plot=axs[1,0])
axs[1,0].set_title('Plot Q-Q')

plot_acf(residual,  lags=35, ax=axs[1,1],color="fuchsia")
axs[1,1].set_title("Autocorrelation");

plt.show();

预测方法

如果您想在处理多个序列或模型的生产环境中提高速度,我们建议使用 StatsForecast.forecast 方法,而不是 .fit.predict

主要区别在于 .forecast 不存储拟合值,并且在分布式环境中具有高度可扩展性。

forecast 方法接受两个参数:预测接下来的 h(预测范围)和 level

  • h (int): 表示向未来预测的 h 个时间步。在此示例中,即未来 12 个月。

  • level (float 列表): 此可选参数用于概率预测。设置预测区间的级别(或置信百分位数)。例如,level=[90] 表示模型期望实际值有 90% 的时间落在该区间内。

这里的 forecast 对象是一个新的数据框,其中包括一列模型名称和 y 帽值,以及表示不确定性区间的列。根据您的计算机,此步骤大约需要 1 分钟。(如果您想加快速度到几秒钟,请移除像 ARIMATheta 这样的 Auto 模型)

# Prediction
Y_hat = sf.forecast(df=train, h=horizon, fitted=True)
Y_hat
unique_iddsAutoTheta
012016-09-01111.075915
112016-10-01129.111292
212016-11-01131.296093
912017-06-01101.125782
1012017-07-0199.870548
1112017-08-01106.021718
values=sf.forecast_fitted_values()
values.head()
unique_iddsyAutoTheta
011972-01-0185.6945300.509837
111972-02-0171.8200133.876280
211972-03-0166.022987.348571
311972-04-0164.564577.149048
411972-05-0165.010076.981935
StatsForecast.plot(values)

使用 forecast 方法添加 95% 置信区间

sf.forecast(df=train, h=horizon, level=[95])
unique_iddsAutoThetaAutoTheta-lo-95AutoTheta-hi-95
012016-09-01111.07591590.139234136.011109
112016-10-01129.11129294.795409160.387128
212016-11-01131.29609390.579813168.268538
912017-06-01101.12578241.186268159.159903
1012017-07-0199.87054835.144354152.867267
1112017-08-01106.02171838.753454166.048584
# Merge the forecasts with the true values
Y_hat1 = test.merge(Y_hat, how='left', on=['unique_id', 'ds'])
Y_hat1
dsyunique_idAutoTheta
02016-09-01109.31911111.075915
12016-10-01119.05021129.111292
22016-11-01116.84311131.296093
92017-06-01104.20221101.125782
102017-07-01102.5861199.870548
112017-08-01114.06131106.021718
sf.plot(train, Y_hat1)

使用 predict 方法添加置信区间

使用 predict 方法生成预测。

predict 方法接受两个参数:预测接下来的 h(预测范围)和 level

  • h (int): 表示向未来预测的 h 个时间步。在此示例中,即未来 12 个月。

  • level (float 列表): 此可选参数用于概率预测。设置预测区间的级别(或置信百分位数)。例如,level=[95] 表示模型期望实际值有 95% 的时间落在该区间内。

这里的 forecast 对象是一个新的数据框,其中包括一列模型名称和 y 帽值,以及表示不确定性区间的列。

此步骤应少于 1 秒。

sf.predict(h=horizon)
unique_iddsAutoTheta
012016-09-01111.075915
112016-10-01129.111292
212016-11-01131.296093
912017-06-01101.125782
1012017-07-0199.870548
1112017-08-01106.021718
forecast_df = sf.predict(h=horizon, level=[95]) 
forecast_df
unique_iddsAutoThetaAutoTheta-lo-95AutoTheta-hi-95
012016-09-01111.07591590.139234136.011109
112016-10-01129.11129294.795409160.387128
212016-11-01131.29609390.579813168.268538
912017-06-01101.12578241.186268159.159903
1012017-07-0199.87054835.144354152.867267
1112017-08-01106.02171838.753454166.048584
sf.plot(train, test.merge(forecast_df), level=[95])

交叉验证

在前面的步骤中,我们使用历史数据预测未来。然而,为了评估其准确性,我们还需要知道模型在过去会如何表现。为了评估您的模型在数据上的准确性和鲁棒性,请执行交叉验证。

对于时间序列数据,交叉验证是通过在历史数据上定义一个滑动窗口并预测其后时段来完成的。这种形式的交叉验证使我们能够在更广泛的时间实例范围内更好地估计模型的预测能力,同时保持训练集中的数据连续,这是我们的模型所必需的。

下图描述了这种交叉验证策略

执行时间序列交叉验证

时间序列模型的交叉验证被认为是最佳实践,但大多数实现都非常慢。statsforecast 库将交叉验证作为分布式操作来实现,从而减少了执行该过程所需的时间。如果您有大型数据集,您还可以使用 Ray、Dask 或 Spark 在分布式集群中执行交叉验证。

在此示例中,我们希望评估每个模型在过去 5 个月(n_windows=5)的性能,每隔 12 个时间步进行一次预测(step_size=12)。根据您的计算机,此步骤大约需要 1 分钟。

StatsForecast 类中的 cross_validation 方法接受以下参数。

  • df: 训练数据框。

  • h (int): 表示向未来预测的 h 个时间步。在此示例中,即未来 12 个月。

  • step_size (int): 每个窗口之间的步长。换句话说:您想多久运行一次预测过程。

  • n_windows(int): 用于交叉验证的窗口数。换句话说:您想评估过去多少个预测过程。

crossvalidation_df = sf.cross_validation(
    df=train,
    h=horizon,
    step_size=12,
    n_windows=5
)

crossvaldation_df 对象是一个新的数据框,包含以下列:

  • unique_id: 序列标识符。
  • ds: 日期戳或时间索引。
  • cutoff: n_windows 中最后一个日期戳或时间索引。
  • y: 真实值。
  • "model": 包含模型名称和拟合值的列。
crossvalidation_df
unique_iddscutoffyAutoTheta
012011-09-012011-08-0193.906298.167469
112011-10-012011-08-01116.7634116.969932
212011-11-012011-08-01116.8258119.135142
5712016-06-012015-08-01102.4044109.600469
5812016-07-012015-08-01102.9512108.260160
5912016-08-012015-08-01104.6977114.248270

模型评估

现在我们将使用预测结果评估我们的模型,我们将使用不同类型的指标 MAE、MAPE、MASE、RMSE、SMAPE 来评估准确性。

from functools import partial

import utilsforecast.losses as ufl
from utilsforecast.evaluation import evaluate
evaluate(
    test.merge(Y_hat),
    metrics=[ufl.mae, ufl.mape, partial(ufl.mase, seasonality=season_length), ufl.rmse, ufl.smape],
    train_df=train,
)
unique_id指标AutoTheta
01mae6.281513
11mape0.055683
21mase1.212473
31rmse7.683669
41smape0.027399

参考资料

  1. Jose A. Fiorucci, Tiago R. Pellegrini, Francisco Louzada, Fotios Petropoulos, Anne B. Koehler (2016)。“优化 theta 方法的模型及其与状态空间模型的关系”。《国际预测杂志》.
  2. Nixtla 参数.
  3. Pandas 可用频率.
  4. Rob J. Hyndman 和 George Athanasopoulos (2018)。《预测:原理与实践(第 3 版)》
  5. 季节周期 - Rob J Hyndman.