!37133 modify MaxPool3DGradWithArgmax cudaMemset
Merge pull request !37133 from 范吉斌/fix_maxpool3dgradwithargmax
This commit is contained in:
commit
ba6f6bb615
|
@ -35,8 +35,9 @@ bool MaxPool3DGradWithArgmaxGpuKernelMod::LaunchKernel(const std::vector<Address
|
|||
T *dy_addr = GetDeviceAddress<T>(inputs, kIndex1);
|
||||
S *index_addr = GetDeviceAddress<S>(inputs, kIndex2);
|
||||
T *dx_addr = GetDeviceAddress<T>(outputs, kIndex0);
|
||||
CHECK_CUDA_RET_WITH_ERROR_NOTRACE(cudaMemset(dx_addr, 0, outputs[kIndex0]->size),
|
||||
"For 'MaxPool3DWithArgmaxGrad' failed to cudaMemset");
|
||||
CHECK_CUDA_RET_WITH_ERROR_NOTRACE(
|
||||
cudaMemsetAsync(dx_addr, 0, outputs[kIndex0]->size, reinterpret_cast<cudaStream_t>(cuda_stream_)),
|
||||
"For 'MaxPool3DWithArgmaxGrad' failed to cudaMemsetAsync");
|
||||
CalMaxPool3DGradWithArgmax(dy_addr, index_addr, x_dhw_, dy_dhw_, dy_ncdhw_, dx_addr, device_id_,
|
||||
reinterpret_cast<cudaStream_t>(cuda_stream_));
|
||||
return true;
|
||||
|
|
Loading…
Reference in New Issue