From d99786e938b53fb25e7be96073c7dd2caee28d88 Mon Sep 17 00:00:00 2001 From: laiyongqiang Date: Tue, 28 Jul 2020 20:11:43 +0800 Subject: [PATCH] fix refnode input type assign --- .../ccsrc/backend/optimizer/mem_reuse/mem_reuse.cc | 7 +++++-- .../runtime/device/ascend/ascend_memory_manager.cc | 13 ++++++++++++- 2 files changed, 17 insertions(+), 3 deletions(-) diff --git a/mindspore/ccsrc/backend/optimizer/mem_reuse/mem_reuse.cc b/mindspore/ccsrc/backend/optimizer/mem_reuse/mem_reuse.cc index 02a277f2243..c45504e214e 100644 --- a/mindspore/ccsrc/backend/optimizer/mem_reuse/mem_reuse.cc +++ b/mindspore/ccsrc/backend/optimizer/mem_reuse/mem_reuse.cc @@ -64,11 +64,14 @@ bool MemReuseUtil::InitDynamicOutputKernelRef() { kernel_ref->type_ = kRefNodeOutput; auto origin_pair = graph_->GetRefCorrespondOutput(out_pair); MS_EXCEPTION_IF_NULL(origin_pair.first); + MS_LOG(INFO) << "REF origin op is " << origin_pair.first->fullname_with_scope() << ", output index is " + << origin_pair.second << ", cur op is " << kernel_cnode->fullname_with_scope() + << ", out index is " << output_index; if (origin_pair.first->isa()) { auto cnode = origin_pair.first->cast(); - auto ref_ptr = GetKernelInputRef(cnode, origin_pair.second); + auto ref_ptr = GetRef(cnode, origin_pair.second); if (ref_ptr != nullptr) { - kernel_ref->type_ = kRefNodeInput; + ref_ptr->type_ = kRefNodeInput; } } } else { diff --git a/mindspore/ccsrc/runtime/device/ascend/ascend_memory_manager.cc b/mindspore/ccsrc/runtime/device/ascend/ascend_memory_manager.cc index ffedeaa6a1a..6d11900a46e 100644 --- a/mindspore/ccsrc/runtime/device/ascend/ascend_memory_manager.cc +++ b/mindspore/ccsrc/runtime/device/ascend/ascend_memory_manager.cc @@ -96,6 +96,12 @@ uint8_t *AscendMemoryManager::MallocStaticMem(size_t size, bool communication_me } else { align_size = GetCommonAlignSize(size); } + + auto device_mem_pool_offset = AscendMemoryPool::GetInstance().device_mem_pool_offset(); + MS_LOG(INFO) << "Malloc Memory: Static, total[" << device_mem_size_ << "] (dynamic[" << total_dynamic_size_ + << "] memory pool[" << device_mem_pool_offset << "])" + << " malloc [" << align_size << "] communication_mem: " << communication_mem; + if (communication_mem) { // create protect area [kMemAlignSize -- data -- kMemAlignSize] uint8_t *alloc_address = reinterpret_cast(AscendMemoryPool::GetInstance().AllocTensorMem(align_size)); @@ -112,12 +118,17 @@ uint8_t *AscendMemoryManager::MallocDynamicMem(size_t size, bool communication_m } else { align_size = GetCommonAlignSize(size); } + + auto device_mem_pool_offset = AscendMemoryPool::GetInstance().device_mem_pool_offset(); + MS_LOG(INFO) << "Malloc Memory: Dynamic, total[" << device_mem_size_ << "] (dynamic[" << total_dynamic_size_ + << "] memory pool[" << device_mem_pool_offset << "])" + << " malloc [" << align_size << "] communication_mem: " << communication_mem; + if (dynamic_mem_offset_ < align_size) { MS_LOG(EXCEPTION) << "Out of memory!!! total[" << device_mem_size_ << "] (dynamic[" << total_dynamic_size_ << "]) malloc [" << align_size << "] failed!"; } auto new_offset = dynamic_mem_offset_ - align_size; - auto device_mem_pool_offset = AscendMemoryPool::GetInstance().device_mem_pool_offset(); if (new_offset <= device_mem_pool_offset) { MS_LOG(EXCEPTION) << "Out of memory!!! total[" << device_mem_size_ << "] (dynamic[" << total_dynamic_size_ << "] memory pool[" << device_mem_pool_offset << "])"