libcity.model.abstract_model 源代码
import torch.nn as nn
[文档]class AbstractModel(nn.Module):
def __init__(self, config, data_feature):
nn.Module.__init__(self)
[文档] def predict(self, batch):
"""
Args:
batch (Batch): a batch of input
Returns:
torch.tensor: predict result of this batch
"""
[文档] def calculate_loss(self, batch):
"""
Args:
batch (Batch): a batch of input
Returns:
torch.tensor: return training loss
"""