!19344 fix bug of host device from different graph
Merge pull request !19344 from limingqi107/bug_fix2
This commit is contained in:
commit
a09f38399f
|
@ -113,5 +113,25 @@ bool IsGatherActor(const AnfNodePtr &front_node,
|
|||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
bool Copy(DeviceTensor *dst_device_tensor, const DeviceTensor *src_device_tensor) {
|
||||
MS_EXCEPTION_IF_NULL(dst_device_tensor);
|
||||
MS_EXCEPTION_IF_NULL(src_device_tensor);
|
||||
|
||||
// Exist the size alignment in some device, so get the min device size.
|
||||
size_t copy_size = std::min(src_device_tensor->GetSize(), dst_device_tensor->GetSize());
|
||||
|
||||
if (src_device_tensor->DeviceType() == device::DeviceAddressType::kCPU) {
|
||||
// CPU device tensor copy to other device tensor.
|
||||
return dst_device_tensor->SyncHostToDevice(copy_size, src_device_tensor->GetPtr());
|
||||
} else if (dst_device_tensor->DeviceType() == device::DeviceAddressType::kCPU) {
|
||||
// Other device tensor copy to CPU device tensor.
|
||||
return src_device_tensor->SyncDeviceToHost(copy_size, dst_device_tensor->GetMutablePtr());
|
||||
} else {
|
||||
MS_LOG(ERROR) << "Invalid device type, src device type: " << src_device_tensor->DeviceType()
|
||||
<< ", dst device type: " << dst_device_tensor->DeviceType();
|
||||
return false;
|
||||
}
|
||||
}
|
||||
} // namespace runtime
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -22,6 +22,7 @@
|
|||
#include <unordered_map>
|
||||
#include <utility>
|
||||
#include <thread>
|
||||
#include <algorithm>
|
||||
#include "mindrt/include/actor/op_actor.h"
|
||||
#include "runtime/device/device_address.h"
|
||||
#include "backend/session/kernel_graph.h"
|
||||
|
@ -86,6 +87,9 @@ bool IsPersistentDeviceTensor(const AnfNodePtr &node);
|
|||
// Judge whether the front node is in a gather actor.
|
||||
bool IsGatherActor(const AnfNodePtr &front_node,
|
||||
const std::unordered_map<std::string, OpActor<DeviceTensor> *> &actor_name_to_actor);
|
||||
|
||||
// Copy data from src_device_tensor to dst_device_tensor.
|
||||
bool Copy(DeviceTensor *dst_device_tensor, const DeviceTensor *src_device_tensor);
|
||||
} // namespace runtime
|
||||
} // namespace mindspore
|
||||
|
||||
|
|
|
@ -89,22 +89,6 @@ void CopyActor::OnMemoryAllocFinish(OpContext<DeviceTensor> *context) {
|
|||
SendOutput(context);
|
||||
}
|
||||
|
||||
bool CopyActor::Copy(DeviceTensor *dst_device_tensor, const DeviceTensor *src_device_tensor) {
|
||||
MS_EXCEPTION_IF_NULL(dst_device_tensor);
|
||||
MS_EXCEPTION_IF_NULL(src_device_tensor);
|
||||
|
||||
if (src_device_tensor->DeviceType() == device::DeviceAddressType::kCPU) {
|
||||
// CPU device tensor copy to other device tensor.
|
||||
return dst_device_tensor->SyncHostToDevice(src_device_tensor->GetSize(), src_device_tensor->GetPtr());
|
||||
} else if (dst_device_tensor->DeviceType() == device::DeviceAddressType::kCPU) {
|
||||
// Other device tensor copy to CPU device tensor.
|
||||
return src_device_tensor->SyncDeviceToHost(dst_device_tensor->GetSize(), dst_device_tensor->GetMutablePtr());
|
||||
} else {
|
||||
MS_LOG(ERROR) << "Invalid device type for copy actor: " << GetAID().Name();
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
bool CopyActor::CheckCopyCondition(OpContext<DeviceTensor> *context) const {
|
||||
MS_EXCEPTION_IF_NULL(context);
|
||||
if (input_datas_num_ != 0) {
|
||||
|
|
|
@ -65,8 +65,6 @@ class CopyActor : public MemoryAwareActor {
|
|||
// Fetch the device tensor for copy.
|
||||
void FetchDeviceTensor(OpContext<DeviceTensor> *context);
|
||||
|
||||
// Copy data from src_device_tensor to dst_device_tensor.
|
||||
bool Copy(DeviceTensor *dst_device_tensor, const DeviceTensor *src_device_tensor);
|
||||
// Send output data and output controls when finish copy.
|
||||
void SendOutput(OpContext<DeviceTensor> *context) const;
|
||||
// Erase input data and input controls when finish copy.
|
||||
|
|
|
@ -258,13 +258,15 @@ void HostQueueDataSourceActor::OnMemoryAllocFinish(OpContext<DeviceTensor> *cont
|
|||
MS_EXCEPTION_IF_NULL(host_tensor);
|
||||
MS_EXCEPTION_IF_NULL(device_tensor);
|
||||
auto tensor_device_address = std::dynamic_pointer_cast<DeviceTensor>(host_tensor->device_address());
|
||||
// Sync data from host_tensor_device_address to device_tensor.
|
||||
if (tensor_device_address != nullptr) {
|
||||
if (tensor_device_address.get() != device_tensor) {
|
||||
MS_LOG(EXCEPTION) << "The device tensor of host queue node should be equal to device address of input tensor";
|
||||
if ((tensor_device_address.get() != device_tensor) && (!Copy(device_tensor, tensor_device_address.get()))) {
|
||||
SET_OPCONTEXT_FAIL_RET_WITH_ERROR((*context), "Copy data failed.");
|
||||
}
|
||||
continue;
|
||||
}
|
||||
|
||||
// Sync data from host_tensor to device_tensor.
|
||||
if (!device_tensor->SyncHostToDevice(trans::GetRuntimePaddingShape(data_nodes_[i], 0),
|
||||
LongToSize(host_tensor->data().nbytes()), host_tensor->data_type(),
|
||||
host_tensor->data_c(), host_tensor->device_info().host_format_)) {
|
||||
|
|
|
@ -231,15 +231,8 @@ void PrepareDataForWeightNode(const AnfNodePtr &backend_node, const AnfNodePtr &
|
|||
}
|
||||
MS_LOG(INFO) << "Prepare device data for weight node:" << backend_node->fullname_with_scope()
|
||||
<< ", device type:" << another_device_type;
|
||||
if (host_tensor_address->DeviceType() == device::DeviceAddressType::kCPU) {
|
||||
// CPU device tensor copy to other device tensor.
|
||||
(void)another_device_tensor->SyncHostToDevice(host_tensor_address->GetSize(), host_tensor_address->GetPtr());
|
||||
} else if (another_device_tensor->DeviceType() == device::DeviceAddressType::kCPU) {
|
||||
// Other device tensor copy to CPU device tensor.
|
||||
(void)host_tensor_address->SyncDeviceToHost(another_device_tensor->GetSize(),
|
||||
another_device_tensor->GetMutablePtr());
|
||||
} else {
|
||||
MS_LOG(EXCEPTION) << "Invalid device type for sync data.";
|
||||
if (!Copy(another_device_tensor.get(), host_tensor_address.get())) {
|
||||
MS_LOG(EXCEPTION) << "Sync data error.";
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -312,10 +305,11 @@ void PrepareDataForHostDataSourceActor(const std::unordered_map<AnfNodePtr, size
|
|||
}
|
||||
|
||||
(*host_tensors)[iter->second] = tensor;
|
||||
auto device_address = std::dynamic_pointer_cast<DeviceTensor>(tensor->device_address());
|
||||
if (device_address != nullptr) {
|
||||
AnfAlgo::SetOutputAddr(device_address, 0, node.get());
|
||||
return;
|
||||
auto tensor_address = std::dynamic_pointer_cast<DeviceTensor>(tensor->device_address());
|
||||
auto device_address = AnfAlgo::GetMutableOutputAddr(node, 0, false);
|
||||
MS_EXCEPTION_IF_NULL(device_address);
|
||||
if ((tensor_address != nullptr) && (tensor_address->DeviceType() == device_address->DeviceType())) {
|
||||
AnfAlgo::SetOutputAddr(tensor_address, 0, node.get());
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -1565,8 +1559,14 @@ void GraphScheduler::LinkDataArrowForCopyActor(OpActor<DeviceTensor> *from_actor
|
|||
|
||||
// Set the member of the copy actor.
|
||||
MS_EXCEPTION_IF_NULL(from_device_tensor);
|
||||
auto to_kernel_mod = AnfAlgo::GetKernelMod(to_kernel_with_input_idx.first);
|
||||
MS_EXCEPTION_IF_NULL(to_kernel_mod);
|
||||
auto input_sizes = to_kernel_mod->GetInputSizeList();
|
||||
if (to_input_index >= input_sizes.size()) {
|
||||
MS_LOG(EXCEPTION) << "To input index(" << to_input_index << ") is out of size: " << input_sizes.size();
|
||||
}
|
||||
copy_actor->output_ = to_devcie_context->CreateDeviceAddress(
|
||||
nullptr, from_device_tensor->GetSize(), from_device_tensor->format(), from_device_tensor->type_id());
|
||||
nullptr, input_sizes[to_input_index], from_device_tensor->format(), from_device_tensor->type_id());
|
||||
MS_EXCEPTION_IF_NULL(from_devcie_context);
|
||||
copy_actor->input_device_context_ = from_devcie_context;
|
||||
copy_actor->output_device_context_ = to_devcie_context;
|
||||
|
|
Loading…
Reference in New Issue