!26635 fix the bug of graph and pynative shared the weight

Merge pull request !26635 from limingqi107/new_actor_runtime
This commit is contained in:
i-robot 2021-11-22 19:02:41 +00:00 committed by Gitee
commit cfb21d3f73
2 changed files with 10 additions and 2 deletions

View File

@ -242,6 +242,7 @@ void DataPrepareActor::PrepareDataForHostTensorQueue(const std::vector<std::vect
if ((tensor_address != nullptr) && (tensor_address->DeviceType() == device_address->DeviceType()) &&
!device_address->is_ptr_persisted()) {
AnfAlgo::SetOutputAddr(tensor_address, 0, input_node.get());
tensor_address->SetNodeIndex(input_node, 0);
}
}
}
@ -288,6 +289,7 @@ void DataPrepareActor::PrepareDataForStepMode(const std::vector<std::vector<Tens
auto host_tensor_address = std::dynamic_pointer_cast<DeviceTensor>(input_tensor->device_address());
if (host_tensor_address != nullptr) {
AnfAlgo::SetOutputAddr(host_tensor_address, 0, input_node.get());
host_tensor_address->SetNodeIndex(input_node, 0);
continue;
}
@ -299,7 +301,9 @@ void DataPrepareActor::PrepareDataForStepMode(const std::vector<std::vector<Tens
size_t tensor_size = AnfAlgo::GetOutputTensorMemSize(input_node, 0);
auto device_address = device_context->CreateDeviceAddress(
nullptr, tensor_size, AnfAlgo::GetOutputFormat(input_node, 0), output_type_id);
MS_EXCEPTION_IF_NULL(device_address);
AnfAlgo::SetOutputAddr(device_address, 0, input_node.get());
device_address->SetNodeIndex(input_node, 0);
}
auto device_tensor = AnfAlgo::GetMutableOutputAddr(input_node, 0, false);
input_tensor->set_device_address(device_tensor);
@ -488,6 +492,7 @@ void DataPrepareActor::PrepareDataForWeightNode(const AnfNodePtr &backend_node,
MS_EXCEPTION_IF_NULL(host_tensor_address);
if (host_tensor_address->DeviceType() == device_tensor->DeviceType()) {
AnfAlgo::SetOutputAddr(host_tensor_address, 0, backend_node.get());
host_tensor_address->SetNodeIndex(backend_node, 0);
} else {
MS_LOG(INFO) << "The device type is not equal, host tensor type:" << host_tensor_address->DeviceType()
<< ", device tensor type:" << device_tensor->DeviceType();
@ -635,6 +640,7 @@ void DataPrepareActor::PrepareHostTensorQueueForControlNode(const std::vector<Te
MS_EXCEPTION_IF_NULL(device_address);
if ((tensor_address != nullptr) && (tensor_address->DeviceType() == device_address->DeviceType())) {
AnfAlgo::SetOutputAddr(tensor_address, 0, backend_node.get());
tensor_address->SetNodeIndex(backend_node, 0);
}
}
}

View File

@ -107,7 +107,8 @@ TensorPtr OutputActor::CreateOutputTensor(const AnfNodePtr &output_node, size_t
const auto &device_tensor = AnfAlgo::GetMutableOutputAddr(output_node, output_index, false);
MS_EXCEPTION_IF_NULL(device_tensor);
if (IsPersistentDeviceTensor(output_node)) {
// In the input as output scenario, use the device tensor of node.
if (output_node->isa<ValueNode>() || output_node->isa<Parameter>()) {
tensor->set_device_address(device_tensor);
return tensor;
}
@ -145,7 +146,8 @@ void OutputActor::UpdateOutputDeviceAddress() {
auto &output_node = output_nodes_[i].first;
auto output_index = output_nodes_[i].second;
auto &tensor = outputs_[i];
if ((output_node == nullptr) || (IsPersistentDeviceTensor(output_node))) {
// In the input as output scenario, the output device tensor may come from the input tensor and can't be replaced.
if ((output_node == nullptr) || output_node->isa<ValueNode>() || output_node->isa<Parameter>()) {
continue;
}