libcity.model.abstract_model 源代码

import torch.nn as nn


[文档]class AbstractModel(nn.Module): def __init__(self, config, data_feature): nn.Module.__init__(self)
[文档] def predict(self, batch): """ Args: batch (Batch): a batch of input Returns: torch.tensor: predict result of this batch """
[文档] def calculate_loss(self, batch): """ Args: batch (Batch): a batch of input Returns: torch.tensor: return training loss """