diff --git a/mindspore/lite/src/kernel_factory.cc b/mindspore/lite/src/kernel_factory.cc index aa81d23c905..b835506bec1 100644 --- a/mindspore/lite/src/kernel_factory.cc +++ b/mindspore/lite/src/kernel_factory.cc @@ -43,19 +43,11 @@ LiteKernel *KernelFactory::GetKernel(const std::vector &inputs MS_LOG(ERROR) << "PopulateParameter return nullptr, type: " << schema::EnumNamePrimitiveType(primitive->Type()); return nullptr; } - auto creator = KernelRegistry::GetInstance()->GetKernelCreator(key); + auto creator = KernelRegistry::GetInstance()->GetCreator(key); if (creator != nullptr) { - auto *kernel = creator(inputs, outputs, parameter, ctx, key); - if (kernel != nullptr) { - return kernel; - } else { - MS_LOG(ERROR) << "Creator kernel failed for " << schema::EnumNamePrimitiveType(key.type); - return nullptr; - } - } else { - MS_LOG(ERROR) << "Can not find OpCreator for " << schema::EnumNamePrimitiveType(key.type); - return nullptr; + auto kernel = creator(inputs, outputs, parameter, ctx, key); + return kernel; } + return nullptr; } } // namespace mindspore::lite - diff --git a/mindspore/lite/src/kernel_factory.h b/mindspore/lite/src/kernel_factory.h index 136008959bf..2065e208d0c 100644 --- a/mindspore/lite/src/kernel_factory.h +++ b/mindspore/lite/src/kernel_factory.h @@ -38,4 +38,3 @@ class KernelFactory { } // namespace mindspore::lite #endif // MINDSPORE_LITE_SRC_KERNEL_FACTORY_H_ - diff --git a/mindspore/lite/src/kernel_registry.cc b/mindspore/lite/src/kernel_registry.cc index 9f48c6d5952..2b16fc9ef7d 100644 --- a/mindspore/lite/src/kernel_registry.cc +++ b/mindspore/lite/src/kernel_registry.cc @@ -13,47 +13,105 @@ * See the License for the specific language governing permissions and * limitations under the License. */ - #include "src/kernel_registry.h" +#include "include/errorcode.h" +#include "ir/dtype/type_id.h" +#ifdef ENABLE_ARM64 +#include +#include "common/utils.h" +#include "utils/log_adapter.h" +#include "src/runtime/kernel/arm/opclib/optimized_kernel.h" +#endif +using mindspore::kernel::kCPU; +using mindspore::kernel::KERNEL_ARCH; using mindspore::kernel::KernelCreator; using mindspore::kernel::KernelKey; -using mindspore::kernel::KERNEL_ARCH; +using mindspore::kernel::kKernelArch_MAX; +using mindspore::kernel::kKernelArch_MIN; +using mindspore::schema::PrimitiveType_MAX; +using mindspore::schema::PrimitiveType_MIN; namespace mindspore::lite { KernelRegistry::KernelRegistry() {} -KernelRegistry::~KernelRegistry() {} +KernelRegistry::~KernelRegistry() { FreeCreatorArray(); } KernelRegistry *KernelRegistry::GetInstance() { static KernelRegistry instance; return &instance; } -KernelCreator KernelRegistry::GetKernelCreator(const KernelKey &desc) { - auto it = creators.find(desc); - if (it != creators.end()) { - return it->second; +int KernelRegistry::Init() { + lock_.lock(); + if (creator_arrays_ != nullptr) { + lock_.unlock(); + return RET_OK; } + device_type_length_ = kKernelArch_MAX - kKernelArch_MIN; + data_type_length_ = kNumberTypeEnd - kNumberTypeBegin; + op_type_length_ = PrimitiveType_MAX - PrimitiveType_MIN; + // malloc an array contain creator functions of kernel + auto total_len = device_type_length_ * data_type_length_ * op_type_length_; + creator_arrays_ = (kernel::KernelCreator *)malloc(total_len * sizeof(kernel::KernelCreator)); + if (creator_arrays_ == nullptr) { + MS_LOG(ERROR) << "malloc creator_arrays_ failed."; + lock_.unlock(); + return RET_ERROR; + } + for (int i = 0; i < total_len; ++i) { + creator_arrays_[i] = nullptr; + } +#ifdef ENABLE_ARM64 + void *optimized_lib_handler = OptimizeModule::GetInstance()->optimized_op_handler_; + if (optimized_lib_handler != nullptr) { + MS_LOG(INFO) << "load optimize lib success."; + } else { + MS_LOG(INFO) << "load optimize lib failed."; + } +#endif + lock_.unlock(); + return RET_OK; +} - // if not find, use cpu kernel - KernelKey cpuDesc {kernel::KERNEL_ARCH::kCPU, desc.type}; - it = creators.find(cpuDesc); - if (it != creators.end()) { - return it->second; +void KernelRegistry::FreeCreatorArray() { + if (creator_arrays_ != nullptr) { + free(creator_arrays_); + creator_arrays_ = nullptr; + } +} + +kernel::KernelCreator KernelRegistry::GetCreator(const KernelKey &desc) { + int index = GetCreatorFuncIndex(desc); + auto it = creator_arrays_[index]; + if (it != nullptr) { + return it; } return nullptr; } -void KernelRegistry::RegKernel(const KernelKey desc, KernelCreator creator) { creators[desc] = creator; } +int KernelRegistry::GetCreatorFuncIndex(const kernel::KernelKey desc) { + int index; + int device_index = static_cast(desc.arch); + int dType_index = static_cast(desc.data_type); + int op_index = static_cast(desc.type); + index = device_index * data_type_length_ * op_type_length_ + dType_index * op_type_length_ + op_index; + return index; +} -void KernelRegistry::RegKernel(const KERNEL_ARCH arch, const schema::PrimitiveType type, KernelCreator creator) { - KernelKey desc = {arch, type}; - creators[desc] = creator; +void KernelRegistry::RegKernel(const KernelKey desc, kernel::KernelCreator creator) { + int index = GetCreatorFuncIndex(desc); + creator_arrays_[index] = creator; +} + +void KernelRegistry::RegKernel(const KERNEL_ARCH arch, const TypeId data_type, const schema::PrimitiveType op_type, + kernel::KernelCreator creator) { + KernelKey desc = {arch, data_type, op_type}; + int index = GetCreatorFuncIndex(desc); + creator_arrays_[index] = creator; } bool KernelRegistry::Merge(const std::unordered_map &newCreators) { return false; } -const std::map &KernelRegistry::GetKernelCreators() { return creators; } +const kernel::KernelCreator *KernelRegistry::GetCreatorArrays() { return creator_arrays_; } } // namespace mindspore::lite - diff --git a/mindspore/lite/src/kernel_registry.h b/mindspore/lite/src/kernel_registry.h index 7873ac2a751..eab7d03a53e 100644 --- a/mindspore/lite/src/kernel_registry.h +++ b/mindspore/lite/src/kernel_registry.h @@ -30,16 +30,22 @@ class KernelRegistry { virtual ~KernelRegistry(); static KernelRegistry *GetInstance(); - virtual kernel::KernelCreator GetKernelCreator(const kernel::KernelKey &desc); - - const std::map &GetKernelCreators(); - + int Init(); + void FreeCreatorArray(); + virtual kernel::KernelCreator GetCreator(const kernel::KernelKey &desc); + const kernel::KernelCreator *GetCreatorArrays(); + int GetCreatorFuncIndex(const kernel::KernelKey desc); void RegKernel(const kernel::KernelKey desc, kernel::KernelCreator creator); - void RegKernel(const kernel::KERNEL_ARCH arch, const schema::PrimitiveType type, kernel::KernelCreator creator); + void RegKernel(const kernel::KERNEL_ARCH arch, const TypeId data_type, const schema::PrimitiveType type, + kernel::KernelCreator creator); bool Merge(const std::unordered_map &newCreators); protected: - std::map creators; + kernel::KernelCreator *creator_arrays_ = nullptr; + int device_type_length_; + int data_type_length_; + int op_type_length_; + std::mutex lock_; }; class KernelRegistrar { @@ -48,14 +54,14 @@ class KernelRegistrar { KernelRegistry::GetInstance()->RegKernel(desc, creator); } - KernelRegistrar(const kernel::KERNEL_ARCH arch, const schema::PrimitiveType type, kernel::KernelCreator creator) { - KernelRegistry::GetInstance()->RegKernel(arch, type, creator); + KernelRegistrar(const kernel::KERNEL_ARCH arch, const TypeId data_type, const schema::PrimitiveType op_type, + kernel::KernelCreator creator) { + KernelRegistry::GetInstance()->RegKernel(arch, data_type, op_type, creator); } }; -#define REG_KERNEL(arch, type, kernelCreater) \ - static KernelRegistrar g_##arch##type##kernelReg(arch, type, kernelCreater); +#define REG_KERNEL(arch, data_type, op_type, kernelCreater) \ + static KernelRegistrar g_##arch##data_type##op_type##kernelReg(arch, data_type, op_type, kernelCreater); } // namespace mindspore::lite #endif // MINDSPORE_LITE_SRC_KERNEL_REGISTRY_H_ - diff --git a/mindspore/lite/src/lite_kernel.h b/mindspore/lite/src/lite_kernel.h index ea6ba756483..30dbe88a038 100644 --- a/mindspore/lite/src/lite_kernel.h +++ b/mindspore/lite/src/lite_kernel.h @@ -18,7 +18,7 @@ #define MINDSPORE_LITE_SRC_LITE_KERNEL_H_ #include #include -#ifdef ENABLE_FP16 +#ifdef ENABLE_ARM #include #endif #include "src/runtime/kernel/arm/opclib/op_base.h" @@ -35,14 +35,17 @@ using FLOAT_t = float; // using mindspore::kernel::AddressPtr; namespace mindspore::kernel { -enum KERNEL_ARCH { kCPU, kGPU, kNPU, kInferShape }; +enum KERNEL_ARCH { kCPU, kGPU, kNPU, kKernelArch_MIN = kCPU, kKernelArch_MAX = kNPU }; struct KernelKey { KERNEL_ARCH arch; + TypeId data_type; schema::PrimitiveType type; bool operator<(const KernelKey &dst) const { if (arch != dst.arch) { return arch < dst.arch; + } else if (data_type != dst.data_type) { + return data_type < dst.data_type; } else { return type < dst.type; } @@ -179,4 +182,3 @@ class LiteKernelUtil { } // namespace mindspore::kernel #endif // MINDSPORE_LITE_SRC_LITE_KERNEL_H_ - diff --git a/mindspore/lite/src/lite_session.cc b/mindspore/lite/src/lite_session.cc index 69e5753e154..c5db4580dfe 100644 --- a/mindspore/lite/src/lite_session.cc +++ b/mindspore/lite/src/lite_session.cc @@ -24,6 +24,7 @@ #include "src/executor.h" #include "src/common/utils.h" #include "src/common/graph_util.h" +#include "src/kernel_registry.h" #if SUPPORT_GPU #include "src/runtime/opencl/opencl_runtime.h" #endif @@ -197,7 +198,11 @@ void LiteSession::Init(Context *context) { this->context->deviceCtx.type = context->deviceCtx.type; this->context->allocator = std::make_shared(); ConfigThreadPool(context->cpuBindMode, context->threadNum); - + auto ret = KernelRegistry::GetInstance()->Init(); + if (ret != RET_OK) { + MS_LOG(ERROR) << "KernelRegistry Init Failed."; + return; + } #if SUPPORT_GPU if (context->deviceCtx.type == DT_GPU) { auto opencl_runtime = lite::opencl::OpenCLRuntime::GetInstance(); @@ -228,6 +233,7 @@ LiteSession::~LiteSession() { delete kernel; } } + std::vector LiteSession::GetInputsByName(std::string name) { return input_map[name]; } diff --git a/mindspore/lite/src/runtime/kernel/arm/CMakeLists.txt b/mindspore/lite/src/runtime/kernel/arm/CMakeLists.txt index a2a39274c72..79342ea9ece 100644 --- a/mindspore/lite/src/runtime/kernel/arm/CMakeLists.txt +++ b/mindspore/lite/src/runtime/kernel/arm/CMakeLists.txt @@ -25,13 +25,5 @@ if (PLATFORM_ARM32) set(KERNEL_SRC ${KERNEL_SRC} ${ASSEMBLY_SRC}) endif() -if (ENABLE_FP16) - file(GLOB FP6_SRC - ${CMAKE_CURRENT_SOURCE_DIR}/fp16/*.cc - ${CMAKE_CURRENT_SOURCE_DIR}/opclib/fp16/*.cc - ) - set(KERNEL_SRC ${KERNEL_SRC} ${FP6_SRC}) -endif () - add_library(cpu_kernel_mid_ OBJECT ${KERNEL_SRC}) add_subdirectory(opclib) diff --git a/mindspore/lite/src/runtime/kernel/arm/base/concat_base.cc b/mindspore/lite/src/runtime/kernel/arm/base/concat_base.cc index bfb0689ffc4..15a7b417cbd 100644 --- a/mindspore/lite/src/runtime/kernel/arm/base/concat_base.cc +++ b/mindspore/lite/src/runtime/kernel/arm/base/concat_base.cc @@ -36,7 +36,8 @@ int ConcatBaseCPUKernel::Init() { kernel::LiteKernel *CpuConcatInt8KernelCreator(const std::vector &inputs, const std::vector &outputs, - OpParameter *opParameter, const Context *ctx) { + OpParameter *opParameter, const Context *ctx, + const kernel::KernelKey &desc) { if (opParameter == nullptr) { MS_LOG(ERROR) << "Input opParameter is nullptr!"; return nullptr; @@ -47,51 +48,6 @@ kernel::LiteKernel *CpuConcatInt8KernelCreator(const std::vector &inputs, - const std::vector &outputs, - OpParameter *opParameter, const Context *ctx) { - if (opParameter == nullptr) { - MS_LOG(ERROR) << "Input opParameter is nullptr!"; - return nullptr; - } - MS_ASSERT(desc.type == schema::PrimitiveType_Concat); - auto *kernel = new(std::nothrow) ConcatCPUKernel(opParameter, inputs, outputs, ctx); - if (kernel == nullptr) { - MS_LOG(ERROR) << "new ConcatCPUKernel fail!"; - return nullptr; - } - return kernel; -} - -kernel::LiteKernel *CpuConcatKernelCreator(const std::vector &inputs, - const std::vector &outputs, - OpParameter *opParameter, - const lite::Context *ctx, const kernel::KernelKey &desc) { - MS_ASSERT(opParameter != nullptr); - MS_ASSERT(desc.type == schema::PrimitiveType_Concat); - auto input_tensor = inputs.at(kInputIndex); - auto data_type = input_tensor->data_type(); - kernel::LiteKernel *kernel = nullptr; - switch (data_type) { - case kNumberTypeInt8: - case kNumberTypeUInt8: - kernel = CpuConcatInt8KernelCreator(inputs, outputs, opParameter, ctx); - break; - case kNumberTypeInt32: - case kNumberTypeFloat32: - kernel = CpuConcatFp32OrInt32KernelCreator(inputs, outputs, opParameter, ctx); - break; - default: - break; - } - - if (kernel == nullptr) { - MS_LOG(ERROR) << "kernel is nullptr."; - return nullptr; - } auto ret = kernel->Init(); if (ret != RET_OK) { delete kernel; @@ -102,6 +58,56 @@ kernel::LiteKernel *CpuConcatKernelCreator(const std::vector &inputs, + const std::vector &outputs, + OpParameter *opParameter, const Context *ctx, + const kernel::KernelKey &desc) { + if (opParameter == nullptr) { + MS_LOG(ERROR) << "Input opParameter is nullptr!"; + return nullptr; + } + MS_ASSERT(desc.type == schema::PrimitiveType_Concat); + auto *kernel = new(std::nothrow) ConcatCPUKernel(opParameter, inputs, outputs, ctx); + if (kernel == nullptr) { + MS_LOG(ERROR) << "new ConcatCPUKernel fail!"; + return nullptr; + } + auto ret = kernel->Init(); + if (ret != RET_OK) { + delete kernel; + MS_LOG(ERROR) << "Init kernel failed, name: " << opParameter->name_ << ", type: " + << schema::EnumNamePrimitiveType(static_cast(opParameter->type_)); + return nullptr; + } + return kernel; +} + +kernel::LiteKernel *CpuConcatFp32KernelCreator(const std::vector &inputs, + const std::vector &outputs, + OpParameter *opParameter, const Context *ctx, + const kernel::KernelKey &desc) {; + if (opParameter == nullptr) { + MS_LOG(ERROR) << "Input opParameter is nullptr!"; + return nullptr; + } + MS_ASSERT(desc.type == schema::PrimitiveType_Concat); + auto *kernel = new(std::nothrow) ConcatCPUKernel(opParameter, inputs, outputs, ctx); + if (kernel == nullptr) { + MS_LOG(ERROR) << "new ConcatCPUKernel fail!"; + return nullptr; + } + auto ret = kernel->Init(); + if (ret != RET_OK) { + delete kernel; + MS_LOG(ERROR) << "Init kernel failed, name: " << opParameter->name_ << ", type: " + << schema::EnumNamePrimitiveType(static_cast(opParameter->type_)); + return nullptr; + } + return kernel; +} + +REG_KERNEL(kCPU, kNumberTypeInt8, PrimitiveType_Concat, CpuConcatInt8KernelCreator) +REG_KERNEL(kCPU, kNumberTypeInt32, PrimitiveType_Concat, CpuConcatInt32KernelCreator) +REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_Concat, CpuConcatFp32KernelCreator) } // namespace mindspore::kernel diff --git a/mindspore/lite/src/runtime/kernel/arm/base/convolution_base.cc b/mindspore/lite/src/runtime/kernel/arm/base/convolution_base.cc index bb615d0ae73..07c1671ae41 100644 --- a/mindspore/lite/src/runtime/kernel/arm/base/convolution_base.cc +++ b/mindspore/lite/src/runtime/kernel/arm/base/convolution_base.cc @@ -15,24 +15,6 @@ */ #include "src/runtime/kernel/arm/base/convolution_base.h" -#include "src/runtime/kernel/arm/fp32/convolution.h" -#include "src/runtime/kernel/arm/fp32/convolution_winograd.h" -#include "src/runtime/kernel/arm/fp32/deconvolution.h" -#include "src/runtime/kernel/arm/fp32/convolution_1x1.h" -#include "src/runtime/kernel/arm/fp32/convolution_3x3.h" -#include "src/runtime/kernel/arm/fp32/convolution_depthwise.h" -#include "src/runtime/kernel/arm/fp32/deconvolution_depthwise.h" -#ifdef ENABLE_FP16 -#include "src/runtime/kernel/arm/fp16/convolution_fp16.h" -#include "src/runtime/kernel/arm/fp16/convolution_3x3_fp16.h" -#include "src/runtime/kernel/arm/fp16/convolution_depthwise_fp16.h" -#include "src/runtime/kernel/arm/fp16/deconvolution_depthwise_fp16.h" -#endif -#include "src/runtime/kernel/arm/int8/deconvolution_int8.h" -#include "src/runtime/kernel/arm/int8/convolution_int8.h" -#include "src/runtime/kernel/arm/int8/convolution_3x3_int8.h" -#include "src/runtime/kernel/arm/int8/convolution_depthwise_int8.h" -#include "src/runtime/kernel/arm/int8/deconvolution_depthwise_int8.h" #include "schema/model_generated.h" #include "src/kernel_factory.h" #include "include/errorcode.h" @@ -42,10 +24,6 @@ using mindspore::lite::RET_ERROR; using mindspore::lite::RET_OK; using mindspore::schema::ActivationType; using mindspore::schema::PadMode; -using mindspore::schema::PrimitiveType_Conv2D; -using mindspore::schema::PrimitiveType_DeConv2D; -using mindspore::schema::PrimitiveType_DeDepthwiseConv2D; -using mindspore::schema::PrimitiveType_DepthwiseConv2D; namespace mindspore::kernel { ConvolutionBaseCPUKernel::~ConvolutionBaseCPUKernel() { @@ -192,352 +170,4 @@ void ComputeQuantOutRange(ConvParameter *conv_param) { conv_param->conv_quant_arg_.out_act_min_[0] = min; conv_param->conv_quant_arg_.out_act_max_[0] = max; } - -void CheckIfUseWinograd(bool *use_winograd, int *output_unit, ConvParameter *conv_param, - InputTransformUnitFunc input_trans_func, OutputTransformUnitFunc output_trans_func) { - if (conv_param->kernel_w_ == conv_param->kernel_h_ && conv_param->dilation_h_ == 1 && conv_param->dilation_w_ == 1 && - conv_param->stride_h_ == 1 && conv_param->stride_w_ == 1) { - *output_unit = SelectOutputUnit(conv_param); - if (*output_unit > 1) { - *use_winograd = true; - int input_unit = conv_param->kernel_h_ + *output_unit - 1; - input_trans_func = GetInputTransFunc(input_unit); - if (input_trans_func == nullptr) { - MS_LOG(INFO) << "No matching input trans func. Turn back to common conv."; - *use_winograd = false; - } - output_trans_func = GetOutputTransFunc(input_unit, *output_unit); - if (output_trans_func == nullptr) { - MS_LOG(INFO) << "No matching output trans func. Turn back to common conv."; - *use_winograd = false; - } - } else { - *use_winograd = false; - } - } else { - *use_winograd = false; - } -} - -bool CheckSupportFP16() { - bool support_fp16 = false; -#ifdef ENABLE_ARM64 - void *optimize_op_handler = OptimizeModule::GetInstance()->optimized_op_handler_; - if (optimize_op_handler != nullptr) { - support_fp16 = true; - MS_LOG(INFO) << "Support FP16."; - } else { - support_fp16 = false; - MS_LOG(INFO) << "Your machine doesn't support fp16, return back to float32 kernel."; - } -#endif - return support_fp16; -} - -kernel::LiteKernel *CpuConvFloatKernelCreator(const std::vector &inputs, - const std::vector &outputs, - OpParameter *opParameter, const Context *ctx) { - auto conv_param = reinterpret_cast(opParameter); - int kernel_h = conv_param->kernel_h_; - int kernel_w = conv_param->kernel_w_; - int stride_h = conv_param->stride_h_; - int stride_w = conv_param->stride_w_; - int dilation_h = conv_param->dilation_h_; - int dilation_w = conv_param->dilation_w_; - conv_param->input_h_ = inputs.front()->Height(); - conv_param->input_w_ = inputs.front()->Width(); - conv_param->output_h_ = outputs.front()->Height(); - conv_param->output_w_ = outputs.front()->Width(); - bool use_winograd; - int out_unit; - InputTransformUnitFunc input_trans_func = nullptr; - OutputTransformUnitFunc output_trans_func = nullptr; - CheckIfUseWinograd(&use_winograd, &out_unit, conv_param, input_trans_func, output_trans_func); - bool support_fp16 = CheckSupportFP16(); - - if (kernel_h == 1 && kernel_w == 1) { - auto kernel = new (std::nothrow) Convolution1x1CPUKernel(opParameter, inputs, outputs, ctx); - return kernel; - } else if (kernel_h == 3 && kernel_w == 3 && stride_h == 1 && stride_w == 1 && dilation_h == 1 && dilation_w == 1) { - if (support_fp16) { -#ifdef ENABLE_FP16 - auto kernel = new (std::nothrow) Convolution3x3FP16CPUKernel(opParameter, inputs, outputs, ctx); - return kernel; -#endif - } - auto kernel = new (std::nothrow) Convolution3x3CPUKernel(opParameter, inputs, outputs, ctx); - return kernel; - } else if (use_winograd) { - auto kernel = new (std::nothrow) ConvolutionWinogradCPUKernel(opParameter, inputs, outputs, ctx, out_unit); - return kernel; - } else { - if (support_fp16) { -#ifdef ENABLE_FP16 - auto kernel = new (std::nothrow) ConvolutionFP16CPUKernel(opParameter, inputs, outputs, ctx); - return kernel; -#endif - } - auto kernel = new (std::nothrow) ConvolutionCPUKernel(opParameter, inputs, outputs, ctx); - return kernel; - } -} - -kernel::LiteKernel *CpuConvInt8KernelCreator(const std::vector &inputs, - const std::vector &outputs, - OpParameter *opParameter, const Context *ctx) { - auto conv_param = reinterpret_cast(opParameter); - int kernel_h = conv_param->kernel_h_; - int kernel_w = conv_param->kernel_w_; - int stride_h = conv_param->stride_h_; - int stride_w = conv_param->stride_w_; - int dilation_h = conv_param->dilation_h_; - int dilation_w = conv_param->dilation_w_; - - if (kernel_h == 3 && kernel_w == 3 && stride_h == 1 && stride_w == 1 && dilation_h == 1 && dilation_w == 1) { - auto kernel = new (std::nothrow) Convolution3x3Int8CPUKernel(opParameter, inputs, outputs, ctx); - return kernel; - } else { - auto kernel = new (std::nothrow) ConvolutionInt8CPUKernel(opParameter, inputs, outputs, ctx); - return kernel; - } -} - -kernel::LiteKernel *CpuConvKernelCreator(const std::vector &inputs, - const std::vector &outputs, OpParameter *opParameter, - const lite::Context *ctx, const kernel::KernelKey &desc) { - MS_ASSERT(opParameter != nullptr); - MS_ASSERT(desc.type == schema::PrimitiveType_Conv2D); - auto input_tensor = inputs.at(kInputIndex); - auto data_type = input_tensor->data_type(); - kernel::LiteKernel *kernel = nullptr; - switch (data_type) { - case kNumberTypeInt8: - kernel = CpuConvInt8KernelCreator(inputs, outputs, opParameter, ctx); - break; - case kNumberTypeFloat32: - kernel = CpuConvFloatKernelCreator(inputs, outputs, opParameter, ctx); - break; - default: - break; - } - - if (kernel == nullptr) { - MS_LOG(ERROR) << "kernel is nullptr."; - return nullptr; - } - auto ret = kernel->Init(); - if (ret != RET_OK) { - delete kernel; - MS_LOG(ERROR) << "Init kernel failed, name: " << opParameter->name_ << ", type: " - << schema::EnumNamePrimitiveType(static_cast(opParameter->type_)); - return nullptr; - } - return kernel; -} - -kernel::LiteKernel *CpuConvDwFp32KernelCreator(const std::vector &inputs, - const std::vector &outputs, - OpParameter *opParameter, const Context *ctx) { - auto kernel = new (std::nothrow) ConvolutionDepthwiseCPUKernel(opParameter, inputs, outputs, ctx); - if (kernel == nullptr) { - MS_LOG(ERROR) << "kernel is nullptr."; - return nullptr; - } - return kernel; -} - -#ifdef ENABLE_FP16 -kernel::LiteKernel *CpuConvDwFp16KernelCreator(const std::vector &inputs, - const std::vector &outputs, - OpParameter *opParameter, const Context *ctx) { - auto kernel = new (std::nothrow) ConvolutionDepthwiseFp16CPUKernel(opParameter, inputs, outputs, ctx); - if (kernel == nullptr) { - MS_LOG(ERROR) << "kernel is nullptr."; - return nullptr; - } - return kernel; -} -#endif - -kernel::LiteKernel *CpuConvDwInt8KernelCreator(const std::vector &inputs, - const std::vector &outputs, - OpParameter *opParameter, const Context *ctx) { - auto kernel = new (std::nothrow) ConvolutionDepthwiseInt8CPUKernel(opParameter, inputs, outputs, ctx); - if (kernel == nullptr) { - MS_LOG(ERROR) << "kernel is nullptr."; - return nullptr; - } - return kernel; -} - -kernel::LiteKernel *CpuConvDwKernelCreator(const std::vector &inputs, - const std::vector &outputs, OpParameter *opParameter, - const lite::Context *ctx, const kernel::KernelKey &desc) { - MS_ASSERT(opParameter != nullptr); - MS_ASSERT(desc.type == schema::PrimitiveType_DepthwiseConv2D); - auto input_tensor = inputs.at(kInputIndex); - auto data_type = input_tensor->data_type(); - kernel::LiteKernel *kernel = nullptr; - switch (data_type) { - case kNumberTypeInt8: - kernel = CpuConvDwInt8KernelCreator(inputs, outputs, opParameter, ctx); - break; - case kNumberTypeFloat32: -#ifdef ENABLE_FP16 - kernel = CpuConvDwFp16KernelCreator(inputs, outputs, opParameter, ctx); -#else - kernel = CpuConvDwFp32KernelCreator(inputs, outputs, opParameter, ctx); -#endif - break; - default: - break; - } - if (kernel == nullptr) { - MS_LOG(ERROR) << "kernel is nullptr."; - return nullptr; - } - auto ret = kernel->Init(); - if (ret != RET_OK) { - delete kernel; - MS_LOG(ERROR) << "Init kernel failed, name: " << opParameter->name_ << ", type: " - << schema::EnumNamePrimitiveType(static_cast(opParameter->type_)); - return nullptr; - } - return kernel; -} - -kernel::LiteKernel *CpuDeconvDwFp32KernelCreator(const std::vector &inputs, - const std::vector &outputs, - OpParameter *opParameter, const lite::Context *ctx) { - auto kernel = new (std::nothrow) DeconvolutionDepthwiseCPUKernel(opParameter, inputs, outputs, ctx); - if (kernel == nullptr) { - MS_LOG(ERROR) << "kernel is nullptr."; - return nullptr; - } - return kernel; -} - -#ifdef ENABLE_FP16 -kernel::LiteKernel *CpuDeconvDwFp16KernelCreator(const std::vector &inputs, - const std::vector &outputs, - OpParameter *opParameter, const lite::Context *ctx) { - auto kernel = new (std::nothrow) DeconvolutionDepthwiseFp16CPUKernel(opParameter, inputs, outputs, ctx); - if (kernel == nullptr) { - MS_LOG(ERROR) << "kernel is nullptr."; - return nullptr; - } - return kernel; -} -#endif - -kernel::LiteKernel *CpuDeconvDwInt8KernelCreator(const std::vector &inputs, - const std::vector &outputs, - OpParameter *opParameter, const lite::Context *ctx) { - auto kernel = new (std::nothrow) DeconvolutionDepthwiseInt8CPUKernel(opParameter, inputs, outputs, ctx); - if (kernel == nullptr) { - MS_LOG(ERROR) << "kernel is nullptr."; - return nullptr; - } - return kernel; -} - -kernel::LiteKernel *CpuDeconvDwKernelCreator(const std::vector &inputs, - const std::vector &outputs, - OpParameter *opParameter, const lite::Context *ctx, - const kernel::KernelKey &desc) { - MS_ASSERT(opParameter != nullptr); - MS_ASSERT(desc.type == schema::PrimitiveType_DepthwiseConv2D); - auto input_tensor = inputs.at(kInputIndex); - auto data_type = input_tensor->data_type(); - kernel::LiteKernel *kernel = nullptr; - switch (data_type) { - case kNumberTypeInt8: - kernel = CpuDeconvDwInt8KernelCreator(inputs, outputs, opParameter, ctx); - break; - case kNumberTypeFloat32: -#ifdef ENABLE_FP16 - kernel = CpuDeconvDwFp16KernelCreator(inputs, outputs, opParameter, ctx); -#else - kernel = CpuDeconvDwFp32KernelCreator(inputs, outputs, opParameter, ctx); -#endif - break; - default: - break; - } - if (kernel == nullptr) { - MS_LOG(ERROR) << "kernel is nullptr."; - return nullptr; - } - auto ret = kernel->Init(); - if (ret != RET_OK) { - delete kernel; - MS_LOG(ERROR) << "Init kernel failed, name: " << opParameter->name_ << ", type: " - << schema::EnumNamePrimitiveType(static_cast(opParameter->type_)); - return nullptr; - } - return kernel; -} - -kernel::LiteKernel *CpuDeConvFp32KernelCreator(const std::vector &inputs, - const std::vector &outputs, - OpParameter *opParameter, const lite::Context *ctx) { - auto kernel = new (std::nothrow) DeConvolutionCPUKernel(opParameter, inputs, outputs, ctx); - if (kernel == nullptr) { - MS_LOG(ERROR) << "kernel is nullptr."; - return nullptr; - } - return kernel; -} - -kernel::LiteKernel *CpuDeConvInt8KernelCreator(const std::vector &inputs, - const std::vector &outputs, - OpParameter *opParameter, const lite::Context *ctx) { - auto kernel = new (std::nothrow) DeConvInt8CPUKernel(opParameter, inputs, outputs, ctx); - if (kernel == nullptr) { - MS_LOG(ERROR) << "kernel is nullptr."; - return nullptr; - } - return kernel; -} - -kernel::LiteKernel *CpuDeConvKernelCreator(const std::vector &inputs, - const std::vector &outputs, OpParameter *opParameter, - const lite::Context *ctx, const kernel::KernelKey &desc) { - MS_ASSERT(opParameter != nullptr); - MS_ASSERT(desc.type == schema::PrimitiveType_DeConv2D); - auto input_tensor = inputs.at(kInputIndex); - auto data_type = input_tensor->data_type(); - kernel::LiteKernel *kernel = nullptr; - switch (data_type) { - case kNumberTypeInt8: - kernel = CpuDeConvInt8KernelCreator(inputs, outputs, opParameter, ctx); - break; -#ifdef ENABLE_FP16 - case kNumberTypeFloat16: - break; -#endif - case kNumberTypeFloat32: - kernel = CpuDeConvFp32KernelCreator(inputs, outputs, opParameter, ctx); - break; - default: - break; - } - if (kernel == nullptr) { - MS_LOG(ERROR) << "kernel is nullptr."; - return nullptr; - } - auto ret = kernel->Init(); - if (ret != RET_OK) { - delete kernel; - MS_LOG(ERROR) << "Init kernel failed, name: " << opParameter->name_ << ", type: " - << schema::EnumNamePrimitiveType(static_cast(opParameter->type_)); - return nullptr; - } - return kernel; -} - -REG_KERNEL(kCPU, PrimitiveType_Conv2D, CpuConvKernelCreator) -REG_KERNEL(kCPU, PrimitiveType_DeConv2D, CpuDeConvKernelCreator) -REG_KERNEL(kCPU, PrimitiveType_DepthwiseConv2D, CpuConvDwKernelCreator) -REG_KERNEL(kCPU, PrimitiveType_DeDepthwiseConv2D, CpuDeconvDwKernelCreator) } // namespace mindspore::kernel diff --git a/mindspore/lite/src/runtime/kernel/arm/base/convolution_base.h b/mindspore/lite/src/runtime/kernel/arm/base/convolution_base.h index dec5e737344..0640012eef4 100644 --- a/mindspore/lite/src/runtime/kernel/arm/base/convolution_base.h +++ b/mindspore/lite/src/runtime/kernel/arm/base/convolution_base.h @@ -28,7 +28,6 @@ #include "src/lite_kernel.h" #include "include/context.h" #include "src/runtime/kernel/arm/base/layout_transform.h" -#include "src/runtime/kernel/arm/opclib/optimized_kernel.h" using mindspore::lite::Context; using mindspore::schema::PadMode; diff --git a/mindspore/lite/src/runtime/kernel/arm/base/fullconnection_base.cc b/mindspore/lite/src/runtime/kernel/arm/base/fullconnection_base.cc index 29fc6ec95b4..4f74e943600 100644 --- a/mindspore/lite/src/runtime/kernel/arm/base/fullconnection_base.cc +++ b/mindspore/lite/src/runtime/kernel/arm/base/fullconnection_base.cc @@ -32,39 +32,17 @@ int FullconnectionBaseCPUKernel::Init() { return RET_OK; } -kernel::LiteKernel *CpuFullConnectionKernelCreator(const std::vector &inputs, - const std::vector &outputs, - OpParameter *opParameter, const lite::Context *ctx, - const kernel::KernelKey &desc) { +kernel::LiteKernel *CpuFullConnectionInt8KernelCreator(const std::vector &inputs, + const std::vector &outputs, + OpParameter *opParameter, const lite::Context *ctx, + const kernel::KernelKey &desc) { MS_ASSERT(opParameter != nullptr); MS_ASSERT(desc.type == schema::PrimitiveType_Concat); - auto input_tensor = inputs.at(kInputIndex); - auto data_type = input_tensor->data_type(); - kernel::LiteKernel *kernel = nullptr; - switch (data_type) { - case kNumberTypeInt8: - case kNumberTypeUInt8: { - kernel = new (std::nothrow) FullconnectionInt8CPUKernel(opParameter, inputs, outputs, ctx); - if (!kernel) { - MS_LOG(ERROR) << "kernel is nullptr."; - return nullptr; - } - break; - } - - case kNumberTypeFloat32: { - kernel = new (std::nothrow) FullconnectionCPUKernel(opParameter, inputs, outputs, ctx); - if (!kernel) { - MS_LOG(ERROR) << "kernel is nullptr."; - return nullptr; - } - break; - } - - default: - break; + auto kernel = new (std::nothrow) FullconnectionInt8CPUKernel(opParameter, inputs, outputs, ctx); + if (!kernel) { + MS_LOG(ERROR) << "kernel is nullptr."; + return nullptr; } - auto ret = kernel->Init(); if (ret != RET_OK) { delete kernel; @@ -75,5 +53,27 @@ kernel::LiteKernel *CpuFullConnectionKernelCreator(const std::vector &inputs, + const std::vector &outputs, + OpParameter *opParameter, const lite::Context *ctx, + const kernel::KernelKey &desc) { + MS_ASSERT(opParameter != nullptr); + MS_ASSERT(desc.type == schema::PrimitiveType_Concat); + auto kernel = new (std::nothrow) FullconnectionCPUKernel(opParameter, inputs, outputs, ctx); + if (!kernel) { + MS_LOG(ERROR) << "kernel is nullptr."; + return nullptr; + } + auto ret = kernel->Init(); + if (ret != RET_OK) { + delete kernel; + MS_LOG(ERROR) << "Init kernel failed, name: " << opParameter->name_ << ", type: " + << schema::EnumNamePrimitiveType(static_cast(opParameter->type_)); + return nullptr; + } + return kernel; +} + +REG_KERNEL(kCPU, kNumberTypeInt8, PrimitiveType_FullConnection, CpuFullConnectionInt8KernelCreator) +REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_FullConnection, CpuFullConnectionFp32KernelCreator) } // namespace mindspore::kernel diff --git a/mindspore/lite/src/runtime/kernel/arm/base/layout_transform.cc b/mindspore/lite/src/runtime/kernel/arm/base/layout_transform.cc index c495f19b0d8..ba8a82cc134 100644 --- a/mindspore/lite/src/runtime/kernel/arm/base/layout_transform.cc +++ b/mindspore/lite/src/runtime/kernel/arm/base/layout_transform.cc @@ -15,6 +15,7 @@ */ #include "src/runtime/kernel/arm/base/layout_transform.h" +#include "mindspore/core/utils/log_adapter.h" using mindspore::schema::Format; namespace mindspore::kernel { diff --git a/mindspore/lite/src/runtime/kernel/arm/base/layout_transform.h b/mindspore/lite/src/runtime/kernel/arm/base/layout_transform.h index b09e533bbb6..7a6c4e93f6d 100644 --- a/mindspore/lite/src/runtime/kernel/arm/base/layout_transform.h +++ b/mindspore/lite/src/runtime/kernel/arm/base/layout_transform.h @@ -21,7 +21,8 @@ #include #endif #include "src/runtime/kernel/arm/opclib/pack.h" -#include "src/ir/tensor.h" +#include "ir/dtype/type_id.h" +#include "schema/ops_generated.h" namespace mindspore::kernel { typedef void (*LayoutConvertor)(const void *src, void *dst, int batch, int plane, int channel); diff --git a/mindspore/lite/src/runtime/kernel/arm/base/pad.cc b/mindspore/lite/src/runtime/kernel/arm/base/pad.cc index 053e70cee36..723657b6037 100644 --- a/mindspore/lite/src/runtime/kernel/arm/base/pad.cc +++ b/mindspore/lite/src/runtime/kernel/arm/base/pad.cc @@ -32,50 +32,13 @@ kernel::LiteKernel *CpuPadInt8KernelCreator(const std::vector &outputs, OpParameter *opParameter, const lite::Context *ctx, const kernel::KernelKey &desc) { + MS_ASSERT(opParameter != nullptr); + MS_ASSERT(desc.type == schema::PrimitiveType_Pad); auto *kernel = new (std::nothrow) PadInt8CPUKernel(opParameter, inputs, outputs, ctx); if (kernel == nullptr) { MS_LOG(ERROR) << "new PadCPUKernel failed."; return nullptr; } - return kernel; -} - -kernel::LiteKernel *CpuPadFp32KernelCreator(const std::vector &inputs, - const std::vector &outputs, - OpParameter *opParameter, const lite::Context *ctx, - const kernel::KernelKey &desc) { - auto *kernel = new (std::nothrow) PadCPUKernel(opParameter, inputs, outputs, ctx); - if (kernel == nullptr) { - MS_LOG(ERROR) << "new PadCPUKernel failed."; - return nullptr; - } - return kernel; -} - -kernel::LiteKernel *CpuPadKernelCreator(const std::vector &inputs, - const std::vector &outputs, OpParameter *opParameter, - const lite::Context *ctx, const kernel::KernelKey &desc) { - MS_ASSERT(opParameter != nullptr); - MS_ASSERT(desc.type == schema::PrimitiveType_Concat); - auto input_tensor = inputs.at(kInputIndex); - auto data_type = input_tensor->data_type(); - kernel::LiteKernel *kernel = nullptr; - switch (data_type) { - case kNumberTypeInt8: - kernel = CpuPadInt8KernelCreator(inputs, outputs, opParameter, ctx, desc); - break; - case kNumberTypeFloat32: - kernel = CpuPadFp32KernelCreator(inputs, outputs, opParameter, ctx, desc); - break; - default: - break; - } - - if (kernel == nullptr) { - MS_LOG(ERROR) << "kernel is nullptr."; - return nullptr; - } - auto ret = kernel->Init(); if (ret != RET_OK) { delete kernel; @@ -86,5 +49,27 @@ kernel::LiteKernel *CpuPadKernelCreator(const std::vector &inputs, + const std::vector &outputs, + OpParameter *opParameter, const lite::Context *ctx, + const kernel::KernelKey &desc) { + MS_ASSERT(opParameter != nullptr); + MS_ASSERT(desc.type == schema::PrimitiveType_Pad); + auto *kernel = new (std::nothrow) PadCPUKernel(opParameter, inputs, outputs, ctx); + if (kernel == nullptr) { + MS_LOG(ERROR) << "new PadCPUKernel failed."; + return nullptr; + } + auto ret = kernel->Init(); + if (ret != RET_OK) { + delete kernel; + MS_LOG(ERROR) << "Init kernel failed, name: " << opParameter->name_ << ", type: " + << schema::EnumNamePrimitiveType(static_cast(opParameter->type_)); + return nullptr; + } + return kernel; +} + +REG_KERNEL(kCPU, kNumberTypeInt8, PrimitiveType_Pad, CpuPadInt8KernelCreator) +REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_Pad, CpuPadFp32KernelCreator) } // namespace mindspore::kernel diff --git a/mindspore/lite/src/runtime/kernel/arm/base/pooling_base.cc b/mindspore/lite/src/runtime/kernel/arm/base/pooling_base.cc index bdeef575a07..767e6126116 100644 --- a/mindspore/lite/src/runtime/kernel/arm/base/pooling_base.cc +++ b/mindspore/lite/src/runtime/kernel/arm/base/pooling_base.cc @@ -81,7 +81,8 @@ int PoolingBaseCPUKernel::Init() { kernel::LiteKernel *CpuPoolingInt8KernelCreator(const std::vector &inputs, const std::vector &outputs, - OpParameter *opParameter, const Context *ctx) { + OpParameter *opParameter, const Context *ctx, + const kernel::KernelKey &desc) { if (opParameter == nullptr) { MS_LOG(ERROR) << "Input opParameter is nullptr!"; return nullptr; @@ -92,50 +93,6 @@ kernel::LiteKernel *CpuPoolingInt8KernelCreator(const std::vector &inputs, - const std::vector &outputs, - OpParameter *opParameter, const Context *ctx) { - if (opParameter == nullptr) { - MS_LOG(ERROR) << "Input opParameter is nullptr!"; - return nullptr; - } - MS_ASSERT(desc.type == schema::PrimitiveType_Pooling); - auto *kernel = new (std::nothrow) PoolingCPUKernel(opParameter, inputs, outputs, ctx); - if (kernel == nullptr) { - MS_LOG(ERROR) << "new PoolingCPUKernel fail!"; - return nullptr; - } - return kernel; -} - -kernel::LiteKernel *CpuPoolingKernelCreator(const std::vector &inputs, - const std::vector &outputs, - OpParameter *opParameter, const lite::Context *ctx, - const kernel::KernelKey &desc) { - MS_ASSERT(opParameter != nullptr); - MS_ASSERT(desc.type == schema::PrimitiveType_Pooing); - auto input_tensor = inputs.at(kInputIndex); - auto data_type = input_tensor->data_type(); - kernel::LiteKernel *kernel = nullptr; - switch (data_type) { - case kNumberTypeInt8: - case kNumberTypeUInt8: - kernel = CpuPoolingInt8KernelCreator(inputs, outputs, opParameter, ctx); - break; - case kNumberTypeFloat32: - kernel = CpuPoolingFp32KernelCreator(inputs, outputs, opParameter, ctx); - break; - default: - break; - } - - if (kernel == nullptr) { - MS_LOG(ERROR) << "kernel is nullptr."; - return nullptr; - } auto ret = kernel->Init(); if (ret != RET_OK) { delete kernel; @@ -146,5 +103,30 @@ kernel::LiteKernel *CpuPoolingKernelCreator(const std::vector &inputs, + const std::vector &outputs, + OpParameter *opParameter, const Context *ctx, + const kernel::KernelKey &desc) { + if (opParameter == nullptr) { + MS_LOG(ERROR) << "Input opParameter is nullptr!"; + return nullptr; + } + MS_ASSERT(desc.type == schema::PrimitiveType_Pooling); + auto *kernel = new (std::nothrow) PoolingCPUKernel(opParameter, inputs, outputs, ctx); + if (kernel == nullptr) { + MS_LOG(ERROR) << "new PoolingCPUKernel fail!"; + return nullptr; + } + auto ret = kernel->Init(); + if (ret != RET_OK) { + delete kernel; + MS_LOG(ERROR) << "Init kernel failed, name: " << opParameter->name_ << ", type: " + << schema::EnumNamePrimitiveType(static_cast(opParameter->type_)); + return nullptr; + } + return kernel; +} + +REG_KERNEL(kCPU, kNumberTypeInt8, PrimitiveType_Pooling, CpuPoolingInt8KernelCreator) +REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_Pooling, CpuPoolingFp32KernelCreator) } // namespace mindspore::kernel diff --git a/mindspore/lite/src/runtime/kernel/arm/base/reshape_base.cc b/mindspore/lite/src/runtime/kernel/arm/base/reshape_base.cc index f3c7d79a425..5065a8aa373 100644 --- a/mindspore/lite/src/runtime/kernel/arm/base/reshape_base.cc +++ b/mindspore/lite/src/runtime/kernel/arm/base/reshape_base.cc @@ -35,62 +35,18 @@ int ReshapeBaseCPUKernel::Init() { kernel::LiteKernel *CpuReshapeInt8KernelCreator(const std::vector &inputs, const std::vector &outputs, - OpParameter *opParameter, const Context *ctx) { + OpParameter *opParameter, const Context *ctx, + const kernel::KernelKey &desc) { if (opParameter == nullptr) { MS_LOG(ERROR) << "Input opParameter is nullptr!"; return nullptr; } - MS_ASSERT(desc.type == schema::PrimitiveType_Concat); - auto *kernel = new(std::nothrow) ReshapeInt8CPUKernel(opParameter, inputs, outputs, ctx); + MS_ASSERT(desc.type == schema::PrimitiveType_Reshape); + auto *kernel = new (std::nothrow) ReshapeInt8CPUKernel(opParameter, inputs, outputs, ctx); if (kernel == nullptr) { MS_LOG(ERROR) << "new ConcatCPUKernel fail!"; return nullptr; } - return kernel; -} - -kernel::LiteKernel *CpuReshapeFp32OrInt32KernelCreator(const std::vector &inputs, - const std::vector &outputs, - OpParameter *opParameter, const Context *ctx) { - if (opParameter == nullptr) { - MS_LOG(ERROR) << "Input opParameter is nullptr!"; - return nullptr; - } - MS_ASSERT(desc.type == schema::PrimitiveType_Concat); - auto *kernel = new(std::nothrow) ReshapeCPUKernel(opParameter, inputs, outputs, ctx); - if (kernel == nullptr) { - MS_LOG(ERROR) << "new ConcatCPUKernel fail!"; - return nullptr; - } - return kernel; -} - -kernel::LiteKernel *CpuReshapeKernelCreator(const std::vector &inputs, - const std::vector &outputs, - OpParameter *opParameter, const lite::Context *ctx, - const kernel::KernelKey &desc) { - MS_ASSERT(opParameter != nullptr); - MS_ASSERT(desc.type == schema::PrimitiveType_Concat); - auto input_tensor = inputs.at(kInputIndex); - auto data_type = input_tensor->data_type(); - kernel::LiteKernel *kernel = nullptr; - switch (data_type) { - case kNumberTypeInt8: - case kNumberTypeUInt8: - kernel = CpuReshapeInt8KernelCreator(inputs, outputs, opParameter, ctx); - break; - case kNumberTypeInt32: - case kNumberTypeFloat32: - kernel = CpuReshapeFp32OrInt32KernelCreator(inputs, outputs, opParameter, ctx); - break; - default: - break; - } - - if (kernel == nullptr) { - MS_LOG(ERROR) << "kernel is nullptr."; - return nullptr; - } auto ret = kernel->Init(); if (ret != RET_OK) { delete kernel; @@ -101,6 +57,55 @@ kernel::LiteKernel *CpuReshapeKernelCreator(const std::vector &inputs, + const std::vector &outputs, + OpParameter *opParameter, const Context *ctx, + const kernel::KernelKey &desc) { + if (opParameter == nullptr) { + MS_LOG(ERROR) << "Input opParameter is nullptr!"; + return nullptr; + } + MS_ASSERT(desc.type == schema::PrimitiveType_Reshape); + auto *kernel = new (std::nothrow) ReshapeCPUKernel(opParameter, inputs, outputs, ctx); + if (kernel == nullptr) { + MS_LOG(ERROR) << "new ConcatCPUKernel fail!"; + return nullptr; + } + auto ret = kernel->Init(); + if (ret != RET_OK) { + delete kernel; + MS_LOG(ERROR) << "Init kernel failed, name: " << opParameter->name_ << ", type: " + << schema::EnumNamePrimitiveType(static_cast(opParameter->type_)); + return nullptr; + } + return kernel; +} +kernel::LiteKernel *CpuReshapeFp32KernelCreator(const std::vector &inputs, + const std::vector &outputs, + OpParameter *opParameter, const Context *ctx, + const kernel::KernelKey &desc) { + if (opParameter == nullptr) { + MS_LOG(ERROR) << "Input opParameter is nullptr!"; + return nullptr; + } + MS_ASSERT(desc.type == schema::PrimitiveType_Reshape); + auto *kernel = new (std::nothrow) ReshapeCPUKernel(opParameter, inputs, outputs, ctx); + if (kernel == nullptr) { + MS_LOG(ERROR) << "new ConcatCPUKernel fail!"; + return nullptr; + } + auto ret = kernel->Init(); + if (ret != RET_OK) { + delete kernel; + MS_LOG(ERROR) << "Init kernel failed, name: " << opParameter->name_ << ", type: " + << schema::EnumNamePrimitiveType(static_cast(opParameter->type_)); + return nullptr; + } + return kernel; +} + +REG_KERNEL(kCPU, kNumberTypeInt8, PrimitiveType_Reshape, CpuReshapeInt8KernelCreator) +REG_KERNEL(kCPU, kNumberTypeInt32, PrimitiveType_Reshape, CpuReshapeInt32KernelCreator) +REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_Reshape, CpuReshapeFp32KernelCreator) +} // namespace mindspore::kernel diff --git a/mindspore/lite/src/runtime/kernel/arm/fp16/convolution_3x3_fp16.cc b/mindspore/lite/src/runtime/kernel/arm/fp16/convolution_3x3_fp16.cc index 20bbb7ecd17..22039321154 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp16/convolution_3x3_fp16.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp16/convolution_3x3_fp16.cc @@ -16,6 +16,8 @@ #include "src/runtime/kernel/arm/fp16/convolution_3x3_fp16.h" #include "src/runtime/kernel/arm/opclib/fp16/conv_fp16.h" +#include "src/runtime/kernel/arm/opclib/fp16/winograd_transform_fp16.h" +#include "src/runtime/kernel/arm/opclib/fp16/pack_fp16.h" #include "src/runtime/kernel/arm/base/layout_transform.h" #include "schema/model_generated.h" #include "src/kernel_registry.h" @@ -165,7 +167,6 @@ void Convolution3x3FP16CPUKernel::ConfigInputOutput() { } int Convolution3x3FP16CPUKernel::Init() { - ConvolutionBaseCPUKernel::Init(); auto ret = ConvolutionBaseCPUKernel::Init(); if (ret != RET_OK) { MS_LOG(ERROR) << "ConvolutionBase init failed."; diff --git a/mindspore/lite/src/runtime/kernel/arm/fp16/convolution_3x3_fp16.h b/mindspore/lite/src/runtime/kernel/arm/fp16/convolution_3x3_fp16.h index 5cf60d1c1e1..b2db5674aa3 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp16/convolution_3x3_fp16.h +++ b/mindspore/lite/src/runtime/kernel/arm/fp16/convolution_3x3_fp16.h @@ -20,8 +20,6 @@ #include #include #include "src/lite_kernel.h" - -#include "src/runtime/kernel/arm/opclib/winograd_transform.h" #include "src/runtime/kernel/arm/base/convolution_base.h" #include "src/runtime/kernel/arm/opclib/optimized_kernel.h" diff --git a/mindspore/lite/src/runtime/kernel/arm/fp16/convolution_depthwise_fp16.cc b/mindspore/lite/src/runtime/kernel/arm/fp16/convolution_depthwise_fp16.cc index dbbdf19fcc9..6b00c60b591 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp16/convolution_depthwise_fp16.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp16/convolution_depthwise_fp16.cc @@ -15,6 +15,7 @@ */ #include "src/runtime/kernel/arm/fp16/convolution_depthwise_fp16.h" +#include "src/runtime/kernel/arm/opclib/fp16/pack_fp16.h" #include "schema/model_generated.h" #include "src/kernel_registry.h" #include "include/errorcode.h" @@ -161,4 +162,27 @@ int ConvolutionDepthwiseFp16CPUKernel::Run() { return RET_OK; } + +kernel::LiteKernel *CpuConvDwFp16KernelCreator(const std::vector &inputs, + const std::vector &outputs, + OpParameter *opParameter, const Context *ctx, + const kernel::KernelKey &desc) { + MS_ASSERT(opParameter != nullptr); + MS_ASSERT(desc.type == schema::PrimitiveType_DepthwiseConv2D); + auto kernel = new (std::nothrow) ConvolutionDepthwiseFp16CPUKernel(opParameter, inputs, outputs, ctx); + if (kernel == nullptr) { + MS_LOG(ERROR) << "kernel is nullptr."; + return nullptr; + } + auto ret = kernel->Init(); + if (ret != RET_OK) { + delete kernel; + MS_LOG(ERROR) << "Init kernel failed, name: " << opParameter->name_ << ", type: " + << schema::EnumNamePrimitiveType(static_cast(opParameter->type_)); + return nullptr; + } + return kernel; +} + +REG_KERNEL(kCPU, kNumberTypeFloat16, PrimitiveType_DepthwiseConv2D, CpuConvDwFp16KernelCreator) } // namespace mindspore::kernel diff --git a/mindspore/lite/src/runtime/kernel/arm/fp16/convolution_fp16.cc b/mindspore/lite/src/runtime/kernel/arm/fp16/convolution_fp16.cc index c73d85a9d1d..c4a775a0304 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp16/convolution_fp16.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp16/convolution_fp16.cc @@ -15,7 +15,9 @@ */ #include "src/runtime/kernel/arm/fp16/convolution_fp16.h" +#include "src/runtime/kernel/arm/fp16/convolution_3x3_fp16.h" #include "src/runtime/kernel/arm/opclib/fp16/conv_fp16.h" +#include "src/runtime/kernel/arm/opclib/fp16/pack_fp16.h" #include "src/runtime/kernel/arm/base/layout_transform.h" #include "schema/model_generated.h" #include "src/kernel_registry.h" @@ -218,5 +220,42 @@ int ConvolutionFP16CPUKernel::Run() { } return RET_OK; } -} // namespace mindspore::kernel +kernel::LiteKernel *CpuConvFp16KernelCreator(const std::vector &inputs, + const std::vector &outputs, + OpParameter *opParameter, const Context *ctx, + const kernel::KernelKey &desc) { + MS_ASSERT(opParameter != nullptr); + MS_ASSERT(desc.type == schema::PrimitiveType_Conv2D); + auto conv_param = reinterpret_cast(opParameter); + int kernel_h = conv_param->kernel_h_; + int kernel_w = conv_param->kernel_w_; + int stride_h = conv_param->stride_h_; + int stride_w = conv_param->stride_w_; + int dilation_h = conv_param->dilation_h_; + int dilation_w = conv_param->dilation_w_; + conv_param->input_h_ = inputs.front()->Height(); + conv_param->input_w_ = inputs.front()->Width(); + conv_param->output_h_ = outputs.front()->Height(); + conv_param->output_w_ = outputs.front()->Width(); + kernel::LiteKernel *kernel; + if (kernel_h == 3 && kernel_w == 3 && stride_h == 1 && stride_w == 1 && dilation_h == 1 && dilation_w == 1) { + kernel = new (std::nothrow) kernel::Convolution3x3FP16CPUKernel(opParameter, inputs, outputs, ctx); + } else { + kernel = new (std::nothrow) kernel::ConvolutionFP16CPUKernel(opParameter, inputs, outputs, ctx); + } + if (kernel == nullptr) { + MS_LOG(ERROR) << "Create conv fp16 kernel failed."; + return nullptr; + } + auto ret = kernel->Init(); + if (ret != RET_OK) { + delete kernel; + MS_LOG(ERROR) << "Init kernel failed, name: " << opParameter->name_ << ", type: " + << schema::EnumNamePrimitiveType(static_cast(opParameter->type_)); + return nullptr; + } + return kernel; +} +REG_KERNEL(kCPU, kNumberTypeFloat16, PrimitiveType_Conv2D, CpuConvFp16KernelCreator) +} // namespace mindspore::kernel diff --git a/mindspore/lite/src/runtime/kernel/arm/fp16/deconvolution_depthwise_fp16.cc b/mindspore/lite/src/runtime/kernel/arm/fp16/deconvolution_depthwise_fp16.cc index f32de057814..ce758358d7d 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp16/deconvolution_depthwise_fp16.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp16/deconvolution_depthwise_fp16.cc @@ -15,6 +15,7 @@ */ #include "src/runtime/kernel/arm/fp16/deconvolution_depthwise_fp16.h" +#include "src/runtime/kernel/arm/opclib/fp16/pack_fp16.h" #include "schema/model_generated.h" #include "src/kernel_registry.h" #include "include/errorcode.h" @@ -24,7 +25,7 @@ using mindspore::kernel::KERNEL_ARCH::kCPU; using mindspore::lite::KernelRegistrar; using mindspore::lite::RET_ERROR; using mindspore::lite::RET_OK; -using mindspore::schema::PrimitiveType_DepthwiseConv2D; +using mindspore::schema::PrimitiveType_DeDepthwiseConv2D; namespace mindspore::kernel { int DeconvolutionDepthwiseFp16CPUKernel::InitSlideParam() { @@ -171,4 +172,27 @@ int DeconvolutionDepthwiseFp16CPUKernel::Run() { conv_param_->output_h_ * conv_param_->output_w_, conv_param_->output_channel_); return RET_OK; } + +kernel::LiteKernel *CpuDeconvDwFp16KernelCreator(const std::vector &inputs, + const std::vector &outputs, + OpParameter *opParameter, const lite::Context *ctx, + const kernel::KernelKey &desc) { + MS_ASSERT(opParameter != nullptr); + MS_ASSERT(desc.type == schema::PrimitiveType_DeDepthwiseConv2D); + auto kernel = new (std::nothrow) DeconvolutionDepthwiseFp16CPUKernel(opParameter, inputs, outputs, ctx); + if (kernel == nullptr) { + MS_LOG(ERROR) << "kernel is nullptr."; + return nullptr; + } + auto ret = kernel->Init(); + if (ret != RET_OK) { + delete kernel; + MS_LOG(ERROR) << "Init kernel failed, name: " << opParameter->name_ << ", type: " + << schema::EnumNamePrimitiveType(static_cast(opParameter->type_)); + return nullptr; + } + return kernel; +} + +REG_KERNEL(kCPU, kNumberTypeFloat16, PrimitiveType_DeDepthwiseConv2D, CpuDeconvDwFp16KernelCreator) } // namespace mindspore::kernel diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/activation.cc b/mindspore/lite/src/runtime/kernel/arm/fp32/activation.cc index b12b5ce5656..9d726e7ad7e 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32/activation.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/activation.cc @@ -93,6 +93,10 @@ kernel::LiteKernel *CpuActivationFp32KernelCreator(const std::vectorInit(); if (ret != RET_OK) { MS_LOG(ERROR) << "Init kernel failed, name: " << opParameter->name_ << ", type: " @@ -101,6 +105,5 @@ kernel::LiteKernel *CpuActivationFp32KernelCreator(const std::vectorthread_num_ = ctx->threadNum; auto *kernel = new (std::nothrow) AddNCPUKernel(op_parameter, inputs, outputs); if (kernel == nullptr) { @@ -117,5 +117,5 @@ kernel::LiteKernel *CpuAddNFp32KernelCreator(const std::vector &outputs, OpParameter *parameter, const lite::Context *ctx, const kernel::KernelKey &desc) { - MS_ASSERT(parameter); - MS_ASSERT(inputs.at(0)); - auto data_type = inputs.at(0)->data_type(); - kernel::LiteKernel *kernel = nullptr; - switch (data_type) { - case kNumberTypeFloat32: - kernel = new (std::nothrow) ArithmeticCPUKernel(parameter, inputs, outputs, ctx); - break; - case kNumberTypeInt8: - if (desc.type == schema::PrimitiveType_Add) { - kernel = new (std::nothrow) QuantizedAddCPUKernel(parameter, inputs, outputs, ctx); - } else if (desc.type == schema::PrimitiveType_Mul) { - kernel = new (std::nothrow) MulInt8CPUKernel(parameter, inputs, outputs, ctx); - } else { - } - break; - default: - break; - } + MS_ASSERT(parameter != nullptr); + auto kernel = new (std::nothrow) ArithmeticCPUKernel(parameter, inputs, outputs, ctx); if (kernel == nullptr) { MS_LOG(ERROR) << "Create kernel failed, name: " << parameter->name_; return nullptr; @@ -145,24 +128,23 @@ kernel::LiteKernel *CpuArithmeticFp32KernelCreator(const std::vectorin_shape1_[i] = 1; bias_param_->out_shape_[i] = dims[i]; } - bias_param_->in_shape1_[bias_param_->ndim_ - 1] = dims[bias_param_->ndim_ - 1]; + bias_param_->in_shape1_[bias_param_->ndim_ - 1] = dims[bias_param_->ndim_ - 1]; return RET_OK; } @@ -61,19 +61,7 @@ kernel::LiteKernel *CpuBiasFp32KernelCreator(const std::vectordata_type(); - kernel::LiteKernel *kernel = nullptr; - switch (data_type) { - case kNumberTypeFloat32: - kernel = new (std::nothrow) BiasCPUKernel(parameter, inputs, outputs); - break; - case kNumberTypeInt8: - kernel = new (std::nothrow) BiasAddInt8CPUKernel(parameter, inputs, outputs, ctx); - break; - default: - break; - } + auto kernel = new (std::nothrow) BiasCPUKernel(parameter, inputs, outputs); if (kernel == nullptr) { MS_LOG(ERROR) << "Create kernel failed, name: " << parameter->name_; return nullptr; @@ -89,6 +77,5 @@ kernel::LiteKernel *CpuBiasFp32KernelCreator(const std::vectorkernel_w_ == conv_param->kernel_h_ && conv_param->dilation_h_ == 1 && conv_param->dilation_w_ == 1 && + conv_param->stride_h_ == 1 && conv_param->stride_w_ == 1) { + *output_unit = SelectOutputUnit(conv_param); + if (*output_unit > 1) { + *use_winograd = true; + int input_unit = conv_param->kernel_h_ + *output_unit - 1; + input_trans_func = GetInputTransFunc(input_unit); + if (input_trans_func == nullptr) { + MS_LOG(INFO) << "No matching input trans func. Turn back to common conv."; + *use_winograd = false; + } + output_trans_func = GetOutputTransFunc(input_unit, *output_unit); + if (output_trans_func == nullptr) { + MS_LOG(INFO) << "No matching output trans func. Turn back to common conv."; + *use_winograd = false; + } + } else { + *use_winograd = false; + } + } else { + *use_winograd = false; + } +} + +kernel::LiteKernel *CpuConvFp32KernelCreator(const std::vector &inputs, + const std::vector &outputs, + OpParameter *opParameter, const Context *ctx, + const kernel::KernelKey &desc) { + MS_ASSERT(opParameter != nullptr); + MS_ASSERT(desc.type == schema::PrimitiveType_Conv2D); + auto conv_param = reinterpret_cast(opParameter); + int kernel_h = conv_param->kernel_h_; + int kernel_w = conv_param->kernel_w_; + int stride_h = conv_param->stride_h_; + int stride_w = conv_param->stride_w_; + int dilation_h = conv_param->dilation_h_; + int dilation_w = conv_param->dilation_w_; + conv_param->input_h_ = inputs.front()->Height(); + conv_param->input_w_ = inputs.front()->Width(); + conv_param->output_h_ = outputs.front()->Height(); + conv_param->output_w_ = outputs.front()->Width(); + bool use_winograd; + int out_unit; + InputTransformUnitFunc input_trans_func = nullptr; + OutputTransformUnitFunc output_trans_func = nullptr; + CheckIfUseWinograd(&use_winograd, &out_unit, conv_param, input_trans_func, output_trans_func); + kernel::LiteKernel *kernel; + if (kernel_h == 1 && kernel_w == 1) { + kernel = new (std::nothrow) kernel::Convolution1x1CPUKernel(opParameter, inputs, outputs, ctx); + } else if (kernel_h == 3 && kernel_w == 3 && stride_h == 1 && stride_w == 1 && dilation_h == 1 && dilation_w == 1) { + kernel = new (std::nothrow) kernel::Convolution3x3CPUKernel(opParameter, inputs, outputs, ctx); + } else if (use_winograd) { + kernel = new (std::nothrow) kernel::ConvolutionWinogradCPUKernel(opParameter, inputs, outputs, ctx, out_unit); + } else { + kernel = new (std::nothrow) kernel::ConvolutionCPUKernel(opParameter, inputs, outputs, ctx); + } + if (kernel == nullptr) { + MS_LOG(ERROR) << "kernel is nullptr."; + return nullptr; + } + auto ret = kernel->Init(); + if (ret != RET_OK) { + delete kernel; + MS_LOG(ERROR) << "Init kernel failed, name: " << opParameter->name_ << ", type: " + << schema::EnumNamePrimitiveType(static_cast(opParameter->type_)); + return nullptr; + } + return kernel; +} + +REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_Conv2D, CpuConvFp32KernelCreator) } // namespace mindspore::kernel diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/convolution_depthwise.cc b/mindspore/lite/src/runtime/kernel/arm/fp32/convolution_depthwise.cc index 27cb024a4bd..75856bfba36 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32/convolution_depthwise.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/convolution_depthwise.cc @@ -145,5 +145,29 @@ int ConvolutionDepthwiseCPUKernel::Run() { } return RET_OK; } + + +kernel::LiteKernel *CpuConvDwFp32KernelCreator(const std::vector &inputs, + const std::vector &outputs, + OpParameter *opParameter, const Context *ctx, + const kernel::KernelKey &desc) { + MS_ASSERT(opParameter != nullptr); + MS_ASSERT(desc.type == schema::PrimitiveType_DepthwiseConv2D); + auto kernel = new (std::nothrow) kernel::ConvolutionDepthwiseCPUKernel(opParameter, inputs, outputs, ctx); + if (kernel == nullptr) { + MS_LOG(ERROR) << "kernel is nullptr."; + return nullptr; + } + auto ret = kernel->Init(); + if (ret != RET_OK) { + delete kernel; + MS_LOG(ERROR) << "Init kernel failed, name: " << opParameter->name_ << ", type: " + << schema::EnumNamePrimitiveType(static_cast(opParameter->type_)); + return nullptr; + } + return kernel; +} + +REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_DepthwiseConv2D, CpuConvDwFp32KernelCreator) } // namespace mindspore::kernel diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/crop.cc b/mindspore/lite/src/runtime/kernel/arm/fp32/crop.cc index 32abb501b9a..a2f8389b848 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32/crop.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/crop.cc @@ -37,12 +37,12 @@ int CropLaunch(int thread_id, LiteParallelGroupEnv *penv, void *cdata) { auto kernel = reinterpret_cast(cdata); return kernel->CropParallelRun(thread_id); } -} +} // namespace int CropCPUKernel::Init() { schema::Format input0_format = inputs_[0]->GetFormat(); if (input0_format != schema::Format_NCHW && input0_format != schema::Format_NHWC) { - MS_LOG(ERROR) << "Unsupport format " << input0_format; + MS_LOG(ERROR) << "Unsupport format " << input0_format; return RET_FORMAT_ERR; } outputs_[0]->SetFormat(input0_format); @@ -90,7 +90,7 @@ kernel::LiteKernel *CpuCropFp32KernelCreator(const std::vectorthread_num_ = ctx->threadNum; auto *kernel = new (std::nothrow) CropCPUKernel(op_parameter, inputs, outputs); if (kernel == nullptr) { @@ -108,5 +108,5 @@ kernel::LiteKernel *CpuCropFp32KernelCreator(const std::vector &inputs, + const std::vector &outputs, + OpParameter *opParameter, const lite::Context *ctx, + const kernel::KernelKey &desc) { + MS_ASSERT(opParameter != nullptr); + MS_ASSERT(desc.type == schema::PrimitiveType_DeConv2D); + auto kernel = new (std::nothrow) kernel::DeConvolutionCPUKernel(opParameter, inputs, outputs, ctx); + if (kernel == nullptr) { + MS_LOG(ERROR) << "kernel is nullptr."; + return nullptr; + } + auto ret = kernel->Init(); + if (ret != RET_OK) { + delete kernel; + MS_LOG(ERROR) << "Init kernel failed, name: " << opParameter->name_ << ", type: " + << schema::EnumNamePrimitiveType(static_cast(opParameter->type_)); + return nullptr; + } + return kernel; +} + +REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_DeConv2D, CpuDeConvFp32KernelCreator) } // namespace mindspore::kernel diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/deconvolution_depthwise.cc b/mindspore/lite/src/runtime/kernel/arm/fp32/deconvolution_depthwise.cc index 82ca58b2a9c..7f09307a43a 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32/deconvolution_depthwise.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/deconvolution_depthwise.cc @@ -24,7 +24,7 @@ using mindspore::kernel::KERNEL_ARCH::kCPU; using mindspore::lite::KernelRegistrar; using mindspore::lite::RET_ERROR; using mindspore::lite::RET_OK; -using mindspore::schema::PrimitiveType_DepthwiseConv2D; +using mindspore::schema::PrimitiveType_DeDepthwiseConv2D; namespace mindspore::kernel { int DeconvolutionDepthwiseCPUKernel::InitSlideParam() { @@ -158,5 +158,28 @@ int DeconvolutionDepthwiseCPUKernel::Run() { } return RET_OK; } + +kernel::LiteKernel *CpuDeconvDwFp32KernelCreator(const std::vector &inputs, + const std::vector &outputs, + OpParameter *opParameter, const lite::Context *ctx, + const kernel::KernelKey &desc) { + MS_ASSERT(opParameter != nullptr); + MS_ASSERT(desc.type == schema::PrimitiveType_DeDepthwiseConv2D); + auto kernel = new (std::nothrow) kernel::DeconvolutionDepthwiseCPUKernel(opParameter, inputs, outputs, ctx); + if (kernel == nullptr) { + MS_LOG(ERROR) << "kernel is nullptr."; + return nullptr; + } + auto ret = kernel->Init(); + if (ret != RET_OK) { + delete kernel; + MS_LOG(ERROR) << "Init kernel failed, name: " << opParameter->name_ << ", type: " + << schema::EnumNamePrimitiveType(static_cast(opParameter->type_)); + return nullptr; + } + return kernel; +} + +REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_DeDepthwiseConv2D, CpuDeconvDwFp32KernelCreator) } // namespace mindspore::kernel diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/depth_to_space.cc b/mindspore/lite/src/runtime/kernel/arm/fp32/depth_to_space.cc index a96c598b0b6..c8b6d028f93 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32/depth_to_space.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/depth_to_space.cc @@ -67,6 +67,7 @@ kernel::LiteKernel *CpuDepthToSpaceFp32KernelCreator(const std::vectorInit(); if (ret != RET_OK) { MS_LOG(ERROR) << "Init kernel failed, name: " << opParameter->name_ << ", type: " @@ -92,6 +96,6 @@ kernel::LiteKernel *CpuExpandsDimsFp32KernelCreator(const std::vectorInit(); if (ret != RET_OK) { MS_LOG(ERROR) << "Init kernel failed, name: " << opParameter->name_ << ", type: " @@ -102,6 +106,6 @@ kernel::LiteKernel *CpuFillFp32KernelCreator(const std::vectorInit(); if (ret != RET_OK) { MS_LOG(ERROR) << "Init kernel failed, name: " << opParameter->name_ << ", type: " @@ -66,6 +70,6 @@ kernel::LiteKernel *CpuFlattenFp32KernelCreator(const std::vectorInit(); if (ret != RET_OK) { MS_LOG(ERROR) << "Init kernel failed, name: " << opParameter->name_ << ", type: " << schema::EnumNamePrimitiveType(static_cast(opParameter->type_)); + return nullptr; } return kernel; } -REG_KERNEL(kCPU, PrimitiveType_FusedBatchNorm, CpuFusedBatchnormKernelCreator) +REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_FusedBatchNorm, CpuFusedBatchnormKernelCreator) } // namespace mindspore::kernel diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/gather.cc b/mindspore/lite/src/runtime/kernel/arm/fp32/gather.cc index dd0d899f8f1..fd073a9f00c 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32/gather.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/gather.cc @@ -121,6 +121,6 @@ kernel::LiteKernel *CpuGatherFp32KernelCreator(const std::vectorInit(); if (ret != RET_OK) { MS_LOG(ERROR) << "Init kernel failed, name: " << opParameter->name_ << ", type: " @@ -105,6 +109,5 @@ kernel::LiteKernel *CpuLocalResponseNormFp32KernelCreator(const std::vectorInit(); if (ret != RET_OK) { MS_LOG(ERROR) << "Init kernel failed, name: " << opParameter->name_ << ", type: " @@ -48,6 +52,6 @@ kernel::LiteKernel *CpuMatmulFp32KernelCreator(const std::vectorInit(); if (ret != RET_OK) { MS_LOG(ERROR) << "Init kernel failed, name: " << opParameter->name_ << ", type: " @@ -53,6 +57,6 @@ kernel::LiteKernel *CpuNchw2NhwcFp32KernelCreator(const std::vectorInit(); if (ret != RET_OK) { MS_LOG(ERROR) << "Init kernel failed, name: " << opParameter->name_ << ", type: " @@ -53,6 +57,6 @@ kernel::LiteKernel *CpuNhwc2NchwFp32KernelCreator(const std::vector(opParameter), inputs, outputs, ctx); + if (kernel == nullptr) { + MS_LOG(ERROR) << "new PowerCPUKernel fail!"; + return nullptr; + } auto ret = kernel->Init(); if (ret != RET_OK) { MS_LOG(ERROR) << "Init kernel failed, name: " << opParameter->name_ << ", type: " @@ -76,5 +80,5 @@ kernel::LiteKernel *CpuPowerFp32KernelCreator(const std::vectorInit(); if (ret != RET_OK) { MS_LOG(ERROR) << "Init kernel failed, name: " << opParameter->name_ << ", type: " @@ -64,7 +68,7 @@ kernel::LiteKernel *CpuRangeFp32KernelCreator(const std::vectorInit(); if (ret != RET_OK) { MS_LOG(ERROR) << "Init kernel failed, name: " << opParameter->name_ << ", type: " @@ -63,6 +67,6 @@ kernel::LiteKernel *CpuRankFp32KernelCreator(const std::vectorname_ << ", type: " @@ -156,6 +157,6 @@ kernel::LiteKernel *CpuReverseFp32KernelCreator(const std::vectorname_; @@ -111,6 +112,6 @@ kernel::LiteKernel *CpuReverseSequenceFp32KernelCreator(const std::vector(cdata); return kernel->SliceParallelRun(thread_id); } -} +} // namespace int SliceCPUKernel::Init() { auto *param = reinterpret_cast(opParameter); @@ -106,7 +106,7 @@ kernel::LiteKernel *CpuSliceFp32KernelCreator(const std::vectorthread_num_ = ctx->threadNum; auto *kernel = new (std::nothrow) SliceCPUKernel(op_parameter, inputs, outputs); if (kernel == nullptr) { @@ -124,5 +124,5 @@ kernel::LiteKernel *CpuSliceFp32KernelCreator(const std::vectorInit(); if (ret != RET_OK) { MS_LOG(ERROR) << "Init kernel failed, name: " << opParameter->name_ << ", type: " @@ -74,6 +78,6 @@ kernel::LiteKernel *CpuSoftmaxFp32KernelCreator(const std::vector &inputs, const std::vector &outputs, OpParameter *parameter, const lite::Context *ctx, const KernelKey &desc) { - MS_EXCEPTION_IF_NULL(parameter); + MS_ASSERT(parameter != nullptr); MS_ASSERT(desc.type == PrimitiveType_Tile); auto *kernel = new (std::nothrow) TopKCPUKernel(parameter, inputs, outputs); - MS_EXCEPTION_IF_NULL(kernel); + if (kernel == nullptr) { + MS_LOG(ERROR) << "new TopKCPUKernel fail!"; + return nullptr; + } auto ret = kernel->Init(); if (ret != RET_OK) { @@ -68,6 +71,6 @@ kernel::LiteKernel *CpuTopKFp32KernelCreator(const std::vectorInit(); if (ret != RET_OK) { MS_LOG(ERROR) << "Init kernel failed, name: " << opParameter->name_ << ", type: " @@ -91,6 +95,6 @@ kernel::LiteKernel *CpuUnsqueezeFp32KernelCreator(const std::vectorInit(); if (0 != ret) { MS_LOG(ERROR) << "Init kernel failed, name: " << parameter->name_ @@ -142,5 +144,6 @@ kernel::LiteKernel *CpuAddInt8KernelCreator(const std::vector &inputs, + const std::vector &outputs, + OpParameter *opParameter, const Context *ctx, + const kernel::KernelKey &desc) { + MS_ASSERT(opParameter != nullptr); + MS_ASSERT(desc.type == schema::PrimitiveType_DepthwiseConv2D); + auto kernel = new (std::nothrow) kernel::ConvolutionDepthwiseInt8CPUKernel(opParameter, inputs, outputs, ctx); + if (kernel == nullptr) { + MS_LOG(ERROR) << "kernel is nullptr."; + return nullptr; + } + auto ret = kernel->Init(); + if (ret != RET_OK) { + delete kernel; + MS_LOG(ERROR) << "Init kernel failed, name: " << opParameter->name_ << ", type: " + << schema::EnumNamePrimitiveType(static_cast(opParameter->type_)); + return nullptr; + } + return kernel; +} + +REG_KERNEL(kCPU, kNumberTypeInt8, PrimitiveType_DepthwiseConv2D, CpuConvDwInt8KernelCreator) } // namespace mindspore::kernel diff --git a/mindspore/lite/src/runtime/kernel/arm/int8/convolution_int8.cc b/mindspore/lite/src/runtime/kernel/arm/int8/convolution_int8.cc index 80d12848715..48502d19fc4 100644 --- a/mindspore/lite/src/runtime/kernel/arm/int8/convolution_int8.cc +++ b/mindspore/lite/src/runtime/kernel/arm/int8/convolution_int8.cc @@ -15,6 +15,7 @@ */ #include "src/runtime/kernel/arm/int8/convolution_int8.h" +#include "src/runtime/kernel/arm/int8/convolution_3x3_int8.h" #include "src/runtime/kernel/arm/opclib/int8/conv_int8.h" #include "src/runtime/kernel/arm/base/layout_transform.h" #include "schema/model_generated.h" @@ -36,7 +37,7 @@ void ConvolutionInt8CPUKernel::CheckSupportOptimize() { support_optimize_ = false; #endif -#ifdef __aarch64__ +#ifdef ENABLE_ARM64 void *optimize_op_handler = OptimizeModule::GetInstance()->optimized_op_handler_; if (optimize_op_handler != nullptr) { dlerror(); @@ -383,4 +384,39 @@ int ConvolutionInt8CPUKernel::Run() { } return RET_OK; } + +kernel::LiteKernel *CpuConvInt8KernelCreator(const std::vector &inputs, + const std::vector &outputs, + OpParameter *opParameter, const Context *ctx, + const kernel::KernelKey &desc) { + MS_ASSERT(opParameter != nullptr); + MS_ASSERT(desc.type == schema::PrimitiveType_Conv2D); + auto conv_param = reinterpret_cast(opParameter); + int kernel_h = conv_param->kernel_h_; + int kernel_w = conv_param->kernel_w_; + int stride_h = conv_param->stride_h_; + int stride_w = conv_param->stride_w_; + int dilation_h = conv_param->dilation_h_; + int dilation_w = conv_param->dilation_w_; + kernel::LiteKernel *kernel; + if (kernel_h == 3 && kernel_w == 3 && stride_h == 1 && stride_w == 1 && dilation_h == 1 && dilation_w == 1) { + kernel = new (std::nothrow) kernel::Convolution3x3Int8CPUKernel(opParameter, inputs, outputs, ctx); + } else { + kernel = new (std::nothrow) kernel::ConvolutionInt8CPUKernel(opParameter, inputs, outputs, ctx); + } + if (kernel == nullptr) { + MS_LOG(ERROR) << "kernel is nullptr."; + return nullptr; + } + auto ret = kernel->Init(); + if (ret != RET_OK) { + delete kernel; + MS_LOG(ERROR) << "Init kernel failed, name: " << opParameter->name_ << ", type: " + << schema::EnumNamePrimitiveType(static_cast(opParameter->type_)); + return nullptr; + } + return kernel; +} + +REG_KERNEL(kCPU, kNumberTypeInt8, PrimitiveType_Conv2D, CpuConvInt8KernelCreator) } // namespace mindspore::kernel diff --git a/mindspore/lite/src/runtime/kernel/arm/int8/convolution_int8.h b/mindspore/lite/src/runtime/kernel/arm/int8/convolution_int8.h index cfae7492ec2..cc264392ee0 100644 --- a/mindspore/lite/src/runtime/kernel/arm/int8/convolution_int8.h +++ b/mindspore/lite/src/runtime/kernel/arm/int8/convolution_int8.h @@ -72,4 +72,3 @@ class ConvolutionInt8CPUKernel : public ConvolutionBaseCPUKernel { } // namespace mindspore::kernel #endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_INT8_CONVOLUTION_INT8_H_ - diff --git a/mindspore/lite/src/runtime/kernel/arm/int8/deconvolution_depthwise_int8.cc b/mindspore/lite/src/runtime/kernel/arm/int8/deconvolution_depthwise_int8.cc index 4cc9a774584..52a0b1ffde4 100644 --- a/mindspore/lite/src/runtime/kernel/arm/int8/deconvolution_depthwise_int8.cc +++ b/mindspore/lite/src/runtime/kernel/arm/int8/deconvolution_depthwise_int8.cc @@ -25,6 +25,7 @@ using mindspore::kernel::KERNEL_ARCH::kCPU; using mindspore::lite::KernelRegistrar; using mindspore::lite::RET_ERROR; using mindspore::lite::RET_OK; +using mindspore::schema::PrimitiveType_DeDepthwiseConv2D; namespace mindspore::kernel { int DeconvolutionDepthwiseInt8CPUKernel::InitWeightBias() { @@ -63,9 +64,9 @@ int DeconvolutionDepthwiseInt8CPUKernel::InitSlideParam() { sliding->in_h_step_ = conv_param_->input_w_ * C4NUM; sliding->in_sh_step_ = conv_param_->input_w_ * C4NUM * conv_param_->stride_h_; // stride H - sliding->in_sw_step_ = C4NUM * conv_param_->stride_h_; // stride W + sliding->in_sw_step_ = C4NUM * conv_param_->stride_h_; // stride W sliding->in_kh_step_ = conv_param_->input_w_ * C4NUM * conv_param_->dilation_h_; // kernel H - sliding->in_kw_step_ = C4NUM * conv_param_->dilation_w_; // kernel W + sliding->in_kw_step_ = C4NUM * conv_param_->dilation_w_; // kernel W return RET_OK; } @@ -171,4 +172,27 @@ int DeconvolutionDepthwiseInt8CPUKernel::Run() { } return RET_OK; } + +kernel::LiteKernel *CpuDeconvDwInt8KernelCreator(const std::vector &inputs, + const std::vector &outputs, + OpParameter *opParameter, const lite::Context *ctx, + const kernel::KernelKey &desc) { + MS_ASSERT(opParameter != nullptr); + MS_ASSERT(desc.type == schema::PrimitiveType_DeDepthwiseConv2D); + auto kernel = new (std::nothrow) kernel::DeconvolutionDepthwiseInt8CPUKernel(opParameter, inputs, outputs, ctx); + if (kernel == nullptr) { + MS_LOG(ERROR) << "kernel is nullptr."; + return nullptr; + } + auto ret = kernel->Init(); + if (ret != RET_OK) { + delete kernel; + MS_LOG(ERROR) << "Init kernel failed, name: " << opParameter->name_ << ", type: " + << schema::EnumNamePrimitiveType(static_cast(opParameter->type_)); + return nullptr; + } + return kernel; +} + +REG_KERNEL(kCPU, kNumberTypeInt8, PrimitiveType_DeDepthwiseConv2D, CpuDeconvDwInt8KernelCreator) } // namespace mindspore::kernel diff --git a/mindspore/lite/src/runtime/kernel/arm/int8/deconvolution_int8.cc b/mindspore/lite/src/runtime/kernel/arm/int8/deconvolution_int8.cc index a3f10aad238..cd69a8911ca 100644 --- a/mindspore/lite/src/runtime/kernel/arm/int8/deconvolution_int8.cc +++ b/mindspore/lite/src/runtime/kernel/arm/int8/deconvolution_int8.cc @@ -17,8 +17,10 @@ #include "src/runtime/kernel/arm/int8/deconvolution_int8.h" #include "src/runtime/kernel/arm/opclib/quantization/fixed_point.h" #include "src/runtime/runtime_api.h" +#include "src/kernel_registry.h" using mindspore::kernel::KERNEL_ARCH::kCPU; +using mindspore::lite::KernelRegistrar; using mindspore::lite::RET_ERROR; using mindspore::lite::RET_MEMORY_FAILED; using mindspore::lite::RET_OK; @@ -216,5 +218,27 @@ int DeConvInt8CPUKernel::Run() { return RET_OK; } -} // namespace mindspore::kernel +kernel::LiteKernel *CpuDeConvInt8KernelCreator(const std::vector &inputs, + const std::vector &outputs, + OpParameter *opParameter, const lite::Context *ctx, + const kernel::KernelKey &desc) { + MS_ASSERT(opParameter != nullptr); + MS_ASSERT(desc.type == schema::PrimitiveType_DeConv2D); + auto kernel = new (std::nothrow) kernel::DeConvInt8CPUKernel(opParameter, inputs, outputs, ctx); + if (kernel == nullptr) { + MS_LOG(ERROR) << "kernel is nullptr."; + return nullptr; + } + auto ret = kernel->Init(); + if (ret != RET_OK) { + delete kernel; + MS_LOG(ERROR) << "Init kernel failed, name: " << opParameter->name_ << ", type: " + << schema::EnumNamePrimitiveType(static_cast(opParameter->type_)); + return nullptr; + } + return kernel; +} + +REG_KERNEL(kCPU, kNumberTypeInt8, PrimitiveType_DeConv2D, CpuDeConvInt8KernelCreator) +} // namespace mindspore::kernel diff --git a/mindspore/lite/src/runtime/kernel/arm/int8/mul_int8.cc b/mindspore/lite/src/runtime/kernel/arm/int8/mul_int8.cc index 00e2fe5d518..d51751d33d6 100644 --- a/mindspore/lite/src/runtime/kernel/arm/int8/mul_int8.cc +++ b/mindspore/lite/src/runtime/kernel/arm/int8/mul_int8.cc @@ -129,4 +129,5 @@ kernel::LiteKernel *CpuMulInt8KernelCreator(const std::vector /*conv depthwise fp16 begin*/ @@ -299,4 +298,3 @@ void DeconvDwC8Fp16(float16_t *output_data, const float16_t *input_data, const f } /*deconv depthwise fp16 end*/ -#endif diff --git a/mindspore/lite/src/runtime/kernel/arm/opclib/fp16/conv_depthwise_fp16.h b/mindspore/lite/src/runtime/kernel/arm/opclib/fp16/conv_depthwise_fp16.h index a0a8cac8ebe..e686df1f08a 100644 --- a/mindspore/lite/src/runtime/kernel/arm/opclib/fp16/conv_depthwise_fp16.h +++ b/mindspore/lite/src/runtime/kernel/arm/opclib/fp16/conv_depthwise_fp16.h @@ -20,7 +20,6 @@ #include "src/runtime/kernel/arm/opclib/conv_parameter.h" #include "src/runtime/kernel/arm/opclib/fp32/conv_depthwise.h" -#ifdef ENABLE_FP16 void ConvDwC8Fp16(float16_t *output_data, const float16_t *input_data, const float16_t *weight_data, const float16_t *bias_data, const ConvParameter *conv_param, const SlidingWindowParam *sliding, int task_id); @@ -28,6 +27,5 @@ void ConvDwC8Fp16(float16_t *output_data, const float16_t *input_data, const flo void DeconvDwC8Fp16(float16_t *output_data, const float16_t *input_data, const float16_t *weight_data, const float16_t *bias_data, const ConvParameter *conv_param, const SlidingWindowParam *sliding, int task_id); -#endif #endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_OPCLIB_FP16_CONV_DEPTHWISE_FP16_H_ diff --git a/mindspore/lite/src/runtime/kernel/arm/opclib/fp16/conv_fp16.cc b/mindspore/lite/src/runtime/kernel/arm/opclib/fp16/conv_fp16.cc index a112cd15536..fb658287c0d 100644 --- a/mindspore/lite/src/runtime/kernel/arm/opclib/fp16/conv_fp16.cc +++ b/mindspore/lite/src/runtime/kernel/arm/opclib/fp16/conv_fp16.cc @@ -15,20 +15,17 @@ */ #include "src/runtime/kernel/arm/opclib/fp16/conv_fp16.h" #include -#include "src/runtime/kernel/arm/opclib/pack.h" -#include "src/runtime/kernel/arm/opclib/winograd_transform.h" +#include "src/runtime/kernel/arm/opclib/fp16/pack_fp16.h" +#include "src/runtime/kernel/arm/opclib/fp16/winograd_transform_fp16.h" extern "C" { #ifdef ENABLE_ARM64 -#ifdef ENABLE_FP16 void IndirectGemmFp16_16x8(float16_t *output, float16_t *input, float16_t *weight, float16_t *bias, size_t step, size_t ic4, size_t oc8, size_t offset, size_t mode, size_t writeC4, size_t relu, size_t relu6); #endif -#endif } -#ifdef ENABLE_FP16 #ifndef ENABLE_NEON void IndirectGemmFp16_16x8(float16_t *output, float16_t *input, float16_t *weight, float16_t *bias, size_t step, size_t ic4, size_t out_channel, size_t offset, size_t mode, size_t writeC4, size_t relu, @@ -215,5 +212,5 @@ void Conv3x3Fp16(float16_t *input_data, float16_t *transed_weight, const float16 } } } -#endif + diff --git a/mindspore/lite/src/runtime/kernel/arm/opclib/fp16/conv_fp16.h b/mindspore/lite/src/runtime/kernel/arm/opclib/fp16/conv_fp16.h index 457e483a983..a5426dd6f99 100644 --- a/mindspore/lite/src/runtime/kernel/arm/opclib/fp16/conv_fp16.h +++ b/mindspore/lite/src/runtime/kernel/arm/opclib/fp16/conv_fp16.h @@ -16,12 +16,9 @@ #ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_OPCLIB_FP16_CONV_FP16_H_ #define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_OPCLIB_FP16_CONV_FP16_H_ -#ifdef ENABLE_FP16 #include -#endif #include "src/runtime/kernel/arm/opclib/conv_parameter.h" -#ifdef ENABLE_FP16 #ifndef ENABLE_NEON void IndirectGemmFp16_16x8(float16_t *output, float16_t *input, float16_t *weight, float16_t *bias, size_t step, size_t ic4, size_t oc8, size_t offset, size_t mode, size_t writeC4, size_t relu, @@ -36,7 +33,6 @@ void ConvFp16(float16_t *input_data, float16_t *packed_input, float16_t *packed_ void Conv3x3Fp16(float16_t *input_data, float16_t *transed_weight, const float16_t *bias_data, float16_t *output_data, float16_t *tile_buffer, float16_t *block_unit_buffer, float16_t *tmp_dst_buffer, float16_t *tmp_out, int task_id, ConvParameter *conv_param); -#endif #endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_OPCLIB_FP16_CONV_FP16_H_ diff --git a/mindspore/lite/src/runtime/kernel/arm/opclib/fp16/pack_fp16.cc b/mindspore/lite/src/runtime/kernel/arm/opclib/fp16/pack_fp16.cc new file mode 100644 index 00000000000..4d6bd7344ef --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/opclib/fp16/pack_fp16.cc @@ -0,0 +1,342 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/runtime/kernel/arm/opclib/fp16/pack_fp16.h" +#include +#include + +void Im2ColPackUnitFp16(float16_t *input_data, ConvParameter *conv_param, float16_t *packed_input, int real_cal_num, + int block_index) { + // input format : nhwc + int kernel_h = conv_param->kernel_h_; + int kernel_w = conv_param->kernel_w_; + int stride_h = conv_param->stride_h_; + int stride_w = conv_param->stride_w_; + int pad_h = conv_param->pad_h_; + int pad_w = conv_param->pad_w_; + int dilation_h = conv_param->dilation_h_; + int dilation_w = conv_param->dilation_w_; + int in_channel = conv_param->input_channel_; + int in_h = conv_param->input_h_; + int in_w = conv_param->input_w_; + int out_w = conv_param->output_w_; + int channel_block = UP_DIV(in_channel, 4); + int kernel_plane = kernel_h * kernel_w; + + for (int i = 0; i < real_cal_num; i++) { + int block_start = block_index + i; + int input_h = block_start / out_w * stride_h - pad_h; + int input_w = block_start % out_w * stride_w - pad_w; + for (int j = 0; j < kernel_h; j++) { + int input_y = input_h + j * dilation_h; + if (input_y < 0 || input_y >= in_h) { + continue; + } + int input_y_stride = input_y * in_w * channel_block * C4NUM; + for (int n = 0; n < kernel_w; n++) { + int input_x = input_w + n * dilation_w; + if (input_x < 0 || input_x >= in_w) { + continue; + } + int input_x_stride = input_y_stride + input_x * channel_block * C4NUM; + int input_plane_offset = (j * kernel_w + n) * 16 * C4NUM * channel_block + i * C4NUM; + for (int m = 0; m < channel_block; m++) { + int channel_block_stride = input_x_stride + m * C4NUM; + int channel_block_offset = input_plane_offset + m * 16 * C4NUM; + (packed_input + channel_block_offset)[0] = (input_data + channel_block_stride)[0]; + (packed_input + channel_block_offset)[1] = (input_data + channel_block_stride)[1]; + (packed_input + channel_block_offset)[2] = (input_data + channel_block_stride)[2]; + (packed_input + channel_block_offset)[3] = (input_data + channel_block_stride)[3]; + } // channel_block loop + } // kernel_w loop + } // kernel_h loop + } // tile num loop +} + +void PackWeightFp16(float16_t *weight_data, ConvParameter *conv_param, float16_t *packed_weight) { + // original weight format : ohwi + int tile_num = 8; + int inchannel_block = 4; + int kernel_h = conv_param->kernel_h_; + int kernel_w = conv_param->kernel_w_; + int in_channel = conv_param->input_channel_; + int out_channel = conv_param->output_channel_; + int kernel_block = UP_DIV(out_channel, tile_num); + int channel_block = UP_DIV(in_channel, inchannel_block); + int kernel_plane = kernel_h * kernel_w; + int pack_weight_size = kernel_block * channel_block * tile_num * inchannel_block * kernel_plane; + + int unit_size = tile_num * inchannel_block; + int block_size = pack_weight_size / kernel_block; + + for (int m = 0; m < kernel_plane; m++) { + int kernel_plane_stride = m * in_channel; + int packed_kernel_plane_stride = m * unit_size * channel_block; + for (int i = 0; i < channel_block; i++) { + int channel_block_stride = kernel_plane_stride + i * inchannel_block; + int packed_channel_block_size = packed_kernel_plane_stride + i * unit_size; + int ic_remainder = in_channel - i * inchannel_block; + int real_ic_num = ic_remainder < inchannel_block ? ic_remainder : inchannel_block; + for (int h = 0; h < real_ic_num; h++) { + int block_stride = channel_block_stride + h; + int packed_block_stride = packed_channel_block_size + h * tile_num; + for (int j = 0; j < kernel_block; j++) { + int kernel_block_stride = block_stride + j * tile_num * kernel_plane * in_channel; + int packed_kernel_block_size = packed_block_stride + j * block_size; + int oc_remainder = out_channel - j * tile_num; + int real_oc_num = oc_remainder < tile_num ? oc_remainder : tile_num; + for (int k = 0; k < real_oc_num; k++) { + float16_t *origin_data_ptr = weight_data + kernel_block_stride + k * kernel_plane * in_channel; + float16_t *packed_data_ptr = packed_weight + packed_kernel_block_size + k; + *packed_data_ptr = *origin_data_ptr; + } + } // kernel block loop + } // inchannel block loop + } // channel block loop + } // kernel plane loop +} + +void PackWeightToC8Fp16(const float16_t *origin_weight_data, float16_t *packed_weight_data, ConvParameter *conv_param) { + // origin weight format : ohwi + int input_channel = conv_param->input_channel_; + int ic8 = UP_DIV(input_channel, C8NUM); + int output_channel = conv_param->output_channel_; + int kernel_plane = conv_param->kernel_h_ * conv_param->kernel_w_; + + for (int k = 0; k < kernel_plane; k++) { + int src_kernel_offset = k * input_channel; + int dst_kernel_offset = k * C8NUM; + for (int o = 0; o < output_channel; o++) { + int src_oc_offset = src_kernel_offset + o * kernel_plane * input_channel; + int dst_oc_offset = dst_kernel_offset + o * ic8 * kernel_plane * C8NUM; + for (int i = 0; i < input_channel; i++) { + int c8_block_num = i / C8NUM; + int c8_block_rem = i % C8NUM; + int src_ic_offset = src_oc_offset + i; + int dst_ic_offset = dst_oc_offset + c8_block_num * kernel_plane * C8NUM + c8_block_rem; + (packed_weight_data + dst_ic_offset)[0] = (origin_weight_data + src_ic_offset)[0]; + } + } + } +} + +void PackWeightToC4Fp16(const float16_t *origin_weight_data, float16_t *packed_weight_data, ConvParameter *conv_param) { + // origin weight format : ohwi + int input_channel = conv_param->input_channel_; + int ic4 = UP_DIV(input_channel, C4NUM); + int output_channel = conv_param->output_channel_; + int kernel_plane = conv_param->kernel_h_ * conv_param->kernel_w_; + + for (int k = 0; k < kernel_plane; k++) { + int src_kernel_offset = k * input_channel; + int dst_kernel_offset = k * C4NUM; + for (int o = 0; o < output_channel; o++) { + int src_oc_offset = src_kernel_offset + o * kernel_plane * input_channel; + int dst_oc_offset = dst_kernel_offset + o * ic4 * kernel_plane * C4NUM; + for (int i = 0; i < input_channel; i++) { + int c4_block_num = i / C4NUM; + int c4_block_rem = i % C4NUM; + int src_ic_offset = src_oc_offset + i; + int dst_ic_offset = dst_oc_offset + c4_block_num * kernel_plane * C4NUM + c4_block_rem; + (packed_weight_data + dst_ic_offset)[0] = (origin_weight_data + src_ic_offset)[0]; + } + } + } +} + +void PackNHWCToNC4HW4Fp16(const void *src, void *dst, int batch, int plane, int channel) { + int c4 = UP_DIV(channel, C4NUM); + for (int b = 0; b < batch; b++) { + int src_oc_offset = b * plane * channel; + int dst_oc_offset = b * plane * c4 * C4NUM; + for (int k = 0; k < plane; k++) { + int src_kernel_offset = src_oc_offset + k * channel; + int dst_kernel_offset = dst_oc_offset + k * C4NUM; + for (int i = 0; i < channel; i++) { + int c4_block_num = i / C4NUM; + int c4_block_rem = i % C4NUM; + int src_ic_offset = src_kernel_offset + i; + int dst_ic_offset = dst_kernel_offset + c4_block_num * plane * C4NUM + c4_block_rem; + ((float16_t *)dst + dst_ic_offset)[0] = ((float16_t *)src + src_ic_offset)[0]; + } + } + } +} + +void PackNCHWToNC4HW4Fp16(const void *src, void *dst, int batch, int plane, int channel) { + int c4 = UP_DIV(channel, C4NUM); + for (int b = 0; b < batch; b++) { + int src_offset = b * plane * channel; + int dst_offset = b * plane * c4 * C4NUM; + for (int c = 0; c < channel; c++) { + int c4_block_num = c / C4NUM; + int c4_block_rem = c % C4NUM; + int src_c_offset = src_offset + c * plane; + int dst_c_offset = dst_offset + c4_block_num * plane * C4NUM; + for (int k = 0; k < plane; k++) { + int src_kernel_offset = src_c_offset + k; + int dst_kernel_offset = dst_c_offset + C4NUM * k + c4_block_rem; + ((float16_t *)dst + dst_kernel_offset)[0] = ((float16_t *)src + src_kernel_offset)[0]; + } + } + } +} + +void PackNHWCToNHWC4Fp16(const void *src, void *dst, int batch, int plane, int channel) { + int ic4 = UP_DIV(channel, C4NUM); + int nhwc4_batch_unit_offset = ic4 * C4NUM * plane; + int ic_remainder_ = channel % C4NUM; + if (ic_remainder_ != 0) { + int nhwc4_batch_offset = 0; + for (int b = 0; b < batch; b++) { + int batch_offset = b * channel * plane; + for (int i = 0; i < plane; i++) { + memcpy((float16_t *)dst + nhwc4_batch_offset + i * ic4 * C4NUM, (float16_t *)src + batch_offset + i * channel, + channel * sizeof(float16_t)); + } + nhwc4_batch_offset += nhwc4_batch_unit_offset; + } + } else { + size_t ori_input_size = batch * plane * channel * sizeof(float16_t); + memcpy(dst, src, ori_input_size); + } +} + +void PackNCHWToNHWC4Fp16(const void *src, void *dst, int batch, int plane, int channel) { + int nhwc4_batch_offset = 0; + int ic4 = UP_DIV(channel, C4NUM); + int nhwc4_batch_unit_offset = ic4 * C4NUM * plane; + + for (int b = 0; b < batch; b++) { + int batch_offset = b * channel * plane; + for (int c = 0; c < channel; c++) { + int src_c_offset = batch_offset + c * plane; + int dst_c_offset = nhwc4_batch_offset + c; + for (int i = 0; i < plane; i++) { + int src_plane_offset = src_c_offset + i; + int dst_plane_offset = dst_c_offset + i * ic4 * C4NUM; + ((float16_t *)dst)[dst_plane_offset] = ((float16_t *)src)[src_plane_offset]; + } + } + nhwc4_batch_offset += nhwc4_batch_unit_offset; + } +} + +void PackNC4HW4ToNHWC4Fp16(const void *src, void *dst, int batch, int plane, int channel) { + int c4 = UP_DIV(channel, C4NUM); + for (int b = 0; b < batch; b++) { + int src_offset = b * plane * c4 * C4NUM; + int dst_offset = b * plane * channel; + for (int c = 0; c < channel; c++) { + int c4_block_num = c / C4NUM; + int c4_block_res = c % C4NUM; + int src_c_offset = src_offset + c4_block_num * plane * C4NUM + c4_block_res; + int dst_c_offset = dst_offset + c4_block_num * C4NUM + c4_block_res; + for (int k = 0; k < plane; k++) { + int src_kernel_offset = src_c_offset + k * C4NUM; + int dst_kernel_offset = dst_c_offset + k * c4 * C4NUM; + ((float16_t *)dst + dst_kernel_offset)[0] = ((float16_t *)src + src_kernel_offset)[0]; + } + } + } +} + +void PackNC4HW4ToNHWCFp16(const void *src, void *dst, int batch, int plane, int channel) { + int c4 = UP_DIV(channel, C4NUM); + for (int b = 0; b < batch; b++) { + int src_offset = b * plane * c4 * C4NUM; + int dst_offset = b * plane * channel; + for (int c = 0; c < channel; c++) { + int c4_block_num = c / C4NUM; + int c4_block_res = c % C4NUM; + int src_c_offset = src_offset + c4_block_num * plane * C4NUM + c4_block_res; + int dst_c_offset = dst_offset + c; + for (int k = 0; k < plane; k++) { + int src_kernel_offset = src_c_offset + k * C4NUM; + int dst_kernel_offset = dst_c_offset + k * channel; + ((float16_t *)dst + dst_kernel_offset)[0] = ((float16_t *)src + src_kernel_offset)[0]; + } + } + } +} + +void PackNC4HW4ToNCHWFp16(const void *src, void *dst, int batch, int plane, int channel) { + int c4 = UP_DIV(channel, C4NUM); + for (int b = 0; b < batch; b++) { + int src_offset = b * plane * c4 * C4NUM; + int dst_offset = b * plane * channel; + for (int c = 0; c < channel; c++) { + int c4_block_num = c / C4NUM; + int c4_block_res = c % C4NUM; + int src_c_offset = src_offset + c4_block_num * plane * C4NUM + c4_block_res; + int dst_c_offset = dst_offset + c * plane; + for (int k = 0; k < plane; k++) { + int src_kernel_offset = src_c_offset + k * C4NUM; + int dst_kernel_offset = dst_c_offset + k; + ((float16_t *)dst + dst_kernel_offset)[0] = ((float16_t *)src + src_kernel_offset)[0]; + } + } + } +} + +void PackNCHWFp32ToNC8HW8Fp16(float *src, float16_t *dst, int batch, int plane, int channel) { + int c8 = UP_DIV(channel, C8NUM); + for (int b = 0; b < batch; b++) { + int src_offset = b * plane * channel; + int dst_offset = b * plane * c8 * C8NUM; + for (int c = 0; c < channel; c++) { + int c8_block_num = c / C8NUM; + int c8_block_rem = c % C8NUM; + int src_c_offset = src_offset + c * plane; + int dst_c_offset = dst_offset + c8_block_num * plane * C8NUM; + for (int k = 0; k < plane; k++) { + int src_kernel_offset = src_c_offset + k; + int dst_kernel_offset = dst_c_offset + C8NUM * k + c8_block_rem; + (dst + dst_kernel_offset)[0] = (float16_t)(src + src_kernel_offset)[0]; + } + } + } +} + +void PackNHWCFp32ToNHWC8Fp16(float *src, float16_t *dst, int batch, int plane, int channel) { + int c8 = UP_DIV(channel, C8NUM); + int nhwc8_batch_unit_offset = c8 * C8NUM * plane; + int nhwc8_batch_offset = 0; + for (int b = 0; b < batch; b++) { + int batch_offset = b * channel * plane; + for (int i = 0; i < plane; i++) { + for (int c = 0; c < channel; c++) { + (dst + nhwc8_batch_offset + i * c8 * C8NUM)[c] = (float16_t)(src + batch_offset + i * channel)[c]; + } + } + nhwc8_batch_offset += nhwc8_batch_unit_offset; + } +} + +void PackNHWC8Fp16ToNHWCFp32(float16_t *src, float *dst, int batch, int plane, int channel) { + int c8 = UP_DIV(channel, C8NUM); + int nhwc_batch_unit_offset = channel * plane; + int nhwc_batch_offset = 0; + for (int b = 0; b < batch; b++) { + int batch_offset = b * c8 * C8NUM * plane; + for (int i = 0; i < plane; i++) { + for (int c = 0; c < channel; c++) { + (dst + nhwc_batch_offset + i * channel)[c] = (float)(src + batch_offset + i * c8 * C8NUM)[c]; + } + } + nhwc_batch_offset += nhwc_batch_unit_offset; + } +} diff --git a/mindspore/lite/src/runtime/kernel/arm/opclib/fp16/pack_fp16.h b/mindspore/lite/src/runtime/kernel/arm/opclib/fp16/pack_fp16.h new file mode 100644 index 00000000000..2aa57548112 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/opclib/fp16/pack_fp16.h @@ -0,0 +1,57 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_OPCLIB_FP16_PACK_FP16_H_ +#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_OPCLIB_FP16_PACK_FP16_H_ + +#ifdef ENABLE_NEON +#include +#endif +#include "src/runtime/kernel/arm/opclib/conv_parameter.h" +#include "src/runtime/kernel/arm/opclib/op_base.h" + +void Im2ColPackUnitFp16(float16_t *input_data, ConvParameter *conv_param, float16_t *packed_input, int real_cal_num, + int block_index); + +void PackWeightFp16(float16_t *weight_data, ConvParameter *conv_param, float16_t *packed_weight); + +void PackWeightToC8Fp16(const float16_t *origin_weight_data, float16_t *packed_weight_data, ConvParameter *conv_param); + +void PackWeightToC4Fp16(const float16_t *origin_weight_data, float16_t *packed_weight_data, ConvParameter *conv_param); + +void PackNHWCToNC4HW4Fp16(const void *src, void *dst, int batch, int plane, int channel); + +void PackNCHWToNC4HW4Fp16(const void *src, void *dst, int batch, int plane, int channel); + +void PackNHWCToNHWC4Fp16(const void *src, void *dst, int batch, int plane, int channel); + +void PackNCHWToNHWC4Fp16(const void *src, void *dst, int batch, int plane, int channel); + +void PackNC4HW4ToNHWC4Fp16(const void *src, void *dst, int batch, int plane, int channel); + +void PackNC4HW4ToNHWCFp16(const void *src, void *dst, int batch, int plane, int channel); + +void PackNC4HW4ToNCHWFp16(const void *src, void *dst, int batch, int plane, int channel); + +void PackNC8HW8ToNHWCFp16(const void *src, void *dst, int batch, int plane, int channel); + +void PackNCHWFp32ToNC8HW8Fp16(float *src, float16_t *dst, int batch, int plane, int channel); + +void PackNHWCFp32ToNHWC8Fp16(float *src, float16_t *dst, int batch, int plane, int channel); + +void PackNHWC8Fp16ToNHWCFp32(float16_t *src, float *dst, int batch, int plane, int channel); + +#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_OPCLIB_FP16_PACK_FP16_H_ diff --git a/mindspore/lite/src/runtime/kernel/arm/opclib/fp16/winograd_transform_fp16.cc b/mindspore/lite/src/runtime/kernel/arm/opclib/fp16/winograd_transform_fp16.cc new file mode 100644 index 00000000000..2a3c84c3f80 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/opclib/fp16/winograd_transform_fp16.cc @@ -0,0 +1,527 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/runtime/kernel/arm/opclib/fp16/winograd_transform_fp16.h" + +// for fp16 convolution 3x3 filter/input/output transform F(4,3) +void Conv3x3Fp16InputUnit(float16_t *tmp_data, float16_t *trans_input_data, size_t step) { + float16x4_t d00 = vld1_f16(tmp_data); + float16x4_t d01 = vld1_f16(tmp_data + 4); + float16x4_t d02 = vld1_f16(tmp_data + 2 * 4); + float16x4_t d03 = vld1_f16(tmp_data + 3 * 4); + float16x4_t d04 = vld1_f16(tmp_data + 4 * 4); + float16x4_t d05 = vld1_f16(tmp_data + 5 * 4); + + float16x4_t d10 = vld1_f16(tmp_data + 6 * 4); + float16x4_t d11 = vld1_f16(tmp_data + 7 * 4); + float16x4_t d12 = vld1_f16(tmp_data + 8 * 4); + float16x4_t d13 = vld1_f16(tmp_data + 9 * 4); + float16x4_t d14 = vld1_f16(tmp_data + 10 * 4); + float16x4_t d15 = vld1_f16(tmp_data + 11 * 4); + + float16x4_t d20 = vld1_f16(tmp_data + 12 * 4); + float16x4_t d21 = vld1_f16(tmp_data + 13 * 4); + float16x4_t d22 = vld1_f16(tmp_data + 14 * 4); + float16x4_t d23 = vld1_f16(tmp_data + 15 * 4); + float16x4_t d24 = vld1_f16(tmp_data + 16 * 4); + float16x4_t d25 = vld1_f16(tmp_data + 17 * 4); + + float16x4_t d30 = vld1_f16(tmp_data + 18 * 4); + float16x4_t d31 = vld1_f16(tmp_data + 19 * 4); + float16x4_t d32 = vld1_f16(tmp_data + 20 * 4); + float16x4_t d33 = vld1_f16(tmp_data + 21 * 4); + float16x4_t d34 = vld1_f16(tmp_data + 22 * 4); + float16x4_t d35 = vld1_f16(tmp_data + 23 * 4); + + float16x4_t d40 = vld1_f16(tmp_data + 24 * 4); + float16x4_t d41 = vld1_f16(tmp_data + 25 * 4); + float16x4_t d42 = vld1_f16(tmp_data + 26 * 4); + float16x4_t d43 = vld1_f16(tmp_data + 27 * 4); + float16x4_t d44 = vld1_f16(tmp_data + 28 * 4); + float16x4_t d45 = vld1_f16(tmp_data + 29 * 4); + + float16x4_t d50 = vld1_f16(tmp_data + 30 * 4); + float16x4_t d51 = vld1_f16(tmp_data + 31 * 4); + float16x4_t d52 = vld1_f16(tmp_data + 32 * 4); + float16x4_t d53 = vld1_f16(tmp_data + 33 * 4); + float16x4_t d54 = vld1_f16(tmp_data + 34 * 4); + float16x4_t d55 = vld1_f16(tmp_data + 35 * 4); + + float16x4_t t00 = vadd_f16(vsub_f16(vmul_n_f16(d00, 4), vmul_n_f16(d20, 5)), d40); + float16x4_t t01 = vadd_f16(vsub_f16(vmul_n_f16(d01, 4), vmul_n_f16(d21, 5)), d41); + float16x4_t t02 = vadd_f16(vsub_f16(vmul_n_f16(d02, 4), vmul_n_f16(d22, 5)), d42); + float16x4_t t03 = vadd_f16(vsub_f16(vmul_n_f16(d03, 4), vmul_n_f16(d23, 5)), d43); + float16x4_t t04 = vadd_f16(vsub_f16(vmul_n_f16(d04, 4), vmul_n_f16(d24, 5)), d44); + float16x4_t t05 = vadd_f16(vsub_f16(vmul_n_f16(d05, 4), vmul_n_f16(d25, 5)), d45); + + float16x4_t t10 = vadd_f16(vadd_f16(d30, d40), vmul_n_f16(vadd_f16(d10, d20), -4)); + float16x4_t t11 = vadd_f16(vadd_f16(d31, d41), vmul_n_f16(vadd_f16(d11, d21), -4)); + float16x4_t t12 = vadd_f16(vadd_f16(d32, d42), vmul_n_f16(vadd_f16(d12, d22), -4)); + float16x4_t t13 = vadd_f16(vadd_f16(d33, d43), vmul_n_f16(vadd_f16(d13, d23), -4)); + float16x4_t t14 = vadd_f16(vadd_f16(d34, d44), vmul_n_f16(vadd_f16(d14, d24), -4)); + float16x4_t t15 = vadd_f16(vadd_f16(d35, d45), vmul_n_f16(vadd_f16(d15, d25), -4)); + + float16x4_t t20 = vadd_f16(vsub_f16(d40, d30), vmul_n_f16(vsub_f16(d10, d20), 4)); + float16x4_t t21 = vadd_f16(vsub_f16(d41, d31), vmul_n_f16(vsub_f16(d11, d21), 4)); + float16x4_t t22 = vadd_f16(vsub_f16(d42, d32), vmul_n_f16(vsub_f16(d12, d22), 4)); + float16x4_t t23 = vadd_f16(vsub_f16(d43, d33), vmul_n_f16(vsub_f16(d13, d23), 4)); + float16x4_t t24 = vadd_f16(vsub_f16(d44, d34), vmul_n_f16(vsub_f16(d14, d24), 4)); + float16x4_t t25 = vadd_f16(vsub_f16(d45, d35), vmul_n_f16(vsub_f16(d15, d25), 4)); + + float16x4_t t30 = vadd_f16(vsub_f16(d40, d20), vmul_n_f16(vsub_f16(d30, d10), 2)); + float16x4_t t31 = vadd_f16(vsub_f16(d41, d21), vmul_n_f16(vsub_f16(d31, d11), 2)); + float16x4_t t32 = vadd_f16(vsub_f16(d42, d22), vmul_n_f16(vsub_f16(d32, d12), 2)); + float16x4_t t33 = vadd_f16(vsub_f16(d43, d23), vmul_n_f16(vsub_f16(d33, d13), 2)); + float16x4_t t34 = vadd_f16(vsub_f16(d44, d24), vmul_n_f16(vsub_f16(d34, d14), 2)); + float16x4_t t35 = vadd_f16(vsub_f16(d45, d25), vmul_n_f16(vsub_f16(d35, d15), 2)); + + float16x4_t t40 = vadd_f16(vsub_f16(d40, d20), vmul_n_f16(vsub_f16(d10, d30), 2)); + float16x4_t t41 = vadd_f16(vsub_f16(d41, d21), vmul_n_f16(vsub_f16(d11, d31), 2)); + float16x4_t t42 = vadd_f16(vsub_f16(d42, d22), vmul_n_f16(vsub_f16(d12, d32), 2)); + float16x4_t t43 = vadd_f16(vsub_f16(d43, d23), vmul_n_f16(vsub_f16(d13, d33), 2)); + float16x4_t t44 = vadd_f16(vsub_f16(d44, d24), vmul_n_f16(vsub_f16(d14, d34), 2)); + float16x4_t t45 = vadd_f16(vsub_f16(d45, d25), vmul_n_f16(vsub_f16(d15, d35), 2)); + + float16x4_t t50 = vadd_f16(vsub_f16(vmul_n_f16(d10, 4), vmul_n_f16(d30, 5)), d50); + float16x4_t t51 = vadd_f16(vsub_f16(vmul_n_f16(d11, 4), vmul_n_f16(d31, 5)), d51); + float16x4_t t52 = vadd_f16(vsub_f16(vmul_n_f16(d12, 4), vmul_n_f16(d32, 5)), d52); + float16x4_t t53 = vadd_f16(vsub_f16(vmul_n_f16(d13, 4), vmul_n_f16(d33, 5)), d53); + float16x4_t t54 = vadd_f16(vsub_f16(vmul_n_f16(d14, 4), vmul_n_f16(d34, 5)), d54); + float16x4_t t55 = vadd_f16(vsub_f16(vmul_n_f16(d15, 4), vmul_n_f16(d35, 5)), d55); + + float16x4_t m00 = vadd_f16(vsub_f16(vmul_n_f16(t00, 4), vmul_n_f16(t02, 5)), t04); + float16x4_t m01 = vadd_f16(vadd_f16(t03, t04), vmul_n_f16(vadd_f16(t01, t02), -4)); + float16x4_t m02 = vadd_f16(vsub_f16(t04, t03), vmul_n_f16(vsub_f16(t01, t02), 4)); + float16x4_t m03 = vadd_f16(vsub_f16(t04, t02), vmul_n_f16(vsub_f16(t03, t01), 2)); + float16x4_t m04 = vadd_f16(vsub_f16(t04, t02), vmul_n_f16(vsub_f16(t01, t03), 2)); + float16x4_t m05 = vadd_f16(vsub_f16(vmul_n_f16(t01, 4), vmul_n_f16(t03, 5)), t05); + + float16x4_t m10 = vadd_f16(vsub_f16(vmul_n_f16(t10, 4), vmul_n_f16(t12, 5)), t14); + float16x4_t m11 = vadd_f16(vadd_f16(t13, t14), vmul_n_f16(vadd_f16(t11, t12), -4)); + float16x4_t m12 = vadd_f16(vsub_f16(t14, t13), vmul_n_f16(vsub_f16(t11, t12), 4)); + float16x4_t m13 = vadd_f16(vsub_f16(t14, t12), vmul_n_f16(vsub_f16(t13, t11), 2)); + float16x4_t m14 = vadd_f16(vsub_f16(t14, t12), vmul_n_f16(vsub_f16(t11, t13), 2)); + float16x4_t m15 = vadd_f16(vsub_f16(vmul_n_f16(t11, 4), vmul_n_f16(t13, 5)), t15); + + float16x4_t m20 = vadd_f16(vsub_f16(vmul_n_f16(t20, 4), vmul_n_f16(t22, 5)), t24); + float16x4_t m21 = vadd_f16(vadd_f16(t23, t24), vmul_n_f16(vadd_f16(t21, t22), -4)); + float16x4_t m22 = vadd_f16(vsub_f16(t24, t23), vmul_n_f16(vsub_f16(t21, t22), 4)); + float16x4_t m23 = vadd_f16(vsub_f16(t24, t22), vmul_n_f16(vsub_f16(t23, t21), 2)); + float16x4_t m24 = vadd_f16(vsub_f16(t24, t22), vmul_n_f16(vsub_f16(t21, t23), 2)); + float16x4_t m25 = vadd_f16(vsub_f16(vmul_n_f16(t21, 4), vmul_n_f16(t23, 5)), t25); + + float16x4_t m30 = vadd_f16(vsub_f16(vmul_n_f16(t30, 4), vmul_n_f16(t32, 5)), t34); + float16x4_t m31 = vadd_f16(vadd_f16(t33, t34), vmul_n_f16(vadd_f16(t31, t32), -4)); + float16x4_t m32 = vadd_f16(vsub_f16(t34, t33), vmul_n_f16(vsub_f16(t31, t32), 4)); + float16x4_t m33 = vadd_f16(vsub_f16(t34, t32), vmul_n_f16(vsub_f16(t33, t31), 2)); + float16x4_t m34 = vadd_f16(vsub_f16(t34, t32), vmul_n_f16(vsub_f16(t31, t33), 2)); + float16x4_t m35 = vadd_f16(vsub_f16(vmul_n_f16(t31, 4), vmul_n_f16(t33, 5)), t35); + + float16x4_t m40 = vadd_f16(vsub_f16(vmul_n_f16(t40, 4), vmul_n_f16(t42, 5)), t44); + float16x4_t m41 = vadd_f16(vadd_f16(t43, t44), vmul_n_f16(vadd_f16(t41, t42), -4)); + float16x4_t m42 = vadd_f16(vsub_f16(t44, t43), vmul_n_f16(vsub_f16(t41, t42), 4)); + float16x4_t m43 = vadd_f16(vsub_f16(t44, t42), vmul_n_f16(vsub_f16(t43, t41), 2)); + float16x4_t m44 = vadd_f16(vsub_f16(t44, t42), vmul_n_f16(vsub_f16(t41, t43), 2)); + float16x4_t m45 = vadd_f16(vsub_f16(vmul_n_f16(t41, 4), vmul_n_f16(t43, 5)), t45); + + float16x4_t m50 = vadd_f16(vsub_f16(vmul_n_f16(t50, 4), vmul_n_f16(t52, 5)), t54); + float16x4_t m51 = vadd_f16(vadd_f16(t53, t54), vmul_n_f16(vadd_f16(t51, t52), -4)); + float16x4_t m52 = vadd_f16(vsub_f16(t54, t53), vmul_n_f16(vsub_f16(t51, t52), 4)); + float16x4_t m53 = vadd_f16(vsub_f16(t54, t52), vmul_n_f16(vsub_f16(t53, t51), 2)); + float16x4_t m54 = vadd_f16(vsub_f16(t54, t52), vmul_n_f16(vsub_f16(t51, t53), 2)); + float16x4_t m55 = vadd_f16(vsub_f16(vmul_n_f16(t51, 4), vmul_n_f16(t53, 5)), t55); + + vst1_f16(trans_input_data, m00); + vst1_f16(trans_input_data + step, m01); + vst1_f16(trans_input_data + 2 * step, m02); + vst1_f16(trans_input_data + 3 * step, m03); + vst1_f16(trans_input_data + 4 * step, m04); + vst1_f16(trans_input_data + 5 * step, m05); + + vst1_f16(trans_input_data + 6 * step, m10); + vst1_f16(trans_input_data + 7 * step, m11); + vst1_f16(trans_input_data + 8 * step, m12); + vst1_f16(trans_input_data + 9 * step, m13); + vst1_f16(trans_input_data + 10 * step, m14); + vst1_f16(trans_input_data + 11 * step, m15); + + vst1_f16(trans_input_data + 12 * step, m20); + vst1_f16(trans_input_data + 13 * step, m21); + vst1_f16(trans_input_data + 14 * step, m22); + vst1_f16(trans_input_data + 15 * step, m23); + vst1_f16(trans_input_data + 16 * step, m24); + vst1_f16(trans_input_data + 17 * step, m25); + + vst1_f16(trans_input_data + 18 * step, m30); + vst1_f16(trans_input_data + 19 * step, m31); + vst1_f16(trans_input_data + 20 * step, m32); + vst1_f16(trans_input_data + 21 * step, m33); + vst1_f16(trans_input_data + 22 * step, m34); + vst1_f16(trans_input_data + 23 * step, m35); + + vst1_f16(trans_input_data + 24 * step, m40); + vst1_f16(trans_input_data + 25 * step, m41); + vst1_f16(trans_input_data + 26 * step, m42); + vst1_f16(trans_input_data + 27 * step, m43); + vst1_f16(trans_input_data + 28 * step, m44); + vst1_f16(trans_input_data + 29 * step, m45); + + vst1_f16(trans_input_data + 30 * step, m50); + vst1_f16(trans_input_data + 31 * step, m51); + vst1_f16(trans_input_data + 32 * step, m52); + vst1_f16(trans_input_data + 33 * step, m53); + vst1_f16(trans_input_data + 34 * step, m54); + vst1_f16(trans_input_data + 35 * step, m55); +} + +void Conv3x3Fp16InputTransform(const float16_t *input_data, float16_t *trans_input, float16_t *tmp_data, + int start_index, int real_cal_num, int out_w_block, ConvParameter *conv_param) { + // input data format : nhwc + int output_unit = 4; + int input_channel = conv_param->input_channel_; + int input_width = conv_param->input_w_; + int input_height = conv_param->input_h_; + int pad_w = conv_param->pad_w_; + int pad_h = conv_param->pad_h_; + int ic4 = UP_DIV(input_channel, C4NUM); + + for (int cal_id = 0; cal_id < real_cal_num; cal_id++) { + int x_id = start_index + cal_id; + int origin_x = (x_id % out_w_block) * output_unit - pad_w; + int origin_y = (x_id / out_w_block) * output_unit - pad_h; + int real_x_start = origin_x > 0 ? 0 : -origin_x; + int real_x_end = (origin_x + 6) < input_width ? 6 : (input_width - origin_x); + int real_y_start = origin_y > 0 ? 0 : -origin_y; + int real_y_end = (origin_y + 6) < input_height ? 6 : (input_height - origin_y); + + int src_plane_offset = input_channel * (origin_y * input_width + origin_x); + int dst_plane_offset = cal_id * C4NUM; + for (int ic = 0; ic < ic4; ic++) { + // clear tmp buffer + memset(tmp_data, 0, 6 * 6 * C4NUM * sizeof(float16_t)); + + // get real input block with padding + int src_ic4_offset = src_plane_offset + ic * C4NUM; + for (int interval = real_y_start; interval < real_y_end; interval++) { + int src_y_offset = src_ic4_offset + interval * input_width * input_channel + real_x_start * input_channel; + int dst_y_offset = interval * 6 * C4NUM + real_x_start * C4NUM; + for (int j = 0; j < (real_x_end - real_x_start); j++) { + int src_x_offset = src_y_offset + j * input_channel; + int dst_x_offset = dst_y_offset + j * C4NUM; + float16_t *src_addr = (float16_t *)(input_data) + src_x_offset; + float16_t *dst_addr = tmp_data + dst_x_offset; + dst_addr[0] = src_addr[0]; + dst_addr[1] = src_addr[1]; + dst_addr[2] = src_addr[2]; + dst_addr[3] = src_addr[3]; + } + } + + // todo + // input transform + int dst_ic4_offset = dst_plane_offset + ic * 16 * C4NUM; + size_t dst_step = ic4 * C4NUM * 16; + float16_t *trans_input_ptr = trans_input + dst_ic4_offset; + Conv3x3Fp16InputUnit(tmp_data, trans_input_ptr, dst_step); + } + } +} + +void Conv3x3Fp16FilterTransform(const float16_t *weight_data, float16_t *trans_weight, int iC4, int output_channel, + int kernel_plane) { + int dst_step = iC4 * C4NUM * 8; + for (int o = 0; o < output_channel; o++) { + int oc8_block_num = o / C8NUM; + int oc8_block_rem = o % C8NUM; + int src_oc_offset = o * iC4 * C4NUM * kernel_plane; + int dst_oc_offset = oc8_block_num * C8NUM * iC4 * C4NUM * 36 + oc8_block_rem; + for (int i = 0; i < iC4; i++) { + const float16_t *src_ic4_ptr = weight_data + src_oc_offset + i * kernel_plane * C4NUM; + float16_t *dst_ic4_ptr = trans_weight + dst_oc_offset + i * 8 * C4NUM; + float16x4_t g00 = vld1_f16(src_ic4_ptr); + float16x4_t g01 = vld1_f16(src_ic4_ptr + 4); + float16x4_t g02 = vld1_f16(src_ic4_ptr + 2 * 4); + float16x4_t g10 = vld1_f16(src_ic4_ptr + 3 * 4); + float16x4_t g11 = vld1_f16(src_ic4_ptr + 4 * 4); + float16x4_t g12 = vld1_f16(src_ic4_ptr + 5 * 4); + float16x4_t g20 = vld1_f16(src_ic4_ptr + 6 * 4); + float16x4_t g21 = vld1_f16(src_ic4_ptr + 7 * 4); + float16x4_t g22 = vld1_f16(src_ic4_ptr + 8 * 4); + + float16x4_t dst00 = vmul_n_f16(g00, 0.25); + float16x4_t dst01 = vmul_n_f16(g01, 0.25); + float16x4_t dst02 = vmul_n_f16(g02, 0.25); + + float16x4_t dst10 = vmul_n_f16(vadd_f16(g00, vadd_f16(g10, g20)), -0.1666666666667); + float16x4_t dst11 = vmul_n_f16(vadd_f16(g01, vadd_f16(g11, g21)), -0.1666666666667); + float16x4_t dst12 = vmul_n_f16(vadd_f16(g02, vadd_f16(g12, g22)), -0.1666666666667); + + float16x4_t dst20 = vmul_n_f16(vsub_f16(vadd_f16(g00, g20), g10), -0.1666666666667); + float16x4_t dst21 = vmul_n_f16(vsub_f16(vadd_f16(g01, g21), g11), -0.1666666666667); + float16x4_t dst22 = vmul_n_f16(vsub_f16(vadd_f16(g02, g22), g12), -0.1666666666667); + + float16x4_t dst30 = vadd_f16(vmul_n_f16(g10, 0.08333333333333), + vadd_f16(vmul_n_f16(g00, 0.04166666666667), vmul_n_f16(g20, 0.1666666666667))); + float16x4_t dst31 = vadd_f16(vmul_n_f16(g11, 0.08333333333333), + vadd_f16(vmul_n_f16(g01, 0.04166666666667), vmul_n_f16(g21, 0.1666666666667))); + float16x4_t dst32 = vadd_f16(vmul_n_f16(g12, 0.08333333333333), + vadd_f16(vmul_n_f16(g02, 0.04166666666667), vmul_n_f16(g22, 0.1666666666667))); + + float16x4_t dst40 = vsub_f16(vadd_f16(vmul_n_f16(g00, 0.04166666666667), vmul_n_f16(g20, 0.1666666666667)), + vmul_n_f16(g10, 0.08333333333333)); + float16x4_t dst41 = vsub_f16(vadd_f16(vmul_n_f16(g01, 0.04166666666667), vmul_n_f16(g21, 0.1666666666667)), + vmul_n_f16(g11, 0.08333333333333)); + float16x4_t dst42 = vsub_f16(vadd_f16(vmul_n_f16(g02, 0.04166666666667), vmul_n_f16(g22, 0.1666666666667)), + vmul_n_f16(g12, 0.08333333333333)); + + float16x4_t dst50 = g20; + float16x4_t dst51 = g21; + float16x4_t dst52 = g22; + + float16x4_t m00 = vmul_n_f16(dst00, 0.25); + float16x4_t m01 = vmul_n_f16(vadd_f16(dst00, vadd_f16(dst01, dst02)), -0.1666666666667); + float16x4_t m02 = vmul_n_f16(vsub_f16(vadd_f16(dst00, dst02), dst01), -0.1666666666667); + float16x4_t m03 = vadd_f16(vmul_n_f16(dst01, 0.08333333333333), + vadd_f16(vmul_n_f16(dst00, 0.04166666666667), vmul_n_f16(dst02, 0.1666666666667))); + float16x4_t m04 = vsub_f16(vadd_f16(vmul_n_f16(dst00, 0.04166666666667), vmul_n_f16(dst02, 0.1666666666667)), + vmul_n_f16(dst01, 0.08333333333333)); + float16x4_t m05 = dst02; + + float16x4_t m10 = vmul_n_f16(dst10, 0.25); + float16x4_t m11 = vmul_n_f16(vadd_f16(dst10, vadd_f16(dst11, dst12)), -0.1666666666667); + float16x4_t m12 = vmul_n_f16(vsub_f16(vadd_f16(dst10, dst12), dst11), -0.1666666666667); + float16x4_t m13 = vadd_f16(vmul_n_f16(dst11, 0.08333333333333), + vadd_f16(vmul_n_f16(dst10, 0.04166666666667), vmul_n_f16(dst12, 0.1666666666667))); + float16x4_t m14 = vsub_f16(vadd_f16(vmul_n_f16(dst10, 0.04166666666667), vmul_n_f16(dst12, 0.1666666666667)), + vmul_n_f16(dst11, 0.08333333333333)); + float16x4_t m15 = dst12; + + float16x4_t m20 = vmul_n_f16(dst20, 0.25); + float16x4_t m21 = vmul_n_f16(vadd_f16(dst20, vadd_f16(dst21, dst22)), -0.1666666666667); + float16x4_t m22 = vmul_n_f16(vsub_f16(vadd_f16(dst20, dst22), dst21), -0.1666666666667); + float16x4_t m23 = vadd_f16(vmul_n_f16(dst21, 0.08333333333333), + vadd_f16(vmul_n_f16(dst20, 0.04166666666667), vmul_n_f16(dst22, 0.1666666666667))); + float16x4_t m24 = vsub_f16(vadd_f16(vmul_n_f16(dst20, 0.04166666666667), vmul_n_f16(dst22, 0.1666666666667)), + vmul_n_f16(dst21, 0.08333333333333)); + float16x4_t m25 = dst22; + + float16x4_t m30 = vmul_n_f16(dst30, 0.25); + float16x4_t m31 = vmul_n_f16(vadd_f16(dst30, vadd_f16(dst31, dst32)), -0.1666666666667); + float16x4_t m32 = vmul_n_f16(vsub_f16(vadd_f16(dst30, dst32), dst31), -0.1666666666667); + float16x4_t m33 = vadd_f16(vmul_n_f16(dst31, 0.08333333333333), + vadd_f16(vmul_n_f16(dst30, 0.04166666666667), vmul_n_f16(dst32, 0.1666666666667))); + float16x4_t m34 = vsub_f16(vadd_f16(vmul_n_f16(dst30, 0.04166666666667), vmul_n_f16(dst32, 0.1666666666667)), + vmul_n_f16(dst31, 0.08333333333333)); + float16x4_t m35 = dst32; + + float16x4_t m40 = vmul_n_f16(dst40, 0.25); + float16x4_t m41 = vmul_n_f16(vadd_f16(dst40, vadd_f16(dst41, dst42)), -0.1666666666667); + float16x4_t m42 = vmul_n_f16(vsub_f16(vadd_f16(dst40, dst42), dst41), -0.1666666666667); + float16x4_t m43 = vadd_f16(vmul_n_f16(dst41, 0.08333333333333), + vadd_f16(vmul_n_f16(dst40, 0.04166666666667), vmul_n_f16(dst42, 0.1666666666667))); + float16x4_t m44 = vsub_f16(vadd_f16(vmul_n_f16(dst40, 0.04166666666667), vmul_n_f16(dst42, 0.1666666666667)), + vmul_n_f16(dst41, 0.08333333333333)); + float16x4_t m45 = dst42; + + float16x4_t m50 = vmul_n_f16(dst50, 0.25); + float16x4_t m51 = vmul_n_f16(vadd_f16(dst50, vadd_f16(dst51, dst52)), -0.1666666666667); + float16x4_t m52 = vmul_n_f16(vsub_f16(vadd_f16(dst50, dst52), dst51), -0.1666666666667); + float16x4_t m53 = vadd_f16(vmul_n_f16(dst51, 0.08333333333333), + vadd_f16(vmul_n_f16(dst50, 0.04166666666667), vmul_n_f16(dst52, 0.1666666666667))); + float16x4_t m54 = vsub_f16(vadd_f16(vmul_n_f16(dst50, 0.04166666666667), vmul_n_f16(dst52, 0.1666666666667)), + vmul_n_f16(dst51, 0.08333333333333)); + float16x4_t m55 = dst52; + + for (int j = 0; j < 4; j++) { + dst_ic4_ptr[j * 8] = m00[j]; + dst_ic4_ptr[j * 8 + dst_step] = m01[j]; + dst_ic4_ptr[j * 8 + 2 * dst_step] = m02[j]; + dst_ic4_ptr[j * 8 + 3 * dst_step] = m03[j]; + dst_ic4_ptr[j * 8 + 4 * dst_step] = m04[j]; + dst_ic4_ptr[j * 8 + 5 * dst_step] = m05[j]; + dst_ic4_ptr[j * 8 + 6 * dst_step] = m10[j]; + dst_ic4_ptr[j * 8 + 7 * dst_step] = m11[j]; + dst_ic4_ptr[j * 8 + 8 * dst_step] = m12[j]; + dst_ic4_ptr[j * 8 + 9 * dst_step] = m13[j]; + dst_ic4_ptr[j * 8 + 10 * dst_step] = m14[j]; + dst_ic4_ptr[j * 8 + 11 * dst_step] = m15[j]; + dst_ic4_ptr[j * 8 + 12 * dst_step] = m20[j]; + dst_ic4_ptr[j * 8 + 13 * dst_step] = m21[j]; + dst_ic4_ptr[j * 8 + 14 * dst_step] = m22[j]; + dst_ic4_ptr[j * 8 + 15 * dst_step] = m23[j]; + dst_ic4_ptr[j * 8 + 16 * dst_step] = m24[j]; + dst_ic4_ptr[j * 8 + 17 * dst_step] = m25[j]; + dst_ic4_ptr[j * 8 + 18 * dst_step] = m30[j]; + dst_ic4_ptr[j * 8 + 19 * dst_step] = m31[j]; + dst_ic4_ptr[j * 8 + 20 * dst_step] = m32[j]; + dst_ic4_ptr[j * 8 + 21 * dst_step] = m33[j]; + dst_ic4_ptr[j * 8 + 22 * dst_step] = m34[j]; + dst_ic4_ptr[j * 8 + 23 * dst_step] = m35[j]; + dst_ic4_ptr[j * 8 + 24 * dst_step] = m40[j]; + dst_ic4_ptr[j * 8 + 25 * dst_step] = m41[j]; + dst_ic4_ptr[j * 8 + 26 * dst_step] = m42[j]; + dst_ic4_ptr[j * 8 + 27 * dst_step] = m43[j]; + dst_ic4_ptr[j * 8 + 28 * dst_step] = m44[j]; + dst_ic4_ptr[j * 8 + 29 * dst_step] = m45[j]; + dst_ic4_ptr[j * 8 + 30 * dst_step] = m50[j]; + dst_ic4_ptr[j * 8 + 31 * dst_step] = m51[j]; + dst_ic4_ptr[j * 8 + 32 * dst_step] = m52[j]; + dst_ic4_ptr[j * 8 + 33 * dst_step] = m53[j]; + dst_ic4_ptr[j * 8 + 34 * dst_step] = m54[j]; + dst_ic4_ptr[j * 8 + 35 * dst_step] = m55[j]; + } + } + } +} + +void Conv3x3Fp16OutputUnit(const float16_t *gemm_out, const float16_t *bias_data, float16_t *output_data, + int output_w) { + float16x8_t s00 = vld1q_f16(gemm_out); + float16x8_t s01 = vld1q_f16(gemm_out + 8); + float16x8_t s02 = vld1q_f16(gemm_out + 16); + float16x8_t s03 = vld1q_f16(gemm_out + 24); + float16x8_t s04 = vld1q_f16(gemm_out + 32); + float16x8_t s05 = vld1q_f16(gemm_out + 40); + + float16x8_t s10 = vld1q_f16(gemm_out + 48); + float16x8_t s11 = vld1q_f16(gemm_out + 56); + float16x8_t s12 = vld1q_f16(gemm_out + 64); + float16x8_t s13 = vld1q_f16(gemm_out + 72); + float16x8_t s14 = vld1q_f16(gemm_out + 80); + float16x8_t s15 = vld1q_f16(gemm_out + 88); + + float16x8_t s20 = vld1q_f16(gemm_out + 96); + float16x8_t s21 = vld1q_f16(gemm_out + 104); + float16x8_t s22 = vld1q_f16(gemm_out + 112); + float16x8_t s23 = vld1q_f16(gemm_out + 120); + float16x8_t s24 = vld1q_f16(gemm_out + 128); + float16x8_t s25 = vld1q_f16(gemm_out + 136); + + float16x8_t s30 = vld1q_f16(gemm_out + 144); + float16x8_t s31 = vld1q_f16(gemm_out + 152); + float16x8_t s32 = vld1q_f16(gemm_out + 160); + float16x8_t s33 = vld1q_f16(gemm_out + 168); + float16x8_t s34 = vld1q_f16(gemm_out + 176); + float16x8_t s35 = vld1q_f16(gemm_out + 184); + + float16x8_t s40 = vld1q_f16(gemm_out + 192); + float16x8_t s41 = vld1q_f16(gemm_out + 200); + float16x8_t s42 = vld1q_f16(gemm_out + 208); + float16x8_t s43 = vld1q_f16(gemm_out + 216); + float16x8_t s44 = vld1q_f16(gemm_out + 224); + float16x8_t s45 = vld1q_f16(gemm_out + 232); + + float16x8_t s50 = vld1q_f16(gemm_out + 240); + float16x8_t s51 = vld1q_f16(gemm_out + 248); + float16x8_t s52 = vld1q_f16(gemm_out + 256); + float16x8_t s53 = vld1q_f16(gemm_out + 264); + float16x8_t s54 = vld1q_f16(gemm_out + 272); + float16x8_t s55 = vld1q_f16(gemm_out + 280); + + float16x8_t t00 = vaddq_f16(vaddq_f16(vaddq_f16(s00, s10), vaddq_f16(s20, s30)), s40); + float16x8_t t01 = vaddq_f16(vaddq_f16(vaddq_f16(s01, s11), vaddq_f16(s21, s31)), s41); + float16x8_t t02 = vaddq_f16(vaddq_f16(vaddq_f16(s02, s12), vaddq_f16(s22, s32)), s42); + float16x8_t t03 = vaddq_f16(vaddq_f16(vaddq_f16(s03, s13), vaddq_f16(s23, s33)), s43); + float16x8_t t04 = vaddq_f16(vaddq_f16(vaddq_f16(s04, s14), vaddq_f16(s24, s34)), s44); + float16x8_t t05 = vaddq_f16(vaddq_f16(vaddq_f16(s05, s15), vaddq_f16(s25, s35)), s45); + + float16x8_t t10 = vaddq_f16(vsubq_f16(s10, s20), vmulq_n_f16(vsubq_f16(s30, s40), 2)); + float16x8_t t11 = vaddq_f16(vsubq_f16(s11, s21), vmulq_n_f16(vsubq_f16(s31, s41), 2)); + float16x8_t t12 = vaddq_f16(vsubq_f16(s12, s22), vmulq_n_f16(vsubq_f16(s32, s42), 2)); + float16x8_t t13 = vaddq_f16(vsubq_f16(s13, s23), vmulq_n_f16(vsubq_f16(s33, s43), 2)); + float16x8_t t14 = vaddq_f16(vsubq_f16(s14, s24), vmulq_n_f16(vsubq_f16(s34, s44), 2)); + float16x8_t t15 = vaddq_f16(vsubq_f16(s15, s25), vmulq_n_f16(vsubq_f16(s35, s45), 2)); + + float16x8_t t20 = vaddq_f16(vaddq_f16(s10, s20), vmulq_n_f16(vaddq_f16(s30, s40), 4)); + float16x8_t t21 = vaddq_f16(vaddq_f16(s11, s21), vmulq_n_f16(vaddq_f16(s31, s41), 4)); + float16x8_t t22 = vaddq_f16(vaddq_f16(s12, s22), vmulq_n_f16(vaddq_f16(s32, s42), 4)); + float16x8_t t23 = vaddq_f16(vaddq_f16(s13, s23), vmulq_n_f16(vaddq_f16(s33, s43), 4)); + float16x8_t t24 = vaddq_f16(vaddq_f16(s14, s24), vmulq_n_f16(vaddq_f16(s34, s44), 4)); + float16x8_t t25 = vaddq_f16(vaddq_f16(s15, s25), vmulq_n_f16(vaddq_f16(s35, s45), 4)); + + float16x8_t t30 = vaddq_f16(vaddq_f16(vsubq_f16(s10, s20), vmulq_n_f16(vsubq_f16(s30, s40), 8)), s50); + float16x8_t t31 = vaddq_f16(vaddq_f16(vsubq_f16(s11, s21), vmulq_n_f16(vsubq_f16(s31, s41), 8)), s51); + float16x8_t t32 = vaddq_f16(vaddq_f16(vsubq_f16(s12, s22), vmulq_n_f16(vsubq_f16(s32, s42), 8)), s52); + float16x8_t t33 = vaddq_f16(vaddq_f16(vsubq_f16(s13, s23), vmulq_n_f16(vsubq_f16(s33, s43), 8)), s53); + float16x8_t t34 = vaddq_f16(vaddq_f16(vsubq_f16(s14, s24), vmulq_n_f16(vsubq_f16(s34, s44), 8)), s54); + float16x8_t t35 = vaddq_f16(vaddq_f16(vsubq_f16(s15, s25), vmulq_n_f16(vsubq_f16(s35, s45), 8)), s55); + + float16x8_t d00 = vaddq_f16(vaddq_f16(vaddq_f16(t00, t01), vaddq_f16(t02, t03)), t04); + float16x8_t d01 = vaddq_f16(vsubq_f16(t01, t02), vmulq_n_f16(vsubq_f16(t03, t04), 2)); + float16x8_t d02 = vaddq_f16(vaddq_f16(t01, t02), vmulq_n_f16(vaddq_f16(t03, t04), 4)); + float16x8_t d03 = vaddq_f16(vaddq_f16(vsubq_f16(t01, t02), vmulq_n_f16(vsubq_f16(t03, t04), 8)), t05); + + float16x8_t d10 = vaddq_f16(vaddq_f16(vaddq_f16(t10, t11), vaddq_f16(t12, t13)), t14); + float16x8_t d11 = vaddq_f16(vsubq_f16(t11, t12), vmulq_n_f16(vsubq_f16(t13, t14), 2)); + float16x8_t d12 = vaddq_f16(vaddq_f16(t11, t12), vmulq_n_f16(vaddq_f16(t13, t14), 4)); + float16x8_t d13 = vaddq_f16(vaddq_f16(vsubq_f16(t11, t12), vmulq_n_f16(vsubq_f16(t13, t14), 8)), t15); + + float16x8_t d20 = vaddq_f16(vaddq_f16(vaddq_f16(t20, t21), vaddq_f16(t22, t23)), t24); + float16x8_t d21 = vaddq_f16(vsubq_f16(t21, t22), vmulq_n_f16(vsubq_f16(t23, t24), 2)); + float16x8_t d22 = vaddq_f16(vaddq_f16(t21, t22), vmulq_n_f16(vaddq_f16(t23, t24), 4)); + float16x8_t d23 = vaddq_f16(vaddq_f16(vsubq_f16(t21, t22), vmulq_n_f16(vsubq_f16(t23, t24), 8)), t25); + + float16x8_t d30 = vaddq_f16(vaddq_f16(vaddq_f16(t30, t31), vaddq_f16(t32, t33)), t34); + float16x8_t d31 = vaddq_f16(vsubq_f16(t31, t32), vmulq_n_f16(vsubq_f16(t33, t34), 2)); + float16x8_t d32 = vaddq_f16(vaddq_f16(t31, t32), vmulq_n_f16(vaddq_f16(t33, t34), 4)); + float16x8_t d33 = vaddq_f16(vaddq_f16(vsubq_f16(t31, t32), vmulq_n_f16(vsubq_f16(t33, t34), 8)), t35); + + vst1q_f16(output_data, d00); + vst1q_f16(output_data + 8, d01); + vst1q_f16(output_data + 16, d02); + vst1q_f16(output_data + 24, d03); + + vst1q_f16(output_data + output_w * 8, d10); + vst1q_f16(output_data + output_w * 8 + 8, d11); + vst1q_f16(output_data + output_w * 8 + 16, d12); + vst1q_f16(output_data + output_w * 8 + 24, d13); + + vst1q_f16(output_data + 2 * output_w * 8, d20); + vst1q_f16(output_data + 2 * output_w * 8 + 8, d21); + vst1q_f16(output_data + 2 * output_w * 8 + 16, d22); + vst1q_f16(output_data + 2 * output_w * 8 + 24, d23); + + vst1q_f16(output_data + 3 * output_w * 8, d30); + vst1q_f16(output_data + 3 * output_w * 8 + 8, d31); + vst1q_f16(output_data + 3 * output_w * 8 + 16, d32); + vst1q_f16(output_data + 3 * output_w * 8 + 24, d33); +} + +void Conv3x3Fp16OutputTransform(const float16_t *gemm_out, float16_t *out_data, const float16_t *bias_data, + int start_index, int real_cal_num, int out_w_block, ConvParameter *conv_param) { + int output_channel = conv_param->output_channel_; + int output_w = conv_param->output_w_; + int output_h = conv_param->output_h_; + int oc8 = UP_DIV(output_channel, C8NUM); + + for (int i = 0; i < real_cal_num; i++) { + int out_w_index = (start_index + i) % out_w_block; + int out_h_index = (start_index + i) / out_w_block; + int src_tile_offset = i * oc8 * C8NUM * 36; + int dst_tile_offset = 8 * (out_w_index * 4 + out_h_index * 4 * output_w); + + for (int j = 0; j < oc8; j++) { + int src_oc8_offset = src_tile_offset + j * 36 * C8NUM; + int dst_oc8_offset = dst_tile_offset + j * C8NUM * output_h * output_w; + const float16_t *src_ptr = gemm_out + src_oc8_offset; + const float16_t *bias_ptr = bias_data + j * C8NUM; + float16_t *dst_ptr = out_data + dst_oc8_offset; + + // output transform + Conv3x3Fp16OutputUnit(src_ptr, bias_ptr, dst_ptr, output_w); + } + } +} diff --git a/mindspore/lite/src/runtime/kernel/arm/opclib/fp16/winograd_transform_fp16.h b/mindspore/lite/src/runtime/kernel/arm/opclib/fp16/winograd_transform_fp16.h new file mode 100644 index 00000000000..3e94d46e3fa --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/opclib/fp16/winograd_transform_fp16.h @@ -0,0 +1,39 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_OPCLIB_FP16_WINOGRAD_TRANSFORM_FP16_H_ +#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_OPCLIB_FP16_WINOGRAD_TRANSFORM_FP16_H_ + +#include +#include +#include "src/runtime/kernel/arm/opclib/fp16/pack_fp16.h" +#include "src/runtime/kernel/arm/opclib/fp16/conv_fp16.h" + +// for fp16 convolution 3x3 filter/input/output transform +void Conv3x3Fp16InputUnit(float16_t *tmp_data, float16_t *trans_input_data, size_t step); + +void Conv3x3Fp16InputTransform(const float16_t *input_data, float16_t *trans_input, float16_t *tmp_data, + int start_index, int real_cal_num, int out_w_block, ConvParameter *conv_param); + +void Conv3x3Fp16FilterTransform(const float16_t *weight_data, float16_t *trans_weight, int iC8, int output_channel, + int kernel_plane); + +void Conv3x3Fp16OutputUnit(const float16_t *gemm_out, const float16_t *bias_data, float16_t *output_data, int output_w); + +void Conv3x3Fp16OutputTransform(const float16_t *gemm_out, float16_t *out_data, const float16_t *bias_data, + int start_index, int real_cal_num, int out_w_block, ConvParameter *conv_param); + +#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_OPCLIB_FP16_WINOGRAD_TRANSFORM_FP16_H_ diff --git a/mindspore/lite/src/runtime/kernel/arm/opclib/optimized_kernel.h b/mindspore/lite/src/runtime/kernel/arm/opclib/optimized_kernel.h index 756056d6a60..150369b110d 100644 --- a/mindspore/lite/src/runtime/kernel/arm/opclib/optimized_kernel.h +++ b/mindspore/lite/src/runtime/kernel/arm/opclib/optimized_kernel.h @@ -29,11 +29,24 @@ class OptimizeModule { public: OptimizeModule() { bool support_optimize_ops = false; - + bool support_fp16 = false; #ifdef __ANDROID__ int hwcap_type = 16; uint32_t hwcap = getHwCap(hwcap_type); -#if defined(__aarch64__) +#ifdef ENABLE_ARM64 + if (hwcap & HWCAP_FPHP) { +#elif defined(__arm__) + if (hwcap & HWCAP_HALF) { +#endif + MS_LOG(INFO) << "Hw cap support FP16, hwcap: 0x" << hwcap; + support_fp16 = true; +#ifdef ENABLE_ARM64 + } +#elif defined(__arm__) + } +#endif + +#ifdef ENABLE_ARM64 if (hwcap & HWCAP_ASIMDDP) { printf("Hw cap support SMID Dot Product, hwcap: 0x%x \n", hwcap); support_optimize_ops = true; @@ -42,7 +55,7 @@ class OptimizeModule { } #endif #endif - if (!support_optimize_ops) { + if ((!support_optimize_ops) && (!support_fp16)) { return; } optimized_op_handler_ = dlopen(OPTIMIZE_SHARED_LIBRARY_PATH, RTLD_LAZY); @@ -61,4 +74,3 @@ class OptimizeModule { }; #endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_OPCLIB_OPTIMIZED_KERNEL_H_ - diff --git a/mindspore/lite/src/runtime/kernel/arm/opclib/pack.cc b/mindspore/lite/src/runtime/kernel/arm/opclib/pack.cc index 7c89e05c198..bffcb84616f 100644 --- a/mindspore/lite/src/runtime/kernel/arm/opclib/pack.cc +++ b/mindspore/lite/src/runtime/kernel/arm/opclib/pack.cc @@ -18,331 +18,6 @@ #include #include -#ifdef ENABLE_FP16 -void Im2ColPackUnitFp16(float16_t *input_data, ConvParameter *conv_param, float16_t *packed_input, int real_cal_num, - int block_index) { - // input format : nhwc - int kernel_h = conv_param->kernel_h_; - int kernel_w = conv_param->kernel_w_; - int stride_h = conv_param->stride_h_; - int stride_w = conv_param->stride_w_; - int pad_h = conv_param->pad_h_; - int pad_w = conv_param->pad_w_; - int dilation_h = conv_param->dilation_h_; - int dilation_w = conv_param->dilation_w_; - int in_channel = conv_param->input_channel_; - int in_h = conv_param->input_h_; - int in_w = conv_param->input_w_; - int out_w = conv_param->output_w_; - int channel_block = UP_DIV(in_channel, 4); - int kernel_plane = kernel_h * kernel_w; - - for (int i = 0; i < real_cal_num; i++) { - int block_start = block_index + i; - int input_h = block_start / out_w * stride_h - pad_h; - int input_w = block_start % out_w * stride_w - pad_w; - for (int j = 0; j < kernel_h; j++) { - int input_y = input_h + j * dilation_h; - if (input_y < 0 || input_y >= in_h) { - continue; - } - int input_y_stride = input_y * in_w * channel_block * C4NUM; - for (int n = 0; n < kernel_w; n++) { - int input_x = input_w + n * dilation_w; - if (input_x < 0 || input_x >= in_w) { - continue; - } - int input_x_stride = input_y_stride + input_x * channel_block * C4NUM; - int input_plane_offset = (j * kernel_w + n) * 16 * C4NUM * channel_block + i * C4NUM; - for (int m = 0; m < channel_block; m++) { - int channel_block_stride = input_x_stride + m * C4NUM; - int channel_block_offset = input_plane_offset + m * 16 * C4NUM; - (packed_input + channel_block_offset)[0] = (input_data + channel_block_stride)[0]; - (packed_input + channel_block_offset)[1] = (input_data + channel_block_stride)[1]; - (packed_input + channel_block_offset)[2] = (input_data + channel_block_stride)[2]; - (packed_input + channel_block_offset)[3] = (input_data + channel_block_stride)[3]; - } // channel_block loop - } // kernel_w loop - } // kernel_h loop - } // tile num loop -} - -void PackWeightFp16(float16_t *weight_data, ConvParameter *conv_param, float16_t *packed_weight) { - // original weight format : ohwi - int tile_num = 8; - int inchannel_block = 4; - int kernel_h = conv_param->kernel_h_; - int kernel_w = conv_param->kernel_w_; - int in_channel = conv_param->input_channel_; - int out_channel = conv_param->output_channel_; - int kernel_block = UP_DIV(out_channel, tile_num); - int channel_block = UP_DIV(in_channel, inchannel_block); - int kernel_plane = kernel_h * kernel_w; - int pack_weight_size = kernel_block * channel_block * tile_num * inchannel_block * kernel_plane; - - int unit_size = tile_num * inchannel_block; - int block_size = pack_weight_size / kernel_block; - - for (int m = 0; m < kernel_plane; m++) { - int kernel_plane_stride = m * in_channel; - int packed_kernel_plane_stride = m * unit_size * channel_block; - for (int i = 0; i < channel_block; i++) { - int channel_block_stride = kernel_plane_stride + i * inchannel_block; - int packed_channel_block_size = packed_kernel_plane_stride + i * unit_size; - int ic_remainder = in_channel - i * inchannel_block; - int real_ic_num = ic_remainder < inchannel_block ? ic_remainder : inchannel_block; - for (int h = 0; h < real_ic_num; h++) { - int block_stride = channel_block_stride + h; - int packed_block_stride = packed_channel_block_size + h * tile_num; - for (int j = 0; j < kernel_block; j++) { - int kernel_block_stride = block_stride + j * tile_num * kernel_plane * in_channel; - int packed_kernel_block_size = packed_block_stride + j * block_size; - int oc_remainder = out_channel - j * tile_num; - int real_oc_num = oc_remainder < tile_num ? oc_remainder : tile_num; - for (int k = 0; k < real_oc_num; k++) { - float16_t *origin_data_ptr = weight_data + kernel_block_stride + k * kernel_plane * in_channel; - float16_t *packed_data_ptr = packed_weight + packed_kernel_block_size + k; - *packed_data_ptr = *origin_data_ptr; - } - } // kernel block loop - } // inchannel block loop - } // channel block loop - } // kernel plane loop -} - -void PackWeightToC8Fp16(const float16_t *origin_weight_data, float16_t *packed_weight_data, ConvParameter *conv_param) { - // origin weight format : ohwi - int input_channel = conv_param->input_channel_; - int ic8 = UP_DIV(input_channel, C8NUM); - int output_channel = conv_param->output_channel_; - int kernel_plane = conv_param->kernel_h_ * conv_param->kernel_w_; - - for (int k = 0; k < kernel_plane; k++) { - int src_kernel_offset = k * input_channel; - int dst_kernel_offset = k * C8NUM; - for (int o = 0; o < output_channel; o++) { - int src_oc_offset = src_kernel_offset + o * kernel_plane * input_channel; - int dst_oc_offset = dst_kernel_offset + o * ic8 * kernel_plane * C8NUM; - for (int i = 0; i < input_channel; i++) { - int c8_block_num = i / C8NUM; - int c8_block_rem = i % C8NUM; - int src_ic_offset = src_oc_offset + i; - int dst_ic_offset = dst_oc_offset + c8_block_num * kernel_plane * C8NUM + c8_block_rem; - (packed_weight_data + dst_ic_offset)[0] = (origin_weight_data + src_ic_offset)[0]; - } - } - } -} - -void PackWeightToC4Fp16(const float16_t *origin_weight_data, float16_t *packed_weight_data, ConvParameter *conv_param) { - // origin weight format : ohwi - int input_channel = conv_param->input_channel_; - int ic4 = UP_DIV(input_channel, C4NUM); - int output_channel = conv_param->output_channel_; - int kernel_plane = conv_param->kernel_h_ * conv_param->kernel_w_; - - for (int k = 0; k < kernel_plane; k++) { - int src_kernel_offset = k * input_channel; - int dst_kernel_offset = k * C4NUM; - for (int o = 0; o < output_channel; o++) { - int src_oc_offset = src_kernel_offset + o * kernel_plane * input_channel; - int dst_oc_offset = dst_kernel_offset + o * ic4 * kernel_plane * C4NUM; - for (int i = 0; i < input_channel; i++) { - int c4_block_num = i / C4NUM; - int c4_block_rem = i % C4NUM; - int src_ic_offset = src_oc_offset + i; - int dst_ic_offset = dst_oc_offset + c4_block_num * kernel_plane * C4NUM + c4_block_rem; - (packed_weight_data + dst_ic_offset)[0] = (origin_weight_data + src_ic_offset)[0]; - } - } - } -} - -void PackNHWCToNC4HW4Fp16(const void *src, void *dst, int batch, int plane, int channel) { - int c4 = UP_DIV(channel, C4NUM); - for (int b = 0; b < batch; b++) { - int src_oc_offset = b * plane * channel; - int dst_oc_offset = b * plane * c4 * C4NUM; - for (int k = 0; k < plane; k++) { - int src_kernel_offset = src_oc_offset + k * channel; - int dst_kernel_offset = dst_oc_offset + k * C4NUM; - for (int i = 0; i < channel; i++) { - int c4_block_num = i / C4NUM; - int c4_block_rem = i % C4NUM; - int src_ic_offset = src_kernel_offset + i; - int dst_ic_offset = dst_kernel_offset + c4_block_num * plane * C4NUM + c4_block_rem; - ((float16_t *)dst + dst_ic_offset)[0] = ((float16_t *)src + src_ic_offset)[0]; - } - } - } -} - -void PackNCHWToNC4HW4Fp16(const void *src, void *dst, int batch, int plane, int channel) { - int c4 = UP_DIV(channel, C4NUM); - for (int b = 0; b < batch; b++) { - int src_offset = b * plane * channel; - int dst_offset = b * plane * c4 * C4NUM; - for (int c = 0; c < channel; c++) { - int c4_block_num = c / C4NUM; - int c4_block_rem = c % C4NUM; - int src_c_offset = src_offset + c * plane; - int dst_c_offset = dst_offset + c4_block_num * plane * C4NUM; - for (int k = 0; k < plane; k++) { - int src_kernel_offset = src_c_offset + k; - int dst_kernel_offset = dst_c_offset + C4NUM * k + c4_block_rem; - ((float16_t *)dst + dst_kernel_offset)[0] = ((float16_t *)src + src_kernel_offset)[0]; - } - } - } -} - -void PackNHWCToNHWC4Fp16(const void *src, void *dst, int batch, int plane, int channel) { - int ic4 = UP_DIV(channel, C4NUM); - int nhwc4_batch_unit_offset = ic4 * C4NUM * plane; - int ic_remainder_ = channel % C4NUM; - if (ic_remainder_ != 0) { - int nhwc4_batch_offset = 0; - for (int b = 0; b < batch; b++) { - int batch_offset = b * channel * plane; - for (int i = 0; i < plane; i++) { - memcpy((float16_t *)dst + nhwc4_batch_offset + i * ic4 * C4NUM, (float16_t *)src + batch_offset + i * channel, - channel * sizeof(float16_t)); - } - nhwc4_batch_offset += nhwc4_batch_unit_offset; - } - } else { - size_t ori_input_size = batch * plane * channel * sizeof(float16_t); - memcpy(dst, src, ori_input_size); - } -} - -void PackNCHWToNHWC4Fp16(const void *src, void *dst, int batch, int plane, int channel) { - int nhwc4_batch_offset = 0; - int ic4 = UP_DIV(channel, C4NUM); - int nhwc4_batch_unit_offset = ic4 * C4NUM * plane; - - for (int b = 0; b < batch; b++) { - int batch_offset = b * channel * plane; - for (int c = 0; c < channel; c++) { - int src_c_offset = batch_offset + c * plane; - int dst_c_offset = nhwc4_batch_offset + c; - for (int i = 0; i < plane; i++) { - int src_plane_offset = src_c_offset + i; - int dst_plane_offset = dst_c_offset + i * ic4 * C4NUM; - ((float16_t *)dst)[dst_plane_offset] = ((float16_t *)src)[src_plane_offset]; - } - } - nhwc4_batch_offset += nhwc4_batch_unit_offset; - } -} - -void PackNC4HW4ToNHWC4Fp16(const void *src, void *dst, int batch, int plane, int channel) { - int c4 = UP_DIV(channel, C4NUM); - for (int b = 0; b < batch; b++) { - int src_offset = b * plane * c4 * C4NUM; - int dst_offset = b * plane * channel; - for (int c = 0; c < channel; c++) { - int c4_block_num = c / C4NUM; - int c4_block_res = c % C4NUM; - int src_c_offset = src_offset + c4_block_num * plane * C4NUM + c4_block_res; - int dst_c_offset = dst_offset + c4_block_num * C4NUM + c4_block_res; - for (int k = 0; k < plane; k++) { - int src_kernel_offset = src_c_offset + k * C4NUM; - int dst_kernel_offset = dst_c_offset + k * c4 * C4NUM; - ((float16_t *)dst + dst_kernel_offset)[0] = ((float16_t *)src + src_kernel_offset)[0]; - } - } - } -} - -void PackNC4HW4ToNHWCFp16(const void *src, void *dst, int batch, int plane, int channel) { - int c4 = UP_DIV(channel, C4NUM); - for (int b = 0; b < batch; b++) { - int src_offset = b * plane * c4 * C4NUM; - int dst_offset = b * plane * channel; - for (int c = 0; c < channel; c++) { - int c4_block_num = c / C4NUM; - int c4_block_res = c % C4NUM; - int src_c_offset = src_offset + c4_block_num * plane * C4NUM + c4_block_res; - int dst_c_offset = dst_offset + c; - for (int k = 0; k < plane; k++) { - int src_kernel_offset = src_c_offset + k * C4NUM; - int dst_kernel_offset = dst_c_offset + k * channel; - ((float16_t *)dst + dst_kernel_offset)[0] = ((float16_t *)src + src_kernel_offset)[0]; - } - } - } -} - -void PackNC4HW4ToNCHWFp16(const void *src, void *dst, int batch, int plane, int channel) { - int c4 = UP_DIV(channel, C4NUM); - for (int b = 0; b < batch; b++) { - int src_offset = b * plane * c4 * C4NUM; - int dst_offset = b * plane * channel; - for (int c = 0; c < channel; c++) { - int c4_block_num = c / C4NUM; - int c4_block_res = c % C4NUM; - int src_c_offset = src_offset + c4_block_num * plane * C4NUM + c4_block_res; - int dst_c_offset = dst_offset + c * plane; - for (int k = 0; k < plane; k++) { - int src_kernel_offset = src_c_offset + k * C4NUM; - int dst_kernel_offset = dst_c_offset + k; - ((float16_t *)dst + dst_kernel_offset)[0] = ((float16_t *)src + src_kernel_offset)[0]; - } - } - } -} - -void PackNCHWFp32ToNC8HW8Fp16(float *src, float16_t *dst, int batch, int plane, int channel) { - int c8 = UP_DIV(channel, C8NUM); - for (int b = 0; b < batch; b++) { - int src_offset = b * plane * channel; - int dst_offset = b * plane * c8 * C8NUM; - for (int c = 0; c < channel; c++) { - int c8_block_num = c / C8NUM; - int c8_block_rem = c % C8NUM; - int src_c_offset = src_offset + c * plane; - int dst_c_offset = dst_offset + c8_block_num * plane * C8NUM; - for (int k = 0; k < plane; k++) { - int src_kernel_offset = src_c_offset + k; - int dst_kernel_offset = dst_c_offset + C8NUM * k + c8_block_rem; - (dst + dst_kernel_offset)[0] = (float16_t)(src + src_kernel_offset)[0]; - } - } - } -} - -void PackNHWCFp32ToNHWC8Fp16(float *src, float16_t *dst, int batch, int plane, int channel) { - int c8 = UP_DIV(channel, C8NUM); - int nhwc8_batch_unit_offset = c8 * C8NUM * plane; - int nhwc8_batch_offset = 0; - for (int b = 0; b < batch; b++) { - int batch_offset = b * channel * plane; - for (int i = 0; i < plane; i++) { - for (int c = 0; c < channel; c++) { - (dst + nhwc8_batch_offset + i * c8 * C8NUM)[c] = (float16_t)(src + batch_offset + i * channel)[c]; - } - } - nhwc8_batch_offset += nhwc8_batch_unit_offset; - } -} - -void PackNHWC8Fp16ToNHWCFp32(float16_t *src, float *dst, int batch, int plane, int channel) { - int c8 = UP_DIV(channel, C8NUM); - int nhwc_batch_unit_offset = channel * plane; - int nhwc_batch_offset = 0; - for (int b = 0; b < batch; b++) { - int batch_offset = b * c8 * C8NUM * plane; - for (int i = 0; i < plane; i++) { - for (int c = 0; c < channel; c++) { - (dst + nhwc_batch_offset + i * channel)[c] = (float)(src + batch_offset + i * c8 * C8NUM)[c]; - } - } - nhwc_batch_offset += nhwc_batch_unit_offset; - } -} -#endif - void PackWeightFp32(float *weight_data, ConvParameter *conv_param, float *packed_weight) { // original weight format : ohwi // todo pack weight for arm32 platform diff --git a/mindspore/lite/src/runtime/kernel/arm/opclib/pack.h b/mindspore/lite/src/runtime/kernel/arm/opclib/pack.h index 66438103ebf..c486b75e38e 100644 --- a/mindspore/lite/src/runtime/kernel/arm/opclib/pack.h +++ b/mindspore/lite/src/runtime/kernel/arm/opclib/pack.h @@ -23,38 +23,6 @@ #include "src/runtime/kernel/arm/opclib/conv_parameter.h" #include "src/runtime/kernel/arm/opclib/op_base.h" -#ifdef ENABLE_FP16 -void Im2ColPackUnitFp16(float16_t *input_data, ConvParameter *conv_param, float16_t *packed_input, int real_cal_num, - int block_index); - -void PackWeightFp16(float16_t *weight_data, ConvParameter *conv_param, float16_t *packed_weight); - -void PackWeightToC8Fp16(const float16_t *origin_weight_data, float16_t *packed_weight_data, ConvParameter *conv_param); - -void PackWeightToC4Fp16(const float16_t *origin_weight_data, float16_t *packed_weight_data, ConvParameter *conv_param); - -void PackNHWCToNC4HW4Fp16(const void *src, void *dst, int batch, int plane, int channel); - -void PackNCHWToNC4HW4Fp16(const void *src, void *dst, int batch, int plane, int channel); - -void PackNHWCToNHWC4Fp16(const void *src, void *dst, int batch, int plane, int channel); - -void PackNCHWToNHWC4Fp16(const void *src, void *dst, int batch, int plane, int channel); - -void PackNC4HW4ToNHWC4Fp16(const void *src, void *dst, int batch, int plane, int channel); - -void PackNC4HW4ToNHWCFp16(const void *src, void *dst, int batch, int plane, int channel); - -void PackNC4HW4ToNCHWFp16(const void *src, void *dst, int batch, int plane, int channel); - -void PackNC8HW8ToNHWCFp16(const void *src, void *dst, int batch, int plane, int channel); - -void PackNCHWFp32ToNC8HW8Fp16(float *src, float16_t *dst, int batch, int plane, int channel); - -void PackNHWCFp32ToNHWC8Fp16(float *src, float16_t *dst, int batch, int plane, int channel); - -void PackNHWC8Fp16ToNHWCFp32(float16_t *src, float *dst, int batch, int plane, int channel); -#endif void Im2ColPackUnitFp32(const float *input_data, ConvParameter *conv_param, float *packed_input, int real_cal_num, int block_index); diff --git a/mindspore/lite/src/runtime/kernel/arm/opclib/winograd_transform.cc b/mindspore/lite/src/runtime/kernel/arm/opclib/winograd_transform.cc index 162b58f7942..571501b1426 100644 --- a/mindspore/lite/src/runtime/kernel/arm/opclib/winograd_transform.cc +++ b/mindspore/lite/src/runtime/kernel/arm/opclib/winograd_transform.cc @@ -675,518 +675,6 @@ void Conv3x3Fp32OutputTransform(const float *gemm_out, float *out_data, const fl } } -#ifdef ENABLE_FP16 -// for fp16 convolution 3x3 filter/input/output transform F(4,3) -void Conv3x3Fp16InputUnit(float16_t *tmp_data, float16_t *trans_input_data, size_t step) { - float16x4_t d00 = vld1_f16(tmp_data); - float16x4_t d01 = vld1_f16(tmp_data + 4); - float16x4_t d02 = vld1_f16(tmp_data + 2 * 4); - float16x4_t d03 = vld1_f16(tmp_data + 3 * 4); - float16x4_t d04 = vld1_f16(tmp_data + 4 * 4); - float16x4_t d05 = vld1_f16(tmp_data + 5 * 4); - - float16x4_t d10 = vld1_f16(tmp_data + 6 * 4); - float16x4_t d11 = vld1_f16(tmp_data + 7 * 4); - float16x4_t d12 = vld1_f16(tmp_data + 8 * 4); - float16x4_t d13 = vld1_f16(tmp_data + 9 * 4); - float16x4_t d14 = vld1_f16(tmp_data + 10 * 4); - float16x4_t d15 = vld1_f16(tmp_data + 11 * 4); - - float16x4_t d20 = vld1_f16(tmp_data + 12 * 4); - float16x4_t d21 = vld1_f16(tmp_data + 13 * 4); - float16x4_t d22 = vld1_f16(tmp_data + 14 * 4); - float16x4_t d23 = vld1_f16(tmp_data + 15 * 4); - float16x4_t d24 = vld1_f16(tmp_data + 16 * 4); - float16x4_t d25 = vld1_f16(tmp_data + 17 * 4); - - float16x4_t d30 = vld1_f16(tmp_data + 18 * 4); - float16x4_t d31 = vld1_f16(tmp_data + 19 * 4); - float16x4_t d32 = vld1_f16(tmp_data + 20 * 4); - float16x4_t d33 = vld1_f16(tmp_data + 21 * 4); - float16x4_t d34 = vld1_f16(tmp_data + 22 * 4); - float16x4_t d35 = vld1_f16(tmp_data + 23 * 4); - - float16x4_t d40 = vld1_f16(tmp_data + 24 * 4); - float16x4_t d41 = vld1_f16(tmp_data + 25 * 4); - float16x4_t d42 = vld1_f16(tmp_data + 26 * 4); - float16x4_t d43 = vld1_f16(tmp_data + 27 * 4); - float16x4_t d44 = vld1_f16(tmp_data + 28 * 4); - float16x4_t d45 = vld1_f16(tmp_data + 29 * 4); - - float16x4_t d50 = vld1_f16(tmp_data + 30 * 4); - float16x4_t d51 = vld1_f16(tmp_data + 31 * 4); - float16x4_t d52 = vld1_f16(tmp_data + 32 * 4); - float16x4_t d53 = vld1_f16(tmp_data + 33 * 4); - float16x4_t d54 = vld1_f16(tmp_data + 34 * 4); - float16x4_t d55 = vld1_f16(tmp_data + 35 * 4); - - float16x4_t t00 = vadd_f16(vsub_f16(vmul_n_f16(d00, 4), vmul_n_f16(d20, 5)), d40); - float16x4_t t01 = vadd_f16(vsub_f16(vmul_n_f16(d01, 4), vmul_n_f16(d21, 5)), d41); - float16x4_t t02 = vadd_f16(vsub_f16(vmul_n_f16(d02, 4), vmul_n_f16(d22, 5)), d42); - float16x4_t t03 = vadd_f16(vsub_f16(vmul_n_f16(d03, 4), vmul_n_f16(d23, 5)), d43); - float16x4_t t04 = vadd_f16(vsub_f16(vmul_n_f16(d04, 4), vmul_n_f16(d24, 5)), d44); - float16x4_t t05 = vadd_f16(vsub_f16(vmul_n_f16(d05, 4), vmul_n_f16(d25, 5)), d45); - - float16x4_t t10 = vadd_f16(vadd_f16(d30, d40), vmul_n_f16(vadd_f16(d10, d20), -4)); - float16x4_t t11 = vadd_f16(vadd_f16(d31, d41), vmul_n_f16(vadd_f16(d11, d21), -4)); - float16x4_t t12 = vadd_f16(vadd_f16(d32, d42), vmul_n_f16(vadd_f16(d12, d22), -4)); - float16x4_t t13 = vadd_f16(vadd_f16(d33, d43), vmul_n_f16(vadd_f16(d13, d23), -4)); - float16x4_t t14 = vadd_f16(vadd_f16(d34, d44), vmul_n_f16(vadd_f16(d14, d24), -4)); - float16x4_t t15 = vadd_f16(vadd_f16(d35, d45), vmul_n_f16(vadd_f16(d15, d25), -4)); - - float16x4_t t20 = vadd_f16(vsub_f16(d40, d30), vmul_n_f16(vsub_f16(d10, d20), 4)); - float16x4_t t21 = vadd_f16(vsub_f16(d41, d31), vmul_n_f16(vsub_f16(d11, d21), 4)); - float16x4_t t22 = vadd_f16(vsub_f16(d42, d32), vmul_n_f16(vsub_f16(d12, d22), 4)); - float16x4_t t23 = vadd_f16(vsub_f16(d43, d33), vmul_n_f16(vsub_f16(d13, d23), 4)); - float16x4_t t24 = vadd_f16(vsub_f16(d44, d34), vmul_n_f16(vsub_f16(d14, d24), 4)); - float16x4_t t25 = vadd_f16(vsub_f16(d45, d35), vmul_n_f16(vsub_f16(d15, d25), 4)); - - float16x4_t t30 = vadd_f16(vsub_f16(d40, d20), vmul_n_f16(vsub_f16(d30, d10), 2)); - float16x4_t t31 = vadd_f16(vsub_f16(d41, d21), vmul_n_f16(vsub_f16(d31, d11), 2)); - float16x4_t t32 = vadd_f16(vsub_f16(d42, d22), vmul_n_f16(vsub_f16(d32, d12), 2)); - float16x4_t t33 = vadd_f16(vsub_f16(d43, d23), vmul_n_f16(vsub_f16(d33, d13), 2)); - float16x4_t t34 = vadd_f16(vsub_f16(d44, d24), vmul_n_f16(vsub_f16(d34, d14), 2)); - float16x4_t t35 = vadd_f16(vsub_f16(d45, d25), vmul_n_f16(vsub_f16(d35, d15), 2)); - - float16x4_t t40 = vadd_f16(vsub_f16(d40, d20), vmul_n_f16(vsub_f16(d10, d30), 2)); - float16x4_t t41 = vadd_f16(vsub_f16(d41, d21), vmul_n_f16(vsub_f16(d11, d31), 2)); - float16x4_t t42 = vadd_f16(vsub_f16(d42, d22), vmul_n_f16(vsub_f16(d12, d32), 2)); - float16x4_t t43 = vadd_f16(vsub_f16(d43, d23), vmul_n_f16(vsub_f16(d13, d33), 2)); - float16x4_t t44 = vadd_f16(vsub_f16(d44, d24), vmul_n_f16(vsub_f16(d14, d34), 2)); - float16x4_t t45 = vadd_f16(vsub_f16(d45, d25), vmul_n_f16(vsub_f16(d15, d35), 2)); - - float16x4_t t50 = vadd_f16(vsub_f16(vmul_n_f16(d10, 4), vmul_n_f16(d30, 5)), d50); - float16x4_t t51 = vadd_f16(vsub_f16(vmul_n_f16(d11, 4), vmul_n_f16(d31, 5)), d51); - float16x4_t t52 = vadd_f16(vsub_f16(vmul_n_f16(d12, 4), vmul_n_f16(d32, 5)), d52); - float16x4_t t53 = vadd_f16(vsub_f16(vmul_n_f16(d13, 4), vmul_n_f16(d33, 5)), d53); - float16x4_t t54 = vadd_f16(vsub_f16(vmul_n_f16(d14, 4), vmul_n_f16(d34, 5)), d54); - float16x4_t t55 = vadd_f16(vsub_f16(vmul_n_f16(d15, 4), vmul_n_f16(d35, 5)), d55); - - float16x4_t m00 = vadd_f16(vsub_f16(vmul_n_f16(t00, 4), vmul_n_f16(t02, 5)), t04); - float16x4_t m01 = vadd_f16(vadd_f16(t03, t04), vmul_n_f16(vadd_f16(t01, t02), -4)); - float16x4_t m02 = vadd_f16(vsub_f16(t04, t03), vmul_n_f16(vsub_f16(t01, t02), 4)); - float16x4_t m03 = vadd_f16(vsub_f16(t04, t02), vmul_n_f16(vsub_f16(t03, t01), 2)); - float16x4_t m04 = vadd_f16(vsub_f16(t04, t02), vmul_n_f16(vsub_f16(t01, t03), 2)); - float16x4_t m05 = vadd_f16(vsub_f16(vmul_n_f16(t01, 4), vmul_n_f16(t03, 5)), t05); - - float16x4_t m10 = vadd_f16(vsub_f16(vmul_n_f16(t10, 4), vmul_n_f16(t12, 5)), t14); - float16x4_t m11 = vadd_f16(vadd_f16(t13, t14), vmul_n_f16(vadd_f16(t11, t12), -4)); - float16x4_t m12 = vadd_f16(vsub_f16(t14, t13), vmul_n_f16(vsub_f16(t11, t12), 4)); - float16x4_t m13 = vadd_f16(vsub_f16(t14, t12), vmul_n_f16(vsub_f16(t13, t11), 2)); - float16x4_t m14 = vadd_f16(vsub_f16(t14, t12), vmul_n_f16(vsub_f16(t11, t13), 2)); - float16x4_t m15 = vadd_f16(vsub_f16(vmul_n_f16(t11, 4), vmul_n_f16(t13, 5)), t15); - - float16x4_t m20 = vadd_f16(vsub_f16(vmul_n_f16(t20, 4), vmul_n_f16(t22, 5)), t24); - float16x4_t m21 = vadd_f16(vadd_f16(t23, t24), vmul_n_f16(vadd_f16(t21, t22), -4)); - float16x4_t m22 = vadd_f16(vsub_f16(t24, t23), vmul_n_f16(vsub_f16(t21, t22), 4)); - float16x4_t m23 = vadd_f16(vsub_f16(t24, t22), vmul_n_f16(vsub_f16(t23, t21), 2)); - float16x4_t m24 = vadd_f16(vsub_f16(t24, t22), vmul_n_f16(vsub_f16(t21, t23), 2)); - float16x4_t m25 = vadd_f16(vsub_f16(vmul_n_f16(t21, 4), vmul_n_f16(t23, 5)), t25); - - float16x4_t m30 = vadd_f16(vsub_f16(vmul_n_f16(t30, 4), vmul_n_f16(t32, 5)), t34); - float16x4_t m31 = vadd_f16(vadd_f16(t33, t34), vmul_n_f16(vadd_f16(t31, t32), -4)); - float16x4_t m32 = vadd_f16(vsub_f16(t34, t33), vmul_n_f16(vsub_f16(t31, t32), 4)); - float16x4_t m33 = vadd_f16(vsub_f16(t34, t32), vmul_n_f16(vsub_f16(t33, t31), 2)); - float16x4_t m34 = vadd_f16(vsub_f16(t34, t32), vmul_n_f16(vsub_f16(t31, t33), 2)); - float16x4_t m35 = vadd_f16(vsub_f16(vmul_n_f16(t31, 4), vmul_n_f16(t33, 5)), t35); - - float16x4_t m40 = vadd_f16(vsub_f16(vmul_n_f16(t40, 4), vmul_n_f16(t42, 5)), t44); - float16x4_t m41 = vadd_f16(vadd_f16(t43, t44), vmul_n_f16(vadd_f16(t41, t42), -4)); - float16x4_t m42 = vadd_f16(vsub_f16(t44, t43), vmul_n_f16(vsub_f16(t41, t42), 4)); - float16x4_t m43 = vadd_f16(vsub_f16(t44, t42), vmul_n_f16(vsub_f16(t43, t41), 2)); - float16x4_t m44 = vadd_f16(vsub_f16(t44, t42), vmul_n_f16(vsub_f16(t41, t43), 2)); - float16x4_t m45 = vadd_f16(vsub_f16(vmul_n_f16(t41, 4), vmul_n_f16(t43, 5)), t45); - - float16x4_t m50 = vadd_f16(vsub_f16(vmul_n_f16(t50, 4), vmul_n_f16(t52, 5)), t54); - float16x4_t m51 = vadd_f16(vadd_f16(t53, t54), vmul_n_f16(vadd_f16(t51, t52), -4)); - float16x4_t m52 = vadd_f16(vsub_f16(t54, t53), vmul_n_f16(vsub_f16(t51, t52), 4)); - float16x4_t m53 = vadd_f16(vsub_f16(t54, t52), vmul_n_f16(vsub_f16(t53, t51), 2)); - float16x4_t m54 = vadd_f16(vsub_f16(t54, t52), vmul_n_f16(vsub_f16(t51, t53), 2)); - float16x4_t m55 = vadd_f16(vsub_f16(vmul_n_f16(t51, 4), vmul_n_f16(t53, 5)), t55); - - vst1_f16(trans_input_data, m00); - vst1_f16(trans_input_data + step, m01); - vst1_f16(trans_input_data + 2 * step, m02); - vst1_f16(trans_input_data + 3 * step, m03); - vst1_f16(trans_input_data + 4 * step, m04); - vst1_f16(trans_input_data + 5 * step, m05); - - vst1_f16(trans_input_data + 6 * step, m10); - vst1_f16(trans_input_data + 7 * step, m11); - vst1_f16(trans_input_data + 8 * step, m12); - vst1_f16(trans_input_data + 9 * step, m13); - vst1_f16(trans_input_data + 10 * step, m14); - vst1_f16(trans_input_data + 11 * step, m15); - - vst1_f16(trans_input_data + 12 * step, m20); - vst1_f16(trans_input_data + 13 * step, m21); - vst1_f16(trans_input_data + 14 * step, m22); - vst1_f16(trans_input_data + 15 * step, m23); - vst1_f16(trans_input_data + 16 * step, m24); - vst1_f16(trans_input_data + 17 * step, m25); - - vst1_f16(trans_input_data + 18 * step, m30); - vst1_f16(trans_input_data + 19 * step, m31); - vst1_f16(trans_input_data + 20 * step, m32); - vst1_f16(trans_input_data + 21 * step, m33); - vst1_f16(trans_input_data + 22 * step, m34); - vst1_f16(trans_input_data + 23 * step, m35); - - vst1_f16(trans_input_data + 24 * step, m40); - vst1_f16(trans_input_data + 25 * step, m41); - vst1_f16(trans_input_data + 26 * step, m42); - vst1_f16(trans_input_data + 27 * step, m43); - vst1_f16(trans_input_data + 28 * step, m44); - vst1_f16(trans_input_data + 29 * step, m45); - - vst1_f16(trans_input_data + 30 * step, m50); - vst1_f16(trans_input_data + 31 * step, m51); - vst1_f16(trans_input_data + 32 * step, m52); - vst1_f16(trans_input_data + 33 * step, m53); - vst1_f16(trans_input_data + 34 * step, m54); - vst1_f16(trans_input_data + 35 * step, m55); -} - -void Conv3x3Fp16InputTransform(const float16_t *input_data, float16_t *trans_input, float16_t *tmp_data, - int start_index, int real_cal_num, int out_w_block, ConvParameter *conv_param) { - // input data format : nhwc - int output_unit = 4; - int input_channel = conv_param->input_channel_; - int input_width = conv_param->input_w_; - int input_height = conv_param->input_h_; - int pad_w = conv_param->pad_w_; - int pad_h = conv_param->pad_h_; - int ic4 = UP_DIV(input_channel, C4NUM); - - for (int cal_id = 0; cal_id < real_cal_num; cal_id++) { - int x_id = start_index + cal_id; - int origin_x = (x_id % out_w_block) * output_unit - pad_w; - int origin_y = (x_id / out_w_block) * output_unit - pad_h; - int real_x_start = origin_x > 0 ? 0 : -origin_x; - int real_x_end = (origin_x + 6) < input_width ? 6 : (input_width - origin_x); - int real_y_start = origin_y > 0 ? 0 : -origin_y; - int real_y_end = (origin_y + 6) < input_height ? 6 : (input_height - origin_y); - - int src_plane_offset = input_channel * (origin_y * input_width + origin_x); - int dst_plane_offset = cal_id * C4NUM; - for (int ic = 0; ic < ic4; ic++) { - // clear tmp buffer - memset(tmp_data, 0, 6 * 6 * C4NUM * sizeof(float16_t)); - - // get real input block with padding - int src_ic4_offset = src_plane_offset + ic * C4NUM; - for (int interval = real_y_start; interval < real_y_end; interval++) { - int src_y_offset = src_ic4_offset + interval * input_width * input_channel + real_x_start * input_channel; - int dst_y_offset = interval * 6 * C4NUM + real_x_start * C4NUM; - for (int j = 0; j < (real_x_end - real_x_start); j++) { - int src_x_offset = src_y_offset + j * input_channel; - int dst_x_offset = dst_y_offset + j * C4NUM; - float16_t *src_addr = (float16_t *)(input_data) + src_x_offset; - float16_t *dst_addr = tmp_data + dst_x_offset; - dst_addr[0] = src_addr[0]; - dst_addr[1] = src_addr[1]; - dst_addr[2] = src_addr[2]; - dst_addr[3] = src_addr[3]; - } - } - - // todo - // input transform - int dst_ic4_offset = dst_plane_offset + ic * 16 * C4NUM; - size_t dst_step = ic4 * C4NUM * 16; - float16_t *trans_input_ptr = trans_input + dst_ic4_offset; - Conv3x3Fp16InputUnit(tmp_data, trans_input_ptr, dst_step); - } - } -} - -void Conv3x3Fp16FilterTransform(const float16_t *weight_data, float16_t *trans_weight, int iC4, int output_channel, - int kernel_plane) { - int dst_step = iC4 * C4NUM * 8; - for (int o = 0; o < output_channel; o++) { - int oc8_block_num = o / C8NUM; - int oc8_block_rem = o % C8NUM; - int src_oc_offset = o * iC4 * C4NUM * kernel_plane; - int dst_oc_offset = oc8_block_num * C8NUM * iC4 * C4NUM * 36 + oc8_block_rem; - for (int i = 0; i < iC4; i++) { - const float16_t *src_ic4_ptr = weight_data + src_oc_offset + i * kernel_plane * C4NUM; - float16_t *dst_ic4_ptr = trans_weight + dst_oc_offset + i * 8 * C4NUM; - float16x4_t g00 = vld1_f16(src_ic4_ptr); - float16x4_t g01 = vld1_f16(src_ic4_ptr + 4); - float16x4_t g02 = vld1_f16(src_ic4_ptr + 2 * 4); - float16x4_t g10 = vld1_f16(src_ic4_ptr + 3 * 4); - float16x4_t g11 = vld1_f16(src_ic4_ptr + 4 * 4); - float16x4_t g12 = vld1_f16(src_ic4_ptr + 5 * 4); - float16x4_t g20 = vld1_f16(src_ic4_ptr + 6 * 4); - float16x4_t g21 = vld1_f16(src_ic4_ptr + 7 * 4); - float16x4_t g22 = vld1_f16(src_ic4_ptr + 8 * 4); - - float16x4_t dst00 = vmul_n_f16(g00, 0.25); - float16x4_t dst01 = vmul_n_f16(g01, 0.25); - float16x4_t dst02 = vmul_n_f16(g02, 0.25); - - float16x4_t dst10 = vmul_n_f16(vadd_f16(g00, vadd_f16(g10, g20)), -0.1666666666667); - float16x4_t dst11 = vmul_n_f16(vadd_f16(g01, vadd_f16(g11, g21)), -0.1666666666667); - float16x4_t dst12 = vmul_n_f16(vadd_f16(g02, vadd_f16(g12, g22)), -0.1666666666667); - - float16x4_t dst20 = vmul_n_f16(vsub_f16(vadd_f16(g00, g20), g10), -0.1666666666667); - float16x4_t dst21 = vmul_n_f16(vsub_f16(vadd_f16(g01, g21), g11), -0.1666666666667); - float16x4_t dst22 = vmul_n_f16(vsub_f16(vadd_f16(g02, g22), g12), -0.1666666666667); - - float16x4_t dst30 = vadd_f16(vmul_n_f16(g10, 0.08333333333333), - vadd_f16(vmul_n_f16(g00, 0.04166666666667), vmul_n_f16(g20, 0.1666666666667))); - float16x4_t dst31 = vadd_f16(vmul_n_f16(g11, 0.08333333333333), - vadd_f16(vmul_n_f16(g01, 0.04166666666667), vmul_n_f16(g21, 0.1666666666667))); - float16x4_t dst32 = vadd_f16(vmul_n_f16(g12, 0.08333333333333), - vadd_f16(vmul_n_f16(g02, 0.04166666666667), vmul_n_f16(g22, 0.1666666666667))); - - float16x4_t dst40 = vsub_f16(vadd_f16(vmul_n_f16(g00, 0.04166666666667), vmul_n_f16(g20, 0.1666666666667)), - vmul_n_f16(g10, 0.08333333333333)); - float16x4_t dst41 = vsub_f16(vadd_f16(vmul_n_f16(g01, 0.04166666666667), vmul_n_f16(g21, 0.1666666666667)), - vmul_n_f16(g11, 0.08333333333333)); - float16x4_t dst42 = vsub_f16(vadd_f16(vmul_n_f16(g02, 0.04166666666667), vmul_n_f16(g22, 0.1666666666667)), - vmul_n_f16(g12, 0.08333333333333)); - - float16x4_t dst50 = g20; - float16x4_t dst51 = g21; - float16x4_t dst52 = g22; - - float16x4_t m00 = vmul_n_f16(dst00, 0.25); - float16x4_t m01 = vmul_n_f16(vadd_f16(dst00, vadd_f16(dst01, dst02)), -0.1666666666667); - float16x4_t m02 = vmul_n_f16(vsub_f16(vadd_f16(dst00, dst02), dst01), -0.1666666666667); - float16x4_t m03 = vadd_f16(vmul_n_f16(dst01, 0.08333333333333), - vadd_f16(vmul_n_f16(dst00, 0.04166666666667), vmul_n_f16(dst02, 0.1666666666667))); - float16x4_t m04 = vsub_f16(vadd_f16(vmul_n_f16(dst00, 0.04166666666667), vmul_n_f16(dst02, 0.1666666666667)), - vmul_n_f16(dst01, 0.08333333333333)); - float16x4_t m05 = dst02; - - float16x4_t m10 = vmul_n_f16(dst10, 0.25); - float16x4_t m11 = vmul_n_f16(vadd_f16(dst10, vadd_f16(dst11, dst12)), -0.1666666666667); - float16x4_t m12 = vmul_n_f16(vsub_f16(vadd_f16(dst10, dst12), dst11), -0.1666666666667); - float16x4_t m13 = vadd_f16(vmul_n_f16(dst11, 0.08333333333333), - vadd_f16(vmul_n_f16(dst10, 0.04166666666667), vmul_n_f16(dst12, 0.1666666666667))); - float16x4_t m14 = vsub_f16(vadd_f16(vmul_n_f16(dst10, 0.04166666666667), vmul_n_f16(dst12, 0.1666666666667)), - vmul_n_f16(dst11, 0.08333333333333)); - float16x4_t m15 = dst12; - - float16x4_t m20 = vmul_n_f16(dst20, 0.25); - float16x4_t m21 = vmul_n_f16(vadd_f16(dst20, vadd_f16(dst21, dst22)), -0.1666666666667); - float16x4_t m22 = vmul_n_f16(vsub_f16(vadd_f16(dst20, dst22), dst21), -0.1666666666667); - float16x4_t m23 = vadd_f16(vmul_n_f16(dst21, 0.08333333333333), - vadd_f16(vmul_n_f16(dst20, 0.04166666666667), vmul_n_f16(dst22, 0.1666666666667))); - float16x4_t m24 = vsub_f16(vadd_f16(vmul_n_f16(dst20, 0.04166666666667), vmul_n_f16(dst22, 0.1666666666667)), - vmul_n_f16(dst21, 0.08333333333333)); - float16x4_t m25 = dst22; - - float16x4_t m30 = vmul_n_f16(dst30, 0.25); - float16x4_t m31 = vmul_n_f16(vadd_f16(dst30, vadd_f16(dst31, dst32)), -0.1666666666667); - float16x4_t m32 = vmul_n_f16(vsub_f16(vadd_f16(dst30, dst32), dst31), -0.1666666666667); - float16x4_t m33 = vadd_f16(vmul_n_f16(dst31, 0.08333333333333), - vadd_f16(vmul_n_f16(dst30, 0.04166666666667), vmul_n_f16(dst32, 0.1666666666667))); - float16x4_t m34 = vsub_f16(vadd_f16(vmul_n_f16(dst30, 0.04166666666667), vmul_n_f16(dst32, 0.1666666666667)), - vmul_n_f16(dst31, 0.08333333333333)); - float16x4_t m35 = dst32; - - float16x4_t m40 = vmul_n_f16(dst40, 0.25); - float16x4_t m41 = vmul_n_f16(vadd_f16(dst40, vadd_f16(dst41, dst42)), -0.1666666666667); - float16x4_t m42 = vmul_n_f16(vsub_f16(vadd_f16(dst40, dst42), dst41), -0.1666666666667); - float16x4_t m43 = vadd_f16(vmul_n_f16(dst41, 0.08333333333333), - vadd_f16(vmul_n_f16(dst40, 0.04166666666667), vmul_n_f16(dst42, 0.1666666666667))); - float16x4_t m44 = vsub_f16(vadd_f16(vmul_n_f16(dst40, 0.04166666666667), vmul_n_f16(dst42, 0.1666666666667)), - vmul_n_f16(dst41, 0.08333333333333)); - float16x4_t m45 = dst42; - - float16x4_t m50 = vmul_n_f16(dst50, 0.25); - float16x4_t m51 = vmul_n_f16(vadd_f16(dst50, vadd_f16(dst51, dst52)), -0.1666666666667); - float16x4_t m52 = vmul_n_f16(vsub_f16(vadd_f16(dst50, dst52), dst51), -0.1666666666667); - float16x4_t m53 = vadd_f16(vmul_n_f16(dst51, 0.08333333333333), - vadd_f16(vmul_n_f16(dst50, 0.04166666666667), vmul_n_f16(dst52, 0.1666666666667))); - float16x4_t m54 = vsub_f16(vadd_f16(vmul_n_f16(dst50, 0.04166666666667), vmul_n_f16(dst52, 0.1666666666667)), - vmul_n_f16(dst51, 0.08333333333333)); - float16x4_t m55 = dst52; - - for (int j = 0; j < 4; j++) { - dst_ic4_ptr[j * 8] = m00[j]; - dst_ic4_ptr[j * 8 + dst_step] = m01[j]; - dst_ic4_ptr[j * 8 + 2 * dst_step] = m02[j]; - dst_ic4_ptr[j * 8 + 3 * dst_step] = m03[j]; - dst_ic4_ptr[j * 8 + 4 * dst_step] = m04[j]; - dst_ic4_ptr[j * 8 + 5 * dst_step] = m05[j]; - dst_ic4_ptr[j * 8 + 6 * dst_step] = m10[j]; - dst_ic4_ptr[j * 8 + 7 * dst_step] = m11[j]; - dst_ic4_ptr[j * 8 + 8 * dst_step] = m12[j]; - dst_ic4_ptr[j * 8 + 9 * dst_step] = m13[j]; - dst_ic4_ptr[j * 8 + 10 * dst_step] = m14[j]; - dst_ic4_ptr[j * 8 + 11 * dst_step] = m15[j]; - dst_ic4_ptr[j * 8 + 12 * dst_step] = m20[j]; - dst_ic4_ptr[j * 8 + 13 * dst_step] = m21[j]; - dst_ic4_ptr[j * 8 + 14 * dst_step] = m22[j]; - dst_ic4_ptr[j * 8 + 15 * dst_step] = m23[j]; - dst_ic4_ptr[j * 8 + 16 * dst_step] = m24[j]; - dst_ic4_ptr[j * 8 + 17 * dst_step] = m25[j]; - dst_ic4_ptr[j * 8 + 18 * dst_step] = m30[j]; - dst_ic4_ptr[j * 8 + 19 * dst_step] = m31[j]; - dst_ic4_ptr[j * 8 + 20 * dst_step] = m32[j]; - dst_ic4_ptr[j * 8 + 21 * dst_step] = m33[j]; - dst_ic4_ptr[j * 8 + 22 * dst_step] = m34[j]; - dst_ic4_ptr[j * 8 + 23 * dst_step] = m35[j]; - dst_ic4_ptr[j * 8 + 24 * dst_step] = m40[j]; - dst_ic4_ptr[j * 8 + 25 * dst_step] = m41[j]; - dst_ic4_ptr[j * 8 + 26 * dst_step] = m42[j]; - dst_ic4_ptr[j * 8 + 27 * dst_step] = m43[j]; - dst_ic4_ptr[j * 8 + 28 * dst_step] = m44[j]; - dst_ic4_ptr[j * 8 + 29 * dst_step] = m45[j]; - dst_ic4_ptr[j * 8 + 30 * dst_step] = m50[j]; - dst_ic4_ptr[j * 8 + 31 * dst_step] = m51[j]; - dst_ic4_ptr[j * 8 + 32 * dst_step] = m52[j]; - dst_ic4_ptr[j * 8 + 33 * dst_step] = m53[j]; - dst_ic4_ptr[j * 8 + 34 * dst_step] = m54[j]; - dst_ic4_ptr[j * 8 + 35 * dst_step] = m55[j]; - } - } - } -} - -void Conv3x3Fp16OutputUnit(const float16_t *gemm_out, const float16_t *bias_data, float16_t *output_data, - int output_w) { - float16x8_t s00 = vld1q_f16(gemm_out); - float16x8_t s01 = vld1q_f16(gemm_out + 8); - float16x8_t s02 = vld1q_f16(gemm_out + 16); - float16x8_t s03 = vld1q_f16(gemm_out + 24); - float16x8_t s04 = vld1q_f16(gemm_out + 32); - float16x8_t s05 = vld1q_f16(gemm_out + 40); - - float16x8_t s10 = vld1q_f16(gemm_out + 48); - float16x8_t s11 = vld1q_f16(gemm_out + 56); - float16x8_t s12 = vld1q_f16(gemm_out + 64); - float16x8_t s13 = vld1q_f16(gemm_out + 72); - float16x8_t s14 = vld1q_f16(gemm_out + 80); - float16x8_t s15 = vld1q_f16(gemm_out + 88); - - float16x8_t s20 = vld1q_f16(gemm_out + 96); - float16x8_t s21 = vld1q_f16(gemm_out + 104); - float16x8_t s22 = vld1q_f16(gemm_out + 112); - float16x8_t s23 = vld1q_f16(gemm_out + 120); - float16x8_t s24 = vld1q_f16(gemm_out + 128); - float16x8_t s25 = vld1q_f16(gemm_out + 136); - - float16x8_t s30 = vld1q_f16(gemm_out + 144); - float16x8_t s31 = vld1q_f16(gemm_out + 152); - float16x8_t s32 = vld1q_f16(gemm_out + 160); - float16x8_t s33 = vld1q_f16(gemm_out + 168); - float16x8_t s34 = vld1q_f16(gemm_out + 176); - float16x8_t s35 = vld1q_f16(gemm_out + 184); - - float16x8_t s40 = vld1q_f16(gemm_out + 192); - float16x8_t s41 = vld1q_f16(gemm_out + 200); - float16x8_t s42 = vld1q_f16(gemm_out + 208); - float16x8_t s43 = vld1q_f16(gemm_out + 216); - float16x8_t s44 = vld1q_f16(gemm_out + 224); - float16x8_t s45 = vld1q_f16(gemm_out + 232); - - float16x8_t s50 = vld1q_f16(gemm_out + 240); - float16x8_t s51 = vld1q_f16(gemm_out + 248); - float16x8_t s52 = vld1q_f16(gemm_out + 256); - float16x8_t s53 = vld1q_f16(gemm_out + 264); - float16x8_t s54 = vld1q_f16(gemm_out + 272); - float16x8_t s55 = vld1q_f16(gemm_out + 280); - - float16x8_t t00 = vaddq_f16(vaddq_f16(vaddq_f16(s00, s10), vaddq_f16(s20, s30)), s40); - float16x8_t t01 = vaddq_f16(vaddq_f16(vaddq_f16(s01, s11), vaddq_f16(s21, s31)), s41); - float16x8_t t02 = vaddq_f16(vaddq_f16(vaddq_f16(s02, s12), vaddq_f16(s22, s32)), s42); - float16x8_t t03 = vaddq_f16(vaddq_f16(vaddq_f16(s03, s13), vaddq_f16(s23, s33)), s43); - float16x8_t t04 = vaddq_f16(vaddq_f16(vaddq_f16(s04, s14), vaddq_f16(s24, s34)), s44); - float16x8_t t05 = vaddq_f16(vaddq_f16(vaddq_f16(s05, s15), vaddq_f16(s25, s35)), s45); - - float16x8_t t10 = vaddq_f16(vsubq_f16(s10, s20), vmulq_n_f16(vsubq_f16(s30, s40), 2)); - float16x8_t t11 = vaddq_f16(vsubq_f16(s11, s21), vmulq_n_f16(vsubq_f16(s31, s41), 2)); - float16x8_t t12 = vaddq_f16(vsubq_f16(s12, s22), vmulq_n_f16(vsubq_f16(s32, s42), 2)); - float16x8_t t13 = vaddq_f16(vsubq_f16(s13, s23), vmulq_n_f16(vsubq_f16(s33, s43), 2)); - float16x8_t t14 = vaddq_f16(vsubq_f16(s14, s24), vmulq_n_f16(vsubq_f16(s34, s44), 2)); - float16x8_t t15 = vaddq_f16(vsubq_f16(s15, s25), vmulq_n_f16(vsubq_f16(s35, s45), 2)); - - float16x8_t t20 = vaddq_f16(vaddq_f16(s10, s20), vmulq_n_f16(vaddq_f16(s30, s40), 4)); - float16x8_t t21 = vaddq_f16(vaddq_f16(s11, s21), vmulq_n_f16(vaddq_f16(s31, s41), 4)); - float16x8_t t22 = vaddq_f16(vaddq_f16(s12, s22), vmulq_n_f16(vaddq_f16(s32, s42), 4)); - float16x8_t t23 = vaddq_f16(vaddq_f16(s13, s23), vmulq_n_f16(vaddq_f16(s33, s43), 4)); - float16x8_t t24 = vaddq_f16(vaddq_f16(s14, s24), vmulq_n_f16(vaddq_f16(s34, s44), 4)); - float16x8_t t25 = vaddq_f16(vaddq_f16(s15, s25), vmulq_n_f16(vaddq_f16(s35, s45), 4)); - - float16x8_t t30 = vaddq_f16(vaddq_f16(vsubq_f16(s10, s20), vmulq_n_f16(vsubq_f16(s30, s40), 8)), s50); - float16x8_t t31 = vaddq_f16(vaddq_f16(vsubq_f16(s11, s21), vmulq_n_f16(vsubq_f16(s31, s41), 8)), s51); - float16x8_t t32 = vaddq_f16(vaddq_f16(vsubq_f16(s12, s22), vmulq_n_f16(vsubq_f16(s32, s42), 8)), s52); - float16x8_t t33 = vaddq_f16(vaddq_f16(vsubq_f16(s13, s23), vmulq_n_f16(vsubq_f16(s33, s43), 8)), s53); - float16x8_t t34 = vaddq_f16(vaddq_f16(vsubq_f16(s14, s24), vmulq_n_f16(vsubq_f16(s34, s44), 8)), s54); - float16x8_t t35 = vaddq_f16(vaddq_f16(vsubq_f16(s15, s25), vmulq_n_f16(vsubq_f16(s35, s45), 8)), s55); - - float16x8_t d00 = vaddq_f16(vaddq_f16(vaddq_f16(t00, t01), vaddq_f16(t02, t03)), t04); - float16x8_t d01 = vaddq_f16(vsubq_f16(t01, t02), vmulq_n_f16(vsubq_f16(t03, t04), 2)); - float16x8_t d02 = vaddq_f16(vaddq_f16(t01, t02), vmulq_n_f16(vaddq_f16(t03, t04), 4)); - float16x8_t d03 = vaddq_f16(vaddq_f16(vsubq_f16(t01, t02), vmulq_n_f16(vsubq_f16(t03, t04), 8)), t05); - - float16x8_t d10 = vaddq_f16(vaddq_f16(vaddq_f16(t10, t11), vaddq_f16(t12, t13)), t14); - float16x8_t d11 = vaddq_f16(vsubq_f16(t11, t12), vmulq_n_f16(vsubq_f16(t13, t14), 2)); - float16x8_t d12 = vaddq_f16(vaddq_f16(t11, t12), vmulq_n_f16(vaddq_f16(t13, t14), 4)); - float16x8_t d13 = vaddq_f16(vaddq_f16(vsubq_f16(t11, t12), vmulq_n_f16(vsubq_f16(t13, t14), 8)), t15); - - float16x8_t d20 = vaddq_f16(vaddq_f16(vaddq_f16(t20, t21), vaddq_f16(t22, t23)), t24); - float16x8_t d21 = vaddq_f16(vsubq_f16(t21, t22), vmulq_n_f16(vsubq_f16(t23, t24), 2)); - float16x8_t d22 = vaddq_f16(vaddq_f16(t21, t22), vmulq_n_f16(vaddq_f16(t23, t24), 4)); - float16x8_t d23 = vaddq_f16(vaddq_f16(vsubq_f16(t21, t22), vmulq_n_f16(vsubq_f16(t23, t24), 8)), t25); - - float16x8_t d30 = vaddq_f16(vaddq_f16(vaddq_f16(t30, t31), vaddq_f16(t32, t33)), t34); - float16x8_t d31 = vaddq_f16(vsubq_f16(t31, t32), vmulq_n_f16(vsubq_f16(t33, t34), 2)); - float16x8_t d32 = vaddq_f16(vaddq_f16(t31, t32), vmulq_n_f16(vaddq_f16(t33, t34), 4)); - float16x8_t d33 = vaddq_f16(vaddq_f16(vsubq_f16(t31, t32), vmulq_n_f16(vsubq_f16(t33, t34), 8)), t35); - - vst1q_f16(output_data, d00); - vst1q_f16(output_data + 8, d01); - vst1q_f16(output_data + 16, d02); - vst1q_f16(output_data + 24, d03); - - vst1q_f16(output_data + output_w * 8, d10); - vst1q_f16(output_data + output_w * 8 + 8, d11); - vst1q_f16(output_data + output_w * 8 + 16, d12); - vst1q_f16(output_data + output_w * 8 + 24, d13); - - vst1q_f16(output_data + 2 * output_w * 8, d20); - vst1q_f16(output_data + 2 * output_w * 8 + 8, d21); - vst1q_f16(output_data + 2 * output_w * 8 + 16, d22); - vst1q_f16(output_data + 2 * output_w * 8 + 24, d23); - - vst1q_f16(output_data + 3 * output_w * 8, d30); - vst1q_f16(output_data + 3 * output_w * 8 + 8, d31); - vst1q_f16(output_data + 3 * output_w * 8 + 16, d32); - vst1q_f16(output_data + 3 * output_w * 8 + 24, d33); -} - -void Conv3x3Fp16OutputTransform(const float16_t *gemm_out, float16_t *out_data, const float16_t *bias_data, - int start_index, int real_cal_num, int out_w_block, ConvParameter *conv_param) { - int output_channel = conv_param->output_channel_; - int output_w = conv_param->output_w_; - int output_h = conv_param->output_h_; - int oc8 = UP_DIV(output_channel, C8NUM); - - for (int i = 0; i < real_cal_num; i++) { - int out_w_index = (start_index + i) % out_w_block; - int out_h_index = (start_index + i) / out_w_block; - int src_tile_offset = i * oc8 * C8NUM * 36; - int dst_tile_offset = 8 * (out_w_index * 4 + out_h_index * 4 * output_w); - - for (int j = 0; j < oc8; j++) { - int src_oc8_offset = src_tile_offset + j * 36 * C8NUM; - int dst_oc8_offset = dst_tile_offset + j * C8NUM * output_h * output_w; - const float16_t *src_ptr = gemm_out + src_oc8_offset; - const float16_t *bias_ptr = bias_data + j * C8NUM; - float16_t *dst_ptr = out_data + dst_oc8_offset; - - // output transform - Conv3x3Fp16OutputUnit(src_ptr, bias_ptr, dst_ptr, output_w); - } - } -} -#endif - // int8 conv3x3 void Conv3x3Uint8InputUnit(int16_t *tmp_data, int16_t *trans_input_data, size_t step, int input_zp) { #ifdef ENABLE_ARM diff --git a/mindspore/lite/src/runtime/kernel/arm/opclib/winograd_transform.h b/mindspore/lite/src/runtime/kernel/arm/opclib/winograd_transform.h index d0cc7e1b330..42e8f4a366c 100644 --- a/mindspore/lite/src/runtime/kernel/arm/opclib/winograd_transform.h +++ b/mindspore/lite/src/runtime/kernel/arm/opclib/winograd_transform.h @@ -51,22 +51,6 @@ void Conv3x3Fp32OutputUnit(const float *gemm_out, const float *bias_data, float void Conv3x3Fp32OutputTransform(const float *gemm_out, float *out_data, const float *bias_data, int start_index, int real_cal_num, int out_w_block, ConvParameter *conv_param); -#ifdef ENABLE_FP16 -// for fp16 convolution 3x3 filter/input/output transform -void Conv3x3Fp16InputUnit(float16_t *tmp_data, float16_t *trans_input_data, size_t step); - -void Conv3x3Fp16InputTransform(const float16_t *input_data, float16_t *trans_input, float16_t *tmp_data, - int start_index, int real_cal_num, int out_w_block, ConvParameter *conv_param); - -void Conv3x3Fp16FilterTransform(const float16_t *weight_data, float16_t *trans_weight, int iC8, int output_channel, - int kernel_plane); - -void Conv3x3Fp16OutputUnit(const float16_t *gemm_out, const float16_t *bias_data, float16_t *output_data, int output_w); - -void Conv3x3Fp16OutputTransform(const float16_t *gemm_out, float16_t *out_data, const float16_t *bias_data, - int start_index, int real_cal_num, int out_w_block, ConvParameter *conv_param); -#endif - // for int8 convolution 3x3 filter/input/output transform void Conv3x3Uint8InputUnit(int16_t *tmp_data, int16_t *trans_input_data, size_t step, int input_zp); diff --git a/mindspore/lite/src/runtime/kernel/opencl/kernel/arithmetic.cc b/mindspore/lite/src/runtime/kernel/opencl/kernel/arithmetic.cc index 65b2c256ea4..58f594d539a 100644 --- a/mindspore/lite/src/runtime/kernel/opencl/kernel/arithmetic.cc +++ b/mindspore/lite/src/runtime/kernel/opencl/kernel/arithmetic.cc @@ -127,9 +127,9 @@ kernel::LiteKernel *OpenCLArithmeticKernelCreator(const std::vector *kernels) { return kernel::LiteKernel *Scheduler::ScheduleNode(const std::vector &inputs, const std::vector &outputs, const lite::Primitive *primitive) { - // todo: support CPU, NPU, APU + // todo: support NPU, APU MS_ASSERT(nullptr != primitive); - kernel::KernelKey desc{kernel::KERNEL_ARCH::kCPU, primitive->Type()}; + auto data_type = inputs.front()->data_type(); + kernel::KernelKey desc{kernel::KERNEL_ARCH::kCPU, data_type, primitive->Type()}; if (context->deviceCtx.type == DT_GPU) { desc.arch = kernel::KERNEL_ARCH::kGPU; auto *kernel = KernelFactory::GetInstance()->GetKernel(inputs, outputs, primitive, context, desc); @@ -161,13 +162,25 @@ kernel::LiteKernel *Scheduler::ScheduleNode(const std::vector return kernel; } } + desc.arch = kernel::KERNEL_ARCH::kCPU; - auto *kernel = KernelFactory::GetInstance()->GetKernel(inputs, outputs, primitive, context, desc); - if (nullptr != kernel) { + kernel::LiteKernel *kernel; + if (data_type == kNumberTypeFloat32) { + // check if support fp16 + kernel::KernelKey key{desc.arch, kNumberTypeFloat16, desc.type}; + kernel = KernelFactory::GetInstance()->GetKernel(inputs, outputs, primitive, context, key); + if (kernel != nullptr) { + kernel->set_desc(desc); + return kernel; + } + kernel = KernelFactory::GetInstance()->GetKernel(inputs, outputs, primitive, context, desc); + } else { + kernel = KernelFactory::GetInstance()->GetKernel(inputs, outputs, primitive, context, desc); + } + if (kernel != nullptr) { kernel->set_desc(desc); return kernel; } return nullptr; } } // namespace mindspore::lite - diff --git a/mindspore/lite/test/CMakeLists.txt b/mindspore/lite/test/CMakeLists.txt index 1dbde6806d5..e01e4079ade 100644 --- a/mindspore/lite/test/CMakeLists.txt +++ b/mindspore/lite/test/CMakeLists.txt @@ -83,7 +83,7 @@ file(GLOB KERNEL_OP_SRC ) if (PLATFORM_ARM64) # assembly - file(GLOB_RECURSE TEST_ASSEMBLY_SRC ${LITE_DIR}/src/runtime/kernel/arm/opclib/assembly/arm64/*.s + file(GLOB TEST_ASSEMBLY_SRC ${LITE_DIR}/src/runtime/kernel/arm/opclib/assembly/arm64/*.s ${LITE_DIR}/src/runtime/kernel/arm/opclib/assembly/arm64/*.S) set_property(SOURCE ${TEST_ASSEMBLY_SRC} PROPERTY LANGUAGE C) @@ -94,8 +94,9 @@ if (PLATFORM_ARM64) endif() if (PLATFORM_ARM32) # assembly - set(GLOB_RECURSE TEST_ASSEMBLY_SRC - ${LITE_DIR}/src/runtime/kernel/arm/opclib/assembly/arm32/*.S) + file(GLOB TEST_ASSEMBLY_SRC + ${LITE_DIR}/src/runtime/kernel/arm/opclib/assembly/arm32/*.S + ${LITE_DIR}/src/runtime/kernel/arm/opclib/assembly/arm32/*.s) set_property(SOURCE ${TEST_ASSEMBLY_SRC} PROPERTY LANGUAGE C) set(KERNEL_OP_SRC ${KERNEL_OP_SRC} diff --git a/mindspore/lite/test/ut/src/runtime/kernel/arm/common/pack_tests.cc b/mindspore/lite/test/ut/src/runtime/kernel/arm/common/pack_tests.cc index 0b2034ec3b8..3a4632809f8 100644 --- a/mindspore/lite/test/ut/src/runtime/kernel/arm/common/pack_tests.cc +++ b/mindspore/lite/test/ut/src/runtime/kernel/arm/common/pack_tests.cc @@ -20,7 +20,6 @@ #include "common/common_test.h" #include "mindspore/lite/src/common/file_utils.h" #include "mindspore/lite/src/runtime/kernel/arm/opclib/pack.h" -#include "mindspore/lite/src/runtime/kernel/arm/base/convolution_base.h" namespace mindspore { class TestPack : public mindspore::Common { diff --git a/mindspore/lite/test/ut/src/runtime/kernel/arm/fp32/activation_fp32_test.cc b/mindspore/lite/test/ut/src/runtime/kernel/arm/fp32/activation_fp32_test.cc index 27dc434a64e..3595010287f 100644 --- a/mindspore/lite/test/ut/src/runtime/kernel/arm/fp32/activation_fp32_test.cc +++ b/mindspore/lite/test/ut/src/runtime/kernel/arm/fp32/activation_fp32_test.cc @@ -108,8 +108,8 @@ TEST_F(TestActivationFp32, HSwishFp32) { outputs_tensor.push_back(&output0_tensor); output0_tensor.SetData(output.data()); - kernel::KernelKey desc = {kernel::KERNEL_ARCH::kCPU, schema::PrimitiveType_Activation}; - auto creator = lite::KernelRegistry::GetInstance()->GetKernelCreator(desc); + kernel::KernelKey desc = {kernel::KERNEL_ARCH::kCPU, kNumberTypeFloat32, schema::PrimitiveType_Activation}; + auto creator = lite::KernelRegistry::GetInstance()->GetCreator(desc); ASSERT_NE(creator, nullptr); lite::Context ctx; ctx.threadNum = 7; diff --git a/mindspore/lite/test/ut/src/runtime/kernel/arm/int8/add_int8_tests.cc b/mindspore/lite/test/ut/src/runtime/kernel/arm/int8/add_int8_tests.cc index c27c34f218b..a8d738ddaab 100644 --- a/mindspore/lite/test/ut/src/runtime/kernel/arm/int8/add_int8_tests.cc +++ b/mindspore/lite/test/ut/src/runtime/kernel/arm/int8/add_int8_tests.cc @@ -33,7 +33,7 @@ TEST_F(TestQuantizedAdd, Add) { lite::tensor::Tensor out_tensor(kNumberTypeInt8, {1, 1, 2, 5}); int8_t input_data0[] = {-102, 25, -51, 89, -102, 25, -51, 89, -102, 25}; // -0.8 0.2 -0.4 0.7 - int8_t input_data1[] = {38, 51, 64, -102, 38, 51, 64, -102, 38, 51}; // 0.3 0.4 0.5 -0.8 + int8_t input_data1[] = {38, 51, 64, -102, 38, 51, 64, -102, 38, 51}; // 0.3 0.4 0.5 -0.8 int8_t output_data[10] = {0}; in_tensor0.SetData(input_data0); in_tensor1.SetData(input_data1); @@ -50,9 +50,9 @@ TEST_F(TestQuantizedAdd, Add) { std::vector outputs = {&out_tensor}; OpParameter parameter = {}; - kernel::KernelKey desc = {kernel::KERNEL_ARCH::kCPU, schema::PrimitiveType_Add}; + kernel::KernelKey desc = {kernel::KERNEL_ARCH::kCPU, kNumberTypeInt8, schema::PrimitiveType_Add}; - auto creator = lite::KernelRegistry::GetInstance()->GetKernelCreator(desc); + auto creator = lite::KernelRegistry::GetInstance()->GetCreator(desc); ASSERT_NE(creator, nullptr); auto ctx = std::make_shared();