fix bug && optimize relu/relu6/winograd unpack func

This commit is contained in:
fuzhiye 2020-07-29 16:03:38 +08:00
parent 31e889881b
commit 0e5d30f917
23 changed files with 1193 additions and 898 deletions

View File

@ -113,6 +113,9 @@ if (BUILD_DEVICE)
if (PLATFORM_ARM64)
set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -march=armv8.2-a+dotprod+fp16")
add_compile_definitions(ENABLE_ARM64)
if (ENABLE_FP16)
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -march=armv8.2-a+dotprod+fp16")
endif ()
endif()
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/src)
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/tools/benchmark)

View File

@ -1,15 +1,16 @@
file(GLOB_RECURSE KERNEL_SRC
file(GLOB KERNEL_SRC
${CMAKE_CURRENT_SOURCE_DIR}/base/*.cc
${CMAKE_CURRENT_SOURCE_DIR}/opclib/*.cc
${CMAKE_CURRENT_SOURCE_DIR}/opclib/fp32/*.cc
${CMAKE_CURRENT_SOURCE_DIR}/opclib/int8/*.cc
${CMAKE_CURRENT_SOURCE_DIR}/opclib/quantization/*.cc
${CMAKE_CURRENT_SOURCE_DIR}/fp32/*.cc
${CMAKE_CURRENT_SOURCE_DIR}/int8/*.cc
)
if (PLATFORM_ARM64)
# assembly
file(GLOB_RECURSE ASSEMBLY_SRC ${CMAKE_CURRENT_SOURCE_DIR}/opclib/assembly/arm64/*.s
file(GLOB ASSEMBLY_SRC ${CMAKE_CURRENT_SOURCE_DIR}/opclib/assembly/arm64/*.s
${CMAKE_CURRENT_SOURCE_DIR}/opclib/assembly/arm64/*.S)
set_property(SOURCE ${ASSEMBLY_SRC} PROPERTY LANGUAGE C)
set(KERNEL_SRC ${KERNEL_SRC} ${ASSEMBLY_SRC})
@ -17,13 +18,15 @@ endif()
if (PLATFORM_ARM32)
# assembly
file(GLOB_RECURSE ASSEMBLY_SRC ${CMAKE_CURRENT_SOURCE_DIR}/opclib/assembly/arm32/*.s)
file(GLOB ASSEMBLY_SRC ${CMAKE_CURRENT_SOURCE_DIR}/opclib/assembly/arm32/*.s
${CMAKE_CURRENT_SOURCE_DIR}/opclib/assembly/arm32/*.S
)
set_property(SOURCE ${ASSEMBLY_SRC} PROPERTY LANGUAGE C)
set(KERNEL_SRC ${KERNEL_SRC} ${ASSEMBLY_SRC})
endif()
if (ENABLE_FP16)
file(GLOB_RECURSE FP6_SRC
file(GLOB FP6_SRC
${CMAKE_CURRENT_SOURCE_DIR}/fp16/*.cc
${CMAKE_CURRENT_SOURCE_DIR}/opclib/fp16/*.cc
)

View File

@ -101,7 +101,6 @@ void ConvolutionBaseCPUKernel::FreeQuantParam() {
int ConvolutionBaseCPUKernel::Init() {
auto input = this->inputs_.front();
auto output = this->outputs_.front();
conv_param_->input_batch_ = input->Batch();
conv_param_->input_h_ = input->Height();
conv_param_->input_w_ = input->Width();
@ -111,7 +110,6 @@ int ConvolutionBaseCPUKernel::Init() {
conv_param_->output_w_ = output->Width();
conv_param_->output_channel_ = output->Channel();
conv_param_->thread_num_ = ctx_->threadNum;
return RET_OK;
}
@ -221,9 +219,24 @@ void CheckIfUseWinograd(bool *use_winograd, int *output_unit, ConvParameter *con
}
}
kernel::LiteKernel *CpuConvFp32KernelCreator(const std::vector<lite::tensor::Tensor *> &inputs,
const std::vector<lite::tensor::Tensor *> &outputs,
OpParameter *opParameter, const Context *ctx) {
bool CheckSupportFP16() {
bool support_fp16 = false;
#ifdef ENABLE_ARM64
void *optimize_op_handler = OptimizeModule::GetInstance()->optimized_op_handler_;
if (optimize_op_handler != nullptr) {
support_fp16 = true;
MS_LOG(INFO) << "Support FP16.";
} else {
support_fp16 = false;
MS_LOG(INFO) << "Your machine doesn't support fp16, return back to float32 kernel.";
}
#endif
return support_fp16;
}
kernel::LiteKernel *CpuConvFloatKernelCreator(const std::vector<lite::tensor::Tensor *> &inputs,
const std::vector<lite::tensor::Tensor *> &outputs,
OpParameter *opParameter, const Context *ctx) {
auto conv_param = reinterpret_cast<ConvParameter *>(opParameter);
int kernel_h = conv_param->kernel_h_;
int kernel_w = conv_param->kernel_w_;
@ -240,44 +253,35 @@ kernel::LiteKernel *CpuConvFp32KernelCreator(const std::vector<lite::tensor::Ten
InputTransformUnitFunc input_trans_func = nullptr;
OutputTransformUnitFunc output_trans_func = nullptr;
CheckIfUseWinograd(&use_winograd, &out_unit, conv_param, input_trans_func, output_trans_func);
bool support_fp16 = CheckSupportFP16();
if (kernel_h == 1 && kernel_w == 1) {
auto kernel = new (std::nothrow) Convolution1x1CPUKernel(opParameter, inputs, outputs, ctx);
return kernel;
} else if (kernel_h == 3 && kernel_w == 3 && stride_h == 1 && stride_w == 1 && dilation_h == 1 && dilation_w == 1) {
if (support_fp16) {
#ifdef ENABLE_FP16
auto kernel = new (std::nothrow) Convolution3x3FP16CPUKernel(opParameter, inputs, outputs, ctx);
return kernel;
#endif
}
auto kernel = new (std::nothrow) Convolution3x3CPUKernel(opParameter, inputs, outputs, ctx);
return kernel;
} else if (use_winograd) {
auto kernel = new (std::nothrow) ConvolutionWinogradCPUKernel(opParameter, inputs, outputs, ctx, out_unit);
return kernel;
} else {
if (support_fp16) {
#ifdef ENABLE_FP16
auto kernel = new (std::nothrow) ConvolutionFP16CPUKernel(opParameter, inputs, outputs, ctx);
return kernel;
#endif
}
auto kernel = new (std::nothrow) ConvolutionCPUKernel(opParameter, inputs, outputs, ctx);
return kernel;
}
}
#ifdef ENABLE_FP16
kernel::LiteKernel *CpuConvFp16KernelCreator(const std::vector<lite::tensor::Tensor *> &inputs,
const std::vector<lite::tensor::Tensor *> &outputs,
OpParameter *opParameter, const Context *ctx) {
auto conv_param = reinterpret_cast<ConvParameter *>(opParameter);
int kernel_h = conv_param->kernel_h_;
int kernel_w = conv_param->kernel_w_;
int stride_h = conv_param->stride_h_;
int stride_w = conv_param->stride_w_;
int dilation_h = conv_param->dilation_h_;
int dilation_w = conv_param->dilation_w_;
if (kernel_h == 3 && kernel_w == 3 && stride_h == 1 && stride_w == 1 && dilation_h == 1 && dilation_w == 1) {
auto kernel = new (std::nothrow) Convolution3x3FP16CPUKernel(opParameter, inputs, outputs, ctx);
return kernel;
} else {
auto kernel = new (std::nothrow) ConvolutionFP16CPUKernel(opParameter, inputs, outputs, ctx);
return kernel;
}
}
#endif
kernel::LiteKernel *CpuConvInt8KernelCreator(const std::vector<lite::tensor::Tensor *> &inputs,
const std::vector<lite::tensor::Tensor *> &outputs,
OpParameter *opParameter, const Context *ctx) {
@ -308,17 +312,10 @@ kernel::LiteKernel *CpuConvKernelCreator(const std::vector<lite::tensor::Tensor
kernel::LiteKernel *kernel = nullptr;
switch (data_type) {
case kNumberTypeInt8:
break;
case kNumberTypeUInt8:
kernel = CpuConvInt8KernelCreator(inputs, outputs, opParameter, ctx);
break;
#ifdef ENABLE_FP16
case kNumberTypeFloat16:
kernel = CpuConvFp16KernelCreator(inputs, outputs, opParameter, ctx);
break;
#endif
case kNumberTypeFloat32:
kernel = CpuConvFp32KernelCreator(inputs, outputs, opParameter, ctx);
kernel = CpuConvFloatKernelCreator(inputs, outputs, opParameter, ctx);
break;
default:
break;
@ -385,8 +382,6 @@ kernel::LiteKernel *CpuConvDwKernelCreator(const std::vector<lite::tensor::Tenso
case kNumberTypeInt8:
kernel = CpuConvDwInt8KernelCreator(inputs, outputs, opParameter, ctx);
break;
case kNumberTypeUInt8:
break;
case kNumberTypeFloat32:
#ifdef ENABLE_FP16
kernel = CpuConvDwFp16KernelCreator(inputs, outputs, opParameter, ctx);
@ -515,8 +510,6 @@ kernel::LiteKernel *CpuDeConvKernelCreator(const std::vector<lite::tensor::Tenso
kernel::LiteKernel *kernel = nullptr;
switch (data_type) {
case kNumberTypeInt8:
break;
case kNumberTypeUInt8:
kernel = CpuDeConvInt8KernelCreator(inputs, outputs, opParameter, ctx);
break;
#ifdef ENABLE_FP16

View File

@ -26,10 +26,9 @@
#include <android/log.h>
#endif
#include "src/lite_kernel.h"
#include "include/context.h"
#include "src/runtime/kernel/arm/base/layout_transform.h"
#include "src/runtime/kernel/arm/opclib/optimized_kernel.h"
using mindspore::lite::Context;
using mindspore::schema::PadMode;
@ -40,7 +39,7 @@ class ConvolutionBaseCPUKernel : public LiteKernel {
public:
ConvolutionBaseCPUKernel(OpParameter *parameter, const std::vector<lite::tensor::Tensor *> &inputs,
const std::vector<lite::tensor::Tensor *> &outputs, const Context *ctx)
: LiteKernel(parameter, inputs, outputs), ctx_(ctx), thread_count_(ctx->threadNum) {
: LiteKernel(parameter, inputs, outputs), ctx_(ctx), thread_count_(ctx->threadNum) {
opParameter->thread_num_ = ctx->threadNum;
conv_param_ = reinterpret_cast<ConvParameter *>(opParameter);
}
@ -63,6 +62,7 @@ class ConvolutionBaseCPUKernel : public LiteKernel {
LayoutConvertor convert_func_;
};
void ComputeQuantOutRange(ConvParameter *conv_param);
bool CheckSupportFP16();
} // namespace mindspore::kernel
#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_BASE_CONVOLUTION_BASE_H_

View File

@ -49,6 +49,8 @@ void ProcessFilterFp16(float16_t *origin_weight, float16_t *dst_weight, ConvPara
int Convolution3x3FP16CPUKernel::InitWeightBias() {
auto input_channel = conv_param_->input_channel_;
int output_channel = conv_param_->output_channel_;
int kernel_h = conv_param_->kernel_h_;
int kernel_w = conv_param_->kernel_w_;
int iC4 = UP_DIV(input_channel, C4NUM);
int oC8 = UP_DIV(output_channel, C8NUM);
// init weight
@ -60,8 +62,8 @@ int Convolution3x3FP16CPUKernel::InitWeightBias() {
}
memset(transformed_filter_addr_, 0, transformed_size);
float *origin_weight = reinterpret_cast<float *>(inputs_.at(kWeightIndex)->Data());
size_t fp16_weight_size = in_channel * out_channel * kernel_h * kernel_w * sizeof(float16_t);
fp16_weight_ = malloc(fp16_weight_size);
size_t fp16_weight_size = input_channel * output_channel * kernel_h * kernel_w * sizeof(float16_t);
fp16_weight_ = reinterpret_cast<float16_t *>(malloc(fp16_weight_size));
if (fp16_weight_ == nullptr) {
MS_LOG(ERROR) << "malloc fp16_weight_ failed.";
return RET_ERROR;
@ -74,16 +76,17 @@ int Convolution3x3FP16CPUKernel::InitWeightBias() {
// init bias
size_t new_bias_size = oC8 * C8NUM * sizeof(float16_t);
bias_data_ = reinterpret_cast<float16_t *>(malloc(new_bias_size));
bias_data_ = malloc(new_bias_size);
if (bias_data_ == nullptr) {
MS_LOG(ERROR) << "malloc bias_data_ failed.";
return RET_ERROR;
}
memset(bias_data_, 0, new_bias_size);
auto fp16_bias_data = reinterpret_cast<float16_t *>(bias_data_);
if (inputs_.size() == kInputSize2) {
auto ori_bias_addr = reinterpret_cast<float *>(inputs_.at(kBiasIndex)->Data());
for (int i = 0; i < out_channel; ++i) {
bias_data_[i] = (float16_t)ori_bias_addr[i];
for (int i = 0; i < output_channel; ++i) {
fp16_bias_data[i] = (float16_t)ori_bias_addr[i];
}
} else {
MS_ASSERT(inputs_.size() == kInputSize1);
@ -129,16 +132,15 @@ int Convolution3x3FP16CPUKernel::InitTmpBuffer() {
}
memset(tmp_out_, 0, tmp_out_size);
size_t fp16_input_size =
in_channel * conv_param_->input_batch_ * conv_param_->input_h_ * conv_param_->input_w_ * sizeof(float16_t);
fp16_input_ = malloc(fp16_input_size);
size_t fp16_input_size = conv_param_->input_channel_ * conv_param_->input_batch_ * conv_param_->input_h_ *
conv_param_->input_w_ * sizeof(float16_t);
fp16_input_ = reinterpret_cast<float16_t *>(malloc(fp16_input_size));
if (fp16_input_ == nullptr) {
MS_LOG(ERROR) << "malloc fp16_input_ failed.";
return RET_ERROR;
}
memset(fp16_input_, 0, fp16_input_size);
// init nhwc4 input
size_t nhwc4_input_size =
iC4 * C4NUM * conv_param_->input_batch_ * conv_param_->input_h_ * conv_param_->input_w_ * sizeof(float16_t);
@ -249,4 +251,3 @@ int Convolution3x3FP16CPUKernel::Run() {
return RET_OK;
}
} // namespace mindspore::kernel

View File

@ -42,7 +42,7 @@ int ConvolutionFP16CPUKernel::InitWeightBias() {
// init weight
float *origin_weight = reinterpret_cast<float *>(inputs_.at(kWeightIndex)->Data());
size_t fp16_weight_size = in_channel * out_channel * kernel_h * kernel_w * sizeof(float16_t);
fp16_weight_ = malloc(fp16_weight_size);
fp16_weight_ = reinterpret_cast<float16_t *>(malloc(fp16_weight_size));
if (fp16_weight_ == nullptr) {
MS_LOG(ERROR) << "malloc fp16_weight_ failed.";
return RET_ERROR;
@ -60,16 +60,17 @@ int ConvolutionFP16CPUKernel::InitWeightBias() {
PackWeightFp16(fp16_weight_, conv_param_, packed_weight_);
// init bias
bias_data_ = reinterpret_cast<float16_t *>(malloc(oc8 * C8NUM * sizeof(float16_t)));
bias_data_ = malloc(oc8 * C8NUM * sizeof(float16_t));
if (bias_data_ == nullptr) {
MS_LOG(ERROR) << "malloc bias_data_ failed.";
return RET_ERROR;
}
memset(bias_data_, 0, oc8 * C8NUM * sizeof(float16_t));
auto fp16_bias_data = reinterpret_cast<float16_t *>(bias_data_);
if (inputs_.size() == kInputSize2) {
auto ori_bias = reinterpret_cast<float *>(inputs_.at(kBiasIndex)->Data());
for (int i = 0; i < out_channel; ++i) {
bias_data_[i] = (float16_t)ori_bias[i];
fp16_bias_data[i] = (float16_t)ori_bias[i];
}
} else {
MS_ASSERT(inputs_.size() == kInputSize1);
@ -101,7 +102,7 @@ int ConvolutionFP16CPUKernel::InitTmpBuffer() {
size_t fp16_input_size =
in_channel * conv_param_->input_batch_ * conv_param_->input_h_ * conv_param_->input_w_ * sizeof(float16_t);
fp16_input_ = malloc(fp16_input_size);
fp16_input_ = reinterpret_cast<float16_t *>(malloc(fp16_input_size));
if (fp16_input_ == nullptr) {
MS_LOG(ERROR) << "malloc fp16_input_ failed.";
return RET_ERROR;

View File

@ -20,9 +20,7 @@
#include <arm_neon.h>
#include <vector>
#include "src/lite_kernel.h"
#include "src/runtime/kernel/arm/base/convolution_base.h"
#include "src/runtime/kernel/arm/opclib/optimized_kernel.h"
namespace mindspore::kernel {
typedef void (*FP16_GEMM_FUNC)(float16_t *output, float16_t *input, float16_t *weight, float16_t *bias, size_t step,
@ -33,7 +31,7 @@ class ConvolutionFP16CPUKernel : public ConvolutionBaseCPUKernel {
public:
ConvolutionFP16CPUKernel(OpParameter *parameter, const std::vector<lite::tensor::Tensor *> &inputs,
const std::vector<lite::tensor::Tensor *> &outputs, const Context *ctx)
: ConvolutionBaseCPUKernel(parameter, inputs, outputs, ctx) {}
: ConvolutionBaseCPUKernel(parameter, inputs, outputs, ctx) {}
~ConvolutionFP16CPUKernel() override {
if (fp16_input_ != nullptr) {
free(fp16_input_);

View File

@ -33,10 +33,17 @@ int ConvolutionCPUKernel::InitWeightBias() {
int kernel_w = conv_param_->kernel_w_;
int in_channel = conv_param_->input_channel_;
int out_channel = conv_param_->output_channel_;
int oc8 = UP_DIV(out_channel, C8NUM);
int ic4 = UP_DIV(in_channel, C4NUM);
int kernel_plane = kernel_h * kernel_w;
int pack_weight_size = oc8 * ic4 * C8NUM * C4NUM * kernel_plane;
int oc_block, oc_block_num;
#ifdef ENABLE_ARM32
oc_block = C4NUM;
oc_block_num = UP_DIV(out_channel, C4NUM);
#else
oc_block = C8NUM;
oc_block_num = UP_DIV(out_channel, C8NUM);
#endif
int pack_weight_size = oc_block_num * oc_block * ic4 * C4NUM * kernel_plane;
// init weight
auto origin_weight = reinterpret_cast<float *>(inputs_.at(kWeightIndex)->Data());
@ -49,12 +56,12 @@ int ConvolutionCPUKernel::InitWeightBias() {
PackWeightFp32(origin_weight, conv_param_, packed_weight_);
// init bias
bias_data_ = reinterpret_cast<float *>(malloc(oc8 * C8NUM * sizeof(float)));
bias_data_ = reinterpret_cast<float *>(malloc(oc_block_num * oc_block * sizeof(float)));
if (bias_data_ == nullptr) {
MS_LOG(ERROR) << "malloc bias failed.";
return RET_ERROR;
}
memset(bias_data_, 0, oc8 * C8NUM * sizeof(float));
memset(bias_data_, 0, oc_block_num * oc_block * sizeof(float));
if (inputs_.size() == kInputSize2) {
auto ori_bias = reinterpret_cast<float *>(inputs_.at(kBiasIndex)->Data());
memcpy(bias_data_, ori_bias, out_channel * sizeof(float));
@ -198,4 +205,3 @@ int ConvolutionCPUKernel::Run() {
return RET_OK;
}
} // namespace mindspore::kernel

View File

@ -55,6 +55,7 @@ void ConvolutionInt8CPUKernel::CheckSupportOptimize() {
support_optimize_ = false;
}
#endif
conv_param_->tile_num_ = tile_num_;
}
int ConvolutionInt8CPUKernel::InitWeightBias() {
@ -78,7 +79,7 @@ int ConvolutionInt8CPUKernel::InitWeightBias() {
return RET_ERROR;
}
memset(packed_weight_, 0, pack_weight_size);
int32_t *weight_sum = reinterpret_cast<int32_t *>(malloc(sizeof(int32_t) * out_channel));
auto *weight_sum = reinterpret_cast<int32_t *>(malloc(sizeof(int32_t) * out_channel));
for (int i = 0; i < out_channel; i++) weight_sum[i] = 0;
PackWeightInt8(origin_weight, conv_param_, packed_weight_, weight_sum);
@ -105,15 +106,14 @@ int ConvolutionInt8CPUKernel::InitWeightBias() {
}
int ConvolutionInt8CPUKernel::InitTmpBuffer() {
int tile_n = 4;
int output_count = conv_param_->output_h_ * conv_param_->output_w_;
int output_tile_count = UP_DIV(output_count, tile_n);
int output_tile_count = UP_DIV(output_count, tile_num_);
int in_channel = conv_param_->input_channel_;
int ic4 = UP_DIV(in_channel, C4NUM);
int kernel_plane = conv_param_->kernel_h_ * conv_param_->kernel_w_;
int plane_c4 = UP_DIV(kernel_plane, C4NUM);
int unit_size = plane_c4 * C4NUM * ic4 * C4NUM;
int packed_input_size = output_tile_count * tile_n * unit_size;
int packed_input_size = output_tile_count * tile_num_ * unit_size;
packed_input_ = reinterpret_cast<int8_t *>(malloc(conv_param_->input_batch_ * packed_input_size));
if (packed_input_ == nullptr) {
@ -122,14 +122,14 @@ int ConvolutionInt8CPUKernel::InitTmpBuffer() {
}
memset(packed_input_, 0, conv_param_->input_batch_ * packed_input_size);
input_sum_ = reinterpret_cast<int32_t *>(malloc(tile_n * thread_count_ * sizeof(int32_t)));
input_sum_ = reinterpret_cast<int32_t *>(malloc(tile_num_ * thread_count_ * sizeof(int32_t)));
if (input_sum_ == nullptr) {
MS_LOG(ERROR) << "malloc input_sum_ failed.";
return RET_ERROR;
}
memset(input_sum_, 0, tile_n * thread_count_ * sizeof(int32_t));
memset(input_sum_, 0, tile_num_ * thread_count_ * sizeof(int32_t));
size_t tmp_dst_size = thread_count_ * tile_n * conv_param_->output_channel_ * sizeof(int32_t);
size_t tmp_dst_size = thread_count_ * tile_num_ * conv_param_->output_channel_ * sizeof(int32_t);
tmp_dst_ = reinterpret_cast<int32_t *>(malloc(tmp_dst_size));
if (tmp_dst_ == nullptr) {
MS_LOG(ERROR) << "malloc tmp_dst_ failed.";
@ -137,7 +137,7 @@ int ConvolutionInt8CPUKernel::InitTmpBuffer() {
}
memset(tmp_dst_, 0, tmp_dst_size);
tmp_out_ = reinterpret_cast<int8_t *>(malloc(thread_count_ * tile_n * conv_param_->output_channel_));
tmp_out_ = reinterpret_cast<int8_t *>(malloc(thread_count_ * tile_num_ * conv_param_->output_channel_));
if (tmp_out_ == nullptr) {
MS_LOG(ERROR) << "malloc tmp_out_ failed.";
return RET_ERROR;
@ -173,7 +173,7 @@ int ConvolutionInt8CPUKernel::InitWeightBiasOpt() {
return RET_ERROR;
}
memset(packed_weight_, filter_zp, pack_weight_size);
int32_t *weight_sum = reinterpret_cast<int32_t *>(malloc(sizeof(int32_t) * out_channel));
auto *weight_sum = reinterpret_cast<int32_t *>(malloc(sizeof(int32_t) * out_channel));
for (int i = 0; i < out_channel; i++) weight_sum[i] = filter_zp * ic4 * C4NUM * kernel_plane;
PackWeightInt8Opt(origin_weight, conv_param_, packed_weight_, weight_sum);
@ -200,15 +200,13 @@ int ConvolutionInt8CPUKernel::InitWeightBiasOpt() {
}
int ConvolutionInt8CPUKernel::InitTmpBufferOpt() {
// todo
int tile_n = 24;
int output_count = conv_param_->output_h_ * conv_param_->output_w_;
int output_tile_count = UP_DIV(output_count, tile_n);
int output_tile_count = UP_DIV(output_count, tile_num_);
int in_channel = conv_param_->input_channel_;
int ic4 = UP_DIV(in_channel, C4NUM);
int kernel_plane = conv_param_->kernel_h_ * conv_param_->kernel_w_;
int unit_size = kernel_plane * ic4 * C4NUM;
int packed_input_size = output_tile_count * tile_n * unit_size;
int packed_input_size = output_tile_count * tile_num_ * unit_size;
packed_input_ = reinterpret_cast<int8_t *>(malloc(conv_param_->input_batch_ * packed_input_size));
if (packed_input_ == nullptr) {
@ -217,14 +215,14 @@ int ConvolutionInt8CPUKernel::InitTmpBufferOpt() {
}
memset(packed_input_, 0, conv_param_->input_batch_ * packed_input_size);
input_sum_ = reinterpret_cast<int32_t *>(malloc(tile_n * thread_count_ * sizeof(int32_t)));
input_sum_ = reinterpret_cast<int32_t *>(malloc(tile_num_ * thread_count_ * sizeof(int32_t)));
if (input_sum_ == nullptr) {
MS_LOG(ERROR) << "malloc input_sum_ failed.";
return RET_ERROR;
}
memset(input_sum_, 0, tile_n * thread_count_ * sizeof(int32_t));
memset(input_sum_, 0, tile_num_ * thread_count_ * sizeof(int32_t));
size_t tmp_dst_size = thread_count_ * tile_n * conv_param_->output_channel_ * sizeof(int32_t);
size_t tmp_dst_size = thread_count_ * tile_num_ * conv_param_->output_channel_ * sizeof(int32_t);
tmp_dst_ = reinterpret_cast<int32_t *>(malloc(tmp_dst_size));
if (tmp_dst_ == nullptr) {
MS_LOG(ERROR) << "malloc tmp_dst_ failed.";
@ -232,7 +230,7 @@ int ConvolutionInt8CPUKernel::InitTmpBufferOpt() {
}
memset(tmp_dst_, 0, tmp_dst_size);
tmp_out_ = reinterpret_cast<int8_t *>(malloc(thread_count_ * tile_n * conv_param_->output_channel_));
tmp_out_ = reinterpret_cast<int8_t *>(malloc(thread_count_ * tile_num_ * conv_param_->output_channel_));
if (tmp_out_ == nullptr) {
MS_LOG(ERROR) << "malloc tmp_out_ failed.";
return RET_ERROR;

View File

@ -5,19 +5,22 @@ set(LITE_DIR ${CMAKE_CURRENT_SOURCE_DIR}/../../../../../)
include_directories(OPTIMIZED_OP_DIR)
########################### optimized files ###########################
set(FP16_ASSEMBLY
# ${OPTIMIZED_OP_DIR}/assembly/arm64/IndirectGemmFp16_16x8.s
file(GLOB OPTIMIZED_ASSEMBLY
${OPTIMIZED_OP_DIR}/assembly/opt/*.s
${OPTIMIZED_OP_DIR}/assembly/opt/*.S
)
file(GLOB_RECURSE OPTIMIZED_INT8_ASSEMBLY
${OPTIMIZED_OP_DIR}/assembly/opt/*.S
file(GLOB FP16_SRC
# ${OPTIMIZED_OP_DIR}/fp16/*.cc
# ${OPTIMIZED_OP_DIR}/../fp16/*.cc
)
########################### share library build ########################
set(OPTIMIZED_OPS "opt_op_handler.c")
set_property(SOURCE ${OPTIMIZED_INT8_ASSEMBLY} PROPERTY LANGUAGE C)
list(APPEND OPTIMIZED_OPS ${OPTIMIZED_INT8_ASSEMBLY} ${FP16_ASSEMBLY})
set_property(SOURCE ${OPTIMIZED_ASSEMBLY} PROPERTY LANGUAGE C)
list(APPEND OPTIMIZED_OPS ${OPTIMIZED_ASSEMBLY} ${FP16_SRC})
if (PLATFORM_ARM64)
string(REPLACE "-fvisibility=hidden" "-fvisibility=default" CMAKE_C_FLAGS "${CMAKE_C_FLAGS}")

View File

@ -17,7 +17,7 @@
#include "src/runtime/kernel/arm/opclib/common_func.h"
#include "src/runtime/kernel/arm/opclib/quantization/fixed_point.h"
#ifndef ENABLE_ARM
#ifndef __aarch64__
void IndirectGemmFp32(float *output, const float *input, const float *weight, const float *bias, size_t step, int ic4,
int output_channel, size_t offset, size_t relu, size_t relu6) {
for (int i = 0; i < TILE_NUM; i++) {
@ -108,24 +108,49 @@ int8_t MinInt8(int8_t a, int8_t b) { return b ^ ((a ^ b) & -(a < b)); }
int8_t MaxInt8(int8_t a, int8_t b) { return a ^ ((a ^ b) & -(a < b)); }
void ReluFp32(float *data, int ele_num) {
for (int i = 0; i < ele_num; i++) {
if (data[i] < 0) {
data[i] = 0;
} else {
// do nothing
}
int four_block = UP_DIV(ele_num, C4NUM);
for (int i = 0; i < four_block - 1; i++) {
int index = i * C4NUM;
#ifdef ENABLE_NEON
float32x4_t relu_data = vld1q_f32(data + index);
float32x4_t zero_data = vdupq_n_f32(0);
relu_data = vmaxq_f32(relu_data, zero_data);
#else
data[index] = data[index] < 0 ? 0 : data[index];
data[index + 1] = data[index + 1] < 0 ? 0 : data[index + 1];
data[index + 2] = data[index + 2] < 0 ? 0 : data[index + 2];
data[index + 3] = data[index + 3] < 0 ? 0 : data[index + 3];
#endif
}
for (int j = (four_block - 1) * C4NUM; j < ele_num; ++j) {
data[j] = data[j] < 0 ? 0 : data[j];
}
}
void Relu6Fp32(float *data, int ele_num) {
for (int i = 0; i < ele_num; i++) {
if (data[i] < 0) {
data[i] = 0;
} else if (data[i] > 6) {
data[i] = 6;
} else {
// do nothing
}
int four_block = UP_DIV(ele_num, C4NUM);
for (int i = 0; i < four_block - 1; i++) {
int index = i * C4NUM;
#ifdef ENABLE_NEON
float32x4_t relu6_data = vld1q_f32(data + index);
float32x4_t zero_data = vdupq_n_f32(0);
float32x4_t six_data = vdupq_n_f32(6);
relu6_data = vmaxq_f32(relu6_data, zero_data);
relu6_data = vminq_f32(relu6_data, six_data);
#else
data[index] = data[index] < 0 ? 0 : data[index];
data[index] = data[index] > 6 ? 6 : data[index];
data[index + 1] = data[index + 1] < 0 ? 0 : data[index + 1];
data[index + 1] = data[index + 1] > 6 ? 6 : data[index + 1];
data[index + 2] = data[index + 2] < 0 ? 0 : data[index + 2];
data[index + 2] = data[index + 2] > 6 ? 6 : data[index + 2];
data[index + 3] = data[index + 3] < 0 ? 0 : data[index + 3];
data[index + 3] = data[index + 3] > 6 ? 6 : data[index + 3];
#endif
}
for (int j = (four_block - 1) * C4NUM; j < ele_num; ++j) {
data[j] = data[j] < 0 ? 0 : data[j];
data[j] = data[j] > 6 ? 6 : data[j];
}
}

View File

@ -39,7 +39,7 @@ struct ConvParameter {
int pad_l_;
int pad_r_;
int group_;
int n_dim_;
int tile_num_;
int input_batch_;
int input_h_;
int input_w_;

View File

@ -141,7 +141,7 @@ void ConvFp16(float16_t *input_data, float16_t *packed_input, float16_t *packed_
int start_index = thread_id * tile_n;
int real_cal_num = (output_count - start_index) < tile_n ? (output_count - start_index) : tile_n;
float16_t *gemm_input =
(float *)(packed_input + thread_id * unit_size * tile_n + gemm_in_batch_offset);
(float16_t *)(packed_input + thread_id * unit_size * tile_n + gemm_in_batch_offset);
Im2ColPackUnitFp16(input_data + in_batch_offset, conv_param, gemm_input, real_cal_num, start_index);
int out_offset = thread_id * tile_n * out_channel + out_batch_offset;

View File

@ -16,7 +16,7 @@
#include "src/runtime/kernel/arm/opclib/fp32/common_func.h"
#ifndef ENABLE_ARM
#ifndef __aarch64__
void MatrixAdd(const float *a_ptr, const float *b_ptr, float *dst, size_t a_stride, size_t b_stride, size_t c_stride,
size_t row, size_t col) {
for (int r = 0; r < row; r++) {

View File

@ -31,13 +31,12 @@ void ConvFp32(float *input_data, float *packed_input, float *packed_weight, cons
int out_w = conv_param->output_w_;
int out_channel = conv_param->output_channel_;
int thread_count = conv_param->thread_num_;
int tile_n = 8;
int output_count = out_h * out_w;
int output_tile_count = UP_DIV(output_count, tile_n);
int output_tile_count = UP_DIV(output_count, TILE_NUM);
int ic4 = UP_DIV(in_channel, C4NUM);
int kernel_plane = kernel_h * kernel_w;
int unit_size = kernel_plane * ic4 * C4NUM;
int packed_input_size = output_tile_count * tile_n * unit_size;
int packed_input_size = output_tile_count * TILE_NUM * unit_size;
// we accumulate 4 channels per time for input blocks
int conv_depth = kernel_h * kernel_w;
@ -50,13 +49,13 @@ void ConvFp32(float *input_data, float *packed_input, float *packed_weight, cons
int out_batch_offset = b * out_channel * out_h * out_w;
int gemm_in_batch_offset = b * packed_input_size;
for (int thread_id = task_id; thread_id < output_tile_count; thread_id += thread_count) {
int start_index = thread_id * tile_n;
int real_cal_num = (output_count - start_index) < tile_n ? (output_count - start_index) : tile_n;
float *gemm_input = packed_input + thread_id * unit_size * tile_n + gemm_in_batch_offset;
int start_index = thread_id * TILE_NUM;
int real_cal_num = (output_count - start_index) < TILE_NUM ? (output_count - start_index) : TILE_NUM;
float *gemm_input = packed_input + thread_id * unit_size * TILE_NUM + gemm_in_batch_offset;
Im2ColPackUnitFp32(input_data + in_batch_offset, conv_param, gemm_input, real_cal_num, start_index);
int out_offset = thread_id * tile_n * out_channel + out_batch_offset;
if (real_cal_num == tile_n) {
int out_offset = thread_id * TILE_NUM * out_channel + out_batch_offset;
if (real_cal_num == TILE_NUM) {
float *gemm_output = output_data + out_offset;
IndirectGemmFp32_8x8(gemm_output, gemm_input, packed_weight, bias_data, conv_depth, ic4, out_channel,
output_offset, 0, 0, conv_param->is_relu_, conv_param->is_relu6_);
@ -121,22 +120,8 @@ void ConvWinogardFp32(float *input_data, float *trans_weight, const float *bias_
}
}
// get real output
for (int batch = 0; batch < out_batch; batch++) {
int batch_size = batch * out_channel * conv_param->output_h_ * conv_param->output_w_;
for (int h = 0; h < conv_param->output_h_; h++) {
for (int w = 0; w < conv_param->output_w_; w++) {
for (int c = 0; c < out_channel; c++) {
int oc4_block = c / C4NUM;
int oc4_res = c % C4NUM;
int src_offset = oc4_block * C4NUM * out_w_block * out_h_block * out_unit * out_unit +
C4NUM * (h * out_w_block * out_unit + w) + oc4_res;
int dst_offset = (h * conv_param->output_w_ + w) * out_channel + c;
(output_data + dst_offset)[0] = (tmp_out_data + src_offset)[0];
}
}
}
}
UnPackWinogradOutput(tmp_out_data, output_data, out_batch, conv_param->output_h_, conv_param->output_w_, out_channel,
out_unit);
int output_num = out_channel * conv_param->output_h_ * conv_param->output_w_ * conv_param->output_batch_;
if (is_relu) {
ReluFp32(output_data, output_num);
@ -147,6 +132,45 @@ void ConvWinogardFp32(float *input_data, float *trans_weight, const float *bias_
}
}
void UnPackWinogradOutput(const float *src, float *dst, int batch, int height, int width, int channel,
int output_unit) {
int out_h_block_num = UP_DIV(height, output_unit);
int out_w_block_num = UP_DIV(width, output_unit);
int c4 = UP_DIV(channel, C4NUM);
for (int b = 0; b < batch; b++) {
int src_batch_offset = b * c4 * C4NUM * out_h_block_num * output_unit * out_w_block_num * output_unit;
int dst_batch_offset = b * height * width * channel;
for (int h = 0; h < height; h++) {
int src_h_offset = src_batch_offset + C4NUM * (h * out_w_block_num * output_unit);
int dst_h_offset = dst_batch_offset + h * width * channel;
for (int w = 0; w < width; w++) {
int src_w_offset = src_h_offset + w * C4NUM;
int dst_w_offset = dst_h_offset + w * channel;
for (int c = 0; c < c4 - 1; c++) {
int src_c4_offset = src_w_offset + c * C4NUM * out_w_block_num * out_h_block_num * output_unit * output_unit;
int dst_c4_offset = dst_w_offset + c * C4NUM;
#ifdef ENABLE_NEON
vst1q_f32(dst + dst_c4_offset, vld1q_f32(src + src_c4_offset));
#else
dst[dst_c4_offset] = src[src_c4_offset];
dst[dst_c4_offset + 1] = src[src_c4_offset + 1];
dst[dst_c4_offset + 2] = src[src_c4_offset + 2];
dst[dst_c4_offset + 3] = src[src_c4_offset + 3];
#endif
}
int c_res = channel - (c4 - 1) * C4NUM;
int src_c_res_offset = (c4 - 1) * C4NUM * out_w_block_num * out_h_block_num * output_unit * output_unit;
int dst_c_res_offset = (c4 - 1) * C4NUM;
for (int c = 0; c < c_res; c++) {
int src_c4_res_offset = src_w_offset + src_c_res_offset + c;
int dst_c4_res_offset = dst_w_offset + dst_c_res_offset + c;
dst[dst_c4_res_offset] = src[src_c4_res_offset];
}
}
}
}
}
// fp32 conv3x3
void Conv3x3Fp32(float *input_data, float *transed_weight, const float *bias_data, float *output_data,
TmpBufferAddress *buffer_list, int task_id, ConvParameter *conv_param) {
@ -182,7 +206,7 @@ void Conv3x3Fp32(float *input_data, float *transed_weight, const float *bias_dat
}
PackNC4HW4ToNHWCFp32(nc4hw4_out, output_data, 1, conv_param->output_h_ * conv_param->output_w_, output_channel);
}
int output_num = oc4 * C4NUM * conv_param->output_h_ * conv_param->output_w_ * conv_param->output_batch_;
int output_num = output_channel * conv_param->output_h_ * conv_param->output_w_ * conv_param->output_batch_;
if (is_relu) {
ReluFp32(output_data, output_num);
} else if (is_relu6) {
@ -191,4 +215,3 @@ void Conv3x3Fp32(float *input_data, float *transed_weight, const float *bias_dat
// do nothing
}
}

View File

@ -42,10 +42,10 @@ void ConvWinogardFp32(float *input_data, float *trans_weight, const float *bias_
TmpBufferAddress *buffer_list, int task_id, ConvParameter *conv_param,
InputTransformUnitFunc input_trans_func, OutputTransformUnitFunc output_trans_func);
void UnPackWinogradOutput(const float *src, float *dst, int batch, int height, int width, int channel, int output_unit);
// fp32 conv3x3
void Conv3x3Fp32(float *input_data, float *transed_weight, const float *bias_data, float *output_data,
TmpBufferAddress *buffer_list, int task_id,
ConvParameter *conv_param);
TmpBufferAddress *buffer_list, int task_id, ConvParameter *conv_param);
#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_OPCLIB_FP32_CONV_H_

View File

@ -81,8 +81,9 @@ void AvgPooling(const float *input_ptr, float *output_ptr, PoolingParameter *poo
} // win_w loop
} // win_h loop
#ifdef ENABLE_NEON
float32x4_t dup_count = vdupq_n_f32(real_count);
vst1q_f32(output_ptr + out_channel_offset, vdivq_f32(tmp_avg, dup_count));
float reverse_count = 1 / real_count;
float32x4_t dup_count = vdupq_n_f32(reverse_count);
vst1q_f32(output_ptr + out_channel_offset, vmulq_f32(tmp_avg, dup_count));
#else
*(output_ptr + out_channel_offset) = tmp_avg1 / (float)real_count;
*(output_ptr + out_channel_offset + 1) = tmp_avg2 / (float)real_count;

View File

@ -24,11 +24,6 @@ void IndirectGemmInt16to32_8x4(int32_t *dst, const int16_t *src, const int16_t *
size_t oc4, size_t offset);
#ifdef ENABLE_ARM64
// void IndirectGemmInt8_24x4_dp(int8_t *dst, const int8_t *src, const int8_t *weight, const int32_t *bias, size_t
// ksize,
// size_t ic4, size_t output_channel, size_t offset, const int32_t *input_sum,
// size_t act_min, size_t act_max, size_t out_zp, size_t out_multiplier, size_t
// shift_before, size_t shift_after);
void IndirectGemmInt8_4x4(int8_t *output, const int8_t *input, const int8_t *weight, const int32_t *bias, size_t ksize,
size_t ic4, size_t oc, size_t offset, const int32_t *input_sum, size_t act_min,
size_t act_max, size_t out_zp, size_t out_multiplier, size_t shift_before,
@ -54,8 +49,9 @@ void IndirectGemmInt8(int8_t *dst, int32_t *tmp_dst, const int8_t *src, const in
#ifdef __aarch64__
IndirectGemmInt8_4x4(dst, src, weight, bias, kernel_plane, ic4, output_channel, output_channel * sizeof(int8_t),
input_sum, act_min, act_max, out_zp, out_multiplier, shift_before, shift_after);
// todo arm32
#else
int tile_num = 4;
int tile_num = conv_param->tile_num_;
int plane_c4 = UP_DIV(kernel_plane, C4NUM);
for (int oc = 0; oc < output_channel; oc++) {
int oc4_block = oc / C4NUM;
@ -109,7 +105,7 @@ void IndirectGemmInt8Opt(int8_t *dst, int32_t *tmp_dst, const int8_t *src, const
act_min, act_max, out_zp, out_multiplier, shift_before, shift_after);
#endif
} else {
int tile_num = 24;
int tile_num = conv_param->tile_num_;
for (int oc = 0; oc < output_channel; oc++) {
int oc4_block = oc / C4NUM;
int oc4_res = oc % C4NUM;
@ -202,7 +198,7 @@ void ConvInt8(int8_t *input_data, int8_t *packed_input, int8_t *packed_weight, c
int out_channel = conv_param->output_channel_;
int32_t input_zp = conv_param->conv_quant_arg_.quant_args_[0][0].zp_;
int tile_n = 4;
int tile_n = conv_param->tile_num_;
int thread_count = conv_param->thread_num_;
int output_count = out_h * out_w;
int output_tile_count = UP_DIV(output_count, tile_n);
@ -255,9 +251,7 @@ void ConvInt8Opt(int8_t *input_data, int8_t *packed_input, int8_t *packed_weight
int out_w = conv_param->output_w_;
int out_channel = conv_param->output_channel_;
int32_t input_zp = conv_param->conv_quant_arg_.quant_args_[0][0].zp_;
// todo
int tile_n = 24;
int tile_n = conv_param->tile_num_;
int thread_count = conv_param->thread_num_;
int output_count = out_h * out_w;
int output_tile_count = UP_DIV(output_count, tile_n);
@ -302,7 +296,6 @@ void ConvInt8Opt(int8_t *input_data, int8_t *packed_input, int8_t *packed_weight
void Conv3x3Int8(int16_t *input_data, int16_t *transed_weight, const int32_t *bias_data, int8_t *output_data,
int16_t *tile_buffer, int16_t *block_unit_buffer, int32_t *tmp_dst_buffer, int8_t *tmp_out,
int task_id, ConvParameter *conv_param) {
// todo
int thread_count = conv_param->thread_num_;
int ic8 = UP_DIV(conv_param->input_channel_, C8NUM);
int output_batch = conv_param->output_batch_;
@ -331,8 +324,5 @@ void Conv3x3Int8(int16_t *input_data, int16_t *transed_weight, const int32_t *bi
}
// get real output
for (int batch = 0; batch < output_batch; batch++) {
// int batch_size = batch * output_channel * output_h * output_w;
C4UnpackToHwcInt8(tmp_out, output_data, output_channel, output_h, output_w);
}
PackNC4HW4ToNHWCInt8(tmp_out, output_data, output_batch, output_h * output_w, output_channel);
}

View File

@ -16,7 +16,6 @@
#include <stdlib.h>
// todo
extern void IndirectGemmInt8_24x4_dp(int8_t *dst, const int8_t *src, const int8_t *weight, const int32_t *bias,
size_t ksize, size_t ic4, size_t output_channel, size_t offset,
const int32_t *input_sum, size_t act_min, size_t act_max, size_t out_zp,

View File

@ -345,6 +345,7 @@ void PackNHWC8Fp16ToNHWCFp32(float16_t *src, float *dst, int batch, int plane, i
void PackWeightFp32(float *weight_data, ConvParameter *conv_param, float *packed_weight) {
// original weight format : ohwi
// todo pack weight for arm32 platform
int kernel_h = conv_param->kernel_h_;
int kernel_w = conv_param->kernel_w_;
int in_channel = conv_param->input_channel_;
@ -352,7 +353,7 @@ void PackWeightFp32(float *weight_data, ConvParameter *conv_param, float *packed
int oc8 = UP_DIV(out_channel, C8NUM);
int ic4 = UP_DIV(in_channel, C4NUM);
int kernel_plane = kernel_h * kernel_w;
int pack_weight_size = oc8 * ic4 * C8NUM * C4NUM * kernel_plane;
int pack_weight_size = oc8 * C8NUM * ic4 * C4NUM * kernel_plane;
int unit_size = C8NUM * C4NUM;
int block_size = pack_weight_size / oc8;
@ -565,7 +566,7 @@ void Im2ColPackUnitFp32(const float *input_data, ConvParameter *conv_param, floa
void Im2ColPackUnitInt8(const int8_t *input_data, int8_t *packed_input, int real_cal_num, int block_index,
int32_t *input_sum, ConvParameter *conv_param) {
// input format : nhwc
int tile_num = 4;
int tile_num = conv_param->tile_num_;
int32_t filter_zp = conv_param->conv_quant_arg_.quant_args_[1][0].zp_;
int kernel_h = conv_param->kernel_h_;
int kernel_w = conv_param->kernel_w_;
@ -624,7 +625,7 @@ void Im2ColPackUnitInt8(const int8_t *input_data, int8_t *packed_input, int real
void Im2ColPackUnitInt8Opt(const int8_t *input_data, int8_t *packed_input, int real_cal_num, int block_index,
int32_t *input_sum, ConvParameter *conv_param) {
// input format : nhwc
int tile_num = 24;
int tile_num = conv_param->tile_num_;
int32_t filter_zp = conv_param->conv_quant_arg_.quant_args_[1][0].zp_;
int kernel_h = conv_param->kernel_h_;
int kernel_w = conv_param->kernel_w_;
@ -980,15 +981,23 @@ void PackNC4HW4ToNHWCInt8(const void *src, void *dst, int batch, int plane, int
for (int b = 0; b < batch; b++) {
int src_offset = b * plane * c4 * C4NUM;
int dst_offset = b * plane * channel;
for (int c = 0; c < channel; c++) {
int c4_block_num = c / C4NUM;
int c4_block_res = c % C4NUM;
int src_c_offset = src_offset + c4_block_num * plane * C4NUM + c4_block_res;
int dst_c_offset = dst_offset + c;
for (int k = 0; k < plane; k++) {
int src_kernel_offset = src_c_offset + k * C4NUM;
int dst_kernel_offset = dst_c_offset + k * channel;
((uint8_t *)dst + dst_kernel_offset)[0] = ((uint8_t *)src + src_kernel_offset)[0];
for (int k = 0; k < plane; k++) {
int src_kernel_offset = src_offset + k * C4NUM;
int dst_kernel_offset = dst_offset + k * channel;
for (int c = 0; c < c4 - 1; c++) {
int src_c_offset = src_kernel_offset + c * plane * C4NUM;
int dst_c_offset = dst_kernel_offset + c * C4NUM;
((int8_t *)dst + dst_c_offset)[0] = ((int8_t *)src + src_c_offset)[0];
((int8_t *)dst + dst_c_offset)[1] = ((int8_t *)src + src_c_offset)[1];
((int8_t *)dst + dst_c_offset)[2] = ((int8_t *)src + src_c_offset)[2];
((int8_t *)dst + dst_c_offset)[3] = ((int8_t *)src + src_c_offset)[3];
}
// res part
int res_c = channel - (c4 - 1) * C4NUM;
for (int i = 0; i < res_c; i++) {
int src_res_c_offset = src_kernel_offset + (c4 - 1) * C4NUM * plane + i;
int dst_res_c_offset = dst_kernel_offset + (c4 - 1) * C4NUM + i;
((int8_t *)dst + dst_res_c_offset)[0] = ((int8_t *)src + src_res_c_offset)[0];
}
}
}

View File

@ -122,52 +122,4 @@ void PackDepthwiseInt8Input(const int8_t *src, int16_t *dst, const ConvParameter
void PackDepthwiseInt8Weight(const int8_t *src, int16_t *dst, const ConvParameter *conv_param);
inline void UnpackHwcToChwFp32(float *src_ptr, float *dst_ptr, int channel, int h, int w) {
int cur = 0;
for (int i = 0; i < channel; i++) {
auto plane = i / BLOCK;
auto offset = i % BLOCK;
auto src_plane = plane * h * w * BLOCK + src_ptr;
for (int j = 0; j < h * w; j++) {
dst_ptr[cur++] = src_plane[j * BLOCK + offset];
}
}
}
inline void C8UnpackToHwcFp32(float *src_ptr, float *dst_ptr, int channel, int h, int w) {
int cur = 0;
for (int j = 0; j < h * w; j++) {
for (int i = 0; i < channel; i++) {
auto plane = i / 8;
auto offset = i % 8;
auto src_plane = plane * h * w * 8 + src_ptr;
dst_ptr[cur++] = src_plane[j * 8 + offset];
}
}
}
inline void C4UnpackToHwcFp32(float *src_ptr, float *dst_ptr, int channel, int h, int w) {
int cur = 0;
for (int j = 0; j < h * w; j++) {
for (int i = 0; i < channel; i++) {
auto plane = i / 4;
auto offset = i % 4;
auto src_plane = plane * h * w * 4 + src_ptr;
dst_ptr[cur++] = src_plane[j * 4 + offset];
}
}
}
inline void C4UnpackToHwcInt8(int8_t *src_ptr, int8_t *dst_ptr, int channel, int h, int w) {
int cur = 0;
for (int j = 0; j < h * w; j++) {
for (int i = 0; i < channel; i++) {
auto plane = i / 4;
auto offset = i % 4;
auto src_plane = plane * h * w * 4 + src_ptr;
dst_ptr[cur++] = src_plane[j * 4 + offset];
}
}
}
#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_OPCLIB_PACK_H_

View File

@ -17,659 +17,43 @@
#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_OPCLIB_QUANTIZATION_FIXED_POINT_H_
#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_OPCLIB_QUANTIZATION_FIXED_POINT_H_
#include <algorithm>
#include <cassert>
#include <cmath>
#include <cstdint>
#include <limits>
#include <limits.h>
#include "include/infer_log.h"
#ifdef ENABLE_NEON
#include <arm_neon.h>
#endif
// Part 1: Low-level integer-arithmetic primitives.
// The implementations here are generic implementations valid for
// scalar types (e.g. std::int32_t). Architecture-specific SIMD types
// (e.g. NEON int32x4_t) may be supported by providing
// specializations for them in separate files.
//
// The purpose of these primitives is two-fold:
// - They will be used to implement higher-level fixed-point
// abstractions, namely the FixedPoint class and its arithmetic
// operators.
// - They will be directly used to implement some more involved
// fixed-point computations, e.g. the fixed-point implementation
// of math functions such as tanh.
// Some compile-time traits around raw types to handle SIMD aspects:
// number of lanes, underlying scalar type.
template <typename tIntegerType>
struct FixedPointRawTypeTraits {};
template <>
struct FixedPointRawTypeTraits<std::int32_t> {
typedef std::int32_t ScalarRawType;
static constexpr int kLanes = 1;
};
template <>
struct FixedPointRawTypeTraits<std::int16_t> {
typedef std::int16_t ScalarRawType;
static constexpr int kLanes = 1;
};
// Returns a SIMD value duplicating a scalar value across all lanes.
template <typename tRawType>
tRawType Dup(typename FixedPointRawTypeTraits<tRawType>::ScalarRawType x) {
return x;
// returns the high-32 bits of a * b with rounding
// assume that a and b is divided by 2^31, who fall into [-1, 1]
// so the mantissa of a * b is (a / 2^31) * (b / 2^31) * 2^31= (a * b) / 2^31
// actually we compute 2 * a * b / 2^32
// and take 32 bits of mantissa for rounding
inline int SaturatingRoundingDoublingHighMul(int a, int b) {
if (a == INT_MIN && b == INT_MIN) {
return INT_MAX;
}
int64_t ab = ((int64_t)a) * ((int64_t)b);
int64_t rounding = ab >= 0 ? (1ll << 30) : (1ll - (1ll << 30));
// do not apply right shift to potential negetive values
int ab_mantissa = (int) ((ab + rounding) / (1ll << 31));
return ab_mantissa;
}
// Plain bit-wise AND
template <typename tIntegerType>
tIntegerType BitAnd(tIntegerType a, tIntegerType b) {
return a & b;
}
// Plain bit-wise OR
template <typename tIntegerType>
tIntegerType BitOr(tIntegerType a, tIntegerType b) {
return a | b;
}
// Plain bit-wise XOR
template <typename tIntegerType>
tIntegerType BitXor(tIntegerType a, tIntegerType b) {
return a ^ b;
}
// Plain bit-wise NOT
template <typename tIntegerType>
tIntegerType BitNot(tIntegerType a) {
return ~a;
}
// Integer addition. Not saturating. Overflow is undefined behavior.
template <typename tIntegerType>
tIntegerType Add(tIntegerType a, tIntegerType b) {
return a + b;
}
// Integer multiplication. Not saturating. Overflow is undefined behavior.
template <typename tIntegerType>
tIntegerType Mul(tIntegerType a, tIntegerType b) {
return a * b;
}
// Integer subtraction. Not saturating. Overflow is undefined behavior.
template <typename tIntegerType>
tIntegerType Sub(tIntegerType a, tIntegerType b) {
return a - b;
}
// Integer unary negative. Not saturating. Overflow is undefined behavior.
template <typename tIntegerType>
tIntegerType Neg(tIntegerType a) {
return -a;
}
// Integer arithmetic left-shift, equivalent to multiplying with a power of two.
// Negative values are OK. In case of overflow, no Undefined
// Behavior, but the results are implementation-defined (in practice,
// they currently are saturated, but we make no commitment to that). The idea
// is that the caller will want to implement the overflowing cases with
// saturation with compare-and-mask, so we don't care about the results
// in the overflow case, we just want to avoid undefined behavior.
//
// tIntegerType may be int32 or any narrower signed type.
template <typename tIntegerType, typename OffsetType>
tIntegerType ShiftLeft(tIntegerType a, OffsetType offset) {
const std::int64_t wide_a = (std::int64_t)(a);
const std::int64_t wide_shifted = wide_a * (1 << offset);
const auto min = std::numeric_limits<tIntegerType>::min();
const auto max = std::numeric_limits<tIntegerType>::max();
return wide_shifted < min ? min : wide_shifted > max ? max : (tIntegerType)(wide_shifted);
}
// Integer arithmetic right-shift. Not rounding.
// Relying on implementation-defined, but in-practice-consistent,
// C++ compiler behavior.
template <typename tIntegerType>
tIntegerType ShiftRight(tIntegerType a, int offset) {
return a >> offset;
}
// Each bit of the result is set to the corresponding bit of either then_val or
// else_val depending on whether the corresponding bit of if_mask is set.
// Equivalent to the VBSL instruction in ARM NEON.
template <typename tIntegerType>
tIntegerType SelectUsingMask(tIntegerType if_mask, tIntegerType then_val, tIntegerType else_val) {
return BitXor(BitAnd(if_mask, then_val), BitAnd(BitNot(if_mask), else_val));
}
// For each input scalar, the corresponding bits of the result are set if the
// input scalar is non-zero.
template <typename tIntegerType>
tIntegerType MaskIfNonZero(tIntegerType a) {
static constexpr tIntegerType zero = 0;
return a ? BitNot(zero) : zero;
}
// For each input scalar, the corresponding bits of the result are set if the
// input scalar is zero.
template <typename tIntegerType>
tIntegerType MaskIfZero(tIntegerType a) {
return MaskIfNonZero<tIntegerType>(!a);
}
// For each pair of input scalars, the corresponding bits of the result are
// set if the input scalars are equal.
template <typename tIntegerType>
tIntegerType MaskIfEqual(tIntegerType a, tIntegerType b) {
return MaskIfNonZero<tIntegerType>(a == b);
}
// For each pair of input scalars, the corresponding bits of the result are
// set if the input scalars are not equal.
template <typename tIntegerType>
tIntegerType MaskIfNotEqual(tIntegerType a, tIntegerType b) {
return MaskIfNonZero<tIntegerType>(a != b);
}
// For each pair of input scalars, the corresponding bits of the result are
// set if the input scalars a, b satisfy a > b.
template <typename tIntegerType>
tIntegerType MaskIfGreaterThan(tIntegerType a, tIntegerType b) {
return MaskIfNonZero<tIntegerType>(a > b);
}
// For each pair of input scalars, the corresponding bits of the result are
// set if the input scalars a, b satisfy a >= b.
template <typename tIntegerType>
tIntegerType MaskIfGreaterThanOrEqual(tIntegerType a, tIntegerType b) {
return MaskIfNonZero<tIntegerType>(a >= b);
}
// For each pair of input scalars, the corresponding bits of the result are
// set if the input scalars a, b satisfy a < b.
template <typename tIntegerType>
tIntegerType MaskIfLessThan(tIntegerType a, tIntegerType b) {
return MaskIfNonZero<tIntegerType>(a < b);
}
// For each pair of input scalars, the corresponding bits of the result are
// set if the input scalars a, b satisfy a <= b.
template <typename tIntegerType>
tIntegerType MaskIfLessThanOrEqual(tIntegerType a, tIntegerType b) {
return MaskIfNonZero<tIntegerType>(a <= b);
}
// Returns true if all of the input scalars are nonzero.
// This function may currently assume that each of the input scalars has either
// all or none of its bits set. Otherwise, its behavior is currently undefined.
template <typename tIntegerType>
bool All(tIntegerType a) {
return a;
}
// Returns true if any of the input scalars are nonzero.
// This function may currently assume that each of the input scalars has either
// all or none of its bits set. Otherwise, its behavior is currently undefined.
template <typename tIntegerType>
bool Any(tIntegerType a) {
return a;
}
// Returns (a+b)/2, rounded to the nearest integer.
// Equivalent to VRHADD in the ARM NEON instruction set.
template <typename IntegerType>
IntegerType RoundingHalfSum(IntegerType a, IntegerType b) {
static_assert(std::is_same<IntegerType, void>::value, "unimplemented");
(void)b;
return a;
}
template <>
inline std::int32_t RoundingHalfSum(std::int32_t a, std::int32_t b) {
std::int64_t a64 = a;
std::int64_t b64 = b;
std::int64_t sum = a64 + b64;
std::int64_t sign = sum >= 0 ? 1 : -1;
return (std::int32_t)((sum + sign) / 2);
}
template <>
inline std::int16_t RoundingHalfSum(std::int16_t a, std::int16_t b) {
std::int32_t a32 = a;
std::int32_t b32 = b;
std::int32_t sum = a32 + b32;
std::int32_t sign = sum >= 0 ? 1 : -1;
return (std::int16_t)((sum + sign) / 2);
}
template <typename IntegerType>
IntegerType SaturatingAdd(IntegerType a, IntegerType b) {
static_assert(std::is_same<IntegerType, void>::value, "unimplemented");
(void)b;
return a;
}
// So far this is only needed for int16.
template <>
inline std::int16_t SaturatingAdd(std::int16_t a, std::int16_t b) {
std::int32_t a32 = a;
std::int32_t b32 = b;
std::int32_t sum = a32 + b32;
return (std::int16_t)(std::min((std::int32_t)(32767), std::max((std::int32_t)(-32768), sum)));
}
template <>
inline std::int8_t SaturatingAdd(std::int8_t a, std::int8_t b) {
std::int16_t a16 = a;
std::int16_t b16 = b;
std::int16_t sum = a16 + b16;
return (std::int8_t)(std::min((int16_t)(std::numeric_limits<int8_t>::max()),
std::max((int16_t)(std::numeric_limits<int8_t>::min()), sum)));
}
// Returns a+b, saturating if the integers are 16bit or narrower,
// otherwise just a plain addition.
template <typename IntegerType, bool Is16Bit>
struct AddSaturatingIf16BitImpl {
static IntegerType Run(IntegerType a, IntegerType b) { return Add(a, b); }
};
template <typename IntegerType>
struct AddSaturatingIf16BitImpl<IntegerType, true> {
static IntegerType Run(IntegerType a, IntegerType b) { return SaturatingAdd(a, b); }
};
template <typename IntegerType>
IntegerType AddSaturatingIf16Bit(IntegerType a, IntegerType b) {
using ScalarType = typename FixedPointRawTypeTraits<IntegerType>::ScalarRawType;
return AddSaturatingIf16BitImpl<IntegerType, sizeof(ScalarType) == 2>::Run(a, b);
}
// Returns the integer that represents the product of two fixed-point
// numbers, interpreting all integers as fixed-point values in the
// interval [-1, 1), rounding to the nearest value, and saturating
// -1 * -1 to the maximum value (since 1 is not in the half-open
// interval [-1, 1)).
//
// [The explanation below specializes to std::int32_t for example purpose.]
//
// The mapping between IntegerType and the interval [-1, 1) is unique and
// implied by IntegerType, which is assumed to be signed. For example,
// for IntegerType==std::int32_t, the mapping is
// real_value = integer_value / 2^31.
// So in this case, and leaving aside rounding and saturating, this
// function computes ((a / 2^31) * (b / 2^31)) * 2^31, which simplifies to
// (a * b) / 2^31.
//
// The 'doubling' part in the name of this function comes from the fact that
// this operation is very close to a "multiply-high" operation, keeping only
// the top half bits, except that that would be effectively computing
// (a * b) / 2^32,
// so here we are computing 2x that, since
// 1/2^31 = 2 * 1/2^32.
// The idea is to use all of the available 32 bits in the destination int32
// value.
//
// [End of the explanation specializing to int32.]
//
// This is equivalent to the VQRDMULH instruction in ARM NEON.
template <typename IntegerType>
IntegerType SaturatingRoundingDoublingHighMul(IntegerType a, IntegerType b) {
static_assert(std::is_same<IntegerType, void>::value, "unimplemented");
(void)b;
return a;
}
// This function implements the same computation as the ARMv7 NEON VQRDMULH
// instruction.
template <>
inline std::int32_t SaturatingRoundingDoublingHighMul(std::int32_t a, std::int32_t b) {
bool overflow = a == b && a == std::numeric_limits<std::int32_t>::min();
std::int64_t a_64(a);
std::int64_t b_64(b);
std::int64_t ab_64 = a_64 * b_64;
std::int32_t nudge = ab_64 >= 0 ? (1 << 30) : (1 - (1 << 30));
std::int32_t ab_x2_high32 = (std::int32_t)((ab_64 + nudge) / (1ll << 31));
return overflow ? std::numeric_limits<std::int32_t>::max() : ab_x2_high32;
}
template <>
inline std::int16_t SaturatingRoundingDoublingHighMul(std::int16_t a, std::int16_t b) {
bool overflow = a == b && a == std::numeric_limits<std::int16_t>::min();
std::int32_t a_32(a);
std::int32_t b_32(b);
std::int32_t ab_32 = a_32 * b_32;
std::int16_t nudge = ab_32 >= 0 ? (1 << 14) : (1 - (1 << 14));
std::int16_t ab_x2_high16 = (std::int16_t)((ab_32 + nudge) / (1 << 15));
return overflow ? std::numeric_limits<std::int16_t>::max() : ab_x2_high16;
}
// Correctly-rounded-to-nearest division by a power-of-two.
// Also known as a rounding arithmetic right shift.
template <typename IntegerType, typename ExponentType>
inline IntegerType RoundingDivideByPOT(IntegerType x, ExponentType exponent) {
assert(exponent >= 0);
assert(exponent <= 31);
const IntegerType mask = Dup<IntegerType>((1ll << exponent) - 1);
const IntegerType zero = Dup<IntegerType>(0);
const IntegerType one = Dup<IntegerType>(1);
const IntegerType remainder = BitAnd(x, mask);
const IntegerType threshold = Add(ShiftRight(mask, 1), BitAnd(MaskIfLessThan(x, zero), one));
return Add(ShiftRight(x, exponent), BitAnd(MaskIfGreaterThan(remainder, threshold), one));
// division by a 2^exponent with rounding
// or arithmetic right shift with rouding
inline int RoundingDivideByPOT(int x, int exponent) {
MS_ASSERT(exponent >= 0);
MS_ASSERT(exponent <= 31);
const int mask = (1ll << exponent) - 1;
const int remainder = x & mask;
const int threshold = (mask >> 1) + (x < 0 ? 1 : 0);
return (x >> exponent) + (remainder > threshold ? 1 : 0);
}
inline int MultiplyByQuantizedMultiplier(int32_t value, int32_t multiplier, int32_t left_shift, int32_t right_shift) {
return RoundingDivideByPOT(SaturatingRoundingDoublingHighMul(value * (1 << left_shift), multiplier), -right_shift);
}
// Returns the product of a run-time integer value by a compile-time power
// of two, with either a positive exponent (equivalent to an arithmetic
// left shift, saturating) or a negative exponent (equivalent to an arithmetic
// right shift, rounding to nearest).
template <int Exponent, typename IntegerType, int ExponentSign = (Exponent > 0 ? 1 : Exponent < 0 ? -1 : 0)>
struct ImplSaturatingRoundingMultiplyByPOT {};
template <int Exponent, typename IntegerType>
struct ImplSaturatingRoundingMultiplyByPOT<Exponent, IntegerType, 0> {
static IntegerType eval(IntegerType x) { return x; }
};
template <int Exponent, typename IntegerType>
struct ImplSaturatingRoundingMultiplyByPOT<Exponent, IntegerType, 1> {
static IntegerType eval(IntegerType x) {
using ScalarIntegerType = typename FixedPointRawTypeTraits<IntegerType>::ScalarRawType;
const IntegerType min = Dup<IntegerType>(std::numeric_limits<ScalarIntegerType>::min());
const IntegerType max = Dup<IntegerType>(std::numeric_limits<ScalarIntegerType>::max());
const int ScalarIntegerTypeBits = 8 * sizeof(ScalarIntegerType);
const std::int32_t threshold = ((1 << (ScalarIntegerTypeBits - 1 - Exponent)) - 1);
const IntegerType positive_mask = MaskIfGreaterThan(x, Dup<IntegerType>(threshold));
const IntegerType negative_mask = MaskIfLessThan(x, Dup<IntegerType>(-threshold));
IntegerType result = ShiftLeft(x, Exponent);
result = SelectUsingMask(positive_mask, max, result);
result = SelectUsingMask(negative_mask, min, result);
return result;
}
};
template <int Exponent, typename IntegerType>
struct ImplSaturatingRoundingMultiplyByPOT<Exponent, IntegerType, -1> {
static IntegerType eval(IntegerType x) { return RoundingDivideByPOT<IntegerType>(x, -Exponent); }
};
template <int Exponent, typename IntegerType>
IntegerType SaturatingRoundingMultiplyByPOT(IntegerType x) {
return ImplSaturatingRoundingMultiplyByPOT<Exponent, IntegerType>::eval(x);
}
// Part 2: the FixedPoint class.
// A FixedPoint object represents a fixed-point value stored in the underlying
// integer type tRawType, if tRawType is a plain scalar integer type.
// Alternatively, tRawType may be a SIMD type (e.g. NEON int32x4_t) in which
// case a FixedPoint object represents a corresponding SIMD vector of fixed
// point values.
//
// tIntegerBits describes the range of the fixed-point format: if
// tIntegerBits == m then the range of representable values is the half-open
// interval [-2^m; 2^m) where the open boundary on the right side means that
// 2^m is not representable (how close the maximum representable value is to
// it, depends on bit-depth of tRawType).
//
// In "Q format notation",
// https://en.wikipedia.org/wiki/Q_(number_format)
// we are describing the format
// Qm.n
// where
// m = tIntegerBits
// and
// n = NumberOfBits(tRawType) - (m + 1)
// Note that the (m + 1) in the above line is because we adopt the convention
// that we count the integer bits exclusively of the sign bit; so (m + 1) is
// the total number of integer bits inclusive of the sign bit.
//
// Accordingly, the number of integral representable values in our range
// [-2^m ; 2^m)
// is equal to 2^(m+1).
template <typename tRawType, int tIntegerBits>
class FixedPoint {
public:
typedef tRawType RawType;
typedef FixedPointRawTypeTraits<RawType> RawTypeTraits;
typedef typename RawTypeTraits::ScalarRawType ScalarRawType;
static constexpr int kTotalBits = 8 * sizeof(ScalarRawType);
static constexpr int kIntegerBits = tIntegerBits;
static constexpr int kFractionalBits = kTotalBits - 1 - kIntegerBits;
static_assert(kIntegerBits >= 0 && kIntegerBits < kTotalBits, "bad IntegerBits");
typedef FixedPoint<ScalarRawType, kIntegerBits> ScalarFixedPointType;
static const ScalarRawType ScalarRawMin() { return std::numeric_limits<ScalarRawType>::min(); }
static const ScalarRawType ScalarRawMax() { return std::numeric_limits<ScalarRawType>::max(); }
static const ScalarRawType RawMin() { return VectorFromScalar(ScalarRawMin()); }
static const ScalarRawType RawMax() { return VectorFromScalar(ScalarRawMax()); }
static FixedPoint FromRaw(RawType x) {
FixedPoint retval;
retval.raw() = x;
return retval;
}
static FixedPoint FromScalarRaw(ScalarRawType x) {
FixedPoint retval;
retval.raw() = Dup<RawType>(x);
return retval;
}
static FixedPoint FromScalarFixedPoint(ScalarFixedPointType x) { return FromScalarRaw(x.raw()); }
template <int Exponent>
static FixedPoint ConstantPOT() {
static constexpr int kOffset = kFractionalBits + Exponent;
static_assert(kOffset < 31, "Constant not exactly representable in this fixed-point format");
return FromScalarRaw(ScalarRawType(1) << kOffset);
}
static FixedPoint Zero() { return FromScalarRaw(0); }
static FixedPoint One() {
return FromScalarRaw(kIntegerBits == 0 ? ScalarRawMax()
: (ScalarRawType(1) << (kIntegerBits == 0 ? 0 : kFractionalBits)));
}
static FixedPoint FromDouble(double x) {
const double min_bound = (double)(ScalarRawMin());
const double max_bound = (double)(ScalarRawMax());
return FromScalarRaw(
(ScalarRawType)(std::min(std::max(round(x * (double)(1ll << kFractionalBits)), min_bound), max_bound)));
}
RawType raw() const { return i_; }
RawType &raw() { return i_; }
private:
RawType i_;
};
// Part 3: implementation of arithmetic operators for the
// FixedPoint class, and a few related functions.
// A FixedPoint multiplication is just a
// SaturatingRoundingDoublingHighMul operation on the underlying
// raw integer values. The IntegerBits simply add up, as is obvious
// from the fact that the range is [-2^IntegerBits, 2^IntegerBits).
template <typename tRawType, int tIntegerBits_a, int tIntegerBits_b>
FixedPoint<tRawType, tIntegerBits_a + tIntegerBits_b> operator*(FixedPoint<tRawType, tIntegerBits_a> a,
FixedPoint<tRawType, tIntegerBits_b> b) {
FixedPoint<tRawType, tIntegerBits_a + tIntegerBits_b> c;
c.raw() = SaturatingRoundingDoublingHighMul(a.raw(), b.raw());
return c;
}
// Tweaking IntegerBits gives exact multiplication by a power of two.
template <int tExponent, typename tRawType, int tIntegerBits>
FixedPoint<tRawType, tExponent + tIntegerBits> ExactMulByPot(FixedPoint<tRawType, tIntegerBits> a) {
FixedPoint<tRawType, tExponent + tIntegerBits> c;
c.raw() = a.raw();
return c;
}
// If we want to leave IntegerBits fixed, then multiplication
// by a power of two has to be saturating/rounding, not exact anymore.
template <int tExponent, typename tRawType, int tIntegerBits>
FixedPoint<tRawType, tIntegerBits> SaturatingRoundingMultiplyByPOT(FixedPoint<tRawType, tIntegerBits> a) {
return FixedPoint<tRawType, tIntegerBits>::FromRaw(SaturatingRoundingMultiplyByPOT<tExponent>(a.raw()));
}
// Generic arithmetic operators.
#define MAKE_FIXEDPOINT_UNARY_FUNC(FuncName, ImplFuncName) \
template <typename tRawType, int tIntegerBits> \
FixedPoint<tRawType, tIntegerBits> FuncName(FixedPoint<tRawType, tIntegerBits> a) { \
return FixedPoint<tRawType, tIntegerBits>::FromRaw(ImplFuncName(a.raw())); \
}
#define MAKE_FIXEDPOINT_BINARY_FUNC(FuncName, ImplFuncName) \
template <typename tRawType, int tIntegerBits> \
FixedPoint<tRawType, tIntegerBits> FuncName(FixedPoint<tRawType, tIntegerBits> a, \
FixedPoint<tRawType, tIntegerBits> b) { \
return FixedPoint<tRawType, tIntegerBits>::FromRaw(ImplFuncName(a.raw(), b.raw())); \
}
MAKE_FIXEDPOINT_UNARY_FUNC(operator-, Neg)
MAKE_FIXEDPOINT_UNARY_FUNC(operator~, BitNot)
MAKE_FIXEDPOINT_BINARY_FUNC(operator+, Add)
MAKE_FIXEDPOINT_BINARY_FUNC(operator-, Sub)
MAKE_FIXEDPOINT_BINARY_FUNC(operator&, BitAnd)
MAKE_FIXEDPOINT_BINARY_FUNC(operator^, BitXor)
MAKE_FIXEDPOINT_BINARY_FUNC(operator|, BitOr)
MAKE_FIXEDPOINT_BINARY_FUNC(RoundingHalfSum, RoundingHalfSum)
#undef MAKE_FIXEDPOINT_UNARY_FUNC
#undef MAKE_FIXEDPOINT_BINARY_FUNC
#define MAKE_FIXEDPOINT_UNARY_FUNC_RETURNING_RAW(FuncName) \
template <typename tRawType, int tIntegerBits> \
tRawType FuncName(FixedPoint<tRawType, tIntegerBits> a) { \
return FuncName(a.raw()); \
}
#define MAKE_FIXEDPOINT_BINARY_FUNC_RETURNING_RAW(FuncName) \
template <typename tRawType, int tIntegerBits> \
tRawType FuncName(FixedPoint<tRawType, tIntegerBits> a, FixedPoint<tRawType, tIntegerBits> b) { \
return FuncName(a.raw(), b.raw()); \
}
MAKE_FIXEDPOINT_UNARY_FUNC_RETURNING_RAW(MaskIfZero)
MAKE_FIXEDPOINT_UNARY_FUNC_RETURNING_RAW(MaskIfNonZero)
MAKE_FIXEDPOINT_BINARY_FUNC_RETURNING_RAW(MaskIfEqual)
MAKE_FIXEDPOINT_BINARY_FUNC_RETURNING_RAW(MaskIfNotEqual)
MAKE_FIXEDPOINT_BINARY_FUNC_RETURNING_RAW(MaskIfGreaterThan)
MAKE_FIXEDPOINT_BINARY_FUNC_RETURNING_RAW(MaskIfGreaterThanOrEqual)
MAKE_FIXEDPOINT_BINARY_FUNC_RETURNING_RAW(MaskIfLessThan)
MAKE_FIXEDPOINT_BINARY_FUNC_RETURNING_RAW(MaskIfLessThanOrEqual)
#undef MAKE_FIXEDPOINT_UNARY_FUNC_RETURNING_RAW
#undef MAKE_FIXEDPOINT_BINARY_FUNC_RETURNING_RAW
template <typename tRawType, int tIntegerBits>
FixedPoint<tRawType, tIntegerBits> SelectUsingMask(tRawType if_mask, FixedPoint<tRawType, tIntegerBits> then_val,
FixedPoint<tRawType, tIntegerBits> else_val) {
return FixedPoint<tRawType, tIntegerBits>::FromRaw(SelectUsingMask(if_mask, then_val.raw(), else_val.raw()));
}
template <typename tRawType, int tIntegerBits>
bool operator==(FixedPoint<tRawType, tIntegerBits> a, FixedPoint<tRawType, tIntegerBits> b) {
return All(MaskIfEqual(a.raw(), b.raw()));
}
template <typename tRawType, int tIntegerBits>
bool operator!=(FixedPoint<tRawType, tIntegerBits> a, FixedPoint<tRawType, tIntegerBits> b) {
return !(a == b);
}
template <typename tRawType, int tIntegerBits>
FixedPoint<tRawType, tIntegerBits> SaturatingAdd(FixedPoint<tRawType, tIntegerBits> a,
FixedPoint<tRawType, tIntegerBits> b) {
return FixedPoint<tRawType, tIntegerBits>::FromRaw(SaturatingAdd(a.raw(), b.raw()));
}
template <typename tRawType, int tIntegerBits>
FixedPoint<tRawType, tIntegerBits> AddSaturatingIf16Bit(FixedPoint<tRawType, tIntegerBits> a,
FixedPoint<tRawType, tIntegerBits> b) {
return FixedPoint<tRawType, tIntegerBits>::FromRaw(AddSaturatingIf16Bit(a.raw(), b.raw()));
}
// Conversion to floating-point.
template <typename tRawType, int tIntegerBits>
double ToDouble(FixedPoint<tRawType, tIntegerBits> x) {
static_assert(FixedPointRawTypeTraits<tRawType>::kLanes == 1, "not applicable to SIMD types");
typedef FixedPoint<tRawType, tIntegerBits> F;
return x.raw() / (double)(1ll << F::kFractionalBits);
}
// Rescale changes the number of IntegerBits and updates the underlying
// raw integer value accordingly.
template <int tIntegerBitsDst, typename tRawType, int tIntegerBitsSrc>
FixedPoint<tRawType, tIntegerBitsDst> Rescale(FixedPoint<tRawType, tIntegerBitsSrc> x) {
static constexpr int kExponent = tIntegerBitsSrc - tIntegerBitsDst;
FixedPoint<tRawType, tIntegerBitsDst> result;
result.raw() = SaturatingRoundingMultiplyByPOT<kExponent>(x.raw());
return result;
}
// CheckedFixedPointConstant allows to specify fixed-point constants
// initialized as real numbers, in a way that does not compile floating-point
// arithmetic in production code, yet still checks agreement with the
// floating-point expressions when asserts are enabled.
//
// The raw integer value provided is always a int32, encoding a 32-bit
// fixed-point value, regardless of the actual Scalar type. This allows
// writing generic code that applies just as well to the 32-bit and 16-bit
// cases. In the 16-bit case, the raw integer value is internally
// rounding-shifted by 16 bits to the right.
template <typename FixedPointType>
inline typename FixedPointType::ScalarRawType RescaleConstantInitializer(std::int32_t int32_value) {
typedef typename FixedPointType::ScalarRawType ScalarRawType;
static constexpr int ScalarTypeBits = 8 * sizeof(ScalarRawType);
return (ScalarRawType)(RoundingDivideByPOT<std::int32_t>(int32_value, 32 - ScalarTypeBits));
}
// Implementation of exponential function.
// Returns -tanh(x) for x < 0.
template <typename tRawType, int tIntegerBits>
FixedPoint<tRawType, 0> neg_tanh_on_negative_values(FixedPoint<tRawType, tIntegerBits> a) {
return one_minus_x_over_one_plus_x_for_x_in_0_1(exp_on_negative_values(ExactMulByPot<1>(a)));
}
// Returns tanh(x) for any x.
template <typename tRawType, int tIntegerBits>
FixedPoint<tRawType, 0> tanh(FixedPoint<tRawType, tIntegerBits> a) {
typedef FixedPoint<tRawType, tIntegerBits> InputF;
typedef FixedPoint<tRawType, 0> ResultF;
tRawType mask_if_negative = MaskIfLessThan(a, InputF::Zero());
tRawType mask_if_zero = MaskIfZero(a);
InputF n = SelectUsingMask(mask_if_negative, a, -a);
ResultF t = neg_tanh_on_negative_values(n);
return SelectUsingMask(mask_if_zero, ResultF::Zero(), SelectUsingMask(mask_if_negative, -t, t));
}
// Implementation of logistic function.
// Returns logistic(x) = 1 / (1 + exp(-x)) for x > 0.
template <typename tRawType, int tIntegerBits>
FixedPoint<tRawType, 0> logistic_on_positive_values(FixedPoint<tRawType, tIntegerBits> a) {
return one_over_one_plus_x_for_x_in_0_1(exp_on_negative_values(-a));
}
#ifdef ENABLE_NEON
inline int32x4_t RoundingDivideByPOTInt32x4(int32x4_t x, int exponent) {
const int32x4_t shift_vec = vdupq_n_s32(-exponent);