From b5223681613da3b1896e66c0476354ba3c6b78a6 Mon Sep 17 00:00:00 2001 From: Wei Luning Date: Thu, 9 Jul 2020 10:53:23 +0800 Subject: [PATCH] fix bug of quant export without input fakequant. --- mindspore/ccsrc/pipeline/pipeline.cc | 3 + mindspore/nn/layer/quant.py | 4 +- mindspore/ops/primitive.py | 22 ++-- mindspore/train/quant/quant.py | 9 +- mindspore/train/quant/quant_utils.py | 27 ++-- tests/ut/python/train/quant/mobilenetv2.py | 115 ----------------- .../train/quant/mobilenetv2_combined.py | 122 ------------------ tests/ut/python/train/quant/test_quant.py | 17 ++- 8 files changed, 47 insertions(+), 272 deletions(-) delete mode 100644 tests/ut/python/train/quant/mobilenetv2.py delete mode 100644 tests/ut/python/train/quant/mobilenetv2_combined.py diff --git a/mindspore/ccsrc/pipeline/pipeline.cc b/mindspore/ccsrc/pipeline/pipeline.cc index 7f5f3c3ffad..b164d9ca3fd 100644 --- a/mindspore/ccsrc/pipeline/pipeline.cc +++ b/mindspore/ccsrc/pipeline/pipeline.cc @@ -328,6 +328,9 @@ std::map> ExecutorPy::FetchI x = cnode->input(1); count += 1; } + if (x->isa()) { + fake_quant_table[weight_name] = std::make_pair(nullptr, "input"); + } // get the fakequant parameter minq's name if (!is_quant_cnode(x)) { continue; diff --git a/mindspore/nn/layer/quant.py b/mindspore/nn/layer/quant.py index e0871ee3647..994f09dfd85 100644 --- a/mindspore/nn/layer/quant.py +++ b/mindspore/nn/layer/quant.py @@ -1169,9 +1169,9 @@ class QuantBlock(Cell): return x def extend_repr(self): - str_info = f'quant={self.quant}, core_op={type(self.core_op)}' + str_info = f'quant={self.quant}, core_op={type(self.core_op)}, weight=shape[{self.weight.shape}]' if self.has_bias: - str_info = str_info + f', bias={self.bias}' + str_info = str_info + f', bias=shape[{self.bias.shape}]' if self.has_act: str_info = str_info + f', activation={self.activation}' str_info = str_info + f', dequant={self.dequant}' diff --git a/mindspore/ops/primitive.py b/mindspore/ops/primitive.py index 7ceb6877780..768e9db2db4 100644 --- a/mindspore/ops/primitive.py +++ b/mindspore/ops/primitive.py @@ -237,12 +237,14 @@ class PrimitiveWithInfer(Primitive): """ Infer output shape based on input shape. - Args: - inputs (tuple(int)): dimensions of input tensors. - outputs (tuple(int)): dimensions of output tensors. - Note: The shape of scalar is an empty tuple. + + Args: + args (tuple(int)): shapes of input tensors. + + Return: + `tuple(int)`, shapes of output tensors. """ return None @@ -251,8 +253,10 @@ class PrimitiveWithInfer(Primitive): Infer output dtype based on input dtype. Args: - inputs (mstype): data type of inputs. - outputs (mstype): data type of outputs. + args (:class:`mindspore.dtype`): data type of inputs. + + Return: + :class:`mindspore.dtype`, data type of outputs. """ return None @@ -261,8 +265,10 @@ class PrimitiveWithInfer(Primitive): Infer output value based on input value at compile time. Args: - inputs (any): value of inputs. - outputs (any): value of outputs. + args (Any): value of inputs. + + Return: + Value of outputs. Return `None` for, cat not infer the value at compile time. """ return None diff --git a/mindspore/train/quant/quant.py b/mindspore/train/quant/quant.py index a079644aef1..b553373f105 100644 --- a/mindspore/train/quant/quant.py +++ b/mindspore/train/quant/quant.py @@ -318,9 +318,12 @@ class ExportToQuantInferNetwork: info = self.quant_info_table.get(w_minq_name, None) if info: fack_quant_a_in_op, minq_name = info - maxq = self.all_parameters[minq_name[:-4] + "maxq"] - minq = self.all_parameters[minq_name] - scale_a_in, zp_a_in = quant_utils.scale_zp_from_data(fack_quant_a_in_op, maxq, minq, np_type) + if minq_name == 'input': + scale_a_in, zp_a_in = self.input_scale, self.input_zero_point + else: + maxq = self.all_parameters[minq_name[:-4] + "maxq"] + minq = self.all_parameters[minq_name] + scale_a_in, zp_a_in = quant_utils.scale_zp_from_data(fack_quant_a_in_op, maxq, minq, np_type) else: logger.warning(f"Do not find `fake_quant` from input with `fake_quant.minq` {w_minq_name}") return None diff --git a/mindspore/train/quant/quant_utils.py b/mindspore/train/quant/quant_utils.py index da6d4fc8728..69505970fd8 100644 --- a/mindspore/train/quant/quant_utils.py +++ b/mindspore/train/quant/quant_utils.py @@ -104,19 +104,20 @@ def weight2int(data, scale, zero_point): raise ValueError("`scale` and `zero_point` should have the same shape.") if scale.shape[0] < 0: raise ValueError("`scale` and `zero_point` shape should greater than zero.") - - if scale.shape[0] == data.shape[0]: - # `Conv2d` or `Dense` op weight - shape_list = [-1] + [1] * len(data.shape[1:]) - scale = scale.reshape(shape_list) - zero_point = zero_point.reshape(shape_list) - elif scale.shape[0] == data.shape[1]: - # `DepthwiseConv2d` op weight - shape_list = [1, -1] + [1] * len(data.shape[2:]) - scale = scale.reshape(shape_list) - zero_point = zero_point.reshape(shape_list) - else: - raise ValueError("Unsupported weight shape({})".format(data.shape)) + if len(scale.shape) > 1: + # for perchannel + if scale.shape[0] == data.shape[0]: + # `Conv2d` or `Dense` op weight + shape_list = [-1] + [1] * len(data.shape[1:]) + scale = scale.reshape(shape_list) + zero_point = zero_point.reshape(shape_list) + elif scale.shape[0] == data.shape[1]: + # `DepthwiseConv2d` op weight + shape_list = [1, -1] + [1] * len(data.shape[2:]) + scale = scale.reshape(shape_list) + zero_point = zero_point.reshape(shape_list) + else: + raise ValueError("Unsupported weight shape({})".format(data.shape)) return np.round((data / scale) + zero_point) diff --git a/tests/ut/python/train/quant/mobilenetv2.py b/tests/ut/python/train/quant/mobilenetv2.py deleted file mode 100644 index 163b230e1e0..00000000000 --- a/tests/ut/python/train/quant/mobilenetv2.py +++ /dev/null @@ -1,115 +0,0 @@ -# Copyright 2020 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ -"""MobileNetV2""" -from mindspore import nn -from mindspore.ops import operations as P - - -def make_divisible(input_x, div_by=8): - return int((input_x + div_by) // div_by) - - -def _conv_bn(in_channel, - out_channel, - ksize, - stride=1): - """Get a conv2d batchnorm and relu layer.""" - return nn.SequentialCell( - [nn.Conv2d(in_channel, - out_channel, - kernel_size=ksize, - stride=stride), - nn.BatchNorm2d(out_channel)]) - - -class InvertedResidual(nn.Cell): - def __init__(self, inp, oup, stride, expend_ratio): - super(InvertedResidual, self).__init__() - self.stride = stride - assert stride in [1, 2] - - hidden_dim = int(inp * expend_ratio) - self.use_res_connect = self.stride == 1 and inp == oup - if expend_ratio == 1: - self.conv = nn.SequentialCell([ - nn.Conv2d(hidden_dim, hidden_dim, 3, stride, group=hidden_dim), - nn.BatchNorm2d(hidden_dim), - nn.ReLU6(), - nn.Conv2d(hidden_dim, oup, 1, 1), - nn.BatchNorm2d(oup) - ]) - else: - self.conv = nn.SequentialCell([ - nn.Conv2d(inp, hidden_dim, 1, 1), - nn.BatchNorm2d(hidden_dim), - nn.ReLU6(), - - nn.Conv2d(hidden_dim, hidden_dim, 3, stride, group=hidden_dim), - nn.BatchNorm2d(hidden_dim), - nn.ReLU6(), - - nn.Conv2d(hidden_dim, oup, 1, 1), - nn.BatchNorm2d(oup) - ]) - - def construct(self, input_x): - out = self.conv(input_x) - if self.use_res_connect: - out = input_x + out - return out - - -class MobileNetV2(nn.Cell): - def __init__(self, num_class=1000, input_size=224, width_mul=1.): - super(MobileNetV2, self).__init__() - _ = input_size - block = InvertedResidual - input_channel = 32 - last_channel = 1280 - inverted_residual_setting = [ - [1, 16, 1, 1], - [6, 24, 2, 2], - [6, 32, 3, 2], - [6, 64, 4, 2], - [6, 96, 3, 1], - [6, 160, 3, 2], - [6, 230, 1, 1], - ] - if width_mul > 1.0: - last_channel = make_divisible(last_channel * width_mul) - self.last_channel = last_channel - features = [_conv_bn(3, input_channel, 3, 2)] - - for t, c, n, s in inverted_residual_setting: - out_channel = make_divisible(c * width_mul) if t > 1 else c - for i in range(n): - if i == 0: - features.append(block(input_channel, out_channel, s, t)) - else: - features.append(block(input_channel, out_channel, 1, t)) - input_channel = out_channel - - features.append(_conv_bn(input_channel, self.last_channel, 1)) - - self.features = nn.SequentialCell(features) - self.mean = P.ReduceMean(keep_dims=False) - self.classifier = nn.Dense(self.last_channel, num_class) - - def construct(self, input_x): - out = input_x - out = self.features(out) - out = self.mean(out, (2, 3)) - out = self.classifier(out) - return out diff --git a/tests/ut/python/train/quant/mobilenetv2_combined.py b/tests/ut/python/train/quant/mobilenetv2_combined.py deleted file mode 100644 index 51916192d84..00000000000 --- a/tests/ut/python/train/quant/mobilenetv2_combined.py +++ /dev/null @@ -1,122 +0,0 @@ -# Copyright 2020 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ -"""mobile net v2""" -from mindspore import nn -from mindspore.ops import operations as P - - -def make_divisible(input_x, div_by=8): - return int((input_x + div_by) // div_by) - - -def _conv_bn(in_channel, - out_channel, - ksize, - stride=1): - """Get a conv2d batchnorm and relu layer.""" - return nn.SequentialCell( - [nn.Conv2dBnAct(in_channel, - out_channel, - kernel_size=ksize, - stride=stride, - has_bn=True)]) - - -class InvertedResidual(nn.Cell): - def __init__(self, inp, oup, stride, expend_ratio): - super(InvertedResidual, self).__init__() - self.stride = stride - assert stride in [1, 2] - - hidden_dim = int(inp * expend_ratio) - self.use_res_connect = self.stride == 1 and inp == oup - if expend_ratio == 1: - self.conv = nn.SequentialCell([ - nn.Conv2dBnAct(hidden_dim, - hidden_dim, - 3, - stride, - group=hidden_dim, - has_bn=True, - activation='relu6'), - nn.Conv2dBnAct(hidden_dim, oup, 1, 1, - has_bn=True) - ]) - else: - self.conv = nn.SequentialCell([ - nn.Conv2dBnAct(inp, hidden_dim, 1, 1, - has_bn=True, - activation='relu6'), - nn.Conv2dBnAct(hidden_dim, - hidden_dim, - 3, - stride, - group=hidden_dim, - has_bn=True, - activation='relu6'), - nn.Conv2dBnAct(hidden_dim, oup, 1, 1, - has_bn=True) - ]) - self.add = P.TensorAdd() - - def construct(self, input_x): - out = self.conv(input_x) - if self.use_res_connect: - out = self.add(input_x, out) - return out - - -class MobileNetV2(nn.Cell): - def __init__(self, num_class=1000, input_size=224, width_mul=1.): - super(MobileNetV2, self).__init__() - _ = input_size - block = InvertedResidual - input_channel = 32 - last_channel = 1280 - inverted_residual_setting = [ - [1, 16, 1, 1], - [6, 24, 2, 2], - [6, 32, 3, 2], - [6, 64, 4, 2], - [6, 96, 3, 1], - [6, 160, 3, 2], - [6, 230, 1, 1], - ] - if width_mul > 1.0: - last_channel = make_divisible(last_channel * width_mul) - self.last_channel = last_channel - features = [_conv_bn(3, input_channel, 3, 2)] - - for t, c, n, s in inverted_residual_setting: - out_channel = make_divisible(c * width_mul) if t > 1 else c - for i in range(n): - if i == 0: - features.append(block(input_channel, out_channel, s, t)) - else: - features.append(block(input_channel, out_channel, 1, t)) - input_channel = out_channel - - features.append(_conv_bn(input_channel, self.last_channel, 1)) - - self.features = nn.SequentialCell(features) - self.mean = P.ReduceMean(keep_dims=False) - self.classifier = nn.DenseBnAct(self.last_channel, num_class) - - def construct(self, input_x): - out = input_x - out = self.features(out) - out = self.mean(out, (2, 3)) - out = self.classifier(out) - return out diff --git a/tests/ut/python/train/quant/test_quant.py b/tests/ut/python/train/quant/test_quant.py index 1a21bc2c023..39e887170ca 100644 --- a/tests/ut/python/train/quant/test_quant.py +++ b/tests/ut/python/train/quant/test_quant.py @@ -20,7 +20,7 @@ import mindspore.context as context from mindspore import Tensor from mindspore import nn from mindspore.train.quant import quant as qat -from mobilenetv2_combined import MobileNetV2 +from model_zoo.mobilenetv2_quant.src.mobilenetV2 import mobilenetV2 context.set_context(mode=context.GRAPH_MODE, device_target="GPU") @@ -42,7 +42,7 @@ class LeNet5(nn.Cell): def __init__(self, num_class=10): super(LeNet5, self).__init__() self.num_class = num_class - self.conv1 = nn.Conv2dBnAct(1, 6, kernel_size=5, has_bn=True, activation='relu6', pad_mode="valid") + self.conv1 = nn.Conv2dBnAct(1, 6, kernel_size=5, has_bn=True, activation='relu', pad_mode="valid") self.conv2 = nn.Conv2dBnAct(6, 16, kernel_size=5, activation='relu', pad_mode="valid") self.fc1 = nn.DenseBnAct(16 * 5 * 5, 120, activation='relu') self.fc2 = nn.DenseBnAct(120, 84, activation='relu') @@ -67,20 +67,19 @@ def test_qat_lenet(): img = Tensor(np.ones((32, 1, 32, 32)).astype(np.float32)) net = LeNet5() net = qat.convert_quant_network( - net, freeze_bn=10000, num_bits=8) + net, bn_fold=True, per_channel=[True, False], symmetric=[True, False]) # should load the checkpoint. mock here for param in net.get_parameters(): param.init_data() - qat.export_geir(net, img, file_name="quant.pb") + qat.export(net, img, file_name="quant.pb") @pytest.mark.skip(reason="no `te.lang.cce` in ut env") def test_qat_mobile(): - net = MobileNetV2() + network = mobilenetV2(num_classes=1000) img = Tensor(np.ones((1, 3, 224, 224)).astype(np.float32)) - net = qat.convert_quant_network( - net, quant_delay=0, bn_fold=True, freeze_bn=10000, num_bits=8) + network = qat.convert_quant_network(network, bn_fold=True, per_channel=[True, False], symmetric=[True, False]) # should load the checkpoint. mock here - for param in net.get_parameters(): + for param in network.get_parameters(): param.init_data() - qat.export_geir(net, img, file_name="quant.pb") + qat.export(network, img, file_name="quant.pb")