forked from mindspore-Ecosystem/mindspore
optimize div
This commit is contained in:
parent
dff02cc282
commit
2f2ec6540b
|
@ -470,6 +470,65 @@ int ElementOptAddRelu6(float *input0, float *input1, float *output, int element_
|
|||
return NNACL_OK;
|
||||
}
|
||||
|
||||
int ElementOptDiv(float *input0, float *input1, float *output, int element_size, ArithmeticParameter *param) {
|
||||
if (param->in_elements_num0_ == 1) {
|
||||
for (int index = 0; index < element_size; ++index) {
|
||||
if (input1[index] == 0) {
|
||||
return NNACL_ERRCODE_DIVISOR_ZERO;
|
||||
}
|
||||
output[index] = input0[0] / input1[index];
|
||||
}
|
||||
} else {
|
||||
if (input1[0] == 0) {
|
||||
return NNACL_ERRCODE_DIVISOR_ZERO;
|
||||
}
|
||||
for (int index = 0; index < element_size; ++index) {
|
||||
output[index] = input0[index] / input1[0];
|
||||
}
|
||||
}
|
||||
return NNACL_OK;
|
||||
}
|
||||
|
||||
int ElementOptDivRelu(float *input0, float *input1, float *output, int element_size, ArithmeticParameter *param) {
|
||||
if (param->in_elements_num0_ == 1) {
|
||||
for (int index = 0; index < element_size; ++index) {
|
||||
if (input1[index] == 0) {
|
||||
return NNACL_ERRCODE_DIVISOR_ZERO;
|
||||
}
|
||||
output[index] = input0[0] / input1[index];
|
||||
output[index] = output[index] > 0 ? output[index] : 0;
|
||||
}
|
||||
} else {
|
||||
if (input1[0] == 0) {
|
||||
return NNACL_ERRCODE_DIVISOR_ZERO;
|
||||
}
|
||||
for (int index = 0; index < element_size; ++index) {
|
||||
output[index] = input0[index] / input1[0];
|
||||
output[index] = output[index] > 0 ? output[index] : 0;
|
||||
}
|
||||
}
|
||||
return NNACL_OK;
|
||||
}
|
||||
|
||||
int ElementOptDivRelu6(float *input0, float *input1, float *output, int element_size, ArithmeticParameter *param) {
|
||||
if (param->in_elements_num0_ == 1) {
|
||||
for (int index = 0; index < element_size; ++index) {
|
||||
if (input1[index] == 0) {
|
||||
return NNACL_ERRCODE_DIVISOR_ZERO;
|
||||
}
|
||||
output[index] = MSMIN(MSMAX(input0[0] / input1[index], 0), 6);
|
||||
}
|
||||
} else {
|
||||
if (input1[0] == 0) {
|
||||
return NNACL_ERRCODE_DIVISOR_ZERO;
|
||||
}
|
||||
for (int index = 0; index < element_size; ++index) {
|
||||
output[index] = MSMIN(MSMAX(input0[index] / input1[0], 0), 6);
|
||||
}
|
||||
}
|
||||
return NNACL_OK;
|
||||
}
|
||||
|
||||
int ElementMul(float *input0, float *input1, float *output, int element_size) {
|
||||
int block_mod = element_size % C4NUM;
|
||||
int block_c4 = element_size - block_mod;
|
||||
|
|
|
@ -35,6 +35,9 @@ int ElementOptSubRelu6(float *input0, float *input1, float *output, int element_
|
|||
int ElementOptMul(float *input0, float *input1, float *output, int element_size, ArithmeticParameter *param);
|
||||
int ElementOptMulRelu(float *input0, float *input1, float *output, int element_size, ArithmeticParameter *param);
|
||||
int ElementOptMulRelu6(float *input0, float *input1, float *output, int element_size, ArithmeticParameter *param);
|
||||
int ElementOptDiv(float *input0, float *input1, float *output, int element_size, ArithmeticParameter *param);
|
||||
int ElementOptDivRelu(float *input0, float *input1, float *output, int element_size, ArithmeticParameter *param);
|
||||
int ElementOptDivRelu6(float *input0, float *input1, float *output, int element_size, ArithmeticParameter *param);
|
||||
int ElementMul(float *input0, float *input1, float *output, int element_size);
|
||||
int ElementMulRelu(float *input0, float *input1, float *output, int element_size);
|
||||
int ElementMulRelu6(float *input0, float *input1, float *output, int element_size);
|
||||
|
|
|
@ -94,6 +94,22 @@ int ArithmeticCPUKernel::ReSize() {
|
|||
break;
|
||||
}
|
||||
break;
|
||||
case PrimitiveType_Div:
|
||||
switch (arithmeticParameter_->activation_type_) {
|
||||
case schema::ActivationType_RELU:
|
||||
arithmeticParameter_->broadcasting_ = false;
|
||||
arithmetic_opt_run_ = ElementOptDivRelu;
|
||||
break;
|
||||
case schema::ActivationType_RELU6:
|
||||
arithmeticParameter_->broadcasting_ = false;
|
||||
arithmetic_opt_run_ = ElementOptDivRelu6;
|
||||
break;
|
||||
default:
|
||||
arithmeticParameter_->broadcasting_ = false;
|
||||
arithmetic_opt_run_ = ElementOptDiv;
|
||||
break;
|
||||
}
|
||||
break;
|
||||
default:
|
||||
break;
|
||||
}
|
||||
|
|
|
@ -158,6 +158,21 @@ TEST_F(TestArithmeticTestFp32, DivTest) {
|
|||
delete div_param;
|
||||
}
|
||||
|
||||
TEST_F(TestArithmeticTestFp32, DivTest2) {
|
||||
std::vector<float> in0 = {10, 20, 30, 40, 50, 60, 70, 80, 90, 100};
|
||||
std::vector<float> in1 = {5, 10, 2, 8, 2, 3, 7, 80, 45, 20};
|
||||
std::vector<float> correct_out = {2, 2, 15, 5, 25, 20, 10, 1, 2, 5};
|
||||
constexpr int kOutSize = 10;
|
||||
float out[kOutSize];
|
||||
ElementDiv(in0.data(), in1.data(), out, kOutSize);
|
||||
std::cout << "out: ";
|
||||
for (int i = 0; i < kOutSize; ++i) {
|
||||
std::cout << out[i] << " ";
|
||||
}
|
||||
std::cout << "\n";
|
||||
CompareOutputData(out, correct_out.data(), kOutSize, 0.00001);
|
||||
}
|
||||
|
||||
TEST_F(TestArithmeticTestFp32, FloorDivTest) {
|
||||
auto fdiv_param = new ArithmeticParameter();
|
||||
fdiv_param->ndim_ = 4;
|
||||
|
|
Loading…
Reference in New Issue