forked from mindspore-Ecosystem/mindspore
!3117 not reuse ref node input's memory
Merge pull request !3117 from laiyongqiang/refnode_input
This commit is contained in:
commit
72a2b7d496
|
@ -25,8 +25,8 @@
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace memreuse {
|
namespace memreuse {
|
||||||
enum RefCountType { kDynamicRefCount, kStaticRefCount };
|
enum RefCountType { kDynamicRefCount, kStaticRefCount };
|
||||||
enum NodeType { COMMON_NODE, COMMUNICATION_NODE };
|
enum NodeType { kCommonNode, kCommunicationNode };
|
||||||
enum KernelRefType { COMMON, REFNODE_OUTPUT, COMM_NOTREUSE, COMM_REUSE, SUMMARY };
|
enum KernelRefType { kCommon, kRefNodeInput, kRefNodeOutput, kCommNotReuse, kCommReuse, kSummary };
|
||||||
static constexpr int kInitIndex = -1;
|
static constexpr int kInitIndex = -1;
|
||||||
class KernelRefCount {
|
class KernelRefCount {
|
||||||
public:
|
public:
|
||||||
|
@ -46,7 +46,7 @@ class KernelRefCount {
|
||||||
offset_(0),
|
offset_(0),
|
||||||
size_(0),
|
size_(0),
|
||||||
index_(kInitIndex),
|
index_(kInitIndex),
|
||||||
type_(COMMON),
|
type_(kCommon),
|
||||||
reftype_(kStaticRefCount) {}
|
reftype_(kStaticRefCount) {}
|
||||||
~KernelRefCount() = default;
|
~KernelRefCount() = default;
|
||||||
void SetKernelRefCountInfo(int index, size_t size, RefCountType reftype);
|
void SetKernelRefCountInfo(int index, size_t size, RefCountType reftype);
|
||||||
|
@ -68,7 +68,7 @@ class KernelDef {
|
||||||
KernelMap inputs_;
|
KernelMap inputs_;
|
||||||
KernelMap outputs_;
|
KernelMap outputs_;
|
||||||
KernelMap wk_space_;
|
KernelMap wk_space_;
|
||||||
NodeType type_ = COMMON_NODE;
|
NodeType type_ = kCommonNode;
|
||||||
KernelDef() = default;
|
KernelDef() = default;
|
||||||
~KernelDef() = default;
|
~KernelDef() = default;
|
||||||
void set_input_refs(const KernelRefCountPtrList &kernelRefPtrList) { input_refs_ = kernelRefPtrList; }
|
void set_input_refs(const KernelRefCountPtrList &kernelRefPtrList) { input_refs_ = kernelRefPtrList; }
|
||||||
|
|
|
@ -57,13 +57,22 @@ bool MemReuseUtil::InitDynamicOutputKernelRef() {
|
||||||
kernel_ref->stream_id_ = curr_stream_id;
|
kernel_ref->stream_id_ = curr_stream_id;
|
||||||
kernel_ref->SetKernelRefCountInfo(index, size, kDynamicRefCount);
|
kernel_ref->SetKernelRefCountInfo(index, size, kDynamicRefCount);
|
||||||
if (is_comm_op) {
|
if (is_comm_op) {
|
||||||
kernel_ref->type_ = COMM_REUSE;
|
kernel_ref->type_ = kCommReuse;
|
||||||
} else {
|
} else {
|
||||||
session::AnfWithOutIndex out_pair(kernel_cnode, output_index);
|
session::AnfWithOutIndex out_pair(kernel_cnode, output_index);
|
||||||
if (graph_->IsInRefOutputMap(out_pair)) {
|
if (graph_->IsInRefOutputMap(out_pair)) {
|
||||||
kernel_ref->type_ = REFNODE_OUTPUT;
|
kernel_ref->type_ = kRefNodeOutput;
|
||||||
|
auto origin_pair = graph_->GetRefCorrespondOutput(out_pair);
|
||||||
|
MS_EXCEPTION_IF_NULL(origin_pair.first);
|
||||||
|
if (origin_pair.first->isa<CNode>()) {
|
||||||
|
auto cnode = origin_pair.first->cast<CNodePtr>();
|
||||||
|
auto ref_ptr = GetKernelInputRef(cnode, origin_pair.second);
|
||||||
|
if (ref_ptr != nullptr) {
|
||||||
|
kernel_ref->type_ = kRefNodeInput;
|
||||||
|
}
|
||||||
|
}
|
||||||
} else {
|
} else {
|
||||||
kernel_ref->type_ = COMMON;
|
kernel_ref->type_ = kCommon;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
kernel_refs.push_back(kernel_ref);
|
kernel_refs.push_back(kernel_ref);
|
||||||
|
@ -175,9 +184,9 @@ void MemReuseUtil::SetInputMap(const CNodePtr &kernel, KernelDef *kernel_def_ptr
|
||||||
if (ref_ptr != nullptr) {
|
if (ref_ptr != nullptr) {
|
||||||
if (is_comm_op) {
|
if (is_comm_op) {
|
||||||
if (input_tensor_num == 1) {
|
if (input_tensor_num == 1) {
|
||||||
ref_ptr->type_ = COMM_REUSE;
|
ref_ptr->type_ = kCommReuse;
|
||||||
} else {
|
} else {
|
||||||
ref_ptr->type_ = COMM_NOTREUSE;
|
ref_ptr->type_ = kCommNotReuse;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -282,9 +291,9 @@ void MemReuseUtil::SetKernelDefMap() {
|
||||||
kernel_def_ptr->set_input_refs(kernel_def_ptr->inputs_[key]);
|
kernel_def_ptr->set_input_refs(kernel_def_ptr->inputs_[key]);
|
||||||
kernel_def_ptr->set_output_refs(kernel_def_ptr->outputs_[key]);
|
kernel_def_ptr->set_output_refs(kernel_def_ptr->outputs_[key]);
|
||||||
if (AnfAlgo::IsCommunicationOp(kernel)) {
|
if (AnfAlgo::IsCommunicationOp(kernel)) {
|
||||||
kernel_def_ptr->type_ = COMMUNICATION_NODE;
|
kernel_def_ptr->type_ = kCommunicationNode;
|
||||||
} else {
|
} else {
|
||||||
kernel_def_ptr->type_ = COMMON_NODE;
|
kernel_def_ptr->type_ = kCommonNode;
|
||||||
}
|
}
|
||||||
kernel_def_ptr_list_.push_back(kernel_def_ptr);
|
kernel_def_ptr_list_.push_back(kernel_def_ptr);
|
||||||
kernel_map_[key] = kernel_def_ptr;
|
kernel_map_[key] = kernel_def_ptr;
|
||||||
|
@ -365,7 +374,7 @@ void MemReuseUtil::SetSummaryNodesRefCount() {
|
||||||
KernelRefCountPtr kernel_ref = kernel_output_refs_[node.get()][index];
|
KernelRefCountPtr kernel_ref = kernel_output_refs_[node.get()][index];
|
||||||
kernel_ref->ref_count_ = kMaxRefCount;
|
kernel_ref->ref_count_ = kMaxRefCount;
|
||||||
kernel_ref->ref_count_dynamic_use_ = kMaxRefCount;
|
kernel_ref->ref_count_dynamic_use_ = kMaxRefCount;
|
||||||
kernel_ref->type_ = SUMMARY;
|
kernel_ref->type_ = kSummary;
|
||||||
total_summary_size += kernel_ref->size_;
|
total_summary_size += kernel_ref->size_;
|
||||||
MS_LOG(INFO) << "Set summary node's ref count, node: " << node->fullname_with_scope() << " index: " << index;
|
MS_LOG(INFO) << "Set summary node's ref count, node: " << node->fullname_with_scope() << " index: " << index;
|
||||||
} else {
|
} else {
|
||||||
|
@ -373,12 +382,29 @@ void MemReuseUtil::SetSummaryNodesRefCount() {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
#ifdef MEM_REUSE_DEBUG
|
#ifdef MEM_REUSE_DEBUG
|
||||||
auto graph = *graph_;
|
MemReuseChecker::GetInstance().CheckMemReuseIR(total_refs_list_, kernel_def_ptr_list_, graph_);
|
||||||
MemReuseChecker::GetInstance().CheckMemReuseIR(total_refs_list_, kernel_def_ptr_list_, &graph);
|
|
||||||
#endif
|
#endif
|
||||||
MS_LOG(INFO) << "Special Tensor total size: SummaryNodes: " << total_summary_size;
|
MS_LOG(INFO) << "Special Tensor total size: SummaryNodes: " << total_summary_size;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void MemReuseUtil::SetRefNodesInputRefCount() {
|
||||||
|
size_t total_size = 0;
|
||||||
|
for (auto iter : kernel_output_refs_) {
|
||||||
|
for (auto &ref_count : iter.second) {
|
||||||
|
MS_EXCEPTION_IF_NULL(ref_count);
|
||||||
|
if (ref_count->type_ == kRefNodeInput) {
|
||||||
|
ref_count->ref_count_ = kMaxRefCount;
|
||||||
|
total_size += ref_count->size_;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
MS_LOG(INFO) << "Special Tensor total size: RefNodeInput: " << total_size;
|
||||||
|
#ifdef MEM_REUSE_DEBUG
|
||||||
|
MemReuseChecker::GetInstance().CheckMemReuseIR(total_refs_list_, kernel_def_ptr_list_, graph_);
|
||||||
|
#endif
|
||||||
|
}
|
||||||
|
|
||||||
void MemReuseUtil::SetGraphOutputRefCount() {
|
void MemReuseUtil::SetGraphOutputRefCount() {
|
||||||
auto nodes = AnfAlgo::GetAllOutput(graph_->output(), {prim::kPrimTupleGetItem});
|
auto nodes = AnfAlgo::GetAllOutput(graph_->output(), {prim::kPrimTupleGetItem});
|
||||||
for (const auto &node : nodes) {
|
for (const auto &node : nodes) {
|
||||||
|
@ -405,8 +431,7 @@ void MemReuseUtil::SetGraphOutputRefCount() {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
#ifdef MEM_REUSE_DEBUG
|
#ifdef MEM_REUSE_DEBUG
|
||||||
auto graph = *graph_;
|
MemReuseChecker::GetInstance().CheckMemReuseIR(total_refs_list_, kernel_def_ptr_list_, graph_);
|
||||||
MemReuseChecker::GetInstance().CheckMemReuseIR(total_refs_list_, kernel_def_ptr_list_, &graph);
|
|
||||||
#endif
|
#endif
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -419,13 +444,14 @@ void MemReuseUtil::ResetDynamicUsedRefCount() {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void MemReuseUtil::SetAllInfo(KernelGraph *graph) {
|
void MemReuseUtil::SetAllInfo(const KernelGraph *graph) {
|
||||||
if (!InitDynamicKernelRef(graph)) {
|
if (!InitDynamicKernelRef(graph)) {
|
||||||
MS_LOG(EXCEPTION) << "Init ReuseAssignDynamicMemory Fault";
|
MS_LOG(EXCEPTION) << "Init ReuseAssignDynamicMemory Fault";
|
||||||
}
|
}
|
||||||
SetKernelDefMap();
|
SetKernelDefMap();
|
||||||
SetReuseRefCount();
|
SetReuseRefCount();
|
||||||
SetSummaryNodesRefCount();
|
SetSummaryNodesRefCount();
|
||||||
|
SetRefNodesInputRefCount();
|
||||||
SetWorkSpaceList();
|
SetWorkSpaceList();
|
||||||
#ifdef MEM_REUSE_DEBUG
|
#ifdef MEM_REUSE_DEBUG
|
||||||
MemReuseChecker::GetInstance().CheckMemReuseIR(total_refs_list_, kernel_def_ptr_list_, graph);
|
MemReuseChecker::GetInstance().CheckMemReuseIR(total_refs_list_, kernel_def_ptr_list_, graph);
|
||||||
|
|
|
@ -52,7 +52,7 @@ class MemReuseUtil {
|
||||||
MS_LOG(INFO) << "Total Reused WorkSpafce Memory Size: " << total_reuseworkspace_size_;
|
MS_LOG(INFO) << "Total Reused WorkSpafce Memory Size: " << total_reuseworkspace_size_;
|
||||||
}
|
}
|
||||||
|
|
||||||
void SetAllInfo(KernelGraph *graph);
|
void SetAllInfo(const KernelGraph *graph);
|
||||||
bool InitDynamicOutputKernelRef();
|
bool InitDynamicOutputKernelRef();
|
||||||
bool InitDynamicWorkspaceKernelRef();
|
bool InitDynamicWorkspaceKernelRef();
|
||||||
bool InitDynamicKernelRef(const KernelGraph *graph);
|
bool InitDynamicKernelRef(const KernelGraph *graph);
|
||||||
|
@ -64,6 +64,7 @@ class MemReuseUtil {
|
||||||
void SetKernelDefInputs();
|
void SetKernelDefInputs();
|
||||||
void SetReuseRefCount();
|
void SetReuseRefCount();
|
||||||
void SetSummaryNodesRefCount();
|
void SetSummaryNodesRefCount();
|
||||||
|
void SetRefNodesInputRefCount();
|
||||||
// Set the reference count of graph output specially.
|
// Set the reference count of graph output specially.
|
||||||
void SetGraphOutputRefCount();
|
void SetGraphOutputRefCount();
|
||||||
// Reset the dynamic used reference count by ref_count_.
|
// Reset the dynamic used reference count by ref_count_.
|
||||||
|
|
|
@ -90,7 +90,7 @@ bool BestFitMemReuse::IsUsable(const KernelDefPtr &kernel_curr, const MembufPtr
|
||||||
auto curr_stream_id = kernel_curr->stream_id();
|
auto curr_stream_id = kernel_curr->stream_id();
|
||||||
auto prev_stream_id = kernel_prev->stream_id();
|
auto prev_stream_id = kernel_prev->stream_id();
|
||||||
if (curr_stream_id == prev_stream_id) {
|
if (curr_stream_id == prev_stream_id) {
|
||||||
mem_buf->type_ = IN_STREAM_REUSE;
|
mem_buf->type_ = kInStreamReuse;
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -117,7 +117,7 @@ bool BestFitMemReuse::IsUsable(const KernelDefPtr &kernel_curr, const MembufPtr
|
||||||
}
|
}
|
||||||
|
|
||||||
if (reuse_between_streams) {
|
if (reuse_between_streams) {
|
||||||
mem_buf->type_ = BETWEEN_STREAMS_REUSE;
|
mem_buf->type_ = kBetweenStreamReuse;
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -128,7 +128,7 @@ bool BestFitMemReuse::IsUsable(const KernelDefPtr &kernel_curr, const MembufPtr
|
||||||
auto kernel_curr_front = iter->second;
|
auto kernel_curr_front = iter->second;
|
||||||
auto depend_count = kernel_curr_front.count(kernel_prev);
|
auto depend_count = kernel_curr_front.count(kernel_prev);
|
||||||
if (depend_count) {
|
if (depend_count) {
|
||||||
mem_buf->type_ = KERNEL_DEPENDENCE_REUSE;
|
mem_buf->type_ = kKernelDependenceReuse;
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -137,16 +137,19 @@ bool BestFitMemReuse::IsUsable(const KernelDefPtr &kernel_curr, const MembufPtr
|
||||||
|
|
||||||
void BestFitMemReuse::AssignCommonNodeOutputOffset() {
|
void BestFitMemReuse::AssignCommonNodeOutputOffset() {
|
||||||
MS_EXCEPTION_IF_NULL(current_kernel_);
|
MS_EXCEPTION_IF_NULL(current_kernel_);
|
||||||
for (auto &tensor_idx : current_kernel_->GetOutputRefIndexs()) {
|
for (const auto &tensor_idx : current_kernel_->GetOutputRefIndexs()) {
|
||||||
size_t index = GetTensorIndex(tensor_idx);
|
size_t index = GetTensorIndex(tensor_idx);
|
||||||
auto tensor_desc = tensor_ptr_list_[index];
|
auto tensor_desc = tensor_ptr_list_[index];
|
||||||
MS_EXCEPTION_IF_NULL(tensor_desc);
|
MS_EXCEPTION_IF_NULL(tensor_desc);
|
||||||
if (tensor_desc->type_ == REFNODE_OUTPUT) {
|
if (tensor_desc->type_ == kRefNodeInput) {
|
||||||
|
total_refinput_size += tensor_desc->size_;
|
||||||
|
} else if (tensor_desc->type_ == kRefNodeOutput) {
|
||||||
total_refoutput_size += tensor_desc->size_;
|
total_refoutput_size += tensor_desc->size_;
|
||||||
|
// no need to alloc refnode output's memory
|
||||||
continue;
|
continue;
|
||||||
} else if (tensor_desc->type_ == COMM_NOTREUSE) {
|
} else if (tensor_desc->type_ == kCommNotReuse) {
|
||||||
total_comm_not_reuse_size += tensor_desc->size_;
|
total_comm_not_reuse_size += tensor_desc->size_;
|
||||||
} else if (tensor_desc->type_ == COMM_REUSE) {
|
} else if (tensor_desc->type_ == kCommReuse) {
|
||||||
// get align size for communication op's single input
|
// get align size for communication op's single input
|
||||||
tensor_desc->size_ = AlignCommunicationMemorySize(tensor_desc->size_);
|
tensor_desc->size_ = AlignCommunicationMemorySize(tensor_desc->size_);
|
||||||
total_comm_reuse_size += tensor_desc->size_;
|
total_comm_reuse_size += tensor_desc->size_;
|
||||||
|
@ -165,7 +168,7 @@ void BestFitMemReuse::AssignCommonNodeOutputOffset() {
|
||||||
#endif
|
#endif
|
||||||
}
|
}
|
||||||
// skip left align border for communication op single input to used
|
// skip left align border for communication op single input to used
|
||||||
if (tensor_desc->type_ == COMM_REUSE) {
|
if (tensor_desc->type_ == kCommReuse) {
|
||||||
tensor_desc->offset_ += kDefaultMemAlignSize;
|
tensor_desc->offset_ += kDefaultMemAlignSize;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -176,17 +179,18 @@ void BestFitMemReuse::AssignCommunicationNodeOutputOffset() {
|
||||||
size_t output_num = 0;
|
size_t output_num = 0;
|
||||||
// get all output size
|
// get all output size
|
||||||
MS_EXCEPTION_IF_NULL(current_kernel_);
|
MS_EXCEPTION_IF_NULL(current_kernel_);
|
||||||
for (auto &tensor_idx : current_kernel_->GetOutputRefIndexs()) {
|
for (const auto &tensor_idx : current_kernel_->GetOutputRefIndexs()) {
|
||||||
size_t index = GetTensorIndex(tensor_idx);
|
size_t index = GetTensorIndex(tensor_idx);
|
||||||
auto tensor_desc = tensor_ptr_list_[index];
|
auto tensor_desc = tensor_ptr_list_[index];
|
||||||
MS_EXCEPTION_IF_NULL(tensor_desc);
|
MS_EXCEPTION_IF_NULL(tensor_desc);
|
||||||
if (tensor_desc->type_ == COMM_REUSE) {
|
if (tensor_desc->type_ == kCommReuse) {
|
||||||
total_comm_reuse_size += tensor_desc->size_;
|
total_comm_reuse_size += tensor_desc->size_;
|
||||||
total_comm_output_reuse_size += tensor_desc->size_;
|
total_comm_output_reuse_size += tensor_desc->size_;
|
||||||
total_kernel_output_size += tensor_desc->size_;
|
total_kernel_output_size += tensor_desc->size_;
|
||||||
} else {
|
} else {
|
||||||
MS_LOG(ERROR) << "All communication op's outputs should be memory reuse, Kernel:"
|
MS_LOG(ERROR) << "All communication op's outputs should be memory reuse, Kernel:"
|
||||||
<< current_kernel_->scope_full_name();
|
<< current_kernel_->scope_full_name() << " output index:" << tensor_idx
|
||||||
|
<< " tensor_type:" << tensor_desc->type_;
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -195,7 +199,7 @@ void BestFitMemReuse::AssignCommunicationNodeOutputOffset() {
|
||||||
// add left align border for the first output and right align border for the last output to alloc align border memory
|
// add left align border for the first output and right align border for the last output to alloc align border memory
|
||||||
size_t output_index = 0;
|
size_t output_index = 0;
|
||||||
auto output_ref_indexes = current_kernel_->GetOutputRefIndexs();
|
auto output_ref_indexes = current_kernel_->GetOutputRefIndexs();
|
||||||
for (auto &tensor_idx : output_ref_indexes) {
|
for (const auto &tensor_idx : output_ref_indexes) {
|
||||||
size_t index = GetTensorIndex(tensor_idx);
|
size_t index = GetTensorIndex(tensor_idx);
|
||||||
auto tensor_desc = tensor_ptr_list_[index];
|
auto tensor_desc = tensor_ptr_list_[index];
|
||||||
MS_EXCEPTION_IF_NULL(tensor_desc);
|
MS_EXCEPTION_IF_NULL(tensor_desc);
|
||||||
|
@ -215,7 +219,7 @@ void BestFitMemReuse::AssignCommunicationNodeOutputOffset() {
|
||||||
if (!reusable_membuf_map.empty()) {
|
if (!reusable_membuf_map.empty()) {
|
||||||
auto membuf_index = reusable_membuf_map.begin()->second;
|
auto membuf_index = reusable_membuf_map.begin()->second;
|
||||||
output_index = 0;
|
output_index = 0;
|
||||||
for (auto &tensor_idx : current_kernel_->GetOutputRefIndexs()) {
|
for (const auto &tensor_idx : current_kernel_->GetOutputRefIndexs()) {
|
||||||
size_t index = GetTensorIndex(tensor_idx);
|
size_t index = GetTensorIndex(tensor_idx);
|
||||||
auto tensor_desc = tensor_ptr_list_[index];
|
auto tensor_desc = tensor_ptr_list_[index];
|
||||||
MS_EXCEPTION_IF_NULL(tensor_desc);
|
MS_EXCEPTION_IF_NULL(tensor_desc);
|
||||||
|
@ -229,7 +233,7 @@ void BestFitMemReuse::AssignCommunicationNodeOutputOffset() {
|
||||||
} else {
|
} else {
|
||||||
// no membuf can reuse, add new membuf after the membuf_ptr_list
|
// no membuf can reuse, add new membuf after the membuf_ptr_list
|
||||||
output_index = 0;
|
output_index = 0;
|
||||||
for (auto &tensor_idx : current_kernel_->GetOutputRefIndexs()) {
|
for (const auto &tensor_idx : current_kernel_->GetOutputRefIndexs()) {
|
||||||
size_t index = GetTensorIndex(tensor_idx);
|
size_t index = GetTensorIndex(tensor_idx);
|
||||||
auto tensor_desc = tensor_ptr_list_[index];
|
auto tensor_desc = tensor_ptr_list_[index];
|
||||||
MS_EXCEPTION_IF_NULL(tensor_desc);
|
MS_EXCEPTION_IF_NULL(tensor_desc);
|
||||||
|
@ -247,7 +251,7 @@ void BestFitMemReuse::AssignCommunicationNodeOutputOffset() {
|
||||||
}
|
}
|
||||||
|
|
||||||
void BestFitMemReuse::AssignNodeOutputOffset() {
|
void BestFitMemReuse::AssignNodeOutputOffset() {
|
||||||
if (current_kernel_->type_ == COMMUNICATION_NODE) {
|
if (current_kernel_->type_ == kCommunicationNode) {
|
||||||
AssignCommunicationNodeOutputOffset();
|
AssignCommunicationNodeOutputOffset();
|
||||||
} else {
|
} else {
|
||||||
AssignCommonNodeOutputOffset();
|
AssignCommonNodeOutputOffset();
|
||||||
|
@ -330,7 +334,7 @@ void BestFitMemReuse::AddNewMembufPtr(KernelRefCount *tensor_desc, int flag) {
|
||||||
}
|
}
|
||||||
auto membuf_size = tensor_desc->size_;
|
auto membuf_size = tensor_desc->size_;
|
||||||
auto real_index = GetRealIndex(IntToSize(tensor_desc->index_), flag);
|
auto real_index = GetRealIndex(IntToSize(tensor_desc->index_), flag);
|
||||||
auto membuf = std::make_shared<Membuf>(kReused, membuf_size, membuf_offset, real_index, NEW, current_kernel_);
|
auto membuf = std::make_shared<Membuf>(kReused, membuf_size, membuf_offset, real_index, kNew, current_kernel_);
|
||||||
membuf_ptr_list_.push_back(membuf);
|
membuf_ptr_list_.push_back(membuf);
|
||||||
tensor_desc->offset_ = membuf_offset;
|
tensor_desc->offset_ = membuf_offset;
|
||||||
}
|
}
|
||||||
|
@ -352,7 +356,7 @@ void BestFitMemReuse::UpdateNodeInputAndMembuf() {
|
||||||
}
|
}
|
||||||
|
|
||||||
void BestFitMemReuse::ReleaseNodeUnusedOutput() {
|
void BestFitMemReuse::ReleaseNodeUnusedOutput() {
|
||||||
for (auto &tensor_idx : current_kernel_->GetOutputRefIndexs()) {
|
for (const auto &tensor_idx : current_kernel_->GetOutputRefIndexs()) {
|
||||||
size_t tensor_index = GetTensorIndex(tensor_idx);
|
size_t tensor_index = GetTensorIndex(tensor_idx);
|
||||||
auto tensor_desc = tensor_ptr_list_[tensor_index];
|
auto tensor_desc = tensor_ptr_list_[tensor_index];
|
||||||
MS_EXCEPTION_IF_NULL(tensor_desc);
|
MS_EXCEPTION_IF_NULL(tensor_desc);
|
||||||
|
@ -517,8 +521,9 @@ void BestFitMemReuse::Reuse(const MemReuseUtil *mem_reuse_util_ptr) {
|
||||||
++op_num;
|
++op_num;
|
||||||
#endif
|
#endif
|
||||||
}
|
}
|
||||||
MS_LOG(INFO) << "Special Tensor total size: RefOutput: " << total_refoutput_size
|
MS_LOG(INFO) << "Special Tensor total size: RefInput: " << total_refinput_size
|
||||||
<< " CommReuse: " << total_comm_reuse_size << " CommOutputReuse: " << total_comm_output_reuse_size
|
<< " RefOutput: " << total_refoutput_size << " CommReuse: " << total_comm_reuse_size
|
||||||
|
<< " CommOutputReuse: " << total_comm_output_reuse_size
|
||||||
<< " CommNotReuse: " << total_comm_not_reuse_size;
|
<< " CommNotReuse: " << total_comm_not_reuse_size;
|
||||||
#ifdef MEM_REUSE_DEBUG
|
#ifdef MEM_REUSE_DEBUG
|
||||||
MemReuseChecker::GetInstance().ExportMembufInfoIR();
|
MemReuseChecker::GetInstance().ExportMembufInfoIR();
|
||||||
|
|
|
@ -40,11 +40,11 @@ static constexpr int kDynamicMem = -1;
|
||||||
static constexpr int kWorkspaceMem = 1;
|
static constexpr int kWorkspaceMem = 1;
|
||||||
static constexpr size_t kTotalSize = 0;
|
static constexpr size_t kTotalSize = 0;
|
||||||
enum Status { kUnused, kReused };
|
enum Status { kUnused, kReused };
|
||||||
enum MEMTYPE { NEW, IN_STREAM_REUSE, BETWEEN_STREAMS_REUSE, KERNEL_DEPENDENCE_REUSE };
|
enum MemType { kNew, kInStreamReuse, kBetweenStreamReuse, kKernelDependenceReuse };
|
||||||
class Membuf {
|
class Membuf {
|
||||||
public:
|
public:
|
||||||
Membuf() = default;
|
Membuf() = default;
|
||||||
Membuf(Status status, size_t size, size_t offset, int index, MEMTYPE type, const KernelDefPtr &used_kernel)
|
Membuf(Status status, size_t size, size_t offset, int index, MemType type, const KernelDefPtr &used_kernel)
|
||||||
: status_(status), size_(size), offset_(offset), index_(index), type_(type), used_kernel_(used_kernel) {}
|
: status_(status), size_(size), offset_(offset), index_(index), type_(type), used_kernel_(used_kernel) {}
|
||||||
~Membuf() = default;
|
~Membuf() = default;
|
||||||
// Memory block status flags
|
// Memory block status flags
|
||||||
|
@ -53,7 +53,7 @@ class Membuf {
|
||||||
size_t offset_{0};
|
size_t offset_{0};
|
||||||
// Store the tensor index stored in this memory block at a certain moment
|
// Store the tensor index stored in this memory block at a certain moment
|
||||||
int index_{0};
|
int index_{0};
|
||||||
MEMTYPE type_{NEW};
|
MemType type_{kNew};
|
||||||
KernelDefPtr used_kernel_;
|
KernelDefPtr used_kernel_;
|
||||||
};
|
};
|
||||||
using MembufPtr = std::shared_ptr<Membuf>;
|
using MembufPtr = std::shared_ptr<Membuf>;
|
||||||
|
@ -163,6 +163,7 @@ class BestFitMemReuse {
|
||||||
// kernel_front_map_, key: the kernel_def, value: kernels before this kernel_def
|
// kernel_front_map_, key: the kernel_def, value: kernels before this kernel_def
|
||||||
std::map<KernelDefPtr, std::set<KernelDefPtr>> kernel_front_map_;
|
std::map<KernelDefPtr, std::set<KernelDefPtr>> kernel_front_map_;
|
||||||
std::vector<std::vector<uint32_t>> stream_groups_;
|
std::vector<std::vector<uint32_t>> stream_groups_;
|
||||||
|
size_t total_refinput_size{0};
|
||||||
size_t total_refoutput_size{0};
|
size_t total_refoutput_size{0};
|
||||||
size_t total_comm_reuse_size{0};
|
size_t total_comm_reuse_size{0};
|
||||||
size_t total_comm_output_reuse_size{0};
|
size_t total_comm_output_reuse_size{0};
|
||||||
|
|
|
@ -83,7 +83,7 @@ int64_t MemReuseChecker::CalculOriInput(const KernelGraph *graph) const {
|
||||||
return static_input_size;
|
return static_input_size;
|
||||||
}
|
}
|
||||||
|
|
||||||
int64_t MemReuseChecker::CalculOriValue(KernelGraph *graph) const {
|
int64_t MemReuseChecker::CalculOriValue(const KernelGraph *graph) const {
|
||||||
MS_EXCEPTION_IF_NULL(graph);
|
MS_EXCEPTION_IF_NULL(graph);
|
||||||
int64_t static_value_size = 0;
|
int64_t static_value_size = 0;
|
||||||
for (auto &value_node : graph->graph_value_nodes()) {
|
for (auto &value_node : graph->graph_value_nodes()) {
|
||||||
|
@ -101,7 +101,7 @@ int64_t MemReuseChecker::CalculOriValue(KernelGraph *graph) const {
|
||||||
return static_value_size;
|
return static_value_size;
|
||||||
}
|
}
|
||||||
|
|
||||||
int64_t MemReuseChecker::CalculOriStatic(KernelGraph *graph) const {
|
int64_t MemReuseChecker::CalculOriStatic(const KernelGraph *graph) const {
|
||||||
// cal static inputs
|
// cal static inputs
|
||||||
auto static_input_size = CalculOriInput(graph);
|
auto static_input_size = CalculOriInput(graph);
|
||||||
// do not calcul outpput size
|
// do not calcul outpput size
|
||||||
|
@ -154,7 +154,7 @@ std::string MemReuseChecker::GetSplitName(const std::string &scope_name) const {
|
||||||
}
|
}
|
||||||
|
|
||||||
void MemReuseChecker::CheckMemReuseIR(const KernelRefCountPtrList &total_refs_list,
|
void MemReuseChecker::CheckMemReuseIR(const KernelRefCountPtrList &total_refs_list,
|
||||||
const KernelDefPtrMaps &kernel_def_ptr_list, KernelGraph *graph) {
|
const KernelDefPtrMaps &kernel_def_ptr_list, const KernelGraph *graph) {
|
||||||
total_ori_static_size_ = CalculOriStatic(graph);
|
total_ori_static_size_ = CalculOriStatic(graph);
|
||||||
total_ori_input_size_ = CalculOriInput(graph);
|
total_ori_input_size_ = CalculOriInput(graph);
|
||||||
total_ori_value_size_ = CalculOriValue(graph);
|
total_ori_value_size_ = CalculOriValue(graph);
|
||||||
|
|
|
@ -43,10 +43,10 @@ class MemReuseChecker {
|
||||||
void CheckOutRef(const KernelRefs &kernel_refs, const CNodePtr &c_node, size_t output_idx);
|
void CheckOutRef(const KernelRefs &kernel_refs, const CNodePtr &c_node, size_t output_idx);
|
||||||
bool CheckGraphOutputAssigned(const session::KernelGraph *graph);
|
bool CheckGraphOutputAssigned(const session::KernelGraph *graph);
|
||||||
void CheckMemReuseIR(const KernelRefCountPtrList &total_refs_list, const KernelDefPtrMaps &kernel_def_ptr_list,
|
void CheckMemReuseIR(const KernelRefCountPtrList &total_refs_list, const KernelDefPtrMaps &kernel_def_ptr_list,
|
||||||
KernelGraph *graph);
|
const KernelGraph *graph);
|
||||||
int64_t CalculOriStatic(KernelGraph *graph) const;
|
int64_t CalculOriStatic(const KernelGraph *graph) const;
|
||||||
int64_t CalculOriInput(const KernelGraph *graph) const;
|
int64_t CalculOriInput(const KernelGraph *graph) const;
|
||||||
int64_t CalculOriValue(KernelGraph *graph) const;
|
int64_t CalculOriValue(const KernelGraph *graph) const;
|
||||||
int64_t CalculOriDy(const KernelGraph *graph) const;
|
int64_t CalculOriDy(const KernelGraph *graph) const;
|
||||||
int64_t CalculOriWk(const KernelGraph *graph) const;
|
int64_t CalculOriWk(const KernelGraph *graph) const;
|
||||||
std::string GetSplitName(const std::string &scope_name) const;
|
std::string GetSplitName(const std::string &scope_name) const;
|
||||||
|
|
|
@ -398,12 +398,12 @@ void KernelRuntime::UpdateRefNodeOutputMem(const session::KernelGraph *graph) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void KernelRuntime::AssignCommunicationNodeMem(int flag, const AnfNodePtr &node) {
|
void KernelRuntime::AssignCommunicationNodeMem(MemType type, const AnfNodePtr &node) {
|
||||||
AssignCommunicationNodeInputMem(flag, node);
|
AssignCommunicationNodeInputMem(type, node);
|
||||||
AssignCommunicationNodeOutputMem(flag, node);
|
AssignCommunicationNodeOutputMem(type, node);
|
||||||
}
|
}
|
||||||
|
|
||||||
void KernelRuntime::AssignCommunicationNodeOutputMem(int flag, const AnfNodePtr &node) {
|
void KernelRuntime::AssignCommunicationNodeOutputMem(MemType type, const AnfNodePtr &node) {
|
||||||
MS_EXCEPTION_IF_NULL(node);
|
MS_EXCEPTION_IF_NULL(node);
|
||||||
MS_EXCEPTION_IF_NULL(mem_manager_);
|
MS_EXCEPTION_IF_NULL(mem_manager_);
|
||||||
auto kernel_mod = AnfAlgo::GetKernelMod(node);
|
auto kernel_mod = AnfAlgo::GetKernelMod(node);
|
||||||
|
@ -430,11 +430,11 @@ void KernelRuntime::AssignCommunicationNodeOutputMem(int flag, const AnfNodePtr
|
||||||
align_size_list.emplace_back(mem_size);
|
align_size_list.emplace_back(mem_size);
|
||||||
}
|
}
|
||||||
|
|
||||||
if (flag == kReuseDynamicMem) {
|
if (type == kReuseDynamicMem) {
|
||||||
// reuse communication op's all outputs' memory
|
// reuse communication op's all outputs' memory
|
||||||
flag = kReuseDynamicCommMem;
|
type = kReuseDynamicCommMem;
|
||||||
}
|
}
|
||||||
uint8_t *output_ptr = mem_manager_->MallocOutputMem(node, 0, flag, total_size);
|
uint8_t *output_ptr = mem_manager_->MallocOutputMem(node, 0, type, total_size);
|
||||||
for (size_t j = 0; j < align_size_list.size(); ++j) {
|
for (size_t j = 0; j < align_size_list.size(); ++j) {
|
||||||
std::string output_format = AnfAlgo::GetOutputFormat(node, j);
|
std::string output_format = AnfAlgo::GetOutputFormat(node, j);
|
||||||
auto output_type = AnfAlgo::GetOutputDeviceDataType(node, j);
|
auto output_type = AnfAlgo::GetOutputDeviceDataType(node, j);
|
||||||
|
@ -458,7 +458,7 @@ DeviceAddressPtr KernelRuntime::PreAssignCNodeMemory(const AnfNodePtr &anf_node,
|
||||||
return address;
|
return address;
|
||||||
}
|
}
|
||||||
|
|
||||||
void KernelRuntime::AssignCommunicationNodeInputMem(int flag, const AnfNodePtr &node) {
|
void KernelRuntime::AssignCommunicationNodeInputMem(MemType type, const AnfNodePtr &node) {
|
||||||
auto context_ptr = MsContext::GetInstance();
|
auto context_ptr = MsContext::GetInstance();
|
||||||
MS_EXCEPTION_IF_NULL(context_ptr);
|
MS_EXCEPTION_IF_NULL(context_ptr);
|
||||||
MS_EXCEPTION_IF_NULL(node);
|
MS_EXCEPTION_IF_NULL(node);
|
||||||
|
@ -479,7 +479,7 @@ void KernelRuntime::AssignCommunicationNodeInputMem(int flag, const AnfNodePtr &
|
||||||
total_size += mem_size;
|
total_size += mem_size;
|
||||||
addr_size.emplace_back(address.get(), mem_size);
|
addr_size.emplace_back(address.get(), mem_size);
|
||||||
}
|
}
|
||||||
uint8_t *input_ptr = mem_manager_->MallocOutputMem(node, 0, flag, total_size);
|
uint8_t *input_ptr = mem_manager_->MallocOutputMem(node, 0, type, total_size);
|
||||||
for (const auto &iter : addr_size) {
|
for (const auto &iter : addr_size) {
|
||||||
MS_EXCEPTION_IF_NULL(iter.first);
|
MS_EXCEPTION_IF_NULL(iter.first);
|
||||||
iter.first->set_ptr(input_ptr);
|
iter.first->set_ptr(input_ptr);
|
||||||
|
@ -487,12 +487,12 @@ void KernelRuntime::AssignCommunicationNodeInputMem(int flag, const AnfNodePtr &
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void KernelRuntime::AssignNodeOutputMem(int flag, const AnfNodePtr &node, int index) {
|
void KernelRuntime::AssignNodeOutputMem(MemType type, const AnfNodePtr &node, int index) {
|
||||||
MS_EXCEPTION_IF_NULL(node);
|
MS_EXCEPTION_IF_NULL(node);
|
||||||
MS_EXCEPTION_IF_NULL(mem_manager_);
|
MS_EXCEPTION_IF_NULL(mem_manager_);
|
||||||
if (AnfAlgo::IsGetNext(NOT_NULL(node)) && flag == kReuseDynamicMem) {
|
if (AnfAlgo::IsGetNext(NOT_NULL(node)) && type == kReuseDynamicMem) {
|
||||||
MS_LOG(INFO) << "GetNext disable mem_reuse";
|
MS_LOG(INFO) << "GetNext disable mem_reuse";
|
||||||
flag = kDynamicMem;
|
type = kDynamicMem;
|
||||||
}
|
}
|
||||||
auto kernel_mod = AnfAlgo::GetKernelMod(node);
|
auto kernel_mod = AnfAlgo::GetKernelMod(node);
|
||||||
MS_EXCEPTION_IF_NULL(kernel_mod);
|
MS_EXCEPTION_IF_NULL(kernel_mod);
|
||||||
|
@ -509,7 +509,7 @@ void KernelRuntime::AssignNodeOutputMem(int flag, const AnfNodePtr &node, int in
|
||||||
MS_LOG(INFO) << "Already malloc index:" << i;
|
MS_LOG(INFO) << "Already malloc index:" << i;
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
auto ptr = mem_manager_->MallocOutputMem(node, i, flag, output_sizes[i]);
|
auto ptr = mem_manager_->MallocOutputMem(node, i, type, output_sizes[i]);
|
||||||
if (ptr == nullptr) {
|
if (ptr == nullptr) {
|
||||||
// reused ptr, no need alloc, continue;
|
// reused ptr, no need alloc, continue;
|
||||||
continue;
|
continue;
|
||||||
|
@ -608,10 +608,10 @@ void KernelRuntime::AssignDynamicMemory(session::KernelGraph *graph) {
|
||||||
auto context_ptr = MsContext::GetInstance();
|
auto context_ptr = MsContext::GetInstance();
|
||||||
MS_EXCEPTION_IF_NULL(context_ptr);
|
MS_EXCEPTION_IF_NULL(context_ptr);
|
||||||
bool is_enable_mem_reuse = context_ptr->enable_mem_reuse();
|
bool is_enable_mem_reuse = context_ptr->enable_mem_reuse();
|
||||||
auto mem_flag = kDynamicMem;
|
auto mem_type = kDynamicMem;
|
||||||
if (is_enable_mem_reuse) {
|
if (is_enable_mem_reuse) {
|
||||||
mem_manager_->MallocReusedDynamicMem(graph);
|
mem_manager_->MallocReusedDynamicMem(graph);
|
||||||
mem_flag = kReuseDynamicMem;
|
mem_type = kReuseDynamicMem;
|
||||||
}
|
}
|
||||||
auto &execution_nodes = graph->execution_order();
|
auto &execution_nodes = graph->execution_order();
|
||||||
std::vector<CNodePtr> compute_nodes;
|
std::vector<CNodePtr> compute_nodes;
|
||||||
|
@ -619,7 +619,7 @@ void KernelRuntime::AssignDynamicMemory(session::KernelGraph *graph) {
|
||||||
for (auto &node : execution_nodes) {
|
for (auto &node : execution_nodes) {
|
||||||
if (AnfAlgo::IsCommunicationOp(node)) {
|
if (AnfAlgo::IsCommunicationOp(node)) {
|
||||||
// skip if the memory is already alocated
|
// skip if the memory is already alocated
|
||||||
AssignCommunicationNodeMem(mem_flag, node);
|
AssignCommunicationNodeMem(mem_type, node);
|
||||||
} else {
|
} else {
|
||||||
compute_nodes.emplace_back(node);
|
compute_nodes.emplace_back(node);
|
||||||
}
|
}
|
||||||
|
@ -627,19 +627,19 @@ void KernelRuntime::AssignDynamicMemory(session::KernelGraph *graph) {
|
||||||
|
|
||||||
// then compute nodes
|
// then compute nodes
|
||||||
for (auto &node : compute_nodes) {
|
for (auto &node : compute_nodes) {
|
||||||
AssignNodeOutputMem(mem_flag, node, kGetAllOuts);
|
AssignNodeOutputMem(mem_type, node, kGetAllOuts);
|
||||||
AssignWorkSpaceMem(mem_flag, node);
|
AssignWorkSpaceMem(mem_type, node);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void KernelRuntime::AssignWorkSpaceMem(int flag, const AnfNodePtr &node) {
|
void KernelRuntime::AssignWorkSpaceMem(MemType type, const AnfNodePtr &node) {
|
||||||
MS_EXCEPTION_IF_NULL(node);
|
MS_EXCEPTION_IF_NULL(node);
|
||||||
MS_EXCEPTION_IF_NULL(mem_manager_);
|
MS_EXCEPTION_IF_NULL(mem_manager_);
|
||||||
auto kernel_mod = AnfAlgo::GetKernelMod(node);
|
auto kernel_mod = AnfAlgo::GetKernelMod(node);
|
||||||
MS_EXCEPTION_IF_NULL(kernel_mod);
|
MS_EXCEPTION_IF_NULL(kernel_mod);
|
||||||
size_t index = 0;
|
size_t index = 0;
|
||||||
for (auto &size : kernel_mod->GetWorkspaceSizeList()) {
|
for (auto &size : kernel_mod->GetWorkspaceSizeList()) {
|
||||||
auto ptr = mem_manager_->MallocWorkSpaceMem(node, index, flag, size);
|
auto ptr = mem_manager_->MallocWorkSpaceMem(node, index, type, size);
|
||||||
AnfAlgo::SetWorkspaceAddr(CreateDeviceAddress(ptr, size, "", kTypeUnknown), index, node.get());
|
AnfAlgo::SetWorkspaceAddr(CreateDeviceAddress(ptr, size, "", kTypeUnknown), index, node.get());
|
||||||
index++;
|
index++;
|
||||||
}
|
}
|
||||||
|
|
|
@ -83,15 +83,15 @@ class KernelRuntime {
|
||||||
void AssignStaticMemory(session::KernelGraph *graph);
|
void AssignStaticMemory(session::KernelGraph *graph);
|
||||||
void AssignDynamicMemory(session::KernelGraph *graph);
|
void AssignDynamicMemory(session::KernelGraph *graph);
|
||||||
void ReuseAssignDynamicMemory(session::KernelGraph *graph);
|
void ReuseAssignDynamicMemory(session::KernelGraph *graph);
|
||||||
void AssignNodeOutputMem(int flag, const AnfNodePtr &node, int index);
|
void AssignNodeOutputMem(MemType type, const AnfNodePtr &node, int index);
|
||||||
void AssignWorkSpaceMem(int flag, const AnfNodePtr &node);
|
void AssignWorkSpaceMem(MemType type, const AnfNodePtr &node);
|
||||||
void AssignReuseWorkSpaceMem(const AnfNodePtr &node);
|
void AssignReuseWorkSpaceMem(const AnfNodePtr &node);
|
||||||
|
|
||||||
void UpdateRefNodeOutputMem(const session::KernelGraph *graph);
|
void UpdateRefNodeOutputMem(const session::KernelGraph *graph);
|
||||||
|
|
||||||
void AssignCommunicationNodeOutputMem(int flag, const AnfNodePtr &node);
|
void AssignCommunicationNodeOutputMem(MemType type, const AnfNodePtr &node);
|
||||||
void AssignCommunicationNodeInputMem(int flag, const AnfNodePtr &node);
|
void AssignCommunicationNodeInputMem(MemType type, const AnfNodePtr &node);
|
||||||
void AssignCommunicationNodeMem(int flag, const AnfNodePtr &node);
|
void AssignCommunicationNodeMem(MemType type, const AnfNodePtr &node);
|
||||||
#ifdef ENABLE_DUMP_E2E
|
#ifdef ENABLE_DUMP_E2E
|
||||||
bool SetDumpConf();
|
bool SetDumpConf();
|
||||||
#endif
|
#endif
|
||||||
|
|
|
@ -29,7 +29,7 @@ size_t MemoryManager::GetCommunicationAlignSize(size_t input_size) const {
|
||||||
return (input_size + kMemAlignSize - 1) / kMemAlignSize * kMemAlignSize + 2 * kMemAlignSize;
|
return (input_size + kMemAlignSize - 1) / kMemAlignSize * kMemAlignSize + 2 * kMemAlignSize;
|
||||||
}
|
}
|
||||||
|
|
||||||
void MemoryManager::MallocReusedDynamicMem(session::KernelGraph *graph) {
|
void MemoryManager::MallocReusedDynamicMem(const session::KernelGraph *graph) {
|
||||||
MS_EXCEPTION_IF_NULL(graph);
|
MS_EXCEPTION_IF_NULL(graph);
|
||||||
MemReuseUtilPtr mem_reuse_util_ptr = std::make_shared<memreuse::MemReuseUtil>();
|
MemReuseUtilPtr mem_reuse_util_ptr = std::make_shared<memreuse::MemReuseUtil>();
|
||||||
MS_EXCEPTION_IF_NULL(mem_reuse_util_ptr);
|
MS_EXCEPTION_IF_NULL(mem_reuse_util_ptr);
|
||||||
|
@ -45,7 +45,7 @@ void MemoryManager::MallocReusedDynamicMem(session::KernelGraph *graph) {
|
||||||
mem_reuse_util_ptr_->set_mem_base(base_ptr);
|
mem_reuse_util_ptr_->set_mem_base(base_ptr);
|
||||||
}
|
}
|
||||||
|
|
||||||
uint8_t *MemoryManager::MallocOutputMem(const AnfNodePtr &node, size_t index, int flag, size_t size) {
|
uint8_t *MemoryManager::MallocOutputMem(const AnfNodePtr &node, size_t index, MemType type, size_t size) {
|
||||||
MS_EXCEPTION_IF_NULL(node);
|
MS_EXCEPTION_IF_NULL(node);
|
||||||
auto context_ptr = MsContext::GetInstance();
|
auto context_ptr = MsContext::GetInstance();
|
||||||
MS_EXCEPTION_IF_NULL(context_ptr);
|
MS_EXCEPTION_IF_NULL(context_ptr);
|
||||||
|
@ -55,9 +55,9 @@ uint8_t *MemoryManager::MallocOutputMem(const AnfNodePtr &node, size_t index, in
|
||||||
if (context_ptr->enable_hccl()) {
|
if (context_ptr->enable_hccl()) {
|
||||||
communication_mem = true;
|
communication_mem = true;
|
||||||
}
|
}
|
||||||
if (flag == kStaticMem) {
|
if (type == kStaticMem) {
|
||||||
ptr = MallocStaticMem(size, communication_mem);
|
ptr = MallocStaticMem(size, communication_mem);
|
||||||
} else if (flag == kReuseDynamicCommMem) {
|
} else if (type == kReuseDynamicCommMem) {
|
||||||
MS_EXCEPTION_IF_NULL(mem_reuse_util_ptr_);
|
MS_EXCEPTION_IF_NULL(mem_reuse_util_ptr_);
|
||||||
ptr = mem_reuse_util_ptr_->GetNodeOutputPtr(node, index);
|
ptr = mem_reuse_util_ptr_->GetNodeOutputPtr(node, index);
|
||||||
} else {
|
} else {
|
||||||
|
@ -66,30 +66,30 @@ uint8_t *MemoryManager::MallocOutputMem(const AnfNodePtr &node, size_t index, in
|
||||||
return ptr;
|
return ptr;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (flag == kStaticMem) {
|
if (type == kStaticMem) {
|
||||||
ptr = MallocStaticMem(size, false);
|
ptr = MallocStaticMem(size, false);
|
||||||
} else if (flag == kDynamicMem) {
|
} else if (type == kDynamicMem) {
|
||||||
ptr = MallocDynamicMem(size, false);
|
ptr = MallocDynamicMem(size, false);
|
||||||
} else if (flag == kReuseDynamicMem) {
|
} else if (type == kReuseDynamicMem) {
|
||||||
MS_EXCEPTION_IF_NULL(mem_reuse_util_ptr_);
|
MS_EXCEPTION_IF_NULL(mem_reuse_util_ptr_);
|
||||||
ptr = mem_reuse_util_ptr_->GetNodeOutputPtr(node, index);
|
ptr = mem_reuse_util_ptr_->GetNodeOutputPtr(node, index);
|
||||||
}
|
}
|
||||||
return ptr;
|
return ptr;
|
||||||
}
|
}
|
||||||
|
|
||||||
uint8_t *MemoryManager::MallocWorkSpaceMem(const AnfNodePtr &node, size_t index, int flag, size_t size) {
|
uint8_t *MemoryManager::MallocWorkSpaceMem(const AnfNodePtr &node, size_t index, MemType type, size_t size) {
|
||||||
if (flag == kReuseDynamicMem) {
|
if (type == kReuseDynamicMem) {
|
||||||
MS_EXCEPTION_IF_NULL(mem_reuse_util_ptr_);
|
MS_EXCEPTION_IF_NULL(mem_reuse_util_ptr_);
|
||||||
return mem_reuse_util_ptr_->GetNodeWorkSpacePtr(node, index);
|
return mem_reuse_util_ptr_->GetNodeWorkSpacePtr(node, index);
|
||||||
}
|
}
|
||||||
return MallocDynamicMem(size, false);
|
return MallocDynamicMem(size, false);
|
||||||
}
|
}
|
||||||
|
|
||||||
uint8_t *MemoryManager::MallocMem(int flag, size_t size) {
|
uint8_t *MemoryManager::MallocMem(MemType type, size_t size) {
|
||||||
uint8_t *ptr = nullptr;
|
uint8_t *ptr = nullptr;
|
||||||
if (flag == kStaticMem) {
|
if (type == kStaticMem) {
|
||||||
ptr = MallocStaticMem(size, false);
|
ptr = MallocStaticMem(size, false);
|
||||||
} else if (flag == kDynamicMem) {
|
} else if (type == kDynamicMem) {
|
||||||
ptr = MallocDynamicMem(size, false);
|
ptr = MallocDynamicMem(size, false);
|
||||||
}
|
}
|
||||||
return ptr;
|
return ptr;
|
||||||
|
|
|
@ -22,10 +22,7 @@
|
||||||
#include "backend/optimizer/mem_reuse/mem_reuse_allocator.h"
|
#include "backend/optimizer/mem_reuse/mem_reuse_allocator.h"
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace device {
|
namespace device {
|
||||||
const int kStaticMem = 0;
|
enum MemType { kStaticMem, kDynamicMem, kReuseDynamicMem, kReuseDynamicCommMem };
|
||||||
const int kDynamicMem = 1;
|
|
||||||
const int kReuseDynamicMem = 2;
|
|
||||||
const int kReuseDynamicCommMem = 3;
|
|
||||||
const int kGetAllOuts = -1;
|
const int kGetAllOuts = -1;
|
||||||
const uint64_t kMemAlignSize = 512;
|
const uint64_t kMemAlignSize = 512;
|
||||||
using MemReuseUtilPtr = mindspore::memreuse::MemReuseUtilPtr;
|
using MemReuseUtilPtr = mindspore::memreuse::MemReuseUtilPtr;
|
||||||
|
@ -42,10 +39,10 @@ class MemoryManager {
|
||||||
dynamic_mem_offset_ = 0;
|
dynamic_mem_offset_ = 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
void MallocReusedDynamicMem(session::KernelGraph *graph);
|
void MallocReusedDynamicMem(const session::KernelGraph *graph);
|
||||||
uint8_t *MallocOutputMem(const AnfNodePtr &node, size_t index, int flag, size_t size);
|
uint8_t *MallocOutputMem(const AnfNodePtr &node, size_t index, MemType type, size_t size);
|
||||||
uint8_t *MallocWorkSpaceMem(const AnfNodePtr &node, size_t index, int flag, size_t size);
|
uint8_t *MallocWorkSpaceMem(const AnfNodePtr &node, size_t index, MemType type, size_t size);
|
||||||
virtual uint8_t *MallocMem(int flag, size_t size);
|
virtual uint8_t *MallocMem(MemType type, size_t size);
|
||||||
|
|
||||||
virtual bool MallocMemFromMemPool(const DeviceAddressPtr address, size_t size);
|
virtual bool MallocMemFromMemPool(const DeviceAddressPtr address, size_t size);
|
||||||
virtual void *MallocMemFromMemPool(size_t size);
|
virtual void *MallocMemFromMemPool(size_t size);
|
||||||
|
|
Loading…
Reference in New Issue