USER GUIDE
Tutorials
Changelog
API Doc
# Copyright (c) Alibaba, Inc. and its affiliates. import torch import torch.nn as nn [docs]class Scale(nn.Module): """ A learnable scale parameter """ [docs] def __init__(self, scale=1.0): super(Scale, self).__init__() self.scale = nn.Parameter(torch.tensor(scale, dtype=torch.float)) [docs] def forward(self, x): return x * self.scale