forked from mindspore-Ecosystem/mindspore
!3189 Reback memory pool to memory offset
Merge pull request !3189 from JoyLvliang/reback-memory-pool-to-memory-offset
This commit is contained in:
commit
55cd091f5e
|
@ -303,22 +303,12 @@ bool AscendDeviceAddress::ConvertFormatAndSyncHostToDevice(const std::vector<int
|
|||
return sync_ok;
|
||||
}
|
||||
|
||||
void AscendDeviceAddress::UpdateCommunicationAddress() {
|
||||
MS_EXCEPTION_IF_NULL(ptr_);
|
||||
communication_ptr_ = reinterpret_cast<uint8_t *>(ptr_) - kMemAlignSize;
|
||||
}
|
||||
|
||||
AscendDeviceAddress::~AscendDeviceAddress() {
|
||||
if (ptr_ == nullptr) {
|
||||
return;
|
||||
}
|
||||
if (from_mem_pool_) {
|
||||
if (communication_ptr_ != nullptr) {
|
||||
AscendMemoryPool::GetInstance().FreeTensorMem(communication_ptr_);
|
||||
communication_ptr_ = nullptr;
|
||||
} else {
|
||||
AscendMemoryPool::GetInstance().FreeTensorMem(ptr_);
|
||||
}
|
||||
AscendMemoryPool::GetInstance().FreeTensorMem(ptr_);
|
||||
ptr_ = nullptr;
|
||||
}
|
||||
}
|
||||
|
|
|
@ -39,7 +39,6 @@ class AscendDeviceAddress : public DeviceAddress {
|
|||
bool SyncDeviceToHost(const std::vector<int> &shape, size_t size, TypeId type, void *host_ptr) const override;
|
||||
bool SyncHostToDevice(const std::vector<int> &shape, size_t size, TypeId type, const void *host_ptr) const override;
|
||||
DeviceAddressType DeviceType() const override { return DeviceAddressType::kAscend; }
|
||||
void UpdateCommunicationAddress() override;
|
||||
#ifdef ENABLE_DUMP_E2E
|
||||
bool DumpMemToFile(bool dump_mode, const std::string &filepath, const std::string &host_fmt,
|
||||
const std::vector<int> &host_shape, TypeId host_type) const;
|
||||
|
@ -55,7 +54,6 @@ class AscendDeviceAddress : public DeviceAddress {
|
|||
bool ConvertFormatAndSyncHostToDevice(const std::vector<int> &shape, size_t size, TypeId type,
|
||||
const void *host_ptr) const;
|
||||
void SyncStream() const;
|
||||
uint8_t *communication_ptr_{nullptr};
|
||||
};
|
||||
using AscendDeviceAddressPtr = std::shared_ptr<AscendDeviceAddress>;
|
||||
} // namespace ascend
|
||||
|
|
|
@ -21,22 +21,32 @@
|
|||
namespace mindspore {
|
||||
namespace device {
|
||||
namespace ascend {
|
||||
constexpr uint64_t kAscendDeviceMemGB = 30;
|
||||
constexpr uint64_t kAscendDeviceMemGB = 26;
|
||||
constexpr uint64_t kAscendMemPoolGB = 4;
|
||||
constexpr uint64_t kMemSizeGB = 30;
|
||||
constexpr uint64_t kMaxMemSizeGB = 30;
|
||||
constexpr uint64_t kAscendDeviceMemSize = (kAscendDeviceMemGB << kMemSizeGB);
|
||||
constexpr uint64_t kAscendMemPoolSize = (kAscendMemPoolGB << kMemSizeGB);
|
||||
|
||||
void AscendMemoryManager::MallocDeviceMemory() {
|
||||
auto context_mem = GetDeviceMemSizeFromContext();
|
||||
device_mem_size_ = context_mem == 0 ? kAscendDeviceMemSize : context_mem;
|
||||
dynamic_mem_offset_ = device_mem_size_;
|
||||
auto ret = rtMalloc(reinterpret_cast<void **>(&device_mem_base_), dynamic_mem_offset_, RT_MEMORY_HBM);
|
||||
static_mem_offset_ = device_mem_size_;
|
||||
auto ret = rtMalloc(reinterpret_cast<void **>(&device_mem_base_), static_mem_offset_, RT_MEMORY_HBM);
|
||||
|
||||
if (ret != RT_ERROR_NONE) {
|
||||
MS_EXCEPTION(DeviceProcessError) << "rtMalloc mem size[" << dynamic_mem_offset_ << "] fail, ret[" << ret << "]";
|
||||
MS_EXCEPTION(DeviceProcessError) << "rtMalloc mem size[" << static_mem_offset_ << "] fail, ret[" << ret << "]";
|
||||
}
|
||||
|
||||
AscendMemoryPool::GetInstance().set_device_mem_pool_base(device_mem_base_);
|
||||
AscendMemoryPool::GetInstance().set_graph_dynamic_mem_offset(dynamic_mem_offset_);
|
||||
if (context_mem == 0) {
|
||||
device_mem_pool_size_ = kAscendMemPoolSize;
|
||||
ret = rtMalloc(reinterpret_cast<void **>(&device_mem_pool_base_), device_mem_pool_size_, RT_MEMORY_HBM);
|
||||
if (ret != RT_ERROR_NONE) {
|
||||
MS_EXCEPTION(DeviceProcessError) << "rtMalloc mem size[" << device_mem_pool_size_ << "] fail, ret[" << ret << "]";
|
||||
}
|
||||
AscendMemoryPool::GetInstance().set_device_mem_pool_base(device_mem_pool_base_);
|
||||
AscendMemoryPool::GetInstance().set_device_mem_pool_size(device_mem_pool_size_);
|
||||
}
|
||||
}
|
||||
|
||||
uint64_t AscendMemoryManager::GetDeviceMemSizeFromContext() {
|
||||
|
@ -54,7 +64,7 @@ uint64_t AscendMemoryManager::GetDeviceMemSizeFromContext() {
|
|||
auto gb_str = variable_memory_max_size.substr(0, pos);
|
||||
auto gb_var = std::stoull(gb_str);
|
||||
MS_LOG(INFO) << "variable_memory_max_size(GB):" << gb_var;
|
||||
if (gb_var > kAscendDeviceMemGB || gb_var == 0) {
|
||||
if (gb_var > kMaxMemSizeGB || gb_var == 0) {
|
||||
MS_LOG(EXCEPTION) << "Invalid allocate memory size:" << gb_var << " which should be in (0-30]GB";
|
||||
}
|
||||
return gb_var << kMemSizeGB;
|
||||
|
@ -77,71 +87,8 @@ void AscendMemoryManager::FreeDeviceMemory() {
|
|||
}
|
||||
}
|
||||
|
||||
void AscendMemoryManager::ResetDynamicMemory() {
|
||||
total_dynamic_size_ = 0;
|
||||
dynamic_mem_offset_ = device_mem_size_;
|
||||
AscendMemoryPool::GetInstance().set_graph_dynamic_mem_offset(dynamic_mem_offset_);
|
||||
}
|
||||
|
||||
void *AscendMemoryManager::MallocMemFromMemPool(size_t size) {
|
||||
auto align_size = GetCommonAlignSize(size);
|
||||
return AscendMemoryPool::GetInstance().AllocTensorMem(align_size);
|
||||
}
|
||||
|
||||
uint8_t *AscendMemoryManager::MallocStaticMem(size_t size, bool communication_mem) {
|
||||
size_t align_size = 0;
|
||||
if (communication_mem) {
|
||||
align_size = GetCommunicationAlignSize(size);
|
||||
} else {
|
||||
align_size = GetCommonAlignSize(size);
|
||||
}
|
||||
|
||||
auto device_mem_pool_offset = AscendMemoryPool::GetInstance().device_mem_pool_offset();
|
||||
MS_LOG(INFO) << "Malloc Memory: Static, total[" << device_mem_size_ << "] (dynamic[" << total_dynamic_size_
|
||||
<< "] memory pool[" << device_mem_pool_offset << "])"
|
||||
<< " malloc [" << align_size << "]";
|
||||
|
||||
if (communication_mem) {
|
||||
// create protect area [kMemAlignSize -- data -- kMemAlignSize]
|
||||
uint8_t *alloc_address = reinterpret_cast<uint8_t *>(AscendMemoryPool::GetInstance().AllocTensorMem(align_size));
|
||||
return alloc_address + kMemAlignSize;
|
||||
} else {
|
||||
return reinterpret_cast<uint8_t *>(AscendMemoryPool::GetInstance().AllocTensorMem(align_size));
|
||||
}
|
||||
}
|
||||
|
||||
uint8_t *AscendMemoryManager::MallocDynamicMem(size_t size, bool communication_mem) {
|
||||
size_t align_size = 0;
|
||||
if (communication_mem) {
|
||||
align_size = GetCommunicationAlignSize(size);
|
||||
} else {
|
||||
align_size = GetCommonAlignSize(size);
|
||||
}
|
||||
|
||||
auto device_mem_pool_offset = AscendMemoryPool::GetInstance().device_mem_pool_offset();
|
||||
MS_LOG(INFO) << "Malloc Memory: Dynamic, total[" << device_mem_size_ << "] (dynamic[" << total_dynamic_size_
|
||||
<< "] memory pool[" << device_mem_pool_offset << "])"
|
||||
<< " malloc [" << align_size << "]";
|
||||
|
||||
if (dynamic_mem_offset_ < align_size) {
|
||||
MS_LOG(EXCEPTION) << "Out of memory!!! total[" << device_mem_size_ << "] (dynamic[" << total_dynamic_size_
|
||||
<< "]) malloc [" << align_size << "] failed!";
|
||||
}
|
||||
auto new_offset = dynamic_mem_offset_ - align_size;
|
||||
if (new_offset <= device_mem_pool_offset) {
|
||||
MS_LOG(EXCEPTION) << "Out of memory!!! total[" << device_mem_size_ << "] (dynamic[" << total_dynamic_size_
|
||||
<< "] memory pool[" << device_mem_pool_offset << "])"
|
||||
<< " malloc [" << align_size << "] failed!";
|
||||
}
|
||||
total_dynamic_size_ += align_size;
|
||||
dynamic_mem_offset_ = new_offset;
|
||||
AscendMemoryPool::GetInstance().set_graph_dynamic_mem_offset(dynamic_mem_offset_);
|
||||
if (communication_mem) {
|
||||
// create protect area [kMemAlignSize -- data -- kMemAlignSize]
|
||||
return device_mem_base_ + new_offset + kMemAlignSize;
|
||||
} else {
|
||||
return device_mem_base_ + new_offset;
|
||||
}
|
||||
return AscendMemoryPool::GetInstance().AllocTensorMem(size);
|
||||
}
|
||||
} // namespace ascend
|
||||
} // namespace device
|
||||
|
|
|
@ -27,13 +27,8 @@ class AscendMemoryManager : public MemoryManager {
|
|||
|
||||
void MallocDeviceMemory() override;
|
||||
void FreeDeviceMemory() override;
|
||||
void ResetDynamicMemory() override;
|
||||
void *MallocMemFromMemPool(size_t size) override;
|
||||
|
||||
protected:
|
||||
uint8_t *MallocStaticMem(size_t size, bool communication_mem) override;
|
||||
uint8_t *MallocDynamicMem(size_t size, bool communication_mem) override;
|
||||
|
||||
private:
|
||||
uint8_t *device_mem_pool_base_{nullptr};
|
||||
uint64_t device_mem_pool_size_{0};
|
||||
|
|
|
@ -22,54 +22,51 @@ namespace mindspore {
|
|||
namespace device {
|
||||
namespace ascend {
|
||||
size_t AscendMemoryPool::AllocDeviceMem(size_t size, DeviceMemPtr *addr) {
|
||||
if (size == 0) {
|
||||
MS_LOG(EXCEPTION) << "Can not alloc memory size(0) in memory pool !";
|
||||
if (has_malloc_) {
|
||||
MS_LOG(EXCEPTION) << "Memory pool has been allocated memory resource!";
|
||||
}
|
||||
if (device_mem_pool_offset_ + size >= graph_dynamic_mem_offset_) {
|
||||
MS_LOG(EXCEPTION) << "Failed to alloc memory pool memory, the current device_mem_pool_offset_ ["
|
||||
<< device_mem_pool_offset_ << "], current graph_dynamic_mem_offset_ " << graph_dynamic_mem_offset_
|
||||
<< "], need memory size [" << size << "]";
|
||||
if (size == 0 || size > free_mem_size_) {
|
||||
MS_LOG(EXCEPTION) << "Failed to alloc memory pool resource, the size is zero or large than free mem size!";
|
||||
}
|
||||
*addr = device_mem_pool_base_ + device_mem_pool_offset_;
|
||||
device_mem_pool_offset_ += size;
|
||||
*addr = device_mem_pool_base_;
|
||||
if (*addr == nullptr) {
|
||||
MS_LOG(EXCEPTION) << "Alloc device address is nullptr, failed to alloc memory pool memory!";
|
||||
MS_LOG(EXCEPTION) << "Device memory pool base address is nullptr, failed to alloc memory pool resource!";
|
||||
}
|
||||
has_malloc_ = true;
|
||||
free_mem_size_ -= size;
|
||||
return size;
|
||||
}
|
||||
|
||||
bool AscendMemoryPool::FreeDeviceMem(const DeviceMemPtr &addr) {
|
||||
MS_EXCEPTION_IF_NULL(addr);
|
||||
has_malloc_ = false;
|
||||
free_mem_size_ = total_mem_size_;
|
||||
return true;
|
||||
}
|
||||
|
||||
size_t AscendMemoryPool::AlignMemorySize(size_t size) const {
|
||||
if (size == 0) {
|
||||
MS_LOG(EXCEPTION) << "The align memory size is a zero !";
|
||||
return DYNAMIC_MEM_ALIGN_SIZE;
|
||||
}
|
||||
return size;
|
||||
return ((size + DYNAMIC_MEM_ALIGN_SIZE + 31) / DYNAMIC_MEM_ALIGN_SIZE) * DYNAMIC_MEM_ALIGN_SIZE;
|
||||
}
|
||||
|
||||
size_t AscendMemoryPool::mem_alloc_unit_size() const { return free_mem_size_ - DYNAMIC_MEM_ALIGN_SIZE; }
|
||||
|
||||
void AscendMemoryPool::set_device_mem_pool_base(uint8_t *device_mem_pool_base) {
|
||||
MS_EXCEPTION_IF_NULL(device_mem_pool_base);
|
||||
device_mem_pool_base_ = device_mem_pool_base;
|
||||
}
|
||||
|
||||
void AscendMemoryPool::set_graph_dynamic_mem_offset(uint64_t graph_dynamic_mem_offset) {
|
||||
graph_dynamic_mem_offset_ = graph_dynamic_mem_offset;
|
||||
void AscendMemoryPool::set_device_mem_pool_size(uint64_t device_mem_pool_size) {
|
||||
device_mem_pool_size_ = device_mem_pool_size;
|
||||
free_mem_size_ = device_mem_pool_size_;
|
||||
total_mem_size_ = free_mem_size_;
|
||||
}
|
||||
|
||||
uint64_t AscendMemoryPool::device_mem_pool_offset() const { return device_mem_pool_offset_; }
|
||||
size_t AscendMemoryPool::free_mem_size() { return free_mem_size_; }
|
||||
|
||||
size_t AscendMemoryPool::free_mem_size() {
|
||||
if (graph_dynamic_mem_offset_ < device_mem_pool_offset_) {
|
||||
MS_LOG(EXCEPTION) << "graph dynamic mem offset [" << graph_dynamic_mem_offset_
|
||||
<< "] less than device mem pool offset [" << device_mem_pool_offset_ << "]!";
|
||||
}
|
||||
return graph_dynamic_mem_offset_ - device_mem_pool_offset_;
|
||||
}
|
||||
|
||||
size_t AscendMemoryPool::total_mem_size() { return graph_dynamic_mem_offset_ == 0 ? 0 : graph_dynamic_mem_offset_ - 1; }
|
||||
size_t AscendMemoryPool::total_mem_size() { return total_mem_size_; }
|
||||
} // namespace ascend
|
||||
} // namespace device
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -32,9 +32,8 @@ class AscendMemoryPool : public DynamicMemPoolBestFit {
|
|||
size_t AllocDeviceMem(size_t size, DeviceMemPtr *addr) override;
|
||||
bool FreeDeviceMem(const DeviceMemPtr &addr) override;
|
||||
void set_device_mem_pool_base(uint8_t *device_mem_pool_base);
|
||||
void set_graph_dynamic_mem_offset(uint64_t graph_dynamic_mem_offset);
|
||||
void set_device_mem_pool_size(uint64_t device_mem_pool_size);
|
||||
|
||||
uint64_t device_mem_pool_offset() const;
|
||||
size_t free_mem_size() override;
|
||||
size_t total_mem_size() override;
|
||||
|
||||
|
@ -46,12 +45,16 @@ class AscendMemoryPool : public DynamicMemPoolBestFit {
|
|||
protected:
|
||||
// The real size by memory alloc aligned.
|
||||
size_t AlignMemorySize(size_t size) const override;
|
||||
// Get the minimum memory unit size using for dynamic extend.
|
||||
size_t mem_alloc_unit_size() const override;
|
||||
|
||||
private:
|
||||
AscendMemoryPool() = default;
|
||||
bool has_malloc_{false};
|
||||
uint8_t *device_mem_pool_base_{nullptr};
|
||||
uint64_t device_mem_pool_offset_{0};
|
||||
uint64_t graph_dynamic_mem_offset_{0};
|
||||
uint64_t device_mem_pool_size_{0};
|
||||
size_t free_mem_size_{0};
|
||||
size_t total_mem_size_{0};
|
||||
};
|
||||
} // namespace ascend
|
||||
} // namespace device
|
||||
|
|
|
@ -61,7 +61,6 @@ class DeviceAddress : public mindspore::DeviceSync {
|
|||
std::string format() const { return format_; }
|
||||
TypeId type_id() const { return type_id_; }
|
||||
void set_host_shape(const std::vector<int> &shape) { host_shape_ = shape; }
|
||||
virtual void UpdateCommunicationAddress() {}
|
||||
virtual void set_status(DeviceAddressStatus status) {}
|
||||
virtual DeviceAddressStatus status() const { return DeviceAddressStatus::kInDevice; }
|
||||
virtual DeviceAddressType DeviceType() const { return DeviceAddressType::kUnknown; }
|
||||
|
|
|
@ -439,10 +439,6 @@ void KernelRuntime::AssignCommunicationNodeOutputMem(int flag, const AnfNodePtr
|
|||
std::string output_format = AnfAlgo::GetOutputFormat(node, j);
|
||||
auto output_type = AnfAlgo::GetOutputDeviceDataType(node, j);
|
||||
auto address = CreateDeviceAddress(output_ptr, output_sizes[j], output_format, output_type);
|
||||
MS_EXCEPTION_IF_NULL(address);
|
||||
if (AnfAlgo::IsCommunicationOp(node) && context_ptr->enable_hccl()) {
|
||||
address->UpdateCommunicationAddress();
|
||||
}
|
||||
AnfAlgo::SetOutputAddr(address, j, node.get());
|
||||
output_ptr += align_size_list[j];
|
||||
}
|
||||
|
@ -492,8 +488,6 @@ void KernelRuntime::AssignCommunicationNodeInputMem(int flag, const AnfNodePtr &
|
|||
}
|
||||
|
||||
void KernelRuntime::AssignNodeOutputMem(int flag, const AnfNodePtr &node, int index) {
|
||||
auto context_ptr = MsContext::GetInstance();
|
||||
MS_EXCEPTION_IF_NULL(context_ptr);
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
MS_EXCEPTION_IF_NULL(mem_manager_);
|
||||
if (AnfAlgo::IsGetNext(NOT_NULL(node)) && flag == kReuseDynamicMem) {
|
||||
|
@ -525,9 +519,6 @@ void KernelRuntime::AssignNodeOutputMem(int flag, const AnfNodePtr &node, int in
|
|||
auto device_address = CreateDeviceAddress(ptr, output_sizes[i], output_format, output_type);
|
||||
MS_EXCEPTION_IF_NULL(device_address);
|
||||
device_address->set_host_shape(trans::GetRuntimePaddingShape(node, i));
|
||||
if (AnfAlgo::IsCommunicationOp(node) && context_ptr->enable_hccl()) {
|
||||
device_address->UpdateCommunicationAddress();
|
||||
}
|
||||
AnfAlgo::SetOutputAddr(device_address, i, node.get());
|
||||
}
|
||||
}
|
||||
|
|
|
@ -37,7 +37,7 @@ class MemoryManager {
|
|||
|
||||
virtual void MallocDeviceMemory() = 0;
|
||||
virtual void FreeDeviceMemory() = 0;
|
||||
virtual void ResetDynamicMemory() {
|
||||
void ResetDynamicMemory() {
|
||||
total_dynamic_size_ = 0;
|
||||
dynamic_mem_offset_ = 0;
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue