!19090 Fix mem_unit_size overflow bug.

Merge pull request !19090 from linqingke/gpu_memory
This commit is contained in:
i-robot 2021-07-01 06:35:13 +00:00 committed by Gitee
commit 3ffb9ea147
1 changed files with 8 additions and 3 deletions

View File

@ -351,6 +351,7 @@ void GPUKernelRuntime::FetchMemUnitSize(const session::KernelGraph *graph) {
mem_reuse_util_->ResetDynamicUsedRefCount();
size_t max_sum_size = 0;
size_t current_sum_size = 0;
constexpr size_t kZeroNumber = 0;
auto &kernels = graph->execution_order();
for (const auto &cnode : kernels) {
auto kernel_mode = AnfAlgo::GetKernelMod(cnode);
@ -364,9 +365,9 @@ void GPUKernelRuntime::FetchMemUnitSize(const session::KernelGraph *graph) {
const auto &output_size_list = kernel_mode->GetOutputSizeList();
const auto &workspace_size_list = kernel_mode->GetWorkspaceSizeList();
size_t input_size = std::accumulate(input_size_list.begin(), input_size_list.end(), 0);
size_t output_size = std::accumulate(output_size_list.begin(), output_size_list.end(), 0);
size_t workspace_size = std::accumulate(workspace_size_list.begin(), workspace_size_list.end(), 0);
size_t input_size = std::accumulate(input_size_list.begin(), input_size_list.end(), kZeroNumber);
size_t output_size = std::accumulate(output_size_list.begin(), output_size_list.end(), kZeroNumber);
size_t workspace_size = std::accumulate(workspace_size_list.begin(), workspace_size_list.end(), kZeroNumber);
current_sum_size = current_sum_size + input_size + output_size + workspace_size;
if (current_sum_size > max_sum_size) {
max_sum_size = current_sum_size;
@ -404,6 +405,10 @@ void GPUKernelRuntime::FetchMemUnitSize(const session::KernelGraph *graph) {
}
if (max_sum_size > GPUMemoryAllocator::GetInstance().mem_alloc_unit_size()) {
size_t unit_size = (max_sum_size / DYNAMIC_MEM_ALLOC_UNIT_SIZE + 1) * DYNAMIC_MEM_ALLOC_UNIT_SIZE;
if (unit_size < DYNAMIC_MEM_ALLOC_UNIT_SIZE) {
MS_LOG(WARNING) << "Current memory unit size [" << unit_size << "] is too small.";
return;
}
size_t free_mem_size = GPUMemoryAllocator::GetInstance().free_mem_size();
constexpr float kValidMemoryRatio = 0.9;
free_mem_size = kValidMemoryRatio * free_mem_size;