diff --git a/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/sparse_segment_mean_with_num_segments_impl.cu b/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/sparse_segment_mean_with_num_segments_impl.cu index 134fb4e7122..485f9632dda 100644 --- a/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/sparse_segment_mean_with_num_segments_impl.cu +++ b/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/sparse_segment_mean_with_num_segments_impl.cu @@ -125,8 +125,8 @@ inline int Log2Ceil64(uint64_t n) { } template -__global__ void CoalesceKernelCheck(IndexType *indices_ptr, IndexType *segment_ids_ptr, IndexType *num_segments_ptr, - size_t outer_size, int *ret_flag, size_t indices_size) { +__global__ void InputValidCheck(IndexType *indices_ptr, IndexType *segment_ids_ptr, IndexType *num_segments_ptr, + size_t outer_size, int *ret_flag, size_t indices_size) { for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < indices_size; i += gridDim.x * blockDim.x) { if ((i != indices_size - 1) && (segment_ids_ptr[i] > segment_ids_ptr[i + 1])) { *ret_flag = 1; @@ -144,32 +144,32 @@ __global__ void CoalesceKernelCheck(IndexType *indices_ptr, IndexType *segment_i } template -CUDA_LIB_EXPORT void CalSparseSegmentMeanWithNumSegments(const DataType *x_ptr, const IndexType *indices_ptr, - const IndexType *segment_ids_ptr, - const IndexType *num_segments_ptr, size_t *segment_pos_ptr, - DataType *y_ptr, size_t outer_size, size_t inner_size, - size_t indices_size, size_t segment_size, size_t x_size, - size_t y_size, size_t batch_size, int *ret_flag_host, - uint32_t device_id, cudaStream_t cuda_stream) { +CUDA_LIB_EXPORT int CalSparseSegmentMeanWithNumSegments(const DataType *x_ptr, const IndexType *indices_ptr, + const IndexType *segment_ids_ptr, + const IndexType *num_segments_ptr, size_t *segment_pos_ptr, + DataType *y_ptr, size_t outer_size, size_t inner_size, + size_t indices_size, size_t segment_size, size_t x_size, + size_t y_size, size_t batch_size, int *ret_flag_device, + uint32_t device_id, cudaStream_t cuda_stream) { // Get start position of each segment and set to segment_pos_ptr. // The last element of segment_pos_ptr must equal to segment_size. - int *ret_flag_device = nullptr; - (void)cudaMalloc(&ret_flag_device, sizeof(int)); - (void)cudaMemset(ret_flag_device, 0, sizeof(int)); - CoalesceKernelCheck<<>>( + int ret_flag_host = 0; + int thread_num = indices_size + 1 > 256 ? 256 : (indices_size + 1); + (void)cudaMemsetAsync(ret_flag_device, 0, sizeof(int), cuda_stream); + InputValidCheck<<>>( indices_ptr, segment_ids_ptr, num_segments_ptr, outer_size, ret_flag_device, indices_size); - (void)cudaMemcpy(ret_flag_host, ret_flag_device, sizeof(int), cudaMemcpyDeviceToHost); - (void)cudaFree(ret_flag_device); - if (*ret_flag_host != 0) { - return; + (void)cudaMemcpyAsync(&ret_flag_host, ret_flag_device, sizeof(int), cudaMemcpyDeviceToHost, cuda_stream); + cudaStreamSynchronize(cuda_stream); + if (ret_flag_host != 0) { + return ret_flag_host; } - SparseSegmentPosKernel<<>>( + SparseSegmentPosKernel<<>>( segment_ids_ptr, segment_pos_ptr, indices_size, segment_size); const unsigned int max_grid_x = (1u << 31) - 1; const unsigned int max_grid_y = (1u << 16) - 1; - const unsigned int max_block_x = 1024; - const unsigned int max_block_y = 64; + const unsigned int max_block_x = 64; + const unsigned int max_block_y = 8; unsigned int inner_power2 = 1u << Log2Ceil64(inner_size); unsigned int avg_reduce_size = UP_DIV(outer_size, segment_size); unsigned int avg_reduce_size_power2 = 1u << Log2Ceil64(avg_reduce_size); @@ -188,32 +188,32 @@ CUDA_LIB_EXPORT void CalSparseSegmentMeanWithNumSegments(const DataType *x_ptr, SparseSegmentMeanWithNumSegmentsKernel<<>>( batch_x_ptr, batch_indices_ptr, segment_pos_ptr, batch_y_ptr, outer_size, inner_size, segment_size); } - return; + return ret_flag_host; } -template CUDA_LIB_EXPORT void CalSparseSegmentMeanWithNumSegments( +template CUDA_LIB_EXPORT int CalSparseSegmentMeanWithNumSegments( const half *x_ptr, const int32_t *indices_ptr, const int32_t *segment_ids_ptr, const int32_t *num_segments_ptr, size_t *segment_pos_ptr, half *y_ptr, size_t outer_size, size_t inner_size, size_t indices_size, size_t segment_size, - size_t x_size, size_t y_size, size_t batch_size, int *ret_flag_host, uint32_t device_id, cudaStream_t cuda_stream); -template CUDA_LIB_EXPORT void CalSparseSegmentMeanWithNumSegments( + size_t x_size, size_t y_size, size_t batch_size, int *ret_flag_device, uint32_t device_id, cudaStream_t cuda_stream); +template CUDA_LIB_EXPORT int CalSparseSegmentMeanWithNumSegments( const float *x_ptr, const int32_t *indices_ptr, const int32_t *segment_ids_ptr, const int32_t *num_segments_ptr, size_t *segment_pos_ptr, float *y_ptr, size_t outer_size, size_t inner_size, size_t indices_size, size_t segment_size, - size_t x_size, size_t y_size, size_t batch_size, int *ret_flag_host, uint32_t device_id, cudaStream_t cuda_stream); -template CUDA_LIB_EXPORT void CalSparseSegmentMeanWithNumSegments( + size_t x_size, size_t y_size, size_t batch_size, int *ret_flag_device, uint32_t device_id, cudaStream_t cuda_stream); +template CUDA_LIB_EXPORT int CalSparseSegmentMeanWithNumSegments( const double *x_ptr, const int32_t *indices_ptr, const int32_t *segment_ids_ptr, const int32_t *num_segments_ptr, size_t *segment_pos_ptr, double *y_ptr, size_t outer_size, size_t inner_size, size_t indices_size, - size_t segment_size, size_t x_size, size_t y_size, size_t batch_size, int *ret_flag_host, uint32_t device_id, + size_t segment_size, size_t x_size, size_t y_size, size_t batch_size, int *ret_flag_device, uint32_t device_id, cudaStream_t cuda_stream); -template CUDA_LIB_EXPORT void CalSparseSegmentMeanWithNumSegments( +template CUDA_LIB_EXPORT int CalSparseSegmentMeanWithNumSegments( const half *x_ptr, const int64_t *indices_ptr, const int64_t *segment_ids_ptr, const int64_t *num_segments_ptr, size_t *segment_pos_ptr, half *y_ptr, size_t outer_size, size_t inner_size, size_t indices_size, size_t segment_size, - size_t x_size, size_t y_size, size_t batch_size, int *ret_flag_host, uint32_t device_id, cudaStream_t cuda_stream); -template CUDA_LIB_EXPORT void CalSparseSegmentMeanWithNumSegments( + size_t x_size, size_t y_size, size_t batch_size, int *ret_flag_device, uint32_t device_id, cudaStream_t cuda_stream); +template CUDA_LIB_EXPORT int CalSparseSegmentMeanWithNumSegments( const float *x_ptr, const int64_t *indices_ptr, const int64_t *segment_ids_ptr, const int64_t *num_segments_ptr, size_t *segment_pos_ptr, float *y_ptr, size_t outer_size, size_t inner_size, size_t indices_size, size_t segment_size, - size_t x_size, size_t y_size, size_t batch_size, int *ret_flag_host, uint32_t device_id, cudaStream_t cuda_stream); -template CUDA_LIB_EXPORT void CalSparseSegmentMeanWithNumSegments( + size_t x_size, size_t y_size, size_t batch_size, int *ret_flag_device, uint32_t device_id, cudaStream_t cuda_stream); +template CUDA_LIB_EXPORT int CalSparseSegmentMeanWithNumSegments( const double *x_ptr, const int64_t *indices_ptr, const int64_t *segment_ids_ptr, const int64_t *num_segments_ptr, size_t *segment_pos_ptr, double *y_ptr, size_t outer_size, size_t inner_size, size_t indices_size, - size_t segment_size, size_t x_size, size_t y_size, size_t batch_size, int *ret_flag_host, uint32_t device_id, + size_t segment_size, size_t x_size, size_t y_size, size_t batch_size, int *ret_flag_device, uint32_t device_id, cudaStream_t cuda_stream); diff --git a/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/sparse_segment_mean_with_num_segments_impl.cuh b/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/sparse_segment_mean_with_num_segments_impl.cuh index 0818baa19a6..d7909d134ad 100644 --- a/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/sparse_segment_mean_with_num_segments_impl.cuh +++ b/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/sparse_segment_mean_with_num_segments_impl.cuh @@ -20,12 +20,12 @@ #include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/cuda_common.h" template -CUDA_LIB_EXPORT void CalSparseSegmentMeanWithNumSegments(const DataType *x_ptr, const IndexType *indices_ptr, - const IndexType *segment_ids_ptr, - const IndexType *num_segments_ptr, size_t *segment_pos_ptr, - DataType *y_ptr, size_t outer_size, size_t inner_size, - size_t indices_size, size_t segment_size, size_t x_size, - size_t y_size, size_t batch_size, int *ret_flag_host, - uint32_t device_id, cudaStream_t cuda_stream); +CUDA_LIB_EXPORT int CalSparseSegmentMeanWithNumSegments(const DataType *x_ptr, const IndexType *indices_ptr, + const IndexType *segment_ids_ptr, + const IndexType *num_segments_ptr, size_t *segment_pos_ptr, + DataType *y_ptr, size_t outer_size, size_t inner_size, + size_t indices_size, size_t segment_size, size_t x_size, + size_t y_size, size_t batch_size, int *ret_flag_device, + uint32_t device_id, cudaStream_t cuda_stream); #endif // MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_SPARSE_SEGMENT_MEAN_WITH_NUM_SEGMENTS_IMPL_CUH_ diff --git a/mindspore/ccsrc/plugin/device/gpu/kernel/math/sparse_segment_mean_with_num_segments_gpu_kernel.cc b/mindspore/ccsrc/plugin/device/gpu/kernel/math/sparse_segment_mean_with_num_segments_gpu_kernel.cc index f794c75aef7..32d3b0e892c 100644 --- a/mindspore/ccsrc/plugin/device/gpu/kernel/math/sparse_segment_mean_with_num_segments_gpu_kernel.cc +++ b/mindspore/ccsrc/plugin/device/gpu/kernel/math/sparse_segment_mean_with_num_segments_gpu_kernel.cc @@ -56,6 +56,7 @@ int SparseSegmentMeanWithNumSegmentsGpuKernelMod::Resize(const BaseOperatorPtr & y_size_ = std::accumulate(y_shape.begin() + batch_rank_, y_shape.end(), size_t(1), std::multiplies{}); segment_size_ = LongToSize(y_shape.at(batch_rank_)); workspace_size_list_.push_back((segment_size_ + 1) * sizeof(size_t)); + workspace_size_list_.push_back(sizeof(int)); return ret; } @@ -71,15 +72,16 @@ bool SparseSegmentMeanWithNumSegmentsGpuKernelMod::LaunchKernel(const std::vecto auto segment_ids_ptr = GetDeviceAddress(inputs, kIndex2); auto num_segments_ptr = GetDeviceAddress(inputs, kIndex3); auto segment_pos_ptr = GetDeviceAddress(workspace, kIndex0); + auto ret_flag_device = GetDeviceAddress(workspace, kIndex1); auto y_ptr = GetDeviceAddress(outputs, kIndex0); auto any = [](auto... args) -> bool { return ((args == nullptr) || ...); }; if (any(x_ptr, indices_ptr, segment_ids_ptr, num_segments_ptr, segment_pos_ptr, y_ptr)) { - return false; + cudaMemset(y_ptr, 0, outputs[0]->size); + return true; } - int ret_flag_host = 0; - CalSparseSegmentMeanWithNumSegments(x_ptr, indices_ptr, segment_ids_ptr, num_segments_ptr, segment_pos_ptr, y_ptr, - outer_size_, inner_size_, indices_size_, segment_size_, x_size_, y_size_, - batch_size_, &ret_flag_host, device_id_, cuda_stream); + int ret_flag_host = CalSparseSegmentMeanWithNumSegments( + x_ptr, indices_ptr, segment_ids_ptr, num_segments_ptr, segment_pos_ptr, y_ptr, outer_size_, inner_size_, + indices_size_, segment_size_, x_size_, y_size_, batch_size_, ret_flag_device, device_id_, cuda_stream); int FALSE_1 = 1; int FALSE_2 = 2; int FALSE_3 = 3; diff --git a/mindspore/ccsrc/plugin/device/gpu/kernel/sparse_grad/sparse_segment_mean_grad_gpu_kernel.cc b/mindspore/ccsrc/plugin/device/gpu/kernel/sparse_grad/sparse_segment_mean_grad_gpu_kernel.cc index aa62852ff5d..b82cdeb94af 100644 --- a/mindspore/ccsrc/plugin/device/gpu/kernel/sparse_grad/sparse_segment_mean_grad_gpu_kernel.cc +++ b/mindspore/ccsrc/plugin/device/gpu/kernel/sparse_grad/sparse_segment_mean_grad_gpu_kernel.cc @@ -99,7 +99,8 @@ bool SparseSegmentMeanGradGpuKernelMod::LaunchKernel(const std::vector(workspace, kIndex0); auto any = [](auto... args) -> bool { return ((args == nullptr) || ...); }; if (any(grad_ptr, indices_ptr, segment_ids_ptr, segment_pos_ptr, y_ptr)) { - return false; + cudaMemset(y_ptr, 0, outputs[0]->size); + return true; } cudaStream_t stream = reinterpret_cast(cuda_stream_); std::vector indices_host; diff --git a/mindspore/core/ops/grad/sparse_segment_mean_grad.cc b/mindspore/core/ops/grad/sparse_segment_mean_grad.cc index 44ce8cc2f0e..93e9c12ac04 100644 --- a/mindspore/core/ops/grad/sparse_segment_mean_grad.cc +++ b/mindspore/core/ops/grad/sparse_segment_mean_grad.cc @@ -65,8 +65,8 @@ abstract::ShapePtr SparseSegmentMeanGradInferShape(const PrimitivePtr &prim, auto output_dim0_value_ptr_tensor = CheckAndConvertUtils::CheckTensorIntValue("output_dim0", output_dim0_value_ptr, prim_name); int dim_zero = output_dim0_value_ptr_tensor[kShapeNum0]; - if (dim_zero <= kDimNum0) { - MS_EXCEPTION(ValueError) << "Input output_dim0 must > 0!"; + if (dim_zero < kDimNum0) { + MS_EXCEPTION(ValueError) << "Input output_dim0 must >= 0!"; } else { y_shape[kShapeNum0] = dim_zero; }