!2718 fix quantization aware training auto create graph bug

Merge pull request !2718 from chenzhongming/master
This commit is contained in:
mindspore-ci-bot 2020-07-02 09:57:30 +08:00 committed by Gitee
commit f1a9a7ceb1
28 changed files with 322 additions and 142 deletions

View File

@ -22,7 +22,6 @@ message Checkpoint {
required TensorProto tensor = 2; required TensorProto tensor = 2;
} }
repeated Value value = 1; repeated Value value = 1;
required string model_type = 2;
} }

View File

@ -81,6 +81,7 @@ class Cell:
self.enable_hook = False self.enable_hook = False
self._bprop_debug = False self._bprop_debug = False
self._is_run = False self._is_run = False
self.cell_type = None
@property @property
def is_run(self): def is_run(self):
@ -140,6 +141,14 @@ class Cell:
for cell_name, cell in cells_name: for cell_name, cell in cells_name:
cell._param_prefix = cell_name cell._param_prefix = cell_name
def update_cell_type(self, cell_type):
"""
Update current cell type mainly identify if quantization aware training network.
After invoked, can set the cell type to 'cell_type'.
"""
self.cell_type = cell_type
@cell_init_args.setter @cell_init_args.setter
def cell_init_args(self, value): def cell_init_args(self, value):
if not isinstance(value, str): if not isinstance(value, str):

View File

@ -17,11 +17,12 @@ from mindspore import log as logger
from mindspore.ops import operations as P from mindspore.ops import operations as P
from mindspore.common.parameter import Parameter from mindspore.common.parameter import Parameter
from mindspore.common.initializer import initializer from mindspore.common.initializer import initializer
from mindspore._checkparam import ParamValidator as validator, Rel
from mindspore._checkparam import check_bool, twice, check_int_positive, check_int_non_negative from mindspore._checkparam import check_bool, twice, check_int_positive, check_int_non_negative
from mindspore._extends import cell_attr_register from mindspore._extends import cell_attr_register
from ..cell import Cell from ..cell import Cell
__all__ = ['Conv2d', 'Conv2dTranspose'] __all__ = ['Conv2d', 'Conv2dTranspose', 'DepthwiseConv2d']
class _Conv(Cell): class _Conv(Cell):
""" """
@ -397,3 +398,150 @@ class Conv2dTranspose(_Conv):
self.weight, self.weight,
self.bias) self.bias)
return s return s
class DepthwiseConv2d(Cell):
r"""
2D depthwise convolution layer.
Applies a 2D depthwise convolution over an input tensor which is typically of shape:
math:`(N, C_{in}, H_{in}, W_{in})`, where :math:`N` is batch size and :math:`C_{in}` is channel number.
For each batch of shape:math:`(C_{in}, H_{in}, W_{in})`, the formula is defined as:
.. math::
out_j = \sum_{i=0}^{C_{in} - 1} ccor(W_{ij}, X_i) + b_j,
where :math:`ccor` is cross correlation operator, :math:`C_{in}` is the input channel number, :math:`j` ranges
from :math:`0` to :math:`C_{out} - 1`, :math:`W_{ij}` corresponds to :math:`i`-th channel of the :math:`j`-th
filter and :math:`out_{j}` corresponds to the :math:`j`-th channel of the output. :math:`W_{ij}` is a slice
of kernel and it has shape :math:`(\text{ks_h}, \text{ks_w})`, where :math:`\text{ks_h}` and
:math:`\text{ks_w}` are height and width of the convolution kernel. The full kernel has shape
:math:`(C_{out}, C_{in} // \text{group}, \text{ks_h}, \text{ks_w})`, where group is the group number
to split the input in the channel dimension.
If the 'pad_mode' is set to be "valid", the output height and width will be
:math:`\left \lfloor{1 + \frac{H_{in} + 2 \times \text{padding} - \text{ks_h} -
(\text{ks_h} - 1) \times (\text{dilation} - 1) }{\text{stride}}} \right \rfloor` and
:math:`\left \lfloor{1 + \frac{W_{in} + 2 \times \text{padding} - \text{ks_w} -
(\text{ks_w} - 1) \times (\text{dilation} - 1) }{\text{stride}}} \right \rfloor` respectively.
The first introduction can be found in paper `Gradient Based Learning Applied to Document Recognition
<http://vision.stanford.edu/cs598_spring07/papers/Lecun98.pdf>`_.
Args:
in_channels (int): The number of input channel :math:`C_{in}`.
out_channels (int): The number of output channel :math:`C_{out}`.
kernel_size (Union[int, tuple[int]]): The data type is int or tuple with 2 integers. Specifies the height
and width of the 2D convolution window. Single int means the value if for both height and width of
the kernel. A tuple of 2 ints means the first value is for the height and the other is for the
width of the kernel.
stride (Union[int, tuple[int]]): The distance of kernel moving, an int number that represents
the height and width of movement are both strides, or a tuple of two int numbers that
represent height and width of movement respectively. Default: 1.
pad_mode (str): Specifies padding mode. The optional values are
"same", "valid", "pad". Default: "same".
- same: Adopts the way of completion. Output height and width will be the same as the input.
Total number of padding will be calculated for horizontal and vertical
direction and evenly distributed to top and bottom, left and right if possible. Otherwise, the
last extra padding will be done from the bottom and the right side. If this mode is set, `padding`
must be 0.
- valid: Adopts the way of discarding. The possibly largest height and width of output will be return
without padding. Extra pixels will be discarded. If this mode is set, `padding`
must be 0.
- pad: Implicit paddings on both sides of the input. The number of `padding` will be padded to the input
Tensor borders. `padding` should be greater than or equal to 0.
padding (int): Implicit paddings on both sides of the input. Default: 0.
dilation (Union[int, tuple[int]]): The data type is int or tuple with 2 integers. Specifies the dilation rate
to use for dilated convolution. If set to be :math:`k > 1`, there will
be :math:`k - 1` pixels skipped for each sampling location. Its value should
be greater or equal to 1 and bounded by the height and width of the
input. Default: 1.
group (int): Split filter into groups, `in_ channels` and `out_channels` should be
divisible by the number of groups. Default: 1.
has_bias (bool): Specifies whether the layer uses a bias vector. Default: False.
weight_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the convolution kernel.
It can be a Tensor, a string, an Initializer or a numbers.Number. When a string is specified,
values from 'TruncatedNormal', 'Normal', 'Uniform', 'HeUniform' and 'XavierUniform' distributions as well
as constant 'One' and 'Zero' distributions are possible. Alias 'xavier_uniform', 'he_uniform', 'ones'
and 'zeros' are acceptable. Uppercase and lowercase are both acceptable. Refer to the values of
Initializer for more details. Default: 'normal'.
bias_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the bias vector. Possible
Initializer and string are the same as 'weight_init'. Refer to the values of
Initializer for more details. Default: 'zeros'.
Inputs:
- **input** (Tensor) - Tensor of shape :math:`(N, C_{in}, H_{in}, W_{in})`.
Outputs:
Tensor of shape :math:`(N, C_{out}, H_{out}, W_{out})`.
Examples:
>>> net = nn.DepthwiseConv2d(120, 240, 4, has_bias=False, weight_init='normal')
>>> input = Tensor(np.ones([1, 120, 1024, 640]), mindspore.float32)
>>> net(input).shape
(1, 240, 1024, 640)
"""
def __init__(self,
in_channels,
out_channels,
kernel_size,
stride=1,
pad_mode='same',
padding=0,
dilation=1,
group=1,
has_bias=False,
weight_init='normal',
bias_init='zeros'):
super(DepthwiseConv2d, self).__init__()
self.kernel_size = twice(kernel_size)
self.stride = twice(stride)
self.dilation = twice(dilation)
self.in_channels = check_int_positive(in_channels)
self.out_channels = check_int_positive(out_channels)
validator.check_integer('group', group, in_channels, Rel.EQ)
validator.check_integer('group', group, out_channels, Rel.EQ)
validator.check_integer('group', group, 1, Rel.GE)
self.pad_mode = pad_mode
self.padding = padding
self.dilation = dilation
self.group = group
self.has_bias = has_bias
self.conv = P.DepthwiseConv2dNative(channel_multiplier=1,
kernel_size=self.kernel_size,
pad_mode=self.pad_mode,
pad=self.padding,
stride=self.stride,
dilation=self.dilation)
self.bias_add = P.BiasAdd()
weight_shape = [1, in_channels, *self.kernel_size]
self.weight = Parameter(initializer(weight_init, weight_shape), name='weight')
if check_bool(has_bias):
self.bias = Parameter(initializer(bias_init, [out_channels]), name='bias')
else:
if bias_init != 'zeros':
logger.warning("value of `has_bias` is False, value of `bias_init` will be ignore.")
self.bias = None
def construct(self, x):
out = self.conv(x, self.weight)
if self.has_bias:
out = self.bias_add(out, self.bias)
return out
def extend_repr(self):
s = 'input_channels={}, output_channels={}, kernel_size={}, stride={}, ' \
'pad_mode={}, padding={}, dilation={}, group={},' \
'has_bias={}, weight_init={}, bias_init={}'.format(
self.in_channels, self.out_channels, self.kernel_size, self.stride,
self.pad_mode, self.padding, self.dilation, self.group,
self.has_bias, self.weight_init, self.bias_init)
if self.has_bias:
s += ', bias={}'.format(self.bias)
return s

