forked from mindspore-Ecosystem/mindspore
!4349 [MS][LITE][Develop]compare ops support quant
Merge pull request !4349 from chenjianping/lite_dev2
This commit is contained in:
commit
874972caf8
|
@ -15,7 +15,6 @@
|
||||||
*/
|
*/
|
||||||
|
|
||||||
#include "src/runtime/kernel/arm/int8/arithmetic_int8.h"
|
#include "src/runtime/kernel/arm/int8/arithmetic_int8.h"
|
||||||
#include "src/runtime/kernel/arm/nnacl/int8/arithmetic_int8.h"
|
|
||||||
#include "src/runtime/kernel/arm/nnacl/arithmetic_common.h"
|
#include "src/runtime/kernel/arm/nnacl/arithmetic_common.h"
|
||||||
#include "schema/model_generated.h"
|
#include "schema/model_generated.h"
|
||||||
#include "src/kernel_registry.h"
|
#include "src/kernel_registry.h"
|
||||||
|
@ -42,7 +41,7 @@ int ArithmeticsInt8Launch(int thread_id, LiteParallelGroupEnv *penv, void *cdata
|
||||||
auto error_code = arithmetic_kernel->DoArithmetic(thread_id);
|
auto error_code = arithmetic_kernel->DoArithmetic(thread_id);
|
||||||
if (error_code != RET_OK) {
|
if (error_code != RET_OK) {
|
||||||
MS_LOG(ERROR) << "ArithmeticsRun error thread_id[" << thread_id << "] error_code[" << error_code << "]";
|
MS_LOG(ERROR) << "ArithmeticsRun error thread_id[" << thread_id << "] error_code[" << error_code << "]";
|
||||||
return RET_ERROR;
|
return error_code;
|
||||||
}
|
}
|
||||||
return RET_OK;
|
return RET_OK;
|
||||||
}
|
}
|
||||||
|
@ -79,28 +78,43 @@ ArithmeticInt8CPUKernel::~ArithmeticInt8CPUKernel() {
|
||||||
int ArithmeticInt8CPUKernel::Init() {
|
int ArithmeticInt8CPUKernel::Init() {
|
||||||
switch (op_parameter_->type_) {
|
switch (op_parameter_->type_) {
|
||||||
case PrimitiveType_Equal:
|
case PrimitiveType_Equal:
|
||||||
arithmetic_run_ = ElementEqual;
|
arithmetic_run_ = ElementEqualInt8;
|
||||||
break;
|
break;
|
||||||
case PrimitiveType_NotEqual:
|
case PrimitiveType_NotEqual:
|
||||||
arithmetic_run_ = ElementNotEqual;
|
arithmetic_run_ = ElementNotEqualInt8;
|
||||||
break;
|
break;
|
||||||
case PrimitiveType_Less:
|
case PrimitiveType_Less:
|
||||||
arithmetic_run_ = ElementLess;
|
arithmetic_run_ = ElementLessInt8;
|
||||||
break;
|
break;
|
||||||
case PrimitiveType_LessEqual:
|
case PrimitiveType_LessEqual:
|
||||||
arithmetic_run_ = ElementLessEqual;
|
arithmetic_run_ = ElementLessEqualInt8;
|
||||||
break;
|
break;
|
||||||
case PrimitiveType_Greater:
|
case PrimitiveType_Greater:
|
||||||
arithmetic_run_ = ElementGreater;
|
arithmetic_run_ = ElementGreaterInt8;
|
||||||
break;
|
break;
|
||||||
case PrimitiveType_GreaterEqual:
|
case PrimitiveType_GreaterEqual:
|
||||||
arithmetic_run_ = ElementGreaterEqual;
|
arithmetic_run_ = ElementGreaterEqualInt8;
|
||||||
break;
|
break;
|
||||||
default:
|
default:
|
||||||
MS_LOG(ERROR) << "Error Operator type " << op_parameter_->type_;
|
MS_LOG(ERROR) << "Error Operator type " << op_parameter_->type_;
|
||||||
arithmetic_run_ = nullptr;
|
arithmetic_run_ = nullptr;
|
||||||
return RET_PARAM_INVALID;
|
return RET_PARAM_INVALID;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
auto *input0_tensor = in_tensors_.at(0);
|
||||||
|
auto in0_quant_args = input0_tensor->GetQuantParams();
|
||||||
|
quant_args_.in0_args_.scale_ = in0_quant_args.front().scale;
|
||||||
|
quant_args_.in0_args_.zp_ = in0_quant_args.front().zeroPoint;
|
||||||
|
|
||||||
|
auto *input1_tensor = in_tensors_.at(1);
|
||||||
|
auto in1_quant_args = input1_tensor->GetQuantParams();
|
||||||
|
quant_args_.in1_args_.scale_ = in1_quant_args.front().scale;
|
||||||
|
quant_args_.in1_args_.zp_ = in1_quant_args.front().zeroPoint;
|
||||||
|
|
||||||
|
auto *out_tensor = out_tensors_.at(kOutputIndex);
|
||||||
|
auto out_quant_args = out_tensor->GetQuantParams();
|
||||||
|
quant_args_.out_args_.scale_ = out_quant_args.front().scale;
|
||||||
|
quant_args_.out_args_.zp_ = out_quant_args.front().zeroPoint;
|
||||||
if (!InferShapeDone()) {
|
if (!InferShapeDone()) {
|
||||||
return RET_OK;
|
return RET_OK;
|
||||||
}
|
}
|
||||||
|
@ -142,16 +156,16 @@ int ArithmeticInt8CPUKernel::DoArithmetic(int thread_id) {
|
||||||
}
|
}
|
||||||
|
|
||||||
int error_code = arithmetic_run_(tile_data0_ + stride * thread_id, tile_data1_ + stride * thread_id,
|
int error_code = arithmetic_run_(tile_data0_ + stride * thread_id, tile_data1_ + stride * thread_id,
|
||||||
output_data + stride * thread_id, count);
|
output_data + stride * thread_id, count, &quant_args_);
|
||||||
if (error_code != RET_OK) {
|
if (error_code != RET_OK) {
|
||||||
MS_LOG(ERROR) << "Arithmetic run fail! ret: " << error_code;
|
MS_LOG(ERROR) << "Arithmetic run fail! ret: " << error_code;
|
||||||
return RET_ERROR;
|
return error_code;
|
||||||
}
|
}
|
||||||
} else if (arithmetic_run_ != nullptr) {
|
} else if (arithmetic_run_ != nullptr) {
|
||||||
int error_code = arithmetic_run_(input0_data, input1_data1, output_data, element_num);
|
int error_code = arithmetic_run_(input0_data, input1_data1, output_data, element_num, &quant_args_);
|
||||||
if (error_code != RET_OK) {
|
if (error_code != RET_OK) {
|
||||||
MS_LOG(ERROR) << "Arithmetic run fail!ret: " << error_code;
|
MS_LOG(ERROR) << "Arithmetic run fail!ret: " << error_code;
|
||||||
return RET_ERROR;
|
return error_code;
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
MS_LOG(ERROR) << "arithmetic_run function is nullptr!";
|
MS_LOG(ERROR) << "arithmetic_run function is nullptr!";
|
||||||
|
|
|
@ -20,10 +20,12 @@
|
||||||
#include <vector>
|
#include <vector>
|
||||||
#include "src/lite_kernel.h"
|
#include "src/lite_kernel.h"
|
||||||
#include "schema/model_generated.h"
|
#include "schema/model_generated.h"
|
||||||
|
#include "src/runtime/kernel/arm/nnacl/int8/arithmetic_int8.h"
|
||||||
|
|
||||||
namespace mindspore::kernel {
|
namespace mindspore::kernel {
|
||||||
class ArithmeticInt8CPUKernel : public LiteKernel {
|
class ArithmeticInt8CPUKernel : public LiteKernel {
|
||||||
typedef int (*ArithmeticRunInt8)(int8_t *input0, int8_t *input1, int8_t *output, int element_size);
|
typedef int (*ArithmeticRunInt8)(int8_t *input0, int8_t *input1, int8_t *output, int element_size,
|
||||||
|
ArithmeticQuantArg *quant_arg);
|
||||||
|
|
||||||
public:
|
public:
|
||||||
ArithmeticInt8CPUKernel(OpParameter *parameter, const std::vector<lite::tensor::Tensor *> &inputs,
|
ArithmeticInt8CPUKernel(OpParameter *parameter, const std::vector<lite::tensor::Tensor *> &inputs,
|
||||||
|
@ -39,10 +41,10 @@ class ArithmeticInt8CPUKernel : public LiteKernel {
|
||||||
|
|
||||||
private:
|
private:
|
||||||
void FreeTileData();
|
void FreeTileData();
|
||||||
int thread_count_;
|
|
||||||
int8_t *tile_data0_;
|
int8_t *tile_data0_;
|
||||||
int8_t *tile_data1_;
|
int8_t *tile_data1_;
|
||||||
ArithmeticRunInt8 arithmetic_run_;
|
ArithmeticRunInt8 arithmetic_run_;
|
||||||
|
ArithmeticQuantArg quant_args_;
|
||||||
};
|
};
|
||||||
} // namespace mindspore::kernel
|
} // namespace mindspore::kernel
|
||||||
#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_INT8_ARITHMETIC_INT8_H_
|
#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_INT8_ARITHMETIC_INT8_H_
|
||||||
|
|
|
@ -17,6 +17,8 @@
|
||||||
#include "nnacl/fp32/arithmetic.h"
|
#include "nnacl/fp32/arithmetic.h"
|
||||||
#include <math.h>
|
#include <math.h>
|
||||||
|
|
||||||
|
#define ACCURACY_DATA 0.00000001
|
||||||
|
|
||||||
int ElementMul(float *input0, float *input1, float *output, int element_size) {
|
int ElementMul(float *input0, float *input1, float *output, int element_size) {
|
||||||
int block_mod = element_size % C4NUM;
|
int block_mod = element_size % C4NUM;
|
||||||
int block_c4 = element_size - block_mod;
|
int block_c4 = element_size - block_mod;
|
||||||
|
@ -549,6 +551,14 @@ int BroadcastMinimum(float *input0, float *input1, float *tile_input0, float *ti
|
||||||
return ElementMinimum(tile_input0, tile_input1, output, element_size);
|
return ElementMinimum(tile_input0, tile_input1, output, element_size);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
float FloatNotEqualCheck(float in0, float in1) {
|
||||||
|
float minus = in0 - in1;
|
||||||
|
if (minus <= ACCURACY_DATA && minus >= -ACCURACY_DATA) {
|
||||||
|
return (float)false;
|
||||||
|
}
|
||||||
|
return (float)true;
|
||||||
|
}
|
||||||
|
|
||||||
int ElementNotEqual(float *input0, float *input1, float *output, int element_size) {
|
int ElementNotEqual(float *input0, float *input1, float *output, int element_size) {
|
||||||
int block_mod = element_size % C4NUM;
|
int block_mod = element_size % C4NUM;
|
||||||
int block_c4 = element_size - block_mod;
|
int block_c4 = element_size - block_mod;
|
||||||
|
@ -563,10 +573,10 @@ int ElementNotEqual(float *input0, float *input1, float *output, int element_siz
|
||||||
float32x4_t vout = vbslq_f32(vceqq_f32(vin0, vin1), vfalse, vtrue);
|
float32x4_t vout = vbslq_f32(vceqq_f32(vin0, vin1), vfalse, vtrue);
|
||||||
vst1q_f32(output, vout);
|
vst1q_f32(output, vout);
|
||||||
#else
|
#else
|
||||||
output[0] = (float)(input0[0] != input1[0]);
|
output[0] = FloatNotEqualCheck(input0[0], input1[0]);
|
||||||
output[1] = (float)(input0[1] != input1[1]);
|
output[1] = FloatNotEqualCheck(input0[1], input1[1]);
|
||||||
output[2] = (float)(input0[2] != input1[2]);
|
output[2] = FloatNotEqualCheck(input0[2], input1[2]);
|
||||||
output[3] = (float)(input0[3] != input1[3]);
|
output[3] = FloatNotEqualCheck(input0[3], input1[3]);
|
||||||
#endif
|
#endif
|
||||||
input0 += C4NUM;
|
input0 += C4NUM;
|
||||||
input1 += C4NUM;
|
input1 += C4NUM;
|
||||||
|
@ -584,6 +594,14 @@ int BroadcastNotEqual(float *input0, float *input1, float *tile_input0, float *t
|
||||||
return ElementNotEqual(tile_input0, tile_input1, output, element_size);
|
return ElementNotEqual(tile_input0, tile_input1, output, element_size);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
float FloatEqualCheck(float in0, float in1) {
|
||||||
|
float minus = in0 - in1;
|
||||||
|
if (minus <= ACCURACY_DATA && minus >= -ACCURACY_DATA) {
|
||||||
|
return (float)true;
|
||||||
|
}
|
||||||
|
return (float)false;
|
||||||
|
}
|
||||||
|
|
||||||
int ElementEqual(float *input0, float *input1, float *output, int element_size) {
|
int ElementEqual(float *input0, float *input1, float *output, int element_size) {
|
||||||
int block_mod = element_size % C4NUM;
|
int block_mod = element_size % C4NUM;
|
||||||
int block_c4 = element_size - block_mod;
|
int block_c4 = element_size - block_mod;
|
||||||
|
@ -598,10 +616,10 @@ int ElementEqual(float *input0, float *input1, float *output, int element_size)
|
||||||
float32x4_t vout = vbslq_f32(vceqq_f32(vin0, vin1), vtrue, vfalse);
|
float32x4_t vout = vbslq_f32(vceqq_f32(vin0, vin1), vtrue, vfalse);
|
||||||
vst1q_f32(output, vout);
|
vst1q_f32(output, vout);
|
||||||
#else
|
#else
|
||||||
output[0] = (float)(input0[0] == input1[0]);
|
output[0] = FloatEqualCheck(input0[0], input1[0]);
|
||||||
output[1] = (float)(input0[1] == input1[1]);
|
output[1] = FloatEqualCheck(input0[1], input1[1]);
|
||||||
output[2] = (float)(input0[2] == input1[2]);
|
output[2] = FloatEqualCheck(input0[2], input1[2]);
|
||||||
output[3] = (float)(input0[3] == input1[3]);
|
output[3] = FloatEqualCheck(input0[3], input1[3]);
|
||||||
#endif
|
#endif
|
||||||
input0 += C4NUM;
|
input0 += C4NUM;
|
||||||
input1 += C4NUM;
|
input1 += C4NUM;
|
||||||
|
@ -758,3 +776,5 @@ int BroadcastGreaterEqual(float *input0, float *input1, float *tile_input0, floa
|
||||||
TileDimensions(input0, input1, tile_input0, tile_input1, param);
|
TileDimensions(input0, input1, tile_input0, tile_input1, param);
|
||||||
return ElementGreaterEqual(tile_input0, tile_input1, output, element_size);
|
return ElementGreaterEqual(tile_input0, tile_input1, output, element_size);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#undef ACCURACY_DATA
|
||||||
|
|
|
@ -20,44 +20,102 @@
|
||||||
#endif
|
#endif
|
||||||
#include "nnacl/errorcode.h"
|
#include "nnacl/errorcode.h"
|
||||||
|
|
||||||
int ElementNotEqual(int8_t *input0, int8_t *input1, int8_t *output, int element_size) {
|
#define ACCURACY_DATA 0.00000001
|
||||||
|
|
||||||
|
int ElementNotEqualInt8(int8_t *input0, int8_t *input1, int8_t *output, int element_size,
|
||||||
|
ArithmeticQuantArg *quant_arg) {
|
||||||
|
float in0_bias = -quant_arg->in0_args_.zp_ * quant_arg->in0_args_.scale_;
|
||||||
|
float in1_bias = -quant_arg->in1_args_.zp_ * quant_arg->in1_args_.scale_;
|
||||||
|
float output_inverse_scale = 1.f / quant_arg->out_args_.scale_;
|
||||||
|
float out_zp = quant_arg->out_args_.zp_;
|
||||||
for (int index = 0; index < element_size; ++index) {
|
for (int index = 0; index < element_size; ++index) {
|
||||||
output[index] = (int8_t)(input0[index] != input1[index]);
|
float in0_real = input0[index] * quant_arg->in0_args_.scale_ + in0_bias;
|
||||||
|
float in1_real = input1[index] * quant_arg->in1_args_.scale_ + in1_bias;
|
||||||
|
float minus_inputs = in0_real - in1_real;
|
||||||
|
float out_real = (float)true;
|
||||||
|
if (minus_inputs >= -ACCURACY_DATA && minus_inputs <= ACCURACY_DATA) {
|
||||||
|
out_real = (float)false;
|
||||||
|
}
|
||||||
|
output[index] = (int8_t)(out_real * output_inverse_scale + out_zp);
|
||||||
}
|
}
|
||||||
return NNACL_OK;
|
return NNACL_OK;
|
||||||
}
|
}
|
||||||
|
|
||||||
int ElementEqual(int8_t *input0, int8_t *input1, int8_t *output, int element_size) {
|
int ElementEqualInt8(int8_t *input0, int8_t *input1, int8_t *output, int element_size, ArithmeticQuantArg *quant_arg) {
|
||||||
|
float in0_bias = -quant_arg->in0_args_.zp_ * quant_arg->in0_args_.scale_;
|
||||||
|
float in1_bias = -quant_arg->in1_args_.zp_ * quant_arg->in1_args_.scale_;
|
||||||
|
float output_inverse_scale = 1.f / quant_arg->out_args_.scale_;
|
||||||
|
float out_zp = quant_arg->out_args_.zp_;
|
||||||
for (int index = 0; index < element_size; ++index) {
|
for (int index = 0; index < element_size; ++index) {
|
||||||
output[index] = (int8_t)(input0[index] == input1[index]);
|
float in0_real = input0[index] * quant_arg->in0_args_.scale_ + in0_bias;
|
||||||
|
float in1_real = input1[index] * quant_arg->in1_args_.scale_ + in1_bias;
|
||||||
|
float minus_inputs = in0_real - in1_real;
|
||||||
|
float out_real = (float)false;
|
||||||
|
if (minus_inputs >= -ACCURACY_DATA && minus_inputs <= ACCURACY_DATA) {
|
||||||
|
out_real = (float)true;
|
||||||
|
}
|
||||||
|
output[index] = (int8_t)(out_real * output_inverse_scale + out_zp);
|
||||||
}
|
}
|
||||||
return NNACL_OK;
|
return NNACL_OK;
|
||||||
}
|
}
|
||||||
|
|
||||||
int ElementLess(int8_t *input0, int8_t *input1, int8_t *output, int element_size) {
|
int ElementLessInt8(int8_t *input0, int8_t *input1, int8_t *output, int element_size, ArithmeticQuantArg *quant_arg) {
|
||||||
|
float in0_bias = -quant_arg->in0_args_.zp_ * quant_arg->in0_args_.scale_;
|
||||||
|
float in1_bias = -quant_arg->in1_args_.zp_ * quant_arg->in1_args_.scale_;
|
||||||
|
float output_inverse_scale = 1.f / quant_arg->out_args_.scale_;
|
||||||
|
float out_zp = quant_arg->out_args_.zp_;
|
||||||
for (int index = 0; index < element_size; ++index) {
|
for (int index = 0; index < element_size; ++index) {
|
||||||
output[index] = (int8_t)(input0[index] < input1[index]);
|
float in0_real = input0[index] * quant_arg->in0_args_.scale_ + in0_bias;
|
||||||
|
float in1_real = input1[index] * quant_arg->in1_args_.scale_ + in1_bias;
|
||||||
|
float out_real = (float)(in0_real < in1_real);
|
||||||
|
output[index] = (int8_t)(out_real * output_inverse_scale + out_zp);
|
||||||
}
|
}
|
||||||
return NNACL_OK;
|
return NNACL_OK;
|
||||||
}
|
}
|
||||||
|
|
||||||
int ElementLessEqual(int8_t *input0, int8_t *input1, int8_t *output, int element_size) {
|
int ElementLessEqualInt8(int8_t *input0, int8_t *input1, int8_t *output, int element_size,
|
||||||
|
ArithmeticQuantArg *quant_arg) {
|
||||||
|
float in0_bias = -quant_arg->in0_args_.zp_ * quant_arg->in0_args_.scale_;
|
||||||
|
float in1_bias = -quant_arg->in1_args_.zp_ * quant_arg->in1_args_.scale_;
|
||||||
|
float output_inverse_scale = 1.f / quant_arg->out_args_.scale_;
|
||||||
|
float out_zp = quant_arg->out_args_.zp_;
|
||||||
for (int index = 0; index < element_size; ++index) {
|
for (int index = 0; index < element_size; ++index) {
|
||||||
output[index] = (int8_t)(input0[index] <= input1[index]);
|
float in0_real = input0[index] * quant_arg->in0_args_.scale_ + in0_bias;
|
||||||
|
float in1_real = input1[index] * quant_arg->in1_args_.scale_ + in1_bias;
|
||||||
|
float out_real = (float)(in0_real <= in1_real);
|
||||||
|
output[index] = (int8_t)(out_real * output_inverse_scale + out_zp);
|
||||||
}
|
}
|
||||||
return NNACL_OK;
|
return NNACL_OK;
|
||||||
}
|
}
|
||||||
|
|
||||||
int ElementGreater(int8_t *input0, int8_t *input1, int8_t *output, int element_size) {
|
int ElementGreaterInt8(int8_t *input0, int8_t *input1, int8_t *output, int element_size,
|
||||||
|
ArithmeticQuantArg *quant_arg) {
|
||||||
|
float in0_bias = -quant_arg->in0_args_.zp_ * quant_arg->in0_args_.scale_;
|
||||||
|
float in1_bias = -quant_arg->in1_args_.zp_ * quant_arg->in1_args_.scale_;
|
||||||
|
float output_inverse_scale = 1.f / quant_arg->out_args_.scale_;
|
||||||
|
float out_zp = quant_arg->out_args_.zp_;
|
||||||
for (int index = 0; index < element_size; ++index) {
|
for (int index = 0; index < element_size; ++index) {
|
||||||
output[index] = (int8_t)(input0[index] > input1[index]);
|
float in0_real = input0[index] * quant_arg->in0_args_.scale_ + in0_bias;
|
||||||
|
float in1_real = input1[index] * quant_arg->in1_args_.scale_ + in1_bias;
|
||||||
|
float out_real = (float)(in0_real > in1_real);
|
||||||
|
output[index] = (int8_t)(out_real * output_inverse_scale + out_zp);
|
||||||
}
|
}
|
||||||
return NNACL_OK;
|
return NNACL_OK;
|
||||||
}
|
}
|
||||||
|
|
||||||
int ElementGreaterEqual(int8_t *input0, int8_t *input1, int8_t *output, int element_size) {
|
int ElementGreaterEqualInt8(int8_t *input0, int8_t *input1, int8_t *output, int element_size,
|
||||||
|
ArithmeticQuantArg *quant_arg) {
|
||||||
|
float in0_bias = -quant_arg->in0_args_.zp_ * quant_arg->in0_args_.scale_;
|
||||||
|
float in1_bias = -quant_arg->in1_args_.zp_ * quant_arg->in1_args_.scale_;
|
||||||
|
float output_inverse_scale = 1.f / quant_arg->out_args_.scale_;
|
||||||
|
float out_zp = quant_arg->out_args_.zp_;
|
||||||
for (int index = 0; index < element_size; ++index) {
|
for (int index = 0; index < element_size; ++index) {
|
||||||
output[index] = (int8_t)(input0[index] >= input1[index]);
|
float in0_real = input0[index] * quant_arg->in0_args_.scale_ + in0_bias;
|
||||||
|
float in1_real = input1[index] * quant_arg->in1_args_.scale_ + in1_bias;
|
||||||
|
float out_real = (float)(in0_real >= in1_real);
|
||||||
|
output[index] = (int8_t)(out_real * output_inverse_scale + out_zp);
|
||||||
}
|
}
|
||||||
return NNACL_OK;
|
return NNACL_OK;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#undef ACCURACY_DATA
|
||||||
|
|
|
@ -17,16 +17,21 @@
|
||||||
#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_INT8_ARITHMETIC_INT8_H_
|
#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_INT8_ARITHMETIC_INT8_H_
|
||||||
|
|
||||||
#include "nnacl/op_base.h"
|
#include "nnacl/op_base.h"
|
||||||
|
#include "nnacl/quantization/quantize.h"
|
||||||
|
|
||||||
int ElementNotEqual(int8_t *input0, int8_t *input1, int8_t *output, int element_size);
|
int ElementNotEqualInt8(int8_t *input0, int8_t *input1, int8_t *output, int element_size,
|
||||||
|
ArithmeticQuantArg *quant_arg);
|
||||||
|
|
||||||
int ElementEqual(int8_t *input0, int8_t *input1, int8_t *output, int element_size);
|
int ElementEqualInt8(int8_t *input0, int8_t *input1, int8_t *output, int element_size, ArithmeticQuantArg *quant_arg);
|
||||||
|
|
||||||
int ElementLess(int8_t *input0, int8_t *input1, int8_t *output, int element_size);
|
int ElementLessInt8(int8_t *input0, int8_t *input1, int8_t *output, int element_size, ArithmeticQuantArg *quant_arg);
|
||||||
|
|
||||||
int ElementLessEqual(int8_t *input0, int8_t *input1, int8_t *output, int element_size);
|
int ElementLessEqualInt8(int8_t *input0, int8_t *input1, int8_t *output, int element_size,
|
||||||
|
ArithmeticQuantArg *quant_arg);
|
||||||
|
|
||||||
int ElementGreater(int8_t *input0, int8_t *input1, int8_t *output, int element_size);
|
int ElementGreaterInt8(int8_t *input0, int8_t *input1, int8_t *output, int element_size,
|
||||||
|
ArithmeticQuantArg *quant_arg);
|
||||||
|
|
||||||
int ElementGreaterEqual(int8_t *input0, int8_t *input1, int8_t *output, int element_size);
|
int ElementGreaterEqualInt8(int8_t *input0, int8_t *input1, int8_t *output, int element_size,
|
||||||
|
ArithmeticQuantArg *quant_arg);
|
||||||
#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_INT8_ARITHMETIC_INT8_H_
|
#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_INT8_ARITHMETIC_INT8_H_
|
||||||
|
|
|
@ -193,6 +193,12 @@ typedef struct SubQuantArg {
|
||||||
int right_shift_out_;
|
int right_shift_out_;
|
||||||
} SubQuantArg;
|
} SubQuantArg;
|
||||||
|
|
||||||
|
typedef struct ArithmeticQuantArg {
|
||||||
|
QuantArg in0_args_;
|
||||||
|
QuantArg in1_args_;
|
||||||
|
QuantArg out_args_;
|
||||||
|
} ArithmeticQuantArg;
|
||||||
|
|
||||||
void QuantizeMultiplier(double double_multiplier, int32_t *quantized_multiplier, int *shift);
|
void QuantizeMultiplier(double double_multiplier, int32_t *quantized_multiplier, int *shift);
|
||||||
|
|
||||||
inline void QuantizeMultiplierSmallerThanOne(double double_multiplier, int32_t *quantized_multiplier,
|
inline void QuantizeMultiplierSmallerThanOne(double double_multiplier, int32_t *quantized_multiplier,
|
||||||
|
|
Loading…
Reference in New Issue