!46879 recover the error exception of scatter gpu func

Merge pull request !46879 from zhengzuohe/scatter_gpu_error
This commit is contained in:
i-robot 2022-12-23 03:13:57 +00:00 committed by Gitee
commit 9e13bb9fcf
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
1 changed files with 15 additions and 10 deletions

View File

@ -23,8 +23,9 @@ __global__ void ScatterUpdateKernel(S size_limit, const size_t inner_size, const
for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < updates_size; pos += blockDim.x * gridDim.x) {
const size_t index = pos / inner_size;
const size_t offset = pos % inner_size;
CUDA_KERNEL_ASSERT(indices[index] >= 0 && indices[index] < size_limit
&& "For 'ScatterUpdate', the value of indices is out of range.");
if (indices[index] < 0 || indices[index] >= size_limit) {
continue;
}
const size_t current_pos = indices[index] * inner_size + offset;
input[current_pos] = updates[pos];
}
@ -36,8 +37,9 @@ __global__ void ScatterAddKernel(S size_limit, const size_t inner_size, const si
for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < updates_size; pos += blockDim.x * gridDim.x) {
const size_t index = pos / inner_size;
const size_t offset = pos % inner_size;
CUDA_KERNEL_ASSERT(indices[index] >= 0 && indices[index] < size_limit
&& "For 'ScatterAdd', the value of indices is out of range.");
if (indices[index] < 0 || indices[index] >= size_limit) {
continue;
}
const size_t current_pos = indices[index] * inner_size + offset;
MsAtomicAdd(&input[current_pos], updates[pos]);
}
@ -49,8 +51,9 @@ __global__ void ScatterSubKernel(S size_limit, const size_t inner_size, const si
for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < updates_size; pos += blockDim.x * gridDim.x) {
const size_t index = pos / inner_size;
const size_t offset = pos % inner_size;
CUDA_KERNEL_ASSERT(indices[index] >= 0 && indices[index] < size_limit
&& "For 'ScatterSub', the value of indices is out of range.");
if (indices[index] < 0 || indices[index] >= size_limit) {
continue;
}
const size_t current_pos = indices[index] * inner_size + offset;
MsAtomicSub(&input[current_pos], updates[pos]);
}
@ -62,8 +65,9 @@ __global__ void ScatterMaxKernel(S size_limit, const size_t inner_size, const si
for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < updates_size; pos += blockDim.x * gridDim.x) {
const size_t index = pos / inner_size;
const size_t offset = pos % inner_size;
CUDA_KERNEL_ASSERT(indices[index] >= 0 && indices[index] < size_limit
&& "For 'ScatterMax', the value of indices is out of range.");
if (indices[index] < 0 || indices[index] >= size_limit) {
continue;
}
const size_t current_pos = indices[index] * inner_size + offset;
MsAtomicMax(&input[current_pos], updates[pos]);
}
@ -75,8 +79,9 @@ __global__ void ScatterMinKernel(S size_limit, const size_t inner_size, const si
for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < updates_size; pos += blockDim.x * gridDim.x) {
const size_t index = pos / inner_size;
const size_t offset = pos % inner_size;
CUDA_KERNEL_ASSERT(indices[index] >= 0 && indices[index] < size_limit
&& "For 'ScatterMin', the value of indices is out of range.");
if (indices[index] < 0 || indices[index] >= size_limit) {
continue;
}
const size_t current_pos = indices[index] * inner_size + offset;
MsAtomicMin(&input[current_pos], updates[pos]);
}