libcity.model.road_representation.LINE

class libcity.model.road_representation.LINE.LINE(config, data_feature)[source]

Bases: libcity.model.abstract_traffic_state_model.AbstractTrafficStateModel

calculate_loss(batch)[source]

输入一个batch的数据,返回训练过程的loss,也就是需要定义一个loss函数

Parameters

batch (Batch) – a batch of input

Returns

return training loss

Return type

torch.tensor

forward(I, J)[source]
Parameters
  • I – origin indices of node i ; (B,)

  • J – origin indices of node j ; (B,)

Returns

[u_j^T * u_i for (i,j) in zip(I, J)]; (B,) elif order == ‘second’:

[u’_j^T * v_i for (i,j) in zip(I, J)]; (B,)

Return type

if order == ‘first’

training: bool
class libcity.model.road_representation.LINE.LINE_FIRST(num_nodes, output_dim)[source]

Bases: torch.nn.modules.module.Module

forward(i, j)[source]
Parameters
  • i – indices of i; (B,)

  • j – indices of j; (B,)

Returns

v_i^T * v_j; (B,)

get_embeddings()[source]
training: bool
class libcity.model.road_representation.LINE.LINE_SECOND(num_nodes, output_dim)[source]

Bases: torch.nn.modules.module.Module

forward(I, J)[source]
Parameters
  • I – indices of i; (B,)

  • J – indices of j; (B,)

Returns

[v_i^T * u_j for (i,j) in zip(I,J)]; (B,)

get_embeddings()[source]
training: bool