fix allgather depend bug

This commit is contained in:
wangjun 2021-06-07 09:50:29 +08:00
parent 8fa246e980
commit 63d230da3b
1 changed files with 0 additions and 13 deletions

View File

@ -39,7 +39,6 @@ bool InsertDependForAllGather::Run(const FuncGraphPtr &graph) {
all_gather_node[AnfAlgo::GetNodeAttr<int64_t>(cnode, kAttrFusion)] = node; all_gather_node[AnfAlgo::GetNodeAttr<int64_t>(cnode, kAttrFusion)] = node;
} }
} }
std::vector<AnfNodePtr> depends = {NewValueNode(prim::kPrimMakeTuple)};
auto iter = all_gather_node.begin(); auto iter = all_gather_node.begin();
for (int64_t i = 0; i < SizeToInt(all_gather_node.size()) - 1; ++i) { for (int64_t i = 0; i < SizeToInt(all_gather_node.size()) - 1; ++i) {
auto current_node = iter->second; auto current_node = iter->second;
@ -50,20 +49,8 @@ bool InsertDependForAllGather::Run(const FuncGraphPtr &graph) {
auto new_input = graph->NewCNode(inputs); auto new_input = graph->NewCNode(inputs);
new_input->set_abstract(AnfAlgo::GetInputNode(next_cnode, 0)->abstract()); new_input->set_abstract(AnfAlgo::GetInputNode(next_cnode, 0)->abstract());
AnfAlgo::SetNodeInput(next_cnode, new_input, 0); AnfAlgo::SetNodeInput(next_cnode, new_input, 0);
depends.push_back(new_input);
}
if (depends.size() > 1) {
auto make_tuple = graph->NewCNode(depends);
auto return_node = graph->get_return();
auto return_cnode = return_node->cast<CNodePtr>();
std::vector<AnfNodePtr> inputs = {NewValueNode(std::make_shared<Primitive>(prim::kPrimDepend->name())),
AnfAlgo::GetInputNode(return_cnode, 0), make_tuple};
auto depend_node = graph->NewCNode(inputs);
depend_node->set_abstract(AnfAlgo::GetInputNode(return_cnode, 0)->abstract());
AnfAlgo::SetNodeInput(return_cnode, depend_node, 0);
changed = true; changed = true;
} }
return changed; return changed;
} }
} // namespace opt } // namespace opt