diff --git a/mindspore/ccsrc/pipeline/jit/pipeline.cc b/mindspore/ccsrc/pipeline/jit/pipeline.cc index 36be63387e6..59d359a2024 100644 --- a/mindspore/ccsrc/pipeline/jit/pipeline.cc +++ b/mindspore/ccsrc/pipeline/jit/pipeline.cc @@ -395,7 +395,7 @@ void ExecutorPy::GetGeBackendPolicy() const { bool IsPhaseExportGeir(const std::string &phase_s) { auto phase_to_export = "export.geir"; - return phase_s.rfind(phase_to_export, 0) != std::string::npos; + return phase_s.rfind(phase_to_export) != std::string::npos; } std::vector GetPipline(const ResourcePtr &resource, const std::string &phase_s, bool use_vm) { diff --git a/mindspore/ccsrc/transform/graph_ir/op_declare.cc b/mindspore/ccsrc/transform/graph_ir/op_declare.cc index 7efc6061586..1f803d4d813 100644 --- a/mindspore/ccsrc/transform/graph_ir/op_declare.cc +++ b/mindspore/ccsrc/transform/graph_ir/op_declare.cc @@ -757,7 +757,7 @@ ATTR_MAP(ExtractImagePatches) = {{"ksizes", ATTR_DESC(ksizes, AnyTraits(), OUTPUT_MAP(ExtractImagePatches) = {{0, OUTPUT_DESC(y)}}; // Conv2D -INPUT_MAP(Conv2D) = {{1, INPUT_DESC(x)}, {2, INPUT_DESC(filter)}}; +INPUT_MAP(Conv2D) = {{1, INPUT_DESC(x)}, {2, INPUT_DESC(filter)}, {3, INPUT_DESC(bias)}}; ATTR_MAP(Conv2D) = { {"stride", ATTR_DESC(strides, AnyTraits>(), AnyTraits>())}, {"pad_list", ATTR_DESC(pads, AnyTraits>(), AnyTraits>())}, @@ -794,7 +794,7 @@ ATTR_MAP(Conv2DBackpropFilterD) = { OUTPUT_MAP(Conv2DBackpropFilterD) = {{0, OUTPUT_DESC(y)}}; // DepthwiseConv2D -INPUT_MAP(DepthwiseConv2D) = {{1, INPUT_DESC(x)}, {2, INPUT_DESC(filter)}}; +INPUT_MAP(DepthwiseConv2D) = {{1, INPUT_DESC(x)}, {2, INPUT_DESC(filter)}, {3, INPUT_DESC(bias)}}; ATTR_MAP(DepthwiseConv2D) = { {"stride", ATTR_DESC(strides, AnyTraits>(), AnyTraits>())}, {"pads", ATTR_DESC(pads, AnyTraits>(), AnyTraits>())}, @@ -826,7 +826,7 @@ ATTR_MAP(DepthwiseConv2DBackpropFilterD) = { OUTPUT_MAP(DepthwiseConv2DBackpropFilterD) = {{0, OUTPUT_DESC(filter_grad)}}; // MatMulV2 -INPUT_MAP(MatMulV2) = {{1, INPUT_DESC(x1)}, {2, INPUT_DESC(x2)}}; +INPUT_MAP(MatMulV2) = {{1, INPUT_DESC(x1)}, {2, INPUT_DESC(x2)}, {3, INPUT_DESC(bias)}}; ATTR_MAP(MatMulV2) = {{"transpose_a", ATTR_DESC(transpose_x1, AnyTraits())}, {"transpose_b", ATTR_DESC(transpose_x2, AnyTraits())}}; OUTPUT_MAP(MatMulV2) = {{0, OUTPUT_DESC(y)}}; @@ -1347,7 +1347,8 @@ OUTPUT_MAP(AscendQuant) = {{0, OUTPUT_DESC(y)}}; // AscendDequant INPUT_MAP(AscendDequant) = {{1, INPUT_DESC(x)}, {2, INPUT_DESC(deq_scale)}}; ATTR_MAP(AscendDequant) = {{"sqrt_mode", ATTR_DESC(sqrt_mode, AnyTraits())}, - {"relu_flag", ATTR_DESC(relu_flag, AnyTraits())}}; + {"relu_flag", ATTR_DESC(relu_flag, AnyTraits())}, + {"dtype", ATTR_DESC(dtype, AnyTraits())}}; OUTPUT_MAP(AscendDequant) = {{0, OUTPUT_DESC(y)}}; #ifdef ENABLE_GE // Print diff --git a/mindspore/nn/layer/quant.py b/mindspore/nn/layer/quant.py index dc30d33ac18..19d2034370d 100644 --- a/mindspore/nn/layer/quant.py +++ b/mindspore/nn/layer/quant.py @@ -28,8 +28,8 @@ from mindspore._checkparam import check_int_positive, check_bool, twice from mindspore._checkparam import Rel import mindspore.context as context -from .normalization import BatchNorm2d -from .activation import get_activation +from .normalization import BatchNorm2d, BatchNorm1d +from .activation import get_activation, ReLU from ..cell import Cell from . import conv, basic from ..._checkparam import ParamValidator as validator @@ -206,7 +206,7 @@ class DenseBnAct(Cell): self.has_bn = validator.check_bool("has_bn", has_bn) self.has_act = activation is not None if has_bn: - self.batchnorm = BatchNorm2d(out_channels) + self.batchnorm = BatchNorm1d(out_channels) self.activation = get_activation(activation) def construct(self, x): @@ -1156,13 +1156,18 @@ class QuantBlock(Cell): self.has_bias = bias is not None self.activation = activation self.has_act = activation is not None + if isinstance(activation, ReLU): + self.activation = None + self.has_act = False + self.dequant.add_prim_attr("relu_flag", True) self.bias_add = P.BiasAdd() def construct(self, x): x = self.quant(x) - x = self.core_op(x, self.weight) if self.has_bias: - x = self.bias_add(x, self.bias) + x = self.core_op(x, self.weight, self.bias) + else: + x = self.core_op(x, self.weight) if self.has_act: x = self.activation(x) x = self.dequant(x, self.dequant_scale) diff --git a/mindspore/ops/operations/_inner_ops.py b/mindspore/ops/operations/_inner_ops.py index 7ae59a64670..0c421cd3fcc 100644 --- a/mindspore/ops/operations/_inner_ops.py +++ b/mindspore/ops/operations/_inner_ops.py @@ -383,6 +383,7 @@ class Dequant(PrimitiveWithInfer): def __init__(self, sqrt_mode=False, relu_flag=False): self.sqrt_mode = validator.check_value_type("sqrt_mode", sqrt_mode, [bool], self.name) self.relu_flag = validator.check_value_type("relu_flag", relu_flag, [bool], self.name) + self.add_prim_attr("dtype", mstype.float16) def infer_shape(self, x_shape, deq_scale_shape): return x_shape diff --git a/mindspore/ops/operations/math_ops.py b/mindspore/ops/operations/math_ops.py index 4e4e9187cb3..940bf65576c 100644 --- a/mindspore/ops/operations/math_ops.py +++ b/mindspore/ops/operations/math_ops.py @@ -596,7 +596,7 @@ class MatMul(PrimitiveWithInfer): raise ValueError('MatMul input x, y should be the same dimension size and should be ' + f'equal to 2, while x size = {len(x)}, y size= {len(y)}') - def infer_shape(self, x, y): + def infer_shape(self, x, y, bias=None): self.check_shape_size(x, y) cls_name = self.name # expected dimension of x, y, x:[...,a,b] y:[..., c,d], the dim size should be the same except the last two @@ -621,7 +621,7 @@ class MatMul(PrimitiveWithInfer): ret_dims = x[: -2] + [x_last[self.transpose_a], y_last[not self.transpose_b]] return ret_dims - def infer_dtype(self, x, y): + def infer_dtype(self, x, y, bias=None): args = {"x": x, "y": y} validator.check_tensor_type_same(args, mstype.float_type + mstype.int_type, self.name) if x.element_type() == mstype.int8: diff --git a/mindspore/ops/operations/nn_ops.py b/mindspore/ops/operations/nn_ops.py index e2f6888c886..aa55ab62a91 100644 --- a/mindspore/ops/operations/nn_ops.py +++ b/mindspore/ops/operations/nn_ops.py @@ -842,7 +842,7 @@ class Conv2D(PrimitiveWithInfer): self.group = validator.check_integer('group', group, 0, Rel.GT, self.name) self.add_prim_attr('offset_a', 0) - def infer_shape(self, x_shape, w_shape): + def infer_shape(self, x_shape, w_shape, b_shape=None): validator.check_integer("weight rank", len(w_shape), 4, Rel.EQ, self.name) validator.check_integer("x rank", len(x_shape), 4, Rel.EQ, self.name) validator.check(f"x_shape[1] / group", x_shape[1] // self.group, "w_shape[1]", w_shape[1], Rel.EQ, self.name) @@ -887,7 +887,7 @@ class Conv2D(PrimitiveWithInfer): out_shape = [x_shape[0], out_channel, h_out, w_out] return out_shape - def infer_dtype(self, x_dtype, w_dtype): + def infer_dtype(self, x_dtype, w_dtype, b_dtype=None): args = {'x': x_dtype, 'w': w_dtype} valid_types = [mstype.int8, mstype.int32, mstype.float16, mstype.float32] validator.check_tensor_type_same(args, valid_types, self.name) @@ -968,7 +968,7 @@ class DepthwiseConv2dNative(PrimitiveWithInfer): self.group = validator.check_integer("group", group, 0, Rel.GT, self.name) self.add_prim_attr('offset_a', 0) - def infer_shape(self, x_shape, w_shape): + def infer_shape(self, x_shape, w_shape, b_shape=None): validator.check_integer("weight rank", len(w_shape), 4, Rel.EQ, self.name) validator.check_integer("x rank", len(x_shape), 4, Rel.EQ, self.name) validator.check("x_shape[1]", x_shape[1], "w_shape[1]", w_shape[1], Rel.EQ, self.name) @@ -1011,7 +1011,7 @@ class DepthwiseConv2dNative(PrimitiveWithInfer): out_shape = [x_shape[0], out_channel, h_out, w_out] return out_shape - def infer_dtype(self, x_dtype, w_dtype): + def infer_dtype(self, x_dtype, w_dtype, b_dtype=None): args = {'x': x_dtype, 'w': w_dtype} validator.check_tensor_type_same(args, mstype.number_type, self.name) if x_dtype.element_type() == mstype.int8: diff --git a/tests/ut/python/train/quant/test_quant.py b/tests/ut/python/train/quant/test_quant.py index 4816af89360..b7a2372464f 100644 --- a/tests/ut/python/train/quant/test_quant.py +++ b/tests/ut/python/train/quant/test_quant.py @@ -78,7 +78,7 @@ def test_qat_lenet(): def test_qat_mobile_per_channel_tf(): network = mobilenetV2(num_classes=1000) img = Tensor(np.ones((1, 3, 224, 224)).astype(np.float32)) - network = qat.convert_quant_network(network, bn_fold=True, per_channel=[False, True], symmetric=[True, False]) + network = qat.convert_quant_network(network, bn_fold=True, per_channel=[True, False], symmetric=[True, False]) # should load the checkpoint. mock here for param in network.get_parameters(): param.init_data()