rewrite op register func

This commit is contained in:
fuzhiye 2020-08-01 10:33:28 +08:00
parent 11f786c9bd
commit 3a70fd23ce
110 changed files with 1855 additions and 1724 deletions

View File

@ -43,19 +43,11 @@ LiteKernel *KernelFactory::GetKernel(const std::vector<tensor::Tensor *> &inputs
MS_LOG(ERROR) << "PopulateParameter return nullptr, type: " << schema::EnumNamePrimitiveType(primitive->Type());
return nullptr;
}
auto creator = KernelRegistry::GetInstance()->GetKernelCreator(key);
auto creator = KernelRegistry::GetInstance()->GetCreator(key);
if (creator != nullptr) {
auto *kernel = creator(inputs, outputs, parameter, ctx, key);
if (kernel != nullptr) {
return kernel;
} else {
MS_LOG(ERROR) << "Creator kernel failed for " << schema::EnumNamePrimitiveType(key.type);
return nullptr;
}
} else {
MS_LOG(ERROR) << "Can not find OpCreator for " << schema::EnumNamePrimitiveType(key.type);
return nullptr;
auto kernel = creator(inputs, outputs, parameter, ctx, key);
return kernel;
}
return nullptr;
}
} // namespace mindspore::lite

View File

@ -38,4 +38,3 @@ class KernelFactory {
} // namespace mindspore::lite
#endif // MINDSPORE_LITE_SRC_KERNEL_FACTORY_H_

View File

