fancy index cpu unsorted_segment_sum change

This commit is contained in:
zhangzhewei 2021-06-17 16:00:20 +08:00
parent f5f2679757
commit e927c193c5
4 changed files with 48 additions and 36 deletions

View File

@ -16,25 +16,27 @@
#include "nnacl/base/unsorted_segment_sum_base.h"
#include "nnacl/errorcode.h"
#define UNSORTEDSEGMENTSUM(type) \
int UnsortedSegmentSum_##type(const type *input, int unit_num, int input_dim1, const int *indices, type *output, \
int output_dim0, int output_dim1) { \
if (input_dim1 == 0) { \
return NNACL_ERR; \
} \
for (int i = 0; i < unit_num; ++i) { \
int j = i / input_dim1; \
int k = i % input_dim1; \
\
int index = indices[j]; \
if (index < 0 || index >= output_dim0) { \
continue; \
} \
int output_index = index * output_dim1 + k; \
output[output_index] += input[i]; \
} \
return NNACL_OK; \
#define UNSORTEDSEGMENTSUM(type, type1) \
int UnsortedSegmentSum_##type##_##type1(const type *input, int unit_num, int input_dim1, const type1 *indices, \
type *output, int output_dim0, int output_dim1) { \
if (input_dim1 == 0) { \
return NNACL_ERR; \
} \
for (int i = 0; i < unit_num; ++i) { \
int j = i / input_dim1; \
int k = i % input_dim1; \
\
type1 index = indices[j]; \
if (index < 0 || index >= output_dim0) { \
continue; \
} \
type1 output_index = index * output_dim1 + k; \
output[output_index] += input[i]; \
} \
return NNACL_OK; \
}
UNSORTEDSEGMENTSUM(int)
UNSORTEDSEGMENTSUM(float)
UNSORTEDSEGMENTSUM(int, int)
UNSORTEDSEGMENTSUM(float, int)
UNSORTEDSEGMENTSUM(int, int64_t)
UNSORTEDSEGMENTSUM(float, int64_t)

View File

@ -17,15 +17,21 @@
#ifndef MINDSPORE_NNACL_UNSORTED_SEGMENT_SUM_BASE_H_
#define MINDSPORE_NNACL_UNSORTED_SEGMENT_SUM_BASE_H_
#include "nnacl/op_base.h"
#ifdef __cplusplus
extern "C" {
#endif
#define UnsortedSegmentSum(type, input, unit_num, input_dim1, indices, output, output_dim0, output_dim1) \
UnsortedSegmentSum_##type(input, unit_num, input_dim1, indices, output, output_dim0, output_dim1)
int UnsortedSegmentSum_int(const int *input, int unit_num, int input_dim1, const int *indices, int *output,
int output_dim0, int output_dim1);
int UnsortedSegmentSum_float(const float *input, int unit_num, int input_dim1, const int *indices, float *output,
int output_dim0, int output_dim1);
#define UnsortedSegmentSum(type, type1, input, unit_num, input_dim1, indices, output, output_dim0, output_dim1) \
UnsortedSegmentSum_##type##_##type1(input, unit_num, input_dim1, indices, output, output_dim0, output_dim1)
int UnsortedSegmentSum_int_int(const int *input, int unit_num, int input_dim1, const int *indices, int *output,
int output_dim0, int output_dim1);
int UnsortedSegmentSum_float_int(const float *input, int unit_num, int input_dim1, const int *indices, float *output,
int output_dim0, int output_dim1);
int UnsortedSegmentSum_int_int64_t(const int *input, int unit_num, int input_dim1, const int64_t *indices, int *output,
int output_dim0, int output_dim1);
int UnsortedSegmentSum_float_int64_t(const float *input, int unit_num, int input_dim1, const int64_t *indices,
float *output, int output_dim0, int output_dim1);
#ifdef __cplusplus
}
#endif

View File

