!44932 fix the address bug of output ref node

Merge pull request !44932 from limingqi107/bug_fix4
This commit is contained in:
i-robot 2022-11-01 12:48:58 +00:00 committed by Gitee
commit b4e6ec78be
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
5 changed files with 29 additions and 16 deletions

View File

@ -914,10 +914,6 @@ bool KernelGraph::IsRefOutputMapValue(const AnfWithOutIndex &pair) const {
}
AnfWithOutIndex KernelGraph::GetRefCorrespondOutput(const AnfWithOutIndex &out_pair) const {
if (!IsInRefOutputMap(out_pair)) {
MS_LOG(EXCEPTION) << "Out_pair is not in RefOutputMap, node is " << out_pair.first->DebugString() << ", index is "
<< out_pair.second;
}
return ref_out_in_map_.at(out_pair);
}

View File

@ -374,14 +374,14 @@ inline T *GetDeviceAddress(const std::vector<AddressPtr> &addr_list, size_t inde
}
if (addr_list[index]->addr == nullptr) {
MS_LOG(ERROR) << "The memory of device address is nullptr, address index: " << index
<< ", and the length of 'addr_list' is " << addr_list.size();
MS_LOG(WARNING) << "The memory of device address is nullptr, address index: " << index
<< ", and the length of 'addr_list' is " << addr_list.size();
return nullptr;
}
if (addr_list[index]->size == 0) {
MS_LOG(ERROR) << "The size of device address is zero, address index: " << index
<< ", and the length of 'addr_list' is " << addr_list.size();
MS_LOG(WARNING) << "The size of device address is zero, address index: " << index
<< ", and the length of 'addr_list' is " << addr_list.size();
return nullptr;
}

View File

@ -369,7 +369,13 @@ void DataPrepareActor::UpdateDeviceAddressForDataNode(const AnfNodePtr &input_no
}
auto tensor_address = std::dynamic_pointer_cast<DeviceTensor>(input_tensor->device_address());
if ((tensor_address == nullptr) || (tensor_address == device_address)) {
if (tensor_address == nullptr) {
return;
}
if (tensor_address == device_address) {
tensor_address->SetNodeIndex(input_node, 0);
tensor_address->set_original_ref_count(SIZE_MAX);
tensor_address->ResetRefCount();
return;
}

View File

@ -655,7 +655,8 @@ bool KernelActor::LaunchKernel(OpContext<DeviceTensor> *const) {
MS_EXCEPTION_IF_NULL(launch_info_.inputs_[input_index]);
MS_EXCEPTION_IF_NULL(launch_info_.outputs_[output_index]);
if (launch_info_.inputs_[input_index]->addr != launch_info_.outputs_[output_index]->addr) {
MS_LOG(ERROR) << "Input address and output address are not equal of ref kernel actor: " << GetAID().Name();
MS_LOG(ERROR) << "Input address and output address are not equal of ref kernel actor: " << GetAID().Name()
<< ", input index: " << input_index << ", output index: " << output_index;
return false;
}
}

View File

@ -25,24 +25,34 @@ namespace runtime {
using distributed::collective::CollectiveManager;
using distributed::recovery::RecoveryContext;
bool IsOutputAddressPersisted(const DeviceTensor *output_device_tensor, const AnfNodePtr &output_node) {
MS_EXCEPTION_IF_NULL(output_node);
bool IsOutputAddressPersisted(const DeviceTensor *output_device_tensor, const KernelWithIndex &output_node) {
MS_EXCEPTION_IF_NULL(output_node.first);
MS_EXCEPTION_IF_NULL(output_device_tensor);
// The persisted address can't be replaced.
if (output_device_tensor->is_ptr_persisted()) {
return true;
}
if (output_node->isa<ValueNode>()) {
if (output_node.first->isa<ValueNode>()) {
return true;
}
// The device address of parameter may come from the device address of input tensor.
// In order to avoid mistakenly cleaning up the device data of input tensor, return it as persisted address.
if (output_node->isa<Parameter>()) {
if (output_node.first->isa<Parameter>()) {
return true;
}
// Ref node need check the origin node.
const auto &graph = AnfAlgo::FetchKernelGraph(output_node.first.get());
if ((graph != nullptr) && graph->IsInRefOutputMap(output_node)) {
const auto &origin_node = graph->GetRefCorrespondOutput(output_node).first;
MS_EXCEPTION_IF_NULL(origin_node);
if (origin_node->isa<ValueNode>() || origin_node->isa<Parameter>()) {
return true;
}
}
return false;
}
@ -84,7 +94,7 @@ void OutputActor::FreeOutputNodeMem() {
if ((output_node == nullptr) || (output_device_tensor == nullptr) || (output_device_tensor->GetPtr() == nullptr)) {
return;
}
if (!IsOutputAddressPersisted(output_device_tensor, output_node)) {
if (!IsOutputAddressPersisted(output_device_tensor, output_nodes_[i])) {
FreeMemoryByDeviceContext(output_device_tensor, device_contexts_[i]);
}
}
@ -295,7 +305,7 @@ void OutputActor::UpdateOutputDeviceAddress() {
}
// If the output node whose output address ptr can't be changed, then alloc the new device memory and copy the data:
if (IsOutputAddressPersisted(device_tensor, output_node)) {
if (IsOutputAddressPersisted(device_tensor, output_nodes_[i])) {
auto device_context = device_contexts_[i];
MS_EXCEPTION_IF_NULL(device_context);
device::DynamicMemAllocatorDebugInfo::SetDebugInfo(GetAID().Name(), device::AllocatorType::kOther);