Source code for easycv.hooks.export_hook

# Copyright (c) Alibaba, Inc. and its affiliates.
import os

from mmcv.runner import Hook
from mmcv.runner.dist_utils import master_only

from easycv.utils.config_tools import validate_export_config
from .registry import HOOKS


[docs]@HOOKS.register_module class ExportHook(Hook): """ export model when training on pai """
[docs] def __init__( self, cfg, ckpt_filename_tmpl='epoch_{}.pth', export_ckpt_filename_tmpl='epoch_{}_export.pt', export_after_each_ckpt=False, ): """ Args: cfg: config dict ckpt_filename_tmpl: checkpoint filename template """ self.cfg = validate_export_config(cfg) self.work_dir = cfg.work_dir self.ckpt_filename_tmpl = ckpt_filename_tmpl self.export_ckpt_filename_tmpl = export_ckpt_filename_tmpl self.export_after_each_ckpt = export_after_each_ckpt or cfg.get( 'export_after_each_ckpt', False)
[docs] def export_model(self, runner, epoch): export_ckpt_fname = self.export_ckpt_filename_tmpl.format(epoch) export_local_ckpt = os.path.join(self.work_dir, export_ckpt_fname) runner.logger.info(f'export model to {export_local_ckpt}') from easycv.apis.export import export if hasattr(runner.model, 'module'): model = runner.model.module else: model = runner.model export( self.cfg, ckpt_path='dummy', filename=export_local_ckpt, model=model)
[docs] @master_only def after_train_iter(self, runner): pass
[docs] @master_only def after_train_epoch(self, runner): # do export after every ckpy is right! should do so! if self.export_after_each_ckpt: self.export_model(runner, runner.epoch) pass
[docs] @master_only def after_run(self, runner): self.export_model(runner, runner.epoch)