libcity.model.abstract_traffic_state_model¶
-
class
libcity.model.abstract_traffic_state_model.AbstractTrafficStateModel(config, data_feature)[源代码]¶ 基类:
libcity.model.abstract_model.AbstractModel-
calculate_loss(batch)[源代码]¶ 输入一个batch的数据,返回训练过程的loss,也就是需要定义一个loss函数
- 参数
batch (Batch) – a batch of input
- 返回
return training loss
- 返回类型
torch.tensor
-
predict(batch)[源代码]¶ 输入一个batch的数据,返回对应的预测值,一般应该是**多步预测**的结果,一般会调用nn.Moudle的forward()方法
- 参数
batch (Batch) – a batch of input
- 返回
predict result of this batch
- 返回类型
torch.tensor
-
training: bool¶
-