forked from mindspore-Ecosystem/mindspore
!1170 fix memreuse to support large batchsize
Merge pull request !1170 from yangjie159/fix_memreuse_to_support_large_batchsize
This commit is contained in:
commit
2445ffdf7b
|
@ -21,8 +21,8 @@
|
|||
namespace mindspore {
|
||||
namespace device {
|
||||
namespace ascend {
|
||||
const uint64_t kAscendDeviceMemGB = 24;
|
||||
const uint64_t kAscendMemPoolGB = 6;
|
||||
const uint64_t kAscendDeviceMemGB = 26;
|
||||
const uint64_t kAscendMemPoolGB = 4;
|
||||
const uint64_t kAscendDeviceMemSize = (kAscendDeviceMemGB << 30);
|
||||
const uint64_t kAscendMemPoolSize = (kAscendMemPoolGB << 30);
|
||||
|
||||
|
|
|
@ -401,6 +401,15 @@ bool BestFitMemReuse::IsReusableStream(uint32_t curr_stream_id, uint32_t target_
|
|||
return curr_parallel_set.find(target_stream_id) == curr_parallel_set.end();
|
||||
}
|
||||
|
||||
bool BestFitMemReuse::IsRelease(const std::string &kernel_name) {
|
||||
// unable_used_node include the node type that output tensor cannot be released,
|
||||
// even if its refcount is equal to zero.
|
||||
std::unordered_set<std::string> unable_used_node = {prim::kPrimBatchNorm->name(), prim::kPrimBatchNormGrad->name(),
|
||||
prim::kPrimFusedBatchNorm->name(),
|
||||
prim::kPrimFusedBatchNormGrad->name()};
|
||||
return unable_used_node.find(kernel_name) == unable_used_node.end();
|
||||
}
|
||||
|
||||
void BestFitMemReuse::CheckTensorIndex(int tensor_index) const {
|
||||
if (tensor_index < 0) {
|
||||
MS_LOG(EXCEPTION) << "warning, please check tensor info.";
|
||||
|
@ -437,6 +446,9 @@ void BestFitMemReuse::Reuse(const MemReuseUtil *mem_reuse_util_ptr) {
|
|||
// update node input tensor refcount, and membuf list status
|
||||
UpdateNodeInputAndMembuf(op_def_ptr.get());
|
||||
// check node output tensor which refcount is equal to zero
|
||||
if (IsRelease(op_def_ptr->kernel_name())) {
|
||||
ReleaseNodeUnusedOutput(op_def_ptr.get());
|
||||
}
|
||||
#ifdef MEM_REUSE_DEBUG
|
||||
MemReuseChecker::GetInstance().SetMembuInfos(op_def_ptr.get(), membuf_ptr_list_);
|
||||
++op_num;
|
||||
|
|
|
@ -102,6 +102,8 @@ class BestFitMemReuse {
|
|||
size_t GetAllocatedSize();
|
||||
// If the target stream can be reused by current stream
|
||||
bool IsReusableStream(uint32_t curr_stream_id, uint32_t target_stream_id);
|
||||
// return false, when the node output cannot be released
|
||||
bool IsRelease(const std::string &kernel_name);
|
||||
// set tensor_def and op_def
|
||||
void set_tensor_ptr_list(const std::vector<KernelRefCountPtr> &tensor_ptr_list) {
|
||||
tensor_ptr_list_ = tensor_ptr_list;
|
||||
|
|
Loading…
Reference in New Issue