diff --git a/mindspore/ccsrc/backend/optimizer/ascend/mindir/sparse_softmax_cross_entropy_with_logits_unify_mindir.cc b/mindspore/ccsrc/backend/optimizer/ascend/mindir/sparse_softmax_cross_entropy_with_logits_unify_mindir.cc index 7f79aed72b8..d4929b85cd1 100644 --- a/mindspore/ccsrc/backend/optimizer/ascend/mindir/sparse_softmax_cross_entropy_with_logits_unify_mindir.cc +++ b/mindspore/ccsrc/backend/optimizer/ascend/mindir/sparse_softmax_cross_entropy_with_logits_unify_mindir.cc @@ -50,7 +50,8 @@ ValueNodePtr CreateValueNode(const ValuePtr &value_ptr, TypeId output_type) { return new_node; } -CNodePtr CreateOneHot(const FuncGraphPtr &graph, const CNodePtr &sparse_softmax_node, bool is_pynative = false) { +CNodePtr CreateOneHot(const FuncGraphPtr &graph, const CNodePtr &sparse_softmax_node, + bool is_convert_const_to_attr = false) { MS_EXCEPTION_IF_NULL(graph); MS_EXCEPTION_IF_NULL(sparse_softmax_node); @@ -80,7 +81,7 @@ CNodePtr CreateOneHot(const FuncGraphPtr &graph, const CNodePtr &sparse_softmax_ one_hot_primitive->set_attr(kAttrOutputNames, MakeValue(output_names)); std::vector one_hot_inputs; - if (is_pynative) { + if (is_convert_const_to_attr) { one_hot_inputs = {NewValueNode(one_hot_primitive), sparse_softmax_node->input(2), value_on_node, value_off_node}; } else { auto depth_node = NewValueNode(depth); @@ -97,7 +98,7 @@ CNodePtr CreateOneHot(const FuncGraphPtr &graph, const CNodePtr &sparse_softmax_ std::vector labels_shape = AnfAlgo ::GetPrevNodeOutputInferShape(sparse_softmax_node, 1); labels_shape.emplace_back(depth); AnfAlgo::SetOutputInferTypeAndShape({kNumberTypeFloat32}, {labels_shape}, one_hot_node.get()); - if (is_pynative) { + if (is_convert_const_to_attr) { AnfAlgo::SetNodeAttr(kAttrDepth, MakeValue(depth), one_hot_node); } return one_hot_node; @@ -252,7 +253,7 @@ CNodePtr CreateExpandDimsPynative(const FuncGraphPtr &graph, const CNodePtr &rea } CNodePtr CreateTile(const FuncGraphPtr &graph, const CNodePtr &sparse_softmax_node, const CNodePtr &mul_node, - bool is_pynative = false) { + bool is_convert_const_to_attr = false) { MS_EXCEPTION_IF_NULL(graph); MS_EXCEPTION_IF_NULL(sparse_softmax_node); MS_EXCEPTION_IF_NULL(mul_node); @@ -268,6 +269,9 @@ CNodePtr CreateTile(const FuncGraphPtr &graph, const CNodePtr &sparse_softmax_no std::vector multiple_value; std::transform(labels_shape.begin(), labels_shape.end(), std::back_inserter(multiple_value), [](size_t label) { return static_cast(label); }); + if (std::all_of(multiple_value.begin(), multiple_value.end(), [](int64_t value) { return value == 1; })) { + return nullptr; + } auto multiples = MakeValue(multiple_value); auto multiples_node = CreateValueNode(multiples, kNumberTypeInt64); MS_EXCEPTION_IF_NULL(multiples_node); @@ -279,7 +283,7 @@ CNodePtr CreateTile(const FuncGraphPtr &graph, const CNodePtr &sparse_softmax_no tile_primitive->set_attr(kAttrOutputNames, MakeValue(output_names)); std::vector tile_inputs; - if (is_pynative) { + if (is_convert_const_to_attr) { tile_inputs = {NewValueNode(tile_primitive), mul_node->input(2)}; } else { auto kernel_graph = graph->cast(); @@ -292,7 +296,7 @@ CNodePtr CreateTile(const FuncGraphPtr &graph, const CNodePtr &sparse_softmax_no tile_node->set_scope(mul_node->scope()); AnfAlgo::SetOutputInferTypeAndShape({AnfAlgo::GetPrevNodeOutputInferDataType(mul_node, 1)}, {labels_shape}, tile_node.get()); - if (is_pynative) { + if (is_convert_const_to_attr) { AnfAlgo::SetNodeAttr(kAttrMultiples, MakeValue(multiples), tile_node); } // feature map set @@ -302,7 +306,7 @@ CNodePtr CreateTile(const FuncGraphPtr &graph, const CNodePtr &sparse_softmax_no return tile_node; } -CNodePtr CreateRealDiv(const FuncGraphPtr &graph, const CNodePtr &sparse_softmax_node, const CNodePtr &tile_node) { +CNodePtr CreateRealDiv(const FuncGraphPtr &graph, const CNodePtr &sparse_softmax_node, const AnfNodePtr &tile_node) { MS_EXCEPTION_IF_NULL(graph); MS_EXCEPTION_IF_NULL(sparse_softmax_node); MS_EXCEPTION_IF_NULL(tile_node); @@ -464,16 +468,24 @@ const AnfNodePtr GradSparseSoftmaxCrossEntropyWithLogitsUnifyMindIR::Process(con std::vector softmax_node_outputs; CreateMultipleOutputsOfAnfNode(graph, softmax_node, kSoftmaxCrossEntropyWithLogitsOutputNum, &softmax_node_outputs); auto tile_node = CreateTile(graph, sparse_softmax_node_grad, mul_node); - auto real_div_node = CreateRealDiv(graph, sparse_softmax_node_grad, tile_node); + CNodePtr real_div_node; + if (tile_node == nullptr) { + real_div_node = CreateRealDiv(graph, sparse_softmax_node_grad, mul_node->input(2)); + } else { + real_div_node = CreateRealDiv(graph, sparse_softmax_node_grad, tile_node); + } auto expand_dims_node = CreateExpandDims(graph, real_div_node); - - mul_node->set_input(1, softmax_node_outputs[1]); - mul_node->set_input(2, expand_dims_node); + std::vector new_mul_inputs = {NewValueNode(std::make_shared(kMulOpName)), + softmax_node_outputs[1], expand_dims_node}; + auto new_mul_node = graph->NewCNode(new_mul_inputs); + MS_EXCEPTION_IF_NULL(new_mul_node); + new_mul_node->set_scope(mul_node->scope()); + new_mul_node->set_abstract(mul_node->abstract()); auto manager = graph->manager(); MS_EXCEPTION_IF_NULL(manager); manager->Replace(sparse_softmax_node_grad, softmax_node_outputs[1]); - return mul_node; + return new_mul_node; } const BaseRef GradSparseSoftmaxCrossEntropyWithLogitsUnifyMindIRV2::DefinePattern() const { @@ -563,19 +575,26 @@ const AnfNodePtr PynativeGradSparseSoftmaxCrossEntropyWithLogitsUnifyMindIR::Pro } CNodePtr softmax_node; - auto one_hot_node = CreateOneHot(graph, sparse_softmax_node_grad, true); + auto one_hot_node = CreateOneHot(graph, sparse_softmax_node_grad); softmax_node = CreateSoftmaxCrossEntropyWithLogits(graph, sparse_softmax_node_grad, one_hot_node); std::vector softmax_node_outputs; CreateMultipleOutputsOfAnfNode(graph, softmax_node, kSoftmaxCrossEntropyWithLogitsOutputNum, &softmax_node_outputs); - auto tile_node = CreateTile(graph, sparse_softmax_node_grad, mul_node, true); - auto real_div_node = CreateRealDiv(graph, sparse_softmax_node_grad, tile_node); + auto tile_node = CreateTile(graph, sparse_softmax_node_grad, mul_node); + CNodePtr real_div_node; + if (tile_node == nullptr) { + real_div_node = CreateRealDiv(graph, sparse_softmax_node_grad, mul_node->input(2)); + } else { + real_div_node = CreateRealDiv(graph, sparse_softmax_node_grad, tile_node); + } auto expand_dims_node = CreateExpandDimsPynative(graph, real_div_node); - - mul_node->set_input(1, softmax_node_outputs[1]); - mul_node->set_input(2, expand_dims_node); - - return mul_node; + std::vector new_mul_inputs = {NewValueNode(std::make_shared(kMulOpName)), + softmax_node_outputs[1], expand_dims_node}; + auto new_mul_node = graph->NewCNode(new_mul_inputs); + MS_EXCEPTION_IF_NULL(new_mul_node); + new_mul_node->set_scope(mul_node->scope()); + new_mul_node->set_abstract(mul_node->abstract()); + return new_mul_node; } } // namespace opt } // namespace mindspore