Optimize pynative device memory use

Add gradient to pynative unique

Signed-off-by: zjun <zhangjun0@huawei.com>
This commit is contained in:
zjun 2021-10-15 09:21:31 +08:00
parent 4711194524
commit eb450dd31f
41 changed files with 574 additions and 274 deletions

View File

@ -24,6 +24,8 @@ MindSpore上下文用于配置当前执行环境包括执行模式、执
| | max_device_memory | GPU | | | max_device_memory | GPU |
| +------------------------------+----------------------------+ | +------------------------------+----------------------------+
| | variable_memory_max_size | Ascend | | | variable_memory_max_size | Ascend |
| +------------------------------+----------------------------+
| | mempool_block_size | GPU/Ascend |
+-------------------------+------------------------------+----------------------------+ +-------------------------+------------------------------+----------------------------+
| 调试配置 | save_graphs | CPU/GPU/Ascend | | 调试配置 | save_graphs | CPU/GPU/Ascend |
| +------------------------------+----------------------------+ | +------------------------------+----------------------------+
@ -76,6 +78,7 @@ MindSpore上下文用于配置当前执行环境包括执行模式、执
- **device_target** (str) - 表示待运行的目标设备支持Ascend、GPU和CPU。如果未设置设备目标则使用MindSpore包的版本。 - **device_target** (str) - 表示待运行的目标设备支持Ascend、GPU和CPU。如果未设置设备目标则使用MindSpore包的版本。
- **max_device_memory** (str) - 设置设备可用的最大内存。目前仅在GPU上支持。格式为“xxGB”。默认值1024GB。实际使用的内存大小是设备的可用内存和 `max_device_memory` 值中的最小值。 - **max_device_memory** (str) - 设置设备可用的最大内存。目前仅在GPU上支持。格式为“xxGB”。默认值1024GB。实际使用的内存大小是设备的可用内存和 `max_device_memory` 值中的最小值。
- **variable_memory_max_size** (str) - 设置可变内存的最大值。默认值30GB。设置此参数后框架使用的最大内存受配置值的限制。 - **variable_memory_max_size** (str) - 设置可变内存的最大值。默认值30GB。设置此参数后框架使用的最大内存受配置值的限制。
- **mempool_block_size** (str) - 设置PyNative模式下设备内存池的块大小。格式为“xxGB”。默认值1GB。最小值是1GB。实际使用的内存池块大小是设备的可用内存和 `mempool_block_size` 值中的最小值。
- **save_graphs** (bool) - 表示是否保存图形。默认值False。当 `save_graphs` 属性设为True时 `save_graphs_path` 属性用于设置中间编译图的存储路径。默认情况下,图形保存在当前目录下。 - **save_graphs** (bool) - 表示是否保存图形。默认值False。当 `save_graphs` 属性设为True时 `save_graphs_path` 属性用于设置中间编译图的存储路径。默认情况下,图形保存在当前目录下。
- **save_graphs_path** (str) - 表示保存图形的路径。默认值:"."。如果指定的目录不存在,系统将自动创建该目录。在分布式训练中,图形将被保存到 `save_graphs_path/rank_${rank_id}/` 目录下。 `rank_id` 为集群中当前设备的ID。 - **save_graphs_path** (str) - 表示保存图形的路径。默认值:"."。如果指定的目录不存在,系统将自动创建该目录。在分布式训练中,图形将被保存到 `save_graphs_path/rank_${rank_id}/` 目录下。 `rank_id` 为集群中当前设备的ID。
- **enable_dump** (bool) - 此参数已弃用,将在下一版本中删除。 - **enable_dump** (bool) - 此参数已弃用,将在下一版本中删除。
@ -155,6 +158,7 @@ MindSpore上下文用于配置当前执行环境包括执行模式、执
... profiling_options='{"output":"/home/data/output","training_trace":"on"}') ... profiling_options='{"output":"/home/data/output","training_trace":"on"}')
>>> context.set_context(check_bprop=True) >>> context.set_context(check_bprop=True)
>>> context.set_context(max_device_memory="3.5GB") >>> context.set_context(max_device_memory="3.5GB")
>>> context.set_context(mempool_block_size="1GB")
>>> context.set_context(print_file_path="print.pb") >>> context.set_context(print_file_path="print.pb")
>>> context.set_context(enable_sparse=True) >>> context.set_context(enable_sparse=True)
>>> context.set_context(max_call_depth=80) >>> context.set_context(max_call_depth=80)

View File

