libcity.model.road_representation.LINE¶
-
class
libcity.model.road_representation.LINE.LINE(config, data_feature)[源代码]¶ 基类:
libcity.model.abstract_traffic_state_model.AbstractTrafficStateModel-
calculate_loss(batch)[源代码]¶ 输入一个batch的数据,返回训练过程的loss,也就是需要定义一个loss函数
- 参数
batch (Batch) – a batch of input
- 返回
return training loss
- 返回类型
torch.tensor
-
forward(I, J)[源代码]¶ - 参数
I – origin indices of node i ; (B,)
J – origin indices of node j ; (B,)
- 返回
[u_j^T * u_i for (i,j) in zip(I, J)]; (B,) elif order == ‘second’:
[u’_j^T * v_i for (i,j) in zip(I, J)]; (B,)
- 返回类型
if order == ‘first’
-
training: bool¶
-