自定义Executor¶
本文档将介绍如何在LibCity
中开发一个新的执行器。
当一个新添加的模型训练方法复杂,现有的执行器不能用于训练和评估,我们就需要开发一个新的执行器。
创建新的Executor类¶
首先,我们创建的执行器应该继承自AbstractExecutoror
。
例如,如下代码可用于开发一个名为NewExecutor
的交通状态预测执行器,代码被写入libcity/executor/
目录下的newexecutor.py
。
from libcity.executor.abstract_executor import AbstractExecutor
class NewExecutor(AbstractExecutor):
def __init__(self, config, model):
self.evaluator = get_evaluator(config)
pass
重写对应的方法¶
用于训练模型的函数是train()
,它将调用_train_epoch()
来训练模型。
用来评估模型的函数是evaluate()
,它将调用_valid_epoch()
来评估该模型。
剩下的两个接口load_model()
和save_model()
分别用来加载和保存模型。
如果开发的模型需要更复杂的训练或评估方法,那么你可以重写上述接口。
from libcity.executor.abstract_executor import AbstractExecutor
class NewExecutor(AbstractExecutor):
def __init__(self, config, model):
self.evaluator = get_evaluator(config)
pass
def save_model(self, cache_name):
pass
def load_model(self, cache_name):
pass
def evaluate(self, test_dataloader):
pass
def train(self, train_dataloader, eval_dataloader):
pass