forked from mindspore-Ecosystem/mindspore
!9595 [MS][LITE]Fix arithmetic bug
From: @gongdaguo Reviewed-by: @zhang_xue_tong,@zhanghaibo5 Signed-off-by: @zhang_xue_tong
This commit is contained in:
commit
e7462680a9
|
@ -83,22 +83,22 @@ int ArithmeticCompareCPUKernel::DoArithmetic(int task_id) {
|
||||||
return RET_OK;
|
return RET_OK;
|
||||||
}
|
}
|
||||||
if (data_type_ == kDataTypeFloat) {
|
if (data_type_ == kDataTypeFloat) {
|
||||||
error_code = BroadcastRun(
|
error_code =
|
||||||
reinterpret_cast<float *>(in_tensors_[0]->data_c()), reinterpret_cast<float *>(in_tensors_[1]->data_c()),
|
BroadcastRun(reinterpret_cast<float *>(input0_ptr_), reinterpret_cast<float *>(input1_ptr_),
|
||||||
reinterpret_cast<uint8_t *>(out_tensors_[0]->data_c()), 0, out_count, out_thread_stride);
|
reinterpret_cast<uint8_t *>(out_tensors_[0]->data_c()), 0, out_count, out_thread_stride);
|
||||||
} else {
|
} else {
|
||||||
error_code = BroadcastRun(
|
error_code =
|
||||||
reinterpret_cast<int *>(in_tensors_[0]->data_c()), reinterpret_cast<int *>(in_tensors_[1]->data_c()),
|
BroadcastRun(reinterpret_cast<int *>(input0_ptr_), reinterpret_cast<int *>(input1_ptr_),
|
||||||
reinterpret_cast<uint8_t *>(out_tensors_[0]->data_c()), 0, out_count, out_thread_stride);
|
reinterpret_cast<uint8_t *>(out_tensors_[0]->data_c()), 0, out_count, out_thread_stride);
|
||||||
}
|
}
|
||||||
} else { // no broadcast, neither is scalar, two same shape
|
} else { // no broadcast, neither is scalar, two same shape
|
||||||
if (data_type_ == kDataTypeFloat) {
|
if (data_type_ == kDataTypeFloat) {
|
||||||
error_code = func_fp32_(reinterpret_cast<float *>(in_tensors_[0]->data_c()) + stride * task_id,
|
error_code = func_fp32_(reinterpret_cast<float *>(input0_ptr_) + stride * task_id,
|
||||||
reinterpret_cast<float *>(in_tensors_[1]->data_c()) + stride * task_id,
|
reinterpret_cast<float *>(input1_ptr_) + stride * task_id,
|
||||||
reinterpret_cast<uint8_t *>(out_tensors_[0]->data_c()) + stride * task_id, count);
|
reinterpret_cast<uint8_t *>(out_tensors_[0]->data_c()) + stride * task_id, count);
|
||||||
} else {
|
} else {
|
||||||
error_code = func_int32_(reinterpret_cast<int *>(in_tensors_[0]->data_c()) + stride * task_id,
|
error_code = func_int32_(reinterpret_cast<int *>(input0_ptr_) + stride * task_id,
|
||||||
reinterpret_cast<int *>(in_tensors_[1]->data_c()) + stride * task_id,
|
reinterpret_cast<int *>(input1_ptr_) + stride * task_id,
|
||||||
reinterpret_cast<uint8_t *>(out_tensors_[0]->data_c()) + stride * task_id, count);
|
reinterpret_cast<uint8_t *>(out_tensors_[0]->data_c()) + stride * task_id, count);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -51,8 +51,7 @@ int ArithmeticCPUKernel::InitBroadCastCase() {
|
||||||
return RET_OK;
|
return RET_OK;
|
||||||
}
|
}
|
||||||
|
|
||||||
int broadcast_size = out_tensors_[0]->Size();
|
if (out_tensors_[0]->Size() < 0) {
|
||||||
if (broadcast_size < 0) {
|
|
||||||
return RET_OK;
|
return RET_OK;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -69,7 +68,7 @@ int ArithmeticCPUKernel::InitBroadCastCase() {
|
||||||
|
|
||||||
if (in_tensors_[0]->data_c() != nullptr &&
|
if (in_tensors_[0]->data_c() != nullptr &&
|
||||||
arithmeticParameter_->in_elements_num1_ == arithmeticParameter_->out_elements_num_) {
|
arithmeticParameter_->in_elements_num1_ == arithmeticParameter_->out_elements_num_) {
|
||||||
input0_ptr_ = malloc(broadcast_size);
|
input0_ptr_ = malloc(arithmeticParameter_->out_elements_num_ * sizeof(float));
|
||||||
if (input0_ptr_ == nullptr) {
|
if (input0_ptr_ == nullptr) {
|
||||||
return RET_ERROR;
|
return RET_ERROR;
|
||||||
}
|
}
|
||||||
|
@ -81,7 +80,7 @@ int ArithmeticCPUKernel::InitBroadCastCase() {
|
||||||
}
|
}
|
||||||
if (in_tensors_[1]->data_c() != nullptr &&
|
if (in_tensors_[1]->data_c() != nullptr &&
|
||||||
arithmeticParameter_->in_elements_num0_ == arithmeticParameter_->out_elements_num_) {
|
arithmeticParameter_->in_elements_num0_ == arithmeticParameter_->out_elements_num_) {
|
||||||
input1_ptr_ = malloc(broadcast_size);
|
input1_ptr_ = malloc(arithmeticParameter_->out_elements_num_ * sizeof(float));
|
||||||
if (input1_ptr_ == nullptr) {
|
if (input1_ptr_ == nullptr) {
|
||||||
FreeTmpPtr();
|
FreeTmpPtr();
|
||||||
return RET_ERROR;
|
return RET_ERROR;
|
||||||
|
@ -212,6 +211,15 @@ void ArithmeticCPUKernel::InitRunFunction() {
|
||||||
case PrimitiveType_SquaredDifference:
|
case PrimitiveType_SquaredDifference:
|
||||||
arithmetic_run_ = ElementSquaredDifference;
|
arithmetic_run_ = ElementSquaredDifference;
|
||||||
break;
|
break;
|
||||||
|
case PrimitiveType_Equal:
|
||||||
|
case PrimitiveType_Less:
|
||||||
|
case PrimitiveType_Greater:
|
||||||
|
case PrimitiveType_NotEqual:
|
||||||
|
case PrimitiveType_LessEqual:
|
||||||
|
case PrimitiveType_GreaterEqual:
|
||||||
|
arithmetic_run_ = nullptr;
|
||||||
|
arithmetic_run_int_ = nullptr;
|
||||||
|
break;
|
||||||
default:
|
default:
|
||||||
MS_LOG(ERROR) << "Error Operator type " << op_parameter_->type_;
|
MS_LOG(ERROR) << "Error Operator type " << op_parameter_->type_;
|
||||||
arithmetic_run_ = nullptr;
|
arithmetic_run_ = nullptr;
|
||||||
|
|
|
@ -75,13 +75,11 @@ class ArithmeticCPUKernel : public LiteKernel {
|
||||||
int InitBroadCastCase();
|
int InitBroadCastCase();
|
||||||
void InitParamInRunTime();
|
void InitParamInRunTime();
|
||||||
|
|
||||||
private:
|
protected:
|
||||||
bool input0_broadcast_ = false;
|
bool input0_broadcast_ = false;
|
||||||
bool input1_broadcast_ = false;
|
bool input1_broadcast_ = false;
|
||||||
void *input0_ptr_ = nullptr;
|
void *input0_ptr_ = nullptr;
|
||||||
void *input1_ptr_ = nullptr;
|
void *input1_ptr_ = nullptr;
|
||||||
|
|
||||||
protected:
|
|
||||||
int break_pos_ = 0;
|
int break_pos_ = 0;
|
||||||
int outside_ = 0;
|
int outside_ = 0;
|
||||||
int thread_count_ = 1;
|
int thread_count_ = 1;
|
||||||
|
|
Loading…
Reference in New Issue