From f110c7616bb7805352b4b176e40882431ef8b38f Mon Sep 17 00:00:00 2001 From: wangdongxu Date: Sat, 20 Jun 2020 17:53:47 +0800 Subject: [PATCH] fix perchannel num_channels not set bug and adjust quant.py params order --- mindspore/nn/layer/quant.py | 537 ++++++++++++++---- mindspore/ops/_grad/grad_quant_ops.py | 10 +- ..._update.py => minmax_update_perchannel.py} | 83 ++- ...er_update.py => minmax_update_perlayer.py} | 73 +-- mindspore/ops/operations/_quant_ops.py | 311 ++++------ 5 files changed, 607 insertions(+), 407 deletions(-) rename mindspore/ops/_op_impl/_custom_op/{fake_quant_minmax_perchannel_update.py => minmax_update_perchannel.py} (57%) rename mindspore/ops/_op_impl/_custom_op/{fake_quant_minmax_perlayer_update.py => minmax_update_perlayer.py} (61%) diff --git a/mindspore/nn/layer/quant.py b/mindspore/nn/layer/quant.py index 26d56689ff..14731c6262 100644 --- a/mindspore/nn/layer/quant.py +++ b/mindspore/nn/layer/quant.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================ -"""Aware quantization.""" +"""Quantization aware.""" from functools import partial import numpy as np @@ -27,9 +27,16 @@ from mindspore._checkparam import Validator as validator, Rel from mindspore.nn.cell import Cell from mindspore.nn.layer.activation import get_activation import mindspore.context as context +from .normalization import BatchNorm2d +from .activation import get_activation +from ..cell import Cell +from . import conv, basic +from ..._checkparam import ParamValidator as validator from ...ops.operations import _quant_ops as Q __all__ = [ + 'Conv2dBnAct', + 'DenseBnAct', 'FakeQuantWithMinMax', 'Conv2dBatchNormQuant', 'Conv2dQuant', @@ -43,6 +50,165 @@ __all__ = [ ] +class Conv2dBnAct(Cell): + r""" + A combination of convolution, Batchnorm, activation layer. + + For a more Detailed overview of Conv2d op. + + Args: + in_channels (int): The number of input channel :math:`C_{in}`. + out_channels (int): The number of output channel :math:`C_{out}`. + kernel_size (Union[int, tuple]): The data type is int or tuple with 2 integers. Specifies the height + and width of the 2D convolution window. Single int means the value if for both height and width of + the kernel. A tuple of 2 ints means the first value is for the height and the other is for the + width of the kernel. + stride (int): Specifies stride for all spatial dimensions with the same value. Value of stride should be + greater or equal to 1 but bounded by the height and width of the input. Default: 1. + pad_mode (str): Specifies padding mode. The optional values are "same", "valid", "pad". Default: "same". + padding (int): Implicit paddings on both sides of the input. Default: 0. + dilation (int): Specifying the dilation rate to use for dilated convolution. If set to be :math:`k > 1`, + there will be :math:`k - 1` pixels skipped for each sampling location. Its value should be greater + or equal to 1 and bounded by the height and width of the input. Default: 1. + group (int): Split filter into groups, `in_ channels` and `out_channels` should be + divisible by the number of groups. Default: 1. + has_bias (bool): Specifies whether the layer uses a bias vector. Default: False. + weight_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the convolution kernel. + It can be a Tensor, a string, an Initializer or a numbers.Number. When a string is specified, + values from 'TruncatedNormal', 'Normal', 'Uniform', 'HeUniform' and 'XavierUniform' distributions as well + as constant 'One' and 'Zero' distributions are possible. Alias 'xavier_uniform', 'he_uniform', 'ones' + and 'zeros' are acceptable. Uppercase and lowercase are both acceptable. Refer to the values of + Initializer for more details. Default: 'normal'. + bias_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the bias vector. Possible + Initializer and string are the same as 'weight_init'. Refer to the values of + Initializer for more details. Default: 'zeros'. + batchnorm (bool): Specifies to used batchnorm or not. Default: None. + activation (string): Specifies activation type. The optional values are as following: + 'softmax', 'logsoftmax', 'relu', 'relu6', 'tanh', 'gelu', 'sigmoid', + 'prelu', 'leakyrelu', 'hswish', 'hsigmoid'. Default: None. + + Inputs: + - **input** (Tensor) - Tensor of shape :math:`(N, C_{in}, H_{in}, W_{in})`. + + Outputs: + Tensor of shape :math:`(N, C_{out}, H_{out}, W_{out})`. + + Examples: + >>> net = Conv2dBnAct(120, 240, 4, batchnorm=True, activation='ReLU') + >>> input = Tensor(np.ones([1, 120, 1024, 640]), mindspore.float32) + >>> net(input).shape + (1, 240, 1024, 640) + """ + + def __init__(self, + in_channels, + out_channels, + kernel_size, + stride=1, + pad_mode='same', + padding=0, + dilation=1, + group=1, + has_bias=False, + weight_init='normal', + bias_init='zeros', + batchnorm=None, + activation=None): + super(Conv2dBnAct, self).__init__() + self.conv = conv.Conv2d( + in_channels, + out_channels, + kernel_size, + stride, + pad_mode, + padding, + dilation, + group, + has_bias, + weight_init, + bias_init) + self.has_bn = batchnorm is not None + self.has_act = activation is not None + self.batchnorm = batchnorm + if batchnorm is True: + self.batchnorm = BatchNorm2d(out_channels) + elif batchnorm is not None: + validator.check_isinstance('batchnorm', batchnorm, (BatchNorm2d,)) + self.activation = get_activation(activation) + + def construct(self, x): + x = self.conv(x) + if self.has_bn: + x = self.batchnorm(x) + if self.has_act: + x = self.activation(x) + return x + + +class DenseBnAct(Cell): + r""" + A combination of Dense, Batchnorm, activation layer. + + For a more Detailed overview of Dense op. + + Args: + in_channels (int): The number of channels in the input space. + out_channels (int): The number of channels in the output space. + weight_init (Union[Tensor, str, Initializer, numbers.Number]): The trainable weight_init parameter. The dtype + is same as input x. The values of str refer to the function `initializer`. Default: 'normal'. + bias_init (Union[Tensor, str, Initializer, numbers.Number]): The trainable bias_init parameter. The dtype is + same as input x. The values of str refer to the function `initializer`. Default: 'zeros'. + has_bias (bool): Specifies whether the layer uses a bias vector. Default: True. + activation (str): Regularizer function applied to the output of the layer, eg. 'relu'. Default: None. + batchnorm (bool): Specifies to used batchnorm or not. Default: None. + activation (string): Specifies activation type. The optional values are as following: + 'softmax', 'logsoftmax', 'relu', 'relu6', 'tanh', 'gelu', 'sigmoid', + 'prelu', 'leakyrelu', 'hswish', 'hsigmoid'. Default: None. + + Inputs: + - **input** (Tensor) - Tensor of shape :math:`(N, in\_channels)`. + + Outputs: + Tensor of shape :math:`(N, out\_channels)`. + + Examples: + >>> net = nn.DenseBnAct(3, 4) + >>> input = Tensor(np.random.randint(0, 255, [2, 3]), mindspore.float32) + >>> net(input) + """ + + def __init__(self, + in_channels, + out_channels, + weight_init='normal', + bias_init='zeros', + has_bias=True, + batchnorm=None, + activation=None): + super(DenseBnAct, self).__init__() + self.dense = basic.Dense( + in_channels, + out_channels, + weight_init, + bias_init, + has_bias) + self.has_bn = batchnorm is not None + self.has_act = activation is not None + if batchnorm is True: + self.batchnorm = BatchNorm2d(out_channels) + elif batchnorm is not None: + validator.check_isinstance('batchnorm', batchnorm, (BatchNorm2d,)) + self.activation = get_activation(activation) + + def construct(self, x): + x = self.dense(x) + if self.has_bn: + x = self.batchnorm(x) + if self.has_act: + x = self.activation(x) + return x + + class BatchNormFoldCell(Cell): """ Batch normalization folded. @@ -105,20 +271,20 @@ class BatchNormFoldCell(Cell): class FakeQuantWithMinMax(Cell): r""" - Aware Quantization op. This OP provide Fake quantization observer function on data with min and max. + Quantization aware op. This OP provide Fake quantization observer function on data with min and max. Args: min_init (int, float): The dimension of channel or 1(layer). Default: -6. max_init (int, float): The dimension of channel or 1(layer). Default: 6. - num_bits (int): Quantization number bit, support 4 and 8bit. Default: 8. ema (bool): Exponential Moving Average algorithm update min and max. Default: False. ema_decay (float): Exponential Moving Average algorithm parameter. Default: 0.999. - per_channel (bool): Quantization by layer or channel. Default: False. + per_channel (bool): Quantization granularity based on layer or on channel. Default: False. channel_axis (int): Quantization by channel axis. Default: 1. - out_channels (int): declarate the min and max channel size, Default: 1. - quant_delay (int): Quantization delay parameters according by global step. Default: 0. + num_channels (int): declarate the min and max channel size, Default: 1. + num_bits (int): Quantization number bit, support 4 and 8bit. Default: 8. symmetric (bool): Quantization algorithm use symmetric or not. Default: False. narrow_range (bool): Quantization algorithm use narrow range or not. Default: False. + quant_delay (int): Quantization delay parameters according by global step. Default: 0. Inputs: - **x** (Tensor) - The input of FakeQuantWithMinMax. @@ -135,15 +301,15 @@ class FakeQuantWithMinMax(Cell): def __init__(self, min_init=-6, max_init=6, - num_bits=8, ema=False, ema_decay=0.999, per_channel=False, channel_axis=1, - out_channels=1, - quant_delay=0, + num_channels=1, + num_bits=8, symmetric=False, - narrow_range=False): + narrow_range=False, + quant_delay=0): """init FakeQuantWithMinMax layer""" super(FakeQuantWithMinMax, self).__init__() self.min_init = min_init @@ -152,7 +318,7 @@ class FakeQuantWithMinMax(Cell): self.ema = ema self.ema_decay = ema_decay self.per_channel = per_channel - self.out_channels = out_channels + self.num_channels = num_channels self.channel_axis = channel_axis self.quant_delay = quant_delay self.symmetric = symmetric @@ -161,54 +327,54 @@ class FakeQuantWithMinMax(Cell): # init tensor min and max for fake quant op if self.per_channel: - min_array = np.array([self.min_init for i in range(0, self.out_channels)]).astype(np.float32) - max_array = np.array([self.max_init for i in range(0, self.out_channels)]).astype(np.float32) + min_array = np.array([self.min_init] * self.num_channels).astype(np.float32) + max_array = np.array([self.max_init] * self.num_channels).astype(np.float32) else: - min_array = np.array([self.min_init]).reshape(1).astype(np.float32) - max_array = np.array([self.max_init]).reshape(1).astype(np.float32) + min_array = np.array([self.min_init]).astype(np.float32) + max_array = np.array([self.max_init]).astype(np.float32) self.minq = Parameter(Tensor(min_array), name='quant_min', requires_grad=False) self.maxq = Parameter(Tensor(max_array), name='quant_max', requires_grad=False) # init fake quant relative op if per_channel: quant_fun = partial(Q.FakeQuantPerChannel, channel_axis=self.channel_axis) - ema_fun = partial(Q.FakeQuantMinMaxPerChannelUpdate, channel_axis=self.channel_axis) + ema_fun = partial(Q.MinMaxUpdatePerChannel, channel_axis=self.channel_axis) else: quant_fun = Q.FakeQuantPerLayer - ema_fun = Q.FakeQuantMinMaxPerLayerUpdate + ema_fun = Q.MinMaxUpdatePerLayer + self.ema_update = ema_fun(ema=self.ema, ema_decay=self.ema_decay) if self.is_ascend: - self.fake_quant = quant_fun(num_bits=self.num_bits, - symmetric=self.symmetric, - narrow_range=self.narrow_range) + self.fake_quant_train = quant_fun(num_bits=self.num_bits, + symmetric=self.symmetric, + narrow_range=self.narrow_range) + self.fake_quant_infer = self.fake_quant_train else: - self.fake_quant = quant_fun(num_bits=self.num_bits, - ema=self.ema, - ema_decay=ema_decay, - quant_delay=quant_delay, - symmetric=self.symmetric, - narrow_range=self.narrow_range) - self.ema_update = ema_fun(num_bits=self.num_bits, - ema=self.ema, - ema_decay=self.ema_decay, - symmetric=self.symmetric, - narrow_range=self.narrow_range) + quant_fun = partial(quant_fun, + ema=self.ema, + ema_decay=ema_decay, + num_bits=self.num_bits, + symmetric=self.symmetric, + narrow_range=self.narrow_range, + quant_delay=quant_delay) + self.fake_quant_train = quant_fun(training=True) + self.fake_quant_infer = quant_fun(training=False) def extend_repr(self): s = 'num_bits={}, symmetric={}, narrow_range={}, ema={}({}), per_channel={}({}, {}), ' \ 'quant_delay={}, min_init={}, max_init={}'.format( self.num_bits, self.symmetric, self.narrow_range, self.ema, self.ema_decay, self.per_channel, - self.channel_axis, self.out_channels, self.quant_delay, self.min_init, self.max_init) + self.channel_axis, self.num_channels, self.quant_delay, self.min_init, self.max_init) return s def construct(self, x): - if self.is_ascend and self.training: + if self.training: min_up, max_up = self.ema_update(x, self.minq, self.maxq) - out = self.fake_quant(x, min_up, max_up) P.Assign()(self.minq, min_up) P.Assign()(self.maxq, max_up) + out = self.fake_quant_train(x, self.minq, self.maxq) else: - out = self.fake_quant(x, self.minq, self.maxq) + out = self.fake_quant_infer(x, self.minq, self.maxq) return out @@ -225,8 +391,8 @@ class Conv2dBatchNormQuant(Cell): stride (int): Specifies stride for all spatial dimensions with the same value. pad_mode: (str): Specifies padding mode. The optional values are "same", "valid", "pad". Default: "same". padding: (int): Implicit paddings on both sides of the input. Default: 0. - eps (int): Parameters for BatchNormal. Default: 1e-5. - momentum (int): Parameters for BatchNormal op. Default: 0.997. + eps (float): Parameters for BatchNormal. Default: 1e-5. + momentum (float): Parameters for BatchNormal op. Default: 0.997. weight_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the convolution kernel. Default: 'normal'. beta_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the @@ -237,13 +403,13 @@ class Conv2dBatchNormQuant(Cell): mean vector. Default: 'zeros'. var_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the variance vector. Default: 'ones'. - quant_delay (int): Quantization delay parameters according by global step. Default: 0. - freeze_bn (int): Quantization freeze BatchNormal op according by global step. Default: 100000. fake (bool): Conv2dBatchNormQuant Cell add FakeQuantWithMinMax op or not. Default: True. - num_bits (int): Quantization number bit, support 4 and 8bit. Default: 8. per_channel (bool): FakeQuantWithMinMax Parameters. Default: False. + num_bits (int): Quantization number bit, support 4 and 8bit. Default: 8. symmetric (bool): Quantization algorithm use symmetric or not. Default: False. narrow_range (bool): Quantization algorithm use narrow range or not. Default: False. + quant_delay (int): Quantization delay parameters according by global step. Default: 0. + freeze_bn (int): Quantization freeze BatchNormal op according by global step. Default: 100000. Inputs: - **x** (Tensor) - Tensor of shape :math:`(N, C_{in}, H_{in}, W_{in})`. @@ -274,13 +440,13 @@ class Conv2dBatchNormQuant(Cell): gamma_init='ones', mean_init='zeros', var_init='ones', - quant_delay=0, - freeze_bn=100000, fake=True, - num_bits=8, per_channel=False, + num_bits=8, symmetric=False, - narrow_range=False): + narrow_range=False, + quant_delay=0, + freeze_bn=100000): """init Conv2dBatchNormQuant layer""" super(Conv2dBatchNormQuant, self).__init__() self.in_channels = in_channels @@ -304,8 +470,8 @@ class Conv2dBatchNormQuant(Cell): # initialize convolution op and Parameter if context.get_context('device_target') == "Ascend" and group > 1: - validator.check_integer('group', group, in_channels, Rel.EQ, 'Conv2dBatchNormQuant') - validator.check_integer('group', group, out_channels, Rel.EQ, 'Conv2dBatchNormQuant') + validator.check_integer('group', group, in_channels, Rel.EQ) + validator.check_integer('group', group, out_channels, Rel.EQ) self.conv = P.DepthwiseConv2dNative(channel_multiplier=1, kernel_size=self.kernel_size, pad_mode=pad_mode, @@ -337,12 +503,13 @@ class Conv2dBatchNormQuant(Cell): self.fake_quant_weight = FakeQuantWithMinMax(min_init=-6, max_init=6, ema=False, - num_bits=num_bits, - quant_delay=quant_delay, per_channel=per_channel, - out_channels=out_channels, + channel_axis=channel_axis, + num_channels=out_channels, + num_bits=num_bits, symmetric=symmetric, - narrow_range=narrow_range) + narrow_range=narrow_range, + quant_delay=quant_delay) self.batchnorm_fold = BatchNormFoldCell(epsilon=eps, momentum=momentum, freeze_bn=freeze_bn) self.correct_mul = Q.CorrectionMul(channel_axis) if context.get_context('device_target') == "Ascend": @@ -416,11 +583,11 @@ class Conv2dQuant(Cell): weight_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the convolution kernel. Default: 'normal'. bias_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the bias vector. Default: 'zeros'. - quant_delay (int): Quantization delay parameters according by global step. Default: 0. - num_bits (int): Quantization number bit, support 4 and 8bit. Default: 8. per_channel (bool): FakeQuantWithMinMax Parameters. Default: False. + num_bits (int): Quantization number bit, support 4 and 8bit. Default: 8. symmetric (bool): Quantization algorithm use symmetric or not. Default: False. narrow_range (bool): Quantization algorithm use narrow range or not. Default: False. + quant_delay (int): Quantization delay parameters according by global step. Default: 0. Inputs: - **x** (Tensor) - Tensor of shape :math:`(N, C_{in}, H_{in}, W_{in})`. @@ -447,11 +614,11 @@ class Conv2dQuant(Cell): has_bias=False, weight_init='normal', bias_init='zeros', - quant_delay=0, - num_bits=8, per_channel=False, + num_bits=8, symmetric=False, - narrow_range=False): + narrow_range=False, + quant_delay=0): super(Conv2dQuant, self).__init__() if isinstance(kernel_size, int): self.kernel_size = (kernel_size, kernel_size) @@ -487,12 +654,13 @@ class Conv2dQuant(Cell): self.fake_quant_weight = FakeQuantWithMinMax(min_init=-6, max_init=6, ema=False, - num_bits=num_bits, - quant_delay=quant_delay, per_channel=per_channel, - out_channels=out_channels, + channel_axis=0, + num_channels=out_channels, + num_bits=num_bits, symmetric=symmetric, - narrow_range=narrow_range) + narrow_range=narrow_range, + quant_delay=quant_delay) def construct(self, x): weight = self.fake_quant_weight(self.weight) @@ -526,11 +694,11 @@ class DenseQuant(Cell): same as input x. The values of str refer to the function `initializer`. Default: 'zeros'. has_bias (bool): Specifies whether the layer uses a bias vector. Default: True. activation (str): Regularizer function applied to the output of the layer, eg. 'relu'. Default: None. - num_bits (int): Quantization number bit, support 4 and 8bit. Default: 8. - quant_delay (int): Quantization delay parameters according by global step. Default: 0. per_channel (bool): FakeQuantWithMinMax Parameters. Default: False. + num_bits (int): Quantization number bit, support 4 and 8bit. Default: 8. symmetric (bool): Quantization algorithm use symmetric or not. Default: False. narrow_range (bool): Quantization algorithm use narrow range or not. Default: False. + quant_delay (int): Quantization delay parameters according by global step. Default: 0. Inputs: - **x** (Tensor) - Tensor of shape :math:`(N, C_{in}, H_{in}, W_{in})`. @@ -552,11 +720,11 @@ class DenseQuant(Cell): bias_init='zeros', has_bias=True, activation=None, - num_bits=8, - quant_delay=0, per_channel=False, + num_bits=8, symmetric=False, - narrow_range=False): + narrow_range=False, + quant_delay=0): super(DenseQuant, self).__init__() self.in_channels = check_int_positive(in_channels) self.out_channels = check_int_positive(out_channels) @@ -586,12 +754,13 @@ class DenseQuant(Cell): self.fake_quant_weight = FakeQuantWithMinMax(min_init=-6, max_init=6, ema=False, - num_bits=num_bits, - quant_delay=quant_delay, per_channel=per_channel, - out_channels=out_channels, + channel_axis=0, + num_channels=out_channels, + num_bits=num_bits, symmetric=symmetric, - narrow_range=narrow_range) + narrow_range=narrow_range, + quant_delay=quant_delay) def construct(self, x): """Use operators to construct to Dense layer.""" @@ -615,18 +784,28 @@ class DenseQuant(Cell): return str_info -class ReLUQuant(Cell): +class _QuantActivation(Cell): + r""" + Base class for Quant activation function. Add Fake Quant OP after activation OP. + """ + + def get_origin(self): + raise NotImplementedError + + +class ReLUQuant(_QuantActivation): r""" ReLUQuant activation function. Add Fake Quant OP after Relu OP. For a more Detailed overview of ReLU op. Args: - num_bits (int): Quantization number bit, support 4 and 8bit. Default: 8. - quant_delay (int): Quantization delay parameters according by global step. Default: 0. ema_decay (float): Exponential Moving Average algorithm parameter. Default: 0.999. + per_channel (bool): Quantization granularity based on layer or on channel. Default: False. + num_bits (int): Quantization number bit, support 4 and 8bit. Default: 8. symmetric (bool): Quantization algorithm use symmetric or not. Default: False. narrow_range (bool): Quantization algorithm use narrow range or not. Default: False. + quant_delay (int): Quantization delay parameters according by global step. Default: 0. Inputs: - **x** (Tensor) - The input of ReLUQuant. @@ -641,20 +820,22 @@ class ReLUQuant(Cell): """ def __init__(self, - num_bits=8, - quant_delay=0, ema_decay=0.999, + per_channel=False, + num_bits=8, symmetric=False, - narrow_range=False): + narrow_range=False, + quant_delay=0): super(ReLUQuant, self).__init__() self.fake_quant_act = FakeQuantWithMinMax(min_init=0, max_init=6, - num_bits=num_bits, - quant_delay=quant_delay, ema=True, ema_decay=ema_decay, + per_channel=per_channel, + num_bits=num_bits, symmetric=symmetric, - narrow_range=narrow_range) + narrow_range=narrow_range, + quant_delay=quant_delay) self.relu = P.ReLU() def construct(self, x): @@ -662,8 +843,11 @@ class ReLUQuant(Cell): x = self.fake_quant_act(x) return x + def get_origin(self): + return self.relu -class ReLU6Quant(Cell): + +class ReLU6Quant(_QuantActivation): r""" ReLU6Quant activation function. @@ -672,11 +856,12 @@ class ReLU6Quant(Cell): For a more Detailed overview of ReLU6 op. Args: - num_bits (int): Quantization number bit, support 4 and 8bit. Default: 8. - quant_delay (int): Quantization delay parameters according by global step. Default: 0. ema_decay (float): Exponential Moving Average algorithm parameter. Default: 0.999. + per_channel (bool): Quantization granularity based on layer or on channel. Default: False. + num_bits (int): Quantization number bit, support 4 and 8bit. Default: 8. symmetric (bool): Quantization algorithm use symmetric or not. Default: False. narrow_range (bool): Quantization algorithm use narrow range or not. Default: False. + quant_delay (int): Quantization delay parameters according by global step. Default: 0. Inputs: - **x** (Tensor) - The input of ReLU6Quant. @@ -691,20 +876,22 @@ class ReLU6Quant(Cell): """ def __init__(self, - num_bits=8, - quant_delay=0, ema_decay=0.999, + per_channel=False, + num_bits=8, symmetric=False, - narrow_range=False): + narrow_range=False, + quant_delay=0): super(ReLU6Quant, self).__init__() self.fake_quant_act = FakeQuantWithMinMax(min_init=0, max_init=6, - num_bits=num_bits, - quant_delay=quant_delay, ema=True, ema_decay=ema_decay, + per_channel=per_channel, + num_bits=num_bits, symmetric=symmetric, - narrow_range=narrow_range) + narrow_range=narrow_range, + quant_delay=quant_delay) self.relu6 = P.ReLU6() def construct(self, x): @@ -712,19 +899,23 @@ class ReLU6Quant(Cell): x = self.fake_quant_act(x) return x + def get_origin(self): + return self.relu6 -class HSwishQuant(Cell): + +class HSwishQuant(_QuantActivation): r""" HSwishQuant activation function. Add Fake Quant OP after HSwish OP. For a more Detailed overview of HSwish op. Args: - num_bits (int): Quantization number bit, support 4 and 8bit. Default: 8. - quant_delay (int): Quantization delay parameters according by global step. Default: 0. ema_decay (float): Exponential Moving Average algorithm parameter. Default: 0.999. + per_channel (bool): Quantization granularity based on layer or on channel. Default: False. + num_bits (int): Quantization number bit, support 4 and 8bit. Default: 8. symmetric (bool): Quantization algorithm use symmetric or not. Default: False. narrow_range (bool): Quantization algorithm use narrow range or not. Default: False. + quant_delay (int): Quantization delay parameters according by global step. Default: 0. Inputs: - **x** (Tensor) - The input of HSwishQuant. @@ -739,28 +930,31 @@ class HSwishQuant(Cell): """ def __init__(self, - num_bits=8, - quant_delay=0, ema_decay=0.999, + per_channel=False, + num_bits=8, symmetric=False, - narrow_range=False): + narrow_range=False, + quant_delay=0): super(HSwishQuant, self).__init__() self.fake_quant_act_before = FakeQuantWithMinMax(min_init=-6, max_init=6, - num_bits=num_bits, - quant_delay=quant_delay, ema=True, ema_decay=ema_decay, + per_channel=per_channel, + num_bits=num_bits, symmetric=symmetric, - narrow_range=narrow_range) + narrow_range=narrow_range, + quant_delay=quant_delay) self.fake_quant_act_after = FakeQuantWithMinMax(min_init=-6, max_init=6, - num_bits=num_bits, - quant_delay=quant_delay, ema=True, ema_decay=ema_decay, + per_channel=per_channel, + num_bits=num_bits, symmetric=symmetric, - narrow_range=narrow_range) + narrow_range=narrow_range, + quant_delay=quant_delay) self.act = P.HSwish() def construct(self, x): @@ -769,19 +963,23 @@ class HSwishQuant(Cell): x = self.fake_quant_act_after(x) return x + def get_origin(self): + return self.act -class HSigmoidQuant(Cell): + +class HSigmoidQuant(_QuantActivation): r""" HSigmoidQuant activation function. Add Fake Quant OP before and after HSigmoid OP. For a more Detailed overview of HSigmoid op. Args: - num_bits (int): Quantization number bit, support 4 and 8bit. Default: 8. - quant_delay (int): Quantization delay parameters according by global step. Default: 0. ema_decay (float): Exponential Moving Average algorithm parameter. Default: 0.999. + per_channel (bool): Quantization granularity based on layer or on channel. Default: False. + num_bits (int): Quantization number bit, support 4 and 8bit. Default: 8. symmetric (bool): Quantization algorithm use symmetric or not. Default: False. narrow_range (bool): Quantization algorithm use narrow range or not. Default: False. + quant_delay (int): Quantization delay parameters according by global step. Default: 0. Inputs: - **x** (Tensor) - The input of HSigmoidQuant. @@ -796,27 +994,31 @@ class HSigmoidQuant(Cell): """ def __init__(self, - num_bits=8, - quant_delay=0, ema_decay=0.999, + per_channel=False, + num_bits=8, symmetric=False, - narrow_range=False): + narrow_range=False, + quant_delay=0): super(HSigmoidQuant, self).__init__() self.fake_quant_act_before = FakeQuantWithMinMax(min_init=-6, max_init=6, - num_bits=num_bits, - quant_delay=quant_delay, ema=True, + ema_decay=ema_decay, + per_channel=per_channel, + num_bits=num_bits, symmetric=symmetric, - narrow_range=narrow_range) + narrow_range=narrow_range, + quant_delay=quant_delay) self.fake_quant_act_after = FakeQuantWithMinMax(min_init=-6, max_init=6, - num_bits=num_bits, - quant_delay=quant_delay, ema=True, ema_decay=ema_decay, + per_channel=per_channel, + num_bits=num_bits, symmetric=symmetric, - narrow_range=narrow_range) + narrow_range=narrow_range, + quant_delay=quant_delay) self.act = P.HSigmoid() def construct(self, x): @@ -825,6 +1027,9 @@ class HSigmoidQuant(Cell): x = self.fake_quant_act_after(x) return x + def get_origin(self): + return self.act + class TensorAddQuant(Cell): r""" @@ -833,11 +1038,12 @@ class TensorAddQuant(Cell): For a more Detailed overview of TensorAdd op. Args: - num_bits (int): Quantization number bit, support 4 and 8bit. Default: 8. - quant_delay (int): Quantization delay parameters according by global step. Default: 0. ema_decay (float): Exponential Moving Average algorithm parameter. Default: 0.999. + per_channel (bool): Quantization granularity based on layer or on channel. Default: False. + num_bits (int): Quantization number bit, support 4 and 8bit. Default: 8. symmetric (bool): Quantization algorithm use symmetric or not. Default: False. narrow_range (bool): Quantization algorithm use narrow range or not. Default: False. + quant_delay (int): Quantization delay parameters according by global step. Default: 0. Inputs: - **x** (Tensor) - The input of TensorAddQuant. @@ -853,20 +1059,22 @@ class TensorAddQuant(Cell): """ def __init__(self, - num_bits=8, - quant_delay=0, ema_decay=0.999, + per_channel=False, + num_bits=8, symmetric=False, - narrow_range=False): + narrow_range=False, + quant_delay=0): super(TensorAddQuant, self).__init__() self.fake_quant_act = FakeQuantWithMinMax(min_init=-6, max_init=6, - num_bits=num_bits, - quant_delay=quant_delay, ema=True, ema_decay=ema_decay, + per_channel=per_channel, + num_bits=num_bits, symmetric=symmetric, - narrow_range=narrow_range) + narrow_range=narrow_range, + quant_delay=quant_delay) self.add = P.TensorAdd() def construct(self, x1, x2): @@ -882,11 +1090,12 @@ class MulQuant(Cell): For a more Detailed overview of Mul op. Args: - num_bits (int): Quantization number bit, support 4 and 8bit. Default: 8. - quant_delay (int): Quantization delay parameters according by global step. Default: 0. ema_decay (float): Exponential Moving Average algorithm parameter. Default: 0.999. + per_channel (bool): Quantization granularity based on layer or on channel. Default: False. + num_bits (int): Quantization number bit, support 4 and 8bit. Default: 8. symmetric (bool): Quantization algorithm use symmetric or not. Default: False. narrow_range (bool): Quantization algorithm use narrow range or not. Default: False. + quant_delay (int): Quantization delay parameters according by global step. Default: 0. Inputs: - **x** (Tensor) - The input of MulQuant. @@ -897,23 +1106,99 @@ class MulQuant(Cell): """ def __init__(self, - num_bits=8, - quant_delay=0, ema_decay=0.999, + per_channel=False, + num_bits=8, symmetric=False, - narrow_range=False): + narrow_range=False, + quant_delay=0): super(MulQuant, self).__init__() self.fake_quant_act = FakeQuantWithMinMax(min_init=-6, max_init=6, - num_bits=num_bits, - quant_delay=quant_delay, ema=True, ema_decay=ema_decay, + per_channel=per_channel, + num_bits=num_bits, symmetric=symmetric, - narrow_range=narrow_range) + narrow_range=narrow_range, + quant_delay=quant_delay) self.mul = P.Mul() def construct(self, x1, x2): x = self.mul(x1, x2) x = self.fake_quant_act(x) return x + + +class QuantBlock(Cell): + r""" + A quant block of Conv/Dense, activation layer for Ascend deploy. + + Calculate Conv or Dense in Int8, with AscendQuant and AscendDeQuant. + + Notes: + This block is only for deploy, and not trainable. + + Args: + in_channels (int): The number of channels in the input space. + out_channels (int): The number of channels in the output space. + weight_init (Union[Tensor, str, Initializer, numbers.Number]): The trainable weight_init parameter. The dtype + is same as input x. The values of str refer to the function `initializer`. Default: 'normal'. + bias_init (Union[Tensor, str, Initializer, numbers.Number]): The trainable bias_init parameter. The dtype is + same as input x. The values of str refer to the function `initializer`. Default: 'zeros'. + has_bias (bool): Specifies whether the layer uses a bias vector. Default: True. + activation (str): Regularizer function applied to the output of the layer, eg. 'relu'. Default: None. + batchnorm (bool): Specifies to used batchnorm or not. Default: None. + activation (string): Specifies activation type. The optional values are as following: + 'softmax', 'logsoftmax', 'relu', 'relu6', 'tanh', 'gelu', 'sigmoid', + 'prelu', 'leakyrelu', 'hswish', 'hsigmoid'. Default: None. + + Inputs: + - **input** (Tensor) - Tensor of shape :math:`(N, in\_channels)`. + + Outputs: + Tensor of shape :math:`(N, out\_channels)`. + + Examples: + >>> net = nn.Dense(3, 4) + >>> input = Tensor(np.random.randint(0, 255, [2, 3]), mindspore.float32) + >>> net(input) + """ + + def __init__(self, + core_op, + weight, + quant_op, + dequant_op, + dequant_scale, + bias=None, + activation=None): + super(QuantBlock, self).__init__() + self.core_op = core_op + self.weight = weight + self.quant = quant_op + self.dequant = dequant_op + self.dequant_scale = dequant_scale + self.bias = bias + self.has_bias = bias is None + self.activation = activation + self.has_act = activation is None + + def construct(self, x): + x = self.quant(x) + x = self.core_op(x, self.weight) + if self.has_bias: + output = self.bias_add(output, self.bias) + if self.has_act: + x = self.activation(x) + x = self.dequant(x, self.dequant_scale) + return x + + def extend_repr(self): + str_info = f'quant={self.quant}, core_op={type(self.core_op)}' + if self.has_bias: + str_info = str_info + f', bias={self.bias}' + if self.has_act: + str_info = str_info + f', activation={self.activation}' + str_info = str_info + f', dequant={self.dequant}' + return str_info diff --git a/mindspore/ops/_grad/grad_quant_ops.py b/mindspore/ops/_grad/grad_quant_ops.py index da19662e97..a2b0ba8d97 100644 --- a/mindspore/ops/_grad/grad_quant_ops.py +++ b/mindspore/ops/_grad/grad_quant_ops.py @@ -13,7 +13,7 @@ # limitations under the License. # ============================================================================ -"""Generate bprop for aware quantization ops""" +"""Generate bprop for quantization aware ops""" from .. import operations as P from ..operations import _quant_ops as Q @@ -133,9 +133,9 @@ def get_bprop_batchnorm_fold2_(self): return bprop -@bprop_getters.register(Q.FakeQuantMinMaxPerLayerUpdate) +@bprop_getters.register(Q.MinMaxUpdatePerLayer) def get_bprop_fakequant_with_minmax_per_layer_update(self): - """Generate bprop for FakeQuantMinMaxPerLayerUpdate for Ascend""" + """Generate bprop for MinMaxUpdatePerLayer for Ascend""" def bprop(x, x_min, x_max, out, dout): return zeros_like(x), zeros_like(x_min), zeros_like(x_max) @@ -143,9 +143,9 @@ def get_bprop_fakequant_with_minmax_per_layer_update(self): return bprop -@bprop_getters.register(Q.FakeQuantMinMaxPerChannelUpdate) +@bprop_getters.register(Q.MinMaxUpdatePerChannel) def get_bprop_fakequant_with_minmax_per_channel_update(self): - """Generate bprop for FakeQuantMinMaxPerChannelUpdate for Ascend""" + """Generate bprop for MinMaxUpdatePerChannel for Ascend""" def bprop(x, x_min, x_max, out, dout): return zeros_like(x), zeros_like(x_min), zeros_like(x_max) diff --git a/mindspore/ops/_op_impl/_custom_op/fake_quant_minmax_perchannel_update.py b/mindspore/ops/_op_impl/_custom_op/minmax_update_perchannel.py similarity index 57% rename from mindspore/ops/_op_impl/_custom_op/fake_quant_minmax_perchannel_update.py rename to mindspore/ops/_op_impl/_custom_op/minmax_update_perchannel.py index fee7f3ed1b..f29fc53325 100644 --- a/mindspore/ops/_op_impl/_custom_op/fake_quant_minmax_perchannel_update.py +++ b/mindspore/ops/_op_impl/_custom_op/minmax_update_perchannel.py @@ -1,4 +1,3 @@ - # Copyright 2020 Huawei Technologies Co., Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -14,7 +13,7 @@ # limitations under the License. # ============================================================================ -"""FakeQuantMinMaxPerChannelUpdate op""" +"""MinMaxUpdatePerChannel op""" import te.lang.cce from te import tvm from te.platform.fusion_manager import fusion_manager @@ -22,20 +21,15 @@ from topi import generic from topi.cce import util from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType - -fake_quant_min_max_per_channel_update_op_info = TBERegOp("FakeQuantMinMaxPerChannelUpdate") \ +minmax_update_perchannel_op_info = TBERegOp("MinMaxUpdatePerChannel") \ .fusion_type("OPAQUE") \ .async_flag(False) \ - .binfile_name("fake_quant_min_max_per_channel_update.so") \ + .binfile_name("minmax_update_perchannel.so") \ .compute_cost(10) \ - .kernel_name("fake_quant_min_max_per_channel_update") \ + .kernel_name("minmax_update_perchannel") \ .partial_flag(True) \ .attr("ema", "optional", "bool", "all") \ .attr("ema_decay", "optional", "float", "all") \ - .attr("symmetric", "optional", "bool", "all") \ - .attr("narrow_range", "optional", "bool", "all") \ - .attr("training", "optional", "bool", "all") \ - .attr("num_bits", "optional", "int", "all") \ .attr("channel_axis", "optional", "int", "all") \ .input(0, "x", None, "required", None) \ .input(1, "min", None, "required", None) \ @@ -47,43 +41,46 @@ fake_quant_min_max_per_channel_update_op_info = TBERegOp("FakeQuantMinMaxPerChan .get_op_info() -@op_info_register(fake_quant_min_max_per_channel_update_op_info) -def _fake_quant_min_max_per_channel_update_tbe(): - """FakeQuantPerChannelUpdate TBE register""" +@op_info_register(minmax_update_perchannel_op_info) +def _minmax_update_perchannel_tbe(): + """MinMaxUpdatePerChannel TBE register""" return -@fusion_manager.register("fake_quant_min_max_per_channel_update") -def fake_quant_min_max_per_channel_update_compute(x, min_val, max_val, - ema, ema_decay, quant_min, quant_max, training, channel_axis, - kernel_name="fake_quant_min_max_per_channel_update"): - """FakeQuantPerChannelUpdate compute""" +@fusion_manager.register("minmax_update_perchannel") +def minmax_update_perchannel_compute(x, min_val, max_val, + ema, ema_decay, channel_axis): + """MinMaxUpdatePerChannel compute""" shape_min = te.lang.cce.util.shape_to_list(min_val.shape) if not ema: ema_decay = 0.0 - if training: - # CalMinMax + + # CalMinMax + if channel_axis == 0: + axis = [1, 2, 3, 4] + else: axis = [0, 2, 3] - x_min = te.lang.cce.reduce_min(x, axis=axis) - x_max = te.lang.cce.reduce_max(x, axis=axis) - x_min = te.lang.cce.broadcast(x_min, shape_min) - x_max = te.lang.cce.broadcast(x_max, shape_min) - min_val = te.lang.cce.vadd(te.lang.cce.vmuls( - min_val, ema_decay), te.lang.cce.vmuls(x_min, (1 - ema_decay))) - max_val = te.lang.cce.vadd(te.lang.cce.vmuls( - max_val, ema_decay), te.lang.cce.vmuls(x_max, (1 - ema_decay))) - min_val = te.lang.cce.vmins(min_val, 0) - max_val = te.lang.cce.vmaxs(max_val, 0) + + x_min = te.lang.cce.reduce_min(x, axis=axis) + x_max = te.lang.cce.reduce_max(x, axis=axis) + x_min = te.lang.cce.broadcast(x_min, shape_min) + x_max = te.lang.cce.broadcast(x_max, shape_min) + min_val = te.lang.cce.vadd(te.lang.cce.vmuls( + min_val, ema_decay), te.lang.cce.vmuls(x_min, (1 - ema_decay))) + max_val = te.lang.cce.vadd(te.lang.cce.vmuls( + max_val, ema_decay), te.lang.cce.vmuls(x_max, (1 - ema_decay))) + min_val = te.lang.cce.vmins(min_val, 0) + max_val = te.lang.cce.vmaxs(max_val, 0) return [min_val, max_val] -@util.check_input_type(dict, dict, dict, dict, dict, bool, float, bool, bool, bool, int, int, str) -def fake_quant_min_max_per_channel_update(x, min_val, max_val, min_up, max_up, - ema, ema_decay, symmetric, narrow_range, training, num_bits, channel_axis, - kernel_name="fake_quant_min_max_per_channel_update"): - """FakeQuantPerLayer op""" +@util.check_input_type(dict, dict, dict, dict, dict, bool, float, int, str) +def minmax_update_perchannel(x, min_val, max_val, min_up, max_up, + ema, ema_decay, channel_axis, + kernel_name="minmax_update_perchannel"): + """MinMaxUpdatePerChannel op""" x_shape = x.get("ori_shape") x_format = x.get("format") x_dtype = x.get("dtype") @@ -112,21 +109,15 @@ def fake_quant_min_max_per_channel_update(x, min_val, max_val, min_up, max_up, util.check_dtype_rule(min_dtype, check_list) util.check_dtype_rule(max_dtype, check_list) - if symmetric: - quant_min = 0 - 2 ** (num_bits - 1) - quant_max = 2 ** (num_bits - 1) - 1 + if channel_axis_ == 0: + shape_c = min_val.get("ori_shape") else: - quant_min = 0 - quant_max = 2 ** num_bits - 1 - if narrow_range: - quant_min = quant_min + 1 - - shape_c = [min_val.get("shape")[1], min_val.get("shape")[-1]] + shape_c = [min_val.get("shape")[1], min_val.get("shape")[-1]] input_data = tvm.placeholder(x.get("shape"), name="x", dtype=x_dtype) min_data = tvm.placeholder(shape_c, name="min_val", dtype=x_dtype) max_data = tvm.placeholder(shape_c, name="max_val", dtype=x_dtype) - res_list = fake_quant_min_max_per_channel_update_compute(input_data, min_data, max_data, - ema, ema_decay, quant_min, quant_max, training, channel_axis_, kernel_name) + res_list = minmax_update_perchannel_compute(input_data, min_data, max_data, + ema, ema_decay, channel_axis_) with tvm.target.cce(): sch = generic.auto_schedule(res_list) diff --git a/mindspore/ops/_op_impl/_custom_op/fake_quant_minmax_perlayer_update.py b/mindspore/ops/_op_impl/_custom_op/minmax_update_perlayer.py similarity index 61% rename from mindspore/ops/_op_impl/_custom_op/fake_quant_minmax_perlayer_update.py rename to mindspore/ops/_op_impl/_custom_op/minmax_update_perlayer.py index 0ad2315bb3..4d2096d55b 100644 --- a/mindspore/ops/_op_impl/_custom_op/fake_quant_minmax_perlayer_update.py +++ b/mindspore/ops/_op_impl/_custom_op/minmax_update_perlayer.py @@ -13,7 +13,7 @@ # limitations under the License. # ============================================================================ -"""FakeQuantMinMaxPerLayerUpdate op""" +"""MinMaxUpdatePerLayer op""" from functools import reduce as functools_reduce import te.lang.cce from te import tvm @@ -22,20 +22,15 @@ from topi import generic from topi.cce import util from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType - -fake_quant_minmax_update_op_info = TBERegOp("FakeQuantMinMaxPerLayerUpdate") \ +minmax_update_perlayer_op_info = TBERegOp("MinMaxUpdatePerLayer") \ .fusion_type("OPAQUE") \ .async_flag(False) \ - .binfile_name("fake_quant_minmax_update.so") \ + .binfile_name("minmax_update_perlayer.so") \ .compute_cost(10) \ - .kernel_name("fake_quant_minmax_update") \ + .kernel_name("minmax_update_perlayer") \ .partial_flag(True) \ .attr("ema", "optional", "bool", "all") \ .attr("ema_decay", "optional", "float", "all") \ - .attr("symmetric", "optional", "bool", "all") \ - .attr("narrow_range", "optional", "bool", "all") \ - .attr("training", "optional", "bool", "all") \ - .attr("num_bits", "optional", "int", "all") \ .input(0, "x", None, "required", None) \ .input(1, "min", None, "required", None) \ .input(2, "max", None, "required", None) \ @@ -46,44 +41,42 @@ fake_quant_minmax_update_op_info = TBERegOp("FakeQuantMinMaxPerLayerUpdate") \ .get_op_info() -@op_info_register(fake_quant_minmax_update_op_info) -def _fake_quant_minmax_update_tbe(): - """FakeQuantMinMaxPerLayerUpdate TBE register""" +@op_info_register(minmax_update_perlayer_op_info) +def _minmax_update_perlayer_tbe(): + """MinMaxUpdatePerLayer TBE register""" return -@fusion_manager.register("fake_quant_minmax_update") -def fake_quant_minmax_update_compute(x, min_val, max_val, ema, ema_decay, quant_min, quant_max, training, - kernel_name="fake_quant_minmax_update"): - """FakeQuantMinMaxPerLayerUpdate compute""" +@fusion_manager.register("minmax_update_perlayer") +def minmax_update_perlayer_compute(x, min_val, max_val, ema, ema_decay): + """MinMaxUpdatePerLayer compute""" shape = te.lang.cce.util.shape_to_list(x.shape) shape_min = te.lang.cce.util.shape_to_list(min_val.shape) min_val = te.lang.cce.broadcast(min_val, shape_min, x.dtype) max_val = te.lang.cce.broadcast(max_val, shape_min, x.dtype) if not ema: ema_decay = 0.0 - if training: - # CalMinMax - axis = tuple(range(len(shape))) - x_min = te.lang.cce.reduce_min(x, axis=axis) - x_max = te.lang.cce.reduce_max(x, axis=axis) - x_min = te.lang.cce.broadcast(x_min, shape_min) - x_max = te.lang.cce.broadcast(x_max, shape_min) - min_val = te.lang.cce.vadd(te.lang.cce.vmuls( - min_val, ema_decay), te.lang.cce.vmuls(x_min, (1 - ema_decay))) - max_val = te.lang.cce.vadd(te.lang.cce.vmuls( - max_val, ema_decay), te.lang.cce.vmuls(x_max, (1 - ema_decay))) - min_val = te.lang.cce.vmins(min_val, 0) - max_val = te.lang.cce.vmaxs(max_val, 0) + + # CalMinMax + axis = tuple(range(len(shape))) + x_min = te.lang.cce.reduce_min(x, axis=axis) + x_max = te.lang.cce.reduce_max(x, axis=axis) + x_min = te.lang.cce.broadcast(x_min, shape_min) + x_max = te.lang.cce.broadcast(x_max, shape_min) + min_val = te.lang.cce.vadd(te.lang.cce.vmuls( + min_val, ema_decay), te.lang.cce.vmuls(x_min, (1 - ema_decay))) + max_val = te.lang.cce.vadd(te.lang.cce.vmuls( + max_val, ema_decay), te.lang.cce.vmuls(x_max, (1 - ema_decay))) + min_val = te.lang.cce.vmins(min_val, 0) + max_val = te.lang.cce.vmaxs(max_val, 0) return [min_val, max_val] -@util.check_input_type(dict, dict, dict, dict, dict, bool, float, bool, bool, bool, int, str) -def fake_quant_minmax_update(x, min_val, max_val, min_up, max_up, - ema, ema_decay, symmetric, narrow_range, training, num_bits, - kernel_name="fake_quant_minmax_update"): - """FakeQuantPerLayer op""" +@util.check_input_type(dict, dict, dict, dict, dict, bool, float, str) +def minmax_update_perlayer(x, min_val, max_val, min_up, max_up, + ema, ema_decay, kernel_name="minmax_update_perlayer"): + """MinMaxUpdatePerLayer op""" input_shape = x.get("shape") input_dtype = x.get("dtype") min_shape = min_val.get("ori_shape") @@ -112,20 +105,10 @@ def fake_quant_minmax_update(x, min_val, max_val, min_up, max_up, input_shape = (functools_reduce(lambda x, y: x * y, input_shape[:]),) shape_min, _, _ = util.produce_shapes(min_shape, input_shape) - if symmetric: - quant_min = 0 - 2 ** (num_bits - 1) - quant_max = 2 ** (num_bits - 1) - 1 - else: - quant_min = 0 - quant_max = 2 ** num_bits - 1 - if narrow_range: - quant_min = quant_min + 1 - input_data = tvm.placeholder(input_shape, name="x", dtype=x_dtype) min_data = tvm.placeholder(shape_min, name="min_data", dtype=min_dtype) max_data = tvm.placeholder(shape_min, name="max_data", dtype=max_dtype) - res_list = fake_quant_minmax_update_compute(input_data, min_data, max_data, - ema, ema_decay, quant_min, quant_max, training, kernel_name) + res_list = minmax_update_perlayer_compute(input_data, min_data, max_data, ema, ema_decay) with tvm.target.cce(): sch = generic.auto_schedule(res_list) diff --git a/mindspore/ops/operations/_quant_ops.py b/mindspore/ops/operations/_quant_ops.py index 852b3c638e..42c2406906 100644 --- a/mindspore/ops/operations/_quant_ops.py +++ b/mindspore/ops/operations/_quant_ops.py @@ -21,12 +21,12 @@ from ..._checkparam import Rel from ..primitive import PrimitiveWithInfer, prim_attr_register from ...common import dtype as mstype -__all__ = ["FakeQuantPerLayer", +__all__ = ["MinMaxUpdatePerLayer", + "MinMaxUpdatePerChannel", + "FakeQuantPerLayer", "FakeQuantPerLayerGrad", "FakeQuantPerChannel", "FakeQuantPerChannelGrad", - "FakeQuantMinMaxPerLayerUpdate", - "FakeQuantMinMaxPerChannelUpdate", "BatchNormFold", "BatchNormFoldGrad", "CorrectionMul", @@ -36,23 +36,140 @@ __all__ = ["FakeQuantPerLayer", "BatchNormFold2Grad", "BatchNormFoldD", "BatchNormFoldGradD", - "BNTrainingReduce", "BatchNormFold2_D", "BatchNormFold2GradD", - "BatchNormFold2GradReduce", + "BatchNormFold2GradReduce" ] +class MinMaxUpdatePerLayer(PrimitiveWithInfer): + r""" + Update min and max per layer. + + Args: + ema (bool): Use EMA algorithm update value min and max. Default: False. + ema_decay (int) : EMA algorithm decay parameter. Default: 0.999. + + Inputs: + - **x** (Tensor) : float32 Tensor representing the shape of the output tensor. + - **min** (Tensor) : Value of the min range of the input data x. + - **max** (Tensor) : Value of the max range of the input data x. + + Outputs: + - Tensor: Simulate quantize tensor of x. + + Examples: + >>> input_tensor = Tensor(np.random.rand(3, 16, 5, 5), mstype.float32) + >>> min_tensor = Tensor(np.array([-6]), mstype.float32) + >>> max_tensor = Tensor(np.array([6]), mstype.float32) + >>> output_tensor = MinMaxUpdatePerLayer(num_bits=8)(input_tensor, min_tensor, max_tensor) + """ + support_quant_bit = [4, 7, 8] + + @prim_attr_register + def __init__(self, ema=False, ema_decay=0.999): + """init FakeQuantMinMaxPerLayerUpdate OP""" + if context.get_context('device_target') == "Ascend": + from mindspore.ops._op_impl._custom_op import minmax_update_perlayer + if ema and not ema_decay: + raise ValueError( + f"For '{self.name}' attr \'ema\' and \'ema_decay\' should set together.") + + self.ema = validator.check_value_type('ema', ema, (bool,), self.name) + self.ema_decay = validator.check_number_range( + 'ema_decay', ema_decay, 0, 1, Rel.INC_BOTH, self.name) + self.init_prim_io_names(inputs=['x', 'min', 'max'], + outputs=['min_up', 'max_up']) + + def infer_shape(self, x_shape, min_shape, max_shape): + validator.check_integer("x rank", len(x_shape), 1, Rel.GE, self.name) + validator.check("min shape", min_shape, "max shape", + max_shape, Rel.EQ, self.name) + validator.check_integer("min shape", len( + min_shape), 1, Rel.EQ, self.name) + return min_shape, max_shape + + def infer_dtype(self, x_type, min_type, max_type): + valid_types = (mstype.float16, mstype.float32) + validator.check_tensor_type_same({"x": x_type}, valid_types, self.name) + validator.check_tensor_type_same( + {"min": min_type}, valid_types, self.name) + validator.check_tensor_type_same( + {"max": max_type}, valid_types, self.name) + return min_type, max_type + + +class MinMaxUpdatePerChannel(PrimitiveWithInfer): + r""" + Update min and max per channel. + + Args: + ema (bool): Use EMA algorithm update value min and max. Default: False. + ema_decay (int) : EMA algorithm decay parameter. Default: 0.999. + channel_axis (int): Channel asis for per channel compute. Default: 1. + + Inputs: + - **x** (Tensor) : float32 Tensor representing the shape of the output tensor. + - **min** (Tensor) : Value of the min range of the input data x. + - **max** (Tensor) : Value of the max range of the input data x. + + Outputs: + - Tensor: Simulate quantize tensor of x. + + Examples: + >>> x = Tensor(np.random.rand(3, 16, 5, 5), mstype.float32) + >>> min = Tensor(np.random.uniform(-1, 1, size=16), mstype.float32) + >>> max = Tensor(np.random.uniform(-1, 1, size=16), mstype.float32) + >>> output_tensor = MinMaxUpdatePerChannel(num_bits=8)(x, min, max) + """ + support_quant_bit = [4, 7, 8] + + @prim_attr_register + def __init__(self, ema=False, ema_decay=0.999, channel_axis=1): + """init FakeQuantPerChannelUpdate OP for Ascend""" + if context.get_context('device_target') == "Ascend": + from mindspore.ops._op_impl._custom_op import minmax_update_perchannel + if ema and not ema_decay: + raise ValueError( + f"For '{self.name}' attr \'ema\' and \'ema_decay\' should set together.") + + self.ema = validator.check_value_type('ema', ema, (bool,), self.name) + self.ema_decay = validator.check_number_range( + 'ema_decay', ema_decay, 0, 1, Rel.INC_BOTH, self.name) + self.channel_axis = validator.check_integer( + 'channel axis', channel_axis, 0, Rel.GE, self.name) + self.init_prim_io_names( + inputs=['x', 'min', 'max'], outputs=['min_up', 'max_up']) + + def infer_shape(self, x_shape, min_shape, max_shape): + validator.check_integer("x rank", len(x_shape), 1, Rel.GT, self.name) + validator.check("min shape", min_shape, "max shape", + max_shape, Rel.EQ, self.name) + validator.check_integer("min shape", len( + min_shape), 1, Rel.EQ, self.name) + return min_shape, max_shape + + def infer_dtype(self, x_type, min_type, max_type): + valid_types = (mstype.float16, mstype.float32) + validator.check_tensor_type_same( + {"x": x_type}, valid_types, self.name) + validator.check_tensor_type_same( + {"min": min_type}, valid_types, self.name) + validator.check_tensor_type_same( + {"max": max_type}, valid_types, self.name) + return min_type, max_type + + class FakeQuantPerLayer(PrimitiveWithInfer): r""" Simulate the quantize and dequantize operations in training time. Args: - num_bits (int) : Number bits for aware quantilization. Default: 8. + num_bits (int) : Number bits for quantization aware. Default: 8. ema (bool): Use EMA algorithm update value min and max. Default: False. ema_decay (int) : EMA algorithm decay parameter. Default: 0.999. quant_delay (int): Quantilization delay parameter. Before delay step in training time not update - simulate aware quantize funcion. After delay step in training time begin simulate the aware + simulate quantization aware funcion. After delay step in training time begin simulate the aware quantize funcion. Default: 0. symmetric (bool): Quantization algorithm use symmetric or not. Default: False. narrow_range (bool): Quantization algorithm use narrow range or not. Default: False. @@ -334,7 +451,7 @@ class BatchNormFold(PrimitiveWithInfer): Batch normalization folded. Args: - momentum (float): Momentum value should be [0, 1]. Default: 0.1. + momentum (float): Momentum value should be [0, 1]. Default: 0.9. epsilon (float): A small float number to avoid dividing by 0. 1e-5 if dtype in float32 else 1e-3. Default: 1e-5. is_training (bool): In training mode set True, else set False. Default: True. @@ -366,7 +483,7 @@ class BatchNormFold(PrimitiveWithInfer): channel_axis = 1 @prim_attr_register - def __init__(self, momentum=0.1, epsilon=1e-5, is_training=True, freeze_bn=0): + def __init__(self, momentum=0.9, epsilon=1e-5, is_training=True, freeze_bn=0): """init batch norm fold layer""" self.momentum = validator.check_number_range('momentum', momentum, 0, 1, Rel.INC_BOTH, self.name) self.epsilon = validator.check_float_positive('epsilon', epsilon, self.name) @@ -731,32 +848,6 @@ class BatchNormFoldGradD(PrimitiveWithInfer): return x_type -class BNTrainingReduce(PrimitiveWithInfer): - """ - reduce sum at axis [0, 2, 3]. - - Inputs: - - **x** (Tensor) - Tensor of shape :math:`(N, C)`. - - Outputs: - - **x_sum** (Tensor) - Tensor has the same shape as x. - - **x_square_sum** (Tensor) - Tensor has the same shape as x. - - """ - - @prim_attr_register - def __init__(self): - """init _BNTrainingReduce layer""" - self.init_prim_io_names(inputs=['x'], - outputs=['x_sum', 'x_square_sum']) - - def infer_shape(self, x_shape): - return [x_shape[1]], [x_shape[1]] - - def infer_dtype(self, x_type): - return x_type, x_type - - class BatchNormFold2_D(PrimitiveWithInfer): """ Scale the bias with a correction factor to the long term statistics @@ -859,153 +950,3 @@ class BatchNormFold2GradReduce(PrimitiveWithInfer): def infer_dtype(self, dout_type, x_type): validator.check("dout type", dout_type, "x type", x_type) return dout_type, dout_type - - -class FakeQuantMinMaxPerLayerUpdate(PrimitiveWithInfer): - r""" - Update min and max value for fake quant per layer op. - - Args: - num_bits (int) : Number bits for aware quantilization. Default: 8. - ema (bool): Use EMA algorithm update value min and max. Default: False. - ema_decay (int) : EMA algorithm decay parameter. Default: 0.999. - symmetric (bool): Quantization algorithm use symmetric or not. Default: False. - narrow_range (bool): Quantization algorithm use narrow range or not. Default: False. - training (bool): Training the network or not. Default: True. - - Inputs: - - **x** (Tensor) : float32 Tensor representing the shape of the output tensor. - - **min** (Tensor) : Value of the min range of the input data x. - - **max** (Tensor) : Value of the max range of the input data x. - - Outputs: - - Tensor: Simulate quantize tensor of x. - - Examples: - >>> input_tensor = Tensor(np.random.rand(3, 16, 5, 5), mstype.float32) - >>> min_tensor = Tensor(np.array([-6]), mstype.float32) - >>> max_tensor = Tensor(np.array([6]), mstype.float32) - >>> output_tensor = FakeQuantWithMinMax(num_bits=8)(input_tensor, min_tensor, max_tensor) - """ - support_quant_bit = [4, 7, 8] - - @prim_attr_register - def __init__(self, num_bits=8, ema=False, ema_decay=0.999, symmetric=False, narrow_range=False, - training=True): - """init FakeQuantMinMaxPerLayerUpdate OP""" - if context.get_context('device_target') == "Ascend": - from mindspore.ops._op_impl._custom_op import fake_quant_minmax_perlayer_update - if num_bits not in self.support_quant_bit: - raise ValueError( - f"For '{self.name}' attr \'num_bits\' is not support.") - if ema and not ema_decay: - raise ValueError( - f"For '{self.name}' attr \'ema\' and \'ema_decay\' should set together.") - - self.ema = validator.check_value_type('ema', ema, (bool,), self.name) - self.symmetric = validator.check_value_type( - 'symmetric', symmetric, (bool,), self.name) - self.narrow_range = validator.check_value_type( - 'narrow_range', narrow_range, (bool,), self.name) - self.training = validator.check_value_type( - 'training', training, (bool,), self.name) - self.ema_decay = validator.check_number_range( - 'ema_decay', ema_decay, 0, 1, Rel.INC_BOTH, self.name) - self.num_bits = validator.check_integer( - 'num_bits', num_bits, 0, Rel.GT, self.name) - self.init_prim_io_names(inputs=['x', 'min', 'max'], - outputs=['min_up', 'max_up']) - - def infer_shape(self, x_shape, min_shape, max_shape): - validator.check_integer("x rank", len(x_shape), 1, Rel.GE, self.name) - validator.check("min shape", min_shape, "max shape", - max_shape, Rel.EQ, self.name) - validator.check_integer("min shape", len( - min_shape), 1, Rel.EQ, self.name) - return min_shape, max_shape - - def infer_dtype(self, x_type, min_type, max_type): - valid_types = (mstype.float16, mstype.float32) - validator.check_tensor_type_same({"x": x_type}, valid_types, self.name) - validator.check_tensor_type_same( - {"min": min_type}, valid_types, self.name) - validator.check_tensor_type_same( - {"max": max_type}, valid_types, self.name) - return min_type, max_type - - -class FakeQuantMinMaxPerChannelUpdate(PrimitiveWithInfer): - r""" - Update min and max value for fake quant per layer op. - - Args: - num_bits (int) : Number bits for aware quantilization. Default: 8. - ema (bool): Use EMA algorithm update value min and max. Default: False. - ema_decay (int) : EMA algorithm decay parameter. Default: 0.999. - symmetric (bool): Quantization algorithm use symmetric or not. Default: False. - narrow_range (bool): Quantization algorithm use narrow range or not. Default: False. - training (bool): Training the network or not. Default: True. - channel_axis (int): Channel asis for per channel compute. Default: 1. - - Inputs: - - **x** (Tensor) : float32 Tensor representing the shape of the output tensor. - - **min** (Tensor) : Value of the min range of the input data x. - - **max** (Tensor) : Value of the max range of the input data x. - - Outputs: - - Tensor: Simulate quantize tensor of x. - - Examples: - >>> x = Tensor(np.random.rand(3, 16, 5, 5), mstype.float32) - >>> min = Tensor(np.random.uniform(-1, 1, size=16), mstype.float32) - >>> max = Tensor(np.random.uniform(-1, 1, size=16), mstype.float32) - >>> output_tensor = FakeQuantWithMinMax(num_bits=8)(x, min, max) - """ - support_quant_bit = [4, 7, 8] - - @prim_attr_register - def __init__(self, num_bits=8, ema=False, ema_decay=0.999, symmetric=False, narrow_range=False, - training=True, channel_axis=1): - """init FakeQuantPerChannelUpdate OP for Ascend""" - if context.get_context('device_target') == "Ascend": - from mindspore.ops._op_impl._custom_op import fake_quant_minmax_perchannel_update - if num_bits not in self.support_quant_bit: - raise ValueError( - f"For '{self.name}' attr \'num_bits\' is not support.") - if ema and not ema_decay: - raise ValueError( - f"For '{self.name}' attr \'ema\' and \'ema_decay\' should set together.") - - self.ema = validator.check_value_type('ema', ema, (bool,), self.name) - self.symmetric = validator.check_value_type( - 'symmetric', symmetric, (bool,), self.name) - self.narrow_range = validator.check_value_type( - 'narrow_range', narrow_range, (bool,), self.name) - self.training = validator.check_value_type( - 'training', training, (bool,), self.name) - self.ema_decay = validator.check_number_range( - 'ema_decay', ema_decay, 0, 1, Rel.INC_BOTH, self.name) - self.num_bits = validator.check_integer( - 'num_bits', num_bits, 0, Rel.GT, self.name) - self.channel_axis = validator.check_integer( - 'channel axis', channel_axis, 0, Rel.GE, self.name) - self.init_prim_io_names( - inputs=['x', 'min', 'max'], outputs=['min_up', 'max_up']) - - def infer_shape(self, x_shape, min_shape, max_shape): - validator.check_integer("x rank", len(x_shape), 1, Rel.GT, self.name) - validator.check("min shape", min_shape, "max shape", - max_shape, Rel.EQ, self.name) - validator.check_integer("min shape", len( - min_shape), 1, Rel.EQ, self.name) - return min_shape, max_shape - - def infer_dtype(self, x_type, min_type, max_type): - valid_types = (mstype.float16, mstype.float32) - validator.check_tensor_type_same( - {"x": x_type}, valid_types, self.name) - validator.check_tensor_type_same( - {"min": min_type}, valid_types, self.name) - validator.check_tensor_type_same( - {"max": max_type}, valid_types, self.name) - return min_type, max_type