From bc2df2c913f83e54864b4e3d3637e2c47f615c6d Mon Sep 17 00:00:00 2001 From: YuJianfeng Date: Sat, 18 Apr 2020 16:48:51 +0800 Subject: [PATCH] Fix inputs size and attr for AddN fission pass --- .../ascend/ir_fission/addn_fission.cc | 16 ++++++++++------ mindspore/ccsrc/utils/utils.h | 2 +- .../pre_activate/addn_fission_test.py | 11 +++-------- 3 files changed, 14 insertions(+), 15 deletions(-) diff --git a/mindspore/ccsrc/pre_activate/ascend/ir_fission/addn_fission.cc b/mindspore/ccsrc/pre_activate/ascend/ir_fission/addn_fission.cc index f6eb6aca64e..b9a86f7bcb8 100644 --- a/mindspore/ccsrc/pre_activate/ascend/ir_fission/addn_fission.cc +++ b/mindspore/ccsrc/pre_activate/ascend/ir_fission/addn_fission.cc @@ -34,6 +34,8 @@ AnfNodePtr CreateNewAddn(const FuncGraphPtr &func_graph, const CNodePtr &origin_ new_addn->set_scope(origin_addn_cnode->scope()); new_addn->set_abstract(origin_addn_cnode->abstract()); AnfAlgo::SetNodeAttr(kAttrN, MakeValue(SizeToInt(offset)), new_addn); + std::vector dyn_input_sizes{SizeToInt(offset)}; + AnfAlgo::SetNodeAttr(kAttrDynInputSizes, MakeValue(dyn_input_sizes), new_addn); return new_addn; } } // namespace @@ -55,22 +57,24 @@ const AnfNodePtr AddnFission::Process(const FuncGraphPtr &func_graph, const AnfN } CNodePtr new_cnode = cnode; while (origin_input_size > inputs_divisor_) { + MS_EXCEPTION_IF_NULL(new_cnode); std::vector base_addn_inputs{NewValueNode(std::make_shared(prim::kPrimAddN->name()))}; size_t cur_input_index = 1; - // Divide the inputs of addn by 63. - while (origin_input_size - cur_input_index + 1 > inputs_divisor_) { + // Divide the inputs of addn by inputs_divisor_. + while (origin_input_size - cur_input_index + 1 >= inputs_divisor_) { base_addn_inputs.push_back(CreateNewAddn(func_graph, new_cnode, cur_input_index, inputs_divisor_)); cur_input_index += inputs_divisor_; } - base_addn_inputs.push_back( - CreateNewAddn(func_graph, new_cnode, cur_input_index, origin_input_size - cur_input_index + 1)); - + for (size_t i = cur_input_index; i <= origin_input_size; i++) { + base_addn_inputs.push_back(new_cnode->input(i)); + } CNodePtr base_addn = func_graph->NewCNode(base_addn_inputs); MS_EXCEPTION_IF_NULL(base_addn); - MS_EXCEPTION_IF_NULL(new_cnode); base_addn->set_scope(new_cnode->scope()); base_addn->set_abstract(new_cnode->abstract()); AnfAlgo::SetNodeAttr(kAttrN, MakeValue(SizeToInt(base_addn_inputs.size() - 1)), base_addn); + std::vector dyn_input_sizes{SizeToInt(base_addn_inputs.size() - 1)}; + AnfAlgo::SetNodeAttr(kAttrDynInputSizes, MakeValue(dyn_input_sizes), base_addn); new_cnode = base_addn; origin_input_size = base_addn->inputs().size() - 1; } diff --git a/mindspore/ccsrc/utils/utils.h b/mindspore/ccsrc/utils/utils.h index 10ef4abf62f..eac901b74de 100644 --- a/mindspore/ccsrc/utils/utils.h +++ b/mindspore/ccsrc/utils/utils.h @@ -149,7 +149,7 @@ constexpr auto kAttrDynInputSizes = "dyn_input_sizes"; constexpr auto kAttrSrcFormat = "src_format"; constexpr auto kAttrOutputUsedNum = "output_used_num"; constexpr auto kAttrHasBias = "has_bias"; -constexpr auto kAttrN = "N"; +constexpr auto kAttrN = "n"; constexpr auto kAttrLabelForInsertStreamActive = "label_for_insert_stream_active"; // attr value diff --git a/tests/ut/cpp/python_input/gtest_input/pre_activate/addn_fission_test.py b/tests/ut/cpp/python_input/gtest_input/pre_activate/addn_fission_test.py index c120ac3e68e..76d7e73a800 100644 --- a/tests/ut/cpp/python_input/gtest_input/pre_activate/addn_fission_test.py +++ b/tests/ut/cpp/python_input/gtest_input/pre_activate/addn_fission_test.py @@ -45,13 +45,10 @@ def test_addn_fission(tag): b = addn((input2, input3)) c = addn((input4, input5)) d = addn((input6, input7)) - e = addn((input8,)) f = addn((a, b)) g = addn((c, d)) - h = addn((e,)) i = addn((f, g)) - j = addn((h,)) - return addn((i, j)) + return addn((i, input8)) @fns def after_divided_by_3(input0, input1, input2, input3, input4, input5, input6, input7, input8): @@ -64,14 +61,12 @@ def test_addn_fission(tag): def after_divided_by_4(input0, input1, input2, input3, input4, input5, input6, input7, input8): a = addn((input0, input1, input2, input3)) b = addn((input4, input5, input6, input7)) - c = addn((input8,)) - return addn((a, b, c)) + return addn((a, b, input8)) @fns def after_divided_by_8(input0, input1, input2, input3, input4, input5, input6, input7, input8): a = addn((input0, input1, input2, input3, input4, input5, input6, input7)) - b = addn((input8,)) - return addn((a, b)) + return addn((a, input8)) @fns def after_divided_by_9(input0, input1, input2, input3, input4, input5, input6, input7, input8):