From e17a533369df0c50594f8fc755e28d295b736bb5 Mon Sep 17 00:00:00 2001 From: limingqi107 Date: Mon, 22 Nov 2021 19:59:10 +0800 Subject: [PATCH] fix the bug of graph and pynative shared the weight --- .../ccsrc/runtime/framework/actor/data_prepare_actor.cc | 6 ++++++ mindspore/ccsrc/runtime/framework/actor/output_actor.cc | 6 ++++-- 2 files changed, 10 insertions(+), 2 deletions(-) diff --git a/mindspore/ccsrc/runtime/framework/actor/data_prepare_actor.cc b/mindspore/ccsrc/runtime/framework/actor/data_prepare_actor.cc index 8d5af5fa58b..4b2611bac7f 100644 --- a/mindspore/ccsrc/runtime/framework/actor/data_prepare_actor.cc +++ b/mindspore/ccsrc/runtime/framework/actor/data_prepare_actor.cc @@ -242,6 +242,7 @@ void DataPrepareActor::PrepareDataForHostTensorQueue(const std::vectorDeviceType() == device_address->DeviceType()) && !device_address->is_ptr_persisted()) { AnfAlgo::SetOutputAddr(tensor_address, 0, input_node.get()); + tensor_address->SetNodeIndex(input_node, 0); } } } @@ -288,6 +289,7 @@ void DataPrepareActor::PrepareDataForStepMode(const std::vector(input_tensor->device_address()); if (host_tensor_address != nullptr) { AnfAlgo::SetOutputAddr(host_tensor_address, 0, input_node.get()); + host_tensor_address->SetNodeIndex(input_node, 0); continue; } @@ -299,7 +301,9 @@ void DataPrepareActor::PrepareDataForStepMode(const std::vectorCreateDeviceAddress( nullptr, tensor_size, AnfAlgo::GetOutputFormat(input_node, 0), output_type_id); + MS_EXCEPTION_IF_NULL(device_address); AnfAlgo::SetOutputAddr(device_address, 0, input_node.get()); + device_address->SetNodeIndex(input_node, 0); } auto device_tensor = AnfAlgo::GetMutableOutputAddr(input_node, 0, false); input_tensor->set_device_address(device_tensor); @@ -488,6 +492,7 @@ void DataPrepareActor::PrepareDataForWeightNode(const AnfNodePtr &backend_node, MS_EXCEPTION_IF_NULL(host_tensor_address); if (host_tensor_address->DeviceType() == device_tensor->DeviceType()) { AnfAlgo::SetOutputAddr(host_tensor_address, 0, backend_node.get()); + host_tensor_address->SetNodeIndex(backend_node, 0); } else { MS_LOG(INFO) << "The device type is not equal, host tensor type:" << host_tensor_address->DeviceType() << ", device tensor type:" << device_tensor->DeviceType(); @@ -635,6 +640,7 @@ void DataPrepareActor::PrepareHostTensorQueueForControlNode(const std::vectorDeviceType() == device_address->DeviceType())) { AnfAlgo::SetOutputAddr(tensor_address, 0, backend_node.get()); + tensor_address->SetNodeIndex(backend_node, 0); } } } diff --git a/mindspore/ccsrc/runtime/framework/actor/output_actor.cc b/mindspore/ccsrc/runtime/framework/actor/output_actor.cc index 19497037bbb..1f4ef030278 100644 --- a/mindspore/ccsrc/runtime/framework/actor/output_actor.cc +++ b/mindspore/ccsrc/runtime/framework/actor/output_actor.cc @@ -107,7 +107,8 @@ TensorPtr OutputActor::CreateOutputTensor(const AnfNodePtr &output_node, size_t const auto &device_tensor = AnfAlgo::GetMutableOutputAddr(output_node, output_index, false); MS_EXCEPTION_IF_NULL(device_tensor); - if (IsPersistentDeviceTensor(output_node)) { + // In the input as output scenario, use the device tensor of node. + if (output_node->isa() || output_node->isa()) { tensor->set_device_address(device_tensor); return tensor; } @@ -145,7 +146,8 @@ void OutputActor::UpdateOutputDeviceAddress() { auto &output_node = output_nodes_[i].first; auto output_index = output_nodes_[i].second; auto &tensor = outputs_[i]; - if ((output_node == nullptr) || (IsPersistentDeviceTensor(output_node))) { + // In the input as output scenario, the output device tensor may come from the input tensor and can't be replaced. + if ((output_node == nullptr) || output_node->isa() || output_node->isa()) { continue; }