diff --git a/mindspore/ccsrc/backend/optimizer/mem_reuse/mem_dynamic_allocator.h b/mindspore/ccsrc/backend/optimizer/mem_reuse/mem_dynamic_allocator.h index 022626d58c5..6141a9a2711 100644 --- a/mindspore/ccsrc/backend/optimizer/mem_reuse/mem_dynamic_allocator.h +++ b/mindspore/ccsrc/backend/optimizer/mem_reuse/mem_dynamic_allocator.h @@ -95,6 +95,11 @@ class DynamicMemPoolBestFit { return global_idle_mem_buf_map_; } + // Get the minimum memory unit size using for dynamic extend. + size_t mem_alloc_unit_size() const { return mem_alloc_unit_size_; } + // Set the minimum memory unit size using for dynamic extend. + void set_mem_alloc_unit_size(const size_t &size) { mem_alloc_unit_size_ = size; } + // Get the related memory statistics information. size_t total_mem_statistics() const { return total_mem_statistics_; } size_t used_mem_statistics() const { return total_used_mem_statistics_; } @@ -113,8 +118,6 @@ class DynamicMemPoolBestFit { virtual size_t CalMemBlockAllocSize(size_t size); private: - // Get the minimum memory unit size using for dynamic extend. - size_t mem_alloc_unit_size() const { return DYNAMIC_MEM_ALLOC_UNIT_SIZE; } // Find the idle memory buf by aligned size when memory alloc. DeviceMemPtr FindIdleMemBuf(size_t size); // Add the memory block and memory buf when memory alloc not find the idle memory buf. @@ -143,6 +146,9 @@ class DynamicMemPoolBestFit { size_t total_used_mem_statistics_{0}; size_t used_mem_peak_statistics_{0}; + // The minimum memory unit size. + size_t mem_alloc_unit_size_{DYNAMIC_MEM_ALLOC_UNIT_SIZE}; + // Support multi-thread. std::mutex mutex_; }; diff --git a/mindspore/ccsrc/runtime/device/gpu/gpu_kernel_runtime.cc b/mindspore/ccsrc/runtime/device/gpu/gpu_kernel_runtime.cc index b85e7ff9ff9..28184740f3c 100644 --- a/mindspore/ccsrc/runtime/device/gpu/gpu_kernel_runtime.cc +++ b/mindspore/ccsrc/runtime/device/gpu/gpu_kernel_runtime.cc @@ -341,6 +341,68 @@ bool GPUKernelRuntime::IsDistributedTraining(const session::KernelGraph *graph) [](const AnfNodePtr &kernel) { return AnfAlgo::IsCommunicationOp(kernel); }); } +void GPUKernelRuntime::FetchMemUnitSize(const session::KernelGraph *graph) { + MS_EXCEPTION_IF_NULL(graph); + mem_reuse_util_->ResetDynamicUsedRefCount(); + size_t max_sum_size = 0; + size_t current_sum_size = 0; + auto &kernels = graph->execution_order(); + for (const auto &cnode : kernels) { + auto kernel_mode = AnfAlgo::GetKernelMod(cnode); + MS_EXCEPTION_IF_NULL(kernel_mode); + auto kernel = cnode->cast(); + MS_EXCEPTION_IF_NULL(kernel); + if (AnfAlgo::IsCommunicationOp(kernel)) { + continue; + } + const auto &input_size_list = kernel_mode->GetInputSizeList(); + const auto &output_size_list = kernel_mode->GetOutputSizeList(); + const auto &workspace_size_list = kernel_mode->GetWorkspaceSizeList(); + + size_t input_size = std::accumulate(input_size_list.begin(), input_size_list.end(), 0); + size_t output_size = std::accumulate(output_size_list.begin(), output_size_list.end(), 0); + size_t workspace_size = std::accumulate(workspace_size_list.begin(), workspace_size_list.end(), 0); + current_sum_size = current_sum_size + input_size + output_size + workspace_size; + if (current_sum_size > max_sum_size) { + max_sum_size = current_sum_size; + } + + // Free the input of kernel by reference count. + size_t input_num = AnfAlgo::GetInputTensorNum(kernel); + if (input_num != input_size_list.size()) { + continue; + } + for (size_t i = 0; i < input_num; ++i) { + auto kernel_ref_count_ptr = mem_reuse_util_->GetKernelInputRef(cnode, i); + if (kernel_ref_count_ptr == nullptr) { + continue; + } + kernel_ref_count_ptr->ref_count_dynamic_use_--; + if (kernel_ref_count_ptr->ref_count_dynamic_use_ < 0) { + MS_LOG(EXCEPTION) << "Check dynamic reference count failed."; + } + if (kernel_ref_count_ptr->ref_count_dynamic_use_ == 0) { + auto remove_size = kernel_ref_count_ptr->ref_count_ * input_size_list.at(i); + if (remove_size <= current_sum_size) { + current_sum_size -= remove_size; + } else { + current_sum_size = 0; + } + } + } + auto output_workspace_size = output_size + workspace_size; + if (output_workspace_size <= current_sum_size) { + current_sum_size -= output_workspace_size; + } else { + current_sum_size = 0; + } + } + if (max_sum_size > GPUMemoryAllocator::GetInstance().mem_alloc_unit_size()) { + size_t unit_size = (max_sum_size / DYNAMIC_MEM_ALLOC_UNIT_SIZE + 1) * DYNAMIC_MEM_ALLOC_UNIT_SIZE; + GPUMemoryAllocator::GetInstance().set_mem_alloc_unit_size(unit_size); + } +} + void GPUKernelRuntime::AssignMemory(session::KernelGraph *graph) { auto context_ptr = MsContext::GetInstance(); MS_EXCEPTION_IF_NULL(context_ptr); @@ -416,6 +478,7 @@ bool GPUKernelRuntime::RunOneStep(const session::KernelGraph *graph) { return LaunchKernelDynamic(graph); } // Mock run first step + FetchMemUnitSize(graph); bool ret = LaunchKernelDynamic(graph, true, false); is_first_step_map_[graph_id] = false; if (ret) { @@ -1001,6 +1064,9 @@ bool GPUKernelRuntime::AllocKernelOutputDynamicRes(const mindspore::kernel::Kern MS_EXCEPTION_IF_NULL(kernel); MS_EXCEPTION_IF_NULL(kernel_outputs); UpdateHostSwapOutQueue(mock); + if (AnfAlgo::IsCommunicationOp(kernel)) { + AllocCommunicationOpOutputDynamicRes(kernel); + } auto output_sizes = kernel_mod.GetOutputSizeList(); for (size_t i = 0; i < output_sizes.size(); ++i) { auto device_address = GetMutableOutputAddr(kernel, i, false); @@ -1054,7 +1120,6 @@ void GPUKernelRuntime::AllocCommunicationOpDynamicRes(const session::KernelGraph if (AnfAlgo::IsCommunicationOp(kernel) && AnfAlgo::GetCNodeName(kernel) != kHcomSendOpName && AnfAlgo::GetCNodeName(kernel) != kReceiveOpName) { AllocCommunicationOpInputDynamicRes(kernel); - AllocCommunicationOpOutputDynamicRes(kernel); } } } @@ -1147,13 +1212,6 @@ void GPUKernelRuntime::FreeKernelDynamicRes(const mindspore::AnfNodePtr &kernel) } } - auto kernel_with_index = GetPrevNodeOutput(kernel, i); - // Maintain output addr of fused communication op to improve training performance - if (AnfAlgo::IsCommunicationOp(kernel_with_index.first) && - AnfAlgo::GetInputTensorNum(kernel_with_index.first) > 1) { - continue; - } - auto kernel_ref_count_ptr = mem_reuse_util_->GetKernelInputRef(cnode, i); if (kernel_ref_count_ptr == nullptr) { continue; diff --git a/mindspore/ccsrc/runtime/device/gpu/gpu_kernel_runtime.h b/mindspore/ccsrc/runtime/device/gpu/gpu_kernel_runtime.h index 3d1a6a19a58..ffc8c39baf1 100644 --- a/mindspore/ccsrc/runtime/device/gpu/gpu_kernel_runtime.h +++ b/mindspore/ccsrc/runtime/device/gpu/gpu_kernel_runtime.h @@ -106,6 +106,7 @@ class GPUKernelRuntime : public KernelRuntime { void ClearSwapInfo(bool mock); void AllocInplaceNodeMemory(const session::KernelGraph *graph); bool IsDistributedTraining(const session::KernelGraph *graph); + void FetchMemUnitSize(const session::KernelGraph *graph); DeviceAddressPtr GetPrevNodeMutableOutputAddr(const AnfNodePtr &node, size_t i, bool visit_nop_node); DeviceAddressPtr GetMutableOutputAddr(const AnfNodePtr &node, size_t i, bool visit_nop_node);