!37133 modify MaxPool3DGradWithArgmax cudaMemset

Merge pull request !37133 from 范吉斌/fix_maxpool3dgradwithargmax
This commit is contained in:
i-robot 2022-07-05 06:58:08 +00:00 committed by Gitee
commit ba6f6bb615
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
1 changed files with 3 additions and 2 deletions

View File

@ -35,8 +35,9 @@ bool MaxPool3DGradWithArgmaxGpuKernelMod::LaunchKernel(const std::vector<Address
T *dy_addr = GetDeviceAddress<T>(inputs, kIndex1);
S *index_addr = GetDeviceAddress<S>(inputs, kIndex2);
T *dx_addr = GetDeviceAddress<T>(outputs, kIndex0);
CHECK_CUDA_RET_WITH_ERROR_NOTRACE(cudaMemset(dx_addr, 0, outputs[kIndex0]->size),
"For 'MaxPool3DWithArgmaxGrad' failed to cudaMemset");
CHECK_CUDA_RET_WITH_ERROR_NOTRACE(
cudaMemsetAsync(dx_addr, 0, outputs[kIndex0]->size, reinterpret_cast<cudaStream_t>(cuda_stream_)),
"For 'MaxPool3DWithArgmaxGrad' failed to cudaMemsetAsync");
CalMaxPool3DGradWithArgmax(dy_addr, index_addr, x_dhw_, dy_dhw_, dy_ncdhw_, dx_addr, device_id_,
reinterpret_cast<cudaStream_t>(cuda_stream_));
return true;