forked from mindspore-Ecosystem/mindspore
fix conv2d precision problem and conv2d runtime error.
This commit is contained in:
parent
ca6756b5fe
commit
25878e62d9
|
@ -415,9 +415,9 @@ bool AscendDeviceAddress::SyncDeviceToHostAndConvertFormatBasedOnTransData(const
|
||||||
MS_LOG(ERROR) << "Illegal dtype.";
|
MS_LOG(ERROR) << "Illegal dtype.";
|
||||||
}
|
}
|
||||||
auto shape_size = trans::ShapeSize(host_shape);
|
auto shape_size = trans::ShapeSize(host_shape);
|
||||||
auto size_tmp = device_dtype_size * shape_size;
|
size = device_dtype_size * shape_size;
|
||||||
size = GetCommonAlignSize(size_tmp);
|
|
||||||
}
|
}
|
||||||
|
size = GetCommonAlignSize(size);
|
||||||
void *output_address_ptr = nullptr;
|
void *output_address_ptr = nullptr;
|
||||||
auto ret_malloc = rtMalloc(&output_address_ptr, size, RT_MEMORY_HBM);
|
auto ret_malloc = rtMalloc(&output_address_ptr, size, RT_MEMORY_HBM);
|
||||||
if (ret_malloc != RT_ERROR_NONE) {
|
if (ret_malloc != RT_ERROR_NONE) {
|
||||||
|
@ -427,7 +427,7 @@ bool AscendDeviceAddress::SyncDeviceToHostAndConvertFormatBasedOnTransData(const
|
||||||
// launch
|
// launch
|
||||||
LaunchTransData(kernel_mod_ptr, output_address_ptr, size, workspace_size_list);
|
LaunchTransData(kernel_mod_ptr, output_address_ptr, size, workspace_size_list);
|
||||||
if (type_id_ == type) {
|
if (type_id_ == type) {
|
||||||
SyncMemory(host_ptr, output_address_ptr, size, RT_MEMCPY_DEVICE_TO_HOST);
|
SyncMemory(host_ptr, output_address_ptr, host_size, RT_MEMCPY_DEVICE_TO_HOST);
|
||||||
} else {
|
} else {
|
||||||
auto host = std::vector<uint8_t>(size);
|
auto host = std::vector<uint8_t>(size);
|
||||||
SyncMemory(host.data(), output_address_ptr, size, RT_MEMCPY_DEVICE_TO_HOST);
|
SyncMemory(host.data(), output_address_ptr, size, RT_MEMCPY_DEVICE_TO_HOST);
|
||||||
|
|
Loading…
Reference in New Issue