!19344 fix bug of host device from different graph

Merge pull request !19344 from limingqi107/bug_fix2
This commit is contained in:
i-robot 2021-07-04 16:58:51 +00:00 committed by Gitee
commit a09f38399f
6 changed files with 42 additions and 34 deletions

View File

@ -113,5 +113,25 @@ bool IsGatherActor(const AnfNodePtr &front_node,
} }
return false; return false;
} }
bool Copy(DeviceTensor *dst_device_tensor, const DeviceTensor *src_device_tensor) {
MS_EXCEPTION_IF_NULL(dst_device_tensor);
MS_EXCEPTION_IF_NULL(src_device_tensor);
// Exist the size alignment in some device, so get the min device size.
size_t copy_size = std::min(src_device_tensor->GetSize(), dst_device_tensor->GetSize());
if (src_device_tensor->DeviceType() == device::DeviceAddressType::kCPU) {
// CPU device tensor copy to other device tensor.
return dst_device_tensor->SyncHostToDevice(copy_size, src_device_tensor->GetPtr());
} else if (dst_device_tensor->DeviceType() == device::DeviceAddressType::kCPU) {
// Other device tensor copy to CPU device tensor.
return src_device_tensor->SyncDeviceToHost(copy_size, dst_device_tensor->GetMutablePtr());
} else {
MS_LOG(ERROR) << "Invalid device type, src device type: " << src_device_tensor->DeviceType()
<< ", dst device type: " << dst_device_tensor->DeviceType();
return false;
}
}
} // namespace runtime } // namespace runtime
} // namespace mindspore } // namespace mindspore

View File

@ -22,6 +22,7 @@
#include <unordered_map> #include <unordered_map>
#include <utility> #include <utility>
#include <thread> #include <thread>
#include <algorithm>
#include "mindrt/include/actor/op_actor.h" #include "mindrt/include/actor/op_actor.h"
#include "runtime/device/device_address.h" #include "runtime/device/device_address.h"
#include "backend/session/kernel_graph.h" #include "backend/session/kernel_graph.h"
@ -86,6 +87,9 @@ bool IsPersistentDeviceTensor(const AnfNodePtr &node);
// Judge whether the front node is in a gather actor. // Judge whether the front node is in a gather actor.
bool IsGatherActor(const AnfNodePtr &front_node, bool IsGatherActor(const AnfNodePtr &front_node,
const std::unordered_map<std::string, OpActor<DeviceTensor> *> &actor_name_to_actor); const std::unordered_map<std::string, OpActor<DeviceTensor> *> &actor_name_to_actor);
// Copy data from src_device_tensor to dst_device_tensor.
bool Copy(DeviceTensor *dst_device_tensor, const DeviceTensor *src_device_tensor);
} // namespace runtime } // namespace runtime
} // namespace mindspore } // namespace mindspore

View File

