# Copyright (c) Alibaba, Inc. and its affiliates.
import copy
import json
import mmcv
import numpy as np
import torch
from mmcv.image import imwrite
from mmcv.utils.path import is_filepath
from mmcv.visualization.image import imshow
from easycv.core.visualization import imshow_bboxes, imshow_keypoints
from easycv.datasets.pose.data_sources.top_down import DatasetInfo
from easycv.datasets.pose.pipelines.transforms import bbox_cs2xyxy
from easycv.file import io
from easycv.predictors.builder import PREDICTORS, build_predictor
from easycv.utils.checkpoint import load_checkpoint
from easycv.utils.config_tools import mmcv_config_fromfile
from easycv.utils.misc import deprecated
from .base import InputProcessor, OutputProcessor, PredictorV2
np.set_printoptions(suppress=True)
def _box2cs(image_size, box):
"""This encodes bbox(x,y,w,h) into (center, scale)
Args:
x, y, w, h
Returns:
tuple: A tuple containing center and scale.
- np.ndarray[float32](2,): Center of the bbox (x, y).
- np.ndarray[float32](2,): Scale of the bbox w & h.
"""
x, y, w, h = box[:4]
aspect_ratio = image_size[0] / image_size[1]
center = np.array([x + w * 0.5, y + h * 0.5], dtype=np.float32)
if w > aspect_ratio * h:
h = w * 1.0 / aspect_ratio
elif w < aspect_ratio * h:
w = h * aspect_ratio
# pixel std is 200.0
scale = np.array([w / 200.0, h / 200.0], dtype=np.float32)
scale = scale * 1.25
return center, scale
[docs]def vis_pose_result(
model,
img,
result,
radius=4,
thickness=1,
kpt_score_thr=0.3,
bbox_color='green',
dataset_info=None,
out_file=None,
pose_kpt_color=None,
pose_link_color=None,
text_color='white',
font_scale=0.5,
bbox_thickness=1,
win_name='',
show=False,
wait_time=0,
):
"""Visualize the detection results on the image.
Args:
model (nn.Module): The loaded detector.
img (str | np.ndarray): Image filename or loaded image.
result (list[dict]): The results to draw over `img`
(bbox_result, pose_result).
radius (int): Radius of circles.
thickness (int): Thickness of lines.
kpt_score_thr (float): The threshold to visualize the keypoints.
skeleton (list[tuple()]): Default None.
out_file (str|None): The filename of the output visualization image.
show (bool): Whether to show the image. Default: False.
wait_time (int): Value of waitKey param.
Default: 0.
out_file (str or None): The filename to write the image.
Default: None.
"""
# get dataset info
if (dataset_info is None and hasattr(model, 'cfg')
and 'dataset_info' in model.cfg):
dataset_info = DatasetInfo(model.cfg.dataset_info)
if not dataset_info:
raise ValueError('Please provide `dataset_info`!')
skeleton = dataset_info.skeleton
pose_kpt_color = dataset_info.pose_kpt_color
pose_link_color = dataset_info.pose_link_color
if hasattr(model, 'module'):
model = model.module
img = mmcv.imread(img)
img = img.copy()
bbox_result = result.get('bbox', [])
pose_result = result['keypoints']
if len(bbox_result) > 0:
bboxes = np.vstack(bbox_result)
labels = None
if 'label' in result:
labels = result['label']
# draw bounding boxes
imshow_bboxes(
img,
bboxes,
labels=labels,
colors=bbox_color,
text_color=text_color,
thickness=bbox_thickness,
font_scale=font_scale,
show=False)
imshow_keypoints(img, pose_result, skeleton, kpt_score_thr, pose_kpt_color,
pose_link_color, radius, thickness)
if show:
imshow(img, win_name, wait_time)
if out_file is not None:
imwrite(img, out_file)
return img
[docs]class PoseTopDownOutputProcessor(OutputProcessor):
def __call__(self, inputs):
output = {}
output['keypoints'] = inputs['preds']
output['bbox'] = np.array(inputs['boxes']) # x1, y1, x2, y2 score
return output
# TODO: Fix when multi people are detected in each sample,
# all the people results will be passed to the pose model,
# resulting in a dynamic batch_size, which is not supported by jit script model.
[docs]@PREDICTORS.register_module()
class PoseTopDownPredictor(PredictorV2):
"""Pose topdown predictor.
Args:
model_path (str): Path of model path.
config_file (Optinal[str]): Config file path for model and processor to init. Defaults to None.
detection_model_config: Dict of person detection model predictor config,
example like ``dict(type="", model_path="", config_file="", ......)``
batch_size (int): Batch size for forward.
bbox_thr (float): Bounding box threshold to filter output results of detection model
cat_id (int | str): Category id or name to filter target objects.
device (str | torch.device): Support str('cuda' or 'cpu') or torch.device, if is None, detect device automatically.
save_results (bool): Whether to save predict results.
save_path (str): File path for saving results, only valid when `save_results` is True.
pipelines (list[dict]): Data pipeline configs.
mode (str): The image mode into the model.
"""
[docs] def __init__(self,
model_path,
config_file=None,
detection_predictor_config=None,
batch_size=1,
bbox_thr=None,
cat_id=None,
device=None,
pipelines=None,
save_results=False,
save_path=None,
mode='BGR',
model_type=None,
*args,
**kwargs):
assert batch_size == 1, 'Only support batch_size=1 now!'
self.cat_id = cat_id
self.bbox_thr = bbox_thr
self.detection_predictor_config = detection_predictor_config
self.model_type = model_type
if self.model_type is None:
if model_path.endswith('jit'):
assert config_file is not None
self.model_type = 'jit'
elif model_path.endswith('blade'):
import torch_blade
assert config_file is not None
self.model_type = 'blade'
else:
self.model_type = 'raw'
assert self.model_type in ['raw', 'jit', 'blade']
super(PoseTopDownPredictor, self).__init__(
model_path,
config_file=config_file,
batch_size=batch_size,
device=device,
save_results=save_results,
save_path=save_path,
pipelines=pipelines,
input_processor_threads=1,
mode=mode,
*args,
**kwargs)
if hasattr(self.cfg, 'dataset_info'):
dataset_info = self.cfg.dataset_info
if is_filepath(dataset_info):
cfg = mmcv_config_fromfile(dataset_info)
dataset_info = cfg._cfg_dict['dataset_info']
else:
from easycv.datasets.pose.data_sources.coco import COCO_DATASET_INFO
dataset_info = COCO_DATASET_INFO
self.dataset_info = DatasetInfo(dataset_info)
def _build_model(self):
if self.model_type != 'raw':
with io.open(self.model_path, 'rb') as infile:
model = torch.jit.load(infile, self.device)
else:
model = super()._build_model()
return model
[docs] def prepare_model(self):
"""Build model from config file by default.
If the model is not loaded from a configuration file, e.g. torch jit model, you need to reimplement it.
"""
model = self._build_model()
model.to(self.device)
model.eval()
if self.model_type == 'raw':
load_checkpoint(model, self.model_path, map_location='cpu')
return model
[docs] def model_forward(self, inputs, return_heatmap=False):
boxes = inputs['bbox'].cpu().numpy()
if self.model_type == 'raw':
with torch.no_grad():
result = self.model(
**inputs, mode='test', return_heatmap=return_heatmap)
else:
img_metas = inputs['img_metas']
with torch.no_grad():
img = inputs['img'].to(self.device)
tensor_img_metas = copy.deepcopy(img_metas)
for meta in tensor_img_metas:
meta.pop('image_file')
for k, v in meta.items():
meta[k] = torch.tensor(v)
output_heatmap = self.model(img, tensor_img_metas)
from easycv.models.pose.heads.topdown_heatmap_base_head import decode_heatmap
output_heatmap = output_heatmap.cpu().numpy()
result = decode_heatmap(output_heatmap, img_metas,
self.cfg.model.test_cfg)
result['boxes'] = np.array(boxes)
return result
[docs] def get_output_processor(self):
return PoseTopDownOutputProcessor()
[docs] def show_result(self,
image,
keypoints,
radius=4,
thickness=3,
kpt_score_thr=0.3,
bbox_color='green',
show=False,
save_path=None):
vis_result = vis_pose_result(
self.model,
image,
keypoints,
dataset_info=self.dataset_info,
radius=radius,
thickness=thickness,
kpt_score_thr=kpt_score_thr,
bbox_color=bbox_color,
show=show,
out_file=save_path)
return vis_result
class _TorchPoseTopDownOutputProcessor(PoseTopDownOutputProcessor):
def __call__(self, inputs):
output = super(_TorchPoseTopDownOutputProcessor, self).__call__(inputs)
bbox = output['bbox']
keypoints = output['keypoints']
results = []
for i in range(len(keypoints)):
results.append({'bbox': bbox[i], 'keypoints': keypoints[i]})
return {'pose_results': results}
@deprecated(reason='Please use PoseTopDownPredictor.')
@PREDICTORS.register_module()
class TorchPoseTopDownPredictorWithDetector(PoseTopDownPredictor):
def __init__(
self,
model_path,
model_config={
'pose': {
'bbox_thr': 0.3,
'format': 'xywh'
},
'detection': {
'model_type': None,
'reserved_classes': [],
'score_thresh': 0.0,
}
},
):
"""
init model
Args:
model_path: pose and detection model file path, split with `,`,
make sure the first is pose model, second is detection model
model_config: config string for model to init, in json format
"""
if isinstance(model_config, str):
model_config = json.loads(model_config)
reserved_classes = model_config['detection'].pop(
'reserved_classes', [])
if len(reserved_classes) == 0:
reserved_classes = None
else:
assert len(reserved_classes) == 1
reserved_classes = reserved_classes[0]
model_list = model_path.split(',')
assert len(model_list) == 2
# first is pose model, second is detection model
pose_model_path, detection_model_path = model_list
detection_model_type = model_config['detection'].pop('model_type')
if detection_model_type == 'TorchYoloXPredictor':
detection_predictor_config = dict(
type=detection_model_type,
model_path=detection_model_path,
model_config=model_config['detection'])
else:
detection_predictor_config = dict(
model_path=detection_model_path, **model_config['detection'])
pose_kwargs = model_config['pose']
pose_kwargs.pop('format', None)
super().__init__(
model_path=pose_model_path,
detection_predictor_config=detection_predictor_config,
cat_id=reserved_classes,
**pose_kwargs,
)
def get_output_processor(self):
return _TorchPoseTopDownOutputProcessor()
def show_result(self,
image_path,
keypoints,
radius=4,
thickness=1,
kpt_score_thr=0.3,
bbox_color='green',
show=False,
save_path=None):
dataset_info = self.dataset_info
# get dataset info
if (dataset_info is None and hasattr(self.model, 'cfg')
and 'dataset_info' in self.model.cfg):
dataset_info = DatasetInfo(self.model.cfg.dataset_info)
if not dataset_info:
raise ValueError('Please provide `dataset_info`!')
skeleton = dataset_info.skeleton
pose_kpt_color = dataset_info.pose_kpt_color
pose_link_color = dataset_info.pose_link_color
if hasattr(self.model, 'module'):
self.model = self.model.module
img = self.model.show_result(
image_path,
keypoints,
skeleton,
radius=radius,
thickness=thickness,
pose_kpt_color=pose_kpt_color,
pose_link_color=pose_link_color,
kpt_score_thr=kpt_score_thr,
bbox_color=bbox_color,
show=show,
out_file=save_path)
return img