!306 gpu uses dynamic memory pool by default

Merge pull request !306 from limingqi107/master
This commit is contained in:
mindspore-ci-bot 2020-04-14 21:50:24 +08:00 committed by Gitee
commit d90e121547
5 changed files with 34 additions and 40 deletions

View File

@ -127,9 +127,10 @@ bool GPUKernelRuntime::Run(session::KernelGraph *graph) {
auto context_ptr = MsContext::GetInstance(); auto context_ptr = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(context_ptr); MS_EXCEPTION_IF_NULL(context_ptr);
bool is_enable_dynamic_mem = context_ptr->enable_dynamic_mem_pool(); bool is_enable_dynamic_mem = context_ptr->enable_dynamic_mem_pool();
bool is_enable_pynative_infer = context_ptr->enable_pynative_infer();
struct timeval start_time, end_time; struct timeval start_time, end_time;
(void)gettimeofday(&start_time, nullptr); (void)gettimeofday(&start_time, nullptr);
if (is_enable_dynamic_mem) { if (is_enable_dynamic_mem && !is_enable_pynative_infer) {
ret = LaunchKernelDynamic(graph); ret = LaunchKernelDynamic(graph);
} else { } else {
ret = LaunchKernel(graph); ret = LaunchKernel(graph);
@ -152,7 +153,7 @@ void GPUKernelRuntime::InitKernelRefCount(const session::KernelGraph *graph) {
} }
mem_reuse_util_ptr->SetKernelDefMap(); mem_reuse_util_ptr->SetKernelDefMap();
mem_reuse_util_ptr->SetReuseRefCount(); mem_reuse_util_ptr->SetReuseRefCount();
// Can't free the device address of graph output, so set the reference count of graph output specially, // Can't free the device address of graph output, so set the reference count of graph output specially.
mem_reuse_util_ptr->SetGraphOutputRefCount(); mem_reuse_util_ptr->SetGraphOutputRefCount();
mem_reuse_util_ptr_ = mem_reuse_util_ptr; mem_reuse_util_ptr_ = mem_reuse_util_ptr;
} }
@ -351,6 +352,10 @@ void GPUKernelRuntime::FreeKernelDynamicRes(const mindspore::AnfNodePtr &kernel,
if (kernel_ref_count_ptr == nullptr) { if (kernel_ref_count_ptr == nullptr) {
continue; continue;
} }
// Can't free the output of graph.
if (kernel_ref_count_ptr->ref_count_dynamic_use_ == memreuse::kMaxRefCount) {
continue;
}
kernel_ref_count_ptr->ref_count_dynamic_use_--; kernel_ref_count_ptr->ref_count_dynamic_use_--;
if (kernel_ref_count_ptr->ref_count_dynamic_use_ == 0) { if (kernel_ref_count_ptr->ref_count_dynamic_use_ == 0) {
// Reset the reference count. // Reset the reference count.
@ -360,14 +365,10 @@ void GPUKernelRuntime::FreeKernelDynamicRes(const mindspore::AnfNodePtr &kernel,
FreeCommunicationOpDynamicRes(kernel, i, &is_communication_op); FreeCommunicationOpDynamicRes(kernel, i, &is_communication_op);
if (!is_communication_op) { if (!is_communication_op) {
auto device_address = AnfAlgo::GetPrevNodeMutableOutputAddr(kernel, i); auto device_address = AnfAlgo::GetPrevNodeMutableOutputAddr(kernel, i);
MS_EXCEPTION_IF_NULL(device_address); mem_manager_->FreeMemFromMemPool(device_address);
MS_EXCEPTION_IF_NULL(device_address->ptr_);
mem_manager_->FreeMemFromMemPool(device_address->ptr_);
device_address->ptr_ = nullptr;
} }
} }
} }
// Free the workspace of kernel. // Free the workspace of kernel.
for (size_t i = 0; i < kernel_workspaces.size(); ++i) { for (size_t i = 0; i < kernel_workspaces.size(); ++i) {
auto workspace = kernel_workspaces[i]; auto workspace = kernel_workspaces[i];
@ -388,10 +389,7 @@ void GPUKernelRuntime::FreeCommunicationOpDynamicRes(const mindspore::AnfNodePtr
communication_op_input_ref_count_--; communication_op_input_ref_count_--;
if (communication_op_input_ref_count_ == 0) { if (communication_op_input_ref_count_ == 0) {
auto device_address = AnfAlgo::GetPrevNodeMutableOutputAddr(kernel, 0); auto device_address = AnfAlgo::GetPrevNodeMutableOutputAddr(kernel, 0);
MS_EXCEPTION_IF_NULL(device_address); mem_manager_->FreeMemFromMemPool(device_address);
MS_EXCEPTION_IF_NULL(device_address->ptr_);
mem_manager_->FreeMemFromMemPool(device_address->ptr_);
device_address->ptr_ = nullptr;
} }
*is_communication_op = true; *is_communication_op = true;
return; return;
@ -410,10 +408,7 @@ void GPUKernelRuntime::FreeCommunicationOpDynamicRes(const mindspore::AnfNodePtr
communication_op_output_ref_count_--; communication_op_output_ref_count_--;
if (communication_op_output_ref_count_ == 0) { if (communication_op_output_ref_count_ == 0) {
auto device_address = AnfAlgo::GetMutableOutputAddr(kernel_input.first, 0); auto device_address = AnfAlgo::GetMutableOutputAddr(kernel_input.first, 0);
MS_EXCEPTION_IF_NULL(device_address); mem_manager_->FreeMemFromMemPool(device_address);
MS_EXCEPTION_IF_NULL(device_address->ptr_);
mem_manager_->FreeMemFromMemPool(device_address->ptr_);
device_address->ptr_ = nullptr;
} }
*is_communication_op = true; *is_communication_op = true;
} }

View File

@ -155,6 +155,13 @@ void *MemoryManager::MallocMemFromMemPool(size_t size) {
return nullptr; return nullptr;
} }
void MemoryManager::FreeMemFromMemPool(const DeviceAddressPtr address) {
MS_EXCEPTION_IF_NULL(address);
MS_EXCEPTION_IF_NULL(address->ptr_);
FreeMemFromMemPool(address->ptr_);
address->ptr_ = nullptr;
}
void MemoryManager::FreeMemFromMemPool(void *device_ptr) { void MemoryManager::FreeMemFromMemPool(void *device_ptr) {
if (device_ptr == nullptr) { if (device_ptr == nullptr) {
MS_LOG(ERROR) << "FreeMemFromMemPool device_ptr is null."; MS_LOG(ERROR) << "FreeMemFromMemPool device_ptr is null.";

View File

@ -47,6 +47,7 @@ class MemoryManager {
virtual void MallocMemFromMemPool(const DeviceAddressPtr address, size_t size); virtual void MallocMemFromMemPool(const DeviceAddressPtr address, size_t size);
virtual void *MallocMemFromMemPool(size_t size); virtual void *MallocMemFromMemPool(size_t size);
virtual void FreeMemFromMemPool(const DeviceAddressPtr address);
virtual void FreeMemFromMemPool(void *device_ptr); virtual void FreeMemFromMemPool(void *device_ptr);
size_t GetCommonAlignSize(size_t input_size) const; size_t GetCommonAlignSize(size_t input_size) const;

View File

@ -273,30 +273,21 @@ void MemReuseUtil::SetReuseRefCount() {
} }
void MemReuseUtil::SetGraphOutputRefCount() { void MemReuseUtil::SetGraphOutputRefCount() {
for (const auto &output : graph_->outputs()) { auto nodes = AnfAlgo::GetAllOutput(graph_->output(), {prim::kPrimTupleGetItem});
MS_EXCEPTION_IF_NULL(output); for (const auto &node : nodes) {
for (size_t i = 0; i < AnfAlgo::GetInputTensorNum(output); ++i) { auto kernel_input = AnfAlgo::VisitKernelWithReturnType(node, 0);
if (!(output->isa<CNode>())) { MS_EXCEPTION_IF_NULL(kernel_input.first);
continue; if (!kernel_input.first->isa<CNode>() || !AnfAlgo::IsRealKernel(kernel_input.first)) {
} continue;
auto cnode = output->cast<CNodePtr>(); }
MS_EXCEPTION_IF_NULL(cnode); auto ak_node = kernel_input.first->cast<CNodePtr>();
auto input_node = cnode->input(i + 1); auto key = ak_node.get();
MS_EXCEPTION_IF_NULL(input_node); auto iter = kernel_output_refs_.find(key);
auto kernel_input = AnfAlgo::VisitKernel(input_node, 0); if ((iter != kernel_output_refs_.end()) && (kernel_input.second < iter->second.size())) {
MS_EXCEPTION_IF_NULL(kernel_input.first); auto kernel_ref_count_ptr = kernel_output_refs_[key][kernel_input.second];
if (!(kernel_input.first->isa<CNode>())) { MS_EXCEPTION_IF_NULL(kernel_ref_count_ptr);
continue; kernel_ref_count_ptr->ref_count_ = kMaxRefCount;
} kernel_ref_count_ptr->ref_count_dynamic_use_ = kMaxRefCount;
auto ak_node = kernel_input.first->cast<CNodePtr>();
auto key = ak_node.get();
auto iter = kernel_output_refs_.find(key);
if ((iter != kernel_output_refs_.end()) && (kernel_input.second < iter->second.size())) {
auto kernel_ref_count_ptr = kernel_output_refs_[key][kernel_input.second];
MS_EXCEPTION_IF_NULL(kernel_ref_count_ptr);
kernel_ref_count_ptr->ref_count_ = kMaxRefCount;
kernel_ref_count_ptr->ref_count_dynamic_use_ = kMaxRefCount;
}
} }
} }
#ifdef MEM_REUSE_DEBUG #ifdef MEM_REUSE_DEBUG

View File

@ -75,7 +75,7 @@ MsContext::MsContext(const std::string& policy, const std::string& target) {
precompile_only_ = false; precompile_only_ = false;
auto_mixed_precision_flag_ = true; auto_mixed_precision_flag_ = true;
enable_pynative_infer_ = false; enable_pynative_infer_ = false;
enable_dynamic_mem_pool_ = false; enable_dynamic_mem_pool_ = true;
graph_memory_max_size_ = "0"; graph_memory_max_size_ = "0";
variable_memory_max_size_ = "0"; variable_memory_max_size_ = "0";
MS_LOG(INFO) << "Create context with backend policy:" << policy << ", device target:" << target << "."; MS_LOG(INFO) << "Create context with backend policy:" << policy << ", device target:" << target << ".";