!7877 optimize internal depend

Merge pull request !7877 from kisnwang/optimize-internal-depend
This commit is contained in:
mindspore-ci-bot 2020-10-29 14:15:46 +08:00 committed by Gitee
commit deb38e0dbe
1 changed files with 23 additions and 5 deletions

View File

@ -1186,6 +1186,25 @@ bool CNodeFirstInputIsPrimitive(const AnfNodePtr &node) {
return true;
}
std::vector<AnfNodePtr> ExtendNodeUsers(const FuncGraphManagerPtr &front_func_graph_manager,
const AnfNodePtr &front_node) {
auto node_users = front_func_graph_manager->node_users();
auto users = node_users[front_node];
std::vector<AnfNodePtr> result;
for (auto user : users) {
if (IsPrimitiveCNode(user.first, prim::kPrimControlDepend)) {
continue;
}
if (IsPrimitiveCNode(user.first, prim::kPrimDepend)) {
auto res = ExtendNodeUsers(front_func_graph_manager, user.first);
result.insert(result.end(), res.begin(), res.end());
continue;
}
result.emplace_back(user.first);
}
return result;
}
void HandleInternalOutput(const AnfNodePtr &front_node, const AnfNodePtr &backend_node,
const FuncGraphManagerPtr &front_func_graph_manager,
const std::shared_ptr<KernelGraph> &backend_graph) {
@ -1193,8 +1212,6 @@ void HandleInternalOutput(const AnfNodePtr &front_node, const AnfNodePtr &backen
if (!AnfAlgo::IsRealKernel(front_node) && !AnfAlgo::CheckPrimitiveType(front_node, prim::kPrimTupleGetItem)) {
return;
}
auto node_users = front_func_graph_manager->node_users();
auto users = node_users[front_node];
auto front_real_kernel_pair = AnfAlgo::VisitKernel(front_node, 0);
auto backend_real_kernel_pair = AnfAlgo::VisitKernel(backend_node, 0);
@ -1210,16 +1227,17 @@ void HandleInternalOutput(const AnfNodePtr &front_node, const AnfNodePtr &backen
}
}
if (internal_output) {
auto users = ExtendNodeUsers(front_func_graph_manager, front_node);
for (auto user : users) {
if (!CNodeFirstInputIsPrimitive(user.first)) {
if (!CNodeFirstInputIsPrimitive(user)) {
internal_output = false;
break;
}
if (!AnfAlgo::IsRealKernel(user.first)) {
if (!AnfAlgo::IsRealKernel(user)) {
internal_output = false;
break;
}
if (kernel_target != GetCNodeTarget(user.first)) {
if (kernel_target != GetCNodeTarget(user)) {
unique_target = false;
}
}