optimize div

This commit is contained in:
chenjianping 2020-09-01 21:57:41 +08:00
parent dff02cc282
commit 2f2ec6540b
4 changed files with 93 additions and 0 deletions

View File

@ -470,6 +470,65 @@ int ElementOptAddRelu6(float *input0, float *input1, float *output, int element_
return NNACL_OK;
}
int ElementOptDiv(float *input0, float *input1, float *output, int element_size, ArithmeticParameter *param) {
if (param->in_elements_num0_ == 1) {
for (int index = 0; index < element_size; ++index) {
if (input1[index] == 0) {
return NNACL_ERRCODE_DIVISOR_ZERO;
}
output[index] = input0[0] / input1[index];
}
} else {
if (input1[0] == 0) {
return NNACL_ERRCODE_DIVISOR_ZERO;
}
for (int index = 0; index < element_size; ++index) {
output[index] = input0[index] / input1[0];
}
}
return NNACL_OK;
}
int ElementOptDivRelu(float *input0, float *input1, float *output, int element_size, ArithmeticParameter *param) {
if (param->in_elements_num0_ == 1) {
for (int index = 0; index < element_size; ++index) {
if (input1[index] == 0) {
return NNACL_ERRCODE_DIVISOR_ZERO;
}
output[index] = input0[0] / input1[index];
output[index] = output[index] > 0 ? output[index] : 0;
}
} else {
if (input1[0] == 0) {
return NNACL_ERRCODE_DIVISOR_ZERO;
}
for (int index = 0; index < element_size; ++index) {
output[index] = input0[index] / input1[0];
output[index] = output[index] > 0 ? output[index] : 0;
}
}
return NNACL_OK;
}
int ElementOptDivRelu6(float *input0, float *input1, float *output, int element_size, ArithmeticParameter *param) {
if (param->in_elements_num0_ == 1) {
for (int index = 0; index < element_size; ++index) {
if (input1[index] == 0) {
return NNACL_ERRCODE_DIVISOR_ZERO;
}
output[index] = MSMIN(MSMAX(input0[0] / input1[index], 0), 6);
}
} else {
if (input1[0] == 0) {
return NNACL_ERRCODE_DIVISOR_ZERO;
}
for (int index = 0; index < element_size; ++index) {
output[index] = MSMIN(MSMAX(input0[index] / input1[0], 0), 6);
}
}
return NNACL_OK;
}
int ElementMul(float *input0, float *input1, float *output, int element_size) {
int block_mod = element_size % C4NUM;
int block_c4 = element_size - block_mod;

View File

@ -35,6 +35,9 @@ int ElementOptSubRelu6(float *input0, float *input1, float *output, int element_
int ElementOptMul(float *input0, float *input1, float *output, int element_size, ArithmeticParameter *param);
int ElementOptMulRelu(float *input0, float *input1, float *output, int element_size, ArithmeticParameter *param);
int ElementOptMulRelu6(float *input0, float *input1, float *output, int element_size, ArithmeticParameter *param);
int ElementOptDiv(float *input0, float *input1, float *output, int element_size, ArithmeticParameter *param);
int ElementOptDivRelu(float *input0, float *input1, float *output, int element_size, ArithmeticParameter *param);
int ElementOptDivRelu6(float *input0, float *input1, float *output, int element_size, ArithmeticParameter *param);
int ElementMul(float *input0, float *input1, float *output, int element_size);
int ElementMulRelu(float *input0, float *input1, float *output, int element_size);
int ElementMulRelu6(float *input0, float *input1, float *output, int element_size);

View File

@ -94,6 +94,22 @@ int ArithmeticCPUKernel::ReSize() {
break;
}
break;
case PrimitiveType_Div:
switch (arithmeticParameter_->activation_type_) {
case schema::ActivationType_RELU:
arithmeticParameter_->broadcasting_ = false;
arithmetic_opt_run_ = ElementOptDivRelu;
break;
case schema::ActivationType_RELU6:
arithmeticParameter_->broadcasting_ = false;
arithmetic_opt_run_ = ElementOptDivRelu6;
break;
default:
arithmeticParameter_->broadcasting_ = false;
arithmetic_opt_run_ = ElementOptDiv;
break;
}
break;
default:
break;
}

View File

@ -158,6 +158,21 @@ TEST_F(TestArithmeticTestFp32, DivTest) {
delete div_param;
}
TEST_F(TestArithmeticTestFp32, DivTest2) {
std::vector<float> in0 = {10, 20, 30, 40, 50, 60, 70, 80, 90, 100};
std::vector<float> in1 = {5, 10, 2, 8, 2, 3, 7, 80, 45, 20};
std::vector<float> correct_out = {2, 2, 15, 5, 25, 20, 10, 1, 2, 5};
constexpr int kOutSize = 10;
float out[kOutSize];
ElementDiv(in0.data(), in1.data(), out, kOutSize);
std::cout << "out: ";
for (int i = 0; i < kOutSize; ++i) {
std::cout << out[i] << " ";
}
std::cout << "\n";
CompareOutputData(out, correct_out.data(), kOutSize, 0.00001);
}
TEST_F(TestArithmeticTestFp32, FloorDivTest) {
auto fdiv_param = new ArithmeticParameter();
fdiv_param->ndim_ = 4;