forked from mindspore-Ecosystem/mindspore
quantization aware training frontend operators define.
This commit is contained in:
parent
c75f75a3e1
commit
d64f662c76
|
@ -17,7 +17,7 @@ Layer.
|
|||
|
||||
The high-level components(Cells) used to construct the neural network.
|
||||
"""
|
||||
from .activation import Softmax, LogSoftmax, ReLU, ReLU6, Tanh, GELU, ELU, Sigmoid, PReLU, get_activation, LeakyReLU
|
||||
from .activation import Softmax, LogSoftmax, ReLU, ReLU6, Tanh, GELU, ELU, Sigmoid, PReLU, get_activation, LeakyReLU, HSigmoid, HSwish
|
||||
from .normalization import BatchNorm1d, BatchNorm2d, LayerNorm
|
||||
from .container import SequentialCell, CellList
|
||||
from .conv import Conv2d, Conv2dTranspose
|
||||
|
@ -26,8 +26,9 @@ from .basic import Dropout, Flatten, Dense, ClipByNorm, Norm, OneHot, ImageGradi
|
|||
from .embedding import Embedding
|
||||
from .pooling import AvgPool2d, MaxPool2d
|
||||
|
||||
__all__ = ['Softmax', 'LogSoftmax', 'ReLU', 'ReLU6', 'Tanh', 'GELU', 'Sigmoid', 'PReLU', 'get_activation', 'LeakyReLU',
|
||||
'BatchNorm1d', 'BatchNorm2d', 'LayerNorm', 'ELU',
|
||||
__all__ = ['Softmax', 'LogSoftmax', 'ReLU', 'ReLU6', 'Tanh', 'GELU', 'Sigmoid',
|
||||
'PReLU', 'get_activation', 'LeakyReLU', 'HSigmoid', 'HSwish', 'ELU',
|
||||
'BatchNorm1d', 'BatchNorm2d', 'LayerNorm',
|
||||
'SequentialCell', 'CellList',
|
||||
'Conv2d', 'Conv2dTranspose',
|
||||
'LSTM',
|
||||
|
|
|
@ -0,0 +1,703 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Aware quantization."""
|
||||
|
||||
import numpy as np
|
||||
import mindspore.nn as nn
|
||||
import mindspore.common.dtype as mstype
|
||||
from mindspore.ops import operations as P
|
||||
from mindspore.ops import functional as F
|
||||
from mindspore.common.parameter import Parameter
|
||||
from mindspore.common.initializer import initializer
|
||||
from mindspore.common.tensor import Tensor
|
||||
from mindspore._checkparam import check_int_positive, check_bool, twice
|
||||
from mindspore.nn.cell import Cell
|
||||
from mindspore.nn.layer.conv import _Conv
|
||||
from mindspore.nn.layer.activation import get_activation
|
||||
|
||||
__all__ = [
|
||||
'FakeQuantWithMinMax',
|
||||
'Conv2dBatchNormQuant',
|
||||
'Conv2dQuant',
|
||||
'DenseQuant',
|
||||
'ReLUQuant',
|
||||
'ReLU6Quant',
|
||||
'HSwishQuant',
|
||||
'HSigmoidQuant',
|
||||
'TensorAddQuant',
|
||||
]
|
||||
|
||||
|
||||
class FakeQuantWithMinMax(Cell):
|
||||
r"""
|
||||
Aware Quantization training op. This OP provide Fake quantization observer function on data with min and max.
|
||||
|
||||
Args:
|
||||
min_init (int, list): The dimension of channel or 1(layer). Default: -6.
|
||||
max_init (int, list): The dimension of channel or 1(layer). Default: 6.
|
||||
num_bits (int): Quantization number bit, support 4 and 8bit. Default: 8.
|
||||
ema (bool): Exponential Moving Average algorithm update min and max. Default: False.
|
||||
ema_decay (float): Exponential Moving Average algorithm parameter. Default: 0.9999.
|
||||
per_channel (bool): Quantization by layer or channel. Default: False.
|
||||
channel_size (int): declarate the min and max channel size, Default: 1.
|
||||
quant_delay (int): Quantization delay parameters according by global step. Default: 0.
|
||||
symmetric (bool): Quantization algorithm use symmetric or not. Default: False.
|
||||
narrow_range (bool): Quantization algorithm use narrow range or not. Default: False.
|
||||
|
||||
Inputs:
|
||||
- **x** (Tensor) - The input of FakeQuantWithMinMax.
|
||||
|
||||
Outputs:
|
||||
Tensor, with the same type and shape as the `x`.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
min_init=-6,
|
||||
max_init=6,
|
||||
num_bits=8,
|
||||
ema=False,
|
||||
ema_decay=0.999,
|
||||
per_channel=False,
|
||||
channel_size=1,
|
||||
quant_delay=0,
|
||||
symmetric=False,
|
||||
narrow_range=False):
|
||||
super(FakeQuantWithMinMax, self).__init__()
|
||||
|
||||
self.min_init = min_init
|
||||
self.num_bits = num_bits
|
||||
self.max_init = max_init
|
||||
self.ema = ema
|
||||
self.ema_decay = ema_decay
|
||||
self.per_channel = per_channel
|
||||
self.channel_size = channel_size
|
||||
self.quant_delay = quant_delay
|
||||
self.symmetric = symmetric
|
||||
self.narrow_range = narrow_range
|
||||
|
||||
if per_channel:
|
||||
min_array = np.array([self.min_init for i in range(
|
||||
0, self.channel_size)]).astype(np.float32)
|
||||
max_array = np.array([self.max_init for i in range(
|
||||
0, self.channel_size)]).astype(np.float32)
|
||||
self.fake_quant_train = P.FakeQuantWithMinMaxPerChannel(num_bits=self.num_bits,
|
||||
ema=self.ema,
|
||||
ema_decay=self.ema_decay,
|
||||
quant_delay=self.quant_delay,
|
||||
symmetric=self.symmetric,
|
||||
narrow_range=self.narrow_range,
|
||||
training=True)
|
||||
self.fake_quant_infer = P.FakeQuantWithMinMaxPerChannel(num_bits=self.num_bits,
|
||||
ema=self.ema,
|
||||
ema_decay=ema_decay,
|
||||
quant_delay=quant_delay,
|
||||
symmetric=self.symmetric,
|
||||
narrow_range=self.narrow_range,
|
||||
training=False)
|
||||
else:
|
||||
min_array = np.array([min_init]).reshape(1).astype(np.float32)
|
||||
max_array = np.array([max_init]).reshape(1).astype(np.float32)
|
||||
self.fake_quant_train = P.FakeQuantWithMinMax(num_bits=self.num_bits,
|
||||
ema=self.ema,
|
||||
ema_decay=self.ema_decay,
|
||||
quant_delay=self.quant_delay,
|
||||
symmetric=self.symmetric,
|
||||
narrow_range=self.narrow_range,
|
||||
training=True)
|
||||
self.fake_quant_infer = P.FakeQuantWithMinMax(num_bits=self.num_bits,
|
||||
ema=self.ema,
|
||||
ema_decay=ema_decay,
|
||||
quant_delay=quant_delay,
|
||||
symmetric=self.symmetric,
|
||||
narrow_range=self.narrow_range,
|
||||
training=False)
|
||||
|
||||
self.min = Parameter(
|
||||
Tensor(min_array), name='quant_min', requires_grad=False)
|
||||
self.max = Parameter(
|
||||
Tensor(max_array), name='quant_max', requires_grad=False)
|
||||
|
||||
def extend_repr(self):
|
||||
s = 'min_init={}, max_init={}, ema={}, ema_decay={}, per_channel={}, channel_size={}, quant_delay={}'.format(
|
||||
self.min_init, self.max_init, self.ema, self.ema_decay, self.per_channel, self.channel_size,
|
||||
self.quant_delay)
|
||||
return s
|
||||
|
||||
def construct(self, x):
|
||||
if self.training:
|
||||
out = self.fake_quant_train(x, self.min, self.max)
|
||||
else:
|
||||
out = self.fake_quant_infer(x, self.min, self.max)
|
||||
return out
|
||||
|
||||
|
||||
class Conv2dBatchNormQuant(Cell):
|
||||
r"""
|
||||
2D convolution with BatchNormal op folded layer.
|
||||
|
||||
For a more Detailed overview of Conv2d op.
|
||||
|
||||
Args:
|
||||
in_channels (int): The number of input channel :math:`C_{in}`.
|
||||
out_channels (int): The number of output channel :math:`C_{out}`.
|
||||
kernel_size (Union[int, tuple]): Specifies the height and width of the 2D convolution window.
|
||||
stride (int): Specifies stride for all spatial dimensions with the same value.
|
||||
pad_mode: (str): Specifies padding mode. The optional values are "same", "valid", "pad". Default: "same".
|
||||
padding: (int): Implicit paddings on both sides of the input. Default: 0.
|
||||
eps (int): Parameters for BatchNormal. Default: 1e-5.
|
||||
momentum (int): Parameters for BatchNormal op. Default: 0.9.
|
||||
weight_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the
|
||||
convolution kernel. Default: 'None'.
|
||||
beta_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the
|
||||
beta vector. Default: 'None'.
|
||||
gamma_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the
|
||||
gamma vector. Default: 'None'.
|
||||
mean_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the
|
||||
mean vector. Default: 'None'.
|
||||
var_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the
|
||||
variance vector. Default: 'None'.
|
||||
quant_delay (int): Quantization delay parameters according by global step. Default: 0.
|
||||
freeze_bn (int): Quantization freeze BatchNormal op according by global step. Default: 100000.
|
||||
fake (bool): Conv2dBatchNormQuant Cell add FakeQuantWithMinMax op or not. Default: True.
|
||||
num_bits (int): Quantization number bit, support 4 and 8bit. Default: 8.
|
||||
per_channel (bool): FakeQuantWithMinMax Parameters. Default: False.
|
||||
symmetric (bool): Quantization algorithm use symmetric or not. Default: False.
|
||||
narrow_range (bool): Quantization algorithm use narrow range or not. Default: False.
|
||||
|
||||
Inputs:
|
||||
- **x** (Tensor) - Tensor of shape :math:`(N, C_{in}, H_{in}, W_{in})`.
|
||||
|
||||
Outputs:
|
||||
Tensor of shape :math:`(N, C_{out}, H_{out}, W_{out})`.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
in_channels,
|
||||
out_channels,
|
||||
kernel_size,
|
||||
stride,
|
||||
pad_mode,
|
||||
padding=0,
|
||||
eps=1e-5,
|
||||
momentum=0.9,
|
||||
weight_init=None,
|
||||
beta_init=None,
|
||||
gamma_init=None,
|
||||
mean_init=None,
|
||||
var_init=None,
|
||||
group=1,
|
||||
quant_delay=0,
|
||||
freeze_bn=100000,
|
||||
fake=True,
|
||||
num_bits=8,
|
||||
per_channel=False,
|
||||
symmetric=False,
|
||||
narrow_range=False):
|
||||
super(Conv2dBatchNormQuant, self).__init__()
|
||||
self.stride = stride
|
||||
self.conv = P.Conv2D(out_channel=out_channels,
|
||||
kernel_size=kernel_size,
|
||||
mode=1,
|
||||
pad_mode=pad_mode,
|
||||
pad=padding,
|
||||
stride=stride,
|
||||
dilation=1,
|
||||
group=group)
|
||||
self.fake = fake
|
||||
self.freeze_bn = freeze_bn
|
||||
if isinstance(kernel_size, int):
|
||||
kernel_size = (kernel_size, kernel_size)
|
||||
|
||||
if weight_init is None:
|
||||
weight_init = initializer(
|
||||
'normal', [out_channels, in_channels // group, *kernel_size])
|
||||
self.weight = Parameter(weight_init, name='weight')
|
||||
if gamma_init is None:
|
||||
gamma_init = initializer('ones', [out_channels])
|
||||
self.gamma = Parameter(gamma_init, name='gamma')
|
||||
if beta_init is None:
|
||||
beta_init = initializer('zeros', [out_channels])
|
||||
self.beta = Parameter(beta_init, name='beta')
|
||||
if mean_init is None:
|
||||
mean_init = initializer('zeros', [out_channels])
|
||||
self.moving_mean = Parameter(
|
||||
mean_init, name='moving_mean', requires_grad=False)
|
||||
if var_init is None:
|
||||
var_init = initializer('ones', [out_channels])
|
||||
self.moving_variance = Parameter(
|
||||
var_init, name='moving_variance', requires_grad=False)
|
||||
|
||||
self.step = Parameter(initializer(
|
||||
'normal', [1], dtype=mstype.int32), name='step', requires_grad=False)
|
||||
|
||||
self.fake_quant_weight = nn.FakeQuantWithMinMax(min_init=-6,
|
||||
max_init=6,
|
||||
ema=False,
|
||||
num_bits=num_bits,
|
||||
quant_delay=quant_delay,
|
||||
per_channel=per_channel,
|
||||
channel_size=out_channels,
|
||||
symmetric=symmetric,
|
||||
narrow_range=narrow_range)
|
||||
|
||||
self.batchnorm_fold_train = P.BatchNormFold(epsilon=eps,
|
||||
momentum=momentum,
|
||||
is_training=True,
|
||||
freeze_bn=freeze_bn)
|
||||
self.batchnorm_fold_infer = P.BatchNormFold(epsilon=eps,
|
||||
momentum=momentum,
|
||||
is_training=False,
|
||||
freeze_bn=freeze_bn)
|
||||
self.correct_mul = P.CorrectionMul()
|
||||
self.relu = P.ReLU()
|
||||
self.batchnorm_fold2 = P.BatchNormFold2(freeze_bn=freeze_bn)
|
||||
self.batchnorm_fold2_infer = P.BatchNormFold2(freeze_bn=0)
|
||||
self.one = Tensor(1, mstype.int32)
|
||||
self.assignadd = P.AssignAdd()
|
||||
|
||||
def extend_repr(self):
|
||||
s = 'fake={}, freeze_bn={}'.format(self.fake, self.freeze_bn)
|
||||
return s
|
||||
|
||||
def construct(self, x):
|
||||
if self.training:
|
||||
beta = self.beta
|
||||
gamma = self.gamma
|
||||
gmean = self.moving_mean
|
||||
gvar = self.moving_variance
|
||||
step = self.step
|
||||
out_conv = self.conv(x, self.weight)
|
||||
batch_mean, batch_std, running_mean, running_std = self.batchnorm_fold_train(
|
||||
out_conv, gmean, gvar, step)
|
||||
# BN fold1
|
||||
weight = self.correct_mul(self.weight, gamma, running_std)
|
||||
if self.fake:
|
||||
weight = self.fake_quant_weight(weight)
|
||||
out = self.conv(x, weight)
|
||||
# BN fold2
|
||||
out = self.batchnorm_fold2(
|
||||
out, beta, gamma, batch_std, batch_mean, running_std, running_mean, step)
|
||||
F.control_depend(out, self.assignadd(self.step, self.one))
|
||||
else:
|
||||
step = self.step
|
||||
out_conv = self.conv(x, self.weight)
|
||||
batch_mean, batch_std, running_mean, running_std = self.batchnorm_fold_infer(
|
||||
out_conv, self.moving_mean, self.moving_variance, step)
|
||||
weight = self.correct_mul(self.weight, self.gamma, running_std)
|
||||
if self.fake:
|
||||
weight = self.fake_quant_weight(weight)
|
||||
out = self.conv(x, weight)
|
||||
out = self.batchnorm_fold2_infer(out, self.beta, self.gamma, batch_std, batch_mean,
|
||||
running_std, running_mean, step)
|
||||
return out
|
||||
|
||||
|
||||
class Conv2dQuant(_Conv):
|
||||
r"""
|
||||
2D convolution with fake quant op layer.
|
||||
|
||||
For a more Detailed overview of Conv2d op.
|
||||
|
||||
Args:
|
||||
in_channels (int): The number of input channel :math:`C_{in}`.
|
||||
out_channels (int): The number of output channel :math:`C_{out}`.
|
||||
kernel_size (Union[int, tuple]): Specifies the height and width of the 2D convolution window.
|
||||
stride (int): Specifies stride for all spatial dimensions with the same value. Default: 1.
|
||||
pad_mode: (str): Specifies padding mode. The optional values are "same", "valid", "pad". Default: "same".
|
||||
padding: (int): Implicit paddings on both sides of the input. Default: 0.
|
||||
dilation (int): Specifying the dilation rate to use for dilated convolution. Default: 1.
|
||||
group (int): Split filter into groups, `in_ channels` and `out_channels` should be
|
||||
divisible by the number of groups. Default: 1.
|
||||
has_bias (bool): Specifies whether the layer uses a bias vector. Default: False.
|
||||
weight_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the convolution kernel.
|
||||
Default: 'normal'.
|
||||
bias_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the bias vector. Default: 'zeros'.
|
||||
quant_delay (int): Quantization delay parameters according by global step. Default: 0.
|
||||
num_bits (int): Quantization number bit, support 4 and 8bit. Default: 8.
|
||||
per_channel (bool): FakeQuantWithMinMax Parameters. Default: False.
|
||||
symmetric (bool): Quantization algorithm use symmetric or not. Default: False.
|
||||
narrow_range (bool): Quantization algorithm use narrow range or not. Default: False.
|
||||
|
||||
Inputs:
|
||||
- **x** (Tensor) - Tensor of shape :math:`(N, C_{in}, H_{in}, W_{in})`.
|
||||
|
||||
Outputs:
|
||||
Tensor of shape :math:`(N, C_{out}, H_{out}, W_{out})`.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
in_channels,
|
||||
out_channels,
|
||||
kernel_size,
|
||||
stride=1,
|
||||
pad_mode='same',
|
||||
padding=0,
|
||||
dilation=1,
|
||||
group=1,
|
||||
has_bias=False,
|
||||
weight_init='normal',
|
||||
bias_init='zeros',
|
||||
quant_delay=0,
|
||||
num_bits=8,
|
||||
per_channel=False,
|
||||
symmetric=False,
|
||||
narrow_range=False):
|
||||
kernel_size = twice(kernel_size)
|
||||
super(Conv2dQuant, self).__init__(in_channels, out_channels, kernel_size, stride, pad_mode, padding, dilation,
|
||||
group, has_bias, weight_init, bias_init)
|
||||
self.conv2d = P.Conv2D(out_channel=self.out_channels, kernel_size=self.kernel_size, mode=1,
|
||||
pad_mode=self.pad_mode, pad=self.padding, stride=self.stride, dilation=self.dilation,
|
||||
group=self.group)
|
||||
self.bias_add = P.BiasAdd()
|
||||
if pad_mode not in ('valid', 'same', 'pad'):
|
||||
raise ValueError('Attr \'pad_mode\' of \'Conv2d\' Op passed '
|
||||
+ str(pad_mode) + ', should be one of values in \'valid\', \'same\', \'pad\'.')
|
||||
self.fake_quant_weight = nn.FakeQuantWithMinMax(min_init=-6,
|
||||
max_init=6,
|
||||
ema=False,
|
||||
num_bits=num_bits,
|
||||
quant_delay=quant_delay,
|
||||
per_channel=per_channel,
|
||||
channel_size=out_channels,
|
||||
symmetric=symmetric,
|
||||
narrow_range=narrow_range)
|
||||
|
||||
def construct(self, x):
|
||||
weight_q = self.fake_quant_weight(self.weight)
|
||||
out = self.conv2d(x, weight_q)
|
||||
if self.has_bias:
|
||||
return self.bias_add(out, self.bias)
|
||||
return out
|
||||
|
||||
|
||||
class DenseQuant(Cell):
|
||||
r"""
|
||||
The fully connected layer with fake quant op.
|
||||
|
||||
For a more Detailed overview of Dense op.
|
||||
|
||||
Args:
|
||||
in_channels (int): The dimension of the input space.
|
||||
out_channels (int): The dimension of the output space.
|
||||
weight_init (Union[Tensor, str, Initializer, numbers.Number]): The trainable weight_init parameter. The dtype
|
||||
is same as input x. The values of str refer to the function `initializer`. Default: 'normal'.
|
||||
bias_init (Union[Tensor, str, Initializer, numbers.Number]): The trainable bias_init parameter. The dtype is
|
||||
same as input x. The values of str refer to the function `initializer`. Default: 'zeros'.
|
||||
has_bias (bool): Specifies whether the layer uses a bias vector. Default: True.
|
||||
activation (str): Regularizer function applied to the output of the layer, eg. 'relu'. Default: None.
|
||||
num_bits (int): Quantization number bit, support 4 and 8bit. Default: 8.
|
||||
quant_delay (int): Quantization delay parameters according by global step. Default: 0.
|
||||
per_channel (bool): FakeQuantWithMinMax Parameters. Default: False.
|
||||
symmetric (bool): Quantization algorithm use symmetric or not. Default: False.
|
||||
narrow_range (bool): Quantization algorithm use narrow range or not. Default: False.
|
||||
|
||||
Inputs:
|
||||
- **x** (Tensor) - Tensor of shape :math:`(N, C_{in}, H_{in}, W_{in})`.
|
||||
|
||||
Outputs:
|
||||
Tensor of shape :math:`(N, C_{out}, H_{out}, W_{out})`.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_channels,
|
||||
out_channels,
|
||||
weight_init='normal',
|
||||
bias_init='zeros',
|
||||
has_bias=True,
|
||||
activation=None,
|
||||
num_bits=8,
|
||||
quant_delay=0,
|
||||
per_channel=False,
|
||||
symmetric=False,
|
||||
narrow_range=False):
|
||||
super(DenseQuant, self).__init__()
|
||||
self.in_channels = check_int_positive(in_channels)
|
||||
self.out_channels = check_int_positive(out_channels)
|
||||
self.has_bias = check_bool(has_bias)
|
||||
|
||||
if isinstance(weight_init, Tensor):
|
||||
if weight_init.dim() != 2 or weight_init.shape()[0] != out_channels or \
|
||||
weight_init.shape()[1] != in_channels:
|
||||
raise ValueError("weight_init shape error")
|
||||
|
||||
self.weight = Parameter(initializer(
|
||||
weight_init, [out_channels, in_channels]), name="weight")
|
||||
|
||||
if self.has_bias:
|
||||
if isinstance(bias_init, Tensor):
|
||||
if bias_init.dim() != 1 or bias_init.shape()[0] != out_channels:
|
||||
raise ValueError("bias_init shape error")
|
||||
|
||||
self.bias = Parameter(initializer(
|
||||
bias_init, [out_channels]), name="bias")
|
||||
|
||||
self.matmul = P.MatMul(transpose_b=True)
|
||||
self.bias_add = P.BiasAdd()
|
||||
|
||||
self.activation = get_activation(activation)
|
||||
self.activation_flag = self.activation is not None
|
||||
self.fake_quant_weight = nn.FakeQuantWithMinMax(min_init=-6,
|
||||
max_init=6,
|
||||
ema=False,
|
||||
num_bits=num_bits,
|
||||
quant_delay=quant_delay,
|
||||
per_channel=per_channel,
|
||||
channel_size=out_channels,
|
||||
symmetric=symmetric,
|
||||
narrow_range=narrow_range)
|
||||
|
||||
def construct(self, x):
|
||||
"""Use operators to construct to Dense layer."""
|
||||
output = self.fake_quant_weight(self.weight)
|
||||
output = self.matmul(x, output)
|
||||
if self.has_bias:
|
||||
output = self.bias_add(output, self.bias)
|
||||
if self.activation_flag:
|
||||
return self.activation(output)
|
||||
return output
|
||||
|
||||
def extend_repr(self):
|
||||
"""A pretty print for Dense layer."""
|
||||
str_info = 'in_channels={}, out_channels={}, weight={}, has_bias={}'.format(
|
||||
self.in_channels, self.out_channels, self.weight, self.has_bias)
|
||||
if self.has_bias:
|
||||
str_info = str_info + ', bias={}'.format(self.bias)
|
||||
if self.activation_flag:
|
||||
str_info = str_info + ', activation={}'.format(self.activation)
|
||||
|
||||
return str_info
|
||||
|
||||
|
||||
class ReLUQuant(Cell):
|
||||
r"""
|
||||
ReLUQuant activation function. Add Fake Quant OP after Relu OP.
|
||||
|
||||
For a more Detailed overview of ReLU op.
|
||||
|
||||
Args:
|
||||
num_bits (int): Quantization number bit, support 4 and 8bit. Default: 8.
|
||||
quant_delay (int): Quantization delay parameters according by global step. Default: 0.
|
||||
symmetric (bool): Quantization algorithm use symmetric or not. Default: False.
|
||||
narrow_range (bool): Quantization algorithm use narrow range or not. Default: False.
|
||||
|
||||
Inputs:
|
||||
- **x** (Tensor) - The input of ReLUQuant.
|
||||
|
||||
Outputs:
|
||||
Tensor, with the same type and shape as the `x`.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
num_bits=8,
|
||||
quant_delay=0,
|
||||
symmetric=False,
|
||||
narrow_range=False):
|
||||
super(ReLUQuant, self).__init__()
|
||||
self.fake_quant_act = nn.FakeQuantWithMinMax(min_init=0,
|
||||
max_init=6,
|
||||
num_bits=num_bits,
|
||||
quant_delay=quant_delay,
|
||||
ema=True,
|
||||
symmetric=symmetric,
|
||||
narrow_range=narrow_range)
|
||||
self.relu = P.ReLU()
|
||||
|
||||
def construct(self, x):
|
||||
x = self.relu(x)
|
||||
x = self.fake_quant_act(x)
|
||||
return x
|
||||
|
||||
|
||||
class ReLU6Quant(Cell):
|
||||
r"""
|
||||
ReLU6Quant activation function.
|
||||
|
||||
Add Fake Quant OP after Relu6. Not Recommand to used these cell for Fake Quant Op
|
||||
Will climp the max range of the activation and the relu6 do the same operation.
|
||||
For a more Detailed overview of ReLU6 op.
|
||||
|
||||
Args:
|
||||
num_bits (int): Quantization number bit, support 4 and 8bit. Default: 8.
|
||||
quant_delay (int): Quantization delay parameters according by global step. Default: 0.
|
||||
symmetric (bool): Quantization algorithm use symmetric or not. Default: False.
|
||||
narrow_range (bool): Quantization algorithm use narrow range or not. Default: False.
|
||||
|
||||
Inputs:
|
||||
- **x** (Tensor) - The input of ReLU6Quant.
|
||||
|
||||
Outputs:
|
||||
Tensor, with the same type and shape as the `x`.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, num_bits=8, quant_delay=0, symmetric=False,
|
||||
narrow_range=False):
|
||||
super(ReLU6Quant, self).__init__()
|
||||
self.fake_quant_act = nn.FakeQuantWithMinMax(min_init=0,
|
||||
max_init=6,
|
||||
num_bits=num_bits,
|
||||
quant_delay=quant_delay,
|
||||
ema=True,
|
||||
symmetric=symmetric,
|
||||
narrow_range=narrow_range)
|
||||
self.relu6 = P.ReLU6()
|
||||
|
||||
def construct(self, x):
|
||||
x = self.relu6(x)
|
||||
x = self.fake_quant_act(x)
|
||||
return x
|
||||
|
||||
|
||||
class HSwishQuant(Cell):
|
||||
r"""
|
||||
HSwishQuant activation function. Add Fake Quant OP after HSwish OP.
|
||||
|
||||
For a more Detailed overview of HSwish op.
|
||||
|
||||
Args:
|
||||
num_bits (int): Quantization number bit, support 4 and 8bit. Default: 8.
|
||||
quant_delay (int): Quantization delay parameters according by global step. Default: 0.
|
||||
symmetric (bool): Quantization algorithm use symmetric or not. Default: False.
|
||||
narrow_range (bool): Quantization algorithm use narrow range or not. Default: False.
|
||||
|
||||
Inputs:
|
||||
- **x** (Tensor) - The input of HSwishQuant.
|
||||
|
||||
Outputs:
|
||||
Tensor, with the same type and shape as the `x`.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
num_bits=8,
|
||||
quant_delay=0,
|
||||
symmetric=False,
|
||||
narrow_range=False):
|
||||
super(HSwishQuant, self).__init__()
|
||||
self.fake_quant_act_before = nn.FakeQuantWithMinMax(min_init=0,
|
||||
max_init=6,
|
||||
num_bits=num_bits,
|
||||
quant_delay=quant_delay,
|
||||
ema=True,
|
||||
symmetric=symmetric,
|
||||
narrow_range=narrow_range)
|
||||
self.fake_quant_act_after = nn.FakeQuantWithMinMax(min_init=0,
|
||||
max_init=6,
|
||||
num_bits=num_bits,
|
||||
quant_delay=quant_delay,
|
||||
ema=True,
|
||||
symmetric=symmetric,
|
||||
narrow_range=narrow_range)
|
||||
self.act = P.HSwish()
|
||||
|
||||
def construct(self, x):
|
||||
x = self.fake_quant_act_before(x)
|
||||
x = self.act(x)
|
||||
x = self.fake_quant_act_after(x)
|
||||
return x
|
||||
|
||||
|
||||
class HSigmoidQuant(Cell):
|
||||
r"""
|
||||
HSigmoidQuant activation function. Add Fake Quant OP before and after HSigmoid OP.
|
||||
|
||||
For a more Detailed overview of HSigmoid op.
|
||||
|
||||
Args:
|
||||
num_bits (int): Quantization number bit, support 4 and 8bit. Default: 8.
|
||||
quant_delay (int): Quantization delay parameters according by global step. Default: 0.
|
||||
symmetric (bool): Quantization algorithm use symmetric or not. Default: False.
|
||||
narrow_range (bool): Quantization algorithm use narrow range or not. Default: False.
|
||||
|
||||
Inputs:
|
||||
- **x** (Tensor) - The input of HSigmoidQuant.
|
||||
|
||||
Outputs:
|
||||
Tensor, with the same type and shape as the `x`.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
num_bits=8,
|
||||
quant_delay=0,
|
||||
symmetric=False,
|
||||
narrow_range=False):
|
||||
super(HSigmoidQuant, self).__init__()
|
||||
self.fake_quant_act_before = nn.FakeQuantWithMinMax(min_init=0,
|
||||
max_init=6,
|
||||
num_bits=num_bits,
|
||||
quant_delay=quant_delay,
|
||||
ema=True,
|
||||
symmetric=symmetric,
|
||||
narrow_range=narrow_range)
|
||||
self.fake_quant_act_after = nn.FakeQuantWithMinMax(min_init=0,
|
||||
max_init=6,
|
||||
num_bits=num_bits,
|
||||
quant_delay=quant_delay,
|
||||
ema=True,
|
||||
symmetric=symmetric,
|
||||
narrow_range=narrow_range)
|
||||
self.act = P.HSigmoid()
|
||||
|
||||
def construct(self, x):
|
||||
x = self.fake_quant_act_before(x)
|
||||
x = self.act(x)
|
||||
x = self.fake_quant_act_after(x)
|
||||
return x
|
||||
|
||||
|
||||
class TensorAddQuant(Cell):
|
||||
r"""
|
||||
Add Fake Quant OP after TensorAdd OP.
|
||||
|
||||
For a more Detailed overview of TensorAdd op.
|
||||
|
||||
Args:
|
||||
num_bits (int): Quantization number bit, support 4 and 8bit. Default: 8.
|
||||
quant_delay (int): Quantization delay parameters according by global step. Default: 0.
|
||||
symmetric (bool): Quantization algorithm use symmetric or not. Default: False.
|
||||
narrow_range (bool): Quantization algorithm use narrow range or not. Default: False.
|
||||
|
||||
Inputs:
|
||||
- **x** (Tensor) - The input of TensorAddQuant.
|
||||
|
||||
Outputs:
|
||||
Tensor, with the same type and shape as the `x`.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
num_bits=8,
|
||||
quant_delay=0,
|
||||
symmetric=False,
|
||||
narrow_range=False):
|
||||
super(TensorAddQuant, self).__init__()
|
||||
self.fake_quant_act = nn.FakeQuantWithMinMax(min_init=-6,
|
||||
max_init=6,
|
||||
num_bits=num_bits,
|
||||
quant_delay=quant_delay,
|
||||
ema=True,
|
||||
symmetric=symmetric,
|
||||
narrow_range=narrow_range)
|
||||
self.add = P.TensorAdd()
|
||||
|
||||
def construct(self, x1, x2):
|
||||
x = self.add(x1, x2)
|
||||
x = self.fake_quant_act(x)
|
||||
return x
|
|
@ -234,7 +234,7 @@ class Tanh(Cell):
|
|||
|
||||
|
||||
class GELU(Cell):
|
||||
"""
|
||||
r"""
|
||||
Gaussian error linear unit activation function.
|
||||
|
||||
Applies GELU function to each element of the input. The input is a Tensor with any valid shape.
|
||||
|
@ -332,15 +332,74 @@ class PReLU(Cell):
|
|||
return v
|
||||
|
||||
|
||||
class HSwish(Cell):
|
||||
r"""
|
||||
rHard swish activation function.
|
||||
|
||||
Applies hswish-type activation element-wise. The input is a Tensor with any valid shape.
|
||||
|
||||
Hard swish is defined as:
|
||||
|
||||
.. math::
|
||||
\text{hswish}(x_{i}) = x_{i} * \frac{ReLU6(x_{i} + 3)}{6},
|
||||
|
||||
where :math:`x_{i}` is the :math:`i`-th slice along the given dim of the input Tensor.
|
||||
|
||||
Inputs:
|
||||
- **input_data** (Tensor) - The input of Hswish.
|
||||
|
||||
Outputs:
|
||||
Tensor, with the same type and shape as the `input_data`.
|
||||
|
||||
"""
|
||||
def __init__(self):
|
||||
super(HSwish, self).__init__()
|
||||
self.hswish = P.HSwish()
|
||||
|
||||
def construct(self, x):
|
||||
return self.hswish(x)
|
||||
|
||||
|
||||
class HSigmoid(Cell):
|
||||
r"""
|
||||
Hard sigmoid activation function.
|
||||
|
||||
Applies hard sigmoid activation element-wise. The input is a Tensor with any valid shape.
|
||||
|
||||
Hard sigmoid is defined as:
|
||||
|
||||
.. math::
|
||||
\text{hsigmoid}(x_{i}) = max(0, min(1, \ftac{2 * x_{i} + 5}{10})),
|
||||
|
||||
where :math:`x_{i}` is the :math:`i`-th slice along the given dim of the input Tensor.
|
||||
|
||||
Inputs:
|
||||
- **input_data** (Tensor) - The input of HSigmoid.
|
||||
|
||||
Outputs:
|
||||
Tensor, with the same type and shape as the `input_data`.
|
||||
|
||||
"""
|
||||
def __init__(self):
|
||||
super(HSigmoid, self).__init__()
|
||||
self.hsigmoid = P.HSigmoid()
|
||||
|
||||
def construct(self, x):
|
||||
return self.hsigmoid(x)
|
||||
|
||||
|
||||
_activation = {
|
||||
'softmax': Softmax,
|
||||
'logsoftmax': LogSoftmax,
|
||||
'relu': ReLU,
|
||||
'relu6': ReLU6,
|
||||
'tanh': Tanh,
|
||||
'gelu': GELU,
|
||||
'sigmoid': Sigmoid,
|
||||
'prelu': PReLU,
|
||||
'leakyrelu': LeakyReLU
|
||||
'leakyrelu': LeakyReLU,
|
||||
'hswish': HSwish,
|
||||
'hsigmoid': HSigmoid,
|
||||
}
|
||||
|
||||
|
||||
|
|
|
@ -172,6 +172,28 @@ def get_bprop_relu6(self):
|
|||
return bprop
|
||||
|
||||
|
||||
@bprop_getters.register(P.HSwish)
|
||||
def get_bprop_hswish(self):
|
||||
"""Grad definition for `HSwish` operation."""
|
||||
input_grad = G.HSwishGrad()
|
||||
|
||||
def bprop(x, out, dout):
|
||||
dx = input_grad(dout, x)
|
||||
return (dx,)
|
||||
return bprop
|
||||
|
||||
|
||||
@bprop_getters.register(P.HSigmoid)
|
||||
def get_bprop_hsigmoid(self):
|
||||
"""Grad definition for `HSigmoid` operation."""
|
||||
input_grad = G.HSigmoidGrad()
|
||||
|
||||
def bprop(x, out, dout):
|
||||
dx = input_grad(dout, x)
|
||||
return (dx,)
|
||||
return bprop
|
||||
|
||||
|
||||
@bprop_getters.register(P.Elu)
|
||||
def get_bprop_elu(self):
|
||||
"""Grad definition for `Elu` operation."""
|
||||
|
|
|
@ -0,0 +1,82 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
"""Generate bprop for aware quantization ops"""
|
||||
|
||||
from .. import operations as P
|
||||
from .grad_base import bprop_getters
|
||||
from ..composite.multitype_ops.zeros_like_impl import zeros_like
|
||||
|
||||
|
||||
@bprop_getters.register(P.FakeQuantWithMinMax)
|
||||
def get_bprop_fakequant_with_minmax(self):
|
||||
"""Generate bprop for FakeQuantWithMinMax"""
|
||||
op = P.FakeQuantWithMinMaxGrad(num_bits=self.num_bits, quant_delay=self.quant_delay)
|
||||
|
||||
def bprop(x, x_min, x_max, out, dout):
|
||||
dx = op(dout, x, x_min, x_max)
|
||||
return dx, zeros_like(x_min), zeros_like(x_max)
|
||||
|
||||
return bprop
|
||||
|
||||
|
||||
@bprop_getters.register(P.FakeQuantWithMinMaxPerChannel)
|
||||
def get_bprop_fakequant_with_minmax_perchannel(self):
|
||||
"""Generate bprop for FakeQuantWithMinMaxPerChannel"""
|
||||
op = P.FakeQuantWithMinMaxPerChannelGrad(num_bits=self.num_bits, quant_delay=self.quant_delay)
|
||||
|
||||
def bprop(x, x_min, x_max, out, dout):
|
||||
dx = op(dout, x, x_min, x_max)
|
||||
return dx, zeros_like(x_min), zeros_like(x_max)
|
||||
|
||||
return bprop
|
||||
|
||||
|
||||
@bprop_getters.register(P.BatchNormFold)
|
||||
def get_bprop_batchnorm_fold(self):
|
||||
"""Generate bprop for BatchNormFold"""
|
||||
op = P.BatchNormFoldGrad(self.epsilon, self.is_training, self.freeze_bn)
|
||||
|
||||
def bprop(x, mean, variance, global_step, out, dout):
|
||||
dx = op(dout[0], dout[1], x, out[0], out[1], global_step)
|
||||
return dx, zeros_like(mean), zeros_like(variance), zeros_like(global_step)
|
||||
|
||||
return bprop
|
||||
|
||||
|
||||
@bprop_getters.register(P.CorrectionMul)
|
||||
def get_bprop_correction_mul(self):
|
||||
"""Generate bprop for CorrectionMul"""
|
||||
grad = P.CorrectionMulGrad()
|
||||
|
||||
def bprop(x, batch_std, running_std, out, dout):
|
||||
dx, d_batch_std = grad(dout, x, batch_std, running_std)
|
||||
return dx, d_batch_std, zeros_like(running_std)
|
||||
|
||||
return bprop
|
||||
|
||||
|
||||
@bprop_getters.register(P.BatchNormFold2)
|
||||
def get_bprop_batchnorm_fold2(self):
|
||||
"""Generate bprop for CorrectionAdd"""
|
||||
op_f = P.BatchNormFold2Grad(freeze_bn=self.freeze_bn)
|
||||
|
||||
def bprop(x, beta, gamma, batch_std, batch_mean, running_std, running_mean, global_step, out, dout):
|
||||
d_batch_std, d_batch_mean, d_beta, d_gamma, d_x = op_f(dout, x, gamma, batch_std, batch_mean, running_std,
|
||||
running_mean, global_step)
|
||||
return d_x, d_beta, d_gamma, d_batch_std, d_batch_mean, zeros_like(running_std), zeros_like(running_mean), \
|
||||
zeros_like(global_step)
|
||||
|
||||
return bprop
|
|
@ -59,7 +59,7 @@ from .nn_ops import (LSTM, SGD, Adam, ApplyMomentum, BatchNorm,
|
|||
LogSoftmax,
|
||||
MaxPool,
|
||||
AvgPool, Conv2DBackpropInput,
|
||||
MaxPoolWithArgmax, OneHot, Pad, PReLU, ReLU, ReLU6,
|
||||
MaxPoolWithArgmax, OneHot, Pad, PReLU, ReLU, ReLU6, HSwish, HSigmoid,
|
||||
ResizeBilinear, Sigmoid,
|
||||
SigmoidCrossEntropyWithLogits,
|
||||
SmoothL1Loss, Softmax,
|
||||
|
@ -68,7 +68,8 @@ from .nn_ops import (LSTM, SGD, Adam, ApplyMomentum, BatchNorm,
|
|||
TopK, BinaryCrossEntropy, SparseApplyAdagrad, LARSUpdate, ApplyFtrl,
|
||||
ApplyRMSProp, ApplyCenteredRMSProp)
|
||||
from .other_ops import Assign, IOU, BoundingBoxDecode, BoundingBoxEncode, CheckValid, MakeRefKey
|
||||
|
||||
from . import _quant_ops
|
||||
from ._quant_ops import *
|
||||
|
||||
__all__ = [
|
||||
'TensorAdd',
|
||||
|
@ -138,6 +139,8 @@ __all__ = [
|
|||
'ReLU6',
|
||||
'Elu',
|
||||
'Sigmoid',
|
||||
'HSwish',
|
||||
'HSigmoid',
|
||||
'Tanh',
|
||||
'RandomChoiceWithMask',
|
||||
'ResizeBilinear',
|
||||
|
@ -241,4 +244,5 @@ __all__ = [
|
|||
"ApplyCenteredRMSProp"
|
||||
]
|
||||
|
||||
__all__.extend(_quant_ops.__all__)
|
||||
__all__.sort()
|
||||
|
|
|
@ -805,6 +805,38 @@ class SigmoidGrad(PrimitiveWithInfer):
|
|||
return out
|
||||
|
||||
|
||||
class HSigmoidGrad(PrimitiveWithInfer):
|
||||
"""Gets the gradient of HSigmoid operation."""
|
||||
|
||||
@prim_attr_register
|
||||
def __init__(self):
|
||||
self.init_prim_io_names(inputs=['y_grad', 'x'], outputs=['output'])
|
||||
|
||||
def infer_shape(self, y_grad_shape, x_shape):
|
||||
return x_shape
|
||||
|
||||
def infer_dtype(self, y_grad_dtype, x_dtype):
|
||||
validator.check_typename("y_grad dtype", y_grad_dtype, (mstype.float16, mstype.float32))
|
||||
validator.check_typename("x dtype", x_dtype, (mstype.float16, mstype.float32))
|
||||
return x_dtype
|
||||
|
||||
|
||||
class HSwishGrad(PrimitiveWithInfer):
|
||||
"""Gets the gradient of HSwish operation."""
|
||||
|
||||
@prim_attr_register
|
||||
def __init__(self):
|
||||
self.init_prim_io_names(inputs=['y_grad', 'x'], outputs=['output'])
|
||||
|
||||
def infer_shape(self, y_grad_shape, x_shape):
|
||||
return x_shape
|
||||
|
||||
def infer_dtype(self, y_grad_dtype, x_dtype):
|
||||
validator.check_typename("y_grad dtype", y_grad_dtype, (mstype.float16, mstype.float32))
|
||||
validator.check_typename("x_ dtype", x_dtype, (mstype.float16, mstype.float32))
|
||||
return x_dtype
|
||||
|
||||
|
||||
class SigmoidCrossEntropyWithLogitsGrad(PrimitiveWithInfer):
|
||||
"""Computes the gradients of `SigmoidCrossEntropyWithLogits`."""
|
||||
|
||||
|
|
|
@ -0,0 +1,525 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0(the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http: // www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
"""Operators for quantization."""
|
||||
|
||||
from ..._checkparam import ParamValidator as validator
|
||||
from ..._checkparam import Rel, check_bool, check_int_positive, check_int
|
||||
from ..primitive import PrimitiveWithInfer, prim_attr_register
|
||||
from ...common import dtype as mstype
|
||||
|
||||
__all__ = ["FakeQuantWithMinMax",
|
||||
"FakeQuantWithMinMaxGrad",
|
||||
"FakeQuantWithMinMaxPerChannel",
|
||||
"FakeQuantWithMinMaxPerChannelGrad",
|
||||
"BatchNormFold",
|
||||
"BatchNormFoldGrad",
|
||||
"CorrectionMul",
|
||||
"CorrectionMulGrad",
|
||||
"BatchNormFold2",
|
||||
"BatchNormFold2Grad",
|
||||
]
|
||||
|
||||
|
||||
class FakeQuantWithMinMax(PrimitiveWithInfer):
|
||||
r"""
|
||||
Simulate the quantize and dequantize operations in training time.
|
||||
|
||||
Args:
|
||||
num_bits (int) : Number bits for aware quantilization. Default: 8.
|
||||
ema (bool): Use EMA algorithm update value min and max. Default: False.
|
||||
ema_decay (int) : EMA algorithm decay parameter. Default: 0.999.
|
||||
quant_delay (int): Quantilization delay parameter. Before delay step in training time not update
|
||||
simulate aware quantize funcion. After delay step in training time begin simulate the aware
|
||||
quantize funcion. Default: 0.
|
||||
symmetric (bool): Quantization algorithm use symmetric or not. Default: False.
|
||||
narrow_range (bool): Quantization algorithm use narrow range or not. Default: False.
|
||||
training (bool): Training the network or not. Default: True.
|
||||
|
||||
Inputs:
|
||||
- **x** (Tensor) : float32 Tensor representing the shape of the output tensor.
|
||||
- **min** (Tensor) : Value of the min range of the input data x.
|
||||
- **max** (Tensor) : Value of the max range of the input data x.
|
||||
|
||||
Outputs:
|
||||
- Tensor: Simulate quantize tensor of x.
|
||||
|
||||
Examples:
|
||||
>>> input_tensor = Tensor(np.random.rand(3, 16, 5, 5), mstype.float32)
|
||||
>>> min_tensor = Tensor(np.array([-6]), mstype.float32)
|
||||
>>> max_tensor = Tensor(np.array([6]), mstype.float32)
|
||||
>>> output_tensor = P.FakeQuantWithMinMax(num_bits=8)(input_tensor, min_tensor, max_tensor)
|
||||
"""
|
||||
support_quant_bit = [4, 7, 8]
|
||||
|
||||
@prim_attr_register
|
||||
def __init__(self, num_bits=8, ema=False, ema_decay=0.999, quant_delay=0, symmetric=False, narrow_range=False,
|
||||
training=True):
|
||||
"""init FakeQuantWithMinMax OP"""
|
||||
if num_bits not in self.support_quant_bit:
|
||||
raise ValueError("Attr \'num_bits\' is not support.")
|
||||
if ema and not ema_decay:
|
||||
raise ValueError(
|
||||
"Attr \'ema\' and \'ema_decay\' should set together.")
|
||||
|
||||
self.ema = check_bool(ema)
|
||||
self.symmetric = check_bool(symmetric)
|
||||
self.narrow_range = check_bool(narrow_range)
|
||||
self.training = check_bool(training)
|
||||
self.ema_decay = validator.check_number_range(
|
||||
'ema_decay', ema_decay, 0, 1, Rel.INC_BOTH)
|
||||
self.num_bits = check_int_positive(num_bits)
|
||||
self.quant_delay = check_int(quant_delay)
|
||||
self.init_prim_io_names(inputs=['x', 'min', 'max'],
|
||||
outputs=['out'])
|
||||
|
||||
def infer_shape(self, x_shape, min_shape, max_shape):
|
||||
validator.check_integer("x shape", len(x_shape), 1, Rel.GT)
|
||||
validator.check("min shape", min_shape, "max shape", max_shape)
|
||||
validator.check_integer("min shape", len(min_shape), 1, Rel.EQ)
|
||||
validator.check_integer("max shape", len(min_shape), 1, Rel.EQ)
|
||||
return x_shape
|
||||
|
||||
def infer_dtype(self, x_type, min_type, max_type):
|
||||
validator.check_typename(
|
||||
"x type", x_type, (mstype.float16, mstype.float32))
|
||||
validator.check_typename("min type", min_type,
|
||||
(mstype.float16, mstype.float32))
|
||||
validator.check_typename("max type", max_type,
|
||||
(mstype.float16, mstype.float32))
|
||||
return x_type
|
||||
|
||||
|
||||
class FakeQuantWithMinMaxGrad(PrimitiveWithInfer):
|
||||
"""Performs grad of FakeQuantWithMinMax operation."""
|
||||
support_quant_bit = [4, 8]
|
||||
|
||||
@prim_attr_register
|
||||
def __init__(self, num_bits=8, quant_delay=0):
|
||||
if num_bits not in self.support_quant_bit:
|
||||
raise ValueError("Attr \'num_bits\' is not support.")
|
||||
|
||||
self.quant_delay = check_int(quant_delay)
|
||||
self.num_bits = check_int_positive(num_bits)
|
||||
self.init_prim_io_names(inputs=['dout', 'x', 'min', 'max'],
|
||||
outputs=['dx'])
|
||||
|
||||
def infer_shape(self, dout_shape, x_shape, min_shape, max_shape):
|
||||
validator.check("dout shape", dout_shape, "x shape", x_shape)
|
||||
validator.check("min shape", min_shape, "max shape", max_shape)
|
||||
validator.check_integer("min shape", len(min_shape), 1, Rel.EQ)
|
||||
validator.check_integer("max shape", len(min_shape), 1, Rel.EQ)
|
||||
return dout_shape
|
||||
|
||||
def infer_dtype(self, dout_type, x_type, min_type, max_type):
|
||||
validator.check_typename(
|
||||
"dout type", dout_type, (mstype.float16, mstype.float32))
|
||||
validator.check_typename(
|
||||
"x type", x_type, (mstype.float16, mstype.float32))
|
||||
validator.check_typename("min type", min_type,
|
||||
(mstype.float16, mstype.float32))
|
||||
validator.check_typename("max type", max_type,
|
||||
(mstype.float16, mstype.float32))
|
||||
return dout_type
|
||||
|
||||
|
||||
class FakeQuantWithMinMaxPerChannel(PrimitiveWithInfer):
|
||||
r"""
|
||||
Simulate the quantize and dequantize operations in training time base on per channel.
|
||||
|
||||
Args:
|
||||
num_bits (int) : Number bits to quantilization. Default: 8.
|
||||
ema (bool): Use EMA algorithm update tensor min and tensor max. Default: False.
|
||||
ema_decay (int) : EMA algorithm decay parameter. Default: 0.999.
|
||||
quant_delay (int): Quantilization delay parameter. Before delay step in training time not
|
||||
update the weight data to simulate quantize operation. After delay step in training time
|
||||
begin simulate the quantize operation. Default: 0.
|
||||
symmetric (bool): Quantization algorithm use symmetric or not. Default: False.
|
||||
narrow_range (bool): Quantization algorithm use narrow range or not. Default: False.
|
||||
training (bool): Training the network or not. Default: True.
|
||||
|
||||
Inputs:
|
||||
- **x** (Tensor) : 4-D float32 Tensor representing the shape of the output tensor.
|
||||
- **min** (int, float) : Value of the min range of the input data.
|
||||
- **max** (int, float) : Value of the max range of the input data.
|
||||
|
||||
Outputs:
|
||||
- Tensor, has the same type as input.
|
||||
|
||||
Examples:
|
||||
>>> input_tensor = Tensor(np.random.rand(3,4,5,5), mstype.float32)
|
||||
>>> min_tensor = Tensor(np.array([-6.0, -6.5, -4.0, -5.0]), mstype.float32)
|
||||
>>> max_tensor = Tensor(np.array([6.0, 6.5, 4.0, 5.0]), mstype.float32)
|
||||
>>> output_tensor = P.FakeQuantWithMinMax(num_bits=8)(input_tensor, min_tensor, max_tensor)
|
||||
"""
|
||||
support_quant_bit = [4, 8]
|
||||
channel_idx = 0
|
||||
|
||||
@prim_attr_register
|
||||
def __init__(self, num_bits=8, ema=False, ema_decay=0.999, quant_delay=0, symmetric=False, narrow_range=False,
|
||||
training=True):
|
||||
"""init FakeQuantWithMinMaxPerChannel OP"""
|
||||
if num_bits not in self.support_quant_bit:
|
||||
raise ValueError("Attr \'num_bits\' is not support.")
|
||||
if ema and not ema_decay:
|
||||
raise ValueError(
|
||||
"Attr \'ema\' and \'ema_decay\' should set together.")
|
||||
|
||||
self.ema = check_bool(ema)
|
||||
self.symmetric = check_bool(symmetric)
|
||||
self.narrow_range = check_bool(narrow_range)
|
||||
self.training = check_bool(training)
|
||||
self.ema_decay = validator.check_number_range(
|
||||
'ema_decay', ema_decay, 0, 1, Rel.INC_BOTH)
|
||||
self.num_bits = check_int_positive(num_bits)
|
||||
self.quant_delay = check_int(quant_delay)
|
||||
self.init_prim_io_names(inputs=['x', 'min', 'max'],
|
||||
outputs=['out'])
|
||||
|
||||
def infer_shape(self, x_shape, min_shape, max_shape):
|
||||
validator.check_integer("x shape", len(x_shape), 1, Rel.GT)
|
||||
validator.check_integer(
|
||||
"min len", min_shape[0], x_shape[self.channel_idx], Rel.EQ)
|
||||
validator.check_integer(
|
||||
"max len", max_shape[0], x_shape[self.channel_idx], Rel.EQ)
|
||||
return x_shape
|
||||
|
||||
def infer_dtype(self, x_type, min_type, max_type):
|
||||
validator.check_typename(
|
||||
"x type", x_type, (mstype.float16, mstype.float32))
|
||||
validator.check_typename("min type", min_type,
|
||||
(mstype.float16, mstype.float32))
|
||||
validator.check_typename("max type", max_type,
|
||||
(mstype.float16, mstype.float32))
|
||||
return x_type
|
||||
|
||||
|
||||
class FakeQuantWithMinMaxPerChannelGrad(PrimitiveWithInfer):
|
||||
"""Performs grad of FakeQuantWithMinMaxPerChannel operation."""
|
||||
support_quant_bit = [4, 8]
|
||||
|
||||
@prim_attr_register
|
||||
def __init__(self, num_bits=8, quant_delay=0):
|
||||
"""init FakeQuantWithMinMaxPerChannel Fill"""
|
||||
if num_bits not in self.support_quant_bit:
|
||||
raise ValueError("Attr \'num_bits\' is not support.")
|
||||
|
||||
self.quant_delay = check_int(quant_delay)
|
||||
self.num_bits = check_int_positive(num_bits)
|
||||
self.init_prim_io_names(inputs=['dout', 'x', 'min', 'max'],
|
||||
outputs=['dx'])
|
||||
|
||||
def infer_shape(self, dout_shape, x_shape, min_shape, max_shape):
|
||||
validator.check("dout shape", dout_shape, "x shape", x_shape)
|
||||
validator.check("min shape", min_shape, "max shape", max_shape)
|
||||
return dout_shape
|
||||
|
||||
def infer_dtype(self, dout_type, x_type, min_type, max_type):
|
||||
validator.check_typename(
|
||||
"dout", dout_type, (mstype.float16, mstype.float32))
|
||||
validator.check_typename("x", x_type, (mstype.float16, mstype.float32))
|
||||
validator.check_typename(
|
||||
"min", min_type, (mstype.float16, mstype.float32))
|
||||
validator.check_typename(
|
||||
"max", max_type, (mstype.float16, mstype.float32))
|
||||
return dout_type
|
||||
|
||||
|
||||
class BatchNormFold(PrimitiveWithInfer):
|
||||
"""
|
||||
Batch normalization folded.
|
||||
|
||||
Args:
|
||||
momentum (float): Momentum value should be [0, 1]. Default: 0.1.
|
||||
epsilon (float): A small float number to avoid dividing by 0. 1e-12 if dtype in
|
||||
float32 else 1e-3. Default: 1e-12.
|
||||
is_training (bool): In training mode set True, else set False. Default: True.
|
||||
freeze_bn (int): Delay in steps at which computation switches from regular batch
|
||||
norm to frozen mean and std. Default: 0.
|
||||
|
||||
Inputs:
|
||||
- **x** (Tensor) - Tensor of shape :math:`(N, C)`.
|
||||
- **mean** (Tensor) - Tensor of shape :math:`(C,)`.
|
||||
- **variance** (Tensor) - Tensor of shape :math:`(C,)`.
|
||||
- **global_step** (Tensor) - Tensor to record current global step.
|
||||
|
||||
Outputs:
|
||||
Tuple of 4 Tensor, the normalized input and the updated parameters.
|
||||
|
||||
- **batch_mean** (Tensor) - Tensor of shape :math:`(C,)`.
|
||||
- **batch_std** (Tensor) - Tensor of shape :math:`(C,)`.
|
||||
- **running_mean** (Tensor) - Tensor of shape :math:`(C,)`.
|
||||
- **running_std** (Tensor) - Tensor of shape :math:`(C,)`.
|
||||
|
||||
"""
|
||||
channel = 1
|
||||
|
||||
@prim_attr_register
|
||||
def __init__(self, momentum=0.1, epsilon=1e-12, is_training=True, freeze_bn=0):
|
||||
"""init batch norm fold layer"""
|
||||
self.momentum = validator.check_number_range(
|
||||
'momentum', momentum, 0, 1, Rel.INC_BOTH)
|
||||
self.epsilon = validator.check_float_positive('epsilon', epsilon)
|
||||
self.is_training = check_bool(is_training)
|
||||
self.freeze_bn = check_int(freeze_bn)
|
||||
|
||||
self.init_prim_io_names(inputs=['x', 'mean', 'variance', 'global_step'],
|
||||
outputs=['batch_mean', 'batch_std', 'running_mean', 'running_std'])
|
||||
|
||||
def infer_shape(self, x_shape, mean_shape, variance_shape, global_step_shape):
|
||||
validator.check("mean shape", mean_shape,
|
||||
"gamma_shape", variance_shape)
|
||||
validator.check("mean_shape size",
|
||||
mean_shape[0], "input channel", x_shape[self.channel])
|
||||
validator.check_integer("global_step shape",
|
||||
len(global_step_shape), 1, Rel.EQ)
|
||||
return mean_shape, mean_shape, mean_shape, mean_shape
|
||||
|
||||
def infer_dtype(self, x_type, mean_type, variance_type, global_step_type):
|
||||
validator.check("input type", x_type, "mean type", mean_type)
|
||||
validator.check("input type", x_type, "variance type", variance_type)
|
||||
validator.check_typename("input type", x_type,
|
||||
(mstype.float16, mstype.float32))
|
||||
validator.check_typename(
|
||||
"global_step type", global_step_type, (mstype.int32,))
|
||||
return x_type, x_type, x_type, x_type
|
||||
|
||||
|
||||
class BatchNormFoldGrad(PrimitiveWithInfer):
|
||||
"""Performs grad of BatchNormFold operation."""
|
||||
channel = 1
|
||||
|
||||
@prim_attr_register
|
||||
def __init__(self, epsilon=1e-12, is_training=True, freeze_bn=0):
|
||||
"""init BatchNormGrad layer"""
|
||||
self.is_training = check_bool(is_training)
|
||||
self.freeze_bn = check_int(freeze_bn)
|
||||
self.epsilon = validator.check_float_positive('epsilon', epsilon)
|
||||
self.init_prim_io_names(inputs=['d_batch_mean', 'd_batch_std', 'x', 'batch_mean', 'batch_std', 'global_step'],
|
||||
outputs=['dx'])
|
||||
|
||||
def infer_shape(self, d_batch_mean_shape, d_batch_std_shape, x_shape, batch_mean_shape, batch_std_shape,
|
||||
global_step_shape):
|
||||
validator.check("d_batch_mean shape", d_batch_mean_shape,
|
||||
"d_batch_std shape", d_batch_std_shape)
|
||||
validator.check("d_batch_mean shape", d_batch_mean_shape,
|
||||
"batch_mean shape", batch_mean_shape)
|
||||
validator.check("d_batch_mean shape", d_batch_mean_shape,
|
||||
"batch_std shape", batch_std_shape)
|
||||
validator.check(
|
||||
"x_shape shape", d_batch_mean_shape[0], "input channel", x_shape[self.channel])
|
||||
validator.check_integer("global_step shape",
|
||||
len(global_step_shape), 1, Rel.EQ)
|
||||
return x_shape
|
||||
|
||||
def infer_dtype(self, d_batch_mean_type, d_batch_std_type, x_type, batch_mean_type, batch_std_type,
|
||||
global_step_type):
|
||||
validator.check("input type", x_type,
|
||||
"d_batch_mean type", d_batch_mean_type)
|
||||
validator.check("input type", x_type,
|
||||
"d_batch_std type", d_batch_std_type)
|
||||
validator.check("input type", x_type,
|
||||
"batch_mean type", batch_mean_type)
|
||||
validator.check("input type", x_type, "batch_std type", batch_std_type)
|
||||
validator.check_typename("input type", x_type,
|
||||
(mstype.float16, mstype.float32))
|
||||
validator.check_typename(
|
||||
"global_step type", global_step_type, (mstype.int32,))
|
||||
return x_type
|
||||
|
||||
|
||||
class CorrectionMul(PrimitiveWithInfer):
|
||||
"""
|
||||
Scale the weights with a correction factor to the long term statistics
|
||||
prior to quantization. This ensures that there is no jitter in the quantized weights
|
||||
due to batch to batch variation.
|
||||
|
||||
Inputs:
|
||||
- **x** (Tensor) - Tensor of shape :math:`(N, C)`.
|
||||
- **batch_std** (Tensor) - Tensor of shape :math:`(C,)`.
|
||||
- **running_std** (Tensor) - Tensor of shape :math:`(C,)`.
|
||||
|
||||
Outputs:
|
||||
- **out** (Tensor) - Tensor has the same shape as x.
|
||||
|
||||
"""
|
||||
channel = 0
|
||||
|
||||
@prim_attr_register
|
||||
def __init__(self):
|
||||
"""init correction mul layer"""
|
||||
self.init_prim_io_names(inputs=['x', 'batch_std', 'running_std'],
|
||||
outputs=['out'])
|
||||
|
||||
def infer_shape(self, x_shape, batch_std_shape, running_std_shape):
|
||||
validator.check("batch_std shape", batch_std_shape,
|
||||
"running_std shape", running_std_shape)
|
||||
validator.check(
|
||||
"batch_std size", batch_std_shape[0], "x_shape channel size", x_shape[self.channel])
|
||||
return x_shape
|
||||
|
||||
def infer_dtype(self, x_type, batch_std_type, running_std_type):
|
||||
validator.check("batch_std type", batch_std_type,
|
||||
"running_std type", running_std_type)
|
||||
validator.check("batch_std_type", batch_std_type, "x_type", x_type)
|
||||
validator.check_typename(
|
||||
"batch_std type", batch_std_type, (mstype.float16, mstype.float32))
|
||||
return x_type
|
||||
|
||||
|
||||
class CorrectionMulGrad(PrimitiveWithInfer):
|
||||
"""Performs grad of CorrectionMul operation."""
|
||||
channel = 0
|
||||
|
||||
@prim_attr_register
|
||||
def __init__(self):
|
||||
"""init correction mul layer"""
|
||||
self.init_prim_io_names(inputs=['dout', 'x', 'gamma', 'running_std'],
|
||||
outputs=['dx', 'd_gamma'])
|
||||
|
||||
def infer_shape(self, dout_shape, x_shape, gamma_shape, running_std_shape):
|
||||
validator.check("dout shape", dout_shape, "x_shape x", x_shape)
|
||||
validator.check(
|
||||
"gamma size", gamma_shape[0], "dout channel size", dout_shape[self.channel])
|
||||
validator.check(
|
||||
"running_std size", running_std_shape[0], "dout channel size", dout_shape[self.channel])
|
||||
return x_shape, gamma_shape
|
||||
|
||||
def infer_dtype(self, dout_type, x_type, gamma_type, running_std_type):
|
||||
validator.check("x type", x_type, "dout type", dout_type)
|
||||
validator.check("gamma type", gamma_type, "dout type", dout_type)
|
||||
validator.check("running_std type", running_std_type,
|
||||
"dout type", dout_type)
|
||||
validator.check_typename(
|
||||
"dout type", dout_type, (mstype.float16, mstype.float32))
|
||||
return x_type, x_type
|
||||
|
||||
|
||||
class BatchNormFold2(PrimitiveWithInfer):
|
||||
"""
|
||||
Scale the bias with a correction factor to the long term statistics
|
||||
prior to quantization. This ensures that there is no jitter in the quantized bias
|
||||
due to batch to batch variation.
|
||||
|
||||
Inputs:
|
||||
- **x** (Tensor) - Tensor of shape :math:`(N, C)`.
|
||||
- **beta** (Tensor) - Tensor of shape :math:`(C,)`.
|
||||
- **gamma** (Tensor) - Tensor of shape :math:`(C,)`.
|
||||
- **batch_std** (Tensor) - Tensor of shape :math:`(C,)`.
|
||||
- **batch_mean** (Tensor) - Tensor of shape :math:`(C,)`.
|
||||
- **running_std** (Tensor) - Tensor of shape :math:`(C,)`.
|
||||
- **running_mean** (Tensor) - Tensor of shape :math:`(C,)`.
|
||||
- **global_step** (Tensor) - Tensor to record current global step.
|
||||
|
||||
Outputs:
|
||||
- **y** (Tensor) - Tensor has the same shape as x.
|
||||
|
||||
"""
|
||||
channel = 1
|
||||
|
||||
@prim_attr_register
|
||||
def __init__(self, freeze_bn=0):
|
||||
"""init conv2d fold layer"""
|
||||
self.freeze_bn = check_int(freeze_bn)
|
||||
self.init_prim_io_names(inputs=['x', 'beta', 'gamma', 'batch_std', 'batch_mean',
|
||||
'running_std', 'running_mean', 'global_step'],
|
||||
outputs=['y'])
|
||||
|
||||
def infer_shape(self, x_shape, beta_shape, gamma_shape, batch_std_shape, running_std_shape, batch_mean_shape,
|
||||
running_mean_shape, global_step_shape):
|
||||
validator.check("batch_std shape", batch_std_shape,
|
||||
"running_std shape", running_std_shape)
|
||||
validator.check("batch_std shape", batch_std_shape,
|
||||
"batch_mean shape", batch_mean_shape)
|
||||
validator.check("batch_std shape", batch_std_shape,
|
||||
"beta shape", beta_shape)
|
||||
validator.check("batch_std shape", batch_std_shape,
|
||||
"running_mean shape", running_mean_shape)
|
||||
validator.check("batch_std shape", batch_std_shape,
|
||||
"batch_mean shape", gamma_shape)
|
||||
validator.check(
|
||||
"batch_std size", batch_std_shape[0], "x_shape channel size", x_shape[self.channel])
|
||||
validator.check_integer("global_step shape",
|
||||
len(global_step_shape), 1, Rel.EQ)
|
||||
return x_shape
|
||||
|
||||
def infer_dtype(self, x_type, beta_type, gamma_type, batch_std_type, running_std_type, batch_mean_type,
|
||||
running_mean_type, global_step_type):
|
||||
validator.check("batch_std type", batch_std_type,
|
||||
"running_std type", running_std_type)
|
||||
validator.check("batch_std type", batch_std_type,
|
||||
"batch_mean type", batch_mean_type)
|
||||
validator.check("batch_std type", batch_std_type,
|
||||
"beta type", beta_type)
|
||||
validator.check("batch_std type", batch_std_type,
|
||||
"running_mean type", running_mean_type)
|
||||
validator.check("batch_std type", batch_std_type,
|
||||
"gamma type", gamma_type)
|
||||
validator.check("x_type", x_type, "batch_std type", batch_std_type)
|
||||
validator.check_typename(
|
||||
"batch_std type", batch_std_type, (mstype.float16, mstype.float32))
|
||||
validator.check_typename(
|
||||
"global_step type", global_step_type, (mstype.int32,))
|
||||
return x_type
|
||||
|
||||
|
||||
class BatchNormFold2Grad(PrimitiveWithInfer):
|
||||
"""Performs grad of CorrectionAddGrad operation."""
|
||||
channel = 1
|
||||
|
||||
@prim_attr_register
|
||||
def __init__(self, freeze_bn=0):
|
||||
"""init MulFold layer"""
|
||||
self.freeze_bn = freeze_bn
|
||||
self.init_prim_io_names(inputs=['dout', 'x', 'gamma',
|
||||
'batch_std', 'batch_mean',
|
||||
'running_std', 'running_mean', 'global_step'],
|
||||
outputs=['d_batch_std', 'd_batch_mean', 'd_beta', 'd_gamma', 'dx'])
|
||||
|
||||
def infer_shape(self, dout_shape, x_shape, gamma_shape,
|
||||
batch_std_shape, batch_mean_shape,
|
||||
running_std_shape, running_mean_shape, global_step_shape):
|
||||
validator.check("batch_std shape", batch_std_shape,
|
||||
"batch_mean shape", batch_mean_shape)
|
||||
validator.check("batch_std shape", batch_std_shape,
|
||||
"running_std shape", running_std_shape)
|
||||
validator.check("batch_std shape", batch_std_shape,
|
||||
"running_mean shape", running_mean_shape)
|
||||
validator.check("batch_std shape", batch_std_shape,
|
||||
"gamma shape", gamma_shape)
|
||||
validator.check(
|
||||
"batch_std size", batch_std_shape[0], "dout channel size", dout_shape[self.channel])
|
||||
validator.check_integer("global_step shape",
|
||||
len(global_step_shape), 1, Rel.EQ)
|
||||
return gamma_shape, gamma_shape, gamma_shape, gamma_shape, x_shape
|
||||
|
||||
def infer_dtype(self, dout_type, x_type, gamma_type,
|
||||
batch_std_type, batch_mean_type,
|
||||
running_std_type, running_mean_type, global_step_type):
|
||||
validator.check("batch_std type", batch_std_type,
|
||||
"batch_mean type", batch_mean_type)
|
||||
validator.check("batch_std type", batch_std_type,
|
||||
"gamma type", gamma_type)
|
||||
validator.check("batch_std type", batch_std_type,
|
||||
"running_std type", running_std_type)
|
||||
validator.check("batch_std type", batch_std_type,
|
||||
"running_mean type", running_mean_type)
|
||||
validator.check("batch_std_type", batch_std_type,
|
||||
"dout type", dout_type)
|
||||
validator.check_typename(
|
||||
"batch_std type", batch_std_type, (mstype.float16, mstype.float32))
|
||||
validator.check_typename(
|
||||
"global_step type", global_step_type, (mstype.int32,))
|
||||
return gamma_type, gamma_type, gamma_type, gamma_type, gamma_type
|
|
@ -207,7 +207,7 @@ class ReLU6(PrimitiveWithInfer):
|
|||
|
||||
|
||||
class Elu(PrimitiveWithInfer):
|
||||
"""
|
||||
r"""
|
||||
Computes exponential linear: `alpha * (exp(x) - 1)` if x < 0, `x` otherwise.
|
||||
The data type of input tensor should be float.
|
||||
|
||||
|
@ -242,6 +242,40 @@ class Elu(PrimitiveWithInfer):
|
|||
return input_x
|
||||
|
||||
|
||||
class HSwish(PrimitiveWithInfer):
|
||||
r"""
|
||||
Hard swish activation function.
|
||||
|
||||
Applies hswish-type activation element-wise. The input is a Tensor with any valid shape.
|
||||
|
||||
Hard swish is defined as:
|
||||
|
||||
.. math::
|
||||
\text{hswish}(x_{i}) = x_{i} * \frac{ReLU6(x_{i} + 3)}{6},
|
||||
|
||||
where :math:`x_{i}` is the :math:`i`-th slice along the given dim of the input Tensor.
|
||||
|
||||
Inputs:
|
||||
- **input_data** (Tensor) - The input of Hswish.
|
||||
|
||||
Outputs:
|
||||
Tensor, with the same type and shape as the `input_data`.
|
||||
|
||||
"""
|
||||
@prim_attr_register
|
||||
def __init__(self):
|
||||
self.init_prim_io_names(inputs=['x'], outputs=['output'])
|
||||
|
||||
def infer_shape(self, xshape):
|
||||
return xshape
|
||||
|
||||
def infer_dtype(self, x_dtype):
|
||||
validator.check_subclass("x_dtype", x_dtype, mstype.tensor)
|
||||
validator.check_typename("x_dtype", x_dtype, (mstype.float16, mstype.float32))
|
||||
return x_dtype
|
||||
|
||||
|
||||
|
||||
class Sigmoid(PrimitiveWithInfer):
|
||||
r"""
|
||||
Sigmoid activation function.
|
||||
|
@ -258,6 +292,7 @@ class Sigmoid(PrimitiveWithInfer):
|
|||
|
||||
Outputs:
|
||||
Tensor, with the same type and shape as the input_x.
|
||||
|
||||
"""
|
||||
|
||||
@prim_attr_register
|
||||
|
@ -273,6 +308,40 @@ class Sigmoid(PrimitiveWithInfer):
|
|||
return input_x
|
||||
|
||||
|
||||
class HSigmoid(PrimitiveWithInfer):
|
||||
r"""
|
||||
Hard sigmoid activation function.
|
||||
|
||||
Applies hard sigmoid activation element-wise. The input is a Tensor with any valid shape.
|
||||
|
||||
Hard sigmoid is defined as:
|
||||
|
||||
.. math::
|
||||
\text{hsigmoid}(x_{i}) = max(0, min(1, \ftac{2 * x_{i} + 5}{10})),
|
||||
|
||||
where :math:`x_{i}` is the :math:`i`-th slice along the given dim of the input Tensor.
|
||||
|
||||
Inputs:
|
||||
- **input_data** (Tensor) - The input of HSigmoid.
|
||||
|
||||
Outputs:
|
||||
Tensor, with the same type and shape as the `input_data`.
|
||||
|
||||
"""
|
||||
|
||||
@prim_attr_register
|
||||
def __init__(self):
|
||||
self.init_prim_io_names(inputs=['x'], outputs=['output'])
|
||||
|
||||
def infer_shape(self, x_shape):
|
||||
return x_shape
|
||||
|
||||
def infer_dtype(self, x_dtype):
|
||||
validator.check_subclass("x_dtype", x_dtype, mstype.tensor)
|
||||
validator.check_typename("x_dtype", x_dtype, (mstype.float16, mstype.float32))
|
||||
return x_dtype
|
||||
|
||||
|
||||
class Tanh(PrimitiveWithInfer):
|
||||
r"""
|
||||
Tanh activation function.
|
||||
|
|
|
@ -27,11 +27,6 @@ def test_dense_none():
|
|||
nn.Dense(3, 2, None, None)
|
||||
|
||||
|
||||
def test_dense_invalid_activation():
|
||||
with pytest.raises(KeyError):
|
||||
nn.Dense(3, 2, activation='relu6')
|
||||
|
||||
|
||||
@non_graph_engine
|
||||
def test_dense_str_activation():
|
||||
dense = nn.Dense(1, 1, activation='relu')
|
||||
|
|
|
@ -51,11 +51,6 @@ def test_activation_empty():
|
|||
assert nn.get_activation('') is None
|
||||
|
||||
|
||||
def test_activation_invalid():
|
||||
with pytest.raises(KeyError):
|
||||
nn.get_activation('relu6')
|
||||
|
||||
|
||||
# test softmax
|
||||
def test_softmax_axis():
|
||||
layer = nn.Softmax(1)
|
||||
|
|
|
@ -68,11 +68,6 @@ def test_dense_none():
|
|||
nn.Dense(3, 2, None, None)
|
||||
|
||||
|
||||
def test_dense_invalid_activation():
|
||||
with pytest.raises(KeyError):
|
||||
nn.Dense(3, 2, activation='relu6')
|
||||
|
||||
|
||||
def test_dense_str_activation():
|
||||
dense = nn.Dense(1, 1, activation='relu')
|
||||
assert isinstance(dense.activation, nn.ReLU)
|
||||
|
|
Loading…
Reference in New Issue