libcity.model.road_representation.ChebConv

class libcity.model.road_representation.ChebConv.ChebConv(config, data_feature)[源代码]

基类:libcity.model.abstract_traffic_state_model.AbstractTrafficStateModel

calculate_loss(batch)[源代码]
参数

batch – dict, need key ‘node_features’, ‘node_labels’, ‘mask’

Returns:

forward(batch)[源代码]

自回归任务

参数

batch – dict, need key ‘node_features’ contains tensor shape=(N, feature_dim)

返回

N, feature_dim

返回类型

torch.tensor

predict(batch)[源代码]
参数

batch – dict, need key ‘node_features’

返回

torch.tensor

training: bool
class libcity.model.road_representation.ChebConv.ChebConvModule(num_nodes, max_diffusion_step, adj_mx, device, input_dim, output_dim, filter_type)[源代码]

基类:torch.nn.modules.module.Module

路网表征模型的基类并不统一 图卷积,将N*C的输入矩阵映射成N*F的输出矩阵,其中邻接矩阵形状N*N。

forward(x)[源代码]

GONV :param x: (N, input_dim) :return: (N, output_dim)

training: bool