diff --git a/mindspore/ccsrc/runtime/device/ascend/ascend_memory_manager.cc b/mindspore/ccsrc/runtime/device/ascend/ascend_memory_manager.cc index c5ce0bf3a33..14f4e119fce 100644 --- a/mindspore/ccsrc/runtime/device/ascend/ascend_memory_manager.cc +++ b/mindspore/ccsrc/runtime/device/ascend/ascend_memory_manager.cc @@ -27,14 +27,29 @@ using mindspore::profiler::MemoryProfiling; namespace mindspore { namespace device { namespace ascend { +namespace { constexpr uint64_t kAscendInitDeviceMemGB = 30; constexpr uint64_t kAscendMaxDeviceMemGB = 31; constexpr uint64_t kMemSizeGB = 30; constexpr uint64_t kAscendDeviceMemSize = (kAscendInitDeviceMemGB << kMemSizeGB); +uint64_t GetDefaultDeviceMemSize() { + size_t free = 0; + size_t total = 0; + rtError_t ret = rtMemGetInfoEx(RT_MEMORYINFO_HBM, &free, &total); + if (ret != RT_ERROR_NONE || total == 0) { + MS_LOG(WARNING) << "Get total HBM memory size failed, ret = " << ret << ", use default value " + << kAscendDeviceMemSize; + return kAscendDeviceMemSize; + } + + return total * 15 / 16; // reserved memory is 1/16 of total +} +} // namespace + void AscendMemoryManager::MallocDeviceMemory() { auto context_mem = GetDeviceMemSizeFromContext(); - device_mem_size_ = context_mem == 0 ? kAscendDeviceMemSize : context_mem; + device_mem_size_ = context_mem == 0 ? GetDefaultDeviceMemSize() : context_mem; auto ret = rtMalloc(reinterpret_cast(&device_mem_base_), device_mem_size_, RT_MEMORY_HBM); if (ret != ACL_RT_SUCCESS) { if (ret == ACL_ERROR_RT_MEMORY_ALLOCATION) { @@ -56,7 +71,7 @@ void AscendMemoryManager::MallocDeviceMemory() { uint64_t AscendMemoryManager::GetDeviceMemSize() { auto mem_size = GetDeviceMemSizeFromContext(); - return mem_size == 0 ? kAscendDeviceMemSize : mem_size; + return mem_size == 0 ? GetDefaultDeviceMemSize() : mem_size; } uint64_t AscendMemoryManager::GetDeviceMemSizeFromContext() { diff --git a/tests/ut/cpp/stub/runtime/runtime_stub.cc b/tests/ut/cpp/stub/runtime/runtime_stub.cc index 8e7d749e58a..0682ce3e7f8 100644 --- a/tests/ut/cpp/stub/runtime/runtime_stub.cc +++ b/tests/ut/cpp/stub/runtime/runtime_stub.cc @@ -195,3 +195,5 @@ RTS_API rtError_t rtKernelLaunchWithFlag(const void *stubFunc, uint32_t blockDim rtSmDesc_t *smDesc, rtStream_t stream, uint32_t flags) { return RT_ERROR_NONE; } + +RTS_API rtError_t rtMemGetInfoEx(rtMemInfoType_t memInfoType, size_t *free, size_t *total) { return RT_ERROR_NONE; }