optimize arithmetic operators

This commit is contained in:
wangyanling 2022-01-27 14:19:40 +08:00
parent 2451a94125
commit ee72df2cbc
14 changed files with 791 additions and 354 deletions

View File

@ -32,9 +32,9 @@ int ElementAddRelu(const float *in0, const float *in1, float *out, int size);
int ElementAddRelu6(const float *in0, const float *in1, float *out, int size); int ElementAddRelu6(const float *in0, const float *in1, float *out, int size);
int ElementAddInt(const int *in0, const int *in1, int *out, int size); int ElementAddInt(const int *in0, const int *in1, int *out, int size);
int ElementOptAdd(const float *in0, const float *in1, float *out, int size, const ArithmeticParameter *param); int ElementOptAdd(const float *in0, const float *in1, float *out, int size, const ArithmeticParameter *param);
int ElementOptAddInt(const int *in0, const int *in1, int *out, int size, const ArithmeticParameter *param);
int ElementOptAddRelu(const float *in0, const float *in1, float *out, int size, const ArithmeticParameter *param); int ElementOptAddRelu(const float *in0, const float *in1, float *out, int size, const ArithmeticParameter *param);
int ElementOptAddRelu6(const float *in0, const float *in1, float *out, int size, const ArithmeticParameter *param); int ElementOptAddRelu6(const float *in0, const float *in1, float *out, int size, const ArithmeticParameter *param);
int ElementOptAddInt(const int *in0, const int *in1, int *out, int size, const ArithmeticParameter *param);
int BroadcastAdd(const float *in0, const float *in1, float *tile_in0, float *tile_in1, float *out, int size, int BroadcastAdd(const float *in0, const float *in1, float *tile_in0, float *tile_in1, float *out, int size,
ArithmeticParameter *param); ArithmeticParameter *param);

View File

