diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/unsorted_segment_max_gpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/unsorted_segment_max_gpu_kernel.cc index bc4df65f4a8..760c151277e 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/unsorted_segment_max_gpu_kernel.cc +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/unsorted_segment_max_gpu_kernel.cc @@ -18,60 +18,131 @@ namespace mindspore { namespace kernel { -MS_REG_GPU_KERNEL_ONE( +MS_REG_GPU_KERNEL_TWO( UnsortedSegmentMax, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat32), - UnsortedSegmentMaxGpuKernel, float) -MS_REG_GPU_KERNEL_ONE( + UnsortedSegmentMaxGpuKernel, float, int) + +MS_REG_GPU_KERNEL_TWO( + UnsortedSegmentMax, + KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeFloat32), + UnsortedSegmentMaxGpuKernel, float, int64_t) + +MS_REG_GPU_KERNEL_TWO( UnsortedSegmentMax, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat16), - UnsortedSegmentMaxGpuKernel, half) -MS_REG_GPU_KERNEL_ONE( + UnsortedSegmentMaxGpuKernel, half, int) + +MS_REG_GPU_KERNEL_TWO( + UnsortedSegmentMax, + KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeFloat16), + UnsortedSegmentMaxGpuKernel, half, int64_t) + +MS_REG_GPU_KERNEL_TWO( UnsortedSegmentMax, KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32), - UnsortedSegmentMaxGpuKernel, int) + UnsortedSegmentMaxGpuKernel, int, int) + +MS_REG_GPU_KERNEL_TWO( + UnsortedSegmentMax, + KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt32), + UnsortedSegmentMaxGpuKernel, int, int64_t) + // Dynamic Mode - registered for int32/int64 3rd input -MS_REG_GPU_KERNEL_ONE(UnsortedSegmentMax, +MS_REG_GPU_KERNEL_TWO(UnsortedSegmentMax, KernelAttr() .AddInputAttr(kNumberTypeFloat32) .AddInputAttr(kNumberTypeInt32) .AddInputAttr(kNumberTypeInt32) .AddOutputAttr(kNumberTypeFloat32), - UnsortedSegmentMaxGpuKernel, float) -MS_REG_GPU_KERNEL_ONE(UnsortedSegmentMax, + UnsortedSegmentMaxGpuKernel, float, int) + +MS_REG_GPU_KERNEL_TWO(UnsortedSegmentMax, + KernelAttr() + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeInt64) + .AddInputAttr(kNumberTypeInt32) + .AddOutputAttr(kNumberTypeFloat32), + UnsortedSegmentMaxGpuKernel, float, int64_t) + +MS_REG_GPU_KERNEL_TWO(UnsortedSegmentMax, KernelAttr() .AddInputAttr(kNumberTypeFloat32) .AddInputAttr(kNumberTypeInt32) .AddInputAttr(kNumberTypeInt64) .AddOutputAttr(kNumberTypeFloat32), - UnsortedSegmentMaxGpuKernel, float) -MS_REG_GPU_KERNEL_ONE(UnsortedSegmentMax, + UnsortedSegmentMaxGpuKernel, float, int) + +MS_REG_GPU_KERNEL_TWO(UnsortedSegmentMax, + KernelAttr() + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeInt64) + .AddInputAttr(kNumberTypeInt64) + .AddOutputAttr(kNumberTypeFloat32), + UnsortedSegmentMaxGpuKernel, float, int64_t) + +MS_REG_GPU_KERNEL_TWO(UnsortedSegmentMax, KernelAttr() .AddInputAttr(kNumberTypeFloat16) .AddInputAttr(kNumberTypeInt32) .AddInputAttr(kNumberTypeInt32) .AddOutputAttr(kNumberTypeFloat16), - UnsortedSegmentMaxGpuKernel, half) -MS_REG_GPU_KERNEL_ONE(UnsortedSegmentMax, + UnsortedSegmentMaxGpuKernel, half, int) + +MS_REG_GPU_KERNEL_TWO(UnsortedSegmentMax, + KernelAttr() + .AddInputAttr(kNumberTypeFloat16) + .AddInputAttr(kNumberTypeInt64) + .AddInputAttr(kNumberTypeInt32) + .AddOutputAttr(kNumberTypeFloat16), + UnsortedSegmentMaxGpuKernel, half, int64_t) + +MS_REG_GPU_KERNEL_TWO(UnsortedSegmentMax, KernelAttr() .AddInputAttr(kNumberTypeFloat16) .AddInputAttr(kNumberTypeInt32) .AddInputAttr(kNumberTypeInt64) .AddOutputAttr(kNumberTypeFloat16), - UnsortedSegmentMaxGpuKernel, half) -MS_REG_GPU_KERNEL_ONE(UnsortedSegmentMax, + UnsortedSegmentMaxGpuKernel, half, int) + +MS_REG_GPU_KERNEL_TWO(UnsortedSegmentMax, + KernelAttr() + .AddInputAttr(kNumberTypeFloat16) + .AddInputAttr(kNumberTypeInt64) + .AddInputAttr(kNumberTypeInt64) + .AddOutputAttr(kNumberTypeFloat16), + UnsortedSegmentMaxGpuKernel, half, int64_t) + +MS_REG_GPU_KERNEL_TWO(UnsortedSegmentMax, KernelAttr() .AddInputAttr(kNumberTypeInt32) .AddInputAttr(kNumberTypeInt32) .AddInputAttr(kNumberTypeInt32) .AddOutputAttr(kNumberTypeInt32), - UnsortedSegmentMaxGpuKernel, int) -MS_REG_GPU_KERNEL_ONE(UnsortedSegmentMax, + UnsortedSegmentMaxGpuKernel, int, int) + +MS_REG_GPU_KERNEL_TWO(UnsortedSegmentMax, + KernelAttr() + .AddInputAttr(kNumberTypeInt32) + .AddInputAttr(kNumberTypeInt64) + .AddInputAttr(kNumberTypeInt32) + .AddOutputAttr(kNumberTypeInt32), + UnsortedSegmentMaxGpuKernel, int, int64_t) + +MS_REG_GPU_KERNEL_TWO(UnsortedSegmentMax, KernelAttr() .AddInputAttr(kNumberTypeInt32) .AddInputAttr(kNumberTypeInt32) .AddInputAttr(kNumberTypeInt64) .AddOutputAttr(kNumberTypeInt32), - UnsortedSegmentMaxGpuKernel, int) + UnsortedSegmentMaxGpuKernel, int, int) + +MS_REG_GPU_KERNEL_TWO(UnsortedSegmentMax, + KernelAttr() + .AddInputAttr(kNumberTypeInt32) + .AddInputAttr(kNumberTypeInt64) + .AddInputAttr(kNumberTypeInt64) + .AddOutputAttr(kNumberTypeInt32), + UnsortedSegmentMaxGpuKernel, int, int64_t) } // namespace kernel } // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/unsorted_segment_max_gpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/unsorted_segment_max_gpu_kernel.h index 2cda65112a8..f08838ec71c 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/unsorted_segment_max_gpu_kernel.h +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/unsorted_segment_max_gpu_kernel.h @@ -25,7 +25,7 @@ namespace mindspore { namespace kernel { -template +template class UnsortedSegmentMaxGpuKernel : public GpuKernel { public: UnsortedSegmentMaxGpuKernel() { ResetResource(); } @@ -41,7 +41,7 @@ class UnsortedSegmentMaxGpuKernel : public GpuKernel { return true; } T *input_addr = GetDeviceAddress(inputs, 0); - int *indices_addr = GetDeviceAddress(inputs, 1); + S *indices_addr = GetDeviceAddress(inputs, 1); T *output_addr = GetDeviceAddress(outputs, 0); CHECK_CUDA_RET_WITH_EXCEPT(kernel_node_, diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/unsorted_segment_max.cu b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/unsorted_segment_max.cu index c1eb49f00ba..53403e4b5f9 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/unsorted_segment_max.cu +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/unsorted_segment_max.cu @@ -17,21 +17,21 @@ #include "backend/kernel_compiler/gpu/cuda_impl/unsorted_segment_max.cuh" #include -template -__global__ void UnsortedSegmentMax(const T *input, const int *segment_ids, const int64_t num_segments, - size_t outer_size, size_t inner_size, bool fp16_flag, T init_K, T *output) { +template +__global__ void UnsortedSegmentMax(const T *input, const S *segment_ids, const int64_t num_segments, size_t outer_size, + size_t inner_size, bool fp16_flag, T init_K, T *output) { if (fp16_flag) { init_K = __int2half_rd(-65504); // min value representable by float16 } - for (int t_idx = blockIdx.x * blockDim.x + threadIdx.x; t_idx < KWARPSIZE * num_segments * inner_size; + for (size_t t_idx = blockIdx.x * blockDim.x + threadIdx.x; t_idx < KWARPSIZE * num_segments * inner_size; t_idx += blockDim.x * gridDim.x) { - int segment_id = t_idx / KWARPSIZE / inner_size; - int inner_id = t_idx / KWARPSIZE % inner_size; - int lane_id = threadIdx.x % KWARPSIZE; + size_t segment_id = t_idx / KWARPSIZE / inner_size; + size_t inner_id = t_idx / KWARPSIZE % inner_size; + size_t lane_id = threadIdx.x % KWARPSIZE; T threadK = init_K; - for (int i = lane_id; i < outer_size; i += KWARPSIZE) { + for (size_t i = lane_id; i < outer_size; i += KWARPSIZE) { if (segment_ids[i] != segment_id) continue; T other_K = input[i * inner_size + inner_id]; if (threadK < other_K) { @@ -40,7 +40,7 @@ __global__ void UnsortedSegmentMax(const T *input, const int *segment_ids, const } __syncwarp(); - for (int offset = KWARPSIZE / 2; offset > 0; offset /= 2) { + for (size_t offset = KWARPSIZE / 2; offset > 0; offset /= 2) { T other_K = __shfl_down_sync(0xffffffff, threadK, offset); if (threadK < other_K) { threadK = other_K; @@ -56,10 +56,10 @@ __global__ void UnsortedSegmentMax(const T *input, const int *segment_ids, const } } -template -void CalUnsortedSegmentMax(const T *input, const int *segment_ids, const int64_t num_segments, size_t outer_size, +template +void CalUnsortedSegmentMax(const T *input, const S *segment_ids, const int64_t num_segments, size_t outer_size, size_t inner_size, T *output, cudaStream_t stream) { - int size = (inner_size * KWARPSIZE * num_segments); + size_t size = (inner_size * KWARPSIZE * num_segments); bool fp16_flag = false; // handle fp16 min value if (std::is_same::value) { @@ -71,9 +71,19 @@ void CalUnsortedSegmentMax(const T *input, const int *segment_ids, const int64_t return; } -template void CalUnsortedSegmentMax(const float *input, const int *segment_ids, const int64_t num_segments, - size_t outer_size, size_t inner_size, float *output, cudaStream_t stream); -template void CalUnsortedSegmentMax(const half *input, const int *segment_ids, const int64_t num_segments, - size_t outer_size, size_t inner_size, half *output, cudaStream_t stream); -template void CalUnsortedSegmentMax(const int *input, const int *segment_ids, const int64_t num_segments, - size_t outer_size, size_t inner_size, int *output, cudaStream_t stream); +template void CalUnsortedSegmentMax(const float *input, const int *segment_ids, const int64_t num_segments, + size_t outer_size, size_t inner_size, float *output, + cudaStream_t stream); +template void CalUnsortedSegmentMax(const float *input, const int64_t *segment_ids, + const int64_t num_segments, size_t outer_size, size_t inner_size, + float *output, cudaStream_t stream); +template void CalUnsortedSegmentMax(const half *input, const int *segment_ids, const int64_t num_segments, + size_t outer_size, size_t inner_size, half *output, cudaStream_t stream); +template void CalUnsortedSegmentMax(const half *input, const int64_t *segment_ids, + const int64_t num_segments, size_t outer_size, size_t inner_size, + half *output, cudaStream_t stream); +template void CalUnsortedSegmentMax(const int *input, const int *segment_ids, const int64_t num_segments, + size_t outer_size, size_t inner_size, int *output, cudaStream_t stream); +template void CalUnsortedSegmentMax(const int *input, const int64_t *segment_ids, + const int64_t num_segments, size_t outer_size, size_t inner_size, + int *output, cudaStream_t stream); diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/unsorted_segment_max.cuh b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/unsorted_segment_max.cuh index caab13ce65d..859f1f5181a 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/unsorted_segment_max.cuh +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/unsorted_segment_max.cuh @@ -22,8 +22,8 @@ // Setting warp size to sync data across threads #define KWARPSIZE 32 -template -void CalUnsortedSegmentMax(const T *input, const int *segment_ids, const int64_t num_segments, size_t outer_size, +template +void CalUnsortedSegmentMax(const T *input, const S *segment_ids, const int64_t num_segments, size_t outer_size, size_t inner_size, T *output, cudaStream_t stream); #endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_UNSORT_SEGMENT_MAX_H_ diff --git a/mindspore/core/abstract/prim_arrays.cc b/mindspore/core/abstract/prim_arrays.cc index b607e4214cc..445fa9510af 100644 --- a/mindspore/core/abstract/prim_arrays.cc +++ b/mindspore/core/abstract/prim_arrays.cc @@ -306,7 +306,7 @@ AbstractBasePtr InferImplUnsortedSegmentMax(const AnalysisEnginePtr &, const Pri MS_EXCEPTION_IF_NULL(segment_ids->shape()); auto segment_ids_shape = segment_ids->shape()->shape(); (void)CheckTensorDType(x, {kFloat16, kFloat32, kInt32}, "Input 0 (x) for UnsortedSegmentMax should be %s"); - (void)CheckTensorDType(segment_ids, {kInt32}, "Input 1 (segment_ids) for UnsortedSegmentMax should be %s"); + (void)CheckTensorDType(segment_ids, {kInt32, kInt64}, "Input 1 (segment_ids) for UnsortedSegmentMax should be %s"); // check if dynamic shape bool x_is_dyn = (!x->shape()->min_shape().empty() && !x->shape()->max_shape().empty()); bool ids_is_dyn = (!segment_ids->shape()->min_shape().empty() && !segment_ids->shape()->max_shape().empty()); diff --git a/mindspore/ops/operations/array_ops.py b/mindspore/ops/operations/array_ops.py index cac9d391729..02f6ff0d111 100644 --- a/mindspore/ops/operations/array_ops.py +++ b/mindspore/ops/operations/array_ops.py @@ -2001,7 +2001,8 @@ class UnsortedSegmentMax(PrimitiveWithCheck): segment_ids_shape = segment_ids['shape'] valid_type = [mstype.float16, mstype.float32, mstype.int32] validator.check_tensor_dtype_valid("x", x['dtype'], valid_type, self.name) - validator.check_tensors_dtypes_same_and_valid({"segment_ids": segment_ids['dtype']}, [mstype.int32], self.name) + validator.check_tensors_dtypes_same_and_valid({"segment_ids": segment_ids['dtype']}, + [mstype.int32, mstype.int64], self.name) validator.check_equal_int(len(segment_ids_shape), 1, "rank of segment_ids_shape", self.name) num_segments_type = num_segments['dtype'] validator.check_subclass("num_segments", num_segments_type, [mstype.tensor, mstype.number], self.name) diff --git a/tests/st/ops/gpu/test_unsorted_segment_max.py b/tests/st/ops/gpu/test_unsorted_segment_max.py index 36c3ebdb022..fa1e1a32a81 100644 --- a/tests/st/ops/gpu/test_unsorted_segment_max.py +++ b/tests/st/ops/gpu/test_unsorted_segment_max.py @@ -71,12 +71,12 @@ def test_2d_int32(): @pytest.mark.level0 @pytest.mark.platform_x86_gpu_training @pytest.mark.env_onecard -def test_3d_float16(): - context.set_context(mode=context.PYNATIVE_MODE, device_target='GPU') +def test_3d_float16_int64(): + context.set_context(mode=context.GRAPH_MODE, device_target='GPU') input_x = Tensor(np.arange( 4 * 5 * 3, dtype=np.float16).reshape(4, 5, 3), dtype=mindspore.float16) - segment_ids = Tensor([2, 1, 1, -1], mstype.int32) - num_segments = 5 + segment_ids = Tensor([2, 1, 1, -1], mstype.int64) + num_segments = Tensor(5, dtype=mstype.int64) net = UnsortedSegmentMaxNet(num_segments) output = net(input_x, segment_ids).asnumpy() expect = np.array([[[-6.55e+04, -6.55e+04, -6.55e+04], @@ -110,12 +110,12 @@ def test_3d_float16(): @pytest.mark.level0 @pytest.mark.platform_x86_gpu_training @pytest.mark.env_onecard -def test_3d_float32(): +def test_3d_float32_int64(): context.set_context(mode=context.GRAPH_MODE, device_target='GPU') input_x = Tensor(np.arange( 4 * 5 * 3, dtype=np.float32).reshape(4, 5, 3), dtype=mindspore.float32) - segment_ids = Tensor([2, 1, 1, -1], mstype.int32) - num_segments = 3 + segment_ids = Tensor([2, 1, 1, -1], mstype.int64) + num_segments = Tensor(3, dtype=mstype.int64) net = UnsortedSegmentMaxNet(num_segments) output = net(input_x, segment_ids).asnumpy() expect = np.array([[[-3.4028235e+38, -3.4028235e+38, -3.4028235e+38],