@ -13,47 +13,105 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "src/kernel_registry.h"
#include "include/errorcode.h"
#include "ir/dtype/type_id.h"
#ifdef ENABLE_ARM64
#include <asm/hwcap.h>
#include "common/utils.h"
#include "utils/log_adapter.h"
#include "src/runtime/kernel/arm/opclib/optimized_kernel.h"
#endif
using mindspore::kernel::kCPU;
using mindspore::kernel::KERNEL_ARCH;
using mindspore::kernel::KernelCreator;
using mindspore::kernel::KernelKey;
using mindspore::kernel::KERNEL_ARCH;
using mindspore::kernel::kKernelArch_MAX;
using mindspore::kernel::kKernelArch_MIN;
using mindspore::schema::PrimitiveType_MAX;
using mindspore::schema::PrimitiveType_MIN;
namespace mindspore::lite {
KernelRegistry::KernelRegistry() {}
KernelRegistry::~KernelRegistry() {}
KernelRegistry::~KernelRegistry() { FreeCreatorArray(); }
KernelRegistry *KernelRegistry::GetInstance() {
static KernelRegistry instance;
return &instance;
}
KernelCreator KernelRegistry::GetKernelCreator(const KernelKey &desc) {
auto it = creators.find(desc);
if (it != creators.end()) {
return it->second;
int KernelRegistry::Init() {
lock_.lock();
if (creator_arrays_ != nullptr) {
lock_.unlock();
return RET_OK;
}
device_type_length_ = kKernelArch_MAX - kKernelArch_MIN;
data_type_length_ = kNumberTypeEnd - kNumberTypeBegin;
op_type_length_ = PrimitiveType_MAX - PrimitiveType_MIN;
// malloc an array contain creator functions of kernel
auto total_len = device_type_length_ * data_type_length_ * op_type_length_;
creator_arrays_ = (kernel::KernelCreator *)malloc(total_len * sizeof(kernel::KernelCreator));
if (creator_arrays_ == nullptr) {
MS_LOG(ERROR) << "malloc creator_arrays_ failed.";
lock_.unlock();
return RET_ERROR;
}
for (int i = 0; i < total_len; ++i) {
creator_arrays_[i] = nullptr;
}
#ifdef ENABLE_ARM64
void *optimized_lib_handler = OptimizeModule::GetInstance()->optimized_op_handler_;
if (optimized_lib_handler != nullptr) {
MS_LOG(INFO) << "load optimize lib success.";
} else {
MS_LOG(INFO) << "load optimize lib failed.";
}
#endif
lock_.unlock();
return RET_OK;
}
// if not find, use cpu kernel
KernelKey cpuDesc {kernel::KERNEL_ARCH::kCPU, desc.type};
it = creators.find(cpuDesc);
if (it != creators.end()) {
return it->second;
void KernelRegistry::FreeCreatorArray() {
if (creator_arrays_ != nullptr) {
free(creator_arrays_);
creator_arrays_ = nullptr;
}
}
kernel::KernelCreator KernelRegistry::GetCreator(const KernelKey &desc) {
int index = GetCreatorFuncIndex(desc);
auto it = creator_arrays_[index];
if (it != nullptr) {
return it;
}
return nullptr;
}
void KernelRegistry::RegKernel(const KernelKey desc, KernelCreator creator) { creators[desc] = creator; }
int KernelRegistry::GetCreatorFuncIndex(const kernel::KernelKey desc) {
int index;
int device_index = static_cast<int>(desc.arch);
int dType_index = static_cast<int>(desc.data_type);
int op_index = static_cast<int>(desc.type);
index = device_index * data_type_length_ * op_type_length_ + dType_index * op_type_length_ + op_index;
return index;
}
void KernelRegistry::RegKernel(const KERNEL_ARCH arch, const schema::PrimitiveType type, KernelCreator creator) {
KernelKey desc = {arch, type};
creators[desc] = creator;
void KernelRegistry::RegKernel(const KernelKey desc, kernel::KernelCreator creator) {
int index = GetCreatorFuncIndex(desc);
creator_arrays_[index] = creator;
}
void KernelRegistry::RegKernel(const KERNEL_ARCH arch, const TypeId data_type, const schema::PrimitiveType op_type,
kernel::KernelCreator creator) {
KernelKey desc = {arch, data_type, op_type};
int index = GetCreatorFuncIndex(desc);
creator_arrays_[index] = creator;
}
bool KernelRegistry::Merge(const std::unordered_map<KernelKey, KernelCreator> &newCreators) { return false; }
const std::map<KernelKey, KernelCreator> &KernelRegistry::GetKernelCreators() { return creators; }
const kernel::KernelCreator *KernelRegistry::GetCreatorArrays() { return creator_arrays_; }
} // namespace mindspore::lite

View File

@ -30,16 +30,22 @@ class KernelRegistry {
virtual ~KernelRegistry();
static KernelRegistry *GetInstance();
virtual kernel::KernelCreator GetKernelCreator(const kernel::KernelKey &desc);
const std::map<kernel::KernelKey, kernel::KernelCreator> &GetKernelCreators();
int Init();
void FreeCreatorArray();
virtual kernel::KernelCreator GetCreator(const kernel::KernelKey &desc);
const kernel::KernelCreator *GetCreatorArrays();
int GetCreatorFuncIndex(const kernel::KernelKey desc);
void RegKernel(const kernel::KernelKey desc, kernel::KernelCreator creator);
void RegKernel(const kernel::KERNEL_ARCH arch, const schema::PrimitiveType type, kernel::KernelCreator creator);
void RegKernel(const kernel::KERNEL_ARCH arch, const TypeId data_type, const schema::PrimitiveType type,
kernel::KernelCreator creator);
bool Merge(const std::unordered_map<kernel::KernelKey, kernel::KernelCreator> &newCreators);
protected:
std::map<kernel::KernelKey, kernel::KernelCreator> creators;
kernel::KernelCreator *creator_arrays_ = nullptr;
int device_type_length_;
int data_type_length_;
int op_type_length_;
std::mutex lock_;
};
class KernelRegistrar {
@ -48,14 +54,14 @@ class KernelRegistrar {
KernelRegistry::GetInstance()->RegKernel(desc, creator);
}
KernelRegistrar(const kernel::KERNEL_ARCH arch, const schema::PrimitiveType type, kernel::KernelCreator creator) {
KernelRegistry::GetInstance()->RegKernel(arch, type, creator);
KernelRegistrar(const kernel::KERNEL_ARCH arch, const TypeId data_type, const schema::PrimitiveType op_type,
kernel::KernelCreator creator) {
KernelRegistry::GetInstance()->RegKernel(arch, data_type, op_type, creator);
}
};
#define REG_KERNEL(arch, type, kernelCreater) \
static KernelRegistrar g_##arch##type##kernelReg(arch, type, kernelCreater);
#define REG_KERNEL(arch, data_type, op_type, kernelCreater) \
static KernelRegistrar g_##arch##data_type##op_type##kernelReg(arch, data_type, op_type, kernelCreater);
} // namespace mindspore::lite
#endif // MINDSPORE_LITE_SRC_KERNEL_REGISTRY_H_

View File

@ -18,7 +18,7 @@
#define MINDSPORE_LITE_SRC_LITE_KERNEL_H_
#include <vector>
#include <string>
#ifdef ENABLE_FP16
#ifdef ENABLE_ARM
#include <arm_neon.h>
#endif
#include "src/runtime/kernel/arm/opclib/op_base.h"
@ -35,14 +35,17 @@ using FLOAT_t = float;
// using mindspore::kernel::AddressPtr;
namespace mindspore::kernel {
enum KERNEL_ARCH { kCPU, kGPU, kNPU, kInferShape };
enum KERNEL_ARCH { kCPU, kGPU, kNPU, kKernelArch_MIN = kCPU, kKernelArch_MAX = kNPU };
struct KernelKey {
KERNEL_ARCH arch;
TypeId data_type;
schema::PrimitiveType type;
bool operator<(const KernelKey &dst) const {
if (arch != dst.arch) {
return arch < dst.arch;
} else if (data_type != dst.data_type) {
return data_type < dst.data_type;
} else {
return type < dst.type;
}
@ -179,4 +182,3 @@ class LiteKernelUtil {
} // namespace mindspore::kernel
#endif // MINDSPORE_LITE_SRC_LITE_KERNEL_H_

View File

@ -24,6 +24,7 @@
#include "src/executor.h"
#include "src/common/utils.h"
#include "src/common/graph_util.h"
#include "src/kernel_registry.h"
#if SUPPORT_GPU
#include "src/runtime/opencl/opencl_runtime.h"
#endif
@ -197,7 +198,11 @@ void LiteSession::Init(Context *context) {
this->context->deviceCtx.type = context->deviceCtx.type;
this->context->allocator = std::make_shared<DefaultAllocator>();
ConfigThreadPool(context->cpuBindMode, context->threadNum);
auto ret = KernelRegistry::GetInstance()->Init();
if (ret != RET_OK) {
MS_LOG(ERROR) << "KernelRegistry Init Failed.";
return;
}
#if SUPPORT_GPU
if (context->deviceCtx.type == DT_GPU) {
auto opencl_runtime = lite::opencl::OpenCLRuntime::GetInstance();
@ -228,6 +233,7 @@ LiteSession::~LiteSession() {
delete kernel;
}
}
std::vector<mindspore::tensor::MSTensor *> LiteSession::GetInputsByName(std::string name) {
return input_map[name];
}

View File

@ -25,13 +25,5 @@ if (PLATFORM_ARM32)
set(KERNEL_SRC ${KERNEL_SRC} ${ASSEMBLY_SRC})
endif()
if (ENABLE_FP16)
file(GLOB FP6_SRC
${CMAKE_CURRENT_SOURCE_DIR}/fp16/*.cc
${CMAKE_CURRENT_SOURCE_DIR}/opclib/fp16/*.cc
)
set(KERNEL_SRC ${KERNEL_SRC} ${FP6_SRC})
endif ()
add_library(cpu_kernel_mid_ OBJECT ${KERNEL_SRC})
add_subdirectory(opclib)

View File

@ -36,7 +36,8 @@ int ConcatBaseCPUKernel::Init() {
kernel::LiteKernel *CpuConcatInt8KernelCreator(const std::vector<lite::tensor::Tensor *> &inputs,
const std::vector<lite::tensor::Tensor *> &outputs,
OpParameter *opParameter, const Context *ctx) {
OpParameter *opParameter, const Context *ctx,
const kernel::KernelKey &desc) {
if (opParameter == nullptr) {
MS_LOG(ERROR) << "Input opParameter is nullptr!";
return nullptr;
@ -47,51 +48,6 @@ kernel::LiteKernel *CpuConcatInt8KernelCreator(const std::vector<lite::tensor::T
MS_LOG(ERROR) << "new ConcatCPUKernel fail!";
return nullptr;
}
return kernel;
}
kernel::LiteKernel *CpuConcatFp32OrInt32KernelCreator(const std::vector<lite::tensor::Tensor *> &inputs,
const std::vector<lite::tensor::Tensor *> &outputs,
OpParameter *opParameter, const Context *ctx) {
if (opParameter == nullptr) {
MS_LOG(ERROR) << "Input opParameter is nullptr!";
return nullptr;
}
MS_ASSERT(desc.type == schema::PrimitiveType_Concat);
auto *kernel = new(std::nothrow) ConcatCPUKernel(opParameter, inputs, outputs, ctx);
if (kernel == nullptr) {
MS_LOG(ERROR) << "new ConcatCPUKernel fail!";
return nullptr;
}
return kernel;
}
kernel::LiteKernel *CpuConcatKernelCreator(const std::vector<lite::tensor::Tensor *> &inputs,
const std::vector<lite::tensor::Tensor *> &outputs,
OpParameter *opParameter,
const lite::Context *ctx, const kernel::KernelKey &desc) {
MS_ASSERT(opParameter != nullptr);
MS_ASSERT(desc.type == schema::PrimitiveType_Concat);
auto input_tensor = inputs.at(kInputIndex);
auto data_type = input_tensor->data_type();
kernel::LiteKernel *kernel = nullptr;
switch (data_type) {
case kNumberTypeInt8:
case kNumberTypeUInt8:
kernel = CpuConcatInt8KernelCreator(inputs, outputs, opParameter, ctx);
break;
case kNumberTypeInt32:
case kNumberTypeFloat32:
kernel = CpuConcatFp32OrInt32KernelCreator(inputs, outputs, opParameter, ctx);
break;
default:
break;
}
if (kernel == nullptr) {
MS_LOG(ERROR) << "kernel is nullptr.";
return nullptr;
}
auto ret = kernel->Init();
if (ret != RET_OK) {
delete kernel;
@ -102,6 +58,56 @@ kernel::LiteKernel *CpuConcatKernelCreator(const std::vector<lite::tensor::Tenso
return kernel;
}
REG_KERNEL(kCPU, PrimitiveType_Concat, CpuConcatKernelCreator)
kernel::LiteKernel *CpuConcatInt32KernelCreator(const std::vector<lite::tensor::Tensor *> &inputs,
const std::vector<lite::tensor::Tensor *> &outputs,
OpParameter *opParameter, const Context *ctx,
const kernel::KernelKey &desc) {
if (opParameter == nullptr) {
MS_LOG(ERROR) << "Input opParameter is nullptr!";
return nullptr;
}
MS_ASSERT(desc.type == schema::PrimitiveType_Concat);
auto *kernel = new(std::nothrow) ConcatCPUKernel(opParameter, inputs, outputs, ctx);
if (kernel == nullptr) {
MS_LOG(ERROR) << "new ConcatCPUKernel fail!";
return nullptr;
}
auto ret = kernel->Init();
if (ret != RET_OK) {
delete kernel;
MS_LOG(ERROR) << "Init kernel failed, name: " << opParameter->name_ << ", type: "
<< schema::EnumNamePrimitiveType(static_cast<schema::PrimitiveType>(opParameter->type_));
return nullptr;
}
return kernel;
}
kernel::LiteKernel *CpuConcatFp32KernelCreator(const std::vector<lite::tensor::Tensor *> &inputs,
const std::vector<lite::tensor::Tensor *> &outputs,
OpParameter *opParameter, const Context *ctx,
const kernel::KernelKey &desc) {;
if (opParameter == nullptr) {
MS_LOG(ERROR) << "Input opParameter is nullptr!";
return nullptr;
}
MS_ASSERT(desc.type == schema::PrimitiveType_Concat);
auto *kernel = new(std::nothrow) ConcatCPUKernel(opParameter, inputs, outputs, ctx);
if (kernel == nullptr) {
MS_LOG(ERROR) << "new ConcatCPUKernel fail!";
return nullptr;
}
auto ret = kernel->Init();
if (ret != RET_OK) {
delete kernel;
MS_LOG(ERROR) << "Init kernel failed, name: " << opParameter->name_ << ", type: "
<< schema::EnumNamePrimitiveType(static_cast<schema::PrimitiveType>(opParameter->type_));
return nullptr;
}
return kernel;
}
REG_KERNEL(kCPU, kNumberTypeInt8, PrimitiveType_Concat, CpuConcatInt8KernelCreator)
REG_KERNEL(kCPU, kNumberTypeInt32, PrimitiveType_Concat, CpuConcatInt32KernelCreator)
REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_Concat, CpuConcatFp32KernelCreator)
} // namespace mindspore::kernel

View File

@ -15,24 +15,6 @@
*/
#include "src/runtime/kernel/arm/base/convolution_base.h"
#include "src/runtime/kernel/arm/fp32/convolution.h"
#include "src/runtime/kernel/arm/fp32/convolution_winograd.h"
#include "src/runtime/kernel/arm/fp32/deconvolution.h"
#include "src/runtime/kernel/arm/fp32/convolution_1x1.h"
#include "src/runtime/kernel/arm/fp32/convolution_3x3.h"
#include "src/runtime/kernel/arm/fp32/convolution_depthwise.h"
#include "src/runtime/kernel/arm/fp32/deconvolution_depthwise.h"
#ifdef ENABLE_FP16
#include "src/runtime/kernel/arm/fp16/convolution_fp16.h"
#include "src/runtime/kernel/arm/fp16/convolution_3x3_fp16.h"
#include "src/runtime/kernel/arm/fp16/convolution_depthwise_fp16.h"
#include "src/runtime/kernel/arm/fp16/deconvolution_depthwise_fp16.h"
#endif
#include "src/runtime/kernel/arm/int8/deconvolution_int8.h"
#include "src/runtime/kernel/arm/int8/convolution_int8.h"
#include "src/runtime/kernel/arm/int8/convolution_3x3_int8.h"
#include "src/runtime/kernel/arm/int8/convolution_depthwise_int8.h"
#include "src/runtime/kernel/arm/int8/deconvolution_depthwise_int8.h"
#include "schema/model_generated.h"
#include "src/kernel_factory.h"
#include "include/errorcode.h"
@ -42,10 +24,6 @@ using mindspore::lite::RET_ERROR;
using mindspore::lite::RET_OK;
using mindspore::schema::ActivationType;
using mindspore::schema::PadMode;
using mindspore::schema::PrimitiveType_Conv2D;
using mindspore::schema::PrimitiveType_DeConv2D;
using mindspore::schema::PrimitiveType_DeDepthwiseConv2D;
using mindspore::schema::PrimitiveType_DepthwiseConv2D;
namespace mindspore::kernel {
ConvolutionBaseCPUKernel::~ConvolutionBaseCPUKernel() {
@ -192,352 +170,4 @@ void ComputeQuantOutRange(ConvParameter *conv_param) {
conv_param->conv_quant_arg_.out_act_min_[0] = min;
conv_param->conv_quant_arg_.out_act_max_[0] = max;
}
void CheckIfUseWinograd(bool *use_winograd, int *output_unit, ConvParameter *conv_param,
InputTransformUnitFunc input_trans_func, OutputTransformUnitFunc output_trans_func) {
if (conv_param->kernel_w_ == conv_param->kernel_h_ && conv_param->dilation_h_ == 1 && conv_param->dilation_w_ == 1 &&
conv_param->stride_h_ == 1 && conv_param->stride_w_ == 1) {
*output_unit = SelectOutputUnit(conv_param);
if (*output_unit > 1) {
*use_winograd = true;
int input_unit = conv_param->kernel_h_ + *output_unit - 1;
input_trans_func = GetInputTransFunc(input_unit);
if (input_trans_func == nullptr) {
MS_LOG(INFO) << "No matching input trans func. Turn back to common conv.";
*use_winograd = false;
}
output_trans_func = GetOutputTransFunc(input_unit, *output_unit);
if (output_trans_func == nullptr) {
MS_LOG(INFO) << "No matching output trans func. Turn back to common conv.";
*use_winograd = false;
}
} else {
*use_winograd = false;
}
} else {
*use_winograd = false;
}
}
bool CheckSupportFP16() {
bool support_fp16 = false;
#ifdef ENABLE_ARM64
void *optimize_op_handler = OptimizeModule::GetInstance()->optimized_op_handler_;
if (optimize_op_handler != nullptr) {
support_fp16 = true;
MS_LOG(INFO) << "Support FP16.";
} else {
support_fp16 = false;
MS_LOG(INFO) << "Your machine doesn't support fp16, return back to float32 kernel.";
}
#endif
return support_fp16;
}
kernel::LiteKernel *CpuConvFloatKernelCreator(const std::vector<lite::tensor::Tensor *> &inputs,
const std::vector<lite::tensor::Tensor *> &outputs,
OpParameter *opParameter, const Context *ctx) {
auto conv_param = reinterpret_cast<ConvParameter *>(opParameter);
int kernel_h = conv_param->kernel_h_;
int kernel_w = conv_param->kernel_w_;
int stride_h = conv_param->stride_h_;
int stride_w = conv_param->stride_w_;
int dilation_h = conv_param->dilation_h_;
int dilation_w = conv_param->dilation_w_;
conv_param->input_h_ = inputs.front()->Height();
conv_param->input_w_ = inputs.front()->Width();
conv_param->output_h_ = outputs.front()->Height();
conv_param->output_w_ = outputs.front()->Width();
bool use_winograd;
int out_unit;
InputTransformUnitFunc input_trans_func = nullptr;
OutputTransformUnitFunc output_trans_func = nullptr;
CheckIfUseWinograd(&use_winograd, &out_unit, conv_param, input_trans_func, output_trans_func);
bool support_fp16 = CheckSupportFP16();
if (kernel_h == 1 && kernel_w == 1) {
auto kernel = new (std::nothrow) Convolution1x1CPUKernel(opParameter, inputs, outputs, ctx);
return kernel;
} else if (kernel_h == 3 && kernel_w == 3 && stride_h == 1 && stride_w == 1 && dilation_h == 1 && dilation_w == 1) {
if (support_fp16) {
#ifdef ENABLE_FP16
auto kernel = new (std::nothrow) Convolution3x3FP16CPUKernel(opParameter, inputs, outputs, ctx);
return kernel;
#endif
}
auto kernel = new (std::nothrow) Convolution3x3CPUKernel(opParameter, inputs, outputs, ctx);
return kernel;
} else if (use_winograd) {
auto kernel = new (std::nothrow) ConvolutionWinogradCPUKernel(opParameter, inputs, outputs, ctx, out_unit);
return kernel;
} else {
if (support_fp16) {
#ifdef ENABLE_FP16
auto kernel = new (std::nothrow) ConvolutionFP16CPUKernel(opParameter, inputs, outputs, ctx);
return kernel;
#endif
}
auto kernel = new (std::nothrow) ConvolutionCPUKernel(opParameter, inputs, outputs, ctx);
return kernel;
}
}
kernel::LiteKernel *CpuConvInt8KernelCreator(const std::vector<lite::tensor::Tensor *> &inputs,
const std::vector<lite::tensor::Tensor *> &outputs,
OpParameter *opParameter, const Context *ctx) {
auto conv_param = reinterpret_cast<ConvParameter *>(opParameter);
int kernel_h = conv_param->kernel_h_;
int kernel_w = conv_param->kernel_w_;
int stride_h = conv_param->stride_h_;
int stride_w = conv_param->stride_w_;
int dilation_h = conv_param->dilation_h_;
int dilation_w = conv_param->dilation_w_;
if (kernel_h == 3 && kernel_w == 3 && stride_h == 1 && stride_w == 1 && dilation_h == 1 && dilation_w == 1) {
auto kernel = new (std::nothrow) Convolution3x3Int8CPUKernel(opParameter, inputs, outputs, ctx);
return kernel;
} else {
auto kernel = new (std::nothrow) ConvolutionInt8CPUKernel(opParameter, inputs, outputs, ctx);
return kernel;
}
}
kernel::LiteKernel *CpuConvKernelCreator(const std::vector<lite::tensor::Tensor *> &inputs,
const std::vector<lite::tensor::Tensor *> &outputs, OpParameter *opParameter,
const lite::Context *ctx, const kernel::KernelKey &desc) {
MS_ASSERT(opParameter != nullptr);
MS_ASSERT(desc.type == schema::PrimitiveType_Conv2D);
auto input_tensor = inputs.at(kInputIndex);
auto data_type = input_tensor->data_type();
kernel::LiteKernel *kernel = nullptr;
switch (data_type) {
case kNumberTypeInt8:
kernel = CpuConvInt8KernelCreator(inputs, outputs, opParameter, ctx);
break;
case kNumberTypeFloat32:
kernel = CpuConvFloatKernelCreator(inputs, outputs, opParameter, ctx);
break;
default:
break;
}
if (kernel == nullptr) {
MS_LOG(ERROR) << "kernel is nullptr.";
return nullptr;
}
auto ret = kernel->Init();
if (ret != RET_OK) {
delete kernel;
MS_LOG(ERROR) << "Init kernel failed, name: " << opParameter->name_ << ", type: "
<< schema::EnumNamePrimitiveType(static_cast<schema::PrimitiveType>(opParameter->type_));
return nullptr;
}
return kernel;
}
kernel::LiteKernel *CpuConvDwFp32KernelCreator(const std::vector<lite::tensor::Tensor *> &inputs,
const std::vector<lite::tensor::Tensor *> &outputs,
OpParameter *opParameter, const Context *ctx) {
auto kernel = new (std::nothrow) ConvolutionDepthwiseCPUKernel(opParameter, inputs, outputs, ctx);
if (kernel == nullptr) {
MS_LOG(ERROR) << "kernel is nullptr.";
return nullptr;
}
return kernel;
}
#ifdef ENABLE_FP16
kernel::LiteKernel *CpuConvDwFp16KernelCreator(const std::vector<lite::tensor::Tensor *> &inputs,
const std::vector<lite::tensor::Tensor *> &outputs,
OpParameter *opParameter, const Context *ctx) {
auto kernel = new (std::nothrow) ConvolutionDepthwiseFp16CPUKernel(opParameter, inputs, outputs, ctx);
if (kernel == nullptr) {
MS_LOG(ERROR) << "kernel is nullptr.";
return nullptr;
}
return kernel;
}
#endif
kernel::LiteKernel *CpuConvDwInt8KernelCreator(const std::vector<lite::tensor::Tensor *> &inputs,
const std::vector<lite::tensor::Tensor *> &outputs,
OpParameter *opParameter, const Context *ctx) {
auto kernel = new (std::nothrow) ConvolutionDepthwiseInt8CPUKernel(opParameter, inputs, outputs, ctx);
if (kernel == nullptr) {
MS_LOG(ERROR) << "kernel is nullptr.";
return nullptr;
}
return kernel;
}
kernel::LiteKernel *CpuConvDwKernelCreator(const std::vector<lite::tensor::Tensor *> &inputs,
const std::vector<lite::tensor::Tensor *> &outputs, OpParameter *opParameter,
const lite::Context *ctx, const kernel::KernelKey &desc) {
MS_ASSERT(opParameter != nullptr);
MS_ASSERT(desc.type == schema::PrimitiveType_DepthwiseConv2D);
auto input_tensor = inputs.at(kInputIndex);
auto data_type = input_tensor->data_type();
kernel::LiteKernel *kernel = nullptr;
switch (data_type) {
case kNumberTypeInt8:
kernel = CpuConvDwInt8KernelCreator(inputs, outputs, opParameter, ctx);
break;
case kNumberTypeFloat32:
#ifdef ENABLE_FP16
kernel = CpuConvDwFp16KernelCreator(inputs, outputs, opParameter, ctx);
#else
kernel = CpuConvDwFp32KernelCreator(inputs, outputs, opParameter, ctx);
#endif
break;
default:
break;
}
if (kernel == nullptr) {
MS_LOG(ERROR) << "kernel is nullptr.";
return nullptr;
}
auto ret = kernel->Init();
if (ret != RET_OK) {
delete kernel;
MS_LOG(ERROR) << "Init kernel failed, name: " << opParameter->name_ << ", type: "
<< schema::EnumNamePrimitiveType(static_cast<schema::PrimitiveType>(opParameter->type_));
return nullptr;
}
return kernel;
}
kernel::LiteKernel *CpuDeconvDwFp32KernelCreator(const std::vector<lite::tensor::Tensor *> &inputs,
const std::vector<lite::tensor::Tensor *> &outputs,
OpParameter *opParameter, const lite::Context *ctx) {
auto kernel = new (std::nothrow) DeconvolutionDepthwiseCPUKernel(opParameter, inputs, outputs, ctx);
if (kernel == nullptr) {
MS_LOG(ERROR) << "kernel is nullptr.";
return nullptr;
}
return kernel;
}
#ifdef ENABLE_FP16
kernel::LiteKernel *CpuDeconvDwFp16KernelCreator(const std::vector<lite::tensor::Tensor *> &inputs,
const std::vector<lite::tensor::Tensor *> &outputs,
OpParameter *opParameter, const lite::Context *ctx) {
auto kernel = new (std::nothrow) DeconvolutionDepthwiseFp16CPUKernel(opParameter, inputs, outputs, ctx);
if (kernel == nullptr) {
MS_LOG(ERROR) << "kernel is nullptr.";
return nullptr;
}
return kernel;
}
#endif
kernel::LiteKernel *CpuDeconvDwInt8KernelCreator(const std::vector<lite::tensor::Tensor *> &inputs,
const std::vector<lite::tensor::Tensor *> &outputs,
OpParameter *opParameter, const lite::Context *ctx) {
auto kernel = new (std::nothrow) DeconvolutionDepthwiseInt8CPUKernel(opParameter, inputs, outputs, ctx);
if (kernel == nullptr) {
MS_LOG(ERROR) << "kernel is nullptr.";
return nullptr;
}
return kernel;
}
kernel::LiteKernel *CpuDeconvDwKernelCreator(const std::vector<lite::tensor::Tensor *> &inputs,
const std::vector<lite::tensor::Tensor *> &outputs,
OpParameter *opParameter, const lite::Context *ctx,
const kernel::KernelKey &desc) {
MS_ASSERT(opParameter != nullptr);
MS_ASSERT(desc.type == schema::PrimitiveType_DepthwiseConv2D);
auto input_tensor = inputs.at(kInputIndex);
auto data_type = input_tensor->data_type();
kernel::LiteKernel *kernel = nullptr;
switch (data_type) {
case kNumberTypeInt8:
kernel = CpuDeconvDwInt8KernelCreator(inputs, outputs, opParameter, ctx);
break;
case kNumberTypeFloat32:
#ifdef ENABLE_FP16
kernel = CpuDeconvDwFp16KernelCreator(inputs, outputs, opParameter, ctx);
#else
kernel = CpuDeconvDwFp32KernelCreator(inputs, outputs, opParameter, ctx);
#endif
break;
default:
break;
}
if (kernel == nullptr) {
MS_LOG(ERROR) << "kernel is nullptr.";
return nullptr;
}
auto ret = kernel->Init();
if (ret != RET_OK) {
delete kernel;
MS_LOG(ERROR) << "Init kernel failed, name: " << opParameter->name_ << ", type: "
<< schema::EnumNamePrimitiveType(static_cast<schema::PrimitiveType>(opParameter->type_));
return nullptr;
}
return kernel;
}
kernel::LiteKernel *CpuDeConvFp32KernelCreator(const std::vector<lite::tensor::Tensor *> &inputs,
const std::vector<lite::tensor::Tensor *> &outputs,
OpParameter *opParameter, const lite::Context *ctx) {
auto kernel = new (std::nothrow) DeConvolutionCPUKernel(opParameter, inputs, outputs, ctx);
if (kernel == nullptr) {
MS_LOG(ERROR) << "kernel is nullptr.";
return nullptr;
}
return kernel;
}
kernel::LiteKernel *CpuDeConvInt8KernelCreator(const std::vector<lite::tensor::Tensor *> &inputs,
const std::vector<lite::tensor::Tensor *> &outputs,
OpParameter *opParameter, const lite::Context *ctx) {
auto kernel = new (std::nothrow) DeConvInt8CPUKernel(opParameter, inputs, outputs, ctx);
if (kernel == nullptr) {
MS_LOG(ERROR) << "kernel is nullptr.";
return nullptr;
}
return kernel;
}
kernel::LiteKernel *CpuDeConvKernelCreator(const std::vector<lite::tensor::Tensor *> &inputs,
const std::vector<lite::tensor::Tensor *> &outputs, OpParameter *opParameter,
const lite::Context *ctx, const kernel::KernelKey &desc) {
MS_ASSERT(opParameter != nullptr);
MS_ASSERT(desc.type == schema::PrimitiveType_DeConv2D);
auto input_tensor = inputs.at(kInputIndex);
auto data_type = input_tensor->data_type();
kernel::LiteKernel *kernel = nullptr;
switch (data_type) {
case kNumberTypeInt8:
kernel = CpuDeConvInt8KernelCreator(inputs, outputs, opParameter, ctx);
break;
#ifdef ENABLE_FP16
case kNumberTypeFloat16:
break;
#endif
case kNumberTypeFloat32:
kernel = CpuDeConvFp32KernelCreator(inputs, outputs, opParameter, ctx);
break;
default:
break;
}
if (kernel == nullptr) {
MS_LOG(ERROR) << "kernel is nullptr.";
return nullptr;
}
auto ret = kernel->Init();
if (ret != RET_OK) {
delete kernel;
MS_LOG(ERROR) << "Init kernel failed, name: " << opParameter->name_ << ", type: "
<< schema::EnumNamePrimitiveType(static_cast<schema::PrimitiveType>(opParameter->type_));
return nullptr;
}
return kernel;
}
REG_KERNEL(kCPU, PrimitiveType_Conv2D, CpuConvKernelCreator)
REG_KERNEL(kCPU, PrimitiveType_DeConv2D, CpuDeConvKernelCreator)
REG_KERNEL(kCPU, PrimitiveType_DepthwiseConv2D, CpuConvDwKernelCreator)
REG_KERNEL(kCPU, PrimitiveType_DeDepthwiseConv2D, CpuDeconvDwKernelCreator)
} // namespace mindspore::kernel

View File

@ -28,7 +28,6 @@
#include "src/lite_kernel.h"
#include "include/context.h"
#include "src/runtime/kernel/arm/base/layout_transform.h"
#include "src/runtime/kernel/arm/opclib/optimized_kernel.h"
using mindspore::lite::Context;
using mindspore::schema::PadMode;

View File

@ -32,39 +32,17 @@ int FullconnectionBaseCPUKernel::Init() {
return RET_OK;
}
kernel::LiteKernel *CpuFullConnectionKernelCreator(const std::vector<lite::tensor::Tensor *> &inputs,
const std::vector<lite::tensor::Tensor *> &outputs,
OpParameter *opParameter, const lite::Context *ctx,
const kernel::KernelKey &desc) {
kernel::LiteKernel *CpuFullConnectionInt8KernelCreator(const std::vector<lite::tensor::Tensor *> &inputs,
const std::vector<lite::tensor::Tensor *> &outputs,
OpParameter *opParameter, const lite::Context *ctx,
const kernel::KernelKey &desc) {
MS_ASSERT(opParameter != nullptr);
MS_ASSERT(desc.type == schema::PrimitiveType_Concat);
auto input_tensor = inputs.at(kInputIndex);
auto data_type = input_tensor->data_type();
kernel::LiteKernel *kernel = nullptr;
switch (data_type) {
case kNumberTypeInt8:
case kNumberTypeUInt8: {
kernel = new (std::nothrow) FullconnectionInt8CPUKernel(opParameter, inputs, outputs, ctx);
if (!kernel) {
MS_LOG(ERROR) << "kernel is nullptr.";
return nullptr;
}
break;
}
case kNumberTypeFloat32: {
kernel = new (std::nothrow) FullconnectionCPUKernel(opParameter, inputs, outputs, ctx);
if (!kernel) {
MS_LOG(ERROR) << "kernel is nullptr.";
return nullptr;
}
break;
}
default:
break;
auto kernel = new (std::nothrow) FullconnectionInt8CPUKernel(opParameter, inputs, outputs, ctx);
if (!kernel) {
MS_LOG(ERROR) << "kernel is nullptr.";
return nullptr;
}
auto ret = kernel->Init();
if (ret != RET_OK) {
delete kernel;
@ -75,5 +53,27 @@ kernel::LiteKernel *CpuFullConnectionKernelCreator(const std::vector<lite::tenso
return kernel;
}
REG_KERNEL(kCPU, PrimitiveType_FullConnection, CpuFullConnectionKernelCreator)
kernel::LiteKernel *CpuFullConnectionFp32KernelCreator(const std::vector<lite::tensor::Tensor *> &inputs,
const std::vector<lite::tensor::Tensor *> &outputs,
OpParameter *opParameter, const lite::Context *ctx,
const kernel::KernelKey &desc) {
MS_ASSERT(opParameter != nullptr);
MS_ASSERT(desc.type == schema::PrimitiveType_Concat);
auto kernel = new (std::nothrow) FullconnectionCPUKernel(opParameter, inputs, outputs, ctx);
if (!kernel) {
MS_LOG(ERROR) << "kernel is nullptr.";
return nullptr;
}
auto ret = kernel->Init();
if (ret != RET_OK) {
delete kernel;
MS_LOG(ERROR) << "Init kernel failed, name: " << opParameter->name_ << ", type: "
<< schema::EnumNamePrimitiveType(static_cast<schema::PrimitiveType>(opParameter->type_));
return nullptr;
}
return kernel;
}
REG_KERNEL(kCPU, kNumberTypeInt8, PrimitiveType_FullConnection, CpuFullConnectionInt8KernelCreator)
REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_FullConnection, CpuFullConnectionFp32KernelCreator)
} // namespace mindspore::kernel

View File

@ -15,6 +15,7 @@
*/
#include "src/runtime/kernel/arm/base/layout_transform.h"
#include "mindspore/core/utils/log_adapter.h"
using mindspore::schema::Format;
namespace mindspore::kernel {

View File

@ -21,7 +21,8 @@
#include <arm_neon.h>
#endif
#include "src/runtime/kernel/arm/opclib/pack.h"
#include "src/ir/tensor.h"
#include "ir/dtype/type_id.h"
#include "schema/ops_generated.h"
namespace mindspore::kernel {
typedef void (*LayoutConvertor)(const void *src, void *dst, int batch, int plane, int channel);

View File

@ -32,50 +32,13 @@ kernel::LiteKernel *CpuPadInt8KernelCreator(const std::vector<lite::tensor::Tens
const std::vector<lite::tensor::Tensor *> &outputs,
OpParameter *opParameter, const lite::Context *ctx,
const kernel::KernelKey &desc) {
MS_ASSERT(opParameter != nullptr);
MS_ASSERT(desc.type == schema::PrimitiveType_Pad);
auto *kernel = new (std::nothrow) PadInt8CPUKernel(opParameter, inputs, outputs, ctx);
if (kernel == nullptr) {
MS_LOG(ERROR) << "new PadCPUKernel failed.";
return nullptr;
}
return kernel;
}
kernel::LiteKernel *CpuPadFp32KernelCreator(const std::vector<lite::tensor::Tensor *> &inputs,
const std::vector<lite::tensor::Tensor *> &outputs,
OpParameter *opParameter, const lite::Context *ctx,
const kernel::KernelKey &desc) {
auto *kernel = new (std::nothrow) PadCPUKernel(opParameter, inputs, outputs, ctx);
if (kernel == nullptr) {
MS_LOG(ERROR) << "new PadCPUKernel failed.";
return nullptr;
}
return kernel;
}
kernel::LiteKernel *CpuPadKernelCreator(const std::vector<lite::tensor::Tensor *> &inputs,
const std::vector<lite::tensor::Tensor *> &outputs, OpParameter *opParameter,
const lite::Context *ctx, const kernel::KernelKey &desc) {
MS_ASSERT(opParameter != nullptr);
MS_ASSERT(desc.type == schema::PrimitiveType_Concat);
auto input_tensor = inputs.at(kInputIndex);
auto data_type = input_tensor->data_type();
kernel::LiteKernel *kernel = nullptr;
switch (data_type) {
case kNumberTypeInt8:
kernel = CpuPadInt8KernelCreator(inputs, outputs, opParameter, ctx, desc);
break;
case kNumberTypeFloat32:
kernel = CpuPadFp32KernelCreator(inputs, outputs, opParameter, ctx, desc);
break;
default:
break;
}
if (kernel == nullptr) {
MS_LOG(ERROR) << "kernel is nullptr.";
return nullptr;
}
auto ret = kernel->Init();
if (ret != RET_OK) {
delete kernel;
@ -86,5 +49,27 @@ kernel::LiteKernel *CpuPadKernelCreator(const std::vector<lite::tensor::Tensor *
return kernel;
}
REG_KERNEL(kCPU, PrimitiveType_Pad, CpuPadKernelCreator)
kernel::LiteKernel *CpuPadFp32KernelCreator(const std::vector<lite::tensor::Tensor *> &inputs,
const std::vector<lite::tensor::Tensor *> &outputs,
OpParameter *opParameter, const lite::Context *ctx,
const kernel::KernelKey &desc) {
MS_ASSERT(opParameter != nullptr);
MS_ASSERT(desc.type == schema::PrimitiveType_Pad);
auto *kernel = new (std::nothrow) PadCPUKernel(opParameter, inputs, outputs, ctx);
if (kernel == nullptr) {
MS_LOG(ERROR) << "new PadCPUKernel failed.";
return nullptr;
}
auto ret = kernel->Init();
if (ret != RET_OK) {
delete kernel;
MS_LOG(ERROR) << "Init kernel failed, name: " << opParameter->name_ << ", type: "
<< schema::EnumNamePrimitiveType(static_cast<schema::PrimitiveType>(opParameter->type_));
return nullptr;
}
return kernel;
}
REG_KERNEL(kCPU, kNumberTypeInt8, PrimitiveType_Pad, CpuPadInt8KernelCreator)
REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_Pad, CpuPadFp32KernelCreator)
} // namespace mindspore::kernel

View File

@ -81,7 +81,8 @@ int PoolingBaseCPUKernel::Init() {
kernel::LiteKernel *CpuPoolingInt8KernelCreator(const std::vector<lite::tensor::Tensor *> &inputs,
const std::vector<lite::tensor::Tensor *> &outputs,
OpParameter *opParameter, const Context *ctx) {
OpParameter *opParameter, const Context *ctx,
const kernel::KernelKey &desc) {
if (opParameter == nullptr) {
MS_LOG(ERROR) << "Input opParameter is nullptr!";
return nullptr;
@ -92,50 +93,6 @@ kernel::LiteKernel *CpuPoolingInt8KernelCreator(const std::vector<lite::tensor::
MS_LOG(ERROR) << "new PoolingInt8CPUKernel fail!";
return nullptr;
}
return kernel;
}
kernel::LiteKernel *CpuPoolingFp32KernelCreator(const std::vector<lite::tensor::Tensor *> &inputs,
const std::vector<lite::tensor::Tensor *> &outputs,
OpParameter *opParameter, const Context *ctx) {
if (opParameter == nullptr) {
MS_LOG(ERROR) << "Input opParameter is nullptr!";
return nullptr;
}
MS_ASSERT(desc.type == schema::PrimitiveType_Pooling);
auto *kernel = new (std::nothrow) PoolingCPUKernel(opParameter, inputs, outputs, ctx);
if (kernel == nullptr) {
MS_LOG(ERROR) << "new PoolingCPUKernel fail!";
return nullptr;
}
return kernel;
}
kernel::LiteKernel *CpuPoolingKernelCreator(const std::vector<lite::tensor::Tensor *> &inputs,
const std::vector<lite::tensor::Tensor *> &outputs,
OpParameter *opParameter, const lite::Context *ctx,
const kernel::KernelKey &desc) {
MS_ASSERT(opParameter != nullptr);
MS_ASSERT(desc.type == schema::PrimitiveType_Pooing);
auto input_tensor = inputs.at(kInputIndex);
auto data_type = input_tensor->data_type();
kernel::LiteKernel *kernel = nullptr;
switch (data_type) {
case kNumberTypeInt8:
case kNumberTypeUInt8:
kernel = CpuPoolingInt8KernelCreator(inputs, outputs, opParameter, ctx);
break;
case kNumberTypeFloat32:
kernel = CpuPoolingFp32KernelCreator(inputs, outputs, opParameter, ctx);
break;
default:
break;
}
if (kernel == nullptr) {
MS_LOG(ERROR) << "kernel is nullptr.";
return nullptr;
}
auto ret = kernel->Init();
if (ret != RET_OK) {
delete kernel;
@ -146,5 +103,30 @@ kernel::LiteKernel *CpuPoolingKernelCreator(const std::vector<lite::tensor::Tens
return kernel;
}
REG_KERNEL(kCPU, PrimitiveType_Pooling, CpuPoolingKernelCreator)
kernel::LiteKernel *CpuPoolingFp32KernelCreator(const std::vector<lite::tensor::Tensor *> &inputs,
const std::vector<lite::tensor::Tensor *> &outputs,
OpParameter *opParameter, const Context *ctx,
const kernel::KernelKey &desc) {
if (opParameter == nullptr) {
MS_LOG(ERROR) << "Input opParameter is nullptr!";
return nullptr;
}
MS_ASSERT(desc.type == schema::PrimitiveType_Pooling);
auto *kernel = new (std::nothrow) PoolingCPUKernel(opParameter, inputs, outputs, ctx);
if (kernel == nullptr) {
MS_LOG(ERROR) << "new PoolingCPUKernel fail!";
return nullptr;
}
auto ret = kernel->Init();
if (ret != RET_OK) {
delete kernel;
MS_LOG(ERROR) << "Init kernel failed, name: " << opParameter->name_ << ", type: "
<< schema::EnumNamePrimitiveType(static_cast<schema::PrimitiveType>(opParameter->type_));
return nullptr;
}
return kernel;
}
REG_KERNEL(kCPU, kNumberTypeInt8, PrimitiveType_Pooling, CpuPoolingInt8KernelCreator)
REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_Pooling, CpuPoolingFp32KernelCreator)
} // namespace mindspore::kernel

View File

@ -35,62 +35,18 @@ int ReshapeBaseCPUKernel::Init() {
kernel::LiteKernel *CpuReshapeInt8KernelCreator(const std::vector<lite::tensor::Tensor *> &inputs,
const std::vector<lite::tensor::Tensor *> &outputs,
OpParameter *opParameter, const Context *ctx) {
OpParameter *opParameter, const Context *ctx,
const kernel::KernelKey &desc) {
if (opParameter == nullptr) {
MS_LOG(ERROR) << "Input opParameter is nullptr!";
return nullptr;
}
MS_ASSERT(desc.type == schema::PrimitiveType_Concat);
auto *kernel = new(std::nothrow) ReshapeInt8CPUKernel(opParameter, inputs, outputs, ctx);
MS_ASSERT(desc.type == schema::PrimitiveType_Reshape);
auto *kernel = new (std::nothrow) ReshapeInt8CPUKernel(opParameter, inputs, outputs, ctx);
if (kernel == nullptr) {
MS_LOG(ERROR) << "new ConcatCPUKernel fail!";
return nullptr;
}
return kernel;
}
kernel::LiteKernel *CpuReshapeFp32OrInt32KernelCreator(const std::vector<lite::tensor::Tensor *> &inputs,
const std::vector<lite::tensor::Tensor *> &outputs,
OpParameter *opParameter, const Context *ctx) {
if (opParameter == nullptr) {
MS_LOG(ERROR) << "Input opParameter is nullptr!";
return nullptr;
}
MS_ASSERT(desc.type == schema::PrimitiveType_Concat);
auto *kernel = new(std::nothrow) ReshapeCPUKernel(opParameter, inputs, outputs, ctx);
if (kernel == nullptr) {
MS_LOG(ERROR) << "new ConcatCPUKernel fail!";
return nullptr;
}
return kernel;
}
kernel::LiteKernel *CpuReshapeKernelCreator(const std::vector<lite::tensor::Tensor *> &inputs,
const std::vector<lite::tensor::Tensor *> &outputs,
OpParameter *opParameter, const lite::Context *ctx,
const kernel::KernelKey &desc) {
MS_ASSERT(opParameter != nullptr);
MS_ASSERT(desc.type == schema::PrimitiveType_Concat);
auto input_tensor = inputs.at(kInputIndex);
auto data_type = input_tensor->data_type();
kernel::LiteKernel *kernel = nullptr;
switch (data_type) {
case kNumberTypeInt8:
case kNumberTypeUInt8:
kernel = CpuReshapeInt8KernelCreator(inputs, outputs, opParameter, ctx);
break;
case kNumberTypeInt32:
case kNumberTypeFloat32:
kernel = CpuReshapeFp32OrInt32KernelCreator(inputs, outputs, opParameter, ctx);
break;
default:
break;
}
if (kernel == nullptr) {
MS_LOG(ERROR) << "kernel is nullptr.";
return nullptr;
}
auto ret = kernel->Init();
if (ret != RET_OK) {
delete kernel;
@ -101,6 +57,55 @@ kernel::LiteKernel *CpuReshapeKernelCreator(const std::vector<lite::tensor::Tens
return kernel;
}
REG_KERNEL(kCPU, PrimitiveType_Reshape, CpuReshapeKernelCreator)
} // namespace mindspore::kernel
kernel::LiteKernel *CpuReshapeInt32KernelCreator(const std::vector<lite::tensor::Tensor *> &inputs,
const std::vector<lite::tensor::Tensor *> &outputs,
OpParameter *opParameter, const Context *ctx,
const kernel::KernelKey &desc) {
if (opParameter == nullptr) {
MS_LOG(ERROR) << "Input opParameter is nullptr!";
return nullptr;
}
MS_ASSERT(desc.type == schema::PrimitiveType_Reshape);
auto *kernel = new (std::nothrow) ReshapeCPUKernel(opParameter, inputs, outputs, ctx);
if (kernel == nullptr) {
MS_LOG(ERROR) << "new ConcatCPUKernel fail!";
return nullptr;
}
auto ret = kernel->Init();
if (ret != RET_OK) {
delete kernel;
MS_LOG(ERROR) << "Init kernel failed, name: " << opParameter->name_ << ", type: "
<< schema::EnumNamePrimitiveType(static_cast<schema::PrimitiveType>(opParameter->type_));
return nullptr;
}
return kernel;
}
kernel::LiteKernel *CpuReshapeFp32KernelCreator(const std::vector<lite::tensor::Tensor *> &inputs,
const std::vector<lite::tensor::Tensor *> &outputs,
OpParameter *opParameter, const Context *ctx,
const kernel::KernelKey &desc) {
if (opParameter == nullptr) {
MS_LOG(ERROR) << "Input opParameter is nullptr!";
return nullptr;
}
MS_ASSERT(desc.type == schema::PrimitiveType_Reshape);
auto *kernel = new (std::nothrow) ReshapeCPUKernel(opParameter, inputs, outputs, ctx);
if (kernel == nullptr) {
MS_LOG(ERROR) << "new ConcatCPUKernel fail!";
return nullptr;
}
auto ret = kernel->Init();
if (ret != RET_OK) {
delete kernel;
MS_LOG(ERROR) << "Init kernel failed, name: " << opParameter->name_ << ", type: "
<< schema::EnumNamePrimitiveType(static_cast<schema::PrimitiveType>(opParameter->type_));
return nullptr;
}
return kernel;
}
REG_KERNEL(kCPU, kNumberTypeInt8, PrimitiveType_Reshape, CpuReshapeInt8KernelCreator)
REG_KERNEL(kCPU, kNumberTypeInt32, PrimitiveType_Reshape, CpuReshapeInt32KernelCreator)
REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_Reshape, CpuReshapeFp32KernelCreator)
} // namespace mindspore::kernel

View File

@ -16,6 +16,8 @@
#include "src/runtime/kernel/arm/fp16/convolution_3x3_fp16.h"
#include "src/runtime/kernel/arm/opclib/fp16/conv_fp16.h"
#include "src/runtime/kernel/arm/opclib/fp16/winograd_transform_fp16.h"
#include "src/runtime/kernel/arm/opclib/fp16/pack_fp16.h"
#include "src/runtime/kernel/arm/base/layout_transform.h"
#include "schema/model_generated.h"
#include "src/kernel_registry.h"
@ -165,7 +167,6 @@ void Convolution3x3FP16CPUKernel::ConfigInputOutput() {
}
int Convolution3x3FP16CPUKernel::Init() {
ConvolutionBaseCPUKernel::Init();
auto ret = ConvolutionBaseCPUKernel::Init();
if (ret != RET_OK) {
MS_LOG(ERROR) << "ConvolutionBase init failed.";

View File

@ -20,8 +20,6 @@
#include <arm_neon.h>
#include <vector>
#include "src/lite_kernel.h"
#include "src/runtime/kernel/arm/opclib/winograd_transform.h"
#include "src/runtime/kernel/arm/base/convolution_base.h"
#include "src/runtime/kernel/arm/opclib/optimized_kernel.h"

View File

@ -15,6 +15,7 @@
*/
#include "src/runtime/kernel/arm/fp16/convolution_depthwise_fp16.h"
#include "src/runtime/kernel/arm/opclib/fp16/pack_fp16.h"
#include "schema/model_generated.h"
#include "src/kernel_registry.h"
#include "include/errorcode.h"
@ -161,4 +162,27 @@ int ConvolutionDepthwiseFp16CPUKernel::Run() {
return RET_OK;
}
kernel::LiteKernel *CpuConvDwFp16KernelCreator(const std::vector<lite::tensor::Tensor *> &inputs,
const std::vector<lite::tensor::Tensor *> &outputs,
OpParameter *opParameter, const Context *ctx,
const kernel::KernelKey &desc) {
MS_ASSERT(opParameter != nullptr);
MS_ASSERT(desc.type == schema::PrimitiveType_DepthwiseConv2D);
auto kernel = new (std::nothrow) ConvolutionDepthwiseFp16CPUKernel(opParameter, inputs, outputs, ctx);
if (kernel == nullptr) {
MS_LOG(ERROR) << "kernel is nullptr.";
return nullptr;
}
auto ret = kernel->Init();
if (ret != RET_OK) {
delete kernel;
MS_LOG(ERROR) << "Init kernel failed, name: " << opParameter->name_ << ", type: "
<< schema::EnumNamePrimitiveType(static_cast<schema::PrimitiveType>(opParameter->type_));
return nullptr;
}
return kernel;
}
REG_KERNEL(kCPU, kNumberTypeFloat16, PrimitiveType_DepthwiseConv2D, CpuConvDwFp16KernelCreator)
} // namespace mindspore::kernel

View File

@ -15,7 +15,9 @@
*/
#include "src/runtime/kernel/arm/fp16/convolution_fp16.h"
#include "src/runtime/kernel/arm/fp16/convolution_3x3_fp16.h"
#include "src/runtime/kernel/arm/opclib/fp16/conv_fp16.h"
#include "src/runtime/kernel/arm/opclib/fp16/pack_fp16.h"
#include "src/runtime/kernel/arm/base/layout_transform.h"
#include "schema/model_generated.h"
#include "src/kernel_registry.h"
@ -218,5 +220,42 @@ int ConvolutionFP16CPUKernel::Run() {
}
return RET_OK;
}
} // namespace mindspore::kernel
kernel::LiteKernel *CpuConvFp16KernelCreator(const std::vector<lite::tensor::Tensor *> &inputs,
const std::vector<lite::tensor::Tensor *> &outputs,
OpParameter *opParameter, const Context *ctx,
const kernel::KernelKey &desc) {
MS_ASSERT(opParameter != nullptr);
MS_ASSERT(desc.type == schema::PrimitiveType_Conv2D);
auto conv_param = reinterpret_cast<ConvParameter *>(opParameter);
int kernel_h = conv_param->kernel_h_;
int kernel_w = conv_param->kernel_w_;
int stride_h = conv_param->stride_h_;
int stride_w = conv_param->stride_w_;
int dilation_h = conv_param->dilation_h_;
int dilation_w = conv_param->dilation_w_;
conv_param->input_h_ = inputs.front()->Height();
conv_param->input_w_ = inputs.front()->Width();
conv_param->output_h_ = outputs.front()->Height();
conv_param->output_w_ = outputs.front()->Width();
kernel::LiteKernel *kernel;
if (kernel_h == 3 && kernel_w == 3 && stride_h == 1 && stride_w == 1 && dilation_h == 1 && dilation_w == 1) {
kernel = new (std::nothrow) kernel::Convolution3x3FP16CPUKernel(opParameter, inputs, outputs, ctx);
} else {
kernel = new (std::nothrow) kernel::ConvolutionFP16CPUKernel(opParameter, inputs, outputs, ctx);
}
if (kernel == nullptr) {
MS_LOG(ERROR) << "Create conv fp16 kernel failed.";
return nullptr;
}
auto ret = kernel->Init();
if (ret != RET_OK) {
delete kernel;
MS_LOG(ERROR) << "Init kernel failed, name: " << opParameter->name_ << ", type: "
<< schema::EnumNamePrimitiveType(static_cast<schema::PrimitiveType>(opParameter->type_));
return nullptr;
}
return kernel;
}
REG_KERNEL(kCPU, kNumberTypeFloat16, PrimitiveType_Conv2D, CpuConvFp16KernelCreator)
} // namespace mindspore::kernel

View File

@ -15,6 +15,7 @@
*/
#include "src/runtime/kernel/arm/fp16/deconvolution_depthwise_fp16.h"
#include "src/runtime/kernel/arm/opclib/fp16/pack_fp16.h"
#include "schema/model_generated.h"
#include "src/kernel_registry.h"
#include "include/errorcode.h"
@ -24,7 +25,7 @@ using mindspore::kernel::KERNEL_ARCH::kCPU;
using mindspore::lite::KernelRegistrar;
using mindspore::lite::RET_ERROR;
using mindspore::lite::RET_OK;
using mindspore::schema::PrimitiveType_DepthwiseConv2D;
using mindspore::schema::PrimitiveType_DeDepthwiseConv2D;
namespace mindspore::kernel {
int DeconvolutionDepthwiseFp16CPUKernel::InitSlideParam() {
@ -171,4 +172,27 @@ int DeconvolutionDepthwiseFp16CPUKernel::Run() {
conv_param_->output_h_ * conv_param_->output_w_, conv_param_->output_channel_);
return RET_OK;
}
kernel::LiteKernel *CpuDeconvDwFp16KernelCreator(const std::vector<lite::tensor::Tensor *> &inputs,
const std::vector<lite::tensor::Tensor *> &outputs,
OpParameter *opParameter, const lite::Context *ctx,
const kernel::KernelKey &desc) {
MS_ASSERT(opParameter != nullptr);
MS_ASSERT(desc.type == schema::PrimitiveType_DeDepthwiseConv2D);
auto kernel = new (std::nothrow) DeconvolutionDepthwiseFp16CPUKernel(opParameter, inputs, outputs, ctx);
if (kernel == nullptr) {
MS_LOG(ERROR) << "kernel is nullptr.";
return nullptr;
}
auto ret = kernel->Init();
if (ret != RET_OK) {
delete kernel;
MS_LOG(ERROR) << "Init kernel failed, name: " << opParameter->name_ << ", type: "
<< schema::EnumNamePrimitiveType(static_cast<schema::PrimitiveType>(opParameter->type_));
return nullptr;
}
return kernel;
}
REG_KERNEL(kCPU, kNumberTypeFloat16, PrimitiveType_DeDepthwiseConv2D, CpuDeconvDwFp16KernelCreator)
} // namespace mindspore::kernel

View File

@ -93,6 +93,10 @@ kernel::LiteKernel *CpuActivationFp32KernelCreator(const std::vector<lite::tenso
MS_ASSERT(opParameter != nullptr);
MS_ASSERT(desc.type == schema::PrimitiveType_Activation);
auto *kernel = new (std::nothrow) ActivationCPUKernel(opParameter, inputs, outputs, ctx);
if (kernel == nullptr) {
MS_LOG(ERROR) << "kernel is nullptr.";
return nullptr;
}
auto ret = kernel->Init();
if (ret != RET_OK) {
MS_LOG(ERROR) << "Init kernel failed, name: " << opParameter->name_ << ", type: "
@ -101,6 +105,5 @@ kernel::LiteKernel *CpuActivationFp32KernelCreator(const std::vector<lite::tenso
return kernel;
}
REG_KERNEL(kCPU, PrimitiveType_Activation, CpuActivationFp32KernelCreator)
REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_Activation, CpuActivationFp32KernelCreator)
} // namespace mindspore::kernel

View File

@ -100,7 +100,7 @@ kernel::LiteKernel *CpuAddNFp32KernelCreator(const std::vector<lite::tensor::Ten
MS_LOG(ERROR) << "Input context is nullptr!";
return nullptr;
}
MS_ASSERT(desc.type == schema::PrimitiveType_AddN);
op_parameter->thread_num_ = ctx->threadNum;
auto *kernel = new (std::nothrow) AddNCPUKernel(op_parameter, inputs, outputs);
if (kernel == nullptr) {
@ -117,5 +117,5 @@ kernel::LiteKernel *CpuAddNFp32KernelCreator(const std::vector<lite::tensor::Ten
return kernel;
}
REG_KERNEL(kCPU, PrimitiveType_AddN, CpuAddNFp32KernelCreator)
REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_AddN, CpuAddNFp32KernelCreator)
} // namespace mindspore::kernel

View File

@ -92,9 +92,6 @@ kernel::LiteKernel *CpuArgMinMaxFp32KernelCreator(const std::vector<lite::tensor
return kernel;
}
REG_KERNEL(kCPU, PrimitiveType_ArgMax, CpuArgMinMaxFp32KernelCreator)
REG_KERNEL(kCPU, PrimitiveType_ArgMin, CpuArgMinMaxFp32KernelCreator)
REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_ArgMax, CpuArgMinMaxFp32KernelCreator)
REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_ArgMin, CpuArgMinMaxFp32KernelCreator)
} // namespace mindspore::kernel

View File

@ -112,25 +112,8 @@ kernel::LiteKernel *CpuArithmeticFp32KernelCreator(const std::vector<lite::tenso
const std::vector<lite::tensor::Tensor *> &outputs,
OpParameter *parameter, const lite::Context *ctx,
const kernel::KernelKey &desc) {
MS_ASSERT(parameter);
MS_ASSERT(inputs.at(0));
auto data_type = inputs.at(0)->data_type();
kernel::LiteKernel *kernel = nullptr;
switch (data_type) {
case kNumberTypeFloat32:
kernel = new (std::nothrow) ArithmeticCPUKernel(parameter, inputs, outputs, ctx);
break;
case kNumberTypeInt8:
if (desc.type == schema::PrimitiveType_Add) {
kernel = new (std::nothrow) QuantizedAddCPUKernel(parameter, inputs, outputs, ctx);
} else if (desc.type == schema::PrimitiveType_Mul) {
kernel = new (std::nothrow) MulInt8CPUKernel(parameter, inputs, outputs, ctx);
} else {
}
break;
default:
break;
}
MS_ASSERT(parameter != nullptr);
auto kernel = new (std::nothrow) ArithmeticCPUKernel(parameter, inputs, outputs, ctx);
if (kernel == nullptr) {
MS_LOG(ERROR) << "Create kernel failed, name: " << parameter->name_;
return nullptr;
@ -145,24 +128,23 @@ kernel::LiteKernel *CpuArithmeticFp32KernelCreator(const std::vector<lite::tenso
return kernel;
}
REG_KERNEL(kCPU, PrimitiveType_Mul, CpuArithmeticFp32KernelCreator)
REG_KERNEL(kCPU, PrimitiveType_Add, CpuArithmeticFp32KernelCreator)
REG_KERNEL(kCPU, PrimitiveType_Sub, CpuArithmeticFp32KernelCreator)
REG_KERNEL(kCPU, PrimitiveType_Div, CpuArithmeticFp32KernelCreator)
REG_KERNEL(kCPU, PrimitiveType_LogicalAnd, CpuArithmeticFp32KernelCreator)
REG_KERNEL(kCPU, PrimitiveType_LogicalOr, CpuArithmeticFp32KernelCreator)
REG_KERNEL(kCPU, PrimitiveType_Maximum, CpuArithmeticFp32KernelCreator)
REG_KERNEL(kCPU, PrimitiveType_Minimum, CpuArithmeticFp32KernelCreator)
REG_KERNEL(kCPU, PrimitiveType_FloorDiv, CpuArithmeticFp32KernelCreator)
REG_KERNEL(kCPU, PrimitiveType_FloorMod, CpuArithmeticFp32KernelCreator)
REG_KERNEL(kCPU, PrimitiveType_SquaredDifference, CpuArithmeticFp32KernelCreator)
REG_KERNEL(kCPU, PrimitiveType_Equal, CpuArithmeticFp32KernelCreator)
REG_KERNEL(kCPU, PrimitiveType_NotEqual, CpuArithmeticFp32KernelCreator)
REG_KERNEL(kCPU, PrimitiveType_Less, CpuArithmeticFp32KernelCreator)
REG_KERNEL(kCPU, PrimitiveType_LessEqual, CpuArithmeticFp32KernelCreator)
REG_KERNEL(kCPU, PrimitiveType_Greater, CpuArithmeticFp32KernelCreator)
REG_KERNEL(kCPU, PrimitiveType_GreaterEqual, CpuArithmeticFp32KernelCreator)
REG_KERNEL(kCPU, PrimitiveType_Eltwise, CpuArithmeticFp32KernelCreator)
REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_Mul, CpuArithmeticFp32KernelCreator)
REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_Add, CpuArithmeticFp32KernelCreator)
REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_Sub, CpuArithmeticFp32KernelCreator)
REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_Div, CpuArithmeticFp32KernelCreator)
REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_LogicalAnd, CpuArithmeticFp32KernelCreator)
REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_LogicalOr, CpuArithmeticFp32KernelCreator)
REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_Maximum, CpuArithmeticFp32KernelCreator)
REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_Minimum, CpuArithmeticFp32KernelCreator)
REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_FloorDiv, CpuArithmeticFp32KernelCreator)
REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_FloorMod, CpuArithmeticFp32KernelCreator)
REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_SquaredDifference, CpuArithmeticFp32KernelCreator)
REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_Equal, CpuArithmeticFp32KernelCreator)
REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_NotEqual, CpuArithmeticFp32KernelCreator)
REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_Less, CpuArithmeticFp32KernelCreator)
REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_LessEqual, CpuArithmeticFp32KernelCreator)
REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_Greater, CpuArithmeticFp32KernelCreator)
REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_GreaterEqual, CpuArithmeticFp32KernelCreator)
REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_Eltwise, CpuArithmeticFp32KernelCreator)
} // namespace mindspore::kernel

