quantization aware training frontend operators define.

This commit is contained in:
chenzomi 2020-04-10 09:37:57 +08:00
parent c75f75a3e1
commit d64f662c76
12 changed files with 1505 additions and 23 deletions

View File

@ -17,7 +17,7 @@ Layer.
The high-level components(Cells) used to construct the neural network.
"""
from .activation import Softmax, LogSoftmax, ReLU, ReLU6, Tanh, GELU, ELU, Sigmoid, PReLU, get_activation, LeakyReLU
from .activation import Softmax, LogSoftmax, ReLU, ReLU6, Tanh, GELU, ELU, Sigmoid, PReLU, get_activation, LeakyReLU, HSigmoid, HSwish
from .normalization import BatchNorm1d, BatchNorm2d, LayerNorm
from .container import SequentialCell, CellList
from .conv import Conv2d, Conv2dTranspose
@ -26,8 +26,9 @@ from .basic import Dropout, Flatten, Dense, ClipByNorm, Norm, OneHot, ImageGradi
from .embedding import Embedding
from .pooling import AvgPool2d, MaxPool2d
__all__ = ['Softmax', 'LogSoftmax', 'ReLU', 'ReLU6', 'Tanh', 'GELU', 'Sigmoid', 'PReLU', 'get_activation', 'LeakyReLU',
'BatchNorm1d', 'BatchNorm2d', 'LayerNorm', 'ELU',
__all__ = ['Softmax', 'LogSoftmax', 'ReLU', 'ReLU6', 'Tanh', 'GELU', 'Sigmoid',
'PReLU', 'get_activation', 'LeakyReLU', 'HSigmoid', 'HSwish', 'ELU',
'BatchNorm1d', 'BatchNorm2d', 'LayerNorm',
'SequentialCell', 'CellList',
'Conv2d', 'Conv2dTranspose',
'LSTM',

View File

@ -0,0 +1,703 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Aware quantization."""
import numpy as np
import mindspore.nn as nn
import mindspore.common.dtype as mstype
from mindspore.ops import operations as P
from mindspore.ops import functional as F
from mindspore.common.parameter import Parameter
from mindspore.common.initializer import initializer
from mindspore.common.tensor import Tensor
from mindspore._checkparam import check_int_positive, check_bool, twice
from mindspore.nn.cell import Cell
from mindspore.nn.layer.conv import _Conv
from mindspore.nn.layer.activation import get_activation
__all__ = [
'FakeQuantWithMinMax',
'Conv2dBatchNormQuant',
'Conv2dQuant',
'DenseQuant',
'ReLUQuant',
'ReLU6Quant',
'HSwishQuant',
'HSigmoidQuant',
'TensorAddQuant',
]
class FakeQuantWithMinMax(Cell):
r"""
Aware Quantization training op. This OP provide Fake quantization observer function on data with min and max.
Args:
min_init (int, list): The dimension of channel or 1(layer). Default: -6.
max_init (int, list): The dimension of channel or 1(layer). Default: 6.
num_bits (int): Quantization number bit, support 4 and 8bit. Default: 8.
ema (bool): Exponential Moving Average algorithm update min and max. Default: False.
ema_decay (float): Exponential Moving Average algorithm parameter. Default: 0.9999.
per_channel (bool): Quantization by layer or channel. Default: False.
channel_size (int): declarate the min and max channel size, Default: 1.
quant_delay (int): Quantization delay parameters according by global step. Default: 0.
symmetric (bool): Quantization algorithm use symmetric or not. Default: False.
narrow_range (bool): Quantization algorithm use narrow range or not. Default: False.
Inputs:
- **x** (Tensor) - The input of FakeQuantWithMinMax.
Outputs:
Tensor, with the same type and shape as the `x`.
"""
def __init__(self,
min_init=-6,
max_init=6,
num_bits=8,
ema=False,
ema_decay=0.999,
per_channel=False,
channel_size=1,
quant_delay=0,
symmetric=False,
narrow_range=False):
super(FakeQuantWithMinMax, self).__init__()
self.min_init = min_init
self.num_bits = num_bits
self.max_init = max_init
self.ema = ema
self.ema_decay = ema_decay
self.per_channel = per_channel
self.channel_size = channel_size
self.quant_delay = quant_delay
self.symmetric = symmetric
self.narrow_range = narrow_range
if per_channel:
min_array = np.array([self.min_init for i in range(
0, self.channel_size)]).astype(np.float32)
max_array = np.array([self.max_init for i in range(
0, self.channel_size)]).astype(np.float32)
self.fake_quant_train = P.FakeQuantWithMinMaxPerChannel(num_bits=self.num_bits,
ema=self.ema,
ema_decay=self.ema_decay,
quant_delay=self.quant_delay,
symmetric=self.symmetric,
narrow_range=self.narrow_range,
training=True)
self.fake_quant_infer = P.FakeQuantWithMinMaxPerChannel(num_bits=self.num_bits,
ema=self.ema,
ema_decay=ema_decay,
quant_delay=quant_delay,
symmetric=self.symmetric,
narrow_range=self.narrow_range,
training=False)
else:
min_array = np.array([min_init]).reshape(1).astype(np.float32)
max_array = np.array([max_init]).reshape(1).astype(np.float32)
self.fake_quant_train = P.FakeQuantWithMinMax(num_bits=self.num_bits,
ema=self.ema,
ema_decay=self.ema_decay,
quant_delay=self.quant_delay,
symmetric=self.symmetric,
narrow_range=self.narrow_range,
training=True)
self.fake_quant_infer = P.FakeQuantWithMinMax(num_bits=self.num_bits,
ema=self.ema,
ema_decay=ema_decay,
quant_delay=quant_delay,
symmetric=self.symmetric,
narrow_range=self.narrow_range,
training=False)
self.min = Parameter(
Tensor(min_array), name='quant_min', requires_grad=False)
self.max = Parameter(
Tensor(max_array), name='quant_max', requires_grad=False)
def extend_repr(self):
s = 'min_init={}, max_init={}, ema={}, ema_decay={}, per_channel={}, channel_size={}, quant_delay={}'.format(
self.min_init, self.max_init, self.ema, self.ema_decay, self.per_channel, self.channel_size,
self.quant_delay)
return s
def construct(self, x):
if self.training:
out = self.fake_quant_train(x, self.min, self.max)
else:
out = self.fake_quant_infer(x, self.min, self.max)
return out
class Conv2dBatchNormQuant(Cell):
r"""
2D convolution with BatchNormal op folded layer.
For a more Detailed overview of Conv2d op.
Args:
in_channels (int): The number of input channel :math:`C_{in}`.
out_channels (int): The number of output channel :math:`C_{out}`.
kernel_size (Union[int, tuple]): Specifies the height and width of the 2D convolution window.
stride (int): Specifies stride for all spatial dimensions with the same value.
pad_mode: (str): Specifies padding mode. The optional values are "same", "valid", "pad". Default: "same".
padding: (int): Implicit paddings on both sides of the input. Default: 0.
eps (int): Parameters for BatchNormal. Default: 1e-5.
momentum (int): Parameters for BatchNormal op. Default: 0.9.
weight_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the
convolution kernel. Default: 'None'.
beta_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the
beta vector. Default: 'None'.
gamma_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the
gamma vector. Default: 'None'.
mean_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the
mean vector. Default: 'None'.
var_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the
variance vector. Default: 'None'.
quant_delay (int): Quantization delay parameters according by global step. Default: 0.
freeze_bn (int): Quantization freeze BatchNormal op according by global step. Default: 100000.
fake (bool): Conv2dBatchNormQuant Cell add FakeQuantWithMinMax op or not. Default: True.
num_bits (int): Quantization number bit, support 4 and 8bit. Default: 8.
per_channel (bool): FakeQuantWithMinMax Parameters. Default: False.
symmetric (bool): Quantization algorithm use symmetric or not. Default: False.
narrow_range (bool): Quantization algorithm use narrow range or not. Default: False.
Inputs:
- **x** (Tensor) - Tensor of shape :math:`(N, C_{in}, H_{in}, W_{in})`.
Outputs:
Tensor of shape :math:`(N, C_{out}, H_{out}, W_{out})`.
"""
def __init__(self,
in_channels,
out_channels,
kernel_size,
stride,
pad_mode,
padding=0,
eps=1e-5,
momentum=0.9,
weight_init=None,
beta_init=None,
gamma_init=None,
mean_init=None,
var_init=None,
group=1,
quant_delay=0,
freeze_bn=100000,
fake=True,
num_bits=8,
per_channel=False,
symmetric=False,
narrow_range=False):
super(Conv2dBatchNormQuant, self).__init__()
self.stride = stride
self.conv = P.Conv2D(out_channel=out_channels,
kernel_size=kernel_size,
mode=1,
pad_mode=pad_mode,
pad=padding,
stride=stride,
dilation=1,
group=group)
self.fake = fake
self.freeze_bn = freeze_bn
if isinstance(kernel_size, int):
kernel_size = (kernel_size, kernel_size)
if weight_init is None:
weight_init = initializer(
'normal', [out_channels, in_channels // group, *kernel_size])
self.weight = Parameter(weight_init, name='weight')
if gamma_init is None:
gamma_init = initializer('ones', [out_channels])
self.gamma = Parameter(gamma_init, name='gamma')
if beta_init is None:
beta_init = initializer('zeros', [out_channels])
self.beta = Parameter(beta_init, name='beta')
if mean_init is None:
mean_init = initializer('zeros', [out_channels])
self.moving_mean = Parameter(
mean_init, name='moving_mean', requires_grad=False)
if var_init is None:
var_init = initializer('ones', [out_channels])
self.moving_variance = Parameter(
var_init, name='moving_variance', requires_grad=False)
self.step = Parameter(initializer(
'normal', [1], dtype=mstype.int32), name='step', requires_grad=False)
self.fake_quant_weight = nn.FakeQuantWithMinMax(min_init=-6,
max_init=6,
ema=False,
num_bits=num_bits,
quant_delay=quant_delay,
per_channel=per_channel,
channel_size=out_channels,
symmetric=symmetric,
narrow_range=narrow_range)
self.batchnorm_fold_train = P.BatchNormFold(epsilon=eps,
momentum=momentum,
is_training=True,
freeze_bn=freeze_bn)
self.batchnorm_fold_infer = P.BatchNormFold(epsilon=eps,
momentum=momentum,
is_training=False,
freeze_bn=freeze_bn)
self.correct_mul = P.CorrectionMul()
self.relu = P.ReLU()
self.batchnorm_fold2 = P.BatchNormFold2(freeze_bn=freeze_bn)
self.batchnorm_fold2_infer = P.BatchNormFold2(freeze_bn=0)
self.one = Tensor(1, mstype.int32)
self.assignadd = P.AssignAdd()
def extend_repr(self):
s = 'fake={}, freeze_bn={}'.format(self.fake, self.freeze_bn)
return s
def construct(self, x):
if self.training:
beta = self.beta
gamma = self.gamma
gmean = self.moving_mean
gvar = self.moving_variance
step = self.step
out_conv = self.conv(x, self.weight)
batch_mean, batch_std, running_mean, running_std = self.batchnorm_fold_train(
out_conv, gmean, gvar, step)
# BN fold1
weight = self.correct_mul(self.weight, gamma, running_std)
if self.fake:
weight = self.fake_quant_weight(weight)
out = self.conv(x, weight)
# BN fold2
out = self.batchnorm_fold2(
out, beta, gamma, batch_std, batch_mean, running_std, running_mean, step)
F.control_depend(out, self.assignadd(self.step, self.one))
else:
step = self.step
out_conv = self.conv(x, self.weight)
batch_mean, batch_std, running_mean, running_std = self.batchnorm_fold_infer(
out_conv, self.moving_mean, self.moving_variance, step)
weight = self.correct_mul(self.weight, self.gamma, running_std)
if self.fake:
weight = self.fake_quant_weight(weight)
out = self.conv(x, weight)
out = self.batchnorm_fold2_infer(out, self.beta, self.gamma, batch_std, batch_mean,
running_std, running_mean, step)
return out
class Conv2dQuant(_Conv):
r"""
2D convolution with fake quant op layer.
For a more Detailed overview of Conv2d op.
Args:
in_channels (int): The number of input channel :math:`C_{in}`.
out_channels (int): The number of output channel :math:`C_{out}`.
kernel_size (Union[int, tuple]): Specifies the height and width of the 2D convolution window.
stride (int): Specifies stride for all spatial dimensions with the same value. Default: 1.
pad_mode: (str): Specifies padding mode. The optional values are "same", "valid", "pad". Default: "same".
padding: (int): Implicit paddings on both sides of the input. Default: 0.
dilation (int): Specifying the dilation rate to use for dilated convolution. Default: 1.
group (int): Split filter into groups, `in_ channels` and `out_channels` should be
divisible by the number of groups. Default: 1.
has_bias (bool): Specifies whether the layer uses a bias vector. Default: False.
weight_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the convolution kernel.
Default: 'normal'.
bias_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the bias vector. Default: 'zeros'.
quant_delay (int): Quantization delay parameters according by global step. Default: 0.
num_bits (int): Quantization number bit, support 4 and 8bit. Default: 8.
per_channel (bool): FakeQuantWithMinMax Parameters. Default: False.
symmetric (bool): Quantization algorithm use symmetric or not. Default: False.
narrow_range (bool): Quantization algorithm use narrow range or not. Default: False.
Inputs:
- **x** (Tensor) - Tensor of shape :math:`(N, C_{in}, H_{in}, W_{in})`.
Outputs:
Tensor of shape :math:`(N, C_{out}, H_{out}, W_{out})`.
"""
def __init__(self,
in_channels,
out_channels,
kernel_size,
stride=1,
pad_mode='same',
padding=0,
dilation=1,
group=1,
has_bias=False,
weight_init='normal',
bias_init='zeros',
quant_delay=0,
num_bits=8,
per_channel=False,
symmetric=False,
narrow_range=False):
kernel_size = twice(kernel_size)
super(Conv2dQuant, self).__init__(in_channels, out_channels, kernel_size, stride, pad_mode, padding, dilation,
group, has_bias, weight_init, bias_init)
self.conv2d = P.Conv2D(out_channel=self.out_channels, kernel_size=self.kernel_size, mode=1,
pad_mode=self.pad_mode, pad=self.padding, stride=self.stride, dilation=self.dilation,
group=self.group)
self.bias_add = P.BiasAdd()
if pad_mode not in ('valid', 'same', 'pad'):
raise ValueError('Attr \'pad_mode\' of \'Conv2d\' Op passed '
+ str(pad_mode) + ', should be one of values in \'valid\', \'same\', \'pad\'.')
self.fake_quant_weight = nn.FakeQuantWithMinMax(min_init=-6,
max_init=6,
ema=False,
num_bits=num_bits,
quant_delay=quant_delay,
per_channel=per_channel,
channel_size=out_channels,
symmetric=symmetric,
narrow_range=narrow_range)
def construct(self, x):
weight_q = self.fake_quant_weight(self.weight)
out = self.conv2d(x, weight_q)
if self.has_bias:
return self.bias_add(out, self.bias)
return out
class DenseQuant(Cell):
r"""
The fully connected layer with fake quant op.
For a more Detailed overview of Dense op.
Args:
in_channels (int): The dimension of the input space.
out_channels (int): The dimension of the output space.
weight_init (Union[Tensor, str, Initializer, numbers.Number]): The trainable weight_init parameter. The dtype
is same as input x. The values of str refer to the function `initializer`. Default: 'normal'.
bias_init (Union[Tensor, str, Initializer, numbers.Number]): The trainable bias_init parameter. The dtype is
same as input x. The values of str refer to the function `initializer`. Default: 'zeros'.
has_bias (bool): Specifies whether the layer uses a bias vector. Default: True.
activation (str): Regularizer function applied to the output of the layer, eg. 'relu'. Default: None.
num_bits (int): Quantization number bit, support 4 and 8bit. Default: 8.
quant_delay (int): Quantization delay parameters according by global step. Default: 0.
per_channel (bool): FakeQuantWithMinMax Parameters. Default: False.
symmetric (bool): Quantization algorithm use symmetric or not. Default: False.
narrow_range (bool): Quantization algorithm use narrow range or not. Default: False.
Inputs:
- **x** (Tensor) - Tensor of shape :math:`(N, C_{in}, H_{in}, W_{in})`.
Outputs:
Tensor of shape :math:`(N, C_{out}, H_{out}, W_{out})`.
"""
def __init__(
self,
in_channels,
out_channels,
weight_init='normal',
bias_init='zeros',
has_bias=True,
activation=None,
num_bits=8,
quant_delay=0,
per_channel=False,
symmetric=False,
narrow_range=False):
super(DenseQuant, self).__init__()
self.in_channels = check_int_positive(in_channels)
self.out_channels = check_int_positive(out_channels)
self.has_bias = check_bool(has_bias)
if isinstance(weight_init, Tensor):
if weight_init.dim() != 2 or weight_init.shape()[0] != out_channels or \
weight_init.shape()[1] != in_channels:
raise ValueError("weight_init shape error")
self.weight = Parameter(initializer(
weight_init, [out_channels, in_channels]), name="weight")
if self.has_bias:
if isinstance(bias_init, Tensor):
if bias_init.dim() != 1 or bias_init.shape()[0] != out_channels:
raise ValueError("bias_init shape error")
self.bias = Parameter(initializer(
bias_init, [out_channels]), name="bias")
self.matmul = P.MatMul(transpose_b=True)
self.bias_add = P.BiasAdd()
self.activation = get_activation(activation)
self.activation_flag = self.activation is not None
self.fake_quant_weight = nn.FakeQuantWithMinMax(min_init=-6,
max_init=6,
ema=False,
num_bits=num_bits,
quant_delay=quant_delay,
per_channel=per_channel,
channel_size=out_channels,
symmetric=symmetric,
narrow_range=narrow_range)
def construct(self, x):
"""Use operators to construct to Dense layer."""
output = self.fake_quant_weight(self.weight)
output = self.matmul(x, output)
if self.has_bias:
output = self.bias_add(output, self.bias)
if self.activation_flag:
return self.activation(output)
return output
def extend_repr(self):
"""A pretty print for Dense layer."""
str_info = 'in_channels={}, out_channels={}, weight={}, has_bias={}'.format(
self.in_channels, self.out_channels, self.weight, self.has_bias)
if self.has_bias:
str_info = str_info + ', bias={}'.format(self.bias)
if self.activation_flag:
str_info = str_info + ', activation={}'.format(self.activation)
return str_info
class ReLUQuant(Cell):
r"""
ReLUQuant activation function. Add Fake Quant OP after Relu OP.
For a more Detailed overview of ReLU op.
Args:
num_bits (int): Quantization number bit, support 4 and 8bit. Default: 8.
quant_delay (int): Quantization delay parameters according by global step. Default: 0.
symmetric (bool): Quantization algorithm use symmetric or not. Default: False.
narrow_range (bool): Quantization algorithm use narrow range or not. Default: False.
Inputs:
- **x** (Tensor) - The input of ReLUQuant.
Outputs:
Tensor, with the same type and shape as the `x`.
"""
def __init__(self,
num_bits=8,
quant_delay=0,
symmetric=False,
narrow_range=False):
super(ReLUQuant, self).__init__()
self.fake_quant_act = nn.FakeQuantWithMinMax(min_init=0,
max_init=6,
num_bits=num_bits,
quant_delay=quant_delay,
ema=True,
symmetric=symmetric,
narrow_range=narrow_range)
self.relu = P.ReLU()
def construct(self, x):
x = self.relu(x)
x = self.fake_quant_act(x)
return x
class ReLU6Quant(Cell):
r"""
ReLU6Quant activation function.
Add Fake Quant OP after Relu6. Not Recommand to used these cell for Fake Quant Op
Will climp the max range of the activation and the relu6 do the same operation.
For a more Detailed overview of ReLU6 op.
Args:
num_bits (int): Quantization number bit, support 4 and 8bit. Default: 8.
quant_delay (int): Quantization delay parameters according by global step. Default: 0.
symmetric (bool): Quantization algorithm use symmetric or not. Default: False.
narrow_range (bool): Quantization algorithm use narrow range or not. Default: False.
Inputs:
- **x** (Tensor) - The input of ReLU6Quant.
Outputs:
Tensor, with the same type and shape as the `x`.
"""
def __init__(self, num_bits=8, quant_delay=0, symmetric=False,
narrow_range=False):
super(ReLU6Quant, self).__init__()
self.fake_quant_act = nn.FakeQuantWithMinMax(min_init=0,
max_init=6,
num_bits=num_bits,
quant_delay=quant_delay,
ema=True,
symmetric=symmetric,
narrow_range=narrow_range)
self.relu6 = P.ReLU6()
def construct(self, x):
x = self.relu6(x)
x = self.fake_quant_act(x)
return x
class HSwishQuant(Cell):
r"""
HSwishQuant activation function. Add Fake Quant OP after HSwish OP.
For a more Detailed overview of HSwish op.
Args:
num_bits (int): Quantization number bit, support 4 and 8bit. Default: 8.
quant_delay (int): Quantization delay parameters according by global step. Default: 0.
symmetric (bool): Quantization algorithm use symmetric or not. Default: False.
narrow_range (bool): Quantization algorithm use narrow range or not. Default: False.
Inputs:
- **x** (Tensor) - The input of HSwishQuant.
Outputs:
Tensor, with the same type and shape as the `x`.
"""
def __init__(self,
num_bits=8,
quant_delay=0,
symmetric=False,
narrow_range=False):
super(HSwishQuant, self).__init__()
self.fake_quant_act_before = nn.FakeQuantWithMinMax(min_init=0,
max_init=6,
num_bits=num_bits,
quant_delay=quant_delay,
ema=True,
symmetric=symmetric,
narrow_range=narrow_range)
self.fake_quant_act_after = nn.FakeQuantWithMinMax(min_init=0,
max_init=6,
num_bits=num_bits,
quant_delay=quant_delay,
ema=True,
symmetric=symmetric,
narrow_range=narrow_range)
self.act = P.HSwish()
def construct(self, x):
x = self.fake_quant_act_before(x)
x = self.act(x)
x = self.fake_quant_act_after(x)
return x
class HSigmoidQuant(Cell):
r"""
HSigmoidQuant activation function. Add Fake Quant OP before and after HSigmoid OP.
For a more Detailed overview of HSigmoid op.
Args:
num_bits (int): Quantization number bit, support 4 and 8bit. Default: 8.
quant_delay (int): Quantization delay parameters according by global step. Default: 0.
symmetric (bool): Quantization algorithm use symmetric or not. Default: False.
narrow_range (bool): Quantization algorithm use narrow range or not. Default: False.
Inputs:
- **x** (Tensor) - The input of HSigmoidQuant.
Outputs:
Tensor, with the same type and shape as the `x`.
"""
def __init__(self,
num_bits=8,
quant_delay=0,
symmetric=False,
narrow_range=False):
super(HSigmoidQuant, self).__init__()
self.fake_quant_act_before = nn.FakeQuantWithMinMax(min_init=0,
max_init=6,
num_bits=num_bits,
quant_delay=quant_delay,
ema=True,
symmetric=symmetric,
narrow_range=narrow_range)
self.fake_quant_act_after = nn.FakeQuantWithMinMax(min_init=0,
max_init=6,
num_bits=num_bits,
quant_delay=quant_delay,
ema=True,
symmetric=symmetric,
narrow_range=narrow_range)
self.act = P.HSigmoid()
def construct(self, x):
x = self.fake_quant_act_before(x)
x = self.act(x)
x = self.fake_quant_act_after(x)
return x
class TensorAddQuant(Cell):
r"""
Add Fake Quant OP after TensorAdd OP.
For a more Detailed overview of TensorAdd op.
Args:
num_bits (int): Quantization number bit, support 4 and 8bit. Default: 8.
quant_delay (int): Quantization delay parameters according by global step. Default: 0.
symmetric (bool): Quantization algorithm use symmetric or not. Default: False.
narrow_range (bool): Quantization algorithm use narrow range or not. Default: False.
Inputs:
- **x** (Tensor) - The input of TensorAddQuant.
Outputs:
Tensor, with the same type and shape as the `x`.
"""
def __init__(self,
num_bits=8,
quant_delay=0,
symmetric=False,
narrow_range=False):
super(TensorAddQuant, self).__init__()
self.fake_quant_act = nn.FakeQuantWithMinMax(min_init=-6,
max_init=6,
num_bits=num_bits,
quant_delay=quant_delay,
ema=True,
symmetric=symmetric,
narrow_range=narrow_range)
self.add = P.TensorAdd()
def construct(self, x1, x2):
x = self.add(x1, x2)
x = self.fake_quant_act(x)
return x

View File

@ -234,7 +234,7 @@ class Tanh(Cell):
class GELU(Cell):
"""
r"""
Gaussian error linear unit activation function.
Applies GELU function to each element of the input. The input is a Tensor with any valid shape.
@ -332,15 +332,74 @@ class PReLU(Cell):
return v
class HSwish(Cell):
r"""
rHard swish activation function.
Applies hswish-type activation element-wise. The input is a Tensor with any valid shape.
Hard swish is defined as:
.. math::
\text{hswish}(x_{i}) = x_{i} * \frac{ReLU6(x_{i} + 3)}{6},
where :math:`x_{i}` is the :math:`i`-th slice along the given dim of the input Tensor.
Inputs:
- **input_data** (Tensor) - The input of Hswish.
Outputs:
Tensor, with the same type and shape as the `input_data`.
"""
def __init__(self):
super(HSwish, self).__init__()
self.hswish = P.HSwish()
def construct(self, x):
return self.hswish(x)
class HSigmoid(Cell):
r"""
Hard sigmoid activation function.
Applies hard sigmoid activation element-wise. The input is a Tensor with any valid shape.
Hard sigmoid is defined as:
.. math::
\text{hsigmoid}(x_{i}) = max(0, min(1, \ftac{2 * x_{i} + 5}{10})),
where :math:`x_{i}` is the :math:`i`-th slice along the given dim of the input Tensor.
Inputs:
- **input_data** (Tensor) - The input of HSigmoid.
Outputs:
Tensor, with the same type and shape as the `input_data`.
"""
def __init__(self):
super(HSigmoid, self).__init__()
self.hsigmoid = P.HSigmoid()
def construct(self, x):
return self.hsigmoid(x)
_activation = {
'softmax': Softmax,
'logsoftmax': LogSoftmax,
'relu': ReLU,
'relu6': ReLU6,
'tanh': Tanh,
'gelu': GELU,
'sigmoid': Sigmoid,
'prelu': PReLU,
'leakyrelu': LeakyReLU
'leakyrelu': LeakyReLU,
'hswish': HSwish,
'hsigmoid': HSigmoid,
}

View File

@ -172,6 +172,28 @@ def get_bprop_relu6(self):
return bprop
@bprop_getters.register(P.HSwish)
def get_bprop_hswish(self):
"""Grad definition for `HSwish` operation."""
input_grad = G.HSwishGrad()
def bprop(x, out, dout):
dx = input_grad(dout, x)
return (dx,)
return bprop
@bprop_getters.register(P.HSigmoid)
def get_bprop_hsigmoid(self):
"""Grad definition for `HSigmoid` operation."""
input_grad = G.HSigmoidGrad()
def bprop(x, out, dout):
dx = input_grad(dout, x)
return (dx,)
return bprop
@bprop_getters.register(P.Elu)
def get_bprop_elu(self):
"""Grad definition for `Elu` operation."""

View File

@ -0,0 +1,82 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Generate bprop for aware quantization ops"""
from .. import operations as P
from .grad_base import bprop_getters
from ..composite.multitype_ops.zeros_like_impl import zeros_like
@bprop_getters.register(P.FakeQuantWithMinMax)
def get_bprop_fakequant_with_minmax(self):
"""Generate bprop for FakeQuantWithMinMax"""
op = P.FakeQuantWithMinMaxGrad(num_bits=self.num_bits, quant_delay=self.quant_delay)
def bprop(x, x_min, x_max, out, dout):
dx = op(dout, x, x_min, x_max)
return dx, zeros_like(x_min), zeros_like(x_max)
return bprop
@bprop_getters.register(P.FakeQuantWithMinMaxPerChannel)
def get_bprop_fakequant_with_minmax_perchannel(self):
"""Generate bprop for FakeQuantWithMinMaxPerChannel"""
op = P.FakeQuantWithMinMaxPerChannelGrad(num_bits=self.num_bits, quant_delay=self.quant_delay)
def bprop(x, x_min, x_max, out, dout):
dx = op(dout, x, x_min, x_max)
return dx, zeros_like(x_min), zeros_like(x_max)
return bprop
@bprop_getters.register(P.BatchNormFold)
def get_bprop_batchnorm_fold(self):
"""Generate bprop for BatchNormFold"""
op = P.BatchNormFoldGrad(self.epsilon, self.is_training, self.freeze_bn)
def bprop(x, mean, variance, global_step, out, dout):
dx = op(dout[0], dout[1], x, out[0], out[1], global_step)
return dx, zeros_like(mean), zeros_like(variance), zeros_like(global_step)
return bprop
@bprop_getters.register(P.CorrectionMul)
def get_bprop_correction_mul(self):
"""Generate bprop for CorrectionMul"""
grad = P.CorrectionMulGrad()
def bprop(x, batch_std, running_std, out, dout):
dx, d_batch_std = grad(dout, x, batch_std, running_std)
return dx, d_batch_std, zeros_like(running_std)
return bprop
@bprop_getters.register(P.BatchNormFold2)
def get_bprop_batchnorm_fold2(self):
"""Generate bprop for CorrectionAdd"""
op_f = P.BatchNormFold2Grad(freeze_bn=self.freeze_bn)
def bprop(x, beta, gamma, batch_std, batch_mean, running_std, running_mean, global_step, out, dout):
d_batch_std, d_batch_mean, d_beta, d_gamma, d_x = op_f(dout, x, gamma, batch_std, batch_mean, running_std,
running_mean, global_step)
return d_x, d_beta, d_gamma, d_batch_std, d_batch_mean, zeros_like(running_std), zeros_like(running_mean), \
zeros_like(global_step)
return bprop

View File

@ -59,7 +59,7 @@ from .nn_ops import (LSTM, SGD, Adam, ApplyMomentum, BatchNorm,
LogSoftmax,
MaxPool,
AvgPool, Conv2DBackpropInput,
MaxPoolWithArgmax, OneHot, Pad, PReLU, ReLU, ReLU6,
MaxPoolWithArgmax, OneHot, Pad, PReLU, ReLU, ReLU6, HSwish, HSigmoid,
ResizeBilinear, Sigmoid,
SigmoidCrossEntropyWithLogits,
SmoothL1Loss, Softmax,
@ -68,7 +68,8 @@ from .nn_ops import (LSTM, SGD, Adam, ApplyMomentum, BatchNorm,
TopK, BinaryCrossEntropy, SparseApplyAdagrad, LARSUpdate, ApplyFtrl,
ApplyRMSProp, ApplyCenteredRMSProp)
from .other_ops import Assign, IOU, BoundingBoxDecode, BoundingBoxEncode, CheckValid, MakeRefKey
from . import _quant_ops
from ._quant_ops import *
__all__ = [
'TensorAdd',
@ -138,6 +139,8 @@ __all__ = [
'ReLU6',
'Elu',
'Sigmoid',
'HSwish',
'HSigmoid',
'Tanh',
'RandomChoiceWithMask',
'ResizeBilinear',
@ -241,4 +244,5 @@ __all__ = [
"ApplyCenteredRMSProp"
]
__all__.extend(_quant_ops.__all__)
__all__.sort()

View File

@ -805,6 +805,38 @@ class SigmoidGrad(PrimitiveWithInfer):
return out
class HSigmoidGrad(PrimitiveWithInfer):
"""Gets the gradient of HSigmoid operation."""
@prim_attr_register
def __init__(self):
self.init_prim_io_names(inputs=['y_grad', 'x'], outputs=['output'])
def infer_shape(self, y_grad_shape, x_shape):
return x_shape
def infer_dtype(self, y_grad_dtype, x_dtype):
validator.check_typename("y_grad dtype", y_grad_dtype, (mstype.float16, mstype.float32))
validator.check_typename("x dtype", x_dtype, (mstype.float16, mstype.float32))
return x_dtype
class HSwishGrad(PrimitiveWithInfer):
"""Gets the gradient of HSwish operation."""
@prim_attr_register
def __init__(self):
self.init_prim_io_names(inputs=['y_grad', 'x'], outputs=['output'])
def infer_shape(self, y_grad_shape, x_shape):
return x_shape
def infer_dtype(self, y_grad_dtype, x_dtype):
validator.check_typename("y_grad dtype", y_grad_dtype, (mstype.float16, mstype.float32))
validator.check_typename("x_ dtype", x_dtype, (mstype.float16, mstype.float32))
return x_dtype
class SigmoidCrossEntropyWithLogitsGrad(PrimitiveWithInfer):
"""Computes the gradients of `SigmoidCrossEntropyWithLogits`."""

View File

@ -0,0 +1,525 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0(the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http: // www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Operators for quantization."""
from ..._checkparam import ParamValidator as validator
from ..._checkparam import Rel, check_bool, check_int_positive, check_int
from ..primitive import PrimitiveWithInfer, prim_attr_register
from ...common import dtype as mstype
__all__ = ["FakeQuantWithMinMax",
"FakeQuantWithMinMaxGrad",
"FakeQuantWithMinMaxPerChannel",
"FakeQuantWithMinMaxPerChannelGrad",
"BatchNormFold",
"BatchNormFoldGrad",
"CorrectionMul",
"CorrectionMulGrad",
"BatchNormFold2",
"BatchNormFold2Grad",
]
class FakeQuantWithMinMax(PrimitiveWithInfer):
r"""
Simulate the quantize and dequantize operations in training time.
Args:
num_bits (int) : Number bits for aware quantilization. Default: 8.
ema (bool): Use EMA algorithm update value min and max. Default: False.
ema_decay (int) : EMA algorithm decay parameter. Default: 0.999.
quant_delay (int): Quantilization delay parameter. Before delay step in training time not update
simulate aware quantize funcion. After delay step in training time begin simulate the aware
quantize funcion. Default: 0.
symmetric (bool): Quantization algorithm use symmetric or not. Default: False.
narrow_range (bool): Quantization algorithm use narrow range or not. Default: False.
training (bool): Training the network or not. Default: True.
Inputs:
- **x** (Tensor) : float32 Tensor representing the shape of the output tensor.
- **min** (Tensor) : Value of the min range of the input data x.
- **max** (Tensor) : Value of the max range of the input data x.
Outputs:
- Tensor: Simulate quantize tensor of x.
Examples:
>>> input_tensor = Tensor(np.random.rand(3, 16, 5, 5), mstype.float32)
>>> min_tensor = Tensor(np.array([-6]), mstype.float32)
>>> max_tensor = Tensor(np.array([6]), mstype.float32)
>>> output_tensor = P.FakeQuantWithMinMax(num_bits=8)(input_tensor, min_tensor, max_tensor)
"""
support_quant_bit = [4, 7, 8]
@prim_attr_register
def __init__(self, num_bits=8, ema=False, ema_decay=0.999, quant_delay=0, symmetric=False, narrow_range=False,
training=True):
"""init FakeQuantWithMinMax OP"""
if num_bits not in self.support_quant_bit:
raise ValueError("Attr \'num_bits\' is not support.")
if ema and not ema_decay:
raise ValueError(
"Attr \'ema\' and \'ema_decay\' should set together.")
self.ema = check_bool(ema)
self.symmetric = check_bool(symmetric)
self.narrow_range = check_bool(narrow_range)
self.training = check_bool(training)
self.ema_decay = validator.check_number_range(
'ema_decay', ema_decay, 0, 1, Rel.INC_BOTH)
self.num_bits = check_int_positive(num_bits)
self.quant_delay = check_int(quant_delay)
self.init_prim_io_names(inputs=['x', 'min', 'max'],
outputs=['out'])
def infer_shape(self, x_shape, min_shape, max_shape):
validator.check_integer("x shape", len(x_shape), 1, Rel.GT)
validator.check("min shape", min_shape, "max shape", max_shape)
validator.check_integer("min shape", len(min_shape), 1, Rel.EQ)
validator.check_integer("max shape", len(min_shape), 1, Rel.EQ)
return x_shape
def infer_dtype(self, x_type, min_type, max_type):
validator.check_typename(
"x type", x_type, (mstype.float16, mstype.float32))
validator.check_typename("min type", min_type,
(mstype.float16, mstype.float32))
validator.check_typename("max type", max_type,
(mstype.float16, mstype.float32))
return x_type
class FakeQuantWithMinMaxGrad(PrimitiveWithInfer):
"""Performs grad of FakeQuantWithMinMax operation."""
support_quant_bit = [4, 8]
@prim_attr_register
def __init__(self, num_bits=8, quant_delay=0):
if num_bits not in self.support_quant_bit:
raise ValueError("Attr \'num_bits\' is not support.")
self.quant_delay = check_int(quant_delay)
self.num_bits = check_int_positive(num_bits)
self.init_prim_io_names(inputs=['dout', 'x', 'min', 'max'],
outputs=['dx'])
def infer_shape(self, dout_shape, x_shape, min_shape, max_shape):
validator.check("dout shape", dout_shape, "x shape", x_shape)
validator.check("min shape", min_shape, "max shape", max_shape)
validator.check_integer("min shape", len(min_shape), 1, Rel.EQ)
validator.check_integer("max shape", len(min_shape), 1, Rel.EQ)
return dout_shape
def infer_dtype(self, dout_type, x_type, min_type, max_type):
validator.check_typename(
"dout type", dout_type, (mstype.float16, mstype.float32))
validator.check_typename(
"x type", x_type, (mstype.float16, mstype.float32))
validator.check_typename("min type", min_type,
(mstype.float16, mstype.float32))
validator.check_typename("max type", max_type,
(mstype.float16, mstype.float32))
return dout_type
class FakeQuantWithMinMaxPerChannel(PrimitiveWithInfer):
r"""
Simulate the quantize and dequantize operations in training time base on per channel.
Args:
num_bits (int) : Number bits to quantilization. Default: 8.
ema (bool): Use EMA algorithm update tensor min and tensor max. Default: False.
ema_decay (int) : EMA algorithm decay parameter. Default: 0.999.
quant_delay (int): Quantilization delay parameter. Before delay step in training time not
update the weight data to simulate quantize operation. After delay step in training time
begin simulate the quantize operation. Default: 0.
symmetric (bool): Quantization algorithm use symmetric or not. Default: False.
narrow_range (bool): Quantization algorithm use narrow range or not. Default: False.
training (bool): Training the network or not. Default: True.
Inputs:
- **x** (Tensor) : 4-D float32 Tensor representing the shape of the output tensor.
- **min** (int, float) : Value of the min range of the input data.
- **max** (int, float) : Value of the max range of the input data.
Outputs:
- Tensor, has the same type as input.
Examples:
>>> input_tensor = Tensor(np.random.rand(3,4,5,5), mstype.float32)
>>> min_tensor = Tensor(np.array([-6.0, -6.5, -4.0, -5.0]), mstype.float32)
>>> max_tensor = Tensor(np.array([6.0, 6.5, 4.0, 5.0]), mstype.float32)
>>> output_tensor = P.FakeQuantWithMinMax(num_bits=8)(input_tensor, min_tensor, max_tensor)
"""
support_quant_bit = [4, 8]
channel_idx = 0
@prim_attr_register
def __init__(self, num_bits=8, ema=False, ema_decay=0.999, quant_delay=0, symmetric=False, narrow_range=False,
training=True):
"""init FakeQuantWithMinMaxPerChannel OP"""
if num_bits not in self.support_quant_bit:
raise ValueError("Attr \'num_bits\' is not support.")
if ema and not ema_decay:
raise ValueError(
"Attr \'ema\' and \'ema_decay\' should set together.")
self.ema = check_bool(ema)
self.symmetric = check_bool(symmetric)
self.narrow_range = check_bool(narrow_range)
self.training = check_bool(training)
self.ema_decay = validator.check_number_range(
'ema_decay', ema_decay, 0, 1, Rel.INC_BOTH)
self.num_bits = check_int_positive(num_bits)
self.quant_delay = check_int(quant_delay)
self.init_prim_io_names(inputs=['x', 'min', 'max'],
outputs=['out'])
def infer_shape(self, x_shape, min_shape, max_shape):
validator.check_integer("x shape", len(x_shape), 1, Rel.GT)
validator.check_integer(
"min len", min_shape[0], x_shape[self.channel_idx], Rel.EQ)
validator.check_integer(
"max len", max_shape[0], x_shape[self.channel_idx], Rel.EQ)
return x_shape
def infer_dtype(self, x_type, min_type, max_type):
validator.check_typename(
"x type", x_type, (mstype.float16, mstype.float32))
validator.check_typename("min type", min_type,
(mstype.float16, mstype.float32))
validator.check_typename("max type", max_type,
(mstype.float16, mstype.float32))
return x_type
class FakeQuantWithMinMaxPerChannelGrad(PrimitiveWithInfer):
"""Performs grad of FakeQuantWithMinMaxPerChannel operation."""
support_quant_bit = [4, 8]
@prim_attr_register
def __init__(self, num_bits=8, quant_delay=0):
"""init FakeQuantWithMinMaxPerChannel Fill"""
if num_bits not in self.support_quant_bit:
raise ValueError("Attr \'num_bits\' is not support.")
self.quant_delay = check_int(quant_delay)
self.num_bits = check_int_positive(num_bits)
self.init_prim_io_names(inputs=['dout', 'x', 'min', 'max'],
outputs=['dx'])
def infer_shape(self, dout_shape, x_shape, min_shape, max_shape):
validator.check("dout shape", dout_shape, "x shape", x_shape)
validator.check("min shape", min_shape, "max shape", max_shape)
return dout_shape
def infer_dtype(self, dout_type, x_type, min_type, max_type):
validator.check_typename(
"dout", dout_type, (mstype.float16, mstype.float32))
validator.check_typename("x", x_type, (mstype.float16, mstype.float32))
validator.check_typename(
"min", min_type, (mstype.float16, mstype.float32))
validator.check_typename(
"max", max_type, (mstype.float16, mstype.float32))
return dout_type
class BatchNormFold(PrimitiveWithInfer):
"""
Batch normalization folded.
Args:
momentum (float): Momentum value should be [0, 1]. Default: 0.1.
epsilon (float): A small float number to avoid dividing by 0. 1e-12 if dtype in
float32 else 1e-3. Default: 1e-12.
is_training (bool): In training mode set True, else set False. Default: True.
freeze_bn (int): Delay in steps at which computation switches from regular batch
norm to frozen mean and std. Default: 0.
Inputs:
- **x** (Tensor) - Tensor of shape :math:`(N, C)`.
- **mean** (Tensor) - Tensor of shape :math:`(C,)`.
- **variance** (Tensor) - Tensor of shape :math:`(C,)`.
- **global_step** (Tensor) - Tensor to record current global step.
Outputs:
Tuple of 4 Tensor, the normalized input and the updated parameters.
- **batch_mean** (Tensor) - Tensor of shape :math:`(C,)`.
- **batch_std** (Tensor) - Tensor of shape :math:`(C,)`.
- **running_mean** (Tensor) - Tensor of shape :math:`(C,)`.
- **running_std** (Tensor) - Tensor of shape :math:`(C,)`.
"""
channel = 1
@prim_attr_register
def __init__(self, momentum=0.1, epsilon=1e-12, is_training=True, freeze_bn=0):
"""init batch norm fold layer"""
self.momentum = validator.check_number_range(
'momentum', momentum, 0, 1, Rel.INC_BOTH)
self.epsilon = validator.check_float_positive('epsilon', epsilon)
self.is_training = check_bool(is_training)
self.freeze_bn = check_int(freeze_bn)
self.init_prim_io_names(inputs=['x', 'mean', 'variance', 'global_step'],
outputs=['batch_mean', 'batch_std', 'running_mean', 'running_std'])
def infer_shape(self, x_shape, mean_shape, variance_shape, global_step_shape):
validator.check("mean shape", mean_shape,
"gamma_shape", variance_shape)
validator.check("mean_shape size",
mean_shape[0], "input channel", x_shape[self.channel])
validator.check_integer("global_step shape",
len(global_step_shape), 1, Rel.EQ)
return mean_shape, mean_shape, mean_shape, mean_shape
def infer_dtype(self, x_type, mean_type, variance_type, global_step_type):
validator.check("input type", x_type, "mean type", mean_type)
validator.check("input type", x_type, "variance type", variance_type)
validator.check_typename("input type", x_type,
(mstype.float16, mstype.float32))
validator.check_typename(
"global_step type", global_step_type, (mstype.int32,))
return x_type, x_type, x_type, x_type
class BatchNormFoldGrad(PrimitiveWithInfer):
"""Performs grad of BatchNormFold operation."""
channel = 1
@prim_attr_register
def __init__(self, epsilon=1e-12, is_training=True, freeze_bn=0):
"""init BatchNormGrad layer"""
self.is_training = check_bool(is_training)
self.freeze_bn = check_int(freeze_bn)
self.epsilon = validator.check_float_positive('epsilon', epsilon)
self.init_prim_io_names(inputs=['d_batch_mean', 'd_batch_std', 'x', 'batch_mean', 'batch_std', 'global_step'],
outputs=['dx'])
def infer_shape(self, d_batch_mean_shape, d_batch_std_shape, x_shape, batch_mean_shape, batch_std_shape,
global_step_shape):
validator.check("d_batch_mean shape", d_batch_mean_shape,
"d_batch_std shape", d_batch_std_shape)
validator.check("d_batch_mean shape", d_batch_mean_shape,
"batch_mean shape", batch_mean_shape)
validator.check("d_batch_mean shape", d_batch_mean_shape,
"batch_std shape", batch_std_shape)
validator.check(
"x_shape shape", d_batch_mean_shape[0], "input channel", x_shape[self.channel])
validator.check_integer("global_step shape",
len(global_step_shape), 1, Rel.EQ)
return x_shape
def infer_dtype(self, d_batch_mean_type, d_batch_std_type, x_type, batch_mean_type, batch_std_type,
global_step_type):
validator.check("input type", x_type,
"d_batch_mean type", d_batch_mean_type)
validator.check("input type", x_type,
"d_batch_std type", d_batch_std_type)
validator.check("input type", x_type,
"batch_mean type", batch_mean_type)
validator.check("input type", x_type, "batch_std type", batch_std_type)
validator.check_typename("input type", x_type,
(mstype.float16, mstype.float32))
validator.check_typename(
"global_step type", global_step_type, (mstype.int32,))
return x_type
class CorrectionMul(PrimitiveWithInfer):
"""
Scale the weights with a correction factor to the long term statistics
prior to quantization. This ensures that there is no jitter in the quantized weights
due to batch to batch variation.
Inputs:
- **x** (Tensor) - Tensor of shape :math:`(N, C)`.
- **batch_std** (Tensor) - Tensor of shape :math:`(C,)`.
- **running_std** (Tensor) - Tensor of shape :math:`(C,)`.
Outputs:
- **out** (Tensor) - Tensor has the same shape as x.
"""
channel = 0
@prim_attr_register
def __init__(self):
"""init correction mul layer"""
self.init_prim_io_names(inputs=['x', 'batch_std', 'running_std'],
outputs=['out'])
def infer_shape(self, x_shape, batch_std_shape, running_std_shape):
validator.check("batch_std shape", batch_std_shape,
"running_std shape", running_std_shape)
validator.check(
"batch_std size", batch_std_shape[0], "x_shape channel size", x_shape[self.channel])
return x_shape
def infer_dtype(self, x_type, batch_std_type, running_std_type):
validator.check("batch_std type", batch_std_type,
"running_std type", running_std_type)
validator.check("batch_std_type", batch_std_type, "x_type", x_type)
validator.check_typename(
"batch_std type", batch_std_type, (mstype.float16, mstype.float32))
return x_type
class CorrectionMulGrad(PrimitiveWithInfer):
"""Performs grad of CorrectionMul operation."""
channel = 0
@prim_attr_register
def __init__(self):
"""init correction mul layer"""
self.init_prim_io_names(inputs=['dout', 'x', 'gamma', 'running_std'],
outputs=['dx', 'd_gamma'])
def infer_shape(self, dout_shape, x_shape, gamma_shape, running_std_shape):
validator.check("dout shape", dout_shape, "x_shape x", x_shape)
validator.check(
"gamma size", gamma_shape[0], "dout channel size", dout_shape[self.channel])
validator.check(
"running_std size", running_std_shape[0], "dout channel size", dout_shape[self.channel])
return x_shape, gamma_shape
def infer_dtype(self, dout_type, x_type, gamma_type, running_std_type):
validator.check("x type", x_type, "dout type", dout_type)
validator.check("gamma type", gamma_type, "dout type", dout_type)
validator.check("running_std type", running_std_type,
"dout type", dout_type)
validator.check_typename(
"dout type", dout_type, (mstype.float16, mstype.float32))
return x_type, x_type
class BatchNormFold2(PrimitiveWithInfer):
"""
Scale the bias with a correction factor to the long term statistics
prior to quantization. This ensures that there is no jitter in the quantized bias
due to batch to batch variation.
Inputs:
- **x** (Tensor) - Tensor of shape :math:`(N, C)`.
- **beta** (Tensor) - Tensor of shape :math:`(C,)`.
- **gamma** (Tensor) - Tensor of shape :math:`(C,)`.
- **batch_std** (Tensor) - Tensor of shape :math:`(C,)`.
- **batch_mean** (Tensor) - Tensor of shape :math:`(C,)`.
- **running_std** (Tensor) - Tensor of shape :math:`(C,)`.
- **running_mean** (Tensor) - Tensor of shape :math:`(C,)`.
- **global_step** (Tensor) - Tensor to record current global step.
Outputs:
- **y** (Tensor) - Tensor has the same shape as x.
"""
channel = 1
@prim_attr_register
def __init__(self, freeze_bn=0):
"""init conv2d fold layer"""
self.freeze_bn = check_int(freeze_bn)
self.init_prim_io_names(inputs=['x', 'beta', 'gamma', 'batch_std', 'batch_mean',
'running_std', 'running_mean', 'global_step'],
outputs=['y'])
def infer_shape(self, x_shape, beta_shape, gamma_shape, batch_std_shape, running_std_shape, batch_mean_shape,
running_mean_shape, global_step_shape):
validator.check("batch_std shape", batch_std_shape,
"running_std shape", running_std_shape)
validator.check("batch_std shape", batch_std_shape,
"batch_mean shape", batch_mean_shape)
validator.check("batch_std shape", batch_std_shape,
"beta shape", beta_shape)
validator.check("batch_std shape", batch_std_shape,
"running_mean shape", running_mean_shape)
validator.check("batch_std shape", batch_std_shape,
"batch_mean shape", gamma_shape)
validator.check(
"batch_std size", batch_std_shape[0], "x_shape channel size", x_shape[self.channel])
validator.check_integer("global_step shape",
len(global_step_shape), 1, Rel.EQ)
return x_shape
def infer_dtype(self, x_type, beta_type, gamma_type, batch_std_type, running_std_type, batch_mean_type,
running_mean_type, global_step_type):
validator.check("batch_std type", batch_std_type,
"running_std type", running_std_type)
validator.check("batch_std type", batch_std_type,
"batch_mean type", batch_mean_type)
validator.check("batch_std type", batch_std_type,
"beta type", beta_type)
validator.check("batch_std type", batch_std_type,
"running_mean type", running_mean_type)
validator.check("batch_std type", batch_std_type,
"gamma type", gamma_type)
validator.check("x_type", x_type, "batch_std type", batch_std_type)
validator.check_typename(
"batch_std type", batch_std_type, (mstype.float16, mstype.float32))
validator.check_typename(
"global_step type", global_step_type, (mstype.int32,))
return x_type
class BatchNormFold2Grad(PrimitiveWithInfer):
"""Performs grad of CorrectionAddGrad operation."""
channel = 1
@prim_attr_register
def __init__(self, freeze_bn=0):
"""init MulFold layer"""
self.freeze_bn = freeze_bn
self.init_prim_io_names(inputs=['dout', 'x', 'gamma',
'batch_std', 'batch_mean',
'running_std', 'running_mean', 'global_step'],
outputs=['d_batch_std', 'd_batch_mean', 'd_beta', 'd_gamma', 'dx'])
def infer_shape(self, dout_shape, x_shape, gamma_shape,
batch_std_shape, batch_mean_shape,
running_std_shape, running_mean_shape, global_step_shape):
validator.check("batch_std shape", batch_std_shape,
"batch_mean shape", batch_mean_shape)
validator.check("batch_std shape", batch_std_shape,
"running_std shape", running_std_shape)
validator.check("batch_std shape", batch_std_shape,
"running_mean shape", running_mean_shape)
validator.check("batch_std shape", batch_std_shape,
"gamma shape", gamma_shape)
validator.check(
"batch_std size", batch_std_shape[0], "dout channel size", dout_shape[self.channel])
validator.check_integer("global_step shape",
len(global_step_shape), 1, Rel.EQ)
return gamma_shape, gamma_shape, gamma_shape, gamma_shape, x_shape
def infer_dtype(self, dout_type, x_type, gamma_type,
batch_std_type, batch_mean_type,
running_std_type, running_mean_type, global_step_type):
validator.check("batch_std type", batch_std_type,
"batch_mean type", batch_mean_type)
validator.check("batch_std type", batch_std_type,
"gamma type", gamma_type)
validator.check("batch_std type", batch_std_type,
"running_std type", running_std_type)
validator.check("batch_std type", batch_std_type,
"running_mean type", running_mean_type)
validator.check("batch_std_type", batch_std_type,
"dout type", dout_type)
validator.check_typename(
"batch_std type", batch_std_type, (mstype.float16, mstype.float32))
validator.check_typename(
"global_step type", global_step_type, (mstype.int32,))
return gamma_type, gamma_type, gamma_type, gamma_type, gamma_type

View File

@ -207,7 +207,7 @@ class ReLU6(PrimitiveWithInfer):
class Elu(PrimitiveWithInfer):
"""
r"""
Computes exponential linear: `alpha * (exp(x) - 1)` if x < 0, `x` otherwise.
The data type of input tensor should be float.
@ -242,6 +242,40 @@ class Elu(PrimitiveWithInfer):
return input_x
class HSwish(PrimitiveWithInfer):
r"""
Hard swish activation function.
Applies hswish-type activation element-wise. The input is a Tensor with any valid shape.
Hard swish is defined as:
.. math::
\text{hswish}(x_{i}) = x_{i} * \frac{ReLU6(x_{i} + 3)}{6},
where :math:`x_{i}` is the :math:`i`-th slice along the given dim of the input Tensor.
Inputs:
- **input_data** (Tensor) - The input of Hswish.
Outputs:
Tensor, with the same type and shape as the `input_data`.
"""
@prim_attr_register
def __init__(self):
self.init_prim_io_names(inputs=['x'], outputs=['output'])
def infer_shape(self, xshape):
return xshape
def infer_dtype(self, x_dtype):
validator.check_subclass("x_dtype", x_dtype, mstype.tensor)
validator.check_typename("x_dtype", x_dtype, (mstype.float16, mstype.float32))
return x_dtype
class Sigmoid(PrimitiveWithInfer):
r"""
Sigmoid activation function.
@ -258,6 +292,7 @@ class Sigmoid(PrimitiveWithInfer):
Outputs:
Tensor, with the same type and shape as the input_x.
"""
@prim_attr_register
@ -273,6 +308,40 @@ class Sigmoid(PrimitiveWithInfer):
return input_x
class HSigmoid(PrimitiveWithInfer):
r"""
Hard sigmoid activation function.
Applies hard sigmoid activation element-wise. The input is a Tensor with any valid shape.
Hard sigmoid is defined as:
.. math::
\text{hsigmoid}(x_{i}) = max(0, min(1, \ftac{2 * x_{i} + 5}{10})),
where :math:`x_{i}` is the :math:`i`-th slice along the given dim of the input Tensor.
Inputs:
- **input_data** (Tensor) - The input of HSigmoid.
Outputs:
Tensor, with the same type and shape as the `input_data`.
"""
@prim_attr_register
def __init__(self):
self.init_prim_io_names(inputs=['x'], outputs=['output'])
def infer_shape(self, x_shape):
return x_shape
def infer_dtype(self, x_dtype):
validator.check_subclass("x_dtype", x_dtype, mstype.tensor)
validator.check_typename("x_dtype", x_dtype, (mstype.float16, mstype.float32))
return x_dtype
class Tanh(PrimitiveWithInfer):
r"""
Tanh activation function.

View File

@ -27,11 +27,6 @@ def test_dense_none():
nn.Dense(3, 2, None, None)
def test_dense_invalid_activation():
with pytest.raises(KeyError):
nn.Dense(3, 2, activation='relu6')
@non_graph_engine
def test_dense_str_activation():
dense = nn.Dense(1, 1, activation='relu')

View File

@ -51,11 +51,6 @@ def test_activation_empty():
assert nn.get_activation('') is None
def test_activation_invalid():
with pytest.raises(KeyError):
nn.get_activation('relu6')
# test softmax
def test_softmax_axis():
layer = nn.Softmax(1)

View File

@ -68,11 +68,6 @@ def test_dense_none():
nn.Dense(3, 2, None, None)
def test_dense_invalid_activation():
with pytest.raises(KeyError):
nn.Dense(3, 2, activation='relu6')
def test_dense_str_activation():
dense = nn.Dense(1, 1, activation='relu')
assert isinstance(dense.activation, nn.ReLU)