Source code for easycv.datasets.shared.pipelines.third_transforms_wrapper

# Copyright (c) Alibaba, Inc. and its affiliates.
import copy
import inspect
from enum import EnumMeta

import torch
from torchvision import transforms as _transforms

from easycv.datasets.registry import PIPELINES


[docs]def is_child_of(obj, cls): try: for i in obj.__bases__: if i is cls or isinstance(i, cls): return True for i in obj.__bases__: if is_child_of(i, cls): return True except AttributeError: return is_child_of(obj.__class__, cls) return False
[docs]def get_args(obj): full_args_spec = inspect.getfullargspec(obj) args = [] if not full_args_spec.args else full_args_spec.args if (args and args[0] in ['self', 'cls']): args.pop(0) return args
def _reset_forward(obj): original_forward = obj.forward def _new_forward(self, results): img = results['img'] img = original_forward(self, img) results['img'] = img return results setattr(obj, 'forward', _new_forward) def _reset_call(obj): original_call = obj.__call__ def _new_call(self, results): img = results['img'] img = original_call(self, img) results['img'] = img return results setattr(obj, '__call__', _new_call) # TODO: find a more pretty way to wrap third transfomrs or import fixed api to warp
[docs]def wrap_torchvision_transforms(transform_obj): transform_obj = copy.deepcopy(transform_obj) # args_format = ['img', 'pic'] if is_child_of(transform_obj, torch.nn.Module): args = get_args(transform_obj.forward) if len(args) == 1: # and args[0] in args_format: _reset_forward(transform_obj) elif hasattr(transform_obj, '__call__'): args = get_args(transform_obj.__call__) if len(args) == 1: # and args[0] in args_format: _reset_call(transform_obj) else: pass
skip_list = ['Compose', 'RandomApply'] _transforms_names = locals() # register all existing transforms in torchvision for member in inspect.getmembers(_transforms, inspect.isclass): obj_name, obj = member[0], member[1] if obj_name in skip_list: continue if isinstance(obj, EnumMeta): continue _transforms_names[obj_name] = type(obj_name, (obj, ), dict()) wrap_torchvision_transforms(_transforms_names[obj_name]) PIPELINES.register_module(_transforms_names[obj_name])