!46351 修复SparseSlice 算子多线程计算结果不正确问题
Merge pull request !46351 from yujialiang/sparse
This commit is contained in:
commit
b2b88fc65e
|
@ -19,6 +19,7 @@
|
|||
#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/sparse_slice_impl.cuh"
|
||||
#include "plugin/device/cpu/kernel/nnacl/op_base.h"
|
||||
#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/complex.h"
|
||||
#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/util.cuh"
|
||||
|
||||
template <typename DataType, typename IndexType>
|
||||
__global__ void SparseSliceKernel(const IndexType *indices_ptr, const DataType *values_ptr,
|
||||
|
@ -26,6 +27,7 @@ __global__ void SparseSliceKernel(const IndexType *indices_ptr, const DataType *
|
|||
IndexType *y_indices_ptr, DataType *y_values_ptr, IndexType *out_shape_ptr,
|
||||
int64_t *sum_count_ptr, size_t input_nnz_, size_t num_dim_, size_t out_size_) {
|
||||
IndexType non_zeros_ = 0;
|
||||
int64_t addnum = 1;
|
||||
for (int a = 0; a < out_size_; a += 1) {
|
||||
out_shape_ptr[a] = size_ptr[a];
|
||||
}
|
||||
|
@ -50,7 +52,7 @@ __global__ void SparseSliceKernel(const IndexType *indices_ptr, const DataType *
|
|||
y_indices_ptr[non_zeros_ * num_dim_ + dim] = new_index;
|
||||
}
|
||||
non_zeros_ += 1;
|
||||
*sum_count_ptr += 1;
|
||||
MsAtomicAdd(sum_count_ptr, addnum);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -61,7 +63,7 @@ CUDA_LIB_EXPORT void SparseSlice(const IndexType *indices_ptr, const DataType *v
|
|||
IndexType *y_indices_ptr, DataType *y_values_ptr, IndexType *out_shape_ptr,
|
||||
int64_t *sum_count_ptr, size_t input_nnz_, size_t num_dim_, size_t out_size_,
|
||||
uint32_t device_id, cudaStream_t cuda_stream) {
|
||||
SparseSliceKernel<<<GET_BLOCKS(input_nnz_), 1, 0, cuda_stream>>>(
|
||||
SparseSliceKernel<<<GET_BLOCKS(input_nnz_), GET_THREADS, 0, cuda_stream>>>(
|
||||
indices_ptr, values_ptr, x_ptr, start_ptr, size_ptr, y_indices_ptr, y_values_ptr, out_shape_ptr, sum_count_ptr,
|
||||
input_nnz_, num_dim_, out_size_);
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue