fix bug of not insert memcpy when hccl_op has more than one input

This commit is contained in:
huanghui 2020-06-19 11:14:15 +08:00
parent c17c525f2f
commit c66fe00049
1 changed files with 17 additions and 8 deletions

View File

@ -98,16 +98,23 @@ bool InsertMemcpyAsyncForHcclOp::NeedInsertMemcpy(const FuncGraphPtr &graph, con
void InsertMemcpyAsyncForHcclOp::InsertMemcpyAsync(const FuncGraphPtr &graph, const CNodePtr &hccl_node) const {
MS_EXCEPTION_IF_NULL(graph);
MS_EXCEPTION_IF_NULL(hccl_node);
if (hccl_node->size() != 2) {
MS_LOG(INFO) << "node[" + AnfAlgo::GetCNodeName(hccl_node) + "]'s inputs size not equal 2";
return;
bool has_insert_memcpy = false;
AnfNodePtr memcpy_async = nullptr;
std::vector<AnfNodePtr> new_inputs = {hccl_node->input(0)};
for (size_t i = 1; i < hccl_node->size(); ++i) {
auto input = hccl_node->input(i);
if (NeedInsertMemcpy(graph, input)) {
memcpy_async = CreateMemcpyAsyncOp(graph, input);
has_insert_memcpy = true;
new_inputs.push_back(memcpy_async);
} else {
new_inputs.push_back(input);
}
}
auto input = hccl_node->input(1);
if (NeedInsertMemcpy(graph, input)) {
auto memcpy_async = CreateMemcpyAsyncOp(graph, input);
if (has_insert_memcpy) {
CNodePtr new_hccl_node = std::make_shared<CNode>(*hccl_node);
new_hccl_node->set_inputs({hccl_node->input(0), memcpy_async});
new_hccl_node->set_inputs(new_inputs);
auto manager = graph->manager();
MS_EXCEPTION_IF_NULL(manager);
MS_LOG(DEBUG) << "start replace new_hccl_node to old hccl_node";
@ -115,8 +122,10 @@ void InsertMemcpyAsyncForHcclOp::InsertMemcpyAsync(const FuncGraphPtr &graph, co
MS_LOG(DEBUG) << "end replace";
// transer hccl op's control to the memcpy_async
if (hccl_node->size() == 2) {
TransferControl(new_hccl_node, memcpy_async, graph);
}
}
}
const AnfNodePtr InsertMemcpyAsyncForHcclOp::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node,