From 29f75e241794f6cd04718f26961dceb524dcd2d9 Mon Sep 17 00:00:00 2001 From: gongdaguo Date: Mon, 7 Dec 2020 16:37:02 +0800 Subject: [PATCH] fix arithmetic bug --- .../arm/fp32/arithmetic_compare_fp32.cc | 20 +++++++++---------- .../kernel/arm/fp32/arithmetic_fp32.cc | 16 +++++++++++---- .../runtime/kernel/arm/fp32/arithmetic_fp32.h | 4 +--- 3 files changed, 23 insertions(+), 17 deletions(-) diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/arithmetic_compare_fp32.cc b/mindspore/lite/src/runtime/kernel/arm/fp32/arithmetic_compare_fp32.cc index d3b7238c9a7..ffa69ecc122 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32/arithmetic_compare_fp32.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/arithmetic_compare_fp32.cc @@ -83,22 +83,22 @@ int ArithmeticCompareCPUKernel::DoArithmetic(int task_id) { return RET_OK; } if (data_type_ == kDataTypeFloat) { - error_code = BroadcastRun( - reinterpret_cast(in_tensors_[0]->data_c()), reinterpret_cast(in_tensors_[1]->data_c()), - reinterpret_cast(out_tensors_[0]->data_c()), 0, out_count, out_thread_stride); + error_code = + BroadcastRun(reinterpret_cast(input0_ptr_), reinterpret_cast(input1_ptr_), + reinterpret_cast(out_tensors_[0]->data_c()), 0, out_count, out_thread_stride); } else { - error_code = BroadcastRun( - reinterpret_cast(in_tensors_[0]->data_c()), reinterpret_cast(in_tensors_[1]->data_c()), - reinterpret_cast(out_tensors_[0]->data_c()), 0, out_count, out_thread_stride); + error_code = + BroadcastRun(reinterpret_cast(input0_ptr_), reinterpret_cast(input1_ptr_), + reinterpret_cast(out_tensors_[0]->data_c()), 0, out_count, out_thread_stride); } } else { // no broadcast, neither is scalar, two same shape if (data_type_ == kDataTypeFloat) { - error_code = func_fp32_(reinterpret_cast(in_tensors_[0]->data_c()) + stride * task_id, - reinterpret_cast(in_tensors_[1]->data_c()) + stride * task_id, + error_code = func_fp32_(reinterpret_cast(input0_ptr_) + stride * task_id, + reinterpret_cast(input1_ptr_) + stride * task_id, reinterpret_cast(out_tensors_[0]->data_c()) + stride * task_id, count); } else { - error_code = func_int32_(reinterpret_cast(in_tensors_[0]->data_c()) + stride * task_id, - reinterpret_cast(in_tensors_[1]->data_c()) + stride * task_id, + error_code = func_int32_(reinterpret_cast(input0_ptr_) + stride * task_id, + reinterpret_cast(input1_ptr_) + stride * task_id, reinterpret_cast(out_tensors_[0]->data_c()) + stride * task_id, count); } } diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/arithmetic_fp32.cc b/mindspore/lite/src/runtime/kernel/arm/fp32/arithmetic_fp32.cc index 2921df87534..210db97044a 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32/arithmetic_fp32.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/arithmetic_fp32.cc @@ -51,8 +51,7 @@ int ArithmeticCPUKernel::InitBroadCastCase() { return RET_OK; } - int broadcast_size = out_tensors_[0]->Size(); - if (broadcast_size < 0) { + if (out_tensors_[0]->Size() < 0) { return RET_OK; } @@ -69,7 +68,7 @@ int ArithmeticCPUKernel::InitBroadCastCase() { if (in_tensors_[0]->data_c() != nullptr && arithmeticParameter_->in_elements_num1_ == arithmeticParameter_->out_elements_num_) { - input0_ptr_ = malloc(broadcast_size); + input0_ptr_ = malloc(arithmeticParameter_->out_elements_num_ * sizeof(float)); if (input0_ptr_ == nullptr) { return RET_ERROR; } @@ -81,7 +80,7 @@ int ArithmeticCPUKernel::InitBroadCastCase() { } if (in_tensors_[1]->data_c() != nullptr && arithmeticParameter_->in_elements_num0_ == arithmeticParameter_->out_elements_num_) { - input1_ptr_ = malloc(broadcast_size); + input1_ptr_ = malloc(arithmeticParameter_->out_elements_num_ * sizeof(float)); if (input1_ptr_ == nullptr) { FreeTmpPtr(); return RET_ERROR; @@ -212,6 +211,15 @@ void ArithmeticCPUKernel::InitRunFunction() { case PrimitiveType_SquaredDifference: arithmetic_run_ = ElementSquaredDifference; break; + case PrimitiveType_Equal: + case PrimitiveType_Less: + case PrimitiveType_Greater: + case PrimitiveType_NotEqual: + case PrimitiveType_LessEqual: + case PrimitiveType_GreaterEqual: + arithmetic_run_ = nullptr; + arithmetic_run_int_ = nullptr; + break; default: MS_LOG(ERROR) << "Error Operator type " << op_parameter_->type_; arithmetic_run_ = nullptr; diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/arithmetic_fp32.h b/mindspore/lite/src/runtime/kernel/arm/fp32/arithmetic_fp32.h index fee65a2ba2d..fcee9252d55 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32/arithmetic_fp32.h +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/arithmetic_fp32.h @@ -75,13 +75,11 @@ class ArithmeticCPUKernel : public LiteKernel { int InitBroadCastCase(); void InitParamInRunTime(); - private: + protected: bool input0_broadcast_ = false; bool input1_broadcast_ = false; void *input0_ptr_ = nullptr; void *input1_ptr_ = nullptr; - - protected: int break_pos_ = 0; int outside_ = 0; int thread_count_ = 1;