From e6b0ac9b41b88cdb777700b8da828e58ef5e76a2 Mon Sep 17 00:00:00 2001 From: kswang Date: Fri, 31 Jul 2020 16:59:17 +0800 Subject: [PATCH] fix nopnode internal output error --- .../ccsrc/backend/session/session_basic.cc | 54 +++++++++++++------ 1 file changed, 38 insertions(+), 16 deletions(-) diff --git a/mindspore/ccsrc/backend/session/session_basic.cc b/mindspore/ccsrc/backend/session/session_basic.cc index dfde9336b2b..d484a4dffca 100644 --- a/mindspore/ccsrc/backend/session/session_basic.cc +++ b/mindspore/ccsrc/backend/session/session_basic.cc @@ -980,8 +980,9 @@ CNodePtr SessionBasic::ConstructOutput(const AnfNodePtrList &outputs, const std: if (context_ptr->execution_mode() == kPynativeMode) { return backend_anf; } - auto front_real_kernel = AnfAlgo::VisitKernel(out, 0); - auto backend_real_kernel = AnfAlgo::VisitKernel(backend_anf, 0); + auto front_real_kernel_pair = AnfAlgo::VisitKernel(out, 0); + auto front_real_kernel = front_real_kernel_pair.first; + auto backend_real_kernel_pair = AnfAlgo::VisitKernel(backend_anf, 0); MS_EXCEPTION_IF_NULL(out); auto out_func_graph = out->func_graph(); MS_EXCEPTION_IF_NULL(out_func_graph); @@ -992,26 +993,47 @@ CNodePtr SessionBasic::ConstructOutput(const AnfNodePtrList &outputs, const std: auto node_users = out_func_graph_manager->node_users(); auto users = node_users[out]; bool internal_output = true; - std::string kernel_target = GetCNodeTarget(front_real_kernel.first); - for (auto user : users) { - auto cnode = user.first->cast(); - if (cnode == nullptr) { + std::string kernel_target = GetCNodeTarget(front_real_kernel); + if (front_real_kernel != nullptr && front_real_kernel->isa()) { + auto front_cnode = front_real_kernel->cast(); + if (front_cnode != nullptr) { + auto prim = front_cnode->input(kAnfPrimitiveIndex); + if (prim == nullptr || !prim->isa()) { + internal_output = false; + } + } else { internal_output = false; - break; } - auto prim = cnode->input(kAnfPrimitiveIndex); - if (prim == nullptr || !prim->isa()) { + } + if (internal_output && opt::IsNopNode(front_real_kernel)) { + auto pre_node_pair = AnfAlgo::GetPrevNodeOutput(front_real_kernel, 0); + auto pre_node_target = GetCNodeTarget(pre_node_pair.first); + if (pre_node_target != kernel_target) { internal_output = false; - break; - } - if (!AnfAlgo::IsRealKernel(user.first) || kernel_target != GetCNodeTarget(user.first)) { - internal_output = false; - break; } } if (internal_output) { - MS_LOG(INFO) << "Internal output1: " << out->DebugString() << "To " << backend_real_kernel.first->DebugString(); - graph->AddInternalOutput(out, backend_real_kernel.first); + for (auto user : users) { + auto cnode = user.first->cast(); + if (cnode == nullptr) { + internal_output = false; + break; + } + auto prim = cnode->input(kAnfPrimitiveIndex); + if (prim == nullptr || !prim->isa()) { + internal_output = false; + break; + } + if (!AnfAlgo::IsRealKernel(user.first) || kernel_target != GetCNodeTarget(user.first)) { + internal_output = false; + break; + } + } + } + if (internal_output) { + MS_LOG(INFO) << "Internal output: " << out->DebugString() << "To " + << backend_real_kernel_pair.first->DebugString(); + graph->AddInternalOutput(out, backend_real_kernel_pair.first); } return backend_anf; }