!7147 insert memcpy in hccl input

Merge pull request !7147 from jjfeing/master
This commit is contained in:
mindspore-ci-bot 2020-10-11 10:31:41 +08:00 committed by Gitee
commit 7af0d3374f
2 changed files with 17 additions and 23 deletions

View File

@ -94,6 +94,22 @@ void TransferControl(const CNodePtr &hccl_node, const std::vector<AnfNodePtr> &m
}
}
}
// NodeUsersMap, for node B input i use node A, it will be one item in map with key: A, and value: (B, i)
bool IsNodeOutPutUsedByOtherRealKernel(const AnfNodeIndexSet &node_users) {
if (node_users.size() == 1) {
MS_LOG(INFO) << "This node only used once, no need to insert memcpy node.";
return false;
}
for (const auto &node_pair : node_users) {
auto node = node_pair.first;
if (AnfAlgo::IsRealKernel(node) && !AnfAlgo::IsCommunicationOp(node)) {
MS_LOG(INFO) << "This node only used other real kernel: " << node->fullname_with_scope();
return true;
}
}
MS_LOG(INFO) << "This node used by other node, but the node is not real kernel, no need to insert memcpy node.";
return false;
}
} // namespace
bool InsertMemcpyAsyncForHcclOp::NeedInsertMemcpy(const FuncGraphPtr &graph, const AnfNodePtr &input,
@ -126,7 +142,7 @@ bool InsertMemcpyAsyncForHcclOp::NeedInsertMemcpy(const FuncGraphPtr &graph, con
if (iter == node_users.end()) {
MS_LOG(EXCEPTION) << "node has no output in manager";
}
if (iter->second.size() > 1) {
if (IsNodeOutPutUsedByOtherRealKernel(iter->second)) {
return true;
}
}

View File

@ -52,28 +52,6 @@ class MockInsertMemcpyForHcclKernelQuery : public KernelQuery {
}
};
TEST_F(TestHWInsertMemcpyForHccl, test_cond1) {
get_py_fun_.SetDoResolve(true);
FuncGraphPtr g = get_py_fun_.CallAndParseRet("test_insert_memcpy_async_for_hccl_op_cond1", "before1");
ASSERT_TRUE(g != nullptr);
std::vector<int> shp_x{1, 64, 112, 112};
auto x_abstract = std::make_shared<abstract::AbstractTensor>(kFloat32, shp_x);
AbstractBasePtrList args_spec_list{x_abstract};
auto kg = GetKernelGraph(g, args_spec_list);
EXPECT_NE(kg, nullptr);
auto optimizer = std::make_shared<opt::GraphOptimizer>();
auto pm = std::make_shared<opt::PassManager>();
auto pass = std::make_shared<opt::InsertMemcpyAsyncForHcclOp>();
pass->kernel_query_ = std::make_shared<MockInsertMemcpyForHcclKernelQuery>();
pm->AddPass(pass);
optimizer->AddPassManager(pm);
auto new_graph = optimizer->Optimize(kg);
FuncGraphPtr g_after = get_py_fun_.CallAndParseRet("test_insert_memcpy_async_for_hccl_op_cond1", "after");
EXPECT_TRUE(CheckEqualGraph(g_after, new_graph));
}
TEST_F(TestHWInsertMemcpyForHccl, test_cond1_no_insert) {
get_py_fun_.SetDoResolve(true);
FuncGraphPtr g = get_py_fun_.CallAndParseRet("test_insert_memcpy_async_for_hccl_op_cond1", "before2");