@ -89,22 +89,6 @@ void CopyActor::OnMemoryAllocFinish(OpContext<DeviceTensor> *context) {
SendOutput(context); SendOutput(context);
} }
bool CopyActor::Copy(DeviceTensor *dst_device_tensor, const DeviceTensor *src_device_tensor) {
MS_EXCEPTION_IF_NULL(dst_device_tensor);
MS_EXCEPTION_IF_NULL(src_device_tensor);
if (src_device_tensor->DeviceType() == device::DeviceAddressType::kCPU) {
// CPU device tensor copy to other device tensor.
return dst_device_tensor->SyncHostToDevice(src_device_tensor->GetSize(), src_device_tensor->GetPtr());
} else if (dst_device_tensor->DeviceType() == device::DeviceAddressType::kCPU) {
// Other device tensor copy to CPU device tensor.
return src_device_tensor->SyncDeviceToHost(dst_device_tensor->GetSize(), dst_device_tensor->GetMutablePtr());
} else {
MS_LOG(ERROR) << "Invalid device type for copy actor: " << GetAID().Name();
return false;
}
}
bool CopyActor::CheckCopyCondition(OpContext<DeviceTensor> *context) const { bool CopyActor::CheckCopyCondition(OpContext<DeviceTensor> *context) const {
MS_EXCEPTION_IF_NULL(context); MS_EXCEPTION_IF_NULL(context);
if (input_datas_num_ != 0) { if (input_datas_num_ != 0) {

View File

@ -65,8 +65,6 @@ class CopyActor : public MemoryAwareActor {
// Fetch the device tensor for copy. // Fetch the device tensor for copy.
void FetchDeviceTensor(OpContext<DeviceTensor> *context); void FetchDeviceTensor(OpContext<DeviceTensor> *context);
// Copy data from src_device_tensor to dst_device_tensor.
bool Copy(DeviceTensor *dst_device_tensor, const DeviceTensor *src_device_tensor);
// Send output data and output controls when finish copy. // Send output data and output controls when finish copy.
void SendOutput(OpContext<DeviceTensor> *context) const; void SendOutput(OpContext<DeviceTensor> *context) const;
// Erase input data and input controls when finish copy. // Erase input data and input controls when finish copy.

View File

@ -258,13 +258,15 @@ void HostQueueDataSourceActor::OnMemoryAllocFinish(OpContext<DeviceTensor> *cont
MS_EXCEPTION_IF_NULL(host_tensor); MS_EXCEPTION_IF_NULL(host_tensor);
MS_EXCEPTION_IF_NULL(device_tensor); MS_EXCEPTION_IF_NULL(device_tensor);
auto tensor_device_address = std::dynamic_pointer_cast<DeviceTensor>(host_tensor->device_address()); auto tensor_device_address = std::dynamic_pointer_cast<DeviceTensor>(host_tensor->device_address());
// Sync data from host_tensor_device_address to device_tensor.
if (tensor_device_address != nullptr) { if (tensor_device_address != nullptr) {
if (tensor_device_address.get() != device_tensor) { if ((tensor_device_address.get() != device_tensor) && (!Copy(device_tensor, tensor_device_address.get()))) {
MS_LOG(EXCEPTION) << "The device tensor of host queue node should be equal to device address of input tensor"; SET_OPCONTEXT_FAIL_RET_WITH_ERROR((*context), "Copy data failed.");
} }
continue; continue;
} }
// Sync data from host_tensor to device_tensor.
if (!device_tensor->SyncHostToDevice(trans::GetRuntimePaddingShape(data_nodes_[i], 0), if (!device_tensor->SyncHostToDevice(trans::GetRuntimePaddingShape(data_nodes_[i], 0),
LongToSize(host_tensor->data().nbytes()), host_tensor->data_type(), LongToSize(host_tensor->data().nbytes()), host_tensor->data_type(),
host_tensor->data_c(), host_tensor->device_info().host_format_)) { host_tensor->data_c(), host_tensor->device_info().host_format_)) {

View File

@ -231,15 +231,8 @@ void PrepareDataForWeightNode(const AnfNodePtr &backend_node, const AnfNodePtr &
} }
MS_LOG(INFO) << "Prepare device data for weight node:" << backend_node->fullname_with_scope() MS_LOG(INFO) << "Prepare device data for weight node:" << backend_node->fullname_with_scope()
<< ", device type:" << another_device_type; << ", device type:" << another_device_type;
if (host_tensor_address->DeviceType() == device::DeviceAddressType::kCPU) { if (!Copy(another_device_tensor.get(), host_tensor_address.get())) {
// CPU device tensor copy to other device tensor. MS_LOG(EXCEPTION) << "Sync data error.";
(void)another_device_tensor->SyncHostToDevice(host_tensor_address->GetSize(), host_tensor_address->GetPtr());
} else if (another_device_tensor->DeviceType() == device::DeviceAddressType::kCPU) {
// Other device tensor copy to CPU device tensor.
(void)host_tensor_address->SyncDeviceToHost(another_device_tensor->GetSize(),
another_device_tensor->GetMutablePtr());
} else {
MS_LOG(EXCEPTION) << "Invalid device type for sync data.";
} }
} }
} }
@ -312,10 +305,11 @@ void PrepareDataForHostDataSourceActor(const std::unordered_map<AnfNodePtr, size
} }
(*host_tensors)[iter->second] = tensor; (*host_tensors)[iter->second] = tensor;
auto device_address = std::dynamic_pointer_cast<DeviceTensor>(tensor->device_address()); auto tensor_address = std::dynamic_pointer_cast<DeviceTensor>(tensor->device_address());
if (device_address != nullptr) { auto device_address = AnfAlgo::GetMutableOutputAddr(node, 0, false);
AnfAlgo::SetOutputAddr(device_address, 0, node.get()); MS_EXCEPTION_IF_NULL(device_address);
return; if ((tensor_address != nullptr) && (tensor_address->DeviceType() == device_address->DeviceType())) {
AnfAlgo::SetOutputAddr(tensor_address, 0, node.get());
} }
} }
@ -1565,8 +1559,14 @@ void GraphScheduler::LinkDataArrowForCopyActor(OpActor<DeviceTensor> *from_actor
// Set the member of the copy actor. // Set the member of the copy actor.
MS_EXCEPTION_IF_NULL(from_device_tensor); MS_EXCEPTION_IF_NULL(from_device_tensor);
auto to_kernel_mod = AnfAlgo::GetKernelMod(to_kernel_with_input_idx.first);
MS_EXCEPTION_IF_NULL(to_kernel_mod);
auto input_sizes = to_kernel_mod->GetInputSizeList();
if (to_input_index >= input_sizes.size()) {
MS_LOG(EXCEPTION) << "To input index(" << to_input_index << ") is out of size: " << input_sizes.size();
}
copy_actor->output_ = to_devcie_context->CreateDeviceAddress( copy_actor->output_ = to_devcie_context->CreateDeviceAddress(
nullptr, from_device_tensor->GetSize(), from_device_tensor->format(), from_device_tensor->type_id()); nullptr, input_sizes[to_input_index], from_device_tensor->format(), from_device_tensor->type_id());
MS_EXCEPTION_IF_NULL(from_devcie_context); MS_EXCEPTION_IF_NULL(from_devcie_context);
copy_actor->input_device_context_ = from_devcie_context; copy_actor->input_device_context_ = from_devcie_context;
copy_actor->output_device_context_ = to_devcie_context; copy_actor->output_device_context_ = to_devcie_context;