libcity.model.abstract_model

class libcity.model.abstract_model.AbstractModel(config, data_feature)[源代码]

基类:torch.nn.modules.module.Module

calculate_loss(batch)[源代码]
参数

batch (Batch) – a batch of input

返回

return training loss

返回类型

torch.tensor

predict(batch)[源代码]
参数

batch (Batch) – a batch of input

返回

predict result of this batch

返回类型

torch.tensor

training: bool