optimize arithmetic operators
This commit is contained in:
parent
2451a94125
commit
ee72df2cbc
|
@ -32,9 +32,9 @@ int ElementAddRelu(const float *in0, const float *in1, float *out, int size);
|
|||
int ElementAddRelu6(const float *in0, const float *in1, float *out, int size);
|
||||
int ElementAddInt(const int *in0, const int *in1, int *out, int size);
|
||||
int ElementOptAdd(const float *in0, const float *in1, float *out, int size, const ArithmeticParameter *param);
|
||||
int ElementOptAddInt(const int *in0, const int *in1, int *out, int size, const ArithmeticParameter *param);
|
||||
int ElementOptAddRelu(const float *in0, const float *in1, float *out, int size, const ArithmeticParameter *param);
|
||||
int ElementOptAddRelu6(const float *in0, const float *in1, float *out, int size, const ArithmeticParameter *param);
|
||||
int ElementOptAddInt(const int *in0, const int *in1, int *out, int size, const ArithmeticParameter *param);
|
||||
int BroadcastAdd(const float *in0, const float *in1, float *tile_in0, float *tile_in1, float *out, int size,
|
||||
ArithmeticParameter *param);
|
||||
|
||||
|
|
|
@ -26,6 +26,22 @@ int ElementEqualFp32(const float *input0, const float *input1, uint8_t *output,
|
|||
return NNACL_OK;
|
||||
}
|
||||
|
||||
int ElementOptEqualFp32(const float *input0, const float *input1, uint8_t *output, int element_size,
|
||||
const ArithmeticParameter *param) {
|
||||
int i = 0;
|
||||
if (param->in_elements_num0_ == 1) {
|
||||
for (; i < element_size; i++) {
|
||||
output[i] = input0[0] == input1[i];
|
||||
}
|
||||
} else {
|
||||
for (; i < element_size; i++) {
|
||||
output[i] = input0[i] == input1[0];
|
||||
}
|
||||
}
|
||||
|
||||
return NNACL_OK;
|
||||
}
|
||||
|
||||
int ElementEqualInt32(const int32_t *input0, const int32_t *input1, uint8_t *output, int element_size) {
|
||||
for (int i = 0; i < element_size; i++) {
|
||||
output[i] = input0[i] == input1[i];
|
||||
|
@ -33,6 +49,22 @@ int ElementEqualInt32(const int32_t *input0, const int32_t *input1, uint8_t *out
|
|||
return NNACL_OK;
|
||||
}
|
||||
|
||||
int ElementOptEqualInt32(const int32_t *input0, const int32_t *input1, uint8_t *output, int element_size,
|
||||
const ArithmeticParameter *param) {
|
||||
int i = 0;
|
||||
if (param->in_elements_num0_ == 1) {
|
||||
for (; i < element_size; i++) {
|
||||
output[i] = input0[0] == input1[i];
|
||||
}
|
||||
} else {
|
||||
for (; i < element_size; i++) {
|
||||
output[i] = input0[i] == input1[0];
|
||||
}
|
||||
}
|
||||
|
||||
return NNACL_OK;
|
||||
}
|
||||
|
||||
// not equal
|
||||
int ElementNotEqualFp32(const float *input0, const float *input1, uint8_t *output, int element_size) {
|
||||
for (int i = 0; i < element_size; i++) {
|
||||
|
@ -41,6 +73,23 @@ int ElementNotEqualFp32(const float *input0, const float *input1, uint8_t *outpu
|
|||
return NNACL_OK;
|
||||
}
|
||||
|
||||
// not equal
|
||||
int ElementOptNotEqualFp32(const float *input0, const float *input1, uint8_t *output, int element_size,
|
||||
const ArithmeticParameter *param) {
|
||||
int i = 0;
|
||||
if (param->in_elements_num0_ == 1) {
|
||||
for (; i < element_size; i++) {
|
||||
output[i] = input0[0] != input1[i];
|
||||
}
|
||||
} else {
|
||||
for (; i < element_size; i++) {
|
||||
output[i] = input0[i] != input1[0];
|
||||
}
|
||||
}
|
||||
|
||||
return NNACL_OK;
|
||||
}
|
||||
|
||||
int ElementNotEqualInt32(const int32_t *input0, const int32_t *input1, uint8_t *output, int element_size) {
|
||||
for (int i = 0; i < element_size; i++) {
|
||||
output[i] = input0[i] != input1[i];
|
||||
|
@ -48,6 +97,22 @@ int ElementNotEqualInt32(const int32_t *input0, const int32_t *input1, uint8_t *
|
|||
return NNACL_OK;
|
||||
}
|
||||
|
||||
int ElementOptNotEqualInt32(const int32_t *input0, const int32_t *input1, uint8_t *output, int element_size,
|
||||
const ArithmeticParameter *param) {
|
||||
int i = 0;
|
||||
if (param->in_elements_num0_ == 1) {
|
||||
for (; i < element_size; i++) {
|
||||
output[i] = input0[0] != input1[i];
|
||||
}
|
||||
} else {
|
||||
for (; i < element_size; i++) {
|
||||
output[i] = input0[i] != input1[0];
|
||||
}
|
||||
}
|
||||
|
||||
return NNACL_OK;
|
||||
}
|
||||
|
||||
// less
|
||||
int ElementLessFp32(const float *input0, const float *input1, uint8_t *output, int element_size) {
|
||||
for (int i = 0; i < element_size; i++) {
|
||||
|
@ -56,6 +121,22 @@ int ElementLessFp32(const float *input0, const float *input1, uint8_t *output, i
|
|||
return NNACL_OK;
|
||||
}
|
||||
|
||||
int ElementOptLessFp32(const float *input0, const float *input1, uint8_t *output, int element_size,
|
||||
const ArithmeticParameter *param) {
|
||||
int i = 0;
|
||||
if (param->in_elements_num0_ == 1) {
|
||||
for (; i < element_size; i++) {
|
||||
output[i] = input0[0] < input1[i];
|
||||
}
|
||||
} else {
|
||||
for (; i < element_size; i++) {
|
||||
output[i] = input0[i] < input1[0];
|
||||
}
|
||||
}
|
||||
|
||||
return NNACL_OK;
|
||||
}
|
||||
|
||||
int ElementLessInt32(const int32_t *input0, const int32_t *input1, uint8_t *output, int element_size) {
|
||||
for (int i = 0; i < element_size; i++) {
|
||||
output[i] = input0[i] < input1[i];
|
||||
|
@ -63,6 +144,22 @@ int ElementLessInt32(const int32_t *input0, const int32_t *input1, uint8_t *outp
|
|||
return NNACL_OK;
|
||||
}
|
||||
|
||||
int ElementOptLessInt32(const int32_t *input0, const int32_t *input1, uint8_t *output, int element_size,
|
||||
const ArithmeticParameter *param) {
|
||||
int i = 0;
|
||||
if (param->in_elements_num0_ == 1) {
|
||||
for (; i < element_size; i++) {
|
||||
output[i] = input0[0] < input1[i];
|
||||
}
|
||||
} else {
|
||||
for (; i < element_size; i++) {
|
||||
output[i] = input0[i] < input1[0];
|
||||
}
|
||||
}
|
||||
|
||||
return NNACL_OK;
|
||||
}
|
||||
|
||||
// less equal
|
||||
int ElementLessEqualFp32(const float *input0, const float *input1, uint8_t *output, int element_size) {
|
||||
for (int i = 0; i < element_size; i++) {
|
||||
|
@ -71,6 +168,22 @@ int ElementLessEqualFp32(const float *input0, const float *input1, uint8_t *outp
|
|||
return NNACL_OK;
|
||||
}
|
||||
|
||||
int ElementOptLessEqualFp32(const float *input0, const float *input1, uint8_t *output, int element_size,
|
||||
const ArithmeticParameter *param) {
|
||||
int i = 0;
|
||||
if (param->in_elements_num0_ == 1) {
|
||||
for (; i < element_size; i++) {
|
||||
output[i] = input0[0] <= input1[i];
|
||||
}
|
||||
} else {
|
||||
for (; i < element_size; i++) {
|
||||
output[i] = input0[i] <= input1[0];
|
||||
}
|
||||
}
|
||||
|
||||
return NNACL_OK;
|
||||
}
|
||||
|
||||
int ElementLessEqualInt32(const int32_t *input0, const int32_t *input1, uint8_t *output, int element_size) {
|
||||
for (int i = 0; i < element_size; i++) {
|
||||
output[i] = input0[i] <= input1[i];
|
||||
|
@ -78,6 +191,22 @@ int ElementLessEqualInt32(const int32_t *input0, const int32_t *input1, uint8_t
|
|||
return NNACL_OK;
|
||||
}
|
||||
|
||||
int ElementOptLessEqualInt32(const int32_t *input0, const int32_t *input1, uint8_t *output, int element_size,
|
||||
const ArithmeticParameter *param) {
|
||||
int i = 0;
|
||||
if (param->in_elements_num0_ == 1) {
|
||||
for (; i < element_size; i++) {
|
||||
output[i] = input0[0] <= input1[i];
|
||||
}
|
||||
} else {
|
||||
for (; i < element_size; i++) {
|
||||
output[i] = input0[i] <= input1[0];
|
||||
}
|
||||
}
|
||||
|
||||
return NNACL_OK;
|
||||
}
|
||||
|
||||
// greater
|
||||
int ElementGreaterFp32(const float *input0, const float *input1, uint8_t *output, int element_size) {
|
||||
for (int i = 0; i < element_size; i++) {
|
||||
|
@ -86,6 +215,22 @@ int ElementGreaterFp32(const float *input0, const float *input1, uint8_t *output
|
|||
return NNACL_OK;
|
||||
}
|
||||
|
||||
int ElementOptGreaterFp32(const float *input0, const float *input1, uint8_t *output, int element_size,
|
||||
const ArithmeticParameter *param) {
|
||||
int i = 0;
|
||||
if (param->in_elements_num0_ == 1) {
|
||||
for (; i < element_size; i++) {
|
||||
output[i] = input0[0] > input1[i];
|
||||
}
|
||||
} else {
|
||||
for (; i < element_size; i++) {
|
||||
output[i] = input0[i] > input1[0];
|
||||
}
|
||||
}
|
||||
|
||||
return NNACL_OK;
|
||||
}
|
||||
|
||||
int ElementGreaterInt32(const int32_t *input0, const int32_t *input1, uint8_t *output, int element_size) {
|
||||
for (int i = 0; i < element_size; i++) {
|
||||
output[i] = input0[i] > input1[i];
|
||||
|
@ -93,6 +238,21 @@ int ElementGreaterInt32(const int32_t *input0, const int32_t *input1, uint8_t *o
|
|||
return NNACL_OK;
|
||||
}
|
||||
|
||||
int ElementOptGreaterInt32(const int32_t *input0, const int32_t *input1, uint8_t *output, int element_size,
|
||||
const ArithmeticParameter *param) {
|
||||
int i = 0;
|
||||
if (param->in_elements_num0_ == 1) {
|
||||
for (; i < element_size; i++) {
|
||||
output[i] = input0[0] > input1[i];
|
||||
}
|
||||
} else {
|
||||
for (; i < element_size; i++) {
|
||||
output[i] = input0[i] > input1[0];
|
||||
}
|
||||
}
|
||||
|
||||
return NNACL_OK;
|
||||
}
|
||||
// greater equal
|
||||
int ElementGreaterEqualFp32(const float *input0, const float *input1, uint8_t *output, int element_size) {
|
||||
for (int i = 0; i < element_size; i++) {
|
||||
|
@ -101,9 +261,41 @@ int ElementGreaterEqualFp32(const float *input0, const float *input1, uint8_t *o
|
|||
return NNACL_OK;
|
||||
}
|
||||
|
||||
int ElementOptGreaterEqualFp32(const float *input0, const float *input1, uint8_t *output, int element_size,
|
||||
const ArithmeticParameter *param) {
|
||||
int i = 0;
|
||||
if (param->in_elements_num0_ == 1) {
|
||||
for (; i < element_size; i++) {
|
||||
output[i] = input0[0] >= input1[i];
|
||||
}
|
||||
} else {
|
||||
for (; i < element_size; i++) {
|
||||
output[i] = input0[i] >= input1[0];
|
||||
}
|
||||
}
|
||||
|
||||
return NNACL_OK;
|
||||
}
|
||||
|
||||
int ElementGreaterEqualInt32(const int32_t *input0, const int32_t *input1, uint8_t *output, int element_size) {
|
||||
for (int i = 0; i < element_size; i++) {
|
||||
output[i] = input0[i] >= input1[i];
|
||||
}
|
||||
return NNACL_OK;
|
||||
}
|
||||
|
||||
int ElementOptGreaterEqualInt32(const int32_t *input0, const int32_t *input1, uint8_t *output, int element_size,
|
||||
const ArithmeticParameter *param) {
|
||||
int i = 0;
|
||||
if (param->in_elements_num0_ == 1) {
|
||||
for (; i < element_size; i++) {
|
||||
output[i] = input0[0] >= input1[i];
|
||||
}
|
||||
} else {
|
||||
for (; i < element_size; i++) {
|
||||
output[i] = input0[i] >= input1[0];
|
||||
}
|
||||
}
|
||||
|
||||
return NNACL_OK;
|
||||
}
|
||||
|
|
|
@ -21,28 +21,53 @@
|
|||
#include <arm_neon.h>
|
||||
#endif
|
||||
#include "nnacl/op_base.h"
|
||||
#include "nnacl/base/arithmetic_base.h"
|
||||
#include "nnacl/errorcode.h"
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif
|
||||
int ElementEqualFp32(const float *input0, const float *input1, uint8_t *output, int element_size);
|
||||
int ElementOptEqualFp32(const float *input0, const float *input1, uint8_t *output, int element_size,
|
||||
const ArithmeticParameter *param);
|
||||
int ElementEqualInt32(const int32_t *input0, const int32_t *input1, uint8_t *output, int element_size);
|
||||
int ElementOptEqualInt32(const int32_t *input0, const int32_t *input1, uint8_t *output, int element_size,
|
||||
const ArithmeticParameter *param);
|
||||
|
||||
int ElementNotEqualFp32(const float *input0, const float *input1, uint8_t *output, int element_size);
|
||||
int ElementOptNotEqualFp32(const float *input0, const float *input1, uint8_t *output, int element_size,
|
||||
const ArithmeticParameter *param);
|
||||
int ElementNotEqualInt32(const int32_t *input0, const int32_t *input1, uint8_t *output, int element_size);
|
||||
int ElementOptNotEqualInt32(const int32_t *input0, const int32_t *input1, uint8_t *output, int element_size,
|
||||
const ArithmeticParameter *param);
|
||||
|
||||
int ElementLessFp32(const float *input0, const float *input1, uint8_t *output, int element_size);
|
||||
int ElementOptLessFp32(const float *input0, const float *input1, uint8_t *output, int element_size,
|
||||
const ArithmeticParameter *param);
|
||||
int ElementLessInt32(const int32_t *input0, const int32_t *input1, uint8_t *output, int element_size);
|
||||
int ElementOptLessInt32(const int32_t *input0, const int32_t *input1, uint8_t *output, int element_size,
|
||||
const ArithmeticParameter *param);
|
||||
|
||||
int ElementLessEqualFp32(const float *input0, const float *input1, uint8_t *output, int element_size);
|
||||
int ElementOptLessEqualFp32(const float *input0, const float *input1, uint8_t *output, int element_size,
|
||||
const ArithmeticParameter *param);
|
||||
int ElementLessEqualInt32(const int32_t *input0, const int32_t *input1, uint8_t *output, int element_size);
|
||||
int ElementOptLessEqualInt32(const int32_t *input0, const int32_t *input1, uint8_t *output, int element_size,
|
||||
const ArithmeticParameter *param);
|
||||
|
||||
int ElementGreaterFp32(const float *input0, const float *input1, uint8_t *output, int element_size);
|
||||
int ElementOptGreaterFp32(const float *input0, const float *input1, uint8_t *output, int element_size,
|
||||
const ArithmeticParameter *param);
|
||||
int ElementGreaterInt32(const int32_t *input0, const int32_t *input1, uint8_t *output, int element_size);
|
||||
int ElementOptGreaterInt32(const int32_t *input0, const int32_t *input1, uint8_t *output, int element_size,
|
||||
const ArithmeticParameter *param);
|
||||
|
||||
int ElementGreaterEqualFp32(const float *input0, const float *input1, uint8_t *output, int element_size);
|
||||
int ElementOptGreaterEqualFp32(const float *input0, const float *input1, uint8_t *output, int element_size,
|
||||
const ArithmeticParameter *param);
|
||||
int ElementGreaterEqualInt32(const int32_t *input0, const int32_t *input1, uint8_t *output, int element_size);
|
||||
int ElementOptGreaterEqualInt32(const int32_t *input0, const int32_t *input1, uint8_t *output, int element_size,
|
||||
const ArithmeticParameter *param);
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
|
|
|
@ -26,6 +26,21 @@ int ElementFloorMod(const float *in0, const float *in1, float *out, int size) {
|
|||
return NNACL_OK;
|
||||
}
|
||||
|
||||
int ElementOptFloorMod(const float *in0, const float *in1, float *out, int size, const ArithmeticParameter *param) {
|
||||
int i = 0;
|
||||
if (param->in_elements_num0_ == 1) {
|
||||
for (; i < size; i++) {
|
||||
out[i] = in0[0] - floorf(in0[0] / in1[i]) * in1[i];
|
||||
}
|
||||
} else {
|
||||
for (; i < size; i++) {
|
||||
out[i] = in0[i] - floorf(in0[i] / in1[0]) * in1[0];
|
||||
}
|
||||
}
|
||||
|
||||
return NNACL_OK;
|
||||
}
|
||||
|
||||
int ElementFloorModInt(const int *in0, const int *in1, int *out, int size) {
|
||||
for (int i = 0; i < size; i++) {
|
||||
NNACL_CHECK_ZERO_RETURN_ERR(in1[i]);
|
||||
|
@ -35,6 +50,25 @@ int ElementFloorModInt(const int *in0, const int *in1, int *out, int size) {
|
|||
return NNACL_OK;
|
||||
}
|
||||
|
||||
int ElementOptFloorModInt(const int *in0, const int *in1, int *out, int size, const ArithmeticParameter *param) {
|
||||
int i = 0;
|
||||
if (param->in_elements_num0_ == 1) {
|
||||
for (; i < size; i++) {
|
||||
NNACL_CHECK_ZERO_RETURN_ERR(in1[i]);
|
||||
int remainder = in0[0] - (in0[0] / in1[i]) * in1[i];
|
||||
out[i] = (remainder != 0) && ((in0[0] > 0) != (in1[i] > 0)) ? remainder + in1[i] : remainder;
|
||||
}
|
||||
} else {
|
||||
NNACL_CHECK_ZERO_RETURN_ERR(in1[0]);
|
||||
for (; i < size; i++) {
|
||||
int remainder = in0[i] - (in0[i] / in1[0]) * in1[0];
|
||||
out[i] = (remainder != 0) && ((in0[i] > 0) != (in1[0] > 0)) ? remainder + in1[0] : remainder;
|
||||
}
|
||||
}
|
||||
|
||||
return NNACL_OK;
|
||||
}
|
||||
|
||||
int ElementMod(const float *in0, const float *in1, float *out, int size) {
|
||||
for (int i = 0; i < size; i++) {
|
||||
out[i] = fmodf(in0[i], in1[i]);
|
||||
|
@ -42,23 +76,24 @@ int ElementMod(const float *in0, const float *in1, float *out, int size) {
|
|||
return NNACL_OK;
|
||||
}
|
||||
|
||||
int ElementModInt(const int *in0, const int *in1, int *out, int size) {
|
||||
for (int i = 0; i < size; i++) {
|
||||
NNACL_CHECK_ZERO_RETURN_ERR(in1[i]);
|
||||
out[i] = in0[i] % in1[i];
|
||||
int ElementOptMod(const float *in0, const float *in1, float *out, int size, const ArithmeticParameter *param) {
|
||||
int index = 0;
|
||||
if (param->in_elements_num0_ == 1) {
|
||||
for (; index < size; index++) {
|
||||
out[index] = fmodf(in0[0], in1[index]);
|
||||
}
|
||||
} else {
|
||||
for (; index < size; index++) {
|
||||
out[index] = fmodf(in0[index], in1[0]);
|
||||
}
|
||||
}
|
||||
return NNACL_OK;
|
||||
}
|
||||
|
||||
int ElementOptMod(const float *in0, const float *in1, float *out, int size, const ArithmeticParameter *param) {
|
||||
if (param->in_elements_num0_ == 1) {
|
||||
for (int index = 0; index < size; index++) {
|
||||
out[index] = fmodf(in0[0], in1[index]);
|
||||
}
|
||||
} else {
|
||||
for (int index = 0; index < size; index++) {
|
||||
out[index] = fmodf(in0[index], in1[0]);
|
||||
}
|
||||
int ElementModInt(const int *in0, const int *in1, int *out, int size) {
|
||||
for (int i = 0; i < size; i++) {
|
||||
NNACL_CHECK_ZERO_RETURN_ERR(in1[i]);
|
||||
out[i] = in0[i] % in1[i];
|
||||
}
|
||||
return NNACL_OK;
|
||||
}
|
||||
|
@ -85,6 +120,21 @@ int ElementFloorDiv(const float *in0, const float *in1, float *out, int size) {
|
|||
return NNACL_OK;
|
||||
}
|
||||
|
||||
int ElementOptFloorDiv(const float *in0, const float *in1, float *out, int size, const ArithmeticParameter *param) {
|
||||
int i = 0;
|
||||
if (param->in_elements_num0_ == 1) {
|
||||
for (; i < size; i++) {
|
||||
out[i] = floorf(in0[0] / in1[i]);
|
||||
}
|
||||
} else {
|
||||
for (; i < size; i++) {
|
||||
out[i] = floorf(in0[i] / in1[0]);
|
||||
}
|
||||
}
|
||||
|
||||
return NNACL_OK;
|
||||
}
|
||||
|
||||
int ElementFloorDivInt(const int *in0, const int *in1, int *out, int size) {
|
||||
for (int i = 0; i < size; i++) {
|
||||
NNACL_CHECK_ZERO_RETURN_ERR(in1[i]);
|
||||
|
@ -93,6 +143,23 @@ int ElementFloorDivInt(const int *in0, const int *in1, int *out, int size) {
|
|||
return NNACL_OK;
|
||||
}
|
||||
|
||||
int ElementOptFloorDivInt(const int *in0, const int *in1, int *out, int size, const ArithmeticParameter *param) {
|
||||
int i = 0;
|
||||
if (param->in_elements_num0_ == 1) {
|
||||
for (; i < size; i++) {
|
||||
NNACL_CHECK_ZERO_RETURN_ERR(in1[i]);
|
||||
out[i] = in0[0] / in1[i];
|
||||
}
|
||||
} else {
|
||||
NNACL_CHECK_ZERO_RETURN_ERR(in1[0]);
|
||||
for (; i < size; i++) {
|
||||
out[i] = in0[i] / in1[0];
|
||||
}
|
||||
}
|
||||
|
||||
return NNACL_OK;
|
||||
}
|
||||
|
||||
int ElementLogicalAnd(const float *in0, const float *in1, float *out, int size) {
|
||||
int index = 0;
|
||||
#ifdef ENABLE_NEON
|
||||
|
@ -113,6 +180,21 @@ int ElementLogicalAnd(const float *in0, const float *in1, float *out, int size)
|
|||
return NNACL_OK;
|
||||
}
|
||||
|
||||
int ElementOptLogicalAnd(const float *in0, const float *in1, float *out, int size, const ArithmeticParameter *param) {
|
||||
int index = 0;
|
||||
if (param->in_elements_num0_ == 1) {
|
||||
for (; index < size; index++) {
|
||||
out[index] = (float)((bool)(in0[0]) & (bool)(in1[index]));
|
||||
}
|
||||
} else {
|
||||
for (; index < size; index++) {
|
||||
out[index] = (float)((bool)(in0[index]) & (bool)(in1[0]));
|
||||
}
|
||||
}
|
||||
|
||||
return NNACL_OK;
|
||||
}
|
||||
|
||||
int ElementLogicalAndInt(const int *in0, const int *in1, int *out, int size) {
|
||||
int index = 0;
|
||||
for (; index < size; index++) {
|
||||
|
@ -121,11 +203,42 @@ int ElementLogicalAndInt(const int *in0, const int *in1, int *out, int size) {
|
|||
return NNACL_OK;
|
||||
}
|
||||
|
||||
int ElementOptLogicalAndInt(const int *in0, const int *in1, int *out, int size, const ArithmeticParameter *param) {
|
||||
int index = 0;
|
||||
if (param->in_elements_num0_ == 1) {
|
||||
for (; index < size; index++) {
|
||||
out[index] = (int)((unsigned int)(in0[0]) & (unsigned int)(in1[index]));
|
||||
}
|
||||
} else {
|
||||
for (; index < size; index++) {
|
||||
out[index] = (int)((unsigned int)(in0[index]) & (unsigned int)(in1[0]));
|
||||
}
|
||||
}
|
||||
|
||||
return NNACL_OK;
|
||||
}
|
||||
|
||||
int ElementLogicalAndBool(const bool *in0, const bool *in1, bool *out, int size) {
|
||||
int index = 0;
|
||||
for (; index < size; index++) {
|
||||
out[index] = (bool)((unsigned int)(in0[index]) & (unsigned int)(in1[index]));
|
||||
}
|
||||
|
||||
return NNACL_OK;
|
||||
}
|
||||
|
||||
int ElementOptLogicalAndBool(const bool *in0, const bool *in1, bool *out, int size, const ArithmeticParameter *param) {
|
||||
int index = 0;
|
||||
if (param->in_elements_num0_ == 1) {
|
||||
for (; index < size; index++) {
|
||||
out[index] = (bool)((unsigned int)(in0[0]) & (unsigned int)(in1[index]));
|
||||
}
|
||||
} else {
|
||||
for (; index < size; index++) {
|
||||
out[index] = (bool)((unsigned int)(in0[index]) & (unsigned int)(in1[0]));
|
||||
}
|
||||
}
|
||||
|
||||
return NNACL_OK;
|
||||
}
|
||||
|
||||
|
@ -149,6 +262,21 @@ int ElementLogicalOr(const float *in0, const float *in1, float *out, int size) {
|
|||
return NNACL_OK;
|
||||
}
|
||||
|
||||
int ElementOptLogicalOr(const float *in0, const float *in1, float *out, int size, const ArithmeticParameter *param) {
|
||||
int index = 0;
|
||||
if (param->in_elements_num0_ == 1) {
|
||||
for (; index < size; index++) {
|
||||
out[index] = (float)((bool)(in0[0]) | (bool)(in1[index]));
|
||||
}
|
||||
} else {
|
||||
for (; index < size; index++) {
|
||||
out[index] = (float)((bool)(in0[index]) | (bool)(in1[0]));
|
||||
}
|
||||
}
|
||||
|
||||
return NNACL_OK;
|
||||
}
|
||||
|
||||
int ElementLogicalOrBool(const bool *in0, const bool *in1, bool *out, int size) {
|
||||
int index = 0;
|
||||
for (; index < size; index++) {
|
||||
|
@ -157,6 +285,21 @@ int ElementLogicalOrBool(const bool *in0, const bool *in1, bool *out, int size)
|
|||
return NNACL_OK;
|
||||
}
|
||||
|
||||
int ElementOptLogicalOrBool(const bool *in0, const bool *in1, bool *out, int size, const ArithmeticParameter *param) {
|
||||
int index = 0;
|
||||
if (param->in_elements_num0_ == 1) {
|
||||
for (; index < size; index++) {
|
||||
out[index] = (bool)(in0[0] | in1[index]);
|
||||
}
|
||||
} else {
|
||||
for (; index < size; index++) {
|
||||
out[index] = (bool)(in0[index] | in1[0]);
|
||||
}
|
||||
}
|
||||
|
||||
return NNACL_OK;
|
||||
}
|
||||
|
||||
int ElementMaximum(const float *in0, const float *in1, float *out, int size) {
|
||||
int index = 0;
|
||||
#ifdef ENABLE_NEON
|
||||
|
@ -173,6 +316,22 @@ int ElementMaximum(const float *in0, const float *in1, float *out, int size) {
|
|||
return NNACL_OK;
|
||||
}
|
||||
|
||||
int ElementOptMaximum(const float *in0, const float *in1, float *out, int size, const ArithmeticParameter *param) {
|
||||
int index = 0;
|
||||
|
||||
if (param->in_elements_num0_ == 1) {
|
||||
for (; index < size; index++) {
|
||||
out[index] = in0[0] > in1[index] ? in0[0] : in1[index];
|
||||
}
|
||||
} else {
|
||||
for (; index < size; index++) {
|
||||
out[index] = in0[index] > in1[0] ? in0[index] : in1[0];
|
||||
}
|
||||
}
|
||||
|
||||
return NNACL_OK;
|
||||
}
|
||||
|
||||
int ElementMaximumInt(const int *in0, const int *in1, int *out, int size) {
|
||||
int index = 0;
|
||||
#ifdef ENABLE_NEON
|
||||
|
@ -189,22 +348,53 @@ int ElementMaximumInt(const int *in0, const int *in1, int *out, int size) {
|
|||
return NNACL_OK;
|
||||
}
|
||||
|
||||
int ElementMinimumInt(const int *input0, const int *input1, int *output, const int element_size) {
|
||||
int ElementOptMaximumInt(const int *in0, const int *in1, int *out, int size, const ArithmeticParameter *param) {
|
||||
int index = 0;
|
||||
if (param->in_elements_num0_ == 1) {
|
||||
for (; index < size; index++) {
|
||||
out[index] = in0[0] > in1[index] ? in0[0] : in1[index];
|
||||
}
|
||||
} else {
|
||||
for (; index < size; index++) {
|
||||
out[index] = in0[index] > in1[0] ? in0[index] : in1[0];
|
||||
}
|
||||
}
|
||||
|
||||
return NNACL_OK;
|
||||
}
|
||||
|
||||
int ElementMinimumInt(const int *input0, const int *input1, int *output, int size) {
|
||||
int index = 0;
|
||||
#ifdef ENABLE_NEON
|
||||
for (; index <= element_size - 4; index += C4NUM) {
|
||||
for (; index <= size - 4; index += C4NUM) {
|
||||
int32x4_t vin0 = vld1q_s32(input0 + index);
|
||||
int32x4_t vin1 = vld1q_s32(input1 + index);
|
||||
int32x4_t vout = vminq_s32(vin0, vin1);
|
||||
vst1q_s32(output + index, vout);
|
||||
}
|
||||
#endif
|
||||
for (; index < element_size; index++) {
|
||||
for (; index < size; index++) {
|
||||
output[index] = input0[index] > input1[index] ? input1[index] : input0[index];
|
||||
}
|
||||
return NNACL_OK;
|
||||
}
|
||||
|
||||
int ElementOptMinimumInt(const int *input0, const int *input1, int *output, int size,
|
||||
const ArithmeticParameter *param) {
|
||||
int index = 0;
|
||||
if (param->in_elements_num0_ == 1) {
|
||||
for (; index < size; index++) {
|
||||
output[index] = input0[0] > input1[index] ? input1[index] : input0[0];
|
||||
}
|
||||
} else {
|
||||
for (; index < size; index++) {
|
||||
output[index] = input0[index] > input1[0] ? input1[0] : input0[index];
|
||||
}
|
||||
}
|
||||
|
||||
return NNACL_OK;
|
||||
}
|
||||
|
||||
int ElementMinimum(const float *in0, const float *in1, float *out, int size) {
|
||||
int index = 0;
|
||||
#ifdef ENABLE_NEON
|
||||
|
@ -221,6 +411,21 @@ int ElementMinimum(const float *in0, const float *in1, float *out, int size) {
|
|||
return NNACL_OK;
|
||||
}
|
||||
|
||||
int ElementOptMinimum(const float *in0, const float *in1, float *out, int size, const ArithmeticParameter *param) {
|
||||
int index = 0;
|
||||
if (param->in_elements_num0_ == 1) {
|
||||
for (; index < size; index++) {
|
||||
out[index] = in0[0] > in1[index] ? in1[index] : in0[0];
|
||||
}
|
||||
} else {
|
||||
for (; index < size; index++) {
|
||||
out[index] = in0[index] > in1[0] ? in1[0] : in0[index];
|
||||
}
|
||||
}
|
||||
|
||||
return NNACL_OK;
|
||||
}
|
||||
|
||||
#undef ACCURACY_DATA
|
||||
|
||||
void TileOneDimensionFp32(const float *inData, float *outData, int dim, size_t ndim, const int *inShape,
|
||||
|
|
|
@ -37,31 +37,44 @@ void TileDimensionsFp32(const float *data0, const float *data1, float *tile_data
|
|||
ArithmeticParameter *param);
|
||||
/* logical and */
|
||||
int ElementLogicalAnd(const float *in0, const float *in1, float *out, int size);
|
||||
int ElementOptLogicalAnd(const float *in0, const float *in1, float *out, int size, const ArithmeticParameter *param);
|
||||
int ElementLogicalAndInt(const int *in0, const int *in1, int *out, int size);
|
||||
int ElementOptLogicalAndInt(const int *in0, const int *in1, int *out, int size, const ArithmeticParameter *param);
|
||||
int ElementLogicalAndBool(const bool *in0, const bool *in1, bool *out, int size);
|
||||
int ElementOptLogicalAndBool(const bool *in0, const bool *in1, bool *out, int size, const ArithmeticParameter *param);
|
||||
|
||||
/* logical or */
|
||||
int ElementLogicalOr(const float *in0, const float *in1, float *out, int size);
|
||||
int ElementOptLogicalOr(const float *in0, const float *in1, float *out, int size, const ArithmeticParameter *param);
|
||||
int ElementLogicalOrBool(const bool *in0, const bool *in1, bool *out, int size);
|
||||
int ElementOptLogicalOrBool(const bool *in0, const bool *in1, bool *out, int size, const ArithmeticParameter *param);
|
||||
|
||||
/* max min */
|
||||
int ElementMaximum(const float *in0, const float *in1, float *out, int size);
|
||||
int ElementOptMaximum(const float *in0, const float *in1, float *out, int size, const ArithmeticParameter *param);
|
||||
int ElementMinimum(const float *in0, const float *in1, float *out, int size);
|
||||
int ElementOptMinimum(const float *in0, const float *in1, float *out, int size, const ArithmeticParameter *param);
|
||||
int ElementMaximumInt(const int *in0, const int *in1, int *out, int size);
|
||||
int ElementMinimumInt(const int *input0, const int *input1, int *output, const int element_size);
|
||||
int ElementOptMaximumInt(const int *in0, const int *in1, int *out, int size, const ArithmeticParameter *param);
|
||||
int ElementMinimumInt(const int *input0, const int *input1, int *output, int size);
|
||||
int ElementOptMinimumInt(const int *input0, const int *input1, int *output, int size, const ArithmeticParameter *param);
|
||||
|
||||
/* floor div */
|
||||
int ElementFloorDiv(const float *in0, const float *in1, float *out, int size);
|
||||
int ElementOptFloorDiv(const float *in0, const float *in1, float *out, int size, const ArithmeticParameter *param);
|
||||
int ElementFloorDivInt(const int *in0, const int *in1, int *out, int size);
|
||||
int ElementOptFloorDivInt(const int *in0, const int *in1, int *out, int size, const ArithmeticParameter *param);
|
||||
|
||||
/* floor mod */
|
||||
int ElementFloorMod(const float *in0, const float *in1, float *out, int size);
|
||||
int ElementOptFloorMod(const float *in0, const float *in1, float *out, int size, const ArithmeticParameter *param);
|
||||
int ElementFloorModInt(const int *in0, const int *in1, int *out, int size);
|
||||
int ElementOptFloorModInt(const int *in0, const int *in1, int *out, int size, const ArithmeticParameter *param);
|
||||
|
||||
/* mod */
|
||||
int ElementMod(const float *in0, const float *in1, float *out, int size);
|
||||
int ElementModInt(const int *in0, const int *in1, int *out, int size);
|
||||
int ElementOptMod(const float *in0, const float *in1, float *out, int size, const ArithmeticParameter *param);
|
||||
int ElementModInt(const int *in0, const int *in1, int *out, int size);
|
||||
int ElementOptModInt(const int *in0, const int *in1, int *out, int size, const ArithmeticParameter *param);
|
||||
|
||||
#ifdef __cplusplus
|
||||
|
|
|
@ -25,4 +25,9 @@ int ElementSquaredDifference(const float *in0, const float *in1, float *out, int
|
|||
return ElementMul(out, out, out, size);
|
||||
}
|
||||
|
||||
int ElementOptSquaredDifference(const float *in0, const float *in1, float *out, int size,
|
||||
const ArithmeticParameter *param) {
|
||||
ElementOptSub(in0, in1, out, size, param);
|
||||
return ElementMul(out, out, out, size);
|
||||
}
|
||||
#endif // MINDSPORE_NNACL_SQUARED_DIFFERENCE_H_
|
||||
|
|
|
@ -29,7 +29,8 @@ extern "C" {
|
|||
|
||||
/* Element Squared Difference */
|
||||
int ElementSquaredDifference(const float *in0, const float *in1, float *out, int size);
|
||||
|
||||
int ElementOptSquaredDifference(const float *in0, const float *in1, float *out, int size,
|
||||
const ArithmeticParameter *param);
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
|
|
|
@ -32,9 +32,9 @@ int ElementSubInt(const int *in0, const int *in1, int *out, int size);
|
|||
int ElementSubRelu(const float *in0, const float *in1, float *out, int size);
|
||||
int ElementSubRelu6(const float *in0, const float *in1, float *out, int size);
|
||||
int ElementOptSub(const float *in0, const float *in1, float *out, int size, const ArithmeticParameter *param);
|
||||
int ElementOptSubInt(const int *in0, const int *in1, int *out, int size, const ArithmeticParameter *param);
|
||||
int ElementOptSubRelu(const float *in0, const float *in1, float *out, int size, const ArithmeticParameter *param);
|
||||
int ElementOptSubRelu6(const float *in0, const float *in1, float *out, int size, const ArithmeticParameter *param);
|
||||
int ElementOptSubInt(const int *in0, const int *in1, int *out, int size, const ArithmeticParameter *param);
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
|
|
|
@ -63,36 +63,6 @@ int ArithmeticFP16CPUKernel::CheckDataType() {
|
|||
return RET_OK;
|
||||
}
|
||||
|
||||
bool ArithmeticFP16CPUKernel::IsScalarClac() { // 2 32 240 240, 1 1 1 1
|
||||
if ((param_->in_elements_num0_ == 1 || param_->in_elements_num1_ == 1) && (arithmetic_opt_func_ != nullptr)) {
|
||||
return true;
|
||||
} else {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
bool ArithmeticFP16CPUKernel::IsBatchScalarCalc() {
|
||||
if (arithmetic_opt_func_ == nullptr) {
|
||||
return false;
|
||||
}
|
||||
size_t break_axis = 0;
|
||||
for (size_t i = 0; i < param_->ndim_; i++) {
|
||||
if (param_->in_shape0_[i] != param_->in_shape1_[i]) {
|
||||
break_axis = i;
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (break_axis < param_->ndim_) {
|
||||
for (size_t i = break_axis; i < param_->ndim_; i++) {
|
||||
if (param_->in_shape1_[i] != 1) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
}
|
||||
break_pos_ = break_axis;
|
||||
return true;
|
||||
}
|
||||
|
||||
void ArithmeticFP16CPUKernel::InitRunFunction(int primitive_type) {
|
||||
ARITHMETIC_FUNC_INFO_FP16 fun_table[] = {
|
||||
{PrimitiveType_MulFusion, schema::ActivationType_RELU, ElementMulReluFp16, ElementOptMulReluFp16},
|
||||
|
@ -171,6 +141,7 @@ int ArithmeticFP16CPUKernel::Run() {
|
|||
MS_LOG(ERROR) << "ArithmeticFP16CPUKernel check dataType failed.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
|
||||
if (!input0_broadcast_) {
|
||||
input0_ptr_ = ConvertInputFp32toFp16(in_tensors_.at(0), static_cast<const lite::InnerContext *>(this->ms_context_));
|
||||
}
|
||||
|
@ -183,10 +154,16 @@ int ArithmeticFP16CPUKernel::Run() {
|
|||
FreeFp16Buffer();
|
||||
return RET_ERROR;
|
||||
}
|
||||
|
||||
batch_a_ptr_ = static_cast<uint8_t *>(input0_ptr_);
|
||||
batch_b_ptr_ = static_cast<uint8_t *>(input1_ptr_);
|
||||
batch_c_ptr_ = static_cast<uint8_t *>(output_ptr_);
|
||||
auto ret = ParallelLaunch(this->ms_context_, ArithmeticsRun, this, op_parameter_->thread_num_);
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "ArithmeticsRun failed, ret : " << ret;
|
||||
return RET_ERROR;
|
||||
}
|
||||
|
||||
if (out_tensors_.at(0)->data_type() == kNumberTypeFloat32) {
|
||||
Float16ToFloat32(static_cast<float16_t *>(output_ptr_), reinterpret_cast<float *>(output_tensor->data()),
|
||||
output_tensor->ElementsNum());
|
||||
|
|
|
@ -40,8 +40,6 @@ class ArithmeticFP16CPUKernel : public ArithmeticCPUKernel {
|
|||
~ArithmeticFP16CPUKernel() = default;
|
||||
int ReSize() override;
|
||||
int Run() override;
|
||||
bool IsBatchScalarCalc() override;
|
||||
bool IsScalarClac() override;
|
||||
|
||||
private:
|
||||
void InitRunFunction(int primitive_type) override;
|
||||
|
|
|
@ -28,85 +28,110 @@ using mindspore::schema::PrimitiveType_LessEqual;
|
|||
using mindspore::schema::PrimitiveType_NotEqual;
|
||||
|
||||
namespace mindspore::kernel {
|
||||
int ArithmeticCompareCPUKernel::BroadcastRun(void *input0, void *input1, void *output, int dim, int out_count,
|
||||
int out_thread_stride) {
|
||||
if (dim > break_pos_) {
|
||||
if (in_tensors_[0]->data_type() == kNumberTypeInt || in_tensors_[0]->data_type() == kNumberTypeInt32) {
|
||||
return func_int32_(reinterpret_cast<int *>(input0) + out_thread_stride,
|
||||
reinterpret_cast<int *>(input1) + out_thread_stride,
|
||||
reinterpret_cast<uint8_t *>(output) + out_thread_stride, out_count);
|
||||
void ArithmeticCompareCPUKernel::InitRunFunction(int primitive_type) {
|
||||
ARITHMETIC_COMEPARE_FUNC_INFO_FP32 fun_table[] = {
|
||||
{PrimitiveType_Equal, ElementEqualFp32, ElementEqualInt32, ElementOptEqualFp32, ElementOptEqualInt32},
|
||||
{PrimitiveType_NotEqual, ElementNotEqualFp32, ElementNotEqualInt32, ElementOptNotEqualFp32,
|
||||
ElementOptNotEqualInt32},
|
||||
{PrimitiveType_Less, ElementLessFp32, ElementLessInt32, ElementOptLessFp32, ElementOptLessInt32},
|
||||
{PrimitiveType_LessEqual, ElementLessEqualFp32, ElementLessEqualInt32, ElementOptLessEqualFp32,
|
||||
ElementOptLessEqualInt32},
|
||||
{PrimitiveType_Greater, ElementGreaterFp32, ElementGreaterInt32, ElementOptGreaterFp32, ElementOptGreaterInt32},
|
||||
{PrimitiveType_GreaterEqual, ElementGreaterEqualFp32, ElementGreaterEqualInt32, ElementOptGreaterEqualFp32,
|
||||
ElementOptGreaterEqualInt32}};
|
||||
size_t length = sizeof(fun_table) / sizeof(ARITHMETIC_COMEPARE_FUNC_INFO_FP32);
|
||||
for (size_t i = 0; i < length; i++) {
|
||||
if (fun_table[i].primitive_type_ == primitive_type) {
|
||||
func_fp32_ = fun_table[i].func_;
|
||||
func_int32_ = fun_table[i].int_func_;
|
||||
opt_func_fp32_ = fun_table[i].opt_func_;
|
||||
opt_func_int32_ = fun_table[i].opt_int_func_;
|
||||
return;
|
||||
}
|
||||
return func_fp32_(reinterpret_cast<float *>(input0) + out_thread_stride,
|
||||
reinterpret_cast<float *>(input1) + out_thread_stride,
|
||||
reinterpret_cast<uint8_t *>(output) + out_thread_stride, out_count);
|
||||
}
|
||||
for (int i = 0; i < param_->out_shape_[dim]; ++i) {
|
||||
int pos0_ = param_->in_shape0_[dim] == 1 ? 0 : i;
|
||||
int pos1_ = param_->in_shape1_[dim] == 1 ? 0 : i;
|
||||
int error_code;
|
||||
if (in_tensors_[0]->data_type() == kNumberTypeInt || in_tensors_[0]->data_type() == kNumberTypeInt32) {
|
||||
error_code = BroadcastRun(reinterpret_cast<int *>(input0) + pos0_ * param_->in_strides0_[dim],
|
||||
reinterpret_cast<int *>(input1) + pos1_ * param_->in_strides1_[dim],
|
||||
reinterpret_cast<uint8_t *>(output) + i * param_->out_strides_[dim], dim + 1, out_count,
|
||||
out_thread_stride);
|
||||
}
|
||||
|
||||
int ArithmeticCompareCPUKernel::Execute(const void *input0, const void *input1, void *output, int size, bool is_opt) {
|
||||
int ret = RET_OK;
|
||||
if (in_tensors_[0]->data_type() == kNumberTypeFloat32) {
|
||||
if (is_opt) {
|
||||
CHECK_NULL_RETURN(opt_func_fp32_);
|
||||
ret = opt_func_fp32_(reinterpret_cast<const float *>(input0), reinterpret_cast<const float *>(input1),
|
||||
reinterpret_cast<uint8_t *>(output), size, param_);
|
||||
} else {
|
||||
error_code = BroadcastRun(reinterpret_cast<float *>(input0) + pos0_ * param_->in_strides0_[dim],
|
||||
reinterpret_cast<float *>(input1) + pos1_ * param_->in_strides1_[dim],
|
||||
reinterpret_cast<uint8_t *>(output) + i * param_->out_strides_[dim], dim + 1, out_count,
|
||||
out_thread_stride);
|
||||
CHECK_NULL_RETURN(func_fp32_);
|
||||
ret = func_fp32_(reinterpret_cast<const float *>(input0), reinterpret_cast<const float *>(input1),
|
||||
reinterpret_cast<uint8_t *>(output), size);
|
||||
}
|
||||
if (error_code != RET_OK) {
|
||||
return error_code;
|
||||
} else if (in_tensors_[0]->data_type() == kNumberTypeInt || in_tensors_[0]->data_type() == kNumberTypeInt32) {
|
||||
if (is_opt) {
|
||||
CHECK_NULL_RETURN(opt_func_int32_);
|
||||
ret = opt_func_int32_(reinterpret_cast<const int *>(input0), reinterpret_cast<const int *>(input1),
|
||||
reinterpret_cast<uint8_t *>(output), size, param_);
|
||||
} else {
|
||||
CHECK_NULL_RETURN(func_int32_);
|
||||
ret = func_int32_(reinterpret_cast<const int *>(input0), reinterpret_cast<const int *>(input1),
|
||||
reinterpret_cast<uint8_t *>(output), size);
|
||||
}
|
||||
} else {
|
||||
MS_LOG(ERROR) << "Error Operator type " << kNumberTypeInt32;
|
||||
return RET_ERROR;
|
||||
}
|
||||
return ret;
|
||||
}
|
||||
|
||||
int ArithmeticCompareCPUKernel::CalcArithmeticByBatch(int task_id) {
|
||||
if (break_pos_ > ARITHMETIC_SUPPORT_DIMS_NUM || param_->out_strides_[break_pos_ - 1] == 0) {
|
||||
MS_LOG(ERROR) << "param_->out_strides_[break_pos_ - 1] is 0 or break_pos_ is > 10";
|
||||
return RET_ERROR;
|
||||
}
|
||||
|
||||
int batch_per_thread = UP_DIV(out_batch_, op_parameter_->thread_num_);
|
||||
int start_batch = batch_per_thread * task_id;
|
||||
int end_batch = MSMIN(start_batch + batch_per_thread, out_batch_);
|
||||
int ret = RET_ERROR;
|
||||
for (int i = start_batch; i < end_batch; i++) {
|
||||
batch_a_ptr_ = static_cast<uint8_t *>(input0_ptr_) + a_offset_[i] * a_stride_size_ * data_type_len_;
|
||||
batch_b_ptr_ = static_cast<uint8_t *>(input1_ptr_) + b_offset_[i] * b_stride_size_ * data_type_len_;
|
||||
batch_c_ptr_ = static_cast<uint8_t *>(output_ptr_) + i * c_stride_size_ * sizeof(uint8_t);
|
||||
if (batch_scalar_) {
|
||||
ret = Execute(batch_a_ptr_, batch_b_ptr_, batch_c_ptr_, c_stride_size_, true);
|
||||
} else {
|
||||
ret = Execute(batch_a_ptr_, batch_b_ptr_, batch_c_ptr_, c_stride_size_, false);
|
||||
}
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "failed to calculate.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
}
|
||||
return RET_OK;
|
||||
return ret;
|
||||
}
|
||||
|
||||
int ArithmeticCompareCPUKernel::DoArithmetic(int task_id) {
|
||||
auto element_num = out_tensors_[0]->ElementsNum();
|
||||
if (split_by_batch_) {
|
||||
return CalcArithmeticByBatch(task_id);
|
||||
}
|
||||
|
||||
MS_ASSERT(op_parameter_->thread_num_ != 0);
|
||||
int64_t element_num = out_tensors_[0]->ElementsNum();
|
||||
auto ret = RET_ERROR;
|
||||
int stride = UP_DIV(element_num, op_parameter_->thread_num_);
|
||||
int count = MSMIN(stride, element_num - stride * task_id);
|
||||
if (count <= 0) {
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
if (func_fp32_ == nullptr) {
|
||||
MS_LOG(ERROR) << "func_fp32_ function is nullptr!";
|
||||
return RET_ERROR;
|
||||
}
|
||||
|
||||
int error_code;
|
||||
if (param_->broadcasting_) { // need broadcast
|
||||
stride = UP_DIV(outside_, op_parameter_->thread_num_);
|
||||
int out_count = MSMIN(stride, outside_ - stride * task_id);
|
||||
int out_thread_stride = stride * task_id;
|
||||
if (out_count <= 0) {
|
||||
return RET_OK;
|
||||
}
|
||||
if (in_tensors_[0]->data_type() == kNumberTypeFloat32) {
|
||||
error_code = BroadcastRun(reinterpret_cast<float *>(input0_ptr_), reinterpret_cast<float *>(input1_ptr_),
|
||||
reinterpret_cast<uint8_t *>(out_tensors_[0]->data()), 0, out_count, out_thread_stride);
|
||||
CHECK_LESS_RETURN(ARITHMETIC_SUPPORT_DIMS_NUM, param_->ndim_);
|
||||
int in_offset = stride * task_id * data_type_len_;
|
||||
int out_offset = stride * task_id * sizeof(uint8_t);
|
||||
if (scalar_) {
|
||||
if (param_->in_elements_num0_ == 1) {
|
||||
ret = Execute(batch_a_ptr_, batch_b_ptr_ + in_offset, batch_c_ptr_ + out_offset, count, true);
|
||||
} else {
|
||||
error_code = BroadcastRun(reinterpret_cast<int *>(input0_ptr_), reinterpret_cast<int *>(input1_ptr_),
|
||||
reinterpret_cast<uint8_t *>(out_tensors_[0]->data()), 0, out_count, out_thread_stride);
|
||||
}
|
||||
} else { // no broadcast, neither is scalar, two same shape
|
||||
if (in_tensors_[0]->data_type() == kNumberTypeFloat32) {
|
||||
error_code = func_fp32_(reinterpret_cast<float *>(input0_ptr_) + stride * task_id,
|
||||
reinterpret_cast<float *>(input1_ptr_) + stride * task_id,
|
||||
reinterpret_cast<uint8_t *>(out_tensors_[0]->data()) + stride * task_id, count);
|
||||
} else {
|
||||
error_code = func_int32_(reinterpret_cast<int *>(input0_ptr_) + stride * task_id,
|
||||
reinterpret_cast<int *>(input1_ptr_) + stride * task_id,
|
||||
reinterpret_cast<uint8_t *>(out_tensors_[0]->data()) + stride * task_id, count);
|
||||
ret = Execute(batch_a_ptr_ + in_offset, batch_b_ptr_, batch_c_ptr_ + out_offset, count, true);
|
||||
}
|
||||
} else {
|
||||
ret = Execute(batch_a_ptr_ + in_offset, batch_b_ptr_ + in_offset, batch_c_ptr_ + out_offset, count, false);
|
||||
}
|
||||
if (error_code != RET_OK) {
|
||||
return RET_ERROR;
|
||||
}
|
||||
return RET_OK;
|
||||
return ret;
|
||||
}
|
||||
|
||||
REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_Equal, LiteKernelCreator<ArithmeticCompareCPUKernel>)
|
||||
|
|
|
@ -21,53 +21,38 @@
|
|||
#include "nnacl/fp32/arithmetic_compare_fp32.h"
|
||||
|
||||
namespace mindspore::kernel {
|
||||
typedef int (*ArithmeticCompareFp32Func)(const float *input0, const float *input1, uint8_t *output, int element_size);
|
||||
typedef int (*ArithmeticCompareIntFunc)(const int *input0, const int *input1, uint8_t *output, int element_size);
|
||||
class ArithmeticCompareCPUKernel : public ArithmeticCPUKernel {
|
||||
typedef int (*ArithmeticCompareFp32Func)(const float *input0, const float *input1, uint8_t *output, int element_size);
|
||||
typedef int (*ArithmeticCompareIntFunc)(const int *input0, const int *input1, uint8_t *output, int element_size);
|
||||
typedef int (*ArithmeticOptCompareFp32Func)(const float *input0, const float *input1, uint8_t *output,
|
||||
int element_size, const ArithmeticParameter *param);
|
||||
typedef int (*ArithmeticOptCompareIntFunc)(const int *input0, const int *input1, uint8_t *output, int element_size,
|
||||
const ArithmeticParameter *param);
|
||||
typedef struct {
|
||||
int primitive_type_;
|
||||
ArithmeticCompareFp32Func func_;
|
||||
ArithmeticCompareIntFunc int_func_;
|
||||
ArithmeticOptCompareFp32Func opt_func_;
|
||||
ArithmeticOptCompareIntFunc opt_int_func_;
|
||||
} ARITHMETIC_COMEPARE_FUNC_INFO_FP32;
|
||||
|
||||
public:
|
||||
explicit ArithmeticCompareCPUKernel(OpParameter *parameter, const std::vector<lite::Tensor *> &inputs,
|
||||
const std::vector<lite::Tensor *> &outputs, const lite::InnerContext *ctx)
|
||||
: ArithmeticCPUKernel(parameter, inputs, outputs, ctx) {
|
||||
switch (parameter->type_) {
|
||||
case PrimitiveType_Equal:
|
||||
func_fp32_ = ElementEqualFp32;
|
||||
func_int32_ = ElementEqualInt32;
|
||||
break;
|
||||
case PrimitiveType_NotEqual:
|
||||
func_fp32_ = ElementNotEqualFp32;
|
||||
func_int32_ = ElementNotEqualInt32;
|
||||
break;
|
||||
case PrimitiveType_Less:
|
||||
func_fp32_ = ElementLessFp32;
|
||||
func_int32_ = ElementLessInt32;
|
||||
break;
|
||||
case PrimitiveType_LessEqual:
|
||||
func_fp32_ = ElementLessEqualFp32;
|
||||
func_int32_ = ElementLessEqualInt32;
|
||||
break;
|
||||
case PrimitiveType_Greater:
|
||||
func_fp32_ = ElementGreaterFp32;
|
||||
func_int32_ = ElementGreaterInt32;
|
||||
break;
|
||||
case PrimitiveType_GreaterEqual:
|
||||
func_fp32_ = ElementGreaterEqualFp32;
|
||||
func_int32_ = ElementGreaterEqualInt32;
|
||||
break;
|
||||
default:
|
||||
MS_LOG(ERROR) << "Error Operator type " << parameter->type_;
|
||||
func_fp32_ = nullptr;
|
||||
func_int32_ = nullptr;
|
||||
break;
|
||||
}
|
||||
}
|
||||
: ArithmeticCPUKernel(parameter, inputs, outputs, ctx) {}
|
||||
~ArithmeticCompareCPUKernel() override = default;
|
||||
|
||||
int DoArithmetic(int task_id) override;
|
||||
int BroadcastRun(void *input0, void *input1, void *output, int dim, int out_count, int out_thread_stride) override;
|
||||
|
||||
protected:
|
||||
void InitRunFunction(int primitive_type) override;
|
||||
int Execute(const void *input0, const void *input1, void *output, int size, bool is_opt) override;
|
||||
int CalcArithmeticByBatch(int task_id) override;
|
||||
|
||||
private:
|
||||
ArithmeticCompareFp32Func func_fp32_ = nullptr;
|
||||
ArithmeticCompareIntFunc func_int32_ = nullptr;
|
||||
ArithmeticOptCompareFp32Func opt_func_fp32_ = nullptr;
|
||||
ArithmeticOptCompareIntFunc opt_func_int32_ = nullptr;
|
||||
};
|
||||
int ArithmeticCompareRun(void *cdata, int task_id, float lhs_scale, float rhs_scale);
|
||||
} // namespace mindspore::kernel
|
||||
|
|
|
@ -50,27 +50,133 @@ int ArithmeticCPUKernel::Prepare() {
|
|||
}
|
||||
return ReSize();
|
||||
}
|
||||
|
||||
bool ArithmeticCPUKernel::IsScalarClac() {
|
||||
if (param_->in_elements_num0_ == 1 || param_->in_elements_num1_ == 1) {
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
int ArithmeticCPUKernel::ReSize() {
|
||||
CalcMultiplesAndStrides(param_);
|
||||
if (param_->broadcasting_) {
|
||||
outside_ = 1;
|
||||
for (int i = static_cast<int>(param_->ndim_) - 1; i >= 0 && i < ARITHMETIC_SUPPORT_DIMS_NUM; --i) {
|
||||
if (param_->in_shape0_[i] != param_->in_shape1_[i]) {
|
||||
break_pos_ = i;
|
||||
break;
|
||||
}
|
||||
outside_ *= param_->out_shape_[i];
|
||||
scalar_ = IsScalarClac();
|
||||
int ret = RET_OK;
|
||||
if (!scalar_) {
|
||||
ret = ConstTensorBroadCast();
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "failed to init const tensor";
|
||||
return ret;
|
||||
}
|
||||
}
|
||||
data_type_len_ = lite::DataTypeSize(in_tensors_.at(0)->data_type());
|
||||
int ret = RET_OK;
|
||||
if (!IsScalarClac() && !IsBatchScalarCalc() && !IsBiasCalc()) {
|
||||
ret = ConstTensorBroadCast();
|
||||
if (!scalar_ && param_->broadcasting_) {
|
||||
ret = InitIndexOffsetInfo();
|
||||
}
|
||||
data_type_len_ = lite::DataTypeSize(in_tensors_.at(0)->data_type());
|
||||
|
||||
return ret;
|
||||
}
|
||||
|
||||
bool ArithmeticCPUKernel::IsBatchScalarCalc() { // 1 32 240 240, 2 32 1 1
|
||||
int last_batch_axis0 = ARITHMETIC_SUPPORT_DIMS_NUM + 1;
|
||||
int last_batch_axis1 = ARITHMETIC_SUPPORT_DIMS_NUM + 1;
|
||||
if (param_->in_shape0_[param_->ndim_ - 1] == 1) {
|
||||
for (int i = static_cast<int>(param_->ndim_) - 1; i >= 0 && i < ARITHMETIC_SUPPORT_DIMS_NUM; --i) {
|
||||
if (param_->in_shape0_[i] != 1) {
|
||||
last_batch_axis0 = i;
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
if (param_->in_shape1_[param_->ndim_ - 1] == 1) {
|
||||
for (int i = static_cast<int>(param_->ndim_) - 1; i >= 0 && i < ARITHMETIC_SUPPORT_DIMS_NUM; --i) {
|
||||
if (param_->in_shape1_[i] != 1) {
|
||||
last_batch_axis1 = i;
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
int min_axis = MSMIN(last_batch_axis0, last_batch_axis1);
|
||||
if (min_axis < static_cast<int>(param_->ndim_) - 1) {
|
||||
last_batch_axis_ = min_axis;
|
||||
if (last_batch_axis0 < last_batch_axis1) {
|
||||
param_->in_elements_num0_ = 1;
|
||||
} else {
|
||||
param_->in_elements_num1_ = 1;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
int ArithmeticCPUKernel::InitIndexOffsetInfo() {
|
||||
split_by_batch_ = true;
|
||||
for (int i = static_cast<int>(param_->ndim_) - 1; i >= 0 && i < ARITHMETIC_SUPPORT_DIMS_NUM; --i) {
|
||||
if (param_->in_shape0_[i] != param_->in_shape1_[i]) {
|
||||
break_pos_ = i;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<int> a_shape;
|
||||
std::vector<int> b_shape;
|
||||
std::vector<int> c_shape = out_tensors_[0]->shape();
|
||||
size_t dim = c_shape.size();
|
||||
for (size_t i = 0; i < dim; ++i) {
|
||||
a_shape.push_back(param_->in_shape0_[i]);
|
||||
b_shape.push_back(param_->in_shape1_[i]);
|
||||
}
|
||||
batch_scalar_ = IsBatchScalarCalc();
|
||||
|
||||
a_stride_size_ = 1;
|
||||
b_stride_size_ = 1;
|
||||
c_stride_size_ = 1;
|
||||
int last_batch_axis = batch_scalar_ ? last_batch_axis_ : break_pos_;
|
||||
for (int i = static_cast<int>(param_->ndim_) - 1; i > last_batch_axis && i < ARITHMETIC_SUPPORT_DIMS_NUM; --i) {
|
||||
a_stride_size_ *= a_shape[i];
|
||||
b_stride_size_ *= b_shape[i];
|
||||
c_stride_size_ *= c_shape[i];
|
||||
}
|
||||
|
||||
out_batch_ = 1;
|
||||
int batch_size[ARITHMETIC_SUPPORT_DIMS_NUM] = {};
|
||||
int a_batch_size[ARITHMETIC_SUPPORT_DIMS_NUM] = {};
|
||||
int b_batch_size[ARITHMETIC_SUPPORT_DIMS_NUM] = {};
|
||||
for (int i = last_batch_axis; i >= 0; --i) {
|
||||
out_batch_ *= c_shape[i];
|
||||
if (i == last_batch_axis) {
|
||||
batch_size[i] = c_shape[i];
|
||||
a_batch_size[i] = a_shape[i];
|
||||
b_batch_size[i] = b_shape[i];
|
||||
} else {
|
||||
batch_size[i] = batch_size[i + 1] * c_shape[i];
|
||||
a_batch_size[i] = a_batch_size[i + 1] * a_shape[i];
|
||||
b_batch_size[i] = b_batch_size[i + 1] * b_shape[i];
|
||||
}
|
||||
}
|
||||
|
||||
a_offset_.resize(out_batch_, 0);
|
||||
b_offset_.resize(out_batch_, 0);
|
||||
for (int i = 0; i < out_batch_; ++i) {
|
||||
int delta = i;
|
||||
int a_offset = 0;
|
||||
int b_offset = 0;
|
||||
for (int j = 0; j <= last_batch_axis; ++j) {
|
||||
if (j > 0) {
|
||||
delta = delta % batch_size[j];
|
||||
}
|
||||
if (j < last_batch_axis) {
|
||||
a_offset += (delta / batch_size[j + 1] * a_shape[j] / MSMAX(a_shape[j], b_shape[j])) * a_batch_size[j + 1];
|
||||
b_offset += (delta / batch_size[j + 1] * b_shape[j] / MSMAX(a_shape[j], b_shape[j])) * b_batch_size[j + 1];
|
||||
} else {
|
||||
a_offset += (delta * a_shape[j] / MSMAX(a_shape[j], b_shape[j]));
|
||||
b_offset += (delta * b_shape[j] / MSMAX(a_shape[j], b_shape[j]));
|
||||
}
|
||||
}
|
||||
a_offset_[i] = a_offset;
|
||||
b_offset_[i] = b_offset;
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
int ArithmeticCPUKernel::CheckDataType() {
|
||||
auto in0_dataType = in_tensors_.at(0)->data_type();
|
||||
auto in1_dataType = in_tensors_.at(1)->data_type();
|
||||
|
@ -85,47 +191,6 @@ int ArithmeticCPUKernel::CheckDataType() {
|
|||
return RET_OK;
|
||||
}
|
||||
|
||||
bool ArithmeticCPUKernel::IsScalarClac() { // 2 32 240 240, 1 1 1 1
|
||||
if ((param_->in_elements_num0_ == 1 || param_->in_elements_num1_ == 1) && (arithmetic_opt_run_ != nullptr)) {
|
||||
return true;
|
||||
} else {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
bool ArithmeticCPUKernel::IsBatchScalarCalc() { // 2 32 240 240, 2 32 1 1
|
||||
if (arithmetic_opt_run_ == nullptr) {
|
||||
return false;
|
||||
}
|
||||
size_t break_axis = 0;
|
||||
for (size_t i = 0; i < param_->ndim_; i++) {
|
||||
if (param_->in_shape0_[i] != param_->in_shape1_[i]) {
|
||||
break_axis = i;
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (break_axis < param_->ndim_) {
|
||||
for (size_t i = break_axis; i < param_->ndim_; i++) {
|
||||
if (param_->in_shape1_[i] != 1) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
}
|
||||
break_pos_ = break_axis;
|
||||
return true;
|
||||
}
|
||||
|
||||
bool ArithmeticCPUKernel::IsBiasCalc() const { // 2 240 240 32, 1 1 1 32
|
||||
int last_shape0 = param_->in_shape0_[param_->ndim_ - 1];
|
||||
int last_shape1 = param_->in_shape1_[param_->ndim_ - 1];
|
||||
if (param_->in_elements_num0_ > param_->in_elements_num1_) {
|
||||
return param_->in_elements_num1_ == last_shape1 && last_shape0 == last_shape1;
|
||||
} else if (param_->in_elements_num0_ < param_->in_elements_num1_) {
|
||||
return param_->in_elements_num0_ == last_shape0 && last_shape0 == last_shape1;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
int ArithmeticCPUKernel::ConstTensorBroadCast() {
|
||||
/* if const node need broadcast and all need-broadcast-node are const, broadcast in resize */
|
||||
if (!param_->broadcasting_) {
|
||||
|
@ -168,13 +233,11 @@ int ArithmeticCPUKernel::ConstTensorBroadCast() {
|
|||
}
|
||||
}
|
||||
// broadcast input and get new break_pos_
|
||||
outside_ = 1;
|
||||
for (int i = static_cast<int>(param_->ndim_) - 1; i >= 0; --i) {
|
||||
if (param_->in_shape0_[i] != param_->in_shape1_[i]) {
|
||||
break_pos_ = i;
|
||||
break;
|
||||
}
|
||||
outside_ *= param_->out_shape_[i];
|
||||
}
|
||||
if (param_->in_elements_num0_ == param_->out_elements_num_ &&
|
||||
param_->in_elements_num1_ == param_->out_elements_num_) {
|
||||
|
@ -205,50 +268,51 @@ void ArithmeticCPUKernel::FreeConstTileBuff() {
|
|||
void ArithmeticCPUKernel::InitRunFunction(int primitive_type) {
|
||||
ARITHMETIC_FUNC_INFO_FP32 fun_table[] = {
|
||||
{PrimitiveType_MulFusion, schema::ActivationType_RELU, ElementMulRelu, ElementMulReluInt, nullptr,
|
||||
ElementOptMulRelu, ElementOptMulReluInt},
|
||||
ElementOptMulRelu, ElementOptMulReluInt, nullptr},
|
||||
{PrimitiveType_MulFusion, schema::ActivationType_RELU6, ElementMulRelu6, ElementMulRelu6Int, nullptr,
|
||||
ElementOptMulRelu6, ElementOptMulRelu6Int},
|
||||
ElementOptMulRelu6, ElementOptMulRelu6Int, nullptr},
|
||||
{PrimitiveType_MulFusion, schema::ActivationType_NO_ACTIVATION, ElementMul, ElementMulInt, nullptr, ElementOptMul,
|
||||
ElementOptMulInt},
|
||||
{PrimitiveType_AddFusion, schema::ActivationType_RELU, ElementAddRelu, nullptr, nullptr, ElementOptAddRelu,
|
||||
ElementOptMulInt, nullptr},
|
||||
{PrimitiveType_AddFusion, schema::ActivationType_RELU, ElementAddRelu, nullptr, nullptr, ElementOptAddRelu, nullptr,
|
||||
nullptr},
|
||||
{PrimitiveType_AddFusion, schema::ActivationType_RELU6, ElementAddRelu6, nullptr, nullptr, ElementOptAddRelu6,
|
||||
nullptr},
|
||||
nullptr, nullptr},
|
||||
{PrimitiveType_AddFusion, schema::ActivationType_NO_ACTIVATION, ElementAdd, ElementAddInt, nullptr, ElementOptAdd,
|
||||
ElementOptAddInt},
|
||||
{PrimitiveType_SubFusion, schema::ActivationType_RELU, ElementSubRelu, nullptr, nullptr, ElementOptSubRelu,
|
||||
ElementOptAddInt, nullptr},
|
||||
{PrimitiveType_SubFusion, schema::ActivationType_RELU, ElementSubRelu, nullptr, nullptr, ElementOptSubRelu, nullptr,
|
||||
nullptr},
|
||||
{PrimitiveType_SubFusion, schema::ActivationType_RELU6, ElementSubRelu6, nullptr, nullptr, ElementOptSubRelu6,
|
||||
nullptr},
|
||||
nullptr, nullptr},
|
||||
{PrimitiveType_SubFusion, schema::ActivationType_NO_ACTIVATION, ElementSub, ElementSubInt, nullptr, ElementOptSub,
|
||||
ElementOptSubInt},
|
||||
{PrimitiveType_DivFusion, schema::ActivationType_RELU, ElementDivRelu, nullptr, nullptr, ElementOptDivRelu,
|
||||
ElementOptSubInt, nullptr},
|
||||
{PrimitiveType_DivFusion, schema::ActivationType_RELU, ElementDivRelu, nullptr, nullptr, ElementOptDivRelu, nullptr,
|
||||
nullptr},
|
||||
{PrimitiveType_DivFusion, schema::ActivationType_RELU6, ElementDivRelu6, nullptr, nullptr, ElementOptDivRelu6,
|
||||
nullptr},
|
||||
nullptr, nullptr},
|
||||
{PrimitiveType_DivFusion, schema::ActivationType_NO_ACTIVATION, ElementDiv, nullptr, nullptr, ElementOptDiv,
|
||||
ElementOptDivInt},
|
||||
{PrimitiveType_RealDiv, schema::ActivationType_RELU, ElementDivRelu, nullptr, nullptr, ElementOptDivRelu, nullptr},
|
||||
ElementOptDivInt, nullptr},
|
||||
{PrimitiveType_RealDiv, schema::ActivationType_RELU, ElementDivRelu, nullptr, nullptr, ElementOptDivRelu, nullptr,
|
||||
nullptr},
|
||||
{PrimitiveType_RealDiv, schema::ActivationType_RELU6, ElementDivRelu6, nullptr, nullptr, ElementOptDivRelu6,
|
||||
nullptr},
|
||||
nullptr, nullptr},
|
||||
{PrimitiveType_RealDiv, schema::ActivationType_NO_ACTIVATION, ElementDiv, nullptr, nullptr, ElementOptDiv,
|
||||
ElementOptDivInt},
|
||||
ElementOptDivInt, nullptr},
|
||||
{PrimitiveType_LogicalAnd, schema::ActivationType_NO_ACTIVATION, ElementLogicalAnd, ElementLogicalAndInt,
|
||||
ElementLogicalAndBool, nullptr, nullptr},
|
||||
ElementLogicalAndBool, ElementOptLogicalAnd, ElementOptLogicalAndInt, ElementOptLogicalAndBool},
|
||||
{PrimitiveType_LogicalOr, schema::ActivationType_NO_ACTIVATION, ElementLogicalOr, nullptr, ElementLogicalOrBool,
|
||||
nullptr, nullptr},
|
||||
{PrimitiveType_Maximum, schema::ActivationType_NO_ACTIVATION, ElementMaximum, ElementMaximumInt, nullptr, nullptr,
|
||||
nullptr},
|
||||
{PrimitiveType_Minimum, schema::ActivationType_NO_ACTIVATION, ElementMinimum, ElementMinimumInt, nullptr, nullptr,
|
||||
nullptr},
|
||||
nullptr, nullptr, ElementOptLogicalOrBool},
|
||||
{PrimitiveType_Maximum, schema::ActivationType_NO_ACTIVATION, ElementMaximum, ElementMaximumInt, nullptr,
|
||||
ElementOptMaximum, ElementOptMaximumInt, nullptr},
|
||||
{PrimitiveType_Minimum, schema::ActivationType_NO_ACTIVATION, ElementMinimum, ElementMinimumInt, nullptr,
|
||||
ElementOptMinimum, ElementOptMinimumInt, nullptr},
|
||||
{PrimitiveType_FloorMod, schema::ActivationType_NO_ACTIVATION, ElementFloorMod, ElementFloorModInt, nullptr,
|
||||
nullptr, nullptr},
|
||||
ElementOptFloorMod, ElementOptFloorModInt, nullptr},
|
||||
{PrimitiveType_FloorDiv, schema::ActivationType_NO_ACTIVATION, ElementFloorDiv, ElementFloorDivInt, nullptr,
|
||||
nullptr, nullptr},
|
||||
ElementOptFloorDiv, ElementOptFloorDivInt, nullptr},
|
||||
{PrimitiveType_Mod, schema::ActivationType_NO_ACTIVATION, ElementMod, ElementModInt, nullptr, ElementOptMod,
|
||||
ElementOptModInt},
|
||||
ElementOptModInt, nullptr},
|
||||
{PrimitiveType_SquaredDifference, schema::ActivationType_NO_ACTIVATION, ElementSquaredDifference, nullptr, nullptr,
|
||||
nullptr, nullptr}};
|
||||
ElementOptSquaredDifference, nullptr, nullptr}};
|
||||
|
||||
size_t length = sizeof(fun_table) / sizeof(ARITHMETIC_FUNC_INFO_FP32);
|
||||
for (size_t i = 0; i < length; i++) {
|
||||
|
@ -258,6 +322,7 @@ void ArithmeticCPUKernel::InitRunFunction(int primitive_type) {
|
|||
arithmetic_run_bool_ = fun_table[i].bool_func_;
|
||||
arithmetic_opt_run_ = fun_table[i].opt_func_;
|
||||
arithmetic_opt_run_int_ = fun_table[i].opt_int_func_;
|
||||
arithmetic_opt_run_bool_ = fun_table[i].opt_bool_func_;
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
@ -276,9 +341,15 @@ int ArithmeticCPUKernel::Execute(const void *input0, const void *input1, void *o
|
|||
reinterpret_cast<float *>(output), size);
|
||||
}
|
||||
} else if (in_tensors_[0]->data_type() == kNumberTypeBool) {
|
||||
CHECK_NULL_RETURN(arithmetic_run_bool_);
|
||||
ret = arithmetic_run_bool_(reinterpret_cast<const bool *>(input0), reinterpret_cast<const bool *>(input1),
|
||||
reinterpret_cast<bool *>(output), size);
|
||||
if (is_opt) {
|
||||
CHECK_NULL_RETURN(arithmetic_opt_run_bool_);
|
||||
ret = arithmetic_opt_run_bool_(reinterpret_cast<const bool *>(input0), reinterpret_cast<const bool *>(input1),
|
||||
reinterpret_cast<bool *>(output), size, param_);
|
||||
} else {
|
||||
CHECK_NULL_RETURN(arithmetic_run_bool_);
|
||||
ret = arithmetic_run_bool_(reinterpret_cast<const bool *>(input0), reinterpret_cast<const bool *>(input1),
|
||||
reinterpret_cast<bool *>(output), size);
|
||||
}
|
||||
} else {
|
||||
if (is_opt) {
|
||||
CHECK_NULL_RETURN(arithmetic_opt_run_int_);
|
||||
|
@ -293,102 +364,35 @@ int ArithmeticCPUKernel::Execute(const void *input0, const void *input1, void *o
|
|||
return ret;
|
||||
}
|
||||
|
||||
int ArithmeticCPUKernel::BroadcastRun(void *input0, void *input1, void *output, int dim, int out_count,
|
||||
int out_thread_stride) {
|
||||
if (dim > break_pos_) {
|
||||
int offset = out_thread_stride * data_type_len_;
|
||||
return Execute(static_cast<uint8_t *>(input0) + offset, static_cast<uint8_t *>(input1) + offset,
|
||||
static_cast<uint8_t *>(output) + offset, out_count, false);
|
||||
}
|
||||
int offset_size[] = {param_->in_strides0_[dim] * data_type_len_, param_->in_strides1_[dim] * data_type_len_,
|
||||
param_->out_strides_[dim] * data_type_len_};
|
||||
for (int i = 0; i < param_->out_shape_[dim]; ++i) {
|
||||
int pos0_ = param_->in_shape0_[dim] == 1 ? 0 : i;
|
||||
int pos1_ = param_->in_shape1_[dim] == 1 ? 0 : i;
|
||||
int ret = BroadcastRun(static_cast<uint8_t *>(input0) + pos0_ * offset_size[0],
|
||||
static_cast<uint8_t *>(input1) + pos1_ * offset_size[1],
|
||||
static_cast<uint8_t *>(output) + i * offset_size[2], dim + 1, out_count, out_thread_stride);
|
||||
int ArithmeticCPUKernel::CalcArithmeticByBatch(int task_id) {
|
||||
int batch_per_thread = UP_DIV(out_batch_, op_parameter_->thread_num_);
|
||||
int start_batch = batch_per_thread * task_id;
|
||||
int end_batch = MSMIN(start_batch + batch_per_thread, out_batch_);
|
||||
int ret = RET_ERROR;
|
||||
for (int i = start_batch; i < end_batch; i++) {
|
||||
batch_a_ptr_ = static_cast<uint8_t *>(input0_ptr_) + a_offset_[i] * a_stride_size_ * data_type_len_;
|
||||
batch_b_ptr_ = static_cast<uint8_t *>(input1_ptr_) + b_offset_[i] * b_stride_size_ * data_type_len_;
|
||||
batch_c_ptr_ = static_cast<uint8_t *>(output_ptr_) + i * c_stride_size_ * data_type_len_;
|
||||
if (batch_scalar_) {
|
||||
ret = Execute(batch_a_ptr_, batch_b_ptr_, batch_c_ptr_, c_stride_size_, true);
|
||||
} else {
|
||||
ret = Execute(batch_a_ptr_, batch_b_ptr_, batch_c_ptr_, c_stride_size_, false);
|
||||
}
|
||||
if (ret != RET_OK) {
|
||||
return ret;
|
||||
}
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
int ArithmeticCPUKernel::BatchScalarCalc(int task_id) {
|
||||
if (break_pos_ < 1) {
|
||||
return RET_ERROR;
|
||||
}
|
||||
if (break_pos_ > ARITHMETIC_SUPPORT_DIMS_NUM || param_->out_strides_[break_pos_ - 1] == 0) {
|
||||
MS_LOG(ERROR) << "param_->out_strides_[break_pos_ - 1] is 0 or break_pos_ is > 10";
|
||||
return RET_ERROR;
|
||||
}
|
||||
int batch = param_->out_elements_num_ / param_->out_strides_[break_pos_ - 1];
|
||||
int batch_per_thread = UP_DIV(batch, op_parameter_->thread_num_);
|
||||
|
||||
int start_batch = batch_per_thread * task_id;
|
||||
int end_batch = MSMIN(start_batch + batch_per_thread, batch);
|
||||
int batch_size = end_batch - start_batch;
|
||||
|
||||
int stride0 = param_->in_strides0_[break_pos_ - 1] * data_type_len_;
|
||||
int stride1 = param_->in_strides1_[break_pos_ - 1] * data_type_len_;
|
||||
int out_stride = param_->out_strides_[break_pos_ - 1] * data_type_len_;
|
||||
|
||||
int offset0 = stride0 * start_batch;
|
||||
int offset1 = stride1 * start_batch;
|
||||
int out_offset = out_stride * start_batch;
|
||||
|
||||
int ret = RET_OK;
|
||||
for (int i = 0; i < batch_size; i++) {
|
||||
ret = Execute(static_cast<uint8_t *>(input0_ptr_) + offset0, static_cast<uint8_t *>(input1_ptr_) + offset1,
|
||||
static_cast<uint8_t *>(output_ptr_) + out_offset, param_->out_strides_[break_pos_ - 1], true);
|
||||
offset0 += stride0;
|
||||
offset1 += stride1;
|
||||
out_offset += out_stride;
|
||||
}
|
||||
return ret;
|
||||
}
|
||||
|
||||
int ArithmeticCPUKernel::BiasCalc(int task_id) {
|
||||
if (param_->ndim_ > ARITHMETIC_SUPPORT_DIMS_NUM || param_->out_shape_[param_->ndim_ - 1] == 0) {
|
||||
MS_LOG(ERROR) << "BiasCalc param is error!";
|
||||
return RET_ERROR;
|
||||
}
|
||||
int last_shape = param_->out_shape_[param_->ndim_ - 1];
|
||||
int batch = param_->out_elements_num_ / last_shape;
|
||||
int batch_per_thread = UP_DIV(batch, op_parameter_->thread_num_);
|
||||
|
||||
int start_batch = batch_per_thread * task_id;
|
||||
int end_batch = MSMIN(start_batch + batch_per_thread, batch);
|
||||
int batch_size = end_batch - start_batch;
|
||||
|
||||
int stride = last_shape * data_type_len_;
|
||||
int offset = stride * start_batch;
|
||||
int ret = RET_OK;
|
||||
if (param_->in_elements_num0_ > param_->in_elements_num1_) {
|
||||
for (int i = 0; i < batch_size; i++) {
|
||||
ret = Execute(static_cast<uint8_t *>(input0_ptr_) + offset, static_cast<uint8_t *>(input1_ptr_),
|
||||
static_cast<uint8_t *>(output_ptr_) + offset, last_shape, false);
|
||||
if (ret != RET_OK) {
|
||||
return ret;
|
||||
}
|
||||
offset += stride;
|
||||
}
|
||||
} else {
|
||||
for (int i = 0; i < batch_size; i++) {
|
||||
ret = Execute(static_cast<uint8_t *>(input0_ptr_), static_cast<uint8_t *>(input1_ptr_) + offset,
|
||||
static_cast<uint8_t *>(output_ptr_) + offset, last_shape, false);
|
||||
if (ret != RET_OK) {
|
||||
return ret;
|
||||
}
|
||||
offset += stride;
|
||||
MS_LOG(ERROR) << "failed to calculate.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
}
|
||||
return ret;
|
||||
}
|
||||
|
||||
int ArithmeticCPUKernel::DoArithmetic(int task_id) {
|
||||
auto element_num = out_tensors_[0]->ElementsNum();
|
||||
if (split_by_batch_) {
|
||||
return CalcArithmeticByBatch(task_id);
|
||||
}
|
||||
|
||||
int64_t element_num = out_tensors_[0]->ElementsNum();
|
||||
auto ret = RET_ERROR;
|
||||
int stride = UP_DIV(element_num, op_parameter_->thread_num_);
|
||||
int count = MSMIN(stride, element_num - stride * task_id);
|
||||
if (count <= 0) {
|
||||
|
@ -396,36 +400,16 @@ int ArithmeticCPUKernel::DoArithmetic(int task_id) {
|
|||
}
|
||||
CHECK_LESS_RETURN(ARITHMETIC_SUPPORT_DIMS_NUM, param_->ndim_);
|
||||
int offset = stride * task_id * data_type_len_;
|
||||
/* run opt function, one of input is scalar */
|
||||
if (IsScalarClac()) { // 2 32 240 240, 1 1 1 1
|
||||
if (scalar_) {
|
||||
if (param_->in_elements_num0_ == 1) {
|
||||
return Execute(input0_ptr_, static_cast<uint8_t *>(input1_ptr_) + offset,
|
||||
static_cast<uint8_t *>(output_ptr_) + offset, count, true);
|
||||
} else if (param_->in_elements_num1_ == 1) {
|
||||
return Execute(static_cast<uint8_t *>(input0_ptr_) + offset, input1_ptr_,
|
||||
static_cast<uint8_t *>(output_ptr_) + offset, count, true);
|
||||
ret = Execute(batch_a_ptr_, batch_b_ptr_ + offset, batch_c_ptr_ + offset, count, true);
|
||||
} else {
|
||||
ret = Execute(batch_a_ptr_ + offset, batch_b_ptr_, batch_c_ptr_ + offset, count, true);
|
||||
}
|
||||
} else {
|
||||
ret = Execute(batch_a_ptr_ + offset, batch_b_ptr_ + offset, batch_c_ptr_ + offset, count, false);
|
||||
}
|
||||
/* run opt function, every batch one of input is scalar */
|
||||
if (IsBatchScalarCalc()) { // 2 32 240 240, 2 32 1 1
|
||||
return BatchScalarCalc(task_id);
|
||||
}
|
||||
/* each batch is eltwise calculation */
|
||||
if (IsBiasCalc()) { // 2 240 240 32, 1 1 1 32
|
||||
return BiasCalc(task_id);
|
||||
}
|
||||
/* need broadcast in runtime */
|
||||
if (param_->broadcasting_) {
|
||||
stride = UP_DIV(outside_, op_parameter_->thread_num_);
|
||||
int out_count = MSMIN(stride, outside_ - stride * task_id);
|
||||
if (out_count <= 0) {
|
||||
return RET_OK;
|
||||
}
|
||||
return BroadcastRun(input0_ptr_, input1_ptr_, output_ptr_, 0, out_count, stride * task_id);
|
||||
}
|
||||
/* all elements eltwise calculation */
|
||||
return Execute(static_cast<uint8_t *>(input0_ptr_) + offset, static_cast<uint8_t *>(input1_ptr_) + offset,
|
||||
static_cast<uint8_t *>(output_ptr_) + offset, count, false);
|
||||
return ret;
|
||||
}
|
||||
|
||||
int ArithmeticsRun(void *cdata, int task_id, float lhs_scale, float rhs_scale) {
|
||||
|
@ -442,6 +426,7 @@ int ArithmeticCPUKernel::Run() {
|
|||
MS_LOG(ERROR) << "ArithmeticCPUKernel check dataType failed, kernel name: " << this->name();
|
||||
return RET_ERROR;
|
||||
}
|
||||
|
||||
if (!input0_broadcast_) {
|
||||
input0_ptr_ = in_tensors_[0]->data();
|
||||
CHECK_NULL_RETURN(input0_ptr_);
|
||||
|
@ -452,7 +437,16 @@ int ArithmeticCPUKernel::Run() {
|
|||
}
|
||||
output_ptr_ = out_tensors_[0]->data();
|
||||
CHECK_NULL_RETURN(output_ptr_);
|
||||
return ParallelLaunch(this->ms_context_, ArithmeticsRun, this, op_parameter_->thread_num_);
|
||||
batch_a_ptr_ = static_cast<uint8_t *>(input0_ptr_);
|
||||
batch_b_ptr_ = static_cast<uint8_t *>(input1_ptr_);
|
||||
batch_c_ptr_ = static_cast<uint8_t *>(output_ptr_);
|
||||
auto ret = ParallelLaunch(this->ms_context_, ArithmeticsRun, this, op_parameter_->thread_num_);
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "arithmetic failed";
|
||||
return RET_ERROR;
|
||||
}
|
||||
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_MulFusion, LiteKernelCreator<ArithmeticCPUKernel>)
|
||||
|
|
|
@ -49,6 +49,8 @@ class ArithmeticCPUKernel : public InnerKernel {
|
|||
typedef int (*ArithmeticOptIntRun)(const int *input0, const int *input1, int *output, const int element_size,
|
||||
const ArithmeticParameter *param);
|
||||
typedef int (*ArithmeticBoolRun)(const bool *input0, const bool *input1, bool *output, const int element_size);
|
||||
typedef int (*ArithmeticOptBoolRun)(const bool *input0, const bool *input1, bool *output, const int element_size,
|
||||
const ArithmeticParameter *param);
|
||||
|
||||
typedef struct {
|
||||
int primitive_type_;
|
||||
|
@ -58,6 +60,7 @@ class ArithmeticCPUKernel : public InnerKernel {
|
|||
ArithmeticBoolRun bool_func_;
|
||||
ArithmeticOptRun opt_func_;
|
||||
ArithmeticOptIntRun opt_int_func_;
|
||||
ArithmeticOptBoolRun opt_bool_func_;
|
||||
} ARITHMETIC_FUNC_INFO_FP32;
|
||||
|
||||
public:
|
||||
|
@ -72,7 +75,6 @@ class ArithmeticCPUKernel : public InnerKernel {
|
|||
int ReSize() override;
|
||||
int Run() override;
|
||||
virtual int DoArithmetic(int task_id);
|
||||
virtual int BroadcastRun(void *input0, void *input1, void *output, int dim, int out_count, int out_thread_stride);
|
||||
|
||||
protected:
|
||||
virtual void InitRunFunction(int primitive_type);
|
||||
|
@ -83,17 +85,31 @@ class ArithmeticCPUKernel : public InnerKernel {
|
|||
virtual int Execute(const void *input0, const void *input1, void *output, int size, bool is_opt);
|
||||
virtual bool IsBatchScalarCalc();
|
||||
virtual bool IsScalarClac();
|
||||
virtual int CalcArithmeticByBatch(int task_id);
|
||||
bool input0_broadcast_ = false;
|
||||
bool input1_broadcast_ = false;
|
||||
void *input0_ptr_ = nullptr;
|
||||
void *input1_ptr_ = nullptr;
|
||||
void *output_ptr_ = nullptr;
|
||||
uint8_t *batch_a_ptr_ = nullptr;
|
||||
uint8_t *batch_b_ptr_ = nullptr;
|
||||
uint8_t *batch_c_ptr_ = nullptr;
|
||||
int break_pos_ = 0;
|
||||
int outside_ = 0;
|
||||
ArithmeticParameter *param_ = nullptr;
|
||||
int data_type_len_ = sizeof(float);
|
||||
int out_batch_ = 1;
|
||||
int a_stride_size_ = 1;
|
||||
int b_stride_size_ = 1;
|
||||
int c_stride_size_ = 1;
|
||||
int last_batch_axis_ = 0;
|
||||
bool scalar_ = false;
|
||||
bool batch_scalar_ = false;
|
||||
bool split_by_batch_ = false;
|
||||
std::vector<int> a_offset_;
|
||||
std::vector<int> b_offset_;
|
||||
|
||||
private:
|
||||
int InitIndexOffsetInfo();
|
||||
int BatchScalarCalc(int task_id);
|
||||
int BiasCalc(int task_id);
|
||||
void FreeConstTileBuff();
|
||||
|
@ -103,6 +119,7 @@ class ArithmeticCPUKernel : public InnerKernel {
|
|||
ArithmeticIntRun arithmetic_run_int_ = nullptr;
|
||||
ArithmeticOptIntRun arithmetic_opt_run_int_ = nullptr;
|
||||
ArithmeticBoolRun arithmetic_run_bool_ = nullptr;
|
||||
ArithmeticOptBoolRun arithmetic_opt_run_bool_ = nullptr;
|
||||
};
|
||||
int ArithmeticsRun(void *cdata, int task_id, float lhs_scale, float rhs_scale);
|
||||
} // namespace mindspore::kernel
|
||||
|
|
Loading…
Reference in New Issue