diff --git a/mindspore/ccsrc/pre_activate/ascend/ir_fusion/confusion_softmax_grad_rule.cc b/mindspore/ccsrc/pre_activate/ascend/ir_fusion/confusion_softmax_grad_rule.cc index a524d694e67..9e2c6374ce9 100644 --- a/mindspore/ccsrc/pre_activate/ascend/ir_fusion/confusion_softmax_grad_rule.cc +++ b/mindspore/ccsrc/pre_activate/ascend/ir_fusion/confusion_softmax_grad_rule.cc @@ -25,29 +25,8 @@ namespace mindspore { namespace opt { -namespace { -void SetAttrsForFusionNode(const AnfNodePtr &sub_anf, const AnfNodePtr &fusion_node) { - MS_EXCEPTION_IF_NULL(sub_anf); - MS_EXCEPTION_IF_NULL(fusion_node); - auto sub = sub_anf->cast(); - MS_EXCEPTION_IF_NULL(sub); - if (sub->size() != kSubInputNum) { - MS_LOG(EXCEPTION) << "Sub's size is not equal with 3"; - } - auto reduce_sum_anf = sub->input(2); - MS_EXCEPTION_IF_NULL(reduce_sum_anf); - auto reduce_sum = reduce_sum_anf->cast(); - if (reduce_sum == nullptr) { - MS_LOG(EXCEPTION) << "Sub's second input is not a cnode"; - } - AnfAlgo::CopyNodeAttr(kAttrAxis, reduce_sum, fusion_node); - AnfAlgo::CopyNodeAttr(kAttrKeepDims, reduce_sum, fusion_node); -} -} // namespace - const BaseRef ConfusionSoftmaxGradRule::DefinePattern() const { - return VectorRef( - {prim::kPrimSub, input0_, VectorRef({prim::kPrimReduceSum, VectorRef({prim::kPrimMul, input1_, input0_})})}); + return VectorRef({prim::kPrimSub, input0_, VectorRef({reduce_sum_, VectorRef({prim::kPrimMul, input1_, input0_})})}); } const AnfNodePtr ConfusionSoftmaxGradRule::Process(const FuncGraphPtr &graph, const AnfNodePtr &node, @@ -55,22 +34,28 @@ const AnfNodePtr ConfusionSoftmaxGradRule::Process(const FuncGraphPtr &graph, co MS_EXCEPTION_IF_NULL(graph); MS_EXCEPTION_IF_NULL(node); MS_EXCEPTION_IF_NULL(equiv); - auto input0 = utils::cast((*equiv)[input0_]); - auto input1 = utils::cast((*equiv)[input1_]); - MS_EXCEPTION_IF_NULL(input0); - MS_EXCEPTION_IF_NULL(input1); + AnfNodePtr input0 = GetAnfNodeByVar(equiv, input0_); + AnfNodePtr input1 = GetAnfNodeByVar(equiv, input1_); + AnfNodePtr sum_anf = GetAnfNodeByVar(equiv, reduce_sum_); + if (sum_anf == nullptr || !sum_anf->isa()) { + MS_LOG(WARNING) << "Matched ReduceSum is not a CNode!"; + return nullptr; + } + if (!GetBoolAttr(sum_anf, kAttrKeepDims)) { + MS_LOG(INFO) << "ReduceSum's attr keep_dims should be true if do fusion. Otherwise the calculation will be wrong"; + return nullptr; + } auto prim = std::make_shared(kConfusionSoftmaxGradOpName); MS_EXCEPTION_IF_NULL(prim); std::vector inputs = {NewValueNode(prim), input0, input1}; - auto confusion_softmax_grad = graph->NewCNode(inputs); - MS_EXCEPTION_IF_NULL(confusion_softmax_grad); - auto types = {AnfAlgo::GetOutputInferDataType(node, 0)}; - auto shapes = {AnfAlgo::GetOutputInferShape(node, 0)}; - AnfAlgo::SetOutputInferTypeAndShape(types, shapes, confusion_softmax_grad.get()); - confusion_softmax_grad->set_scope(node->scope()); - SetAttrsForFusionNode(node, confusion_softmax_grad); - return confusion_softmax_grad; + auto fusion_node = graph->NewCNode(inputs); + MS_EXCEPTION_IF_NULL(fusion_node); + fusion_node->set_abstract(node->abstract()); + fusion_node->set_scope(node->scope()); + AnfAlgo::CopyNodeAttr(kAttrAxis, sum_anf, fusion_node); + AnfAlgo::CopyNodeAttr(kAttrKeepDims, sum_anf, fusion_node); + return fusion_node; } } // namespace opt } // namespace mindspore diff --git a/mindspore/ccsrc/pre_activate/ascend/ir_fusion/confusion_softmax_grad_rule.h b/mindspore/ccsrc/pre_activate/ascend/ir_fusion/confusion_softmax_grad_rule.h index 58722e586f9..a4d0d1ce7aa 100644 --- a/mindspore/ccsrc/pre_activate/ascend/ir_fusion/confusion_softmax_grad_rule.h +++ b/mindspore/ccsrc/pre_activate/ascend/ir_fusion/confusion_softmax_grad_rule.h @@ -24,9 +24,11 @@ namespace opt { class ConfusionSoftmaxGradRule : public PatternProcessPass { public: explicit ConfusionSoftmaxGradRule(bool multigraph = true) - : PatternProcessPass("confusion_softmax_grad_rule", multigraph), - input0_(std::make_shared()), - input1_(std::make_shared()) {} + : PatternProcessPass("confusion_softmax_grad_rule", multigraph) { + input0_ = std::make_shared(); + input1_ = std::make_shared(); + reduce_sum_ = std::make_shared(std::make_shared(prim::kPrimReduceSum->name())); + } ~ConfusionSoftmaxGradRule() override = default; const BaseRef DefinePattern() const override; const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; @@ -34,6 +36,7 @@ class ConfusionSoftmaxGradRule : public PatternProcessPass { private: VarPtr input0_; VarPtr input1_; + VarPtr reduce_sum_; }; } // namespace opt } // namespace mindspore diff --git a/tests/ut/cpp/python_input/gtest_input/pre_activate/confusion_softmax_grad_rule.py b/tests/ut/cpp/python_input/gtest_input/pre_activate/confusion_softmax_grad_rule.py index db435712f81..93902c24cac 100644 --- a/tests/ut/cpp/python_input/gtest_input/pre_activate/confusion_softmax_grad_rule.py +++ b/tests/ut/cpp/python_input/gtest_input/pre_activate/confusion_softmax_grad_rule.py @@ -16,7 +16,7 @@ from mindspore.ops import Primitive from mindspore.ops import operations as P mul = P.Mul() -reduce_sum = P.ReduceSum() +reduce_sum = P.ReduceSum(keep_dims=True) sub = P.Sub() confusion_softmax_grad = Primitive('ConfusionSoftmaxGrad') make_tuple = Primitive('make_tuple')