!46351 修复SparseSlice 算子多线程计算结果不正确问题

Merge pull request !46351 from yujialiang/sparse
This commit is contained in:
i-robot 2022-12-02 06:54:36 +00:00 committed by Gitee
commit b2b88fc65e
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
1 changed files with 4 additions and 2 deletions

View File

@ -19,6 +19,7 @@
#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/sparse_slice_impl.cuh"
#include "plugin/device/cpu/kernel/nnacl/op_base.h"
#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/complex.h"
#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/util.cuh"
template <typename DataType, typename IndexType>
__global__ void SparseSliceKernel(const IndexType *indices_ptr, const DataType *values_ptr,
@ -26,6 +27,7 @@ __global__ void SparseSliceKernel(const IndexType *indices_ptr, const DataType *
IndexType *y_indices_ptr, DataType *y_values_ptr, IndexType *out_shape_ptr,
int64_t *sum_count_ptr, size_t input_nnz_, size_t num_dim_, size_t out_size_) {
IndexType non_zeros_ = 0;
int64_t addnum = 1;
for (int a = 0; a < out_size_; a += 1) {
out_shape_ptr[a] = size_ptr[a];
}
@ -50,7 +52,7 @@ __global__ void SparseSliceKernel(const IndexType *indices_ptr, const DataType *
y_indices_ptr[non_zeros_ * num_dim_ + dim] = new_index;
}
non_zeros_ += 1;
*sum_count_ptr += 1;
MsAtomicAdd(sum_count_ptr, addnum);
}
}
}
@ -61,7 +63,7 @@ CUDA_LIB_EXPORT void SparseSlice(const IndexType *indices_ptr, const DataType *v
IndexType *y_indices_ptr, DataType *y_values_ptr, IndexType *out_shape_ptr,
int64_t *sum_count_ptr, size_t input_nnz_, size_t num_dim_, size_t out_size_,
uint32_t device_id, cudaStream_t cuda_stream) {
SparseSliceKernel<<<GET_BLOCKS(input_nnz_), 1, 0, cuda_stream>>>(
SparseSliceKernel<<<GET_BLOCKS(input_nnz_), GET_THREADS, 0, cuda_stream>>>(
indices_ptr, values_ptr, x_ptr, start_ptr, size_ptr, y_indices_ptr, y_values_ptr, out_shape_ptr, sum_count_ptr,
input_nnz_, num_dim_, out_size_);
}