forked from mindspore-Ecosystem/mindspore
!44932 fix the address bug of output ref node
Merge pull request !44932 from limingqi107/bug_fix4
This commit is contained in:
commit
b4e6ec78be
|
@ -914,10 +914,6 @@ bool KernelGraph::IsRefOutputMapValue(const AnfWithOutIndex &pair) const {
|
|||
}
|
||||
|
||||
AnfWithOutIndex KernelGraph::GetRefCorrespondOutput(const AnfWithOutIndex &out_pair) const {
|
||||
if (!IsInRefOutputMap(out_pair)) {
|
||||
MS_LOG(EXCEPTION) << "Out_pair is not in RefOutputMap, node is " << out_pair.first->DebugString() << ", index is "
|
||||
<< out_pair.second;
|
||||
}
|
||||
return ref_out_in_map_.at(out_pair);
|
||||
}
|
||||
|
||||
|
|
|
@ -374,14 +374,14 @@ inline T *GetDeviceAddress(const std::vector<AddressPtr> &addr_list, size_t inde
|
|||
}
|
||||
|
||||
if (addr_list[index]->addr == nullptr) {
|
||||
MS_LOG(ERROR) << "The memory of device address is nullptr, address index: " << index
|
||||
<< ", and the length of 'addr_list' is " << addr_list.size();
|
||||
MS_LOG(WARNING) << "The memory of device address is nullptr, address index: " << index
|
||||
<< ", and the length of 'addr_list' is " << addr_list.size();
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
if (addr_list[index]->size == 0) {
|
||||
MS_LOG(ERROR) << "The size of device address is zero, address index: " << index
|
||||
<< ", and the length of 'addr_list' is " << addr_list.size();
|
||||
MS_LOG(WARNING) << "The size of device address is zero, address index: " << index
|
||||
<< ", and the length of 'addr_list' is " << addr_list.size();
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
|
|
|
@ -369,7 +369,13 @@ void DataPrepareActor::UpdateDeviceAddressForDataNode(const AnfNodePtr &input_no
|
|||
}
|
||||
|
||||
auto tensor_address = std::dynamic_pointer_cast<DeviceTensor>(input_tensor->device_address());
|
||||
if ((tensor_address == nullptr) || (tensor_address == device_address)) {
|
||||
if (tensor_address == nullptr) {
|
||||
return;
|
||||
}
|
||||
if (tensor_address == device_address) {
|
||||
tensor_address->SetNodeIndex(input_node, 0);
|
||||
tensor_address->set_original_ref_count(SIZE_MAX);
|
||||
tensor_address->ResetRefCount();
|
||||
return;
|
||||
}
|
||||
|
||||
|
|
|
@ -655,7 +655,8 @@ bool KernelActor::LaunchKernel(OpContext<DeviceTensor> *const) {
|
|||
MS_EXCEPTION_IF_NULL(launch_info_.inputs_[input_index]);
|
||||
MS_EXCEPTION_IF_NULL(launch_info_.outputs_[output_index]);
|
||||
if (launch_info_.inputs_[input_index]->addr != launch_info_.outputs_[output_index]->addr) {
|
||||
MS_LOG(ERROR) << "Input address and output address are not equal of ref kernel actor: " << GetAID().Name();
|
||||
MS_LOG(ERROR) << "Input address and output address are not equal of ref kernel actor: " << GetAID().Name()
|
||||
<< ", input index: " << input_index << ", output index: " << output_index;
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
|
|
@ -25,24 +25,34 @@ namespace runtime {
|
|||
using distributed::collective::CollectiveManager;
|
||||
using distributed::recovery::RecoveryContext;
|
||||
|
||||
bool IsOutputAddressPersisted(const DeviceTensor *output_device_tensor, const AnfNodePtr &output_node) {
|
||||
MS_EXCEPTION_IF_NULL(output_node);
|
||||
bool IsOutputAddressPersisted(const DeviceTensor *output_device_tensor, const KernelWithIndex &output_node) {
|
||||
MS_EXCEPTION_IF_NULL(output_node.first);
|
||||
MS_EXCEPTION_IF_NULL(output_device_tensor);
|
||||
// The persisted address can't be replaced.
|
||||
if (output_device_tensor->is_ptr_persisted()) {
|
||||
return true;
|
||||
}
|
||||
|
||||
if (output_node->isa<ValueNode>()) {
|
||||
if (output_node.first->isa<ValueNode>()) {
|
||||
return true;
|
||||
}
|
||||
|
||||
// The device address of parameter may come from the device address of input tensor.
|
||||
// In order to avoid mistakenly cleaning up the device data of input tensor, return it as persisted address.
|
||||
if (output_node->isa<Parameter>()) {
|
||||
if (output_node.first->isa<Parameter>()) {
|
||||
return true;
|
||||
}
|
||||
|
||||
// Ref node need check the origin node.
|
||||
const auto &graph = AnfAlgo::FetchKernelGraph(output_node.first.get());
|
||||
if ((graph != nullptr) && graph->IsInRefOutputMap(output_node)) {
|
||||
const auto &origin_node = graph->GetRefCorrespondOutput(output_node).first;
|
||||
MS_EXCEPTION_IF_NULL(origin_node);
|
||||
if (origin_node->isa<ValueNode>() || origin_node->isa<Parameter>()) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
|
@ -84,7 +94,7 @@ void OutputActor::FreeOutputNodeMem() {
|
|||
if ((output_node == nullptr) || (output_device_tensor == nullptr) || (output_device_tensor->GetPtr() == nullptr)) {
|
||||
return;
|
||||
}
|
||||
if (!IsOutputAddressPersisted(output_device_tensor, output_node)) {
|
||||
if (!IsOutputAddressPersisted(output_device_tensor, output_nodes_[i])) {
|
||||
FreeMemoryByDeviceContext(output_device_tensor, device_contexts_[i]);
|
||||
}
|
||||
}
|
||||
|
@ -295,7 +305,7 @@ void OutputActor::UpdateOutputDeviceAddress() {
|
|||
}
|
||||
|
||||
// If the output node whose output address ptr can't be changed, then alloc the new device memory and copy the data:
|
||||
if (IsOutputAddressPersisted(device_tensor, output_node)) {
|
||||
if (IsOutputAddressPersisted(device_tensor, output_nodes_[i])) {
|
||||
auto device_context = device_contexts_[i];
|
||||
MS_EXCEPTION_IF_NULL(device_context);
|
||||
device::DynamicMemAllocatorDebugInfo::SetDebugInfo(GetAID().Name(), device::AllocatorType::kOther);
|
||||
|
|
Loading…
Reference in New Issue