From 2f2ec6540b5abe9e47e10e9ec4c0a0ddecd3a168 Mon Sep 17 00:00:00 2001 From: chenjianping Date: Tue, 1 Sep 2020 21:57:41 +0800 Subject: [PATCH] optimize div --- mindspore/lite/nnacl/fp32/arithmetic.c | 59 +++++++++++++++++++ mindspore/lite/nnacl/fp32/arithmetic.h | 3 + .../src/runtime/kernel/arm/fp32/arithmetic.cc | 16 +++++ .../kernel/arm/fp32/arithmetic_fp32_tests.cc | 15 +++++ 4 files changed, 93 insertions(+) diff --git a/mindspore/lite/nnacl/fp32/arithmetic.c b/mindspore/lite/nnacl/fp32/arithmetic.c index d08ba5c92c9..a5686a59b63 100644 --- a/mindspore/lite/nnacl/fp32/arithmetic.c +++ b/mindspore/lite/nnacl/fp32/arithmetic.c @@ -470,6 +470,65 @@ int ElementOptAddRelu6(float *input0, float *input1, float *output, int element_ return NNACL_OK; } +int ElementOptDiv(float *input0, float *input1, float *output, int element_size, ArithmeticParameter *param) { + if (param->in_elements_num0_ == 1) { + for (int index = 0; index < element_size; ++index) { + if (input1[index] == 0) { + return NNACL_ERRCODE_DIVISOR_ZERO; + } + output[index] = input0[0] / input1[index]; + } + } else { + if (input1[0] == 0) { + return NNACL_ERRCODE_DIVISOR_ZERO; + } + for (int index = 0; index < element_size; ++index) { + output[index] = input0[index] / input1[0]; + } + } + return NNACL_OK; +} + +int ElementOptDivRelu(float *input0, float *input1, float *output, int element_size, ArithmeticParameter *param) { + if (param->in_elements_num0_ == 1) { + for (int index = 0; index < element_size; ++index) { + if (input1[index] == 0) { + return NNACL_ERRCODE_DIVISOR_ZERO; + } + output[index] = input0[0] / input1[index]; + output[index] = output[index] > 0 ? output[index] : 0; + } + } else { + if (input1[0] == 0) { + return NNACL_ERRCODE_DIVISOR_ZERO; + } + for (int index = 0; index < element_size; ++index) { + output[index] = input0[index] / input1[0]; + output[index] = output[index] > 0 ? output[index] : 0; + } + } + return NNACL_OK; +} + +int ElementOptDivRelu6(float *input0, float *input1, float *output, int element_size, ArithmeticParameter *param) { + if (param->in_elements_num0_ == 1) { + for (int index = 0; index < element_size; ++index) { + if (input1[index] == 0) { + return NNACL_ERRCODE_DIVISOR_ZERO; + } + output[index] = MSMIN(MSMAX(input0[0] / input1[index], 0), 6); + } + } else { + if (input1[0] == 0) { + return NNACL_ERRCODE_DIVISOR_ZERO; + } + for (int index = 0; index < element_size; ++index) { + output[index] = MSMIN(MSMAX(input0[index] / input1[0], 0), 6); + } + } + return NNACL_OK; +} + int ElementMul(float *input0, float *input1, float *output, int element_size) { int block_mod = element_size % C4NUM; int block_c4 = element_size - block_mod; diff --git a/mindspore/lite/nnacl/fp32/arithmetic.h b/mindspore/lite/nnacl/fp32/arithmetic.h index ab0e0d02972..22c5d36c021 100644 --- a/mindspore/lite/nnacl/fp32/arithmetic.h +++ b/mindspore/lite/nnacl/fp32/arithmetic.h @@ -35,6 +35,9 @@ int ElementOptSubRelu6(float *input0, float *input1, float *output, int element_ int ElementOptMul(float *input0, float *input1, float *output, int element_size, ArithmeticParameter *param); int ElementOptMulRelu(float *input0, float *input1, float *output, int element_size, ArithmeticParameter *param); int ElementOptMulRelu6(float *input0, float *input1, float *output, int element_size, ArithmeticParameter *param); +int ElementOptDiv(float *input0, float *input1, float *output, int element_size, ArithmeticParameter *param); +int ElementOptDivRelu(float *input0, float *input1, float *output, int element_size, ArithmeticParameter *param); +int ElementOptDivRelu6(float *input0, float *input1, float *output, int element_size, ArithmeticParameter *param); int ElementMul(float *input0, float *input1, float *output, int element_size); int ElementMulRelu(float *input0, float *input1, float *output, int element_size); int ElementMulRelu6(float *input0, float *input1, float *output, int element_size); diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/arithmetic.cc b/mindspore/lite/src/runtime/kernel/arm/fp32/arithmetic.cc index 6a72842ce53..9b9ea75fb8c 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32/arithmetic.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/arithmetic.cc @@ -94,6 +94,22 @@ int ArithmeticCPUKernel::ReSize() { break; } break; + case PrimitiveType_Div: + switch (arithmeticParameter_->activation_type_) { + case schema::ActivationType_RELU: + arithmeticParameter_->broadcasting_ = false; + arithmetic_opt_run_ = ElementOptDivRelu; + break; + case schema::ActivationType_RELU6: + arithmeticParameter_->broadcasting_ = false; + arithmetic_opt_run_ = ElementOptDivRelu6; + break; + default: + arithmeticParameter_->broadcasting_ = false; + arithmetic_opt_run_ = ElementOptDiv; + break; + } + break; default: break; } diff --git a/mindspore/lite/test/ut/src/runtime/kernel/arm/fp32/arithmetic_fp32_tests.cc b/mindspore/lite/test/ut/src/runtime/kernel/arm/fp32/arithmetic_fp32_tests.cc index 0b0967e64f2..b66894ac882 100644 --- a/mindspore/lite/test/ut/src/runtime/kernel/arm/fp32/arithmetic_fp32_tests.cc +++ b/mindspore/lite/test/ut/src/runtime/kernel/arm/fp32/arithmetic_fp32_tests.cc @@ -158,6 +158,21 @@ TEST_F(TestArithmeticTestFp32, DivTest) { delete div_param; } +TEST_F(TestArithmeticTestFp32, DivTest2) { + std::vector in0 = {10, 20, 30, 40, 50, 60, 70, 80, 90, 100}; + std::vector in1 = {5, 10, 2, 8, 2, 3, 7, 80, 45, 20}; + std::vector correct_out = {2, 2, 15, 5, 25, 20, 10, 1, 2, 5}; + constexpr int kOutSize = 10; + float out[kOutSize]; + ElementDiv(in0.data(), in1.data(), out, kOutSize); + std::cout << "out: "; + for (int i = 0; i < kOutSize; ++i) { + std::cout << out[i] << " "; + } + std::cout << "\n"; + CompareOutputData(out, correct_out.data(), kOutSize, 0.00001); +} + TEST_F(TestArithmeticTestFp32, FloorDivTest) { auto fdiv_param = new ArithmeticParameter(); fdiv_param->ndim_ = 4;