libcity.data.dataset.dataset_subclass.astgcn_dataset

class libcity.data.dataset.dataset_subclass.astgcn_dataset.ASTGCNDataset(config)[source]

Bases: libcity.data.dataset.traffic_state_point_dataset.TrafficStatePointDataset

_generate_input_data(df)[source]

根据全局参数len_closeness/len_period/len_trend切分输入,产生模型需要的输入

Parameters

df (np.ndarray) – 输入数据, shape: (len_time, …, feature_dim)

Returns

tuple contains:

sources(np.ndarray): 模型输入数据, shape: (num_samples, Tw+Td+Th, …, feature_dim)

targets(np.ndarray): 模型输出数据, shape: (num_samples, Tp, …, feature_dim)

Return type

tuple

_get_sample_indices(data_sequence, label_start_idx)[source]

根据全局参数len_closeness/len_period/len_trend找到数据预测目标数据 段: [label_start_idx: label_start_idx+output_window)

Parameters
  • data_sequence (np.ndarray) – 输入数据,shape: (len_time, …, feature_dim)

  • label_start_idx (int) – the first index of predicting target, 预测开始的时间片的索引

Returns

tuple contains:

trend_sample: 输入数据1, (len_trend * self.output_window, …, feature_dim)

period_sample: 输入数据2, (len_period * self.output_window, …, feature_dim)

closeness_sample: 输入数据3, (len_closeness * self.output_window, …, feature_dim)

target: 输出数据, (self.output_window, …, feature_dim)

Return type

tuple

_search_data(sequence_length, label_start_idx, num_for_predict, num_of_depend, units)[source]

根据全局参数len_closeness/len_period/len_trend找到数据索引的位置

Parameters
  • sequence_length (int) – 历史数据的总长度

  • label_start_idx (int) – 预测开始的时间片的索引

  • num_for_predict (int) – 预测的时间片序列长度

  • num_of_depend (int) – len_trend/len_period/len_closeness

  • units (int) – trend/period/closeness的长度(以小时为单位)

Returns

起点-终点区间段的数组,list[(start_idx, end_idx)]

Return type

list

get_data_feature()[source]

返回数据集特征,scaler是归一化方法,adj_mx是邻接矩阵,num_nodes是点的个数, feature_dim是输入数据的维度,output_dim是模型输出的维度, len_closeness/len_period/len_trend分别是三段数据的长度

Returns

包含数据集的相关特征的字典

Return type

dict