Source code for easycv.models.backbones.genet

# Copyright (c) Alibaba, Inc. and its affiliates.
import uuid

import numpy as np
import torch
import torch.nn.functional as F
from mmcv.cnn import constant_init, kaiming_init
from torch import nn
from torch.nn.modules.batchnorm import _BatchNorm

from ..modelzoo import genet as model_urls
from ..registry import BACKBONES

GENET_LARGE = 'ConvKX(uuid9d1dca0f098143aaa1a947acf1100787|3,32,3,2)\
BN(uuid7d10ba10dc524ffb8863ae97c4a21797|32)RELU(uuidccd810d3d10a48158ccfa48ca975915c|32)\
SuperResKXKX(uuid5ba1db21fce64b16a34ad577c258fd6c|32,128,3,2,1.0,1)\
SuperResKXKX(uuida09fc4e4946444bf9b912f8c666c4b12|128,192,3,2,1.0,2)\
SuperResK1KXK1(uuidfa45c5f5cc96435dbd54801f31c83ca8|192,640,3,2,0.25,6)\
SuperResK1DWK1(uuid99bf6442b33643579dc680045da7549d|640,640,3,2,3.0,5)\
SuperResK1DWK1(uuid615cbfd4ed284cbc8589d84cbe9b0e92|640,640,3,1,3.0,4)\
ConvKX(uuid002fa25f74f14cdeb89a5aacd6ce64ff|640,2560,1,1)\
BN(uuidc5d6c88c326343efa2a8700907f87732|2560)RELU(uuidd2b39caab4cb4ac2b6905b18858c0037|2560)AdaptiveAvgPool(2560,1)'

GENET_NORMAL = 'ConvKX(uuid70de938099844017bd745349f7a1d35a|3,32,3,2)\
BN(uuid10f8a99f83294067bfdf5fc5a5c9bffd|32)\
RELU(uuideffe03bd73254e7c8027364ba71d25cd|32)\
SuperResKXKX(uuidb023bea8c7b34c22a1650e07dfc8e2c1|32,128,3,2,1.0,1)\
SuperResKXKX(uuidf829740023044b879eefaf7fc7d1ad8e|128,192,3,2,1.0,2)\
SuperResK1KXK1(uuid33bfe77cb8864357a840ca3341ea629a|192,640,3,2,0.25,6)\
SuperResK1DWK1(uuide2c948d819fb4869980e30d67a773244|640,640,3,2,3.0,4)\
SuperResK1DWK1(uuid53c308e481c24154b7a81fcbaf99edbf|640,640,3,1,3.0,1)\
ConvKX(uuidbc6953bfd8de45fc8534787a66b96430|640,2560,1,1)\
BN(uuida8acaaae74ed47a4a7514b41c643eb23|2560)RELU(uuida5d71c4fd5d24a7b848472f0383df467|2560)AdaptiveAvgPool(2560,1)'

GENET_SMALL = 'ConvKX(uuid46ff2328b77f40ff88aed69a5318d771|3,13,3,2)\
BN(uuid43b72f65311c42d9a1af485c594a6ab4|13)RELU(uuid282901aaa7f84b028e3c5bd7d37ae056|13)\
SuperResKXKX(uuiddb56d6f9a60b4455966e13b06a8ff723|13,48,3,2,1.0,1)\
SuperResKXKX(uuidd964406e6fdf4e9abac225afaeb1fe0b|48,48,3,2,1.0,3)\
SuperResK1KXK1(uuid39819ad4f4da405583de614af437b568|48,384,3,2,0.25,7)\
SuperResK1DWK1(uuid420593fe7b1e46f690b76bac3786d4b7|384,560,3,2,3.0,2)\
SuperResK1DWK1(uuid96236b3c50774f1ab2d3049d6aca6d85|560,256,3,1,3.0,1)\
ConvKX(uuid89ed263767a14f21b7426cccb120ad1d|256,1920,1,1)\
BN(uuidd6ad568b290544be9f4b47dc3fa271c9|1920)RELU(uuid823ced7441394fb9b3a96a5f7c40da2b|1920)AdaptiveAvgPool(1920,1)'

plainnet_struct_dict = {
    'normal': GENET_NORMAL,
    'large': GENET_LARGE,
    'small': GENET_SMALL
}

# ------------ Fuse BN ------


def _fuse_convkx_and_bn_(convkx, bn):
    the_weight_scale = bn.weight / torch.sqrt(bn.running_var + bn.eps)
    convkx.weight[:] = convkx.weight * the_weight_scale.view((-1, 1, 1, 1))
    the_bias_shift = (bn.weight * bn.running_mean) / \
        torch.sqrt(bn.running_var + bn.eps)
    bn.weight[:] = 1
    bn.bias[:] = bn.bias - the_bias_shift
    bn.running_var[:] = 1.0 - bn.eps
    bn.running_mean[:] = 0.0
    convkx.bias = nn.Parameter(bn.bias)


