forked from mindspore-Ecosystem/mindspore
getnext disable memory reuse
This commit is contained in:
parent
a4cf9028ee
commit
2aad57c595
|
@ -355,6 +355,10 @@ void KernelRuntime::AssignNodeOutputMem(int flag, const AnfNodePtr &node, int in
|
||||||
AssignCommunicationNodeOutputMem(flag, node);
|
AssignCommunicationNodeOutputMem(flag, node);
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
if (AnfAlgo::IsGetNext(NOT_NULL(node)) && flag == kReuseDynamicMem) {
|
||||||
|
MS_LOG(INFO) << "GetNext disable mem_reuse";
|
||||||
|
flag = kDynamicMem;
|
||||||
|
}
|
||||||
auto kernel_mod = AnfAlgo::GetKernelMod(node);
|
auto kernel_mod = AnfAlgo::GetKernelMod(node);
|
||||||
MS_EXCEPTION_IF_NULL(kernel_mod);
|
MS_EXCEPTION_IF_NULL(kernel_mod);
|
||||||
auto output_sizes = kernel_mod->GetOutputSizeList();
|
auto output_sizes = kernel_mod->GetOutputSizeList();
|
||||||
|
|
|
@ -825,5 +825,10 @@ bool AnfRuntimeAlgorithm::IsCommunicationOp(const AnfNodePtr &node) {
|
||||||
}
|
}
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
bool AnfRuntimeAlgorithm::IsGetNext(const NotNull<AnfNodePtr> &node) {
|
||||||
|
auto kernel_name = AnfAlgo::GetCNodeName(node);
|
||||||
|
return kernel_name == kGetNextOpName;
|
||||||
|
}
|
||||||
} // namespace session
|
} // namespace session
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
|
@ -31,6 +31,7 @@
|
||||||
#include "kernel/kernel.h"
|
#include "kernel/kernel.h"
|
||||||
#include "kernel/kernel_build_info.h"
|
#include "kernel/kernel_build_info.h"
|
||||||
#include "operator/ops.h"
|
#include "operator/ops.h"
|
||||||
|
#include "utils/contract.h"
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace session {
|
namespace session {
|
||||||
|
@ -169,6 +170,7 @@ class AnfRuntimeAlgorithm {
|
||||||
// get real input index for some tbe ops which input order is different between me and tbe impl
|
// get real input index for some tbe ops which input order is different between me and tbe impl
|
||||||
static size_t GetRealInputIndex(const AnfNodePtr &anf_node, const size_t cur_index);
|
static size_t GetRealInputIndex(const AnfNodePtr &anf_node, const size_t cur_index);
|
||||||
static bool IsCommunicationOp(const AnfNodePtr &node);
|
static bool IsCommunicationOp(const AnfNodePtr &node);
|
||||||
|
static bool IsGetNext(const NotNull<AnfNodePtr> &node);
|
||||||
};
|
};
|
||||||
} // namespace session
|
} // namespace session
|
||||||
using AnfAlgo = session::AnfRuntimeAlgorithm;
|
using AnfAlgo = session::AnfRuntimeAlgorithm;
|
||||||
|
|
|
@ -42,6 +42,7 @@ constexpr auto kBNGrad2OpName = "BNGrad2";
|
||||||
constexpr auto kBNGrad3OpName = "BNGrad3";
|
constexpr auto kBNGrad3OpName = "BNGrad3";
|
||||||
constexpr auto kClearZeroOpName = "ClearZero";
|
constexpr auto kClearZeroOpName = "ClearZero";
|
||||||
constexpr auto kAtomicAddrCleanOpName = "AtomicAddrClean";
|
constexpr auto kAtomicAddrCleanOpName = "AtomicAddrClean";
|
||||||
|
constexpr auto kGetNextOpName = "GetNext";
|
||||||
constexpr auto kAllReduceOpName = "AllReduce";
|
constexpr auto kAllReduceOpName = "AllReduce";
|
||||||
constexpr auto kAllGatherOpName = "AllGather";
|
constexpr auto kAllGatherOpName = "AllGather";
|
||||||
constexpr auto kBroadcastOpName = "Broadcast";
|
constexpr auto kBroadcastOpName = "Broadcast";
|
||||||
|
|
Loading…
Reference in New Issue