自定义Dataset

对于一个新的模型,如果没有合适的已实现数据集供使用,则需要设计一个新的数据集。本文档用于介绍如何在LibCity中开发一个新的数据集。

创建新的Dataset类

首先,我们创建的数据集应该继承自AbstractDataset或它的子类。

例如,若想为交通状态预测任务开发一个名为NewDataset的数据集,需要将代码写入libcity/data/dataset/目录下的newdataset.py文件中。

在如下代码中,我们的NewDataset继承了AbstractDataset类的子类TrafficStatePointDataset

from libcity.data.dataset import TrafficStatePointDataset

class NewDatasets(TrafficStatePointDataset):
    def __init__(self, config):
        super().__init__(config)
        pass

或者可以直接继承AbstractDataset 类,代码如下。

from libcity.data.dataset import AbstractDataset

class NewDatasets(AbstractDataset):
    def __init__(self, config):
        pass

重写对应方法

AbstractDataset中的函数get_data()被用来分割数据,并获取3个数据加载器train_dataloadereval_dataloadertest_dataloader。若想要获取数据加载器,需要调用函数libcity.data.utils.generate_dataloader来从输入数据的列表中获取,其中生成的数据加载器包含Batch对象。

AbstractDataset中的函数get_data_feature()被用来返回一些数据集的特征,这些特征将被模型和执行器使用。

AbstractDataset的子类中定义的其他接口将不在此描述。

如果没有合适的数据集类,那么你可以重写上面提到的相应接口。

样例 1

样例1用来演示如何直接继承AbstractDataset并重写函数get_data_feature()以返回我们想要的一些值。

from libcity.data.dataset import AbstractDataset

class NewDatasets(AbstractDataset):
    def __init__(self, config):
        pass
    
    def get_data_feature(self):
        return {"scaler": self.scaler, "adj_mx": self.adj_mx,
                "num_nodes": self.num_nodes, "feature_dim": self.feature_dim,
                "output_dim": self.output_dim}

样例 2

样例2用来演示如何继承AbstractDataset的子类并重写其中的一个方法(_load_rel)。

from libcity.data.dataset import TrafficStatePointDataset

class NewDatasets(TrafficStatePointDataset):
    def __init__(self, config):
        super().__init__(config)
        pass
    
    # We will rewrite this method which is used to calculate `self.adj_mx` based on the atmoic file `rel_file.rel`.
    def _load_rel(self):
        relfile = pd.read_csv(self.data_path + self.rel_file + '.rel')
        self.adj_mx = np.zeros((len(self.geo_ids), len(self.geo_ids)))
        self.adj_mx[:] = 1  # set all one

样例 3

样例3用来解释如何继承AbstractDataset的子类,并返回从原始数据文件转化而来的,含有不同键的Batch。具体来说,我们打算返回三个键值对,其中键包括:XYZ。这只是一个例子,更多的细节,你可以参考TrafficStateCPTDataset,它有四个键的Batch

from libcity.data.dataset import TrafficStateDataset

class NewDatasets(TrafficStateDataset):
    def __init__(self, config):
        super().__init__(config)
        # the origin code
        # self.feature_name = {'X': 'float', 'y': 'float'}
        # the modified code
        self.feature_name = {'X': 'float', 'Y': 'float', 'Z': 'int'}
        pass
    
    def get_data(self):
        # Load datset for the keys x,y,z, generate [x|y|z]_[train|val|test].
        # ... (implement it yourself)
        # Data normalization using self.scaler.
        # ... (implement it yourself)
        # Aggregate X, Y, Z into a list.
        # The i-th element in train_data(a list) is a tuple, consists of x_train[i], y_train[i] and z_train[i].
        train_data = list(zip(x_train, y_train, z_train))
        eval_data = list(zip(x_val, y_val, z_val))
        test_data = list(zip(x_test, y_test, z_test))
        # Get dataloader by libcity.data.utils.generate_dataloader.
        self.train_dataloader, self.eval_dataloader, self.test_dataloader = \
            generate_dataloader(train_data, eval_data, test_data, self.feature_name,
                                self.batch_size, self.num_workers)
        # Return the dataloader
        return self.train_dataloader, self.eval_dataloader, self.test_dataloader