Source code for easycv.utils.checkpoint

# Copyright (c) Alibaba, Inc. and its affiliates.
import logging
import os

import torch
from mmcv.parallel import is_module_wrapper
from mmcv.runner import load_checkpoint as mmcv_load_checkpoint
from mmcv.runner.checkpoint import get_state_dict, weights_to_cpu
from torch.optim import Optimizer

from easycv.file import io
from easycv.file.utils import is_url_path
from easycv.framework.errors import TypeError
from easycv.utils.constant import CACHE_DIR


[docs]def get_checkpoint(filename): if filename.startswith('oss://'): _, fname = os.path.split(filename) cache_file = os.path.join(CACHE_DIR, fname) if not os.path.exists(CACHE_DIR): os.makedirs(CACHE_DIR) if not os.path.exists(cache_file): logging.info( f'download checkpoint from {filename} to {cache_file}') io.copy(filename, cache_file) if torch.distributed.is_available( ) and torch.distributed.is_initialized(): torch.distributed.barrier() filename = cache_file elif is_url_path(filename): from torch.hub import urlparse, download_url_to_file parts = urlparse(filename) base_name = os.path.basename(parts.path) cache_file = os.path.join(CACHE_DIR, base_name) if not os.path.exists(CACHE_DIR): os.makedirs(CACHE_DIR) if not os.path.exists(cache_file): logging.info( f'download checkpoint from {filename} to {cache_file}') download_url_to_file(filename, cache_file) if torch.distributed.is_available( ) and torch.distributed.is_initialized(): torch.distributed.barrier() filename = cache_file return filename
[docs]def load_checkpoint(model, filename, map_location='cpu', strict=False, logger=None, revise_keys=[(r'^module\.', '')]): """Load checkpoint from a file or URI. Args: model (Module): Module to load checkpoint. filename (str): Accept local filepath, URL, ``torchvision://xxx``, ``open-mmlab://xxx``. Please refer to ``docs/model_zoo.md`` for details. map_location (str): Same as :func:`torch.load`. strict (bool): Whether to allow different params for the model and checkpoint. logger (:mod:`logging.Logger` or None): The logger for error message. revise_keys (list): A list of customized keywords to modify the state_dict in checkpoint. Each item is a (pattern, replacement) pair of the regular expression operations. Default: strip the prefix 'module.' by [(r'^module\\.', '')]. Returns: dict or OrderedDict: The loaded checkpoint. """ filename = get_checkpoint(filename) return mmcv_load_checkpoint( model, filename, map_location=map_location, strict=strict, logger=logger, revise_keys=revise_keys)
[docs]def save_checkpoint(model, filename, optimizer=None, meta=None): """Save checkpoint to file. The checkpoint will have 3 fields: ``meta``, ``state_dict`` and ``optimizer``. By default ``meta`` will contain version and time info. Args: model (Module): Module whose params are to be saved. filename (str): Checkpoint filename. optimizer (:obj:`Optimizer`, optional): Optimizer to be saved. meta (dict, optional): Metadata to be saved in checkpoint. """ if meta is None: meta = {} elif not isinstance(meta, dict): raise TypeError(f'meta must be a dict or None, but got {type(meta)}') out_dir = os.path.dirname(filename) out_dir = out_dir + '/' if out_dir[-1] != '/' else out_dir if not io.isdir(out_dir): io.makedirs(out_dir) if is_module_wrapper(model): model = model.module checkpoint = { 'meta': meta, 'state_dict': weights_to_cpu(get_state_dict(model)) } if isinstance(optimizer, Optimizer): checkpoint['optimizer'] = optimizer.state_dict() elif isinstance(optimizer, dict): checkpoint['optimizer'] = {} for name, optim in optimizer.items(): checkpoint['optimizer'][name] = optim.state_dict() with io.open(filename, 'wb') as ofile: torch.save(checkpoint, ofile)