forked from mindspore-Ecosystem/mindspore
fix quantization aware training auto create graph bug
This commit is contained in:
parent
6ef1a731db
commit
c831d3eb60
|
@ -81,6 +81,7 @@ class Cell:
|
||||||
self.enable_hook = False
|
self.enable_hook = False
|
||||||
self._bprop_debug = False
|
self._bprop_debug = False
|
||||||
self._is_run = False
|
self._is_run = False
|
||||||
|
self.cell_type = None
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def is_run(self):
|
def is_run(self):
|
||||||
|
@ -140,6 +141,14 @@ class Cell:
|
||||||
for cell_name, cell in cells_name:
|
for cell_name, cell in cells_name:
|
||||||
cell._param_prefix = cell_name
|
cell._param_prefix = cell_name
|
||||||
|
|
||||||
|
def update_cell_type(self, cell_type):
|
||||||
|
"""
|
||||||
|
Update current cell type mainly identify if quantization aware training network.
|
||||||
|
|
||||||
|
After invoked, can set the cell type to 'cell_type'.
|
||||||
|
"""
|
||||||
|
self.cell_type = cell_type
|
||||||
|
|
||||||
@cell_init_args.setter
|
@cell_init_args.setter
|
||||||
def cell_init_args(self, value):
|
def cell_init_args(self, value):
|
||||||
if not isinstance(value, str):
|
if not isinstance(value, str):
|
||||||
|
|
|
@ -17,6 +17,7 @@ from mindspore import log as logger
|
||||||
from mindspore.ops import operations as P
|
from mindspore.ops import operations as P
|
||||||
from mindspore.common.parameter import Parameter
|
from mindspore.common.parameter import Parameter
|
||||||
from mindspore.common.initializer import initializer
|
from mindspore.common.initializer import initializer
|
||||||
|
from mindspore._checkparam import ParamValidator as validator, Rel
|
||||||
from mindspore._checkparam import check_bool, twice, check_int_positive, check_int_non_negative
|
from mindspore._checkparam import check_bool, twice, check_int_positive, check_int_non_negative
|
||||||
from mindspore._extends import cell_attr_register
|
from mindspore._extends import cell_attr_register
|
||||||
from ..cell import Cell
|
from ..cell import Cell
|
||||||
|
@ -397,3 +398,150 @@ class Conv2dTranspose(_Conv):
|
||||||
self.weight,
|
self.weight,
|
||||||
self.bias)
|
self.bias)
|
||||||
return s
|
return s
|
||||||
|
|
||||||
|
|
||||||
|
class DepthwiseConv2d(Cell):
|
||||||
|
r"""
|
||||||
|
2D depthwise convolution layer.
|
||||||
|
|
||||||
|
Applies a 2D depthwise convolution over an input tensor which is typically of shape:
|
||||||
|
math:`(N, C_{in}, H_{in}, W_{in})`, where :math:`N` is batch size and :math:`C_{in}` is channel number.
|
||||||
|
For each batch of shape:math:`(C_{in}, H_{in}, W_{in})`, the formula is defined as:
|
||||||
|
|
||||||
|
.. math::
|
||||||
|
|
||||||
|
out_j = \sum_{i=0}^{C_{in} - 1} ccor(W_{ij}, X_i) + b_j,
|
||||||
|
|
||||||
|
where :math:`ccor` is cross correlation operator, :math:`C_{in}` is the input channel number, :math:`j` ranges
|
||||||
|
from :math:`0` to :math:`C_{out} - 1`, :math:`W_{ij}` corresponds to :math:`i`-th channel of the :math:`j`-th
|
||||||
|
filter and :math:`out_{j}` corresponds to the :math:`j`-th channel of the output. :math:`W_{ij}` is a slice
|
||||||
|
of kernel and it has shape :math:`(\text{ks_h}, \text{ks_w})`, where :math:`\text{ks_h}` and
|
||||||
|
:math:`\text{ks_w}` are height and width of the convolution kernel. The full kernel has shape
|
||||||
|
:math:`(C_{out}, C_{in} // \text{group}, \text{ks_h}, \text{ks_w})`, where group is the group number
|
||||||
|
to split the input in the channel dimension.
|
||||||
|
|
||||||
|
If the 'pad_mode' is set to be "valid", the output height and width will be
|
||||||
|
:math:`\left \lfloor{1 + \frac{H_{in} + 2 \times \text{padding} - \text{ks_h} -
|
||||||
|
(\text{ks_h} - 1) \times (\text{dilation} - 1) }{\text{stride}}} \right \rfloor` and
|
||||||
|
:math:`\left \lfloor{1 + \frac{W_{in} + 2 \times \text{padding} - \text{ks_w} -
|
||||||
|
(\text{ks_w} - 1) \times (\text{dilation} - 1) }{\text{stride}}} \right \rfloor` respectively.
|
||||||
|
|
||||||
|
The first introduction can be found in paper `Gradient Based Learning Applied to Document Recognition
|
||||||
|
<http://vision.stanford.edu/cs598_spring07/papers/Lecun98.pdf>`_.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
in_channels (int): The number of input channel :math:`C_{in}`.
|
||||||
|
out_channels (int): The number of output channel :math:`C_{out}`.
|
||||||
|
kernel_size (Union[int, tuple[int]]): The data type is int or tuple with 2 integers. Specifies the height
|
||||||
|
and width of the 2D convolution window. Single int means the value if for both height and width of
|
||||||
|
the kernel. A tuple of 2 ints means the first value is for the height and the other is for the
|
||||||
|
width of the kernel.
|
||||||
|
stride (Union[int, tuple[int]]): The distance of kernel moving, an int number that represents
|
||||||
|
the height and width of movement are both strides, or a tuple of two int numbers that
|
||||||
|
represent height and width of movement respectively. Default: 1.
|
||||||
|
pad_mode (str): Specifies padding mode. The optional values are
|
||||||
|
"same", "valid", "pad". Default: "same".
|
||||||
|
|
||||||
|
- same: Adopts the way of completion. Output height and width will be the same as the input.
|
||||||
|
Total number of padding will be calculated for horizontal and vertical
|
||||||
|
direction and evenly distributed to top and bottom, left and right if possible. Otherwise, the
|
||||||
|
last extra padding will be done from the bottom and the right side. If this mode is set, `padding`
|
||||||
|
must be 0.
|
||||||
|
|
||||||
|
- valid: Adopts the way of discarding. The possibly largest height and width of output will be return
|
||||||
|
without padding. Extra pixels will be discarded. If this mode is set, `padding`
|
||||||
|
must be 0.
|
||||||
|
|
||||||
|
- pad: Implicit paddings on both sides of the input. The number of `padding` will be padded to the input
|
||||||
|
Tensor borders. `padding` should be greater than or equal to 0.
|
||||||
|
|
||||||
|
padding (int): Implicit paddings on both sides of the input. Default: 0.
|
||||||
|
dilation (Union[int, tuple[int]]): The data type is int or tuple with 2 integers. Specifies the dilation rate
|
||||||
|
to use for dilated convolution. If set to be :math:`k > 1`, there will
|
||||||
|
be :math:`k - 1` pixels skipped for each sampling location. Its value should
|
||||||
|
be greater or equal to 1 and bounded by the height and width of the
|
||||||
|
input. Default: 1.
|
||||||
|
group (int): Split filter into groups, `in_ channels` and `out_channels` should be
|
||||||
|
divisible by the number of groups. Default: 1.
|
||||||
|
has_bias (bool): Specifies whether the layer uses a bias vector. Default: False.
|
||||||
|
weight_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the convolution kernel.
|
||||||
|
It can be a Tensor, a string, an Initializer or a numbers.Number. When a string is specified,
|
||||||
|
values from 'TruncatedNormal', 'Normal', 'Uniform', 'HeUniform' and 'XavierUniform' distributions as well
|
||||||
|
as constant 'One' and 'Zero' distributions are possible. Alias 'xavier_uniform', 'he_uniform', 'ones'
|
||||||
|
and 'zeros' are acceptable. Uppercase and lowercase are both acceptable. Refer to the values of
|
||||||
|
Initializer for more details. Default: 'normal'.
|
||||||
|
bias_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the bias vector. Possible
|
||||||
|
Initializer and string are the same as 'weight_init'. Refer to the values of
|
||||||
|
Initializer for more details. Default: 'zeros'.
|
||||||
|
|
||||||
|
Inputs:
|
||||||
|
- **input** (Tensor) - Tensor of shape :math:`(N, C_{in}, H_{in}, W_{in})`.
|
||||||
|
|
||||||
|
Outputs:
|
||||||
|
Tensor of shape :math:`(N, C_{out}, H_{out}, W_{out})`.
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
>>> net = nn.DepthwiseConv2d(120, 240, 4, has_bias=False, weight_init='normal')
|
||||||
|
>>> input = Tensor(np.ones([1, 120, 1024, 640]), mindspore.float32)
|
||||||
|
>>> net(input).shape
|
||||||
|
(1, 240, 1024, 640)
|
||||||
|
"""
|
||||||
|
def __init__(self,
|
||||||
|
in_channels,
|
||||||
|
out_channels,
|
||||||
|
kernel_size,
|
||||||
|
stride=1,
|
||||||
|
pad_mode='same',
|
||||||
|
padding=0,
|
||||||
|
dilation=1,
|
||||||
|
group=1,
|
||||||
|
has_bias=False,
|
||||||
|
weight_init='normal',
|
||||||
|
bias_init='zeros'):
|
||||||
|
super(DepthwiseConv2d, self).__init__()
|
||||||
|
self.kernel_size = twice(kernel_size)
|
||||||
|
self.stride = twice(stride)
|
||||||
|
self.dilation = twice(dilation)
|
||||||
|
self.in_channels = check_int_positive(in_channels)
|
||||||
|
self.out_channels = check_int_positive(out_channels)
|
||||||
|
validator.check_integer('group', group, in_channels, Rel.EQ)
|
||||||
|
validator.check_integer('group', group, out_channels, Rel.EQ)
|
||||||
|
validator.check_integer('group', group, 1, Rel.GE)
|
||||||
|
self.pad_mode = pad_mode
|
||||||
|
self.padding = padding
|
||||||
|
self.dilation = dilation
|
||||||
|
self.group = group
|
||||||
|
self.has_bias = has_bias
|
||||||
|
self.conv = P.DepthwiseConv2dNative(channel_multiplier=1,
|
||||||
|
kernel_size=self.kernel_size,
|
||||||
|
pad_mode=self.pad_mode,
|
||||||
|
pad=self.padding,
|
||||||
|
stride=self.stride,
|
||||||
|
dilation=self.dilation)
|
||||||
|
self.bias_add = P.BiasAdd()
|
||||||
|
weight_shape = [1, in_channels, *self.kernel_size]
|
||||||
|
self.weight = Parameter(initializer(weight_init, weight_shape), name='weight')
|
||||||
|
if check_bool(has_bias):
|
||||||
|
self.bias = Parameter(initializer(bias_init, [out_channels]), name='bias')
|
||||||
|
else:
|
||||||
|
if bias_init != 'zeros':
|
||||||
|
logger.warning("value of `has_bias` is False, value of `bias_init` will be ignore.")
|
||||||
|
self.bias = None
|
||||||
|
|
||||||
|
def construct(self, x):
|
||||||
|
out = self.conv(x, self.weight)
|
||||||
|
if self.has_bias:
|
||||||
|
out = self.bias_add(out, self.bias)
|
||||||
|
return out
|
||||||
|
|
||||||
|
def extend_repr(self):
|
||||||
|
s = 'input_channels={}, output_channels={}, kernel_size={}, stride={}, ' \
|
||||||
|
'pad_mode={}, padding={}, dilation={}, group={},' \
|
||||||
|
'has_bias={}, weight_init={}, bias_init={}'.format(
|
||||||
|
self.in_channels, self.out_channels, self.kernel_size, self.stride,
|
||||||
|
self.pad_mode, self.padding, self.dilation, self.group,
|
||||||
|
self.has_bias, self.weight_init, self.bias_init)
|
||||||
|
|
||||||
|
if self.has_bias:
|
||||||
|
s += ', bias={}'.format(self.bias)
|
||||||
|
return s
|
||||||
|
|
|
@ -12,10 +12,11 @@
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
# ============================================================================
|
# ============================================================================
|
||||||
"""Aware quantization."""
|
"""Quantization aware."""
|
||||||
|
|
||||||
from functools import partial
|
from functools import partial
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
import mindspore.common.dtype as mstype
|
import mindspore.common.dtype as mstype
|
||||||
from mindspore.ops import operations as P
|
from mindspore.ops import operations as P
|
||||||
from mindspore.ops import functional as F
|
from mindspore.ops import functional as F
|
||||||
|
@ -23,10 +24,9 @@ from mindspore.common.parameter import Parameter
|
||||||
from mindspore.common.initializer import initializer
|
from mindspore.common.initializer import initializer
|
||||||
from mindspore.common.tensor import Tensor
|
from mindspore.common.tensor import Tensor
|
||||||
from mindspore._checkparam import check_int_positive, check_bool, twice
|
from mindspore._checkparam import check_int_positive, check_bool, twice
|
||||||
from mindspore._checkparam import Validator as validator, Rel
|
from mindspore._checkparam import Rel
|
||||||
from mindspore.nn.cell import Cell
|
|
||||||
from mindspore.nn.layer.activation import get_activation
|
|
||||||
import mindspore.context as context
|
import mindspore.context as context
|
||||||
|
|
||||||
from .normalization import BatchNorm2d
|
from .normalization import BatchNorm2d
|
||||||
from .activation import get_activation
|
from .activation import get_activation
|
||||||
from ..cell import Cell
|
from ..cell import Cell
|
||||||
|
@ -82,7 +82,7 @@ class Conv2dBnAct(Cell):
|
||||||
bias_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the bias vector. Possible
|
bias_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the bias vector. Possible
|
||||||
Initializer and string are the same as 'weight_init'. Refer to the values of
|
Initializer and string are the same as 'weight_init'. Refer to the values of
|
||||||
Initializer for more details. Default: 'zeros'.
|
Initializer for more details. Default: 'zeros'.
|
||||||
batchnorm (bool): Specifies to used batchnorm or not. Default: None.
|
has_bn (bool): Specifies to used batchnorm or not. Default: False.
|
||||||
activation (string): Specifies activation type. The optional values are as following:
|
activation (string): Specifies activation type. The optional values are as following:
|
||||||
'softmax', 'logsoftmax', 'relu', 'relu6', 'tanh', 'gelu', 'sigmoid',
|
'softmax', 'logsoftmax', 'relu', 'relu6', 'tanh', 'gelu', 'sigmoid',
|
||||||
'prelu', 'leakyrelu', 'hswish', 'hsigmoid'. Default: None.
|
'prelu', 'leakyrelu', 'hswish', 'hsigmoid'. Default: None.
|
||||||
|
@ -94,7 +94,7 @@ class Conv2dBnAct(Cell):
|
||||||
Tensor of shape :math:`(N, C_{out}, H_{out}, W_{out})`.
|
Tensor of shape :math:`(N, C_{out}, H_{out}, W_{out})`.
|
||||||
|
|
||||||
Examples:
|
Examples:
|
||||||
>>> net = Conv2dBnAct(120, 240, 4, batchnorm=True, activation='ReLU')
|
>>> net = Conv2dBnAct(120, 240, 4, has_bn=True, activation='ReLU')
|
||||||
>>> input = Tensor(np.ones([1, 120, 1024, 640]), mindspore.float32)
|
>>> input = Tensor(np.ones([1, 120, 1024, 640]), mindspore.float32)
|
||||||
>>> net(input).shape
|
>>> net(input).shape
|
||||||
(1, 240, 1024, 640)
|
(1, 240, 1024, 640)
|
||||||
|
@ -112,28 +112,39 @@ class Conv2dBnAct(Cell):
|
||||||
has_bias=False,
|
has_bias=False,
|
||||||
weight_init='normal',
|
weight_init='normal',
|
||||||
bias_init='zeros',
|
bias_init='zeros',
|
||||||
batchnorm=None,
|
has_bn=False,
|
||||||
activation=None):
|
activation=None):
|
||||||
super(Conv2dBnAct, self).__init__()
|
super(Conv2dBnAct, self).__init__()
|
||||||
self.conv = conv.Conv2d(
|
|
||||||
in_channels,
|
if context.get_context('device_target') == "Ascend" and group > 1:
|
||||||
out_channels,
|
self.conv = conv.DepthwiseConv2d(in_channels,
|
||||||
kernel_size,
|
out_channels,
|
||||||
stride,
|
kernel_size=kernel_size,
|
||||||
pad_mode,
|
stride=stride,
|
||||||
padding,
|
pad_mode=pad_mode,
|
||||||
dilation,
|
padding=padding,
|
||||||
group,
|
dilation=dilation,
|
||||||
has_bias,
|
group=group,
|
||||||
weight_init,
|
has_bias=has_bias,
|
||||||
bias_init)
|
weight_init=weight_init,
|
||||||
self.has_bn = batchnorm is not None
|
bias_init=bias_init)
|
||||||
|
else:
|
||||||
|
self.conv = conv.Conv2d(in_channels,
|
||||||
|
out_channels,
|
||||||
|
kernel_size=kernel_size,
|
||||||
|
stride=stride,
|
||||||
|
pad_mode=pad_mode,
|
||||||
|
padding=padding,
|
||||||
|
dilation=dilation,
|
||||||
|
group=group,
|
||||||
|
has_bias=has_bias,
|
||||||
|
weight_init=weight_init,
|
||||||
|
bias_init=bias_init)
|
||||||
|
|
||||||
|
self.has_bn = has_bn
|
||||||
self.has_act = activation is not None
|
self.has_act = activation is not None
|
||||||
self.batchnorm = batchnorm
|
if has_bn:
|
||||||
if batchnorm is True:
|
|
||||||
self.batchnorm = BatchNorm2d(out_channels)
|
self.batchnorm = BatchNorm2d(out_channels)
|
||||||
elif batchnorm is not None:
|
|
||||||
validator.check_isinstance('batchnorm', batchnorm, (BatchNorm2d,))
|
|
||||||
self.activation = get_activation(activation)
|
self.activation = get_activation(activation)
|
||||||
|
|
||||||
def construct(self, x):
|
def construct(self, x):
|
||||||
|
@ -160,7 +171,7 @@ class DenseBnAct(Cell):
|
||||||
same as input x. The values of str refer to the function `initializer`. Default: 'zeros'.
|
same as input x. The values of str refer to the function `initializer`. Default: 'zeros'.
|
||||||
has_bias (bool): Specifies whether the layer uses a bias vector. Default: True.
|
has_bias (bool): Specifies whether the layer uses a bias vector. Default: True.
|
||||||
activation (str): Regularizer function applied to the output of the layer, eg. 'relu'. Default: None.
|
activation (str): Regularizer function applied to the output of the layer, eg. 'relu'. Default: None.
|
||||||
batchnorm (bool): Specifies to used batchnorm or not. Default: None.
|
has_bn (bool): Specifies to used batchnorm or not. Default: False.
|
||||||
activation (string): Specifies activation type. The optional values are as following:
|
activation (string): Specifies activation type. The optional values are as following:
|
||||||
'softmax', 'logsoftmax', 'relu', 'relu6', 'tanh', 'gelu', 'sigmoid',
|
'softmax', 'logsoftmax', 'relu', 'relu6', 'tanh', 'gelu', 'sigmoid',
|
||||||
'prelu', 'leakyrelu', 'hswish', 'hsigmoid'. Default: None.
|
'prelu', 'leakyrelu', 'hswish', 'hsigmoid'. Default: None.
|
||||||
|
@ -172,7 +183,7 @@ class DenseBnAct(Cell):
|
||||||
Tensor of shape :math:`(N, out\_channels)`.
|
Tensor of shape :math:`(N, out\_channels)`.
|
||||||
|
|
||||||
Examples:
|
Examples:
|
||||||
>>> net = nn.Dense(3, 4)
|
>>> net = nn.DenseBnAct(3, 4)
|
||||||
>>> input = Tensor(np.random.randint(0, 255, [2, 3]), mindspore.float32)
|
>>> input = Tensor(np.random.randint(0, 255, [2, 3]), mindspore.float32)
|
||||||
>>> net(input)
|
>>> net(input)
|
||||||
"""
|
"""
|
||||||
|
@ -183,7 +194,7 @@ class DenseBnAct(Cell):
|
||||||
weight_init='normal',
|
weight_init='normal',
|
||||||
bias_init='zeros',
|
bias_init='zeros',
|
||||||
has_bias=True,
|
has_bias=True,
|
||||||
batchnorm=None,
|
has_bn=False,
|
||||||
activation=None):
|
activation=None):
|
||||||
super(DenseBnAct, self).__init__()
|
super(DenseBnAct, self).__init__()
|
||||||
self.dense = basic.Dense(
|
self.dense = basic.Dense(
|
||||||
|
@ -192,12 +203,10 @@ class DenseBnAct(Cell):
|
||||||
weight_init,
|
weight_init,
|
||||||
bias_init,
|
bias_init,
|
||||||
has_bias)
|
has_bias)
|
||||||
self.has_bn = batchnorm is not None
|
self.has_bn = has_bn
|
||||||
self.has_act = activation is not None
|
self.has_act = activation is not None
|
||||||
if batchnorm is True:
|
if has_bn:
|
||||||
self.batchnorm = BatchNorm2d(out_channels)
|
self.batchnorm = BatchNorm2d(out_channels)
|
||||||
elif batchnorm is not None:
|
|
||||||
validator.check_isinstance('batchnorm', batchnorm, (BatchNorm2d,))
|
|
||||||
self.activation = get_activation(activation)
|
self.activation = get_activation(activation)
|
||||||
|
|
||||||
def construct(self, x):
|
def construct(self, x):
|
||||||
|
@ -271,20 +280,20 @@ class BatchNormFoldCell(Cell):
|
||||||
|
|
||||||
class FakeQuantWithMinMax(Cell):
|
class FakeQuantWithMinMax(Cell):
|
||||||
r"""
|
r"""
|
||||||
Aware Quantization op. This OP provide Fake quantization observer function on data with min and max.
|
Quantization aware op. This OP provide Fake quantization observer function on data with min and max.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
min_init (int, float): The dimension of channel or 1(layer). Default: -6.
|
min_init (int, float): The dimension of channel or 1(layer). Default: -6.
|
||||||
max_init (int, float): The dimension of channel or 1(layer). Default: 6.
|
max_init (int, float): The dimension of channel or 1(layer). Default: 6.
|
||||||
num_bits (int): Quantization number bit, support 4 and 8bit. Default: 8.
|
|
||||||
ema (bool): Exponential Moving Average algorithm update min and max. Default: False.
|
ema (bool): Exponential Moving Average algorithm update min and max. Default: False.
|
||||||
ema_decay (float): Exponential Moving Average algorithm parameter. Default: 0.999.
|
ema_decay (float): Exponential Moving Average algorithm parameter. Default: 0.999.
|
||||||
per_channel (bool): Quantization granularity based on layer or on channel. Default: False.
|
per_channel (bool): Quantization granularity based on layer or on channel. Default: False.
|
||||||
channel_axis (int): Quantization by channel axis. Default: 1.
|
channel_axis (int): Quantization by channel axis. Default: 1.
|
||||||
out_channels (int): declarate the min and max channel size, Default: 1.
|
num_channels (int): declarate the min and max channel size, Default: 1.
|
||||||
quant_delay (int): Quantization delay parameters according by global step. Default: 0.
|
num_bits (int): Quantization number bit, support 4 and 8bit. Default: 8.
|
||||||
symmetric (bool): Quantization algorithm use symmetric or not. Default: False.
|
symmetric (bool): Quantization algorithm use symmetric or not. Default: False.
|
||||||
narrow_range (bool): Quantization algorithm use narrow range or not. Default: False.
|
narrow_range (bool): Quantization algorithm use narrow range or not. Default: False.
|
||||||
|
quant_delay (int): Quantization delay parameters according by global step. Default: 0.
|
||||||
|
|
||||||
Inputs:
|
Inputs:
|
||||||
- **x** (Tensor) - The input of FakeQuantWithMinMax.
|
- **x** (Tensor) - The input of FakeQuantWithMinMax.
|
||||||
|
@ -301,24 +310,27 @@ class FakeQuantWithMinMax(Cell):
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
min_init=-6,
|
min_init=-6,
|
||||||
max_init=6,
|
max_init=6,
|
||||||
num_bits=8,
|
|
||||||
ema=False,
|
ema=False,
|
||||||
ema_decay=0.999,
|
ema_decay=0.999,
|
||||||
per_channel=False,
|
per_channel=False,
|
||||||
channel_axis=1,
|
channel_axis=1,
|
||||||
out_channels=1,
|
num_channels=1,
|
||||||
quant_delay=0,
|
num_bits=8,
|
||||||
symmetric=False,
|
symmetric=False,
|
||||||
narrow_range=False):
|
narrow_range=False,
|
||||||
|
quant_delay=0):
|
||||||
"""init FakeQuantWithMinMax layer"""
|
"""init FakeQuantWithMinMax layer"""
|
||||||
super(FakeQuantWithMinMax, self).__init__()
|
super(FakeQuantWithMinMax, self).__init__()
|
||||||
|
validator.check_type("min_init", min_init, [int, float])
|
||||||
|
validator.check_type("max_init", max_init, [int, float])
|
||||||
|
validator.check("min_init", min_init, "max_init", max_init, rel=Rel.LT)
|
||||||
self.min_init = min_init
|
self.min_init = min_init
|
||||||
self.max_init = max_init
|
self.max_init = max_init
|
||||||
self.num_bits = num_bits
|
self.num_bits = num_bits
|
||||||
self.ema = ema
|
self.ema = ema
|
||||||
self.ema_decay = ema_decay
|
self.ema_decay = ema_decay
|
||||||
self.per_channel = per_channel
|
self.per_channel = per_channel
|
||||||
self.out_channels = out_channels
|
self.num_channels = num_channels
|
||||||
self.channel_axis = channel_axis
|
self.channel_axis = channel_axis
|
||||||
self.quant_delay = quant_delay
|
self.quant_delay = quant_delay
|
||||||
self.symmetric = symmetric
|
self.symmetric = symmetric
|
||||||
|
@ -327,54 +339,54 @@ class FakeQuantWithMinMax(Cell):
|
||||||
|
|
||||||
# init tensor min and max for fake quant op
|
# init tensor min and max for fake quant op
|
||||||
if self.per_channel:
|
if self.per_channel:
|
||||||
min_array = np.array([self.min_init for i in range(0, self.out_channels)]).astype(np.float32)
|
min_array = np.array([self.min_init] * self.num_channels).astype(np.float32)
|
||||||
max_array = np.array([self.max_init for i in range(0, self.out_channels)]).astype(np.float32)
|
max_array = np.array([self.max_init] * self.num_channels).astype(np.float32)
|
||||||
else:
|
else:
|
||||||
min_array = np.array([self.min_init]).reshape(1).astype(np.float32)
|
min_array = np.array([self.min_init]).astype(np.float32)
|
||||||
max_array = np.array([self.max_init]).reshape(1).astype(np.float32)
|
max_array = np.array([self.max_init]).astype(np.float32)
|
||||||
self.minq = Parameter(Tensor(min_array), name='quant_min', requires_grad=False)
|
self.minq = Parameter(Tensor(min_array), name='quant_min', requires_grad=False)
|
||||||
self.maxq = Parameter(Tensor(max_array), name='quant_max', requires_grad=False)
|
self.maxq = Parameter(Tensor(max_array), name='quant_max', requires_grad=False)
|
||||||
|
|
||||||
# init fake quant relative op
|
# init fake quant relative op
|
||||||
if per_channel:
|
if per_channel:
|
||||||
quant_fun = partial(Q.FakeQuantPerChannel, channel_axis=self.channel_axis)
|
quant_fun = partial(Q.FakeQuantPerChannel, channel_axis=self.channel_axis)
|
||||||
ema_fun = partial(Q.FakeQuantMinMaxPerChannelUpdate, channel_axis=self.channel_axis)
|
ema_fun = partial(Q.MinMaxUpdatePerChannel, channel_axis=self.channel_axis)
|
||||||
else:
|
else:
|
||||||
quant_fun = Q.FakeQuantPerLayer
|
quant_fun = Q.FakeQuantPerLayer
|
||||||
ema_fun = Q.FakeQuantMinMaxPerLayerUpdate
|
ema_fun = Q.MinMaxUpdatePerLayer
|
||||||
|
|
||||||
|
self.ema_update = ema_fun(ema=self.ema, ema_decay=self.ema_decay)
|
||||||
if self.is_ascend:
|
if self.is_ascend:
|
||||||
self.fake_quant = quant_fun(num_bits=self.num_bits,
|
self.fake_quant_train = quant_fun(num_bits=self.num_bits,
|
||||||
symmetric=self.symmetric,
|
symmetric=self.symmetric,
|
||||||
narrow_range=self.narrow_range)
|
narrow_range=self.narrow_range)
|
||||||
|
self.fake_quant_infer = self.fake_quant_train
|
||||||
else:
|
else:
|
||||||
self.fake_quant = quant_fun(num_bits=self.num_bits,
|
quant_fun = partial(quant_fun,
|
||||||
ema=self.ema,
|
ema=self.ema,
|
||||||
ema_decay=ema_decay,
|
ema_decay=ema_decay,
|
||||||
quant_delay=quant_delay,
|
num_bits=self.num_bits,
|
||||||
symmetric=self.symmetric,
|
symmetric=self.symmetric,
|
||||||
narrow_range=self.narrow_range)
|
narrow_range=self.narrow_range,
|
||||||
self.ema_update = ema_fun(num_bits=self.num_bits,
|
quant_delay=quant_delay)
|
||||||
ema=self.ema,
|
self.fake_quant_train = quant_fun(training=True)
|
||||||
ema_decay=self.ema_decay,
|
self.fake_quant_infer = quant_fun(training=False)
|
||||||
symmetric=self.symmetric,
|
|
||||||
narrow_range=self.narrow_range)
|
|
||||||
|
|
||||||
def extend_repr(self):
|
def extend_repr(self):
|
||||||
s = 'num_bits={}, symmetric={}, narrow_range={}, ema={}({}), per_channel={}({}, {}), ' \
|
s = 'num_bits={}, symmetric={}, narrow_range={}, ema={}({}), per_channel={}({}, {}), ' \
|
||||||
'quant_delay={}, min_init={}, max_init={}'.format(
|
'quant_delay={}, min_init={}, max_init={}'.format(
|
||||||
self.num_bits, self.symmetric, self.narrow_range, self.ema, self.ema_decay, self.per_channel,
|
self.num_bits, self.symmetric, self.narrow_range, self.ema, self.ema_decay, self.per_channel,
|
||||||
self.channel_axis, self.out_channels, self.quant_delay, self.min_init, self.max_init)
|
self.channel_axis, self.num_channels, self.quant_delay, self.min_init, self.max_init)
|
||||||
return s
|
return s
|
||||||
|
|
||||||
def construct(self, x):
|
def construct(self, x):
|
||||||
if self.is_ascend and self.training:
|
if self.training:
|
||||||
min_up, max_up = self.ema_update(x, self.minq, self.maxq)
|
min_up, max_up = self.ema_update(x, self.minq, self.maxq)
|
||||||
out = self.fake_quant(x, min_up, max_up)
|
|
||||||
P.Assign()(self.minq, min_up)
|
P.Assign()(self.minq, min_up)
|
||||||
P.Assign()(self.maxq, max_up)
|
P.Assign()(self.maxq, max_up)
|
||||||
|
out = self.fake_quant_train(x, self.minq, self.maxq)
|
||||||
else:
|
else:
|
||||||
out = self.fake_quant(x, self.minq, self.maxq)
|
out = self.fake_quant_infer(x, self.minq, self.maxq)
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
@ -391,8 +403,8 @@ class Conv2dBatchNormQuant(Cell):
|
||||||
stride (int): Specifies stride for all spatial dimensions with the same value.
|
stride (int): Specifies stride for all spatial dimensions with the same value.
|
||||||
pad_mode: (str): Specifies padding mode. The optional values are "same", "valid", "pad". Default: "same".
|
pad_mode: (str): Specifies padding mode. The optional values are "same", "valid", "pad". Default: "same".
|
||||||
padding: (int): Implicit paddings on both sides of the input. Default: 0.
|
padding: (int): Implicit paddings on both sides of the input. Default: 0.
|
||||||
eps (int): Parameters for BatchNormal. Default: 1e-5.
|
eps (float): Parameters for BatchNormal. Default: 1e-5.
|
||||||
momentum (int): Parameters for BatchNormal op. Default: 0.997.
|
momentum (float): Parameters for BatchNormal op. Default: 0.997.
|
||||||
weight_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the
|
weight_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the
|
||||||
convolution kernel. Default: 'normal'.
|
convolution kernel. Default: 'normal'.
|
||||||
beta_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the
|
beta_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the
|
||||||
|
@ -403,13 +415,13 @@ class Conv2dBatchNormQuant(Cell):
|
||||||
mean vector. Default: 'zeros'.
|
mean vector. Default: 'zeros'.
|
||||||
var_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the
|
var_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the
|
||||||
variance vector. Default: 'ones'.
|
variance vector. Default: 'ones'.
|
||||||
quant_delay (int): Quantization delay parameters according by global step. Default: 0.
|
|
||||||
freeze_bn (int): Quantization freeze BatchNormal op according by global step. Default: 100000.
|
|
||||||
fake (bool): Conv2dBatchNormQuant Cell add FakeQuantWithMinMax op or not. Default: True.
|
fake (bool): Conv2dBatchNormQuant Cell add FakeQuantWithMinMax op or not. Default: True.
|
||||||
|
per_channel (bool): FakeQuantWithMinMax Parameters. Default: False.
|
||||||
num_bits (int): Quantization number bit, support 4 and 8bit. Default: 8.
|
num_bits (int): Quantization number bit, support 4 and 8bit. Default: 8.
|
||||||
per_channel (bool): Quantization granularity based on layer or on channel. Default: False.
|
|
||||||
symmetric (bool): Quantization algorithm use symmetric or not. Default: False.
|
symmetric (bool): Quantization algorithm use symmetric or not. Default: False.
|
||||||
narrow_range (bool): Quantization algorithm use narrow range or not. Default: False.
|
narrow_range (bool): Quantization algorithm use narrow range or not. Default: False.
|
||||||
|
quant_delay (int): Quantization delay parameters according by global step. Default: 0.
|
||||||
|
freeze_bn (int): Quantization freeze BatchNormal op according by global step. Default: 100000.
|
||||||
|
|
||||||
Inputs:
|
Inputs:
|
||||||
- **x** (Tensor) - Tensor of shape :math:`(N, C_{in}, H_{in}, W_{in})`.
|
- **x** (Tensor) - Tensor of shape :math:`(N, C_{in}, H_{in}, W_{in})`.
|
||||||
|
@ -440,13 +452,13 @@ class Conv2dBatchNormQuant(Cell):
|
||||||
gamma_init='ones',
|
gamma_init='ones',
|
||||||
mean_init='zeros',
|
mean_init='zeros',
|
||||||
var_init='ones',
|
var_init='ones',
|
||||||
quant_delay=0,
|
|
||||||
freeze_bn=100000,
|
|
||||||
fake=True,
|
fake=True,
|
||||||
num_bits=8,
|
|
||||||
per_channel=False,
|
per_channel=False,
|
||||||
|
num_bits=8,
|
||||||
symmetric=False,
|
symmetric=False,
|
||||||
narrow_range=False):
|
narrow_range=False,
|
||||||
|
quant_delay=0,
|
||||||
|
freeze_bn=100000):
|
||||||
"""init Conv2dBatchNormQuant layer"""
|
"""init Conv2dBatchNormQuant layer"""
|
||||||
super(Conv2dBatchNormQuant, self).__init__()
|
super(Conv2dBatchNormQuant, self).__init__()
|
||||||
self.in_channels = in_channels
|
self.in_channels = in_channels
|
||||||
|
@ -503,12 +515,13 @@ class Conv2dBatchNormQuant(Cell):
|
||||||
self.fake_quant_weight = FakeQuantWithMinMax(min_init=-6,
|
self.fake_quant_weight = FakeQuantWithMinMax(min_init=-6,
|
||||||
max_init=6,
|
max_init=6,
|
||||||
ema=False,
|
ema=False,
|
||||||
num_bits=num_bits,
|
|
||||||
quant_delay=quant_delay,
|
|
||||||
per_channel=per_channel,
|
per_channel=per_channel,
|
||||||
out_channels=out_channels,
|
channel_axis=channel_axis,
|
||||||
|
num_channels=out_channels,
|
||||||
|
num_bits=num_bits,
|
||||||
symmetric=symmetric,
|
symmetric=symmetric,
|
||||||
narrow_range=narrow_range)
|
narrow_range=narrow_range,
|
||||||
|
quant_delay=quant_delay)
|
||||||
self.batchnorm_fold = BatchNormFoldCell(epsilon=eps, momentum=momentum, freeze_bn=freeze_bn)
|
self.batchnorm_fold = BatchNormFoldCell(epsilon=eps, momentum=momentum, freeze_bn=freeze_bn)
|
||||||
self.correct_mul = Q.CorrectionMul(channel_axis)
|
self.correct_mul = Q.CorrectionMul(channel_axis)
|
||||||
if context.get_context('device_target') == "Ascend":
|
if context.get_context('device_target') == "Ascend":
|
||||||
|
@ -582,11 +595,11 @@ class Conv2dQuant(Cell):
|
||||||
weight_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the convolution kernel.
|
weight_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the convolution kernel.
|
||||||
Default: 'normal'.
|
Default: 'normal'.
|
||||||
bias_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the bias vector. Default: 'zeros'.
|
bias_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the bias vector. Default: 'zeros'.
|
||||||
quant_delay (int): Quantization delay parameters according by global step. Default: 0.
|
per_channel (bool): FakeQuantWithMinMax Parameters. Default: False.
|
||||||
num_bits (int): Quantization number bit, support 4 and 8bit. Default: 8.
|
num_bits (int): Quantization number bit, support 4 and 8bit. Default: 8.
|
||||||
per_channel (bool): Quantization granularity based on layer or on channel. Default: False.
|
|
||||||
symmetric (bool): Quantization algorithm use symmetric or not. Default: False.
|
symmetric (bool): Quantization algorithm use symmetric or not. Default: False.
|
||||||
narrow_range (bool): Quantization algorithm use narrow range or not. Default: False.
|
narrow_range (bool): Quantization algorithm use narrow range or not. Default: False.
|
||||||
|
quant_delay (int): Quantization delay parameters according by global step. Default: 0.
|
||||||
|
|
||||||
Inputs:
|
Inputs:
|
||||||
- **x** (Tensor) - Tensor of shape :math:`(N, C_{in}, H_{in}, W_{in})`.
|
- **x** (Tensor) - Tensor of shape :math:`(N, C_{in}, H_{in}, W_{in})`.
|
||||||
|
@ -613,11 +626,11 @@ class Conv2dQuant(Cell):
|
||||||
has_bias=False,
|
has_bias=False,
|
||||||
weight_init='normal',
|
weight_init='normal',
|
||||||
bias_init='zeros',
|
bias_init='zeros',
|
||||||
quant_delay=0,
|
|
||||||
num_bits=8,
|
|
||||||
per_channel=False,
|
per_channel=False,
|
||||||
|
num_bits=8,
|
||||||
symmetric=False,
|
symmetric=False,
|
||||||
narrow_range=False):
|
narrow_range=False,
|
||||||
|
quant_delay=0):
|
||||||
super(Conv2dQuant, self).__init__()
|
super(Conv2dQuant, self).__init__()
|
||||||
if isinstance(kernel_size, int):
|
if isinstance(kernel_size, int):
|
||||||
self.kernel_size = (kernel_size, kernel_size)
|
self.kernel_size = (kernel_size, kernel_size)
|
||||||
|
@ -653,12 +666,13 @@ class Conv2dQuant(Cell):
|
||||||
self.fake_quant_weight = FakeQuantWithMinMax(min_init=-6,
|
self.fake_quant_weight = FakeQuantWithMinMax(min_init=-6,
|
||||||
max_init=6,
|
max_init=6,
|
||||||
ema=False,
|
ema=False,
|
||||||
num_bits=num_bits,
|
|
||||||
quant_delay=quant_delay,
|
|
||||||
per_channel=per_channel,
|
per_channel=per_channel,
|
||||||
out_channels=out_channels,
|
channel_axis=0,
|
||||||
|
num_channels=out_channels,
|
||||||
|
num_bits=num_bits,
|
||||||
symmetric=symmetric,
|
symmetric=symmetric,
|
||||||
narrow_range=narrow_range)
|
narrow_range=narrow_range,
|
||||||
|
quant_delay=quant_delay)
|
||||||
|
|
||||||
def construct(self, x):
|
def construct(self, x):
|
||||||
weight = self.fake_quant_weight(self.weight)
|
weight = self.fake_quant_weight(self.weight)
|
||||||
|
@ -692,11 +706,11 @@ class DenseQuant(Cell):
|
||||||
same as input x. The values of str refer to the function `initializer`. Default: 'zeros'.
|
same as input x. The values of str refer to the function `initializer`. Default: 'zeros'.
|
||||||
has_bias (bool): Specifies whether the layer uses a bias vector. Default: True.
|
has_bias (bool): Specifies whether the layer uses a bias vector. Default: True.
|
||||||
activation (str): Regularizer function applied to the output of the layer, eg. 'relu'. Default: None.
|
activation (str): Regularizer function applied to the output of the layer, eg. 'relu'. Default: None.
|
||||||
|
per_channel (bool): FakeQuantWithMinMax Parameters. Default: False.
|
||||||
num_bits (int): Quantization number bit, support 4 and 8bit. Default: 8.
|
num_bits (int): Quantization number bit, support 4 and 8bit. Default: 8.
|
||||||
quant_delay (int): Quantization delay parameters according by global step. Default: 0.
|
|
||||||
per_channel (bool): Quantization granularity based on layer or on channel. Default: False.
|
|
||||||
symmetric (bool): Quantization algorithm use symmetric or not. Default: False.
|
symmetric (bool): Quantization algorithm use symmetric or not. Default: False.
|
||||||
narrow_range (bool): Quantization algorithm use narrow range or not. Default: False.
|
narrow_range (bool): Quantization algorithm use narrow range or not. Default: False.
|
||||||
|
quant_delay (int): Quantization delay parameters according by global step. Default: 0.
|
||||||
|
|
||||||
Inputs:
|
Inputs:
|
||||||
- **x** (Tensor) - Tensor of shape :math:`(N, C_{in}, H_{in}, W_{in})`.
|
- **x** (Tensor) - Tensor of shape :math:`(N, C_{in}, H_{in}, W_{in})`.
|
||||||
|
@ -718,19 +732,19 @@ class DenseQuant(Cell):
|
||||||
bias_init='zeros',
|
bias_init='zeros',
|
||||||
has_bias=True,
|
has_bias=True,
|
||||||
activation=None,
|
activation=None,
|
||||||
num_bits=8,
|
|
||||||
quant_delay=0,
|
|
||||||
per_channel=False,
|
per_channel=False,
|
||||||
|
num_bits=8,
|
||||||
symmetric=False,
|
symmetric=False,
|
||||||
narrow_range=False):
|
narrow_range=False,
|
||||||
|
quant_delay=0):
|
||||||
super(DenseQuant, self).__init__()
|
super(DenseQuant, self).__init__()
|
||||||
self.in_channels = check_int_positive(in_channels)
|
self.in_channels = check_int_positive(in_channels)
|
||||||
self.out_channels = check_int_positive(out_channels)
|
self.out_channels = check_int_positive(out_channels)
|
||||||
self.has_bias = check_bool(has_bias)
|
self.has_bias = check_bool(has_bias)
|
||||||
|
|
||||||
if isinstance(weight_init, Tensor):
|
if isinstance(weight_init, Tensor):
|
||||||
if weight_init.dim() != 2 or weight_init.shape[0] != out_channels or \
|
if weight_init.dim() != 2 or weight_init.shape()[0] != out_channels or \
|
||||||
weight_init.shape[1] != in_channels:
|
weight_init.shape()[1] != in_channels:
|
||||||
raise ValueError("weight_init shape error")
|
raise ValueError("weight_init shape error")
|
||||||
|
|
||||||
self.weight = Parameter(initializer(
|
self.weight = Parameter(initializer(
|
||||||
|
@ -738,7 +752,7 @@ class DenseQuant(Cell):
|
||||||
|
|
||||||
if self.has_bias:
|
if self.has_bias:
|
||||||
if isinstance(bias_init, Tensor):
|
if isinstance(bias_init, Tensor):
|
||||||
if bias_init.dim() != 1 or bias_init.shape[0] != out_channels:
|
if bias_init.dim() != 1 or bias_init.shape()[0] != out_channels:
|
||||||
raise ValueError("bias_init shape error")
|
raise ValueError("bias_init shape error")
|
||||||
|
|
||||||
self.bias = Parameter(initializer(
|
self.bias = Parameter(initializer(
|
||||||
|
@ -752,12 +766,13 @@ class DenseQuant(Cell):
|
||||||
self.fake_quant_weight = FakeQuantWithMinMax(min_init=-6,
|
self.fake_quant_weight = FakeQuantWithMinMax(min_init=-6,
|
||||||
max_init=6,
|
max_init=6,
|
||||||
ema=False,
|
ema=False,
|
||||||
num_bits=num_bits,
|
|
||||||
quant_delay=quant_delay,
|
|
||||||
per_channel=per_channel,
|
per_channel=per_channel,
|
||||||
out_channels=out_channels,
|
channel_axis=0,
|
||||||
|
num_channels=out_channels,
|
||||||
|
num_bits=num_bits,
|
||||||
symmetric=symmetric,
|
symmetric=symmetric,
|
||||||
narrow_range=narrow_range)
|
narrow_range=narrow_range,
|
||||||
|
quant_delay=quant_delay)
|
||||||
|
|
||||||
def construct(self, x):
|
def construct(self, x):
|
||||||
"""Use operators to construct to Dense layer."""
|
"""Use operators to construct to Dense layer."""
|
||||||
|
@ -780,13 +795,16 @@ class DenseQuant(Cell):
|
||||||
|
|
||||||
return str_info
|
return str_info
|
||||||
|
|
||||||
|
|
||||||
class _QuantActivation(Cell):
|
class _QuantActivation(Cell):
|
||||||
r"""
|
r"""
|
||||||
Base class for Quant activation function. Add Fake Quant OP after activation OP.
|
Base class for Quant activation function. Add Fake Quant OP after activation OP.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def get_origin(self):
|
def get_origin(self):
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
||||||
class ReLUQuant(_QuantActivation):
|
class ReLUQuant(_QuantActivation):
|
||||||
r"""
|
r"""
|
||||||
ReLUQuant activation function. Add Fake Quant OP after Relu OP.
|
ReLUQuant activation function. Add Fake Quant OP after Relu OP.
|
||||||
|
@ -794,12 +812,12 @@ class ReLUQuant(_QuantActivation):
|
||||||
For a more Detailed overview of ReLU op.
|
For a more Detailed overview of ReLU op.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
num_bits (int): Quantization number bit, support 4 and 8bit. Default: 8.
|
|
||||||
quant_delay (int): Quantization delay parameters according by global step. Default: 0.
|
|
||||||
ema_decay (float): Exponential Moving Average algorithm parameter. Default: 0.999.
|
ema_decay (float): Exponential Moving Average algorithm parameter. Default: 0.999.
|
||||||
per_channel (bool): Quantization granularity based on layer or on channel. Default: False.
|
per_channel (bool): Quantization granularity based on layer or on channel. Default: False.
|
||||||
|
num_bits (int): Quantization number bit, support 4 and 8bit. Default: 8.
|
||||||
symmetric (bool): Quantization algorithm use symmetric or not. Default: False.
|
symmetric (bool): Quantization algorithm use symmetric or not. Default: False.
|
||||||
narrow_range (bool): Quantization algorithm use narrow range or not. Default: False.
|
narrow_range (bool): Quantization algorithm use narrow range or not. Default: False.
|
||||||
|
quant_delay (int): Quantization delay parameters according by global step. Default: 0.
|
||||||
|
|
||||||
Inputs:
|
Inputs:
|
||||||
- **x** (Tensor) - The input of ReLUQuant.
|
- **x** (Tensor) - The input of ReLUQuant.
|
||||||
|
@ -814,22 +832,22 @@ class ReLUQuant(_QuantActivation):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
num_bits=8,
|
|
||||||
quant_delay=0,
|
|
||||||
ema_decay=0.999,
|
ema_decay=0.999,
|
||||||
per_channel=False,
|
per_channel=False,
|
||||||
|
num_bits=8,
|
||||||
symmetric=False,
|
symmetric=False,
|
||||||
narrow_range=False):
|
narrow_range=False,
|
||||||
|
quant_delay=0):
|
||||||
super(ReLUQuant, self).__init__()
|
super(ReLUQuant, self).__init__()
|
||||||
self.fake_quant_act = FakeQuantWithMinMax(min_init=0,
|
self.fake_quant_act = FakeQuantWithMinMax(min_init=0,
|
||||||
max_init=6,
|
max_init=6,
|
||||||
num_bits=num_bits,
|
|
||||||
quant_delay=quant_delay,
|
|
||||||
ema=True,
|
ema=True,
|
||||||
per_channel=per_channel,
|
|
||||||
ema_decay=ema_decay,
|
ema_decay=ema_decay,
|
||||||
|
per_channel=per_channel,
|
||||||
|
num_bits=num_bits,
|
||||||
symmetric=symmetric,
|
symmetric=symmetric,
|
||||||
narrow_range=narrow_range)
|
narrow_range=narrow_range,
|
||||||
|
quant_delay=quant_delay)
|
||||||
self.relu = P.ReLU()
|
self.relu = P.ReLU()
|
||||||
|
|
||||||
def construct(self, x):
|
def construct(self, x):
|
||||||
|
@ -850,12 +868,12 @@ class ReLU6Quant(_QuantActivation):
|
||||||
For a more Detailed overview of ReLU6 op.
|
For a more Detailed overview of ReLU6 op.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
num_bits (int): Quantization number bit, support 4 and 8bit. Default: 8.
|
|
||||||
quant_delay (int): Quantization delay parameters according by global step. Default: 0.
|
|
||||||
ema_decay (float): Exponential Moving Average algorithm parameter. Default: 0.999.
|
ema_decay (float): Exponential Moving Average algorithm parameter. Default: 0.999.
|
||||||
per_channel (bool): Quantization granularity based on layer or on channel. Default: False.
|
per_channel (bool): Quantization granularity based on layer or on channel. Default: False.
|
||||||
|
num_bits (int): Quantization number bit, support 4 and 8bit. Default: 8.
|
||||||
symmetric (bool): Quantization algorithm use symmetric or not. Default: False.
|
symmetric (bool): Quantization algorithm use symmetric or not. Default: False.
|
||||||
narrow_range (bool): Quantization algorithm use narrow range or not. Default: False.
|
narrow_range (bool): Quantization algorithm use narrow range or not. Default: False.
|
||||||
|
quant_delay (int): Quantization delay parameters according by global step. Default: 0.
|
||||||
|
|
||||||
Inputs:
|
Inputs:
|
||||||
- **x** (Tensor) - The input of ReLU6Quant.
|
- **x** (Tensor) - The input of ReLU6Quant.
|
||||||
|
@ -870,22 +888,22 @@ class ReLU6Quant(_QuantActivation):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
num_bits=8,
|
|
||||||
quant_delay=0,
|
|
||||||
ema_decay=0.999,
|
ema_decay=0.999,
|
||||||
per_channel=False,
|
per_channel=False,
|
||||||
|
num_bits=8,
|
||||||
symmetric=False,
|
symmetric=False,
|
||||||
narrow_range=False):
|
narrow_range=False,
|
||||||
|
quant_delay=0):
|
||||||
super(ReLU6Quant, self).__init__()
|
super(ReLU6Quant, self).__init__()
|
||||||
self.fake_quant_act = FakeQuantWithMinMax(min_init=0,
|
self.fake_quant_act = FakeQuantWithMinMax(min_init=0,
|
||||||
max_init=6,
|
max_init=6,
|
||||||
num_bits=num_bits,
|
|
||||||
quant_delay=quant_delay,
|
|
||||||
ema=True,
|
ema=True,
|
||||||
per_channel=per_channel,
|
|
||||||
ema_decay=ema_decay,
|
ema_decay=ema_decay,
|
||||||
|
per_channel=per_channel,
|
||||||
|
num_bits=num_bits,
|
||||||
symmetric=symmetric,
|
symmetric=symmetric,
|
||||||
narrow_range=narrow_range)
|
narrow_range=narrow_range,
|
||||||
|
quant_delay=quant_delay)
|
||||||
self.relu6 = P.ReLU6()
|
self.relu6 = P.ReLU6()
|
||||||
|
|
||||||
def construct(self, x):
|
def construct(self, x):
|
||||||
|
@ -896,6 +914,7 @@ class ReLU6Quant(_QuantActivation):
|
||||||
def get_origin(self):
|
def get_origin(self):
|
||||||
return self.relu6
|
return self.relu6
|
||||||
|
|
||||||
|
|
||||||
class HSwishQuant(_QuantActivation):
|
class HSwishQuant(_QuantActivation):
|
||||||
r"""
|
r"""
|
||||||
HSwishQuant activation function. Add Fake Quant OP after HSwish OP.
|
HSwishQuant activation function. Add Fake Quant OP after HSwish OP.
|
||||||
|
@ -903,12 +922,12 @@ class HSwishQuant(_QuantActivation):
|
||||||
For a more Detailed overview of HSwish op.
|
For a more Detailed overview of HSwish op.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
num_bits (int): Quantization number bit, support 4 and 8bit. Default: 8.
|
|
||||||
quant_delay (int): Quantization delay parameters according by global step. Default: 0.
|
|
||||||
ema_decay (float): Exponential Moving Average algorithm parameter. Default: 0.999.
|
ema_decay (float): Exponential Moving Average algorithm parameter. Default: 0.999.
|
||||||
per_channel (bool): Quantization granularity based on layer or on channel. Default: False.
|
per_channel (bool): Quantization granularity based on layer or on channel. Default: False.
|
||||||
|
num_bits (int): Quantization number bit, support 4 and 8bit. Default: 8.
|
||||||
symmetric (bool): Quantization algorithm use symmetric or not. Default: False.
|
symmetric (bool): Quantization algorithm use symmetric or not. Default: False.
|
||||||
narrow_range (bool): Quantization algorithm use narrow range or not. Default: False.
|
narrow_range (bool): Quantization algorithm use narrow range or not. Default: False.
|
||||||
|
quant_delay (int): Quantization delay parameters according by global step. Default: 0.
|
||||||
|
|
||||||
Inputs:
|
Inputs:
|
||||||
- **x** (Tensor) - The input of HSwishQuant.
|
- **x** (Tensor) - The input of HSwishQuant.
|
||||||
|
@ -923,31 +942,31 @@ class HSwishQuant(_QuantActivation):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
num_bits=8,
|
|
||||||
quant_delay=0,
|
|
||||||
ema_decay=0.999,
|
ema_decay=0.999,
|
||||||
per_channel=False,
|
per_channel=False,
|
||||||
|
num_bits=8,
|
||||||
symmetric=False,
|
symmetric=False,
|
||||||
narrow_range=False):
|
narrow_range=False,
|
||||||
|
quant_delay=0):
|
||||||
super(HSwishQuant, self).__init__()
|
super(HSwishQuant, self).__init__()
|
||||||
self.fake_quant_act_before = FakeQuantWithMinMax(min_init=-6,
|
self.fake_quant_act_before = FakeQuantWithMinMax(min_init=-6,
|
||||||
max_init=6,
|
max_init=6,
|
||||||
num_bits=num_bits,
|
|
||||||
quant_delay=quant_delay,
|
|
||||||
ema=True,
|
ema=True,
|
||||||
per_channel=per_channel,
|
|
||||||
ema_decay=ema_decay,
|
ema_decay=ema_decay,
|
||||||
|
per_channel=per_channel,
|
||||||
|
num_bits=num_bits,
|
||||||
symmetric=symmetric,
|
symmetric=symmetric,
|
||||||
narrow_range=narrow_range)
|
narrow_range=narrow_range,
|
||||||
|
quant_delay=quant_delay)
|
||||||
self.fake_quant_act_after = FakeQuantWithMinMax(min_init=-6,
|
self.fake_quant_act_after = FakeQuantWithMinMax(min_init=-6,
|
||||||
max_init=6,
|
max_init=6,
|
||||||
num_bits=num_bits,
|
|
||||||
quant_delay=quant_delay,
|
|
||||||
ema=True,
|
ema=True,
|
||||||
per_channel=per_channel,
|
|
||||||
ema_decay=ema_decay,
|
ema_decay=ema_decay,
|
||||||
|
per_channel=per_channel,
|
||||||
|
num_bits=num_bits,
|
||||||
symmetric=symmetric,
|
symmetric=symmetric,
|
||||||
narrow_range=narrow_range)
|
narrow_range=narrow_range,
|
||||||
|
quant_delay=quant_delay)
|
||||||
self.act = P.HSwish()
|
self.act = P.HSwish()
|
||||||
|
|
||||||
def construct(self, x):
|
def construct(self, x):
|
||||||
|
@ -959,6 +978,7 @@ class HSwishQuant(_QuantActivation):
|
||||||
def get_origin(self):
|
def get_origin(self):
|
||||||
return self.act
|
return self.act
|
||||||
|
|
||||||
|
|
||||||
class HSigmoidQuant(_QuantActivation):
|
class HSigmoidQuant(_QuantActivation):
|
||||||
r"""
|
r"""
|
||||||
HSigmoidQuant activation function. Add Fake Quant OP before and after HSigmoid OP.
|
HSigmoidQuant activation function. Add Fake Quant OP before and after HSigmoid OP.
|
||||||
|
@ -966,12 +986,12 @@ class HSigmoidQuant(_QuantActivation):
|
||||||
For a more Detailed overview of HSigmoid op.
|
For a more Detailed overview of HSigmoid op.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
num_bits (int): Quantization number bit, support 4 and 8bit. Default: 8.
|
|
||||||
quant_delay (int): Quantization delay parameters according by global step. Default: 0.
|
|
||||||
ema_decay (float): Exponential Moving Average algorithm parameter. Default: 0.999.
|
ema_decay (float): Exponential Moving Average algorithm parameter. Default: 0.999.
|
||||||
per_channel (bool): Quantization granularity based on layer or on channel. Default: False.
|
per_channel (bool): Quantization granularity based on layer or on channel. Default: False.
|
||||||
|
num_bits (int): Quantization number bit, support 4 and 8bit. Default: 8.
|
||||||
symmetric (bool): Quantization algorithm use symmetric or not. Default: False.
|
symmetric (bool): Quantization algorithm use symmetric or not. Default: False.
|
||||||
narrow_range (bool): Quantization algorithm use narrow range or not. Default: False.
|
narrow_range (bool): Quantization algorithm use narrow range or not. Default: False.
|
||||||
|
quant_delay (int): Quantization delay parameters according by global step. Default: 0.
|
||||||
|
|
||||||
Inputs:
|
Inputs:
|
||||||
- **x** (Tensor) - The input of HSigmoidQuant.
|
- **x** (Tensor) - The input of HSigmoidQuant.
|
||||||
|
@ -986,30 +1006,31 @@ class HSigmoidQuant(_QuantActivation):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
num_bits=8,
|
|
||||||
quant_delay=0,
|
|
||||||
ema_decay=0.999,
|
ema_decay=0.999,
|
||||||
per_channel=False,
|
per_channel=False,
|
||||||
|
num_bits=8,
|
||||||
symmetric=False,
|
symmetric=False,
|
||||||
narrow_range=False):
|
narrow_range=False,
|
||||||
|
quant_delay=0):
|
||||||
super(HSigmoidQuant, self).__init__()
|
super(HSigmoidQuant, self).__init__()
|
||||||
self.fake_quant_act_before = FakeQuantWithMinMax(min_init=-6,
|
self.fake_quant_act_before = FakeQuantWithMinMax(min_init=-6,
|
||||||
max_init=6,
|
max_init=6,
|
||||||
num_bits=num_bits,
|
|
||||||
quant_delay=quant_delay,
|
|
||||||
ema=True,
|
ema=True,
|
||||||
|
ema_decay=ema_decay,
|
||||||
per_channel=per_channel,
|
per_channel=per_channel,
|
||||||
|
num_bits=num_bits,
|
||||||
symmetric=symmetric,
|
symmetric=symmetric,
|
||||||
narrow_range=narrow_range)
|
narrow_range=narrow_range,
|
||||||
|
quant_delay=quant_delay)
|
||||||
self.fake_quant_act_after = FakeQuantWithMinMax(min_init=-6,
|
self.fake_quant_act_after = FakeQuantWithMinMax(min_init=-6,
|
||||||
max_init=6,
|
max_init=6,
|
||||||
num_bits=num_bits,
|
|
||||||
quant_delay=quant_delay,
|
|
||||||
ema=True,
|
ema=True,
|
||||||
per_channel=per_channel,
|
|
||||||
ema_decay=ema_decay,
|
ema_decay=ema_decay,
|
||||||
|
per_channel=per_channel,
|
||||||
|
num_bits=num_bits,
|
||||||
symmetric=symmetric,
|
symmetric=symmetric,
|
||||||
narrow_range=narrow_range)
|
narrow_range=narrow_range,
|
||||||
|
quant_delay=quant_delay)
|
||||||
self.act = P.HSigmoid()
|
self.act = P.HSigmoid()
|
||||||
|
|
||||||
def construct(self, x):
|
def construct(self, x):
|
||||||
|
@ -1021,6 +1042,7 @@ class HSigmoidQuant(_QuantActivation):
|
||||||
def get_origin(self):
|
def get_origin(self):
|
||||||
return self.act
|
return self.act
|
||||||
|
|
||||||
|
|
||||||
class TensorAddQuant(Cell):
|
class TensorAddQuant(Cell):
|
||||||
r"""
|
r"""
|
||||||
Add Fake Quant OP after TensorAdd OP.
|
Add Fake Quant OP after TensorAdd OP.
|
||||||
|
@ -1028,12 +1050,12 @@ class TensorAddQuant(Cell):
|
||||||
For a more Detailed overview of TensorAdd op.
|
For a more Detailed overview of TensorAdd op.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
num_bits (int): Quantization number bit, support 4 and 8bit. Default: 8.
|
|
||||||
quant_delay (int): Quantization delay parameters according by global step. Default: 0.
|
|
||||||
ema_decay (float): Exponential Moving Average algorithm parameter. Default: 0.999.
|
ema_decay (float): Exponential Moving Average algorithm parameter. Default: 0.999.
|
||||||
per_channel (bool): Quantization granularity based on layer or on channel. Default: False.
|
per_channel (bool): Quantization granularity based on layer or on channel. Default: False.
|
||||||
|
num_bits (int): Quantization number bit, support 4 and 8bit. Default: 8.
|
||||||
symmetric (bool): Quantization algorithm use symmetric or not. Default: False.
|
symmetric (bool): Quantization algorithm use symmetric or not. Default: False.
|
||||||
narrow_range (bool): Quantization algorithm use narrow range or not. Default: False.
|
narrow_range (bool): Quantization algorithm use narrow range or not. Default: False.
|
||||||
|
quant_delay (int): Quantization delay parameters according by global step. Default: 0.
|
||||||
|
|
||||||
Inputs:
|
Inputs:
|
||||||
- **x** (Tensor) - The input of TensorAddQuant.
|
- **x** (Tensor) - The input of TensorAddQuant.
|
||||||
|
@ -1049,22 +1071,22 @@ class TensorAddQuant(Cell):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
num_bits=8,
|
|
||||||
quant_delay=0,
|
|
||||||
ema_decay=0.999,
|
ema_decay=0.999,
|
||||||
per_channel=False,
|
per_channel=False,
|
||||||
|
num_bits=8,
|
||||||
symmetric=False,
|
symmetric=False,
|
||||||
narrow_range=False):
|
narrow_range=False,
|
||||||
|
quant_delay=0):
|
||||||
super(TensorAddQuant, self).__init__()
|
super(TensorAddQuant, self).__init__()
|
||||||
self.fake_quant_act = FakeQuantWithMinMax(min_init=-6,
|
self.fake_quant_act = FakeQuantWithMinMax(min_init=-6,
|
||||||
max_init=6,
|
max_init=6,
|
||||||
num_bits=num_bits,
|
|
||||||
quant_delay=quant_delay,
|
|
||||||
ema=True,
|
ema=True,
|
||||||
per_channel=per_channel,
|
|
||||||
ema_decay=ema_decay,
|
ema_decay=ema_decay,
|
||||||
|
per_channel=per_channel,
|
||||||
|
num_bits=num_bits,
|
||||||
symmetric=symmetric,
|
symmetric=symmetric,
|
||||||
narrow_range=narrow_range)
|
narrow_range=narrow_range,
|
||||||
|
quant_delay=quant_delay)
|
||||||
self.add = P.TensorAdd()
|
self.add = P.TensorAdd()
|
||||||
|
|
||||||
def construct(self, x1, x2):
|
def construct(self, x1, x2):
|
||||||
|
@ -1080,12 +1102,12 @@ class MulQuant(Cell):
|
||||||
For a more Detailed overview of Mul op.
|
For a more Detailed overview of Mul op.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
num_bits (int): Quantization number bit, support 4 and 8bit. Default: 8.
|
|
||||||
quant_delay (int): Quantization delay parameters according by global step. Default: 0.
|
|
||||||
ema_decay (float): Exponential Moving Average algorithm parameter. Default: 0.999.
|
ema_decay (float): Exponential Moving Average algorithm parameter. Default: 0.999.
|
||||||
per_channel (bool): Quantization granularity based on layer or on channel. Default: False.
|
per_channel (bool): Quantization granularity based on layer or on channel. Default: False.
|
||||||
|
num_bits (int): Quantization number bit, support 4 and 8bit. Default: 8.
|
||||||
symmetric (bool): Quantization algorithm use symmetric or not. Default: False.
|
symmetric (bool): Quantization algorithm use symmetric or not. Default: False.
|
||||||
narrow_range (bool): Quantization algorithm use narrow range or not. Default: False.
|
narrow_range (bool): Quantization algorithm use narrow range or not. Default: False.
|
||||||
|
quant_delay (int): Quantization delay parameters according by global step. Default: 0.
|
||||||
|
|
||||||
Inputs:
|
Inputs:
|
||||||
- **x** (Tensor) - The input of MulQuant.
|
- **x** (Tensor) - The input of MulQuant.
|
||||||
|
@ -1096,22 +1118,22 @@ class MulQuant(Cell):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
num_bits=8,
|
|
||||||
quant_delay=0,
|
|
||||||
ema_decay=0.999,
|
ema_decay=0.999,
|
||||||
per_channel=False,
|
per_channel=False,
|
||||||
|
num_bits=8,
|
||||||
symmetric=False,
|
symmetric=False,
|
||||||
narrow_range=False):
|
narrow_range=False,
|
||||||
|
quant_delay=0):
|
||||||
super(MulQuant, self).__init__()
|
super(MulQuant, self).__init__()
|
||||||
self.fake_quant_act = FakeQuantWithMinMax(min_init=-6,
|
self.fake_quant_act = FakeQuantWithMinMax(min_init=-6,
|
||||||
max_init=6,
|
max_init=6,
|
||||||
num_bits=num_bits,
|
|
||||||
quant_delay=quant_delay,
|
|
||||||
ema=True,
|
ema=True,
|
||||||
per_channel=per_channel,
|
|
||||||
ema_decay=ema_decay,
|
ema_decay=ema_decay,
|
||||||
|
per_channel=per_channel,
|
||||||
|
num_bits=num_bits,
|
||||||
symmetric=symmetric,
|
symmetric=symmetric,
|
||||||
narrow_range=narrow_range)
|
narrow_range=narrow_range,
|
||||||
|
quant_delay=quant_delay)
|
||||||
self.mul = P.Mul()
|
self.mul = P.Mul()
|
||||||
|
|
||||||
def construct(self, x1, x2):
|
def construct(self, x1, x2):
|
||||||
|
@ -1173,12 +1195,13 @@ class QuantBlock(Cell):
|
||||||
self.has_bias = bias is None
|
self.has_bias = bias is None
|
||||||
self.activation = activation
|
self.activation = activation
|
||||||
self.has_act = activation is None
|
self.has_act = activation is None
|
||||||
|
self.bias_add = P.BiasAdd()
|
||||||
|
|
||||||
def construct(self, x):
|
def construct(self, x):
|
||||||
x = self.quant(x)
|
x = self.quant(x)
|
||||||
x = self.core_op(x, self.weight)
|
x = self.core_op(x, self.weight)
|
||||||
if self.has_bias:
|
if self.has_bias:
|
||||||
output = self.bias_add(output, self.bias)
|
x = self.bias_add(x, self.bias)
|
||||||
if self.has_act:
|
if self.has_act:
|
||||||
x = self.activation(x)
|
x = self.activation(x)
|
||||||
x = self.dequant(x, self.dequant_scale)
|
x = self.dequant(x, self.dequant_scale)
|
||||||
|
|
|
@ -13,7 +13,7 @@
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
# ============================================================================
|
# ============================================================================
|
||||||
|
|
||||||
"""Generate bprop for aware quantization ops"""
|
"""Generate bprop for quantization aware ops"""
|
||||||
|
|
||||||
from .. import operations as P
|
from .. import operations as P
|
||||||
from ..operations import _quant_ops as Q
|
from ..operations import _quant_ops as Q
|
||||||
|
@ -133,9 +133,9 @@ def get_bprop_batchnorm_fold2_(self):
|
||||||
return bprop
|
return bprop
|
||||||
|
|
||||||
|
|
||||||
@bprop_getters.register(Q.FakeQuantMinMaxPerLayerUpdate)
|
@bprop_getters.register(Q.MinMaxUpdatePerLayer)
|
||||||
def get_bprop_fakequant_with_minmax_per_layer_update(self):
|
def get_bprop_fakequant_with_minmax_per_layer_update(self):
|
||||||
"""Generate bprop for FakeQuantMinMaxPerLayerUpdate for Ascend"""
|
"""Generate bprop for MinMaxUpdatePerLayer for Ascend"""
|
||||||
|
|
||||||
def bprop(x, x_min, x_max, out, dout):
|
def bprop(x, x_min, x_max, out, dout):
|
||||||
return zeros_like(x), zeros_like(x_min), zeros_like(x_max)
|
return zeros_like(x), zeros_like(x_min), zeros_like(x_max)
|
||||||
|
@ -143,9 +143,9 @@ def get_bprop_fakequant_with_minmax_per_layer_update(self):
|
||||||
return bprop
|
return bprop
|
||||||
|
|
||||||
|
|
||||||
@bprop_getters.register(Q.FakeQuantMinMaxPerChannelUpdate)
|
@bprop_getters.register(Q.MinMaxUpdatePerChannel)
|
||||||
def get_bprop_fakequant_with_minmax_per_channel_update(self):
|
def get_bprop_fakequant_with_minmax_per_channel_update(self):
|
||||||
"""Generate bprop for FakeQuantMinMaxPerChannelUpdate for Ascend"""
|
"""Generate bprop for MinMaxUpdatePerChannel for Ascend"""
|
||||||
|
|
||||||
def bprop(x, x_min, x_max, out, dout):
|
def bprop(x, x_min, x_max, out, dout):
|
||||||
return zeros_like(x), zeros_like(x_min), zeros_like(x_max)
|
return zeros_like(x), zeros_like(x_min), zeros_like(x_max)
|
||||||
|
|
|
@ -30,7 +30,6 @@ batchnorm_fold2_op_info = TBERegOp("BatchNormFold2_D") \
|
||||||
.compute_cost(10) \
|
.compute_cost(10) \
|
||||||
.kernel_name("batchnorm_fold2") \
|
.kernel_name("batchnorm_fold2") \
|
||||||
.partial_flag(True) \
|
.partial_flag(True) \
|
||||||
.op_pattern("formatAgnostic") \
|
|
||||||
.input(0, "x", None, "required", None) \
|
.input(0, "x", None, "required", None) \
|
||||||
.input(1, "beta", None, "required", None) \
|
.input(1, "beta", None, "required", None) \
|
||||||
.input(2, "gamma", None, "required", None) \
|
.input(2, "gamma", None, "required", None) \
|
||||||
|
|
|
@ -30,7 +30,6 @@ batchnorm_fold2_grad_op_info = TBERegOp("BatchNormFold2GradD") \
|
||||||
.compute_cost(10) \
|
.compute_cost(10) \
|
||||||
.kernel_name("batchnorm_fold2_grad") \
|
.kernel_name("batchnorm_fold2_grad") \
|
||||||
.partial_flag(True) \
|
.partial_flag(True) \
|
||||||
.op_pattern("formatAgnostic") \
|
|
||||||
.input(0, "dout", None, "required", None) \
|
.input(0, "dout", None, "required", None) \
|
||||||
.input(1, "dout_reduce", None, "required", None) \
|
.input(1, "dout_reduce", None, "required", None) \
|
||||||
.input(2, "dout_x_reduce", None, "required", None) \
|
.input(2, "dout_x_reduce", None, "required", None) \
|
||||||
|
|
|
@ -31,7 +31,6 @@ batchnorm_fold2_grad_reduce_op_info = TBERegOp("BatchNormFold2GradReduce") \
|
||||||
.compute_cost(10) \
|
.compute_cost(10) \
|
||||||
.kernel_name("batchnorm_fold2_grad_reduce") \
|
.kernel_name("batchnorm_fold2_grad_reduce") \
|
||||||
.partial_flag(True) \
|
.partial_flag(True) \
|
||||||
.op_pattern("formatAgnostic") \
|
|
||||||
.input(0, "dout", None, "required", None) \
|
.input(0, "dout", None, "required", None) \
|
||||||
.input(1, "x", None, "required", None) \
|
.input(1, "x", None, "required", None) \
|
||||||
.output(0, "dout_reduce", True, "required", "all") \
|
.output(0, "dout_reduce", True, "required", "all") \
|
||||||
|
|
|
@ -30,7 +30,6 @@ correction_mul_op_info = TBERegOp("CorrectionMul") \
|
||||||
.compute_cost(10) \
|
.compute_cost(10) \
|
||||||
.kernel_name("correction_mul") \
|
.kernel_name("correction_mul") \
|
||||||
.partial_flag(True) \
|
.partial_flag(True) \
|
||||||
.op_pattern("formatAgnostic") \
|
|
||||||
.attr("channel_axis", "optional", "int", "all") \
|
.attr("channel_axis", "optional", "int", "all") \
|
||||||
.input(0, "x", None, "required", None) \
|
.input(0, "x", None, "required", None) \
|
||||||
.input(1, "batch_std", None, "required", None) \
|
.input(1, "batch_std", None, "required", None) \
|
||||||
|
|
|
@ -30,7 +30,6 @@ correction_mul_grad_op_info = TBERegOp("CorrectionMulGrad") \
|
||||||
.compute_cost(10) \
|
.compute_cost(10) \
|
||||||
.kernel_name("correction_mul_grad") \
|
.kernel_name("correction_mul_grad") \
|
||||||
.partial_flag(True) \
|
.partial_flag(True) \
|
||||||
.op_pattern("formatAgnostic") \
|
|
||||||
.attr("channel_axis", "optional", "int", "all") \
|
.attr("channel_axis", "optional", "int", "all") \
|
||||||
.input(0, "dout", None, "required", None) \
|
.input(0, "dout", None, "required", None) \
|
||||||
.input(1, "x", None, "required", None) \
|
.input(1, "x", None, "required", None) \
|
||||||
|
@ -128,7 +127,6 @@ correction_mul_grad_reduce_op_info = TBERegOp("CorrectionMulGradReduce") \
|
||||||
.compute_cost(10) \
|
.compute_cost(10) \
|
||||||
.kernel_name("correction_mul_grad_reduce") \
|
.kernel_name("correction_mul_grad_reduce") \
|
||||||
.partial_flag(True) \
|
.partial_flag(True) \
|
||||||
.op_pattern("formatAgnostic") \
|
|
||||||
.attr("channel_axis", "optional", "int", "all") \
|
.attr("channel_axis", "optional", "int", "all") \
|
||||||
.input(0, "dout", None, "required", None) \
|
.input(0, "dout", None, "required", None) \
|
||||||
.output(0, "d_batch_std", True, "required", "all") \
|
.output(0, "d_batch_std", True, "required", "all") \
|
||||||
|
|
|
@ -99,11 +99,15 @@ def fake_quant_perchannel(x, min_val, max_val, y,
|
||||||
min_dtype = min_val.get("dtype")
|
min_dtype = min_val.get("dtype")
|
||||||
max_shape = max_val.get("ori_shape")
|
max_shape = max_val.get("ori_shape")
|
||||||
max_dtype = max_val.get("dtype")
|
max_dtype = max_val.get("dtype")
|
||||||
|
# for Dense weight quant, 2d[co,ci] -> 4d[1,co,ci,1], channel_axis_ need change to 1.
|
||||||
|
if channel_axis == 0 and x_shape_[0] != min_shape[0] and x_shape_[1] == min_shape[0]:
|
||||||
|
channel_axis_ = 1
|
||||||
|
else:
|
||||||
|
channel_axis_ = channel_axis
|
||||||
util.check_kernel_name(kernel_name)
|
util.check_kernel_name(kernel_name)
|
||||||
util.check_shape_rule(x_shape)
|
util.check_shape_rule(x_shape)
|
||||||
util.check_shape_rule(min_shape, 1, 1, x_shape_[channel_axis])
|
util.check_shape_rule(min_shape, 1, 1, x_shape_[channel_axis_])
|
||||||
util.check_shape_rule(max_shape, 1, 1, x_shape_[channel_axis])
|
util.check_shape_rule(max_shape, 1, 1, x_shape_[channel_axis_])
|
||||||
util.check_tensor_shape_size(x_shape)
|
util.check_tensor_shape_size(x_shape)
|
||||||
util.check_tensor_shape_size(min_shape)
|
util.check_tensor_shape_size(min_shape)
|
||||||
util.check_tensor_shape_size(max_shape)
|
util.check_tensor_shape_size(max_shape)
|
||||||
|
@ -126,8 +130,8 @@ def fake_quant_perchannel(x, min_val, max_val, y,
|
||||||
quant_min = quant_min + 1
|
quant_min = quant_min + 1
|
||||||
|
|
||||||
shape_c = [1] * len(x_shape)
|
shape_c = [1] * len(x_shape)
|
||||||
shape_c[channel_axis] = min_val.get("ori_shape")[0]
|
shape_c[channel_axis_] = min_val.get("ori_shape")[0]
|
||||||
if x_format == "NC1HWC0" and channel_axis == 1:
|
if x_format == "NC1HWC0" and channel_axis_ == 1:
|
||||||
shape_c = min_val.get("shape")
|
shape_c = min_val.get("shape")
|
||||||
input_data = tvm.placeholder(x_shape, name="x", dtype=x_dtype)
|
input_data = tvm.placeholder(x_shape, name="x", dtype=x_dtype)
|
||||||
min_data = tvm.placeholder(shape_c, name="min_val", dtype=x_dtype)
|
min_data = tvm.placeholder(shape_c, name="min_val", dtype=x_dtype)
|
||||||
|
|
|
@ -124,11 +124,15 @@ def fake_quant_perchannel_grad(dout, x, min_val, max_val, dx,
|
||||||
min_dtype = min_val.get("dtype")
|
min_dtype = min_val.get("dtype")
|
||||||
max_shape = max_val.get("ori_shape")
|
max_shape = max_val.get("ori_shape")
|
||||||
max_dtype = max_val.get("dtype")
|
max_dtype = max_val.get("dtype")
|
||||||
|
# for Dense weight quant, 2d[co,ci] -> 4d[1,co,ci,1], channel_axis_ need change to 1.
|
||||||
|
if channel_axis == 0 and x_shape_[0] != min_shape[0] and x_shape_[1] == min_shape[0]:
|
||||||
|
channel_axis_ = 1
|
||||||
|
else:
|
||||||
|
channel_axis_ = channel_axis
|
||||||
util.check_kernel_name(kernel_name)
|
util.check_kernel_name(kernel_name)
|
||||||
util.check_shape_rule(x_shape)
|
util.check_shape_rule(x_shape)
|
||||||
util.check_shape_rule(min_shape, 1, 1, x_shape_[channel_axis])
|
util.check_shape_rule(min_shape, 1, 1, x_shape_[channel_axis_])
|
||||||
util.check_shape_rule(max_shape, 1, 1, x_shape_[channel_axis])
|
util.check_shape_rule(max_shape, 1, 1, x_shape_[channel_axis_])
|
||||||
util.check_tensor_shape_size(x_shape)
|
util.check_tensor_shape_size(x_shape)
|
||||||
util.check_tensor_shape_size(min_shape)
|
util.check_tensor_shape_size(min_shape)
|
||||||
util.check_tensor_shape_size(max_shape)
|
util.check_tensor_shape_size(max_shape)
|
||||||
|
@ -151,8 +155,8 @@ def fake_quant_perchannel_grad(dout, x, min_val, max_val, dx,
|
||||||
quant_min = quant_min + 1
|
quant_min = quant_min + 1
|
||||||
|
|
||||||
shape_c = [1] * len(x_shape)
|
shape_c = [1] * len(x_shape)
|
||||||
shape_c[channel_axis] = min_val.get("ori_shape")[0]
|
shape_c[channel_axis_] = min_val.get("ori_shape")[0]
|
||||||
if x_format == "NC1HWC0" and channel_axis == 1:
|
if x_format == "NC1HWC0" and channel_axis_ == 1:
|
||||||
shape_c = min_val.get("shape")
|
shape_c = min_val.get("shape")
|
||||||
dout_data = tvm.placeholder(x_shape, name="dout", dtype=x_dtype)
|
dout_data = tvm.placeholder(x_shape, name="dout", dtype=x_dtype)
|
||||||
input_data = tvm.placeholder(x_shape, name="x", dtype=x_dtype)
|
input_data = tvm.placeholder(x_shape, name="x", dtype=x_dtype)
|
||||||
|
|
|
@ -1,4 +1,3 @@
|
||||||
|
|
||||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
#
|
#
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
@ -14,7 +13,7 @@
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
# ============================================================================
|
# ============================================================================
|
||||||
|
|
||||||
"""FakeQuantMinMaxPerChannelUpdate op"""
|
"""MinMaxUpdatePerChannel op"""
|
||||||
import te.lang.cce
|
import te.lang.cce
|
||||||
from te import tvm
|
from te import tvm
|
||||||
from te.platform.fusion_manager import fusion_manager
|
from te.platform.fusion_manager import fusion_manager
|
||||||
|
@ -22,20 +21,15 @@ from topi import generic
|
||||||
from topi.cce import util
|
from topi.cce import util
|
||||||
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
|
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
|
||||||
|
|
||||||
|
minmax_update_perchannel_op_info = TBERegOp("MinMaxUpdatePerChannel") \
|
||||||
fake_quant_min_max_per_channel_update_op_info = TBERegOp("FakeQuantMinMaxPerChannelUpdate") \
|
|
||||||
.fusion_type("OPAQUE") \
|
.fusion_type("OPAQUE") \
|
||||||
.async_flag(False) \
|
.async_flag(False) \
|
||||||
.binfile_name("fake_quant_min_max_per_channel_update.so") \
|
.binfile_name("minmax_update_perchannel.so") \
|
||||||
.compute_cost(10) \
|
.compute_cost(10) \
|
||||||
.kernel_name("fake_quant_min_max_per_channel_update") \
|
.kernel_name("minmax_update_perchannel") \
|
||||||
.partial_flag(True) \
|
.partial_flag(True) \
|
||||||
.attr("ema", "optional", "bool", "all") \
|
.attr("ema", "optional", "bool", "all") \
|
||||||
.attr("ema_decay", "optional", "float", "all") \
|
.attr("ema_decay", "optional", "float", "all") \
|
||||||
.attr("symmetric", "optional", "bool", "all") \
|
|
||||||
.attr("narrow_range", "optional", "bool", "all") \
|
|
||||||
.attr("training", "optional", "bool", "all") \
|
|
||||||
.attr("num_bits", "optional", "int", "all") \
|
|
||||||
.attr("channel_axis", "optional", "int", "all") \
|
.attr("channel_axis", "optional", "int", "all") \
|
||||||
.input(0, "x", None, "required", None) \
|
.input(0, "x", None, "required", None) \
|
||||||
.input(1, "min", None, "required", None) \
|
.input(1, "min", None, "required", None) \
|
||||||
|
@ -47,43 +41,46 @@ fake_quant_min_max_per_channel_update_op_info = TBERegOp("FakeQuantMinMaxPerChan
|
||||||
.get_op_info()
|
.get_op_info()
|
||||||
|
|
||||||
|
|
||||||
@op_info_register(fake_quant_min_max_per_channel_update_op_info)
|
@op_info_register(minmax_update_perchannel_op_info)
|
||||||
def _fake_quant_min_max_per_channel_update_tbe():
|
def _minmax_update_perchannel_tbe():
|
||||||
"""FakeQuantPerChannelUpdate TBE register"""
|
"""MinMaxUpdatePerChannel TBE register"""
|
||||||
return
|
return
|
||||||
|
|
||||||
|
|
||||||
@fusion_manager.register("fake_quant_min_max_per_channel_update")
|
@fusion_manager.register("minmax_update_perchannel")
|
||||||
def fake_quant_min_max_per_channel_update_compute(x, min_val, max_val,
|
def minmax_update_perchannel_compute(x, min_val, max_val,
|
||||||
ema, ema_decay, quant_min, quant_max, training, channel_axis,
|
ema, ema_decay, channel_axis):
|
||||||
kernel_name="fake_quant_min_max_per_channel_update"):
|
"""MinMaxUpdatePerChannel compute"""
|
||||||
"""FakeQuantPerChannelUpdate compute"""
|
|
||||||
shape_min = te.lang.cce.util.shape_to_list(min_val.shape)
|
shape_min = te.lang.cce.util.shape_to_list(min_val.shape)
|
||||||
|
|
||||||
if not ema:
|
if not ema:
|
||||||
ema_decay = 0.0
|
ema_decay = 0.0
|
||||||
if training:
|
|
||||||
# CalMinMax
|
# CalMinMax
|
||||||
|
if channel_axis == 0:
|
||||||
|
axis = [1, 2, 3, 4]
|
||||||
|
else:
|
||||||
axis = [0, 2, 3]
|
axis = [0, 2, 3]
|
||||||
x_min = te.lang.cce.reduce_min(x, axis=axis)
|
|
||||||
x_max = te.lang.cce.reduce_max(x, axis=axis)
|
x_min = te.lang.cce.reduce_min(x, axis=axis)
|
||||||
x_min = te.lang.cce.broadcast(x_min, shape_min)
|
x_max = te.lang.cce.reduce_max(x, axis=axis)
|
||||||
x_max = te.lang.cce.broadcast(x_max, shape_min)
|
x_min = te.lang.cce.broadcast(x_min, shape_min)
|
||||||
min_val = te.lang.cce.vadd(te.lang.cce.vmuls(
|
x_max = te.lang.cce.broadcast(x_max, shape_min)
|
||||||
min_val, ema_decay), te.lang.cce.vmuls(x_min, (1 - ema_decay)))
|
min_val = te.lang.cce.vadd(te.lang.cce.vmuls(
|
||||||
max_val = te.lang.cce.vadd(te.lang.cce.vmuls(
|
min_val, ema_decay), te.lang.cce.vmuls(x_min, (1 - ema_decay)))
|
||||||
max_val, ema_decay), te.lang.cce.vmuls(x_max, (1 - ema_decay)))
|
max_val = te.lang.cce.vadd(te.lang.cce.vmuls(
|
||||||
min_val = te.lang.cce.vmins(min_val, 0)
|
max_val, ema_decay), te.lang.cce.vmuls(x_max, (1 - ema_decay)))
|
||||||
max_val = te.lang.cce.vmaxs(max_val, 0)
|
min_val = te.lang.cce.vmins(min_val, 0)
|
||||||
|
max_val = te.lang.cce.vmaxs(max_val, 0)
|
||||||
|
|
||||||
return [min_val, max_val]
|
return [min_val, max_val]
|
||||||
|
|
||||||
|
|
||||||
@util.check_input_type(dict, dict, dict, dict, dict, bool, float, bool, bool, bool, int, int, str)
|
@util.check_input_type(dict, dict, dict, dict, dict, bool, float, int, str)
|
||||||
def fake_quant_min_max_per_channel_update(x, min_val, max_val, min_up, max_up,
|
def minmax_update_perchannel(x, min_val, max_val, min_up, max_up,
|
||||||
ema, ema_decay, symmetric, narrow_range, training, num_bits, channel_axis,
|
ema, ema_decay, channel_axis,
|
||||||
kernel_name="fake_quant_min_max_per_channel_update"):
|
kernel_name="minmax_update_perchannel"):
|
||||||
"""FakeQuantPerLayer op"""
|
"""MinMaxUpdatePerChannel op"""
|
||||||
x_shape = x.get("ori_shape")
|
x_shape = x.get("ori_shape")
|
||||||
x_format = x.get("format")
|
x_format = x.get("format")
|
||||||
x_dtype = x.get("dtype")
|
x_dtype = x.get("dtype")
|
||||||
|
@ -91,11 +88,15 @@ def fake_quant_min_max_per_channel_update(x, min_val, max_val, min_up, max_up,
|
||||||
min_dtype = min_val.get("dtype")
|
min_dtype = min_val.get("dtype")
|
||||||
max_shape = max_val.get("ori_shape")
|
max_shape = max_val.get("ori_shape")
|
||||||
max_dtype = max_val.get("dtype")
|
max_dtype = max_val.get("dtype")
|
||||||
|
# for Dense weight quant, 2d[co,ci] -> 4d[1,co,ci,1], channel_axis_ need change to 1.
|
||||||
|
if channel_axis == 0 and x_shape[0] != min_shape[0] and x_shape[1] == min_shape[0]:
|
||||||
|
channel_axis_ = 1
|
||||||
|
else:
|
||||||
|
channel_axis_ = channel_axis
|
||||||
util.check_kernel_name(kernel_name)
|
util.check_kernel_name(kernel_name)
|
||||||
util.check_shape_rule(x_shape)
|
util.check_shape_rule(x_shape)
|
||||||
util.check_shape_rule(min_shape, 1, 1, x_shape[channel_axis])
|
util.check_shape_rule(min_shape, 1, 1, x_shape[channel_axis_])
|
||||||
util.check_shape_rule(max_shape, 1, 1, x_shape[channel_axis])
|
util.check_shape_rule(max_shape, 1, 1, x_shape[channel_axis_])
|
||||||
util.check_tensor_shape_size(x_shape)
|
util.check_tensor_shape_size(x_shape)
|
||||||
util.check_tensor_shape_size(min_shape)
|
util.check_tensor_shape_size(min_shape)
|
||||||
util.check_tensor_shape_size(max_shape)
|
util.check_tensor_shape_size(max_shape)
|
||||||
|
@ -108,21 +109,15 @@ def fake_quant_min_max_per_channel_update(x, min_val, max_val, min_up, max_up,
|
||||||
util.check_dtype_rule(min_dtype, check_list)
|
util.check_dtype_rule(min_dtype, check_list)
|
||||||
util.check_dtype_rule(max_dtype, check_list)
|
util.check_dtype_rule(max_dtype, check_list)
|
||||||
|
|
||||||
if symmetric:
|
if channel_axis_ == 0:
|
||||||
quant_min = 0 - 2 ** (num_bits - 1)
|
shape_c = min_val.get("ori_shape")
|
||||||
quant_max = 2 ** (num_bits - 1) - 1
|
|
||||||
else:
|
else:
|
||||||
quant_min = 0
|
shape_c = [min_val.get("shape")[1], min_val.get("shape")[-1]]
|
||||||
quant_max = 2 ** num_bits - 1
|
|
||||||
if narrow_range:
|
|
||||||
quant_min = quant_min + 1
|
|
||||||
|
|
||||||
shape_c = [min_val.get("shape")[1], min_val.get("shape")[-1]]
|
|
||||||
input_data = tvm.placeholder(x.get("shape"), name="x", dtype=x_dtype)
|
input_data = tvm.placeholder(x.get("shape"), name="x", dtype=x_dtype)
|
||||||
min_data = tvm.placeholder(shape_c, name="min_val", dtype=x_dtype)
|
min_data = tvm.placeholder(shape_c, name="min_val", dtype=x_dtype)
|
||||||
max_data = tvm.placeholder(shape_c, name="max_val", dtype=x_dtype)
|
max_data = tvm.placeholder(shape_c, name="max_val", dtype=x_dtype)
|
||||||
res_list = fake_quant_min_max_per_channel_update_compute(input_data, min_data, max_data,
|
res_list = minmax_update_perchannel_compute(input_data, min_data, max_data,
|
||||||
ema, ema_decay, quant_min, quant_max, training, channel_axis, kernel_name)
|
ema, ema_decay, channel_axis_)
|
||||||
|
|
||||||
with tvm.target.cce():
|
with tvm.target.cce():
|
||||||
sch = generic.auto_schedule(res_list)
|
sch = generic.auto_schedule(res_list)
|
|
@ -13,7 +13,7 @@
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
# ============================================================================
|
# ============================================================================
|
||||||
|
|
||||||
"""FakeQuantMinMaxPerLayerUpdate op"""
|
"""MinMaxUpdatePerLayer op"""
|
||||||
from functools import reduce as functools_reduce
|
from functools import reduce as functools_reduce
|
||||||
import te.lang.cce
|
import te.lang.cce
|
||||||
from te import tvm
|
from te import tvm
|
||||||
|
@ -22,20 +22,15 @@ from topi import generic
|
||||||
from topi.cce import util
|
from topi.cce import util
|
||||||
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
|
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
|
||||||
|
|
||||||
|
minmax_update_perlayer_op_info = TBERegOp("MinMaxUpdatePerLayer") \
|
||||||
fake_quant_minmax_update_op_info = TBERegOp("FakeQuantMinMaxPerLayerUpdate") \
|
|
||||||
.fusion_type("OPAQUE") \
|
.fusion_type("OPAQUE") \
|
||||||
.async_flag(False) \
|
.async_flag(False) \
|
||||||
.binfile_name("fake_quant_minmax_update.so") \
|
.binfile_name("minmax_update_perlayer.so") \
|
||||||
.compute_cost(10) \
|
.compute_cost(10) \
|
||||||
.kernel_name("fake_quant_minmax_update") \
|
.kernel_name("minmax_update_perlayer") \
|
||||||
.partial_flag(True) \
|
.partial_flag(True) \
|
||||||
.attr("ema", "optional", "bool", "all") \
|
.attr("ema", "optional", "bool", "all") \
|
||||||
.attr("ema_decay", "optional", "float", "all") \
|
.attr("ema_decay", "optional", "float", "all") \
|
||||||
.attr("symmetric", "optional", "bool", "all") \
|
|
||||||
.attr("narrow_range", "optional", "bool", "all") \
|
|
||||||
.attr("training", "optional", "bool", "all") \
|
|
||||||
.attr("num_bits", "optional", "int", "all") \
|
|
||||||
.input(0, "x", None, "required", None) \
|
.input(0, "x", None, "required", None) \
|
||||||
.input(1, "min", None, "required", None) \
|
.input(1, "min", None, "required", None) \
|
||||||
.input(2, "max", None, "required", None) \
|
.input(2, "max", None, "required", None) \
|
||||||
|
@ -46,44 +41,42 @@ fake_quant_minmax_update_op_info = TBERegOp("FakeQuantMinMaxPerLayerUpdate") \
|
||||||
.get_op_info()
|
.get_op_info()
|
||||||
|
|
||||||
|
|
||||||
@op_info_register(fake_quant_minmax_update_op_info)
|
@op_info_register(minmax_update_perlayer_op_info)
|
||||||
def _fake_quant_minmax_update_tbe():
|
def _minmax_update_perlayer_tbe():
|
||||||
"""FakeQuantMinMaxPerLayerUpdate TBE register"""
|
"""MinMaxUpdatePerLayer TBE register"""
|
||||||
return
|
return
|
||||||
|
|
||||||
|
|
||||||
@fusion_manager.register("fake_quant_minmax_update")
|
@fusion_manager.register("minmax_update_perlayer")
|
||||||
def fake_quant_minmax_update_compute(x, min_val, max_val, ema, ema_decay, quant_min, quant_max, training,
|
def minmax_update_perlayer_compute(x, min_val, max_val, ema, ema_decay):
|
||||||
kernel_name="fake_quant_minmax_update"):
|
"""MinMaxUpdatePerLayer compute"""
|
||||||
"""FakeQuantMinMaxPerLayerUpdate compute"""
|
|
||||||
shape = te.lang.cce.util.shape_to_list(x.shape)
|
shape = te.lang.cce.util.shape_to_list(x.shape)
|
||||||
shape_min = te.lang.cce.util.shape_to_list(min_val.shape)
|
shape_min = te.lang.cce.util.shape_to_list(min_val.shape)
|
||||||
min_val = te.lang.cce.broadcast(min_val, shape_min, x.dtype)
|
min_val = te.lang.cce.broadcast(min_val, shape_min, x.dtype)
|
||||||
max_val = te.lang.cce.broadcast(max_val, shape_min, x.dtype)
|
max_val = te.lang.cce.broadcast(max_val, shape_min, x.dtype)
|
||||||
if not ema:
|
if not ema:
|
||||||
ema_decay = 0.0
|
ema_decay = 0.0
|
||||||
if training:
|
|
||||||
# CalMinMax
|
# CalMinMax
|
||||||
axis = tuple(range(len(shape)))
|
axis = tuple(range(len(shape)))
|
||||||
x_min = te.lang.cce.reduce_min(x, axis=axis)
|
x_min = te.lang.cce.reduce_min(x, axis=axis)
|
||||||
x_max = te.lang.cce.reduce_max(x, axis=axis)
|
x_max = te.lang.cce.reduce_max(x, axis=axis)
|
||||||
x_min = te.lang.cce.broadcast(x_min, shape_min)
|
x_min = te.lang.cce.broadcast(x_min, shape_min)
|
||||||
x_max = te.lang.cce.broadcast(x_max, shape_min)
|
x_max = te.lang.cce.broadcast(x_max, shape_min)
|
||||||
min_val = te.lang.cce.vadd(te.lang.cce.vmuls(
|
min_val = te.lang.cce.vadd(te.lang.cce.vmuls(
|
||||||
min_val, ema_decay), te.lang.cce.vmuls(x_min, (1 - ema_decay)))
|
min_val, ema_decay), te.lang.cce.vmuls(x_min, (1 - ema_decay)))
|
||||||
max_val = te.lang.cce.vadd(te.lang.cce.vmuls(
|
max_val = te.lang.cce.vadd(te.lang.cce.vmuls(
|
||||||
max_val, ema_decay), te.lang.cce.vmuls(x_max, (1 - ema_decay)))
|
max_val, ema_decay), te.lang.cce.vmuls(x_max, (1 - ema_decay)))
|
||||||
min_val = te.lang.cce.vmins(min_val, 0)
|
min_val = te.lang.cce.vmins(min_val, 0)
|
||||||
max_val = te.lang.cce.vmaxs(max_val, 0)
|
max_val = te.lang.cce.vmaxs(max_val, 0)
|
||||||
|
|
||||||
return [min_val, max_val]
|
return [min_val, max_val]
|
||||||
|
|
||||||
|
|
||||||
@util.check_input_type(dict, dict, dict, dict, dict, bool, float, bool, bool, bool, int, str)
|
@util.check_input_type(dict, dict, dict, dict, dict, bool, float, str)
|
||||||
def fake_quant_minmax_update(x, min_val, max_val, min_up, max_up,
|
def minmax_update_perlayer(x, min_val, max_val, min_up, max_up,
|
||||||
ema, ema_decay, symmetric, narrow_range, training, num_bits,
|
ema, ema_decay, kernel_name="minmax_update_perlayer"):
|
||||||
kernel_name="fake_quant_minmax_update"):
|
"""MinMaxUpdatePerLayer op"""
|
||||||
"""FakeQuantPerLayer op"""
|
|
||||||
input_shape = x.get("shape")
|
input_shape = x.get("shape")
|
||||||
input_dtype = x.get("dtype")
|
input_dtype = x.get("dtype")
|
||||||
min_shape = min_val.get("ori_shape")
|
min_shape = min_val.get("ori_shape")
|
||||||
|
@ -112,20 +105,10 @@ def fake_quant_minmax_update(x, min_val, max_val, min_up, max_up,
|
||||||
input_shape = (functools_reduce(lambda x, y: x * y, input_shape[:]),)
|
input_shape = (functools_reduce(lambda x, y: x * y, input_shape[:]),)
|
||||||
shape_min, _, _ = util.produce_shapes(min_shape, input_shape)
|
shape_min, _, _ = util.produce_shapes(min_shape, input_shape)
|
||||||
|
|
||||||
if symmetric:
|
|
||||||
quant_min = 0 - 2 ** (num_bits - 1)
|
|
||||||
quant_max = 2 ** (num_bits - 1) - 1
|
|
||||||
else:
|
|
||||||
quant_min = 0
|
|
||||||
quant_max = 2 ** num_bits - 1
|
|
||||||
if narrow_range:
|
|
||||||
quant_min = quant_min + 1
|
|
||||||
|
|
||||||
input_data = tvm.placeholder(input_shape, name="x", dtype=x_dtype)
|
input_data = tvm.placeholder(input_shape, name="x", dtype=x_dtype)
|
||||||
min_data = tvm.placeholder(shape_min, name="min_data", dtype=min_dtype)
|
min_data = tvm.placeholder(shape_min, name="min_data", dtype=min_dtype)
|
||||||
max_data = tvm.placeholder(shape_min, name="max_data", dtype=max_dtype)
|
max_data = tvm.placeholder(shape_min, name="max_data", dtype=max_dtype)
|
||||||
res_list = fake_quant_minmax_update_compute(input_data, min_data, max_data,
|
res_list = minmax_update_perlayer_compute(input_data, min_data, max_data, ema, ema_decay)
|
||||||
ema, ema_decay, quant_min, quant_max, training, kernel_name)
|
|
||||||
|
|
||||||
with tvm.target.cce():
|
with tvm.target.cce():
|
||||||
sch = generic.auto_schedule(res_list)
|
sch = generic.auto_schedule(res_list)
|
|
@ -21,12 +21,12 @@ from ..._checkparam import Rel
|
||||||
from ..primitive import PrimitiveWithInfer, prim_attr_register
|
from ..primitive import PrimitiveWithInfer, prim_attr_register
|
||||||
from ...common import dtype as mstype
|
from ...common import dtype as mstype
|
||||||
|
|
||||||
__all__ = ["FakeQuantPerLayer",
|
__all__ = ["MinMaxUpdatePerLayer",
|
||||||
|
"MinMaxUpdatePerChannel",
|
||||||
|
"FakeQuantPerLayer",
|
||||||
"FakeQuantPerLayerGrad",
|
"FakeQuantPerLayerGrad",
|
||||||
"FakeQuantPerChannel",
|
"FakeQuantPerChannel",
|
||||||
"FakeQuantPerChannelGrad",
|
"FakeQuantPerChannelGrad",
|
||||||
"FakeQuantMinMaxPerLayerUpdate",
|
|
||||||
"FakeQuantMinMaxPerChannelUpdate",
|
|
||||||
"BatchNormFold",
|
"BatchNormFold",
|
||||||
"BatchNormFoldGrad",
|
"BatchNormFoldGrad",
|
||||||
"CorrectionMul",
|
"CorrectionMul",
|
||||||
|
@ -38,20 +38,141 @@ __all__ = ["FakeQuantPerLayer",
|
||||||
"BatchNormFoldGradD",
|
"BatchNormFoldGradD",
|
||||||
"BatchNormFold2_D",
|
"BatchNormFold2_D",
|
||||||
"BatchNormFold2GradD",
|
"BatchNormFold2GradD",
|
||||||
"BatchNormFold2GradReduce",
|
"BatchNormFold2GradReduce"
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
class MinMaxUpdatePerLayer(PrimitiveWithInfer):
|
||||||
|
r"""
|
||||||
|
Update min and max per layer.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
ema (bool): Use EMA algorithm update value min and max. Default: False.
|
||||||
|
ema_decay (int) : EMA algorithm decay parameter. Default: 0.999.
|
||||||
|
|
||||||
|
Inputs:
|
||||||
|
- **x** (Tensor) : float32 Tensor representing the shape of the output tensor.
|
||||||
|
- **min** (Tensor) : Value of the min range of the input data x.
|
||||||
|
- **max** (Tensor) : Value of the max range of the input data x.
|
||||||
|
|
||||||
|
Outputs:
|
||||||
|
- Tensor: Simulate quantize tensor of x.
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
>>> input_tensor = Tensor(np.random.rand(3, 16, 5, 5), mstype.float32)
|
||||||
|
>>> min_tensor = Tensor(np.array([-6]), mstype.float32)
|
||||||
|
>>> max_tensor = Tensor(np.array([6]), mstype.float32)
|
||||||
|
>>> output_tensor = MinMaxUpdatePerLayer(num_bits=8)(input_tensor, min_tensor, max_tensor)
|
||||||
|
"""
|
||||||
|
support_quant_bit = [4, 7, 8]
|
||||||
|
|
||||||
|
@prim_attr_register
|
||||||
|
def __init__(self, ema=False, ema_decay=0.999):
|
||||||
|
"""init FakeQuantMinMaxPerLayerUpdate OP"""
|
||||||
|
if context.get_context('device_target') == "Ascend":
|
||||||
|
from mindspore.ops._op_impl._custom_op import minmax_update_perlayer
|
||||||
|
if ema and not ema_decay:
|
||||||
|
raise ValueError(
|
||||||
|
f"For '{self.name}' attr \'ema\' and \'ema_decay\' should set together.")
|
||||||
|
|
||||||
|
self.ema = validator.check_value_type('ema', ema, (bool,), self.name)
|
||||||
|
self.ema_decay = validator.check_number_range(
|
||||||
|
'ema_decay', ema_decay, 0, 1, Rel.INC_BOTH, self.name)
|
||||||
|
self.init_prim_io_names(inputs=['x', 'min', 'max'],
|
||||||
|
outputs=['min_up', 'max_up'])
|
||||||
|
|
||||||
|
def infer_shape(self, x_shape, min_shape, max_shape):
|
||||||
|
validator.check_integer("x rank", len(x_shape), 1, Rel.GE, self.name)
|
||||||
|
validator.check("min shape", min_shape, "max shape",
|
||||||
|
max_shape, Rel.EQ, self.name)
|
||||||
|
validator.check_integer("min shape", len(
|
||||||
|
min_shape), 1, Rel.EQ, self.name)
|
||||||
|
return min_shape, max_shape
|
||||||
|
|
||||||
|
def infer_dtype(self, x_type, min_type, max_type):
|
||||||
|
valid_types = (mstype.float16, mstype.float32)
|
||||||
|
validator.check_tensor_type_same({"x": x_type}, valid_types, self.name)
|
||||||
|
validator.check_tensor_type_same(
|
||||||
|
{"min": min_type}, valid_types, self.name)
|
||||||
|
validator.check_tensor_type_same(
|
||||||
|
{"max": max_type}, valid_types, self.name)
|
||||||
|
return min_type, max_type
|
||||||
|
|
||||||
|
|
||||||
|
class MinMaxUpdatePerChannel(PrimitiveWithInfer):
|
||||||
|
r"""
|
||||||
|
Update min and max per channel.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
ema (bool): Use EMA algorithm update value min and max. Default: False.
|
||||||
|
ema_decay (int) : EMA algorithm decay parameter. Default: 0.999.
|
||||||
|
channel_axis (int): Quantization by channel axis, support 0 and 1. Default: 1.
|
||||||
|
|
||||||
|
Inputs:
|
||||||
|
- **x** (Tensor) : float32 Tensor representing the shape of the output tensor.
|
||||||
|
- **min** (Tensor) : Value of the min range of the input data x.
|
||||||
|
- **max** (Tensor) : Value of the max range of the input data x.
|
||||||
|
|
||||||
|
Outputs:
|
||||||
|
- Tensor: Simulate quantize tensor of x.
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
>>> x = Tensor(np.random.rand(3, 16, 5, 5), mstype.float32)
|
||||||
|
>>> min = Tensor(np.random.uniform(-1, 1, size=16), mstype.float32)
|
||||||
|
>>> max = Tensor(np.random.uniform(-1, 1, size=16), mstype.float32)
|
||||||
|
>>> output_tensor = MinMaxUpdatePerChannel(num_bits=8)(x, min, max)
|
||||||
|
"""
|
||||||
|
support_quant_bit = [4, 7, 8]
|
||||||
|
support_x_rank = [2, 4]
|
||||||
|
|
||||||
|
@prim_attr_register
|
||||||
|
def __init__(self, ema=False, ema_decay=0.999, channel_axis=1):
|
||||||
|
"""init FakeQuantPerChannelUpdate OP for Ascend"""
|
||||||
|
if context.get_context('device_target') == "Ascend":
|
||||||
|
from mindspore.ops._op_impl._custom_op import minmax_update_perchannel
|
||||||
|
if ema and not ema_decay:
|
||||||
|
raise ValueError(
|
||||||
|
f"For '{self.name}' attr \'ema\' and \'ema_decay\' should set together.")
|
||||||
|
|
||||||
|
self.ema = validator.check_value_type('ema', ema, (bool,), self.name)
|
||||||
|
self.ema_decay = validator.check_number_range(
|
||||||
|
'ema_decay', ema_decay, 0, 1, Rel.INC_BOTH, self.name)
|
||||||
|
self.channel_axis = validator.check_int_range(
|
||||||
|
'channel_axis', channel_axis, 0, 1, Rel.INC_BOTH, self.name)
|
||||||
|
self.init_prim_io_names(
|
||||||
|
inputs=['x', 'min', 'max'], outputs=['min_up', 'max_up'])
|
||||||
|
|
||||||
|
def infer_shape(self, x_shape, min_shape, max_shape):
|
||||||
|
if len(x_shape) not in self.support_x_rank:
|
||||||
|
raise ValueError(f"For '{self.name}' x rank should be in '{self.support_x_rank}'")
|
||||||
|
validator.check_integer("x rank", len(x_shape), 1, Rel.GT, self.name)
|
||||||
|
validator.check("min shape", min_shape, "max shape",
|
||||||
|
max_shape, Rel.EQ, self.name)
|
||||||
|
validator.check_integer("min shape", len(
|
||||||
|
min_shape), 1, Rel.EQ, self.name)
|
||||||
|
return min_shape, max_shape
|
||||||
|
|
||||||
|
def infer_dtype(self, x_type, min_type, max_type):
|
||||||
|
valid_types = (mstype.float16, mstype.float32)
|
||||||
|
validator.check_tensor_type_same(
|
||||||
|
{"x": x_type}, valid_types, self.name)
|
||||||
|
validator.check_tensor_type_same(
|
||||||
|
{"min": min_type}, valid_types, self.name)
|
||||||
|
validator.check_tensor_type_same(
|
||||||
|
{"max": max_type}, valid_types, self.name)
|
||||||
|
return min_type, max_type
|
||||||
|
|
||||||
|
|
||||||
class FakeQuantPerLayer(PrimitiveWithInfer):
|
class FakeQuantPerLayer(PrimitiveWithInfer):
|
||||||
r"""
|
r"""
|
||||||
Simulate the quantize and dequantize operations in training time.
|
Simulate the quantize and dequantize operations in training time.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
num_bits (int) : Number bits for aware quantilization. Default: 8.
|
num_bits (int) : Number bits for quantization aware. Default: 8.
|
||||||
ema (bool): Use EMA algorithm update value min and max. Default: False.
|
ema (bool): Use EMA algorithm update value min and max. Default: False.
|
||||||
ema_decay (int) : EMA algorithm decay parameter. Default: 0.999.
|
ema_decay (int) : EMA algorithm decay parameter. Default: 0.999.
|
||||||
quant_delay (int): Quantilization delay parameter. Before delay step in training time not update
|
quant_delay (int): Quantilization delay parameter. Before delay step in training time not update
|
||||||
simulate aware quantize funcion. After delay step in training time begin simulate the aware
|
simulate quantization aware funcion. After delay step in training time begin simulate the aware
|
||||||
quantize funcion. Default: 0.
|
quantize funcion. Default: 0.
|
||||||
symmetric (bool): Quantization algorithm use symmetric or not. Default: False.
|
symmetric (bool): Quantization algorithm use symmetric or not. Default: False.
|
||||||
narrow_range (bool): Quantization algorithm use narrow range or not. Default: False.
|
narrow_range (bool): Quantization algorithm use narrow range or not. Default: False.
|
||||||
|
@ -103,8 +224,8 @@ class FakeQuantPerLayer(PrimitiveWithInfer):
|
||||||
'ema_decay', ema_decay, 0, 1, Rel.INC_BOTH, self.name)
|
'ema_decay', ema_decay, 0, 1, Rel.INC_BOTH, self.name)
|
||||||
self.num_bits = validator.check_integer(
|
self.num_bits = validator.check_integer(
|
||||||
'num_bits', num_bits, 0, Rel.GT, self.name)
|
'num_bits', num_bits, 0, Rel.GT, self.name)
|
||||||
self.quant_delay = validator.check_value_type(
|
self.quant_delay = validator.check_integer(
|
||||||
'quant_delay', quant_delay, (int,), self.name)
|
'quant_delay', quant_delay, 0, Rel.GE, self.name)
|
||||||
self.init_prim_io_names(inputs=['x', 'min', 'max'],
|
self.init_prim_io_names(inputs=['x', 'min', 'max'],
|
||||||
outputs=['out'])
|
outputs=['out'])
|
||||||
|
|
||||||
|
@ -196,6 +317,7 @@ class FakeQuantPerChannel(PrimitiveWithInfer):
|
||||||
symmetric (bool): Quantization algorithm use symmetric or not. Default: False.
|
symmetric (bool): Quantization algorithm use symmetric or not. Default: False.
|
||||||
narrow_range (bool): Quantization algorithm use narrow range or not. Default: False.
|
narrow_range (bool): Quantization algorithm use narrow range or not. Default: False.
|
||||||
training (bool): Training the network or not. Default: True.
|
training (bool): Training the network or not. Default: True.
|
||||||
|
channel_axis (int): Quantization by channel axis, support 0 and 1. Default: 1.
|
||||||
|
|
||||||
Inputs:
|
Inputs:
|
||||||
- **x** (Tensor) : 4-D float32 Tensor representing the shape of the output tensor.
|
- **x** (Tensor) : 4-D float32 Tensor representing the shape of the output tensor.
|
||||||
|
@ -213,6 +335,7 @@ class FakeQuantPerChannel(PrimitiveWithInfer):
|
||||||
>>> result = fake_quant(input_x, _min, _max)
|
>>> result = fake_quant(input_x, _min, _max)
|
||||||
"""
|
"""
|
||||||
support_quant_bit = [4, 7, 8]
|
support_quant_bit = [4, 7, 8]
|
||||||
|
support_x_rank = [2, 4]
|
||||||
|
|
||||||
@prim_attr_register
|
@prim_attr_register
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
|
@ -245,14 +368,15 @@ class FakeQuantPerChannel(PrimitiveWithInfer):
|
||||||
'ema_decay', ema_decay, 0, 1, Rel.INC_BOTH, self.name)
|
'ema_decay', ema_decay, 0, 1, Rel.INC_BOTH, self.name)
|
||||||
self.num_bits = validator.check_integer(
|
self.num_bits = validator.check_integer(
|
||||||
'num_bits', num_bits, 0, Rel.GT, self.name)
|
'num_bits', num_bits, 0, Rel.GT, self.name)
|
||||||
self.quant_delay = validator.check_value_type(
|
self.quant_delay = validator.check_integer(
|
||||||
'quant_delay', quant_delay, (int,), self.name)
|
'quant_delay', quant_delay, 0, Rel.GE, self.name)
|
||||||
self.channel_axis = validator.check_integer(
|
self.channel_axis = validator.check_int_range(
|
||||||
'channel_axis', channel_axis, 0, Rel.GE, self.name)
|
'channel_axis', channel_axis, 0, 1, Rel.INC_BOTH, self.name)
|
||||||
self.init_prim_io_names(inputs=['x', 'min', 'max'], outputs=['out'])
|
self.init_prim_io_names(inputs=['x', 'min', 'max'], outputs=['out'])
|
||||||
|
|
||||||
def infer_shape(self, x_shape, min_shape, max_shape):
|
def infer_shape(self, x_shape, min_shape, max_shape):
|
||||||
validator.check_integer("x rank", len(x_shape), 1, Rel.GE, self.name)
|
if len(x_shape) not in self.support_x_rank:
|
||||||
|
raise ValueError(f"For '{self.name}' x rank should be in '{self.support_x_rank}'")
|
||||||
validator.check("min shape", min_shape, "max shape", max_shape, Rel.EQ, self.name)
|
validator.check("min shape", min_shape, "max shape", max_shape, Rel.EQ, self.name)
|
||||||
validator.check_integer(
|
validator.check_integer(
|
||||||
"min shape", min_shape[0], x_shape[self.channel_axis], Rel.EQ, self.name)
|
"min shape", min_shape[0], x_shape[self.channel_axis], Rel.EQ, self.name)
|
||||||
|
@ -832,153 +956,3 @@ class BatchNormFold2GradReduce(PrimitiveWithInfer):
|
||||||
def infer_dtype(self, dout_type, x_type):
|
def infer_dtype(self, dout_type, x_type):
|
||||||
validator.check("dout type", dout_type, "x type", x_type)
|
validator.check("dout type", dout_type, "x type", x_type)
|
||||||
return dout_type, dout_type
|
return dout_type, dout_type
|
||||||
|
|
||||||
|
|
||||||
class FakeQuantMinMaxPerLayerUpdate(PrimitiveWithInfer):
|
|
||||||
r"""
|
|
||||||
Update min and max value for fake quant per layer op.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
num_bits (int) : Number bits for aware quantilization. Default: 8.
|
|
||||||
ema (bool): Use EMA algorithm update value min and max. Default: False.
|
|
||||||
ema_decay (int) : EMA algorithm decay parameter. Default: 0.999.
|
|
||||||
symmetric (bool): Quantization algorithm use symmetric or not. Default: False.
|
|
||||||
narrow_range (bool): Quantization algorithm use narrow range or not. Default: False.
|
|
||||||
training (bool): Training the network or not. Default: True.
|
|
||||||
|
|
||||||
Inputs:
|
|
||||||
- **x** (Tensor) : float32 Tensor representing the shape of the output tensor.
|
|
||||||
- **min** (Tensor) : Value of the min range of the input data x.
|
|
||||||
- **max** (Tensor) : Value of the max range of the input data x.
|
|
||||||
|
|
||||||
Outputs:
|
|
||||||
- Tensor: Simulate quantize tensor of x.
|
|
||||||
|
|
||||||
Examples:
|
|
||||||
>>> input_tensor = Tensor(np.random.rand(3, 16, 5, 5), mstype.float32)
|
|
||||||
>>> min_tensor = Tensor(np.array([-6]), mstype.float32)
|
|
||||||
>>> max_tensor = Tensor(np.array([6]), mstype.float32)
|
|
||||||
>>> output_tensor = FakeQuantWithMinMax(num_bits=8)(input_tensor, min_tensor, max_tensor)
|
|
||||||
"""
|
|
||||||
support_quant_bit = [4, 7, 8]
|
|
||||||
|
|
||||||
@prim_attr_register
|
|
||||||
def __init__(self, num_bits=8, ema=False, ema_decay=0.999, symmetric=False, narrow_range=False,
|
|
||||||
training=True):
|
|
||||||
"""init FakeQuantMinMaxPerLayerUpdate OP"""
|
|
||||||
if context.get_context('device_target') == "Ascend":
|
|
||||||
from mindspore.ops._op_impl._custom_op import fake_quant_minmax_perlayer_update
|
|
||||||
if num_bits not in self.support_quant_bit:
|
|
||||||
raise ValueError(
|
|
||||||
f"For '{self.name}' attr \'num_bits\' is not support.")
|
|
||||||
if ema and not ema_decay:
|
|
||||||
raise ValueError(
|
|
||||||
f"For '{self.name}' attr \'ema\' and \'ema_decay\' should set together.")
|
|
||||||
|
|
||||||
self.ema = validator.check_value_type('ema', ema, (bool,), self.name)
|
|
||||||
self.symmetric = validator.check_value_type(
|
|
||||||
'symmetric', symmetric, (bool,), self.name)
|
|
||||||
self.narrow_range = validator.check_value_type(
|
|
||||||
'narrow_range', narrow_range, (bool,), self.name)
|
|
||||||
self.training = validator.check_value_type(
|
|
||||||
'training', training, (bool,), self.name)
|
|
||||||
self.ema_decay = validator.check_number_range(
|
|
||||||
'ema_decay', ema_decay, 0, 1, Rel.INC_BOTH, self.name)
|
|
||||||
self.num_bits = validator.check_integer(
|
|
||||||
'num_bits', num_bits, 0, Rel.GT, self.name)
|
|
||||||
self.init_prim_io_names(inputs=['x', 'min', 'max'],
|
|
||||||
outputs=['min_up', 'max_up'])
|
|
||||||
|
|
||||||
def infer_shape(self, x_shape, min_shape, max_shape):
|
|
||||||
validator.check_integer("x rank", len(x_shape), 1, Rel.GE, self.name)
|
|
||||||
validator.check("min shape", min_shape, "max shape",
|
|
||||||
max_shape, Rel.EQ, self.name)
|
|
||||||
validator.check_integer("min shape", len(
|
|
||||||
min_shape), 1, Rel.EQ, self.name)
|
|
||||||
return min_shape, max_shape
|
|
||||||
|
|
||||||
def infer_dtype(self, x_type, min_type, max_type):
|
|
||||||
valid_types = (mstype.float16, mstype.float32)
|
|
||||||
validator.check_tensor_type_same({"x": x_type}, valid_types, self.name)
|
|
||||||
validator.check_tensor_type_same(
|
|
||||||
{"min": min_type}, valid_types, self.name)
|
|
||||||
validator.check_tensor_type_same(
|
|
||||||
{"max": max_type}, valid_types, self.name)
|
|
||||||
return min_type, max_type
|
|
||||||
|
|
||||||
|
|
||||||
class FakeQuantMinMaxPerChannelUpdate(PrimitiveWithInfer):
|
|
||||||
r"""
|
|
||||||
Update min and max value for fake quant per layer op.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
num_bits (int) : Number bits for aware quantilization. Default: 8.
|
|
||||||
ema (bool): Use EMA algorithm update value min and max. Default: False.
|
|
||||||
ema_decay (int) : EMA algorithm decay parameter. Default: 0.999.
|
|
||||||
symmetric (bool): Quantization algorithm use symmetric or not. Default: False.
|
|
||||||
narrow_range (bool): Quantization algorithm use narrow range or not. Default: False.
|
|
||||||
training (bool): Training the network or not. Default: True.
|
|
||||||
channel_axis (int): Channel asis for per channel compute. Default: 1.
|
|
||||||
|
|
||||||
Inputs:
|
|
||||||
- **x** (Tensor) : float32 Tensor representing the shape of the output tensor.
|
|
||||||
- **min** (Tensor) : Value of the min range of the input data x.
|
|
||||||
- **max** (Tensor) : Value of the max range of the input data x.
|
|
||||||
|
|
||||||
Outputs:
|
|
||||||
- Tensor: Simulate quantize tensor of x.
|
|
||||||
|
|
||||||
Examples:
|
|
||||||
>>> x = Tensor(np.random.rand(3, 16, 5, 5), mstype.float32)
|
|
||||||
>>> min = Tensor(np.random.uniform(-1, 1, size=16), mstype.float32)
|
|
||||||
>>> max = Tensor(np.random.uniform(-1, 1, size=16), mstype.float32)
|
|
||||||
>>> output_tensor = FakeQuantWithMinMax(num_bits=8)(x, min, max)
|
|
||||||
"""
|
|
||||||
support_quant_bit = [4, 7, 8]
|
|
||||||
|
|
||||||
@prim_attr_register
|
|
||||||
def __init__(self, num_bits=8, ema=False, ema_decay=0.999, symmetric=False, narrow_range=False,
|
|
||||||
training=True, channel_axis=1):
|
|
||||||
"""init FakeQuantPerChannelUpdate OP for Ascend"""
|
|
||||||
if context.get_context('device_target') == "Ascend":
|
|
||||||
from mindspore.ops._op_impl._custom_op import fake_quant_minmax_perchannel_update
|
|
||||||
if num_bits not in self.support_quant_bit:
|
|
||||||
raise ValueError(
|
|
||||||
f"For '{self.name}' attr \'num_bits\' is not support.")
|
|
||||||
if ema and not ema_decay:
|
|
||||||
raise ValueError(
|
|
||||||
f"For '{self.name}' attr \'ema\' and \'ema_decay\' should set together.")
|
|
||||||
|
|
||||||
self.ema = validator.check_value_type('ema', ema, (bool,), self.name)
|
|
||||||
self.symmetric = validator.check_value_type(
|
|
||||||
'symmetric', symmetric, (bool,), self.name)
|
|
||||||
self.narrow_range = validator.check_value_type(
|
|
||||||
'narrow_range', narrow_range, (bool,), self.name)
|
|
||||||
self.training = validator.check_value_type(
|
|
||||||
'training', training, (bool,), self.name)
|
|
||||||
self.ema_decay = validator.check_number_range(
|
|
||||||
'ema_decay', ema_decay, 0, 1, Rel.INC_BOTH, self.name)
|
|
||||||
self.num_bits = validator.check_integer(
|
|
||||||
'num_bits', num_bits, 0, Rel.GT, self.name)
|
|
||||||
self.channel_axis = validator.check_integer(
|
|
||||||
'channel axis', channel_axis, 0, Rel.GE, self.name)
|
|
||||||
self.init_prim_io_names(
|
|
||||||
inputs=['x', 'min', 'max'], outputs=['min_up', 'max_up'])
|
|
||||||
|
|
||||||
def infer_shape(self, x_shape, min_shape, max_shape):
|
|
||||||
validator.check_integer("x rank", len(x_shape), 1, Rel.GT, self.name)
|
|
||||||
validator.check("min shape", min_shape, "max shape",
|
|
||||||
max_shape, Rel.EQ, self.name)
|
|
||||||
validator.check_integer("min shape", len(
|
|
||||||
min_shape), 1, Rel.EQ, self.name)
|
|
||||||
return min_shape, max_shape
|
|
||||||
|
|
||||||
def infer_dtype(self, x_type, min_type, max_type):
|
|
||||||
valid_types = (mstype.float16, mstype.float32)
|
|
||||||
validator.check_tensor_type_same(
|
|
||||||
{"x": x_type}, valid_types, self.name)
|
|
||||||
validator.check_tensor_type_same(
|
|
||||||
{"min": min_type}, valid_types, self.name)
|
|
||||||
validator.check_tensor_type_same(
|
|
||||||
{"max": max_type}, valid_types, self.name)
|
|
||||||
return min_type, max_type
|
|
||||||
|
|
|
@ -32,7 +32,6 @@ from ...ops.operations import _inner_ops as inner
|
||||||
from ...train import serialization
|
from ...train import serialization
|
||||||
from . import quant_utils
|
from . import quant_utils
|
||||||
|
|
||||||
|
|
||||||
_ACTIVATION_MAP = {nn.ReLU: quant.ReLUQuant,
|
_ACTIVATION_MAP = {nn.ReLU: quant.ReLUQuant,
|
||||||
nn.ReLU6: quant.ReLU6Quant,
|
nn.ReLU6: quant.ReLU6Quant,
|
||||||
nn.HSigmoid: quant.HSigmoidQuant,
|
nn.HSigmoid: quant.HSigmoidQuant,
|
||||||
|
@ -41,15 +40,14 @@ _ACTIVATION_MAP = {nn.ReLU: quant.ReLUQuant,
|
||||||
|
|
||||||
class _AddFakeQuantInput(nn.Cell):
|
class _AddFakeQuantInput(nn.Cell):
|
||||||
"""
|
"""
|
||||||
Add FakeQuant at input and output of the Network. Only support one input and one output case.
|
Add FakeQuant OP at input of the network. Only support one input case.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, network, quant_delay=0):
|
def __init__(self, network, quant_delay=0):
|
||||||
super(_AddFakeQuantInput, self).__init__(auto_prefix=False)
|
super(_AddFakeQuantInput, self).__init__(auto_prefix=False)
|
||||||
self.network = network
|
self.fake_quant_input = quant.FakeQuantWithMinMax(min_init=-6, max_init=6, quant_delay=quant_delay, ema=True)
|
||||||
self.fake_quant_input = quant.FakeQuantWithMinMax(
|
|
||||||
min_init=-6, max_init=6, quant_delay=quant_delay, ema=True)
|
|
||||||
self.fake_quant_input.update_parameters_name('fake_quant_input')
|
self.fake_quant_input.update_parameters_name('fake_quant_input')
|
||||||
|
self.network = network
|
||||||
|
|
||||||
def construct(self, data):
|
def construct(self, data):
|
||||||
data = self.fake_quant_input(data)
|
data = self.fake_quant_input(data)
|
||||||
|
@ -59,7 +57,7 @@ class _AddFakeQuantInput(nn.Cell):
|
||||||
|
|
||||||
class _AddFakeQuantAfterSubCell(nn.Cell):
|
class _AddFakeQuantAfterSubCell(nn.Cell):
|
||||||
"""
|
"""
|
||||||
Add FakeQuant after of the sub Cell.
|
Add FakeQuant OP after of the sub Cell.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, subcell, **kwargs):
|
def __init__(self, subcell, **kwargs):
|
||||||
|
@ -114,11 +112,12 @@ class ConvertToQuantNetwork:
|
||||||
self.network.update_cell_prefix()
|
self.network.update_cell_prefix()
|
||||||
network = self._convert_subcells2quant(self.network)
|
network = self._convert_subcells2quant(self.network)
|
||||||
network = _AddFakeQuantInput(network)
|
network = _AddFakeQuantInput(network)
|
||||||
|
self.network.update_cell_type("quant")
|
||||||
return network
|
return network
|
||||||
|
|
||||||
def _convert_subcells2quant(self, network):
|
def _convert_subcells2quant(self, network):
|
||||||
"""
|
"""
|
||||||
convet sub cell to quant cell
|
convert sub cell like `Conv2dBnAct` and `DenseBnAct` to quant cell
|
||||||
"""
|
"""
|
||||||
cells = network.name_cells()
|
cells = network.name_cells()
|
||||||
change = False
|
change = False
|
||||||
|
@ -137,19 +136,19 @@ class ConvertToQuantNetwork:
|
||||||
if isinstance(network, nn.SequentialCell) and change:
|
if isinstance(network, nn.SequentialCell) and change:
|
||||||
network.cell_list = list(network.cells())
|
network.cell_list = list(network.cells())
|
||||||
|
|
||||||
# tensoradd to tensoradd quant
|
# add FakeQuant OP after OP in while list
|
||||||
add_list = []
|
add_list = []
|
||||||
for name in network.__dict__:
|
for name in network.__dict__:
|
||||||
if name[0] == '_':
|
if name[0] == '_':
|
||||||
continue
|
continue
|
||||||
attr = network.__dict__[name]
|
attr = network.__dict__[name]
|
||||||
if isinstance(attr, ops.Primitive) and attr.name in ConvertToQuantNetwork.__quant_op_name__:
|
if isinstance(attr, ops.Primitive) and attr.name in self.__quant_op_name__:
|
||||||
add_list.append((name, attr))
|
add_list.append((name, attr))
|
||||||
for name, prim_op in add_list:
|
for name, prim_op in add_list:
|
||||||
prefix = name
|
prefix = name
|
||||||
add_quant = _AddFakeQuantAfterSubCell(prim_op,
|
add_quant = _AddFakeQuantAfterSubCell(prim_op,
|
||||||
num_bits=self.act_bits,
|
num_bits=self.act_bits,
|
||||||
quant_delay=self.act_delay,
|
quant_delay=self.act_qdelay,
|
||||||
per_channel=self.act_channel,
|
per_channel=self.act_channel,
|
||||||
symmetric=self.act_symmetric,
|
symmetric=self.act_symmetric,
|
||||||
narrow_range=self.act_range)
|
narrow_range=self.act_range)
|
||||||
|
@ -161,11 +160,11 @@ class ConvertToQuantNetwork:
|
||||||
|
|
||||||
def _convert_conv(self, subcell):
|
def _convert_conv(self, subcell):
|
||||||
"""
|
"""
|
||||||
convet conv cell to quant cell
|
convert Conv2d cell to quant cell
|
||||||
"""
|
"""
|
||||||
conv_inner = subcell.conv
|
conv_inner = subcell.conv
|
||||||
bn_inner = subcell.batchnorm
|
bn_inner = subcell.batchnorm
|
||||||
if subcell.batchnorm is not None and self.bn_fold:
|
if subcell.has_bn and self.bn_fold:
|
||||||
conv_inner = quant.Conv2dBatchNormQuant(conv_inner.in_channels,
|
conv_inner = quant.Conv2dBatchNormQuant(conv_inner.in_channels,
|
||||||
conv_inner.out_channels,
|
conv_inner.out_channels,
|
||||||
kernel_size=conv_inner.kernel_size,
|
kernel_size=conv_inner.kernel_size,
|
||||||
|
@ -175,7 +174,7 @@ class ConvertToQuantNetwork:
|
||||||
dilation=conv_inner.dilation,
|
dilation=conv_inner.dilation,
|
||||||
group=conv_inner.group,
|
group=conv_inner.group,
|
||||||
eps=bn_inner.eps,
|
eps=bn_inner.eps,
|
||||||
momentum=bn_inner.momentum,
|
momentum=1 - bn_inner.momentum,
|
||||||
quant_delay=self.weight_qdelay,
|
quant_delay=self.weight_qdelay,
|
||||||
freeze_bn=self.freeze_bn,
|
freeze_bn=self.freeze_bn,
|
||||||
per_channel=self.weight_channel,
|
per_channel=self.weight_channel,
|
||||||
|
@ -183,6 +182,11 @@ class ConvertToQuantNetwork:
|
||||||
fake=True,
|
fake=True,
|
||||||
symmetric=self.weight_symmetric,
|
symmetric=self.weight_symmetric,
|
||||||
narrow_range=self.weight_range)
|
narrow_range=self.weight_range)
|
||||||
|
# change original network BatchNormal OP parameters to quant network
|
||||||
|
conv_inner.gamma = subcell.batchnorm.gamma
|
||||||
|
conv_inner.beta = subcell.batchnorm.beta
|
||||||
|
conv_inner.moving_mean = subcell.batchnorm.moving_mean
|
||||||
|
conv_inner.moving_variance = subcell.batchnorm.moving_variance
|
||||||
del subcell.batchnorm
|
del subcell.batchnorm
|
||||||
subcell.batchnorm = None
|
subcell.batchnorm = None
|
||||||
subcell.has_bn = False
|
subcell.has_bn = False
|
||||||
|
@ -201,6 +205,10 @@ class ConvertToQuantNetwork:
|
||||||
num_bits=self.weight_bits,
|
num_bits=self.weight_bits,
|
||||||
symmetric=self.weight_symmetric,
|
symmetric=self.weight_symmetric,
|
||||||
narrow_range=self.weight_range)
|
narrow_range=self.weight_range)
|
||||||
|
# change original network Conv2D OP parameters to quant network
|
||||||
|
conv_inner.weight = subcell.conv.weight
|
||||||
|
if subcell.conv.has_bias:
|
||||||
|
conv_inner.bias = subcell.conv.bias
|
||||||
subcell.conv = conv_inner
|
subcell.conv = conv_inner
|
||||||
if subcell.has_act and subcell.activation is not None:
|
if subcell.has_act and subcell.activation is not None:
|
||||||
subcell.activation = self._convert_activation(subcell.activation)
|
subcell.activation = self._convert_activation(subcell.activation)
|
||||||
|
@ -227,6 +235,10 @@ class ConvertToQuantNetwork:
|
||||||
per_channel=self.weight_channel,
|
per_channel=self.weight_channel,
|
||||||
symmetric=self.weight_symmetric,
|
symmetric=self.weight_symmetric,
|
||||||
narrow_range=self.weight_range)
|
narrow_range=self.weight_range)
|
||||||
|
# change original network Dense OP parameters to quant network
|
||||||
|
dense_inner.weight = subcell.dense.weight
|
||||||
|
if subcell.dense.has_bias:
|
||||||
|
dense_inner.bias = subcell.dense.bias
|
||||||
subcell.dense = dense_inner
|
subcell.dense = dense_inner
|
||||||
if subcell.has_act and subcell.activation is not None:
|
if subcell.has_act and subcell.activation is not None:
|
||||||
subcell.activation = self._convert_activation(subcell.activation)
|
subcell.activation = self._convert_activation(subcell.activation)
|
||||||
|
@ -234,7 +246,7 @@ class ConvertToQuantNetwork:
|
||||||
subcell.has_act = True
|
subcell.has_act = True
|
||||||
subcell.activation = _AddFakeQuantAfterSubCell(F.identity,
|
subcell.activation = _AddFakeQuantAfterSubCell(F.identity,
|
||||||
num_bits=self.act_bits,
|
num_bits=self.act_bits,
|
||||||
quant_delay=self.act_delay,
|
quant_delay=self.act_qdelay,
|
||||||
per_channel=self.act_channel,
|
per_channel=self.act_channel,
|
||||||
symmetric=self.act_symmetric,
|
symmetric=self.act_symmetric,
|
||||||
narrow_range=self.act_range)
|
narrow_range=self.act_range)
|
||||||
|
@ -244,12 +256,12 @@ class ConvertToQuantNetwork:
|
||||||
act_class = activation.__class__
|
act_class = activation.__class__
|
||||||
if act_class not in _ACTIVATION_MAP:
|
if act_class not in _ACTIVATION_MAP:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"Unsupported activation in auto Quant: ", act_class)
|
"Unsupported activation in auto quant: ", act_class)
|
||||||
return _ACTIVATION_MAP[act_class](num_bits=self.act_bits,
|
return _ACTIVATION_MAP[act_class](num_bits=self.act_bits,
|
||||||
quant_delay=self.act_qdelay,
|
quant_delay=self.act_qdelay,
|
||||||
per_channel=self.act_channel,
|
per_channel=self.act_channel,
|
||||||
symmetric=self.weight_symmetric,
|
symmetric=self.act_symmetric,
|
||||||
narrow_range=self.weight_range)
|
narrow_range=self.act_range)
|
||||||
|
|
||||||
|
|
||||||
class ExportQuantNetworkDeploy:
|
class ExportQuantNetworkDeploy:
|
||||||
|
@ -403,7 +415,7 @@ def convert_quant_network(network,
|
||||||
narrow_range=(False, False)
|
narrow_range=(False, False)
|
||||||
):
|
):
|
||||||
r"""
|
r"""
|
||||||
Create aware quantizaiton training network.
|
Create quantization aware training network.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
network (Cell): Obtain a pipeline through network for saving graph summary.
|
network (Cell): Obtain a pipeline through network for saving graph summary.
|
||||||
|
@ -417,7 +429,7 @@ def convert_quant_network(network,
|
||||||
then base on per channel otherwise base on per layer. The first element represent weights
|
then base on per channel otherwise base on per layer. The first element represent weights
|
||||||
and second element represent data flow. Default: [False, False]
|
and second element represent data flow. Default: [False, False]
|
||||||
symmetric (list of bool): Quantization algorithm use symmetric or not. If `True` then base on
|
symmetric (list of bool): Quantization algorithm use symmetric or not. If `True` then base on
|
||||||
symmetric otherwise base on assymmetric. The first element represent weights and second
|
symmetric otherwise base on asymmetric. The first element represent weights and second
|
||||||
element represent data flow. Default: [False, False]
|
element represent data flow. Default: [False, False]
|
||||||
narrow_range (list of bool): Quantization algorithm use narrow range or not. If `True` then base
|
narrow_range (list of bool): Quantization algorithm use narrow range or not. If `True` then base
|
||||||
on narrow range otherwise base on off narrow range. The first element represent weights and
|
on narrow range otherwise base on off narrow range. The first element represent weights and
|
||||||
|
@ -426,6 +438,7 @@ def convert_quant_network(network,
|
||||||
Returns:
|
Returns:
|
||||||
Cell, Network which has change to aware quantization training network cell.
|
Cell, Network which has change to aware quantization training network cell.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def convert2list(name, value):
|
def convert2list(name, value):
|
||||||
if not isinstance(value, list) and not isinstance(value, tuple):
|
if not isinstance(value, list) and not isinstance(value, tuple):
|
||||||
value = [value]
|
value = [value]
|
||||||
|
|
|
@ -2,13 +2,13 @@
|
||||||
|
|
||||||
## Description
|
## Description
|
||||||
|
|
||||||
Training LeNet with MNIST dataset in MindSpore with quantization aware trainging.
|
Training LeNet with MNIST dataset in MindSpore with quantization aware training.
|
||||||
|
|
||||||
This is the simple and basic tutorial for constructing a network in MindSpore with quantization aware.
|
This is the simple and basic tutorial for constructing a network in MindSpore with quantization aware.
|
||||||
|
|
||||||
In this tutorial, you will:
|
In this tutorial, you will:
|
||||||
|
|
||||||
1. Train a Mindspore fusion model for MNIST from scratch using `nn.Conv2dBnAct` and `nn.DenseBnAct`.
|
1. Train a MindSpore fusion model for MNIST from scratch using `nn.Conv2dBnAct` and `nn.DenseBnAct`.
|
||||||
2. Fine tune the fusion model by applying the quantization aware training auto network converter API `convert_quant_network`, after the network convergence then export a quantization aware model checkpoint file.
|
2. Fine tune the fusion model by applying the quantization aware training auto network converter API `convert_quant_network`, after the network convergence then export a quantization aware model checkpoint file.
|
||||||
3. Use the quantization aware model to create an actually quantized model for the Ascend inference backend.
|
3. Use the quantization aware model to create an actually quantized model for the Ascend inference backend.
|
||||||
4. See the persistence of accuracy in inference backend and a 4x smaller model. To see the latency benefits on mobile, try out the Ascend inference backend examples.
|
4. See the persistence of accuracy in inference backend and a 4x smaller model. To see the latency benefits on mobile, try out the Ascend inference backend examples.
|
||||||
|
@ -24,16 +24,16 @@ Install MindSpore base on the ascend device and GPU device from [MindSpore](http
|
||||||
```python
|
```python
|
||||||
pip uninstall -y mindspore-ascend
|
pip uninstall -y mindspore-ascend
|
||||||
pip uninstall -y mindspore-gpu
|
pip uninstall -y mindspore-gpu
|
||||||
pip install mindspore-ascend-0.4.0.whl
|
pip install mindspore-ascend.whl
|
||||||
```
|
```
|
||||||
|
|
||||||
then you will get the following display
|
Then you will get the following display
|
||||||
|
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
>>> Found existing installation: mindspore-ascend
|
>>> Found existing installation: mindspore-ascend
|
||||||
>>> Uninstalling mindspore-ascend:
|
>>> Uninstalling mindspore-ascend:
|
||||||
>>> Successfully uninstalled mindspore-ascend.
|
>>> Successfully uninstalled mindspore-ascend.
|
||||||
```
|
```
|
||||||
|
|
||||||
### Prepare Dataset
|
### Prepare Dataset
|
||||||
|
@ -87,7 +87,7 @@ class LeNet5(nn.Cell):
|
||||||
return x
|
return x
|
||||||
```
|
```
|
||||||
|
|
||||||
get the MNIST from scratch dataset.
|
Get the MNIST from scratch dataset.
|
||||||
|
|
||||||
```Python
|
```Python
|
||||||
ds_train = create_dataset(os.path.join(args.data_path, "train"),
|
ds_train = create_dataset(os.path.join(args.data_path, "train"),
|
||||||
|
@ -97,7 +97,7 @@ step_size = ds_train.get_dataset_size()
|
||||||
|
|
||||||
### Train model
|
### Train model
|
||||||
|
|
||||||
Load teh Lenet fusion network, traing network using loss `nn.SoftmaxCrossEntropyWithLogits` with optimization `nn.Momentum`.
|
Load the Lenet fusion network, training network using loss `nn.SoftmaxCrossEntropyWithLogits` with optimization `nn.Momentum`.
|
||||||
|
|
||||||
```Python
|
```Python
|
||||||
# Define the network
|
# Define the network
|
||||||
|
@ -133,7 +133,7 @@ After all the following we will get the loss value of each step as following:
|
||||||
>>> Epoch: [ 9/ 10] step: [889/ 900], loss: [0.0233/0.0223], time: [1.300234]
|
>>> Epoch: [ 9/ 10] step: [889/ 900], loss: [0.0233/0.0223], time: [1.300234]
|
||||||
```
|
```
|
||||||
|
|
||||||
To save your time, just run this command.
|
Also, you can just run this command instead.
|
||||||
|
|
||||||
```python
|
```python
|
||||||
python train.py --data_path MNIST_Data --device_target Ascend
|
python train.py --data_path MNIST_Data --device_target Ascend
|
||||||
|
@ -165,17 +165,17 @@ Note that the resulting model is quantization aware but not quantized (e.g. the
|
||||||
# define funsion network
|
# define funsion network
|
||||||
network = LeNet5Fusion(cfg.num_classes)
|
network = LeNet5Fusion(cfg.num_classes)
|
||||||
|
|
||||||
# load aware quantizaiton network checkpoint
|
# load quantization aware network checkpoint
|
||||||
param_dict = load_checkpoint(args.ckpt_path)
|
param_dict = load_checkpoint(args.ckpt_path)
|
||||||
load_param_into_net(network, param_dict)
|
load_param_into_net(network, param_dict)
|
||||||
|
|
||||||
# convert funsion netwrok to aware quantizaiton network
|
# convert funsion netwrok to quantization aware network
|
||||||
network = quant.convert_quant_network(network)
|
network = quant.convert_quant_network(network)
|
||||||
```
|
```
|
||||||
|
|
||||||
### load checkpoint
|
### load checkpoint
|
||||||
|
|
||||||
after convert to quantization aware network, we can load the checkpoint file.
|
After convert to quantization aware network, we can load the checkpoint file.
|
||||||
|
|
||||||
```python
|
```python
|
||||||
config_ck = CheckpointConfig(save_checkpoint_steps=cfg.epoch_size * step_size,
|
config_ck = CheckpointConfig(save_checkpoint_steps=cfg.epoch_size * step_size,
|
||||||
|
@ -186,7 +186,7 @@ model = Model(network, net_loss, net_opt, metrics={"Accuracy": Accuracy()})
|
||||||
|
|
||||||
### train quantization aware model
|
### train quantization aware model
|
||||||
|
|
||||||
To save your time, just run this command.
|
Also, you can just run this command instead.
|
||||||
|
|
||||||
```python
|
```python
|
||||||
python train_quant.py --data_path MNIST_Data --device_target Ascend --ckpt_path checkpoint_lenet.ckpt
|
python train_quant.py --data_path MNIST_Data --device_target Ascend --ckpt_path checkpoint_lenet.ckpt
|
||||||
|
@ -210,7 +210,7 @@ Procedure of quantization aware model evaluation is different from normal. Becau
|
||||||
# define funsion network
|
# define funsion network
|
||||||
network = LeNet5Fusion(cfg.num_classes)
|
network = LeNet5Fusion(cfg.num_classes)
|
||||||
|
|
||||||
# load aware quantizaiton network checkpoint
|
# load quantization aware network checkpoint
|
||||||
param_dict = load_checkpoint(args.ckpt_path)
|
param_dict = load_checkpoint(args.ckpt_path)
|
||||||
load_param_into_net(network, param_dict)
|
load_param_into_net(network, param_dict)
|
||||||
|
|
||||||
|
@ -218,10 +218,10 @@ load_param_into_net(network, param_dict)
|
||||||
network = quant.convert_quant_network(network)
|
network = quant.convert_quant_network(network)
|
||||||
```
|
```
|
||||||
|
|
||||||
To save your time, just run this command.
|
Also, you can just run this command insread.
|
||||||
|
|
||||||
```python
|
```python
|
||||||
python eval.py --data_path MNIST_Data --device_target Ascend --ckpt_path checkpoint_lenet.ckpt
|
python eval_quant.py --data_path MNIST_Data --device_target Ascend --ckpt_path checkpoint_lenet.ckpt
|
||||||
```
|
```
|
||||||
|
|
||||||
The top1 accuracy would display on shell.
|
The top1 accuracy would display on shell.
|
||||||
|
@ -235,7 +235,7 @@ The top1 accuracy would display on shell.
|
||||||
Here are some optional parameters:
|
Here are some optional parameters:
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
--device_target {Ascend,GPU,CPU}
|
--device_target {Ascend,GPU}
|
||||||
device where the code will be implemented (default: Ascend)
|
device where the code will be implemented (default: Ascend)
|
||||||
--data_path DATA_PATH
|
--data_path DATA_PATH
|
||||||
path where the dataset is saved
|
path where the dataset is saved
|
||||||
|
|
|
@ -31,7 +31,7 @@ from src.lenet_fusion import LeNet5 as LeNet5Fusion
|
||||||
|
|
||||||
parser = argparse.ArgumentParser(description='MindSpore MNIST Example')
|
parser = argparse.ArgumentParser(description='MindSpore MNIST Example')
|
||||||
parser.add_argument('--device_target', type=str, default="Ascend",
|
parser.add_argument('--device_target', type=str, default="Ascend",
|
||||||
choices=['Ascend', 'GPU', 'CPU'],
|
choices=['Ascend', 'GPU'],
|
||||||
help='device where the code will be implemented (default: Ascend)')
|
help='device where the code will be implemented (default: Ascend)')
|
||||||
parser.add_argument('--data_path', type=str, default="./MNIST_Data",
|
parser.add_argument('--data_path', type=str, default="./MNIST_Data",
|
||||||
help='path where the dataset is saved')
|
help='path where the dataset is saved')
|
||||||
|
|
|
@ -32,7 +32,7 @@ from src.lenet_fusion import LeNet5 as LeNet5Fusion
|
||||||
|
|
||||||
parser = argparse.ArgumentParser(description='MindSpore MNIST Example')
|
parser = argparse.ArgumentParser(description='MindSpore MNIST Example')
|
||||||
parser.add_argument('--device_target', type=str, default="Ascend",
|
parser.add_argument('--device_target', type=str, default="Ascend",
|
||||||
choices=['Ascend', 'GPU', 'CPU'],
|
choices=['Ascend', 'GPU'],
|
||||||
help='device where the code will be implemented (default: Ascend)')
|
help='device where the code will be implemented (default: Ascend)')
|
||||||
parser.add_argument('--data_path', type=str, default="./MNIST_Data",
|
parser.add_argument('--data_path', type=str, default="./MNIST_Data",
|
||||||
help='path where the dataset is saved')
|
help='path where the dataset is saved')
|
||||||
|
@ -61,7 +61,7 @@ if __name__ == "__main__":
|
||||||
model = Model(network, net_loss, net_opt, metrics={"Accuracy": Accuracy()})
|
model = Model(network, net_loss, net_opt, metrics={"Accuracy": Accuracy()})
|
||||||
|
|
||||||
# load quantization aware network checkpoint
|
# load quantization aware network checkpoint
|
||||||
param_dict = load_checkpoint(args.ckpt_path, model_type="quant")
|
param_dict = load_checkpoint(args.ckpt_path)
|
||||||
load_param_into_net(network, param_dict)
|
load_param_into_net(network, param_dict)
|
||||||
|
|
||||||
print("============== Starting Testing ==============")
|
print("============== Starting Testing ==============")
|
||||||
|
|
|
@ -31,7 +31,7 @@ from src.lenet_fusion import LeNet5 as LeNet5Fusion
|
||||||
|
|
||||||
parser = argparse.ArgumentParser(description='MindSpore MNIST Example')
|
parser = argparse.ArgumentParser(description='MindSpore MNIST Example')
|
||||||
parser.add_argument('--device_target', type=str, default="Ascend",
|
parser.add_argument('--device_target', type=str, default="Ascend",
|
||||||
choices=['Ascend', 'GPU', 'CPU'],
|
choices=['Ascend', 'GPU'],
|
||||||
help='device where the code will be implemented (default: Ascend)')
|
help='device where the code will be implemented (default: Ascend)')
|
||||||
parser.add_argument('--data_path', type=str, default="./MNIST_Data",
|
parser.add_argument('--data_path', type=str, default="./MNIST_Data",
|
||||||
help='path where the dataset is saved')
|
help='path where the dataset is saved')
|
||||||
|
@ -56,8 +56,7 @@ if __name__ == "__main__":
|
||||||
# call back and monitor
|
# call back and monitor
|
||||||
time_cb = TimeMonitor(data_size=ds_train.get_dataset_size())
|
time_cb = TimeMonitor(data_size=ds_train.get_dataset_size())
|
||||||
config_ckpt = CheckpointConfig(save_checkpoint_steps=cfg.epoch_size * step_size,
|
config_ckpt = CheckpointConfig(save_checkpoint_steps=cfg.epoch_size * step_size,
|
||||||
keep_checkpoint_max=cfg.keep_checkpoint_max,
|
keep_checkpoint_max=cfg.keep_checkpoint_max)
|
||||||
model_type=network.type)
|
|
||||||
ckpt_callback = ModelCheckpoint(prefix="checkpoint_lenet", config=config_ckpt)
|
ckpt_callback = ModelCheckpoint(prefix="checkpoint_lenet", config=config_ckpt)
|
||||||
|
|
||||||
# define model
|
# define model
|
||||||
|
|
|
@ -33,7 +33,7 @@ from src.lenet_fusion import LeNet5 as LeNet5Fusion
|
||||||
|
|
||||||
parser = argparse.ArgumentParser(description='MindSpore MNIST Example')
|
parser = argparse.ArgumentParser(description='MindSpore MNIST Example')
|
||||||
parser.add_argument('--device_target', type=str, default="Ascend",
|
parser.add_argument('--device_target', type=str, default="Ascend",
|
||||||
choices=['Ascend', 'GPU', 'CPU'],
|
choices=['Ascend', 'GPU'],
|
||||||
help='device where the code will be implemented (default: Ascend)')
|
help='device where the code will be implemented (default: Ascend)')
|
||||||
parser.add_argument('--data_path', type=str, default="./MNIST_Data",
|
parser.add_argument('--data_path', type=str, default="./MNIST_Data",
|
||||||
help='path where the dataset is saved')
|
help='path where the dataset is saved')
|
||||||
|
@ -50,11 +50,13 @@ if __name__ == "__main__":
|
||||||
|
|
||||||
# define fusion network
|
# define fusion network
|
||||||
network = LeNet5Fusion(cfg.num_classes)
|
network = LeNet5Fusion(cfg.num_classes)
|
||||||
|
|
||||||
|
# convert fusion network to quantization aware network
|
||||||
|
network = quant.convert_quant_network(network, quant_delay=0, bn_fold=False, freeze_bn=10000)
|
||||||
|
|
||||||
# load quantization aware network checkpoint
|
# load quantization aware network checkpoint
|
||||||
param_dict = load_checkpoint(args.ckpt_path, network.type)
|
param_dict = load_checkpoint(args.ckpt_path, network.type)
|
||||||
load_param_into_net(network, param_dict)
|
load_param_into_net(network, param_dict)
|
||||||
# convert fusion network to quantization aware network
|
|
||||||
network = quant.convert_quant_network(network, quant_delay=0, bn_fold=False, freeze_bn=10000)
|
|
||||||
|
|
||||||
# define network loss
|
# define network loss
|
||||||
net_loss = nn.SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True, reduction="mean")
|
net_loss = nn.SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True, reduction="mean")
|
||||||
|
@ -64,8 +66,7 @@ if __name__ == "__main__":
|
||||||
# call back and monitor
|
# call back and monitor
|
||||||
time_cb = TimeMonitor(data_size=ds_train.get_dataset_size())
|
time_cb = TimeMonitor(data_size=ds_train.get_dataset_size())
|
||||||
config_ckpt = CheckpointConfig(save_checkpoint_steps=cfg.epoch_size * step_size,
|
config_ckpt = CheckpointConfig(save_checkpoint_steps=cfg.epoch_size * step_size,
|
||||||
keep_checkpoint_max=cfg.keep_checkpoint_max,
|
keep_checkpoint_max=cfg.keep_checkpoint_max)
|
||||||
model_type="quant")
|
|
||||||
ckpt_callback = ModelCheckpoint(prefix="checkpoint_lenet", config=config_ckpt)
|
ckpt_callback = ModelCheckpoint(prefix="checkpoint_lenet", config=config_ckpt)
|
||||||
|
|
||||||
# define model
|
# define model
|
||||||
|
|
|
@ -40,7 +40,7 @@ export PYTHONPATH=${BASEPATH}:$PYTHONPATH
|
||||||
export DEVICE_ID=0
|
export DEVICE_ID=0
|
||||||
export RANK_ID=0
|
export RANK_ID=0
|
||||||
export RANK_SIZE=1
|
export RANK_SIZE=1
|
||||||
if [ -d "eval" ];
|
if [ -d "../eval" ];
|
||||||
then
|
then
|
||||||
rm -rf ../eval
|
rm -rf ../eval
|
||||||
fi
|
fi
|
||||||
|
|
|
@ -62,7 +62,7 @@ run_gpu()
|
||||||
|
|
||||||
BASEPATH=$(cd "`dirname $0`" || exit; pwd)
|
BASEPATH=$(cd "`dirname $0`" || exit; pwd)
|
||||||
export PYTHONPATH=${BASEPATH}:$PYTHONPATH
|
export PYTHONPATH=${BASEPATH}:$PYTHONPATH
|
||||||
if [ -d "train" ];
|
if [ -d "../train" ];
|
||||||
then
|
then
|
||||||
rm -rf ../train
|
rm -rf ../train
|
||||||
fi
|
fi
|
||||||
|
|
|
@ -40,7 +40,7 @@ export PYTHONPATH=${BASEPATH}:$PYTHONPATH
|
||||||
export DEVICE_ID=0
|
export DEVICE_ID=0
|
||||||
export RANK_ID=0
|
export RANK_ID=0
|
||||||
export RANK_SIZE=1
|
export RANK_SIZE=1
|
||||||
if [ -d "eval" ];
|
if [ -d "../eval" ];
|
||||||
then
|
then
|
||||||
rm -rf ../eval
|
rm -rf ../eval
|
||||||
fi
|
fi
|
||||||
|
|
|
@ -60,7 +60,7 @@ run_gpu()
|
||||||
|
|
||||||
BASEPATH=$(cd "`dirname $0`" || exit; pwd)
|
BASEPATH=$(cd "`dirname $0`" || exit; pwd)
|
||||||
export PYTHONPATH=${BASEPATH}:$PYTHONPATH
|
export PYTHONPATH=${BASEPATH}:$PYTHONPATH
|
||||||
if [ -d "train" ];
|
if [ -d "../train" ];
|
||||||
then
|
then
|
||||||
rm -rf ../train
|
rm -rf ../train
|
||||||
fi
|
fi
|
||||||
|
|
|
@ -17,7 +17,7 @@ def _conv_bn(in_channel,
|
||||||
out_channel,
|
out_channel,
|
||||||
kernel_size=ksize,
|
kernel_size=ksize,
|
||||||
stride=stride,
|
stride=stride,
|
||||||
batchnorm=True)])
|
has_bn=True)])
|
||||||
|
|
||||||
|
|
||||||
class InvertedResidual(nn.Cell):
|
class InvertedResidual(nn.Cell):
|
||||||
|
@ -35,25 +35,25 @@ class InvertedResidual(nn.Cell):
|
||||||
3,
|
3,
|
||||||
stride,
|
stride,
|
||||||
group=hidden_dim,
|
group=hidden_dim,
|
||||||
batchnorm=True,
|
has_bn=True,
|
||||||
activation='relu6'),
|
activation='relu6'),
|
||||||
nn.Conv2dBnAct(hidden_dim, oup, 1, 1,
|
nn.Conv2dBnAct(hidden_dim, oup, 1, 1,
|
||||||
batchnorm=True)
|
has_bn=True)
|
||||||
])
|
])
|
||||||
else:
|
else:
|
||||||
self.conv = nn.SequentialCell([
|
self.conv = nn.SequentialCell([
|
||||||
nn.Conv2dBnAct(inp, hidden_dim, 1, 1,
|
nn.Conv2dBnAct(inp, hidden_dim, 1, 1,
|
||||||
batchnorm=True,
|
has_bn=True,
|
||||||
activation='relu6'),
|
activation='relu6'),
|
||||||
nn.Conv2dBnAct(hidden_dim,
|
nn.Conv2dBnAct(hidden_dim,
|
||||||
hidden_dim,
|
hidden_dim,
|
||||||
3,
|
3,
|
||||||
stride,
|
stride,
|
||||||
group=hidden_dim,
|
group=hidden_dim,
|
||||||
batchnorm=True,
|
has_bn=True,
|
||||||
activation='relu6'),
|
activation='relu6'),
|
||||||
nn.Conv2dBnAct(hidden_dim, oup, 1, 1,
|
nn.Conv2dBnAct(hidden_dim, oup, 1, 1,
|
||||||
batchnorm=True)
|
has_bn=True)
|
||||||
])
|
])
|
||||||
self.add = P.TensorAdd()
|
self.add = P.TensorAdd()
|
||||||
|
|
||||||
|
|
|
@ -42,7 +42,7 @@ class LeNet5(nn.Cell):
|
||||||
def __init__(self, num_class=10):
|
def __init__(self, num_class=10):
|
||||||
super(LeNet5, self).__init__()
|
super(LeNet5, self).__init__()
|
||||||
self.num_class = num_class
|
self.num_class = num_class
|
||||||
self.conv1 = nn.Conv2dBnAct(1, 6, kernel_size=5, batchnorm=True, activation='relu6', pad_mode="valid")
|
self.conv1 = nn.Conv2dBnAct(1, 6, kernel_size=5, has_bn=True, activation='relu6', pad_mode="valid")
|
||||||
self.conv2 = nn.Conv2dBnAct(6, 16, kernel_size=5, activation='relu', pad_mode="valid")
|
self.conv2 = nn.Conv2dBnAct(6, 16, kernel_size=5, activation='relu', pad_mode="valid")
|
||||||
self.fc1 = nn.DenseBnAct(16 * 5 * 5, 120, activation='relu')
|
self.fc1 = nn.DenseBnAct(16 * 5 * 5, 120, activation='relu')
|
||||||
self.fc2 = nn.DenseBnAct(120, 84, activation='relu')
|
self.fc2 = nn.DenseBnAct(120, 84, activation='relu')
|
||||||
|
|
Loading…
Reference in New Issue