forked from mindspore-Ecosystem/mindspore
parent
11694db238
commit
92d9bc7ccd
|
@ -41,8 +41,12 @@ class ResizeBilinearGradGpuKernel : public GpuKernel {
|
|||
T *dx = GetDeviceAddress<T>(outputs, 0);
|
||||
float h_scale = Scaling(dx_h_, dy_h_, align_corners_);
|
||||
float w_scale = Scaling(dx_w_, dy_w_, align_corners_);
|
||||
CHECK_CUDA_RET_WITH_EXCEPT(kernel_node_, cudaMemset(interim, 0, workspace_size_), "cudaMemset dx_interim failed");
|
||||
CHECK_CUDA_RET_WITH_EXCEPT(kernel_node_, cudaMemset(dx, 0, dx_size_), "cudaMemset dx failed");
|
||||
CHECK_CUDA_RET_WITH_EXCEPT(kernel_node_,
|
||||
cudaMemsetAsync(dx, 0, dx_size_, reinterpret_cast<cudaStream_t>(stream_ptr)),
|
||||
"cudaMemsetAsync dx failed");
|
||||
CHECK_CUDA_RET_WITH_EXCEPT(kernel_node_,
|
||||
cudaMemsetAsync(interim, 0, workspace_size_, reinterpret_cast<cudaStream_t>(stream_ptr)),
|
||||
"cudaMemsetAsync dx_interim failed");
|
||||
CalResizeBilinearGrad(dy, n_, c_, dy_h_, dy_w_, dx_h_, dx_w_, h_scale, w_scale, dx, interim,
|
||||
reinterpret_cast<cudaStream_t>(stream_ptr));
|
||||
return true;
|
||||
|
|
Loading…
Reference in New Issue