Source code for easycv.hooks.oss_sync_hook

# Copyright (c) Alibaba, Inc. and its affiliates.
import glob
import os

from mmcv.runner import Hook
from mmcv.runner.dist_utils import master_only

from easycv.file import io
from .registry import HOOKS


[docs]@HOOKS.register_module class OSSSyncHook(Hook): """ upload log files and checkpoints to oss when training on pai """
[docs] def __init__(self, work_dir, oss_work_dir, interval=1, ckpt_filename_tmpl='epoch_{}.pth', export_ckpt_filename_tmpl='epoch_{}_export.pt', other_file_list=[], iter_interval=None): """ Args: work_dir: work_dir in cfg oss_work_dir: oss directory where to upload local files in work_dir interval: upload frequency ckpt_filename_tmpl: checkpoint filename template other_file_list: other file need to be upload to oss iter_interval: upload frequency by iter interval, default to be None, means do it with certain assignment """ self.work_dir = work_dir self.oss_work_dir = oss_work_dir self.interval = interval self.ckpt_filename_tmpl = ckpt_filename_tmpl self.export_ckpt_filename_tmpl = export_ckpt_filename_tmpl self.other_file_list = other_file_list self.iter_interval = iter_interval
[docs] def upload_file(self, runner): if hasattr(runner, 'file_upload_perepoch'): up_load_file_list = runner.file_upload_perepoch + self.other_file_list else: up_load_file_list = self.other_file_list up_load_file_list = list(set(up_load_file_list)) epoch = runner.epoch + 1 # try to up load ckpt model ckpt_fname = self.ckpt_filename_tmpl.format(epoch) local_ckpt = os.path.join(self.work_dir, ckpt_fname) oss_ckpt = os.path.join(self.oss_work_dir, ckpt_fname) if not os.path.exists(local_ckpt): runner.logger.warning(f'{local_ckpt} does not exists, skip upload') else: runner.logger.info(f'upload {local_ckpt} to {oss_ckpt}') io.safe_copy(local_ckpt, oss_ckpt) for other_file in up_load_file_list: local_files = glob.glob( os.path.join(self.work_dir, other_file), recursive=True) for local_file in local_files: rel_path = os.path.relpath(local_file, self.work_dir) oss_file = os.path.join(self.oss_work_dir, rel_path) runner.logger.info(f'upload {up_load_file_list}') io.safe_copy(local_file, oss_file)
# local_tf_logs = os.path.join(self.work_dir, 'tf_logs') # oss_tf_logs = os.path.join(self.oss_work_dir, 'tf_logs') # runner.logger.info(f'upload directory {local_tf_logs} to {oss_tf_logs}') # io.copytree(local_tf_logs, oss_tf_logs) # we still use oss sdk to upload pth, log, by default iter 1000, which
[docs] @master_only def after_train_iter(self, runner): # upload checkpoint and tf events if self.iter_interval is not None: if not self.every_n_inner_iters(runner, self.iter_interval): return self.upload_file(runner) return
[docs] @master_only def after_train_epoch(self, runner): # upload checkpoint and tf events if not self.every_n_epochs(runner, self.interval): return self.upload_file(runner)
[docs] @master_only def after_run(self, runner): # upload final log files timestamp = runner.timestamp upload_files = [ '{}.log'.format(timestamp), '{}.log.json'.format(timestamp), ] for log_file in upload_files: local_log = os.path.join(self.work_dir, log_file) if not os.path.exists(local_log): runner.logger.warning( f'{local_log} does not exists, skip upload') continue oss_log = os.path.join(self.oss_work_dir, log_file) runner.logger.info(f'upload {local_log} to {oss_log}') io.safe_copy(local_log, oss_log) # try to upload exported model epoch = runner.epoch export_ckpt_fname = self.export_ckpt_filename_tmpl.format(epoch) # upload all export files export_files = glob.glob( os.path.join(self.work_dir, '*{}*'.format(export_ckpt_fname)), recursive=True) for export_file in export_files: rel_path = os.path.relpath(export_file, self.work_dir) target_oss_path = os.path.join(self.oss_work_dir, rel_path) runner.logger.info(f'upload {export_file} to {target_oss_path}') io.safe_copy(export_file, target_oss_path)