!2966 reuse communication op output's memory

Merge pull request !2966 from laiyongqiang/hcom_memreuse
This commit is contained in:
mindspore-ci-bot 2020-07-16 14:51:12 +08:00 committed by Gitee
commit 25ee322ba3
12 changed files with 186 additions and 18 deletions

View File

@ -25,7 +25,8 @@
namespace mindspore {
namespace memreuse {
enum RefCountType { kDynamicRefCount, kStaticRefCount };
enum NodeType { NORMAL, SPECIAL };
enum NodeType { COMMON_NODE, COMMUNICATION_NODE };
enum KernelRefType { COMMON, REFNODE_OUTPUT, COMM_NOTREUSE, COMM_REUSE, SUMMARY };
static constexpr int kInitIndex = -1;
class KernelRefCount {
public:
@ -36,6 +37,7 @@ class KernelRefCount {
size_t offset_;
size_t size_;
int index_;
KernelRefType type_;
// remember to reset offset
KernelRefCount()
: stream_id_(0),
@ -44,6 +46,7 @@ class KernelRefCount {
offset_(0),
size_(0),
index_(kInitIndex),
type_(COMMON),
reftype_(kStaticRefCount) {}
~KernelRefCount() = default;
void SetKernelRefCountInfo(int index, size_t size, RefCountType reftype);
@ -65,7 +68,7 @@ class KernelDef {
KernelMap inputs_;
KernelMap outputs_;
KernelMap wk_space_;
NodeType dirty = NORMAL;
NodeType type_ = COMMON_NODE;
KernelDef() = default;
~KernelDef() = default;
void set_input_refs(const KernelRefCountPtrList &kernelRefPtrList) { input_refs_ = kernelRefPtrList; }

View File

@ -46,6 +46,8 @@ bool MemReuseUtil::InitDynamicOutputKernelRef() {
if (iter == kernel_output_refs_.end()) {
auto output_sizes = kernel_mod->GetOutputSizeList();
KernelRefCountPtrList kernel_refs;
bool is_comm_op = AnfAlgo::IsCommunicationOp(kernel_cnode);
size_t output_index = 0;
for (auto size : output_sizes) {
total_dy_size_ += size;
// do not MallocDynamicMem just record this
@ -54,9 +56,20 @@ bool MemReuseUtil::InitDynamicOutputKernelRef() {
auto curr_stream_id = AnfAlgo::GetStreamId(kernel_cnode);
kernel_ref->stream_id_ = curr_stream_id;
kernel_ref->SetKernelRefCountInfo(index, size, kDynamicRefCount);
if (is_comm_op) {
kernel_ref->type_ = COMM_REUSE;
} else {
session::AnfWithOutIndex out_pair(kernel_cnode, output_index);
if (graph_->IsInRefOutputMap(out_pair)) {
kernel_ref->type_ = REFNODE_OUTPUT;
} else {
kernel_ref->type_ = COMMON;
}
}
kernel_refs.push_back(kernel_ref);
kernel_out_ref_num++;
total_refs_list_.push_back(kernel_ref);
output_index++;
}
if (!kernel_refs.empty()) {
kernel_output_refs_[key] = kernel_refs;
@ -155,9 +168,19 @@ void MemReuseUtil::SetInputMap(const CNodePtr &kernel, KernelDef *kernel_def_ptr
MS_EXCEPTION_IF_NULL(kernel);
MS_EXCEPTION_IF_NULL(kernel_def_ptr);
auto key = kernel.get();
for (size_t i = 0; i < AnfAlgo::GetInputTensorNum(kernel); ++i) {
bool is_comm_op = AnfAlgo::IsCommunicationOp(kernel);
size_t input_tensor_num = AnfAlgo::GetInputTensorNum(kernel);
for (size_t i = 0; i < input_tensor_num; ++i) {
auto ref_ptr = GetKernelInputRef(kernel, i);
if (ref_ptr != nullptr) {
if (is_comm_op) {
if (input_tensor_num == 1) {
ref_ptr->type_ = COMM_REUSE;
} else {
ref_ptr->type_ = COMM_NOTREUSE;
}
}
if (ref_ptr->reftype() == kStaticRefCount) {
continue;
} else if (ref_ptr->reftype() == kDynamicRefCount) {
@ -258,6 +281,11 @@ void MemReuseUtil::SetKernelDefMap() {
auto key = kernel.get();
kernel_def_ptr->set_input_refs(kernel_def_ptr->inputs_[key]);
kernel_def_ptr->set_output_refs(kernel_def_ptr->outputs_[key]);
if (AnfAlgo::IsCommunicationOp(kernel)) {
kernel_def_ptr->type_ = COMMUNICATION_NODE;
} else {
kernel_def_ptr->type_ = COMMON_NODE;
}
kernel_def_ptr_list_.push_back(kernel_def_ptr);
kernel_map_[key] = kernel_def_ptr;
}
@ -337,6 +365,7 @@ void MemReuseUtil::SetSummaryNodesRefCount() {
KernelRefCountPtr kernel_ref = kernel_output_refs_[node.get()][index];
kernel_ref->ref_count_ = kMaxRefCount;
kernel_ref->ref_count_dynamic_use_ = kMaxRefCount;
kernel_ref->type_ = SUMMARY;
total_summary_size += kernel_ref->size_;
MS_LOG(INFO) << "Set summary node's ref count, node: " << node->fullname_with_scope() << " index: " << index;
} else {

View File

@ -33,11 +33,11 @@ void BestFitMemReuse::InitMemReuseInfo(const MemReuseUtil *mem_reuse_util_ptr) {
set_op_ptr_list(mem_reuse_util_ptr->kernel_def_ptr_list());
// check info Correctness
for (auto &tensor : tensor_ptr_list_) {
tensor->size_ = AlignMemorySize(tensor->size_);
tensor->size_ = AlignCommonMemorySize(tensor->size_);
}
// align wk size to 512 && refcount == 1
for (auto &wk : wk_tensor_list_) {
wk->size_ = AlignMemorySize(wk->size_);
wk->size_ = AlignCommonMemorySize(wk->size_);
wk->ref_count_ = 1;
}
#ifdef ENABLE_D
@ -135,11 +135,23 @@ bool BestFitMemReuse::IsUsable(const KernelDefPtr &kernel_curr, const MembufPtr
return false;
}
void BestFitMemReuse::AssignNodeOutputOffset() {
void BestFitMemReuse::AssignCommonNodeOutputOffset() {
MS_EXCEPTION_IF_NULL(current_kernel_);
for (auto &tensor_idx : current_kernel_->GetOutputRefIndexs()) {
size_t index = GetTensorIndex(tensor_idx);
auto tensor_desc = tensor_ptr_list_[index];
MS_EXCEPTION_IF_NULL(tensor_desc);
if (tensor_desc->type_ == REFNODE_OUTPUT) {
total_refoutput_size += tensor_desc->size_;
continue;
} else if (tensor_desc->type_ == COMM_NOTREUSE) {
total_comm_not_reuse_size += tensor_desc->size_;
} else if (tensor_desc->type_ == COMM_REUSE) {
// get align size for communication op's single input
tensor_desc->size_ = AlignCommunicationMemorySize(tensor_desc->size_);
total_comm_reuse_size += tensor_desc->size_;
}
auto reusable_membuf_map = GetReusableMembufMap(tensor_desc->size_);
if (!reusable_membuf_map.empty()) {
auto membuf_index = reusable_membuf_map.begin()->second;
@ -152,6 +164,86 @@ void BestFitMemReuse::AssignNodeOutputOffset() {
MemReuseChecker::GetInstance().IsAddNewMembuf_ = true;
#endif
}
// skip left align border for communication op single input to used
if (tensor_desc->type_ == COMM_REUSE) {
tensor_desc->offset_ += kDefaultMemAlignSize;
}
}
}
void BestFitMemReuse::AssignCommunicationNodeOutputOffset() {
size_t total_kernel_output_size = 0;
size_t output_num = 0;
// get all output size
MS_EXCEPTION_IF_NULL(current_kernel_);
for (auto &tensor_idx : current_kernel_->GetOutputRefIndexs()) {
size_t index = GetTensorIndex(tensor_idx);
auto tensor_desc = tensor_ptr_list_[index];
MS_EXCEPTION_IF_NULL(tensor_desc);
if (tensor_desc->type_ == COMM_REUSE) {
total_comm_reuse_size += tensor_desc->size_;
total_comm_output_reuse_size += tensor_desc->size_;
total_kernel_output_size += tensor_desc->size_;
} else {
MS_LOG(ERROR) << "All communication op's outputs should be memory reuse, Kernel:"
<< current_kernel_->scope_full_name();
continue;
}
}
total_kernel_output_size = AlignCommunicationMemorySize(total_kernel_output_size);
// add left align border for the first output and right align border for the last output to alloc align border memory
size_t output_index = 0;
for (auto &tensor_idx : current_kernel_->GetOutputRefIndexs()) {
size_t index = GetTensorIndex(tensor_idx);
auto tensor_desc = tensor_ptr_list_[index];
MS_EXCEPTION_IF_NULL(tensor_desc);
if (output_index == 0 || output_index == output_num - 1) {
tensor_desc->size_ += kDefaultMemAlignSize;
}
output_index++;
}
auto reusable_membuf_map = GetReusableMembufMap(total_kernel_output_size);
if (!reusable_membuf_map.empty()) {
auto membuf_index = reusable_membuf_map.begin()->second;
output_index = 0;
for (auto &tensor_idx : current_kernel_->GetOutputRefIndexs()) {
size_t index = GetTensorIndex(tensor_idx);
auto tensor_desc = tensor_ptr_list_[index];
MS_EXCEPTION_IF_NULL(tensor_desc);
ReuseExistMembuf(tensor_desc.get(), membuf_index + output_index, kDynamicMem);
// skip skip left align border for communication op's first output to used
if (output_index == 0) {
tensor_desc->offset_ += kDefaultMemAlignSize;
}
output_index++;
}
} else {
// no membuf can reuse, add new membuf after the membuf_ptr_list
output_index = 0;
for (auto &tensor_idx : current_kernel_->GetOutputRefIndexs()) {
size_t index = GetTensorIndex(tensor_idx);
auto tensor_desc = tensor_ptr_list_[index];
MS_EXCEPTION_IF_NULL(tensor_desc);
AddNewMembufPtr(tensor_desc.get(), kDynamicMem);
// skip align size offset for first output to used
if (output_index == 0) {
tensor_desc->offset_ += kDefaultMemAlignSize;
}
output_index++;
#ifdef MEM_REUSE_DEBUG
MemReuseChecker::GetInstance().IsAddNewMembuf_ = true;
#endif
}
}
}
void BestFitMemReuse::AssignNodeOutputOffset() {
if (current_kernel_->type_ == COMMUNICATION_NODE) {
AssignCommunicationNodeOutputOffset();
} else {
AssignCommonNodeOutputOffset();
}
}
@ -319,11 +411,17 @@ void BestFitMemReuse::ReleaseMembuf(size_t tensor_index, int flag) {
}
}
size_t BestFitMemReuse::AlignMemorySize(size_t size) const {
size_t BestFitMemReuse::AlignCommonMemorySize(size_t size) const {
// memory size 512 align
return (size + kDefaultMemAlignSize + kAttAlignSize) / kDefaultMemAlignSize * kDefaultMemAlignSize;
}
size_t BestFitMemReuse::AlignCommunicationMemorySize(size_t size) const {
// memory size 512 align and add communication memory: left align border memory - data - right align border memory
return kDefaultMemAlignSize + (size + kDefaultMemAlignSize - 1) / kDefaultMemAlignSize * kDefaultMemAlignSize +
kDefaultMemAlignSize;
}
size_t BestFitMemReuse::GetAllocatedSize() {
size_t AllocatedSize = kTotalSize;
if (membuf_ptr_list_.empty()) {
@ -412,6 +510,9 @@ void BestFitMemReuse::Reuse(const MemReuseUtil *mem_reuse_util_ptr) {
++op_num;
#endif
}
MS_LOG(INFO) << "Special Tensor total size: RefOutput: " << total_refoutput_size
<< " CommReuse: " << total_comm_reuse_size << " CommOutputReuse: " << total_comm_output_reuse_size
<< " CommNotReuse: " << total_comm_not_reuse_size;
#ifdef MEM_REUSE_DEBUG
MemReuseChecker::GetInstance().ExportMembufInfoIR();
MemReuseChecker::GetInstance().ExportAddNewMmebufIR();

View File

@ -74,6 +74,14 @@ class BestFitMemReuse {
* Assign output tensor memory offset of current kernel
*/
void AssignNodeOutputOffset();
/**
* Assign output tensor memory offset of common kernel
*/
void AssignCommonNodeOutputOffset();
/**
* Assign output tensor memory offset of communication kernel
*/
void AssignCommunicationNodeOutputOffset();
/**
* Update input tensor's status of current kernel, and the status of membuf used by current kernel
*/
@ -110,8 +118,10 @@ class BestFitMemReuse {
void AddNewMembufPtr(KernelRefCount *tensor_desc, int flag);
// Merge unused membuf
void ReleaseMembuf(size_t tensor_index, int flag);
// Memory address alignment 512
size_t AlignMemorySize(size_t size) const;
// Memory address alignment for common memory
size_t AlignCommonMemorySize(size_t size) const;
// Memory address alignment for communication used memory
size_t AlignCommunicationMemorySize(size_t size) const;
int GetRealIndex(size_t index, int flag = kDynamicMem) const;
size_t GetTensorIndex(int index) const;
size_t GetWorkspaceIndex(int index) const;
@ -153,6 +163,10 @@ class BestFitMemReuse {
// kernel_front_map_, key: the kernel_def, value: kernels before this kernel_def
std::map<KernelDefPtr, std::set<KernelDefPtr>> kernel_front_map_;
std::vector<std::vector<uint32_t>> stream_groups_;
size_t total_refoutput_size{0};
size_t total_comm_reuse_size{0};
size_t total_comm_output_reuse_size{0};
size_t total_comm_not_reuse_size{0};
};
} // namespace memreuse
} // namespace mindspore

View File

@ -170,12 +170,14 @@ void MemReuseChecker::CheckMemReuseIR(const KernelRefCountPtrList &total_refs_li
ofs << "all_tensor_refs:\n";
ofs << "index:"
<< "\tsize:"
<< "\trefcount:\n";
<< "\trefcount:"
<< "\ttype:\n";
for (auto &ref : total_refs_list) {
ofs << "%" << ref->index_ << "T"
<< "\t"
<< "#" << ref->size_ << "S"
<< "\t" << ref->ref_count_ << "C"
<< "\t" << ref->type_ << "t"
<< "\n";
}
ofs << "kernel_def exc_order:\n";
@ -241,7 +243,7 @@ bool MemReuseChecker::CheckGraphOutputAssigned(const session::KernelGraph *graph
void MemReuseChecker::ExportMemOpIr(const KernelDef *def, std::ofstream &ofs, int def_idx) {
auto scope_name = def->scope_full_name();
std::string split_name = GetSplitName(scope_name);
ofs << "$" << def_idx << "\t" << split_name << "\t";
ofs << "$" << def_idx << "\t" << split_name << "\t" << static_cast<int>(def->type_) << "\t";
ofs << "inputs[";
for (auto &in : def->inputs_) {
for (auto &in_ref : in.second) {

View File

@ -95,6 +95,12 @@ uint8_t *AscendMemoryManager::MallocStaticMem(size_t size, bool communication_me
} else {
align_size = GetCommonAlignSize(size);
}
auto device_mem_pool_offset = AscendMemoryPool::GetInstance().device_mem_pool_offset();
MS_LOG(INFO) << "Malloc Memory: Static, total[" << device_mem_size_ << "] (dynamic[" << total_dynamic_size_
<< "] memory pool[" << device_mem_pool_offset << "])"
<< " malloc [" << align_size << "]";
if (communication_mem) {
// create protect area [kMemAlignSize -- data -- kMemAlignSize]
uint8_t *alloc_address = reinterpret_cast<uint8_t *>(AscendMemoryPool::GetInstance().AllocTensorMem(align_size));
@ -111,12 +117,17 @@ uint8_t *AscendMemoryManager::MallocDynamicMem(size_t size, bool communication_m
} else {
align_size = GetCommonAlignSize(size);
}
auto device_mem_pool_offset = AscendMemoryPool::GetInstance().device_mem_pool_offset();
MS_LOG(INFO) << "Malloc Memory: Dynamic, total[" << device_mem_size_ << "] (dynamic[" << total_dynamic_size_
<< "] memory pool[" << device_mem_pool_offset << "])"
<< " malloc [" << align_size << "]";
if (dynamic_mem_offset_ < align_size) {
MS_LOG(EXCEPTION) << "Out of memory!!! total[" << device_mem_size_ << "] (dynamic[" << total_dynamic_size_
<< "]) malloc [" << align_size << "] failed!";
}
auto new_offset = dynamic_mem_offset_ - align_size;
auto device_mem_pool_offset = AscendMemoryPool::GetInstance().device_mem_pool_offset();
if (new_offset <= device_mem_pool_offset) {
MS_LOG(EXCEPTION) << "Out of memory!!! total[" << device_mem_size_ << "] (dynamic[" << total_dynamic_size_
<< "] memory pool[" << device_mem_pool_offset << "])"

View File

@ -399,7 +399,7 @@ void KernelRuntime::UpdateRefNodeOutputMem(const session::KernelGraph *graph) {
}
void KernelRuntime::AssignCommunicationNodeMem(int flag, const AnfNodePtr &node) {
AssignCommunicationNodeInputMem(node);
AssignCommunicationNodeInputMem(flag, node);
AssignCommunicationNodeOutputMem(flag, node);
}
@ -429,6 +429,11 @@ void KernelRuntime::AssignCommunicationNodeOutputMem(int flag, const AnfNodePtr
total_size += mem_size;
align_size_list.emplace_back(mem_size);
}
if (flag == kReuseDynamicMem) {
// reuse communication op's all outputs' memory
flag = kReuseDynamicCommMem;
}
uint8_t *output_ptr = mem_manager_->MallocOutputMem(node, 0, flag, total_size);
for (size_t j = 0; j < align_size_list.size(); ++j) {
std::string output_format = AnfAlgo::GetOutputFormat(node, j);
@ -457,7 +462,7 @@ DeviceAddressPtr KernelRuntime::PreAssignCNodeMemory(const AnfNodePtr &anf_node,
return address;
}
void KernelRuntime::AssignCommunicationNodeInputMem(const AnfNodePtr &node) {
void KernelRuntime::AssignCommunicationNodeInputMem(int flag, const AnfNodePtr &node) {
auto context_ptr = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(context_ptr);
MS_EXCEPTION_IF_NULL(node);
@ -478,7 +483,7 @@ void KernelRuntime::AssignCommunicationNodeInputMem(const AnfNodePtr &node) {
total_size += mem_size;
addr_size.emplace_back(address.get(), mem_size);
}
uint8_t *input_ptr = mem_manager_->MallocOutputMem(node, 0, kDynamicMem, total_size);
uint8_t *input_ptr = mem_manager_->MallocOutputMem(node, 0, flag, total_size);
for (const auto &iter : addr_size) {
MS_EXCEPTION_IF_NULL(iter.first);
iter.first->set_ptr(input_ptr);

View File

@ -88,7 +88,7 @@ class KernelRuntime {
void UpdateRefNodeOutputMem(const session::KernelGraph *graph);
void AssignCommunicationNodeOutputMem(int flag, const AnfNodePtr &node);
void AssignCommunicationNodeInputMem(const AnfNodePtr &node);
void AssignCommunicationNodeInputMem(int flag, const AnfNodePtr &node);
void AssignCommunicationNodeMem(int flag, const AnfNodePtr &node);
#ifdef ENABLE_DUMP_E2E
bool SetDumpConf();

View File

@ -57,6 +57,9 @@ uint8_t *MemoryManager::MallocOutputMem(const AnfNodePtr &node, size_t index, in
}
if (flag == kStaticMem) {
ptr = MallocStaticMem(size, communication_mem);
} else if (flag == kReuseDynamicCommMem) {
MS_EXCEPTION_IF_NULL(mem_reuse_util_ptr_);
ptr = mem_reuse_util_ptr_->GetNodeOutputPtr(node, index);
} else {
ptr = MallocDynamicMem(size, communication_mem);
}

View File

@ -25,6 +25,7 @@ namespace device {
const int kStaticMem = 0;
const int kDynamicMem = 1;
const int kReuseDynamicMem = 2;
const int kReuseDynamicCommMem = 3;
const int kGetAllOuts = -1;
const uint64_t kMemAlignSize = 512;
using MemReuseUtilPtr = mindspore::memreuse::MemReuseUtilPtr;

View File

@ -146,7 +146,7 @@ TEST_F(TestMemReuseAllocator, mem_reuse_allocator_split_membuf) {
TEST_F(TestMemReuseAllocator, mem_reuse_allocator_align) {
auto best_fit_mem_reuse = std::make_shared<BestFitMemReuse>();
auto size = best_fit_mem_reuse->AlignMemorySize(510);
auto size = best_fit_mem_reuse->AlignCommonMemorySize(510);
ASSERT_EQ(size, 1024);
}
} // namespace memreuse

View File

@ -225,7 +225,6 @@ TEST_F(TestMemReuseWithPy, KernelRef) {
ASSERT_EQ(kernel_ref_count_ptr->size_, 512);
KernelDefPtr kernel_def_ptr = std::make_shared<KernelDef>();
ASSERT_NE(kernel_def_ptr, nullptr);
ASSERT_EQ(kernel_def_ptr->dirty, false);
MembufPtr membuf_ptr = std::make_shared<Membuf>();
ASSERT_NE(membuf_ptr, nullptr);
}