diff --git a/mindspore/ccsrc/runtime/device/ascend/ascend_device_address.cc b/mindspore/ccsrc/runtime/device/ascend/ascend_device_address.cc index a4eb11adba1..5ec96a7467c 100644 --- a/mindspore/ccsrc/runtime/device/ascend/ascend_device_address.cc +++ b/mindspore/ccsrc/runtime/device/ascend/ascend_device_address.cc @@ -415,9 +415,9 @@ bool AscendDeviceAddress::SyncDeviceToHostAndConvertFormatBasedOnTransData(const MS_LOG(ERROR) << "Illegal dtype."; } auto shape_size = trans::ShapeSize(host_shape); - auto size_tmp = device_dtype_size * shape_size; - size = GetCommonAlignSize(size_tmp); + size = device_dtype_size * shape_size; } + size = GetCommonAlignSize(size); void *output_address_ptr = nullptr; auto ret_malloc = rtMalloc(&output_address_ptr, size, RT_MEMORY_HBM); if (ret_malloc != RT_ERROR_NONE) { @@ -427,7 +427,7 @@ bool AscendDeviceAddress::SyncDeviceToHostAndConvertFormatBasedOnTransData(const // launch LaunchTransData(kernel_mod_ptr, output_address_ptr, size, workspace_size_list); if (type_id_ == type) { - SyncMemory(host_ptr, output_address_ptr, size, RT_MEMCPY_DEVICE_TO_HOST); + SyncMemory(host_ptr, output_address_ptr, host_size, RT_MEMCPY_DEVICE_TO_HOST); } else { auto host = std::vector(size); SyncMemory(host.data(), output_address_ptr, size, RT_MEMCPY_DEVICE_TO_HOST);