forked from mindspore-Ecosystem/mindspore
!2966 reuse communication op output's memory
Merge pull request !2966 from laiyongqiang/hcom_memreuse
This commit is contained in:
commit
25ee322ba3
|
@ -25,7 +25,8 @@
|
|||
namespace mindspore {
|
||||
namespace memreuse {
|
||||
enum RefCountType { kDynamicRefCount, kStaticRefCount };
|
||||
enum NodeType { NORMAL, SPECIAL };
|
||||
enum NodeType { COMMON_NODE, COMMUNICATION_NODE };
|
||||
enum KernelRefType { COMMON, REFNODE_OUTPUT, COMM_NOTREUSE, COMM_REUSE, SUMMARY };
|
||||
static constexpr int kInitIndex = -1;
|
||||
class KernelRefCount {
|
||||
public:
|
||||
|
@ -36,6 +37,7 @@ class KernelRefCount {
|
|||
size_t offset_;
|
||||
size_t size_;
|
||||
int index_;
|
||||
KernelRefType type_;
|
||||
// remember to reset offset
|
||||
KernelRefCount()
|
||||
: stream_id_(0),
|
||||
|
@ -44,6 +46,7 @@ class KernelRefCount {
|
|||
offset_(0),
|
||||
size_(0),
|
||||
index_(kInitIndex),
|
||||
type_(COMMON),
|
||||
reftype_(kStaticRefCount) {}
|
||||
~KernelRefCount() = default;
|
||||
void SetKernelRefCountInfo(int index, size_t size, RefCountType reftype);
|
||||
|
@ -65,7 +68,7 @@ class KernelDef {
|
|||
KernelMap inputs_;
|
||||
KernelMap outputs_;
|
||||
KernelMap wk_space_;
|
||||
NodeType dirty = NORMAL;
|
||||
NodeType type_ = COMMON_NODE;
|
||||
KernelDef() = default;
|
||||
~KernelDef() = default;
|
||||
void set_input_refs(const KernelRefCountPtrList &kernelRefPtrList) { input_refs_ = kernelRefPtrList; }
|
||||
|
|
|
@ -46,6 +46,8 @@ bool MemReuseUtil::InitDynamicOutputKernelRef() {
|
|||
if (iter == kernel_output_refs_.end()) {
|
||||
auto output_sizes = kernel_mod->GetOutputSizeList();
|
||||
KernelRefCountPtrList kernel_refs;
|
||||
bool is_comm_op = AnfAlgo::IsCommunicationOp(kernel_cnode);
|
||||
size_t output_index = 0;
|
||||
for (auto size : output_sizes) {
|
||||
total_dy_size_ += size;
|
||||
// do not MallocDynamicMem just record this
|
||||
|
@ -54,9 +56,20 @@ bool MemReuseUtil::InitDynamicOutputKernelRef() {
|
|||
auto curr_stream_id = AnfAlgo::GetStreamId(kernel_cnode);
|
||||
kernel_ref->stream_id_ = curr_stream_id;
|
||||
kernel_ref->SetKernelRefCountInfo(index, size, kDynamicRefCount);
|
||||
if (is_comm_op) {
|
||||
kernel_ref->type_ = COMM_REUSE;
|
||||
} else {
|
||||
session::AnfWithOutIndex out_pair(kernel_cnode, output_index);
|
||||
if (graph_->IsInRefOutputMap(out_pair)) {
|
||||
kernel_ref->type_ = REFNODE_OUTPUT;
|
||||
} else {
|
||||
kernel_ref->type_ = COMMON;
|
||||
}
|
||||
}
|
||||
kernel_refs.push_back(kernel_ref);
|
||||
kernel_out_ref_num++;
|
||||
total_refs_list_.push_back(kernel_ref);
|
||||
output_index++;
|
||||
}
|
||||
if (!kernel_refs.empty()) {
|
||||
kernel_output_refs_[key] = kernel_refs;
|
||||
|
@ -155,9 +168,19 @@ void MemReuseUtil::SetInputMap(const CNodePtr &kernel, KernelDef *kernel_def_ptr
|
|||
MS_EXCEPTION_IF_NULL(kernel);
|
||||
MS_EXCEPTION_IF_NULL(kernel_def_ptr);
|
||||
auto key = kernel.get();
|
||||
for (size_t i = 0; i < AnfAlgo::GetInputTensorNum(kernel); ++i) {
|
||||
bool is_comm_op = AnfAlgo::IsCommunicationOp(kernel);
|
||||
size_t input_tensor_num = AnfAlgo::GetInputTensorNum(kernel);
|
||||
for (size_t i = 0; i < input_tensor_num; ++i) {
|
||||
auto ref_ptr = GetKernelInputRef(kernel, i);
|
||||
if (ref_ptr != nullptr) {
|
||||
if (is_comm_op) {
|
||||
if (input_tensor_num == 1) {
|
||||
ref_ptr->type_ = COMM_REUSE;
|
||||
} else {
|
||||
ref_ptr->type_ = COMM_NOTREUSE;
|
||||
}
|
||||
}
|
||||
|
||||
if (ref_ptr->reftype() == kStaticRefCount) {
|
||||
continue;
|
||||
} else if (ref_ptr->reftype() == kDynamicRefCount) {
|
||||
|
@ -258,6 +281,11 @@ void MemReuseUtil::SetKernelDefMap() {
|
|||
auto key = kernel.get();
|
||||
kernel_def_ptr->set_input_refs(kernel_def_ptr->inputs_[key]);
|
||||
kernel_def_ptr->set_output_refs(kernel_def_ptr->outputs_[key]);
|
||||
if (AnfAlgo::IsCommunicationOp(kernel)) {
|
||||
kernel_def_ptr->type_ = COMMUNICATION_NODE;
|
||||
} else {
|
||||
kernel_def_ptr->type_ = COMMON_NODE;
|
||||
}
|
||||
kernel_def_ptr_list_.push_back(kernel_def_ptr);
|
||||
kernel_map_[key] = kernel_def_ptr;
|
||||
}
|
||||
|
@ -337,6 +365,7 @@ void MemReuseUtil::SetSummaryNodesRefCount() {
|
|||
KernelRefCountPtr kernel_ref = kernel_output_refs_[node.get()][index];
|
||||
kernel_ref->ref_count_ = kMaxRefCount;
|
||||
kernel_ref->ref_count_dynamic_use_ = kMaxRefCount;
|
||||
kernel_ref->type_ = SUMMARY;
|
||||
total_summary_size += kernel_ref->size_;
|
||||
MS_LOG(INFO) << "Set summary node's ref count, node: " << node->fullname_with_scope() << " index: " << index;
|
||||
} else {
|
||||
|
|
|
@ -33,11 +33,11 @@ void BestFitMemReuse::InitMemReuseInfo(const MemReuseUtil *mem_reuse_util_ptr) {
|
|||
set_op_ptr_list(mem_reuse_util_ptr->kernel_def_ptr_list());
|
||||
// check info Correctness
|
||||
for (auto &tensor : tensor_ptr_list_) {
|
||||
tensor->size_ = AlignMemorySize(tensor->size_);
|
||||
tensor->size_ = AlignCommonMemorySize(tensor->size_);
|
||||
}
|
||||
// align wk size to 512 && refcount == 1
|
||||
for (auto &wk : wk_tensor_list_) {
|
||||
wk->size_ = AlignMemorySize(wk->size_);
|
||||
wk->size_ = AlignCommonMemorySize(wk->size_);
|
||||
wk->ref_count_ = 1;
|
||||
}
|
||||
#ifdef ENABLE_D
|
||||
|
@ -135,11 +135,23 @@ bool BestFitMemReuse::IsUsable(const KernelDefPtr &kernel_curr, const MembufPtr
|
|||
return false;
|
||||
}
|
||||
|
||||
void BestFitMemReuse::AssignNodeOutputOffset() {
|
||||
void BestFitMemReuse::AssignCommonNodeOutputOffset() {
|
||||
MS_EXCEPTION_IF_NULL(current_kernel_);
|
||||
for (auto &tensor_idx : current_kernel_->GetOutputRefIndexs()) {
|
||||
size_t index = GetTensorIndex(tensor_idx);
|
||||
auto tensor_desc = tensor_ptr_list_[index];
|
||||
MS_EXCEPTION_IF_NULL(tensor_desc);
|
||||
if (tensor_desc->type_ == REFNODE_OUTPUT) {
|
||||
total_refoutput_size += tensor_desc->size_;
|
||||
continue;
|
||||
} else if (tensor_desc->type_ == COMM_NOTREUSE) {
|
||||
total_comm_not_reuse_size += tensor_desc->size_;
|
||||
} else if (tensor_desc->type_ == COMM_REUSE) {
|
||||
// get align size for communication op's single input
|
||||
tensor_desc->size_ = AlignCommunicationMemorySize(tensor_desc->size_);
|
||||
total_comm_reuse_size += tensor_desc->size_;
|
||||
}
|
||||
|
||||
auto reusable_membuf_map = GetReusableMembufMap(tensor_desc->size_);
|
||||
if (!reusable_membuf_map.empty()) {
|
||||
auto membuf_index = reusable_membuf_map.begin()->second;
|
||||
|
@ -152,6 +164,86 @@ void BestFitMemReuse::AssignNodeOutputOffset() {
|
|||
MemReuseChecker::GetInstance().IsAddNewMembuf_ = true;
|
||||
#endif
|
||||
}
|
||||
// skip left align border for communication op single input to used
|
||||
if (tensor_desc->type_ == COMM_REUSE) {
|
||||
tensor_desc->offset_ += kDefaultMemAlignSize;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void BestFitMemReuse::AssignCommunicationNodeOutputOffset() {
|
||||
size_t total_kernel_output_size = 0;
|
||||
size_t output_num = 0;
|
||||
// get all output size
|
||||
MS_EXCEPTION_IF_NULL(current_kernel_);
|
||||
for (auto &tensor_idx : current_kernel_->GetOutputRefIndexs()) {
|
||||
size_t index = GetTensorIndex(tensor_idx);
|
||||
auto tensor_desc = tensor_ptr_list_[index];
|
||||
MS_EXCEPTION_IF_NULL(tensor_desc);
|
||||
if (tensor_desc->type_ == COMM_REUSE) {
|
||||
total_comm_reuse_size += tensor_desc->size_;
|
||||
total_comm_output_reuse_size += tensor_desc->size_;
|
||||
total_kernel_output_size += tensor_desc->size_;
|
||||
} else {
|
||||
MS_LOG(ERROR) << "All communication op's outputs should be memory reuse, Kernel:"
|
||||
<< current_kernel_->scope_full_name();
|
||||
continue;
|
||||
}
|
||||
}
|
||||
total_kernel_output_size = AlignCommunicationMemorySize(total_kernel_output_size);
|
||||
|
||||
// add left align border for the first output and right align border for the last output to alloc align border memory
|
||||
size_t output_index = 0;
|
||||
for (auto &tensor_idx : current_kernel_->GetOutputRefIndexs()) {
|
||||
size_t index = GetTensorIndex(tensor_idx);
|
||||
auto tensor_desc = tensor_ptr_list_[index];
|
||||
MS_EXCEPTION_IF_NULL(tensor_desc);
|
||||
if (output_index == 0 || output_index == output_num - 1) {
|
||||
tensor_desc->size_ += kDefaultMemAlignSize;
|
||||
}
|
||||
output_index++;
|
||||
}
|
||||
|
||||
auto reusable_membuf_map = GetReusableMembufMap(total_kernel_output_size);
|
||||
if (!reusable_membuf_map.empty()) {
|
||||
auto membuf_index = reusable_membuf_map.begin()->second;
|
||||
output_index = 0;
|
||||
for (auto &tensor_idx : current_kernel_->GetOutputRefIndexs()) {
|
||||
size_t index = GetTensorIndex(tensor_idx);
|
||||
auto tensor_desc = tensor_ptr_list_[index];
|
||||
MS_EXCEPTION_IF_NULL(tensor_desc);
|
||||
ReuseExistMembuf(tensor_desc.get(), membuf_index + output_index, kDynamicMem);
|
||||
// skip skip left align border for communication op's first output to used
|
||||
if (output_index == 0) {
|
||||
tensor_desc->offset_ += kDefaultMemAlignSize;
|
||||
}
|
||||
output_index++;
|
||||
}
|
||||
} else {
|
||||
// no membuf can reuse, add new membuf after the membuf_ptr_list
|
||||
output_index = 0;
|
||||
for (auto &tensor_idx : current_kernel_->GetOutputRefIndexs()) {
|
||||
size_t index = GetTensorIndex(tensor_idx);
|
||||
auto tensor_desc = tensor_ptr_list_[index];
|
||||
MS_EXCEPTION_IF_NULL(tensor_desc);
|
||||
AddNewMembufPtr(tensor_desc.get(), kDynamicMem);
|
||||
// skip align size offset for first output to used
|
||||
if (output_index == 0) {
|
||||
tensor_desc->offset_ += kDefaultMemAlignSize;
|
||||
}
|
||||
output_index++;
|
||||
#ifdef MEM_REUSE_DEBUG
|
||||
MemReuseChecker::GetInstance().IsAddNewMembuf_ = true;
|
||||
#endif
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void BestFitMemReuse::AssignNodeOutputOffset() {
|
||||
if (current_kernel_->type_ == COMMUNICATION_NODE) {
|
||||
AssignCommunicationNodeOutputOffset();
|
||||
} else {
|
||||
AssignCommonNodeOutputOffset();
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -319,11 +411,17 @@ void BestFitMemReuse::ReleaseMembuf(size_t tensor_index, int flag) {
|
|||
}
|
||||
}
|
||||
|
||||
size_t BestFitMemReuse::AlignMemorySize(size_t size) const {
|
||||
size_t BestFitMemReuse::AlignCommonMemorySize(size_t size) const {
|
||||
// memory size 512 align
|
||||
return (size + kDefaultMemAlignSize + kAttAlignSize) / kDefaultMemAlignSize * kDefaultMemAlignSize;
|
||||
}
|
||||
|
||||
size_t BestFitMemReuse::AlignCommunicationMemorySize(size_t size) const {
|
||||
// memory size 512 align and add communication memory: left align border memory - data - right align border memory
|
||||
return kDefaultMemAlignSize + (size + kDefaultMemAlignSize - 1) / kDefaultMemAlignSize * kDefaultMemAlignSize +
|
||||
kDefaultMemAlignSize;
|
||||
}
|
||||
|
||||
size_t BestFitMemReuse::GetAllocatedSize() {
|
||||
size_t AllocatedSize = kTotalSize;
|
||||
if (membuf_ptr_list_.empty()) {
|
||||
|
@ -412,6 +510,9 @@ void BestFitMemReuse::Reuse(const MemReuseUtil *mem_reuse_util_ptr) {
|
|||
++op_num;
|
||||
#endif
|
||||
}
|
||||
MS_LOG(INFO) << "Special Tensor total size: RefOutput: " << total_refoutput_size
|
||||
<< " CommReuse: " << total_comm_reuse_size << " CommOutputReuse: " << total_comm_output_reuse_size
|
||||
<< " CommNotReuse: " << total_comm_not_reuse_size;
|
||||
#ifdef MEM_REUSE_DEBUG
|
||||
MemReuseChecker::GetInstance().ExportMembufInfoIR();
|
||||
MemReuseChecker::GetInstance().ExportAddNewMmebufIR();
|
||||
|
|
|
@ -74,6 +74,14 @@ class BestFitMemReuse {
|
|||
* Assign output tensor memory offset of current kernel
|
||||
*/
|
||||
void AssignNodeOutputOffset();
|
||||
/**
|
||||
* Assign output tensor memory offset of common kernel
|
||||
*/
|
||||
void AssignCommonNodeOutputOffset();
|
||||
/**
|
||||
* Assign output tensor memory offset of communication kernel
|
||||
*/
|
||||
void AssignCommunicationNodeOutputOffset();
|
||||
/**
|
||||
* Update input tensor's status of current kernel, and the status of membuf used by current kernel
|
||||
*/
|
||||
|
@ -110,8 +118,10 @@ class BestFitMemReuse {
|
|||
void AddNewMembufPtr(KernelRefCount *tensor_desc, int flag);
|
||||
// Merge unused membuf
|
||||
void ReleaseMembuf(size_t tensor_index, int flag);
|
||||
// Memory address alignment 512
|
||||
size_t AlignMemorySize(size_t size) const;
|
||||
// Memory address alignment for common memory
|
||||
size_t AlignCommonMemorySize(size_t size) const;
|
||||
// Memory address alignment for communication used memory
|
||||
size_t AlignCommunicationMemorySize(size_t size) const;
|
||||
int GetRealIndex(size_t index, int flag = kDynamicMem) const;
|
||||
size_t GetTensorIndex(int index) const;
|
||||
size_t GetWorkspaceIndex(int index) const;
|
||||
|
@ -153,6 +163,10 @@ class BestFitMemReuse {
|
|||
// kernel_front_map_, key: the kernel_def, value: kernels before this kernel_def
|
||||
std::map<KernelDefPtr, std::set<KernelDefPtr>> kernel_front_map_;
|
||||
std::vector<std::vector<uint32_t>> stream_groups_;
|
||||
size_t total_refoutput_size{0};
|
||||
size_t total_comm_reuse_size{0};
|
||||
size_t total_comm_output_reuse_size{0};
|
||||
size_t total_comm_not_reuse_size{0};
|
||||
};
|
||||
} // namespace memreuse
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -170,12 +170,14 @@ void MemReuseChecker::CheckMemReuseIR(const KernelRefCountPtrList &total_refs_li
|
|||
ofs << "all_tensor_refs:\n";
|
||||
ofs << "index:"
|
||||
<< "\tsize:"
|
||||
<< "\trefcount:\n";
|
||||
<< "\trefcount:"
|
||||
<< "\ttype:\n";
|
||||
for (auto &ref : total_refs_list) {
|
||||
ofs << "%" << ref->index_ << "T"
|
||||
<< "\t"
|
||||
<< "#" << ref->size_ << "S"
|
||||
<< "\t" << ref->ref_count_ << "C"
|
||||
<< "\t" << ref->type_ << "t"
|
||||
<< "\n";
|
||||
}
|
||||
ofs << "kernel_def exc_order:\n";
|
||||
|
@ -241,7 +243,7 @@ bool MemReuseChecker::CheckGraphOutputAssigned(const session::KernelGraph *graph
|
|||
void MemReuseChecker::ExportMemOpIr(const KernelDef *def, std::ofstream &ofs, int def_idx) {
|
||||
auto scope_name = def->scope_full_name();
|
||||
std::string split_name = GetSplitName(scope_name);
|
||||
ofs << "$" << def_idx << "\t" << split_name << "\t";
|
||||
ofs << "$" << def_idx << "\t" << split_name << "\t" << static_cast<int>(def->type_) << "\t";
|
||||
ofs << "inputs[";
|
||||
for (auto &in : def->inputs_) {
|
||||
for (auto &in_ref : in.second) {
|
||||
|
|
|
@ -95,6 +95,12 @@ uint8_t *AscendMemoryManager::MallocStaticMem(size_t size, bool communication_me
|
|||
} else {
|
||||
align_size = GetCommonAlignSize(size);
|
||||
}
|
||||
|
||||
auto device_mem_pool_offset = AscendMemoryPool::GetInstance().device_mem_pool_offset();
|
||||
MS_LOG(INFO) << "Malloc Memory: Static, total[" << device_mem_size_ << "] (dynamic[" << total_dynamic_size_
|
||||
<< "] memory pool[" << device_mem_pool_offset << "])"
|
||||
<< " malloc [" << align_size << "]";
|
||||
|
||||
if (communication_mem) {
|
||||
// create protect area [kMemAlignSize -- data -- kMemAlignSize]
|
||||
uint8_t *alloc_address = reinterpret_cast<uint8_t *>(AscendMemoryPool::GetInstance().AllocTensorMem(align_size));
|
||||
|
@ -111,12 +117,17 @@ uint8_t *AscendMemoryManager::MallocDynamicMem(size_t size, bool communication_m
|
|||
} else {
|
||||
align_size = GetCommonAlignSize(size);
|
||||
}
|
||||
|
||||
auto device_mem_pool_offset = AscendMemoryPool::GetInstance().device_mem_pool_offset();
|
||||
MS_LOG(INFO) << "Malloc Memory: Dynamic, total[" << device_mem_size_ << "] (dynamic[" << total_dynamic_size_
|
||||
<< "] memory pool[" << device_mem_pool_offset << "])"
|
||||
<< " malloc [" << align_size << "]";
|
||||
|
||||
if (dynamic_mem_offset_ < align_size) {
|
||||
MS_LOG(EXCEPTION) << "Out of memory!!! total[" << device_mem_size_ << "] (dynamic[" << total_dynamic_size_
|
||||
<< "]) malloc [" << align_size << "] failed!";
|
||||
}
|
||||
auto new_offset = dynamic_mem_offset_ - align_size;
|
||||
auto device_mem_pool_offset = AscendMemoryPool::GetInstance().device_mem_pool_offset();
|
||||
if (new_offset <= device_mem_pool_offset) {
|
||||
MS_LOG(EXCEPTION) << "Out of memory!!! total[" << device_mem_size_ << "] (dynamic[" << total_dynamic_size_
|
||||
<< "] memory pool[" << device_mem_pool_offset << "])"
|
||||
|
|
|
@ -399,7 +399,7 @@ void KernelRuntime::UpdateRefNodeOutputMem(const session::KernelGraph *graph) {
|
|||
}
|
||||
|
||||
void KernelRuntime::AssignCommunicationNodeMem(int flag, const AnfNodePtr &node) {
|
||||
AssignCommunicationNodeInputMem(node);
|
||||
AssignCommunicationNodeInputMem(flag, node);
|
||||
AssignCommunicationNodeOutputMem(flag, node);
|
||||
}
|
||||
|
||||
|
@ -429,6 +429,11 @@ void KernelRuntime::AssignCommunicationNodeOutputMem(int flag, const AnfNodePtr
|
|||
total_size += mem_size;
|
||||
align_size_list.emplace_back(mem_size);
|
||||
}
|
||||
|
||||
if (flag == kReuseDynamicMem) {
|
||||
// reuse communication op's all outputs' memory
|
||||
flag = kReuseDynamicCommMem;
|
||||
}
|
||||
uint8_t *output_ptr = mem_manager_->MallocOutputMem(node, 0, flag, total_size);
|
||||
for (size_t j = 0; j < align_size_list.size(); ++j) {
|
||||
std::string output_format = AnfAlgo::GetOutputFormat(node, j);
|
||||
|
@ -457,7 +462,7 @@ DeviceAddressPtr KernelRuntime::PreAssignCNodeMemory(const AnfNodePtr &anf_node,
|
|||
return address;
|
||||
}
|
||||
|
||||
void KernelRuntime::AssignCommunicationNodeInputMem(const AnfNodePtr &node) {
|
||||
void KernelRuntime::AssignCommunicationNodeInputMem(int flag, const AnfNodePtr &node) {
|
||||
auto context_ptr = MsContext::GetInstance();
|
||||
MS_EXCEPTION_IF_NULL(context_ptr);
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
|
@ -478,7 +483,7 @@ void KernelRuntime::AssignCommunicationNodeInputMem(const AnfNodePtr &node) {
|
|||
total_size += mem_size;
|
||||
addr_size.emplace_back(address.get(), mem_size);
|
||||
}
|
||||
uint8_t *input_ptr = mem_manager_->MallocOutputMem(node, 0, kDynamicMem, total_size);
|
||||
uint8_t *input_ptr = mem_manager_->MallocOutputMem(node, 0, flag, total_size);
|
||||
for (const auto &iter : addr_size) {
|
||||
MS_EXCEPTION_IF_NULL(iter.first);
|
||||
iter.first->set_ptr(input_ptr);
|
||||
|
|
|
@ -88,7 +88,7 @@ class KernelRuntime {
|
|||
void UpdateRefNodeOutputMem(const session::KernelGraph *graph);
|
||||
|
||||
void AssignCommunicationNodeOutputMem(int flag, const AnfNodePtr &node);
|
||||
void AssignCommunicationNodeInputMem(const AnfNodePtr &node);
|
||||
void AssignCommunicationNodeInputMem(int flag, const AnfNodePtr &node);
|
||||
void AssignCommunicationNodeMem(int flag, const AnfNodePtr &node);
|
||||
#ifdef ENABLE_DUMP_E2E
|
||||
bool SetDumpConf();
|
||||
|
|
|
@ -57,6 +57,9 @@ uint8_t *MemoryManager::MallocOutputMem(const AnfNodePtr &node, size_t index, in
|
|||
}
|
||||
if (flag == kStaticMem) {
|
||||
ptr = MallocStaticMem(size, communication_mem);
|
||||
} else if (flag == kReuseDynamicCommMem) {
|
||||
MS_EXCEPTION_IF_NULL(mem_reuse_util_ptr_);
|
||||
ptr = mem_reuse_util_ptr_->GetNodeOutputPtr(node, index);
|
||||
} else {
|
||||
ptr = MallocDynamicMem(size, communication_mem);
|
||||
}
|
||||
|
|
|
@ -25,6 +25,7 @@ namespace device {
|
|||
const int kStaticMem = 0;
|
||||
const int kDynamicMem = 1;
|
||||
const int kReuseDynamicMem = 2;
|
||||
const int kReuseDynamicCommMem = 3;
|
||||
const int kGetAllOuts = -1;
|
||||
const uint64_t kMemAlignSize = 512;
|
||||
using MemReuseUtilPtr = mindspore::memreuse::MemReuseUtilPtr;
|
||||
|
|
|
@ -146,7 +146,7 @@ TEST_F(TestMemReuseAllocator, mem_reuse_allocator_split_membuf) {
|
|||
|
||||
TEST_F(TestMemReuseAllocator, mem_reuse_allocator_align) {
|
||||
auto best_fit_mem_reuse = std::make_shared<BestFitMemReuse>();
|
||||
auto size = best_fit_mem_reuse->AlignMemorySize(510);
|
||||
auto size = best_fit_mem_reuse->AlignCommonMemorySize(510);
|
||||
ASSERT_EQ(size, 1024);
|
||||
}
|
||||
} // namespace memreuse
|
||||
|
|
|
@ -225,7 +225,6 @@ TEST_F(TestMemReuseWithPy, KernelRef) {
|
|||
ASSERT_EQ(kernel_ref_count_ptr->size_, 512);
|
||||
KernelDefPtr kernel_def_ptr = std::make_shared<KernelDef>();
|
||||
ASSERT_NE(kernel_def_ptr, nullptr);
|
||||
ASSERT_EQ(kernel_def_ptr->dirty, false);
|
||||
MembufPtr membuf_ptr = std::make_shared<Membuf>();
|
||||
ASSERT_NE(membuf_ptr, nullptr);
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue