fix perchannel num_channels not set bug and adjust quant.py params order

This commit is contained in:
wangdongxu 2020-06-20 17:53:47 +08:00
parent 3e3cbbba0f
commit f110c7616b
5 changed files with 607 additions and 407 deletions

View File

@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Aware quantization."""
"""Quantization aware."""
from functools import partial
import numpy as np
@ -27,9 +27,16 @@ from mindspore._checkparam import Validator as validator, Rel
from mindspore.nn.cell import Cell
from mindspore.nn.layer.activation import get_activation
import mindspore.context as context
from .normalization import BatchNorm2d
from .activation import get_activation
from ..cell import Cell
from . import conv, basic
from ..._checkparam import ParamValidator as validator
from ...ops.operations import _quant_ops as Q
__all__ = [
'Conv2dBnAct',
'DenseBnAct',
'FakeQuantWithMinMax',
'Conv2dBatchNormQuant',
'Conv2dQuant',
@ -43,6 +50,165 @@ __all__ = [
]
class Conv2dBnAct(Cell):
r"""
A combination of convolution, Batchnorm, activation layer.
For a more Detailed overview of Conv2d op.
Args:
in_channels (int): The number of input channel :math:`C_{in}`.
out_channels (int): The number of output channel :math:`C_{out}`.
kernel_size (Union[int, tuple]): The data type is int or tuple with 2 integers. Specifies the height
and width of the 2D convolution window. Single int means the value if for both height and width of
the kernel. A tuple of 2 ints means the first value is for the height and the other is for the
width of the kernel.
stride (int): Specifies stride for all spatial dimensions with the same value. Value of stride should be
greater or equal to 1 but bounded by the height and width of the input. Default: 1.
pad_mode (str): Specifies padding mode. The optional values are "same", "valid", "pad". Default: "same".
padding (int): Implicit paddings on both sides of the input. Default: 0.
dilation (int): Specifying the dilation rate to use for dilated convolution. If set to be :math:`k > 1`,
there will be :math:`k - 1` pixels skipped for each sampling location. Its value should be greater
or equal to 1 and bounded by the height and width of the input. Default: 1.
group (int): Split filter into groups, `in_ channels` and `out_channels` should be
divisible by the number of groups. Default: 1.
has_bias (bool): Specifies whether the layer uses a bias vector. Default: False.
weight_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the convolution kernel.
It can be a Tensor, a string, an Initializer or a numbers.Number. When a string is specified,
values from 'TruncatedNormal', 'Normal', 'Uniform', 'HeUniform' and 'XavierUniform' distributions as well
as constant 'One' and 'Zero' distributions are possible. Alias 'xavier_uniform', 'he_uniform', 'ones'
and 'zeros' are acceptable. Uppercase and lowercase are both acceptable. Refer to the values of
Initializer for more details. Default: 'normal'.
bias_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the bias vector. Possible
Initializer and string are the same as 'weight_init'. Refer to the values of
Initializer for more details. Default: 'zeros'.
batchnorm (bool): Specifies to used batchnorm or not. Default: None.
activation (string): Specifies activation type. The optional values are as following:
'softmax', 'logsoftmax', 'relu', 'relu6', 'tanh', 'gelu', 'sigmoid',
'prelu', 'leakyrelu', 'hswish', 'hsigmoid'. Default: None.
Inputs:
- **input** (Tensor) - Tensor of shape :math:`(N, C_{in}, H_{in}, W_{in})`.
Outputs:
Tensor of shape :math:`(N, C_{out}, H_{out}, W_{out})`.
Examples:
>>> net = Conv2dBnAct(120, 240, 4, batchnorm=True, activation='ReLU')
>>> input = Tensor(np.ones([1, 120, 1024, 640]), mindspore.float32)
>>> net(input).shape
(1, 240, 1024, 640)
"""
def __init__(self,
in_channels,
out_channels,
kernel_size,
stride=1,
pad_mode='same',
padding=0,
dilation=1,
group=1,
has_bias=False,
weight_init='normal',
bias_init='zeros',
batchnorm=None,
activation=None):
super(Conv2dBnAct, self).__init__()
self.conv = conv.Conv2d(
in_channels,
out_channels,
kernel_size,
stride,
pad_mode,
padding,
dilation,
group,
has_bias,
weight_init,
bias_init)
self.has_bn = batchnorm is not None
self.has_act = activation is not None
self.batchnorm = batchnorm
if batchnorm is True:
self.batchnorm = BatchNorm2d(out_channels)
elif batchnorm is not None:
validator.check_isinstance('batchnorm', batchnorm, (BatchNorm2d,))
self.activation = get_activation(activation)
def construct(self, x):
x = self.conv(x)
if self.has_bn:
x = self.batchnorm(x)
if self.has_act:
x = self.activation(x)
return x
class DenseBnAct(Cell):
r"""
A combination of Dense, Batchnorm, activation layer.
For a more Detailed overview of Dense op.
Args:
in_channels (int): The number of channels in the input space.
out_channels (int): The number of channels in the output space.
weight_init (Union[Tensor, str, Initializer, numbers.Number]): The trainable weight_init parameter. The dtype
is same as input x. The values of str refer to the function `initializer`. Default: 'normal'.
bias_init (Union[Tensor, str, Initializer, numbers.Number]): The trainable bias_init parameter. The dtype is
same as input x. The values of str refer to the function `initializer`. Default: 'zeros'.
has_bias (bool): Specifies whether the layer uses a bias vector. Default: True.
activation (str): Regularizer function applied to the output of the layer, eg. 'relu'. Default: None.
batchnorm (bool): Specifies to used batchnorm or not. Default: None.
activation (string): Specifies activation type. The optional values are as following:
'softmax', 'logsoftmax', 'relu', 'relu6', 'tanh', 'gelu', 'sigmoid',
'prelu', 'leakyrelu', 'hswish', 'hsigmoid'. Default: None.
Inputs:
- **input** (Tensor) - Tensor of shape :math:`(N, in\_channels)`.
Outputs:
Tensor of shape :math:`(N, out\_channels)`.
Examples:
>>> net = nn.DenseBnAct(3, 4)
>>> input = Tensor(np.random.randint(0, 255, [2, 3]), mindspore.float32)
>>> net(input)
"""
def __init__(self,
in_channels,
out_channels,
weight_init='normal',
bias_init='zeros',
has_bias=True,
batchnorm=None,
activation=None):
super(DenseBnAct, self).__init__()
self.dense = basic.Dense(
in_channels,
out_channels,
weight_init,
bias_init,
has_bias)
self.has_bn = batchnorm is not None
self.has_act = activation is not None
if batchnorm is True:
self.batchnorm = BatchNorm2d(out_channels)
elif batchnorm is not None:
validator.check_isinstance('batchnorm', batchnorm, (BatchNorm2d,))
self.activation = get_activation(activation)
def construct(self, x):
x = self.dense(x)
if self.has_bn:
x = self.batchnorm(x)
if self.has_act:
x = self.activation(x)
return x
class BatchNormFoldCell(Cell):
"""
Batch normalization folded.
@ -105,20 +271,20 @@ class BatchNormFoldCell(Cell):
class FakeQuantWithMinMax(Cell):
r"""
Aware Quantization op. This OP provide Fake quantization observer function on data with min and max.
Quantization aware op. This OP provide Fake quantization observer function on data with min and max.
Args:
min_init (int, float): The dimension of channel or 1(layer). Default: -6.
max_init (int, float): The dimension of channel or 1(layer). Default: 6.
num_bits (int): Quantization number bit, support 4 and 8bit. Default: 8.
ema (bool): Exponential Moving Average algorithm update min and max. Default: False.
ema_decay (float): Exponential Moving Average algorithm parameter. Default: 0.999.
per_channel (bool): Quantization by layer or channel. Default: False.
per_channel (bool): Quantization granularity based on layer or on channel. Default: False.
channel_axis (int): Quantization by channel axis. Default: 1.
out_channels (int): declarate the min and max channel size, Default: 1.
quant_delay (int): Quantization delay parameters according by global step. Default: 0.
num_channels (int): declarate the min and max channel size, Default: 1.
num_bits (int): Quantization number bit, support 4 and 8bit. Default: 8.
symmetric (bool): Quantization algorithm use symmetric or not. Default: False.
narrow_range (bool): Quantization algorithm use narrow range or not. Default: False.
quant_delay (int): Quantization delay parameters according by global step. Default: 0.
Inputs:
- **x** (Tensor) - The input of FakeQuantWithMinMax.
@ -135,15 +301,15 @@ class FakeQuantWithMinMax(Cell):
def __init__(self,
min_init=-6,
max_init=6,
num_bits=8,
ema=False,
ema_decay=0.999,
per_channel=False,
channel_axis=1,
out_channels=1,
quant_delay=0,
num_channels=1,
num_bits=8,
symmetric=False,
narrow_range=False):
narrow_range=False,
quant_delay=0):
"""init FakeQuantWithMinMax layer"""
super(FakeQuantWithMinMax, self).__init__()
self.min_init = min_init
@ -152,7 +318,7 @@ class FakeQuantWithMinMax(Cell):
self.ema = ema
self.ema_decay = ema_decay
self.per_channel = per_channel
self.out_channels = out_channels
self.num_channels = num_channels
self.channel_axis = channel_axis
self.quant_delay = quant_delay
self.symmetric = symmetric
@ -161,54 +327,54 @@ class FakeQuantWithMinMax(Cell):
# init tensor min and max for fake quant op
if self.per_channel:
min_array = np.array([self.min_init for i in range(0, self.out_channels)]).astype(np.float32)
max_array = np.array([self.max_init for i in range(0, self.out_channels)]).astype(np.float32)
min_array = np.array([self.min_init] * self.num_channels).astype(np.float32)
max_array = np.array([self.max_init] * self.num_channels).astype(np.float32)
else:
min_array = np.array([self.min_init]).reshape(1).astype(np.float32)
max_array = np.array([self.max_init]).reshape(1).astype(np.float32)
min_array = np.array([self.min_init]).astype(np.float32)
max_array = np.array([self.max_init]).astype(np.float32)
self.minq = Parameter(Tensor(min_array), name='quant_min', requires_grad=False)
self.maxq = Parameter(Tensor(max_array), name='quant_max', requires_grad=False)
# init fake quant relative op
if per_channel:
quant_fun = partial(Q.FakeQuantPerChannel, channel_axis=self.channel_axis)
ema_fun = partial(Q.FakeQuantMinMaxPerChannelUpdate, channel_axis=self.channel_axis)
ema_fun = partial(Q.MinMaxUpdatePerChannel, channel_axis=self.channel_axis)
else:
quant_fun = Q.FakeQuantPerLayer
ema_fun = Q.FakeQuantMinMaxPerLayerUpdate
ema_fun = Q.MinMaxUpdatePerLayer
self.ema_update = ema_fun(ema=self.ema, ema_decay=self.ema_decay)
if self.is_ascend:
self.fake_quant = quant_fun(num_bits=self.num_bits,
symmetric=self.symmetric,
narrow_range=self.narrow_range)
self.fake_quant_train = quant_fun(num_bits=self.num_bits,
symmetric=self.symmetric,
narrow_range=self.narrow_range)
self.fake_quant_infer = self.fake_quant_train
else:
self.fake_quant = quant_fun(num_bits=self.num_bits,
ema=self.ema,
ema_decay=ema_decay,
quant_delay=quant_delay,
symmetric=self.symmetric,
narrow_range=self.narrow_range)
self.ema_update = ema_fun(num_bits=self.num_bits,
ema=self.ema,
ema_decay=self.ema_decay,
symmetric=self.symmetric,
narrow_range=self.narrow_range)
quant_fun = partial(quant_fun,
ema=self.ema,
ema_decay=ema_decay,
num_bits=self.num_bits,
symmetric=self.symmetric,
narrow_range=self.narrow_range,
quant_delay=quant_delay)
self.fake_quant_train = quant_fun(training=True)
self.fake_quant_infer = quant_fun(training=False)
def extend_repr(self):
s = 'num_bits={}, symmetric={}, narrow_range={}, ema={}({}), per_channel={}({}, {}), ' \
'quant_delay={}, min_init={}, max_init={}'.format(
self.num_bits, self.symmetric, self.narrow_range, self.ema, self.ema_decay, self.per_channel,
self.channel_axis, self.out_channels, self.quant_delay, self.min_init, self.max_init)
self.channel_axis, self.num_channels, self.quant_delay, self.min_init, self.max_init)
return s
def construct(self, x):
if self.is_ascend and self.training:
if self.training:
min_up, max_up = self.ema_update(x, self.minq, self.maxq)
out = self.fake_quant(x, min_up, max_up)
P.Assign()(self.minq, min_up)
P.Assign()(self.maxq, max_up)
out = self.fake_quant_train(x, self.minq, self.maxq)
else:
out = self.fake_quant(x, self.minq, self.maxq)
out = self.fake_quant_infer(x, self.minq, self.maxq)
return out
@ -225,8 +391,8 @@ class Conv2dBatchNormQuant(Cell):
stride (int): Specifies stride for all spatial dimensions with the same value.
pad_mode: (str): Specifies padding mode. The optional values are "same", "valid", "pad". Default: "same".
padding: (int): Implicit paddings on both sides of the input. Default: 0.
eps (int): Parameters for BatchNormal. Default: 1e-5.
momentum (int): Parameters for BatchNormal op. Default: 0.997.
eps (float): Parameters for BatchNormal. Default: 1e-5.
momentum (float): Parameters for BatchNormal op. Default: 0.997.
weight_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the
convolution kernel. Default: 'normal'.
beta_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the
@ -237,13 +403,13 @@ class Conv2dBatchNormQuant(Cell):
mean vector. Default: 'zeros'.
var_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the
variance vector. Default: 'ones'.
quant_delay (int): Quantization delay parameters according by global step. Default: 0.
freeze_bn (int): Quantization freeze BatchNormal op according by global step. Default: 100000.
fake (bool): Conv2dBatchNormQuant Cell add FakeQuantWithMinMax op or not. Default: True.
num_bits (int): Quantization number bit, support 4 and 8bit. Default: 8.
per_channel (bool): FakeQuantWithMinMax Parameters. Default: False.
num_bits (int): Quantization number bit, support 4 and 8bit. Default: 8.
symmetric (bool): Quantization algorithm use symmetric or not. Default: False.
narrow_range (bool): Quantization algorithm use narrow range or not. Default: False.
quant_delay (int): Quantization delay parameters according by global step. Default: 0.
freeze_bn (int): Quantization freeze BatchNormal op according by global step. Default: 100000.
Inputs:
- **x** (Tensor) - Tensor of shape :math:`(N, C_{in}, H_{in}, W_{in})`.
@ -274,13 +440,13 @@ class Conv2dBatchNormQuant(Cell):
gamma_init='ones',
mean_init='zeros',
var_init='ones',
quant_delay=0,
freeze_bn=100000,
fake=True,
num_bits=8,
per_channel=False,
num_bits=8,
symmetric=False,
narrow_range=False):
narrow_range=False,
quant_delay=0,
freeze_bn=100000):
"""init Conv2dBatchNormQuant layer"""
super(Conv2dBatchNormQuant, self).__init__()
self.in_channels = in_channels
@ -304,8 +470,8 @@ class Conv2dBatchNormQuant(Cell):
# initialize convolution op and Parameter
if context.get_context('device_target') == "Ascend" and group > 1:
validator.check_integer('group', group, in_channels, Rel.EQ, 'Conv2dBatchNormQuant')
validator.check_integer('group', group, out_channels, Rel.EQ, 'Conv2dBatchNormQuant')
validator.check_integer('group', group, in_channels, Rel.EQ)
validator.check_integer('group', group, out_channels, Rel.EQ)
self.conv = P.DepthwiseConv2dNative(channel_multiplier=1,
kernel_size=self.kernel_size,
pad_mode=pad_mode,
@ -337,12 +503,13 @@ class Conv2dBatchNormQuant(Cell):
self.fake_quant_weight = FakeQuantWithMinMax(min_init=-6,
max_init=6,
ema=False,
num_bits=num_bits,
quant_delay=quant_delay,
per_channel=per_channel,
out_channels=out_channels,
channel_axis=channel_axis,
num_channels=out_channels,
num_bits=num_bits,
symmetric=symmetric,
narrow_range=narrow_range)
narrow_range=narrow_range,
quant_delay=quant_delay)
self.batchnorm_fold = BatchNormFoldCell(epsilon=eps, momentum=momentum, freeze_bn=freeze_bn)
self.correct_mul = Q.CorrectionMul(channel_axis)
if context.get_context('device_target') == "Ascend":
@ -416,11 +583,11 @@ class Conv2dQuant(Cell):
weight_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the convolution kernel.
Default: 'normal'.
bias_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the bias vector. Default: 'zeros'.
quant_delay (int): Quantization delay parameters according by global step. Default: 0.
num_bits (int): Quantization number bit, support 4 and 8bit. Default: 8.
per_channel (bool): FakeQuantWithMinMax Parameters. Default: False.
num_bits (int): Quantization number bit, support 4 and 8bit. Default: 8.
symmetric (bool): Quantization algorithm use symmetric or not. Default: False.
narrow_range (bool): Quantization algorithm use narrow range or not. Default: False.
quant_delay (int): Quantization delay parameters according by global step. Default: 0.
Inputs:
- **x** (Tensor) - Tensor of shape :math:`(N, C_{in}, H_{in}, W_{in})`.
@ -447,11 +614,11 @@ class Conv2dQuant(Cell):
has_bias=False,
weight_init='normal',
bias_init='zeros',
quant_delay=0,
num_bits=8,
per_channel=False,
num_bits=8,
symmetric=False,
narrow_range=False):
narrow_range=False,
quant_delay=0):
super(Conv2dQuant, self).__init__()
if isinstance(kernel_size, int):
self.kernel_size = (kernel_size, kernel_size)
@ -487,12 +654,13 @@ class Conv2dQuant(Cell):
self.fake_quant_weight = FakeQuantWithMinMax(min_init=-6,
max_init=6,
ema=False,
num_bits=num_bits,
quant_delay=quant_delay,
per_channel=per_channel,
out_channels=out_channels,
channel_axis=0,
num_channels=out_channels,
num_bits=num_bits,
symmetric=symmetric,
narrow_range=narrow_range)
narrow_range=narrow_range,
quant_delay=quant_delay)
def construct(self, x):
weight = self.fake_quant_weight(self.weight)
@ -526,11 +694,11 @@ class DenseQuant(Cell):
same as input x. The values of str refer to the function `initializer`. Default: 'zeros'.
has_bias (bool): Specifies whether the layer uses a bias vector. Default: True.
activation (str): Regularizer function applied to the output of the layer, eg. 'relu'. Default: None.
num_bits (int): Quantization number bit, support 4 and 8bit. Default: 8.
quant_delay (int): Quantization delay parameters according by global step. Default: 0.
per_channel (bool): FakeQuantWithMinMax Parameters. Default: False.
num_bits (int): Quantization number bit, support 4 and 8bit. Default: 8.
symmetric (bool): Quantization algorithm use symmetric or not. Default: False.
narrow_range (bool): Quantization algorithm use narrow range or not. Default: False.
quant_delay (int): Quantization delay parameters according by global step. Default: 0.
Inputs:
- **x** (Tensor) - Tensor of shape :math:`(N, C_{in}, H_{in}, W_{in})`.
@ -552,11 +720,11 @@ class DenseQuant(Cell):
bias_init='zeros',
has_bias=True,
activation=None,
num_bits=8,
quant_delay=0,
per_channel=False,
num_bits=8,
symmetric=False,
narrow_range=False):
narrow_range=False,
quant_delay=0):
super(DenseQuant, self).__init__()
self.in_channels = check_int_positive(in_channels)
self.out_channels = check_int_positive(out_channels)
@ -586,12 +754,13 @@ class DenseQuant(Cell):
self.fake_quant_weight = FakeQuantWithMinMax(min_init=-6,
max_init=6,
ema=False,
num_bits=num_bits,
quant_delay=quant_delay,
per_channel=per_channel,
out_channels=out_channels,
channel_axis=0,
num_channels=out_channels,
num_bits=num_bits,
symmetric=symmetric,
narrow_range=narrow_range)
narrow_range=narrow_range,
quant_delay=quant_delay)
def construct(self, x):
"""Use operators to construct to Dense layer."""
@ -615,18 +784,28 @@ class DenseQuant(Cell):
return str_info
class ReLUQuant(Cell):
class _QuantActivation(Cell):
r"""
Base class for Quant activation function. Add Fake Quant OP after activation OP.
"""
def get_origin(self):
raise NotImplementedError
class ReLUQuant(_QuantActivation):
r"""
ReLUQuant activation function. Add Fake Quant OP after Relu OP.
For a more Detailed overview of ReLU op.
Args:
num_bits (int): Quantization number bit, support 4 and 8bit. Default: 8.
quant_delay (int): Quantization delay parameters according by global step. Default: 0.
ema_decay (float): Exponential Moving Average algorithm parameter. Default: 0.999.
per_channel (bool): Quantization granularity based on layer or on channel. Default: False.
num_bits (int): Quantization number bit, support 4 and 8bit. Default: 8.
symmetric (bool): Quantization algorithm use symmetric or not. Default: False.
narrow_range (bool): Quantization algorithm use narrow range or not. Default: False.
quant_delay (int): Quantization delay parameters according by global step. Default: 0.
Inputs:
- **x** (Tensor) - The input of ReLUQuant.
@ -641,20 +820,22 @@ class ReLUQuant(Cell):
"""
def __init__(self,
num_bits=8,
quant_delay=0,
ema_decay=0.999,
per_channel=False,
num_bits=8,
symmetric=False,
narrow_range=False):
narrow_range=False,
quant_delay=0):
super(ReLUQuant, self).__init__()
self.fake_quant_act = FakeQuantWithMinMax(min_init=0,
max_init=6,
num_bits=num_bits,
quant_delay=quant_delay,
ema=True,
ema_decay=ema_decay,
per_channel=per_channel,
num_bits=num_bits,
symmetric=symmetric,
narrow_range=narrow_range)
narrow_range=narrow_range,
quant_delay=quant_delay)
self.relu = P.ReLU()
def construct(self, x):
@ -662,8 +843,11 @@ class ReLUQuant(Cell):
x = self.fake_quant_act(x)
return x
def get_origin(self):
return self.relu
class ReLU6Quant(Cell):
class ReLU6Quant(_QuantActivation):
r"""
ReLU6Quant activation function.
@ -672,11 +856,12 @@ class ReLU6Quant(Cell):
For a more Detailed overview of ReLU6 op.
Args:
num_bits (int): Quantization number bit, support 4 and 8bit. Default: 8.
quant_delay (int): Quantization delay parameters according by global step. Default: 0.
ema_decay (float): Exponential Moving Average algorithm parameter. Default: 0.999.
per_channel (bool): Quantization granularity based on layer or on channel. Default: False.
num_bits (int): Quantization number bit, support 4 and 8bit. Default: 8.
symmetric (bool): Quantization algorithm use symmetric or not. Default: False.
narrow_range (bool): Quantization algorithm use narrow range or not. Default: False.
quant_delay (int): Quantization delay parameters according by global step. Default: 0.
Inputs:
- **x** (Tensor) - The input of ReLU6Quant.
@ -691,20 +876,22 @@ class ReLU6Quant(Cell):
"""
def __init__(self,
num_bits=8,
quant_delay=0,
ema_decay=0.999,
per_channel=False,
num_bits=8,
symmetric=False,
narrow_range=False):
narrow_range=False,
quant_delay=0):
super(ReLU6Quant, self).__init__()
self.fake_quant_act = FakeQuantWithMinMax(min_init=0,
max_init=6,
num_bits=num_bits,
quant_delay=quant_delay,
ema=True,
ema_decay=ema_decay,
per_channel=per_channel,
num_bits=num_bits,
symmetric=symmetric,
narrow_range=narrow_range)
narrow_range=narrow_range,
quant_delay=quant_delay)
self.relu6 = P.ReLU6()
def construct(self, x):
@ -712,19 +899,23 @@ class ReLU6Quant(Cell):
x = self.fake_quant_act(x)
return x
def get_origin(self):
return self.relu6
class HSwishQuant(Cell):
class HSwishQuant(_QuantActivation):
r"""
HSwishQuant activation function. Add Fake Quant OP after HSwish OP.
For a more Detailed overview of HSwish op.
Args:
num_bits (int): Quantization number bit, support 4 and 8bit. Default: 8.
quant_delay (int): Quantization delay parameters according by global step. Default: 0.
ema_decay (float): Exponential Moving Average algorithm parameter. Default: 0.999.
per_channel (bool): Quantization granularity based on layer or on channel. Default: False.
num_bits (int): Quantization number bit, support 4 and 8bit. Default: 8.
symmetric (bool): Quantization algorithm use symmetric or not. Default: False.
narrow_range (bool): Quantization algorithm use narrow range or not. Default: False.
quant_delay (int): Quantization delay parameters according by global step. Default: 0.
Inputs:
- **x** (Tensor) - The input of HSwishQuant.
@ -739,28 +930,31 @@ class HSwishQuant(Cell):
"""
def __init__(self,
num_bits=8,
quant_delay=0,
ema_decay=0.999,
per_channel=False,
num_bits=8,
symmetric=False,
narrow_range=False):
narrow_range=False,
quant_delay=0):
super(HSwishQuant, self).__init__()
self.fake_quant_act_before = FakeQuantWithMinMax(min_init=-6,
max_init=6,
num_bits=num_bits,
quant_delay=quant_delay,
ema=True,
ema_decay=ema_decay,
per_channel=per_channel,
num_bits=num_bits,
symmetric=symmetric,
narrow_range=narrow_range)
narrow_range=narrow_range,
quant_delay=quant_delay)
self.fake_quant_act_after = FakeQuantWithMinMax(min_init=-6,
max_init=6,
num_bits=num_bits,
quant_delay=quant_delay,
ema=True,
ema_decay=ema_decay,
per_channel=per_channel,
num_bits=num_bits,
symmetric=symmetric,
narrow_range=narrow_range)
narrow_range=narrow_range,
quant_delay=quant_delay)
self.act = P.HSwish()
def construct(self, x):
@ -769,19 +963,23 @@ class HSwishQuant(Cell):
x = self.fake_quant_act_after(x)
return x
def get_origin(self):
return self.act
class HSigmoidQuant(Cell):
class HSigmoidQuant(_QuantActivation):
r"""
HSigmoidQuant activation function. Add Fake Quant OP before and after HSigmoid OP.
For a more Detailed overview of HSigmoid op.
Args:
num_bits (int): Quantization number bit, support 4 and 8bit. Default: 8.
quant_delay (int): Quantization delay parameters according by global step. Default: 0.
ema_decay (float): Exponential Moving Average algorithm parameter. Default: 0.999.
per_channel (bool): Quantization granularity based on layer or on channel. Default: False.
num_bits (int): Quantization number bit, support 4 and 8bit. Default: 8.
symmetric (bool): Quantization algorithm use symmetric or not. Default: False.
narrow_range (bool): Quantization algorithm use narrow range or not. Default: False.
quant_delay (int): Quantization delay parameters according by global step. Default: 0.
Inputs:
- **x** (Tensor) - The input of HSigmoidQuant.
@ -796,27 +994,31 @@ class HSigmoidQuant(Cell):
"""
def __init__(self,
num_bits=8,
quant_delay=0,
ema_decay=0.999,
per_channel=False,
num_bits=8,
symmetric=False,
narrow_range=False):
narrow_range=False,
quant_delay=0):
super(HSigmoidQuant, self).__init__()
self.fake_quant_act_before = FakeQuantWithMinMax(min_init=-6,
max_init=6,
num_bits=num_bits,
quant_delay=quant_delay,
ema=True,
ema_decay=ema_decay,
per_channel=per_channel,
num_bits=num_bits,
symmetric=symmetric,
narrow_range=narrow_range)
narrow_range=narrow_range,
quant_delay=quant_delay)
self.fake_quant_act_after = FakeQuantWithMinMax(min_init=-6,
max_init=6,
num_bits=num_bits,
quant_delay=quant_delay,
ema=True,
ema_decay=ema_decay,
per_channel=per_channel,
num_bits=num_bits,
symmetric=symmetric,
narrow_range=narrow_range)
narrow_range=narrow_range,
quant_delay=quant_delay)
self.act = P.HSigmoid()
def construct(self, x):
@ -825,6 +1027,9 @@ class HSigmoidQuant(Cell):
x = self.fake_quant_act_after(x)
return x
def get_origin(self):
return self.act
class TensorAddQuant(Cell):
r"""
@ -833,11 +1038,12 @@ class TensorAddQuant(Cell):
For a more Detailed overview of TensorAdd op.
Args:
num_bits (int): Quantization number bit, support 4 and 8bit. Default: 8.
quant_delay (int): Quantization delay parameters according by global step. Default: 0.
ema_decay (float): Exponential Moving Average algorithm parameter. Default: 0.999.
per_channel (bool): Quantization granularity based on layer or on channel. Default: False.
num_bits (int): Quantization number bit, support 4 and 8bit. Default: 8.
symmetric (bool): Quantization algorithm use symmetric or not. Default: False.
narrow_range (bool): Quantization algorithm use narrow range or not. Default: False.
quant_delay (int): Quantization delay parameters according by global step. Default: 0.
Inputs:
- **x** (Tensor) - The input of TensorAddQuant.
@ -853,20 +1059,22 @@ class TensorAddQuant(Cell):
"""
def __init__(self,
num_bits=8,
quant_delay=0,
ema_decay=0.999,
per_channel=False,
num_bits=8,
symmetric=False,
narrow_range=False):
narrow_range=False,
quant_delay=0):
super(TensorAddQuant, self).__init__()
self.fake_quant_act = FakeQuantWithMinMax(min_init=-6,
max_init=6,
num_bits=num_bits,
quant_delay=quant_delay,
ema=True,
ema_decay=ema_decay,
per_channel=per_channel,
num_bits=num_bits,
symmetric=symmetric,
narrow_range=narrow_range)
narrow_range=narrow_range,
quant_delay=quant_delay)
self.add = P.TensorAdd()
def construct(self, x1, x2):
@ -882,11 +1090,12 @@ class MulQuant(Cell):
For a more Detailed overview of Mul op.
Args:
num_bits (int): Quantization number bit, support 4 and 8bit. Default: 8.
quant_delay (int): Quantization delay parameters according by global step. Default: 0.
ema_decay (float): Exponential Moving Average algorithm parameter. Default: 0.999.
per_channel (bool): Quantization granularity based on layer or on channel. Default: False.
num_bits (int): Quantization number bit, support 4 and 8bit. Default: 8.
symmetric (bool): Quantization algorithm use symmetric or not. Default: False.
narrow_range (bool): Quantization algorithm use narrow range or not. Default: False.
quant_delay (int): Quantization delay parameters according by global step. Default: 0.
Inputs:
- **x** (Tensor) - The input of MulQuant.
@ -897,23 +1106,99 @@ class MulQuant(Cell):
"""
def __init__(self,
num_bits=8,
quant_delay=0,
ema_decay=0.999,
per_channel=False,
num_bits=8,
symmetric=False,
narrow_range=False):
narrow_range=False,
quant_delay=0):
super(MulQuant, self).__init__()
self.fake_quant_act = FakeQuantWithMinMax(min_init=-6,
max_init=6,
num_bits=num_bits,
quant_delay=quant_delay,
ema=True,
ema_decay=ema_decay,
per_channel=per_channel,
num_bits=num_bits,
symmetric=symmetric,
narrow_range=narrow_range)
narrow_range=narrow_range,
quant_delay=quant_delay)
self.mul = P.Mul()
def construct(self, x1, x2):
x = self.mul(x1, x2)
x = self.fake_quant_act(x)
return x
class QuantBlock(Cell):
r"""
A quant block of Conv/Dense, activation layer for Ascend deploy.
Calculate Conv or Dense in Int8, with AscendQuant and AscendDeQuant.
Notes:
This block is only for deploy, and not trainable.
Args:
in_channels (int): The number of channels in the input space.
out_channels (int): The number of channels in the output space.
weight_init (Union[Tensor, str, Initializer, numbers.Number]): The trainable weight_init parameter. The dtype
is same as input x. The values of str refer to the function `initializer`. Default: 'normal'.
bias_init (Union[Tensor, str, Initializer, numbers.Number]): The trainable bias_init parameter. The dtype is
same as input x. The values of str refer to the function `initializer`. Default: 'zeros'.
has_bias (bool): Specifies whether the layer uses a bias vector. Default: True.
activation (str): Regularizer function applied to the output of the layer, eg. 'relu'. Default: None.
batchnorm (bool): Specifies to used batchnorm or not. Default: None.
activation (string): Specifies activation type. The optional values are as following:
'softmax', 'logsoftmax', 'relu', 'relu6', 'tanh', 'gelu', 'sigmoid',
'prelu', 'leakyrelu', 'hswish', 'hsigmoid'. Default: None.
Inputs:
- **input** (Tensor) - Tensor of shape :math:`(N, in\_channels)`.
Outputs:
Tensor of shape :math:`(N, out\_channels)`.
Examples:
>>> net = nn.Dense(3, 4)
>>> input = Tensor(np.random.randint(0, 255, [2, 3]), mindspore.float32)
>>> net(input)
"""
def __init__(self,
core_op,
weight,
quant_op,
dequant_op,
dequant_scale,
bias=None,
activation=None):
super(QuantBlock, self).__init__()
self.core_op = core_op
self.weight = weight
self.quant = quant_op
self.dequant = dequant_op
self.dequant_scale = dequant_scale
self.bias = bias
self.has_bias = bias is None
self.activation = activation
self.has_act = activation is None
def construct(self, x):
x = self.quant(x)
x = self.core_op(x, self.weight)
if self.has_bias:
output = self.bias_add(output, self.bias)
if self.has_act:
x = self.activation(x)
x = self.dequant(x, self.dequant_scale)
return x
def extend_repr(self):
str_info = f'quant={self.quant}, core_op={type(self.core_op)}'
if self.has_bias:
str_info = str_info + f', bias={self.bias}'
if self.has_act:
str_info = str_info + f', activation={self.activation}'
str_info = str_info + f', dequant={self.dequant}'
return str_info

