1. 导入包

首先,我们导入所需的包并初始化 Nixtla 客户端

import pandas as pd
from nixtla import NixtlaClient
from utilsforecast.losses import mae, mse
from utilsforecast.evaluation import evaluate
nixtla_client = NixtlaClient(
    # defaults to os.environ.get("NIXTLA_API_KEY")
    api_key = 'my_api_key_provided_by_nixtla'
)

👍 使用 Azure AI 端点

要使用 Azure AI 端点,请记住同时设置 base_url 参数

nixtla_client = NixtlaClient(base_url="you azure ai endpoint", api_key="your api_key")

2. 加载数据

df = pd.read_csv('https://raw.githubusercontent.com/Nixtla/transfer-learning-time-series/main/datasets/air_passengers.csv')
df.head()
时间戳
01949-01-01112
11949-02-01118
21949-03-01132
31949-04-01129
41949-05-01121

现在,我们将数据分割成训练集和测试集,以便在改变 finetune_depth 时衡量模型的性能。

train = df[:-24]
test = df[-24:]

接下来,我们微调 TimeGPT 并改变 finetune_depth 来衡量其对性能的影响。

3. 使用 finetune_depth 进行微调

📘 Azure AI 中可用的模型

如果您正在使用 Azure AI 端点,请确保设置 model="azureai"

nixtla_client.forecast(..., model="azureai")

对于公共 API,我们支持两种模型:timegpt-1timegpt-1-long-horizon

默认使用 timegpt-1。关于如何以及何时使用 timegpt-1-long-horizon,请参阅本教程

如上所述,finetune_depth 控制 TimeGPT 模型中有多少参数会在您的特定数据集上进行微调。如果值设置为 1,则只有少量参数会进行微调。设置为 5 意味着模型的所有参数都将进行微调。

对于具有复杂模式的大型数据集,使用较大的 finetune_depth 值可以带来更好的性能。然而,它也可能导致过拟合,在这种情况下,预测的准确性可能会下降,我们将在下面的小实验中看到这一点。

depths = [1, 2, 3, 4, 5]

test = test.copy()

for depth in depths:
    preds_df = nixtla_client.forecast(
    df=train, 
    h=24, 
    finetune_steps=5,
    finetune_depth=depth,
    time_col='timestamp', 
    target_col='value')

    preds = preds_df['TimeGPT'].values

    test.loc[:,f'TimeGPT_depth{depth}'] = preds
INFO:nixtla.nixtla_client:Validating inputs...
INFO:nixtla.nixtla_client:Inferred freq: MS
INFO:nixtla.nixtla_client:Querying model metadata...
WARNING:nixtla.nixtla_client:The specified horizon "h" exceeds the model horizon. This may lead to less accurate forecasts. Please consider using a smaller horizon.
INFO:nixtla.nixtla_client:Preprocessing dataframes...
INFO:nixtla.nixtla_client:Calling Forecast Endpoint...
INFO:nixtla.nixtla_client:Validating inputs...
INFO:nixtla.nixtla_client:Inferred freq: MS
WARNING:nixtla.nixtla_client:The specified horizon "h" exceeds the model horizon. This may lead to less accurate forecasts. Please consider using a smaller horizon.
INFO:nixtla.nixtla_client:Preprocessing dataframes...
INFO:nixtla.nixtla_client:Calling Forecast Endpoint...
INFO:nixtla.nixtla_client:Validating inputs...
INFO:nixtla.nixtla_client:Inferred freq: MS
WARNING:nixtla.nixtla_client:The specified horizon "h" exceeds the model horizon. This may lead to less accurate forecasts. Please consider using a smaller horizon.
INFO:nixtla.nixtla_client:Preprocessing dataframes...
INFO:nixtla.nixtla_client:Calling Forecast Endpoint...
INFO:nixtla.nixtla_client:Validating inputs...
INFO:nixtla.nixtla_client:Inferred freq: MS
WARNING:nixtla.nixtla_client:The specified horizon "h" exceeds the model horizon. This may lead to less accurate forecasts. Please consider using a smaller horizon.
INFO:nixtla.nixtla_client:Preprocessing dataframes...
INFO:nixtla.nixtla_client:Calling Forecast Endpoint...
INFO:nixtla.nixtla_client:Validating inputs...
INFO:nixtla.nixtla_client:Inferred freq: MS
WARNING:nixtla.nixtla_client:The specified horizon "h" exceeds the model horizon. This may lead to less accurate forecasts. Please consider using a smaller horizon.
INFO:nixtla.nixtla_client:Preprocessing dataframes...
INFO:nixtla.nixtla_client:Calling Forecast Endpoint...
test['unique_id'] = 0

evaluation = evaluate(test, metrics=[mae, mse], time_col="timestamp", target_col="value")
evaluation
唯一 ID指标TimeGPT_depth1TimeGPT_depth2TimeGPT_depth3TimeGPT_depth4TimeGPT_depth5
00mae22.67554017.90896321.31851824.74509628.734302
10mse677.254283461.320852676.202126991.8353591119.722602

从上面的结果可以看出,finetune_depth 为 2 时取得了最好的结果,因为它具有最低的 MAE 和 MSE。

另请注意,当 finetune_depth 为 4 和 5 时,性能会下降,这是过拟合的明显迹象。

因此,请记住微调可能需要一些反复试验。您可能需要根据您的特定需求和数据的复杂性来调整 finetune_steps 的数量和 finetune_depth 的级别。通常,对于大型数据集,较高的 finetune_depth 效果更好。在本教程中,由于我们预测的是一个非常短的单个序列,增加深度导致了过拟合。

建议在微调过程中监控模型的性能并根据需要进行调整。请注意,更多的 finetune_steps 和更大的 finetune_depth 值可能会导致更长的训练时间,如果管理不当,可能会导致过拟合。