!19090 Fix mem_unit_size overflow bug.
Merge pull request !19090 from linqingke/gpu_memory
This commit is contained in:
commit
3ffb9ea147
|
@ -351,6 +351,7 @@ void GPUKernelRuntime::FetchMemUnitSize(const session::KernelGraph *graph) {
|
|||
mem_reuse_util_->ResetDynamicUsedRefCount();
|
||||
size_t max_sum_size = 0;
|
||||
size_t current_sum_size = 0;
|
||||
constexpr size_t kZeroNumber = 0;
|
||||
auto &kernels = graph->execution_order();
|
||||
for (const auto &cnode : kernels) {
|
||||
auto kernel_mode = AnfAlgo::GetKernelMod(cnode);
|
||||
|
@ -364,9 +365,9 @@ void GPUKernelRuntime::FetchMemUnitSize(const session::KernelGraph *graph) {
|
|||
const auto &output_size_list = kernel_mode->GetOutputSizeList();
|
||||
const auto &workspace_size_list = kernel_mode->GetWorkspaceSizeList();
|
||||
|
||||
size_t input_size = std::accumulate(input_size_list.begin(), input_size_list.end(), 0);
|
||||
size_t output_size = std::accumulate(output_size_list.begin(), output_size_list.end(), 0);
|
||||
size_t workspace_size = std::accumulate(workspace_size_list.begin(), workspace_size_list.end(), 0);
|
||||
size_t input_size = std::accumulate(input_size_list.begin(), input_size_list.end(), kZeroNumber);
|
||||
size_t output_size = std::accumulate(output_size_list.begin(), output_size_list.end(), kZeroNumber);
|
||||
size_t workspace_size = std::accumulate(workspace_size_list.begin(), workspace_size_list.end(), kZeroNumber);
|
||||
current_sum_size = current_sum_size + input_size + output_size + workspace_size;
|
||||
if (current_sum_size > max_sum_size) {
|
||||
max_sum_size = current_sum_size;
|
||||
|
@ -404,6 +405,10 @@ void GPUKernelRuntime::FetchMemUnitSize(const session::KernelGraph *graph) {
|
|||
}
|
||||
if (max_sum_size > GPUMemoryAllocator::GetInstance().mem_alloc_unit_size()) {
|
||||
size_t unit_size = (max_sum_size / DYNAMIC_MEM_ALLOC_UNIT_SIZE + 1) * DYNAMIC_MEM_ALLOC_UNIT_SIZE;
|
||||
if (unit_size < DYNAMIC_MEM_ALLOC_UNIT_SIZE) {
|
||||
MS_LOG(WARNING) << "Current memory unit size [" << unit_size << "] is too small.";
|
||||
return;
|
||||
}
|
||||
size_t free_mem_size = GPUMemoryAllocator::GetInstance().free_mem_size();
|
||||
constexpr float kValidMemoryRatio = 0.9;
|
||||
free_mem_size = kValidMemoryRatio * free_mem_size;
|
||||
|
|
Loading…
Reference in New Issue