View File

@ -13,7 +13,7 @@
# limitations under the License.
# ============================================================================
"""Generate bprop for aware quantization ops"""
"""Generate bprop for quantization aware ops"""
from .. import operations as P
from ..operations import _quant_ops as Q
@ -133,9 +133,9 @@ def get_bprop_batchnorm_fold2_(self):
return bprop
@bprop_getters.register(Q.FakeQuantMinMaxPerLayerUpdate)
@bprop_getters.register(Q.MinMaxUpdatePerLayer)
def get_bprop_fakequant_with_minmax_per_layer_update(self):
"""Generate bprop for FakeQuantMinMaxPerLayerUpdate for Ascend"""
"""Generate bprop for MinMaxUpdatePerLayer for Ascend"""
def bprop(x, x_min, x_max, out, dout):
return zeros_like(x), zeros_like(x_min), zeros_like(x_max)
@ -143,9 +143,9 @@ def get_bprop_fakequant_with_minmax_per_layer_update(self):
return bprop
@bprop_getters.register(Q.FakeQuantMinMaxPerChannelUpdate)
@bprop_getters.register(Q.MinMaxUpdatePerChannel)
def get_bprop_fakequant_with_minmax_per_channel_update(self):
"""Generate bprop for FakeQuantMinMaxPerChannelUpdate for Ascend"""
"""Generate bprop for MinMaxUpdatePerChannel for Ascend"""
def bprop(x, x_min, x_max, out, dout):
return zeros_like(x), zeros_like(x_min), zeros_like(x_max)

