From 1fb6f1b65d50310b0957bce063573edca5a7bbbd Mon Sep 17 00:00:00 2001 From: tao_yunhao Date: Mon, 31 Aug 2020 09:31:18 +0800 Subject: [PATCH] modify arm cpu fp16&fp32 op: Arithmetic --- mindspore/lite/nnacl/fp16/arithmetic_fp16.c | 915 ++++++++++++-------- 1 file changed, 574 insertions(+), 341 deletions(-) diff --git a/mindspore/lite/nnacl/fp16/arithmetic_fp16.c b/mindspore/lite/nnacl/fp16/arithmetic_fp16.c index ef8ae6fd64e..7168d32c61b 100644 --- a/mindspore/lite/nnacl/fp16/arithmetic_fp16.c +++ b/mindspore/lite/nnacl/fp16/arithmetic_fp16.c @@ -580,27 +580,42 @@ int ElementOptSubFp16(float16_t *input0, float16_t *input1, float16_t *output, i float16x8_t vin0_opt = {input0[0], input0[0], input0[0], input0[0], input0[0], input0[0], input0[0], input0[0]}; float16x8_t vin1_opt = {input1[0], input1[0], input1[0], input1[0], input1[0], input1[0], input1[0], input1[0]}; #endif - for (int index = 0; index < block_c8; index += C8NUM) { + if (param->in_elements_num0_ == 1) { + for (int index = 0; index < block_c8; index += C8NUM) { #ifdef ENABLE_NEON - float16x8_t vin0 = param->in_elements_num0_ == 1 ? vin0_opt : vld1q_f16(input0); - float16x8_t vin1 = param->in_elements_num1_ == 1 ? vin1_opt : vld1q_f16(input1); - float16x8_t vout = vsubq_f16(vin0, vin1); - vst1q_f16(output, vout); + float16x8_t vin0 = vin0_opt; + float16x8_t vin1 = vld1q_f16(input1); + float16x8_t vout = vsubq_f16(vin0, vin1); + vst1q_f16(output, vout); #else - for (int i = 0; i < C8NUM; ++i) { - float16_t in0 = param->in_elements_num0_ == 1 ? in0_opt : input0[i]; - float16_t in1 = param->in_elements_num1_ == 1 ? in1_opt : input1[i]; - output[i] = in0 - in1; - } + for (int i = 0; i < C8NUM; ++i) { + output[i] = in0_opt - input1[i]; + } #endif - input0 += C8NUM; - input1 += C8NUM; - output += C8NUM; - } - for (int index = 0; index < block_mod; ++index) { - float16_t in0 = param->in_elements_num0_ == 1 ? in0_opt : input0[index]; - float16_t in1 = param->in_elements_num1_ == 1 ? in1_opt : input1[index]; - output[index] = in0 - in1; + input1 += C8NUM; + output += C8NUM; + } + for (int index = 0; index < block_mod; ++index) { + output[index] = in0_opt - input1[index]; + } + } else { + for (int index = 0; index < block_c8; index += C8NUM) { +#ifdef ENABLE_NEON + float16x8_t vin0 = vld1q_f16(input0); + float16x8_t vin1 = vin1_opt; + float16x8_t vout = vsubq_f16(vin0, vin1); + vst1q_f16(output, vout); +#else + for (int i = 0; i < C8NUM; ++i) { + output[i] = input0[i] - in1_opt; + } +#endif + input0 += C8NUM; + output += C8NUM; + } + for (int index = 0; index < block_mod; ++index) { + output[index] = input0[index] - in1_opt; + } } return NNACL_OK; } @@ -644,29 +659,46 @@ int ElementOptSubReluFp16(float16_t *input0, float16_t *input1, float16_t *outpu float16x8_t vin1_opt = {input1[0], input1[0], input1[0], input1[0], input1[0], input1[0], input1[0], input1[0]}; float16x8_t zeros = {0, 0, 0, 0, 0, 0, 0, 0}; #endif - for (int index = 0; index < block_c8; index += C8NUM) { + if (param->in_elements_num0_ == 1) { + for (int index = 0; index < block_c8; index += C8NUM) { #ifdef ENABLE_NEON - float16x8_t vin0 = param->in_elements_num0_ == 1 ? vin0_opt : vld1q_f16(input0); - float16x8_t vin1 = param->in_elements_num1_ == 1 ? vin1_opt : vld1q_f16(input1); - float16x8_t vout = vsubq_f16(vin0, vin1); - vout = vmaxq_f16(vout, zeros); - vst1q_f16(output, vout); + float16x8_t vin0 = vin0_opt; + float16x8_t vin1 = vld1q_f16(input1); + float16x8_t vout = vsubq_f16(vin0, vin1); + vout = vmaxq_f16(vout, zeros); + vst1q_f16(output, vout); #else - for (int i = 0; i < C8NUM; ++i) { - float16_t in0 = param->in_elements_num0_ == 1 ? in0_opt : input0[i]; - float16_t in1 = param->in_elements_num1_ == 1 ? in1_opt : input1[i]; - output[i] = MSMAX(in0 - in1, 0); - } + for (int i = 0; i < C8NUM; ++i) { + output[i] = MSMAX(in0_opt - input1[i], 0); + } #endif - input0 += C8NUM; - input1 += C8NUM; - output += C8NUM; - } - for (int index = 0; index < block_mod; ++index) { - float16_t in0 = param->in_elements_num0_ == 1 ? in0_opt : input0[index]; - float16_t in1 = param->in_elements_num1_ == 1 ? in1_opt : input1[index]; - float16_t res = in0 - in1; - output[index] = res > 0 ? res : 0; + input1 += C8NUM; + output += C8NUM; + } + for (int index = 0; index < block_mod; ++index) { + float16_t res = in0_opt - input1[index]; + output[index] = res > 0 ? res : 0; + } + } else { + for (int index = 0; index < block_c8; index += C8NUM) { +#ifdef ENABLE_NEON + float16x8_t vin0 = vld1q_f16(input0); + float16x8_t vin1 = vin1_opt; + float16x8_t vout = vsubq_f16(vin0, vin1); + vout = vmaxq_f16(vout, zeros); + vst1q_f16(output, vout); +#else + for (int i = 0; i < C8NUM; ++i) { + output[i] = MSMAX(input0[i] - in1_opt, 0); + } +#endif + input0 += C8NUM; + output += C8NUM; + } + for (int index = 0; index < block_mod; ++index) { + float16_t res = input0[index] - in1_opt; + output[index] = res > 0 ? res : 0; + } } return NNACL_OK; } @@ -712,30 +744,45 @@ int ElementOptSubRelu6Fp16(float16_t *input0, float16_t *input1, float16_t *outp float16x8_t zeros = {0, 0, 0, 0, 0, 0, 0, 0}; float16x8_t bounds = {6, 6, 6, 6, 6, 6, 6, 6}; #endif - for (int index = 0; index < block_c8; index += C8NUM) { + if (param->in_elements_num0_ == 1) { + for (int index = 0; index < block_c8; index += C8NUM) { #ifdef ENABLE_NEON - float16x8_t vin0 = param->in_elements_num0_ == 1 ? vin0_opt : vld1q_f16(input0); - float16x8_t vin1 = param->in_elements_num1_ == 1 ? vin1_opt : vld1q_f16(input1); - float16x8_t vout = vsubq_f16(vin0, vin1); - vout = vminq_f16(vmaxq_f16(vout, zeros), bounds); - vst1q_f16(output, vout); + float16x8_t vin0 = vin0_opt; + float16x8_t vin1 = vld1q_f16(input1); + float16x8_t vout = vsubq_f16(vin0, vin1); + vout = vminq_f16(vmaxq_f16(vout, zeros), bounds); + vst1q_f16(output, vout); #else - for (int i = 0; i < C8NUM; ++i) { - float16_t in0 = param->in_elements_num0_ == 1 ? in0_opt : input0[i]; - float16_t in1 = param->in_elements_num1_ == 1 ? in1_opt : input1[i]; - output[i] = MSMIN(MSMAX(in0 - in1, 0), 6); - } + for (int i = 0; i < C8NUM; ++i) { + output[i] = MSMIN(MSMAX(in0_opt - input1[i], 0), 6); + } #endif - input0 += C8NUM; - input1 += C8NUM; - output += C8NUM; + input1 += C8NUM; + output += C8NUM; + } + for (int index = 0; index < block_mod; ++index) { + output[index] = MSMIN(MSMAX(in0_opt - input1[index], 0), 6); + } + } else { + for (int index = 0; index < block_c8; index += C8NUM) { +#ifdef ENABLE_NEON + float16x8_t vin0 = vld1q_f16(input0); + float16x8_t vin1 = vin1_opt; + float16x8_t vout = vsubq_f16(vin0, vin1); + vout = vminq_f16(vmaxq_f16(vout, zeros), bounds); + vst1q_f16(output, vout); +#else + for (int i = 0; i < C8NUM; ++i) { + output[i] = MSMIN(MSMAX(input0[i] - in1_opt, 0), 6); + } +#endif + input0 += C8NUM; + output += C8NUM; + } + for (int index = 0; index < block_mod; ++index) { + output[index] = MSMIN(MSMAX(input0[index] - in1_opt, 0), 6); + } } - for (int index = 0; index < block_mod; ++index) { - float16_t in0 = param->in_elements_num0_ == 1 ? in0_opt : input0[index]; - float16_t in1 = param->in_elements_num1_ == 1 ? in1_opt : input1[index]; - output[index] = MSMIN(MSMAX(in0 - in1, 0), 6); - } - return NNACL_OK; } @@ -781,41 +828,53 @@ int ElementOptDivFp16(float16_t *input0, float16_t *input1, float16_t *output, i float16x8_t vin0_opt = {input0[0], input0[0], input0[0], input0[0], input0[0], input0[0], input0[0], input0[0]}; float16x8_t vin1_opt = {input1[0], input1[0], input1[0], input1[0], input1[0], input1[0], input1[0], input1[0]}; #endif - for (int index = 0; index < block_c8; index += C8NUM) { - if (param->in_elements_num1_ == 1) { - if (in1_opt == 0) { - return NNACL_ERRCODE_DIVISOR_ZERO; - } - } else { + if (param->in_elements_num0_ == 1) { + for (int index = 0; index < block_c8; index += C8NUM) { for (int i = 0; i < C8NUM; ++i) { if (input1[i] == 0) { return NNACL_ERRCODE_DIVISOR_ZERO; } } - } #ifdef ENABLE_NEON - float16x8_t vin0 = param->in_elements_num0_ == 1 ? vin0_opt : vld1q_f16(input0); - float16x8_t vin1 = param->in_elements_num1_ == 1 ? vin1_opt : vld1q_f16(input1); - float16x8_t vout = vdivq_f16(vin0, vin1); - vst1q_f16(output, vout); + float16x8_t vin0 = vin0_opt; + float16x8_t vin1 = vld1q_f16(input1); + float16x8_t vout = vdivq_f16(vin0, vin1); + vst1q_f16(output, vout); #else - for (int i = 0; i < C8NUM; ++i) { - float16_t in0 = param->in_elements_num0_ == 1 ? in0_opt : input0[i]; - float16_t in1 = param->in_elements_num1_ == 1 ? in1_opt : input1[i]; - output[i] = in0 / in1; - } + for (int i = 0; i < C8NUM; ++i) { + output[i] = in0_opt / input1[i]; + } #endif - input0 += C8NUM; - input1 += C8NUM; - output += C8NUM; - } - for (int index = 0; index < block_mod; ++index) { - float16_t in0 = param->in_elements_num0_ == 1 ? in0_opt : input0[index]; - float16_t in1 = param->in_elements_num1_ == 1 ? in1_opt : input1[index]; - if (in1 == 0) { + input1 += C8NUM; + output += C8NUM; + } + for (int index = 0; index < block_mod; ++index) { + if (input1[index] == 0) { + return NNACL_ERRCODE_DIVISOR_ZERO; + } + output[index] = in0_opt / input1[index]; + } + } else { + if (in1_opt == 0) { return NNACL_ERRCODE_DIVISOR_ZERO; } - output[index] = in0 / in1; + for (int index = 0; index < block_c8; index += C8NUM) { +#ifdef ENABLE_NEON + float16x8_t vin0 = vld1q_f16(input0); + float16x8_t vin1 = vin1_opt; + float16x8_t vout = vdivq_f16(vin0, vin1); + vst1q_f16(output, vout); +#else + for (int i = 0; i < C8NUM; ++i) { + output[i] = input0[i] / in1_opt; + } +#endif + input0 += C8NUM; + output += C8NUM; + } + for (int index = 0; index < block_mod; ++index) { + output[index] = input0[index] / in1_opt; + } } return NNACL_OK; } @@ -867,43 +926,53 @@ int ElementOptDivReluFp16(float16_t *input0, float16_t *input1, float16_t *outpu float16x8_t vin1_opt = {input1[0], input1[0], input1[0], input1[0], input1[0], input1[0], input1[0], input1[0]}; float16x8_t zeros = {0, 0, 0, 0, 0, 0, 0, 0}; #endif - for (int index = 0; index < block_c8; index += C8NUM) { - if (param->in_elements_num1_ == 1) { - if (in1_opt == 0) { - return NNACL_ERRCODE_DIVISOR_ZERO; - } - } else { + if (param->in_elements_num0_ == 1) { + for (int index = 0; index < block_c8; index += C8NUM) { for (int i = 0; i < C8NUM; ++i) { if (input1[i] == 0) { return NNACL_ERRCODE_DIVISOR_ZERO; } } - } #ifdef ENABLE_NEON - float16x8_t vin0 = param->in_elements_num0_ == 1 ? vin0_opt : vld1q_f16(input0); - float16x8_t vin1 = param->in_elements_num1_ == 1 ? vin1_opt : vld1q_f16(input1); - float16x8_t vout = vdivq_f16(vin0, vin1); - vout = vmaxq_f16(vout, zeros); - vst1q_f16(output, vout); + float16x8_t vin0 = vin0_opt; + float16x8_t vin1 = vld1q_f16(input1); + float16x8_t vout = vmaxq_f16(vdivq_f16(vin0, vin1), zeros); + vst1q_f16(output, vout); #else - for (int i = 0; i < C8NUM; ++i) { - float16_t in0 = param->in_elements_num0_ == 1 ? in0_opt : input0[i]; - float16_t in1 = param->in_elements_num1_ == 1 ? in1_opt : input1[i]; - output[i] = MSMAX(in0 / in1, 0); - } + for (int i = 0; i < C8NUM; ++i) { + output[i] = MSMAX(in0_opt / input1[i], 0); + } #endif - input0 += C8NUM; - input1 += C8NUM; - output += C8NUM; - } - for (int index = 0; index < block_mod; ++index) { - float16_t in0 = param->in_elements_num0_ == 1 ? in0_opt : input0[index]; - float16_t in1 = param->in_elements_num1_ == 1 ? in1_opt : input1[index]; - if (in1 == 0) { + input1 += C8NUM; + output += C8NUM; + } + for (int index = 0; index < block_mod; ++index) { + if (input1[index] == 0) { + return NNACL_ERRCODE_DIVISOR_ZERO; + } + output[index] = MSMAX(in0_opt / input1[index], 0); + } + } else { + if (in1_opt == 0) { return NNACL_ERRCODE_DIVISOR_ZERO; } - float16_t res = in0 / in1; - output[index] = res > 0 ? res : 0; + for (int index = 0; index < block_c8; index += C8NUM) { +#ifdef ENABLE_NEON + float16x8_t vin0 = vld1q_f16(input0); + float16x8_t vin1 = vin1_opt; + float16x8_t vout = vmaxq_f16(vdivq_f16(vin0, vin1), zeros); + vst1q_f16(output, vout); +#else + for (int i = 0; i < C8NUM; ++i) { + output[i] = MSMAX(input0[i] / in1_opt, 0); + } +#endif + input0 += C8NUM; + output += C8NUM; + } + for (int index = 0; index < block_mod; ++index) { + output[index] = MSMAX(input0[index] / in1_opt, 0); + } } return NNACL_OK; } @@ -948,7 +1017,6 @@ int ElementOptDivRelu6Fp16(float16_t *input0, float16_t *input1, float16_t *outp ArithmeticParameter *param) { int block_mod = element_size % C8NUM; int block_c8 = element_size - block_mod; - float16_t in0_opt = input0[0]; float16_t in1_opt = input1[0]; #ifdef ENABLE_NEON @@ -957,42 +1025,53 @@ int ElementOptDivRelu6Fp16(float16_t *input0, float16_t *input1, float16_t *outp float16x8_t zeros = {0, 0, 0, 0, 0, 0, 0, 0}; float16x8_t bounds = {6, 6, 6, 6, 6, 6, 6, 6}; #endif - for (int index = 0; index < block_c8; index += C8NUM) { - if (param->in_elements_num1_ == 1) { - if (in1_opt == 0) { - return NNACL_ERRCODE_DIVISOR_ZERO; - } - } else { + if (param->in_elements_num0_ == 1) { + for (int index = 0; index < block_c8; index += C8NUM) { for (int i = 0; i < C8NUM; ++i) { if (input1[i] == 0) { return NNACL_ERRCODE_DIVISOR_ZERO; } } - } #ifdef ENABLE_NEON - float16x8_t vin0 = param->in_elements_num0_ == 1 ? vin0_opt : vld1q_f16(input0); - float16x8_t vin1 = param->in_elements_num1_ == 1 ? vin1_opt : vld1q_f16(input1); - float16x8_t vout = vdivq_f16(vin0, vin1); - vout = vminq_f16(vmaxq_f16(vout, zeros), bounds); - vst1q_f16(output, vout); + float16x8_t vin0 = vin0_opt; + float16x8_t vin1 = vld1q_f16(input1); + float16x8_t vout = vminq_f16(vmaxq_f16(vdivq_f16(vin0, vin1), zeros), bounds); + vst1q_f16(output, vout); #else - for (int i = 0; i < C8NUM; ++i) { - float16_t in0 = param->in_elements_num0_ == 1 ? in0_opt : input0[i]; - float16_t in1 = param->in_elements_num1_ == 1 ? in1_opt : input1[i]; - output[i] = MSMIN(MSMAX(in0 / in1, 0), 6); - } + for (int i = 0; i < C8NUM; ++i) { + output[i] = MSMIN(MSMAX(in0_opt / input1[i], 0), 6); + } #endif - input0 += C8NUM; - input1 += C8NUM; - output += C8NUM; - } - for (int index = 0; index < block_mod; ++index) { - float16_t in0 = param->in_elements_num0_ == 1 ? in0_opt : input0[index]; - float16_t in1 = param->in_elements_num1_ == 1 ? in1_opt : input1[index]; - if (in1 == 0) { + input1 += C8NUM; + output += C8NUM; + } + for (int index = 0; index < block_mod; ++index) { + if (input1[index] == 0) { + return NNACL_ERRCODE_DIVISOR_ZERO; + } + output[index] = MSMIN(MSMAX(in0_opt / input1[index], 0), 6); + } + } else { + if (in1_opt == 0) { return NNACL_ERRCODE_DIVISOR_ZERO; } - output[index] = MSMIN(MSMAX(in0 / in1, 0), 6); + for (int index = 0; index < block_c8; index += C8NUM) { +#ifdef ENABLE_NEON + float16x8_t vin0 = vld1q_f16(input0); + float16x8_t vin1 = vin1_opt; + float16x8_t vout = vminq_f16(vmaxq_f16(vdivq_f16(vin0, vin1), zeros), bounds); + vst1q_f16(output, vout); +#else + for (int i = 0; i < C8NUM; ++i) { + output[i] = MSMIN(MSMAX(input0[i] / in1_opt, 0), 6); + } +#endif + input0 += C8NUM; + output += C8NUM; + } + for (int index = 0; index < block_mod; ++index) { + output[index] = MSMIN(MSMAX(input0[index] / in1_opt, 0), 6); + } } return NNACL_OK; } @@ -1089,39 +1168,56 @@ int ElementOptLogicalAndFp16(float16_t *input0, float16_t *input1, float16_t *ou ArithmeticParameter *param) { int block_mod = element_size % C8NUM; int block_c8 = element_size - block_mod; + float16_t in0_opt = input0[0]; + float16_t in1_opt = input1[0]; #ifdef ENABLE_NEON float16x8_t vin0_opt = {input0[0], input0[0], input0[0], input0[0], input0[0], input0[0], input0[0], input0[0]}; float16x8_t vin1_opt = {input1[0], input1[0], input1[0], input1[0], input1[0], input1[0], input1[0], input1[0]}; - float16_t in0_opt = input0[0]; - float16_t in1_opt = input1[0]; float16x8_t vtrue = {1, 1, 1, 1, 1, 1, 1, 1}; float16x8_t vfalse = {0, 0, 0, 0, 0, 0, 0, 0}; uint16x8_t mask = vmovq_n_u16(((uint16_t)(1u << 15) - 1)); uint16x8_t zeros = {0, 0, 0, 0, 0, 0, 0, 0}; #endif - for (int index = 0; index < block_c8; index += C8NUM) { + if (param->in_elements_num0_ == 1) { + for (int index = 0; index < block_c8; index += C8NUM) { #ifdef ENABLE_NEON - float16x8_t vin0_ = param->in_elements_num0_ == 1 ? vin0_opt : vld1q_f16(input0); - float16x8_t vin1_ = param->in_elements_num1_ == 1 ? vin1_opt : vld1q_f16(input1); - uint16x8_t vin0 = vandq_u16(vreinterpretq_s16_f16(vin0_), mask); - uint16x8_t vin1 = vandq_u16(vreinterpretq_s16_f16(vin1_), mask); - float16x8_t vout = vbslq_f16(vceqq_u16(vandq_u16(vin0, vin1), zeros), vfalse, vtrue); - vst1q_f16(output, vout); + float16x8_t vin0_ = vin0_opt; + float16x8_t vin1_ = vld1q_f16(input1); + uint16x8_t vin0 = vandq_u16(vreinterpretq_s16_f16(vin0_), mask); + uint16x8_t vin1 = vandq_u16(vreinterpretq_s16_f16(vin1_), mask); + float16x8_t vout = vbslq_f16(vceqq_u16(vandq_u16(vin0, vin1), zeros), vfalse, vtrue); + vst1q_f16(output, vout); #else - for (int i = 0; i < C8NUM; ++i) { - float16_t in0 = param->in_elements_num0_ == 1 ? in0_opt : input0[i]; - float16_t in1 = param->in_elements_num1_ == 1 ? in1_opt : input1[i]; - output[i] = (float16_t)((bool)(in0) & (bool)(in1)); - } + for (int i = 0; i < C8NUM; ++i) { + output[i] = (float16_t)((bool)(in0_opt) & (bool)(input1[i])); + } #endif - input0 += C8NUM; - input1 += C8NUM; - output += C8NUM; - } - for (int index = 0; index < block_mod; ++index) { - float16_t in0 = param->in_elements_num0_ == 1 ? in0_opt : input0[index]; - float16_t in1 = param->in_elements_num1_ == 1 ? in1_opt : input1[index]; - output[index] = (float16_t)((bool)(in0) & (bool)(in1)); + input1 += C8NUM; + output += C8NUM; + } + for (int index = 0; index < block_mod; ++index) { + output[index] = (float16_t)((bool)(in0_opt) & (bool)(input1[index])); + } + } else { + for (int index = 0; index < block_c8; index += C8NUM) { +#ifdef ENABLE_NEON + float16x8_t vin0_ = vld1q_f16(input0); + float16x8_t vin1_ = vin1_opt; + uint16x8_t vin0 = vandq_u16(vreinterpretq_s16_f16(vin0_), mask); + uint16x8_t vin1 = vandq_u16(vreinterpretq_s16_f16(vin1_), mask); + float16x8_t vout = vbslq_f16(vceqq_u16(vandq_u16(vin0, vin1), zeros), vfalse, vtrue); + vst1q_f16(output, vout); +#else + for (int i = 0; i < C8NUM; ++i) { + output[i] = (float16_t)((bool)(input0[i]) & (bool)(in1_opt)); + } +#endif + input0 += C8NUM; + output += C8NUM; + } + for (int index = 0; index < block_mod; ++index) { + output[index] = (float16_t)((bool)(input0[index]) & (bool)(in1_opt)); + } } return NNACL_OK; } @@ -1160,39 +1256,56 @@ int ElementOptLogicalOrFp16(float16_t *input0, float16_t *input1, float16_t *out ArithmeticParameter *param) { int block_mod = element_size % C8NUM; int block_c8 = element_size - block_mod; + float16_t in0_opt = input0[0]; + float16_t in1_opt = input1[0]; #ifdef ENABLE_NEON float16x8_t vin0_opt = {input0[0], input0[0], input0[0], input0[0], input0[0], input0[0], input0[0], input0[0]}; float16x8_t vin1_opt = {input1[0], input1[0], input1[0], input1[0], input1[0], input1[0], input1[0], input1[0]}; - float16_t in0_opt = input0[0]; - float16_t in1_opt = input1[0]; float16x8_t vtrue = {1, 1, 1, 1, 1, 1, 1, 1}; float16x8_t vfalse = {0, 0, 0, 0, 0, 0, 0, 0}; uint16x8_t mask = vmovq_n_u16(((uint16_t)(1u << 15) - 1)); uint16x8_t zeros = {0, 0, 0, 0, 0, 0, 0, 0}; #endif - for (int index = 0; index < block_c8; index += C8NUM) { + if (param->in_elements_num0_ == 1) { + for (int index = 0; index < block_c8; index += C8NUM) { #ifdef ENABLE_NEON - float16x8_t vin0_ = param->in_elements_num0_ == 1 ? vin0_opt : vld1q_f16(input0); - float16x8_t vin1_ = param->in_elements_num1_ == 1 ? vin1_opt : vld1q_f16(input1); - uint16x8_t vin0 = vandq_u16(vreinterpretq_s16_f16(vin0_), mask); - uint16x8_t vin1 = vandq_u16(vreinterpretq_s16_f16(vin1_), mask); - float16x8_t vout = vbslq_f16(vceqq_u16(vorrq_u16(vin0, vin1), zeros), vfalse, vtrue); - vst1q_f16(output, vout); + float16x8_t vin0_ = vin0_opt; + float16x8_t vin1_ = vld1q_f16(input1); + uint16x8_t vin0 = vandq_u16(vreinterpretq_s16_f16(vin0_), mask); + uint16x8_t vin1 = vandq_u16(vreinterpretq_s16_f16(vin1_), mask); + float16x8_t vout = vbslq_f16(vceqq_u16(vorrq_u16(vin0, vin1), zeros), vfalse, vtrue); + vst1q_f16(output, vout); #else - for (int i = 0; i < C8NUM; ++i) { - float16_t in0 = param->in_elements_num0_ == 1 ? in0_opt : input0[i]; - float16_t in1 = param->in_elements_num1_ == 1 ? in1_opt : input1[i]; - output[i] = (float16_t)((bool)(in0) | (bool)(in1)); - } + for (int i = 0; i < C8NUM; ++i) { + output[i] = (float16_t)((bool)(in0_opt) | (bool)(input1[i])); + } #endif - input0 += C8NUM; - input1 += C8NUM; - output += C8NUM; - } - for (int index = 0; index < block_mod; ++index) { - float16_t in0 = param->in_elements_num0_ == 1 ? in0_opt : input0[index]; - float16_t in1 = param->in_elements_num1_ == 1 ? in1_opt : input1[index]; - output[index] = (float16_t)((bool)(in0) | (bool)(in1)); + input1 += C8NUM; + output += C8NUM; + } + for (int index = 0; index < block_mod; ++index) { + output[index] = (float16_t)((bool)(in0_opt) | (bool)(input1[index])); + } + } else { + for (int index = 0; index < block_c8; index += C8NUM) { +#ifdef ENABLE_NEON + float16x8_t vin0_ = vld1q_f16(input0); + float16x8_t vin1_ = vin1_opt; + uint16x8_t vin0 = vandq_u16(vreinterpretq_s16_f16(vin0_), mask); + uint16x8_t vin1 = vandq_u16(vreinterpretq_s16_f16(vin1_), mask); + float16x8_t vout = vbslq_f16(vceqq_u16(vorrq_u16(vin0, vin1), zeros), vfalse, vtrue); + vst1q_f16(output, vout); +#else + for (int i = 0; i < C8NUM; ++i) { + output[i] = (float16_t)((bool)(input0[i]) | (bool)(in1_opt)); + } +#endif + input0 += C8NUM; + output += C8NUM; + } + for (int index = 0; index < block_mod; ++index) { + output[index] = (float16_t)((bool)(input0[index]) | (bool)(in1_opt)); + } } return NNACL_OK; } @@ -1234,33 +1347,48 @@ int ElementOptMaximumFp16(float16_t *input0, float16_t *input1, float16_t *outpu ArithmeticParameter *param) { int block_mod = element_size % C8NUM; int block_c8 = element_size - block_mod; + float16_t in0_opt = input0[0]; + float16_t in1_opt = input1[0]; #ifdef ENABLE_NEON float16x8_t vin0_opt = {input0[0], input0[0], input0[0], input0[0], input0[0], input0[0], input0[0], input0[0]}; float16x8_t vin1_opt = {input1[0], input1[0], input1[0], input1[0], input1[0], input1[0], input1[0], input1[0]}; - float16_t in0_opt = input0[0]; - float16_t in1_opt = input1[0]; #endif - for (int index = 0; index < block_c8; index += C8NUM) { + if (param->in_elements_num0_ == 1) { + for (int index = 0; index < block_c8; index += C8NUM) { #ifdef ENABLE_NEON - float16x8_t vin0 = param->in_elements_num0_ == 1 ? vin0_opt : vld1q_f16(input0); - float16x8_t vin1 = param->in_elements_num1_ == 1 ? vin1_opt : vld1q_f16(input1); - float16x8_t vout = vmaxq_f16(vin0, vin1); - vst1q_f16(output, vout); + float16x8_t vin0 = vin0_opt; + float16x8_t vin1 = vld1q_f16(input1); + float16x8_t vout = vmaxq_f16(vin0, vin1); + vst1q_f16(output, vout); #else - for (int i = 0; i < C8NUM; ++i) { - float16_t in0 = param->in_elements_num0_ == 1 ? in0_opt : input0[i]; - float16_t in1 = param->in_elements_num1_ == 1 ? in1_opt : input1[i]; - output[i] = MSMAX(in0, in1); - } + for (int i = 0; i < C8NUM; ++i) { + output[i] = MSMAX(in0_opt, input1[i]); + } #endif - input0 += C8NUM; - input1 += C8NUM; - output += C8NUM; - } - for (int index = 0; index < block_mod; ++index) { - float16_t in0 = param->in_elements_num0_ == 1 ? in0_opt : input0[index]; - float16_t in1 = param->in_elements_num1_ == 1 ? in1_opt : input1[index]; - output[index] = MSMAX(in0, in1); + input1 += C8NUM; + output += C8NUM; + } + for (int index = 0; index < block_mod; ++index) { + output[index] = MSMAX(in0_opt, input1[index]); + } + } else { + for (int index = 0; index < block_c8; index += C8NUM) { +#ifdef ENABLE_NEON + float16x8_t vin0 = vld1q_f16(input0); + float16x8_t vin1 = vin1_opt; + float16x8_t vout = vmaxq_f16(vin0, vin1); + vst1q_f16(output, vout); +#else + for (int i = 0; i < C8NUM; ++i) { + output[i] = MSMAX(input0[i], in1_opt); + } +#endif + input0 += C8NUM; + output += C8NUM; + } + for (int index = 0; index < block_mod; ++index) { + output[index] = MSMAX(input0[index], in1_opt); + } } return NNACL_OK; } @@ -1292,33 +1420,48 @@ int ElementOptMinimumFp16(float16_t *input0, float16_t *input1, float16_t *outpu ArithmeticParameter *param) { int block_mod = element_size % C8NUM; int block_c8 = element_size - block_mod; + float16_t in0_opt = input0[0]; + float16_t in1_opt = input1[0]; #ifdef ENABLE_NEON float16x8_t vin0_opt = {input0[0], input0[0], input0[0], input0[0], input0[0], input0[0], input0[0], input0[0]}; float16x8_t vin1_opt = {input1[0], input1[0], input1[0], input1[0], input1[0], input1[0], input1[0], input1[0]}; - float16_t in0_opt = input0[0]; - float16_t in1_opt = input1[0]; #endif - for (int index = 0; index < block_c8; index += C8NUM) { + if (param->in_elements_num0_ == 1) { + for (int index = 0; index < block_c8; index += C8NUM) { #ifdef ENABLE_NEON - float16x8_t vin0 = param->in_elements_num0_ == 1 ? vin0_opt : vld1q_f16(input0); - float16x8_t vin1 = param->in_elements_num1_ == 1 ? vin1_opt : vld1q_f16(input1); - float16x8_t vout = vminq_f16(vin0, vin1); - vst1q_f16(output, vout); + float16x8_t vin0 = vin0_opt; + float16x8_t vin1 = vld1q_f16(input1); + float16x8_t vout = vminq_f16(vin0, vin1); + vst1q_f16(output, vout); #else - for (int i = 0; i < C8NUM; ++i) { - float16_t in0 = param->in_elements_num0_ == 1 ? in0_opt : input0[i]; - float16_t in1 = param->in_elements_num1_ == 1 ? in1_opt : input1[i]; - output[i] = MSMIN(in0, in1); - } + for (int i = 0; i < C8NUM; ++i) { + output[i] = MSMIN(in0_opt, input1[i]); + } #endif - input0 += C8NUM; - input1 += C8NUM; - output += C8NUM; - } - for (int index = 0; index < block_mod; ++index) { - float16_t in0 = param->in_elements_num0_ == 1 ? in0_opt : input0[index]; - float16_t in1 = param->in_elements_num1_ == 1 ? in1_opt : input1[index]; - output[index] = MSMIN(in0, in1); + input1 += C8NUM; + output += C8NUM; + } + for (int index = 0; index < block_mod; ++index) { + output[index] = MSMIN(in0_opt, input1[index]); + } + } else { + for (int index = 0; index < block_c8; index += C8NUM) { +#ifdef ENABLE_NEON + float16x8_t vin0 = vld1q_f16(input0); + float16x8_t vin1 = vin1_opt; + float16x8_t vout = vminq_f16(vin0, vin1); + vst1q_f16(output, vout); +#else + for (int i = 0; i < C8NUM; ++i) { + output[i] = MSMIN(input0[i], in1_opt); + } +#endif + input0 += C8NUM; + output += C8NUM; + } + for (int index = 0; index < block_mod; ++index) { + output[index] = MSMIN(input0[index], in1_opt); + } } return NNACL_OK; } @@ -1354,35 +1497,50 @@ int ElementOptNotEqualFp16(float16_t *input0, float16_t *input1, float16_t *outp ArithmeticParameter *param) { int block_mod = element_size % C8NUM; int block_c8 = element_size - block_mod; + float16_t in0_opt = input0[0]; + float16_t in1_opt = input1[0]; #ifdef ENABLE_NEON float16x8_t vin0_opt = {input0[0], input0[0], input0[0], input0[0], input0[0], input0[0], input0[0], input0[0]}; float16x8_t vin1_opt = {input1[0], input1[0], input1[0], input1[0], input1[0], input1[0], input1[0], input1[0]}; - float16_t in0_opt = input0[0]; - float16_t in1_opt = input1[0]; float16x8_t vtrue = {1, 1, 1, 1, 1, 1, 1, 1}; float16x8_t vfalse = {0, 0, 0, 0, 0, 0, 0, 0}; #endif - for (int index = 0; index < block_c8; index += C8NUM) { + if (param->in_elements_num0_ == 1) { + for (int index = 0; index < block_c8; index += C8NUM) { #ifdef ENABLE_NEON - float16x8_t vin0 = param->in_elements_num0_ == 1 ? vin0_opt : vld1q_f16(input0); - float16x8_t vin1 = param->in_elements_num1_ == 1 ? vin1_opt : vld1q_f16(input1); - float16x8_t vout = vbslq_f16(vceqq_f16(vin0, vin1), vfalse, vtrue); - vst1q_f16(output, vout); + float16x8_t vin0 = vin0_opt; + float16x8_t vin1 = vld1q_f16(input1); + float16x8_t vout = vbslq_f16(vceqq_f16(vin0, vin1), vfalse, vtrue); + vst1q_f16(output, vout); #else - for (int i = 0; i < C8NUM; ++i) { - float16_t in0 = param->in_elements_num0_ == 1 ? in0_opt : input0[i]; - float16_t in1 = param->in_elements_num1_ == 1 ? in1_opt : input1[i]; - output[i] = (float16_t)(in0 != in1); - } + for (int i = 0; i < C8NUM; ++i) { + output[i] = (float16_t)(in0_opt != input1[i]); + } #endif - input0 += C8NUM; - input1 += C8NUM; - output += C8NUM; - } - for (int index = 0; index < block_mod; ++index) { - float16_t in0 = param->in_elements_num0_ == 1 ? in0_opt : input0[index]; - float16_t in1 = param->in_elements_num1_ == 1 ? in1_opt : input1[index]; - output[index] = (float16_t)(in0 != in1); + input1 += C8NUM; + output += C8NUM; + } + for (int index = 0; index < block_mod; ++index) { + output[index] = (float16_t)(in0_opt != input1[index]); + } + } else { + for (int index = 0; index < block_c8; index += C8NUM) { +#ifdef ENABLE_NEON + float16x8_t vin0 = vld1q_f16(input0); + float16x8_t vin1 = vin1_opt; + float16x8_t vout = vbslq_f16(vceqq_f16(vin0, vin1), vfalse, vtrue); + vst1q_f16(output, vout); +#else + for (int i = 0; i < C8NUM; ++i) { + output[i] = (float16_t)(input0[i] != in1_opt); + } +#endif + input0 += C8NUM; + output += C8NUM; + } + for (int index = 0; index < block_mod; ++index) { + output[index] = (float16_t)(input0[index] != in1_opt); + } } return NNACL_OK; } @@ -1418,35 +1576,50 @@ int ElementOptEqualFp16(float16_t *input0, float16_t *input1, float16_t *output, ArithmeticParameter *param) { int block_mod = element_size % C8NUM; int block_c8 = element_size - block_mod; + float16_t in0_opt = input0[0]; + float16_t in1_opt = input1[0]; #ifdef ENABLE_NEON float16x8_t vin0_opt = {input0[0], input0[0], input0[0], input0[0], input0[0], input0[0], input0[0], input0[0]}; float16x8_t vin1_opt = {input1[0], input1[0], input1[0], input1[0], input1[0], input1[0], input1[0], input1[0]}; - float16_t in0_opt = input0[0]; - float16_t in1_opt = input1[0]; float16x8_t vtrue = {1, 1, 1, 1, 1, 1, 1, 1}; float16x8_t vfalse = {0, 0, 0, 0, 0, 0, 0, 0}; #endif - for (int index = 0; index < block_c8; index += C8NUM) { + if (param->in_elements_num0_ == 1) { + for (int index = 0; index < block_c8; index += C8NUM) { #ifdef ENABLE_NEON - float16x8_t vin0 = param->in_elements_num0_ == 1 ? vin0_opt : vld1q_f16(input0); - float16x8_t vin1 = param->in_elements_num1_ == 1 ? vin1_opt : vld1q_f16(input1); - float16x8_t vout = vbslq_f16(vceqq_f16(vin0, vin1), vtrue, vfalse); - vst1q_f16(output, vout); + float16x8_t vin0 = vin0_opt; + float16x8_t vin1 = vld1q_f16(input1); + float16x8_t vout = vbslq_f16(vceqq_f16(vin0, vin1), vtrue, vfalse); + vst1q_f16(output, vout); #else - for (int i = 0; i < C8NUM; ++i) { - float16_t in0 = param->in_elements_num0_ == 1 ? in0_opt : input0[i]; - float16_t in1 = param->in_elements_num1_ == 1 ? in1_opt : input1[i]; - output[i] = (float16_t)(in0 == in1); - } + for (int i = 0; i < C8NUM; ++i) { + output[i] = (float16_t)(in0_opt == input1[i]); + } #endif - input0 += C8NUM; - input1 += C8NUM; - output += C8NUM; - } - for (int index = 0; index < block_mod; ++index) { - float16_t in0 = param->in_elements_num0_ == 1 ? in0_opt : input0[index]; - float16_t in1 = param->in_elements_num1_ == 1 ? in1_opt : input1[index]; - output[index] = (float16_t)(in0 == in1); + input1 += C8NUM; + output += C8NUM; + } + for (int index = 0; index < block_mod; ++index) { + output[index] = (float16_t)(in0_opt == input1[index]); + } + } else { + for (int index = 0; index < block_c8; index += C8NUM) { +#ifdef ENABLE_NEON + float16x8_t vin0 = vld1q_f16(input0); + float16x8_t vin1 = vin1_opt; + float16x8_t vout = vbslq_f16(vceqq_f16(vin0, vin1), vtrue, vfalse); + vst1q_f16(output, vout); +#else + for (int i = 0; i < C8NUM; ++i) { + output[i] = (float16_t)(input0[i] == in1_opt); + } +#endif + input0 += C8NUM; + output += C8NUM; + } + for (int index = 0; index < block_mod; ++index) { + output[index] = (float16_t)(input0[index] == in1_opt); + } } return NNACL_OK; } @@ -1482,35 +1655,50 @@ int ElementOptLessFp16(float16_t *input0, float16_t *input1, float16_t *output, ArithmeticParameter *param) { int block_mod = element_size % C8NUM; int block_c8 = element_size - block_mod; + float16_t in0_opt = input0[0]; + float16_t in1_opt = input1[0]; #ifdef ENABLE_NEON float16x8_t vin0_opt = {input0[0], input0[0], input0[0], input0[0], input0[0], input0[0], input0[0], input0[0]}; float16x8_t vin1_opt = {input1[0], input1[0], input1[0], input1[0], input1[0], input1[0], input1[0], input1[0]}; - float16_t in0_opt = input0[0]; - float16_t in1_opt = input1[0]; float16x8_t vtrue = {1, 1, 1, 1, 1, 1, 1, 1}; float16x8_t vfalse = {0, 0, 0, 0, 0, 0, 0, 0}; #endif - for (int index = 0; index < block_c8; index += C8NUM) { + if (param->in_elements_num0_ == 1) { + for (int index = 0; index < block_c8; index += C8NUM) { #ifdef ENABLE_NEON - float16x8_t vin0 = param->in_elements_num0_ == 1 ? vin0_opt : vld1q_f16(input0); - float16x8_t vin1 = param->in_elements_num1_ == 1 ? vin1_opt : vld1q_f16(input1); - float16x8_t vout = vbslq_f16(vcltq_f16(vin0, vin1), vtrue, vfalse); - vst1q_f16(output, vout); + float16x8_t vin0 = vin0_opt; + float16x8_t vin1 = vld1q_f16(input1); + float16x8_t vout = vbslq_f16(vcltq_f16(vin0, vin1), vtrue, vfalse); + vst1q_f16(output, vout); #else - for (int i = 0; i < C8NUM; ++i) { - float16_t in0 = param->in_elements_num0_ == 1 ? in0_opt : input0[i]; - float16_t in1 = param->in_elements_num1_ == 1 ? in1_opt : input1[i]; - output[i] = (float16_t)(in0 < in1); - } + for (int i = 0; i < C8NUM; ++i) { + output[i] = (float16_t)(in0_opt < input1[i]); + } #endif - input0 += C8NUM; - input1 += C8NUM; - output += C8NUM; - } - for (int index = 0; index < block_mod; ++index) { - float16_t in0 = param->in_elements_num0_ == 1 ? in0_opt : input0[index]; - float16_t in1 = param->in_elements_num1_ == 1 ? in1_opt : input1[index]; - output[index] = (float16_t)(in0 < in1); + input1 += C8NUM; + output += C8NUM; + } + for (int index = 0; index < block_mod; ++index) { + output[index] = (float16_t)(in0_opt < input1[index]); + } + } else { + for (int index = 0; index < block_c8; index += C8NUM) { +#ifdef ENABLE_NEON + float16x8_t vin0 = vld1q_f16(input0); + float16x8_t vin1 = vin1_opt; + float16x8_t vout = vbslq_f16(vcltq_f16(vin0, vin1), vtrue, vfalse); + vst1q_f16(output, vout); +#else + for (int i = 0; i < C8NUM; ++i) { + output[i] = (float16_t)(input0[i] < in1_opt); + } +#endif + input0 += C8NUM; + output += C8NUM; + } + for (int index = 0; index < block_mod; ++index) { + output[index] = (float16_t)(input0[index] < in1_opt); + } } return NNACL_OK; } @@ -1546,35 +1734,50 @@ int ElementOptLessEqualFp16(float16_t *input0, float16_t *input1, float16_t *out ArithmeticParameter *param) { int block_mod = element_size % C8NUM; int block_c8 = element_size - block_mod; + float16_t in0_opt = input0[0]; + float16_t in1_opt = input1[0]; #ifdef ENABLE_NEON float16x8_t vin0_opt = {input0[0], input0[0], input0[0], input0[0], input0[0], input0[0], input0[0], input0[0]}; float16x8_t vin1_opt = {input1[0], input1[0], input1[0], input1[0], input1[0], input1[0], input1[0], input1[0]}; - float16_t in0_opt = input0[0]; - float16_t in1_opt = input1[0]; float16x8_t vtrue = {1, 1, 1, 1, 1, 1, 1, 1}; float16x8_t vfalse = {0, 0, 0, 0, 0, 0, 0, 0}; #endif - for (int index = 0; index < block_c8; index += C8NUM) { + if (param->in_elements_num0_ == 1) { + for (int index = 0; index < block_c8; index += C8NUM) { #ifdef ENABLE_NEON - float16x8_t vin0 = param->in_elements_num0_ == 1 ? vin0_opt : vld1q_f16(input0); - float16x8_t vin1 = param->in_elements_num1_ == 1 ? vin1_opt : vld1q_f16(input1); - float16x8_t vout = vbslq_f16(vcleq_f16(vin0, vin1), vtrue, vfalse); - vst1q_f16(output, vout); + float16x8_t vin0 = vin0_opt; + float16x8_t vin1 = vld1q_f16(input1); + float16x8_t vout = vbslq_f16(vcleq_f16(vin0, vin1), vtrue, vfalse); + vst1q_f16(output, vout); #else - for (int i = 0; i < C8NUM; ++i) { - float16_t in0 = param->in_elements_num0_ == 1 ? in0_opt : input0[i]; - float16_t in1 = param->in_elements_num1_ == 1 ? in1_opt : input1[i]; - output[i] = (float16_t)(in0 <= in1); - } + for (int i = 0; i < C8NUM; ++i) { + output[i] = (float16_t)(in0_opt <= input1[i]); + } #endif - input0 += C8NUM; - input1 += C8NUM; - output += C8NUM; - } - for (int index = 0; index < block_mod; ++index) { - float16_t in0 = param->in_elements_num0_ == 1 ? in0_opt : input0[index]; - float16_t in1 = param->in_elements_num1_ == 1 ? in1_opt : input1[index]; - output[index] = (float16_t)(in0 <= in1); + input1 += C8NUM; + output += C8NUM; + } + for (int index = 0; index < block_mod; ++index) { + output[index] = (float16_t)(in0_opt <= input1[index]); + } + } else { + for (int index = 0; index < block_c8; index += C8NUM) { +#ifdef ENABLE_NEON + float16x8_t vin0 = vld1q_f16(input0); + float16x8_t vin1 = vin1_opt; + float16x8_t vout = vbslq_f16(vcleq_f16(vin0, vin1), vtrue, vfalse); + vst1q_f16(output, vout); +#else + for (int i = 0; i < C8NUM; ++i) { + output[i] = (float16_t)(input0[i] <= in1_opt); + } +#endif + input0 += C8NUM; + output += C8NUM; + } + for (int index = 0; index < block_mod; ++index) { + output[index] = (float16_t)(input0[index] <= in1_opt); + } } return NNACL_OK; } @@ -1610,35 +1813,50 @@ int ElementOptGreaterFp16(float16_t *input0, float16_t *input1, float16_t *outpu ArithmeticParameter *param) { int block_mod = element_size % C8NUM; int block_c8 = element_size - block_mod; + float16_t in0_opt = input0[0]; + float16_t in1_opt = input1[0]; #ifdef ENABLE_NEON float16x8_t vin0_opt = {input0[0], input0[0], input0[0], input0[0], input0[0], input0[0], input0[0], input0[0]}; float16x8_t vin1_opt = {input1[0], input1[0], input1[0], input1[0], input1[0], input1[0], input1[0], input1[0]}; - float16_t in0_opt = input0[0]; - float16_t in1_opt = input1[0]; float16x8_t vtrue = {1, 1, 1, 1, 1, 1, 1, 1}; float16x8_t vfalse = {0, 0, 0, 0, 0, 0, 0, 0}; #endif - for (int index = 0; index < block_c8; index += C8NUM) { + if (param->in_elements_num0_ == 1) { + for (int index = 0; index < block_c8; index += C8NUM) { #ifdef ENABLE_NEON - float16x8_t vin0 = param->in_elements_num0_ == 1 ? vin0_opt : vld1q_f16(input0); - float16x8_t vin1 = param->in_elements_num1_ == 1 ? vin1_opt : vld1q_f16(input1); - float16x8_t vout = vbslq_f16(vcgtq_f16(vin0, vin1), vtrue, vfalse); - vst1q_f16(output, vout); + float16x8_t vin0 = vin0_opt; + float16x8_t vin1 = vld1q_f16(input1); + float16x8_t vout = vbslq_f16(vcgtq_f16(vin0, vin1), vtrue, vfalse); + vst1q_f16(output, vout); #else - for (int i = 0; i < C8NUM; ++i) { - float16_t in0 = param->in_elements_num0_ == 1 ? in0_opt : input0[i]; - float16_t in1 = param->in_elements_num1_ == 1 ? in1_opt : input1[i]; - output[i] = (float16_t)(in0 > in1); - } + for (int i = 0; i < C8NUM; ++i) { + output[i] = (float16_t)(in0_opt > input1[i]); + } #endif - input0 += C8NUM; - input1 += C8NUM; - output += C8NUM; - } - for (int index = 0; index < block_mod; ++index) { - float16_t in0 = param->in_elements_num0_ == 1 ? in0_opt : input0[index]; - float16_t in1 = param->in_elements_num1_ == 1 ? in1_opt : input1[index]; - output[index] = (float16_t)(in0 > in1); + input1 += C8NUM; + output += C8NUM; + } + for (int index = 0; index < block_mod; ++index) { + output[index] = (float16_t)(in0_opt > input1[index]); + } + } else { + for (int index = 0; index < block_c8; index += C8NUM) { +#ifdef ENABLE_NEON + float16x8_t vin0 = vld1q_f16(input0); + float16x8_t vin1 = vin1_opt; + float16x8_t vout = vbslq_f16(vcgtq_f16(vin0, vin1), vtrue, vfalse); + vst1q_f16(output, vout); +#else + for (int i = 0; i < C8NUM; ++i) { + output[i] = (float16_t)(input0[i] > in1_opt); + } +#endif + input0 += C8NUM; + output += C8NUM; + } + for (int index = 0; index < block_mod; ++index) { + output[index] = (float16_t)(input0[index] > in1_opt); + } } return NNACL_OK; } @@ -1674,35 +1892,50 @@ int ElementOptGreaterEqualFp16(float16_t *input0, float16_t *input1, float16_t * ArithmeticParameter *param) { int block_mod = element_size % C8NUM; int block_c8 = element_size - block_mod; + float16_t in0_opt = input0[0]; + float16_t in1_opt = input1[0]; #ifdef ENABLE_NEON float16x8_t vin0_opt = {input0[0], input0[0], input0[0], input0[0], input0[0], input0[0], input0[0], input0[0]}; float16x8_t vin1_opt = {input1[0], input1[0], input1[0], input1[0], input1[0], input1[0], input1[0], input1[0]}; - float16_t in0_opt = input0[0]; - float16_t in1_opt = input1[0]; float16x8_t vtrue = {1, 1, 1, 1, 1, 1, 1, 1}; float16x8_t vfalse = {0, 0, 0, 0, 0, 0, 0, 0}; #endif - for (int index = 0; index < block_c8; index += C8NUM) { + if (param->in_elements_num0_ == 1) { + for (int index = 0; index < block_c8; index += C8NUM) { #ifdef ENABLE_NEON - float16x8_t vin0 = param->in_elements_num0_ == 1 ? vin0_opt : vld1q_f16(input0); - float16x8_t vin1 = param->in_elements_num1_ == 1 ? vin1_opt : vld1q_f16(input1); - float16x8_t vout = vbslq_f16(vcgeq_f16(vin0, vin1), vtrue, vfalse); - vst1q_f16(output, vout); + float16x8_t vin0 = vin0_opt; + float16x8_t vin1 = vld1q_f16(input1); + float16x8_t vout = vbslq_f16(vcgeq_f16(vin0, vin1), vtrue, vfalse); + vst1q_f16(output, vout); #else - for (int i = 0; i < C8NUM; ++i) { - float16_t in0 = param->in_elements_num0_ == 1 ? in0_opt : input0[i]; - float16_t in1 = param->in_elements_num1_ == 1 ? in1_opt : input1[i]; - output[i] = (float16_t)(in0 >= in1); - } + for (int i = 0; i < C8NUM; ++i) { + output[i] = (float16_t)(in0_opt >= input1[i]); + } #endif - input0 += C8NUM; - input1 += C8NUM; - output += C8NUM; - } - for (int index = 0; index < block_mod; ++index) { - float16_t in0 = param->in_elements_num0_ == 1 ? in0_opt : input0[index]; - float16_t in1 = param->in_elements_num1_ == 1 ? in1_opt : input1[index]; - output[index] = (float16_t)(in0 >= in1); + input1 += C8NUM; + output += C8NUM; + } + for (int index = 0; index < block_mod; ++index) { + output[index] = (float16_t)(in0_opt >= input1[index]); + } + } else { + for (int index = 0; index < block_c8; index += C8NUM) { +#ifdef ENABLE_NEON + float16x8_t vin0 = vld1q_f16(input0); + float16x8_t vin1 = vin1_opt; + float16x8_t vout = vbslq_f16(vcgeq_f16(vin0, vin1), vtrue, vfalse); + vst1q_f16(output, vout); +#else + for (int i = 0; i < C8NUM; ++i) { + output[i] = (float16_t)(input0[i] >= in1_opt); + } +#endif + input0 += C8NUM; + output += C8NUM; + } + for (int index = 0; index < block_mod; ++index) { + output[index] = (float16_t)(input0[index] >= in1_opt); + } } return NNACL_OK; }