libcity.data.dataset.dataset_subclass.ccrnn_dataset

class libcity.data.dataset.dataset_subclass.ccrnn_dataset.CCRNNDataset(config)[source]

Bases: libcity.data.dataset.traffic_state_point_dataset.TrafficStatePointDataset

_generate_data()[source]

加载数据文件(.dyna/.grid/.od/.gridod)和外部数据(.ext),且将二者融合,以X,y的形式返回

Returns

tuple contains:

x(np.ndarray): 模型输入数据,(num_samples, input_length, …, feature_dim)

y(np.ndarray): 模型输出数据,(num_samples, output_length, …, feature_dim)

Return type

tuple

_generate_train_val_test()[source]

加载数据集,并划分训练集、测试集、验证集,并缓存数据集

Returns

tuple contains:

x_train: (num_samples, input_length, …, feature_dim)

y_train: (num_samples, input_length, …, feature_dim)

x_val: (num_samples, input_length, …, feature_dim)

y_val: (num_samples, input_length, …, feature_dim)

x_test: (num_samples, input_length, …, feature_dim)

y_test: (num_samples, input_length, …, feature_dim)

Return type

tuple

_load_cache_train_val_test()[source]

加载之前缓存好的训练集、测试集、验证集

Returns

tuple contains:

x_train: (num_samples, input_length, …, feature_dim)

y_train: (num_samples, input_length, …, feature_dim)

x_val: (num_samples, input_length, …, feature_dim)

y_val: (num_samples, input_length, …, feature_dim)

x_test: (num_samples, input_length, …, feature_dim)

y_test: (num_samples, input_length, …, feature_dim)

Return type

tuple

_load_rel()[source]

根据网格结构构建邻接矩阵,一个格子跟他周围的8个格子邻接

Returns

self.adj_mx, N*N的邻接矩阵

Return type

np.ndarray

_split_train_val_test(x, y, df=None)[source]

划分训练集、测试集、验证集,并缓存数据集

Parameters
  • x (np.ndarray) – 输入数据 (num_samples, input_length, …, feature_dim)

  • y (np.ndarray) – 输出数据 (num_samples, input_length, …, feature_dim)

Returns

tuple contains:

x_train: (num_samples, input_length, …, feature_dim)

y_train: (num_samples, input_length, …, feature_dim)

x_val: (num_samples, input_length, …, feature_dim)

y_val: (num_samples, input_length, …, feature_dim)

x_test: (num_samples, input_length, …, feature_dim)

y_test: (num_samples, input_length, …, feature_dim)

Return type

tuple