optimize arithmetic operators
This commit is contained in:
parent
2451a94125
commit
ee72df2cbc
|
@ -32,9 +32,9 @@ int ElementAddRelu(const float *in0, const float *in1, float *out, int size);
|
||||||
int ElementAddRelu6(const float *in0, const float *in1, float *out, int size);
|
int ElementAddRelu6(const float *in0, const float *in1, float *out, int size);
|
||||||
int ElementAddInt(const int *in0, const int *in1, int *out, int size);
|
int ElementAddInt(const int *in0, const int *in1, int *out, int size);
|
||||||
int ElementOptAdd(const float *in0, const float *in1, float *out, int size, const ArithmeticParameter *param);
|
int ElementOptAdd(const float *in0, const float *in1, float *out, int size, const ArithmeticParameter *param);
|
||||||
int ElementOptAddInt(const int *in0, const int *in1, int *out, int size, const ArithmeticParameter *param);
|
|
||||||
int ElementOptAddRelu(const float *in0, const float *in1, float *out, int size, const ArithmeticParameter *param);
|
int ElementOptAddRelu(const float *in0, const float *in1, float *out, int size, const ArithmeticParameter *param);
|
||||||
int ElementOptAddRelu6(const float *in0, const float *in1, float *out, int size, const ArithmeticParameter *param);
|
int ElementOptAddRelu6(const float *in0, const float *in1, float *out, int size, const ArithmeticParameter *param);
|
||||||
|
int ElementOptAddInt(const int *in0, const int *in1, int *out, int size, const ArithmeticParameter *param);
|
||||||
int BroadcastAdd(const float *in0, const float *in1, float *tile_in0, float *tile_in1, float *out, int size,
|
int BroadcastAdd(const float *in0, const float *in1, float *tile_in0, float *tile_in1, float *out, int size,
|
||||||
ArithmeticParameter *param);
|
ArithmeticParameter *param);
|
||||||
|
|
||||||
|
|
|
@ -26,6 +26,22 @@ int ElementEqualFp32(const float *input0, const float *input1, uint8_t *output,
|
||||||
return NNACL_OK;
|
return NNACL_OK;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
int ElementOptEqualFp32(const float *input0, const float *input1, uint8_t *output, int element_size,
|
||||||
|
const ArithmeticParameter *param) {
|
||||||
|
int i = 0;
|
||||||
|
if (param->in_elements_num0_ == 1) {
|
||||||
|
for (; i < element_size; i++) {
|
||||||
|
output[i] = input0[0] == input1[i];
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
for (; i < element_size; i++) {
|
||||||
|
output[i] = input0[i] == input1[0];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return NNACL_OK;
|
||||||
|
}
|
||||||
|
|
||||||
int ElementEqualInt32(const int32_t *input0, const int32_t *input1, uint8_t *output, int element_size) {
|
int ElementEqualInt32(const int32_t *input0, const int32_t *input1, uint8_t *output, int element_size) {
|
||||||
for (int i = 0; i < element_size; i++) {
|
for (int i = 0; i < element_size; i++) {
|
||||||
output[i] = input0[i] == input1[i];
|
output[i] = input0[i] == input1[i];
|
||||||
|
@ -33,6 +49,22 @@ int ElementEqualInt32(const int32_t *input0, const int32_t *input1, uint8_t *out
|
||||||
return NNACL_OK;
|
return NNACL_OK;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
int ElementOptEqualInt32(const int32_t *input0, const int32_t *input1, uint8_t *output, int element_size,
|
||||||
|
const ArithmeticParameter *param) {
|
||||||
|
int i = 0;
|
||||||
|
if (param->in_elements_num0_ == 1) {
|
||||||
|
for (; i < element_size; i++) {
|
||||||
|
output[i] = input0[0] == input1[i];
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
for (; i < element_size; i++) {
|
||||||
|
output[i] = input0[i] == input1[0];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return NNACL_OK;
|
||||||
|
}
|
||||||
|
|
||||||
// not equal
|
// not equal
|
||||||
int ElementNotEqualFp32(const float *input0, const float *input1, uint8_t *output, int element_size) {
|
int ElementNotEqualFp32(const float *input0, const float *input1, uint8_t *output, int element_size) {
|
||||||
for (int i = 0; i < element_size; i++) {
|
for (int i = 0; i < element_size; i++) {
|
||||||
|
@ -41,6 +73,23 @@ int ElementNotEqualFp32(const float *input0, const float *input1, uint8_t *outpu
|
||||||
return NNACL_OK;
|
return NNACL_OK;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// not equal
|
||||||
|
int ElementOptNotEqualFp32(const float *input0, const float *input1, uint8_t *output, int element_size,
|
||||||
|
const ArithmeticParameter *param) {
|
||||||
|
int i = 0;
|
||||||
|
if (param->in_elements_num0_ == 1) {
|
||||||
|
for (; i < element_size; i++) {
|
||||||
|
output[i] = input0[0] != input1[i];
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
for (; i < element_size; i++) {
|
||||||
|
output[i] = input0[i] != input1[0];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return NNACL_OK;
|
||||||
|
}
|
||||||
|
|
||||||
int ElementNotEqualInt32(const int32_t *input0, const int32_t *input1, uint8_t *output, int element_size) {
|
int ElementNotEqualInt32(const int32_t *input0, const int32_t *input1, uint8_t *output, int element_size) {
|
||||||
for (int i = 0; i < element_size; i++) {
|
for (int i = 0; i < element_size; i++) {
|
||||||
output[i] = input0[i] != input1[i];
|
output[i] = input0[i] != input1[i];
|
||||||
|
@ -48,6 +97,22 @@ int ElementNotEqualInt32(const int32_t *input0, const int32_t *input1, uint8_t *
|
||||||
return NNACL_OK;
|
return NNACL_OK;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
int ElementOptNotEqualInt32(const int32_t *input0, const int32_t *input1, uint8_t *output, int element_size,
|
||||||
|
const ArithmeticParameter *param) {
|
||||||
|
int i = 0;
|
||||||
|
if (param->in_elements_num0_ == 1) {
|
||||||
|
for (; i < element_size; i++) {
|
||||||
|
output[i] = input0[0] != input1[i];
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
for (; i < element_size; i++) {
|
||||||
|
output[i] = input0[i] != input1[0];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return NNACL_OK;
|
||||||
|
}
|
||||||
|
|
||||||
// less
|
// less
|
||||||
int ElementLessFp32(const float *input0, const float *input1, uint8_t *output, int element_size) {
|
int ElementLessFp32(const float *input0, const float *input1, uint8_t *output, int element_size) {
|
||||||
for (int i = 0; i < element_size; i++) {
|
for (int i = 0; i < element_size; i++) {
|
||||||
|
@ -56,6 +121,22 @@ int ElementLessFp32(const float *input0, const float *input1, uint8_t *output, i
|
||||||
return NNACL_OK;
|
return NNACL_OK;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
int ElementOptLessFp32(const float *input0, const float *input1, uint8_t *output, int element_size,
|
||||||
|
const ArithmeticParameter *param) {
|
||||||
|
int i = 0;
|
||||||
|
if (param->in_elements_num0_ == 1) {
|
||||||
|
for (; i < element_size; i++) {
|
||||||
|
output[i] = input0[0] < input1[i];
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
for (; i < element_size; i++) {
|
||||||
|
output[i] = input0[i] < input1[0];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return NNACL_OK;
|
||||||
|
}
|
||||||
|
|
||||||
int ElementLessInt32(const int32_t *input0, const int32_t *input1, uint8_t *output, int element_size) {
|
int ElementLessInt32(const int32_t *input0, const int32_t *input1, uint8_t *output, int element_size) {
|
||||||
for (int i = 0; i < element_size; i++) {
|
for (int i = 0; i < element_size; i++) {
|
||||||
output[i] = input0[i] < input1[i];
|
output[i] = input0[i] < input1[i];
|
||||||
|
@ -63,6 +144,22 @@ int ElementLessInt32(const int32_t *input0, const int32_t *input1, uint8_t *outp
|
||||||
return NNACL_OK;
|
return NNACL_OK;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
int ElementOptLessInt32(const int32_t *input0, const int32_t *input1, uint8_t *output, int element_size,
|
||||||
|
const ArithmeticParameter *param) {
|
||||||
|
int i = 0;
|
||||||
|
if (param->in_elements_num0_ == 1) {
|
||||||
|
for (; i < element_size; i++) {
|
||||||
|
output[i] = input0[0] < input1[i];
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
for (; i < element_size; i++) {
|
||||||
|
output[i] = input0[i] < input1[0];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return NNACL_OK;
|
||||||
|
}
|
||||||
|
|
||||||
// less equal
|
// less equal
|
||||||
int ElementLessEqualFp32(const float *input0, const float *input1, uint8_t *output, int element_size) {
|
int ElementLessEqualFp32(const float *input0, const float *input1, uint8_t *output, int element_size) {
|
||||||
for (int i = 0; i < element_size; i++) {
|
for (int i = 0; i < element_size; i++) {
|
||||||
|
@ -71,6 +168,22 @@ int ElementLessEqualFp32(const float *input0, const float *input1, uint8_t *outp
|
||||||
return NNACL_OK;
|
return NNACL_OK;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
int ElementOptLessEqualFp32(const float *input0, const float *input1, uint8_t *output, int element_size,
|
||||||
|
const ArithmeticParameter *param) {
|
||||||
|
int i = 0;
|
||||||
|
if (param->in_elements_num0_ == 1) {
|
||||||
|
for (; i < element_size; i++) {
|
||||||
|
output[i] = input0[0] <= input1[i];
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
for (; i < element_size; i++) {
|
||||||
|
output[i] = input0[i] <= input1[0];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return NNACL_OK;
|
||||||
|
}
|
||||||
|
|
||||||
int ElementLessEqualInt32(const int32_t *input0, const int32_t *input1, uint8_t *output, int element_size) {
|
int ElementLessEqualInt32(const int32_t *input0, const int32_t *input1, uint8_t *output, int element_size) {
|
||||||
for (int i = 0; i < element_size; i++) {
|
for (int i = 0; i < element_size; i++) {
|
||||||
output[i] = input0[i] <= input1[i];
|
output[i] = input0[i] <= input1[i];
|
||||||
|
@ -78,6 +191,22 @@ int ElementLessEqualInt32(const int32_t *input0, const int32_t *input1, uint8_t
|
||||||
return NNACL_OK;
|
return NNACL_OK;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
int ElementOptLessEqualInt32(const int32_t *input0, const int32_t *input1, uint8_t *output, int element_size,
|
||||||
|
const ArithmeticParameter *param) {
|
||||||
|
int i = 0;
|
||||||
|
if (param->in_elements_num0_ == 1) {
|
||||||
|
for (; i < element_size; i++) {
|
||||||
|
output[i] = input0[0] <= input1[i];
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
for (; i < element_size; i++) {
|
||||||
|
output[i] = input0[i] <= input1[0];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return NNACL_OK;
|
||||||
|
}
|
||||||
|
|
||||||
// greater
|
// greater
|
||||||
int ElementGreaterFp32(const float *input0, const float *input1, uint8_t *output, int element_size) {
|
int ElementGreaterFp32(const float *input0, const float *input1, uint8_t *output, int element_size) {
|
||||||
for (int i = 0; i < element_size; i++) {
|
for (int i = 0; i < element_size; i++) {
|
||||||
|
@ -86,6 +215,22 @@ int ElementGreaterFp32(const float *input0, const float *input1, uint8_t *output
|
||||||
return NNACL_OK;
|
return NNACL_OK;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
int ElementOptGreaterFp32(const float *input0, const float *input1, uint8_t *output, int element_size,
|
||||||
|
const ArithmeticParameter *param) {
|
||||||
|
int i = 0;
|
||||||
|
if (param->in_elements_num0_ == 1) {
|
||||||
|
for (; i < element_size; i++) {
|
||||||
|
output[i] = input0[0] > input1[i];
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
for (; i < element_size; i++) {
|
||||||
|
output[i] = input0[i] > input1[0];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return NNACL_OK;
|
||||||
|
}
|
||||||
|
|
||||||
int ElementGreaterInt32(const int32_t *input0, const int32_t *input1, uint8_t *output, int element_size) {
|
int ElementGreaterInt32(const int32_t *input0, const int32_t *input1, uint8_t *output, int element_size) {
|
||||||
for (int i = 0; i < element_size; i++) {
|
for (int i = 0; i < element_size; i++) {
|
||||||
output[i] = input0[i] > input1[i];
|
output[i] = input0[i] > input1[i];
|
||||||
|
@ -93,6 +238,21 @@ int ElementGreaterInt32(const int32_t *input0, const int32_t *input1, uint8_t *o
|
||||||
return NNACL_OK;
|
return NNACL_OK;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
int ElementOptGreaterInt32(const int32_t *input0, const int32_t *input1, uint8_t *output, int element_size,
|
||||||
|
const ArithmeticParameter *param) {
|
||||||
|
int i = 0;
|
||||||
|
if (param->in_elements_num0_ == 1) {
|
||||||
|
for (; i < element_size; i++) {
|
||||||
|
output[i] = input0[0] > input1[i];
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
for (; i < element_size; i++) {
|
||||||
|
output[i] = input0[i] > input1[0];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return NNACL_OK;
|
||||||
|
}
|
||||||
// greater equal
|
// greater equal
|
||||||
int ElementGreaterEqualFp32(const float *input0, const float *input1, uint8_t *output, int element_size) {
|
int ElementGreaterEqualFp32(const float *input0, const float *input1, uint8_t *output, int element_size) {
|
||||||
for (int i = 0; i < element_size; i++) {
|
for (int i = 0; i < element_size; i++) {
|
||||||
|
@ -101,9 +261,41 @@ int ElementGreaterEqualFp32(const float *input0, const float *input1, uint8_t *o
|
||||||
return NNACL_OK;
|
return NNACL_OK;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
int ElementOptGreaterEqualFp32(const float *input0, const float *input1, uint8_t *output, int element_size,
|
||||||
|
const ArithmeticParameter *param) {
|
||||||
|
int i = 0;
|
||||||
|
if (param->in_elements_num0_ == 1) {
|
||||||
|
for (; i < element_size; i++) {
|
||||||
|
output[i] = input0[0] >= input1[i];
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
for (; i < element_size; i++) {
|
||||||
|
output[i] = input0[i] >= input1[0];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return NNACL_OK;
|
||||||
|
}
|
||||||
|
|
||||||
int ElementGreaterEqualInt32(const int32_t *input0, const int32_t *input1, uint8_t *output, int element_size) {
|
int ElementGreaterEqualInt32(const int32_t *input0, const int32_t *input1, uint8_t *output, int element_size) {
|
||||||
for (int i = 0; i < element_size; i++) {
|
for (int i = 0; i < element_size; i++) {
|
||||||
output[i] = input0[i] >= input1[i];
|
output[i] = input0[i] >= input1[i];
|
||||||
}
|
}
|
||||||
return NNACL_OK;
|
return NNACL_OK;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
int ElementOptGreaterEqualInt32(const int32_t *input0, const int32_t *input1, uint8_t *output, int element_size,
|
||||||
|
const ArithmeticParameter *param) {
|
||||||
|
int i = 0;
|
||||||
|
if (param->in_elements_num0_ == 1) {
|
||||||
|
for (; i < element_size; i++) {
|
||||||
|
output[i] = input0[0] >= input1[i];
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
for (; i < element_size; i++) {
|
||||||
|
output[i] = input0[i] >= input1[0];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return NNACL_OK;
|
||||||
|
}
|
||||||
|
|
|
@ -21,28 +21,53 @@
|
||||||
#include <arm_neon.h>
|
#include <arm_neon.h>
|
||||||
#endif
|
#endif
|
||||||
#include "nnacl/op_base.h"
|
#include "nnacl/op_base.h"
|
||||||
|
#include "nnacl/base/arithmetic_base.h"
|
||||||
#include "nnacl/errorcode.h"
|
#include "nnacl/errorcode.h"
|
||||||
|
|
||||||
#ifdef __cplusplus
|
#ifdef __cplusplus
|
||||||
extern "C" {
|
extern "C" {
|
||||||
#endif
|
#endif
|
||||||
int ElementEqualFp32(const float *input0, const float *input1, uint8_t *output, int element_size);
|
int ElementEqualFp32(const float *input0, const float *input1, uint8_t *output, int element_size);
|
||||||
|
int ElementOptEqualFp32(const float *input0, const float *input1, uint8_t *output, int element_size,
|
||||||
|
const ArithmeticParameter *param);
|
||||||
int ElementEqualInt32(const int32_t *input0, const int32_t *input1, uint8_t *output, int element_size);
|
int ElementEqualInt32(const int32_t *input0, const int32_t *input1, uint8_t *output, int element_size);
|
||||||
|
int ElementOptEqualInt32(const int32_t *input0, const int32_t *input1, uint8_t *output, int element_size,
|
||||||
|
const ArithmeticParameter *param);
|
||||||
|
|
||||||
int ElementNotEqualFp32(const float *input0, const float *input1, uint8_t *output, int element_size);
|
int ElementNotEqualFp32(const float *input0, const float *input1, uint8_t *output, int element_size);
|
||||||
|
int ElementOptNotEqualFp32(const float *input0, const float *input1, uint8_t *output, int element_size,
|
||||||
|
const ArithmeticParameter *param);
|
||||||
int ElementNotEqualInt32(const int32_t *input0, const int32_t *input1, uint8_t *output, int element_size);
|
int ElementNotEqualInt32(const int32_t *input0, const int32_t *input1, uint8_t *output, int element_size);
|
||||||
|
int ElementOptNotEqualInt32(const int32_t *input0, const int32_t *input1, uint8_t *output, int element_size,
|
||||||
|
const ArithmeticParameter *param);
|
||||||
|
|
||||||
int ElementLessFp32(const float *input0, const float *input1, uint8_t *output, int element_size);
|
int ElementLessFp32(const float *input0, const float *input1, uint8_t *output, int element_size);
|
||||||
|
int ElementOptLessFp32(const float *input0, const float *input1, uint8_t *output, int element_size,
|
||||||
|
const ArithmeticParameter *param);
|
||||||
int ElementLessInt32(const int32_t *input0, const int32_t *input1, uint8_t *output, int element_size);
|
int ElementLessInt32(const int32_t *input0, const int32_t *input1, uint8_t *output, int element_size);
|
||||||
|
int ElementOptLessInt32(const int32_t *input0, const int32_t *input1, uint8_t *output, int element_size,
|
||||||
|
const ArithmeticParameter *param);
|
||||||
|
|
||||||
int ElementLessEqualFp32(const float *input0, const float *input1, uint8_t *output, int element_size);
|
int ElementLessEqualFp32(const float *input0, const float *input1, uint8_t *output, int element_size);
|
||||||
|
int ElementOptLessEqualFp32(const float *input0, const float *input1, uint8_t *output, int element_size,
|
||||||
|
const ArithmeticParameter *param);
|
||||||
int ElementLessEqualInt32(const int32_t *input0, const int32_t *input1, uint8_t *output, int element_size);
|
int ElementLessEqualInt32(const int32_t *input0, const int32_t *input1, uint8_t *output, int element_size);
|
||||||
|
int ElementOptLessEqualInt32(const int32_t *input0, const int32_t *input1, uint8_t *output, int element_size,
|
||||||
|
const ArithmeticParameter *param);
|
||||||
|
|
||||||
int ElementGreaterFp32(const float *input0, const float *input1, uint8_t *output, int element_size);
|
int ElementGreaterFp32(const float *input0, const float *input1, uint8_t *output, int element_size);
|
||||||
|
int ElementOptGreaterFp32(const float *input0, const float *input1, uint8_t *output, int element_size,
|
||||||
|
const ArithmeticParameter *param);
|
||||||
int ElementGreaterInt32(const int32_t *input0, const int32_t *input1, uint8_t *output, int element_size);
|
int ElementGreaterInt32(const int32_t *input0, const int32_t *input1, uint8_t *output, int element_size);
|
||||||
|
int ElementOptGreaterInt32(const int32_t *input0, const int32_t *input1, uint8_t *output, int element_size,
|
||||||
|
const ArithmeticParameter *param);
|
||||||
|
|
||||||
int ElementGreaterEqualFp32(const float *input0, const float *input1, uint8_t *output, int element_size);
|
int ElementGreaterEqualFp32(const float *input0, const float *input1, uint8_t *output, int element_size);
|
||||||
|
int ElementOptGreaterEqualFp32(const float *input0, const float *input1, uint8_t *output, int element_size,
|
||||||
|
const ArithmeticParameter *param);
|
||||||
int ElementGreaterEqualInt32(const int32_t *input0, const int32_t *input1, uint8_t *output, int element_size);
|
int ElementGreaterEqualInt32(const int32_t *input0, const int32_t *input1, uint8_t *output, int element_size);
|
||||||
|
int ElementOptGreaterEqualInt32(const int32_t *input0, const int32_t *input1, uint8_t *output, int element_size,
|
||||||
|
const ArithmeticParameter *param);
|
||||||
#ifdef __cplusplus
|
#ifdef __cplusplus
|
||||||
}
|
}
|
||||||
#endif
|
#endif
|
||||||
|
|
|
@ -26,6 +26,21 @@ int ElementFloorMod(const float *in0, const float *in1, float *out, int size) {
|
||||||
return NNACL_OK;
|
return NNACL_OK;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
int ElementOptFloorMod(const float *in0, const float *in1, float *out, int size, const ArithmeticParameter *param) {
|
||||||
|
int i = 0;
|
||||||
|
if (param->in_elements_num0_ == 1) {
|
||||||
|
for (; i < size; i++) {
|
||||||
|
out[i] = in0[0] - floorf(in0[0] / in1[i]) * in1[i];
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
for (; i < size; i++) {
|
||||||
|
out[i] = in0[i] - floorf(in0[i] / in1[0]) * in1[0];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return NNACL_OK;
|
||||||
|
}
|
||||||
|
|
||||||
int ElementFloorModInt(const int *in0, const int *in1, int *out, int size) {
|
int ElementFloorModInt(const int *in0, const int *in1, int *out, int size) {
|
||||||
for (int i = 0; i < size; i++) {
|
for (int i = 0; i < size; i++) {
|
||||||
NNACL_CHECK_ZERO_RETURN_ERR(in1[i]);
|
NNACL_CHECK_ZERO_RETURN_ERR(in1[i]);
|
||||||
|
@ -35,6 +50,25 @@ int ElementFloorModInt(const int *in0, const int *in1, int *out, int size) {
|
||||||
return NNACL_OK;
|
return NNACL_OK;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
int ElementOptFloorModInt(const int *in0, const int *in1, int *out, int size, const ArithmeticParameter *param) {
|
||||||
|
int i = 0;
|
||||||
|
if (param->in_elements_num0_ == 1) {
|
||||||
|
for (; i < size; i++) {
|
||||||
|
NNACL_CHECK_ZERO_RETURN_ERR(in1[i]);
|
||||||
|
int remainder = in0[0] - (in0[0] / in1[i]) * in1[i];
|
||||||
|
out[i] = (remainder != 0) && ((in0[0] > 0) != (in1[i] > 0)) ? remainder + in1[i] : remainder;
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
NNACL_CHECK_ZERO_RETURN_ERR(in1[0]);
|
||||||
|
for (; i < size; i++) {
|
||||||
|
int remainder = in0[i] - (in0[i] / in1[0]) * in1[0];
|
||||||
|
out[i] = (remainder != 0) && ((in0[i] > 0) != (in1[0] > 0)) ? remainder + in1[0] : remainder;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return NNACL_OK;
|
||||||
|
}
|
||||||
|
|
||||||
int ElementMod(const float *in0, const float *in1, float *out, int size) {
|
int ElementMod(const float *in0, const float *in1, float *out, int size) {
|
||||||
for (int i = 0; i < size; i++) {
|
for (int i = 0; i < size; i++) {
|
||||||
out[i] = fmodf(in0[i], in1[i]);
|
out[i] = fmodf(in0[i], in1[i]);
|
||||||
|
@ -42,23 +76,24 @@ int ElementMod(const float *in0, const float *in1, float *out, int size) {
|
||||||
return NNACL_OK;
|
return NNACL_OK;
|
||||||
}
|
}
|
||||||
|
|
||||||
int ElementModInt(const int *in0, const int *in1, int *out, int size) {
|
int ElementOptMod(const float *in0, const float *in1, float *out, int size, const ArithmeticParameter *param) {
|
||||||
for (int i = 0; i < size; i++) {
|
int index = 0;
|
||||||
NNACL_CHECK_ZERO_RETURN_ERR(in1[i]);
|
if (param->in_elements_num0_ == 1) {
|
||||||
out[i] = in0[i] % in1[i];
|
for (; index < size; index++) {
|
||||||
|
out[index] = fmodf(in0[0], in1[index]);
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
for (; index < size; index++) {
|
||||||
|
out[index] = fmodf(in0[index], in1[0]);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
return NNACL_OK;
|
return NNACL_OK;
|
||||||
}
|
}
|
||||||
|
|
||||||
int ElementOptMod(const float *in0, const float *in1, float *out, int size, const ArithmeticParameter *param) {
|
int ElementModInt(const int *in0, const int *in1, int *out, int size) {
|
||||||
if (param->in_elements_num0_ == 1) {
|
for (int i = 0; i < size; i++) {
|
||||||
for (int index = 0; index < size; index++) {
|
NNACL_CHECK_ZERO_RETURN_ERR(in1[i]);
|
||||||
out[index] = fmodf(in0[0], in1[index]);
|
out[i] = in0[i] % in1[i];
|
||||||
}
|
|
||||||
} else {
|
|
||||||
for (int index = 0; index < size; index++) {
|
|
||||||
out[index] = fmodf(in0[index], in1[0]);
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
return NNACL_OK;
|
return NNACL_OK;
|
||||||
}
|
}
|
||||||
|
@ -85,6 +120,21 @@ int ElementFloorDiv(const float *in0, const float *in1, float *out, int size) {
|
||||||
return NNACL_OK;
|
return NNACL_OK;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
int ElementOptFloorDiv(const float *in0, const float *in1, float *out, int size, const ArithmeticParameter *param) {
|
||||||
|
int i = 0;
|
||||||
|
if (param->in_elements_num0_ == 1) {
|
||||||
|
for (; i < size; i++) {
|
||||||
|
out[i] = floorf(in0[0] / in1[i]);
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
for (; i < size; i++) {
|
||||||
|
out[i] = floorf(in0[i] / in1[0]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return NNACL_OK;
|
||||||
|
}
|
||||||
|
|
||||||
int ElementFloorDivInt(const int *in0, const int *in1, int *out, int size) {
|
int ElementFloorDivInt(const int *in0, const int *in1, int *out, int size) {
|
||||||
for (int i = 0; i < size; i++) {
|
for (int i = 0; i < size; i++) {
|
||||||
NNACL_CHECK_ZERO_RETURN_ERR(in1[i]);
|
NNACL_CHECK_ZERO_RETURN_ERR(in1[i]);
|
||||||
|
@ -93,6 +143,23 @@ int ElementFloorDivInt(const int *in0, const int *in1, int *out, int size) {
|
||||||
return NNACL_OK;
|
return NNACL_OK;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
int ElementOptFloorDivInt(const int *in0, const int *in1, int *out, int size, const ArithmeticParameter *param) {
|
||||||
|
int i = 0;
|
||||||
|
if (param->in_elements_num0_ == 1) {
|
||||||
|
for (; i < size; i++) {
|
||||||
|
NNACL_CHECK_ZERO_RETURN_ERR(in1[i]);
|
||||||
|
out[i] = in0[0] / in1[i];
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
NNACL_CHECK_ZERO_RETURN_ERR(in1[0]);
|
||||||
|
for (; i < size; i++) {
|
||||||
|
out[i] = in0[i] / in1[0];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return NNACL_OK;
|
||||||
|
}
|
||||||
|
|
||||||
int ElementLogicalAnd(const float *in0, const float *in1, float *out, int size) {
|
int ElementLogicalAnd(const float *in0, const float *in1, float *out, int size) {
|
||||||
int index = 0;
|
int index = 0;
|
||||||
#ifdef ENABLE_NEON
|
#ifdef ENABLE_NEON
|
||||||
|
@ -113,6 +180,21 @@ int ElementLogicalAnd(const float *in0, const float *in1, float *out, int size)
|
||||||
return NNACL_OK;
|
return NNACL_OK;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
int ElementOptLogicalAnd(const float *in0, const float *in1, float *out, int size, const ArithmeticParameter *param) {
|
||||||
|
int index = 0;
|
||||||
|
if (param->in_elements_num0_ == 1) {
|
||||||
|
for (; index < size; index++) {
|
||||||
|
out[index] = (float)((bool)(in0[0]) & (bool)(in1[index]));
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
for (; index < size; index++) {
|
||||||
|
out[index] = (float)((bool)(in0[index]) & (bool)(in1[0]));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return NNACL_OK;
|
||||||
|
}
|
||||||
|
|
||||||
int ElementLogicalAndInt(const int *in0, const int *in1, int *out, int size) {
|
int ElementLogicalAndInt(const int *in0, const int *in1, int *out, int size) {
|
||||||
int index = 0;
|
int index = 0;
|
||||||
for (; index < size; index++) {
|
for (; index < size; index++) {
|
||||||
|
@ -121,11 +203,42 @@ int ElementLogicalAndInt(const int *in0, const int *in1, int *out, int size) {
|
||||||
return NNACL_OK;
|
return NNACL_OK;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
int ElementOptLogicalAndInt(const int *in0, const int *in1, int *out, int size, const ArithmeticParameter *param) {
|
||||||
|
int index = 0;
|
||||||
|
if (param->in_elements_num0_ == 1) {
|
||||||
|
for (; index < size; index++) {
|
||||||
|
out[index] = (int)((unsigned int)(in0[0]) & (unsigned int)(in1[index]));
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
for (; index < size; index++) {
|
||||||
|
out[index] = (int)((unsigned int)(in0[index]) & (unsigned int)(in1[0]));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return NNACL_OK;
|
||||||
|
}
|
||||||
|
|
||||||
int ElementLogicalAndBool(const bool *in0, const bool *in1, bool *out, int size) {
|
int ElementLogicalAndBool(const bool *in0, const bool *in1, bool *out, int size) {
|
||||||
int index = 0;
|
int index = 0;
|
||||||
for (; index < size; index++) {
|
for (; index < size; index++) {
|
||||||
out[index] = (bool)((unsigned int)(in0[index]) & (unsigned int)(in1[index]));
|
out[index] = (bool)((unsigned int)(in0[index]) & (unsigned int)(in1[index]));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
return NNACL_OK;
|
||||||
|
}
|
||||||
|
|
||||||
|
int ElementOptLogicalAndBool(const bool *in0, const bool *in1, bool *out, int size, const ArithmeticParameter *param) {
|
||||||
|
int index = 0;
|
||||||
|
if (param->in_elements_num0_ == 1) {
|
||||||
|
for (; index < size; index++) {
|
||||||
|
out[index] = (bool)((unsigned int)(in0[0]) & (unsigned int)(in1[index]));
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
for (; index < size; index++) {
|
||||||
|
out[index] = (bool)((unsigned int)(in0[index]) & (unsigned int)(in1[0]));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
return NNACL_OK;
|
return NNACL_OK;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -149,6 +262,21 @@ int ElementLogicalOr(const float *in0, const float *in1, float *out, int size) {
|
||||||
return NNACL_OK;
|
return NNACL_OK;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
int ElementOptLogicalOr(const float *in0, const float *in1, float *out, int size, const ArithmeticParameter *param) {
|
||||||
|
int index = 0;
|
||||||
|
if (param->in_elements_num0_ == 1) {
|
||||||
|
for (; index < size; index++) {
|
||||||
|
out[index] = (float)((bool)(in0[0]) | (bool)(in1[index]));
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
for (; index < size; index++) {
|
||||||
|
out[index] = (float)((bool)(in0[index]) | (bool)(in1[0]));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return NNACL_OK;
|
||||||
|
}
|
||||||
|
|
||||||
int ElementLogicalOrBool(const bool *in0, const bool *in1, bool *out, int size) {
|
int ElementLogicalOrBool(const bool *in0, const bool *in1, bool *out, int size) {
|
||||||
int index = 0;
|
int index = 0;
|
||||||
for (; index < size; index++) {
|
for (; index < size; index++) {
|
||||||
|
@ -157,6 +285,21 @@ int ElementLogicalOrBool(const bool *in0, const bool *in1, bool *out, int size)
|
||||||
return NNACL_OK;
|
return NNACL_OK;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
int ElementOptLogicalOrBool(const bool *in0, const bool *in1, bool *out, int size, const ArithmeticParameter *param) {
|
||||||
|
int index = 0;
|
||||||
|
if (param->in_elements_num0_ == 1) {
|
||||||
|
for (; index < size; index++) {
|
||||||
|
out[index] = (bool)(in0[0] | in1[index]);
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
for (; index < size; index++) {
|
||||||
|
out[index] = (bool)(in0[index] | in1[0]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return NNACL_OK;
|
||||||
|
}
|
||||||
|
|
||||||
int ElementMaximum(const float *in0, const float *in1, float *out, int size) {
|
int ElementMaximum(const float *in0, const float *in1, float *out, int size) {
|
||||||
int index = 0;
|
int index = 0;
|
||||||
#ifdef ENABLE_NEON
|
#ifdef ENABLE_NEON
|
||||||
|
@ -173,6 +316,22 @@ int ElementMaximum(const float *in0, const float *in1, float *out, int size) {
|
||||||
return NNACL_OK;
|
return NNACL_OK;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
int ElementOptMaximum(const float *in0, const float *in1, float *out, int size, const ArithmeticParameter *param) {
|
||||||
|
int index = 0;
|
||||||
|
|
||||||
|
if (param->in_elements_num0_ == 1) {
|
||||||
|
for (; index < size; index++) {
|
||||||
|
out[index] = in0[0] > in1[index] ? in0[0] : in1[index];
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
for (; index < size; index++) {
|
||||||
|
out[index] = in0[index] > in1[0] ? in0[index] : in1[0];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return NNACL_OK;
|
||||||
|
}
|
||||||
|
|
||||||
int ElementMaximumInt(const int *in0, const int *in1, int *out, int size) {
|
int ElementMaximumInt(const int *in0, const int *in1, int *out, int size) {
|
||||||
int index = 0;
|
int index = 0;
|
||||||
#ifdef ENABLE_NEON
|
#ifdef ENABLE_NEON
|
||||||
|
@ -189,22 +348,53 @@ int ElementMaximumInt(const int *in0, const int *in1, int *out, int size) {
|
||||||
return NNACL_OK;
|
return NNACL_OK;
|
||||||
}
|
}
|
||||||
|
|
||||||
int ElementMinimumInt(const int *input0, const int *input1, int *output, const int element_size) {
|
int ElementOptMaximumInt(const int *in0, const int *in1, int *out, int size, const ArithmeticParameter *param) {
|
||||||
|
int index = 0;
|
||||||
|
if (param->in_elements_num0_ == 1) {
|
||||||
|
for (; index < size; index++) {
|
||||||
|
out[index] = in0[0] > in1[index] ? in0[0] : in1[index];
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
for (; index < size; index++) {
|
||||||
|
out[index] = in0[index] > in1[0] ? in0[index] : in1[0];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return NNACL_OK;
|
||||||
|
}
|
||||||
|
|
||||||
|
int ElementMinimumInt(const int *input0, const int *input1, int *output, int size) {
|
||||||
int index = 0;
|
int index = 0;
|
||||||
#ifdef ENABLE_NEON
|
#ifdef ENABLE_NEON
|
||||||
for (; index <= element_size - 4; index += C4NUM) {
|
for (; index <= size - 4; index += C4NUM) {
|
||||||
int32x4_t vin0 = vld1q_s32(input0 + index);
|
int32x4_t vin0 = vld1q_s32(input0 + index);
|
||||||
int32x4_t vin1 = vld1q_s32(input1 + index);
|
int32x4_t vin1 = vld1q_s32(input1 + index);
|
||||||
int32x4_t vout = vminq_s32(vin0, vin1);
|
int32x4_t vout = vminq_s32(vin0, vin1);
|
||||||
vst1q_s32(output + index, vout);
|
vst1q_s32(output + index, vout);
|
||||||
}
|
}
|
||||||
#endif
|
#endif
|
||||||
for (; index < element_size; index++) {
|
for (; index < size; index++) {
|
||||||
output[index] = input0[index] > input1[index] ? input1[index] : input0[index];
|
output[index] = input0[index] > input1[index] ? input1[index] : input0[index];
|
||||||
}
|
}
|
||||||
return NNACL_OK;
|
return NNACL_OK;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
int ElementOptMinimumInt(const int *input0, const int *input1, int *output, int size,
|
||||||
|
const ArithmeticParameter *param) {
|
||||||
|
int index = 0;
|
||||||
|
if (param->in_elements_num0_ == 1) {
|
||||||
|
for (; index < size; index++) {
|
||||||
|
output[index] = input0[0] > input1[index] ? input1[index] : input0[0];
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
for (; index < size; index++) {
|
||||||
|
output[index] = input0[index] > input1[0] ? input1[0] : input0[index];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return NNACL_OK;
|
||||||
|
}
|
||||||
|
|
||||||
int ElementMinimum(const float *in0, const float *in1, float *out, int size) {
|
int ElementMinimum(const float *in0, const float *in1, float *out, int size) {
|
||||||
int index = 0;
|
int index = 0;
|
||||||
#ifdef ENABLE_NEON
|
#ifdef ENABLE_NEON
|
||||||
|
@ -221,6 +411,21 @@ int ElementMinimum(const float *in0, const float *in1, float *out, int size) {
|
||||||
return NNACL_OK;
|
return NNACL_OK;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
int ElementOptMinimum(const float *in0, const float *in1, float *out, int size, const ArithmeticParameter *param) {
|
||||||
|
int index = 0;
|
||||||
|
if (param->in_elements_num0_ == 1) {
|
||||||
|
for (; index < size; index++) {
|
||||||
|
out[index] = in0[0] > in1[index] ? in1[index] : in0[0];
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
for (; index < size; index++) {
|
||||||
|
out[index] = in0[index] > in1[0] ? in1[0] : in0[index];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return NNACL_OK;
|
||||||
|
}
|
||||||
|
|
||||||
#undef ACCURACY_DATA
|
#undef ACCURACY_DATA
|
||||||
|
|
||||||
void TileOneDimensionFp32(const float *inData, float *outData, int dim, size_t ndim, const int *inShape,
|
void TileOneDimensionFp32(const float *inData, float *outData, int dim, size_t ndim, const int *inShape,
|
||||||
|
|
|
@ -37,31 +37,44 @@ void TileDimensionsFp32(const float *data0, const float *data1, float *tile_data
|
||||||
ArithmeticParameter *param);
|
ArithmeticParameter *param);
|
||||||
/* logical and */
|
/* logical and */
|
||||||
int ElementLogicalAnd(const float *in0, const float *in1, float *out, int size);
|
int ElementLogicalAnd(const float *in0, const float *in1, float *out, int size);
|
||||||
|
int ElementOptLogicalAnd(const float *in0, const float *in1, float *out, int size, const ArithmeticParameter *param);
|
||||||
int ElementLogicalAndInt(const int *in0, const int *in1, int *out, int size);
|
int ElementLogicalAndInt(const int *in0, const int *in1, int *out, int size);
|
||||||
|
int ElementOptLogicalAndInt(const int *in0, const int *in1, int *out, int size, const ArithmeticParameter *param);
|
||||||
int ElementLogicalAndBool(const bool *in0, const bool *in1, bool *out, int size);
|
int ElementLogicalAndBool(const bool *in0, const bool *in1, bool *out, int size);
|
||||||
|
int ElementOptLogicalAndBool(const bool *in0, const bool *in1, bool *out, int size, const ArithmeticParameter *param);
|
||||||
|
|
||||||
/* logical or */
|
/* logical or */
|
||||||
int ElementLogicalOr(const float *in0, const float *in1, float *out, int size);
|
int ElementLogicalOr(const float *in0, const float *in1, float *out, int size);
|
||||||
|
int ElementOptLogicalOr(const float *in0, const float *in1, float *out, int size, const ArithmeticParameter *param);
|
||||||
int ElementLogicalOrBool(const bool *in0, const bool *in1, bool *out, int size);
|
int ElementLogicalOrBool(const bool *in0, const bool *in1, bool *out, int size);
|
||||||
|
int ElementOptLogicalOrBool(const bool *in0, const bool *in1, bool *out, int size, const ArithmeticParameter *param);
|
||||||
|
|
||||||
/* max min */
|
/* max min */
|
||||||
int ElementMaximum(const float *in0, const float *in1, float *out, int size);
|
int ElementMaximum(const float *in0, const float *in1, float *out, int size);
|
||||||
|
int ElementOptMaximum(const float *in0, const float *in1, float *out, int size, const ArithmeticParameter *param);
|
||||||
int ElementMinimum(const float *in0, const float *in1, float *out, int size);
|
int ElementMinimum(const float *in0, const float *in1, float *out, int size);
|
||||||
|
int ElementOptMinimum(const float *in0, const float *in1, float *out, int size, const ArithmeticParameter *param);
|
||||||
int ElementMaximumInt(const int *in0, const int *in1, int *out, int size);
|
int ElementMaximumInt(const int *in0, const int *in1, int *out, int size);
|
||||||
int ElementMinimumInt(const int *input0, const int *input1, int *output, const int element_size);
|
int ElementOptMaximumInt(const int *in0, const int *in1, int *out, int size, const ArithmeticParameter *param);
|
||||||
|
int ElementMinimumInt(const int *input0, const int *input1, int *output, int size);
|
||||||
|
int ElementOptMinimumInt(const int *input0, const int *input1, int *output, int size, const ArithmeticParameter *param);
|
||||||
|
|
||||||
/* floor div */
|
/* floor div */
|
||||||
int ElementFloorDiv(const float *in0, const float *in1, float *out, int size);
|
int ElementFloorDiv(const float *in0, const float *in1, float *out, int size);
|
||||||
|
int ElementOptFloorDiv(const float *in0, const float *in1, float *out, int size, const ArithmeticParameter *param);
|
||||||
int ElementFloorDivInt(const int *in0, const int *in1, int *out, int size);
|
int ElementFloorDivInt(const int *in0, const int *in1, int *out, int size);
|
||||||
|
int ElementOptFloorDivInt(const int *in0, const int *in1, int *out, int size, const ArithmeticParameter *param);
|
||||||
|
|
||||||
/* floor mod */
|
/* floor mod */
|
||||||
int ElementFloorMod(const float *in0, const float *in1, float *out, int size);
|
int ElementFloorMod(const float *in0, const float *in1, float *out, int size);
|
||||||
|
int ElementOptFloorMod(const float *in0, const float *in1, float *out, int size, const ArithmeticParameter *param);
|
||||||
int ElementFloorModInt(const int *in0, const int *in1, int *out, int size);
|
int ElementFloorModInt(const int *in0, const int *in1, int *out, int size);
|
||||||
|
int ElementOptFloorModInt(const int *in0, const int *in1, int *out, int size, const ArithmeticParameter *param);
|
||||||
|
|
||||||
/* mod */
|
/* mod */
|
||||||
int ElementMod(const float *in0, const float *in1, float *out, int size);
|
int ElementMod(const float *in0, const float *in1, float *out, int size);
|
||||||
int ElementModInt(const int *in0, const int *in1, int *out, int size);
|
|
||||||
int ElementOptMod(const float *in0, const float *in1, float *out, int size, const ArithmeticParameter *param);
|
int ElementOptMod(const float *in0, const float *in1, float *out, int size, const ArithmeticParameter *param);
|
||||||
|
int ElementModInt(const int *in0, const int *in1, int *out, int size);
|
||||||
int ElementOptModInt(const int *in0, const int *in1, int *out, int size, const ArithmeticParameter *param);
|
int ElementOptModInt(const int *in0, const int *in1, int *out, int size, const ArithmeticParameter *param);
|
||||||
|
|
||||||
#ifdef __cplusplus
|
#ifdef __cplusplus
|
||||||
|
|
|
@ -25,4 +25,9 @@ int ElementSquaredDifference(const float *in0, const float *in1, float *out, int
|
||||||
return ElementMul(out, out, out, size);
|
return ElementMul(out, out, out, size);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
int ElementOptSquaredDifference(const float *in0, const float *in1, float *out, int size,
|
||||||
|
const ArithmeticParameter *param) {
|
||||||
|
ElementOptSub(in0, in1, out, size, param);
|
||||||
|
return ElementMul(out, out, out, size);
|
||||||
|
}
|
||||||
#endif // MINDSPORE_NNACL_SQUARED_DIFFERENCE_H_
|
#endif // MINDSPORE_NNACL_SQUARED_DIFFERENCE_H_
|
||||||
|
|
|
@ -29,7 +29,8 @@ extern "C" {
|
||||||
|
|
||||||
/* Element Squared Difference */
|
/* Element Squared Difference */
|
||||||
int ElementSquaredDifference(const float *in0, const float *in1, float *out, int size);
|
int ElementSquaredDifference(const float *in0, const float *in1, float *out, int size);
|
||||||
|
int ElementOptSquaredDifference(const float *in0, const float *in1, float *out, int size,
|
||||||
|
const ArithmeticParameter *param);
|
||||||
#ifdef __cplusplus
|
#ifdef __cplusplus
|
||||||
}
|
}
|
||||||
#endif
|
#endif
|
||||||
|
|
|
@ -32,9 +32,9 @@ int ElementSubInt(const int *in0, const int *in1, int *out, int size);
|
||||||
int ElementSubRelu(const float *in0, const float *in1, float *out, int size);
|
int ElementSubRelu(const float *in0, const float *in1, float *out, int size);
|
||||||
int ElementSubRelu6(const float *in0, const float *in1, float *out, int size);
|
int ElementSubRelu6(const float *in0, const float *in1, float *out, int size);
|
||||||
int ElementOptSub(const float *in0, const float *in1, float *out, int size, const ArithmeticParameter *param);
|
int ElementOptSub(const float *in0, const float *in1, float *out, int size, const ArithmeticParameter *param);
|
||||||
|
int ElementOptSubInt(const int *in0, const int *in1, int *out, int size, const ArithmeticParameter *param);
|
||||||
int ElementOptSubRelu(const float *in0, const float *in1, float *out, int size, const ArithmeticParameter *param);
|
int ElementOptSubRelu(const float *in0, const float *in1, float *out, int size, const ArithmeticParameter *param);
|
||||||
int ElementOptSubRelu6(const float *in0, const float *in1, float *out, int size, const ArithmeticParameter *param);
|
int ElementOptSubRelu6(const float *in0, const float *in1, float *out, int size, const ArithmeticParameter *param);
|
||||||
int ElementOptSubInt(const int *in0, const int *in1, int *out, int size, const ArithmeticParameter *param);
|
|
||||||
|
|
||||||
#ifdef __cplusplus
|
#ifdef __cplusplus
|
||||||
}
|
}
|
||||||
|
|
|
@ -63,36 +63,6 @@ int ArithmeticFP16CPUKernel::CheckDataType() {
|
||||||
return RET_OK;
|
return RET_OK;
|
||||||
}
|
}
|
||||||
|
|
||||||
bool ArithmeticFP16CPUKernel::IsScalarClac() { // 2 32 240 240, 1 1 1 1
|
|
||||||
if ((param_->in_elements_num0_ == 1 || param_->in_elements_num1_ == 1) && (arithmetic_opt_func_ != nullptr)) {
|
|
||||||
return true;
|
|
||||||
} else {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
bool ArithmeticFP16CPUKernel::IsBatchScalarCalc() {
|
|
||||||
if (arithmetic_opt_func_ == nullptr) {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
size_t break_axis = 0;
|
|
||||||
for (size_t i = 0; i < param_->ndim_; i++) {
|
|
||||||
if (param_->in_shape0_[i] != param_->in_shape1_[i]) {
|
|
||||||
break_axis = i;
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if (break_axis < param_->ndim_) {
|
|
||||||
for (size_t i = break_axis; i < param_->ndim_; i++) {
|
|
||||||
if (param_->in_shape1_[i] != 1) {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
break_pos_ = break_axis;
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
|
|
||||||
void ArithmeticFP16CPUKernel::InitRunFunction(int primitive_type) {
|
void ArithmeticFP16CPUKernel::InitRunFunction(int primitive_type) {
|
||||||
ARITHMETIC_FUNC_INFO_FP16 fun_table[] = {
|
ARITHMETIC_FUNC_INFO_FP16 fun_table[] = {
|
||||||
{PrimitiveType_MulFusion, schema::ActivationType_RELU, ElementMulReluFp16, ElementOptMulReluFp16},
|
{PrimitiveType_MulFusion, schema::ActivationType_RELU, ElementMulReluFp16, ElementOptMulReluFp16},
|
||||||
|
@ -171,6 +141,7 @@ int ArithmeticFP16CPUKernel::Run() {
|
||||||
MS_LOG(ERROR) << "ArithmeticFP16CPUKernel check dataType failed.";
|
MS_LOG(ERROR) << "ArithmeticFP16CPUKernel check dataType failed.";
|
||||||
return RET_ERROR;
|
return RET_ERROR;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (!input0_broadcast_) {
|
if (!input0_broadcast_) {
|
||||||
input0_ptr_ = ConvertInputFp32toFp16(in_tensors_.at(0), static_cast<const lite::InnerContext *>(this->ms_context_));
|
input0_ptr_ = ConvertInputFp32toFp16(in_tensors_.at(0), static_cast<const lite::InnerContext *>(this->ms_context_));
|
||||||
}
|
}
|
||||||
|
@ -183,10 +154,16 @@ int ArithmeticFP16CPUKernel::Run() {
|
||||||
FreeFp16Buffer();
|
FreeFp16Buffer();
|
||||||
return RET_ERROR;
|
return RET_ERROR;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
batch_a_ptr_ = static_cast<uint8_t *>(input0_ptr_);
|
||||||
|
batch_b_ptr_ = static_cast<uint8_t *>(input1_ptr_);
|
||||||
|
batch_c_ptr_ = static_cast<uint8_t *>(output_ptr_);
|
||||||
auto ret = ParallelLaunch(this->ms_context_, ArithmeticsRun, this, op_parameter_->thread_num_);
|
auto ret = ParallelLaunch(this->ms_context_, ArithmeticsRun, this, op_parameter_->thread_num_);
|
||||||
if (ret != RET_OK) {
|
if (ret != RET_OK) {
|
||||||
MS_LOG(ERROR) << "ArithmeticsRun failed, ret : " << ret;
|
MS_LOG(ERROR) << "ArithmeticsRun failed, ret : " << ret;
|
||||||
|
return RET_ERROR;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (out_tensors_.at(0)->data_type() == kNumberTypeFloat32) {
|
if (out_tensors_.at(0)->data_type() == kNumberTypeFloat32) {
|
||||||
Float16ToFloat32(static_cast<float16_t *>(output_ptr_), reinterpret_cast<float *>(output_tensor->data()),
|
Float16ToFloat32(static_cast<float16_t *>(output_ptr_), reinterpret_cast<float *>(output_tensor->data()),
|
||||||
output_tensor->ElementsNum());
|
output_tensor->ElementsNum());
|
||||||
|
|
|
@ -40,8 +40,6 @@ class ArithmeticFP16CPUKernel : public ArithmeticCPUKernel {
|
||||||
~ArithmeticFP16CPUKernel() = default;
|
~ArithmeticFP16CPUKernel() = default;
|
||||||
int ReSize() override;
|
int ReSize() override;
|
||||||
int Run() override;
|
int Run() override;
|
||||||
bool IsBatchScalarCalc() override;
|
|
||||||
bool IsScalarClac() override;
|
|
||||||
|
|
||||||
private:
|
private:
|
||||||
void InitRunFunction(int primitive_type) override;
|
void InitRunFunction(int primitive_type) override;
|
||||||
|
|
|
@ -28,85 +28,110 @@ using mindspore::schema::PrimitiveType_LessEqual;
|
||||||
using mindspore::schema::PrimitiveType_NotEqual;
|
using mindspore::schema::PrimitiveType_NotEqual;
|
||||||
|
|
||||||
namespace mindspore::kernel {
|
namespace mindspore::kernel {
|
||||||
int ArithmeticCompareCPUKernel::BroadcastRun(void *input0, void *input1, void *output, int dim, int out_count,
|
void ArithmeticCompareCPUKernel::InitRunFunction(int primitive_type) {
|
||||||
int out_thread_stride) {
|
ARITHMETIC_COMEPARE_FUNC_INFO_FP32 fun_table[] = {
|
||||||
if (dim > break_pos_) {
|
{PrimitiveType_Equal, ElementEqualFp32, ElementEqualInt32, ElementOptEqualFp32, ElementOptEqualInt32},
|
||||||
if (in_tensors_[0]->data_type() == kNumberTypeInt || in_tensors_[0]->data_type() == kNumberTypeInt32) {
|
{PrimitiveType_NotEqual, ElementNotEqualFp32, ElementNotEqualInt32, ElementOptNotEqualFp32,
|
||||||
return func_int32_(reinterpret_cast<int *>(input0) + out_thread_stride,
|
ElementOptNotEqualInt32},
|
||||||
reinterpret_cast<int *>(input1) + out_thread_stride,
|
{PrimitiveType_Less, ElementLessFp32, ElementLessInt32, ElementOptLessFp32, ElementOptLessInt32},
|
||||||
reinterpret_cast<uint8_t *>(output) + out_thread_stride, out_count);
|
{PrimitiveType_LessEqual, ElementLessEqualFp32, ElementLessEqualInt32, ElementOptLessEqualFp32,
|
||||||
|
ElementOptLessEqualInt32},
|
||||||
|
{PrimitiveType_Greater, ElementGreaterFp32, ElementGreaterInt32, ElementOptGreaterFp32, ElementOptGreaterInt32},
|
||||||
|
{PrimitiveType_GreaterEqual, ElementGreaterEqualFp32, ElementGreaterEqualInt32, ElementOptGreaterEqualFp32,
|
||||||
|
ElementOptGreaterEqualInt32}};
|
||||||
|
size_t length = sizeof(fun_table) / sizeof(ARITHMETIC_COMEPARE_FUNC_INFO_FP32);
|
||||||
|
for (size_t i = 0; i < length; i++) {
|
||||||
|
if (fun_table[i].primitive_type_ == primitive_type) {
|
||||||
|
func_fp32_ = fun_table[i].func_;
|
||||||
|
func_int32_ = fun_table[i].int_func_;
|
||||||
|
opt_func_fp32_ = fun_table[i].opt_func_;
|
||||||
|
opt_func_int32_ = fun_table[i].opt_int_func_;
|
||||||
|
return;
|
||||||
}
|
}
|
||||||
return func_fp32_(reinterpret_cast<float *>(input0) + out_thread_stride,
|
|
||||||
reinterpret_cast<float *>(input1) + out_thread_stride,
|
|
||||||
reinterpret_cast<uint8_t *>(output) + out_thread_stride, out_count);
|
|
||||||
}
|
}
|
||||||
for (int i = 0; i < param_->out_shape_[dim]; ++i) {
|
}
|
||||||
int pos0_ = param_->in_shape0_[dim] == 1 ? 0 : i;
|
|
||||||
int pos1_ = param_->in_shape1_[dim] == 1 ? 0 : i;
|
int ArithmeticCompareCPUKernel::Execute(const void *input0, const void *input1, void *output, int size, bool is_opt) {
|
||||||
int error_code;
|
int ret = RET_OK;
|
||||||
if (in_tensors_[0]->data_type() == kNumberTypeInt || in_tensors_[0]->data_type() == kNumberTypeInt32) {
|
if (in_tensors_[0]->data_type() == kNumberTypeFloat32) {
|
||||||
error_code = BroadcastRun(reinterpret_cast<int *>(input0) + pos0_ * param_->in_strides0_[dim],
|
if (is_opt) {
|
||||||
reinterpret_cast<int *>(input1) + pos1_ * param_->in_strides1_[dim],
|
CHECK_NULL_RETURN(opt_func_fp32_);
|
||||||
reinterpret_cast<uint8_t *>(output) + i * param_->out_strides_[dim], dim + 1, out_count,
|
ret = opt_func_fp32_(reinterpret_cast<const float *>(input0), reinterpret_cast<const float *>(input1),
|
||||||
out_thread_stride);
|
reinterpret_cast<uint8_t *>(output), size, param_);
|
||||||
} else {
|
} else {
|
||||||
error_code = BroadcastRun(reinterpret_cast<float *>(input0) + pos0_ * param_->in_strides0_[dim],
|
CHECK_NULL_RETURN(func_fp32_);
|
||||||
reinterpret_cast<float *>(input1) + pos1_ * param_->in_strides1_[dim],
|
ret = func_fp32_(reinterpret_cast<const float *>(input0), reinterpret_cast<const float *>(input1),
|
||||||
reinterpret_cast<uint8_t *>(output) + i * param_->out_strides_[dim], dim + 1, out_count,
|
reinterpret_cast<uint8_t *>(output), size);
|
||||||
out_thread_stride);
|
|
||||||
}
|
}
|
||||||
if (error_code != RET_OK) {
|
} else if (in_tensors_[0]->data_type() == kNumberTypeInt || in_tensors_[0]->data_type() == kNumberTypeInt32) {
|
||||||
return error_code;
|
if (is_opt) {
|
||||||
|
CHECK_NULL_RETURN(opt_func_int32_);
|
||||||
|
ret = opt_func_int32_(reinterpret_cast<const int *>(input0), reinterpret_cast<const int *>(input1),
|
||||||
|
reinterpret_cast<uint8_t *>(output), size, param_);
|
||||||
|
} else {
|
||||||
|
CHECK_NULL_RETURN(func_int32_);
|
||||||
|
ret = func_int32_(reinterpret_cast<const int *>(input0), reinterpret_cast<const int *>(input1),
|
||||||
|
reinterpret_cast<uint8_t *>(output), size);
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
MS_LOG(ERROR) << "Error Operator type " << kNumberTypeInt32;
|
||||||
|
return RET_ERROR;
|
||||||
|
}
|
||||||
|
return ret;
|
||||||
|
}
|
||||||
|
|
||||||
|
int ArithmeticCompareCPUKernel::CalcArithmeticByBatch(int task_id) {
|
||||||
|
if (break_pos_ > ARITHMETIC_SUPPORT_DIMS_NUM || param_->out_strides_[break_pos_ - 1] == 0) {
|
||||||
|
MS_LOG(ERROR) << "param_->out_strides_[break_pos_ - 1] is 0 or break_pos_ is > 10";
|
||||||
|
return RET_ERROR;
|
||||||
|
}
|
||||||
|
|
||||||
|
int batch_per_thread = UP_DIV(out_batch_, op_parameter_->thread_num_);
|
||||||
|
int start_batch = batch_per_thread * task_id;
|
||||||
|
int end_batch = MSMIN(start_batch + batch_per_thread, out_batch_);
|
||||||
|
int ret = RET_ERROR;
|
||||||
|
for (int i = start_batch; i < end_batch; i++) {
|
||||||
|
batch_a_ptr_ = static_cast<uint8_t *>(input0_ptr_) + a_offset_[i] * a_stride_size_ * data_type_len_;
|
||||||
|
batch_b_ptr_ = static_cast<uint8_t *>(input1_ptr_) + b_offset_[i] * b_stride_size_ * data_type_len_;
|
||||||
|
batch_c_ptr_ = static_cast<uint8_t *>(output_ptr_) + i * c_stride_size_ * sizeof(uint8_t);
|
||||||
|
if (batch_scalar_) {
|
||||||
|
ret = Execute(batch_a_ptr_, batch_b_ptr_, batch_c_ptr_, c_stride_size_, true);
|
||||||
|
} else {
|
||||||
|
ret = Execute(batch_a_ptr_, batch_b_ptr_, batch_c_ptr_, c_stride_size_, false);
|
||||||
|
}
|
||||||
|
if (ret != RET_OK) {
|
||||||
|
MS_LOG(ERROR) << "failed to calculate.";
|
||||||
|
return RET_ERROR;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return RET_OK;
|
return ret;
|
||||||
}
|
}
|
||||||
|
|
||||||
int ArithmeticCompareCPUKernel::DoArithmetic(int task_id) {
|
int ArithmeticCompareCPUKernel::DoArithmetic(int task_id) {
|
||||||
auto element_num = out_tensors_[0]->ElementsNum();
|
if (split_by_batch_) {
|
||||||
|
return CalcArithmeticByBatch(task_id);
|
||||||
|
}
|
||||||
|
|
||||||
MS_ASSERT(op_parameter_->thread_num_ != 0);
|
int64_t element_num = out_tensors_[0]->ElementsNum();
|
||||||
|
auto ret = RET_ERROR;
|
||||||
int stride = UP_DIV(element_num, op_parameter_->thread_num_);
|
int stride = UP_DIV(element_num, op_parameter_->thread_num_);
|
||||||
int count = MSMIN(stride, element_num - stride * task_id);
|
int count = MSMIN(stride, element_num - stride * task_id);
|
||||||
if (count <= 0) {
|
if (count <= 0) {
|
||||||
return RET_OK;
|
return RET_OK;
|
||||||
}
|
}
|
||||||
|
CHECK_LESS_RETURN(ARITHMETIC_SUPPORT_DIMS_NUM, param_->ndim_);
|
||||||
if (func_fp32_ == nullptr) {
|
int in_offset = stride * task_id * data_type_len_;
|
||||||
MS_LOG(ERROR) << "func_fp32_ function is nullptr!";
|
int out_offset = stride * task_id * sizeof(uint8_t);
|
||||||
return RET_ERROR;
|
if (scalar_) {
|
||||||
}
|
if (param_->in_elements_num0_ == 1) {
|
||||||
|
ret = Execute(batch_a_ptr_, batch_b_ptr_ + in_offset, batch_c_ptr_ + out_offset, count, true);
|
||||||
int error_code;
|
|
||||||
if (param_->broadcasting_) { // need broadcast
|
|
||||||
stride = UP_DIV(outside_, op_parameter_->thread_num_);
|
|
||||||
int out_count = MSMIN(stride, outside_ - stride * task_id);
|
|
||||||
int out_thread_stride = stride * task_id;
|
|
||||||
if (out_count <= 0) {
|
|
||||||
return RET_OK;
|
|
||||||
}
|
|
||||||
if (in_tensors_[0]->data_type() == kNumberTypeFloat32) {
|
|
||||||
error_code = BroadcastRun(reinterpret_cast<float *>(input0_ptr_), reinterpret_cast<float *>(input1_ptr_),
|
|
||||||
reinterpret_cast<uint8_t *>(out_tensors_[0]->data()), 0, out_count, out_thread_stride);
|
|
||||||
} else {
|
} else {
|
||||||
error_code = BroadcastRun(reinterpret_cast<int *>(input0_ptr_), reinterpret_cast<int *>(input1_ptr_),
|
ret = Execute(batch_a_ptr_ + in_offset, batch_b_ptr_, batch_c_ptr_ + out_offset, count, true);
|
||||||
reinterpret_cast<uint8_t *>(out_tensors_[0]->data()), 0, out_count, out_thread_stride);
|
|
||||||
}
|
|
||||||
} else { // no broadcast, neither is scalar, two same shape
|
|
||||||
if (in_tensors_[0]->data_type() == kNumberTypeFloat32) {
|
|
||||||
error_code = func_fp32_(reinterpret_cast<float *>(input0_ptr_) + stride * task_id,
|
|
||||||
reinterpret_cast<float *>(input1_ptr_) + stride * task_id,
|
|
||||||
reinterpret_cast<uint8_t *>(out_tensors_[0]->data()) + stride * task_id, count);
|
|
||||||
} else {
|
|
||||||
error_code = func_int32_(reinterpret_cast<int *>(input0_ptr_) + stride * task_id,
|
|
||||||
reinterpret_cast<int *>(input1_ptr_) + stride * task_id,
|
|
||||||
reinterpret_cast<uint8_t *>(out_tensors_[0]->data()) + stride * task_id, count);
|
|
||||||
}
|
}
|
||||||
|
} else {
|
||||||
|
ret = Execute(batch_a_ptr_ + in_offset, batch_b_ptr_ + in_offset, batch_c_ptr_ + out_offset, count, false);
|
||||||
}
|
}
|
||||||
if (error_code != RET_OK) {
|
return ret;
|
||||||
return RET_ERROR;
|
|
||||||
}
|
|
||||||
return RET_OK;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_Equal, LiteKernelCreator<ArithmeticCompareCPUKernel>)
|
REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_Equal, LiteKernelCreator<ArithmeticCompareCPUKernel>)
|
||||||
|
|
|
@ -21,53 +21,38 @@
|
||||||
#include "nnacl/fp32/arithmetic_compare_fp32.h"
|
#include "nnacl/fp32/arithmetic_compare_fp32.h"
|
||||||
|
|
||||||
namespace mindspore::kernel {
|
namespace mindspore::kernel {
|
||||||
typedef int (*ArithmeticCompareFp32Func)(const float *input0, const float *input1, uint8_t *output, int element_size);
|
|
||||||
typedef int (*ArithmeticCompareIntFunc)(const int *input0, const int *input1, uint8_t *output, int element_size);
|
|
||||||
class ArithmeticCompareCPUKernel : public ArithmeticCPUKernel {
|
class ArithmeticCompareCPUKernel : public ArithmeticCPUKernel {
|
||||||
|
typedef int (*ArithmeticCompareFp32Func)(const float *input0, const float *input1, uint8_t *output, int element_size);
|
||||||
|
typedef int (*ArithmeticCompareIntFunc)(const int *input0, const int *input1, uint8_t *output, int element_size);
|
||||||
|
typedef int (*ArithmeticOptCompareFp32Func)(const float *input0, const float *input1, uint8_t *output,
|
||||||
|
int element_size, const ArithmeticParameter *param);
|
||||||
|
typedef int (*ArithmeticOptCompareIntFunc)(const int *input0, const int *input1, uint8_t *output, int element_size,
|
||||||
|
const ArithmeticParameter *param);
|
||||||
|
typedef struct {
|
||||||
|
int primitive_type_;
|
||||||
|
ArithmeticCompareFp32Func func_;
|
||||||
|
ArithmeticCompareIntFunc int_func_;
|
||||||
|
ArithmeticOptCompareFp32Func opt_func_;
|
||||||
|
ArithmeticOptCompareIntFunc opt_int_func_;
|
||||||
|
} ARITHMETIC_COMEPARE_FUNC_INFO_FP32;
|
||||||
|
|
||||||
public:
|
public:
|
||||||
explicit ArithmeticCompareCPUKernel(OpParameter *parameter, const std::vector<lite::Tensor *> &inputs,
|
explicit ArithmeticCompareCPUKernel(OpParameter *parameter, const std::vector<lite::Tensor *> &inputs,
|
||||||
const std::vector<lite::Tensor *> &outputs, const lite::InnerContext *ctx)
|
const std::vector<lite::Tensor *> &outputs, const lite::InnerContext *ctx)
|
||||||
: ArithmeticCPUKernel(parameter, inputs, outputs, ctx) {
|
: ArithmeticCPUKernel(parameter, inputs, outputs, ctx) {}
|
||||||
switch (parameter->type_) {
|
|
||||||
case PrimitiveType_Equal:
|
|
||||||
func_fp32_ = ElementEqualFp32;
|
|
||||||
func_int32_ = ElementEqualInt32;
|
|
||||||
break;
|
|
||||||
case PrimitiveType_NotEqual:
|
|
||||||
func_fp32_ = ElementNotEqualFp32;
|
|
||||||
func_int32_ = ElementNotEqualInt32;
|
|
||||||
break;
|
|
||||||
case PrimitiveType_Less:
|
|
||||||
func_fp32_ = ElementLessFp32;
|
|
||||||
func_int32_ = ElementLessInt32;
|
|
||||||
break;
|
|
||||||
case PrimitiveType_LessEqual:
|
|
||||||
func_fp32_ = ElementLessEqualFp32;
|
|
||||||
func_int32_ = ElementLessEqualInt32;
|
|
||||||
break;
|
|
||||||
case PrimitiveType_Greater:
|
|
||||||
func_fp32_ = ElementGreaterFp32;
|
|
||||||
func_int32_ = ElementGreaterInt32;
|
|
||||||
break;
|
|
||||||
case PrimitiveType_GreaterEqual:
|
|
||||||
func_fp32_ = ElementGreaterEqualFp32;
|
|
||||||
func_int32_ = ElementGreaterEqualInt32;
|
|
||||||
break;
|
|
||||||
default:
|
|
||||||
MS_LOG(ERROR) << "Error Operator type " << parameter->type_;
|
|
||||||
func_fp32_ = nullptr;
|
|
||||||
func_int32_ = nullptr;
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
~ArithmeticCompareCPUKernel() override = default;
|
~ArithmeticCompareCPUKernel() override = default;
|
||||||
|
|
||||||
int DoArithmetic(int task_id) override;
|
int DoArithmetic(int task_id) override;
|
||||||
int BroadcastRun(void *input0, void *input1, void *output, int dim, int out_count, int out_thread_stride) override;
|
|
||||||
|
protected:
|
||||||
|
void InitRunFunction(int primitive_type) override;
|
||||||
|
int Execute(const void *input0, const void *input1, void *output, int size, bool is_opt) override;
|
||||||
|
int CalcArithmeticByBatch(int task_id) override;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
ArithmeticCompareFp32Func func_fp32_ = nullptr;
|
ArithmeticCompareFp32Func func_fp32_ = nullptr;
|
||||||
ArithmeticCompareIntFunc func_int32_ = nullptr;
|
ArithmeticCompareIntFunc func_int32_ = nullptr;
|
||||||
|
ArithmeticOptCompareFp32Func opt_func_fp32_ = nullptr;
|
||||||
|
ArithmeticOptCompareIntFunc opt_func_int32_ = nullptr;
|
||||||
};
|
};
|
||||||
int ArithmeticCompareRun(void *cdata, int task_id, float lhs_scale, float rhs_scale);
|
int ArithmeticCompareRun(void *cdata, int task_id, float lhs_scale, float rhs_scale);
|
||||||
} // namespace mindspore::kernel
|
} // namespace mindspore::kernel
|
||||||
|
|
|
@ -50,27 +50,133 @@ int ArithmeticCPUKernel::Prepare() {
|
||||||
}
|
}
|
||||||
return ReSize();
|
return ReSize();
|
||||||
}
|
}
|
||||||
|
bool ArithmeticCPUKernel::IsScalarClac() {
|
||||||
|
if (param_->in_elements_num0_ == 1 || param_->in_elements_num1_ == 1) {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
return false;
|
||||||
|
}
|
||||||
int ArithmeticCPUKernel::ReSize() {
|
int ArithmeticCPUKernel::ReSize() {
|
||||||
CalcMultiplesAndStrides(param_);
|
CalcMultiplesAndStrides(param_);
|
||||||
if (param_->broadcasting_) {
|
scalar_ = IsScalarClac();
|
||||||
outside_ = 1;
|
int ret = RET_OK;
|
||||||
for (int i = static_cast<int>(param_->ndim_) - 1; i >= 0 && i < ARITHMETIC_SUPPORT_DIMS_NUM; --i) {
|
if (!scalar_) {
|
||||||
if (param_->in_shape0_[i] != param_->in_shape1_[i]) {
|
ret = ConstTensorBroadCast();
|
||||||
break_pos_ = i;
|
if (ret != RET_OK) {
|
||||||
break;
|
MS_LOG(ERROR) << "failed to init const tensor";
|
||||||
}
|
return ret;
|
||||||
outside_ *= param_->out_shape_[i];
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
data_type_len_ = lite::DataTypeSize(in_tensors_.at(0)->data_type());
|
if (!scalar_ && param_->broadcasting_) {
|
||||||
int ret = RET_OK;
|
ret = InitIndexOffsetInfo();
|
||||||
if (!IsScalarClac() && !IsBatchScalarCalc() && !IsBiasCalc()) {
|
|
||||||
ret = ConstTensorBroadCast();
|
|
||||||
}
|
}
|
||||||
|
data_type_len_ = lite::DataTypeSize(in_tensors_.at(0)->data_type());
|
||||||
|
|
||||||
return ret;
|
return ret;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
bool ArithmeticCPUKernel::IsBatchScalarCalc() { // 1 32 240 240, 2 32 1 1
|
||||||
|
int last_batch_axis0 = ARITHMETIC_SUPPORT_DIMS_NUM + 1;
|
||||||
|
int last_batch_axis1 = ARITHMETIC_SUPPORT_DIMS_NUM + 1;
|
||||||
|
if (param_->in_shape0_[param_->ndim_ - 1] == 1) {
|
||||||
|
for (int i = static_cast<int>(param_->ndim_) - 1; i >= 0 && i < ARITHMETIC_SUPPORT_DIMS_NUM; --i) {
|
||||||
|
if (param_->in_shape0_[i] != 1) {
|
||||||
|
last_batch_axis0 = i;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (param_->in_shape1_[param_->ndim_ - 1] == 1) {
|
||||||
|
for (int i = static_cast<int>(param_->ndim_) - 1; i >= 0 && i < ARITHMETIC_SUPPORT_DIMS_NUM; --i) {
|
||||||
|
if (param_->in_shape1_[i] != 1) {
|
||||||
|
last_batch_axis1 = i;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
int min_axis = MSMIN(last_batch_axis0, last_batch_axis1);
|
||||||
|
if (min_axis < static_cast<int>(param_->ndim_) - 1) {
|
||||||
|
last_batch_axis_ = min_axis;
|
||||||
|
if (last_batch_axis0 < last_batch_axis1) {
|
||||||
|
param_->in_elements_num0_ = 1;
|
||||||
|
} else {
|
||||||
|
param_->in_elements_num1_ = 1;
|
||||||
|
}
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
int ArithmeticCPUKernel::InitIndexOffsetInfo() {
|
||||||
|
split_by_batch_ = true;
|
||||||
|
for (int i = static_cast<int>(param_->ndim_) - 1; i >= 0 && i < ARITHMETIC_SUPPORT_DIMS_NUM; --i) {
|
||||||
|
if (param_->in_shape0_[i] != param_->in_shape1_[i]) {
|
||||||
|
break_pos_ = i;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<int> a_shape;
|
||||||
|
std::vector<int> b_shape;
|
||||||
|
std::vector<int> c_shape = out_tensors_[0]->shape();
|
||||||
|
size_t dim = c_shape.size();
|
||||||
|
for (size_t i = 0; i < dim; ++i) {
|
||||||
|
a_shape.push_back(param_->in_shape0_[i]);
|
||||||
|
b_shape.push_back(param_->in_shape1_[i]);
|
||||||
|
}
|
||||||
|
batch_scalar_ = IsBatchScalarCalc();
|
||||||
|
|
||||||
|
a_stride_size_ = 1;
|
||||||
|
b_stride_size_ = 1;
|
||||||
|
c_stride_size_ = 1;
|
||||||
|
int last_batch_axis = batch_scalar_ ? last_batch_axis_ : break_pos_;
|
||||||
|
for (int i = static_cast<int>(param_->ndim_) - 1; i > last_batch_axis && i < ARITHMETIC_SUPPORT_DIMS_NUM; --i) {
|
||||||
|
a_stride_size_ *= a_shape[i];
|
||||||
|
b_stride_size_ *= b_shape[i];
|
||||||
|
c_stride_size_ *= c_shape[i];
|
||||||
|
}
|
||||||
|
|
||||||
|
out_batch_ = 1;
|
||||||
|
int batch_size[ARITHMETIC_SUPPORT_DIMS_NUM] = {};
|
||||||
|
int a_batch_size[ARITHMETIC_SUPPORT_DIMS_NUM] = {};
|
||||||
|
int b_batch_size[ARITHMETIC_SUPPORT_DIMS_NUM] = {};
|
||||||
|
for (int i = last_batch_axis; i >= 0; --i) {
|
||||||
|
out_batch_ *= c_shape[i];
|
||||||
|
if (i == last_batch_axis) {
|
||||||
|
batch_size[i] = c_shape[i];
|
||||||
|
a_batch_size[i] = a_shape[i];
|
||||||
|
b_batch_size[i] = b_shape[i];
|
||||||
|
} else {
|
||||||
|
batch_size[i] = batch_size[i + 1] * c_shape[i];
|
||||||
|
a_batch_size[i] = a_batch_size[i + 1] * a_shape[i];
|
||||||
|
b_batch_size[i] = b_batch_size[i + 1] * b_shape[i];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
a_offset_.resize(out_batch_, 0);
|
||||||
|
b_offset_.resize(out_batch_, 0);
|
||||||
|
for (int i = 0; i < out_batch_; ++i) {
|
||||||
|
int delta = i;
|
||||||
|
int a_offset = 0;
|
||||||
|
int b_offset = 0;
|
||||||
|
for (int j = 0; j <= last_batch_axis; ++j) {
|
||||||
|
if (j > 0) {
|
||||||
|
delta = delta % batch_size[j];
|
||||||
|
}
|
||||||
|
if (j < last_batch_axis) {
|
||||||
|
a_offset += (delta / batch_size[j + 1] * a_shape[j] / MSMAX(a_shape[j], b_shape[j])) * a_batch_size[j + 1];
|
||||||
|
b_offset += (delta / batch_size[j + 1] * b_shape[j] / MSMAX(a_shape[j], b_shape[j])) * b_batch_size[j + 1];
|
||||||
|
} else {
|
||||||
|
a_offset += (delta * a_shape[j] / MSMAX(a_shape[j], b_shape[j]));
|
||||||
|
b_offset += (delta * b_shape[j] / MSMAX(a_shape[j], b_shape[j]));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
a_offset_[i] = a_offset;
|
||||||
|
b_offset_[i] = b_offset;
|
||||||
|
}
|
||||||
|
return RET_OK;
|
||||||
|
}
|
||||||
|
|
||||||
int ArithmeticCPUKernel::CheckDataType() {
|
int ArithmeticCPUKernel::CheckDataType() {
|
||||||
auto in0_dataType = in_tensors_.at(0)->data_type();
|
auto in0_dataType = in_tensors_.at(0)->data_type();
|
||||||
auto in1_dataType = in_tensors_.at(1)->data_type();
|
auto in1_dataType = in_tensors_.at(1)->data_type();
|
||||||
|
@ -85,47 +191,6 @@ int ArithmeticCPUKernel::CheckDataType() {
|
||||||
return RET_OK;
|
return RET_OK;
|
||||||
}
|
}
|
||||||
|
|
||||||
bool ArithmeticCPUKernel::IsScalarClac() { // 2 32 240 240, 1 1 1 1
|
|
||||||
if ((param_->in_elements_num0_ == 1 || param_->in_elements_num1_ == 1) && (arithmetic_opt_run_ != nullptr)) {
|
|
||||||
return true;
|
|
||||||
} else {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
bool ArithmeticCPUKernel::IsBatchScalarCalc() { // 2 32 240 240, 2 32 1 1
|
|
||||||
if (arithmetic_opt_run_ == nullptr) {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
size_t break_axis = 0;
|
|
||||||
for (size_t i = 0; i < param_->ndim_; i++) {
|
|
||||||
if (param_->in_shape0_[i] != param_->in_shape1_[i]) {
|
|
||||||
break_axis = i;
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if (break_axis < param_->ndim_) {
|
|
||||||
for (size_t i = break_axis; i < param_->ndim_; i++) {
|
|
||||||
if (param_->in_shape1_[i] != 1) {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
break_pos_ = break_axis;
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
|
|
||||||
bool ArithmeticCPUKernel::IsBiasCalc() const { // 2 240 240 32, 1 1 1 32
|
|
||||||
int last_shape0 = param_->in_shape0_[param_->ndim_ - 1];
|
|
||||||
int last_shape1 = param_->in_shape1_[param_->ndim_ - 1];
|
|
||||||
if (param_->in_elements_num0_ > param_->in_elements_num1_) {
|
|
||||||
return param_->in_elements_num1_ == last_shape1 && last_shape0 == last_shape1;
|
|
||||||
} else if (param_->in_elements_num0_ < param_->in_elements_num1_) {
|
|
||||||
return param_->in_elements_num0_ == last_shape0 && last_shape0 == last_shape1;
|
|
||||||
}
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
int ArithmeticCPUKernel::ConstTensorBroadCast() {
|
int ArithmeticCPUKernel::ConstTensorBroadCast() {
|
||||||
/* if const node need broadcast and all need-broadcast-node are const, broadcast in resize */
|
/* if const node need broadcast and all need-broadcast-node are const, broadcast in resize */
|
||||||
if (!param_->broadcasting_) {
|
if (!param_->broadcasting_) {
|
||||||
|
@ -168,13 +233,11 @@ int ArithmeticCPUKernel::ConstTensorBroadCast() {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
// broadcast input and get new break_pos_
|
// broadcast input and get new break_pos_
|
||||||
outside_ = 1;
|
|
||||||
for (int i = static_cast<int>(param_->ndim_) - 1; i >= 0; --i) {
|
for (int i = static_cast<int>(param_->ndim_) - 1; i >= 0; --i) {
|
||||||
if (param_->in_shape0_[i] != param_->in_shape1_[i]) {
|
if (param_->in_shape0_[i] != param_->in_shape1_[i]) {
|
||||||
break_pos_ = i;
|
break_pos_ = i;
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
outside_ *= param_->out_shape_[i];
|
|
||||||
}
|
}
|
||||||
if (param_->in_elements_num0_ == param_->out_elements_num_ &&
|
if (param_->in_elements_num0_ == param_->out_elements_num_ &&
|
||||||
param_->in_elements_num1_ == param_->out_elements_num_) {
|
param_->in_elements_num1_ == param_->out_elements_num_) {
|
||||||
|
@ -205,50 +268,51 @@ void ArithmeticCPUKernel::FreeConstTileBuff() {
|
||||||
void ArithmeticCPUKernel::InitRunFunction(int primitive_type) {
|
void ArithmeticCPUKernel::InitRunFunction(int primitive_type) {
|
||||||
ARITHMETIC_FUNC_INFO_FP32 fun_table[] = {
|
ARITHMETIC_FUNC_INFO_FP32 fun_table[] = {
|
||||||
{PrimitiveType_MulFusion, schema::ActivationType_RELU, ElementMulRelu, ElementMulReluInt, nullptr,
|
{PrimitiveType_MulFusion, schema::ActivationType_RELU, ElementMulRelu, ElementMulReluInt, nullptr,
|
||||||
ElementOptMulRelu, ElementOptMulReluInt},
|
ElementOptMulRelu, ElementOptMulReluInt, nullptr},
|
||||||
{PrimitiveType_MulFusion, schema::ActivationType_RELU6, ElementMulRelu6, ElementMulRelu6Int, nullptr,
|
{PrimitiveType_MulFusion, schema::ActivationType_RELU6, ElementMulRelu6, ElementMulRelu6Int, nullptr,
|
||||||
ElementOptMulRelu6, ElementOptMulRelu6Int},
|
ElementOptMulRelu6, ElementOptMulRelu6Int, nullptr},
|
||||||
{PrimitiveType_MulFusion, schema::ActivationType_NO_ACTIVATION, ElementMul, ElementMulInt, nullptr, ElementOptMul,
|
{PrimitiveType_MulFusion, schema::ActivationType_NO_ACTIVATION, ElementMul, ElementMulInt, nullptr, ElementOptMul,
|
||||||
ElementOptMulInt},
|
ElementOptMulInt, nullptr},
|
||||||
{PrimitiveType_AddFusion, schema::ActivationType_RELU, ElementAddRelu, nullptr, nullptr, ElementOptAddRelu,
|
{PrimitiveType_AddFusion, schema::ActivationType_RELU, ElementAddRelu, nullptr, nullptr, ElementOptAddRelu, nullptr,
|
||||||
nullptr},
|
nullptr},
|
||||||
{PrimitiveType_AddFusion, schema::ActivationType_RELU6, ElementAddRelu6, nullptr, nullptr, ElementOptAddRelu6,
|
{PrimitiveType_AddFusion, schema::ActivationType_RELU6, ElementAddRelu6, nullptr, nullptr, ElementOptAddRelu6,
|
||||||
nullptr},
|
nullptr, nullptr},
|
||||||
{PrimitiveType_AddFusion, schema::ActivationType_NO_ACTIVATION, ElementAdd, ElementAddInt, nullptr, ElementOptAdd,
|
{PrimitiveType_AddFusion, schema::ActivationType_NO_ACTIVATION, ElementAdd, ElementAddInt, nullptr, ElementOptAdd,
|
||||||
ElementOptAddInt},
|
ElementOptAddInt, nullptr},
|
||||||
{PrimitiveType_SubFusion, schema::ActivationType_RELU, ElementSubRelu, nullptr, nullptr, ElementOptSubRelu,
|
{PrimitiveType_SubFusion, schema::ActivationType_RELU, ElementSubRelu, nullptr, nullptr, ElementOptSubRelu, nullptr,
|
||||||
nullptr},
|
nullptr},
|
||||||
{PrimitiveType_SubFusion, schema::ActivationType_RELU6, ElementSubRelu6, nullptr, nullptr, ElementOptSubRelu6,
|
{PrimitiveType_SubFusion, schema::ActivationType_RELU6, ElementSubRelu6, nullptr, nullptr, ElementOptSubRelu6,
|
||||||
nullptr},
|
nullptr, nullptr},
|
||||||
{PrimitiveType_SubFusion, schema::ActivationType_NO_ACTIVATION, ElementSub, ElementSubInt, nullptr, ElementOptSub,
|
{PrimitiveType_SubFusion, schema::ActivationType_NO_ACTIVATION, ElementSub, ElementSubInt, nullptr, ElementOptSub,
|
||||||
ElementOptSubInt},
|
ElementOptSubInt, nullptr},
|
||||||
{PrimitiveType_DivFusion, schema::ActivationType_RELU, ElementDivRelu, nullptr, nullptr, ElementOptDivRelu,
|
{PrimitiveType_DivFusion, schema::ActivationType_RELU, ElementDivRelu, nullptr, nullptr, ElementOptDivRelu, nullptr,
|
||||||
nullptr},
|
nullptr},
|
||||||
{PrimitiveType_DivFusion, schema::ActivationType_RELU6, ElementDivRelu6, nullptr, nullptr, ElementOptDivRelu6,
|
{PrimitiveType_DivFusion, schema::ActivationType_RELU6, ElementDivRelu6, nullptr, nullptr, ElementOptDivRelu6,
|
||||||
nullptr},
|
nullptr, nullptr},
|
||||||
{PrimitiveType_DivFusion, schema::ActivationType_NO_ACTIVATION, ElementDiv, nullptr, nullptr, ElementOptDiv,
|
{PrimitiveType_DivFusion, schema::ActivationType_NO_ACTIVATION, ElementDiv, nullptr, nullptr, ElementOptDiv,
|
||||||
ElementOptDivInt},
|
ElementOptDivInt, nullptr},
|
||||||
{PrimitiveType_RealDiv, schema::ActivationType_RELU, ElementDivRelu, nullptr, nullptr, ElementOptDivRelu, nullptr},
|
{PrimitiveType_RealDiv, schema::ActivationType_RELU, ElementDivRelu, nullptr, nullptr, ElementOptDivRelu, nullptr,
|
||||||
|
nullptr},
|
||||||
{PrimitiveType_RealDiv, schema::ActivationType_RELU6, ElementDivRelu6, nullptr, nullptr, ElementOptDivRelu6,
|
{PrimitiveType_RealDiv, schema::ActivationType_RELU6, ElementDivRelu6, nullptr, nullptr, ElementOptDivRelu6,
|
||||||
nullptr},
|
nullptr, nullptr},
|
||||||
{PrimitiveType_RealDiv, schema::ActivationType_NO_ACTIVATION, ElementDiv, nullptr, nullptr, ElementOptDiv,
|
{PrimitiveType_RealDiv, schema::ActivationType_NO_ACTIVATION, ElementDiv, nullptr, nullptr, ElementOptDiv,
|
||||||
ElementOptDivInt},
|
ElementOptDivInt, nullptr},
|
||||||
{PrimitiveType_LogicalAnd, schema::ActivationType_NO_ACTIVATION, ElementLogicalAnd, ElementLogicalAndInt,
|
{PrimitiveType_LogicalAnd, schema::ActivationType_NO_ACTIVATION, ElementLogicalAnd, ElementLogicalAndInt,
|
||||||
ElementLogicalAndBool, nullptr, nullptr},
|
ElementLogicalAndBool, ElementOptLogicalAnd, ElementOptLogicalAndInt, ElementOptLogicalAndBool},
|
||||||
{PrimitiveType_LogicalOr, schema::ActivationType_NO_ACTIVATION, ElementLogicalOr, nullptr, ElementLogicalOrBool,
|
{PrimitiveType_LogicalOr, schema::ActivationType_NO_ACTIVATION, ElementLogicalOr, nullptr, ElementLogicalOrBool,
|
||||||
nullptr, nullptr},
|
nullptr, nullptr, ElementOptLogicalOrBool},
|
||||||
{PrimitiveType_Maximum, schema::ActivationType_NO_ACTIVATION, ElementMaximum, ElementMaximumInt, nullptr, nullptr,
|
{PrimitiveType_Maximum, schema::ActivationType_NO_ACTIVATION, ElementMaximum, ElementMaximumInt, nullptr,
|
||||||
nullptr},
|
ElementOptMaximum, ElementOptMaximumInt, nullptr},
|
||||||
{PrimitiveType_Minimum, schema::ActivationType_NO_ACTIVATION, ElementMinimum, ElementMinimumInt, nullptr, nullptr,
|
{PrimitiveType_Minimum, schema::ActivationType_NO_ACTIVATION, ElementMinimum, ElementMinimumInt, nullptr,
|
||||||
nullptr},
|
ElementOptMinimum, ElementOptMinimumInt, nullptr},
|
||||||
{PrimitiveType_FloorMod, schema::ActivationType_NO_ACTIVATION, ElementFloorMod, ElementFloorModInt, nullptr,
|
{PrimitiveType_FloorMod, schema::ActivationType_NO_ACTIVATION, ElementFloorMod, ElementFloorModInt, nullptr,
|
||||||
nullptr, nullptr},
|
ElementOptFloorMod, ElementOptFloorModInt, nullptr},
|
||||||
{PrimitiveType_FloorDiv, schema::ActivationType_NO_ACTIVATION, ElementFloorDiv, ElementFloorDivInt, nullptr,
|
{PrimitiveType_FloorDiv, schema::ActivationType_NO_ACTIVATION, ElementFloorDiv, ElementFloorDivInt, nullptr,
|
||||||
nullptr, nullptr},
|
ElementOptFloorDiv, ElementOptFloorDivInt, nullptr},
|
||||||
{PrimitiveType_Mod, schema::ActivationType_NO_ACTIVATION, ElementMod, ElementModInt, nullptr, ElementOptMod,
|
{PrimitiveType_Mod, schema::ActivationType_NO_ACTIVATION, ElementMod, ElementModInt, nullptr, ElementOptMod,
|
||||||
ElementOptModInt},
|
ElementOptModInt, nullptr},
|
||||||
{PrimitiveType_SquaredDifference, schema::ActivationType_NO_ACTIVATION, ElementSquaredDifference, nullptr, nullptr,
|
{PrimitiveType_SquaredDifference, schema::ActivationType_NO_ACTIVATION, ElementSquaredDifference, nullptr, nullptr,
|
||||||
nullptr, nullptr}};
|
ElementOptSquaredDifference, nullptr, nullptr}};
|
||||||
|
|
||||||
size_t length = sizeof(fun_table) / sizeof(ARITHMETIC_FUNC_INFO_FP32);
|
size_t length = sizeof(fun_table) / sizeof(ARITHMETIC_FUNC_INFO_FP32);
|
||||||
for (size_t i = 0; i < length; i++) {
|
for (size_t i = 0; i < length; i++) {
|
||||||
|
@ -258,6 +322,7 @@ void ArithmeticCPUKernel::InitRunFunction(int primitive_type) {
|
||||||
arithmetic_run_bool_ = fun_table[i].bool_func_;
|
arithmetic_run_bool_ = fun_table[i].bool_func_;
|
||||||
arithmetic_opt_run_ = fun_table[i].opt_func_;
|
arithmetic_opt_run_ = fun_table[i].opt_func_;
|
||||||
arithmetic_opt_run_int_ = fun_table[i].opt_int_func_;
|
arithmetic_opt_run_int_ = fun_table[i].opt_int_func_;
|
||||||
|
arithmetic_opt_run_bool_ = fun_table[i].opt_bool_func_;
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -276,9 +341,15 @@ int ArithmeticCPUKernel::Execute(const void *input0, const void *input1, void *o
|
||||||
reinterpret_cast<float *>(output), size);
|
reinterpret_cast<float *>(output), size);
|
||||||
}
|
}
|
||||||
} else if (in_tensors_[0]->data_type() == kNumberTypeBool) {
|
} else if (in_tensors_[0]->data_type() == kNumberTypeBool) {
|
||||||
CHECK_NULL_RETURN(arithmetic_run_bool_);
|
if (is_opt) {
|
||||||
ret = arithmetic_run_bool_(reinterpret_cast<const bool *>(input0), reinterpret_cast<const bool *>(input1),
|
CHECK_NULL_RETURN(arithmetic_opt_run_bool_);
|
||||||
reinterpret_cast<bool *>(output), size);
|
ret = arithmetic_opt_run_bool_(reinterpret_cast<const bool *>(input0), reinterpret_cast<const bool *>(input1),
|
||||||
|
reinterpret_cast<bool *>(output), size, param_);
|
||||||
|
} else {
|
||||||
|
CHECK_NULL_RETURN(arithmetic_run_bool_);
|
||||||
|
ret = arithmetic_run_bool_(reinterpret_cast<const bool *>(input0), reinterpret_cast<const bool *>(input1),
|
||||||
|
reinterpret_cast<bool *>(output), size);
|
||||||
|
}
|
||||||
} else {
|
} else {
|
||||||
if (is_opt) {
|
if (is_opt) {
|
||||||
CHECK_NULL_RETURN(arithmetic_opt_run_int_);
|
CHECK_NULL_RETURN(arithmetic_opt_run_int_);
|
||||||
|
@ -293,102 +364,35 @@ int ArithmeticCPUKernel::Execute(const void *input0, const void *input1, void *o
|
||||||
return ret;
|
return ret;
|
||||||
}
|
}
|
||||||
|
|
||||||
int ArithmeticCPUKernel::BroadcastRun(void *input0, void *input1, void *output, int dim, int out_count,
|
int ArithmeticCPUKernel::CalcArithmeticByBatch(int task_id) {
|
||||||
int out_thread_stride) {
|
int batch_per_thread = UP_DIV(out_batch_, op_parameter_->thread_num_);
|
||||||
if (dim > break_pos_) {
|
int start_batch = batch_per_thread * task_id;
|
||||||
int offset = out_thread_stride * data_type_len_;
|
int end_batch = MSMIN(start_batch + batch_per_thread, out_batch_);
|
||||||
return Execute(static_cast<uint8_t *>(input0) + offset, static_cast<uint8_t *>(input1) + offset,
|
int ret = RET_ERROR;
|
||||||
static_cast<uint8_t *>(output) + offset, out_count, false);
|
for (int i = start_batch; i < end_batch; i++) {
|
||||||
}
|
batch_a_ptr_ = static_cast<uint8_t *>(input0_ptr_) + a_offset_[i] * a_stride_size_ * data_type_len_;
|
||||||
int offset_size[] = {param_->in_strides0_[dim] * data_type_len_, param_->in_strides1_[dim] * data_type_len_,
|
batch_b_ptr_ = static_cast<uint8_t *>(input1_ptr_) + b_offset_[i] * b_stride_size_ * data_type_len_;
|
||||||
param_->out_strides_[dim] * data_type_len_};
|
batch_c_ptr_ = static_cast<uint8_t *>(output_ptr_) + i * c_stride_size_ * data_type_len_;
|
||||||
for (int i = 0; i < param_->out_shape_[dim]; ++i) {
|
if (batch_scalar_) {
|
||||||
int pos0_ = param_->in_shape0_[dim] == 1 ? 0 : i;
|
ret = Execute(batch_a_ptr_, batch_b_ptr_, batch_c_ptr_, c_stride_size_, true);
|
||||||
int pos1_ = param_->in_shape1_[dim] == 1 ? 0 : i;
|
} else {
|
||||||
int ret = BroadcastRun(static_cast<uint8_t *>(input0) + pos0_ * offset_size[0],
|
ret = Execute(batch_a_ptr_, batch_b_ptr_, batch_c_ptr_, c_stride_size_, false);
|
||||||
static_cast<uint8_t *>(input1) + pos1_ * offset_size[1],
|
}
|
||||||
static_cast<uint8_t *>(output) + i * offset_size[2], dim + 1, out_count, out_thread_stride);
|
|
||||||
if (ret != RET_OK) {
|
if (ret != RET_OK) {
|
||||||
return ret;
|
MS_LOG(ERROR) << "failed to calculate.";
|
||||||
}
|
return RET_ERROR;
|
||||||
}
|
|
||||||
return RET_OK;
|
|
||||||
}
|
|
||||||
|
|
||||||
int ArithmeticCPUKernel::BatchScalarCalc(int task_id) {
|
|
||||||
if (break_pos_ < 1) {
|
|
||||||
return RET_ERROR;
|
|
||||||
}
|
|
||||||
if (break_pos_ > ARITHMETIC_SUPPORT_DIMS_NUM || param_->out_strides_[break_pos_ - 1] == 0) {
|
|
||||||
MS_LOG(ERROR) << "param_->out_strides_[break_pos_ - 1] is 0 or break_pos_ is > 10";
|
|
||||||
return RET_ERROR;
|
|
||||||
}
|
|
||||||
int batch = param_->out_elements_num_ / param_->out_strides_[break_pos_ - 1];
|
|
||||||
int batch_per_thread = UP_DIV(batch, op_parameter_->thread_num_);
|
|
||||||
|
|
||||||
int start_batch = batch_per_thread * task_id;
|
|
||||||
int end_batch = MSMIN(start_batch + batch_per_thread, batch);
|
|
||||||
int batch_size = end_batch - start_batch;
|
|
||||||
|
|
||||||
int stride0 = param_->in_strides0_[break_pos_ - 1] * data_type_len_;
|
|
||||||
int stride1 = param_->in_strides1_[break_pos_ - 1] * data_type_len_;
|
|
||||||
int out_stride = param_->out_strides_[break_pos_ - 1] * data_type_len_;
|
|
||||||
|
|
||||||
int offset0 = stride0 * start_batch;
|
|
||||||
int offset1 = stride1 * start_batch;
|
|
||||||
int out_offset = out_stride * start_batch;
|
|
||||||
|
|
||||||
int ret = RET_OK;
|
|
||||||
for (int i = 0; i < batch_size; i++) {
|
|
||||||
ret = Execute(static_cast<uint8_t *>(input0_ptr_) + offset0, static_cast<uint8_t *>(input1_ptr_) + offset1,
|
|
||||||
static_cast<uint8_t *>(output_ptr_) + out_offset, param_->out_strides_[break_pos_ - 1], true);
|
|
||||||
offset0 += stride0;
|
|
||||||
offset1 += stride1;
|
|
||||||
out_offset += out_stride;
|
|
||||||
}
|
|
||||||
return ret;
|
|
||||||
}
|
|
||||||
|
|
||||||
int ArithmeticCPUKernel::BiasCalc(int task_id) {
|
|
||||||
if (param_->ndim_ > ARITHMETIC_SUPPORT_DIMS_NUM || param_->out_shape_[param_->ndim_ - 1] == 0) {
|
|
||||||
MS_LOG(ERROR) << "BiasCalc param is error!";
|
|
||||||
return RET_ERROR;
|
|
||||||
}
|
|
||||||
int last_shape = param_->out_shape_[param_->ndim_ - 1];
|
|
||||||
int batch = param_->out_elements_num_ / last_shape;
|
|
||||||
int batch_per_thread = UP_DIV(batch, op_parameter_->thread_num_);
|
|
||||||
|
|
||||||
int start_batch = batch_per_thread * task_id;
|
|
||||||
int end_batch = MSMIN(start_batch + batch_per_thread, batch);
|
|
||||||
int batch_size = end_batch - start_batch;
|
|
||||||
|
|
||||||
int stride = last_shape * data_type_len_;
|
|
||||||
int offset = stride * start_batch;
|
|
||||||
int ret = RET_OK;
|
|
||||||
if (param_->in_elements_num0_ > param_->in_elements_num1_) {
|
|
||||||
for (int i = 0; i < batch_size; i++) {
|
|
||||||
ret = Execute(static_cast<uint8_t *>(input0_ptr_) + offset, static_cast<uint8_t *>(input1_ptr_),
|
|
||||||
static_cast<uint8_t *>(output_ptr_) + offset, last_shape, false);
|
|
||||||
if (ret != RET_OK) {
|
|
||||||
return ret;
|
|
||||||
}
|
|
||||||
offset += stride;
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
for (int i = 0; i < batch_size; i++) {
|
|
||||||
ret = Execute(static_cast<uint8_t *>(input0_ptr_), static_cast<uint8_t *>(input1_ptr_) + offset,
|
|
||||||
static_cast<uint8_t *>(output_ptr_) + offset, last_shape, false);
|
|
||||||
if (ret != RET_OK) {
|
|
||||||
return ret;
|
|
||||||
}
|
|
||||||
offset += stride;
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return ret;
|
return ret;
|
||||||
}
|
}
|
||||||
|
|
||||||
int ArithmeticCPUKernel::DoArithmetic(int task_id) {
|
int ArithmeticCPUKernel::DoArithmetic(int task_id) {
|
||||||
auto element_num = out_tensors_[0]->ElementsNum();
|
if (split_by_batch_) {
|
||||||
|
return CalcArithmeticByBatch(task_id);
|
||||||
|
}
|
||||||
|
|
||||||
|
int64_t element_num = out_tensors_[0]->ElementsNum();
|
||||||
|
auto ret = RET_ERROR;
|
||||||
int stride = UP_DIV(element_num, op_parameter_->thread_num_);
|
int stride = UP_DIV(element_num, op_parameter_->thread_num_);
|
||||||
int count = MSMIN(stride, element_num - stride * task_id);
|
int count = MSMIN(stride, element_num - stride * task_id);
|
||||||
if (count <= 0) {
|
if (count <= 0) {
|
||||||
|
@ -396,36 +400,16 @@ int ArithmeticCPUKernel::DoArithmetic(int task_id) {
|
||||||
}
|
}
|
||||||
CHECK_LESS_RETURN(ARITHMETIC_SUPPORT_DIMS_NUM, param_->ndim_);
|
CHECK_LESS_RETURN(ARITHMETIC_SUPPORT_DIMS_NUM, param_->ndim_);
|
||||||
int offset = stride * task_id * data_type_len_;
|
int offset = stride * task_id * data_type_len_;
|
||||||
/* run opt function, one of input is scalar */
|
if (scalar_) {
|
||||||
if (IsScalarClac()) { // 2 32 240 240, 1 1 1 1
|
|
||||||
if (param_->in_elements_num0_ == 1) {
|
if (param_->in_elements_num0_ == 1) {
|
||||||
return Execute(input0_ptr_, static_cast<uint8_t *>(input1_ptr_) + offset,
|
ret = Execute(batch_a_ptr_, batch_b_ptr_ + offset, batch_c_ptr_ + offset, count, true);
|
||||||
static_cast<uint8_t *>(output_ptr_) + offset, count, true);
|
} else {
|
||||||
} else if (param_->in_elements_num1_ == 1) {
|
ret = Execute(batch_a_ptr_ + offset, batch_b_ptr_, batch_c_ptr_ + offset, count, true);
|
||||||
return Execute(static_cast<uint8_t *>(input0_ptr_) + offset, input1_ptr_,
|
|
||||||
static_cast<uint8_t *>(output_ptr_) + offset, count, true);
|
|
||||||
}
|
}
|
||||||
|
} else {
|
||||||
|
ret = Execute(batch_a_ptr_ + offset, batch_b_ptr_ + offset, batch_c_ptr_ + offset, count, false);
|
||||||
}
|
}
|
||||||
/* run opt function, every batch one of input is scalar */
|
return ret;
|
||||||
if (IsBatchScalarCalc()) { // 2 32 240 240, 2 32 1 1
|
|
||||||
return BatchScalarCalc(task_id);
|
|
||||||
}
|
|
||||||
/* each batch is eltwise calculation */
|
|
||||||
if (IsBiasCalc()) { // 2 240 240 32, 1 1 1 32
|
|
||||||
return BiasCalc(task_id);
|
|
||||||
}
|
|
||||||
/* need broadcast in runtime */
|
|
||||||
if (param_->broadcasting_) {
|
|
||||||
stride = UP_DIV(outside_, op_parameter_->thread_num_);
|
|
||||||
int out_count = MSMIN(stride, outside_ - stride * task_id);
|
|
||||||
if (out_count <= 0) {
|
|
||||||
return RET_OK;
|
|
||||||
}
|
|
||||||
return BroadcastRun(input0_ptr_, input1_ptr_, output_ptr_, 0, out_count, stride * task_id);
|
|
||||||
}
|
|
||||||
/* all elements eltwise calculation */
|
|
||||||
return Execute(static_cast<uint8_t *>(input0_ptr_) + offset, static_cast<uint8_t *>(input1_ptr_) + offset,
|
|
||||||
static_cast<uint8_t *>(output_ptr_) + offset, count, false);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
int ArithmeticsRun(void *cdata, int task_id, float lhs_scale, float rhs_scale) {
|
int ArithmeticsRun(void *cdata, int task_id, float lhs_scale, float rhs_scale) {
|
||||||
|
@ -442,6 +426,7 @@ int ArithmeticCPUKernel::Run() {
|
||||||
MS_LOG(ERROR) << "ArithmeticCPUKernel check dataType failed, kernel name: " << this->name();
|
MS_LOG(ERROR) << "ArithmeticCPUKernel check dataType failed, kernel name: " << this->name();
|
||||||
return RET_ERROR;
|
return RET_ERROR;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (!input0_broadcast_) {
|
if (!input0_broadcast_) {
|
||||||
input0_ptr_ = in_tensors_[0]->data();
|
input0_ptr_ = in_tensors_[0]->data();
|
||||||
CHECK_NULL_RETURN(input0_ptr_);
|
CHECK_NULL_RETURN(input0_ptr_);
|
||||||
|
@ -452,7 +437,16 @@ int ArithmeticCPUKernel::Run() {
|
||||||
}
|
}
|
||||||
output_ptr_ = out_tensors_[0]->data();
|
output_ptr_ = out_tensors_[0]->data();
|
||||||
CHECK_NULL_RETURN(output_ptr_);
|
CHECK_NULL_RETURN(output_ptr_);
|
||||||
return ParallelLaunch(this->ms_context_, ArithmeticsRun, this, op_parameter_->thread_num_);
|
batch_a_ptr_ = static_cast<uint8_t *>(input0_ptr_);
|
||||||
|
batch_b_ptr_ = static_cast<uint8_t *>(input1_ptr_);
|
||||||
|
batch_c_ptr_ = static_cast<uint8_t *>(output_ptr_);
|
||||||
|
auto ret = ParallelLaunch(this->ms_context_, ArithmeticsRun, this, op_parameter_->thread_num_);
|
||||||
|
if (ret != RET_OK) {
|
||||||
|
MS_LOG(ERROR) << "arithmetic failed";
|
||||||
|
return RET_ERROR;
|
||||||
|
}
|
||||||
|
|
||||||
|
return RET_OK;
|
||||||
}
|
}
|
||||||
|
|
||||||
REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_MulFusion, LiteKernelCreator<ArithmeticCPUKernel>)
|
REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_MulFusion, LiteKernelCreator<ArithmeticCPUKernel>)
|
||||||
|
|
|
@ -49,6 +49,8 @@ class ArithmeticCPUKernel : public InnerKernel {
|
||||||
typedef int (*ArithmeticOptIntRun)(const int *input0, const int *input1, int *output, const int element_size,
|
typedef int (*ArithmeticOptIntRun)(const int *input0, const int *input1, int *output, const int element_size,
|
||||||
const ArithmeticParameter *param);
|
const ArithmeticParameter *param);
|
||||||
typedef int (*ArithmeticBoolRun)(const bool *input0, const bool *input1, bool *output, const int element_size);
|
typedef int (*ArithmeticBoolRun)(const bool *input0, const bool *input1, bool *output, const int element_size);
|
||||||
|
typedef int (*ArithmeticOptBoolRun)(const bool *input0, const bool *input1, bool *output, const int element_size,
|
||||||
|
const ArithmeticParameter *param);
|
||||||
|
|
||||||
typedef struct {
|
typedef struct {
|
||||||
int primitive_type_;
|
int primitive_type_;
|
||||||
|
@ -58,6 +60,7 @@ class ArithmeticCPUKernel : public InnerKernel {
|
||||||
ArithmeticBoolRun bool_func_;
|
ArithmeticBoolRun bool_func_;
|
||||||
ArithmeticOptRun opt_func_;
|
ArithmeticOptRun opt_func_;
|
||||||
ArithmeticOptIntRun opt_int_func_;
|
ArithmeticOptIntRun opt_int_func_;
|
||||||
|
ArithmeticOptBoolRun opt_bool_func_;
|
||||||
} ARITHMETIC_FUNC_INFO_FP32;
|
} ARITHMETIC_FUNC_INFO_FP32;
|
||||||
|
|
||||||
public:
|
public:
|
||||||
|
@ -72,7 +75,6 @@ class ArithmeticCPUKernel : public InnerKernel {
|
||||||
int ReSize() override;
|
int ReSize() override;
|
||||||
int Run() override;
|
int Run() override;
|
||||||
virtual int DoArithmetic(int task_id);
|
virtual int DoArithmetic(int task_id);
|
||||||
virtual int BroadcastRun(void *input0, void *input1, void *output, int dim, int out_count, int out_thread_stride);
|
|
||||||
|
|
||||||
protected:
|
protected:
|
||||||
virtual void InitRunFunction(int primitive_type);
|
virtual void InitRunFunction(int primitive_type);
|
||||||
|
@ -83,17 +85,31 @@ class ArithmeticCPUKernel : public InnerKernel {
|
||||||
virtual int Execute(const void *input0, const void *input1, void *output, int size, bool is_opt);
|
virtual int Execute(const void *input0, const void *input1, void *output, int size, bool is_opt);
|
||||||
virtual bool IsBatchScalarCalc();
|
virtual bool IsBatchScalarCalc();
|
||||||
virtual bool IsScalarClac();
|
virtual bool IsScalarClac();
|
||||||
|
virtual int CalcArithmeticByBatch(int task_id);
|
||||||
bool input0_broadcast_ = false;
|
bool input0_broadcast_ = false;
|
||||||
bool input1_broadcast_ = false;
|
bool input1_broadcast_ = false;
|
||||||
void *input0_ptr_ = nullptr;
|
void *input0_ptr_ = nullptr;
|
||||||
void *input1_ptr_ = nullptr;
|
void *input1_ptr_ = nullptr;
|
||||||
void *output_ptr_ = nullptr;
|
void *output_ptr_ = nullptr;
|
||||||
|
uint8_t *batch_a_ptr_ = nullptr;
|
||||||
|
uint8_t *batch_b_ptr_ = nullptr;
|
||||||
|
uint8_t *batch_c_ptr_ = nullptr;
|
||||||
int break_pos_ = 0;
|
int break_pos_ = 0;
|
||||||
int outside_ = 0;
|
|
||||||
ArithmeticParameter *param_ = nullptr;
|
ArithmeticParameter *param_ = nullptr;
|
||||||
int data_type_len_ = sizeof(float);
|
int data_type_len_ = sizeof(float);
|
||||||
|
int out_batch_ = 1;
|
||||||
|
int a_stride_size_ = 1;
|
||||||
|
int b_stride_size_ = 1;
|
||||||
|
int c_stride_size_ = 1;
|
||||||
|
int last_batch_axis_ = 0;
|
||||||
|
bool scalar_ = false;
|
||||||
|
bool batch_scalar_ = false;
|
||||||
|
bool split_by_batch_ = false;
|
||||||
|
std::vector<int> a_offset_;
|
||||||
|
std::vector<int> b_offset_;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
|
int InitIndexOffsetInfo();
|
||||||
int BatchScalarCalc(int task_id);
|
int BatchScalarCalc(int task_id);
|
||||||
int BiasCalc(int task_id);
|
int BiasCalc(int task_id);
|
||||||
void FreeConstTileBuff();
|
void FreeConstTileBuff();
|
||||||
|
@ -103,6 +119,7 @@ class ArithmeticCPUKernel : public InnerKernel {
|
||||||
ArithmeticIntRun arithmetic_run_int_ = nullptr;
|
ArithmeticIntRun arithmetic_run_int_ = nullptr;
|
||||||
ArithmeticOptIntRun arithmetic_opt_run_int_ = nullptr;
|
ArithmeticOptIntRun arithmetic_opt_run_int_ = nullptr;
|
||||||
ArithmeticBoolRun arithmetic_run_bool_ = nullptr;
|
ArithmeticBoolRun arithmetic_run_bool_ = nullptr;
|
||||||
|
ArithmeticOptBoolRun arithmetic_opt_run_bool_ = nullptr;
|
||||||
};
|
};
|
||||||
int ArithmeticsRun(void *cdata, int task_id, float lhs_scale, float rhs_scale);
|
int ArithmeticsRun(void *cdata, int task_id, float lhs_scale, float rhs_scale);
|
||||||
} // namespace mindspore::kernel
|
} // namespace mindspore::kernel
|
||||||
|
|
Loading…
Reference in New Issue