!6413 [MSLITE][Develop] fix reduce relu6

Merge pull request !6413 from sunsuodong/fix_reduce_relu6
This commit is contained in:
mindspore-ci-bot 2020-09-17 20:35:38 +08:00 committed by Gitee
commit e6c738d3b6
2 changed files with 16 additions and 26 deletions

View File

@ -34,28 +34,20 @@ int ReluFp16(const float16_t *src, float16_t *dst, int ele_num) {
}
int Relu6Fp16(const float16_t *data, float16_t *dst, int ele_num) {
int eight_block = UP_DIV(ele_num, C8NUM);
int i;
for (i = 0; i < eight_block - 1; i++) {
int index = i * C8NUM;
int offset = 0;
#ifdef ENABLE_NEON
float16x8_t relu6_data = vld1q_f16(data + index);
float16x8_t zero_data = vdupq_n_f16(0);
float16x8_t six_data = vdupq_n_f16(6);
float16x8_t zero_data = vdupq_n_f16(0);
float16x8_t six_data = vdupq_n_f16(6);
for (; offset <= ele_num - C8NUM; offset += C8NUM) {
float16x8_t relu6_data = vld1q_f16(data + offset);
relu6_data = vmaxq_f16(relu6_data, zero_data);
relu6_data = vminq_f16(relu6_data, six_data);
vst1q_f16(dst + index, relu6_data);
#else
int j;
for (j = 0; j < C8NUM; ++j) {
dst[index + j] = data[index + j] < 0 ? 0 : data[index + j];
dst[index + j] = dst[index + j] > 6 ? 6 : dst[index + j];
}
#endif
vst1q_f16(dst + offset, relu6_data);
}
for (int j = (eight_block - 1) * C8NUM; j < ele_num; ++j) {
dst[j] = data[j] < 0 ? 0 : data[j];
dst[j] = dst[j] > 6 ? 6 : dst[j];
#endif
for (; offset < ele_num; offset++) {
dst[offset] = data[offset] < 0 ? 0 : data[offset];
dst[offset] = dst[offset] > 6 ? 6 : dst[offset];
}
return NNACL_OK;
}

View File

@ -82,14 +82,7 @@ int ReduceCPUKernel::Init() {
return ReSize();
}
int ReduceCPUKernel::ReSize() {
if (in_tensors().at(0)->data_type() == kNumberTypeFloat32) {
data_type_ = kDataTypeFloat;
} else {
data_type_ = kDataTypeInt;
}
return ReduceBaseCPUKernel::ReSize();
}
int ReduceCPUKernel::ReSize() { return ReduceBaseCPUKernel::ReSize(); }
int ReduceCPUKernel::CallReduceUnit(int task_id) {
int ret;
@ -120,6 +113,11 @@ int ReduceCPUKernel::Run() {
MS_LOG(ERROR) << "Prepare fail!ret: " << prepare_ret;
return prepare_ret;
}
if (in_tensors().at(0)->data_type() == kNumberTypeFloat32) {
data_type_ = kDataTypeFloat;
} else {
data_type_ = kDataTypeInt;
}
auto ret = MallocTmpBuffer();
if (ret != RET_OK) {
FreeTmpBuffer();