Source code for easycv.models.utils.conv_ws

# Copyright (c) Alibaba, Inc. and its affiliates.
import torch.nn as nn
import torch.nn.functional as F


[docs]def conv_ws_2d(input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1, eps=1e-5): c_in = weight.size(0) weight_flat = weight.view(c_in, -1) mean = weight_flat.mean(dim=1, keepdim=True).view(c_in, 1, 1, 1) std = weight_flat.std(dim=1, keepdim=True).view(c_in, 1, 1, 1) weight = (weight - mean) / (std + eps) return F.conv2d(input, weight, bias, stride, padding, dilation, groups)
[docs]class ConvWS2d(nn.Conv2d):
[docs] def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True, eps=1e-5): super(ConvWS2d, self).__init__( in_channels, out_channels, kernel_size, stride=stride, padding=padding, dilation=dilation, groups=groups, bias=bias) self.eps = eps
[docs] def forward(self, x): return conv_ws_2d(x, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups, self.eps)