diff --git a/mindspore/ccsrc/plugin/device/ascend/optimizer/mindir/sparse_softmax_cross_entropy_with_logits_unify_mindir.cc b/mindspore/ccsrc/plugin/device/ascend/optimizer/mindir/sparse_softmax_cross_entropy_with_logits_unify_mindir.cc index 6b3cacfb4f0..190a17581fa 100644 --- a/mindspore/ccsrc/plugin/device/ascend/optimizer/mindir/sparse_softmax_cross_entropy_with_logits_unify_mindir.cc +++ b/mindspore/ccsrc/plugin/device/ascend/optimizer/mindir/sparse_softmax_cross_entropy_with_logits_unify_mindir.cc @@ -31,6 +31,7 @@ namespace mindspore { namespace opt { namespace { constexpr size_t kCastInputNum = 2; +constexpr size_t kDependInputNum = 2; constexpr auto softmax_output_shape_size = 2; constexpr auto kAttrDepth = "depth"; constexpr auto kAttrMultiples = "multiples"; @@ -411,6 +412,20 @@ CNodePtr CreateCast(const FuncGraphPtr &graph, const CNodePtr &cast, const AnfNo return new_cast_node; } +CNodePtr CreateDepend(const FuncGraphPtr &graph, const CNodePtr &depend, const AnfNodePtr &depend_input, + const PatternProcessPass &pass) { + MS_EXCEPTION_IF_NULL(graph); + MS_EXCEPTION_IF_NULL(depend); + MS_EXCEPTION_IF_NULL(depend_input); + + std::vector new_depend_inputs = {depend->input(kAnfPrimitiveIndex), depend_input, depend->input(kIndex2)}; + auto new_depend_node = pass.NewCNode(new_depend_inputs, graph); + MS_EXCEPTION_IF_NULL(new_depend_node); + new_depend_node->set_scope(depend->scope()); + new_depend_node->set_abstract(depend->abstract()); + return new_depend_node; +} + bool IsSparseSoftmaxCrossEntropyWithLogitsGrad(const CNodePtr &sparse, const string &pass_name) { MS_EXCEPTION_IF_NULL(sparse); if (common::AnfAlgo::GetCNodeName(sparse) != kSparseSoftmaxCrossEntropyWithLogitsOpName) { @@ -696,5 +711,48 @@ const AnfNodePtr PynativeGradSparseSoftmaxCrossEntropyWithLogitsUnifyMindIRV2::P new_mul_node->set_abstract(mul_node->abstract()); return new_mul_node; } + +const BaseRef PynativeGradSparseSoftmaxCrossEntropyWithLogitsUnifyMindIRV3::DefinePattern() const { + VarPtr x1 = std::make_shared(); + VarPtr x2 = std::make_shared(); + VarPtr x3 = std::make_shared(); + VarPtr x4 = std::make_shared(); + VectorRef sparse_softmax_cross_entropy_with_logits({prim::kPrimSparseSoftmaxCrossEntropyWithLogits, x1, x2}); + VectorRef depend({prim::kPrimDepend, sparse_softmax_cross_entropy_with_logits, x3}); + return VectorRef({prim::kPrimMul, depend, x4}); +} + +const AnfNodePtr PynativeGradSparseSoftmaxCrossEntropyWithLogitsUnifyMindIRV3::Process(const FuncGraphPtr &graph, + const AnfNodePtr &node, + const EquivPtr &) const { + MS_EXCEPTION_IF_NULL(graph); + MS_EXCEPTION_IF_NULL(node); + + auto mul_node = node->cast(); + MS_EXCEPTION_IF_NULL(mul_node); + CheckCNodeInputSize(mul_node, kMulInputTensorNum); + + auto depend_node = mul_node->input(kIndex1); + auto depend_cnode = depend_node->cast(); + MS_EXCEPTION_IF_NULL(depend_cnode); + CheckCNodeInputSize(depend_cnode, kDependInputNum); + + auto sparse_softmax_node = depend_cnode->input(kIndex1); + bool is_sp_grad_flag = true; + std::vector softmax_node_outputs; + auto expand_dims_node = + CreateMulInput(graph, mul_node, sparse_softmax_node, name(), *this, &softmax_node_outputs, &is_sp_grad_flag); + if (!is_sp_grad_flag) { + return nullptr; + } + auto new_depend = CreateDepend(graph, depend_cnode, softmax_node_outputs[1], *this); + std::vector new_mul_inputs = {NewValueNode(std::make_shared(kMulOpName)), new_depend, + expand_dims_node}; + auto new_mul_node = NewCNode(new_mul_inputs, graph); + MS_EXCEPTION_IF_NULL(new_mul_node); + new_mul_node->set_scope(mul_node->scope()); + new_mul_node->set_abstract(mul_node->abstract()); + return new_mul_node; +} } // namespace opt } // namespace mindspore diff --git a/mindspore/ccsrc/plugin/device/ascend/optimizer/mindir/sparse_softmax_cross_entropy_with_logits_unify_mindir.h b/mindspore/ccsrc/plugin/device/ascend/optimizer/mindir/sparse_softmax_cross_entropy_with_logits_unify_mindir.h index edfaa721a13..4c8267d132d 100644 --- a/mindspore/ccsrc/plugin/device/ascend/optimizer/mindir/sparse_softmax_cross_entropy_with_logits_unify_mindir.h +++ b/mindspore/ccsrc/plugin/device/ascend/optimizer/mindir/sparse_softmax_cross_entropy_with_logits_unify_mindir.h @@ -77,6 +77,15 @@ class PynativeGradSparseSoftmaxCrossEntropyWithLogitsUnifyMindIRV2 : public Patt const BaseRef DefinePattern() const override; const AnfNodePtr Process(const FuncGraphPtr &graph, const AnfNodePtr &node, const EquivPtr &) const override; }; + +class PynativeGradSparseSoftmaxCrossEntropyWithLogitsUnifyMindIRV3 : public PatternProcessPass { + public: + explicit PynativeGradSparseSoftmaxCrossEntropyWithLogitsUnifyMindIRV3(bool multigraph = true) + : PatternProcessPass("pynative_grad_sparse_softmax_cross_entropy_with_logits_unify_mindir_v3", multigraph) {} + ~PynativeGradSparseSoftmaxCrossEntropyWithLogitsUnifyMindIRV3() override = default; + const BaseRef DefinePattern() const override; + const AnfNodePtr Process(const FuncGraphPtr &graph, const AnfNodePtr &node, const EquivPtr &) const override; +}; } // namespace opt } // namespace mindspore #endif // MINDSPORE_SPARSE_SOFTMAX_CROSS_ENTROPY_WITH_LOGITS_UNIFY_MINDIR_H