forked from OSSInnovation/mindspore
!6413 [MSLITE][Develop] fix reduce relu6
Merge pull request !6413 from sunsuodong/fix_reduce_relu6
This commit is contained in:
commit
e6c738d3b6
|
@ -34,28 +34,20 @@ int ReluFp16(const float16_t *src, float16_t *dst, int ele_num) {
|
|||
}
|
||||
|
||||
int Relu6Fp16(const float16_t *data, float16_t *dst, int ele_num) {
|
||||
int eight_block = UP_DIV(ele_num, C8NUM);
|
||||
int i;
|
||||
for (i = 0; i < eight_block - 1; i++) {
|
||||
int index = i * C8NUM;
|
||||
int offset = 0;
|
||||
#ifdef ENABLE_NEON
|
||||
float16x8_t relu6_data = vld1q_f16(data + index);
|
||||
float16x8_t zero_data = vdupq_n_f16(0);
|
||||
float16x8_t six_data = vdupq_n_f16(6);
|
||||
float16x8_t zero_data = vdupq_n_f16(0);
|
||||
float16x8_t six_data = vdupq_n_f16(6);
|
||||
for (; offset <= ele_num - C8NUM; offset += C8NUM) {
|
||||
float16x8_t relu6_data = vld1q_f16(data + offset);
|
||||
relu6_data = vmaxq_f16(relu6_data, zero_data);
|
||||
relu6_data = vminq_f16(relu6_data, six_data);
|
||||
vst1q_f16(dst + index, relu6_data);
|
||||
#else
|
||||
int j;
|
||||
for (j = 0; j < C8NUM; ++j) {
|
||||
dst[index + j] = data[index + j] < 0 ? 0 : data[index + j];
|
||||
dst[index + j] = dst[index + j] > 6 ? 6 : dst[index + j];
|
||||
}
|
||||
#endif
|
||||
vst1q_f16(dst + offset, relu6_data);
|
||||
}
|
||||
for (int j = (eight_block - 1) * C8NUM; j < ele_num; ++j) {
|
||||
dst[j] = data[j] < 0 ? 0 : data[j];
|
||||
dst[j] = dst[j] > 6 ? 6 : dst[j];
|
||||
#endif
|
||||
for (; offset < ele_num; offset++) {
|
||||
dst[offset] = data[offset] < 0 ? 0 : data[offset];
|
||||
dst[offset] = dst[offset] > 6 ? 6 : dst[offset];
|
||||
}
|
||||
return NNACL_OK;
|
||||
}
|
||||
|
|
|
@ -82,14 +82,7 @@ int ReduceCPUKernel::Init() {
|
|||
return ReSize();
|
||||
}
|
||||
|
||||
int ReduceCPUKernel::ReSize() {
|
||||
if (in_tensors().at(0)->data_type() == kNumberTypeFloat32) {
|
||||
data_type_ = kDataTypeFloat;
|
||||
} else {
|
||||
data_type_ = kDataTypeInt;
|
||||
}
|
||||
return ReduceBaseCPUKernel::ReSize();
|
||||
}
|
||||
int ReduceCPUKernel::ReSize() { return ReduceBaseCPUKernel::ReSize(); }
|
||||
|
||||
int ReduceCPUKernel::CallReduceUnit(int task_id) {
|
||||
int ret;
|
||||
|
@ -120,6 +113,11 @@ int ReduceCPUKernel::Run() {
|
|||
MS_LOG(ERROR) << "Prepare fail!ret: " << prepare_ret;
|
||||
return prepare_ret;
|
||||
}
|
||||
if (in_tensors().at(0)->data_type() == kNumberTypeFloat32) {
|
||||
data_type_ = kDataTypeFloat;
|
||||
} else {
|
||||
data_type_ = kDataTypeInt;
|
||||
}
|
||||
auto ret = MallocTmpBuffer();
|
||||
if (ret != RET_OK) {
|
||||
FreeTmpBuffer();
|
||||
|
|
Loading…
Reference in New Issue