forked from mindspore-Ecosystem/mindspore
!27663 Check Format when Sync DeviceToDevice
Merge pull request !27663 from hwjiaorui/fix-bug-nullptr
This commit is contained in:
commit
2d7e4adf44
|
@ -293,14 +293,14 @@ std::string FetchActorName(KernelTransformType kernel_type, const std::string &a
|
|||
return actor_name;
|
||||
}
|
||||
|
||||
bool CheckMemcpyInDevice(const DeviceTensor *dst_device_addr, const DeviceTensor *src_device_addr) {
|
||||
bool NeedSyncByTensor(const DeviceTensor *dst_device_addr, const DeviceTensor *src_device_addr) {
|
||||
MS_EXCEPTION_IF_NULL(dst_device_addr);
|
||||
if (src_device_addr == nullptr) {
|
||||
MS_EXCEPTION_IF_NULL(src_device_addr);
|
||||
if (src_device_addr->DeviceType() != dst_device_addr->DeviceType()) {
|
||||
return false;
|
||||
}
|
||||
return (src_device_addr->DeviceType() == dst_device_addr->DeviceType() &&
|
||||
src_device_addr->format() == dst_device_addr->format() &&
|
||||
src_device_addr->type_id() == dst_device_addr->type_id());
|
||||
return (src_device_addr->format() != dst_device_addr->format() ||
|
||||
src_device_addr->type_id() != dst_device_addr->type_id());
|
||||
}
|
||||
} // namespace runtime
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -211,7 +211,7 @@ KernelTransformType FetchKernelTransformType(const AnfNodePtr &node, const Kerne
|
|||
std::string FetchActorName(KernelTransformType kernel_type, const std::string &actor_set_name,
|
||||
const AnfNodePtr &node = nullptr, const KernelGraphPtr &graph = nullptr);
|
||||
|
||||
bool CheckMemcpyInDevice(const DeviceTensor *dst_device_tensor, const DeviceTensor *src_device_tensor);
|
||||
bool NeedSyncByTensor(const DeviceTensor *dst_device_tensor, const DeviceTensor *src_device_tensor);
|
||||
} // namespace runtime
|
||||
} // namespace mindspore
|
||||
|
||||
|
|
|
@ -239,13 +239,17 @@ void HostQueueDataSourceActor::OnMemoryAllocFinish(OpContext<DeviceTensor> *cons
|
|||
auto tensor_device_address = std::dynamic_pointer_cast<DeviceTensor>(host_tensor->device_address());
|
||||
// Sync data from host_tensor_device_address to device_tensor.
|
||||
if (tensor_device_address != nullptr) {
|
||||
if (CheckMemcpyInDevice(device_tensor, tensor_device_address.get())) {
|
||||
if ((tensor_device_address.get() != device_tensor) && (!Copy(device_tensor, tensor_device_address.get()))) {
|
||||
if (tensor_device_address.get() == device_tensor) {
|
||||
continue;
|
||||
}
|
||||
|
||||
if (NeedSyncByTensor(device_tensor, tensor_device_address.get())) {
|
||||
host_tensor->data_sync(false);
|
||||
} else {
|
||||
if ((!Copy(device_tensor, tensor_device_address.get()))) {
|
||||
SET_OPCONTEXT_FAIL_RET_WITH_ERROR((*context), "Copy data failed.");
|
||||
}
|
||||
continue;
|
||||
} else {
|
||||
host_tensor->data_sync(false);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
Loading…
Reference in New Issue