modify arm cpu op: Arithmetic_fp16

This commit is contained in:
tao_yunhao 2020-08-18 18:11:26 +08:00
parent ff2851a25c
commit d1c8f967ac
3 changed files with 1061 additions and 82 deletions

View File

@ -195,23 +195,104 @@ int ArithmeticFP16CPUKernel::ReSize() {
}
if (arithmeticParameter_->in_elements_num0_ == 1 || arithmeticParameter_->in_elements_num1_ == 1) {
if (arithmeticParameter_->activation_type_ == schema::ActivationType_NO_ACTIVATION) {
switch (arithmeticParameter_->op_parameter_.type_) {
case PrimitiveType_Mul:
arithmeticParameter_->broadcasting_ = false;
arithmetic_opt_run_ = ElementOptMulFp16;
switch (arithmeticParameter_->activation_type_) {
case schema::ActivationType_RELU:
arithmetic_opt_run_ = ElementOptMulReluFp16;
break;
case schema::ActivationType_RELU6:
arithmetic_opt_run_ = ElementOptDivRelu6Fp16;
break;
default:
arithmetic_opt_run_ = ElementOptDivFp16;
break;
}
break;
case PrimitiveType_Add:
arithmeticParameter_->broadcasting_ = false;
switch (arithmeticParameter_->activation_type_) {
case schema::ActivationType_RELU:
arithmetic_opt_run_ = ElementOptAddReluFp16;
break;
case schema::ActivationType_RELU6:
arithmetic_opt_run_ = ElementOptAddRelu6Fp16;
break;
default:
arithmetic_opt_run_ = ElementOptAddFp16;
break;
}
break;
case PrimitiveType_Sub:
arithmeticParameter_->broadcasting_ = false;
arithmetic_opt_run_ = ElementOptSubFp16;
switch (arithmeticParameter_->activation_type_) {
case schema::ActivationType_RELU:
arithmetic_opt_run_ = ElementOptSubReluFp16;
break;
case schema::ActivationType_RELU6:
arithmetic_opt_run_ = ElementOptSubRelu6Fp16;
break;
default:
arithmetic_opt_run_ = ElementOptSubFp16;
break;
}
break;
case PrimitiveType_Div:
arithmeticParameter_->broadcasting_ = false;
switch (arithmeticParameter_->activation_type_) {
case schema::ActivationType_RELU:
arithmetic_opt_run_ = ElementOptDivReluFp16;
break;
case schema::ActivationType_RELU6:
arithmetic_opt_run_ = ElementOptDivRelu6Fp16;
break;
default:
arithmetic_opt_run_ = ElementOptDivFp16;
break;
}
break;
case PrimitiveType_FloorMod:
arithmeticParameter_->broadcasting_ = false;
arithmetic_opt_run_ = ElementOptFloorModFp16;
case PrimitiveType_FloorDiv:
arithmeticParameter_->broadcasting_ = false;
arithmetic_opt_run_ = ElementOptFloorDivFp16;
case PrimitiveType_LogicalAnd:
arithmeticParameter_->broadcasting_ = false;
arithmetic_opt_run_ = ElementOptLogicalAndFp16;
case PrimitiveType_LogicalOr:
arithmeticParameter_->broadcasting_ = false;
arithmetic_opt_run_ = ElementOptLogicalOrFp16;
case PrimitiveType_SquaredDifference:
arithmeticParameter_->broadcasting_ = false;
arithmetic_opt_run_ = ElementOptSquaredDifferenceFp16;
case PrimitiveType_Maximum:
arithmeticParameter_->broadcasting_ = false;
arithmetic_opt_run_ = ElementOptMaximumFp16;
case PrimitiveType_Minimum:
arithmeticParameter_->broadcasting_ = false;
arithmetic_opt_run_ = ElementOptMinimumFp16;
case PrimitiveType_NotEqual:
arithmeticParameter_->broadcasting_ = false;
arithmetic_opt_run_ = ElementOptNotEqualFp16;
case PrimitiveType_Equal:
arithmeticParameter_->broadcasting_ = false;
arithmetic_opt_run_ = ElementOptEqualFp16;
case PrimitiveType_Less:
arithmeticParameter_->broadcasting_ = false;
arithmetic_opt_run_ = ElementOptLessFp16;
case PrimitiveType_LessEqual:
arithmeticParameter_->broadcasting_ = false;
arithmetic_opt_run_ = ElementOptLessEqualFp16;
case PrimitiveType_Greater:
arithmeticParameter_->broadcasting_ = false;
arithmetic_opt_run_ = ElementOptGreaterFp16;
case PrimitiveType_GreaterEqual:
arithmeticParameter_->broadcasting_ = false;
arithmetic_opt_run_ = ElementOptGreaterEqualFp16;
default:
break;
}
}
@ -333,4 +414,17 @@ kernel::LiteKernel *CpuArithmeticFp16KernelCreator(const std::vector<lite::tenso
REG_KERNEL(kCPU, kNumberTypeFloat16, PrimitiveType_Mul, CpuArithmeticFp16KernelCreator)
REG_KERNEL(kCPU, kNumberTypeFloat16, PrimitiveType_Add, CpuArithmeticFp16KernelCreator)
REG_KERNEL(kCPU, kNumberTypeFloat16, PrimitiveType_Sub, CpuArithmeticFp16KernelCreator)
REG_KERNEL(kCPU, kNumberTypeFloat16, PrimitiveType_Div, CpuArithmeticFp16KernelCreator)
REG_KERNEL(kCPU, kNumberTypeFloat16, PrimitiveType_FloorMod, CpuArithmeticFp16KernelCreator)
REG_KERNEL(kCPU, kNumberTypeFloat16, PrimitiveType_FloorDiv, CpuArithmeticFp16KernelCreator)
REG_KERNEL(kCPU, kNumberTypeFloat16, PrimitiveType_LogicalAnd, CpuArithmeticFp16KernelCreator)
REG_KERNEL(kCPU, kNumberTypeFloat16, PrimitiveType_LogicalOr, CpuArithmeticFp16KernelCreator)
REG_KERNEL(kCPU, kNumberTypeFloat16, PrimitiveType_Maximum, CpuArithmeticFp16KernelCreator)
REG_KERNEL(kCPU, kNumberTypeFloat16, PrimitiveType_Minimum, CpuArithmeticFp16KernelCreator)
REG_KERNEL(kCPU, kNumberTypeFloat16, PrimitiveType_NotEqual, CpuArithmeticFp16KernelCreator)
REG_KERNEL(kCPU, kNumberTypeFloat16, PrimitiveType_Equal, CpuArithmeticFp16KernelCreator)
REG_KERNEL(kCPU, kNumberTypeFloat16, PrimitiveType_Less, CpuArithmeticFp16KernelCreator)
REG_KERNEL(kCPU, kNumberTypeFloat16, PrimitiveType_LessEqual, CpuArithmeticFp16KernelCreator)
REG_KERNEL(kCPU, kNumberTypeFloat16, PrimitiveType_Greater, CpuArithmeticFp16KernelCreator)
REG_KERNEL(kCPU, kNumberTypeFloat16, PrimitiveType_GreaterEqual, CpuArithmeticFp16KernelCreator)
} // namespace mindspore::kernel

View File

@ -26,12 +26,57 @@
#ifdef __cplusplus
extern "C" {
#endif
int ElementOptMulFp16(float16_t *input0, float16_t *input1, float16_t *output, int element_size,
ArithmeticParameter *param);
int ElementOptMulReluFp16(float16_t *input0, float16_t *input1, float16_t *output, int element_size,
ArithmeticParameter *param);
int ElementOptMulRelu6Fp16(float16_t *input0, float16_t *input1, float16_t *output, int element_size,
ArithmeticParameter *param);
int ElementOptAddFp16(float16_t *input0, float16_t *input1, float16_t *output, int element_size,
ArithmeticParameter *param);
int ElementOptAddReluFp16(float16_t *input0, float16_t *input1, float16_t *output, int element_size,
ArithmeticParameter *param);
int ElementOptAddRelu6Fp16(float16_t *input0, float16_t *input1, float16_t *output, int element_size,
ArithmeticParameter *param);
int ElementOptSubFp16(float16_t *input0, float16_t *input1, float16_t *output, int element_size,
ArithmeticParameter *param);
int ElementOptMulFp16(float16_t *input0, float16_t *input1, float16_t *output, int element_size,
int ElementOptSubReluFp16(float16_t *input0, float16_t *input1, float16_t *output, int element_size,
ArithmeticParameter *param);
int ElementOptSubRelu6Fp16(float16_t *input0, float16_t *input1, float16_t *output, int element_size,
ArithmeticParameter *param);
int ElementOptDivFp16(float16_t *input0, float16_t *input1, float16_t *output, int element_size,
ArithmeticParameter *param);
int ElementOptDivReluFp16(float16_t *input0, float16_t *input1, float16_t *output, int element_size,
ArithmeticParameter *param);
int ElementOptDivRelu6Fp16(float16_t *input0, float16_t *input1, float16_t *output, int element_size,
ArithmeticParameter *param);
int ElementOptFloorModFp16(float16_t *input0, float16_t *input1, float16_t *output, int element_size,
ArithmeticParameter *param);
int ElementOptFloorDivFp16(float16_t *input0, float16_t *input1, float16_t *output, int element_size,
ArithmeticParameter *param);
int ElementOptLogicalAndFp16(float16_t *input0, float16_t *input1, float16_t *output, int element_size,
ArithmeticParameter *param);
int ElementOptLogicalOrFp16(float16_t *input0, float16_t *input1, float16_t *output, int element_size,
ArithmeticParameter *param);
int ElementOptSquaredDifferenceFp16(float16_t *input0, float16_t *input1, float16_t *output, int element_size,
ArithmeticParameter *param);
int ElementOptMaximumFp16(float16_t *input0, float16_t *input1, float16_t *output, int element_size,
ArithmeticParameter *param);
int ElementOptMinimumFp16(float16_t *input0, float16_t *input1, float16_t *output, int element_size,
ArithmeticParameter *param);
int ElementOptNotEqualFp16(float16_t *input0, float16_t *input1, float16_t *output, int element_size,
ArithmeticParameter *param);
int ElementOptEqualFp16(float16_t *input0, float16_t *input1, float16_t *output, int element_size,
ArithmeticParameter *param);
int ElementOptLessFp16(float16_t *input0, float16_t *input1, float16_t *output, int element_size,
ArithmeticParameter *param);
int ElementOptLessEqualFp16(float16_t *input0, float16_t *input1, float16_t *output, int element_size,
ArithmeticParameter *param);
int ElementOptGreaterFp16(float16_t *input0, float16_t *input1, float16_t *output, int element_size,
ArithmeticParameter *param);
int ElementOptGreaterEqualFp16(float16_t *input0, float16_t *input1, float16_t *output, int element_size,
ArithmeticParameter *param);
int ElementMulFp16(float16_t *input0, float16_t *input1, float16_t *output, int element_size);
int ElementMulReluFp16(float16_t *input0, float16_t *input1, float16_t *output, int element_size);
int ElementMulRelu6Fp16(float16_t *input0, float16_t *input1, float16_t *output, int element_size);