From 83151509dca7674582158e1d0bacbf499a94ea6f Mon Sep 17 00:00:00 2001 From: wilfChen Date: Sat, 16 May 2020 17:06:47 +0800 Subject: [PATCH] UnsortedSegmentSum kernel support Nd --- .../gpu/arrays/unsorted_segment_sum_gpu_kernel.h | 15 ++++++++++----- 1 file changed, 10 insertions(+), 5 deletions(-) diff --git a/mindspore/ccsrc/kernel/gpu/arrays/unsorted_segment_sum_gpu_kernel.h b/mindspore/ccsrc/kernel/gpu/arrays/unsorted_segment_sum_gpu_kernel.h index d773422c276..24c1f09097f 100644 --- a/mindspore/ccsrc/kernel/gpu/arrays/unsorted_segment_sum_gpu_kernel.h +++ b/mindspore/ccsrc/kernel/gpu/arrays/unsorted_segment_sum_gpu_kernel.h @@ -50,16 +50,21 @@ class UnsortedSegmentSumGpuKernel : public GpuKernel { bool Init(const CNodePtr &kernel_node) override { auto input_shapes = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); + auto ids_shapes = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 1); auto output_shapes = AnfAlgo::GetOutputInferShape(kernel_node, 0); - input_dim0_ = input_shapes[0]; - for (size_t i = 1; i < input_shapes.size(); i++) { - input_dim1_ *= input_shapes[i]; + auto axis = ids_shapes.size(); + for (size_t i = 0; i < input_shapes.size(); i++) { + if (i < axis) { + input_dim0_ *= input_shapes[i]; + } else { + input_dim1_ *= input_shapes[i]; + } } output_dim0_ = output_shapes[0]; - for (size_t i = 1; i < output_shapes.size(); i++) { - output_dim1_ *= output_shapes[i]; + for (size_t j = 1; j < output_shapes.size(); j++) { + output_dim1_ *= output_shapes[j]; } InitSizeLists();