View File

@ -1,4 +1,3 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
@ -14,7 +13,7 @@
# limitations under the License.
# ============================================================================
"""FakeQuantMinMaxPerChannelUpdate op"""
"""MinMaxUpdatePerChannel op"""
import te.lang.cce
from te import tvm
from te.platform.fusion_manager import fusion_manager
@ -22,20 +21,15 @@ from topi import generic
from topi.cce import util
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
fake_quant_min_max_per_channel_update_op_info = TBERegOp("FakeQuantMinMaxPerChannelUpdate") \
minmax_update_perchannel_op_info = TBERegOp("MinMaxUpdatePerChannel") \
.fusion_type("OPAQUE") \
.async_flag(False) \
.binfile_name("fake_quant_min_max_per_channel_update.so") \
.binfile_name("minmax_update_perchannel.so") \
.compute_cost(10) \
.kernel_name("fake_quant_min_max_per_channel_update") \
.kernel_name("minmax_update_perchannel") \
.partial_flag(True) \
.attr("ema", "optional", "bool", "all") \
.attr("ema_decay", "optional", "float", "all") \
.attr("symmetric", "optional", "bool", "all") \
.attr("narrow_range", "optional", "bool", "all") \
.attr("training", "optional", "bool", "all") \
.attr("num_bits", "optional", "int", "all") \
.attr("channel_axis", "optional", "int", "all") \
.input(0, "x", None, "required", None) \
.input(1, "min", None, "required", None) \
@ -47,43 +41,46 @@ fake_quant_min_max_per_channel_update_op_info = TBERegOp("FakeQuantMinMaxPerChan
.get_op_info()
@op_info_register(fake_quant_min_max_per_channel_update_op_info)
def _fake_quant_min_max_per_channel_update_tbe():
"""FakeQuantPerChannelUpdate TBE register"""
@op_info_register(minmax_update_perchannel_op_info)
def _minmax_update_perchannel_tbe():
"""MinMaxUpdatePerChannel TBE register"""
return
@fusion_manager.register("fake_quant_min_max_per_channel_update")
def fake_quant_min_max_per_channel_update_compute(x, min_val, max_val,
ema, ema_decay, quant_min, quant_max, training, channel_axis,
kernel_name="fake_quant_min_max_per_channel_update"):
"""FakeQuantPerChannelUpdate compute"""
@fusion_manager.register("minmax_update_perchannel")
def minmax_update_perchannel_compute(x, min_val, max_val,
ema, ema_decay, channel_axis):
"""MinMaxUpdatePerChannel compute"""
shape_min = te.lang.cce.util.shape_to_list(min_val.shape)
if not ema:
ema_decay = 0.0
if training:
# CalMinMax
# CalMinMax
if channel_axis == 0:
axis = [1, 2, 3, 4]
else:
axis = [0, 2, 3]
x_min = te.lang.cce.reduce_min(x, axis=axis)
x_max = te.lang.cce.reduce_max(x, axis=axis)
x_min = te.lang.cce.broadcast(x_min, shape_min)
x_max = te.lang.cce.broadcast(x_max, shape_min)
min_val = te.lang.cce.vadd(te.lang.cce.vmuls(
min_val, ema_decay), te.lang.cce.vmuls(x_min, (1 - ema_decay)))
max_val = te.lang.cce.vadd(te.lang.cce.vmuls(
max_val, ema_decay), te.lang.cce.vmuls(x_max, (1 - ema_decay)))
min_val = te.lang.cce.vmins(min_val, 0)
max_val = te.lang.cce.vmaxs(max_val, 0)
x_min = te.lang.cce.reduce_min(x, axis=axis)
x_max = te.lang.cce.reduce_max(x, axis=axis)
x_min = te.lang.cce.broadcast(x_min, shape_min)
x_max = te.lang.cce.broadcast(x_max, shape_min)
min_val = te.lang.cce.vadd(te.lang.cce.vmuls(
min_val, ema_decay), te.lang.cce.vmuls(x_min, (1 - ema_decay)))
max_val = te.lang.cce.vadd(te.lang.cce.vmuls(
max_val, ema_decay), te.lang.cce.vmuls(x_max, (1 - ema_decay)))
min_val = te.lang.cce.vmins(min_val, 0)
max_val = te.lang.cce.vmaxs(max_val, 0)
return [min_val, max_val]
@util.check_input_type(dict, dict, dict, dict, dict, bool, float, bool, bool, bool, int, int, str)
def fake_quant_min_max_per_channel_update(x, min_val, max_val, min_up, max_up,
ema, ema_decay, symmetric, narrow_range, training, num_bits, channel_axis,
kernel_name="fake_quant_min_max_per_channel_update"):
"""FakeQuantPerLayer op"""
@util.check_input_type(dict, dict, dict, dict, dict, bool, float, int, str)
def minmax_update_perchannel(x, min_val, max_val, min_up, max_up,
ema, ema_decay, channel_axis,
kernel_name="minmax_update_perchannel"):
"""MinMaxUpdatePerChannel op"""
x_shape = x.get("ori_shape")
x_format = x.get("format")
x_dtype = x.get("dtype")
@ -112,21 +109,15 @@ def fake_quant_min_max_per_channel_update(x, min_val, max_val, min_up, max_up,
util.check_dtype_rule(min_dtype, check_list)
util.check_dtype_rule(max_dtype, check_list)
if symmetric:
quant_min = 0 - 2 ** (num_bits - 1)
quant_max = 2 ** (num_bits - 1) - 1
if channel_axis_ == 0:
shape_c = min_val.get("ori_shape")
else:
quant_min = 0
quant_max = 2 ** num_bits - 1
if narrow_range:
quant_min = quant_min + 1
shape_c = [min_val.get("shape")[1], min_val.get("shape")[-1]]
shape_c = [min_val.get("shape")[1], min_val.get("shape")[-1]]
input_data = tvm.placeholder(x.get("shape"), name="x", dtype=x_dtype)
min_data = tvm.placeholder(shape_c, name="min_val", dtype=x_dtype)
max_data = tvm.placeholder(shape_c, name="max_val", dtype=x_dtype)
res_list = fake_quant_min_max_per_channel_update_compute(input_data, min_data, max_data,
ema, ema_decay, quant_min, quant_max, training, channel_axis_, kernel_name)
res_list = minmax_update_perchannel_compute(input_data, min_data, max_data,
ema, ema_decay, channel_axis_)
with tvm.target.cce():
sch = generic.auto_schedule(res_list)

View File

@ -13,7 +13,7 @@
# limitations under the License.
# ============================================================================
"""FakeQuantMinMaxPerLayerUpdate op"""
"""MinMaxUpdatePerLayer op"""
from functools import reduce as functools_reduce
import te.lang.cce
from te import tvm
@ -22,20 +22,15 @@ from topi import generic
from topi.cce import util
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
fake_quant_minmax_update_op_info = TBERegOp("FakeQuantMinMaxPerLayerUpdate") \
minmax_update_perlayer_op_info = TBERegOp("MinMaxUpdatePerLayer") \
.fusion_type("OPAQUE") \
.async_flag(False) \
.binfile_name("fake_quant_minmax_update.so") \
.binfile_name("minmax_update_perlayer.so") \
.compute_cost(10) \
.kernel_name("fake_quant_minmax_update") \
.kernel_name("minmax_update_perlayer") \
.partial_flag(True) \
.attr("ema", "optional", "bool", "all") \
.attr("ema_decay", "optional", "float", "all") \
.attr("symmetric", "optional", "bool", "all") \
.attr("narrow_range", "optional", "bool", "all") \
.attr("training", "optional", "bool", "all") \
.attr("num_bits", "optional", "int", "all") \
.input(0, "x", None, "required", None) \
.input(1, "min", None, "required", None) \
.input(2, "max", None, "required", None) \
@ -46,44 +41,42 @@ fake_quant_minmax_update_op_info = TBERegOp("FakeQuantMinMaxPerLayerUpdate") \
.get_op_info()
@op_info_register(fake_quant_minmax_update_op_info)
def _fake_quant_minmax_update_tbe():
"""FakeQuantMinMaxPerLayerUpdate TBE register"""
@op_info_register(minmax_update_perlayer_op_info)
def _minmax_update_perlayer_tbe():
"""MinMaxUpdatePerLayer TBE register"""
return
@fusion_manager.register("fake_quant_minmax_update")
def fake_quant_minmax_update_compute(x, min_val, max_val, ema, ema_decay, quant_min, quant_max, training,
kernel_name="fake_quant_minmax_update"):
"""FakeQuantMinMaxPerLayerUpdate compute"""
@fusion_manager.register("minmax_update_perlayer")
def minmax_update_perlayer_compute(x, min_val, max_val, ema, ema_decay):
"""MinMaxUpdatePerLayer compute"""
shape = te.lang.cce.util.shape_to_list(x.shape)
shape_min = te.lang.cce.util.shape_to_list(min_val.shape)
min_val = te.lang.cce.broadcast(min_val, shape_min, x.dtype)
max_val = te.lang.cce.broadcast(max_val, shape_min, x.dtype)
if not ema:
ema_decay = 0.0
if training:
# CalMinMax
axis = tuple(range(len(shape)))
x_min = te.lang.cce.reduce_min(x, axis=axis)
x_max = te.lang.cce.reduce_max(x, axis=axis)
x_min = te.lang.cce.broadcast(x_min, shape_min)
x_max = te.lang.cce.broadcast(x_max, shape_min)
min_val = te.lang.cce.vadd(te.lang.cce.vmuls(
min_val, ema_decay), te.lang.cce.vmuls(x_min, (1 - ema_decay)))
max_val = te.lang.cce.vadd(te.lang.cce.vmuls(
max_val, ema_decay), te.lang.cce.vmuls(x_max, (1 - ema_decay)))
min_val = te.lang.cce.vmins(min_val, 0)
max_val = te.lang.cce.vmaxs(max_val, 0)
# CalMinMax
axis = tuple(range(len(shape)))
x_min = te.lang.cce.reduce_min(x, axis=axis)
x_max = te.lang.cce.reduce_max(x, axis=axis)
x_min = te.lang.cce.broadcast(x_min, shape_min)
x_max = te.lang.cce.broadcast(x_max, shape_min)
min_val = te.lang.cce.vadd(te.lang.cce.vmuls(
min_val, ema_decay), te.lang.cce.vmuls(x_min, (1 - ema_decay)))
max_val = te.lang.cce.vadd(te.lang.cce.vmuls(
max_val, ema_decay), te.lang.cce.vmuls(x_max, (1 - ema_decay)))
min_val = te.lang.cce.vmins(min_val, 0)
max_val = te.lang.cce.vmaxs(max_val, 0)
return [min_val, max_val]
@util.check_input_type(dict, dict, dict, dict, dict, bool, float, bool, bool, bool, int, str)
def fake_quant_minmax_update(x, min_val, max_val, min_up, max_up,
ema, ema_decay, symmetric, narrow_range, training, num_bits,
kernel_name="fake_quant_minmax_update"):
"""FakeQuantPerLayer op"""
@util.check_input_type(dict, dict, dict, dict, dict, bool, float, str)
def minmax_update_perlayer(x, min_val, max_val, min_up, max_up,
ema, ema_decay, kernel_name="minmax_update_perlayer"):
"""MinMaxUpdatePerLayer op"""
input_shape = x.get("shape")
input_dtype = x.get("dtype")
min_shape = min_val.get("ori_shape")
@ -112,20 +105,10 @@ def fake_quant_minmax_update(x, min_val, max_val, min_up, max_up,
input_shape = (functools_reduce(lambda x, y: x * y, input_shape[:]),)
shape_min, _, _ = util.produce_shapes(min_shape, input_shape)
if symmetric:
quant_min = 0 - 2 ** (num_bits - 1)
quant_max = 2 ** (num_bits - 1) - 1
else:
quant_min = 0
quant_max = 2 ** num_bits - 1
if narrow_range:
quant_min = quant_min + 1
input_data = tvm.placeholder(input_shape, name="x", dtype=x_dtype)
min_data = tvm.placeholder(shape_min, name="min_data", dtype=min_dtype)
max_data = tvm.placeholder(shape_min, name="max_data", dtype=max_dtype)
res_list = fake_quant_minmax_update_compute(input_data, min_data, max_data,
ema, ema_decay, quant_min, quant_max, training, kernel_name)
res_list = minmax_update_perlayer_compute(input_data, min_data, max_data, ema, ema_decay)
with tvm.target.cce():
sch = generic.auto_schedule(res_list)

View File

@ -21,12 +21,12 @@ from ..._checkparam import Rel
from ..primitive import PrimitiveWithInfer, prim_attr_register
from ...common import dtype as mstype
__all__ = ["FakeQuantPerLayer",
__all__ = ["MinMaxUpdatePerLayer",
"MinMaxUpdatePerChannel",
"FakeQuantPerLayer",
"FakeQuantPerLayerGrad",
"FakeQuantPerChannel",
"FakeQuantPerChannelGrad",
"FakeQuantMinMaxPerLayerUpdate",
"FakeQuantMinMaxPerChannelUpdate",
"BatchNormFold",
"BatchNormFoldGrad",
"CorrectionMul",
@ -36,23 +36,140 @@ __all__ = ["FakeQuantPerLayer",
"BatchNormFold2Grad",
"BatchNormFoldD",
"BatchNormFoldGradD",
"BNTrainingReduce",
"BatchNormFold2_D",
"BatchNormFold2GradD",
"BatchNormFold2GradReduce",
"BatchNormFold2GradReduce"
]
class MinMaxUpdatePerLayer(PrimitiveWithInfer):
r"""
Update min and max per layer.
Args:
ema (bool): Use EMA algorithm update value min and max. Default: False.
ema_decay (int) : EMA algorithm decay parameter. Default: 0.999.
Inputs:
- **x** (Tensor) : float32 Tensor representing the shape of the output tensor.
- **min** (Tensor) : Value of the min range of the input data x.
- **max** (Tensor) : Value of the max range of the input data x.
Outputs:
- Tensor: Simulate quantize tensor of x.
Examples:
>>> input_tensor = Tensor(np.random.rand(3, 16, 5, 5), mstype.float32)
>>> min_tensor = Tensor(np.array([-6]), mstype.float32)
>>> max_tensor = Tensor(np.array([6]), mstype.float32)
>>> output_tensor = MinMaxUpdatePerLayer(num_bits=8)(input_tensor, min_tensor, max_tensor)
"""
support_quant_bit = [4, 7, 8]
@prim_attr_register
def __init__(self, ema=False, ema_decay=0.999):
"""init FakeQuantMinMaxPerLayerUpdate OP"""
if context.get_context('device_target') == "Ascend":
from mindspore.ops._op_impl._custom_op import minmax_update_perlayer
if ema and not ema_decay:
raise ValueError(
f"For '{self.name}' attr \'ema\' and \'ema_decay\' should set together.")
self.ema = validator.check_value_type('ema', ema, (bool,), self.name)
self.ema_decay = validator.check_number_range(
'ema_decay', ema_decay, 0, 1, Rel.INC_BOTH, self.name)
self.init_prim_io_names(inputs=['x', 'min', 'max'],
outputs=['min_up', 'max_up'])
def infer_shape(self, x_shape, min_shape, max_shape):
validator.check_integer("x rank", len(x_shape), 1, Rel.GE, self.name)
validator.check("min shape", min_shape, "max shape",
max_shape, Rel.EQ, self.name)
validator.check_integer("min shape", len(
min_shape), 1, Rel.EQ, self.name)
return min_shape, max_shape
def infer_dtype(self, x_type, min_type, max_type):
valid_types = (mstype.float16, mstype.float32)
validator.check_tensor_type_same({"x": x_type}, valid_types, self.name)
validator.check_tensor_type_same(
{"min": min_type}, valid_types, self.name)
validator.check_tensor_type_same(
{"max": max_type}, valid_types, self.name)
return min_type, max_type
class MinMaxUpdatePerChannel(PrimitiveWithInfer):
r"""
Update min and max per channel.
Args:
ema (bool): Use EMA algorithm update value min and max. Default: False.
ema_decay (int) : EMA algorithm decay parameter. Default: 0.999.
channel_axis (int): Channel asis for per channel compute. Default: 1.
Inputs:
- **x** (Tensor) : float32 Tensor representing the shape of the output tensor.
- **min** (Tensor) : Value of the min range of the input data x.
- **max** (Tensor) : Value of the max range of the input data x.
Outputs:
- Tensor: Simulate quantize tensor of x.
Examples:
>>> x = Tensor(np.random.rand(3, 16, 5, 5), mstype.float32)
>>> min = Tensor(np.random.uniform(-1, 1, size=16), mstype.float32)
>>> max = Tensor(np.random.uniform(-1, 1, size=16), mstype.float32)
>>> output_tensor = MinMaxUpdatePerChannel(num_bits=8)(x, min, max)
"""
support_quant_bit = [4, 7, 8]
@prim_attr_register
def __init__(self, ema=False, ema_decay=0.999, channel_axis=1):
"""init FakeQuantPerChannelUpdate OP for Ascend"""
if context.get_context('device_target') == "Ascend":
from mindspore.ops._op_impl._custom_op import minmax_update_perchannel
if ema and not ema_decay:
raise ValueError(
f"For '{self.name}' attr \'ema\' and \'ema_decay\' should set together.")
self.ema = validator.check_value_type('ema', ema, (bool,), self.name)
self.ema_decay = validator.check_number_range(
'ema_decay', ema_decay, 0, 1, Rel.INC_BOTH, self.name)
self.channel_axis = validator.check_integer(
'channel axis', channel_axis, 0, Rel.GE, self.name)
self.init_prim_io_names(
inputs=['x', 'min', 'max'], outputs=['min_up', 'max_up'])
def infer_shape(self, x_shape, min_shape, max_shape):
validator.check_integer("x rank", len(x_shape), 1, Rel.GT, self.name)
validator.check("min shape", min_shape, "max shape",
max_shape, Rel.EQ, self.name)
validator.check_integer("min shape", len(
min_shape), 1, Rel.EQ, self.name)
return min_shape, max_shape
def infer_dtype(self, x_type, min_type, max_type):
valid_types = (mstype.float16, mstype.float32)
validator.check_tensor_type_same(
{"x": x_type}, valid_types, self.name)
validator.check_tensor_type_same(
{"min": min_type}, valid_types, self.name)
validator.check_tensor_type_same(
{"max": max_type}, valid_types, self.name)
return min_type, max_type
class FakeQuantPerLayer(PrimitiveWithInfer):
r"""
Simulate the quantize and dequantize operations in training time.
Args:
num_bits (int) : Number bits for aware quantilization. Default: 8.
num_bits (int) : Number bits for quantization aware. Default: 8.
ema (bool): Use EMA algorithm update value min and max. Default: False.
ema_decay (int) : EMA algorithm decay parameter. Default: 0.999.
quant_delay (int): Quantilization delay parameter. Before delay step in training time not update
simulate aware quantize funcion. After delay step in training time begin simulate the aware
simulate quantization aware funcion. After delay step in training time begin simulate the aware
quantize funcion. Default: 0.
symmetric (bool): Quantization algorithm use symmetric or not. Default: False.
narrow_range (bool): Quantization algorithm use narrow range or not. Default: False.
@ -334,7 +451,7 @@ class BatchNormFold(PrimitiveWithInfer):
Batch normalization folded.
Args:
momentum (float): Momentum value should be [0, 1]. Default: 0.1.
momentum (float): Momentum value should be [0, 1]. Default: 0.9.
epsilon (float): A small float number to avoid dividing by 0. 1e-5 if dtype in
float32 else 1e-3. Default: 1e-5.
is_training (bool): In training mode set True, else set False. Default: True.
@ -366,7 +483,7 @@ class BatchNormFold(PrimitiveWithInfer):
channel_axis = 1
@prim_attr_register
def __init__(self, momentum=0.1, epsilon=1e-5, is_training=True, freeze_bn=0):
def __init__(self, momentum=0.9, epsilon=1e-5, is_training=True, freeze_bn=0):
"""init batch norm fold layer"""
self.momentum = validator.check_number_range('momentum', momentum, 0, 1, Rel.INC_BOTH, self.name)
self.epsilon = validator.check_float_positive('epsilon', epsilon, self.name)
@ -731,32 +848,6 @@ class BatchNormFoldGradD(PrimitiveWithInfer):
return x_type
class BNTrainingReduce(PrimitiveWithInfer):
"""
reduce sum at axis [0, 2, 3].
Inputs:
- **x** (Tensor) - Tensor of shape :math:`(N, C)`.
Outputs:
- **x_sum** (Tensor) - Tensor has the same shape as x.
- **x_square_sum** (Tensor) - Tensor has the same shape as x.
"""
@prim_attr_register
def __init__(self):
"""init _BNTrainingReduce layer"""
self.init_prim_io_names(inputs=['x'],
outputs=['x_sum', 'x_square_sum'])
def infer_shape(self, x_shape):
return [x_shape[1]], [x_shape[1]]
def infer_dtype(self, x_type):
return x_type, x_type
class BatchNormFold2_D(PrimitiveWithInfer):
"""
Scale the bias with a correction factor to the long term statistics
@ -859,153 +950,3 @@ class BatchNormFold2GradReduce(PrimitiveWithInfer):
def infer_dtype(self, dout_type, x_type):
validator.check("dout type", dout_type, "x type", x_type)
return dout_type, dout_type
class FakeQuantMinMaxPerLayerUpdate(PrimitiveWithInfer):
r"""
Update min and max value for fake quant per layer op.
Args:
num_bits (int) : Number bits for aware quantilization. Default: 8.
ema (bool): Use EMA algorithm update value min and max. Default: False.
ema_decay (int) : EMA algorithm decay parameter. Default: 0.999.
symmetric (bool): Quantization algorithm use symmetric or not. Default: False.
narrow_range (bool): Quantization algorithm use narrow range or not. Default: False.
training (bool): Training the network or not. Default: True.
Inputs:
- **x** (Tensor) : float32 Tensor representing the shape of the output tensor.
- **min** (Tensor) : Value of the min range of the input data x.
- **max** (Tensor) : Value of the max range of the input data x.
Outputs:
- Tensor: Simulate quantize tensor of x.
Examples:
>>> input_tensor = Tensor(np.random.rand(3, 16, 5, 5), mstype.float32)
>>> min_tensor = Tensor(np.array([-6]), mstype.float32)
>>> max_tensor = Tensor(np.array([6]), mstype.float32)
>>> output_tensor = FakeQuantWithMinMax(num_bits=8)(input_tensor, min_tensor, max_tensor)
"""
support_quant_bit = [4, 7, 8]
@prim_attr_register
def __init__(self, num_bits=8, ema=False, ema_decay=0.999, symmetric=False, narrow_range=False,
training=True):
"""init FakeQuantMinMaxPerLayerUpdate OP"""
if context.get_context('device_target') == "Ascend":
from mindspore.ops._op_impl._custom_op import fake_quant_minmax_perlayer_update
if num_bits not in self.support_quant_bit:
raise ValueError(
f"For '{self.name}' attr \'num_bits\' is not support.")
if ema and not ema_decay:
raise ValueError(
f"For '{self.name}' attr \'ema\' and \'ema_decay\' should set together.")
self.ema = validator.check_value_type('ema', ema, (bool,), self.name)
self.symmetric = validator.check_value_type(
'symmetric', symmetric, (bool,), self.name)
self.narrow_range = validator.check_value_type(
'narrow_range', narrow_range, (bool,), self.name)
self.training = validator.check_value_type(
'training', training, (bool,), self.name)
self.ema_decay = validator.check_number_range(
'ema_decay', ema_decay, 0, 1, Rel.INC_BOTH, self.name)
self.num_bits = validator.check_integer(
'num_bits', num_bits, 0, Rel.GT, self.name)
self.init_prim_io_names(inputs=['x', 'min', 'max'],
outputs=['min_up', 'max_up'])
def infer_shape(self, x_shape, min_shape, max_shape):
validator.check_integer("x rank", len(x_shape), 1, Rel.GE, self.name)
validator.check("min shape", min_shape, "max shape",
max_shape, Rel.EQ, self.name)
validator.check_integer("min shape", len(
min_shape), 1, Rel.EQ, self.name)
return min_shape, max_shape
def infer_dtype(self, x_type, min_type, max_type):
valid_types = (mstype.float16, mstype.float32)
validator.check_tensor_type_same({"x": x_type}, valid_types, self.name)
validator.check_tensor_type_same(
{"min": min_type}, valid_types, self.name)
validator.check_tensor_type_same(
{"max": max_type}, valid_types, self.name)
return min_type, max_type
class FakeQuantMinMaxPerChannelUpdate(PrimitiveWithInfer):
r"""
Update min and max value for fake quant per layer op.
Args:
num_bits (int) : Number bits for aware quantilization. Default: 8.
ema (bool): Use EMA algorithm update value min and max. Default: False.
ema_decay (int) : EMA algorithm decay parameter. Default: 0.999.
symmetric (bool): Quantization algorithm use symmetric or not. Default: False.
narrow_range (bool): Quantization algorithm use narrow range or not. Default: False.
training (bool): Training the network or not. Default: True.
channel_axis (int): Channel asis for per channel compute. Default: 1.
Inputs:
- **x** (Tensor) : float32 Tensor representing the shape of the output tensor.
- **min** (Tensor) : Value of the min range of the input data x.
- **max** (Tensor) : Value of the max range of the input data x.
Outputs:
- Tensor: Simulate quantize tensor of x.
Examples:
>>> x = Tensor(np.random.rand(3, 16, 5, 5), mstype.float32)
>>> min = Tensor(np.random.uniform(-1, 1, size=16), mstype.float32)
>>> max = Tensor(np.random.uniform(-1, 1, size=16), mstype.float32)
>>> output_tensor = FakeQuantWithMinMax(num_bits=8)(x, min, max)
"""
support_quant_bit = [4, 7, 8]
@prim_attr_register
def __init__(self, num_bits=8, ema=False, ema_decay=0.999, symmetric=False, narrow_range=False,
training=True, channel_axis=1):
"""init FakeQuantPerChannelUpdate OP for Ascend"""
if context.get_context('device_target') == "Ascend":
from mindspore.ops._op_impl._custom_op import fake_quant_minmax_perchannel_update
if num_bits not in self.support_quant_bit:
raise ValueError(
f"For '{self.name}' attr \'num_bits\' is not support.")
if ema and not ema_decay:
raise ValueError(
f"For '{self.name}' attr \'ema\' and \'ema_decay\' should set together.")
self.ema = validator.check_value_type('ema', ema, (bool,), self.name)
self.symmetric = validator.check_value_type(
'symmetric', symmetric, (bool,), self.name)
self.narrow_range = validator.check_value_type(
'narrow_range', narrow_range, (bool,), self.name)
self.training = validator.check_value_type(
'training', training, (bool,), self.name)
self.ema_decay = validator.check_number_range(
'ema_decay', ema_decay, 0, 1, Rel.INC_BOTH, self.name)
self.num_bits = validator.check_integer(
'num_bits', num_bits, 0, Rel.GT, self.name)
self.channel_axis = validator.check_integer(
'channel axis', channel_axis, 0, Rel.GE, self.name)
self.init_prim_io_names(
inputs=['x', 'min', 'max'], outputs=['min_up', 'max_up'])
def infer_shape(self, x_shape, min_shape, max_shape):
validator.check_integer("x rank", len(x_shape), 1, Rel.GT, self.name)
validator.check("min shape", min_shape, "max shape",
max_shape, Rel.EQ, self.name)
validator.check_integer("min shape", len(
min_shape), 1, Rel.EQ, self.name)
return min_shape, max_shape
def infer_dtype(self, x_type, min_type, max_type):
valid_types = (mstype.float16, mstype.float32)
validator.check_tensor_type_same(
{"x": x_type}, valid_types, self.name)
validator.check_tensor_type_same(
{"min": min_type}, valid_types, self.name)
validator.check_tensor_type_same(
{"max": max_type}, valid_types, self.name)
return min_type, max_type