!1811 handle control-depend with virtual node

Merge pull request !1811 from caifubi/handle-contrl-depend-with-virtual-node
This commit is contained in:
mindspore-ci-bot 2020-06-28 09:10:53 +08:00 committed by Gitee
commit 8d49de00e8
2 changed files with 57 additions and 2 deletions

View File

@ -521,6 +521,47 @@ std::vector<AnfNodePtr> KernelGraph::GetOutputNodes(const AnfNodePtr &node) {
return output_nodes;
}
// Find control_depend real input nodes.
void GetAllFatherRealNode(const AnfNodePtr &anf_node, std::vector<AnfNodePtr> *result, std::set<AnfNodePtr> *visited) {
MS_EXCEPTION_IF_NULL(anf_node);
MS_EXCEPTION_IF_NULL(result);
MS_EXCEPTION_IF_NULL(visited);
if (visited->find(anf_node) != visited->end()) {
MS_LOG(WARNING) << "Node:" << anf_node->fullname_with_scope() << " has alreday been visited";
return;
}
visited->insert(anf_node);
if (AnfAlgo::IsRealKernel(anf_node)) {
result->emplace_back(anf_node);
return;
}
if (!anf_node->isa<CNode>()) {
return;
}
auto cnode = anf_node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cnode);
if (cnode->inputs().empty()) {
MS_LOG(EXCEPTION) << "Illegal null input of cnode(%s)" << anf_node->DebugString();
}
auto input0 = cnode->input(0);
if (IsPrimitive(input0, prim::kPrimMakeTuple)) {
for (size_t i = 1; i < cnode->inputs().size(); ++i) {
GetAllFatherRealNode(cnode->input(i), result, visited);
}
} else if (IsPrimitive(input0, prim::kPrimTupleGetItem)) {
if (cnode->inputs().size() != kTupleGetItemInputSize) {
MS_LOG(EXCEPTION) << "The node tuple_get_item must have 2 inputs!";
}
GetAllFatherRealNode(cnode->input(kRealInputNodeIndexInTupleGetItem), result, visited);
} else if (IsPrimitive(input0, prim::kPrimDepend)) {
if (cnode->inputs().size() != kDependInputSize) {
MS_LOG(EXCEPTION) << "Depend node must have 2 inputs!";
}
GetAllFatherRealNode(cnode->input(kRealInputIndexInDepend), result, visited);
GetAllFatherRealNode(cnode->input(kDependAttachNodeIndex), result, visited);
}
}
// update the depend relations of control depend
void KernelGraph::UpdateControlDependRelations(const std::vector<AnfNodePtr> &depends) {
for (const auto &node : depends) {
@ -551,11 +592,24 @@ void KernelGraph::UpdateControlDependRelations(const std::vector<AnfNodePtr> &de
if (depend_node->isa<Parameter>() && depend_mode == 1) {
depend_nodes = GetOutputNodes(depend_node);
}
for (auto &first_node : prior_nodes) {
std::vector<AnfNodePtr> real_prior_nodes;
std::set<AnfNodePtr> prior_visited;
for (const auto &tmp : prior_nodes) {
GetAllFatherRealNode(tmp, &real_prior_nodes, &prior_visited);
}
std::vector<AnfNodePtr> real_depend_nodes;
std::set<AnfNodePtr> depend_visited;
for (const auto &tmp : depend_nodes) {
GetAllFatherRealNode(tmp, &real_depend_nodes, &depend_visited);
}
for (auto &first_node : real_prior_nodes) {
if (AnfAlgo::CheckPrimitiveType(first_node, prim::kPrimControlDepend)) {
continue;
}
for (auto &second_node : depend_nodes) {
for (auto &second_node : real_depend_nodes) {
if (AnfAlgo::CheckPrimitiveType(second_node, prim::kPrimControlDepend)) {
continue;
}

View File

@ -252,6 +252,7 @@ constexpr auto kControlDependMode = "depend_mode";
// index define of depend
constexpr auto kRealInputIndexInDepend = 1;
constexpr auto kDependAttachNodeIndex = 2;
constexpr auto kDependInputSize = 3;
// format
constexpr auto kOpFormat_DEFAULT = "DefaultFormat";
constexpr auto kOpFormat_NC1KHKWHWC0 = "NC1KHKWHWC0";