!5427 modify arm cpu fp16 op: arithmetic

Merge pull request !5427 from 陶云浩/master
This commit is contained in:
mindspore-ci-bot 2020-09-04 09:07:16 +08:00 committed by Gitee
commit 4499d126d6
1 changed files with 574 additions and 341 deletions

View File

@ -580,27 +580,42 @@ int ElementOptSubFp16(float16_t *input0, float16_t *input1, float16_t *output, i
float16x8_t vin0_opt = {input0[0], input0[0], input0[0], input0[0], input0[0], input0[0], input0[0], input0[0]}; float16x8_t vin0_opt = {input0[0], input0[0], input0[0], input0[0], input0[0], input0[0], input0[0], input0[0]};
float16x8_t vin1_opt = {input1[0], input1[0], input1[0], input1[0], input1[0], input1[0], input1[0], input1[0]}; float16x8_t vin1_opt = {input1[0], input1[0], input1[0], input1[0], input1[0], input1[0], input1[0], input1[0]};
#endif #endif
if (param->in_elements_num0_ == 1) {
for (int index = 0; index < block_c8; index += C8NUM) { for (int index = 0; index < block_c8; index += C8NUM) {
#ifdef ENABLE_NEON #ifdef ENABLE_NEON
float16x8_t vin0 = param->in_elements_num0_ == 1 ? vin0_opt : vld1q_f16(input0); float16x8_t vin0 = vin0_opt;
float16x8_t vin1 = param->in_elements_num1_ == 1 ? vin1_opt : vld1q_f16(input1); float16x8_t vin1 = vld1q_f16(input1);
float16x8_t vout = vsubq_f16(vin0, vin1); float16x8_t vout = vsubq_f16(vin0, vin1);
vst1q_f16(output, vout); vst1q_f16(output, vout);
#else #else
for (int i = 0; i < C8NUM; ++i) { for (int i = 0; i < C8NUM; ++i) {
float16_t in0 = param->in_elements_num0_ == 1 ? in0_opt : input0[i]; output[i] = in0_opt - input1[i];
float16_t in1 = param->in_elements_num1_ == 1 ? in1_opt : input1[i];
output[i] = in0 - in1;
} }
#endif #endif
input0 += C8NUM;
input1 += C8NUM; input1 += C8NUM;
output += C8NUM; output += C8NUM;
} }
for (int index = 0; index < block_mod; ++index) { for (int index = 0; index < block_mod; ++index) {
float16_t in0 = param->in_elements_num0_ == 1 ? in0_opt : input0[index]; output[index] = in0_opt - input1[index];
float16_t in1 = param->in_elements_num1_ == 1 ? in1_opt : input1[index]; }
output[index] = in0 - in1; } else {
for (int index = 0; index < block_c8; index += C8NUM) {
#ifdef ENABLE_NEON
float16x8_t vin0 = vld1q_f16(input0);
float16x8_t vin1 = vin1_opt;
float16x8_t vout = vsubq_f16(vin0, vin1);
vst1q_f16(output, vout);
#else
for (int i = 0; i < C8NUM; ++i) {
output[i] = input0[i] - in1_opt;
}
#endif
input0 += C8NUM;
output += C8NUM;
}
for (int index = 0; index < block_mod; ++index) {
output[index] = input0[index] - in1_opt;
}
} }
return NNACL_OK; return NNACL_OK;
} }
@ -644,30 +659,47 @@ int ElementOptSubReluFp16(float16_t *input0, float16_t *input1, float16_t *outpu
float16x8_t vin1_opt = {input1[0], input1[0], input1[0], input1[0], input1[0], input1[0], input1[0], input1[0]}; float16x8_t vin1_opt = {input1[0], input1[0], input1[0], input1[0], input1[0], input1[0], input1[0], input1[0]};
float16x8_t zeros = {0, 0, 0, 0, 0, 0, 0, 0}; float16x8_t zeros = {0, 0, 0, 0, 0, 0, 0, 0};
#endif #endif
if (param->in_elements_num0_ == 1) {
for (int index = 0; index < block_c8; index += C8NUM) { for (int index = 0; index < block_c8; index += C8NUM) {
#ifdef ENABLE_NEON #ifdef ENABLE_NEON
float16x8_t vin0 = param->in_elements_num0_ == 1 ? vin0_opt : vld1q_f16(input0); float16x8_t vin0 = vin0_opt;
float16x8_t vin1 = param->in_elements_num1_ == 1 ? vin1_opt : vld1q_f16(input1); float16x8_t vin1 = vld1q_f16(input1);
float16x8_t vout = vsubq_f16(vin0, vin1); float16x8_t vout = vsubq_f16(vin0, vin1);
vout = vmaxq_f16(vout, zeros); vout = vmaxq_f16(vout, zeros);
vst1q_f16(output, vout); vst1q_f16(output, vout);
#else #else
for (int i = 0; i < C8NUM; ++i) { for (int i = 0; i < C8NUM; ++i) {
float16_t in0 = param->in_elements_num0_ == 1 ? in0_opt : input0[i]; output[i] = MSMAX(in0_opt - input1[i], 0);
float16_t in1 = param->in_elements_num1_ == 1 ? in1_opt : input1[i];
output[i] = MSMAX(in0 - in1, 0);
} }
#endif #endif
input0 += C8NUM;
input1 += C8NUM; input1 += C8NUM;
output += C8NUM; output += C8NUM;
} }
for (int index = 0; index < block_mod; ++index) { for (int index = 0; index < block_mod; ++index) {
float16_t in0 = param->in_elements_num0_ == 1 ? in0_opt : input0[index]; float16_t res = in0_opt - input1[index];
float16_t in1 = param->in_elements_num1_ == 1 ? in1_opt : input1[index];
float16_t res = in0 - in1;
output[index] = res > 0 ? res : 0; output[index] = res > 0 ? res : 0;
} }
} else {
for (int index = 0; index < block_c8; index += C8NUM) {
#ifdef ENABLE_NEON
float16x8_t vin0 = vld1q_f16(input0);
float16x8_t vin1 = vin1_opt;
float16x8_t vout = vsubq_f16(vin0, vin1);
vout = vmaxq_f16(vout, zeros);
vst1q_f16(output, vout);
#else
for (int i = 0; i < C8NUM; ++i) {
output[i] = MSMAX(input0[i] - in1_opt, 0);
}
#endif
input0 += C8NUM;
output += C8NUM;
}
for (int index = 0; index < block_mod; ++index) {
float16_t res = input0[index] - in1_opt;
output[index] = res > 0 ? res : 0;
}
}
return NNACL_OK; return NNACL_OK;
} }
@ -712,30 +744,45 @@ int ElementOptSubRelu6Fp16(float16_t *input0, float16_t *input1, float16_t *outp
float16x8_t zeros = {0, 0, 0, 0, 0, 0, 0, 0}; float16x8_t zeros = {0, 0, 0, 0, 0, 0, 0, 0};
float16x8_t bounds = {6, 6, 6, 6, 6, 6, 6, 6}; float16x8_t bounds = {6, 6, 6, 6, 6, 6, 6, 6};
#endif #endif
if (param->in_elements_num0_ == 1) {
for (int index = 0; index < block_c8; index += C8NUM) { for (int index = 0; index < block_c8; index += C8NUM) {
#ifdef ENABLE_NEON #ifdef ENABLE_NEON
float16x8_t vin0 = param->in_elements_num0_ == 1 ? vin0_opt : vld1q_f16(input0); float16x8_t vin0 = vin0_opt;
float16x8_t vin1 = param->in_elements_num1_ == 1 ? vin1_opt : vld1q_f16(input1); float16x8_t vin1 = vld1q_f16(input1);
float16x8_t vout = vsubq_f16(vin0, vin1); float16x8_t vout = vsubq_f16(vin0, vin1);
vout = vminq_f16(vmaxq_f16(vout, zeros), bounds); vout = vminq_f16(vmaxq_f16(vout, zeros), bounds);
vst1q_f16(output, vout); vst1q_f16(output, vout);
#else #else
for (int i = 0; i < C8NUM; ++i) { for (int i = 0; i < C8NUM; ++i) {
float16_t in0 = param->in_elements_num0_ == 1 ? in0_opt : input0[i]; output[i] = MSMIN(MSMAX(in0_opt - input1[i], 0), 6);
float16_t in1 = param->in_elements_num1_ == 1 ? in1_opt : input1[i];
output[i] = MSMIN(MSMAX(in0 - in1, 0), 6);
} }
#endif #endif
input0 += C8NUM;
input1 += C8NUM; input1 += C8NUM;
output += C8NUM; output += C8NUM;
} }
for (int index = 0; index < block_mod; ++index) { for (int index = 0; index < block_mod; ++index) {
float16_t in0 = param->in_elements_num0_ == 1 ? in0_opt : input0[index]; output[index] = MSMIN(MSMAX(in0_opt - input1[index], 0), 6);
float16_t in1 = param->in_elements_num1_ == 1 ? in1_opt : input1[index]; }
output[index] = MSMIN(MSMAX(in0 - in1, 0), 6); } else {
for (int index = 0; index < block_c8; index += C8NUM) {
#ifdef ENABLE_NEON
float16x8_t vin0 = vld1q_f16(input0);
float16x8_t vin1 = vin1_opt;
float16x8_t vout = vsubq_f16(vin0, vin1);
vout = vminq_f16(vmaxq_f16(vout, zeros), bounds);
vst1q_f16(output, vout);
#else
for (int i = 0; i < C8NUM; ++i) {
output[i] = MSMIN(MSMAX(input0[i] - in1_opt, 0), 6);
}
#endif
input0 += C8NUM;
output += C8NUM;
}
for (int index = 0; index < block_mod; ++index) {
output[index] = MSMIN(MSMAX(input0[index] - in1_opt, 0), 6);
}
} }
return NNACL_OK; return NNACL_OK;
} }
@ -781,41 +828,53 @@ int ElementOptDivFp16(float16_t *input0, float16_t *input1, float16_t *output, i
float16x8_t vin0_opt = {input0[0], input0[0], input0[0], input0[0], input0[0], input0[0], input0[0], input0[0]}; float16x8_t vin0_opt = {input0[0], input0[0], input0[0], input0[0], input0[0], input0[0], input0[0], input0[0]};
float16x8_t vin1_opt = {input1[0], input1[0], input1[0], input1[0], input1[0], input1[0], input1[0], input1[0]}; float16x8_t vin1_opt = {input1[0], input1[0], input1[0], input1[0], input1[0], input1[0], input1[0], input1[0]};
#endif #endif
if (param->in_elements_num0_ == 1) {
for (int index = 0; index < block_c8; index += C8NUM) { for (int index = 0; index < block_c8; index += C8NUM) {
if (param->in_elements_num1_ == 1) {
if (in1_opt == 0) {
return NNACL_ERRCODE_DIVISOR_ZERO;
}
} else {
for (int i = 0; i < C8NUM; ++i) { for (int i = 0; i < C8NUM; ++i) {
if (input1[i] == 0) { if (input1[i] == 0) {
return NNACL_ERRCODE_DIVISOR_ZERO; return NNACL_ERRCODE_DIVISOR_ZERO;
} }
} }
}
#ifdef ENABLE_NEON #ifdef ENABLE_NEON
float16x8_t vin0 = param->in_elements_num0_ == 1 ? vin0_opt : vld1q_f16(input0); float16x8_t vin0 = vin0_opt;
float16x8_t vin1 = param->in_elements_num1_ == 1 ? vin1_opt : vld1q_f16(input1); float16x8_t vin1 = vld1q_f16(input1);
float16x8_t vout = vdivq_f16(vin0, vin1); float16x8_t vout = vdivq_f16(vin0, vin1);
vst1q_f16(output, vout); vst1q_f16(output, vout);
#else #else
for (int i = 0; i < C8NUM; ++i) { for (int i = 0; i < C8NUM; ++i) {
float16_t in0 = param->in_elements_num0_ == 1 ? in0_opt : input0[i]; output[i] = in0_opt / input1[i];
float16_t in1 = param->in_elements_num1_ == 1 ? in1_opt : input1[i];
output[i] = in0 / in1;
} }
#endif #endif
input0 += C8NUM;
input1 += C8NUM; input1 += C8NUM;
output += C8NUM; output += C8NUM;
} }
for (int index = 0; index < block_mod; ++index) { for (int index = 0; index < block_mod; ++index) {
float16_t in0 = param->in_elements_num0_ == 1 ? in0_opt : input0[index]; if (input1[index] == 0) {
float16_t in1 = param->in_elements_num1_ == 1 ? in1_opt : input1[index];
if (in1 == 0) {
return NNACL_ERRCODE_DIVISOR_ZERO; return NNACL_ERRCODE_DIVISOR_ZERO;
} }
output[index] = in0 / in1; output[index] = in0_opt / input1[index];
}
} else {
if (in1_opt == 0) {
return NNACL_ERRCODE_DIVISOR_ZERO;
}
for (int index = 0; index < block_c8; index += C8NUM) {
#ifdef ENABLE_NEON
float16x8_t vin0 = vld1q_f16(input0);
float16x8_t vin1 = vin1_opt;
float16x8_t vout = vdivq_f16(vin0, vin1);
vst1q_f16(output, vout);
#else
for (int i = 0; i < C8NUM; ++i) {
output[i] = input0[i] / in1_opt;
}
#endif
input0 += C8NUM;
output += C8NUM;
}
for (int index = 0; index < block_mod; ++index) {
output[index] = input0[index] / in1_opt;
}
} }
return NNACL_OK; return NNACL_OK;
} }
@ -867,43 +926,53 @@ int ElementOptDivReluFp16(float16_t *input0, float16_t *input1, float16_t *outpu
float16x8_t vin1_opt = {input1[0], input1[0], input1[0], input1[0], input1[0], input1[0], input1[0], input1[0]}; float16x8_t vin1_opt = {input1[0], input1[0], input1[0], input1[0], input1[0], input1[0], input1[0], input1[0]};
float16x8_t zeros = {0, 0, 0, 0, 0, 0, 0, 0}; float16x8_t zeros = {0, 0, 0, 0, 0, 0, 0, 0};
#endif #endif
if (param->in_elements_num0_ == 1) {
for (int index = 0; index < block_c8; index += C8NUM) { for (int index = 0; index < block_c8; index += C8NUM) {
if (param->in_elements_num1_ == 1) {
if (in1_opt == 0) {
return NNACL_ERRCODE_DIVISOR_ZERO;
}
} else {
for (int i = 0; i < C8NUM; ++i) { for (int i = 0; i < C8NUM; ++i) {
if (input1[i] == 0) { if (input1[i] == 0) {
return NNACL_ERRCODE_DIVISOR_ZERO; return NNACL_ERRCODE_DIVISOR_ZERO;
} }
} }
}
#ifdef ENABLE_NEON #ifdef ENABLE_NEON
float16x8_t vin0 = param->in_elements_num0_ == 1 ? vin0_opt : vld1q_f16(input0); float16x8_t vin0 = vin0_opt;
float16x8_t vin1 = param->in_elements_num1_ == 1 ? vin1_opt : vld1q_f16(input1); float16x8_t vin1 = vld1q_f16(input1);
float16x8_t vout = vdivq_f16(vin0, vin1); float16x8_t vout = vmaxq_f16(vdivq_f16(vin0, vin1), zeros);
vout = vmaxq_f16(vout, zeros);
vst1q_f16(output, vout); vst1q_f16(output, vout);
#else #else
for (int i = 0; i < C8NUM; ++i) { for (int i = 0; i < C8NUM; ++i) {
float16_t in0 = param->in_elements_num0_ == 1 ? in0_opt : input0[i]; output[i] = MSMAX(in0_opt / input1[i], 0);
float16_t in1 = param->in_elements_num1_ == 1 ? in1_opt : input1[i];
output[i] = MSMAX(in0 / in1, 0);
} }
#endif #endif
input0 += C8NUM;
input1 += C8NUM; input1 += C8NUM;
output += C8NUM; output += C8NUM;
} }
for (int index = 0; index < block_mod; ++index) { for (int index = 0; index < block_mod; ++index) {
float16_t in0 = param->in_elements_num0_ == 1 ? in0_opt : input0[index]; if (input1[index] == 0) {
float16_t in1 = param->in_elements_num1_ == 1 ? in1_opt : input1[index];
if (in1 == 0) {
return NNACL_ERRCODE_DIVISOR_ZERO; return NNACL_ERRCODE_DIVISOR_ZERO;
} }
float16_t res = in0 / in1; output[index] = MSMAX(in0_opt / input1[index], 0);
output[index] = res > 0 ? res : 0; }
} else {
if (in1_opt == 0) {
return NNACL_ERRCODE_DIVISOR_ZERO;
}
for (int index = 0; index < block_c8; index += C8NUM) {
#ifdef ENABLE_NEON
float16x8_t vin0 = vld1q_f16(input0);
float16x8_t vin1 = vin1_opt;
float16x8_t vout = vmaxq_f16(vdivq_f16(vin0, vin1), zeros);
vst1q_f16(output, vout);
#else
for (int i = 0; i < C8NUM; ++i) {
output[i] = MSMAX(input0[i] / in1_opt, 0);
}
#endif
input0 += C8NUM;
output += C8NUM;
}
for (int index = 0; index < block_mod; ++index) {
output[index] = MSMAX(input0[index] / in1_opt, 0);
}
} }
return NNACL_OK; return NNACL_OK;
} }
@ -948,7 +1017,6 @@ int ElementOptDivRelu6Fp16(float16_t *input0, float16_t *input1, float16_t *outp
ArithmeticParameter *param) { ArithmeticParameter *param) {
int block_mod = element_size % C8NUM; int block_mod = element_size % C8NUM;
int block_c8 = element_size - block_mod; int block_c8 = element_size - block_mod;
float16_t in0_opt = input0[0]; float16_t in0_opt = input0[0];
float16_t in1_opt = input1[0]; float16_t in1_opt = input1[0];
#ifdef ENABLE_NEON #ifdef ENABLE_NEON
@ -957,42 +1025,53 @@ int ElementOptDivRelu6Fp16(float16_t *input0, float16_t *input1, float16_t *outp
float16x8_t zeros = {0, 0, 0, 0, 0, 0, 0, 0}; float16x8_t zeros = {0, 0, 0, 0, 0, 0, 0, 0};
float16x8_t bounds = {6, 6, 6, 6, 6, 6, 6, 6}; float16x8_t bounds = {6, 6, 6, 6, 6, 6, 6, 6};
#endif #endif
if (param->in_elements_num0_ == 1) {
for (int index = 0; index < block_c8; index += C8NUM) { for (int index = 0; index < block_c8; index += C8NUM) {
if (param->in_elements_num1_ == 1) {
if (in1_opt == 0) {
return NNACL_ERRCODE_DIVISOR_ZERO;
}
} else {
for (int i = 0; i < C8NUM; ++i) { for (int i = 0; i < C8NUM; ++i) {
if (input1[i] == 0) { if (input1[i] == 0) {
return NNACL_ERRCODE_DIVISOR_ZERO; return NNACL_ERRCODE_DIVISOR_ZERO;
} }
} }
}
#ifdef ENABLE_NEON #ifdef ENABLE_NEON
float16x8_t vin0 = param->in_elements_num0_ == 1 ? vin0_opt : vld1q_f16(input0); float16x8_t vin0 = vin0_opt;
float16x8_t vin1 = param->in_elements_num1_ == 1 ? vin1_opt : vld1q_f16(input1); float16x8_t vin1 = vld1q_f16(input1);
float16x8_t vout = vdivq_f16(vin0, vin1); float16x8_t vout = vminq_f16(vmaxq_f16(vdivq_f16(vin0, vin1), zeros), bounds);
vout = vminq_f16(vmaxq_f16(vout, zeros), bounds);
vst1q_f16(output, vout); vst1q_f16(output, vout);
#else #else
for (int i = 0; i < C8NUM; ++i) { for (int i = 0; i < C8NUM; ++i) {
float16_t in0 = param->in_elements_num0_ == 1 ? in0_opt : input0[i]; output[i] = MSMIN(MSMAX(in0_opt / input1[i], 0), 6);
float16_t in1 = param->in_elements_num1_ == 1 ? in1_opt : input1[i];
output[i] = MSMIN(MSMAX(in0 / in1, 0), 6);
} }
#endif #endif
input0 += C8NUM;
input1 += C8NUM; input1 += C8NUM;
output += C8NUM; output += C8NUM;
} }
for (int index = 0; index < block_mod; ++index) { for (int index = 0; index < block_mod; ++index) {
float16_t in0 = param->in_elements_num0_ == 1 ? in0_opt : input0[index]; if (input1[index] == 0) {
float16_t in1 = param->in_elements_num1_ == 1 ? in1_opt : input1[index];
if (in1 == 0) {
return NNACL_ERRCODE_DIVISOR_ZERO; return NNACL_ERRCODE_DIVISOR_ZERO;
} }
output[index] = MSMIN(MSMAX(in0 / in1, 0), 6); output[index] = MSMIN(MSMAX(in0_opt / input1[index], 0), 6);
}
} else {
if (in1_opt == 0) {
return NNACL_ERRCODE_DIVISOR_ZERO;
}
for (int index = 0; index < block_c8; index += C8NUM) {
#ifdef ENABLE_NEON
float16x8_t vin0 = vld1q_f16(input0);
float16x8_t vin1 = vin1_opt;
float16x8_t vout = vminq_f16(vmaxq_f16(vdivq_f16(vin0, vin1), zeros), bounds);
vst1q_f16(output, vout);
#else
for (int i = 0; i < C8NUM; ++i) {
output[i] = MSMIN(MSMAX(input0[i] / in1_opt, 0), 6);
}
#endif
input0 += C8NUM;
output += C8NUM;
}
for (int index = 0; index < block_mod; ++index) {
output[index] = MSMIN(MSMAX(input0[index] / in1_opt, 0), 6);
}
} }
return NNACL_OK; return NNACL_OK;
} }
@ -1089,39 +1168,56 @@ int ElementOptLogicalAndFp16(float16_t *input0, float16_t *input1, float16_t *ou
ArithmeticParameter *param) { ArithmeticParameter *param) {
int block_mod = element_size % C8NUM; int block_mod = element_size % C8NUM;
int block_c8 = element_size - block_mod; int block_c8 = element_size - block_mod;
float16_t in0_opt = input0[0];
float16_t in1_opt = input1[0];
#ifdef ENABLE_NEON #ifdef ENABLE_NEON
float16x8_t vin0_opt = {input0[0], input0[0], input0[0], input0[0], input0[0], input0[0], input0[0], input0[0]}; float16x8_t vin0_opt = {input0[0], input0[0], input0[0], input0[0], input0[0], input0[0], input0[0], input0[0]};
float16x8_t vin1_opt = {input1[0], input1[0], input1[0], input1[0], input1[0], input1[0], input1[0], input1[0]}; float16x8_t vin1_opt = {input1[0], input1[0], input1[0], input1[0], input1[0], input1[0], input1[0], input1[0]};
float16_t in0_opt = input0[0];
float16_t in1_opt = input1[0];
float16x8_t vtrue = {1, 1, 1, 1, 1, 1, 1, 1}; float16x8_t vtrue = {1, 1, 1, 1, 1, 1, 1, 1};
float16x8_t vfalse = {0, 0, 0, 0, 0, 0, 0, 0}; float16x8_t vfalse = {0, 0, 0, 0, 0, 0, 0, 0};
uint16x8_t mask = vmovq_n_u16(((uint16_t)(1u << 15) - 1)); uint16x8_t mask = vmovq_n_u16(((uint16_t)(1u << 15) - 1));
uint16x8_t zeros = {0, 0, 0, 0, 0, 0, 0, 0}; uint16x8_t zeros = {0, 0, 0, 0, 0, 0, 0, 0};
#endif #endif
if (param->in_elements_num0_ == 1) {
for (int index = 0; index < block_c8; index += C8NUM) { for (int index = 0; index < block_c8; index += C8NUM) {
#ifdef ENABLE_NEON #ifdef ENABLE_NEON
float16x8_t vin0_ = param->in_elements_num0_ == 1 ? vin0_opt : vld1q_f16(input0); float16x8_t vin0_ = vin0_opt;
float16x8_t vin1_ = param->in_elements_num1_ == 1 ? vin1_opt : vld1q_f16(input1); float16x8_t vin1_ = vld1q_f16(input1);
uint16x8_t vin0 = vandq_u16(vreinterpretq_s16_f16(vin0_), mask); uint16x8_t vin0 = vandq_u16(vreinterpretq_s16_f16(vin0_), mask);
uint16x8_t vin1 = vandq_u16(vreinterpretq_s16_f16(vin1_), mask); uint16x8_t vin1 = vandq_u16(vreinterpretq_s16_f16(vin1_), mask);
float16x8_t vout = vbslq_f16(vceqq_u16(vandq_u16(vin0, vin1), zeros), vfalse, vtrue); float16x8_t vout = vbslq_f16(vceqq_u16(vandq_u16(vin0, vin1), zeros), vfalse, vtrue);
vst1q_f16(output, vout); vst1q_f16(output, vout);
#else #else
for (int i = 0; i < C8NUM; ++i) { for (int i = 0; i < C8NUM; ++i) {
float16_t in0 = param->in_elements_num0_ == 1 ? in0_opt : input0[i]; output[i] = (float16_t)((bool)(in0_opt) & (bool)(input1[i]));
float16_t in1 = param->in_elements_num1_ == 1 ? in1_opt : input1[i];
output[i] = (float16_t)((bool)(in0) & (bool)(in1));
} }
#endif #endif
input0 += C8NUM;
input1 += C8NUM; input1 += C8NUM;
output += C8NUM; output += C8NUM;
} }
for (int index = 0; index < block_mod; ++index) { for (int index = 0; index < block_mod; ++index) {
float16_t in0 = param->in_elements_num0_ == 1 ? in0_opt : input0[index]; output[index] = (float16_t)((bool)(in0_opt) & (bool)(input1[index]));
float16_t in1 = param->in_elements_num1_ == 1 ? in1_opt : input1[index]; }
output[index] = (float16_t)((bool)(in0) & (bool)(in1)); } else {
for (int index = 0; index < block_c8; index += C8NUM) {
#ifdef ENABLE_NEON
float16x8_t vin0_ = vld1q_f16(input0);
float16x8_t vin1_ = vin1_opt;
uint16x8_t vin0 = vandq_u16(vreinterpretq_s16_f16(vin0_), mask);
uint16x8_t vin1 = vandq_u16(vreinterpretq_s16_f16(vin1_), mask);
float16x8_t vout = vbslq_f16(vceqq_u16(vandq_u16(vin0, vin1), zeros), vfalse, vtrue);
vst1q_f16(output, vout);
#else
for (int i = 0; i < C8NUM; ++i) {
output[i] = (float16_t)((bool)(input0[i]) & (bool)(in1_opt));
}
#endif
input0 += C8NUM;
output += C8NUM;
}
for (int index = 0; index < block_mod; ++index) {
output[index] = (float16_t)((bool)(input0[index]) & (bool)(in1_opt));
}
} }
return NNACL_OK; return NNACL_OK;
} }
@ -1160,39 +1256,56 @@ int ElementOptLogicalOrFp16(float16_t *input0, float16_t *input1, float16_t *out
ArithmeticParameter *param) { ArithmeticParameter *param) {
int block_mod = element_size % C8NUM; int block_mod = element_size % C8NUM;
int block_c8 = element_size - block_mod; int block_c8 = element_size - block_mod;
float16_t in0_opt = input0[0];
float16_t in1_opt = input1[0];
#ifdef ENABLE_NEON #ifdef ENABLE_NEON
float16x8_t vin0_opt = {input0[0], input0[0], input0[0], input0[0], input0[0], input0[0], input0[0], input0[0]}; float16x8_t vin0_opt = {input0[0], input0[0], input0[0], input0[0], input0[0], input0[0], input0[0], input0[0]};
float16x8_t vin1_opt = {input1[0], input1[0], input1[0], input1[0], input1[0], input1[0], input1[0], input1[0]}; float16x8_t vin1_opt = {input1[0], input1[0], input1[0], input1[0], input1[0], input1[0], input1[0], input1[0]};
float16_t in0_opt = input0[0];
float16_t in1_opt = input1[0];
float16x8_t vtrue = {1, 1, 1, 1, 1, 1, 1, 1}; float16x8_t vtrue = {1, 1, 1, 1, 1, 1, 1, 1};
float16x8_t vfalse = {0, 0, 0, 0, 0, 0, 0, 0}; float16x8_t vfalse = {0, 0, 0, 0, 0, 0, 0, 0};
uint16x8_t mask = vmovq_n_u16(((uint16_t)(1u << 15) - 1)); uint16x8_t mask = vmovq_n_u16(((uint16_t)(1u << 15) - 1));
uint16x8_t zeros = {0, 0, 0, 0, 0, 0, 0, 0}; uint16x8_t zeros = {0, 0, 0, 0, 0, 0, 0, 0};
#endif #endif
if (param->in_elements_num0_ == 1) {
for (int index = 0; index < block_c8; index += C8NUM) { for (int index = 0; index < block_c8; index += C8NUM) {
#ifdef ENABLE_NEON #ifdef ENABLE_NEON
float16x8_t vin0_ = param->in_elements_num0_ == 1 ? vin0_opt : vld1q_f16(input0); float16x8_t vin0_ = vin0_opt;
float16x8_t vin1_ = param->in_elements_num1_ == 1 ? vin1_opt : vld1q_f16(input1); float16x8_t vin1_ = vld1q_f16(input1);
uint16x8_t vin0 = vandq_u16(vreinterpretq_s16_f16(vin0_), mask); uint16x8_t vin0 = vandq_u16(vreinterpretq_s16_f16(vin0_), mask);
uint16x8_t vin1 = vandq_u16(vreinterpretq_s16_f16(vin1_), mask); uint16x8_t vin1 = vandq_u16(vreinterpretq_s16_f16(vin1_), mask);
float16x8_t vout = vbslq_f16(vceqq_u16(vorrq_u16(vin0, vin1), zeros), vfalse, vtrue); float16x8_t vout = vbslq_f16(vceqq_u16(vorrq_u16(vin0, vin1), zeros), vfalse, vtrue);
vst1q_f16(output, vout); vst1q_f16(output, vout);
#else #else
for (int i = 0; i < C8NUM; ++i) { for (int i = 0; i < C8NUM; ++i) {
float16_t in0 = param->in_elements_num0_ == 1 ? in0_opt : input0[i]; output[i] = (float16_t)((bool)(in0_opt) | (bool)(input1[i]));
float16_t in1 = param->in_elements_num1_ == 1 ? in1_opt : input1[i];
output[i] = (float16_t)((bool)(in0) | (bool)(in1));
} }
#endif #endif
input0 += C8NUM;
input1 += C8NUM; input1 += C8NUM;
output += C8NUM; output += C8NUM;
} }
for (int index = 0; index < block_mod; ++index) { for (int index = 0; index < block_mod; ++index) {
float16_t in0 = param->in_elements_num0_ == 1 ? in0_opt : input0[index]; output[index] = (float16_t)((bool)(in0_opt) | (bool)(input1[index]));
float16_t in1 = param->in_elements_num1_ == 1 ? in1_opt : input1[index]; }
output[index] = (float16_t)((bool)(in0) | (bool)(in1)); } else {
for (int index = 0; index < block_c8; index += C8NUM) {
#ifdef ENABLE_NEON
float16x8_t vin0_ = vld1q_f16(input0);
float16x8_t vin1_ = vin1_opt;
uint16x8_t vin0 = vandq_u16(vreinterpretq_s16_f16(vin0_), mask);
uint16x8_t vin1 = vandq_u16(vreinterpretq_s16_f16(vin1_), mask);
float16x8_t vout = vbslq_f16(vceqq_u16(vorrq_u16(vin0, vin1), zeros), vfalse, vtrue);
vst1q_f16(output, vout);
#else
for (int i = 0; i < C8NUM; ++i) {
output[i] = (float16_t)((bool)(input0[i]) | (bool)(in1_opt));
}
#endif
input0 += C8NUM;
output += C8NUM;
}
for (int index = 0; index < block_mod; ++index) {
output[index] = (float16_t)((bool)(input0[index]) | (bool)(in1_opt));
}
} }
return NNACL_OK; return NNACL_OK;
} }
@ -1234,33 +1347,48 @@ int ElementOptMaximumFp16(float16_t *input0, float16_t *input1, float16_t *outpu
ArithmeticParameter *param) { ArithmeticParameter *param) {
int block_mod = element_size % C8NUM; int block_mod = element_size % C8NUM;
int block_c8 = element_size - block_mod; int block_c8 = element_size - block_mod;
float16_t in0_opt = input0[0];
float16_t in1_opt = input1[0];
#ifdef ENABLE_NEON #ifdef ENABLE_NEON
float16x8_t vin0_opt = {input0[0], input0[0], input0[0], input0[0], input0[0], input0[0], input0[0], input0[0]}; float16x8_t vin0_opt = {input0[0], input0[0], input0[0], input0[0], input0[0], input0[0], input0[0], input0[0]};
float16x8_t vin1_opt = {input1[0], input1[0], input1[0], input1[0], input1[0], input1[0], input1[0], input1[0]}; float16x8_t vin1_opt = {input1[0], input1[0], input1[0], input1[0], input1[0], input1[0], input1[0], input1[0]};
float16_t in0_opt = input0[0];
float16_t in1_opt = input1[0];
#endif #endif
if (param->in_elements_num0_ == 1) {
for (int index = 0; index < block_c8; index += C8NUM) { for (int index = 0; index < block_c8; index += C8NUM) {
#ifdef ENABLE_NEON #ifdef ENABLE_NEON
float16x8_t vin0 = param->in_elements_num0_ == 1 ? vin0_opt : vld1q_f16(input0); float16x8_t vin0 = vin0_opt;
float16x8_t vin1 = param->in_elements_num1_ == 1 ? vin1_opt : vld1q_f16(input1); float16x8_t vin1 = vld1q_f16(input1);
float16x8_t vout = vmaxq_f16(vin0, vin1); float16x8_t vout = vmaxq_f16(vin0, vin1);
vst1q_f16(output, vout); vst1q_f16(output, vout);
#else #else
for (int i = 0; i < C8NUM; ++i) { for (int i = 0; i < C8NUM; ++i) {
float16_t in0 = param->in_elements_num0_ == 1 ? in0_opt : input0[i]; output[i] = MSMAX(in0_opt, input1[i]);
float16_t in1 = param->in_elements_num1_ == 1 ? in1_opt : input1[i];
output[i] = MSMAX(in0, in1);
} }
#endif #endif
input0 += C8NUM;
input1 += C8NUM; input1 += C8NUM;
output += C8NUM; output += C8NUM;
} }
for (int index = 0; index < block_mod; ++index) { for (int index = 0; index < block_mod; ++index) {
float16_t in0 = param->in_elements_num0_ == 1 ? in0_opt : input0[index]; output[index] = MSMAX(in0_opt, input1[index]);
float16_t in1 = param->in_elements_num1_ == 1 ? in1_opt : input1[index]; }
output[index] = MSMAX(in0, in1); } else {
for (int index = 0; index < block_c8; index += C8NUM) {
#ifdef ENABLE_NEON
float16x8_t vin0 = vld1q_f16(input0);
float16x8_t vin1 = vin1_opt;
float16x8_t vout = vmaxq_f16(vin0, vin1);
vst1q_f16(output, vout);
#else
for (int i = 0; i < C8NUM; ++i) {
output[i] = MSMAX(input0[i], in1_opt);
}
#endif
input0 += C8NUM;
output += C8NUM;
}
for (int index = 0; index < block_mod; ++index) {
output[index] = MSMAX(input0[index], in1_opt);
}
} }
return NNACL_OK; return NNACL_OK;
} }
@ -1292,33 +1420,48 @@ int ElementOptMinimumFp16(float16_t *input0, float16_t *input1, float16_t *outpu
ArithmeticParameter *param) { ArithmeticParameter *param) {
int block_mod = element_size % C8NUM; int block_mod = element_size % C8NUM;
int block_c8 = element_size - block_mod; int block_c8 = element_size - block_mod;
float16_t in0_opt = input0[0];
float16_t in1_opt = input1[0];
#ifdef ENABLE_NEON #ifdef ENABLE_NEON
float16x8_t vin0_opt = {input0[0], input0[0], input0[0], input0[0], input0[0], input0[0], input0[0], input0[0]}; float16x8_t vin0_opt = {input0[0], input0[0], input0[0], input0[0], input0[0], input0[0], input0[0], input0[0]};
float16x8_t vin1_opt = {input1[0], input1[0], input1[0], input1[0], input1[0], input1[0], input1[0], input1[0]}; float16x8_t vin1_opt = {input1[0], input1[0], input1[0], input1[0], input1[0], input1[0], input1[0], input1[0]};
float16_t in0_opt = input0[0];
float16_t in1_opt = input1[0];
#endif #endif
if (param->in_elements_num0_ == 1) {
for (int index = 0; index < block_c8; index += C8NUM) { for (int index = 0; index < block_c8; index += C8NUM) {
#ifdef ENABLE_NEON #ifdef ENABLE_NEON
float16x8_t vin0 = param->in_elements_num0_ == 1 ? vin0_opt : vld1q_f16(input0); float16x8_t vin0 = vin0_opt;
float16x8_t vin1 = param->in_elements_num1_ == 1 ? vin1_opt : vld1q_f16(input1); float16x8_t vin1 = vld1q_f16(input1);
float16x8_t vout = vminq_f16(vin0, vin1); float16x8_t vout = vminq_f16(vin0, vin1);
vst1q_f16(output, vout); vst1q_f16(output, vout);
#else #else
for (int i = 0; i < C8NUM; ++i) { for (int i = 0; i < C8NUM; ++i) {
float16_t in0 = param->in_elements_num0_ == 1 ? in0_opt : input0[i]; output[i] = MSMIN(in0_opt, input1[i]);
float16_t in1 = param->in_elements_num1_ == 1 ? in1_opt : input1[i];
output[i] = MSMIN(in0, in1);
} }
#endif #endif
input0 += C8NUM;
input1 += C8NUM; input1 += C8NUM;
output += C8NUM; output += C8NUM;
} }
for (int index = 0; index < block_mod; ++index) { for (int index = 0; index < block_mod; ++index) {
float16_t in0 = param->in_elements_num0_ == 1 ? in0_opt : input0[index]; output[index] = MSMIN(in0_opt, input1[index]);
float16_t in1 = param->in_elements_num1_ == 1 ? in1_opt : input1[index]; }
output[index] = MSMIN(in0, in1); } else {
for (int index = 0; index < block_c8; index += C8NUM) {
#ifdef ENABLE_NEON
float16x8_t vin0 = vld1q_f16(input0);
float16x8_t vin1 = vin1_opt;
float16x8_t vout = vminq_f16(vin0, vin1);
vst1q_f16(output, vout);
#else
for (int i = 0; i < C8NUM; ++i) {
output[i] = MSMIN(input0[i], in1_opt);
}
#endif
input0 += C8NUM;
output += C8NUM;
}
for (int index = 0; index < block_mod; ++index) {
output[index] = MSMIN(input0[index], in1_opt);
}
} }
return NNACL_OK; return NNACL_OK;
} }
@ -1354,35 +1497,50 @@ int ElementOptNotEqualFp16(float16_t *input0, float16_t *input1, float16_t *outp
ArithmeticParameter *param) { ArithmeticParameter *param) {
int block_mod = element_size % C8NUM; int block_mod = element_size % C8NUM;
int block_c8 = element_size - block_mod; int block_c8 = element_size - block_mod;
float16_t in0_opt = input0[0];
float16_t in1_opt = input1[0];
#ifdef ENABLE_NEON #ifdef ENABLE_NEON
float16x8_t vin0_opt = {input0[0], input0[0], input0[0], input0[0], input0[0], input0[0], input0[0], input0[0]}; float16x8_t vin0_opt = {input0[0], input0[0], input0[0], input0[0], input0[0], input0[0], input0[0], input0[0]};
float16x8_t vin1_opt = {input1[0], input1[0], input1[0], input1[0], input1[0], input1[0], input1[0], input1[0]}; float16x8_t vin1_opt = {input1[0], input1[0], input1[0], input1[0], input1[0], input1[0], input1[0], input1[0]};
float16_t in0_opt = input0[0];
float16_t in1_opt = input1[0];
float16x8_t vtrue = {1, 1, 1, 1, 1, 1, 1, 1}; float16x8_t vtrue = {1, 1, 1, 1, 1, 1, 1, 1};
float16x8_t vfalse = {0, 0, 0, 0, 0, 0, 0, 0}; float16x8_t vfalse = {0, 0, 0, 0, 0, 0, 0, 0};
#endif #endif
if (param->in_elements_num0_ == 1) {
for (int index = 0; index < block_c8; index += C8NUM) { for (int index = 0; index < block_c8; index += C8NUM) {
#ifdef ENABLE_NEON #ifdef ENABLE_NEON
float16x8_t vin0 = param->in_elements_num0_ == 1 ? vin0_opt : vld1q_f16(input0); float16x8_t vin0 = vin0_opt;
float16x8_t vin1 = param->in_elements_num1_ == 1 ? vin1_opt : vld1q_f16(input1); float16x8_t vin1 = vld1q_f16(input1);
float16x8_t vout = vbslq_f16(vceqq_f16(vin0, vin1), vfalse, vtrue); float16x8_t vout = vbslq_f16(vceqq_f16(vin0, vin1), vfalse, vtrue);
vst1q_f16(output, vout); vst1q_f16(output, vout);
#else #else
for (int i = 0; i < C8NUM; ++i) { for (int i = 0; i < C8NUM; ++i) {
float16_t in0 = param->in_elements_num0_ == 1 ? in0_opt : input0[i]; output[i] = (float16_t)(in0_opt != input1[i]);
float16_t in1 = param->in_elements_num1_ == 1 ? in1_opt : input1[i];
output[i] = (float16_t)(in0 != in1);
} }
#endif #endif
input0 += C8NUM;
input1 += C8NUM; input1 += C8NUM;
output += C8NUM; output += C8NUM;
} }
for (int index = 0; index < block_mod; ++index) { for (int index = 0; index < block_mod; ++index) {
float16_t in0 = param->in_elements_num0_ == 1 ? in0_opt : input0[index]; output[index] = (float16_t)(in0_opt != input1[index]);
float16_t in1 = param->in_elements_num1_ == 1 ? in1_opt : input1[index]; }
output[index] = (float16_t)(in0 != in1); } else {
for (int index = 0; index < block_c8; index += C8NUM) {
#ifdef ENABLE_NEON
float16x8_t vin0 = vld1q_f16(input0);
float16x8_t vin1 = vin1_opt;
float16x8_t vout = vbslq_f16(vceqq_f16(vin0, vin1), vfalse, vtrue);
vst1q_f16(output, vout);
#else
for (int i = 0; i < C8NUM; ++i) {
output[i] = (float16_t)(input0[i] != in1_opt);
}
#endif
input0 += C8NUM;
output += C8NUM;
}
for (int index = 0; index < block_mod; ++index) {
output[index] = (float16_t)(input0[index] != in1_opt);
}
} }
return NNACL_OK; return NNACL_OK;
} }
@ -1418,35 +1576,50 @@ int ElementOptEqualFp16(float16_t *input0, float16_t *input1, float16_t *output,
ArithmeticParameter *param) { ArithmeticParameter *param) {
int block_mod = element_size % C8NUM; int block_mod = element_size % C8NUM;
int block_c8 = element_size - block_mod; int block_c8 = element_size - block_mod;
float16_t in0_opt = input0[0];
float16_t in1_opt = input1[0];
#ifdef ENABLE_NEON #ifdef ENABLE_NEON
float16x8_t vin0_opt = {input0[0], input0[0], input0[0], input0[0], input0[0], input0[0], input0[0], input0[0]}; float16x8_t vin0_opt = {input0[0], input0[0], input0[0], input0[0], input0[0], input0[0], input0[0], input0[0]};
float16x8_t vin1_opt = {input1[0], input1[0], input1[0], input1[0], input1[0], input1[0], input1[0], input1[0]}; float16x8_t vin1_opt = {input1[0], input1[0], input1[0], input1[0], input1[0], input1[0], input1[0], input1[0]};
float16_t in0_opt = input0[0];
float16_t in1_opt = input1[0];
float16x8_t vtrue = {1, 1, 1, 1, 1, 1, 1, 1}; float16x8_t vtrue = {1, 1, 1, 1, 1, 1, 1, 1};
float16x8_t vfalse = {0, 0, 0, 0, 0, 0, 0, 0}; float16x8_t vfalse = {0, 0, 0, 0, 0, 0, 0, 0};
#endif #endif
if (param->in_elements_num0_ == 1) {
for (int index = 0; index < block_c8; index += C8NUM) { for (int index = 0; index < block_c8; index += C8NUM) {
#ifdef ENABLE_NEON #ifdef ENABLE_NEON
float16x8_t vin0 = param->in_elements_num0_ == 1 ? vin0_opt : vld1q_f16(input0); float16x8_t vin0 = vin0_opt;
float16x8_t vin1 = param->in_elements_num1_ == 1 ? vin1_opt : vld1q_f16(input1); float16x8_t vin1 = vld1q_f16(input1);
float16x8_t vout = vbslq_f16(vceqq_f16(vin0, vin1), vtrue, vfalse); float16x8_t vout = vbslq_f16(vceqq_f16(vin0, vin1), vtrue, vfalse);
vst1q_f16(output, vout); vst1q_f16(output, vout);
#else #else
for (int i = 0; i < C8NUM; ++i) { for (int i = 0; i < C8NUM; ++i) {
float16_t in0 = param->in_elements_num0_ == 1 ? in0_opt : input0[i]; output[i] = (float16_t)(in0_opt == input1[i]);
float16_t in1 = param->in_elements_num1_ == 1 ? in1_opt : input1[i];
output[i] = (float16_t)(in0 == in1);
} }
#endif #endif
input0 += C8NUM;
input1 += C8NUM; input1 += C8NUM;
output += C8NUM; output += C8NUM;
} }
for (int index = 0; index < block_mod; ++index) { for (int index = 0; index < block_mod; ++index) {
float16_t in0 = param->in_elements_num0_ == 1 ? in0_opt : input0[index]; output[index] = (float16_t)(in0_opt == input1[index]);
float16_t in1 = param->in_elements_num1_ == 1 ? in1_opt : input1[index]; }
output[index] = (float16_t)(in0 == in1); } else {
for (int index = 0; index < block_c8; index += C8NUM) {
#ifdef ENABLE_NEON
float16x8_t vin0 = vld1q_f16(input0);
float16x8_t vin1 = vin1_opt;
float16x8_t vout = vbslq_f16(vceqq_f16(vin0, vin1), vtrue, vfalse);
vst1q_f16(output, vout);
#else
for (int i = 0; i < C8NUM; ++i) {
output[i] = (float16_t)(input0[i] == in1_opt);
}
#endif
input0 += C8NUM;
output += C8NUM;
}
for (int index = 0; index < block_mod; ++index) {
output[index] = (float16_t)(input0[index] == in1_opt);
}
} }
return NNACL_OK; return NNACL_OK;
} }
@ -1482,35 +1655,50 @@ int ElementOptLessFp16(float16_t *input0, float16_t *input1, float16_t *output,
ArithmeticParameter *param) { ArithmeticParameter *param) {
int block_mod = element_size % C8NUM; int block_mod = element_size % C8NUM;
int block_c8 = element_size - block_mod; int block_c8 = element_size - block_mod;
float16_t in0_opt = input0[0];
float16_t in1_opt = input1[0];
#ifdef ENABLE_NEON #ifdef ENABLE_NEON
float16x8_t vin0_opt = {input0[0], input0[0], input0[0], input0[0], input0[0], input0[0], input0[0], input0[0]}; float16x8_t vin0_opt = {input0[0], input0[0], input0[0], input0[0], input0[0], input0[0], input0[0], input0[0]};
float16x8_t vin1_opt = {input1[0], input1[0], input1[0], input1[0], input1[0], input1[0], input1[0], input1[0]}; float16x8_t vin1_opt = {input1[0], input1[0], input1[0], input1[0], input1[0], input1[0], input1[0], input1[0]};
float16_t in0_opt = input0[0];
float16_t in1_opt = input1[0];
float16x8_t vtrue = {1, 1, 1, 1, 1, 1, 1, 1}; float16x8_t vtrue = {1, 1, 1, 1, 1, 1, 1, 1};
float16x8_t vfalse = {0, 0, 0, 0, 0, 0, 0, 0}; float16x8_t vfalse = {0, 0, 0, 0, 0, 0, 0, 0};
#endif #endif
if (param->in_elements_num0_ == 1) {
for (int index = 0; index < block_c8; index += C8NUM) { for (int index = 0; index < block_c8; index += C8NUM) {
#ifdef ENABLE_NEON #ifdef ENABLE_NEON
float16x8_t vin0 = param->in_elements_num0_ == 1 ? vin0_opt : vld1q_f16(input0); float16x8_t vin0 = vin0_opt;
float16x8_t vin1 = param->in_elements_num1_ == 1 ? vin1_opt : vld1q_f16(input1); float16x8_t vin1 = vld1q_f16(input1);
float16x8_t vout = vbslq_f16(vcltq_f16(vin0, vin1), vtrue, vfalse); float16x8_t vout = vbslq_f16(vcltq_f16(vin0, vin1), vtrue, vfalse);
vst1q_f16(output, vout); vst1q_f16(output, vout);
#else #else
for (int i = 0; i < C8NUM; ++i) { for (int i = 0; i < C8NUM; ++i) {
float16_t in0 = param->in_elements_num0_ == 1 ? in0_opt : input0[i]; output[i] = (float16_t)(in0_opt < input1[i]);
float16_t in1 = param->in_elements_num1_ == 1 ? in1_opt : input1[i];
output[i] = (float16_t)(in0 < in1);
} }
#endif #endif
input0 += C8NUM;
input1 += C8NUM; input1 += C8NUM;
output += C8NUM; output += C8NUM;
} }
for (int index = 0; index < block_mod; ++index) { for (int index = 0; index < block_mod; ++index) {
float16_t in0 = param->in_elements_num0_ == 1 ? in0_opt : input0[index]; output[index] = (float16_t)(in0_opt < input1[index]);
float16_t in1 = param->in_elements_num1_ == 1 ? in1_opt : input1[index]; }
output[index] = (float16_t)(in0 < in1); } else {
for (int index = 0; index < block_c8; index += C8NUM) {
#ifdef ENABLE_NEON
float16x8_t vin0 = vld1q_f16(input0);
float16x8_t vin1 = vin1_opt;
float16x8_t vout = vbslq_f16(vcltq_f16(vin0, vin1), vtrue, vfalse);
vst1q_f16(output, vout);
#else
for (int i = 0; i < C8NUM; ++i) {
output[i] = (float16_t)(input0[i] < in1_opt);
}
#endif
input0 += C8NUM;
output += C8NUM;
}
for (int index = 0; index < block_mod; ++index) {
output[index] = (float16_t)(input0[index] < in1_opt);
}
} }
return NNACL_OK; return NNACL_OK;
} }
@ -1546,35 +1734,50 @@ int ElementOptLessEqualFp16(float16_t *input0, float16_t *input1, float16_t *out
ArithmeticParameter *param) { ArithmeticParameter *param) {
int block_mod = element_size % C8NUM; int block_mod = element_size % C8NUM;
int block_c8 = element_size - block_mod; int block_c8 = element_size - block_mod;
float16_t in0_opt = input0[0];
float16_t in1_opt = input1[0];
#ifdef ENABLE_NEON #ifdef ENABLE_NEON
float16x8_t vin0_opt = {input0[0], input0[0], input0[0], input0[0], input0[0], input0[0], input0[0], input0[0]}; float16x8_t vin0_opt = {input0[0], input0[0], input0[0], input0[0], input0[0], input0[0], input0[0], input0[0]};
float16x8_t vin1_opt = {input1[0], input1[0], input1[0], input1[0], input1[0], input1[0], input1[0], input1[0]}; float16x8_t vin1_opt = {input1[0], input1[0], input1[0], input1[0], input1[0], input1[0], input1[0], input1[0]};
float16_t in0_opt = input0[0];
float16_t in1_opt = input1[0];
float16x8_t vtrue = {1, 1, 1, 1, 1, 1, 1, 1}; float16x8_t vtrue = {1, 1, 1, 1, 1, 1, 1, 1};
float16x8_t vfalse = {0, 0, 0, 0, 0, 0, 0, 0}; float16x8_t vfalse = {0, 0, 0, 0, 0, 0, 0, 0};
#endif #endif
if (param->in_elements_num0_ == 1) {
for (int index = 0; index < block_c8; index += C8NUM) { for (int index = 0; index < block_c8; index += C8NUM) {
#ifdef ENABLE_NEON #ifdef ENABLE_NEON
float16x8_t vin0 = param->in_elements_num0_ == 1 ? vin0_opt : vld1q_f16(input0); float16x8_t vin0 = vin0_opt;
float16x8_t vin1 = param->in_elements_num1_ == 1 ? vin1_opt : vld1q_f16(input1); float16x8_t vin1 = vld1q_f16(input1);
float16x8_t vout = vbslq_f16(vcleq_f16(vin0, vin1), vtrue, vfalse); float16x8_t vout = vbslq_f16(vcleq_f16(vin0, vin1), vtrue, vfalse);
vst1q_f16(output, vout); vst1q_f16(output, vout);
#else #else
for (int i = 0; i < C8NUM; ++i) { for (int i = 0; i < C8NUM; ++i) {
float16_t in0 = param->in_elements_num0_ == 1 ? in0_opt : input0[i]; output[i] = (float16_t)(in0_opt <= input1[i]);
float16_t in1 = param->in_elements_num1_ == 1 ? in1_opt : input1[i];
output[i] = (float16_t)(in0 <= in1);
} }
#endif #endif
input0 += C8NUM;
input1 += C8NUM; input1 += C8NUM;
output += C8NUM; output += C8NUM;
} }
for (int index = 0; index < block_mod; ++index) { for (int index = 0; index < block_mod; ++index) {
float16_t in0 = param->in_elements_num0_ == 1 ? in0_opt : input0[index]; output[index] = (float16_t)(in0_opt <= input1[index]);
float16_t in1 = param->in_elements_num1_ == 1 ? in1_opt : input1[index]; }
output[index] = (float16_t)(in0 <= in1); } else {
for (int index = 0; index < block_c8; index += C8NUM) {
#ifdef ENABLE_NEON
float16x8_t vin0 = vld1q_f16(input0);
float16x8_t vin1 = vin1_opt;
float16x8_t vout = vbslq_f16(vcleq_f16(vin0, vin1), vtrue, vfalse);
vst1q_f16(output, vout);
#else
for (int i = 0; i < C8NUM; ++i) {
output[i] = (float16_t)(input0[i] <= in1_opt);
}
#endif
input0 += C8NUM;
output += C8NUM;
}
for (int index = 0; index < block_mod; ++index) {
output[index] = (float16_t)(input0[index] <= in1_opt);
}
} }
return NNACL_OK; return NNACL_OK;
} }
@ -1610,35 +1813,50 @@ int ElementOptGreaterFp16(float16_t *input0, float16_t *input1, float16_t *outpu
ArithmeticParameter *param) { ArithmeticParameter *param) {
int block_mod = element_size % C8NUM; int block_mod = element_size % C8NUM;
int block_c8 = element_size - block_mod; int block_c8 = element_size - block_mod;
float16_t in0_opt = input0[0];
float16_t in1_opt = input1[0];
#ifdef ENABLE_NEON #ifdef ENABLE_NEON
float16x8_t vin0_opt = {input0[0], input0[0], input0[0], input0[0], input0[0], input0[0], input0[0], input0[0]}; float16x8_t vin0_opt = {input0[0], input0[0], input0[0], input0[0], input0[0], input0[0], input0[0], input0[0]};
float16x8_t vin1_opt = {input1[0], input1[0], input1[0], input1[0], input1[0], input1[0], input1[0], input1[0]}; float16x8_t vin1_opt = {input1[0], input1[0], input1[0], input1[0], input1[0], input1[0], input1[0], input1[0]};
float16_t in0_opt = input0[0];
float16_t in1_opt = input1[0];
float16x8_t vtrue = {1, 1, 1, 1, 1, 1, 1, 1}; float16x8_t vtrue = {1, 1, 1, 1, 1, 1, 1, 1};
float16x8_t vfalse = {0, 0, 0, 0, 0, 0, 0, 0}; float16x8_t vfalse = {0, 0, 0, 0, 0, 0, 0, 0};
#endif #endif
if (param->in_elements_num0_ == 1) {
for (int index = 0; index < block_c8; index += C8NUM) { for (int index = 0; index < block_c8; index += C8NUM) {
#ifdef ENABLE_NEON #ifdef ENABLE_NEON
float16x8_t vin0 = param->in_elements_num0_ == 1 ? vin0_opt : vld1q_f16(input0); float16x8_t vin0 = vin0_opt;
float16x8_t vin1 = param->in_elements_num1_ == 1 ? vin1_opt : vld1q_f16(input1); float16x8_t vin1 = vld1q_f16(input1);
float16x8_t vout = vbslq_f16(vcgtq_f16(vin0, vin1), vtrue, vfalse); float16x8_t vout = vbslq_f16(vcgtq_f16(vin0, vin1), vtrue, vfalse);
vst1q_f16(output, vout); vst1q_f16(output, vout);
#else #else
for (int i = 0; i < C8NUM; ++i) { for (int i = 0; i < C8NUM; ++i) {
float16_t in0 = param->in_elements_num0_ == 1 ? in0_opt : input0[i]; output[i] = (float16_t)(in0_opt > input1[i]);
float16_t in1 = param->in_elements_num1_ == 1 ? in1_opt : input1[i];
output[i] = (float16_t)(in0 > in1);
} }
#endif #endif
input0 += C8NUM;
input1 += C8NUM; input1 += C8NUM;
output += C8NUM; output += C8NUM;
} }
for (int index = 0; index < block_mod; ++index) { for (int index = 0; index < block_mod; ++index) {
float16_t in0 = param->in_elements_num0_ == 1 ? in0_opt : input0[index]; output[index] = (float16_t)(in0_opt > input1[index]);
float16_t in1 = param->in_elements_num1_ == 1 ? in1_opt : input1[index]; }
output[index] = (float16_t)(in0 > in1); } else {
for (int index = 0; index < block_c8; index += C8NUM) {
#ifdef ENABLE_NEON
float16x8_t vin0 = vld1q_f16(input0);
float16x8_t vin1 = vin1_opt;
float16x8_t vout = vbslq_f16(vcgtq_f16(vin0, vin1), vtrue, vfalse);
vst1q_f16(output, vout);
#else
for (int i = 0; i < C8NUM; ++i) {
output[i] = (float16_t)(input0[i] > in1_opt);
}
#endif
input0 += C8NUM;
output += C8NUM;
}
for (int index = 0; index < block_mod; ++index) {
output[index] = (float16_t)(input0[index] > in1_opt);
}
} }
return NNACL_OK; return NNACL_OK;
} }
@ -1674,35 +1892,50 @@ int ElementOptGreaterEqualFp16(float16_t *input0, float16_t *input1, float16_t *
ArithmeticParameter *param) { ArithmeticParameter *param) {
int block_mod = element_size % C8NUM; int block_mod = element_size % C8NUM;
int block_c8 = element_size - block_mod; int block_c8 = element_size - block_mod;
float16_t in0_opt = input0[0];
float16_t in1_opt = input1[0];
#ifdef ENABLE_NEON #ifdef ENABLE_NEON
float16x8_t vin0_opt = {input0[0], input0[0], input0[0], input0[0], input0[0], input0[0], input0[0], input0[0]}; float16x8_t vin0_opt = {input0[0], input0[0], input0[0], input0[0], input0[0], input0[0], input0[0], input0[0]};
float16x8_t vin1_opt = {input1[0], input1[0], input1[0], input1[0], input1[0], input1[0], input1[0], input1[0]}; float16x8_t vin1_opt = {input1[0], input1[0], input1[0], input1[0], input1[0], input1[0], input1[0], input1[0]};
float16_t in0_opt = input0[0];
float16_t in1_opt = input1[0];
float16x8_t vtrue = {1, 1, 1, 1, 1, 1, 1, 1}; float16x8_t vtrue = {1, 1, 1, 1, 1, 1, 1, 1};
float16x8_t vfalse = {0, 0, 0, 0, 0, 0, 0, 0}; float16x8_t vfalse = {0, 0, 0, 0, 0, 0, 0, 0};
#endif #endif
if (param->in_elements_num0_ == 1) {
for (int index = 0; index < block_c8; index += C8NUM) { for (int index = 0; index < block_c8; index += C8NUM) {
#ifdef ENABLE_NEON #ifdef ENABLE_NEON
float16x8_t vin0 = param->in_elements_num0_ == 1 ? vin0_opt : vld1q_f16(input0); float16x8_t vin0 = vin0_opt;
float16x8_t vin1 = param->in_elements_num1_ == 1 ? vin1_opt : vld1q_f16(input1); float16x8_t vin1 = vld1q_f16(input1);
float16x8_t vout = vbslq_f16(vcgeq_f16(vin0, vin1), vtrue, vfalse); float16x8_t vout = vbslq_f16(vcgeq_f16(vin0, vin1), vtrue, vfalse);
vst1q_f16(output, vout); vst1q_f16(output, vout);
#else #else
for (int i = 0; i < C8NUM; ++i) { for (int i = 0; i < C8NUM; ++i) {
float16_t in0 = param->in_elements_num0_ == 1 ? in0_opt : input0[i]; output[i] = (float16_t)(in0_opt >= input1[i]);
float16_t in1 = param->in_elements_num1_ == 1 ? in1_opt : input1[i];
output[i] = (float16_t)(in0 >= in1);
} }
#endif #endif
input0 += C8NUM;
input1 += C8NUM; input1 += C8NUM;
output += C8NUM; output += C8NUM;
} }
for (int index = 0; index < block_mod; ++index) { for (int index = 0; index < block_mod; ++index) {
float16_t in0 = param->in_elements_num0_ == 1 ? in0_opt : input0[index]; output[index] = (float16_t)(in0_opt >= input1[index]);
float16_t in1 = param->in_elements_num1_ == 1 ? in1_opt : input1[index]; }
output[index] = (float16_t)(in0 >= in1); } else {
for (int index = 0; index < block_c8; index += C8NUM) {
#ifdef ENABLE_NEON
float16x8_t vin0 = vld1q_f16(input0);
float16x8_t vin1 = vin1_opt;
float16x8_t vout = vbslq_f16(vcgeq_f16(vin0, vin1), vtrue, vfalse);
vst1q_f16(output, vout);
#else
for (int i = 0; i < C8NUM; ++i) {
output[i] = (float16_t)(input0[i] >= in1_opt);
}
#endif
input0 += C8NUM;
output += C8NUM;
}
for (int index = 0; index < block_mod; ++index) {
output[index] = (float16_t)(input0[index] >= in1_opt);
}
} }
return NNACL_OK; return NNACL_OK;
} }