!1864 add fake quant per channel and bug fix

Merge pull request !1864 from chenzhongming/master
This commit is contained in:
mindspore-ci-bot 2020-06-05 16:07:26 +08:00 committed by Gitee
commit 5b80937367
13 changed files with 944 additions and 347 deletions

View File

@ -171,6 +171,6 @@ bool FakeQuantGpuKernel::Launch(const std::vector<AddressPtr> &inputs, const std
return true;
}
MS_REG_GPU_KERNEL(FakeQuantWithMinMax, FakeQuantGpuKernel)
MS_REG_GPU_KERNEL(FakeQuantPerLayer, FakeQuantGpuKernel)
} // namespace kernel
} // namespace mindspore

View File

@ -153,6 +153,6 @@ bool FakeQuantGradGpuKernel::Launch(const std::vector<AddressPtr> &inputs, const
return true;
}
MS_REG_GPU_KERNEL(FakeQuantWithMinMaxGrad, FakeQuantGradGpuKernel)
MS_REG_GPU_KERNEL(FakeQuantPerLayerGrad, FakeQuantGradGpuKernel)
} // namespace kernel
} // namespace mindspore

View File

@ -175,6 +175,6 @@ bool FakeQuantPerChannelGpuKernel::Launch(const std::vector<AddressPtr> &inputs,
return true;
}
MS_REG_GPU_KERNEL(FakeQuantWithMinMaxPerChannel, FakeQuantPerChannelGpuKernel)
MS_REG_GPU_KERNEL(FakeQuantPerChannel, FakeQuantPerChannelGpuKernel)
} // namespace kernel
} // namespace mindspore

View File

@ -143,6 +143,6 @@ bool FakeQuantPerChannelGradGpuKernel::Launch(const std::vector<AddressPtr> &inp
return true;
}
MS_REG_GPU_KERNEL(FakeQuantWithMinMaxPerChannelGrad, FakeQuantPerChannelGradGpuKernel)
MS_REG_GPU_KERNEL(FakeQuantPerChannelGrad, FakeQuantPerChannelGradGpuKernel)
} // namespace kernel
} // namespace mindspore

View File

