!27663 Check Format when Sync DeviceToDevice

Merge pull request !27663 from hwjiaorui/fix-bug-nullptr
This commit is contained in:
i-robot 2021-12-20 07:24:09 +00:00 committed by Gitee
commit 2d7e4adf44
3 changed files with 14 additions and 10 deletions

View File

@ -293,14 +293,14 @@ std::string FetchActorName(KernelTransformType kernel_type, const std::string &a
return actor_name;
}
bool CheckMemcpyInDevice(const DeviceTensor *dst_device_addr, const DeviceTensor *src_device_addr) {
bool NeedSyncByTensor(const DeviceTensor *dst_device_addr, const DeviceTensor *src_device_addr) {
MS_EXCEPTION_IF_NULL(dst_device_addr);
if (src_device_addr == nullptr) {
MS_EXCEPTION_IF_NULL(src_device_addr);
if (src_device_addr->DeviceType() != dst_device_addr->DeviceType()) {
return false;
}
return (src_device_addr->DeviceType() == dst_device_addr->DeviceType() &&
src_device_addr->format() == dst_device_addr->format() &&
src_device_addr->type_id() == dst_device_addr->type_id());
return (src_device_addr->format() != dst_device_addr->format() ||
src_device_addr->type_id() != dst_device_addr->type_id());
}
} // namespace runtime
} // namespace mindspore

View File

@ -211,7 +211,7 @@ KernelTransformType FetchKernelTransformType(const AnfNodePtr &node, const Kerne
std::string FetchActorName(KernelTransformType kernel_type, const std::string &actor_set_name,
const AnfNodePtr &node = nullptr, const KernelGraphPtr &graph = nullptr);
bool CheckMemcpyInDevice(const DeviceTensor *dst_device_tensor, const DeviceTensor *src_device_tensor);
bool NeedSyncByTensor(const DeviceTensor *dst_device_tensor, const DeviceTensor *src_device_tensor);
} // namespace runtime
} // namespace mindspore

View File

@ -239,13 +239,17 @@ void HostQueueDataSourceActor::OnMemoryAllocFinish(OpContext<DeviceTensor> *cons
auto tensor_device_address = std::dynamic_pointer_cast<DeviceTensor>(host_tensor->device_address());
// Sync data from host_tensor_device_address to device_tensor.
if (tensor_device_address != nullptr) {
if (CheckMemcpyInDevice(device_tensor, tensor_device_address.get())) {
if ((tensor_device_address.get() != device_tensor) && (!Copy(device_tensor, tensor_device_address.get()))) {
if (tensor_device_address.get() == device_tensor) {
continue;
}
if (NeedSyncByTensor(device_tensor, tensor_device_address.get())) {
host_tensor->data_sync(false);
} else {
if ((!Copy(device_tensor, tensor_device_address.get()))) {
SET_OPCONTEXT_FAIL_RET_WITH_ERROR((*context), "Copy data failed.");
}
continue;
} else {
host_tensor->data_sync(false);
}
}