From c8d33568f2d2c0665c32a595651b67b93186d137 Mon Sep 17 00:00:00 2001 From: yujianfeng Date: Mon, 18 May 2020 16:47:42 +0800 Subject: [PATCH] Add an new output to FusedMulApplyMomentum --- .../ascend/ir_fusion/momentum_lossscale_fusion.cc | 15 +++++++++++++-- mindspore/ccsrc/pre_activate/common/helper.h | 1 + .../momentum_lossscale_fusion_test.py | 2 +- 3 files changed, 15 insertions(+), 3 deletions(-) diff --git a/mindspore/ccsrc/pre_activate/ascend/ir_fusion/momentum_lossscale_fusion.cc b/mindspore/ccsrc/pre_activate/ascend/ir_fusion/momentum_lossscale_fusion.cc index 11fd02d2d89..8833e75c761 100644 --- a/mindspore/ccsrc/pre_activate/ascend/ir_fusion/momentum_lossscale_fusion.cc +++ b/mindspore/ccsrc/pre_activate/ascend/ir_fusion/momentum_lossscale_fusion.cc @@ -23,6 +23,7 @@ namespace mindspore { namespace opt { namespace { +constexpr size_t kAccumIndex = 1; bool CheckValueNodeInputOfMul(const AnfNodePtr &node) { MS_EXCEPTION_IF_NULL(node); if (!node->isa()) { @@ -79,9 +80,19 @@ const AnfNodePtr MomentumLossscaleFusion::Process(const FuncGraphPtr &func_graph input_names_value[3] = "x1"; input_names_value.emplace_back("x2"); AnfAlgo::SetNodeAttr(kAttrInputNames, MakeValue(input_names_value), new_node); - new_node->set_abstract(node->abstract()); + auto node_to_output = cnode->input(kAccumIndex + 1); + MS_EXCEPTION_IF_NULL(node_to_output); + AbstractBasePtrList abstract_list{node->abstract(), node_to_output->abstract()}; + auto abstract_tuple = std::make_shared(abstract_list); + new_node->set_abstract(abstract_tuple); new_node->set_scope(node->scope()); - return new_node; + // Create Output + std::vector new_outputs; + CreateMultipleOutputsOfAnfNode(func_graph, new_node, kFusedMulApplyMomentumOutputNum, &new_outputs); + if (new_outputs.size() != kFusedMulApplyMomentumOutputNum) { + MS_LOG(EXCEPTION) << "Failed to create outputs of " << new_node->DebugString(); + } + return new_outputs[0]; } } // namespace opt } // namespace mindspore diff --git a/mindspore/ccsrc/pre_activate/common/helper.h b/mindspore/ccsrc/pre_activate/common/helper.h index 78c7e11eca2..9a162d66f46 100644 --- a/mindspore/ccsrc/pre_activate/common/helper.h +++ b/mindspore/ccsrc/pre_activate/common/helper.h @@ -92,6 +92,7 @@ constexpr size_t kApplyMomentumInputNum = 6; constexpr size_t kBiasAddInputNum = 3; constexpr size_t kTopkInputNum = 3; constexpr size_t kLarsV2InputNum = 5; +constexpr size_t kFusedMulApplyMomentumOutputNum = 2; enum FusedBatchNormInput { kX = 1, diff --git a/tests/ut/cpp/python_input/gtest_input/pre_activate/momentum_lossscale_fusion_test.py b/tests/ut/cpp/python_input/gtest_input/pre_activate/momentum_lossscale_fusion_test.py index b2464ecc46b..acc27b25a59 100644 --- a/tests/ut/cpp/python_input/gtest_input/pre_activate/momentum_lossscale_fusion_test.py +++ b/tests/ut/cpp/python_input/gtest_input/pre_activate/momentum_lossscale_fusion_test.py @@ -47,6 +47,6 @@ def test_momentum_lossscale_fusion(tag): @fns def after(input0, input1, input2, input3, input4): - return make_tuple(FusedMulApplyMomentum(input0, input1, input2, input3, input4, constant)) + return make_tuple(tuple_getitem(FusedMulApplyMomentum(input0, input1, input2, input3, input4, constant), 0)) return fns[tag]