Source code for easycv.models.backbones.hrnet

# Copyright (c) OpenMMLab. All rights reserved.
# Adapt from https://github.com/open-mmlab/mmpose/blob/master/mmpose/models/backbones/hrnet.py
import copy

import torch.nn as nn
from mmcv.cnn import (build_conv_layer, build_norm_layer, constant_init,
                      normal_init)
from torch.nn.modules.batchnorm import _BatchNorm

from easycv.framework.errors import NotImplementedError, TypeError, ValueError
from easycv.models.registry import BACKBONES
from ..modelzoo import hrnet as model_urls
from .resnet import BasicBlock


[docs]def get_expansion(block, expansion=None): """Get the expansion of a residual block. The block expansion will be obtained by the following order: 1. If ``expansion`` is given, just return it. 2. If ``block`` has the attribute ``expansion``, then return ``block.expansion``. 3. Return the default value according the the block type: 1 for ``BasicBlock`` and 4 for ``Bottleneck``. Args: block (class): The block class. expansion (int | None): The given expansion ratio. Returns: int: The expansion of the block. """ if isinstance(expansion, int): assert expansion > 0 elif expansion is None: if hasattr(block, 'expansion'): expansion = block.expansion elif issubclass(block, BasicBlock): expansion = 1 elif issubclass(block, Bottleneck): expansion = 4 else: raise TypeError(f'expansion is not specified for {block.__name__}') else: raise TypeError('expansion must be an integer or None') return expansion
[docs]class Bottleneck(nn.Module): """Bottleneck block for ResNet. Args: in_channels (int): Input channels of this block. out_channels (int): Output channels of this block. expansion (int): The ratio of ``out_channels/mid_channels`` where ``mid_channels`` is the input/output channels of conv2. Default: 4. stride (int): stride of the block. Default: 1 dilation (int): dilation of convolution. Default: 1 downsample (nn.Module): downsample operation on identity branch. Default: None. style (str): ``"pytorch"`` or ``"caffe"``. If set to "pytorch", the stride-two layer is the 3x3 conv layer, otherwise the stride-two layer is the first 1x1 conv layer. Default: "pytorch". with_cp (bool): Use checkpoint or not. Using checkpoint will save some memory while slowing down the training speed. conv_cfg (dict): dictionary to construct and config conv layer. Default: None norm_cfg (dict): dictionary to construct and config norm layer. Default: dict(type='BN') """
[docs] def __init__(self, in_channels, out_channels, expansion=4, stride=1, dilation=1, downsample=None, style='pytorch', with_cp=False, conv_cfg=None, norm_cfg=dict(type='BN')): # Protect mutable default arguments norm_cfg = copy.deepcopy(norm_cfg) super().__init__() assert style in ['pytorch', 'caffe'] self.in_channels = in_channels self.out_channels = out_channels self.expansion = expansion assert out_channels % expansion == 0 self.mid_channels = out_channels // expansion self.stride = stride self.dilation = dilation self.style = style self.with_cp = with_cp self.conv_cfg = conv_cfg self.norm_cfg = norm_cfg if self.style == 'pytorch': self.conv1_stride = 1 self.conv2_stride = stride else: self.conv1_stride = stride self.conv2_stride = 1 self.norm1_name, norm1 = build_norm_layer( norm_cfg, self.mid_channels, postfix=1) self.norm2_name, norm2 = build_norm_layer( norm_cfg, self.mid_channels, postfix=2) self.norm3_name, norm3 = build_norm_layer( norm_cfg, out_channels, postfix=3) self.conv1 = build_conv_layer( conv_cfg, in_channels, self.mid_channels, kernel_size=1, stride=self.conv1_stride, bias=False) self.add_module(self.norm1_name, norm1) self.conv2 = build_conv_layer( conv_cfg, self.mid_channels, self.mid_channels, kernel_size=3, stride=self.conv2_stride, padding=dilation, dilation=dilation, bias=False) self.add_module(self.norm2_name, norm2) self.conv3 = build_conv_layer( conv_cfg, self.mid_channels, out_channels, kernel_size=1, bias=False) self.add_module(self.norm3_name, norm3) self.relu = nn.ReLU(inplace=True) self.downsample = downsample
@property def norm1(self): """nn.Module: the normalization layer named "norm1" """ return getattr(self, self.norm1_name) @property def norm2(self): """nn.Module: the normalization layer named "norm2" """ return getattr(self, self.norm2_name) @property def norm3(self): """nn.Module: the normalization layer named "norm3" """ return getattr(self, self.norm3_name)
[docs] def forward(self, x): """Forward function.""" def _inner_forward(x): identity = x out = self.conv1(x) out = self.norm1(out) out = self.relu(out) out = self.conv2(out) out = self.norm2(out) out = self.relu(out) out = self.conv3(out) out = self.norm3(out) if self.downsample is not None: identity = self.downsample(x) out += identity return out if self.with_cp and x.requires_grad: raise NotImplementedError else: out = _inner_forward(x) out = self.relu(out) return out
[docs]class HRModule(nn.Module): """High-Resolution Module for HRNet. In this module, every branch has 4 BasicBlocks/Bottlenecks. Fusion/Exchange is in this module. """
[docs] def __init__(self, num_branches, blocks, num_blocks, in_channels, num_channels, multiscale_output=False, with_cp=False, conv_cfg=None, norm_cfg=dict(type='BN'), upsample_cfg=dict(mode='nearest', align_corners=None)): # Protect mutable default arguments norm_cfg = copy.deepcopy(norm_cfg) super().__init__() self._check_branches(num_branches, num_blocks, in_channels, num_channels) self.in_channels = in_channels self.num_branches = num_branches self.multiscale_output = multiscale_output self.norm_cfg = norm_cfg self.conv_cfg = conv_cfg self.upsample_cfg = upsample_cfg self.with_cp = with_cp self.branches = self._make_branches(num_branches, blocks, num_blocks, num_channels) self.fuse_layers = self._make_fuse_layers() self.relu = nn.ReLU(inplace=True)
@staticmethod def _check_branches(num_branches, num_blocks, in_channels, num_channels): """Check input to avoid ValueError.""" if num_branches != len(num_blocks): error_msg = f'NUM_BRANCHES({num_branches}) ' \ f'!= NUM_BLOCKS({len(num_blocks)})' raise ValueError(error_msg) if num_branches != len(num_channels): error_msg = f'NUM_BRANCHES({num_branches}) ' \ f'!= NUM_CHANNELS({len(num_channels)})' raise ValueError(error_msg) if num_branches != len(in_channels): error_msg = f'NUM_BRANCHES({num_branches}) ' \ f'!= NUM_INCHANNELS({len(in_channels)})' raise ValueError(error_msg) def _make_one_branch(self, branch_index, block, num_blocks, num_channels, stride=1): """Make one branch.""" downsample = None if stride != 1 or \ self.in_channels[branch_index] != \ num_channels[branch_index] * get_expansion(block): downsample = nn.Sequential( build_conv_layer( self.conv_cfg, self.in_channels[branch_index], num_channels[branch_index] * get_expansion(block), kernel_size=1, stride=stride, bias=False), build_norm_layer( self.norm_cfg, num_channels[branch_index] * get_expansion(block))[1]) layers = [] layers.append( block( self.in_channels[branch_index], num_channels[branch_index] * get_expansion(block), stride=stride, downsample=downsample, with_cp=self.with_cp, norm_cfg=self.norm_cfg, conv_cfg=self.conv_cfg)) self.in_channels[branch_index] = \ num_channels[branch_index] * get_expansion(block) for _ in range(1, num_blocks[branch_index]): layers.append( block( self.in_channels[branch_index], num_channels[branch_index] * get_expansion(block), with_cp=self.with_cp, norm_cfg=self.norm_cfg, conv_cfg=self.conv_cfg)) return nn.Sequential(*layers) def _make_branches(self, num_branches, block, num_blocks, num_channels): """Make branches.""" branches = [] for i in range(num_branches): branches.append( self._make_one_branch(i, block, num_blocks, num_channels)) return nn.ModuleList(branches) def _make_fuse_layers(self): """Make fuse layer.""" if self.num_branches == 1: return None num_branches = self.num_branches in_channels = self.in_channels fuse_layers = [] num_out_branches = num_branches if self.multiscale_output else 1 for i in range(num_out_branches): fuse_layer = [] for j in range(num_branches): if j > i: fuse_layer.append( nn.Sequential( build_conv_layer( self.conv_cfg, in_channels[j], in_channels[i], kernel_size=1, stride=1, padding=0, bias=False), build_norm_layer(self.norm_cfg, in_channels[i])[1], nn.Upsample( scale_factor=2**(j - i), mode=self.upsample_cfg['mode'], align_corners=self. upsample_cfg['align_corners']))) elif j == i: fuse_layer.append(None) else: conv_downsamples = [] for k in range(i - j): if k == i - j - 1: conv_downsamples.append( nn.Sequential( build_conv_layer( self.conv_cfg, in_channels[j], in_channels[i], kernel_size=3, stride=2, padding=1, bias=False), build_norm_layer(self.norm_cfg, in_channels[i])[1])) else: conv_downsamples.append( nn.Sequential( build_conv_layer( self.conv_cfg, in_channels[j], in_channels[j], kernel_size=3, stride=2, padding=1, bias=False), build_norm_layer(self.norm_cfg, in_channels[j])[1], nn.ReLU(inplace=True))) fuse_layer.append(nn.Sequential(*conv_downsamples)) fuse_layers.append(nn.ModuleList(fuse_layer)) return nn.ModuleList(fuse_layers)
[docs] def forward(self, x): """Forward function.""" if self.num_branches == 1: return [self.branches[0](x[0])] for i in range(self.num_branches): x[i] = self.branches[i](x[i]) x_fuse = [] for i in range(len(self.fuse_layers)): y = 0 for j in range(self.num_branches): if i == j: y += x[j] else: y += self.fuse_layers[i][j](x[j]) x_fuse.append(self.relu(y)) return x_fuse
[docs]@BACKBONES.register_module() class HRNet(nn.Module): """HRNet backbone. `High-Resolution Representations for Labeling Pixels and Regions <https://arxiv.org/abs/1904.04514>`__ Args: extra (dict): detailed configuration for each stage of HRNet. in_channels (int): Number of input image channels. Default: 3. conv_cfg (dict): dictionary to construct and config conv layer. norm_cfg (dict): dictionary to construct and config norm layer. norm_eval (bool): Whether to set norm layers to eval mode, namely, freeze running stats (mean and var). Note: Effect on Batch Norm and its variants only. Default: False with_cp (bool): Use checkpoint or not. Using checkpoint will save some memory while slowing down the training speed. zero_init_residual (bool): whether to use zero init for last norm layer in resblocks to let them behave as identity. Example: >>> from mmpose.models import HRNet >>> import torch >>> extra = dict( >>> stage1=dict( >>> num_modules=1, >>> num_branches=1, >>> block='BOTTLENECK', >>> num_blocks=(4, ), >>> num_channels=(64, )), >>> stage2=dict( >>> num_modules=1, >>> num_branches=2, >>> block='BASIC', >>> num_blocks=(4, 4), >>> num_channels=(32, 64)), >>> stage3=dict( >>> num_modules=4, >>> num_branches=3, >>> block='BASIC', >>> num_blocks=(4, 4, 4), >>> num_channels=(32, 64, 128)), >>> stage4=dict( >>> num_modules=3, >>> num_branches=4, >>> block='BASIC', >>> num_blocks=(4, 4, 4, 4), >>> num_channels=(32, 64, 128, 256))) >>> self = HRNet(extra, in_channels=1) >>> self.eval() >>> inputs = torch.rand(1, 1, 32, 32) >>> level_outputs = self.forward(inputs) >>> for level_out in level_outputs: ... print(tuple(level_out.shape)) (1, 32, 8, 8) """ blocks_dict = {'BASIC': BasicBlock, 'BOTTLENECK': Bottleneck} arch_zoo = { # num_modules, num_branches, block, num_blocks, num_channels 'w18': [[1, 1, 'BOTTLENECK', (4, ), (64, )], [1, 2, 'BASIC', (4, 4), (18, 36)], [4, 3, 'BASIC', (4, 4, 4), (18, 36, 72)], [3, 4, 'BASIC', (4, 4, 4, 4), (18, 36, 72, 144)]], 'w30': [[1, 1, 'BOTTLENECK', (4, ), (64, )], [1, 2, 'BASIC', (4, 4), (30, 60)], [4, 3, 'BASIC', (4, 4, 4), (30, 60, 120)], [3, 4, 'BASIC', (4, 4, 4, 4), (30, 60, 120, 240)]], 'w32': [[1, 1, 'BOTTLENECK', (4, ), (64, )], [1, 2, 'BASIC', (4, 4), (32, 64)], [4, 3, 'BASIC', (4, 4, 4), (32, 64, 128)], [3, 4, 'BASIC', (4, 4, 4, 4), (32, 64, 128, 256)]], 'w40': [[1, 1, 'BOTTLENECK', (4, ), (64, )], [1, 2, 'BASIC', (4, 4), (40, 80)], [4, 3, 'BASIC', (4, 4, 4), (40, 80, 160)], [3, 4, 'BASIC', (4, 4, 4, 4), (40, 80, 160, 320)]], 'w44': [[1, 1, 'BOTTLENECK', (4, ), (64, )], [1, 2, 'BASIC', (4, 4), (44, 88)], [4, 3, 'BASIC', (4, 4, 4), (44, 88, 176)], [3, 4, 'BASIC', (4, 4, 4, 4), (44, 88, 176, 352)]], 'w48': [[1, 1, 'BOTTLENECK', (4, ), (64, )], [1, 2, 'BASIC', (4, 4), (48, 96)], [4, 3, 'BASIC', (4, 4, 4), (48, 96, 192)], [3, 4, 'BASIC', (4, 4, 4, 4), (48, 96, 192, 384)]], 'w64': [[1, 1, 'BOTTLENECK', (4, ), (64, )], [1, 2, 'BASIC', (4, 4), (64, 128)], [4, 3, 'BASIC', (4, 4, 4), (64, 128, 256)], [3, 4, 'BASIC', (4, 4, 4, 4), (64, 128, 256, 512)]], } # yapf:disable
[docs] def __init__(self, arch='w32', extra=None, in_channels=3, conv_cfg=None, norm_cfg=dict(type='BN'), norm_eval=False, with_cp=False, zero_init_residual=False, multi_scale_output=False): # Protect mutable default arguments norm_cfg = copy.deepcopy(norm_cfg) super().__init__() extra = self.parse_arch(arch, extra) self.extra = extra self.conv_cfg = conv_cfg self.norm_cfg = norm_cfg self.norm_eval = norm_eval self.with_cp = with_cp self.zero_init_residual = zero_init_residual # stem net self.norm1_name, norm1 = build_norm_layer(self.norm_cfg, 64, postfix=1) self.norm2_name, norm2 = build_norm_layer(self.norm_cfg, 64, postfix=2) self.conv1 = build_conv_layer( self.conv_cfg, in_channels, 64, kernel_size=3, stride=2, padding=1, bias=False) self.add_module(self.norm1_name, norm1) self.conv2 = build_conv_layer( self.conv_cfg, 64, 64, kernel_size=3, stride=2, padding=1, bias=False) self.add_module(self.norm2_name, norm2) self.relu = nn.ReLU(inplace=True) self.upsample_cfg = self.extra.get('upsample', { 'mode': 'nearest', 'align_corners': None }) # stage 1 self.stage1_cfg = self.extra['stage1'] num_channels = self.stage1_cfg['num_channels'][0] block_type = self.stage1_cfg['block'] num_blocks = self.stage1_cfg['num_blocks'][0] block = self.blocks_dict[block_type] stage1_out_channels = num_channels * get_expansion(block) self.layer1 = self._make_layer(block, 64, stage1_out_channels, num_blocks) # stage 2 self.stage2_cfg = self.extra['stage2'] num_channels = self.stage2_cfg['num_channels'] block_type = self.stage2_cfg['block'] block = self.blocks_dict[block_type] num_channels = [ channel * get_expansion(block) for channel in num_channels ] self.transition1 = self._make_transition_layer([stage1_out_channels], num_channels) self.stage2, pre_stage_channels = self._make_stage( self.stage2_cfg, num_channels) # stage 3 self.stage3_cfg = self.extra['stage3'] num_channels = self.stage3_cfg['num_channels'] block_type = self.stage3_cfg['block'] block = self.blocks_dict[block_type] num_channels = [ channel * get_expansion(block) for channel in num_channels ] self.transition2 = self._make_transition_layer(pre_stage_channels, num_channels) self.stage3, pre_stage_channels = self._make_stage( self.stage3_cfg, num_channels) # stage 4 self.stage4_cfg = self.extra['stage4'] num_channels = self.stage4_cfg['num_channels'] block_type = self.stage4_cfg['block'] block = self.blocks_dict[block_type] num_channels = [ channel * get_expansion(block) for channel in num_channels ] self.transition3 = self._make_transition_layer(pre_stage_channels, num_channels) self.stage4, pre_stage_channels = self._make_stage( self.stage4_cfg, num_channels, multiscale_output=self.stage4_cfg.get('multiscale_output', multi_scale_output)) self.default_pretrained_model_path = model_urls.get( self.__class__.__name__ + arch, None)
@property def norm1(self): """nn.Module: the normalization layer named "norm1" """ return getattr(self, self.norm1_name) @property def norm2(self): """nn.Module: the normalization layer named "norm2" """ return getattr(self, self.norm2_name) def _make_transition_layer(self, num_channels_pre_layer, num_channels_cur_layer): """Make transition layer.""" num_branches_cur = len(num_channels_cur_layer) num_branches_pre = len(num_channels_pre_layer) transition_layers = [] for i in range(num_branches_cur): if i < num_branches_pre: if num_channels_cur_layer[i] != num_channels_pre_layer[i]: transition_layers.append( nn.Sequential( build_conv_layer( self.conv_cfg, num_channels_pre_layer[i], num_channels_cur_layer[i], kernel_size=3, stride=1, padding=1, bias=False), build_norm_layer(self.norm_cfg, num_channels_cur_layer[i])[1], nn.ReLU(inplace=True))) else: transition_layers.append(None) else: conv_downsamples = [] for j in range(i + 1 - num_branches_pre): in_channels = num_channels_pre_layer[-1] out_channels = num_channels_cur_layer[i] \ if j == i - num_branches_pre else in_channels conv_downsamples.append( nn.Sequential( build_conv_layer( self.conv_cfg, in_channels, out_channels, kernel_size=3, stride=2, padding=1, bias=False), build_norm_layer(self.norm_cfg, out_channels)[1], nn.ReLU(inplace=True))) transition_layers.append(nn.Sequential(*conv_downsamples)) return nn.ModuleList(transition_layers) def _make_layer(self, block, in_channels, out_channels, blocks, stride=1): """Make layer.""" downsample = None if stride != 1 or in_channels != out_channels: downsample = nn.Sequential( build_conv_layer( self.conv_cfg, in_channels, out_channels, kernel_size=1, stride=stride, bias=False), build_norm_layer(self.norm_cfg, out_channels)[1]) layers = [] layers.append( block( in_channels, out_channels, stride=stride, downsample=downsample, with_cp=self.with_cp, norm_cfg=self.norm_cfg, conv_cfg=self.conv_cfg)) for _ in range(1, blocks): layers.append( block( out_channels, out_channels, with_cp=self.with_cp, norm_cfg=self.norm_cfg, conv_cfg=self.conv_cfg)) return nn.Sequential(*layers) def _make_stage(self, layer_config, in_channels, multiscale_output=True): """Make stage.""" num_modules = layer_config['num_modules'] num_branches = layer_config['num_branches'] num_blocks = layer_config['num_blocks'] num_channels = layer_config['num_channels'] block = self.blocks_dict[layer_config['block']] hr_modules = [] for i in range(num_modules): # multi_scale_output is only used for the last module if not multiscale_output and i == num_modules - 1: reset_multiscale_output = False else: reset_multiscale_output = True hr_modules.append( HRModule( num_branches, block, num_blocks, in_channels, num_channels, reset_multiscale_output, with_cp=self.with_cp, norm_cfg=self.norm_cfg, conv_cfg=self.conv_cfg, upsample_cfg=self.upsample_cfg)) in_channels = hr_modules[-1].in_channels return nn.Sequential(*hr_modules), in_channels
[docs] def init_weights(self): for m in self.modules(): if isinstance(m, nn.Conv2d): normal_init(m, std=0.001) elif isinstance(m, (_BatchNorm, nn.GroupNorm)): constant_init(m, 1) if self.zero_init_residual: for m in self.modules(): if isinstance(m, Bottleneck): constant_init(m.norm3, 0) elif isinstance(m, BasicBlock): constant_init(m.norm2, 0)
[docs] def forward(self, x): """Forward function.""" x = self.conv1(x) x = self.norm1(x) x = self.relu(x) x = self.conv2(x) x = self.norm2(x) x = self.relu(x) x = self.layer1(x) x_list = [] for i in range(self.stage2_cfg['num_branches']): if self.transition1[i] is not None: x_list.append(self.transition1[i](x)) else: x_list.append(x) y_list = self.stage2(x_list) x_list = [] for i in range(self.stage3_cfg['num_branches']): if self.transition2[i] is not None: x_list.append(self.transition2[i](y_list[-1])) else: x_list.append(y_list[i]) y_list = self.stage3(x_list) x_list = [] for i in range(self.stage4_cfg['num_branches']): if self.transition3[i] is not None: x_list.append(self.transition3[i](y_list[-1])) else: x_list.append(y_list[i]) y_list = self.stage4(x_list) return y_list
[docs] def train(self, mode=True): """Convert the model into training mode.""" super().train(mode) if mode and self.norm_eval: for m in self.modules(): if isinstance(m, _BatchNorm): m.eval()
[docs] def parse_arch(self, arch, extra=None): if extra is not None: return extra assert arch in self.arch_zoo, \ ('Invalid arch, please choose arch from ' f'{list(self.arch_zoo.keys())}, or specify `extra` ' 'argument directly.') extra = dict() for i, stage_setting in enumerate(self.arch_zoo[arch], start=1): extra[f'stage{i}'] = dict( num_modules=stage_setting[0], num_branches=stage_setting[1], block=stage_setting[2], num_blocks=stage_setting[3], num_channels=stage_setting[4], ) return extra