!9595 [MS][LITE]Fix arithmetic bug

From: @gongdaguo
Reviewed-by: @zhang_xue_tong,@zhanghaibo5
Signed-off-by: @zhang_xue_tong
This commit is contained in:
mindspore-ci-bot 2020-12-08 14:16:40 +08:00 committed by Gitee
commit e7462680a9
3 changed files with 23 additions and 17 deletions

View File

@ -83,22 +83,22 @@ int ArithmeticCompareCPUKernel::DoArithmetic(int task_id) {
return RET_OK; return RET_OK;
} }
if (data_type_ == kDataTypeFloat) { if (data_type_ == kDataTypeFloat) {
error_code = BroadcastRun( error_code =
reinterpret_cast<float *>(in_tensors_[0]->data_c()), reinterpret_cast<float *>(in_tensors_[1]->data_c()), BroadcastRun(reinterpret_cast<float *>(input0_ptr_), reinterpret_cast<float *>(input1_ptr_),
reinterpret_cast<uint8_t *>(out_tensors_[0]->data_c()), 0, out_count, out_thread_stride); reinterpret_cast<uint8_t *>(out_tensors_[0]->data_c()), 0, out_count, out_thread_stride);
} else { } else {
error_code = BroadcastRun( error_code =
reinterpret_cast<int *>(in_tensors_[0]->data_c()), reinterpret_cast<int *>(in_tensors_[1]->data_c()), BroadcastRun(reinterpret_cast<int *>(input0_ptr_), reinterpret_cast<int *>(input1_ptr_),
reinterpret_cast<uint8_t *>(out_tensors_[0]->data_c()), 0, out_count, out_thread_stride); reinterpret_cast<uint8_t *>(out_tensors_[0]->data_c()), 0, out_count, out_thread_stride);
} }
} else { // no broadcast, neither is scalar, two same shape } else { // no broadcast, neither is scalar, two same shape
if (data_type_ == kDataTypeFloat) { if (data_type_ == kDataTypeFloat) {
error_code = func_fp32_(reinterpret_cast<float *>(in_tensors_[0]->data_c()) + stride * task_id, error_code = func_fp32_(reinterpret_cast<float *>(input0_ptr_) + stride * task_id,
reinterpret_cast<float *>(in_tensors_[1]->data_c()) + stride * task_id, reinterpret_cast<float *>(input1_ptr_) + stride * task_id,
reinterpret_cast<uint8_t *>(out_tensors_[0]->data_c()) + stride * task_id, count); reinterpret_cast<uint8_t *>(out_tensors_[0]->data_c()) + stride * task_id, count);
} else { } else {
error_code = func_int32_(reinterpret_cast<int *>(in_tensors_[0]->data_c()) + stride * task_id, error_code = func_int32_(reinterpret_cast<int *>(input0_ptr_) + stride * task_id,
reinterpret_cast<int *>(in_tensors_[1]->data_c()) + stride * task_id, reinterpret_cast<int *>(input1_ptr_) + stride * task_id,
reinterpret_cast<uint8_t *>(out_tensors_[0]->data_c()) + stride * task_id, count); reinterpret_cast<uint8_t *>(out_tensors_[0]->data_c()) + stride * task_id, count);
} }
} }

View File

@ -51,8 +51,7 @@ int ArithmeticCPUKernel::InitBroadCastCase() {
return RET_OK; return RET_OK;
} }
int broadcast_size = out_tensors_[0]->Size(); if (out_tensors_[0]->Size() < 0) {
if (broadcast_size < 0) {
return RET_OK; return RET_OK;
} }
@ -69,7 +68,7 @@ int ArithmeticCPUKernel::InitBroadCastCase() {
if (in_tensors_[0]->data_c() != nullptr && if (in_tensors_[0]->data_c() != nullptr &&
arithmeticParameter_->in_elements_num1_ == arithmeticParameter_->out_elements_num_) { arithmeticParameter_->in_elements_num1_ == arithmeticParameter_->out_elements_num_) {
input0_ptr_ = malloc(broadcast_size); input0_ptr_ = malloc(arithmeticParameter_->out_elements_num_ * sizeof(float));
if (input0_ptr_ == nullptr) { if (input0_ptr_ == nullptr) {
return RET_ERROR; return RET_ERROR;
} }
@ -81,7 +80,7 @@ int ArithmeticCPUKernel::InitBroadCastCase() {
} }
if (in_tensors_[1]->data_c() != nullptr && if (in_tensors_[1]->data_c() != nullptr &&
arithmeticParameter_->in_elements_num0_ == arithmeticParameter_->out_elements_num_) { arithmeticParameter_->in_elements_num0_ == arithmeticParameter_->out_elements_num_) {
input1_ptr_ = malloc(broadcast_size); input1_ptr_ = malloc(arithmeticParameter_->out_elements_num_ * sizeof(float));
if (input1_ptr_ == nullptr) { if (input1_ptr_ == nullptr) {
FreeTmpPtr(); FreeTmpPtr();
return RET_ERROR; return RET_ERROR;
@ -212,6 +211,15 @@ void ArithmeticCPUKernel::InitRunFunction() {
case PrimitiveType_SquaredDifference: case PrimitiveType_SquaredDifference:
arithmetic_run_ = ElementSquaredDifference; arithmetic_run_ = ElementSquaredDifference;
break; break;
case PrimitiveType_Equal:
case PrimitiveType_Less:
case PrimitiveType_Greater:
case PrimitiveType_NotEqual:
case PrimitiveType_LessEqual:
case PrimitiveType_GreaterEqual:
arithmetic_run_ = nullptr;
arithmetic_run_int_ = nullptr;
break;
default: default:
MS_LOG(ERROR) << "Error Operator type " << op_parameter_->type_; MS_LOG(ERROR) << "Error Operator type " << op_parameter_->type_;
arithmetic_run_ = nullptr; arithmetic_run_ = nullptr;

View File

@ -75,13 +75,11 @@ class ArithmeticCPUKernel : public LiteKernel {
int InitBroadCastCase(); int InitBroadCastCase();
void InitParamInRunTime(); void InitParamInRunTime();
private: protected:
bool input0_broadcast_ = false; bool input0_broadcast_ = false;
bool input1_broadcast_ = false; bool input1_broadcast_ = false;
void *input0_ptr_ = nullptr; void *input0_ptr_ = nullptr;
void *input1_ptr_ = nullptr; void *input1_ptr_ = nullptr;
protected:
int break_pos_ = 0; int break_pos_ = 0;
int outside_ = 0; int outside_ = 0;
int thread_count_ = 1; int thread_count_ = 1;