在本 Notebook 中,我们将展示如何使用 StatsForecastray 在不到 6 分钟内预测数千个时间序列(M5 数据集)。此外,我们还将展示 StatsForecast 在时间性能和准确性方面优于使用 DataBricks 在 Spark 集群上运行的 Prophet

在此示例中,我们使用了由 11 个 m5.2xlarge 类型(8 核,32 GB RAM)实例组成的 ray 集群 (AWS)。

安装 StatsForecast 库

!pip install "statsforecast[ray]" neuralforecast s3fs pyarrow
from time import time

import pandas as pd
from neuralforecast.data.datasets.m5 import M5, M5Evaluation
from statsforecast import StatsForecast
from statsforecast.models import ETS

下载数据

该示例使用 M5 数据集。它包含 30,490 个底部时间序列。

Y_df = pd.read_parquet('s3://m5-benchmarks/data/train/target.parquet')
Y_df = Y_df.rename(columns={
    'item_id': 'unique_id', 
    'timestamp': 'ds', 
    'demand': 'y'
})
Y_df['ds'] = pd.to_datetime(Y_df['ds'])
Y_df.head()
unique_iddsy
0FOODS_1_001_CA_12011-01-293.0
1FOODS_1_001_CA_12011-01-300.0
2FOODS_1_001_CA_12011-01-310.0
3FOODS_1_001_CA_12011-02-011.0
4FOODS_1_001_CA_12011-02-024.0

由于 M5 数据集包含间歇性时间序列,我们添加一个常数以避免训练阶段出现问题。稍后,我们将从预测结果中减去该常数。

constant = 10
Y_df['y'] += constant

训练模型

StatsForecast 接收一个模型列表来拟合每个时间序列。由于我们处理的是日数据,使用 7 作为季节性会有益。请注意,我们需要将 ray 地址传递给 ray_address 参数。

fcst = StatsForecast(
    df=Y_df, 
    models=[ETS(season_length=7, model='ZNA')], 
    freq='D', 
    #n_jobs=-1
    ray_address='ray://ADDRESS:10001'
)
init = time()
Y_hat = fcst.forecast(28)
end = time()
print(f'Minutes taken by StatsForecast using: {(end - init) / 60}')
/home/ubuntu/miniconda/envs/ray/lib/python3.7/site-packages/ray/util/client/worker.py:618: UserWarning: More than 10MB of messages have been created to schedule tasks on the server. This can be slow on Ray Client due to communication overhead over the network. If you're running many fine-grained tasks, consider running them inside a single remote function. See the section on "Too fine-grained tasks" in the Ray Design Patterns document for more details: https://docs.google.com/document/d/167rnnDFIVRhHhK4mznEIemOtj63IOhtIPvSYaPgI4Fg/edit#heading=h.f7ins22n6nyl. If your functions frequently use large objects, consider storing the objects remotely with ray.put. An example of this is shown in the "Closure capture of large / unserializable object" section of the Ray Design Patterns document, available here: https://docs.google.com/document/d/167rnnDFIVRhHhK4mznEIemOtj63IOhtIPvSYaPgI4Fg/edit#heading=h.1afmymq455wu
  UserWarning,
Minutes taken by StatsForecast using: 5.4817593971888225

StatsForecastray 仅用时 5.48 分钟就训练了 30,490 个时间序列,而 Prophet 和 Spark 用时 18.23 分钟。

我们移除常数。

Y_hat['ETS'] -= constant

评估性能

M5 竞赛使用了加权均方根比例误差。您可以在此处找到有关该指标的详细信息。

Y_hat = Y_hat.reset_index().set_index(['unique_id', 'ds']).unstack()
Y_hat = Y_hat.droplevel(0, 1).reset_index()
*_, S_df = M5.load('./data')
Y_hat = S_df.merge(Y_hat, how='left', on=['unique_id'])
100%|███████████████████████████████████████████████████████████| 50.2M/50.2M [00:00<00:00, 77.1MiB/s]
M5Evaluation.evaluate(y_hat=Y_hat, directory='./data')
wrmsse
总计0.677233
Level10.435558
Level20.522863
Level30.582109
Level40.488484
Level50.567825
Level60.587605
Level70.662774
Level80.647712
Level90.732107
Level101.013124
Level110.970465
Level120.916175

此外,StatsForecast 比 Prophet 更准确,因为整体 WMRSSE 为 0.68,而 Prophet 获得的是 0.77