forked from mindspore-Ecosystem/mindspore
fix refnode input type assign
This commit is contained in:
parent
7cb567ebbe
commit
d99786e938
|
@ -64,11 +64,14 @@ bool MemReuseUtil::InitDynamicOutputKernelRef() {
|
|||
kernel_ref->type_ = kRefNodeOutput;
|
||||
auto origin_pair = graph_->GetRefCorrespondOutput(out_pair);
|
||||
MS_EXCEPTION_IF_NULL(origin_pair.first);
|
||||
MS_LOG(INFO) << "REF origin op is " << origin_pair.first->fullname_with_scope() << ", output index is "
|
||||
<< origin_pair.second << ", cur op is " << kernel_cnode->fullname_with_scope()
|
||||
<< ", out index is " << output_index;
|
||||
if (origin_pair.first->isa<CNode>()) {
|
||||
auto cnode = origin_pair.first->cast<CNodePtr>();
|
||||
auto ref_ptr = GetKernelInputRef(cnode, origin_pair.second);
|
||||
auto ref_ptr = GetRef(cnode, origin_pair.second);
|
||||
if (ref_ptr != nullptr) {
|
||||
kernel_ref->type_ = kRefNodeInput;
|
||||
ref_ptr->type_ = kRefNodeInput;
|
||||
}
|
||||
}
|
||||
} else {
|
||||
|
|
|
@ -96,6 +96,12 @@ uint8_t *AscendMemoryManager::MallocStaticMem(size_t size, bool communication_me
|
|||
} else {
|
||||
align_size = GetCommonAlignSize(size);
|
||||
}
|
||||
|
||||
auto device_mem_pool_offset = AscendMemoryPool::GetInstance().device_mem_pool_offset();
|
||||
MS_LOG(INFO) << "Malloc Memory: Static, total[" << device_mem_size_ << "] (dynamic[" << total_dynamic_size_
|
||||
<< "] memory pool[" << device_mem_pool_offset << "])"
|
||||
<< " malloc [" << align_size << "] communication_mem: " << communication_mem;
|
||||
|
||||
if (communication_mem) {
|
||||
// create protect area [kMemAlignSize -- data -- kMemAlignSize]
|
||||
uint8_t *alloc_address = reinterpret_cast<uint8_t *>(AscendMemoryPool::GetInstance().AllocTensorMem(align_size));
|
||||
|
@ -112,12 +118,17 @@ uint8_t *AscendMemoryManager::MallocDynamicMem(size_t size, bool communication_m
|
|||
} else {
|
||||
align_size = GetCommonAlignSize(size);
|
||||
}
|
||||
|
||||
auto device_mem_pool_offset = AscendMemoryPool::GetInstance().device_mem_pool_offset();
|
||||
MS_LOG(INFO) << "Malloc Memory: Dynamic, total[" << device_mem_size_ << "] (dynamic[" << total_dynamic_size_
|
||||
<< "] memory pool[" << device_mem_pool_offset << "])"
|
||||
<< " malloc [" << align_size << "] communication_mem: " << communication_mem;
|
||||
|
||||
if (dynamic_mem_offset_ < align_size) {
|
||||
MS_LOG(EXCEPTION) << "Out of memory!!! total[" << device_mem_size_ << "] (dynamic[" << total_dynamic_size_
|
||||
<< "]) malloc [" << align_size << "] failed!";
|
||||
}
|
||||
auto new_offset = dynamic_mem_offset_ - align_size;
|
||||
auto device_mem_pool_offset = AscendMemoryPool::GetInstance().device_mem_pool_offset();
|
||||
if (new_offset <= device_mem_pool_offset) {
|
||||
MS_LOG(EXCEPTION) << "Out of memory!!! total[" << device_mem_size_ << "] (dynamic[" << total_dynamic_size_
|
||||
<< "] memory pool[" << device_mem_pool_offset << "])"
|
||||
|
|
Loading…
Reference in New Issue