From 1089c908a9029441759c33f963762801e8d9a809 Mon Sep 17 00:00:00 2001 From: chenzomi Date: Tue, 30 Jun 2020 16:51:32 +0800 Subject: [PATCH] cherry-pick r0.5 to master for quantizaiton aware training --- mindspore/ccsrc/utils/checkpoint.proto | 1 - mindspore/nn/cell.py | 9 ++ mindspore/nn/layer/conv.py | 150 +++++++++++++++++- mindspore/nn/layer/quant.py | 74 +++++---- .../_op_impl/_custom_op/batchnorm_fold2.py | 1 - .../_custom_op/batchnorm_fold2_grad.py | 1 - .../_custom_op/batchnorm_fold2_grad_reduce.py | 1 - .../ops/_op_impl/_custom_op/correction_mul.py | 1 - .../_custom_op/correction_mul_grad.py | 2 - .../_custom_op/fake_quant_perchannel.py | 14 +- .../_custom_op/fake_quant_perchannel_grad.py | 14 +- .../_custom_op/minmax_update_perchannel.py | 14 +- mindspore/ops/operations/_quant_ops.py | 41 +++-- mindspore/train/callback/_checkpoint.py | 25 +-- mindspore/train/callback/_loss_monitor.py | 2 +- mindspore/train/quant/quant.py | 43 +++-- mindspore/train/serialization.py | 21 +-- model_zoo/lenet_quant/README.md | 6 +- model_zoo/lenet_quant/eval.py | 2 +- model_zoo/lenet_quant/eval_quant.py | 4 +- model_zoo/lenet_quant/train.py | 5 +- model_zoo/lenet_quant/train_quant.py | 11 +- model_zoo/mobilenetv2/scripts/run_infer.sh | 2 +- model_zoo/mobilenetv2/scripts/run_train.sh | 2 +- model_zoo/mobilenetv3/scripts/run_infer.sh | 2 +- model_zoo/mobilenetv3/scripts/run_train.sh | 2 +- .../train/quant/mobilenetv2_combined.py | 12 +- tests/ut/python/train/quant/test_quant.py | 2 +- 28 files changed, 322 insertions(+), 142 deletions(-) diff --git a/mindspore/ccsrc/utils/checkpoint.proto b/mindspore/ccsrc/utils/checkpoint.proto index 7fca399e2b..31c7cd8004 100644 --- a/mindspore/ccsrc/utils/checkpoint.proto +++ b/mindspore/ccsrc/utils/checkpoint.proto @@ -22,7 +22,6 @@ message Checkpoint { required TensorProto tensor = 2; } repeated Value value = 1; - required string model_type = 2; } diff --git a/mindspore/nn/cell.py b/mindspore/nn/cell.py index 0533546400..cffe00a920 100755 --- a/mindspore/nn/cell.py +++ b/mindspore/nn/cell.py @@ -81,6 +81,7 @@ class Cell: self.enable_hook = False self._bprop_debug = False self._is_run = False + self.cell_type = None @property def is_run(self): @@ -140,6 +141,14 @@ class Cell: for cell_name, cell in cells_name: cell._param_prefix = cell_name + def update_cell_type(self, cell_type): + """ + Update current cell type mainly identify if quantization aware training network. + + After invoked, can set the cell type to 'cell_type'. + """ + self.cell_type = cell_type + @cell_init_args.setter def cell_init_args(self, value): if not isinstance(value, str): diff --git a/mindspore/nn/layer/conv.py b/mindspore/nn/layer/conv.py index b2a0de9cbe..52ec9f2d63 100644 --- a/mindspore/nn/layer/conv.py +++ b/mindspore/nn/layer/conv.py @@ -17,11 +17,12 @@ from mindspore import log as logger from mindspore.ops import operations as P from mindspore.common.parameter import Parameter from mindspore.common.initializer import initializer +from mindspore._checkparam import ParamValidator as validator, Rel from mindspore._checkparam import check_bool, twice, check_int_positive, check_int_non_negative from mindspore._extends import cell_attr_register from ..cell import Cell -__all__ = ['Conv2d', 'Conv2dTranspose'] +__all__ = ['Conv2d', 'Conv2dTranspose', 'DepthwiseConv2d'] class _Conv(Cell): """ @@ -397,3 +398,150 @@ class Conv2dTranspose(_Conv): self.weight, self.bias) return s + + +class DepthwiseConv2d(Cell): + r""" + 2D depthwise convolution layer. + + Applies a 2D depthwise convolution over an input tensor which is typically of shape: + math:`(N, C_{in}, H_{in}, W_{in})`, where :math:`N` is batch size and :math:`C_{in}` is channel number. + For each batch of shape:math:`(C_{in}, H_{in}, W_{in})`, the formula is defined as: + + .. math:: + + out_j = \sum_{i=0}^{C_{in} - 1} ccor(W_{ij}, X_i) + b_j, + + where :math:`ccor` is cross correlation operator, :math:`C_{in}` is the input channel number, :math:`j` ranges + from :math:`0` to :math:`C_{out} - 1`, :math:`W_{ij}` corresponds to :math:`i`-th channel of the :math:`j`-th + filter and :math:`out_{j}` corresponds to the :math:`j`-th channel of the output. :math:`W_{ij}` is a slice + of kernel and it has shape :math:`(\text{ks_h}, \text{ks_w})`, where :math:`\text{ks_h}` and + :math:`\text{ks_w}` are height and width of the convolution kernel. The full kernel has shape + :math:`(C_{out}, C_{in} // \text{group}, \text{ks_h}, \text{ks_w})`, where group is the group number + to split the input in the channel dimension. + + If the 'pad_mode' is set to be "valid", the output height and width will be + :math:`\left \lfloor{1 + \frac{H_{in} + 2 \times \text{padding} - \text{ks_h} - + (\text{ks_h} - 1) \times (\text{dilation} - 1) }{\text{stride}}} \right \rfloor` and + :math:`\left \lfloor{1 + \frac{W_{in} + 2 \times \text{padding} - \text{ks_w} - + (\text{ks_w} - 1) \times (\text{dilation} - 1) }{\text{stride}}} \right \rfloor` respectively. + + The first introduction can be found in paper `Gradient Based Learning Applied to Document Recognition + `_. + + Args: + in_channels (int): The number of input channel :math:`C_{in}`. + out_channels (int): The number of output channel :math:`C_{out}`. + kernel_size (Union[int, tuple[int]]): The data type is int or tuple with 2 integers. Specifies the height + and width of the 2D convolution window. Single int means the value if for both height and width of + the kernel. A tuple of 2 ints means the first value is for the height and the other is for the + width of the kernel. + stride (Union[int, tuple[int]]): The distance of kernel moving, an int number that represents + the height and width of movement are both strides, or a tuple of two int numbers that + represent height and width of movement respectively. Default: 1. + pad_mode (str): Specifies padding mode. The optional values are + "same", "valid", "pad". Default: "same". + + - same: Adopts the way of completion. Output height and width will be the same as the input. + Total number of padding will be calculated for horizontal and vertical + direction and evenly distributed to top and bottom, left and right if possible. Otherwise, the + last extra padding will be done from the bottom and the right side. If this mode is set, `padding` + must be 0. + + - valid: Adopts the way of discarding. The possibly largest height and width of output will be return + without padding. Extra pixels will be discarded. If this mode is set, `padding` + must be 0. + + - pad: Implicit paddings on both sides of the input. The number of `padding` will be padded to the input + Tensor borders. `padding` should be greater than or equal to 0. + + padding (int): Implicit paddings on both sides of the input. Default: 0. + dilation (Union[int, tuple[int]]): The data type is int or tuple with 2 integers. Specifies the dilation rate + to use for dilated convolution. If set to be :math:`k > 1`, there will + be :math:`k - 1` pixels skipped for each sampling location. Its value should + be greater or equal to 1 and bounded by the height and width of the + input. Default: 1. + group (int): Split filter into groups, `in_ channels` and `out_channels` should be + divisible by the number of groups. Default: 1. + has_bias (bool): Specifies whether the layer uses a bias vector. Default: False. + weight_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the convolution kernel. + It can be a Tensor, a string, an Initializer or a numbers.Number. When a string is specified, + values from 'TruncatedNormal', 'Normal', 'Uniform', 'HeUniform' and 'XavierUniform' distributions as well + as constant 'One' and 'Zero' distributions are possible. Alias 'xavier_uniform', 'he_uniform', 'ones' + and 'zeros' are acceptable. Uppercase and lowercase are both acceptable. Refer to the values of + Initializer for more details. Default: 'normal'. + bias_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the bias vector. Possible + Initializer and string are the same as 'weight_init'. Refer to the values of + Initializer for more details. Default: 'zeros'. + + Inputs: + - **input** (Tensor) - Tensor of shape :math:`(N, C_{in}, H_{in}, W_{in})`. + + Outputs: + Tensor of shape :math:`(N, C_{out}, H_{out}, W_{out})`. + + Examples: + >>> net = nn.DepthwiseConv2d(120, 240, 4, has_bias=False, weight_init='normal') + >>> input = Tensor(np.ones([1, 120, 1024, 640]), mindspore.float32) + >>> net(input).shape + (1, 240, 1024, 640) + """ + def __init__(self, + in_channels, + out_channels, + kernel_size, + stride=1, + pad_mode='same', + padding=0, + dilation=1, + group=1, + has_bias=False, + weight_init='normal', + bias_init='zeros'): + super(DepthwiseConv2d, self).__init__() + self.kernel_size = twice(kernel_size) + self.stride = twice(stride) + self.dilation = twice(dilation) + self.in_channels = check_int_positive(in_channels) + self.out_channels = check_int_positive(out_channels) + validator.check_integer('group', group, in_channels, Rel.EQ) + validator.check_integer('group', group, out_channels, Rel.EQ) + validator.check_integer('group', group, 1, Rel.GE) + self.pad_mode = pad_mode + self.padding = padding + self.dilation = dilation + self.group = group + self.has_bias = has_bias + self.conv = P.DepthwiseConv2dNative(channel_multiplier=1, + kernel_size=self.kernel_size, + pad_mode=self.pad_mode, + pad=self.padding, + stride=self.stride, + dilation=self.dilation) + self.bias_add = P.BiasAdd() + weight_shape = [1, in_channels, *self.kernel_size] + self.weight = Parameter(initializer(weight_init, weight_shape), name='weight') + if check_bool(has_bias): + self.bias = Parameter(initializer(bias_init, [out_channels]), name='bias') + else: + if bias_init != 'zeros': + logger.warning("value of `has_bias` is False, value of `bias_init` will be ignore.") + self.bias = None + + def construct(self, x): + out = self.conv(x, self.weight) + if self.has_bias: + out = self.bias_add(out, self.bias) + return out + + def extend_repr(self): + s = 'input_channels={}, output_channels={}, kernel_size={}, stride={}, ' \ + 'pad_mode={}, padding={}, dilation={}, group={},' \ + 'has_bias={}, weight_init={}, bias_init={}'.format( + self.in_channels, self.out_channels, self.kernel_size, self.stride, + self.pad_mode, self.padding, self.dilation, self.group, + self.has_bias, self.weight_init, self.bias_init) + + if self.has_bias: + s += ', bias={}'.format(self.bias) + return s diff --git a/mindspore/nn/layer/quant.py b/mindspore/nn/layer/quant.py index 14731c6262..225d37bf84 100644 --- a/mindspore/nn/layer/quant.py +++ b/mindspore/nn/layer/quant.py @@ -16,6 +16,7 @@ from functools import partial import numpy as np + import mindspore.common.dtype as mstype from mindspore.ops import operations as P from mindspore.ops import functional as F @@ -23,10 +24,9 @@ from mindspore.common.parameter import Parameter from mindspore.common.initializer import initializer from mindspore.common.tensor import Tensor from mindspore._checkparam import check_int_positive, check_bool, twice -from mindspore._checkparam import Validator as validator, Rel -from mindspore.nn.cell import Cell -from mindspore.nn.layer.activation import get_activation +from mindspore._checkparam import Rel import mindspore.context as context + from .normalization import BatchNorm2d from .activation import get_activation from ..cell import Cell @@ -82,7 +82,7 @@ class Conv2dBnAct(Cell): bias_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the bias vector. Possible Initializer and string are the same as 'weight_init'. Refer to the values of Initializer for more details. Default: 'zeros'. - batchnorm (bool): Specifies to used batchnorm or not. Default: None. + has_bn (bool): Specifies to used batchnorm or not. Default: False. activation (string): Specifies activation type. The optional values are as following: 'softmax', 'logsoftmax', 'relu', 'relu6', 'tanh', 'gelu', 'sigmoid', 'prelu', 'leakyrelu', 'hswish', 'hsigmoid'. Default: None. @@ -94,7 +94,7 @@ class Conv2dBnAct(Cell): Tensor of shape :math:`(N, C_{out}, H_{out}, W_{out})`. Examples: - >>> net = Conv2dBnAct(120, 240, 4, batchnorm=True, activation='ReLU') + >>> net = Conv2dBnAct(120, 240, 4, has_bn=True, activation='ReLU') >>> input = Tensor(np.ones([1, 120, 1024, 640]), mindspore.float32) >>> net(input).shape (1, 240, 1024, 640) @@ -112,28 +112,39 @@ class Conv2dBnAct(Cell): has_bias=False, weight_init='normal', bias_init='zeros', - batchnorm=None, + has_bn=False, activation=None): super(Conv2dBnAct, self).__init__() - self.conv = conv.Conv2d( - in_channels, - out_channels, - kernel_size, - stride, - pad_mode, - padding, - dilation, - group, - has_bias, - weight_init, - bias_init) - self.has_bn = batchnorm is not None + + if context.get_context('device_target') == "Ascend" and group > 1: + self.conv = conv.DepthwiseConv2d(in_channels, + out_channels, + kernel_size=kernel_size, + stride=stride, + pad_mode=pad_mode, + padding=padding, + dilation=dilation, + group=group, + has_bias=has_bias, + weight_init=weight_init, + bias_init=bias_init) + else: + self.conv = conv.Conv2d(in_channels, + out_channels, + kernel_size=kernel_size, + stride=stride, + pad_mode=pad_mode, + padding=padding, + dilation=dilation, + group=group, + has_bias=has_bias, + weight_init=weight_init, + bias_init=bias_init) + + self.has_bn = validator.check_bool("has_bn", has_bn) self.has_act = activation is not None - self.batchnorm = batchnorm - if batchnorm is True: + if has_bn: self.batchnorm = BatchNorm2d(out_channels) - elif batchnorm is not None: - validator.check_isinstance('batchnorm', batchnorm, (BatchNorm2d,)) self.activation = get_activation(activation) def construct(self, x): @@ -160,7 +171,7 @@ class DenseBnAct(Cell): same as input x. The values of str refer to the function `initializer`. Default: 'zeros'. has_bias (bool): Specifies whether the layer uses a bias vector. Default: True. activation (str): Regularizer function applied to the output of the layer, eg. 'relu'. Default: None. - batchnorm (bool): Specifies to used batchnorm or not. Default: None. + has_bn (bool): Specifies to used batchnorm or not. Default: False. activation (string): Specifies activation type. The optional values are as following: 'softmax', 'logsoftmax', 'relu', 'relu6', 'tanh', 'gelu', 'sigmoid', 'prelu', 'leakyrelu', 'hswish', 'hsigmoid'. Default: None. @@ -183,7 +194,7 @@ class DenseBnAct(Cell): weight_init='normal', bias_init='zeros', has_bias=True, - batchnorm=None, + has_bn=False, activation=None): super(DenseBnAct, self).__init__() self.dense = basic.Dense( @@ -192,12 +203,10 @@ class DenseBnAct(Cell): weight_init, bias_init, has_bias) - self.has_bn = batchnorm is not None + self.has_bn = validator.check_bool("has_bn", has_bn) self.has_act = activation is not None - if batchnorm is True: + if has_bn: self.batchnorm = BatchNorm2d(out_channels) - elif batchnorm is not None: - validator.check_isinstance('batchnorm', batchnorm, (BatchNorm2d,)) self.activation = get_activation(activation) def construct(self, x): @@ -312,6 +321,10 @@ class FakeQuantWithMinMax(Cell): quant_delay=0): """init FakeQuantWithMinMax layer""" super(FakeQuantWithMinMax, self).__init__() + validator.check_type("min_init", min_init, [int, float]) + validator.check_type("max_init", max_init, [int, float]) + validator.check("min_init", min_init, "max_init", max_init, rel=Rel.LT) + validator.check_integer('quant_delay', quant_delay, 0, Rel.GE) self.min_init = min_init self.max_init = max_init self.num_bits = num_bits @@ -1183,12 +1196,13 @@ class QuantBlock(Cell): self.has_bias = bias is None self.activation = activation self.has_act = activation is None + self.bias_add = P.BiasAdd() def construct(self, x): x = self.quant(x) x = self.core_op(x, self.weight) if self.has_bias: - output = self.bias_add(output, self.bias) + x = self.bias_add(x, self.bias) if self.has_act: x = self.activation(x) x = self.dequant(x, self.dequant_scale) diff --git a/mindspore/ops/_op_impl/_custom_op/batchnorm_fold2.py b/mindspore/ops/_op_impl/_custom_op/batchnorm_fold2.py index 7e98517057..9daab5a75f 100644 --- a/mindspore/ops/_op_impl/_custom_op/batchnorm_fold2.py +++ b/mindspore/ops/_op_impl/_custom_op/batchnorm_fold2.py @@ -30,7 +30,6 @@ batchnorm_fold2_op_info = TBERegOp("BatchNormFold2_D") \ .compute_cost(10) \ .kernel_name("batchnorm_fold2") \ .partial_flag(True) \ - .op_pattern("formatAgnostic") \ .input(0, "x", None, "required", None) \ .input(1, "beta", None, "required", None) \ .input(2, "gamma", None, "required", None) \ diff --git a/mindspore/ops/_op_impl/_custom_op/batchnorm_fold2_grad.py b/mindspore/ops/_op_impl/_custom_op/batchnorm_fold2_grad.py index 824da62d19..9994a88f30 100644 --- a/mindspore/ops/_op_impl/_custom_op/batchnorm_fold2_grad.py +++ b/mindspore/ops/_op_impl/_custom_op/batchnorm_fold2_grad.py @@ -30,7 +30,6 @@ batchnorm_fold2_grad_op_info = TBERegOp("BatchNormFold2GradD") \ .compute_cost(10) \ .kernel_name("batchnorm_fold2_grad") \ .partial_flag(True) \ - .op_pattern("formatAgnostic") \ .input(0, "dout", None, "required", None) \ .input(1, "dout_reduce", None, "required", None) \ .input(2, "dout_x_reduce", None, "required", None) \ diff --git a/mindspore/ops/_op_impl/_custom_op/batchnorm_fold2_grad_reduce.py b/mindspore/ops/_op_impl/_custom_op/batchnorm_fold2_grad_reduce.py index 7806c6834e..92b91ff712 100644 --- a/mindspore/ops/_op_impl/_custom_op/batchnorm_fold2_grad_reduce.py +++ b/mindspore/ops/_op_impl/_custom_op/batchnorm_fold2_grad_reduce.py @@ -31,7 +31,6 @@ batchnorm_fold2_grad_reduce_op_info = TBERegOp("BatchNormFold2GradReduce") \ .compute_cost(10) \ .kernel_name("batchnorm_fold2_grad_reduce") \ .partial_flag(True) \ - .op_pattern("formatAgnostic") \ .input(0, "dout", None, "required", None) \ .input(1, "x", None, "required", None) \ .output(0, "dout_reduce", True, "required", "all") \ diff --git a/mindspore/ops/_op_impl/_custom_op/correction_mul.py b/mindspore/ops/_op_impl/_custom_op/correction_mul.py index ce92d2bbc5..49cd35cc11 100644 --- a/mindspore/ops/_op_impl/_custom_op/correction_mul.py +++ b/mindspore/ops/_op_impl/_custom_op/correction_mul.py @@ -30,7 +30,6 @@ correction_mul_op_info = TBERegOp("CorrectionMul") \ .compute_cost(10) \ .kernel_name("correction_mul") \ .partial_flag(True) \ - .op_pattern("formatAgnostic") \ .attr("channel_axis", "optional", "int", "all") \ .input(0, "x", None, "required", None) \ .input(1, "batch_std", None, "required", None) \ diff --git a/mindspore/ops/_op_impl/_custom_op/correction_mul_grad.py b/mindspore/ops/_op_impl/_custom_op/correction_mul_grad.py index da3a634454..6c11ce6855 100644 --- a/mindspore/ops/_op_impl/_custom_op/correction_mul_grad.py +++ b/mindspore/ops/_op_impl/_custom_op/correction_mul_grad.py @@ -30,7 +30,6 @@ correction_mul_grad_op_info = TBERegOp("CorrectionMulGrad") \ .compute_cost(10) \ .kernel_name("correction_mul_grad") \ .partial_flag(True) \ - .op_pattern("formatAgnostic") \ .attr("channel_axis", "optional", "int", "all") \ .input(0, "dout", None, "required", None) \ .input(1, "x", None, "required", None) \ @@ -128,7 +127,6 @@ correction_mul_grad_reduce_op_info = TBERegOp("CorrectionMulGradReduce") \ .compute_cost(10) \ .kernel_name("correction_mul_grad_reduce") \ .partial_flag(True) \ - .op_pattern("formatAgnostic") \ .attr("channel_axis", "optional", "int", "all") \ .input(0, "dout", None, "required", None) \ .output(0, "d_batch_std", True, "required", "all") \ diff --git a/mindspore/ops/_op_impl/_custom_op/fake_quant_perchannel.py b/mindspore/ops/_op_impl/_custom_op/fake_quant_perchannel.py index f6c133c808..dae2d7058d 100644 --- a/mindspore/ops/_op_impl/_custom_op/fake_quant_perchannel.py +++ b/mindspore/ops/_op_impl/_custom_op/fake_quant_perchannel.py @@ -99,11 +99,15 @@ def fake_quant_perchannel(x, min_val, max_val, y, min_dtype = min_val.get("dtype") max_shape = max_val.get("ori_shape") max_dtype = max_val.get("dtype") - + # for Dense weight quant, 2d[co,ci] -> 4d[1,co,ci,1], channel_axis_ need change to 1. + if channel_axis == 0 and x_shape_[0] != min_shape[0] and x_shape_[1] == min_shape[0]: + channel_axis_ = 1 + else: + channel_axis_ = channel_axis util.check_kernel_name(kernel_name) util.check_shape_rule(x_shape) - util.check_shape_rule(min_shape, 1, 1, x_shape_[channel_axis]) - util.check_shape_rule(max_shape, 1, 1, x_shape_[channel_axis]) + util.check_shape_rule(min_shape, 1, 1, x_shape_[channel_axis_]) + util.check_shape_rule(max_shape, 1, 1, x_shape_[channel_axis_]) util.check_tensor_shape_size(x_shape) util.check_tensor_shape_size(min_shape) util.check_tensor_shape_size(max_shape) @@ -126,8 +130,8 @@ def fake_quant_perchannel(x, min_val, max_val, y, quant_min = quant_min + 1 shape_c = [1] * len(x_shape) - shape_c[channel_axis] = min_val.get("ori_shape")[0] - if x_format == "NC1HWC0" and channel_axis == 1: + shape_c[channel_axis_] = min_val.get("ori_shape")[0] + if x_format == "NC1HWC0" and channel_axis_ == 1: shape_c = min_val.get("shape") input_data = tvm.placeholder(x_shape, name="x", dtype=x_dtype) min_data = tvm.placeholder(shape_c, name="min_val", dtype=x_dtype) diff --git a/mindspore/ops/_op_impl/_custom_op/fake_quant_perchannel_grad.py b/mindspore/ops/_op_impl/_custom_op/fake_quant_perchannel_grad.py index 4e9053fcb1..795aab52a3 100644 --- a/mindspore/ops/_op_impl/_custom_op/fake_quant_perchannel_grad.py +++ b/mindspore/ops/_op_impl/_custom_op/fake_quant_perchannel_grad.py @@ -124,11 +124,15 @@ def fake_quant_perchannel_grad(dout, x, min_val, max_val, dx, min_dtype = min_val.get("dtype") max_shape = max_val.get("ori_shape") max_dtype = max_val.get("dtype") - + # for Dense weight quant, 2d[co,ci] -> 4d[1,co,ci,1], channel_axis_ need change to 1. + if channel_axis == 0 and x_shape_[0] != min_shape[0] and x_shape_[1] == min_shape[0]: + channel_axis_ = 1 + else: + channel_axis_ = channel_axis util.check_kernel_name(kernel_name) util.check_shape_rule(x_shape) - util.check_shape_rule(min_shape, 1, 1, x_shape_[channel_axis]) - util.check_shape_rule(max_shape, 1, 1, x_shape_[channel_axis]) + util.check_shape_rule(min_shape, 1, 1, x_shape_[channel_axis_]) + util.check_shape_rule(max_shape, 1, 1, x_shape_[channel_axis_]) util.check_tensor_shape_size(x_shape) util.check_tensor_shape_size(min_shape) util.check_tensor_shape_size(max_shape) @@ -151,8 +155,8 @@ def fake_quant_perchannel_grad(dout, x, min_val, max_val, dx, quant_min = quant_min + 1 shape_c = [1] * len(x_shape) - shape_c[channel_axis] = min_val.get("ori_shape")[0] - if x_format == "NC1HWC0" and channel_axis == 1: + shape_c[channel_axis_] = min_val.get("ori_shape")[0] + if x_format == "NC1HWC0" and channel_axis_ == 1: shape_c = min_val.get("shape") dout_data = tvm.placeholder(x_shape, name="dout", dtype=x_dtype) input_data = tvm.placeholder(x_shape, name="x", dtype=x_dtype) diff --git a/mindspore/ops/_op_impl/_custom_op/minmax_update_perchannel.py b/mindspore/ops/_op_impl/_custom_op/minmax_update_perchannel.py index 1ff63464c3..f29fc53325 100644 --- a/mindspore/ops/_op_impl/_custom_op/minmax_update_perchannel.py +++ b/mindspore/ops/_op_impl/_custom_op/minmax_update_perchannel.py @@ -88,11 +88,15 @@ def minmax_update_perchannel(x, min_val, max_val, min_up, max_up, min_dtype = min_val.get("dtype") max_shape = max_val.get("ori_shape") max_dtype = max_val.get("dtype") - + # for Dense weight quant, 2d[co,ci] -> 4d[1,co,ci,1], channel_axis_ need change to 1. + if channel_axis == 0 and x_shape[0] != min_shape[0] and x_shape[1] == min_shape[0]: + channel_axis_ = 1 + else: + channel_axis_ = channel_axis util.check_kernel_name(kernel_name) util.check_shape_rule(x_shape) - util.check_shape_rule(min_shape, 1, 1, x_shape[channel_axis]) - util.check_shape_rule(max_shape, 1, 1, x_shape[channel_axis]) + util.check_shape_rule(min_shape, 1, 1, x_shape[channel_axis_]) + util.check_shape_rule(max_shape, 1, 1, x_shape[channel_axis_]) util.check_tensor_shape_size(x_shape) util.check_tensor_shape_size(min_shape) util.check_tensor_shape_size(max_shape) @@ -105,7 +109,7 @@ def minmax_update_perchannel(x, min_val, max_val, min_up, max_up, util.check_dtype_rule(min_dtype, check_list) util.check_dtype_rule(max_dtype, check_list) - if channel_axis == 0: + if channel_axis_ == 0: shape_c = min_val.get("ori_shape") else: shape_c = [min_val.get("shape")[1], min_val.get("shape")[-1]] @@ -113,7 +117,7 @@ def minmax_update_perchannel(x, min_val, max_val, min_up, max_up, min_data = tvm.placeholder(shape_c, name="min_val", dtype=x_dtype) max_data = tvm.placeholder(shape_c, name="max_val", dtype=x_dtype) res_list = minmax_update_perchannel_compute(input_data, min_data, max_data, - ema, ema_decay, channel_axis) + ema, ema_decay, channel_axis_) with tvm.target.cce(): sch = generic.auto_schedule(res_list) diff --git a/mindspore/ops/operations/_quant_ops.py b/mindspore/ops/operations/_quant_ops.py index 42c2406906..1f4de03d3c 100644 --- a/mindspore/ops/operations/_quant_ops.py +++ b/mindspore/ops/operations/_quant_ops.py @@ -106,7 +106,7 @@ class MinMaxUpdatePerChannel(PrimitiveWithInfer): Args: ema (bool): Use EMA algorithm update value min and max. Default: False. ema_decay (int) : EMA algorithm decay parameter. Default: 0.999. - channel_axis (int): Channel asis for per channel compute. Default: 1. + channel_axis (int): Quantization by channel axis. Ascend backend only supports 0 or 1. Default: 1. Inputs: - **x** (Tensor) : float32 Tensor representing the shape of the output tensor. @@ -123,11 +123,13 @@ class MinMaxUpdatePerChannel(PrimitiveWithInfer): >>> output_tensor = MinMaxUpdatePerChannel(num_bits=8)(x, min, max) """ support_quant_bit = [4, 7, 8] + ascend_support_x_rank = [2, 4] @prim_attr_register def __init__(self, ema=False, ema_decay=0.999, channel_axis=1): """init FakeQuantPerChannelUpdate OP for Ascend""" - if context.get_context('device_target') == "Ascend": + self.is_ascend = context.get_context('device_target') == "Ascend" + if self.is_ascend: from mindspore.ops._op_impl._custom_op import minmax_update_perchannel if ema and not ema_decay: raise ValueError( @@ -136,13 +138,18 @@ class MinMaxUpdatePerChannel(PrimitiveWithInfer): self.ema = validator.check_value_type('ema', ema, (bool,), self.name) self.ema_decay = validator.check_number_range( 'ema_decay', ema_decay, 0, 1, Rel.INC_BOTH, self.name) - self.channel_axis = validator.check_integer( - 'channel axis', channel_axis, 0, Rel.GE, self.name) + if self.is_ascend: + self.channel_axis = validator.check_int_range('channel_axis', channel_axis, 0, 1, Rel.INC_BOTH, self.name) + else: + self.channel_axis = validator.check_integer('channel_axis', channel_axis, 0, Rel.GE, self.name) self.init_prim_io_names( inputs=['x', 'min', 'max'], outputs=['min_up', 'max_up']) def infer_shape(self, x_shape, min_shape, max_shape): - validator.check_integer("x rank", len(x_shape), 1, Rel.GT, self.name) + if self.is_ascend and len(x_shape) not in self.ascend_support_x_rank: + raise ValueError(f"For '{self.name}' x rank should be in '{self.ascend_support_x_rank}'") + if not self.is_ascend: + validator.check_integer("x rank", len(x_shape), 1, Rel.GE, self.name) validator.check("min shape", min_shape, "max shape", max_shape, Rel.EQ, self.name) validator.check_integer("min shape", len( @@ -221,8 +228,8 @@ class FakeQuantPerLayer(PrimitiveWithInfer): 'ema_decay', ema_decay, 0, 1, Rel.INC_BOTH, self.name) self.num_bits = validator.check_integer( 'num_bits', num_bits, 0, Rel.GT, self.name) - self.quant_delay = validator.check_value_type( - 'quant_delay', quant_delay, (int,), self.name) + self.quant_delay = validator.check_integer( + 'quant_delay', quant_delay, 0, Rel.GE, self.name) self.init_prim_io_names(inputs=['x', 'min', 'max'], outputs=['out']) @@ -314,6 +321,7 @@ class FakeQuantPerChannel(PrimitiveWithInfer): symmetric (bool): Quantization algorithm use symmetric or not. Default: False. narrow_range (bool): Quantization algorithm use narrow range or not. Default: False. training (bool): Training the network or not. Default: True. + channel_axis (int): Quantization by channel axis. Ascend backend only supports 0 or 1. Default: 1. Inputs: - **x** (Tensor) : 4-D float32 Tensor representing the shape of the output tensor. @@ -331,6 +339,7 @@ class FakeQuantPerChannel(PrimitiveWithInfer): >>> result = fake_quant(input_x, _min, _max) """ support_quant_bit = [4, 7, 8] + ascend_support_x_rank = [2, 4] @prim_attr_register def __init__(self, @@ -343,7 +352,8 @@ class FakeQuantPerChannel(PrimitiveWithInfer): training=True, channel_axis=1): """init FakeQuantPerChannel OP""" - if context.get_context('device_target') == "Ascend": + self.is_ascend = context.get_context('device_target') == "Ascend" + if self.is_ascend: from mindspore.ops._op_impl._custom_op import fake_quant_perchannel if num_bits not in self.support_quant_bit: raise ValueError( @@ -363,14 +373,19 @@ class FakeQuantPerChannel(PrimitiveWithInfer): 'ema_decay', ema_decay, 0, 1, Rel.INC_BOTH, self.name) self.num_bits = validator.check_integer( 'num_bits', num_bits, 0, Rel.GT, self.name) - self.quant_delay = validator.check_value_type( - 'quant_delay', quant_delay, (int,), self.name) - self.channel_axis = validator.check_integer( - 'channel_axis', channel_axis, 0, Rel.GE, self.name) + self.quant_delay = validator.check_integer( + 'quant_delay', quant_delay, 0, Rel.GE, self.name) + if self.is_ascend: + self.channel_axis = validator.check_int_range('channel_axis', channel_axis, 0, 1, Rel.INC_BOTH, self.name) + else: + self.channel_axis = validator.check_integer('channel_axis', channel_axis, 0, Rel.GE, self.name) self.init_prim_io_names(inputs=['x', 'min', 'max'], outputs=['out']) def infer_shape(self, x_shape, min_shape, max_shape): - validator.check_integer("x rank", len(x_shape), 1, Rel.GE, self.name) + if self.is_ascend and len(x_shape) not in self.ascend_support_x_rank: + raise ValueError(f"For '{self.name}' x rank should be in '{self.ascend_support_x_rank}'") + if not self.is_ascend: + validator.check_integer("x rank", len(x_shape), 1, Rel.GE, self.name) validator.check("min shape", min_shape, "max shape", max_shape, Rel.EQ, self.name) validator.check_integer( "min shape", min_shape[0], x_shape[self.channel_axis], Rel.EQ, self.name) diff --git a/mindspore/train/callback/_checkpoint.py b/mindspore/train/callback/_checkpoint.py index 4e686c414f..e0048ad713 100644 --- a/mindspore/train/callback/_checkpoint.py +++ b/mindspore/train/callback/_checkpoint.py @@ -21,7 +21,7 @@ import time import mindspore.context as context from mindspore import log as logger -from mindspore._checkparam import check_bool, check_string, check_int_non_negative +from mindspore._checkparam import check_bool, check_int_non_negative from mindspore.train._utils import _make_directory from mindspore.train.serialization import _exec_save_checkpoint, _save_graph from ._callback import Callback, set_cur_net @@ -86,7 +86,6 @@ class CheckpointConfig: Can't be used with keep_checkpoint_max at the same time. integrated_save (bool): Whether to intergrated save in automatic model parallel scene. Default: True. Integrated save function is only supported in automatic parallel scene, not supported in manual parallel. - model_type (str): Model type in `normal`, `fusion` or `quant`. Default: "normal". Raises: ValueError: If the input_param is None or 0. @@ -101,8 +100,7 @@ class CheckpointConfig: save_checkpoint_seconds=0, keep_checkpoint_max=5, keep_checkpoint_per_n_minutes=0, - integrated_save=True, - model_type="normal"): + integrated_save=True): if not save_checkpoint_steps and not save_checkpoint_seconds and \ not keep_checkpoint_max and not keep_checkpoint_per_n_minutes: @@ -116,8 +114,6 @@ class CheckpointConfig: keep_checkpoint_max = check_int_non_negative(keep_checkpoint_max) if keep_checkpoint_per_n_minutes: keep_checkpoint_per_n_minutes = check_int_non_negative(keep_checkpoint_per_n_minutes) - if model_type: - model_type = check_string(model_type, ["normal", "fusion", "quant"]) self._save_checkpoint_steps = save_checkpoint_steps self._save_checkpoint_seconds = save_checkpoint_seconds @@ -132,7 +128,6 @@ class CheckpointConfig: if not self._keep_checkpoint_per_n_minutes or self._keep_checkpoint_per_n_minutes == 0: self._keep_checkpoint_max = 1 - self._model_type = model_type self._integrated_save = check_bool(integrated_save) @property @@ -160,18 +155,12 @@ class CheckpointConfig: """Get the value of _integrated_save.""" return self._integrated_save - @property - def model_type(self): - """Get the value of model_type.""" - return self._model_type - def get_checkpoint_policy(self): """Get the policy of checkpoint.""" checkpoint_policy = {'save_checkpoint_steps': self._save_checkpoint_steps, 'save_checkpoint_seconds': self._save_checkpoint_seconds, 'keep_checkpoint_max': self._keep_checkpoint_max, - 'keep_checkpoint_per_n_minutes': self._keep_checkpoint_per_n_minutes, - 'model_type': self._model_type} + 'keep_checkpoint_per_n_minutes': self._keep_checkpoint_per_n_minutes} return checkpoint_policy @@ -236,7 +225,7 @@ class ModelCheckpoint(Callback): graph_file_name = os.path.join(self._directory, self._prefix + '-graph.meta') _save_graph(cb_params.train_network, graph_file_name) self._graph_saved = True - self._save_ckpt(cb_params, self._config.model_type) + self._save_ckpt(cb_params) def end(self, run_context): """ @@ -247,7 +236,7 @@ class ModelCheckpoint(Callback): """ cb_params = run_context.original_args() _to_save_last_ckpt = True - self._save_ckpt(cb_params, self._config.model_type, _to_save_last_ckpt) + self._save_ckpt(cb_params, _to_save_last_ckpt) from mindspore.parallel._cell_wrapper import destroy_allgather_cell destroy_allgather_cell() @@ -266,7 +255,7 @@ class ModelCheckpoint(Callback): return False - def _save_ckpt(self, cb_params, model_type, force_to_save=False): + def _save_ckpt(self, cb_params, force_to_save=False): """Save checkpoint files.""" if cb_params.cur_step_num == self._last_triggered_step: return @@ -302,7 +291,7 @@ class ModelCheckpoint(Callback): set_cur_net(cb_params.train_network) cb_params.train_network.exec_checkpoint_graph() - _exec_save_checkpoint(cb_params.train_network, gen_file, model_type, self._config.integrated_save) + _exec_save_checkpoint(cb_params.train_network, gen_file, self._config.integrated_save) if os.path.exists(gen_file): shutil.move(gen_file, cur_file) diff --git a/mindspore/train/callback/_loss_monitor.py b/mindspore/train/callback/_loss_monitor.py index 3f93c6314d..bdd2220441 100644 --- a/mindspore/train/callback/_loss_monitor.py +++ b/mindspore/train/callback/_loss_monitor.py @@ -86,7 +86,7 @@ class LossMonitor(Callback): if self._per_print_times != 0 and cb_params.cur_step_num % self._per_print_times == 0: print("Epoch: [{:3d}/{:3d}], step: [{:5d}/{:5d}], " - "loss: [{:5.4f}/{:5.4f}], time: [{:5.4f}]".format( + "loss: [{:5.4f}], avg los: [{:5.4f}], time: [{:5.4f}]".format( cb_params.cur_epoch_num, cb_params.epoch_num, cur_step_in_epoch, int(cb_params.batch_num), step_loss, np.mean(self.losses), diff --git a/mindspore/train/quant/quant.py b/mindspore/train/quant/quant.py index cb4cb39e66..3709c171e5 100644 --- a/mindspore/train/quant/quant.py +++ b/mindspore/train/quant/quant.py @@ -42,15 +42,14 @@ _ACTIVATION_MAP = {nn.ReLU: quant.ReLUQuant, class _AddFakeQuantInput(nn.Cell): """ - Add FakeQuant at input and output of the Network. Only support one input and one output case. + Add FakeQuant OP at input of the network. Only support one input case. """ def __init__(self, network, quant_delay=0): super(_AddFakeQuantInput, self).__init__(auto_prefix=False) + self.fake_quant_input = quant.FakeQuantWithMinMax(min_init=-6, max_init=6, quant_delay=quant_delay, ema=True) + self.fake_quant_input.update_parameters_name('fake_quant_input.') self.network = network - self.fake_quant_input = quant.FakeQuantWithMinMax( - min_init=-6, max_init=6, quant_delay=quant_delay, ema=True) - self.fake_quant_input.update_parameters_name('fake_quant_input') def construct(self, data): data = self.fake_quant_input(data) @@ -60,7 +59,7 @@ class _AddFakeQuantInput(nn.Cell): class _AddFakeQuantAfterSubCell(nn.Cell): """ - Add FakeQuant after of the sub Cell. + Add FakeQuant OP after of the sub Cell. """ def __init__(self, subcell, **kwargs): @@ -115,11 +114,12 @@ class ConvertToQuantNetwork: self.network.update_cell_prefix() network = self._convert_subcells2quant(self.network) network = _AddFakeQuantInput(network) + self.network.update_cell_type("quant") return network def _convert_subcells2quant(self, network): """ - convet sub cell to quant cell + convert sub cell like `Conv2dBnAct` and `DenseBnAct` to quant cell """ cells = network.name_cells() change = False @@ -138,13 +138,13 @@ class ConvertToQuantNetwork: if isinstance(network, nn.SequentialCell) and change: network.cell_list = list(network.cells()) - # tensoradd to tensoradd quant + # add FakeQuant OP after OP in while list add_list = [] for name in network.__dict__: if name[0] == '_': continue attr = network.__dict__[name] - if isinstance(attr, ops.Primitive) and attr.name in ConvertToQuantNetwork.__quant_op_name__: + if isinstance(attr, ops.Primitive) and attr.name in self.__quant_op_name__: add_list.append((name, attr)) for name, prim_op in add_list: prefix = name @@ -164,11 +164,11 @@ class ConvertToQuantNetwork: def _convert_conv(self, subcell): """ - convet conv cell to quant cell + convert Conv2d cell to quant cell """ conv_inner = subcell.conv - bn_inner = subcell.batchnorm - if subcell.batchnorm is not None and self.bn_fold: + if subcell.has_bn and self.bn_fold: + bn_inner = subcell.batchnorm conv_inner = quant.Conv2dBatchNormQuant(conv_inner.in_channels, conv_inner.out_channels, kernel_size=conv_inner.kernel_size, @@ -178,7 +178,7 @@ class ConvertToQuantNetwork: dilation=conv_inner.dilation, group=conv_inner.group, eps=bn_inner.eps, - momentum=bn_inner.momentum, + momentum=1 - bn_inner.momentum, quant_delay=self.weight_qdelay, freeze_bn=self.freeze_bn, per_channel=self.weight_channel, @@ -186,6 +186,11 @@ class ConvertToQuantNetwork: fake=True, symmetric=self.weight_symmetric, narrow_range=self.weight_range) + # change original network BatchNormal OP parameters to quant network + conv_inner.gamma = subcell.batchnorm.gamma + conv_inner.beta = subcell.batchnorm.beta + conv_inner.moving_mean = subcell.batchnorm.moving_mean + conv_inner.moving_variance = subcell.batchnorm.moving_variance del subcell.batchnorm subcell.batchnorm = None subcell.has_bn = False @@ -204,6 +209,10 @@ class ConvertToQuantNetwork: num_bits=self.weight_bits, symmetric=self.weight_symmetric, narrow_range=self.weight_range) + # change original network Conv2D OP parameters to quant network + conv_inner.weight = subcell.conv.weight + if subcell.conv.has_bias: + conv_inner.bias = subcell.conv.bias subcell.conv = conv_inner if subcell.has_act and subcell.activation is not None: subcell.activation = self._convert_activation(subcell.activation) @@ -230,6 +239,10 @@ class ConvertToQuantNetwork: per_channel=self.weight_channel, symmetric=self.weight_symmetric, narrow_range=self.weight_range) + # change original network Dense OP parameters to quant network + dense_inner.weight = subcell.dense.weight + if subcell.dense.has_bias: + dense_inner.bias = subcell.dense.bias subcell.dense = dense_inner if subcell.has_act and subcell.activation is not None: subcell.activation = self._convert_activation(subcell.activation) @@ -247,12 +260,12 @@ class ConvertToQuantNetwork: act_class = activation.__class__ if act_class not in _ACTIVATION_MAP: raise ValueError( - "Unsupported activation in auto Quant: ", act_class) + "Unsupported activation in auto quant: ", act_class) return _ACTIVATION_MAP[act_class](num_bits=self.act_bits, quant_delay=self.act_qdelay, per_channel=self.act_channel, - symmetric=self.weight_symmetric, - narrow_range=self.weight_range) + symmetric=self.act_symmetric, + narrow_range=self.act_range) class ExportQuantNetworkDeploy: diff --git a/mindspore/train/serialization.py b/mindspore/train/serialization.py index ff1b8c3122..fc135b18a9 100644 --- a/mindspore/train/serialization.py +++ b/mindspore/train/serialization.py @@ -40,8 +40,6 @@ tensor_to_np_type = {"Int8": np.int8, "Uint8": np.uint8, "Int16": np.int16, "Uin "Int32": np.int32, "Uint32": np.uint32, "Int64": np.int64, "Uint64": np.uint64, "Float16": np.float16, "Float32": np.float32, "Float64": np.float64, "Bool": np.bool_} -ModelType = ["normal", "fusion", "quant"] - def _special_process_par(par, new_par): """ @@ -103,7 +101,7 @@ def _update_param(param, new_param): param.set_parameter_data(type(param.data)(new_param.data)) -def save_checkpoint(parameter_list, ckpt_file_name, model_type="normal"): +def save_checkpoint(parameter_list, ckpt_file_name): """ Saves checkpoint info to a specified file. @@ -111,14 +109,12 @@ def save_checkpoint(parameter_list, ckpt_file_name, model_type="normal"): parameter_list (list): Parameters list, each element is a dict like {"name":xx, "type":xx, "shape":xx, "data":xx}. ckpt_file_name (str): Checkpoint file name. - model_type (str): The name of model type. Default: "normal". Raises: RuntimeError: Failed to save the Checkpoint file. """ logger.info("Execute save checkpoint process.") checkpoint_list = Checkpoint() - checkpoint_list.model_type = model_type try: for param in parameter_list: @@ -147,13 +143,12 @@ def save_checkpoint(parameter_list, ckpt_file_name, model_type="normal"): logger.info("Save checkpoint process finish.") -def load_checkpoint(ckpt_file_name, model_type="normal", net=None): +def load_checkpoint(ckpt_file_name, net=None): """ Loads checkpoint info from a specified file. Args: ckpt_file_name (str): Checkpoint file name. - model_type (str): The name of model type in `normal`, `fusion` or `quant`. Default: "normal". net (Cell): Cell network. Default: None Returns: @@ -165,9 +160,6 @@ def load_checkpoint(ckpt_file_name, model_type="normal", net=None): if not isinstance(ckpt_file_name, str): raise ValueError("The ckpt_file_name must be string.") - if model_type not in ModelType: - raise ValueError(f"The model_type is not in {ModelType}.") - if not os.path.exists(ckpt_file_name) or ckpt_file_name[-5:] != ".ckpt": raise ValueError("Please input the correct checkpoint file name.") @@ -186,10 +178,6 @@ def load_checkpoint(ckpt_file_name, model_type="normal", net=None): raise ValueError(e.__str__()) parameter_dict = {} - if checkpoint_list.model_type: - if model_type != checkpoint_list.model_type: - raise KeyError("Checkpoint file model type({}) is not equal to input model type({}).".format( - checkpoint_list.model_type, model_type)) try: for element in checkpoint_list.value: data = element.tensor.tensor_content @@ -314,14 +302,13 @@ def _save_graph(network, file_name): os.chmod(file_name, stat.S_IWUSR | stat.S_IRUSR) -def _exec_save_checkpoint(train_network, ckpt_file_name, model_type="normal", integrated_save=True): +def _exec_save_checkpoint(train_network, ckpt_file_name, integrated_save=True): """ Saves checkpoint for 'ms' backend. Args: train_network (Network): The train network for training. ckpt_file_name (str): The name of checkpoint file. - model_type (str): The name of model type in `normal`, `fusion` or `quant`. Default: "normal". integrated_save (bool): Whether to integrated save in automatic model parallel scene. """ @@ -346,7 +333,7 @@ def _exec_save_checkpoint(train_network, ckpt_file_name, model_type="normal", in each_param["data"] = param_data param_list.append(each_param) - save_checkpoint(param_list, ckpt_file_name, model_type) + save_checkpoint(param_list, ckpt_file_name) def _get_merged_param_data(net, param_name, param_data): diff --git a/model_zoo/lenet_quant/README.md b/model_zoo/lenet_quant/README.md index 2f949f6d76..2fd3e129a2 100644 --- a/model_zoo/lenet_quant/README.md +++ b/model_zoo/lenet_quant/README.md @@ -33,7 +33,7 @@ Then you will get the following display ```bash >>> Found existing installation: mindspore-ascend >>> Uninstalling mindspore-ascend: ->>> Successfully uninstalled mindspore-ascend. +>>> Successfully uninstalled mindspore-ascend. ``` ### Prepare Dataset @@ -186,7 +186,7 @@ model = Model(network, net_loss, net_opt, metrics={"Accuracy": Accuracy()}) ### train quantization aware model -Also, you can just run this command instread. +Also, you can just run this command instead. ```python python train_quant.py --data_path MNIST_Data --device_target Ascend --ckpt_path checkpoint_lenet.ckpt @@ -235,7 +235,7 @@ The top1 accuracy would display on shell. Here are some optional parameters: ```bash ---device_target {Ascend,GPU,CPU} +--device_target {Ascend,GPU} device where the code will be implemented (default: Ascend) --data_path DATA_PATH path where the dataset is saved diff --git a/model_zoo/lenet_quant/eval.py b/model_zoo/lenet_quant/eval.py index d94e77279f..c0293ae1f7 100644 --- a/model_zoo/lenet_quant/eval.py +++ b/model_zoo/lenet_quant/eval.py @@ -31,7 +31,7 @@ from src.lenet_fusion import LeNet5 as LeNet5Fusion parser = argparse.ArgumentParser(description='MindSpore MNIST Example') parser.add_argument('--device_target', type=str, default="Ascend", - choices=['Ascend', 'GPU', 'CPU'], + choices=['Ascend', 'GPU'], help='device where the code will be implemented (default: Ascend)') parser.add_argument('--data_path', type=str, default="./MNIST_Data", help='path where the dataset is saved') diff --git a/model_zoo/lenet_quant/eval_quant.py b/model_zoo/lenet_quant/eval_quant.py index 2c2477123f..bc9b62121d 100644 --- a/model_zoo/lenet_quant/eval_quant.py +++ b/model_zoo/lenet_quant/eval_quant.py @@ -32,7 +32,7 @@ from src.lenet_fusion import LeNet5 as LeNet5Fusion parser = argparse.ArgumentParser(description='MindSpore MNIST Example') parser.add_argument('--device_target', type=str, default="Ascend", - choices=['Ascend', 'GPU', 'CPU'], + choices=['Ascend', 'GPU'], help='device where the code will be implemented (default: Ascend)') parser.add_argument('--data_path', type=str, default="./MNIST_Data", help='path where the dataset is saved') @@ -61,7 +61,7 @@ if __name__ == "__main__": model = Model(network, net_loss, net_opt, metrics={"Accuracy": Accuracy()}) # load quantization aware network checkpoint - param_dict = load_checkpoint(args.ckpt_path, model_type="quant") + param_dict = load_checkpoint(args.ckpt_path) load_param_into_net(network, param_dict) print("============== Starting Testing ==============") diff --git a/model_zoo/lenet_quant/train.py b/model_zoo/lenet_quant/train.py index b6040776ef..a34b6d5ed6 100644 --- a/model_zoo/lenet_quant/train.py +++ b/model_zoo/lenet_quant/train.py @@ -31,7 +31,7 @@ from src.lenet_fusion import LeNet5 as LeNet5Fusion parser = argparse.ArgumentParser(description='MindSpore MNIST Example') parser.add_argument('--device_target', type=str, default="Ascend", - choices=['Ascend', 'GPU', 'CPU'], + choices=['Ascend', 'GPU'], help='device where the code will be implemented (default: Ascend)') parser.add_argument('--data_path', type=str, default="./MNIST_Data", help='path where the dataset is saved') @@ -56,8 +56,7 @@ if __name__ == "__main__": # call back and monitor time_cb = TimeMonitor(data_size=ds_train.get_dataset_size()) config_ckpt = CheckpointConfig(save_checkpoint_steps=cfg.epoch_size * step_size, - keep_checkpoint_max=cfg.keep_checkpoint_max, - model_type=network.type) + keep_checkpoint_max=cfg.keep_checkpoint_max) ckpt_callback = ModelCheckpoint(prefix="checkpoint_lenet", config=config_ckpt) # define model diff --git a/model_zoo/lenet_quant/train_quant.py b/model_zoo/lenet_quant/train_quant.py index eb1f783a7c..ba54e63d80 100644 --- a/model_zoo/lenet_quant/train_quant.py +++ b/model_zoo/lenet_quant/train_quant.py @@ -33,7 +33,7 @@ from src.lenet_fusion import LeNet5 as LeNet5Fusion parser = argparse.ArgumentParser(description='MindSpore MNIST Example') parser.add_argument('--device_target', type=str, default="Ascend", - choices=['Ascend', 'GPU', 'CPU'], + choices=['Ascend', 'GPU'], help='device where the code will be implemented (default: Ascend)') parser.add_argument('--data_path', type=str, default="./MNIST_Data", help='path where the dataset is saved') @@ -50,11 +50,13 @@ if __name__ == "__main__": # define fusion network network = LeNet5Fusion(cfg.num_classes) + + # convert fusion network to quantization aware network + network = quant.convert_quant_network(network, quant_delay=0, bn_fold=False, freeze_bn=10000) + # load quantization aware network checkpoint param_dict = load_checkpoint(args.ckpt_path, network.type) load_param_into_net(network, param_dict) - # convert fusion network to quantization aware network - network = quant.convert_quant_network(network, quant_delay=0, bn_fold=False, freeze_bn=10000) # define network loss net_loss = nn.SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True, reduction="mean") @@ -64,8 +66,7 @@ if __name__ == "__main__": # call back and monitor time_cb = TimeMonitor(data_size=ds_train.get_dataset_size()) config_ckpt = CheckpointConfig(save_checkpoint_steps=cfg.epoch_size * step_size, - keep_checkpoint_max=cfg.keep_checkpoint_max, - model_type="quant") + keep_checkpoint_max=cfg.keep_checkpoint_max) ckpt_callback = ModelCheckpoint(prefix="checkpoint_lenet", config=config_ckpt) # define model diff --git a/model_zoo/mobilenetv2/scripts/run_infer.sh b/model_zoo/mobilenetv2/scripts/run_infer.sh index e200e600bf..7385a221d4 100644 --- a/model_zoo/mobilenetv2/scripts/run_infer.sh +++ b/model_zoo/mobilenetv2/scripts/run_infer.sh @@ -40,7 +40,7 @@ export PYTHONPATH=${BASEPATH}:$PYTHONPATH export DEVICE_ID=0 export RANK_ID=0 export RANK_SIZE=1 -if [ -d "eval" ]; +if [ -d "../eval" ]; then rm -rf ../eval fi diff --git a/model_zoo/mobilenetv2/scripts/run_train.sh b/model_zoo/mobilenetv2/scripts/run_train.sh index fc013d474c..3414aa7528 100644 --- a/model_zoo/mobilenetv2/scripts/run_train.sh +++ b/model_zoo/mobilenetv2/scripts/run_train.sh @@ -62,7 +62,7 @@ run_gpu() BASEPATH=$(cd "`dirname $0`" || exit; pwd) export PYTHONPATH=${BASEPATH}:$PYTHONPATH - if [ -d "train" ]; + if [ -d "../train" ]; then rm -rf ../train fi diff --git a/model_zoo/mobilenetv3/scripts/run_infer.sh b/model_zoo/mobilenetv3/scripts/run_infer.sh index e200e600bf..7385a221d4 100644 --- a/model_zoo/mobilenetv3/scripts/run_infer.sh +++ b/model_zoo/mobilenetv3/scripts/run_infer.sh @@ -40,7 +40,7 @@ export PYTHONPATH=${BASEPATH}:$PYTHONPATH export DEVICE_ID=0 export RANK_ID=0 export RANK_SIZE=1 -if [ -d "eval" ]; +if [ -d "../eval" ]; then rm -rf ../eval fi diff --git a/model_zoo/mobilenetv3/scripts/run_train.sh b/model_zoo/mobilenetv3/scripts/run_train.sh index 78b79b235f..47dabffe01 100644 --- a/model_zoo/mobilenetv3/scripts/run_train.sh +++ b/model_zoo/mobilenetv3/scripts/run_train.sh @@ -60,7 +60,7 @@ run_gpu() BASEPATH=$(cd "`dirname $0`" || exit; pwd) export PYTHONPATH=${BASEPATH}:$PYTHONPATH - if [ -d "train" ]; + if [ -d "../train" ]; then rm -rf ../train fi diff --git a/tests/ut/python/train/quant/mobilenetv2_combined.py b/tests/ut/python/train/quant/mobilenetv2_combined.py index b0cbafb29a..51916192d8 100644 --- a/tests/ut/python/train/quant/mobilenetv2_combined.py +++ b/tests/ut/python/train/quant/mobilenetv2_combined.py @@ -31,7 +31,7 @@ def _conv_bn(in_channel, out_channel, kernel_size=ksize, stride=stride, - batchnorm=True)]) + has_bn=True)]) class InvertedResidual(nn.Cell): @@ -49,25 +49,25 @@ class InvertedResidual(nn.Cell): 3, stride, group=hidden_dim, - batchnorm=True, + has_bn=True, activation='relu6'), nn.Conv2dBnAct(hidden_dim, oup, 1, 1, - batchnorm=True) + has_bn=True) ]) else: self.conv = nn.SequentialCell([ nn.Conv2dBnAct(inp, hidden_dim, 1, 1, - batchnorm=True, + has_bn=True, activation='relu6'), nn.Conv2dBnAct(hidden_dim, hidden_dim, 3, stride, group=hidden_dim, - batchnorm=True, + has_bn=True, activation='relu6'), nn.Conv2dBnAct(hidden_dim, oup, 1, 1, - batchnorm=True) + has_bn=True) ]) self.add = P.TensorAdd() diff --git a/tests/ut/python/train/quant/test_quant.py b/tests/ut/python/train/quant/test_quant.py index 54563d86eb..1a21bc2c02 100644 --- a/tests/ut/python/train/quant/test_quant.py +++ b/tests/ut/python/train/quant/test_quant.py @@ -42,7 +42,7 @@ class LeNet5(nn.Cell): def __init__(self, num_class=10): super(LeNet5, self).__init__() self.num_class = num_class - self.conv1 = nn.Conv2dBnAct(1, 6, kernel_size=5, batchnorm=True, activation='relu6', pad_mode="valid") + self.conv1 = nn.Conv2dBnAct(1, 6, kernel_size=5, has_bn=True, activation='relu6', pad_mode="valid") self.conv2 = nn.Conv2dBnAct(6, 16, kernel_size=5, activation='relu', pad_mode="valid") self.fc1 = nn.DenseBnAct(16 * 5 * 5, 120, activation='relu') self.fc2 = nn.DenseBnAct(120, 84, activation='relu')