forked from mindspore-Ecosystem/mindspore
optimize div
This commit is contained in:
parent
dff02cc282
commit
2f2ec6540b
|
@ -470,6 +470,65 @@ int ElementOptAddRelu6(float *input0, float *input1, float *output, int element_
|
||||||
return NNACL_OK;
|
return NNACL_OK;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
int ElementOptDiv(float *input0, float *input1, float *output, int element_size, ArithmeticParameter *param) {
|
||||||
|
if (param->in_elements_num0_ == 1) {
|
||||||
|
for (int index = 0; index < element_size; ++index) {
|
||||||
|
if (input1[index] == 0) {
|
||||||
|
return NNACL_ERRCODE_DIVISOR_ZERO;
|
||||||
|
}
|
||||||
|
output[index] = input0[0] / input1[index];
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
if (input1[0] == 0) {
|
||||||
|
return NNACL_ERRCODE_DIVISOR_ZERO;
|
||||||
|
}
|
||||||
|
for (int index = 0; index < element_size; ++index) {
|
||||||
|
output[index] = input0[index] / input1[0];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return NNACL_OK;
|
||||||
|
}
|
||||||
|
|
||||||
|
int ElementOptDivRelu(float *input0, float *input1, float *output, int element_size, ArithmeticParameter *param) {
|
||||||
|
if (param->in_elements_num0_ == 1) {
|
||||||
|
for (int index = 0; index < element_size; ++index) {
|
||||||
|
if (input1[index] == 0) {
|
||||||
|
return NNACL_ERRCODE_DIVISOR_ZERO;
|
||||||
|
}
|
||||||
|
output[index] = input0[0] / input1[index];
|
||||||
|
output[index] = output[index] > 0 ? output[index] : 0;
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
if (input1[0] == 0) {
|
||||||
|
return NNACL_ERRCODE_DIVISOR_ZERO;
|
||||||
|
}
|
||||||
|
for (int index = 0; index < element_size; ++index) {
|
||||||
|
output[index] = input0[index] / input1[0];
|
||||||
|
output[index] = output[index] > 0 ? output[index] : 0;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return NNACL_OK;
|
||||||
|
}
|
||||||
|
|
||||||
|
int ElementOptDivRelu6(float *input0, float *input1, float *output, int element_size, ArithmeticParameter *param) {
|
||||||
|
if (param->in_elements_num0_ == 1) {
|
||||||
|
for (int index = 0; index < element_size; ++index) {
|
||||||
|
if (input1[index] == 0) {
|
||||||
|
return NNACL_ERRCODE_DIVISOR_ZERO;
|
||||||
|
}
|
||||||
|
output[index] = MSMIN(MSMAX(input0[0] / input1[index], 0), 6);
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
if (input1[0] == 0) {
|
||||||
|
return NNACL_ERRCODE_DIVISOR_ZERO;
|
||||||
|
}
|
||||||
|
for (int index = 0; index < element_size; ++index) {
|
||||||
|
output[index] = MSMIN(MSMAX(input0[index] / input1[0], 0), 6);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return NNACL_OK;
|
||||||
|
}
|
||||||
|
|
||||||
int ElementMul(float *input0, float *input1, float *output, int element_size) {
|
int ElementMul(float *input0, float *input1, float *output, int element_size) {
|
||||||
int block_mod = element_size % C4NUM;
|
int block_mod = element_size % C4NUM;
|
||||||
int block_c4 = element_size - block_mod;
|
int block_c4 = element_size - block_mod;
|
||||||
|
|
|
@ -35,6 +35,9 @@ int ElementOptSubRelu6(float *input0, float *input1, float *output, int element_
|
||||||
int ElementOptMul(float *input0, float *input1, float *output, int element_size, ArithmeticParameter *param);
|
int ElementOptMul(float *input0, float *input1, float *output, int element_size, ArithmeticParameter *param);
|
||||||
int ElementOptMulRelu(float *input0, float *input1, float *output, int element_size, ArithmeticParameter *param);
|
int ElementOptMulRelu(float *input0, float *input1, float *output, int element_size, ArithmeticParameter *param);
|
||||||
int ElementOptMulRelu6(float *input0, float *input1, float *output, int element_size, ArithmeticParameter *param);
|
int ElementOptMulRelu6(float *input0, float *input1, float *output, int element_size, ArithmeticParameter *param);
|
||||||
|
int ElementOptDiv(float *input0, float *input1, float *output, int element_size, ArithmeticParameter *param);
|
||||||
|
int ElementOptDivRelu(float *input0, float *input1, float *output, int element_size, ArithmeticParameter *param);
|
||||||
|
int ElementOptDivRelu6(float *input0, float *input1, float *output, int element_size, ArithmeticParameter *param);
|
||||||
int ElementMul(float *input0, float *input1, float *output, int element_size);
|
int ElementMul(float *input0, float *input1, float *output, int element_size);
|
||||||
int ElementMulRelu(float *input0, float *input1, float *output, int element_size);
|
int ElementMulRelu(float *input0, float *input1, float *output, int element_size);
|
||||||
int ElementMulRelu6(float *input0, float *input1, float *output, int element_size);
|
int ElementMulRelu6(float *input0, float *input1, float *output, int element_size);
|
||||||
|
|
|
@ -94,6 +94,22 @@ int ArithmeticCPUKernel::ReSize() {
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
break;
|
break;
|
||||||
|
case PrimitiveType_Div:
|
||||||
|
switch (arithmeticParameter_->activation_type_) {
|
||||||
|
case schema::ActivationType_RELU:
|
||||||
|
arithmeticParameter_->broadcasting_ = false;
|
||||||
|
arithmetic_opt_run_ = ElementOptDivRelu;
|
||||||
|
break;
|
||||||
|
case schema::ActivationType_RELU6:
|
||||||
|
arithmeticParameter_->broadcasting_ = false;
|
||||||
|
arithmetic_opt_run_ = ElementOptDivRelu6;
|
||||||
|
break;
|
||||||
|
default:
|
||||||
|
arithmeticParameter_->broadcasting_ = false;
|
||||||
|
arithmetic_opt_run_ = ElementOptDiv;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
break;
|
||||||
default:
|
default:
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
|
|
@ -158,6 +158,21 @@ TEST_F(TestArithmeticTestFp32, DivTest) {
|
||||||
delete div_param;
|
delete div_param;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST_F(TestArithmeticTestFp32, DivTest2) {
|
||||||
|
std::vector<float> in0 = {10, 20, 30, 40, 50, 60, 70, 80, 90, 100};
|
||||||
|
std::vector<float> in1 = {5, 10, 2, 8, 2, 3, 7, 80, 45, 20};
|
||||||
|
std::vector<float> correct_out = {2, 2, 15, 5, 25, 20, 10, 1, 2, 5};
|
||||||
|
constexpr int kOutSize = 10;
|
||||||
|
float out[kOutSize];
|
||||||
|
ElementDiv(in0.data(), in1.data(), out, kOutSize);
|
||||||
|
std::cout << "out: ";
|
||||||
|
for (int i = 0; i < kOutSize; ++i) {
|
||||||
|
std::cout << out[i] << " ";
|
||||||
|
}
|
||||||
|
std::cout << "\n";
|
||||||
|
CompareOutputData(out, correct_out.data(), kOutSize, 0.00001);
|
||||||
|
}
|
||||||
|
|
||||||
TEST_F(TestArithmeticTestFp32, FloorDivTest) {
|
TEST_F(TestArithmeticTestFp32, FloorDivTest) {
|
||||||
auto fdiv_param = new ArithmeticParameter();
|
auto fdiv_param = new ArithmeticParameter();
|
||||||
fdiv_param->ndim_ = 4;
|
fdiv_param->ndim_ = 4;
|
||||||
|
|
Loading…
Reference in New Issue