View File

@ -101,16 +101,15 @@ kernel::LiteKernel *CpuArithmeticSelfFp32KernelCreator(const std::vector<lite::t
return kernel;
}
REG_KERNEL(kCPU, PrimitiveType_Abs, CpuArithmeticSelfFp32KernelCreator)
REG_KERNEL(kCPU, PrimitiveType_Cos, CpuArithmeticSelfFp32KernelCreator)
REG_KERNEL(kCPU, PrimitiveType_Exp, CpuArithmeticSelfFp32KernelCreator)
REG_KERNEL(kCPU, PrimitiveType_Log, CpuArithmeticSelfFp32KernelCreator)
REG_KERNEL(kCPU, PrimitiveType_Square, CpuArithmeticSelfFp32KernelCreator)
REG_KERNEL(kCPU, PrimitiveType_Sqrt, CpuArithmeticSelfFp32KernelCreator)
REG_KERNEL(kCPU, PrimitiveType_Rsqrt, CpuArithmeticSelfFp32KernelCreator)
REG_KERNEL(kCPU, PrimitiveType_Sin, CpuArithmeticSelfFp32KernelCreator)
REG_KERNEL(kCPU, PrimitiveType_LogicalNot, CpuArithmeticSelfFp32KernelCreator)
REG_KERNEL(kCPU, PrimitiveType_Floor, CpuArithmeticSelfFp32KernelCreator)
REG_KERNEL(kCPU, PrimitiveType_Ceil, CpuArithmeticSelfFp32KernelCreator)
REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_Abs, CpuArithmeticSelfFp32KernelCreator)
REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_Cos, CpuArithmeticSelfFp32KernelCreator)
REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_Exp, CpuArithmeticSelfFp32KernelCreator)
REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_Log, CpuArithmeticSelfFp32KernelCreator)
REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_Square, CpuArithmeticSelfFp32KernelCreator)
REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_Sqrt, CpuArithmeticSelfFp32KernelCreator)
REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_Rsqrt, CpuArithmeticSelfFp32KernelCreator)
REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_Sin, CpuArithmeticSelfFp32KernelCreator)
REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_LogicalNot, CpuArithmeticSelfFp32KernelCreator)
REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_Floor, CpuArithmeticSelfFp32KernelCreator)
REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_Ceil, CpuArithmeticSelfFp32KernelCreator)
} // namespace mindspore::kernel

View File

@ -69,6 +69,7 @@ kernel::LiteKernel *CpuBatchToSpaceFp32KernelCreator(const std::vector<lite::ten
MS_LOG(ERROR) << "Input op_parameter is nullptr!";
return nullptr;
}
MS_ASSERT(desc.type == schema::PrimitiveType_BatchToSpace);
auto *kernel = new (std::nothrow) BatchToSpaceCPUKernel(op_parameter, inputs, outputs);
if (kernel == nullptr) {
MS_LOG(ERROR) << "new BatchToSpaceCPUKernel fail!";
@ -85,7 +86,5 @@ kernel::LiteKernel *CpuBatchToSpaceFp32KernelCreator(const std::vector<lite::ten
return kernel;
}
REG_KERNEL(kCPU, PrimitiveType_BatchToSpace, CpuBatchToSpaceFp32KernelCreator)
REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_BatchToSpace, CpuBatchToSpaceFp32KernelCreator)
} // namespace mindspore::kernel

View File

@ -52,7 +52,7 @@ int BiasCPUKernel::Init() {
bias_param_->in_shape1_[i] = 1;
bias_param_->out_shape_[i] = dims[i];
}
bias_param_->in_shape1_[bias_param_->ndim_ - 1] = dims[bias_param_->ndim_ - 1];
bias_param_->in_shape1_[bias_param_->ndim_ - 1] = dims[bias_param_->ndim_ - 1];
return RET_OK;
}
@ -61,19 +61,7 @@ kernel::LiteKernel *CpuBiasFp32KernelCreator(const std::vector<lite::tensor::Ten
const lite::Context *ctx, const kernel::KernelKey &desc) {
MS_ASSERT(parameter != nullptr);
MS_ASSERT(desc.type == schema::PrimitiveType_BiasAdd);
MS_ASSERT(inputs.at(0));
auto data_type = inputs.at(0)->data_type();
kernel::LiteKernel *kernel = nullptr;
switch (data_type) {
case kNumberTypeFloat32:
kernel = new (std::nothrow) BiasCPUKernel(parameter, inputs, outputs);
break;
case kNumberTypeInt8:
kernel = new (std::nothrow) BiasAddInt8CPUKernel(parameter, inputs, outputs, ctx);
break;
default:
break;
}
auto kernel = new (std::nothrow) BiasCPUKernel(parameter, inputs, outputs);
if (kernel == nullptr) {
MS_LOG(ERROR) << "Create kernel failed, name: " << parameter->name_;
return nullptr;
@ -89,6 +77,5 @@ kernel::LiteKernel *CpuBiasFp32KernelCreator(const std::vector<lite::tensor::Ten
return kernel;
}
REG_KERNEL(kCPU, PrimitiveType_BiasAdd, CpuBiasFp32KernelCreator)
REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_BiasAdd, CpuBiasFp32KernelCreator)
} // namespace mindspore::kernel

View File

@ -56,6 +56,7 @@ kernel::LiteKernel *CpuBroadcastToFp32KernelCreator(const std::vector<lite::tens
MS_LOG(ERROR) << "Input op_parameter is nullptr!";
return nullptr;
}
MS_ASSERT(desc.type == schema::PrimitiveType_BroadcastTo);
auto *kernel = new (std::nothrow) BroadcastToCPUKernel(op_parameter, inputs, outputs);
if (kernel == nullptr) {
MS_LOG(ERROR) << "new BroadcastToCPUKernel fail!";
@ -72,7 +73,5 @@ kernel::LiteKernel *CpuBroadcastToFp32KernelCreator(const std::vector<lite::tens
return kernel;
}
REG_KERNEL(kCPU, PrimitiveType_BroadcastTo, CpuBroadcastToFp32KernelCreator)
REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_BroadcastTo, CpuBroadcastToFp32KernelCreator)
} // namespace mindspore::kernel

View File

@ -114,5 +114,5 @@ kernel::LiteKernel *CpuCastFp32KernelCreator(const std::vector<lite::tensor::Ten
return kernel;
}
REG_KERNEL(kCPU, PrimitiveType_Cast, CpuCastFp32KernelCreator)
REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_Cast, CpuCastFp32KernelCreator)
} // namespace mindspore::kernel

View File

@ -15,6 +15,9 @@
*/
#include "src/runtime/kernel/arm/fp32/convolution.h"
#include "src/runtime/kernel/arm/fp32/convolution_1x1.h"
#include "src/runtime/kernel/arm/fp32/convolution_3x3.h"
#include "src/runtime/kernel/arm/fp32/convolution_winograd.h"
#include "src/runtime/kernel/arm/opclib/fp32/conv.h"
#include "schema/model_generated.h"
#include "src/kernel_factory.h"
@ -204,4 +207,78 @@ int ConvolutionCPUKernel::Run() {
}
return RET_OK;
}
void CheckIfUseWinograd(bool *use_winograd, int *output_unit, ConvParameter *conv_param,
InputTransformUnitFunc input_trans_func, OutputTransformUnitFunc output_trans_func) {
if (conv_param->kernel_w_ == conv_param->kernel_h_ && conv_param->dilation_h_ == 1 && conv_param->dilation_w_ == 1 &&
conv_param->stride_h_ == 1 && conv_param->stride_w_ == 1) {
*output_unit = SelectOutputUnit(conv_param);
if (*output_unit > 1) {
*use_winograd = true;
int input_unit = conv_param->kernel_h_ + *output_unit - 1;
input_trans_func = GetInputTransFunc(input_unit);
if (input_trans_func == nullptr) {
MS_LOG(INFO) << "No matching input trans func. Turn back to common conv.";
*use_winograd = false;
}
output_trans_func = GetOutputTransFunc(input_unit, *output_unit);
if (output_trans_func == nullptr) {
MS_LOG(INFO) << "No matching output trans func. Turn back to common conv.";
*use_winograd = false;
}
} else {
*use_winograd = false;
}
} else {
*use_winograd = false;
}
}
kernel::LiteKernel *CpuConvFp32KernelCreator(const std::vector<lite::tensor::Tensor *> &inputs,
const std::vector<lite::tensor::Tensor *> &outputs,
OpParameter *opParameter, const Context *ctx,
const kernel::KernelKey &desc) {
MS_ASSERT(opParameter != nullptr);
MS_ASSERT(desc.type == schema::PrimitiveType_Conv2D);
auto conv_param = reinterpret_cast<ConvParameter *>(opParameter);
int kernel_h = conv_param->kernel_h_;
int kernel_w = conv_param->kernel_w_;
int stride_h = conv_param->stride_h_;
int stride_w = conv_param->stride_w_;
int dilation_h = conv_param->dilation_h_;
int dilation_w = conv_param->dilation_w_;
conv_param->input_h_ = inputs.front()->Height();
conv_param->input_w_ = inputs.front()->Width();
conv_param->output_h_ = outputs.front()->Height();
conv_param->output_w_ = outputs.front()->Width();
bool use_winograd;
int out_unit;
InputTransformUnitFunc input_trans_func = nullptr;
OutputTransformUnitFunc output_trans_func = nullptr;
CheckIfUseWinograd(&use_winograd, &out_unit, conv_param, input_trans_func, output_trans_func);
kernel::LiteKernel *kernel;
if (kernel_h == 1 && kernel_w == 1) {
kernel = new (std::nothrow) kernel::Convolution1x1CPUKernel(opParameter, inputs, outputs, ctx);
} else if (kernel_h == 3 && kernel_w == 3 && stride_h == 1 && stride_w == 1 && dilation_h == 1 && dilation_w == 1) {
kernel = new (std::nothrow) kernel::Convolution3x3CPUKernel(opParameter, inputs, outputs, ctx);
} else if (use_winograd) {
kernel = new (std::nothrow) kernel::ConvolutionWinogradCPUKernel(opParameter, inputs, outputs, ctx, out_unit);
} else {
kernel = new (std::nothrow) kernel::ConvolutionCPUKernel(opParameter, inputs, outputs, ctx);
}
if (kernel == nullptr) {
MS_LOG(ERROR) << "kernel is nullptr.";
return nullptr;
}
auto ret = kernel->Init();
if (ret != RET_OK) {
delete kernel;
MS_LOG(ERROR) << "Init kernel failed, name: " << opParameter->name_ << ", type: "
<< schema::EnumNamePrimitiveType(static_cast<schema::PrimitiveType>(opParameter->type_));
return nullptr;
}
return kernel;
}
REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_Conv2D, CpuConvFp32KernelCreator)
} // namespace mindspore::kernel

View File

@ -145,5 +145,29 @@ int ConvolutionDepthwiseCPUKernel::Run() {
}
return RET_OK;
}
kernel::LiteKernel *CpuConvDwFp32KernelCreator(const std::vector<lite::tensor::Tensor *> &inputs,
const std::vector<lite::tensor::Tensor *> &outputs,
OpParameter *opParameter, const Context *ctx,
const kernel::KernelKey &desc) {
MS_ASSERT(opParameter != nullptr);
MS_ASSERT(desc.type == schema::PrimitiveType_DepthwiseConv2D);
auto kernel = new (std::nothrow) kernel::ConvolutionDepthwiseCPUKernel(opParameter, inputs, outputs, ctx);
if (kernel == nullptr) {
MS_LOG(ERROR) << "kernel is nullptr.";
return nullptr;
}
auto ret = kernel->Init();
if (ret != RET_OK) {
delete kernel;
MS_LOG(ERROR) << "Init kernel failed, name: " << opParameter->name_ << ", type: "
<< schema::EnumNamePrimitiveType(static_cast<schema::PrimitiveType>(opParameter->type_));
return nullptr;
}
return kernel;
}
REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_DepthwiseConv2D, CpuConvDwFp32KernelCreator)
} // namespace mindspore::kernel

View File

@ -37,12 +37,12 @@ int CropLaunch(int thread_id, LiteParallelGroupEnv *penv, void *cdata) {
auto kernel = reinterpret_cast<CropCPUKernel *>(cdata);
return kernel->CropParallelRun(thread_id);
}
}
} // namespace
int CropCPUKernel::Init() {
schema::Format input0_format = inputs_[0]->GetFormat();
if (input0_format != schema::Format_NCHW && input0_format != schema::Format_NHWC) {
MS_LOG(ERROR) << "Unsupport format " << input0_format;
MS_LOG(ERROR) << "Unsupport format " << input0_format;
return RET_FORMAT_ERR;
}
outputs_[0]->SetFormat(input0_format);
@ -90,7 +90,7 @@ kernel::LiteKernel *CpuCropFp32KernelCreator(const std::vector<lite::tensor::Ten
MS_LOG(ERROR) << "Input context is nullptr!";
return nullptr;
}
MS_ASSERT(desc.type == schema::PrimitiveType_Crop);
op_parameter->thread_num_ = ctx->threadNum;
auto *kernel = new (std::nothrow) CropCPUKernel(op_parameter, inputs, outputs);
if (kernel == nullptr) {
@ -108,5 +108,5 @@ kernel::LiteKernel *CpuCropFp32KernelCreator(const std::vector<lite::tensor::Ten
return kernel;
}
REG_KERNEL(kCPU, PrimitiveType_Crop, CpuCropFp32KernelCreator)
REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_Crop, CpuCropFp32KernelCreator)
} // namespace mindspore::kernel

View File

@ -225,4 +225,28 @@ int DeConvolutionCPUKernel::Run() {
}
return RET_OK;
}
kernel::LiteKernel *CpuDeConvFp32KernelCreator(const std::vector<lite::tensor::Tensor *> &inputs,
const std::vector<lite::tensor::Tensor *> &outputs,
OpParameter *opParameter, const lite::Context *ctx,
const kernel::KernelKey &desc) {
MS_ASSERT(opParameter != nullptr);
MS_ASSERT(desc.type == schema::PrimitiveType_DeConv2D);
auto kernel = new (std::nothrow) kernel::DeConvolutionCPUKernel(opParameter, inputs, outputs, ctx);
if (kernel == nullptr) {
MS_LOG(ERROR) << "kernel is nullptr.";
return nullptr;
}
auto ret = kernel->Init();
if (ret != RET_OK) {
delete kernel;
MS_LOG(ERROR) << "Init kernel failed, name: " << opParameter->name_ << ", type: "
<< schema::EnumNamePrimitiveType(static_cast<schema::PrimitiveType>(opParameter->type_));
return nullptr;
}
return kernel;
}
REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_DeConv2D, CpuDeConvFp32KernelCreator)
} // namespace mindspore::kernel

View File

@ -24,7 +24,7 @@ using mindspore::kernel::KERNEL_ARCH::kCPU;
using mindspore::lite::KernelRegistrar;
using mindspore::lite::RET_ERROR;
using mindspore::lite::RET_OK;
using mindspore::schema::PrimitiveType_DepthwiseConv2D;
using mindspore::schema::PrimitiveType_DeDepthwiseConv2D;
namespace mindspore::kernel {
int DeconvolutionDepthwiseCPUKernel::InitSlideParam() {
@ -158,5 +158,28 @@ int DeconvolutionDepthwiseCPUKernel::Run() {
}
return RET_OK;
}
kernel::LiteKernel *CpuDeconvDwFp32KernelCreator(const std::vector<lite::tensor::Tensor *> &inputs,
const std::vector<lite::tensor::Tensor *> &outputs,
OpParameter *opParameter, const lite::Context *ctx,
const kernel::KernelKey &desc) {
MS_ASSERT(opParameter != nullptr);
MS_ASSERT(desc.type == schema::PrimitiveType_DeDepthwiseConv2D);
auto kernel = new (std::nothrow) kernel::DeconvolutionDepthwiseCPUKernel(opParameter, inputs, outputs, ctx);
if (kernel == nullptr) {
MS_LOG(ERROR) << "kernel is nullptr.";
return nullptr;
}
auto ret = kernel->Init();
if (ret != RET_OK) {
delete kernel;
MS_LOG(ERROR) << "Init kernel failed, name: " << opParameter->name_ << ", type: "
<< schema::EnumNamePrimitiveType(static_cast<schema::PrimitiveType>(opParameter->type_));
return nullptr;
}
return kernel;
}
REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_DeDepthwiseConv2D, CpuDeconvDwFp32KernelCreator)
} // namespace mindspore::kernel

View File

@ -67,6 +67,7 @@ kernel::LiteKernel *CpuDepthToSpaceFp32KernelCreator(const std::vector<lite::ten
MS_LOG(ERROR) << "Input opParameter is nullptr!";
return nullptr;
}
MS_ASSERT(desc.type == schema::PrimitiveType_DepthToSpace);
auto *kernel = new (std::nothrow) DepthToSpaceCPUKernel(opParameter, inputs, outputs);
if (kernel == nullptr) {
MS_LOG(ERROR) << "new DepthToSpaceCPUKernel fail!";
@ -83,6 +84,6 @@ kernel::LiteKernel *CpuDepthToSpaceFp32KernelCreator(const std::vector<lite::ten
return kernel;
}
REG_KERNEL(kCPU, PrimitiveType_DepthToSpace, CpuDepthToSpaceFp32KernelCreator)
REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_DepthToSpace, CpuDepthToSpaceFp32KernelCreator)
} // namespace mindspore::kernel

View File

@ -82,6 +82,10 @@ kernel::LiteKernel *CpuExpandsDimsFp32KernelCreator(const std::vector<lite::tens
MS_ASSERT(opParameter != nullptr);
MS_ASSERT(desc.type == schema::PrimitiveType_ExpandDims);
auto *kernel = new (std::nothrow) ExpandDimsCPUKernel(opParameter, inputs, outputs, ctx);
if (kernel == nullptr) {
MS_LOG(ERROR) << "new ExpandDimsCPUKernel fail!";
return nullptr;
}
auto ret = kernel->Init();
if (ret != RET_OK) {
MS_LOG(ERROR) << "Init kernel failed, name: " << opParameter->name_ << ", type: "
@ -92,6 +96,6 @@ kernel::LiteKernel *CpuExpandsDimsFp32KernelCreator(const std::vector<lite::tens
return kernel;
}
REG_KERNEL(kCPU, PrimitiveType_ExpandDims, CpuExpandsDimsFp32KernelCreator)
REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_ExpandDims, CpuExpandsDimsFp32KernelCreator)
} // namespace mindspore::kernel

View File

@ -92,6 +92,10 @@ kernel::LiteKernel *CpuFillFp32KernelCreator(const std::vector<lite::tensor::Ten
}
MS_ASSERT(desc.type == schema::PrimitiveType_Fill);
auto *kernel = new (std::nothrow) FillCPUKernel(opParameter, inputs, outputs, ctx);
if (kernel == nullptr) {
MS_LOG(ERROR) << "new FillCPUKernel fail!";
return nullptr;
}
auto ret = kernel->Init();
if (ret != RET_OK) {
MS_LOG(ERROR) << "Init kernel failed, name: " << opParameter->name_ << ", type: "
@ -102,6 +106,6 @@ kernel::LiteKernel *CpuFillFp32KernelCreator(const std::vector<lite::tensor::Ten
return kernel;
}
REG_KERNEL(kCPU, PrimitiveType_Fill, CpuFillFp32KernelCreator)
REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_Fill, CpuFillFp32KernelCreator)
} // namespace mindspore::kernel

View File

@ -56,6 +56,10 @@ kernel::LiteKernel *CpuFlattenFp32KernelCreator(const std::vector<lite::tensor::
}
MS_ASSERT(desc.type == schema::PrimitiveType_Flatten);
auto *kernel = new (std::nothrow) FlattenCPUKernel(opParameter, inputs, outputs, ctx);
if (kernel == nullptr) {
MS_LOG(ERROR) << "new FlattenCPUKernel fail!";
return nullptr;
}
auto ret = kernel->Init();
if (ret != RET_OK) {
MS_LOG(ERROR) << "Init kernel failed, name: " << opParameter->name_ << ", type: "
@ -66,6 +70,6 @@ kernel::LiteKernel *CpuFlattenFp32KernelCreator(const std::vector<lite::tensor::
return kernel;
}
REG_KERNEL(kCPU, PrimitiveType_Flatten, CpuFlattenFp32KernelCreator)
REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_Flatten, CpuFlattenFp32KernelCreator)
} // namespace mindspore::kernel

View File

@ -55,14 +55,19 @@ kernel::LiteKernel *CpuFusedBatchnormKernelCreator(const std::vector<lite::tenso
MS_ASSERT(opParameter != nullptr);
MS_ASSERT(desc.type == schema::PrimitiveType_FusedBatchNorm);
auto *kernel = new (std::nothrow) FusedBatchnormCPUKernel(opParameter, inputs, outputs);
if (kernel == nullptr) {
MS_LOG(ERROR) << "new FusedBatchnormCPUKernel fail!";
return nullptr;
}
auto ret = kernel->Init();
if (ret != RET_OK) {
MS_LOG(ERROR) << "Init kernel failed, name: " << opParameter->name_ << ", type: "
<< schema::EnumNamePrimitiveType(static_cast<schema::PrimitiveType>(opParameter->type_));
return nullptr;
}
return kernel;
}
REG_KERNEL(kCPU, PrimitiveType_FusedBatchNorm, CpuFusedBatchnormKernelCreator)
REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_FusedBatchNorm, CpuFusedBatchnormKernelCreator)
} // namespace mindspore::kernel

View File

@ -121,6 +121,6 @@ kernel::LiteKernel *CpuGatherFp32KernelCreator(const std::vector<lite::tensor::T
return kernel;
}
REG_KERNEL(kCPU, PrimitiveType_Gather, CpuGatherFp32KernelCreator)
REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_Gather, CpuGatherFp32KernelCreator)
} // namespace mindspore::kernel

View File

@ -143,6 +143,6 @@ kernel::LiteKernel *CpuGatherNdFp32KernelCreator(const std::vector<lite::tensor:
return kernel;
}
REG_KERNEL(kCPU, PrimitiveType_GatherNd, CpuGatherNdFp32KernelCreator)
REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_GatherNd, CpuGatherNdFp32KernelCreator)
} // namespace mindspore::kernel

View File

@ -95,6 +95,10 @@ kernel::LiteKernel *CpuLocalResponseNormFp32KernelCreator(const std::vector<lite
MS_ASSERT(desc.type == schema::PrimitiveType_LocalResponseNormalization);
auto *kernel = new (std::nothrow) LocalResponseNormCPUKernel(opParameter, inputs, outputs, ctx);
if (kernel == nullptr) {
MS_LOG(ERROR) << "new LocalResponseNormCPUKernel fail!";
return nullptr;
}
auto ret = kernel->Init();
if (ret != RET_OK) {
MS_LOG(ERROR) << "Init kernel failed, name: " << opParameter->name_ << ", type: "
@ -105,6 +109,5 @@ kernel::LiteKernel *CpuLocalResponseNormFp32KernelCreator(const std::vector<lite
return kernel;
}
REG_KERNEL(kCPU, PrimitiveType_LocalResponseNormalization, CpuLocalResponseNormFp32KernelCreator)
REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_LocalResponseNormalization, CpuLocalResponseNormFp32KernelCreator)
} // namespace mindspore::kernel

View File

@ -40,6 +40,10 @@ kernel::LiteKernel *CpuMatmulFp32KernelCreator(const std::vector<lite::tensor::T
const kernel::KernelKey &desc) {
MS_ASSERT(desc.type == schema::PrimitiveType_MatMul);
auto *kernel = new (std::nothrow) MatmulCPUKernel(opParameter, inputs, outputs);
if (kernel == nullptr) {
MS_LOG(ERROR) << "new MatmulCPUKernel fail!";
return nullptr;
}
auto ret = kernel->Init();
if (ret != RET_OK) {
MS_LOG(ERROR) << "Init kernel failed, name: " << opParameter->name_ << ", type: "
@ -48,6 +52,6 @@ kernel::LiteKernel *CpuMatmulFp32KernelCreator(const std::vector<lite::tensor::T
return kernel;
}
REG_KERNEL(kCPU, PrimitiveType_MatMul, CpuMatmulFp32KernelCreator)
REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_MatMul, CpuMatmulFp32KernelCreator)
} // namespace mindspore::kernel

View File

@ -43,6 +43,10 @@ kernel::LiteKernel *CpuNchw2NhwcFp32KernelCreator(const std::vector<lite::tensor
MS_ASSERT(opParameter != nullptr);
MS_ASSERT(desc.type == schema::PrimitiveType_Nchw2Nhwc);
auto *kernel = new (std::nothrow) Nchw2NhwcCPUKernel(opParameter, inputs, outputs);
if (kernel == nullptr) {
MS_LOG(ERROR) << "new Nchw2NhwcCPUKernel fail!";
return nullptr;
}
auto ret = kernel->Init();
if (ret != RET_OK) {
MS_LOG(ERROR) << "Init kernel failed, name: " << opParameter->name_ << ", type: "
@ -53,6 +57,6 @@ kernel::LiteKernel *CpuNchw2NhwcFp32KernelCreator(const std::vector<lite::tensor
return kernel;
}
REG_KERNEL(kCPU, PrimitiveType_Nchw2Nhwc, CpuNchw2NhwcFp32KernelCreator)
REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_Nchw2Nhwc, CpuNchw2NhwcFp32KernelCreator)
} // namespace mindspore::kernel

View File

@ -43,6 +43,10 @@ kernel::LiteKernel *CpuNhwc2NchwFp32KernelCreator(const std::vector<lite::tensor
MS_ASSERT(opParameter != nullptr);
MS_ASSERT(desc.type == schema::PrimitiveType_Nhwc2Nchw);
auto *kernel = new (std::nothrow) Nhwc2NchwCPUKernel(opParameter, inputs, outputs);
if (kernel == nullptr) {
MS_LOG(ERROR) << "new Nhwc2NchwCPUKernel fail!";
return nullptr;
}
auto ret = kernel->Init();
if (ret != RET_OK) {
MS_LOG(ERROR) << "Init kernel failed, name: " << opParameter->name_ << ", type: "
@ -53,6 +57,6 @@ kernel::LiteKernel *CpuNhwc2NchwFp32KernelCreator(const std::vector<lite::tensor
return kernel;
}
REG_KERNEL(kCPU, PrimitiveType_Nhwc2Nchw, CpuNhwc2NchwFp32KernelCreator)
REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_Nhwc2Nchw, CpuNhwc2NchwFp32KernelCreator)
} // namespace mindspore::kernel

View File

@ -183,5 +183,5 @@ kernel::LiteKernel *CpuOneHotFp32KernelCreator(const std::vector<lite::tensor::T
return kernel;
}
REG_KERNEL(kCPU, PrimitiveType_OneHot, CpuOneHotFp32KernelCreator)
REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_OneHot, CpuOneHotFp32KernelCreator)
} // namespace mindspore::kernel

View File