View File

@ -16,6 +16,7 @@
from functools import partial from functools import partial
import numpy as np import numpy as np
import mindspore.common.dtype as mstype import mindspore.common.dtype as mstype
from mindspore.ops import operations as P from mindspore.ops import operations as P
from mindspore.ops import functional as F from mindspore.ops import functional as F
@ -23,10 +24,9 @@ from mindspore.common.parameter import Parameter
from mindspore.common.initializer import initializer from mindspore.common.initializer import initializer
from mindspore.common.tensor import Tensor from mindspore.common.tensor import Tensor
from mindspore._checkparam import check_int_positive, check_bool, twice from mindspore._checkparam import check_int_positive, check_bool, twice
from mindspore._checkparam import Validator as validator, Rel from mindspore._checkparam import Rel
from mindspore.nn.cell import Cell
from mindspore.nn.layer.activation import get_activation
import mindspore.context as context import mindspore.context as context
from .normalization import BatchNorm2d from .normalization import BatchNorm2d
from .activation import get_activation from .activation import get_activation
from ..cell import Cell from ..cell import Cell
@ -82,7 +82,7 @@ class Conv2dBnAct(Cell):
bias_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the bias vector. Possible bias_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the bias vector. Possible
Initializer and string are the same as 'weight_init'. Refer to the values of Initializer and string are the same as 'weight_init'. Refer to the values of
Initializer for more details. Default: 'zeros'. Initializer for more details. Default: 'zeros'.
batchnorm (bool): Specifies to used batchnorm or not. Default: None. has_bn (bool): Specifies to used batchnorm or not. Default: False.
activation (string): Specifies activation type. The optional values are as following: activation (string): Specifies activation type. The optional values are as following:
'softmax', 'logsoftmax', 'relu', 'relu6', 'tanh', 'gelu', 'sigmoid', 'softmax', 'logsoftmax', 'relu', 'relu6', 'tanh', 'gelu', 'sigmoid',
'prelu', 'leakyrelu', 'hswish', 'hsigmoid'. Default: None. 'prelu', 'leakyrelu', 'hswish', 'hsigmoid'. Default: None.
@ -94,7 +94,7 @@ class Conv2dBnAct(Cell):
Tensor of shape :math:`(N, C_{out}, H_{out}, W_{out})`. Tensor of shape :math:`(N, C_{out}, H_{out}, W_{out})`.
Examples: Examples:
>>> net = Conv2dBnAct(120, 240, 4, batchnorm=True, activation='ReLU') >>> net = Conv2dBnAct(120, 240, 4, has_bn=True, activation='ReLU')
>>> input = Tensor(np.ones([1, 120, 1024, 640]), mindspore.float32) >>> input = Tensor(np.ones([1, 120, 1024, 640]), mindspore.float32)
>>> net(input).shape >>> net(input).shape
(1, 240, 1024, 640) (1, 240, 1024, 640)
@ -112,28 +112,39 @@ class Conv2dBnAct(Cell):
has_bias=False, has_bias=False,
weight_init='normal', weight_init='normal',
bias_init='zeros', bias_init='zeros',
batchnorm=None, has_bn=False,
activation=None): activation=None):
super(Conv2dBnAct, self).__init__() super(Conv2dBnAct, self).__init__()
self.conv = conv.Conv2d(
in_channels, if context.get_context('device_target') == "Ascend" and group > 1:
out_channels, self.conv = conv.DepthwiseConv2d(in_channels,
kernel_size, out_channels,
stride, kernel_size=kernel_size,
pad_mode, stride=stride,
padding, pad_mode=pad_mode,
dilation, padding=padding,
group, dilation=dilation,
has_bias, group=group,
weight_init, has_bias=has_bias,
bias_init) weight_init=weight_init,
self.has_bn = batchnorm is not None bias_init=bias_init)
else:
self.conv = conv.Conv2d(in_channels,
out_channels,
kernel_size=kernel_size,
stride=stride,
pad_mode=pad_mode,
padding=padding,
dilation=dilation,
group=group,
has_bias=has_bias,
weight_init=weight_init,
bias_init=bias_init)
self.has_bn = validator.check_bool("has_bn", has_bn)
self.has_act = activation is not None self.has_act = activation is not None
self.batchnorm = batchnorm if has_bn:
if batchnorm is True:
self.batchnorm = BatchNorm2d(out_channels) self.batchnorm = BatchNorm2d(out_channels)
elif batchnorm is not None:
validator.check_isinstance('batchnorm', batchnorm, (BatchNorm2d,))
self.activation = get_activation(activation) self.activation = get_activation(activation)
def construct(self, x): def construct(self, x):
@ -160,7 +171,7 @@ class DenseBnAct(Cell):
same as input x. The values of str refer to the function `initializer`. Default: 'zeros'. same as input x. The values of str refer to the function `initializer`. Default: 'zeros'.
has_bias (bool): Specifies whether the layer uses a bias vector. Default: True. has_bias (bool): Specifies whether the layer uses a bias vector. Default: True.
activation (str): Regularizer function applied to the output of the layer, eg. 'relu'. Default: None. activation (str): Regularizer function applied to the output of the layer, eg. 'relu'. Default: None.
batchnorm (bool): Specifies to used batchnorm or not. Default: None. has_bn (bool): Specifies to used batchnorm or not. Default: False.
activation (string): Specifies activation type. The optional values are as following: activation (string): Specifies activation type. The optional values are as following:
'softmax', 'logsoftmax', 'relu', 'relu6', 'tanh', 'gelu', 'sigmoid', 'softmax', 'logsoftmax', 'relu', 'relu6', 'tanh', 'gelu', 'sigmoid',
'prelu', 'leakyrelu', 'hswish', 'hsigmoid'. Default: None. 'prelu', 'leakyrelu', 'hswish', 'hsigmoid'. Default: None.
@ -183,7 +194,7 @@ class DenseBnAct(Cell):
weight_init='normal', weight_init='normal',
bias_init='zeros', bias_init='zeros',
has_bias=True, has_bias=True,
batchnorm=None, has_bn=False,
activation=None): activation=None):
super(DenseBnAct, self).__init__() super(DenseBnAct, self).__init__()
self.dense = basic.Dense( self.dense = basic.Dense(
@ -192,12 +203,10 @@ class DenseBnAct(Cell):
weight_init, weight_init,
bias_init, bias_init,
has_bias) has_bias)
self.has_bn = batchnorm is not None self.has_bn = validator.check_bool("has_bn", has_bn)
self.has_act = activation is not None self.has_act = activation is not None
if batchnorm is True: if has_bn:
self.batchnorm = BatchNorm2d(out_channels) self.batchnorm = BatchNorm2d(out_channels)
elif batchnorm is not None:
validator.check_isinstance('batchnorm', batchnorm, (BatchNorm2d,))
self.activation = get_activation(activation) self.activation = get_activation(activation)
def construct(self, x): def construct(self, x):
@ -312,6 +321,10 @@ class FakeQuantWithMinMax(Cell):
quant_delay=0): quant_delay=0):
"""init FakeQuantWithMinMax layer""" """init FakeQuantWithMinMax layer"""
super(FakeQuantWithMinMax, self).__init__() super(FakeQuantWithMinMax, self).__init__()
validator.check_type("min_init", min_init, [int, float])
validator.check_type("max_init", max_init, [int, float])
validator.check("min_init", min_init, "max_init", max_init, rel=Rel.LT)
validator.check_integer('quant_delay', quant_delay, 0, Rel.GE)
self.min_init = min_init self.min_init = min_init
self.max_init = max_init self.max_init = max_init
self.num_bits = num_bits self.num_bits = num_bits
@ -1183,12 +1196,13 @@ class QuantBlock(Cell):
self.has_bias = bias is None self.has_bias = bias is None
self.activation = activation self.activation = activation
self.has_act = activation is None self.has_act = activation is None
self.bias_add = P.BiasAdd()
def construct(self, x): def construct(self, x):
x = self.quant(x) x = self.quant(x)
x = self.core_op(x, self.weight) x = self.core_op(x, self.weight)
if self.has_bias: if self.has_bias:
output = self.bias_add(output, self.bias) x = self.bias_add(x, self.bias)
if self.has_act: if self.has_act:
x = self.activation(x) x = self.activation(x)
x = self.dequant(x, self.dequant_scale) x = self.dequant(x, self.dequant_scale)

