!47493 solve the bug of cumsum res zero

Merge pull request !47493 from zong_shuai/cumsum_debug
This commit is contained in:
i-robot 2023-01-05 01:44:14 +00:00 committed by Gitee
commit bd2141ce00
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
1 changed files with 11 additions and 10 deletions

View File

@ -117,27 +117,28 @@ template <typename T>
void CumSum(const T *input, T *output, T *workspace, size_t dim0, size_t dim1, size_t dim2, size_t stride,
size_t stride2, bool exclusive_, bool reverse_, const uint32_t &device_id, cudaStream_t stream) {
int size = dim0 * dim2;
int block_num = size > 256 ? 256 : size;
if (exclusive_) {
if (reverse_) {
RightMoveSum<<<CUDA_BLOCKS(device_id, size), CUDA_THREADS(device_id), 0, stream>>>(input, output, dim0, dim1,
dim2, stride, stride2);
RightMoveSum<<<CUDA_BLOCKS_CAL(device_id, size, block_num), block_num, 0, stream>>>(input, output, dim0, dim1,
dim2, stride, stride2);
Copy<<<CUDA_BLOCKS(device_id, size * dim1), CUDA_THREADS(device_id), 0, stream>>>(workspace, output, size * dim1);
CumSumKernelReverse<<<CUDA_BLOCKS(device_id, size), CUDA_THREADS(device_id), 0, stream>>>(
CumSumKernelReverse<<<CUDA_BLOCKS_CAL(device_id, size, block_num), block_num, 0, stream>>>(
workspace, output, dim0, dim1, dim2, stride, stride2);
} else {
LeftMoveSum<<<CUDA_BLOCKS(device_id, size), CUDA_THREADS(device_id), 0, stream>>>(input, output, dim0, dim1, dim2,
stride, stride2);
Copy<<<CUDA_BLOCKS(device_id, size * dim1), CUDA_THREADS(device_id), 0, stream>>>(workspace, output, size * dim1);
CumSumKernel<<<CUDA_BLOCKS(device_id, size), CUDA_THREADS(device_id), 0, stream>>>(workspace, output, dim0, dim1,
LeftMoveSum<<<CUDA_BLOCKS_CAL(device_id, size, block_num), block_num, 0, stream>>>(input, output, dim0, dim1,
dim2, stride, stride2);
Copy<<<CUDA_BLOCKS(device_id, size * dim1), CUDA_THREADS(device_id), 0, stream>>>(workspace, output, size * dim1);
CumSumKernel<<<CUDA_BLOCKS_CAL(device_id, size, block_num), block_num, 0, stream>>>(workspace, output, dim0, dim1,
dim2, stride, stride2);
}
} else {
if (reverse_) {
CumSumKernelReverse<<<CUDA_BLOCKS(device_id, size), CUDA_THREADS(device_id), 0, stream>>>(
CumSumKernelReverse<<<CUDA_BLOCKS_CAL(device_id, size, block_num), block_num, 0, stream>>>(
input, output, dim0, dim1, dim2, stride, stride2);
} else {
CumSumKernel<<<CUDA_BLOCKS(device_id, size), CUDA_THREADS(device_id), 0, stream>>>(input, output, dim0, dim1,
dim2, stride, stride2);
CumSumKernel<<<CUDA_BLOCKS_CAL(device_id, size, block_num), block_num, 0, stream>>>(input, output, dim0, dim1,
dim2, stride, stride2);
}
}
return;