Source code for easycv.models.utils.scale

# Copyright (c) Alibaba, Inc. and its affiliates.
import torch
import torch.nn as nn


[docs]class Scale(nn.Module): """ A learnable scale parameter """
[docs] def __init__(self, scale=1.0): super(Scale, self).__init__() self.scale = nn.Parameter(torch.tensor(scale, dtype=torch.float))
[docs] def forward(self, x): return x * self.scale