From e927c193c56424890965c3813b8e237154a76e9d Mon Sep 17 00:00:00 2001 From: zhangzhewei Date: Thu, 17 Jun 2021 16:00:20 +0800 Subject: [PATCH] fancy index cpu unsorted_segment_sum change --- .../nnacl/base/unsorted_segment_sum_base.c | 42 ++++++++++--------- .../nnacl/base/unsorted_segment_sum_base.h | 18 +++++--- .../cpu/unsorted_segment_sum_cpu_kernel.cc | 22 ++++++---- .../arm/fp32_grad/unsorted_segment_sum.cc | 2 +- 4 files changed, 48 insertions(+), 36 deletions(-) diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/base/unsorted_segment_sum_base.c b/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/base/unsorted_segment_sum_base.c index 1df4ff5bf86..5cf3e2d8fbf 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/base/unsorted_segment_sum_base.c +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/base/unsorted_segment_sum_base.c @@ -16,25 +16,27 @@ #include "nnacl/base/unsorted_segment_sum_base.h" #include "nnacl/errorcode.h" -#define UNSORTEDSEGMENTSUM(type) \ - int UnsortedSegmentSum_##type(const type *input, int unit_num, int input_dim1, const int *indices, type *output, \ - int output_dim0, int output_dim1) { \ - if (input_dim1 == 0) { \ - return NNACL_ERR; \ - } \ - for (int i = 0; i < unit_num; ++i) { \ - int j = i / input_dim1; \ - int k = i % input_dim1; \ - \ - int index = indices[j]; \ - if (index < 0 || index >= output_dim0) { \ - continue; \ - } \ - int output_index = index * output_dim1 + k; \ - output[output_index] += input[i]; \ - } \ - return NNACL_OK; \ +#define UNSORTEDSEGMENTSUM(type, type1) \ + int UnsortedSegmentSum_##type##_##type1(const type *input, int unit_num, int input_dim1, const type1 *indices, \ + type *output, int output_dim0, int output_dim1) { \ + if (input_dim1 == 0) { \ + return NNACL_ERR; \ + } \ + for (int i = 0; i < unit_num; ++i) { \ + int j = i / input_dim1; \ + int k = i % input_dim1; \ + \ + type1 index = indices[j]; \ + if (index < 0 || index >= output_dim0) { \ + continue; \ + } \ + type1 output_index = index * output_dim1 + k; \ + output[output_index] += input[i]; \ + } \ + return NNACL_OK; \ } -UNSORTEDSEGMENTSUM(int) -UNSORTEDSEGMENTSUM(float) +UNSORTEDSEGMENTSUM(int, int) +UNSORTEDSEGMENTSUM(float, int) +UNSORTEDSEGMENTSUM(int, int64_t) +UNSORTEDSEGMENTSUM(float, int64_t) diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/base/unsorted_segment_sum_base.h b/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/base/unsorted_segment_sum_base.h index d20f7e4bce6..1a92c072098 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/base/unsorted_segment_sum_base.h +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/base/unsorted_segment_sum_base.h @@ -17,15 +17,21 @@ #ifndef MINDSPORE_NNACL_UNSORTED_SEGMENT_SUM_BASE_H_ #define MINDSPORE_NNACL_UNSORTED_SEGMENT_SUM_BASE_H_ +#include "nnacl/op_base.h" + #ifdef __cplusplus extern "C" { #endif -#define UnsortedSegmentSum(type, input, unit_num, input_dim1, indices, output, output_dim0, output_dim1) \ - UnsortedSegmentSum_##type(input, unit_num, input_dim1, indices, output, output_dim0, output_dim1) -int UnsortedSegmentSum_int(const int *input, int unit_num, int input_dim1, const int *indices, int *output, - int output_dim0, int output_dim1); -int UnsortedSegmentSum_float(const float *input, int unit_num, int input_dim1, const int *indices, float *output, - int output_dim0, int output_dim1); +#define UnsortedSegmentSum(type, type1, input, unit_num, input_dim1, indices, output, output_dim0, output_dim1) \ + UnsortedSegmentSum_##type##_##type1(input, unit_num, input_dim1, indices, output, output_dim0, output_dim1) +int UnsortedSegmentSum_int_int(const int *input, int unit_num, int input_dim1, const int *indices, int *output, + int output_dim0, int output_dim1); +int UnsortedSegmentSum_float_int(const float *input, int unit_num, int input_dim1, const int *indices, float *output, + int output_dim0, int output_dim1); +int UnsortedSegmentSum_int_int64_t(const int *input, int unit_num, int input_dim1, const int64_t *indices, int *output, + int output_dim0, int output_dim1); +int UnsortedSegmentSum_float_int64_t(const float *input, int unit_num, int input_dim1, const int64_t *indices, + float *output, int output_dim0, int output_dim1); #ifdef __cplusplus } #endif diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/unsorted_segment_sum_cpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/cpu/unsorted_segment_sum_cpu_kernel.cc index 950fc045d7b..18267c9a895 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/unsorted_segment_sum_cpu_kernel.cc +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/unsorted_segment_sum_cpu_kernel.cc @@ -53,7 +53,7 @@ bool UnsortedSegmentSumCPUKernel::Launch(const std::vector & const std::vector &outputs) { bool ret{true}; void *input_addr = inputs[0]->addr; - const int *indices_addr = reinterpret_cast(inputs[1]->addr); + void *indices_addr = inputs[1]->addr; void *output_addr = outputs[0]->addr; auto ret1 = memset_s(output_addr, outputs[0]->size, 0, outputs[0]->size); if (ret1 != EOK) { @@ -62,17 +62,21 @@ bool UnsortedSegmentSumCPUKernel::Launch(const std::vector & } if (dtype_ == kNumberTypeInt32 && segment_ids_dtype_ == kNumberTypeInt32) { - ret1 = UnsortedSegmentSum(int, static_cast(input_addr), unit_num_, input_dim1_, indices_addr, - static_cast(output_addr), output_dim0_, output_dim1_); + ret1 = UnsortedSegmentSum(int, int, static_cast(input_addr), unit_num_, input_dim1_, + static_cast(indices_addr), static_cast(output_addr), output_dim0_, + output_dim1_); } else if (dtype_ == kNumberTypeFloat32 && segment_ids_dtype_ == kNumberTypeInt32) { - ret1 = UnsortedSegmentSum(float, static_cast(input_addr), unit_num_, input_dim1_, indices_addr, - static_cast(output_addr), output_dim0_, output_dim1_); + ret1 = UnsortedSegmentSum(float, int, static_cast(input_addr), unit_num_, input_dim1_, + static_cast(indices_addr), static_cast(output_addr), output_dim0_, + output_dim1_); } else if (dtype_ == kNumberTypeInt32 && segment_ids_dtype_ == kNumberTypeInt64) { - ret1 = UnsortedSegmentSum(int, static_cast(input_addr), unit_num_, input_dim1_, indices_addr, - static_cast(output_addr), output_dim0_, output_dim1_); + ret1 = UnsortedSegmentSum(int, int64_t, static_cast(input_addr), unit_num_, input_dim1_, + static_cast(indices_addr), static_cast(output_addr), output_dim0_, + output_dim1_); } else if (dtype_ == kNumberTypeFloat32 && segment_ids_dtype_ == kNumberTypeInt64) { - ret1 = UnsortedSegmentSum(float, static_cast(input_addr), unit_num_, input_dim1_, indices_addr, - static_cast(output_addr), output_dim0_, output_dim1_); + ret1 = UnsortedSegmentSum(float, int64_t, static_cast(input_addr), unit_num_, input_dim1_, + static_cast(indices_addr), static_cast(output_addr), + output_dim0_, output_dim1_); } else { MS_LOG(ERROR) << "Only support input_x int32 and float32, indices int32 and int64"; return false; diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32_grad/unsorted_segment_sum.cc b/mindspore/lite/src/runtime/kernel/arm/fp32_grad/unsorted_segment_sum.cc index 80eea9d0a83..dedca0c8585 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32_grad/unsorted_segment_sum.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp32_grad/unsorted_segment_sum.cc @@ -86,7 +86,7 @@ int UnsortedSegmentSumCPUKernel::Execute(int task_id) { int *indices = reinterpret_cast(indices_tensor->data_c()); float *output = reinterpret_cast(output_tensor->MutableData()); std::fill(output, output + output_tensor->ElementsNum(), 0.f); - ret = UnsortedSegmentSum(float, input, unit_num_, input_dim1_, indices, output, output_dim0_, output_dim1_); + ret = UnsortedSegmentSum(float, int, input, unit_num_, input_dim1_, indices, output, output_dim0_, output_dim1_); if (ret != RET_OK) { MS_LOG(ERROR) << "StridedSliceGrad error error_code[" << ret << "]"; return RET_ERROR;