forked from mindspore-Ecosystem/mindspore
fix bug of not insert memcpy when hccl_op has more than one input
This commit is contained in:
parent
c17c525f2f
commit
c66fe00049
|
@ -98,16 +98,23 @@ bool InsertMemcpyAsyncForHcclOp::NeedInsertMemcpy(const FuncGraphPtr &graph, con
|
|||
void InsertMemcpyAsyncForHcclOp::InsertMemcpyAsync(const FuncGraphPtr &graph, const CNodePtr &hccl_node) const {
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
MS_EXCEPTION_IF_NULL(hccl_node);
|
||||
if (hccl_node->size() != 2) {
|
||||
MS_LOG(INFO) << "node[" + AnfAlgo::GetCNodeName(hccl_node) + "]'s inputs size not equal 2";
|
||||
return;
|
||||
bool has_insert_memcpy = false;
|
||||
AnfNodePtr memcpy_async = nullptr;
|
||||
std::vector<AnfNodePtr> new_inputs = {hccl_node->input(0)};
|
||||
for (size_t i = 1; i < hccl_node->size(); ++i) {
|
||||
auto input = hccl_node->input(i);
|
||||
if (NeedInsertMemcpy(graph, input)) {
|
||||
memcpy_async = CreateMemcpyAsyncOp(graph, input);
|
||||
has_insert_memcpy = true;
|
||||
new_inputs.push_back(memcpy_async);
|
||||
} else {
|
||||
new_inputs.push_back(input);
|
||||
}
|
||||
}
|
||||
|
||||
auto input = hccl_node->input(1);
|
||||
if (NeedInsertMemcpy(graph, input)) {
|
||||
auto memcpy_async = CreateMemcpyAsyncOp(graph, input);
|
||||
if (has_insert_memcpy) {
|
||||
CNodePtr new_hccl_node = std::make_shared<CNode>(*hccl_node);
|
||||
new_hccl_node->set_inputs({hccl_node->input(0), memcpy_async});
|
||||
new_hccl_node->set_inputs(new_inputs);
|
||||
auto manager = graph->manager();
|
||||
MS_EXCEPTION_IF_NULL(manager);
|
||||
MS_LOG(DEBUG) << "start replace new_hccl_node to old hccl_node";
|
||||
|
@ -115,8 +122,10 @@ void InsertMemcpyAsyncForHcclOp::InsertMemcpyAsync(const FuncGraphPtr &graph, co
|
|||
MS_LOG(DEBUG) << "end replace";
|
||||
|
||||
// transer hccl op's control to the memcpy_async
|
||||
if (hccl_node->size() == 2) {
|
||||
TransferControl(new_hccl_node, memcpy_async, graph);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
const AnfNodePtr InsertMemcpyAsyncForHcclOp::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node,
|
||||
|
|
Loading…
Reference in New Issue