From 76075c0b36cc7295b5489660b5400c1b633bf624 Mon Sep 17 00:00:00 2001 From: yuchaojie Date: Thu, 1 Jul 2021 17:33:14 +0800 Subject: [PATCH] SliceGrad unify mindir adpation for Cangjie frontend --- .../ascend/mindir/slice_grad_unify_mindir.cc | 32 +++++++++++++------ 1 file changed, 23 insertions(+), 9 deletions(-) diff --git a/mindspore/ccsrc/backend/optimizer/ascend/mindir/slice_grad_unify_mindir.cc b/mindspore/ccsrc/backend/optimizer/ascend/mindir/slice_grad_unify_mindir.cc index fd56fa17127..f02871bf52a 100644 --- a/mindspore/ccsrc/backend/optimizer/ascend/mindir/slice_grad_unify_mindir.cc +++ b/mindspore/ccsrc/backend/optimizer/ascend/mindir/slice_grad_unify_mindir.cc @@ -31,6 +31,7 @@ namespace mindspore { namespace opt { namespace { constexpr size_t kSliceGradInputTensorNum = 4; +constexpr size_t kSliceGradCangjieInputTensorNum = 2; std::vector GetInputXShape(const AnfNodePtr &node) { MS_EXCEPTION_IF_NULL(node); @@ -50,11 +51,8 @@ std::vector GetTupleValue(const AnfNodePtr &node) { } // namespace const BaseRef SliceGradUnifyMindIR::DefinePattern() const { - VarPtr X1 = std::make_shared(); - VarPtr X2 = std::make_shared(); - VarPtr X3 = std::make_shared(); - VarPtr X4 = std::make_shared(); - VectorRef slice_grad({std::make_shared("SliceGrad"), X1, X2, X3, X4}); + VarPtr Xs = std::make_shared(); + VectorRef slice_grad({std::make_shared("SliceGrad"), Xs}); return slice_grad; } @@ -63,8 +61,16 @@ const AnfNodePtr SliceGradUnifyMindIR::Process(const FuncGraphPtr &graph, const MS_EXCEPTION_IF_NULL(graph); MS_EXCEPTION_IF_NULL(node); - auto slice_grad = CheckAnfNodeIfCNodeAndInputSize(node, kSliceGradInputTensorNum); - std::vector pad_inputs = {NewValueNode(std::make_shared(kPadOpName)), slice_grad->input(1)}; + auto slice_grad = node->cast(); + MS_EXCEPTION_IF_NULL(slice_grad); + auto input_num = AnfAlgo::GetInputTensorNum(slice_grad); + if (input_num != kSliceGradInputTensorNum && input_num != kSliceGradCangjieInputTensorNum) { + MS_LOG(EXCEPTION) << "The input tensor size[" << input_num + << "] of node " + slice_grad->DebugString() + " is not equal to " << kSliceGradInputTensorNum + << " or " << kSliceGradCangjieInputTensorNum; + } + std::vector pad_inputs = {NewValueNode(std::make_shared(kPadOpName)), + slice_grad->input(kIndex1)}; auto pad = graph->NewCNode(pad_inputs); MS_EXCEPTION_IF_NULL(pad); pad->set_scope(slice_grad->scope()); @@ -72,8 +78,16 @@ const AnfNodePtr SliceGradUnifyMindIR::Process(const FuncGraphPtr &graph, const // set attr paddings auto x_shape = GetInputXShape(slice_grad); - auto begins = GetTupleValue(slice_grad->input(kIndex3)); - auto sizes = GetTupleValue(slice_grad->input(kIndex4)); + std::vector begins; + std::vector sizes; + if (input_num == kSliceGradInputTensorNum) { + begins = GetTupleValue(slice_grad->input(kIndex3)); + sizes = GetTupleValue(slice_grad->input(kIndex4)); + } else { + // if frontend is Cangjie and mode is pynative, input 2, 3 is already converted to attr + begins = AnfAlgo::GetNodeAttr>(slice_grad, kAttrBegin); + sizes = AnfAlgo::GetNodeAttr>(slice_grad, kAttrSize); + } if (x_shape.size() != begins.size() || begins.size() != sizes.size()) { MS_LOG(EXCEPTION) << "For SliceGrad, x's shape dim number should be equal to len(begin) and len(size)."; }