!19329 fix bug of actor runtime host device multi-graphs

Merge pull request !19329 from limingqi107/bug_fix
This commit is contained in:
i-robot 2021-07-03 14:07:12 +00:00 committed by Gitee
commit 427e1077a6
2 changed files with 50 additions and 31 deletions

View File

@ -39,12 +39,8 @@ void ComputeThreadNums(size_t *actor_thread_num, size_t *OMP_thread_num) {
*actor_thread_num = *actor_thread_num > kActorThreadMaxNum ? kActorThreadMaxNum : *actor_thread_num;
}
const size_t kOMPThreadNumThreshold = 16;
if (cpu_core_num <= kOMPThreadNumThreshold) {
*OMP_thread_num = cpu_core_num;
} else {
*OMP_thread_num = cpu_core_num / (*actor_thread_num - 1);
}
const size_t kOMPThreadMaxNum = 8;
*OMP_thread_num = cpu_core_num < kOMPThreadMaxNum ? cpu_core_num : kOMPThreadMaxNum;
}
bool IsDeviceQueueDSActor(const AnfNodePtr &node, GraphExecutionStrategy strategy) {

View File

@ -170,7 +170,6 @@ void PrepareDataForWeightNode(const AnfNodePtr &backend_node, const AnfNodePtr &
MS_EXCEPTION_IF_NULL(backend_node);
MS_EXCEPTION_IF_NULL(front_node);
MS_EXCEPTION_IF_NULL(device_context);
if (tensor == nullptr) {
return;
}
@ -187,26 +186,30 @@ void PrepareDataForWeightNode(const AnfNodePtr &backend_node, const AnfNodePtr &
UpdateRefCount(host_tensor_address.get(), true);
}
MS_EXCEPTION_IF_NULL(host_tensor_address);
AnfAlgo::SetOutputAddr(host_tensor_address, 0, backend_node.get());
DeviceTensorStore::GetInstance().Insert(front_node.get(), host_tensor_address);
if (host_tensor_address->DeviceType() == device_tensor->DeviceType()) {
AnfAlgo::SetOutputAddr(host_tensor_address, 0, backend_node.get());
} else {
MS_LOG(ERROR) << "The device type is not equal, host tensor type:" << host_tensor_address->DeviceType()
<< ", device tensor type:" << device_tensor->DeviceType();
}
}
// If the ptr of device tensor is not nullptr, it indicates that the device data has been prepared.
if (host_tensor_address->GetPtr() != nullptr) {
return;
}
MS_LOG(INFO) << "Prepare device data for weight node: " << backend_node->fullname_with_scope();
if (host_tensor_address->GetPtr() == nullptr) {
MS_LOG(INFO) << "Prepare device data for weight node:" << backend_node->fullname_with_scope()
<< ", device type:" << host_tensor_address->DeviceType();
// Allocate device memory and copy data from host tensor to device.
if (!device_context->AllocateMemory(host_tensor_address.get(), host_tensor_address->GetSize())) {
MS_LOG(EXCEPTION) << "Device(id:" << device_context->device_context_key().device_id_
<< ") memory isn't enough and alloc failed, node name: " << backend_node->fullname_with_scope();
}
if (!host_tensor_address->SyncHostToDevice(trans::GetRuntimePaddingShape(backend_node, 0),
LongToSize(tensor->data().nbytes()), tensor->data_type(), tensor->data_c(),
tensor->device_info().host_format_)) {
LongToSize(tensor->data().nbytes()), tensor->data_type(),
tensor->data_c(), tensor->device_info().host_format_)) {
MS_LOG(EXCEPTION) << "SyncHostToDevice failed, node name: " << backend_node->fullname_with_scope();
}
}
// Allocate another device memory and copy data from host tensor to another device(if exist).
const auto &device_tensors = DeviceTensorStore::GetInstance().Fetch(front_node.get());
@ -217,14 +220,24 @@ void PrepareDataForWeightNode(const AnfNodePtr &backend_node, const AnfNodePtr &
const auto &another_device_context = device::DeviceContextManager::GetInstance().GetOrCreateDeviceContext(
{device::kDeviceTypeToName.at(another_device_type), device_context->device_context_key().device_id_});
MS_EXCEPTION_IF_NULL(another_device_context);
if (another_device_tensor->GetPtr() == nullptr) {
if (!another_device_context->AllocateMemory(another_device_tensor.get(), another_device_tensor->GetSize())) {
MS_LOG(EXCEPTION) << "Device(id:" << another_device_context->device_context_key().device_id_
<< ") memory isn't enough and alloc failed, node name: " << backend_node->fullname_with_scope();
<< ") memory isn't enough and alloc failed, node name: "
<< backend_node->fullname_with_scope();
}
if (!another_device_tensor->SyncHostToDevice(trans::GetRuntimePaddingShape(backend_node, 0),
LongToSize(tensor->data().nbytes()), tensor->data_type(),
tensor->data_c())) {
MS_LOG(EXCEPTION) << "SyncHostToDevice failed, node name: " << backend_node->fullname_with_scope();
}
MS_LOG(INFO) << "Prepare device data for weight node:" << backend_node->fullname_with_scope()
<< ", device type:" << another_device_type;
if (host_tensor_address->DeviceType() == device::DeviceAddressType::kCPU) {
// CPU device tensor copy to other device tensor.
(void)another_device_tensor->SyncHostToDevice(host_tensor_address->GetSize(), host_tensor_address->GetPtr());
} else if (another_device_tensor->DeviceType() == device::DeviceAddressType::kCPU) {
// Other device tensor copy to CPU device tensor.
(void)host_tensor_address->SyncDeviceToHost(another_device_tensor->GetSize(),
another_device_tensor->GetMutablePtr());
} else {
MS_LOG(EXCEPTION) << "Invalid device type for sync data.";
}
}
}
@ -2640,12 +2653,22 @@ void GraphScheduler::PersistDeviceTensor(const GraphCompilerInfo &graph_compiler
for (auto &input_node : graph->input_nodes()) {
MS_EXCEPTION_IF_NULL(input_node);
if (!IsPersistentDeviceTensor(input_node)) {
AnfNodePtr front_node = nullptr;
if (IsInternalParameter(input_node, graph)) {
auto front_node_with_index = graph->GetFrontNodeByInternalParameter(input_node);
MS_EXCEPTION_IF_NULL(front_node_with_index.first);
const auto &front_output_with_index =
AnfAlgo::VisitKernelWithReturnType(front_node_with_index.first, front_node_with_index.second, false);
front_node = front_output_with_index.first;
} else if (IsPersistentDeviceTensor(input_node)) {
front_node = FetchFrontNodeByBackendNode(input_node, graph);
}
if (front_node == nullptr) {
continue;
}
auto device_tensor = AnfAlgo::GetMutableOutputAddr(input_node, 0, false);
MS_EXCEPTION_IF_NULL(device_tensor);
const auto &front_node = FetchFrontNodeByBackendNode(input_node, graph);
DeviceTensorStore::GetInstance().Insert(front_node.get(), device_tensor);
UpdateRefCount(device_tensor.get(), true);