# Copyright (c) Alibaba, Inc. and its affiliates.
import copy
import inspect
from enum import EnumMeta
import torch
from torchvision import transforms as _transforms
from easycv.datasets.registry import PIPELINES
[docs]def is_child_of(obj, cls):
try:
for i in obj.__bases__:
if i is cls or isinstance(i, cls):
return True
for i in obj.__bases__:
if is_child_of(i, cls):
return True
except AttributeError:
return is_child_of(obj.__class__, cls)
return False
[docs]def get_args(obj):
full_args_spec = inspect.getfullargspec(obj)
args = [] if not full_args_spec.args else full_args_spec.args
if (args and args[0] in ['self', 'cls']):
args.pop(0)
return args
def _reset_forward(obj):
original_forward = obj.forward
def _new_forward(self, results):
img = results['img']
img = original_forward(self, img)
results['img'] = img
return results
setattr(obj, 'forward', _new_forward)
def _reset_call(obj):
original_call = obj.__call__
def _new_call(self, results):
img = results['img']
img = original_call(self, img)
results['img'] = img
return results
setattr(obj, '__call__', _new_call)
# TODO: find a more pretty way to wrap third transfomrs or import fixed api to warp
skip_list = ['Compose', 'RandomApply']
_transforms_names = locals()
# register all existing transforms in torchvision
for member in inspect.getmembers(_transforms, inspect.isclass):
obj_name, obj = member[0], member[1]
if obj_name in skip_list:
continue
if isinstance(obj, EnumMeta):
continue
_transforms_names[obj_name] = type(obj_name, (obj, ), dict())
wrap_torchvision_transforms(_transforms_names[obj_name])
PIPELINES.register_module(_transforms_names[obj_name])