From 755863ebaef79cb746eec7f0b10fc475ef6f29c1 Mon Sep 17 00:00:00 2001 From: jjfeing Date: Sat, 10 Oct 2020 14:43:43 +0800 Subject: [PATCH] insert memcpy when hccl node --- .../insert_memcpy_async_for_hccl_op.cc | 18 ++++++++++++++- .../insert_memcpy_async_for_hccl_op_test.cc | 22 ------------------- 2 files changed, 17 insertions(+), 23 deletions(-) diff --git a/mindspore/ccsrc/backend/optimizer/ascend/enhancer/insert_memcpy_async_for_hccl_op.cc b/mindspore/ccsrc/backend/optimizer/ascend/enhancer/insert_memcpy_async_for_hccl_op.cc index b0bdfd30cde..76e301d1f65 100644 --- a/mindspore/ccsrc/backend/optimizer/ascend/enhancer/insert_memcpy_async_for_hccl_op.cc +++ b/mindspore/ccsrc/backend/optimizer/ascend/enhancer/insert_memcpy_async_for_hccl_op.cc @@ -94,6 +94,22 @@ void TransferControl(const CNodePtr &hccl_node, const std::vector &m } } } +// NodeUsersMap, for node B input i use node A, it will be one item in map with key: A, and value: (B, i) +bool IsNodeOutPutUsedByOtherRealKernel(const AnfNodeIndexSet &node_users) { + if (node_users.size() == 1) { + MS_LOG(INFO) << "This node only used once, no need to insert memcpy node."; + return false; + } + for (const auto &node_pair : node_users) { + auto node = node_pair.first; + if (AnfAlgo::IsRealKernel(node) && !AnfAlgo::IsCommunicationOp(node)) { + MS_LOG(INFO) << "This node only used other real kernel: " << node->fullname_with_scope(); + return true; + } + } + MS_LOG(INFO) << "This node used by other node, but the node is not real kernel, no need to insert memcpy node."; + return false; +} } // namespace bool InsertMemcpyAsyncForHcclOp::NeedInsertMemcpy(const FuncGraphPtr &graph, const AnfNodePtr &input, @@ -126,7 +142,7 @@ bool InsertMemcpyAsyncForHcclOp::NeedInsertMemcpy(const FuncGraphPtr &graph, con if (iter == node_users.end()) { MS_LOG(EXCEPTION) << "node has no output in manager"; } - if (iter->second.size() > 1) { + if (IsNodeOutPutUsedByOtherRealKernel(iter->second)) { return true; } } diff --git a/tests/ut/cpp/pre_activate/ascend/enhancer/insert_memcpy_async_for_hccl_op_test.cc b/tests/ut/cpp/pre_activate/ascend/enhancer/insert_memcpy_async_for_hccl_op_test.cc index bbf2efc0552..ceafdade726 100644 --- a/tests/ut/cpp/pre_activate/ascend/enhancer/insert_memcpy_async_for_hccl_op_test.cc +++ b/tests/ut/cpp/pre_activate/ascend/enhancer/insert_memcpy_async_for_hccl_op_test.cc @@ -52,28 +52,6 @@ class MockInsertMemcpyForHcclKernelQuery : public KernelQuery { } }; -TEST_F(TestHWInsertMemcpyForHccl, test_cond1) { - get_py_fun_.SetDoResolve(true); - FuncGraphPtr g = get_py_fun_.CallAndParseRet("test_insert_memcpy_async_for_hccl_op_cond1", "before1"); - ASSERT_TRUE(g != nullptr); - std::vector shp_x{1, 64, 112, 112}; - auto x_abstract = std::make_shared(kFloat32, shp_x); - AbstractBasePtrList args_spec_list{x_abstract}; - auto kg = GetKernelGraph(g, args_spec_list); - EXPECT_NE(kg, nullptr); - - auto optimizer = std::make_shared(); - auto pm = std::make_shared(); - auto pass = std::make_shared(); - pass->kernel_query_ = std::make_shared(); - pm->AddPass(pass); - optimizer->AddPassManager(pm); - auto new_graph = optimizer->Optimize(kg); - - FuncGraphPtr g_after = get_py_fun_.CallAndParseRet("test_insert_memcpy_async_for_hccl_op_cond1", "after"); - EXPECT_TRUE(CheckEqualGraph(g_after, new_graph)); -} - TEST_F(TestHWInsertMemcpyForHccl, test_cond1_no_insert) { get_py_fun_.SetDoResolve(true); FuncGraphPtr g = get_py_fun_.CallAndParseRet("test_insert_memcpy_async_for_hccl_op_cond1", "before2");