libcity.model.trajectory_loc_prediction.TemplateTLP

class libcity.model.trajectory_loc_prediction.TemplateTLP.TemplateTLP(config, data_feature)[source]

Bases: libcity.model.abstract_model.AbstractModel

请参考开源模型代码,完成本文件的编写。请务必重写 __init__, predict, calculate_loss 三个方法。

calculate_loss(batch)[source]
参数说明:

batch (libcity.data.batch): 类 dict 文件,其中包含的键值参见任务说明文件。

返回值:
loss (pytorch.tensor): 可以调用 pytorch 实现的 loss 函数与 batch[‘target’]

目标值进行 loss 计算,并将计算结果返回。如模型有自己独特的 loss 计算方式则自行参考实现。

predict(batch)[source]
参数说明:

batch (libcity.data.batch): 类 dict 文件,其中包含的键值参见任务说明文件。

返回值:
score (pytorch.tensor): 对应张量 shape 应为 batch_size *

loc_size。这里返回的是模型对于输入当前轨迹的下一跳位置的预测值。

training: bool