View File

@ -30,7 +30,6 @@ batchnorm_fold2_op_info = TBERegOp("BatchNormFold2_D") \
.compute_cost(10) \ .compute_cost(10) \
.kernel_name("batchnorm_fold2") \ .kernel_name("batchnorm_fold2") \
.partial_flag(True) \ .partial_flag(True) \
.op_pattern("formatAgnostic") \
.input(0, "x", None, "required", None) \ .input(0, "x", None, "required", None) \
.input(1, "beta", None, "required", None) \ .input(1, "beta", None, "required", None) \
.input(2, "gamma", None, "required", None) \ .input(2, "gamma", None, "required", None) \

View File

@ -30,7 +30,6 @@ batchnorm_fold2_grad_op_info = TBERegOp("BatchNormFold2GradD") \
.compute_cost(10) \ .compute_cost(10) \
.kernel_name("batchnorm_fold2_grad") \ .kernel_name("batchnorm_fold2_grad") \
.partial_flag(True) \ .partial_flag(True) \
.op_pattern("formatAgnostic") \
.input(0, "dout", None, "required", None) \ .input(0, "dout", None, "required", None) \
.input(1, "dout_reduce", None, "required", None) \ .input(1, "dout_reduce", None, "required", None) \
.input(2, "dout_x_reduce", None, "required", None) \ .input(2, "dout_x_reduce", None, "required", None) \

View File

@ -31,7 +31,6 @@ batchnorm_fold2_grad_reduce_op_info = TBERegOp("BatchNormFold2GradReduce") \
.compute_cost(10) \ .compute_cost(10) \
.kernel_name("batchnorm_fold2_grad_reduce") \ .kernel_name("batchnorm_fold2_grad_reduce") \
.partial_flag(True) \ .partial_flag(True) \
.op_pattern("formatAgnostic") \
.input(0, "dout", None, "required", None) \ .input(0, "dout", None, "required", None) \
.input(1, "x", None, "required", None) \ .input(1, "x", None, "required", None) \
.output(0, "dout_reduce", True, "required", "all") \ .output(0, "dout_reduce", True, "required", "all") \

View File

@ -30,7 +30,6 @@ correction_mul_op_info = TBERegOp("CorrectionMul") \
.compute_cost(10) \ .compute_cost(10) \
.kernel_name("correction_mul") \ .kernel_name("correction_mul") \
.partial_flag(True) \ .partial_flag(True) \
.op_pattern("formatAgnostic") \
.attr("channel_axis", "optional", "int", "all") \ .attr("channel_axis", "optional", "int", "all") \
.input(0, "x", None, "required", None) \ .input(0, "x", None, "required", None) \
.input(1, "batch_std", None, "required", None) \ .input(1, "batch_std", None, "required", None) \

View File

@ -30,7 +30,6 @@ correction_mul_grad_op_info = TBERegOp("CorrectionMulGrad") \
.compute_cost(10) \ .compute_cost(10) \
.kernel_name("correction_mul_grad") \ .kernel_name("correction_mul_grad") \
.partial_flag(True) \ .partial_flag(True) \
.op_pattern("formatAgnostic") \
.attr("channel_axis", "optional", "int", "all") \ .attr("channel_axis", "optional", "int", "all") \
.input(0, "dout", None, "required", None) \ .input(0, "dout", None, "required", None) \
.input(1, "x", None, "required", None) \ .input(1, "x", None, "required", None) \
@ -128,7 +127,6 @@ correction_mul_grad_reduce_op_info = TBERegOp("CorrectionMulGradReduce") \
.compute_cost(10) \ .compute_cost(10) \
.kernel_name("correction_mul_grad_reduce") \ .kernel_name("correction_mul_grad_reduce") \
.partial_flag(True) \ .partial_flag(True) \
.op_pattern("formatAgnostic") \
.attr("channel_axis", "optional", "int", "all") \ .attr("channel_axis", "optional", "int", "all") \
.input(0, "dout", None, "required", None) \ .input(0, "dout", None, "required", None) \
.output(0, "d_batch_std", True, "required", "all") \ .output(0, "d_batch_std", True, "required", "all") \

View File

