forked from mindspore-Ecosystem/mindspore
fix perchannel num_channels not set bug and adjust quant.py params order
This commit is contained in:
parent
3e3cbbba0f
commit
f110c7616b
|
@ -12,7 +12,7 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Aware quantization."""
|
||||
"""Quantization aware."""
|
||||
|
||||
from functools import partial
|
||||
import numpy as np
|
||||
|
@ -27,9 +27,16 @@ from mindspore._checkparam import Validator as validator, Rel
|
|||
from mindspore.nn.cell import Cell
|
||||
from mindspore.nn.layer.activation import get_activation
|
||||
import mindspore.context as context
|
||||
from .normalization import BatchNorm2d
|
||||
from .activation import get_activation
|
||||
from ..cell import Cell
|
||||
from . import conv, basic
|
||||
from ..._checkparam import ParamValidator as validator
|
||||
from ...ops.operations import _quant_ops as Q
|
||||
|
||||
__all__ = [
|
||||
'Conv2dBnAct',
|
||||
'DenseBnAct',
|
||||
'FakeQuantWithMinMax',
|
||||
'Conv2dBatchNormQuant',
|
||||
'Conv2dQuant',
|
||||
|
@ -43,6 +50,165 @@ __all__ = [
|
|||
]
|
||||
|
||||
|
||||
class Conv2dBnAct(Cell):
|
||||
r"""
|
||||
A combination of convolution, Batchnorm, activation layer.
|
||||
|
||||
For a more Detailed overview of Conv2d op.
|
||||
|
||||
Args:
|
||||
in_channels (int): The number of input channel :math:`C_{in}`.
|
||||
out_channels (int): The number of output channel :math:`C_{out}`.
|
||||
kernel_size (Union[int, tuple]): The data type is int or tuple with 2 integers. Specifies the height
|
||||
and width of the 2D convolution window. Single int means the value if for both height and width of
|
||||
the kernel. A tuple of 2 ints means the first value is for the height and the other is for the
|
||||
width of the kernel.
|
||||
stride (int): Specifies stride for all spatial dimensions with the same value. Value of stride should be
|
||||
greater or equal to 1 but bounded by the height and width of the input. Default: 1.
|
||||
pad_mode (str): Specifies padding mode. The optional values are "same", "valid", "pad". Default: "same".
|
||||
padding (int): Implicit paddings on both sides of the input. Default: 0.
|
||||
dilation (int): Specifying the dilation rate to use for dilated convolution. If set to be :math:`k > 1`,
|
||||
there will be :math:`k - 1` pixels skipped for each sampling location. Its value should be greater
|
||||
or equal to 1 and bounded by the height and width of the input. Default: 1.
|
||||
group (int): Split filter into groups, `in_ channels` and `out_channels` should be
|
||||
divisible by the number of groups. Default: 1.
|
||||
has_bias (bool): Specifies whether the layer uses a bias vector. Default: False.
|
||||
weight_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the convolution kernel.
|
||||
It can be a Tensor, a string, an Initializer or a numbers.Number. When a string is specified,
|
||||
values from 'TruncatedNormal', 'Normal', 'Uniform', 'HeUniform' and 'XavierUniform' distributions as well
|
||||
as constant 'One' and 'Zero' distributions are possible. Alias 'xavier_uniform', 'he_uniform', 'ones'
|
||||
and 'zeros' are acceptable. Uppercase and lowercase are both acceptable. Refer to the values of
|
||||
Initializer for more details. Default: 'normal'.
|
||||
bias_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the bias vector. Possible
|
||||
Initializer and string are the same as 'weight_init'. Refer to the values of
|
||||
Initializer for more details. Default: 'zeros'.
|
||||
batchnorm (bool): Specifies to used batchnorm or not. Default: None.
|
||||
activation (string): Specifies activation type. The optional values are as following:
|
||||
'softmax', 'logsoftmax', 'relu', 'relu6', 'tanh', 'gelu', 'sigmoid',
|
||||
'prelu', 'leakyrelu', 'hswish', 'hsigmoid'. Default: None.
|
||||
|
||||
Inputs:
|
||||
- **input** (Tensor) - Tensor of shape :math:`(N, C_{in}, H_{in}, W_{in})`.
|
||||
|
||||
Outputs:
|
||||
Tensor of shape :math:`(N, C_{out}, H_{out}, W_{out})`.
|
||||
|
||||
Examples:
|
||||
>>> net = Conv2dBnAct(120, 240, 4, batchnorm=True, activation='ReLU')
|
||||
>>> input = Tensor(np.ones([1, 120, 1024, 640]), mindspore.float32)
|
||||
>>> net(input).shape
|
||||
(1, 240, 1024, 640)
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
in_channels,
|
||||
out_channels,
|
||||
kernel_size,
|
||||
stride=1,
|
||||
pad_mode='same',
|
||||
padding=0,
|
||||
dilation=1,
|
||||
group=1,
|
||||
has_bias=False,
|
||||
weight_init='normal',
|
||||
bias_init='zeros',
|
||||
batchnorm=None,
|
||||
activation=None):
|
||||
super(Conv2dBnAct, self).__init__()
|
||||
self.conv = conv.Conv2d(
|
||||
in_channels,
|
||||
out_channels,
|
||||
kernel_size,
|
||||
stride,
|
||||
pad_mode,
|
||||
padding,
|
||||
dilation,
|
||||
group,
|
||||
has_bias,
|
||||
weight_init,
|
||||
bias_init)
|
||||
self.has_bn = batchnorm is not None
|
||||
self.has_act = activation is not None
|
||||
self.batchnorm = batchnorm
|
||||
if batchnorm is True:
|
||||
self.batchnorm = BatchNorm2d(out_channels)
|
||||
elif batchnorm is not None:
|
||||
validator.check_isinstance('batchnorm', batchnorm, (BatchNorm2d,))
|
||||
self.activation = get_activation(activation)
|
||||
|
||||
def construct(self, x):
|
||||
x = self.conv(x)
|
||||
if self.has_bn:
|
||||
x = self.batchnorm(x)
|
||||
if self.has_act:
|
||||
x = self.activation(x)
|
||||
return x
|
||||
|
||||
|
||||
class DenseBnAct(Cell):
|
||||
r"""
|
||||
A combination of Dense, Batchnorm, activation layer.
|
||||
|
||||
For a more Detailed overview of Dense op.
|
||||
|
||||
Args:
|
||||
in_channels (int): The number of channels in the input space.
|
||||
out_channels (int): The number of channels in the output space.
|
||||
weight_init (Union[Tensor, str, Initializer, numbers.Number]): The trainable weight_init parameter. The dtype
|
||||
is same as input x. The values of str refer to the function `initializer`. Default: 'normal'.
|
||||
bias_init (Union[Tensor, str, Initializer, numbers.Number]): The trainable bias_init parameter. The dtype is
|
||||
same as input x. The values of str refer to the function `initializer`. Default: 'zeros'.
|
||||
has_bias (bool): Specifies whether the layer uses a bias vector. Default: True.
|
||||
activation (str): Regularizer function applied to the output of the layer, eg. 'relu'. Default: None.
|
||||
batchnorm (bool): Specifies to used batchnorm or not. Default: None.
|
||||
activation (string): Specifies activation type. The optional values are as following:
|
||||
'softmax', 'logsoftmax', 'relu', 'relu6', 'tanh', 'gelu', 'sigmoid',
|
||||
'prelu', 'leakyrelu', 'hswish', 'hsigmoid'. Default: None.
|
||||
|
||||
Inputs:
|
||||
- **input** (Tensor) - Tensor of shape :math:`(N, in\_channels)`.
|
||||
|
||||
Outputs:
|
||||
Tensor of shape :math:`(N, out\_channels)`.
|
||||
|
||||
Examples:
|
||||
>>> net = nn.DenseBnAct(3, 4)
|
||||
>>> input = Tensor(np.random.randint(0, 255, [2, 3]), mindspore.float32)
|
||||
>>> net(input)
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
in_channels,
|
||||
out_channels,
|
||||
weight_init='normal',
|
||||
bias_init='zeros',
|
||||
has_bias=True,
|
||||
batchnorm=None,
|
||||
activation=None):
|
||||
super(DenseBnAct, self).__init__()
|
||||
self.dense = basic.Dense(
|
||||
in_channels,
|
||||
out_channels,
|
||||
weight_init,
|
||||
bias_init,
|
||||
has_bias)
|
||||
self.has_bn = batchnorm is not None
|
||||
self.has_act = activation is not None
|
||||
if batchnorm is True:
|
||||
self.batchnorm = BatchNorm2d(out_channels)
|
||||
elif batchnorm is not None:
|
||||
validator.check_isinstance('batchnorm', batchnorm, (BatchNorm2d,))
|
||||
self.activation = get_activation(activation)
|
||||
|
||||
def construct(self, x):
|
||||
x = self.dense(x)
|
||||
if self.has_bn:
|
||||
x = self.batchnorm(x)
|
||||
if self.has_act:
|
||||
x = self.activation(x)
|
||||
return x
|
||||
|
||||
|
||||
class BatchNormFoldCell(Cell):
|
||||
"""
|
||||
Batch normalization folded.
|
||||
|
@ -105,20 +271,20 @@ class BatchNormFoldCell(Cell):
|
|||
|
||||
class FakeQuantWithMinMax(Cell):
|
||||
r"""
|
||||
Aware Quantization op. This OP provide Fake quantization observer function on data with min and max.
|
||||
Quantization aware op. This OP provide Fake quantization observer function on data with min and max.
|
||||
|
||||
Args:
|
||||
min_init (int, float): The dimension of channel or 1(layer). Default: -6.
|
||||
max_init (int, float): The dimension of channel or 1(layer). Default: 6.
|
||||
num_bits (int): Quantization number bit, support 4 and 8bit. Default: 8.
|
||||
ema (bool): Exponential Moving Average algorithm update min and max. Default: False.
|
||||
ema_decay (float): Exponential Moving Average algorithm parameter. Default: 0.999.
|
||||
per_channel (bool): Quantization by layer or channel. Default: False.
|
||||
per_channel (bool): Quantization granularity based on layer or on channel. Default: False.
|
||||
channel_axis (int): Quantization by channel axis. Default: 1.
|
||||
out_channels (int): declarate the min and max channel size, Default: 1.
|
||||
quant_delay (int): Quantization delay parameters according by global step. Default: 0.
|
||||
num_channels (int): declarate the min and max channel size, Default: 1.
|
||||
num_bits (int): Quantization number bit, support 4 and 8bit. Default: 8.
|
||||
symmetric (bool): Quantization algorithm use symmetric or not. Default: False.
|
||||
narrow_range (bool): Quantization algorithm use narrow range or not. Default: False.
|
||||
quant_delay (int): Quantization delay parameters according by global step. Default: 0.
|
||||
|
||||
Inputs:
|
||||
- **x** (Tensor) - The input of FakeQuantWithMinMax.
|
||||
|
@ -135,15 +301,15 @@ class FakeQuantWithMinMax(Cell):
|
|||
def __init__(self,
|
||||
min_init=-6,
|
||||
max_init=6,
|
||||
num_bits=8,
|
||||
ema=False,
|
||||
ema_decay=0.999,
|
||||
per_channel=False,
|
||||
channel_axis=1,
|
||||
out_channels=1,
|
||||
quant_delay=0,
|
||||
num_channels=1,
|
||||
num_bits=8,
|
||||
symmetric=False,
|
||||
narrow_range=False):
|
||||
narrow_range=False,
|
||||
quant_delay=0):
|
||||
"""init FakeQuantWithMinMax layer"""
|
||||
super(FakeQuantWithMinMax, self).__init__()
|
||||
self.min_init = min_init
|
||||
|
@ -152,7 +318,7 @@ class FakeQuantWithMinMax(Cell):
|
|||
self.ema = ema
|
||||
self.ema_decay = ema_decay
|
||||
self.per_channel = per_channel
|
||||
self.out_channels = out_channels
|
||||
self.num_channels = num_channels
|
||||
self.channel_axis = channel_axis
|
||||
self.quant_delay = quant_delay
|
||||
self.symmetric = symmetric
|
||||
|
@ -161,54 +327,54 @@ class FakeQuantWithMinMax(Cell):
|
|||
|
||||
# init tensor min and max for fake quant op
|
||||
if self.per_channel:
|
||||
min_array = np.array([self.min_init for i in range(0, self.out_channels)]).astype(np.float32)
|
||||
max_array = np.array([self.max_init for i in range(0, self.out_channels)]).astype(np.float32)
|
||||
min_array = np.array([self.min_init] * self.num_channels).astype(np.float32)
|
||||
max_array = np.array([self.max_init] * self.num_channels).astype(np.float32)
|
||||
else:
|
||||
min_array = np.array([self.min_init]).reshape(1).astype(np.float32)
|
||||
max_array = np.array([self.max_init]).reshape(1).astype(np.float32)
|
||||
min_array = np.array([self.min_init]).astype(np.float32)
|
||||
max_array = np.array([self.max_init]).astype(np.float32)
|
||||
self.minq = Parameter(Tensor(min_array), name='quant_min', requires_grad=False)
|
||||
self.maxq = Parameter(Tensor(max_array), name='quant_max', requires_grad=False)
|
||||
|
||||
# init fake quant relative op
|
||||
if per_channel:
|
||||
quant_fun = partial(Q.FakeQuantPerChannel, channel_axis=self.channel_axis)
|
||||
ema_fun = partial(Q.FakeQuantMinMaxPerChannelUpdate, channel_axis=self.channel_axis)
|
||||
ema_fun = partial(Q.MinMaxUpdatePerChannel, channel_axis=self.channel_axis)
|
||||
else:
|
||||
quant_fun = Q.FakeQuantPerLayer
|
||||
ema_fun = Q.FakeQuantMinMaxPerLayerUpdate
|
||||
ema_fun = Q.MinMaxUpdatePerLayer
|
||||
|
||||
self.ema_update = ema_fun(ema=self.ema, ema_decay=self.ema_decay)
|
||||
if self.is_ascend:
|
||||
self.fake_quant = quant_fun(num_bits=self.num_bits,
|
||||
symmetric=self.symmetric,
|
||||
narrow_range=self.narrow_range)
|
||||
self.fake_quant_train = quant_fun(num_bits=self.num_bits,
|
||||
symmetric=self.symmetric,
|
||||
narrow_range=self.narrow_range)
|
||||
self.fake_quant_infer = self.fake_quant_train
|
||||
else:
|
||||
self.fake_quant = quant_fun(num_bits=self.num_bits,
|
||||
ema=self.ema,
|
||||
ema_decay=ema_decay,
|
||||
quant_delay=quant_delay,
|
||||
symmetric=self.symmetric,
|
||||
narrow_range=self.narrow_range)
|
||||
self.ema_update = ema_fun(num_bits=self.num_bits,
|
||||
ema=self.ema,
|
||||
ema_decay=self.ema_decay,
|
||||
symmetric=self.symmetric,
|
||||
narrow_range=self.narrow_range)
|
||||
quant_fun = partial(quant_fun,
|
||||
ema=self.ema,
|
||||
ema_decay=ema_decay,
|
||||
num_bits=self.num_bits,
|
||||
symmetric=self.symmetric,
|
||||
narrow_range=self.narrow_range,
|
||||
quant_delay=quant_delay)
|
||||
self.fake_quant_train = quant_fun(training=True)
|
||||
self.fake_quant_infer = quant_fun(training=False)
|
||||
|
||||
def extend_repr(self):
|
||||
s = 'num_bits={}, symmetric={}, narrow_range={}, ema={}({}), per_channel={}({}, {}), ' \
|
||||
'quant_delay={}, min_init={}, max_init={}'.format(
|
||||
self.num_bits, self.symmetric, self.narrow_range, self.ema, self.ema_decay, self.per_channel,
|
||||
self.channel_axis, self.out_channels, self.quant_delay, self.min_init, self.max_init)
|
||||
self.channel_axis, self.num_channels, self.quant_delay, self.min_init, self.max_init)
|
||||
return s
|
||||
|
||||
def construct(self, x):
|
||||
if self.is_ascend and self.training:
|
||||
if self.training:
|
||||
min_up, max_up = self.ema_update(x, self.minq, self.maxq)
|
||||
out = self.fake_quant(x, min_up, max_up)
|
||||
P.Assign()(self.minq, min_up)
|
||||
P.Assign()(self.maxq, max_up)
|
||||
out = self.fake_quant_train(x, self.minq, self.maxq)
|
||||
else:
|
||||
out = self.fake_quant(x, self.minq, self.maxq)
|
||||
out = self.fake_quant_infer(x, self.minq, self.maxq)
|
||||
return out
|
||||
|
||||
|
||||
|
@ -225,8 +391,8 @@ class Conv2dBatchNormQuant(Cell):
|
|||
stride (int): Specifies stride for all spatial dimensions with the same value.
|
||||
pad_mode: (str): Specifies padding mode. The optional values are "same", "valid", "pad". Default: "same".
|
||||
padding: (int): Implicit paddings on both sides of the input. Default: 0.
|
||||
eps (int): Parameters for BatchNormal. Default: 1e-5.
|
||||
momentum (int): Parameters for BatchNormal op. Default: 0.997.
|
||||
eps (float): Parameters for BatchNormal. Default: 1e-5.
|
||||
momentum (float): Parameters for BatchNormal op. Default: 0.997.
|
||||
weight_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the
|
||||
convolution kernel. Default: 'normal'.
|
||||
beta_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the
|
||||
|
@ -237,13 +403,13 @@ class Conv2dBatchNormQuant(Cell):
|
|||
mean vector. Default: 'zeros'.
|
||||
var_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the
|
||||
variance vector. Default: 'ones'.
|
||||
quant_delay (int): Quantization delay parameters according by global step. Default: 0.
|
||||
freeze_bn (int): Quantization freeze BatchNormal op according by global step. Default: 100000.
|
||||
fake (bool): Conv2dBatchNormQuant Cell add FakeQuantWithMinMax op or not. Default: True.
|
||||
num_bits (int): Quantization number bit, support 4 and 8bit. Default: 8.
|
||||
per_channel (bool): FakeQuantWithMinMax Parameters. Default: False.
|
||||
num_bits (int): Quantization number bit, support 4 and 8bit. Default: 8.
|
||||
symmetric (bool): Quantization algorithm use symmetric or not. Default: False.
|
||||
narrow_range (bool): Quantization algorithm use narrow range or not. Default: False.
|
||||
quant_delay (int): Quantization delay parameters according by global step. Default: 0.
|
||||
freeze_bn (int): Quantization freeze BatchNormal op according by global step. Default: 100000.
|
||||
|
||||
Inputs:
|
||||
- **x** (Tensor) - Tensor of shape :math:`(N, C_{in}, H_{in}, W_{in})`.
|
||||
|
@ -274,13 +440,13 @@ class Conv2dBatchNormQuant(Cell):
|
|||
gamma_init='ones',
|
||||
mean_init='zeros',
|
||||
var_init='ones',
|
||||
quant_delay=0,
|
||||
freeze_bn=100000,
|
||||
fake=True,
|
||||
num_bits=8,
|
||||
per_channel=False,
|
||||
num_bits=8,
|
||||
symmetric=False,
|
||||
narrow_range=False):
|
||||
narrow_range=False,
|
||||
quant_delay=0,
|
||||
freeze_bn=100000):
|
||||
"""init Conv2dBatchNormQuant layer"""
|
||||
super(Conv2dBatchNormQuant, self).__init__()
|
||||
self.in_channels = in_channels
|
||||
|
@ -304,8 +470,8 @@ class Conv2dBatchNormQuant(Cell):
|
|||
|
||||
# initialize convolution op and Parameter
|
||||
if context.get_context('device_target') == "Ascend" and group > 1:
|
||||
validator.check_integer('group', group, in_channels, Rel.EQ, 'Conv2dBatchNormQuant')
|
||||
validator.check_integer('group', group, out_channels, Rel.EQ, 'Conv2dBatchNormQuant')
|
||||
validator.check_integer('group', group, in_channels, Rel.EQ)
|
||||
validator.check_integer('group', group, out_channels, Rel.EQ)
|
||||
self.conv = P.DepthwiseConv2dNative(channel_multiplier=1,
|
||||
kernel_size=self.kernel_size,
|
||||
pad_mode=pad_mode,
|
||||
|
@ -337,12 +503,13 @@ class Conv2dBatchNormQuant(Cell):
|
|||
self.fake_quant_weight = FakeQuantWithMinMax(min_init=-6,
|
||||
max_init=6,
|
||||
ema=False,
|
||||
num_bits=num_bits,
|
||||
quant_delay=quant_delay,
|
||||
per_channel=per_channel,
|
||||
out_channels=out_channels,
|
||||
channel_axis=channel_axis,
|
||||
num_channels=out_channels,
|
||||
num_bits=num_bits,
|
||||
symmetric=symmetric,
|
||||
narrow_range=narrow_range)
|
||||
narrow_range=narrow_range,
|
||||
quant_delay=quant_delay)
|
||||
self.batchnorm_fold = BatchNormFoldCell(epsilon=eps, momentum=momentum, freeze_bn=freeze_bn)
|
||||
self.correct_mul = Q.CorrectionMul(channel_axis)
|
||||
if context.get_context('device_target') == "Ascend":
|
||||
|
@ -416,11 +583,11 @@ class Conv2dQuant(Cell):
|
|||
weight_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the convolution kernel.
|
||||
Default: 'normal'.
|
||||
bias_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the bias vector. Default: 'zeros'.
|
||||
quant_delay (int): Quantization delay parameters according by global step. Default: 0.
|
||||
num_bits (int): Quantization number bit, support 4 and 8bit. Default: 8.
|
||||
per_channel (bool): FakeQuantWithMinMax Parameters. Default: False.
|
||||
num_bits (int): Quantization number bit, support 4 and 8bit. Default: 8.
|
||||
symmetric (bool): Quantization algorithm use symmetric or not. Default: False.
|
||||
narrow_range (bool): Quantization algorithm use narrow range or not. Default: False.
|
||||
quant_delay (int): Quantization delay parameters according by global step. Default: 0.
|
||||
|
||||
Inputs:
|
||||
- **x** (Tensor) - Tensor of shape :math:`(N, C_{in}, H_{in}, W_{in})`.
|
||||
|
@ -447,11 +614,11 @@ class Conv2dQuant(Cell):
|
|||
has_bias=False,
|
||||
weight_init='normal',
|
||||
bias_init='zeros',
|
||||
quant_delay=0,
|
||||
num_bits=8,
|
||||
per_channel=False,
|
||||
num_bits=8,
|
||||
symmetric=False,
|
||||
narrow_range=False):
|
||||
narrow_range=False,
|
||||
quant_delay=0):
|
||||
super(Conv2dQuant, self).__init__()
|
||||
if isinstance(kernel_size, int):
|
||||
self.kernel_size = (kernel_size, kernel_size)
|
||||
|
@ -487,12 +654,13 @@ class Conv2dQuant(Cell):
|
|||
self.fake_quant_weight = FakeQuantWithMinMax(min_init=-6,
|
||||
max_init=6,
|
||||
ema=False,
|
||||
num_bits=num_bits,
|
||||
quant_delay=quant_delay,
|
||||
per_channel=per_channel,
|
||||
out_channels=out_channels,
|
||||
channel_axis=0,
|
||||
num_channels=out_channels,
|
||||
num_bits=num_bits,
|
||||
symmetric=symmetric,
|
||||
narrow_range=narrow_range)
|
||||
narrow_range=narrow_range,
|
||||
quant_delay=quant_delay)
|
||||
|
||||
def construct(self, x):
|
||||
weight = self.fake_quant_weight(self.weight)
|
||||
|
@ -526,11 +694,11 @@ class DenseQuant(Cell):
|
|||
same as input x. The values of str refer to the function `initializer`. Default: 'zeros'.
|
||||
has_bias (bool): Specifies whether the layer uses a bias vector. Default: True.
|
||||
activation (str): Regularizer function applied to the output of the layer, eg. 'relu'. Default: None.
|
||||
num_bits (int): Quantization number bit, support 4 and 8bit. Default: 8.
|
||||
quant_delay (int): Quantization delay parameters according by global step. Default: 0.
|
||||
per_channel (bool): FakeQuantWithMinMax Parameters. Default: False.
|
||||
num_bits (int): Quantization number bit, support 4 and 8bit. Default: 8.
|
||||
symmetric (bool): Quantization algorithm use symmetric or not. Default: False.
|
||||
narrow_range (bool): Quantization algorithm use narrow range or not. Default: False.
|
||||
quant_delay (int): Quantization delay parameters according by global step. Default: 0.
|
||||
|
||||
Inputs:
|
||||
- **x** (Tensor) - Tensor of shape :math:`(N, C_{in}, H_{in}, W_{in})`.
|
||||
|
@ -552,11 +720,11 @@ class DenseQuant(Cell):
|
|||
bias_init='zeros',
|
||||
has_bias=True,
|
||||
activation=None,
|
||||
num_bits=8,
|
||||
quant_delay=0,
|
||||
per_channel=False,
|
||||
num_bits=8,
|
||||
symmetric=False,
|
||||
narrow_range=False):
|
||||
narrow_range=False,
|
||||
quant_delay=0):
|
||||
super(DenseQuant, self).__init__()
|
||||
self.in_channels = check_int_positive(in_channels)
|
||||
self.out_channels = check_int_positive(out_channels)
|
||||
|
@ -586,12 +754,13 @@ class DenseQuant(Cell):
|
|||
self.fake_quant_weight = FakeQuantWithMinMax(min_init=-6,
|
||||
max_init=6,
|
||||
ema=False,
|
||||
num_bits=num_bits,
|
||||
quant_delay=quant_delay,
|
||||
per_channel=per_channel,
|
||||
out_channels=out_channels,
|
||||
channel_axis=0,
|
||||
num_channels=out_channels,
|
||||
num_bits=num_bits,
|
||||
symmetric=symmetric,
|
||||
narrow_range=narrow_range)
|
||||
narrow_range=narrow_range,
|
||||
quant_delay=quant_delay)
|
||||
|
||||
def construct(self, x):
|
||||
"""Use operators to construct to Dense layer."""
|
||||
|
@ -615,18 +784,28 @@ class DenseQuant(Cell):
|
|||
return str_info
|
||||
|
||||
|
||||
class ReLUQuant(Cell):
|
||||
class _QuantActivation(Cell):
|
||||
r"""
|
||||
Base class for Quant activation function. Add Fake Quant OP after activation OP.
|
||||
"""
|
||||
|
||||
def get_origin(self):
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class ReLUQuant(_QuantActivation):
|
||||
r"""
|
||||
ReLUQuant activation function. Add Fake Quant OP after Relu OP.
|
||||
|
||||
For a more Detailed overview of ReLU op.
|
||||
|
||||
Args:
|
||||
num_bits (int): Quantization number bit, support 4 and 8bit. Default: 8.
|
||||
quant_delay (int): Quantization delay parameters according by global step. Default: 0.
|
||||
ema_decay (float): Exponential Moving Average algorithm parameter. Default: 0.999.
|
||||
per_channel (bool): Quantization granularity based on layer or on channel. Default: False.
|
||||
num_bits (int): Quantization number bit, support 4 and 8bit. Default: 8.
|
||||
symmetric (bool): Quantization algorithm use symmetric or not. Default: False.
|
||||
narrow_range (bool): Quantization algorithm use narrow range or not. Default: False.
|
||||
quant_delay (int): Quantization delay parameters according by global step. Default: 0.
|
||||
|
||||
Inputs:
|
||||
- **x** (Tensor) - The input of ReLUQuant.
|
||||
|
@ -641,20 +820,22 @@ class ReLUQuant(Cell):
|
|||
"""
|
||||
|
||||
def __init__(self,
|
||||
num_bits=8,
|
||||
quant_delay=0,
|
||||
ema_decay=0.999,
|
||||
per_channel=False,
|
||||
num_bits=8,
|
||||
symmetric=False,
|
||||
narrow_range=False):
|
||||
narrow_range=False,
|
||||
quant_delay=0):
|
||||
super(ReLUQuant, self).__init__()
|
||||
self.fake_quant_act = FakeQuantWithMinMax(min_init=0,
|
||||
max_init=6,
|
||||
num_bits=num_bits,
|
||||
quant_delay=quant_delay,
|
||||
ema=True,
|
||||
ema_decay=ema_decay,
|
||||
per_channel=per_channel,
|
||||
num_bits=num_bits,
|
||||
symmetric=symmetric,
|
||||
narrow_range=narrow_range)
|
||||
narrow_range=narrow_range,
|
||||
quant_delay=quant_delay)
|
||||
self.relu = P.ReLU()
|
||||
|
||||
def construct(self, x):
|
||||
|
@ -662,8 +843,11 @@ class ReLUQuant(Cell):
|
|||
x = self.fake_quant_act(x)
|
||||
return x
|
||||
|
||||
def get_origin(self):
|
||||
return self.relu
|
||||
|
||||
class ReLU6Quant(Cell):
|
||||
|
||||
class ReLU6Quant(_QuantActivation):
|
||||
r"""
|
||||
ReLU6Quant activation function.
|
||||
|
||||
|
@ -672,11 +856,12 @@ class ReLU6Quant(Cell):
|
|||
For a more Detailed overview of ReLU6 op.
|
||||
|
||||
Args:
|
||||
num_bits (int): Quantization number bit, support 4 and 8bit. Default: 8.
|
||||
quant_delay (int): Quantization delay parameters according by global step. Default: 0.
|
||||
ema_decay (float): Exponential Moving Average algorithm parameter. Default: 0.999.
|
||||
per_channel (bool): Quantization granularity based on layer or on channel. Default: False.
|
||||
num_bits (int): Quantization number bit, support 4 and 8bit. Default: 8.
|
||||
symmetric (bool): Quantization algorithm use symmetric or not. Default: False.
|
||||
narrow_range (bool): Quantization algorithm use narrow range or not. Default: False.
|
||||
quant_delay (int): Quantization delay parameters according by global step. Default: 0.
|
||||
|
||||
Inputs:
|
||||
- **x** (Tensor) - The input of ReLU6Quant.
|
||||
|
@ -691,20 +876,22 @@ class ReLU6Quant(Cell):
|
|||
"""
|
||||
|
||||
def __init__(self,
|
||||
num_bits=8,
|
||||
quant_delay=0,
|
||||
ema_decay=0.999,
|
||||
per_channel=False,
|
||||
num_bits=8,
|
||||
symmetric=False,
|
||||
narrow_range=False):
|
||||
narrow_range=False,
|
||||
quant_delay=0):
|
||||
super(ReLU6Quant, self).__init__()
|
||||
self.fake_quant_act = FakeQuantWithMinMax(min_init=0,
|
||||
max_init=6,
|
||||
num_bits=num_bits,
|
||||
quant_delay=quant_delay,
|
||||
ema=True,
|
||||
ema_decay=ema_decay,
|
||||
per_channel=per_channel,
|
||||
num_bits=num_bits,
|
||||
symmetric=symmetric,
|
||||
narrow_range=narrow_range)
|
||||
narrow_range=narrow_range,
|
||||
quant_delay=quant_delay)
|
||||
self.relu6 = P.ReLU6()
|
||||
|
||||
def construct(self, x):
|
||||
|
@ -712,19 +899,23 @@ class ReLU6Quant(Cell):
|
|||
x = self.fake_quant_act(x)
|
||||
return x
|
||||
|
||||
def get_origin(self):
|
||||
return self.relu6
|
||||
|
||||
class HSwishQuant(Cell):
|
||||
|
||||
class HSwishQuant(_QuantActivation):
|
||||
r"""
|
||||
HSwishQuant activation function. Add Fake Quant OP after HSwish OP.
|
||||
|
||||
For a more Detailed overview of HSwish op.
|
||||
|
||||
Args:
|
||||
num_bits (int): Quantization number bit, support 4 and 8bit. Default: 8.
|
||||
quant_delay (int): Quantization delay parameters according by global step. Default: 0.
|
||||
ema_decay (float): Exponential Moving Average algorithm parameter. Default: 0.999.
|
||||
per_channel (bool): Quantization granularity based on layer or on channel. Default: False.
|
||||
num_bits (int): Quantization number bit, support 4 and 8bit. Default: 8.
|
||||
symmetric (bool): Quantization algorithm use symmetric or not. Default: False.
|
||||
narrow_range (bool): Quantization algorithm use narrow range or not. Default: False.
|
||||
quant_delay (int): Quantization delay parameters according by global step. Default: 0.
|
||||
|
||||
Inputs:
|
||||
- **x** (Tensor) - The input of HSwishQuant.
|
||||
|
@ -739,28 +930,31 @@ class HSwishQuant(Cell):
|
|||
"""
|
||||
|
||||
def __init__(self,
|
||||
num_bits=8,
|
||||
quant_delay=0,
|
||||
ema_decay=0.999,
|
||||
per_channel=False,
|
||||
num_bits=8,
|
||||
symmetric=False,
|
||||
narrow_range=False):
|
||||
narrow_range=False,
|
||||
quant_delay=0):
|
||||
super(HSwishQuant, self).__init__()
|
||||
self.fake_quant_act_before = FakeQuantWithMinMax(min_init=-6,
|
||||
max_init=6,
|
||||
num_bits=num_bits,
|
||||
quant_delay=quant_delay,
|
||||
ema=True,
|
||||
ema_decay=ema_decay,
|
||||
per_channel=per_channel,
|
||||
num_bits=num_bits,
|
||||
symmetric=symmetric,
|
||||
narrow_range=narrow_range)
|
||||
narrow_range=narrow_range,
|
||||
quant_delay=quant_delay)
|
||||
self.fake_quant_act_after = FakeQuantWithMinMax(min_init=-6,
|
||||
max_init=6,
|
||||
num_bits=num_bits,
|
||||
quant_delay=quant_delay,
|
||||
ema=True,
|
||||
ema_decay=ema_decay,
|
||||
per_channel=per_channel,
|
||||
num_bits=num_bits,
|
||||
symmetric=symmetric,
|
||||
narrow_range=narrow_range)
|
||||
narrow_range=narrow_range,
|
||||
quant_delay=quant_delay)
|
||||
self.act = P.HSwish()
|
||||
|
||||
def construct(self, x):
|
||||
|
@ -769,19 +963,23 @@ class HSwishQuant(Cell):
|
|||
x = self.fake_quant_act_after(x)
|
||||
return x
|
||||
|
||||
def get_origin(self):
|
||||
return self.act
|
||||
|
||||
class HSigmoidQuant(Cell):
|
||||
|
||||
class HSigmoidQuant(_QuantActivation):
|
||||
r"""
|
||||
HSigmoidQuant activation function. Add Fake Quant OP before and after HSigmoid OP.
|
||||
|
||||
For a more Detailed overview of HSigmoid op.
|
||||
|
||||
Args:
|
||||
num_bits (int): Quantization number bit, support 4 and 8bit. Default: 8.
|
||||
quant_delay (int): Quantization delay parameters according by global step. Default: 0.
|
||||
ema_decay (float): Exponential Moving Average algorithm parameter. Default: 0.999.
|
||||
per_channel (bool): Quantization granularity based on layer or on channel. Default: False.
|
||||
num_bits (int): Quantization number bit, support 4 and 8bit. Default: 8.
|
||||
symmetric (bool): Quantization algorithm use symmetric or not. Default: False.
|
||||
narrow_range (bool): Quantization algorithm use narrow range or not. Default: False.
|
||||
quant_delay (int): Quantization delay parameters according by global step. Default: 0.
|
||||
|
||||
Inputs:
|
||||
- **x** (Tensor) - The input of HSigmoidQuant.
|
||||
|
@ -796,27 +994,31 @@ class HSigmoidQuant(Cell):
|
|||
"""
|
||||
|
||||
def __init__(self,
|
||||
num_bits=8,
|
||||
quant_delay=0,
|
||||
ema_decay=0.999,
|
||||
per_channel=False,
|
||||
num_bits=8,
|
||||
symmetric=False,
|
||||
narrow_range=False):
|
||||
narrow_range=False,
|
||||
quant_delay=0):
|
||||
super(HSigmoidQuant, self).__init__()
|
||||
self.fake_quant_act_before = FakeQuantWithMinMax(min_init=-6,
|
||||
max_init=6,
|
||||
num_bits=num_bits,
|
||||
quant_delay=quant_delay,
|
||||
ema=True,
|
||||
ema_decay=ema_decay,
|
||||
per_channel=per_channel,
|
||||
num_bits=num_bits,
|
||||
symmetric=symmetric,
|
||||
narrow_range=narrow_range)
|
||||
narrow_range=narrow_range,
|
||||
quant_delay=quant_delay)
|
||||
self.fake_quant_act_after = FakeQuantWithMinMax(min_init=-6,
|
||||
max_init=6,
|
||||
num_bits=num_bits,
|
||||
quant_delay=quant_delay,
|
||||
ema=True,
|
||||
ema_decay=ema_decay,
|
||||
per_channel=per_channel,
|
||||
num_bits=num_bits,
|
||||
symmetric=symmetric,
|
||||
narrow_range=narrow_range)
|
||||
narrow_range=narrow_range,
|
||||
quant_delay=quant_delay)
|
||||
self.act = P.HSigmoid()
|
||||
|
||||
def construct(self, x):
|
||||
|
@ -825,6 +1027,9 @@ class HSigmoidQuant(Cell):
|
|||
x = self.fake_quant_act_after(x)
|
||||
return x
|
||||
|
||||
def get_origin(self):
|
||||
return self.act
|
||||
|
||||
|
||||
class TensorAddQuant(Cell):
|
||||
r"""
|
||||
|
@ -833,11 +1038,12 @@ class TensorAddQuant(Cell):
|
|||
For a more Detailed overview of TensorAdd op.
|
||||
|
||||
Args:
|
||||
num_bits (int): Quantization number bit, support 4 and 8bit. Default: 8.
|
||||
quant_delay (int): Quantization delay parameters according by global step. Default: 0.
|
||||
ema_decay (float): Exponential Moving Average algorithm parameter. Default: 0.999.
|
||||
per_channel (bool): Quantization granularity based on layer or on channel. Default: False.
|
||||
num_bits (int): Quantization number bit, support 4 and 8bit. Default: 8.
|
||||
symmetric (bool): Quantization algorithm use symmetric or not. Default: False.
|
||||
narrow_range (bool): Quantization algorithm use narrow range or not. Default: False.
|
||||
quant_delay (int): Quantization delay parameters according by global step. Default: 0.
|
||||
|
||||
Inputs:
|
||||
- **x** (Tensor) - The input of TensorAddQuant.
|
||||
|
@ -853,20 +1059,22 @@ class TensorAddQuant(Cell):
|
|||
"""
|
||||
|
||||
def __init__(self,
|
||||
num_bits=8,
|
||||
quant_delay=0,
|
||||
ema_decay=0.999,
|
||||
per_channel=False,
|
||||
num_bits=8,
|
||||
symmetric=False,
|
||||
narrow_range=False):
|
||||
narrow_range=False,
|
||||
quant_delay=0):
|
||||
super(TensorAddQuant, self).__init__()
|
||||
self.fake_quant_act = FakeQuantWithMinMax(min_init=-6,
|
||||
max_init=6,
|
||||
num_bits=num_bits,
|
||||
quant_delay=quant_delay,
|
||||
ema=True,
|
||||
ema_decay=ema_decay,
|
||||
per_channel=per_channel,
|
||||
num_bits=num_bits,
|
||||
symmetric=symmetric,
|
||||
narrow_range=narrow_range)
|
||||
narrow_range=narrow_range,
|
||||
quant_delay=quant_delay)
|
||||
self.add = P.TensorAdd()
|
||||
|
||||
def construct(self, x1, x2):
|
||||
|
@ -882,11 +1090,12 @@ class MulQuant(Cell):
|
|||
For a more Detailed overview of Mul op.
|
||||
|
||||
Args:
|
||||
num_bits (int): Quantization number bit, support 4 and 8bit. Default: 8.
|
||||
quant_delay (int): Quantization delay parameters according by global step. Default: 0.
|
||||
ema_decay (float): Exponential Moving Average algorithm parameter. Default: 0.999.
|
||||
per_channel (bool): Quantization granularity based on layer or on channel. Default: False.
|
||||
num_bits (int): Quantization number bit, support 4 and 8bit. Default: 8.
|
||||
symmetric (bool): Quantization algorithm use symmetric or not. Default: False.
|
||||
narrow_range (bool): Quantization algorithm use narrow range or not. Default: False.
|
||||
quant_delay (int): Quantization delay parameters according by global step. Default: 0.
|
||||
|
||||
Inputs:
|
||||
- **x** (Tensor) - The input of MulQuant.
|
||||
|
@ -897,23 +1106,99 @@ class MulQuant(Cell):
|
|||
"""
|
||||
|
||||
def __init__(self,
|
||||
num_bits=8,
|
||||
quant_delay=0,
|
||||
ema_decay=0.999,
|
||||
per_channel=False,
|
||||
num_bits=8,
|
||||
symmetric=False,
|
||||
narrow_range=False):
|
||||
narrow_range=False,
|
||||
quant_delay=0):
|
||||
super(MulQuant, self).__init__()
|
||||
self.fake_quant_act = FakeQuantWithMinMax(min_init=-6,
|
||||
max_init=6,
|
||||
num_bits=num_bits,
|
||||
quant_delay=quant_delay,
|
||||
ema=True,
|
||||
ema_decay=ema_decay,
|
||||
per_channel=per_channel,
|
||||
num_bits=num_bits,
|
||||
symmetric=symmetric,
|
||||
narrow_range=narrow_range)
|
||||
narrow_range=narrow_range,
|
||||
quant_delay=quant_delay)
|
||||
self.mul = P.Mul()
|
||||
|
||||
def construct(self, x1, x2):
|
||||
x = self.mul(x1, x2)
|
||||
x = self.fake_quant_act(x)
|
||||
return x
|
||||
|
||||
|
||||
class QuantBlock(Cell):
|
||||
r"""
|
||||
A quant block of Conv/Dense, activation layer for Ascend deploy.
|
||||
|
||||
Calculate Conv or Dense in Int8, with AscendQuant and AscendDeQuant.
|
||||
|
||||
Notes:
|
||||
This block is only for deploy, and not trainable.
|
||||
|
||||
Args:
|
||||
in_channels (int): The number of channels in the input space.
|
||||
out_channels (int): The number of channels in the output space.
|
||||
weight_init (Union[Tensor, str, Initializer, numbers.Number]): The trainable weight_init parameter. The dtype
|
||||
is same as input x. The values of str refer to the function `initializer`. Default: 'normal'.
|
||||
bias_init (Union[Tensor, str, Initializer, numbers.Number]): The trainable bias_init parameter. The dtype is
|
||||
same as input x. The values of str refer to the function `initializer`. Default: 'zeros'.
|
||||
has_bias (bool): Specifies whether the layer uses a bias vector. Default: True.
|
||||
activation (str): Regularizer function applied to the output of the layer, eg. 'relu'. Default: None.
|
||||
batchnorm (bool): Specifies to used batchnorm or not. Default: None.
|
||||
activation (string): Specifies activation type. The optional values are as following:
|
||||
'softmax', 'logsoftmax', 'relu', 'relu6', 'tanh', 'gelu', 'sigmoid',
|
||||
'prelu', 'leakyrelu', 'hswish', 'hsigmoid'. Default: None.
|
||||
|
||||
Inputs:
|
||||
- **input** (Tensor) - Tensor of shape :math:`(N, in\_channels)`.
|
||||
|
||||
Outputs:
|
||||
Tensor of shape :math:`(N, out\_channels)`.
|
||||
|
||||
Examples:
|
||||
>>> net = nn.Dense(3, 4)
|
||||
>>> input = Tensor(np.random.randint(0, 255, [2, 3]), mindspore.float32)
|
||||
>>> net(input)
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
core_op,
|
||||
weight,
|
||||
quant_op,
|
||||
dequant_op,
|
||||
dequant_scale,
|
||||
bias=None,
|
||||
activation=None):
|
||||
super(QuantBlock, self).__init__()
|
||||
self.core_op = core_op
|
||||
self.weight = weight
|
||||
self.quant = quant_op
|
||||
self.dequant = dequant_op
|
||||
self.dequant_scale = dequant_scale
|
||||
self.bias = bias
|
||||
self.has_bias = bias is None
|
||||
self.activation = activation
|
||||
self.has_act = activation is None
|
||||
|
||||
def construct(self, x):
|
||||
x = self.quant(x)
|
||||
x = self.core_op(x, self.weight)
|
||||
if self.has_bias:
|
||||
output = self.bias_add(output, self.bias)
|
||||
if self.has_act:
|
||||
x = self.activation(x)
|
||||
x = self.dequant(x, self.dequant_scale)
|
||||
return x
|
||||
|
||||
def extend_repr(self):
|
||||
str_info = f'quant={self.quant}, core_op={type(self.core_op)}'
|
||||
if self.has_bias:
|
||||
str_info = str_info + f', bias={self.bias}'
|
||||
if self.has_act:
|
||||
str_info = str_info + f', activation={self.activation}'
|
||||
str_info = str_info + f', dequant={self.dequant}'
|
||||
return str_info
|
||||
|
|
|
@ -13,7 +13,7 @@
|
|||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
"""Generate bprop for aware quantization ops"""
|
||||
"""Generate bprop for quantization aware ops"""
|
||||
|
||||
from .. import operations as P
|
||||
from ..operations import _quant_ops as Q
|
||||
|
@ -133,9 +133,9 @@ def get_bprop_batchnorm_fold2_(self):
|
|||
return bprop
|
||||
|
||||
|
||||
@bprop_getters.register(Q.FakeQuantMinMaxPerLayerUpdate)
|
||||
@bprop_getters.register(Q.MinMaxUpdatePerLayer)
|
||||
def get_bprop_fakequant_with_minmax_per_layer_update(self):
|
||||
"""Generate bprop for FakeQuantMinMaxPerLayerUpdate for Ascend"""
|
||||
"""Generate bprop for MinMaxUpdatePerLayer for Ascend"""
|
||||
|
||||
def bprop(x, x_min, x_max, out, dout):
|
||||
return zeros_like(x), zeros_like(x_min), zeros_like(x_max)
|
||||
|
@ -143,9 +143,9 @@ def get_bprop_fakequant_with_minmax_per_layer_update(self):
|
|||
return bprop
|
||||
|
||||
|
||||
@bprop_getters.register(Q.FakeQuantMinMaxPerChannelUpdate)
|
||||
@bprop_getters.register(Q.MinMaxUpdatePerChannel)
|
||||
def get_bprop_fakequant_with_minmax_per_channel_update(self):
|
||||
"""Generate bprop for FakeQuantMinMaxPerChannelUpdate for Ascend"""
|
||||
"""Generate bprop for MinMaxUpdatePerChannel for Ascend"""
|
||||
|
||||
def bprop(x, x_min, x_max, out, dout):
|
||||
return zeros_like(x), zeros_like(x_min), zeros_like(x_max)
|
||||
|
|
|
@ -1,4 +1,3 @@
|
|||
|
||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
|
@ -14,7 +13,7 @@
|
|||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
"""FakeQuantMinMaxPerChannelUpdate op"""
|
||||
"""MinMaxUpdatePerChannel op"""
|
||||
import te.lang.cce
|
||||
from te import tvm
|
||||
from te.platform.fusion_manager import fusion_manager
|
||||
|
@ -22,20 +21,15 @@ from topi import generic
|
|||
from topi.cce import util
|
||||
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
|
||||
|
||||
|
||||
fake_quant_min_max_per_channel_update_op_info = TBERegOp("FakeQuantMinMaxPerChannelUpdate") \
|
||||
minmax_update_perchannel_op_info = TBERegOp("MinMaxUpdatePerChannel") \
|
||||
.fusion_type("OPAQUE") \
|
||||
.async_flag(False) \
|
||||
.binfile_name("fake_quant_min_max_per_channel_update.so") \
|
||||
.binfile_name("minmax_update_perchannel.so") \
|
||||
.compute_cost(10) \
|
||||
.kernel_name("fake_quant_min_max_per_channel_update") \
|
||||
.kernel_name("minmax_update_perchannel") \
|
||||
.partial_flag(True) \
|
||||
.attr("ema", "optional", "bool", "all") \
|
||||
.attr("ema_decay", "optional", "float", "all") \
|
||||
.attr("symmetric", "optional", "bool", "all") \
|
||||
.attr("narrow_range", "optional", "bool", "all") \
|
||||
.attr("training", "optional", "bool", "all") \
|
||||
.attr("num_bits", "optional", "int", "all") \
|
||||
.attr("channel_axis", "optional", "int", "all") \
|
||||
.input(0, "x", None, "required", None) \
|
||||
.input(1, "min", None, "required", None) \
|
||||
|
@ -47,43 +41,46 @@ fake_quant_min_max_per_channel_update_op_info = TBERegOp("FakeQuantMinMaxPerChan
|
|||
.get_op_info()
|
||||
|
||||
|
||||
@op_info_register(fake_quant_min_max_per_channel_update_op_info)
|
||||
def _fake_quant_min_max_per_channel_update_tbe():
|
||||
"""FakeQuantPerChannelUpdate TBE register"""
|
||||
@op_info_register(minmax_update_perchannel_op_info)
|
||||
def _minmax_update_perchannel_tbe():
|
||||
"""MinMaxUpdatePerChannel TBE register"""
|
||||
return
|
||||
|
||||
|
||||
@fusion_manager.register("fake_quant_min_max_per_channel_update")
|
||||
def fake_quant_min_max_per_channel_update_compute(x, min_val, max_val,
|
||||
ema, ema_decay, quant_min, quant_max, training, channel_axis,
|
||||
kernel_name="fake_quant_min_max_per_channel_update"):
|
||||
"""FakeQuantPerChannelUpdate compute"""
|
||||
@fusion_manager.register("minmax_update_perchannel")
|
||||
def minmax_update_perchannel_compute(x, min_val, max_val,
|
||||
ema, ema_decay, channel_axis):
|
||||
"""MinMaxUpdatePerChannel compute"""
|
||||
shape_min = te.lang.cce.util.shape_to_list(min_val.shape)
|
||||
|
||||
if not ema:
|
||||
ema_decay = 0.0
|
||||
if training:
|
||||
# CalMinMax
|
||||
|
||||
# CalMinMax
|
||||
if channel_axis == 0:
|
||||
axis = [1, 2, 3, 4]
|
||||
else:
|
||||
axis = [0, 2, 3]
|
||||
x_min = te.lang.cce.reduce_min(x, axis=axis)
|
||||
x_max = te.lang.cce.reduce_max(x, axis=axis)
|
||||
x_min = te.lang.cce.broadcast(x_min, shape_min)
|
||||
x_max = te.lang.cce.broadcast(x_max, shape_min)
|
||||
min_val = te.lang.cce.vadd(te.lang.cce.vmuls(
|
||||
min_val, ema_decay), te.lang.cce.vmuls(x_min, (1 - ema_decay)))
|
||||
max_val = te.lang.cce.vadd(te.lang.cce.vmuls(
|
||||
max_val, ema_decay), te.lang.cce.vmuls(x_max, (1 - ema_decay)))
|
||||
min_val = te.lang.cce.vmins(min_val, 0)
|
||||
max_val = te.lang.cce.vmaxs(max_val, 0)
|
||||
|
||||
x_min = te.lang.cce.reduce_min(x, axis=axis)
|
||||
x_max = te.lang.cce.reduce_max(x, axis=axis)
|
||||
x_min = te.lang.cce.broadcast(x_min, shape_min)
|
||||
x_max = te.lang.cce.broadcast(x_max, shape_min)
|
||||
min_val = te.lang.cce.vadd(te.lang.cce.vmuls(
|
||||
min_val, ema_decay), te.lang.cce.vmuls(x_min, (1 - ema_decay)))
|
||||
max_val = te.lang.cce.vadd(te.lang.cce.vmuls(
|
||||
max_val, ema_decay), te.lang.cce.vmuls(x_max, (1 - ema_decay)))
|
||||
min_val = te.lang.cce.vmins(min_val, 0)
|
||||
max_val = te.lang.cce.vmaxs(max_val, 0)
|
||||
|
||||
return [min_val, max_val]
|
||||
|
||||
|
||||
@util.check_input_type(dict, dict, dict, dict, dict, bool, float, bool, bool, bool, int, int, str)
|
||||
def fake_quant_min_max_per_channel_update(x, min_val, max_val, min_up, max_up,
|
||||
ema, ema_decay, symmetric, narrow_range, training, num_bits, channel_axis,
|
||||
kernel_name="fake_quant_min_max_per_channel_update"):
|
||||
"""FakeQuantPerLayer op"""
|
||||
@util.check_input_type(dict, dict, dict, dict, dict, bool, float, int, str)
|
||||
def minmax_update_perchannel(x, min_val, max_val, min_up, max_up,
|
||||
ema, ema_decay, channel_axis,
|
||||
kernel_name="minmax_update_perchannel"):
|
||||
"""MinMaxUpdatePerChannel op"""
|
||||
x_shape = x.get("ori_shape")
|
||||
x_format = x.get("format")
|
||||
x_dtype = x.get("dtype")
|
||||
|
@ -112,21 +109,15 @@ def fake_quant_min_max_per_channel_update(x, min_val, max_val, min_up, max_up,
|
|||
util.check_dtype_rule(min_dtype, check_list)
|
||||
util.check_dtype_rule(max_dtype, check_list)
|
||||
|
||||
if symmetric:
|
||||
quant_min = 0 - 2 ** (num_bits - 1)
|
||||
quant_max = 2 ** (num_bits - 1) - 1
|
||||
if channel_axis_ == 0:
|
||||
shape_c = min_val.get("ori_shape")
|
||||
else:
|
||||
quant_min = 0
|
||||
quant_max = 2 ** num_bits - 1
|
||||
if narrow_range:
|
||||
quant_min = quant_min + 1
|
||||
|
||||
shape_c = [min_val.get("shape")[1], min_val.get("shape")[-1]]
|
||||
shape_c = [min_val.get("shape")[1], min_val.get("shape")[-1]]
|
||||
input_data = tvm.placeholder(x.get("shape"), name="x", dtype=x_dtype)
|
||||
min_data = tvm.placeholder(shape_c, name="min_val", dtype=x_dtype)
|
||||
max_data = tvm.placeholder(shape_c, name="max_val", dtype=x_dtype)
|
||||
res_list = fake_quant_min_max_per_channel_update_compute(input_data, min_data, max_data,
|
||||
ema, ema_decay, quant_min, quant_max, training, channel_axis_, kernel_name)
|
||||
res_list = minmax_update_perchannel_compute(input_data, min_data, max_data,
|
||||
ema, ema_decay, channel_axis_)
|
||||
|
||||
with tvm.target.cce():
|
||||
sch = generic.auto_schedule(res_list)
|
|
@ -13,7 +13,7 @@
|
|||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
"""FakeQuantMinMaxPerLayerUpdate op"""
|
||||
"""MinMaxUpdatePerLayer op"""
|
||||
from functools import reduce as functools_reduce
|
||||
import te.lang.cce
|
||||
from te import tvm
|
||||
|
@ -22,20 +22,15 @@ from topi import generic
|
|||
from topi.cce import util
|
||||
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
|
||||
|
||||
|
||||
fake_quant_minmax_update_op_info = TBERegOp("FakeQuantMinMaxPerLayerUpdate") \
|
||||
minmax_update_perlayer_op_info = TBERegOp("MinMaxUpdatePerLayer") \
|
||||
.fusion_type("OPAQUE") \
|
||||
.async_flag(False) \
|
||||
.binfile_name("fake_quant_minmax_update.so") \
|
||||
.binfile_name("minmax_update_perlayer.so") \
|
||||
.compute_cost(10) \
|
||||
.kernel_name("fake_quant_minmax_update") \
|
||||
.kernel_name("minmax_update_perlayer") \
|
||||
.partial_flag(True) \
|
||||
.attr("ema", "optional", "bool", "all") \
|
||||
.attr("ema_decay", "optional", "float", "all") \
|
||||
.attr("symmetric", "optional", "bool", "all") \
|
||||
.attr("narrow_range", "optional", "bool", "all") \
|
||||
.attr("training", "optional", "bool", "all") \
|
||||
.attr("num_bits", "optional", "int", "all") \
|
||||
.input(0, "x", None, "required", None) \
|
||||
.input(1, "min", None, "required", None) \
|
||||
.input(2, "max", None, "required", None) \
|
||||
|
@ -46,44 +41,42 @@ fake_quant_minmax_update_op_info = TBERegOp("FakeQuantMinMaxPerLayerUpdate") \
|
|||
.get_op_info()
|
||||
|
||||
|
||||
@op_info_register(fake_quant_minmax_update_op_info)
|
||||
def _fake_quant_minmax_update_tbe():
|
||||
"""FakeQuantMinMaxPerLayerUpdate TBE register"""
|
||||
@op_info_register(minmax_update_perlayer_op_info)
|
||||
def _minmax_update_perlayer_tbe():
|
||||
"""MinMaxUpdatePerLayer TBE register"""
|
||||
return
|
||||
|
||||
|
||||
@fusion_manager.register("fake_quant_minmax_update")
|
||||
def fake_quant_minmax_update_compute(x, min_val, max_val, ema, ema_decay, quant_min, quant_max, training,
|
||||
kernel_name="fake_quant_minmax_update"):
|
||||
"""FakeQuantMinMaxPerLayerUpdate compute"""
|
||||
@fusion_manager.register("minmax_update_perlayer")
|
||||
def minmax_update_perlayer_compute(x, min_val, max_val, ema, ema_decay):
|
||||
"""MinMaxUpdatePerLayer compute"""
|
||||
shape = te.lang.cce.util.shape_to_list(x.shape)
|
||||
shape_min = te.lang.cce.util.shape_to_list(min_val.shape)
|
||||
min_val = te.lang.cce.broadcast(min_val, shape_min, x.dtype)
|
||||
max_val = te.lang.cce.broadcast(max_val, shape_min, x.dtype)
|
||||
if not ema:
|
||||
ema_decay = 0.0
|
||||
if training:
|
||||
# CalMinMax
|
||||
axis = tuple(range(len(shape)))
|
||||
x_min = te.lang.cce.reduce_min(x, axis=axis)
|
||||
x_max = te.lang.cce.reduce_max(x, axis=axis)
|
||||
x_min = te.lang.cce.broadcast(x_min, shape_min)
|
||||
x_max = te.lang.cce.broadcast(x_max, shape_min)
|
||||
min_val = te.lang.cce.vadd(te.lang.cce.vmuls(
|
||||
min_val, ema_decay), te.lang.cce.vmuls(x_min, (1 - ema_decay)))
|
||||
max_val = te.lang.cce.vadd(te.lang.cce.vmuls(
|
||||
max_val, ema_decay), te.lang.cce.vmuls(x_max, (1 - ema_decay)))
|
||||
min_val = te.lang.cce.vmins(min_val, 0)
|
||||
max_val = te.lang.cce.vmaxs(max_val, 0)
|
||||
|
||||
# CalMinMax
|
||||
axis = tuple(range(len(shape)))
|
||||
x_min = te.lang.cce.reduce_min(x, axis=axis)
|
||||
x_max = te.lang.cce.reduce_max(x, axis=axis)
|
||||
x_min = te.lang.cce.broadcast(x_min, shape_min)
|
||||
x_max = te.lang.cce.broadcast(x_max, shape_min)
|
||||
min_val = te.lang.cce.vadd(te.lang.cce.vmuls(
|
||||
min_val, ema_decay), te.lang.cce.vmuls(x_min, (1 - ema_decay)))
|
||||
max_val = te.lang.cce.vadd(te.lang.cce.vmuls(
|
||||
max_val, ema_decay), te.lang.cce.vmuls(x_max, (1 - ema_decay)))
|
||||
min_val = te.lang.cce.vmins(min_val, 0)
|
||||
max_val = te.lang.cce.vmaxs(max_val, 0)
|
||||
|
||||
return [min_val, max_val]
|
||||
|
||||
|
||||
@util.check_input_type(dict, dict, dict, dict, dict, bool, float, bool, bool, bool, int, str)
|
||||
def fake_quant_minmax_update(x, min_val, max_val, min_up, max_up,
|
||||
ema, ema_decay, symmetric, narrow_range, training, num_bits,
|
||||
kernel_name="fake_quant_minmax_update"):
|
||||
"""FakeQuantPerLayer op"""
|
||||
@util.check_input_type(dict, dict, dict, dict, dict, bool, float, str)
|
||||
def minmax_update_perlayer(x, min_val, max_val, min_up, max_up,
|
||||
ema, ema_decay, kernel_name="minmax_update_perlayer"):
|
||||
"""MinMaxUpdatePerLayer op"""
|
||||
input_shape = x.get("shape")
|
||||
input_dtype = x.get("dtype")
|
||||
min_shape = min_val.get("ori_shape")
|
||||
|
@ -112,20 +105,10 @@ def fake_quant_minmax_update(x, min_val, max_val, min_up, max_up,
|
|||
input_shape = (functools_reduce(lambda x, y: x * y, input_shape[:]),)
|
||||
shape_min, _, _ = util.produce_shapes(min_shape, input_shape)
|
||||
|
||||
if symmetric:
|
||||
quant_min = 0 - 2 ** (num_bits - 1)
|
||||
quant_max = 2 ** (num_bits - 1) - 1
|
||||
else:
|
||||
quant_min = 0
|
||||
quant_max = 2 ** num_bits - 1
|
||||
if narrow_range:
|
||||
quant_min = quant_min + 1
|
||||
|
||||
input_data = tvm.placeholder(input_shape, name="x", dtype=x_dtype)
|
||||
min_data = tvm.placeholder(shape_min, name="min_data", dtype=min_dtype)
|
||||
max_data = tvm.placeholder(shape_min, name="max_data", dtype=max_dtype)
|
||||
res_list = fake_quant_minmax_update_compute(input_data, min_data, max_data,
|
||||
ema, ema_decay, quant_min, quant_max, training, kernel_name)
|
||||
res_list = minmax_update_perlayer_compute(input_data, min_data, max_data, ema, ema_decay)
|
||||
|
||||
with tvm.target.cce():
|
||||
sch = generic.auto_schedule(res_list)
|
|
@ -21,12 +21,12 @@ from ..._checkparam import Rel
|
|||
from ..primitive import PrimitiveWithInfer, prim_attr_register
|
||||
from ...common import dtype as mstype
|
||||
|
||||
__all__ = ["FakeQuantPerLayer",
|
||||
__all__ = ["MinMaxUpdatePerLayer",
|
||||
"MinMaxUpdatePerChannel",
|
||||
"FakeQuantPerLayer",
|
||||
"FakeQuantPerLayerGrad",
|
||||
"FakeQuantPerChannel",
|
||||
"FakeQuantPerChannelGrad",
|
||||
"FakeQuantMinMaxPerLayerUpdate",
|
||||
"FakeQuantMinMaxPerChannelUpdate",
|
||||
"BatchNormFold",
|
||||
"BatchNormFoldGrad",
|
||||
"CorrectionMul",
|
||||
|
@ -36,23 +36,140 @@ __all__ = ["FakeQuantPerLayer",
|
|||
"BatchNormFold2Grad",
|
||||
"BatchNormFoldD",
|
||||
"BatchNormFoldGradD",
|
||||
"BNTrainingReduce",
|
||||
"BatchNormFold2_D",
|
||||
"BatchNormFold2GradD",
|
||||
"BatchNormFold2GradReduce",
|
||||
"BatchNormFold2GradReduce"
|
||||
]
|
||||
|
||||
|
||||
class MinMaxUpdatePerLayer(PrimitiveWithInfer):
|
||||
r"""
|
||||
Update min and max per layer.
|
||||
|
||||
Args:
|
||||
ema (bool): Use EMA algorithm update value min and max. Default: False.
|
||||
ema_decay (int) : EMA algorithm decay parameter. Default: 0.999.
|
||||
|
||||
Inputs:
|
||||
- **x** (Tensor) : float32 Tensor representing the shape of the output tensor.
|
||||
- **min** (Tensor) : Value of the min range of the input data x.
|
||||
- **max** (Tensor) : Value of the max range of the input data x.
|
||||
|
||||
Outputs:
|
||||
- Tensor: Simulate quantize tensor of x.
|
||||
|
||||
Examples:
|
||||
>>> input_tensor = Tensor(np.random.rand(3, 16, 5, 5), mstype.float32)
|
||||
>>> min_tensor = Tensor(np.array([-6]), mstype.float32)
|
||||
>>> max_tensor = Tensor(np.array([6]), mstype.float32)
|
||||
>>> output_tensor = MinMaxUpdatePerLayer(num_bits=8)(input_tensor, min_tensor, max_tensor)
|
||||
"""
|
||||
support_quant_bit = [4, 7, 8]
|
||||
|
||||
@prim_attr_register
|
||||
def __init__(self, ema=False, ema_decay=0.999):
|
||||
"""init FakeQuantMinMaxPerLayerUpdate OP"""
|
||||
if context.get_context('device_target') == "Ascend":
|
||||
from mindspore.ops._op_impl._custom_op import minmax_update_perlayer
|
||||
if ema and not ema_decay:
|
||||
raise ValueError(
|
||||
f"For '{self.name}' attr \'ema\' and \'ema_decay\' should set together.")
|
||||
|
||||
self.ema = validator.check_value_type('ema', ema, (bool,), self.name)
|
||||
self.ema_decay = validator.check_number_range(
|
||||
'ema_decay', ema_decay, 0, 1, Rel.INC_BOTH, self.name)
|
||||
self.init_prim_io_names(inputs=['x', 'min', 'max'],
|
||||
outputs=['min_up', 'max_up'])
|
||||
|
||||
def infer_shape(self, x_shape, min_shape, max_shape):
|
||||
validator.check_integer("x rank", len(x_shape), 1, Rel.GE, self.name)
|
||||
validator.check("min shape", min_shape, "max shape",
|
||||
max_shape, Rel.EQ, self.name)
|
||||
validator.check_integer("min shape", len(
|
||||
min_shape), 1, Rel.EQ, self.name)
|
||||
return min_shape, max_shape
|
||||
|
||||
def infer_dtype(self, x_type, min_type, max_type):
|
||||
valid_types = (mstype.float16, mstype.float32)
|
||||
validator.check_tensor_type_same({"x": x_type}, valid_types, self.name)
|
||||
validator.check_tensor_type_same(
|
||||
{"min": min_type}, valid_types, self.name)
|
||||
validator.check_tensor_type_same(
|
||||
{"max": max_type}, valid_types, self.name)
|
||||
return min_type, max_type
|
||||
|
||||
|
||||
class MinMaxUpdatePerChannel(PrimitiveWithInfer):
|
||||
r"""
|
||||
Update min and max per channel.
|
||||
|
||||
Args:
|
||||
ema (bool): Use EMA algorithm update value min and max. Default: False.
|
||||
ema_decay (int) : EMA algorithm decay parameter. Default: 0.999.
|
||||
channel_axis (int): Channel asis for per channel compute. Default: 1.
|
||||
|
||||
Inputs:
|
||||
- **x** (Tensor) : float32 Tensor representing the shape of the output tensor.
|
||||
- **min** (Tensor) : Value of the min range of the input data x.
|
||||
- **max** (Tensor) : Value of the max range of the input data x.
|
||||
|
||||
Outputs:
|
||||
- Tensor: Simulate quantize tensor of x.
|
||||
|
||||
Examples:
|
||||
>>> x = Tensor(np.random.rand(3, 16, 5, 5), mstype.float32)
|
||||
>>> min = Tensor(np.random.uniform(-1, 1, size=16), mstype.float32)
|
||||
>>> max = Tensor(np.random.uniform(-1, 1, size=16), mstype.float32)
|
||||
>>> output_tensor = MinMaxUpdatePerChannel(num_bits=8)(x, min, max)
|
||||
"""
|
||||
support_quant_bit = [4, 7, 8]
|
||||
|
||||
@prim_attr_register
|
||||
def __init__(self, ema=False, ema_decay=0.999, channel_axis=1):
|
||||
"""init FakeQuantPerChannelUpdate OP for Ascend"""
|
||||
if context.get_context('device_target') == "Ascend":
|
||||
from mindspore.ops._op_impl._custom_op import minmax_update_perchannel
|
||||
if ema and not ema_decay:
|
||||
raise ValueError(
|
||||
f"For '{self.name}' attr \'ema\' and \'ema_decay\' should set together.")
|
||||
|
||||
self.ema = validator.check_value_type('ema', ema, (bool,), self.name)
|
||||
self.ema_decay = validator.check_number_range(
|
||||
'ema_decay', ema_decay, 0, 1, Rel.INC_BOTH, self.name)
|
||||
self.channel_axis = validator.check_integer(
|
||||
'channel axis', channel_axis, 0, Rel.GE, self.name)
|
||||
self.init_prim_io_names(
|
||||
inputs=['x', 'min', 'max'], outputs=['min_up', 'max_up'])
|
||||
|
||||
def infer_shape(self, x_shape, min_shape, max_shape):
|
||||
validator.check_integer("x rank", len(x_shape), 1, Rel.GT, self.name)
|
||||
validator.check("min shape", min_shape, "max shape",
|
||||
max_shape, Rel.EQ, self.name)
|
||||
validator.check_integer("min shape", len(
|
||||
min_shape), 1, Rel.EQ, self.name)
|
||||
return min_shape, max_shape
|
||||
|
||||
def infer_dtype(self, x_type, min_type, max_type):
|
||||
valid_types = (mstype.float16, mstype.float32)
|
||||
validator.check_tensor_type_same(
|
||||
{"x": x_type}, valid_types, self.name)
|
||||
validator.check_tensor_type_same(
|
||||
{"min": min_type}, valid_types, self.name)
|
||||
validator.check_tensor_type_same(
|
||||
{"max": max_type}, valid_types, self.name)
|
||||
return min_type, max_type
|
||||
|
||||
|
||||
class FakeQuantPerLayer(PrimitiveWithInfer):
|
||||
r"""
|
||||
Simulate the quantize and dequantize operations in training time.
|
||||
|
||||
Args:
|
||||
num_bits (int) : Number bits for aware quantilization. Default: 8.
|
||||
num_bits (int) : Number bits for quantization aware. Default: 8.
|
||||
ema (bool): Use EMA algorithm update value min and max. Default: False.
|
||||
ema_decay (int) : EMA algorithm decay parameter. Default: 0.999.
|
||||
quant_delay (int): Quantilization delay parameter. Before delay step in training time not update
|
||||
simulate aware quantize funcion. After delay step in training time begin simulate the aware
|
||||
simulate quantization aware funcion. After delay step in training time begin simulate the aware
|
||||
quantize funcion. Default: 0.
|
||||
symmetric (bool): Quantization algorithm use symmetric or not. Default: False.
|
||||
narrow_range (bool): Quantization algorithm use narrow range or not. Default: False.
|
||||
|
@ -334,7 +451,7 @@ class BatchNormFold(PrimitiveWithInfer):
|
|||
Batch normalization folded.
|
||||
|
||||
Args:
|
||||
momentum (float): Momentum value should be [0, 1]. Default: 0.1.
|
||||
momentum (float): Momentum value should be [0, 1]. Default: 0.9.
|
||||
epsilon (float): A small float number to avoid dividing by 0. 1e-5 if dtype in
|
||||
float32 else 1e-3. Default: 1e-5.
|
||||
is_training (bool): In training mode set True, else set False. Default: True.
|
||||
|
@ -366,7 +483,7 @@ class BatchNormFold(PrimitiveWithInfer):
|
|||
channel_axis = 1
|
||||
|
||||
@prim_attr_register
|
||||
def __init__(self, momentum=0.1, epsilon=1e-5, is_training=True, freeze_bn=0):
|
||||
def __init__(self, momentum=0.9, epsilon=1e-5, is_training=True, freeze_bn=0):
|
||||
"""init batch norm fold layer"""
|
||||
self.momentum = validator.check_number_range('momentum', momentum, 0, 1, Rel.INC_BOTH, self.name)
|
||||
self.epsilon = validator.check_float_positive('epsilon', epsilon, self.name)
|
||||
|
@ -731,32 +848,6 @@ class BatchNormFoldGradD(PrimitiveWithInfer):
|
|||
return x_type
|
||||
|
||||
|
||||
class BNTrainingReduce(PrimitiveWithInfer):
|
||||
"""
|
||||
reduce sum at axis [0, 2, 3].
|
||||
|
||||
Inputs:
|
||||
- **x** (Tensor) - Tensor of shape :math:`(N, C)`.
|
||||
|
||||
Outputs:
|
||||
- **x_sum** (Tensor) - Tensor has the same shape as x.
|
||||
- **x_square_sum** (Tensor) - Tensor has the same shape as x.
|
||||
|
||||
"""
|
||||
|
||||
@prim_attr_register
|
||||
def __init__(self):
|
||||
"""init _BNTrainingReduce layer"""
|
||||
self.init_prim_io_names(inputs=['x'],
|
||||
outputs=['x_sum', 'x_square_sum'])
|
||||
|
||||
def infer_shape(self, x_shape):
|
||||
return [x_shape[1]], [x_shape[1]]
|
||||
|
||||
def infer_dtype(self, x_type):
|
||||
return x_type, x_type
|
||||
|
||||
|
||||
class BatchNormFold2_D(PrimitiveWithInfer):
|
||||
"""
|
||||
Scale the bias with a correction factor to the long term statistics
|
||||
|
@ -859,153 +950,3 @@ class BatchNormFold2GradReduce(PrimitiveWithInfer):
|
|||
def infer_dtype(self, dout_type, x_type):
|
||||
validator.check("dout type", dout_type, "x type", x_type)
|
||||
return dout_type, dout_type
|
||||
|
||||
|
||||
class FakeQuantMinMaxPerLayerUpdate(PrimitiveWithInfer):
|
||||
r"""
|
||||
Update min and max value for fake quant per layer op.
|
||||
|
||||
Args:
|
||||
num_bits (int) : Number bits for aware quantilization. Default: 8.
|
||||
ema (bool): Use EMA algorithm update value min and max. Default: False.
|
||||
ema_decay (int) : EMA algorithm decay parameter. Default: 0.999.
|
||||
symmetric (bool): Quantization algorithm use symmetric or not. Default: False.
|
||||
narrow_range (bool): Quantization algorithm use narrow range or not. Default: False.
|
||||
training (bool): Training the network or not. Default: True.
|
||||
|
||||
Inputs:
|
||||
- **x** (Tensor) : float32 Tensor representing the shape of the output tensor.
|
||||
- **min** (Tensor) : Value of the min range of the input data x.
|
||||
- **max** (Tensor) : Value of the max range of the input data x.
|
||||
|
||||
Outputs:
|
||||
- Tensor: Simulate quantize tensor of x.
|
||||
|
||||
Examples:
|
||||
>>> input_tensor = Tensor(np.random.rand(3, 16, 5, 5), mstype.float32)
|
||||
>>> min_tensor = Tensor(np.array([-6]), mstype.float32)
|
||||
>>> max_tensor = Tensor(np.array([6]), mstype.float32)
|
||||
>>> output_tensor = FakeQuantWithMinMax(num_bits=8)(input_tensor, min_tensor, max_tensor)
|
||||
"""
|
||||
support_quant_bit = [4, 7, 8]
|
||||
|
||||
@prim_attr_register
|
||||
def __init__(self, num_bits=8, ema=False, ema_decay=0.999, symmetric=False, narrow_range=False,
|
||||
training=True):
|
||||
"""init FakeQuantMinMaxPerLayerUpdate OP"""
|
||||
if context.get_context('device_target') == "Ascend":
|
||||
from mindspore.ops._op_impl._custom_op import fake_quant_minmax_perlayer_update
|
||||
if num_bits not in self.support_quant_bit:
|
||||
raise ValueError(
|
||||
f"For '{self.name}' attr \'num_bits\' is not support.")
|
||||
if ema and not ema_decay:
|
||||
raise ValueError(
|
||||
f"For '{self.name}' attr \'ema\' and \'ema_decay\' should set together.")
|
||||
|
||||
self.ema = validator.check_value_type('ema', ema, (bool,), self.name)
|
||||
self.symmetric = validator.check_value_type(
|
||||
'symmetric', symmetric, (bool,), self.name)
|
||||
self.narrow_range = validator.check_value_type(
|
||||
'narrow_range', narrow_range, (bool,), self.name)
|
||||
self.training = validator.check_value_type(
|
||||
'training', training, (bool,), self.name)
|
||||
self.ema_decay = validator.check_number_range(
|
||||
'ema_decay', ema_decay, 0, 1, Rel.INC_BOTH, self.name)
|
||||
self.num_bits = validator.check_integer(
|
||||
'num_bits', num_bits, 0, Rel.GT, self.name)
|
||||
self.init_prim_io_names(inputs=['x', 'min', 'max'],
|
||||
outputs=['min_up', 'max_up'])
|
||||
|
||||
def infer_shape(self, x_shape, min_shape, max_shape):
|
||||
validator.check_integer("x rank", len(x_shape), 1, Rel.GE, self.name)
|
||||
validator.check("min shape", min_shape, "max shape",
|
||||
max_shape, Rel.EQ, self.name)
|
||||
validator.check_integer("min shape", len(
|
||||
min_shape), 1, Rel.EQ, self.name)
|
||||
return min_shape, max_shape
|
||||
|
||||
def infer_dtype(self, x_type, min_type, max_type):
|
||||
valid_types = (mstype.float16, mstype.float32)
|
||||
validator.check_tensor_type_same({"x": x_type}, valid_types, self.name)
|
||||
validator.check_tensor_type_same(
|
||||
{"min": min_type}, valid_types, self.name)
|
||||
validator.check_tensor_type_same(
|
||||
{"max": max_type}, valid_types, self.name)
|
||||
return min_type, max_type
|
||||
|
||||
|
||||
class FakeQuantMinMaxPerChannelUpdate(PrimitiveWithInfer):
|
||||
r"""
|
||||
Update min and max value for fake quant per layer op.
|
||||
|
||||
Args:
|
||||
num_bits (int) : Number bits for aware quantilization. Default: 8.
|
||||
ema (bool): Use EMA algorithm update value min and max. Default: False.
|
||||
ema_decay (int) : EMA algorithm decay parameter. Default: 0.999.
|
||||
symmetric (bool): Quantization algorithm use symmetric or not. Default: False.
|
||||
narrow_range (bool): Quantization algorithm use narrow range or not. Default: False.
|
||||
training (bool): Training the network or not. Default: True.
|
||||
channel_axis (int): Channel asis for per channel compute. Default: 1.
|
||||
|
||||
Inputs:
|
||||
- **x** (Tensor) : float32 Tensor representing the shape of the output tensor.
|
||||
- **min** (Tensor) : Value of the min range of the input data x.
|
||||
- **max** (Tensor) : Value of the max range of the input data x.
|
||||
|
||||
Outputs:
|
||||
- Tensor: Simulate quantize tensor of x.
|
||||
|
||||
Examples:
|
||||
>>> x = Tensor(np.random.rand(3, 16, 5, 5), mstype.float32)
|
||||
>>> min = Tensor(np.random.uniform(-1, 1, size=16), mstype.float32)
|
||||
>>> max = Tensor(np.random.uniform(-1, 1, size=16), mstype.float32)
|
||||
>>> output_tensor = FakeQuantWithMinMax(num_bits=8)(x, min, max)
|
||||
"""
|
||||
support_quant_bit = [4, 7, 8]
|
||||
|
||||
@prim_attr_register
|
||||
def __init__(self, num_bits=8, ema=False, ema_decay=0.999, symmetric=False, narrow_range=False,
|
||||
training=True, channel_axis=1):
|
||||
"""init FakeQuantPerChannelUpdate OP for Ascend"""
|
||||
if context.get_context('device_target') == "Ascend":
|
||||
from mindspore.ops._op_impl._custom_op import fake_quant_minmax_perchannel_update
|
||||
if num_bits not in self.support_quant_bit:
|
||||
raise ValueError(
|
||||
f"For '{self.name}' attr \'num_bits\' is not support.")
|
||||
if ema and not ema_decay:
|
||||
raise ValueError(
|
||||
f"For '{self.name}' attr \'ema\' and \'ema_decay\' should set together.")
|
||||
|
||||
self.ema = validator.check_value_type('ema', ema, (bool,), self.name)
|
||||
self.symmetric = validator.check_value_type(
|
||||
'symmetric', symmetric, (bool,), self.name)
|
||||
self.narrow_range = validator.check_value_type(
|
||||
'narrow_range', narrow_range, (bool,), self.name)
|
||||
self.training = validator.check_value_type(
|
||||
'training', training, (bool,), self.name)
|
||||
self.ema_decay = validator.check_number_range(
|
||||
'ema_decay', ema_decay, 0, 1, Rel.INC_BOTH, self.name)
|
||||
self.num_bits = validator.check_integer(
|
||||
'num_bits', num_bits, 0, Rel.GT, self.name)
|
||||
self.channel_axis = validator.check_integer(
|
||||
'channel axis', channel_axis, 0, Rel.GE, self.name)
|
||||
self.init_prim_io_names(
|
||||
inputs=['x', 'min', 'max'], outputs=['min_up', 'max_up'])
|
||||
|
||||
def infer_shape(self, x_shape, min_shape, max_shape):
|
||||
validator.check_integer("x rank", len(x_shape), 1, Rel.GT, self.name)
|
||||
validator.check("min shape", min_shape, "max shape",
|
||||
max_shape, Rel.EQ, self.name)
|
||||
validator.check_integer("min shape", len(
|
||||
min_shape), 1, Rel.EQ, self.name)
|
||||
return min_shape, max_shape
|
||||
|
||||
def infer_dtype(self, x_type, min_type, max_type):
|
||||
valid_types = (mstype.float16, mstype.float32)
|
||||
validator.check_tensor_type_same(
|
||||
{"x": x_type}, valid_types, self.name)
|
||||
validator.check_tensor_type_same(
|
||||
{"min": min_type}, valid_types, self.name)
|
||||
validator.check_tensor_type_same(
|
||||
{"max": max_type}, valid_types, self.name)
|
||||
return min_type, max_type
|
||||
|
|
Loading…
Reference in New Issue