快速入门(分布式)
使用 MLForecast 进行分布式训练的最小示例
DistributedMLForecast 类是一个高层抽象,它封装了管道中的所有步骤(预处理、拟合模型和计算预测)并以分布式方式应用它们。
使用 DistributedMLForecast(相对于 MLForecast)所需的不同之处在于:
- 你需要搭建一个集群。我们目前支持 dask、ray 和 spark。
- 你的数据需要是分布式集合(dask, ray 或 spark dataframe)。
- 你需要使用在所选框架中实现了分布式训练的模型,例如 spark 中 LightGBM 的 SynapseML。
Dask
客户端设置
这里我们定义一个连接到 dask.distributed.LocalCluster 的客户端,但它也可以是任何其他类型的集群。
数据设置
对于 dask,数据必须是 dask.dataframe.DataFrame。你需要确保每个时间序列只在一个分区中,并且建议你的分区数与 worker 数相同。如果分区数多于 worker 数,请确保设置 num_threads=1 以避免嵌套并行。
所需的输入格式与 MLForecast 相同,只是它是一个 dask.dataframe.DataFrame 而不是 pandas.Dataframe。
| unique_id | ds | y | static_0 | static_1 | sin1_7 | sin2_7 | cos1_7 | cos2_7 | |
|---|---|---|---|---|---|---|---|---|---|
| npartitions=10 | |||||||||
| id_00 | object | datetime64[ns] | float64 | int64 | int64 | float32 | float32 | float32 | float32 | 
| id_10 | … | … | … | … | … | … | … | … | … | 
| … | … | … | … | … | … | … | … | … | … | 
| id_90 | … | … | … | … | … | … | … | … | … | 
| id_99 | … | … | … | … | … | … | … | … | … | 
模型
为了执行分布式预测,我们需要使用能够利用 dask 进行分布式训练的模型。目前的实现位于 DaskLGBMForecast 和 DaskXGBForecast 中,它们只是原生实现的包装器。
训练
有了模型后,我们通过定义特征来实例化一个 DistributedMLForecast 对象。然后可以在此对象上调用 fit 方法,并传入 dask dataframe。
有了拟合好的模型后,我们可以计算未来 7 个时间步的预测。
预测
| unique_id | ds | DaskXGBForecast | DaskLGBMForecast | |
|---|---|---|---|---|
| 0 | id_00 | 2002-09-27 00:00:00 | 21.722841 | 21.725511 | 
| 1 | id_00 | 2002-09-28 00:00:00 | 84.918194 | 84.606362 | 
| 2 | id_00 | 2002-09-29 00:00:00 | 162.067624 | 163.36802 | 
| 3 | id_00 | 2002-09-30 00:00:00 | 249.001477 | 246.422894 | 
| 4 | id_00 | 2002-10-01 00:00:00 | 317.149512 | 315.538403 | 
保存和加载
训练好模型后,可以使用 DistributedMLForecast.save 方法保存用于推理的artifact。请记住,如果在远程集群上,应将远程存储(如 S3)设置为目标路径。
mlforecast 使用 fsspec 处理不同的文件系统,因此如果使用 s3,例如,还需要安装 s3fs。如果使用 pip,只需包含 aws extra,例如 pip install 'mlforecast[aws,dask]',这将安装使用 dask 进行分布式训练和保存到 S3 所需的依赖项。如果使用 conda,则需要手动安装它们 (conda install dask fsspec fugue s3fs)。
保存好预测对象后,可以通过指定保存路径以及用于执行分布式计算的引擎(在本例中为 dask 客户端)来重新加载。
我们可以验证此对象是否产生相同的结果。
转换为本地
另一种存储分布式预测对象的方法是先将其转换为本地对象,然后保存。请记住,为了这样做,存储在远程的系列数据都必须拉取到一台机器(dask 中的调度器,spark 中的驱动器等)上,因此必须确保它能容纳在内存中,它应该消耗大约目标列大小的两倍(通过在 fit 方法中使用 keep_last_n 参数可以进一步减少)。
交叉验证
| unique_id | ds | DaskXGBForecast | DaskLGBMForecast | cutoff | y | |
|---|---|---|---|---|---|---|
| 61 | id_04 | 2002-08-21 00:00:00 | 68.3418 | 68.944539 | 2002-08-15 00:00:00 | 69.699857 | 
| 83 | id_15 | 2002-08-29 00:00:00 | 199.315403 | 199.663555 | 2002-08-15 00:00:00 | 206.082864 | 
| 103 | id_17 | 2002-08-21 00:00:00 | 156.822598 | 158.018246 | 2002-08-15 00:00:00 | 152.227984 | 
| 61 | id_24 | 2002-08-21 00:00:00 | 136.598356 | 136.576865 | 2002-08-15 00:00:00 | 138.559945 | 
| 36 | id_33 | 2002-08-24 00:00:00 | 95.6072 | 96.249354 | 2002-08-15 00:00:00 | 102.068997 | 
Spark
会话设置
数据设置
对于 spark,数据必须是 pyspark DataFrame。你需要确保每个时间序列只在一个分区中(例如可以使用 repartitionByRange 来实现),并且建议你的分区数与 worker 数相同。如果分区数多于 worker 数,请确保设置 num_threads=1 以避免嵌套并行。
所需的输入格式与 MLForecast 相同,即它应至少包含一个 id 列、一个时间列和一个目标列。
模型
为了执行分布式预测,我们需要使用能够利用 spark 进行分布式训练的模型。目前的实现位于 SparkLGBMForecast 和 SparkXGBForecast 中,它们只是原生实现的包装器。
训练
预测
| unique_id | ds | SparkLGBMForecast | SparkXGBForecast | |
|---|---|---|---|---|
| 0 | id_00 | 2002-09-27 | 15.053577 | 18.631477 | 
| 1 | id_00 | 2002-09-28 | 93.010037 | 93.796269 | 
| 2 | id_00 | 2002-09-29 | 160.120148 | 159.582315 | 
| 3 | id_00 | 2002-09-30 | 250.445885 | 250.861651 | 
| 4 | id_00 | 2002-10-01 | 323.335956 | 321.564089 | 
保存和加载
训练好模型后,可以使用 DistributedMLForecast.save 方法保存用于推理的artifact。请记住,如果在远程集群上,应将远程存储(如 S3)设置为目标路径。
mlforecast 使用 fsspec 处理不同的文件系统,因此如果使用 s3,例如,还需要安装 s3fs。如果使用 pip,只需包含 aws extra,例如 pip install 'mlforecast[aws,spark]',这将安装使用 spark 进行分布式训练和保存到 S3 所需的依赖项。如果使用 conda,则需要手动安装它们 (conda install fsspec fugue pyspark s3fs)。
保存好预测对象后,可以通过指定保存路径以及用于执行分布式计算的引擎(在本例中为 spark 会话)来重新加载。
我们可以验证此对象是否产生相同的结果。
转换为本地
另一种存储分布式预测对象的方法是先将其转换为本地对象,然后保存。请记住,为了这样做,存储在远程的系列数据都必须拉取到一台机器(dask 中的调度器,spark 中的驱动器等)上,因此必须确保它能容纳在内存中,它应该消耗大约目标列大小的两倍(通过在 fit 方法中使用 keep_last_n 参数可以进一步减少)。
交叉验证
| unique_id | ds | SparkLGBMForecast | SparkXGBForecast | cutoff | y | |
|---|---|---|---|---|---|---|
| 0 | id_03 | 2002-08-18 | 3.272922 | 3.348874 | 2002-08-15 | 3.060194 | 
| 1 | id_09 | 2002-08-20 | 402.718091 | 402.622501 | 2002-08-15 | 398.784459 | 
| 2 | id_25 | 2002-08-22 | 87.189811 | 86.891632 | 2002-08-15 | 82.731377 | 
| 3 | id_06 | 2002-08-21 | 20.416790 | 20.478502 | 2002-08-15 | 19.196394 | 
| 4 | id_22 | 2002-08-23 | 357.718513 | 360.502024 | 2002-08-15 | 394.770699 | 
Ray
会话设置
数据设置
对于 ray,数据必须是 ray DataFrame。建议你的分区数与 worker 数相同。如果分区数多于 worker 数,请确保设置 num_threads=1 以避免嵌套并行。
所需的输入格式与 MLForecast 相同,即它应至少包含一个 id 列、一个时间列和一个目标列。
模型
ray 集成允许包含 lightgbm (RayLGBMRegressor) 和 xgboost (RayXGBRegressor)。
训练
要控制使用 Ray 的分区数量,我们需要在 DistributedMLForecast 中包含 num_partitions。
预测
| unique_id | ds | RayLGBMForecast | RayXGBForecast | |
|---|---|---|---|---|
| 0 | id_00 | 2002-09-27 | 15.232455 | 10.38301 | 
| 1 | id_00 | 2002-09-28 | 92.288994 | 92.531502 | 
| 2 | id_00 | 2002-09-29 | 160.043472 | 160.722885 | 
| 3 | id_00 | 2002-09-30 | 250.03212 | 252.821899 | 
| 4 | id_00 | 2002-10-01 | 322.905182 | 324.387695 | 
保存和加载
训练好模型后,可以使用 DistributedMLForecast.save 方法保存用于推理的artifact。请记住,如果在远程集群上,应将远程存储(如 S3)设置为目标路径。
mlforecast 使用 fsspec 处理不同的文件系统,因此如果使用 s3,例如,还需要安装 s3fs。如果使用 pip,只需包含 aws extra,例如 pip install 'mlforecast[aws,ray]',这将安装使用 ray 进行分布式训练和保存到 S3 所需的依赖项。如果使用 conda,则需要手动安装它们 (conda install fsspec fugue ray s3fs)。
保存好预测对象后,可以通过指定保存路径以及用于执行分布式计算的引擎(在本例中为字符串 'ray')来重新加载。
我们可以验证此对象是否产生相同的结果。
转换为本地
另一种存储分布式预测对象的方法是先将其转换为本地对象,然后保存。请记住,为了这样做,存储在远程的系列数据都必须拉取到一台机器(dask 中的调度器,spark 中的驱动器等)上,因此必须确保它能容纳在内存中,它应该消耗大约目标列大小的两倍(通过在 fit 方法中使用 keep_last_n 参数可以进一步减少)。
交叉验证
| unique_id | ds | RayLGBMForecast | RayXGBForecast | cutoff | y | |
|---|---|---|---|---|---|---|
| 0 | id_05 | 2002-09-21 | 108.285187 | 108.619698 | 2002-09-12 | 108.726387 | 
| 1 | id_08 | 2002-09-16 | 26.287956 | 26.589603 | 2002-09-12 | 27.980670 | 
| 2 | id_08 | 2002-09-25 | 83.210945 | 84.194962 | 2002-09-12 | 86.344885 | 
| 3 | id_11 | 2002-09-22 | 416.994843 | 417.106506 | 2002-09-12 | 425.434661 | 
| 4 | id_16 | 2002-09-14 | 377.916382 | 375.421600 | 2002-09-12 | 400.361977 | 

