!27563 Remove MS_CTX_ENABLE_MEM_SCHEDULER

Merge pull request !27563 from tanghuikang/remove_mem_scheduler
This commit is contained in:
i-robot 2021-12-14 06:26:16 +00:00 committed by Gitee
commit b4bd09fa47
7 changed files with 19 additions and 22 deletions

View File

@ -1603,8 +1603,7 @@ void InitHccl() {
#if ENABLE_D
bool task_sink = true;
auto single_op = common::GetEnv(kGraphOpRun);
auto enable_mem_scheduler = common::GetEnv(kEnableMemScheduler);
if (single_op == "1" || enable_mem_scheduler == "1") {
if (single_op == "1") {
task_sink = false;
}
auto mode = ms_context->get_param<int>(MS_CTX_EXECUTION_MODE);

View File

@ -1137,7 +1137,7 @@ void KernelRuntime::GenLaunchArgs(const mindspore::kernel::KernelMod &kernel_mod
bool KernelRuntime::UseMemScheduler() {
auto context_ptr = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(context_ptr);
if (!context_ptr->get_param<bool>(MS_CTX_ENABLE_MEM_SCHEDULER)) {
if (context_ptr->get_param<bool>(MS_CTX_ENABLE_TASK_SINK)) {
return false;
}
// Not use MemScheduler when running single op
@ -1438,6 +1438,7 @@ void KernelRuntime::InitGraphInputTensors(const std::shared_ptr<MemScheduler> &m
}
if (mem_scheduler->HasDeviceMem(tensor_address.get())) {
tensor_address->set_ptr(nullptr);
tensor->set_device_address(nullptr);
}
continue;
}
@ -1589,13 +1590,13 @@ bool KernelRuntime::LaunchKernelMod(const session::KernelGraph &graph, bool mock
LaunchKernelEvent(kernel_post_run_events, kernels[i]);
}
if (UseMemScheduler() && !mock) {
SyncUpdatedParameter(graph, mem_scheduler);
SyncParameter(graph, mem_scheduler);
}
return true;
}
void KernelRuntime::SyncUpdatedParameter(const session::KernelGraph &graph,
const std::shared_ptr<MemScheduler> &mem_scheduler) {
void KernelRuntime::SyncParameter(const session::KernelGraph &graph,
const std::shared_ptr<MemScheduler> &mem_scheduler) {
MS_EXCEPTION_IF_NULL(mem_scheduler);
auto &input_nodes = graph.input_nodes();
auto &input_tensors = graph.input_tensors();
@ -1608,17 +1609,22 @@ void KernelRuntime::SyncUpdatedParameter(const session::KernelGraph &graph,
if (!input_node->isa<Parameter>() || !AnfAlgo::OutputAddrExist(input_node, 0)) {
continue;
}
auto device_address = AnfAlgo::GetMutableOutputAddr(input_node, 0);
MS_EXCEPTION_IF_NULL(device_address);
auto parameter = input_node->cast<ParameterPtr>();
MS_EXCEPTION_IF_NULL(parameter);
if (!graph.IsUpdatedParameter(parameter)) {
if (!AnfAlgo::IsParameterWeight(parameter) && !graph.IsUpdatedParameter(parameter)) {
continue;
}
auto device_address = AnfAlgo::GetMutableOutputAddr(input_node, 0);
auto tensor = input_tensors[i];
MS_EXCEPTION_IF_NULL(tensor);
auto device_ptr = mem_scheduler->GetOrMalloc(device_address.get(), device_address->size(), kMemPriorityHigh);
if (device_ptr != nullptr) {
if (mem_scheduler->HasDeviceMem(device_address.get())) {
auto device_ptr = mem_scheduler->GetOrMalloc(device_address.get(), device_address->size(), kMemPriorityHigh);
device_address->set_ptr(device_ptr);
auto tensor = input_tensors[i];
MS_EXCEPTION_IF_NULL(tensor);
auto origin_tensor_device_address = std::dynamic_pointer_cast<device::DeviceAddress>(tensor->device_address());
if (origin_tensor_device_address != nullptr) {
origin_tensor_device_address->set_ptr(nullptr);
}
tensor->set_device_address(device_address);
tensor->set_sync_status(kNeedSyncDeviceToHost);
}

View File

@ -96,7 +96,7 @@ class KernelRuntime {
void set_device_id(uint32_t device_id) { device_id_ = device_id; }
uint32_t device_id() { return device_id_; }
static bool UseMemScheduler();
void SyncUpdatedParameter(const session::KernelGraph &graph, const std::shared_ptr<MemScheduler> &mem_scheduler);
void SyncParameter(const session::KernelGraph &graph, const std::shared_ptr<MemScheduler> &mem_scheduler);
#ifdef ENABLE_DEBUGGER
// set debugger

View File

@ -522,7 +522,6 @@ constexpr auto kValueTrue = "true";
// env key
constexpr auto kGraphOpRun = "GRAPH_OP_RUN";
constexpr auto kEnableMemScheduler = "ENABLE_MEM_SCHEDULER";
// some size
const size_t kShape4dDims = 4;

View File

@ -589,11 +589,6 @@ BackendPtr CreateBackend() {
if (single_op == "1") {
context_ptr->set_param<bool>(MS_CTX_ENABLE_TASK_SINK, false);
}
auto enable_mem_scheduler = common::GetEnv(kEnableMemScheduler);
if (enable_mem_scheduler == "1") {
context_ptr->set_param<bool>(MS_CTX_ENABLE_MEM_SCHEDULER, true);
context_ptr->set_param<bool>(MS_CTX_ENABLE_TASK_SINK, false);
}
}
}
return backend;
@ -625,7 +620,7 @@ void SetMindRTEnable() {
return;
}
if ((common::GetEnv(kGraphOpRun) == "1" || common::GetEnv(kEnableMemScheduler) == "1") && target == kAscendDevice) {
if ((common::GetEnv(kGraphOpRun) == "1" && target == kAscendDevice)) {
return;
}
} else {

View File

@ -82,7 +82,6 @@ MsContext::MsContext(const std::string &policy, const std::string &target) {
set_param<std::string>(MS_CTX_VARIABLE_MEMORY_MAX_SIZE, "0");
set_param<bool>(MS_CTX_ENABLE_LOOP_SINK, target == kAscendDevice || target == kDavinciDevice);
set_param<bool>(MS_CTX_ENABLE_PROFILING, false);
set_param<bool>(MS_CTX_ENABLE_MEM_SCHEDULER, false);
set_param<std::string>(MS_CTX_PROFILING_OPTIONS, "training_trace");
set_param<bool>(MS_CTX_CHECK_BPROP_FLAG, false);
set_param<float>(MS_CTX_MAX_DEVICE_MEMORY, kDefaultMaxDeviceMemory);

View File

@ -74,7 +74,6 @@ enum MsCtxParam : unsigned {
MS_CTX_ENABLE_GRAPH_KERNEL,
MS_CTX_ENABLE_HCCL,
MS_CTX_ENABLE_LOOP_SINK,
MS_CTX_ENABLE_MEM_SCHEDULER,
MS_CTX_ENABLE_PYNATIVE_HOOK,
MS_CTX_ENABLE_PYNATIVE_INFER,
MS_CTX_ENABLE_REDUCE_PRECISION,