From 90a83db5f1de429174a01126b34e768a5ba2238f Mon Sep 17 00:00:00 2001 From: yuchaojie Date: Thu, 4 Feb 2021 09:46:59 +0800 Subject: [PATCH] fix Conv2DBackpropInputUnifyMindIR when it is a forward op in pynative --- .../ascend/mindir/conv2d_unify_mindir.cc | 38 ++++++++++++------- mindspore/ops/operations/nn_ops.py | 11 ++++-- 2 files changed, 31 insertions(+), 18 deletions(-) diff --git a/mindspore/ccsrc/backend/optimizer/ascend/mindir/conv2d_unify_mindir.cc b/mindspore/ccsrc/backend/optimizer/ascend/mindir/conv2d_unify_mindir.cc index 60e7248139c..56b4664332c 100644 --- a/mindspore/ccsrc/backend/optimizer/ascend/mindir/conv2d_unify_mindir.cc +++ b/mindspore/ccsrc/backend/optimizer/ascend/mindir/conv2d_unify_mindir.cc @@ -37,6 +37,8 @@ constexpr auto kAttrPadList = "pad_list"; constexpr auto kAttrMode = "mode"; constexpr auto kAttrChannelMultiplier = "channel_multiplier"; constexpr auto kAttrPerm = "perm"; +constexpr auto kAttrInputSizes = "input_sizes"; +constexpr auto kAttrInputSize = "input_size"; bool NeedUpdate(const CNodePtr &conv2d, std::vector in_shape, std::vector out_shape) { MS_EXCEPTION_IF_NULL(conv2d); @@ -144,14 +146,22 @@ CNodePtr CreateDepthwiseConv2DBackpropInput(const FuncGraphPtr &graph, const CNo const CNodePtr &transpose) { MS_EXCEPTION_IF_NULL(graph); MS_EXCEPTION_IF_NULL(conv2d_backin); - if (conv2d_backin->inputs().size() != kConv2DBackpropInputNum) { - MS_LOG(EXCEPTION) << "Conv2DBackpropInput's input number should be " << kConv2DBackpropInputNum - 1 << ", but got " - << conv2d_backin->inputs().size() - 1; + + CNodePtr depth_conv_backin = nullptr; + if (conv2d_backin->inputs().size() == kConv2DBackpropInputNum) { + std::vector depth_conv_backin_inputs = { + NewValueNode(std::make_shared(kDepthwiseConv2dNativeBackpropInputOpName)), conv2d_backin->input(3), + transpose, conv2d_backin->input(1)}; + depth_conv_backin = graph->NewCNode(depth_conv_backin_inputs); + } else { + // In nn.Conv2DTranspose, Conv2DBackpropInput is a forward op and the input_sizes input will be convert to attr + // in pynative mode. + std::vector depth_conv_backin_inputs = { + NewValueNode(std::make_shared(kDepthwiseConv2dNativeBackpropInputOpName)), transpose, + conv2d_backin->input(1)}; + depth_conv_backin = graph->NewCNode(depth_conv_backin_inputs); + AnfAlgo::CopyNodeAttr(kAttrInputSizes, kAttrInputSize, conv2d_backin, depth_conv_backin); } - std::vector depth_conv_backin_inputs = { - NewValueNode(std::make_shared(kDepthwiseConv2dNativeBackpropInputOpName)), conv2d_backin->input(3), - transpose, conv2d_backin->input(1)}; - auto depth_conv_backin = graph->NewCNode(depth_conv_backin_inputs); MS_EXCEPTION_IF_NULL(depth_conv_backin); depth_conv_backin->set_abstract(conv2d_backin->abstract()); depth_conv_backin->set_scope(conv2d_backin->scope()); @@ -265,10 +275,8 @@ const AnfNodePtr Conv2DUnifyMindIR::Process(const FuncGraphPtr &graph, const Anf } const BaseRef Conv2DBackpropInputUnifyMindIR::DefinePattern() const { - VarPtr dout = std::make_shared(); - VarPtr weight = std::make_shared(); - VarPtr input_size = std::make_shared(); - VectorRef pattern({prim::kPrimConv2DBackpropInput, dout, weight, input_size}); + VarPtr Xs = std::make_shared(); + VectorRef pattern({prim::kPrimConv2DBackpropInput, Xs}); return pattern; } @@ -285,9 +293,11 @@ const AnfNodePtr Conv2DBackpropInputUnifyMindIR::Process(const FuncGraphPtr &gra return nullptr; } - if (conv2d_backin->inputs().size() != kConv2DBackpropInputNum) { - MS_LOG(EXCEPTION) << "Conv2DBackpropInput's input number should be " << kConv2DBackpropInputNum - 1 << ", but got " - << conv2d_backin->inputs().size() - 1; + auto input_size = conv2d_backin->inputs().size(); + // In pynative mode, input_sizes input will be convert to attr if Conv2DBackpropInput is a forward op. + if (input_size != kConv2DBackpropInputNum && input_size != kConv2DBackpropInputNum - 1) { + MS_LOG(EXCEPTION) << "Conv2DBackpropInput's input number should be " << kConv2DBackpropInputNum - 1 << " or " + << kConv2DBackpropInputNum - 2 << ", but got " << input_size - 1; } auto transpose = CreateTranspose(graph, conv2d_backin, conv2d_backin->input(2), true); auto depth_conv_backin = CreateDepthwiseConv2DBackpropInput(graph, conv2d_backin, transpose); diff --git a/mindspore/ops/operations/nn_ops.py b/mindspore/ops/operations/nn_ops.py index 82cc291de4c..c8d24501d95 100644 --- a/mindspore/ops/operations/nn_ops.py +++ b/mindspore/ops/operations/nn_ops.py @@ -803,8 +803,8 @@ class FusedBatchNorm(Primitive): mode (int): Mode of batch normalization, value is 0 or 1. Default: 0. epsilon (float): A small value added for numerical stability. Default: 1e-5. momentum (float): The hyper parameter to compute moving average for running_mean and running_var - (e.g. :math:`new\_running\_mean = momentum * running\_mean + (1 - momentum) * current\_mean`). - Momentum value must be [0, 1]. Default: 0.9. + (e.g. :math:`new\_running\_mean = (1 - momentum) * running\_mean + momentum * current\_mean`). + Momentum value must be [0, 1]. Default: 0.1. Inputs: - **input_x** (Tensor) - Tensor of shape :math:`(N, C)`. @@ -893,8 +893,8 @@ class FusedBatchNormEx(PrimitiveWithInfer): mode (int): Mode of batch normalization, value is 0 or 1. Default: 0. epsilon (float): A small value added for numerical stability. Default: 1e-5. momentum (float): The hyper parameter to compute moving average for running_mean and running_var - (e.g. :math:`new\_running\_mean = momentum * running\_mean + (1 - momentum) * current\_mean`). - Momentum value must be [0, 1]. Default: 0.9. + (e.g. :math:`new\_running\_mean = (1 - momentum) * running\_mean + momentum * current\_mean`). + Momentum value must be [0, 1]. Default: 0.1. data_format (str): The optional value for data format, is 'NHWC' or 'NCHW'. Default: "NCHW". @@ -1262,6 +1262,9 @@ class BatchNorm(PrimitiveWithInfer): is_training (bool): If `is_training` is True, `mean` and `variance` are computed during training. If `is_training` is False, they're loaded from checkpoint during inference. Default: False. epsilon (float): A small value added for numerical stability. Default: 1e-5. + momentum (float): The hyper parameter to compute moving average for running_mean and running_var + (e.g. :math:`new\_running\_mean = (1 - momentum) * running\_mean + momentum * current\_mean`). + Momentum value must be [0, 1]. Default: 0.1. data_format (str): The optional value for data format, is 'NHWC' or 'NCHW'. Default: "NCHW".