forked from mindspore-Ecosystem/mindspore
!26635 fix the bug of graph and pynative shared the weight
Merge pull request !26635 from limingqi107/new_actor_runtime
This commit is contained in:
commit
cfb21d3f73
|
@ -242,6 +242,7 @@ void DataPrepareActor::PrepareDataForHostTensorQueue(const std::vector<std::vect
|
||||||
if ((tensor_address != nullptr) && (tensor_address->DeviceType() == device_address->DeviceType()) &&
|
if ((tensor_address != nullptr) && (tensor_address->DeviceType() == device_address->DeviceType()) &&
|
||||||
!device_address->is_ptr_persisted()) {
|
!device_address->is_ptr_persisted()) {
|
||||||
AnfAlgo::SetOutputAddr(tensor_address, 0, input_node.get());
|
AnfAlgo::SetOutputAddr(tensor_address, 0, input_node.get());
|
||||||
|
tensor_address->SetNodeIndex(input_node, 0);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -288,6 +289,7 @@ void DataPrepareActor::PrepareDataForStepMode(const std::vector<std::vector<Tens
|
||||||
auto host_tensor_address = std::dynamic_pointer_cast<DeviceTensor>(input_tensor->device_address());
|
auto host_tensor_address = std::dynamic_pointer_cast<DeviceTensor>(input_tensor->device_address());
|
||||||
if (host_tensor_address != nullptr) {
|
if (host_tensor_address != nullptr) {
|
||||||
AnfAlgo::SetOutputAddr(host_tensor_address, 0, input_node.get());
|
AnfAlgo::SetOutputAddr(host_tensor_address, 0, input_node.get());
|
||||||
|
host_tensor_address->SetNodeIndex(input_node, 0);
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -299,7 +301,9 @@ void DataPrepareActor::PrepareDataForStepMode(const std::vector<std::vector<Tens
|
||||||
size_t tensor_size = AnfAlgo::GetOutputTensorMemSize(input_node, 0);
|
size_t tensor_size = AnfAlgo::GetOutputTensorMemSize(input_node, 0);
|
||||||
auto device_address = device_context->CreateDeviceAddress(
|
auto device_address = device_context->CreateDeviceAddress(
|
||||||
nullptr, tensor_size, AnfAlgo::GetOutputFormat(input_node, 0), output_type_id);
|
nullptr, tensor_size, AnfAlgo::GetOutputFormat(input_node, 0), output_type_id);
|
||||||
|
MS_EXCEPTION_IF_NULL(device_address);
|
||||||
AnfAlgo::SetOutputAddr(device_address, 0, input_node.get());
|
AnfAlgo::SetOutputAddr(device_address, 0, input_node.get());
|
||||||
|
device_address->SetNodeIndex(input_node, 0);
|
||||||
}
|
}
|
||||||
auto device_tensor = AnfAlgo::GetMutableOutputAddr(input_node, 0, false);
|
auto device_tensor = AnfAlgo::GetMutableOutputAddr(input_node, 0, false);
|
||||||
input_tensor->set_device_address(device_tensor);
|
input_tensor->set_device_address(device_tensor);
|
||||||
|
@ -488,6 +492,7 @@ void DataPrepareActor::PrepareDataForWeightNode(const AnfNodePtr &backend_node,
|
||||||
MS_EXCEPTION_IF_NULL(host_tensor_address);
|
MS_EXCEPTION_IF_NULL(host_tensor_address);
|
||||||
if (host_tensor_address->DeviceType() == device_tensor->DeviceType()) {
|
if (host_tensor_address->DeviceType() == device_tensor->DeviceType()) {
|
||||||
AnfAlgo::SetOutputAddr(host_tensor_address, 0, backend_node.get());
|
AnfAlgo::SetOutputAddr(host_tensor_address, 0, backend_node.get());
|
||||||
|
host_tensor_address->SetNodeIndex(backend_node, 0);
|
||||||
} else {
|
} else {
|
||||||
MS_LOG(INFO) << "The device type is not equal, host tensor type:" << host_tensor_address->DeviceType()
|
MS_LOG(INFO) << "The device type is not equal, host tensor type:" << host_tensor_address->DeviceType()
|
||||||
<< ", device tensor type:" << device_tensor->DeviceType();
|
<< ", device tensor type:" << device_tensor->DeviceType();
|
||||||
|
@ -635,6 +640,7 @@ void DataPrepareActor::PrepareHostTensorQueueForControlNode(const std::vector<Te
|
||||||
MS_EXCEPTION_IF_NULL(device_address);
|
MS_EXCEPTION_IF_NULL(device_address);
|
||||||
if ((tensor_address != nullptr) && (tensor_address->DeviceType() == device_address->DeviceType())) {
|
if ((tensor_address != nullptr) && (tensor_address->DeviceType() == device_address->DeviceType())) {
|
||||||
AnfAlgo::SetOutputAddr(tensor_address, 0, backend_node.get());
|
AnfAlgo::SetOutputAddr(tensor_address, 0, backend_node.get());
|
||||||
|
tensor_address->SetNodeIndex(backend_node, 0);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -107,7 +107,8 @@ TensorPtr OutputActor::CreateOutputTensor(const AnfNodePtr &output_node, size_t
|
||||||
|
|
||||||
const auto &device_tensor = AnfAlgo::GetMutableOutputAddr(output_node, output_index, false);
|
const auto &device_tensor = AnfAlgo::GetMutableOutputAddr(output_node, output_index, false);
|
||||||
MS_EXCEPTION_IF_NULL(device_tensor);
|
MS_EXCEPTION_IF_NULL(device_tensor);
|
||||||
if (IsPersistentDeviceTensor(output_node)) {
|
// In the input as output scenario, use the device tensor of node.
|
||||||
|
if (output_node->isa<ValueNode>() || output_node->isa<Parameter>()) {
|
||||||
tensor->set_device_address(device_tensor);
|
tensor->set_device_address(device_tensor);
|
||||||
return tensor;
|
return tensor;
|
||||||
}
|
}
|
||||||
|
@ -145,7 +146,8 @@ void OutputActor::UpdateOutputDeviceAddress() {
|
||||||
auto &output_node = output_nodes_[i].first;
|
auto &output_node = output_nodes_[i].first;
|
||||||
auto output_index = output_nodes_[i].second;
|
auto output_index = output_nodes_[i].second;
|
||||||
auto &tensor = outputs_[i];
|
auto &tensor = outputs_[i];
|
||||||
if ((output_node == nullptr) || (IsPersistentDeviceTensor(output_node))) {
|
// In the input as output scenario, the output device tensor may come from the input tensor and can't be replaced.
|
||||||
|
if ((output_node == nullptr) || output_node->isa<ValueNode>() || output_node->isa<Parameter>()) {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue