libcity.model.abstract_model

class libcity.model.abstract_model.AbstractModel(config, data_feature)[source]

Bases: torch.nn.modules.module.Module

calculate_loss(batch)[source]
Parameters

batch (Batch) – a batch of input

Returns

return training loss

Return type

torch.tensor

predict(batch)[source]
Parameters

batch (Batch) – a batch of input

Returns

predict result of this batch

Return type

torch.tensor

training: bool