@ -15,39 +15,51 @@
*/ */
#include "backend/optimizer/mem_reuse/mem_dynamic_allocator.h" #include "backend/optimizer/mem_reuse/mem_dynamic_allocator.h"
#include <string>
#include "utils/ms_utils.h" #include "utils/ms_utils.h"
#include "utils/convert_utils.h" #include "utils/convert_utils.h"
#include "utils/log_adapter.h" #include "utils/log_adapter.h"
#include "utils/ms_context.h"
namespace mindspore { namespace mindspore {
namespace device { namespace device {
DynamicMemPoolBestFit::~DynamicMemPoolBestFit() { static const char kPynativeParamMem[] = "Pynative unique mem";
global_mem_block_list_.clear(); static const char kCommonMem[] = "Common mem";
global_idle_mem_buf_map_.clear(); const size_t kGBToByte = 1024 << 20;
static bool IsPynativeMode() {
auto ms_context = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(ms_context);
return ms_context->get_param<int>(MS_CTX_EXECUTION_MODE) == kPynativeMode;
} }
DeviceMemPtr DynamicMemPoolBestFit::AllocTensorMem(size_t size) { DynamicMemPoolBestFit::~DynamicMemPoolBestFit() {
persistent_mem_->clear();
common_mem_->clear();
}
DeviceMemPtr DynamicMemPoolBestFit::AllocTensorMem(size_t size, bool from_persistent_mem) {
size_t align_size = AlignMemorySize(size); size_t align_size = AlignMemorySize(size);
std::lock_guard<std::mutex> locker(mutex_); std::lock_guard<std::mutex> locker(mutex_);
// Find the idle memory buf by tensor size, if not find, then add new memory block and memory buf. // Find the idle memory buf by tensor size, if not find, then add new memory block and memory buf.
DeviceMemPtr device_addr = FindIdleMemBuf(align_size); DeviceMemPtr device_addr = FindIdleMemBuf(align_size, from_persistent_mem);
if (!device_addr) { if (!device_addr) {
device_addr = AddMemBlockAndMemBuf(align_size); device_addr = AddMemBlockAndMemBuf(align_size, from_persistent_mem);
} }
return device_addr; return device_addr;
} }
std::vector<DeviceMemPtr> DynamicMemPoolBestFit::AllocContinuousTensorMem(size_t total_size, std::vector<DeviceMemPtr> DynamicMemPoolBestFit::AllocContinuousTensorMem(size_t total_size,
std::vector<size_t> size_list) { const std::vector<size_t> &size_list) {
std::vector<DeviceMemPtr> device_addr_list; std::vector<DeviceMemPtr> device_addr_list;
// Pre-alloc the one whole piece memory. // Pre-alloc the one whole piece memory.
auto device_addr = AllocTensorMem(total_size); auto device_addr = AllocTensorMem(total_size, false);
if (!device_addr) { if (!device_addr) {
return device_addr_list; return device_addr_list;
} }
std::lock_guard<std::mutex> locker(mutex_); std::lock_guard<std::mutex> locker(mutex_);
// Remove the pre-alloc memory. // Remove the pre-alloc memory.
const auto &mem_block = FindMemBlock(device_addr); const auto &mem_block = FindMemBlock(device_addr, common_mem_);
MS_EXCEPTION_IF_NULL(mem_block); MS_EXCEPTION_IF_NULL(mem_block);
const auto &iter = mem_block->block_all_mem_buf_map_.find(device_addr); const auto &iter = mem_block->block_all_mem_buf_map_.find(device_addr);
if (iter == mem_block->block_all_mem_buf_map_.end()) { if (iter == mem_block->block_all_mem_buf_map_.end()) {
@ -63,11 +75,11 @@ std::vector<DeviceMemPtr> DynamicMemPoolBestFit::AllocContinuousTensorMem(size_t
// Split the pre-alloc memory into continuous memory by the size list. // Split the pre-alloc memory into continuous memory by the size list.
DynamicMemBufPtr continuous_mem_buf; DynamicMemBufPtr continuous_mem_buf;
auto buf_addr = device_addr; auto buf_addr = device_addr;
for (size_t i = 0; i < size_list.size(); i++) { for (size_t i : size_list) {
continuous_mem_buf = std::make_shared<DynamicMemBuf>(buf_addr, kMemBufUsed, size_list[i]); continuous_mem_buf = std::make_shared<DynamicMemBuf>(buf_addr, kMemBufUsed, i);
(void)mem_block->block_all_mem_buf_map_.emplace(buf_addr, continuous_mem_buf); (void)mem_block->block_all_mem_buf_map_.emplace(buf_addr, continuous_mem_buf);
device_addr_list.emplace_back(buf_addr); device_addr_list.emplace_back(buf_addr);
buf_addr = AddressOffset(buf_addr, size_list[i]); buf_addr = AddressOffset(buf_addr, i);
} }
// Update the size of the last memory buf. // Update the size of the last memory buf.
continuous_mem_buf->size_ += rest_size; continuous_mem_buf->size_ += rest_size;
@ -81,9 +93,14 @@ size_t DynamicMemPoolBestFit::AlignMemorySize(size_t size) const {
return ((size + DYNAMIC_MEM_ALIGN_SIZE - 1) / DYNAMIC_MEM_ALIGN_SIZE) * DYNAMIC_MEM_ALIGN_SIZE; return ((size + DYNAMIC_MEM_ALIGN_SIZE - 1) / DYNAMIC_MEM_ALIGN_SIZE) * DYNAMIC_MEM_ALIGN_SIZE;
} }
DeviceMemPtr DynamicMemPoolBestFit::FindIdleMemBuf(size_t size) { DeviceMemPtr DynamicMemPoolBestFit::FindIdleMemBuf(size_t size, bool from_persistent_mem) {
const auto &iter = global_idle_mem_buf_map_.lower_bound(size); auto mem_mng = common_mem_;
if (iter != global_idle_mem_buf_map_.end()) { if (from_persistent_mem) {
mem_mng = persistent_mem_;
}
MS_EXCEPTION_IF_NULL(mem_mng);
const auto &iter = mem_mng->idle_mem_buf_map_.lower_bound(size);
if (iter != mem_mng->idle_mem_buf_map_.end()) {
auto mem_buf = iter->second; auto mem_buf = iter->second;
MS_EXCEPTION_IF_NULL(mem_buf); MS_EXCEPTION_IF_NULL(mem_buf);
if (mem_buf->status_ != kMemBufIdle) { if (mem_buf->status_ != kMemBufIdle) {
@ -92,24 +109,69 @@ DeviceMemPtr DynamicMemPoolBestFit::FindIdleMemBuf(size_t size) {
} }
mem_buf->status_ = kMemBufUsed; mem_buf->status_ = kMemBufUsed;
// Remove map of old idle memory buf // Remove map of old idle memory buf
(void)global_idle_mem_buf_map_.erase(iter); (void)mem_mng->idle_mem_buf_map_.erase(iter);
// Divide memory buf // Divide memory buf
if (IsDivide(size, mem_buf->size_)) { if (IsSplit(size, mem_buf->size_)) {
DivideMemBuf(size, mem_buf); SplitMemBuf(size, mem_buf, mem_mng);
} }
// Memory statistics // Memory statistics
total_used_mem_statistics_ += mem_buf->size_; mem_mng->mps_.total_used_mem_size_ += mem_buf->size_;
if (total_used_mem_statistics_ > used_mem_peak_statistics_) { if (mem_mng->mps_.total_used_mem_size_ > mem_mng->mps_.used_mem_peak_size_) {
used_mem_peak_statistics_ = total_used_mem_statistics_; mem_mng->mps_.used_mem_peak_size_ = mem_mng->mps_.total_used_mem_size_;
} }
return mem_buf->device_addr_; return mem_buf->device_addr_;
} }
return nullptr; return nullptr;
} }
DeviceMemPtr DynamicMemPoolBestFit::AddMemBlockAndMemBuf(size_t size) { size_t DynamicMemPoolBestFit::MemAllocUnitSize(bool from_persistent_mem) const {
size_t alloc_mem_size = CalMemBlockAllocSize(size); return from_persistent_mem ? persistent_mem_->unit_size_ : common_mem_->unit_size_;
}
void DynamicMemPoolBestFit::SetMemAllocUintSize(size_t size) {
auto ms_context = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(ms_context);
persistent_mem_->unit_size_ = DYNAMIC_MEM_ALLOC_UNIT_SIZE;
common_mem_->unit_size_ = size - persistent_mem_->unit_size_;
MS_LOG(INFO) << "Set mem alloc unit size " << size;
}
void DynamicMemPoolBestFit::SetMempoolBlockSize(size_t device_mem_size) {
if (!IsPynativeMode()) {
return;
}
auto ms_context = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(ms_context);
float mem_block_size = ms_context->get_param<float>(MS_CTX_MEMPOOL_BLOCK_SIZE);
if (mem_block_size == kDefaultMempoolBlockSize) {
return;
}
size_t config_size = FloatToSize(mem_block_size * kGBToByte);
size_t real_block_size = std::min(config_size, device_mem_size);
if (config_size > device_mem_size) {
MS_LOG(WARNING) << "Memory pool block size " << config_size << " is bigger than currently available maximum memory "
<< device_mem_size << ", and the actual effective value will be " << device_mem_size;
}
SetMemAllocUintSize(real_block_size);
}
DeviceMemPtr DynamicMemPoolBestFit::AddMemBlockAndMemBuf(size_t size, bool from_persistent_mem) {
// Pyantive unique mem is not enough, find from common
if (from_persistent_mem && !persistent_mem_->mem_block_list_.empty()) {
auto mem_addr = FindIdleMemBuf(size, false);
if (mem_addr != nullptr) {
return mem_addr;
}
from_persistent_mem = false;
}
size_t alloc_mem_size = CalMemBlockAllocSize(size, from_persistent_mem);
if (alloc_mem_size == 0) { if (alloc_mem_size == 0) {
MS_LOG(DEBUG) << "Try to find in other mem";
auto mem_addr = FindIdleMemBuf(size, !from_persistent_mem);
if (mem_addr != nullptr) {
return mem_addr;
}
DumpDynamicMemPoolInfo();
return nullptr; return nullptr;
} }
// Add new memory block // Add new memory block
@ -118,40 +180,50 @@ DeviceMemPtr DynamicMemPoolBestFit::AddMemBlockAndMemBuf(size_t size) {
if (real_alloc_size < size) { if (real_alloc_size < size) {
MS_LOG(WARNING) << "Memory not enough: alloc size[" << real_alloc_size << "] is smaller than required size[" << size MS_LOG(WARNING) << "Memory not enough: alloc size[" << real_alloc_size << "] is smaller than required size[" << size
<< "]."; << "].";
DumpDynamicMemPoolInfo();
return nullptr; return nullptr;
} }
mem_alloc_unit_size_ = DYNAMIC_MEM_ALLOC_UNIT_SIZE; // In graph mode, unit_size are set once using an estimated memory value size, and subsequent memory requests use the
// default size
if (!IsPynativeMode()) {
common_mem_->unit_size_ = DYNAMIC_MEM_ALLOC_UNIT_SIZE;
}
auto mem_mng = common_mem_;
if (from_persistent_mem) {
mem_mng = persistent_mem_;
}
MS_EXCEPTION_IF_NULL(mem_mng);
auto mem_block = std::make_shared<DynamicMemBlock>(device_addr, real_alloc_size); auto mem_block = std::make_shared<DynamicMemBlock>(device_addr, real_alloc_size);
MS_EXCEPTION_IF_NULL(mem_block); MS_EXCEPTION_IF_NULL(mem_block);
const auto &iter = const auto &iter =
std::upper_bound(global_mem_block_list_.begin(), global_mem_block_list_.end(), device_addr, CmpMemBlock); std::upper_bound(mem_mng->mem_block_list_.begin(), mem_mng->mem_block_list_.end(), device_addr, CmpMemBlock);
(void)global_mem_block_list_.insert(iter, mem_block); (void)mem_mng->mem_block_list_.insert(iter, mem_block);
// Add new memory buf // Add new memory buf
auto mem_buf = std::make_shared<DynamicMemBuf>(device_addr, kMemBufUsed, real_alloc_size); auto mem_buf = std::make_shared<DynamicMemBuf>(device_addr, kMemBufUsed, real_alloc_size);
MS_EXCEPTION_IF_NULL(mem_buf); MS_EXCEPTION_IF_NULL(mem_buf);
// Add map of new memory buf in the block // Add map of new memory buf in the block
(void)mem_block->block_all_mem_buf_map_.emplace(device_addr, mem_buf); (void)mem_block->block_all_mem_buf_map_.emplace(device_addr, mem_buf);
// Divide memory buf // Split memory buf
if (IsDivide(size, mem_buf->size_)) { if (IsSplit(size, mem_buf->size_)) {
DivideMemBuf(size, mem_buf); SplitMemBuf(size, mem_buf, mem_mng);
} }
// Memory statistics // Memory statistics
total_mem_statistics_ += real_alloc_size; mem_mng->mps_.total_mem_size_ += real_alloc_size;
total_used_mem_statistics_ += mem_buf->size_; mem_mng->mps_.total_used_mem_size_ += mem_buf->size_;
if (total_used_mem_statistics_ > used_mem_peak_statistics_) { if (mem_mng->mps_.total_used_mem_size_ > mem_mng->mps_.used_mem_peak_size_) {
used_mem_peak_statistics_ = total_used_mem_statistics_; mem_mng->mps_.used_mem_peak_size_ = mem_mng->mps_.total_used_mem_size_;
} }
return mem_buf->device_addr_; return mem_buf->device_addr_;
} }
size_t DynamicMemPoolBestFit::CalMemBlockAllocSize(size_t size) { size_t DynamicMemPoolBestFit::CalMemBlockAllocSize(size_t size, bool from_persistent_mem) {
auto device_free_mem_size = free_mem_size(); auto device_free_mem_size = free_mem_size();
if (device_free_mem_size < size) { if (device_free_mem_size < size) {
MS_LOG(WARNING) << "Memory not enough: current free memory size[" << device_free_mem_size MS_LOG(WARNING) << "Memory not enough: current free memory size[" << device_free_mem_size
<< "] is smaller than required size[" << size << "]."; << "] is smaller than required size[" << size << "].";
return 0; return 0;
} }
auto alloc_mem_size = mem_alloc_unit_size(); auto alloc_mem_size = MemAllocUnitSize(from_persistent_mem);
// Growing at twice of alloc size // Growing at twice of alloc size
constexpr size_t kDouble = 2; constexpr size_t kDouble = 2;
while (alloc_mem_size < size) { while (alloc_mem_size < size) {
@ -161,13 +233,14 @@ size_t DynamicMemPoolBestFit::CalMemBlockAllocSize(size_t size) {
return alloc_mem_size; return alloc_mem_size;
} }
bool DynamicMemPoolBestFit::IsDivide(size_t tensor_size, size_t mem_buf_size) const { bool DynamicMemPoolBestFit::IsSplit(size_t tensor_size, size_t mem_buf_size) const {
return mem_buf_size - tensor_size >= DYNAMIC_MEM_ALIGN_SIZE; return mem_buf_size - tensor_size >= DYNAMIC_MEM_ALIGN_SIZE;
} }
void DynamicMemPoolBestFit::DivideMemBuf(size_t size, const DynamicMemBufPtr &mem_buf) { void DynamicMemPoolBestFit::SplitMemBuf(size_t size, const DynamicMemBufPtr &mem_buf,
const MemStatusManagerPtr &mem_mng) {
MS_EXCEPTION_IF_NULL(mem_buf); MS_EXCEPTION_IF_NULL(mem_buf);
const auto &mem_block = FindMemBlock(mem_buf->device_addr_); const auto &mem_block = FindMemBlock(mem_buf->device_addr_, mem_mng);
MS_EXCEPTION_IF_NULL(mem_block); MS_EXCEPTION_IF_NULL(mem_block);
// Divide new memory buf // Divide new memory buf
if (mem_buf->size_ < size) { if (mem_buf->size_ < size) {
@ -180,7 +253,7 @@ void DynamicMemPoolBestFit::DivideMemBuf(size_t size, const DynamicMemBufPtr &me
// Add map of new memory buf in the block // Add map of new memory buf in the block
(void)mem_block->block_all_mem_buf_map_.emplace(newbuf_addr, new_mem_buf); (void)mem_block->block_all_mem_buf_map_.emplace(newbuf_addr, new_mem_buf);
// Add map of new idle memory buf // Add map of new idle memory buf
(void)global_idle_mem_buf_map_.emplace(newbuf_size, new_mem_buf); (void)mem_mng->idle_mem_buf_map_.emplace(newbuf_size, new_mem_buf);
} }
bool DynamicMemPoolBestFit::CmpMemBlock(const DeviceMemPtr &device_addr, const DynamicMemBlockPtr &mem_block) { bool DynamicMemPoolBestFit::CmpMemBlock(const DeviceMemPtr &device_addr, const DynamicMemBlockPtr &mem_block) {
@ -189,11 +262,12 @@ bool DynamicMemPoolBestFit::CmpMemBlock(const DeviceMemPtr &device_addr, const D
return device_addr < mem_block->device_addr(); return device_addr < mem_block->device_addr();
} }
DynamicMemBlockPtr DynamicMemPoolBestFit::FindMemBlock(const DeviceMemPtr &device_addr) { DynamicMemBlockPtr DynamicMemPoolBestFit::FindMemBlock(const DeviceMemPtr &device_addr,
const MemStatusManagerPtr &mem_mng) {
MS_EXCEPTION_IF_NULL(device_addr); MS_EXCEPTION_IF_NULL(device_addr);
auto &&iter = auto &&iter =
std::upper_bound(global_mem_block_list_.begin(), global_mem_block_list_.end(), device_addr, CmpMemBlock); std::upper_bound(mem_mng->mem_block_list_.begin(), mem_mng->mem_block_list_.end(), device_addr, CmpMemBlock);
if (iter != global_mem_block_list_.begin()) { if (iter != mem_mng->mem_block_list_.begin()) {
return *(--iter); return *(--iter);
} }
return nullptr; return nullptr;
@ -202,16 +276,32 @@ DynamicMemBlockPtr DynamicMemPoolBestFit::FindMemBlock(const DeviceMemPtr &devic
void DynamicMemPoolBestFit::FreeTensorMem(const DeviceMemPtr &device_addr) { void DynamicMemPoolBestFit::FreeTensorMem(const DeviceMemPtr &device_addr) {
MS_EXCEPTION_IF_NULL(device_addr); MS_EXCEPTION_IF_NULL(device_addr);
std::lock_guard<std::mutex> locker(mutex_); std::lock_guard<std::mutex> locker(mutex_);
const auto &mem_block = FindMemBlock(device_addr); auto fn = [this](const MemStatusManagerPtr &mem_mng, const DeviceMemPtr &device_addr) -> DynamicMemBlockPtr {
auto mem_block = FindMemBlock(device_addr, mem_mng);
if (mem_block != nullptr) {
const auto &iter = mem_block->block_all_mem_buf_map_.find(device_addr);
if (iter != mem_block->block_all_mem_buf_map_.end()) {
return mem_block;
}
}
return nullptr;
};
auto mem_block = fn(common_mem_, device_addr);
if (mem_block == nullptr) { if (mem_block == nullptr) {
// May be destroy the memory pool first, then destroy the address, so this is normal case. mem_block = fn(persistent_mem_, device_addr);
MS_LOG(DEBUG) << "Can't find the mem_block of the device address[" << device_addr << "]."; if (mem_block == nullptr) {
return; // May be destroy the memory pool first, then destroy the address, so this is normal case.
MS_LOG(DEBUG) << "Can't find the mem_block of the device address[" << device_addr << "].";
return;
}
CombineMemBuf(mem_block, device_addr, persistent_mem_);
} else {
CombineMemBuf(mem_block, device_addr, common_mem_);
} }
CombineMemBuf(mem_block, device_addr);
} }
void DynamicMemPoolBestFit::CombineMemBuf(const DynamicMemBlockPtr &mem_block, const DeviceMemPtr &device_addr) { void DynamicMemPoolBestFit::CombineMemBuf(const DynamicMemBlockPtr &mem_block, const DeviceMemPtr &device_addr,
const MemStatusManagerPtr &mem_mng) {
MS_EXCEPTION_IF_NULL(mem_block); MS_EXCEPTION_IF_NULL(mem_block);
MS_EXCEPTION_IF_NULL(device_addr); MS_EXCEPTION_IF_NULL(device_addr);
const auto &iter = mem_block->block_all_mem_buf_map_.find(device_addr); const auto &iter = mem_block->block_all_mem_buf_map_.find(device_addr);
@ -224,10 +314,10 @@ void DynamicMemPoolBestFit::CombineMemBuf(const DynamicMemBlockPtr &mem_block, c
MS_LOG(EXCEPTION) << "Find the mem_buf is not used, mem_buf_address[" << mem_buf->device_addr_ << "]."; MS_LOG(EXCEPTION) << "Find the mem_buf is not used, mem_buf_address[" << mem_buf->device_addr_ << "].";
} }
mem_buf->status_ = kMemBufIdle; mem_buf->status_ = kMemBufIdle;
if (total_used_mem_statistics_ < mem_buf->size_) { if (mem_mng->mps_.total_used_mem_size_ < mem_buf->size_) {
MS_LOG(EXCEPTION) << "The total used mem size is less than the size of membuf."; MS_LOG(EXCEPTION) << "The total used mem size is less than the size of membuf.";
} }
total_used_mem_statistics_ -= mem_buf->size_; mem_mng->mps_.total_used_mem_size_ -= mem_buf->size_;
// Combine backward(combine the next_mem_buf to mem_buf) // Combine backward(combine the next_mem_buf to mem_buf)
auto next_iter = iter; auto next_iter = iter;
(void)next_iter++; (void)next_iter++;
@ -236,7 +326,7 @@ void DynamicMemPoolBestFit::CombineMemBuf(const DynamicMemBlockPtr &mem_block, c
MS_EXCEPTION_IF_NULL(next_mem_buf); MS_EXCEPTION_IF_NULL(next_mem_buf);
if (next_mem_buf->status_ == kMemBufIdle) { if (next_mem_buf->status_ == kMemBufIdle) {
mem_buf->size_ += next_mem_buf->size_; mem_buf->size_ += next_mem_buf->size_;
EraseIdleMemBuf(next_mem_buf->size_, next_mem_buf->device_addr_); EraseIdleMemBuf(next_mem_buf->size_, next_mem_buf->device_addr_, mem_mng);
(void)mem_block->block_all_mem_buf_map_.erase(next_iter); (void)mem_block->block_all_mem_buf_map_.erase(next_iter);
} }
} }
@ -249,7 +339,7 @@ void DynamicMemPoolBestFit::CombineMemBuf(const DynamicMemBlockPtr &mem_block, c
prev_mem_buf = prev_iter->second; prev_mem_buf = prev_iter->second;
MS_EXCEPTION_IF_NULL(prev_mem_buf); MS_EXCEPTION_IF_NULL(prev_mem_buf);
if (prev_mem_buf->status_ == kMemBufIdle) { if (prev_mem_buf->status_ == kMemBufIdle) {
EraseIdleMemBuf(prev_mem_buf->size_, prev_mem_buf->device_addr_); EraseIdleMemBuf(prev_mem_buf->size_, prev_mem_buf->device_addr_, mem_mng);
prev_mem_buf->size_ += mem_buf->size_; prev_mem_buf->size_ += mem_buf->size_;
(void)mem_block->block_all_mem_buf_map_.erase(iter); (void)mem_block->block_all_mem_buf_map_.erase(iter);
forward_combine = true; forward_combine = true;
@ -257,20 +347,21 @@ void DynamicMemPoolBestFit::CombineMemBuf(const DynamicMemBlockPtr &mem_block, c
} }
// Add map of new idle memory // Add map of new idle memory
if (forward_combine) { if (forward_combine) {
(void)global_idle_mem_buf_map_.emplace(prev_mem_buf->size_, prev_mem_buf); (void)mem_mng->idle_mem_buf_map_.emplace(prev_mem_buf->size_, prev_mem_buf);
} else { } else {
(void)global_idle_mem_buf_map_.emplace(mem_buf->size_, mem_buf); (void)mem_mng->idle_mem_buf_map_.emplace(mem_buf->size_, mem_buf);
} }
} }
void DynamicMemPoolBestFit::EraseIdleMemBuf(size_t size, const DeviceMemPtr &device_addr) { void DynamicMemPoolBestFit::EraseIdleMemBuf(size_t size, const DeviceMemPtr &device_addr,
const MemStatusManagerPtr &mem_mng) {
MS_EXCEPTION_IF_NULL(device_addr); MS_EXCEPTION_IF_NULL(device_addr);
auto &&iter = global_idle_mem_buf_map_.equal_range(size); auto &&iter = mem_mng->idle_mem_buf_map_.equal_range(size);
while (iter.first != iter.second) { while (iter.first != iter.second) {
MS_EXCEPTION_IF_NULL(iter.first->second); MS_EXCEPTION_IF_NULL(iter.first->second);
// Remove map of the idle memory buf by size and device address // Remove map of the idle memory buf by size and device address
if (iter.first->second->device_addr_ == device_addr) { if (iter.first->second->device_addr_ == device_addr) {
(void)global_idle_mem_buf_map_.erase(iter.first); (void)mem_mng->idle_mem_buf_map_.erase(iter.first);
return; return;
} }
(void)iter.first++; (void)iter.first++;
@ -280,70 +371,48 @@ void DynamicMemPoolBestFit::EraseIdleMemBuf(size_t size, const DeviceMemPtr &dev
void DynamicMemPoolBestFit::ReleaseDeviceRes() { void DynamicMemPoolBestFit::ReleaseDeviceRes() {
std::lock_guard<std::mutex> locker(mutex_); std::lock_guard<std::mutex> locker(mutex_);
MS_LOG(INFO) << "The dynamic memory pool total size is " << total_mem_statistics_ << ", total used size is " auto fn = [this](const MemStatusManagerPtr &mem_mng) {
<< total_used_mem_statistics_ << ", used peak size is " << used_mem_peak_statistics_ << "."; for (auto &iter : mem_mng->mem_block_list_) {
for (auto iter = global_mem_block_list_.begin(); iter != global_mem_block_list_.end(); ++iter) { auto &device_addr = iter->device_addr_base_;
auto &device_addr = (*iter)->device_addr_base_; if (device_addr != nullptr) {
if (device_addr != nullptr) { if (!FreeDeviceMem(device_addr)) {
if (!FreeDeviceMem(device_addr)) { MS_LOG(EXCEPTION) << "Free device memory[" << device_addr << "] error.";
MS_LOG(EXCEPTION) << "Free device memory[" << device_addr << "] error."; }
device_addr = nullptr;
} }
device_addr = nullptr;
} }
} mem_mng->mem_block_list_.clear();
mem_mng->idle_mem_buf_map_.clear();
global_mem_block_list_.clear(); };
global_idle_mem_buf_map_.clear(); fn(common_mem_);
fn(persistent_mem_);
} }
void DynamicMemPoolBestFit::DumpDynamicMemPoolInfo() { void DynamicMemPoolBestFit::DumpDynamicMemPoolInfo() {
std::lock_guard<std::mutex> locker(mutex_); auto fn = [](const MemStatusManagerPtr &mem_mng, const std::string &mem_type) {
MS_LOG(INFO) << "Start dump dynamic memory pool info."; if (mem_mng->mem_block_list_.empty()) {
DeviceAddrMapMemBuf mem_block_map; return;
DynamicMemBufPtr mem_buf;
size_t total_mem = 0;
size_t total_used_mem = 0;
size_t total_idle_mem1 = 0;
size_t total_idle_mem2 = 0;
// Dump the memory block info and memory buf info
MS_LOG(INFO) << "Dump all mem_block info: counts[" << global_mem_block_list_.size() << "].";
for (auto iter = global_mem_block_list_.begin(); iter != global_mem_block_list_.end(); ++iter) {
total_mem += (*iter)->size();
mem_block_map = (*iter)->block_all_mem_buf_map_;
MS_LOG(INFO) << "MemBlock info: number[" << iter - global_mem_block_list_.begin() << "] mem_buf_counts["
<< mem_block_map.size() << "] base_address[" << (*iter)->device_addr() << "] block_size["
<< (*iter)->size() << "].";
for (auto iter_mem_buf = mem_block_map.begin(); iter_mem_buf != mem_block_map.end(); ++iter_mem_buf) {
mem_buf = iter_mem_buf->second;
MS_EXCEPTION_IF_NULL(mem_buf);
if (mem_buf->status_ == kMemBufIdle) {
total_idle_mem1 += mem_buf->size_;
} else {
total_used_mem += mem_buf->size_;
}
MS_LOG(INFO) << "MemBuf info: address[" << mem_buf->device_addr_ << "] size[" << mem_buf->size_ << "] status["
<< mem_buf->status_ << "].";
} }
} std::ostringstream buf;
// Dump all the idle memory buf info for (size_t i = 0; i < mem_mng->mem_block_list_.size(); ++i) {
MS_LOG(INFO) << "Dump all idle mem_buf info: counts[" << global_idle_mem_buf_map_.size() << "]."; size_t idle_size = 0;
for (auto iter_idle = global_idle_mem_buf_map_.begin(); iter_idle != global_idle_mem_buf_map_.end(); ++iter_idle) { for (auto mb = mem_mng->mem_block_list_[i]->block_all_mem_buf_map_.begin();
mem_buf = iter_idle->second; mb != mem_mng->mem_block_list_[i]->block_all_mem_buf_map_.end(); ++mb) {
MS_EXCEPTION_IF_NULL(mem_buf); if (mb->second->status_ == kMemBufIdle) {
total_idle_mem2 += mem_buf->size_; idle_size += mb->second->size_;
MS_LOG(INFO) << "Idle mem_buf info: size[" << mem_buf->size_ << "] address[" << mem_buf->device_addr_ << "] status[" }
<< mem_buf->status_ << "]."; }
} buf << ", block[" << i << "] idle size " << idle_size;
// Dump the memory statistical info }
MS_LOG(INFO) << "Total allocated memory[" << total_mem << "], used memory[" << total_used_mem << "], idle memory[" // Dump all the memory buf info
<< total_idle_mem1 << "]."; MS_LOG(WARNING) << mem_type << "pool info: block size " << mem_mng->unit_size_ << ", block counts "
if (total_idle_mem1 != total_idle_mem2) { << mem_mng->mem_block_list_.size() << buf.str() << ". Total allocated mem "
MS_LOG(ERROR) << "Check error: the idle memory in the mem_block is not equal the global idle memory."; << mem_mng->mps_.total_mem_size_ << ", peak used mem " << mem_mng->mps_.used_mem_peak_size_
} << ", in used mem " << mem_mng->mps_.total_used_mem_size_ << ", total idle mem "
if (total_mem != total_used_mem + total_idle_mem1) { << mem_mng->mps_.total_mem_size_ - mem_mng->mps_.total_used_mem_size_;
MS_LOG(ERROR) << "Check error: the the total memory is not equal the sum of used memory and idle memory."; };
} fn(common_mem_, std::string(kCommonMem));
MS_LOG(INFO) << "Finish dump dynamic memory pool info."; fn(persistent_mem_, std::string(kPynativeParamMem));
} }
} // namespace device } // namespace device
} // namespace mindspore } // namespace mindspore

View File

@ -36,7 +36,7 @@ enum DynamicMemBufStatus : int { kMemBufIdle, kMemBufUsed };
static const size_t DYNAMIC_MEM_ALIGN_SIZE = 512; static const size_t DYNAMIC_MEM_ALIGN_SIZE = 512;
// The minimum unit size (1G) of memory block used for dynamic extend. // The minimum unit size (1G) of memory block used for dynamic extend.
static const size_t DYNAMIC_MEM_ALLOC_UNIT_SIZE = 1024 << 20; static const size_t DYNAMIC_MEM_ALLOC_UNIT_SIZE = 1073741824;
// The Comparator of device address from small to large. // The Comparator of device address from small to large.
struct DeviceAddrCmp { struct DeviceAddrCmp {
@ -77,16 +77,40 @@ class DynamicMemBlock {
}; };
using DynamicMemBlockPtr = std::shared_ptr<DynamicMemBlock>; using DynamicMemBlockPtr = std::shared_ptr<DynamicMemBlock>;
struct DeviceState {
// Memory allocated from device
size_t total_mem_size_{0};
// Memory in use
size_t total_used_mem_size_{0};
// Maximum peak memory usage
size_t used_mem_peak_size_{0};
};
struct MemStatusManager {
size_t unit_size_{DYNAMIC_MEM_ALLOC_UNIT_SIZE};
// Mempool state
DeviceState mps_;
std::vector<DynamicMemBlockPtr> mem_block_list_;
// The map of all idle memory buf by size.
SizeMapMemBuf idle_mem_buf_map_;
void clear() {
mem_block_list_.clear();
idle_mem_buf_map_.clear();
}
};
using MemStatusManagerPtr = std::shared_ptr<MemStatusManager>;
// The main class of dynamic memory pool. // The main class of dynamic memory pool.
class DynamicMemPoolBestFit { class DynamicMemPoolBestFit {
public: public:
DynamicMemPoolBestFit() = default; DynamicMemPoolBestFit()
: persistent_mem_(std::make_shared<MemStatusManager>()), common_mem_(std::make_shared<MemStatusManager>()) {}
virtual ~DynamicMemPoolBestFit(); virtual ~DynamicMemPoolBestFit();
// The main program entry of memory alloc. // The main program entry of memory alloc.
DeviceMemPtr AllocTensorMem(size_t size); DeviceMemPtr AllocTensorMem(size_t size, bool from_persistent_mem = false);
// The main program entry of continuous memory alloc. // The main program entry of continuous memory alloc.
std::vector<DeviceMemPtr> AllocContinuousTensorMem(size_t total_size, std::vector<size_t> size_list); std::vector<DeviceMemPtr> AllocContinuousTensorMem(size_t total_size, const std::vector<size_t> &size_list);
// The main program entry of memory free. // The main program entry of memory free.
void FreeTensorMem(const DeviceMemPtr &device_addr); void FreeTensorMem(const DeviceMemPtr &device_addr);
@ -94,21 +118,22 @@ class DynamicMemPoolBestFit {
void ReleaseDeviceRes(); void ReleaseDeviceRes();
// Display the information of memory block and memory buf. // Display the information of memory block and memory buf.
void DumpDynamicMemPoolInfo(); void DumpDynamicMemPoolInfo();
// Get the map of global idle mem buf and size.
SizeMapMemBuf global_idle_mem_buf_map() {
std::lock_guard<std::mutex> locker(mutex_);
return global_idle_mem_buf_map_;
}
// Get the minimum memory unit size using for dynamic extend. // Get the minimum memory unit size using for dynamic extend.
size_t mem_alloc_unit_size() const { return mem_alloc_unit_size_; } size_t MemAllocUnitSize(bool from_persistent_mem = false) const;
// Set the minimum memory unit size using for dynamic extend. // Set the minimum memory unit size using for dynamic extend.
void set_mem_alloc_unit_size(const size_t &size) { mem_alloc_unit_size_ = size; } void SetMemAllocUintSize(size_t size);
// Set mempool init percent in pynative mode
// Get the related memory statistics information. void SetMempoolBlockSize(size_t device_mem_size);
size_t total_mem_statistics() const { return total_mem_statistics_; } size_t TotalMemStatistics() const {
size_t used_mem_statistics() const { return total_used_mem_statistics_; } return common_mem_->mps_.total_mem_size_ + persistent_mem_->mps_.total_mem_size_;
size_t used_mem_peak_statistics() const { return used_mem_peak_statistics_; } }
size_t TotalUsedMemStatistics() const {
return common_mem_->mps_.total_used_mem_size_ + persistent_mem_->mps_.total_used_mem_size_;
}
size_t UsedMemPeakStatistics() const {
return common_mem_->mps_.used_mem_peak_size_ + persistent_mem_->mps_.used_mem_peak_size_;
}
// The related interface of device memory real operation, needs override by device type. // The related interface of device memory real operation, needs override by device type.
virtual size_t AllocDeviceMem(size_t size, DeviceMemPtr *addr) = 0; virtual size_t AllocDeviceMem(size_t size, DeviceMemPtr *addr) = 0;
@ -116,43 +141,35 @@ class DynamicMemPoolBestFit {
virtual size_t free_mem_size() = 0; virtual size_t free_mem_size() = 0;
protected: protected:
MemStatusManagerPtr &common_mem() { return common_mem_; }
MemStatusManagerPtr &persistent_mem() { return persistent_mem_; }
// The real size by memory alloc aligned. // The real size by memory alloc aligned.
virtual size_t AlignMemorySize(size_t size) const; virtual size_t AlignMemorySize(size_t size) const;
// Calculate memory block required alloc size when adding the memory block. // Calculate memory block required alloc size when adding the memory block.
virtual size_t CalMemBlockAllocSize(size_t size); virtual size_t CalMemBlockAllocSize(size_t size, bool from_persistent_mem);
private: private:
// Find the idle memory buf by aligned size when memory alloc. // Find the idle memory buf by aligned size when memory alloc.
DeviceMemPtr FindIdleMemBuf(size_t size); DeviceMemPtr FindIdleMemBuf(size_t size, bool from_persistent_mem);
// Add the memory block and memory buf when memory alloc not find the idle memory buf. // Add the memory block and memory buf when memory alloc not find the idle memory buf.
DeviceMemPtr AddMemBlockAndMemBuf(size_t size); DeviceMemPtr AddMemBlockAndMemBuf(size_t size, bool from_persistent_mem);
// Judge whether need divide the memory buf by alloc size and memory buf size. // Judge whether need split the memory buf by alloc size and memory buf size.
bool IsDivide(size_t tensor_size, size_t mem_buf_size) const; bool IsSplit(size_t tensor_size, size_t mem_buf_size) const;
// Divide the memory buf by alloc size. // Split the memory buf by alloc size.
void DivideMemBuf(size_t size, const DynamicMemBufPtr &mem_buf); void SplitMemBuf(size_t size, const DynamicMemBufPtr &mem_buf, const MemStatusManagerPtr &mem_mng);
// Find the memory block by device address. // Find the memory block by device address.
DynamicMemBlockPtr FindMemBlock(const DeviceMemPtr &device_addr); DynamicMemBlockPtr FindMemBlock(const DeviceMemPtr &device_addr, const MemStatusManagerPtr &mem_mgr);
// The Comparator of memory block by device address, because memory blocks are arranged in order by device address. // The Comparator of memory block by device address, because memory blocks are arranged in order by device address.
static bool CmpMemBlock(const DeviceMemPtr &device_addr, const DynamicMemBlockPtr &mem_block); static bool CmpMemBlock(const DeviceMemPtr &device_addr, const DynamicMemBlockPtr &mem_block);
// Combine the memory buf when memory free, to avoid the memory fragmentation. // Combine the memory buf when memory free, to avoid the memory fragmentation.
void CombineMemBuf(const DynamicMemBlockPtr &mem_block, const DeviceMemPtr &device_addr); void CombineMemBuf(const DynamicMemBlockPtr &mem_block, const DeviceMemPtr &device_addr,
const MemStatusManagerPtr &mem_mng);
// Erase the idle memory buf by size and device address when idle memory buf is combined. // Erase the idle memory buf by size and device address when idle memory buf is combined.
void EraseIdleMemBuf(size_t size, const DeviceMemPtr &device_addr); void EraseIdleMemBuf(size_t size, const DeviceMemPtr &device_addr, const MemStatusManagerPtr &mem_mng);
// The global memory block list which is arranged in order by base device address of memory block.
std::vector<DynamicMemBlockPtr> global_mem_block_list_;
// The map of all idle memory buf by size.
SizeMapMemBuf global_idle_mem_buf_map_;
// The related memory statistics information.
size_t total_mem_statistics_{0};
size_t total_used_mem_statistics_{0};
size_t used_mem_peak_statistics_{0};
// The minimum memory unit size.
size_t mem_alloc_unit_size_{DYNAMIC_MEM_ALLOC_UNIT_SIZE};
MemStatusManagerPtr persistent_mem_{nullptr};
MemStatusManagerPtr common_mem_{nullptr};
// Support multi-thread. // Support multi-thread.
std::mutex mutex_; std::mutex mutex_;
}; };

View File

@ -864,7 +864,7 @@ void AscendSession::RunOpImplOrigin(const GraphInfo &graph_info, OpRunInfo *op_r
// malloc mem // malloc mem
RunOpRemoveNopNode(graph); RunOpRemoveNopNode(graph);
RunOpMemoryAlloc(*input_tensors, graph.get()); RunOpMemoryAlloc(*input_tensors, graph.get(), op_run_info->is_gradient_out);
RunOpGenKernelEvent(graph.get()); RunOpGenKernelEvent(graph.get());
AnfAlgo::CacheAddrForGraph(graph); AnfAlgo::CacheAddrForGraph(graph);
// Build dynamic kernel // Build dynamic kernel
@ -998,7 +998,7 @@ void AscendSession::BuildOpsInGraph(const GraphId &graph_id, const std::map<AnfN
GetOpInputStubTensors(kernel, parameter_index, graph_inputs, op_output_info, &input_tensor_info); GetOpInputStubTensors(kernel, parameter_index, graph_inputs, op_output_info, &input_tensor_info);
// Get OpRunInfo and GraphInfo // Get OpRunInfo and GraphInfo
const GraphInfo &graph_info = GetSingleOpGraphInfo(kernel, input_tensor_info.input_tensors); const GraphInfo &graph_info = GetSingleOpGraphInfo(kernel, input_tensor_info.input_tensors);
OpRunInfo op_run_info = GetSingleOpRunInfo(kernel, graph_info, input_tensor_info); OpRunInfo op_run_info = GetSingleOpRunInfo(kernel, graph_info, input_tensor_info, nullptr);
if (op_run_info.is_dynamic_shape) { if (op_run_info.is_dynamic_shape) {
MS_LOG(INFO) << "BuildOpsInGraph stop, op " << op_run_info.op_name << " is dynamic shape."; MS_LOG(INFO) << "BuildOpsInGraph stop, op " << op_run_info.op_name << " is dynamic shape.";
break; break;
@ -1318,19 +1318,19 @@ void AscendSession::MemoryAlloc(KernelGraph *kernel_graph) const {
MS_LOG(INFO) << "Status record: end memory alloc. graph id: " << kernel_graph->graph_id() MS_LOG(INFO) << "Status record: end memory alloc. graph id: " << kernel_graph->graph_id()
<< ", Memory Statistics:" << device::ascend::AscendMemAdapter::GetInstance().DevMemStatistics(); << ", Memory Statistics:" << device::ascend::AscendMemAdapter::GetInstance().DevMemStatistics();
MS_LOG(INFO) << "The dynamic memory pool total size is " MS_LOG(INFO) << "The dynamic memory pool total size is "
<< device::ascend::AscendMemoryPool::GetInstance().total_mem_statistics() / kMBToByte << device::ascend::AscendMemoryPool::GetInstance().TotalMemStatistics() / kMBToByte
<< "M, total used size is " << "M, total used size is "
<< device::ascend::AscendMemoryPool::GetInstance().used_mem_statistics() / kMBToByte << device::ascend::AscendMemoryPool::GetInstance().TotalUsedMemStatistics() / kMBToByte
<< "M, used peak size is " << "M, used peak size is "
<< device::ascend::AscendMemoryPool::GetInstance().used_mem_peak_statistics() / kMBToByte << "M."; << device::ascend::AscendMemoryPool::GetInstance().UsedMemPeakStatistics() / kMBToByte << "M.";
} }
void AscendSession::RunOpMemoryAlloc(const std::vector<tensor::TensorPtr> &input_tensors, void AscendSession::RunOpMemoryAlloc(const std::vector<tensor::TensorPtr> &input_tensors, KernelGraph *kernel_graph,
KernelGraph *kernel_graph) const { bool is_gradient_out) const {
MS_EXCEPTION_IF_NULL(kernel_graph); MS_EXCEPTION_IF_NULL(kernel_graph);
auto runtime_instance = device::KernelRuntimeManager::Instance().GetKernelRuntime(kAscendDevice, device_id_); auto runtime_instance = device::KernelRuntimeManager::Instance().GetKernelRuntime(kAscendDevice, device_id_);
MS_EXCEPTION_IF_NULL(runtime_instance); MS_EXCEPTION_IF_NULL(runtime_instance);
runtime_instance->RunOpAssignMemory(input_tensors, *kernel_graph); runtime_instance->RunOpAssignMemory(input_tensors, *kernel_graph, is_gradient_out);
} }
void AscendSession::RunOpMemoryAllocNew(const std::vector<tensor::TensorPtr> &input_tensors, void AscendSession::RunOpMemoryAllocNew(const std::vector<tensor::TensorPtr> &input_tensors,
@ -1338,7 +1338,7 @@ void AscendSession::RunOpMemoryAllocNew(const std::vector<tensor::TensorPtr> &in
const KernelGraph &kernel_graph) const { const KernelGraph &kernel_graph) const {
auto runtime_instance = device::KernelRuntimeManager::Instance().GetKernelRuntime(kAscendDevice, device_id_); auto runtime_instance = device::KernelRuntimeManager::Instance().GetKernelRuntime(kAscendDevice, device_id_);
MS_EXCEPTION_IF_NULL(runtime_instance); MS_EXCEPTION_IF_NULL(runtime_instance);
runtime_instance->RunOpAssignMemory(input_tensors, kernel_graph, tensor_to_node); runtime_instance->RunOpAssignMemory(input_tensors, kernel_graph, false, tensor_to_node);
} }
void AscendSession::RunOpGenKernelEvent(const KernelGraph *graph) const { void AscendSession::RunOpGenKernelEvent(const KernelGraph *graph) const {

View File

@ -103,7 +103,8 @@ class AscendSession : public SessionBasic {
static void BuildKernel(const std::vector<CNodePtr> &kernels); static void BuildKernel(const std::vector<CNodePtr> &kernels);
void BuildDynamicKernel(const std::shared_ptr<KernelGraph> &kernel_graph) const; void BuildDynamicKernel(const std::shared_ptr<KernelGraph> &kernel_graph) const;
void MemoryAlloc(KernelGraph *kernel_graph) const; void MemoryAlloc(KernelGraph *kernel_graph) const;
void RunOpMemoryAlloc(const std::vector<tensor::TensorPtr> &input_tensors, KernelGraph *kernel_graph) const; void RunOpMemoryAlloc(const std::vector<tensor::TensorPtr> &input_tensors, KernelGraph *kernel_graph,
bool is_gradient_out) const;
void RunOpMemoryAllocNew(const std::vector<tensor::TensorPtr> &input_tensors, void RunOpMemoryAllocNew(const std::vector<tensor::TensorPtr> &input_tensors,
const std::map<tensor::TensorPtr, session::KernelWithIndex> &tensor_to_node, const std::map<tensor::TensorPtr, session::KernelWithIndex> &tensor_to_node,
const KernelGraph &kernel_graph) const; const KernelGraph &kernel_graph) const;

View File

@ -255,11 +255,11 @@ void GPUSession::AllocateMemory(const KernelGraph *kernel_graph) const {
} }
void GPUSession::RunOpAllocateMemory(const std::vector<tensor::TensorPtr> &input_tensors, void GPUSession::RunOpAllocateMemory(const std::vector<tensor::TensorPtr> &input_tensors,
const KernelGraph *kernel_graph) const { const KernelGraph *kernel_graph, bool is_gradient_out) const {
MS_EXCEPTION_IF_NULL(kernel_graph); MS_EXCEPTION_IF_NULL(kernel_graph);
auto runtime_instance = device::KernelRuntimeManager::Instance().GetSingleKernelRuntime(kGPUDevice, device_id_); auto runtime_instance = device::KernelRuntimeManager::Instance().GetSingleKernelRuntime(kGPUDevice, device_id_);
MS_EXCEPTION_IF_NULL(runtime_instance); MS_EXCEPTION_IF_NULL(runtime_instance);
runtime_instance->RunOpAssignMemory(input_tensors, *kernel_graph); runtime_instance->RunOpAssignMemory(input_tensors, *kernel_graph, is_gradient_out);
} }
void GPUSession::RunOpGenKernelEvent(const KernelGraph *graph) const { void GPUSession::RunOpGenKernelEvent(const KernelGraph *graph) const {
@ -701,7 +701,7 @@ void GPUSession::RunOpImpl(const GraphInfo &graph_info, OpRunInfo *op_run_info,
// run op // run op
MS_EXCEPTION_IF_NULL(kernel_graph); MS_EXCEPTION_IF_NULL(kernel_graph);
RunOpRemoveNopNode(kernel_graph); RunOpRemoveNopNode(kernel_graph);
RunOpAllocateMemory(*input_tensors, kernel_graph.get()); RunOpAllocateMemory(*input_tensors, kernel_graph.get(), op_run_info->is_gradient_out);
RunOpGenKernelEvent(kernel_graph.get()); RunOpGenKernelEvent(kernel_graph.get());
// Execute the computation // Execute the computation
LoadInputData(kernel_graph, *input_tensors); LoadInputData(kernel_graph, *input_tensors);

View File

@ -82,7 +82,8 @@ class GPUSession : public SessionBasic {
void AllocateMemory(const KernelGraph *kernel_graph) const; void AllocateMemory(const KernelGraph *kernel_graph) const;
void RunOpAllocateMemory(const std::vector<tensor::TensorPtr> &input_tensors, const KernelGraph *kernel_graph) const; void RunOpAllocateMemory(const std::vector<tensor::TensorPtr> &input_tensors, const KernelGraph *kernel_graph,
bool is_gradient_out) const;
void RunOpClearMemory(const KernelGraph *kernel_graph) const; void RunOpClearMemory(const KernelGraph *kernel_graph) const;

View File

@ -54,6 +54,7 @@
#endif #endif
#include "backend/session/session_factory.h" #include "backend/session/session_factory.h"
#include "backend/session/pynative_task_manager.h" #include "backend/session/pynative_task_manager.h"
#include "pipeline/pynative/pynative_execute.h"
namespace mindspore { namespace mindspore {
namespace session { namespace session {
@ -1220,7 +1221,8 @@ GraphInfo SessionBasic::GetSingleOpGraphInfo(const CNodePtr &kernel,
} }
OpRunInfo SessionBasic::GetSingleOpRunInfo(const CNodePtr &cnode, const GraphInfo &graph_info, OpRunInfo SessionBasic::GetSingleOpRunInfo(const CNodePtr &cnode, const GraphInfo &graph_info,
const InputTensorInfo &tensor_info) { const InputTensorInfo &tensor_info,
GraphOutputInfo *const graph_output_info) {
MS_EXCEPTION_IF_NULL(cnode); MS_EXCEPTION_IF_NULL(cnode);
auto primitive = AnfAlgo::GetCNodePrimitive(cnode); auto primitive = AnfAlgo::GetCNodePrimitive(cnode);
const auto &abstract = cnode->abstract(); const auto &abstract = cnode->abstract();
@ -1230,7 +1232,14 @@ OpRunInfo SessionBasic::GetSingleOpRunInfo(const CNodePtr &cnode, const GraphInf
const auto &shape = abstract->BuildShape(); const auto &shape = abstract->BuildShape();
MS_EXCEPTION_IF_NULL(shape); MS_EXCEPTION_IF_NULL(shape);
OpRunInfo op_run_info = {.op_name = primitive->name(), bool is_gradient_out =
graph_output_info != nullptr &&
std::any_of(graph_output_info->output_indexes.begin(), graph_output_info->output_indexes.end(),
[cnode](const std::pair<KernelWithIndex, std::vector<std::vector<size_t>>> &output_index) {
return output_index.first.first == cnode;
});
OpRunInfo op_run_info = {.is_gradient_out = is_gradient_out,
.op_name = primitive->name(),
.primitive = primitive.get(), .primitive = primitive.get(),
.abstract = abstract, .abstract = abstract,
.is_dynamic_shape = shape->IsDynamic(), .is_dynamic_shape = shape->IsDynamic(),
@ -1304,16 +1313,63 @@ void SessionBasic::GetRefCount(const KernelGraph *graph, std::map<KernelWithInde
} }
} }
void SessionBasic::GetForwardOutputRefCount(const KernelGraph *graph,
std::map<AnfNodePtr, size_t> *forward_output_refcount) {
if (!pynative::PynativeExecutor::GetInstance()->grad_flag()) {
return;
}
const auto &forward_value_nodes_id = pynative::PynativeExecutor::GetInstance()->grad_executor()->forward_outputs_id();
for (const auto &kernel : graph->execution_order()) {
for (size_t i = 1; i < kernel->inputs().size(); i += 1) {
const auto &input = kernel->input(i);
if (!input->isa<ValueNode>()) {
continue;
}
auto value_node = input->cast<ValueNodePtr>();
auto value = GetValueNode(value_node);
MS_EXCEPTION_IF_NULL(value);
if (value->isa<tensor::Tensor>()) {
auto tensor = value->cast<tensor::TensorPtr>();
MS_EXCEPTION_IF_NULL(tensor);
if (forward_value_nodes_id.find(tensor->id()) != forward_value_nodes_id.end()) {
(*forward_output_refcount)[input] += 1;
}
}
}
}
}
void SessionBasic::HandleOpInputs(const std::set<KernelWithIndex> &input_kernel, void SessionBasic::HandleOpInputs(const std::set<KernelWithIndex> &input_kernel,
std::map<KernelWithIndex, size_t> *ref_count, std::map<KernelWithIndex, size_t> *ref_count,
std::map<AnfNodePtr, size_t> *forward_output_refcount,
std::map<KernelWithIndex, tensor::TensorPtr> *op_output_map) { std::map<KernelWithIndex, tensor::TensorPtr> *op_output_map) {
MS_EXCEPTION_IF_NULL(ref_count); MS_EXCEPTION_IF_NULL(ref_count);
MS_EXCEPTION_IF_NULL(forward_output_refcount);
MS_EXCEPTION_IF_NULL(op_output_map); MS_EXCEPTION_IF_NULL(op_output_map);
for (auto &kernel_with_index : input_kernel) { for (auto &kernel_with_index : input_kernel) {
MS_EXCEPTION_IF_NULL(kernel_with_index.first);
if (!kernel_with_index.first->isa<CNode>()) { if (!kernel_with_index.first->isa<CNode>()) {
continue; continue;
} }
// Release forward output
auto cnode = kernel_with_index.first->cast<CNodePtr>();
for (size_t i = 1; i < cnode->inputs().size(); i += 1) {
const auto &input = cnode->input(i);
auto it = forward_output_refcount->find(input);
if (it != forward_output_refcount->end()) {
if (--(it->second) == 0) {
auto value_node = input->cast<ValueNodePtr>();
auto value = GetValueNode(value_node);
MS_EXCEPTION_IF_NULL(value);
auto tensor = value->cast<tensor::TensorPtr>();
MS_EXCEPTION_IF_NULL(tensor);
tensor->set_device_address(nullptr);
forward_output_refcount->erase(it);
}
}
}
// Release previous output
auto ref_iter = ref_count->find(kernel_with_index); auto ref_iter = ref_count->find(kernel_with_index);
if (ref_iter == ref_count->end()) { if (ref_iter == ref_count->end()) {
MS_LOG(EXCEPTION) << "Can not find input KernelWithIndex in cnode reference count map, input cnode = " MS_LOG(EXCEPTION) << "Can not find input KernelWithIndex in cnode reference count map, input cnode = "
@ -2329,7 +2385,9 @@ void SessionBasic::RunOpsInGraphImpl(const GraphId &graph_id, const std::vector<
graph_output_info.graph_outputs = outputs; graph_output_info.graph_outputs = outputs;
CreateOutputPlaceholder(kernel_graph, inputs, graph_output_info.graph_outputs, &graph_output_info.output_indexes); CreateOutputPlaceholder(kernel_graph, inputs, graph_output_info.graph_outputs, &graph_output_info.output_indexes);
std::map<KernelWithIndex, size_t> cnode_refcount; std::map<KernelWithIndex, size_t> cnode_refcount;
std::map<AnfNodePtr, size_t> forward_output_refcount;
GetRefCount(kernel_graph.get(), &cnode_refcount); GetRefCount(kernel_graph.get(), &cnode_refcount);
GetForwardOutputRefCount(kernel_graph.get(), &forward_output_refcount);
BuildOpsInGraph(graph_id, parameter_index, inputs, cnode_refcount); BuildOpsInGraph(graph_id, parameter_index, inputs, cnode_refcount);
// Clear bucket resources every step // Clear bucket resources every step
@ -2346,14 +2404,14 @@ void SessionBasic::RunOpsInGraphImpl(const GraphId &graph_id, const std::vector<
VectorRef op_outputs; VectorRef op_outputs;
// Get OpRunInfo and GraphInfo // Get OpRunInfo and GraphInfo
GraphInfo graph_info = GetSingleOpGraphInfo(kernel, input_tensor_info.input_tensors); GraphInfo graph_info = GetSingleOpGraphInfo(kernel, input_tensor_info.input_tensors);
OpRunInfo run_info = GetSingleOpRunInfo(kernel, graph_info, input_tensor_info); OpRunInfo run_info = GetSingleOpRunInfo(kernel, graph_info, input_tensor_info, &graph_output_info);
// Build and run current single op // Build and run current single op
RunOpImplOrigin(graph_info, &run_info, &input_tensor_info.input_tensors, &op_outputs, RunOpImplOrigin(graph_info, &run_info, &input_tensor_info.input_tensors, &op_outputs,
input_tensor_info.input_tensors_mask); input_tensor_info.input_tensors_mask);
graph_output_info.graph_output_tensors.clear(); graph_output_info.graph_output_tensors.clear();
// Handle inputs and outputs of current op // Handle inputs and outputs of current op
HandleOpInputs(input_tensor_info.input_kernel, &cnode_refcount, &op_output_map); HandleOpInputs(input_tensor_info.input_kernel, &cnode_refcount, &forward_output_refcount, &op_output_map);
HandleOpOutputs(kernel, op_outputs, cnode_refcount, &op_output_map, &graph_output_info); HandleOpOutputs(kernel, op_outputs, cnode_refcount, &op_output_map, &graph_output_info);
// Save grad node to Bucket // Save grad node to Bucket
if (kernel_graph->is_bprop()) { if (kernel_graph->is_bprop()) {

View File

@ -57,6 +57,7 @@ using AnyList = std::vector<Any>;
using AnyListPtr = std::shared_ptr<AnyList>; using AnyListPtr = std::shared_ptr<AnyList>;
struct OpRunInfo { struct OpRunInfo {
bool is_gradient_out = false;
std::string op_name; std::string op_name;
Primitive *primitive; Primitive *primitive;
AbstractBasePtr abstract; AbstractBasePtr abstract;
@ -193,7 +194,9 @@ class SessionBasic : public std::enable_shared_from_this<SessionBasic> {
VectorRef *const outputs, VectorRef *const outputs,
std::map<KernelWithIndex, std::vector<std::vector<size_t>>> *output_indexes); std::map<KernelWithIndex, std::vector<std::vector<size_t>>> *output_indexes);
void GetRefCount(const KernelGraph *graph, std::map<KernelWithIndex, size_t> *ref_count); void GetRefCount(const KernelGraph *graph, std::map<KernelWithIndex, size_t> *ref_count);
void GetForwardOutputRefCount(const KernelGraph *graph, std::map<AnfNodePtr, size_t> *forward_output_refcount);
void HandleOpInputs(const std::set<KernelWithIndex> &input_kernel, std::map<KernelWithIndex, size_t> *ref_count, void HandleOpInputs(const std::set<KernelWithIndex> &input_kernel, std::map<KernelWithIndex, size_t> *ref_count,
std::map<AnfNodePtr, size_t> *forward_output_refcount,
std::map<KernelWithIndex, tensor::TensorPtr> *op_output_map); std::map<KernelWithIndex, tensor::TensorPtr> *op_output_map);
void HandleOpOutputs(const AnfNodePtr &kernel, const VectorRef &op_outputs, void HandleOpOutputs(const AnfNodePtr &kernel, const VectorRef &op_outputs,
@ -280,7 +283,8 @@ class SessionBasic : public std::enable_shared_from_this<SessionBasic> {
CNodePtr ConstructOutput(const AnfNodePtrList &outputs, const std::shared_ptr<KernelGraph> &graph); CNodePtr ConstructOutput(const AnfNodePtrList &outputs, const std::shared_ptr<KernelGraph> &graph);
// Generate graph info for a single op graph // Generate graph info for a single op graph
GraphInfo GetSingleOpGraphInfo(const CNodePtr &kernel, const std::vector<tensor::TensorPtr> &input_tensors); GraphInfo GetSingleOpGraphInfo(const CNodePtr &kernel, const std::vector<tensor::TensorPtr> &input_tensors);
OpRunInfo GetSingleOpRunInfo(const CNodePtr &cnode, const GraphInfo &graph_info, const InputTensorInfo &tensor_info); OpRunInfo GetSingleOpRunInfo(const CNodePtr &cnode, const GraphInfo &graph_info, const InputTensorInfo &tensor_info,
GraphOutputInfo *const graph_output_info);
tensor::TensorPtr GetValueNodeOutputTensor(const AnfNodePtr &node, size_t output_index); tensor::TensorPtr GetValueNodeOutputTensor(const AnfNodePtr &node, size_t output_index);
tensor::TensorPtr GetParameterOutputTensor(const AnfNodePtr &node, tensor::TensorPtr GetParameterOutputTensor(const AnfNodePtr &node,
const std::map<AnfNodePtr, size_t> &parameter_index, const std::map<AnfNodePtr, size_t> &parameter_index,

View File

@ -50,6 +50,7 @@ enum PynativeStatusCode {
enum RunOpArgsEnum { PY_PRIM = 0, PY_NAME, PY_INPUTS, PY_ARGS_NUM }; enum RunOpArgsEnum { PY_PRIM = 0, PY_NAME, PY_INPUTS, PY_ARGS_NUM };
struct OpExecInfo { struct OpExecInfo {
bool is_nop_prim = false;
bool is_dynamic_shape = false; bool is_dynamic_shape = false;
bool is_mixed_precision_cast = false; bool is_mixed_precision_cast = false;
size_t next_input_index = 0; size_t next_input_index = 0;

View File

@ -880,6 +880,37 @@ py::object GetDstType(const TypeId &type_id) {
bool IsPyObjTypeInvalid(const py::object &obj) { bool IsPyObjTypeInvalid(const py::object &obj) {
return !py::isinstance<tensor::Tensor>(obj) && !py::isinstance<py::int_>(obj) && !py::isinstance<py::float_>(obj); return !py::isinstance<tensor::Tensor>(obj) && !py::isinstance<py::int_>(obj) && !py::isinstance<py::float_>(obj);
} }
inline bool IsNopPrim(const std::string &op_name) {
static std::set<std::string> nop_prim = {prim::kPrimReshape->name(), kExpandDimsOpName, prim::kPrimSqueeze->name(),
prim::kPrimFlatten->name(), kFlattenGradOpName, prim::kPrimReformat->name()};
return nop_prim.find(op_name) != nop_prim.end();
}
// Shallow Copy Value and change shape
ValuePtr ShallowCopyValue(const OpExecInfoPtr &op_exec_info, const ValuePtr &value) {
MS_EXCEPTION_IF_NULL(op_exec_info);
MS_EXCEPTION_IF_NULL(value);
auto tensor_abs = op_exec_info->abstract;
if (tensor_abs->isa<abstract::AbstractRef>()) {
tensor_abs = tensor_abs->cast<abstract::AbstractRefPtr>()->CloneAsTensor();
}
auto new_shape = tensor_abs->BuildShape()->cast<abstract::ShapePtr>();
MS_EXCEPTION_IF_NULL(new_shape);
if (value->isa<mindspore::tensor::Tensor>()) {
auto tensor_value = value->cast<mindspore::tensor::TensorPtr>();
return std::make_shared<mindspore::tensor::Tensor>(tensor_value->data_type(), new_shape->shape(),
tensor_value->data_c(), tensor_value->Size());
} else if (value->isa<ValueTuple>()) {
std::vector<ValuePtr> values;
auto value_tuple = value->cast<ValueTuplePtr>();
(void)std::transform(value_tuple->value().begin(), value_tuple->value().end(), std::back_inserter(values),
[op_exec_info](const ValuePtr &elem) { return ShallowCopyValue(op_exec_info, elem); });
return std::make_shared<ValueTuple>(values);
} else {
return value;
}
}
} // namespace } // namespace
py::object RealRunOp(const py::args &args) { py::object RealRunOp(const py::args &args) {
@ -957,6 +988,7 @@ void TopCellInfo::Clear() {
k_pynative_cell_ptr_ = nullptr; k_pynative_cell_ptr_ = nullptr;
graph_info_map_.clear(); graph_info_map_.clear();
sub_cell_list_.clear(); sub_cell_list_.clear();
outputs_id_.clear();
op_info_with_tensor_id_.clear(); op_info_with_tensor_id_.clear();
tensor_id_with_tensor_object_.clear(); tensor_id_with_tensor_object_.clear();
op_info_with_ms_func_forward_tensors_.clear(); op_info_with_ms_func_forward_tensors_.clear();
@ -992,6 +1024,7 @@ OpExecInfoPtr ForwardExecutor::GenerateOpExecInfo(const py::args &args) {
const auto &op_exec_info = std::make_shared<OpExecInfo>(); const auto &op_exec_info = std::make_shared<OpExecInfo>();
const auto &op_name = py::cast<std::string>(args[PY_NAME]); const auto &op_name = py::cast<std::string>(args[PY_NAME]);
op_exec_info->op_name = op_name; op_exec_info->op_name = op_name;
op_exec_info->is_nop_prim = false; // IsNopPrim(op_exec_info->op_name);
const auto &adapter = py::cast<PrimitivePyAdapterPtr>(args[PY_PRIM]); const auto &adapter = py::cast<PrimitivePyAdapterPtr>(args[PY_PRIM]);
MS_EXCEPTION_IF_NULL(adapter); MS_EXCEPTION_IF_NULL(adapter);
@ -1013,7 +1046,7 @@ OpExecInfoPtr ForwardExecutor::GenerateOpExecInfo(const py::args &args) {
void ForwardExecutor::SetCastForInputs(const OpExecInfoPtr &op_exec_info) { void ForwardExecutor::SetCastForInputs(const OpExecInfoPtr &op_exec_info) {
MS_EXCEPTION_IF_NULL(op_exec_info); MS_EXCEPTION_IF_NULL(op_exec_info);
// No need cast self // No need cast self
if (op_exec_info->op_name == prim::kPrimCast->name()) { if (op_exec_info->op_name == prim::kPrimCast->name() || op_exec_info->is_nop_prim) {
return; return;
} }
@ -1167,6 +1200,21 @@ void ForwardExecutor::GetOpOutputAbstract(const OpExecInfoPtr &op_exec_info,
} }
} }
void ForwardExecutor::DoNopOutput(const OpExecInfoPtr &op_exec_info, ValuePtr *out_real_value) {
MS_EXCEPTION_IF_NULL(op_exec_info);
// Get First input
if (op_exec_info->op_inputs.empty()) {
MS_LOG(EXCEPTION) << "Inputs of " << op_exec_info->op_name << " is empty";
}
const auto &obj = op_exec_info->op_inputs[0];
if (!py::isinstance<tensor::Tensor>(obj)) {
MS_LOG(EXCEPTION) << "First input of " << op_exec_info->op_name << " must be a tensor";
}
const auto &tensor_ptr = py::cast<tensor::TensorPtr>(obj);
*out_real_value = ShallowCopyValue(op_exec_info, tensor_ptr);
MS_LOG(DEBUG) << "New copy value is " << (*out_real_value)->ToString();
}
void ForwardExecutor::GetOpOutput(const OpExecInfoPtr &op_exec_info, void ForwardExecutor::GetOpOutput(const OpExecInfoPtr &op_exec_info,
const abstract::AbstractBasePtrList &args_spec_list, const CNodePtr &cnode, const abstract::AbstractBasePtrList &args_spec_list, const CNodePtr &cnode,
bool prim_cache_hit, py::object *ret) { bool prim_cache_hit, py::object *ret) {
@ -1194,25 +1242,32 @@ void ForwardExecutor::GetOpOutput(const OpExecInfoPtr &op_exec_info,
out[args_spec_list].abs = op_exec_info->abstract; out[args_spec_list].abs = op_exec_info->abstract;
out[args_spec_list].attrs = prim->evaluate_added_attrs(); out[args_spec_list].attrs = prim->evaluate_added_attrs();
} }
// run op with selected backend
auto result = RunOpWithInitBackendPolicy(op_exec_info); // Run op with selected backend, nop is no need run backend
py::object out_real = result;
if (result.size() == 1 && op_exec_info->abstract != nullptr &&
!op_exec_info->abstract->isa<abstract::AbstractSequence>()) {
out_real = result[0];
}
// get output value
ValuePtr out_real_value = nullptr; ValuePtr out_real_value = nullptr;
if (grad()->grad_flag()) { if (op_exec_info->is_nop_prim) {
out_real_value = PyObjToValue(out_real); DoNopOutput(op_exec_info, &out_real_value);
*ret = BaseRefToPyData(out_real_value);
} else {
auto result = RunOpWithInitBackendPolicy(op_exec_info);
py::object out_real = result;
if (result.size() == 1 && op_exec_info->abstract != nullptr &&
!op_exec_info->abstract->isa<abstract::AbstractSequence>()) {
out_real = result[0];
}
// get output value
if (grad()->grad_flag()) {
out_real_value = PyObjToValue(out_real);
}
*ret = out_real;
} }
// Save cnode info and build grad graph
if (grad()->need_construct_graph() && !grad()->in_cell_with_custom_bprop_()) { if (grad()->need_construct_graph() && !grad()->in_cell_with_custom_bprop_()) {
MS_EXCEPTION_IF_NULL(cnode); MS_EXCEPTION_IF_NULL(cnode);
const auto &obj_id = GetId(out_real); const auto &obj_id = GetId(*ret);
cnode->set_abstract(op_exec_info->abstract); cnode->set_abstract(op_exec_info->abstract);
node_abs_map_[obj_id] = op_exec_info->abstract; node_abs_map_[obj_id] = op_exec_info->abstract;
grad()->SaveOutputNodeMap(obj_id, out_real, cnode); grad()->SaveOutputNodeMap(obj_id, *ret, cnode);
grad()->DoOpGrad(op_exec_info, cnode, out_real_value); grad()->DoOpGrad(op_exec_info, cnode, out_real_value);
} else { } else {
node_abs_map_.clear(); node_abs_map_.clear();
@ -1220,7 +1275,6 @@ void ForwardExecutor::GetOpOutput(const OpExecInfoPtr &op_exec_info,
// Record op info for judge whether the construct of cell has been changed // Record op info for judge whether the construct of cell has been changed
grad()->RecordGradOpInfo(op_exec_info, out_real_value); grad()->RecordGradOpInfo(op_exec_info, out_real_value);
grad()->UpdateForwardTensorInfoInBpropGraph(op_exec_info, out_real_value); grad()->UpdateForwardTensorInfoInBpropGraph(op_exec_info, out_real_value);
*ret = out_real;
} }
py::object ForwardExecutor::DoAutoCast(const py::object &arg, const TypeId &type_id, const std::string &op_name, py::object ForwardExecutor::DoAutoCast(const py::object &arg, const TypeId &type_id, const std::string &op_name,
@ -1701,7 +1755,7 @@ void GradExecutor::SaveOutputNodeMap(const std::string &obj_id, const py::object
return; return;
} }
MS_EXCEPTION_IF_NULL(cnode); MS_EXCEPTION_IF_NULL(cnode);
MS_LOG(DEBUG) << "Cnode is " << cnode->DebugString() << " id " << obj_id; MS_LOG(DEBUG) << "Cnode is " << cnode->DebugString() << ", out value id " << obj_id;
if (py::isinstance<py::tuple>(out_real)) { if (py::isinstance<py::tuple>(out_real)) {
auto value = py::cast<py::tuple>(out_real); auto value = py::cast<py::tuple>(out_real);
auto size = static_cast<int64_t>(value.size()); auto size = static_cast<int64_t>(value.size());
@ -1925,6 +1979,7 @@ void GradExecutor::SaveForwardTensorInfoInBpropGraph(const pipeline::ResourcePtr
continue; continue;
} }
tensor_id_with_tensor_object[tensor->id()].emplace_back(tensor); tensor_id_with_tensor_object[tensor->id()].emplace_back(tensor);
top_cell()->outputs_id().insert(tensor->id());
MS_LOG(DEBUG) << "Save forward tensor " << tensor.get() << " id " << tensor->id() MS_LOG(DEBUG) << "Save forward tensor " << tensor.get() << " id " << tensor->id()
<< " device address: " << tensor->device_address() << " shape and dtype " << " device address: " << tensor->device_address() << " shape and dtype "
<< tensor->GetShapeAndDataTypeInfo(); << tensor->GetShapeAndDataTypeInfo();
@ -2090,7 +2145,8 @@ py::object ForwardExecutor::RunOpInMs(const OpExecInfoPtr &op_exec_info, Pynativ
// get graph info for checking it whether existing in the cache // get graph info for checking it whether existing in the cache
GetSingleOpGraphInfo(op_exec_info, input_tensors, tensors_mask, &graph_info); GetSingleOpGraphInfo(op_exec_info, input_tensors, tensors_mask, &graph_info);
#if defined(__APPLE__) #if defined(__APPLE__)
session::OpRunInfo op_run_info = {op_exec_info->op_name, session::OpRunInfo op_run_info = {false,
op_exec_info->op_name,
op_exec_info->py_primitive.get(), op_exec_info->py_primitive.get(),
op_exec_info->abstract, op_exec_info->abstract,
op_exec_info->is_dynamic_shape, op_exec_info->is_dynamic_shape,
@ -2102,7 +2158,8 @@ py::object ForwardExecutor::RunOpInMs(const OpExecInfoPtr &op_exec_info, Pynativ
tensors_mask, tensors_mask,
input_tensors}; input_tensors};
#else #else
session::OpRunInfo op_run_info = {op_exec_info->op_name, session::OpRunInfo op_run_info = {false,
op_exec_info->op_name,
op_exec_info->py_primitive.get(), op_exec_info->py_primitive.get(),
op_exec_info->abstract, op_exec_info->abstract,
op_exec_info->is_dynamic_shape, op_exec_info->is_dynamic_shape,

View File

@ -100,6 +100,7 @@ class TopCellInfo {
const std::string &grad_operation() const { return grad_operation_; } const std::string &grad_operation() const { return grad_operation_; }
void set_grad_operation(const std::string &grad_operation) { grad_operation_ = grad_operation; } void set_grad_operation(const std::string &grad_operation) { grad_operation_ = grad_operation; }
mindspore::HashSet<std::string> &sub_cell_list() { return sub_cell_list_; } mindspore::HashSet<std::string> &sub_cell_list() { return sub_cell_list_; }
std::set<std::string> &outputs_id() { return outputs_id_; }
bool IsSubCell(const std::string &cell_id) const; bool IsSubCell(const std::string &cell_id) const;
OrderedMap<FuncGraphPtr, GraphInfoPtr> &graph_info_map() { return graph_info_map_; } OrderedMap<FuncGraphPtr, GraphInfoPtr> &graph_info_map() { return graph_info_map_; }
OpInfoWithTensorId &op_info_with_tensor_id() { return op_info_with_tensor_id_; } OpInfoWithTensorId &op_info_with_tensor_id() { return op_info_with_tensor_id_; }
@ -139,6 +140,8 @@ class TopCellInfo {
std::string grad_operation_; std::string grad_operation_;
OrderedMap<FuncGraphPtr, GraphInfoPtr> graph_info_map_; OrderedMap<FuncGraphPtr, GraphInfoPtr> graph_info_map_;
mindspore::HashSet<std::string> sub_cell_list_; mindspore::HashSet<std::string> sub_cell_list_;
// Record op output tensor id
std::set<std::string> outputs_id_;
OpInfoWithTensorId op_info_with_tensor_id_; OpInfoWithTensorId op_info_with_tensor_id_;
TensorIdWithTensorObject tensor_id_with_tensor_object_; TensorIdWithTensorObject tensor_id_with_tensor_object_;
OpInfoWithMsFuncForwardTensors op_info_with_ms_func_forward_tensors_; OpInfoWithMsFuncForwardTensors op_info_with_ms_func_forward_tensors_;
@ -195,6 +198,7 @@ class GradExecutor {
void set_grad_flag(bool flag) { grad_flag_ = flag; } void set_grad_flag(bool flag) { grad_flag_ = flag; }
void set_graph_phase(const std::string &graph_phase) { graph_phase_ = graph_phase; } void set_graph_phase(const std::string &graph_phase) { graph_phase_ = graph_phase; }
bool in_cell_with_custom_bprop_() const { return custom_bprop_cell_count_ > 0; } bool in_cell_with_custom_bprop_() const { return custom_bprop_cell_count_ > 0; }
std::set<std::string> &forward_outputs_id() const { return top_cell()->outputs_id(); }
AnfNodePtr GetInput(const py::object &obj, bool op_mask); AnfNodePtr GetInput(const py::object &obj, bool op_mask);
std::string GetCellId(const py::object &obj, const py::args &args); std::string GetCellId(const py::object &obj, const py::args &args);
void RecordGradOpInfo(const OpExecInfoPtr &op_exec_info, const ValuePtr &op_out); void RecordGradOpInfo(const OpExecInfoPtr &op_exec_info, const ValuePtr &op_out);
@ -347,6 +351,7 @@ class ForwardExecutor {
bool *prim_cache_hit); bool *prim_cache_hit);
void GetOpOutput(const OpExecInfoPtr &op_exec_info, const abstract::AbstractBasePtrList &args_spec_list, void GetOpOutput(const OpExecInfoPtr &op_exec_info, const abstract::AbstractBasePtrList &args_spec_list,
const CNodePtr &cnode, bool prim_cache_hit, py::object *ret); const CNodePtr &cnode, bool prim_cache_hit, py::object *ret);
void DoNopOutput(const OpExecInfoPtr &op_exec_info, ValuePtr *out_real_value);
// Mix precision and Implicit transform // Mix precision and Implicit transform
void SetCastForInputs(const OpExecInfoPtr &op_exec_info); void SetCastForInputs(const OpExecInfoPtr &op_exec_info);
void SetTensorMixPrecisionCast(const OpExecInfoPtr &op_exec_info); void SetTensorMixPrecisionCast(const OpExecInfoPtr &op_exec_info);

View File

@ -84,6 +84,7 @@ REGISTER_PYBIND_DEFINE(MsContextPy, ([](const py::module *m) {
.value("save_graphs", MsCtxParam::MS_CTX_SAVE_GRAPHS_FLAG) .value("save_graphs", MsCtxParam::MS_CTX_SAVE_GRAPHS_FLAG)
.value("enable_parallel_split", MsCtxParam::MS_CTX_ENABLE_PARALLEL_SPLIT) .value("enable_parallel_split", MsCtxParam::MS_CTX_ENABLE_PARALLEL_SPLIT)
.value("max_device_memory", MsCtxParam::MS_CTX_MAX_DEVICE_MEMORY) .value("max_device_memory", MsCtxParam::MS_CTX_MAX_DEVICE_MEMORY)
.value("mempool_block_size", MsCtxParam::MS_CTX_MEMPOOL_BLOCK_SIZE)
.value("mode", MsCtxParam::MS_CTX_EXECUTION_MODE) .value("mode", MsCtxParam::MS_CTX_EXECUTION_MODE)
.value("device_target", MsCtxParam::MS_CTX_DEVICE_TARGET) .value("device_target", MsCtxParam::MS_CTX_DEVICE_TARGET)
.value("_graph_memory_max_size", MsCtxParam::MS_CTX_GRAPH_MEMORY_MAX_SIZE) .value("_graph_memory_max_size", MsCtxParam::MS_CTX_GRAPH_MEMORY_MAX_SIZE)

View File

@ -77,6 +77,7 @@ bool AscendMemAdapter::Initialize() {
static_mem_offset_ = ms_used_hbm_size_; static_mem_offset_ = ms_used_hbm_size_;
cur_dynamic_mem_offset_ = 0; cur_dynamic_mem_offset_ = 0;
max_dynamic_mem_offset_ = 0; max_dynamic_mem_offset_ = 0;
AscendMemoryPool::GetInstance().SetMempoolBlockSize(ms_used_hbm_size_);
MS_LOG(INFO) << "Ascend Memory Adapter initialize success, Memory Statistics:" << DevMemStatistics(); MS_LOG(INFO) << "Ascend Memory Adapter initialize success, Memory Statistics:" << DevMemStatistics();
initialized_ = true; initialized_ = true;
return true; return true;
@ -109,7 +110,7 @@ bool AscendMemAdapter::DeInitialize() {
return ret; return ret;
} }
uint8_t *AscendMemAdapter::MallocStaticDevMem(size_t size, std::string tag) { uint8_t *AscendMemAdapter::MallocStaticDevMem(size_t size, const std::string &tag) {
std::lock_guard<std::mutex> locker(mutex_); std::lock_guard<std::mutex> locker(mutex_);
auto new_static_offset = static_mem_offset_ - size; auto new_static_offset = static_mem_offset_ - size;
if (new_static_offset < max_dynamic_mem_offset_) { if (new_static_offset < max_dynamic_mem_offset_) {
@ -122,11 +123,10 @@ uint8_t *AscendMemAdapter::MallocStaticDevMem(size_t size, std::string tag) {
auto memory_block_ptr = device_mem_base_addr_ + new_static_offset; auto memory_block_ptr = device_mem_base_addr_ + new_static_offset;
static_mem_offset_ = new_static_offset; static_mem_offset_ = new_static_offset;
static_memory_block_list_.push_back(std::make_shared<MemoryBlock>(memory_block_ptr, size, tag)); static_memory_block_list_.push_back(std::make_shared<MemoryBlock>(memory_block_ptr, size, tag));
return memory_block_ptr; return memory_block_ptr;
} }
uint8_t *AscendMemAdapter::MallocDynamicDevMem(size_t size, std::string tag) { uint8_t *AscendMemAdapter::MallocDynamicDevMem(size_t size, const std::string &tag) {
std::lock_guard<std::mutex> locker(mutex_); std::lock_guard<std::mutex> locker(mutex_);
auto new_dynamic_offset = cur_dynamic_mem_offset_ + size; auto new_dynamic_offset = cur_dynamic_mem_offset_ + size;
if (new_dynamic_offset > static_mem_offset_) { if (new_dynamic_offset > static_mem_offset_) {

View File

@ -22,6 +22,7 @@
#include <memory> #include <memory>
#include <vector> #include <vector>
#include "utils/ms_context.h" #include "utils/ms_context.h"
#include "runtime/device/ascend/ascend_memory_pool.h"
namespace mindspore { namespace mindspore {
namespace device { namespace device {
@ -37,8 +38,8 @@ class AscendMemAdapter {
bool Initialize(); bool Initialize();
bool DeInitialize(); bool DeInitialize();
uint8_t *MallocStaticDevMem(size_t size, std::string tag = ""); uint8_t *MallocStaticDevMem(size_t size, const std::string &tag = "");
uint8_t *MallocDynamicDevMem(size_t size, std::string tag = ""); uint8_t *MallocDynamicDevMem(size_t size, const std::string &tag = "");
bool FreeStaticDevMem(void *devPtr) { return true; } bool FreeStaticDevMem(void *devPtr) { return true; }
void ResetDynamicMemory(); void ResetDynamicMemory();
@ -76,12 +77,13 @@ class AscendMemAdapter {
uint8_t *device_mem_base_addr_{nullptr}; uint8_t *device_mem_base_addr_{nullptr};
uint64_t ms_used_hbm_size_{0}; uint64_t ms_used_hbm_size_{0};
// dynamic memory info // dynamic memory info, from a low address to a high address
uint64_t cur_dynamic_mem_offset_{0}; uint64_t cur_dynamic_mem_offset_{0};
// Maximum dynamic memory have already allocated, dynamically updated
uint64_t max_dynamic_mem_offset_{0}; uint64_t max_dynamic_mem_offset_{0};
std::vector<std::shared_ptr<MemoryBlock>> dynamic_memory_block_list_; std::vector<std::shared_ptr<MemoryBlock>> dynamic_memory_block_list_;
// static memory info // static memory info, from a high address to a low address
uint64_t static_mem_offset_{0}; uint64_t static_mem_offset_{0};
std::vector<std::shared_ptr<MemoryBlock>> static_memory_block_list_; std::vector<std::shared_ptr<MemoryBlock>> static_memory_block_list_;
}; };

View File

@ -15,7 +15,6 @@
*/ */
#include <string> #include <string>
#include "runtime/device/ascend/ascend_memory_manager.h" #include "runtime/device/ascend/ascend_memory_manager.h"
#include "runtime/device/ascend/ascend_memory_pool.h"
#include "runtime/device/ascend/ascend_memory_adapter.h" #include "runtime/device/ascend/ascend_memory_adapter.h"
#include "utils/ms_context.h" #include "utils/ms_context.h"
#include "runtime/mem.h" #include "runtime/mem.h"
@ -49,9 +48,9 @@ void *AscendMemoryManager::MallocDevice(size_t size) {
return AscendMemoryPool::GetInstance().AllocTensorMem(align_size); return AscendMemoryPool::GetInstance().AllocTensorMem(align_size);
} }
void *AscendMemoryManager::MallocMemFromMemPool(size_t size) { void *AscendMemoryManager::MallocMemFromMemPool(size_t size, bool from_persistent_mem) {
auto align_size = GetCommonAlignSize(size); auto align_size = GetCommonAlignSize(size);
const auto device_addr = AscendMemoryPool::GetInstance().AllocTensorMem(align_size); const auto device_addr = AscendMemoryPool::GetInstance().AllocTensorMem(align_size, from_persistent_mem);
if (device_addr == nullptr) { if (device_addr == nullptr) {
MS_LOG(EXCEPTION) << "Fail to alloc memory, size: " << align_size MS_LOG(EXCEPTION) << "Fail to alloc memory, size: " << align_size
<< ", memory statistics:" << AscendMemAdapter::GetInstance().DevMemStatistics(); << ", memory statistics:" << AscendMemAdapter::GetInstance().DevMemStatistics();
@ -132,8 +131,8 @@ uint8_t *AscendMemoryManager::MallocCommunicationMemFromMemPool(size_t size) {
size_t AscendMemoryManager::GetAvailableMemSize() { size_t AscendMemoryManager::GetAvailableMemSize() {
auto available_mem_size = AscendMemoryPool::GetInstance().free_mem_size() + auto available_mem_size = AscendMemoryPool::GetInstance().free_mem_size() +
AscendMemoryPool::GetInstance().total_mem_statistics() - AscendMemoryPool::GetInstance().TotalMemStatistics() -
AscendMemoryPool::GetInstance().used_mem_statistics(); AscendMemoryPool::GetInstance().TotalUsedMemStatistics();
return available_mem_size; return available_mem_size;
} }

View File

@ -32,7 +32,7 @@ class AscendMemoryManager : public MemoryManager {
void FreeDeviceMemory() override; void FreeDeviceMemory() override;
void ResetDynamicMemory() override; void ResetDynamicMemory() override;
void ClearGlobalIdleMem() override; void ClearGlobalIdleMem() override;
void *MallocMemFromMemPool(size_t size) override; void *MallocMemFromMemPool(size_t size, bool from_persistent_mem) override;
void *MallocDevice(size_t size) override; void *MallocDevice(size_t size) override;
void FreeMemFromMemPool(void *device_ptr) override; void FreeMemFromMemPool(void *device_ptr) override;
uint64_t GetMsMaxMemSize(); uint64_t GetMsMaxMemSize();

View File

@ -24,37 +24,37 @@
namespace mindspore { namespace mindspore {
namespace device { namespace device {
namespace ascend { namespace ascend {
// The minimum unit size (256MB) of memory block used for dynamic extend.
static const size_t ASCEND_DYNAMIC_MEM_ALLOC_UNIT_SIZE = 256 << 20;
// The minimum unit size (8MB) of memory block used for dynamic extend in graph mode. // The minimum unit size (8MB) of memory block used for dynamic extend in graph mode.
static const size_t ASCEND_DYNAMIC_MEM_ALLOC_UNIT_SIZE_FOR_GRAPH = 8 << 20; static const size_t ASCEND_DYNAMIC_MEM_ALLOC_UNIT_SIZE_FOR_GRAPH = 8 << 20;
size_t AscendMemoryPool::CalMemBlockAllocSize(size_t size) { size_t AscendMemoryPool::CalMemBlockAllocSize(size_t size, bool from_persistent_mem) {
auto device_free_mem_size = free_mem_size(); auto device_free_mem_size = free_mem_size();
if (device_free_mem_size < size) { if (device_free_mem_size < size) {
MS_LOG(WARNING) << "The dynamic memory pool total size is " MS_LOG(WARNING) << "The dynamic memory pool total size is "
<< device::ascend::AscendMemoryPool::GetInstance().total_mem_statistics() / kMBToByte << device::ascend::AscendMemoryPool::GetInstance().TotalMemStatistics() / kMBToByte
<< "M, total used size is " << "M, total used size is "
<< device::ascend::AscendMemoryPool::GetInstance().used_mem_statistics() / kMBToByte << device::ascend::AscendMemoryPool::GetInstance().TotalUsedMemStatistics() / kMBToByte
<< "M, used peak size is " << "M, used peak size is "
<< device::ascend::AscendMemoryPool::GetInstance().used_mem_peak_statistics() / kMBToByte << "M."; << device::ascend::AscendMemoryPool::GetInstance().UsedMemPeakStatistics() / kMBToByte << "M.";
MS_LOG(WARNING) << "Out of Memory. Request memory size: " << size MS_LOG(WARNING) << "Out of Memory. Request memory size: " << size << ", device free size " << device_free_mem_size
<< ", Memory Statistic:" << AscendMemAdapter::GetInstance().DevMemStatistics() << ", Memory Statistic:" << AscendMemAdapter::GetInstance().DevMemStatistics()
<< "Please try to reduce 'batch_size' or check whether exists extra large shape. More " << "Please try to reduce 'batch_size' or check whether exists extra large shape. More "
"details can be found in MindSpore's FAQ with keyword 'Out of Memory'."; "details can be found in MindSpore's FAQ with keyword 'Out of Memory'.";
return 0; return 0;
} }
auto alloc_mem_size = ASCEND_DYNAMIC_MEM_ALLOC_UNIT_SIZE; size_t alloc_mem_size = MemAllocUnitSize(from_persistent_mem);
auto ms_context = MsContext::GetInstance(); auto ms_context = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(ms_context); MS_EXCEPTION_IF_NULL(ms_context);
const bool pynative_mode = (ms_context->get_param<int>(MS_CTX_EXECUTION_MODE) == kPynativeMode); const bool pynative_mode = (ms_context->get_param<int>(MS_CTX_EXECUTION_MODE) == kPynativeMode);
if (pynative_mode) { if (pynative_mode) {
// Growing at twice of alloc size // Growing at twice of alloc size
MS_LOG(DEBUG) << "Get unit block size " << alloc_mem_size;
constexpr size_t kDouble = 2; constexpr size_t kDouble = 2;
while (alloc_mem_size < size) { while (alloc_mem_size < size) {
alloc_mem_size = alloc_mem_size * kDouble; alloc_mem_size = alloc_mem_size * kDouble;
} }
} else { } else {
// The graph mode controls itself independently
alloc_mem_size = ASCEND_DYNAMIC_MEM_ALLOC_UNIT_SIZE_FOR_GRAPH; alloc_mem_size = ASCEND_DYNAMIC_MEM_ALLOC_UNIT_SIZE_FOR_GRAPH;
while (alloc_mem_size < size) { while (alloc_mem_size < size) {
alloc_mem_size = alloc_mem_size + ASCEND_DYNAMIC_MEM_ALLOC_UNIT_SIZE_FOR_GRAPH; alloc_mem_size = alloc_mem_size + ASCEND_DYNAMIC_MEM_ALLOC_UNIT_SIZE_FOR_GRAPH;
@ -69,7 +69,6 @@ size_t AscendMemoryPool::AllocDeviceMem(size_t size, DeviceMemPtr *addr) {
if (size == 0) { if (size == 0) {
MS_LOG(EXCEPTION) << "Failed to alloc memory pool resource, the size is zero!"; MS_LOG(EXCEPTION) << "Failed to alloc memory pool resource, the size is zero!";
} }
*addr = AscendMemAdapter::GetInstance().MallocStaticDevMem(size); *addr = AscendMemAdapter::GetInstance().MallocStaticDevMem(size);
if (*addr == nullptr) { if (*addr == nullptr) {
MS_LOG(EXCEPTION) << "Alloc device memory pool address is nullptr, failed to alloc memory pool resource!"; MS_LOG(EXCEPTION) << "Alloc device memory pool address is nullptr, failed to alloc memory pool resource!";
@ -83,11 +82,17 @@ bool AscendMemoryPool::FreeDeviceMem(const DeviceMemPtr &addr) {
} }
void AscendMemoryPool::ResetIdleMemBuf() { void AscendMemoryPool::ResetIdleMemBuf() {
auto idle_mem_buf_map = DynamicMemPoolBestFit::global_idle_mem_buf_map(); auto fn = [this](const MemStatusManagerPtr &mem_mng) {
for (auto &it : idle_mem_buf_map) { if (mem_mng->mem_block_list_.empty()) {
MS_EXCEPTION_IF_NULL(it.second); return;
(void)rtMemset(it.second->device_addr_, it.first, 0, it.first); }
} for (auto &it : mem_mng->idle_mem_buf_map_) {
MS_EXCEPTION_IF_NULL(it.second);
(void)rtMemset(it.second->device_addr_, it.first, 0, it.first);
}
};
fn(persistent_mem());
fn(common_mem());
} }
size_t AscendMemoryPool::free_mem_size() { return AscendMemAdapter::GetInstance().FreeDevMemSize(); } size_t AscendMemoryPool::free_mem_size() { return AscendMemAdapter::GetInstance().FreeDevMemSize(); }

View File

@ -42,7 +42,7 @@ class AscendMemoryPool : public DynamicMemPoolBestFit {
protected: protected:
// Calculate memory block required alloc size when adding the memory block. // Calculate memory block required alloc size when adding the memory block.
size_t CalMemBlockAllocSize(size_t size) override; size_t CalMemBlockAllocSize(size_t size, bool from_persistent_mem) override;
private: private:
AscendMemoryPool() = default; AscendMemoryPool() = default;

View File

@ -112,6 +112,7 @@ void CPUKernelRuntime::AssignValueNodeAddress(session::KernelGraph *kernel_graph
size_t tensor_size = std::accumulate(data_shape.begin(), data_shape.end(), type_size, std::multiplies<size_t>()); size_t tensor_size = std::accumulate(data_shape.begin(), data_shape.end(), type_size, std::multiplies<size_t>());
DeviceAddressPtr address = nullptr; DeviceAddressPtr address = nullptr;
address = CreateDeviceAddress(nullptr, tensor_size, kOpFormat_DEFAULT, output_type_id); address = CreateDeviceAddress(nullptr, tensor_size, kOpFormat_DEFAULT, output_type_id);
address->set_from_persistent_mem(tensor->is_parameter());
MS_EXCEPTION_IF_NULL(address); MS_EXCEPTION_IF_NULL(address);
if (tensor->data_type() == output_type_id) { if (tensor->data_type() == output_type_id) {
address->ptr_ = tensor->data_c(); address->ptr_ = tensor->data_c();
@ -146,6 +147,7 @@ void CPUKernelRuntime::AssignInputNodeAddress(const session::KernelGraph *kernel
: std::accumulate(fmt_shape.begin(), fmt_shape.end(), type_size, std::multiplies<size_t>()); : std::accumulate(fmt_shape.begin(), fmt_shape.end(), type_size, std::multiplies<size_t>());
auto format = AnfAlgo::GetOutputFormat(item, index); auto format = AnfAlgo::GetOutputFormat(item, index);
auto address = CreateDeviceAddress(nullptr, tensor_size, format, output_type_id); auto address = CreateDeviceAddress(nullptr, tensor_size, format, output_type_id);
address->set_from_persistent_mem(true);
AnfAlgo::SetOutputAddr(address, index, item.get()); AnfAlgo::SetOutputAddr(address, index, item.get());
} }
} }

View File

@ -46,7 +46,9 @@ class CPUMemoryManager : public MemoryManager {
void IncreaseSummaryRefCount(const session::NamedSummaryOutputs &summary_outputs); void IncreaseSummaryRefCount(const session::NamedSummaryOutputs &summary_outputs);
void DecreaseSummaryRefCount(const session::NamedSummaryOutputs &summary_outputs); void DecreaseSummaryRefCount(const session::NamedSummaryOutputs &summary_outputs);
void *MallocMemFromMemPool(size_t size) override { return CPUMemoryPool::GetInstance().AllocTensorMem(size); } void *MallocMemFromMemPool(size_t size, bool from_persistent_mem) override {
return CPUMemoryPool::GetInstance().AllocTensorMem(size, from_persistent_mem);
}
void FreeMemFromMemPool(void *device_ptr) override { CPUMemoryPool::GetInstance().FreeTensorMem(device_ptr); } void FreeMemFromMemPool(void *device_ptr) override { CPUMemoryPool::GetInstance().FreeTensorMem(device_ptr); }
std::vector<void *> MallocContinuousMemFromMemPool(size_t total_size, std::vector<size_t> size_list) override { std::vector<void *> MallocContinuousMemFromMemPool(size_t total_size, std::vector<size_t> size_list) override {
return CPUMemoryPool::GetInstance().AllocContinuousTensorMem(total_size, size_list); return CPUMemoryPool::GetInstance().AllocContinuousTensorMem(total_size, size_list);

View File

@ -101,6 +101,8 @@ class DeviceAddress : public mindspore::DeviceSync {
bool is_ptr_persisted() const { return is_ptr_persisted_; } bool is_ptr_persisted() const { return is_ptr_persisted_; }
void set_is_ptr_persisted(bool is_ptr_persisted) { is_ptr_persisted_ = is_ptr_persisted; } void set_is_ptr_persisted(bool is_ptr_persisted) { is_ptr_persisted_ = is_ptr_persisted; }
void set_host_shape(const ShapeVector &shape) { host_shape_ = shape; } void set_host_shape(const ShapeVector &shape) { host_shape_ = shape; }
bool from_persistent_mem() const { return from_persistent_mem_; }
void set_from_persistent_mem(bool from_persistent_mem) { from_persistent_mem_ = from_persistent_mem; }
virtual void set_status(DeviceAddressStatus status) {} virtual void set_status(DeviceAddressStatus status) {}
virtual DeviceAddressStatus status() const { return DeviceAddressStatus::kInDevice; } virtual DeviceAddressStatus status() const { return DeviceAddressStatus::kInDevice; }
virtual DeviceAddressType DeviceType() const { return DeviceAddressType::kUnknown; } virtual DeviceAddressType DeviceType() const { return DeviceAddressType::kUnknown; }
@ -145,6 +147,7 @@ class DeviceAddress : public mindspore::DeviceSync {
// The key of device context. // The key of device context.
std::string device_name_{""}; std::string device_name_{""};
uint32_t device_id_{0}; uint32_t device_id_{0};
bool from_persistent_mem_{false};
friend class KernelRuntime; friend class KernelRuntime;
friend class MemoryManager; friend class MemoryManager;

View File

@ -422,7 +422,7 @@ void GPUKernelRuntime::FetchMemUnitSize(const session::KernelGraph *graph) {
current_sum_size = 0; current_sum_size = 0;
} }
} }
if (max_sum_size > GPUMemoryAllocator::GetInstance().mem_alloc_unit_size()) { if (max_sum_size > GPUMemoryAllocator::GetInstance().MemAllocUnitSize()) {
size_t unit_size = (max_sum_size / DYNAMIC_MEM_ALLOC_UNIT_SIZE + 1) * DYNAMIC_MEM_ALLOC_UNIT_SIZE; size_t unit_size = (max_sum_size / DYNAMIC_MEM_ALLOC_UNIT_SIZE + 1) * DYNAMIC_MEM_ALLOC_UNIT_SIZE;
if (unit_size < DYNAMIC_MEM_ALLOC_UNIT_SIZE) { if (unit_size < DYNAMIC_MEM_ALLOC_UNIT_SIZE) {
MS_LOG(WARNING) << "Current memory unit size [" << unit_size << "] is too small."; MS_LOG(WARNING) << "Current memory unit size [" << unit_size << "] is too small.";
@ -432,7 +432,7 @@ void GPUKernelRuntime::FetchMemUnitSize(const session::KernelGraph *graph) {
constexpr float kValidMemoryRatio = 0.9; constexpr float kValidMemoryRatio = 0.9;
free_mem_size = kValidMemoryRatio * free_mem_size; free_mem_size = kValidMemoryRatio * free_mem_size;
unit_size = std::min(unit_size, free_mem_size); unit_size = std::min(unit_size, free_mem_size);
GPUMemoryAllocator::GetInstance().set_mem_alloc_unit_size(unit_size); GPUMemoryAllocator::GetInstance().SetMemAllocUintSize(unit_size);
} }
} }

View File

@ -25,6 +25,7 @@ namespace mindspore {
namespace device { namespace device {
namespace gpu { namespace gpu {
const size_t kGBToByte = 1024 << 20; const size_t kGBToByte = 1024 << 20;
constexpr float kReservedMemoryRatio = 0.0625; // 1/16
bool GPUMemoryAllocator::Init() { bool GPUMemoryAllocator::Init() {
size_t total_size = CudaDriver::total_mem_size(); size_t total_size = CudaDriver::total_mem_size();
@ -42,6 +43,13 @@ bool GPUMemoryAllocator::Init() {
<< total_size << ", current free memory size " << free_size << ", set max available memory size " << total_size << ", current free memory size " << free_size << ", set max available memory size "
<< available_device_memory_ << "."; << available_device_memory_ << ".";
} }
// In gpu mode, recommend 1/16 reserved for other cuda functions
if (available_device_memory_ > total_size) {
size_t recommend_mem_size_for_others = FloatToSize(total_size * kReservedMemoryRatio);
SetMempoolBlockSize(std::min(available_device_memory_, total_size - recommend_mem_size_for_others));
} else {
SetMempoolBlockSize(std::min(available_device_memory_, total_size));
}
return true; return true;
} }
@ -71,7 +79,7 @@ bool GPUMemoryAllocator::AllocBufferQueueMem(size_t size, DeviceMemPtr *addr) {
auto alloc_size = AllocDeviceMem(size, addr); auto alloc_size = AllocDeviceMem(size, addr);
buffer_q_addr_ = *addr; buffer_q_addr_ = *addr;
// Buffer queue needs to ensure that the alloc_size and size is equal. // Buffer queue needs to ensure that the alloc_size and size is equal.
return (alloc_size == size) ? true : false; return alloc_size == size;
} }
size_t GPUMemoryAllocator::AllocDeviceMem(size_t size, DeviceMemPtr *addr) { size_t GPUMemoryAllocator::AllocDeviceMem(size_t size, DeviceMemPtr *addr) {
@ -90,8 +98,9 @@ size_t GPUMemoryAllocator::AllocDeviceMem(size_t size, DeviceMemPtr *addr) {
} }
total_used_device_memory_ += alloc_size; total_used_device_memory_ += alloc_size;
available_device_memory_ -= alloc_size; available_device_memory_ -= alloc_size;
MS_LOG(INFO) << "Current free memory size[" << free_size - alloc_size << "], current alloc size[" << alloc_size MS_LOG(INFO) << "Cuda current free memory size[" << free_size << "], alloc size[" << alloc_size
<< "], total used size[" << total_used_device_memory_ << "]."; << "], left free memory size[" << free_size - alloc_size << "]"
<< ".Total used size[" << total_used_device_memory_ << "].";
return alloc_size; return alloc_size;
} }

View File

@ -24,8 +24,8 @@
namespace mindspore { namespace mindspore {
namespace device { namespace device {
namespace gpu { namespace gpu {
void *GPUMemoryManager::MallocMemFromMemPool(size_t size) { void *GPUMemoryManager::MallocMemFromMemPool(size_t size, bool from_persistent_mem) {
return GPUMemoryAllocator::GetInstance().AllocTensorMem(size); return GPUMemoryAllocator::GetInstance().AllocTensorMem(size, from_persistent_mem);
} }
void GPUMemoryManager::FreeMemFromMemPool(void *device_ptr) { void GPUMemoryManager::FreeMemFromMemPool(void *device_ptr) {
@ -39,7 +39,7 @@ std::vector<void *> GPUMemoryManager::MallocContinuousMemFromMemPool(size_t tota
bool GPUMemoryManager::MallocContinuousMemFromMemPool(const DeviceAddressPtrList &addr_list, size_t total_size, bool GPUMemoryManager::MallocContinuousMemFromMemPool(const DeviceAddressPtrList &addr_list, size_t total_size,
std::vector<size_t> size_list) { std::vector<size_t> size_list) {
auto device_ptr_list = MallocContinuousMemFromMemPool(total_size, size_list); auto device_ptr_list = MallocContinuousMemFromMemPool(total_size, size_list);
if (device_ptr_list.size() == 0) { if (device_ptr_list.empty()) {
return false; return false;
} }
if (addr_list.size() != device_ptr_list.size()) { if (addr_list.size() != device_ptr_list.size()) {
@ -76,7 +76,7 @@ void GPUMemoryManager::MallocDeviceMemory() {
if (ps::ps_cache_instance.initialized_ps_cache()) { if (ps::ps_cache_instance.initialized_ps_cache()) {
return; return;
} }
auto device_addr = MallocMemFromMemPool(1); auto device_addr = MallocMemFromMemPool(1, false);
if (!device_addr) { if (!device_addr) {
MS_LOG(EXCEPTION) << "Dynamic memory pool init error."; MS_LOG(EXCEPTION) << "Dynamic memory pool init error.";
} }
@ -87,7 +87,7 @@ void GPUMemoryManager::FreeDeviceMemory() { GPUMemoryAllocator::GetInstance().Re
uint8_t *GPUMemoryManager::MallocStaticMem(size_t size, bool, uint32_t) { uint8_t *GPUMemoryManager::MallocStaticMem(size_t size, bool, uint32_t) {
auto context_ptr = MsContext::GetInstance(); auto context_ptr = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(context_ptr); MS_EXCEPTION_IF_NULL(context_ptr);
auto device_ptr = MallocMemFromMemPool(size); auto device_ptr = MallocMemFromMemPool(size, false);
if (device_ptr == nullptr) { if (device_ptr == nullptr) {
MS_LOG(EXCEPTION) << "Device memory isn't enough and alloc failed, alloc size:" << size; MS_LOG(EXCEPTION) << "Device memory isn't enough and alloc failed, alloc size:" << size;
} }

View File

@ -29,7 +29,7 @@ class GPUMemoryManager : public MemoryManager {
void MallocDeviceMemory() override; void MallocDeviceMemory() override;
void FreeDeviceMemory() override; void FreeDeviceMemory() override;
void *MallocMemFromMemPool(size_t size) override; void *MallocMemFromMemPool(size_t size, bool from_persistent_mem) override;
void FreeMemFromMemPool(void *device_ptr) override; void FreeMemFromMemPool(void *device_ptr) override;
std::vector<void *> MallocContinuousMemFromMemPool(size_t total_size, std::vector<size_t> size_list) override; std::vector<void *> MallocContinuousMemFromMemPool(size_t total_size, std::vector<size_t> size_list) override;
bool MallocContinuousMemFromMemPool(const DeviceAddressPtrList &addr_list, size_t total_size, bool MallocContinuousMemFromMemPool(const DeviceAddressPtrList &addr_list, size_t total_size,

View File

@ -258,6 +258,7 @@ void KernelRuntime::RunOpMallocPre(const session::KernelGraph &graph,
auto output_format = op_runtime_info->output_format(index); auto output_format = op_runtime_info->output_format(index);
auto device_address = auto device_address =
CreateDeviceAddress(nullptr, output_tensor_size, output_format, output_type_id, {item, index}); CreateDeviceAddress(nullptr, output_tensor_size, output_format, output_type_id, {item, index});
device_address->set_from_persistent_mem(current_tensor->is_parameter());
AnfAlgo::SetOutputAddr(device_address, index, item.get()); AnfAlgo::SetOutputAddr(device_address, index, item.get());
current_tensor->set_device_address(device_address); current_tensor->set_device_address(device_address);
current_tensor->set_sync_status(kNeedSyncHostToDevice); current_tensor->set_sync_status(kNeedSyncHostToDevice);
@ -287,6 +288,7 @@ void KernelRuntime::ResetNodeAddress(const session::KernelGraph &kernel_graph) {
auto tensor_size = AnfAlgo::GetOutputTensorMemSize(input_node, index); auto tensor_size = AnfAlgo::GetOutputTensorMemSize(input_node, index);
auto device_address = CreateDeviceAddress(nullptr, tensor_size, AnfAlgo::GetOutputFormat(input_node, index), auto device_address = CreateDeviceAddress(nullptr, tensor_size, AnfAlgo::GetOutputFormat(input_node, index),
output_type_id, {input_node, index}); output_type_id, {input_node, index});
device_address->set_from_persistent_mem(input_node->isa<Parameter>());
AnfAlgo::SetOutputAddr(device_address, index, input_node.get()); AnfAlgo::SetOutputAddr(device_address, index, input_node.get());
} }
@ -306,7 +308,7 @@ void KernelRuntime::ResetNodeAddress(const session::KernelGraph &kernel_graph) {
} }
void KernelRuntime::RunOpAssignMemory(const std::vector<tensor::TensorPtr> &input_tensors, void KernelRuntime::RunOpAssignMemory(const std::vector<tensor::TensorPtr> &input_tensors,
const session::KernelGraph &graph, const session::KernelGraph &graph, bool is_gradient_out,
const std::map<tensor::TensorPtr, session::KernelWithIndex> &tensor_to_node) { const std::map<tensor::TensorPtr, session::KernelWithIndex> &tensor_to_node) {
MS_EXCEPTION_IF_NULL(mem_manager_); MS_EXCEPTION_IF_NULL(mem_manager_);
mem_manager_->ResetDynamicMemory(); mem_manager_->ResetDynamicMemory();
@ -319,7 +321,7 @@ void KernelRuntime::RunOpAssignMemory(const std::vector<tensor::TensorPtr> &inpu
RunOpAssignInputMemory(input_tensors, graph); RunOpAssignInputMemory(input_tensors, graph);
AssignStaticMemoryValueNode(graph); AssignStaticMemoryValueNode(graph);
for (const auto &node : graph.execution_order()) { for (const auto &node : graph.execution_order()) {
RunOpAssignOutputMemory(node, tensor_to_node); RunOpAssignOutputMemory(node, tensor_to_node, is_gradient_out);
RunOpAssignWorkSpaceMemory(node); RunOpAssignWorkSpaceMemory(node);
} }
UpdateRefNodeOutputMem(graph); UpdateRefNodeOutputMem(graph);
@ -398,6 +400,7 @@ void KernelRuntime::RunOpAssignInputMemory(const std::vector<tensor::TensorPtr>
auto current_tensor = input_tensors[input_index]; auto current_tensor = input_tensors[input_index];
MS_EXCEPTION_IF_NULL(current_tensor); MS_EXCEPTION_IF_NULL(current_tensor);
auto output_address = std::dynamic_pointer_cast<device::DeviceAddress>(current_tensor->device_address()); auto output_address = std::dynamic_pointer_cast<device::DeviceAddress>(current_tensor->device_address());
// Device address have already create
if (output_address != nullptr && output_address->DeviceType() == GetTargetDeviceAddressType()) { if (output_address != nullptr && output_address->DeviceType() == GetTargetDeviceAddressType()) {
if (output_address->ptr_ == nullptr) { if (output_address->ptr_ == nullptr) {
if (!mem_manager_->MallocMemFromMemPool(output_address, output_address->size())) { if (!mem_manager_->MallocMemFromMemPool(output_address, output_address->size())) {
@ -413,10 +416,12 @@ void KernelRuntime::RunOpAssignInputMemory(const std::vector<tensor::TensorPtr>
output_type_id = AnfAlgo::GetOutputInferDataType(item, index); output_type_id = AnfAlgo::GetOutputInferDataType(item, index);
} }
auto tensor_size = AnfAlgo::GetOutputTensorMemSize(item, index); auto tensor_size = AnfAlgo::GetOutputTensorMemSize(item, index);
// Device address new create
auto device_address = auto device_address =
CreateDeviceAddress(nullptr, tensor_size, AnfAlgo::GetOutputFormat(item, index), output_type_id, {item, index}); CreateDeviceAddress(nullptr, tensor_size, AnfAlgo::GetOutputFormat(item, index), output_type_id, {item, index});
MS_EXCEPTION_IF_NULL(device_address); MS_EXCEPTION_IF_NULL(device_address);
MS_EXCEPTION_IF_NULL(mem_manager_); MS_EXCEPTION_IF_NULL(mem_manager_);
device_address->set_from_persistent_mem(true);
auto ret = mem_manager_->MallocMemFromMemPool(device_address, tensor_size); auto ret = mem_manager_->MallocMemFromMemPool(device_address, tensor_size);
if (!ret) { if (!ret) {
MS_LOG(EXCEPTION) << "Device memory isn't enough and alloc failed, alloc size:" << tensor_size; MS_LOG(EXCEPTION) << "Device memory isn't enough and alloc failed, alloc size:" << tensor_size;
@ -426,8 +431,9 @@ void KernelRuntime::RunOpAssignInputMemory(const std::vector<tensor::TensorPtr>
} }
} }
void KernelRuntime::RunOpAssignOutputMemory( void KernelRuntime::RunOpAssignOutputMemory(const AnfNodePtr &kernel,
const AnfNodePtr &kernel, const std::map<tensor::TensorPtr, session::KernelWithIndex> &tensor_to_node) { const std::map<tensor::TensorPtr, session::KernelWithIndex> &tensor_to_node,
bool is_gradient_out) {
MS_EXCEPTION_IF_NULL(kernel); MS_EXCEPTION_IF_NULL(kernel);
MS_EXCEPTION_IF_NULL(mem_manager_); MS_EXCEPTION_IF_NULL(mem_manager_);
auto kernel_mod = AnfAlgo::GetKernelMod(kernel); auto kernel_mod = AnfAlgo::GetKernelMod(kernel);
@ -464,8 +470,11 @@ void KernelRuntime::RunOpAssignOutputMemory(
std::string output_format = AnfAlgo::GetOutputFormat(kernel, i); std::string output_format = AnfAlgo::GetOutputFormat(kernel, i);
auto output_type = AnfAlgo::GetOutputDeviceDataType(kernel, i); auto output_type = AnfAlgo::GetOutputDeviceDataType(kernel, i);
auto device_address = CreateDeviceAddress(nullptr, output_sizes[i], output_format, output_type, {kernel, i}); auto device_address = CreateDeviceAddress(nullptr, output_sizes[i], output_format, output_type, {kernel, i});
device_address->set_host_shape(trans::GetRuntimePaddingShape(kernel, i));
MS_EXCEPTION_IF_NULL(device_address); MS_EXCEPTION_IF_NULL(device_address);
device_address->set_host_shape(trans::GetRuntimePaddingShape(kernel, i));
if (is_gradient_out) {
device_address->set_from_persistent_mem(true);
}
auto ret = mem_manager_->MallocMemFromMemPool(device_address, output_sizes[i]); auto ret = mem_manager_->MallocMemFromMemPool(device_address, output_sizes[i]);
if (!ret) { if (!ret) {
MS_LOG(EXCEPTION) << "Device memory isn't enough and alloc failed, alloc size:" << output_sizes[i]; MS_LOG(EXCEPTION) << "Device memory isn't enough and alloc failed, alloc size:" << output_sizes[i];
@ -917,6 +926,7 @@ void KernelRuntime::AssignValueNodeTensor(const ValueNodePtr &value_node, const
auto output_format = AnfAlgo::GetOutputFormat(value_node, output_idx); auto output_format = AnfAlgo::GetOutputFormat(value_node, output_idx);
DeviceAddressPtr address = DeviceAddressPtr address =
CreateDeviceAddress(nullptr, node_size, output_format, output_type_id, {value_node, output_idx}); CreateDeviceAddress(nullptr, node_size, output_format, output_type_id, {value_node, output_idx});
address->set_from_persistent_mem(true);
MS_EXCEPTION_IF_NULL(address); MS_EXCEPTION_IF_NULL(address);
if (ms_context->get_param<bool>(MS_CTX_ENABLE_PYNATIVE_INFER) && if (ms_context->get_param<bool>(MS_CTX_ENABLE_PYNATIVE_INFER) &&
!mem_manager_->MallocMemFromMemPool(address, node_size)) { !mem_manager_->MallocMemFromMemPool(address, node_size)) {
@ -979,6 +989,7 @@ void KernelRuntime::AssignStaticMemoryValueNode(const session::KernelGraph &grap
ms_context->get_param<int>(MS_CTX_EXECUTION_MODE) == kPynativeMode; ms_context->get_param<int>(MS_CTX_EXECUTION_MODE) == kPynativeMode;
auto address = CreateDeviceAddressForStringValue(node_value, use_mem_from_memory_pool, graph.graph_id()); auto address = CreateDeviceAddressForStringValue(node_value, use_mem_from_memory_pool, graph.graph_id());
MS_EXCEPTION_IF_NULL(address); MS_EXCEPTION_IF_NULL(address);
address->set_from_persistent_mem(true);
AnfAlgo::SetOutputAddr(address, 0, value_node.get()); AnfAlgo::SetOutputAddr(address, 0, value_node.get());
} }
} }
@ -991,6 +1002,7 @@ DeviceAddressPtr KernelRuntime::CreateDeviceAddressForStringValue(const ValuePtr
size_t tensor_size = value_string.size(); size_t tensor_size = value_string.size();
DeviceAddressPtr address = CreateDeviceAddress(nullptr, tensor_size, kOpFormat_DEFAULT, kNumberTypeUInt8); DeviceAddressPtr address = CreateDeviceAddress(nullptr, tensor_size, kOpFormat_DEFAULT, kNumberTypeUInt8);
MS_EXCEPTION_IF_NULL(address); MS_EXCEPTION_IF_NULL(address);
address->set_from_persistent_mem(true);
auto ms_context = MsContext::GetInstance(); auto ms_context = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(ms_context); MS_EXCEPTION_IF_NULL(ms_context);
if (use_mem_pool && !mem_manager_->MallocMemFromMemPool(address, tensor_size)) { if (use_mem_pool && !mem_manager_->MallocMemFromMemPool(address, tensor_size)) {

View File

@ -56,6 +56,7 @@ class KernelRuntime {
virtual bool Init() = 0; virtual bool Init() = 0;
virtual void AssignMemory(const session::KernelGraph &graph); virtual void AssignMemory(const session::KernelGraph &graph);
void RunOpAssignMemory(const std::vector<tensor::TensorPtr> &input_tensors, const session::KernelGraph &graph, void RunOpAssignMemory(const std::vector<tensor::TensorPtr> &input_tensors, const session::KernelGraph &graph,
bool is_gradient_out,
const std::map<tensor::TensorPtr, session::KernelWithIndex> &tensor_to_node = {}); const std::map<tensor::TensorPtr, session::KernelWithIndex> &tensor_to_node = {});
void AssignCommunicationOutputFromMemoryPool(const AnfNodePtr &node) const; void AssignCommunicationOutputFromMemoryPool(const AnfNodePtr &node) const;
void AssignCommunicationInputFromMemoryPool(const AnfNodePtr &node) const; void AssignCommunicationInputFromMemoryPool(const AnfNodePtr &node) const;
@ -173,7 +174,8 @@ class KernelRuntime {
const std::shared_ptr<MemScheduler> &mem_schedule = nullptr); const std::shared_ptr<MemScheduler> &mem_schedule = nullptr);
void RunOpAssignInputMemory(const std::vector<tensor::TensorPtr> &input_tensors, const session::KernelGraph &graph); void RunOpAssignInputMemory(const std::vector<tensor::TensorPtr> &input_tensors, const session::KernelGraph &graph);
void RunOpAssignOutputMemory(const AnfNodePtr &kernel, void RunOpAssignOutputMemory(const AnfNodePtr &kernel,
const std::map<tensor::TensorPtr, session::KernelWithIndex> &tensor_to_node = {}); const std::map<tensor::TensorPtr, session::KernelWithIndex> &tensor_to_node,
bool is_gradient_out);
void RunOpAssignWorkSpaceMemory(const AnfNodePtr &kernel); void RunOpAssignWorkSpaceMemory(const AnfNodePtr &kernel);
void RunOpAssignOutputNodeMemory(const ValuePtr &pre_output_value, const session::KernelGraph &graph); void RunOpAssignOutputNodeMemory(const ValuePtr &pre_output_value, const session::KernelGraph &graph);
void AssignValueNodeTensor(const ValueNodePtr &value_node, const ValuePtr &node_value, size_t output_idx); void AssignValueNodeTensor(const ValueNodePtr &value_node, const ValuePtr &node_value, size_t output_idx);

View File

@ -141,9 +141,9 @@ uint8_t *MemoryManager::MallocDynamicMem(size_t size, bool communication_mem) {
return nullptr; return nullptr;
} }
bool MemoryManager::MallocMemFromMemPool(const DeviceAddressPtr address, size_t size) { bool MemoryManager::MallocMemFromMemPool(const DeviceAddressPtr &address, size_t size) {
MS_EXCEPTION_IF_NULL(address); MS_EXCEPTION_IF_NULL(address);
auto device_ptr = MallocMemFromMemPool(size); auto device_ptr = MallocMemFromMemPool(size, address->from_persistent_mem_);
if (!device_ptr) { if (!device_ptr) {
return false; return false;
} }
@ -154,7 +154,7 @@ bool MemoryManager::MallocMemFromMemPool(const DeviceAddressPtr address, size_t
return true; return true;
} }
void *MemoryManager::MallocMemFromMemPool(size_t size) { void *MemoryManager::MallocMemFromMemPool(size_t size, bool from_persistent_mem) {
if (size == 0) { if (size == 0) {
MS_LOG(ERROR) << "MallocMemFromMemPool size is 0."; MS_LOG(ERROR) << "MallocMemFromMemPool size is 0.";
} }

View File

@ -49,8 +49,10 @@ class MemoryManager : public MemHandler {
virtual uint8_t *MallocMem(MemType type, size_t size, const DeviceAddressPtr &address, virtual uint8_t *MallocMem(MemType type, size_t size, const DeviceAddressPtr &address,
uint32_t graph_id = kInvalidGraphId); uint32_t graph_id = kInvalidGraphId);
virtual bool MallocMemFromMemPool(const DeviceAddressPtr address, size_t size); // param address is the address type of each device
virtual void *MallocMemFromMemPool(size_t size); // param from_persistent_mem shows whether the tensor is a parameter in Pynative mode
virtual bool MallocMemFromMemPool(const DeviceAddressPtr &address, size_t size);
virtual void *MallocMemFromMemPool(size_t size, bool from_persistent_mem);
virtual uint8_t *MallocCommunicationMemFromMemPool(size_t size) { return nullptr; } virtual uint8_t *MallocCommunicationMemFromMemPool(size_t size) { return nullptr; }
virtual void FreeMemFromMemPool(const DeviceAddressPtr address); virtual void FreeMemFromMemPool(const DeviceAddressPtr address);
virtual void FreeMemFromMemPool(void *device_ptr); virtual void FreeMemFromMemPool(void *device_ptr);
@ -62,7 +64,7 @@ class MemoryManager : public MemHandler {
static size_t GetCommunicationAlignSize(size_t input_size); static size_t GetCommunicationAlignSize(size_t input_size);
// swap manager interface // swap manager interface
void *MallocDevice(size_t mem_size) override { return MallocMemFromMemPool(mem_size); } void *MallocDevice(size_t mem_size) override { return MallocMemFromMemPool(mem_size, false); }
void FreeDevice(void *ptr) override { void FreeDevice(void *ptr) override {
MS_EXCEPTION_IF_NULL(ptr); MS_EXCEPTION_IF_NULL(ptr);
FreeMemFromMemPool(ptr); FreeMemFromMemPool(ptr);

View File

@ -607,6 +607,7 @@ void DataPrepareActor::PrepareDataForWeightNode(const AnfNodePtr &backend_node,
(device_tensor->DeviceType() != device_context->GetDeviceAddressType())) { (device_tensor->DeviceType() != device_context->GetDeviceAddressType())) {
host_tensor_address = device_context->CreateDeviceAddress(nullptr, device_tensor->GetSize(), host_tensor_address = device_context->CreateDeviceAddress(nullptr, device_tensor->GetSize(),
device_tensor->format(), device_tensor->type_id()); device_tensor->format(), device_tensor->type_id());
host_tensor_address->set_from_persistent_mem(tensor->is_parameter());
} else { } else {
host_tensor_address = device_tensor; host_tensor_address = device_tensor;
} }

View File

@ -100,6 +100,7 @@ void CreateParameterDeviceAddress(const DeviceContext *device_context, const Ker
size_t tensor_size = AnfAlgo::GetOutputTensorMemSize(item, index); size_t tensor_size = AnfAlgo::GetOutputTensorMemSize(item, index);
auto device_address = device_context->CreateDeviceAddress(nullptr, tensor_size, auto device_address = device_context->CreateDeviceAddress(nullptr, tensor_size,
AnfAlgo::GetOutputFormat(item, index), output_type_id); AnfAlgo::GetOutputFormat(item, index), output_type_id);
device_address->set_from_persistent_mem(item->isa<Parameter>());
MS_LOG(DEBUG) << "Create addr for node:" << AnfAlgo::GetNodeDebugString(item) << " addr:" << device_address; MS_LOG(DEBUG) << "Create addr for node:" << AnfAlgo::GetNodeDebugString(item) << " addr:" << device_address;
AnfAlgo::SetOutputAddr(device_address, index, item.get()); AnfAlgo::SetOutputAddr(device_address, index, item.get());
} }
@ -143,6 +144,7 @@ void CreateDeviceAddressForTensorValue(const DeviceContext *device_context, cons
device_context->CreateDeviceAddress(nullptr, tensor_size, output_format, output_type_id); device_context->CreateDeviceAddress(nullptr, tensor_size, output_format, output_type_id);
MS_LOG(DEBUG) << "Create addr for node:" << AnfAlgo::GetNodeDebugString(value_node) << " addr:" << address; MS_LOG(DEBUG) << "Create addr for node:" << AnfAlgo::GetNodeDebugString(value_node) << " addr:" << address;
MS_EXCEPTION_IF_NULL(address); MS_EXCEPTION_IF_NULL(address);
address->set_from_persistent_mem(true);
AnfAlgo::SetOutputAddr(address, output_idx++, value_node.get()); AnfAlgo::SetOutputAddr(address, output_idx++, value_node.get());
} }
} }
@ -165,6 +167,7 @@ void CreateValueNodeDeviceAddress(const DeviceContext *device_context, const Ker
size_t tensor_size = value.size(); size_t tensor_size = value.size();
auto address = device_context->CreateDeviceAddress(nullptr, tensor_size, kOpFormat_DEFAULT, kNumberTypeUInt8); auto address = device_context->CreateDeviceAddress(nullptr, tensor_size, kOpFormat_DEFAULT, kNumberTypeUInt8);
MS_EXCEPTION_IF_NULL(address); MS_EXCEPTION_IF_NULL(address);
address->set_from_persistent_mem(true);
MS_LOG(DEBUG) << "Create addr for node:" << AnfAlgo::GetNodeDebugString(value_node) << " addr:" << address; MS_LOG(DEBUG) << "Create addr for node:" << AnfAlgo::GetNodeDebugString(value_node) << " addr:" << address;
AnfAlgo::SetOutputAddr(address, 0, value_node.get()); AnfAlgo::SetOutputAddr(address, 0, value_node.get());
@ -172,7 +175,8 @@ void CreateValueNodeDeviceAddress(const DeviceContext *device_context, const Ker
} }
} }
void CreateKernelOutputDeviceAddress(const DeviceContext *device_context, const KernelGraphPtr &graph) { void CreateKernelOutputDeviceAddress(const DeviceContext *device_context, const KernelGraphPtr &graph,
bool is_gradient_out) {
MS_EXCEPTION_IF_NULL(device_context); MS_EXCEPTION_IF_NULL(device_context);
MS_EXCEPTION_IF_NULL(graph); MS_EXCEPTION_IF_NULL(graph);
const std::vector<CNodePtr> &kernels = graph->execution_order(); const std::vector<CNodePtr> &kernels = graph->execution_order();
@ -191,6 +195,9 @@ void CreateKernelOutputDeviceAddress(const DeviceContext *device_context, const
auto output_type = AnfAlgo::GetOutputDeviceDataType(kernel, i); auto output_type = AnfAlgo::GetOutputDeviceDataType(kernel, i);
auto address_size = AnfAlgo::GetOutputTensorMemSize(kernel, i); auto address_size = AnfAlgo::GetOutputTensorMemSize(kernel, i);
auto device_address = device_context->CreateDeviceAddress(nullptr, address_size, output_format, output_type); auto device_address = device_context->CreateDeviceAddress(nullptr, address_size, output_format, output_type);
if (is_gradient_out) {
device_address->set_from_persistent_mem(true);
}
MS_LOG(DEBUG) << "Create addr for node:" << AnfAlgo::GetNodeDebugString(kernel) << " addr:" << device_address; MS_LOG(DEBUG) << "Create addr for node:" << AnfAlgo::GetNodeDebugString(kernel) << " addr:" << device_address;
AnfAlgo::SetOutputAddr(device_address, i, kernel.get()); AnfAlgo::SetOutputAddr(device_address, i, kernel.get());
} }
@ -434,7 +441,7 @@ GraphId GraphCompiler::CompileGraphImpl(const KernelGraphPtr &graph, const Devic
device_context->PreprocessBeforeRunGraph(graph); device_context->PreprocessBeforeRunGraph(graph);
// Create device address for all anf nodes of graph. // Create device address for all anf nodes of graph.
CreateDeviceAddress(graph, device_context); CreateDeviceAddress(graph, device_context, false);
graph->set_is_all_nop_node(opt::IsAllNopNode(graph.get())); graph->set_is_all_nop_node(opt::IsAllNopNode(graph.get()));
@ -498,7 +505,7 @@ GraphId GraphCompiler::CompileGraph(const session::OpRunInfo &op_run_info, bool
device_context->OptimizeSingleOpGraph(graph); device_context->OptimizeSingleOpGraph(graph);
// Create device address for all anf nodes of graph. // Create device address for all anf nodes of graph.
CreateDeviceAddressWithoutWorkspace(graph, device_context); CreateDeviceAddressWithoutWorkspace(graph, device_context, op_run_info.is_gradient_out);
graph->set_is_all_nop_node(opt::IsAllNopNode(graph.get())); graph->set_is_all_nop_node(opt::IsAllNopNode(graph.get()));
run_op_graphs_[op_run_info.graph_info] = graph; run_op_graphs_[op_run_info.graph_info] = graph;
@ -546,20 +553,22 @@ KernelGraphPtr GraphCompiler::Fetch(const GraphInfo &graph_info) const {
return iter->second; return iter->second;
} }
void GraphCompiler::CreateDeviceAddress(const KernelGraphPtr &graph, const DeviceContext *device_context) const { void GraphCompiler::CreateDeviceAddress(const KernelGraphPtr &graph, const DeviceContext *device_context,
bool is_gradient_out) const {
CreateParameterDeviceAddress(device_context, graph); CreateParameterDeviceAddress(device_context, graph);
CreateValueNodeDeviceAddress(device_context, graph); CreateValueNodeDeviceAddress(device_context, graph);
CreateKernelOutputDeviceAddress(device_context, graph); CreateKernelOutputDeviceAddress(device_context, graph, is_gradient_out);
CreateKernelWorkspaceDeviceAddress(device_context, graph); CreateKernelWorkspaceDeviceAddress(device_context, graph);
UpdateDeviceAddressForInplaceNode(graph); UpdateDeviceAddressForInplaceNode(graph);
UpdateDeviceAddressForRefNode(graph); UpdateDeviceAddressForRefNode(graph);
} }
void GraphCompiler::CreateDeviceAddressWithoutWorkspace(const KernelGraphPtr &graph, void GraphCompiler::CreateDeviceAddressWithoutWorkspace(const KernelGraphPtr &graph,
const DeviceContext *device_context) const { const DeviceContext *device_context,
bool is_gradient_out) const {
CreateParameterDeviceAddress(device_context, graph); CreateParameterDeviceAddress(device_context, graph);
CreateValueNodeDeviceAddress(device_context, graph); CreateValueNodeDeviceAddress(device_context, graph);
CreateKernelOutputDeviceAddress(device_context, graph); CreateKernelOutputDeviceAddress(device_context, graph, is_gradient_out);
UpdateDeviceAddressForInplaceNode(graph); UpdateDeviceAddressForInplaceNode(graph);
UpdateDeviceAddressForRefNode(graph); UpdateDeviceAddressForRefNode(graph);
} }
@ -593,24 +602,27 @@ TensorPtr GraphCompiler::GetSingleOpInputTensorByIndex(const CNodePtr &kernel,
} }
void GraphCompiler::GetSingleOpRunInfoAndGraphInfo(const CNodePtr &kernel, const InputTensorInfo &tensor_info, void GraphCompiler::GetSingleOpRunInfoAndGraphInfo(const CNodePtr &kernel, const InputTensorInfo &tensor_info,
OpRunInfo *run_info, GraphInfo *graph_info) { OpRunInfo *run_info, GraphInfo *graph_info,
GraphOutputInfo *const graph_output_info) {
MS_EXCEPTION_IF_NULL(session_); MS_EXCEPTION_IF_NULL(session_);
MS_EXCEPTION_IF_NULL(graph_info); MS_EXCEPTION_IF_NULL(graph_info);
*graph_info = session_->GetSingleOpGraphInfo(kernel, tensor_info.input_tensors); *graph_info = session_->GetSingleOpGraphInfo(kernel, tensor_info.input_tensors);
*run_info = session_->GetSingleOpRunInfo(kernel, *graph_info, tensor_info); *run_info = session_->GetSingleOpRunInfo(kernel, *graph_info, tensor_info, graph_output_info);
} }
void GraphCompiler::CalculateRefCount(const KernelGraphPtr &graph, std::map<KernelWithIndex, size_t> *ref_count) const { void GraphCompiler::CalculateRefCount(const KernelGraphPtr &graph, std::map<KernelWithIndex, size_t> *ref_count,
std::map<AnfNodePtr, size_t> *forward_output_refcount) const {
MS_EXCEPTION_IF_NULL(session_); MS_EXCEPTION_IF_NULL(session_);
session_->GetRefCount(graph.get(), ref_count); session_->GetRefCount(graph.get(), ref_count);
session_->GetForwardOutputRefCount(graph.get(), forward_output_refcount);
} }
void GraphCompiler::UpdateRefCount(const std::set<KernelWithIndex> &input_kernels_with_index, void GraphCompiler::UpdateRefCount(const std::set<KernelWithIndex> &input_kernels_with_index,
std::map<KernelWithIndex, size_t> *ref_count, std::map<KernelWithIndex, size_t> *ref_count,
std::map<AnfNodePtr, size_t> *forward_output_refcount,
std::map<KernelWithIndex, tensor::TensorPtr> *op_output_map) const { std::map<KernelWithIndex, tensor::TensorPtr> *op_output_map) const {
MS_EXCEPTION_IF_NULL(session_); MS_EXCEPTION_IF_NULL(session_);
session_->HandleOpInputs(input_kernels_with_index, ref_count, op_output_map); session_->HandleOpInputs(input_kernels_with_index, ref_count, forward_output_refcount, op_output_map);
} }
void GraphCompiler::RecoverGraphOutput(const AnfNodePtr &kernel, const VectorRef &op_outputs, void GraphCompiler::RecoverGraphOutput(const AnfNodePtr &kernel, const VectorRef &op_outputs,

View File

@ -139,14 +139,16 @@ class GraphCompiler {
// Get OpRunInfo and GraphInfo for single op compile and run. // Get OpRunInfo and GraphInfo for single op compile and run.
void GetSingleOpRunInfoAndGraphInfo(const CNodePtr &kernel, const InputTensorInfo &tensor_info, OpRunInfo *run_info, void GetSingleOpRunInfoAndGraphInfo(const CNodePtr &kernel, const InputTensorInfo &tensor_info, OpRunInfo *run_info,
GraphInfo *graph_info); GraphInfo *graph_info, GraphOutputInfo *const graph_output_info);
// Calculate ref count of PyNative back propagation operators. // Calculate ref count of PyNative back propagation operators.
void CalculateRefCount(const KernelGraphPtr &graph, std::map<KernelWithIndex, size_t> *ref_count) const; void CalculateRefCount(const KernelGraphPtr &graph, std::map<KernelWithIndex, size_t> *ref_count,
std::map<AnfNodePtr, size_t> *forward_output_refcount) const;
// Update ref count of PyNative back propagation operators. // Update ref count of PyNative back propagation operators.
void UpdateRefCount(const std::set<KernelWithIndex> &input_kernels_with_index, void UpdateRefCount(const std::set<KernelWithIndex> &input_kernels_with_index,
std::map<KernelWithIndex, size_t> *ref_count, std::map<KernelWithIndex, size_t> *ref_count,
std::map<AnfNodePtr, size_t> *forward_output_refcount,
std::map<KernelWithIndex, tensor::TensorPtr> *op_output_map) const; std::map<KernelWithIndex, tensor::TensorPtr> *op_output_map) const;
// Handle single op output tensor and recover output of original complete kernel graph. // Handle single op output tensor and recover output of original complete kernel graph.
@ -182,10 +184,12 @@ class GraphCompiler {
GraphId CompileGraphImpl(const KernelGraphPtr &graph, const DeviceContext *device_context) const; GraphId CompileGraphImpl(const KernelGraphPtr &graph, const DeviceContext *device_context) const;
// Create device address for all anf nodes of graph. // Create device address for all anf nodes of graph.
void CreateDeviceAddress(const KernelGraphPtr &graph, const DeviceContext *device_context) const; void CreateDeviceAddress(const KernelGraphPtr &graph, const DeviceContext *device_context,
bool is_gradient_out) const;
// Create device address for input and output of ops. // Create device address for input and output of ops.
void CreateDeviceAddressWithoutWorkspace(const KernelGraphPtr &graph, const DeviceContext *device_context) const; void CreateDeviceAddressWithoutWorkspace(const KernelGraphPtr &graph, const DeviceContext *device_context,
bool is_gradient_out) const;
// Set Graph's dependencies for pre_graph and post_graph. // Set Graph's dependencies for pre_graph and post_graph.
void SetGraphDependency(const KernelGraphPtr &graph, const GraphSegmentPtr &segment) const; void SetGraphDependency(const KernelGraphPtr &graph, const GraphSegmentPtr &segment) const;

View File

@ -1763,6 +1763,7 @@ void GraphScheduler::PersistDeviceTensor(const GraphCompilerInfo &graph_compiler
auto other_type_device_tensor = device_context->CreateDeviceAddress( auto other_type_device_tensor = device_context->CreateDeviceAddress(
nullptr, device_tensor->GetSize(), device_tensor->format(), device_tensor->type_id()); nullptr, device_tensor->GetSize(), device_tensor->format(), device_tensor->type_id());
other_type_device_tensor->SetNodeIndex(input_node, 0); other_type_device_tensor->SetNodeIndex(input_node, 0);
other_type_device_tensor->set_from_persistent_mem(input_node->isa<Parameter>());
AddDeviceTensorStore(front_node.get(), other_type_device_tensor); AddDeviceTensorStore(front_node.get(), other_type_device_tensor);
} }
} }

View File

@ -511,7 +511,7 @@ bool AscendDeviceContext::AllocateMemory(DeviceAddress *const &address, size_t s
MS_EXCEPTION_IF_NULL(address); MS_EXCEPTION_IF_NULL(address);
MS_EXCEPTION_IF_NULL(runtime_instance_); MS_EXCEPTION_IF_NULL(runtime_instance_);
runtime_instance_->SetContext(); runtime_instance_->SetContext();
auto device_ptr = mem_manager_->MallocMemFromMemPool(size); auto device_ptr = mem_manager_->MallocMemFromMemPool(size, address->from_persistent_mem_);
if (!device_ptr) { if (!device_ptr) {
return false; return false;
} }

View File

@ -89,7 +89,7 @@ bool CPUDeviceContext::AllocateMemory(DeviceAddress *const &address, size_t size
MS_LOG(EXCEPTION) << "The device address type is wrong: " << address->DeviceType(); MS_LOG(EXCEPTION) << "The device address type is wrong: " << address->DeviceType();
} }
auto device_ptr = mem_manager_->MallocMemFromMemPool(size); auto device_ptr = mem_manager_->MallocMemFromMemPool(size, 0);
if (!device_ptr) { if (!device_ptr) {
return false; return false;
} }

View File

@ -173,7 +173,7 @@ bool GPUDeviceContext::AllocateMemory(DeviceAddress *const &address, size_t size
if (!BindDeviceToCurrentThread()) { if (!BindDeviceToCurrentThread()) {
return false; return false;
} }
auto device_ptr = mem_manager_->MallocMemFromMemPool(size); auto device_ptr = mem_manager_->MallocMemFromMemPool(size, address->from_persistent_mem_);
if (!device_ptr) { if (!device_ptr) {
return false; return false;
} }

View File

@ -234,7 +234,7 @@ void UpdateOutput(const std::vector<session::KernelWithIndex> &output_nodes, Vec
} }
} }
void ClearGraphDeviceAddress(const KernelGraphPtr &graph, const DeviceContext *device_context) { void ClearGraphDeviceAddress(const KernelGraphPtr &graph, const DeviceContext *device_context, bool is_gradient_out) {
MS_EXCEPTION_IF_NULL(graph); MS_EXCEPTION_IF_NULL(graph);
for (const auto &node : graph->execution_order()) { for (const auto &node : graph->execution_order()) {
auto output_address_num = AnfAlgo::GetOutputAddressNum(node); auto output_address_num = AnfAlgo::GetOutputAddressNum(node);
@ -252,6 +252,9 @@ void ClearGraphDeviceAddress(const KernelGraphPtr &graph, const DeviceContext *d
MS_EXCEPTION_IF_NULL(new_device_address); MS_EXCEPTION_IF_NULL(new_device_address);
new_device_address->set_original_ref_count(device_address->original_ref_count()); new_device_address->set_original_ref_count(device_address->original_ref_count());
new_device_address->ResetRefCount(); new_device_address->ResetRefCount();
if (is_gradient_out) {
new_device_address->set_from_persistent_mem(true);
}
AnfAlgo::SetOutputAddr(new_device_address, i, node.get()); AnfAlgo::SetOutputAddr(new_device_address, i, node.get());
} }
} }
@ -800,9 +803,10 @@ void MindRTBackend::RunGraphBySingleOp(const std::vector<KernelGraphPtr> &graphs
&graph_output_info.output_indexes); &graph_output_info.output_indexes);
std::map<KernelWithIndex, size_t> cnode_ref_count; std::map<KernelWithIndex, size_t> cnode_ref_count;
std::map<AnfNodePtr, size_t> forward_output_refcount;
auto iter = cnode_ref_counts_.find(graph->graph_id()); auto iter = cnode_ref_counts_.find(graph->graph_id());
if (iter == cnode_ref_counts_.end()) { if (iter == cnode_ref_counts_.end()) {
graph_compiler_->CalculateRefCount(graph, &cnode_ref_count); graph_compiler_->CalculateRefCount(graph, &cnode_ref_count, &forward_output_refcount);
(void)cnode_ref_counts_.emplace(graph->graph_id(), cnode_ref_count); (void)cnode_ref_counts_.emplace(graph->graph_id(), cnode_ref_count);
} else { } else {
cnode_ref_count = iter->second; cnode_ref_count = iter->second;
@ -822,7 +826,8 @@ void MindRTBackend::RunGraphBySingleOp(const std::vector<KernelGraphPtr> &graphs
GraphInfo graph_info; GraphInfo graph_info;
graph_compiler_->GetSingleOpInputTensors(kernel, op_output_map, parameter_index, inputs[graph_index], graph_compiler_->GetSingleOpInputTensors(kernel, op_output_map, parameter_index, inputs[graph_index],
&input_tensor_info); &input_tensor_info);
graph_compiler_->GetSingleOpRunInfoAndGraphInfo(kernel, input_tensor_info, &op_run_info, &graph_info); graph_compiler_->GetSingleOpRunInfoAndGraphInfo(kernel, input_tensor_info, &op_run_info, &graph_info,
&graph_output_info);
RunOp(&op_run_info, &op_outputs); RunOp(&op_run_info, &op_outputs);
} else { } else {
@ -832,7 +837,8 @@ void MindRTBackend::RunGraphBySingleOp(const std::vector<KernelGraphPtr> &graphs
runtime::OpLazyBuilder::GetInstance().ExecuteRemainingTasks(); runtime::OpLazyBuilder::GetInstance().ExecuteRemainingTasks();
} }
graph_compiler_->UpdateRefCount(input_tensor_info.input_kernel, &cnode_ref_count, &op_output_map); graph_compiler_->UpdateRefCount(input_tensor_info.input_kernel, &cnode_ref_count, &forward_output_refcount,
&op_output_map);
graph_output_info.graph_output_tensors.clear(); graph_output_info.graph_output_tensors.clear();
graph_compiler_->RecoverGraphOutput(kernel, op_outputs, cnode_ref_count, &op_output_map, &graph_output_info); graph_compiler_->RecoverGraphOutput(kernel, op_outputs, cnode_ref_count, &op_output_map, &graph_output_info);
@ -1246,7 +1252,8 @@ void MindRTBackend::LazyExecuteTaskCallback() {
const auto &context = op_run_task->context(); const auto &context = op_run_task->context();
RunSingleOpGraph(context->graph(), context->output_nodes(), context->op_run_info(), RunSingleOpGraph(context->graph(), context->output_nodes(), context->op_run_info(),
context->graph_compiler_info(), context->device_context()); context->graph_compiler_info(), context->device_context());
ClearGraphDeviceAddress(context->graph(), context->device_context()); ClearGraphDeviceAddress(context->graph(), context->device_context(), false);
UpdateInputDeviceAddress(context->graph()); UpdateInputDeviceAddress(context->graph());
op_lazy_builder.PopOpRunTask(); op_lazy_builder.PopOpRunTask();
@ -1304,7 +1311,7 @@ void MindRTBackend::RunOpInternal(bool single_op_cache_hit, GraphCompilerInfo *g
} }
RunSingleOpGraph(graph, output_nodes, *op_run_info, graph_compiler_info, device_context); RunSingleOpGraph(graph, output_nodes, *op_run_info, graph_compiler_info, device_context);
UpdateOutput(output_nodes, outputs); UpdateOutput(output_nodes, outputs);
ClearGraphDeviceAddress(graph, device_context); ClearGraphDeviceAddress(graph, device_context, false);
UpdateInputDeviceAddress(graph); UpdateInputDeviceAddress(graph);
if (op_run_info->is_dynamic_shape) { if (op_run_info->is_dynamic_shape) {
UpdateOutputAbstract(graph, op_run_info); UpdateOutputAbstract(graph, op_run_info);

View File

@ -273,6 +273,14 @@ class _Context:
raise ValueError("For 'context.set_context', the argument 'max_device_memory' should not be \"0GB\".") raise ValueError("For 'context.set_context', the argument 'max_device_memory' should not be \"0GB\".")
self.set_param(ms_ctx_param.max_device_memory, max_device_memory_value) self.set_param(ms_ctx_param.max_device_memory, max_device_memory_value)
def set_mempool_block_size(self, mempool_block_size):
if not Validator.check_str_by_regular(mempool_block_size, _re_pattern):
raise ValueError("Context param mempool_block_size should be in correct format! Such as \"10GB\"")
mempool_block_size_value = float(mempool_block_size[:-2])
if mempool_block_size_value < 1.0:
raise ValueError("Context param mempool_block_size should be greater or equal to \"1GB\"")
self.set_param(ms_ctx_param.mempool_block_size, mempool_block_size_value)
def set_print_file_path(self, file_path): def set_print_file_path(self, file_path):
"""Add timestamp suffix to file name. Sets print file path.""" """Add timestamp suffix to file name. Sets print file path."""
print_file_path = os.path.realpath(file_path) print_file_path = os.path.realpath(file_path)
@ -317,6 +325,7 @@ class _Context:
'profiling_options': set_profiling_options, 'profiling_options': set_profiling_options,
'variable_memory_max_size': set_variable_memory_max_size, 'variable_memory_max_size': set_variable_memory_max_size,
'max_device_memory': set_max_device_memory, 'max_device_memory': set_max_device_memory,
'mempool_block_size': set_mempool_block_size,
'print_file_path': set_print_file_path, 'print_file_path': set_print_file_path,
'env_config_path': set_env_config_path 'env_config_path': set_env_config_path
} }
@ -570,7 +579,8 @@ def _check_target_specific_cfgs(device, arg_key):
'print_file_path': ['Ascend'], 'print_file_path': ['Ascend'],
'variable_memory_max_size': ['Ascend'], 'variable_memory_max_size': ['Ascend'],
'auto_tune_mode': ['Ascend'], 'auto_tune_mode': ['Ascend'],
'max_device_memory': ['GPU'] 'max_device_memory': ['GPU'],
'mempool_block_size': ['GPU', 'Ascend']
} }
# configs not in map device_cfgs are supposed to be suitable for all devices # configs not in map device_cfgs are supposed to be suitable for all devices
if not arg_key in device_cfgs: if not arg_key in device_cfgs:
@ -583,15 +593,15 @@ def _check_target_specific_cfgs(device, arg_key):
return False return False
@args_unreset_check(device_id=int, variable_memory_max_size=str, max_device_memory=str) @args_unreset_check(device_id=int, variable_memory_max_size=str, max_device_memory=str, mempool_block_size=str)
@args_type_check(mode=int, precompile_only=bool, device_target=str, device_id=int, save_graphs=bool, @args_type_check(mode=int, precompile_only=bool, device_target=str, device_id=int, save_graphs=bool,
save_graphs_path=str, enable_dump=bool, auto_tune_mode=str, save_graphs_path=str, enable_dump=bool, auto_tune_mode=str,
save_dump_path=str, enable_reduce_precision=bool, variable_memory_max_size=str, save_dump_path=str, enable_reduce_precision=bool, variable_memory_max_size=str,
enable_profiling=bool, profiling_options=str, enable_auto_mixed_precision=bool, enable_profiling=bool, profiling_options=str, enable_auto_mixed_precision=bool,
enable_graph_kernel=bool, reserve_class_name_in_scope=bool, check_bprop=bool, enable_graph_kernel=bool, reserve_class_name_in_scope=bool, check_bprop=bool,
max_device_memory=str, print_file_path=str, enable_sparse=bool, max_call_depth=int, max_device_memory=str, print_file_path=str, enable_sparse=bool, max_call_depth=int,
env_config_path=str, graph_kernel_flags=str, enable_compile_cache=bool, env_config_path=str, graph_kernel_flags=str, save_compile_cache=bool,
compile_cache_path=str, grad_for_scalar=bool, pynative_synchronize=bool) load_compile_cache=bool, grad_for_scalar=bool, pynative_synchronize=bool, mempool_block_size=str)
def set_context(**kwargs): def set_context(**kwargs):
""" """
Set context for running environment. Set context for running environment.
@ -616,6 +626,8 @@ def set_context(**kwargs):
| | max_device_memory | GPU | | | max_device_memory | GPU |
| +------------------------------+----------------------------+ | +------------------------------+----------------------------+
| | variable_memory_max_size | Ascend | | | variable_memory_max_size | Ascend |
| +------------------------------+----------------------------+
| | mempool_block_size | GPU/Ascend |
+-------------------------+------------------------------+----------------------------+ +-------------------------+------------------------------+----------------------------+
| Debug Configuration | save_graphs | CPU/GPU/Ascend | | Debug Configuration | save_graphs | CPU/GPU/Ascend |
| +------------------------------+----------------------------+ | +------------------------------+----------------------------+
@ -672,6 +684,9 @@ def set_context(**kwargs):
The actual used memory size is the minimum of the available memory of the device and max_device_memory. The actual used memory size is the minimum of the available memory of the device and max_device_memory.
variable_memory_max_size (str): Set the maximum size of the variable memory max size. Default: "30GB". variable_memory_max_size (str): Set the maximum size of the variable memory max size. Default: "30GB".
After this parameter is set, the maximum memory used by the framework is restricted to the configured value. After this parameter is set, the maximum memory used by the framework is restricted to the configured value.
mempool_block_size (str): Set the size of the memory pool block in PyNative mode for devices.
The format is "xxGB". Default: "1GB". Minimum size is "1G". The actual used memory block size is the minimum
of the available memory of the device and mempool_block_size.
save_graphs (bool): Whether to save graphs. Default: False. save_graphs (bool): Whether to save graphs. Default: False.
When the `save_graphs` attribute is set as True, attribute of `save_graphs_path` is used to set the When the `save_graphs` attribute is set as True, attribute of `save_graphs_path` is used to set the
intermediate compilation graph storage path. By default, the graphs are saved in the current directory. intermediate compilation graph storage path. By default, the graphs are saved in the current directory.
@ -813,6 +828,7 @@ def set_context(**kwargs):
... profiling_options='{"output":"/home/data/output","training_trace":"on"}') ... profiling_options='{"output":"/home/data/output","training_trace":"on"}')
>>> context.set_context(check_bprop=True) >>> context.set_context(check_bprop=True)
>>> context.set_context(max_device_memory="3.5GB") >>> context.set_context(max_device_memory="3.5GB")
>>> context.set_context(mempool_block_size="1GB")
>>> context.set_context(print_file_path="print.pb") >>> context.set_context(print_file_path="print.pb")
>>> context.set_context(enable_sparse=True) >>> context.set_context(enable_sparse=True)
>>> context.set_context(max_call_depth=80) >>> context.set_context(max_call_depth=80)

View File

@ -86,6 +86,7 @@ MsContext::MsContext(const std::string &policy, const std::string &target) {
set_param<std::string>(MS_CTX_PROFILING_OPTIONS, "training_trace"); set_param<std::string>(MS_CTX_PROFILING_OPTIONS, "training_trace");
set_param<bool>(MS_CTX_CHECK_BPROP_FLAG, false); set_param<bool>(MS_CTX_CHECK_BPROP_FLAG, false);
set_param<float>(MS_CTX_MAX_DEVICE_MEMORY, kDefaultMaxDeviceMemory); set_param<float>(MS_CTX_MAX_DEVICE_MEMORY, kDefaultMaxDeviceMemory);
set_param<float>(MS_CTX_MEMPOOL_BLOCK_SIZE, kDefaultMempoolBlockSize);
set_param<std::string>(MS_CTX_PRINT_FILE_PATH, ""); set_param<std::string>(MS_CTX_PRINT_FILE_PATH, "");
set_param<bool>(MS_CTX_ENABLE_GRAPH_KERNEL, false); set_param<bool>(MS_CTX_ENABLE_GRAPH_KERNEL, false);
set_param<bool>(MS_CTX_ENABLE_SPARSE, false); set_param<bool>(MS_CTX_ENABLE_SPARSE, false);
@ -189,5 +190,4 @@ bool MsContext::enable_dump_ir() const {
return false; return false;
#endif #endif
} }
} // namespace mindspore } // namespace mindspore

View File

@ -60,6 +60,8 @@ const unsigned int MAX_CALL_DEPTH_DEFAULT = 1000;
const std::set<std::string> kTargetSet = {kCPUDevice, kGPUDevice, kAscendDevice, kDavinciDevice}; const std::set<std::string> kTargetSet = {kCPUDevice, kGPUDevice, kAscendDevice, kDavinciDevice};
// The default max available device memory is 1024GB. // The default max available device memory is 1024GB.
const float kDefaultMaxDeviceMemory = 1024; const float kDefaultMaxDeviceMemory = 1024;
// The default memory pool init percent is 0.0.
const float kDefaultMempoolBlockSize = 1.0;
// enum definition for MindSpore Context Parameter // enum definition for MindSpore Context Parameter
enum MsCtxParam : unsigned { enum MsCtxParam : unsigned {
@ -109,6 +111,7 @@ enum MsCtxParam : unsigned {
// parameter of type float // parameter of type float
MS_CTX_TYPE_FLOAT_BEGIN = MS_CTX_TYPE_UINT32_END, MS_CTX_TYPE_FLOAT_BEGIN = MS_CTX_TYPE_UINT32_END,
MS_CTX_MAX_DEVICE_MEMORY = MS_CTX_TYPE_FLOAT_BEGIN, MS_CTX_MAX_DEVICE_MEMORY = MS_CTX_TYPE_FLOAT_BEGIN,
MS_CTX_MEMPOOL_BLOCK_SIZE,
MS_CTX_TYPE_FLOAT_END, MS_CTX_TYPE_FLOAT_END,
// parameter of type string // parameter of type string