在预测中,我们常常对预测的分布感兴趣,而不仅仅是点预测,因为我们想了解预测周围的不确定性。

为此,我们可以创建分位数预测

分位数预测具有直观的解释,因为它们表示预测分布的特定百分位数。这使我们可以做出诸如“我们预计 90% 的航空旅客观测值将高于 100”之类的陈述。这种方法有助于在不确定性下进行规划,提供一系列可能的未来值,并帮助用户通过考虑所有可能的未来结果来做出更明智的决策。

使用 TimeGPT,我们可以创建预测分布,并提取指定百分位数的分位数预测。例如,第 25 和第 75 分位数分别提供了对预期结果的下四分位数和上四分位数的见解,而第 50 分位数(即中位数)则提供了中心估计值。

TimeGPT 使用 共形预测 来生成分位数。

1. 导入包

首先,我们导入所需的包并初始化 Nixtla 客户端

import pandas as pd
from nixtla import NixtlaClient

from IPython.display import display
nixtla_client = NixtlaClient(
    # defaults to os.environ.get("NIXTLA_API_KEY")
    api_key = 'my_api_key_provided_by_nixtla'
)

👍 使用 Azure AI 端点

要使用 Azure AI 端点,请设置 base_url 参数

nixtla_client = NixtlaClient(base_url="you azure ai endpoint", api_key="your api_key")

2. 加载数据

df = pd.read_csv('https://raw.githubusercontent.com/Nixtla/transfer-learning-time-series/main/datasets/air_passengers.csv')
df.head()
时间戳
01949-01-01112
11949-02-01118
21949-03-01132
31949-04-01129
41949-05-01121

3. 使用分位数进行预测

使用 TimeGPT 进行时间序列预测时,您可以设置想要预测的分位数。操作方法如下

quantiles = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9]
timegpt_quantile_fcst_df = nixtla_client.forecast(
    df=df, h=12, 
    quantiles=quantiles, 
    time_col='timestamp', target_col='value',
)
timegpt_quantile_fcst_df.head()
INFO:nixtla.nixtla_client:Validating inputs...
INFO:nixtla.nixtla_client:Preprocessing dataframes...
INFO:nixtla.nixtla_client:Inferred freq: MS
INFO:nixtla.nixtla_client:Restricting input...
INFO:nixtla.nixtla_client:Calling Forecast Endpoint...
时间戳TimeGPTTimeGPT-q-10TimeGPT-q-20TimeGPT-q-30TimeGPT-q-40TimeGPT-q-50TimeGPT-q-60TimeGPT-q-70TimeGPT-q-80TimeGPT-q-90
01961-01-01437.837952431.987091435.043799435.384363436.402155437.837952439.273749440.291541440.632104443.688812
11961-02-01426.062744412.704956414.832837416.042432421.719196426.062744430.406293436.083057437.292651439.420532
21961-03-01463.116577437.412564444.234985446.420233450.705762463.116577475.527393479.812921481.998169488.820590
31961-04-01478.244507448.726837455.428375465.570038469.879114478.244507486.609900490.918976501.060638507.762177
41961-05-01505.646484478.409872493.154315497.990848499.138708505.646484512.154260513.302121518.138654532.883096

📘 Azure AI 中可用的模型

如果您使用 Azure AI 端点,请务必设置 model="azureai"

nixtla_client.forecast(..., model="azureai")

对于公共 API,我们支持两种模型:timegpt-1timegpt-1-long-horizon

默认使用 timegpt-1。有关如何以及何时使用 timegpt-1-long-horizon 的信息,请参阅本教程

TimeGPT 将以 TimeGPT-q-{int(100 * q)} 的格式返回每个分位数 q 的预测结果。

nixtla_client.plot(
    df, timegpt_quantile_fcst_df, 
    time_col='timestamp', target_col='value',
)

重要的是要注意,分位数(或多个分位数)的选择取决于您的具体用例。对于风险较高的预测,您可能倾向于使用更保守的分位数,例如第 10 或第 20 百分位数,以确保为最坏情况做好准备。另一方面,如果您处于过度准备成本很高的情况下,您可能会选择更接近中位数的分位数,例如第 50 百分位数,以平衡谨慎和效率。

例如,如果您在大型促销活动期间管理零售企业的库存,选择较低的分位数可能有助于您避免库存不足,即使这意味着您可能会稍微多备一些库存。但是,如果您是为餐厅安排员工,您可能会选择更接近中位数的分位数,以确保在不过度配备人员的情况下手头有足够的员工。

最终,选择取决于您在特定情境中理解风险和成本之间的平衡,而使用 TimeGPT 的分位数预测可以帮助您完美地调整您的策略以适应这种平衡。

历史预测

您也可以通过添加 add_history=True 参数来计算历史预测的分位数预测,如下所示

