libcity.data.dataset.dataset_subclass.stdn_dataset

class libcity.data.dataset.dataset_subclass.stdn_dataset.STDNDataset(config)[source]

Bases: libcity.data.dataset.traffic_state_datatset.TrafficStateDataset

_generate_train_val_test()[source]

加载数据集,并划分训练集、测试集、验证集,并缓存数据集

_load_cache_train_val_test()[source]

加载之前缓存好的训练集、测试集、验证集

_split_train_val_test_stdn(x, y, flatten_att_nbhd_inputs, flatten_att_flow_inputs, att_lstm_inputs, nbhd_inputs, flow_inputs, lstm_inputs)[source]

划分训练集、测试集、验证集,并缓存数据集

Parameters
  • x (np.ndarray) – 输入数据 (num_samples, input_length, …, feature_dim)

  • y (np.ndarray) – 输出数据 (num_samples, input_length, …, feature_dim)

Returns

tuple contains:

x_train: (num_samples, input_length, …, feature_dim)

y_train: (num_samples, input_length, …, feature_dim)

x_val: (num_samples, input_length, …, feature_dim)

y_val: (num_samples, input_length, …, feature_dim)

x_test: (num_samples, input_length, …, feature_dim)

y_test: (num_samples, input_length, …, feature_dim)

Return type

tuple

get_data()[source]

返回数据的DataLoader,包括训练数据、测试数据、验证数据

Returns

tuple contains:

train_dataloader: Dataloader composed of Batch (class)

eval_dataloader: Dataloader composed of Batch (class)

test_dataloader: Dataloader composed of Batch (class)

Return type

tuple

get_data_feature()[source]

返回数据集特征,子类必须实现这个函数,返回必要的特征

Returns

包含数据集的相关特征的字典

Return type

dict