diff --git a/mindspore/ccsrc/device/kernel_runtime.cc b/mindspore/ccsrc/device/kernel_runtime.cc index eebc6503475..db79484f8c8 100644 --- a/mindspore/ccsrc/device/kernel_runtime.cc +++ b/mindspore/ccsrc/device/kernel_runtime.cc @@ -355,6 +355,10 @@ void KernelRuntime::AssignNodeOutputMem(int flag, const AnfNodePtr &node, int in AssignCommunicationNodeOutputMem(flag, node); return; } + if (AnfAlgo::IsGetNext(NOT_NULL(node)) && flag == kReuseDynamicMem) { + MS_LOG(INFO) << "GetNext disable mem_reuse"; + flag = kDynamicMem; + } auto kernel_mod = AnfAlgo::GetKernelMod(node); MS_EXCEPTION_IF_NULL(kernel_mod); auto output_sizes = kernel_mod->GetOutputSizeList(); diff --git a/mindspore/ccsrc/session/anf_runtime_algorithm.cc b/mindspore/ccsrc/session/anf_runtime_algorithm.cc index 893c379a072..29a27a65b1b 100644 --- a/mindspore/ccsrc/session/anf_runtime_algorithm.cc +++ b/mindspore/ccsrc/session/anf_runtime_algorithm.cc @@ -825,5 +825,10 @@ bool AnfRuntimeAlgorithm::IsCommunicationOp(const AnfNodePtr &node) { } return false; } + +bool AnfRuntimeAlgorithm::IsGetNext(const NotNull &node) { + auto kernel_name = AnfAlgo::GetCNodeName(node); + return kernel_name == kGetNextOpName; +} } // namespace session } // namespace mindspore diff --git a/mindspore/ccsrc/session/anf_runtime_algorithm.h b/mindspore/ccsrc/session/anf_runtime_algorithm.h index 1a1d471b84e..ab5a68db7f2 100644 --- a/mindspore/ccsrc/session/anf_runtime_algorithm.h +++ b/mindspore/ccsrc/session/anf_runtime_algorithm.h @@ -31,6 +31,7 @@ #include "kernel/kernel.h" #include "kernel/kernel_build_info.h" #include "operator/ops.h" +#include "utils/contract.h" namespace mindspore { namespace session { @@ -169,6 +170,7 @@ class AnfRuntimeAlgorithm { // get real input index for some tbe ops which input order is different between me and tbe impl static size_t GetRealInputIndex(const AnfNodePtr &anf_node, const size_t cur_index); static bool IsCommunicationOp(const AnfNodePtr &node); + static bool IsGetNext(const NotNull &node); }; } // namespace session using AnfAlgo = session::AnfRuntimeAlgorithm; diff --git a/mindspore/ccsrc/utils/utils.h b/mindspore/ccsrc/utils/utils.h index 39b4b7a1600..e1df2a8d256 100644 --- a/mindspore/ccsrc/utils/utils.h +++ b/mindspore/ccsrc/utils/utils.h @@ -42,6 +42,7 @@ constexpr auto kBNGrad2OpName = "BNGrad2"; constexpr auto kBNGrad3OpName = "BNGrad3"; constexpr auto kClearZeroOpName = "ClearZero"; constexpr auto kAtomicAddrCleanOpName = "AtomicAddrClean"; +constexpr auto kGetNextOpName = "GetNext"; constexpr auto kAllReduceOpName = "AllReduce"; constexpr auto kAllGatherOpName = "AllGather"; constexpr auto kBroadcastOpName = "Broadcast";