fix bug && optimize relu/relu6/winograd unpack func
This commit is contained in:
parent
31e889881b
commit
0e5d30f917
|
@ -113,6 +113,9 @@ if (BUILD_DEVICE)
|
|||
if (PLATFORM_ARM64)
|
||||
set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -march=armv8.2-a+dotprod+fp16")
|
||||
add_compile_definitions(ENABLE_ARM64)
|
||||
if (ENABLE_FP16)
|
||||
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -march=armv8.2-a+dotprod+fp16")
|
||||
endif ()
|
||||
endif()
|
||||
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/src)
|
||||
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/tools/benchmark)
|
||||
|
|
|
@ -1,15 +1,16 @@
|
|||
file(GLOB_RECURSE KERNEL_SRC
|
||||
file(GLOB KERNEL_SRC
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/base/*.cc
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/opclib/*.cc
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/opclib/fp32/*.cc
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/opclib/int8/*.cc
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/opclib/quantization/*.cc
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/fp32/*.cc
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/int8/*.cc
|
||||
)
|
||||
|
||||
if (PLATFORM_ARM64)
|
||||
# assembly
|
||||
file(GLOB_RECURSE ASSEMBLY_SRC ${CMAKE_CURRENT_SOURCE_DIR}/opclib/assembly/arm64/*.s
|
||||
file(GLOB ASSEMBLY_SRC ${CMAKE_CURRENT_SOURCE_DIR}/opclib/assembly/arm64/*.s
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/opclib/assembly/arm64/*.S)
|
||||
set_property(SOURCE ${ASSEMBLY_SRC} PROPERTY LANGUAGE C)
|
||||
set(KERNEL_SRC ${KERNEL_SRC} ${ASSEMBLY_SRC})
|
||||
|
@ -17,13 +18,15 @@ endif()
|
|||
|
||||
if (PLATFORM_ARM32)
|
||||
# assembly
|
||||
file(GLOB_RECURSE ASSEMBLY_SRC ${CMAKE_CURRENT_SOURCE_DIR}/opclib/assembly/arm32/*.s)
|
||||
file(GLOB ASSEMBLY_SRC ${CMAKE_CURRENT_SOURCE_DIR}/opclib/assembly/arm32/*.s
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/opclib/assembly/arm32/*.S
|
||||
)
|
||||
set_property(SOURCE ${ASSEMBLY_SRC} PROPERTY LANGUAGE C)
|
||||
set(KERNEL_SRC ${KERNEL_SRC} ${ASSEMBLY_SRC})
|
||||
endif()
|
||||
|
||||
if (ENABLE_FP16)
|
||||
file(GLOB_RECURSE FP6_SRC
|
||||
file(GLOB FP6_SRC
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/fp16/*.cc
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/opclib/fp16/*.cc
|
||||
)
|
||||
|
|
|
@ -101,7 +101,6 @@ void ConvolutionBaseCPUKernel::FreeQuantParam() {
|
|||
int ConvolutionBaseCPUKernel::Init() {
|
||||
auto input = this->inputs_.front();
|
||||
auto output = this->outputs_.front();
|
||||
|
||||
conv_param_->input_batch_ = input->Batch();
|
||||
conv_param_->input_h_ = input->Height();
|
||||
conv_param_->input_w_ = input->Width();
|
||||
|
@ -111,7 +110,6 @@ int ConvolutionBaseCPUKernel::Init() {
|
|||
conv_param_->output_w_ = output->Width();
|
||||
conv_param_->output_channel_ = output->Channel();
|
||||
conv_param_->thread_num_ = ctx_->threadNum;
|
||||
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
|
@ -221,9 +219,24 @@ void CheckIfUseWinograd(bool *use_winograd, int *output_unit, ConvParameter *con
|
|||
}
|
||||
}
|
||||
|
||||
kernel::LiteKernel *CpuConvFp32KernelCreator(const std::vector<lite::tensor::Tensor *> &inputs,
|
||||
const std::vector<lite::tensor::Tensor *> &outputs,
|
||||
OpParameter *opParameter, const Context *ctx) {
|
||||
bool CheckSupportFP16() {
|
||||
bool support_fp16 = false;
|
||||
#ifdef ENABLE_ARM64
|
||||
void *optimize_op_handler = OptimizeModule::GetInstance()->optimized_op_handler_;
|
||||
if (optimize_op_handler != nullptr) {
|
||||
support_fp16 = true;
|
||||
MS_LOG(INFO) << "Support FP16.";
|
||||
} else {
|
||||
support_fp16 = false;
|
||||
MS_LOG(INFO) << "Your machine doesn't support fp16, return back to float32 kernel.";
|
||||
}
|
||||
#endif
|
||||
return support_fp16;
|
||||
}
|
||||
|
||||
kernel::LiteKernel *CpuConvFloatKernelCreator(const std::vector<lite::tensor::Tensor *> &inputs,
|
||||
const std::vector<lite::tensor::Tensor *> &outputs,
|
||||
OpParameter *opParameter, const Context *ctx) {
|
||||
auto conv_param = reinterpret_cast<ConvParameter *>(opParameter);
|
||||
int kernel_h = conv_param->kernel_h_;
|
||||
int kernel_w = conv_param->kernel_w_;
|
||||
|
@ -240,44 +253,35 @@ kernel::LiteKernel *CpuConvFp32KernelCreator(const std::vector<lite::tensor::Ten
|
|||
InputTransformUnitFunc input_trans_func = nullptr;
|
||||
OutputTransformUnitFunc output_trans_func = nullptr;
|
||||
CheckIfUseWinograd(&use_winograd, &out_unit, conv_param, input_trans_func, output_trans_func);
|
||||
bool support_fp16 = CheckSupportFP16();
|
||||
|
||||
if (kernel_h == 1 && kernel_w == 1) {
|
||||
auto kernel = new (std::nothrow) Convolution1x1CPUKernel(opParameter, inputs, outputs, ctx);
|
||||
return kernel;
|
||||
} else if (kernel_h == 3 && kernel_w == 3 && stride_h == 1 && stride_w == 1 && dilation_h == 1 && dilation_w == 1) {
|
||||
if (support_fp16) {
|
||||
#ifdef ENABLE_FP16
|
||||
auto kernel = new (std::nothrow) Convolution3x3FP16CPUKernel(opParameter, inputs, outputs, ctx);
|
||||
return kernel;
|
||||
#endif
|
||||
}
|
||||
auto kernel = new (std::nothrow) Convolution3x3CPUKernel(opParameter, inputs, outputs, ctx);
|
||||
return kernel;
|
||||
} else if (use_winograd) {
|
||||
auto kernel = new (std::nothrow) ConvolutionWinogradCPUKernel(opParameter, inputs, outputs, ctx, out_unit);
|
||||
return kernel;
|
||||
} else {
|
||||
if (support_fp16) {
|
||||
#ifdef ENABLE_FP16
|
||||
auto kernel = new (std::nothrow) ConvolutionFP16CPUKernel(opParameter, inputs, outputs, ctx);
|
||||
return kernel;
|
||||
#endif
|
||||
}
|
||||
auto kernel = new (std::nothrow) ConvolutionCPUKernel(opParameter, inputs, outputs, ctx);
|
||||
return kernel;
|
||||
}
|
||||
}
|
||||
|
||||
#ifdef ENABLE_FP16
|
||||
kernel::LiteKernel *CpuConvFp16KernelCreator(const std::vector<lite::tensor::Tensor *> &inputs,
|
||||
const std::vector<lite::tensor::Tensor *> &outputs,
|
||||
OpParameter *opParameter, const Context *ctx) {
|
||||
auto conv_param = reinterpret_cast<ConvParameter *>(opParameter);
|
||||
int kernel_h = conv_param->kernel_h_;
|
||||
int kernel_w = conv_param->kernel_w_;
|
||||
int stride_h = conv_param->stride_h_;
|
||||
int stride_w = conv_param->stride_w_;
|
||||
int dilation_h = conv_param->dilation_h_;
|
||||
int dilation_w = conv_param->dilation_w_;
|
||||
|
||||
if (kernel_h == 3 && kernel_w == 3 && stride_h == 1 && stride_w == 1 && dilation_h == 1 && dilation_w == 1) {
|
||||
auto kernel = new (std::nothrow) Convolution3x3FP16CPUKernel(opParameter, inputs, outputs, ctx);
|
||||
return kernel;
|
||||
} else {
|
||||
auto kernel = new (std::nothrow) ConvolutionFP16CPUKernel(opParameter, inputs, outputs, ctx);
|
||||
return kernel;
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
kernel::LiteKernel *CpuConvInt8KernelCreator(const std::vector<lite::tensor::Tensor *> &inputs,
|
||||
const std::vector<lite::tensor::Tensor *> &outputs,
|
||||
OpParameter *opParameter, const Context *ctx) {
|
||||
|
@ -308,17 +312,10 @@ kernel::LiteKernel *CpuConvKernelCreator(const std::vector<lite::tensor::Tensor
|
|||
kernel::LiteKernel *kernel = nullptr;
|
||||
switch (data_type) {
|
||||
case kNumberTypeInt8:
|
||||
break;
|
||||
case kNumberTypeUInt8:
|
||||
kernel = CpuConvInt8KernelCreator(inputs, outputs, opParameter, ctx);
|
||||
break;
|
||||
#ifdef ENABLE_FP16
|
||||
case kNumberTypeFloat16:
|
||||
kernel = CpuConvFp16KernelCreator(inputs, outputs, opParameter, ctx);
|
||||
break;
|
||||
#endif
|
||||
case kNumberTypeFloat32:
|
||||
kernel = CpuConvFp32KernelCreator(inputs, outputs, opParameter, ctx);
|
||||
kernel = CpuConvFloatKernelCreator(inputs, outputs, opParameter, ctx);
|
||||
break;
|
||||
default:
|
||||
break;
|
||||
|
@ -385,8 +382,6 @@ kernel::LiteKernel *CpuConvDwKernelCreator(const std::vector<lite::tensor::Tenso
|
|||
case kNumberTypeInt8:
|
||||
kernel = CpuConvDwInt8KernelCreator(inputs, outputs, opParameter, ctx);
|
||||
break;
|
||||
case kNumberTypeUInt8:
|
||||
break;
|
||||
case kNumberTypeFloat32:
|
||||
#ifdef ENABLE_FP16
|
||||
kernel = CpuConvDwFp16KernelCreator(inputs, outputs, opParameter, ctx);
|
||||
|
@ -515,8 +510,6 @@ kernel::LiteKernel *CpuDeConvKernelCreator(const std::vector<lite::tensor::Tenso
|
|||
kernel::LiteKernel *kernel = nullptr;
|
||||
switch (data_type) {
|
||||
case kNumberTypeInt8:
|
||||
break;
|
||||
case kNumberTypeUInt8:
|
||||
kernel = CpuDeConvInt8KernelCreator(inputs, outputs, opParameter, ctx);
|
||||
break;
|
||||
#ifdef ENABLE_FP16
|
||||
|
|
|
@ -26,10 +26,9 @@
|
|||
#include <android/log.h>
|
||||
#endif
|
||||
#include "src/lite_kernel.h"
|
||||
|
||||
#include "include/context.h"
|
||||
|
||||
#include "src/runtime/kernel/arm/base/layout_transform.h"
|
||||
#include "src/runtime/kernel/arm/opclib/optimized_kernel.h"
|
||||
|
||||
using mindspore::lite::Context;
|
||||
using mindspore::schema::PadMode;
|
||||
|
@ -40,7 +39,7 @@ class ConvolutionBaseCPUKernel : public LiteKernel {
|
|||
public:
|
||||
ConvolutionBaseCPUKernel(OpParameter *parameter, const std::vector<lite::tensor::Tensor *> &inputs,
|
||||
const std::vector<lite::tensor::Tensor *> &outputs, const Context *ctx)
|
||||
: LiteKernel(parameter, inputs, outputs), ctx_(ctx), thread_count_(ctx->threadNum) {
|
||||
: LiteKernel(parameter, inputs, outputs), ctx_(ctx), thread_count_(ctx->threadNum) {
|
||||
opParameter->thread_num_ = ctx->threadNum;
|
||||
conv_param_ = reinterpret_cast<ConvParameter *>(opParameter);
|
||||
}
|
||||
|
@ -63,6 +62,7 @@ class ConvolutionBaseCPUKernel : public LiteKernel {
|
|||
LayoutConvertor convert_func_;
|
||||
};
|
||||
void ComputeQuantOutRange(ConvParameter *conv_param);
|
||||
bool CheckSupportFP16();
|
||||
} // namespace mindspore::kernel
|
||||
|
||||
#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_BASE_CONVOLUTION_BASE_H_
|
||||
|
|
|
@ -49,6 +49,8 @@ void ProcessFilterFp16(float16_t *origin_weight, float16_t *dst_weight, ConvPara
|
|||
int Convolution3x3FP16CPUKernel::InitWeightBias() {
|
||||
auto input_channel = conv_param_->input_channel_;
|
||||
int output_channel = conv_param_->output_channel_;
|
||||
int kernel_h = conv_param_->kernel_h_;
|
||||
int kernel_w = conv_param_->kernel_w_;
|
||||
int iC4 = UP_DIV(input_channel, C4NUM);
|
||||
int oC8 = UP_DIV(output_channel, C8NUM);
|
||||
// init weight
|
||||
|
@ -60,8 +62,8 @@ int Convolution3x3FP16CPUKernel::InitWeightBias() {
|
|||
}
|
||||
memset(transformed_filter_addr_, 0, transformed_size);
|
||||
float *origin_weight = reinterpret_cast<float *>(inputs_.at(kWeightIndex)->Data());
|
||||
size_t fp16_weight_size = in_channel * out_channel * kernel_h * kernel_w * sizeof(float16_t);
|
||||
fp16_weight_ = malloc(fp16_weight_size);
|
||||
size_t fp16_weight_size = input_channel * output_channel * kernel_h * kernel_w * sizeof(float16_t);
|
||||
fp16_weight_ = reinterpret_cast<float16_t *>(malloc(fp16_weight_size));
|
||||
if (fp16_weight_ == nullptr) {
|
||||
MS_LOG(ERROR) << "malloc fp16_weight_ failed.";
|
||||
return RET_ERROR;
|
||||
|
@ -74,16 +76,17 @@ int Convolution3x3FP16CPUKernel::InitWeightBias() {
|
|||
|
||||
// init bias
|
||||
size_t new_bias_size = oC8 * C8NUM * sizeof(float16_t);
|
||||
bias_data_ = reinterpret_cast<float16_t *>(malloc(new_bias_size));
|
||||
bias_data_ = malloc(new_bias_size);
|
||||
if (bias_data_ == nullptr) {
|
||||
MS_LOG(ERROR) << "malloc bias_data_ failed.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
memset(bias_data_, 0, new_bias_size);
|
||||
auto fp16_bias_data = reinterpret_cast<float16_t *>(bias_data_);
|
||||
if (inputs_.size() == kInputSize2) {
|
||||
auto ori_bias_addr = reinterpret_cast<float *>(inputs_.at(kBiasIndex)->Data());
|
||||
for (int i = 0; i < out_channel; ++i) {
|
||||
bias_data_[i] = (float16_t)ori_bias_addr[i];
|
||||
for (int i = 0; i < output_channel; ++i) {
|
||||
fp16_bias_data[i] = (float16_t)ori_bias_addr[i];
|
||||
}
|
||||
} else {
|
||||
MS_ASSERT(inputs_.size() == kInputSize1);
|
||||
|
@ -129,16 +132,15 @@ int Convolution3x3FP16CPUKernel::InitTmpBuffer() {
|
|||
}
|
||||
memset(tmp_out_, 0, tmp_out_size);
|
||||
|
||||
size_t fp16_input_size =
|
||||
in_channel * conv_param_->input_batch_ * conv_param_->input_h_ * conv_param_->input_w_ * sizeof(float16_t);
|
||||
fp16_input_ = malloc(fp16_input_size);
|
||||
size_t fp16_input_size = conv_param_->input_channel_ * conv_param_->input_batch_ * conv_param_->input_h_ *
|
||||
conv_param_->input_w_ * sizeof(float16_t);
|
||||
fp16_input_ = reinterpret_cast<float16_t *>(malloc(fp16_input_size));
|
||||
if (fp16_input_ == nullptr) {
|
||||
MS_LOG(ERROR) << "malloc fp16_input_ failed.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
memset(fp16_input_, 0, fp16_input_size);
|
||||
|
||||
|
||||
// init nhwc4 input
|
||||
size_t nhwc4_input_size =
|
||||
iC4 * C4NUM * conv_param_->input_batch_ * conv_param_->input_h_ * conv_param_->input_w_ * sizeof(float16_t);
|
||||
|
@ -249,4 +251,3 @@ int Convolution3x3FP16CPUKernel::Run() {
|
|||
return RET_OK;
|
||||
}
|
||||
} // namespace mindspore::kernel
|
||||
|
||||
|
|
|
@ -42,7 +42,7 @@ int ConvolutionFP16CPUKernel::InitWeightBias() {
|
|||
// init weight
|
||||
float *origin_weight = reinterpret_cast<float *>(inputs_.at(kWeightIndex)->Data());
|
||||
size_t fp16_weight_size = in_channel * out_channel * kernel_h * kernel_w * sizeof(float16_t);
|
||||
fp16_weight_ = malloc(fp16_weight_size);
|
||||
fp16_weight_ = reinterpret_cast<float16_t *>(malloc(fp16_weight_size));
|
||||
if (fp16_weight_ == nullptr) {
|
||||
MS_LOG(ERROR) << "malloc fp16_weight_ failed.";
|
||||
return RET_ERROR;
|
||||
|
@ -60,16 +60,17 @@ int ConvolutionFP16CPUKernel::InitWeightBias() {
|
|||
PackWeightFp16(fp16_weight_, conv_param_, packed_weight_);
|
||||
|
||||
// init bias
|
||||
bias_data_ = reinterpret_cast<float16_t *>(malloc(oc8 * C8NUM * sizeof(float16_t)));
|
||||
bias_data_ = malloc(oc8 * C8NUM * sizeof(float16_t));
|
||||
if (bias_data_ == nullptr) {
|
||||
MS_LOG(ERROR) << "malloc bias_data_ failed.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
memset(bias_data_, 0, oc8 * C8NUM * sizeof(float16_t));
|
||||
auto fp16_bias_data = reinterpret_cast<float16_t *>(bias_data_);
|
||||
if (inputs_.size() == kInputSize2) {
|
||||
auto ori_bias = reinterpret_cast<float *>(inputs_.at(kBiasIndex)->Data());
|
||||
for (int i = 0; i < out_channel; ++i) {
|
||||
bias_data_[i] = (float16_t)ori_bias[i];
|
||||
fp16_bias_data[i] = (float16_t)ori_bias[i];
|
||||
}
|
||||
} else {
|
||||
MS_ASSERT(inputs_.size() == kInputSize1);
|
||||
|
@ -101,7 +102,7 @@ int ConvolutionFP16CPUKernel::InitTmpBuffer() {
|
|||
|
||||
size_t fp16_input_size =
|
||||
in_channel * conv_param_->input_batch_ * conv_param_->input_h_ * conv_param_->input_w_ * sizeof(float16_t);
|
||||
fp16_input_ = malloc(fp16_input_size);
|
||||
fp16_input_ = reinterpret_cast<float16_t *>(malloc(fp16_input_size));
|
||||
if (fp16_input_ == nullptr) {
|
||||
MS_LOG(ERROR) << "malloc fp16_input_ failed.";
|
||||
return RET_ERROR;
|
||||
|
|
|
@ -20,9 +20,7 @@
|
|||
#include <arm_neon.h>
|
||||
#include <vector>
|
||||
#include "src/lite_kernel.h"
|
||||
|
||||
#include "src/runtime/kernel/arm/base/convolution_base.h"
|
||||
#include "src/runtime/kernel/arm/opclib/optimized_kernel.h"
|
||||
|
||||
namespace mindspore::kernel {
|
||||
typedef void (*FP16_GEMM_FUNC)(float16_t *output, float16_t *input, float16_t *weight, float16_t *bias, size_t step,
|
||||
|
@ -33,7 +31,7 @@ class ConvolutionFP16CPUKernel : public ConvolutionBaseCPUKernel {
|
|||
public:
|
||||
ConvolutionFP16CPUKernel(OpParameter *parameter, const std::vector<lite::tensor::Tensor *> &inputs,
|
||||
const std::vector<lite::tensor::Tensor *> &outputs, const Context *ctx)
|
||||
: ConvolutionBaseCPUKernel(parameter, inputs, outputs, ctx) {}
|
||||
: ConvolutionBaseCPUKernel(parameter, inputs, outputs, ctx) {}
|
||||
~ConvolutionFP16CPUKernel() override {
|
||||
if (fp16_input_ != nullptr) {
|
||||
free(fp16_input_);
|
||||
|
|
|
@ -33,10 +33,17 @@ int ConvolutionCPUKernel::InitWeightBias() {
|
|||
int kernel_w = conv_param_->kernel_w_;
|
||||
int in_channel = conv_param_->input_channel_;
|
||||
int out_channel = conv_param_->output_channel_;
|
||||
int oc8 = UP_DIV(out_channel, C8NUM);
|
||||
int ic4 = UP_DIV(in_channel, C4NUM);
|
||||
int kernel_plane = kernel_h * kernel_w;
|
||||
int pack_weight_size = oc8 * ic4 * C8NUM * C4NUM * kernel_plane;
|
||||
int oc_block, oc_block_num;
|
||||
#ifdef ENABLE_ARM32
|
||||
oc_block = C4NUM;
|
||||
oc_block_num = UP_DIV(out_channel, C4NUM);
|
||||
#else
|
||||
oc_block = C8NUM;
|
||||
oc_block_num = UP_DIV(out_channel, C8NUM);
|
||||
#endif
|
||||
int pack_weight_size = oc_block_num * oc_block * ic4 * C4NUM * kernel_plane;
|
||||
|
||||
// init weight
|
||||
auto origin_weight = reinterpret_cast<float *>(inputs_.at(kWeightIndex)->Data());
|
||||
|
@ -49,12 +56,12 @@ int ConvolutionCPUKernel::InitWeightBias() {
|
|||
PackWeightFp32(origin_weight, conv_param_, packed_weight_);
|
||||
|
||||
// init bias
|
||||
bias_data_ = reinterpret_cast<float *>(malloc(oc8 * C8NUM * sizeof(float)));
|
||||
bias_data_ = reinterpret_cast<float *>(malloc(oc_block_num * oc_block * sizeof(float)));
|
||||
if (bias_data_ == nullptr) {
|
||||
MS_LOG(ERROR) << "malloc bias failed.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
memset(bias_data_, 0, oc8 * C8NUM * sizeof(float));
|
||||
memset(bias_data_, 0, oc_block_num * oc_block * sizeof(float));
|
||||
if (inputs_.size() == kInputSize2) {
|
||||
auto ori_bias = reinterpret_cast<float *>(inputs_.at(kBiasIndex)->Data());
|
||||
memcpy(bias_data_, ori_bias, out_channel * sizeof(float));
|
||||
|
@ -198,4 +205,3 @@ int ConvolutionCPUKernel::Run() {
|
|||
return RET_OK;
|
||||
}
|
||||
} // namespace mindspore::kernel
|
||||
|
||||
|
|
|
@ -55,6 +55,7 @@ void ConvolutionInt8CPUKernel::CheckSupportOptimize() {
|
|||
support_optimize_ = false;
|
||||
}
|
||||
#endif
|
||||
conv_param_->tile_num_ = tile_num_;
|
||||
}
|
||||
|
||||
int ConvolutionInt8CPUKernel::InitWeightBias() {
|
||||
|
@ -78,7 +79,7 @@ int ConvolutionInt8CPUKernel::InitWeightBias() {
|
|||
return RET_ERROR;
|
||||
}
|
||||
memset(packed_weight_, 0, pack_weight_size);
|
||||
int32_t *weight_sum = reinterpret_cast<int32_t *>(malloc(sizeof(int32_t) * out_channel));
|
||||
auto *weight_sum = reinterpret_cast<int32_t *>(malloc(sizeof(int32_t) * out_channel));
|
||||
for (int i = 0; i < out_channel; i++) weight_sum[i] = 0;
|
||||
PackWeightInt8(origin_weight, conv_param_, packed_weight_, weight_sum);
|
||||
|
||||
|
@ -105,15 +106,14 @@ int ConvolutionInt8CPUKernel::InitWeightBias() {
|
|||
}
|
||||
|
||||
int ConvolutionInt8CPUKernel::InitTmpBuffer() {
|
||||
int tile_n = 4;
|
||||
int output_count = conv_param_->output_h_ * conv_param_->output_w_;
|
||||
int output_tile_count = UP_DIV(output_count, tile_n);
|
||||
int output_tile_count = UP_DIV(output_count, tile_num_);
|
||||
int in_channel = conv_param_->input_channel_;
|
||||
int ic4 = UP_DIV(in_channel, C4NUM);
|
||||
int kernel_plane = conv_param_->kernel_h_ * conv_param_->kernel_w_;
|
||||
int plane_c4 = UP_DIV(kernel_plane, C4NUM);
|
||||
int unit_size = plane_c4 * C4NUM * ic4 * C4NUM;
|
||||
int packed_input_size = output_tile_count * tile_n * unit_size;
|
||||
int packed_input_size = output_tile_count * tile_num_ * unit_size;
|
||||
|
||||
packed_input_ = reinterpret_cast<int8_t *>(malloc(conv_param_->input_batch_ * packed_input_size));
|
||||
if (packed_input_ == nullptr) {
|
||||
|
@ -122,14 +122,14 @@ int ConvolutionInt8CPUKernel::InitTmpBuffer() {
|
|||
}
|
||||
memset(packed_input_, 0, conv_param_->input_batch_ * packed_input_size);
|
||||
|
||||
input_sum_ = reinterpret_cast<int32_t *>(malloc(tile_n * thread_count_ * sizeof(int32_t)));
|
||||
input_sum_ = reinterpret_cast<int32_t *>(malloc(tile_num_ * thread_count_ * sizeof(int32_t)));
|
||||
if (input_sum_ == nullptr) {
|
||||
MS_LOG(ERROR) << "malloc input_sum_ failed.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
memset(input_sum_, 0, tile_n * thread_count_ * sizeof(int32_t));
|
||||
memset(input_sum_, 0, tile_num_ * thread_count_ * sizeof(int32_t));
|
||||
|
||||
size_t tmp_dst_size = thread_count_ * tile_n * conv_param_->output_channel_ * sizeof(int32_t);
|
||||
size_t tmp_dst_size = thread_count_ * tile_num_ * conv_param_->output_channel_ * sizeof(int32_t);
|
||||
tmp_dst_ = reinterpret_cast<int32_t *>(malloc(tmp_dst_size));
|
||||
if (tmp_dst_ == nullptr) {
|
||||
MS_LOG(ERROR) << "malloc tmp_dst_ failed.";
|
||||
|
@ -137,7 +137,7 @@ int ConvolutionInt8CPUKernel::InitTmpBuffer() {
|
|||
}
|
||||
memset(tmp_dst_, 0, tmp_dst_size);
|
||||
|
||||
tmp_out_ = reinterpret_cast<int8_t *>(malloc(thread_count_ * tile_n * conv_param_->output_channel_));
|
||||
tmp_out_ = reinterpret_cast<int8_t *>(malloc(thread_count_ * tile_num_ * conv_param_->output_channel_));
|
||||
if (tmp_out_ == nullptr) {
|
||||
MS_LOG(ERROR) << "malloc tmp_out_ failed.";
|
||||
return RET_ERROR;
|
||||
|
@ -173,7 +173,7 @@ int ConvolutionInt8CPUKernel::InitWeightBiasOpt() {
|
|||
return RET_ERROR;
|
||||
}
|
||||
memset(packed_weight_, filter_zp, pack_weight_size);
|
||||
int32_t *weight_sum = reinterpret_cast<int32_t *>(malloc(sizeof(int32_t) * out_channel));
|
||||
auto *weight_sum = reinterpret_cast<int32_t *>(malloc(sizeof(int32_t) * out_channel));
|
||||
for (int i = 0; i < out_channel; i++) weight_sum[i] = filter_zp * ic4 * C4NUM * kernel_plane;
|
||||
PackWeightInt8Opt(origin_weight, conv_param_, packed_weight_, weight_sum);
|
||||
|
||||
|
@ -200,15 +200,13 @@ int ConvolutionInt8CPUKernel::InitWeightBiasOpt() {
|
|||
}
|
||||
|
||||
int ConvolutionInt8CPUKernel::InitTmpBufferOpt() {
|
||||
// todo
|
||||
int tile_n = 24;
|
||||
int output_count = conv_param_->output_h_ * conv_param_->output_w_;
|
||||
int output_tile_count = UP_DIV(output_count, tile_n);
|
||||
int output_tile_count = UP_DIV(output_count, tile_num_);
|
||||
int in_channel = conv_param_->input_channel_;
|
||||
int ic4 = UP_DIV(in_channel, C4NUM);
|
||||
int kernel_plane = conv_param_->kernel_h_ * conv_param_->kernel_w_;
|
||||
int unit_size = kernel_plane * ic4 * C4NUM;
|
||||
int packed_input_size = output_tile_count * tile_n * unit_size;
|
||||
int packed_input_size = output_tile_count * tile_num_ * unit_size;
|
||||
|
||||
packed_input_ = reinterpret_cast<int8_t *>(malloc(conv_param_->input_batch_ * packed_input_size));
|
||||
if (packed_input_ == nullptr) {
|
||||
|
@ -217,14 +215,14 @@ int ConvolutionInt8CPUKernel::InitTmpBufferOpt() {
|
|||
}
|
||||
memset(packed_input_, 0, conv_param_->input_batch_ * packed_input_size);
|
||||
|
||||
input_sum_ = reinterpret_cast<int32_t *>(malloc(tile_n * thread_count_ * sizeof(int32_t)));
|
||||
input_sum_ = reinterpret_cast<int32_t *>(malloc(tile_num_ * thread_count_ * sizeof(int32_t)));
|
||||
if (input_sum_ == nullptr) {
|
||||
MS_LOG(ERROR) << "malloc input_sum_ failed.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
memset(input_sum_, 0, tile_n * thread_count_ * sizeof(int32_t));
|
||||
memset(input_sum_, 0, tile_num_ * thread_count_ * sizeof(int32_t));
|
||||
|
||||
size_t tmp_dst_size = thread_count_ * tile_n * conv_param_->output_channel_ * sizeof(int32_t);
|
||||
size_t tmp_dst_size = thread_count_ * tile_num_ * conv_param_->output_channel_ * sizeof(int32_t);
|
||||
tmp_dst_ = reinterpret_cast<int32_t *>(malloc(tmp_dst_size));
|
||||
if (tmp_dst_ == nullptr) {
|
||||
MS_LOG(ERROR) << "malloc tmp_dst_ failed.";
|
||||
|
@ -232,7 +230,7 @@ int ConvolutionInt8CPUKernel::InitTmpBufferOpt() {
|
|||
}
|
||||
memset(tmp_dst_, 0, tmp_dst_size);
|
||||
|
||||
tmp_out_ = reinterpret_cast<int8_t *>(malloc(thread_count_ * tile_n * conv_param_->output_channel_));
|
||||
tmp_out_ = reinterpret_cast<int8_t *>(malloc(thread_count_ * tile_num_ * conv_param_->output_channel_));
|
||||
if (tmp_out_ == nullptr) {
|
||||
MS_LOG(ERROR) << "malloc tmp_out_ failed.";
|
||||
return RET_ERROR;
|
||||
|
|
|
@ -5,19 +5,22 @@ set(LITE_DIR ${CMAKE_CURRENT_SOURCE_DIR}/../../../../../)
|
|||
include_directories(OPTIMIZED_OP_DIR)
|
||||
|
||||
########################### optimized files ###########################
|
||||
set(FP16_ASSEMBLY
|
||||
# ${OPTIMIZED_OP_DIR}/assembly/arm64/IndirectGemmFp16_16x8.s
|
||||
file(GLOB OPTIMIZED_ASSEMBLY
|
||||
${OPTIMIZED_OP_DIR}/assembly/opt/*.s
|
||||
${OPTIMIZED_OP_DIR}/assembly/opt/*.S
|
||||
)
|
||||
|
||||
file(GLOB_RECURSE OPTIMIZED_INT8_ASSEMBLY
|
||||
${OPTIMIZED_OP_DIR}/assembly/opt/*.S
|
||||
|
||||
file(GLOB FP16_SRC
|
||||
# ${OPTIMIZED_OP_DIR}/fp16/*.cc
|
||||
# ${OPTIMIZED_OP_DIR}/../fp16/*.cc
|
||||
)
|
||||
|
||||
########################### share library build ########################
|
||||
set(OPTIMIZED_OPS "opt_op_handler.c")
|
||||
|
||||
set_property(SOURCE ${OPTIMIZED_INT8_ASSEMBLY} PROPERTY LANGUAGE C)
|
||||
list(APPEND OPTIMIZED_OPS ${OPTIMIZED_INT8_ASSEMBLY} ${FP16_ASSEMBLY})
|
||||
set_property(SOURCE ${OPTIMIZED_ASSEMBLY} PROPERTY LANGUAGE C)
|
||||
list(APPEND OPTIMIZED_OPS ${OPTIMIZED_ASSEMBLY} ${FP16_SRC})
|
||||
|
||||
if (PLATFORM_ARM64)
|
||||
string(REPLACE "-fvisibility=hidden" "-fvisibility=default" CMAKE_C_FLAGS "${CMAKE_C_FLAGS}")
|
||||
|
|
|
@ -17,7 +17,7 @@
|
|||
#include "src/runtime/kernel/arm/opclib/common_func.h"
|
||||
#include "src/runtime/kernel/arm/opclib/quantization/fixed_point.h"
|
||||
|
||||
#ifndef ENABLE_ARM
|
||||
#ifndef __aarch64__
|
||||
void IndirectGemmFp32(float *output, const float *input, const float *weight, const float *bias, size_t step, int ic4,
|
||||
int output_channel, size_t offset, size_t relu, size_t relu6) {
|
||||
for (int i = 0; i < TILE_NUM; i++) {
|
||||
|
@ -108,24 +108,49 @@ int8_t MinInt8(int8_t a, int8_t b) { return b ^ ((a ^ b) & -(a < b)); }
|
|||
int8_t MaxInt8(int8_t a, int8_t b) { return a ^ ((a ^ b) & -(a < b)); }
|
||||
|
||||
void ReluFp32(float *data, int ele_num) {
|
||||
for (int i = 0; i < ele_num; i++) {
|
||||
if (data[i] < 0) {
|
||||
data[i] = 0;
|
||||
} else {
|
||||
// do nothing
|
||||
}
|
||||
int four_block = UP_DIV(ele_num, C4NUM);
|
||||
for (int i = 0; i < four_block - 1; i++) {
|
||||
int index = i * C4NUM;
|
||||
#ifdef ENABLE_NEON
|
||||
float32x4_t relu_data = vld1q_f32(data + index);
|
||||
float32x4_t zero_data = vdupq_n_f32(0);
|
||||
relu_data = vmaxq_f32(relu_data, zero_data);
|
||||
#else
|
||||
data[index] = data[index] < 0 ? 0 : data[index];
|
||||
data[index + 1] = data[index + 1] < 0 ? 0 : data[index + 1];
|
||||
data[index + 2] = data[index + 2] < 0 ? 0 : data[index + 2];
|
||||
data[index + 3] = data[index + 3] < 0 ? 0 : data[index + 3];
|
||||
#endif
|
||||
}
|
||||
for (int j = (four_block - 1) * C4NUM; j < ele_num; ++j) {
|
||||
data[j] = data[j] < 0 ? 0 : data[j];
|
||||
}
|
||||
}
|
||||
|
||||
void Relu6Fp32(float *data, int ele_num) {
|
||||
for (int i = 0; i < ele_num; i++) {
|
||||
if (data[i] < 0) {
|
||||
data[i] = 0;
|
||||
} else if (data[i] > 6) {
|
||||
data[i] = 6;
|
||||
} else {
|
||||
// do nothing
|
||||
}
|
||||
int four_block = UP_DIV(ele_num, C4NUM);
|
||||
for (int i = 0; i < four_block - 1; i++) {
|
||||
int index = i * C4NUM;
|
||||
#ifdef ENABLE_NEON
|
||||
float32x4_t relu6_data = vld1q_f32(data + index);
|
||||
float32x4_t zero_data = vdupq_n_f32(0);
|
||||
float32x4_t six_data = vdupq_n_f32(6);
|
||||
relu6_data = vmaxq_f32(relu6_data, zero_data);
|
||||
relu6_data = vminq_f32(relu6_data, six_data);
|
||||
#else
|
||||
data[index] = data[index] < 0 ? 0 : data[index];
|
||||
data[index] = data[index] > 6 ? 6 : data[index];
|
||||
data[index + 1] = data[index + 1] < 0 ? 0 : data[index + 1];
|
||||
data[index + 1] = data[index + 1] > 6 ? 6 : data[index + 1];
|
||||
data[index + 2] = data[index + 2] < 0 ? 0 : data[index + 2];
|
||||
data[index + 2] = data[index + 2] > 6 ? 6 : data[index + 2];
|
||||
data[index + 3] = data[index + 3] < 0 ? 0 : data[index + 3];
|
||||
data[index + 3] = data[index + 3] > 6 ? 6 : data[index + 3];
|
||||
#endif
|
||||
}
|
||||
for (int j = (four_block - 1) * C4NUM; j < ele_num; ++j) {
|
||||
data[j] = data[j] < 0 ? 0 : data[j];
|
||||
data[j] = data[j] > 6 ? 6 : data[j];
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -39,7 +39,7 @@ struct ConvParameter {
|
|||
int pad_l_;
|
||||
int pad_r_;
|
||||
int group_;
|
||||
int n_dim_;
|
||||
int tile_num_;
|
||||
int input_batch_;
|
||||
int input_h_;
|
||||
int input_w_;
|
||||
|
|
|
@ -141,7 +141,7 @@ void ConvFp16(float16_t *input_data, float16_t *packed_input, float16_t *packed_
|
|||
int start_index = thread_id * tile_n;
|
||||
int real_cal_num = (output_count - start_index) < tile_n ? (output_count - start_index) : tile_n;
|
||||
float16_t *gemm_input =
|
||||
(float *)(packed_input + thread_id * unit_size * tile_n + gemm_in_batch_offset);
|
||||
(float16_t *)(packed_input + thread_id * unit_size * tile_n + gemm_in_batch_offset);
|
||||
Im2ColPackUnitFp16(input_data + in_batch_offset, conv_param, gemm_input, real_cal_num, start_index);
|
||||
|
||||
int out_offset = thread_id * tile_n * out_channel + out_batch_offset;
|
||||
|
|
|
@ -16,7 +16,7 @@
|
|||
|
||||
#include "src/runtime/kernel/arm/opclib/fp32/common_func.h"
|
||||
|
||||
#ifndef ENABLE_ARM
|
||||
#ifndef __aarch64__
|
||||
void MatrixAdd(const float *a_ptr, const float *b_ptr, float *dst, size_t a_stride, size_t b_stride, size_t c_stride,
|
||||
size_t row, size_t col) {
|
||||
for (int r = 0; r < row; r++) {
|
||||
|
|
|
@ -31,13 +31,12 @@ void ConvFp32(float *input_data, float *packed_input, float *packed_weight, cons
|
|||
int out_w = conv_param->output_w_;
|
||||
int out_channel = conv_param->output_channel_;
|
||||
int thread_count = conv_param->thread_num_;
|
||||
int tile_n = 8;
|
||||
int output_count = out_h * out_w;
|
||||
int output_tile_count = UP_DIV(output_count, tile_n);
|
||||
int output_tile_count = UP_DIV(output_count, TILE_NUM);
|
||||
int ic4 = UP_DIV(in_channel, C4NUM);
|
||||
int kernel_plane = kernel_h * kernel_w;
|
||||
int unit_size = kernel_plane * ic4 * C4NUM;
|
||||
int packed_input_size = output_tile_count * tile_n * unit_size;
|
||||
int packed_input_size = output_tile_count * TILE_NUM * unit_size;
|
||||
|
||||
// we accumulate 4 channels per time for input blocks
|
||||
int conv_depth = kernel_h * kernel_w;
|
||||
|
@ -50,13 +49,13 @@ void ConvFp32(float *input_data, float *packed_input, float *packed_weight, cons
|
|||
int out_batch_offset = b * out_channel * out_h * out_w;
|
||||
int gemm_in_batch_offset = b * packed_input_size;
|
||||
for (int thread_id = task_id; thread_id < output_tile_count; thread_id += thread_count) {
|
||||
int start_index = thread_id * tile_n;
|
||||
int real_cal_num = (output_count - start_index) < tile_n ? (output_count - start_index) : tile_n;
|
||||
float *gemm_input = packed_input + thread_id * unit_size * tile_n + gemm_in_batch_offset;
|
||||
int start_index = thread_id * TILE_NUM;
|
||||
int real_cal_num = (output_count - start_index) < TILE_NUM ? (output_count - start_index) : TILE_NUM;
|
||||
float *gemm_input = packed_input + thread_id * unit_size * TILE_NUM + gemm_in_batch_offset;
|
||||
Im2ColPackUnitFp32(input_data + in_batch_offset, conv_param, gemm_input, real_cal_num, start_index);
|
||||
|
||||
int out_offset = thread_id * tile_n * out_channel + out_batch_offset;
|
||||
if (real_cal_num == tile_n) {
|
||||
int out_offset = thread_id * TILE_NUM * out_channel + out_batch_offset;
|
||||
if (real_cal_num == TILE_NUM) {
|
||||
float *gemm_output = output_data + out_offset;
|
||||
IndirectGemmFp32_8x8(gemm_output, gemm_input, packed_weight, bias_data, conv_depth, ic4, out_channel,
|
||||
output_offset, 0, 0, conv_param->is_relu_, conv_param->is_relu6_);
|
||||
|
@ -121,22 +120,8 @@ void ConvWinogardFp32(float *input_data, float *trans_weight, const float *bias_
|
|||
}
|
||||
}
|
||||
// get real output
|
||||
for (int batch = 0; batch < out_batch; batch++) {
|
||||
int batch_size = batch * out_channel * conv_param->output_h_ * conv_param->output_w_;
|
||||
for (int h = 0; h < conv_param->output_h_; h++) {
|
||||
for (int w = 0; w < conv_param->output_w_; w++) {
|
||||
for (int c = 0; c < out_channel; c++) {
|
||||
int oc4_block = c / C4NUM;
|
||||
int oc4_res = c % C4NUM;
|
||||
int src_offset = oc4_block * C4NUM * out_w_block * out_h_block * out_unit * out_unit +
|
||||
C4NUM * (h * out_w_block * out_unit + w) + oc4_res;
|
||||
int dst_offset = (h * conv_param->output_w_ + w) * out_channel + c;
|
||||
(output_data + dst_offset)[0] = (tmp_out_data + src_offset)[0];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
UnPackWinogradOutput(tmp_out_data, output_data, out_batch, conv_param->output_h_, conv_param->output_w_, out_channel,
|
||||
out_unit);
|
||||
int output_num = out_channel * conv_param->output_h_ * conv_param->output_w_ * conv_param->output_batch_;
|
||||
if (is_relu) {
|
||||
ReluFp32(output_data, output_num);
|
||||
|
@ -147,6 +132,45 @@ void ConvWinogardFp32(float *input_data, float *trans_weight, const float *bias_
|
|||
}
|
||||
}
|
||||
|
||||
void UnPackWinogradOutput(const float *src, float *dst, int batch, int height, int width, int channel,
|
||||
int output_unit) {
|
||||
int out_h_block_num = UP_DIV(height, output_unit);
|
||||
int out_w_block_num = UP_DIV(width, output_unit);
|
||||
int c4 = UP_DIV(channel, C4NUM);
|
||||
for (int b = 0; b < batch; b++) {
|
||||
int src_batch_offset = b * c4 * C4NUM * out_h_block_num * output_unit * out_w_block_num * output_unit;
|
||||
int dst_batch_offset = b * height * width * channel;
|
||||
for (int h = 0; h < height; h++) {
|
||||
int src_h_offset = src_batch_offset + C4NUM * (h * out_w_block_num * output_unit);
|
||||
int dst_h_offset = dst_batch_offset + h * width * channel;
|
||||
for (int w = 0; w < width; w++) {
|
||||
int src_w_offset = src_h_offset + w * C4NUM;
|
||||
int dst_w_offset = dst_h_offset + w * channel;
|
||||
for (int c = 0; c < c4 - 1; c++) {
|
||||
int src_c4_offset = src_w_offset + c * C4NUM * out_w_block_num * out_h_block_num * output_unit * output_unit;
|
||||
int dst_c4_offset = dst_w_offset + c * C4NUM;
|
||||
#ifdef ENABLE_NEON
|
||||
vst1q_f32(dst + dst_c4_offset, vld1q_f32(src + src_c4_offset));
|
||||
#else
|
||||
dst[dst_c4_offset] = src[src_c4_offset];
|
||||
dst[dst_c4_offset + 1] = src[src_c4_offset + 1];
|
||||
dst[dst_c4_offset + 2] = src[src_c4_offset + 2];
|
||||
dst[dst_c4_offset + 3] = src[src_c4_offset + 3];
|
||||
#endif
|
||||
}
|
||||
int c_res = channel - (c4 - 1) * C4NUM;
|
||||
int src_c_res_offset = (c4 - 1) * C4NUM * out_w_block_num * out_h_block_num * output_unit * output_unit;
|
||||
int dst_c_res_offset = (c4 - 1) * C4NUM;
|
||||
for (int c = 0; c < c_res; c++) {
|
||||
int src_c4_res_offset = src_w_offset + src_c_res_offset + c;
|
||||
int dst_c4_res_offset = dst_w_offset + dst_c_res_offset + c;
|
||||
dst[dst_c4_res_offset] = src[src_c4_res_offset];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// fp32 conv3x3
|
||||
void Conv3x3Fp32(float *input_data, float *transed_weight, const float *bias_data, float *output_data,
|
||||
TmpBufferAddress *buffer_list, int task_id, ConvParameter *conv_param) {
|
||||
|
@ -182,7 +206,7 @@ void Conv3x3Fp32(float *input_data, float *transed_weight, const float *bias_dat
|
|||
}
|
||||
PackNC4HW4ToNHWCFp32(nc4hw4_out, output_data, 1, conv_param->output_h_ * conv_param->output_w_, output_channel);
|
||||
}
|
||||
int output_num = oc4 * C4NUM * conv_param->output_h_ * conv_param->output_w_ * conv_param->output_batch_;
|
||||
int output_num = output_channel * conv_param->output_h_ * conv_param->output_w_ * conv_param->output_batch_;
|
||||
if (is_relu) {
|
||||
ReluFp32(output_data, output_num);
|
||||
} else if (is_relu6) {
|
||||
|
@ -191,4 +215,3 @@ void Conv3x3Fp32(float *input_data, float *transed_weight, const float *bias_dat
|
|||
// do nothing
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -42,10 +42,10 @@ void ConvWinogardFp32(float *input_data, float *trans_weight, const float *bias_
|
|||
TmpBufferAddress *buffer_list, int task_id, ConvParameter *conv_param,
|
||||
InputTransformUnitFunc input_trans_func, OutputTransformUnitFunc output_trans_func);
|
||||
|
||||
void UnPackWinogradOutput(const float *src, float *dst, int batch, int height, int width, int channel, int output_unit);
|
||||
|
||||
// fp32 conv3x3
|
||||
void Conv3x3Fp32(float *input_data, float *transed_weight, const float *bias_data, float *output_data,
|
||||
TmpBufferAddress *buffer_list, int task_id,
|
||||
ConvParameter *conv_param);
|
||||
TmpBufferAddress *buffer_list, int task_id, ConvParameter *conv_param);
|
||||
|
||||
#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_OPCLIB_FP32_CONV_H_
|
||||
|
||||
|
|
|
@ -81,8 +81,9 @@ void AvgPooling(const float *input_ptr, float *output_ptr, PoolingParameter *poo
|
|||
} // win_w loop
|
||||
} // win_h loop
|
||||
#ifdef ENABLE_NEON
|
||||
float32x4_t dup_count = vdupq_n_f32(real_count);
|
||||
vst1q_f32(output_ptr + out_channel_offset, vdivq_f32(tmp_avg, dup_count));
|
||||
float reverse_count = 1 / real_count;
|
||||
float32x4_t dup_count = vdupq_n_f32(reverse_count);
|
||||
vst1q_f32(output_ptr + out_channel_offset, vmulq_f32(tmp_avg, dup_count));
|
||||
#else
|
||||
*(output_ptr + out_channel_offset) = tmp_avg1 / (float)real_count;
|
||||
*(output_ptr + out_channel_offset + 1) = tmp_avg2 / (float)real_count;
|
||||
|
|
|
@ -24,11 +24,6 @@ void IndirectGemmInt16to32_8x4(int32_t *dst, const int16_t *src, const int16_t *
|
|||
size_t oc4, size_t offset);
|
||||
|
||||
#ifdef ENABLE_ARM64
|
||||
// void IndirectGemmInt8_24x4_dp(int8_t *dst, const int8_t *src, const int8_t *weight, const int32_t *bias, size_t
|
||||
// ksize,
|
||||
// size_t ic4, size_t output_channel, size_t offset, const int32_t *input_sum,
|
||||
// size_t act_min, size_t act_max, size_t out_zp, size_t out_multiplier, size_t
|
||||
// shift_before, size_t shift_after);
|
||||
void IndirectGemmInt8_4x4(int8_t *output, const int8_t *input, const int8_t *weight, const int32_t *bias, size_t ksize,
|
||||
size_t ic4, size_t oc, size_t offset, const int32_t *input_sum, size_t act_min,
|
||||
size_t act_max, size_t out_zp, size_t out_multiplier, size_t shift_before,
|
||||
|
@ -54,8 +49,9 @@ void IndirectGemmInt8(int8_t *dst, int32_t *tmp_dst, const int8_t *src, const in
|
|||
#ifdef __aarch64__
|
||||
IndirectGemmInt8_4x4(dst, src, weight, bias, kernel_plane, ic4, output_channel, output_channel * sizeof(int8_t),
|
||||
input_sum, act_min, act_max, out_zp, out_multiplier, shift_before, shift_after);
|
||||
// todo arm32
|
||||
#else
|
||||
int tile_num = 4;
|
||||
int tile_num = conv_param->tile_num_;
|
||||
int plane_c4 = UP_DIV(kernel_plane, C4NUM);
|
||||
for (int oc = 0; oc < output_channel; oc++) {
|
||||
int oc4_block = oc / C4NUM;
|
||||
|
@ -109,7 +105,7 @@ void IndirectGemmInt8Opt(int8_t *dst, int32_t *tmp_dst, const int8_t *src, const
|
|||
act_min, act_max, out_zp, out_multiplier, shift_before, shift_after);
|
||||
#endif
|
||||
} else {
|
||||
int tile_num = 24;
|
||||
int tile_num = conv_param->tile_num_;
|
||||
for (int oc = 0; oc < output_channel; oc++) {
|
||||
int oc4_block = oc / C4NUM;
|
||||
int oc4_res = oc % C4NUM;
|
||||
|
@ -202,7 +198,7 @@ void ConvInt8(int8_t *input_data, int8_t *packed_input, int8_t *packed_weight, c
|
|||
int out_channel = conv_param->output_channel_;
|
||||
int32_t input_zp = conv_param->conv_quant_arg_.quant_args_[0][0].zp_;
|
||||
|
||||
int tile_n = 4;
|
||||
int tile_n = conv_param->tile_num_;
|
||||
int thread_count = conv_param->thread_num_;
|
||||
int output_count = out_h * out_w;
|
||||
int output_tile_count = UP_DIV(output_count, tile_n);
|
||||
|
@ -255,9 +251,7 @@ void ConvInt8Opt(int8_t *input_data, int8_t *packed_input, int8_t *packed_weight
|
|||
int out_w = conv_param->output_w_;
|
||||
int out_channel = conv_param->output_channel_;
|
||||
int32_t input_zp = conv_param->conv_quant_arg_.quant_args_[0][0].zp_;
|
||||
|
||||
// todo
|
||||
int tile_n = 24;
|
||||
int tile_n = conv_param->tile_num_;
|
||||
int thread_count = conv_param->thread_num_;
|
||||
int output_count = out_h * out_w;
|
||||
int output_tile_count = UP_DIV(output_count, tile_n);
|
||||
|
@ -302,7 +296,6 @@ void ConvInt8Opt(int8_t *input_data, int8_t *packed_input, int8_t *packed_weight
|
|||
void Conv3x3Int8(int16_t *input_data, int16_t *transed_weight, const int32_t *bias_data, int8_t *output_data,
|
||||
int16_t *tile_buffer, int16_t *block_unit_buffer, int32_t *tmp_dst_buffer, int8_t *tmp_out,
|
||||
int task_id, ConvParameter *conv_param) {
|
||||
// todo
|
||||
int thread_count = conv_param->thread_num_;
|
||||
int ic8 = UP_DIV(conv_param->input_channel_, C8NUM);
|
||||
int output_batch = conv_param->output_batch_;
|
||||
|
@ -331,8 +324,5 @@ void Conv3x3Int8(int16_t *input_data, int16_t *transed_weight, const int32_t *bi
|
|||
}
|
||||
|
||||
// get real output
|
||||
for (int batch = 0; batch < output_batch; batch++) {
|
||||
// int batch_size = batch * output_channel * output_h * output_w;
|
||||
C4UnpackToHwcInt8(tmp_out, output_data, output_channel, output_h, output_w);
|
||||
}
|
||||
PackNC4HW4ToNHWCInt8(tmp_out, output_data, output_batch, output_h * output_w, output_channel);
|
||||
}
|
||||
|
|
|
@ -16,7 +16,6 @@
|
|||
|
||||
#include <stdlib.h>
|
||||
|
||||
// todo
|
||||
extern void IndirectGemmInt8_24x4_dp(int8_t *dst, const int8_t *src, const int8_t *weight, const int32_t *bias,
|
||||
size_t ksize, size_t ic4, size_t output_channel, size_t offset,
|
||||
const int32_t *input_sum, size_t act_min, size_t act_max, size_t out_zp,
|
||||
|
|
|
@ -345,6 +345,7 @@ void PackNHWC8Fp16ToNHWCFp32(float16_t *src, float *dst, int batch, int plane, i
|
|||
|
||||
void PackWeightFp32(float *weight_data, ConvParameter *conv_param, float *packed_weight) {
|
||||
// original weight format : ohwi
|
||||
// todo pack weight for arm32 platform
|
||||
int kernel_h = conv_param->kernel_h_;
|
||||
int kernel_w = conv_param->kernel_w_;
|
||||
int in_channel = conv_param->input_channel_;
|
||||
|
@ -352,7 +353,7 @@ void PackWeightFp32(float *weight_data, ConvParameter *conv_param, float *packed
|
|||
int oc8 = UP_DIV(out_channel, C8NUM);
|
||||
int ic4 = UP_DIV(in_channel, C4NUM);
|
||||
int kernel_plane = kernel_h * kernel_w;
|
||||
int pack_weight_size = oc8 * ic4 * C8NUM * C4NUM * kernel_plane;
|
||||
int pack_weight_size = oc8 * C8NUM * ic4 * C4NUM * kernel_plane;
|
||||
|
||||
int unit_size = C8NUM * C4NUM;
|
||||
int block_size = pack_weight_size / oc8;
|
||||
|
@ -565,7 +566,7 @@ void Im2ColPackUnitFp32(const float *input_data, ConvParameter *conv_param, floa
|
|||
void Im2ColPackUnitInt8(const int8_t *input_data, int8_t *packed_input, int real_cal_num, int block_index,
|
||||
int32_t *input_sum, ConvParameter *conv_param) {
|
||||
// input format : nhwc
|
||||
int tile_num = 4;
|
||||
int tile_num = conv_param->tile_num_;
|
||||
int32_t filter_zp = conv_param->conv_quant_arg_.quant_args_[1][0].zp_;
|
||||
int kernel_h = conv_param->kernel_h_;
|
||||
int kernel_w = conv_param->kernel_w_;
|
||||
|
@ -624,7 +625,7 @@ void Im2ColPackUnitInt8(const int8_t *input_data, int8_t *packed_input, int real
|
|||
void Im2ColPackUnitInt8Opt(const int8_t *input_data, int8_t *packed_input, int real_cal_num, int block_index,
|
||||
int32_t *input_sum, ConvParameter *conv_param) {
|
||||
// input format : nhwc
|
||||
int tile_num = 24;
|
||||
int tile_num = conv_param->tile_num_;
|
||||
int32_t filter_zp = conv_param->conv_quant_arg_.quant_args_[1][0].zp_;
|
||||
int kernel_h = conv_param->kernel_h_;
|
||||
int kernel_w = conv_param->kernel_w_;
|
||||
|
@ -980,15 +981,23 @@ void PackNC4HW4ToNHWCInt8(const void *src, void *dst, int batch, int plane, int
|
|||
for (int b = 0; b < batch; b++) {
|
||||
int src_offset = b * plane * c4 * C4NUM;
|
||||
int dst_offset = b * plane * channel;
|
||||
for (int c = 0; c < channel; c++) {
|
||||
int c4_block_num = c / C4NUM;
|
||||
int c4_block_res = c % C4NUM;
|
||||
int src_c_offset = src_offset + c4_block_num * plane * C4NUM + c4_block_res;
|
||||
int dst_c_offset = dst_offset + c;
|
||||
for (int k = 0; k < plane; k++) {
|
||||
int src_kernel_offset = src_c_offset + k * C4NUM;
|
||||
int dst_kernel_offset = dst_c_offset + k * channel;
|
||||
((uint8_t *)dst + dst_kernel_offset)[0] = ((uint8_t *)src + src_kernel_offset)[0];
|
||||
for (int k = 0; k < plane; k++) {
|
||||
int src_kernel_offset = src_offset + k * C4NUM;
|
||||
int dst_kernel_offset = dst_offset + k * channel;
|
||||
for (int c = 0; c < c4 - 1; c++) {
|
||||
int src_c_offset = src_kernel_offset + c * plane * C4NUM;
|
||||
int dst_c_offset = dst_kernel_offset + c * C4NUM;
|
||||
((int8_t *)dst + dst_c_offset)[0] = ((int8_t *)src + src_c_offset)[0];
|
||||
((int8_t *)dst + dst_c_offset)[1] = ((int8_t *)src + src_c_offset)[1];
|
||||
((int8_t *)dst + dst_c_offset)[2] = ((int8_t *)src + src_c_offset)[2];
|
||||
((int8_t *)dst + dst_c_offset)[3] = ((int8_t *)src + src_c_offset)[3];
|
||||
}
|
||||
// res part
|
||||
int res_c = channel - (c4 - 1) * C4NUM;
|
||||
for (int i = 0; i < res_c; i++) {
|
||||
int src_res_c_offset = src_kernel_offset + (c4 - 1) * C4NUM * plane + i;
|
||||
int dst_res_c_offset = dst_kernel_offset + (c4 - 1) * C4NUM + i;
|
||||
((int8_t *)dst + dst_res_c_offset)[0] = ((int8_t *)src + src_res_c_offset)[0];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -122,52 +122,4 @@ void PackDepthwiseInt8Input(const int8_t *src, int16_t *dst, const ConvParameter
|
|||
|
||||
void PackDepthwiseInt8Weight(const int8_t *src, int16_t *dst, const ConvParameter *conv_param);
|
||||
|
||||
inline void UnpackHwcToChwFp32(float *src_ptr, float *dst_ptr, int channel, int h, int w) {
|
||||
int cur = 0;
|
||||
for (int i = 0; i < channel; i++) {
|
||||
auto plane = i / BLOCK;
|
||||
auto offset = i % BLOCK;
|
||||
auto src_plane = plane * h * w * BLOCK + src_ptr;
|
||||
for (int j = 0; j < h * w; j++) {
|
||||
dst_ptr[cur++] = src_plane[j * BLOCK + offset];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
inline void C8UnpackToHwcFp32(float *src_ptr, float *dst_ptr, int channel, int h, int w) {
|
||||
int cur = 0;
|
||||
for (int j = 0; j < h * w; j++) {
|
||||
for (int i = 0; i < channel; i++) {
|
||||
auto plane = i / 8;
|
||||
auto offset = i % 8;
|
||||
auto src_plane = plane * h * w * 8 + src_ptr;
|
||||
dst_ptr[cur++] = src_plane[j * 8 + offset];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
inline void C4UnpackToHwcFp32(float *src_ptr, float *dst_ptr, int channel, int h, int w) {
|
||||
int cur = 0;
|
||||
for (int j = 0; j < h * w; j++) {
|
||||
for (int i = 0; i < channel; i++) {
|
||||
auto plane = i / 4;
|
||||
auto offset = i % 4;
|
||||
auto src_plane = plane * h * w * 4 + src_ptr;
|
||||
dst_ptr[cur++] = src_plane[j * 4 + offset];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
inline void C4UnpackToHwcInt8(int8_t *src_ptr, int8_t *dst_ptr, int channel, int h, int w) {
|
||||
int cur = 0;
|
||||
for (int j = 0; j < h * w; j++) {
|
||||
for (int i = 0; i < channel; i++) {
|
||||
auto plane = i / 4;
|
||||
auto offset = i % 4;
|
||||
auto src_plane = plane * h * w * 4 + src_ptr;
|
||||
dst_ptr[cur++] = src_plane[j * 4 + offset];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_OPCLIB_PACK_H_
|
||||
|
|
|
@ -17,659 +17,43 @@
|
|||
#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_OPCLIB_QUANTIZATION_FIXED_POINT_H_
|
||||
#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_OPCLIB_QUANTIZATION_FIXED_POINT_H_
|
||||
|
||||
#include <algorithm>
|
||||
#include <cassert>
|
||||
#include <cmath>
|
||||
#include <cstdint>
|
||||
#include <limits>
|
||||
#include <limits.h>
|
||||
#include "include/infer_log.h"
|
||||
#ifdef ENABLE_NEON
|
||||
#include <arm_neon.h>
|
||||
#endif
|
||||
|
||||
// Part 1: Low-level integer-arithmetic primitives.
|
||||
// The implementations here are generic implementations valid for
|
||||
// scalar types (e.g. std::int32_t). Architecture-specific SIMD types
|
||||
// (e.g. NEON int32x4_t) may be supported by providing
|
||||
// specializations for them in separate files.
|
||||
//
|
||||
// The purpose of these primitives is two-fold:
|
||||
// - They will be used to implement higher-level fixed-point
|
||||
// abstractions, namely the FixedPoint class and its arithmetic
|
||||
// operators.
|
||||
// - They will be directly used to implement some more involved
|
||||
// fixed-point computations, e.g. the fixed-point implementation
|
||||
// of math functions such as tanh.
|
||||
|
||||
// Some compile-time traits around raw types to handle SIMD aspects:
|
||||
// number of lanes, underlying scalar type.
|
||||
template <typename tIntegerType>
|
||||
struct FixedPointRawTypeTraits {};
|
||||
|
||||
template <>
|
||||
struct FixedPointRawTypeTraits<std::int32_t> {
|
||||
typedef std::int32_t ScalarRawType;
|
||||
static constexpr int kLanes = 1;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct FixedPointRawTypeTraits<std::int16_t> {
|
||||
typedef std::int16_t ScalarRawType;
|
||||
static constexpr int kLanes = 1;
|
||||
};
|
||||
|
||||
// Returns a SIMD value duplicating a scalar value across all lanes.
|
||||
template <typename tRawType>
|
||||
tRawType Dup(typename FixedPointRawTypeTraits<tRawType>::ScalarRawType x) {
|
||||
return x;
|
||||
// returns the high-32 bits of a * b with rounding
|
||||
// assume that a and b is divided by 2^31, who fall into [-1, 1]
|
||||
// so the mantissa of a * b is (a / 2^31) * (b / 2^31) * 2^31= (a * b) / 2^31
|
||||
// actually we compute 2 * a * b / 2^32
|
||||
// and take 32 bits of mantissa for rounding
|
||||
inline int SaturatingRoundingDoublingHighMul(int a, int b) {
|
||||
if (a == INT_MIN && b == INT_MIN) {
|
||||
return INT_MAX;
|
||||
}
|
||||
int64_t ab = ((int64_t)a) * ((int64_t)b);
|
||||
int64_t rounding = ab >= 0 ? (1ll << 30) : (1ll - (1ll << 30));
|
||||
// do not apply right shift to potential negetive values
|
||||
int ab_mantissa = (int) ((ab + rounding) / (1ll << 31));
|
||||
return ab_mantissa;
|
||||
}
|
||||
|
||||
// Plain bit-wise AND
|
||||
template <typename tIntegerType>
|
||||
tIntegerType BitAnd(tIntegerType a, tIntegerType b) {
|
||||
return a & b;
|
||||
}
|
||||
|
||||
// Plain bit-wise OR
|
||||
template <typename tIntegerType>
|
||||
tIntegerType BitOr(tIntegerType a, tIntegerType b) {
|
||||
return a | b;
|
||||
}
|
||||
|
||||
// Plain bit-wise XOR
|
||||
template <typename tIntegerType>
|
||||
tIntegerType BitXor(tIntegerType a, tIntegerType b) {
|
||||
return a ^ b;
|
||||
}
|
||||
|
||||
// Plain bit-wise NOT
|
||||
template <typename tIntegerType>
|
||||
tIntegerType BitNot(tIntegerType a) {
|
||||
return ~a;
|
||||
}
|
||||
|
||||
// Integer addition. Not saturating. Overflow is undefined behavior.
|
||||
template <typename tIntegerType>
|
||||
tIntegerType Add(tIntegerType a, tIntegerType b) {
|
||||
return a + b;
|
||||
}
|
||||
|
||||
// Integer multiplication. Not saturating. Overflow is undefined behavior.
|
||||
template <typename tIntegerType>
|
||||
tIntegerType Mul(tIntegerType a, tIntegerType b) {
|
||||
return a * b;
|
||||
}
|
||||
|
||||
// Integer subtraction. Not saturating. Overflow is undefined behavior.
|
||||
template <typename tIntegerType>
|
||||
tIntegerType Sub(tIntegerType a, tIntegerType b) {
|
||||
return a - b;
|
||||
}
|
||||
|
||||
// Integer unary negative. Not saturating. Overflow is undefined behavior.
|
||||
template <typename tIntegerType>
|
||||
tIntegerType Neg(tIntegerType a) {
|
||||
return -a;
|
||||
}
|
||||
|
||||
// Integer arithmetic left-shift, equivalent to multiplying with a power of two.
|
||||
// Negative values are OK. In case of overflow, no Undefined
|
||||
// Behavior, but the results are implementation-defined (in practice,
|
||||
// they currently are saturated, but we make no commitment to that). The idea
|
||||
// is that the caller will want to implement the overflowing cases with
|
||||
// saturation with compare-and-mask, so we don't care about the results
|
||||
// in the overflow case, we just want to avoid undefined behavior.
|
||||
//
|
||||
// tIntegerType may be int32 or any narrower signed type.
|
||||
template <typename tIntegerType, typename OffsetType>
|
||||
tIntegerType ShiftLeft(tIntegerType a, OffsetType offset) {
|
||||
const std::int64_t wide_a = (std::int64_t)(a);
|
||||
const std::int64_t wide_shifted = wide_a * (1 << offset);
|
||||
const auto min = std::numeric_limits<tIntegerType>::min();
|
||||
const auto max = std::numeric_limits<tIntegerType>::max();
|
||||
return wide_shifted < min ? min : wide_shifted > max ? max : (tIntegerType)(wide_shifted);
|
||||
}
|
||||
|
||||
// Integer arithmetic right-shift. Not rounding.
|
||||
// Relying on implementation-defined, but in-practice-consistent,
|
||||
// C++ compiler behavior.
|
||||
template <typename tIntegerType>
|
||||
tIntegerType ShiftRight(tIntegerType a, int offset) {
|
||||
return a >> offset;
|
||||
}
|
||||
|
||||
// Each bit of the result is set to the corresponding bit of either then_val or
|
||||
// else_val depending on whether the corresponding bit of if_mask is set.
|
||||
// Equivalent to the VBSL instruction in ARM NEON.
|
||||
template <typename tIntegerType>
|
||||
tIntegerType SelectUsingMask(tIntegerType if_mask, tIntegerType then_val, tIntegerType else_val) {
|
||||
return BitXor(BitAnd(if_mask, then_val), BitAnd(BitNot(if_mask), else_val));
|
||||
}
|
||||
|
||||
// For each input scalar, the corresponding bits of the result are set if the
|
||||
// input scalar is non-zero.
|
||||
template <typename tIntegerType>
|
||||
tIntegerType MaskIfNonZero(tIntegerType a) {
|
||||
static constexpr tIntegerType zero = 0;
|
||||
return a ? BitNot(zero) : zero;
|
||||
}
|
||||
|
||||
// For each input scalar, the corresponding bits of the result are set if the
|
||||
// input scalar is zero.
|
||||
template <typename tIntegerType>
|
||||
tIntegerType MaskIfZero(tIntegerType a) {
|
||||
return MaskIfNonZero<tIntegerType>(!a);
|
||||
}
|
||||
|
||||
// For each pair of input scalars, the corresponding bits of the result are
|
||||
// set if the input scalars are equal.
|
||||
template <typename tIntegerType>
|
||||
tIntegerType MaskIfEqual(tIntegerType a, tIntegerType b) {
|
||||
return MaskIfNonZero<tIntegerType>(a == b);
|
||||
}
|
||||
|
||||
// For each pair of input scalars, the corresponding bits of the result are
|
||||
// set if the input scalars are not equal.
|
||||
template <typename tIntegerType>
|
||||
tIntegerType MaskIfNotEqual(tIntegerType a, tIntegerType b) {
|
||||
return MaskIfNonZero<tIntegerType>(a != b);
|
||||
}
|
||||
|
||||
// For each pair of input scalars, the corresponding bits of the result are
|
||||
// set if the input scalars a, b satisfy a > b.
|
||||
template <typename tIntegerType>
|
||||
tIntegerType MaskIfGreaterThan(tIntegerType a, tIntegerType b) {
|
||||
return MaskIfNonZero<tIntegerType>(a > b);
|
||||
}
|
||||
|
||||
// For each pair of input scalars, the corresponding bits of the result are
|
||||
// set if the input scalars a, b satisfy a >= b.
|
||||
template <typename tIntegerType>
|
||||
tIntegerType MaskIfGreaterThanOrEqual(tIntegerType a, tIntegerType b) {
|
||||
return MaskIfNonZero<tIntegerType>(a >= b);
|
||||
}
|
||||
|
||||
// For each pair of input scalars, the corresponding bits of the result are
|
||||
// set if the input scalars a, b satisfy a < b.
|
||||
template <typename tIntegerType>
|
||||
tIntegerType MaskIfLessThan(tIntegerType a, tIntegerType b) {
|
||||
return MaskIfNonZero<tIntegerType>(a < b);
|
||||
}
|
||||
|
||||
// For each pair of input scalars, the corresponding bits of the result are
|
||||
// set if the input scalars a, b satisfy a <= b.
|
||||
template <typename tIntegerType>
|
||||
tIntegerType MaskIfLessThanOrEqual(tIntegerType a, tIntegerType b) {
|
||||
return MaskIfNonZero<tIntegerType>(a <= b);
|
||||
}
|
||||
|
||||
// Returns true if all of the input scalars are nonzero.
|
||||
// This function may currently assume that each of the input scalars has either
|
||||
// all or none of its bits set. Otherwise, its behavior is currently undefined.
|
||||
template <typename tIntegerType>
|
||||
bool All(tIntegerType a) {
|
||||
return a;
|
||||
}
|
||||
|
||||
// Returns true if any of the input scalars are nonzero.
|
||||
// This function may currently assume that each of the input scalars has either
|
||||
// all or none of its bits set. Otherwise, its behavior is currently undefined.
|
||||
template <typename tIntegerType>
|
||||
bool Any(tIntegerType a) {
|
||||
return a;
|
||||
}
|
||||
|
||||
// Returns (a+b)/2, rounded to the nearest integer.
|
||||
// Equivalent to VRHADD in the ARM NEON instruction set.
|
||||
template <typename IntegerType>
|
||||
IntegerType RoundingHalfSum(IntegerType a, IntegerType b) {
|
||||
static_assert(std::is_same<IntegerType, void>::value, "unimplemented");
|
||||
(void)b;
|
||||
return a;
|
||||
}
|
||||
|
||||
template <>
|
||||
inline std::int32_t RoundingHalfSum(std::int32_t a, std::int32_t b) {
|
||||
std::int64_t a64 = a;
|
||||
std::int64_t b64 = b;
|
||||
std::int64_t sum = a64 + b64;
|
||||
std::int64_t sign = sum >= 0 ? 1 : -1;
|
||||
return (std::int32_t)((sum + sign) / 2);
|
||||
}
|
||||
|
||||
template <>
|
||||
inline std::int16_t RoundingHalfSum(std::int16_t a, std::int16_t b) {
|
||||
std::int32_t a32 = a;
|
||||
std::int32_t b32 = b;
|
||||
std::int32_t sum = a32 + b32;
|
||||
std::int32_t sign = sum >= 0 ? 1 : -1;
|
||||
return (std::int16_t)((sum + sign) / 2);
|
||||
}
|
||||
|
||||
template <typename IntegerType>
|
||||
IntegerType SaturatingAdd(IntegerType a, IntegerType b) {
|
||||
static_assert(std::is_same<IntegerType, void>::value, "unimplemented");
|
||||
(void)b;
|
||||
return a;
|
||||
}
|
||||
|
||||
// So far this is only needed for int16.
|
||||
template <>
|
||||
inline std::int16_t SaturatingAdd(std::int16_t a, std::int16_t b) {
|
||||
std::int32_t a32 = a;
|
||||
std::int32_t b32 = b;
|
||||
std::int32_t sum = a32 + b32;
|
||||
return (std::int16_t)(std::min((std::int32_t)(32767), std::max((std::int32_t)(-32768), sum)));
|
||||
}
|
||||
|
||||
template <>
|
||||
inline std::int8_t SaturatingAdd(std::int8_t a, std::int8_t b) {
|
||||
std::int16_t a16 = a;
|
||||
std::int16_t b16 = b;
|
||||
std::int16_t sum = a16 + b16;
|
||||
return (std::int8_t)(std::min((int16_t)(std::numeric_limits<int8_t>::max()),
|
||||
std::max((int16_t)(std::numeric_limits<int8_t>::min()), sum)));
|
||||
}
|
||||
|
||||
// Returns a+b, saturating if the integers are 16bit or narrower,
|
||||
// otherwise just a plain addition.
|
||||
template <typename IntegerType, bool Is16Bit>
|
||||
struct AddSaturatingIf16BitImpl {
|
||||
static IntegerType Run(IntegerType a, IntegerType b) { return Add(a, b); }
|
||||
};
|
||||
template <typename IntegerType>
|
||||
struct AddSaturatingIf16BitImpl<IntegerType, true> {
|
||||
static IntegerType Run(IntegerType a, IntegerType b) { return SaturatingAdd(a, b); }
|
||||
};
|
||||
template <typename IntegerType>
|
||||
IntegerType AddSaturatingIf16Bit(IntegerType a, IntegerType b) {
|
||||
using ScalarType = typename FixedPointRawTypeTraits<IntegerType>::ScalarRawType;
|
||||
return AddSaturatingIf16BitImpl<IntegerType, sizeof(ScalarType) == 2>::Run(a, b);
|
||||
}
|
||||
|
||||
// Returns the integer that represents the product of two fixed-point
|
||||
// numbers, interpreting all integers as fixed-point values in the
|
||||
// interval [-1, 1), rounding to the nearest value, and saturating
|
||||
// -1 * -1 to the maximum value (since 1 is not in the half-open
|
||||
// interval [-1, 1)).
|
||||
//
|
||||
// [The explanation below specializes to std::int32_t for example purpose.]
|
||||
//
|
||||
// The mapping between IntegerType and the interval [-1, 1) is unique and
|
||||
// implied by IntegerType, which is assumed to be signed. For example,
|
||||
// for IntegerType==std::int32_t, the mapping is
|
||||
// real_value = integer_value / 2^31.
|
||||
// So in this case, and leaving aside rounding and saturating, this
|
||||
// function computes ((a / 2^31) * (b / 2^31)) * 2^31, which simplifies to
|
||||
// (a * b) / 2^31.
|
||||
//
|
||||
// The 'doubling' part in the name of this function comes from the fact that
|
||||
// this operation is very close to a "multiply-high" operation, keeping only
|
||||
// the top half bits, except that that would be effectively computing
|
||||
// (a * b) / 2^32,
|
||||
// so here we are computing 2x that, since
|
||||
// 1/2^31 = 2 * 1/2^32.
|
||||
// The idea is to use all of the available 32 bits in the destination int32
|
||||
// value.
|
||||
//
|
||||
// [End of the explanation specializing to int32.]
|
||||
//
|
||||
// This is equivalent to the VQRDMULH instruction in ARM NEON.
|
||||
template <typename IntegerType>
|
||||
IntegerType SaturatingRoundingDoublingHighMul(IntegerType a, IntegerType b) {
|
||||
static_assert(std::is_same<IntegerType, void>::value, "unimplemented");
|
||||
(void)b;
|
||||
return a;
|
||||
}
|
||||
|
||||
// This function implements the same computation as the ARMv7 NEON VQRDMULH
|
||||
// instruction.
|
||||
template <>
|
||||
inline std::int32_t SaturatingRoundingDoublingHighMul(std::int32_t a, std::int32_t b) {
|
||||
bool overflow = a == b && a == std::numeric_limits<std::int32_t>::min();
|
||||
std::int64_t a_64(a);
|
||||
std::int64_t b_64(b);
|
||||
std::int64_t ab_64 = a_64 * b_64;
|
||||
std::int32_t nudge = ab_64 >= 0 ? (1 << 30) : (1 - (1 << 30));
|
||||
std::int32_t ab_x2_high32 = (std::int32_t)((ab_64 + nudge) / (1ll << 31));
|
||||
return overflow ? std::numeric_limits<std::int32_t>::max() : ab_x2_high32;
|
||||
}
|
||||
|
||||
template <>
|
||||
inline std::int16_t SaturatingRoundingDoublingHighMul(std::int16_t a, std::int16_t b) {
|
||||
bool overflow = a == b && a == std::numeric_limits<std::int16_t>::min();
|
||||
std::int32_t a_32(a);
|
||||
std::int32_t b_32(b);
|
||||
std::int32_t ab_32 = a_32 * b_32;
|
||||
std::int16_t nudge = ab_32 >= 0 ? (1 << 14) : (1 - (1 << 14));
|
||||
std::int16_t ab_x2_high16 = (std::int16_t)((ab_32 + nudge) / (1 << 15));
|
||||
return overflow ? std::numeric_limits<std::int16_t>::max() : ab_x2_high16;
|
||||
}
|
||||
|
||||
// Correctly-rounded-to-nearest division by a power-of-two.
|
||||
// Also known as a rounding arithmetic right shift.
|
||||
template <typename IntegerType, typename ExponentType>
|
||||
inline IntegerType RoundingDivideByPOT(IntegerType x, ExponentType exponent) {
|
||||
assert(exponent >= 0);
|
||||
assert(exponent <= 31);
|
||||
const IntegerType mask = Dup<IntegerType>((1ll << exponent) - 1);
|
||||
const IntegerType zero = Dup<IntegerType>(0);
|
||||
const IntegerType one = Dup<IntegerType>(1);
|
||||
const IntegerType remainder = BitAnd(x, mask);
|
||||
const IntegerType threshold = Add(ShiftRight(mask, 1), BitAnd(MaskIfLessThan(x, zero), one));
|
||||
return Add(ShiftRight(x, exponent), BitAnd(MaskIfGreaterThan(remainder, threshold), one));
|
||||
// division by a 2^exponent with rounding
|
||||
// or arithmetic right shift with rouding
|
||||
inline int RoundingDivideByPOT(int x, int exponent) {
|
||||
MS_ASSERT(exponent >= 0);
|
||||
MS_ASSERT(exponent <= 31);
|
||||
const int mask = (1ll << exponent) - 1;
|
||||
const int remainder = x & mask;
|
||||
const int threshold = (mask >> 1) + (x < 0 ? 1 : 0);
|
||||
return (x >> exponent) + (remainder > threshold ? 1 : 0);
|
||||
}
|
||||
|
||||
inline int MultiplyByQuantizedMultiplier(int32_t value, int32_t multiplier, int32_t left_shift, int32_t right_shift) {
|
||||
return RoundingDivideByPOT(SaturatingRoundingDoublingHighMul(value * (1 << left_shift), multiplier), -right_shift);
|
||||
}
|
||||
|
||||
// Returns the product of a run-time integer value by a compile-time power
|
||||
// of two, with either a positive exponent (equivalent to an arithmetic
|
||||
// left shift, saturating) or a negative exponent (equivalent to an arithmetic
|
||||
// right shift, rounding to nearest).
|
||||
template <int Exponent, typename IntegerType, int ExponentSign = (Exponent > 0 ? 1 : Exponent < 0 ? -1 : 0)>
|
||||
struct ImplSaturatingRoundingMultiplyByPOT {};
|
||||
|
||||
template <int Exponent, typename IntegerType>
|
||||
struct ImplSaturatingRoundingMultiplyByPOT<Exponent, IntegerType, 0> {
|
||||
static IntegerType eval(IntegerType x) { return x; }
|
||||
};
|
||||
|
||||
template <int Exponent, typename IntegerType>
|
||||
struct ImplSaturatingRoundingMultiplyByPOT<Exponent, IntegerType, 1> {
|
||||
static IntegerType eval(IntegerType x) {
|
||||
using ScalarIntegerType = typename FixedPointRawTypeTraits<IntegerType>::ScalarRawType;
|
||||
const IntegerType min = Dup<IntegerType>(std::numeric_limits<ScalarIntegerType>::min());
|
||||
const IntegerType max = Dup<IntegerType>(std::numeric_limits<ScalarIntegerType>::max());
|
||||
const int ScalarIntegerTypeBits = 8 * sizeof(ScalarIntegerType);
|
||||
|
||||
const std::int32_t threshold = ((1 << (ScalarIntegerTypeBits - 1 - Exponent)) - 1);
|
||||
const IntegerType positive_mask = MaskIfGreaterThan(x, Dup<IntegerType>(threshold));
|
||||
const IntegerType negative_mask = MaskIfLessThan(x, Dup<IntegerType>(-threshold));
|
||||
|
||||
IntegerType result = ShiftLeft(x, Exponent);
|
||||
result = SelectUsingMask(positive_mask, max, result);
|
||||
result = SelectUsingMask(negative_mask, min, result);
|
||||
return result;
|
||||
}
|
||||
};
|
||||
|
||||
template <int Exponent, typename IntegerType>
|
||||
struct ImplSaturatingRoundingMultiplyByPOT<Exponent, IntegerType, -1> {
|
||||
static IntegerType eval(IntegerType x) { return RoundingDivideByPOT<IntegerType>(x, -Exponent); }
|
||||
};
|
||||
|
||||
template <int Exponent, typename IntegerType>
|
||||
IntegerType SaturatingRoundingMultiplyByPOT(IntegerType x) {
|
||||
return ImplSaturatingRoundingMultiplyByPOT<Exponent, IntegerType>::eval(x);
|
||||
}
|
||||
|
||||
// Part 2: the FixedPoint class.
|
||||
|
||||
// A FixedPoint object represents a fixed-point value stored in the underlying
|
||||
// integer type tRawType, if tRawType is a plain scalar integer type.
|
||||
// Alternatively, tRawType may be a SIMD type (e.g. NEON int32x4_t) in which
|
||||
// case a FixedPoint object represents a corresponding SIMD vector of fixed
|
||||
// point values.
|
||||
//
|
||||
// tIntegerBits describes the range of the fixed-point format: if
|
||||
// tIntegerBits == m then the range of representable values is the half-open
|
||||
// interval [-2^m; 2^m) where the open boundary on the right side means that
|
||||
// 2^m is not representable (how close the maximum representable value is to
|
||||
// it, depends on bit-depth of tRawType).
|
||||
//
|
||||
// In "Q format notation",
|
||||
// https://en.wikipedia.org/wiki/Q_(number_format)
|
||||
// we are describing the format
|
||||
// Qm.n
|
||||
// where
|
||||
// m = tIntegerBits
|
||||
// and
|
||||
// n = NumberOfBits(tRawType) - (m + 1)
|
||||
// Note that the (m + 1) in the above line is because we adopt the convention
|
||||
// that we count the integer bits exclusively of the sign bit; so (m + 1) is
|
||||
// the total number of integer bits inclusive of the sign bit.
|
||||
//
|
||||
// Accordingly, the number of integral representable values in our range
|
||||
// [-2^m ; 2^m)
|
||||
// is equal to 2^(m+1).
|
||||
template <typename tRawType, int tIntegerBits>
|
||||
class FixedPoint {
|
||||
public:
|
||||
typedef tRawType RawType;
|
||||
|
||||
typedef FixedPointRawTypeTraits<RawType> RawTypeTraits;
|
||||
typedef typename RawTypeTraits::ScalarRawType ScalarRawType;
|
||||
|
||||
static constexpr int kTotalBits = 8 * sizeof(ScalarRawType);
|
||||
static constexpr int kIntegerBits = tIntegerBits;
|
||||
static constexpr int kFractionalBits = kTotalBits - 1 - kIntegerBits;
|
||||
static_assert(kIntegerBits >= 0 && kIntegerBits < kTotalBits, "bad IntegerBits");
|
||||
|
||||
typedef FixedPoint<ScalarRawType, kIntegerBits> ScalarFixedPointType;
|
||||
|
||||
static const ScalarRawType ScalarRawMin() { return std::numeric_limits<ScalarRawType>::min(); }
|
||||
|
||||
static const ScalarRawType ScalarRawMax() { return std::numeric_limits<ScalarRawType>::max(); }
|
||||
|
||||
static const ScalarRawType RawMin() { return VectorFromScalar(ScalarRawMin()); }
|
||||
|
||||
static const ScalarRawType RawMax() { return VectorFromScalar(ScalarRawMax()); }
|
||||
|
||||
static FixedPoint FromRaw(RawType x) {
|
||||
FixedPoint retval;
|
||||
retval.raw() = x;
|
||||
return retval;
|
||||
}
|
||||
|
||||
static FixedPoint FromScalarRaw(ScalarRawType x) {
|
||||
FixedPoint retval;
|
||||
retval.raw() = Dup<RawType>(x);
|
||||
return retval;
|
||||
}
|
||||
|
||||
static FixedPoint FromScalarFixedPoint(ScalarFixedPointType x) { return FromScalarRaw(x.raw()); }
|
||||
|
||||
template <int Exponent>
|
||||
static FixedPoint ConstantPOT() {
|
||||
static constexpr int kOffset = kFractionalBits + Exponent;
|
||||
static_assert(kOffset < 31, "Constant not exactly representable in this fixed-point format");
|
||||
return FromScalarRaw(ScalarRawType(1) << kOffset);
|
||||
}
|
||||
|
||||
static FixedPoint Zero() { return FromScalarRaw(0); }
|
||||
|
||||
static FixedPoint One() {
|
||||
return FromScalarRaw(kIntegerBits == 0 ? ScalarRawMax()
|
||||
: (ScalarRawType(1) << (kIntegerBits == 0 ? 0 : kFractionalBits)));
|
||||
}
|
||||
|
||||
static FixedPoint FromDouble(double x) {
|
||||
const double min_bound = (double)(ScalarRawMin());
|
||||
const double max_bound = (double)(ScalarRawMax());
|
||||
return FromScalarRaw(
|
||||
(ScalarRawType)(std::min(std::max(round(x * (double)(1ll << kFractionalBits)), min_bound), max_bound)));
|
||||
}
|
||||
|
||||
RawType raw() const { return i_; }
|
||||
RawType &raw() { return i_; }
|
||||
|
||||
private:
|
||||
RawType i_;
|
||||
};
|
||||
|
||||
// Part 3: implementation of arithmetic operators for the
|
||||
// FixedPoint class, and a few related functions.
|
||||
|
||||
// A FixedPoint multiplication is just a
|
||||
// SaturatingRoundingDoublingHighMul operation on the underlying
|
||||
// raw integer values. The IntegerBits simply add up, as is obvious
|
||||
// from the fact that the range is [-2^IntegerBits, 2^IntegerBits).
|
||||
template <typename tRawType, int tIntegerBits_a, int tIntegerBits_b>
|
||||
FixedPoint<tRawType, tIntegerBits_a + tIntegerBits_b> operator*(FixedPoint<tRawType, tIntegerBits_a> a,
|
||||
FixedPoint<tRawType, tIntegerBits_b> b) {
|
||||
FixedPoint<tRawType, tIntegerBits_a + tIntegerBits_b> c;
|
||||
c.raw() = SaturatingRoundingDoublingHighMul(a.raw(), b.raw());
|
||||
return c;
|
||||
}
|
||||
|
||||
// Tweaking IntegerBits gives exact multiplication by a power of two.
|
||||
template <int tExponent, typename tRawType, int tIntegerBits>
|
||||
FixedPoint<tRawType, tExponent + tIntegerBits> ExactMulByPot(FixedPoint<tRawType, tIntegerBits> a) {
|
||||
FixedPoint<tRawType, tExponent + tIntegerBits> c;
|
||||
c.raw() = a.raw();
|
||||
return c;
|
||||
}
|
||||
|
||||
// If we want to leave IntegerBits fixed, then multiplication
|
||||
// by a power of two has to be saturating/rounding, not exact anymore.
|
||||
template <int tExponent, typename tRawType, int tIntegerBits>
|
||||
FixedPoint<tRawType, tIntegerBits> SaturatingRoundingMultiplyByPOT(FixedPoint<tRawType, tIntegerBits> a) {
|
||||
return FixedPoint<tRawType, tIntegerBits>::FromRaw(SaturatingRoundingMultiplyByPOT<tExponent>(a.raw()));
|
||||
}
|
||||
|
||||
// Generic arithmetic operators.
|
||||
|
||||
#define MAKE_FIXEDPOINT_UNARY_FUNC(FuncName, ImplFuncName) \
|
||||
template <typename tRawType, int tIntegerBits> \
|
||||
FixedPoint<tRawType, tIntegerBits> FuncName(FixedPoint<tRawType, tIntegerBits> a) { \
|
||||
return FixedPoint<tRawType, tIntegerBits>::FromRaw(ImplFuncName(a.raw())); \
|
||||
}
|
||||
|
||||
#define MAKE_FIXEDPOINT_BINARY_FUNC(FuncName, ImplFuncName) \
|
||||
template <typename tRawType, int tIntegerBits> \
|
||||
FixedPoint<tRawType, tIntegerBits> FuncName(FixedPoint<tRawType, tIntegerBits> a, \
|
||||
FixedPoint<tRawType, tIntegerBits> b) { \
|
||||
return FixedPoint<tRawType, tIntegerBits>::FromRaw(ImplFuncName(a.raw(), b.raw())); \
|
||||
}
|
||||
|
||||
MAKE_FIXEDPOINT_UNARY_FUNC(operator-, Neg)
|
||||
MAKE_FIXEDPOINT_UNARY_FUNC(operator~, BitNot)
|
||||
MAKE_FIXEDPOINT_BINARY_FUNC(operator+, Add)
|
||||
MAKE_FIXEDPOINT_BINARY_FUNC(operator-, Sub)
|
||||
MAKE_FIXEDPOINT_BINARY_FUNC(operator&, BitAnd)
|
||||
MAKE_FIXEDPOINT_BINARY_FUNC(operator^, BitXor)
|
||||
MAKE_FIXEDPOINT_BINARY_FUNC(operator|, BitOr)
|
||||
MAKE_FIXEDPOINT_BINARY_FUNC(RoundingHalfSum, RoundingHalfSum)
|
||||
|
||||
#undef MAKE_FIXEDPOINT_UNARY_FUNC
|
||||
#undef MAKE_FIXEDPOINT_BINARY_FUNC
|
||||
|
||||
#define MAKE_FIXEDPOINT_UNARY_FUNC_RETURNING_RAW(FuncName) \
|
||||
template <typename tRawType, int tIntegerBits> \
|
||||
tRawType FuncName(FixedPoint<tRawType, tIntegerBits> a) { \
|
||||
return FuncName(a.raw()); \
|
||||
}
|
||||
|
||||
#define MAKE_FIXEDPOINT_BINARY_FUNC_RETURNING_RAW(FuncName) \
|
||||
template <typename tRawType, int tIntegerBits> \
|
||||
tRawType FuncName(FixedPoint<tRawType, tIntegerBits> a, FixedPoint<tRawType, tIntegerBits> b) { \
|
||||
return FuncName(a.raw(), b.raw()); \
|
||||
}
|
||||
|
||||
MAKE_FIXEDPOINT_UNARY_FUNC_RETURNING_RAW(MaskIfZero)
|
||||
MAKE_FIXEDPOINT_UNARY_FUNC_RETURNING_RAW(MaskIfNonZero)
|
||||
MAKE_FIXEDPOINT_BINARY_FUNC_RETURNING_RAW(MaskIfEqual)
|
||||
MAKE_FIXEDPOINT_BINARY_FUNC_RETURNING_RAW(MaskIfNotEqual)
|
||||
MAKE_FIXEDPOINT_BINARY_FUNC_RETURNING_RAW(MaskIfGreaterThan)
|
||||
MAKE_FIXEDPOINT_BINARY_FUNC_RETURNING_RAW(MaskIfGreaterThanOrEqual)
|
||||
MAKE_FIXEDPOINT_BINARY_FUNC_RETURNING_RAW(MaskIfLessThan)
|
||||
MAKE_FIXEDPOINT_BINARY_FUNC_RETURNING_RAW(MaskIfLessThanOrEqual)
|
||||
|
||||
#undef MAKE_FIXEDPOINT_UNARY_FUNC_RETURNING_RAW
|
||||
#undef MAKE_FIXEDPOINT_BINARY_FUNC_RETURNING_RAW
|
||||
|
||||
template <typename tRawType, int tIntegerBits>
|
||||
FixedPoint<tRawType, tIntegerBits> SelectUsingMask(tRawType if_mask, FixedPoint<tRawType, tIntegerBits> then_val,
|
||||
FixedPoint<tRawType, tIntegerBits> else_val) {
|
||||
return FixedPoint<tRawType, tIntegerBits>::FromRaw(SelectUsingMask(if_mask, then_val.raw(), else_val.raw()));
|
||||
}
|
||||
|
||||
template <typename tRawType, int tIntegerBits>
|
||||
bool operator==(FixedPoint<tRawType, tIntegerBits> a, FixedPoint<tRawType, tIntegerBits> b) {
|
||||
return All(MaskIfEqual(a.raw(), b.raw()));
|
||||
}
|
||||
|
||||
template <typename tRawType, int tIntegerBits>
|
||||
bool operator!=(FixedPoint<tRawType, tIntegerBits> a, FixedPoint<tRawType, tIntegerBits> b) {
|
||||
return !(a == b);
|
||||
}
|
||||
|
||||
template <typename tRawType, int tIntegerBits>
|
||||
FixedPoint<tRawType, tIntegerBits> SaturatingAdd(FixedPoint<tRawType, tIntegerBits> a,
|
||||
FixedPoint<tRawType, tIntegerBits> b) {
|
||||
return FixedPoint<tRawType, tIntegerBits>::FromRaw(SaturatingAdd(a.raw(), b.raw()));
|
||||
}
|
||||
|
||||
template <typename tRawType, int tIntegerBits>
|
||||
FixedPoint<tRawType, tIntegerBits> AddSaturatingIf16Bit(FixedPoint<tRawType, tIntegerBits> a,
|
||||
FixedPoint<tRawType, tIntegerBits> b) {
|
||||
return FixedPoint<tRawType, tIntegerBits>::FromRaw(AddSaturatingIf16Bit(a.raw(), b.raw()));
|
||||
}
|
||||
|
||||
// Conversion to floating-point.
|
||||
template <typename tRawType, int tIntegerBits>
|
||||
double ToDouble(FixedPoint<tRawType, tIntegerBits> x) {
|
||||
static_assert(FixedPointRawTypeTraits<tRawType>::kLanes == 1, "not applicable to SIMD types");
|
||||
typedef FixedPoint<tRawType, tIntegerBits> F;
|
||||
return x.raw() / (double)(1ll << F::kFractionalBits);
|
||||
}
|
||||
|
||||
// Rescale changes the number of IntegerBits and updates the underlying
|
||||
// raw integer value accordingly.
|
||||
template <int tIntegerBitsDst, typename tRawType, int tIntegerBitsSrc>
|
||||
FixedPoint<tRawType, tIntegerBitsDst> Rescale(FixedPoint<tRawType, tIntegerBitsSrc> x) {
|
||||
static constexpr int kExponent = tIntegerBitsSrc - tIntegerBitsDst;
|
||||
FixedPoint<tRawType, tIntegerBitsDst> result;
|
||||
result.raw() = SaturatingRoundingMultiplyByPOT<kExponent>(x.raw());
|
||||
return result;
|
||||
}
|
||||
|
||||
// CheckedFixedPointConstant allows to specify fixed-point constants
|
||||
// initialized as real numbers, in a way that does not compile floating-point
|
||||
// arithmetic in production code, yet still checks agreement with the
|
||||
// floating-point expressions when asserts are enabled.
|
||||
//
|
||||
// The raw integer value provided is always a int32, encoding a 32-bit
|
||||
// fixed-point value, regardless of the actual Scalar type. This allows
|
||||
// writing generic code that applies just as well to the 32-bit and 16-bit
|
||||
// cases. In the 16-bit case, the raw integer value is internally
|
||||
// rounding-shifted by 16 bits to the right.
|
||||
template <typename FixedPointType>
|
||||
inline typename FixedPointType::ScalarRawType RescaleConstantInitializer(std::int32_t int32_value) {
|
||||
typedef typename FixedPointType::ScalarRawType ScalarRawType;
|
||||
static constexpr int ScalarTypeBits = 8 * sizeof(ScalarRawType);
|
||||
return (ScalarRawType)(RoundingDivideByPOT<std::int32_t>(int32_value, 32 - ScalarTypeBits));
|
||||
}
|
||||
|
||||
// Implementation of exponential function.
|
||||
|
||||
// Returns -tanh(x) for x < 0.
|
||||
template <typename tRawType, int tIntegerBits>
|
||||
FixedPoint<tRawType, 0> neg_tanh_on_negative_values(FixedPoint<tRawType, tIntegerBits> a) {
|
||||
return one_minus_x_over_one_plus_x_for_x_in_0_1(exp_on_negative_values(ExactMulByPot<1>(a)));
|
||||
}
|
||||
|
||||
// Returns tanh(x) for any x.
|
||||
template <typename tRawType, int tIntegerBits>
|
||||
FixedPoint<tRawType, 0> tanh(FixedPoint<tRawType, tIntegerBits> a) {
|
||||
typedef FixedPoint<tRawType, tIntegerBits> InputF;
|
||||
typedef FixedPoint<tRawType, 0> ResultF;
|
||||
tRawType mask_if_negative = MaskIfLessThan(a, InputF::Zero());
|
||||
tRawType mask_if_zero = MaskIfZero(a);
|
||||
InputF n = SelectUsingMask(mask_if_negative, a, -a);
|
||||
ResultF t = neg_tanh_on_negative_values(n);
|
||||
return SelectUsingMask(mask_if_zero, ResultF::Zero(), SelectUsingMask(mask_if_negative, -t, t));
|
||||
}
|
||||
|
||||
// Implementation of logistic function.
|
||||
|
||||
// Returns logistic(x) = 1 / (1 + exp(-x)) for x > 0.
|
||||
template <typename tRawType, int tIntegerBits>
|
||||
FixedPoint<tRawType, 0> logistic_on_positive_values(FixedPoint<tRawType, tIntegerBits> a) {
|
||||
return one_over_one_plus_x_for_x_in_0_1(exp_on_negative_values(-a));
|
||||
}
|
||||
|
||||
#ifdef ENABLE_NEON
|
||||
inline int32x4_t RoundingDivideByPOTInt32x4(int32x4_t x, int exponent) {
|
||||
const int32x4_t shift_vec = vdupq_n_s32(-exponent);
|
||||
|
|
File diff suppressed because it is too large
Load Diff
Loading…
Reference in New Issue