forked from mindspore-Ecosystem/mindspore
!9595 [MS][LITE]Fix arithmetic bug
From: @gongdaguo Reviewed-by: @zhang_xue_tong,@zhanghaibo5 Signed-off-by: @zhang_xue_tong
This commit is contained in:
commit
e7462680a9
|
@ -83,22 +83,22 @@ int ArithmeticCompareCPUKernel::DoArithmetic(int task_id) {
|
|||
return RET_OK;
|
||||
}
|
||||
if (data_type_ == kDataTypeFloat) {
|
||||
error_code = BroadcastRun(
|
||||
reinterpret_cast<float *>(in_tensors_[0]->data_c()), reinterpret_cast<float *>(in_tensors_[1]->data_c()),
|
||||
reinterpret_cast<uint8_t *>(out_tensors_[0]->data_c()), 0, out_count, out_thread_stride);
|
||||
error_code =
|
||||
BroadcastRun(reinterpret_cast<float *>(input0_ptr_), reinterpret_cast<float *>(input1_ptr_),
|
||||
reinterpret_cast<uint8_t *>(out_tensors_[0]->data_c()), 0, out_count, out_thread_stride);
|
||||
} else {
|
||||
error_code = BroadcastRun(
|
||||
reinterpret_cast<int *>(in_tensors_[0]->data_c()), reinterpret_cast<int *>(in_tensors_[1]->data_c()),
|
||||
reinterpret_cast<uint8_t *>(out_tensors_[0]->data_c()), 0, out_count, out_thread_stride);
|
||||
error_code =
|
||||
BroadcastRun(reinterpret_cast<int *>(input0_ptr_), reinterpret_cast<int *>(input1_ptr_),
|
||||
reinterpret_cast<uint8_t *>(out_tensors_[0]->data_c()), 0, out_count, out_thread_stride);
|
||||
}
|
||||
} else { // no broadcast, neither is scalar, two same shape
|
||||
if (data_type_ == kDataTypeFloat) {
|
||||
error_code = func_fp32_(reinterpret_cast<float *>(in_tensors_[0]->data_c()) + stride * task_id,
|
||||
reinterpret_cast<float *>(in_tensors_[1]->data_c()) + stride * task_id,
|
||||
error_code = func_fp32_(reinterpret_cast<float *>(input0_ptr_) + stride * task_id,
|
||||
reinterpret_cast<float *>(input1_ptr_) + stride * task_id,
|
||||
reinterpret_cast<uint8_t *>(out_tensors_[0]->data_c()) + stride * task_id, count);
|
||||
} else {
|
||||
error_code = func_int32_(reinterpret_cast<int *>(in_tensors_[0]->data_c()) + stride * task_id,
|
||||
reinterpret_cast<int *>(in_tensors_[1]->data_c()) + stride * task_id,
|
||||
error_code = func_int32_(reinterpret_cast<int *>(input0_ptr_) + stride * task_id,
|
||||
reinterpret_cast<int *>(input1_ptr_) + stride * task_id,
|
||||
reinterpret_cast<uint8_t *>(out_tensors_[0]->data_c()) + stride * task_id, count);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -51,8 +51,7 @@ int ArithmeticCPUKernel::InitBroadCastCase() {
|
|||
return RET_OK;
|
||||
}
|
||||
|
||||
int broadcast_size = out_tensors_[0]->Size();
|
||||
if (broadcast_size < 0) {
|
||||
if (out_tensors_[0]->Size() < 0) {
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
|
@ -69,7 +68,7 @@ int ArithmeticCPUKernel::InitBroadCastCase() {
|
|||
|
||||
if (in_tensors_[0]->data_c() != nullptr &&
|
||||
arithmeticParameter_->in_elements_num1_ == arithmeticParameter_->out_elements_num_) {
|
||||
input0_ptr_ = malloc(broadcast_size);
|
||||
input0_ptr_ = malloc(arithmeticParameter_->out_elements_num_ * sizeof(float));
|
||||
if (input0_ptr_ == nullptr) {
|
||||
return RET_ERROR;
|
||||
}
|
||||
|
@ -81,7 +80,7 @@ int ArithmeticCPUKernel::InitBroadCastCase() {
|
|||
}
|
||||
if (in_tensors_[1]->data_c() != nullptr &&
|
||||
arithmeticParameter_->in_elements_num0_ == arithmeticParameter_->out_elements_num_) {
|
||||
input1_ptr_ = malloc(broadcast_size);
|
||||
input1_ptr_ = malloc(arithmeticParameter_->out_elements_num_ * sizeof(float));
|
||||
if (input1_ptr_ == nullptr) {
|
||||
FreeTmpPtr();
|
||||
return RET_ERROR;
|
||||
|
@ -212,6 +211,15 @@ void ArithmeticCPUKernel::InitRunFunction() {
|
|||
case PrimitiveType_SquaredDifference:
|
||||
arithmetic_run_ = ElementSquaredDifference;
|
||||
break;
|
||||
case PrimitiveType_Equal:
|
||||
case PrimitiveType_Less:
|
||||
case PrimitiveType_Greater:
|
||||
case PrimitiveType_NotEqual:
|
||||
case PrimitiveType_LessEqual:
|
||||
case PrimitiveType_GreaterEqual:
|
||||
arithmetic_run_ = nullptr;
|
||||
arithmetic_run_int_ = nullptr;
|
||||
break;
|
||||
default:
|
||||
MS_LOG(ERROR) << "Error Operator type " << op_parameter_->type_;
|
||||
arithmetic_run_ = nullptr;
|
||||
|
|
|
@ -75,13 +75,11 @@ class ArithmeticCPUKernel : public LiteKernel {
|
|||
int InitBroadCastCase();
|
||||
void InitParamInRunTime();
|
||||
|
||||
private:
|
||||
protected:
|
||||
bool input0_broadcast_ = false;
|
||||
bool input1_broadcast_ = false;
|
||||
void *input0_ptr_ = nullptr;
|
||||
void *input1_ptr_ = nullptr;
|
||||
|
||||
protected:
|
||||
int break_pos_ = 0;
|
||||
int outside_ = 0;
|
||||
int thread_count_ = 1;
|
||||
|
|
Loading…
Reference in New Issue