@ -68,6 +68,10 @@ kernel::LiteKernel *CpuPowerFp32KernelCreator(const std::vector<lite::tensor::Te
MS_ASSERT(desc.type == schema::PrimitiveType_Power);
auto *kernel =
new (std::nothrow) PowerCPUKernel(reinterpret_cast<PowerParameter *>(opParameter), inputs, outputs, ctx);
if (kernel == nullptr) {
MS_LOG(ERROR) << "new PowerCPUKernel fail!";
return nullptr;
}
auto ret = kernel->Init();
if (ret != RET_OK) {
MS_LOG(ERROR) << "Init kernel failed, name: " << opParameter->name_ << ", type: "
@ -76,5 +80,5 @@ kernel::LiteKernel *CpuPowerFp32KernelCreator(const std::vector<lite::tensor::Te
return kernel;
}
REG_KERNEL(kCPU, PrimitiveType_Power, CpuPowerFp32KernelCreator)
REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_Power, CpuPowerFp32KernelCreator)
} // namespace mindspore::kernel

View File

@ -70,7 +70,7 @@ kernel::LiteKernel *CpuPReluFp32KernelCreator(const std::vector<lite::tensor::Te
MS_LOG(ERROR) << "input opParameter is nullptr!";
return nullptr;
}
MS_ASSERT(desc.type == schema::PrimitiveType_Prelu);
auto *kernel = new (std::nothrow) PReluCPUKernel(opParameter, inputs, outputs, ctx);
if (kernel == nullptr) {
MS_LOG(ERROR) << "new PReluCPUKernel fail!";
@ -86,6 +86,6 @@ kernel::LiteKernel *CpuPReluFp32KernelCreator(const std::vector<lite::tensor::Te
return kernel;
}
REG_KERNEL(kCPU, PrimitiveType_Prelu, CpuPReluFp32KernelCreator)
REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_Prelu, CpuPReluFp32KernelCreator)
} // namespace mindspore::kernel

View File

@ -54,6 +54,10 @@ kernel::LiteKernel *CpuRangeFp32KernelCreator(const std::vector<lite::tensor::Te
MS_ASSERT(desc.type == schema::PrimitiveType_Range);
auto *kernel = new (std::nothrow) RangeCPUKernel(opParameter, inputs, outputs);
if (kernel == nullptr) {
MS_LOG(ERROR) << "new RangeCPUKernel fail!";
return nullptr;
}
auto ret = kernel->Init();
if (ret != RET_OK) {
MS_LOG(ERROR) << "Init kernel failed, name: " << opParameter->name_ << ", type: "
@ -64,7 +68,7 @@ kernel::LiteKernel *CpuRangeFp32KernelCreator(const std::vector<lite::tensor::Te
return kernel;
}
REG_KERNEL(kCPU, PrimitiveType_Range, CpuRangeFp32KernelCreator)
REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_Range, CpuRangeFp32KernelCreator)
} // namespace mindspore::kernel

View File

@ -53,6 +53,10 @@ kernel::LiteKernel *CpuRankFp32KernelCreator(const std::vector<lite::tensor::Ten
MS_ASSERT(desc.type == schema::PrimitiveType_Rank);
auto *kernel = new (std::nothrow) RankCPUKernel(opParameter, inputs, outputs);
if (kernel == nullptr) {
MS_LOG(ERROR) << "new RankCPUKernel fail!";
return nullptr;
}
auto ret = kernel->Init();
if (ret != RET_OK) {
MS_LOG(ERROR) << "Init kernel failed, name: " << opParameter->name_ << ", type: "
@ -63,6 +67,6 @@ kernel::LiteKernel *CpuRankFp32KernelCreator(const std::vector<lite::tensor::Ten
return kernel;
}
REG_KERNEL(kCPU, PrimitiveType_Rank, CpuRankFp32KernelCreator)
REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_Rank, CpuRankFp32KernelCreator)
} // namespace mindspore::kernel

View File

@ -241,7 +241,8 @@ int ReduceCPUKernel::MallocTmpBuffer() {
data_buffers_.emplace_back(buffer);
input_shape[axis] = 1;
}
return RET_OK;
}
REG_KERNEL(kCPU, PrimitiveType_Reduce, CpuReduceFp32KernelCreator)
REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_Reduce, CpuReduceFp32KernelCreator)
} // namespace mindspore::kernel

View File

@ -221,6 +221,7 @@ kernel::LiteKernel *CpuResizeFp32KernelCreator(const std::vector<lite::tensor::T
MS_LOG(ERROR) << "Input opParameter is nullptr!";
return nullptr;
}
MS_ASSERT(desc.type == schema::PrimitiveType_Resize);
auto *kernel = new (std::nothrow) ResizeCPUKernel(opParameter, inputs, outputs, ctx);
if (kernel == nullptr) {
MS_LOG(ERROR) << "new ResizeCPUKernel fail!";
@ -237,6 +238,6 @@ kernel::LiteKernel *CpuResizeFp32KernelCreator(const std::vector<lite::tensor::T
return kernel;
}
REG_KERNEL(kCPU, PrimitiveType_Resize, CpuResizeFp32KernelCreator)
REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_Resize, CpuResizeFp32KernelCreator)
} // namespace mindspore::kernel

View File

@ -139,6 +139,7 @@ kernel::LiteKernel *CpuReverseFp32KernelCreator(const std::vector<lite::tensor::
MS_LOG(ERROR) << "opParameter is NULL! ";
return nullptr;
}
MS_ASSERT(desc.type == schema::PrimitiveType_Reverse);
auto *kernel = new (std::nothrow) ReverseCPUKernel(opParameter, inputs, outputs, ctx);
if (kernel == nullptr) {
MS_LOG(ERROR) << "Kernel is NULL! name: " << opParameter->name_ << ", type: "
@ -156,6 +157,6 @@ kernel::LiteKernel *CpuReverseFp32KernelCreator(const std::vector<lite::tensor::
return kernel;
}
REG_KERNEL(kCPU, PrimitiveType_Reverse, CpuReverseFp32KernelCreator)
REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_Reverse, CpuReverseFp32KernelCreator)
} // namespace mindspore::kernel

View File

@ -96,6 +96,7 @@ kernel::LiteKernel *CpuReverseSequenceFp32KernelCreator(const std::vector<lite::
OpParameter *parameter, const lite::Context *ctx,
const KernelKey &desc) {
MS_ASSERT(parameter != nullptr);
MS_ASSERT(desc.type == schema::PrimitiveType_ReverseSequence);
auto *kernel = new (std::nothrow) ReverseSequenceCPUKernel(parameter, inputs, outputs);
if (kernel == nullptr) {
MS_LOG(ERROR) << "Create kernel failed, name: " << parameter->name_;
@ -111,6 +112,6 @@ kernel::LiteKernel *CpuReverseSequenceFp32KernelCreator(const std::vector<lite::
return kernel;
}
REG_KERNEL(kCPU, PrimitiveType_ReverseSequence, CpuReverseSequenceFp32KernelCreator)
REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_ReverseSequence, CpuReverseSequenceFp32KernelCreator)
} // namespace mindspore::kernel

View File

@ -164,5 +164,5 @@ kernel::LiteKernel *CpuScaleFp32KernelCreator(const std::vector<lite::tensor::Te
return kernel;
}
REG_KERNEL(kCPU, PrimitiveType_Scale, CpuScaleFp32KernelCreator)
REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_Scale, CpuScaleFp32KernelCreator)
} // namespace mindspore::kernel

View File

@ -181,6 +181,6 @@ kernel::LiteKernel *CpuScatterNDFp32KernelCreator(const std::vector<lite::tensor
return kernel;
}
REG_KERNEL(kCPU, PrimitiveType_ScatterND, CpuScatterNDFp32KernelCreator)
REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_ScatterND, CpuScatterNDFp32KernelCreator)
} // namespace mindspore::kernel

View File

@ -79,6 +79,6 @@ kernel::LiteKernel *CpuShapeFp32KernelCreator(const std::vector<lite::tensor::Te
return kernel;
}
REG_KERNEL(kCPU, PrimitiveType_Shape, CpuShapeFp32KernelCreator)
REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_Shape, CpuShapeFp32KernelCreator)
} // namespace mindspore::kernel

View File

@ -23,8 +23,8 @@
using mindspore::lite::KernelRegistrar;
using mindspore::lite::RET_ERROR;
using mindspore::lite::RET_OK;
using mindspore::lite::RET_NULL_PTR;
using mindspore::lite::RET_OK;
using mindspore::schema::PrimitiveType_Slice;
namespace mindspore::kernel {
@ -37,7 +37,7 @@ int SliceLaunch(int thread_id, LiteParallelGroupEnv *penv, void *cdata) {
auto kernel = reinterpret_cast<SliceCPUKernel *>(cdata);
return kernel->SliceParallelRun(thread_id);
}
}
} // namespace
int SliceCPUKernel::Init() {
auto *param = reinterpret_cast<SliceParameter *>(opParameter);
@ -106,7 +106,7 @@ kernel::LiteKernel *CpuSliceFp32KernelCreator(const std::vector<lite::tensor::Te
MS_LOG(ERROR) << "Input context is nullptr!";
return nullptr;
}
MS_ASSERT(desc.type == schema::PrimitiveType_Slice);
op_parameter->thread_num_ = ctx->threadNum;
auto *kernel = new (std::nothrow) SliceCPUKernel(op_parameter, inputs, outputs);
if (kernel == nullptr) {
@ -124,5 +124,5 @@ kernel::LiteKernel *CpuSliceFp32KernelCreator(const std::vector<lite::tensor::Te
return kernel;
}
REG_KERNEL(kCPU, PrimitiveType_Slice, CpuSliceFp32KernelCreator)
REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_Slice, CpuSliceFp32KernelCreator)
} // namespace mindspore::kernel

View File

@ -64,6 +64,10 @@ kernel::LiteKernel *CpuSoftmaxFp32KernelCreator(const std::vector<lite::tensor::
MS_ASSERT(opParameter != nullptr);
MS_ASSERT(desc.type == schema::PrimitiveType_SoftMax);
auto *kernel = new (std::nothrow) SoftmaxCPUKernel(opParameter, inputs, outputs);
if (kernel == nullptr) {
MS_LOG(ERROR) << "new SoftmaxCPUKernel fail!";
return nullptr;
}
auto ret = kernel->Init();
if (ret != RET_OK) {
MS_LOG(ERROR) << "Init kernel failed, name: " << opParameter->name_ << ", type: "
@ -74,6 +78,6 @@ kernel::LiteKernel *CpuSoftmaxFp32KernelCreator(const std::vector<lite::tensor::
return kernel;
}
REG_KERNEL(kCPU, PrimitiveType_SoftMax, CpuSoftmaxFp32KernelCreator)
REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_SoftMax, CpuSoftmaxFp32KernelCreator)
} // namespace mindspore::kernel

View File

@ -81,7 +81,7 @@ kernel::LiteKernel *CpuSparseToDenseFp32KernelCreator(const std::vector<lite::te
MS_LOG(ERROR) << "input opParameter is nullptr!";
return nullptr;
}
MS_ASSERT(desc.type == schema::PrimitiveType_SparseToDense);
auto *kernel = new (std::nothrow) SparseToDenseCPUKernel(opParameter, inputs, outputs, ctx);
if (kernel == nullptr) {
MS_LOG(ERROR) << "new SparseToDenseCPUKernel fail!";
@ -97,6 +97,6 @@ kernel::LiteKernel *CpuSparseToDenseFp32KernelCreator(const std::vector<lite::te
return kernel;
}
REG_KERNEL(kCPU, PrimitiveType_SparseToDense, CpuSparseToDenseFp32KernelCreator)
REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_SparseToDense, CpuSparseToDenseFp32KernelCreator)
} // namespace mindspore::kernel

View File

@ -126,5 +126,5 @@ kernel::LiteKernel *CpuSplitFp32KernelCreator(const std::vector<lite::tensor::Te
return kernel;
}
REG_KERNEL(kCPU, PrimitiveType_Split, CpuSplitFp32KernelCreator)
REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_Split, CpuSplitFp32KernelCreator)
} // namespace mindspore::kernel

View File

@ -74,6 +74,6 @@ kernel::LiteKernel *CpuSqueezeFp32KernelCreator(const std::vector<lite::tensor::
return kernel;
}
REG_KERNEL(kCPU, PrimitiveType_Squeeze, CpuSqueezeFp32KernelCreator)
REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_Squeeze, CpuSqueezeFp32KernelCreator)
} // namespace mindspore::kernel

View File

@ -92,6 +92,7 @@ kernel::LiteKernel *CpuStackFp32KernelCreator(const std::vector<lite::tensor::Te
MS_LOG(ERROR) << "Input op_parameter is nullptr!";
return nullptr;
}
MS_ASSERT(desc.type == schema::PrimitiveType_Stack);
auto *kernel = new (std::nothrow) StackCPUKernel(op_parameter, inputs, outputs);
if (kernel == nullptr) {
MS_LOG(ERROR) << "new StackCPUKernel fail!";
@ -108,6 +109,5 @@ kernel::LiteKernel *CpuStackFp32KernelCreator(const std::vector<lite::tensor::Te
return kernel;
}
REG_KERNEL(kCPU, PrimitiveType_Stack, CpuStackFp32KernelCreator)
REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_Stack, CpuStackFp32KernelCreator)
} // namespace mindspore::kernel

View File

@ -82,5 +82,5 @@ kernel::LiteKernel *CpuStridedSliceFp32KernelCreator(const std::vector<lite::ten
return kernel;
}
REG_KERNEL(kCPU, PrimitiveType_StridedSlice, CpuStridedSliceFp32KernelCreator)
REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_StridedSlice, CpuStridedSliceFp32KernelCreator)
} // namespace mindspore::kernel

View File

@ -77,6 +77,6 @@ kernel::LiteKernel *CpuTileFp32KernelCreator(const std::vector<lite::tensor::Ten
return kernel;
}
REG_KERNEL(kCPU, PrimitiveType_Tile, CpuTileFp32KernelCreator)
REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_Tile, CpuTileFp32KernelCreator)
} // namespace mindspore::kernel

View File

@ -53,10 +53,13 @@ int TopKCPUKernel::Run() {
kernel::LiteKernel *CpuTopKFp32KernelCreator(const std::vector<lite::tensor::Tensor *> &inputs,
const std::vector<lite::tensor::Tensor *> &outputs, OpParameter *parameter,
const lite::Context *ctx, const KernelKey &desc) {
MS_EXCEPTION_IF_NULL(parameter);
MS_ASSERT(parameter != nullptr);
MS_ASSERT(desc.type == PrimitiveType_Tile);
auto *kernel = new (std::nothrow) TopKCPUKernel(parameter, inputs, outputs);
MS_EXCEPTION_IF_NULL(kernel);
if (kernel == nullptr) {
MS_LOG(ERROR) << "new TopKCPUKernel fail!";
return nullptr;
}
auto ret = kernel->Init();
if (ret != RET_OK) {
@ -68,6 +71,6 @@ kernel::LiteKernel *CpuTopKFp32KernelCreator(const std::vector<lite::tensor::Ten
return kernel;
}
REG_KERNEL(kCPU, PrimitiveType_TopK, CpuTopKFp32KernelCreator)
REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_TopK, CpuTopKFp32KernelCreator)
} // namespace mindspore::kernel

View File

@ -95,6 +95,6 @@ kernel::LiteKernel *CpuTransposeFp32KernelCreator(const std::vector<lite::tensor
return kernel;
}
REG_KERNEL(kCPU, PrimitiveType_Transpose, CpuTransposeFp32KernelCreator)
REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_Transpose, CpuTransposeFp32KernelCreator)
} // namespace mindspore::kernel

View File

@ -62,6 +62,6 @@ kernel::LiteKernel *CpuUniqueFp32KernelCreator(const std::vector<lite::tensor::T
return kernel;
}
REG_KERNEL(kCPU, PrimitiveType_Unique, CpuUniqueFp32KernelCreator)
REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_Unique, CpuUniqueFp32KernelCreator)
} // namespace mindspore::kernel

View File

@ -81,6 +81,10 @@ kernel::LiteKernel *CpuUnsqueezeFp32KernelCreator(const std::vector<lite::tensor
MS_ASSERT(opParameter != nullptr);
MS_ASSERT(desc.type == schema::PrimitiveType_Unsqueeze);
auto *kernel = new (std::nothrow) UnsqueezeCPUKernel(opParameter, inputs, outputs, ctx);
if (kernel == nullptr) {
MS_LOG(ERROR) << "new AddNCPUKernel fail!";
return nullptr;
}
auto ret = kernel->Init();
if (ret != RET_OK) {
MS_LOG(ERROR) << "Init kernel failed, name: " << opParameter->name_ << ", type: "
@ -91,6 +95,6 @@ kernel::LiteKernel *CpuUnsqueezeFp32KernelCreator(const std::vector<lite::tensor
return kernel;
}
REG_KERNEL(kCPU, PrimitiveType_Unsqueeze, CpuUnsqueezeFp32KernelCreator)
REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_Unsqueeze, CpuUnsqueezeFp32KernelCreator)
} // namespace mindspore::kernel

View File

@ -86,6 +86,5 @@ kernel::LiteKernel *CpuUnstackFp32KernelCreator(const std::vector<lite::tensor::
return kernel;
}
REG_KERNEL(kCPU, PrimitiveType_Unstack, CpuUnstackFp32KernelCreator)
REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_Unstack, CpuUnstackFp32KernelCreator)
} // namespace mindspore::kernel

View File

@ -90,7 +90,7 @@ kernel::LiteKernel *CpuWhereFp32KernelCreator(const std::vector<lite::tensor::Te
MS_LOG(ERROR) << "input opParameter is nullptr!";
return nullptr;
}
MS_ASSERT(desc.type == schema::PrimitiveType_Where);
auto *kernel = new (std::nothrow) WhereCPUKernel(opParameter, inputs, outputs, ctx);
if (kernel == nullptr) {
MS_LOG(ERROR) << "new WhereCPUKernel fail!";
@ -106,5 +106,5 @@ kernel::LiteKernel *CpuWhereFp32KernelCreator(const std::vector<lite::tensor::Te
return kernel;
}
REG_KERNEL(kCPU, PrimitiveType_Where, CpuWhereFp32KernelCreator)
REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_Where, CpuWhereFp32KernelCreator)
} // namespace mindspore::kernel

View File

@ -48,6 +48,7 @@ kernel::LiteKernel *CpuZerosLikeFp32KernelCreator(const std::vector<lite::tensor
MS_LOG(ERROR) << "input opParameter is nullptr!";
return nullptr;
}
MS_ASSERT(desc.type == schema::PrimitiveType_ZerosLike);
auto *kernel = new (std::nothrow) ZerosLikeCPUKernel(opParameter, inputs, outputs);
if (kernel == nullptr) {
MS_LOG(ERROR) << "new ZerosLikeCPUKernel fail!";
@ -63,6 +64,6 @@ kernel::LiteKernel *CpuZerosLikeFp32KernelCreator(const std::vector<lite::tensor
return kernel;
}
REG_KERNEL(kCPU, PrimitiveType_ZerosLike, CpuZerosLikeFp32KernelCreator)
REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_ZerosLike, CpuZerosLikeFp32KernelCreator)
} // namespace mindspore::kernel

View File

@ -130,8 +130,10 @@ kernel::LiteKernel *CpuAddInt8KernelCreator(const std::vector<lite::tensor::Tens
}
MS_ASSERT(desc.type == PrimitiveType_Add);
auto *kernel = new (std::nothrow) QuantizedAddCPUKernel(parameter, inputs, outputs, ctx);
MS_EXCEPTION_IF_NULL(kernel);
if (kernel == nullptr) {
MS_LOG(ERROR) << "kernel is nullptr.";
return nullptr;
}
auto ret = kernel->Init();
if (0 != ret) {
MS_LOG(ERROR) << "Init kernel failed, name: " << parameter->name_
@ -142,5 +144,6 @@ kernel::LiteKernel *CpuAddInt8KernelCreator(const std::vector<lite::tensor::Tens
return kernel;
}
REG_KERNEL(kCPU, kNumberTypeInt8, PrimitiveType_Add, CpuAddInt8KernelCreator)
} // namespace mindspore::kernel

View File

@ -80,5 +80,6 @@ kernel::LiteKernel *CpuBiasAddInt8KernelCreator(const std::vector<lite::tensor::
return kernel;
}
REG_KERNEL(kCPU, kNumberTypeInt8, PrimitiveType_BiasAdd, CpuBiasAddInt8KernelCreator)
} // namespace mindspore::kernel

View File

@ -25,6 +25,7 @@ using mindspore::kernel::KERNEL_ARCH::kCPU;
using mindspore::lite::KernelRegistrar;
using mindspore::lite::RET_ERROR;
using mindspore::lite::RET_OK;
using mindspore::schema::PrimitiveType_DepthwiseConv2D;
namespace mindspore::kernel {
int ConvolutionDepthwiseInt8CPUKernel::InitWeightBias() {
@ -143,4 +144,27 @@ int ConvolutionDepthwiseInt8CPUKernel::Run() {
}
return RET_OK;
}
kernel::LiteKernel *CpuConvDwInt8KernelCreator(const std::vector<lite::tensor::Tensor *> &inputs,
const std::vector<lite::tensor::Tensor *> &outputs,
OpParameter *opParameter, const Context *ctx,
const kernel::KernelKey &desc) {
MS_ASSERT(opParameter != nullptr);
MS_ASSERT(desc.type == schema::PrimitiveType_DepthwiseConv2D);
auto kernel = new (std::nothrow) kernel::ConvolutionDepthwiseInt8CPUKernel(opParameter, inputs, outputs, ctx);
if (kernel == nullptr) {
MS_LOG(ERROR) << "kernel is nullptr.";
return nullptr;
}
auto ret = kernel->Init();
if (ret != RET_OK) {
delete kernel;
MS_LOG(ERROR) << "Init kernel failed, name: " << opParameter->name_ << ", type: "
<< schema::EnumNamePrimitiveType(static_cast<schema::PrimitiveType>(opParameter->type_));
return nullptr;
}
return kernel;
}
REG_KERNEL(kCPU, kNumberTypeInt8, PrimitiveType_DepthwiseConv2D, CpuConvDwInt8KernelCreator)
} // namespace mindspore::kernel

View File

@ -15,6 +15,7 @@
*/
#include "src/runtime/kernel/arm/int8/convolution_int8.h"
#include "src/runtime/kernel/arm/int8/convolution_3x3_int8.h"
#include "src/runtime/kernel/arm/opclib/int8/conv_int8.h"
#include "src/runtime/kernel/arm/base/layout_transform.h"
#include "schema/model_generated.h"
@ -36,7 +37,7 @@ void ConvolutionInt8CPUKernel::CheckSupportOptimize() {
support_optimize_ = false;
#endif
#ifdef __aarch64__
#ifdef ENABLE_ARM64
void *optimize_op_handler = OptimizeModule::GetInstance()->optimized_op_handler_;
if (optimize_op_handler != nullptr) {
dlerror();
@ -383,4 +384,39 @@ int ConvolutionInt8CPUKernel::Run() {
}
return RET_OK;
}
kernel::LiteKernel *CpuConvInt8KernelCreator(const std::vector<lite::tensor::Tensor *> &inputs,
const std::vector<lite::tensor::Tensor *> &outputs,
OpParameter *opParameter, const Context *ctx,
const kernel::KernelKey &desc) {
MS_ASSERT(opParameter != nullptr);
MS_ASSERT(desc.type == schema::PrimitiveType_Conv2D);
auto conv_param = reinterpret_cast<ConvParameter *>(opParameter);
int kernel_h = conv_param->kernel_h_;
int kernel_w = conv_param->kernel_w_;
int stride_h = conv_param->stride_h_;
int stride_w = conv_param->stride_w_;
int dilation_h = conv_param->dilation_h_;
int dilation_w = conv_param->dilation_w_;
kernel::LiteKernel *kernel;
if (kernel_h == 3 && kernel_w == 3 && stride_h == 1 && stride_w == 1 && dilation_h == 1 && dilation_w == 1) {
kernel = new (std::nothrow) kernel::Convolution3x3Int8CPUKernel(opParameter, inputs, outputs, ctx);
} else {
kernel = new (std::nothrow) kernel::ConvolutionInt8CPUKernel(opParameter, inputs, outputs, ctx);
}
if (kernel == nullptr) {
MS_LOG(ERROR) << "kernel is nullptr.";
return nullptr;
}
auto ret = kernel->Init();
if (ret != RET_OK) {
delete kernel;
MS_LOG(ERROR) << "Init kernel failed, name: " << opParameter->name_ << ", type: "
<< schema::EnumNamePrimitiveType(static_cast<schema::PrimitiveType>(opParameter->type_));
return nullptr;
}
return kernel;
}
REG_KERNEL(kCPU, kNumberTypeInt8, PrimitiveType_Conv2D, CpuConvInt8KernelCreator)
} // namespace mindspore::kernel

View File

@ -72,4 +72,3 @@ class ConvolutionInt8CPUKernel : public ConvolutionBaseCPUKernel {
} // namespace mindspore::kernel
#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_INT8_CONVOLUTION_INT8_H_

View File

@ -25,6 +25,7 @@ using mindspore::kernel::KERNEL_ARCH::kCPU;
using mindspore::lite::KernelRegistrar;
using mindspore::lite::RET_ERROR;
using mindspore::lite::RET_OK;
using mindspore::schema::PrimitiveType_DeDepthwiseConv2D;
namespace mindspore::kernel {
int DeconvolutionDepthwiseInt8CPUKernel::InitWeightBias() {
@ -63,9 +64,9 @@ int DeconvolutionDepthwiseInt8CPUKernel::InitSlideParam() {
sliding->in_h_step_ = conv_param_->input_w_ * C4NUM;
sliding->in_sh_step_ = conv_param_->input_w_ * C4NUM * conv_param_->stride_h_; // stride H
sliding->in_sw_step_ = C4NUM * conv_param_->stride_h_; // stride W
sliding->in_sw_step_ = C4NUM * conv_param_->stride_h_; // stride W
sliding->in_kh_step_ = conv_param_->input_w_ * C4NUM * conv_param_->dilation_h_; // kernel H
sliding->in_kw_step_ = C4NUM * conv_param_->dilation_w_; // kernel W
sliding->in_kw_step_ = C4NUM * conv_param_->dilation_w_; // kernel W
return RET_OK;
}
@ -171,4 +172,27 @@ int DeconvolutionDepthwiseInt8CPUKernel::Run() {
}
return RET_OK;
}
kernel::LiteKernel *CpuDeconvDwInt8KernelCreator(const std::vector<lite::tensor::Tensor *> &inputs,
const std::vector<lite::tensor::Tensor *> &outputs,
OpParameter *opParameter, const lite::Context *ctx,
const kernel::KernelKey &desc) {
MS_ASSERT(opParameter != nullptr);
MS_ASSERT(desc.type == schema::PrimitiveType_DeDepthwiseConv2D);
auto kernel = new (std::nothrow) kernel::DeconvolutionDepthwiseInt8CPUKernel(opParameter, inputs, outputs, ctx);
if (kernel == nullptr) {
MS_LOG(ERROR) << "kernel is nullptr.";
return nullptr;
}
auto ret = kernel->Init();
if (ret != RET_OK) {
delete kernel;
MS_LOG(ERROR) << "Init kernel failed, name: " << opParameter->name_ << ", type: "
<< schema::EnumNamePrimitiveType(static_cast<schema::PrimitiveType>(opParameter->type_));
return nullptr;
}
return kernel;
}
REG_KERNEL(kCPU, kNumberTypeInt8, PrimitiveType_DeDepthwiseConv2D, CpuDeconvDwInt8KernelCreator)
} // namespace mindspore::kernel

View File

@ -17,8 +17,10 @@
#include "src/runtime/kernel/arm/int8/deconvolution_int8.h"
#include "src/runtime/kernel/arm/opclib/quantization/fixed_point.h"
#include "src/runtime/runtime_api.h"
#include "src/kernel_registry.h"
using mindspore::kernel::KERNEL_ARCH::kCPU;
using mindspore::lite::KernelRegistrar;
using mindspore::lite::RET_ERROR;
using mindspore::lite::RET_MEMORY_FAILED;
using mindspore::lite::RET_OK;
@ -216,5 +218,27 @@ int DeConvInt8CPUKernel::Run() {
return RET_OK;
}
} // namespace mindspore::kernel
kernel::LiteKernel *CpuDeConvInt8KernelCreator(const std::vector<lite::tensor::Tensor *> &inputs,
const std::vector<lite::tensor::Tensor *> &outputs,
OpParameter *opParameter, const lite::Context *ctx,
const kernel::KernelKey &desc) {
MS_ASSERT(opParameter != nullptr);
MS_ASSERT(desc.type == schema::PrimitiveType_DeConv2D);
auto kernel = new (std::nothrow) kernel::DeConvInt8CPUKernel(opParameter, inputs, outputs, ctx);
if (kernel == nullptr) {
MS_LOG(ERROR) << "kernel is nullptr.";
return nullptr;
}
auto ret = kernel->Init();
if (ret != RET_OK) {
delete kernel;
MS_LOG(ERROR) << "Init kernel failed, name: " << opParameter->name_ << ", type: "
<< schema::EnumNamePrimitiveType(static_cast<schema::PrimitiveType>(opParameter->type_));
return nullptr;
}
return kernel;
}
REG_KERNEL(kCPU, kNumberTypeInt8, PrimitiveType_DeConv2D, CpuDeConvInt8KernelCreator)
} // namespace mindspore::kernel

View File

@ -129,4 +129,5 @@ kernel::LiteKernel *CpuMulInt8KernelCreator(const std::vector<lite::tensor::Tens
return kernel;
}
REG_KERNEL(kCPU, kNumberTypeInt8, PrimitiveType_Mul, CpuMulInt8KernelCreator)
} // namespace mindspore::kernel

View File

