Source code for easycv.models.pose.top_down

# Copyright (c) OpenMMLab. All rights reserved.
# Adapt from https://raw.githubusercontent.com/open-mmlab/mmpose/master/mmpose/models/detectors/top_down.py

import warnings

import mmcv
import numpy as np
import torch
from mmcv.image import imwrite
from mmcv.utils.misc import deprecated_api_warning
from mmcv.visualization.image import imshow

from easycv.core.visualization import imshow_bboxes, imshow_keypoints
from easycv.models import builder
from easycv.models.base import BaseModel
from easycv.models.builder import MODELS
from easycv.utils.checkpoint import load_checkpoint
from easycv.utils.logger import get_root_logger


[docs]@MODELS.register_module() class TopDown(BaseModel): """Top-down pose detectors. Args: backbone (dict): Backbone modules to extract feature. keypoint_head (dict): Keypoint head to process feature. train_cfg (dict): Config for training. Default: None. test_cfg (dict): Config for testing. Default: None. pretrained (str): Path to the pretrained models. loss_pose (None): Deprecated arguments. Please use `loss_keypoint` for heads instead. """
[docs] def __init__(self, backbone, neck=None, keypoint_head=None, train_cfg=None, test_cfg=None, pretrained=None, loss_pose=None): super().__init__() self.pretrained = pretrained self.backbone = builder.build_backbone(backbone) self.train_cfg = train_cfg self.test_cfg = test_cfg if neck is not None: self.neck = builder.build_neck(neck) if keypoint_head is not None: keypoint_head['train_cfg'] = train_cfg keypoint_head['test_cfg'] = test_cfg if 'loss_keypoint' not in keypoint_head and loss_pose is not None: warnings.warn( '`loss_pose` for TopDown is deprecated, ' 'use `loss_keypoint` for heads instead. See ' 'https://github.com/open-mmlab/mmpose/pull/382' ' for more information.', DeprecationWarning) keypoint_head['loss_keypoint'] = loss_pose self.keypoint_head = builder.build_head(keypoint_head) self.init_weights()
@property def with_neck(self): """Check if has keypoint_head.""" return hasattr(self, 'neck') @property def with_keypoint(self): """Check if has keypoint_head.""" return hasattr(self, 'keypoint_head')
[docs] def init_weights(self): """Weight initialization for model.""" if isinstance(self.pretrained, str): logger = get_root_logger() load_checkpoint( self.backbone, self.pretrained, strict=False, logger=logger) else: self.backbone.init_weights() if self.with_neck: self.neck.init_weights() if self.with_keypoint: self.keypoint_head.init_weights()
[docs] def forward_train(self, img, target, target_weight, img_metas, **kwargs): """Defines the computation performed at every call when training.""" output = self.backbone(img) if self.with_neck: output = self.neck(output) if self.with_keypoint: output = self.keypoint_head(output) # if return loss losses = dict() if self.with_keypoint: keypoint_losses = self.keypoint_head.get_loss( output, target, target_weight) losses.update(keypoint_losses) keypoint_accuracy = self.keypoint_head.get_accuracy( output, target, target_weight) losses.update(keypoint_accuracy) return losses
[docs] def forward_test(self, img, img_metas, return_heatmap=False, **kwargs): """Defines the computation performed at every call when testing.""" assert img.size(0) == len(img_metas) batch_size, _, img_height, img_width = img.shape if batch_size > 1: assert 'bbox_id' in img_metas[0] result = {} features = self.backbone(img) if self.with_neck: features = self.neck(features) if self.with_keypoint: output_heatmap = self.keypoint_head.inference_model( features, flip_pairs=None) if self.test_cfg.get('flip_test', True): img_flipped = img.flip(3) features_flipped = self.backbone(img_flipped) if self.with_neck: features_flipped = self.neck(features_flipped) if self.with_keypoint: output_flipped_heatmap = self.keypoint_head.inference_model( features_flipped, img_metas[0]['flip_pairs']) # remove inplace operation for blade, it will cause calculation errors _tmp = (output_heatmap + output_flipped_heatmap) * 0.5 output_heatmap = _tmp if torch.jit.is_scripting() or torch.jit.is_tracing(): return output_heatmap output_heatmap = output_heatmap.numpy() if self.with_keypoint: keypoint_result = self.keypoint_head.decode( img_metas, output_heatmap, img_size=[img_width, img_height]) result.update(keypoint_result) if not return_heatmap: output_heatmap = None if output_heatmap: result['output_heatmap'] = output_heatmap return result
[docs] def forward_export(self, img, img_metas, return_heatmap=False): return self.forward_test(img, img_metas, return_heatmap=return_heatmap)
[docs] @deprecated_api_warning({'pose_limb_color': 'pose_link_color'}, cls_name='TopDown') def show_result(self, img, result, skeleton=None, kpt_score_thr=0.3, bbox_color='green', pose_kpt_color=None, pose_link_color=None, text_color='white', radius=4, thickness=1, font_scale=0.5, bbox_thickness=1, win_name='', show=False, show_keypoint_weight=False, wait_time=0, out_file=None): """Draw `result` over `img`. Args: img (str or Tensor): The image to be displayed. result (list[dict]): The results to draw over `img` (bbox_result, pose_result). skeleton (list[list]): The connection of keypoints. skeleton is 0-based indexing. kpt_score_thr (float, optional): Minimum score of keypoints to be shown. Default: 0.3. bbox_color (str or tuple or :obj:`Color`): Color of bbox lines. pose_kpt_color (np.array[Nx3]`): Color of N keypoints. If None, do not draw keypoints. pose_link_color (np.array[Mx3]): Color of M links. If None, do not draw links. text_color (str or tuple or :obj:`Color`): Color of texts. radius (int): Radius of circles. thickness (int): Thickness of lines. font_scale (float): Font scales of texts. win_name (str): The window name. show (bool): Whether to show the image. Default: False. show_keypoint_weight (bool): Whether to change the transparency using the predicted confidence scores of keypoints. wait_time (int): Value of waitKey param. Default: 0. out_file (str or None): The filename to write the image. Default: None. Returns: Tensor: Visualized img, only if not `show` or `out_file`. """ img = mmcv.imread(img) img = img.copy() bbox_result = [] pose_result = [] for res in result: if 'bbox' in res: bbox_result.append(res['bbox']) pose_result.append(res['keypoints']) if bbox_result: bboxes = np.vstack(bbox_result) labels = None if 'label' in result[0]: labels = [res['label'] for res in result] # draw bounding boxes imshow_bboxes( img, bboxes, labels=labels, colors=bbox_color, text_color=text_color, thickness=bbox_thickness, font_scale=font_scale, show=False) imshow_keypoints(img, pose_result, skeleton, kpt_score_thr, pose_kpt_color, pose_link_color, radius, thickness) if show: imshow(img, win_name, wait_time) if out_file is not None: imwrite(img, out_file) return img