@ -99,11 +99,15 @@ def fake_quant_perchannel(x, min_val, max_val, y,
min_dtype = min_val.get("dtype") min_dtype = min_val.get("dtype")
max_shape = max_val.get("ori_shape") max_shape = max_val.get("ori_shape")
max_dtype = max_val.get("dtype") max_dtype = max_val.get("dtype")
# for Dense weight quant, 2d[co,ci] -> 4d[1,co,ci,1], channel_axis_ need change to 1.
if channel_axis == 0 and x_shape_[0] != min_shape[0] and x_shape_[1] == min_shape[0]:
channel_axis_ = 1
else:
channel_axis_ = channel_axis
util.check_kernel_name(kernel_name) util.check_kernel_name(kernel_name)
util.check_shape_rule(x_shape) util.check_shape_rule(x_shape)
util.check_shape_rule(min_shape, 1, 1, x_shape_[channel_axis]) util.check_shape_rule(min_shape, 1, 1, x_shape_[channel_axis_])
util.check_shape_rule(max_shape, 1, 1, x_shape_[channel_axis]) util.check_shape_rule(max_shape, 1, 1, x_shape_[channel_axis_])
util.check_tensor_shape_size(x_shape) util.check_tensor_shape_size(x_shape)
util.check_tensor_shape_size(min_shape) util.check_tensor_shape_size(min_shape)
util.check_tensor_shape_size(max_shape) util.check_tensor_shape_size(max_shape)
@ -126,8 +130,8 @@ def fake_quant_perchannel(x, min_val, max_val, y,
quant_min = quant_min + 1 quant_min = quant_min + 1
shape_c = [1] * len(x_shape) shape_c = [1] * len(x_shape)
shape_c[channel_axis] = min_val.get("ori_shape")[0] shape_c[channel_axis_] = min_val.get("ori_shape")[0]
if x_format == "NC1HWC0" and channel_axis == 1: if x_format == "NC1HWC0" and channel_axis_ == 1:
shape_c = min_val.get("shape") shape_c = min_val.get("shape")
input_data = tvm.placeholder(x_shape, name="x", dtype=x_dtype) input_data = tvm.placeholder(x_shape, name="x", dtype=x_dtype)
min_data = tvm.placeholder(shape_c, name="min_val", dtype=x_dtype) min_data = tvm.placeholder(shape_c, name="min_val", dtype=x_dtype)

View File

@ -124,11 +124,15 @@ def fake_quant_perchannel_grad(dout, x, min_val, max_val, dx,
min_dtype = min_val.get("dtype") min_dtype = min_val.get("dtype")
max_shape = max_val.get("ori_shape") max_shape = max_val.get("ori_shape")
max_dtype = max_val.get("dtype") max_dtype = max_val.get("dtype")
# for Dense weight quant, 2d[co,ci] -> 4d[1,co,ci,1], channel_axis_ need change to 1.
if channel_axis == 0 and x_shape_[0] != min_shape[0] and x_shape_[1] == min_shape[0]:
channel_axis_ = 1
else:
channel_axis_ = channel_axis
util.check_kernel_name(kernel_name) util.check_kernel_name(kernel_name)
util.check_shape_rule(x_shape) util.check_shape_rule(x_shape)
util.check_shape_rule(min_shape, 1, 1, x_shape_[channel_axis]) util.check_shape_rule(min_shape, 1, 1, x_shape_[channel_axis_])
util.check_shape_rule(max_shape, 1, 1, x_shape_[channel_axis]) util.check_shape_rule(max_shape, 1, 1, x_shape_[channel_axis_])
util.check_tensor_shape_size(x_shape) util.check_tensor_shape_size(x_shape)
util.check_tensor_shape_size(min_shape) util.check_tensor_shape_size(min_shape)
util.check_tensor_shape_size(max_shape) util.check_tensor_shape_size(max_shape)
@ -151,8 +155,8 @@ def fake_quant_perchannel_grad(dout, x, min_val, max_val, dx,
quant_min = quant_min + 1 quant_min = quant_min + 1
shape_c = [1] * len(x_shape) shape_c = [1] * len(x_shape)
shape_c[channel_axis] = min_val.get("ori_shape")[0] shape_c[channel_axis_] = min_val.get("ori_shape")[0]
if x_format == "NC1HWC0" and channel_axis == 1: if x_format == "NC1HWC0" and channel_axis_ == 1:
shape_c = min_val.get("shape") shape_c = min_val.get("shape")
dout_data = tvm.placeholder(x_shape, name="dout", dtype=x_dtype) dout_data = tvm.placeholder(x_shape, name="dout", dtype=x_dtype)
input_data = tvm.placeholder(x_shape, name="x", dtype=x_dtype) input_data = tvm.placeholder(x_shape, name="x", dtype=x_dtype)

View File

@ -88,11 +88,15 @@ def minmax_update_perchannel(x, min_val, max_val, min_up, max_up,
min_dtype = min_val.get("dtype") min_dtype = min_val.get("dtype")
max_shape = max_val.get("ori_shape") max_shape = max_val.get("ori_shape")
max_dtype = max_val.get("dtype") max_dtype = max_val.get("dtype")
# for Dense weight quant, 2d[co,ci] -> 4d[1,co,ci,1], channel_axis_ need change to 1.
if channel_axis == 0 and x_shape[0] != min_shape[0] and x_shape[1] == min_shape[0]:
channel_axis_ = 1
else:
channel_axis_ = channel_axis
util.check_kernel_name(kernel_name) util.check_kernel_name(kernel_name)
util.check_shape_rule(x_shape) util.check_shape_rule(x_shape)
util.check_shape_rule(min_shape, 1, 1, x_shape[channel_axis]) util.check_shape_rule(min_shape, 1, 1, x_shape[channel_axis_])
util.check_shape_rule(max_shape, 1, 1, x_shape[channel_axis]) util.check_shape_rule(max_shape, 1, 1, x_shape[channel_axis_])
util.check_tensor_shape_size(x_shape) util.check_tensor_shape_size(x_shape)
util.check_tensor_shape_size(min_shape) util.check_tensor_shape_size(min_shape)
util.check_tensor_shape_size(max_shape) util.check_tensor_shape_size(max_shape)
@ -105,7 +109,7 @@ def minmax_update_perchannel(x, min_val, max_val, min_up, max_up,
util.check_dtype_rule(min_dtype, check_list) util.check_dtype_rule(min_dtype, check_list)
util.check_dtype_rule(max_dtype, check_list) util.check_dtype_rule(max_dtype, check_list)
if channel_axis == 0: if channel_axis_ == 0:
shape_c = min_val.get("ori_shape") shape_c = min_val.get("ori_shape")
else: else:
shape_c = [min_val.get("shape")[1], min_val.get("shape")[-1]] shape_c = [min_val.get("shape")[1], min_val.get("shape")[-1]]
@ -113,7 +117,7 @@ def minmax_update_perchannel(x, min_val, max_val, min_up, max_up,
min_data = tvm.placeholder(shape_c, name="min_val", dtype=x_dtype) min_data = tvm.placeholder(shape_c, name="min_val", dtype=x_dtype)
max_data = tvm.placeholder(shape_c, name="max_val", dtype=x_dtype) max_data = tvm.placeholder(shape_c, name="max_val", dtype=x_dtype)
res_list = minmax_update_perchannel_compute(input_data, min_data, max_data, res_list = minmax_update_perchannel_compute(input_data, min_data, max_data,
ema, ema_decay, channel_axis) ema, ema_decay, channel_axis_)
with tvm.target.cce(): with tvm.target.cce():
sch = generic.auto_schedule(res_list) sch = generic.auto_schedule(res_list)

View File

@ -106,7 +106,7 @@ class MinMaxUpdatePerChannel(PrimitiveWithInfer):
Args: Args:
ema (bool): Use EMA algorithm update value min and max. Default: False. ema (bool): Use EMA algorithm update value min and max. Default: False.
ema_decay (int) : EMA algorithm decay parameter. Default: 0.999. ema_decay (int) : EMA algorithm decay parameter. Default: 0.999.
channel_axis (int): Channel asis for per channel compute. Default: 1. channel_axis (int): Quantization by channel axis. Ascend backend only supports 0 or 1. Default: 1.
Inputs: Inputs:
- **x** (Tensor) : float32 Tensor representing the shape of the output tensor. - **x** (Tensor) : float32 Tensor representing the shape of the output tensor.
@ -123,11 +123,13 @@ class MinMaxUpdatePerChannel(PrimitiveWithInfer):
>>> output_tensor = MinMaxUpdatePerChannel(num_bits=8)(x, min, max) >>> output_tensor = MinMaxUpdatePerChannel(num_bits=8)(x, min, max)
""" """
support_quant_bit = [4, 7, 8] support_quant_bit = [4, 7, 8]
ascend_support_x_rank = [2, 4]
@prim_attr_register @prim_attr_register
def __init__(self, ema=False, ema_decay=0.999, channel_axis=1): def __init__(self, ema=False, ema_decay=0.999, channel_axis=1):
"""init FakeQuantPerChannelUpdate OP for Ascend""" """init FakeQuantPerChannelUpdate OP for Ascend"""
if context.get_context('device_target') == "Ascend": self.is_ascend = context.get_context('device_target') == "Ascend"
if self.is_ascend:
from mindspore.ops._op_impl._custom_op import minmax_update_perchannel from mindspore.ops._op_impl._custom_op import minmax_update_perchannel
if ema and not ema_decay: if ema and not ema_decay:
raise ValueError( raise ValueError(
@ -136,13 +138,18 @@ class MinMaxUpdatePerChannel(PrimitiveWithInfer):
self.ema = validator.check_value_type('ema', ema, (bool,), self.name) self.ema = validator.check_value_type('ema', ema, (bool,), self.name)
self.ema_decay = validator.check_number_range( self.ema_decay = validator.check_number_range(
'ema_decay', ema_decay, 0, 1, Rel.INC_BOTH, self.name) 'ema_decay', ema_decay, 0, 1, Rel.INC_BOTH, self.name)
self.channel_axis = validator.check_integer( if self.is_ascend:
'channel axis', channel_axis, 0, Rel.GE, self.name) self.channel_axis = validator.check_int_range('channel_axis', channel_axis, 0, 1, Rel.INC_BOTH, self.name)
else:
self.channel_axis = validator.check_integer('channel_axis', channel_axis, 0, Rel.GE, self.name)
self.init_prim_io_names( self.init_prim_io_names(
inputs=['x', 'min', 'max'], outputs=['min_up', 'max_up']) inputs=['x', 'min', 'max'], outputs=['min_up', 'max_up'])
def infer_shape(self, x_shape, min_shape, max_shape): def infer_shape(self, x_shape, min_shape, max_shape):
validator.check_integer("x rank", len(x_shape), 1, Rel.GT, self.name) if self.is_ascend and len(x_shape) not in self.ascend_support_x_rank:
raise ValueError(f"For '{self.name}' x rank should be in '{self.ascend_support_x_rank}'")
if not self.is_ascend:
validator.check_integer("x rank", len(x_shape), 1, Rel.GE, self.name)
validator.check("min shape", min_shape, "max shape", validator.check("min shape", min_shape, "max shape",
max_shape, Rel.EQ, self.name) max_shape, Rel.EQ, self.name)
validator.check_integer("min shape", len( validator.check_integer("min shape", len(
@ -221,8 +228,8 @@ class FakeQuantPerLayer(PrimitiveWithInfer):
'ema_decay', ema_decay, 0, 1, Rel.INC_BOTH, self.name) 'ema_decay', ema_decay, 0, 1, Rel.INC_BOTH, self.name)
self.num_bits = validator.check_integer( self.num_bits = validator.check_integer(
'num_bits', num_bits, 0, Rel.GT, self.name) 'num_bits', num_bits, 0, Rel.GT, self.name)
self.quant_delay = validator.check_value_type( self.quant_delay = validator.check_integer(
'quant_delay', quant_delay, (int,), self.name) 'quant_delay', quant_delay, 0, Rel.GE, self.name)
self.init_prim_io_names(inputs=['x', 'min', 'max'], self.init_prim_io_names(inputs=['x', 'min', 'max'],
outputs=['out']) outputs=['out'])
@ -314,6 +321,7 @@ class FakeQuantPerChannel(PrimitiveWithInfer):
symmetric (bool): Quantization algorithm use symmetric or not. Default: False. symmetric (bool): Quantization algorithm use symmetric or not. Default: False.
narrow_range (bool): Quantization algorithm use narrow range or not. Default: False. narrow_range (bool): Quantization algorithm use narrow range or not. Default: False.
training (bool): Training the network or not. Default: True. training (bool): Training the network or not. Default: True.
channel_axis (int): Quantization by channel axis. Ascend backend only supports 0 or 1. Default: 1.
Inputs: Inputs:
- **x** (Tensor) : 4-D float32 Tensor representing the shape of the output tensor. - **x** (Tensor) : 4-D float32 Tensor representing the shape of the output tensor.
@ -331,6 +339,7 @@ class FakeQuantPerChannel(PrimitiveWithInfer):
>>> result = fake_quant(input_x, _min, _max) >>> result = fake_quant(input_x, _min, _max)
""" """
support_quant_bit = [4, 7, 8] support_quant_bit = [4, 7, 8]
ascend_support_x_rank = [2, 4]
@prim_attr_register @prim_attr_register
def __init__(self, def __init__(self,
@ -343,7 +352,8 @@ class FakeQuantPerChannel(PrimitiveWithInfer):
training=True, training=True,
channel_axis=1): channel_axis=1):
"""init FakeQuantPerChannel OP""" """init FakeQuantPerChannel OP"""
if context.get_context('device_target') == "Ascend": self.is_ascend = context.get_context('device_target') == "Ascend"
if self.is_ascend:
from mindspore.ops._op_impl._custom_op import fake_quant_perchannel from mindspore.ops._op_impl._custom_op import fake_quant_perchannel
if num_bits not in self.support_quant_bit: if num_bits not in self.support_quant_bit:
raise ValueError( raise ValueError(
@ -363,14 +373,19 @@ class FakeQuantPerChannel(PrimitiveWithInfer):
'ema_decay', ema_decay, 0, 1, Rel.INC_BOTH, self.name) 'ema_decay', ema_decay, 0, 1, Rel.INC_BOTH, self.name)
self.num_bits = validator.check_integer( self.num_bits = validator.check_integer(
'num_bits', num_bits, 0, Rel.GT, self.name) 'num_bits', num_bits, 0, Rel.GT, self.name)
self.quant_delay = validator.check_value_type( self.quant_delay = validator.check_integer(
'quant_delay', quant_delay, (int,), self.name) 'quant_delay', quant_delay, 0, Rel.GE, self.name)
self.channel_axis = validator.check_integer( if self.is_ascend:
'channel_axis', channel_axis, 0, Rel.GE, self.name) self.channel_axis = validator.check_int_range('channel_axis', channel_axis, 0, 1, Rel.INC_BOTH, self.name)
else:
self.channel_axis = validator.check_integer('channel_axis', channel_axis, 0, Rel.GE, self.name)
self.init_prim_io_names(inputs=['x', 'min', 'max'], outputs=['out']) self.init_prim_io_names(inputs=['x', 'min', 'max'], outputs=['out'])
def infer_shape(self, x_shape, min_shape, max_shape): def infer_shape(self, x_shape, min_shape, max_shape):
validator.check_integer("x rank", len(x_shape), 1, Rel.GE, self.name) if self.is_ascend and len(x_shape) not in self.ascend_support_x_rank:
raise ValueError(f"For '{self.name}' x rank should be in '{self.ascend_support_x_rank}'")
if not self.is_ascend:
validator.check_integer("x rank", len(x_shape), 1, Rel.GE, self.name)
validator.check("min shape", min_shape, "max shape", max_shape, Rel.EQ, self.name) validator.check("min shape", min_shape, "max shape", max_shape, Rel.EQ, self.name)
validator.check_integer( validator.check_integer(
"min shape", min_shape[0], x_shape[self.channel_axis], Rel.EQ, self.name) "min shape", min_shape[0], x_shape[self.channel_axis], Rel.EQ, self.name)

View File

@ -21,7 +21,7 @@ import time
import mindspore.context as context import mindspore.context as context
from mindspore import log as logger from mindspore import log as logger
from mindspore._checkparam import check_bool, check_string, check_int_non_negative from mindspore._checkparam import check_bool, check_int_non_negative
from mindspore.train._utils import _make_directory from mindspore.train._utils import _make_directory
from mindspore.train.serialization import _exec_save_checkpoint, _save_graph from mindspore.train.serialization import _exec_save_checkpoint, _save_graph
from ._callback import Callback, set_cur_net from ._callback import Callback, set_cur_net
@ -86,7 +86,6 @@ class CheckpointConfig:
Can't be used with keep_checkpoint_max at the same time. Can't be used with keep_checkpoint_max at the same time.
integrated_save (bool): Whether to intergrated save in automatic model parallel scene. Default: True. integrated_save (bool): Whether to intergrated save in automatic model parallel scene. Default: True.
Integrated save function is only supported in automatic parallel scene, not supported in manual parallel. Integrated save function is only supported in automatic parallel scene, not supported in manual parallel.
model_type (str): Model type in `normal`, `fusion` or `quant`. Default: "normal".
Raises: Raises:
ValueError: If the input_param is None or 0. ValueError: If the input_param is None or 0.
@ -101,8 +100,7 @@ class CheckpointConfig:
save_checkpoint_seconds=0, save_checkpoint_seconds=0,
keep_checkpoint_max=5, keep_checkpoint_max=5,
keep_checkpoint_per_n_minutes=0, keep_checkpoint_per_n_minutes=0,
integrated_save=True, integrated_save=True):
model_type="normal"):
if not save_checkpoint_steps and not save_checkpoint_seconds and \ if not save_checkpoint_steps and not save_checkpoint_seconds and \
not keep_checkpoint_max and not keep_checkpoint_per_n_minutes: not keep_checkpoint_max and not keep_checkpoint_per_n_minutes:
@ -116,8 +114,6 @@ class CheckpointConfig:
keep_checkpoint_max = check_int_non_negative(keep_checkpoint_max) keep_checkpoint_max = check_int_non_negative(keep_checkpoint_max)
if keep_checkpoint_per_n_minutes: if keep_checkpoint_per_n_minutes:
keep_checkpoint_per_n_minutes = check_int_non_negative(keep_checkpoint_per_n_minutes) keep_checkpoint_per_n_minutes = check_int_non_negative(keep_checkpoint_per_n_minutes)
if model_type:
model_type = check_string(model_type, ["normal", "fusion", "quant"])
self._save_checkpoint_steps = save_checkpoint_steps self._save_checkpoint_steps = save_checkpoint_steps
self._save_checkpoint_seconds = save_checkpoint_seconds self._save_checkpoint_seconds = save_checkpoint_seconds
@ -132,7 +128,6 @@ class CheckpointConfig:
if not self._keep_checkpoint_per_n_minutes or self._keep_checkpoint_per_n_minutes == 0: if not self._keep_checkpoint_per_n_minutes or self._keep_checkpoint_per_n_minutes == 0:
self._keep_checkpoint_max = 1 self._keep_checkpoint_max = 1
self._model_type = model_type
self._integrated_save = check_bool(integrated_save) self._integrated_save = check_bool(integrated_save)
@property @property
@ -160,18 +155,12 @@ class CheckpointConfig:
"""Get the value of _integrated_save.""" """Get the value of _integrated_save."""
return self._integrated_save return self._integrated_save
@property
def model_type(self):
"""Get the value of model_type."""
return self._model_type
def get_checkpoint_policy(self): def get_checkpoint_policy(self):
"""Get the policy of checkpoint.""" """Get the policy of checkpoint."""
checkpoint_policy = {'save_checkpoint_steps': self._save_checkpoint_steps, checkpoint_policy = {'save_checkpoint_steps': self._save_checkpoint_steps,
'save_checkpoint_seconds': self._save_checkpoint_seconds, 'save_checkpoint_seconds': self._save_checkpoint_seconds,
'keep_checkpoint_max': self._keep_checkpoint_max, 'keep_checkpoint_max': self._keep_checkpoint_max,
'keep_checkpoint_per_n_minutes': self._keep_checkpoint_per_n_minutes, 'keep_checkpoint_per_n_minutes': self._keep_checkpoint_per_n_minutes}
'model_type': self._model_type}
return checkpoint_policy return checkpoint_policy
@ -236,7 +225,7 @@ class ModelCheckpoint(Callback):
graph_file_name = os.path.join(self._directory, self._prefix + '-graph.meta') graph_file_name = os.path.join(self._directory, self._prefix + '-graph.meta')
_save_graph(cb_params.train_network, graph_file_name) _save_graph(cb_params.train_network, graph_file_name)
self._graph_saved = True self._graph_saved = True
self._save_ckpt(cb_params, self._config.model_type) self._save_ckpt(cb_params)
def end(self, run_context): def end(self, run_context):
""" """
@ -247,7 +236,7 @@ class ModelCheckpoint(Callback):
""" """
cb_params = run_context.original_args() cb_params = run_context.original_args()
_to_save_last_ckpt = True _to_save_last_ckpt = True
self._save_ckpt(cb_params, self._config.model_type, _to_save_last_ckpt) self._save_ckpt(cb_params, _to_save_last_ckpt)
from mindspore.parallel._cell_wrapper import destroy_allgather_cell from mindspore.parallel._cell_wrapper import destroy_allgather_cell
destroy_allgather_cell() destroy_allgather_cell()
@ -266,7 +255,7 @@ class ModelCheckpoint(Callback):
return False return False
def _save_ckpt(self, cb_params, model_type, force_to_save=False): def _save_ckpt(self, cb_params, force_to_save=False):
"""Save checkpoint files.""" """Save checkpoint files."""
if cb_params.cur_step_num == self._last_triggered_step: if cb_params.cur_step_num == self._last_triggered_step:
return return
@ -302,7 +291,7 @@ class ModelCheckpoint(Callback):
set_cur_net(cb_params.train_network) set_cur_net(cb_params.train_network)
cb_params.train_network.exec_checkpoint_graph() cb_params.train_network.exec_checkpoint_graph()
_exec_save_checkpoint(cb_params.train_network, gen_file, model_type, self._config.integrated_save) _exec_save_checkpoint(cb_params.train_network, gen_file, self._config.integrated_save)
if os.path.exists(gen_file): if os.path.exists(gen_file):
shutil.move(gen_file, cur_file) shutil.move(gen_file, cur_file)

View File

@ -86,7 +86,7 @@ class LossMonitor(Callback):
if self._per_print_times != 0 and cb_params.cur_step_num % self._per_print_times == 0: if self._per_print_times != 0 and cb_params.cur_step_num % self._per_print_times == 0:
print("Epoch: [{:3d}/{:3d}], step: [{:5d}/{:5d}], " print("Epoch: [{:3d}/{:3d}], step: [{:5d}/{:5d}], "
"loss: [{:5.4f}/{:5.4f}], time: [{:5.4f}]".format( "loss: [{:5.4f}], avg los: [{:5.4f}], time: [{:5.4f}]".format(
cb_params.cur_epoch_num, cb_params.epoch_num, cb_params.cur_epoch_num, cb_params.epoch_num,
cur_step_in_epoch, int(cb_params.batch_num), cur_step_in_epoch, int(cb_params.batch_num),
step_loss, np.mean(self.losses), step_loss, np.mean(self.losses),

View File

@ -42,15 +42,14 @@ _ACTIVATION_MAP = {nn.ReLU: quant.ReLUQuant,
class _AddFakeQuantInput(nn.Cell): class _AddFakeQuantInput(nn.Cell):
""" """
Add FakeQuant at input and output of the Network. Only support one input and one output case. Add FakeQuant OP at input of the network. Only support one input case.
""" """
def __init__(self, network, quant_delay=0): def __init__(self, network, quant_delay=0):
super(_AddFakeQuantInput, self).__init__(auto_prefix=False) super(_AddFakeQuantInput, self).__init__(auto_prefix=False)
self.fake_quant_input = quant.FakeQuantWithMinMax(min_init=-6, max_init=6, quant_delay=quant_delay, ema=True)
self.fake_quant_input.update_parameters_name('fake_quant_input.')
self.network = network self.network = network
self.fake_quant_input = quant.FakeQuantWithMinMax(
min_init=-6, max_init=6, quant_delay=quant_delay, ema=True)
self.fake_quant_input.update_parameters_name('fake_quant_input')
def construct(self, data): def construct(self, data):
data = self.fake_quant_input(data) data = self.fake_quant_input(data)
@ -60,7 +59,7 @@ class _AddFakeQuantInput(nn.Cell):
class _AddFakeQuantAfterSubCell(nn.Cell): class _AddFakeQuantAfterSubCell(nn.Cell):
""" """
Add FakeQuant after of the sub Cell. Add FakeQuant OP after of the sub Cell.
""" """
def __init__(self, subcell, **kwargs): def __init__(self, subcell, **kwargs):
@ -115,11 +114,12 @@ class ConvertToQuantNetwork:
self.network.update_cell_prefix() self.network.update_cell_prefix()
network = self._convert_subcells2quant(self.network) network = self._convert_subcells2quant(self.network)
network = _AddFakeQuantInput(network) network = _AddFakeQuantInput(network)
self.network.update_cell_type("quant")
return network return network
def _convert_subcells2quant(self, network): def _convert_subcells2quant(self, network):
""" """
convet sub cell to quant cell convert sub cell like `Conv2dBnAct` and `DenseBnAct` to quant cell
""" """
cells = network.name_cells() cells = network.name_cells()
change = False change = False
@ -138,13 +138,13 @@ class ConvertToQuantNetwork:
if isinstance(network, nn.SequentialCell) and change: if isinstance(network, nn.SequentialCell) and change:
network.cell_list = list(network.cells()) network.cell_list = list(network.cells())
# tensoradd to tensoradd quant # add FakeQuant OP after OP in while list
add_list = [] add_list = []
for name in network.__dict__: for name in network.__dict__:
if name[0] == '_': if name[0] == '_':
continue continue
attr = network.__dict__[name] attr = network.__dict__[name]
if isinstance(attr, ops.Primitive) and attr.name in ConvertToQuantNetwork.__quant_op_name__: if isinstance(attr, ops.Primitive) and attr.name in self.__quant_op_name__:
add_list.append((name, attr)) add_list.append((name, attr))
for name, prim_op in add_list: for name, prim_op in add_list:
prefix = name prefix = name
@ -164,11 +164,11 @@ class ConvertToQuantNetwork:
def _convert_conv(self, subcell): def _convert_conv(self, subcell):
""" """
convet conv cell to quant cell convert Conv2d cell to quant cell
""" """
conv_inner = subcell.conv conv_inner = subcell.conv
bn_inner = subcell.batchnorm if subcell.has_bn and self.bn_fold:
if subcell.batchnorm is not None and self.bn_fold: bn_inner = subcell.batchnorm
conv_inner = quant.Conv2dBatchNormQuant(conv_inner.in_channels, conv_inner = quant.Conv2dBatchNormQuant(conv_inner.in_channels,
conv_inner.out_channels, conv_inner.out_channels,
kernel_size=conv_inner.kernel_size, kernel_size=conv_inner.kernel_size,
@ -178,7 +178,7 @@ class ConvertToQuantNetwork:
dilation=conv_inner.dilation, dilation=conv_inner.dilation,
group=conv_inner.group, group=conv_inner.group,
eps=bn_inner.eps, eps=bn_inner.eps,
momentum=bn_inner.momentum, momentum=1 - bn_inner.momentum,
quant_delay=self.weight_qdelay, quant_delay=self.weight_qdelay,
freeze_bn=self.freeze_bn, freeze_bn=self.freeze_bn,
per_channel=self.weight_channel, per_channel=self.weight_channel,
@ -186,6 +186,11 @@ class ConvertToQuantNetwork:
fake=True, fake=True,
symmetric=self.weight_symmetric, symmetric=self.weight_symmetric,
narrow_range=self.weight_range) narrow_range=self.weight_range)
# change original network BatchNormal OP parameters to quant network
conv_inner.gamma = subcell.batchnorm.gamma
conv_inner.beta = subcell.batchnorm.beta
conv_inner.moving_mean = subcell.batchnorm.moving_mean
conv_inner.moving_variance = subcell.batchnorm.moving_variance
del subcell.batchnorm del subcell.batchnorm
subcell.batchnorm = None subcell.batchnorm = None
subcell.has_bn = False subcell.has_bn = False
@ -204,6 +209,10 @@ class ConvertToQuantNetwork:
num_bits=self.weight_bits, num_bits=self.weight_bits,
symmetric=self.weight_symmetric, symmetric=self.weight_symmetric,
narrow_range=self.weight_range) narrow_range=self.weight_range)
# change original network Conv2D OP parameters to quant network
conv_inner.weight = subcell.conv.weight
if subcell.conv.has_bias:
conv_inner.bias = subcell.conv.bias
subcell.conv = conv_inner subcell.conv = conv_inner
if subcell.has_act and subcell.activation is not None: if subcell.has_act and subcell.activation is not None:
subcell.activation = self._convert_activation(subcell.activation) subcell.activation = self._convert_activation(subcell.activation)
@ -230,6 +239,10 @@ class ConvertToQuantNetwork:
per_channel=self.weight_channel, per_channel=self.weight_channel,
symmetric=self.weight_symmetric, symmetric=self.weight_symmetric,
narrow_range=self.weight_range) narrow_range=self.weight_range)
# change original network Dense OP parameters to quant network
dense_inner.weight = subcell.dense.weight
if subcell.dense.has_bias:
dense_inner.bias = subcell.dense.bias
subcell.dense = dense_inner subcell.dense = dense_inner
if subcell.has_act and subcell.activation is not None: if subcell.has_act and subcell.activation is not None:
subcell.activation = self._convert_activation(subcell.activation) subcell.activation = self._convert_activation(subcell.activation)
@ -247,12 +260,12 @@ class ConvertToQuantNetwork:
act_class = activation.__class__ act_class = activation.__class__
if act_class not in _ACTIVATION_MAP: if act_class not in _ACTIVATION_MAP:
raise ValueError( raise ValueError(
"Unsupported activation in auto Quant: ", act_class) "Unsupported activation in auto quant: ", act_class)
return _ACTIVATION_MAP[act_class](num_bits=self.act_bits, return _ACTIVATION_MAP[act_class](num_bits=self.act_bits,
quant_delay=self.act_qdelay, quant_delay=self.act_qdelay,
per_channel=self.act_channel, per_channel=self.act_channel,
symmetric=self.weight_symmetric, symmetric=self.act_symmetric,
narrow_range=self.weight_range) narrow_range=self.act_range)
class ExportQuantNetworkDeploy: class ExportQuantNetworkDeploy:

View File

@ -40,8 +40,6 @@ tensor_to_np_type = {"Int8": np.int8, "Uint8": np.uint8, "Int16": np.int16, "Uin
"Int32": np.int32, "Uint32": np.uint32, "Int64": np.int64, "Uint64": np.uint64, "Int32": np.int32, "Uint32": np.uint32, "Int64": np.int64, "Uint64": np.uint64,
"Float16": np.float16, "Float32": np.float32, "Float64": np.float64, "Bool": np.bool_} "Float16": np.float16, "Float32": np.float32, "Float64": np.float64, "Bool": np.bool_}
ModelType = ["normal", "fusion", "quant"]
def _special_process_par(par, new_par): def _special_process_par(par, new_par):
""" """
@ -103,7 +101,7 @@ def _update_param(param, new_param):
param.set_parameter_data(type(param.data)(new_param.data)) param.set_parameter_data(type(param.data)(new_param.data))
def save_checkpoint(parameter_list, ckpt_file_name, model_type="normal"): def save_checkpoint(parameter_list, ckpt_file_name):
""" """
Saves checkpoint info to a specified file. Saves checkpoint info to a specified file.
@ -111,14 +109,12 @@ def save_checkpoint(parameter_list, ckpt_file_name, model_type="normal"):
parameter_list (list): Parameters list, each element is a dict parameter_list (list): Parameters list, each element is a dict
like {"name":xx, "type":xx, "shape":xx, "data":xx}. like {"name":xx, "type":xx, "shape":xx, "data":xx}.
ckpt_file_name (str): Checkpoint file name. ckpt_file_name (str): Checkpoint file name.
model_type (str): The name of model type. Default: "normal".
Raises: Raises:
RuntimeError: Failed to save the Checkpoint file. RuntimeError: Failed to save the Checkpoint file.
""" """
logger.info("Execute save checkpoint process.") logger.info("Execute save checkpoint process.")
checkpoint_list = Checkpoint() checkpoint_list = Checkpoint()
checkpoint_list.model_type = model_type
try: try:
for param in parameter_list: for param in parameter_list:
@ -147,13 +143,12 @@ def save_checkpoint(parameter_list, ckpt_file_name, model_type="normal"):
logger.info("Save checkpoint process finish.") logger.info("Save checkpoint process finish.")
def load_checkpoint(ckpt_file_name, model_type="normal", net=None): def load_checkpoint(ckpt_file_name, net=None):
""" """
Loads checkpoint info from a specified file. Loads checkpoint info from a specified file.
Args: Args:
ckpt_file_name (str): Checkpoint file name. ckpt_file_name (str): Checkpoint file name.
model_type (str): The name of model type in `normal`, `fusion` or `quant`. Default: "normal".
net (Cell): Cell network. Default: None net (Cell): Cell network. Default: None
Returns: Returns:
@ -165,9 +160,6 @@ def load_checkpoint(ckpt_file_name, model_type="normal", net=None):
if not isinstance(ckpt_file_name, str): if not isinstance(ckpt_file_name, str):
raise ValueError("The ckpt_file_name must be string.") raise ValueError("The ckpt_file_name must be string.")
if model_type not in ModelType:
raise ValueError(f"The model_type is not in {ModelType}.")
if not os.path.exists(ckpt_file_name) or ckpt_file_name[-5:] != ".ckpt": if not os.path.exists(ckpt_file_name) or ckpt_file_name[-5:] != ".ckpt":
raise ValueError("Please input the correct checkpoint file name.") raise ValueError("Please input the correct checkpoint file name.")
@ -186,10 +178,6 @@ def load_checkpoint(ckpt_file_name, model_type="normal", net=None):
raise ValueError(e.__str__()) raise ValueError(e.__str__())
parameter_dict = {} parameter_dict = {}
if checkpoint_list.model_type:
if model_type != checkpoint_list.model_type:
raise KeyError("Checkpoint file model type({}) is not equal to input model type({}).".format(
checkpoint_list.model_type, model_type))
try: try:
for element in checkpoint_list.value: for element in checkpoint_list.value:
data = element.tensor.tensor_content data = element.tensor.tensor_content
@ -314,14 +302,13 @@ def _save_graph(network, file_name):
os.chmod(file_name, stat.S_IWUSR | stat.S_IRUSR) os.chmod(file_name, stat.S_IWUSR | stat.S_IRUSR)
def _exec_save_checkpoint(train_network, ckpt_file_name, model_type="normal", integrated_save=True): def _exec_save_checkpoint(train_network, ckpt_file_name, integrated_save=True):
""" """
Saves checkpoint for 'ms' backend. Saves checkpoint for 'ms' backend.
Args: Args:
train_network (Network): The train network for training. train_network (Network): The train network for training.
ckpt_file_name (str): The name of checkpoint file. ckpt_file_name (str): The name of checkpoint file.
model_type (str): The name of model type in `normal`, `fusion` or `quant`. Default: "normal".
integrated_save (bool): Whether to integrated save in automatic model parallel scene. integrated_save (bool): Whether to integrated save in automatic model parallel scene.
""" """
@ -346,7 +333,7 @@ def _exec_save_checkpoint(train_network, ckpt_file_name, model_type="normal", in
each_param["data"] = param_data each_param["data"] = param_data
param_list.append(each_param) param_list.append(each_param)
save_checkpoint(param_list, ckpt_file_name, model_type) save_checkpoint(param_list, ckpt_file_name)
def _get_merged_param_data(net, param_name, param_data): def _get_merged_param_data(net, param_name, param_data):

View File

@ -33,7 +33,7 @@ Then you will get the following display
```bash ```bash
>>> Found existing installation: mindspore-ascend >>> Found existing installation: mindspore-ascend
>>> Uninstalling mindspore-ascend: >>> Uninstalling mindspore-ascend:
>>> Successfully uninstalled mindspore-ascend. >>> Successfully uninstalled mindspore-ascend.
``` ```
### Prepare Dataset ### Prepare Dataset
@ -186,7 +186,7 @@ model = Model(network, net_loss, net_opt, metrics={"Accuracy": Accuracy()})
### train quantization aware model ### train quantization aware model
Also, you can just run this command instread. Also, you can just run this command instead.
```python ```python
python train_quant.py --data_path MNIST_Data --device_target Ascend --ckpt_path checkpoint_lenet.ckpt python train_quant.py --data_path MNIST_Data --device_target Ascend --ckpt_path checkpoint_lenet.ckpt
@ -235,7 +235,7 @@ The top1 accuracy would display on shell.
Here are some optional parameters: Here are some optional parameters:
```bash ```bash
--device_target {Ascend,GPU,CPU} --device_target {Ascend,GPU}
device where the code will be implemented (default: Ascend) device where the code will be implemented (default: Ascend)
--data_path DATA_PATH --data_path DATA_PATH
path where the dataset is saved path where the dataset is saved

View File

@ -31,7 +31,7 @@ from src.lenet_fusion import LeNet5 as LeNet5Fusion
parser = argparse.ArgumentParser(description='MindSpore MNIST Example') parser = argparse.ArgumentParser(description='MindSpore MNIST Example')
parser.add_argument('--device_target', type=str, default="Ascend", parser.add_argument('--device_target', type=str, default="Ascend",
choices=['Ascend', 'GPU', 'CPU'], choices=['Ascend', 'GPU'],
help='device where the code will be implemented (default: Ascend)') help='device where the code will be implemented (default: Ascend)')
parser.add_argument('--data_path', type=str, default="./MNIST_Data", parser.add_argument('--data_path', type=str, default="./MNIST_Data",
help='path where the dataset is saved') help='path where the dataset is saved')

View File

@ -32,7 +32,7 @@ from src.lenet_fusion import LeNet5 as LeNet5Fusion
parser = argparse.ArgumentParser(description='MindSpore MNIST Example') parser = argparse.ArgumentParser(description='MindSpore MNIST Example')
parser.add_argument('--device_target', type=str, default="Ascend", parser.add_argument('--device_target', type=str, default="Ascend",
choices=['Ascend', 'GPU', 'CPU'], choices=['Ascend', 'GPU'],
help='device where the code will be implemented (default: Ascend)') help='device where the code will be implemented (default: Ascend)')
parser.add_argument('--data_path', type=str, default="./MNIST_Data", parser.add_argument('--data_path', type=str, default="./MNIST_Data",
help='path where the dataset is saved') help='path where the dataset is saved')
@ -61,7 +61,7 @@ if __name__ == "__main__":
model = Model(network, net_loss, net_opt, metrics={"Accuracy": Accuracy()}) model = Model(network, net_loss, net_opt, metrics={"Accuracy": Accuracy()})
# load quantization aware network checkpoint # load quantization aware network checkpoint
param_dict = load_checkpoint(args.ckpt_path, model_type="quant") param_dict = load_checkpoint(args.ckpt_path)
load_param_into_net(network, param_dict) load_param_into_net(network, param_dict)
print("============== Starting Testing ==============") print("============== Starting Testing ==============")

View File

@ -31,7 +31,7 @@ from src.lenet_fusion import LeNet5 as LeNet5Fusion
parser = argparse.ArgumentParser(description='MindSpore MNIST Example') parser = argparse.ArgumentParser(description='MindSpore MNIST Example')
parser.add_argument('--device_target', type=str, default="Ascend", parser.add_argument('--device_target', type=str, default="Ascend",
choices=['Ascend', 'GPU', 'CPU'], choices=['Ascend', 'GPU'],
help='device where the code will be implemented (default: Ascend)') help='device where the code will be implemented (default: Ascend)')
parser.add_argument('--data_path', type=str, default="./MNIST_Data", parser.add_argument('--data_path', type=str, default="./MNIST_Data",
help='path where the dataset is saved') help='path where the dataset is saved')
@ -56,8 +56,7 @@ if __name__ == "__main__":
# call back and monitor # call back and monitor
time_cb = TimeMonitor(data_size=ds_train.get_dataset_size()) time_cb = TimeMonitor(data_size=ds_train.get_dataset_size())
config_ckpt = CheckpointConfig(save_checkpoint_steps=cfg.epoch_size * step_size, config_ckpt = CheckpointConfig(save_checkpoint_steps=cfg.epoch_size * step_size,
keep_checkpoint_max=cfg.keep_checkpoint_max, keep_checkpoint_max=cfg.keep_checkpoint_max)
model_type=network.type)
ckpt_callback = ModelCheckpoint(prefix="checkpoint_lenet", config=config_ckpt) ckpt_callback = ModelCheckpoint(prefix="checkpoint_lenet", config=config_ckpt)
# define model # define model

View File

@ -33,7 +33,7 @@ from src.lenet_fusion import LeNet5 as LeNet5Fusion
parser = argparse.ArgumentParser(description='MindSpore MNIST Example') parser = argparse.ArgumentParser(description='MindSpore MNIST Example')
parser.add_argument('--device_target', type=str, default="Ascend", parser.add_argument('--device_target', type=str, default="Ascend",
choices=['Ascend', 'GPU', 'CPU'], choices=['Ascend', 'GPU'],
help='device where the code will be implemented (default: Ascend)') help='device where the code will be implemented (default: Ascend)')
parser.add_argument('--data_path', type=str, default="./MNIST_Data", parser.add_argument('--data_path', type=str, default="./MNIST_Data",
help='path where the dataset is saved') help='path where the dataset is saved')
@ -50,11 +50,13 @@ if __name__ == "__main__":
# define fusion network # define fusion network
network = LeNet5Fusion(cfg.num_classes) network = LeNet5Fusion(cfg.num_classes)
# convert fusion network to quantization aware network
network = quant.convert_quant_network(network, quant_delay=0, bn_fold=False, freeze_bn=10000)
# load quantization aware network checkpoint # load quantization aware network checkpoint
param_dict = load_checkpoint(args.ckpt_path, network.type) param_dict = load_checkpoint(args.ckpt_path, network.type)
load_param_into_net(network, param_dict) load_param_into_net(network, param_dict)
# convert fusion network to quantization aware network
network = quant.convert_quant_network(network, quant_delay=0, bn_fold=False, freeze_bn=10000)
# define network loss # define network loss
net_loss = nn.SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True, reduction="mean") net_loss = nn.SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True, reduction="mean")
@ -64,8 +66,7 @@ if __name__ == "__main__":
# call back and monitor # call back and monitor
time_cb = TimeMonitor(data_size=ds_train.get_dataset_size()) time_cb = TimeMonitor(data_size=ds_train.get_dataset_size())
config_ckpt = CheckpointConfig(save_checkpoint_steps=cfg.epoch_size * step_size, config_ckpt = CheckpointConfig(save_checkpoint_steps=cfg.epoch_size * step_size,
keep_checkpoint_max=cfg.keep_checkpoint_max, keep_checkpoint_max=cfg.keep_checkpoint_max)
model_type="quant")
ckpt_callback = ModelCheckpoint(prefix="checkpoint_lenet", config=config_ckpt) ckpt_callback = ModelCheckpoint(prefix="checkpoint_lenet", config=config_ckpt)
# define model # define model

View File

@ -40,7 +40,7 @@ export PYTHONPATH=${BASEPATH}:$PYTHONPATH
export DEVICE_ID=0 export DEVICE_ID=0
export RANK_ID=0 export RANK_ID=0
export RANK_SIZE=1 export RANK_SIZE=1
if [ -d "eval" ]; if [ -d "../eval" ];
then then
rm -rf ../eval rm -rf ../eval
fi fi

View File

@ -62,7 +62,7 @@ run_gpu()
BASEPATH=$(cd "`dirname $0`" || exit; pwd) BASEPATH=$(cd "`dirname $0`" || exit; pwd)
export PYTHONPATH=${BASEPATH}:$PYTHONPATH export PYTHONPATH=${BASEPATH}:$PYTHONPATH
if [ -d "train" ]; if [ -d "../train" ];
then then
rm -rf ../train rm -rf ../train
fi fi

View File

@ -40,7 +40,7 @@ export PYTHONPATH=${BASEPATH}:$PYTHONPATH
export DEVICE_ID=0 export DEVICE_ID=0
export RANK_ID=0 export RANK_ID=0
export RANK_SIZE=1 export RANK_SIZE=1
if [ -d "eval" ]; if [ -d "../eval" ];
then then
rm -rf ../eval rm -rf ../eval
fi fi

View File

@ -60,7 +60,7 @@ run_gpu()
BASEPATH=$(cd "`dirname $0`" || exit; pwd) BASEPATH=$(cd "`dirname $0`" || exit; pwd)
export PYTHONPATH=${BASEPATH}:$PYTHONPATH export PYTHONPATH=${BASEPATH}:$PYTHONPATH
if [ -d "train" ]; if [ -d "../train" ];
then then
rm -rf ../train rm -rf ../train
fi fi

View File

@ -31,7 +31,7 @@ def _conv_bn(in_channel,
out_channel, out_channel,
kernel_size=ksize, kernel_size=ksize,
stride=stride, stride=stride,
batchnorm=True)]) has_bn=True)])
class InvertedResidual(nn.Cell): class InvertedResidual(nn.Cell):
@ -49,25 +49,25 @@ class InvertedResidual(nn.Cell):
3, 3,
stride, stride,
group=hidden_dim, group=hidden_dim,
batchnorm=True, has_bn=True,
activation='relu6'), activation='relu6'),
nn.Conv2dBnAct(hidden_dim, oup, 1, 1, nn.Conv2dBnAct(hidden_dim, oup, 1, 1,
batchnorm=True) has_bn=True)
]) ])
else: else:
self.conv = nn.SequentialCell([ self.conv = nn.SequentialCell([
nn.Conv2dBnAct(inp, hidden_dim, 1, 1, nn.Conv2dBnAct(inp, hidden_dim, 1, 1,
batchnorm=True, has_bn=True,
activation='relu6'), activation='relu6'),
nn.Conv2dBnAct(hidden_dim, nn.Conv2dBnAct(hidden_dim,
hidden_dim, hidden_dim,
3, 3,
stride, stride,
group=hidden_dim, group=hidden_dim,
batchnorm=True, has_bn=True,
activation='relu6'), activation='relu6'),
nn.Conv2dBnAct(hidden_dim, oup, 1, 1, nn.Conv2dBnAct(hidden_dim, oup, 1, 1,
batchnorm=True) has_bn=True)
]) ])
self.add = P.TensorAdd() self.add = P.TensorAdd()

View File

@ -42,7 +42,7 @@ class LeNet5(nn.Cell):
def __init__(self, num_class=10): def __init__(self, num_class=10):
super(LeNet5, self).__init__() super(LeNet5, self).__init__()
self.num_class = num_class self.num_class = num_class
self.conv1 = nn.Conv2dBnAct(1, 6, kernel_size=5, batchnorm=True, activation='relu6', pad_mode="valid") self.conv1 = nn.Conv2dBnAct(1, 6, kernel_size=5, has_bn=True, activation='relu6', pad_mode="valid")
self.conv2 = nn.Conv2dBnAct(6, 16, kernel_size=5, activation='relu', pad_mode="valid") self.conv2 = nn.Conv2dBnAct(6, 16, kernel_size=5, activation='relu', pad_mode="valid")
self.fc1 = nn.DenseBnAct(16 * 5 * 5, 120, activation='relu') self.fc1 = nn.DenseBnAct(16 * 5 * 5, 120, activation='relu')
self.fc2 = nn.DenseBnAct(120, 84, activation='relu') self.fc2 = nn.DenseBnAct(120, 84, activation='relu')