@ -53,7 +53,7 @@ bool UnsortedSegmentSumCPUKernel::Launch(const std::vector<kernel::AddressPtr> &
const std::vector<kernel::AddressPtr> &outputs) {
bool ret{true};
void *input_addr = inputs[0]->addr;
const int *indices_addr = reinterpret_cast<const int *>(inputs[1]->addr);
void *indices_addr = inputs[1]->addr;
void *output_addr = outputs[0]->addr;
auto ret1 = memset_s(output_addr, outputs[0]->size, 0, outputs[0]->size);
if (ret1 != EOK) {
@ -62,17 +62,21 @@ bool UnsortedSegmentSumCPUKernel::Launch(const std::vector<kernel::AddressPtr> &
}
if (dtype_ == kNumberTypeInt32 && segment_ids_dtype_ == kNumberTypeInt32) {
ret1 = UnsortedSegmentSum(int, static_cast<const int *>(input_addr), unit_num_, input_dim1_, indices_addr,
static_cast<int *>(output_addr), output_dim0_, output_dim1_);
ret1 = UnsortedSegmentSum(int, int, static_cast<const int *>(input_addr), unit_num_, input_dim1_,
static_cast<const int *>(indices_addr), static_cast<int *>(output_addr), output_dim0_,
output_dim1_);
} else if (dtype_ == kNumberTypeFloat32 && segment_ids_dtype_ == kNumberTypeInt32) {
ret1 = UnsortedSegmentSum(float, static_cast<const float *>(input_addr), unit_num_, input_dim1_, indices_addr,
static_cast<float *>(output_addr), output_dim0_, output_dim1_);
ret1 = UnsortedSegmentSum(float, int, static_cast<const float *>(input_addr), unit_num_, input_dim1_,
static_cast<const int *>(indices_addr), static_cast<float *>(output_addr), output_dim0_,
output_dim1_);
} else if (dtype_ == kNumberTypeInt32 && segment_ids_dtype_ == kNumberTypeInt64) {
ret1 = UnsortedSegmentSum(int, static_cast<const int *>(input_addr), unit_num_, input_dim1_, indices_addr,
static_cast<int *>(output_addr), output_dim0_, output_dim1_);
ret1 = UnsortedSegmentSum(int, int64_t, static_cast<const int *>(input_addr), unit_num_, input_dim1_,
static_cast<const int64_t *>(indices_addr), static_cast<int *>(output_addr), output_dim0_,
output_dim1_);
} else if (dtype_ == kNumberTypeFloat32 && segment_ids_dtype_ == kNumberTypeInt64) {
ret1 = UnsortedSegmentSum(float, static_cast<const float *>(input_addr), unit_num_, input_dim1_, indices_addr,
static_cast<float *>(output_addr), output_dim0_, output_dim1_);
ret1 = UnsortedSegmentSum(float, int64_t, static_cast<const float *>(input_addr), unit_num_, input_dim1_,
static_cast<const int64_t *>(indices_addr), static_cast<float *>(output_addr),
output_dim0_, output_dim1_);
} else {
MS_LOG(ERROR) << "Only support input_x int32 and float32, indices int32 and int64";
return false;

View File

@ -86,7 +86,7 @@ int UnsortedSegmentSumCPUKernel::Execute(int task_id) {
int *indices = reinterpret_cast<int *>(indices_tensor->data_c());
float *output = reinterpret_cast<float *>(output_tensor->MutableData());
std::fill(output, output + output_tensor->ElementsNum(), 0.f);
ret = UnsortedSegmentSum(float, input, unit_num_, input_dim1_, indices, output, output_dim0_, output_dim1_);
ret = UnsortedSegmentSum(float, int, input, unit_num_, input_dim1_, indices, output, output_dim0_, output_dim1_);
if (ret != RET_OK) {
MS_LOG(ERROR) << "StridedSliceGrad error error_code[" << ret << "]";
return RET_ERROR;