@ -10,14 +10,13 @@ file(GLOB OPTIMIZED_ASSEMBLY
${OPTIMIZED_OP_DIR}/assembly/opt/*.S
)
file(GLOB FP16_SRC
# ${OPTIMIZED_OP_DIR}/fp16/*.cc
# ${OPTIMIZED_OP_DIR}/../fp16/*.cc
${OPTIMIZED_OP_DIR}/fp16/*.cc
${OPTIMIZED_OP_DIR}/../fp16/*.cc
)
########################### share library build ########################
set(OPTIMIZED_OPS "opt_op_handler.c")
set(OPTIMIZED_OPS ${OPTIMIZED_OP_DIR}/opt_op_handler.c)
set_property(SOURCE ${OPTIMIZED_ASSEMBLY} PROPERTY LANGUAGE C)
list(APPEND OPTIMIZED_OPS ${OPTIMIZED_ASSEMBLY} ${FP16_SRC})
@ -27,6 +26,10 @@ if (PLATFORM_ARM64)
set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -march=armv8.2-a+dotprod+fp16")
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -march=armv8.2-a+dotprod+fp16")
add_library(optimize SHARED ${OPTIMIZED_OPS})
target_link_libraries(
optimize
mindspore-lite
)
set_target_properties(optimize PROPERTIES CLEAN_DIRECT_OUTPUT 1)
add_custom_command(TARGET optimize POST_BUILD

View File

@ -1,5 +1,4 @@
#ifdef __arm__
#ifndef __aarch64__
#ifdef ENABLE_ARM32
.text
.align 5
@ -236,5 +235,4 @@ IndirectGemmInt16to32_8x4:
pop {r4-r8, r10, pc}
#endif
#endif

View File

@ -15,7 +15,6 @@
*/
#include "src/runtime/kernel/arm/opclib/fp16/conv_depthwise_fp16.h"
#ifdef ENABLE_FP16
#include <arm_neon.h>
/*conv depthwise fp16 begin*/
@ -299,4 +298,3 @@ void DeconvDwC8Fp16(float16_t *output_data, const float16_t *input_data, const f
}
/*deconv depthwise fp16 end*/
#endif

View File

@ -20,7 +20,6 @@
#include "src/runtime/kernel/arm/opclib/conv_parameter.h"
#include "src/runtime/kernel/arm/opclib/fp32/conv_depthwise.h"
#ifdef ENABLE_FP16
void ConvDwC8Fp16(float16_t *output_data, const float16_t *input_data, const float16_t *weight_data,
const float16_t *bias_data, const ConvParameter *conv_param, const SlidingWindowParam *sliding,
int task_id);
@ -28,6 +27,5 @@ void ConvDwC8Fp16(float16_t *output_data, const float16_t *input_data, const flo
void DeconvDwC8Fp16(float16_t *output_data, const float16_t *input_data, const float16_t *weight_data,
const float16_t *bias_data, const ConvParameter *conv_param, const SlidingWindowParam *sliding,
int task_id);
#endif
#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_OPCLIB_FP16_CONV_DEPTHWISE_FP16_H_

View File

@ -15,20 +15,17 @@
*/
#include "src/runtime/kernel/arm/opclib/fp16/conv_fp16.h"
#include <string.h>
#include "src/runtime/kernel/arm/opclib/pack.h"
#include "src/runtime/kernel/arm/opclib/winograd_transform.h"
#include "src/runtime/kernel/arm/opclib/fp16/pack_fp16.h"
#include "src/runtime/kernel/arm/opclib/fp16/winograd_transform_fp16.h"
extern "C" {
#ifdef ENABLE_ARM64
#ifdef ENABLE_FP16
void IndirectGemmFp16_16x8(float16_t *output, float16_t *input, float16_t *weight, float16_t *bias, size_t step,
size_t ic4, size_t oc8, size_t offset, size_t mode, size_t writeC4, size_t relu,
size_t relu6);
#endif
#endif
}
#ifdef ENABLE_FP16
#ifndef ENABLE_NEON
void IndirectGemmFp16_16x8(float16_t *output, float16_t *input, float16_t *weight, float16_t *bias, size_t step,
size_t ic4, size_t out_channel, size_t offset, size_t mode, size_t writeC4, size_t relu,
@ -215,5 +212,5 @@ void Conv3x3Fp16(float16_t *input_data, float16_t *transed_weight, const float16
}
}
}
#endif

View File

@ -16,12 +16,9 @@
#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_OPCLIB_FP16_CONV_FP16_H_
#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_OPCLIB_FP16_CONV_FP16_H_
#ifdef ENABLE_FP16
#include <arm_neon.h>
#endif
#include "src/runtime/kernel/arm/opclib/conv_parameter.h"
#ifdef ENABLE_FP16
#ifndef ENABLE_NEON
void IndirectGemmFp16_16x8(float16_t *output, float16_t *input, float16_t *weight, float16_t *bias, size_t step,
size_t ic4, size_t oc8, size_t offset, size_t mode, size_t writeC4, size_t relu,
@ -36,7 +33,6 @@ void ConvFp16(float16_t *input_data, float16_t *packed_input, float16_t *packed_
void Conv3x3Fp16(float16_t *input_data, float16_t *transed_weight, const float16_t *bias_data, float16_t *output_data,
float16_t *tile_buffer, float16_t *block_unit_buffer, float16_t *tmp_dst_buffer, float16_t *tmp_out,
int task_id, ConvParameter *conv_param);
#endif
#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_OPCLIB_FP16_CONV_FP16_H_

View File

@ -0,0 +1,342 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "src/runtime/kernel/arm/opclib/fp16/pack_fp16.h"
#include <cstring>
#include <cstdlib>
void Im2ColPackUnitFp16(float16_t *input_data, ConvParameter *conv_param, float16_t *packed_input, int real_cal_num,
int block_index) {
// input format : nhwc
int kernel_h = conv_param->kernel_h_;
int kernel_w = conv_param->kernel_w_;
int stride_h = conv_param->stride_h_;
int stride_w = conv_param->stride_w_;
int pad_h = conv_param->pad_h_;
int pad_w = conv_param->pad_w_;
int dilation_h = conv_param->dilation_h_;
int dilation_w = conv_param->dilation_w_;
int in_channel = conv_param->input_channel_;
int in_h = conv_param->input_h_;
int in_w = conv_param->input_w_;
int out_w = conv_param->output_w_;
int channel_block = UP_DIV(in_channel, 4);
int kernel_plane = kernel_h * kernel_w;
for (int i = 0; i < real_cal_num; i++) {
int block_start = block_index + i;
int input_h = block_start / out_w * stride_h - pad_h;
int input_w = block_start % out_w * stride_w - pad_w;
for (int j = 0; j < kernel_h; j++) {
int input_y = input_h + j * dilation_h;
if (input_y < 0 || input_y >= in_h) {
continue;
}
int input_y_stride = input_y * in_w * channel_block * C4NUM;
for (int n = 0; n < kernel_w; n++) {
int input_x = input_w + n * dilation_w;
if (input_x < 0 || input_x >= in_w) {
continue;
}
int input_x_stride = input_y_stride + input_x * channel_block * C4NUM;
int input_plane_offset = (j * kernel_w + n) * 16 * C4NUM * channel_block + i * C4NUM;
for (int m = 0; m < channel_block; m++) {
int channel_block_stride = input_x_stride + m * C4NUM;
int channel_block_offset = input_plane_offset + m * 16 * C4NUM;
(packed_input + channel_block_offset)[0] = (input_data + channel_block_stride)[0];
(packed_input + channel_block_offset)[1] = (input_data + channel_block_stride)[1];
(packed_input + channel_block_offset)[2] = (input_data + channel_block_stride)[2];
(packed_input + channel_block_offset)[3] = (input_data + channel_block_stride)[3];
} // channel_block loop
} // kernel_w loop
} // kernel_h loop
} // tile num loop
}
void PackWeightFp16(float16_t *weight_data, ConvParameter *conv_param, float16_t *packed_weight) {
// original weight format : ohwi
int tile_num = 8;
int inchannel_block = 4;
int kernel_h = conv_param->kernel_h_;
int kernel_w = conv_param->kernel_w_;
int in_channel = conv_param->input_channel_;
int out_channel = conv_param->output_channel_;
int kernel_block = UP_DIV(out_channel, tile_num);
int channel_block = UP_DIV(in_channel, inchannel_block);
int kernel_plane = kernel_h * kernel_w;
int pack_weight_size = kernel_block * channel_block * tile_num * inchannel_block * kernel_plane;
int unit_size = tile_num * inchannel_block;
int block_size = pack_weight_size / kernel_block;
for (int m = 0; m < kernel_plane; m++) {
int kernel_plane_stride = m * in_channel;
int packed_kernel_plane_stride = m * unit_size * channel_block;
for (int i = 0; i < channel_block; i++) {
int channel_block_stride = kernel_plane_stride + i * inchannel_block;
int packed_channel_block_size = packed_kernel_plane_stride + i * unit_size;
int ic_remainder = in_channel - i * inchannel_block;
int real_ic_num = ic_remainder < inchannel_block ? ic_remainder : inchannel_block;
for (int h = 0; h < real_ic_num; h++) {
int block_stride = channel_block_stride + h;
int packed_block_stride = packed_channel_block_size + h * tile_num;
for (int j = 0; j < kernel_block; j++) {
int kernel_block_stride = block_stride + j * tile_num * kernel_plane * in_channel;
int packed_kernel_block_size = packed_block_stride + j * block_size;
int oc_remainder = out_channel - j * tile_num;
int real_oc_num = oc_remainder < tile_num ? oc_remainder : tile_num;
for (int k = 0; k < real_oc_num; k++) {
float16_t *origin_data_ptr = weight_data + kernel_block_stride + k * kernel_plane * in_channel;
float16_t *packed_data_ptr = packed_weight + packed_kernel_block_size + k;
*packed_data_ptr = *origin_data_ptr;
}
} // kernel block loop
} // inchannel block loop
} // channel block loop
} // kernel plane loop
}
void PackWeightToC8Fp16(const float16_t *origin_weight_data, float16_t *packed_weight_data, ConvParameter *conv_param) {
// origin weight format : ohwi
int input_channel = conv_param->input_channel_;
int ic8 = UP_DIV(input_channel, C8NUM);
int output_channel = conv_param->output_channel_;
int kernel_plane = conv_param->kernel_h_ * conv_param->kernel_w_;
for (int k = 0; k < kernel_plane; k++) {
int src_kernel_offset = k * input_channel;
int dst_kernel_offset = k * C8NUM;
for (int o = 0; o < output_channel; o++) {
int src_oc_offset = src_kernel_offset + o * kernel_plane * input_channel;
int dst_oc_offset = dst_kernel_offset + o * ic8 * kernel_plane * C8NUM;
for (int i = 0; i < input_channel; i++) {
int c8_block_num = i / C8NUM;
int c8_block_rem = i % C8NUM;
int src_ic_offset = src_oc_offset + i;
int dst_ic_offset = dst_oc_offset + c8_block_num * kernel_plane * C8NUM + c8_block_rem;
(packed_weight_data + dst_ic_offset)[0] = (origin_weight_data + src_ic_offset)[0];
}
}
}
}
void PackWeightToC4Fp16(const float16_t *origin_weight_data, float16_t *packed_weight_data, ConvParameter *conv_param) {
// origin weight format : ohwi
int input_channel = conv_param->input_channel_;
int ic4 = UP_DIV(input_channel, C4NUM);
int output_channel = conv_param->output_channel_;
int kernel_plane = conv_param->kernel_h_ * conv_param->kernel_w_;
for (int k = 0; k < kernel_plane; k++) {
int src_kernel_offset = k * input_channel;
int dst_kernel_offset = k * C4NUM;
for (int o = 0; o < output_channel; o++) {
int src_oc_offset = src_kernel_offset + o * kernel_plane * input_channel;
int dst_oc_offset = dst_kernel_offset + o * ic4 * kernel_plane * C4NUM;
for (int i = 0; i < input_channel; i++) {
int c4_block_num = i / C4NUM;
int c4_block_rem = i % C4NUM;
int src_ic_offset = src_oc_offset + i;
int dst_ic_offset = dst_oc_offset + c4_block_num * kernel_plane * C4NUM + c4_block_rem;
(packed_weight_data + dst_ic_offset)[0] = (origin_weight_data + src_ic_offset)[0];
}
}
}
}
void PackNHWCToNC4HW4Fp16(const void *src, void *dst, int batch, int plane, int channel) {
int c4 = UP_DIV(channel, C4NUM);
for (int b = 0; b < batch; b++) {
int src_oc_offset = b * plane * channel;
int dst_oc_offset = b * plane * c4 * C4NUM;
for (int k = 0; k < plane; k++) {
int src_kernel_offset = src_oc_offset + k * channel;
int dst_kernel_offset = dst_oc_offset + k * C4NUM;
for (int i = 0; i < channel; i++) {
int c4_block_num = i / C4NUM;
int c4_block_rem = i % C4NUM;
int src_ic_offset = src_kernel_offset + i;
int dst_ic_offset = dst_kernel_offset + c4_block_num * plane * C4NUM + c4_block_rem;
((float16_t *)dst + dst_ic_offset)[0] = ((float16_t *)src + src_ic_offset)[0];
}
}
}
}
void PackNCHWToNC4HW4Fp16(const void *src, void *dst, int batch, int plane, int channel) {
int c4 = UP_DIV(channel, C4NUM);
for (int b = 0; b < batch; b++) {
int src_offset = b * plane * channel;
int dst_offset = b * plane * c4 * C4NUM;
for (int c = 0; c < channel; c++) {
int c4_block_num = c / C4NUM;
int c4_block_rem = c % C4NUM;
int src_c_offset = src_offset + c * plane;
int dst_c_offset = dst_offset + c4_block_num * plane * C4NUM;
for (int k = 0; k < plane; k++) {
int src_kernel_offset = src_c_offset + k;
int dst_kernel_offset = dst_c_offset + C4NUM * k + c4_block_rem;
((float16_t *)dst + dst_kernel_offset)[0] = ((float16_t *)src + src_kernel_offset)[0];
}
}
}
}
void PackNHWCToNHWC4Fp16(const void *src, void *dst, int batch, int plane, int channel) {
int ic4 = UP_DIV(channel, C4NUM);
int nhwc4_batch_unit_offset = ic4 * C4NUM * plane;
int ic_remainder_ = channel % C4NUM;
if (ic_remainder_ != 0) {
int nhwc4_batch_offset = 0;
for (int b = 0; b < batch; b++) {
int batch_offset = b * channel * plane;
for (int i = 0; i < plane; i++) {
memcpy((float16_t *)dst + nhwc4_batch_offset + i * ic4 * C4NUM, (float16_t *)src + batch_offset + i * channel,
channel * sizeof(float16_t));
}
nhwc4_batch_offset += nhwc4_batch_unit_offset;
}
} else {
size_t ori_input_size = batch * plane * channel * sizeof(float16_t);
memcpy(dst, src, ori_input_size);
}
}
void PackNCHWToNHWC4Fp16(const void *src, void *dst, int batch, int plane, int channel) {
int nhwc4_batch_offset = 0;
int ic4 = UP_DIV(channel, C4NUM);
int nhwc4_batch_unit_offset = ic4 * C4NUM * plane;
for (int b = 0; b < batch; b++) {
int batch_offset = b * channel * plane;
for (int c = 0; c < channel; c++) {
int src_c_offset = batch_offset + c * plane;
int dst_c_offset = nhwc4_batch_offset + c;
for (int i = 0; i < plane; i++) {
int src_plane_offset = src_c_offset + i;
int dst_plane_offset = dst_c_offset + i * ic4 * C4NUM;
((float16_t *)dst)[dst_plane_offset] = ((float16_t *)src)[src_plane_offset];
}
}
nhwc4_batch_offset += nhwc4_batch_unit_offset;
}
}
void PackNC4HW4ToNHWC4Fp16(const void *src, void *dst, int batch, int plane, int channel) {
int c4 = UP_DIV(channel, C4NUM);
for (int b = 0; b < batch; b++) {
int src_offset = b * plane * c4 * C4NUM;
int dst_offset = b * plane * channel;
for (int c = 0; c < channel; c++) {
int c4_block_num = c / C4NUM;
int c4_block_res = c % C4NUM;
int src_c_offset = src_offset + c4_block_num * plane * C4NUM + c4_block_res;
int dst_c_offset = dst_offset + c4_block_num * C4NUM + c4_block_res;
for (int k = 0; k < plane; k++) {
int src_kernel_offset = src_c_offset + k * C4NUM;
int dst_kernel_offset = dst_c_offset + k * c4 * C4NUM;
((float16_t *)dst + dst_kernel_offset)[0] = ((float16_t *)src + src_kernel_offset)[0];
}
}
}
}
void PackNC4HW4ToNHWCFp16(const void *src, void *dst, int batch, int plane, int channel) {
int c4 = UP_DIV(channel, C4NUM);
for (int b = 0; b < batch; b++) {
int src_offset = b * plane * c4 * C4NUM;
int dst_offset = b * plane * channel;
for (int c = 0; c < channel; c++) {
int c4_block_num = c / C4NUM;
int c4_block_res = c % C4NUM;
int src_c_offset = src_offset + c4_block_num * plane * C4NUM + c4_block_res;
int dst_c_offset = dst_offset + c;
for (int k = 0; k < plane; k++) {
int src_kernel_offset = src_c_offset + k * C4NUM;
int dst_kernel_offset = dst_c_offset + k * channel;
((float16_t *)dst + dst_kernel_offset)[0] = ((float16_t *)src + src_kernel_offset)[0];
}
}
}
}
void PackNC4HW4ToNCHWFp16(const void *src, void *dst, int batch, int plane, int channel) {
int c4 = UP_DIV(channel, C4NUM);
for (int b = 0; b < batch; b++) {
int src_offset = b * plane * c4 * C4NUM;
int dst_offset = b * plane * channel;
for (int c = 0; c < channel; c++) {
int c4_block_num = c / C4NUM;
int c4_block_res = c % C4NUM;
int src_c_offset = src_offset + c4_block_num * plane * C4NUM + c4_block_res;
int dst_c_offset = dst_offset + c * plane;
for (int k = 0; k < plane; k++) {
int src_kernel_offset = src_c_offset + k * C4NUM;
int dst_kernel_offset = dst_c_offset + k;
((float16_t *)dst + dst_kernel_offset)[0] = ((float16_t *)src + src_kernel_offset)[0];
}
}
}
}
void PackNCHWFp32ToNC8HW8Fp16(float *src, float16_t *dst, int batch, int plane, int channel) {
int c8 = UP_DIV(channel, C8NUM);
for (int b = 0; b < batch; b++) {
int src_offset = b * plane * channel;
int dst_offset = b * plane * c8 * C8NUM;
for (int c = 0; c < channel; c++) {
int c8_block_num = c / C8NUM;
int c8_block_rem = c % C8NUM;
int src_c_offset = src_offset + c * plane;
int dst_c_offset = dst_offset + c8_block_num * plane * C8NUM;
for (int k = 0; k < plane; k++) {
int src_kernel_offset = src_c_offset + k;
int dst_kernel_offset = dst_c_offset + C8NUM * k + c8_block_rem;
(dst + dst_kernel_offset)[0] = (float16_t)(src + src_kernel_offset)[0];
}
}
}
}
void PackNHWCFp32ToNHWC8Fp16(float *src, float16_t *dst, int batch, int plane, int channel) {
int c8 = UP_DIV(channel, C8NUM);
int nhwc8_batch_unit_offset = c8 * C8NUM * plane;
int nhwc8_batch_offset = 0;
for (int b = 0; b < batch; b++) {
int batch_offset = b * channel * plane;
for (int i = 0; i < plane; i++) {
for (int c = 0; c < channel; c++) {
(dst + nhwc8_batch_offset + i * c8 * C8NUM)[c] = (float16_t)(src + batch_offset + i * channel)[c];
}
}
nhwc8_batch_offset += nhwc8_batch_unit_offset;
}
}
void PackNHWC8Fp16ToNHWCFp32(float16_t *src, float *dst, int batch, int plane, int channel) {
int c8 = UP_DIV(channel, C8NUM);
int nhwc_batch_unit_offset = channel * plane;
int nhwc_batch_offset = 0;
for (int b = 0; b < batch; b++) {
int batch_offset = b * c8 * C8NUM * plane;
for (int i = 0; i < plane; i++) {
for (int c = 0; c < channel; c++) {
(dst + nhwc_batch_offset + i * channel)[c] = (float)(src + batch_offset + i * c8 * C8NUM)[c];
}
}
nhwc_batch_offset += nhwc_batch_unit_offset;
}
}

View File

@ -0,0 +1,57 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_OPCLIB_FP16_PACK_FP16_H_
#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_OPCLIB_FP16_PACK_FP16_H_
#ifdef ENABLE_NEON
#include <arm_neon.h>
#endif
#include "src/runtime/kernel/arm/opclib/conv_parameter.h"
#include "src/runtime/kernel/arm/opclib/op_base.h"
void Im2ColPackUnitFp16(float16_t *input_data, ConvParameter *conv_param, float16_t *packed_input, int real_cal_num,
int block_index);
void PackWeightFp16(float16_t *weight_data, ConvParameter *conv_param, float16_t *packed_weight);
void PackWeightToC8Fp16(const float16_t *origin_weight_data, float16_t *packed_weight_data, ConvParameter *conv_param);
void PackWeightToC4Fp16(const float16_t *origin_weight_data, float16_t *packed_weight_data, ConvParameter *conv_param);
void PackNHWCToNC4HW4Fp16(const void *src, void *dst, int batch, int plane, int channel);
void PackNCHWToNC4HW4Fp16(const void *src, void *dst, int batch, int plane, int channel);
void PackNHWCToNHWC4Fp16(const void *src, void *dst, int batch, int plane, int channel);
void PackNCHWToNHWC4Fp16(const void *src, void *dst, int batch, int plane, int channel);
void PackNC4HW4ToNHWC4Fp16(const void *src, void *dst, int batch, int plane, int channel);
void PackNC4HW4ToNHWCFp16(const void *src, void *dst, int batch, int plane, int channel);
void PackNC4HW4ToNCHWFp16(const void *src, void *dst, int batch, int plane, int channel);
void PackNC8HW8ToNHWCFp16(const void *src, void *dst, int batch, int plane, int channel);
void PackNCHWFp32ToNC8HW8Fp16(float *src, float16_t *dst, int batch, int plane, int channel);
void PackNHWCFp32ToNHWC8Fp16(float *src, float16_t *dst, int batch, int plane, int channel);
void PackNHWC8Fp16ToNHWCFp32(float16_t *src, float *dst, int batch, int plane, int channel);
#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_OPCLIB_FP16_PACK_FP16_H_

View File

@ -0,0 +1,527 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "src/runtime/kernel/arm/opclib/fp16/winograd_transform_fp16.h"
// for fp16 convolution 3x3 filter/input/output transform F(4,3)
void Conv3x3Fp16InputUnit(float16_t *tmp_data, float16_t *trans_input_data, size_t step) {
float16x4_t d00 = vld1_f16(tmp_data);
float16x4_t d01 = vld1_f16(tmp_data + 4);
float16x4_t d02 = vld1_f16(tmp_data + 2 * 4);
float16x4_t d03 = vld1_f16(tmp_data + 3 * 4);
float16x4_t d04 = vld1_f16(tmp_data + 4 * 4);
float16x4_t d05 = vld1_f16(tmp_data + 5 * 4);
float16x4_t d10 = vld1_f16(tmp_data + 6 * 4);
float16x4_t d11 = vld1_f16(tmp_data + 7 * 4);
float16x4_t d12 = vld1_f16(tmp_data + 8 * 4);
float16x4_t d13 = vld1_f16(tmp_data + 9 * 4);
float16x4_t d14 = vld1_f16(tmp_data + 10 * 4);
float16x4_t d15 = vld1_f16(tmp_data + 11 * 4);
float16x4_t d20 = vld1_f16(tmp_data + 12 * 4);
float16x4_t d21 = vld1_f16(tmp_data + 13 * 4);
float16x4_t d22 = vld1_f16(tmp_data + 14 * 4);
float16x4_t d23 = vld1_f16(tmp_data + 15 * 4);
float16x4_t d24 = vld1_f16(tmp_data + 16 * 4);
float16x4_t d25 = vld1_f16(tmp_data + 17 * 4);
float16x4_t d30 = vld1_f16(tmp_data + 18 * 4);
float16x4_t d31 = vld1_f16(tmp_data + 19 * 4);
float16x4_t d32 = vld1_f16(tmp_data + 20 * 4);
float16x4_t d33 = vld1_f16(tmp_data + 21 * 4);
float16x4_t d34 = vld1_f16(tmp_data + 22 * 4);
float16x4_t d35 = vld1_f16(tmp_data + 23 * 4);
float16x4_t d40 = vld1_f16(tmp_data + 24 * 4);
float16x4_t d41 = vld1_f16(tmp_data + 25 * 4);
float16x4_t d42 = vld1_f16(tmp_data + 26 * 4);
float16x4_t d43 = vld1_f16(tmp_data + 27 * 4);
float16x4_t d44 = vld1_f16(tmp_data + 28 * 4);
float16x4_t d45 = vld1_f16(tmp_data + 29 * 4);
float16x4_t d50 = vld1_f16(tmp_data + 30 * 4);
float16x4_t d51 = vld1_f16(tmp_data + 31 * 4);
float16x4_t d52 = vld1_f16(tmp_data + 32 * 4);
float16x4_t d53 = vld1_f16(tmp_data + 33 * 4);
float16x4_t d54 = vld1_f16(tmp_data + 34 * 4);
float16x4_t d55 = vld1_f16(tmp_data + 35 * 4);
float16x4_t t00 = vadd_f16(vsub_f16(vmul_n_f16(d00, 4), vmul_n_f16(d20, 5)), d40);
float16x4_t t01 = vadd_f16(vsub_f16(vmul_n_f16(d01, 4), vmul_n_f16(d21, 5)), d41);
float16x4_t t02 = vadd_f16(vsub_f16(vmul_n_f16(d02, 4), vmul_n_f16(d22, 5)), d42);
float16x4_t t03 = vadd_f16(vsub_f16(vmul_n_f16(d03, 4), vmul_n_f16(d23, 5)), d43);
float16x4_t t04 = vadd_f16(vsub_f16(vmul_n_f16(d04, 4), vmul_n_f16(d24, 5)), d44);
float16x4_t t05 = vadd_f16(vsub_f16(vmul_n_f16(d05, 4), vmul_n_f16(d25, 5)), d45);
float16x4_t t10 = vadd_f16(vadd_f16(d30, d40), vmul_n_f16(vadd_f16(d10, d20), -4));
float16x4_t t11 = vadd_f16(vadd_f16(d31, d41), vmul_n_f16(vadd_f16(d11, d21), -4));
float16x4_t t12 = vadd_f16(vadd_f16(d32, d42), vmul_n_f16(vadd_f16(d12, d22), -4));
float16x4_t t13 = vadd_f16(vadd_f16(d33, d43), vmul_n_f16(vadd_f16(d13, d23), -4));
float16x4_t t14 = vadd_f16(vadd_f16(d34, d44), vmul_n_f16(vadd_f16(d14, d24), -4));
float16x4_t t15 = vadd_f16(vadd_f16(d35, d45), vmul_n_f16(vadd_f16(d15, d25), -4));
float16x4_t t20 = vadd_f16(vsub_f16(d40, d30), vmul_n_f16(vsub_f16(d10, d20), 4));
float16x4_t t21 = vadd_f16(vsub_f16(d41, d31), vmul_n_f16(vsub_f16(d11, d21), 4));
float16x4_t t22 = vadd_f16(vsub_f16(d42, d32), vmul_n_f16(vsub_f16(d12, d22), 4));
float16x4_t t23 = vadd_f16(vsub_f16(d43, d33), vmul_n_f16(vsub_f16(d13, d23), 4));
float16x4_t t24 = vadd_f16(vsub_f16(d44, d34), vmul_n_f16(vsub_f16(d14, d24), 4));
float16x4_t t25 = vadd_f16(vsub_f16(d45, d35), vmul_n_f16(vsub_f16(d15, d25), 4));
float16x4_t t30 = vadd_f16(vsub_f16(d40, d20), vmul_n_f16(vsub_f16(d30, d10), 2));
float16x4_t t31 = vadd_f16(vsub_f16(d41, d21), vmul_n_f16(vsub_f16(d31, d11), 2));
float16x4_t t32 = vadd_f16(vsub_f16(d42, d22), vmul_n_f16(vsub_f16(d32, d12), 2));
float16x4_t t33 = vadd_f16(vsub_f16(d43, d23), vmul_n_f16(vsub_f16(d33, d13), 2));
float16x4_t t34 = vadd_f16(vsub_f16(d44, d24), vmul_n_f16(vsub_f16(d34, d14), 2));
float16x4_t t35 = vadd_f16(vsub_f16(d45, d25), vmul_n_f16(vsub_f16(d35, d15), 2));
float16x4_t t40 = vadd_f16(vsub_f16(d40, d20), vmul_n_f16(vsub_f16(d10, d30), 2));
float16x4_t t41 = vadd_f16(vsub_f16(d41, d21), vmul_n_f16(vsub_f16(d11, d31), 2));
float16x4_t t42 = vadd_f16(vsub_f16(d42, d22), vmul_n_f16(vsub_f16(d12, d32), 2));
float16x4_t t43 = vadd_f16(vsub_f16(d43, d23), vmul_n_f16(vsub_f16(d13, d33), 2));
float16x4_t t44 = vadd_f16(vsub_f16(d44, d24), vmul_n_f16(vsub_f16(d14, d34), 2));
float16x4_t t45 = vadd_f16(vsub_f16(d45, d25), vmul_n_f16(vsub_f16(d15, d35), 2));
float16x4_t t50 = vadd_f16(vsub_f16(vmul_n_f16(d10, 4), vmul_n_f16(d30, 5)), d50);
float16x4_t t51 = vadd_f16(vsub_f16(vmul_n_f16(d11, 4), vmul_n_f16(d31, 5)), d51);
float16x4_t t52 = vadd_f16(vsub_f16(vmul_n_f16(d12, 4), vmul_n_f16(d32, 5)), d52);
float16x4_t t53 = vadd_f16(vsub_f16(vmul_n_f16(d13, 4), vmul_n_f16(d33, 5)), d53);
float16x4_t t54 = vadd_f16(vsub_f16(vmul_n_f16(d14, 4), vmul_n_f16(d34, 5)), d54);
float16x4_t t55 = vadd_f16(vsub_f16(vmul_n_f16(d15, 4), vmul_n_f16(d35, 5)), d55);
float16x4_t m00 = vadd_f16(vsub_f16(vmul_n_f16(t00, 4), vmul_n_f16(t02, 5)), t04);
float16x4_t m01 = vadd_f16(vadd_f16(t03, t04), vmul_n_f16(vadd_f16(t01, t02), -4));
float16x4_t m02 = vadd_f16(vsub_f16(t04, t03), vmul_n_f16(vsub_f16(t01, t02), 4));
float16x4_t m03 = vadd_f16(vsub_f16(t04, t02), vmul_n_f16(vsub_f16(t03, t01), 2));
float16x4_t m04 = vadd_f16(vsub_f16(t04, t02), vmul_n_f16(vsub_f16(t01, t03), 2));
float16x4_t m05 = vadd_f16(vsub_f16(vmul_n_f16(t01, 4), vmul_n_f16(t03, 5)), t05);
float16x4_t m10 = vadd_f16(vsub_f16(vmul_n_f16(t10, 4), vmul_n_f16(t12, 5)), t14);
float16x4_t m11 = vadd_f16(vadd_f16(t13, t14), vmul_n_f16(vadd_f16(t11, t12), -4));
float16x4_t m12 = vadd_f16(vsub_f16(t14, t13), vmul_n_f16(vsub_f16(t11, t12), 4));
float16x4_t m13 = vadd_f16(vsub_f16(t14, t12), vmul_n_f16(vsub_f16(t13, t11), 2));
float16x4_t m14 = vadd_f16(vsub_f16(t14, t12), vmul_n_f16(vsub_f16(t11, t13), 2));
float16x4_t m15 = vadd_f16(vsub_f16(vmul_n_f16(t11, 4), vmul_n_f16(t13, 5)), t15);
float16x4_t m20 = vadd_f16(vsub_f16(vmul_n_f16(t20, 4), vmul_n_f16(t22, 5)), t24);
float16x4_t m21 = vadd_f16(vadd_f16(t23, t24), vmul_n_f16(vadd_f16(t21, t22), -4));
float16x4_t m22 = vadd_f16(vsub_f16(t24, t23), vmul_n_f16(vsub_f16(t21, t22), 4));
float16x4_t m23 = vadd_f16(vsub_f16(t24, t22), vmul_n_f16(vsub_f16(t23, t21), 2));
float16x4_t m24 = vadd_f16(vsub_f16(t24, t22), vmul_n_f16(vsub_f16(t21, t23), 2));
float16x4_t m25 = vadd_f16(vsub_f16(vmul_n_f16(t21, 4), vmul_n_f16(t23, 5)), t25);
float16x4_t m30 = vadd_f16(vsub_f16(vmul_n_f16(t30, 4), vmul_n_f16(t32, 5)), t34);
float16x4_t m31 = vadd_f16(vadd_f16(t33, t34), vmul_n_f16(vadd_f16(t31, t32), -4));
float16x4_t m32 = vadd_f16(vsub_f16(t34, t33), vmul_n_f16(vsub_f16(t31, t32), 4));
float16x4_t m33 = vadd_f16(vsub_f16(t34, t32), vmul_n_f16(vsub_f16(t33, t31), 2));
float16x4_t m34 = vadd_f16(vsub_f16(t34, t32), vmul_n_f16(vsub_f16(t31, t33), 2));
float16x4_t m35 = vadd_f16(vsub_f16(vmul_n_f16(t31, 4), vmul_n_f16(t33, 5)), t35);
float16x4_t m40 = vadd_f16(vsub_f16(vmul_n_f16(t40, 4), vmul_n_f16(t42, 5)), t44);
float16x4_t m41 = vadd_f16(vadd_f16(t43, t44), vmul_n_f16(vadd_f16(t41, t42), -4));
float16x4_t m42 = vadd_f16(vsub_f16(t44, t43), vmul_n_f16(vsub_f16(t41, t42), 4));
float16x4_t m43 = vadd_f16(vsub_f16(t44, t42), vmul_n_f16(vsub_f16(t43, t41), 2));
float16x4_t m44 = vadd_f16(vsub_f16(t44, t42), vmul_n_f16(vsub_f16(t41, t43), 2));
float16x4_t m45 = vadd_f16(vsub_f16(vmul_n_f16(t41, 4), vmul_n_f16(t43, 5)), t45);
float16x4_t m50 = vadd_f16(vsub_f16(vmul_n_f16(t50, 4), vmul_n_f16(t52, 5)), t54);
float16x4_t m51 = vadd_f16(vadd_f16(t53, t54), vmul_n_f16(vadd_f16(t51, t52), -4));
float16x4_t m52 = vadd_f16(vsub_f16(t54, t53), vmul_n_f16(vsub_f16(t51, t52), 4));
float16x4_t m53 = vadd_f16(vsub_f16(t54, t52), vmul_n_f16(vsub_f16(t53, t51), 2));
float16x4_t m54 = vadd_f16(vsub_f16(t54, t52), vmul_n_f16(vsub_f16(t51, t53), 2));
float16x4_t m55 = vadd_f16(vsub_f16(vmul_n_f16(t51, 4), vmul_n_f16(t53, 5)), t55);
vst1_f16(trans_input_data, m00);
vst1_f16(trans_input_data + step, m01);
vst1_f16(trans_input_data + 2 * step, m02);
vst1_f16(trans_input_data + 3 * step, m03);
vst1_f16(trans_input_data + 4 * step, m04);
vst1_f16(trans_input_data + 5 * step, m05);
vst1_f16(trans_input_data + 6 * step, m10);
vst1_f16(trans_input_data + 7 * step, m11);
vst1_f16(trans_input_data + 8 * step, m12);
vst1_f16(trans_input_data + 9 * step, m13);
vst1_f16(trans_input_data + 10 * step, m14);
vst1_f16(trans_input_data + 11 * step, m15);
vst1_f16(trans_input_data + 12 * step, m20);
vst1_f16(trans_input_data + 13 * step, m21);
vst1_f16(trans_input_data + 14 * step, m22);
vst1_f16(trans_input_data + 15 * step, m23);
vst1_f16(trans_input_data + 16 * step, m24);
vst1_f16(trans_input_data + 17 * step, m25);
vst1_f16(trans_input_data + 18 * step, m30);
vst1_f16(trans_input_data + 19 * step, m31);
vst1_f16(trans_input_data + 20 * step, m32);
vst1_f16(trans_input_data + 21 * step, m33);
vst1_f16(trans_input_data + 22 * step, m34);
vst1_f16(trans_input_data + 23 * step, m35);
vst1_f16(trans_input_data + 24 * step, m40);
vst1_f16(trans_input_data + 25 * step, m41);
vst1_f16(trans_input_data + 26 * step, m42);
vst1_f16(trans_input_data + 27 * step, m43);
vst1_f16(trans_input_data + 28 * step, m44);
vst1_f16(trans_input_data + 29 * step, m45);
vst1_f16(trans_input_data + 30 * step, m50);
vst1_f16(trans_input_data + 31 * step, m51);
vst1_f16(trans_input_data + 32 * step, m52);
vst1_f16(trans_input_data + 33 * step, m53);
vst1_f16(trans_input_data + 34 * step, m54);
vst1_f16(trans_input_data + 35 * step, m55);
}
void Conv3x3Fp16InputTransform(const float16_t *input_data, float16_t *trans_input, float16_t *tmp_data,
int start_index, int real_cal_num, int out_w_block, ConvParameter *conv_param) {
// input data format : nhwc
int output_unit = 4;
int input_channel = conv_param->input_channel_;
int input_width = conv_param->input_w_;
int input_height = conv_param->input_h_;
int pad_w = conv_param->pad_w_;
int pad_h = conv_param->pad_h_;
int ic4 = UP_DIV(input_channel, C4NUM);
for (int cal_id = 0; cal_id < real_cal_num; cal_id++) {
int x_id = start_index + cal_id;
int origin_x = (x_id % out_w_block) * output_unit - pad_w;
int origin_y = (x_id / out_w_block) * output_unit - pad_h;
int real_x_start = origin_x > 0 ? 0 : -origin_x;
int real_x_end = (origin_x + 6) < input_width ? 6 : (input_width - origin_x);
int real_y_start = origin_y > 0 ? 0 : -origin_y;
int real_y_end = (origin_y + 6) < input_height ? 6 : (input_height - origin_y);
int src_plane_offset = input_channel * (origin_y * input_width + origin_x);
int dst_plane_offset = cal_id * C4NUM;
for (int ic = 0; ic < ic4; ic++) {
// clear tmp buffer
memset(tmp_data, 0, 6 * 6 * C4NUM * sizeof(float16_t));
// get real input block with padding
int src_ic4_offset = src_plane_offset + ic * C4NUM;
for (int interval = real_y_start; interval < real_y_end; interval++) {
int src_y_offset = src_ic4_offset + interval * input_width * input_channel + real_x_start * input_channel;
int dst_y_offset = interval * 6 * C4NUM + real_x_start * C4NUM;
for (int j = 0; j < (real_x_end - real_x_start); j++) {
int src_x_offset = src_y_offset + j * input_channel;
int dst_x_offset = dst_y_offset + j * C4NUM;
float16_t *src_addr = (float16_t *)(input_data) + src_x_offset;
float16_t *dst_addr = tmp_data + dst_x_offset;
dst_addr[0] = src_addr[0];
dst_addr[1] = src_addr[1];
dst_addr[2] = src_addr[2];
dst_addr[3] = src_addr[3];
}
}
// todo
// input transform
int dst_ic4_offset = dst_plane_offset + ic * 16 * C4NUM;
size_t dst_step = ic4 * C4NUM * 16;
float16_t *trans_input_ptr = trans_input + dst_ic4_offset;
Conv3x3Fp16InputUnit(tmp_data, trans_input_ptr, dst_step);
}
}
}
void Conv3x3Fp16FilterTransform(const float16_t *weight_data, float16_t *trans_weight, int iC4, int output_channel,
int kernel_plane) {
int dst_step = iC4 * C4NUM * 8;
for (int o = 0; o < output_channel; o++) {
int oc8_block_num = o / C8NUM;
int oc8_block_rem = o % C8NUM;
int src_oc_offset = o * iC4 * C4NUM * kernel_plane;
int dst_oc_offset = oc8_block_num * C8NUM * iC4 * C4NUM * 36 + oc8_block_rem;
for (int i = 0; i < iC4; i++) {
const float16_t *src_ic4_ptr = weight_data + src_oc_offset + i * kernel_plane * C4NUM;
float16_t *dst_ic4_ptr = trans_weight + dst_oc_offset + i * 8 * C4NUM;
float16x4_t g00 = vld1_f16(src_ic4_ptr);
float16x4_t g01 = vld1_f16(src_ic4_ptr + 4);
float16x4_t g02 = vld1_f16(src_ic4_ptr + 2 * 4);
float16x4_t g10 = vld1_f16(src_ic4_ptr + 3 * 4);
float16x4_t g11 = vld1_f16(src_ic4_ptr + 4 * 4);
float16x4_t g12 = vld1_f16(src_ic4_ptr + 5 * 4);
float16x4_t g20 = vld1_f16(src_ic4_ptr + 6 * 4);
float16x4_t g21 = vld1_f16(src_ic4_ptr + 7 * 4);
float16x4_t g22 = vld1_f16(src_ic4_ptr + 8 * 4);
float16x4_t dst00 = vmul_n_f16(g00, 0.25);
float16x4_t dst01 = vmul_n_f16(g01, 0.25);
float16x4_t dst02 = vmul_n_f16(g02, 0.25);
float16x4_t dst10 = vmul_n_f16(vadd_f16(g00, vadd_f16(g10, g20)), -0.1666666666667);
float16x4_t dst11 = vmul_n_f16(vadd_f16(g01, vadd_f16(g11, g21)), -0.1666666666667);
float16x4_t dst12 = vmul_n_f16(vadd_f16(g02, vadd_f16(g12, g22)), -0.1666666666667);
float16x4_t dst20 = vmul_n_f16(vsub_f16(vadd_f16(g00, g20), g10), -0.1666666666667);
float16x4_t dst21 = vmul_n_f16(vsub_f16(vadd_f16(g01, g21), g11), -0.1666666666667);
float16x4_t dst22 = vmul_n_f16(vsub_f16(vadd_f16(g02, g22), g12), -0.1666666666667);
float16x4_t dst30 = vadd_f16(vmul_n_f16(g10, 0.08333333333333),
vadd_f16(vmul_n_f16(g00, 0.04166666666667), vmul_n_f16(g20, 0.1666666666667)));
float16x4_t dst31 = vadd_f16(vmul_n_f16(g11, 0.08333333333333),
vadd_f16(vmul_n_f16(g01, 0.04166666666667), vmul_n_f16(g21, 0.1666666666667)));
float16x4_t dst32 = vadd_f16(vmul_n_f16(g12, 0.08333333333333),
vadd_f16(vmul_n_f16(g02, 0.04166666666667), vmul_n_f16(g22, 0.1666666666667)));
float16x4_t dst40 = vsub_f16(vadd_f16(vmul_n_f16(g00, 0.04166666666667), vmul_n_f16(g20, 0.1666666666667)),
vmul_n_f16(g10, 0.08333333333333));
float16x4_t dst41 = vsub_f16(vadd_f16(vmul_n_f16(g01, 0.04166666666667), vmul_n_f16(g21, 0.1666666666667)),
vmul_n_f16(g11, 0.08333333333333));
float16x4_t dst42 = vsub_f16(vadd_f16(vmul_n_f16(g02, 0.04166666666667), vmul_n_f16(g22, 0.1666666666667)),
vmul_n_f16(g12, 0.08333333333333));
float16x4_t dst50 = g20;
float16x4_t dst51 = g21;
float16x4_t dst52 = g22;
float16x4_t m00 = vmul_n_f16(dst00, 0.25);
float16x4_t m01 = vmul_n_f16(vadd_f16(dst00, vadd_f16(dst01, dst02)), -0.1666666666667);
float16x4_t m02 = vmul_n_f16(vsub_f16(vadd_f16(dst00, dst02), dst01), -0.1666666666667);
float16x4_t m03 = vadd_f16(vmul_n_f16(dst01, 0.08333333333333),
vadd_f16(vmul_n_f16(dst00, 0.04166666666667), vmul_n_f16(dst02, 0.1666666666667)));
float16x4_t m04 = vsub_f16(vadd_f16(vmul_n_f16(dst00, 0.04166666666667), vmul_n_f16(dst02, 0.1666666666667)),
vmul_n_f16(dst01, 0.08333333333333));
float16x4_t m05 = dst02;
float16x4_t m10 = vmul_n_f16(dst10, 0.25);
float16x4_t m11 = vmul_n_f16(vadd_f16(dst10, vadd_f16(dst11, dst12)), -0.1666666666667);
float16x4_t m12 = vmul_n_f16(vsub_f16(vadd_f16(dst10, dst12), dst11), -0.1666666666667);
float16x4_t m13 = vadd_f16(vmul_n_f16(dst11, 0.08333333333333),
vadd_f16(vmul_n_f16(dst10, 0.04166666666667), vmul_n_f16(dst12, 0.1666666666667)));
float16x4_t m14 = vsub_f16(vadd_f16(vmul_n_f16(dst10, 0.04166666666667), vmul_n_f16(dst12, 0.1666666666667)),
vmul_n_f16(dst11, 0.08333333333333));
float16x4_t m15 = dst12;
float16x4_t m20 = vmul_n_f16(dst20, 0.25);
float16x4_t m21 = vmul_n_f16(vadd_f16(dst20, vadd_f16(dst21, dst22)), -0.1666666666667);
float16x4_t m22 = vmul_n_f16(vsub_f16(vadd_f16(dst20, dst22), dst21), -0.1666666666667);
float16x4_t m23 = vadd_f16(vmul_n_f16(dst21, 0.08333333333333),
vadd_f16(vmul_n_f16(dst20, 0.04166666666667), vmul_n_f16(dst22, 0.1666666666667)));
float16x4_t m24 = vsub_f16(vadd_f16(vmul_n_f16(dst20, 0.04166666666667), vmul_n_f16(dst22, 0.1666666666667)),
vmul_n_f16(dst21, 0.08333333333333));
float16x4_t m25 = dst22;
float16x4_t m30 = vmul_n_f16(dst30, 0.25);
float16x4_t m31 = vmul_n_f16(vadd_f16(dst30, vadd_f16(dst31, dst32)), -0.1666666666667);
float16x4_t m32 = vmul_n_f16(vsub_f16(vadd_f16(dst30, dst32), dst31), -0.1666666666667);
float16x4_t m33 = vadd_f16(vmul_n_f16(dst31, 0.08333333333333),
vadd_f16(vmul_n_f16(dst30, 0.04166666666667), vmul_n_f16(dst32, 0.1666666666667)));
float16x4_t m34 = vsub_f16(vadd_f16(vmul_n_f16(dst30, 0.04166666666667), vmul_n_f16(dst32, 0.1666666666667)),
vmul_n_f16(dst31, 0.08333333333333));
float16x4_t m35 = dst32;
float16x4_t m40 = vmul_n_f16(dst40, 0.25);
float16x4_t m41 = vmul_n_f16(vadd_f16(dst40, vadd_f16(dst41, dst42)), -0.1666666666667);
float16x4_t m42 = vmul_n_f16(vsub_f16(vadd_f16(dst40, dst42), dst41), -0.1666666666667);
float16x4_t m43 = vadd_f16(vmul_n_f16(dst41, 0.08333333333333),
vadd_f16(vmul_n_f16(dst40, 0.04166666666667), vmul_n_f16(dst42, 0.1666666666667)));
float16x4_t m44 = vsub_f16(vadd_f16(vmul_n_f16(dst40, 0.04166666666667), vmul_n_f16(dst42, 0.1666666666667)),
vmul_n_f16(dst41, 0.08333333333333));
float16x4_t m45 = dst42;
float16x4_t m50 = vmul_n_f16(dst50, 0.25);
float16x4_t m51 = vmul_n_f16(vadd_f16(dst50, vadd_f16(dst51, dst52)), -0.1666666666667);
float16x4_t m52 = vmul_n_f16(vsub_f16(vadd_f16(dst50, dst52), dst51), -0.1666666666667);
float16x4_t m53 = vadd_f16(vmul_n_f16(dst51, 0.08333333333333),
vadd_f16(vmul_n_f16(dst50, 0.04166666666667), vmul_n_f16(dst52, 0.1666666666667)));
float16x4_t m54 = vsub_f16(vadd_f16(vmul_n_f16(dst50, 0.04166666666667), vmul_n_f16(dst52, 0.1666666666667)),
vmul_n_f16(dst51, 0.08333333333333));
float16x4_t m55 = dst52;
for (int j = 0; j < 4; j++) {
dst_ic4_ptr[j * 8] = m00[j];
dst_ic4_ptr[j * 8 + dst_step] = m01[j];
dst_ic4_ptr[j * 8 + 2 * dst_step] = m02[j];
dst_ic4_ptr[j * 8 + 3 * dst_step] = m03[j];
dst_ic4_ptr[j * 8 + 4 * dst_step] = m04[j];
dst_ic4_ptr[j * 8 + 5 * dst_step] = m05[j];
dst_ic4_ptr[j * 8 + 6 * dst_step] = m10[j];
dst_ic4_ptr[j * 8 + 7 * dst_step] = m11[j];
dst_ic4_ptr[j * 8 + 8 * dst_step] = m12[j];
dst_ic4_ptr[j * 8 + 9 * dst_step] = m13[j];
dst_ic4_ptr[j * 8 + 10 * dst_step] = m14[j];
dst_ic4_ptr[j * 8 + 11 * dst_step] = m15[j];
dst_ic4_ptr[j * 8 + 12 * dst_step] = m20[j];
dst_ic4_ptr[j * 8 + 13 * dst_step] = m21[j];
dst_ic4_ptr[j * 8 + 14 * dst_step] = m22[j];
dst_ic4_ptr[j * 8 + 15 * dst_step] = m23[j];
dst_ic4_ptr[j * 8 + 16 * dst_step] = m24[j];
dst_ic4_ptr[j * 8 + 17 * dst_step] = m25[j];
dst_ic4_ptr[j * 8 + 18 * dst_step] = m30[j];
dst_ic4_ptr[j * 8 + 19 * dst_step] = m31[j];
dst_ic4_ptr[j * 8 + 20 * dst_step] = m32[j];
dst_ic4_ptr[j * 8 + 21 * dst_step] = m33[j];
dst_ic4_ptr[j * 8 + 22 * dst_step] = m34[j];
dst_ic4_ptr[j * 8 + 23 * dst_step] = m35[j];
dst_ic4_ptr[j * 8 + 24 * dst_step] = m40[j];
dst_ic4_ptr[j * 8 + 25 * dst_step] = m41[j];
dst_ic4_ptr[j * 8 + 26 * dst_step] = m42[j];
dst_ic4_ptr[j * 8 + 27 * dst_step] = m43[j];
dst_ic4_ptr[j * 8 + 28 * dst_step] = m44[j];
dst_ic4_ptr[j * 8 + 29 * dst_step] = m45[j];
dst_ic4_ptr[j * 8 + 30 * dst_step] = m50[j];
dst_ic4_ptr[j * 8 + 31 * dst_step] = m51[j];
dst_ic4_ptr[j * 8 + 32 * dst_step] = m52[j];
dst_ic4_ptr[j * 8 + 33 * dst_step] = m53[j];
dst_ic4_ptr[j * 8 + 34 * dst_step] = m54[j];
dst_ic4_ptr[j * 8 + 35 * dst_step] = m55[j];
}
}
}
}
void Conv3x3Fp16OutputUnit(const float16_t *gemm_out, const float16_t *bias_data, float16_t *output_data,
int output_w) {
float16x8_t s00 = vld1q_f16(gemm_out);
float16x8_t s01 = vld1q_f16(gemm_out + 8);
float16x8_t s02 = vld1q_f16(gemm_out + 16);
float16x8_t s03 = vld1q_f16(gemm_out + 24);
float16x8_t s04 = vld1q_f16(gemm_out + 32);
float16x8_t s05 = vld1q_f16(gemm_out + 40);
float16x8_t s10 = vld1q_f16(gemm_out + 48);
float16x8_t s11 = vld1q_f16(gemm_out + 56);
float16x8_t s12 = vld1q_f16(gemm_out + 64);
float16x8_t s13 = vld1q_f16(gemm_out + 72);
float16x8_t s14 = vld1q_f16(gemm_out + 80);
float16x8_t s15 = vld1q_f16(gemm_out + 88);
float16x8_t s20 = vld1q_f16(gemm_out + 96);
float16x8_t s21 = vld1q_f16(gemm_out + 104);
float16x8_t s22 = vld1q_f16(gemm_out + 112);
float16x8_t s23 = vld1q_f16(gemm_out + 120);
float16x8_t s24 = vld1q_f16(gemm_out + 128);
float16x8_t s25 = vld1q_f16(gemm_out + 136);
float16x8_t s30 = vld1q_f16(gemm_out + 144);
float16x8_t s31 = vld1q_f16(gemm_out + 152);
float16x8_t s32 = vld1q_f16(gemm_out + 160);
float16x8_t s33 = vld1q_f16(gemm_out + 168);
float16x8_t s34 = vld1q_f16(gemm_out + 176);
float16x8_t s35 = vld1q_f16(gemm_out + 184);
float16x8_t s40 = vld1q_f16(gemm_out + 192);
float16x8_t s41 = vld1q_f16(gemm_out + 200);
float16x8_t s42 = vld1q_f16(gemm_out + 208);
float16x8_t s43 = vld1q_f16(gemm_out + 216);
float16x8_t s44 = vld1q_f16(gemm_out + 224);
float16x8_t s45 = vld1q_f16(gemm_out + 232);
float16x8_t s50 = vld1q_f16(gemm_out + 240);
float16x8_t s51 = vld1q_f16(gemm_out + 248);
float16x8_t s52 = vld1q_f16(gemm_out + 256);
float16x8_t s53 = vld1q_f16(gemm_out + 264);
float16x8_t s54 = vld1q_f16(gemm_out + 272);
float16x8_t s55 = vld1q_f16(gemm_out + 280);
float16x8_t t00 = vaddq_f16(vaddq_f16(vaddq_f16(s00, s10), vaddq_f16(s20, s30)), s40);
float16x8_t t01 = vaddq_f16(vaddq_f16(vaddq_f16(s01, s11), vaddq_f16(s21, s31)), s41);
float16x8_t t02 = vaddq_f16(vaddq_f16(vaddq_f16(s02, s12), vaddq_f16(s22, s32)), s42);
float16x8_t t03 = vaddq_f16(vaddq_f16(vaddq_f16(s03, s13), vaddq_f16(s23, s33)), s43);
float16x8_t t04 = vaddq_f16(vaddq_f16(vaddq_f16(s04, s14), vaddq_f16(s24, s34)), s44);
float16x8_t t05 = vaddq_f16(vaddq_f16(vaddq_f16(s05, s15), vaddq_f16(s25, s35)), s45);
float16x8_t t10 = vaddq_f16(vsubq_f16(s10, s20), vmulq_n_f16(vsubq_f16(s30, s40), 2));
float16x8_t t11 = vaddq_f16(vsubq_f16(s11, s21), vmulq_n_f16(vsubq_f16(s31, s41), 2));
float16x8_t t12 = vaddq_f16(vsubq_f16(s12, s22), vmulq_n_f16(vsubq_f16(s32, s42), 2));
float16x8_t t13 = vaddq_f16(vsubq_f16(s13, s23), vmulq_n_f16(vsubq_f16(s33, s43), 2));
float16x8_t t14 = vaddq_f16(vsubq_f16(s14, s24), vmulq_n_f16(vsubq_f16(s34, s44), 2));
float16x8_t t15 = vaddq_f16(vsubq_f16(s15, s25), vmulq_n_f16(vsubq_f16(s35, s45), 2));
float16x8_t t20 = vaddq_f16(vaddq_f16(s10, s20), vmulq_n_f16(vaddq_f16(s30, s40), 4));
float16x8_t t21 = vaddq_f16(vaddq_f16(s11, s21), vmulq_n_f16(vaddq_f16(s31, s41), 4));
float16x8_t t22 = vaddq_f16(vaddq_f16(s12, s22), vmulq_n_f16(vaddq_f16(s32, s42), 4));
float16x8_t t23 = vaddq_f16(vaddq_f16(s13, s23), vmulq_n_f16(vaddq_f16(s33, s43), 4));
float16x8_t t24 = vaddq_f16(vaddq_f16(s14, s24), vmulq_n_f16(vaddq_f16(s34, s44), 4));
float16x8_t t25 = vaddq_f16(vaddq_f16(s15, s25), vmulq_n_f16(vaddq_f16(s35, s45), 4));
float16x8_t t30 = vaddq_f16(vaddq_f16(vsubq_f16(s10, s20), vmulq_n_f16(vsubq_f16(s30, s40), 8)), s50);
float16x8_t t31 = vaddq_f16(vaddq_f16(vsubq_f16(s11, s21), vmulq_n_f16(vsubq_f16(s31, s41), 8)), s51);
float16x8_t t32 = vaddq_f16(vaddq_f16(vsubq_f16(s12, s22), vmulq_n_f16(vsubq_f16(s32, s42), 8)), s52);
float16x8_t t33 = vaddq_f16(vaddq_f16(vsubq_f16(s13, s23), vmulq_n_f16(vsubq_f16(s33, s43), 8)), s53);
float16x8_t t34 = vaddq_f16(vaddq_f16(vsubq_f16(s14, s24), vmulq_n_f16(vsubq_f16(s34, s44), 8)), s54);
float16x8_t t35 = vaddq_f16(vaddq_f16(vsubq_f16(s15, s25), vmulq_n_f16(vsubq_f16(s35, s45), 8)), s55);
float16x8_t d00 = vaddq_f16(vaddq_f16(vaddq_f16(t00, t01), vaddq_f16(t02, t03)), t04);
float16x8_t d01 = vaddq_f16(vsubq_f16(t01, t02), vmulq_n_f16(vsubq_f16(t03, t04), 2));
float16x8_t d02 = vaddq_f16(vaddq_f16(t01, t02), vmulq_n_f16(vaddq_f16(t03, t04), 4));
float16x8_t d03 = vaddq_f16(vaddq_f16(vsubq_f16(t01, t02), vmulq_n_f16(vsubq_f16(t03, t04), 8)), t05);
float16x8_t d10 = vaddq_f16(vaddq_f16(vaddq_f16(t10, t11), vaddq_f16(t12, t13)), t14);
float16x8_t d11 = vaddq_f16(vsubq_f16(t11, t12), vmulq_n_f16(vsubq_f16(t13, t14), 2));
float16x8_t d12 = vaddq_f16(vaddq_f16(t11, t12), vmulq_n_f16(vaddq_f16(t13, t14), 4));
float16x8_t d13 = vaddq_f16(vaddq_f16(vsubq_f16(t11, t12), vmulq_n_f16(vsubq_f16(t13, t14), 8)), t15);
float16x8_t d20 = vaddq_f16(vaddq_f16(vaddq_f16(t20, t21), vaddq_f16(t22, t23)), t24);
float16x8_t d21 = vaddq_f16(vsubq_f16(t21, t22), vmulq_n_f16(vsubq_f16(t23, t24), 2));
float16x8_t d22 = vaddq_f16(vaddq_f16(t21, t22), vmulq_n_f16(vaddq_f16(t23, t24), 4));
float16x8_t d23 = vaddq_f16(vaddq_f16(vsubq_f16(t21, t22), vmulq_n_f16(vsubq_f16(t23, t24), 8)), t25);
float16x8_t d30 = vaddq_f16(vaddq_f16(vaddq_f16(t30, t31), vaddq_f16(t32, t33)), t34);
float16x8_t d31 = vaddq_f16(vsubq_f16(t31, t32), vmulq_n_f16(vsubq_f16(t33, t34), 2));
float16x8_t d32 = vaddq_f16(vaddq_f16(t31, t32), vmulq_n_f16(vaddq_f16(t33, t34), 4));
float16x8_t d33 = vaddq_f16(vaddq_f16(vsubq_f16(t31, t32), vmulq_n_f16(vsubq_f16(t33, t34), 8)), t35);
vst1q_f16(output_data, d00);
vst1q_f16(output_data + 8, d01);
vst1q_f16(output_data + 16, d02);
vst1q_f16(output_data + 24, d03);
vst1q_f16(output_data + output_w * 8, d10);
vst1q_f16(output_data + output_w * 8 + 8, d11);
vst1q_f16(output_data + output_w * 8 + 16, d12);
vst1q_f16(output_data + output_w * 8 + 24, d13);
vst1q_f16(output_data + 2 * output_w * 8, d20);
vst1q_f16(output_data + 2 * output_w * 8 + 8, d21);
vst1q_f16(output_data + 2 * output_w * 8 + 16, d22);
vst1q_f16(output_data + 2 * output_w * 8 + 24, d23);
vst1q_f16(output_data + 3 * output_w * 8, d30);
vst1q_f16(output_data + 3 * output_w * 8 + 8, d31);
vst1q_f16(output_data + 3 * output_w * 8 + 16, d32);
vst1q_f16(output_data + 3 * output_w * 8 + 24, d33);
}
void Conv3x3Fp16OutputTransform(const float16_t *gemm_out, float16_t *out_data, const float16_t *bias_data,
int start_index, int real_cal_num, int out_w_block, ConvParameter *conv_param) {
int output_channel = conv_param->output_channel_;
int output_w = conv_param->output_w_;
int output_h = conv_param->output_h_;
int oc8 = UP_DIV(output_channel, C8NUM);
for (int i = 0; i < real_cal_num; i++) {
int out_w_index = (start_index + i) % out_w_block;
int out_h_index = (start_index + i) / out_w_block;
int src_tile_offset = i * oc8 * C8NUM * 36;
int dst_tile_offset = 8 * (out_w_index * 4 + out_h_index * 4 * output_w);
for (int j = 0; j < oc8; j++) {
int src_oc8_offset = src_tile_offset + j * 36 * C8NUM;
int dst_oc8_offset = dst_tile_offset + j * C8NUM * output_h * output_w;
const float16_t *src_ptr = gemm_out + src_oc8_offset;
const float16_t *bias_ptr = bias_data + j * C8NUM;
float16_t *dst_ptr = out_data + dst_oc8_offset;
// output transform
Conv3x3Fp16OutputUnit(src_ptr, bias_ptr, dst_ptr, output_w);
}
}
}

View File

@ -0,0 +1,39 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_OPCLIB_FP16_WINOGRAD_TRANSFORM_FP16_H_
#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_OPCLIB_FP16_WINOGRAD_TRANSFORM_FP16_H_
#include <arm_neon.h>
#include <string.h>
#include "src/runtime/kernel/arm/opclib/fp16/pack_fp16.h"
#include "src/runtime/kernel/arm/opclib/fp16/conv_fp16.h"
// for fp16 convolution 3x3 filter/input/output transform
void Conv3x3Fp16InputUnit(float16_t *tmp_data, float16_t *trans_input_data, size_t step);
void Conv3x3Fp16InputTransform(const float16_t *input_data, float16_t *trans_input, float16_t *tmp_data,
int start_index, int real_cal_num, int out_w_block, ConvParameter *conv_param);
void Conv3x3Fp16FilterTransform(const float16_t *weight_data, float16_t *trans_weight, int iC8, int output_channel,
int kernel_plane);
void Conv3x3Fp16OutputUnit(const float16_t *gemm_out, const float16_t *bias_data, float16_t *output_data, int output_w);
void Conv3x3Fp16OutputTransform(const float16_t *gemm_out, float16_t *out_data, const float16_t *bias_data,
int start_index, int real_cal_num, int out_w_block, ConvParameter *conv_param);
#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_OPCLIB_FP16_WINOGRAD_TRANSFORM_FP16_H_

View File

@ -29,11 +29,24 @@ class OptimizeModule {
public:
OptimizeModule() {
bool support_optimize_ops = false;
bool support_fp16 = false;
#ifdef __ANDROID__
int hwcap_type = 16;
uint32_t hwcap = getHwCap(hwcap_type);
#if defined(__aarch64__)
#ifdef ENABLE_ARM64
if (hwcap & HWCAP_FPHP) {
#elif defined(__arm__)
if (hwcap & HWCAP_HALF) {
#endif
MS_LOG(INFO) << "Hw cap support FP16, hwcap: 0x" << hwcap;
support_fp16 = true;
#ifdef ENABLE_ARM64
}
#elif defined(__arm__)
}
#endif
#ifdef ENABLE_ARM64
if (hwcap & HWCAP_ASIMDDP) {
printf("Hw cap support SMID Dot Product, hwcap: 0x%x \n", hwcap);
support_optimize_ops = true;
@ -42,7 +55,7 @@ class OptimizeModule {
}
#endif
#endif
if (!support_optimize_ops) {
if ((!support_optimize_ops) && (!support_fp16)) {
return;
}
optimized_op_handler_ = dlopen(OPTIMIZE_SHARED_LIBRARY_PATH, RTLD_LAZY);
@ -61,4 +74,3 @@ class OptimizeModule {
};
#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_OPCLIB_OPTIMIZED_KERNEL_H_

View File

@ -18,331 +18,6 @@
#include <cstring>
#include <cstdlib>
#ifdef ENABLE_FP16
void Im2ColPackUnitFp16(float16_t *input_data, ConvParameter *conv_param, float16_t *packed_input, int real_cal_num,
int block_index) {
// input format : nhwc
int kernel_h = conv_param->kernel_h_;
int kernel_w = conv_param->kernel_w_;
int stride_h = conv_param->stride_h_;
int stride_w = conv_param->stride_w_;
int pad_h = conv_param->pad_h_;
int pad_w = conv_param->pad_w_;
int dilation_h = conv_param->dilation_h_;
int dilation_w = conv_param->dilation_w_;
int in_channel = conv_param->input_channel_;
int in_h = conv_param->input_h_;
int in_w = conv_param->input_w_;
int out_w = conv_param->output_w_;
int channel_block = UP_DIV(in_channel, 4);
int kernel_plane = kernel_h * kernel_w;
for (int i = 0; i < real_cal_num; i++) {
int block_start = block_index + i;
int input_h = block_start / out_w * stride_h - pad_h;
int input_w = block_start % out_w * stride_w - pad_w;
for (int j = 0; j < kernel_h; j++) {
int input_y = input_h + j * dilation_h;
if (input_y < 0 || input_y >= in_h) {
continue;
}
int input_y_stride = input_y * in_w * channel_block * C4NUM;
for (int n = 0; n < kernel_w; n++) {
int input_x = input_w + n * dilation_w;
if (input_x < 0 || input_x >= in_w) {
continue;
}
int input_x_stride = input_y_stride + input_x * channel_block * C4NUM;
int input_plane_offset = (j * kernel_w + n) * 16 * C4NUM * channel_block + i * C4NUM;
for (int m = 0; m < channel_block; m++) {
int channel_block_stride = input_x_stride + m * C4NUM;
int channel_block_offset = input_plane_offset + m * 16 * C4NUM;
(packed_input + channel_block_offset)[0] = (input_data + channel_block_stride)[0];
(packed_input + channel_block_offset)[1] = (input_data + channel_block_stride)[1];
(packed_input + channel_block_offset)[2] = (input_data + channel_block_stride)[2];
(packed_input + channel_block_offset)[3] = (input_data + channel_block_stride)[3];
} // channel_block loop
} // kernel_w loop
} // kernel_h loop
} // tile num loop
}
void PackWeightFp16(float16_t *weight_data, ConvParameter *conv_param, float16_t *packed_weight) {
// original weight format : ohwi
int tile_num = 8;
int inchannel_block = 4;
int kernel_h = conv_param->kernel_h_;
int kernel_w = conv_param->kernel_w_;
int in_channel = conv_param->input_channel_;
int out_channel = conv_param->output_channel_;
int kernel_block = UP_DIV(out_channel, tile_num);
int channel_block = UP_DIV(in_channel, inchannel_block);
int kernel_plane = kernel_h * kernel_w;
int pack_weight_size = kernel_block * channel_block * tile_num * inchannel_block * kernel_plane;
int unit_size = tile_num * inchannel_block;
int block_size = pack_weight_size / kernel_block;
for (int m = 0; m < kernel_plane; m++) {
int kernel_plane_stride = m * in_channel;
int packed_kernel_plane_stride = m * unit_size * channel_block;
for (int i = 0; i < channel_block; i++) {
int channel_block_stride = kernel_plane_stride + i * inchannel_block;
int packed_channel_block_size = packed_kernel_plane_stride + i * unit_size;
int ic_remainder = in_channel - i * inchannel_block;
int real_ic_num = ic_remainder < inchannel_block ? ic_remainder : inchannel_block;
for (int h = 0; h < real_ic_num; h++) {
int block_stride = channel_block_stride + h;
int packed_block_stride = packed_channel_block_size + h * tile_num;
for (int j = 0; j < kernel_block; j++) {
int kernel_block_stride = block_stride + j * tile_num * kernel_plane * in_channel;
int packed_kernel_block_size = packed_block_stride + j * block_size;
int oc_remainder = out_channel - j * tile_num;
int real_oc_num = oc_remainder < tile_num ? oc_remainder : tile_num;
for (int k = 0; k < real_oc_num; k++) {
float16_t *origin_data_ptr = weight_data + kernel_block_stride + k * kernel_plane * in_channel;
float16_t *packed_data_ptr = packed_weight + packed_kernel_block_size + k;
*packed_data_ptr = *origin_data_ptr;
}
} // kernel block loop
} // inchannel block loop
} // channel block loop
} // kernel plane loop
}
void PackWeightToC8Fp16(const float16_t *origin_weight_data, float16_t *packed_weight_data, ConvParameter *conv_param) {
// origin weight format : ohwi
int input_channel = conv_param->input_channel_;
int ic8 = UP_DIV(input_channel, C8NUM);
int output_channel = conv_param->output_channel_;
int kernel_plane = conv_param->kernel_h_ * conv_param->kernel_w_;
for (int k = 0; k < kernel_plane; k++) {
int src_kernel_offset = k * input_channel;
int dst_kernel_offset = k * C8NUM;
for (int o = 0; o < output_channel; o++) {
int src_oc_offset = src_kernel_offset + o * kernel_plane * input_channel;
int dst_oc_offset = dst_kernel_offset + o * ic8 * kernel_plane * C8NUM;
for (int i = 0; i < input_channel; i++) {
int c8_block_num = i / C8NUM;
int c8_block_rem = i % C8NUM;
int src_ic_offset = src_oc_offset + i;
int dst_ic_offset = dst_oc_offset + c8_block_num * kernel_plane * C8NUM + c8_block_rem;
(packed_weight_data + dst_ic_offset)[0] = (origin_weight_data + src_ic_offset)[0];
}
}
}
}
void PackWeightToC4Fp16(const float16_t *origin_weight_data, float16_t *packed_weight_data, ConvParameter *conv_param) {
// origin weight format : ohwi
int input_channel = conv_param->input_channel_;
int ic4 = UP_DIV(input_channel, C4NUM);
int output_channel = conv_param->output_channel_;
int kernel_plane = conv_param->kernel_h_ * conv_param->kernel_w_;
for (int k = 0; k < kernel_plane; k++) {
int src_kernel_offset = k * input_channel;
int dst_kernel_offset = k * C4NUM;
for (int o = 0; o < output_channel; o++) {
int src_oc_offset = src_kernel_offset + o * kernel_plane * input_channel;
int dst_oc_offset = dst_kernel_offset + o * ic4 * kernel_plane * C4NUM;
for (int i = 0; i < input_channel; i++) {
int c4_block_num = i / C4NUM;
int c4_block_rem = i % C4NUM;
int src_ic_offset = src_oc_offset + i;
int dst_ic_offset = dst_oc_offset + c4_block_num * kernel_plane * C4NUM + c4_block_rem;
(packed_weight_data + dst_ic_offset)[0] = (origin_weight_data + src_ic_offset)[0];
}
}
}
}
void PackNHWCToNC4HW4Fp16(const void *src, void *dst, int batch, int plane, int channel) {
int c4 = UP_DIV(channel, C4NUM);
for (int b = 0; b < batch; b++) {
int src_oc_offset = b * plane * channel;
int dst_oc_offset = b * plane * c4 * C4NUM;
for (int k = 0; k < plane; k++) {
int src_kernel_offset = src_oc_offset + k * channel;
int dst_kernel_offset = dst_oc_offset + k * C4NUM;
for (int i = 0; i < channel; i++) {
int c4_block_num = i / C4NUM;
int c4_block_rem = i % C4NUM;
int src_ic_offset = src_kernel_offset + i;
int dst_ic_offset = dst_kernel_offset + c4_block_num * plane * C4NUM + c4_block_rem;
((float16_t *)dst + dst_ic_offset)[0] = ((float16_t *)src + src_ic_offset)[0];
}
}
}
}
void PackNCHWToNC4HW4Fp16(const void *src, void *dst, int batch, int plane, int channel) {
int c4 = UP_DIV(channel, C4NUM);
for (int b = 0; b < batch; b++) {
int src_offset = b * plane * channel;
int dst_offset = b * plane * c4 * C4NUM;
for (int c = 0; c < channel; c++) {
int c4_block_num = c / C4NUM;
int c4_block_rem = c % C4NUM;
int src_c_offset = src_offset + c * plane;
int dst_c_offset = dst_offset + c4_block_num * plane * C4NUM;
for (int k = 0; k < plane; k++) {
int src_kernel_offset = src_c_offset + k;
int dst_kernel_offset = dst_c_offset + C4NUM * k + c4_block_rem;
((float16_t *)dst + dst_kernel_offset)[0] = ((float16_t *)src + src_kernel_offset)[0];
}
}
}
}
void PackNHWCToNHWC4Fp16(const void *src, void *dst, int batch, int plane, int channel) {
int ic4 = UP_DIV(channel, C4NUM);
int nhwc4_batch_unit_offset = ic4 * C4NUM * plane;
int ic_remainder_ = channel % C4NUM;
if (ic_remainder_ != 0) {
int nhwc4_batch_offset = 0;
for (int b = 0; b < batch; b++) {
int batch_offset = b * channel * plane;
for (int i = 0; i < plane; i++) {
memcpy((float16_t *)dst + nhwc4_batch_offset + i * ic4 * C4NUM, (float16_t *)src + batch_offset + i * channel,
channel * sizeof(float16_t));
}
nhwc4_batch_offset += nhwc4_batch_unit_offset;
}
} else {
size_t ori_input_size = batch * plane * channel * sizeof(float16_t);
memcpy(dst, src, ori_input_size);
}
}
void PackNCHWToNHWC4Fp16(const void *src, void *dst, int batch, int plane, int channel) {
int nhwc4_batch_offset = 0;
int ic4 = UP_DIV(channel, C4NUM);
int nhwc4_batch_unit_offset = ic4 * C4NUM * plane;
for (int b = 0; b < batch; b++) {
int batch_offset = b * channel * plane;
for (int c = 0; c < channel; c++) {
int src_c_offset = batch_offset + c * plane;
int dst_c_offset = nhwc4_batch_offset + c;
for (int i = 0; i < plane; i++) {
int src_plane_offset = src_c_offset + i;
int dst_plane_offset = dst_c_offset + i * ic4 * C4NUM;
((float16_t *)dst)[dst_plane_offset] = ((float16_t *)src)[src_plane_offset];
}
}
nhwc4_batch_offset += nhwc4_batch_unit_offset;
}
}
void PackNC4HW4ToNHWC4Fp16(const void *src, void *dst, int batch, int plane, int channel) {
int c4 = UP_DIV(channel, C4NUM);
for (int b = 0; b < batch; b++) {
int src_offset = b * plane * c4 * C4NUM;
int dst_offset = b * plane * channel;
for (int c = 0; c < channel; c++) {
int c4_block_num = c / C4NUM;
int c4_block_res = c % C4NUM;
int src_c_offset = src_offset + c4_block_num * plane * C4NUM + c4_block_res;
int dst_c_offset = dst_offset + c4_block_num * C4NUM + c4_block_res;
for (int k = 0; k < plane; k++) {
int src_kernel_offset = src_c_offset + k * C4NUM;
int dst_kernel_offset = dst_c_offset + k * c4 * C4NUM;
((float16_t *)dst + dst_kernel_offset)[0] = ((float16_t *)src + src_kernel_offset)[0];
}
}
}
}
void PackNC4HW4ToNHWCFp16(const void *src, void *dst, int batch, int plane, int channel) {
int c4 = UP_DIV(channel, C4NUM);
for (int b = 0; b < batch; b++) {
int src_offset = b * plane * c4 * C4NUM;
int dst_offset = b * plane * channel;
for (int c = 0; c < channel; c++) {
int c4_block_num = c / C4NUM;
int c4_block_res = c % C4NUM;
int src_c_offset = src_offset + c4_block_num * plane * C4NUM + c4_block_res;
int dst_c_offset = dst_offset + c;
for (int k = 0; k < plane; k++) {
int src_kernel_offset = src_c_offset + k * C4NUM;
int dst_kernel_offset = dst_c_offset + k * channel;
((float16_t *)dst + dst_kernel_offset)[0] = ((float16_t *)src + src_kernel_offset)[0];
}
}
}
}
void PackNC4HW4ToNCHWFp16(const void *src, void *dst, int batch, int plane, int channel) {
int c4 = UP_DIV(channel, C4NUM);
for (int b = 0; b < batch; b++) {
int src_offset = b * plane * c4 * C4NUM;
int dst_offset = b * plane * channel;
for (int c = 0; c < channel; c++) {
int c4_block_num = c / C4NUM;
int c4_block_res = c % C4NUM;
int src_c_offset = src_offset + c4_block_num * plane * C4NUM + c4_block_res;
int dst_c_offset = dst_offset + c * plane;
for (int k = 0; k < plane; k++) {
int src_kernel_offset = src_c_offset + k * C4NUM;
int dst_kernel_offset = dst_c_offset + k;
((float16_t *)dst + dst_kernel_offset)[0] = ((float16_t *)src + src_kernel_offset)[0];
}
}
}
}
void PackNCHWFp32ToNC8HW8Fp16(float *src, float16_t *dst, int batch, int plane, int channel) {
int c8 = UP_DIV(channel, C8NUM);
for (int b = 0; b < batch; b++) {
int src_offset = b * plane * channel;
int dst_offset = b * plane * c8 * C8NUM;
for (int c = 0; c < channel; c++) {
int c8_block_num = c / C8NUM;
int c8_block_rem = c % C8NUM;
int src_c_offset = src_offset + c * plane;
int dst_c_offset = dst_offset + c8_block_num * plane * C8NUM;
for (int k = 0; k < plane; k++) {
int src_kernel_offset = src_c_offset + k;
int dst_kernel_offset = dst_c_offset + C8NUM * k + c8_block_rem;
(dst + dst_kernel_offset)[0] = (float16_t)(src + src_kernel_offset)[0];
}
}
}
}
void PackNHWCFp32ToNHWC8Fp16(float *src, float16_t *dst, int batch, int plane, int channel) {
int c8 = UP_DIV(channel, C8NUM);
int nhwc8_batch_unit_offset = c8 * C8NUM * plane;
int nhwc8_batch_offset = 0;
for (int b = 0; b < batch; b++) {
int batch_offset = b * channel * plane;
for (int i = 0; i < plane; i++) {
for (int c = 0; c < channel; c++) {
(dst + nhwc8_batch_offset + i * c8 * C8NUM)[c] = (float16_t)(src + batch_offset + i * channel)[c];
}
}
nhwc8_batch_offset += nhwc8_batch_unit_offset;
}
}
void PackNHWC8Fp16ToNHWCFp32(float16_t *src, float *dst, int batch, int plane, int channel) {
int c8 = UP_DIV(channel, C8NUM);
int nhwc_batch_unit_offset = channel * plane;
int nhwc_batch_offset = 0;
for (int b = 0; b < batch; b++) {
int batch_offset = b * c8 * C8NUM * plane;
for (int i = 0; i < plane; i++) {
for (int c = 0; c < channel; c++) {
(dst + nhwc_batch_offset + i * channel)[c] = (float)(src + batch_offset + i * c8 * C8NUM)[c];
}
}
nhwc_batch_offset += nhwc_batch_unit_offset;
}
}
#endif
void PackWeightFp32(float *weight_data, ConvParameter *conv_param, float *packed_weight) {
// original weight format : ohwi
// todo pack weight for arm32 platform

View File

@ -23,38 +23,6 @@
#include "src/runtime/kernel/arm/opclib/conv_parameter.h"
#include "src/runtime/kernel/arm/opclib/op_base.h"
#ifdef ENABLE_FP16
void Im2ColPackUnitFp16(float16_t *input_data, ConvParameter *conv_param, float16_t *packed_input, int real_cal_num,
int block_index);
void PackWeightFp16(float16_t *weight_data, ConvParameter *conv_param, float16_t *packed_weight);
void PackWeightToC8Fp16(const float16_t *origin_weight_data, float16_t *packed_weight_data, ConvParameter *conv_param);
void PackWeightToC4Fp16(const float16_t *origin_weight_data, float16_t *packed_weight_data, ConvParameter *conv_param);
void PackNHWCToNC4HW4Fp16(const void *src, void *dst, int batch, int plane, int channel);
void PackNCHWToNC4HW4Fp16(const void *src, void *dst, int batch, int plane, int channel);
void PackNHWCToNHWC4Fp16(const void *src, void *dst, int batch, int plane, int channel);
void PackNCHWToNHWC4Fp16(const void *src, void *dst, int batch, int plane, int channel);
void PackNC4HW4ToNHWC4Fp16(const void *src, void *dst, int batch, int plane, int channel);
void PackNC4HW4ToNHWCFp16(const void *src, void *dst, int batch, int plane, int channel);
void PackNC4HW4ToNCHWFp16(const void *src, void *dst, int batch, int plane, int channel);
void PackNC8HW8ToNHWCFp16(const void *src, void *dst, int batch, int plane, int channel);
void PackNCHWFp32ToNC8HW8Fp16(float *src, float16_t *dst, int batch, int plane, int channel);
void PackNHWCFp32ToNHWC8Fp16(float *src, float16_t *dst, int batch, int plane, int channel);
void PackNHWC8Fp16ToNHWCFp32(float16_t *src, float *dst, int batch, int plane, int channel);
#endif
void Im2ColPackUnitFp32(const float *input_data, ConvParameter *conv_param, float *packed_input, int real_cal_num,
int block_index);

View File

@ -675,518 +675,6 @@ void Conv3x3Fp32OutputTransform(const float *gemm_out, float *out_data, const fl
}
}
#ifdef ENABLE_FP16
// for fp16 convolution 3x3 filter/input/output transform F(4,3)
void Conv3x3Fp16InputUnit(float16_t *tmp_data, float16_t *trans_input_data, size_t step) {
float16x4_t d00 = vld1_f16(tmp_data);
float16x4_t d01 = vld1_f16(tmp_data + 4);
float16x4_t d02 = vld1_f16(tmp_data + 2 * 4);
float16x4_t d03 = vld1_f16(tmp_data + 3 * 4);
float16x4_t d04 = vld1_f16(tmp_data + 4 * 4);
float16x4_t d05 = vld1_f16(tmp_data + 5 * 4);
float16x4_t d10 = vld1_f16(tmp_data + 6 * 4);
float16x4_t d11 = vld1_f16(tmp_data + 7 * 4);
float16x4_t d12 = vld1_f16(tmp_data + 8 * 4);
float16x4_t d13 = vld1_f16(tmp_data + 9 * 4);
float16x4_t d14 = vld1_f16(tmp_data + 10 * 4);
float16x4_t d15 = vld1_f16(tmp_data + 11 * 4);
float16x4_t d20 = vld1_f16(tmp_data + 12 * 4);
float16x4_t d21 = vld1_f16(tmp_data + 13 * 4);
float16x4_t d22 = vld1_f16(tmp_data + 14 * 4);
float16x4_t d23 = vld1_f16(tmp_data + 15 * 4);
float16x4_t d24 = vld1_f16(tmp_data + 16 * 4);
float16x4_t d25 = vld1_f16(tmp_data + 17 * 4);
float16x4_t d30 = vld1_f16(tmp_data + 18 * 4);
float16x4_t d31 = vld1_f16(tmp_data + 19 * 4);
float16x4_t d32 = vld1_f16(tmp_data + 20 * 4);
float16x4_t d33 = vld1_f16(tmp_data + 21 * 4);
float16x4_t d34 = vld1_f16(tmp_data + 22 * 4);
float16x4_t d35 = vld1_f16(tmp_data + 23 * 4);
float16x4_t d40 = vld1_f16(tmp_data + 24 * 4);
float16x4_t d41 = vld1_f16(tmp_data + 25 * 4);
float16x4_t d42 = vld1_f16(tmp_data + 26 * 4);
float16x4_t d43 = vld1_f16(tmp_data + 27 * 4);
float16x4_t d44 = vld1_f16(tmp_data + 28 * 4);
float16x4_t d45 = vld1_f16(tmp_data + 29 * 4);
float16x4_t d50 = vld1_f16(tmp_data + 30 * 4);
float16x4_t d51 = vld1_f16(tmp_data + 31 * 4);
float16x4_t d52 = vld1_f16(tmp_data + 32 * 4);
float16x4_t d53 = vld1_f16(tmp_data + 33 * 4);
float16x4_t d54 = vld1_f16(tmp_data + 34 * 4);
float16x4_t d55 = vld1_f16(tmp_data + 35 * 4);
float16x4_t t00 = vadd_f16(vsub_f16(vmul_n_f16(d00, 4), vmul_n_f16(d20, 5)), d40);
float16x4_t t01 = vadd_f16(vsub_f16(vmul_n_f16(d01, 4), vmul_n_f16(d21, 5)), d41);
float16x4_t t02 = vadd_f16(vsub_f16(vmul_n_f16(d02, 4), vmul_n_f16(d22, 5)), d42);
float16x4_t t03 = vadd_f16(vsub_f16(vmul_n_f16(d03, 4), vmul_n_f16(d23, 5)), d43);
float16x4_t t04 = vadd_f16(vsub_f16(vmul_n_f16(d04, 4), vmul_n_f16(d24, 5)), d44);
float16x4_t t05 = vadd_f16(vsub_f16(vmul_n_f16(d05, 4), vmul_n_f16(d25, 5)), d45);
float16x4_t t10 = vadd_f16(vadd_f16(d30, d40), vmul_n_f16(vadd_f16(d10, d20), -4));
float16x4_t t11 = vadd_f16(vadd_f16(d31, d41), vmul_n_f16(vadd_f16(d11, d21), -4));
float16x4_t t12 = vadd_f16(vadd_f16(d32, d42), vmul_n_f16(vadd_f16(d12, d22), -4));
float16x4_t t13 = vadd_f16(vadd_f16(d33, d43), vmul_n_f16(vadd_f16(d13, d23), -4));
float16x4_t t14 = vadd_f16(vadd_f16(d34, d44), vmul_n_f16(vadd_f16(d14, d24), -4));
float16x4_t t15 = vadd_f16(vadd_f16(d35, d45), vmul_n_f16(vadd_f16(d15, d25), -4));
float16x4_t t20 = vadd_f16(vsub_f16(d40, d30), vmul_n_f16(vsub_f16(d10, d20), 4));
float16x4_t t21 = vadd_f16(vsub_f16(d41, d31), vmul_n_f16(vsub_f16(d11, d21), 4));
float16x4_t t22 = vadd_f16(vsub_f16(d42, d32), vmul_n_f16(vsub_f16(d12, d22), 4));
float16x4_t t23 = vadd_f16(vsub_f16(d43, d33), vmul_n_f16(vsub_f16(d13, d23), 4));
float16x4_t t24 = vadd_f16(vsub_f16(d44, d34), vmul_n_f16(vsub_f16(d14, d24), 4));
float16x4_t t25 = vadd_f16(vsub_f16(d45, d35), vmul_n_f16(vsub_f16(d15, d25), 4));
float16x4_t t30 = vadd_f16(vsub_f16(d40, d20), vmul_n_f16(vsub_f16(d30, d10), 2));
float16x4_t t31 = vadd_f16(vsub_f16(d41, d21), vmul_n_f16(vsub_f16(d31, d11), 2));
float16x4_t t32 = vadd_f16(vsub_f16(d42, d22), vmul_n_f16(vsub_f16(d32, d12), 2));
float16x4_t t33 = vadd_f16(vsub_f16(d43, d23), vmul_n_f16(vsub_f16(d33, d13), 2));
float16x4_t t34 = vadd_f16(vsub_f16(d44, d24), vmul_n_f16(vsub_f16(d34, d14), 2));
float16x4_t t35 = vadd_f16(vsub_f16(d45, d25), vmul_n_f16(vsub_f16(d35, d15), 2));
float16x4_t t40 = vadd_f16(vsub_f16(d40, d20), vmul_n_f16(vsub_f16(d10, d30), 2));
float16x4_t t41 = vadd_f16(vsub_f16(d41, d21), vmul_n_f16(vsub_f16(d11, d31), 2));
float16x4_t t42 = vadd_f16(vsub_f16(d42, d22), vmul_n_f16(vsub_f16(d12, d32), 2));
float16x4_t t43 = vadd_f16(vsub_f16(d43, d23), vmul_n_f16(vsub_f16(d13, d33), 2));
float16x4_t t44 = vadd_f16(vsub_f16(d44, d24), vmul_n_f16(vsub_f16(d14, d34), 2));
float16x4_t t45 = vadd_f16(vsub_f16(d45, d25), vmul_n_f16(vsub_f16(d15, d35), 2));
float16x4_t t50 = vadd_f16(vsub_f16(vmul_n_f16(d10, 4), vmul_n_f16(d30, 5)), d50);
float16x4_t t51 = vadd_f16(vsub_f16(vmul_n_f16(d11, 4), vmul_n_f16(d31, 5)), d51);
float16x4_t t52 = vadd_f16(vsub_f16(vmul_n_f16(d12, 4), vmul_n_f16(d32, 5)), d52);
float16x4_t t53 = vadd_f16(vsub_f16(vmul_n_f16(d13, 4), vmul_n_f16(d33, 5)), d53);
float16x4_t t54 = vadd_f16(vsub_f16(vmul_n_f16(d14, 4), vmul_n_f16(d34, 5)), d54);
float16x4_t t55 = vadd_f16(vsub_f16(vmul_n_f16(d15, 4), vmul_n_f16(d35, 5)), d55);
float16x4_t m00 = vadd_f16(vsub_f16(vmul_n_f16(t00, 4), vmul_n_f16(t02, 5)), t04);
float16x4_t m01 = vadd_f16(vadd_f16(t03, t04), vmul_n_f16(vadd_f16(t01, t02), -4));
float16x4_t m02 = vadd_f16(vsub_f16(t04, t03), vmul_n_f16(vsub_f16(t01, t02), 4));
float16x4_t m03 = vadd_f16(vsub_f16(t04, t02), vmul_n_f16(vsub_f16(t03, t01), 2));
float16x4_t m04 = vadd_f16(vsub_f16(t04, t02), vmul_n_f16(vsub_f16(t01, t03), 2));
float16x4_t m05 = vadd_f16(vsub_f16(vmul_n_f16(t01, 4), vmul_n_f16(t03, 5)), t05);
float16x4_t m10 = vadd_f16(vsub_f16(vmul_n_f16(t10, 4), vmul_n_f16(t12, 5)), t14);
float16x4_t m11 = vadd_f16(vadd_f16(t13, t14), vmul_n_f16(vadd_f16(t11, t12), -4));
float16x4_t m12 = vadd_f16(vsub_f16(t14, t13), vmul_n_f16(vsub_f16(t11, t12), 4));
float16x4_t m13 = vadd_f16(vsub_f16(t14, t12), vmul_n_f16(vsub_f16(t13, t11), 2));
float16x4_t m14 = vadd_f16(vsub_f16(t14, t12), vmul_n_f16(vsub_f16(t11, t13), 2));
float16x4_t m15 = vadd_f16(vsub_f16(vmul_n_f16(t11, 4), vmul_n_f16(t13, 5)), t15);
float16x4_t m20 = vadd_f16(vsub_f16(vmul_n_f16(t20, 4), vmul_n_f16(t22, 5)), t24);
float16x4_t m21 = vadd_f16(vadd_f16(t23, t24), vmul_n_f16(vadd_f16(t21, t22), -4));
float16x4_t m22 = vadd_f16(vsub_f16(t24, t23), vmul_n_f16(vsub_f16(t21, t22), 4));
float16x4_t m23 = vadd_f16(vsub_f16(t24, t22), vmul_n_f16(vsub_f16(t23, t21), 2));
float16x4_t m24 = vadd_f16(vsub_f16(t24, t22), vmul_n_f16(vsub_f16(t21, t23), 2));
float16x4_t m25 = vadd_f16(vsub_f16(vmul_n_f16(t21, 4), vmul_n_f16(t23, 5)), t25);
float16x4_t m30 = vadd_f16(vsub_f16(vmul_n_f16(t30, 4), vmul_n_f16(t32, 5)), t34);
float16x4_t m31 = vadd_f16(vadd_f16(t33, t34), vmul_n_f16(vadd_f16(t31, t32), -4));
float16x4_t m32 = vadd_f16(vsub_f16(t34, t33), vmul_n_f16(vsub_f16(t31, t32), 4));
float16x4_t m33 = vadd_f16(vsub_f16(t34, t32), vmul_n_f16(vsub_f16(t33, t31), 2));
float16x4_t m34 = vadd_f16(vsub_f16(t34, t32), vmul_n_f16(vsub_f16(t31, t33), 2));
float16x4_t m35 = vadd_f16(vsub_f16(vmul_n_f16(t31, 4), vmul_n_f16(t33, 5)), t35);
float16x4_t m40 = vadd_f16(vsub_f16(vmul_n_f16(t40, 4), vmul_n_f16(t42, 5)), t44);
float16x4_t m41 = vadd_f16(vadd_f16(t43, t44), vmul_n_f16(vadd_f16(t41, t42), -4));
float16x4_t m42 = vadd_f16(vsub_f16(t44, t43), vmul_n_f16(vsub_f16(t41, t42), 4));
float16x4_t m43 = vadd_f16(vsub_f16(t44, t42), vmul_n_f16(vsub_f16(t43, t41), 2));
float16x4_t m44 = vadd_f16(vsub_f16(t44, t42), vmul_n_f16(vsub_f16(t41, t43), 2));
float16x4_t m45 = vadd_f16(vsub_f16(vmul_n_f16(t41, 4), vmul_n_f16(t43, 5)), t45);
float16x4_t m50 = vadd_f16(vsub_f16(vmul_n_f16(t50, 4), vmul_n_f16(t52, 5)), t54);
float16x4_t m51 = vadd_f16(vadd_f16(t53, t54), vmul_n_f16(vadd_f16(t51, t52), -4));
float16x4_t m52 = vadd_f16(vsub_f16(t54, t53), vmul_n_f16(vsub_f16(t51, t52), 4));
float16x4_t m53 = vadd_f16(vsub_f16(t54, t52), vmul_n_f16(vsub_f16(t53, t51), 2));
float16x4_t m54 = vadd_f16(vsub_f16(t54, t52), vmul_n_f16(vsub_f16(t51, t53), 2));
float16x4_t m55 = vadd_f16(vsub_f16(vmul_n_f16(t51, 4), vmul_n_f16(t53, 5)), t55);
vst1_f16(trans_input_data, m00);
vst1_f16(trans_input_data + step, m01);
vst1_f16(trans_input_data + 2 * step, m02);
vst1_f16(trans_input_data + 3 * step, m03);
vst1_f16(trans_input_data + 4 * step, m04);
vst1_f16(trans_input_data + 5 * step, m05);
vst1_f16(trans_input_data + 6 * step, m10);
vst1_f16(trans_input_data + 7 * step, m11);
vst1_f16(trans_input_data + 8 * step, m12);
vst1_f16(trans_input_data + 9 * step, m13);
vst1_f16(trans_input_data + 10 * step, m14);
vst1_f16(trans_input_data + 11 * step, m15);
vst1_f16(trans_input_data + 12 * step, m20);
vst1_f16(trans_input_data + 13 * step, m21);
vst1_f16(trans_input_data + 14 * step, m22);
vst1_f16(trans_input_data + 15 * step, m23);
vst1_f16(trans_input_data + 16 * step, m24);
vst1_f16(trans_input_data + 17 * step, m25);
vst1_f16(trans_input_data + 18 * step, m30);
vst1_f16(trans_input_data + 19 * step, m31);
vst1_f16(trans_input_data + 20 * step, m32);
vst1_f16(trans_input_data + 21 * step, m33);
vst1_f16(trans_input_data + 22 * step, m34);
vst1_f16(trans_input_data + 23 * step, m35);
vst1_f16(trans_input_data + 24 * step, m40);
vst1_f16(trans_input_data + 25 * step, m41);
vst1_f16(trans_input_data + 26 * step, m42);
vst1_f16(trans_input_data + 27 * step, m43);
vst1_f16(trans_input_data + 28 * step, m44);
vst1_f16(trans_input_data + 29 * step, m45);
vst1_f16(trans_input_data + 30 * step, m50);
vst1_f16(trans_input_data + 31 * step, m51);
vst1_f16(trans_input_data + 32 * step, m52);
vst1_f16(trans_input_data + 33 * step, m53);
vst1_f16(trans_input_data + 34 * step, m54);
vst1_f16(trans_input_data + 35 * step, m55);
}
void Conv3x3Fp16InputTransform(const float16_t *input_data, float16_t *trans_input, float16_t *tmp_data,
int start_index, int real_cal_num, int out_w_block, ConvParameter *conv_param) {
// input data format : nhwc
int output_unit = 4;
int input_channel = conv_param->input_channel_;
int input_width = conv_param->input_w_;
int input_height = conv_param->input_h_;
int pad_w = conv_param->pad_w_;
int pad_h = conv_param->pad_h_;
int ic4 = UP_DIV(input_channel, C4NUM);
for (int cal_id = 0; cal_id < real_cal_num; cal_id++) {
int x_id = start_index + cal_id;
int origin_x = (x_id % out_w_block) * output_unit - pad_w;
int origin_y = (x_id / out_w_block) * output_unit - pad_h;
int real_x_start = origin_x > 0 ? 0 : -origin_x;
int real_x_end = (origin_x + 6) < input_width ? 6 : (input_width - origin_x);
int real_y_start = origin_y > 0 ? 0 : -origin_y;
int real_y_end = (origin_y + 6) < input_height ? 6 : (input_height - origin_y);
int src_plane_offset = input_channel * (origin_y * input_width + origin_x);
int dst_plane_offset = cal_id * C4NUM;
for (int ic = 0; ic < ic4; ic++) {
// clear tmp buffer
memset(tmp_data, 0, 6 * 6 * C4NUM * sizeof(float16_t));
// get real input block with padding
int src_ic4_offset = src_plane_offset + ic * C4NUM;
for (int interval = real_y_start; interval < real_y_end; interval++) {
int src_y_offset = src_ic4_offset + interval * input_width * input_channel + real_x_start * input_channel;
int dst_y_offset = interval * 6 * C4NUM + real_x_start * C4NUM;
for (int j = 0; j < (real_x_end - real_x_start); j++) {
int src_x_offset = src_y_offset + j * input_channel;
int dst_x_offset = dst_y_offset + j * C4NUM;
float16_t *src_addr = (float16_t *)(input_data) + src_x_offset;
float16_t *dst_addr = tmp_data + dst_x_offset;
dst_addr[0] = src_addr[0];
dst_addr[1] = src_addr[1];
dst_addr[2] = src_addr[2];
dst_addr[3] = src_addr[3];
}
}
// todo
// input transform
int dst_ic4_offset = dst_plane_offset + ic * 16 * C4NUM;
size_t dst_step = ic4 * C4NUM * 16;
float16_t *trans_input_ptr = trans_input + dst_ic4_offset;
Conv3x3Fp16InputUnit(tmp_data, trans_input_ptr, dst_step);
}
}
}
void Conv3x3Fp16FilterTransform(const float16_t *weight_data, float16_t *trans_weight, int iC4, int output_channel,
int kernel_plane) {
int dst_step = iC4 * C4NUM * 8;
for (int o = 0; o < output_channel; o++) {
int oc8_block_num = o / C8NUM;
int oc8_block_rem = o % C8NUM;
int src_oc_offset = o * iC4 * C4NUM * kernel_plane;
int dst_oc_offset = oc8_block_num * C8NUM * iC4 * C4NUM * 36 + oc8_block_rem;
for (int i = 0; i < iC4; i++) {
const float16_t *src_ic4_ptr = weight_data + src_oc_offset + i * kernel_plane * C4NUM;
float16_t *dst_ic4_ptr = trans_weight + dst_oc_offset + i * 8 * C4NUM;
float16x4_t g00 = vld1_f16(src_ic4_ptr);
float16x4_t g01 = vld1_f16(src_ic4_ptr + 4);
float16x4_t g02 = vld1_f16(src_ic4_ptr + 2 * 4);
float16x4_t g10 = vld1_f16(src_ic4_ptr + 3 * 4);
float16x4_t g11 = vld1_f16(src_ic4_ptr + 4 * 4);
float16x4_t g12 = vld1_f16(src_ic4_ptr + 5 * 4);
float16x4_t g20 = vld1_f16(src_ic4_ptr + 6 * 4);
float16x4_t g21 = vld1_f16(src_ic4_ptr + 7 * 4);
float16x4_t g22 = vld1_f16(src_ic4_ptr + 8 * 4);
float16x4_t dst00 = vmul_n_f16(g00, 0.25);
float16x4_t dst01 = vmul_n_f16(g01, 0.25);
float16x4_t dst02 = vmul_n_f16(g02, 0.25);
float16x4_t dst10 = vmul_n_f16(vadd_f16(g00, vadd_f16(g10, g20)), -0.1666666666667);
float16x4_t dst11 = vmul_n_f16(vadd_f16(g01, vadd_f16(g11, g21)), -0.1666666666667);
float16x4_t dst12 = vmul_n_f16(vadd_f16(g02, vadd_f16(g12, g22)), -0.1666666666667);
float16x4_t dst20 = vmul_n_f16(vsub_f16(vadd_f16(g00, g20), g10), -0.1666666666667);
float16x4_t dst21 = vmul_n_f16(vsub_f16(vadd_f16(g01, g21), g11), -0.1666666666667);
float16x4_t dst22 = vmul_n_f16(vsub_f16(vadd_f16(g02, g22), g12), -0.1666666666667);
float16x4_t dst30 = vadd_f16(vmul_n_f16(g10, 0.08333333333333),
vadd_f16(vmul_n_f16(g00, 0.04166666666667), vmul_n_f16(g20, 0.1666666666667)));
float16x4_t dst31 = vadd_f16(vmul_n_f16(g11, 0.08333333333333),
vadd_f16(vmul_n_f16(g01, 0.04166666666667), vmul_n_f16(g21, 0.1666666666667)));
float16x4_t dst32 = vadd_f16(vmul_n_f16(g12, 0.08333333333333),
vadd_f16(vmul_n_f16(g02, 0.04166666666667), vmul_n_f16(g22, 0.1666666666667)));
float16x4_t dst40 = vsub_f16(vadd_f16(vmul_n_f16(g00, 0.04166666666667), vmul_n_f16(g20, 0.1666666666667)),
vmul_n_f16(g10, 0.08333333333333));
float16x4_t dst41 = vsub_f16(vadd_f16(vmul_n_f16(g01, 0.04166666666667), vmul_n_f16(g21, 0.1666666666667)),
vmul_n_f16(g11, 0.08333333333333));
float16x4_t dst42 = vsub_f16(vadd_f16(vmul_n_f16(g02, 0.04166666666667), vmul_n_f16(g22, 0.1666666666667)),
vmul_n_f16(g12, 0.08333333333333));
float16x4_t dst50 = g20;
float16x4_t dst51 = g21;
float16x4_t dst52 = g22;
float16x4_t m00 = vmul_n_f16(dst00, 0.25);
float16x4_t m01 = vmul_n_f16(vadd_f16(dst00, vadd_f16(dst01, dst02)), -0.1666666666667);
float16x4_t m02 = vmul_n_f16(vsub_f16(vadd_f16(dst00, dst02), dst01), -0.1666666666667);
float16x4_t m03 = vadd_f16(vmul_n_f16(dst01, 0.08333333333333),
vadd_f16(vmul_n_f16(dst00, 0.04166666666667), vmul_n_f16(dst02, 0.1666666666667)));
float16x4_t m04 = vsub_f16(vadd_f16(vmul_n_f16(dst00, 0.04166666666667), vmul_n_f16(dst02, 0.1666666666667)),
vmul_n_f16(dst01, 0.08333333333333));
float16x4_t m05 = dst02;
float16x4_t m10 = vmul_n_f16(dst10, 0.25);
float16x4_t m11 = vmul_n_f16(vadd_f16(dst10, vadd_f16(dst11, dst12)), -0.1666666666667);
float16x4_t m12 = vmul_n_f16(vsub_f16(vadd_f16(dst10, dst12), dst11), -0.1666666666667);
float16x4_t m13 = vadd_f16(vmul_n_f16(dst11, 0.08333333333333),
vadd_f16(vmul_n_f16(dst10, 0.04166666666667), vmul_n_f16(dst12, 0.1666666666667)));
float16x4_t m14 = vsub_f16(vadd_f16(vmul_n_f16(dst10, 0.04166666666667), vmul_n_f16(dst12, 0.1666666666667)),
vmul_n_f16(dst11, 0.08333333333333));
float16x4_t m15 = dst12;
float16x4_t m20 = vmul_n_f16(dst20, 0.25);
float16x4_t m21 = vmul_n_f16(vadd_f16(dst20, vadd_f16(dst21, dst22)), -0.1666666666667);
float16x4_t m22 = vmul_n_f16(vsub_f16(vadd_f16(dst20, dst22), dst21), -0.1666666666667);
float16x4_t m23 = vadd_f16(vmul_n_f16(dst21, 0.08333333333333),
vadd_f16(vmul_n_f16(dst20, 0.04166666666667), vmul_n_f16(dst22, 0.1666666666667)));
float16x4_t m24 = vsub_f16(vadd_f16(vmul_n_f16(dst20, 0.04166666666667), vmul_n_f16(dst22, 0.1666666666667)),
vmul_n_f16(dst21, 0.08333333333333));
float16x4_t m25 = dst22;
float16x4_t m30 = vmul_n_f16(dst30, 0.25);
float16x4_t m31 = vmul_n_f16(vadd_f16(dst30, vadd_f16(dst31, dst32)), -0.1666666666667);
float16x4_t m32 = vmul_n_f16(vsub_f16(vadd_f16(dst30, dst32), dst31), -0.1666666666667);
float16x4_t m33 = vadd_f16(vmul_n_f16(dst31, 0.08333333333333),
vadd_f16(vmul_n_f16(dst30, 0.04166666666667), vmul_n_f16(dst32, 0.1666666666667)));
float16x4_t m34 = vsub_f16(vadd_f16(vmul_n_f16(dst30, 0.04166666666667), vmul_n_f16(dst32, 0.1666666666667)),
vmul_n_f16(dst31, 0.08333333333333));
float16x4_t m35 = dst32;
float16x4_t m40 = vmul_n_f16(dst40, 0.25);
float16x4_t m41 = vmul_n_f16(vadd_f16(dst40, vadd_f16(dst41, dst42)), -0.1666666666667);
float16x4_t m42 = vmul_n_f16(vsub_f16(vadd_f16(dst40, dst42), dst41), -0.1666666666667);
float16x4_t m43 = vadd_f16(vmul_n_f16(dst41, 0.08333333333333),
vadd_f16(vmul_n_f16(dst40, 0.04166666666667), vmul_n_f16(dst42, 0.1666666666667)));
float16x4_t m44 = vsub_f16(vadd_f16(vmul_n_f16(dst40, 0.04166666666667), vmul_n_f16(dst42, 0.1666666666667)),
vmul_n_f16(dst41, 0.08333333333333));
float16x4_t m45 = dst42;
float16x4_t m50 = vmul_n_f16(dst50, 0.25);
float16x4_t m51 = vmul_n_f16(vadd_f16(dst50, vadd_f16(dst51, dst52)), -0.1666666666667);
float16x4_t m52 = vmul_n_f16(vsub_f16(vadd_f16(dst50, dst52), dst51), -0.1666666666667);
float16x4_t m53 = vadd_f16(vmul_n_f16(dst51, 0.08333333333333),
vadd_f16(vmul_n_f16(dst50, 0.04166666666667), vmul_n_f16(dst52, 0.1666666666667)));
float16x4_t m54 = vsub_f16(vadd_f16(vmul_n_f16(dst50, 0.04166666666667), vmul_n_f16(dst52, 0.1666666666667)),
vmul_n_f16(dst51, 0.08333333333333));
float16x4_t m55 = dst52;
for (int j = 0; j < 4; j++) {
dst_ic4_ptr[j * 8] = m00[j];
dst_ic4_ptr[j * 8 + dst_step] = m01[j];
dst_ic4_ptr[j * 8 + 2 * dst_step] = m02[j];
dst_ic4_ptr[j * 8 + 3 * dst_step] = m03[j];
dst_ic4_ptr[j * 8 + 4 * dst_step] = m04[j];
dst_ic4_ptr[j * 8 + 5 * dst_step] = m05[j];
dst_ic4_ptr[j * 8 + 6 * dst_step] = m10[j];
dst_ic4_ptr[j * 8 + 7 * dst_step] = m11[j];
dst_ic4_ptr[j * 8 + 8 * dst_step] = m12[j];
dst_ic4_ptr[j * 8 + 9 * dst_step] = m13[j];
dst_ic4_ptr[j * 8 + 10 * dst_step] = m14[j];
dst_ic4_ptr[j * 8 + 11 * dst_step] = m15[j];
dst_ic4_ptr[j * 8 + 12 * dst_step] = m20[j];
dst_ic4_ptr[j * 8 + 13 * dst_step] = m21[j];
dst_ic4_ptr[j * 8 + 14 * dst_step] = m22[j];
dst_ic4_ptr[j * 8 + 15 * dst_step] = m23[j];
dst_ic4_ptr[j * 8 + 16 * dst_step] = m24[j];
dst_ic4_ptr[j * 8 + 17 * dst_step] = m25[j];
dst_ic4_ptr[j * 8 + 18 * dst_step] = m30[j];
dst_ic4_ptr[j * 8 + 19 * dst_step] = m31[j];
dst_ic4_ptr[j * 8 + 20 * dst_step] = m32[j];
dst_ic4_ptr[j * 8 + 21 * dst_step] = m33[j];
dst_ic4_ptr[j * 8 + 22 * dst_step] = m34[j];
dst_ic4_ptr[j * 8 + 23 * dst_step] = m35[j];
dst_ic4_ptr[j * 8 + 24 * dst_step] = m40[j];
dst_ic4_ptr[j * 8 + 25 * dst_step] = m41[j];
dst_ic4_ptr[j * 8 + 26 * dst_step] = m42[j];
dst_ic4_ptr[j * 8 + 27 * dst_step] = m43[j];
dst_ic4_ptr[j * 8 + 28 * dst_step] = m44[j];
dst_ic4_ptr[j * 8 + 29 * dst_step] = m45[j];
dst_ic4_ptr[j * 8 + 30 * dst_step] = m50[j];
dst_ic4_ptr[j * 8 + 31 * dst_step] = m51[j];
dst_ic4_ptr[j * 8 + 32 * dst_step] = m52[j];
dst_ic4_ptr[j * 8 + 33 * dst_step] = m53[j];
dst_ic4_ptr[j * 8 + 34 * dst_step] = m54[j];
dst_ic4_ptr[j * 8 + 35 * dst_step] = m55[j];
}
}
}
}
void Conv3x3Fp16OutputUnit(const float16_t *gemm_out, const float16_t *bias_data, float16_t *output_data,
int output_w) {
float16x8_t s00 = vld1q_f16(gemm_out);
float16x8_t s01 = vld1q_f16(gemm_out + 8);
float16x8_t s02 = vld1q_f16(gemm_out + 16);
float16x8_t s03 = vld1q_f16(gemm_out + 24);
float16x8_t s04 = vld1q_f16(gemm_out + 32);
float16x8_t s05 = vld1q_f16(gemm_out + 40);
float16x8_t s10 = vld1q_f16(gemm_out + 48);
float16x8_t s11 = vld1q_f16(gemm_out + 56);
float16x8_t s12 = vld1q_f16(gemm_out + 64);
float16x8_t s13 = vld1q_f16(gemm_out + 72);
float16x8_t s14 = vld1q_f16(gemm_out + 80);
float16x8_t s15 = vld1q_f16(gemm_out + 88);
float16x8_t s20 = vld1q_f16(gemm_out + 96);
float16x8_t s21 = vld1q_f16(gemm_out + 104);
float16x8_t s22 = vld1q_f16(gemm_out + 112);
float16x8_t s23 = vld1q_f16(gemm_out + 120);
float16x8_t s24 = vld1q_f16(gemm_out + 128);
float16x8_t s25 = vld1q_f16(gemm_out + 136);
float16x8_t s30 = vld1q_f16(gemm_out + 144);
float16x8_t s31 = vld1q_f16(gemm_out + 152);
float16x8_t s32 = vld1q_f16(gemm_out + 160);
float16x8_t s33 = vld1q_f16(gemm_out + 168);
float16x8_t s34 = vld1q_f16(gemm_out + 176);
float16x8_t s35 = vld1q_f16(gemm_out + 184);
float16x8_t s40 = vld1q_f16(gemm_out + 192);
float16x8_t s41 = vld1q_f16(gemm_out + 200);
float16x8_t s42 = vld1q_f16(gemm_out + 208);
float16x8_t s43 = vld1q_f16(gemm_out + 216);
float16x8_t s44 = vld1q_f16(gemm_out + 224);
float16x8_t s45 = vld1q_f16(gemm_out + 232);
float16x8_t s50 = vld1q_f16(gemm_out + 240);
float16x8_t s51 = vld1q_f16(gemm_out + 248);
float16x8_t s52 = vld1q_f16(gemm_out + 256);
float16x8_t s53 = vld1q_f16(gemm_out + 264);
float16x8_t s54 = vld1q_f16(gemm_out + 272);
float16x8_t s55 = vld1q_f16(gemm_out + 280);
float16x8_t t00 = vaddq_f16(vaddq_f16(vaddq_f16(s00, s10), vaddq_f16(s20, s30)), s40);
float16x8_t t01 = vaddq_f16(vaddq_f16(vaddq_f16(s01, s11), vaddq_f16(s21, s31)), s41);
float16x8_t t02 = vaddq_f16(vaddq_f16(vaddq_f16(s02, s12), vaddq_f16(s22, s32)), s42);
float16x8_t t03 = vaddq_f16(vaddq_f16(vaddq_f16(s03, s13), vaddq_f16(s23, s33)), s43);
float16x8_t t04 = vaddq_f16(vaddq_f16(vaddq_f16(s04, s14), vaddq_f16(s24, s34)), s44);
float16x8_t t05 = vaddq_f16(vaddq_f16(vaddq_f16(s05, s15), vaddq_f16(s25, s35)), s45);
float16x8_t t10 = vaddq_f16(vsubq_f16(s10, s20), vmulq_n_f16(vsubq_f16(s30, s40), 2));
float16x8_t t11 = vaddq_f16(vsubq_f16(s11, s21), vmulq_n_f16(vsubq_f16(s31, s41), 2));
float16x8_t t12 = vaddq_f16(vsubq_f16(s12, s22), vmulq_n_f16(vsubq_f16(s32, s42), 2));
float16x8_t t13 = vaddq_f16(vsubq_f16(s13, s23), vmulq_n_f16(vsubq_f16(s33, s43), 2));
float16x8_t t14 = vaddq_f16(vsubq_f16(s14, s24), vmulq_n_f16(vsubq_f16(s34, s44), 2));
float16x8_t t15 = vaddq_f16(vsubq_f16(s15, s25), vmulq_n_f16(vsubq_f16(s35, s45), 2));
float16x8_t t20 = vaddq_f16(vaddq_f16(s10, s20), vmulq_n_f16(vaddq_f16(s30, s40), 4));
float16x8_t t21 = vaddq_f16(vaddq_f16(s11, s21), vmulq_n_f16(vaddq_f16(s31, s41), 4));
float16x8_t t22 = vaddq_f16(vaddq_f16(s12, s22), vmulq_n_f16(vaddq_f16(s32, s42), 4));
float16x8_t t23 = vaddq_f16(vaddq_f16(s13, s23), vmulq_n_f16(vaddq_f16(s33, s43), 4));
float16x8_t t24 = vaddq_f16(vaddq_f16(s14, s24), vmulq_n_f16(vaddq_f16(s34, s44), 4));
float16x8_t t25 = vaddq_f16(vaddq_f16(s15, s25), vmulq_n_f16(vaddq_f16(s35, s45), 4));
float16x8_t t30 = vaddq_f16(vaddq_f16(vsubq_f16(s10, s20), vmulq_n_f16(vsubq_f16(s30, s40), 8)), s50);
float16x8_t t31 = vaddq_f16(vaddq_f16(vsubq_f16(s11, s21), vmulq_n_f16(vsubq_f16(s31, s41), 8)), s51);
float16x8_t t32 = vaddq_f16(vaddq_f16(vsubq_f16(s12, s22), vmulq_n_f16(vsubq_f16(s32, s42), 8)), s52);
float16x8_t t33 = vaddq_f16(vaddq_f16(vsubq_f16(s13, s23), vmulq_n_f16(vsubq_f16(s33, s43), 8)), s53);
float16x8_t t34 = vaddq_f16(vaddq_f16(vsubq_f16(s14, s24), vmulq_n_f16(vsubq_f16(s34, s44), 8)), s54);
float16x8_t t35 = vaddq_f16(vaddq_f16(vsubq_f16(s15, s25), vmulq_n_f16(vsubq_f16(s35, s45), 8)), s55);
float16x8_t d00 = vaddq_f16(vaddq_f16(vaddq_f16(t00, t01), vaddq_f16(t02, t03)), t04);
float16x8_t d01 = vaddq_f16(vsubq_f16(t01, t02), vmulq_n_f16(vsubq_f16(t03, t04), 2));
float16x8_t d02 = vaddq_f16(vaddq_f16(t01, t02), vmulq_n_f16(vaddq_f16(t03, t04), 4));
float16x8_t d03 = vaddq_f16(vaddq_f16(vsubq_f16(t01, t02), vmulq_n_f16(vsubq_f16(t03, t04), 8)), t05);
float16x8_t d10 = vaddq_f16(vaddq_f16(vaddq_f16(t10, t11), vaddq_f16(t12, t13)), t14);
float16x8_t d11 = vaddq_f16(vsubq_f16(t11, t12), vmulq_n_f16(vsubq_f16(t13, t14), 2));
float16x8_t d12 = vaddq_f16(vaddq_f16(t11, t12), vmulq_n_f16(vaddq_f16(t13, t14), 4));
float16x8_t d13 = vaddq_f16(vaddq_f16(vsubq_f16(t11, t12), vmulq_n_f16(vsubq_f16(t13, t14), 8)), t15);
float16x8_t d20 = vaddq_f16(vaddq_f16(vaddq_f16(t20, t21), vaddq_f16(t22, t23)), t24);
float16x8_t d21 = vaddq_f16(vsubq_f16(t21, t22), vmulq_n_f16(vsubq_f16(t23, t24), 2));
float16x8_t d22 = vaddq_f16(vaddq_f16(t21, t22), vmulq_n_f16(vaddq_f16(t23, t24), 4));
float16x8_t d23 = vaddq_f16(vaddq_f16(vsubq_f16(t21, t22), vmulq_n_f16(vsubq_f16(t23, t24), 8)), t25);
float16x8_t d30 = vaddq_f16(vaddq_f16(vaddq_f16(t30, t31), vaddq_f16(t32, t33)), t34);
float16x8_t d31 = vaddq_f16(vsubq_f16(t31, t32), vmulq_n_f16(vsubq_f16(t33, t34), 2));
float16x8_t d32 = vaddq_f16(vaddq_f16(t31, t32), vmulq_n_f16(vaddq_f16(t33, t34), 4));
float16x8_t d33 = vaddq_f16(vaddq_f16(vsubq_f16(t31, t32), vmulq_n_f16(vsubq_f16(t33, t34), 8)), t35);
vst1q_f16(output_data, d00);
vst1q_f16(output_data + 8, d01);
vst1q_f16(output_data + 16, d02);
vst1q_f16(output_data + 24, d03);
vst1q_f16(output_data + output_w * 8, d10);
vst1q_f16(output_data + output_w * 8 + 8, d11);
vst1q_f16(output_data + output_w * 8 + 16, d12);
vst1q_f16(output_data + output_w * 8 + 24, d13);
vst1q_f16(output_data + 2 * output_w * 8, d20);
vst1q_f16(output_data + 2 * output_w * 8 + 8, d21);
vst1q_f16(output_data + 2 * output_w * 8 + 16, d22);
vst1q_f16(output_data + 2 * output_w * 8 + 24, d23);
vst1q_f16(output_data + 3 * output_w * 8, d30);
vst1q_f16(output_data + 3 * output_w * 8 + 8, d31);
vst1q_f16(output_data + 3 * output_w * 8 + 16, d32);
vst1q_f16(output_data + 3 * output_w * 8 + 24, d33);
}
void Conv3x3Fp16OutputTransform(const float16_t *gemm_out, float16_t *out_data, const float16_t *bias_data,
int start_index, int real_cal_num, int out_w_block, ConvParameter *conv_param) {
int output_channel = conv_param->output_channel_;
int output_w = conv_param->output_w_;
int output_h = conv_param->output_h_;
int oc8 = UP_DIV(output_channel, C8NUM);
for (int i = 0; i < real_cal_num; i++) {
int out_w_index = (start_index + i) % out_w_block;
int out_h_index = (start_index + i) / out_w_block;
int src_tile_offset = i * oc8 * C8NUM * 36;
int dst_tile_offset = 8 * (out_w_index * 4 + out_h_index * 4 * output_w);
for (int j = 0; j < oc8; j++) {
int src_oc8_offset = src_tile_offset + j * 36 * C8NUM;
int dst_oc8_offset = dst_tile_offset + j * C8NUM * output_h * output_w;
const float16_t *src_ptr = gemm_out + src_oc8_offset;
const float16_t *bias_ptr = bias_data + j * C8NUM;
float16_t *dst_ptr = out_data + dst_oc8_offset;
// output transform
Conv3x3Fp16OutputUnit(src_ptr, bias_ptr, dst_ptr, output_w);
}
}
}
#endif
// int8 conv3x3
void Conv3x3Uint8InputUnit(int16_t *tmp_data, int16_t *trans_input_data, size_t step, int input_zp) {
#ifdef ENABLE_ARM

View File

@ -51,22 +51,6 @@ void Conv3x3Fp32OutputUnit(const float *gemm_out, const float *bias_data, float
void Conv3x3Fp32OutputTransform(const float *gemm_out, float *out_data, const float *bias_data, int start_index,
int real_cal_num, int out_w_block, ConvParameter *conv_param);
#ifdef ENABLE_FP16
// for fp16 convolution 3x3 filter/input/output transform
void Conv3x3Fp16InputUnit(float16_t *tmp_data, float16_t *trans_input_data, size_t step);
void Conv3x3Fp16InputTransform(const float16_t *input_data, float16_t *trans_input, float16_t *tmp_data,
int start_index, int real_cal_num, int out_w_block, ConvParameter *conv_param);
void Conv3x3Fp16FilterTransform(const float16_t *weight_data, float16_t *trans_weight, int iC8, int output_channel,
int kernel_plane);
void Conv3x3Fp16OutputUnit(const float16_t *gemm_out, const float16_t *bias_data, float16_t *output_data, int output_w);
void Conv3x3Fp16OutputTransform(const float16_t *gemm_out, float16_t *out_data, const float16_t *bias_data,
int start_index, int real_cal_num, int out_w_block, ConvParameter *conv_param);
#endif
// for int8 convolution 3x3 filter/input/output transform
void Conv3x3Uint8InputUnit(int16_t *tmp_data, int16_t *trans_input_data, size_t step, int input_zp);

View File

@ -127,9 +127,9 @@ kernel::LiteKernel *OpenCLArithmeticKernelCreator(const std::vector<lite::tensor
return kernel;
}
REG_KERNEL(kGPU, PrimitiveType_Mul, OpenCLArithmeticKernelCreator)
REG_KERNEL(kGPU, PrimitiveType_Add, OpenCLArithmeticKernelCreator)
REG_KERNEL(kGPU, PrimitiveType_Sub, OpenCLArithmeticKernelCreator)
REG_KERNEL(kGPU, PrimitiveType_Div, OpenCLArithmeticKernelCreator)
REG_KERNEL(kGPU, kNumberTypeFloat32, PrimitiveType_Mul, OpenCLArithmeticKernelCreator)
REG_KERNEL(kGPU, kNumberTypeFloat32, PrimitiveType_Add, OpenCLArithmeticKernelCreator)
REG_KERNEL(kGPU, kNumberTypeFloat32, PrimitiveType_Sub, OpenCLArithmeticKernelCreator)
REG_KERNEL(kGPU, kNumberTypeFloat32, PrimitiveType_Div, OpenCLArithmeticKernelCreator)
} // namespace mindspore::kernel

View File

@ -134,6 +134,6 @@ kernel::LiteKernel *OpenCLConcatKernelCreator(const std::vector<lite::tensor::Te
return kernel;
}
REG_KERNEL(kGPU, PrimitiveType_Concat, OpenCLConcatKernelCreator);
REG_KERNEL(kGPU, kNumberTypeFloat32, PrimitiveType_Concat, OpenCLConcatKernelCreator);
} // namespace mindspore::kernel

View File

@ -178,6 +178,6 @@ kernel::LiteKernel *OpenCLConv2dTransposeKernelCreator(const std::vector<lite::t
return kernel;
}
REG_KERNEL(kGPU, PrimitiveType_DeConv2D, OpenCLConv2dTransposeKernelCreator)
REG_KERNEL(kGPU, kNumberTypeFloat32, PrimitiveType_DeConv2D, OpenCLConv2dTransposeKernelCreator)
} // namespace mindspore::kernel

Some files were not shown because too many files have changed in this diff Show More