!5907 optimize op maximum minimum greater

Merge pull request !5907 from 陶云浩/lite
This commit is contained in:
mindspore-ci-bot 2020-09-11 14:52:35 +08:00 committed by Gitee
commit a2002da77c
2 changed files with 39 additions and 43 deletions

View File

@ -1321,19 +1321,14 @@ int ElementOptSquaredDifferenceFp16(float16_t *input0, float16_t *input1, float1
}
int ElementMaximumFp16(float16_t *input0, float16_t *input1, float16_t *output, int element_size) {
#ifdef ENABLE_NEON
int block_mod = element_size % C8NUM;
int block_c8 = element_size - block_mod;
for (int index = 0; index < block_c8; index += C8NUM) {
#ifdef ENABLE_NEON
float16x8_t vin0 = vld1q_f16(input0);
float16x8_t vin1 = vld1q_f16(input1);
float16x8_t vout = vmaxq_f16(vin0, vin1);
vst1q_f16(output, vout);
#else
for (int i = 0; i < C8NUM; ++i) {
output[i] = MSMAX(input0[i], input1[i]);
}
#endif
input0 += C8NUM;
input1 += C8NUM;
output += C8NUM;
@ -1341,6 +1336,11 @@ int ElementMaximumFp16(float16_t *input0, float16_t *input1, float16_t *output,
for (int index = 0; index < block_mod; ++index) {
output[index] = MSMAX(input0[index], input1[index]);
}
#else
for (int index = 0; index < element_size; ++index) {
output[index] = MSMAX(input0[index], input1[index]);
}
#endif
return NNACL_OK;
}
int ElementOptMaximumFp16(float16_t *input0, float16_t *input1, float16_t *output, int element_size,
@ -1394,19 +1394,14 @@ int ElementOptMaximumFp16(float16_t *input0, float16_t *input1, float16_t *outpu
}
int ElementMinimumFp16(float16_t *input0, float16_t *input1, float16_t *output, int element_size) {
#ifdef ENABLE_NEON
int block_mod = element_size % C8NUM;
int block_c8 = element_size - block_mod;
for (int index = 0; index < block_c8; index += C8NUM) {
#ifdef ENABLE_NEON
float16x8_t vin0 = vld1q_f16(input0);
float16x8_t vin1 = vld1q_f16(input1);
float16x8_t vout = vminq_f16(vin0, vin1);
vst1q_f16(output, vout);
#else
for (int i = 0; i < C8NUM; ++i) {
output[i] = MSMIN(input0[i], input1[i]);
}
#endif
input0 += C8NUM;
input1 += C8NUM;
output += C8NUM;
@ -1414,6 +1409,11 @@ int ElementMinimumFp16(float16_t *input0, float16_t *input1, float16_t *output,
for (int index = 0; index < block_mod; ++index) {
output[index] = MSMIN(input0[index], input1[index]);
}
#else
for (int index = 0; index < element_size; ++index) {
output[index] = MSMIN(input0[index], input1[index]);
}
#endif
return NNACL_OK;
}
int ElementOptMinimumFp16(float16_t *input0, float16_t *input1, float16_t *output, int element_size,
@ -1783,23 +1783,18 @@ int ElementOptLessEqualFp16(float16_t *input0, float16_t *input1, float16_t *out
}
int ElementGreaterFp16(float16_t *input0, float16_t *input1, float16_t *output, int element_size) {
#ifdef ENABLE_NEON
int block_mod = element_size % C8NUM;
int block_c8 = element_size - block_mod;
#ifdef ENABLE_NEON
float16x8_t vtrue = {1, 1, 1, 1, 1, 1, 1, 1};
float16x8_t vfalse = {0, 0, 0, 0, 0, 0, 0, 0};
#endif
for (int index = 0; index < block_c8; index += C8NUM) {
#ifdef ENABLE_NEON
float16x8_t vin0 = vld1q_f16(input0);
float16x8_t vin1 = vld1q_f16(input1);
float16x8_t vout = vbslq_f16(vcgtq_f16(vin0, vin1), vtrue, vfalse);
vst1q_f16(output, vout);
#else
for (int i = 0; i < C8NUM; ++i) {
output[i] = (float16_t)(input0[i] > input1[i]);
}
#endif
input0 += C8NUM;
input1 += C8NUM;
output += C8NUM;
@ -1807,6 +1802,11 @@ int ElementGreaterFp16(float16_t *input0, float16_t *input1, float16_t *output,
for (int index = 0; index < block_mod; ++index) {
output[index] = (float16_t)(input0[index] > input1[index]);
}
#else
for (int index = 0; index < element_size; ++index) {
output[index] = (float16_t)(input0[index] > input1[index]);
}
#endif
return NNACL_OK;
}
int ElementOptGreaterFp16(float16_t *input0, float16_t *input1, float16_t *output, int element_size,

View File

@ -997,21 +997,15 @@ int BroadcastLogicalOr(float *input0, float *input1, float *tile_input0, float *
}
int ElementMaximum(float *input0, float *input1, float *output, int element_size) {
#ifdef ENABLE_NEON
int block_mod = element_size % C4NUM;
int block_c4 = element_size - block_mod;
for (int index = 0; index < block_c4; index += C4NUM) {
#ifdef ENABLE_NEON
float32x4_t vin0 = vld1q_f32(input0);
float32x4_t vin1 = vld1q_f32(input1);
float32x4_t vout = vmaxq_f32(vin0, vin1);
vst1q_f32(output, vout);
#else
output[0] = input0[0] > input1[0] ? input0[0] : input1[0];
output[1] = input0[1] > input1[1] ? input0[1] : input1[1];
output[2] = input0[2] > input1[2] ? input0[2] : input1[2];
output[3] = input0[3] > input1[3] ? input0[3] : input1[3];
#endif
input0 += C4NUM;
input1 += C4NUM;
output += C4NUM;
@ -1019,6 +1013,11 @@ int ElementMaximum(float *input0, float *input1, float *output, int element_size
for (int index = 0; index < block_mod; ++index) {
output[index] = input0[index] > input1[index] ? input0[index] : input1[index];
}
#else
for (int index = 0; index < element_size; ++index) {
output[index] = MSMAX(input0[index], input1[index]);
}
#endif
return NNACL_OK;
}
@ -1029,21 +1028,15 @@ int BroadcastMaximum(float *input0, float *input1, float *tile_input0, float *ti
}
int ElementMinimum(float *input0, float *input1, float *output, int element_size) {
#ifdef ENABLE_NEON
int block_mod = element_size % C4NUM;
int block_c4 = element_size - block_mod;
for (int index = 0; index < block_c4; index += C4NUM) {
#ifdef ENABLE_NEON
float32x4_t vin0 = vld1q_f32(input0);
float32x4_t vin1 = vld1q_f32(input1);
float32x4_t vout = vminq_f32(vin0, vin1);
vst1q_f32(output, vout);
#else
output[0] = input0[0] > input1[0] ? input1[0] : input0[0];
output[1] = input0[1] > input1[1] ? input1[1] : input0[1];
output[2] = input0[2] > input1[2] ? input1[2] : input0[2];
output[3] = input0[3] > input1[3] ? input1[3] : input0[3];
#endif
input0 += C4NUM;
input1 += C4NUM;
output += C4NUM;
@ -1051,6 +1044,11 @@ int ElementMinimum(float *input0, float *input1, float *output, int element_size
for (int index = 0; index < block_mod; ++index) {
output[index] = input0[index] > input1[index] ? input1[index] : input0[index];
}
#else
for (int index = 0; index < element_size; ++index) {
output[index] = MSMIN(input0[index], input1[index]);
}
#endif
return NNACL_OK;
}
@ -1217,24 +1215,17 @@ int BroadcastLessEqual(float *input0, float *input1, float *tile_input0, float *
}
int ElementGreater(float *input0, float *input1, float *output, int element_size) {
#ifdef ENABLE_NEON
int block_mod = element_size % C4NUM;
int block_c4 = element_size - block_mod;
#ifdef ENABLE_NEON
float32x4_t vtrue = {1, 1, 1, 1};
float32x4_t vfalse = {0, 0, 0, 0};
#endif
for (int index = 0; index < block_c4; index += C4NUM) {
#ifdef ENABLE_NEON
float32x4_t vin0 = vld1q_f32(input0);
float32x4_t vin1 = vld1q_f32(input1);
float32x4_t vout = vbslq_f32(vcgtq_f32(vin0, vin1), vtrue, vfalse);
vst1q_f32(output, vout);
#else
output[0] = (float)(input0[0] > input1[0]);
output[1] = (float)(input0[1] > input1[1]);
output[2] = (float)(input0[2] > input1[2]);
output[3] = (float)(input0[3] > input1[3]);
#endif
input0 += C4NUM;
input1 += C4NUM;
output += C4NUM;
@ -1242,6 +1233,11 @@ int ElementGreater(float *input0, float *input1, float *output, int element_size
for (int index = 0; index < block_mod; ++index) {
output[index] = (float)(input0[index] > input1[index]);
}
#else
for (int index = 0; index < element_size; ++index) {
output[index] = (float)(input0[index] > input1[index]);
}
#endif
return NNACL_OK;
}