libcity.model.traffic_speed_prediction.TemplateTSP

class libcity.model.traffic_speed_prediction.TemplateTSP.TemplateTSP(config, data_feature)[source]

Bases: libcity.model.abstract_traffic_state_model.AbstractTrafficStateModel

calculate_loss(batch)[source]

输入一个batch的数据,返回训练过程这个batch数据的loss,也就是需要定义一个loss函数。 :param batch: 输入数据,类字典,可以按字典的方法取数据 :return: training loss (tensor)

forward(batch)[source]

调用模型计算这个batch输入对应的输出,nn.Module必须实现的接口 :param batch: 输入数据,类字典,可以按字典的方法取数据 :return:

predict(batch)[source]

输入一个batch的数据,返回对应的预测值,一般应该是**多步预测**的结果 一般会调用上边定义的forward()方法 :param batch: 输入数据,类字典,可以按字典的方法取数据 :return: predict result of this batch (tensor)

training: bool