diff --git a/mindspore/ccsrc/pre_activate/mem_reuse/mem_reuse.cc b/mindspore/ccsrc/pre_activate/mem_reuse/mem_reuse.cc index 210bc232f5d..d550b77bba2 100644 --- a/mindspore/ccsrc/pre_activate/mem_reuse/mem_reuse.cc +++ b/mindspore/ccsrc/pre_activate/mem_reuse/mem_reuse.cc @@ -103,6 +103,7 @@ bool MemReuseUtil::InitDynamicWorkspaceKernelRef() { bool MemReuseUtil::InitDynamicKernelRef(const KernelGraph *graph) { MS_EXCEPTION_IF_NULL(graph); graph_ = graph; + is_all_nop_node_ = opt::IsAllNopNode(graph); if (!InitDynamicOutputKernelRef()) { MS_LOG(INFO) << "InitDynamicOutputKernelRef fail"; return false; @@ -223,7 +224,6 @@ KernelRefCountPtr MemReuseUtil::GetRef(const AnfNodePtr &node, int output_idx) { } KernelRefCountPtr MemReuseUtil::GetKernelInputRef(const CNodePtr &kernel, size_t input_idx) { - auto is_all_nop_node = opt::IsAllNopNode(graph_); if (input_idx >= AnfAlgo::GetInputTensorNum(kernel)) { MS_LOG(EXCEPTION) << "Input index " << input_idx << " is larger than input number " << AnfAlgo::GetInputTensorNum(kernel); @@ -231,7 +231,7 @@ KernelRefCountPtr MemReuseUtil::GetKernelInputRef(const CNodePtr &kernel, size_t auto input_node = kernel->input(input_idx + 1); // Graph may be all nop nodes and not remove nop node, so this can not skip nop node. session::KernelWithIndex kernel_input; - if (is_all_nop_node) { + if (is_all_nop_node_) { // The graph does not remove the nop node. kernel_input = AnfAlgo::VisitKernelWithReturnType(input_node, 0, false); } else { @@ -265,7 +265,6 @@ void MemReuseUtil::SetKernelDefMap() { } void MemReuseUtil::SetKernelDefInputs() { - auto is_all_nop_node = opt::IsAllNopNode(graph_); for (const auto &kernel : graph_->execution_order()) { MS_EXCEPTION_IF_NULL(kernel); auto key = kernel.get(); @@ -282,7 +281,7 @@ void MemReuseUtil::SetKernelDefInputs() { auto input_node = AnfAlgo::GetInputNode(kernel, i); // Graph may be all nop nodes and not remove nop node, so this can not skip nop node. session::KernelWithIndex input; - if (is_all_nop_node) { + if (is_all_nop_node_) { // The graph does not remove the nop node. input = AnfAlgo::VisitKernelWithReturnType(input_node, 0, false); } else { @@ -349,11 +348,10 @@ void MemReuseUtil::SetSummaryNodesRefCount() { } void MemReuseUtil::SetGraphOutputRefCount() { - auto is_all_nop_node = opt::IsAllNopNode(graph_); auto nodes = AnfAlgo::GetAllOutput(graph_->output(), {prim::kPrimTupleGetItem}); for (const auto &node : nodes) { session::KernelWithIndex kernel_input; - if (is_all_nop_node) { + if (is_all_nop_node_) { // The graph does not remove the nop node. kernel_input = AnfAlgo::VisitKernelWithReturnType(node, 0, false); } else { diff --git a/mindspore/ccsrc/pre_activate/mem_reuse/mem_reuse.h b/mindspore/ccsrc/pre_activate/mem_reuse/mem_reuse.h index c7a129f1e92..37281a7128b 100644 --- a/mindspore/ccsrc/pre_activate/mem_reuse/mem_reuse.h +++ b/mindspore/ccsrc/pre_activate/mem_reuse/mem_reuse.h @@ -42,7 +42,7 @@ class MemReuseUtil { KernelRefCountPtrList total_refs_list_; KernelRefCountPtrList total_wk_ref_list_; KernelRefs kernel_workspace_refs_; - MemReuseUtil() : util_index_(kInitIndex), graph_(nullptr) {} + MemReuseUtil() : util_index_(kInitIndex), graph_(nullptr), is_all_nop_node_(false) {} ~MemReuseUtil() { if (graph_ != nullptr) { graph_ = nullptr; @@ -87,6 +87,7 @@ class MemReuseUtil { private: int util_index_; const KernelGraph *graph_; + bool is_all_nop_node_; KernelRefCountPtrList ref_list_; KernelDefPtrMaps kernel_def_ptr_list_; KernelRefCountPtrList last_ref_list_;