@ -14,6 +14,7 @@
# ============================================================================
"""Aware quantization."""
from functools import partial
import numpy as np
import mindspore.common.dtype as mstype
from mindspore.ops import operations as P
@ -101,111 +102,7 @@ class BatchNormFoldCell(Cell):
return batch_mean, batch_std, running_mean, running_std
class FakeQuantWithMinMaxD(Cell):
r"""
Aware Quantization training op of ascend. This OP provide Fake quantization observer
function on data with min and max.
Args:
min_init (int, list): The dimension of channel or 1(layer). Default: -6.
max_init (int, list): The dimension of channel or 1(layer). Default: 6.
num_bits (int): Quantization number bit, support 4 and 8bit. Default: 8.
ema (bool): Exponential Moving Average algorithm update min and max. Default: False.
ema_decay (float): Exponential Moving Average algorithm parameter. Default: 0.999.
per_channel (bool): Quantization by layer or channel. Default: False.
out_channels (int): declarate the min and max channel size, Default: 1.
quant_delay (int): Quantization delay parameters according by global step. Default: 0.
symmetric (bool): Quantization algorithm use symmetric or not. Default: False.
narrow_range (bool): Quantization algorithm use narrow range or not. Default: False.
Inputs:
- **x** (Tensor) - The input of FakeQuantWithMinMax.
Outputs:
Tensor, with the same type and shape as the `x`.
Examples:
>>> fake_quant = nn.FakeQuantWithMinMaxD()
>>> input_x = Tensor(np.array([[1, 2, 1], [-2, 0, -1]]), mindspore.float32)
>>> result = fake_quant(input_x)
"""
def __init__(self,
min_init=-6,
max_init=6,
num_bits=8,
ema=False,
ema_decay=0.999,
per_channel=False,
channel_size=1,
quant_delay=0,
symmetric=False,
narrow_range=False,
training=True):
"""init FakeQuantWithMinMax ascend layer"""
super(FakeQuantWithMinMaxD, self).__init__()
self.min_init = min_init
self.num_bits = num_bits
self.max_init = max_init
self.ema = ema
self.ema_decay = ema_decay
self.per_channel = per_channel
self.channel_size = channel_size
self.quant_delay = quant_delay
self.symmetric = symmetric
self.narrow_range = narrow_range
self.training = training
if not per_channel:
self.fake_quant = P.FakeQuantWithMinMax(num_bits=self.num_bits,
ema=self.ema,
ema_decay=self.ema_decay,
quant_delay=self.quant_delay,
symmetric=self.symmetric,
narrow_range=self.narrow_range,
training=training)
self.ema_update = P.FakeQuantWithMinMaxUpdate(num_bits=self.num_bits,
ema=self.ema,
ema_decay=self.ema_decay,
quant_delay=self.quant_delay,
symmetric=self.symmetric,
narrow_range=self.narrow_range,
training=training)
else:
raise RuntimeError("not support per channel")
if isinstance(min_init, Parameter):
self.minq = min_init
self.maxq = max_init
else:
self.minq = Parameter(Tensor(np.array([min_init]).astype(np.float32)),
name='quant_min',
requires_grad=False)
self.maxq = Parameter(Tensor(np.array([max_init]).astype(np.float32)),
name='quant_max',
requires_grad=False)
self.reduce_min = P.ReduceMin()
self.reduce_max = P.ReduceMax()
def extend_repr(self):
s = 'min_init={}, max_init={}, ema={}, ema_decay={}, per_channel={}, channel_size={}, quant_delay={}'.format(
self.min_init, self.max_init, self.ema, self.ema_decay, self.per_channel, self.channel_size,
self.quant_delay)
return s
def construct(self, x, minq, maxq):
if self.training:
min_up, max_up = self.ema_update(x, minq, maxq)
out = self.fake_quant(x, min_up, max_up)
P.Assign()(self.minq, min_up)
P.Assign()(self.maxq, max_up)
else:
out = self.fake_quant(x, minq, maxq)
return out
class FakeQuantWithMinMax(Cell):
class FakeQuantWithMinMaxAscend(Cell):
r"""
Aware Quantization op. This OP provide Fake quantization observer function on data with min and max.
@ -240,98 +137,174 @@ class FakeQuantWithMinMax(Cell):
ema=False,
ema_decay=0.999,
per_channel=False,
channel_axis=1,
out_channels=1,
quant_delay=0,
symmetric=False,
narrow_range=False):
"""init FakeQuantWithMinMax layer"""
super(FakeQuantWithMinMax, self).__init__()
narrow_range=False,
training=True):
"""init FakeQuantWithMinMaxAscend layer"""
super(FakeQuantWithMinMaxAscend, self).__init__()
self.min_init = min_init
self.num_bits = num_bits
self.max_init = max_init
self.num_bits = num_bits
self.ema = ema
self.ema_decay = ema_decay
self.per_channel = per_channel
self.out_channels = out_channels
self.channel_axis = channel_axis
self.quant_delay = quant_delay
self.symmetric = symmetric
self.narrow_range = narrow_range
self.training = training
if per_channel:
min_array = np.array([self.min_init for i in range(0, self.out_channels)]).astype(np.float32)
max_array = np.array([self.max_init for i in range(0, self.channel_size)]).astype(np.float32)
self.minq = Parameter(Tensor(min_array), name='quant_min', requires_grad=False)
self.maxq = Parameter(Tensor(max_array), name='quant_max', requires_grad=False)
self.fake_quant_train = P.FakeQuantWithMinMaxPerChannel(num_bits=self.num_bits,
ema=self.ema,
ema_decay=self.ema_decay,
quant_delay=self.quant_delay,
symmetric=self.symmetric,
narrow_range=self.narrow_range,
training=True)
self.fake_quant_infer = P.FakeQuantWithMinMaxPerChannel(num_bits=self.num_bits,
ema=self.ema,
ema_decay=self.ema_decay,
quant_delay=self.quant_delay,
symmetric=self.symmetric,
narrow_range=self.narrow_range,
training=False)
else:
# init tensor min and max for fake quant op
if isinstance(min_init, int):
min_array = np.array([min_init]).reshape(1).astype(np.float32)
max_array = np.array([max_init]).reshape(1).astype(np.float32)
self.minq = Parameter(Tensor(min_array), name='quant_min', requires_grad=False)
self.maxq = Parameter(Tensor(max_array), name='quant_max', requires_grad=False)
if context.get_context('device_target') == "Ascend":
self.fake_quant_train = FakeQuantWithMinMaxD(num_bits=self.num_bits,
ema=self.ema,
ema_decay=self.ema_decay,
quant_delay=self.quant_delay,
symmetric=self.symmetric,
narrow_range=self.narrow_range,
training=True,
min_init=self.minq,
max_init=self.maxq)
self.fake_quant_infer = FakeQuantWithMinMaxD(num_bits=self.num_bits,
ema=self.ema,
ema_decay=self.ema_decay,
quant_delay=self.quant_delay,
symmetric=self.symmetric,
narrow_range=self.narrow_range,
training=False,
min_init=self.minq,
max_init=self.maxq)
elif context.get_context('device_target') == "GPU":
self.fake_quant_train = P.FakeQuantWithMinMax(num_bits=self.num_bits,
ema=self.ema,
ema_decay=self.ema_decay,
quant_delay=self.quant_delay,
symmetric=self.symmetric,
narrow_range=self.narrow_range,
training=True)
self.fake_quant_infer = P.FakeQuantWithMinMax(num_bits=self.num_bits,
ema=self.ema,
ema_decay=ema_decay,
quant_delay=quant_delay,
symmetric=self.symmetric,
narrow_range=self.narrow_range,
training=False)
else:
raise ValueError("Not support platform.")
elif isinstance(min_init, list):
min_array = np.array([self.min_init for i in range(
0, self.out_channels)]).astype(np.float32)
max_array = np.array([self.max_init for i in range(
0, self.out_channels)]).astype(np.float32)
self.minq = Parameter(Tensor(min_array), name='quant_min', requires_grad=False)
self.maxq = Parameter(Tensor(max_array), name='quant_max', requires_grad=False)
if per_channel:
quant_fun = partial(P.FakeQuantPerChannel, channel_axis=self.channel_axis)
ema_fun = partial(P.FakeQuantMinMaxPerChannelUpdate, channel_axis=self.channel_axis)
else:
quant_fun = P.FakeQuantPerLayer
ema_fun = P.FakeQuantMinMaxPerLayerUpdate
self.fake_quant = quant_fun(num_bits=self.num_bits,
ema=self.ema,
ema_decay=self.ema_decay,
quant_delay=self.quant_delay,
symmetric=self.symmetric,
narrow_range=self.narrow_range,
training=self.training)
self.ema_update = ema_fun(num_bits=self.num_bits,
ema=self.ema,
ema_decay=self.ema_decay,
symmetric=self.symmetric,
narrow_range=self.narrow_range,
training=self.training)
def extend_repr(self):
s = 'min={}, max={}, ema={}, ema_decay={}, per_channel={}, quant_delay={}'.format(
self.min_init, self.max_init, self.ema, self.ema_decay, self.per_channel, self.quant_delay)
s = 'ema={}, ema_decay={}, per_channel={}, quant_delay={}, channel_axis={}, min={}, max={}'.format(
self.min_init, self.max_init, self.ema, self.ema_decay,
self.per_channel, self.quant_delay, self.channel_axis)
return s
def construct(self, x):
if self.training:
out = self.fake_quant_train(x, self.minq, self.maxq)
if self.update:
min_up, max_up = self.ema_update(x, self.minq, self.maxq)
out = self.fake_quant(x, min_up, max_up)
P.Assign()(self.minq, min_up)
P.Assign()(self.maxq, max_up)
else:
out = self.fake_quant_infer(x, self.minq, self.maxq)
out = self.fake_quant(x, self.minq, self.maxq)
return out
class FakeQuantWithMinMaxGPU(Cell):
r"""
Aware Quantization op. This OP provide Fake quantization observer function on data with min and max.
Args:
min_init (int, list): The dimension of channel or 1(layer). Default: -6.
max_init (int, list): The dimension of channel or 1(layer). Default: 6.
num_bits (int): Quantization number bit, support 4 and 8bit. Default: 8.
ema (bool): Exponential Moving Average algorithm update min and max. Default: False.
ema_decay (float): Exponential Moving Average algorithm parameter. Default: 0.999.
per_channel (bool): Quantization by layer or channel. Default: False.
out_channels (int): declarate the min and max channel size, Default: 1.
quant_delay (int): Quantization delay parameters according by global step. Default: 0.
symmetric (bool): Quantization algorithm use symmetric or not. Default: False.
narrow_range (bool): Quantization algorithm use narrow range or not. Default: False.
Inputs:
- **x** (Tensor) - The input of FakeQuantWithMinMax.
Outputs:
Tensor, with the same type and shape as the `x`.
Examples:
>>> fake_quant = FakeQuantWithMinMax()
>>> input_x = Tensor(np.array([[1, 2, 1], [-2, 0, -1]]), mindspore.float32)
>>> result = fake_quant(input_x)
"""
def __init__(self,
min_init=-6,
max_init=6,
num_bits=8,
ema=False,
ema_decay=0.999,
per_channel=False,
channel_axis=1,
out_channels=1,
quant_delay=0,
symmetric=False,
narrow_range=False,
training=True):
super(FakeQuantWithMinMaxGPU, self).__init__()
self.min_init = min_init
self.max_init = max_init
self.num_bits = num_bits
self.ema = ema
self.ema_decay = ema_decay
self.per_channel = per_channel
self.channel_axis = channel_axis
self.quant_delay = quant_delay
self.symmetric = symmetric
self.narrow_range = narrow_range
self.training = training
# init tensor min and max for fake quant op
if isinstance(min_init, int):
min_array = np.array([min_init]).reshape(1).astype(np.float32)
max_array = np.array([max_init]).reshape(1).astype(np.float32)
elif isinstance(min_init, list):
min_array = np.array([self.min_init for i in range(
0, self.out_channels)]).astype(np.float32)
max_array = np.array([self.max_init for i in range(
0, self.out_channels)]).astype(np.float32)
self.minq = Parameter(Tensor(min_array), name='quant_min', requires_grad=False)
self.maxq = Parameter(Tensor(max_array), name='quant_max', requires_grad=False)
if per_channel:
quant_fun = partial(P.FakeQuantPerChannel, channel_axis=self.channel_axis)
else:
quant_fun = P.FakeQuantPerLayer
self.fake_quant = quant_fun(num_bits=self.num_bits,
ema=self.ema,
ema_decay=ema_decay,
quant_delay=quant_delay,
symmetric=self.symmetric,
narrow_range=self.narrow_range,
training=self.training)
def extend_repr(self):
s = 'ema={}, ema_decay={}, per_channel={}, quant_delay={}, channel_axis={}, min={}, max={}'.format(
self.min_init, self.max_init, self.ema, self.ema_decay,
self.per_channel, self.quant_delay, self.channel_axis)
return s
def construct(self, x):
out = self.fake_quant(x, self.minq, self.maxq)
return out
def FakeQuantWithMinMax(**kwargs):
if context.get_context('device_target') == "Ascend":
out = FakeQuantWithMinMaxAscend(**kwargs)
if context.get_context('device_target') == "GPU":
out = FakeQuantWithMinMaxGPU(**kwargs)
else:
raise ValueError("Not support platform or channel mode.")
return out
class Conv2dBatchNormQuant(Cell):
r"""
2D convolution with BatchNormal op folded layer.
@ -420,7 +393,6 @@ class Conv2dBatchNormQuant(Cell):
self.per_channel = per_channel
self.symmetric = symmetric
self.narrow_range = narrow_range
self.channel_axis = int(group > 1)
self.is_gpu = context.get_context('device_target') == "GPU"
# initialize convolution op and Parameter
@ -435,6 +407,7 @@ class Conv2dBatchNormQuant(Cell):
dilation=self.dilation)
if weight_init is None:
weight_init = initializer('normal', [1, in_channels, *self.kernel_size])
channel_axis = 1
else:
self.conv = P.Conv2D(out_channel=out_channels,
kernel_size=self.kernel_size,
@ -445,6 +418,7 @@ class Conv2dBatchNormQuant(Cell):
group=group)
if weight_init is None:
weight_init = initializer('normal', [out_channels, in_channels // group, *self.kernel_size])
channel_axis = 0
self.weight = Parameter(weight_init, name='weight')
# initialize batchnorm Parameter
@ -472,7 +446,7 @@ class Conv2dBatchNormQuant(Cell):
symmetric=symmetric,
narrow_range=narrow_range)
self.batchnorm_fold = BatchNormFoldCell(epsilon=eps, momentum=momentum, freeze_bn=freeze_bn)
self.correct_mul = P.CorrectionMul(self.channel_axis)
self.correct_mul = P.CorrectionMul(channel_axis)
if context.get_context('device_target') == "Ascend":
self.batchnorm_fold2_train = P.BatchNormFold2_D(freeze_bn=freeze_bn)
self.batchnorm_fold2_infer = P.BatchNormFold2_D(freeze_bn=0)
@ -520,7 +494,7 @@ class Conv2dBatchNormQuant(Cell):
out = self.batchnorm_fold2_train(out, self.beta, self.gamma, batch_std, batch_mean, running_std)
F.control_depend(out, self.assignadd(self.step, self.one))
else:
out = self.batchnorm_fold2_infer(out, self.beta, self.gamma, batch_std, batch_mean, running_std)
out = self.batchnorm_fold2_infer(out, self.beta, self.gamma, running_std, running_mean, running_std)
return out

View File

@ -20,10 +20,11 @@ from .grad_base import bprop_getters
from ..composite.multitype_ops.zeros_like_impl import zeros_like
@bprop_getters.register(P.FakeQuantWithMinMax)
@bprop_getters.register(P.FakeQuantPerLayer)
def get_bprop_fakequant_with_minmax(self):
"""Generate bprop for FakeQuantWithMinMax for GPU and Ascend"""
op = P.FakeQuantWithMinMaxGrad(num_bits=self.num_bits, quant_delay=self.quant_delay)
"""Generate bprop for FakeQuantPerLayer for GPU and Ascend"""
op = P.FakeQuantPerLayerGrad(
num_bits=self.num_bits, quant_delay=self.quant_delay)
def bprop(x, x_min, x_max, out, dout):
dx = op(dout, x, x_min, x_max)
@ -32,10 +33,14 @@ def get_bprop_fakequant_with_minmax(self):
return bprop
@bprop_getters.register(P.FakeQuantWithMinMaxPerChannel)
@bprop_getters.register(P.FakeQuantPerChannel)
def get_bprop_fakequant_with_minmax_perchannel(self):
"""Generate bprop for FakeQuantWithMinMaxPerChannel for GPU"""
op = P.FakeQuantWithMinMaxPerChannelGrad(num_bits=self.num_bits, quant_delay=self.quant_delay)
"""Generate bprop for FakeQuantPerChannel"""
op = P.FakeQuantPerChannelGrad(num_bits=self.num_bits,
quant_delay=self.quant_delay,
symmetric=self.symmetric,
narrow_range=self.symmetric,
channel_axis=self.channel_axis)
def bprop(x, x_min, x_max, out, dout):
dx = op(dout, x, x_min, x_max)
@ -77,7 +82,7 @@ def get_bprop_batchnorm_fold2(self):
d_batch_std, d_batch_mean, d_beta, d_gamma, d_x = op_f(dout, x, gamma, batch_std, batch_mean, running_std,
running_mean, global_step)
return d_x, d_beta, d_gamma, d_batch_std, d_batch_mean, zeros_like(running_std), zeros_like(running_mean), \
zeros_like(global_step)
zeros_like(global_step)
return bprop
@ -117,9 +122,19 @@ def get_bprop_batchnorm_fold2_(self):
return bprop
@bprop_getters.register(P.FakeQuantWithMinMaxUpdate)
def get_bprop_fakequant_with_minmax_update(self):
"""Generate bprop for FakeQuantWithMinMaxUpdate for Ascend"""
@bprop_getters.register(P.FakeQuantMinMaxPerLayerUpdate)
def get_bprop_fakequant_with_minmax_per_layer_update(self):
"""Generate bprop for FakeQuantMinMaxPerLayerUpdate for Ascend"""
def bprop(x, x_min, x_max, out, dout):
return zeros_like(x), zeros_like(x_min), zeros_like(x_max)
return bprop
@bprop_getters.register(P.FakeQuantMinMaxPerChannelUpdate)
def get_bprop_fakequant_with_minmax_per_channel_update(self):
"""Generate bprop for FakeQuantMinMaxPerChannelUpdate for Ascend"""
def bprop(x, x_min, x_max, out, dout):
return zeros_like(x), zeros_like(x_min), zeros_like(x_max)

View File

@ -0,0 +1,135 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""FakeQuantMinMaxPerChannelUpdate op"""
import te.lang.cce
from te import tvm
from te.platform.fusion_manager import fusion_manager
from topi import generic
from topi.cce import util
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
fake_quant_min_max_per_channel_update_op_info = TBERegOp("FakeQuantMinMaxPerChannelUpdate") \
.fusion_type("OPAQUE") \
.async_flag(False) \
.binfile_name("fake_quant_min_max_per_channel_update.so") \
.compute_cost(10) \
.kernel_name("fake_quant_min_max_per_channel_update") \
.partial_flag(True) \
.attr("ema", "optional", "bool", "all") \
.attr("ema_decay", "optional", "float", "all") \
.attr("symmetric", "optional", "bool", "all") \
.attr("narrow_range", "optional", "bool", "all") \
.attr("training", "optional", "bool", "all") \
.attr("num_bits", "optional", "int", "all") \
.attr("channel_axis", "optional", "int", "all") \
.input(0, "x", None, "required", None) \
.input(1, "min", None, "required", None) \
.input(2, "max", None, "required", None) \
.output(0, "min_up", True, "required", "all") \
.output(1, "max_up", True, "required", "all") \
.dtype_format(DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD,
DataType.F32_5HD) \
.get_op_info()
@op_info_register(fake_quant_min_max_per_channel_update_op_info)
def _fake_quant_min_max_per_channel_update_tbe():
"""FakeQuantPerChannelUpdate TBE register"""
return
@fusion_manager.register("fake_quant_min_max_per_channel_update")
def fake_quant_min_max_per_channel_update_compute(x, min_val, max_val,
ema, ema_decay, quant_min, quant_max, training, channel_axis,
kernel_name="fake_quant_min_max_per_channel_update"):
"""FakeQuantPerChannelUpdate compute"""
shape_min = te.lang.cce.util.shape_to_list(min_val.shape)
if not ema:
ema_decay = 0.0
if training:
# CalMinMax
axis = [0, 2, 3]
x_min = te.lang.cce.reduce_min(x, axis=axis)
x_max = te.lang.cce.reduce_max(x, axis=axis)
x_min = te.lang.cce.broadcast(x_min, shape_min)
x_max = te.lang.cce.broadcast(x_max, shape_min)
min_val = te.lang.cce.vadd(te.lang.cce.vmuls(
min_val, ema_decay), te.lang.cce.vmuls(x_min, (1 - ema_decay)))
max_val = te.lang.cce.vadd(te.lang.cce.vmuls(
max_val, ema_decay), te.lang.cce.vmuls(x_max, (1 - ema_decay)))
min_val = te.lang.cce.vmins(min_val, 0)
max_val = te.lang.cce.vmaxs(max_val, 0)
return [min_val, max_val]
@util.check_input_type(dict, dict, dict, dict, dict, bool, float, bool, bool, bool, int, int, str)
def fake_quant_min_max_per_channel_update(x, min_val, max_val, min_up, max_up,
ema, ema_decay, symmetric, narrow_range, training, num_bits, channel_axis,
kernel_name="fake_quant_min_max_per_channel_update"):
"""FakeQuantPerLayer op"""
x_shape = x.get("ori_shape")
x_format = x.get("format")
x_dtype = x.get("dtype")
min_shape = min_val.get("ori_shape")
min_dtype = min_val.get("dtype")
max_shape = max_val.get("ori_shape")
max_dtype = max_val.get("dtype")
util.check_kernel_name(kernel_name)
util.check_shape_rule(x_shape)
util.check_shape_rule(min_shape, 1, 1, x_shape[channel_axis])
util.check_shape_rule(max_shape, 1, 1, x_shape[channel_axis])
util.check_tensor_shape_size(x_shape)
util.check_tensor_shape_size(min_shape)
util.check_tensor_shape_size(max_shape)
check_list = ["float32", "float16"]
x_dtype = x_dtype.lower()
min_dtype = min_dtype.lower()
max_dtype = max_dtype.lower()
util.check_dtype_rule(x_dtype, check_list)
util.check_dtype_rule(min_dtype, check_list)
util.check_dtype_rule(max_dtype, check_list)
if symmetric:
quant_min = 0 - 2 ** (num_bits - 1)
quant_max = 2 ** (num_bits - 1) - 1
else:
quant_min = 0
quant_max = 2 ** num_bits - 1
if narrow_range:
quant_min = quant_min + 1
shape_c = [min_val.get("shape")[1], min_val.get("shape")[-1]]
input_data = tvm.placeholder(x.get("shape"), name="x", dtype=x_dtype)
min_data = tvm.placeholder(shape_c, name="min_val", dtype=x_dtype)
max_data = tvm.placeholder(shape_c, name="max_val", dtype=x_dtype)
res_list = fake_quant_min_max_per_channel_update_compute(input_data, min_data, max_data,
ema, ema_decay, quant_min, quant_max, training, channel_axis, kernel_name)
with tvm.target.cce():
sch = generic.auto_schedule(res_list)
tensor_list = [input_data, min_data, max_data] + list(res_list)
config = {"print_ir": False,
"name": kernel_name,
"tensor_list": tensor_list}
te.lang.cce.cce_build_code(sch, config)

View File

@ -13,7 +13,7 @@
# limitations under the License.
# ============================================================================
"""FakeQuantWithMinMaxUpdate op"""
"""FakeQuantMinMaxPerLayerUpdate op"""
from functools import reduce as functools_reduce
import te.lang.cce
from te import tvm
@ -23,12 +23,12 @@ from topi.cce import util
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
fake_quant_update_op_info = TBERegOp("FakeQuantWithMinMaxUpdate") \
fake_quant_minmax_update_op_info = TBERegOp("FakeQuantMinMaxPerLayerUpdate") \
.fusion_type("OPAQUE") \
.async_flag(False) \
.binfile_name("fake_quant_with_min_max_update.so") \
.binfile_name("fake_quant_minmax_update.so") \
.compute_cost(10) \
.kernel_name("fake_quant_with_min_max_update") \
.kernel_name("fake_quant_minmax_update") \
.partial_flag(True) \
.attr("ema", "optional", "bool", "all") \
.attr("ema_decay", "optional", "float", "all") \
@ -36,7 +36,6 @@ fake_quant_update_op_info = TBERegOp("FakeQuantWithMinMaxUpdate") \
.attr("narrow_range", "optional", "bool", "all") \
.attr("training", "optional", "bool", "all") \
.attr("num_bits", "optional", "int", "all") \
.attr("quant_delay", "optional", "int", "all") \
.input(0, "x", None, "required", None) \
.input(1, "min", None, "required", None) \
.input(2, "max", None, "required", None) \
@ -47,16 +46,16 @@ fake_quant_update_op_info = TBERegOp("FakeQuantWithMinMaxUpdate") \
.get_op_info()
@op_info_register(fake_quant_update_op_info)
def _fake_quant_update_tbe():
"""_FakeQuantWithMinMaxUpdate TBE register"""
@op_info_register(fake_quant_minmax_update_op_info)
def _fake_quant_minmax_update_tbe():
"""FakeQuantMinMaxPerLayerUpdate TBE register"""
return
@fusion_manager.register("fake_quant_with_min_max_update")
def fake_quant_with_min_max_update_compute(x, min_val, max_val, ema, ema_decay, quant_min, quant_max, training,
kernel_name="fake_quant_update"):
"""FakeQuantWithMinMaxUpdate compute"""
@fusion_manager.register("fake_quant_minmax_update")
def fake_quant_minmax_update_compute(x, min_val, max_val, ema, ema_decay, quant_min, quant_max, training,
kernel_name="fake_quant_minmax_update"):
"""FakeQuantMinMaxPerLayerUpdate compute"""
shape = te.lang.cce.util.shape_to_list(x.shape)
shape_min = te.lang.cce.util.shape_to_list(min_val.shape)
min_val = te.lang.cce.broadcast(min_val, shape_min, x.dtype)
@ -70,19 +69,21 @@ def fake_quant_with_min_max_update_compute(x, min_val, max_val, ema, ema_decay,
x_max = te.lang.cce.reduce_max(x, axis=axis)
x_min = te.lang.cce.broadcast(x_min, shape_min)
x_max = te.lang.cce.broadcast(x_max, shape_min)
min_val = te.lang.cce.vadd(te.lang.cce.vmuls(min_val, ema_decay), te.lang.cce.vmuls(x_min, (1 - ema_decay)))
max_val = te.lang.cce.vadd(te.lang.cce.vmuls(max_val, ema_decay), te.lang.cce.vmuls(x_max, (1 - ema_decay)))
min_val = te.lang.cce.vadd(te.lang.cce.vmuls(
min_val, ema_decay), te.lang.cce.vmuls(x_min, (1 - ema_decay)))
max_val = te.lang.cce.vadd(te.lang.cce.vmuls(
max_val, ema_decay), te.lang.cce.vmuls(x_max, (1 - ema_decay)))
min_val = te.lang.cce.vmins(min_val, 0)
max_val = te.lang.cce.vmaxs(max_val, 0)
return [min_val, max_val]
@util.check_input_type(dict, dict, dict, dict, dict, bool, float, bool, bool, bool, int, int, str)
def fake_quant_with_min_max_update(x, min_val, max_val, min_up, max_up,
ema, ema_decay, symmetric, narrow_range, training, num_bits, quant_delay,
kernel_name="fake_quant_update"):
"""FakeQuantWithMinMax op"""
@util.check_input_type(dict, dict, dict, dict, dict, bool, float, bool, bool, bool, int, str)
def fake_quant_minmax_update(x, min_val, max_val, min_up, max_up,
ema, ema_decay, symmetric, narrow_range, training, num_bits,
kernel_name="fake_quant_minmax_update"):
"""FakeQuantPerLayer op"""
input_shape = x.get("shape")
input_dtype = x.get("dtype")
min_shape = min_val.get("ori_shape")
@ -123,8 +124,8 @@ def fake_quant_with_min_max_update(x, min_val, max_val, min_up, max_up,
input_data = tvm.placeholder(input_shape, name="x", dtype=x_dtype)
min_data = tvm.placeholder(shape_min, name="min_data", dtype=min_dtype)
max_data = tvm.placeholder(shape_min, name="max_data", dtype=max_dtype)
res_list = fake_quant_with_min_max_update_compute(input_data, min_data, max_data,
ema, ema_decay, quant_min, quant_max, training, kernel_name)
res_list = fake_quant_minmax_update_compute(input_data, min_data, max_data,
ema, ema_decay, quant_min, quant_max, training, kernel_name)
with tvm.target.cce():
sch = generic.auto_schedule(res_list)

View File

@ -0,0 +1,145 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""FakeQuantPerChannel op"""
import te.lang.cce
from te import tvm
from te.platform.fusion_manager import fusion_manager
from topi import generic
from topi.cce import util
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
fake_quant_perchannel_op_info = TBERegOp("FakeQuantPerChannel") \
.fusion_type("ELEMWISE") \
.async_flag(False) \
.binfile_name("fake_quant_perchannel.so") \
.compute_cost(10) \
.kernel_name("fake_quant_perchannel") \
.partial_flag(True) \
.attr("symmetric", "optional", "bool", "all") \
.attr("narrow_range", "optional", "bool", "all") \
.attr("num_bits", "optional", "int", "all") \
.attr("channel_axis", "optional", "int", "all") \
.input(0, "x", None, "required", None) \
.input(1, "min", None, "required", None) \
.input(2, "max", None, "required", None) \
.output(0, "y", True, "required", "all") \
.dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, DataType.F16_Default) \
.dtype_format(DataType.F16_5HD, DataType.F16_5HD, DataType.F16_5HD, DataType.F16_5HD) \
.dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default) \
.dtype_format(DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD) \
.get_op_info()
@op_info_register(fake_quant_perchannel_op_info)
def _fake_quant_perchannel_tbe():
"""FakeQuantPerChannel TBE register"""
return
@fusion_manager.register("fake_quant_perchannel")
def fake_quant_perchannel_compute(x, min_val, max_val, y, quant_min, quant_max,
kernel_name="fake_quant_perchannel"):
"""FakeQuantPerChannel"""
x_shape = te.lang.cce.util.shape_to_list(x.shape)
minmax_shape = te.lang.cce.util.shape_to_list(min_val.shape)
quant_min = tvm.const(quant_min, x.dtype)
quant_max = tvm.const(quant_max, x.dtype)
quant_min = te.lang.cce.broadcast(quant_min, minmax_shape, x.dtype)
quant_max = te.lang.cce.broadcast(quant_max, minmax_shape, x.dtype)
# CalNudge(NudgeMinMax)
scale = te.lang.cce.vdiv(te.lang.cce.vsub(
max_val, min_val), te.lang.cce.vsub(quant_max, quant_min))
zp_from_min = te.lang.cce.vsub(quant_min, te.lang.cce.vdiv(min_val, scale))
# Nudge zero point
nudge_zp_ = te.lang.cce.vmin(
quant_max, te.lang.cce.vmax(quant_min, zp_from_min))
nudge_zp = te.lang.cce.floor(te.lang.cce.vadds(nudge_zp_, 0.5))
nudge_min = te.lang.cce.vmul(te.lang.cce.vsub(quant_min, nudge_zp), scale)
nudge_max = te.lang.cce.vmul(te.lang.cce.vsub(quant_max, nudge_zp), scale)
# FakeQuant
nudge_min_b = te.lang.cce.broadcast(nudge_min, x_shape)
nudge_max_b = te.lang.cce.broadcast(nudge_max, x_shape)
scale_b = te.lang.cce.broadcast(scale, x_shape)
input_x = te.lang.cce.vmin(nudge_max_b, te.lang.cce.vmax(nudge_min_b, x))
nudge_input_ = te.lang.cce.vdiv(
te.lang.cce.vsub(input_x, nudge_min_b), scale_b)
nudge_input = te.lang.cce.floor(te.lang.cce.vadds(nudge_input_, 0.5))
res = te.lang.cce.vadd(te.lang.cce.vmul(nudge_input, scale_b), nudge_min_b)
return res
@util.check_input_type(dict, dict, dict, dict, bool, bool, int, int, str)
def fake_quant_perchannel(x, min_val, max_val, y,
symmetric, narrow_range, num_bits, channel_axis,
kernel_name="fake_quant_perchannel"):
"""FakeQuantPerChannel"""
x_shape = x.get("shape")
x_format = x.get("format")
x_dtype = x.get("dtype")
min_shape = min_val.get("ori_shape")
min_dtype = min_val.get("dtype")
max_shape = max_val.get("ori_shape")
max_dtype = max_val.get("dtype")
util.check_kernel_name(kernel_name)
util.check_shape_rule(x_shape)
util.check_shape_rule(min_shape, 1, 1, x_shape[channel_axis])
util.check_shape_rule(max_shape, 1, 1, x_shape[channel_axis])
util.check_tensor_shape_size(x_shape)
util.check_tensor_shape_size(min_shape)
util.check_tensor_shape_size(max_shape)
check_list = ["float32", "float16"]
x_dtype = x_dtype.lower()
min_dtype = min_dtype.lower()
max_dtype = max_dtype.lower()
util.check_dtype_rule(x_dtype, check_list)
util.check_dtype_rule(min_dtype, check_list)
util.check_dtype_rule(max_dtype, check_list)
if symmetric:
quant_min = 0 - 2 ** (num_bits - 1)
quant_max = 2 ** (num_bits - 1) - 1
else:
quant_min = 0
quant_max = 2 ** num_bits - 1
if narrow_range:
quant_min = quant_min + 1
shape_c = [1] * len(x_shape)
shape_c[channel_axis] = min_val.get("ori_shape")[0]
if x_format == "NC1HWC0" and channel_axis == 1:
shape_c = min_val.get("shape")
input_data = tvm.placeholder(x_shape, name="x", dtype=x_dtype)
min_data = tvm.placeholder(shape_c, name="min_val", dtype=x_dtype)
max_data = tvm.placeholder(shape_c, name="max_val", dtype=x_dtype)
res = fake_quant_perchannel_compute(input_data, min_data, max_data, y,
quant_min, quant_max, kernel_name)
with tvm.target.cce():
sch = generic.auto_schedule(res)
tensor_list = [input_data, min_data, max_data, res]
config = {"print_ir": False,
"name": kernel_name,
"tensor_list": tensor_list}
te.lang.cce.cce_build_code(sch, config)

