diff --git a/mindspore/ccsrc/backend/session/anf_runtime_algorithm.cc b/mindspore/ccsrc/backend/session/anf_runtime_algorithm.cc index d60aa7c21cb..2e20760856b 100644 --- a/mindspore/ccsrc/backend/session/anf_runtime_algorithm.cc +++ b/mindspore/ccsrc/backend/session/anf_runtime_algorithm.cc @@ -1006,8 +1006,17 @@ DeviceAddressPtr AnfRuntimeAlgorithm::GetMutableOutputAddr(const AnfNodePtr &nod } // get output device addr of anf_node -bool AnfRuntimeAlgorithm::OutputAddrExist(const AnfNodePtr &node, size_t output_idx) { +bool AnfRuntimeAlgorithm::OutputAddrExist(const AnfNodePtr &node, size_t output_idx, bool visit_nop_node) { MS_EXCEPTION_IF_NULL(node); + if (opt::IsNopNode(node) && visit_nop_node) { + auto cnode = node->cast(); + MS_EXCEPTION_IF_NULL(cnode); + if (cnode->inputs().size() > 1) { + auto kernel_with_index = AnfAlgo::GetPrevNodeOutput(cnode, 0); + return OutputAddrExist(kernel_with_index.first, kernel_with_index.second); + } + return false; + } // Critical path performance optimization: `KernelInfo` is unique subclass of `KernelInfoDevice` auto kernel_info = static_cast(node->kernel_info()); MS_EXCEPTION_IF_NULL(kernel_info); diff --git a/mindspore/ccsrc/backend/session/anf_runtime_algorithm.h b/mindspore/ccsrc/backend/session/anf_runtime_algorithm.h index 43c873ac37a..d75c4f62380 100644 --- a/mindspore/ccsrc/backend/session/anf_runtime_algorithm.h +++ b/mindspore/ccsrc/backend/session/anf_runtime_algorithm.h @@ -174,7 +174,7 @@ class AnfRuntimeAlgorithm { // get mutable output device addr of anf_node static DeviceAddressPtr GetMutableOutputAddr(const AnfNodePtr &node, size_t output_idx, bool visit_nop_node = true); // check whether output addr is exist or not - static bool OutputAddrExist(const AnfNodePtr &node, size_t output_idx); + static bool OutputAddrExist(const AnfNodePtr &node, size_t output_idx, bool visit_nop_node = false); // check whether workspace addr is exist or not static bool WorkspaceAddrExist(const AnfNodePtr &node, size_t output_idx); // get address from prev node,input_index is the input index of current node related to prev node diff --git a/mindspore/ccsrc/backend/session/session_basic.cc b/mindspore/ccsrc/backend/session/session_basic.cc index 1735ca6ebb7..fce9c18182d 100644 --- a/mindspore/ccsrc/backend/session/session_basic.cc +++ b/mindspore/ccsrc/backend/session/session_basic.cc @@ -1612,7 +1612,7 @@ void SessionBasic::UpdateOutputTensors(const VectorRef *outputs, if (iter != tensor_to_node.end()) { const auto &node = iter->second.first; const auto &output_index = iter->second.second; - if (!AnfAlgo::OutputAddrExist(node, output_index)) { + if (!AnfAlgo::OutputAddrExist(node, output_index, true)) { continue; } const auto &address = AnfAlgo::GetMutableOutputAddr(node, output_index);