来源

TimeSeriesLoader

 TimeSeriesLoader (dataset, **kwargs)

*TimeSeriesLoader DataLoader. 源代码.

对 PyTorch 的 DataLoader 进行了微小改动。它结合了数据集和采样器,并提供了给定数据集的可迭代对象。

~torch.utils.data.DataLoader 支持 map-style 和 iterable-style 数据集,具有单进程或多进程加载、自定义加载顺序以及可选的自动批处理 (collation) 和内存锁定功能。

参数
batch_size: (int, 可选): 每批加载多少样本 (默认: 1)。
shuffle: (bool, 可选): 设置为 True 以在每个 epoch 重新打乱数据 (默认: False)。
sampler: (Sampler 或 Iterable, 可选): 定义从数据集中抽取样本的策略。
可以是任何实现了 __len__ 方法的 Iterable。如果指定了此参数,则不得指定 shuffle
*


来源

BaseTimeSeriesDataset

 BaseTimeSeriesDataset (temporal_cols, max_size:int, min_size:int,
                        y_idx:int, static=None, static_cols=None)

*一个表示 :class:Dataset 的抽象类。

所有表示从键到数据样本的映射的数据集都应该继承此类。所有子类都应该重写 :meth:__getitem__ 方法,以支持根据给定键获取数据样本。子类也可以选择性地重写 :meth:__len__ 方法,许多 :class:~torch.utils.data.Sampler 实现和 :class:~torch.utils.data.DataLoader 的默认选项都期望此方法返回数据集的大小。子类还可以选择性地实现 :meth:__getitems__ 方法,以加快批量样本加载。此方法接受一个批量样本索引列表,并返回样本列表。

.. 注意:: :class:~torch.utils.data.DataLoader 默认构造一个索引采样器,它产生整数索引。要使其适用于带有非整数索引/键的 map-style 数据集,必须提供自定义采样器。*


来源

LocalFilesTimeSeriesDataset

 LocalFilesTimeSeriesDataset (files_ds:List[str], temporal_cols,
                              id_col:str, time_col:str, target_col:str,
                              last_times, indices, max_size:int,
                              min_size:int, y_idx:int, static=None,
                              static_cols=None)

*一个表示 :class:Dataset 的抽象类。

所有表示从键到数据样本的映射的数据集都应该继承此类。所有子类都应该重写 :meth:__getitem__ 方法,以支持根据给定键获取数据样本。子类也可以选择性地重写 :meth:__len__ 方法,许多 :class:~torch.utils.data.Sampler 实现和 :class:~torch.utils.data.DataLoader 的默认选项都期望此方法返回数据集的大小。子类还可以选择性地实现 :meth:__getitems__ 方法,以加快批量样本加载。此方法接受一个批量样本索引列表,并返回样本列表。

.. 注意:: :class:~torch.utils.data.DataLoader 默认构造一个索引采样器,它产生整数索引。要使其适用于带有非整数索引/键的 map-style 数据集,必须提供自定义采样器。*


来源

TimeSeriesDataset

 TimeSeriesDataset (temporal, temporal_cols, indptr, y_idx:int,
                    static=None, static_cols=None)

*一个表示 :class:Dataset 的抽象类。

所有表示从键到数据样本的映射的数据集都应该继承此类。所有子类都应该重写 :meth:__getitem__ 方法,以支持根据给定键获取数据样本。子类也可以选择性地重写 :meth:__len__ 方法,许多 :class:~torch.utils.data.Sampler 实现和 :class:~torch.utils.data.DataLoader 的默认选项都期望此方法返回数据集的大小。子类还可以选择性地实现 :meth:__getitems__ 方法,以加快批量样本加载。此方法接受一个批量样本索引列表,并返回样本列表。

.. 注意:: :class:~torch.utils.data.DataLoader 默认构造一个索引采样器,它产生整数索引。要使其适用于带有非整数索引/键的 map-style 数据集,必须提供自定义采样器。*


来源

TimeSeriesDataModule

 TimeSeriesDataModule (dataset:__main__.BaseTimeSeriesDataset,
                       batch_size=32, valid_batch_size=1024,
                       drop_last=False, shuffle_train=True,
                       **dataloaders_kwargs)

*DataModule 标准化了训练集、验证集、测试集的划分、数据准备和转换。主要优点是在不同模型之间保持数据划分、数据准备和转换的一致性。

示例:

import lightning.pytorch as L
import torch.utils.data as data
from pytorch_lightning.demos.boring_classes import RandomDataset

class MyDataModule(L.LightningDataModule):
    def prepare_data(self):
        # download, IO, etc. Useful with shared filesystems
        # only called on 1 GPU/TPU in distributed
        ...

    def setup(self, stage):
        # make assignments here (val/train/test split)
        # called on every process in DDP
        dataset = RandomDataset(1, 100)
        self.train, self.val, self.test = data.random_split(
            dataset, [80, 10, 10], generator=torch.Generator().manual_seed(42)
        )

    def train_dataloader(self):
        return data.DataLoader(self.train)

    def val_dataloader(self):
        return data.DataLoader(self.val)

    def test_dataloader(self):
        return data.DataLoader(self.test)

    def on_exception(self, exception):
        # clean up state after the trainer faced an exception
        ...

    def teardown(self):
        # clean up state after the trainer stops, delete files...
        # called on every process in DDP
        ...*
# To test correct future_df wrangling of the `update_df` method
# We are checking that we are able to recover the AirPassengers dataset
# using the dataframe or splitting it into parts and initializing.