forked from mindspore-Ecosystem/mindspore
fix allgather depend bug
This commit is contained in:
parent
8fa246e980
commit
63d230da3b
|
@ -39,7 +39,6 @@ bool InsertDependForAllGather::Run(const FuncGraphPtr &graph) {
|
|||
all_gather_node[AnfAlgo::GetNodeAttr<int64_t>(cnode, kAttrFusion)] = node;
|
||||
}
|
||||
}
|
||||
std::vector<AnfNodePtr> depends = {NewValueNode(prim::kPrimMakeTuple)};
|
||||
auto iter = all_gather_node.begin();
|
||||
for (int64_t i = 0; i < SizeToInt(all_gather_node.size()) - 1; ++i) {
|
||||
auto current_node = iter->second;
|
||||
|
@ -50,20 +49,8 @@ bool InsertDependForAllGather::Run(const FuncGraphPtr &graph) {
|
|||
auto new_input = graph->NewCNode(inputs);
|
||||
new_input->set_abstract(AnfAlgo::GetInputNode(next_cnode, 0)->abstract());
|
||||
AnfAlgo::SetNodeInput(next_cnode, new_input, 0);
|
||||
depends.push_back(new_input);
|
||||
}
|
||||
if (depends.size() > 1) {
|
||||
auto make_tuple = graph->NewCNode(depends);
|
||||
auto return_node = graph->get_return();
|
||||
auto return_cnode = return_node->cast<CNodePtr>();
|
||||
std::vector<AnfNodePtr> inputs = {NewValueNode(std::make_shared<Primitive>(prim::kPrimDepend->name())),
|
||||
AnfAlgo::GetInputNode(return_cnode, 0), make_tuple};
|
||||
auto depend_node = graph->NewCNode(inputs);
|
||||
depend_node->set_abstract(AnfAlgo::GetInputNode(return_cnode, 0)->abstract());
|
||||
AnfAlgo::SetNodeInput(return_cnode, depend_node, 0);
|
||||
changed = true;
|
||||
}
|
||||
|
||||
return changed;
|
||||
}
|
||||
} // namespace opt
|
||||
|
|
Loading…
Reference in New Issue