fix nopnode internal output error

This commit is contained in:
kswang 2020-07-31 16:59:17 +08:00
parent b55e5e2ce2
commit e6b0ac9b41
1 changed files with 38 additions and 16 deletions

View File

@ -980,8 +980,9 @@ CNodePtr SessionBasic::ConstructOutput(const AnfNodePtrList &outputs, const std:
if (context_ptr->execution_mode() == kPynativeMode) { if (context_ptr->execution_mode() == kPynativeMode) {
return backend_anf; return backend_anf;
} }
auto front_real_kernel = AnfAlgo::VisitKernel(out, 0); auto front_real_kernel_pair = AnfAlgo::VisitKernel(out, 0);
auto backend_real_kernel = AnfAlgo::VisitKernel(backend_anf, 0); auto front_real_kernel = front_real_kernel_pair.first;
auto backend_real_kernel_pair = AnfAlgo::VisitKernel(backend_anf, 0);
MS_EXCEPTION_IF_NULL(out); MS_EXCEPTION_IF_NULL(out);
auto out_func_graph = out->func_graph(); auto out_func_graph = out->func_graph();
MS_EXCEPTION_IF_NULL(out_func_graph); MS_EXCEPTION_IF_NULL(out_func_graph);
@ -992,26 +993,47 @@ CNodePtr SessionBasic::ConstructOutput(const AnfNodePtrList &outputs, const std:
auto node_users = out_func_graph_manager->node_users(); auto node_users = out_func_graph_manager->node_users();
auto users = node_users[out]; auto users = node_users[out];
bool internal_output = true; bool internal_output = true;
std::string kernel_target = GetCNodeTarget(front_real_kernel.first); std::string kernel_target = GetCNodeTarget(front_real_kernel);
for (auto user : users) { if (front_real_kernel != nullptr && front_real_kernel->isa<CNode>()) {
auto cnode = user.first->cast<CNodePtr>(); auto front_cnode = front_real_kernel->cast<CNodePtr>();
if (cnode == nullptr) { if (front_cnode != nullptr) {
auto prim = front_cnode->input(kAnfPrimitiveIndex);
if (prim == nullptr || !prim->isa<ValueNode>()) {
internal_output = false;
}
} else {
internal_output = false; internal_output = false;
break;
} }
auto prim = cnode->input(kAnfPrimitiveIndex); }
if (prim == nullptr || !prim->isa<ValueNode>()) { if (internal_output && opt::IsNopNode(front_real_kernel)) {
auto pre_node_pair = AnfAlgo::GetPrevNodeOutput(front_real_kernel, 0);
auto pre_node_target = GetCNodeTarget(pre_node_pair.first);
if (pre_node_target != kernel_target) {
internal_output = false; internal_output = false;
break;
}
if (!AnfAlgo::IsRealKernel(user.first) || kernel_target != GetCNodeTarget(user.first)) {
internal_output = false;
break;
} }
} }
if (internal_output) { if (internal_output) {
MS_LOG(INFO) << "Internal output1: " << out->DebugString() << "To " << backend_real_kernel.first->DebugString(); for (auto user : users) {
graph->AddInternalOutput(out, backend_real_kernel.first); auto cnode = user.first->cast<CNodePtr>();
if (cnode == nullptr) {
internal_output = false;
break;
}
auto prim = cnode->input(kAnfPrimitiveIndex);
if (prim == nullptr || !prim->isa<ValueNode>()) {
internal_output = false;
break;
}
if (!AnfAlgo::IsRealKernel(user.first) || kernel_target != GetCNodeTarget(user.first)) {
internal_output = false;
break;
}
}
}
if (internal_output) {
MS_LOG(INFO) << "Internal output: " << out->DebugString() << "To "
<< backend_real_kernel_pair.first->DebugString();
graph->AddInternalOutput(out, backend_real_kernel_pair.first);
} }
return backend_anf; return backend_anf;
} }