diff --git a/mindspore/lite/nnacl/fp16/arithmetic_self_fp16.c b/mindspore/lite/nnacl/fp16/arithmetic_self_fp16.c new file mode 100644 index 00000000000..3296b5bdd51 --- /dev/null +++ b/mindspore/lite/nnacl/fp16/arithmetic_self_fp16.c @@ -0,0 +1,110 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include "nnacl/fp16/arithmetic_self_fp16.h" + +int ElementAbsFp16(float16_t *input, float16_t *output, int element_size) { + for (int i = 0; i < element_size; i++) { + output[i] = fabsf(input[i]); + } + return NNACL_OK; +} + +int ElementCosFp16(float16_t *input, float16_t *output, int element_size) { + for (int i = 0; i < element_size; i++) { + output[i] = cosf(input[i]); + } + return NNACL_OK; +} + +int ElementLogFp16(float16_t *input, float16_t *output, int element_size) { + for (int i = 0; i < element_size; i++) { + if (input[i] <= 0) { + return NNACL_ERRCODE_LOG_NEGATIVE_OR_ZERO; + } + output[i] = logf(input[i]); + } + return NNACL_OK; +} + +int ElementSquareFp16(float16_t *input, float16_t *output, int element_size) { + for (int i = 0; i < element_size; i++) { + output[i] = input[i] * input[i]; + } + return NNACL_OK; +} + +int ElementSqrtFp16(float16_t *input, float16_t *output, int element_size) { + for (int i = 0; i < element_size; i++) { + if (input[i] < 0) { + return NNACL_ERRCODE_SQRT_NEGATIVE; + } + output[i] = sqrtf(input[i]); + } + return NNACL_OK; +} + +int ElementRsqrtFp16(float16_t *input, float16_t *output, int element_size) { + for (int i = 0; i < element_size; i++) { + if (input[i] <= 0) { + return NNACL_ERRCODE_RSQRT_NEGATIVE_OR_ZERO; + } + output[i] = 1.f / sqrtf(input[i]); + } + return NNACL_OK; +} + +int ElementSinFp16(float16_t *input, float16_t *output, int element_size) { + for (int i = 0; i < element_size; i++) { + output[i] = sinf(input[i]); + } + return NNACL_OK; +} + +int ElementLogicalNotFp16(float16_t *input, float16_t *output, int element_size) { + for (int i = 0; i < element_size; i++) { + output[i] = (float)(!((bool)(input[i]))); + } + return NNACL_OK; +} + +int ElementRoundFp16(float16_t *input, float16_t *output, int element_size) { + for (int i = 0; i < element_size; i++) { + output[i] = round(input[i]); + } + return NNACL_OK; +} + +int ElementFloorFp16(float16_t *input, float16_t *output, int element_size) { + for (int i = 0; i < element_size; i++) { + output[i] = floorf(input[i]); + } + return NNACL_OK; +} + +int ElementCeilFp16(float16_t *input, float16_t *output, int number) { + for (int i = 0; i < number; ++i) { + output[i] = ceil(input[i]); + } + return NNACL_OK; +} + +int ElementNegativeFp16(float16_t *input, float16_t *output, int element_size) { + for (int i = 0; i < element_size; ++i) { + output[i] = -input[i]; + } + return NNACL_OK; +} diff --git a/mindspore/lite/nnacl/fp16/arithmetic_self_fp16.h b/mindspore/lite/nnacl/fp16/arithmetic_self_fp16.h new file mode 100644 index 00000000000..21590a6b24f --- /dev/null +++ b/mindspore/lite/nnacl/fp16/arithmetic_self_fp16.h @@ -0,0 +1,55 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_LITE_NNACL_FP16_ARITHMETIC_SELF_FP16_H_ +#define MINDSPORE_LITE_NNACL_FP16_ARITHMETIC_SELF_FP16_H_ + +#ifdef ENABLE_NEON +#include +#endif +#include "nnacl/op_base.h" +#include "nnacl/errorcode.h" + +#ifdef __cplusplus +extern "C" { +#endif +int ElementAbsFp16(float16_t *input, float16_t *output, int element_size); + +int ElementCosFp16(float16_t *input, float16_t *output, int element_size); + +int ElementLogFp16(float16_t *input, float16_t *output, int element_size); + +int ElementSquareFp16(float16_t *input, float16_t *output, int element_size); + +int ElementSqrtFp16(float16_t *input, float16_t *output, int element_size); + +int ElementRsqrtFp16(float16_t *input, float16_t *output, int element_size); + +int ElementSinFp16(float16_t *input, float16_t *output, int element_size); + +int ElementLogicalNotFp16(float16_t *input, float16_t *output, int element_size); + +int ElementRoundFp16(float16_t *input, float16_t *output, int element_size); + +int ElementFloorFp16(float16_t *input, float16_t *output, int element_size); + +int ElementCeilFp16(float16_t *input, float16_t *output, int number); + +int ElementNegativeFp16(float16_t *input, float16_t *output, int element_size); +#ifdef __cplusplus +} +#endif + +#endif // MINDSPORE_LITE_NNACL_FP16_ARITHMETIC_SELF_FP16_H_ diff --git a/mindspore/lite/src/runtime/kernel/arm/fp16/arithmetic_self_fp16.cc b/mindspore/lite/src/runtime/kernel/arm/fp16/arithmetic_self_fp16.cc new file mode 100644 index 00000000000..6140e18b046 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/fp16/arithmetic_self_fp16.cc @@ -0,0 +1,142 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "src/runtime/kernel/arm/fp16/arithmetic_self_fp16.h" +#include "src/runtime/kernel/arm/fp16/common_fp16.h" +#include "src/kernel_registry.h" +#include "nnacl/fp16/cast_fp16.h" +#include "nnacl/fp16/arithmetic_self_fp16.h" + +using mindspore::lite::KernelRegistrar; + +namespace mindspore::kernel { +namespace { +typedef struct { + int primitive_type_; + ArithmeticSelfFp16Func func_; +} TYPE_FUNC_INFO; +} // namespace + +ArithmeticSelfFp16Func ArithmeticSelfFp16CPUKernel::GetArithmeticSelfFp16Fun(int primitive_type) { + TYPE_FUNC_INFO type_func_table[] = {{mindspore::schema::PrimitiveType_Abs, ElementAbsFp16}, + {mindspore::schema::PrimitiveType_Cos, ElementCosFp16}, + {mindspore::schema::PrimitiveType_Log, ElementLogFp16}, + {mindspore::schema::PrimitiveType_Square, ElementSquareFp16}, + {mindspore::schema::PrimitiveType_Sqrt, ElementSqrtFp16}, + {mindspore::schema::PrimitiveType_Rsqrt, ElementRsqrtFp16}, + {mindspore::schema::PrimitiveType_Sin, ElementSinFp16}, + {mindspore::schema::PrimitiveType_LogicalNot, ElementLogicalNotFp16}, + {mindspore::schema::PrimitiveType_Floor, ElementFloorFp16}, + {mindspore::schema::PrimitiveType_Ceil, ElementCeilFp16}, + {mindspore::schema::PrimitiveType_Round, ElementRoundFp16}, + {mindspore::schema::PrimitiveType_Neg, ElementNegativeFp16}}; + for (size_t i = 0; i < sizeof(type_func_table); i++) { + if (type_func_table[i].primitive_type_ == primitive_type) { + return type_func_table[i].func_; + } + } + return nullptr; +} + +int ArithmeticSelfFp16CPUKernel::DoExecute(int task_id) { + int elements_num = in_tensors_.at(0)->ElementsNum(); + int stride = UP_DIV(elements_num, op_parameter_->thread_num_); + int offset = task_id * stride; + int count = MSMIN(stride, elements_num - offset); + if (count <= 0) { + return RET_OK; + } + if (fp16_func_ == nullptr) { + MS_LOG(ERROR) << "Run function is null! "; + return RET_ERROR; + } + auto ret = fp16_func_(input_fp16_ptr_ + offset, output_fp16_ptr_ + offset, count); + if (ret != RET_OK) { + MS_LOG(ERROR) << "Run failed, illegal input! "; + } + return ret; +} + +void ArithmeticSelfFp16CPUKernel::FreeInputAndOutput() { + if (in_tensors_.at(0)->data_type() == kNumberTypeFloat32) { + context_->allocator->Free(input_fp16_ptr_); + input_fp16_ptr_ = nullptr; + } + if (out_tensors_.at(0)->data_type() == kNumberTypeFloat32) { + context_->allocator->Free(output_fp16_ptr_); + output_fp16_ptr_ = nullptr; + } +} + +int ArithmeticSelfFp16CPUKernel::Run() { + auto ret = Prepare(); + if (ret != RET_OK) { + MS_LOG(ERROR) << "Prepare fail! ret: " << ret; + return ret; + } + auto input_tensor = in_tensors_.at(0); + auto output_tensor = out_tensors_.at(0); + input_fp16_ptr_ = ConvertInputFp32toFp16(input_tensor, context_); + output_fp16_ptr_ = MallocOutputFp16(output_tensor, context_); + if (input_fp16_ptr_ == nullptr || output_fp16_ptr_ == nullptr) { + FreeInputAndOutput(); + MS_LOG(ERROR) << "input or output is nullptr"; + return RET_ERROR; + } + ret = ParallelLaunch(this->context_->thread_pool_, ArithmeticSelfRun, this, op_parameter_->thread_num_); + if (ret != RET_OK) { + MS_LOG(ERROR) << "ArithmeticSelfRun error error_code[" << ret << "]"; + } + if (out_tensors_.at(0)->data_type() == kNumberTypeFloat32) { + Float16ToFloat32(output_fp16_ptr_, reinterpret_cast(output_tensor->MutableData()), + output_tensor->ElementsNum()); + } + FreeInputAndOutput(); + return ret; +} + +kernel::LiteKernel *CpuArithmeticSelfFp16KernelCreator(const std::vector &inputs, + const std::vector &outputs, + OpParameter *parameter, const lite::InnerContext *ctx, + const kernel::KernelKey &desc, + const mindspore::lite::PrimitiveC *primitive) { + auto *kernel = new (std::nothrow) ArithmeticSelfFp16CPUKernel(parameter, inputs, outputs, ctx, primitive); + if (kernel == nullptr) { + MS_LOG(ERROR) << "new ArithmeticSelfFp16CPUKernel fail!"; + return nullptr; + } + auto ret = kernel->Init(); + if (ret != RET_OK) { + MS_LOG(ERROR) << "Init kernel failed, name: " << parameter->name_ + << ", type: " << schema::EnumNamePrimitiveType(static_cast(parameter->type_)); + delete kernel; + return nullptr; + } + return kernel; +} + +REG_KERNEL(kCPU, kNumberTypeFloat16, PrimitiveType_Abs, CpuArithmeticSelfFp16KernelCreator) +REG_KERNEL(kCPU, kNumberTypeFloat16, PrimitiveType_Cos, CpuArithmeticSelfFp16KernelCreator) +REG_KERNEL(kCPU, kNumberTypeFloat16, PrimitiveType_Log, CpuArithmeticSelfFp16KernelCreator) +REG_KERNEL(kCPU, kNumberTypeFloat16, PrimitiveType_Square, CpuArithmeticSelfFp16KernelCreator) +REG_KERNEL(kCPU, kNumberTypeFloat16, PrimitiveType_Sqrt, CpuArithmeticSelfFp16KernelCreator) +REG_KERNEL(kCPU, kNumberTypeFloat16, PrimitiveType_Rsqrt, CpuArithmeticSelfFp16KernelCreator) +REG_KERNEL(kCPU, kNumberTypeFloat16, PrimitiveType_Sin, CpuArithmeticSelfFp16KernelCreator) +REG_KERNEL(kCPU, kNumberTypeFloat16, PrimitiveType_LogicalNot, CpuArithmeticSelfFp16KernelCreator) +REG_KERNEL(kCPU, kNumberTypeFloat16, PrimitiveType_Floor, CpuArithmeticSelfFp16KernelCreator) +REG_KERNEL(kCPU, kNumberTypeFloat16, PrimitiveType_Ceil, CpuArithmeticSelfFp16KernelCreator) +REG_KERNEL(kCPU, kNumberTypeFloat16, PrimitiveType_Round, CpuArithmeticSelfFp16KernelCreator) +REG_KERNEL(kCPU, kNumberTypeFloat16, PrimitiveType_Neg, CpuArithmeticSelfFp16KernelCreator) +} // namespace mindspore::kernel diff --git a/mindspore/lite/src/runtime/kernel/arm/fp16/arithmetic_self_fp16.h b/mindspore/lite/src/runtime/kernel/arm/fp16/arithmetic_self_fp16.h new file mode 100644 index 00000000000..a16b6287e11 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/fp16/arithmetic_self_fp16.h @@ -0,0 +1,46 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP16_ARITHMETIC_SELF_FP16_H_ +#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP16_ARITHMETIC_SELF_FP16_H_ + +#include +#include "src/runtime/kernel/arm/fp32/arithmetic_self.h" + +namespace mindspore::kernel { +typedef int (*ArithmeticSelfFp16Func)(float16_t *input, float16_t *output, int element_size); +class ArithmeticSelfFp16CPUKernel : public ArithmeticSelfCPUKernel { + public: + explicit ArithmeticSelfFp16CPUKernel(OpParameter *parameter, const std::vector &inputs, + const std::vector &outputs, const lite::InnerContext *ctx, + const mindspore::lite::PrimitiveC *primitive) + : ArithmeticSelfCPUKernel(parameter, inputs, outputs, ctx, primitive) { + fp16_func_ = GetArithmeticSelfFp16Fun(parameter->type_); + } + ~ArithmeticSelfFp16CPUKernel() override = default; + + int Run() override; + int DoExecute(int task_id) override; + + private: + void FreeInputAndOutput(); + ArithmeticSelfFp16Func GetArithmeticSelfFp16Fun(int primitive_type); + ArithmeticSelfFp16Func fp16_func_ = nullptr; + float16_t *input_fp16_ptr_ = nullptr; + float16_t *output_fp16_ptr_ = nullptr; +}; +} // namespace mindspore::kernel + +#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP16_ARITHMETIC_SELF_FP16_H_ diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/arithmetic_self.cc b/mindspore/lite/src/runtime/kernel/arm/fp32/arithmetic_self.cc index 98bf1fbd29c..bb1bd6411b4 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32/arithmetic_self.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/arithmetic_self.cc @@ -13,99 +13,107 @@ * See the License for the specific language governing permissions and * limitations under the License. */ - #include "src/runtime/kernel/arm/fp32/arithmetic_self.h" -#include "schema/model_generated.h" #include "src/kernel_registry.h" -#include "include/errorcode.h" -#include "src/runtime/runtime_api.h" +#include "nnacl/fp32/arithmetic_self.h" -using mindspore::kernel::KERNEL_ARCH::kCPU; using mindspore::lite::KernelRegistrar; -using mindspore::lite::RET_ERROR; -using mindspore::lite::RET_OK; namespace mindspore::kernel { +namespace { +typedef struct { + int primitive_type_; + ArithmeticSelfFunc func_; +} TYPE_FUNC_INFO; +} // namespace + +ArithmeticSelfFunc ArithmeticSelfCPUKernel::GetArithmeticSelfFun(int primitive_type) { + TYPE_FUNC_INFO type_func_table[] = {{mindspore::schema::PrimitiveType_Abs, ElementAbs}, + {mindspore::schema::PrimitiveType_Cos, ElementCos}, + {mindspore::schema::PrimitiveType_Log, ElementLog}, + {mindspore::schema::PrimitiveType_Square, ElementSquare}, + {mindspore::schema::PrimitiveType_Sqrt, ElementSqrt}, + {mindspore::schema::PrimitiveType_Rsqrt, ElementRsqrt}, + {mindspore::schema::PrimitiveType_Sin, ElementSin}, + {mindspore::schema::PrimitiveType_LogicalNot, ElementLogicalNot}, + {mindspore::schema::PrimitiveType_Floor, ElementFloor}, + {mindspore::schema::PrimitiveType_Ceil, ElementCeil}, + {mindspore::schema::PrimitiveType_Round, ElementRound}, + {mindspore::schema::PrimitiveType_Neg, ElementNegative}}; + for (size_t i = 0; i < sizeof(type_func_table); i++) { + if (type_func_table[i].primitive_type_ == primitive_type) { + return type_func_table[i].func_; + } + } + return nullptr; +} + int ArithmeticSelfCPUKernel::Init() { if (!InferShapeDone()) { return RET_OK; } - return ReSize(); } -int ArithmeticSelfCPUKernel::ReSize() { - data_size_ = in_tensors_[0]->ElementsNum(); - thread_sz_count_ = MSMIN(thread_count_, static_cast(data_size_)); - thread_sz_stride_ = UP_DIV(data_size_, thread_sz_count_); - return RET_OK; -} +int ArithmeticSelfCPUKernel::ReSize() { return RET_OK; } -int ArithmeticSelfRuns(void *cdata, int task_id) { - auto g_kernel = reinterpret_cast(cdata); - auto ret = g_kernel->DoArithmeticSelf(task_id); - if (ret != RET_OK) { - MS_LOG(ERROR) << "ArithmeticSelfRuns error task_id[" << task_id << "] error_code[" << ret << "]"; - return ret; - } - return RET_OK; -} - -int ArithmeticSelfCPUKernel::DoArithmeticSelf(int task_id) { - int size = MSMIN(thread_sz_stride_, static_cast(data_size_ - task_id * thread_sz_stride_)); - if (size <= 0) { +int ArithmeticSelfCPUKernel::DoExecute(int task_id) { + int elements_num = in_tensors_.at(0)->ElementsNum(); + int stride = UP_DIV(elements_num, op_parameter_->thread_num_); + int offset = task_id * stride; + int count = MSMIN(stride, elements_num - offset); + if (count <= 0) { return RET_OK; } - int offset = task_id * thread_sz_stride_; - if (arithmeticSelf_run_) { - auto ret = arithmeticSelf_run_(in_ptr_ + offset, out_ptr_ + offset, size); - if (ret != RET_OK) { - MS_LOG(ERROR) << "Run failed, illegal input! "; - return ret; - } - } else { + if (func_ == nullptr) { MS_LOG(ERROR) << "Run function is null! "; return RET_ERROR; } - return RET_OK; + float *input_ptr = reinterpret_cast(in_tensors_.at(0)->MutableData()); + float *output_ptr = reinterpret_cast(out_tensors_.at(0)->MutableData()); + auto ret = func_(input_ptr + offset, output_ptr + offset, count); + if (ret != RET_OK) { + MS_LOG(ERROR) << "Run failed, illegal input! "; + } + return ret; } + +int ArithmeticSelfRun(void *cdata, int task_id) { + auto kernel = reinterpret_cast(cdata); + auto ret = kernel->DoExecute(task_id); + if (ret != RET_OK) { + MS_LOG(ERROR) << "ArithmeticSelfRuns error task_id[" << task_id << "] error_code[" << ret << "]"; + } + return ret; +} + int ArithmeticSelfCPUKernel::Run() { auto ret = Prepare(); if (ret != RET_OK) { - MS_LOG(ERROR) << "Prepare fail!ret: " << ret; + MS_LOG(ERROR) << "Prepare fail! ret: " << ret; return ret; } - auto input_tensor = in_tensors_.at(0); - auto out_tensor = out_tensors_.at(0); - in_ptr_ = reinterpret_cast(input_tensor->MutableData()); - out_ptr_ = reinterpret_cast(out_tensor->MutableData()); - ret = ParallelLaunch(this->context_->thread_pool_, ArithmeticSelfRuns, this, thread_sz_count_); + ret = ParallelLaunch(this->context_->thread_pool_, ArithmeticSelfRun, this, op_parameter_->thread_num_); if (ret != RET_OK) { MS_LOG(ERROR) << "ArithmeticSelfRun error error_code[" << ret << "]"; - return ret; } - return RET_OK; + return ret; } kernel::LiteKernel *CpuArithmeticSelfFp32KernelCreator(const std::vector &inputs, const std::vector &outputs, - OpParameter *opParameter, const lite::InnerContext *ctx, + OpParameter *parameter, const lite::InnerContext *ctx, const kernel::KernelKey &desc, const mindspore::lite::PrimitiveC *primitive) { - MS_ASSERT(opParameter != nullptr); - if (opParameter == nullptr) { - MS_LOG(ERROR) << "Creator failed, opParameter is nullptr!"; - return nullptr; - } - auto *kernel = new (std::nothrow) ArithmeticSelfCPUKernel(opParameter, inputs, outputs, ctx, primitive); + auto *kernel = new (std::nothrow) ArithmeticSelfCPUKernel(parameter, inputs, outputs, ctx, primitive); if (kernel == nullptr) { MS_LOG(ERROR) << "new ArithmeticSelfCPUKernel fail!"; return nullptr; } auto ret = kernel->Init(); if (ret != RET_OK) { - MS_LOG(ERROR) << "Init kernel failed, name: " << opParameter->name_ << ", type: " - << schema::EnumNamePrimitiveType(static_cast(opParameter->type_)); + MS_LOG(ERROR) << "Init kernel failed, name: " << parameter->name_ + << ", type: " << schema::EnumNamePrimitiveType(static_cast(parameter->type_)); delete kernel; return nullptr; } diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/arithmetic_self.h b/mindspore/lite/src/runtime/kernel/arm/fp32/arithmetic_self.h index 50216d5e848..4059152644b 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32/arithmetic_self.h +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/arithmetic_self.h @@ -13,18 +13,12 @@ * See the License for the specific language governing permissions and * limitations under the License. */ - #ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_ARITHMETIC_SELF_H_ #define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_ARITHMETIC_SELF_H_ #include #include "src/lite_kernel.h" -#include "nnacl/fp32/arithmetic_self.h" -#include "nnacl/arithmetic_self_parameter.h" -#include "schema/model_generated.h" -#include "include/context.h" -using mindspore::lite::InnerContext; using mindspore::schema::PrimitiveType_Abs; using mindspore::schema::PrimitiveType_Ceil; using mindspore::schema::PrimitiveType_Cos; @@ -39,73 +33,27 @@ using mindspore::schema::PrimitiveType_Sqrt; using mindspore::schema::PrimitiveType_Square; namespace mindspore::kernel { +typedef int (*ArithmeticSelfFunc)(float *input, float *output, int element_size); class ArithmeticSelfCPUKernel : public LiteKernel { - typedef int (*ArithmeticSelfRun)(float *input, float *output, int element_size); - public: explicit ArithmeticSelfCPUKernel(OpParameter *parameter, const std::vector &inputs, const std::vector &outputs, const lite::InnerContext *ctx, const mindspore::lite::PrimitiveC *primitive) - : LiteKernel(parameter, inputs, outputs, ctx, primitive), thread_count_(ctx->thread_num_) { - switch (parameter->type_) { - case PrimitiveType_Abs: - arithmeticSelf_run_ = ElementAbs; - break; - case PrimitiveType_Cos: - arithmeticSelf_run_ = ElementCos; - break; - case PrimitiveType_Log: - arithmeticSelf_run_ = ElementLog; - break; - case PrimitiveType_Square: - arithmeticSelf_run_ = ElementSquare; - break; - case PrimitiveType_Sqrt: - arithmeticSelf_run_ = ElementSqrt; - break; - case PrimitiveType_Rsqrt: - arithmeticSelf_run_ = ElementRsqrt; - break; - case PrimitiveType_Sin: - arithmeticSelf_run_ = ElementSin; - break; - case PrimitiveType_LogicalNot: - arithmeticSelf_run_ = ElementLogicalNot; - break; - case PrimitiveType_Floor: - arithmeticSelf_run_ = ElementFloor; - break; - case PrimitiveType_Ceil: - arithmeticSelf_run_ = ElementCeil; - break; - case PrimitiveType_Round: - arithmeticSelf_run_ = ElementRound; - break; - case PrimitiveType_Neg: - arithmeticSelf_run_ = ElementNegative; - break; - default: - break; - } - arithmeticSelfParameter_ = reinterpret_cast(parameter); + : LiteKernel(parameter, inputs, outputs, ctx, primitive) { + func_ = GetArithmeticSelfFun(parameter->type_); } ~ArithmeticSelfCPUKernel() override = default; int Init() override; int ReSize() override; int Run() override; - int DoArithmeticSelf(int task_id); + virtual int DoExecute(int task_id); private: - int thread_sz_count_; - int thread_sz_stride_; - size_t data_size_; - ArithmeticSelfParameter *arithmeticSelfParameter_; - ArithmeticSelfRun arithmeticSelf_run_; - int thread_count_; - float *in_ptr_; - float *out_ptr_; + ArithmeticSelfFunc GetArithmeticSelfFun(int primitive_type); + ArithmeticSelfFunc func_; }; +int ArithmeticSelfRun(void *cdata, int task_id); } // namespace mindspore::kernel #endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_ARITHMETIC_SELF_H_