libcity.evaluator.traj_loc_pred_evaluator

class libcity.evaluator.traj_loc_pred_evaluator.TrajLocPredEvaluator(config)[source]

Bases: libcity.evaluator.abstract_evaluator.AbstractEvaluator

clear()[source]

清除之前收集到的 batch 的评估信息,适用于每次评估开始时进行一次清空,排除之前的评估输入的影响。

collect(batch)[source]
Parameters
  • batch (dict) – contains three keys: uid, loc_true, and loc_pred.

  • uid (list) – 来自于 batch 中的 uid,通过索引可以确定 loc_true 与 loc_pred 中每一行(元素)是哪个用户的一次输入。

  • loc_true (list) – 期望地点(target),来自于 batch 中的 target。 对于负样本评估,loc_pred 中第一个点是 target 的置信度,后面的都是负样本的

  • loc_pred (matrix) – 实际上模型的输出,batch_size * output_dim.

evaluate()[source]

返回之前收集到的所有 batch 的评估结果

save_result(save_path, filename=None)[source]

将评估结果保存到 save_path 文件夹下的 filename 文件中

Parameters
  • save_path – 保存路径

  • filename – 保存文件名