diff --git a/mindspore/ccsrc/pre_activate/mem_reuse/mem_reuse.cc b/mindspore/ccsrc/pre_activate/mem_reuse/mem_reuse.cc index 2927b1204f1..210bc232f5d 100644 --- a/mindspore/ccsrc/pre_activate/mem_reuse/mem_reuse.cc +++ b/mindspore/ccsrc/pre_activate/mem_reuse/mem_reuse.cc @@ -223,13 +223,21 @@ KernelRefCountPtr MemReuseUtil::GetRef(const AnfNodePtr &node, int output_idx) { } KernelRefCountPtr MemReuseUtil::GetKernelInputRef(const CNodePtr &kernel, size_t input_idx) { + auto is_all_nop_node = opt::IsAllNopNode(graph_); if (input_idx >= AnfAlgo::GetInputTensorNum(kernel)) { MS_LOG(EXCEPTION) << "Input index " << input_idx << " is larger than input number " << AnfAlgo::GetInputTensorNum(kernel); } auto input_node = kernel->input(input_idx + 1); // Graph may be all nop nodes and not remove nop node, so this can not skip nop node. - auto kernel_input = AnfAlgo::VisitKernelWithReturnType(input_node, 0, false); + session::KernelWithIndex kernel_input; + if (is_all_nop_node) { + // The graph does not remove the nop node. + kernel_input = AnfAlgo::VisitKernelWithReturnType(input_node, 0, false); + } else { + // The graph removes the nop node. + kernel_input = AnfAlgo::VisitKernelWithReturnType(input_node, 0, true); + } if (IsPrimitive(kernel_input.first, prim::kPrimMakeTuple)) { MS_LOG(EXCEPTION) << "Input node [" << input_node->DebugString() << "]'s input " << input_idx << " is MakeTuple"; } @@ -257,6 +265,7 @@ void MemReuseUtil::SetKernelDefMap() { } void MemReuseUtil::SetKernelDefInputs() { + auto is_all_nop_node = opt::IsAllNopNode(graph_); for (const auto &kernel : graph_->execution_order()) { MS_EXCEPTION_IF_NULL(kernel); auto key = kernel.get(); @@ -272,7 +281,14 @@ void MemReuseUtil::SetKernelDefInputs() { // set the inputs of this kernel_def auto input_node = AnfAlgo::GetInputNode(kernel, i); // Graph may be all nop nodes and not remove nop node, so this can not skip nop node. - auto input = AnfAlgo::VisitKernelWithReturnType(input_node, 0, false); + session::KernelWithIndex input; + if (is_all_nop_node) { + // The graph does not remove the nop node. + input = AnfAlgo::VisitKernelWithReturnType(input_node, 0, false); + } else { + // The graph removes the nop node. + input = AnfAlgo::VisitKernelWithReturnType(input_node, 0, true); + } if (IsPrimitive(input.first, prim::kPrimMakeTuple)) { MS_LOG(EXCEPTION) << "Input node [" << input_node->DebugString() << "]'s input " << i << " is MakeTuple"; }