View File

@ -0,0 +1,171 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""FakeQuantPerChannelGrad op"""
import te.lang.cce
from te import tvm
from te.platform.fusion_manager import fusion_manager
from topi import generic
from topi.cce import util
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
SHAPE_SIZE_LIMIT = 2147483648
D_TYPE = 'float32'
fake_quant_perchannel_grad_op_info = TBERegOp("FakeQuantPerChannelGrad") \
.fusion_type("OPAQUE") \
.async_flag(False) \
.binfile_name("fake_quant_perchannel_grad.so") \
.compute_cost(10) \
.kernel_name("fake_quant_perchannel_grad") \
.partial_flag(True) \
.attr("symmetric", "optional", "bool", "all") \
.attr("narrow_range", "optional", "bool", "all") \
.attr("num_bits", "optional", "int", "all") \
.attr("channel_axis", "optional", "int", "all") \
.input(0, "dout", None, "required", None) \
.input(1, "x", None, "required", None) \
.input(2, "min", None, "required", None) \
.input(3, "max", None, "required", None) \
.output(0, "dx", True, "required", "all") \
.dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, DataType.F16_Default,
DataType.F16_Default) \
.dtype_format(DataType.F16_5HD, DataType.F16_5HD, DataType.F16_5HD, DataType.F16_5HD, DataType.F16_5HD) \
.dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default,
DataType.F32_Default) \
.dtype_format(DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD) \
.get_op_info()
def _less_compare_float32(data_x, data_y):
"""_less_compare_float32 compute"""
input_shape = te.lang.cce.util.shape_to_list(data_x.shape)
min_value = tvm.const(2 ** (-126), dtype=D_TYPE)
max_value = tvm.const(2 ** 62, dtype=D_TYPE)
factor_value = tvm.const(2 ** 2, dtype=D_TYPE)
data_zero = te.lang.cce.broadcast(
tvm.const(0, dtype=D_TYPE), input_shape, D_TYPE)
min_value_tensor = te.lang.cce.vadds(data_zero, min_value)
res_sub = te.lang.cce.vsub(data_y, data_x)
res_min = te.lang.cce.vmin(res_sub, min_value_tensor)
res_max = te.lang.cce.vmax(res_min, data_zero)
res_max_mul = te.lang.cce.vmuls(res_max, max_value)
res_max_mul_max = te.lang.cce.vmuls(res_max_mul, max_value)
res = te.lang.cce.vmuls(res_max_mul_max, factor_value)
return res
@op_info_register(fake_quant_perchannel_grad_op_info)
def _fake_quant_perchannel_grad_tbe():
"""FakeQuantPerChannelGrad TBE register"""
return
@fusion_manager.register("fake_quant_perchannel_grad")
def fake_quant_perchannel_grad_compute(dout, x, min_val, max_val, quant_min, quant_max,
kernel_name="fake_quant_perchannel_grad"):
"""FakeQuantPerChannelGrad"""
x_shape = te.lang.cce.util.shape_to_list(x.shape)
minmax_shape = te.lang.cce.util.shape_to_list(min_val.shape)
quant_min = tvm.const(quant_min, x.dtype)
quant_max = tvm.const(quant_max, x.dtype)
quant_min = te.lang.cce.broadcast(quant_min, minmax_shape, x.dtype)
quant_max = te.lang.cce.broadcast(quant_max, minmax_shape, x.dtype)
# CalNudge(NudgeMinMax)
scale = te.lang.cce.vdiv(te.lang.cce.vsub(
max_val, min_val), te.lang.cce.vsub(quant_max, quant_min))
zp_from_min = te.lang.cce.vsub(quant_min, te.lang.cce.vdiv(min_val, scale))
# Nudge zero point
nudge_zp_ = te.lang.cce.vmin(
quant_max, te.lang.cce.vmax(quant_min, zp_from_min))
nudge_zp = te.lang.cce.floor(te.lang.cce.vadds(nudge_zp_, 0.5))
nudge_min = te.lang.cce.vmul(te.lang.cce.vsub(quant_min, nudge_zp), scale)
nudge_max = te.lang.cce.vmul(te.lang.cce.vsub(quant_max, nudge_zp), scale)
# FakeQuant Grad
nudge_min_b = te.lang.cce.broadcast(nudge_min, x_shape)
nudge_max_b = te.lang.cce.broadcast(nudge_max, x_shape)
bool_over_min = _less_compare_float32(nudge_min_b, x)
bool_less_max = _less_compare_float32(x, nudge_max_b)
bool_between = te.lang.cce.vmul(bool_over_min, bool_less_max)
res = te.lang.cce.vmul(dout, bool_between)
return res
@util.check_input_type(dict, dict, dict, dict, dict, bool, bool, int, int, str)
def fake_quant_perchannel_grad(dout, x, min_val, max_val, dx,
symmetric, narrow_range, num_bits, channel_axis,
kernel_name="fake_quant_perchannel_grad"):
"""FakeQuantPerChannelGrad"""
x_shape = x.get("shape")
x_format = x.get("format")
x_dtype = x.get("dtype")
min_shape = min_val.get("ori_shape")
min_dtype = min_val.get("dtype")
max_shape = max_val.get("ori_shape")
max_dtype = max_val.get("dtype")
util.check_kernel_name(kernel_name)
util.check_shape_rule(x_shape)
util.check_shape_rule(min_shape, 1, 1, x_shape[channel_axis])
util.check_shape_rule(max_shape, 1, 1, x_shape[channel_axis])
util.check_tensor_shape_size(x_shape)
util.check_tensor_shape_size(min_shape)
util.check_tensor_shape_size(max_shape)
check_list = ["float32", "float16"]
x_dtype = x_dtype.lower()
min_dtype = min_dtype.lower()
max_dtype = max_dtype.lower()
util.check_dtype_rule(x_dtype, check_list)
util.check_dtype_rule(min_dtype, check_list)
util.check_dtype_rule(max_dtype, check_list)
if symmetric:
quant_min = 0 - 2 ** (num_bits - 1)
quant_max = 2 ** (num_bits - 1) - 1
else:
quant_min = 0
quant_max = 2 ** num_bits - 1
if narrow_range:
quant_min = quant_min + 1
shape_c = [1] * len(x_shape)
shape_c[channel_axis] = min_val.get("ori_shape")[0]
if x_format == "NC1HWC0" and channel_axis == 1:
shape_c = min_val.get("shape")
dout_data = tvm.placeholder(x_shape, name="dout", dtype=x_dtype)
input_data = tvm.placeholder(x_shape, name="x", dtype=x_dtype)
min_data = tvm.placeholder(shape_c, name="min_val", dtype=x_dtype)
max_data = tvm.placeholder(shape_c, name="max_val", dtype=x_dtype)
res = fake_quant_perchannel_grad_compute(dout_data, input_data, min_data, max_data,
quant_min, quant_max, kernel_name)
with tvm.target.cce():
sch = generic.auto_schedule(res)
tensor_list = [dout_data, input_data, min_data, max_data, res]
config = {"print_ir": False,
"name": kernel_name,
"tensor_list": tensor_list}
te.lang.cce.cce_build_code(sch, config)

View File

@ -13,8 +13,7 @@
# limitations under the License.
# ============================================================================
"""FakeQuantWithMinMax op"""
"""FakeQuantPerLayer op"""
from functools import reduce as functools_reduce
import te.lang.cce
from te import tvm
@ -23,20 +22,16 @@ from topi import generic
from topi.cce import util
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
fake_quant_op_info = TBERegOp("FakeQuantWithMinMax") \
fake_quant_per_layer_op_info = TBERegOp("FakeQuantPerLayer") \
.fusion_type("ELEMWISE") \
.async_flag(False) \
.binfile_name("fake_quant_with_min_max_vars_ema.so") \
.binfile_name("fake_quant_per_layer.so") \
.compute_cost(10) \
.kernel_name("fake_quant_with_min_max_vars_ema") \
.kernel_name("fake_quant_per_layer") \
.partial_flag(True) \
.attr("ema", "optional", "bool", "all") \
.attr("ema_decay", "optional", "float", "all") \
.attr("symmetric", "optional", "bool", "all") \
.attr("narrow_range", "optional", "bool", "all") \
.attr("training", "optional", "bool", "all") \
.attr("num_bits", "optional", "int", "all") \
.attr("quant_delay", "optional", "int", "all") \
.input(0, "x", None, "required", None) \
.input(1, "min", None, "required", None) \
.input(2, "max", None, "required", None) \
@ -49,15 +44,15 @@ fake_quant_op_info = TBERegOp("FakeQuantWithMinMax") \
@op_info_register(fake_quant_op_info)
def _fake_quant_tbe():
"""FakeQuantWithMinMax TBE register"""
def _fake_quant_per_layer_tbe():
"""FakeQuantPerLayer TBE register"""
return
@fusion_manager.register("fake_quant_with_min_max_vars_ema")
def fake_quant_with_min_max_vars_ema_compute(x, min_val, max_val, y, quant_min, quant_max,
kernel_name="correction_mul"):
"""FakeQuantWithMinMax"""
@fusion_manager.register("fake_quant_per_layer")
def fake_quant_per_layer_compute(x, min_val, max_val, y, quant_min, quant_max,
kernel_name="fake_quant_per_layer"):
"""FakeQuantPerLayer"""
shape = te.lang.cce.util.shape_to_list(x.shape)
shape_min = te.lang.cce.util.shape_to_list(min_val.shape)
quant_min = te.lang.cce.broadcast(quant_min, shape_min, x.dtype)
@ -66,10 +61,13 @@ def fake_quant_with_min_max_vars_ema_compute(x, min_val, max_val, y, quant_min,
max_val = te.lang.cce.broadcast(max_val, shape_min, x.dtype)
# CalNudge(NudgeMinMax)
scale = te.lang.cce.vdiv(te.lang.cce.vsub(max_val, min_val), te.lang.cce.vsub(quant_max, quant_min))
scale = te.lang.cce.vdiv(te.lang.cce.vsub(
max_val, min_val), te.lang.cce.vsub(quant_max, quant_min))
zp_from_min = te.lang.cce.vsub(quant_min, te.lang.cce.vdiv(min_val, scale))
# Nudge zero point
nudge_zp = te.lang.cce.round(te.lang.cce.vmin(quant_max, te.lang.cce.vmax(quant_min, zp_from_min)))
nudge_zp_ = te.lang.cce.vmin(
quant_max, te.lang.cce.vmax(quant_min, zp_from_min))
nudge_zp = te.lang.cce.floor(te.lang.cce.vadds(nudge_zp_, 0.5))
nudge_min = te.lang.cce.vmul(te.lang.cce.vsub(quant_min, nudge_zp), scale)
nudge_max = te.lang.cce.vmul(te.lang.cce.vsub(quant_max, nudge_zp), scale)
@ -80,17 +78,19 @@ def fake_quant_with_min_max_vars_ema_compute(x, min_val, max_val, y, quant_min,
# FakeQuant
input_x = te.lang.cce.vmin(nudge_max, te.lang.cce.vmax(nudge_min, x))
nudge_input = te.lang.cce.round(te.lang.cce.vdiv(te.lang.cce.vsub(input_x, nudge_min), scale))
nudge_input_ = te.lang.cce.vdiv(
te.lang.cce.vsub(input_x, nudge_min), scale)
nudge_input = te.lang.cce.floor(te.lang.cce.vadds(nudge_input_, 0.5))
res = te.lang.cce.vadd(te.lang.cce.vmul(nudge_input, scale), nudge_min)
return res
@util.check_input_type(dict, dict, dict, dict, bool, float, bool, bool, bool, int, int, str)
def fake_quant_with_min_max_vars_ema(x, min_val, max_val, y,
ema, ema_decay, symmetric, narrow_range, training, num_bits, quant_delay,
kernel_name="fake_quant"):
"""FakeQuantWithMinMax"""
@util.check_input_type(dict, dict, dict, dict, bool, bool, int, str)
def fake_quant_per_layer(x, min_val, max_val, y,
symmetric, narrow_range, num_bits,
kernel_name="fake_quant_per_layer"):
"""FakeQuantPerLayer"""
input_shape = x.get("shape")
input_dtype = x.get("dtype")
min_shape = min_val.get("ori_shape")
@ -131,8 +131,8 @@ def fake_quant_with_min_max_vars_ema(x, min_val, max_val, y,
input_data = tvm.placeholder(input_shape, name="x", dtype=x_dtype)
min_data = tvm.placeholder(shape_min, name="min_data", dtype=min_dtype)
max_data = tvm.placeholder(shape_min, name="max_data", dtype=max_dtype)
res = fake_quant_with_min_max_vars_ema_compute(input_data, min_data, max_data, y,
quant_min, quant_max, kernel_name)
res = fake_quant_per_layer_compute(input_data, min_data, max_data, y,
quant_min, quant_max, kernel_name)
with tvm.target.cce():
sch = generic.auto_schedule(res)

View File

@ -13,7 +13,7 @@
# limitations under the License.
# ============================================================================
"""FakeQuantWithMinMaxGrad op"""
"""FakeQuantPerLayerGrad op"""
from functools import reduce as functools_reduce
import te.lang.cce
@ -26,15 +26,14 @@ from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
SHAPE_SIZE_LIMIT = 2147483648
D_TYPE = 'float32'
fake_quant_grad_op_info = TBERegOp("FakeQuantWithMinMaxGrad") \
fake_quant_per_layer_grad_op_info = TBERegOp("FakeQuantPerLayerGrad") \
.fusion_type("OPAQUE") \
.async_flag(False) \
.binfile_name("fake_quant_with_min_max_grad.so") \
.binfile_name("fake_quant_per_layer_grad.so") \
.compute_cost(10) \
.kernel_name("fake_quant_with_min_max_grad") \
.kernel_name("fake_quant_per_layer_grad") \
.partial_flag(True) \
.attr("num_bits", "optional", "int", "all") \
.attr("quant_delay", "optional", "int", "all") \
.attr("symmetric", "optional", "bool", "all") \
.attr("narrow_range", "optional", "bool", "all") \
.input(0, "dout", None, "required", None) \
@ -57,7 +56,8 @@ def _less_compare_float32(data_x, data_y):
min_value = tvm.const(2 ** (-126), dtype=D_TYPE)
max_value = tvm.const(2 ** 62, dtype=D_TYPE)
factor_value = tvm.const(2 ** 2, dtype=D_TYPE)
data_zero = te.lang.cce.broadcast(tvm.const(0, dtype=D_TYPE), shape_inputs, D_TYPE)
data_zero = te.lang.cce.broadcast(
tvm.const(0, dtype=D_TYPE), shape_inputs, D_TYPE)
min_value_tensor = te.lang.cce.vadds(data_zero, min_value)
res_sub = te.lang.cce.vsub(data_y, data_x)
@ -71,16 +71,16 @@ def _less_compare_float32(data_x, data_y):
return res
@op_info_register(fake_quant_grad_op_info)
def _fake_quant_grad_tbe():
"""FakeQuantWithMinMaxGrad TBE register"""
@op_info_register(fake_quant_per_layer_grad_op_info)
def _fake_quant_per_layer_grad_tbe():
"""FakeQuantPerLayerGrad TBE register"""
return
@fusion_manager.register("fake_quant_with_min_max_grad")
def fake_quant_with_min_max_grad_compute(dout, x, min_val, max_val, quant_min, quant_max,
kernel_name="fake_quant_with_min_max_grad"):
"""FakeQuantWithMinMaxGrad"""
@fusion_manager.register("fake_quant_per_layer_grad")
def fake_quant_per_layer_grad_compute(dout, x, min_val, max_val, quant_min, quant_max,
kernel_name="fake_quant_per_layer_grad"):
"""FakeQuantPerLayerGrad"""
shape = te.lang.cce.util.shape_to_list(x.shape)
shape_min = te.lang.cce.util.shape_to_list(min_val.shape)
quant_min = tvm.const(quant_min, x.dtype)
@ -89,10 +89,13 @@ def fake_quant_with_min_max_grad_compute(dout, x, min_val, max_val, quant_min, q
quant_max = te.lang.cce.broadcast(quant_max, shape_min)
# CalNudge(NudgeMinMax)
scale = te.lang.cce.vdiv(te.lang.cce.vsub(max_val, min_val), te.lang.cce.vsub(quant_max, quant_min))
scale = te.lang.cce.vdiv(te.lang.cce.vsub(
max_val, min_val), te.lang.cce.vsub(quant_max, quant_min))
zp_from_min = te.lang.cce.vsub(quant_min, te.lang.cce.vdiv(min_val, scale))
# Nudge zero point
nudge_zp = te.lang.cce.round(te.lang.cce.vmin(quant_max, te.lang.cce.vmax(quant_min, zp_from_min)))
nudge_zp_ = te.lang.cce.vmin(
quant_max, te.lang.cce.vmax(quant_min, zp_from_min))
nudge_zp = te.lang.cce.floor(te.lang.cce.vadds(nudge_zp_, 0.5))
nudge_min = te.lang.cce.vmul(te.lang.cce.vsub(quant_min, nudge_zp), scale)
nudge_max = te.lang.cce.vmul(te.lang.cce.vsub(quant_max, nudge_zp), scale)
nudge_min = te.lang.cce.broadcast(nudge_min, shape)
@ -106,11 +109,11 @@ def fake_quant_with_min_max_grad_compute(dout, x, min_val, max_val, quant_min, q
return res
@util.check_input_type(dict, dict, dict, dict, dict, int, int, bool, bool, str)
def fake_quant_with_min_max_grad(dout, x, min_val, max_val, dx,
num_bits, quant_delay, symmetric, narrow_range,
kernel_name="fake_quant_with_min_max_grad"):
"""FakeQuantWithMinMaxGrad"""
@util.check_input_type(dict, dict, dict, dict, dict, int, bool, bool, str)
def fake_quant_per_layer_grad(dout, x, min_val, max_val, dx,
num_bits, symmetric, narrow_range,
kernel_name="fake_quant_per_layer_grad"):
"""FakeQuantPerLayerGrad"""
input_shape = x.get("shape")
input_dtype = x.get("dtype")
min_shape = min_val.get("ori_shape")
@ -152,8 +155,8 @@ def fake_quant_with_min_max_grad(dout, x, min_val, max_val, dx,
input_data = tvm.placeholder(input_shape, name="x", dtype=x_dtype)
min_data = tvm.placeholder(shape_min, name="min_data", dtype=min_dtype)
max_data = tvm.placeholder(shape_min, name="max_data", dtype=max_dtype)
res = fake_quant_with_min_max_grad_compute(dout_data, input_data, min_data, max_data, quant_min,
quant_max, kernel_name)
res = fake_quant_per_layer_grad_compute(dout_data, input_data, min_data, max_data, quant_min,
quant_max, kernel_name)
with tvm.target.cce():
sch = generic.auto_schedule(res)

View File

@ -20,10 +20,12 @@ from ..._checkparam import Rel
from ..primitive import PrimitiveWithInfer, prim_attr_register
from ...common import dtype as mstype
__all__ = ["FakeQuantWithMinMax",
"FakeQuantWithMinMaxGrad",
"FakeQuantWithMinMaxPerChannel",
"FakeQuantWithMinMaxPerChannelGrad",
__all__ = ["FakeQuantPerLayer",
"FakeQuantPerLayerGrad",
"FakeQuantPerChannel",
"FakeQuantPerChannelGrad",
"FakeQuantMinMaxPerLayerUpdate",
"FakeQuantMinMaxPerChannelUpdate",
"BatchNormFold",
"BatchNormFoldGrad",
"CorrectionMul",
@ -36,11 +38,10 @@ __all__ = ["FakeQuantWithMinMax",
"BatchNormFold2_D",
"BatchNormFold2GradD",
"BatchNormFold2GradReduce",
"FakeQuantWithMinMaxUpdate",
]
class FakeQuantWithMinMax(PrimitiveWithInfer):
class FakeQuantPerLayer(PrimitiveWithInfer):
r"""
Simulate the quantize and dequantize operations in training time.
@ -67,49 +68,67 @@ class FakeQuantWithMinMax(PrimitiveWithInfer):
>>> input_tensor = Tensor(np.random.rand(3, 16, 5, 5), mstype.float32)
>>> min_tensor = Tensor(np.array([-6]), mstype.float32)
>>> max_tensor = Tensor(np.array([6]), mstype.float32)
>>> output_tensor = P.FakeQuantWithMinMax(num_bits=8)(input_tensor, min_tensor, max_tensor)
>>> output_tensor = P.FakeQuantPerLayer(num_bits=8)(input_tensor, min_tensor, max_tensor)
"""
support_quant_bit = [4, 7, 8]
@prim_attr_register
def __init__(self, num_bits=8, ema=False, ema_decay=0.999, quant_delay=0, symmetric=False, narrow_range=False,
def __init__(self,
num_bits=8,
ema=False,
ema_decay=0.999,
quant_delay=0,
symmetric=False,
narrow_range=False,
training=True):
"""init FakeQuantWithMinMax OP"""
"""init FakeQuantPerLayer OP"""
if num_bits not in self.support_quant_bit:
raise ValueError(f"For '{self.name}' attr \'num_bits\' is not support.")
raise ValueError(
f"For '{self.name}' attr \'num_bits\' is not support.")
if ema and not ema_decay:
raise ValueError(f"For '{self.name}' attr \'ema\' and \'ema_decay\' should set together.")
raise ValueError(
f"For '{self.name}' attr \'ema\' and \'ema_decay\' should set together.")
self.ema = validator.check_value_type('ema', ema, (bool,), self.name)
self.symmetric = validator.check_value_type('symmetric', symmetric, (bool,), self.name)
self.narrow_range = validator.check_value_type('narrow_range', narrow_range, (bool,), self.name)
self.training = validator.check_value_type('training', training, (bool,), self.name)
self.ema_decay = validator.check_number_range('ema_decay', ema_decay, 0, 1, Rel.INC_BOTH, self.name)
self.num_bits = validator.check_integer('num_bits', num_bits, 0, Rel.GT, self.name)
self.quant_delay = validator.check_value_type('quant_delay', quant_delay, (int,), self.name)
self.symmetric = validator.check_value_type(
'symmetric', symmetric, (bool,), self.name)
self.narrow_range = validator.check_value_type(
'narrow_range', narrow_range, (bool,), self.name)
self.training = validator.check_value_type(
'training', training, (bool,), self.name)
self.ema_decay = validator.check_number_range(
'ema_decay', ema_decay, 0, 1, Rel.INC_BOTH, self.name)
self.num_bits = validator.check_integer(
'num_bits', num_bits, 0, Rel.GT, self.name)
self.quant_delay = validator.check_value_type(
'quant_delay', quant_delay, (int,), self.name)
self.init_prim_io_names(inputs=['x', 'min', 'max'],
outputs=['out'])
def infer_shape(self, x_shape, min_shape, max_shape):
validator.check_integer("x rank", len(x_shape), 1, Rel.GE, self.name)
validator.check("min shape", min_shape, "max shape", max_shape, Rel.EQ, self.name)
validator.check_integer("min rank", len(min_shape), 1, Rel.EQ, self.name)
validator.check("min shape", min_shape, "max shape",
max_shape, Rel.EQ, self.name)
validator.check_integer("min rank", len(
min_shape), 1, Rel.EQ, self.name)
return x_shape
def infer_dtype(self, x_type, min_type, max_type):
valid_types = (mstype.float16, mstype.float32)
validator.check_tensor_type_same({"x": x_type}, valid_types, self.name)
validator.check_tensor_type_same({"min": min_type}, valid_types, self.name)
validator.check_tensor_type_same({"max": max_type}, valid_types, self.name)
validator.check_tensor_type_same(
{"min": min_type}, valid_types, self.name)
validator.check_tensor_type_same(
{"max": max_type}, valid_types, self.name)
return x_type
class FakeQuantWithMinMaxGrad(PrimitiveWithInfer):
class FakeQuantPerLayerGrad(PrimitiveWithInfer):
r"""
Performs grad of FakeQuantWithMinMax operation.
Performs grad of FakeQuantPerLayerGrad operation.
Examples:
>>> fake_min_max_grad = P.FakeQuantWithMinMaxGrad()
>>> fake_min_max_grad = P.FakeQuantPerLayerGrad()
>>> dout = Tensor(np.array([[-2.3, 1.2], [5.7, 0.2]]), mindspore.float32)
>>> input_x = Tensor(np.array([[18, -23], [0.2, 6]]), mindspore.float32)
>>> _min = Tensor(np.array([-4]), mindspore.float32)
@ -119,32 +138,48 @@ class FakeQuantWithMinMaxGrad(PrimitiveWithInfer):
support_quant_bit = [4, 7, 8]
@prim_attr_register
def __init__(self, num_bits=8, quant_delay=0, symmetric=False, narrow_range=False):
def __init__(self,
num_bits=8,
quant_delay=0,
symmetric=False,
narrow_range=False):
if num_bits not in self.support_quant_bit:
raise ValueError(f"For '{self.name}' attr \'num_bits\' is not support.")
raise ValueError(
f"For '{self.name}' attr \'num_bits\' is not support.")
self.num_bits = validator.check_integer('num_bits', num_bits, 0, Rel.GT, self.name)
self.quant_delay = validator.check_value_type('quant_delay', quant_delay, (int,), self.name)
self.symmetric = validator.check_value_type('symmetric', symmetric, (bool,), self.name)
self.narrow_range = validator.check_value_type('narrow_range', narrow_range, (bool,), self.name)
self.init_prim_io_names(inputs=['dout', 'x', 'min', 'max'], outputs=['dx'])
self.num_bits = validator.check_integer(
'num_bits', num_bits, 0, Rel.GT, self.name)
self.quant_delay = validator.check_value_type(
'quant_delay', quant_delay, (int,), self.name)
self.symmetric = validator.check_value_type(
'symmetric', symmetric, (bool,), self.name)
self.narrow_range = validator.check_value_type(
'narrow_range', narrow_range, (bool,), self.name)
self.init_prim_io_names(
inputs=['dout', 'x', 'min', 'max'], outputs=['dx'])
def infer_shape(self, dout_shape, x_shape, min_shape, max_shape):
validator.check("dout shape", dout_shape, "x shape", x_shape, Rel.EQ, self.name)
validator.check("min shape", min_shape, "max shape", max_shape, Rel.EQ, self.name)
validator.check_integer("min rank", len(min_shape), 1, Rel.EQ, self.name)
validator.check("dout shape", dout_shape, "x shape",
x_shape, Rel.EQ, self.name)
validator.check("min shape", min_shape, "max shape",
max_shape, Rel.EQ, self.name)
validator.check_integer("min rank", len(
min_shape), 1, Rel.EQ, self.name)
return dout_shape
def infer_dtype(self, dout_type, x_type, min_type, max_type):
valid_types = (mstype.float16, mstype.float32)
validator.check_tensor_type_same({"dout": dout_type}, valid_types, self.name)
validator.check_tensor_type_same(
{"dout": dout_type}, valid_types, self.name)
validator.check_tensor_type_same({"x": x_type}, valid_types, self.name)
validator.check_tensor_type_same({"min": min_type}, valid_types, self.name)
validator.check_tensor_type_same({"max": max_type}, valid_types, self.name)
validator.check_tensor_type_same(
{"min": min_type}, valid_types, self.name)
validator.check_tensor_type_same(
{"max": max_type}, valid_types, self.name)
return dout_type
class FakeQuantWithMinMaxPerChannel(PrimitiveWithInfer):
class FakeQuantPerChannel(PrimitiveWithInfer):
r"""
Simulate the quantize and dequantize operations in training time base on per channel.
@ -168,53 +203,73 @@ class FakeQuantWithMinMaxPerChannel(PrimitiveWithInfer):
- Tensor, has the same type as input.
Examples:
>>> fake_quant = P.FakeQuantWithMinMaxPerChannel()
>>> fake_quant = P.FakeQuantPerChannel()
>>> input_x = Tensor(np.array([3, 4, 5, -2, -3, -1]).reshape(3, 2), mindspore.float32)
>>> _min = Tensor(np.linspace(-2, 2, 12).reshape(3, 2, 2), mindspore.float32)
>>> _max = Tensor(np.linspace(8, 12, 12).reshape(3, 2, 2), mindspore.float32)
>>> result = fake_quant(input_x, _min, _max)
"""
support_quant_bit = [4, 7, 8]
channel_axis = 0
@prim_attr_register
def __init__(self, num_bits=8, ema=False, ema_decay=0.999, quant_delay=0, symmetric=False, narrow_range=False,
training=True):
"""init FakeQuantWithMinMaxPerChannel OP"""
def __init__(self,
num_bits=8,
ema=False,
ema_decay=0.999,
quant_delay=0,
symmetric=False,
narrow_range=False,
training=True,
channel_axis=1):
"""init FakeQuantPerChannel OP"""
if num_bits not in self.support_quant_bit:
raise ValueError(f"For '{self.name}' Attr \'num_bits\' is not support.")
raise ValueError(
f"For '{self.name}' Attr \'num_bits\' is not support.")
if ema and not ema_decay:
raise ValueError(f"For '{self.name}' attr \'ema\' and \'ema_decay\' should set together.")
raise ValueError(
f"For '{self.name}' attr \'ema\' and \'ema_decay\' should set together.")
self.ema = validator.check_value_type('ema', ema, (bool,), self.name)
self.symmetric = validator.check_value_type('symmetric', symmetric, (bool,), self.name)
self.narrow_range = validator.check_value_type('narrow_range', narrow_range, (bool,), self.name)
self.training = validator.check_value_type('training', training, (bool,), self.name)
self.ema_decay = validator.check_number_range('ema_decay', ema_decay, 0, 1, Rel.INC_BOTH, self.name)
self.num_bits = validator.check_integer('num_bits', num_bits, 0, Rel.GT, self.name)
self.quant_delay = validator.check_value_type('quant_delay', quant_delay, (int,), self.name)
self.symmetric = validator.check_value_type(
'symmetric', symmetric, (bool,), self.name)
self.narrow_range = validator.check_value_type(
'narrow_range', narrow_range, (bool,), self.name)
self.training = validator.check_value_type(
'training', training, (bool,), self.name)
self.ema_decay = validator.check_number_range(
'ema_decay', ema_decay, 0, 1, Rel.INC_BOTH, self.name)
self.num_bits = validator.check_integer(
'num_bits', num_bits, 0, Rel.GT, self.name)
self.quant_delay = validator.check_value_type(
'quant_delay', quant_delay, (int,), self.name)
self.channel_axis = validator.check_integer(
'channel_axis', channel_axis, 0, Rel.GE, self.name)
self.init_prim_io_names(inputs=['x', 'min', 'max'], outputs=['out'])
def infer_shape(self, x_shape, min_shape, max_shape):
validator.check_integer("x rank", len(x_shape), 1, Rel.GE, self.name)
validator.check_integer("min shape[0]", min_shape[0], x_shape[self.channel_axis], Rel.EQ, self.name)
validator.check_integer("max shape[0]", max_shape[0], x_shape[self.channel_axis], Rel.EQ, self.name)
validator.check_integer(
"min shape[0]", min_shape[0], x_shape[self.channel_axis], Rel.EQ, self.name)
validator.check_integer(
"max shape[0]", max_shape[0], x_shape[self.channel_axis], Rel.EQ, self.name)
return x_shape
def infer_dtype(self, x_type, min_type, max_type):
valid_types = (mstype.float16, mstype.float32)
validator.check_tensor_type_same({"x": x_type}, valid_types, self.name)
validator.check_tensor_type_same({"min": min_type}, valid_types, self.name)
validator.check_tensor_type_same({"max": max_type}, valid_types, self.name)
validator.check_tensor_type_same(
{"min": min_type}, valid_types, self.name)
validator.check_tensor_type_same(
{"max": max_type}, valid_types, self.name)
return x_type
class FakeQuantWithMinMaxPerChannelGrad(PrimitiveWithInfer):
class FakeQuantPerChannelGrad(PrimitiveWithInfer):
r"""
Performs grad of FakeQuantWithMinMaxPerChannel operation.
Performs grad of FakeQuantPerChannelGrad operation.
Examples:
>>> fqmmpc_grad = P.FakeQuantWithMinMaxPerChannelGrad()
>>> fqmmpc_grad = P.FakeQuantPerChannelGrad()
>>> input_x = Tensor(np.random.randint(-4, 4, (2, 3, 4)), mindspore.float32)
>>> dout = Tensor(np.random.randint(-2, 2, (2, 3, 4)), mindspore.float32)
>>> _min = Tensor(np.random.randint(-8, 2, (2, 3, 4)), mindspore.float32)
@ -224,16 +279,29 @@ class FakeQuantWithMinMaxPerChannelGrad(PrimitiveWithInfer):
support_quant_bit = [4, 7, 8]
@prim_attr_register
def __init__(self, num_bits=8, quant_delay=0, symmetric=False, narrow_range=False):
"""init FakeQuantWithMinMaxPerChannel Fill"""
def __init__(self,
num_bits=8,
quant_delay=0,
symmetric=False,
narrow_range=False,
channel_axis=1):
"""init FakeQuantPerChannelGrad Fill"""
if num_bits not in self.support_quant_bit:
raise ValueError(f"For '{self.name}' attr \'num_bits\' is not support.")
raise ValueError(
f"For '{self.name}' attr \'num_bits\' is not support.")
self.num_bits = validator.check_integer('num_bits', num_bits, 0, Rel.GT, self.name)
self.quant_delay = validator.check_value_type('quant_delay', quant_delay, (int,), self.name)
self.symmetric = validator.check_value_type('symmetric', symmetric, (bool,), self.name)
self.narrow_range = validator.check_value_type('narrow_range', narrow_range, (bool,), self.name)
self.init_prim_io_names(inputs=['dout', 'x', 'min', 'max'], outputs=['dx'])
self.num_bits = validator.check_integer(
'num_bits', num_bits, 0, Rel.GT, self.name)
self.quant_delay = validator.check_value_type(
'quant_delay', quant_delay, (int,), self.name)
self.symmetric = validator.check_value_type(
'symmetric', symmetric, (bool,), self.name)
self.narrow_range = validator.check_value_type(
'narrow_range', narrow_range, (bool,), self.name)
self.channel_axis = validator.check_integer(
'channel axis', channel_axis, 0, Rel.GE, self.name)
self.init_prim_io_names(
inputs=['dout', 'x', 'min', 'max'], outputs=['dx'])
def infer_shape(self, dout_shape, x_shape, min_shape, max_shape):
validator.check("dout shape", dout_shape, "x shape", x_shape)
@ -242,10 +310,13 @@ class FakeQuantWithMinMaxPerChannelGrad(PrimitiveWithInfer):
def infer_dtype(self, dout_type, x_type, min_type, max_type):
valid_types = (mstype.float16, mstype.float32)
validator.check_tensor_type_same({"dout": dout_type}, valid_types, self.name)
validator.check_tensor_type_same(
{"dout": dout_type}, valid_types, self.name)
validator.check_tensor_type_same({"x": x_type}, valid_types, self.name)
validator.check_tensor_type_same({"min": min_type}, valid_types, self.name)
validator.check_tensor_type_same({"max": max_type}, valid_types, self.name)
validator.check_tensor_type_same(
{"min": min_type}, valid_types, self.name)
validator.check_tensor_type_same(
{"max": max_type}, valid_types, self.name)
return dout_type
@ -744,17 +815,14 @@ class BatchNormFold2GradReduce(PrimitiveWithInfer):
return dout_type, dout_type
class FakeQuantWithMinMaxUpdate(PrimitiveWithInfer):
class FakeQuantMinMaxPerLayerUpdate(PrimitiveWithInfer):
r"""
Simulate the quantize and dequantize operations in training time.
Update min and max value for fake quant per layer op.
Args:
num_bits (int) : Number bits for aware quantilization. Default: 8.
ema (bool): Use EMA algorithm update value min and max. Default: False.
ema_decay (int) : EMA algorithm decay parameter. Default: 0.999.
quant_delay (int): Quantilization delay parameter. Before delay step in training time not update
simulate aware quantize funcion. After delay step in training time begin simulate the aware
quantize funcion. Default: 0.
symmetric (bool): Quantization algorithm use symmetric or not. Default: False.
narrow_range (bool): Quantization algorithm use narrow range or not. Default: False.
training (bool): Training the network or not. Default: True.
@ -776,36 +844,121 @@ class FakeQuantWithMinMaxUpdate(PrimitiveWithInfer):
support_quant_bit = [4, 7, 8]
@prim_attr_register
def __init__(self, num_bits=8, ema=False, ema_decay=0.999, quant_delay=0, symmetric=False, narrow_range=False,
def __init__(self, num_bits=8, ema=False, ema_decay=0.999, symmetric=False, narrow_range=False,
training=True):
"""init FakeQuantWithMinMax OP"""
"""init FakeQuantMinMaxPerLayerUpdate OP"""
from mindspore.ops._op_impl._custom_op import correction_mul, correction_mul_grad
from mindspore.ops._op_impl._custom_op import fake_quant_with_min_max, fake_quant_with_min_max_grad
from mindspore.ops._op_impl._custom_op import fake_quant_with_min_max_update
if num_bits not in self.support_quant_bit:
raise ValueError(f"For '{self.name}' attr \'num_bits\' is not support.")
raise ValueError(
f"For '{self.name}' attr \'num_bits\' is not support.")
if ema and not ema_decay:
raise ValueError(f"For '{self.name}' attr \'ema\' and \'ema_decay\' should set together.")
raise ValueError(
f"For '{self.name}' attr \'ema\' and \'ema_decay\' should set together.")
self.ema = validator.check_value_type('ema', ema, (bool,), self.name)
self.symmetric = validator.check_value_type('symmetric', symmetric, (bool,), self.name)
self.narrow_range = validator.check_value_type('narrow_range', narrow_range, (bool,), self.name)
self.training = validator.check_value_type('training', training, (bool,), self.name)
self.ema_decay = validator.check_number_range('ema_decay', ema_decay, 0, 1, Rel.INC_BOTH, self.name)
self.num_bits = validator.check_integer('num_bits', num_bits, 0, Rel.GT, self.name)
self.quant_delay = validator.check_value_type('quant_delay', quant_delay, (int,), self.name)
self.symmetric = validator.check_value_type(
'symmetric', symmetric, (bool,), self.name)
self.narrow_range = validator.check_value_type(
'narrow_range', narrow_range, (bool,), self.name)
self.training = validator.check_value_type(
'training', training, (bool,), self.name)
self.ema_decay = validator.check_number_range(
'ema_decay', ema_decay, 0, 1, Rel.INC_BOTH, self.name)
self.num_bits = validator.check_integer(
'num_bits', num_bits, 0, Rel.GT, self.name)
self.init_prim_io_names(inputs=['x', 'min', 'max'],
outputs=['min_up', 'max_up'])
def infer_shape(self, x_shape, min_shape, max_shape):
validator.check_integer("x rank", len(x_shape), 1, Rel.GE, self.name)
validator.check("min shape", min_shape, "max shape", max_shape, Rel.EQ, self.name)
validator.check_integer("min rank", len(min_shape), 1, Rel.EQ, self.name)
validator.check("min shape", min_shape, "max shape",
max_shape, Rel.EQ, self.name)
validator.check_integer("min rank", len(
min_shape), 1, Rel.EQ, self.name)
return min_shape, max_shape
def infer_dtype(self, x_type, min_type, max_type):
valid_types = (mstype.float16, mstype.float32)
validator.check_tensor_type_same({"x": x_type}, valid_types, self.name)
validator.check_tensor_type_same({"min": min_type}, valid_types, self.name)
validator.check_tensor_type_same({"max": max_type}, valid_types, self.name)
validator.check_tensor_type_same(
{"min": min_type}, valid_types, self.name)
validator.check_tensor_type_same(
{"max": max_type}, valid_types, self.name)
return min_type, max_type
class FakeQuantMinMaxPerChannelUpdate(PrimitiveWithInfer):
r"""
Update min and max value for fake quant per layer op.
Args:
num_bits (int) : Number bits for aware quantilization. Default: 8.
ema (bool): Use EMA algorithm update value min and max. Default: False.
ema_decay (int) : EMA algorithm decay parameter. Default: 0.999.
symmetric (bool): Quantization algorithm use symmetric or not. Default: False.
narrow_range (bool): Quantization algorithm use narrow range or not. Default: False.
training (bool): Training the network or not. Default: True.
channel_axis (int): Channel asis for per channel compute. Default: 1.
Inputs:
- **x** (Tensor) : float32 Tensor representing the shape of the output tensor.
- **min** (Tensor) : Value of the min range of the input data x.
- **max** (Tensor) : Value of the max range of the input data x.
Outputs:
- Tensor: Simulate quantize tensor of x.
Examples:
>>> x = Tensor(np.random.rand(3, 16, 5, 5), mstype.float32)
>>> min = Tensor(np.random.uniform(-1, 1, size=16), mstype.float32)
>>> max = Tensor(np.random.uniform(-1, 1, size=16), mstype.float32)
>>> output_tensor = P.FakeQuantWithMinMax(num_bits=8)(x, min, max)
"""
support_quant_bit = [4, 7, 8]
@prim_attr_register
def __init__(self, num_bits=8, ema=False, ema_decay=0.999, symmetric=False, narrow_range=False,
training=True, channel_axis=1):
"""init FakeQuantPerChannelUpdate OP for Ascend"""
if num_bits not in self.support_quant_bit:
raise ValueError(
f"For '{self.name}' attr \'num_bits\' is not support.")
if ema and not ema_decay:
raise ValueError(
f"For '{self.name}' attr \'ema\' and \'ema_decay\' should set together.")
self.ema = validator.check_value_type('ema', ema, (bool,), self.name)
self.symmetric = validator.check_value_type(
'symmetric', symmetric, (bool,), self.name)
self.narrow_range = validator.check_value_type(
'narrow_range', narrow_range, (bool,), self.name)
self.training = validator.check_value_type(
'training', training, (bool,), self.name)
self.ema_decay = validator.check_number_range(
'ema_decay', ema_decay, 0, 1, Rel.INC_BOTH, self.name)
self.num_bits = validator.check_integer(
'num_bits', num_bits, 0, Rel.GT, self.name)
self.channel_axis = validator.check_integer(
'channel axis', channel_axis, 0, Rel.GE, self.name)
self.init_prim_io_names(
inputs=['x', 'min', 'max'], outputs=['min_up', 'max_up'])
def infer_shape(self, x_shape, min_shape, max_shape):
validator.check_integer("x rank", len(x_shape), 1, Rel.GT, self.name)
validator.check("min shape", min_shape, "max shape",
max_shape, Rel.EQ, self.name)
validator.check_integer("min rank", len(
min_shape), 1, Rel.EQ, self.name)
return min_shape, max_shape
def infer_dtype(self, x_type, min_type, max_type):
valid_types = (mstype.float16, mstype.float32)
validator.check_tensor_type_same(
{"x": x_type}, valid_types, self.name)
validator.check_tensor_type_same(
{"min": min_type}, valid_types, self.name)
validator.check_tensor_type_same(
{"max": max_type}, valid_types, self.name)
return min_type, max_type