forked from mindspore-Ecosystem/mindspore
!1271 refactor of memreuse allocator to adapt the control stream
Merge pull request !1271 from yangjie159/refactor_memreuse_allocator
This commit is contained in:
commit
5e2f440eed
|
@ -47,7 +47,7 @@ std::vector<int> KernelDef::GetOutputRefIndexs() const {
|
|||
return output_ref_indexs;
|
||||
}
|
||||
|
||||
std::vector<int> KernelDef::GetWkRefIndexs() const {
|
||||
std::vector<int> KernelDef::GetWorkspaceRefIndexs() const {
|
||||
std::vector<int> wk_ref_indexs;
|
||||
if (wk_space_.empty()) {
|
||||
return wk_ref_indexs;
|
||||
|
|
|
@ -20,6 +20,7 @@
|
|||
#include <map>
|
||||
#include <string>
|
||||
#include <memory>
|
||||
#include <set>
|
||||
|
||||
namespace mindspore {
|
||||
namespace memreuse {
|
||||
|
@ -73,13 +74,15 @@ class KernelDef {
|
|||
KernelRefCountPtrList output_refs() const { return output_refs_; }
|
||||
std::vector<int> GetInputRefIndexs() const;
|
||||
std::vector<int> GetOutputRefIndexs() const;
|
||||
std::vector<int> GetWkRefIndexs() const;
|
||||
std::vector<int> GetWorkspaceRefIndexs() const;
|
||||
void set_stream_id(uint32_t stream_id) { stream_id_ = stream_id; }
|
||||
uint32_t stream_id() const { return stream_id_; }
|
||||
void set_kernel_name(const std::string &kernel_name) { kernel_name_ = kernel_name; }
|
||||
std::string kernel_name() const { return kernel_name_; }
|
||||
void set_scope_full_name(const std::string &scop_name) { scop_full_name_ = scop_name; }
|
||||
std::string scope_full_name() const { return scop_full_name_; }
|
||||
void InsertInputKernel(const std::shared_ptr<KernelDef> &input_kernel) { input_kernels_.insert(input_kernel); }
|
||||
const std::set<std::shared_ptr<KernelDef>> &input_kernels() { return input_kernels_; }
|
||||
|
||||
private:
|
||||
std::string scop_full_name_;
|
||||
|
@ -87,6 +90,7 @@ class KernelDef {
|
|||
uint32_t stream_id_{0};
|
||||
KernelRefCountPtrList input_refs_;
|
||||
KernelRefCountPtrList output_refs_;
|
||||
std::set<std::shared_ptr<KernelDef>> input_kernels_;
|
||||
};
|
||||
using KernelDefPtr = std::shared_ptr<KernelDef>;
|
||||
} // namespace memreuse
|
||||
|
|
|
@ -245,6 +245,34 @@ void MemReuseUtil::SetKernelDefMap() {
|
|||
kernel_def_ptr->set_input_refs(kernel_def_ptr->inputs_[key]);
|
||||
kernel_def_ptr->set_output_refs(kernel_def_ptr->outputs_[key]);
|
||||
kernel_def_ptr_list_.push_back(kernel_def_ptr);
|
||||
kernel_map_[key] = kernel_def_ptr;
|
||||
}
|
||||
SetKernelDefInputs();
|
||||
}
|
||||
|
||||
void MemReuseUtil::SetKernelDefInputs() {
|
||||
for (const auto &kernel : graph_->execution_order()) {
|
||||
auto key = kernel.get();
|
||||
// find kernel_def according to cnode addr
|
||||
auto iter = kernel_map_.find(key);
|
||||
if (iter == kernel_map_.end()) {
|
||||
MS_LOG(EXCEPTION) << "kernel [" << kernel->fullname_with_scope() << "] is not init.";
|
||||
}
|
||||
auto kernel_def = iter->second;
|
||||
for (size_t i = 0; i < AnfAlgo::GetInputTensorNum(kernel); ++i) {
|
||||
auto ref_ptr = GetKernelInputRef(kernel, i);
|
||||
if (ref_ptr != nullptr) {
|
||||
// set the inputs of this kernel_def
|
||||
auto input_node = AnfAlgo::GetInputNode(kernel, i);
|
||||
auto input = AnfAlgo::VisitKernel(input_node, 0);
|
||||
auto input_key = (input.first).get();
|
||||
auto input_iter = kernel_map_.find(input_key);
|
||||
if (input_iter == kernel_map_.end()) {
|
||||
MS_LOG(EXCEPTION) << "kernel [" << (input.first)->fullname_with_scope() << "] is not init.";
|
||||
}
|
||||
kernel_def->InsertInputKernel(input_iter->second);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -61,6 +61,7 @@ class MemReuseUtil {
|
|||
void SetInputMap(const CNodePtr &kernel, KernelDef *kernel_def_ptr);
|
||||
void SetOutputMap(const CNodePtr &kernel, KernelDef *kernel_def_ptr);
|
||||
void SetWkMap(const CNodePtr &kernel, KernelDef *kernel_def_ptr);
|
||||
void SetKernelDefInputs();
|
||||
void SetReuseRefCount();
|
||||
// Set the reference count of graph output specially.
|
||||
void SetGraphOutputRefCount();
|
||||
|
@ -94,6 +95,8 @@ class MemReuseUtil {
|
|||
size_t total_workspace_size_ = 0;
|
||||
size_t total_reuseworkspace_size_ = 0;
|
||||
uint8_t *mem_base_{nullptr};
|
||||
// kernel_map_: key is the AnfNodePtr addr, value is the KernelDef
|
||||
std::map<KernelKey, KernelDefPtr> kernel_map_;
|
||||
};
|
||||
using MemReuseUtilPtr = std::shared_ptr<MemReuseUtil>;
|
||||
} // namespace memreuse
|
||||
|
|
|
@ -15,9 +15,6 @@
|
|||
*/
|
||||
|
||||
#include "pre_activate/mem_reuse/mem_reuse_allocator.h"
|
||||
#include <memory>
|
||||
#include <algorithm>
|
||||
#include <set>
|
||||
#include "pre_activate/mem_reuse/mem_reuse.h"
|
||||
#include "pre_activate/mem_reuse/mem_reuse_checker.h"
|
||||
|
||||
|
@ -25,9 +22,9 @@ namespace mindspore {
|
|||
namespace memreuse {
|
||||
void BestFitMemReuse::InitMemReuseInfo(const MemReuseUtil *mem_reuse_util_ptr) {
|
||||
MS_EXCEPTION_IF_NULL(mem_reuse_util_ptr);
|
||||
tensor_ptr_list_ = mem_reuse_util_ptr->total_refs_list();
|
||||
wk_tensor_list_ = mem_reuse_util_ptr->total_wk_ref_list();
|
||||
op_ptr_list_ = mem_reuse_util_ptr->kernel_def_ptr_list();
|
||||
set_tensor_ptr_list(mem_reuse_util_ptr->total_refs_list());
|
||||
set_workspace_ptr_list(mem_reuse_util_ptr->total_wk_ref_list());
|
||||
set_op_ptr_list(mem_reuse_util_ptr->kernel_def_ptr_list());
|
||||
// check info Correctness
|
||||
for (auto &tensor : tensor_ptr_list_) {
|
||||
tensor->size_ = AlignMemorySize(tensor->size_);
|
||||
|
@ -37,63 +34,65 @@ void BestFitMemReuse::InitMemReuseInfo(const MemReuseUtil *mem_reuse_util_ptr) {
|
|||
wk->size_ = AlignMemorySize(wk->size_);
|
||||
wk->ref_count_ = 1;
|
||||
}
|
||||
auto stream_reuse = std::make_shared<StreamReuse>();
|
||||
stream_reuse->SetStreamReuseResource();
|
||||
parallel_streams_map_ = stream_reuse->parallel_streams_map();
|
||||
}
|
||||
|
||||
bool BestFitMemReuse::CheckMembufIndx(const std::vector<MembufPtr> &membuf_ptr_list, size_t check_idx) const {
|
||||
return check_idx < membuf_ptr_list.size();
|
||||
}
|
||||
|
||||
bool BestFitMemReuse::IsMembufListEmpty(const std::vector<MembufPtr> &membuf_ptr_list) const {
|
||||
return membuf_ptr_list.empty();
|
||||
}
|
||||
|
||||
int BestFitMemReuse::GetFacIdx(size_t real_idx, int flag) const {
|
||||
if (flag == kDyFac) {
|
||||
return SizeToInt(real_idx);
|
||||
} else if (flag == kWkFac) {
|
||||
auto wk_fac_idx = kWkIndexFactor * SizeToInt(real_idx + 1);
|
||||
return wk_fac_idx;
|
||||
} else {
|
||||
MS_LOG(EXCEPTION) << "flag " << flag << " is invalid";
|
||||
void BestFitMemReuse::InitKernelDependence() {
|
||||
for (const auto &kernel : op_ptr_list_) {
|
||||
std::set<KernelDefPtr> front;
|
||||
std::queue<KernelDefPtr> to_visit;
|
||||
to_visit.push(kernel);
|
||||
// find all kernels before current kernel
|
||||
while (!to_visit.empty()) {
|
||||
auto curr = to_visit.front();
|
||||
to_visit.pop();
|
||||
if (front.count(curr)) {
|
||||
continue;
|
||||
}
|
||||
front.insert(curr);
|
||||
auto iter = kernel_front_map_.find(curr);
|
||||
if (iter != kernel_front_map_.end()) {
|
||||
auto visited_front = iter->second;
|
||||
front.insert(visited_front.begin(), visited_front.end());
|
||||
continue;
|
||||
}
|
||||
for (const auto &input : curr->input_kernels()) {
|
||||
to_visit.push(input);
|
||||
}
|
||||
}
|
||||
kernel_front_map_[kernel] = front;
|
||||
}
|
||||
}
|
||||
|
||||
int BestFitMemReuse::GetRealIdx(int fac_idx, int flag) const {
|
||||
// membuf index maybe invalid_index
|
||||
if (fac_idx == kInvalidIndex) {
|
||||
MS_LOG(EXCEPTION) << "this membuf index is invalid";
|
||||
bool BestFitMemReuse::IsUsable(const KernelDefPtr &kernel_curr, const KernelDefPtr &kernel_prev) {
|
||||
// determine whether the kernel_curr can reuse kernel_prev's output tensor membuf
|
||||
MS_EXCEPTION_IF_NULL(kernel_curr);
|
||||
MS_EXCEPTION_IF_NULL(kernel_prev);
|
||||
auto curr_stream_id = kernel_curr->stream_id();
|
||||
auto prev_stream_id = kernel_prev->stream_id();
|
||||
if (curr_stream_id == prev_stream_id) {
|
||||
return true;
|
||||
}
|
||||
if (flag == kDyFac) {
|
||||
return fac_idx;
|
||||
} else if (flag == kWkFac) {
|
||||
if (fac_idx % 10 == 0) {
|
||||
auto wk_fac_idx = fac_idx / kWkIndexFactor + 1;
|
||||
return wk_fac_idx;
|
||||
} else {
|
||||
MS_LOG(EXCEPTION) << "fac_idx: " << fac_idx << "is invalid";
|
||||
}
|
||||
} else {
|
||||
MS_LOG(EXCEPTION) << "flag: " << flag << " is invalid";
|
||||
auto iter = kernel_front_map_.find(kernel_curr);
|
||||
if (iter == kernel_front_map_.end()) {
|
||||
MS_LOG(EXCEPTION) << kernel_curr->scope_full_name() << " is not init.";
|
||||
}
|
||||
auto kernel_curr_front = iter->second;
|
||||
return kernel_curr_front.count(kernel_prev);
|
||||
}
|
||||
|
||||
void BestFitMemReuse::AssignNodeOutputOffset(const KernelDef *kernel_def_ptr) {
|
||||
MS_EXCEPTION_IF_NULL(kernel_def_ptr);
|
||||
for (auto &tensor_idx : kernel_def_ptr->GetOutputRefIndexs()) {
|
||||
CheckTensorIndex(tensor_idx);
|
||||
auto tensor_desc = tensor_ptr_list_[IntToSize(tensor_idx)];
|
||||
void BestFitMemReuse::AssignNodeOutputOffset() {
|
||||
for (auto &tensor_idx : current_kernel_->GetOutputRefIndexs()) {
|
||||
size_t index = GetTensorIndex(tensor_idx);
|
||||
auto tensor_desc = tensor_ptr_list_[index];
|
||||
MS_EXCEPTION_IF_NULL(tensor_desc);
|
||||
auto reusable_membuf_map = GetReusableMembufMap(tensor_desc->size_);
|
||||
if (!reusable_membuf_map.empty()) {
|
||||
auto membuf_index = reusable_membuf_map.begin()->second;
|
||||
// find the best suitable membuf in membuf list, and reuse it
|
||||
ReuseExistMembuf(tensor_desc.get(), membuf_index, kDyFac);
|
||||
ReuseExistMembuf(tensor_desc.get(), membuf_index, kDynamicMem);
|
||||
} else {
|
||||
// no membuf can reuse, add new membuf after the membuf_ptr_list
|
||||
AddNewMembufPtr(tensor_desc.get(), kDyFac);
|
||||
AddNewMembufPtr(tensor_desc.get(), kDynamicMem);
|
||||
#ifdef MEM_REUSE_DEBUG
|
||||
MemReuseChecker::GetInstance().IsAddNewMembuf_ = true;
|
||||
#endif
|
||||
|
@ -101,43 +100,24 @@ void BestFitMemReuse::AssignNodeOutputOffset(const KernelDef *kernel_def_ptr) {
|
|||
}
|
||||
}
|
||||
|
||||
void BestFitMemReuse::AssignNodeWkOffset(const KernelDef *kernel_def_ptr) {
|
||||
MS_EXCEPTION_IF_NULL(kernel_def_ptr);
|
||||
for (auto &wk_idx : kernel_def_ptr->GetWkRefIndexs()) {
|
||||
if (IntToSize(wk_idx) >= wk_tensor_list_.size()) {
|
||||
MS_LOG(EXCEPTION) << "wk_idx: " << wk_idx << " is invalid";
|
||||
}
|
||||
auto wk_ref = wk_tensor_list_[IntToSize(wk_idx)];
|
||||
void BestFitMemReuse::AssignNodeWorkspaceOffset() {
|
||||
for (auto &wk_idx : current_kernel_->GetWorkspaceRefIndexs()) {
|
||||
size_t index = GetWorkspaceIndex(wk_idx);
|
||||
auto wk_ref = wk_tensor_list_[index];
|
||||
MS_EXCEPTION_IF_NULL(wk_ref);
|
||||
auto re_wk_membuf_map = GetReusableMembufMap(wk_ref->size_);
|
||||
if (!re_wk_membuf_map.empty()) {
|
||||
auto membuf_index = re_wk_membuf_map.begin()->second;
|
||||
ReuseExistMembuf(wk_ref.get(), membuf_index, kWkFac);
|
||||
ReuseExistMembuf(wk_ref.get(), membuf_index, kWorkspaceMem);
|
||||
} else {
|
||||
AddNewMembufPtr(wk_ref.get(), kWkFac);
|
||||
}
|
||||
}
|
||||
}
|
||||
// releas pre node wk
|
||||
void BestFitMemReuse::ReleasePreNodeWkSpace(const KernelDef *kernel_def_ptr) {
|
||||
for (auto &wk_idx : kernel_def_ptr->GetWkRefIndexs()) {
|
||||
auto wk_index = IntToSize(wk_idx);
|
||||
if (wk_index >= wk_tensor_list_.size()) {
|
||||
MS_LOG(EXCEPTION) << "wk_index: " << wk_index << " is larger than wk_tensor_list size" << wk_tensor_list_.size();
|
||||
}
|
||||
auto wk_tensor = wk_tensor_list_[wk_index];
|
||||
wk_tensor->ref_count_--;
|
||||
if (wk_tensor->ref_count_ == 0) {
|
||||
ReleaseMembuf(wk_index, kWkFac);
|
||||
AddNewMembufPtr(wk_ref.get(), kWorkspaceMem);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void BestFitMemReuse::ReuseExistMembuf(KernelRefCount *tensor_desc, size_t membuf_index, int flag) {
|
||||
MS_EXCEPTION_IF_NULL(tensor_desc);
|
||||
if (!CheckMembufIndx(membuf_ptr_list_, membuf_index)) {
|
||||
return;
|
||||
}
|
||||
CheckMembufIndx(membuf_index);
|
||||
auto membuf = membuf_ptr_list_[membuf_index];
|
||||
MS_EXCEPTION_IF_NULL(membuf);
|
||||
// first to split && then update membuf_info
|
||||
|
@ -153,11 +133,9 @@ std::map<size_t, size_t> BestFitMemReuse::GetReusableMembufMap(size_t tensor_siz
|
|||
std::map<size_t, size_t> size_map;
|
||||
for (size_t i = 0; i < membuf_ptr_list_.size(); ++i) {
|
||||
auto membuf = membuf_ptr_list_[i];
|
||||
auto called_ids = membuf->called_stream_ids_;
|
||||
auto index = i;
|
||||
bool IsMembufOk = membuf->status_ == kUnused && membuf->size_ >= tensor_size;
|
||||
bool has_parallel_id = HasParallelId(called_ids, current_stream_id_);
|
||||
if (IsMembufOk && !has_parallel_id) {
|
||||
bool is_membuf_ok = membuf->status_ == kUnused && membuf->size_ >= tensor_size;
|
||||
if (is_membuf_ok && IsUsable(current_kernel_, membuf->used_kernel_)) {
|
||||
(void)size_map.insert(std::make_pair(membuf->size_, index));
|
||||
break;
|
||||
}
|
||||
|
@ -168,13 +146,10 @@ std::map<size_t, size_t> BestFitMemReuse::GetReusableMembufMap(size_t tensor_siz
|
|||
void BestFitMemReuse::UpdateMembufInfo(KernelRefCount *tensor_desc, Membuf *membuf, int flag) {
|
||||
MS_EXCEPTION_IF_NULL(tensor_desc);
|
||||
MS_EXCEPTION_IF_NULL(membuf);
|
||||
auto fac_idx = GetFacIdx(IntToSize(tensor_desc->index_), flag);
|
||||
auto real_index = GetRealIndex(IntToSize(tensor_desc->index_), flag);
|
||||
membuf->status_ = kReused;
|
||||
membuf->stream_id_ = current_stream_id_;
|
||||
// clear before called_ids
|
||||
membuf->called_stream_ids_.clear();
|
||||
(void)membuf->called_stream_ids_.insert(current_stream_id_);
|
||||
membuf->index_ = fac_idx;
|
||||
membuf->index_ = real_index;
|
||||
membuf->used_kernel_ = current_kernel_;
|
||||
tensor_desc->offset_ = membuf->offset_;
|
||||
}
|
||||
|
||||
|
@ -182,52 +157,39 @@ bool BestFitMemReuse::IsSplit(size_t tensor_size, size_t membuf_size) const { re
|
|||
|
||||
void BestFitMemReuse::SplitMembuf(const KernelRefCount *tensor_desc, size_t membuf_index) {
|
||||
MS_EXCEPTION_IF_NULL(tensor_desc);
|
||||
if (!CheckMembufIndx(membuf_ptr_list_, membuf_index)) {
|
||||
return;
|
||||
}
|
||||
CheckMembufIndx(membuf_index);
|
||||
auto membuf = membuf_ptr_list_[membuf_index];
|
||||
MS_EXCEPTION_IF_NULL(membuf);
|
||||
auto bias = membuf->size_ - tensor_desc->size_;
|
||||
membuf->size_ = tensor_desc->size_;
|
||||
// to check if spilt membuf can be merge
|
||||
auto new_membuf =
|
||||
std::make_shared<Membuf>(current_stream_id_, kUnused, bias, membuf->offset_ + membuf->size_, kInvalidIndex);
|
||||
std::make_shared<Membuf>(kUnused, bias, membuf->offset_ + membuf->size_, kInvalidIndex, current_kernel_);
|
||||
(void)membuf_ptr_list_.insert(membuf_ptr_list_.begin() + SizeToInt(membuf_index + 1), new_membuf);
|
||||
MergeCalledIds(membuf.get(), new_membuf.get());
|
||||
}
|
||||
|
||||
void BestFitMemReuse::AddNewMembufPtr(KernelRefCount *tensor_desc, int flag) {
|
||||
MS_EXCEPTION_IF_NULL(tensor_desc);
|
||||
size_t membuf_offset = std::accumulate(membuf_ptr_list_.begin(), membuf_ptr_list_.end(), IntToSize(0),
|
||||
[](size_t sum, MembufPtr &membuf) { return sum + membuf->size_; });
|
||||
size_t membuf_size = tensor_desc->size_;
|
||||
auto fac_idx = GetFacIdx(IntToSize(tensor_desc->index_), flag);
|
||||
auto membuf = std::make_shared<Membuf>(current_stream_id_, kReused, membuf_size, membuf_offset, fac_idx);
|
||||
size_t membuf_offset = 0;
|
||||
if (!membuf_ptr_list_.empty()) {
|
||||
membuf_offset = membuf_ptr_list_.back()->offset_ + membuf_ptr_list_.back()->size_;
|
||||
}
|
||||
auto membuf_size = tensor_desc->size_;
|
||||
auto real_index = GetRealIndex(IntToSize(tensor_desc->index_), flag);
|
||||
auto membuf = std::make_shared<Membuf>(kReused, membuf_size, membuf_offset, real_index, current_kernel_);
|
||||
membuf_ptr_list_.push_back(membuf);
|
||||
tensor_desc->offset_ = membuf_offset;
|
||||
(void)membuf->called_stream_ids_.insert(current_stream_id_);
|
||||
}
|
||||
|
||||
void BestFitMemReuse::UpdateNodeInputAndMembuf(const KernelDef *kernel_def_ptr) {
|
||||
void BestFitMemReuse::UpdateNodeInputAndMembuf() {
|
||||
// process node input tensor
|
||||
for (const auto &tensor_idx : kernel_def_ptr->GetInputRefIndexs()) {
|
||||
auto tensor_index = IntToSize(tensor_idx);
|
||||
CheckTensorIndex(tensor_idx);
|
||||
for (const auto &tensor_idx : current_kernel_->GetInputRefIndexs()) {
|
||||
size_t tensor_index = GetTensorIndex(tensor_idx);
|
||||
auto tensor_desc = tensor_ptr_list_[tensor_index];
|
||||
auto fac_idx = GetFacIdx(tensor_index, kDyFac);
|
||||
MS_EXCEPTION_IF_NULL(tensor_desc);
|
||||
tensor_desc->ref_count_--;
|
||||
// find tensor_index -> membuf update it's called_ids
|
||||
for (size_t i = 0; i < membuf_ptr_list_.size(); ++i) {
|
||||
auto membuf = membuf_ptr_list_[i];
|
||||
// find it
|
||||
if (membuf->index_ == fac_idx) {
|
||||
(void)membuf->called_stream_ids_.insert(current_stream_id_);
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (tensor_desc->ref_count_ == 0) {
|
||||
ReleaseMembuf(tensor_index, kDyFac);
|
||||
ReleaseMembuf(tensor_index, kDynamicMem);
|
||||
} else if (tensor_desc->ref_count_ < 0) {
|
||||
MS_LOG(EXCEPTION) << "tensor: " << tensor_desc->index_ << " refcount: " << tensor_desc->ref_count_
|
||||
<< " check error";
|
||||
|
@ -235,14 +197,13 @@ void BestFitMemReuse::UpdateNodeInputAndMembuf(const KernelDef *kernel_def_ptr)
|
|||
}
|
||||
}
|
||||
|
||||
void BestFitMemReuse::ReleaseNodeUnusedOutput(const KernelDef *kernel_def_ptr) {
|
||||
for (auto &tensor_idx : kernel_def_ptr->GetOutputRefIndexs()) {
|
||||
auto tensor_index = IntToSize(tensor_idx);
|
||||
CheckTensorIndex(tensor_idx);
|
||||
void BestFitMemReuse::ReleaseNodeUnusedOutput() {
|
||||
for (auto &tensor_idx : current_kernel_->GetOutputRefIndexs()) {
|
||||
size_t tensor_index = GetTensorIndex(tensor_idx);
|
||||
auto tensor_desc = tensor_ptr_list_[tensor_index];
|
||||
MS_EXCEPTION_IF_NULL(tensor_desc);
|
||||
if (tensor_desc->ref_count_ == 0) {
|
||||
ReleaseMembuf(tensor_index, kDyFac);
|
||||
ReleaseMembuf(tensor_index, kDynamicMem);
|
||||
} else if (tensor_desc->ref_count_ < 0) {
|
||||
MS_LOG(EXCEPTION) << "tensor: " << tensor_desc->index_ << " refcount: " << tensor_desc->ref_count_
|
||||
<< " check error";
|
||||
|
@ -250,124 +211,57 @@ void BestFitMemReuse::ReleaseNodeUnusedOutput(const KernelDef *kernel_def_ptr) {
|
|||
}
|
||||
}
|
||||
|
||||
size_t BestFitMemReuse::FindIndx(const std::vector<MembufPtr> &membuf_ptr_list, int fac_idx) const {
|
||||
size_t membuf_index = membuf_ptr_list.size();
|
||||
for (size_t n = 0; n < membuf_ptr_list.size(); ++n) {
|
||||
auto membuf = membuf_ptr_list[n];
|
||||
MS_EXCEPTION_IF_NULL(membuf);
|
||||
if (membuf->index_ == fac_idx) {
|
||||
membuf_index = n;
|
||||
break;
|
||||
void BestFitMemReuse::ReleasePreNodeWorkspace(const KernelDef *kernel_def_ptr) {
|
||||
for (auto &workspace_index : kernel_def_ptr->GetWorkspaceRefIndexs()) {
|
||||
size_t index = GetWorkspaceIndex(workspace_index);
|
||||
auto wk_tensor = wk_tensor_list_[index];
|
||||
wk_tensor->ref_count_--;
|
||||
if (wk_tensor->ref_count_ == 0) {
|
||||
ReleaseMembuf(index, kWorkspaceMem);
|
||||
} else if (wk_tensor->ref_count_ < 0) {
|
||||
MS_LOG(EXCEPTION) << "tensor: " << wk_tensor->index_ << " refcount: " << wk_tensor->ref_count_ << " check error";
|
||||
}
|
||||
}
|
||||
return membuf_index;
|
||||
}
|
||||
|
||||
void BestFitMemReuse::ReleaseMembuf(size_t tensor_index, int flag) {
|
||||
auto fac_idex = GetFacIdx(tensor_index, flag);
|
||||
auto membuf_index = FindIndx(membuf_ptr_list_, fac_idex);
|
||||
if (!CheckMembufIndx(membuf_ptr_list_, membuf_index)) {
|
||||
if (membuf_ptr_list_.empty()) {
|
||||
return;
|
||||
}
|
||||
auto membuf = membuf_ptr_list_[membuf_index];
|
||||
auto real_index = GetRealIndex(tensor_index, flag);
|
||||
auto membuf_iter = std::find_if(membuf_ptr_list_.begin(), membuf_ptr_list_.end(),
|
||||
[real_index](const MembufPtr &membuf) { return membuf->index_ == real_index; });
|
||||
if (membuf_iter == membuf_ptr_list_.end()) {
|
||||
return;
|
||||
}
|
||||
auto membuf = (*membuf_iter);
|
||||
MS_EXCEPTION_IF_NULL(membuf);
|
||||
membuf->status_ = kUnused;
|
||||
if (membuf_index != (membuf_ptr_list_.size() - 1)) {
|
||||
auto membuf_next = membuf_ptr_list_[membuf_index + 1];
|
||||
if (membuf_iter != membuf_ptr_list_.end() - 1) {
|
||||
auto next_iter = membuf_iter + 1;
|
||||
auto membuf_next = (*next_iter);
|
||||
MS_EXCEPTION_IF_NULL(membuf_next);
|
||||
bool has_parallel_id = false;
|
||||
for (auto &cal_id : membuf->called_stream_ids_) {
|
||||
has_parallel_id = HasParallelId(membuf_next->called_stream_ids_, cal_id);
|
||||
if (has_parallel_id) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (membuf_next->status_ == kUnused && !has_parallel_id) {
|
||||
if (membuf_next->status_ == kUnused) {
|
||||
bool is_merge = IsUsable(current_kernel_, membuf_next->used_kernel_);
|
||||
if (is_merge) {
|
||||
membuf->size_ += membuf_next->size_;
|
||||
MergeCalledIds(membuf_next.get(), membuf.get());
|
||||
auto it = membuf_ptr_list_.begin() + SizeToInt(membuf_index + 1);
|
||||
(void)membuf_ptr_list_.erase(it);
|
||||
(void)membuf_ptr_list_.erase(next_iter);
|
||||
}
|
||||
}
|
||||
if (membuf_index != 0) {
|
||||
if (!CheckMembufIndx(membuf_ptr_list_, membuf_index - 1)) {
|
||||
return;
|
||||
}
|
||||
auto membuf_prev = membuf_ptr_list_[membuf_index - 1];
|
||||
if (membuf_iter != membuf_ptr_list_.begin()) {
|
||||
auto prev_iter = membuf_iter - 1;
|
||||
auto membuf_prev = (*prev_iter);
|
||||
MS_EXCEPTION_IF_NULL(membuf_prev);
|
||||
bool has_parallel_id = false;
|
||||
for (auto &cal_id : membuf->called_stream_ids_) {
|
||||
has_parallel_id = HasParallelId(membuf_prev->called_stream_ids_, cal_id);
|
||||
if (has_parallel_id) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (membuf_prev->status_ == kUnused && !has_parallel_id) {
|
||||
if (membuf_prev->status_ == kUnused) {
|
||||
bool is_merge = IsUsable(current_kernel_, membuf_prev->used_kernel_);
|
||||
if (is_merge) {
|
||||
membuf->size_ += membuf_prev->size_;
|
||||
membuf->offset_ = membuf_prev->offset_;
|
||||
MergeCalledIds(membuf_prev.get(), membuf.get());
|
||||
auto it = membuf_ptr_list_.begin() + SizeToInt(membuf_index - 1);
|
||||
(void)membuf_ptr_list_.erase(it);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
bool BestFitMemReuse::HasParallelId(const std::set<uint32_t> &called_ids, uint32_t curr_id) {
|
||||
if (called_ids.empty()) {
|
||||
MS_LOG(EXCEPTION) << "There is a invalid WkMembuf,called_ids is empty";
|
||||
}
|
||||
for (auto item : called_ids) {
|
||||
if (!IsReusableStream(curr_id, item)) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
void BestFitMemReuse::MergeCalledIds(const Membuf *membuf_target, Membuf *membuf) {
|
||||
MS_EXCEPTION_IF_NULL(membuf_target);
|
||||
MS_EXCEPTION_IF_NULL(membuf);
|
||||
for (auto target : membuf_target->called_stream_ids_) {
|
||||
(void)membuf->called_stream_ids_.insert(target);
|
||||
}
|
||||
}
|
||||
|
||||
void BestFitMemReuse::ReleaseParallStream() {
|
||||
std::vector<size_t> target_relea_idxs;
|
||||
for (size_t i = 0; i < membuf_ptr_list_.size(); ++i) {
|
||||
auto membuf = membuf_ptr_list_[i];
|
||||
if (membuf->status_ == kReused) {
|
||||
continue;
|
||||
}
|
||||
// for begin to end, so no need merge pre_membuf
|
||||
if (i != (membuf_ptr_list_.size() - 1)) {
|
||||
auto membuf_next = membuf_ptr_list_[i + 1];
|
||||
if (membuf_next->status_ == kReused) {
|
||||
continue;
|
||||
}
|
||||
MS_EXCEPTION_IF_NULL(membuf_next);
|
||||
// judge current id no parallel fro membuf && membuf_next
|
||||
bool has_parallel_id_crr = HasParallelId(membuf->called_stream_ids_, current_stream_id_);
|
||||
bool has_parallel_id_next = HasParallelId(membuf_next->called_stream_ids_, current_stream_id_);
|
||||
if (membuf->status_ == kUnused && membuf_next->status_ == kUnused && !has_parallel_id_crr &&
|
||||
!has_parallel_id_next) {
|
||||
membuf->size_ += membuf_next->size_;
|
||||
MergeCalledIds(membuf_next.get(), membuf.get());
|
||||
target_relea_idxs.push_back(i + 1);
|
||||
(void)membuf_ptr_list_.erase(prev_iter);
|
||||
}
|
||||
}
|
||||
}
|
||||
// erase all target membuf
|
||||
std::vector<MembufPtr> membuf_ptr_list_tmp;
|
||||
for (size_t j = 0; j < membuf_ptr_list_.size(); ++j) {
|
||||
for (auto idx : target_relea_idxs) {
|
||||
if (j != idx) {
|
||||
membuf_ptr_list_tmp.push_back(membuf_ptr_list_[j]);
|
||||
}
|
||||
}
|
||||
}
|
||||
membuf_ptr_list_.clear();
|
||||
(void)std::copy(membuf_ptr_list_tmp.begin(), membuf_ptr_list_tmp.end(), back_inserter(membuf_ptr_list_));
|
||||
}
|
||||
|
||||
size_t BestFitMemReuse::AlignMemorySize(size_t size) const {
|
||||
|
@ -380,74 +274,83 @@ size_t BestFitMemReuse::GetAllocatedSize() {
|
|||
if (membuf_ptr_list_.empty()) {
|
||||
return AllocatedSize;
|
||||
}
|
||||
AllocatedSize = (*membuf_ptr_list_.rbegin())->offset_ + (*membuf_ptr_list_.rbegin())->size_;
|
||||
AllocatedSize = membuf_ptr_list_.back()->offset_ + membuf_ptr_list_.back()->size_;
|
||||
MS_LOG(INFO) << "MemReuse Allocated Dynamic Size: " << AllocatedSize;
|
||||
return AllocatedSize;
|
||||
}
|
||||
|
||||
/**
|
||||
* parallel_streams_map: key, current_stream_id; value, streams parallel to current stream
|
||||
* @param curr_stream_id
|
||||
* @param target_stream_id
|
||||
* @return bool, if the target stream can be reused by current stream
|
||||
*/
|
||||
bool BestFitMemReuse::IsReusableStream(uint32_t curr_stream_id, uint32_t target_stream_id) {
|
||||
auto iter_parall = parallel_streams_map_.find(curr_stream_id);
|
||||
if (parallel_streams_map_.empty() || (iter_parall == parallel_streams_map_.end())) {
|
||||
// no parallel stream exists
|
||||
return true;
|
||||
}
|
||||
auto curr_parallel_set = iter_parall->second;
|
||||
return curr_parallel_set.find(target_stream_id) == curr_parallel_set.end();
|
||||
}
|
||||
|
||||
bool BestFitMemReuse::IsRelease(const std::string &kernel_name) {
|
||||
bool BestFitMemReuse::IsRelease() {
|
||||
// unable_used_node include the node type that output tensor cannot be released,
|
||||
// even if its refcount is equal to zero.
|
||||
std::unordered_set<std::string> unable_used_node = {prim::kPrimBatchNorm->name(), prim::kPrimBatchNormGrad->name(),
|
||||
prim::kPrimFusedBatchNorm->name(),
|
||||
prim::kPrimFusedBatchNormGrad->name()};
|
||||
return unable_used_node.find(kernel_name) == unable_used_node.end();
|
||||
return unable_used_node.find(current_kernel_->kernel_name()) == unable_used_node.end();
|
||||
}
|
||||
|
||||
void BestFitMemReuse::CheckTensorIndex(int tensor_index) const {
|
||||
if (tensor_index < 0) {
|
||||
MS_LOG(EXCEPTION) << "warning, please check tensor info.";
|
||||
}
|
||||
if (IntToSize(tensor_index) >= tensor_ptr_list_.size()) {
|
||||
size_t BestFitMemReuse::GetTensorIndex(int index) const {
|
||||
if (index < 0 || IntToSize(index) >= tensor_ptr_list_.size()) {
|
||||
MS_LOG(WARNING) << "current cnode: " << current_kernel_->scope_full_name();
|
||||
MS_LOG(EXCEPTION) << "invalid tensor index";
|
||||
}
|
||||
return IntToSize(index);
|
||||
}
|
||||
|
||||
size_t BestFitMemReuse::GetWorkspaceIndex(int index) const {
|
||||
if (index < 0 || IntToSize(index) >= wk_tensor_list_.size()) {
|
||||
MS_LOG(WARNING) << "current cnode: " << current_kernel_->scope_full_name();
|
||||
MS_LOG(EXCEPTION) << "invalid tensor index";
|
||||
}
|
||||
return IntToSize(index);
|
||||
}
|
||||
|
||||
int BestFitMemReuse::GetRealIndex(size_t index, int flag) const {
|
||||
if (flag == kDynamicMem) {
|
||||
return SizeToInt(index);
|
||||
} else if (flag == kWorkspaceMem) {
|
||||
return kWorkspaceIndexFactor * SizeToInt(index + 1);
|
||||
} else {
|
||||
MS_LOG(EXCEPTION) << "flag " << flag << " is invalid";
|
||||
}
|
||||
}
|
||||
|
||||
void BestFitMemReuse::CheckMembufIndx(size_t membuf_index) const {
|
||||
if (membuf_index >= membuf_ptr_list_.size()) {
|
||||
MS_LOG(WARNING) << "current cnode: " << current_kernel_->scope_full_name();
|
||||
MS_LOG(EXCEPTION) << "invalid membuf index: " << membuf_index << ", real size: " << membuf_ptr_list_.size();
|
||||
}
|
||||
}
|
||||
|
||||
void BestFitMemReuse::Reuse(const MemReuseUtil *mem_reuse_util_ptr) {
|
||||
MS_EXCEPTION_IF_NULL(mem_reuse_util_ptr);
|
||||
InitMemReuseInfo(mem_reuse_util_ptr);
|
||||
InitKernelDependence();
|
||||
KernelDefPtr pre_op = nullptr;
|
||||
#ifdef MEM_REUSE_DEBUG
|
||||
size_t op_num = 0;
|
||||
#endif
|
||||
for (const auto &op_def_ptr : op_ptr_list_) {
|
||||
current_stream_id_ = op_def_ptr->stream_id();
|
||||
current_kernel_ = op_def_ptr;
|
||||
// releas pre_op_def
|
||||
if (pre_op != nullptr) {
|
||||
ReleasePreNodeWkSpace(pre_op.get());
|
||||
ReleasePreNodeWorkspace(pre_op.get());
|
||||
}
|
||||
MemReuseChecker::GetInstance().IsAddNewMembuf_ = false;
|
||||
// process node output tensor
|
||||
AssignNodeOutputOffset(op_def_ptr.get());
|
||||
AssignNodeOutputOffset();
|
||||
#ifdef MEM_REUSE_DEBUG
|
||||
if (MemReuseChecker::GetInstance().IsAddNewMembuf_) {
|
||||
MemReuseChecker::GetInstance().SetAddNewMembuInfos(op_def_ptr.get(), membuf_ptr_list_, op_num);
|
||||
}
|
||||
#endif
|
||||
// deal with current op'workspace
|
||||
AssignNodeWkOffset(op_def_ptr.get());
|
||||
AssignNodeWorkspaceOffset();
|
||||
pre_op = op_def_ptr;
|
||||
// update node input tensor refcount, and membuf list status
|
||||
UpdateNodeInputAndMembuf(op_def_ptr.get());
|
||||
UpdateNodeInputAndMembuf();
|
||||
// check node output tensor which refcount is equal to zero
|
||||
if (IsRelease(op_def_ptr->kernel_name())) {
|
||||
ReleaseNodeUnusedOutput(op_def_ptr.get());
|
||||
if (IsRelease()) {
|
||||
ReleaseNodeUnusedOutput();
|
||||
}
|
||||
#ifdef MEM_REUSE_DEBUG
|
||||
MemReuseChecker::GetInstance().SetMembuInfos(op_def_ptr.get(), membuf_ptr_list_);
|
||||
|
@ -457,6 +360,8 @@ void BestFitMemReuse::Reuse(const MemReuseUtil *mem_reuse_util_ptr) {
|
|||
#ifdef MEM_REUSE_DEBUG
|
||||
MemReuseChecker::GetInstance().ExportMembufInfoIR();
|
||||
MemReuseChecker::GetInstance().ExportAddNewMmebufIR();
|
||||
MemReuseChecker::GetInstance().set_kernel_front_map(kernel_front_map_);
|
||||
MemReuseChecker::GetInstance().ExportKernelDependence();
|
||||
#endif
|
||||
}
|
||||
} // namespace memreuse
|
||||
|
|
|
@ -29,31 +29,30 @@
|
|||
#include <unordered_map>
|
||||
#include <unordered_set>
|
||||
#include <set>
|
||||
#include <queue>
|
||||
#include "pre_activate/mem_reuse/kernel_refcount.h"
|
||||
#include "pre_activate/mem_reuse/mem_reuse.h"
|
||||
#include "pre_activate/mem_reuse/stream_reuse.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace memreuse {
|
||||
static constexpr int kWkIndexFactor = -1000;
|
||||
static constexpr int kDyFac = -1;
|
||||
static constexpr int kWkFac = 1;
|
||||
static constexpr int kWorkspaceIndexFactor = -1000;
|
||||
static constexpr int kDynamicMem = -1;
|
||||
static constexpr int kWorkspaceMem = 1;
|
||||
static constexpr size_t kTotalSize = 0;
|
||||
enum Status { kUnused, kReused };
|
||||
class Membuf {
|
||||
public:
|
||||
Membuf() = default;
|
||||
Membuf(uint32_t stream_id, Status status, size_t size, size_t offset, int index)
|
||||
: stream_id_(stream_id), status_(status), size_(size), offset_(offset), index_(index) {}
|
||||
Membuf(Status status, size_t size, size_t offset, int index, const KernelDefPtr &used_kernel)
|
||||
: status_(status), size_(size), offset_(offset), index_(index), used_kernel_(used_kernel) {}
|
||||
~Membuf() = default;
|
||||
// Memory block status flags
|
||||
std::set<uint32_t> called_stream_ids_;
|
||||
uint32_t stream_id_{0};
|
||||
Status status_ = kUnused;
|
||||
size_t size_{0};
|
||||
size_t offset_{0};
|
||||
// Store the tensor index stored in this memory block at a certain moment
|
||||
int index_{0};
|
||||
KernelDefPtr used_kernel_;
|
||||
};
|
||||
using MembufPtr = std::shared_ptr<Membuf>;
|
||||
|
||||
|
@ -61,24 +60,45 @@ class BestFitMemReuse {
|
|||
public:
|
||||
BestFitMemReuse() = default;
|
||||
~BestFitMemReuse() { membuf_ptr_list_.clear(); }
|
||||
// Init all information need by memory reuse
|
||||
/**
|
||||
* Init all information need by memory reuse
|
||||
* @param mem_reuse_util_ptr, initialize in the memreuse.cc
|
||||
*/
|
||||
void InitMemReuseInfo(const MemReuseUtil *mem_reuse_util_ptr);
|
||||
bool CheckMembufIndx(const std::vector<MembufPtr> &membuf_ptr_list, size_t check_idx) const;
|
||||
bool IsMembufListEmpty(const std::vector<MembufPtr> &membuf_ptr_list) const;
|
||||
void AssignNodeWkOffset(const KernelDef *kernel_def_ptr);
|
||||
void ReleasePreNodeWkSpace(const KernelDef *kernel_def_ptr);
|
||||
// void assign node output tensor memory offset
|
||||
void AssignNodeOutputOffset(const KernelDef *kernel_def_ptr);
|
||||
void ReleaseParallStream();
|
||||
// update node input tensor refcount, and membuf list status
|
||||
void UpdateNodeInputAndMembuf(const KernelDef *kernel_def_ptr);
|
||||
// check node output tensor which refcount is equal to zero
|
||||
void ReleaseNodeUnusedOutput(const KernelDef *kernel_def_ptr);
|
||||
// If there are memory blocks that can be reused
|
||||
void CheckMembufIndx(size_t check_idx) const;
|
||||
void AssignNodeWorkspaceOffset();
|
||||
void ReleasePreNodeWorkspace(const KernelDef *kernel_def_ptr);
|
||||
/**
|
||||
* Assign output tensor memory offset of current kernel
|
||||
*/
|
||||
void AssignNodeOutputOffset();
|
||||
/**
|
||||
* Update input tensor's status of current kernel, and the status of membuf used by current kernel
|
||||
*/
|
||||
void UpdateNodeInputAndMembuf();
|
||||
/**
|
||||
* Check whether to release the kernel output tensor which refcount is equal to zero
|
||||
*/
|
||||
void ReleaseNodeUnusedOutput();
|
||||
/**
|
||||
* Reuse the exist membuf if possible
|
||||
* @param tensor_desc, the output tensor of current kernel
|
||||
* @param membuf_index, the index of membuf to be reused
|
||||
* @param flag
|
||||
*/
|
||||
void ReuseExistMembuf(KernelRefCount *tensor_desc, size_t membuf_index, int flag);
|
||||
// Save memory blocks that can be reused to the map
|
||||
/**
|
||||
* Get the membuf that can be reused
|
||||
* @param tensor_size, the size of the tensor ready to assign memory offset
|
||||
* @return membuf map, key: the membuf size, value: the membuf index
|
||||
*/
|
||||
std::map<size_t, size_t> GetReusableMembufMap(size_t tensor_size);
|
||||
// Update the status of the reused memory block
|
||||
/**
|
||||
* Update the status of the reused memory block
|
||||
* @param tensor_desc, the tensor ready to assign memory
|
||||
* @param membuf, the membuf to be reused
|
||||
* @param flag, distinguish dynamic memory and workspace
|
||||
*/
|
||||
void UpdateMembufInfo(KernelRefCount *tensor_desc, Membuf *membuf, int flag);
|
||||
// If the size of the memory block is greater than the size of the tensor, split the extra memory
|
||||
void SplitMembuf(const KernelRefCount *tensor_desc, size_t membuf_index);
|
||||
|
@ -88,30 +108,39 @@ class BestFitMemReuse {
|
|||
void AddNewMembufPtr(KernelRefCount *tensor_desc, int flag);
|
||||
// Merge unused membuf
|
||||
void ReleaseMembuf(size_t tensor_index, int flag);
|
||||
bool HasParallelId(const std::set<uint32_t> &called_ids, uint32_t curr_id);
|
||||
void MergeCalledIds(const Membuf *membuf_target, Membuf *membuf);
|
||||
// Memory address alignment 512
|
||||
size_t AlignMemorySize(size_t size) const;
|
||||
int GetFacIdx(size_t real_idx, int flag = kDyFac) const;
|
||||
int GetRealIdx(int fac_idx, int flag = kDyFac) const;
|
||||
size_t FindIndx(const std::vector<MembufPtr> &membuf_ptr_list, int fac_idx) const;
|
||||
void CheckTensorIndex(int tensor_index) const;
|
||||
int GetRealIndex(size_t index, int flag = kDynamicMem) const;
|
||||
size_t GetTensorIndex(int index) const;
|
||||
size_t GetWorkspaceIndex(int index) const;
|
||||
// Memory reuse main program entry
|
||||
void Reuse(const MemReuseUtil *mem_reuse_util_ptr);
|
||||
// Get the total memory that needs to be applied eventually
|
||||
size_t GetAllocatedSize();
|
||||
// If the target stream can be reused by current stream
|
||||
bool IsReusableStream(uint32_t curr_stream_id, uint32_t target_stream_id);
|
||||
// return false, when the node output cannot be released
|
||||
bool IsRelease(const std::string &kernel_name);
|
||||
bool IsRelease();
|
||||
/**
|
||||
* determine if the kernel_curr can reuse the output tensor add of kernel_prev
|
||||
* @param kernel_curr, current kernel
|
||||
* @param kernel_prev, the membuf used by this kernel
|
||||
* @return bool
|
||||
*/
|
||||
bool IsUsable(const KernelDefPtr &kernel_curr, const KernelDefPtr &kernel_prev);
|
||||
/**
|
||||
* init the dependence of all kernels in the graph
|
||||
*/
|
||||
void InitKernelDependence();
|
||||
// set tensor_def and op_def
|
||||
void set_tensor_ptr_list(const std::vector<KernelRefCountPtr> &tensor_ptr_list) {
|
||||
tensor_ptr_list_ = tensor_ptr_list;
|
||||
}
|
||||
void set_workspace_ptr_list(const std::vector<KernelRefCountPtr> &workspace_ptr_list) {
|
||||
wk_tensor_list_ = workspace_ptr_list;
|
||||
}
|
||||
void set_op_ptr_list(const std::vector<KernelDefPtr> &op_ptr_list) { op_ptr_list_ = op_ptr_list; }
|
||||
|
||||
private:
|
||||
uint32_t current_stream_id_{0};
|
||||
KernelDefPtr current_kernel_;
|
||||
// Save all tensor information
|
||||
std::vector<KernelRefCountPtr> tensor_ptr_list_;
|
||||
std::vector<KernelRefCountPtr> wk_tensor_list_;
|
||||
|
@ -119,7 +148,8 @@ class BestFitMemReuse {
|
|||
std::vector<KernelDefPtr> op_ptr_list_;
|
||||
// Memory block information sequence, temporary variables
|
||||
std::vector<MembufPtr> membuf_ptr_list_;
|
||||
std::unordered_map<uint32_t, std::unordered_set<uint32_t>> parallel_streams_map_;
|
||||
// kernel_front_map_, key: the kernel_def, value: kernels before this kernel_def
|
||||
std::map<KernelDefPtr, std::set<KernelDefPtr>> kernel_front_map_;
|
||||
};
|
||||
} // namespace memreuse
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -19,8 +19,6 @@
|
|||
#include <vector>
|
||||
#include <utility>
|
||||
#include <string>
|
||||
#include <unordered_map>
|
||||
#include <unordered_set>
|
||||
|
||||
namespace mindspore {
|
||||
namespace memreuse {
|
||||
|
@ -188,6 +186,27 @@ void MemReuseChecker::CheckMemReuseIR(const KernelRefCountPtrList &total_refs_li
|
|||
ofs.close();
|
||||
}
|
||||
|
||||
void MemReuseChecker::ExportKernelDependence() {
|
||||
std::string filename = "./memreuse_dependence.ir";
|
||||
std::ofstream ofs(filename);
|
||||
if (!ofs.is_open()) {
|
||||
MS_LOG(ERROR) << "Open file [" << filename << "] failed!";
|
||||
return;
|
||||
}
|
||||
size_t i = 0;
|
||||
for (const auto &kernel_front : kernel_front_map_) {
|
||||
auto kernel = kernel_front.first;
|
||||
auto front = kernel_front.second;
|
||||
ofs << "[" << i++ << "] " << kernel->scope_full_name() << "\n";
|
||||
for (const auto &node : front) {
|
||||
ofs << node->scope_full_name() << "\n";
|
||||
}
|
||||
ofs << "\n\n";
|
||||
}
|
||||
|
||||
ofs.close();
|
||||
}
|
||||
|
||||
bool MemReuseChecker::CheckGraphOutputAssigned(const session::KernelGraph *graph) {
|
||||
// set real graph output node to be special who's refcount equal kMaxRefCount
|
||||
for (const auto &output : graph->outputs()) {
|
||||
|
@ -393,7 +412,7 @@ void MemReuseChecker::CheckNormalIR(const session::KernelGraph *graph) {
|
|||
void MemReuseChecker::SetMembuInfos(const KernelDef *op_def, const std::vector<MembufPtr> &membuf_ptr_list) {
|
||||
std::vector<MembufPtr> curr_mem_infos;
|
||||
for (const auto &mem : membuf_ptr_list) {
|
||||
auto mem_checker = std::make_shared<Membuf>(mem->stream_id_, mem->status_, mem->size_, mem->offset_, mem->index_);
|
||||
auto mem_checker = std::make_shared<Membuf>(mem->status_, mem->size_, mem->offset_, mem->index_, mem->used_kernel_);
|
||||
curr_mem_infos.push_back(mem_checker);
|
||||
}
|
||||
membuf_all_infos_.push_back(curr_mem_infos);
|
||||
|
@ -407,7 +426,7 @@ void MemReuseChecker::SetAddNewMembuInfos(const KernelDef *op_def, const std::ve
|
|||
std::vector<MembufPtr> add_new_curr_mem;
|
||||
|
||||
for (const auto &mem : membuf_ptr_list) {
|
||||
auto mem_checker = std::make_shared<Membuf>(mem->stream_id_, mem->status_, mem->size_, mem->offset_, mem->index_);
|
||||
auto mem_checker = std::make_shared<Membuf>(mem->status_, mem->size_, mem->offset_, mem->index_, mem->used_kernel_);
|
||||
add_new_curr_mem.push_back(mem_checker);
|
||||
}
|
||||
add_new_mem_infos_.push_back(add_new_curr_mem);
|
||||
|
@ -424,11 +443,11 @@ void MemReuseChecker::ExportMembufInfoIR() {
|
|||
if (!ofs.is_open()) {
|
||||
MS_LOG(ERROR) << "Open file [" << ir_file_name << "] failed!";
|
||||
}
|
||||
ofs << "total_ori_static_size:" << total_ori_static_size_ << "\n";
|
||||
ofs << "total_ori_weight_size:" << total_ori_input_size_ << "\n";
|
||||
ofs << "total_ori_constant_size:" << total_ori_value_size_ << "\n";
|
||||
ofs << "total_ori_dy_size:" << total_ori_dy_size_ << "\n";
|
||||
ofs << "total_ori_wkspace_size:" << total_ori_wkspace_size_ << "\n";
|
||||
ofs << "Total static size:\t" << total_ori_static_size_ << "\n";
|
||||
ofs << "Graph inputs size:\t" << total_ori_input_size_ << "\n";
|
||||
ofs << "Value nodes size:\t" << total_ori_value_size_ << "\n";
|
||||
ofs << "Total dynamic size:\t" << total_ori_dy_size_ << "\n";
|
||||
ofs << "Total workspace size:\t" << total_ori_wkspace_size_ << "\n";
|
||||
// get last membuf_list
|
||||
if (membuf_all_infos_.empty()) {
|
||||
return;
|
||||
|
@ -438,8 +457,10 @@ void MemReuseChecker::ExportMembufInfoIR() {
|
|||
auto checker_size = SizeToLong(membuf->size_);
|
||||
total_reuse_size += checker_size;
|
||||
}
|
||||
ofs << "total_reuse_size:" << total_reuse_size << "\n";
|
||||
ofs << "After reuse size:\t" << total_reuse_size << "\n\n";
|
||||
size_t i = 0;
|
||||
std::vector<size_t> each_node_used_size;
|
||||
std::vector<size_t> each_node_allocated_size;
|
||||
for (const auto &curr_membuf_list : membuf_all_infos_) {
|
||||
ofs << all_split_names_.at(i) << "\n";
|
||||
++i;
|
||||
|
@ -449,17 +470,42 @@ void MemReuseChecker::ExportMembufInfoIR() {
|
|||
<< "tensor_idex\t"
|
||||
<< "mem_size\t"
|
||||
<< "mem_head\t"
|
||||
<< "mem_tail\n";
|
||||
<< "mem_tail\t"
|
||||
<< "used_kernel\n";
|
||||
size_t curr_used = 0;
|
||||
size_t curr_allocated = 0;
|
||||
for (size_t j = 0; j < curr_membuf_list.size(); ++j) {
|
||||
auto membuf = curr_membuf_list.at(j);
|
||||
auto used_kernel = membuf->used_kernel_->scope_full_name();
|
||||
ofs << "&" << j << "\t"
|
||||
<< "streamID[@" << membuf->stream_id_ << "]"
|
||||
<< "streamID[@" << membuf->used_kernel_->stream_id() << "]"
|
||||
<< "\t"
|
||||
<< "#" << static_cast<int>(membuf->status_) << "\t%" << membuf->index_ << "T"
|
||||
<< "\t" << membuf->size_ << "\t" << membuf->offset_ << "\t" << membuf->offset_ + membuf->size_ << "\n";
|
||||
<< "\t" << membuf->size_ << "\t" << membuf->offset_ << "\t" << membuf->offset_ + membuf->size_ << "\t"
|
||||
<< GetSplitName(used_kernel) << "\n";
|
||||
if (membuf->status_ == kReused) {
|
||||
curr_used += membuf->size_;
|
||||
}
|
||||
}
|
||||
if (!curr_membuf_list.empty()) {
|
||||
curr_allocated = curr_membuf_list.back()->offset_ + curr_membuf_list.back()->size_;
|
||||
}
|
||||
each_node_used_size.push_back(curr_used);
|
||||
each_node_allocated_size.push_back(curr_allocated);
|
||||
ofs << "curr real used size: \t" << curr_used << "\n";
|
||||
ofs << "curr allocated size: \t" << curr_allocated << "\n";
|
||||
ofs << "\n\n";
|
||||
}
|
||||
ofs << "each node used size: \n";
|
||||
for (auto size : each_node_used_size) {
|
||||
ofs << size << "\t";
|
||||
}
|
||||
ofs << "\n\n";
|
||||
ofs << "each node allocated size: \n";
|
||||
for (auto size : each_node_allocated_size) {
|
||||
ofs << size << "\t";
|
||||
}
|
||||
ofs << "\n\n";
|
||||
ofs.close();
|
||||
}
|
||||
|
||||
|
@ -479,7 +525,6 @@ void MemReuseChecker::ExportAddNewMmebufIR() {
|
|||
<< "\n";
|
||||
i++;
|
||||
ofs << "mem_num\t"
|
||||
<< "stream_id\t"
|
||||
<< "status\t"
|
||||
<< "tensor_idex\t"
|
||||
<< "mem_size\t"
|
||||
|
@ -490,7 +535,6 @@ void MemReuseChecker::ExportAddNewMmebufIR() {
|
|||
for (size_t j = 0; j < curr_membuf_list.size(); ++j) {
|
||||
auto membuf = curr_membuf_list.at(j);
|
||||
ofs << "&" << j << "\t"
|
||||
<< "streamID[@" << membuf->stream_id_ << "]"
|
||||
<< "\t"
|
||||
<< "#" << static_cast<int>(membuf->status_) << "\t%" << membuf->index_ << "T"
|
||||
<< "\t" << membuf->size_ << "\t" << membuf->offset_ << "\t" << membuf->offset_ + membuf->size_ << "\t";
|
||||
|
|
|
@ -17,6 +17,7 @@
|
|||
#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_MEM_REUSE_MEM_REUSE_CHECKER_H_
|
||||
#define MINDSPORE_CCSRC_PRE_ACTIVATE_MEM_REUSE_MEM_REUSE_CHECKER_H_
|
||||
#include <map>
|
||||
#include <set>
|
||||
#include <vector>
|
||||
#include <string>
|
||||
#include <memory>
|
||||
|
@ -59,10 +60,14 @@ class MemReuseChecker {
|
|||
void ExportMembufInfoIR();
|
||||
void SetAddNewMembuInfos(const KernelDef *op_def, const std::vector<MembufPtr> &membuf_ptr_list, size_t op_idx);
|
||||
void ExportAddNewMmebufIR();
|
||||
void set_kernel_front_map(const std::map<KernelDefPtr, std::set<KernelDefPtr>> &kernel_front_map) {
|
||||
kernel_front_map_ = kernel_front_map;
|
||||
}
|
||||
void ExportKernelDependence();
|
||||
|
||||
private:
|
||||
MemReuseChecker() = default;
|
||||
~MemReuseChecker() { MS_LOG(INFO) << "Total reused workspace size: " << total_re_wkspe_size_checker_; }
|
||||
~MemReuseChecker() {}
|
||||
size_t total_re_wkspe_size_checker_{0};
|
||||
std::vector<std::vector<MembufPtr>> membuf_all_infos_;
|
||||
std::vector<const void *> nor_output_tensors_;
|
||||
|
@ -79,6 +84,7 @@ class MemReuseChecker {
|
|||
std::vector<std::string> all_split_names_;
|
||||
std::map<int, std::vector<string>> tensor_from_;
|
||||
std::map<int, std::vector<string>> tensor_to_;
|
||||
std::map<KernelDefPtr, std::set<KernelDefPtr>> kernel_front_map_;
|
||||
int64_t total_ori_static_size_ = 0;
|
||||
int64_t total_ori_input_size_ = 0;
|
||||
int64_t total_ori_value_size_ = 0;
|
||||
|
|
|
@ -1,102 +0,0 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "pre_activate/mem_reuse/stream_reuse.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace memreuse {
|
||||
void StreamReuse::SetStreamReuseResource() {
|
||||
#ifdef ENABLE_D
|
||||
auto logic_physic_map = device::ascend::AscendStreamAssign::GetInstance().logic_to_physic_map();
|
||||
auto logic_independent_map = device::ascend::AscendStreamAssign::GetInstance().logic_to_independent_map();
|
||||
MS_LOG(INFO) << "stream mem reuse for Davici";
|
||||
if (!logic_independent_map.empty() && !logic_physic_map.empty()) {
|
||||
set_logic_physic_map(logic_physic_map);
|
||||
set_logic_independent_map(logic_independent_map);
|
||||
InitReusableStreamMap();
|
||||
} else {
|
||||
MS_LOG(INFO) << "Non task sink or No Parallel stream exists";
|
||||
}
|
||||
#endif
|
||||
MS_LOG(INFO) << "no need to set stream mem reuse resource";
|
||||
}
|
||||
|
||||
std::vector<std::pair<uint32_t, uint32_t>> StreamReuse::SortLogicPhysicMapToList() {
|
||||
std::vector<std::pair<uint32_t, uint32_t>> logic_physic_list;
|
||||
(void)std::transform(logic_physic_map_.begin(), logic_physic_map_.end(), std::back_inserter(logic_physic_list),
|
||||
[](std::pair<uint32_t, uint32_t> log_phy) { return log_phy; });
|
||||
std::sort(
|
||||
logic_physic_list.begin(), logic_physic_list.end(),
|
||||
[](const std::pair<uint32_t, uint32_t> &logic_phyic_pair1, const std::pair<uint32_t, uint32_t> &logic_phyic_pair2) {
|
||||
return logic_phyic_pair1.second < logic_phyic_pair2.second;
|
||||
});
|
||||
return logic_physic_list;
|
||||
}
|
||||
|
||||
std::unordered_map<int, std::set<uint32_t>> StreamReuse::GetLogicPhysicsStreamMap() {
|
||||
auto logic_physic_list = SortLogicPhysicMapToList();
|
||||
std::unordered_map<int, std::set<uint32_t>> logic_phyics_map;
|
||||
for (size_t i = 0; i < logic_physic_list.size() - IntToSize(1); ++i) {
|
||||
auto curr_logic_physic = logic_physic_list.at(i);
|
||||
auto next_logic_physic = logic_physic_list.at(i + 1);
|
||||
for (auto j = curr_logic_physic.second; j < next_logic_physic.second; ++j) {
|
||||
(void)logic_phyics_map[curr_logic_physic.first].insert(j);
|
||||
}
|
||||
}
|
||||
// sort the logic independ map by value
|
||||
std::map<uint32_t, uint32_t> temp_map;
|
||||
for (const auto &logic_independ : logic_independent_map_) {
|
||||
(void)temp_map.insert(std::make_pair(logic_independ.second, logic_independ.first));
|
||||
}
|
||||
auto first_independent_stream_id = (*temp_map.begin()).first;
|
||||
auto last_physic_logic_stream_id = (*logic_physic_list.rbegin()).second;
|
||||
for (auto i = last_physic_logic_stream_id; i < first_independent_stream_id; ++i) {
|
||||
(void)logic_phyics_map[(*logic_physic_list.rbegin()).first].insert(i);
|
||||
}
|
||||
return logic_phyics_map;
|
||||
}
|
||||
|
||||
void StreamReuse::InitReusableStreamMap() {
|
||||
// logic_phyics_map, key, logic_stream_id; value, physic_strema_ids included in that logic stream
|
||||
auto logic_phyics_map = GetLogicPhysicsStreamMap();
|
||||
// parallel_streams_map: key, current_stream_id; value, streams parallel to current stream
|
||||
for (const auto &logic_to_phyics : logic_phyics_map) {
|
||||
auto logic_stream_id = logic_to_phyics.first;
|
||||
auto iter_inde = logic_independent_map_.find(logic_stream_id);
|
||||
if (iter_inde != logic_independent_map_.end()) {
|
||||
// exist independent steam parallel to these logic streams
|
||||
auto independent_stream_id = iter_inde->second;
|
||||
auto physics_stream_id = logic_to_phyics.second;
|
||||
for (const auto &physic : physics_stream_id) {
|
||||
(void)parallel_streams_map_[physic].insert(independent_stream_id);
|
||||
}
|
||||
}
|
||||
}
|
||||
for (const auto &logic_to_independent : logic_independent_map_) {
|
||||
auto logic_stream_id = logic_to_independent.first;
|
||||
auto independent_stream_id = logic_to_independent.second;
|
||||
auto iter_physics = logic_phyics_map.find(logic_stream_id);
|
||||
if (iter_physics != logic_phyics_map.end()) {
|
||||
// exist logic steam parallel to these independent streams, default
|
||||
auto physics_set = iter_physics->second;
|
||||
for (const auto &physic : physics_set) {
|
||||
(void)parallel_streams_map_[independent_stream_id].insert(physic);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
} // namespace memreuse
|
||||
} // namespace mindspore
|
|
@ -1,63 +0,0 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_MEM_REUSE_STREAM_REUSE_H_
|
||||
#define MINDSPORE_CCSRC_PRE_ACTIVATE_MEM_REUSE_STREAM_REUSE_H_
|
||||
#include <cmath>
|
||||
#include <map>
|
||||
#include <set>
|
||||
#include <list>
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
#include <numeric>
|
||||
#include <algorithm>
|
||||
#include <utility>
|
||||
#include <fstream>
|
||||
#include <unordered_set>
|
||||
#include <unordered_map>
|
||||
#include "session/anf_runtime_algorithm.h"
|
||||
#include "pre_activate/mem_reuse/kernel_refcount.h"
|
||||
|
||||
#ifdef ENABLE_D
|
||||
#include "device/ascend/ascend_stream_assign.h"
|
||||
#endif
|
||||
|
||||
namespace mindspore {
|
||||
namespace memreuse {
|
||||
class StreamReuse {
|
||||
public:
|
||||
StreamReuse() = default;
|
||||
~StreamReuse() = default;
|
||||
void SetStreamReuseResource();
|
||||
void InitReusableStreamMap();
|
||||
std::vector<std::pair<uint32_t, uint32_t>> SortLogicPhysicMapToList();
|
||||
std::unordered_map<int, std::set<uint32_t>> GetLogicPhysicsStreamMap();
|
||||
void set_logic_physic_map(const std::unordered_map<uint32_t, uint32_t> &logic_physic_map) {
|
||||
logic_physic_map_ = logic_physic_map;
|
||||
}
|
||||
void set_logic_independent_map(const std::unordered_map<uint32_t, uint32_t> &logic_independent_map) {
|
||||
logic_independent_map_ = logic_independent_map;
|
||||
}
|
||||
std::unordered_map<uint32_t, std::unordered_set<uint32_t>> parallel_streams_map() { return parallel_streams_map_; }
|
||||
|
||||
private:
|
||||
std::unordered_map<uint32_t, std::unordered_set<uint32_t>> parallel_streams_map_;
|
||||
std::unordered_map<uint32_t, uint32_t> logic_physic_map_;
|
||||
std::unordered_map<uint32_t, uint32_t> logic_independent_map_;
|
||||
};
|
||||
} // namespace memreuse
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_MEM_REUSE_STREAM_REUSE_H_
|
|
@ -117,16 +117,13 @@ TEST_F(TestMemReuseAllocator, mem_reuse_allocator) {
|
|||
MS_LOG(INFO) << "run mem reuse success";
|
||||
size_t total_allocated_size = best_fit_mem_reuse->GetAllocatedSize();
|
||||
ASSERT_NE(total_allocated_size, 0);
|
||||
|
||||
auto is_reusable_stream = best_fit_mem_reuse->IsReusableStream(1, 3);
|
||||
ASSERT_EQ(is_reusable_stream, true);
|
||||
}
|
||||
|
||||
TEST_F(TestMemReuseAllocator, mem_reuse_allocator_add_membuf) {
|
||||
auto best_fit_mem_reuse = std::make_shared<BestFitMemReuse>();
|
||||
auto tensor_desc = std::make_shared<KernelRefCount>();
|
||||
tensor_desc->SetKernelRefCountInfo(0, 1024, kDynamicRefCount);
|
||||
best_fit_mem_reuse->AddNewMembufPtr(tensor_desc.get(), kDyFac);
|
||||
best_fit_mem_reuse->AddNewMembufPtr(tensor_desc.get(), kDynamicMem);
|
||||
auto allocated_size = best_fit_mem_reuse->GetAllocatedSize();
|
||||
ASSERT_EQ(allocated_size, 1024);
|
||||
}
|
||||
|
@ -135,7 +132,7 @@ TEST_F(TestMemReuseAllocator, mem_reuse_allocator_split_membuf) {
|
|||
auto best_fit_mem_reuse = std::make_shared<BestFitMemReuse>();
|
||||
auto tensor_0 = std::make_shared<KernelRefCount>();
|
||||
tensor_0->SetKernelRefCountInfo(0, 2048, kDynamicRefCount);
|
||||
best_fit_mem_reuse->AddNewMembufPtr(tensor_0.get(), kDyFac);
|
||||
best_fit_mem_reuse->AddNewMembufPtr(tensor_0.get(), kDynamicMem);
|
||||
|
||||
auto tensor_1 = std::make_shared<KernelRefCount>();
|
||||
tensor_1->SetKernelRefCountInfo(1, 800, kDynamicRefCount);
|
||||
|
|
|
@ -228,12 +228,6 @@ TEST_F(TestMemReuseWithPy, KernelRef) {
|
|||
ASSERT_EQ(kernel_def_ptr->dirty, false);
|
||||
MembufPtr membuf_ptr = std::make_shared<Membuf>();
|
||||
ASSERT_NE(membuf_ptr, nullptr);
|
||||
MembufPtr membuf_ptr_x = std::make_shared<Membuf>(0, memreuse::kUnused, 512, 128, 2);
|
||||
ASSERT_EQ(membuf_ptr_x->status_, memreuse::kUnused);
|
||||
ASSERT_EQ(membuf_ptr_x->size_, 512);
|
||||
ASSERT_EQ(membuf_ptr_x->offset_, 128);
|
||||
ASSERT_EQ(membuf_ptr_x->index_, 2);
|
||||
ASSERT_EQ(membuf_ptr_x->stream_id_, 0);
|
||||
}
|
||||
|
||||
TEST_F(TestMemReuseWithPy, ReuseAssignDynamicMemory) {
|
||||
|
|
|
@ -1,63 +0,0 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
#include <string>
|
||||
#include "operator/ops.h"
|
||||
#include "pre_activate/mem_reuse/stream_reuse.h"
|
||||
#include "common/common_test.h"
|
||||
#include "common/py_func_graph_fetcher.h"
|
||||
|
||||
using mindspore::memreuse::StreamReuse;
|
||||
|
||||
namespace mindspore {
|
||||
class TestStreamMemReuse : public UT::Common {
|
||||
public:
|
||||
TestStreamMemReuse() : getPyFun_("gtest_input.mem_reuse.TestMemReuseAllocator", true) {}
|
||||
void SetUp() {}
|
||||
|
||||
public:
|
||||
UT::PyFuncGraphFetcher getPyFun_;
|
||||
};
|
||||
|
||||
TEST_F(TestStreamMemReuse, init_reusable_stream_map_test) {
|
||||
std::unordered_map<uint32_t, uint32_t> logic_physic_map{{1, 0}, {2, 8}, {3, 3}};
|
||||
std::unordered_map<uint32_t, uint32_t> logic_independent_map{{3, 10}, {2, 11}};
|
||||
auto stream_reuse = std::make_shared<StreamReuse>();
|
||||
stream_reuse->set_logic_physic_map(logic_physic_map);
|
||||
stream_reuse->set_logic_independent_map(logic_independent_map);
|
||||
|
||||
auto logic_phyics_map = stream_reuse->GetLogicPhysicsStreamMap();
|
||||
for (const auto &logic_physics : logic_phyics_map) {
|
||||
MS_LOG(INFO) << "[logic_id: " << logic_physics.first << "]";
|
||||
for (const auto &physic : logic_physics.second) {
|
||||
MS_LOG(INFO) << "physic: " << physic;
|
||||
}
|
||||
}
|
||||
MS_LOG(INFO) << "===========UT logic_physic_map size: " << logic_physic_map.size() << "========";
|
||||
ASSERT_EQ(logic_physic_map.size(), 3);
|
||||
stream_reuse->InitReusableStreamMap();
|
||||
auto parallel_streams_map = stream_reuse->parallel_streams_map();
|
||||
for (const auto ¶llel_streams : parallel_streams_map) {
|
||||
MS_LOG(INFO) << "[stream id: " << parallel_streams.first << "]";
|
||||
for (const auto &stream : parallel_streams.second) {
|
||||
MS_LOG(INFO) << "parallel stream id: " << stream;
|
||||
}
|
||||
}
|
||||
ASSERT_EQ(parallel_streams_map[7].size(), 1);
|
||||
}
|
||||
} // namespace mindspore
|
Loading…
Reference in New Issue