forked from mindspore-Ecosystem/mindspore
!306 gpu uses dynamic memory pool by default
Merge pull request !306 from limingqi107/master
This commit is contained in:
commit
d90e121547
|
@ -127,9 +127,10 @@ bool GPUKernelRuntime::Run(session::KernelGraph *graph) {
|
||||||
auto context_ptr = MsContext::GetInstance();
|
auto context_ptr = MsContext::GetInstance();
|
||||||
MS_EXCEPTION_IF_NULL(context_ptr);
|
MS_EXCEPTION_IF_NULL(context_ptr);
|
||||||
bool is_enable_dynamic_mem = context_ptr->enable_dynamic_mem_pool();
|
bool is_enable_dynamic_mem = context_ptr->enable_dynamic_mem_pool();
|
||||||
|
bool is_enable_pynative_infer = context_ptr->enable_pynative_infer();
|
||||||
struct timeval start_time, end_time;
|
struct timeval start_time, end_time;
|
||||||
(void)gettimeofday(&start_time, nullptr);
|
(void)gettimeofday(&start_time, nullptr);
|
||||||
if (is_enable_dynamic_mem) {
|
if (is_enable_dynamic_mem && !is_enable_pynative_infer) {
|
||||||
ret = LaunchKernelDynamic(graph);
|
ret = LaunchKernelDynamic(graph);
|
||||||
} else {
|
} else {
|
||||||
ret = LaunchKernel(graph);
|
ret = LaunchKernel(graph);
|
||||||
|
@ -152,7 +153,7 @@ void GPUKernelRuntime::InitKernelRefCount(const session::KernelGraph *graph) {
|
||||||
}
|
}
|
||||||
mem_reuse_util_ptr->SetKernelDefMap();
|
mem_reuse_util_ptr->SetKernelDefMap();
|
||||||
mem_reuse_util_ptr->SetReuseRefCount();
|
mem_reuse_util_ptr->SetReuseRefCount();
|
||||||
// Can't free the device address of graph output, so set the reference count of graph output specially,
|
// Can't free the device address of graph output, so set the reference count of graph output specially.
|
||||||
mem_reuse_util_ptr->SetGraphOutputRefCount();
|
mem_reuse_util_ptr->SetGraphOutputRefCount();
|
||||||
mem_reuse_util_ptr_ = mem_reuse_util_ptr;
|
mem_reuse_util_ptr_ = mem_reuse_util_ptr;
|
||||||
}
|
}
|
||||||
|
@ -351,6 +352,10 @@ void GPUKernelRuntime::FreeKernelDynamicRes(const mindspore::AnfNodePtr &kernel,
|
||||||
if (kernel_ref_count_ptr == nullptr) {
|
if (kernel_ref_count_ptr == nullptr) {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
// Can't free the output of graph.
|
||||||
|
if (kernel_ref_count_ptr->ref_count_dynamic_use_ == memreuse::kMaxRefCount) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
kernel_ref_count_ptr->ref_count_dynamic_use_--;
|
kernel_ref_count_ptr->ref_count_dynamic_use_--;
|
||||||
if (kernel_ref_count_ptr->ref_count_dynamic_use_ == 0) {
|
if (kernel_ref_count_ptr->ref_count_dynamic_use_ == 0) {
|
||||||
// Reset the reference count.
|
// Reset the reference count.
|
||||||
|
@ -360,14 +365,10 @@ void GPUKernelRuntime::FreeKernelDynamicRes(const mindspore::AnfNodePtr &kernel,
|
||||||
FreeCommunicationOpDynamicRes(kernel, i, &is_communication_op);
|
FreeCommunicationOpDynamicRes(kernel, i, &is_communication_op);
|
||||||
if (!is_communication_op) {
|
if (!is_communication_op) {
|
||||||
auto device_address = AnfAlgo::GetPrevNodeMutableOutputAddr(kernel, i);
|
auto device_address = AnfAlgo::GetPrevNodeMutableOutputAddr(kernel, i);
|
||||||
MS_EXCEPTION_IF_NULL(device_address);
|
mem_manager_->FreeMemFromMemPool(device_address);
|
||||||
MS_EXCEPTION_IF_NULL(device_address->ptr_);
|
|
||||||
mem_manager_->FreeMemFromMemPool(device_address->ptr_);
|
|
||||||
device_address->ptr_ = nullptr;
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Free the workspace of kernel.
|
// Free the workspace of kernel.
|
||||||
for (size_t i = 0; i < kernel_workspaces.size(); ++i) {
|
for (size_t i = 0; i < kernel_workspaces.size(); ++i) {
|
||||||
auto workspace = kernel_workspaces[i];
|
auto workspace = kernel_workspaces[i];
|
||||||
|
@ -388,10 +389,7 @@ void GPUKernelRuntime::FreeCommunicationOpDynamicRes(const mindspore::AnfNodePtr
|
||||||
communication_op_input_ref_count_--;
|
communication_op_input_ref_count_--;
|
||||||
if (communication_op_input_ref_count_ == 0) {
|
if (communication_op_input_ref_count_ == 0) {
|
||||||
auto device_address = AnfAlgo::GetPrevNodeMutableOutputAddr(kernel, 0);
|
auto device_address = AnfAlgo::GetPrevNodeMutableOutputAddr(kernel, 0);
|
||||||
MS_EXCEPTION_IF_NULL(device_address);
|
mem_manager_->FreeMemFromMemPool(device_address);
|
||||||
MS_EXCEPTION_IF_NULL(device_address->ptr_);
|
|
||||||
mem_manager_->FreeMemFromMemPool(device_address->ptr_);
|
|
||||||
device_address->ptr_ = nullptr;
|
|
||||||
}
|
}
|
||||||
*is_communication_op = true;
|
*is_communication_op = true;
|
||||||
return;
|
return;
|
||||||
|
@ -410,10 +408,7 @@ void GPUKernelRuntime::FreeCommunicationOpDynamicRes(const mindspore::AnfNodePtr
|
||||||
communication_op_output_ref_count_--;
|
communication_op_output_ref_count_--;
|
||||||
if (communication_op_output_ref_count_ == 0) {
|
if (communication_op_output_ref_count_ == 0) {
|
||||||
auto device_address = AnfAlgo::GetMutableOutputAddr(kernel_input.first, 0);
|
auto device_address = AnfAlgo::GetMutableOutputAddr(kernel_input.first, 0);
|
||||||
MS_EXCEPTION_IF_NULL(device_address);
|
mem_manager_->FreeMemFromMemPool(device_address);
|
||||||
MS_EXCEPTION_IF_NULL(device_address->ptr_);
|
|
||||||
mem_manager_->FreeMemFromMemPool(device_address->ptr_);
|
|
||||||
device_address->ptr_ = nullptr;
|
|
||||||
}
|
}
|
||||||
*is_communication_op = true;
|
*is_communication_op = true;
|
||||||
}
|
}
|
||||||
|
|
|
@ -155,6 +155,13 @@ void *MemoryManager::MallocMemFromMemPool(size_t size) {
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void MemoryManager::FreeMemFromMemPool(const DeviceAddressPtr address) {
|
||||||
|
MS_EXCEPTION_IF_NULL(address);
|
||||||
|
MS_EXCEPTION_IF_NULL(address->ptr_);
|
||||||
|
FreeMemFromMemPool(address->ptr_);
|
||||||
|
address->ptr_ = nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
void MemoryManager::FreeMemFromMemPool(void *device_ptr) {
|
void MemoryManager::FreeMemFromMemPool(void *device_ptr) {
|
||||||
if (device_ptr == nullptr) {
|
if (device_ptr == nullptr) {
|
||||||
MS_LOG(ERROR) << "FreeMemFromMemPool device_ptr is null.";
|
MS_LOG(ERROR) << "FreeMemFromMemPool device_ptr is null.";
|
||||||
|
|
|
@ -47,6 +47,7 @@ class MemoryManager {
|
||||||
|
|
||||||
virtual void MallocMemFromMemPool(const DeviceAddressPtr address, size_t size);
|
virtual void MallocMemFromMemPool(const DeviceAddressPtr address, size_t size);
|
||||||
virtual void *MallocMemFromMemPool(size_t size);
|
virtual void *MallocMemFromMemPool(size_t size);
|
||||||
|
virtual void FreeMemFromMemPool(const DeviceAddressPtr address);
|
||||||
virtual void FreeMemFromMemPool(void *device_ptr);
|
virtual void FreeMemFromMemPool(void *device_ptr);
|
||||||
|
|
||||||
size_t GetCommonAlignSize(size_t input_size) const;
|
size_t GetCommonAlignSize(size_t input_size) const;
|
||||||
|
|
|
@ -273,30 +273,21 @@ void MemReuseUtil::SetReuseRefCount() {
|
||||||
}
|
}
|
||||||
|
|
||||||
void MemReuseUtil::SetGraphOutputRefCount() {
|
void MemReuseUtil::SetGraphOutputRefCount() {
|
||||||
for (const auto &output : graph_->outputs()) {
|
auto nodes = AnfAlgo::GetAllOutput(graph_->output(), {prim::kPrimTupleGetItem});
|
||||||
MS_EXCEPTION_IF_NULL(output);
|
for (const auto &node : nodes) {
|
||||||
for (size_t i = 0; i < AnfAlgo::GetInputTensorNum(output); ++i) {
|
auto kernel_input = AnfAlgo::VisitKernelWithReturnType(node, 0);
|
||||||
if (!(output->isa<CNode>())) {
|
MS_EXCEPTION_IF_NULL(kernel_input.first);
|
||||||
continue;
|
if (!kernel_input.first->isa<CNode>() || !AnfAlgo::IsRealKernel(kernel_input.first)) {
|
||||||
}
|
continue;
|
||||||
auto cnode = output->cast<CNodePtr>();
|
}
|
||||||
MS_EXCEPTION_IF_NULL(cnode);
|
auto ak_node = kernel_input.first->cast<CNodePtr>();
|
||||||
auto input_node = cnode->input(i + 1);
|
auto key = ak_node.get();
|
||||||
MS_EXCEPTION_IF_NULL(input_node);
|
auto iter = kernel_output_refs_.find(key);
|
||||||
auto kernel_input = AnfAlgo::VisitKernel(input_node, 0);
|
if ((iter != kernel_output_refs_.end()) && (kernel_input.second < iter->second.size())) {
|
||||||
MS_EXCEPTION_IF_NULL(kernel_input.first);
|
auto kernel_ref_count_ptr = kernel_output_refs_[key][kernel_input.second];
|
||||||
if (!(kernel_input.first->isa<CNode>())) {
|
MS_EXCEPTION_IF_NULL(kernel_ref_count_ptr);
|
||||||
continue;
|
kernel_ref_count_ptr->ref_count_ = kMaxRefCount;
|
||||||
}
|
kernel_ref_count_ptr->ref_count_dynamic_use_ = kMaxRefCount;
|
||||||
auto ak_node = kernel_input.first->cast<CNodePtr>();
|
|
||||||
auto key = ak_node.get();
|
|
||||||
auto iter = kernel_output_refs_.find(key);
|
|
||||||
if ((iter != kernel_output_refs_.end()) && (kernel_input.second < iter->second.size())) {
|
|
||||||
auto kernel_ref_count_ptr = kernel_output_refs_[key][kernel_input.second];
|
|
||||||
MS_EXCEPTION_IF_NULL(kernel_ref_count_ptr);
|
|
||||||
kernel_ref_count_ptr->ref_count_ = kMaxRefCount;
|
|
||||||
kernel_ref_count_ptr->ref_count_dynamic_use_ = kMaxRefCount;
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
#ifdef MEM_REUSE_DEBUG
|
#ifdef MEM_REUSE_DEBUG
|
||||||
|
|
|
@ -75,7 +75,7 @@ MsContext::MsContext(const std::string& policy, const std::string& target) {
|
||||||
precompile_only_ = false;
|
precompile_only_ = false;
|
||||||
auto_mixed_precision_flag_ = true;
|
auto_mixed_precision_flag_ = true;
|
||||||
enable_pynative_infer_ = false;
|
enable_pynative_infer_ = false;
|
||||||
enable_dynamic_mem_pool_ = false;
|
enable_dynamic_mem_pool_ = true;
|
||||||
graph_memory_max_size_ = "0";
|
graph_memory_max_size_ = "0";
|
||||||
variable_memory_max_size_ = "0";
|
variable_memory_max_size_ = "0";
|
||||||
MS_LOG(INFO) << "Create context with backend policy:" << policy << ", device target:" << target << ".";
|
MS_LOG(INFO) << "Create context with backend policy:" << policy << ", device target:" << target << ".";
|
||||||
|
|
Loading…
Reference in New Issue