!17505 Update gpu memory reuse.

Merge pull request !17505 from linqingke/gpu_memory
This commit is contained in:
i-robot 2021-06-16 14:50:20 +08:00 committed by Gitee
commit 5a34e74551
3 changed files with 75 additions and 10 deletions

View File

@ -95,6 +95,11 @@ class DynamicMemPoolBestFit {
return global_idle_mem_buf_map_;
}
// Get the minimum memory unit size using for dynamic extend.
size_t mem_alloc_unit_size() const { return mem_alloc_unit_size_; }
// Set the minimum memory unit size using for dynamic extend.
void set_mem_alloc_unit_size(const size_t &size) { mem_alloc_unit_size_ = size; }
// Get the related memory statistics information.
size_t total_mem_statistics() const { return total_mem_statistics_; }
size_t used_mem_statistics() const { return total_used_mem_statistics_; }
@ -113,8 +118,6 @@ class DynamicMemPoolBestFit {
virtual size_t CalMemBlockAllocSize(size_t size);
private:
// Get the minimum memory unit size using for dynamic extend.
size_t mem_alloc_unit_size() const { return DYNAMIC_MEM_ALLOC_UNIT_SIZE; }
// Find the idle memory buf by aligned size when memory alloc.
DeviceMemPtr FindIdleMemBuf(size_t size);
// Add the memory block and memory buf when memory alloc not find the idle memory buf.
@ -143,6 +146,9 @@ class DynamicMemPoolBestFit {
size_t total_used_mem_statistics_{0};
size_t used_mem_peak_statistics_{0};
// The minimum memory unit size.
size_t mem_alloc_unit_size_{DYNAMIC_MEM_ALLOC_UNIT_SIZE};
// Support multi-thread.
std::mutex mutex_;
};

View File

@ -341,6 +341,68 @@ bool GPUKernelRuntime::IsDistributedTraining(const session::KernelGraph *graph)
[](const AnfNodePtr &kernel) { return AnfAlgo::IsCommunicationOp(kernel); });
}
void GPUKernelRuntime::FetchMemUnitSize(const session::KernelGraph *graph) {
MS_EXCEPTION_IF_NULL(graph);
mem_reuse_util_->ResetDynamicUsedRefCount();
size_t max_sum_size = 0;
size_t current_sum_size = 0;
auto &kernels = graph->execution_order();
for (const auto &cnode : kernels) {
auto kernel_mode = AnfAlgo::GetKernelMod(cnode);
MS_EXCEPTION_IF_NULL(kernel_mode);
auto kernel = cnode->cast<AnfNodePtr>();
MS_EXCEPTION_IF_NULL(kernel);
if (AnfAlgo::IsCommunicationOp(kernel)) {
continue;
}
const auto &input_size_list = kernel_mode->GetInputSizeList();
const auto &output_size_list = kernel_mode->GetOutputSizeList();
const auto &workspace_size_list = kernel_mode->GetWorkspaceSizeList();
size_t input_size = std::accumulate(input_size_list.begin(), input_size_list.end(), 0);
size_t output_size = std::accumulate(output_size_list.begin(), output_size_list.end(), 0);
size_t workspace_size = std::accumulate(workspace_size_list.begin(), workspace_size_list.end(), 0);
current_sum_size = current_sum_size + input_size + output_size + workspace_size;
if (current_sum_size > max_sum_size) {
max_sum_size = current_sum_size;
}
// Free the input of kernel by reference count.
size_t input_num = AnfAlgo::GetInputTensorNum(kernel);
if (input_num != input_size_list.size()) {
continue;
}
for (size_t i = 0; i < input_num; ++i) {
auto kernel_ref_count_ptr = mem_reuse_util_->GetKernelInputRef(cnode, i);
if (kernel_ref_count_ptr == nullptr) {
continue;
}
kernel_ref_count_ptr->ref_count_dynamic_use_--;
if (kernel_ref_count_ptr->ref_count_dynamic_use_ < 0) {
MS_LOG(EXCEPTION) << "Check dynamic reference count failed.";
}
if (kernel_ref_count_ptr->ref_count_dynamic_use_ == 0) {
auto remove_size = kernel_ref_count_ptr->ref_count_ * input_size_list.at(i);
if (remove_size <= current_sum_size) {
current_sum_size -= remove_size;
} else {
current_sum_size = 0;
}
}
}
auto output_workspace_size = output_size + workspace_size;
if (output_workspace_size <= current_sum_size) {
current_sum_size -= output_workspace_size;
} else {
current_sum_size = 0;
}
}
if (max_sum_size > GPUMemoryAllocator::GetInstance().mem_alloc_unit_size()) {
size_t unit_size = (max_sum_size / DYNAMIC_MEM_ALLOC_UNIT_SIZE + 1) * DYNAMIC_MEM_ALLOC_UNIT_SIZE;
GPUMemoryAllocator::GetInstance().set_mem_alloc_unit_size(unit_size);
}
}
void GPUKernelRuntime::AssignMemory(session::KernelGraph *graph) {
auto context_ptr = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(context_ptr);
@ -416,6 +478,7 @@ bool GPUKernelRuntime::RunOneStep(const session::KernelGraph *graph) {
return LaunchKernelDynamic(graph);
}
// Mock run first step
FetchMemUnitSize(graph);
bool ret = LaunchKernelDynamic(graph, true, false);
is_first_step_map_[graph_id] = false;
if (ret) {
@ -1001,6 +1064,9 @@ bool GPUKernelRuntime::AllocKernelOutputDynamicRes(const mindspore::kernel::Kern
MS_EXCEPTION_IF_NULL(kernel);
MS_EXCEPTION_IF_NULL(kernel_outputs);
UpdateHostSwapOutQueue(mock);
if (AnfAlgo::IsCommunicationOp(kernel)) {
AllocCommunicationOpOutputDynamicRes(kernel);
}
auto output_sizes = kernel_mod.GetOutputSizeList();
for (size_t i = 0; i < output_sizes.size(); ++i) {
auto device_address = GetMutableOutputAddr(kernel, i, false);
@ -1054,7 +1120,6 @@ void GPUKernelRuntime::AllocCommunicationOpDynamicRes(const session::KernelGraph
if (AnfAlgo::IsCommunicationOp(kernel) && AnfAlgo::GetCNodeName(kernel) != kHcomSendOpName &&
AnfAlgo::GetCNodeName(kernel) != kReceiveOpName) {
AllocCommunicationOpInputDynamicRes(kernel);
AllocCommunicationOpOutputDynamicRes(kernel);
}
}
}
@ -1147,13 +1212,6 @@ void GPUKernelRuntime::FreeKernelDynamicRes(const mindspore::AnfNodePtr &kernel)
}
}
auto kernel_with_index = GetPrevNodeOutput(kernel, i);
// Maintain output addr of fused communication op to improve training performance
if (AnfAlgo::IsCommunicationOp(kernel_with_index.first) &&
AnfAlgo::GetInputTensorNum(kernel_with_index.first) > 1) {
continue;
}
auto kernel_ref_count_ptr = mem_reuse_util_->GetKernelInputRef(cnode, i);
if (kernel_ref_count_ptr == nullptr) {
continue;

View File

@ -106,6 +106,7 @@ class GPUKernelRuntime : public KernelRuntime {
void ClearSwapInfo(bool mock);
void AllocInplaceNodeMemory(const session::KernelGraph *graph);
bool IsDistributedTraining(const session::KernelGraph *graph);
void FetchMemUnitSize(const session::KernelGraph *graph);
DeviceAddressPtr GetPrevNodeMutableOutputAddr(const AnfNodePtr &node, size_t i, bool visit_nop_node);
DeviceAddressPtr GetMutableOutputAddr(const AnfNodePtr &node, size_t i, bool visit_nop_node);