@ -26,6 +26,22 @@ int ElementEqualFp32(const float *input0, const float *input1, uint8_t *output,
return NNACL_OK; return NNACL_OK;
} }
int ElementOptEqualFp32(const float *input0, const float *input1, uint8_t *output, int element_size,
const ArithmeticParameter *param) {
int i = 0;
if (param->in_elements_num0_ == 1) {
for (; i < element_size; i++) {
output[i] = input0[0] == input1[i];
}
} else {
for (; i < element_size; i++) {
output[i] = input0[i] == input1[0];
}
}
return NNACL_OK;
}
int ElementEqualInt32(const int32_t *input0, const int32_t *input1, uint8_t *output, int element_size) { int ElementEqualInt32(const int32_t *input0, const int32_t *input1, uint8_t *output, int element_size) {
for (int i = 0; i < element_size; i++) { for (int i = 0; i < element_size; i++) {
output[i] = input0[i] == input1[i]; output[i] = input0[i] == input1[i];
@ -33,6 +49,22 @@ int ElementEqualInt32(const int32_t *input0, const int32_t *input1, uint8_t *out
return NNACL_OK; return NNACL_OK;
} }
int ElementOptEqualInt32(const int32_t *input0, const int32_t *input1, uint8_t *output, int element_size,
const ArithmeticParameter *param) {
int i = 0;
if (param->in_elements_num0_ == 1) {
for (; i < element_size; i++) {
output[i] = input0[0] == input1[i];
}
} else {
for (; i < element_size; i++) {
output[i] = input0[i] == input1[0];
}
}
return NNACL_OK;
}
// not equal // not equal
int ElementNotEqualFp32(const float *input0, const float *input1, uint8_t *output, int element_size) { int ElementNotEqualFp32(const float *input0, const float *input1, uint8_t *output, int element_size) {
for (int i = 0; i < element_size; i++) { for (int i = 0; i < element_size; i++) {
@ -41,6 +73,23 @@ int ElementNotEqualFp32(const float *input0, const float *input1, uint8_t *outpu
return NNACL_OK; return NNACL_OK;
} }
// not equal
int ElementOptNotEqualFp32(const float *input0, const float *input1, uint8_t *output, int element_size,
const ArithmeticParameter *param) {
int i = 0;
if (param->in_elements_num0_ == 1) {
for (; i < element_size; i++) {
output[i] = input0[0] != input1[i];
}
} else {
for (; i < element_size; i++) {
output[i] = input0[i] != input1[0];
}
}
return NNACL_OK;
}
int ElementNotEqualInt32(const int32_t *input0, const int32_t *input1, uint8_t *output, int element_size) { int ElementNotEqualInt32(const int32_t *input0, const int32_t *input1, uint8_t *output, int element_size) {
for (int i = 0; i < element_size; i++) { for (int i = 0; i < element_size; i++) {
output[i] = input0[i] != input1[i]; output[i] = input0[i] != input1[i];
@ -48,6 +97,22 @@ int ElementNotEqualInt32(const int32_t *input0, const int32_t *input1, uint8_t *
return NNACL_OK; return NNACL_OK;
} }
int ElementOptNotEqualInt32(const int32_t *input0, const int32_t *input1, uint8_t *output, int element_size,
const ArithmeticParameter *param) {
int i = 0;
if (param->in_elements_num0_ == 1) {
for (; i < element_size; i++) {
output[i] = input0[0] != input1[i];
}
} else {
for (; i < element_size; i++) {
output[i] = input0[i] != input1[0];
}
}
return NNACL_OK;
}
// less // less
int ElementLessFp32(const float *input0, const float *input1, uint8_t *output, int element_size) { int ElementLessFp32(const float *input0, const float *input1, uint8_t *output, int element_size) {
for (int i = 0; i < element_size; i++) { for (int i = 0; i < element_size; i++) {
@ -56,6 +121,22 @@ int ElementLessFp32(const float *input0, const float *input1, uint8_t *output, i
return NNACL_OK; return NNACL_OK;
} }
int ElementOptLessFp32(const float *input0, const float *input1, uint8_t *output, int element_size,
const ArithmeticParameter *param) {
int i = 0;
if (param->in_elements_num0_ == 1) {
for (; i < element_size; i++) {
output[i] = input0[0] < input1[i];
}
} else {
for (; i < element_size; i++) {
output[i] = input0[i] < input1[0];
}
}
return NNACL_OK;
}
int ElementLessInt32(const int32_t *input0, const int32_t *input1, uint8_t *output, int element_size) { int ElementLessInt32(const int32_t *input0, const int32_t *input1, uint8_t *output, int element_size) {
for (int i = 0; i < element_size; i++) { for (int i = 0; i < element_size; i++) {
output[i] = input0[i] < input1[i]; output[i] = input0[i] < input1[i];
@ -63,6 +144,22 @@ int ElementLessInt32(const int32_t *input0, const int32_t *input1, uint8_t *outp
return NNACL_OK; return NNACL_OK;
} }
int ElementOptLessInt32(const int32_t *input0, const int32_t *input1, uint8_t *output, int element_size,
const ArithmeticParameter *param) {
int i = 0;
if (param->in_elements_num0_ == 1) {
for (; i < element_size; i++) {
output[i] = input0[0] < input1[i];
}
} else {
for (; i < element_size; i++) {
output[i] = input0[i] < input1[0];
}
}
return NNACL_OK;
}
// less equal // less equal
int ElementLessEqualFp32(const float *input0, const float *input1, uint8_t *output, int element_size) { int ElementLessEqualFp32(const float *input0, const float *input1, uint8_t *output, int element_size) {
for (int i = 0; i < element_size; i++) { for (int i = 0; i < element_size; i++) {
@ -71,6 +168,22 @@ int ElementLessEqualFp32(const float *input0, const float *input1, uint8_t *outp
return NNACL_OK; return NNACL_OK;
} }
int ElementOptLessEqualFp32(const float *input0, const float *input1, uint8_t *output, int element_size,
const ArithmeticParameter *param) {
int i = 0;
if (param->in_elements_num0_ == 1) {
for (; i < element_size; i++) {
output[i] = input0[0] <= input1[i];
}
} else {
for (; i < element_size; i++) {
output[i] = input0[i] <= input1[0];
}
}
return NNACL_OK;
}
int ElementLessEqualInt32(const int32_t *input0, const int32_t *input1, uint8_t *output, int element_size) { int ElementLessEqualInt32(const int32_t *input0, const int32_t *input1, uint8_t *output, int element_size) {
for (int i = 0; i < element_size; i++) { for (int i = 0; i < element_size; i++) {
output[i] = input0[i] <= input1[i]; output[i] = input0[i] <= input1[i];
@ -78,6 +191,22 @@ int ElementLessEqualInt32(const int32_t *input0, const int32_t *input1, uint8_t
return NNACL_OK; return NNACL_OK;
} }
int ElementOptLessEqualInt32(const int32_t *input0, const int32_t *input1, uint8_t *output, int element_size,
const ArithmeticParameter *param) {
int i = 0;
if (param->in_elements_num0_ == 1) {
for (; i < element_size; i++) {
output[i] = input0[0] <= input1[i];
}
} else {
for (; i < element_size; i++) {
output[i] = input0[i] <= input1[0];
}
}
return NNACL_OK;
}
// greater // greater
int ElementGreaterFp32(const float *input0, const float *input1, uint8_t *output, int element_size) { int ElementGreaterFp32(const float *input0, const float *input1, uint8_t *output, int element_size) {
for (int i = 0; i < element_size; i++) { for (int i = 0; i < element_size; i++) {
@ -86,6 +215,22 @@ int ElementGreaterFp32(const float *input0, const float *input1, uint8_t *output
return NNACL_OK; return NNACL_OK;
} }
int ElementOptGreaterFp32(const float *input0, const float *input1, uint8_t *output, int element_size,
const ArithmeticParameter *param) {
int i = 0;
if (param->in_elements_num0_ == 1) {
for (; i < element_size; i++) {
output[i] = input0[0] > input1[i];
}
} else {
for (; i < element_size; i++) {
output[i] = input0[i] > input1[0];
}
}
return NNACL_OK;
}
int ElementGreaterInt32(const int32_t *input0, const int32_t *input1, uint8_t *output, int element_size) { int ElementGreaterInt32(const int32_t *input0, const int32_t *input1, uint8_t *output, int element_size) {
for (int i = 0; i < element_size; i++) { for (int i = 0; i < element_size; i++) {
output[i] = input0[i] > input1[i]; output[i] = input0[i] > input1[i];
@ -93,6 +238,21 @@ int ElementGreaterInt32(const int32_t *input0, const int32_t *input1, uint8_t *o
return NNACL_OK; return NNACL_OK;
} }
int ElementOptGreaterInt32(const int32_t *input0, const int32_t *input1, uint8_t *output, int element_size,
const ArithmeticParameter *param) {
int i = 0;
if (param->in_elements_num0_ == 1) {
for (; i < element_size; i++) {
output[i] = input0[0] > input1[i];
}
} else {
for (; i < element_size; i++) {
output[i] = input0[i] > input1[0];
}
}
return NNACL_OK;
}
// greater equal // greater equal
int ElementGreaterEqualFp32(const float *input0, const float *input1, uint8_t *output, int element_size) { int ElementGreaterEqualFp32(const float *input0, const float *input1, uint8_t *output, int element_size) {
for (int i = 0; i < element_size; i++) { for (int i = 0; i < element_size; i++) {
@ -101,9 +261,41 @@ int ElementGreaterEqualFp32(const float *input0, const float *input1, uint8_t *o
return NNACL_OK; return NNACL_OK;
} }
int ElementOptGreaterEqualFp32(const float *input0, const float *input1, uint8_t *output, int element_size,
const ArithmeticParameter *param) {
int i = 0;
if (param->in_elements_num0_ == 1) {
for (; i < element_size; i++) {
output[i] = input0[0] >= input1[i];
}
} else {
for (; i < element_size; i++) {
output[i] = input0[i] >= input1[0];
}
}
return NNACL_OK;
}
int ElementGreaterEqualInt32(const int32_t *input0, const int32_t *input1, uint8_t *output, int element_size) { int ElementGreaterEqualInt32(const int32_t *input0, const int32_t *input1, uint8_t *output, int element_size) {
for (int i = 0; i < element_size; i++) { for (int i = 0; i < element_size; i++) {
output[i] = input0[i] >= input1[i]; output[i] = input0[i] >= input1[i];
} }
return NNACL_OK; return NNACL_OK;
} }
int ElementOptGreaterEqualInt32(const int32_t *input0, const int32_t *input1, uint8_t *output, int element_size,
const ArithmeticParameter *param) {
int i = 0;
if (param->in_elements_num0_ == 1) {
for (; i < element_size; i++) {
output[i] = input0[0] >= input1[i];
}
} else {
for (; i < element_size; i++) {
output[i] = input0[i] >= input1[0];
}
}
return NNACL_OK;
}

View File

@ -21,28 +21,53 @@
#include <arm_neon.h> #include <arm_neon.h>
#endif #endif
#include "nnacl/op_base.h" #include "nnacl/op_base.h"
#include "nnacl/base/arithmetic_base.h"
#include "nnacl/errorcode.h" #include "nnacl/errorcode.h"
#ifdef __cplusplus #ifdef __cplusplus
extern "C" { extern "C" {
#endif #endif
int ElementEqualFp32(const float *input0, const float *input1, uint8_t *output, int element_size); int ElementEqualFp32(const float *input0, const float *input1, uint8_t *output, int element_size);
int ElementOptEqualFp32(const float *input0, const float *input1, uint8_t *output, int element_size,
const ArithmeticParameter *param);
int ElementEqualInt32(const int32_t *input0, const int32_t *input1, uint8_t *output, int element_size); int ElementEqualInt32(const int32_t *input0, const int32_t *input1, uint8_t *output, int element_size);
int ElementOptEqualInt32(const int32_t *input0, const int32_t *input1, uint8_t *output, int element_size,
const ArithmeticParameter *param);
int ElementNotEqualFp32(const float *input0, const float *input1, uint8_t *output, int element_size); int ElementNotEqualFp32(const float *input0, const float *input1, uint8_t *output, int element_size);
int ElementOptNotEqualFp32(const float *input0, const float *input1, uint8_t *output, int element_size,
const ArithmeticParameter *param);
int ElementNotEqualInt32(const int32_t *input0, const int32_t *input1, uint8_t *output, int element_size); int ElementNotEqualInt32(const int32_t *input0, const int32_t *input1, uint8_t *output, int element_size);
int ElementOptNotEqualInt32(const int32_t *input0, const int32_t *input1, uint8_t *output, int element_size,
const ArithmeticParameter *param);
int ElementLessFp32(const float *input0, const float *input1, uint8_t *output, int element_size); int ElementLessFp32(const float *input0, const float *input1, uint8_t *output, int element_size);
int ElementOptLessFp32(const float *input0, const float *input1, uint8_t *output, int element_size,
const ArithmeticParameter *param);
int ElementLessInt32(const int32_t *input0, const int32_t *input1, uint8_t *output, int element_size); int ElementLessInt32(const int32_t *input0, const int32_t *input1, uint8_t *output, int element_size);
int ElementOptLessInt32(const int32_t *input0, const int32_t *input1, uint8_t *output, int element_size,
const ArithmeticParameter *param);
int ElementLessEqualFp32(const float *input0, const float *input1, uint8_t *output, int element_size); int ElementLessEqualFp32(const float *input0, const float *input1, uint8_t *output, int element_size);
int ElementOptLessEqualFp32(const float *input0, const float *input1, uint8_t *output, int element_size,
const ArithmeticParameter *param);
int ElementLessEqualInt32(const int32_t *input0, const int32_t *input1, uint8_t *output, int element_size); int ElementLessEqualInt32(const int32_t *input0, const int32_t *input1, uint8_t *output, int element_size);
int ElementOptLessEqualInt32(const int32_t *input0, const int32_t *input1, uint8_t *output, int element_size,
const ArithmeticParameter *param);
int ElementGreaterFp32(const float *input0, const float *input1, uint8_t *output, int element_size); int ElementGreaterFp32(const float *input0, const float *input1, uint8_t *output, int element_size);
int ElementOptGreaterFp32(const float *input0, const float *input1, uint8_t *output, int element_size,
const ArithmeticParameter *param);
int ElementGreaterInt32(const int32_t *input0, const int32_t *input1, uint8_t *output, int element_size); int ElementGreaterInt32(const int32_t *input0, const int32_t *input1, uint8_t *output, int element_size);
int ElementOptGreaterInt32(const int32_t *input0, const int32_t *input1, uint8_t *output, int element_size,
const ArithmeticParameter *param);
int ElementGreaterEqualFp32(const float *input0, const float *input1, uint8_t *output, int element_size); int ElementGreaterEqualFp32(const float *input0, const float *input1, uint8_t *output, int element_size);
int ElementOptGreaterEqualFp32(const float *input0, const float *input1, uint8_t *output, int element_size,
const ArithmeticParameter *param);
int ElementGreaterEqualInt32(const int32_t *input0, const int32_t *input1, uint8_t *output, int element_size); int ElementGreaterEqualInt32(const int32_t *input0, const int32_t *input1, uint8_t *output, int element_size);
int ElementOptGreaterEqualInt32(const int32_t *input0, const int32_t *input1, uint8_t *output, int element_size,
const ArithmeticParameter *param);
#ifdef __cplusplus #ifdef __cplusplus
} }
#endif #endif

View File

@ -26,6 +26,21 @@ int ElementFloorMod(const float *in0, const float *in1, float *out, int size) {
return NNACL_OK; return NNACL_OK;
} }
int ElementOptFloorMod(const float *in0, const float *in1, float *out, int size, const ArithmeticParameter *param) {
int i = 0;
if (param->in_elements_num0_ == 1) {
for (; i < size; i++) {
out[i] = in0[0] - floorf(in0[0] / in1[i]) * in1[i];
}
} else {
for (; i < size; i++) {
out[i] = in0[i] - floorf(in0[i] / in1[0]) * in1[0];
}
}
return NNACL_OK;
}
int ElementFloorModInt(const int *in0, const int *in1, int *out, int size) { int ElementFloorModInt(const int *in0, const int *in1, int *out, int size) {
for (int i = 0; i < size; i++) { for (int i = 0; i < size; i++) {
NNACL_CHECK_ZERO_RETURN_ERR(in1[i]); NNACL_CHECK_ZERO_RETURN_ERR(in1[i]);
@ -35,6 +50,25 @@ int ElementFloorModInt(const int *in0, const int *in1, int *out, int size) {
return NNACL_OK; return NNACL_OK;
} }
int ElementOptFloorModInt(const int *in0, const int *in1, int *out, int size, const ArithmeticParameter *param) {
int i = 0;
if (param->in_elements_num0_ == 1) {
for (; i < size; i++) {
NNACL_CHECK_ZERO_RETURN_ERR(in1[i]);
int remainder = in0[0] - (in0[0] / in1[i]) * in1[i];
out[i] = (remainder != 0) && ((in0[0] > 0) != (in1[i] > 0)) ? remainder + in1[i] : remainder;
}
} else {
NNACL_CHECK_ZERO_RETURN_ERR(in1[0]);
for (; i < size; i++) {
int remainder = in0[i] - (in0[i] / in1[0]) * in1[0];
out[i] = (remainder != 0) && ((in0[i] > 0) != (in1[0] > 0)) ? remainder + in1[0] : remainder;
}
}
return NNACL_OK;
}
int ElementMod(const float *in0, const float *in1, float *out, int size) { int ElementMod(const float *in0, const float *in1, float *out, int size) {
for (int i = 0; i < size; i++) { for (int i = 0; i < size; i++) {
out[i] = fmodf(in0[i], in1[i]); out[i] = fmodf(in0[i], in1[i]);
@ -42,23 +76,24 @@ int ElementMod(const float *in0, const float *in1, float *out, int size) {
return NNACL_OK; return NNACL_OK;
} }
int ElementModInt(const int *in0, const int *in1, int *out, int size) { int ElementOptMod(const float *in0, const float *in1, float *out, int size, const ArithmeticParameter *param) {
for (int i = 0; i < size; i++) { int index = 0;
NNACL_CHECK_ZERO_RETURN_ERR(in1[i]); if (param->in_elements_num0_ == 1) {
out[i] = in0[i] % in1[i]; for (; index < size; index++) {
out[index] = fmodf(in0[0], in1[index]);
}
} else {
for (; index < size; index++) {
out[index] = fmodf(in0[index], in1[0]);
}
} }
return NNACL_OK; return NNACL_OK;
} }
int ElementOptMod(const float *in0, const float *in1, float *out, int size, const ArithmeticParameter *param) { int ElementModInt(const int *in0, const int *in1, int *out, int size) {
if (param->in_elements_num0_ == 1) { for (int i = 0; i < size; i++) {
for (int index = 0; index < size; index++) { NNACL_CHECK_ZERO_RETURN_ERR(in1[i]);
out[index] = fmodf(in0[0], in1[index]); out[i] = in0[i] % in1[i];
}
} else {
for (int index = 0; index < size; index++) {
out[index] = fmodf(in0[index], in1[0]);
}
} }
return NNACL_OK; return NNACL_OK;
} }
@ -85,6 +120,21 @@ int ElementFloorDiv(const float *in0, const float *in1, float *out, int size) {
return NNACL_OK; return NNACL_OK;
} }
int ElementOptFloorDiv(const float *in0, const float *in1, float *out, int size, const ArithmeticParameter *param) {
int i = 0;
if (param->in_elements_num0_ == 1) {
for (; i < size; i++) {
out[i] = floorf(in0[0] / in1[i]);
}
} else {
for (; i < size; i++) {
out[i] = floorf(in0[i] / in1[0]);
}
}
return NNACL_OK;
}
int ElementFloorDivInt(const int *in0, const int *in1, int *out, int size) { int ElementFloorDivInt(const int *in0, const int *in1, int *out, int size) {
for (int i = 0; i < size; i++) { for (int i = 0; i < size; i++) {
NNACL_CHECK_ZERO_RETURN_ERR(in1[i]); NNACL_CHECK_ZERO_RETURN_ERR(in1[i]);
@ -93,6 +143,23 @@ int ElementFloorDivInt(const int *in0, const int *in1, int *out, int size) {
return NNACL_OK; return NNACL_OK;
} }
int ElementOptFloorDivInt(const int *in0, const int *in1, int *out, int size, const ArithmeticParameter *param) {
int i = 0;
if (param->in_elements_num0_ == 1) {
for (; i < size; i++) {
NNACL_CHECK_ZERO_RETURN_ERR(in1[i]);
out[i] = in0[0] / in1[i];
}
} else {
NNACL_CHECK_ZERO_RETURN_ERR(in1[0]);
for (; i < size; i++) {
out[i] = in0[i] / in1[0];
}
}
return NNACL_OK;
}
int ElementLogicalAnd(const float *in0, const float *in1, float *out, int size) { int ElementLogicalAnd(const float *in0, const float *in1, float *out, int size) {
int index = 0; int index = 0;
#ifdef ENABLE_NEON #ifdef ENABLE_NEON
@ -113,6 +180,21 @@ int ElementLogicalAnd(const float *in0, const float *in1, float *out, int size)
return NNACL_OK; return NNACL_OK;
} }
int ElementOptLogicalAnd(const float *in0, const float *in1, float *out, int size, const ArithmeticParameter *param) {
int index = 0;
if (param->in_elements_num0_ == 1) {
for (; index < size; index++) {
out[index] = (float)((bool)(in0[0]) & (bool)(in1[index]));
}
} else {
for (; index < size; index++) {
out[index] = (float)((bool)(in0[index]) & (bool)(in1[0]));
}
}
return NNACL_OK;
}
int ElementLogicalAndInt(const int *in0, const int *in1, int *out, int size) { int ElementLogicalAndInt(const int *in0, const int *in1, int *out, int size) {
int index = 0; int index = 0;
for (; index < size; index++) { for (; index < size; index++) {
@ -121,11 +203,42 @@ int ElementLogicalAndInt(const int *in0, const int *in1, int *out, int size) {
return NNACL_OK; return NNACL_OK;
} }
int ElementOptLogicalAndInt(const int *in0, const int *in1, int *out, int size, const ArithmeticParameter *param) {
int index = 0;
if (param->in_elements_num0_ == 1) {
for (; index < size; index++) {
out[index] = (int)((unsigned int)(in0[0]) & (unsigned int)(in1[index]));
}
} else {
for (; index < size; index++) {
out[index] = (int)((unsigned int)(in0[index]) & (unsigned int)(in1[0]));
}
}
return NNACL_OK;
}
int ElementLogicalAndBool(const bool *in0, const bool *in1, bool *out, int size) { int ElementLogicalAndBool(const bool *in0, const bool *in1, bool *out, int size) {
int index = 0; int index = 0;
for (; index < size; index++) { for (; index < size; index++) {
out[index] = (bool)((unsigned int)(in0[index]) & (unsigned int)(in1[index])); out[index] = (bool)((unsigned int)(in0[index]) & (unsigned int)(in1[index]));
} }
return NNACL_OK;
}
int ElementOptLogicalAndBool(const bool *in0, const bool *in1, bool *out, int size, const ArithmeticParameter *param) {
int index = 0;
if (param->in_elements_num0_ == 1) {
for (; index < size; index++) {
out[index] = (bool)((unsigned int)(in0[0]) & (unsigned int)(in1[index]));
}
} else {
for (; index < size; index++) {
out[index] = (bool)((unsigned int)(in0[index]) & (unsigned int)(in1[0]));
}
}
return NNACL_OK; return NNACL_OK;
} }
@ -149,6 +262,21 @@ int ElementLogicalOr(const float *in0, const float *in1, float *out, int size) {
return NNACL_OK; return NNACL_OK;
} }
int ElementOptLogicalOr(const float *in0, const float *in1, float *out, int size, const ArithmeticParameter *param) {
int index = 0;
if (param->in_elements_num0_ == 1) {
for (; index < size; index++) {
out[index] = (float)((bool)(in0[0]) | (bool)(in1[index]));
}
} else {
for (; index < size; index++) {
out[index] = (float)((bool)(in0[index]) | (bool)(in1[0]));
}
}
return NNACL_OK;
}
int ElementLogicalOrBool(const bool *in0, const bool *in1, bool *out, int size) { int ElementLogicalOrBool(const bool *in0, const bool *in1, bool *out, int size) {
int index = 0; int index = 0;
for (; index < size; index++) { for (; index < size; index++) {
@ -157,6 +285,21 @@ int ElementLogicalOrBool(const bool *in0, const bool *in1, bool *out, int size)
return NNACL_OK; return NNACL_OK;
} }
int ElementOptLogicalOrBool(const bool *in0, const bool *in1, bool *out, int size, const ArithmeticParameter *param) {
int index = 0;
if (param->in_elements_num0_ == 1) {
for (; index < size; index++) {
out[index] = (bool)(in0[0] | in1[index]);
}
} else {
for (; index < size; index++) {
out[index] = (bool)(in0[index] | in1[0]);
}
}
return NNACL_OK;
}
int ElementMaximum(const float *in0, const float *in1, float *out, int size) { int ElementMaximum(const float *in0, const float *in1, float *out, int size) {
int index = 0; int index = 0;
#ifdef ENABLE_NEON #ifdef ENABLE_NEON
@ -173,6 +316,22 @@ int ElementMaximum(const float *in0, const float *in1, float *out, int size) {
return NNACL_OK; return NNACL_OK;
} }
int ElementOptMaximum(const float *in0, const float *in1, float *out, int size, const ArithmeticParameter *param) {
int index = 0;
if (param->in_elements_num0_ == 1) {
for (; index < size; index++) {
out[index] = in0[0] > in1[index] ? in0[0] : in1[index];
}
} else {
for (; index < size; index++) {
out[index] = in0[index] > in1[0] ? in0[index] : in1[0];
}
}
return NNACL_OK;
}
int ElementMaximumInt(const int *in0, const int *in1, int *out, int size) { int ElementMaximumInt(const int *in0, const int *in1, int *out, int size) {
int index = 0; int index = 0;
#ifdef ENABLE_NEON #ifdef ENABLE_NEON
@ -189,22 +348,53 @@ int ElementMaximumInt(const int *in0, const int *in1, int *out, int size) {
return NNACL_OK; return NNACL_OK;
} }
int ElementMinimumInt(const int *input0, const int *input1, int *output, const int element_size) { int ElementOptMaximumInt(const int *in0, const int *in1, int *out, int size, const ArithmeticParameter *param) {
int index = 0;
if (param->in_elements_num0_ == 1) {
for (; index < size; index++) {
out[index] = in0[0] > in1[index] ? in0[0] : in1[index];
}
} else {
for (; index < size; index++) {
out[index] = in0[index] > in1[0] ? in0[index] : in1[0];
}
}
return NNACL_OK;
}
int ElementMinimumInt(const int *input0, const int *input1, int *output, int size) {
int index = 0; int index = 0;
#ifdef ENABLE_NEON #ifdef ENABLE_NEON
for (; index <= element_size - 4; index += C4NUM) { for (; index <= size - 4; index += C4NUM) {
int32x4_t vin0 = vld1q_s32(input0 + index); int32x4_t vin0 = vld1q_s32(input0 + index);
int32x4_t vin1 = vld1q_s32(input1 + index); int32x4_t vin1 = vld1q_s32(input1 + index);
int32x4_t vout = vminq_s32(vin0, vin1); int32x4_t vout = vminq_s32(vin0, vin1);
vst1q_s32(output + index, vout); vst1q_s32(output + index, vout);
} }
#endif #endif
for (; index < element_size; index++) { for (; index < size; index++) {
output[index] = input0[index] > input1[index] ? input1[index] : input0[index]; output[index] = input0[index] > input1[index] ? input1[index] : input0[index];
} }
return NNACL_OK; return NNACL_OK;
} }
int ElementOptMinimumInt(const int *input0, const int *input1, int *output, int size,
const ArithmeticParameter *param) {
int index = 0;
if (param->in_elements_num0_ == 1) {
for (; index < size; index++) {
output[index] = input0[0] > input1[index] ? input1[index] : input0[0];
}
} else {
for (; index < size; index++) {
output[index] = input0[index] > input1[0] ? input1[0] : input0[index];
}
}
return NNACL_OK;
}
int ElementMinimum(const float *in0, const float *in1, float *out, int size) { int ElementMinimum(const float *in0, const float *in1, float *out, int size) {
int index = 0; int index = 0;
#ifdef ENABLE_NEON #ifdef ENABLE_NEON
@ -221,6 +411,21 @@ int ElementMinimum(const float *in0, const float *in1, float *out, int size) {
return NNACL_OK; return NNACL_OK;
} }
int ElementOptMinimum(const float *in0, const float *in1, float *out, int size, const ArithmeticParameter *param) {
int index = 0;
if (param->in_elements_num0_ == 1) {
for (; index < size; index++) {
out[index] = in0[0] > in1[index] ? in1[index] : in0[0];
}
} else {
for (; index < size; index++) {
out[index] = in0[index] > in1[0] ? in1[0] : in0[index];
}
}
return NNACL_OK;
}
#undef ACCURACY_DATA #undef ACCURACY_DATA
void TileOneDimensionFp32(const float *inData, float *outData, int dim, size_t ndim, const int *inShape, void TileOneDimensionFp32(const float *inData, float *outData, int dim, size_t ndim, const int *inShape,

View File

@ -37,31 +37,44 @@ void TileDimensionsFp32(const float *data0, const float *data1, float *tile_data
ArithmeticParameter *param); ArithmeticParameter *param);
/* logical and */ /* logical and */
int ElementLogicalAnd(const float *in0, const float *in1, float *out, int size); int ElementLogicalAnd(const float *in0, const float *in1, float *out, int size);
int ElementOptLogicalAnd(const float *in0, const float *in1, float *out, int size, const ArithmeticParameter *param);
int ElementLogicalAndInt(const int *in0, const int *in1, int *out, int size); int ElementLogicalAndInt(const int *in0, const int *in1, int *out, int size);
int ElementOptLogicalAndInt(const int *in0, const int *in1, int *out, int size, const ArithmeticParameter *param);
int ElementLogicalAndBool(const bool *in0, const bool *in1, bool *out, int size); int ElementLogicalAndBool(const bool *in0, const bool *in1, bool *out, int size);
int ElementOptLogicalAndBool(const bool *in0, const bool *in1, bool *out, int size, const ArithmeticParameter *param);
/* logical or */ /* logical or */
int ElementLogicalOr(const float *in0, const float *in1, float *out, int size); int ElementLogicalOr(const float *in0, const float *in1, float *out, int size);
int ElementOptLogicalOr(const float *in0, const float *in1, float *out, int size, const ArithmeticParameter *param);
int ElementLogicalOrBool(const bool *in0, const bool *in1, bool *out, int size); int ElementLogicalOrBool(const bool *in0, const bool *in1, bool *out, int size);
int ElementOptLogicalOrBool(const bool *in0, const bool *in1, bool *out, int size, const ArithmeticParameter *param);
/* max min */ /* max min */
int ElementMaximum(const float *in0, const float *in1, float *out, int size); int ElementMaximum(const float *in0, const float *in1, float *out, int size);
int ElementOptMaximum(const float *in0, const float *in1, float *out, int size, const ArithmeticParameter *param);
int ElementMinimum(const float *in0, const float *in1, float *out, int size); int ElementMinimum(const float *in0, const float *in1, float *out, int size);
int ElementOptMinimum(const float *in0, const float *in1, float *out, int size, const ArithmeticParameter *param);
int ElementMaximumInt(const int *in0, const int *in1, int *out, int size); int ElementMaximumInt(const int *in0, const int *in1, int *out, int size);
int ElementMinimumInt(const int *input0, const int *input1, int *output, const int element_size); int ElementOptMaximumInt(const int *in0, const int *in1, int *out, int size, const ArithmeticParameter *param);
int ElementMinimumInt(const int *input0, const int *input1, int *output, int size);
int ElementOptMinimumInt(const int *input0, const int *input1, int *output, int size, const ArithmeticParameter *param);
/* floor div */ /* floor div */
int ElementFloorDiv(const float *in0, const float *in1, float *out, int size); int ElementFloorDiv(const float *in0, const float *in1, float *out, int size);
int ElementOptFloorDiv(const float *in0, const float *in1, float *out, int size, const ArithmeticParameter *param);
int ElementFloorDivInt(const int *in0, const int *in1, int *out, int size); int ElementFloorDivInt(const int *in0, const int *in1, int *out, int size);
int ElementOptFloorDivInt(const int *in0, const int *in1, int *out, int size, const ArithmeticParameter *param);
/* floor mod */ /* floor mod */
int ElementFloorMod(const float *in0, const float *in1, float *out, int size); int ElementFloorMod(const float *in0, const float *in1, float *out, int size);
int ElementOptFloorMod(const float *in0, const float *in1, float *out, int size, const ArithmeticParameter *param);
int ElementFloorModInt(const int *in0, const int *in1, int *out, int size); int ElementFloorModInt(const int *in0, const int *in1, int *out, int size);
int ElementOptFloorModInt(const int *in0, const int *in1, int *out, int size, const ArithmeticParameter *param);
/* mod */ /* mod */
int ElementMod(const float *in0, const float *in1, float *out, int size); int ElementMod(const float *in0, const float *in1, float *out, int size);
int ElementModInt(const int *in0, const int *in1, int *out, int size);
int ElementOptMod(const float *in0, const float *in1, float *out, int size, const ArithmeticParameter *param); int ElementOptMod(const float *in0, const float *in1, float *out, int size, const ArithmeticParameter *param);
int ElementModInt(const int *in0, const int *in1, int *out, int size);
int ElementOptModInt(const int *in0, const int *in1, int *out, int size, const ArithmeticParameter *param); int ElementOptModInt(const int *in0, const int *in1, int *out, int size, const ArithmeticParameter *param);
#ifdef __cplusplus #ifdef __cplusplus

View File

@ -25,4 +25,9 @@ int ElementSquaredDifference(const float *in0, const float *in1, float *out, int
return ElementMul(out, out, out, size); return ElementMul(out, out, out, size);
} }
int ElementOptSquaredDifference(const float *in0, const float *in1, float *out, int size,
const ArithmeticParameter *param) {
ElementOptSub(in0, in1, out, size, param);
return ElementMul(out, out, out, size);
}
#endif // MINDSPORE_NNACL_SQUARED_DIFFERENCE_H_ #endif // MINDSPORE_NNACL_SQUARED_DIFFERENCE_H_

View File

@ -29,7 +29,8 @@ extern "C" {
/* Element Squared Difference */ /* Element Squared Difference */
int ElementSquaredDifference(const float *in0, const float *in1, float *out, int size); int ElementSquaredDifference(const float *in0, const float *in1, float *out, int size);
int ElementOptSquaredDifference(const float *in0, const float *in1, float *out, int size,
const ArithmeticParameter *param);
#ifdef __cplusplus #ifdef __cplusplus
} }
#endif #endif

View File

@ -32,9 +32,9 @@ int ElementSubInt(const int *in0, const int *in1, int *out, int size);
int ElementSubRelu(const float *in0, const float *in1, float *out, int size); int ElementSubRelu(const float *in0, const float *in1, float *out, int size);
int ElementSubRelu6(const float *in0, const float *in1, float *out, int size); int ElementSubRelu6(const float *in0, const float *in1, float *out, int size);
int ElementOptSub(const float *in0, const float *in1, float *out, int size, const ArithmeticParameter *param); int ElementOptSub(const float *in0, const float *in1, float *out, int size, const ArithmeticParameter *param);
int ElementOptSubInt(const int *in0, const int *in1, int *out, int size, const ArithmeticParameter *param);
int ElementOptSubRelu(const float *in0, const float *in1, float *out, int size, const ArithmeticParameter *param); int ElementOptSubRelu(const float *in0, const float *in1, float *out, int size, const ArithmeticParameter *param);
int ElementOptSubRelu6(const float *in0, const float *in1, float *out, int size, const ArithmeticParameter *param); int ElementOptSubRelu6(const float *in0, const float *in1, float *out, int size, const ArithmeticParameter *param);
int ElementOptSubInt(const int *in0, const int *in1, int *out, int size, const ArithmeticParameter *param);
#ifdef __cplusplus #ifdef __cplusplus
} }

View File

@ -63,36 +63,6 @@ int ArithmeticFP16CPUKernel::CheckDataType() {
return RET_OK; return RET_OK;
} }
bool ArithmeticFP16CPUKernel::IsScalarClac() { // 2 32 240 240, 1 1 1 1
if ((param_->in_elements_num0_ == 1 || param_->in_elements_num1_ == 1) && (arithmetic_opt_func_ != nullptr)) {
return true;
} else {
return false;
}
}
bool ArithmeticFP16CPUKernel::IsBatchScalarCalc() {
if (arithmetic_opt_func_ == nullptr) {
return false;
}
size_t break_axis = 0;
for (size_t i = 0; i < param_->ndim_; i++) {
if (param_->in_shape0_[i] != param_->in_shape1_[i]) {
break_axis = i;
break;
}
}
if (break_axis < param_->ndim_) {
for (size_t i = break_axis; i < param_->ndim_; i++) {
if (param_->in_shape1_[i] != 1) {
return false;
}
}
}
break_pos_ = break_axis;
return true;
}
void ArithmeticFP16CPUKernel::InitRunFunction(int primitive_type) { void ArithmeticFP16CPUKernel::InitRunFunction(int primitive_type) {
ARITHMETIC_FUNC_INFO_FP16 fun_table[] = { ARITHMETIC_FUNC_INFO_FP16 fun_table[] = {
{PrimitiveType_MulFusion, schema::ActivationType_RELU, ElementMulReluFp16, ElementOptMulReluFp16}, {PrimitiveType_MulFusion, schema::ActivationType_RELU, ElementMulReluFp16, ElementOptMulReluFp16},
@ -171,6 +141,7 @@ int ArithmeticFP16CPUKernel::Run() {
MS_LOG(ERROR) << "ArithmeticFP16CPUKernel check dataType failed."; MS_LOG(ERROR) << "ArithmeticFP16CPUKernel check dataType failed.";
return RET_ERROR; return RET_ERROR;
} }
if (!input0_broadcast_) { if (!input0_broadcast_) {
input0_ptr_ = ConvertInputFp32toFp16(in_tensors_.at(0), static_cast<const lite::InnerContext *>(this->ms_context_)); input0_ptr_ = ConvertInputFp32toFp16(in_tensors_.at(0), static_cast<const lite::InnerContext *>(this->ms_context_));
} }
@ -183,10 +154,16 @@ int ArithmeticFP16CPUKernel::Run() {
FreeFp16Buffer(); FreeFp16Buffer();
return RET_ERROR; return RET_ERROR;
} }
batch_a_ptr_ = static_cast<uint8_t *>(input0_ptr_);
batch_b_ptr_ = static_cast<uint8_t *>(input1_ptr_);
batch_c_ptr_ = static_cast<uint8_t *>(output_ptr_);
auto ret = ParallelLaunch(this->ms_context_, ArithmeticsRun, this, op_parameter_->thread_num_); auto ret = ParallelLaunch(this->ms_context_, ArithmeticsRun, this, op_parameter_->thread_num_);
if (ret != RET_OK) { if (ret != RET_OK) {
MS_LOG(ERROR) << "ArithmeticsRun failed, ret : " << ret; MS_LOG(ERROR) << "ArithmeticsRun failed, ret : " << ret;
return RET_ERROR;
} }
if (out_tensors_.at(0)->data_type() == kNumberTypeFloat32) { if (out_tensors_.at(0)->data_type() == kNumberTypeFloat32) {
Float16ToFloat32(static_cast<float16_t *>(output_ptr_), reinterpret_cast<float *>(output_tensor->data()), Float16ToFloat32(static_cast<float16_t *>(output_ptr_), reinterpret_cast<float *>(output_tensor->data()),
output_tensor->ElementsNum()); output_tensor->ElementsNum());

View File

@ -40,8 +40,6 @@ class ArithmeticFP16CPUKernel : public ArithmeticCPUKernel {
~ArithmeticFP16CPUKernel() = default; ~ArithmeticFP16CPUKernel() = default;
int ReSize() override; int ReSize() override;
int Run() override; int Run() override;
bool IsBatchScalarCalc() override;
bool IsScalarClac() override;
private: private:
void InitRunFunction(int primitive_type) override; void InitRunFunction(int primitive_type) override;

View File

@ -28,85 +28,110 @@ using mindspore::schema::PrimitiveType_LessEqual;
using mindspore::schema::PrimitiveType_NotEqual; using mindspore::schema::PrimitiveType_NotEqual;
namespace mindspore::kernel { namespace mindspore::kernel {
int ArithmeticCompareCPUKernel::BroadcastRun(void *input0, void *input1, void *output, int dim, int out_count, void ArithmeticCompareCPUKernel::InitRunFunction(int primitive_type) {
int out_thread_stride) { ARITHMETIC_COMEPARE_FUNC_INFO_FP32 fun_table[] = {
if (dim > break_pos_) { {PrimitiveType_Equal, ElementEqualFp32, ElementEqualInt32, ElementOptEqualFp32, ElementOptEqualInt32},
if (in_tensors_[0]->data_type() == kNumberTypeInt || in_tensors_[0]->data_type() == kNumberTypeInt32) { {PrimitiveType_NotEqual, ElementNotEqualFp32, ElementNotEqualInt32, ElementOptNotEqualFp32,
return func_int32_(reinterpret_cast<int *>(input0) + out_thread_stride, ElementOptNotEqualInt32},
reinterpret_cast<int *>(input1) + out_thread_stride, {PrimitiveType_Less, ElementLessFp32, ElementLessInt32, ElementOptLessFp32, ElementOptLessInt32},
reinterpret_cast<uint8_t *>(output) + out_thread_stride, out_count); {PrimitiveType_LessEqual, ElementLessEqualFp32, ElementLessEqualInt32, ElementOptLessEqualFp32,
ElementOptLessEqualInt32},
{PrimitiveType_Greater, ElementGreaterFp32, ElementGreaterInt32, ElementOptGreaterFp32, ElementOptGreaterInt32},
{PrimitiveType_GreaterEqual, ElementGreaterEqualFp32, ElementGreaterEqualInt32, ElementOptGreaterEqualFp32,
ElementOptGreaterEqualInt32}};
size_t length = sizeof(fun_table) / sizeof(ARITHMETIC_COMEPARE_FUNC_INFO_FP32);
for (size_t i = 0; i < length; i++) {
if (fun_table[i].primitive_type_ == primitive_type) {
func_fp32_ = fun_table[i].func_;
func_int32_ = fun_table[i].int_func_;
opt_func_fp32_ = fun_table[i].opt_func_;
opt_func_int32_ = fun_table[i].opt_int_func_;
return;
} }
return func_fp32_(reinterpret_cast<float *>(input0) + out_thread_stride,
reinterpret_cast<float *>(input1) + out_thread_stride,
reinterpret_cast<uint8_t *>(output) + out_thread_stride, out_count);
} }
for (int i = 0; i < param_->out_shape_[dim]; ++i) { }
int pos0_ = param_->in_shape0_[dim] == 1 ? 0 : i;
int pos1_ = param_->in_shape1_[dim] == 1 ? 0 : i; int ArithmeticCompareCPUKernel::Execute(const void *input0, const void *input1, void *output, int size, bool is_opt) {
int error_code; int ret = RET_OK;
if (in_tensors_[0]->data_type() == kNumberTypeInt || in_tensors_[0]->data_type() == kNumberTypeInt32) { if (in_tensors_[0]->data_type() == kNumberTypeFloat32) {
error_code = BroadcastRun(reinterpret_cast<int *>(input0) + pos0_ * param_->in_strides0_[dim], if (is_opt) {
reinterpret_cast<int *>(input1) + pos1_ * param_->in_strides1_[dim], CHECK_NULL_RETURN(opt_func_fp32_);
reinterpret_cast<uint8_t *>(output) + i * param_->out_strides_[dim], dim + 1, out_count, ret = opt_func_fp32_(reinterpret_cast<const float *>(input0), reinterpret_cast<const float *>(input1),
out_thread_stride); reinterpret_cast<uint8_t *>(output), size, param_);
} else { } else {
error_code = BroadcastRun(reinterpret_cast<float *>(input0) + pos0_ * param_->in_strides0_[dim], CHECK_NULL_RETURN(func_fp32_);
reinterpret_cast<float *>(input1) + pos1_ * param_->in_strides1_[dim], ret = func_fp32_(reinterpret_cast<const float *>(input0), reinterpret_cast<const float *>(input1),
reinterpret_cast<uint8_t *>(output) + i * param_->out_strides_[dim], dim + 1, out_count, reinterpret_cast<uint8_t *>(output), size);
out_thread_stride);
} }
if (error_code != RET_OK) { } else if (in_tensors_[0]->data_type() == kNumberTypeInt || in_tensors_[0]->data_type() == kNumberTypeInt32) {
return error_code; if (is_opt) {
CHECK_NULL_RETURN(opt_func_int32_);
ret = opt_func_int32_(reinterpret_cast<const int *>(input0), reinterpret_cast<const int *>(input1),
reinterpret_cast<uint8_t *>(output), size, param_);
} else {
CHECK_NULL_RETURN(func_int32_);
ret = func_int32_(reinterpret_cast<const int *>(input0), reinterpret_cast<const int *>(input1),
reinterpret_cast<uint8_t *>(output), size);
}
} else {
MS_LOG(ERROR) << "Error Operator type " << kNumberTypeInt32;
return RET_ERROR;
}
return ret;
}
int ArithmeticCompareCPUKernel::CalcArithmeticByBatch(int task_id) {
if (break_pos_ > ARITHMETIC_SUPPORT_DIMS_NUM || param_->out_strides_[break_pos_ - 1] == 0) {
MS_LOG(ERROR) << "param_->out_strides_[break_pos_ - 1] is 0 or break_pos_ is > 10";
return RET_ERROR;
}
int batch_per_thread = UP_DIV(out_batch_, op_parameter_->thread_num_);
int start_batch = batch_per_thread * task_id;
int end_batch = MSMIN(start_batch + batch_per_thread, out_batch_);
int ret = RET_ERROR;
for (int i = start_batch; i < end_batch; i++) {
batch_a_ptr_ = static_cast<uint8_t *>(input0_ptr_) + a_offset_[i] * a_stride_size_ * data_type_len_;
batch_b_ptr_ = static_cast<uint8_t *>(input1_ptr_) + b_offset_[i] * b_stride_size_ * data_type_len_;
batch_c_ptr_ = static_cast<uint8_t *>(output_ptr_) + i * c_stride_size_ * sizeof(uint8_t);
if (batch_scalar_) {
ret = Execute(batch_a_ptr_, batch_b_ptr_, batch_c_ptr_, c_stride_size_, true);
} else {
ret = Execute(batch_a_ptr_, batch_b_ptr_, batch_c_ptr_, c_stride_size_, false);
}
if (ret != RET_OK) {
MS_LOG(ERROR) << "failed to calculate.";
return RET_ERROR;
} }
} }
return RET_OK; return ret;
} }
int ArithmeticCompareCPUKernel::DoArithmetic(int task_id) { int ArithmeticCompareCPUKernel::DoArithmetic(int task_id) {
auto element_num = out_tensors_[0]->ElementsNum(); if (split_by_batch_) {
return CalcArithmeticByBatch(task_id);
}
MS_ASSERT(op_parameter_->thread_num_ != 0); int64_t element_num = out_tensors_[0]->ElementsNum();
auto ret = RET_ERROR;
int stride = UP_DIV(element_num, op_parameter_->thread_num_); int stride = UP_DIV(element_num, op_parameter_->thread_num_);
int count = MSMIN(stride, element_num - stride * task_id); int count = MSMIN(stride, element_num - stride * task_id);
if (count <= 0) { if (count <= 0) {
return RET_OK; return RET_OK;
} }
CHECK_LESS_RETURN(ARITHMETIC_SUPPORT_DIMS_NUM, param_->ndim_);
if (func_fp32_ == nullptr) { int in_offset = stride * task_id * data_type_len_;
MS_LOG(ERROR) << "func_fp32_ function is nullptr!"; int out_offset = stride * task_id * sizeof(uint8_t);
return RET_ERROR; if (scalar_) {
} if (param_->in_elements_num0_ == 1) {
ret = Execute(batch_a_ptr_, batch_b_ptr_ + in_offset, batch_c_ptr_ + out_offset, count, true);
int error_code;
if (param_->broadcasting_) { // need broadcast
stride = UP_DIV(outside_, op_parameter_->thread_num_);
int out_count = MSMIN(stride, outside_ - stride * task_id);
int out_thread_stride = stride * task_id;
if (out_count <= 0) {
return RET_OK;
}
if (in_tensors_[0]->data_type() == kNumberTypeFloat32) {
error_code = BroadcastRun(reinterpret_cast<float *>(input0_ptr_), reinterpret_cast<float *>(input1_ptr_),
reinterpret_cast<uint8_t *>(out_tensors_[0]->data()), 0, out_count, out_thread_stride);
} else { } else {
error_code = BroadcastRun(reinterpret_cast<int *>(input0_ptr_), reinterpret_cast<int *>(input1_ptr_), ret = Execute(batch_a_ptr_ + in_offset, batch_b_ptr_, batch_c_ptr_ + out_offset, count, true);
reinterpret_cast<uint8_t *>(out_tensors_[0]->data()), 0, out_count, out_thread_stride);
}
} else { // no broadcast, neither is scalar, two same shape
if (in_tensors_[0]->data_type() == kNumberTypeFloat32) {
error_code = func_fp32_(reinterpret_cast<float *>(input0_ptr_) + stride * task_id,
reinterpret_cast<float *>(input1_ptr_) + stride * task_id,
reinterpret_cast<uint8_t *>(out_tensors_[0]->data()) + stride * task_id, count);
} else {
error_code = func_int32_(reinterpret_cast<int *>(input0_ptr_) + stride * task_id,
reinterpret_cast<int *>(input1_ptr_) + stride * task_id,
reinterpret_cast<uint8_t *>(out_tensors_[0]->data()) + stride * task_id, count);
} }
} else {
ret = Execute(batch_a_ptr_ + in_offset, batch_b_ptr_ + in_offset, batch_c_ptr_ + out_offset, count, false);
} }
if (error_code != RET_OK) { return ret;
return RET_ERROR;
}
return RET_OK;
} }
REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_Equal, LiteKernelCreator<ArithmeticCompareCPUKernel>) REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_Equal, LiteKernelCreator<ArithmeticCompareCPUKernel>)

View File

@ -21,53 +21,38 @@
#include "nnacl/fp32/arithmetic_compare_fp32.h" #include "nnacl/fp32/arithmetic_compare_fp32.h"
namespace mindspore::kernel { namespace mindspore::kernel {
typedef int (*ArithmeticCompareFp32Func)(const float *input0, const float *input1, uint8_t *output, int element_size);
typedef int (*ArithmeticCompareIntFunc)(const int *input0, const int *input1, uint8_t *output, int element_size);
class ArithmeticCompareCPUKernel : public ArithmeticCPUKernel { class ArithmeticCompareCPUKernel : public ArithmeticCPUKernel {
typedef int (*ArithmeticCompareFp32Func)(const float *input0, const float *input1, uint8_t *output, int element_size);
typedef int (*ArithmeticCompareIntFunc)(const int *input0, const int *input1, uint8_t *output, int element_size);
typedef int (*ArithmeticOptCompareFp32Func)(const float *input0, const float *input1, uint8_t *output,
int element_size, const ArithmeticParameter *param);
typedef int (*ArithmeticOptCompareIntFunc)(const int *input0, const int *input1, uint8_t *output, int element_size,
const ArithmeticParameter *param);
typedef struct {
int primitive_type_;
ArithmeticCompareFp32Func func_;
ArithmeticCompareIntFunc int_func_;
ArithmeticOptCompareFp32Func opt_func_;
ArithmeticOptCompareIntFunc opt_int_func_;
} ARITHMETIC_COMEPARE_FUNC_INFO_FP32;
public: public:
explicit ArithmeticCompareCPUKernel(OpParameter *parameter, const std::vector<lite::Tensor *> &inputs, explicit ArithmeticCompareCPUKernel(OpParameter *parameter, const std::vector<lite::Tensor *> &inputs,
const std::vector<lite::Tensor *> &outputs, const lite::InnerContext *ctx) const std::vector<lite::Tensor *> &outputs, const lite::InnerContext *ctx)
: ArithmeticCPUKernel(parameter, inputs, outputs, ctx) { : ArithmeticCPUKernel(parameter, inputs, outputs, ctx) {}
switch (parameter->type_) {
case PrimitiveType_Equal:
func_fp32_ = ElementEqualFp32;
func_int32_ = ElementEqualInt32;
break;
case PrimitiveType_NotEqual:
func_fp32_ = ElementNotEqualFp32;
func_int32_ = ElementNotEqualInt32;
break;
case PrimitiveType_Less:
func_fp32_ = ElementLessFp32;
func_int32_ = ElementLessInt32;
break;
case PrimitiveType_LessEqual:
func_fp32_ = ElementLessEqualFp32;
func_int32_ = ElementLessEqualInt32;
break;
case PrimitiveType_Greater:
func_fp32_ = ElementGreaterFp32;
func_int32_ = ElementGreaterInt32;
break;
case PrimitiveType_GreaterEqual:
func_fp32_ = ElementGreaterEqualFp32;
func_int32_ = ElementGreaterEqualInt32;
break;
default:
MS_LOG(ERROR) << "Error Operator type " << parameter->type_;
func_fp32_ = nullptr;
func_int32_ = nullptr;
break;
}
}
~ArithmeticCompareCPUKernel() override = default; ~ArithmeticCompareCPUKernel() override = default;
int DoArithmetic(int task_id) override; int DoArithmetic(int task_id) override;
int BroadcastRun(void *input0, void *input1, void *output, int dim, int out_count, int out_thread_stride) override;
protected:
void InitRunFunction(int primitive_type) override;
int Execute(const void *input0, const void *input1, void *output, int size, bool is_opt) override;
int CalcArithmeticByBatch(int task_id) override;
private: private:
ArithmeticCompareFp32Func func_fp32_ = nullptr; ArithmeticCompareFp32Func func_fp32_ = nullptr;
ArithmeticCompareIntFunc func_int32_ = nullptr; ArithmeticCompareIntFunc func_int32_ = nullptr;
ArithmeticOptCompareFp32Func opt_func_fp32_ = nullptr;
ArithmeticOptCompareIntFunc opt_func_int32_ = nullptr;
}; };
int ArithmeticCompareRun(void *cdata, int task_id, float lhs_scale, float rhs_scale); int ArithmeticCompareRun(void *cdata, int task_id, float lhs_scale, float rhs_scale);
} // namespace mindspore::kernel } // namespace mindspore::kernel

View File

@ -50,27 +50,133 @@ int ArithmeticCPUKernel::Prepare() {
} }
return ReSize(); return ReSize();
} }
bool ArithmeticCPUKernel::IsScalarClac() {
if (param_->in_elements_num0_ == 1 || param_->in_elements_num1_ == 1) {
return true;
}
return false;
}
int ArithmeticCPUKernel::ReSize() { int ArithmeticCPUKernel::ReSize() {
CalcMultiplesAndStrides(param_); CalcMultiplesAndStrides(param_);
if (param_->broadcasting_) { scalar_ = IsScalarClac();
outside_ = 1; int ret = RET_OK;
for (int i = static_cast<int>(param_->ndim_) - 1; i >= 0 && i < ARITHMETIC_SUPPORT_DIMS_NUM; --i) { if (!scalar_) {
if (param_->in_shape0_[i] != param_->in_shape1_[i]) { ret = ConstTensorBroadCast();
break_pos_ = i; if (ret != RET_OK) {
break; MS_LOG(ERROR) << "failed to init const tensor";
} return ret;
outside_ *= param_->out_shape_[i];
} }
} }
data_type_len_ = lite::DataTypeSize(in_tensors_.at(0)->data_type()); if (!scalar_ && param_->broadcasting_) {
int ret = RET_OK; ret = InitIndexOffsetInfo();
if (!IsScalarClac() && !IsBatchScalarCalc() && !IsBiasCalc()) {
ret = ConstTensorBroadCast();
} }
data_type_len_ = lite::DataTypeSize(in_tensors_.at(0)->data_type());
return ret; return ret;
} }
bool ArithmeticCPUKernel::IsBatchScalarCalc() { // 1 32 240 240, 2 32 1 1
int last_batch_axis0 = ARITHMETIC_SUPPORT_DIMS_NUM + 1;
int last_batch_axis1 = ARITHMETIC_SUPPORT_DIMS_NUM + 1;
if (param_->in_shape0_[param_->ndim_ - 1] == 1) {
for (int i = static_cast<int>(param_->ndim_) - 1; i >= 0 && i < ARITHMETIC_SUPPORT_DIMS_NUM; --i) {
if (param_->in_shape0_[i] != 1) {
last_batch_axis0 = i;
break;
}
}
}
if (param_->in_shape1_[param_->ndim_ - 1] == 1) {
for (int i = static_cast<int>(param_->ndim_) - 1; i >= 0 && i < ARITHMETIC_SUPPORT_DIMS_NUM; --i) {
if (param_->in_shape1_[i] != 1) {
last_batch_axis1 = i;
break;
}
}
}
int min_axis = MSMIN(last_batch_axis0, last_batch_axis1);
if (min_axis < static_cast<int>(param_->ndim_) - 1) {
last_batch_axis_ = min_axis;
if (last_batch_axis0 < last_batch_axis1) {
param_->in_elements_num0_ = 1;
} else {
param_->in_elements_num1_ = 1;
}
return true;
}
return false;
}
int ArithmeticCPUKernel::InitIndexOffsetInfo() {
split_by_batch_ = true;
for (int i = static_cast<int>(param_->ndim_) - 1; i >= 0 && i < ARITHMETIC_SUPPORT_DIMS_NUM; --i) {
if (param_->in_shape0_[i] != param_->in_shape1_[i]) {
break_pos_ = i;
break;
}
}
std::vector<int> a_shape;
std::vector<int> b_shape;
std::vector<int> c_shape = out_tensors_[0]->shape();
size_t dim = c_shape.size();
for (size_t i = 0; i < dim; ++i) {
a_shape.push_back(param_->in_shape0_[i]);
b_shape.push_back(param_->in_shape1_[i]);
}
batch_scalar_ = IsBatchScalarCalc();
a_stride_size_ = 1;
b_stride_size_ = 1;
c_stride_size_ = 1;
int last_batch_axis = batch_scalar_ ? last_batch_axis_ : break_pos_;
for (int i = static_cast<int>(param_->ndim_) - 1; i > last_batch_axis && i < ARITHMETIC_SUPPORT_DIMS_NUM; --i) {
a_stride_size_ *= a_shape[i];
b_stride_size_ *= b_shape[i];
c_stride_size_ *= c_shape[i];
}
out_batch_ = 1;
int batch_size[ARITHMETIC_SUPPORT_DIMS_NUM] = {};
int a_batch_size[ARITHMETIC_SUPPORT_DIMS_NUM] = {};
int b_batch_size[ARITHMETIC_SUPPORT_DIMS_NUM] = {};
for (int i = last_batch_axis; i >= 0; --i) {
out_batch_ *= c_shape[i];
if (i == last_batch_axis) {
batch_size[i] = c_shape[i];
a_batch_size[i] = a_shape[i];
b_batch_size[i] = b_shape[i];
} else {
batch_size[i] = batch_size[i + 1] * c_shape[i];
a_batch_size[i] = a_batch_size[i + 1] * a_shape[i];
b_batch_size[i] = b_batch_size[i + 1] * b_shape[i];
}
}
a_offset_.resize(out_batch_, 0);
b_offset_.resize(out_batch_, 0);
for (int i = 0; i < out_batch_; ++i) {
int delta = i;
int a_offset = 0;
int b_offset = 0;
for (int j = 0; j <= last_batch_axis; ++j) {
if (j > 0) {
delta = delta % batch_size[j];
}
if (j < last_batch_axis) {
a_offset += (delta / batch_size[j + 1] * a_shape[j] / MSMAX(a_shape[j], b_shape[j])) * a_batch_size[j + 1];
b_offset += (delta / batch_size[j + 1] * b_shape[j] / MSMAX(a_shape[j], b_shape[j])) * b_batch_size[j + 1];
} else {
a_offset += (delta * a_shape[j] / MSMAX(a_shape[j], b_shape[j]));
b_offset += (delta * b_shape[j] / MSMAX(a_shape[j], b_shape[j]));
}
}
a_offset_[i] = a_offset;
b_offset_[i] = b_offset;
}
return RET_OK;
}
int ArithmeticCPUKernel::CheckDataType() { int ArithmeticCPUKernel::CheckDataType() {
auto in0_dataType = in_tensors_.at(0)->data_type(); auto in0_dataType = in_tensors_.at(0)->data_type();
auto in1_dataType = in_tensors_.at(1)->data_type(); auto in1_dataType = in_tensors_.at(1)->data_type();
@ -85,47 +191,6 @@ int ArithmeticCPUKernel::CheckDataType() {
return RET_OK; return RET_OK;
} }
bool ArithmeticCPUKernel::IsScalarClac() { // 2 32 240 240, 1 1 1 1
if ((param_->in_elements_num0_ == 1 || param_->in_elements_num1_ == 1) && (arithmetic_opt_run_ != nullptr)) {
return true;
} else {
return false;
}
}
bool ArithmeticCPUKernel::IsBatchScalarCalc() { // 2 32 240 240, 2 32 1 1
if (arithmetic_opt_run_ == nullptr) {
return false;
}
size_t break_axis = 0;
for (size_t i = 0; i < param_->ndim_; i++) {
if (param_->in_shape0_[i] != param_->in_shape1_[i]) {
break_axis = i;
break;
}
}
if (break_axis < param_->ndim_) {
for (size_t i = break_axis; i < param_->ndim_; i++) {
if (param_->in_shape1_[i] != 1) {
return false;
}
}
}
break_pos_ = break_axis;
return true;
}
bool ArithmeticCPUKernel::IsBiasCalc() const { // 2 240 240 32, 1 1 1 32
int last_shape0 = param_->in_shape0_[param_->ndim_ - 1];
int last_shape1 = param_->in_shape1_[param_->ndim_ - 1];
if (param_->in_elements_num0_ > param_->in_elements_num1_) {
return param_->in_elements_num1_ == last_shape1 && last_shape0 == last_shape1;
} else if (param_->in_elements_num0_ < param_->in_elements_num1_) {
return param_->in_elements_num0_ == last_shape0 && last_shape0 == last_shape1;
}
return false;
}
int ArithmeticCPUKernel::ConstTensorBroadCast() { int ArithmeticCPUKernel::ConstTensorBroadCast() {
/* if const node need broadcast and all need-broadcast-node are const, broadcast in resize */ /* if const node need broadcast and all need-broadcast-node are const, broadcast in resize */
if (!param_->broadcasting_) { if (!param_->broadcasting_) {
@ -168,13 +233,11 @@ int ArithmeticCPUKernel::ConstTensorBroadCast() {
} }
} }
// broadcast input and get new break_pos_ // broadcast input and get new break_pos_
outside_ = 1;
for (int i = static_cast<int>(param_->ndim_) - 1; i >= 0; --i) { for (int i = static_cast<int>(param_->ndim_) - 1; i >= 0; --i) {
if (param_->in_shape0_[i] != param_->in_shape1_[i]) { if (param_->in_shape0_[i] != param_->in_shape1_[i]) {
break_pos_ = i; break_pos_ = i;
break; break;
} }
outside_ *= param_->out_shape_[i];
} }
if (param_->in_elements_num0_ == param_->out_elements_num_ && if (param_->in_elements_num0_ == param_->out_elements_num_ &&
param_->in_elements_num1_ == param_->out_elements_num_) { param_->in_elements_num1_ == param_->out_elements_num_) {
@ -205,50 +268,51 @@ void ArithmeticCPUKernel::FreeConstTileBuff() {
void ArithmeticCPUKernel::InitRunFunction(int primitive_type) { void ArithmeticCPUKernel::InitRunFunction(int primitive_type) {
ARITHMETIC_FUNC_INFO_FP32 fun_table[] = { ARITHMETIC_FUNC_INFO_FP32 fun_table[] = {
{PrimitiveType_MulFusion, schema::ActivationType_RELU, ElementMulRelu, ElementMulReluInt, nullptr, {PrimitiveType_MulFusion, schema::ActivationType_RELU, ElementMulRelu, ElementMulReluInt, nullptr,
ElementOptMulRelu, ElementOptMulReluInt}, ElementOptMulRelu, ElementOptMulReluInt, nullptr},
{PrimitiveType_MulFusion, schema::ActivationType_RELU6, ElementMulRelu6, ElementMulRelu6Int, nullptr, {PrimitiveType_MulFusion, schema::ActivationType_RELU6, ElementMulRelu6, ElementMulRelu6Int, nullptr,
ElementOptMulRelu6, ElementOptMulRelu6Int}, ElementOptMulRelu6, ElementOptMulRelu6Int, nullptr},
{PrimitiveType_MulFusion, schema::ActivationType_NO_ACTIVATION, ElementMul, ElementMulInt, nullptr, ElementOptMul, {PrimitiveType_MulFusion, schema::ActivationType_NO_ACTIVATION, ElementMul, ElementMulInt, nullptr, ElementOptMul,
ElementOptMulInt}, ElementOptMulInt, nullptr},
{PrimitiveType_AddFusion, schema::ActivationType_RELU, ElementAddRelu, nullptr, nullptr, ElementOptAddRelu, {PrimitiveType_AddFusion, schema::ActivationType_RELU, ElementAddRelu, nullptr, nullptr, ElementOptAddRelu, nullptr,
nullptr}, nullptr},
{PrimitiveType_AddFusion, schema::ActivationType_RELU6, ElementAddRelu6, nullptr, nullptr, ElementOptAddRelu6, {PrimitiveType_AddFusion, schema::ActivationType_RELU6, ElementAddRelu6, nullptr, nullptr, ElementOptAddRelu6,
nullptr}, nullptr, nullptr},
{PrimitiveType_AddFusion, schema::ActivationType_NO_ACTIVATION, ElementAdd, ElementAddInt, nullptr, ElementOptAdd, {PrimitiveType_AddFusion, schema::ActivationType_NO_ACTIVATION, ElementAdd, ElementAddInt, nullptr, ElementOptAdd,
ElementOptAddInt}, ElementOptAddInt, nullptr},
{PrimitiveType_SubFusion, schema::ActivationType_RELU, ElementSubRelu, nullptr, nullptr, ElementOptSubRelu, {PrimitiveType_SubFusion, schema::ActivationType_RELU, ElementSubRelu, nullptr, nullptr, ElementOptSubRelu, nullptr,
nullptr}, nullptr},
{PrimitiveType_SubFusion, schema::ActivationType_RELU6, ElementSubRelu6, nullptr, nullptr, ElementOptSubRelu6, {PrimitiveType_SubFusion, schema::ActivationType_RELU6, ElementSubRelu6, nullptr, nullptr, ElementOptSubRelu6,
nullptr}, nullptr, nullptr},
{PrimitiveType_SubFusion, schema::ActivationType_NO_ACTIVATION, ElementSub, ElementSubInt, nullptr, ElementOptSub, {PrimitiveType_SubFusion, schema::ActivationType_NO_ACTIVATION, ElementSub, ElementSubInt, nullptr, ElementOptSub,
ElementOptSubInt}, ElementOptSubInt, nullptr},
{PrimitiveType_DivFusion, schema::ActivationType_RELU, ElementDivRelu, nullptr, nullptr, ElementOptDivRelu, {PrimitiveType_DivFusion, schema::ActivationType_RELU, ElementDivRelu, nullptr, nullptr, ElementOptDivRelu, nullptr,
nullptr}, nullptr},
{PrimitiveType_DivFusion, schema::ActivationType_RELU6, ElementDivRelu6, nullptr, nullptr, ElementOptDivRelu6, {PrimitiveType_DivFusion, schema::ActivationType_RELU6, ElementDivRelu6, nullptr, nullptr, ElementOptDivRelu6,
nullptr}, nullptr, nullptr},
{PrimitiveType_DivFusion, schema::ActivationType_NO_ACTIVATION, ElementDiv, nullptr, nullptr, ElementOptDiv, {PrimitiveType_DivFusion, schema::ActivationType_NO_ACTIVATION, ElementDiv, nullptr, nullptr, ElementOptDiv,
ElementOptDivInt}, ElementOptDivInt, nullptr},
{PrimitiveType_RealDiv, schema::ActivationType_RELU, ElementDivRelu, nullptr, nullptr, ElementOptDivRelu, nullptr}, {PrimitiveType_RealDiv, schema::ActivationType_RELU, ElementDivRelu, nullptr, nullptr, ElementOptDivRelu, nullptr,
nullptr},
{PrimitiveType_RealDiv, schema::ActivationType_RELU6, ElementDivRelu6, nullptr, nullptr, ElementOptDivRelu6, {PrimitiveType_RealDiv, schema::ActivationType_RELU6, ElementDivRelu6, nullptr, nullptr, ElementOptDivRelu6,
nullptr}, nullptr, nullptr},
{PrimitiveType_RealDiv, schema::ActivationType_NO_ACTIVATION, ElementDiv, nullptr, nullptr, ElementOptDiv, {PrimitiveType_RealDiv, schema::ActivationType_NO_ACTIVATION, ElementDiv, nullptr, nullptr, ElementOptDiv,
ElementOptDivInt}, ElementOptDivInt, nullptr},
{PrimitiveType_LogicalAnd, schema::ActivationType_NO_ACTIVATION, ElementLogicalAnd, ElementLogicalAndInt, {PrimitiveType_LogicalAnd, schema::ActivationType_NO_ACTIVATION, ElementLogicalAnd, ElementLogicalAndInt,
ElementLogicalAndBool, nullptr, nullptr}, ElementLogicalAndBool, ElementOptLogicalAnd, ElementOptLogicalAndInt, ElementOptLogicalAndBool},
{PrimitiveType_LogicalOr, schema::ActivationType_NO_ACTIVATION, ElementLogicalOr, nullptr, ElementLogicalOrBool, {PrimitiveType_LogicalOr, schema::ActivationType_NO_ACTIVATION, ElementLogicalOr, nullptr, ElementLogicalOrBool,
nullptr, nullptr}, nullptr, nullptr, ElementOptLogicalOrBool},
{PrimitiveType_Maximum, schema::ActivationType_NO_ACTIVATION, ElementMaximum, ElementMaximumInt, nullptr, nullptr, {PrimitiveType_Maximum, schema::ActivationType_NO_ACTIVATION, ElementMaximum, ElementMaximumInt, nullptr,
nullptr}, ElementOptMaximum, ElementOptMaximumInt, nullptr},
{PrimitiveType_Minimum, schema::ActivationType_NO_ACTIVATION, ElementMinimum, ElementMinimumInt, nullptr, nullptr, {PrimitiveType_Minimum, schema::ActivationType_NO_ACTIVATION, ElementMinimum, ElementMinimumInt, nullptr,
nullptr}, ElementOptMinimum, ElementOptMinimumInt, nullptr},
{PrimitiveType_FloorMod, schema::ActivationType_NO_ACTIVATION, ElementFloorMod, ElementFloorModInt, nullptr, {PrimitiveType_FloorMod, schema::ActivationType_NO_ACTIVATION, ElementFloorMod, ElementFloorModInt, nullptr,
nullptr, nullptr}, ElementOptFloorMod, ElementOptFloorModInt, nullptr},
{PrimitiveType_FloorDiv, schema::ActivationType_NO_ACTIVATION, ElementFloorDiv, ElementFloorDivInt, nullptr, {PrimitiveType_FloorDiv, schema::ActivationType_NO_ACTIVATION, ElementFloorDiv, ElementFloorDivInt, nullptr,
nullptr, nullptr}, ElementOptFloorDiv, ElementOptFloorDivInt, nullptr},
{PrimitiveType_Mod, schema::ActivationType_NO_ACTIVATION, ElementMod, ElementModInt, nullptr, ElementOptMod, {PrimitiveType_Mod, schema::ActivationType_NO_ACTIVATION, ElementMod, ElementModInt, nullptr, ElementOptMod,
ElementOptModInt}, ElementOptModInt, nullptr},
{PrimitiveType_SquaredDifference, schema::ActivationType_NO_ACTIVATION, ElementSquaredDifference, nullptr, nullptr, {PrimitiveType_SquaredDifference, schema::ActivationType_NO_ACTIVATION, ElementSquaredDifference, nullptr, nullptr,
nullptr, nullptr}}; ElementOptSquaredDifference, nullptr, nullptr}};
size_t length = sizeof(fun_table) / sizeof(ARITHMETIC_FUNC_INFO_FP32); size_t length = sizeof(fun_table) / sizeof(ARITHMETIC_FUNC_INFO_FP32);
for (size_t i = 0; i < length; i++) { for (size_t i = 0; i < length; i++) {
@ -258,6 +322,7 @@ void ArithmeticCPUKernel::InitRunFunction(int primitive_type) {
arithmetic_run_bool_ = fun_table[i].bool_func_; arithmetic_run_bool_ = fun_table[i].bool_func_;
arithmetic_opt_run_ = fun_table[i].opt_func_; arithmetic_opt_run_ = fun_table[i].opt_func_;
arithmetic_opt_run_int_ = fun_table[i].opt_int_func_; arithmetic_opt_run_int_ = fun_table[i].opt_int_func_;
arithmetic_opt_run_bool_ = fun_table[i].opt_bool_func_;
return; return;
} }
} }
@ -276,9 +341,15 @@ int ArithmeticCPUKernel::Execute(const void *input0, const void *input1, void *o
reinterpret_cast<float *>(output), size); reinterpret_cast<float *>(output), size);
} }
} else if (in_tensors_[0]->data_type() == kNumberTypeBool) { } else if (in_tensors_[0]->data_type() == kNumberTypeBool) {
CHECK_NULL_RETURN(arithmetic_run_bool_); if (is_opt) {
ret = arithmetic_run_bool_(reinterpret_cast<const bool *>(input0), reinterpret_cast<const bool *>(input1), CHECK_NULL_RETURN(arithmetic_opt_run_bool_);
reinterpret_cast<bool *>(output), size); ret = arithmetic_opt_run_bool_(reinterpret_cast<const bool *>(input0), reinterpret_cast<const bool *>(input1),
reinterpret_cast<bool *>(output), size, param_);
} else {
CHECK_NULL_RETURN(arithmetic_run_bool_);
ret = arithmetic_run_bool_(reinterpret_cast<const bool *>(input0), reinterpret_cast<const bool *>(input1),
reinterpret_cast<bool *>(output), size);
}
} else { } else {
if (is_opt) { if (is_opt) {
CHECK_NULL_RETURN(arithmetic_opt_run_int_); CHECK_NULL_RETURN(arithmetic_opt_run_int_);
@ -293,102 +364,35 @@ int ArithmeticCPUKernel::Execute(const void *input0, const void *input1, void *o
return ret; return ret;
} }
int ArithmeticCPUKernel::BroadcastRun(void *input0, void *input1, void *output, int dim, int out_count, int ArithmeticCPUKernel::CalcArithmeticByBatch(int task_id) {
int out_thread_stride) { int batch_per_thread = UP_DIV(out_batch_, op_parameter_->thread_num_);
if (dim > break_pos_) { int start_batch = batch_per_thread * task_id;
int offset = out_thread_stride * data_type_len_; int end_batch = MSMIN(start_batch + batch_per_thread, out_batch_);
return Execute(static_cast<uint8_t *>(input0) + offset, static_cast<uint8_t *>(input1) + offset, int ret = RET_ERROR;
static_cast<uint8_t *>(output) + offset, out_count, false); for (int i = start_batch; i < end_batch; i++) {
} batch_a_ptr_ = static_cast<uint8_t *>(input0_ptr_) + a_offset_[i] * a_stride_size_ * data_type_len_;
int offset_size[] = {param_->in_strides0_[dim] * data_type_len_, param_->in_strides1_[dim] * data_type_len_, batch_b_ptr_ = static_cast<uint8_t *>(input1_ptr_) + b_offset_[i] * b_stride_size_ * data_type_len_;
param_->out_strides_[dim] * data_type_len_}; batch_c_ptr_ = static_cast<uint8_t *>(output_ptr_) + i * c_stride_size_ * data_type_len_;
for (int i = 0; i < param_->out_shape_[dim]; ++i) { if (batch_scalar_) {
int pos0_ = param_->in_shape0_[dim] == 1 ? 0 : i; ret = Execute(batch_a_ptr_, batch_b_ptr_, batch_c_ptr_, c_stride_size_, true);
int pos1_ = param_->in_shape1_[dim] == 1 ? 0 : i; } else {
int ret = BroadcastRun(static_cast<uint8_t *>(input0) + pos0_ * offset_size[0], ret = Execute(batch_a_ptr_, batch_b_ptr_, batch_c_ptr_, c_stride_size_, false);
static_cast<uint8_t *>(input1) + pos1_ * offset_size[1], }
static_cast<uint8_t *>(output) + i * offset_size[2], dim + 1, out_count, out_thread_stride);
if (ret != RET_OK) { if (ret != RET_OK) {
return ret; MS_LOG(ERROR) << "failed to calculate.";
} return RET_ERROR;
}
return RET_OK;
}
int ArithmeticCPUKernel::BatchScalarCalc(int task_id) {
if (break_pos_ < 1) {
return RET_ERROR;
}
if (break_pos_ > ARITHMETIC_SUPPORT_DIMS_NUM || param_->out_strides_[break_pos_ - 1] == 0) {
MS_LOG(ERROR) << "param_->out_strides_[break_pos_ - 1] is 0 or break_pos_ is > 10";
return RET_ERROR;
}
int batch = param_->out_elements_num_ / param_->out_strides_[break_pos_ - 1];
int batch_per_thread = UP_DIV(batch, op_parameter_->thread_num_);
int start_batch = batch_per_thread * task_id;
int end_batch = MSMIN(start_batch + batch_per_thread, batch);
int batch_size = end_batch - start_batch;
int stride0 = param_->in_strides0_[break_pos_ - 1] * data_type_len_;
int stride1 = param_->in_strides1_[break_pos_ - 1] * data_type_len_;
int out_stride = param_->out_strides_[break_pos_ - 1] * data_type_len_;
int offset0 = stride0 * start_batch;
int offset1 = stride1 * start_batch;
int out_offset = out_stride * start_batch;
int ret = RET_OK;
for (int i = 0; i < batch_size; i++) {
ret = Execute(static_cast<uint8_t *>(input0_ptr_) + offset0, static_cast<uint8_t *>(input1_ptr_) + offset1,
static_cast<uint8_t *>(output_ptr_) + out_offset, param_->out_strides_[break_pos_ - 1], true);
offset0 += stride0;
offset1 += stride1;
out_offset += out_stride;
}
return ret;
}
int ArithmeticCPUKernel::BiasCalc(int task_id) {
if (param_->ndim_ > ARITHMETIC_SUPPORT_DIMS_NUM || param_->out_shape_[param_->ndim_ - 1] == 0) {
MS_LOG(ERROR) << "BiasCalc param is error!";
return RET_ERROR;
}
int last_shape = param_->out_shape_[param_->ndim_ - 1];
int batch = param_->out_elements_num_ / last_shape;
int batch_per_thread = UP_DIV(batch, op_parameter_->thread_num_);
int start_batch = batch_per_thread * task_id;
int end_batch = MSMIN(start_batch + batch_per_thread, batch);
int batch_size = end_batch - start_batch;
int stride = last_shape * data_type_len_;
int offset = stride * start_batch;
int ret = RET_OK;
if (param_->in_elements_num0_ > param_->in_elements_num1_) {
for (int i = 0; i < batch_size; i++) {
ret = Execute(static_cast<uint8_t *>(input0_ptr_) + offset, static_cast<uint8_t *>(input1_ptr_),
static_cast<uint8_t *>(output_ptr_) + offset, last_shape, false);
if (ret != RET_OK) {
return ret;
}
offset += stride;
}
} else {
for (int i = 0; i < batch_size; i++) {
ret = Execute(static_cast<uint8_t *>(input0_ptr_), static_cast<uint8_t *>(input1_ptr_) + offset,
static_cast<uint8_t *>(output_ptr_) + offset, last_shape, false);
if (ret != RET_OK) {
return ret;
}
offset += stride;
} }
} }
return ret; return ret;
} }
int ArithmeticCPUKernel::DoArithmetic(int task_id) { int ArithmeticCPUKernel::DoArithmetic(int task_id) {
auto element_num = out_tensors_[0]->ElementsNum(); if (split_by_batch_) {
return CalcArithmeticByBatch(task_id);
}
int64_t element_num = out_tensors_[0]->ElementsNum();
auto ret = RET_ERROR;
int stride = UP_DIV(element_num, op_parameter_->thread_num_); int stride = UP_DIV(element_num, op_parameter_->thread_num_);
int count = MSMIN(stride, element_num - stride * task_id); int count = MSMIN(stride, element_num - stride * task_id);
if (count <= 0) { if (count <= 0) {
@ -396,36 +400,16 @@ int ArithmeticCPUKernel::DoArithmetic(int task_id) {
} }
CHECK_LESS_RETURN(ARITHMETIC_SUPPORT_DIMS_NUM, param_->ndim_); CHECK_LESS_RETURN(ARITHMETIC_SUPPORT_DIMS_NUM, param_->ndim_);
int offset = stride * task_id * data_type_len_; int offset = stride * task_id * data_type_len_;
/* run opt function, one of input is scalar */ if (scalar_) {
if (IsScalarClac()) { // 2 32 240 240, 1 1 1 1
if (param_->in_elements_num0_ == 1) { if (param_->in_elements_num0_ == 1) {
return Execute(input0_ptr_, static_cast<uint8_t *>(input1_ptr_) + offset, ret = Execute(batch_a_ptr_, batch_b_ptr_ + offset, batch_c_ptr_ + offset, count, true);
static_cast<uint8_t *>(output_ptr_) + offset, count, true); } else {
} else if (param_->in_elements_num1_ == 1) { ret = Execute(batch_a_ptr_ + offset, batch_b_ptr_, batch_c_ptr_ + offset, count, true);
return Execute(static_cast<uint8_t *>(input0_ptr_) + offset, input1_ptr_,
static_cast<uint8_t *>(output_ptr_) + offset, count, true);
} }
} else {
ret = Execute(batch_a_ptr_ + offset, batch_b_ptr_ + offset, batch_c_ptr_ + offset, count, false);
} }
/* run opt function, every batch one of input is scalar */ return ret;
if (IsBatchScalarCalc()) { // 2 32 240 240, 2 32 1 1
return BatchScalarCalc(task_id);
}
/* each batch is eltwise calculation */
if (IsBiasCalc()) { // 2 240 240 32, 1 1 1 32
return BiasCalc(task_id);
}
/* need broadcast in runtime */
if (param_->broadcasting_) {
stride = UP_DIV(outside_, op_parameter_->thread_num_);
int out_count = MSMIN(stride, outside_ - stride * task_id);
if (out_count <= 0) {
return RET_OK;
}
return BroadcastRun(input0_ptr_, input1_ptr_, output_ptr_, 0, out_count, stride * task_id);
}
/* all elements eltwise calculation */
return Execute(static_cast<uint8_t *>(input0_ptr_) + offset, static_cast<uint8_t *>(input1_ptr_) + offset,
static_cast<uint8_t *>(output_ptr_) + offset, count, false);
} }
int ArithmeticsRun(void *cdata, int task_id, float lhs_scale, float rhs_scale) { int ArithmeticsRun(void *cdata, int task_id, float lhs_scale, float rhs_scale) {
@ -442,6 +426,7 @@ int ArithmeticCPUKernel::Run() {
MS_LOG(ERROR) << "ArithmeticCPUKernel check dataType failed, kernel name: " << this->name(); MS_LOG(ERROR) << "ArithmeticCPUKernel check dataType failed, kernel name: " << this->name();
return RET_ERROR; return RET_ERROR;
} }
if (!input0_broadcast_) { if (!input0_broadcast_) {
input0_ptr_ = in_tensors_[0]->data(); input0_ptr_ = in_tensors_[0]->data();
CHECK_NULL_RETURN(input0_ptr_); CHECK_NULL_RETURN(input0_ptr_);
@ -452,7 +437,16 @@ int ArithmeticCPUKernel::Run() {
} }
output_ptr_ = out_tensors_[0]->data(); output_ptr_ = out_tensors_[0]->data();
CHECK_NULL_RETURN(output_ptr_); CHECK_NULL_RETURN(output_ptr_);
return ParallelLaunch(this->ms_context_, ArithmeticsRun, this, op_parameter_->thread_num_); batch_a_ptr_ = static_cast<uint8_t *>(input0_ptr_);
batch_b_ptr_ = static_cast<uint8_t *>(input1_ptr_);
batch_c_ptr_ = static_cast<uint8_t *>(output_ptr_);
auto ret = ParallelLaunch(this->ms_context_, ArithmeticsRun, this, op_parameter_->thread_num_);
if (ret != RET_OK) {
MS_LOG(ERROR) << "arithmetic failed";
return RET_ERROR;
}
return RET_OK;
} }
REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_MulFusion, LiteKernelCreator<ArithmeticCPUKernel>) REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_MulFusion, LiteKernelCreator<ArithmeticCPUKernel>)

View File

@ -49,6 +49,8 @@ class ArithmeticCPUKernel : public InnerKernel {
typedef int (*ArithmeticOptIntRun)(const int *input0, const int *input1, int *output, const int element_size, typedef int (*ArithmeticOptIntRun)(const int *input0, const int *input1, int *output, const int element_size,
const ArithmeticParameter *param); const ArithmeticParameter *param);
typedef int (*ArithmeticBoolRun)(const bool *input0, const bool *input1, bool *output, const int element_size); typedef int (*ArithmeticBoolRun)(const bool *input0, const bool *input1, bool *output, const int element_size);
typedef int (*ArithmeticOptBoolRun)(const bool *input0, const bool *input1, bool *output, const int element_size,
const ArithmeticParameter *param);
typedef struct { typedef struct {
int primitive_type_; int primitive_type_;
@ -58,6 +60,7 @@ class ArithmeticCPUKernel : public InnerKernel {
ArithmeticBoolRun bool_func_; ArithmeticBoolRun bool_func_;
ArithmeticOptRun opt_func_; ArithmeticOptRun opt_func_;
ArithmeticOptIntRun opt_int_func_; ArithmeticOptIntRun opt_int_func_;
ArithmeticOptBoolRun opt_bool_func_;
} ARITHMETIC_FUNC_INFO_FP32; } ARITHMETIC_FUNC_INFO_FP32;
public: public:
@ -72,7 +75,6 @@ class ArithmeticCPUKernel : public InnerKernel {
int ReSize() override; int ReSize() override;
int Run() override; int Run() override;
virtual int DoArithmetic(int task_id); virtual int DoArithmetic(int task_id);
virtual int BroadcastRun(void *input0, void *input1, void *output, int dim, int out_count, int out_thread_stride);
protected: protected:
virtual void InitRunFunction(int primitive_type); virtual void InitRunFunction(int primitive_type);
@ -83,17 +85,31 @@ class ArithmeticCPUKernel : public InnerKernel {
virtual int Execute(const void *input0, const void *input1, void *output, int size, bool is_opt); virtual int Execute(const void *input0, const void *input1, void *output, int size, bool is_opt);
virtual bool IsBatchScalarCalc(); virtual bool IsBatchScalarCalc();
virtual bool IsScalarClac(); virtual bool IsScalarClac();
virtual int CalcArithmeticByBatch(int task_id);
bool input0_broadcast_ = false; bool input0_broadcast_ = false;
bool input1_broadcast_ = false; bool input1_broadcast_ = false;
void *input0_ptr_ = nullptr; void *input0_ptr_ = nullptr;
void *input1_ptr_ = nullptr; void *input1_ptr_ = nullptr;
void *output_ptr_ = nullptr; void *output_ptr_ = nullptr;
uint8_t *batch_a_ptr_ = nullptr;
uint8_t *batch_b_ptr_ = nullptr;
uint8_t *batch_c_ptr_ = nullptr;
int break_pos_ = 0; int break_pos_ = 0;
int outside_ = 0;
ArithmeticParameter *param_ = nullptr; ArithmeticParameter *param_ = nullptr;
int data_type_len_ = sizeof(float); int data_type_len_ = sizeof(float);
int out_batch_ = 1;
int a_stride_size_ = 1;
int b_stride_size_ = 1;
int c_stride_size_ = 1;
int last_batch_axis_ = 0;
bool scalar_ = false;
bool batch_scalar_ = false;
bool split_by_batch_ = false;
std::vector<int> a_offset_;
std::vector<int> b_offset_;
private: private:
int InitIndexOffsetInfo();
int BatchScalarCalc(int task_id); int BatchScalarCalc(int task_id);
int BiasCalc(int task_id); int BiasCalc(int task_id);
void FreeConstTileBuff(); void FreeConstTileBuff();
@ -103,6 +119,7 @@ class ArithmeticCPUKernel : public InnerKernel {
ArithmeticIntRun arithmetic_run_int_ = nullptr; ArithmeticIntRun arithmetic_run_int_ = nullptr;
ArithmeticOptIntRun arithmetic_opt_run_int_ = nullptr; ArithmeticOptIntRun arithmetic_opt_run_int_ = nullptr;
ArithmeticBoolRun arithmetic_run_bool_ = nullptr; ArithmeticBoolRun arithmetic_run_bool_ = nullptr;
ArithmeticOptBoolRun arithmetic_opt_run_bool_ = nullptr;
}; };
int ArithmeticsRun(void *cdata, int task_id, float lhs_scale, float rhs_scale); int ArithmeticsRun(void *cdata, int task_id, float lhs_scale, float rhs_scale);
} // namespace mindspore::kernel } // namespace mindspore::kernel