Source code for easycv.models.backbones.lighthrnet

# Copyright (c) OpenMMLab. All rights reserved.
# Adapt from https://github.com/open-mmlab/mmpose/blob/master/mmpose/models/backbones/litehrnet.py

import mmcv
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.checkpoint as cp
from mmcv.cnn import (ConvModule, DepthwiseSeparableConvModule,
                      build_conv_layer, build_norm_layer, constant_init,
                      normal_init)
from torch.nn.modules.batchnorm import _BatchNorm

from easycv.framework.errors import ValueError
from easycv.models.registry import BACKBONES


[docs]def channel_shuffle(x, groups): """Channel Shuffle operation. This function enables cross-group information flow for multiple groups convolution layers. Args: x (Tensor): The input tensor. groups (int): The number of groups to divide the input tensor in the channel dimension. Returns: Tensor: The output tensor after channel shuffle operation. """ batch_size, num_channels, height, width = x.size() assert (num_channels % groups == 0), ('num_channels should be ' 'divisible by groups') channels_per_group = num_channels // groups x = x.view(batch_size, groups, channels_per_group, height, width) x = torch.transpose(x, 1, 2).contiguous() x = x.view(batch_size, -1, height, width) return x
[docs]class SpatialWeighting(nn.Module): """Spatial weighting module. Args: channels (int): The channels of the module. ratio (int): channel reduction ratio. conv_cfg (dict): Config dict for convolution layer. Default: None, which means using conv2d. norm_cfg (dict): Config dict for normalization layer. Default: None. act_cfg (dict): Config dict for activation layer. Default: (dict(type='ReLU'), dict(type='Sigmoid')). The last ConvModule uses Sigmoid by default. """
[docs] def __init__(self, channels, ratio=16, conv_cfg=None, norm_cfg=None, act_cfg=(dict(type='ReLU'), dict(type='Sigmoid'))): super().__init__() if isinstance(act_cfg, dict): act_cfg = (act_cfg, act_cfg) assert len(act_cfg) == 2 assert mmcv.is_tuple_of(act_cfg, dict) self.global_avgpool = nn.AdaptiveAvgPool2d(1) self.conv1 = ConvModule( in_channels=channels, out_channels=int(channels / ratio), kernel_size=1, stride=1, conv_cfg=conv_cfg, norm_cfg=norm_cfg, act_cfg=act_cfg[0]) self.conv2 = ConvModule( in_channels=int(channels / ratio), out_channels=channels, kernel_size=1, stride=1, conv_cfg=conv_cfg, norm_cfg=norm_cfg, act_cfg=act_cfg[1])
[docs] def forward(self, x): out = self.global_avgpool(x) out = self.conv1(out) out = self.conv2(out) return x * out
[docs]class CrossResolutionWeighting(nn.Module): """Cross-resolution channel weighting module. Args: channels (int): The channels of the module. ratio (int): channel reduction ratio. conv_cfg (dict): Config dict for convolution layer. Default: None, which means using conv2d. norm_cfg (dict): Config dict for normalization layer. Default: None. act_cfg (dict): Config dict for activation layer. Default: (dict(type='ReLU'), dict(type='Sigmoid')). The last ConvModule uses Sigmoid by default. """
[docs] def __init__(self, channels, ratio=16, conv_cfg=None, norm_cfg=None, act_cfg=(dict(type='ReLU'), dict(type='Sigmoid'))): super().__init__() if isinstance(act_cfg, dict): act_cfg = (act_cfg, act_cfg) assert len(act_cfg) == 2 assert mmcv.is_tuple_of(act_cfg, dict) self.channels = channels total_channel = sum(channels) self.conv1 = ConvModule( in_channels=total_channel, out_channels=int(total_channel / ratio), kernel_size=1, stride=1, conv_cfg=conv_cfg, norm_cfg=norm_cfg, act_cfg=act_cfg[0]) self.conv2 = ConvModule( in_channels=int(total_channel / ratio), out_channels=total_channel, kernel_size=1, stride=1, conv_cfg=conv_cfg, norm_cfg=norm_cfg, act_cfg=act_cfg[1])
[docs] def forward(self, x): mini_size = x[-1].size()[-2:] out = [F.adaptive_avg_pool2d(s, mini_size) for s in x[:-1]] + [x[-1]] out = torch.cat(out, dim=1) out = self.conv1(out) out = self.conv2(out) out = torch.split(out, self.channels, dim=1) out = [ s * F.interpolate(a, size=s.size()[-2:], mode='nearest') for s, a in zip(x, out) ] return out
[docs]class ConditionalChannelWeighting(nn.Module): """Conditional channel weighting block. Args: in_channels (int): The input channels of the block. stride (int): Stride of the 3x3 convolution layer. reduce_ratio (int): channel reduction ratio. conv_cfg (dict): Config dict for convolution layer. Default: None, which means using conv2d. norm_cfg (dict): Config dict for normalization layer. Default: dict(type='BN'). with_cp (bool): Use checkpoint or not. Using checkpoint will save some memory while slowing down the training speed. Default: False. """
[docs] def __init__(self, in_channels, stride, reduce_ratio, conv_cfg=None, norm_cfg=dict(type='BN'), with_cp=False): super().__init__() self.with_cp = with_cp self.stride = stride assert stride in [1, 2] branch_channels = [channel // 2 for channel in in_channels] self.cross_resolution_weighting = CrossResolutionWeighting( branch_channels, ratio=reduce_ratio, conv_cfg=conv_cfg, norm_cfg=norm_cfg) self.depthwise_convs = nn.ModuleList([ ConvModule( channel, channel, kernel_size=3, stride=self.stride, padding=1, groups=channel, conv_cfg=conv_cfg, norm_cfg=norm_cfg, act_cfg=None) for channel in branch_channels ]) self.spatial_weighting = nn.ModuleList([ SpatialWeighting(channels=channel, ratio=4) for channel in branch_channels ])
[docs] def forward(self, x): def _inner_forward(x): x = [s.chunk(2, dim=1) for s in x] x1 = [s[0] for s in x] x2 = [s[1] for s in x] x2 = self.cross_resolution_weighting(x2) x2 = [dw(s) for s, dw in zip(x2, self.depthwise_convs)] x2 = [sw(s) for s, sw in zip(x2, self.spatial_weighting)] out = [torch.cat([s1, s2], dim=1) for s1, s2 in zip(x1, x2)] out = [channel_shuffle(s, 2) for s in out] return out if self.with_cp and x.requires_grad: out = cp.checkpoint(_inner_forward, x) else: out = _inner_forward(x) return out
[docs]class Stem(nn.Module): """Stem network block. Args: in_channels (int): The input channels of the block. stem_channels (int): Output channels of the stem layer. out_channels (int): The output channels of the block. expand_ratio (int): adjusts number of channels of the hidden layer in InvertedResidual by this amount. conv_cfg (dict): Config dict for convolution layer. Default: None, which means using conv2d. norm_cfg (dict): Config dict for normalization layer. Default: dict(type='BN'). with_cp (bool): Use checkpoint or not. Using checkpoint will save some memory while slowing down the training speed. Default: False. """
[docs] def __init__(self, in_channels, stem_channels, out_channels, expand_ratio, conv_cfg=None, norm_cfg=dict(type='BN'), with_cp=False): super().__init__() self.in_channels = in_channels self.out_channels = out_channels self.conv_cfg = conv_cfg self.norm_cfg = norm_cfg self.with_cp = with_cp self.conv1 = ConvModule( in_channels=in_channels, out_channels=stem_channels, kernel_size=3, stride=2, padding=1, conv_cfg=self.conv_cfg, norm_cfg=self.norm_cfg, act_cfg=dict(type='ReLU')) mid_channels = int(round(stem_channels * expand_ratio)) branch_channels = stem_channels // 2 if stem_channels == self.out_channels: inc_channels = self.out_channels - branch_channels else: inc_channels = self.out_channels - stem_channels self.branch1 = nn.Sequential( ConvModule( branch_channels, branch_channels, kernel_size=3, stride=2, padding=1, groups=branch_channels, conv_cfg=conv_cfg, norm_cfg=norm_cfg, act_cfg=None), ConvModule( branch_channels, inc_channels, kernel_size=1, stride=1, padding=0, conv_cfg=conv_cfg, norm_cfg=norm_cfg, act_cfg=dict(type='ReLU')), ) self.expand_conv = ConvModule( branch_channels, mid_channels, kernel_size=1, stride=1, padding=0, conv_cfg=conv_cfg, norm_cfg=norm_cfg, act_cfg=dict(type='ReLU')) self.depthwise_conv = ConvModule( mid_channels, mid_channels, kernel_size=3, stride=2, padding=1, groups=mid_channels, conv_cfg=conv_cfg, norm_cfg=norm_cfg, act_cfg=None) self.linear_conv = ConvModule( mid_channels, branch_channels if stem_channels == self.out_channels else stem_channels, kernel_size=1, stride=1, padding=0, conv_cfg=conv_cfg, norm_cfg=norm_cfg, act_cfg=dict(type='ReLU'))
[docs] def forward(self, x): def _inner_forward(x): x = self.conv1(x) x1, x2 = x.chunk(2, dim=1) x2 = self.expand_conv(x2) x2 = self.depthwise_conv(x2) x2 = self.linear_conv(x2) out = torch.cat((self.branch1(x1), x2), dim=1) out = channel_shuffle(out, 2) return out if self.with_cp and x.requires_grad: out = cp.checkpoint(_inner_forward, x) else: out = _inner_forward(x) return out
[docs]class IterativeHead(nn.Module): """Extra iterative head for feature learning. Args: in_channels (int): The input channels of the block. norm_cfg (dict): Config dict for normalization layer. Default: dict(type='BN'). """
[docs] def __init__(self, in_channels, norm_cfg=dict(type='BN')): super().__init__() projects = [] num_branchs = len(in_channels) self.in_channels = in_channels[::-1] for i in range(num_branchs): if i != num_branchs - 1: projects.append( DepthwiseSeparableConvModule( in_channels=self.in_channels[i], out_channels=self.in_channels[i + 1], kernel_size=3, stride=1, padding=1, norm_cfg=norm_cfg, act_cfg=dict(type='ReLU'), dw_act_cfg=None, pw_act_cfg=dict(type='ReLU'))) else: projects.append( DepthwiseSeparableConvModule( in_channels=self.in_channels[i], out_channels=self.in_channels[i], kernel_size=3, stride=1, padding=1, norm_cfg=norm_cfg, act_cfg=dict(type='ReLU'), dw_act_cfg=None, pw_act_cfg=dict(type='ReLU'))) self.projects = nn.ModuleList(projects)
[docs] def forward(self, x): x = x[::-1] y = [] last_x = None for i, s in enumerate(x): if last_x is not None: last_x = F.interpolate( last_x, size=s.size()[-2:], mode='bilinear', align_corners=True) s = s + last_x s = self.projects[i](s) y.append(s) last_x = s return y[::-1]
[docs]class ShuffleUnit(nn.Module): """InvertedResidual block for ShuffleNetV2 backbone. Args: in_channels (int): The input channels of the block. out_channels (int): The output channels of the block. stride (int): Stride of the 3x3 convolution layer. Default: 1 conv_cfg (dict): Config dict for convolution layer. Default: None, which means using conv2d. norm_cfg (dict): Config dict for normalization layer. Default: dict(type='BN'). act_cfg (dict): Config dict for activation layer. Default: dict(type='ReLU'). with_cp (bool): Use checkpoint or not. Using checkpoint will save some memory while slowing down the training speed. Default: False. """
[docs] def __init__(self, in_channels, out_channels, stride=1, conv_cfg=None, norm_cfg=dict(type='BN'), act_cfg=dict(type='ReLU'), with_cp=False): super().__init__() self.stride = stride self.with_cp = with_cp branch_features = out_channels // 2 if self.stride == 1: assert in_channels == branch_features * 2, ( f'in_channels ({in_channels}) should equal to ' f'branch_features * 2 ({branch_features * 2}) ' 'when stride is 1') if in_channels != branch_features * 2: assert self.stride != 1, ( f'stride ({self.stride}) should not equal 1 when ' f'in_channels != branch_features * 2') if self.stride > 1: self.branch1 = nn.Sequential( ConvModule( in_channels, in_channels, kernel_size=3, stride=self.stride, padding=1, groups=in_channels, conv_cfg=conv_cfg, norm_cfg=norm_cfg, act_cfg=None), ConvModule( in_channels, branch_features, kernel_size=1, stride=1, padding=0, conv_cfg=conv_cfg, norm_cfg=norm_cfg, act_cfg=act_cfg), ) self.branch2 = nn.Sequential( ConvModule( in_channels if (self.stride > 1) else branch_features, branch_features, kernel_size=1, stride=1, padding=0, conv_cfg=conv_cfg, norm_cfg=norm_cfg, act_cfg=act_cfg), ConvModule( branch_features, branch_features, kernel_size=3, stride=self.stride, padding=1, groups=branch_features, conv_cfg=conv_cfg, norm_cfg=norm_cfg, act_cfg=None), ConvModule( branch_features, branch_features, kernel_size=1, stride=1, padding=0, conv_cfg=conv_cfg, norm_cfg=norm_cfg, act_cfg=act_cfg))
[docs] def forward(self, x): def _inner_forward(x): if self.stride > 1: out = torch.cat((self.branch1(x), self.branch2(x)), dim=1) else: x1, x2 = x.chunk(2, dim=1) out = torch.cat((x1, self.branch2(x2)), dim=1) out = channel_shuffle(out, 2) return out if self.with_cp and x.requires_grad: out = cp.checkpoint(_inner_forward, x) else: out = _inner_forward(x) return out
[docs]class LiteHRModule(nn.Module): """High-Resolution Module for LiteHRNet. It contains conditional channel weighting blocks and shuffle blocks. Args: num_branches (int): Number of branches in the module. num_blocks (int): Number of blocks in the module. in_channels (list(int)): Number of input image channels. reduce_ratio (int): Channel reduction ratio. module_type (str): 'LITE' or 'NAIVE' multiscale_output (bool): Whether to output multi-scale features. with_fuse (bool): Whether to use fuse layers. conv_cfg (dict): dictionary to construct and config conv layer. norm_cfg (dict): dictionary to construct and config norm layer. with_cp (bool): Use checkpoint or not. Using checkpoint will save some memory while slowing down the training speed. """
[docs] def __init__( self, num_branches, num_blocks, in_channels, reduce_ratio, module_type, multiscale_output=False, with_fuse=True, conv_cfg=None, norm_cfg=dict(type='BN'), with_cp=False, ): super().__init__() self._check_branches(num_branches, in_channels) self.in_channels = in_channels self.num_branches = num_branches self.module_type = module_type self.multiscale_output = multiscale_output self.with_fuse = with_fuse self.norm_cfg = norm_cfg self.conv_cfg = conv_cfg self.with_cp = with_cp if self.module_type.upper() == 'LITE': self.layers = self._make_weighting_blocks(num_blocks, reduce_ratio) elif self.module_type.upper() == 'NAIVE': self.layers = self._make_naive_branches(num_branches, num_blocks) else: raise ValueError("module_type should be either 'LITE' or 'NAIVE'.") if self.with_fuse: self.fuse_layers = self._make_fuse_layers() self.relu = nn.ReLU()
def _check_branches(self, num_branches, in_channels): """Check input to avoid ValueError.""" if num_branches != len(in_channels): error_msg = f'NUM_BRANCHES({num_branches}) ' \ f'!= NUM_INCHANNELS({len(in_channels)})' raise ValueError(error_msg) def _make_weighting_blocks(self, num_blocks, reduce_ratio, stride=1): """Make channel weighting blocks.""" layers = [] for i in range(num_blocks): layers.append( ConditionalChannelWeighting( self.in_channels, stride=stride, reduce_ratio=reduce_ratio, conv_cfg=self.conv_cfg, norm_cfg=self.norm_cfg, with_cp=self.with_cp)) return nn.Sequential(*layers) def _make_one_branch(self, branch_index, num_blocks, stride=1): """Make one branch.""" layers = [] layers.append( ShuffleUnit( self.in_channels[branch_index], self.in_channels[branch_index], stride=stride, conv_cfg=self.conv_cfg, norm_cfg=self.norm_cfg, act_cfg=dict(type='ReLU'), with_cp=self.with_cp)) for i in range(1, num_blocks): layers.append( ShuffleUnit( self.in_channels[branch_index], self.in_channels[branch_index], stride=1, conv_cfg=self.conv_cfg, norm_cfg=self.norm_cfg, act_cfg=dict(type='ReLU'), with_cp=self.with_cp)) return nn.Sequential(*layers) def _make_naive_branches(self, num_branches, num_blocks): """Make branches.""" branches = [] for i in range(num_branches): branches.append(self._make_one_branch(i, num_blocks)) return nn.ModuleList(branches) def _make_fuse_layers(self): """Make fuse layer.""" if self.num_branches == 1: return None num_branches = self.num_branches in_channels = self.in_channels fuse_layers = [] num_out_branches = num_branches if self.multiscale_output else 1 for i in range(num_out_branches): fuse_layer = [] for j in range(num_branches): if j > i: fuse_layer.append( nn.Sequential( build_conv_layer( self.conv_cfg, in_channels[j], in_channels[i], kernel_size=1, stride=1, padding=0, bias=False), build_norm_layer(self.norm_cfg, in_channels[i])[1], nn.Upsample( scale_factor=2**(j - i), mode='nearest'))) elif j == i: fuse_layer.append(None) else: conv_downsamples = [] for k in range(i - j): if k == i - j - 1: conv_downsamples.append( nn.Sequential( build_conv_layer( self.conv_cfg, in_channels[j], in_channels[j], kernel_size=3, stride=2, padding=1, groups=in_channels[j], bias=False), build_norm_layer(self.norm_cfg, in_channels[j])[1], build_conv_layer( self.conv_cfg, in_channels[j], in_channels[i], kernel_size=1, stride=1, padding=0, bias=False), build_norm_layer(self.norm_cfg, in_channels[i])[1])) else: conv_downsamples.append( nn.Sequential( build_conv_layer( self.conv_cfg, in_channels[j], in_channels[j], kernel_size=3, stride=2, padding=1, groups=in_channels[j], bias=False), build_norm_layer(self.norm_cfg, in_channels[j])[1], build_conv_layer( self.conv_cfg, in_channels[j], in_channels[j], kernel_size=1, stride=1, padding=0, bias=False), build_norm_layer(self.norm_cfg, in_channels[j])[1], nn.ReLU(inplace=True))) fuse_layer.append(nn.Sequential(*conv_downsamples)) fuse_layers.append(nn.ModuleList(fuse_layer)) return nn.ModuleList(fuse_layers)
[docs] def forward(self, x): """Forward function.""" if self.num_branches == 1: return [self.layers[0](x[0])] if self.module_type.upper() == 'LITE': out = self.layers(x) elif self.module_type.upper() == 'NAIVE': for i in range(self.num_branches): x[i] = self.layers[i](x[i]) out = x if self.with_fuse: out_fuse = [] for i in range(len(self.fuse_layers)): # `y = 0` will lead to decreased accuracy (0.5~1 mAP) y = out[0] if i == 0 else self.fuse_layers[i][0](out[0]) for j in range(self.num_branches): if i == j: y += out[j] else: y += self.fuse_layers[i][j](out[j]) out_fuse.append(self.relu(y)) out = out_fuse if not self.multiscale_output: out = [out[0]] return out
[docs]@BACKBONES.register_module() class LiteHRNet(nn.Module): """Lite-HRNet backbone. `Lite-HRNet: A Lightweight High-Resolution Network <https://arxiv.org/abs/2104.06403>`__ Code adapted from 'https://github.com/HRNet/Lite-HRNet/' 'blob/hrnet/models/backbones/litehrnet.py' Args: extra (dict): detailed configuration for each stage of HRNet. in_channels (int): Number of input image channels. Default: 3. conv_cfg (dict): dictionary to construct and config conv layer. norm_cfg (dict): dictionary to construct and config norm layer. norm_eval (bool): Whether to set norm layers to eval mode, namely, freeze running stats (mean and var). Note: Effect on Batch Norm and its variants only. Default: False with_cp (bool): Use checkpoint or not. Using checkpoint will save some memory while slowing down the training speed. Example: >>> from mmpose.models import LiteHRNet >>> import torch >>> extra=dict( >>> stem=dict(stem_channels=32, out_channels=32, expand_ratio=1), >>> num_stages=3, >>> stages_spec=dict( >>> num_modules=(2, 4, 2), >>> num_branches=(2, 3, 4), >>> num_blocks=(2, 2, 2), >>> module_type=('LITE', 'LITE', 'LITE'), >>> with_fuse=(True, True, True), >>> reduce_ratios=(8, 8, 8), >>> num_channels=( >>> (40, 80), >>> (40, 80, 160), >>> (40, 80, 160, 320), >>> )), >>> with_head=False) >>> self = LiteHRNet(extra, in_channels=1) >>> self.eval() >>> inputs = torch.rand(1, 1, 32, 32) >>> level_outputs = self.forward(inputs) >>> for level_out in level_outputs: ... print(tuple(level_out.shape)) (1, 40, 8, 8) """
[docs] def __init__(self, extra, in_channels=3, conv_cfg=None, norm_cfg=dict(type='BN'), norm_eval=False, with_cp=False): super().__init__() self.extra = extra self.conv_cfg = conv_cfg self.norm_cfg = norm_cfg self.norm_eval = norm_eval self.with_cp = with_cp self.stem = Stem( in_channels, stem_channels=self.extra['stem']['stem_channels'], out_channels=self.extra['stem']['out_channels'], expand_ratio=self.extra['stem']['expand_ratio'], conv_cfg=self.conv_cfg, norm_cfg=self.norm_cfg) self.num_stages = self.extra['num_stages'] self.stages_spec = self.extra['stages_spec'] num_channels_last = [ self.stem.out_channels, ] for i in range(self.num_stages): num_channels = self.stages_spec['num_channels'][i] num_channels = [num_channels[i] for i in range(len(num_channels))] setattr( self, f'transition{i}', self._make_transition_layer(num_channels_last, num_channels)) stage, num_channels_last = self._make_stage( self.stages_spec, i, num_channels, multiscale_output=True) setattr(self, f'stage{i}', stage) self.with_head = self.extra['with_head'] if self.with_head: self.head_layer = IterativeHead( in_channels=num_channels_last, norm_cfg=self.norm_cfg, )
def _make_transition_layer(self, num_channels_pre_layer, num_channels_cur_layer): """Make transition layer.""" num_branches_cur = len(num_channels_cur_layer) num_branches_pre = len(num_channels_pre_layer) transition_layers = [] for i in range(num_branches_cur): if i < num_branches_pre: if num_channels_cur_layer[i] != num_channels_pre_layer[i]: transition_layers.append( nn.Sequential( build_conv_layer( self.conv_cfg, num_channels_pre_layer[i], num_channels_pre_layer[i], kernel_size=3, stride=1, padding=1, groups=num_channels_pre_layer[i], bias=False), build_norm_layer(self.norm_cfg, num_channels_pre_layer[i])[1], build_conv_layer( self.conv_cfg, num_channels_pre_layer[i], num_channels_cur_layer[i], kernel_size=1, stride=1, padding=0, bias=False), build_norm_layer(self.norm_cfg, num_channels_cur_layer[i])[1], nn.ReLU())) else: transition_layers.append(None) else: conv_downsamples = [] for j in range(i + 1 - num_branches_pre): in_channels = num_channels_pre_layer[-1] out_channels = num_channels_cur_layer[i] \ if j == i - num_branches_pre else in_channels conv_downsamples.append( nn.Sequential( build_conv_layer( self.conv_cfg, in_channels, in_channels, kernel_size=3, stride=2, padding=1, groups=in_channels, bias=False), build_norm_layer(self.norm_cfg, in_channels)[1], build_conv_layer( self.conv_cfg, in_channels, out_channels, kernel_size=1, stride=1, padding=0, bias=False), build_norm_layer(self.norm_cfg, out_channels)[1], nn.ReLU())) transition_layers.append(nn.Sequential(*conv_downsamples)) return nn.ModuleList(transition_layers) def _make_stage(self, stages_spec, stage_index, in_channels, multiscale_output=True): num_modules = stages_spec['num_modules'][stage_index] num_branches = stages_spec['num_branches'][stage_index] num_blocks = stages_spec['num_blocks'][stage_index] reduce_ratio = stages_spec['reduce_ratios'][stage_index] with_fuse = stages_spec['with_fuse'][stage_index] module_type = stages_spec['module_type'][stage_index] modules = [] for i in range(num_modules): # multi_scale_output is only used last module if not multiscale_output and i == num_modules - 1: reset_multiscale_output = False else: reset_multiscale_output = True modules.append( LiteHRModule( num_branches, num_blocks, in_channels, reduce_ratio, module_type, multiscale_output=reset_multiscale_output, with_fuse=with_fuse, conv_cfg=self.conv_cfg, norm_cfg=self.norm_cfg, with_cp=self.with_cp)) in_channels = modules[-1].in_channels return nn.Sequential(*modules), in_channels
[docs] def init_weights(self): """Initialize the weights in backbone. Args: pretrained (str, optional): Path to pre-trained weights. Defaults to None. """ for m in self.modules(): if isinstance(m, nn.Conv2d): normal_init(m, std=0.001) elif isinstance(m, (_BatchNorm, nn.GroupNorm)): constant_init(m, 1)
[docs] def forward(self, x): """Forward function.""" x = self.stem(x) y_list = [x] for i in range(self.num_stages): x_list = [] transition = getattr(self, f'transition{i}') for j in range(self.stages_spec['num_branches'][i]): if transition[j]: if j >= len(y_list): x_list.append(transition[j](y_list[-1])) else: x_list.append(transition[j](y_list[j])) else: x_list.append(y_list[j]) y_list = getattr(self, f'stage{i}')(x_list) x = y_list if self.with_head: x = self.head_layer(x) return [x[0]]
[docs] def train(self, mode=True): """Convert the model into training mode.""" super().train(mode) if mode and self.norm_eval: for m in self.modules(): if isinstance(m, _BatchNorm): m.eval()