!26635 fix the bug of graph and pynative shared the weight

Merge pull request !26635 from limingqi107/new_actor_runtime
This commit is contained in:
i-robot 2021-11-22 19:02:41 +00:00 committed by Gitee
commit cfb21d3f73
2 changed files with 10 additions and 2 deletions

View File

@ -242,6 +242,7 @@ void DataPrepareActor::PrepareDataForHostTensorQueue(const std::vector<std::vect
if ((tensor_address != nullptr) && (tensor_address->DeviceType() == device_address->DeviceType()) && if ((tensor_address != nullptr) && (tensor_address->DeviceType() == device_address->DeviceType()) &&
!device_address->is_ptr_persisted()) { !device_address->is_ptr_persisted()) {
AnfAlgo::SetOutputAddr(tensor_address, 0, input_node.get()); AnfAlgo::SetOutputAddr(tensor_address, 0, input_node.get());
tensor_address->SetNodeIndex(input_node, 0);
} }
} }
} }
@ -288,6 +289,7 @@ void DataPrepareActor::PrepareDataForStepMode(const std::vector<std::vector<Tens
auto host_tensor_address = std::dynamic_pointer_cast<DeviceTensor>(input_tensor->device_address()); auto host_tensor_address = std::dynamic_pointer_cast<DeviceTensor>(input_tensor->device_address());
if (host_tensor_address != nullptr) { if (host_tensor_address != nullptr) {
AnfAlgo::SetOutputAddr(host_tensor_address, 0, input_node.get()); AnfAlgo::SetOutputAddr(host_tensor_address, 0, input_node.get());
host_tensor_address->SetNodeIndex(input_node, 0);
continue; continue;
} }
@ -299,7 +301,9 @@ void DataPrepareActor::PrepareDataForStepMode(const std::vector<std::vector<Tens
size_t tensor_size = AnfAlgo::GetOutputTensorMemSize(input_node, 0); size_t tensor_size = AnfAlgo::GetOutputTensorMemSize(input_node, 0);
auto device_address = device_context->CreateDeviceAddress( auto device_address = device_context->CreateDeviceAddress(
nullptr, tensor_size, AnfAlgo::GetOutputFormat(input_node, 0), output_type_id); nullptr, tensor_size, AnfAlgo::GetOutputFormat(input_node, 0), output_type_id);
MS_EXCEPTION_IF_NULL(device_address);
AnfAlgo::SetOutputAddr(device_address, 0, input_node.get()); AnfAlgo::SetOutputAddr(device_address, 0, input_node.get());
device_address->SetNodeIndex(input_node, 0);
} }
auto device_tensor = AnfAlgo::GetMutableOutputAddr(input_node, 0, false); auto device_tensor = AnfAlgo::GetMutableOutputAddr(input_node, 0, false);
input_tensor->set_device_address(device_tensor); input_tensor->set_device_address(device_tensor);
@ -488,6 +492,7 @@ void DataPrepareActor::PrepareDataForWeightNode(const AnfNodePtr &backend_node,
MS_EXCEPTION_IF_NULL(host_tensor_address); MS_EXCEPTION_IF_NULL(host_tensor_address);
if (host_tensor_address->DeviceType() == device_tensor->DeviceType()) { if (host_tensor_address->DeviceType() == device_tensor->DeviceType()) {
AnfAlgo::SetOutputAddr(host_tensor_address, 0, backend_node.get()); AnfAlgo::SetOutputAddr(host_tensor_address, 0, backend_node.get());
host_tensor_address->SetNodeIndex(backend_node, 0);
} else { } else {
MS_LOG(INFO) << "The device type is not equal, host tensor type:" << host_tensor_address->DeviceType() MS_LOG(INFO) << "The device type is not equal, host tensor type:" << host_tensor_address->DeviceType()
<< ", device tensor type:" << device_tensor->DeviceType(); << ", device tensor type:" << device_tensor->DeviceType();
@ -635,6 +640,7 @@ void DataPrepareActor::PrepareHostTensorQueueForControlNode(const std::vector<Te
MS_EXCEPTION_IF_NULL(device_address); MS_EXCEPTION_IF_NULL(device_address);
if ((tensor_address != nullptr) && (tensor_address->DeviceType() == device_address->DeviceType())) { if ((tensor_address != nullptr) && (tensor_address->DeviceType() == device_address->DeviceType())) {
AnfAlgo::SetOutputAddr(tensor_address, 0, backend_node.get()); AnfAlgo::SetOutputAddr(tensor_address, 0, backend_node.get());
tensor_address->SetNodeIndex(backend_node, 0);
} }
} }
} }

View File

@ -107,7 +107,8 @@ TensorPtr OutputActor::CreateOutputTensor(const AnfNodePtr &output_node, size_t
const auto &device_tensor = AnfAlgo::GetMutableOutputAddr(output_node, output_index, false); const auto &device_tensor = AnfAlgo::GetMutableOutputAddr(output_node, output_index, false);
MS_EXCEPTION_IF_NULL(device_tensor); MS_EXCEPTION_IF_NULL(device_tensor);
if (IsPersistentDeviceTensor(output_node)) { // In the input as output scenario, use the device tensor of node.
if (output_node->isa<ValueNode>() || output_node->isa<Parameter>()) {
tensor->set_device_address(device_tensor); tensor->set_device_address(device_tensor);
return tensor; return tensor;
} }
@ -145,7 +146,8 @@ void OutputActor::UpdateOutputDeviceAddress() {
auto &output_node = output_nodes_[i].first; auto &output_node = output_nodes_[i].first;
auto output_index = output_nodes_[i].second; auto output_index = output_nodes_[i].second;
auto &tensor = outputs_[i]; auto &tensor = outputs_[i];
if ((output_node == nullptr) || (IsPersistentDeviceTensor(output_node))) { // In the input as output scenario, the output device tensor may come from the input tensor and can't be replaced.
if ((output_node == nullptr) || output_node->isa<ValueNode>() || output_node->isa<Parameter>()) {
continue; continue;
} }