timegpt_quantile_fcst_df = nixtla_client.forecast(
    df=df, h=12, 
    quantiles=quantiles, 
    time_col='timestamp', target_col='value',
    add_history=True,
)
timegpt_quantile_fcst_df.head()
INFO:nixtla.nixtla_client:Validating inputs...
INFO:nixtla.nixtla_client:Preprocessing dataframes...
INFO:nixtla.nixtla_client:Inferred freq: MS
INFO:nixtla.nixtla_client:Calling Forecast Endpoint...
INFO:nixtla.nixtla_client:Calling Historical Forecast Endpoint...
时间戳TimeGPTTimeGPT-q-10TimeGPT-q-20TimeGPT-q-30TimeGPT-q-40TimeGPT-q-50TimeGPT-q-60TimeGPT-q-70TimeGPT-q-80TimeGPT-q-90
01951-01-01135.483673111.937768120.020593125.848879130.828935135.483673140.138411145.118467150.946753159.029579
11951-02-01144.442398120.896493128.979318134.807604139.787660144.442398149.097136154.077192159.905478167.988304
21951-03-01157.191910133.646004141.728830147.557116152.537172157.191910161.846648166.826703172.654990180.737815
31951-04-01148.769363125.223458133.306284139.134570144.114625148.769363153.424102158.404157164.232443172.315269
41951-05-01140.472946116.927041125.009866130.838152135.818208140.472946145.127684150.107740155.936026164.018852
nixtla_client.plot(
    df, timegpt_quantile_fcst_df, 
    time_col='timestamp', target_col='value',
)

交叉验证

quantiles 参数也可以包含在 cross_validation 方法中,从而允许比较 TimeGPT 在不同时间窗口和不同分位数上的性能。

timegpt_cv_quantile_fcst_df = nixtla_client.cross_validation(
    df=df, 
    h=12, 
    n_windows=5,
    quantiles=quantiles, 
    time_col='timestamp', 
    target_col='value',
)
timegpt_quantile_fcst_df.head()
INFO:nixtla.nixtla_client:Validating inputs...
INFO:nixtla.nixtla_client:Inferred freq: MS
INFO:nixtla.nixtla_client:Validating inputs...
INFO:nixtla.nixtla_client:Preprocessing dataframes...
INFO:nixtla.nixtla_client:Inferred freq: MS
INFO:nixtla.nixtla_client:Restricting input...
INFO:nixtla.nixtla_client:Calling Forecast Endpoint...
INFO:nixtla.nixtla_client:Validating inputs...
INFO:nixtla.nixtla_client:Validating inputs...
INFO:nixtla.nixtla_client:Preprocessing dataframes...
INFO:nixtla.nixtla_client:Inferred freq: MS
INFO:nixtla.nixtla_client:Restricting input...
INFO:nixtla.nixtla_client:Calling Forecast Endpoint...
INFO:nixtla.nixtla_client:Validating inputs...
INFO:nixtla.nixtla_client:Validating inputs...
INFO:nixtla.nixtla_client:Preprocessing dataframes...
INFO:nixtla.nixtla_client:Inferred freq: MS
INFO:nixtla.nixtla_client:Restricting input...
INFO:nixtla.nixtla_client:Calling Forecast Endpoint...
INFO:nixtla.nixtla_client:Validating inputs...
INFO:nixtla.nixtla_client:Validating inputs...
INFO:nixtla.nixtla_client:Preprocessing dataframes...
INFO:nixtla.nixtla_client:Inferred freq: MS
INFO:nixtla.nixtla_client:Restricting input...
INFO:nixtla.nixtla_client:Calling Forecast Endpoint...
INFO:nixtla.nixtla_client:Validating inputs...
INFO:nixtla.nixtla_client:Validating inputs...
INFO:nixtla.nixtla_client:Preprocessing dataframes...
INFO:nixtla.nixtla_client:Inferred freq: MS
INFO:nixtla.nixtla_client:Restricting input...
INFO:nixtla.nixtla_client:Calling Forecast Endpoint...
INFO:nixtla.nixtla_client:Validating inputs...
时间戳TimeGPTTimeGPT-q-10TimeGPT-q-20TimeGPT-q-30TimeGPT-q-40TimeGPT-q-50TimeGPT-q-60TimeGPT-q-70TimeGPT-q-80TimeGPT-q-90
01951-01-01135.483673111.937768120.020593125.848879130.828935135.483673140.138411145.118467150.946753159.029579
11951-02-01144.442398120.896493128.979318134.807604139.787660144.442398149.097136154.077192159.905478167.988304
21951-03-01157.191910133.646004141.728830147.557116152.537172157.191910161.846648166.826703172.654990180.737815
31951-04-01148.769363125.223458133.306284139.134570144.114625148.769363153.424102158.404157164.232443172.315269
41951-05-01140.472946116.927041125.009866130.838152135.818208140.472946145.127684150.107740155.936026164.018852
cutoffs = timegpt_cv_quantile_fcst_df['cutoff'].unique()
for cutoff in cutoffs:
    fig = nixtla_client.plot(
        df.tail(100), 
        timegpt_cv_quantile_fcst_df.query('cutoff == @cutoff').drop(columns=['cutoff', 'value']),
        time_col='timestamp', 
        target_col='value'
    )
    display(fig)