forked from mindspore-Ecosystem/mindspore
modify arm cpu op: Arithmetic_fp16
This commit is contained in:
parent
ff2851a25c
commit
d1c8f967ac
|
@ -195,23 +195,104 @@ int ArithmeticFP16CPUKernel::ReSize() {
|
|||
}
|
||||
|
||||
if (arithmeticParameter_->in_elements_num0_ == 1 || arithmeticParameter_->in_elements_num1_ == 1) {
|
||||
if (arithmeticParameter_->activation_type_ == schema::ActivationType_NO_ACTIVATION) {
|
||||
switch (arithmeticParameter_->op_parameter_.type_) {
|
||||
case PrimitiveType_Mul:
|
||||
arithmeticParameter_->broadcasting_ = false;
|
||||
arithmetic_opt_run_ = ElementOptMulFp16;
|
||||
break;
|
||||
case PrimitiveType_Add:
|
||||
arithmeticParameter_->broadcasting_ = false;
|
||||
arithmetic_opt_run_ = ElementOptAddFp16;
|
||||
break;
|
||||
case PrimitiveType_Sub:
|
||||
arithmeticParameter_->broadcasting_ = false;
|
||||
arithmetic_opt_run_ = ElementOptSubFp16;
|
||||
break;
|
||||
default:
|
||||
break;
|
||||
}
|
||||
switch (arithmeticParameter_->op_parameter_.type_) {
|
||||
case PrimitiveType_Mul:
|
||||
arithmeticParameter_->broadcasting_ = false;
|
||||
switch (arithmeticParameter_->activation_type_) {
|
||||
case schema::ActivationType_RELU:
|
||||
arithmetic_opt_run_ = ElementOptMulReluFp16;
|
||||
break;
|
||||
case schema::ActivationType_RELU6:
|
||||
arithmetic_opt_run_ = ElementOptDivRelu6Fp16;
|
||||
break;
|
||||
default:
|
||||
arithmetic_opt_run_ = ElementOptDivFp16;
|
||||
break;
|
||||
}
|
||||
break;
|
||||
case PrimitiveType_Add:
|
||||
arithmeticParameter_->broadcasting_ = false;
|
||||
switch (arithmeticParameter_->activation_type_) {
|
||||
case schema::ActivationType_RELU:
|
||||
arithmetic_opt_run_ = ElementOptAddReluFp16;
|
||||
break;
|
||||
case schema::ActivationType_RELU6:
|
||||
arithmetic_opt_run_ = ElementOptAddRelu6Fp16;
|
||||
break;
|
||||
default:
|
||||
arithmetic_opt_run_ = ElementOptAddFp16;
|
||||
break;
|
||||
}
|
||||
break;
|
||||
case PrimitiveType_Sub:
|
||||
arithmeticParameter_->broadcasting_ = false;
|
||||
switch (arithmeticParameter_->activation_type_) {
|
||||
case schema::ActivationType_RELU:
|
||||
arithmetic_opt_run_ = ElementOptSubReluFp16;
|
||||
break;
|
||||
case schema::ActivationType_RELU6:
|
||||
arithmetic_opt_run_ = ElementOptSubRelu6Fp16;
|
||||
break;
|
||||
default:
|
||||
arithmetic_opt_run_ = ElementOptSubFp16;
|
||||
break;
|
||||
}
|
||||
break;
|
||||
case PrimitiveType_Div:
|
||||
arithmeticParameter_->broadcasting_ = false;
|
||||
switch (arithmeticParameter_->activation_type_) {
|
||||
case schema::ActivationType_RELU:
|
||||
arithmetic_opt_run_ = ElementOptDivReluFp16;
|
||||
break;
|
||||
case schema::ActivationType_RELU6:
|
||||
arithmetic_opt_run_ = ElementOptDivRelu6Fp16;
|
||||
break;
|
||||
default:
|
||||
arithmetic_opt_run_ = ElementOptDivFp16;
|
||||
break;
|
||||
}
|
||||
break;
|
||||
case PrimitiveType_FloorMod:
|
||||
arithmeticParameter_->broadcasting_ = false;
|
||||
arithmetic_opt_run_ = ElementOptFloorModFp16;
|
||||
case PrimitiveType_FloorDiv:
|
||||
arithmeticParameter_->broadcasting_ = false;
|
||||
arithmetic_opt_run_ = ElementOptFloorDivFp16;
|
||||
case PrimitiveType_LogicalAnd:
|
||||
arithmeticParameter_->broadcasting_ = false;
|
||||
arithmetic_opt_run_ = ElementOptLogicalAndFp16;
|
||||
case PrimitiveType_LogicalOr:
|
||||
arithmeticParameter_->broadcasting_ = false;
|
||||
arithmetic_opt_run_ = ElementOptLogicalOrFp16;
|
||||
case PrimitiveType_SquaredDifference:
|
||||
arithmeticParameter_->broadcasting_ = false;
|
||||
arithmetic_opt_run_ = ElementOptSquaredDifferenceFp16;
|
||||
case PrimitiveType_Maximum:
|
||||
arithmeticParameter_->broadcasting_ = false;
|
||||
arithmetic_opt_run_ = ElementOptMaximumFp16;
|
||||
case PrimitiveType_Minimum:
|
||||
arithmeticParameter_->broadcasting_ = false;
|
||||
arithmetic_opt_run_ = ElementOptMinimumFp16;
|
||||
case PrimitiveType_NotEqual:
|
||||
arithmeticParameter_->broadcasting_ = false;
|
||||
arithmetic_opt_run_ = ElementOptNotEqualFp16;
|
||||
case PrimitiveType_Equal:
|
||||
arithmeticParameter_->broadcasting_ = false;
|
||||
arithmetic_opt_run_ = ElementOptEqualFp16;
|
||||
case PrimitiveType_Less:
|
||||
arithmeticParameter_->broadcasting_ = false;
|
||||
arithmetic_opt_run_ = ElementOptLessFp16;
|
||||
case PrimitiveType_LessEqual:
|
||||
arithmeticParameter_->broadcasting_ = false;
|
||||
arithmetic_opt_run_ = ElementOptLessEqualFp16;
|
||||
case PrimitiveType_Greater:
|
||||
arithmeticParameter_->broadcasting_ = false;
|
||||
arithmetic_opt_run_ = ElementOptGreaterFp16;
|
||||
case PrimitiveType_GreaterEqual:
|
||||
arithmeticParameter_->broadcasting_ = false;
|
||||
arithmetic_opt_run_ = ElementOptGreaterEqualFp16;
|
||||
default:
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -333,4 +414,17 @@ kernel::LiteKernel *CpuArithmeticFp16KernelCreator(const std::vector<lite::tenso
|
|||
REG_KERNEL(kCPU, kNumberTypeFloat16, PrimitiveType_Mul, CpuArithmeticFp16KernelCreator)
|
||||
REG_KERNEL(kCPU, kNumberTypeFloat16, PrimitiveType_Add, CpuArithmeticFp16KernelCreator)
|
||||
REG_KERNEL(kCPU, kNumberTypeFloat16, PrimitiveType_Sub, CpuArithmeticFp16KernelCreator)
|
||||
REG_KERNEL(kCPU, kNumberTypeFloat16, PrimitiveType_Div, CpuArithmeticFp16KernelCreator)
|
||||
REG_KERNEL(kCPU, kNumberTypeFloat16, PrimitiveType_FloorMod, CpuArithmeticFp16KernelCreator)
|
||||
REG_KERNEL(kCPU, kNumberTypeFloat16, PrimitiveType_FloorDiv, CpuArithmeticFp16KernelCreator)
|
||||
REG_KERNEL(kCPU, kNumberTypeFloat16, PrimitiveType_LogicalAnd, CpuArithmeticFp16KernelCreator)
|
||||
REG_KERNEL(kCPU, kNumberTypeFloat16, PrimitiveType_LogicalOr, CpuArithmeticFp16KernelCreator)
|
||||
REG_KERNEL(kCPU, kNumberTypeFloat16, PrimitiveType_Maximum, CpuArithmeticFp16KernelCreator)
|
||||
REG_KERNEL(kCPU, kNumberTypeFloat16, PrimitiveType_Minimum, CpuArithmeticFp16KernelCreator)
|
||||
REG_KERNEL(kCPU, kNumberTypeFloat16, PrimitiveType_NotEqual, CpuArithmeticFp16KernelCreator)
|
||||
REG_KERNEL(kCPU, kNumberTypeFloat16, PrimitiveType_Equal, CpuArithmeticFp16KernelCreator)
|
||||
REG_KERNEL(kCPU, kNumberTypeFloat16, PrimitiveType_Less, CpuArithmeticFp16KernelCreator)
|
||||
REG_KERNEL(kCPU, kNumberTypeFloat16, PrimitiveType_LessEqual, CpuArithmeticFp16KernelCreator)
|
||||
REG_KERNEL(kCPU, kNumberTypeFloat16, PrimitiveType_Greater, CpuArithmeticFp16KernelCreator)
|
||||
REG_KERNEL(kCPU, kNumberTypeFloat16, PrimitiveType_GreaterEqual, CpuArithmeticFp16KernelCreator)
|
||||
} // namespace mindspore::kernel
|
||||
|
|
File diff suppressed because it is too large
Load Diff
|
@ -26,12 +26,57 @@
|
|||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif
|
||||
int ElementOptAddFp16(float16_t *input0, float16_t *input1, float16_t *output, int element_size,
|
||||
ArithmeticParameter *param);
|
||||
int ElementOptSubFp16(float16_t *input0, float16_t *input1, float16_t *output, int element_size,
|
||||
ArithmeticParameter *param);
|
||||
int ElementOptMulFp16(float16_t *input0, float16_t *input1, float16_t *output, int element_size,
|
||||
ArithmeticParameter *param);
|
||||
int ElementOptMulReluFp16(float16_t *input0, float16_t *input1, float16_t *output, int element_size,
|
||||
ArithmeticParameter *param);
|
||||
int ElementOptMulRelu6Fp16(float16_t *input0, float16_t *input1, float16_t *output, int element_size,
|
||||
ArithmeticParameter *param);
|
||||
int ElementOptAddFp16(float16_t *input0, float16_t *input1, float16_t *output, int element_size,
|
||||
ArithmeticParameter *param);
|
||||
int ElementOptAddReluFp16(float16_t *input0, float16_t *input1, float16_t *output, int element_size,
|
||||
ArithmeticParameter *param);
|
||||
int ElementOptAddRelu6Fp16(float16_t *input0, float16_t *input1, float16_t *output, int element_size,
|
||||
ArithmeticParameter *param);
|
||||
int ElementOptSubFp16(float16_t *input0, float16_t *input1, float16_t *output, int element_size,
|
||||
ArithmeticParameter *param);
|
||||
int ElementOptSubReluFp16(float16_t *input0, float16_t *input1, float16_t *output, int element_size,
|
||||
ArithmeticParameter *param);
|
||||
int ElementOptSubRelu6Fp16(float16_t *input0, float16_t *input1, float16_t *output, int element_size,
|
||||
ArithmeticParameter *param);
|
||||
int ElementOptDivFp16(float16_t *input0, float16_t *input1, float16_t *output, int element_size,
|
||||
ArithmeticParameter *param);
|
||||
int ElementOptDivReluFp16(float16_t *input0, float16_t *input1, float16_t *output, int element_size,
|
||||
ArithmeticParameter *param);
|
||||
int ElementOptDivRelu6Fp16(float16_t *input0, float16_t *input1, float16_t *output, int element_size,
|
||||
ArithmeticParameter *param);
|
||||
int ElementOptFloorModFp16(float16_t *input0, float16_t *input1, float16_t *output, int element_size,
|
||||
ArithmeticParameter *param);
|
||||
int ElementOptFloorDivFp16(float16_t *input0, float16_t *input1, float16_t *output, int element_size,
|
||||
ArithmeticParameter *param);
|
||||
int ElementOptLogicalAndFp16(float16_t *input0, float16_t *input1, float16_t *output, int element_size,
|
||||
ArithmeticParameter *param);
|
||||
int ElementOptLogicalOrFp16(float16_t *input0, float16_t *input1, float16_t *output, int element_size,
|
||||
ArithmeticParameter *param);
|
||||
int ElementOptSquaredDifferenceFp16(float16_t *input0, float16_t *input1, float16_t *output, int element_size,
|
||||
ArithmeticParameter *param);
|
||||
int ElementOptMaximumFp16(float16_t *input0, float16_t *input1, float16_t *output, int element_size,
|
||||
ArithmeticParameter *param);
|
||||
int ElementOptMinimumFp16(float16_t *input0, float16_t *input1, float16_t *output, int element_size,
|
||||
ArithmeticParameter *param);
|
||||
int ElementOptNotEqualFp16(float16_t *input0, float16_t *input1, float16_t *output, int element_size,
|
||||
ArithmeticParameter *param);
|
||||
int ElementOptEqualFp16(float16_t *input0, float16_t *input1, float16_t *output, int element_size,
|
||||
ArithmeticParameter *param);
|
||||
int ElementOptLessFp16(float16_t *input0, float16_t *input1, float16_t *output, int element_size,
|
||||
ArithmeticParameter *param);
|
||||
int ElementOptLessEqualFp16(float16_t *input0, float16_t *input1, float16_t *output, int element_size,
|
||||
ArithmeticParameter *param);
|
||||
int ElementOptGreaterFp16(float16_t *input0, float16_t *input1, float16_t *output, int element_size,
|
||||
ArithmeticParameter *param);
|
||||
int ElementOptGreaterEqualFp16(float16_t *input0, float16_t *input1, float16_t *output, int element_size,
|
||||
ArithmeticParameter *param);
|
||||
|
||||
int ElementMulFp16(float16_t *input0, float16_t *input1, float16_t *output, int element_size);
|
||||
int ElementMulReluFp16(float16_t *input0, float16_t *input1, float16_t *output, int element_size);
|
||||
int ElementMulRelu6Fp16(float16_t *input0, float16_t *input1, float16_t *output, int element_size);
|
||||
|
|
Loading…
Reference in New Issue