!20097 [MS][LITE][CPU] fill 算子arm bug修复

Merge pull request !20097 from liuzhongkai/fill1
This commit is contained in:
i-robot 2021-07-13 00:25:54 +00:00 committed by Gitee
commit b0831fa50f
2 changed files with 6 additions and 16 deletions

View File

@ -16,7 +16,7 @@
#include "nnacl/fp16/fill_fp16.h"
int FillFp16(float16_t *output, int size, float16_t data) {
inline int FillFp16(float16_t *output, int size, float16_t data) {
for (int i = 0; i < size; ++i) {
output[i] = data;
}

View File

@ -49,13 +49,8 @@ int FillFp16CPUKernel::DoFill(int task_id) {
return RET_OK;
}
int offset = task_id * thread_sz_stride_;
auto input_tensor = in_tensors_.at(0);
int ret = RET_OK;
if (input_tensor->data_type() == kNumberTypeFloat16) {
ret = FillFp16(fp16_out_ptr_ + offset, size, fp16_src_data_);
} else {
return RET_ERROR;
}
ret = FillFp16(fp16_out_ptr_ + offset, size, fp16_src_data_);
if (ret != RET_OK) {
MS_LOG(ERROR) << "FillRun error task_id[" << task_id << "] error_code[" << ret << "]";
return ret;
@ -76,15 +71,10 @@ int FillRunFp16(void *cdata, int task_id, float lhs_scale, float rhs_scale) {
int FillFp16CPUKernel::Run() {
auto fill_input = in_tensors_.front();
auto output = out_tensors_.front();
if (fill_input->data_type() == kNumberTypeFloat16) {
auto fill_data = reinterpret_cast<float16_t *>(fill_input->MutableData());
fp16_src_data_ = fill_data[0];
fp16_out_ptr_ = reinterpret_cast<float16_t *>(output->MutableData());
} else {
MS_LOG(ERROR) << "unsupported fill data type " << fill_input->data_type();
return RET_ERROR;
}
auto ret = ParallelLaunch(this->context_, FillRunFp16, this, thread_sz_count_);
auto fill_data = reinterpret_cast<float16_t *>(fill_input->MutableData());
fp16_src_data_ = fill_data[0];
fp16_out_ptr_ = reinterpret_cast<float16_t *>(output->MutableData());
auto ret = ParallelLaunch(this->ms_context_, FillRunFp16, this, thread_sz_count_);
if (ret != RET_OK) {
MS_LOG(ERROR) << "FillRun error error_code[" << ret << "]";
return ret;