fancy index cpu unsorted_segment_sum change
This commit is contained in:
parent
f5f2679757
commit
e927c193c5
|
@ -16,25 +16,27 @@
|
|||
#include "nnacl/base/unsorted_segment_sum_base.h"
|
||||
#include "nnacl/errorcode.h"
|
||||
|
||||
#define UNSORTEDSEGMENTSUM(type) \
|
||||
int UnsortedSegmentSum_##type(const type *input, int unit_num, int input_dim1, const int *indices, type *output, \
|
||||
int output_dim0, int output_dim1) { \
|
||||
if (input_dim1 == 0) { \
|
||||
return NNACL_ERR; \
|
||||
} \
|
||||
for (int i = 0; i < unit_num; ++i) { \
|
||||
int j = i / input_dim1; \
|
||||
int k = i % input_dim1; \
|
||||
\
|
||||
int index = indices[j]; \
|
||||
if (index < 0 || index >= output_dim0) { \
|
||||
continue; \
|
||||
} \
|
||||
int output_index = index * output_dim1 + k; \
|
||||
output[output_index] += input[i]; \
|
||||
} \
|
||||
return NNACL_OK; \
|
||||
#define UNSORTEDSEGMENTSUM(type, type1) \
|
||||
int UnsortedSegmentSum_##type##_##type1(const type *input, int unit_num, int input_dim1, const type1 *indices, \
|
||||
type *output, int output_dim0, int output_dim1) { \
|
||||
if (input_dim1 == 0) { \
|
||||
return NNACL_ERR; \
|
||||
} \
|
||||
for (int i = 0; i < unit_num; ++i) { \
|
||||
int j = i / input_dim1; \
|
||||
int k = i % input_dim1; \
|
||||
\
|
||||
type1 index = indices[j]; \
|
||||
if (index < 0 || index >= output_dim0) { \
|
||||
continue; \
|
||||
} \
|
||||
type1 output_index = index * output_dim1 + k; \
|
||||
output[output_index] += input[i]; \
|
||||
} \
|
||||
return NNACL_OK; \
|
||||
}
|
||||
|
||||
UNSORTEDSEGMENTSUM(int)
|
||||
UNSORTEDSEGMENTSUM(float)
|
||||
UNSORTEDSEGMENTSUM(int, int)
|
||||
UNSORTEDSEGMENTSUM(float, int)
|
||||
UNSORTEDSEGMENTSUM(int, int64_t)
|
||||
UNSORTEDSEGMENTSUM(float, int64_t)
|
||||
|
|
|
@ -17,15 +17,21 @@
|
|||
#ifndef MINDSPORE_NNACL_UNSORTED_SEGMENT_SUM_BASE_H_
|
||||
#define MINDSPORE_NNACL_UNSORTED_SEGMENT_SUM_BASE_H_
|
||||
|
||||
#include "nnacl/op_base.h"
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif
|
||||
#define UnsortedSegmentSum(type, input, unit_num, input_dim1, indices, output, output_dim0, output_dim1) \
|
||||
UnsortedSegmentSum_##type(input, unit_num, input_dim1, indices, output, output_dim0, output_dim1)
|
||||
int UnsortedSegmentSum_int(const int *input, int unit_num, int input_dim1, const int *indices, int *output,
|
||||
int output_dim0, int output_dim1);
|
||||
int UnsortedSegmentSum_float(const float *input, int unit_num, int input_dim1, const int *indices, float *output,
|
||||
int output_dim0, int output_dim1);
|
||||
#define UnsortedSegmentSum(type, type1, input, unit_num, input_dim1, indices, output, output_dim0, output_dim1) \
|
||||
UnsortedSegmentSum_##type##_##type1(input, unit_num, input_dim1, indices, output, output_dim0, output_dim1)
|
||||
int UnsortedSegmentSum_int_int(const int *input, int unit_num, int input_dim1, const int *indices, int *output,
|
||||
int output_dim0, int output_dim1);
|
||||
int UnsortedSegmentSum_float_int(const float *input, int unit_num, int input_dim1, const int *indices, float *output,
|
||||
int output_dim0, int output_dim1);
|
||||
int UnsortedSegmentSum_int_int64_t(const int *input, int unit_num, int input_dim1, const int64_t *indices, int *output,
|
||||
int output_dim0, int output_dim1);
|
||||
int UnsortedSegmentSum_float_int64_t(const float *input, int unit_num, int input_dim1, const int64_t *indices,
|
||||
float *output, int output_dim0, int output_dim1);
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
|
|
|
@ -53,7 +53,7 @@ bool UnsortedSegmentSumCPUKernel::Launch(const std::vector<kernel::AddressPtr> &
|
|||
const std::vector<kernel::AddressPtr> &outputs) {
|
||||
bool ret{true};
|
||||
void *input_addr = inputs[0]->addr;
|
||||
const int *indices_addr = reinterpret_cast<const int *>(inputs[1]->addr);
|
||||
void *indices_addr = inputs[1]->addr;
|
||||
void *output_addr = outputs[0]->addr;
|
||||
auto ret1 = memset_s(output_addr, outputs[0]->size, 0, outputs[0]->size);
|
||||
if (ret1 != EOK) {
|
||||
|
@ -62,17 +62,21 @@ bool UnsortedSegmentSumCPUKernel::Launch(const std::vector<kernel::AddressPtr> &
|
|||
}
|
||||
|
||||
if (dtype_ == kNumberTypeInt32 && segment_ids_dtype_ == kNumberTypeInt32) {
|
||||
ret1 = UnsortedSegmentSum(int, static_cast<const int *>(input_addr), unit_num_, input_dim1_, indices_addr,
|
||||
static_cast<int *>(output_addr), output_dim0_, output_dim1_);
|
||||
ret1 = UnsortedSegmentSum(int, int, static_cast<const int *>(input_addr), unit_num_, input_dim1_,
|
||||
static_cast<const int *>(indices_addr), static_cast<int *>(output_addr), output_dim0_,
|
||||
output_dim1_);
|
||||
} else if (dtype_ == kNumberTypeFloat32 && segment_ids_dtype_ == kNumberTypeInt32) {
|
||||
ret1 = UnsortedSegmentSum(float, static_cast<const float *>(input_addr), unit_num_, input_dim1_, indices_addr,
|
||||
static_cast<float *>(output_addr), output_dim0_, output_dim1_);
|
||||
ret1 = UnsortedSegmentSum(float, int, static_cast<const float *>(input_addr), unit_num_, input_dim1_,
|
||||
static_cast<const int *>(indices_addr), static_cast<float *>(output_addr), output_dim0_,
|
||||
output_dim1_);
|
||||
} else if (dtype_ == kNumberTypeInt32 && segment_ids_dtype_ == kNumberTypeInt64) {
|
||||
ret1 = UnsortedSegmentSum(int, static_cast<const int *>(input_addr), unit_num_, input_dim1_, indices_addr,
|
||||
static_cast<int *>(output_addr), output_dim0_, output_dim1_);
|
||||
ret1 = UnsortedSegmentSum(int, int64_t, static_cast<const int *>(input_addr), unit_num_, input_dim1_,
|
||||
static_cast<const int64_t *>(indices_addr), static_cast<int *>(output_addr), output_dim0_,
|
||||
output_dim1_);
|
||||
} else if (dtype_ == kNumberTypeFloat32 && segment_ids_dtype_ == kNumberTypeInt64) {
|
||||
ret1 = UnsortedSegmentSum(float, static_cast<const float *>(input_addr), unit_num_, input_dim1_, indices_addr,
|
||||
static_cast<float *>(output_addr), output_dim0_, output_dim1_);
|
||||
ret1 = UnsortedSegmentSum(float, int64_t, static_cast<const float *>(input_addr), unit_num_, input_dim1_,
|
||||
static_cast<const int64_t *>(indices_addr), static_cast<float *>(output_addr),
|
||||
output_dim0_, output_dim1_);
|
||||
} else {
|
||||
MS_LOG(ERROR) << "Only support input_x int32 and float32, indices int32 and int64";
|
||||
return false;
|
||||
|
|
|
@ -86,7 +86,7 @@ int UnsortedSegmentSumCPUKernel::Execute(int task_id) {
|
|||
int *indices = reinterpret_cast<int *>(indices_tensor->data_c());
|
||||
float *output = reinterpret_cast<float *>(output_tensor->MutableData());
|
||||
std::fill(output, output + output_tensor->ElementsNum(), 0.f);
|
||||
ret = UnsortedSegmentSum(float, input, unit_num_, input_dim1_, indices, output, output_dim0_, output_dim1_);
|
||||
ret = UnsortedSegmentSum(float, int, input, unit_num_, input_dim1_, indices, output, output_dim0_, output_dim1_);
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "StridedSliceGrad error error_code[" << ret << "]";
|
||||
return RET_ERROR;
|
||||
|
|
Loading…
Reference in New Issue