diff --git a/mindspore/ccsrc/pre_activate/ascend/enhancer/insert_memcpy_async_for_hccl_op.cc b/mindspore/ccsrc/pre_activate/ascend/enhancer/insert_memcpy_async_for_hccl_op.cc index cae036b2bd1..63ea59d744c 100644 --- a/mindspore/ccsrc/pre_activate/ascend/enhancer/insert_memcpy_async_for_hccl_op.cc +++ b/mindspore/ccsrc/pre_activate/ascend/enhancer/insert_memcpy_async_for_hccl_op.cc @@ -98,16 +98,23 @@ bool InsertMemcpyAsyncForHcclOp::NeedInsertMemcpy(const FuncGraphPtr &graph, con void InsertMemcpyAsyncForHcclOp::InsertMemcpyAsync(const FuncGraphPtr &graph, const CNodePtr &hccl_node) const { MS_EXCEPTION_IF_NULL(graph); MS_EXCEPTION_IF_NULL(hccl_node); - if (hccl_node->size() != 2) { - MS_LOG(INFO) << "node[" + AnfAlgo::GetCNodeName(hccl_node) + "]'s inputs size not equal 2"; - return; + bool has_insert_memcpy = false; + AnfNodePtr memcpy_async = nullptr; + std::vector new_inputs = {hccl_node->input(0)}; + for (size_t i = 1; i < hccl_node->size(); ++i) { + auto input = hccl_node->input(i); + if (NeedInsertMemcpy(graph, input)) { + memcpy_async = CreateMemcpyAsyncOp(graph, input); + has_insert_memcpy = true; + new_inputs.push_back(memcpy_async); + } else { + new_inputs.push_back(input); + } } - auto input = hccl_node->input(1); - if (NeedInsertMemcpy(graph, input)) { - auto memcpy_async = CreateMemcpyAsyncOp(graph, input); + if (has_insert_memcpy) { CNodePtr new_hccl_node = std::make_shared(*hccl_node); - new_hccl_node->set_inputs({hccl_node->input(0), memcpy_async}); + new_hccl_node->set_inputs(new_inputs); auto manager = graph->manager(); MS_EXCEPTION_IF_NULL(manager); MS_LOG(DEBUG) << "start replace new_hccl_node to old hccl_node"; @@ -115,7 +122,9 @@ void InsertMemcpyAsyncForHcclOp::InsertMemcpyAsync(const FuncGraphPtr &graph, co MS_LOG(DEBUG) << "end replace"; // transer hccl op's control to the memcpy_async - TransferControl(new_hccl_node, memcpy_async, graph); + if (hccl_node->size() == 2) { + TransferControl(new_hccl_node, memcpy_async, graph); + } } }