forked from mindspore-Ecosystem/mindspore
!3286 fix conv2d precision problem and conv2d runtime error
Merge pull request !3286 from lvchangquan/transdata
This commit is contained in:
commit
4cdc404805
|
@ -415,9 +415,9 @@ bool AscendDeviceAddress::SyncDeviceToHostAndConvertFormatBasedOnTransData(const
|
|||
MS_LOG(ERROR) << "Illegal dtype.";
|
||||
}
|
||||
auto shape_size = trans::ShapeSize(host_shape);
|
||||
auto size_tmp = device_dtype_size * shape_size;
|
||||
size = GetCommonAlignSize(size_tmp);
|
||||
size = device_dtype_size * shape_size;
|
||||
}
|
||||
size = GetCommonAlignSize(size);
|
||||
void *output_address_ptr = nullptr;
|
||||
auto ret_malloc = rtMalloc(&output_address_ptr, size, RT_MEMORY_HBM);
|
||||
if (ret_malloc != RT_ERROR_NONE) {
|
||||
|
@ -427,7 +427,7 @@ bool AscendDeviceAddress::SyncDeviceToHostAndConvertFormatBasedOnTransData(const
|
|||
// launch
|
||||
LaunchTransData(kernel_mod_ptr, output_address_ptr, size, workspace_size_list);
|
||||
if (type_id_ == type) {
|
||||
SyncMemory(host_ptr, output_address_ptr, size, RT_MEMCPY_DEVICE_TO_HOST);
|
||||
SyncMemory(host_ptr, output_address_ptr, host_size, RT_MEMCPY_DEVICE_TO_HOST);
|
||||
} else {
|
||||
auto host = std::vector<uint8_t>(size);
|
||||
SyncMemory(host.data(), output_address_ptr, size, RT_MEMCPY_DEVICE_TO_HOST);
|
||||
|
|
Loading…
Reference in New Issue