From 7cded1ec32ad0455532b5ba9ae545134d0b7ea7f Mon Sep 17 00:00:00 2001 From: huanghui Date: Fri, 17 Apr 2020 16:50:16 +0800 Subject: [PATCH] bugfix: confusion_softmax_grad need to be set with axis and keep_dims attr --- .../ir_fusion/confusion_softmax_grad_rule.cc | 22 +++++++++++++++++++ 1 file changed, 22 insertions(+) diff --git a/mindspore/ccsrc/pre_activate/ascend/ir_fusion/confusion_softmax_grad_rule.cc b/mindspore/ccsrc/pre_activate/ascend/ir_fusion/confusion_softmax_grad_rule.cc index 1270ae77c19..8078247c2a8 100644 --- a/mindspore/ccsrc/pre_activate/ascend/ir_fusion/confusion_softmax_grad_rule.cc +++ b/mindspore/ccsrc/pre_activate/ascend/ir_fusion/confusion_softmax_grad_rule.cc @@ -21,9 +21,30 @@ #include "session/anf_runtime_algorithm.h" #include "ir/primitive.h" #include "utils/utils.h" +#include "pre_activate/common/helper.h" namespace mindspore { namespace opt { +namespace { +void SetAttrsForFusionNode(const AnfNodePtr &sub_anf, const AnfNodePtr &fusion_node) { + MS_EXCEPTION_IF_NULL(sub_anf); + MS_EXCEPTION_IF_NULL(fusion_node); + auto sub = sub_anf->cast(); + MS_EXCEPTION_IF_NULL(sub); + if (sub->size() != kSubInputNum) { + MS_LOG(EXCEPTION) << "Sub's size is not equal with 3"; + } + auto reduce_sum_anf = sub->input(2); + MS_EXCEPTION_IF_NULL(reduce_sum_anf); + auto reduce_sum = reduce_sum_anf->cast(); + if (reduce_sum == nullptr) { + MS_LOG(EXCEPTION) << "Sub's second input is not a cnode"; + } + AnfAlgo::CopyNodeAttr(kAttrAxis, reduce_sum, fusion_node); + AnfAlgo::CopyNodeAttr(kAttrKeepDims, reduce_sum, fusion_node); +} +} // namespace + const BaseRef ConfusionSoftmaxGradRule::DefinePattern() const { return VectorRef( {prim::kPrimSub, input0_, VectorRef({prim::kPrimReduceSum, VectorRef({prim::kPrimMul, input0_, input1_})})}); @@ -48,6 +69,7 @@ const AnfNodePtr ConfusionSoftmaxGradRule::Process(const FuncGraphPtr &graph, co auto shapes = {AnfAlgo::GetOutputInferShape(node, 0)}; AnfAlgo::SetOutputInferTypeAndShape(types, shapes, confusion_softmax_grad.get()); confusion_softmax_grad->set_scope(node->scope()); + SetAttrsForFusionNode(node, confusion_softmax_grad); return confusion_softmax_grad; } } // namespace opt