!19329 fix bug of actor runtime host device multi-graphs
Merge pull request !19329 from limingqi107/bug_fix
This commit is contained in:
commit
427e1077a6
|
@ -39,12 +39,8 @@ void ComputeThreadNums(size_t *actor_thread_num, size_t *OMP_thread_num) {
|
|||
*actor_thread_num = *actor_thread_num > kActorThreadMaxNum ? kActorThreadMaxNum : *actor_thread_num;
|
||||
}
|
||||
|
||||
const size_t kOMPThreadNumThreshold = 16;
|
||||
if (cpu_core_num <= kOMPThreadNumThreshold) {
|
||||
*OMP_thread_num = cpu_core_num;
|
||||
} else {
|
||||
*OMP_thread_num = cpu_core_num / (*actor_thread_num - 1);
|
||||
}
|
||||
const size_t kOMPThreadMaxNum = 8;
|
||||
*OMP_thread_num = cpu_core_num < kOMPThreadMaxNum ? cpu_core_num : kOMPThreadMaxNum;
|
||||
}
|
||||
|
||||
bool IsDeviceQueueDSActor(const AnfNodePtr &node, GraphExecutionStrategy strategy) {
|
||||
|
|
|
@ -170,7 +170,6 @@ void PrepareDataForWeightNode(const AnfNodePtr &backend_node, const AnfNodePtr &
|
|||
MS_EXCEPTION_IF_NULL(backend_node);
|
||||
MS_EXCEPTION_IF_NULL(front_node);
|
||||
MS_EXCEPTION_IF_NULL(device_context);
|
||||
|
||||
if (tensor == nullptr) {
|
||||
return;
|
||||
}
|
||||
|
@ -187,26 +186,30 @@ void PrepareDataForWeightNode(const AnfNodePtr &backend_node, const AnfNodePtr &
|
|||
UpdateRefCount(host_tensor_address.get(), true);
|
||||
}
|
||||
MS_EXCEPTION_IF_NULL(host_tensor_address);
|
||||
AnfAlgo::SetOutputAddr(host_tensor_address, 0, backend_node.get());
|
||||
DeviceTensorStore::GetInstance().Insert(front_node.get(), host_tensor_address);
|
||||
if (host_tensor_address->DeviceType() == device_tensor->DeviceType()) {
|
||||
AnfAlgo::SetOutputAddr(host_tensor_address, 0, backend_node.get());
|
||||
} else {
|
||||
MS_LOG(ERROR) << "The device type is not equal, host tensor type:" << host_tensor_address->DeviceType()
|
||||
<< ", device tensor type:" << device_tensor->DeviceType();
|
||||
}
|
||||
}
|
||||
|
||||
// If the ptr of device tensor is not nullptr, it indicates that the device data has been prepared.
|
||||
if (host_tensor_address->GetPtr() != nullptr) {
|
||||
return;
|
||||
}
|
||||
MS_LOG(INFO) << "Prepare device data for weight node: " << backend_node->fullname_with_scope();
|
||||
|
||||
if (host_tensor_address->GetPtr() == nullptr) {
|
||||
MS_LOG(INFO) << "Prepare device data for weight node:" << backend_node->fullname_with_scope()
|
||||
<< ", device type:" << host_tensor_address->DeviceType();
|
||||
// Allocate device memory and copy data from host tensor to device.
|
||||
if (!device_context->AllocateMemory(host_tensor_address.get(), host_tensor_address->GetSize())) {
|
||||
MS_LOG(EXCEPTION) << "Device(id:" << device_context->device_context_key().device_id_
|
||||
<< ") memory isn't enough and alloc failed, node name: " << backend_node->fullname_with_scope();
|
||||
}
|
||||
if (!host_tensor_address->SyncHostToDevice(trans::GetRuntimePaddingShape(backend_node, 0),
|
||||
LongToSize(tensor->data().nbytes()), tensor->data_type(), tensor->data_c(),
|
||||
tensor->device_info().host_format_)) {
|
||||
LongToSize(tensor->data().nbytes()), tensor->data_type(),
|
||||
tensor->data_c(), tensor->device_info().host_format_)) {
|
||||
MS_LOG(EXCEPTION) << "SyncHostToDevice failed, node name: " << backend_node->fullname_with_scope();
|
||||
}
|
||||
}
|
||||
|
||||
// Allocate another device memory and copy data from host tensor to another device(if exist).
|
||||
const auto &device_tensors = DeviceTensorStore::GetInstance().Fetch(front_node.get());
|
||||
|
@ -217,14 +220,24 @@ void PrepareDataForWeightNode(const AnfNodePtr &backend_node, const AnfNodePtr &
|
|||
const auto &another_device_context = device::DeviceContextManager::GetInstance().GetOrCreateDeviceContext(
|
||||
{device::kDeviceTypeToName.at(another_device_type), device_context->device_context_key().device_id_});
|
||||
MS_EXCEPTION_IF_NULL(another_device_context);
|
||||
if (another_device_tensor->GetPtr() == nullptr) {
|
||||
if (!another_device_context->AllocateMemory(another_device_tensor.get(), another_device_tensor->GetSize())) {
|
||||
MS_LOG(EXCEPTION) << "Device(id:" << another_device_context->device_context_key().device_id_
|
||||
<< ") memory isn't enough and alloc failed, node name: " << backend_node->fullname_with_scope();
|
||||
<< ") memory isn't enough and alloc failed, node name: "
|
||||
<< backend_node->fullname_with_scope();
|
||||
}
|
||||
if (!another_device_tensor->SyncHostToDevice(trans::GetRuntimePaddingShape(backend_node, 0),
|
||||
LongToSize(tensor->data().nbytes()), tensor->data_type(),
|
||||
tensor->data_c())) {
|
||||
MS_LOG(EXCEPTION) << "SyncHostToDevice failed, node name: " << backend_node->fullname_with_scope();
|
||||
}
|
||||
MS_LOG(INFO) << "Prepare device data for weight node:" << backend_node->fullname_with_scope()
|
||||
<< ", device type:" << another_device_type;
|
||||
if (host_tensor_address->DeviceType() == device::DeviceAddressType::kCPU) {
|
||||
// CPU device tensor copy to other device tensor.
|
||||
(void)another_device_tensor->SyncHostToDevice(host_tensor_address->GetSize(), host_tensor_address->GetPtr());
|
||||
} else if (another_device_tensor->DeviceType() == device::DeviceAddressType::kCPU) {
|
||||
// Other device tensor copy to CPU device tensor.
|
||||
(void)host_tensor_address->SyncDeviceToHost(another_device_tensor->GetSize(),
|
||||
another_device_tensor->GetMutablePtr());
|
||||
} else {
|
||||
MS_LOG(EXCEPTION) << "Invalid device type for sync data.";
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -2640,12 +2653,22 @@ void GraphScheduler::PersistDeviceTensor(const GraphCompilerInfo &graph_compiler
|
|||
|
||||
for (auto &input_node : graph->input_nodes()) {
|
||||
MS_EXCEPTION_IF_NULL(input_node);
|
||||
if (!IsPersistentDeviceTensor(input_node)) {
|
||||
AnfNodePtr front_node = nullptr;
|
||||
if (IsInternalParameter(input_node, graph)) {
|
||||
auto front_node_with_index = graph->GetFrontNodeByInternalParameter(input_node);
|
||||
MS_EXCEPTION_IF_NULL(front_node_with_index.first);
|
||||
const auto &front_output_with_index =
|
||||
AnfAlgo::VisitKernelWithReturnType(front_node_with_index.first, front_node_with_index.second, false);
|
||||
front_node = front_output_with_index.first;
|
||||
} else if (IsPersistentDeviceTensor(input_node)) {
|
||||
front_node = FetchFrontNodeByBackendNode(input_node, graph);
|
||||
}
|
||||
if (front_node == nullptr) {
|
||||
continue;
|
||||
}
|
||||
|
||||
auto device_tensor = AnfAlgo::GetMutableOutputAddr(input_node, 0, false);
|
||||
MS_EXCEPTION_IF_NULL(device_tensor);
|
||||
const auto &front_node = FetchFrontNodeByBackendNode(input_node, graph);
|
||||
DeviceTensorStore::GetInstance().Insert(front_node.get(), device_tensor);
|
||||
UpdateRefCount(device_tensor.get(), true);
|
||||
|
||||
|
|
Loading…
Reference in New Issue