forked from mindspore-Ecosystem/mindspore
!18494 dynamic default memory size for vnpu
Merge pull request !18494 from zhoufeng/vnpu
This commit is contained in:
commit
930755a16b
|
@ -27,14 +27,29 @@ using mindspore::profiler::MemoryProfiling;
|
|||
namespace mindspore {
|
||||
namespace device {
|
||||
namespace ascend {
|
||||
namespace {
|
||||
constexpr uint64_t kAscendInitDeviceMemGB = 30;
|
||||
constexpr uint64_t kAscendMaxDeviceMemGB = 31;
|
||||
constexpr uint64_t kMemSizeGB = 30;
|
||||
constexpr uint64_t kAscendDeviceMemSize = (kAscendInitDeviceMemGB << kMemSizeGB);
|
||||
|
||||
uint64_t GetDefaultDeviceMemSize() {
|
||||
size_t free = 0;
|
||||
size_t total = 0;
|
||||
rtError_t ret = rtMemGetInfoEx(RT_MEMORYINFO_HBM, &free, &total);
|
||||
if (ret != RT_ERROR_NONE || total == 0) {
|
||||
MS_LOG(WARNING) << "Get total HBM memory size failed, ret = " << ret << ", use default value "
|
||||
<< kAscendDeviceMemSize;
|
||||
return kAscendDeviceMemSize;
|
||||
}
|
||||
|
||||
return total * 15 / 16; // reserved memory is 1/16 of total
|
||||
}
|
||||
} // namespace
|
||||
|
||||
void AscendMemoryManager::MallocDeviceMemory() {
|
||||
auto context_mem = GetDeviceMemSizeFromContext();
|
||||
device_mem_size_ = context_mem == 0 ? kAscendDeviceMemSize : context_mem;
|
||||
device_mem_size_ = context_mem == 0 ? GetDefaultDeviceMemSize() : context_mem;
|
||||
auto ret = rtMalloc(reinterpret_cast<void **>(&device_mem_base_), device_mem_size_, RT_MEMORY_HBM);
|
||||
if (ret != ACL_RT_SUCCESS) {
|
||||
if (ret == ACL_ERROR_RT_MEMORY_ALLOCATION) {
|
||||
|
@ -56,7 +71,7 @@ void AscendMemoryManager::MallocDeviceMemory() {
|
|||
|
||||
uint64_t AscendMemoryManager::GetDeviceMemSize() {
|
||||
auto mem_size = GetDeviceMemSizeFromContext();
|
||||
return mem_size == 0 ? kAscendDeviceMemSize : mem_size;
|
||||
return mem_size == 0 ? GetDefaultDeviceMemSize() : mem_size;
|
||||
}
|
||||
|
||||
uint64_t AscendMemoryManager::GetDeviceMemSizeFromContext() {
|
||||
|
|
|
@ -195,3 +195,5 @@ RTS_API rtError_t rtKernelLaunchWithFlag(const void *stubFunc, uint32_t blockDim
|
|||
rtSmDesc_t *smDesc, rtStream_t stream, uint32_t flags) {
|
||||
return RT_ERROR_NONE;
|
||||
}
|
||||
|
||||
RTS_API rtError_t rtMemGetInfoEx(rtMemInfoType_t memInfoType, size_t *free, size_t *total) { return RT_ERROR_NONE; }
|
||||
|
|
Loading…
Reference in New Issue