libcity.model.road_representation.ChebConv

class libcity.model.road_representation.ChebConv.ChebConv(config, data_feature)[source]

Bases: libcity.model.abstract_traffic_state_model.AbstractTrafficStateModel

calculate_loss(batch)[source]
Parameters

batch – dict, need key ‘node_features’, ‘node_labels’, ‘mask’

Returns:

forward(batch)[source]

自回归任务

Parameters

batch – dict, need key ‘node_features’ contains tensor shape=(N, feature_dim)

Returns

N, feature_dim

Return type

torch.tensor

predict(batch)[source]
Parameters

batch – dict, need key ‘node_features’

Returns

torch.tensor

training: bool
class libcity.model.road_representation.ChebConv.ChebConvModule(num_nodes, max_diffusion_step, adj_mx, device, input_dim, output_dim, filter_type)[source]

Bases: torch.nn.modules.module.Module

路网表征模型的基类并不统一 图卷积,将N*C的输入矩阵映射成N*F的输出矩阵,其中邻接矩阵形状N*N。

forward(x)[source]

GONV :param x: (N, input_dim) :return: (N, output_dim)

training: bool