libcity.data.dataset.dataset_subclass.stdn_dataset

class libcity.data.dataset.dataset_subclass.stdn_dataset.STDNDataset(config)[源代码]

基类:libcity.data.dataset.traffic_state_datatset.TrafficStateDataset

_generate_train_val_test()[源代码]

加载数据集,并划分训练集、测试集、验证集,并缓存数据集

_load_cache_train_val_test()[源代码]

加载之前缓存好的训练集、测试集、验证集

_split_train_val_test_stdn(x, y, flatten_att_nbhd_inputs, flatten_att_flow_inputs, att_lstm_inputs, nbhd_inputs, flow_inputs, lstm_inputs)[源代码]

划分训练集、测试集、验证集,并缓存数据集

参数
  • x (np.ndarray) – 输入数据 (num_samples, input_length, …, feature_dim)

  • y (np.ndarray) – 输出数据 (num_samples, input_length, …, feature_dim)

返回

tuple contains:

x_train: (num_samples, input_length, …, feature_dim)

y_train: (num_samples, input_length, …, feature_dim)

x_val: (num_samples, input_length, …, feature_dim)

y_val: (num_samples, input_length, …, feature_dim)

x_test: (num_samples, input_length, …, feature_dim)

y_test: (num_samples, input_length, …, feature_dim)

返回类型

tuple

get_data()[源代码]

返回数据的DataLoader,包括训练数据、测试数据、验证数据

返回

tuple contains:

train_dataloader: Dataloader composed of Batch (class)

eval_dataloader: Dataloader composed of Batch (class)

test_dataloader: Dataloader composed of Batch (class)

返回类型

tuple

get_data_feature()[源代码]

返回数据集特征,子类必须实现这个函数,返回必要的特征

返回

包含数据集的相关特征的字典

返回类型

dict