forked from mindspore-Ecosystem/mindspore
!7147 insert memcpy in hccl input
Merge pull request !7147 from jjfeing/master
This commit is contained in:
commit
7af0d3374f
|
@ -94,6 +94,22 @@ void TransferControl(const CNodePtr &hccl_node, const std::vector<AnfNodePtr> &m
|
|||
}
|
||||
}
|
||||
}
|
||||
// NodeUsersMap, for node B input i use node A, it will be one item in map with key: A, and value: (B, i)
|
||||
bool IsNodeOutPutUsedByOtherRealKernel(const AnfNodeIndexSet &node_users) {
|
||||
if (node_users.size() == 1) {
|
||||
MS_LOG(INFO) << "This node only used once, no need to insert memcpy node.";
|
||||
return false;
|
||||
}
|
||||
for (const auto &node_pair : node_users) {
|
||||
auto node = node_pair.first;
|
||||
if (AnfAlgo::IsRealKernel(node) && !AnfAlgo::IsCommunicationOp(node)) {
|
||||
MS_LOG(INFO) << "This node only used other real kernel: " << node->fullname_with_scope();
|
||||
return true;
|
||||
}
|
||||
}
|
||||
MS_LOG(INFO) << "This node used by other node, but the node is not real kernel, no need to insert memcpy node.";
|
||||
return false;
|
||||
}
|
||||
} // namespace
|
||||
|
||||
bool InsertMemcpyAsyncForHcclOp::NeedInsertMemcpy(const FuncGraphPtr &graph, const AnfNodePtr &input,
|
||||
|
@ -126,7 +142,7 @@ bool InsertMemcpyAsyncForHcclOp::NeedInsertMemcpy(const FuncGraphPtr &graph, con
|
|||
if (iter == node_users.end()) {
|
||||
MS_LOG(EXCEPTION) << "node has no output in manager";
|
||||
}
|
||||
if (iter->second.size() > 1) {
|
||||
if (IsNodeOutPutUsedByOtherRealKernel(iter->second)) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
|
|
|
@ -52,28 +52,6 @@ class MockInsertMemcpyForHcclKernelQuery : public KernelQuery {
|
|||
}
|
||||
};
|
||||
|
||||
TEST_F(TestHWInsertMemcpyForHccl, test_cond1) {
|
||||
get_py_fun_.SetDoResolve(true);
|
||||
FuncGraphPtr g = get_py_fun_.CallAndParseRet("test_insert_memcpy_async_for_hccl_op_cond1", "before1");
|
||||
ASSERT_TRUE(g != nullptr);
|
||||
std::vector<int> shp_x{1, 64, 112, 112};
|
||||
auto x_abstract = std::make_shared<abstract::AbstractTensor>(kFloat32, shp_x);
|
||||
AbstractBasePtrList args_spec_list{x_abstract};
|
||||
auto kg = GetKernelGraph(g, args_spec_list);
|
||||
EXPECT_NE(kg, nullptr);
|
||||
|
||||
auto optimizer = std::make_shared<opt::GraphOptimizer>();
|
||||
auto pm = std::make_shared<opt::PassManager>();
|
||||
auto pass = std::make_shared<opt::InsertMemcpyAsyncForHcclOp>();
|
||||
pass->kernel_query_ = std::make_shared<MockInsertMemcpyForHcclKernelQuery>();
|
||||
pm->AddPass(pass);
|
||||
optimizer->AddPassManager(pm);
|
||||
auto new_graph = optimizer->Optimize(kg);
|
||||
|
||||
FuncGraphPtr g_after = get_py_fun_.CallAndParseRet("test_insert_memcpy_async_for_hccl_op_cond1", "after");
|
||||
EXPECT_TRUE(CheckEqualGraph(g_after, new_graph));
|
||||
}
|
||||
|
||||
TEST_F(TestHWInsertMemcpyForHccl, test_cond1_no_insert) {
|
||||
get_py_fun_.SetDoResolve(true);
|
||||
FuncGraphPtr g = get_py_fun_.CallAndParseRet("test_insert_memcpy_async_for_hccl_op_cond1", "before2");
|
||||
|
|
Loading…
Reference in New Issue