Spaces:
Sleeping
Sleeping
| from ditk import logging | |
| import signal | |
| import sys | |
| import traceback | |
| from typing import Callable | |
| import torch | |
| import torch.utils.data # torch1.1.0 compatibility | |
| from ding.utils import read_file, save_file | |
| logger = logging.getLogger('default_logger') | |
| def build_checkpoint_helper(cfg): | |
| """ | |
| Overview: | |
| Use config to build checkpoint helper. | |
| Arguments: | |
| - cfg (:obj:`dict`): ckpt_helper config | |
| Returns: | |
| - (:obj:`CheckpointHelper`): checkpoint_helper created by this function | |
| """ | |
| return CheckpointHelper() | |
| class CheckpointHelper: | |
| """ | |
| Overview: | |
| Help to save or load checkpoint by give args. | |
| Interfaces: | |
| ``__init__``, ``save``, ``load``, ``_remove_prefix``, ``_add_prefix``, ``_load_matched_model_state_dict`` | |
| """ | |
| def __init__(self): | |
| pass | |
| def _remove_prefix(self, state_dict: dict, prefix: str = 'module.') -> dict: | |
| """ | |
| Overview: | |
| Remove prefix in state_dict | |
| Arguments: | |
| - state_dict (:obj:`dict`): model's state_dict | |
| - prefix (:obj:`str`): this prefix will be removed in keys | |
| Returns: | |
| - new_state_dict (:obj:`dict`): new state_dict after removing prefix | |
| """ | |
| new_state_dict = {} | |
| for k, v in state_dict.items(): | |
| if k.startswith(prefix): | |
| new_k = ''.join(k.split(prefix)) | |
| else: | |
| new_k = k | |
| new_state_dict[new_k] = v | |
| return new_state_dict | |
| def _add_prefix(self, state_dict: dict, prefix: str = 'module.') -> dict: | |
| """ | |
| Overview: | |
| Add prefix in state_dict | |
| Arguments: | |
| - state_dict (:obj:`dict`): model's state_dict | |
| - prefix (:obj:`str`): this prefix will be added in keys | |
| Returns: | |
| - (:obj:`dict`): new state_dict after adding prefix | |
| """ | |
| return {prefix + k: v for k, v in state_dict.items()} | |
| def save( | |
| self, | |
| path: str, | |
| model: torch.nn.Module, | |
| optimizer: torch.optim.Optimizer = None, | |
| last_iter: 'CountVar' = None, # noqa | |
| last_epoch: 'CountVar' = None, # noqa | |
| last_frame: 'CountVar' = None, # noqa | |
| dataset: torch.utils.data.Dataset = None, | |
| collector_info: torch.nn.Module = None, | |
| prefix_op: str = None, | |
| prefix: str = None, | |
| ) -> None: | |
| """ | |
| Overview: | |
| Save checkpoint by given args | |
| Arguments: | |
| - path (:obj:`str`): the path of saving checkpoint | |
| - model (:obj:`torch.nn.Module`): model to be saved | |
| - optimizer (:obj:`torch.optim.Optimizer`): optimizer obj | |
| - last_iter (:obj:`CountVar`): iter num, default None | |
| - last_epoch (:obj:`CountVar`): epoch num, default None | |
| - last_frame (:obj:`CountVar`): frame num, default None | |
| - dataset (:obj:`torch.utils.data.Dataset`): dataset, should be replaydataset | |
| - collector_info (:obj:`torch.nn.Module`): attr of checkpoint, save collector info | |
| - prefix_op (:obj:`str`): should be ['remove', 'add'], process on state_dict | |
| - prefix (:obj:`str`): prefix to be processed on state_dict | |
| """ | |
| checkpoint = {} | |
| model = model.state_dict() | |
| if prefix_op is not None: # remove or add prefix to model.keys() | |
| prefix_func = {'remove': self._remove_prefix, 'add': self._add_prefix} | |
| if prefix_op not in prefix_func.keys(): | |
| raise KeyError('invalid prefix_op:{}'.format(prefix_op)) | |
| else: | |
| model = prefix_func[prefix_op](model, prefix) | |
| checkpoint['model'] = model | |
| if optimizer is not None: # save optimizer | |
| assert (last_iter is not None or last_epoch is not None) | |
| checkpoint['last_iter'] = last_iter.val | |
| if last_epoch is not None: | |
| checkpoint['last_epoch'] = last_epoch.val | |
| if last_frame is not None: | |
| checkpoint['last_frame'] = last_frame.val | |
| checkpoint['optimizer'] = optimizer.state_dict() | |
| if dataset is not None: | |
| checkpoint['dataset'] = dataset.state_dict() | |
| if collector_info is not None: | |
| checkpoint['collector_info'] = collector_info.state_dict() | |
| save_file(path, checkpoint) | |
| logger.info('save checkpoint in {}'.format(path)) | |
| def _load_matched_model_state_dict(self, model: torch.nn.Module, ckpt_state_dict: dict) -> None: | |
| """ | |
| Overview: | |
| Load matched model state_dict, and show mismatch keys between model's state_dict and checkpoint's state_dict | |
| Arguments: | |
| - model (:obj:`torch.nn.Module`): model | |
| - ckpt_state_dict (:obj:`dict`): checkpoint's state_dict | |
| """ | |
| assert isinstance(model, torch.nn.Module) | |
| diff = {'miss_keys': [], 'redundant_keys': [], 'mismatch_shape_keys': []} | |
| model_state_dict = model.state_dict() | |
| model_keys = set(model_state_dict.keys()) | |
| ckpt_keys = set(ckpt_state_dict.keys()) | |
| diff['miss_keys'] = model_keys - ckpt_keys | |
| diff['redundant_keys'] = ckpt_keys - model_keys | |
| intersection_keys = model_keys.intersection(ckpt_keys) | |
| valid_keys = [] | |
| for k in intersection_keys: | |
| if model_state_dict[k].shape == ckpt_state_dict[k].shape: | |
| valid_keys.append(k) | |
| else: | |
| diff['mismatch_shape_keys'].append( | |
| '{}\tmodel_shape: {}\tckpt_shape: {}'.format( | |
| k, model_state_dict[k].shape, ckpt_state_dict[k].shape | |
| ) | |
| ) | |
| valid_ckpt_state_dict = {k: v for k, v in ckpt_state_dict.items() if k in valid_keys} | |
| model.load_state_dict(valid_ckpt_state_dict, strict=False) | |
| for n, keys in diff.items(): | |
| for k in keys: | |
| logger.info('{}: {}'.format(n, k)) | |
| def load( | |
| self, | |
| load_path: str, | |
| model: torch.nn.Module, | |
| optimizer: torch.optim.Optimizer = None, | |
| last_iter: 'CountVar' = None, # noqa | |
| last_epoch: 'CountVar' = None, # noqa | |
| last_frame: 'CountVar' = None, # noqa | |
| lr_schduler: 'Scheduler' = None, # noqa | |
| dataset: torch.utils.data.Dataset = None, | |
| collector_info: torch.nn.Module = None, | |
| prefix_op: str = None, | |
| prefix: str = None, | |
| strict: bool = True, | |
| logger_prefix: str = '', | |
| state_dict_mask: list = [], | |
| ): | |
| """ | |
| Overview: | |
| Load checkpoint by given path | |
| Arguments: | |
| - load_path (:obj:`str`): checkpoint's path | |
| - model (:obj:`torch.nn.Module`): model definition | |
| - optimizer (:obj:`torch.optim.Optimizer`): optimizer obj | |
| - last_iter (:obj:`CountVar`): iter num, default None | |
| - last_epoch (:obj:`CountVar`): epoch num, default None | |
| - last_frame (:obj:`CountVar`): frame num, default None | |
| - lr_schduler (:obj:`Schduler`): lr_schduler obj | |
| - dataset (:obj:`torch.utils.data.Dataset`): dataset, should be replaydataset | |
| - collector_info (:obj:`torch.nn.Module`): attr of checkpoint, save collector info | |
| - prefix_op (:obj:`str`): should be ['remove', 'add'], process on state_dict | |
| - prefix (:obj:`str`): prefix to be processed on state_dict | |
| - strict (:obj:`bool`): args of model.load_state_dict | |
| - logger_prefix (:obj:`str`): prefix of logger | |
| - state_dict_mask (:obj:`list`): A list containing state_dict keys, \ | |
| which shouldn't be loaded into model(after prefix op) | |
| .. note:: | |
| The checkpoint loaded from load_path is a dict, whose format is like '{'state_dict': OrderedDict(), ...}' | |
| """ | |
| # TODO save config | |
| # Note: for reduce first GPU memory cost and compatible for cpu env | |
| checkpoint = read_file(load_path) | |
| state_dict = checkpoint['model'] | |
| if prefix_op is not None: | |
| prefix_func = {'remove': self._remove_prefix, 'add': self._add_prefix} | |
| if prefix_op not in prefix_func.keys(): | |
| raise KeyError('invalid prefix_op:{}'.format(prefix_op)) | |
| else: | |
| state_dict = prefix_func[prefix_op](state_dict, prefix) | |
| if len(state_dict_mask) > 0: | |
| if strict: | |
| logger.info( | |
| logger_prefix + | |
| '[Warning] non-empty state_dict_mask expects strict=False, but finds strict=True in input argument' | |
| ) | |
| strict = False | |
| for m in state_dict_mask: | |
| state_dict_keys = list(state_dict.keys()) | |
| for k in state_dict_keys: | |
| if k.startswith(m): | |
| state_dict.pop(k) # ignore return value | |
| if strict: | |
| model.load_state_dict(state_dict, strict=True) | |
| else: | |
| self._load_matched_model_state_dict(model, state_dict) | |
| logger.info(logger_prefix + 'load model state_dict in {}'.format(load_path)) | |
| if dataset is not None: | |
| if 'dataset' in checkpoint.keys(): | |
| dataset.load_state_dict(checkpoint['dataset']) | |
| logger.info(logger_prefix + 'load online data in {}'.format(load_path)) | |
| else: | |
| logger.info(logger_prefix + "dataset not in checkpoint, ignore load procedure") | |
| if optimizer is not None: | |
| if 'optimizer' in checkpoint.keys(): | |
| optimizer.load_state_dict(checkpoint['optimizer']) | |
| logger.info(logger_prefix + 'load optimizer in {}'.format(load_path)) | |
| else: | |
| logger.info(logger_prefix + "optimizer not in checkpoint, ignore load procedure") | |
| if last_iter is not None: | |
| if 'last_iter' in checkpoint.keys(): | |
| last_iter.update(checkpoint['last_iter']) | |
| logger.info( | |
| logger_prefix + 'load last_iter in {}, current last_iter is {}'.format(load_path, last_iter.val) | |
| ) | |
| else: | |
| logger.info(logger_prefix + "last_iter not in checkpoint, ignore load procedure") | |
| if collector_info is not None: | |
| collector_info.load_state_dict(checkpoint['collector_info']) | |
| logger.info(logger_prefix + 'load collector info in {}'.format(load_path)) | |
| if lr_schduler is not None: | |
| assert (last_iter is not None) | |
| raise NotImplementedError | |
| class CountVar(object): | |
| """ | |
| Overview: | |
| Number counter | |
| Interfaces: | |
| ``__init__``, ``update``, ``add`` | |
| Properties: | |
| - val (:obj:`int`): the value of the counter | |
| """ | |
| def __init__(self, init_val: int) -> None: | |
| """ | |
| Overview: | |
| Init the var counter | |
| Arguments: | |
| - init_val (:obj:`int`): the init value of the counter | |
| """ | |
| self._val = init_val | |
| def val(self) -> int: | |
| """ | |
| Overview: | |
| Get the var counter | |
| """ | |
| return self._val | |
| def update(self, val: int) -> None: | |
| """ | |
| Overview: | |
| Update the var counter | |
| Arguments: | |
| - val (:obj:`int`): the update value of the counter | |
| """ | |
| self._val = val | |
| def add(self, add_num: int): | |
| """ | |
| Overview: | |
| Add the number to counter | |
| Arguments: | |
| - add_num (:obj:`int`): the number added to the counter | |
| """ | |
| self._val += add_num | |
| def auto_checkpoint(func: Callable) -> Callable: | |
| """ | |
| Overview: | |
| Create a wrapper to wrap function, and the wrapper will call the save_checkpoint method | |
| whenever an exception happens. | |
| Arguments: | |
| - func(:obj:`Callable`): the function to be wrapped | |
| Returns: | |
| - wrapper (:obj:`Callable`): the wrapped function | |
| """ | |
| dead_signals = ['SIGILL', 'SIGINT', 'SIGKILL', 'SIGQUIT', 'SIGSEGV', 'SIGSTOP', 'SIGTERM', 'SIGBUS'] | |
| all_signals = dead_signals + ['SIGUSR1'] | |
| def register_signal_handler(handler): | |
| valid_sig = [] | |
| invalid_sig = [] | |
| for sig in all_signals: | |
| try: | |
| sig = getattr(signal, sig) | |
| signal.signal(sig, handler) | |
| valid_sig.append(sig) | |
| except Exception: | |
| invalid_sig.append(sig) | |
| logger.info('valid sig: ({})\ninvalid sig: ({})'.format(valid_sig, invalid_sig)) | |
| def wrapper(*args, **kwargs): | |
| handle = args[0] | |
| assert (hasattr(handle, 'save_checkpoint')) | |
| def signal_handler(signal_num, frame): | |
| sig = signal.Signals(signal_num) | |
| logger.info("SIGNAL: {}({})".format(sig.name, sig.value)) | |
| handle.save_checkpoint('ckpt_interrupt.pth.tar') | |
| sys.exit(1) | |
| register_signal_handler(signal_handler) | |
| try: | |
| return func(*args, **kwargs) | |
| except Exception as e: | |
| handle.save_checkpoint('ckpt_exception.pth.tar') | |
| traceback.print_exc() | |
| return wrapper | |