[docs]def remove_bn_in_superblock(super_block): new_shortcut_list = [] for the_seq_list in super_block.shortcut_list: assert isinstance(the_seq_list, nn.Sequential) new_seq_list = [] last_block = None for block in the_seq_list: if isinstance(block, nn.BatchNorm2d): _fuse_convkx_and_bn_(last_block, block) else: new_seq_list.append(block) last_block = block new_shortcut_list.append(nn.Sequential(*new_seq_list)) super_block.shortcut_list = nn.ModuleList(new_shortcut_list) new_conv_list = [] for the_seq_list in super_block.conv_list: assert isinstance(the_seq_list, nn.Sequential) new_seq_list = [] last_block = None for block in the_seq_list: if isinstance(block, nn.BatchNorm2d): _fuse_convkx_and_bn_(last_block, block) else: new_seq_list.append(block) last_block = block new_conv_list.append(nn.Sequential(*new_seq_list)) super_block.conv_list = nn.ModuleList(new_conv_list)
[docs]def fuse_bn(model): the_block_list = model.block_list last_block = the_block_list[0] new_block_list = [last_block] for the_block in the_block_list[1:]: if isinstance(the_block, BN): _fuse_convkx_and_bn_(last_block.netblock, the_block.netblock) else: new_block_list.append(the_block) last_block = the_block pass the_block_list = new_block_list for the_block in the_block_list: if hasattr(the_block, 'shortcut_list'): remove_bn_in_superblock(the_block) else: continue model.block_list = new_block_list model.module_list = nn.ModuleList(new_block_list) return model
# ------------ end of fuse bn -------- def _create_netblock_list_from_str_(s, no_create=False): block_list = [] while len(s) > 0: is_found_block_class = False for the_block_class_name in _all_netblocks_dict_.keys(): if s.startswith(the_block_class_name): is_found_block_class = True the_block_class = _all_netblocks_dict_[the_block_class_name] the_block, remaining_s = the_block_class.create_from_str( s, no_create=no_create) if the_block is not None: block_list.append(the_block) s = remaining_s if len(s) > 0 and s[0] == ';': return block_list, s[1:] break pass # end if pass # end for assert is_found_block_class pass # end while return block_list, '' def _get_right_parentheses_index_(s): # assert s[0] == '(' left_paren_count = 0 for index, x in enumerate(s): if x == '(': left_paren_count += 1 elif x == ')': left_paren_count -= 1 if left_paren_count == 0: return index else: pass return None ''' -------------------- GENet Blocks -------------------- '''
[docs]class PlainNetBasicBlockClass(nn.Module):
[docs] def __init__(self, in_channels=0, out_channels=0, stride=1, no_create=False, block_name=None, **kwargs): super(PlainNetBasicBlockClass, self).__init__(**kwargs) self.in_channels = in_channels self.out_channels = out_channels self.stride = stride self.no_create = no_create self.block_name = block_name
[docs] def forward(self, x): return x
[docs] @staticmethod def create_from_str(s, no_create=False): assert PlainNetBasicBlockClass.is_instance_from_str(s) idx = _get_right_parentheses_index_(s) assert idx is not None param_str = s[len('PlainNetBasicBlockClass('):idx] # find block_name tmp_idx = param_str.find('|') if tmp_idx < 0: tmp_block_name = 'uuid{}'.format(uuid.uuid4().hex) else: tmp_block_name = param_str[0:tmp_idx] param_str = param_str[tmp_idx + 1:] param_str_split = param_str.split(',') in_channels = int(param_str_split[0]) out_channels = int(param_str_split[1]) stride = int(param_str_split[2]) return PlainNetBasicBlockClass( in_channels=in_channels, out_channels=out_channels, stride=stride, block_name=tmp_block_name, no_create=no_create), s[idx + 1:]
[docs] @staticmethod def is_instance_from_str(s): if s.startswith('PlainNetBasicBlockClass(') and s[-1] == ')': return True else: return False
[docs]class AdaptiveAvgPool(PlainNetBasicBlockClass):
[docs] def __init__(self, out_channels, output_size, no_create=False, block_name=None, **kwargs): super(AdaptiveAvgPool, self).__init__(**kwargs) self.in_channels = out_channels self.out_channels = out_channels * output_size**2 self.output_size = output_size self.block_name = block_name if not no_create: self.netblock = nn.AdaptiveAvgPool2d( output_size=(self.output_size, self.output_size))
[docs] def forward(self, x): return self.netblock(x)
[docs] @staticmethod def create_from_str(s, no_create=False): assert AdaptiveAvgPool.is_instance_from_str(s) idx = _get_right_parentheses_index_(s) assert idx is not None param_str = s[len('AdaptiveAvgPool('):idx] # find block_name tmp_idx = param_str.find('|') if tmp_idx < 0: tmp_block_name = 'uuid{}'.format(uuid.uuid4().hex) else: tmp_block_name = param_str[0:tmp_idx] param_str = param_str[tmp_idx + 1:] param_str_split = param_str.split(',') out_channels = int(param_str_split[0]) output_size = int(param_str_split[1]) return AdaptiveAvgPool( out_channels=out_channels, output_size=output_size, block_name=tmp_block_name, no_create=no_create), s[idx + 1:]
[docs] @staticmethod def is_instance_from_str(s): if s.startswith('AdaptiveAvgPool(') and s[-1] == ')': return True else: return False
[docs]class BN(PlainNetBasicBlockClass):
[docs] def __init__(self, out_channels=None, copy_from=None, no_create=False, block_name=None, **kwargs): super(BN, self).__init__(**kwargs) self.block_name = block_name if copy_from is not None: assert isinstance(copy_from, nn.BatchNorm2d) self.in_channels = copy_from.weight.shape[0] self.out_channels = copy_from.weight.shape[0] assert out_channels is None or out_channels == self.out_channels self.netblock = copy_from else: self.in_channels = out_channels self.out_channels = out_channels if no_create: return else: self.netblock = nn.BatchNorm2d(num_features=self.out_channels)
[docs] def forward(self, x): return self.netblock(x)
[docs] @staticmethod def create_from_str(s, no_create=False): assert BN.is_instance_from_str(s) idx = _get_right_parentheses_index_(s) assert idx is not None param_str = s[len('BN('):idx] # find block_name tmp_idx = param_str.find('|') if tmp_idx < 0: tmp_block_name = 'uuid{}'.format(uuid.uuid4().hex) else: tmp_block_name = param_str[0:tmp_idx] param_str = param_str[tmp_idx + 1:] out_channels = int(param_str) return BN( out_channels=out_channels, block_name=tmp_block_name, no_create=no_create), s[idx + 1:]
[docs] @staticmethod def is_instance_from_str(s): if s.startswith('BN(') and s[-1] == ')': return True else: return False
[docs]class ConvDW(PlainNetBasicBlockClass):
[docs] def __init__(self, out_channels=None, kernel_size=None, stride=None, copy_from=None, no_create=False, block_name=None, **kwargs): super(ConvDW, self).__init__(**kwargs) self.block_name = block_name self.use_weight_mean_zero_constrain = False if copy_from is not None: assert isinstance(copy_from, nn.Conv2d) self.in_channels = copy_from.in_channels self.out_channels = copy_from.out_channels self.kernel_size = copy_from.kernel_size[0] self.stride = copy_from.stride[0] assert self.in_channels == self.out_channels assert out_channels is None or out_channels == self.out_channels assert kernel_size is None or kernel_size == self.kernel_size assert stride is None or stride == self.stride self.netblock = copy_from else: self.in_channels = out_channels self.out_channels = out_channels self.stride = stride self.kernel_size = kernel_size self.padding = (self.kernel_size - 1) // 2 if no_create or self.in_channels == 0 or self.out_channels == 0 or self.kernel_size == 0 \ or self.stride == 0: return else: self.netblock = nn.Conv2d( in_channels=self.in_channels, out_channels=self.out_channels, kernel_size=self.kernel_size, stride=self.stride, padding=self.padding, bias=False, groups=self.in_channels)
[docs] def forward(self, x): output = self.netblock(x) return output
[docs] @staticmethod def create_from_str(s, no_create=False): assert ConvDW.is_instance_from_str(s) idx = _get_right_parentheses_index_(s) assert idx is not None param_str = s[len('ConvDW('):idx] # find block_name tmp_idx = param_str.find('|') if tmp_idx < 0: tmp_block_name = 'uuid{}'.format(uuid.uuid4().hex) else: tmp_block_name = param_str[0:tmp_idx] param_str = param_str[tmp_idx + 1:] split_str = param_str.split(',') out_channels = int(split_str[0]) kernel_size = int(split_str[1]) stride = int(split_str[2]) return ConvDW( out_channels=out_channels, kernel_size=kernel_size, stride=stride, no_create=no_create, block_name=tmp_block_name), s[idx + 1:]
[docs] @staticmethod def is_instance_from_str(s): if s.startswith('ConvDW(') and s[-1] == ')': return True else: return False
[docs]class ConvKX(PlainNetBasicBlockClass):
[docs] def __init__(self, in_channels=None, out_channels=None, kernel_size=None, stride=None, copy_from=None, no_create=False, block_name=None, **kwargs): super(ConvKX, self).__init__(**kwargs) self.block_name = block_name self.use_weight_mean_zero_constrain = False if copy_from is not None: assert isinstance(copy_from, nn.Conv2d) self.in_channels = copy_from.in_channels self.out_channels = copy_from.out_channels self.kernel_size = copy_from.kernel_size[0] self.stride = copy_from.stride[0] assert in_channels is None or in_channels == self.in_channels assert out_channels is None or out_channels == self.out_channels assert kernel_size is None or kernel_size == self.kernel_size assert stride is None or stride == self.stride self.netblock = copy_from else: self.in_channels = in_channels self.out_channels = out_channels self.stride = stride self.kernel_size = kernel_size self.padding = (self.kernel_size - 1) // 2 if no_create or self.in_channels == 0 or self.out_channels == 0 or self.kernel_size == 0 \ or self.stride == 0: return else: self.netblock = nn.Conv2d( in_channels=self.in_channels, out_channels=self.out_channels, kernel_size=self.kernel_size, stride=self.stride, padding=self.padding, bias=False)
[docs] def forward(self, x): output = self.netblock(x) return output
[docs] @staticmethod def create_from_str(s, no_create=False): assert ConvKX.is_instance_from_str(s) idx = _get_right_parentheses_index_(s) assert idx is not None param_str = s[len('ConvKX('):idx] # find block_name tmp_idx = param_str.find('|') if tmp_idx < 0: tmp_block_name = 'uuid{}'.format(uuid.uuid4().hex) else: tmp_block_name = param_str[0:tmp_idx] param_str = param_str[tmp_idx + 1:] split_str = param_str.split(',') in_channels = int(split_str[0]) out_channels = int(split_str[1]) kernel_size = int(split_str[2]) stride = int(split_str[3]) return ConvKX( in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=stride, no_create=no_create, block_name=tmp_block_name), s[idx + 1:]
[docs] @staticmethod def is_instance_from_str(s): if s.startswith('ConvKX(') and s[-1] == ')': return True else: return False
[docs]class Flatten(PlainNetBasicBlockClass):
[docs] def __init__(self, out_channels, no_create=False, block_name=None, **kwargs): super(Flatten, self).__init__(**kwargs) self.block_name = block_name self.in_channels = out_channels self.out_channels = out_channels
[docs] def forward(self, x): return torch.flatten(x, 1)
[docs] @staticmethod def create_from_str(s, no_create=False): assert Flatten.is_instance_from_str(s) idx = _get_right_parentheses_index_(s) assert idx is not None param_str = s[len('Flatten('):idx] # find block_name tmp_idx = param_str.find('|') if tmp_idx < 0: tmp_block_name = 'uuid{}'.format(uuid.uuid4().hex) else: tmp_block_name = param_str[0:tmp_idx] param_str = param_str[tmp_idx + 1:] out_channels = int(param_str) return Flatten( out_channels=out_channels, no_create=no_create, block_name=tmp_block_name), s[idx + 1:]
[docs] @staticmethod def is_instance_from_str(s): if s.startswith('Flatten(') and s[-1] == ')': return True else: return False
[docs]class Linear(PlainNetBasicBlockClass):
[docs] def __init__(self, in_channels=None, out_channels=None, bias=None, copy_from=None, no_create=False, block_name=None, **kwargs): super(Linear, self).__init__(**kwargs) self.block_name = block_name if copy_from is not None: assert isinstance(copy_from, nn.Linear) self.in_channels = copy_from.in_channels self.out_channels = copy_from.out_channels self.bias = copy_from.bias assert in_channels is None or in_channels == self.in_channels assert out_channels is None or out_channels == self.out_channels assert bias is None or bias == self.bias self.netblock = copy_from else: self.in_channels = in_channels self.out_channels = out_channels self.bias = bias if not no_create: self.netblock = nn.Linear( self.in_channels, self.out_channels, bias=self.bias)
[docs] def forward(self, x): return self.netblock(x)
[docs] @staticmethod def create_from_str(s, no_create=False): assert Linear.is_instance_from_str(s) idx = _get_right_parentheses_index_(s) assert idx is not None param_str = s[len('Linear('):idx] # find block_name tmp_idx = param_str.find('|') if tmp_idx < 0: tmp_block_name = 'uuid{}'.format(uuid.uuid4().hex) else: tmp_block_name = param_str[0:tmp_idx] param_str = param_str[tmp_idx + 1:] split_str = param_str.split(',') in_channels = int(split_str[0]) out_channels = int(split_str[1]) bias = int(split_str[2]) return Linear( in_channels=in_channels, out_channels=out_channels, bias=bias == 1, block_name=tmp_block_name, no_create=no_create), s[idx + 1:]
[docs] @staticmethod def is_instance_from_str(s): if s.startswith('Linear(') and s[-1] == ')': return True else: return False
[docs]class MaxPool(PlainNetBasicBlockClass):
[docs] def __init__(self, out_channels, kernel_size, stride, no_create=False, block_name=None, **kwargs): super(MaxPool, self).__init__(**kwargs) self.block_name = block_name self.in_channels = out_channels self.out_channels = out_channels self.kernel_size = kernel_size self.stride = stride self.padding = (kernel_size - 1) // 2 if not no_create: self.netblock = nn.MaxPool2d( kernel_size=self.kernel_size, stride=self.stride, padding=self.padding)
[docs] def forward(self, x): return self.netblock(x)
[docs] @staticmethod def create_from_str(s, no_create=False): assert MaxPool.is_instance_from_str(s) idx = _get_right_parentheses_index_(s) assert idx is not None param_str = s[len('MaxPool('):idx] # find block_name tmp_idx = param_str.find('|') if tmp_idx < 0: tmp_block_name = 'uuid{}'.format(uuid.uuid4().hex) else: tmp_block_name = param_str[0:tmp_idx] param_str = param_str[tmp_idx + 1:] param_str_split = param_str.split(',') out_channels = int(param_str_split[0]) kernel_size = int(param_str_split[1]) stride = int(param_str_split[2]) return MaxPool( out_channels=out_channels, kernel_size=kernel_size, stride=stride, no_create=no_create, block_name=tmp_block_name), s[idx + 1:]
[docs] @staticmethod def is_instance_from_str(s): if s.startswith('MaxPool(') and s[-1] == ')': return True else: return False
[docs]class MultiSumBlock(PlainNetBasicBlockClass):
[docs] def __init__(self, inner_block_list, no_create=False, block_name=None, **kwargs): super(MultiSumBlock, self).__init__(**kwargs) self.block_name = block_name self.inner_block_list = inner_block_list if not no_create: self.inner_module_list = nn.ModuleList(inner_block_list) self.in_channels = np.max([x.in_channels for x in inner_block_list]) self.out_channels = np.max([x.out_channels for x in inner_block_list]) res = 1024 res = self.inner_block_list[0].get_output_resolution(res) self.stride = 1024 // res
[docs] def forward(self, x): output = self.inner_block_list[0](x) for inner_block in self.inner_block_list[1:]: output2 = inner_block(x) output = output + output2 return output
[docs] @staticmethod def create_from_str(s, no_create=False): assert MultiSumBlock.is_instance_from_str(s) idx = _get_right_parentheses_index_(s) assert idx is not None param_str = s[len('MultiSumBlock('):idx] # find block_name tmp_idx = param_str.find('|') if tmp_idx < 0: tmp_block_name = 'uuid{}'.format(uuid.uuid4().hex) else: tmp_block_name = param_str[0:tmp_idx] param_str = param_str[tmp_idx + 1:] the_s = param_str the_inner_block_list = [] while len(the_s) > 0: tmp_block_list, remaining_s = _create_netblock_list_from_str_( the_s, no_create=no_create) the_s = remaining_s if tmp_block_list is None: pass elif len(tmp_block_list) == 1: the_inner_block_list.append(tmp_block_list[0]) else: the_inner_block_list.append( Sequential( inner_block_list=tmp_block_list, no_create=no_create)) pass # end while if len(the_inner_block_list) == 0: return None, s[idx + 1:] return MultiSumBlock( inner_block_list=the_inner_block_list, block_name=tmp_block_name, no_create=no_create), s[idx + 1:]
[docs] @staticmethod def is_instance_from_str(s): if s.startswith('MultiSumBlock(') and s[-1] == ')': return True else: return False
[docs]class RELU(PlainNetBasicBlockClass):
[docs] def __init__(self, out_channels, no_create=False, block_name=None, **kwargs): super(RELU, self).__init__(**kwargs) self.block_name = block_name self.in_channels = out_channels self.out_channels = out_channels
[docs] def forward(self, x): return F.relu(x)
[docs] @staticmethod def create_from_str(s, no_create=False): assert RELU.is_instance_from_str(s) idx = _get_right_parentheses_index_(s) assert idx is not None param_str = s[len('RELU('):idx] # find block_name tmp_idx = param_str.find('|') if tmp_idx < 0: tmp_block_name = 'uuid{}'.format(uuid.uuid4().hex) else: tmp_block_name = param_str[0:tmp_idx] param_str = param_str[tmp_idx + 1:] out_channels = int(param_str) return RELU( out_channels=out_channels, no_create=no_create, block_name=tmp_block_name), s[idx + 1:]
[docs] @staticmethod def is_instance_from_str(s): if s.startswith('RELU(') and s[-1] == ')': return True else: return False
[docs]class ResBlock(PlainNetBasicBlockClass): ''' ResBlock(in_channles, inner_blocks_str). If in_channels is missing, use inner_block_list[0].in_channels as in_channels '''
[docs] def __init__(self, inner_block_list, in_channels=None, stride=None, no_create=False, block_name=None, **kwargs): super(ResBlock, self).__init__(**kwargs) self.block_name = block_name self.inner_block_list = inner_block_list self.stride = stride if not no_create: self.inner_module_list = nn.ModuleList(inner_block_list) if in_channels is None: self.in_channels = inner_block_list[0].in_channels else: self.in_channels = in_channels self.out_channels = max(self.in_channels, inner_block_list[-1].out_channels) if self.stride is None: tmp_input_res = 1024 tmp_output_res = self.get_output_resolution(tmp_input_res) self.stride = tmp_input_res // tmp_output_res
[docs] def forward(self, x): if self.stride > 1: downsampled_x = F.avg_pool2d( x, kernel_size=self.stride + 1, stride=self.stride, padding=self.stride // 2) else: downsampled_x = x if len(self.inner_block_list) == 0: return downsampled_x output = x for inner_block in self.inner_block_list: output = inner_block(output) output = output + downsampled_x return output
[docs] @staticmethod def create_from_str(s, no_create=False): assert ResBlock.is_instance_from_str(s) idx = _get_right_parentheses_index_(s) assert idx is not None the_stride = None param_str = s[len('ResBlock('):idx] # find block_name tmp_idx = param_str.find('|') if tmp_idx < 0: tmp_block_name = 'uuid{}'.format(uuid.uuid4().hex) else: tmp_block_name = param_str[0:tmp_idx] param_str = param_str[tmp_idx + 1:] first_comma_index = param_str.find(',') if first_comma_index < 0 or not param_str[0:first_comma_index].isdigit( ): # cannot parse in_channels, missing, use default in_channels = None the_inner_block_list, remaining_s = _create_netblock_list_from_str_( param_str, no_create=no_create) else: in_channels = int(param_str[0:first_comma_index]) param_str = param_str[first_comma_index + 1:] second_comma_index = param_str.find(',') if second_comma_index < 0 or not param_str[ 0:second_comma_index].isdigit(): the_inner_block_list, remaining_s = _create_netblock_list_from_str_( param_str, no_create=no_create) else: the_stride = int(param_str[0:second_comma_index]) param_str = param_str[second_comma_index + 1:] the_inner_block_list, remaining_s = _create_netblock_list_from_str_( param_str, no_create=no_create) pass pass assert len(remaining_s) == 0 if the_inner_block_list is None or len(the_inner_block_list) == 0: return None, s[idx + 1:] return ResBlock( inner_block_list=the_inner_block_list, in_channels=in_channels, stride=the_stride, no_create=no_create, block_name=tmp_block_name), s[idx + 1:]
[docs] @staticmethod def is_instance_from_str(s): if s.startswith('ResBlock(') and s[-1] == ')': return True else: return False
[docs]class Sequential(PlainNetBasicBlockClass):
[docs] def __init__(self, inner_block_list, no_create=False, block_name=None, **kwargs): super(Sequential, self).__init__(**kwargs) self.block_name = block_name self.inner_block_list = inner_block_list if not no_create: self.inner_module_list = nn.ModuleList(inner_block_list) self.in_channels = inner_block_list[0].in_channels self.out_channels = inner_block_list[-1].out_channels res = 1024 for block in self.inner_block_list: res = block.get_output_resolution(res) self.stride = 1024 // res
[docs] def forward(self, x): output = x for inner_block in self.inner_block_list: output = inner_block(output) return output
[docs] @staticmethod def create_from_str(s, no_create=False): assert Sequential.is_instance_from_str(s) the_right_paraen_idx = _get_right_parentheses_index_(s) param_str = s[len('Sequential(') + 1:the_right_paraen_idx] # find block_name tmp_idx = param_str.find('|') if tmp_idx < 0: tmp_block_name = 'uuid{}'.format(uuid.uuid4().hex) else: tmp_block_name = param_str[0:tmp_idx] param_str = param_str[tmp_idx + 1:] the_inner_block_list, remaining_s = _create_netblock_list_from_str_( param_str, no_create=no_create) assert len(remaining_s) == 0 if the_inner_block_list is None or len(the_inner_block_list) == 0: return None, '' return Sequential( inner_block_list=the_inner_block_list, no_create=no_create, block_name=tmp_block_name), ''
[docs] @staticmethod def is_instance_from_str(s): if s.startswith('Sequential('): return True else: return False
''' Super Blocks '''
[docs]class SuperResKXKX(PlainNetBasicBlockClass):
[docs] def __init__(self, in_channels=0, out_channels=0, kernel_size=3, stride=1, expansion=1.0, sublayers=1, no_create=False, block_name=None, **kwargs): super(SuperResKXKX, self).__init__(**kwargs) self.in_channels = in_channels self.out_channels = out_channels self.kernel_size = kernel_size self.expansion = expansion self.stride = stride self.sublayers = sublayers self.no_create = no_create self.block_name = block_name self.shortcut_list = nn.ModuleList() self.conv_list = nn.ModuleList() for layerID in range(self.sublayers): if layerID == 0: current_in_channels = self.in_channels current_out_channels = self.out_channels current_stride = self.stride current_kernel_size = self.kernel_size else: current_in_channels = self.out_channels current_out_channels = self.out_channels current_stride = 1 current_kernel_size = self.kernel_size current_expansion_channel = int( round(current_out_channels * self.expansion)) the_conv_block = nn.Sequential( nn.Conv2d( current_in_channels, current_expansion_channel, kernel_size=current_kernel_size, stride=current_stride, padding=(current_kernel_size - 1) // 2, bias=False), nn.BatchNorm2d(current_expansion_channel), nn.ReLU(), nn.Conv2d( current_expansion_channel, current_out_channels, kernel_size=current_kernel_size, stride=1, padding=(current_kernel_size - 1) // 2, bias=False), nn.BatchNorm2d(current_out_channels), ) self.conv_list.append(the_conv_block) if current_stride == 1 and current_in_channels == current_out_channels: shortcut = nn.Sequential() else: shortcut = nn.Sequential( nn.Conv2d( current_in_channels, current_out_channels, kernel_size=1, stride=current_stride, padding=0, bias=False), nn.BatchNorm2d(current_out_channels)) self.shortcut_list.append(shortcut) pass # end for
[docs] def forward(self, x): output = x for block, shortcut in zip(self.conv_list, self.shortcut_list): conv_output = block(output) output = conv_output + shortcut(output) output = F.relu(output) return output
[docs] @staticmethod def create_from_str(s, no_create=False): assert SuperResKXKX.is_instance_from_str(s) idx = _get_right_parentheses_index_(s) assert idx is not None param_str = s[len('SuperResKXKX('):idx] # find block_name tmp_idx = param_str.find('|') if tmp_idx < 0: tmp_block_name = 'uuid{}'.format(uuid.uuid4().hex) else: tmp_block_name = param_str[0:tmp_idx] param_str = param_str[tmp_idx + 1:] param_str_split = param_str.split(',') in_channels = int(param_str_split[0]) out_channels = int(param_str_split[1]) kernel_size = int(param_str_split[2]) stride = int(param_str_split[3]) expansion = float(param_str_split[4]) sublayers = int(param_str_split[5]) return SuperResKXKX( in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=stride, expansion=expansion, sublayers=sublayers, block_name=tmp_block_name, no_create=no_create), s[idx + 1:]
[docs] @staticmethod def is_instance_from_str(s): if s.startswith('SuperResKXKX(') and s[-1] == ')': return True else: return False
[docs]class SuperResK1KX(PlainNetBasicBlockClass):
[docs] def __init__(self, in_channels=0, out_channels=0, kernel_size=3, stride=1, expansion=1.0, sublayers=1, no_create=False, block_name=None, **kwargs): super(SuperResK1KX, self).__init__(**kwargs) self.in_channels = in_channels self.out_channels = out_channels self.kernel_size = kernel_size self.expansion = expansion self.stride = stride self.sublayers = sublayers self.no_create = no_create self.block_name = block_name self.shortcut_list = nn.ModuleList() self.conv_list = nn.ModuleList() for layerID in range(self.sublayers): if layerID == 0: current_in_channels = self.in_channels current_out_channels = self.out_channels current_stride = self.stride current_kernel_size = self.kernel_size else: current_in_channels = self.out_channels current_out_channels = self.out_channels current_stride = 1 current_kernel_size = self.kernel_size current_expansion_channel = int( round(current_out_channels * self.expansion)) the_conv_block = nn.Sequential( nn.Conv2d( current_in_channels, current_expansion_channel, kernel_size=1, stride=1, padding=0, bias=False), nn.BatchNorm2d(current_expansion_channel), nn.ReLU(), nn.Conv2d( current_expansion_channel, current_out_channels, kernel_size=current_kernel_size, stride=current_stride, padding=(current_kernel_size - 1) // 2, bias=False), nn.BatchNorm2d(current_out_channels), ) self.conv_list.append(the_conv_block) if current_stride == 1 and current_in_channels == current_out_channels: shortcut = nn.Sequential() else: shortcut = nn.Sequential( nn.Conv2d( current_in_channels, current_out_channels, kernel_size=1, stride=current_stride, padding=0, bias=False), nn.BatchNorm2d(current_out_channels)) self.shortcut_list.append(shortcut) pass # end for
[docs] def forward(self, x): output = x for block, shortcut in zip(self.conv_list, self.shortcut_list): conv_output = block(output) output = conv_output + shortcut(output) output = F.relu(output) return output
[docs] @staticmethod def create_from_str(s, no_create=False): assert SuperResK1KX.is_instance_from_str(s) idx = _get_right_parentheses_index_(s) assert idx is not None param_str = s[len('SuperResK1KX('):idx] # find block_name tmp_idx = param_str.find('|') if tmp_idx < 0: tmp_block_name = 'uuid{}'.format(uuid.uuid4().hex) else: tmp_block_name = param_str[0:tmp_idx] param_str = param_str[tmp_idx + 1:] param_str_split = param_str.split(',') in_channels = int(param_str_split[0]) out_channels = int(param_str_split[1]) kernel_size = int(param_str_split[2]) stride = int(param_str_split[3]) expansion = float(param_str_split[4]) sublayers = int(param_str_split[5]) return SuperResK1KX( in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=stride, expansion=expansion, sublayers=sublayers, block_name=tmp_block_name, no_create=no_create), s[idx + 1:]
[docs] @staticmethod def is_instance_from_str(s): if s.startswith('SuperResK1KX(') and s[-1] == ')': return True else: return False
[docs]class SuperResK1KXK1(PlainNetBasicBlockClass):
[docs] def __init__(self, in_channels=0, out_channels=0, kernel_size=3, stride=1, expansion=1.0, sublayers=1, no_create=False, block_name=None, **kwargs): super(SuperResK1KXK1, self).__init__(**kwargs) self.in_channels = in_channels self.out_channels = out_channels self.kernel_size = kernel_size self.expansion = expansion self.stride = stride self.sublayers = sublayers self.no_create = no_create self.block_name = block_name self.shortcut_list = nn.ModuleList() self.conv_list = nn.ModuleList() for layerID in range(self.sublayers): if layerID == 0: current_in_channels = self.in_channels current_out_channels = self.out_channels current_stride = self.stride current_kernel_size = self.kernel_size else: current_in_channels = self.out_channels current_out_channels = self.out_channels current_stride = 1 current_kernel_size = self.kernel_size current_expansion_channel = int( round(current_out_channels * self.expansion)) the_conv_block = nn.Sequential( nn.Conv2d( current_in_channels, current_expansion_channel, kernel_size=1, stride=1, padding=0, bias=False), nn.BatchNorm2d(current_expansion_channel), nn.ReLU(), nn.Conv2d( current_expansion_channel, current_expansion_channel, kernel_size=current_kernel_size, stride=current_stride, padding=(current_kernel_size - 1) // 2, bias=False), nn.BatchNorm2d(current_expansion_channel), nn.ReLU(), nn.Conv2d( current_expansion_channel, current_out_channels, kernel_size=1, stride=1, padding=0, bias=False), nn.BatchNorm2d(current_out_channels), ) self.conv_list.append(the_conv_block) if current_stride == 1 and current_in_channels == current_out_channels: shortcut = nn.Sequential() else: shortcut = nn.Sequential( nn.Conv2d( current_in_channels, current_out_channels, kernel_size=1, stride=current_stride, padding=0, bias=False), nn.BatchNorm2d(current_out_channels)) self.shortcut_list.append(shortcut) pass # end for
[docs] def forward(self, x): output = x for block, shortcut in zip(self.conv_list, self.shortcut_list): conv_output = block(output) output = conv_output + shortcut(output) output = F.relu(output) return output
[docs] @staticmethod def create_from_str(s, no_create=False): assert SuperResK1KXK1.is_instance_from_str(s) idx = _get_right_parentheses_index_(s) assert idx is not None param_str = s[len('SuperResK1KXK1('):idx] # find block_name tmp_idx = param_str.find('|') if tmp_idx < 0: tmp_block_name = 'uuid{}'.format(uuid.uuid4().hex) else: tmp_block_name = param_str[0:tmp_idx] param_str = param_str[tmp_idx + 1:] param_str_split = param_str.split(',') in_channels = int(param_str_split[0]) out_channels = int(param_str_split[1]) kernel_size = int(param_str_split[2]) stride = int(param_str_split[3]) expansion = float(param_str_split[4]) sublayers = int(param_str_split[5]) return SuperResK1KXK1( in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=stride, expansion=expansion, sublayers=sublayers, block_name=tmp_block_name, no_create=no_create), s[idx + 1:]
[docs] @staticmethod def is_instance_from_str(s): if s.startswith('SuperResK1KXK1(') and s[-1] == ')': return True else: return False
[docs]class SuperResK1DWK1(PlainNetBasicBlockClass):
[docs] def __init__(self, in_channels=0, out_channels=0, kernel_size=3, stride=1, expansion=1.0, sublayers=1, no_create=False, block_name=None, **kwargs): super(SuperResK1DWK1, self).__init__(**kwargs) self.in_channels = in_channels self.out_channels = out_channels self.kernel_size = kernel_size self.expansion = expansion self.stride = stride self.sublayers = sublayers self.no_create = no_create self.block_name = block_name self.shortcut_list = nn.ModuleList() self.conv_list = nn.ModuleList() for layerID in range(self.sublayers): if layerID == 0: current_in_channels = self.in_channels current_out_channels = self.out_channels current_stride = self.stride current_kernel_size = self.kernel_size else: current_in_channels = self.out_channels current_out_channels = self.out_channels current_stride = 1 current_kernel_size = self.kernel_size current_expansion_channel = int( round(current_out_channels * self.expansion)) the_conv_block = nn.Sequential( nn.Conv2d( current_in_channels, current_expansion_channel, kernel_size=1, stride=1, padding=0, bias=False), nn.BatchNorm2d(current_expansion_channel), nn.ReLU(), nn.Conv2d( current_expansion_channel, current_expansion_channel, kernel_size=current_kernel_size, stride=current_stride, padding=(current_kernel_size - 1) // 2, bias=False, groups=current_expansion_channel), nn.BatchNorm2d(current_expansion_channel), nn.ReLU(), nn.Conv2d( current_expansion_channel, current_out_channels, kernel_size=1, stride=1, padding=0, bias=False), nn.BatchNorm2d(current_out_channels), ) self.conv_list.append(the_conv_block) if current_stride == 1 and current_in_channels == current_out_channels: shortcut = nn.Sequential() else: shortcut = nn.Sequential( nn.Conv2d( current_in_channels, current_out_channels, kernel_size=1, stride=current_stride, padding=0, bias=False), nn.BatchNorm2d(current_out_channels)) self.shortcut_list.append(shortcut) pass # end for
[docs] def forward(self, x): output = x for block, shortcut in zip(self.conv_list, self.shortcut_list): conv_output = block(output) output = conv_output + shortcut(output) output = F.relu(output) return output
[docs] @staticmethod def create_from_str(s, no_create=False): assert SuperResK1DWK1.is_instance_from_str(s) idx = _get_right_parentheses_index_(s) assert idx is not None param_str = s[len('SuperResK1DWK1('):idx] # find block_name tmp_idx = param_str.find('|') if tmp_idx < 0: tmp_block_name = 'uuid{}'.format(uuid.uuid4().hex) else: tmp_block_name = param_str[0:tmp_idx] param_str = param_str[tmp_idx + 1:] param_str_split = param_str.split(',') in_channels = int(param_str_split[0]) out_channels = int(param_str_split[1]) kernel_size = int(param_str_split[2]) stride = int(param_str_split[3]) expansion = float(param_str_split[4]) sublayers = int(param_str_split[5]) return SuperResK1DWK1( in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=stride, expansion=expansion, sublayers=sublayers, block_name=tmp_block_name, no_create=no_create), s[idx + 1:]
[docs] @staticmethod def is_instance_from_str(s): if s.startswith('SuperResK1DWK1(') and s[-1] == ')': return True else: return False
[docs]class SuperResK1DW(PlainNetBasicBlockClass):
[docs] def __init__(self, in_channels=0, out_channels=0, kernel_size=3, stride=1, expansion=1.0, sublayers=1, no_create=False, block_name=None, **kwargs): super(SuperResK1DW, self).__init__(**kwargs) self.in_channels = in_channels self.out_channels = out_channels self.kernel_size = kernel_size self.expansion = expansion assert abs(expansion - 1) < 1e-6 self.stride = stride self.sublayers = sublayers self.no_create = no_create self.block_name = block_name self.shortcut_list = nn.ModuleList() self.conv_list = nn.ModuleList() for layerID in range(self.sublayers): if layerID == 0: current_in_channels = self.in_channels current_out_channels = self.out_channels current_stride = self.stride current_kernel_size = self.kernel_size else: current_in_channels = self.out_channels current_out_channels = self.out_channels current_stride = 1 current_kernel_size = self.kernel_size current_expansion_channel = int( round(current_out_channels * self.expansion)) the_conv_block = nn.Sequential( nn.Conv2d( current_in_channels, current_out_channels, kernel_size=1, stride=1, padding=0, bias=False), nn.BatchNorm2d(current_expansion_channel), nn.ReLU(), nn.Conv2d( current_out_channels, current_out_channels, kernel_size=current_kernel_size, stride=current_stride, padding=(current_kernel_size - 1) // 2, bias=False, groups=current_out_channels), nn.BatchNorm2d(current_out_channels), ) self.conv_list.append(the_conv_block) if current_stride == 1 and current_in_channels == current_out_channels: shortcut = nn.Sequential() else: shortcut = nn.Sequential( nn.Conv2d( current_in_channels, current_out_channels, kernel_size=1, stride=current_stride, padding=0, bias=False), nn.BatchNorm2d(current_out_channels)) self.shortcut_list.append(shortcut) pass # end for
[docs] def forward(self, x): output = x for block, shortcut in zip(self.conv_list, self.shortcut_list): conv_output = block(output) output = conv_output + shortcut(output) output = F.relu(output) return output
[docs] @staticmethod def create_from_str(s, no_create=False): assert SuperResK1DW.is_instance_from_str(s) idx = _get_right_parentheses_index_(s) assert idx is not None param_str = s[len('SuperResK1DW('):idx] # find block_name tmp_idx = param_str.find('|') if tmp_idx < 0: tmp_block_name = 'uuid{}'.format(uuid.uuid4().hex) else: tmp_block_name = param_str[0:tmp_idx] param_str = param_str[tmp_idx + 1:] param_str_split = param_str.split(',') in_channels = int(param_str_split[0]) out_channels = int(param_str_split[1]) kernel_size = int(param_str_split[2]) stride = int(param_str_split[3]) expansion = float(param_str_split[4]) sublayers = int(param_str_split[5]) return SuperResK1DW( in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=stride, expansion=expansion, sublayers=sublayers, block_name=tmp_block_name, no_create=no_create), s[idx + 1:]
[docs] @staticmethod def is_instance_from_str(s): if s.startswith('SuperResK1DW(') and s[-1] == ')': return True else: return False
_all_netblocks_dict_ = { 'AdaptiveAvgPool': AdaptiveAvgPool, 'BN': BN, 'ConvDW': ConvDW, 'ConvKX': ConvKX, 'Flatten': Flatten, 'Linear': Linear, 'MaxPool': MaxPool, 'MultiSumBlock': MultiSumBlock, 'PlainNetBasicBlockClass': PlainNetBasicBlockClass, 'RELU': RELU, 'ResBlock': ResBlock, 'Sequential': Sequential, 'SuperResKXKX': SuperResKXKX, 'SuperResK1KXK1': SuperResK1KXK1, 'SuperResK1DWK1': SuperResK1DWK1, 'SuperResK1KX': SuperResK1KX, 'SuperResK1DW': SuperResK1DW, }
[docs]@BACKBONES.register_module class PlainNet(nn.Module):
[docs] def __init__(self, plainnet_struct_idx=None, num_classes=0, no_create=False, **kwargs): super(PlainNet, self).__init__(**kwargs) self.num_classes = num_classes self.plainnet_struct = plainnet_struct_dict[plainnet_struct_idx] the_s = self.plainnet_struct # type: str block_list, remaining_s = _create_netblock_list_from_str_( the_s, no_create=no_create) assert len(remaining_s) == 0 if isinstance(block_list[-1], AdaptiveAvgPool): self.adptive_avg_pool = block_list[-1] block_list.pop(-1) else: self.adptive_avg_pool = nn.AdaptiveAvgPool2d((1, 1)) self.block_list = block_list if not no_create: self.module_list = nn.ModuleList(block_list) # register self.last_channels = self.adptive_avg_pool.out_channels if num_classes > 0: self.fc_linear = nn.Linear( self.last_channels, self.num_classes, bias=True) else: self.fc_linear = None self.plainnet_struct = str(self) + str(self.adptive_avg_pool) self.zero_init_residual = False self.default_pretrained_model_path = model_urls[self.__class__.__name__ + plainnet_struct_idx]
[docs] def init_weights(self): for m in self.modules(): if isinstance(m, nn.Conv2d): kaiming_init(m, mode='fan_in', nonlinearity='relu') elif isinstance(m, (_BatchNorm, nn.GroupNorm)): constant_init(m, 1)
[docs] def forward(self, x): output = x for the_block in self.block_list: output = the_block(output) if self.fc_linear is not None: bs = output.size(0) output = self.adptive_avg_pool(output) output = output.view(bs, -1) output = self.fc_linear(output) return [output]