forked from mindspore-Ecosystem/mindspore
!20097 [MS][LITE][CPU] fill 算子arm bug修复
Merge pull request !20097 from liuzhongkai/fill1
This commit is contained in:
commit
b0831fa50f
|
@ -16,7 +16,7 @@
|
|||
|
||||
#include "nnacl/fp16/fill_fp16.h"
|
||||
|
||||
int FillFp16(float16_t *output, int size, float16_t data) {
|
||||
inline int FillFp16(float16_t *output, int size, float16_t data) {
|
||||
for (int i = 0; i < size; ++i) {
|
||||
output[i] = data;
|
||||
}
|
||||
|
|
|
@ -49,13 +49,8 @@ int FillFp16CPUKernel::DoFill(int task_id) {
|
|||
return RET_OK;
|
||||
}
|
||||
int offset = task_id * thread_sz_stride_;
|
||||
auto input_tensor = in_tensors_.at(0);
|
||||
int ret = RET_OK;
|
||||
if (input_tensor->data_type() == kNumberTypeFloat16) {
|
||||
ret = FillFp16(fp16_out_ptr_ + offset, size, fp16_src_data_);
|
||||
} else {
|
||||
return RET_ERROR;
|
||||
}
|
||||
ret = FillFp16(fp16_out_ptr_ + offset, size, fp16_src_data_);
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "FillRun error task_id[" << task_id << "] error_code[" << ret << "]";
|
||||
return ret;
|
||||
|
@ -76,15 +71,10 @@ int FillRunFp16(void *cdata, int task_id, float lhs_scale, float rhs_scale) {
|
|||
int FillFp16CPUKernel::Run() {
|
||||
auto fill_input = in_tensors_.front();
|
||||
auto output = out_tensors_.front();
|
||||
if (fill_input->data_type() == kNumberTypeFloat16) {
|
||||
auto fill_data = reinterpret_cast<float16_t *>(fill_input->MutableData());
|
||||
fp16_src_data_ = fill_data[0];
|
||||
fp16_out_ptr_ = reinterpret_cast<float16_t *>(output->MutableData());
|
||||
} else {
|
||||
MS_LOG(ERROR) << "unsupported fill data type " << fill_input->data_type();
|
||||
return RET_ERROR;
|
||||
}
|
||||
auto ret = ParallelLaunch(this->context_, FillRunFp16, this, thread_sz_count_);
|
||||
auto fill_data = reinterpret_cast<float16_t *>(fill_input->MutableData());
|
||||
fp16_src_data_ = fill_data[0];
|
||||
fp16_out_ptr_ = reinterpret_cast<float16_t *>(output->MutableData());
|
||||
auto ret = ParallelLaunch(this->ms_context_, FillRunFp16, this, thread_sz_count_);
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "FillRun error error_code[" << ret << "]";
|
||||
return ret;
|
||||
|
|
Loading…
Reference in New Issue