libcity.data.dataset.dataset_subclass.stresnet_dataset

class libcity.data.dataset.dataset_subclass.stresnet_dataset.STResNetDataset(config)[source]

Bases: libcity.data.dataset.traffic_state_grid_dataset.TrafficStateGridDataset, libcity.data.dataset.traffic_state_cpt_dataset.TrafficStateCPTDataset

STResNet外部数据源代码只用了ext_y, 没有用到ext_x!

_get_external_array(timestamp_list, ext_data=None, previous_ext=False, ext_time=True)[source]

根据时间戳数组,获取对应时间的外部特征

Parameters
  • timestamp_list – 时间戳序列

  • ext_data – 外部数据

  • previous_ext – 是否是用过去时间段的外部数据,因为对于预测的时间段Y, 一般没有真实的外部数据,所以用前一个时刻的数据,多步预测则用提前多步的数据

Returns

External data shape is (len(timestamp_list), ext_dim)

Return type

np.ndarray

_load_ext_data(ts_x, ts_y)[source]

加载对应时间的外部数据(.ext)

Parameters
  • ts_x – 输入数据X对应的时间戳,shape: (num_samples, T_c+T_p+T_t)

  • ts_y – 输出数据Y对应的时间戳,shape:(num_samples, )

Returns

tuple contains:

ext_x(np.ndarray): 对应时间的外部数据, shape: (num_samples, T_c+T_p+T_t, ext_dim), ext_y(np.ndarray): 对应时间的外部数据, shape: (num_samples, ext_dim)

Return type

tuple

get_data_feature()[source]

返回数据集特征,scaler是归一化方法,adj_mx是邻接矩阵,num_nodes是网格的个数, len_row是网格的行数,len_column是网格的列数, feature_dim是输入数据的维度,output_dim是模型输出的维度

Returns

包含数据集的相关特征的字典

Return type

dict