diff --git a/mindspore/lite/src/runtime/kernel/arm/fp16/arithmetic_fp16.cc b/mindspore/lite/src/runtime/kernel/arm/fp16/arithmetic_fp16.cc index e2d7a9d68a0..373ebf78795 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp16/arithmetic_fp16.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp16/arithmetic_fp16.cc @@ -195,23 +195,104 @@ int ArithmeticFP16CPUKernel::ReSize() { } if (arithmeticParameter_->in_elements_num0_ == 1 || arithmeticParameter_->in_elements_num1_ == 1) { - if (arithmeticParameter_->activation_type_ == schema::ActivationType_NO_ACTIVATION) { - switch (arithmeticParameter_->op_parameter_.type_) { - case PrimitiveType_Mul: - arithmeticParameter_->broadcasting_ = false; - arithmetic_opt_run_ = ElementOptMulFp16; - break; - case PrimitiveType_Add: - arithmeticParameter_->broadcasting_ = false; - arithmetic_opt_run_ = ElementOptAddFp16; - break; - case PrimitiveType_Sub: - arithmeticParameter_->broadcasting_ = false; - arithmetic_opt_run_ = ElementOptSubFp16; - break; - default: - break; - } + switch (arithmeticParameter_->op_parameter_.type_) { + case PrimitiveType_Mul: + arithmeticParameter_->broadcasting_ = false; + switch (arithmeticParameter_->activation_type_) { + case schema::ActivationType_RELU: + arithmetic_opt_run_ = ElementOptMulReluFp16; + break; + case schema::ActivationType_RELU6: + arithmetic_opt_run_ = ElementOptDivRelu6Fp16; + break; + default: + arithmetic_opt_run_ = ElementOptDivFp16; + break; + } + break; + case PrimitiveType_Add: + arithmeticParameter_->broadcasting_ = false; + switch (arithmeticParameter_->activation_type_) { + case schema::ActivationType_RELU: + arithmetic_opt_run_ = ElementOptAddReluFp16; + break; + case schema::ActivationType_RELU6: + arithmetic_opt_run_ = ElementOptAddRelu6Fp16; + break; + default: + arithmetic_opt_run_ = ElementOptAddFp16; + break; + } + break; + case PrimitiveType_Sub: + arithmeticParameter_->broadcasting_ = false; + switch (arithmeticParameter_->activation_type_) { + case schema::ActivationType_RELU: + arithmetic_opt_run_ = ElementOptSubReluFp16; + break; + case schema::ActivationType_RELU6: + arithmetic_opt_run_ = ElementOptSubRelu6Fp16; + break; + default: + arithmetic_opt_run_ = ElementOptSubFp16; + break; + } + break; + case PrimitiveType_Div: + arithmeticParameter_->broadcasting_ = false; + switch (arithmeticParameter_->activation_type_) { + case schema::ActivationType_RELU: + arithmetic_opt_run_ = ElementOptDivReluFp16; + break; + case schema::ActivationType_RELU6: + arithmetic_opt_run_ = ElementOptDivRelu6Fp16; + break; + default: + arithmetic_opt_run_ = ElementOptDivFp16; + break; + } + break; + case PrimitiveType_FloorMod: + arithmeticParameter_->broadcasting_ = false; + arithmetic_opt_run_ = ElementOptFloorModFp16; + case PrimitiveType_FloorDiv: + arithmeticParameter_->broadcasting_ = false; + arithmetic_opt_run_ = ElementOptFloorDivFp16; + case PrimitiveType_LogicalAnd: + arithmeticParameter_->broadcasting_ = false; + arithmetic_opt_run_ = ElementOptLogicalAndFp16; + case PrimitiveType_LogicalOr: + arithmeticParameter_->broadcasting_ = false; + arithmetic_opt_run_ = ElementOptLogicalOrFp16; + case PrimitiveType_SquaredDifference: + arithmeticParameter_->broadcasting_ = false; + arithmetic_opt_run_ = ElementOptSquaredDifferenceFp16; + case PrimitiveType_Maximum: + arithmeticParameter_->broadcasting_ = false; + arithmetic_opt_run_ = ElementOptMaximumFp16; + case PrimitiveType_Minimum: + arithmeticParameter_->broadcasting_ = false; + arithmetic_opt_run_ = ElementOptMinimumFp16; + case PrimitiveType_NotEqual: + arithmeticParameter_->broadcasting_ = false; + arithmetic_opt_run_ = ElementOptNotEqualFp16; + case PrimitiveType_Equal: + arithmeticParameter_->broadcasting_ = false; + arithmetic_opt_run_ = ElementOptEqualFp16; + case PrimitiveType_Less: + arithmeticParameter_->broadcasting_ = false; + arithmetic_opt_run_ = ElementOptLessFp16; + case PrimitiveType_LessEqual: + arithmeticParameter_->broadcasting_ = false; + arithmetic_opt_run_ = ElementOptLessEqualFp16; + case PrimitiveType_Greater: + arithmeticParameter_->broadcasting_ = false; + arithmetic_opt_run_ = ElementOptGreaterFp16; + case PrimitiveType_GreaterEqual: + arithmeticParameter_->broadcasting_ = false; + arithmetic_opt_run_ = ElementOptGreaterEqualFp16; + default: + break; } } @@ -334,4 +415,17 @@ kernel::LiteKernel *CpuArithmeticFp16KernelCreator(const std::vectormultiples1_); } -int ElementOptMulFp16(float16_t *input0, float16_t *input1, float16_t *output, int element_size, - ArithmeticParameter *param) { - if (param->in_elements_num0_ == 1) { - for (int i = 0; i < element_size; ++i) { - output[i] = input0[0] * input1[i]; - } - } else if (param->in_elements_num1_ == 1) { - for (int i = 0; i < element_size; ++i) { - output[i] = input0[i] * input1[0]; - } - } else { - for (int i = 0; i < element_size; ++i) { - output[i] = input0[i] * input1[i]; - } - } - return NNACL_OK; -} - -int ElementOptSubFp16(float16_t *input0, float16_t *input1, float16_t *output, int element_size, - ArithmeticParameter *param) { - if (param->in_elements_num0_ == 1) { - for (int i = 0; i < element_size; ++i) { - output[i] = input0[0] - input1[i]; - } - } else if (param->in_elements_num1_ == 1) { - for (int i = 0; i < element_size; ++i) { - output[i] = input0[i] - input1[0]; - } - } else { - for (int i = 0; i < element_size; ++i) { - output[i] = input0[i] - input1[i]; - } - } - return NNACL_OK; -} - -int ElementOptAddFp16(float16_t *input0, float16_t *input1, float16_t *output, int element_size, - ArithmeticParameter *param) { - if (param->in_elements_num0_ == 1) { - for (int i = 0; i < element_size; ++i) { - output[i] = input0[0] + input1[i]; - } - } else if (param->in_elements_num1_ == 1) { - for (int i = 0; i < element_size; ++i) { - output[i] = input0[i] + input1[0]; - } - } else { - for (int i = 0; i < element_size; ++i) { - output[i] = input0[i] + input1[i]; - } - } - return NNACL_OK; -} - int ElementMulFp16(float16_t *input0, float16_t *input1, float16_t *output, int element_size) { int block_mod = element_size % C8NUM; int block_c8 = element_size - block_mod; @@ -124,6 +70,41 @@ int ElementMulFp16(float16_t *input0, float16_t *input1, float16_t *output, int return NNACL_OK; } +int ElementOptMulFp16(float16_t *input0, float16_t *input1, float16_t *output, int element_size, + ArithmeticParameter *param) { + int block_mod = element_size % C8NUM; + int block_c8 = element_size - block_mod; +#ifdef ENABLE_NEON + float16x8_t vin0_opt = {input0[0], input0[0], input0[0], input0[0], input0[0], input0[0], input0[0], input0[0]}; + float16x8_t vin1_opt = {input1[0], input1[0], input1[0], input1[0], input1[0], input1[0], input1[0], input1[0]}; + float16_t in0_opt = input0[0]; + float16_t in1_opt = input1[0]; +#endif + for (int index = 0; index < block_c8; index += C8NUM) { +#ifdef ENABLE_NEON + float16x8_t vin0 = param->in_elements_num0_ == 1 ? vin0_opt : vld1q_f16(input0); + float16x8_t vin1 = param->in_elements_num1_ == 1 ? vin1_opt : vld1q_f16(input1); + float16x8_t vout = vmulq_f16(vin0, vin1); + vst1q_f16(output, vout); +#else + for (int i = 0; i < C8NUM; ++i) { + float16_t in0 = param->in_elements_num0_ == 1 ? in0_opt : input0[i]; + float16_t in1 = param->in_elements_num1_ == 1 ? in1_opt : input1[i]; + output[i] = in0 * in1; + } +#endif + input0 += C8NUM; + input1 += C8NUM; + output += C8NUM; + } + for (int index = 0; index < block_mod; ++index) { + float16_t in0 = param->in_elements_num0_ == 1 ? in0_opt : input0[index]; + float16_t in1 = param->in_elements_num1_ == 1 ? in1_opt : input1[index]; + output[index] = in0 * in1; + } + + return NNACL_OK; +} int ElementMulReluFp16(float16_t *input0, float16_t *input1, float16_t *output, int element_size) { int block_mod = element_size % C8NUM; @@ -158,6 +139,47 @@ int ElementMulReluFp16(float16_t *input0, float16_t *input1, float16_t *output, return NNACL_OK; } +int ElementOptMulReluFp16(float16_t *input0, float16_t *input1, float16_t *output, int element_size, + ArithmeticParameter *param) { + int block_mod = element_size % C8NUM; + int block_c8 = element_size - block_mod; +#ifdef ENABLE_NEON + float16x8_t vin0_opt = {input0[0], input0[0], input0[0], input0[0], input0[0], input0[0], input0[0], input0[0]}; + float16x8_t vin1_opt = {input1[0], input1[0], input1[0], input1[0], input1[0], input1[0], input1[0], input1[0]}; + float16_t in0_opt = input0[0]; + float16_t in1_opt = input1[0]; + float16x8_t zeros = {0, 0, 0, 0, 0, 0, 0, 0}; +#endif + + for (int index = 0; index < block_c8; index += C8NUM) { +#ifdef ENABLE_NEON + float16x8_t vin0 = param->in_elements_num0_ == 1 ? vin0_opt : vld1q_f16(input0); + float16x8_t vin1 = param->in_elements_num1_ == 1 ? vin1_opt : vld1q_f16(input1); + float16x8_t vout = vmulq_f16(vin0, vin1); + vout = vmaxq_f16(vout, zeros); + vst1q_f16(output, vout); +#else + float16_t res; + for (int i = 0; i < C8NUM; ++i) { + float16_t in0 = param->in_elements_num0_ == 1 ? in0_opt : input0[i]; + float16_t in1 = param->in_elements_num1_ == 1 ? in1_opt : input1[i]; + res = in0 * in1; + output[i] = res > 0 ? res : 0; + } +#endif + input0 += C8NUM; + input1 += C8NUM; + output += C8NUM; + } + for (int index = 0; index < block_mod; ++index) { + float16_t in0 = param->in_elements_num0_ == 1 ? in0_opt : input0[index]; + float16_t in1 = param->in_elements_num1_ == 1 ? in1_opt : input1[index]; + float16_t res = in0 * in1; + output[index] = res > 0 ? res : 0; + } + + return NNACL_OK; +} int ElementMulRelu6Fp16(float16_t *input0, float16_t *input1, float16_t *output, int element_size) { int block_mod = element_size % C8NUM; @@ -190,6 +212,45 @@ int ElementMulRelu6Fp16(float16_t *input0, float16_t *input1, float16_t *output, return NNACL_OK; } +int ElementOptMulRelu6Fp16(float16_t *input0, float16_t *input1, float16_t *output, int element_size, + ArithmeticParameter *param) { + int block_mod = element_size % C8NUM; + int block_c8 = element_size - block_mod; +#ifdef ENABLE_NEON + float16x8_t vin0_opt = {input0[0], input0[0], input0[0], input0[0], input0[0], input0[0], input0[0], input0[0]}; + float16x8_t vin1_opt = {input1[0], input1[0], input1[0], input1[0], input1[0], input1[0], input1[0], input1[0]}; + float16_t in0_opt = input0[0]; + float16_t in1_opt = input1[0]; + float16x8_t zeros = {0, 0, 0, 0, 0, 0, 0, 0}; + float16x8_t bounds = {6, 6, 6, 6, 6, 6, 6, 6}; +#endif + + for (int index = 0; index < block_c8; index += C8NUM) { +#ifdef ENABLE_NEON + float16x8_t vin0 = param->in_elements_num0_ == 1 ? vin0_opt : vld1q_f16(input0); + float16x8_t vin1 = param->in_elements_num1_ == 1 ? vin1_opt : vld1q_f16(input1); + float16x8_t vout = vmulq_f16(vin0, vin1); + vout = vminq_f16(vmaxq_f16(vout, zeros), bounds); + vst1q_f16(output, vout); +#else + for (int i = 0; i < C8NUM; ++i) { + float16_t in0 = param->in_elements_num0_ == 1 ? in0_opt : input0[i]; + float16_t in1 = param->in_elements_num1_ == 1 ? in1_opt : input1[i]; + output[i] = MSMIN(MSMAX(in0 * in1, 0), 6); + } +#endif + input0 += C8NUM; + input1 += C8NUM; + output += C8NUM; + } + for (int index = 0; index < block_mod; ++index) { + float16_t in0 = param->in_elements_num0_ == 1 ? in0_opt : input0[index]; + float16_t in1 = param->in_elements_num1_ == 1 ? in1_opt : input1[index]; + output[index] = MSMIN(MSMAX(in0 * in1, 0), 6); + } + + return NNACL_OK; +} int ElementAddFp16(float16_t *input0, float16_t *input1, float16_t *output, int element_size) { int block_mod = element_size % C8NUM; @@ -215,6 +276,40 @@ int ElementAddFp16(float16_t *input0, float16_t *input1, float16_t *output, int } return NNACL_OK; } +int ElementOptAddFp16(float16_t *input0, float16_t *input1, float16_t *output, int element_size, + ArithmeticParameter *param) { + int block_mod = element_size % C8NUM; + int block_c8 = element_size - block_mod; +#ifdef ENABLE_NEON + float16x8_t vin0_opt = {input0[0], input0[0], input0[0], input0[0], input0[0], input0[0], input0[0], input0[0]}; + float16x8_t vin1_opt = {input1[0], input1[0], input1[0], input1[0], input1[0], input1[0], input1[0], input1[0]}; + float16_t in0_opt = input0[0]; + float16_t in1_opt = input1[0]; +#endif + for (int index = 0; index < block_c8; index += C8NUM) { +#ifdef ENABLE_NEON + float16x8_t vin0 = param->in_elements_num0_ == 1 ? vin0_opt : vld1q_f16(input0); + float16x8_t vin1 = param->in_elements_num1_ == 1 ? vin1_opt : vld1q_f16(input1); + float16x8_t vout = vaddq_f16(vin0, vin1); + vst1q_f16(output, vout); +#else + for (int i = 0; i < C8NUM; ++i) { + float16_t in0 = param->in_elements_num0_ == 1 ? in0_opt : input0[i]; + float16_t in1 = param->in_elements_num1_ == 1 ? in1_opt : input1[i]; + output[i] = in0 + in1; + } +#endif + input0 += C8NUM; + input1 += C8NUM; + output += C8NUM; + } + for (int index = 0; index < block_mod; ++index) { + float16_t in0 = param->in_elements_num0_ == 1 ? in0_opt : input0[index]; + float16_t in1 = param->in_elements_num1_ == 1 ? in1_opt : input1[index]; + output[index] = in0 + in1; + } + return NNACL_OK; +} int ElementAddReluFp16(float16_t *input0, float16_t *input1, float16_t *output, int element_size) { int block_mod = element_size % C8NUM; @@ -246,6 +341,44 @@ int ElementAddReluFp16(float16_t *input0, float16_t *input1, float16_t *output, } return NNACL_OK; } +int ElementOptAddReluFp16(float16_t *input0, float16_t *input1, float16_t *output, int element_size, + ArithmeticParameter *param) { + int block_mod = element_size % C8NUM; + int block_c8 = element_size - block_mod; +#ifdef ENABLE_NEON + float16x8_t vin0_opt = {input0[0], input0[0], input0[0], input0[0], input0[0], input0[0], input0[0], input0[0]}; + float16x8_t vin1_opt = {input1[0], input1[0], input1[0], input1[0], input1[0], input1[0], input1[0], input1[0]}; + float16_t in0_opt = input0[0]; + float16_t in1_opt = input1[0]; + float16x8_t zeros = {0, 0, 0, 0, 0, 0, 0, 0}; +#endif + + for (int index = 0; index < block_c8; index += C8NUM) { +#ifdef ENABLE_NEON + float16x8_t vin0 = param->in_elements_num0_ == 1 ? vin0_opt : vld1q_f16(input0); + float16x8_t vin1 = param->in_elements_num1_ == 1 ? vin1_opt : vld1q_f16(input1); + float16x8_t vout = vaddq_f16(vin0, vin1); + vout = vmaxq_f16(vout, zeros); + vst1q_f16(output, vout); +#else + for (int i = 0; i < C8NUM; ++i) { + float16_t in0 = param->in_elements_num0_ == 1 ? in0_opt : input0[i]; + float16_t in1 = param->in_elements_num1_ == 1 ? in1_opt : input1[i]; + output[i] = MSMAX(in0 + in1, 0); + } +#endif + input0 += C8NUM; + input1 += C8NUM; + output += C8NUM; + } + for (int index = 0; index < block_mod; ++index) { + float16_t in0 = param->in_elements_num0_ == 1 ? in0_opt : input0[index]; + float16_t in1 = param->in_elements_num1_ == 1 ? in1_opt : input1[index]; + float16_t res = in0 + in1; + output[index] = res > 0 ? res : 0; + } + return NNACL_OK; +} int ElementAddRelu6Fp16(float16_t *input0, float16_t *input1, float16_t *output, int element_size) { int block_mod = element_size % C8NUM; @@ -278,6 +411,45 @@ int ElementAddRelu6Fp16(float16_t *input0, float16_t *input1, float16_t *output, return NNACL_OK; } +int ElementOptAddRelu6Fp16(float16_t *input0, float16_t *input1, float16_t *output, int element_size, + ArithmeticParameter *param) { + int block_mod = element_size % C8NUM; + int block_c8 = element_size - block_mod; +#ifdef ENABLE_NEON + float16x8_t vin0_opt = {input0[0], input0[0], input0[0], input0[0], input0[0], input0[0], input0[0], input0[0]}; + float16x8_t vin1_opt = {input1[0], input1[0], input1[0], input1[0], input1[0], input1[0], input1[0], input1[0]}; + float16_t in0_opt = input0[0]; + float16_t in1_opt = input1[0]; + float16x8_t zeros = {0, 0, 0, 0, 0, 0, 0, 0}; + float16x8_t bounds = {6, 6, 6, 6, 6, 6, 6, 6}; +#endif + + for (int index = 0; index < block_c8; index += C8NUM) { +#ifdef ENABLE_NEON + float16x8_t vin0 = param->in_elements_num0_ == 1 ? vin0_opt : vld1q_f16(input0); + float16x8_t vin1 = param->in_elements_num1_ == 1 ? vin1_opt : vld1q_f16(input1); + float16x8_t vout = vaddq_f16(vin0, vin1); + vout = vminq_f16(vmaxq_f16(vout, zeros), bounds); + vst1q_f16(output, vout); +#else + for (int i = 0; i < C8NUM; ++i) { + float16_t in0 = param->in_elements_num0_ == 1 ? in0_opt : input0[i]; + float16_t in1 = param->in_elements_num1_ == 1 ? in1_opt : input1[i]; + output[i] = MSMIN(MSMAX(in0 + in1, 0), 6); + } +#endif + input0 += C8NUM; + input1 += C8NUM; + output += C8NUM; + } + for (int index = 0; index < block_mod; ++index) { + float16_t in0 = param->in_elements_num0_ == 1 ? in0_opt : input0[index]; + float16_t in1 = param->in_elements_num1_ == 1 ? in1_opt : input1[index]; + output[index] = MSMIN(MSMAX(in0 + in1, 0), 6); + } + + return NNACL_OK; +} int ElementSubFp16(float16_t *input0, float16_t *input1, float16_t *output, int element_size) { int block_mod = element_size % C8NUM; @@ -303,6 +475,40 @@ int ElementSubFp16(float16_t *input0, float16_t *input1, float16_t *output, int } return NNACL_OK; } +int ElementOptSubFp16(float16_t *input0, float16_t *input1, float16_t *output, int element_size, + ArithmeticParameter *param) { + int block_mod = element_size % C8NUM; + int block_c8 = element_size - block_mod; +#ifdef ENABLE_NEON + float16x8_t vin0_opt = {input0[0], input0[0], input0[0], input0[0], input0[0], input0[0], input0[0], input0[0]}; + float16x8_t vin1_opt = {input1[0], input1[0], input1[0], input1[0], input1[0], input1[0], input1[0], input1[0]}; + float16_t in0_opt = input0[0]; + float16_t in1_opt = input1[0]; +#endif + for (int index = 0; index < block_c8; index += C8NUM) { +#ifdef ENABLE_NEON + float16x8_t vin0 = param->in_elements_num0_ == 1 ? vin0_opt : vld1q_f16(input0); + float16x8_t vin1 = param->in_elements_num1_ == 1 ? vin1_opt : vld1q_f16(input1); + float16x8_t vout = vsubq_f16(vin0, vin1); + vst1q_f16(output, vout); +#else + for (int i = 0; i < C8NUM; ++i) { + float16_t in0 = param->in_elements_num0_ == 1 ? in0_opt : input0[i]; + float16_t in1 = param->in_elements_num1_ == 1 ? in1_opt : input1[i]; + output[i] = in0 - in1; + } +#endif + input0 += C8NUM; + input1 += C8NUM; + output += C8NUM; + } + for (int index = 0; index < block_mod; ++index) { + float16_t in0 = param->in_elements_num0_ == 1 ? in0_opt : input0[index]; + float16_t in1 = param->in_elements_num1_ == 1 ? in1_opt : input1[index]; + output[index] = in0 - in1; + } + return NNACL_OK; +} int ElementSubReluFp16(float16_t *input0, float16_t *input1, float16_t *output, int element_size) { int block_mod = element_size % C8NUM; @@ -332,6 +538,43 @@ int ElementSubReluFp16(float16_t *input0, float16_t *input1, float16_t *output, } return NNACL_OK; } +int ElementOptSubReluFp16(float16_t *input0, float16_t *input1, float16_t *output, int element_size, + ArithmeticParameter *param) { + int block_mod = element_size % C8NUM; + int block_c8 = element_size - block_mod; +#ifdef ENABLE_NEON + float16x8_t vin0_opt = {input0[0], input0[0], input0[0], input0[0], input0[0], input0[0], input0[0], input0[0]}; + float16x8_t vin1_opt = {input1[0], input1[0], input1[0], input1[0], input1[0], input1[0], input1[0], input1[0]}; + float16_t in0_opt = input0[0]; + float16_t in1_opt = input1[0]; + float16x8_t zeros = {0, 0, 0, 0, 0, 0, 0, 0}; +#endif + for (int index = 0; index < block_c8; index += C8NUM) { +#ifdef ENABLE_NEON + float16x8_t vin0 = param->in_elements_num0_ == 1 ? vin0_opt : vld1q_f16(input0); + float16x8_t vin1 = param->in_elements_num1_ == 1 ? vin1_opt : vld1q_f16(input1); + float16x8_t vout = vsubq_f16(vin0, vin1); + vout = vmaxq_f16(vout, zeros); + vst1q_f16(output, vout); +#else + for (int i = 0; i < C8NUM; ++i) { + float16_t in0 = param->in_elements_num0_ == 1 ? in0_opt : input0[i]; + float16_t in1 = param->in_elements_num1_ == 1 ? in1_opt : input1[i]; + output[i] = MSMAX(in0 - in1, 0); + } +#endif + input0 += C8NUM; + input1 += C8NUM; + output += C8NUM; + } + for (int index = 0; index < block_mod; ++index) { + float16_t in0 = param->in_elements_num0_ == 1 ? in0_opt : input0[index]; + float16_t in1 = param->in_elements_num1_ == 1 ? in1_opt : input1[index]; + float16_t res = in0 - in1; + output[index] = res > 0 ? res : 0; + } + return NNACL_OK; +} int ElementSubRelu6Fp16(float16_t *input0, float16_t *input1, float16_t *output, int element_size) { int block_mod = element_size % C8NUM; @@ -362,6 +605,44 @@ int ElementSubRelu6Fp16(float16_t *input0, float16_t *input1, float16_t *output, return NNACL_OK; } +int ElementOptSubRelu6Fp16(float16_t *input0, float16_t *input1, float16_t *output, int element_size, + ArithmeticParameter *param) { + int block_mod = element_size % C8NUM; + int block_c8 = element_size - block_mod; +#ifdef ENABLE_NEON + float16x8_t vin0_opt = {input0[0], input0[0], input0[0], input0[0], input0[0], input0[0], input0[0], input0[0]}; + float16x8_t vin1_opt = {input1[0], input1[0], input1[0], input1[0], input1[0], input1[0], input1[0], input1[0]}; + float16_t in0_opt = input0[0]; + float16_t in1_opt = input1[0]; + float16x8_t zeros = {0, 0, 0, 0, 0, 0, 0, 0}; + float16x8_t bounds = {6, 6, 6, 6, 6, 6, 6, 6}; +#endif + for (int index = 0; index < block_c8; index += C8NUM) { +#ifdef ENABLE_NEON + float16x8_t vin0 = param->in_elements_num0_ == 1 ? vin0_opt : vld1q_f16(input0); + float16x8_t vin1 = param->in_elements_num1_ == 1 ? vin1_opt : vld1q_f16(input1); + float16x8_t vout = vsubq_f16(vin0, vin1); + vout = vminq_f16(vmaxq_f16(vout, zeros), bounds); + vst1q_f16(output, vout); +#else + for (int i = 0; i < C8NUM; ++i) { + float16_t in0 = param->in_elements_num0_ == 1 ? in0_opt : input0[i]; + float16_t in1 = param->in_elements_num1_ == 1 ? in1_opt : input1[i]; + output[i] = MSMIN(MSMAX(in0 - in1, 0), 6); + } +#endif + input0 += C8NUM; + input1 += C8NUM; + output += C8NUM; + } + for (int index = 0; index < block_mod; ++index) { + float16_t in0 = param->in_elements_num0_ == 1 ? in0_opt : input0[index]; + float16_t in1 = param->in_elements_num1_ == 1 ? in1_opt : input1[index]; + output[index] = MSMIN(MSMAX(in0 - in1, 0), 6); + } + + return NNACL_OK; +} int ElementDivFp16(float16_t *input0, float16_t *input1, float16_t *output, int element_size) { int block_mod = element_size % C8NUM; @@ -376,7 +657,7 @@ int ElementDivFp16(float16_t *input0, float16_t *input1, float16_t *output, int #ifdef ENABLE_NEON float16x8_t vin0 = vld1q_f16(input0); float16x8_t vin1 = vld1q_f16(input1); - float16x8_t vout = vsubq_f16(vin0, vin1); + float16x8_t vout = vdivq_f16(vin0, vin1); vst1q_f16(output, vout); #else for (int i = 0; i < C8NUM; ++i) { @@ -395,6 +676,54 @@ int ElementDivFp16(float16_t *input0, float16_t *input1, float16_t *output, int } return NNACL_OK; } +int ElementOptDivFp16(float16_t *input0, float16_t *input1, float16_t *output, int element_size, + ArithmeticParameter *param) { + int block_mod = element_size % C8NUM; + int block_c8 = element_size - block_mod; +#ifdef ENABLE_NEON + float16x8_t vin0_opt = {input0[0], input0[0], input0[0], input0[0], input0[0], input0[0], input0[0], input0[0]}; + float16x8_t vin1_opt = {input1[0], input1[0], input1[0], input1[0], input1[0], input1[0], input1[0], input1[0]}; + float16_t in0_opt = input0[0]; + float16_t in1_opt = input1[0]; +#endif + for (int index = 0; index < block_c8; index += C8NUM) { + if (param->in_elements_num1_ == 1) { + if (in1_opt == 0) { + return NNACL_ERRCODE_DIVISOR_ZERO; + } + } else { + for (int i = 0; i < C8NUM; ++i) { + if (input1[i] == 0) { + return NNACL_ERRCODE_DIVISOR_ZERO; + } + } + } +#ifdef ENABLE_NEON + float16x8_t vin0 = param->in_elements_num0_ == 1 ? vin0_opt : vld1q_f16(input0); + float16x8_t vin1 = param->in_elements_num1_ == 1 ? vin1_opt : vld1q_f16(input1); + float16x8_t vout = vdivq_f16(vin0, vin1); + vst1q_f16(output, vout); +#else + for (int i = 0; i < C8NUM; ++i) { + float16_t in0 = param->in_elements_num0_ == 1 ? in0_opt : input0[i]; + float16_t in1 = param->in_elements_num1_ == 1 ? in1_opt : input1[i]; + output[i] = in0 / in1; + } +#endif + input0 += C8NUM; + input1 += C8NUM; + output += C8NUM; + } + for (int index = 0; index < block_mod; ++index) { + float16_t in0 = param->in_elements_num0_ == 1 ? in0_opt : input0[index]; + float16_t in1 = param->in_elements_num1_ == 1 ? in1_opt : input1[index]; + if (in1 == 0) { + return NNACL_ERRCODE_DIVISOR_ZERO; + } + output[index] = in0 / in1; + } + return NNACL_OK; +} int ElementDivReluFp16(float16_t *input0, float16_t *input1, float16_t *output, int element_size) { int block_mod = element_size % C8NUM; @@ -411,12 +740,12 @@ int ElementDivReluFp16(float16_t *input0, float16_t *input1, float16_t *output, #ifdef ENABLE_NEON float16x8_t vin0 = vld1q_f16(input0); float16x8_t vin1 = vld1q_f16(input1); - float16x8_t vout = vsubq_f16(vin0, vin1); + float16x8_t vout = vdivq_f16(vin0, vin1); vout = vmaxq_f16(vout, zeros); vst1q_f16(output, vout); #else for (int i = 0; i < C8NUM; ++i) { - output[i] = MSMAX(input0[i] - input1[i], 0); + output[i] = MSMAX(input0[i] / input1[i], 0); } #endif input0 += C8NUM; @@ -427,7 +756,59 @@ int ElementDivReluFp16(float16_t *input0, float16_t *input1, float16_t *output, if (input1[index] == 0) { return NNACL_ERRCODE_DIVISOR_ZERO; } - float16_t res = input0[index] - input1[index]; + float16_t res = input0[index] / input1[index]; + output[index] = res > 0 ? res : 0; + } + return NNACL_OK; +} +int ElementOptDivReluFp16(float16_t *input0, float16_t *input1, float16_t *output, int element_size, + ArithmeticParameter *param) { + int block_mod = element_size % C8NUM; + int block_c8 = element_size - block_mod; + +#ifdef ENABLE_NEON + float16x8_t vin0_opt = {input0[0], input0[0], input0[0], input0[0], input0[0], input0[0], input0[0], input0[0]}; + float16x8_t vin1_opt = {input1[0], input1[0], input1[0], input1[0], input1[0], input1[0], input1[0], input1[0]}; + float16_t in0_opt = input0[0]; + float16_t in1_opt = input1[0]; + float16x8_t zeros = {0, 0, 0, 0, 0, 0, 0, 0}; +#endif + for (int index = 0; index < block_c8; index += C8NUM) { + if (param->in_elements_num1_ == 1) { + if (in1_opt == 0) { + return NNACL_ERRCODE_DIVISOR_ZERO; + } + } else { + for (int i = 0; i < C8NUM; ++i) { + if (input1[i] == 0) { + return NNACL_ERRCODE_DIVISOR_ZERO; + } + } + } +#ifdef ENABLE_NEON + float16x8_t vin0 = param->in_elements_num0_ == 1 ? vin0_opt : vld1q_f16(input0); + float16x8_t vin1 = param->in_elements_num1_ == 1 ? vin1_opt : vld1q_f16(input1); + float16x8_t vout = vdivq_f16(vin0, vin1); + vout = vmaxq_f16(vout, zeros); + vst1q_f16(output, vout); +#else + for (int i = 0; i < C8NUM; ++i) { + float16_t in0 = param->in_elements_num0_ == 1 ? in0_opt : input0[i]; + float16_t in1 = param->in_elements_num1_ == 1 ? in1_opt : input1[i]; + output[i] = MSMAX(in0 / in1, 0); + } +#endif + input0 += C8NUM; + input1 += C8NUM; + output += C8NUM; + } + for (int index = 0; index < block_mod; ++index) { + float16_t in0 = param->in_elements_num0_ == 1 ? in0_opt : input0[index]; + float16_t in1 = param->in_elements_num1_ == 1 ? in1_opt : input1[index]; + if (in1 == 0) { + return NNACL_ERRCODE_DIVISOR_ZERO; + } + float16_t res = in0 / in1; output[index] = res > 0 ? res : 0; } return NNACL_OK; @@ -449,12 +830,12 @@ int ElementDivRelu6Fp16(float16_t *input0, float16_t *input1, float16_t *output, #ifdef ENABLE_NEON float16x8_t vin0 = vld1q_f16(input0); float16x8_t vin1 = vld1q_f16(input1); - float16x8_t vout = vsubq_f16(vin0, vin1); + float16x8_t vout = vdivq_f16(vin0, vin1); vout = vminq_f16(vmaxq_f16(vout, zeros), bounds); vst1q_f16(output, vout); #else for (int i = 0; i < C8NUM; ++i) { - output[i] = MSMIN(MSMAX(input0[i] - input1[i], 0), 6); + output[i] = MSMIN(MSMAX(input0[i] / input1[i], 0), 6); } #endif input0 += C8NUM; @@ -465,7 +846,59 @@ int ElementDivRelu6Fp16(float16_t *input0, float16_t *input1, float16_t *output, if (input1[index] == 0) { return NNACL_ERRCODE_DIVISOR_ZERO; } - output[index] = MSMIN(MSMAX(input0[index] - input1[index], 0), 6); + output[index] = MSMIN(MSMAX(input0[index] / input1[index], 0), 6); + } + return NNACL_OK; +} +int ElementOptDivRelu6Fp16(float16_t *input0, float16_t *input1, float16_t *output, int element_size, + ArithmeticParameter *param) { + int block_mod = element_size % C8NUM; + int block_c8 = element_size - block_mod; + +#ifdef ENABLE_NEON + float16x8_t vin0_opt = {input0[0], input0[0], input0[0], input0[0], input0[0], input0[0], input0[0], input0[0]}; + float16x8_t vin1_opt = {input1[0], input1[0], input1[0], input1[0], input1[0], input1[0], input1[0], input1[0]}; + float16_t in0_opt = input0[0]; + float16_t in1_opt = input1[0]; + float16x8_t zeros = {0, 0, 0, 0, 0, 0, 0, 0}; + float16x8_t bounds = {6, 6, 6, 6, 6, 6, 6, 6}; +#endif + for (int index = 0; index < block_c8; index += C8NUM) { + if (param->in_elements_num1_ == 1) { + if (in1_opt == 0) { + return NNACL_ERRCODE_DIVISOR_ZERO; + } + } else { + for (int i = 0; i < C8NUM; ++i) { + if (input1[i] == 0) { + return NNACL_ERRCODE_DIVISOR_ZERO; + } + } + } +#ifdef ENABLE_NEON + float16x8_t vin0 = param->in_elements_num0_ == 1 ? vin0_opt : vld1q_f16(input0); + float16x8_t vin1 = param->in_elements_num1_ == 1 ? vin1_opt : vld1q_f16(input1); + float16x8_t vout = vdivq_f16(vin0, vin1); + vout = vminq_f16(vmaxq_f16(vout, zeros), bounds); + vst1q_f16(output, vout); +#else + for (int i = 0; i < C8NUM; ++i) { + float16_t in0 = param->in_elements_num0_ == 1 ? in0_opt : input0[i]; + float16_t in1 = param->in_elements_num1_ == 1 ? in1_opt : input1[i]; + output[i] = MSMIN(MSMAX(in0 / in1, 0), 6); + } +#endif + input0 += C8NUM; + input1 += C8NUM; + output += C8NUM; + } + for (int index = 0; index < block_mod; ++index) { + float16_t in0 = param->in_elements_num0_ == 1 ? in0_opt : input0[index]; + float16_t in1 = param->in_elements_num1_ == 1 ? in1_opt : input1[index]; + if (in1 == 0) { + return NNACL_ERRCODE_DIVISOR_ZERO; + } + output[index] = MSMIN(MSMAX(in0 / in1, 0), 6); } return NNACL_OK; } @@ -479,6 +912,25 @@ int ElementFloorModFp16(float16_t *input0, float16_t *input1, float16_t *output, } return NNACL_OK; } +int ElementOptFloorModFp16(float16_t *input0, float16_t *input1, float16_t *output, int element_size, + ArithmeticParameter *param) { + if (param->in_elements_num1_ == 1) { + if (input1[0] == 0) { + return NNACL_ERRCODE_DIVISOR_ZERO; + } + for (int i = 0; i < element_size; ++i) { + output[i] = input0[i] - floorf(input0[i] / input1[0]) * input1[0]; + } + } else { + for (int i = 0; i < element_size; ++i) { + if (input1[i] == 0) { + return NNACL_ERRCODE_DIVISOR_ZERO; + } + output[i] = input0[i] - floorf(input0[i] / input1[i]) * input1[i]; + } + } + return NNACL_OK; +} int ElementFloorDivFp16(float16_t *input0, float16_t *input1, float16_t *output, int element_size) { for (int i = 0; i < element_size; ++i) { @@ -489,6 +941,25 @@ int ElementFloorDivFp16(float16_t *input0, float16_t *input1, float16_t *output, } return NNACL_OK; } +int ElementOptFloorDivFp16(float16_t *input0, float16_t *input1, float16_t *output, int element_size, + ArithmeticParameter *param) { + if (param->in_elements_num1_ == 1) { + if (input1[0] == 0) { + return NNACL_ERRCODE_DIVISOR_ZERO; + } + for (int i = 0; i < element_size; ++i) { + output[i] = floorf(input0[i] / input1[0]); + } + } else { + for (int i = 0; i < element_size; ++i) { + if (input1[i] == 0) { + return NNACL_ERRCODE_DIVISOR_ZERO; + } + output[i] = floorf(input0[i] / input1[i]); + } + } + return NNACL_OK; +} int ElementLogicalAndFp16(float16_t *input0, float16_t *input1, float16_t *output, int element_size) { int block_mod = element_size % C8NUM; @@ -520,6 +991,46 @@ int ElementLogicalAndFp16(float16_t *input0, float16_t *input1, float16_t *outpu } return NNACL_OK; } +int ElementOptLogicalAndFp16(float16_t *input0, float16_t *input1, float16_t *output, int element_size, + ArithmeticParameter *param) { + int block_mod = element_size % C8NUM; + int block_c8 = element_size - block_mod; +#ifdef ENABLE_NEON + float16x8_t vin0_opt = {input0[0], input0[0], input0[0], input0[0], input0[0], input0[0], input0[0], input0[0]}; + float16x8_t vin1_opt = {input1[0], input1[0], input1[0], input1[0], input1[0], input1[0], input1[0], input1[0]}; + float16_t in0_opt = input0[0]; + float16_t in1_opt = input1[0]; + float16x8_t vtrue = {1, 1, 1, 1, 1, 1, 1, 1}; + float16x8_t vfalse = {0, 0, 0, 0, 0, 0, 0, 0}; + uint16x8_t mask = vmovq_n_u16(((uint16_t)(1u << 15) - 1)); + uint16x8_t zeros = {0, 0, 0, 0, 0, 0, 0, 0}; +#endif + for (int index = 0; index < block_c8; index += C8NUM) { +#ifdef ENABLE_NEON + float16x8_t vin0_ = param->in_elements_num0_ == 1 ? vin0_opt : vld1q_f16(input0); + float16x8_t vin1_ = param->in_elements_num1_ == 1 ? vin1_opt : vld1q_f16(input1); + uint16x8_t vin0 = vandq_u16(vreinterpretq_s16_f16(vin0_), mask); + uint16x8_t vin1 = vandq_u16(vreinterpretq_s16_f16(vin1_), mask); + float16x8_t vout = vbslq_f16(vceqq_u16(vandq_u16(vin0, vin1), zeros), vfalse, vtrue); + vst1q_f16(output, vout); +#else + for (int i = 0; i < C8NUM; ++i) { + float16_t in0 = param->in_elements_num0_ == 1 ? in0_opt : input0[i]; + float16_t in1 = param->in_elements_num1_ == 1 ? in1_opt : input1[i]; + output[i] = (float16_t)((bool)(in0) & (bool)(in1)); + } +#endif + input0 += C8NUM; + input1 += C8NUM; + output += C8NUM; + } + for (int index = 0; index < block_mod; ++index) { + float16_t in0 = param->in_elements_num0_ == 1 ? in0_opt : input0[index]; + float16_t in1 = param->in_elements_num1_ == 1 ? in1_opt : input1[index]; + output[index] = (float16_t)((bool)(in0) & (bool)(in1)); + } + return NNACL_OK; +} int ElementLogicalOrFp16(float16_t *input0, float16_t *input1, float16_t *output, int element_size) { int block_mod = element_size % C8NUM; @@ -551,11 +1062,56 @@ int ElementLogicalOrFp16(float16_t *input0, float16_t *input1, float16_t *output } return NNACL_OK; } +int ElementOptLogicalOrFp16(float16_t *input0, float16_t *input1, float16_t *output, int element_size, + ArithmeticParameter *param) { + int block_mod = element_size % C8NUM; + int block_c8 = element_size - block_mod; +#ifdef ENABLE_NEON + float16x8_t vin0_opt = {input0[0], input0[0], input0[0], input0[0], input0[0], input0[0], input0[0], input0[0]}; + float16x8_t vin1_opt = {input1[0], input1[0], input1[0], input1[0], input1[0], input1[0], input1[0], input1[0]}; + float16_t in0_opt = input0[0]; + float16_t in1_opt = input1[0]; + float16x8_t vtrue = {1, 1, 1, 1, 1, 1, 1, 1}; + float16x8_t vfalse = {0, 0, 0, 0, 0, 0, 0, 0}; + uint16x8_t mask = vmovq_n_u16(((uint16_t)(1u << 15) - 1)); + uint16x8_t zeros = {0, 0, 0, 0, 0, 0, 0, 0}; +#endif + for (int index = 0; index < block_c8; index += C8NUM) { +#ifdef ENABLE_NEON + float16x8_t vin0_ = param->in_elements_num0_ == 1 ? vin0_opt : vld1q_f16(input0); + float16x8_t vin1_ = param->in_elements_num1_ == 1 ? vin1_opt : vld1q_f16(input1); + uint16x8_t vin0 = vandq_u16(vreinterpretq_s16_f16(vin0_), mask); + uint16x8_t vin1 = vandq_u16(vreinterpretq_s16_f16(vin1_), mask); + float16x8_t vout = vbslq_f16(vceqq_u16(vorrq_u16(vin0, vin1), zeros), vfalse, vtrue); + vst1q_f16(output, vout); +#else + for (int i = 0; i < C8NUM; ++i) { + float16_t in0 = param->in_elements_num0_ == 1 ? in0_opt : input0[i]; + float16_t in1 = param->in_elements_num1_ == 1 ? in1_opt : input1[i]; + output[i] = (float16_t)((bool)(in0) | (bool)(in1)); + } +#endif + input0 += C8NUM; + input1 += C8NUM; + output += C8NUM; + } + for (int index = 0; index < block_mod; ++index) { + float16_t in0 = param->in_elements_num0_ == 1 ? in0_opt : input0[index]; + float16_t in1 = param->in_elements_num1_ == 1 ? in1_opt : input1[index]; + output[index] = (float16_t)((bool)(in0) | (bool)(in1)); + } + return NNACL_OK; +} int ElementSquaredDifferenceFp16(float16_t *input0, float16_t *input1, float16_t *output, int element_size) { ElementSubFp16(input0, input1, output, element_size); return ElementMulFp16(output, output, output, element_size); } +int ElementOptSquaredDifferenceFp16(float16_t *input0, float16_t *input1, float16_t *output, int element_size, + ArithmeticParameter *param) { + ElementOptSubFp16(input0, input1, output, element_size, param); + return ElementMulFp16(output, output, output, element_size); +} int ElementMaximumFp16(float16_t *input0, float16_t *input1, float16_t *output, int element_size) { int block_mod = element_size % C8NUM; @@ -580,6 +1136,40 @@ int ElementMaximumFp16(float16_t *input0, float16_t *input1, float16_t *output, } return NNACL_OK; } +int ElementOptMaximumFp16(float16_t *input0, float16_t *input1, float16_t *output, int element_size, + ArithmeticParameter *param) { + int block_mod = element_size % C8NUM; + int block_c8 = element_size - block_mod; +#ifdef ENABLE_NEON + float16x8_t vin0_opt = {input0[0], input0[0], input0[0], input0[0], input0[0], input0[0], input0[0], input0[0]}; + float16x8_t vin1_opt = {input1[0], input1[0], input1[0], input1[0], input1[0], input1[0], input1[0], input1[0]}; + float16_t in0_opt = input0[0]; + float16_t in1_opt = input1[0]; +#endif + for (int index = 0; index < block_c8; index += C8NUM) { +#ifdef ENABLE_NEON + float16x8_t vin0 = param->in_elements_num0_ == 1 ? vin0_opt : vld1q_f16(input0); + float16x8_t vin1 = param->in_elements_num1_ == 1 ? vin1_opt : vld1q_f16(input1); + float16x8_t vout = vmaxq_f16(vin0, vin1); + vst1q_f16(output, vout); +#else + for (int i = 0; i < C8NUM; ++i) { + float16_t in0 = param->in_elements_num0_ == 1 ? in0_opt : input0[i]; + float16_t in1 = param->in_elements_num1_ == 1 ? in1_opt : input1[i]; + output[i] = MSMAX(in0, in1); + } +#endif + input0 += C8NUM; + input1 += C8NUM; + output += C8NUM; + } + for (int index = 0; index < block_mod; ++index) { + float16_t in0 = param->in_elements_num0_ == 1 ? in0_opt : input0[index]; + float16_t in1 = param->in_elements_num1_ == 1 ? in1_opt : input1[index]; + output[index] = MSMAX(in0, in1); + } + return NNACL_OK; +} int ElementMinimumFp16(float16_t *input0, float16_t *input1, float16_t *output, int element_size) { int block_mod = element_size % C8NUM; @@ -604,6 +1194,40 @@ int ElementMinimumFp16(float16_t *input0, float16_t *input1, float16_t *output, } return NNACL_OK; } +int ElementOptMinimumFp16(float16_t *input0, float16_t *input1, float16_t *output, int element_size, + ArithmeticParameter *param) { + int block_mod = element_size % C8NUM; + int block_c8 = element_size - block_mod; +#ifdef ENABLE_NEON + float16x8_t vin0_opt = {input0[0], input0[0], input0[0], input0[0], input0[0], input0[0], input0[0], input0[0]}; + float16x8_t vin1_opt = {input1[0], input1[0], input1[0], input1[0], input1[0], input1[0], input1[0], input1[0]}; + float16_t in0_opt = input0[0]; + float16_t in1_opt = input1[0]; +#endif + for (int index = 0; index < block_c8; index += C8NUM) { +#ifdef ENABLE_NEON + float16x8_t vin0 = param->in_elements_num0_ == 1 ? vin0_opt : vld1q_f16(input0); + float16x8_t vin1 = param->in_elements_num1_ == 1 ? vin1_opt : vld1q_f16(input1); + float16x8_t vout = vminq_f16(vin0, vin1); + vst1q_f16(output, vout); +#else + for (int i = 0; i < C8NUM; ++i) { + float16_t in0 = param->in_elements_num0_ == 1 ? in0_opt : input0[i]; + float16_t in1 = param->in_elements_num1_ == 1 ? in1_opt : input1[i]; + output[i] = MSMIN(in0, in1); + } +#endif + input0 += C8NUM; + input1 += C8NUM; + output += C8NUM; + } + for (int index = 0; index < block_mod; ++index) { + float16_t in0 = param->in_elements_num0_ == 1 ? in0_opt : input0[index]; + float16_t in1 = param->in_elements_num1_ == 1 ? in1_opt : input1[index]; + output[index] = MSMIN(in0, in1); + } + return NNACL_OK; +} int ElementNotEqualFp16(float16_t *input0, float16_t *input1, float16_t *output, int element_size) { int block_mod = element_size % C8NUM; @@ -632,6 +1256,42 @@ int ElementNotEqualFp16(float16_t *input0, float16_t *input1, float16_t *output, } return NNACL_OK; } +int ElementOptNotEqualFp16(float16_t *input0, float16_t *input1, float16_t *output, int element_size, + ArithmeticParameter *param) { + int block_mod = element_size % C8NUM; + int block_c8 = element_size - block_mod; +#ifdef ENABLE_NEON + float16x8_t vin0_opt = {input0[0], input0[0], input0[0], input0[0], input0[0], input0[0], input0[0], input0[0]}; + float16x8_t vin1_opt = {input1[0], input1[0], input1[0], input1[0], input1[0], input1[0], input1[0], input1[0]}; + float16_t in0_opt = input0[0]; + float16_t in1_opt = input1[0]; + float16x8_t vtrue = {1, 1, 1, 1, 1, 1, 1, 1}; + float16x8_t vfalse = {0, 0, 0, 0, 0, 0, 0, 0}; +#endif + for (int index = 0; index < block_c8; index += C8NUM) { +#ifdef ENABLE_NEON + float16x8_t vin0 = param->in_elements_num0_ == 1 ? vin0_opt : vld1q_f16(input0); + float16x8_t vin1 = param->in_elements_num1_ == 1 ? vin1_opt : vld1q_f16(input1); + float16x8_t vout = vbslq_f16(vceqq_f16(vin0, vin1), vfalse, vtrue); + vst1q_f16(output, vout); +#else + for (int i = 0; i < C8NUM; ++i) { + float16_t in0 = param->in_elements_num0_ == 1 ? in0_opt : input0[i]; + float16_t in1 = param->in_elements_num1_ == 1 ? in1_opt : input1[i]; + output[i] = (float16_t)(in0 != in1); + } +#endif + input0 += C8NUM; + input1 += C8NUM; + output += C8NUM; + } + for (int index = 0; index < block_mod; ++index) { + float16_t in0 = param->in_elements_num0_ == 1 ? in0_opt : input0[index]; + float16_t in1 = param->in_elements_num1_ == 1 ? in1_opt : input1[index]; + output[index] = (float16_t)(in0 != in1); + } + return NNACL_OK; +} int ElementEqualFp16(float16_t *input0, float16_t *input1, float16_t *output, int element_size) { int block_mod = element_size % C8NUM; @@ -660,6 +1320,42 @@ int ElementEqualFp16(float16_t *input0, float16_t *input1, float16_t *output, in } return NNACL_OK; } +int ElementOptEqualFp16(float16_t *input0, float16_t *input1, float16_t *output, int element_size, + ArithmeticParameter *param) { + int block_mod = element_size % C8NUM; + int block_c8 = element_size - block_mod; +#ifdef ENABLE_NEON + float16x8_t vin0_opt = {input0[0], input0[0], input0[0], input0[0], input0[0], input0[0], input0[0], input0[0]}; + float16x8_t vin1_opt = {input1[0], input1[0], input1[0], input1[0], input1[0], input1[0], input1[0], input1[0]}; + float16_t in0_opt = input0[0]; + float16_t in1_opt = input1[0]; + float16x8_t vtrue = {1, 1, 1, 1, 1, 1, 1, 1}; + float16x8_t vfalse = {0, 0, 0, 0, 0, 0, 0, 0}; +#endif + for (int index = 0; index < block_c8; index += C8NUM) { +#ifdef ENABLE_NEON + float16x8_t vin0 = param->in_elements_num0_ == 1 ? vin0_opt : vld1q_f16(input0); + float16x8_t vin1 = param->in_elements_num1_ == 1 ? vin1_opt : vld1q_f16(input1); + float16x8_t vout = vbslq_f16(vceqq_f16(vin0, vin1), vtrue, vfalse); + vst1q_f16(output, vout); +#else + for (int i = 0; i < C8NUM; ++i) { + float16_t in0 = param->in_elements_num0_ == 1 ? in0_opt : input0[i]; + float16_t in1 = param->in_elements_num1_ == 1 ? in1_opt : input1[i]; + output[i] = (float16_t)(in0 == in1); + } +#endif + input0 += C8NUM; + input1 += C8NUM; + output += C8NUM; + } + for (int index = 0; index < block_mod; ++index) { + float16_t in0 = param->in_elements_num0_ == 1 ? in0_opt : input0[index]; + float16_t in1 = param->in_elements_num1_ == 1 ? in1_opt : input1[index]; + output[index] = (float16_t)(in0 == in1); + } + return NNACL_OK; +} int ElementLessFp16(float16_t *input0, float16_t *input1, float16_t *output, int element_size) { int block_mod = element_size % C8NUM; @@ -688,6 +1384,42 @@ int ElementLessFp16(float16_t *input0, float16_t *input1, float16_t *output, int } return NNACL_OK; } +int ElementOptLessFp16(float16_t *input0, float16_t *input1, float16_t *output, int element_size, + ArithmeticParameter *param) { + int block_mod = element_size % C8NUM; + int block_c8 = element_size - block_mod; +#ifdef ENABLE_NEON + float16x8_t vin0_opt = {input0[0], input0[0], input0[0], input0[0], input0[0], input0[0], input0[0], input0[0]}; + float16x8_t vin1_opt = {input1[0], input1[0], input1[0], input1[0], input1[0], input1[0], input1[0], input1[0]}; + float16_t in0_opt = input0[0]; + float16_t in1_opt = input1[0]; + float16x8_t vtrue = {1, 1, 1, 1, 1, 1, 1, 1}; + float16x8_t vfalse = {0, 0, 0, 0, 0, 0, 0, 0}; +#endif + for (int index = 0; index < block_c8; index += C8NUM) { +#ifdef ENABLE_NEON + float16x8_t vin0 = param->in_elements_num0_ == 1 ? vin0_opt : vld1q_f16(input0); + float16x8_t vin1 = param->in_elements_num1_ == 1 ? vin1_opt : vld1q_f16(input1); + float16x8_t vout = vbslq_f16(vcltq_f16(vin0, vin1), vtrue, vfalse); + vst1q_f16(output, vout); +#else + for (int i = 0; i < C8NUM; ++i) { + float16_t in0 = param->in_elements_num0_ == 1 ? in0_opt : input0[i]; + float16_t in1 = param->in_elements_num1_ == 1 ? in1_opt : input1[i]; + output[i] = (float16_t)(in0 < in1); + } +#endif + input0 += C8NUM; + input1 += C8NUM; + output += C8NUM; + } + for (int index = 0; index < block_mod; ++index) { + float16_t in0 = param->in_elements_num0_ == 1 ? in0_opt : input0[index]; + float16_t in1 = param->in_elements_num1_ == 1 ? in1_opt : input1[index]; + output[index] = (float16_t)(in0 < in1); + } + return NNACL_OK; +} int ElementLessEqualFp16(float16_t *input0, float16_t *input1, float16_t *output, int element_size) { int block_mod = element_size % C8NUM; @@ -716,6 +1448,42 @@ int ElementLessEqualFp16(float16_t *input0, float16_t *input1, float16_t *output } return NNACL_OK; } +int ElementOptLessEqualFp16(float16_t *input0, float16_t *input1, float16_t *output, int element_size, + ArithmeticParameter *param) { + int block_mod = element_size % C8NUM; + int block_c8 = element_size - block_mod; +#ifdef ENABLE_NEON + float16x8_t vin0_opt = {input0[0], input0[0], input0[0], input0[0], input0[0], input0[0], input0[0], input0[0]}; + float16x8_t vin1_opt = {input1[0], input1[0], input1[0], input1[0], input1[0], input1[0], input1[0], input1[0]}; + float16_t in0_opt = input0[0]; + float16_t in1_opt = input1[0]; + float16x8_t vtrue = {1, 1, 1, 1, 1, 1, 1, 1}; + float16x8_t vfalse = {0, 0, 0, 0, 0, 0, 0, 0}; +#endif + for (int index = 0; index < block_c8; index += C8NUM) { +#ifdef ENABLE_NEON + float16x8_t vin0 = param->in_elements_num0_ == 1 ? vin0_opt : vld1q_f16(input0); + float16x8_t vin1 = param->in_elements_num1_ == 1 ? vin1_opt : vld1q_f16(input1); + float16x8_t vout = vbslq_f16(vcleq_f16(vin0, vin1), vtrue, vfalse); + vst1q_f16(output, vout); +#else + for (int i = 0; i < C8NUM; ++i) { + float16_t in0 = param->in_elements_num0_ == 1 ? in0_opt : input0[i]; + float16_t in1 = param->in_elements_num1_ == 1 ? in1_opt : input1[i]; + output[i] = (float16_t)(in0 <= in1); + } +#endif + input0 += C8NUM; + input1 += C8NUM; + output += C8NUM; + } + for (int index = 0; index < block_mod; ++index) { + float16_t in0 = param->in_elements_num0_ == 1 ? in0_opt : input0[index]; + float16_t in1 = param->in_elements_num1_ == 1 ? in1_opt : input1[index]; + output[index] = (float16_t)(in0 <= in1); + } + return NNACL_OK; +} int ElementGreaterFp16(float16_t *input0, float16_t *input1, float16_t *output, int element_size) { int block_mod = element_size % C8NUM; @@ -744,6 +1512,42 @@ int ElementGreaterFp16(float16_t *input0, float16_t *input1, float16_t *output, } return NNACL_OK; } +int ElementOptGreaterFp16(float16_t *input0, float16_t *input1, float16_t *output, int element_size, + ArithmeticParameter *param) { + int block_mod = element_size % C8NUM; + int block_c8 = element_size - block_mod; +#ifdef ENABLE_NEON + float16x8_t vin0_opt = {input0[0], input0[0], input0[0], input0[0], input0[0], input0[0], input0[0], input0[0]}; + float16x8_t vin1_opt = {input1[0], input1[0], input1[0], input1[0], input1[0], input1[0], input1[0], input1[0]}; + float16_t in0_opt = input0[0]; + float16_t in1_opt = input1[0]; + float16x8_t vtrue = {1, 1, 1, 1, 1, 1, 1, 1}; + float16x8_t vfalse = {0, 0, 0, 0, 0, 0, 0, 0}; +#endif + for (int index = 0; index < block_c8; index += C8NUM) { +#ifdef ENABLE_NEON + float16x8_t vin0 = param->in_elements_num0_ == 1 ? vin0_opt : vld1q_f16(input0); + float16x8_t vin1 = param->in_elements_num1_ == 1 ? vin1_opt : vld1q_f16(input1); + float16x8_t vout = vbslq_f16(vcgtq_f16(vin0, vin1), vtrue, vfalse); + vst1q_f16(output, vout); +#else + for (int i = 0; i < C8NUM; ++i) { + float16_t in0 = param->in_elements_num0_ == 1 ? in0_opt : input0[i]; + float16_t in1 = param->in_elements_num1_ == 1 ? in1_opt : input1[i]; + output[i] = (float16_t)(in0 > in1); + } +#endif + input0 += C8NUM; + input1 += C8NUM; + output += C8NUM; + } + for (int index = 0; index < block_mod; ++index) { + float16_t in0 = param->in_elements_num0_ == 1 ? in0_opt : input0[index]; + float16_t in1 = param->in_elements_num1_ == 1 ? in1_opt : input1[index]; + output[index] = (float16_t)(in0 > in1); + } + return NNACL_OK; +} int ElementGreaterEqualFp16(float16_t *input0, float16_t *input1, float16_t *output, int element_size) { int block_mod = element_size % C8NUM; @@ -772,3 +1576,39 @@ int ElementGreaterEqualFp16(float16_t *input0, float16_t *input1, float16_t *out } return NNACL_OK; } +int ElementOptGreaterEqualFp16(float16_t *input0, float16_t *input1, float16_t *output, int element_size, + ArithmeticParameter *param) { + int block_mod = element_size % C8NUM; + int block_c8 = element_size - block_mod; +#ifdef ENABLE_NEON + float16x8_t vin0_opt = {input0[0], input0[0], input0[0], input0[0], input0[0], input0[0], input0[0], input0[0]}; + float16x8_t vin1_opt = {input1[0], input1[0], input1[0], input1[0], input1[0], input1[0], input1[0], input1[0]}; + float16_t in0_opt = input0[0]; + float16_t in1_opt = input1[0]; + float16x8_t vtrue = {1, 1, 1, 1, 1, 1, 1, 1}; + float16x8_t vfalse = {0, 0, 0, 0, 0, 0, 0, 0}; +#endif + for (int index = 0; index < block_c8; index += C8NUM) { +#ifdef ENABLE_NEON + float16x8_t vin0 = param->in_elements_num0_ == 1 ? vin0_opt : vld1q_f16(input0); + float16x8_t vin1 = param->in_elements_num1_ == 1 ? vin1_opt : vld1q_f16(input1); + float16x8_t vout = vbslq_f16(vcgeq_f16(vin0, vin1), vtrue, vfalse); + vst1q_f16(output, vout); +#else + for (int i = 0; i < C8NUM; ++i) { + float16_t in0 = param->in_elements_num0_ == 1 ? in0_opt : input0[i]; + float16_t in1 = param->in_elements_num1_ == 1 ? in1_opt : input1[i]; + output[i] = (float16_t)(in0 >= in1); + } +#endif + input0 += C8NUM; + input1 += C8NUM; + output += C8NUM; + } + for (int index = 0; index < block_mod; ++index) { + float16_t in0 = param->in_elements_num0_ == 1 ? in0_opt : input0[index]; + float16_t in1 = param->in_elements_num1_ == 1 ? in1_opt : input1[index]; + output[index] = (float16_t)(in0 >= in1); + } + return NNACL_OK; +} diff --git a/mindspore/lite/src/runtime/kernel/arm/nnacl/fp16/arithmetic_fp16.h b/mindspore/lite/src/runtime/kernel/arm/nnacl/fp16/arithmetic_fp16.h index 0cbaa9e6c8e..c3369519a6f 100644 --- a/mindspore/lite/src/runtime/kernel/arm/nnacl/fp16/arithmetic_fp16.h +++ b/mindspore/lite/src/runtime/kernel/arm/nnacl/fp16/arithmetic_fp16.h @@ -26,12 +26,57 @@ #ifdef __cplusplus extern "C" { #endif -int ElementOptAddFp16(float16_t *input0, float16_t *input1, float16_t *output, int element_size, - ArithmeticParameter *param); -int ElementOptSubFp16(float16_t *input0, float16_t *input1, float16_t *output, int element_size, - ArithmeticParameter *param); int ElementOptMulFp16(float16_t *input0, float16_t *input1, float16_t *output, int element_size, ArithmeticParameter *param); +int ElementOptMulReluFp16(float16_t *input0, float16_t *input1, float16_t *output, int element_size, + ArithmeticParameter *param); +int ElementOptMulRelu6Fp16(float16_t *input0, float16_t *input1, float16_t *output, int element_size, + ArithmeticParameter *param); +int ElementOptAddFp16(float16_t *input0, float16_t *input1, float16_t *output, int element_size, + ArithmeticParameter *param); +int ElementOptAddReluFp16(float16_t *input0, float16_t *input1, float16_t *output, int element_size, + ArithmeticParameter *param); +int ElementOptAddRelu6Fp16(float16_t *input0, float16_t *input1, float16_t *output, int element_size, + ArithmeticParameter *param); +int ElementOptSubFp16(float16_t *input0, float16_t *input1, float16_t *output, int element_size, + ArithmeticParameter *param); +int ElementOptSubReluFp16(float16_t *input0, float16_t *input1, float16_t *output, int element_size, + ArithmeticParameter *param); +int ElementOptSubRelu6Fp16(float16_t *input0, float16_t *input1, float16_t *output, int element_size, + ArithmeticParameter *param); +int ElementOptDivFp16(float16_t *input0, float16_t *input1, float16_t *output, int element_size, + ArithmeticParameter *param); +int ElementOptDivReluFp16(float16_t *input0, float16_t *input1, float16_t *output, int element_size, + ArithmeticParameter *param); +int ElementOptDivRelu6Fp16(float16_t *input0, float16_t *input1, float16_t *output, int element_size, + ArithmeticParameter *param); +int ElementOptFloorModFp16(float16_t *input0, float16_t *input1, float16_t *output, int element_size, + ArithmeticParameter *param); +int ElementOptFloorDivFp16(float16_t *input0, float16_t *input1, float16_t *output, int element_size, + ArithmeticParameter *param); +int ElementOptLogicalAndFp16(float16_t *input0, float16_t *input1, float16_t *output, int element_size, + ArithmeticParameter *param); +int ElementOptLogicalOrFp16(float16_t *input0, float16_t *input1, float16_t *output, int element_size, + ArithmeticParameter *param); +int ElementOptSquaredDifferenceFp16(float16_t *input0, float16_t *input1, float16_t *output, int element_size, + ArithmeticParameter *param); +int ElementOptMaximumFp16(float16_t *input0, float16_t *input1, float16_t *output, int element_size, + ArithmeticParameter *param); +int ElementOptMinimumFp16(float16_t *input0, float16_t *input1, float16_t *output, int element_size, + ArithmeticParameter *param); +int ElementOptNotEqualFp16(float16_t *input0, float16_t *input1, float16_t *output, int element_size, + ArithmeticParameter *param); +int ElementOptEqualFp16(float16_t *input0, float16_t *input1, float16_t *output, int element_size, + ArithmeticParameter *param); +int ElementOptLessFp16(float16_t *input0, float16_t *input1, float16_t *output, int element_size, + ArithmeticParameter *param); +int ElementOptLessEqualFp16(float16_t *input0, float16_t *input1, float16_t *output, int element_size, + ArithmeticParameter *param); +int ElementOptGreaterFp16(float16_t *input0, float16_t *input1, float16_t *output, int element_size, + ArithmeticParameter *param); +int ElementOptGreaterEqualFp16(float16_t *input0, float16_t *input1, float16_t *output, int element_size, + ArithmeticParameter *param); + int ElementMulFp16(float16_t *input0, float16_t *input1, float16_t *output, int element_size); int ElementMulReluFp16(float16_t *input0, float16_t *input1, float16_t *output, int element_size); int ElementMulRelu6Fp16(float16_t *input0, float16_t *input1, float16_t *output, int element_size);