forked from mindspore-Ecosystem/mindspore
Optimize pynative device memory use
Add gradient to pynative unique Signed-off-by: zjun <zhangjun0@huawei.com>
This commit is contained in:
parent
4711194524
commit
eb450dd31f
|
@ -24,6 +24,8 @@ MindSpore上下文,用于配置当前执行环境,包括执行模式、执
|
||||||
| | max_device_memory | GPU |
|
| | max_device_memory | GPU |
|
||||||
| +------------------------------+----------------------------+
|
| +------------------------------+----------------------------+
|
||||||
| | variable_memory_max_size | Ascend |
|
| | variable_memory_max_size | Ascend |
|
||||||
|
| +------------------------------+----------------------------+
|
||||||
|
| | mempool_block_size | GPU/Ascend |
|
||||||
+-------------------------+------------------------------+----------------------------+
|
+-------------------------+------------------------------+----------------------------+
|
||||||
| 调试配置 | save_graphs | CPU/GPU/Ascend |
|
| 调试配置 | save_graphs | CPU/GPU/Ascend |
|
||||||
| +------------------------------+----------------------------+
|
| +------------------------------+----------------------------+
|
||||||
|
@ -76,6 +78,7 @@ MindSpore上下文,用于配置当前执行环境,包括执行模式、执
|
||||||
- **device_target** (str) - 表示待运行的目标设备,支持Ascend、GPU和CPU。如果未设置设备目标,则使用MindSpore包的版本。
|
- **device_target** (str) - 表示待运行的目标设备,支持Ascend、GPU和CPU。如果未设置设备目标,则使用MindSpore包的版本。
|
||||||
- **max_device_memory** (str) - 设置设备可用的最大内存。目前,仅在GPU上支持。格式为“xxGB”。默认值:1024GB。实际使用的内存大小是设备的可用内存和 `max_device_memory` 值中的最小值。
|
- **max_device_memory** (str) - 设置设备可用的最大内存。目前,仅在GPU上支持。格式为“xxGB”。默认值:1024GB。实际使用的内存大小是设备的可用内存和 `max_device_memory` 值中的最小值。
|
||||||
- **variable_memory_max_size** (str) - 设置可变内存的最大值。默认值:30GB。设置此参数后,框架使用的最大内存受配置值的限制。
|
- **variable_memory_max_size** (str) - 设置可变内存的最大值。默认值:30GB。设置此参数后,框架使用的最大内存受配置值的限制。
|
||||||
|
- **mempool_block_size** (str) - 设置PyNative模式下设备内存池的块大小。格式为“xxGB”。默认值:1GB。最小值是1GB。实际使用的内存池块大小是设备的可用内存和 `mempool_block_size` 值中的最小值。
|
||||||
- **save_graphs** (bool) - 表示是否保存图形。默认值:False。当 `save_graphs` 属性设为True时, `save_graphs_path` 属性用于设置中间编译图的存储路径。默认情况下,图形保存在当前目录下。
|
- **save_graphs** (bool) - 表示是否保存图形。默认值:False。当 `save_graphs` 属性设为True时, `save_graphs_path` 属性用于设置中间编译图的存储路径。默认情况下,图形保存在当前目录下。
|
||||||
- **save_graphs_path** (str) - 表示保存图形的路径。默认值:"."。如果指定的目录不存在,系统将自动创建该目录。在分布式训练中,图形将被保存到 `save_graphs_path/rank_${rank_id}/` 目录下。 `rank_id` 为集群中当前设备的ID。
|
- **save_graphs_path** (str) - 表示保存图形的路径。默认值:"."。如果指定的目录不存在,系统将自动创建该目录。在分布式训练中,图形将被保存到 `save_graphs_path/rank_${rank_id}/` 目录下。 `rank_id` 为集群中当前设备的ID。
|
||||||
- **enable_dump** (bool) - 此参数已弃用,将在下一版本中删除。
|
- **enable_dump** (bool) - 此参数已弃用,将在下一版本中删除。
|
||||||
|
@ -155,6 +158,7 @@ MindSpore上下文,用于配置当前执行环境,包括执行模式、执
|
||||||
... profiling_options='{"output":"/home/data/output","training_trace":"on"}')
|
... profiling_options='{"output":"/home/data/output","training_trace":"on"}')
|
||||||
>>> context.set_context(check_bprop=True)
|
>>> context.set_context(check_bprop=True)
|
||||||
>>> context.set_context(max_device_memory="3.5GB")
|
>>> context.set_context(max_device_memory="3.5GB")
|
||||||
|
>>> context.set_context(mempool_block_size="1GB")
|
||||||
>>> context.set_context(print_file_path="print.pb")
|
>>> context.set_context(print_file_path="print.pb")
|
||||||
>>> context.set_context(enable_sparse=True)
|
>>> context.set_context(enable_sparse=True)
|
||||||
>>> context.set_context(max_call_depth=80)
|
>>> context.set_context(max_call_depth=80)
|
||||||
|
|
|
@ -15,39 +15,51 @@
|
||||||
*/
|
*/
|
||||||
|
|
||||||
#include "backend/optimizer/mem_reuse/mem_dynamic_allocator.h"
|
#include "backend/optimizer/mem_reuse/mem_dynamic_allocator.h"
|
||||||
|
#include <string>
|
||||||
#include "utils/ms_utils.h"
|
#include "utils/ms_utils.h"
|
||||||
#include "utils/convert_utils.h"
|
#include "utils/convert_utils.h"
|
||||||
#include "utils/log_adapter.h"
|
#include "utils/log_adapter.h"
|
||||||
|
#include "utils/ms_context.h"
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace device {
|
namespace device {
|
||||||
DynamicMemPoolBestFit::~DynamicMemPoolBestFit() {
|
static const char kPynativeParamMem[] = "Pynative unique mem";
|
||||||
global_mem_block_list_.clear();
|
static const char kCommonMem[] = "Common mem";
|
||||||
global_idle_mem_buf_map_.clear();
|
const size_t kGBToByte = 1024 << 20;
|
||||||
|
|
||||||
|
static bool IsPynativeMode() {
|
||||||
|
auto ms_context = MsContext::GetInstance();
|
||||||
|
MS_EXCEPTION_IF_NULL(ms_context);
|
||||||
|
return ms_context->get_param<int>(MS_CTX_EXECUTION_MODE) == kPynativeMode;
|
||||||
}
|
}
|
||||||
|
|
||||||
DeviceMemPtr DynamicMemPoolBestFit::AllocTensorMem(size_t size) {
|
DynamicMemPoolBestFit::~DynamicMemPoolBestFit() {
|
||||||
|
persistent_mem_->clear();
|
||||||
|
common_mem_->clear();
|
||||||
|
}
|
||||||
|
|
||||||
|
DeviceMemPtr DynamicMemPoolBestFit::AllocTensorMem(size_t size, bool from_persistent_mem) {
|
||||||
size_t align_size = AlignMemorySize(size);
|
size_t align_size = AlignMemorySize(size);
|
||||||
std::lock_guard<std::mutex> locker(mutex_);
|
std::lock_guard<std::mutex> locker(mutex_);
|
||||||
// Find the idle memory buf by tensor size, if not find, then add new memory block and memory buf.
|
// Find the idle memory buf by tensor size, if not find, then add new memory block and memory buf.
|
||||||
DeviceMemPtr device_addr = FindIdleMemBuf(align_size);
|
DeviceMemPtr device_addr = FindIdleMemBuf(align_size, from_persistent_mem);
|
||||||
if (!device_addr) {
|
if (!device_addr) {
|
||||||
device_addr = AddMemBlockAndMemBuf(align_size);
|
device_addr = AddMemBlockAndMemBuf(align_size, from_persistent_mem);
|
||||||
}
|
}
|
||||||
return device_addr;
|
return device_addr;
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<DeviceMemPtr> DynamicMemPoolBestFit::AllocContinuousTensorMem(size_t total_size,
|
std::vector<DeviceMemPtr> DynamicMemPoolBestFit::AllocContinuousTensorMem(size_t total_size,
|
||||||
std::vector<size_t> size_list) {
|
const std::vector<size_t> &size_list) {
|
||||||
std::vector<DeviceMemPtr> device_addr_list;
|
std::vector<DeviceMemPtr> device_addr_list;
|
||||||
// Pre-alloc the one whole piece memory.
|
// Pre-alloc the one whole piece memory.
|
||||||
auto device_addr = AllocTensorMem(total_size);
|
auto device_addr = AllocTensorMem(total_size, false);
|
||||||
if (!device_addr) {
|
if (!device_addr) {
|
||||||
return device_addr_list;
|
return device_addr_list;
|
||||||
}
|
}
|
||||||
std::lock_guard<std::mutex> locker(mutex_);
|
std::lock_guard<std::mutex> locker(mutex_);
|
||||||
// Remove the pre-alloc memory.
|
// Remove the pre-alloc memory.
|
||||||
const auto &mem_block = FindMemBlock(device_addr);
|
const auto &mem_block = FindMemBlock(device_addr, common_mem_);
|
||||||
MS_EXCEPTION_IF_NULL(mem_block);
|
MS_EXCEPTION_IF_NULL(mem_block);
|
||||||
const auto &iter = mem_block->block_all_mem_buf_map_.find(device_addr);
|
const auto &iter = mem_block->block_all_mem_buf_map_.find(device_addr);
|
||||||
if (iter == mem_block->block_all_mem_buf_map_.end()) {
|
if (iter == mem_block->block_all_mem_buf_map_.end()) {
|
||||||
|
@ -63,11 +75,11 @@ std::vector<DeviceMemPtr> DynamicMemPoolBestFit::AllocContinuousTensorMem(size_t
|
||||||
// Split the pre-alloc memory into continuous memory by the size list.
|
// Split the pre-alloc memory into continuous memory by the size list.
|
||||||
DynamicMemBufPtr continuous_mem_buf;
|
DynamicMemBufPtr continuous_mem_buf;
|
||||||
auto buf_addr = device_addr;
|
auto buf_addr = device_addr;
|
||||||
for (size_t i = 0; i < size_list.size(); i++) {
|
for (size_t i : size_list) {
|
||||||
continuous_mem_buf = std::make_shared<DynamicMemBuf>(buf_addr, kMemBufUsed, size_list[i]);
|
continuous_mem_buf = std::make_shared<DynamicMemBuf>(buf_addr, kMemBufUsed, i);
|
||||||
(void)mem_block->block_all_mem_buf_map_.emplace(buf_addr, continuous_mem_buf);
|
(void)mem_block->block_all_mem_buf_map_.emplace(buf_addr, continuous_mem_buf);
|
||||||
device_addr_list.emplace_back(buf_addr);
|
device_addr_list.emplace_back(buf_addr);
|
||||||
buf_addr = AddressOffset(buf_addr, size_list[i]);
|
buf_addr = AddressOffset(buf_addr, i);
|
||||||
}
|
}
|
||||||
// Update the size of the last memory buf.
|
// Update the size of the last memory buf.
|
||||||
continuous_mem_buf->size_ += rest_size;
|
continuous_mem_buf->size_ += rest_size;
|
||||||
|
@ -81,9 +93,14 @@ size_t DynamicMemPoolBestFit::AlignMemorySize(size_t size) const {
|
||||||
return ((size + DYNAMIC_MEM_ALIGN_SIZE - 1) / DYNAMIC_MEM_ALIGN_SIZE) * DYNAMIC_MEM_ALIGN_SIZE;
|
return ((size + DYNAMIC_MEM_ALIGN_SIZE - 1) / DYNAMIC_MEM_ALIGN_SIZE) * DYNAMIC_MEM_ALIGN_SIZE;
|
||||||
}
|
}
|
||||||
|
|
||||||
DeviceMemPtr DynamicMemPoolBestFit::FindIdleMemBuf(size_t size) {
|
DeviceMemPtr DynamicMemPoolBestFit::FindIdleMemBuf(size_t size, bool from_persistent_mem) {
|
||||||
const auto &iter = global_idle_mem_buf_map_.lower_bound(size);
|
auto mem_mng = common_mem_;
|
||||||
if (iter != global_idle_mem_buf_map_.end()) {
|
if (from_persistent_mem) {
|
||||||
|
mem_mng = persistent_mem_;
|
||||||
|
}
|
||||||
|
MS_EXCEPTION_IF_NULL(mem_mng);
|
||||||
|
const auto &iter = mem_mng->idle_mem_buf_map_.lower_bound(size);
|
||||||
|
if (iter != mem_mng->idle_mem_buf_map_.end()) {
|
||||||
auto mem_buf = iter->second;
|
auto mem_buf = iter->second;
|
||||||
MS_EXCEPTION_IF_NULL(mem_buf);
|
MS_EXCEPTION_IF_NULL(mem_buf);
|
||||||
if (mem_buf->status_ != kMemBufIdle) {
|
if (mem_buf->status_ != kMemBufIdle) {
|
||||||
|
@ -92,24 +109,69 @@ DeviceMemPtr DynamicMemPoolBestFit::FindIdleMemBuf(size_t size) {
|
||||||
}
|
}
|
||||||
mem_buf->status_ = kMemBufUsed;
|
mem_buf->status_ = kMemBufUsed;
|
||||||
// Remove map of old idle memory buf
|
// Remove map of old idle memory buf
|
||||||
(void)global_idle_mem_buf_map_.erase(iter);
|
(void)mem_mng->idle_mem_buf_map_.erase(iter);
|
||||||
// Divide memory buf
|
// Divide memory buf
|
||||||
if (IsDivide(size, mem_buf->size_)) {
|
if (IsSplit(size, mem_buf->size_)) {
|
||||||
DivideMemBuf(size, mem_buf);
|
SplitMemBuf(size, mem_buf, mem_mng);
|
||||||
}
|
}
|
||||||
// Memory statistics
|
// Memory statistics
|
||||||
total_used_mem_statistics_ += mem_buf->size_;
|
mem_mng->mps_.total_used_mem_size_ += mem_buf->size_;
|
||||||
if (total_used_mem_statistics_ > used_mem_peak_statistics_) {
|
if (mem_mng->mps_.total_used_mem_size_ > mem_mng->mps_.used_mem_peak_size_) {
|
||||||
used_mem_peak_statistics_ = total_used_mem_statistics_;
|
mem_mng->mps_.used_mem_peak_size_ = mem_mng->mps_.total_used_mem_size_;
|
||||||
}
|
}
|
||||||
return mem_buf->device_addr_;
|
return mem_buf->device_addr_;
|
||||||
}
|
}
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
|
|
||||||
DeviceMemPtr DynamicMemPoolBestFit::AddMemBlockAndMemBuf(size_t size) {
|
size_t DynamicMemPoolBestFit::MemAllocUnitSize(bool from_persistent_mem) const {
|
||||||
size_t alloc_mem_size = CalMemBlockAllocSize(size);
|
return from_persistent_mem ? persistent_mem_->unit_size_ : common_mem_->unit_size_;
|
||||||
|
}
|
||||||
|
|
||||||
|
void DynamicMemPoolBestFit::SetMemAllocUintSize(size_t size) {
|
||||||
|
auto ms_context = MsContext::GetInstance();
|
||||||
|
MS_EXCEPTION_IF_NULL(ms_context);
|
||||||
|
persistent_mem_->unit_size_ = DYNAMIC_MEM_ALLOC_UNIT_SIZE;
|
||||||
|
common_mem_->unit_size_ = size - persistent_mem_->unit_size_;
|
||||||
|
MS_LOG(INFO) << "Set mem alloc unit size " << size;
|
||||||
|
}
|
||||||
|
|
||||||
|
void DynamicMemPoolBestFit::SetMempoolBlockSize(size_t device_mem_size) {
|
||||||
|
if (!IsPynativeMode()) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
auto ms_context = MsContext::GetInstance();
|
||||||
|
MS_EXCEPTION_IF_NULL(ms_context);
|
||||||
|
float mem_block_size = ms_context->get_param<float>(MS_CTX_MEMPOOL_BLOCK_SIZE);
|
||||||
|
if (mem_block_size == kDefaultMempoolBlockSize) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
size_t config_size = FloatToSize(mem_block_size * kGBToByte);
|
||||||
|
size_t real_block_size = std::min(config_size, device_mem_size);
|
||||||
|
if (config_size > device_mem_size) {
|
||||||
|
MS_LOG(WARNING) << "Memory pool block size " << config_size << " is bigger than currently available maximum memory "
|
||||||
|
<< device_mem_size << ", and the actual effective value will be " << device_mem_size;
|
||||||
|
}
|
||||||
|
SetMemAllocUintSize(real_block_size);
|
||||||
|
}
|
||||||
|
|
||||||
|
DeviceMemPtr DynamicMemPoolBestFit::AddMemBlockAndMemBuf(size_t size, bool from_persistent_mem) {
|
||||||
|
// Pyantive unique mem is not enough, find from common
|
||||||
|
if (from_persistent_mem && !persistent_mem_->mem_block_list_.empty()) {
|
||||||
|
auto mem_addr = FindIdleMemBuf(size, false);
|
||||||
|
if (mem_addr != nullptr) {
|
||||||
|
return mem_addr;
|
||||||
|
}
|
||||||
|
from_persistent_mem = false;
|
||||||
|
}
|
||||||
|
size_t alloc_mem_size = CalMemBlockAllocSize(size, from_persistent_mem);
|
||||||
if (alloc_mem_size == 0) {
|
if (alloc_mem_size == 0) {
|
||||||
|
MS_LOG(DEBUG) << "Try to find in other mem";
|
||||||
|
auto mem_addr = FindIdleMemBuf(size, !from_persistent_mem);
|
||||||
|
if (mem_addr != nullptr) {
|
||||||
|
return mem_addr;
|
||||||
|
}
|
||||||
|
DumpDynamicMemPoolInfo();
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
// Add new memory block
|
// Add new memory block
|
||||||
|
@ -118,40 +180,50 @@ DeviceMemPtr DynamicMemPoolBestFit::AddMemBlockAndMemBuf(size_t size) {
|
||||||
if (real_alloc_size < size) {
|
if (real_alloc_size < size) {
|
||||||
MS_LOG(WARNING) << "Memory not enough: alloc size[" << real_alloc_size << "] is smaller than required size[" << size
|
MS_LOG(WARNING) << "Memory not enough: alloc size[" << real_alloc_size << "] is smaller than required size[" << size
|
||||||
<< "].";
|
<< "].";
|
||||||
|
DumpDynamicMemPoolInfo();
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
mem_alloc_unit_size_ = DYNAMIC_MEM_ALLOC_UNIT_SIZE;
|
// In graph mode, unit_size are set once using an estimated memory value size, and subsequent memory requests use the
|
||||||
|
// default size
|
||||||
|
if (!IsPynativeMode()) {
|
||||||
|
common_mem_->unit_size_ = DYNAMIC_MEM_ALLOC_UNIT_SIZE;
|
||||||
|
}
|
||||||
|
auto mem_mng = common_mem_;
|
||||||
|
if (from_persistent_mem) {
|
||||||
|
mem_mng = persistent_mem_;
|
||||||
|
}
|
||||||
|
MS_EXCEPTION_IF_NULL(mem_mng);
|
||||||
auto mem_block = std::make_shared<DynamicMemBlock>(device_addr, real_alloc_size);
|
auto mem_block = std::make_shared<DynamicMemBlock>(device_addr, real_alloc_size);
|
||||||
MS_EXCEPTION_IF_NULL(mem_block);
|
MS_EXCEPTION_IF_NULL(mem_block);
|
||||||
const auto &iter =
|
const auto &iter =
|
||||||
std::upper_bound(global_mem_block_list_.begin(), global_mem_block_list_.end(), device_addr, CmpMemBlock);
|
std::upper_bound(mem_mng->mem_block_list_.begin(), mem_mng->mem_block_list_.end(), device_addr, CmpMemBlock);
|
||||||
(void)global_mem_block_list_.insert(iter, mem_block);
|
(void)mem_mng->mem_block_list_.insert(iter, mem_block);
|
||||||
// Add new memory buf
|
// Add new memory buf
|
||||||
auto mem_buf = std::make_shared<DynamicMemBuf>(device_addr, kMemBufUsed, real_alloc_size);
|
auto mem_buf = std::make_shared<DynamicMemBuf>(device_addr, kMemBufUsed, real_alloc_size);
|
||||||
MS_EXCEPTION_IF_NULL(mem_buf);
|
MS_EXCEPTION_IF_NULL(mem_buf);
|
||||||
// Add map of new memory buf in the block
|
// Add map of new memory buf in the block
|
||||||
(void)mem_block->block_all_mem_buf_map_.emplace(device_addr, mem_buf);
|
(void)mem_block->block_all_mem_buf_map_.emplace(device_addr, mem_buf);
|
||||||
// Divide memory buf
|
// Split memory buf
|
||||||
if (IsDivide(size, mem_buf->size_)) {
|
if (IsSplit(size, mem_buf->size_)) {
|
||||||
DivideMemBuf(size, mem_buf);
|
SplitMemBuf(size, mem_buf, mem_mng);
|
||||||
}
|
}
|
||||||
// Memory statistics
|
// Memory statistics
|
||||||
total_mem_statistics_ += real_alloc_size;
|
mem_mng->mps_.total_mem_size_ += real_alloc_size;
|
||||||
total_used_mem_statistics_ += mem_buf->size_;
|
mem_mng->mps_.total_used_mem_size_ += mem_buf->size_;
|
||||||
if (total_used_mem_statistics_ > used_mem_peak_statistics_) {
|
if (mem_mng->mps_.total_used_mem_size_ > mem_mng->mps_.used_mem_peak_size_) {
|
||||||
used_mem_peak_statistics_ = total_used_mem_statistics_;
|
mem_mng->mps_.used_mem_peak_size_ = mem_mng->mps_.total_used_mem_size_;
|
||||||
}
|
}
|
||||||
return mem_buf->device_addr_;
|
return mem_buf->device_addr_;
|
||||||
}
|
}
|
||||||
|
|
||||||
size_t DynamicMemPoolBestFit::CalMemBlockAllocSize(size_t size) {
|
size_t DynamicMemPoolBestFit::CalMemBlockAllocSize(size_t size, bool from_persistent_mem) {
|
||||||
auto device_free_mem_size = free_mem_size();
|
auto device_free_mem_size = free_mem_size();
|
||||||
if (device_free_mem_size < size) {
|
if (device_free_mem_size < size) {
|
||||||
MS_LOG(WARNING) << "Memory not enough: current free memory size[" << device_free_mem_size
|
MS_LOG(WARNING) << "Memory not enough: current free memory size[" << device_free_mem_size
|
||||||
<< "] is smaller than required size[" << size << "].";
|
<< "] is smaller than required size[" << size << "].";
|
||||||
return 0;
|
return 0;
|
||||||
}
|
}
|
||||||
auto alloc_mem_size = mem_alloc_unit_size();
|
auto alloc_mem_size = MemAllocUnitSize(from_persistent_mem);
|
||||||
// Growing at twice of alloc size
|
// Growing at twice of alloc size
|
||||||
constexpr size_t kDouble = 2;
|
constexpr size_t kDouble = 2;
|
||||||
while (alloc_mem_size < size) {
|
while (alloc_mem_size < size) {
|
||||||
|
@ -161,13 +233,14 @@ size_t DynamicMemPoolBestFit::CalMemBlockAllocSize(size_t size) {
|
||||||
return alloc_mem_size;
|
return alloc_mem_size;
|
||||||
}
|
}
|
||||||
|
|
||||||
bool DynamicMemPoolBestFit::IsDivide(size_t tensor_size, size_t mem_buf_size) const {
|
bool DynamicMemPoolBestFit::IsSplit(size_t tensor_size, size_t mem_buf_size) const {
|
||||||
return mem_buf_size - tensor_size >= DYNAMIC_MEM_ALIGN_SIZE;
|
return mem_buf_size - tensor_size >= DYNAMIC_MEM_ALIGN_SIZE;
|
||||||
}
|
}
|
||||||
|
|
||||||
void DynamicMemPoolBestFit::DivideMemBuf(size_t size, const DynamicMemBufPtr &mem_buf) {
|
void DynamicMemPoolBestFit::SplitMemBuf(size_t size, const DynamicMemBufPtr &mem_buf,
|
||||||
|
const MemStatusManagerPtr &mem_mng) {
|
||||||
MS_EXCEPTION_IF_NULL(mem_buf);
|
MS_EXCEPTION_IF_NULL(mem_buf);
|
||||||
const auto &mem_block = FindMemBlock(mem_buf->device_addr_);
|
const auto &mem_block = FindMemBlock(mem_buf->device_addr_, mem_mng);
|
||||||
MS_EXCEPTION_IF_NULL(mem_block);
|
MS_EXCEPTION_IF_NULL(mem_block);
|
||||||
// Divide new memory buf
|
// Divide new memory buf
|
||||||
if (mem_buf->size_ < size) {
|
if (mem_buf->size_ < size) {
|
||||||
|
@ -180,7 +253,7 @@ void DynamicMemPoolBestFit::DivideMemBuf(size_t size, const DynamicMemBufPtr &me
|
||||||
// Add map of new memory buf in the block
|
// Add map of new memory buf in the block
|
||||||
(void)mem_block->block_all_mem_buf_map_.emplace(newbuf_addr, new_mem_buf);
|
(void)mem_block->block_all_mem_buf_map_.emplace(newbuf_addr, new_mem_buf);
|
||||||
// Add map of new idle memory buf
|
// Add map of new idle memory buf
|
||||||
(void)global_idle_mem_buf_map_.emplace(newbuf_size, new_mem_buf);
|
(void)mem_mng->idle_mem_buf_map_.emplace(newbuf_size, new_mem_buf);
|
||||||
}
|
}
|
||||||
|
|
||||||
bool DynamicMemPoolBestFit::CmpMemBlock(const DeviceMemPtr &device_addr, const DynamicMemBlockPtr &mem_block) {
|
bool DynamicMemPoolBestFit::CmpMemBlock(const DeviceMemPtr &device_addr, const DynamicMemBlockPtr &mem_block) {
|
||||||
|
@ -189,11 +262,12 @@ bool DynamicMemPoolBestFit::CmpMemBlock(const DeviceMemPtr &device_addr, const D
|
||||||
return device_addr < mem_block->device_addr();
|
return device_addr < mem_block->device_addr();
|
||||||
}
|
}
|
||||||
|
|
||||||
DynamicMemBlockPtr DynamicMemPoolBestFit::FindMemBlock(const DeviceMemPtr &device_addr) {
|
DynamicMemBlockPtr DynamicMemPoolBestFit::FindMemBlock(const DeviceMemPtr &device_addr,
|
||||||
|
const MemStatusManagerPtr &mem_mng) {
|
||||||
MS_EXCEPTION_IF_NULL(device_addr);
|
MS_EXCEPTION_IF_NULL(device_addr);
|
||||||
auto &&iter =
|
auto &&iter =
|
||||||
std::upper_bound(global_mem_block_list_.begin(), global_mem_block_list_.end(), device_addr, CmpMemBlock);
|
std::upper_bound(mem_mng->mem_block_list_.begin(), mem_mng->mem_block_list_.end(), device_addr, CmpMemBlock);
|
||||||
if (iter != global_mem_block_list_.begin()) {
|
if (iter != mem_mng->mem_block_list_.begin()) {
|
||||||
return *(--iter);
|
return *(--iter);
|
||||||
}
|
}
|
||||||
return nullptr;
|
return nullptr;
|
||||||
|
@ -202,16 +276,32 @@ DynamicMemBlockPtr DynamicMemPoolBestFit::FindMemBlock(const DeviceMemPtr &devic
|
||||||
void DynamicMemPoolBestFit::FreeTensorMem(const DeviceMemPtr &device_addr) {
|
void DynamicMemPoolBestFit::FreeTensorMem(const DeviceMemPtr &device_addr) {
|
||||||
MS_EXCEPTION_IF_NULL(device_addr);
|
MS_EXCEPTION_IF_NULL(device_addr);
|
||||||
std::lock_guard<std::mutex> locker(mutex_);
|
std::lock_guard<std::mutex> locker(mutex_);
|
||||||
const auto &mem_block = FindMemBlock(device_addr);
|
auto fn = [this](const MemStatusManagerPtr &mem_mng, const DeviceMemPtr &device_addr) -> DynamicMemBlockPtr {
|
||||||
|
auto mem_block = FindMemBlock(device_addr, mem_mng);
|
||||||
|
if (mem_block != nullptr) {
|
||||||
|
const auto &iter = mem_block->block_all_mem_buf_map_.find(device_addr);
|
||||||
|
if (iter != mem_block->block_all_mem_buf_map_.end()) {
|
||||||
|
return mem_block;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nullptr;
|
||||||
|
};
|
||||||
|
auto mem_block = fn(common_mem_, device_addr);
|
||||||
if (mem_block == nullptr) {
|
if (mem_block == nullptr) {
|
||||||
// May be destroy the memory pool first, then destroy the address, so this is normal case.
|
mem_block = fn(persistent_mem_, device_addr);
|
||||||
MS_LOG(DEBUG) << "Can't find the mem_block of the device address[" << device_addr << "].";
|
if (mem_block == nullptr) {
|
||||||
return;
|
// May be destroy the memory pool first, then destroy the address, so this is normal case.
|
||||||
|
MS_LOG(DEBUG) << "Can't find the mem_block of the device address[" << device_addr << "].";
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
CombineMemBuf(mem_block, device_addr, persistent_mem_);
|
||||||
|
} else {
|
||||||
|
CombineMemBuf(mem_block, device_addr, common_mem_);
|
||||||
}
|
}
|
||||||
CombineMemBuf(mem_block, device_addr);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
void DynamicMemPoolBestFit::CombineMemBuf(const DynamicMemBlockPtr &mem_block, const DeviceMemPtr &device_addr) {
|
void DynamicMemPoolBestFit::CombineMemBuf(const DynamicMemBlockPtr &mem_block, const DeviceMemPtr &device_addr,
|
||||||
|
const MemStatusManagerPtr &mem_mng) {
|
||||||
MS_EXCEPTION_IF_NULL(mem_block);
|
MS_EXCEPTION_IF_NULL(mem_block);
|
||||||
MS_EXCEPTION_IF_NULL(device_addr);
|
MS_EXCEPTION_IF_NULL(device_addr);
|
||||||
const auto &iter = mem_block->block_all_mem_buf_map_.find(device_addr);
|
const auto &iter = mem_block->block_all_mem_buf_map_.find(device_addr);
|
||||||
|
@ -224,10 +314,10 @@ void DynamicMemPoolBestFit::CombineMemBuf(const DynamicMemBlockPtr &mem_block, c
|
||||||
MS_LOG(EXCEPTION) << "Find the mem_buf is not used, mem_buf_address[" << mem_buf->device_addr_ << "].";
|
MS_LOG(EXCEPTION) << "Find the mem_buf is not used, mem_buf_address[" << mem_buf->device_addr_ << "].";
|
||||||
}
|
}
|
||||||
mem_buf->status_ = kMemBufIdle;
|
mem_buf->status_ = kMemBufIdle;
|
||||||
if (total_used_mem_statistics_ < mem_buf->size_) {
|
if (mem_mng->mps_.total_used_mem_size_ < mem_buf->size_) {
|
||||||
MS_LOG(EXCEPTION) << "The total used mem size is less than the size of membuf.";
|
MS_LOG(EXCEPTION) << "The total used mem size is less than the size of membuf.";
|
||||||
}
|
}
|
||||||
total_used_mem_statistics_ -= mem_buf->size_;
|
mem_mng->mps_.total_used_mem_size_ -= mem_buf->size_;
|
||||||
// Combine backward(combine the next_mem_buf to mem_buf)
|
// Combine backward(combine the next_mem_buf to mem_buf)
|
||||||
auto next_iter = iter;
|
auto next_iter = iter;
|
||||||
(void)next_iter++;
|
(void)next_iter++;
|
||||||
|
@ -236,7 +326,7 @@ void DynamicMemPoolBestFit::CombineMemBuf(const DynamicMemBlockPtr &mem_block, c
|
||||||
MS_EXCEPTION_IF_NULL(next_mem_buf);
|
MS_EXCEPTION_IF_NULL(next_mem_buf);
|
||||||
if (next_mem_buf->status_ == kMemBufIdle) {
|
if (next_mem_buf->status_ == kMemBufIdle) {
|
||||||
mem_buf->size_ += next_mem_buf->size_;
|
mem_buf->size_ += next_mem_buf->size_;
|
||||||
EraseIdleMemBuf(next_mem_buf->size_, next_mem_buf->device_addr_);
|
EraseIdleMemBuf(next_mem_buf->size_, next_mem_buf->device_addr_, mem_mng);
|
||||||
(void)mem_block->block_all_mem_buf_map_.erase(next_iter);
|
(void)mem_block->block_all_mem_buf_map_.erase(next_iter);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -249,7 +339,7 @@ void DynamicMemPoolBestFit::CombineMemBuf(const DynamicMemBlockPtr &mem_block, c
|
||||||
prev_mem_buf = prev_iter->second;
|
prev_mem_buf = prev_iter->second;
|
||||||
MS_EXCEPTION_IF_NULL(prev_mem_buf);
|
MS_EXCEPTION_IF_NULL(prev_mem_buf);
|
||||||
if (prev_mem_buf->status_ == kMemBufIdle) {
|
if (prev_mem_buf->status_ == kMemBufIdle) {
|
||||||
EraseIdleMemBuf(prev_mem_buf->size_, prev_mem_buf->device_addr_);
|
EraseIdleMemBuf(prev_mem_buf->size_, prev_mem_buf->device_addr_, mem_mng);
|
||||||
prev_mem_buf->size_ += mem_buf->size_;
|
prev_mem_buf->size_ += mem_buf->size_;
|
||||||
(void)mem_block->block_all_mem_buf_map_.erase(iter);
|
(void)mem_block->block_all_mem_buf_map_.erase(iter);
|
||||||
forward_combine = true;
|
forward_combine = true;
|
||||||
|
@ -257,20 +347,21 @@ void DynamicMemPoolBestFit::CombineMemBuf(const DynamicMemBlockPtr &mem_block, c
|
||||||
}
|
}
|
||||||
// Add map of new idle memory
|
// Add map of new idle memory
|
||||||
if (forward_combine) {
|
if (forward_combine) {
|
||||||
(void)global_idle_mem_buf_map_.emplace(prev_mem_buf->size_, prev_mem_buf);
|
(void)mem_mng->idle_mem_buf_map_.emplace(prev_mem_buf->size_, prev_mem_buf);
|
||||||
} else {
|
} else {
|
||||||
(void)global_idle_mem_buf_map_.emplace(mem_buf->size_, mem_buf);
|
(void)mem_mng->idle_mem_buf_map_.emplace(mem_buf->size_, mem_buf);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void DynamicMemPoolBestFit::EraseIdleMemBuf(size_t size, const DeviceMemPtr &device_addr) {
|
void DynamicMemPoolBestFit::EraseIdleMemBuf(size_t size, const DeviceMemPtr &device_addr,
|
||||||
|
const MemStatusManagerPtr &mem_mng) {
|
||||||
MS_EXCEPTION_IF_NULL(device_addr);
|
MS_EXCEPTION_IF_NULL(device_addr);
|
||||||
auto &&iter = global_idle_mem_buf_map_.equal_range(size);
|
auto &&iter = mem_mng->idle_mem_buf_map_.equal_range(size);
|
||||||
while (iter.first != iter.second) {
|
while (iter.first != iter.second) {
|
||||||
MS_EXCEPTION_IF_NULL(iter.first->second);
|
MS_EXCEPTION_IF_NULL(iter.first->second);
|
||||||
// Remove map of the idle memory buf by size and device address
|
// Remove map of the idle memory buf by size and device address
|
||||||
if (iter.first->second->device_addr_ == device_addr) {
|
if (iter.first->second->device_addr_ == device_addr) {
|
||||||
(void)global_idle_mem_buf_map_.erase(iter.first);
|
(void)mem_mng->idle_mem_buf_map_.erase(iter.first);
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
(void)iter.first++;
|
(void)iter.first++;
|
||||||
|
@ -280,70 +371,48 @@ void DynamicMemPoolBestFit::EraseIdleMemBuf(size_t size, const DeviceMemPtr &dev
|
||||||
|
|
||||||
void DynamicMemPoolBestFit::ReleaseDeviceRes() {
|
void DynamicMemPoolBestFit::ReleaseDeviceRes() {
|
||||||
std::lock_guard<std::mutex> locker(mutex_);
|
std::lock_guard<std::mutex> locker(mutex_);
|
||||||
MS_LOG(INFO) << "The dynamic memory pool total size is " << total_mem_statistics_ << ", total used size is "
|
auto fn = [this](const MemStatusManagerPtr &mem_mng) {
|
||||||
<< total_used_mem_statistics_ << ", used peak size is " << used_mem_peak_statistics_ << ".";
|
for (auto &iter : mem_mng->mem_block_list_) {
|
||||||
for (auto iter = global_mem_block_list_.begin(); iter != global_mem_block_list_.end(); ++iter) {
|
auto &device_addr = iter->device_addr_base_;
|
||||||
auto &device_addr = (*iter)->device_addr_base_;
|
if (device_addr != nullptr) {
|
||||||
if (device_addr != nullptr) {
|
if (!FreeDeviceMem(device_addr)) {
|
||||||
if (!FreeDeviceMem(device_addr)) {
|
MS_LOG(EXCEPTION) << "Free device memory[" << device_addr << "] error.";
|
||||||
MS_LOG(EXCEPTION) << "Free device memory[" << device_addr << "] error.";
|
}
|
||||||
|
device_addr = nullptr;
|
||||||
}
|
}
|
||||||
device_addr = nullptr;
|
|
||||||
}
|
}
|
||||||
}
|
mem_mng->mem_block_list_.clear();
|
||||||
|
mem_mng->idle_mem_buf_map_.clear();
|
||||||
global_mem_block_list_.clear();
|
};
|
||||||
global_idle_mem_buf_map_.clear();
|
fn(common_mem_);
|
||||||
|
fn(persistent_mem_);
|
||||||
}
|
}
|
||||||
|
|
||||||
void DynamicMemPoolBestFit::DumpDynamicMemPoolInfo() {
|
void DynamicMemPoolBestFit::DumpDynamicMemPoolInfo() {
|
||||||
std::lock_guard<std::mutex> locker(mutex_);
|
auto fn = [](const MemStatusManagerPtr &mem_mng, const std::string &mem_type) {
|
||||||
MS_LOG(INFO) << "Start dump dynamic memory pool info.";
|
if (mem_mng->mem_block_list_.empty()) {
|
||||||
DeviceAddrMapMemBuf mem_block_map;
|
return;
|
||||||
DynamicMemBufPtr mem_buf;
|
|
||||||
size_t total_mem = 0;
|
|
||||||
size_t total_used_mem = 0;
|
|
||||||
size_t total_idle_mem1 = 0;
|
|
||||||
size_t total_idle_mem2 = 0;
|
|
||||||
// Dump the memory block info and memory buf info
|
|
||||||
MS_LOG(INFO) << "Dump all mem_block info: counts[" << global_mem_block_list_.size() << "].";
|
|
||||||
for (auto iter = global_mem_block_list_.begin(); iter != global_mem_block_list_.end(); ++iter) {
|
|
||||||
total_mem += (*iter)->size();
|
|
||||||
mem_block_map = (*iter)->block_all_mem_buf_map_;
|
|
||||||
MS_LOG(INFO) << "MemBlock info: number[" << iter - global_mem_block_list_.begin() << "] mem_buf_counts["
|
|
||||||
<< mem_block_map.size() << "] base_address[" << (*iter)->device_addr() << "] block_size["
|
|
||||||
<< (*iter)->size() << "].";
|
|
||||||
for (auto iter_mem_buf = mem_block_map.begin(); iter_mem_buf != mem_block_map.end(); ++iter_mem_buf) {
|
|
||||||
mem_buf = iter_mem_buf->second;
|
|
||||||
MS_EXCEPTION_IF_NULL(mem_buf);
|
|
||||||
if (mem_buf->status_ == kMemBufIdle) {
|
|
||||||
total_idle_mem1 += mem_buf->size_;
|
|
||||||
} else {
|
|
||||||
total_used_mem += mem_buf->size_;
|
|
||||||
}
|
|
||||||
MS_LOG(INFO) << "MemBuf info: address[" << mem_buf->device_addr_ << "] size[" << mem_buf->size_ << "] status["
|
|
||||||
<< mem_buf->status_ << "].";
|
|
||||||
}
|
}
|
||||||
}
|
std::ostringstream buf;
|
||||||
// Dump all the idle memory buf info
|
for (size_t i = 0; i < mem_mng->mem_block_list_.size(); ++i) {
|
||||||
MS_LOG(INFO) << "Dump all idle mem_buf info: counts[" << global_idle_mem_buf_map_.size() << "].";
|
size_t idle_size = 0;
|
||||||
for (auto iter_idle = global_idle_mem_buf_map_.begin(); iter_idle != global_idle_mem_buf_map_.end(); ++iter_idle) {
|
for (auto mb = mem_mng->mem_block_list_[i]->block_all_mem_buf_map_.begin();
|
||||||
mem_buf = iter_idle->second;
|
mb != mem_mng->mem_block_list_[i]->block_all_mem_buf_map_.end(); ++mb) {
|
||||||
MS_EXCEPTION_IF_NULL(mem_buf);
|
if (mb->second->status_ == kMemBufIdle) {
|
||||||
total_idle_mem2 += mem_buf->size_;
|
idle_size += mb->second->size_;
|
||||||
MS_LOG(INFO) << "Idle mem_buf info: size[" << mem_buf->size_ << "] address[" << mem_buf->device_addr_ << "] status["
|
}
|
||||||
<< mem_buf->status_ << "].";
|
}
|
||||||
}
|
buf << ", block[" << i << "] idle size " << idle_size;
|
||||||
// Dump the memory statistical info
|
}
|
||||||
MS_LOG(INFO) << "Total allocated memory[" << total_mem << "], used memory[" << total_used_mem << "], idle memory["
|
// Dump all the memory buf info
|
||||||
<< total_idle_mem1 << "].";
|
MS_LOG(WARNING) << mem_type << "pool info: block size " << mem_mng->unit_size_ << ", block counts "
|
||||||
if (total_idle_mem1 != total_idle_mem2) {
|
<< mem_mng->mem_block_list_.size() << buf.str() << ". Total allocated mem "
|
||||||
MS_LOG(ERROR) << "Check error: the idle memory in the mem_block is not equal the global idle memory.";
|
<< mem_mng->mps_.total_mem_size_ << ", peak used mem " << mem_mng->mps_.used_mem_peak_size_
|
||||||
}
|
<< ", in used mem " << mem_mng->mps_.total_used_mem_size_ << ", total idle mem "
|
||||||
if (total_mem != total_used_mem + total_idle_mem1) {
|
<< mem_mng->mps_.total_mem_size_ - mem_mng->mps_.total_used_mem_size_;
|
||||||
MS_LOG(ERROR) << "Check error: the the total memory is not equal the sum of used memory and idle memory.";
|
};
|
||||||
}
|
fn(common_mem_, std::string(kCommonMem));
|
||||||
MS_LOG(INFO) << "Finish dump dynamic memory pool info.";
|
fn(persistent_mem_, std::string(kPynativeParamMem));
|
||||||
}
|
}
|
||||||
} // namespace device
|
} // namespace device
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
|
@ -36,7 +36,7 @@ enum DynamicMemBufStatus : int { kMemBufIdle, kMemBufUsed };
|
||||||
static const size_t DYNAMIC_MEM_ALIGN_SIZE = 512;
|
static const size_t DYNAMIC_MEM_ALIGN_SIZE = 512;
|
||||||
|
|
||||||
// The minimum unit size (1G) of memory block used for dynamic extend.
|
// The minimum unit size (1G) of memory block used for dynamic extend.
|
||||||
static const size_t DYNAMIC_MEM_ALLOC_UNIT_SIZE = 1024 << 20;
|
static const size_t DYNAMIC_MEM_ALLOC_UNIT_SIZE = 1073741824;
|
||||||
|
|
||||||
// The Comparator of device address from small to large.
|
// The Comparator of device address from small to large.
|
||||||
struct DeviceAddrCmp {
|
struct DeviceAddrCmp {
|
||||||
|
@ -77,16 +77,40 @@ class DynamicMemBlock {
|
||||||
};
|
};
|
||||||
using DynamicMemBlockPtr = std::shared_ptr<DynamicMemBlock>;
|
using DynamicMemBlockPtr = std::shared_ptr<DynamicMemBlock>;
|
||||||
|
|
||||||
|
struct DeviceState {
|
||||||
|
// Memory allocated from device
|
||||||
|
size_t total_mem_size_{0};
|
||||||
|
// Memory in use
|
||||||
|
size_t total_used_mem_size_{0};
|
||||||
|
// Maximum peak memory usage
|
||||||
|
size_t used_mem_peak_size_{0};
|
||||||
|
};
|
||||||
|
|
||||||
|
struct MemStatusManager {
|
||||||
|
size_t unit_size_{DYNAMIC_MEM_ALLOC_UNIT_SIZE};
|
||||||
|
// Mempool state
|
||||||
|
DeviceState mps_;
|
||||||
|
std::vector<DynamicMemBlockPtr> mem_block_list_;
|
||||||
|
// The map of all idle memory buf by size.
|
||||||
|
SizeMapMemBuf idle_mem_buf_map_;
|
||||||
|
void clear() {
|
||||||
|
mem_block_list_.clear();
|
||||||
|
idle_mem_buf_map_.clear();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
using MemStatusManagerPtr = std::shared_ptr<MemStatusManager>;
|
||||||
|
|
||||||
// The main class of dynamic memory pool.
|
// The main class of dynamic memory pool.
|
||||||
class DynamicMemPoolBestFit {
|
class DynamicMemPoolBestFit {
|
||||||
public:
|
public:
|
||||||
DynamicMemPoolBestFit() = default;
|
DynamicMemPoolBestFit()
|
||||||
|
: persistent_mem_(std::make_shared<MemStatusManager>()), common_mem_(std::make_shared<MemStatusManager>()) {}
|
||||||
virtual ~DynamicMemPoolBestFit();
|
virtual ~DynamicMemPoolBestFit();
|
||||||
|
|
||||||
// The main program entry of memory alloc.
|
// The main program entry of memory alloc.
|
||||||
DeviceMemPtr AllocTensorMem(size_t size);
|
DeviceMemPtr AllocTensorMem(size_t size, bool from_persistent_mem = false);
|
||||||
// The main program entry of continuous memory alloc.
|
// The main program entry of continuous memory alloc.
|
||||||
std::vector<DeviceMemPtr> AllocContinuousTensorMem(size_t total_size, std::vector<size_t> size_list);
|
std::vector<DeviceMemPtr> AllocContinuousTensorMem(size_t total_size, const std::vector<size_t> &size_list);
|
||||||
// The main program entry of memory free.
|
// The main program entry of memory free.
|
||||||
void FreeTensorMem(const DeviceMemPtr &device_addr);
|
void FreeTensorMem(const DeviceMemPtr &device_addr);
|
||||||
|
|
||||||
|
@ -94,21 +118,22 @@ class DynamicMemPoolBestFit {
|
||||||
void ReleaseDeviceRes();
|
void ReleaseDeviceRes();
|
||||||
// Display the information of memory block and memory buf.
|
// Display the information of memory block and memory buf.
|
||||||
void DumpDynamicMemPoolInfo();
|
void DumpDynamicMemPoolInfo();
|
||||||
// Get the map of global idle mem buf and size.
|
|
||||||
SizeMapMemBuf global_idle_mem_buf_map() {
|
|
||||||
std::lock_guard<std::mutex> locker(mutex_);
|
|
||||||
return global_idle_mem_buf_map_;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Get the minimum memory unit size using for dynamic extend.
|
// Get the minimum memory unit size using for dynamic extend.
|
||||||
size_t mem_alloc_unit_size() const { return mem_alloc_unit_size_; }
|
size_t MemAllocUnitSize(bool from_persistent_mem = false) const;
|
||||||
// Set the minimum memory unit size using for dynamic extend.
|
// Set the minimum memory unit size using for dynamic extend.
|
||||||
void set_mem_alloc_unit_size(const size_t &size) { mem_alloc_unit_size_ = size; }
|
void SetMemAllocUintSize(size_t size);
|
||||||
|
// Set mempool init percent in pynative mode
|
||||||
// Get the related memory statistics information.
|
void SetMempoolBlockSize(size_t device_mem_size);
|
||||||
size_t total_mem_statistics() const { return total_mem_statistics_; }
|
size_t TotalMemStatistics() const {
|
||||||
size_t used_mem_statistics() const { return total_used_mem_statistics_; }
|
return common_mem_->mps_.total_mem_size_ + persistent_mem_->mps_.total_mem_size_;
|
||||||
size_t used_mem_peak_statistics() const { return used_mem_peak_statistics_; }
|
}
|
||||||
|
size_t TotalUsedMemStatistics() const {
|
||||||
|
return common_mem_->mps_.total_used_mem_size_ + persistent_mem_->mps_.total_used_mem_size_;
|
||||||
|
}
|
||||||
|
size_t UsedMemPeakStatistics() const {
|
||||||
|
return common_mem_->mps_.used_mem_peak_size_ + persistent_mem_->mps_.used_mem_peak_size_;
|
||||||
|
}
|
||||||
|
|
||||||
// The related interface of device memory real operation, needs override by device type.
|
// The related interface of device memory real operation, needs override by device type.
|
||||||
virtual size_t AllocDeviceMem(size_t size, DeviceMemPtr *addr) = 0;
|
virtual size_t AllocDeviceMem(size_t size, DeviceMemPtr *addr) = 0;
|
||||||
|
@ -116,43 +141,35 @@ class DynamicMemPoolBestFit {
|
||||||
virtual size_t free_mem_size() = 0;
|
virtual size_t free_mem_size() = 0;
|
||||||
|
|
||||||
protected:
|
protected:
|
||||||
|
MemStatusManagerPtr &common_mem() { return common_mem_; }
|
||||||
|
MemStatusManagerPtr &persistent_mem() { return persistent_mem_; }
|
||||||
// The real size by memory alloc aligned.
|
// The real size by memory alloc aligned.
|
||||||
virtual size_t AlignMemorySize(size_t size) const;
|
virtual size_t AlignMemorySize(size_t size) const;
|
||||||
// Calculate memory block required alloc size when adding the memory block.
|
// Calculate memory block required alloc size when adding the memory block.
|
||||||
virtual size_t CalMemBlockAllocSize(size_t size);
|
virtual size_t CalMemBlockAllocSize(size_t size, bool from_persistent_mem);
|
||||||
|
|
||||||
private:
|
private:
|
||||||
// Find the idle memory buf by aligned size when memory alloc.
|
// Find the idle memory buf by aligned size when memory alloc.
|
||||||
DeviceMemPtr FindIdleMemBuf(size_t size);
|
DeviceMemPtr FindIdleMemBuf(size_t size, bool from_persistent_mem);
|
||||||
// Add the memory block and memory buf when memory alloc not find the idle memory buf.
|
// Add the memory block and memory buf when memory alloc not find the idle memory buf.
|
||||||
DeviceMemPtr AddMemBlockAndMemBuf(size_t size);
|
DeviceMemPtr AddMemBlockAndMemBuf(size_t size, bool from_persistent_mem);
|
||||||
// Judge whether need divide the memory buf by alloc size and memory buf size.
|
// Judge whether need split the memory buf by alloc size and memory buf size.
|
||||||
bool IsDivide(size_t tensor_size, size_t mem_buf_size) const;
|
bool IsSplit(size_t tensor_size, size_t mem_buf_size) const;
|
||||||
// Divide the memory buf by alloc size.
|
// Split the memory buf by alloc size.
|
||||||
void DivideMemBuf(size_t size, const DynamicMemBufPtr &mem_buf);
|
void SplitMemBuf(size_t size, const DynamicMemBufPtr &mem_buf, const MemStatusManagerPtr &mem_mng);
|
||||||
// Find the memory block by device address.
|
// Find the memory block by device address.
|
||||||
DynamicMemBlockPtr FindMemBlock(const DeviceMemPtr &device_addr);
|
DynamicMemBlockPtr FindMemBlock(const DeviceMemPtr &device_addr, const MemStatusManagerPtr &mem_mgr);
|
||||||
// The Comparator of memory block by device address, because memory blocks are arranged in order by device address.
|
// The Comparator of memory block by device address, because memory blocks are arranged in order by device address.
|
||||||
static bool CmpMemBlock(const DeviceMemPtr &device_addr, const DynamicMemBlockPtr &mem_block);
|
static bool CmpMemBlock(const DeviceMemPtr &device_addr, const DynamicMemBlockPtr &mem_block);
|
||||||
|
|
||||||
// Combine the memory buf when memory free, to avoid the memory fragmentation.
|
// Combine the memory buf when memory free, to avoid the memory fragmentation.
|
||||||
void CombineMemBuf(const DynamicMemBlockPtr &mem_block, const DeviceMemPtr &device_addr);
|
void CombineMemBuf(const DynamicMemBlockPtr &mem_block, const DeviceMemPtr &device_addr,
|
||||||
|
const MemStatusManagerPtr &mem_mng);
|
||||||
// Erase the idle memory buf by size and device address when idle memory buf is combined.
|
// Erase the idle memory buf by size and device address when idle memory buf is combined.
|
||||||
void EraseIdleMemBuf(size_t size, const DeviceMemPtr &device_addr);
|
void EraseIdleMemBuf(size_t size, const DeviceMemPtr &device_addr, const MemStatusManagerPtr &mem_mng);
|
||||||
|
|
||||||
// The global memory block list which is arranged in order by base device address of memory block.
|
|
||||||
std::vector<DynamicMemBlockPtr> global_mem_block_list_;
|
|
||||||
// The map of all idle memory buf by size.
|
|
||||||
SizeMapMemBuf global_idle_mem_buf_map_;
|
|
||||||
|
|
||||||
// The related memory statistics information.
|
|
||||||
size_t total_mem_statistics_{0};
|
|
||||||
size_t total_used_mem_statistics_{0};
|
|
||||||
size_t used_mem_peak_statistics_{0};
|
|
||||||
|
|
||||||
// The minimum memory unit size.
|
|
||||||
size_t mem_alloc_unit_size_{DYNAMIC_MEM_ALLOC_UNIT_SIZE};
|
|
||||||
|
|
||||||
|
MemStatusManagerPtr persistent_mem_{nullptr};
|
||||||
|
MemStatusManagerPtr common_mem_{nullptr};
|
||||||
// Support multi-thread.
|
// Support multi-thread.
|
||||||
std::mutex mutex_;
|
std::mutex mutex_;
|
||||||
};
|
};
|
||||||
|
|
|
@ -864,7 +864,7 @@ void AscendSession::RunOpImplOrigin(const GraphInfo &graph_info, OpRunInfo *op_r
|
||||||
|
|
||||||
// malloc mem
|
// malloc mem
|
||||||
RunOpRemoveNopNode(graph);
|
RunOpRemoveNopNode(graph);
|
||||||
RunOpMemoryAlloc(*input_tensors, graph.get());
|
RunOpMemoryAlloc(*input_tensors, graph.get(), op_run_info->is_gradient_out);
|
||||||
RunOpGenKernelEvent(graph.get());
|
RunOpGenKernelEvent(graph.get());
|
||||||
AnfAlgo::CacheAddrForGraph(graph);
|
AnfAlgo::CacheAddrForGraph(graph);
|
||||||
// Build dynamic kernel
|
// Build dynamic kernel
|
||||||
|
@ -998,7 +998,7 @@ void AscendSession::BuildOpsInGraph(const GraphId &graph_id, const std::map<AnfN
|
||||||
GetOpInputStubTensors(kernel, parameter_index, graph_inputs, op_output_info, &input_tensor_info);
|
GetOpInputStubTensors(kernel, parameter_index, graph_inputs, op_output_info, &input_tensor_info);
|
||||||
// Get OpRunInfo and GraphInfo
|
// Get OpRunInfo and GraphInfo
|
||||||
const GraphInfo &graph_info = GetSingleOpGraphInfo(kernel, input_tensor_info.input_tensors);
|
const GraphInfo &graph_info = GetSingleOpGraphInfo(kernel, input_tensor_info.input_tensors);
|
||||||
OpRunInfo op_run_info = GetSingleOpRunInfo(kernel, graph_info, input_tensor_info);
|
OpRunInfo op_run_info = GetSingleOpRunInfo(kernel, graph_info, input_tensor_info, nullptr);
|
||||||
if (op_run_info.is_dynamic_shape) {
|
if (op_run_info.is_dynamic_shape) {
|
||||||
MS_LOG(INFO) << "BuildOpsInGraph stop, op " << op_run_info.op_name << " is dynamic shape.";
|
MS_LOG(INFO) << "BuildOpsInGraph stop, op " << op_run_info.op_name << " is dynamic shape.";
|
||||||
break;
|
break;
|
||||||
|
@ -1318,19 +1318,19 @@ void AscendSession::MemoryAlloc(KernelGraph *kernel_graph) const {
|
||||||
MS_LOG(INFO) << "Status record: end memory alloc. graph id: " << kernel_graph->graph_id()
|
MS_LOG(INFO) << "Status record: end memory alloc. graph id: " << kernel_graph->graph_id()
|
||||||
<< ", Memory Statistics:" << device::ascend::AscendMemAdapter::GetInstance().DevMemStatistics();
|
<< ", Memory Statistics:" << device::ascend::AscendMemAdapter::GetInstance().DevMemStatistics();
|
||||||
MS_LOG(INFO) << "The dynamic memory pool total size is "
|
MS_LOG(INFO) << "The dynamic memory pool total size is "
|
||||||
<< device::ascend::AscendMemoryPool::GetInstance().total_mem_statistics() / kMBToByte
|
<< device::ascend::AscendMemoryPool::GetInstance().TotalMemStatistics() / kMBToByte
|
||||||
<< "M, total used size is "
|
<< "M, total used size is "
|
||||||
<< device::ascend::AscendMemoryPool::GetInstance().used_mem_statistics() / kMBToByte
|
<< device::ascend::AscendMemoryPool::GetInstance().TotalUsedMemStatistics() / kMBToByte
|
||||||
<< "M, used peak size is "
|
<< "M, used peak size is "
|
||||||
<< device::ascend::AscendMemoryPool::GetInstance().used_mem_peak_statistics() / kMBToByte << "M.";
|
<< device::ascend::AscendMemoryPool::GetInstance().UsedMemPeakStatistics() / kMBToByte << "M.";
|
||||||
}
|
}
|
||||||
|
|
||||||
void AscendSession::RunOpMemoryAlloc(const std::vector<tensor::TensorPtr> &input_tensors,
|
void AscendSession::RunOpMemoryAlloc(const std::vector<tensor::TensorPtr> &input_tensors, KernelGraph *kernel_graph,
|
||||||
KernelGraph *kernel_graph) const {
|
bool is_gradient_out) const {
|
||||||
MS_EXCEPTION_IF_NULL(kernel_graph);
|
MS_EXCEPTION_IF_NULL(kernel_graph);
|
||||||
auto runtime_instance = device::KernelRuntimeManager::Instance().GetKernelRuntime(kAscendDevice, device_id_);
|
auto runtime_instance = device::KernelRuntimeManager::Instance().GetKernelRuntime(kAscendDevice, device_id_);
|
||||||
MS_EXCEPTION_IF_NULL(runtime_instance);
|
MS_EXCEPTION_IF_NULL(runtime_instance);
|
||||||
runtime_instance->RunOpAssignMemory(input_tensors, *kernel_graph);
|
runtime_instance->RunOpAssignMemory(input_tensors, *kernel_graph, is_gradient_out);
|
||||||
}
|
}
|
||||||
|
|
||||||
void AscendSession::RunOpMemoryAllocNew(const std::vector<tensor::TensorPtr> &input_tensors,
|
void AscendSession::RunOpMemoryAllocNew(const std::vector<tensor::TensorPtr> &input_tensors,
|
||||||
|
@ -1338,7 +1338,7 @@ void AscendSession::RunOpMemoryAllocNew(const std::vector<tensor::TensorPtr> &in
|
||||||
const KernelGraph &kernel_graph) const {
|
const KernelGraph &kernel_graph) const {
|
||||||
auto runtime_instance = device::KernelRuntimeManager::Instance().GetKernelRuntime(kAscendDevice, device_id_);
|
auto runtime_instance = device::KernelRuntimeManager::Instance().GetKernelRuntime(kAscendDevice, device_id_);
|
||||||
MS_EXCEPTION_IF_NULL(runtime_instance);
|
MS_EXCEPTION_IF_NULL(runtime_instance);
|
||||||
runtime_instance->RunOpAssignMemory(input_tensors, kernel_graph, tensor_to_node);
|
runtime_instance->RunOpAssignMemory(input_tensors, kernel_graph, false, tensor_to_node);
|
||||||
}
|
}
|
||||||
|
|
||||||
void AscendSession::RunOpGenKernelEvent(const KernelGraph *graph) const {
|
void AscendSession::RunOpGenKernelEvent(const KernelGraph *graph) const {
|
||||||
|
|
|
@ -103,7 +103,8 @@ class AscendSession : public SessionBasic {
|
||||||
static void BuildKernel(const std::vector<CNodePtr> &kernels);
|
static void BuildKernel(const std::vector<CNodePtr> &kernels);
|
||||||
void BuildDynamicKernel(const std::shared_ptr<KernelGraph> &kernel_graph) const;
|
void BuildDynamicKernel(const std::shared_ptr<KernelGraph> &kernel_graph) const;
|
||||||
void MemoryAlloc(KernelGraph *kernel_graph) const;
|
void MemoryAlloc(KernelGraph *kernel_graph) const;
|
||||||
void RunOpMemoryAlloc(const std::vector<tensor::TensorPtr> &input_tensors, KernelGraph *kernel_graph) const;
|
void RunOpMemoryAlloc(const std::vector<tensor::TensorPtr> &input_tensors, KernelGraph *kernel_graph,
|
||||||
|
bool is_gradient_out) const;
|
||||||
void RunOpMemoryAllocNew(const std::vector<tensor::TensorPtr> &input_tensors,
|
void RunOpMemoryAllocNew(const std::vector<tensor::TensorPtr> &input_tensors,
|
||||||
const std::map<tensor::TensorPtr, session::KernelWithIndex> &tensor_to_node,
|
const std::map<tensor::TensorPtr, session::KernelWithIndex> &tensor_to_node,
|
||||||
const KernelGraph &kernel_graph) const;
|
const KernelGraph &kernel_graph) const;
|
||||||
|
|
|
@ -255,11 +255,11 @@ void GPUSession::AllocateMemory(const KernelGraph *kernel_graph) const {
|
||||||
}
|
}
|
||||||
|
|
||||||
void GPUSession::RunOpAllocateMemory(const std::vector<tensor::TensorPtr> &input_tensors,
|
void GPUSession::RunOpAllocateMemory(const std::vector<tensor::TensorPtr> &input_tensors,
|
||||||
const KernelGraph *kernel_graph) const {
|
const KernelGraph *kernel_graph, bool is_gradient_out) const {
|
||||||
MS_EXCEPTION_IF_NULL(kernel_graph);
|
MS_EXCEPTION_IF_NULL(kernel_graph);
|
||||||
auto runtime_instance = device::KernelRuntimeManager::Instance().GetSingleKernelRuntime(kGPUDevice, device_id_);
|
auto runtime_instance = device::KernelRuntimeManager::Instance().GetSingleKernelRuntime(kGPUDevice, device_id_);
|
||||||
MS_EXCEPTION_IF_NULL(runtime_instance);
|
MS_EXCEPTION_IF_NULL(runtime_instance);
|
||||||
runtime_instance->RunOpAssignMemory(input_tensors, *kernel_graph);
|
runtime_instance->RunOpAssignMemory(input_tensors, *kernel_graph, is_gradient_out);
|
||||||
}
|
}
|
||||||
|
|
||||||
void GPUSession::RunOpGenKernelEvent(const KernelGraph *graph) const {
|
void GPUSession::RunOpGenKernelEvent(const KernelGraph *graph) const {
|
||||||
|
@ -701,7 +701,7 @@ void GPUSession::RunOpImpl(const GraphInfo &graph_info, OpRunInfo *op_run_info,
|
||||||
// run op
|
// run op
|
||||||
MS_EXCEPTION_IF_NULL(kernel_graph);
|
MS_EXCEPTION_IF_NULL(kernel_graph);
|
||||||
RunOpRemoveNopNode(kernel_graph);
|
RunOpRemoveNopNode(kernel_graph);
|
||||||
RunOpAllocateMemory(*input_tensors, kernel_graph.get());
|
RunOpAllocateMemory(*input_tensors, kernel_graph.get(), op_run_info->is_gradient_out);
|
||||||
RunOpGenKernelEvent(kernel_graph.get());
|
RunOpGenKernelEvent(kernel_graph.get());
|
||||||
// Execute the computation
|
// Execute the computation
|
||||||
LoadInputData(kernel_graph, *input_tensors);
|
LoadInputData(kernel_graph, *input_tensors);
|
||||||
|
|
|
@ -82,7 +82,8 @@ class GPUSession : public SessionBasic {
|
||||||
|
|
||||||
void AllocateMemory(const KernelGraph *kernel_graph) const;
|
void AllocateMemory(const KernelGraph *kernel_graph) const;
|
||||||
|
|
||||||
void RunOpAllocateMemory(const std::vector<tensor::TensorPtr> &input_tensors, const KernelGraph *kernel_graph) const;
|
void RunOpAllocateMemory(const std::vector<tensor::TensorPtr> &input_tensors, const KernelGraph *kernel_graph,
|
||||||
|
bool is_gradient_out) const;
|
||||||
|
|
||||||
void RunOpClearMemory(const KernelGraph *kernel_graph) const;
|
void RunOpClearMemory(const KernelGraph *kernel_graph) const;
|
||||||
|
|
||||||
|
|
|
@ -54,6 +54,7 @@
|
||||||
#endif
|
#endif
|
||||||
#include "backend/session/session_factory.h"
|
#include "backend/session/session_factory.h"
|
||||||
#include "backend/session/pynative_task_manager.h"
|
#include "backend/session/pynative_task_manager.h"
|
||||||
|
#include "pipeline/pynative/pynative_execute.h"
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace session {
|
namespace session {
|
||||||
|
@ -1220,7 +1221,8 @@ GraphInfo SessionBasic::GetSingleOpGraphInfo(const CNodePtr &kernel,
|
||||||
}
|
}
|
||||||
|
|
||||||
OpRunInfo SessionBasic::GetSingleOpRunInfo(const CNodePtr &cnode, const GraphInfo &graph_info,
|
OpRunInfo SessionBasic::GetSingleOpRunInfo(const CNodePtr &cnode, const GraphInfo &graph_info,
|
||||||
const InputTensorInfo &tensor_info) {
|
const InputTensorInfo &tensor_info,
|
||||||
|
GraphOutputInfo *const graph_output_info) {
|
||||||
MS_EXCEPTION_IF_NULL(cnode);
|
MS_EXCEPTION_IF_NULL(cnode);
|
||||||
auto primitive = AnfAlgo::GetCNodePrimitive(cnode);
|
auto primitive = AnfAlgo::GetCNodePrimitive(cnode);
|
||||||
const auto &abstract = cnode->abstract();
|
const auto &abstract = cnode->abstract();
|
||||||
|
@ -1230,7 +1232,14 @@ OpRunInfo SessionBasic::GetSingleOpRunInfo(const CNodePtr &cnode, const GraphInf
|
||||||
const auto &shape = abstract->BuildShape();
|
const auto &shape = abstract->BuildShape();
|
||||||
MS_EXCEPTION_IF_NULL(shape);
|
MS_EXCEPTION_IF_NULL(shape);
|
||||||
|
|
||||||
OpRunInfo op_run_info = {.op_name = primitive->name(),
|
bool is_gradient_out =
|
||||||
|
graph_output_info != nullptr &&
|
||||||
|
std::any_of(graph_output_info->output_indexes.begin(), graph_output_info->output_indexes.end(),
|
||||||
|
[cnode](const std::pair<KernelWithIndex, std::vector<std::vector<size_t>>> &output_index) {
|
||||||
|
return output_index.first.first == cnode;
|
||||||
|
});
|
||||||
|
OpRunInfo op_run_info = {.is_gradient_out = is_gradient_out,
|
||||||
|
.op_name = primitive->name(),
|
||||||
.primitive = primitive.get(),
|
.primitive = primitive.get(),
|
||||||
.abstract = abstract,
|
.abstract = abstract,
|
||||||
.is_dynamic_shape = shape->IsDynamic(),
|
.is_dynamic_shape = shape->IsDynamic(),
|
||||||
|
@ -1304,16 +1313,63 @@ void SessionBasic::GetRefCount(const KernelGraph *graph, std::map<KernelWithInde
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void SessionBasic::GetForwardOutputRefCount(const KernelGraph *graph,
|
||||||
|
std::map<AnfNodePtr, size_t> *forward_output_refcount) {
|
||||||
|
if (!pynative::PynativeExecutor::GetInstance()->grad_flag()) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
const auto &forward_value_nodes_id = pynative::PynativeExecutor::GetInstance()->grad_executor()->forward_outputs_id();
|
||||||
|
for (const auto &kernel : graph->execution_order()) {
|
||||||
|
for (size_t i = 1; i < kernel->inputs().size(); i += 1) {
|
||||||
|
const auto &input = kernel->input(i);
|
||||||
|
if (!input->isa<ValueNode>()) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
auto value_node = input->cast<ValueNodePtr>();
|
||||||
|
auto value = GetValueNode(value_node);
|
||||||
|
MS_EXCEPTION_IF_NULL(value);
|
||||||
|
if (value->isa<tensor::Tensor>()) {
|
||||||
|
auto tensor = value->cast<tensor::TensorPtr>();
|
||||||
|
MS_EXCEPTION_IF_NULL(tensor);
|
||||||
|
if (forward_value_nodes_id.find(tensor->id()) != forward_value_nodes_id.end()) {
|
||||||
|
(*forward_output_refcount)[input] += 1;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
void SessionBasic::HandleOpInputs(const std::set<KernelWithIndex> &input_kernel,
|
void SessionBasic::HandleOpInputs(const std::set<KernelWithIndex> &input_kernel,
|
||||||
std::map<KernelWithIndex, size_t> *ref_count,
|
std::map<KernelWithIndex, size_t> *ref_count,
|
||||||
|
std::map<AnfNodePtr, size_t> *forward_output_refcount,
|
||||||
std::map<KernelWithIndex, tensor::TensorPtr> *op_output_map) {
|
std::map<KernelWithIndex, tensor::TensorPtr> *op_output_map) {
|
||||||
MS_EXCEPTION_IF_NULL(ref_count);
|
MS_EXCEPTION_IF_NULL(ref_count);
|
||||||
|
MS_EXCEPTION_IF_NULL(forward_output_refcount);
|
||||||
MS_EXCEPTION_IF_NULL(op_output_map);
|
MS_EXCEPTION_IF_NULL(op_output_map);
|
||||||
for (auto &kernel_with_index : input_kernel) {
|
for (auto &kernel_with_index : input_kernel) {
|
||||||
MS_EXCEPTION_IF_NULL(kernel_with_index.first);
|
|
||||||
if (!kernel_with_index.first->isa<CNode>()) {
|
if (!kernel_with_index.first->isa<CNode>()) {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Release forward output
|
||||||
|
auto cnode = kernel_with_index.first->cast<CNodePtr>();
|
||||||
|
for (size_t i = 1; i < cnode->inputs().size(); i += 1) {
|
||||||
|
const auto &input = cnode->input(i);
|
||||||
|
auto it = forward_output_refcount->find(input);
|
||||||
|
if (it != forward_output_refcount->end()) {
|
||||||
|
if (--(it->second) == 0) {
|
||||||
|
auto value_node = input->cast<ValueNodePtr>();
|
||||||
|
auto value = GetValueNode(value_node);
|
||||||
|
MS_EXCEPTION_IF_NULL(value);
|
||||||
|
auto tensor = value->cast<tensor::TensorPtr>();
|
||||||
|
MS_EXCEPTION_IF_NULL(tensor);
|
||||||
|
tensor->set_device_address(nullptr);
|
||||||
|
forward_output_refcount->erase(it);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Release previous output
|
||||||
auto ref_iter = ref_count->find(kernel_with_index);
|
auto ref_iter = ref_count->find(kernel_with_index);
|
||||||
if (ref_iter == ref_count->end()) {
|
if (ref_iter == ref_count->end()) {
|
||||||
MS_LOG(EXCEPTION) << "Can not find input KernelWithIndex in cnode reference count map, input cnode = "
|
MS_LOG(EXCEPTION) << "Can not find input KernelWithIndex in cnode reference count map, input cnode = "
|
||||||
|
@ -2329,7 +2385,9 @@ void SessionBasic::RunOpsInGraphImpl(const GraphId &graph_id, const std::vector<
|
||||||
graph_output_info.graph_outputs = outputs;
|
graph_output_info.graph_outputs = outputs;
|
||||||
CreateOutputPlaceholder(kernel_graph, inputs, graph_output_info.graph_outputs, &graph_output_info.output_indexes);
|
CreateOutputPlaceholder(kernel_graph, inputs, graph_output_info.graph_outputs, &graph_output_info.output_indexes);
|
||||||
std::map<KernelWithIndex, size_t> cnode_refcount;
|
std::map<KernelWithIndex, size_t> cnode_refcount;
|
||||||
|
std::map<AnfNodePtr, size_t> forward_output_refcount;
|
||||||
GetRefCount(kernel_graph.get(), &cnode_refcount);
|
GetRefCount(kernel_graph.get(), &cnode_refcount);
|
||||||
|
GetForwardOutputRefCount(kernel_graph.get(), &forward_output_refcount);
|
||||||
BuildOpsInGraph(graph_id, parameter_index, inputs, cnode_refcount);
|
BuildOpsInGraph(graph_id, parameter_index, inputs, cnode_refcount);
|
||||||
|
|
||||||
// Clear bucket resources every step
|
// Clear bucket resources every step
|
||||||
|
@ -2346,14 +2404,14 @@ void SessionBasic::RunOpsInGraphImpl(const GraphId &graph_id, const std::vector<
|
||||||
VectorRef op_outputs;
|
VectorRef op_outputs;
|
||||||
// Get OpRunInfo and GraphInfo
|
// Get OpRunInfo and GraphInfo
|
||||||
GraphInfo graph_info = GetSingleOpGraphInfo(kernel, input_tensor_info.input_tensors);
|
GraphInfo graph_info = GetSingleOpGraphInfo(kernel, input_tensor_info.input_tensors);
|
||||||
OpRunInfo run_info = GetSingleOpRunInfo(kernel, graph_info, input_tensor_info);
|
OpRunInfo run_info = GetSingleOpRunInfo(kernel, graph_info, input_tensor_info, &graph_output_info);
|
||||||
|
|
||||||
// Build and run current single op
|
// Build and run current single op
|
||||||
RunOpImplOrigin(graph_info, &run_info, &input_tensor_info.input_tensors, &op_outputs,
|
RunOpImplOrigin(graph_info, &run_info, &input_tensor_info.input_tensors, &op_outputs,
|
||||||
input_tensor_info.input_tensors_mask);
|
input_tensor_info.input_tensors_mask);
|
||||||
graph_output_info.graph_output_tensors.clear();
|
graph_output_info.graph_output_tensors.clear();
|
||||||
// Handle inputs and outputs of current op
|
// Handle inputs and outputs of current op
|
||||||
HandleOpInputs(input_tensor_info.input_kernel, &cnode_refcount, &op_output_map);
|
HandleOpInputs(input_tensor_info.input_kernel, &cnode_refcount, &forward_output_refcount, &op_output_map);
|
||||||
HandleOpOutputs(kernel, op_outputs, cnode_refcount, &op_output_map, &graph_output_info);
|
HandleOpOutputs(kernel, op_outputs, cnode_refcount, &op_output_map, &graph_output_info);
|
||||||
// Save grad node to Bucket
|
// Save grad node to Bucket
|
||||||
if (kernel_graph->is_bprop()) {
|
if (kernel_graph->is_bprop()) {
|
||||||
|
|
|
@ -57,6 +57,7 @@ using AnyList = std::vector<Any>;
|
||||||
using AnyListPtr = std::shared_ptr<AnyList>;
|
using AnyListPtr = std::shared_ptr<AnyList>;
|
||||||
|
|
||||||
struct OpRunInfo {
|
struct OpRunInfo {
|
||||||
|
bool is_gradient_out = false;
|
||||||
std::string op_name;
|
std::string op_name;
|
||||||
Primitive *primitive;
|
Primitive *primitive;
|
||||||
AbstractBasePtr abstract;
|
AbstractBasePtr abstract;
|
||||||
|
@ -193,7 +194,9 @@ class SessionBasic : public std::enable_shared_from_this<SessionBasic> {
|
||||||
VectorRef *const outputs,
|
VectorRef *const outputs,
|
||||||
std::map<KernelWithIndex, std::vector<std::vector<size_t>>> *output_indexes);
|
std::map<KernelWithIndex, std::vector<std::vector<size_t>>> *output_indexes);
|
||||||
void GetRefCount(const KernelGraph *graph, std::map<KernelWithIndex, size_t> *ref_count);
|
void GetRefCount(const KernelGraph *graph, std::map<KernelWithIndex, size_t> *ref_count);
|
||||||
|
void GetForwardOutputRefCount(const KernelGraph *graph, std::map<AnfNodePtr, size_t> *forward_output_refcount);
|
||||||
void HandleOpInputs(const std::set<KernelWithIndex> &input_kernel, std::map<KernelWithIndex, size_t> *ref_count,
|
void HandleOpInputs(const std::set<KernelWithIndex> &input_kernel, std::map<KernelWithIndex, size_t> *ref_count,
|
||||||
|
std::map<AnfNodePtr, size_t> *forward_output_refcount,
|
||||||
std::map<KernelWithIndex, tensor::TensorPtr> *op_output_map);
|
std::map<KernelWithIndex, tensor::TensorPtr> *op_output_map);
|
||||||
|
|
||||||
void HandleOpOutputs(const AnfNodePtr &kernel, const VectorRef &op_outputs,
|
void HandleOpOutputs(const AnfNodePtr &kernel, const VectorRef &op_outputs,
|
||||||
|
@ -280,7 +283,8 @@ class SessionBasic : public std::enable_shared_from_this<SessionBasic> {
|
||||||
CNodePtr ConstructOutput(const AnfNodePtrList &outputs, const std::shared_ptr<KernelGraph> &graph);
|
CNodePtr ConstructOutput(const AnfNodePtrList &outputs, const std::shared_ptr<KernelGraph> &graph);
|
||||||
// Generate graph info for a single op graph
|
// Generate graph info for a single op graph
|
||||||
GraphInfo GetSingleOpGraphInfo(const CNodePtr &kernel, const std::vector<tensor::TensorPtr> &input_tensors);
|
GraphInfo GetSingleOpGraphInfo(const CNodePtr &kernel, const std::vector<tensor::TensorPtr> &input_tensors);
|
||||||
OpRunInfo GetSingleOpRunInfo(const CNodePtr &cnode, const GraphInfo &graph_info, const InputTensorInfo &tensor_info);
|
OpRunInfo GetSingleOpRunInfo(const CNodePtr &cnode, const GraphInfo &graph_info, const InputTensorInfo &tensor_info,
|
||||||
|
GraphOutputInfo *const graph_output_info);
|
||||||
tensor::TensorPtr GetValueNodeOutputTensor(const AnfNodePtr &node, size_t output_index);
|
tensor::TensorPtr GetValueNodeOutputTensor(const AnfNodePtr &node, size_t output_index);
|
||||||
tensor::TensorPtr GetParameterOutputTensor(const AnfNodePtr &node,
|
tensor::TensorPtr GetParameterOutputTensor(const AnfNodePtr &node,
|
||||||
const std::map<AnfNodePtr, size_t> ¶meter_index,
|
const std::map<AnfNodePtr, size_t> ¶meter_index,
|
||||||
|
|
|
@ -50,6 +50,7 @@ enum PynativeStatusCode {
|
||||||
enum RunOpArgsEnum { PY_PRIM = 0, PY_NAME, PY_INPUTS, PY_ARGS_NUM };
|
enum RunOpArgsEnum { PY_PRIM = 0, PY_NAME, PY_INPUTS, PY_ARGS_NUM };
|
||||||
|
|
||||||
struct OpExecInfo {
|
struct OpExecInfo {
|
||||||
|
bool is_nop_prim = false;
|
||||||
bool is_dynamic_shape = false;
|
bool is_dynamic_shape = false;
|
||||||
bool is_mixed_precision_cast = false;
|
bool is_mixed_precision_cast = false;
|
||||||
size_t next_input_index = 0;
|
size_t next_input_index = 0;
|
||||||
|
|
|
@ -880,6 +880,37 @@ py::object GetDstType(const TypeId &type_id) {
|
||||||
bool IsPyObjTypeInvalid(const py::object &obj) {
|
bool IsPyObjTypeInvalid(const py::object &obj) {
|
||||||
return !py::isinstance<tensor::Tensor>(obj) && !py::isinstance<py::int_>(obj) && !py::isinstance<py::float_>(obj);
|
return !py::isinstance<tensor::Tensor>(obj) && !py::isinstance<py::int_>(obj) && !py::isinstance<py::float_>(obj);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
inline bool IsNopPrim(const std::string &op_name) {
|
||||||
|
static std::set<std::string> nop_prim = {prim::kPrimReshape->name(), kExpandDimsOpName, prim::kPrimSqueeze->name(),
|
||||||
|
prim::kPrimFlatten->name(), kFlattenGradOpName, prim::kPrimReformat->name()};
|
||||||
|
return nop_prim.find(op_name) != nop_prim.end();
|
||||||
|
}
|
||||||
|
|
||||||
|
// Shallow Copy Value and change shape
|
||||||
|
ValuePtr ShallowCopyValue(const OpExecInfoPtr &op_exec_info, const ValuePtr &value) {
|
||||||
|
MS_EXCEPTION_IF_NULL(op_exec_info);
|
||||||
|
MS_EXCEPTION_IF_NULL(value);
|
||||||
|
auto tensor_abs = op_exec_info->abstract;
|
||||||
|
if (tensor_abs->isa<abstract::AbstractRef>()) {
|
||||||
|
tensor_abs = tensor_abs->cast<abstract::AbstractRefPtr>()->CloneAsTensor();
|
||||||
|
}
|
||||||
|
auto new_shape = tensor_abs->BuildShape()->cast<abstract::ShapePtr>();
|
||||||
|
MS_EXCEPTION_IF_NULL(new_shape);
|
||||||
|
if (value->isa<mindspore::tensor::Tensor>()) {
|
||||||
|
auto tensor_value = value->cast<mindspore::tensor::TensorPtr>();
|
||||||
|
return std::make_shared<mindspore::tensor::Tensor>(tensor_value->data_type(), new_shape->shape(),
|
||||||
|
tensor_value->data_c(), tensor_value->Size());
|
||||||
|
} else if (value->isa<ValueTuple>()) {
|
||||||
|
std::vector<ValuePtr> values;
|
||||||
|
auto value_tuple = value->cast<ValueTuplePtr>();
|
||||||
|
(void)std::transform(value_tuple->value().begin(), value_tuple->value().end(), std::back_inserter(values),
|
||||||
|
[op_exec_info](const ValuePtr &elem) { return ShallowCopyValue(op_exec_info, elem); });
|
||||||
|
return std::make_shared<ValueTuple>(values);
|
||||||
|
} else {
|
||||||
|
return value;
|
||||||
|
}
|
||||||
|
}
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
py::object RealRunOp(const py::args &args) {
|
py::object RealRunOp(const py::args &args) {
|
||||||
|
@ -957,6 +988,7 @@ void TopCellInfo::Clear() {
|
||||||
k_pynative_cell_ptr_ = nullptr;
|
k_pynative_cell_ptr_ = nullptr;
|
||||||
graph_info_map_.clear();
|
graph_info_map_.clear();
|
||||||
sub_cell_list_.clear();
|
sub_cell_list_.clear();
|
||||||
|
outputs_id_.clear();
|
||||||
op_info_with_tensor_id_.clear();
|
op_info_with_tensor_id_.clear();
|
||||||
tensor_id_with_tensor_object_.clear();
|
tensor_id_with_tensor_object_.clear();
|
||||||
op_info_with_ms_func_forward_tensors_.clear();
|
op_info_with_ms_func_forward_tensors_.clear();
|
||||||
|
@ -992,6 +1024,7 @@ OpExecInfoPtr ForwardExecutor::GenerateOpExecInfo(const py::args &args) {
|
||||||
const auto &op_exec_info = std::make_shared<OpExecInfo>();
|
const auto &op_exec_info = std::make_shared<OpExecInfo>();
|
||||||
const auto &op_name = py::cast<std::string>(args[PY_NAME]);
|
const auto &op_name = py::cast<std::string>(args[PY_NAME]);
|
||||||
op_exec_info->op_name = op_name;
|
op_exec_info->op_name = op_name;
|
||||||
|
op_exec_info->is_nop_prim = false; // IsNopPrim(op_exec_info->op_name);
|
||||||
|
|
||||||
const auto &adapter = py::cast<PrimitivePyAdapterPtr>(args[PY_PRIM]);
|
const auto &adapter = py::cast<PrimitivePyAdapterPtr>(args[PY_PRIM]);
|
||||||
MS_EXCEPTION_IF_NULL(adapter);
|
MS_EXCEPTION_IF_NULL(adapter);
|
||||||
|
@ -1013,7 +1046,7 @@ OpExecInfoPtr ForwardExecutor::GenerateOpExecInfo(const py::args &args) {
|
||||||
void ForwardExecutor::SetCastForInputs(const OpExecInfoPtr &op_exec_info) {
|
void ForwardExecutor::SetCastForInputs(const OpExecInfoPtr &op_exec_info) {
|
||||||
MS_EXCEPTION_IF_NULL(op_exec_info);
|
MS_EXCEPTION_IF_NULL(op_exec_info);
|
||||||
// No need cast self
|
// No need cast self
|
||||||
if (op_exec_info->op_name == prim::kPrimCast->name()) {
|
if (op_exec_info->op_name == prim::kPrimCast->name() || op_exec_info->is_nop_prim) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1167,6 +1200,21 @@ void ForwardExecutor::GetOpOutputAbstract(const OpExecInfoPtr &op_exec_info,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void ForwardExecutor::DoNopOutput(const OpExecInfoPtr &op_exec_info, ValuePtr *out_real_value) {
|
||||||
|
MS_EXCEPTION_IF_NULL(op_exec_info);
|
||||||
|
// Get First input
|
||||||
|
if (op_exec_info->op_inputs.empty()) {
|
||||||
|
MS_LOG(EXCEPTION) << "Inputs of " << op_exec_info->op_name << " is empty";
|
||||||
|
}
|
||||||
|
const auto &obj = op_exec_info->op_inputs[0];
|
||||||
|
if (!py::isinstance<tensor::Tensor>(obj)) {
|
||||||
|
MS_LOG(EXCEPTION) << "First input of " << op_exec_info->op_name << " must be a tensor";
|
||||||
|
}
|
||||||
|
const auto &tensor_ptr = py::cast<tensor::TensorPtr>(obj);
|
||||||
|
*out_real_value = ShallowCopyValue(op_exec_info, tensor_ptr);
|
||||||
|
MS_LOG(DEBUG) << "New copy value is " << (*out_real_value)->ToString();
|
||||||
|
}
|
||||||
|
|
||||||
void ForwardExecutor::GetOpOutput(const OpExecInfoPtr &op_exec_info,
|
void ForwardExecutor::GetOpOutput(const OpExecInfoPtr &op_exec_info,
|
||||||
const abstract::AbstractBasePtrList &args_spec_list, const CNodePtr &cnode,
|
const abstract::AbstractBasePtrList &args_spec_list, const CNodePtr &cnode,
|
||||||
bool prim_cache_hit, py::object *ret) {
|
bool prim_cache_hit, py::object *ret) {
|
||||||
|
@ -1194,25 +1242,32 @@ void ForwardExecutor::GetOpOutput(const OpExecInfoPtr &op_exec_info,
|
||||||
out[args_spec_list].abs = op_exec_info->abstract;
|
out[args_spec_list].abs = op_exec_info->abstract;
|
||||||
out[args_spec_list].attrs = prim->evaluate_added_attrs();
|
out[args_spec_list].attrs = prim->evaluate_added_attrs();
|
||||||
}
|
}
|
||||||
// run op with selected backend
|
|
||||||
auto result = RunOpWithInitBackendPolicy(op_exec_info);
|
// Run op with selected backend, nop is no need run backend
|
||||||
py::object out_real = result;
|
|
||||||
if (result.size() == 1 && op_exec_info->abstract != nullptr &&
|
|
||||||
!op_exec_info->abstract->isa<abstract::AbstractSequence>()) {
|
|
||||||
out_real = result[0];
|
|
||||||
}
|
|
||||||
// get output value
|
|
||||||
ValuePtr out_real_value = nullptr;
|
ValuePtr out_real_value = nullptr;
|
||||||
if (grad()->grad_flag()) {
|
if (op_exec_info->is_nop_prim) {
|
||||||
out_real_value = PyObjToValue(out_real);
|
DoNopOutput(op_exec_info, &out_real_value);
|
||||||
|
*ret = BaseRefToPyData(out_real_value);
|
||||||
|
} else {
|
||||||
|
auto result = RunOpWithInitBackendPolicy(op_exec_info);
|
||||||
|
py::object out_real = result;
|
||||||
|
if (result.size() == 1 && op_exec_info->abstract != nullptr &&
|
||||||
|
!op_exec_info->abstract->isa<abstract::AbstractSequence>()) {
|
||||||
|
out_real = result[0];
|
||||||
|
}
|
||||||
|
// get output value
|
||||||
|
if (grad()->grad_flag()) {
|
||||||
|
out_real_value = PyObjToValue(out_real);
|
||||||
|
}
|
||||||
|
*ret = out_real;
|
||||||
}
|
}
|
||||||
// Save cnode info and build grad graph
|
|
||||||
if (grad()->need_construct_graph() && !grad()->in_cell_with_custom_bprop_()) {
|
if (grad()->need_construct_graph() && !grad()->in_cell_with_custom_bprop_()) {
|
||||||
MS_EXCEPTION_IF_NULL(cnode);
|
MS_EXCEPTION_IF_NULL(cnode);
|
||||||
const auto &obj_id = GetId(out_real);
|
const auto &obj_id = GetId(*ret);
|
||||||
cnode->set_abstract(op_exec_info->abstract);
|
cnode->set_abstract(op_exec_info->abstract);
|
||||||
node_abs_map_[obj_id] = op_exec_info->abstract;
|
node_abs_map_[obj_id] = op_exec_info->abstract;
|
||||||
grad()->SaveOutputNodeMap(obj_id, out_real, cnode);
|
grad()->SaveOutputNodeMap(obj_id, *ret, cnode);
|
||||||
grad()->DoOpGrad(op_exec_info, cnode, out_real_value);
|
grad()->DoOpGrad(op_exec_info, cnode, out_real_value);
|
||||||
} else {
|
} else {
|
||||||
node_abs_map_.clear();
|
node_abs_map_.clear();
|
||||||
|
@ -1220,7 +1275,6 @@ void ForwardExecutor::GetOpOutput(const OpExecInfoPtr &op_exec_info,
|
||||||
// Record op info for judge whether the construct of cell has been changed
|
// Record op info for judge whether the construct of cell has been changed
|
||||||
grad()->RecordGradOpInfo(op_exec_info, out_real_value);
|
grad()->RecordGradOpInfo(op_exec_info, out_real_value);
|
||||||
grad()->UpdateForwardTensorInfoInBpropGraph(op_exec_info, out_real_value);
|
grad()->UpdateForwardTensorInfoInBpropGraph(op_exec_info, out_real_value);
|
||||||
*ret = out_real;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
py::object ForwardExecutor::DoAutoCast(const py::object &arg, const TypeId &type_id, const std::string &op_name,
|
py::object ForwardExecutor::DoAutoCast(const py::object &arg, const TypeId &type_id, const std::string &op_name,
|
||||||
|
@ -1701,7 +1755,7 @@ void GradExecutor::SaveOutputNodeMap(const std::string &obj_id, const py::object
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
MS_EXCEPTION_IF_NULL(cnode);
|
MS_EXCEPTION_IF_NULL(cnode);
|
||||||
MS_LOG(DEBUG) << "Cnode is " << cnode->DebugString() << " id " << obj_id;
|
MS_LOG(DEBUG) << "Cnode is " << cnode->DebugString() << ", out value id " << obj_id;
|
||||||
if (py::isinstance<py::tuple>(out_real)) {
|
if (py::isinstance<py::tuple>(out_real)) {
|
||||||
auto value = py::cast<py::tuple>(out_real);
|
auto value = py::cast<py::tuple>(out_real);
|
||||||
auto size = static_cast<int64_t>(value.size());
|
auto size = static_cast<int64_t>(value.size());
|
||||||
|
@ -1925,6 +1979,7 @@ void GradExecutor::SaveForwardTensorInfoInBpropGraph(const pipeline::ResourcePtr
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
tensor_id_with_tensor_object[tensor->id()].emplace_back(tensor);
|
tensor_id_with_tensor_object[tensor->id()].emplace_back(tensor);
|
||||||
|
top_cell()->outputs_id().insert(tensor->id());
|
||||||
MS_LOG(DEBUG) << "Save forward tensor " << tensor.get() << " id " << tensor->id()
|
MS_LOG(DEBUG) << "Save forward tensor " << tensor.get() << " id " << tensor->id()
|
||||||
<< " device address: " << tensor->device_address() << " shape and dtype "
|
<< " device address: " << tensor->device_address() << " shape and dtype "
|
||||||
<< tensor->GetShapeAndDataTypeInfo();
|
<< tensor->GetShapeAndDataTypeInfo();
|
||||||
|
@ -2090,7 +2145,8 @@ py::object ForwardExecutor::RunOpInMs(const OpExecInfoPtr &op_exec_info, Pynativ
|
||||||
// get graph info for checking it whether existing in the cache
|
// get graph info for checking it whether existing in the cache
|
||||||
GetSingleOpGraphInfo(op_exec_info, input_tensors, tensors_mask, &graph_info);
|
GetSingleOpGraphInfo(op_exec_info, input_tensors, tensors_mask, &graph_info);
|
||||||
#if defined(__APPLE__)
|
#if defined(__APPLE__)
|
||||||
session::OpRunInfo op_run_info = {op_exec_info->op_name,
|
session::OpRunInfo op_run_info = {false,
|
||||||
|
op_exec_info->op_name,
|
||||||
op_exec_info->py_primitive.get(),
|
op_exec_info->py_primitive.get(),
|
||||||
op_exec_info->abstract,
|
op_exec_info->abstract,
|
||||||
op_exec_info->is_dynamic_shape,
|
op_exec_info->is_dynamic_shape,
|
||||||
|
@ -2102,7 +2158,8 @@ py::object ForwardExecutor::RunOpInMs(const OpExecInfoPtr &op_exec_info, Pynativ
|
||||||
tensors_mask,
|
tensors_mask,
|
||||||
input_tensors};
|
input_tensors};
|
||||||
#else
|
#else
|
||||||
session::OpRunInfo op_run_info = {op_exec_info->op_name,
|
session::OpRunInfo op_run_info = {false,
|
||||||
|
op_exec_info->op_name,
|
||||||
op_exec_info->py_primitive.get(),
|
op_exec_info->py_primitive.get(),
|
||||||
op_exec_info->abstract,
|
op_exec_info->abstract,
|
||||||
op_exec_info->is_dynamic_shape,
|
op_exec_info->is_dynamic_shape,
|
||||||
|
|
|
@ -100,6 +100,7 @@ class TopCellInfo {
|
||||||
const std::string &grad_operation() const { return grad_operation_; }
|
const std::string &grad_operation() const { return grad_operation_; }
|
||||||
void set_grad_operation(const std::string &grad_operation) { grad_operation_ = grad_operation; }
|
void set_grad_operation(const std::string &grad_operation) { grad_operation_ = grad_operation; }
|
||||||
mindspore::HashSet<std::string> &sub_cell_list() { return sub_cell_list_; }
|
mindspore::HashSet<std::string> &sub_cell_list() { return sub_cell_list_; }
|
||||||
|
std::set<std::string> &outputs_id() { return outputs_id_; }
|
||||||
bool IsSubCell(const std::string &cell_id) const;
|
bool IsSubCell(const std::string &cell_id) const;
|
||||||
OrderedMap<FuncGraphPtr, GraphInfoPtr> &graph_info_map() { return graph_info_map_; }
|
OrderedMap<FuncGraphPtr, GraphInfoPtr> &graph_info_map() { return graph_info_map_; }
|
||||||
OpInfoWithTensorId &op_info_with_tensor_id() { return op_info_with_tensor_id_; }
|
OpInfoWithTensorId &op_info_with_tensor_id() { return op_info_with_tensor_id_; }
|
||||||
|
@ -139,6 +140,8 @@ class TopCellInfo {
|
||||||
std::string grad_operation_;
|
std::string grad_operation_;
|
||||||
OrderedMap<FuncGraphPtr, GraphInfoPtr> graph_info_map_;
|
OrderedMap<FuncGraphPtr, GraphInfoPtr> graph_info_map_;
|
||||||
mindspore::HashSet<std::string> sub_cell_list_;
|
mindspore::HashSet<std::string> sub_cell_list_;
|
||||||
|
// Record op output tensor id
|
||||||
|
std::set<std::string> outputs_id_;
|
||||||
OpInfoWithTensorId op_info_with_tensor_id_;
|
OpInfoWithTensorId op_info_with_tensor_id_;
|
||||||
TensorIdWithTensorObject tensor_id_with_tensor_object_;
|
TensorIdWithTensorObject tensor_id_with_tensor_object_;
|
||||||
OpInfoWithMsFuncForwardTensors op_info_with_ms_func_forward_tensors_;
|
OpInfoWithMsFuncForwardTensors op_info_with_ms_func_forward_tensors_;
|
||||||
|
@ -195,6 +198,7 @@ class GradExecutor {
|
||||||
void set_grad_flag(bool flag) { grad_flag_ = flag; }
|
void set_grad_flag(bool flag) { grad_flag_ = flag; }
|
||||||
void set_graph_phase(const std::string &graph_phase) { graph_phase_ = graph_phase; }
|
void set_graph_phase(const std::string &graph_phase) { graph_phase_ = graph_phase; }
|
||||||
bool in_cell_with_custom_bprop_() const { return custom_bprop_cell_count_ > 0; }
|
bool in_cell_with_custom_bprop_() const { return custom_bprop_cell_count_ > 0; }
|
||||||
|
std::set<std::string> &forward_outputs_id() const { return top_cell()->outputs_id(); }
|
||||||
AnfNodePtr GetInput(const py::object &obj, bool op_mask);
|
AnfNodePtr GetInput(const py::object &obj, bool op_mask);
|
||||||
std::string GetCellId(const py::object &obj, const py::args &args);
|
std::string GetCellId(const py::object &obj, const py::args &args);
|
||||||
void RecordGradOpInfo(const OpExecInfoPtr &op_exec_info, const ValuePtr &op_out);
|
void RecordGradOpInfo(const OpExecInfoPtr &op_exec_info, const ValuePtr &op_out);
|
||||||
|
@ -347,6 +351,7 @@ class ForwardExecutor {
|
||||||
bool *prim_cache_hit);
|
bool *prim_cache_hit);
|
||||||
void GetOpOutput(const OpExecInfoPtr &op_exec_info, const abstract::AbstractBasePtrList &args_spec_list,
|
void GetOpOutput(const OpExecInfoPtr &op_exec_info, const abstract::AbstractBasePtrList &args_spec_list,
|
||||||
const CNodePtr &cnode, bool prim_cache_hit, py::object *ret);
|
const CNodePtr &cnode, bool prim_cache_hit, py::object *ret);
|
||||||
|
void DoNopOutput(const OpExecInfoPtr &op_exec_info, ValuePtr *out_real_value);
|
||||||
// Mix precision and Implicit transform
|
// Mix precision and Implicit transform
|
||||||
void SetCastForInputs(const OpExecInfoPtr &op_exec_info);
|
void SetCastForInputs(const OpExecInfoPtr &op_exec_info);
|
||||||
void SetTensorMixPrecisionCast(const OpExecInfoPtr &op_exec_info);
|
void SetTensorMixPrecisionCast(const OpExecInfoPtr &op_exec_info);
|
||||||
|
|
|
@ -84,6 +84,7 @@ REGISTER_PYBIND_DEFINE(MsContextPy, ([](const py::module *m) {
|
||||||
.value("save_graphs", MsCtxParam::MS_CTX_SAVE_GRAPHS_FLAG)
|
.value("save_graphs", MsCtxParam::MS_CTX_SAVE_GRAPHS_FLAG)
|
||||||
.value("enable_parallel_split", MsCtxParam::MS_CTX_ENABLE_PARALLEL_SPLIT)
|
.value("enable_parallel_split", MsCtxParam::MS_CTX_ENABLE_PARALLEL_SPLIT)
|
||||||
.value("max_device_memory", MsCtxParam::MS_CTX_MAX_DEVICE_MEMORY)
|
.value("max_device_memory", MsCtxParam::MS_CTX_MAX_DEVICE_MEMORY)
|
||||||
|
.value("mempool_block_size", MsCtxParam::MS_CTX_MEMPOOL_BLOCK_SIZE)
|
||||||
.value("mode", MsCtxParam::MS_CTX_EXECUTION_MODE)
|
.value("mode", MsCtxParam::MS_CTX_EXECUTION_MODE)
|
||||||
.value("device_target", MsCtxParam::MS_CTX_DEVICE_TARGET)
|
.value("device_target", MsCtxParam::MS_CTX_DEVICE_TARGET)
|
||||||
.value("_graph_memory_max_size", MsCtxParam::MS_CTX_GRAPH_MEMORY_MAX_SIZE)
|
.value("_graph_memory_max_size", MsCtxParam::MS_CTX_GRAPH_MEMORY_MAX_SIZE)
|
||||||
|
|
|
@ -77,6 +77,7 @@ bool AscendMemAdapter::Initialize() {
|
||||||
static_mem_offset_ = ms_used_hbm_size_;
|
static_mem_offset_ = ms_used_hbm_size_;
|
||||||
cur_dynamic_mem_offset_ = 0;
|
cur_dynamic_mem_offset_ = 0;
|
||||||
max_dynamic_mem_offset_ = 0;
|
max_dynamic_mem_offset_ = 0;
|
||||||
|
AscendMemoryPool::GetInstance().SetMempoolBlockSize(ms_used_hbm_size_);
|
||||||
MS_LOG(INFO) << "Ascend Memory Adapter initialize success, Memory Statistics:" << DevMemStatistics();
|
MS_LOG(INFO) << "Ascend Memory Adapter initialize success, Memory Statistics:" << DevMemStatistics();
|
||||||
initialized_ = true;
|
initialized_ = true;
|
||||||
return true;
|
return true;
|
||||||
|
@ -109,7 +110,7 @@ bool AscendMemAdapter::DeInitialize() {
|
||||||
return ret;
|
return ret;
|
||||||
}
|
}
|
||||||
|
|
||||||
uint8_t *AscendMemAdapter::MallocStaticDevMem(size_t size, std::string tag) {
|
uint8_t *AscendMemAdapter::MallocStaticDevMem(size_t size, const std::string &tag) {
|
||||||
std::lock_guard<std::mutex> locker(mutex_);
|
std::lock_guard<std::mutex> locker(mutex_);
|
||||||
auto new_static_offset = static_mem_offset_ - size;
|
auto new_static_offset = static_mem_offset_ - size;
|
||||||
if (new_static_offset < max_dynamic_mem_offset_) {
|
if (new_static_offset < max_dynamic_mem_offset_) {
|
||||||
|
@ -122,11 +123,10 @@ uint8_t *AscendMemAdapter::MallocStaticDevMem(size_t size, std::string tag) {
|
||||||
auto memory_block_ptr = device_mem_base_addr_ + new_static_offset;
|
auto memory_block_ptr = device_mem_base_addr_ + new_static_offset;
|
||||||
static_mem_offset_ = new_static_offset;
|
static_mem_offset_ = new_static_offset;
|
||||||
static_memory_block_list_.push_back(std::make_shared<MemoryBlock>(memory_block_ptr, size, tag));
|
static_memory_block_list_.push_back(std::make_shared<MemoryBlock>(memory_block_ptr, size, tag));
|
||||||
|
|
||||||
return memory_block_ptr;
|
return memory_block_ptr;
|
||||||
}
|
}
|
||||||
|
|
||||||
uint8_t *AscendMemAdapter::MallocDynamicDevMem(size_t size, std::string tag) {
|
uint8_t *AscendMemAdapter::MallocDynamicDevMem(size_t size, const std::string &tag) {
|
||||||
std::lock_guard<std::mutex> locker(mutex_);
|
std::lock_guard<std::mutex> locker(mutex_);
|
||||||
auto new_dynamic_offset = cur_dynamic_mem_offset_ + size;
|
auto new_dynamic_offset = cur_dynamic_mem_offset_ + size;
|
||||||
if (new_dynamic_offset > static_mem_offset_) {
|
if (new_dynamic_offset > static_mem_offset_) {
|
||||||
|
|
|
@ -22,6 +22,7 @@
|
||||||
#include <memory>
|
#include <memory>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
#include "utils/ms_context.h"
|
#include "utils/ms_context.h"
|
||||||
|
#include "runtime/device/ascend/ascend_memory_pool.h"
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace device {
|
namespace device {
|
||||||
|
@ -37,8 +38,8 @@ class AscendMemAdapter {
|
||||||
bool Initialize();
|
bool Initialize();
|
||||||
bool DeInitialize();
|
bool DeInitialize();
|
||||||
|
|
||||||
uint8_t *MallocStaticDevMem(size_t size, std::string tag = "");
|
uint8_t *MallocStaticDevMem(size_t size, const std::string &tag = "");
|
||||||
uint8_t *MallocDynamicDevMem(size_t size, std::string tag = "");
|
uint8_t *MallocDynamicDevMem(size_t size, const std::string &tag = "");
|
||||||
bool FreeStaticDevMem(void *devPtr) { return true; }
|
bool FreeStaticDevMem(void *devPtr) { return true; }
|
||||||
void ResetDynamicMemory();
|
void ResetDynamicMemory();
|
||||||
|
|
||||||
|
@ -76,12 +77,13 @@ class AscendMemAdapter {
|
||||||
uint8_t *device_mem_base_addr_{nullptr};
|
uint8_t *device_mem_base_addr_{nullptr};
|
||||||
uint64_t ms_used_hbm_size_{0};
|
uint64_t ms_used_hbm_size_{0};
|
||||||
|
|
||||||
// dynamic memory info
|
// dynamic memory info, from a low address to a high address
|
||||||
uint64_t cur_dynamic_mem_offset_{0};
|
uint64_t cur_dynamic_mem_offset_{0};
|
||||||
|
// Maximum dynamic memory have already allocated, dynamically updated
|
||||||
uint64_t max_dynamic_mem_offset_{0};
|
uint64_t max_dynamic_mem_offset_{0};
|
||||||
std::vector<std::shared_ptr<MemoryBlock>> dynamic_memory_block_list_;
|
std::vector<std::shared_ptr<MemoryBlock>> dynamic_memory_block_list_;
|
||||||
|
|
||||||
// static memory info
|
// static memory info, from a high address to a low address
|
||||||
uint64_t static_mem_offset_{0};
|
uint64_t static_mem_offset_{0};
|
||||||
std::vector<std::shared_ptr<MemoryBlock>> static_memory_block_list_;
|
std::vector<std::shared_ptr<MemoryBlock>> static_memory_block_list_;
|
||||||
};
|
};
|
||||||
|
|
|
@ -15,7 +15,6 @@
|
||||||
*/
|
*/
|
||||||
#include <string>
|
#include <string>
|
||||||
#include "runtime/device/ascend/ascend_memory_manager.h"
|
#include "runtime/device/ascend/ascend_memory_manager.h"
|
||||||
#include "runtime/device/ascend/ascend_memory_pool.h"
|
|
||||||
#include "runtime/device/ascend/ascend_memory_adapter.h"
|
#include "runtime/device/ascend/ascend_memory_adapter.h"
|
||||||
#include "utils/ms_context.h"
|
#include "utils/ms_context.h"
|
||||||
#include "runtime/mem.h"
|
#include "runtime/mem.h"
|
||||||
|
@ -49,9 +48,9 @@ void *AscendMemoryManager::MallocDevice(size_t size) {
|
||||||
return AscendMemoryPool::GetInstance().AllocTensorMem(align_size);
|
return AscendMemoryPool::GetInstance().AllocTensorMem(align_size);
|
||||||
}
|
}
|
||||||
|
|
||||||
void *AscendMemoryManager::MallocMemFromMemPool(size_t size) {
|
void *AscendMemoryManager::MallocMemFromMemPool(size_t size, bool from_persistent_mem) {
|
||||||
auto align_size = GetCommonAlignSize(size);
|
auto align_size = GetCommonAlignSize(size);
|
||||||
const auto device_addr = AscendMemoryPool::GetInstance().AllocTensorMem(align_size);
|
const auto device_addr = AscendMemoryPool::GetInstance().AllocTensorMem(align_size, from_persistent_mem);
|
||||||
if (device_addr == nullptr) {
|
if (device_addr == nullptr) {
|
||||||
MS_LOG(EXCEPTION) << "Fail to alloc memory, size: " << align_size
|
MS_LOG(EXCEPTION) << "Fail to alloc memory, size: " << align_size
|
||||||
<< ", memory statistics:" << AscendMemAdapter::GetInstance().DevMemStatistics();
|
<< ", memory statistics:" << AscendMemAdapter::GetInstance().DevMemStatistics();
|
||||||
|
@ -132,8 +131,8 @@ uint8_t *AscendMemoryManager::MallocCommunicationMemFromMemPool(size_t size) {
|
||||||
|
|
||||||
size_t AscendMemoryManager::GetAvailableMemSize() {
|
size_t AscendMemoryManager::GetAvailableMemSize() {
|
||||||
auto available_mem_size = AscendMemoryPool::GetInstance().free_mem_size() +
|
auto available_mem_size = AscendMemoryPool::GetInstance().free_mem_size() +
|
||||||
AscendMemoryPool::GetInstance().total_mem_statistics() -
|
AscendMemoryPool::GetInstance().TotalMemStatistics() -
|
||||||
AscendMemoryPool::GetInstance().used_mem_statistics();
|
AscendMemoryPool::GetInstance().TotalUsedMemStatistics();
|
||||||
return available_mem_size;
|
return available_mem_size;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -32,7 +32,7 @@ class AscendMemoryManager : public MemoryManager {
|
||||||
void FreeDeviceMemory() override;
|
void FreeDeviceMemory() override;
|
||||||
void ResetDynamicMemory() override;
|
void ResetDynamicMemory() override;
|
||||||
void ClearGlobalIdleMem() override;
|
void ClearGlobalIdleMem() override;
|
||||||
void *MallocMemFromMemPool(size_t size) override;
|
void *MallocMemFromMemPool(size_t size, bool from_persistent_mem) override;
|
||||||
void *MallocDevice(size_t size) override;
|
void *MallocDevice(size_t size) override;
|
||||||
void FreeMemFromMemPool(void *device_ptr) override;
|
void FreeMemFromMemPool(void *device_ptr) override;
|
||||||
uint64_t GetMsMaxMemSize();
|
uint64_t GetMsMaxMemSize();
|
||||||
|
|
|
@ -24,37 +24,37 @@
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace device {
|
namespace device {
|
||||||
namespace ascend {
|
namespace ascend {
|
||||||
// The minimum unit size (256MB) of memory block used for dynamic extend.
|
|
||||||
static const size_t ASCEND_DYNAMIC_MEM_ALLOC_UNIT_SIZE = 256 << 20;
|
|
||||||
// The minimum unit size (8MB) of memory block used for dynamic extend in graph mode.
|
// The minimum unit size (8MB) of memory block used for dynamic extend in graph mode.
|
||||||
static const size_t ASCEND_DYNAMIC_MEM_ALLOC_UNIT_SIZE_FOR_GRAPH = 8 << 20;
|
static const size_t ASCEND_DYNAMIC_MEM_ALLOC_UNIT_SIZE_FOR_GRAPH = 8 << 20;
|
||||||
|
|
||||||
size_t AscendMemoryPool::CalMemBlockAllocSize(size_t size) {
|
size_t AscendMemoryPool::CalMemBlockAllocSize(size_t size, bool from_persistent_mem) {
|
||||||
auto device_free_mem_size = free_mem_size();
|
auto device_free_mem_size = free_mem_size();
|
||||||
if (device_free_mem_size < size) {
|
if (device_free_mem_size < size) {
|
||||||
MS_LOG(WARNING) << "The dynamic memory pool total size is "
|
MS_LOG(WARNING) << "The dynamic memory pool total size is "
|
||||||
<< device::ascend::AscendMemoryPool::GetInstance().total_mem_statistics() / kMBToByte
|
<< device::ascend::AscendMemoryPool::GetInstance().TotalMemStatistics() / kMBToByte
|
||||||
<< "M, total used size is "
|
<< "M, total used size is "
|
||||||
<< device::ascend::AscendMemoryPool::GetInstance().used_mem_statistics() / kMBToByte
|
<< device::ascend::AscendMemoryPool::GetInstance().TotalUsedMemStatistics() / kMBToByte
|
||||||
<< "M, used peak size is "
|
<< "M, used peak size is "
|
||||||
<< device::ascend::AscendMemoryPool::GetInstance().used_mem_peak_statistics() / kMBToByte << "M.";
|
<< device::ascend::AscendMemoryPool::GetInstance().UsedMemPeakStatistics() / kMBToByte << "M.";
|
||||||
MS_LOG(WARNING) << "Out of Memory. Request memory size: " << size
|
MS_LOG(WARNING) << "Out of Memory. Request memory size: " << size << ", device free size " << device_free_mem_size
|
||||||
<< ", Memory Statistic:" << AscendMemAdapter::GetInstance().DevMemStatistics()
|
<< ", Memory Statistic:" << AscendMemAdapter::GetInstance().DevMemStatistics()
|
||||||
<< "Please try to reduce 'batch_size' or check whether exists extra large shape. More "
|
<< "Please try to reduce 'batch_size' or check whether exists extra large shape. More "
|
||||||
"details can be found in MindSpore's FAQ with keyword 'Out of Memory'.";
|
"details can be found in MindSpore's FAQ with keyword 'Out of Memory'.";
|
||||||
return 0;
|
return 0;
|
||||||
}
|
}
|
||||||
auto alloc_mem_size = ASCEND_DYNAMIC_MEM_ALLOC_UNIT_SIZE;
|
size_t alloc_mem_size = MemAllocUnitSize(from_persistent_mem);
|
||||||
auto ms_context = MsContext::GetInstance();
|
auto ms_context = MsContext::GetInstance();
|
||||||
MS_EXCEPTION_IF_NULL(ms_context);
|
MS_EXCEPTION_IF_NULL(ms_context);
|
||||||
const bool pynative_mode = (ms_context->get_param<int>(MS_CTX_EXECUTION_MODE) == kPynativeMode);
|
const bool pynative_mode = (ms_context->get_param<int>(MS_CTX_EXECUTION_MODE) == kPynativeMode);
|
||||||
if (pynative_mode) {
|
if (pynative_mode) {
|
||||||
// Growing at twice of alloc size
|
// Growing at twice of alloc size
|
||||||
|
MS_LOG(DEBUG) << "Get unit block size " << alloc_mem_size;
|
||||||
constexpr size_t kDouble = 2;
|
constexpr size_t kDouble = 2;
|
||||||
while (alloc_mem_size < size) {
|
while (alloc_mem_size < size) {
|
||||||
alloc_mem_size = alloc_mem_size * kDouble;
|
alloc_mem_size = alloc_mem_size * kDouble;
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
|
// The graph mode controls itself independently
|
||||||
alloc_mem_size = ASCEND_DYNAMIC_MEM_ALLOC_UNIT_SIZE_FOR_GRAPH;
|
alloc_mem_size = ASCEND_DYNAMIC_MEM_ALLOC_UNIT_SIZE_FOR_GRAPH;
|
||||||
while (alloc_mem_size < size) {
|
while (alloc_mem_size < size) {
|
||||||
alloc_mem_size = alloc_mem_size + ASCEND_DYNAMIC_MEM_ALLOC_UNIT_SIZE_FOR_GRAPH;
|
alloc_mem_size = alloc_mem_size + ASCEND_DYNAMIC_MEM_ALLOC_UNIT_SIZE_FOR_GRAPH;
|
||||||
|
@ -69,7 +69,6 @@ size_t AscendMemoryPool::AllocDeviceMem(size_t size, DeviceMemPtr *addr) {
|
||||||
if (size == 0) {
|
if (size == 0) {
|
||||||
MS_LOG(EXCEPTION) << "Failed to alloc memory pool resource, the size is zero!";
|
MS_LOG(EXCEPTION) << "Failed to alloc memory pool resource, the size is zero!";
|
||||||
}
|
}
|
||||||
|
|
||||||
*addr = AscendMemAdapter::GetInstance().MallocStaticDevMem(size);
|
*addr = AscendMemAdapter::GetInstance().MallocStaticDevMem(size);
|
||||||
if (*addr == nullptr) {
|
if (*addr == nullptr) {
|
||||||
MS_LOG(EXCEPTION) << "Alloc device memory pool address is nullptr, failed to alloc memory pool resource!";
|
MS_LOG(EXCEPTION) << "Alloc device memory pool address is nullptr, failed to alloc memory pool resource!";
|
||||||
|
@ -83,11 +82,17 @@ bool AscendMemoryPool::FreeDeviceMem(const DeviceMemPtr &addr) {
|
||||||
}
|
}
|
||||||
|
|
||||||
void AscendMemoryPool::ResetIdleMemBuf() {
|
void AscendMemoryPool::ResetIdleMemBuf() {
|
||||||
auto idle_mem_buf_map = DynamicMemPoolBestFit::global_idle_mem_buf_map();
|
auto fn = [this](const MemStatusManagerPtr &mem_mng) {
|
||||||
for (auto &it : idle_mem_buf_map) {
|
if (mem_mng->mem_block_list_.empty()) {
|
||||||
MS_EXCEPTION_IF_NULL(it.second);
|
return;
|
||||||
(void)rtMemset(it.second->device_addr_, it.first, 0, it.first);
|
}
|
||||||
}
|
for (auto &it : mem_mng->idle_mem_buf_map_) {
|
||||||
|
MS_EXCEPTION_IF_NULL(it.second);
|
||||||
|
(void)rtMemset(it.second->device_addr_, it.first, 0, it.first);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
fn(persistent_mem());
|
||||||
|
fn(common_mem());
|
||||||
}
|
}
|
||||||
|
|
||||||
size_t AscendMemoryPool::free_mem_size() { return AscendMemAdapter::GetInstance().FreeDevMemSize(); }
|
size_t AscendMemoryPool::free_mem_size() { return AscendMemAdapter::GetInstance().FreeDevMemSize(); }
|
||||||
|
|
|
@ -42,7 +42,7 @@ class AscendMemoryPool : public DynamicMemPoolBestFit {
|
||||||
|
|
||||||
protected:
|
protected:
|
||||||
// Calculate memory block required alloc size when adding the memory block.
|
// Calculate memory block required alloc size when adding the memory block.
|
||||||
size_t CalMemBlockAllocSize(size_t size) override;
|
size_t CalMemBlockAllocSize(size_t size, bool from_persistent_mem) override;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
AscendMemoryPool() = default;
|
AscendMemoryPool() = default;
|
||||||
|
|
|
@ -112,6 +112,7 @@ void CPUKernelRuntime::AssignValueNodeAddress(session::KernelGraph *kernel_graph
|
||||||
size_t tensor_size = std::accumulate(data_shape.begin(), data_shape.end(), type_size, std::multiplies<size_t>());
|
size_t tensor_size = std::accumulate(data_shape.begin(), data_shape.end(), type_size, std::multiplies<size_t>());
|
||||||
DeviceAddressPtr address = nullptr;
|
DeviceAddressPtr address = nullptr;
|
||||||
address = CreateDeviceAddress(nullptr, tensor_size, kOpFormat_DEFAULT, output_type_id);
|
address = CreateDeviceAddress(nullptr, tensor_size, kOpFormat_DEFAULT, output_type_id);
|
||||||
|
address->set_from_persistent_mem(tensor->is_parameter());
|
||||||
MS_EXCEPTION_IF_NULL(address);
|
MS_EXCEPTION_IF_NULL(address);
|
||||||
if (tensor->data_type() == output_type_id) {
|
if (tensor->data_type() == output_type_id) {
|
||||||
address->ptr_ = tensor->data_c();
|
address->ptr_ = tensor->data_c();
|
||||||
|
@ -146,6 +147,7 @@ void CPUKernelRuntime::AssignInputNodeAddress(const session::KernelGraph *kernel
|
||||||
: std::accumulate(fmt_shape.begin(), fmt_shape.end(), type_size, std::multiplies<size_t>());
|
: std::accumulate(fmt_shape.begin(), fmt_shape.end(), type_size, std::multiplies<size_t>());
|
||||||
auto format = AnfAlgo::GetOutputFormat(item, index);
|
auto format = AnfAlgo::GetOutputFormat(item, index);
|
||||||
auto address = CreateDeviceAddress(nullptr, tensor_size, format, output_type_id);
|
auto address = CreateDeviceAddress(nullptr, tensor_size, format, output_type_id);
|
||||||
|
address->set_from_persistent_mem(true);
|
||||||
AnfAlgo::SetOutputAddr(address, index, item.get());
|
AnfAlgo::SetOutputAddr(address, index, item.get());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -46,7 +46,9 @@ class CPUMemoryManager : public MemoryManager {
|
||||||
void IncreaseSummaryRefCount(const session::NamedSummaryOutputs &summary_outputs);
|
void IncreaseSummaryRefCount(const session::NamedSummaryOutputs &summary_outputs);
|
||||||
void DecreaseSummaryRefCount(const session::NamedSummaryOutputs &summary_outputs);
|
void DecreaseSummaryRefCount(const session::NamedSummaryOutputs &summary_outputs);
|
||||||
|
|
||||||
void *MallocMemFromMemPool(size_t size) override { return CPUMemoryPool::GetInstance().AllocTensorMem(size); }
|
void *MallocMemFromMemPool(size_t size, bool from_persistent_mem) override {
|
||||||
|
return CPUMemoryPool::GetInstance().AllocTensorMem(size, from_persistent_mem);
|
||||||
|
}
|
||||||
void FreeMemFromMemPool(void *device_ptr) override { CPUMemoryPool::GetInstance().FreeTensorMem(device_ptr); }
|
void FreeMemFromMemPool(void *device_ptr) override { CPUMemoryPool::GetInstance().FreeTensorMem(device_ptr); }
|
||||||
std::vector<void *> MallocContinuousMemFromMemPool(size_t total_size, std::vector<size_t> size_list) override {
|
std::vector<void *> MallocContinuousMemFromMemPool(size_t total_size, std::vector<size_t> size_list) override {
|
||||||
return CPUMemoryPool::GetInstance().AllocContinuousTensorMem(total_size, size_list);
|
return CPUMemoryPool::GetInstance().AllocContinuousTensorMem(total_size, size_list);
|
||||||
|
|
|
@ -101,6 +101,8 @@ class DeviceAddress : public mindspore::DeviceSync {
|
||||||
bool is_ptr_persisted() const { return is_ptr_persisted_; }
|
bool is_ptr_persisted() const { return is_ptr_persisted_; }
|
||||||
void set_is_ptr_persisted(bool is_ptr_persisted) { is_ptr_persisted_ = is_ptr_persisted; }
|
void set_is_ptr_persisted(bool is_ptr_persisted) { is_ptr_persisted_ = is_ptr_persisted; }
|
||||||
void set_host_shape(const ShapeVector &shape) { host_shape_ = shape; }
|
void set_host_shape(const ShapeVector &shape) { host_shape_ = shape; }
|
||||||
|
bool from_persistent_mem() const { return from_persistent_mem_; }
|
||||||
|
void set_from_persistent_mem(bool from_persistent_mem) { from_persistent_mem_ = from_persistent_mem; }
|
||||||
virtual void set_status(DeviceAddressStatus status) {}
|
virtual void set_status(DeviceAddressStatus status) {}
|
||||||
virtual DeviceAddressStatus status() const { return DeviceAddressStatus::kInDevice; }
|
virtual DeviceAddressStatus status() const { return DeviceAddressStatus::kInDevice; }
|
||||||
virtual DeviceAddressType DeviceType() const { return DeviceAddressType::kUnknown; }
|
virtual DeviceAddressType DeviceType() const { return DeviceAddressType::kUnknown; }
|
||||||
|
@ -145,6 +147,7 @@ class DeviceAddress : public mindspore::DeviceSync {
|
||||||
// The key of device context.
|
// The key of device context.
|
||||||
std::string device_name_{""};
|
std::string device_name_{""};
|
||||||
uint32_t device_id_{0};
|
uint32_t device_id_{0};
|
||||||
|
bool from_persistent_mem_{false};
|
||||||
|
|
||||||
friend class KernelRuntime;
|
friend class KernelRuntime;
|
||||||
friend class MemoryManager;
|
friend class MemoryManager;
|
||||||
|
|
|
@ -422,7 +422,7 @@ void GPUKernelRuntime::FetchMemUnitSize(const session::KernelGraph *graph) {
|
||||||
current_sum_size = 0;
|
current_sum_size = 0;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if (max_sum_size > GPUMemoryAllocator::GetInstance().mem_alloc_unit_size()) {
|
if (max_sum_size > GPUMemoryAllocator::GetInstance().MemAllocUnitSize()) {
|
||||||
size_t unit_size = (max_sum_size / DYNAMIC_MEM_ALLOC_UNIT_SIZE + 1) * DYNAMIC_MEM_ALLOC_UNIT_SIZE;
|
size_t unit_size = (max_sum_size / DYNAMIC_MEM_ALLOC_UNIT_SIZE + 1) * DYNAMIC_MEM_ALLOC_UNIT_SIZE;
|
||||||
if (unit_size < DYNAMIC_MEM_ALLOC_UNIT_SIZE) {
|
if (unit_size < DYNAMIC_MEM_ALLOC_UNIT_SIZE) {
|
||||||
MS_LOG(WARNING) << "Current memory unit size [" << unit_size << "] is too small.";
|
MS_LOG(WARNING) << "Current memory unit size [" << unit_size << "] is too small.";
|
||||||
|
@ -432,7 +432,7 @@ void GPUKernelRuntime::FetchMemUnitSize(const session::KernelGraph *graph) {
|
||||||
constexpr float kValidMemoryRatio = 0.9;
|
constexpr float kValidMemoryRatio = 0.9;
|
||||||
free_mem_size = kValidMemoryRatio * free_mem_size;
|
free_mem_size = kValidMemoryRatio * free_mem_size;
|
||||||
unit_size = std::min(unit_size, free_mem_size);
|
unit_size = std::min(unit_size, free_mem_size);
|
||||||
GPUMemoryAllocator::GetInstance().set_mem_alloc_unit_size(unit_size);
|
GPUMemoryAllocator::GetInstance().SetMemAllocUintSize(unit_size);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -25,6 +25,7 @@ namespace mindspore {
|
||||||
namespace device {
|
namespace device {
|
||||||
namespace gpu {
|
namespace gpu {
|
||||||
const size_t kGBToByte = 1024 << 20;
|
const size_t kGBToByte = 1024 << 20;
|
||||||
|
constexpr float kReservedMemoryRatio = 0.0625; // 1/16
|
||||||
|
|
||||||
bool GPUMemoryAllocator::Init() {
|
bool GPUMemoryAllocator::Init() {
|
||||||
size_t total_size = CudaDriver::total_mem_size();
|
size_t total_size = CudaDriver::total_mem_size();
|
||||||
|
@ -42,6 +43,13 @@ bool GPUMemoryAllocator::Init() {
|
||||||
<< total_size << ", current free memory size " << free_size << ", set max available memory size "
|
<< total_size << ", current free memory size " << free_size << ", set max available memory size "
|
||||||
<< available_device_memory_ << ".";
|
<< available_device_memory_ << ".";
|
||||||
}
|
}
|
||||||
|
// In gpu mode, recommend 1/16 reserved for other cuda functions
|
||||||
|
if (available_device_memory_ > total_size) {
|
||||||
|
size_t recommend_mem_size_for_others = FloatToSize(total_size * kReservedMemoryRatio);
|
||||||
|
SetMempoolBlockSize(std::min(available_device_memory_, total_size - recommend_mem_size_for_others));
|
||||||
|
} else {
|
||||||
|
SetMempoolBlockSize(std::min(available_device_memory_, total_size));
|
||||||
|
}
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -71,7 +79,7 @@ bool GPUMemoryAllocator::AllocBufferQueueMem(size_t size, DeviceMemPtr *addr) {
|
||||||
auto alloc_size = AllocDeviceMem(size, addr);
|
auto alloc_size = AllocDeviceMem(size, addr);
|
||||||
buffer_q_addr_ = *addr;
|
buffer_q_addr_ = *addr;
|
||||||
// Buffer queue needs to ensure that the alloc_size and size is equal.
|
// Buffer queue needs to ensure that the alloc_size and size is equal.
|
||||||
return (alloc_size == size) ? true : false;
|
return alloc_size == size;
|
||||||
}
|
}
|
||||||
|
|
||||||
size_t GPUMemoryAllocator::AllocDeviceMem(size_t size, DeviceMemPtr *addr) {
|
size_t GPUMemoryAllocator::AllocDeviceMem(size_t size, DeviceMemPtr *addr) {
|
||||||
|
@ -90,8 +98,9 @@ size_t GPUMemoryAllocator::AllocDeviceMem(size_t size, DeviceMemPtr *addr) {
|
||||||
}
|
}
|
||||||
total_used_device_memory_ += alloc_size;
|
total_used_device_memory_ += alloc_size;
|
||||||
available_device_memory_ -= alloc_size;
|
available_device_memory_ -= alloc_size;
|
||||||
MS_LOG(INFO) << "Current free memory size[" << free_size - alloc_size << "], current alloc size[" << alloc_size
|
MS_LOG(INFO) << "Cuda current free memory size[" << free_size << "], alloc size[" << alloc_size
|
||||||
<< "], total used size[" << total_used_device_memory_ << "].";
|
<< "], left free memory size[" << free_size - alloc_size << "]"
|
||||||
|
<< ".Total used size[" << total_used_device_memory_ << "].";
|
||||||
return alloc_size;
|
return alloc_size;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -24,8 +24,8 @@
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace device {
|
namespace device {
|
||||||
namespace gpu {
|
namespace gpu {
|
||||||
void *GPUMemoryManager::MallocMemFromMemPool(size_t size) {
|
void *GPUMemoryManager::MallocMemFromMemPool(size_t size, bool from_persistent_mem) {
|
||||||
return GPUMemoryAllocator::GetInstance().AllocTensorMem(size);
|
return GPUMemoryAllocator::GetInstance().AllocTensorMem(size, from_persistent_mem);
|
||||||
}
|
}
|
||||||
|
|
||||||
void GPUMemoryManager::FreeMemFromMemPool(void *device_ptr) {
|
void GPUMemoryManager::FreeMemFromMemPool(void *device_ptr) {
|
||||||
|
@ -39,7 +39,7 @@ std::vector<void *> GPUMemoryManager::MallocContinuousMemFromMemPool(size_t tota
|
||||||
bool GPUMemoryManager::MallocContinuousMemFromMemPool(const DeviceAddressPtrList &addr_list, size_t total_size,
|
bool GPUMemoryManager::MallocContinuousMemFromMemPool(const DeviceAddressPtrList &addr_list, size_t total_size,
|
||||||
std::vector<size_t> size_list) {
|
std::vector<size_t> size_list) {
|
||||||
auto device_ptr_list = MallocContinuousMemFromMemPool(total_size, size_list);
|
auto device_ptr_list = MallocContinuousMemFromMemPool(total_size, size_list);
|
||||||
if (device_ptr_list.size() == 0) {
|
if (device_ptr_list.empty()) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
if (addr_list.size() != device_ptr_list.size()) {
|
if (addr_list.size() != device_ptr_list.size()) {
|
||||||
|
@ -76,7 +76,7 @@ void GPUMemoryManager::MallocDeviceMemory() {
|
||||||
if (ps::ps_cache_instance.initialized_ps_cache()) {
|
if (ps::ps_cache_instance.initialized_ps_cache()) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
auto device_addr = MallocMemFromMemPool(1);
|
auto device_addr = MallocMemFromMemPool(1, false);
|
||||||
if (!device_addr) {
|
if (!device_addr) {
|
||||||
MS_LOG(EXCEPTION) << "Dynamic memory pool init error.";
|
MS_LOG(EXCEPTION) << "Dynamic memory pool init error.";
|
||||||
}
|
}
|
||||||
|
@ -87,7 +87,7 @@ void GPUMemoryManager::FreeDeviceMemory() { GPUMemoryAllocator::GetInstance().Re
|
||||||
uint8_t *GPUMemoryManager::MallocStaticMem(size_t size, bool, uint32_t) {
|
uint8_t *GPUMemoryManager::MallocStaticMem(size_t size, bool, uint32_t) {
|
||||||
auto context_ptr = MsContext::GetInstance();
|
auto context_ptr = MsContext::GetInstance();
|
||||||
MS_EXCEPTION_IF_NULL(context_ptr);
|
MS_EXCEPTION_IF_NULL(context_ptr);
|
||||||
auto device_ptr = MallocMemFromMemPool(size);
|
auto device_ptr = MallocMemFromMemPool(size, false);
|
||||||
if (device_ptr == nullptr) {
|
if (device_ptr == nullptr) {
|
||||||
MS_LOG(EXCEPTION) << "Device memory isn't enough and alloc failed, alloc size:" << size;
|
MS_LOG(EXCEPTION) << "Device memory isn't enough and alloc failed, alloc size:" << size;
|
||||||
}
|
}
|
||||||
|
|
|
@ -29,7 +29,7 @@ class GPUMemoryManager : public MemoryManager {
|
||||||
void MallocDeviceMemory() override;
|
void MallocDeviceMemory() override;
|
||||||
void FreeDeviceMemory() override;
|
void FreeDeviceMemory() override;
|
||||||
|
|
||||||
void *MallocMemFromMemPool(size_t size) override;
|
void *MallocMemFromMemPool(size_t size, bool from_persistent_mem) override;
|
||||||
void FreeMemFromMemPool(void *device_ptr) override;
|
void FreeMemFromMemPool(void *device_ptr) override;
|
||||||
std::vector<void *> MallocContinuousMemFromMemPool(size_t total_size, std::vector<size_t> size_list) override;
|
std::vector<void *> MallocContinuousMemFromMemPool(size_t total_size, std::vector<size_t> size_list) override;
|
||||||
bool MallocContinuousMemFromMemPool(const DeviceAddressPtrList &addr_list, size_t total_size,
|
bool MallocContinuousMemFromMemPool(const DeviceAddressPtrList &addr_list, size_t total_size,
|
||||||
|
|
|
@ -258,6 +258,7 @@ void KernelRuntime::RunOpMallocPre(const session::KernelGraph &graph,
|
||||||
auto output_format = op_runtime_info->output_format(index);
|
auto output_format = op_runtime_info->output_format(index);
|
||||||
auto device_address =
|
auto device_address =
|
||||||
CreateDeviceAddress(nullptr, output_tensor_size, output_format, output_type_id, {item, index});
|
CreateDeviceAddress(nullptr, output_tensor_size, output_format, output_type_id, {item, index});
|
||||||
|
device_address->set_from_persistent_mem(current_tensor->is_parameter());
|
||||||
AnfAlgo::SetOutputAddr(device_address, index, item.get());
|
AnfAlgo::SetOutputAddr(device_address, index, item.get());
|
||||||
current_tensor->set_device_address(device_address);
|
current_tensor->set_device_address(device_address);
|
||||||
current_tensor->set_sync_status(kNeedSyncHostToDevice);
|
current_tensor->set_sync_status(kNeedSyncHostToDevice);
|
||||||
|
@ -287,6 +288,7 @@ void KernelRuntime::ResetNodeAddress(const session::KernelGraph &kernel_graph) {
|
||||||
auto tensor_size = AnfAlgo::GetOutputTensorMemSize(input_node, index);
|
auto tensor_size = AnfAlgo::GetOutputTensorMemSize(input_node, index);
|
||||||
auto device_address = CreateDeviceAddress(nullptr, tensor_size, AnfAlgo::GetOutputFormat(input_node, index),
|
auto device_address = CreateDeviceAddress(nullptr, tensor_size, AnfAlgo::GetOutputFormat(input_node, index),
|
||||||
output_type_id, {input_node, index});
|
output_type_id, {input_node, index});
|
||||||
|
device_address->set_from_persistent_mem(input_node->isa<Parameter>());
|
||||||
AnfAlgo::SetOutputAddr(device_address, index, input_node.get());
|
AnfAlgo::SetOutputAddr(device_address, index, input_node.get());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -306,7 +308,7 @@ void KernelRuntime::ResetNodeAddress(const session::KernelGraph &kernel_graph) {
|
||||||
}
|
}
|
||||||
|
|
||||||
void KernelRuntime::RunOpAssignMemory(const std::vector<tensor::TensorPtr> &input_tensors,
|
void KernelRuntime::RunOpAssignMemory(const std::vector<tensor::TensorPtr> &input_tensors,
|
||||||
const session::KernelGraph &graph,
|
const session::KernelGraph &graph, bool is_gradient_out,
|
||||||
const std::map<tensor::TensorPtr, session::KernelWithIndex> &tensor_to_node) {
|
const std::map<tensor::TensorPtr, session::KernelWithIndex> &tensor_to_node) {
|
||||||
MS_EXCEPTION_IF_NULL(mem_manager_);
|
MS_EXCEPTION_IF_NULL(mem_manager_);
|
||||||
mem_manager_->ResetDynamicMemory();
|
mem_manager_->ResetDynamicMemory();
|
||||||
|
@ -319,7 +321,7 @@ void KernelRuntime::RunOpAssignMemory(const std::vector<tensor::TensorPtr> &inpu
|
||||||
RunOpAssignInputMemory(input_tensors, graph);
|
RunOpAssignInputMemory(input_tensors, graph);
|
||||||
AssignStaticMemoryValueNode(graph);
|
AssignStaticMemoryValueNode(graph);
|
||||||
for (const auto &node : graph.execution_order()) {
|
for (const auto &node : graph.execution_order()) {
|
||||||
RunOpAssignOutputMemory(node, tensor_to_node);
|
RunOpAssignOutputMemory(node, tensor_to_node, is_gradient_out);
|
||||||
RunOpAssignWorkSpaceMemory(node);
|
RunOpAssignWorkSpaceMemory(node);
|
||||||
}
|
}
|
||||||
UpdateRefNodeOutputMem(graph);
|
UpdateRefNodeOutputMem(graph);
|
||||||
|
@ -398,6 +400,7 @@ void KernelRuntime::RunOpAssignInputMemory(const std::vector<tensor::TensorPtr>
|
||||||
auto current_tensor = input_tensors[input_index];
|
auto current_tensor = input_tensors[input_index];
|
||||||
MS_EXCEPTION_IF_NULL(current_tensor);
|
MS_EXCEPTION_IF_NULL(current_tensor);
|
||||||
auto output_address = std::dynamic_pointer_cast<device::DeviceAddress>(current_tensor->device_address());
|
auto output_address = std::dynamic_pointer_cast<device::DeviceAddress>(current_tensor->device_address());
|
||||||
|
// Device address have already create
|
||||||
if (output_address != nullptr && output_address->DeviceType() == GetTargetDeviceAddressType()) {
|
if (output_address != nullptr && output_address->DeviceType() == GetTargetDeviceAddressType()) {
|
||||||
if (output_address->ptr_ == nullptr) {
|
if (output_address->ptr_ == nullptr) {
|
||||||
if (!mem_manager_->MallocMemFromMemPool(output_address, output_address->size())) {
|
if (!mem_manager_->MallocMemFromMemPool(output_address, output_address->size())) {
|
||||||
|
@ -413,10 +416,12 @@ void KernelRuntime::RunOpAssignInputMemory(const std::vector<tensor::TensorPtr>
|
||||||
output_type_id = AnfAlgo::GetOutputInferDataType(item, index);
|
output_type_id = AnfAlgo::GetOutputInferDataType(item, index);
|
||||||
}
|
}
|
||||||
auto tensor_size = AnfAlgo::GetOutputTensorMemSize(item, index);
|
auto tensor_size = AnfAlgo::GetOutputTensorMemSize(item, index);
|
||||||
|
// Device address new create
|
||||||
auto device_address =
|
auto device_address =
|
||||||
CreateDeviceAddress(nullptr, tensor_size, AnfAlgo::GetOutputFormat(item, index), output_type_id, {item, index});
|
CreateDeviceAddress(nullptr, tensor_size, AnfAlgo::GetOutputFormat(item, index), output_type_id, {item, index});
|
||||||
MS_EXCEPTION_IF_NULL(device_address);
|
MS_EXCEPTION_IF_NULL(device_address);
|
||||||
MS_EXCEPTION_IF_NULL(mem_manager_);
|
MS_EXCEPTION_IF_NULL(mem_manager_);
|
||||||
|
device_address->set_from_persistent_mem(true);
|
||||||
auto ret = mem_manager_->MallocMemFromMemPool(device_address, tensor_size);
|
auto ret = mem_manager_->MallocMemFromMemPool(device_address, tensor_size);
|
||||||
if (!ret) {
|
if (!ret) {
|
||||||
MS_LOG(EXCEPTION) << "Device memory isn't enough and alloc failed, alloc size:" << tensor_size;
|
MS_LOG(EXCEPTION) << "Device memory isn't enough and alloc failed, alloc size:" << tensor_size;
|
||||||
|
@ -426,8 +431,9 @@ void KernelRuntime::RunOpAssignInputMemory(const std::vector<tensor::TensorPtr>
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void KernelRuntime::RunOpAssignOutputMemory(
|
void KernelRuntime::RunOpAssignOutputMemory(const AnfNodePtr &kernel,
|
||||||
const AnfNodePtr &kernel, const std::map<tensor::TensorPtr, session::KernelWithIndex> &tensor_to_node) {
|
const std::map<tensor::TensorPtr, session::KernelWithIndex> &tensor_to_node,
|
||||||
|
bool is_gradient_out) {
|
||||||
MS_EXCEPTION_IF_NULL(kernel);
|
MS_EXCEPTION_IF_NULL(kernel);
|
||||||
MS_EXCEPTION_IF_NULL(mem_manager_);
|
MS_EXCEPTION_IF_NULL(mem_manager_);
|
||||||
auto kernel_mod = AnfAlgo::GetKernelMod(kernel);
|
auto kernel_mod = AnfAlgo::GetKernelMod(kernel);
|
||||||
|
@ -464,8 +470,11 @@ void KernelRuntime::RunOpAssignOutputMemory(
|
||||||
std::string output_format = AnfAlgo::GetOutputFormat(kernel, i);
|
std::string output_format = AnfAlgo::GetOutputFormat(kernel, i);
|
||||||
auto output_type = AnfAlgo::GetOutputDeviceDataType(kernel, i);
|
auto output_type = AnfAlgo::GetOutputDeviceDataType(kernel, i);
|
||||||
auto device_address = CreateDeviceAddress(nullptr, output_sizes[i], output_format, output_type, {kernel, i});
|
auto device_address = CreateDeviceAddress(nullptr, output_sizes[i], output_format, output_type, {kernel, i});
|
||||||
device_address->set_host_shape(trans::GetRuntimePaddingShape(kernel, i));
|
|
||||||
MS_EXCEPTION_IF_NULL(device_address);
|
MS_EXCEPTION_IF_NULL(device_address);
|
||||||
|
device_address->set_host_shape(trans::GetRuntimePaddingShape(kernel, i));
|
||||||
|
if (is_gradient_out) {
|
||||||
|
device_address->set_from_persistent_mem(true);
|
||||||
|
}
|
||||||
auto ret = mem_manager_->MallocMemFromMemPool(device_address, output_sizes[i]);
|
auto ret = mem_manager_->MallocMemFromMemPool(device_address, output_sizes[i]);
|
||||||
if (!ret) {
|
if (!ret) {
|
||||||
MS_LOG(EXCEPTION) << "Device memory isn't enough and alloc failed, alloc size:" << output_sizes[i];
|
MS_LOG(EXCEPTION) << "Device memory isn't enough and alloc failed, alloc size:" << output_sizes[i];
|
||||||
|
@ -917,6 +926,7 @@ void KernelRuntime::AssignValueNodeTensor(const ValueNodePtr &value_node, const
|
||||||
auto output_format = AnfAlgo::GetOutputFormat(value_node, output_idx);
|
auto output_format = AnfAlgo::GetOutputFormat(value_node, output_idx);
|
||||||
DeviceAddressPtr address =
|
DeviceAddressPtr address =
|
||||||
CreateDeviceAddress(nullptr, node_size, output_format, output_type_id, {value_node, output_idx});
|
CreateDeviceAddress(nullptr, node_size, output_format, output_type_id, {value_node, output_idx});
|
||||||
|
address->set_from_persistent_mem(true);
|
||||||
MS_EXCEPTION_IF_NULL(address);
|
MS_EXCEPTION_IF_NULL(address);
|
||||||
if (ms_context->get_param<bool>(MS_CTX_ENABLE_PYNATIVE_INFER) &&
|
if (ms_context->get_param<bool>(MS_CTX_ENABLE_PYNATIVE_INFER) &&
|
||||||
!mem_manager_->MallocMemFromMemPool(address, node_size)) {
|
!mem_manager_->MallocMemFromMemPool(address, node_size)) {
|
||||||
|
@ -979,6 +989,7 @@ void KernelRuntime::AssignStaticMemoryValueNode(const session::KernelGraph &grap
|
||||||
ms_context->get_param<int>(MS_CTX_EXECUTION_MODE) == kPynativeMode;
|
ms_context->get_param<int>(MS_CTX_EXECUTION_MODE) == kPynativeMode;
|
||||||
auto address = CreateDeviceAddressForStringValue(node_value, use_mem_from_memory_pool, graph.graph_id());
|
auto address = CreateDeviceAddressForStringValue(node_value, use_mem_from_memory_pool, graph.graph_id());
|
||||||
MS_EXCEPTION_IF_NULL(address);
|
MS_EXCEPTION_IF_NULL(address);
|
||||||
|
address->set_from_persistent_mem(true);
|
||||||
AnfAlgo::SetOutputAddr(address, 0, value_node.get());
|
AnfAlgo::SetOutputAddr(address, 0, value_node.get());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -991,6 +1002,7 @@ DeviceAddressPtr KernelRuntime::CreateDeviceAddressForStringValue(const ValuePtr
|
||||||
size_t tensor_size = value_string.size();
|
size_t tensor_size = value_string.size();
|
||||||
DeviceAddressPtr address = CreateDeviceAddress(nullptr, tensor_size, kOpFormat_DEFAULT, kNumberTypeUInt8);
|
DeviceAddressPtr address = CreateDeviceAddress(nullptr, tensor_size, kOpFormat_DEFAULT, kNumberTypeUInt8);
|
||||||
MS_EXCEPTION_IF_NULL(address);
|
MS_EXCEPTION_IF_NULL(address);
|
||||||
|
address->set_from_persistent_mem(true);
|
||||||
auto ms_context = MsContext::GetInstance();
|
auto ms_context = MsContext::GetInstance();
|
||||||
MS_EXCEPTION_IF_NULL(ms_context);
|
MS_EXCEPTION_IF_NULL(ms_context);
|
||||||
if (use_mem_pool && !mem_manager_->MallocMemFromMemPool(address, tensor_size)) {
|
if (use_mem_pool && !mem_manager_->MallocMemFromMemPool(address, tensor_size)) {
|
||||||
|
|
|
@ -56,6 +56,7 @@ class KernelRuntime {
|
||||||
virtual bool Init() = 0;
|
virtual bool Init() = 0;
|
||||||
virtual void AssignMemory(const session::KernelGraph &graph);
|
virtual void AssignMemory(const session::KernelGraph &graph);
|
||||||
void RunOpAssignMemory(const std::vector<tensor::TensorPtr> &input_tensors, const session::KernelGraph &graph,
|
void RunOpAssignMemory(const std::vector<tensor::TensorPtr> &input_tensors, const session::KernelGraph &graph,
|
||||||
|
bool is_gradient_out,
|
||||||
const std::map<tensor::TensorPtr, session::KernelWithIndex> &tensor_to_node = {});
|
const std::map<tensor::TensorPtr, session::KernelWithIndex> &tensor_to_node = {});
|
||||||
void AssignCommunicationOutputFromMemoryPool(const AnfNodePtr &node) const;
|
void AssignCommunicationOutputFromMemoryPool(const AnfNodePtr &node) const;
|
||||||
void AssignCommunicationInputFromMemoryPool(const AnfNodePtr &node) const;
|
void AssignCommunicationInputFromMemoryPool(const AnfNodePtr &node) const;
|
||||||
|
@ -173,7 +174,8 @@ class KernelRuntime {
|
||||||
const std::shared_ptr<MemScheduler> &mem_schedule = nullptr);
|
const std::shared_ptr<MemScheduler> &mem_schedule = nullptr);
|
||||||
void RunOpAssignInputMemory(const std::vector<tensor::TensorPtr> &input_tensors, const session::KernelGraph &graph);
|
void RunOpAssignInputMemory(const std::vector<tensor::TensorPtr> &input_tensors, const session::KernelGraph &graph);
|
||||||
void RunOpAssignOutputMemory(const AnfNodePtr &kernel,
|
void RunOpAssignOutputMemory(const AnfNodePtr &kernel,
|
||||||
const std::map<tensor::TensorPtr, session::KernelWithIndex> &tensor_to_node = {});
|
const std::map<tensor::TensorPtr, session::KernelWithIndex> &tensor_to_node,
|
||||||
|
bool is_gradient_out);
|
||||||
void RunOpAssignWorkSpaceMemory(const AnfNodePtr &kernel);
|
void RunOpAssignWorkSpaceMemory(const AnfNodePtr &kernel);
|
||||||
void RunOpAssignOutputNodeMemory(const ValuePtr &pre_output_value, const session::KernelGraph &graph);
|
void RunOpAssignOutputNodeMemory(const ValuePtr &pre_output_value, const session::KernelGraph &graph);
|
||||||
void AssignValueNodeTensor(const ValueNodePtr &value_node, const ValuePtr &node_value, size_t output_idx);
|
void AssignValueNodeTensor(const ValueNodePtr &value_node, const ValuePtr &node_value, size_t output_idx);
|
||||||
|
|
|
@ -141,9 +141,9 @@ uint8_t *MemoryManager::MallocDynamicMem(size_t size, bool communication_mem) {
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
|
|
||||||
bool MemoryManager::MallocMemFromMemPool(const DeviceAddressPtr address, size_t size) {
|
bool MemoryManager::MallocMemFromMemPool(const DeviceAddressPtr &address, size_t size) {
|
||||||
MS_EXCEPTION_IF_NULL(address);
|
MS_EXCEPTION_IF_NULL(address);
|
||||||
auto device_ptr = MallocMemFromMemPool(size);
|
auto device_ptr = MallocMemFromMemPool(size, address->from_persistent_mem_);
|
||||||
if (!device_ptr) {
|
if (!device_ptr) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
@ -154,7 +154,7 @@ bool MemoryManager::MallocMemFromMemPool(const DeviceAddressPtr address, size_t
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
void *MemoryManager::MallocMemFromMemPool(size_t size) {
|
void *MemoryManager::MallocMemFromMemPool(size_t size, bool from_persistent_mem) {
|
||||||
if (size == 0) {
|
if (size == 0) {
|
||||||
MS_LOG(ERROR) << "MallocMemFromMemPool size is 0.";
|
MS_LOG(ERROR) << "MallocMemFromMemPool size is 0.";
|
||||||
}
|
}
|
||||||
|
|
|
@ -49,8 +49,10 @@ class MemoryManager : public MemHandler {
|
||||||
virtual uint8_t *MallocMem(MemType type, size_t size, const DeviceAddressPtr &address,
|
virtual uint8_t *MallocMem(MemType type, size_t size, const DeviceAddressPtr &address,
|
||||||
uint32_t graph_id = kInvalidGraphId);
|
uint32_t graph_id = kInvalidGraphId);
|
||||||
|
|
||||||
virtual bool MallocMemFromMemPool(const DeviceAddressPtr address, size_t size);
|
// param address is the address type of each device
|
||||||
virtual void *MallocMemFromMemPool(size_t size);
|
// param from_persistent_mem shows whether the tensor is a parameter in Pynative mode
|
||||||
|
virtual bool MallocMemFromMemPool(const DeviceAddressPtr &address, size_t size);
|
||||||
|
virtual void *MallocMemFromMemPool(size_t size, bool from_persistent_mem);
|
||||||
virtual uint8_t *MallocCommunicationMemFromMemPool(size_t size) { return nullptr; }
|
virtual uint8_t *MallocCommunicationMemFromMemPool(size_t size) { return nullptr; }
|
||||||
virtual void FreeMemFromMemPool(const DeviceAddressPtr address);
|
virtual void FreeMemFromMemPool(const DeviceAddressPtr address);
|
||||||
virtual void FreeMemFromMemPool(void *device_ptr);
|
virtual void FreeMemFromMemPool(void *device_ptr);
|
||||||
|
@ -62,7 +64,7 @@ class MemoryManager : public MemHandler {
|
||||||
static size_t GetCommunicationAlignSize(size_t input_size);
|
static size_t GetCommunicationAlignSize(size_t input_size);
|
||||||
|
|
||||||
// swap manager interface
|
// swap manager interface
|
||||||
void *MallocDevice(size_t mem_size) override { return MallocMemFromMemPool(mem_size); }
|
void *MallocDevice(size_t mem_size) override { return MallocMemFromMemPool(mem_size, false); }
|
||||||
void FreeDevice(void *ptr) override {
|
void FreeDevice(void *ptr) override {
|
||||||
MS_EXCEPTION_IF_NULL(ptr);
|
MS_EXCEPTION_IF_NULL(ptr);
|
||||||
FreeMemFromMemPool(ptr);
|
FreeMemFromMemPool(ptr);
|
||||||
|
|
|
@ -607,6 +607,7 @@ void DataPrepareActor::PrepareDataForWeightNode(const AnfNodePtr &backend_node,
|
||||||
(device_tensor->DeviceType() != device_context->GetDeviceAddressType())) {
|
(device_tensor->DeviceType() != device_context->GetDeviceAddressType())) {
|
||||||
host_tensor_address = device_context->CreateDeviceAddress(nullptr, device_tensor->GetSize(),
|
host_tensor_address = device_context->CreateDeviceAddress(nullptr, device_tensor->GetSize(),
|
||||||
device_tensor->format(), device_tensor->type_id());
|
device_tensor->format(), device_tensor->type_id());
|
||||||
|
host_tensor_address->set_from_persistent_mem(tensor->is_parameter());
|
||||||
} else {
|
} else {
|
||||||
host_tensor_address = device_tensor;
|
host_tensor_address = device_tensor;
|
||||||
}
|
}
|
||||||
|
|
|
@ -100,6 +100,7 @@ void CreateParameterDeviceAddress(const DeviceContext *device_context, const Ker
|
||||||
size_t tensor_size = AnfAlgo::GetOutputTensorMemSize(item, index);
|
size_t tensor_size = AnfAlgo::GetOutputTensorMemSize(item, index);
|
||||||
auto device_address = device_context->CreateDeviceAddress(nullptr, tensor_size,
|
auto device_address = device_context->CreateDeviceAddress(nullptr, tensor_size,
|
||||||
AnfAlgo::GetOutputFormat(item, index), output_type_id);
|
AnfAlgo::GetOutputFormat(item, index), output_type_id);
|
||||||
|
device_address->set_from_persistent_mem(item->isa<Parameter>());
|
||||||
MS_LOG(DEBUG) << "Create addr for node:" << AnfAlgo::GetNodeDebugString(item) << " addr:" << device_address;
|
MS_LOG(DEBUG) << "Create addr for node:" << AnfAlgo::GetNodeDebugString(item) << " addr:" << device_address;
|
||||||
AnfAlgo::SetOutputAddr(device_address, index, item.get());
|
AnfAlgo::SetOutputAddr(device_address, index, item.get());
|
||||||
}
|
}
|
||||||
|
@ -143,6 +144,7 @@ void CreateDeviceAddressForTensorValue(const DeviceContext *device_context, cons
|
||||||
device_context->CreateDeviceAddress(nullptr, tensor_size, output_format, output_type_id);
|
device_context->CreateDeviceAddress(nullptr, tensor_size, output_format, output_type_id);
|
||||||
MS_LOG(DEBUG) << "Create addr for node:" << AnfAlgo::GetNodeDebugString(value_node) << " addr:" << address;
|
MS_LOG(DEBUG) << "Create addr for node:" << AnfAlgo::GetNodeDebugString(value_node) << " addr:" << address;
|
||||||
MS_EXCEPTION_IF_NULL(address);
|
MS_EXCEPTION_IF_NULL(address);
|
||||||
|
address->set_from_persistent_mem(true);
|
||||||
AnfAlgo::SetOutputAddr(address, output_idx++, value_node.get());
|
AnfAlgo::SetOutputAddr(address, output_idx++, value_node.get());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -165,6 +167,7 @@ void CreateValueNodeDeviceAddress(const DeviceContext *device_context, const Ker
|
||||||
size_t tensor_size = value.size();
|
size_t tensor_size = value.size();
|
||||||
auto address = device_context->CreateDeviceAddress(nullptr, tensor_size, kOpFormat_DEFAULT, kNumberTypeUInt8);
|
auto address = device_context->CreateDeviceAddress(nullptr, tensor_size, kOpFormat_DEFAULT, kNumberTypeUInt8);
|
||||||
MS_EXCEPTION_IF_NULL(address);
|
MS_EXCEPTION_IF_NULL(address);
|
||||||
|
address->set_from_persistent_mem(true);
|
||||||
MS_LOG(DEBUG) << "Create addr for node:" << AnfAlgo::GetNodeDebugString(value_node) << " addr:" << address;
|
MS_LOG(DEBUG) << "Create addr for node:" << AnfAlgo::GetNodeDebugString(value_node) << " addr:" << address;
|
||||||
|
|
||||||
AnfAlgo::SetOutputAddr(address, 0, value_node.get());
|
AnfAlgo::SetOutputAddr(address, 0, value_node.get());
|
||||||
|
@ -172,7 +175,8 @@ void CreateValueNodeDeviceAddress(const DeviceContext *device_context, const Ker
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void CreateKernelOutputDeviceAddress(const DeviceContext *device_context, const KernelGraphPtr &graph) {
|
void CreateKernelOutputDeviceAddress(const DeviceContext *device_context, const KernelGraphPtr &graph,
|
||||||
|
bool is_gradient_out) {
|
||||||
MS_EXCEPTION_IF_NULL(device_context);
|
MS_EXCEPTION_IF_NULL(device_context);
|
||||||
MS_EXCEPTION_IF_NULL(graph);
|
MS_EXCEPTION_IF_NULL(graph);
|
||||||
const std::vector<CNodePtr> &kernels = graph->execution_order();
|
const std::vector<CNodePtr> &kernels = graph->execution_order();
|
||||||
|
@ -191,6 +195,9 @@ void CreateKernelOutputDeviceAddress(const DeviceContext *device_context, const
|
||||||
auto output_type = AnfAlgo::GetOutputDeviceDataType(kernel, i);
|
auto output_type = AnfAlgo::GetOutputDeviceDataType(kernel, i);
|
||||||
auto address_size = AnfAlgo::GetOutputTensorMemSize(kernel, i);
|
auto address_size = AnfAlgo::GetOutputTensorMemSize(kernel, i);
|
||||||
auto device_address = device_context->CreateDeviceAddress(nullptr, address_size, output_format, output_type);
|
auto device_address = device_context->CreateDeviceAddress(nullptr, address_size, output_format, output_type);
|
||||||
|
if (is_gradient_out) {
|
||||||
|
device_address->set_from_persistent_mem(true);
|
||||||
|
}
|
||||||
MS_LOG(DEBUG) << "Create addr for node:" << AnfAlgo::GetNodeDebugString(kernel) << " addr:" << device_address;
|
MS_LOG(DEBUG) << "Create addr for node:" << AnfAlgo::GetNodeDebugString(kernel) << " addr:" << device_address;
|
||||||
AnfAlgo::SetOutputAddr(device_address, i, kernel.get());
|
AnfAlgo::SetOutputAddr(device_address, i, kernel.get());
|
||||||
}
|
}
|
||||||
|
@ -434,7 +441,7 @@ GraphId GraphCompiler::CompileGraphImpl(const KernelGraphPtr &graph, const Devic
|
||||||
device_context->PreprocessBeforeRunGraph(graph);
|
device_context->PreprocessBeforeRunGraph(graph);
|
||||||
|
|
||||||
// Create device address for all anf nodes of graph.
|
// Create device address for all anf nodes of graph.
|
||||||
CreateDeviceAddress(graph, device_context);
|
CreateDeviceAddress(graph, device_context, false);
|
||||||
|
|
||||||
graph->set_is_all_nop_node(opt::IsAllNopNode(graph.get()));
|
graph->set_is_all_nop_node(opt::IsAllNopNode(graph.get()));
|
||||||
|
|
||||||
|
@ -498,7 +505,7 @@ GraphId GraphCompiler::CompileGraph(const session::OpRunInfo &op_run_info, bool
|
||||||
device_context->OptimizeSingleOpGraph(graph);
|
device_context->OptimizeSingleOpGraph(graph);
|
||||||
|
|
||||||
// Create device address for all anf nodes of graph.
|
// Create device address for all anf nodes of graph.
|
||||||
CreateDeviceAddressWithoutWorkspace(graph, device_context);
|
CreateDeviceAddressWithoutWorkspace(graph, device_context, op_run_info.is_gradient_out);
|
||||||
|
|
||||||
graph->set_is_all_nop_node(opt::IsAllNopNode(graph.get()));
|
graph->set_is_all_nop_node(opt::IsAllNopNode(graph.get()));
|
||||||
run_op_graphs_[op_run_info.graph_info] = graph;
|
run_op_graphs_[op_run_info.graph_info] = graph;
|
||||||
|
@ -546,20 +553,22 @@ KernelGraphPtr GraphCompiler::Fetch(const GraphInfo &graph_info) const {
|
||||||
return iter->second;
|
return iter->second;
|
||||||
}
|
}
|
||||||
|
|
||||||
void GraphCompiler::CreateDeviceAddress(const KernelGraphPtr &graph, const DeviceContext *device_context) const {
|
void GraphCompiler::CreateDeviceAddress(const KernelGraphPtr &graph, const DeviceContext *device_context,
|
||||||
|
bool is_gradient_out) const {
|
||||||
CreateParameterDeviceAddress(device_context, graph);
|
CreateParameterDeviceAddress(device_context, graph);
|
||||||
CreateValueNodeDeviceAddress(device_context, graph);
|
CreateValueNodeDeviceAddress(device_context, graph);
|
||||||
CreateKernelOutputDeviceAddress(device_context, graph);
|
CreateKernelOutputDeviceAddress(device_context, graph, is_gradient_out);
|
||||||
CreateKernelWorkspaceDeviceAddress(device_context, graph);
|
CreateKernelWorkspaceDeviceAddress(device_context, graph);
|
||||||
UpdateDeviceAddressForInplaceNode(graph);
|
UpdateDeviceAddressForInplaceNode(graph);
|
||||||
UpdateDeviceAddressForRefNode(graph);
|
UpdateDeviceAddressForRefNode(graph);
|
||||||
}
|
}
|
||||||
|
|
||||||
void GraphCompiler::CreateDeviceAddressWithoutWorkspace(const KernelGraphPtr &graph,
|
void GraphCompiler::CreateDeviceAddressWithoutWorkspace(const KernelGraphPtr &graph,
|
||||||
const DeviceContext *device_context) const {
|
const DeviceContext *device_context,
|
||||||
|
bool is_gradient_out) const {
|
||||||
CreateParameterDeviceAddress(device_context, graph);
|
CreateParameterDeviceAddress(device_context, graph);
|
||||||
CreateValueNodeDeviceAddress(device_context, graph);
|
CreateValueNodeDeviceAddress(device_context, graph);
|
||||||
CreateKernelOutputDeviceAddress(device_context, graph);
|
CreateKernelOutputDeviceAddress(device_context, graph, is_gradient_out);
|
||||||
UpdateDeviceAddressForInplaceNode(graph);
|
UpdateDeviceAddressForInplaceNode(graph);
|
||||||
UpdateDeviceAddressForRefNode(graph);
|
UpdateDeviceAddressForRefNode(graph);
|
||||||
}
|
}
|
||||||
|
@ -593,24 +602,27 @@ TensorPtr GraphCompiler::GetSingleOpInputTensorByIndex(const CNodePtr &kernel,
|
||||||
}
|
}
|
||||||
|
|
||||||
void GraphCompiler::GetSingleOpRunInfoAndGraphInfo(const CNodePtr &kernel, const InputTensorInfo &tensor_info,
|
void GraphCompiler::GetSingleOpRunInfoAndGraphInfo(const CNodePtr &kernel, const InputTensorInfo &tensor_info,
|
||||||
OpRunInfo *run_info, GraphInfo *graph_info) {
|
OpRunInfo *run_info, GraphInfo *graph_info,
|
||||||
|
GraphOutputInfo *const graph_output_info) {
|
||||||
MS_EXCEPTION_IF_NULL(session_);
|
MS_EXCEPTION_IF_NULL(session_);
|
||||||
MS_EXCEPTION_IF_NULL(graph_info);
|
MS_EXCEPTION_IF_NULL(graph_info);
|
||||||
|
|
||||||
*graph_info = session_->GetSingleOpGraphInfo(kernel, tensor_info.input_tensors);
|
*graph_info = session_->GetSingleOpGraphInfo(kernel, tensor_info.input_tensors);
|
||||||
*run_info = session_->GetSingleOpRunInfo(kernel, *graph_info, tensor_info);
|
*run_info = session_->GetSingleOpRunInfo(kernel, *graph_info, tensor_info, graph_output_info);
|
||||||
}
|
}
|
||||||
|
|
||||||
void GraphCompiler::CalculateRefCount(const KernelGraphPtr &graph, std::map<KernelWithIndex, size_t> *ref_count) const {
|
void GraphCompiler::CalculateRefCount(const KernelGraphPtr &graph, std::map<KernelWithIndex, size_t> *ref_count,
|
||||||
|
std::map<AnfNodePtr, size_t> *forward_output_refcount) const {
|
||||||
MS_EXCEPTION_IF_NULL(session_);
|
MS_EXCEPTION_IF_NULL(session_);
|
||||||
session_->GetRefCount(graph.get(), ref_count);
|
session_->GetRefCount(graph.get(), ref_count);
|
||||||
|
session_->GetForwardOutputRefCount(graph.get(), forward_output_refcount);
|
||||||
}
|
}
|
||||||
|
|
||||||
void GraphCompiler::UpdateRefCount(const std::set<KernelWithIndex> &input_kernels_with_index,
|
void GraphCompiler::UpdateRefCount(const std::set<KernelWithIndex> &input_kernels_with_index,
|
||||||
std::map<KernelWithIndex, size_t> *ref_count,
|
std::map<KernelWithIndex, size_t> *ref_count,
|
||||||
|
std::map<AnfNodePtr, size_t> *forward_output_refcount,
|
||||||
std::map<KernelWithIndex, tensor::TensorPtr> *op_output_map) const {
|
std::map<KernelWithIndex, tensor::TensorPtr> *op_output_map) const {
|
||||||
MS_EXCEPTION_IF_NULL(session_);
|
MS_EXCEPTION_IF_NULL(session_);
|
||||||
session_->HandleOpInputs(input_kernels_with_index, ref_count, op_output_map);
|
session_->HandleOpInputs(input_kernels_with_index, ref_count, forward_output_refcount, op_output_map);
|
||||||
}
|
}
|
||||||
|
|
||||||
void GraphCompiler::RecoverGraphOutput(const AnfNodePtr &kernel, const VectorRef &op_outputs,
|
void GraphCompiler::RecoverGraphOutput(const AnfNodePtr &kernel, const VectorRef &op_outputs,
|
||||||
|
|
|
@ -139,14 +139,16 @@ class GraphCompiler {
|
||||||
|
|
||||||
// Get OpRunInfo and GraphInfo for single op compile and run.
|
// Get OpRunInfo and GraphInfo for single op compile and run.
|
||||||
void GetSingleOpRunInfoAndGraphInfo(const CNodePtr &kernel, const InputTensorInfo &tensor_info, OpRunInfo *run_info,
|
void GetSingleOpRunInfoAndGraphInfo(const CNodePtr &kernel, const InputTensorInfo &tensor_info, OpRunInfo *run_info,
|
||||||
GraphInfo *graph_info);
|
GraphInfo *graph_info, GraphOutputInfo *const graph_output_info);
|
||||||
|
|
||||||
// Calculate ref count of PyNative back propagation operators.
|
// Calculate ref count of PyNative back propagation operators.
|
||||||
void CalculateRefCount(const KernelGraphPtr &graph, std::map<KernelWithIndex, size_t> *ref_count) const;
|
void CalculateRefCount(const KernelGraphPtr &graph, std::map<KernelWithIndex, size_t> *ref_count,
|
||||||
|
std::map<AnfNodePtr, size_t> *forward_output_refcount) const;
|
||||||
|
|
||||||
// Update ref count of PyNative back propagation operators.
|
// Update ref count of PyNative back propagation operators.
|
||||||
void UpdateRefCount(const std::set<KernelWithIndex> &input_kernels_with_index,
|
void UpdateRefCount(const std::set<KernelWithIndex> &input_kernels_with_index,
|
||||||
std::map<KernelWithIndex, size_t> *ref_count,
|
std::map<KernelWithIndex, size_t> *ref_count,
|
||||||
|
std::map<AnfNodePtr, size_t> *forward_output_refcount,
|
||||||
std::map<KernelWithIndex, tensor::TensorPtr> *op_output_map) const;
|
std::map<KernelWithIndex, tensor::TensorPtr> *op_output_map) const;
|
||||||
|
|
||||||
// Handle single op output tensor and recover output of original complete kernel graph.
|
// Handle single op output tensor and recover output of original complete kernel graph.
|
||||||
|
@ -182,10 +184,12 @@ class GraphCompiler {
|
||||||
GraphId CompileGraphImpl(const KernelGraphPtr &graph, const DeviceContext *device_context) const;
|
GraphId CompileGraphImpl(const KernelGraphPtr &graph, const DeviceContext *device_context) const;
|
||||||
|
|
||||||
// Create device address for all anf nodes of graph.
|
// Create device address for all anf nodes of graph.
|
||||||
void CreateDeviceAddress(const KernelGraphPtr &graph, const DeviceContext *device_context) const;
|
void CreateDeviceAddress(const KernelGraphPtr &graph, const DeviceContext *device_context,
|
||||||
|
bool is_gradient_out) const;
|
||||||
|
|
||||||
// Create device address for input and output of ops.
|
// Create device address for input and output of ops.
|
||||||
void CreateDeviceAddressWithoutWorkspace(const KernelGraphPtr &graph, const DeviceContext *device_context) const;
|
void CreateDeviceAddressWithoutWorkspace(const KernelGraphPtr &graph, const DeviceContext *device_context,
|
||||||
|
bool is_gradient_out) const;
|
||||||
|
|
||||||
// Set Graph's dependencies for pre_graph and post_graph.
|
// Set Graph's dependencies for pre_graph and post_graph.
|
||||||
void SetGraphDependency(const KernelGraphPtr &graph, const GraphSegmentPtr &segment) const;
|
void SetGraphDependency(const KernelGraphPtr &graph, const GraphSegmentPtr &segment) const;
|
||||||
|
|
|
@ -1763,6 +1763,7 @@ void GraphScheduler::PersistDeviceTensor(const GraphCompilerInfo &graph_compiler
|
||||||
auto other_type_device_tensor = device_context->CreateDeviceAddress(
|
auto other_type_device_tensor = device_context->CreateDeviceAddress(
|
||||||
nullptr, device_tensor->GetSize(), device_tensor->format(), device_tensor->type_id());
|
nullptr, device_tensor->GetSize(), device_tensor->format(), device_tensor->type_id());
|
||||||
other_type_device_tensor->SetNodeIndex(input_node, 0);
|
other_type_device_tensor->SetNodeIndex(input_node, 0);
|
||||||
|
other_type_device_tensor->set_from_persistent_mem(input_node->isa<Parameter>());
|
||||||
AddDeviceTensorStore(front_node.get(), other_type_device_tensor);
|
AddDeviceTensorStore(front_node.get(), other_type_device_tensor);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -511,7 +511,7 @@ bool AscendDeviceContext::AllocateMemory(DeviceAddress *const &address, size_t s
|
||||||
MS_EXCEPTION_IF_NULL(address);
|
MS_EXCEPTION_IF_NULL(address);
|
||||||
MS_EXCEPTION_IF_NULL(runtime_instance_);
|
MS_EXCEPTION_IF_NULL(runtime_instance_);
|
||||||
runtime_instance_->SetContext();
|
runtime_instance_->SetContext();
|
||||||
auto device_ptr = mem_manager_->MallocMemFromMemPool(size);
|
auto device_ptr = mem_manager_->MallocMemFromMemPool(size, address->from_persistent_mem_);
|
||||||
if (!device_ptr) {
|
if (!device_ptr) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
|
@ -89,7 +89,7 @@ bool CPUDeviceContext::AllocateMemory(DeviceAddress *const &address, size_t size
|
||||||
MS_LOG(EXCEPTION) << "The device address type is wrong: " << address->DeviceType();
|
MS_LOG(EXCEPTION) << "The device address type is wrong: " << address->DeviceType();
|
||||||
}
|
}
|
||||||
|
|
||||||
auto device_ptr = mem_manager_->MallocMemFromMemPool(size);
|
auto device_ptr = mem_manager_->MallocMemFromMemPool(size, 0);
|
||||||
if (!device_ptr) {
|
if (!device_ptr) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
|
@ -173,7 +173,7 @@ bool GPUDeviceContext::AllocateMemory(DeviceAddress *const &address, size_t size
|
||||||
if (!BindDeviceToCurrentThread()) {
|
if (!BindDeviceToCurrentThread()) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
auto device_ptr = mem_manager_->MallocMemFromMemPool(size);
|
auto device_ptr = mem_manager_->MallocMemFromMemPool(size, address->from_persistent_mem_);
|
||||||
if (!device_ptr) {
|
if (!device_ptr) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
|
@ -234,7 +234,7 @@ void UpdateOutput(const std::vector<session::KernelWithIndex> &output_nodes, Vec
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void ClearGraphDeviceAddress(const KernelGraphPtr &graph, const DeviceContext *device_context) {
|
void ClearGraphDeviceAddress(const KernelGraphPtr &graph, const DeviceContext *device_context, bool is_gradient_out) {
|
||||||
MS_EXCEPTION_IF_NULL(graph);
|
MS_EXCEPTION_IF_NULL(graph);
|
||||||
for (const auto &node : graph->execution_order()) {
|
for (const auto &node : graph->execution_order()) {
|
||||||
auto output_address_num = AnfAlgo::GetOutputAddressNum(node);
|
auto output_address_num = AnfAlgo::GetOutputAddressNum(node);
|
||||||
|
@ -252,6 +252,9 @@ void ClearGraphDeviceAddress(const KernelGraphPtr &graph, const DeviceContext *d
|
||||||
MS_EXCEPTION_IF_NULL(new_device_address);
|
MS_EXCEPTION_IF_NULL(new_device_address);
|
||||||
new_device_address->set_original_ref_count(device_address->original_ref_count());
|
new_device_address->set_original_ref_count(device_address->original_ref_count());
|
||||||
new_device_address->ResetRefCount();
|
new_device_address->ResetRefCount();
|
||||||
|
if (is_gradient_out) {
|
||||||
|
new_device_address->set_from_persistent_mem(true);
|
||||||
|
}
|
||||||
AnfAlgo::SetOutputAddr(new_device_address, i, node.get());
|
AnfAlgo::SetOutputAddr(new_device_address, i, node.get());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -800,9 +803,10 @@ void MindRTBackend::RunGraphBySingleOp(const std::vector<KernelGraphPtr> &graphs
|
||||||
&graph_output_info.output_indexes);
|
&graph_output_info.output_indexes);
|
||||||
|
|
||||||
std::map<KernelWithIndex, size_t> cnode_ref_count;
|
std::map<KernelWithIndex, size_t> cnode_ref_count;
|
||||||
|
std::map<AnfNodePtr, size_t> forward_output_refcount;
|
||||||
auto iter = cnode_ref_counts_.find(graph->graph_id());
|
auto iter = cnode_ref_counts_.find(graph->graph_id());
|
||||||
if (iter == cnode_ref_counts_.end()) {
|
if (iter == cnode_ref_counts_.end()) {
|
||||||
graph_compiler_->CalculateRefCount(graph, &cnode_ref_count);
|
graph_compiler_->CalculateRefCount(graph, &cnode_ref_count, &forward_output_refcount);
|
||||||
(void)cnode_ref_counts_.emplace(graph->graph_id(), cnode_ref_count);
|
(void)cnode_ref_counts_.emplace(graph->graph_id(), cnode_ref_count);
|
||||||
} else {
|
} else {
|
||||||
cnode_ref_count = iter->second;
|
cnode_ref_count = iter->second;
|
||||||
|
@ -822,7 +826,8 @@ void MindRTBackend::RunGraphBySingleOp(const std::vector<KernelGraphPtr> &graphs
|
||||||
GraphInfo graph_info;
|
GraphInfo graph_info;
|
||||||
graph_compiler_->GetSingleOpInputTensors(kernel, op_output_map, parameter_index, inputs[graph_index],
|
graph_compiler_->GetSingleOpInputTensors(kernel, op_output_map, parameter_index, inputs[graph_index],
|
||||||
&input_tensor_info);
|
&input_tensor_info);
|
||||||
graph_compiler_->GetSingleOpRunInfoAndGraphInfo(kernel, input_tensor_info, &op_run_info, &graph_info);
|
graph_compiler_->GetSingleOpRunInfoAndGraphInfo(kernel, input_tensor_info, &op_run_info, &graph_info,
|
||||||
|
&graph_output_info);
|
||||||
|
|
||||||
RunOp(&op_run_info, &op_outputs);
|
RunOp(&op_run_info, &op_outputs);
|
||||||
} else {
|
} else {
|
||||||
|
@ -832,7 +837,8 @@ void MindRTBackend::RunGraphBySingleOp(const std::vector<KernelGraphPtr> &graphs
|
||||||
runtime::OpLazyBuilder::GetInstance().ExecuteRemainingTasks();
|
runtime::OpLazyBuilder::GetInstance().ExecuteRemainingTasks();
|
||||||
}
|
}
|
||||||
|
|
||||||
graph_compiler_->UpdateRefCount(input_tensor_info.input_kernel, &cnode_ref_count, &op_output_map);
|
graph_compiler_->UpdateRefCount(input_tensor_info.input_kernel, &cnode_ref_count, &forward_output_refcount,
|
||||||
|
&op_output_map);
|
||||||
|
|
||||||
graph_output_info.graph_output_tensors.clear();
|
graph_output_info.graph_output_tensors.clear();
|
||||||
graph_compiler_->RecoverGraphOutput(kernel, op_outputs, cnode_ref_count, &op_output_map, &graph_output_info);
|
graph_compiler_->RecoverGraphOutput(kernel, op_outputs, cnode_ref_count, &op_output_map, &graph_output_info);
|
||||||
|
@ -1246,7 +1252,8 @@ void MindRTBackend::LazyExecuteTaskCallback() {
|
||||||
const auto &context = op_run_task->context();
|
const auto &context = op_run_task->context();
|
||||||
RunSingleOpGraph(context->graph(), context->output_nodes(), context->op_run_info(),
|
RunSingleOpGraph(context->graph(), context->output_nodes(), context->op_run_info(),
|
||||||
context->graph_compiler_info(), context->device_context());
|
context->graph_compiler_info(), context->device_context());
|
||||||
ClearGraphDeviceAddress(context->graph(), context->device_context());
|
ClearGraphDeviceAddress(context->graph(), context->device_context(), false);
|
||||||
|
|
||||||
UpdateInputDeviceAddress(context->graph());
|
UpdateInputDeviceAddress(context->graph());
|
||||||
|
|
||||||
op_lazy_builder.PopOpRunTask();
|
op_lazy_builder.PopOpRunTask();
|
||||||
|
@ -1304,7 +1311,7 @@ void MindRTBackend::RunOpInternal(bool single_op_cache_hit, GraphCompilerInfo *g
|
||||||
}
|
}
|
||||||
RunSingleOpGraph(graph, output_nodes, *op_run_info, graph_compiler_info, device_context);
|
RunSingleOpGraph(graph, output_nodes, *op_run_info, graph_compiler_info, device_context);
|
||||||
UpdateOutput(output_nodes, outputs);
|
UpdateOutput(output_nodes, outputs);
|
||||||
ClearGraphDeviceAddress(graph, device_context);
|
ClearGraphDeviceAddress(graph, device_context, false);
|
||||||
UpdateInputDeviceAddress(graph);
|
UpdateInputDeviceAddress(graph);
|
||||||
if (op_run_info->is_dynamic_shape) {
|
if (op_run_info->is_dynamic_shape) {
|
||||||
UpdateOutputAbstract(graph, op_run_info);
|
UpdateOutputAbstract(graph, op_run_info);
|
||||||
|
|
|
@ -273,6 +273,14 @@ class _Context:
|
||||||
raise ValueError("For 'context.set_context', the argument 'max_device_memory' should not be \"0GB\".")
|
raise ValueError("For 'context.set_context', the argument 'max_device_memory' should not be \"0GB\".")
|
||||||
self.set_param(ms_ctx_param.max_device_memory, max_device_memory_value)
|
self.set_param(ms_ctx_param.max_device_memory, max_device_memory_value)
|
||||||
|
|
||||||
|
def set_mempool_block_size(self, mempool_block_size):
|
||||||
|
if not Validator.check_str_by_regular(mempool_block_size, _re_pattern):
|
||||||
|
raise ValueError("Context param mempool_block_size should be in correct format! Such as \"10GB\"")
|
||||||
|
mempool_block_size_value = float(mempool_block_size[:-2])
|
||||||
|
if mempool_block_size_value < 1.0:
|
||||||
|
raise ValueError("Context param mempool_block_size should be greater or equal to \"1GB\"")
|
||||||
|
self.set_param(ms_ctx_param.mempool_block_size, mempool_block_size_value)
|
||||||
|
|
||||||
def set_print_file_path(self, file_path):
|
def set_print_file_path(self, file_path):
|
||||||
"""Add timestamp suffix to file name. Sets print file path."""
|
"""Add timestamp suffix to file name. Sets print file path."""
|
||||||
print_file_path = os.path.realpath(file_path)
|
print_file_path = os.path.realpath(file_path)
|
||||||
|
@ -317,6 +325,7 @@ class _Context:
|
||||||
'profiling_options': set_profiling_options,
|
'profiling_options': set_profiling_options,
|
||||||
'variable_memory_max_size': set_variable_memory_max_size,
|
'variable_memory_max_size': set_variable_memory_max_size,
|
||||||
'max_device_memory': set_max_device_memory,
|
'max_device_memory': set_max_device_memory,
|
||||||
|
'mempool_block_size': set_mempool_block_size,
|
||||||
'print_file_path': set_print_file_path,
|
'print_file_path': set_print_file_path,
|
||||||
'env_config_path': set_env_config_path
|
'env_config_path': set_env_config_path
|
||||||
}
|
}
|
||||||
|
@ -570,7 +579,8 @@ def _check_target_specific_cfgs(device, arg_key):
|
||||||
'print_file_path': ['Ascend'],
|
'print_file_path': ['Ascend'],
|
||||||
'variable_memory_max_size': ['Ascend'],
|
'variable_memory_max_size': ['Ascend'],
|
||||||
'auto_tune_mode': ['Ascend'],
|
'auto_tune_mode': ['Ascend'],
|
||||||
'max_device_memory': ['GPU']
|
'max_device_memory': ['GPU'],
|
||||||
|
'mempool_block_size': ['GPU', 'Ascend']
|
||||||
}
|
}
|
||||||
# configs not in map device_cfgs are supposed to be suitable for all devices
|
# configs not in map device_cfgs are supposed to be suitable for all devices
|
||||||
if not arg_key in device_cfgs:
|
if not arg_key in device_cfgs:
|
||||||
|
@ -583,15 +593,15 @@ def _check_target_specific_cfgs(device, arg_key):
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
@args_unreset_check(device_id=int, variable_memory_max_size=str, max_device_memory=str)
|
@args_unreset_check(device_id=int, variable_memory_max_size=str, max_device_memory=str, mempool_block_size=str)
|
||||||
@args_type_check(mode=int, precompile_only=bool, device_target=str, device_id=int, save_graphs=bool,
|
@args_type_check(mode=int, precompile_only=bool, device_target=str, device_id=int, save_graphs=bool,
|
||||||
save_graphs_path=str, enable_dump=bool, auto_tune_mode=str,
|
save_graphs_path=str, enable_dump=bool, auto_tune_mode=str,
|
||||||
save_dump_path=str, enable_reduce_precision=bool, variable_memory_max_size=str,
|
save_dump_path=str, enable_reduce_precision=bool, variable_memory_max_size=str,
|
||||||
enable_profiling=bool, profiling_options=str, enable_auto_mixed_precision=bool,
|
enable_profiling=bool, profiling_options=str, enable_auto_mixed_precision=bool,
|
||||||
enable_graph_kernel=bool, reserve_class_name_in_scope=bool, check_bprop=bool,
|
enable_graph_kernel=bool, reserve_class_name_in_scope=bool, check_bprop=bool,
|
||||||
max_device_memory=str, print_file_path=str, enable_sparse=bool, max_call_depth=int,
|
max_device_memory=str, print_file_path=str, enable_sparse=bool, max_call_depth=int,
|
||||||
env_config_path=str, graph_kernel_flags=str, enable_compile_cache=bool,
|
env_config_path=str, graph_kernel_flags=str, save_compile_cache=bool,
|
||||||
compile_cache_path=str, grad_for_scalar=bool, pynative_synchronize=bool)
|
load_compile_cache=bool, grad_for_scalar=bool, pynative_synchronize=bool, mempool_block_size=str)
|
||||||
def set_context(**kwargs):
|
def set_context(**kwargs):
|
||||||
"""
|
"""
|
||||||
Set context for running environment.
|
Set context for running environment.
|
||||||
|
@ -616,6 +626,8 @@ def set_context(**kwargs):
|
||||||
| | max_device_memory | GPU |
|
| | max_device_memory | GPU |
|
||||||
| +------------------------------+----------------------------+
|
| +------------------------------+----------------------------+
|
||||||
| | variable_memory_max_size | Ascend |
|
| | variable_memory_max_size | Ascend |
|
||||||
|
| +------------------------------+----------------------------+
|
||||||
|
| | mempool_block_size | GPU/Ascend |
|
||||||
+-------------------------+------------------------------+----------------------------+
|
+-------------------------+------------------------------+----------------------------+
|
||||||
| Debug Configuration | save_graphs | CPU/GPU/Ascend |
|
| Debug Configuration | save_graphs | CPU/GPU/Ascend |
|
||||||
| +------------------------------+----------------------------+
|
| +------------------------------+----------------------------+
|
||||||
|
@ -672,6 +684,9 @@ def set_context(**kwargs):
|
||||||
The actual used memory size is the minimum of the available memory of the device and max_device_memory.
|
The actual used memory size is the minimum of the available memory of the device and max_device_memory.
|
||||||
variable_memory_max_size (str): Set the maximum size of the variable memory max size. Default: "30GB".
|
variable_memory_max_size (str): Set the maximum size of the variable memory max size. Default: "30GB".
|
||||||
After this parameter is set, the maximum memory used by the framework is restricted to the configured value.
|
After this parameter is set, the maximum memory used by the framework is restricted to the configured value.
|
||||||
|
mempool_block_size (str): Set the size of the memory pool block in PyNative mode for devices.
|
||||||
|
The format is "xxGB". Default: "1GB". Minimum size is "1G". The actual used memory block size is the minimum
|
||||||
|
of the available memory of the device and mempool_block_size.
|
||||||
save_graphs (bool): Whether to save graphs. Default: False.
|
save_graphs (bool): Whether to save graphs. Default: False.
|
||||||
When the `save_graphs` attribute is set as True, attribute of `save_graphs_path` is used to set the
|
When the `save_graphs` attribute is set as True, attribute of `save_graphs_path` is used to set the
|
||||||
intermediate compilation graph storage path. By default, the graphs are saved in the current directory.
|
intermediate compilation graph storage path. By default, the graphs are saved in the current directory.
|
||||||
|
@ -813,6 +828,7 @@ def set_context(**kwargs):
|
||||||
... profiling_options='{"output":"/home/data/output","training_trace":"on"}')
|
... profiling_options='{"output":"/home/data/output","training_trace":"on"}')
|
||||||
>>> context.set_context(check_bprop=True)
|
>>> context.set_context(check_bprop=True)
|
||||||
>>> context.set_context(max_device_memory="3.5GB")
|
>>> context.set_context(max_device_memory="3.5GB")
|
||||||
|
>>> context.set_context(mempool_block_size="1GB")
|
||||||
>>> context.set_context(print_file_path="print.pb")
|
>>> context.set_context(print_file_path="print.pb")
|
||||||
>>> context.set_context(enable_sparse=True)
|
>>> context.set_context(enable_sparse=True)
|
||||||
>>> context.set_context(max_call_depth=80)
|
>>> context.set_context(max_call_depth=80)
|
||||||
|
|
|
@ -86,6 +86,7 @@ MsContext::MsContext(const std::string &policy, const std::string &target) {
|
||||||
set_param<std::string>(MS_CTX_PROFILING_OPTIONS, "training_trace");
|
set_param<std::string>(MS_CTX_PROFILING_OPTIONS, "training_trace");
|
||||||
set_param<bool>(MS_CTX_CHECK_BPROP_FLAG, false);
|
set_param<bool>(MS_CTX_CHECK_BPROP_FLAG, false);
|
||||||
set_param<float>(MS_CTX_MAX_DEVICE_MEMORY, kDefaultMaxDeviceMemory);
|
set_param<float>(MS_CTX_MAX_DEVICE_MEMORY, kDefaultMaxDeviceMemory);
|
||||||
|
set_param<float>(MS_CTX_MEMPOOL_BLOCK_SIZE, kDefaultMempoolBlockSize);
|
||||||
set_param<std::string>(MS_CTX_PRINT_FILE_PATH, "");
|
set_param<std::string>(MS_CTX_PRINT_FILE_PATH, "");
|
||||||
set_param<bool>(MS_CTX_ENABLE_GRAPH_KERNEL, false);
|
set_param<bool>(MS_CTX_ENABLE_GRAPH_KERNEL, false);
|
||||||
set_param<bool>(MS_CTX_ENABLE_SPARSE, false);
|
set_param<bool>(MS_CTX_ENABLE_SPARSE, false);
|
||||||
|
@ -189,5 +190,4 @@ bool MsContext::enable_dump_ir() const {
|
||||||
return false;
|
return false;
|
||||||
#endif
|
#endif
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
|
@ -60,6 +60,8 @@ const unsigned int MAX_CALL_DEPTH_DEFAULT = 1000;
|
||||||
const std::set<std::string> kTargetSet = {kCPUDevice, kGPUDevice, kAscendDevice, kDavinciDevice};
|
const std::set<std::string> kTargetSet = {kCPUDevice, kGPUDevice, kAscendDevice, kDavinciDevice};
|
||||||
// The default max available device memory is 1024GB.
|
// The default max available device memory is 1024GB.
|
||||||
const float kDefaultMaxDeviceMemory = 1024;
|
const float kDefaultMaxDeviceMemory = 1024;
|
||||||
|
// The default memory pool init percent is 0.0.
|
||||||
|
const float kDefaultMempoolBlockSize = 1.0;
|
||||||
|
|
||||||
// enum definition for MindSpore Context Parameter
|
// enum definition for MindSpore Context Parameter
|
||||||
enum MsCtxParam : unsigned {
|
enum MsCtxParam : unsigned {
|
||||||
|
@ -109,6 +111,7 @@ enum MsCtxParam : unsigned {
|
||||||
// parameter of type float
|
// parameter of type float
|
||||||
MS_CTX_TYPE_FLOAT_BEGIN = MS_CTX_TYPE_UINT32_END,
|
MS_CTX_TYPE_FLOAT_BEGIN = MS_CTX_TYPE_UINT32_END,
|
||||||
MS_CTX_MAX_DEVICE_MEMORY = MS_CTX_TYPE_FLOAT_BEGIN,
|
MS_CTX_MAX_DEVICE_MEMORY = MS_CTX_TYPE_FLOAT_BEGIN,
|
||||||
|
MS_CTX_MEMPOOL_BLOCK_SIZE,
|
||||||
MS_CTX_TYPE_FLOAT_END,
|
MS_CTX_TYPE_FLOAT_END,
|
||||||
|
|
||||||
// parameter of type string
|
// parameter of type string
|
||||||
|
|
Loading…
Reference in New Issue