libcity.data.dataset.dataset_subclass.cstn_dataset

class libcity.data.dataset.dataset_subclass.cstn_dataset.CSTNDataset(config)[source]

Bases: libcity.data.dataset.traffic_state_grid_od_dataset.TrafficStateGridOdDataset

_generate_data()[source]

加载数据文件(.gridod)和外部数据(.ext),以X, W, y的形式返回

Returns

tuple contains:

X(np.ndarray): 模型输入数据,(num_samples, input_length, …, feature_dim)

W(np.ndarray): 模型外部数据,(num_samples, input_length, ext_dim) y(np.ndarray): 模型输出数据,(num_samples, output_length, …, feature_dim)

Return type

tuple

get_data()[source]

返回数据的DataLoader,包括训练数据、测试数据、验证数据

Returns

tuple contains:

train_dataloader: Dataloader composed of Batch (class)

eval_dataloader: Dataloader composed of Batch (class)

test_dataloader: Dataloader composed of Batch (class)

Return type

tuple

get_data_feature()[source]

返回数据集特征,scaler是归一化方法,adj_mx是邻接矩阵,num_nodes是网格的个数, len_row是网格的行数,len_column是网格的列数, feature_dim是输入数据的维度,output_dim是模型输出的维度

Returns

包含数据集的相关特征的字典

Return type

dict