[MSLITE] arithmetic optimize

This commit is contained in:
ling 2020-12-29 19:42:05 +08:00
parent 77adcecedb
commit 8a78b1f362
141 changed files with 3268 additions and 3253 deletions

View File

@ -15,7 +15,7 @@ file(GLOB KERNEL_SRC
${NNACL_DIR}/*.c
${NNACL_DIR}/fp32/*.c
${NNACL_DIR}/int8/*.c
${NNACL_DIR}/quantization/*.c
${NNACL_DIR}/base/*.c
)
if (SUPPORT_TRAIN)

View File

@ -42,12 +42,4 @@ typedef struct ArithmeticParameter {
int multiples1_[10];
} ArithmeticParameter;
#ifdef __cplusplus
extern "C" {
#endif
void CalcMultiplesAndStrides(ArithmeticParameter *param);
#ifdef __cplusplus
}
#endif
#endif // MINDSPORE_LITE_NNACL_ARTITHMETIC_H_

View File

@ -19,7 +19,7 @@
#include "nnacl/op_base.h"
#include "nnacl/errorcode.h"
#include "nnacl/quantization/quantize.h"
#include "nnacl/int8/quantize.h"
// For Abs, Cos, Exp, Log, Square, Sqrt, Rsqrt ops.
typedef struct ArithmeticSelfParameter {

View File

@ -14,7 +14,7 @@
* limitations under the License.
*/
#include "nnacl/arithmetic.h"
#include "nnacl/base/arithmetic_base.h"
void CalcMultiplesAndStrides(ArithmeticParameter *param) {
NNACL_ASSERT(param->in_shape0_[i] != 0);

View File

@ -0,0 +1,34 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_NNACL_BASE_ARITHMETIC_BASE_H_
#define MINDSPORE_LITE_NNACL_BASE_ARITHMETIC_BASE_H_
#include "nnacl/arithmetic.h"
#include "nnacl/nnacl_utils.h"
#include "nnacl/nnacl_common.h"
#ifdef __cplusplus
extern "C" {
#endif
void CalcMultiplesAndStrides(ArithmeticParameter *param);
#ifdef __cplusplus
}
#endif
#endif // MINDSPORE_LITE_NNACL_BASE_ARITHMETIC_BASE_H_

View File

@ -0,0 +1,40 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "nnacl/base/conv1x1_base.h"
void Conv1x1InputPack(const void *src_ptr, void *dst_ptr, ConvParameter *conv_param, int data_size) {
/* support nhwc */
char *src = (char *)src_ptr;
char *dst = (char *)dst_ptr;
for (int dst_h = 0; dst_h < conv_param->output_h_; dst_h++) {
int src_h = dst_h * conv_param->stride_h_ - conv_param->pad_u_;
if (src_h < 0 || src_h >= conv_param->input_h_) {
continue;
}
const char *src_h_ptr = src + src_h * conv_param->input_w_ * conv_param->input_channel_ * data_size;
char *dst_h_ptr = dst + dst_h * conv_param->output_w_ * conv_param->input_channel_ * data_size;
for (int dst_w = 0; dst_w < conv_param->output_w_; dst_w++) {
int src_w = dst_w * conv_param->stride_w_ - conv_param->pad_l_;
if (src_w < 0 || src_w >= conv_param->input_w_) {
continue;
}
memcpy(dst_h_ptr + dst_w * conv_param->input_channel_ * data_size,
src_h_ptr + src_w * conv_param->input_channel_ * data_size, conv_param->input_channel_ * data_size);
}
}
return;
}

View File

@ -0,0 +1,32 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_NNACL_BASE_CONV1X1_BASE_H_
#define MINDSPORE_LITE_NNACL_BASE_CONV1X1_BASE_H_
#include "nnacl/conv_parameter.h"
#ifdef __cplusplus
extern "C" {
#endif
void Conv1x1InputPack(const void *src_ptr, void *dst_ptr, ConvParameter *conv_param, int data_size);
#ifdef __cplusplus
}
#endif
#endif // MINDSPORE_LITE_NNACL_BASE_CONV1X1_BASE_H_

View File

@ -13,8 +13,8 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "nnacl/depth_to_space.h"
#include <string.h>
#include "nnacl/base/depth_to_space_base.h"
void DepthToSpaceForNHWC(const void *input, void *output, const int *in_shape, const DepthToSpaceParameter *param) {
int32_t block_size = param->block_size_;

View File

@ -15,6 +15,8 @@
*/
#ifndef MINDSPORE_LITE_NNACL_DEPTH_TO_SPACE_H_
#define MINDSPORE_LITE_NNACL_DEPTH_TO_SPACE_H_
#include <string.h>
#include "nnacl/depth_to_space_parameter.h"
#ifdef __cplusplus

View File

@ -18,7 +18,7 @@
#include <math.h>
#include "nnacl/op_base.h"
#include "nnacl/quantization/fixed_point.h"
#include "mindspore/lite/nnacl/int8/fixed_point.h"
typedef struct ClipParameter {
OpParameter op_parameter_;

View File

@ -17,11 +17,10 @@
#ifndef MINDSPORE_LITE_NNACL_COMMON_FUNC_H_
#define MINDSPORE_LITE_NNACL_COMMON_FUNC_H_
#include <stdint.h>
#include <stdio.h>
#include <string.h>
#include "nnacl/op_base.h"
#include "nnacl/conv_parameter.h"
#include "nnacl/nnacl_common.h"
#ifdef __cplusplus
extern "C" {
@ -63,14 +62,6 @@ static inline int GetStride(int *strides, const int *shape, int length) {
return stride;
}
inline void ComputeStrides(const int *shape, int *strides, const int ndim) {
int stride = 1;
for (int i = ndim - 1; i >= 0; i--) {
strides[i] = stride;
stride *= shape[i];
}
}
#ifdef ENABLE_ARM64
void BiasAdd(const float *bias, float *data, size_t oc4, size_t plan_size);
void BiasAddRelu6(const float *bias, float *data, size_t oc4, size_t plan_size);

View File

@ -18,7 +18,7 @@
#define MINDSPORE_LITE_NNACL_CONCAT_PARAMETER_H_
#include "nnacl/op_base.h"
#include "nnacl/quantization/quantize.h"
#include "mindspore/lite/nnacl/int8/quantize.h"
typedef struct ConcatParameter {
OpParameter op_parameter_;

View File

@ -21,7 +21,7 @@
#include <arm_neon.h>
#endif
#include "nnacl/op_base.h"
#include "nnacl/quantization/quantize.h"
#include "mindspore/lite/nnacl/int8/quantize.h"
typedef struct ConvParameter {
OpParameter op_parameter_;

View File

@ -18,7 +18,7 @@
#define MINDSPORE_LITE_NNACL_CROP_PARAMETER_H_
#include "nnacl/op_base.h"
#include "nnacl/quantization/quantize.h"
#include "mindspore/lite/nnacl/int8/quantize.h"
#define CROP_OFFSET_MAX_SIZE 4

View File

@ -21,7 +21,7 @@
#endif
#include <math.h>
#include "nnacl/op_base.h"
#include "nnacl/quantization/fixed_point.h"
#include "mindspore/lite/nnacl/int8/fixed_point.h"
typedef struct ActivationParameter {
OpParameter op_parameter_;

View File

@ -20,7 +20,7 @@
#include <arm_neon.h>
#endif
#include "nnacl/op_base.h"
#include "nnacl/arithmetic.h"
#include "nnacl/base/arithmetic_base.h"
#include "nnacl/errorcode.h"
#ifdef __cplusplus

View File

@ -16,7 +16,6 @@
#include "nnacl/fp16/pack_fp16.h"
#include <string.h>
#include <stdlib.h>
void Im2ColPackUnitFp16(float16_t *input_data, ConvParameter *conv_param, float16_t *packed_input, int real_cal_num,
int block_index) {

View File

@ -17,6 +17,7 @@
#ifndef MINDSPORE_LITE_NNACL_FP16_POOLING_FP16_H_
#define MINDSPORE_LITE_NNACL_FP16_POOLING_FP16_H_
#include <math.h>
#ifdef ENABLE_NEON
#include <arm_neon.h>
#endif

View File

@ -18,7 +18,7 @@
#include <math.h>
#include "nnacl/op_base.h"
#include "nnacl/quantization/fixed_point.h"
#include "mindspore/lite/nnacl/int8/fixed_point.h"
typedef struct ActivationParameter {
OpParameter op_parameter_;

View File

@ -13,8 +13,8 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "nnacl/fp32/arg_min_max_fp32.h"
#include <stdlib.h>
#include <float.h>
int ArgCompareAscFp32(const void *a, const void *b) {

View File

@ -20,7 +20,7 @@
#include <arm_neon.h>
#endif
#include "nnacl/op_base.h"
#include "nnacl/arithmetic.h"
#include "nnacl/base/arithmetic_base.h"
#include "nnacl/errorcode.h"
#ifdef __cplusplus

View File

@ -17,8 +17,6 @@
#ifndef MINDSPORE_LITE_NNACL_FP32_COMMON_FUNC_H_
#define MINDSPORE_LITE_NNACL_FP32_COMMON_FUNC_H_
#include <stdint.h>
#include <stdio.h>
#include <string.h>
#include "nnacl/op_base.h"
#include "nnacl/conv_parameter.h"

View File

@ -0,0 +1,479 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "nnacl/fp32/pack_fp32.h"
void PackWeightKHWToHWKFp32(const void *src, void *dst, int plane, int channel) {
return PackNCHWToNHWCFp32(src, dst, 1, plane, channel);
}
void PackHWCToWHC(const float *src, float *dst, int height, int width, int channel) {
for (int i = 0; i < height; ++i) {
for (int j = 0; j < width; ++j) {
memcpy(dst + (j * height + i) * channel, src + (i * width + j) * channel, channel * sizeof(float));
}
}
}
void Im2ColPackUnitFp32(const float *input_data, const ConvParameter *conv_param, float *packed_input, int real_cal_num,
int block_index) {
// input format : nhwc
int kernel_h = conv_param->kernel_h_;
int kernel_w = conv_param->kernel_w_;
int kernel_plane = kernel_h * kernel_w;
int dilation_h = conv_param->dilation_h_;
int dilation_w = conv_param->dilation_w_;
int in_channel = conv_param->input_channel_;
int in_w = conv_param->input_w_;
int out_w = conv_param->output_w_;
for (int i = 0; i < real_cal_num; i++) {
int block_start = block_index + i;
int input_h = block_start / out_w * conv_param->stride_h_ - conv_param->pad_u_;
int input_w = block_start % out_w * conv_param->stride_w_ - conv_param->pad_l_;
int input_stride = (input_h * in_w + input_w) * in_channel;
int kh_s = MSMAX(0, UP_DIV(-input_h, dilation_h));
int kh_e = MSMIN(kernel_h, UP_DIV(conv_param->input_h_ - input_h, dilation_h));
int kw_s = MSMAX(0, UP_DIV(-input_w, dilation_w));
int kw_e = MSMIN(kernel_w, UP_DIV(in_w - input_w, dilation_w));
if (dilation_w == 1 && dilation_h == 1) {
for (int j = kh_s; j < kh_e; j++) {
int input_y_stride = j * in_w * in_channel + input_stride;
int input_x_stride = input_y_stride + kw_s * in_channel;
int input_plane_offset = (j * kernel_w + kw_s) * in_channel + i * in_channel * kernel_plane;
memcpy(packed_input + input_plane_offset, input_data + input_x_stride,
(kw_e - kw_s) * in_channel * sizeof(float));
} // kernel_h loop
} else {
for (int j = kh_s; j < kh_e; j++) {
int input_y_stride = j * dilation_h * in_w * in_channel + input_stride;
for (int k = kw_s; k < kw_e; ++k) {
int input_x_stride = input_y_stride + k * dilation_w * in_channel;
int input_plane_offset = (j * kernel_w + k) * in_channel + i * in_channel * kernel_plane;
memcpy(packed_input + input_plane_offset, input_data + input_x_stride, in_channel * sizeof(float));
}
} // kernel_h loop
}
} // tile num loop
}
void PackNHWCToNC4HW4Fp32(const void *src, void *dst, int batch, int plane, int channel) {
int c4 = UP_DIV(channel, C4NUM);
int c4_minus = c4 - 1;
for (int b = 0; b < batch; b++) {
int src_oc_offset = b * plane * channel;
int dst_oc_offset = b * plane * c4 * C4NUM;
for (int k = 0; k < plane; k++) {
int src_kernel_offset = src_oc_offset + k * channel;
int dst_kernel_offset = dst_oc_offset + k * C4NUM;
for (int j = 0; j < c4_minus; ++j) {
int src_ic_offset = src_kernel_offset + j * C4NUM;
int dst_ic_offset = dst_kernel_offset + j * plane * C4NUM;
#ifdef ENABLE_ARM
vst1q_f32((float *)dst + dst_ic_offset, vld1q_f32((float *)src + src_ic_offset));
#else
for (int i = 0; i < C4NUM; ++i) {
((float *)dst + dst_ic_offset)[i] = ((float *)src + src_ic_offset)[i];
}
#endif
}
int tmp_c = c4_minus * C4NUM;
int tmp_c_offset = tmp_c * plane;
int res_c = channel - tmp_c;
for (int l = 0; l < res_c; ++l) {
int src_ic_offset = src_kernel_offset + tmp_c + l;
int dst_ic_offset = dst_kernel_offset + tmp_c_offset + l;
((float *)dst + dst_ic_offset)[0] = ((float *)src + src_ic_offset)[0];
}
}
}
}
void PackNCHWToNC4HW4Fp32(const void *src, void *dst, int batch, int plane, int channel) {
int c4 = UP_DIV(channel, C4NUM);
for (int b = 0; b < batch; b++) {
int src_offset = b * plane * channel;
int dst_offset = b * plane * c4 * C4NUM;
for (int c = 0; c < channel; c++) {
int c4_block_num = c / C4NUM;
int c4_block_rem = c % C4NUM;
int src_c_offset = src_offset + c * plane;
int dst_c_offset = dst_offset + c4_block_num * plane * C4NUM;
for (int k = 0; k < plane; k++) {
int src_kernel_offset = src_c_offset + k;
int dst_kernel_offset = dst_c_offset + C4NUM * k + c4_block_rem;
((float *)dst + dst_kernel_offset)[0] = ((float *)src + src_kernel_offset)[0];
}
}
}
}
void PackNHWCToNHWC4Fp32(const void *src, void *dst, int batch, int plane, int channel) {
int c4 = UP_DIV(channel, C4NUM);
int c4_channel = c4 * C4NUM;
int nhwc4_batch_unit_offset = c4 * C4NUM * plane;
int ic_remainder_ = channel % C4NUM;
if (ic_remainder_ != 0) {
int nhwc4_batch_offset = 0;
for (int b = 0; b < batch; b++) {
int batch_offset = b * channel * plane;
for (int i = 0; i < plane; i++) {
float *dst_per_plane = (float *)dst + nhwc4_batch_offset + i * c4_channel;
memcpy(dst_per_plane, (float *)src + batch_offset + i * channel, channel * sizeof(float));
for (int j = channel; j < c4_channel; ++j) {
dst_per_plane[j] = 0;
}
}
nhwc4_batch_offset += nhwc4_batch_unit_offset;
}
} else {
size_t ori_input_size = batch * plane * channel * sizeof(float);
memcpy((float *)dst, (float *)src, ori_input_size);
}
}
void PackNHWCToNHWC8Fp32(const void *src, void *dst, int batch, int plane, int channel) {
int c8 = UP_DIV(channel, C8NUM);
int c8_channel = c8 * C8NUM;
int nhwc8_batch_unit_offset = c8 * C8NUM * plane;
int ic_remainder_ = channel % C8NUM;
if (ic_remainder_ != 0) {
int nhwc8_batch_offset = 0;
for (int b = 0; b < batch; b++) {
int batch_offset = b * channel * plane;
for (int i = 0; i < plane; i++) {
float *dst_per_plane = (float *)dst + nhwc8_batch_offset + i * c8_channel;
memcpy(dst_per_plane, (float *)src + batch_offset + i * channel, channel * sizeof(float));
for (int j = channel; j < c8_channel; ++j) {
dst_per_plane[j] = 0;
}
}
nhwc8_batch_offset += nhwc8_batch_unit_offset;
}
} else {
size_t ori_input_size = batch * plane * channel * sizeof(float);
memcpy((float *)dst, (float *)src, ori_input_size);
}
}
void PackNHWC4ToNHWCFp32(const void *src, void *dst, int batch, int plane, int channel) {
int c4 = UP_DIV(channel, C4NUM);
int ic_remainder_ = channel % C4NUM;
if (ic_remainder_ != 0) {
int nhwc_batch_unit_offset = channel * plane;
for (int b = 0; b < batch; b++) {
int batch_offset = b * c4 * C4NUM * plane;
for (int i = 0; i < plane; i++) {
memcpy((float *)dst + b * nhwc_batch_unit_offset + i * channel, (float *)src + batch_offset + i * c4 * C4NUM,
channel * sizeof(float));
}
}
} else {
size_t ori_input_size = batch * plane * channel * sizeof(float);
memcpy((float *)dst, (float *)src, ori_input_size);
}
}
void PackNC4HW4ToNHWC4Fp32(const void *src, void *dst, int batch, int plane, int channel) {
int c4 = UP_DIV(channel, C4NUM);
for (int b = 0; b < batch; b++) {
int src_offset = b * plane * c4 * C4NUM;
int dst_offset = b * plane * channel;
for (int c = 0; c < channel; c++) {
int c4_block_num = c / C4NUM;
int c4_block_res = c % C4NUM;
int src_c_offset = src_offset + c4_block_num * plane * C4NUM + c4_block_res;
int dst_c_offset = dst_offset + c4_block_num * C4NUM + c4_block_res;
for (int k = 0; k < plane; k++) {
int src_kernel_offset = src_c_offset + k * C4NUM;
int dst_kernel_offset = dst_c_offset + k * c4 * C4NUM;
((float *)dst + dst_kernel_offset)[0] = ((float *)src + src_kernel_offset)[0];
}
}
}
}
void PackNC4HW4ToNHWCFp32(const void *src, void *dst, int batch, int plane, int channel) {
int c4 = UP_DIV(channel, C4NUM);
for (int b = 0; b < batch; b++) {
int src_offset = b * plane * c4 * C4NUM;
int dst_offset = b * plane * channel;
for (int k = 0; k < plane; k++) {
int src_kernel_offset = src_offset + k * C4NUM;
int dst_kernel_offset = dst_offset + k * channel;
for (int c = 0; c < c4 - 1; c++) {
int src_c_offset = src_kernel_offset + c * plane * C4NUM;
int dst_c_offset = dst_kernel_offset + c * C4NUM;
#ifdef ENABLE_NEON
vst1q_f32((float *)dst + dst_c_offset, vld1q_f32((float *)src + src_c_offset));
#else
((float *)dst + dst_c_offset)[0] = ((float *)src + src_c_offset)[0];
((float *)dst + dst_c_offset)[1] = ((float *)src + src_c_offset)[1];
((float *)dst + dst_c_offset)[2] = ((float *)src + src_c_offset)[2];
((float *)dst + dst_c_offset)[3] = ((float *)src + src_c_offset)[3];
#endif
}
// res part
int res_c = channel - (c4 - 1) * C4NUM;
for (int i = 0; i < res_c; i++) {
int src_res_c_offset = src_kernel_offset + (c4 - 1) * C4NUM * plane + i;
int dst_res_c_offset = dst_kernel_offset + (c4 - 1) * C4NUM + i;
((float *)dst + dst_res_c_offset)[0] = ((float *)src + src_res_c_offset)[0];
}
}
}
}
void PackNHWCToC8HWN8Fp32(const void *src, void *dst, int batch, int plane, int channel) {
for (int n = 0; n < batch; n++) {
for (int hw = 0; hw < plane; hw++) {
for (int c = 0; c < channel; c++) {
int c8div = c / C8NUM;
int c8mod = c % C8NUM;
int src_index = n * plane * channel + hw * channel + c;
int dst_index = c8div * batch * plane * C8NUM + hw * batch * C8NUM + n * C8NUM + c8mod;
((float *)dst)[dst_index] = ((float *)src)[src_index];
}
}
}
return;
}
void PackDepthwiseIndirectWeightC4Fp32(const void *src, void *dst, int height, int width, int channel) {
int c4 = UP_DIV(channel, C4NUM);
for (int c = 0; c < c4; c++) {
int dst_off_c = c * C4NUM * height * width;
for (int i = 0; i < C4NUM; i++) {
int src_off_c = (c * C4NUM + i) * height * width;
for (int kh = 0; kh < height; kh++) {
int src_off_kh = src_off_c + kh * width;
for (int kw = 0; kw < width; kw++) {
int dst_off = dst_off_c + kw * height * C4NUM + kh * C4NUM + i;
((float *)dst)[dst_off] = ((float *)src)[src_off_kh + kw];
}
}
}
}
}
void PackDepthwiseIndirectWeightC8Fp32(const void *src, void *dst, int height, int width, int channel) {
int c8 = UP_DIV(channel, C8NUM);
for (int c = 0; c < c8; c++) {
int dst_off_c = c * C8NUM * height * width;
for (int i = 0; i < C8NUM; i++) {
int src_off_c = (c * C8NUM + i) * height * width;
for (int kh = 0; kh < height; kh++) {
int src_off_kh = src_off_c + kh * width;
for (int kw = 0; kw < width; kw++) {
int dst_off = dst_off_c + kw * height * C8NUM + kh * C8NUM + i;
((float *)dst)[dst_off] = ((float *)src)[src_off_kh + kw];
}
}
}
}
}
#ifndef ENABLE_SSE
void PackNHWCToNCHWFp32(const void *src, void *dst, int batches, int plane, int channel) {
int hw8 = plane / C8NUM * C8NUM;
int c8 = channel / C8NUM * C8NUM;
int batch = plane * channel;
for (int n = 0; n < batches; n++) {
const float *src_batch = (const float *)src + n * batch;
float *dst_batch = (float *)dst + n * batch;
int hw = 0;
for (; hw < hw8; hw += C8NUM) {
int c = 0;
for (; c < c8; c += C8NUM) {
const float *src_ptr = src_batch + hw * channel + c;
float *dst_ptr = dst_batch + c * plane + hw;
#ifdef ENABLE_ARM64
size_t srcStride = channel * sizeof(float);
size_t dstStride = plane * sizeof(float);
asm volatile(
"mov x10, %[src_ptr]\n"
"mov x11, %[dst_ptr]\n"
"ld1 {v0.4s, v1.4s}, [x10], %[srcStride]\n"
"ld1 {v2.4s, v3.4s}, [x10], %[srcStride]\n"
"zip1 v8.4s, v0.4s, v2.4s\n"
"zip2 v9.4s, v0.4s, v2.4s\n"
"zip1 v12.4s, v1.4s, v3.4s\n"
"zip2 v13.4s, v1.4s, v3.4s\n"
"ld1 {v4.4s, v5.4s}, [x10], %[srcStride]\n"
"ld1 {v6.4s, v7.4s}, [x10], %[srcStride]\n"
"zip1 v10.4s, v4.4s, v6.4s\n"
"zip2 v11.4s, v4.4s, v6.4s\n"
"zip1 v14.4s, v5.4s, v7.4s\n"
"zip2 v15.4s, v5.4s, v7.4s\n"
"ld1 {v0.4s, v1.4s}, [x10], %[srcStride]\n"
"ld1 {v2.4s, v3.4s}, [x10], %[srcStride]\n"
"trn1 v16.2d, v8.2d, v10.2d\n"
"trn2 v18.2d, v8.2d, v10.2d\n"
"trn1 v20.2d, v9.2d, v11.2d\n"
"trn2 v22.2d, v9.2d, v11.2d\n"
"ld1 {v4.4s, v5.4s}, [x10], %[srcStride]\n"
"ld1 {v6.4s, v7.4s}, [x10], %[srcStride]\n"
"trn1 v24.2d, v12.2d, v14.2d\n"
"trn2 v26.2d, v12.2d, v14.2d\n"
"trn1 v28.2d, v13.2d, v15.2d\n"
"trn2 v30.2d, v13.2d, v15.2d\n"
"zip1 v8.4s, v0.4s, v2.4s\n"
"zip2 v9.4s, v0.4s, v2.4s\n"
"zip1 v12.4s, v1.4s, v3.4s\n"
"zip2 v13.4s, v1.4s, v3.4s\n"
"zip1 v10.4s, v4.4s, v6.4s\n"
"zip2 v11.4s, v4.4s, v6.4s\n"
"zip1 v14.4s, v5.4s, v7.4s\n"
"zip2 v15.4s, v5.4s, v7.4s\n"
"trn1 v17.2d, v8.2d, v10.2d\n"
"trn2 v19.2d, v8.2d, v10.2d\n"
"trn1 v21.2d, v9.2d, v11.2d\n"
"trn2 v23.2d, v9.2d, v11.2d\n"
"trn1 v25.2d, v12.2d, v14.2d\n"
"trn2 v27.2d, v12.2d, v14.2d\n"
"trn1 v29.2d, v13.2d, v15.2d\n"
"trn2 v31.2d, v13.2d, v15.2d\n"
"st1 {v16.4s, v17.4s}, [x11], %[dstStride]\n"
"st1 {v18.4s, v19.4s}, [x11], %[dstStride]\n"
"st1 {v20.4s, v21.4s}, [x11], %[dstStride]\n"
"st1 {v22.4s, v23.4s}, [x11], %[dstStride]\n"
"st1 {v24.4s, v25.4s}, [x11], %[dstStride]\n"
"st1 {v26.4s, v27.4s}, [x11], %[dstStride]\n"
"st1 {v28.4s, v29.4s}, [x11], %[dstStride]\n"
"st1 {v30.4s, v31.4s}, [x11], %[dstStride]\n"
:
:
[ dst_ptr ] "r"(dst_ptr), [ src_ptr ] "r"(src_ptr), [ srcStride ] "r"(srcStride), [ dstStride ] "r"(dstStride)
: "x10", "x11", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14",
"v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29",
"v30", "v31");
#elif ENABLE_ARM32
size_t srcStride = channel * sizeof(float);
size_t dstStride = plane * sizeof(float);
asm volatile(
"mov r10, %[src_ptr]\n"
"mov r12, %[dst_ptr]\n"
"vld1.32 {q0, q1}, [r10], %[srcStride]\n"
"vld1.32 {q2, q3}, [r10], %[srcStride]\n"
"vtrn.32 d0, d4\n"
"vtrn.32 d1, d5\n"
"vtrn.32 d2, d6\n"
"vtrn.32 d3, d7\n"
"vld1.32 {q4, q5}, [r10], %[srcStride]\n"
"vld1.32 {q6, q7}, [r10], %[srcStride]\n"
"vtrn.32 d8, d12\n"
"vtrn.32 d9, d13\n"
"vtrn.32 d10, d14\n"
"vtrn.32 d11, d15\n"
"vld1.32 {q8, q9}, [r10], %[srcStride]\n"
"vld1.32 {q10, q11}, [r10], %[srcStride]\n"
"vswp d1, d8\n"
"vswp d3, d10\n"
"vswp d5, d12\n"
"vswp d7, d14\n"
"vtrn.32 d16, d20\n"
"vtrn.32 d17, d21\n"
"vtrn.32 d18, d22\n"
"vtrn.32 d19, d23\n"
"vld1.32 {q12, q13}, [r10], %[srcStride]\n"
"vld1.32 {q14, q15}, [r10], %[srcStride]\n"
"vtrn.32 d24, d28\n"
"vtrn.32 d25, d29\n"
"vtrn.32 d26, d30\n"
"vtrn.32 d27, d31\n"
"vswp d17, d24\n"
"vswp d19, d26\n"
"vswp d21, d28\n"
"vswp d23, d30\n"
"add r10, r12, #16\n"
"vst1.32 {q0}, [r12], %[dstStride]\n"
"vst1.32 {q8}, [r10], %[dstStride]\n"
"vst1.32 {q2}, [r12], %[dstStride]\n"
"vst1.32 {q10}, [r10], %[dstStride]\n"
"vst1.32 {q4}, [r12], %[dstStride]\n"
"vst1.32 {q12}, [r10], %[dstStride]\n"
"vst1.32 {q6}, [r12], %[dstStride]\n"
"vst1.32 {q14}, [r10], %[dstStride]\n"
"vst1.32 {q1}, [r12], %[dstStride]\n"
"vst1.32 {q9}, [r10], %[dstStride]\n"
"vst1.32 {q3}, [r12], %[dstStride]\n"
"vst1.32 {q11}, [r10], %[dstStride]\n"
"vst1.32 {q5}, [r12], %[dstStride]\n"
"vst1.32 {q13}, [r10], %[dstStride]\n"
"vst1.32 {q7}, [r12], %[dstStride]\n"
"vst1.32 {q15}, [r10], %[dstStride]\n"
:
:
[ dst_ptr ] "r"(dst_ptr), [ src_ptr ] "r"(src_ptr), [ srcStride ] "r"(srcStride), [ dstStride ] "r"(dstStride)
: "r10", "r12", "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7", "q8", "q9", "q10", "q11", "q12", "q13", "q14",
"q15");
#else
for (int tr = 0; tr < C8NUM; tr++) {
for (int tc = 0; tc < C8NUM; tc++) {
dst_ptr[tc * plane + tr] = src_ptr[tr * channel + tc];
}
}
#endif
}
for (; c < channel; c++) {
const float *src_ptr = src_batch + hw * channel + c;
float *dst_ptr = dst_batch + c * plane + hw;
for (size_t i = 0; i < C8NUM; i++) {
dst_ptr[i] = src_ptr[i * channel];
}
}
}
for (; hw < plane; hw++) {
const float *src_ptr = src_batch + hw * channel;
float *dst_ptr = dst_batch + hw;
for (size_t i = 0; i < channel; i++) {
dst_ptr[i * plane] = src_ptr[i];
}
}
}
return;
}
#endif
void PackNCHWToNHWCFp32(const void *src, void *dst, int batch, int plane, int channel) {
return PackNHWCToNCHWFp32(src, dst, batch, channel, plane);
}

View File

@ -0,0 +1,50 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_NNACL_FP32_PACK_H_
#define MINDSPORE_LITE_NNACL_FP32_PACK_H_
#ifdef ENABLE_NEON
#include <arm_neon.h>
#endif
#include "nnacl/conv_parameter.h"
#ifdef __cplusplus
extern "C" {
#endif
void PackHWCToWHC(const float *src, float *dst, int height, int width, int channel);
void PackNHWCToNC4HW4Fp32(const void *src, void *dst, int batch, int plane, int channel);
void PackNCHWToNC4HW4Fp32(const void *src, void *dst, int batch, int plane, int channel);
void PackNHWCToNHWC4Fp32(const void *src, void *dst, int batch, int plane, int channel);
void PackNHWCToNHWC8Fp32(const void *src, void *dst, int batch, int plane, int channel);
void PackNHWCToNCHWFp32(const void *src, void *dst, int batch, int plane, int channel);
void PackNCHWToNHWCFp32(const void *src, void *dst, int batch, int plane, int channel);
void PackNHWC4ToNHWCFp32(const void *src, void *dst, int batch, int plane, int channel);
void PackNC4HW4ToNHWC4Fp32(const void *src, void *dst, int batch, int plane, int channel);
void PackNC4HW4ToNHWCFp32(const void *src, void *dst, int batch, int plane, int channel);
void PackNHWCToC8HWN8Fp32(const void *src, void *dst, int batch, int plane, int channel);
void PackWeightKHWToHWKFp32(const void *src, void *dst, int plane, int channel);
void PackDepthwiseIndirectWeightC4Fp32(const void *src, void *dst, int height, int width, int channel);
void PackDepthwiseIndirectWeightC8Fp32(const void *src, void *dst, int height, int width, int channel);
void Im2ColPackUnitFp32(const float *input_data, const ConvParameter *conv_param, float *packed_input, int real_cal_num,
int block_index);
#ifdef __cplusplus
}
#endif
#endif // MINDSPORE_LITE_NNACL_FP32_PAD_H_

View File

@ -22,7 +22,7 @@
#endif
#include "nnacl/op_base.h"
#include "nnacl/pooling_parameter.h"
#include "nnacl/quantization/quantize.h"
#include "mindspore/lite/nnacl/int8/quantize.h"
#ifdef __cplusplus
extern "C" {

View File

@ -18,7 +18,7 @@
#ifdef ENABLE_NEON
#include <arm_neon.h>
#endif
#include "nnacl/quantization/fixed_point.h"
#include "nnacl/int8/fixed_point.h"
void AddInt8(const int8_t *input0, const int8_t *input1, int8_t *output, int size, AddQuantParameter *params) {
int in0_left_shift = (1 << params->left_shift_) * (1 << params->in0_args_.left_shift_);

View File

@ -17,7 +17,7 @@
#define MINDSPORE_LITE_NNACL_INT8_ARG_MIN_MAX_INT8_H_
#include "nnacl/arg_min_max_parameter.h"
#include "nnacl/quantization/quantize.h"
#include "nnacl/int8/quantize.h"
#ifdef __cplusplus
extern "C" {

View File

@ -17,8 +17,8 @@
#define MINDSPORE_LITE_NNACL_INT8_ARITHMETIC_INT8_H_
#include "nnacl/op_base.h"
#include "nnacl/arithmetic.h"
#include "nnacl/quantization/quantize.h"
#include "nnacl/int8/quantize.h"
#include "nnacl/base/arithmetic_base.h"
#ifdef __cplusplus
extern "C" {

View File

@ -21,7 +21,7 @@
#include <arm_neon.h>
#include "nnacl/int8/common_func_int8.h"
#endif
#include "nnacl/quantization/fixed_point.h"
#include "nnacl/int8/fixed_point.h"
int Int8ElementFloor(int8_t *input, int8_t *output, int element_size, ArithSelfQuantArg para) {
float in_scale = para.in_args_.scale_;

View File

@ -22,7 +22,7 @@
#endif
#include "nnacl/op_base.h"
#include "nnacl/errorcode.h"
#include "nnacl/quantization/quantize.h"
#include "nnacl/int8/quantize.h"
#ifdef __cplusplus
extern "C" {

View File

@ -16,7 +16,7 @@
#ifndef MINDSPORE_LITE_NNACL_INT8_BATCH_TO_SPACE_INT8_H_
#define MINDSPORE_LITE_NNACL_INT8_BATCH_TO_SPACE_INT8_H_
#include "nnacl/op_base.h"
#include "nnacl/quantization/quantize.h"
#include "nnacl/int8/quantize.h"
#ifdef __cplusplus
extern "C" {

View File

@ -15,7 +15,7 @@
*/
#include "nnacl/int8/common_func_int8.h"
#include "nnacl/quantization/fixed_point.h"
#include "nnacl/int8/fixed_point.h"
void PostConvFuncCommInt8(const int32_t *in, int8_t *out, const int32_t *bias, size_t oc, size_t plane,
size_t out_oc_stride, size_t in_plane_stride, int32_t multiplier, int32_t mini, int32_t maxi,

View File

@ -17,8 +17,6 @@
#ifndef MINDSPORE_LITE_NNACL_INT8_COMMON_FUNC_H_
#define MINDSPORE_LITE_NNACL_INT8_COMMON_FUNC_H_
#include <stdint.h>
#include <stdio.h>
#include <string.h>
#ifdef ENABLE_NEON
#include <arm_neon.h>
@ -29,9 +27,6 @@
#ifdef __cplusplus
extern "C" {
#endif
void PostFuncInt8C8(const int32_t *in, const int32_t *bias, int8_t *out, size_t oc, size_t plane, int32_t multiplier,
int32_t left_shift, int32_t right_shift, int32_t zp, int32_t mini, int32_t maxi);
void PostFuncInt8C4(const int32_t *in, const int32_t *bias, int8_t *out, size_t oc, size_t plane, size_t stride,
int32_t multiplier, int32_t left_shift, int32_t right_shift, int32_t zp, int32_t mini,
int32_t maxi);

View File

@ -0,0 +1,39 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "nnacl/int8/conv1x1_int8.h"
void Conv1x1Int8Opt(const int8_t *packed_input, const int8_t *packed_weight, int8_t *dst, const int32_t *input_sum,
const int32_t *bias, int row, int col, int deep4, int32_t *left_shift, int32_t *right_shift,
int32_t *multiplier, ConvParameter *conv_param, MATMUL_OPT_DP_FUNC matmul_func, int *filter_zp) {
int is_per_oc = (int)conv_param->conv_quant_arg_.filter_arg_num_ != 1;
matmul_func(packed_input, packed_weight, dst, row, col, deep4, conv_param->output_channel_, input_sum, bias,
left_shift, right_shift, multiplier, conv_param->conv_quant_arg_.output_quant_args_[0].zp_,
conv_param->conv_quant_arg_.out_act_min_[0], conv_param->conv_quant_arg_.out_act_max_[0], is_per_oc,
filter_zp);
return;
}
void Conv1x1Int8(const int8_t *packed_input, const int8_t *packed_weight, int8_t *dst, const int32_t *input_sum,
const int32_t *bias, int row, int col, int deep16, int32_t *left_shift, int32_t *right_shift,
int32_t *multiplier, ConvParameter *conv_param, int32_t *filter_zp) {
int is_per_oc = (int)conv_param->conv_quant_arg_.filter_arg_num_ != 1;
MatmulInt8Opt(packed_input, packed_weight, dst, row, col, deep16, input_sum, bias,
conv_param->conv_quant_arg_.out_act_min_[0], conv_param->conv_quant_arg_.out_act_max_[0],
conv_param->conv_quant_arg_.output_quant_args_[0].zp_, multiplier, left_shift, right_shift,
conv_param->output_channel_, is_per_oc, filter_zp);
return;
}

View File

@ -0,0 +1,45 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_NNACL_INT8_CONV1X1_INT8_H_
#define MINDSPORE_LITE_NNACL_INT8_CONV1X1_INT8_H_
#ifdef ENABLE_NEON
#include <arm_neon.h>
#endif
#include "nnacl/pack.h"
#include "nnacl/op_base.h"
#include "nnacl/common_func.h"
#include "nnacl/conv_parameter.h"
#include "nnacl/int8/quantize.h"
#include "nnacl/matmul_parameter.h"
#include "nnacl/int8/matmul_int8.h"
#ifdef __cplusplus
extern "C" {
#endif
void Conv1x1Int8(const int8_t *packed_input, const int8_t *packed_weight, int8_t *dst, const int32_t *input_sum,
const int32_t *bias, int row, int col, int deep16, int32_t *left_shift, int32_t *right_shift,
int32_t *multiplier, ConvParameter *conv_param, int32_t *filter_zp);
void Conv1x1Int8Opt(const int8_t *packed_input, const int8_t *packed_weight, int8_t *dst, const int32_t *input_sum,
const int32_t *bias, int row, int col, int deep4, int32_t *left_shift, int32_t *right_shift,
int32_t *multiplier, ConvParameter *conv_param, MATMUL_OPT_DP_FUNC matmul_func, int32_t *filter_zp);
#ifdef __cplusplus
}
#endif
#endif // MINDSPORE_LITE_NNACL_INT8_CONV1X1_INT8_H_

View File

@ -0,0 +1,900 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "nnacl/int8/conv3x3_int8.h"
void Conv3x3Int8InputUnit(int16_t *tmp_data, int16_t *trans_input_data, size_t step, int input_zp) {
#ifdef ENABLE_ARM
int16x8_t zp = vdupq_n_s16(input_zp);
int16x8_t d00 = vsubq_s16(vld1q_s16(tmp_data), zp);
int16x8_t d01 = vsubq_s16(vld1q_s16(tmp_data + 8), zp);
int16x8_t d02 = vsubq_s16(vld1q_s16(tmp_data + 2 * 8), zp);
int16x8_t d03 = vsubq_s16(vld1q_s16(tmp_data + 3 * 8), zp);
int16x8_t d10 = vsubq_s16(vld1q_s16(tmp_data + 4 * 8), zp);
int16x8_t d11 = vsubq_s16(vld1q_s16(tmp_data + 5 * 8), zp);
int16x8_t d12 = vsubq_s16(vld1q_s16(tmp_data + 6 * 8), zp);
int16x8_t d13 = vsubq_s16(vld1q_s16(tmp_data + 7 * 8), zp);
int16x8_t d20 = vsubq_s16(vld1q_s16(tmp_data + 8 * 8), zp);
int16x8_t d21 = vsubq_s16(vld1q_s16(tmp_data + 9 * 8), zp);
int16x8_t d22 = vsubq_s16(vld1q_s16(tmp_data + 10 * 8), zp);
int16x8_t d23 = vsubq_s16(vld1q_s16(tmp_data + 11 * 8), zp);
int16x8_t d30 = vsubq_s16(vld1q_s16(tmp_data + 12 * 8), zp);
int16x8_t d31 = vsubq_s16(vld1q_s16(tmp_data + 13 * 8), zp);
int16x8_t d32 = vsubq_s16(vld1q_s16(tmp_data + 14 * 8), zp);
int16x8_t d33 = vsubq_s16(vld1q_s16(tmp_data + 15 * 8), zp);
int16x8_t t00 = vsubq_s16(d00, d20);
int16x8_t t01 = vsubq_s16(d01, d21);
int16x8_t t02 = vsubq_s16(d02, d22);
int16x8_t t03 = vsubq_s16(d03, d23);
int16x8_t t10 = vaddq_s16(d10, d20);
int16x8_t t11 = vaddq_s16(d11, d21);
int16x8_t t12 = vaddq_s16(d12, d22);
int16x8_t t13 = vaddq_s16(d13, d23);
int16x8_t t20 = vsubq_s16(d20, d10);
int16x8_t t21 = vsubq_s16(d21, d11);
int16x8_t t22 = vsubq_s16(d22, d12);
int16x8_t t23 = vsubq_s16(d23, d13);
int16x8_t t30 = vsubq_s16(d10, d30);
int16x8_t t31 = vsubq_s16(d11, d31);
int16x8_t t32 = vsubq_s16(d12, d32);
int16x8_t t33 = vsubq_s16(d13, d33);
int16x8_t m00 = vsubq_s16(t00, t02);
int16x8_t m01 = vaddq_s16(t01, t02);
int16x8_t m02 = vsubq_s16(t02, t01);
int16x8_t m03 = vsubq_s16(t01, t03);
int16x8_t m10 = vsubq_s16(t10, t12);
int16x8_t m11 = vaddq_s16(t11, t12);
int16x8_t m12 = vsubq_s16(t12, t11);
int16x8_t m13 = vsubq_s16(t11, t13);
int16x8_t m20 = vsubq_s16(t20, t22);
int16x8_t m21 = vaddq_s16(t21, t22);
int16x8_t m22 = vsubq_s16(t22, t21);
int16x8_t m23 = vsubq_s16(t21, t23);
int16x8_t m30 = vsubq_s16(t30, t32);
int16x8_t m31 = vaddq_s16(t31, t32);
int16x8_t m32 = vsubq_s16(t32, t31);
int16x8_t m33 = vsubq_s16(t31, t33);
vst1q_s16(trans_input_data, m00);
vst1q_s16(trans_input_data + step, m01);
vst1q_s16(trans_input_data + 2 * step, m02);
vst1q_s16(trans_input_data + 3 * step, m03);
vst1q_s16(trans_input_data + 4 * step, m10);
vst1q_s16(trans_input_data + 5 * step, m11);
vst1q_s16(trans_input_data + 6 * step, m12);
vst1q_s16(trans_input_data + 7 * step, m13);
vst1q_s16(trans_input_data + 8 * step, m20);
vst1q_s16(trans_input_data + 9 * step, m21);
vst1q_s16(trans_input_data + 10 * step, m22);
vst1q_s16(trans_input_data + 11 * step, m23);
vst1q_s16(trans_input_data + 12 * step, m30);
vst1q_s16(trans_input_data + 13 * step, m31);
vst1q_s16(trans_input_data + 14 * step, m32);
vst1q_s16(trans_input_data + 15 * step, m33);
#else
for (int i = 0; i < C8NUM; i++) {
int16_t *local_ptr = tmp_data + i;
int16_t d00 = local_ptr[0] - input_zp;
int16_t d01 = (local_ptr + C8NUM)[0] - input_zp;
int16_t d02 = (local_ptr + 2 * C8NUM)[0] - input_zp;
int16_t d03 = (local_ptr + 3 * C8NUM)[0] - input_zp;
int16_t d10 = (local_ptr + 4 * C8NUM)[0] - input_zp;
int16_t d11 = (local_ptr + 5 * C8NUM)[0] - input_zp;
int16_t d12 = (local_ptr + 6 * C8NUM)[0] - input_zp;
int16_t d13 = (local_ptr + 7 * C8NUM)[0] - input_zp;
int16_t d20 = (local_ptr + 8 * C8NUM)[0] - input_zp;
int16_t d21 = (local_ptr + 9 * C8NUM)[0] - input_zp;
int16_t d22 = (local_ptr + 10 * C8NUM)[0] - input_zp;
int16_t d23 = (local_ptr + 11 * C8NUM)[0] - input_zp;
int16_t d30 = (local_ptr + 12 * C8NUM)[0] - input_zp;
int16_t d31 = (local_ptr + 13 * C8NUM)[0] - input_zp;
int16_t d32 = (local_ptr + 14 * C8NUM)[0] - input_zp;
int16_t d33 = (local_ptr + 15 * C8NUM)[0] - input_zp;
int16_t t00 = d00 - d20;
int16_t t01 = d01 - d21;
int16_t t02 = d02 - d22;
int16_t t03 = d03 - d23;
int16_t t10 = d10 + d20;
int16_t t11 = d11 + d21;
int16_t t12 = d12 + d22;
int16_t t13 = d13 + d23;
int16_t t20 = d20 - d10;
int16_t t21 = d21 - d11;
int16_t t22 = d22 - d12;
int16_t t23 = d23 - d13;
int16_t t30 = d10 - d30;
int16_t t31 = d11 - d31;
int16_t t32 = d12 - d32;
int16_t t33 = d13 - d33;
int16_t m00 = t00 - t02;
int16_t m01 = t01 + t02;
int16_t m02 = t02 - t01;
int16_t m03 = t01 - t03;
int16_t m10 = t10 - t12;
int16_t m11 = t11 + t12;
int16_t m12 = t12 - t11;
int16_t m13 = t11 - t13;
int16_t m20 = t20 - t22;
int16_t m21 = t21 + t22;
int16_t m22 = t22 - t21;
int16_t m23 = t21 - t23;
int16_t m30 = t30 - t32;
int16_t m31 = t31 + t32;
int16_t m32 = t32 - t31;
int16_t m33 = t31 - t33;
(trans_input_data + i)[0] = m00;
(trans_input_data + i + step)[0] = m01;
(trans_input_data + i + 2 * step)[0] = m02;
(trans_input_data + i + 3 * step)[0] = m03;
(trans_input_data + i + 4 * step)[0] = m10;
(trans_input_data + i + 5 * step)[0] = m11;
(trans_input_data + i + 6 * step)[0] = m12;
(trans_input_data + i + 7 * step)[0] = m13;
(trans_input_data + i + 8 * step)[0] = m20;
(trans_input_data + i + 9 * step)[0] = m21;
(trans_input_data + i + 10 * step)[0] = m22;
(trans_input_data + i + 11 * step)[0] = m23;
(trans_input_data + i + 12 * step)[0] = m30;
(trans_input_data + i + 13 * step)[0] = m31;
(trans_input_data + i + 14 * step)[0] = m32;
(trans_input_data + i + 15 * step)[0] = m33;
}
#endif
}
void Conv3x3Int8FilterTransform(const int16_t *weight_data, int16_t *trans_weight, int iC8, int output_channel,
int kernel_plane) {
const int input_unit = 4;
int dst_step = iC8 * C8NUM * C4NUM;
for (int o = 0; o < output_channel; o++) {
int oc4_block_num = o / C4NUM;
int oc4_block_rem = o % C4NUM;
int src_oc_offset = o * iC8 * C8NUM * kernel_plane;
int dst_oc_offset = oc4_block_num * C4NUM * iC8 * C8NUM * input_unit * input_unit + oc4_block_rem;
for (int i = 0; i < iC8; i++) {
const int16_t *src_ic8_ptr = weight_data + src_oc_offset + i * kernel_plane * C8NUM;
int16_t *dst_ic8_ptr = trans_weight + dst_oc_offset + i * C4NUM * C8NUM;
#ifdef ENABLE_ARM
int16x8_t g00 = vld1q_s16(src_ic8_ptr);
int16x8_t g01 = vld1q_s16(src_ic8_ptr + 8);
int16x8_t g02 = vld1q_s16(src_ic8_ptr + 2 * 8);
int16x8_t g10 = vld1q_s16(src_ic8_ptr + 3 * 8);
int16x8_t g11 = vld1q_s16(src_ic8_ptr + 4 * 8);
int16x8_t g12 = vld1q_s16(src_ic8_ptr + 5 * 8);
int16x8_t g20 = vld1q_s16(src_ic8_ptr + 6 * 8);
int16x8_t g21 = vld1q_s16(src_ic8_ptr + 7 * 8);
int16x8_t g22 = vld1q_s16(src_ic8_ptr + 8 * 8);
int16x8_t dst00 = vmulq_n_s16(g00, 2);
int16x8_t dst01 = vmulq_n_s16(g01, 2);
int16x8_t dst02 = vmulq_n_s16(g02, 2);
int16x8_t dst10 = vaddq_s16(vaddq_s16(g00, g10), g20);
int16x8_t dst11 = vaddq_s16(vaddq_s16(g01, g11), g21);
int16x8_t dst12 = vaddq_s16(vaddq_s16(g02, g12), g22);
int16x8_t dst20 = vaddq_s16(vsubq_s16(g00, g10), g20);
int16x8_t dst21 = vaddq_s16(vsubq_s16(g01, g11), g21);
int16x8_t dst22 = vaddq_s16(vsubq_s16(g02, g12), g22);
int16x8_t dst30 = vmulq_n_s16(g20, 2);
int16x8_t dst31 = vmulq_n_s16(g21, 2);
int16x8_t dst32 = vmulq_n_s16(g22, 2);
int16x8_t m00 = vmulq_n_s16(dst00, 2);
int16x8_t m01 = vaddq_s16(vaddq_s16(dst00, dst01), dst02);
int16x8_t m02 = vaddq_s16(vsubq_s16(dst00, dst01), dst02);
int16x8_t m03 = vmulq_n_s16(dst02, 2);
int16x8_t m10 = vmulq_n_s16(dst10, 2);
int16x8_t m11 = vaddq_s16(vaddq_s16(dst10, dst11), dst12);
int16x8_t m12 = vaddq_s16(vsubq_s16(dst10, dst11), dst12);
int16x8_t m13 = vmulq_n_s16(dst12, 2);
int16x8_t m20 = vmulq_n_s16(dst20, 2);
int16x8_t m21 = vaddq_s16(vaddq_s16(dst20, dst21), dst22);
int16x8_t m22 = vaddq_s16(vsubq_s16(dst20, dst21), dst22);
int16x8_t m23 = vmulq_n_s16(dst22, 2);
int16x8_t m30 = vmulq_n_s16(dst30, 2);
int16x8_t m31 = vaddq_s16(vaddq_s16(dst30, dst31), dst32);
int16x8_t m32 = vaddq_s16(vsubq_s16(dst30, dst31), dst32);
int16x8_t m33 = vmulq_n_s16(dst32, 2);
dst_ic8_ptr[0] = m00[0];
dst_ic8_ptr[4] = m00[1];
dst_ic8_ptr[8] = m00[2];
dst_ic8_ptr[12] = m00[3];
dst_ic8_ptr[16] = m00[4];
dst_ic8_ptr[20] = m00[5];
dst_ic8_ptr[24] = m00[6];
dst_ic8_ptr[28] = m00[7];
dst_ic8_ptr[0 + dst_step] = m01[0];
dst_ic8_ptr[4 + dst_step] = m01[1];
dst_ic8_ptr[8 + dst_step] = m01[2];
dst_ic8_ptr[12 + dst_step] = m01[3];
dst_ic8_ptr[16 + dst_step] = m01[4];
dst_ic8_ptr[20 + dst_step] = m01[5];
dst_ic8_ptr[24 + dst_step] = m01[6];
dst_ic8_ptr[28 + dst_step] = m01[7];
dst_ic8_ptr[0 + 2 * dst_step] = m02[0];
dst_ic8_ptr[4 + 2 * dst_step] = m02[1];
dst_ic8_ptr[8 + 2 * dst_step] = m02[2];
dst_ic8_ptr[12 + 2 * dst_step] = m02[3];
dst_ic8_ptr[16 + 2 * dst_step] = m02[4];
dst_ic8_ptr[20 + 2 * dst_step] = m02[5];
dst_ic8_ptr[24 + 2 * dst_step] = m02[6];
dst_ic8_ptr[28 + 2 * dst_step] = m02[7];
dst_ic8_ptr[0 + 3 * dst_step] = m03[0];
dst_ic8_ptr[4 + 3 * dst_step] = m03[1];
dst_ic8_ptr[8 + 3 * dst_step] = m03[2];
dst_ic8_ptr[12 + 3 * dst_step] = m03[3];
dst_ic8_ptr[16 + 3 * dst_step] = m03[4];
dst_ic8_ptr[20 + 3 * dst_step] = m03[5];
dst_ic8_ptr[24 + 3 * dst_step] = m03[6];
dst_ic8_ptr[28 + 3 * dst_step] = m03[7];
dst_ic8_ptr[0 + 4 * dst_step] = m10[0];
dst_ic8_ptr[4 + 4 * dst_step] = m10[1];
dst_ic8_ptr[8 + 4 * dst_step] = m10[2];
dst_ic8_ptr[12 + 4 * dst_step] = m10[3];
dst_ic8_ptr[16 + 4 * dst_step] = m10[4];
dst_ic8_ptr[20 + 4 * dst_step] = m10[5];
dst_ic8_ptr[24 + 4 * dst_step] = m10[6];
dst_ic8_ptr[28 + 4 * dst_step] = m10[7];
dst_ic8_ptr[0 + 5 * dst_step] = m11[0];
dst_ic8_ptr[4 + 5 * dst_step] = m11[1];
dst_ic8_ptr[8 + 5 * dst_step] = m11[2];
dst_ic8_ptr[12 + 5 * dst_step] = m11[3];
dst_ic8_ptr[16 + 5 * dst_step] = m11[4];
dst_ic8_ptr[20 + 5 * dst_step] = m11[5];
dst_ic8_ptr[24 + 5 * dst_step] = m11[6];
dst_ic8_ptr[28 + 5 * dst_step] = m11[7];
dst_ic8_ptr[0 + 6 * dst_step] = m12[0];
dst_ic8_ptr[4 + 6 * dst_step] = m12[1];
dst_ic8_ptr[8 + 6 * dst_step] = m12[2];
dst_ic8_ptr[12 + 6 * dst_step] = m12[3];
dst_ic8_ptr[16 + 6 * dst_step] = m12[4];
dst_ic8_ptr[20 + 6 * dst_step] = m12[5];
dst_ic8_ptr[24 + 6 * dst_step] = m12[6];
dst_ic8_ptr[28 + 6 * dst_step] = m12[7];
dst_ic8_ptr[0 + 7 * dst_step] = m13[0];
dst_ic8_ptr[4 + 7 * dst_step] = m13[1];
dst_ic8_ptr[8 + 7 * dst_step] = m13[2];
dst_ic8_ptr[12 + 7 * dst_step] = m13[3];
dst_ic8_ptr[16 + 7 * dst_step] = m13[4];
dst_ic8_ptr[20 + 7 * dst_step] = m13[5];
dst_ic8_ptr[24 + 7 * dst_step] = m13[6];
dst_ic8_ptr[28 + 7 * dst_step] = m13[7];
dst_ic8_ptr[0 + 8 * dst_step] = m20[0];
dst_ic8_ptr[4 + 8 * dst_step] = m20[1];
dst_ic8_ptr[8 + 8 * dst_step] = m20[2];
dst_ic8_ptr[12 + 8 * dst_step] = m20[3];
dst_ic8_ptr[16 + 8 * dst_step] = m20[4];
dst_ic8_ptr[20 + 8 * dst_step] = m20[5];
dst_ic8_ptr[24 + 8 * dst_step] = m20[6];
dst_ic8_ptr[28 + 8 * dst_step] = m20[7];
dst_ic8_ptr[0 + 9 * dst_step] = m21[0];
dst_ic8_ptr[4 + 9 * dst_step] = m21[1];
dst_ic8_ptr[8 + 9 * dst_step] = m21[2];
dst_ic8_ptr[12 + 9 * dst_step] = m21[3];
dst_ic8_ptr[16 + 9 * dst_step] = m21[4];
dst_ic8_ptr[20 + 9 * dst_step] = m21[5];
dst_ic8_ptr[24 + 9 * dst_step] = m21[6];
dst_ic8_ptr[28 + 9 * dst_step] = m21[7];
dst_ic8_ptr[0 + 10 * dst_step] = m22[0];
dst_ic8_ptr[4 + 10 * dst_step] = m22[1];
dst_ic8_ptr[8 + 10 * dst_step] = m22[2];
dst_ic8_ptr[12 + 10 * dst_step] = m22[3];
dst_ic8_ptr[16 + 10 * dst_step] = m22[4];
dst_ic8_ptr[20 + 10 * dst_step] = m22[5];
dst_ic8_ptr[24 + 10 * dst_step] = m22[6];
dst_ic8_ptr[28 + 10 * dst_step] = m22[7];
dst_ic8_ptr[0 + 11 * dst_step] = m23[0];
dst_ic8_ptr[4 + 11 * dst_step] = m23[1];
dst_ic8_ptr[8 + 11 * dst_step] = m23[2];
dst_ic8_ptr[12 + 11 * dst_step] = m23[3];
dst_ic8_ptr[16 + 11 * dst_step] = m23[4];
dst_ic8_ptr[20 + 11 * dst_step] = m23[5];
dst_ic8_ptr[24 + 11 * dst_step] = m23[6];
dst_ic8_ptr[28 + 11 * dst_step] = m23[7];
dst_ic8_ptr[0 + 12 * dst_step] = m30[0];
dst_ic8_ptr[4 + 12 * dst_step] = m30[1];
dst_ic8_ptr[8 + 12 * dst_step] = m30[2];
dst_ic8_ptr[12 + 12 * dst_step] = m30[3];
dst_ic8_ptr[16 + 12 * dst_step] = m30[4];
dst_ic8_ptr[20 + 12 * dst_step] = m30[5];
dst_ic8_ptr[24 + 12 * dst_step] = m30[6];
dst_ic8_ptr[28 + 12 * dst_step] = m30[7];
dst_ic8_ptr[0 + 13 * dst_step] = m31[0];
dst_ic8_ptr[4 + 13 * dst_step] = m31[1];
dst_ic8_ptr[8 + 13 * dst_step] = m31[2];
dst_ic8_ptr[12 + 13 * dst_step] = m31[3];
dst_ic8_ptr[16 + 13 * dst_step] = m31[4];
dst_ic8_ptr[20 + 13 * dst_step] = m31[5];
dst_ic8_ptr[24 + 13 * dst_step] = m31[6];
dst_ic8_ptr[28 + 13 * dst_step] = m31[7];
dst_ic8_ptr[0 + 14 * dst_step] = m32[0];
dst_ic8_ptr[4 + 14 * dst_step] = m32[1];
dst_ic8_ptr[8 + 14 * dst_step] = m32[2];
dst_ic8_ptr[12 + 14 * dst_step] = m32[3];
dst_ic8_ptr[16 + 14 * dst_step] = m32[4];
dst_ic8_ptr[20 + 14 * dst_step] = m32[5];
dst_ic8_ptr[24 + 14 * dst_step] = m32[6];
dst_ic8_ptr[28 + 14 * dst_step] = m32[7];
dst_ic8_ptr[0 + 15 * dst_step] = m33[0];
dst_ic8_ptr[4 + 15 * dst_step] = m33[1];
dst_ic8_ptr[8 + 15 * dst_step] = m33[2];
dst_ic8_ptr[12 + 15 * dst_step] = m33[3];
dst_ic8_ptr[16 + 15 * dst_step] = m33[4];
dst_ic8_ptr[20 + 15 * dst_step] = m33[5];
dst_ic8_ptr[24 + 15 * dst_step] = m33[6];
dst_ic8_ptr[28 + 15 * dst_step] = m33[7];
#else
for (int j = 0; j < C8NUM; j++) {
const int16_t *local_ptr = src_ic8_ptr + j;
int16_t dst00 = local_ptr[0] * 2;
int16_t dst01 = (local_ptr + 8)[0] * 2;
int16_t dst02 = (local_ptr + 16)[0] * 2;
int16_t dst10 = local_ptr[0] + (local_ptr + 24)[0] + (local_ptr + 48)[0];
int16_t dst11 = (local_ptr + 8)[0] + (local_ptr + 32)[0] + (local_ptr + 56)[0];
int16_t dst12 = (local_ptr + 16)[0] + (local_ptr + 40)[0] + (local_ptr + 64)[0];
int16_t dst20 = local_ptr[0] - (local_ptr + 24)[0] + (local_ptr + 48)[0];
int16_t dst21 = (local_ptr + 8)[0] - (local_ptr + 32)[0] + (local_ptr + 56)[0];
int16_t dst22 = (local_ptr + 16)[0] - (local_ptr + 40)[0] + (local_ptr + 64)[0];
int16_t dst30 = (local_ptr + 48)[0] * 2;
int16_t dst31 = (local_ptr + 56)[0] * 2;
int16_t dst32 = (local_ptr + 64)[0] * 2;
int16_t m00 = dst00 * 2;
int16_t m01 = dst00 + dst01 + dst02;
int16_t m02 = dst00 - dst01 + dst02;
int16_t m03 = dst02 * 2;
int16_t m10 = dst10 * 2;
int16_t m11 = dst10 + dst11 + dst12;
int16_t m12 = dst10 - dst11 + dst12;
int16_t m13 = dst12 * 2;
int16_t m20 = dst20 * 2;
int16_t m21 = dst20 + dst21 + dst22;
int16_t m22 = dst20 - dst21 + dst22;
int16_t m23 = dst22 * 2;
int16_t m30 = dst30 * 2;
int16_t m31 = dst30 + dst31 + dst32;
int16_t m32 = dst30 - dst31 + dst32;
int16_t m33 = dst32 * 2;
*(dst_ic8_ptr + j * 4) = m00;
*(dst_ic8_ptr + j * 4 + dst_step) = m01;
*(dst_ic8_ptr + j * 4 + 2 * dst_step) = m02;
*(dst_ic8_ptr + j * 4 + 3 * dst_step) = m03;
*(dst_ic8_ptr + j * 4 + 4 * dst_step) = m10;
*(dst_ic8_ptr + j * 4 + 5 * dst_step) = m11;
*(dst_ic8_ptr + j * 4 + 6 * dst_step) = m12;
*(dst_ic8_ptr + j * 4 + 7 * dst_step) = m13;
*(dst_ic8_ptr + j * 4 + 8 * dst_step) = m20;
*(dst_ic8_ptr + j * 4 + 9 * dst_step) = m21;
*(dst_ic8_ptr + j * 4 + 10 * dst_step) = m22;
*(dst_ic8_ptr + j * 4 + 11 * dst_step) = m23;
*(dst_ic8_ptr + j * 4 + 12 * dst_step) = m30;
*(dst_ic8_ptr + j * 4 + 13 * dst_step) = m31;
*(dst_ic8_ptr + j * 4 + 14 * dst_step) = m32;
*(dst_ic8_ptr + j * 4 + 15 * dst_step) = m33;
}
#endif
}
}
}
void Conv3x3Int8OutputUnit(const int32_t *gemm_out, const int32_t *bias_data, int8_t *output_data, bool h_not_bound,
bool w_not_bound, int output_w, int real_num, int oc_start, ConvParameter *conv_param) {
int32_t *left_shift = conv_param->conv_quant_arg_.left_shift_;
int32_t *right_shift = conv_param->conv_quant_arg_.right_shift_;
int32_t *quant_multiplier = conv_param->conv_quant_arg_.quant_multiplier_;
int output_zp = conv_param->conv_quant_arg_.output_quant_args_[0].zp_;
int out_min = conv_param->conv_quant_arg_.out_act_min_[0];
int out_max = conv_param->conv_quant_arg_.out_act_max_[0];
#ifdef ENABLE_ARM
int32x4_t bias_ptr = vld1q_s32(bias_data);
int32x4_t s00 = vld1q_s32(gemm_out);
int32x4_t s01 = vld1q_s32(gemm_out + 4);
int32x4_t s02 = vld1q_s32(gemm_out + 8);
int32x4_t s03 = vld1q_s32(gemm_out + 12);
int32x4_t s10 = vld1q_s32(gemm_out + 16);
int32x4_t s11 = vld1q_s32(gemm_out + 20);
int32x4_t s12 = vld1q_s32(gemm_out + 24);
int32x4_t s13 = vld1q_s32(gemm_out + 28);
int32x4_t s20 = vld1q_s32(gemm_out + 32);
int32x4_t s21 = vld1q_s32(gemm_out + 36);
int32x4_t s22 = vld1q_s32(gemm_out + 40);
int32x4_t s23 = vld1q_s32(gemm_out + 44);
int32x4_t s30 = vld1q_s32(gemm_out + 48);
int32x4_t s31 = vld1q_s32(gemm_out + 52);
int32x4_t s32 = vld1q_s32(gemm_out + 56);
int32x4_t s33 = vld1q_s32(gemm_out + 60);
int32x4_t t00 = vshrq_n_s32(vaddq_s32(vaddq_s32(s00, s10), s20), 1);
int32x4_t t01 = vshrq_n_s32(vaddq_s32(vaddq_s32(s01, s11), s21), 1);
int32x4_t t02 = vshrq_n_s32(vaddq_s32(vaddq_s32(s02, s12), s22), 1);
int32x4_t t03 = vshrq_n_s32(vaddq_s32(vaddq_s32(s03, s13), s23), 1);
int32x4_t t10 = vshrq_n_s32(vsubq_s32(vsubq_s32(s10, s20), s30), 1);
int32x4_t t11 = vshrq_n_s32(vsubq_s32(vsubq_s32(s11, s21), s31), 1);
int32x4_t t12 = vshrq_n_s32(vsubq_s32(vsubq_s32(s12, s22), s32), 1);
int32x4_t t13 = vshrq_n_s32(vsubq_s32(vsubq_s32(s13, s23), s33), 1);
int32x4_t d00 = vaddq_s32(vshrq_n_s32(vaddq_s32(vaddq_s32(t00, t01), t02), 1), bias_ptr);
int32x4_t d01 = vaddq_s32(vshrq_n_s32(vsubq_s32(vsubq_s32(t01, t02), t03), 1), bias_ptr);
int32x4_t d10 = vaddq_s32(vshrq_n_s32(vaddq_s32(vaddq_s32(t10, t11), t12), 1), bias_ptr);
int32x4_t d11 = vaddq_s32(vshrq_n_s32(vsubq_s32(vsubq_s32(t11, t12), t13), 1), bias_ptr);
int32x4_t out_multiplier;
int32x4_t ls;
int32x4_t rs;
if ((conv_param->conv_quant_arg_.per_channel_ & FILTER_PER_CHANNEL)) {
out_multiplier = vld1q_s32(quant_multiplier + oc_start);
ls = vld1q_s32(left_shift + oc_start);
rs = vld1q_s32(right_shift + oc_start);
} else {
out_multiplier = vdupq_n_s32(quant_multiplier[0]);
ls = vdupq_n_s32(left_shift[0]);
rs = vdupq_n_s32(right_shift[0]);
}
int32x4_t out_zp = vdupq_n_s32(output_zp);
int32x4_t output_min = vdupq_n_s32(out_min);
int32x4_t output_max = vdupq_n_s32(out_max);
d00 = vqshlq_s32(d00, ls);
d00 = vqrdmulhq_s32(d00, out_multiplier);
int32x4_t carry = vandq_s32(d00, rs);
carry = vshrq_n_s32(carry, 31);
d00 = vqaddq_s32(d00, carry);
d00 = vqrshlq_s32(d00, rs);
d00 = vaddq_s32(d00, out_zp);
d00 = vmaxq_s32(d00, output_min);
d00 = vminq_s32(d00, output_max);
d01 = vqshlq_s32(d01, ls);
d01 = vqrdmulhq_s32(d01, out_multiplier);
carry = vandq_s32(d01, rs);
carry = vshrq_n_s32(carry, 31);
d01 = vqaddq_s32(d01, carry);
d01 = vqrshlq_s32(d01, rs);
d01 = vaddq_s32(d01, out_zp);
d01 = vmaxq_s32(d01, output_min);
d01 = vminq_s32(d01, output_max);
d10 = vqshlq_s32(d10, ls);
d10 = vqrdmulhq_s32(d10, out_multiplier);
carry = vandq_s32(d10, rs);
carry = vshrq_n_s32(carry, 31);
d10 = vqaddq_s32(d10, carry);
d10 = vqrshlq_s32(d10, rs);
d10 = vaddq_s32(d10, out_zp);
d10 = vmaxq_s32(d10, output_min);
d10 = vminq_s32(d10, output_max);
d11 = vqshlq_s32(d11, ls);
d11 = vqrdmulhq_s32(d11, out_multiplier);
carry = vandq_s32(d11, rs);
carry = vshrq_n_s32(carry, 31);
d11 = vqaddq_s32(d11, carry);
d11 = vqrshlq_s32(d11, rs);
d11 = vaddq_s32(d11, out_zp);
d11 = vmaxq_s32(d11, output_min);
d11 = vminq_s32(d11, output_max);
(output_data)[0] = (int8_t)d00[0];
(output_data + 1)[0] = (int8_t)d00[1];
(output_data + 2)[0] = (int8_t)d00[2];
(output_data + 3)[0] = (int8_t)d00[3];
if (w_not_bound) {
*(output_data + 4) = (int8_t)d01[0];
*(output_data + 5) = (int8_t)d01[1];
*(output_data + 6) = (int8_t)d01[2];
*(output_data + 7) = (int8_t)d01[3];
}
if (h_not_bound) {
*(output_data + output_w * 4) = (int8_t)d10[0];
*(output_data + output_w * 4 + 1) = (int8_t)d10[1];
*(output_data + output_w * 4 + 2) = (int8_t)d10[2];
*(output_data + output_w * 4 + 3) = (int8_t)d10[3];
if (w_not_bound) {
*(output_data + output_w * 4 + 4) = (int8_t)d11[0];
*(output_data + output_w * 4 + 5) = (int8_t)d11[1];
*(output_data + output_w * 4 + 6) = (int8_t)d11[2];
*(output_data + output_w * 4 + 7) = (int8_t)d11[3];
}
}
#else
if ((conv_param->conv_quant_arg_.per_channel_ & FILTER_PER_CHANNEL)) {
for (int i = 0; i < C4NUM; i++) {
const int32_t *local_ptr = gemm_out + i;
const int32_t *bias_ptr = bias_data + i;
int32_t s00 = local_ptr[0];
int32_t s01 = (local_ptr + 4)[0];
int32_t s02 = (local_ptr + 8)[0];
int32_t s03 = (local_ptr + 12)[0];
int32_t s10 = (local_ptr + 16)[0];
int32_t s11 = (local_ptr + 20)[0];
int32_t s12 = (local_ptr + 24)[0];
int32_t s13 = (local_ptr + 28)[0];
int32_t s20 = (local_ptr + 32)[0];
int32_t s21 = (local_ptr + 36)[0];
int32_t s22 = (local_ptr + 40)[0];
int32_t s23 = (local_ptr + 44)[0];
int32_t s30 = (local_ptr + 48)[0];
int32_t s31 = (local_ptr + 52)[0];
int32_t s32 = (local_ptr + 56)[0];
int32_t s33 = (local_ptr + 60)[0];
int32_t t00 = (s00 + s10 + s20) / 2;
int32_t t01 = (s01 + s11 + s21) / 2;
int32_t t02 = (s02 + s12 + s22) / 2;
int32_t t03 = (s03 + s13 + s23) / 2;
int32_t t10 = (s10 - s20 - s30) / 2;
int32_t t11 = (s11 - s21 - s31) / 2;
int32_t t12 = (s12 - s22 - s32) / 2;
int32_t t13 = (s13 - s23 - s33) / 2;
int32_t d00 = (t00 + t01 + t02) / 2 + bias_ptr[0];
int32_t d01 = (t01 - t02 - t03) / 2 + bias_ptr[0];
int32_t d10 = (t10 + t11 + t12) / 2 + bias_ptr[0];
int32_t d11 = (t11 - t12 - t13) / 2 + bias_ptr[0];
int oc_index = oc_start + i;
d00 = RoundingDivideByPOT(
SaturatingRoundingDoublingHighMul(d00 * (1 << (unsigned int)left_shift[oc_index]), quant_multiplier[oc_index]),
-right_shift[oc_index]);
d00 += output_zp;
d00 = d00 > out_min ? d00 : out_min;
d00 = d00 < out_max ? d00 : out_max;
d01 = RoundingDivideByPOT(
SaturatingRoundingDoublingHighMul(d01 * (1 << (unsigned int)left_shift[oc_index]), quant_multiplier[oc_index]),
-right_shift[oc_index]);
d01 += output_zp;
d01 = d01 > out_min ? d01 : out_min;
d01 = d01 < out_max ? d01 : out_max;
d10 = RoundingDivideByPOT(
SaturatingRoundingDoublingHighMul(d10 * (1 << (unsigned int)left_shift[oc_index]), quant_multiplier[oc_index]),
-right_shift[oc_index]);
d10 += output_zp;
d10 = d10 > out_min ? d10 : out_min;
d10 = d10 < out_max ? d10 : out_max;
d11 = RoundingDivideByPOT(
SaturatingRoundingDoublingHighMul(d11 * (1 << (unsigned int)left_shift[oc_index]), quant_multiplier[oc_index]),
-right_shift[oc_index]);
d11 += output_zp;
d11 = d11 > out_min ? d11 : out_min;
d11 = d11 < out_max ? d11 : out_max;
(output_data + i)[0] = (int8_t)d00;
if (w_not_bound) {
(output_data + i + C4NUM)[0] = (int8_t)d01;
}
if (h_not_bound) {
(output_data + i + output_w * C4NUM)[0] = (int8_t)d10;
if (w_not_bound) {
(output_data + i + output_w * C4NUM + C4NUM)[0] = (int8_t)d11;
}
}
}
} else {
for (int i = 0; i < C4NUM; i++) {
const int32_t *local_ptr = gemm_out + i;
const int32_t *bias_ptr = bias_data + i;
int32_t s00 = local_ptr[0];
int32_t s01 = (local_ptr + 4)[0];
int32_t s02 = (local_ptr + 8)[0];
int32_t s03 = (local_ptr + 12)[0];
int32_t s10 = (local_ptr + 16)[0];
int32_t s11 = (local_ptr + 20)[0];
int32_t s12 = (local_ptr + 24)[0];
int32_t s13 = (local_ptr + 28)[0];
int32_t s20 = (local_ptr + 32)[0];
int32_t s21 = (local_ptr + 36)[0];
int32_t s22 = (local_ptr + 40)[0];
int32_t s23 = (local_ptr + 44)[0];
int32_t s30 = (local_ptr + 48)[0];
int32_t s31 = (local_ptr + 52)[0];
int32_t s32 = (local_ptr + 56)[0];
int32_t s33 = (local_ptr + 60)[0];
int32_t t00 = (s00 + s10 + s20) / 2;
int32_t t01 = (s01 + s11 + s21) / 2;
int32_t t02 = (s02 + s12 + s22) / 2;
int32_t t03 = (s03 + s13 + s23) / 2;
int32_t t10 = (s10 - s20 - s30) / 2;
int32_t t11 = (s11 - s21 - s31) / 2;
int32_t t12 = (s12 - s22 - s32) / 2;
int32_t t13 = (s13 - s23 - s33) / 2;
int32_t d00 = (t00 + t01 + t02) / 2 + bias_ptr[0];
int32_t d01 = (t01 - t02 - t03) / 2 + bias_ptr[0];
int32_t d10 = (t10 + t11 + t12) / 2 + bias_ptr[0];
int32_t d11 = (t11 - t12 - t13) / 2 + bias_ptr[0];
d00 = RoundingDivideByPOT(
SaturatingRoundingDoublingHighMul(d00 * (1 << (unsigned int)left_shift[0]), quant_multiplier[0]),
-right_shift[0]);
d00 += output_zp;
d00 = d00 > out_min ? d00 : out_min;
d00 = d00 < out_max ? d00 : out_max;
d01 = RoundingDivideByPOT(
SaturatingRoundingDoublingHighMul(d01 * (1 << (unsigned int)left_shift[0]), quant_multiplier[0]),
-right_shift[0]);
d01 += output_zp;
d01 = d01 > out_min ? d01 : out_min;
d01 = d01 < out_max ? d01 : out_max;
d10 = RoundingDivideByPOT(
SaturatingRoundingDoublingHighMul(d10 * (1 << (unsigned int)left_shift[0]), quant_multiplier[0]),
-right_shift[0]);
d10 += output_zp;
d10 = d10 > out_min ? d10 : out_min;
d10 = d10 < out_max ? d10 : out_max;
d11 = RoundingDivideByPOT(
SaturatingRoundingDoublingHighMul(d11 * (1 << (unsigned int)left_shift[0]), quant_multiplier[0]),
-right_shift[0]);
d11 += output_zp;
d11 = d11 > out_min ? d11 : out_min;
d11 = d11 < out_max ? d11 : out_max;
(output_data + i)[0] = (int8_t)d00;
if (w_not_bound) {
(output_data + i + C4NUM)[0] = (int8_t)d01;
}
if (h_not_bound) {
(output_data + i + output_w * C4NUM)[0] = (int8_t)d10;
if (w_not_bound) {
(output_data + i + output_w * C4NUM + C4NUM)[0] = (int8_t)d11;
}
}
}
}
#endif
}
void Conv3x3Int8OutputTransform(const int32_t *gemm_out, int8_t *out_data, const int32_t *bias_data, int start_index,
int real_cal_num, int out_w_block, ConvParameter *conv_param) {
int output_channel = conv_param->output_channel_;
int output_w = conv_param->output_w_;
int output_h = conv_param->output_h_;
const int oc4 = UP_DIV(output_channel, C4NUM);
const int input_unit = 4;
if (out_w_block == 0) {
return;
}
for (int i = 0; i < real_cal_num; i++) {
int out_w_index = (start_index + i) % out_w_block;
int out_h_index = (start_index + i) / out_w_block;
int src_tile_offset = i * oc4 * C4NUM * input_unit * input_unit;
int dst_tile_offset = C4NUM * (out_w_index * OUPUT_UNIT + out_h_index * OUPUT_UNIT * output_w);
for (int j = 0; j < oc4; j++) {
int src_oc4_offset = src_tile_offset + j * input_unit * input_unit * C4NUM;
int dst_oc4_offset = dst_tile_offset + j * C4NUM * output_h * output_w;
const int32_t *src_ptr = gemm_out + src_oc4_offset;
const int32_t *bias_ptr = bias_data + j * C4NUM;
int8_t *dst_ptr = out_data + dst_oc4_offset;
// output transform
int real_num = (output_channel - j * C4NUM) < C4NUM ? (output_channel - j * C4NUM) : C4NUM;
bool w_not_bound = out_w_index * OUPUT_UNIT + 1 < output_w;
bool h_not_bound = out_h_index * OUPUT_UNIT + 1 < output_h;
Conv3x3Int8OutputUnit(src_ptr, bias_ptr, dst_ptr, h_not_bound, w_not_bound, output_w, real_num, j * C4NUM,
conv_param);
}
}
}
void Conv3x3Int8InputTransform(const int16_t *input_data, int16_t *trans_input, int16_t *tmp_data, int start_index,
int real_cal_num, int out_w_block, ConvParameter *conv_param) {
// input data format : nhwc
int input_channel = conv_param->input_channel_;
int input_width = conv_param->input_w_;
int input_height = conv_param->input_h_;
int pad_w = conv_param->pad_l_;
int pad_h = conv_param->pad_u_;
ConvQuantArg quant_arg = conv_param->conv_quant_arg_;
int input_zp = quant_arg.input_quant_args_[0].zp_;
const int ic8 = UP_DIV(input_channel, C8NUM);
const int input_unit = 4;
if (out_w_block == 0) {
return;
}
for (int cal_id = 0; cal_id < real_cal_num; cal_id++) {
int x_id = start_index + cal_id;
int origin_x = (x_id % out_w_block) * OUPUT_UNIT - pad_w;
int origin_y = (x_id / out_w_block) * OUPUT_UNIT - pad_h;
int real_x_start = origin_x > 0 ? 0 : -origin_x;
int real_x_end = (origin_x + input_unit) < input_width ? input_unit : (input_width - origin_x);
int real_y_start = origin_y > 0 ? 0 : -origin_y;
int real_y_end = (origin_y + input_unit) < input_height ? input_unit : (input_height - origin_y);
int src_plane_offset = C8NUM * (origin_y * input_width + origin_x);
int dst_plane_offset = cal_id * C8NUM;
for (int ic = 0; ic < ic8; ic++) {
// copy data from origin input to tmp buffer
for (int i = 0; i < input_unit * input_unit * TILE_NUM; i++) tmp_data[i] = input_zp;
int src_c8_offset = src_plane_offset + ic * C8NUM * input_height * input_width;
for (int j = real_y_start; j < real_y_end; j++) {
const int16_t *src = input_data + src_c8_offset + C8NUM * (j * input_width + real_x_start);
int16_t *dst = tmp_data + C8NUM * (C4NUM * j + real_x_start);
memcpy(dst, src, (real_x_end - real_x_start) * C8NUM * sizeof(int16_t));
}
// input transform
int dst_ic8_offset = dst_plane_offset + ic * TILE_NUM * C8NUM;
size_t dst_step = ic8 * C8NUM * TILE_NUM;
int16_t *trans_input_ptr = trans_input + dst_ic8_offset;
Conv3x3Int8InputUnit(tmp_data, trans_input_ptr, dst_step, input_zp);
}
}
}
void Conv3x3Int8Gemm(int32_t *dst, const int16_t *src, const int16_t *weight, int oc, int ic8, size_t real_cal_num) {
int oc4 = UP_DIV(oc, C4NUM);
#ifdef ENABLE_ARM
IndirectGemmInt16to32_8x4(dst, src, weight, 16, ic8, oc4, oc4 * 4 * 16 * sizeof(int32_t));
#else
const int input_unit_square = 16;
for (int c = 0; c < oc4; c++) {
int filter_oc_offset = c * input_unit_square * ic8 * C8NUM * C4NUM;
int dst_oc_offset = c * input_unit_square * C4NUM;
for (int n = 0; n < real_cal_num; n++) {
int src_tile_offset = n * C8NUM;
int dst_tile_offset = dst_oc_offset + n * oc4 * C4NUM * input_unit_square;
for (int i = 0; i < 4; i++) {
int filter_h_offset = filter_oc_offset + i * 4 * ic8 * C8NUM * C4NUM;
int src_h_offset = src_tile_offset + i * C8NUM * ic8 * C8NUM * C4NUM;
int dst_h_offset = dst_tile_offset + i * 4 * 4;
for (int m = 0; m < 4; m++) {
int filter_w_offset = filter_h_offset + m * 4 * C8NUM * ic8;
int src_w_offset = src_h_offset + m * 8 * ic8 * C8NUM;
int dst_w_offset = dst_h_offset + m * C4NUM;
int32_t acc[4] = {0};
for (int z = 0; z < 4; z++) {
int filter_offset = filter_w_offset + z;
for (int j = 0; j < ic8; j++) {
int filter_c8_offset = filter_offset + j * 4 * 8;
int src_c8_offset = src_w_offset + j * 8 * 8;
for (int k = 0; k < 8; k++) {
const int16_t *w_ptr = weight + filter_c8_offset + k * 4;
const int16_t *input_ptr = src + src_c8_offset + k;
acc[z] += w_ptr[0] * input_ptr[0];
}
}
(dst + dst_w_offset + z)[0] = acc[z];
}
}
}
}
}
#endif
}
// int8 convolution 3x3
void Conv3x3Int8(int16_t *input_data, int16_t *transed_weight, const int32_t *bias_data, int8_t *output_data,
int16_t *tile_buffer, int16_t *block_unit_buffer, int32_t *tmp_dst_buffer, int8_t *tmp_out,
int task_id, ConvParameter *conv_param) {
int ic8 = UP_DIV(conv_param->input_channel_, C8NUM);
int out_w_block = UP_DIV(conv_param->output_w_, OUPUT_UNIT);
int out_h_block = UP_DIV(conv_param->output_h_, OUPUT_UNIT);
int output_count = out_w_block * out_h_block;
int output_tile_count = UP_DIV(output_count, TILE_NUM);
int oc4 = UP_DIV(conv_param->output_channel_, C4NUM);
int tile_buffer_offset = TILE_NUM * 16 * ic8 * C8NUM;
const int block_unit_buffer_offset = 16 * C8NUM;
int tmp_dst_buffer_offset = TILE_NUM * 16 * oc4 * C4NUM;
for (int batch = 0; batch < conv_param->input_batch_; batch++) {
int in_batch_offset = batch * ic8 * C8NUM * conv_param->input_h_ * conv_param->input_w_;
int tmp_out_batch_offset = batch * oc4 * C4NUM * conv_param->output_w_ * conv_param->output_h_;
for (int thread_id = task_id; thread_id < output_tile_count; thread_id += conv_param->thread_num_) {
int start_index = thread_id * TILE_NUM;
int real_cal_num = (output_count - start_index) < TILE_NUM ? (output_count - start_index) : TILE_NUM;
Conv3x3Int8InputTransform(input_data + in_batch_offset, tile_buffer + task_id * tile_buffer_offset,
block_unit_buffer + task_id * block_unit_buffer_offset, start_index, real_cal_num,
out_w_block, conv_param);
Conv3x3Int8Gemm(tmp_dst_buffer + task_id * tmp_dst_buffer_offset, tile_buffer + task_id * tile_buffer_offset,
transed_weight, conv_param->output_channel_, ic8, real_cal_num);
Conv3x3Int8OutputTransform(tmp_dst_buffer + task_id * tmp_dst_buffer_offset, tmp_out + tmp_out_batch_offset,
bias_data, start_index, real_cal_num, out_w_block, conv_param);
}
}
}

View File

@ -0,0 +1,49 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_NNACL_INT8_CONV_INT8_H_
#define MINDSPORE_LITE_NNACL_INT8_CONV_INT8_H_
#include <string.h>
#ifdef ENABLE_NEON
#include <arm_neon.h>
#endif
#include "nnacl/pack.h"
#include "nnacl/op_base.h"
#include "nnacl/common_func.h"
#include "nnacl/conv_parameter.h"
#include "nnacl/winograd_utils.h"
#include "nnacl/int8/quantize.h"
#include "nnacl/matmul_parameter.h"
#include "nnacl/int8/matmul_int8.h"
#include "nnacl/winograd_transform.h"
#include "nnacl/int8/common_func_int8.h"
#ifdef __cplusplus
extern "C" {
#endif
void Conv3x3Int8FilterTransform(const int16_t *weight_data, int16_t *trans_weight, int iC8, int output_channel,
int kernel_plane);
void Conv3x3Int8(int16_t *input_data, int16_t *transed_weight, const int32_t *bias_data, int8_t *output_data,
int16_t *tile_buffer, int16_t *block_unit_buffer, int32_t *tmp_dst_buffer, int8_t *tmp_out,
int task_id, ConvParameter *conv_param);
#ifdef __cplusplus
}
#endif
#endif // MINDSPORE_LITE_NNACL_INT8_CONV_INT8_H_

View File

@ -16,7 +16,7 @@
#include "nnacl/int8/conv_depthwise_int8.h"
#include <string.h>
#include "nnacl/quantization/fixed_point.h"
#include "nnacl/int8/fixed_point.h"
#include "nnacl/int8/common_func_int8.h"
/*conv depthwise int8 begin*/

View File

@ -15,52 +15,6 @@
*/
#include "nnacl/int8/conv_int8.h"
#include <string.h>
#include "nnacl/winograd_transform.h"
#include "nnacl/int8/common_func_int8.h"
void Conv3x3Int8Gemm(int32_t *dst, const int16_t *src, const int16_t *weight, int oc, int ic8, size_t real_cal_num) {
int oc4 = UP_DIV(oc, C4NUM);
#ifdef ENABLE_ARM
IndirectGemmInt16to32_8x4(dst, src, weight, 16, ic8, oc4, oc4 * 4 * 16 * sizeof(int32_t));
#else
const int input_unit_square = 16;
for (int c = 0; c < oc4; c++) {
int filter_oc_offset = c * input_unit_square * ic8 * C8NUM * C4NUM;
int dst_oc_offset = c * input_unit_square * C4NUM;
for (int n = 0; n < real_cal_num; n++) {
int src_tile_offset = n * C8NUM;
int dst_tile_offset = dst_oc_offset + n * oc4 * C4NUM * input_unit_square;
for (int i = 0; i < 4; i++) {
int filter_h_offset = filter_oc_offset + i * 4 * ic8 * C8NUM * C4NUM;
int src_h_offset = src_tile_offset + i * C8NUM * ic8 * C8NUM * C4NUM;
int dst_h_offset = dst_tile_offset + i * 4 * 4;
for (int m = 0; m < 4; m++) {
int filter_w_offset = filter_h_offset + m * 4 * C8NUM * ic8;
int src_w_offset = src_h_offset + m * 8 * ic8 * C8NUM;
int dst_w_offset = dst_h_offset + m * C4NUM;
int32_t acc[4] = {0};
for (int z = 0; z < 4; z++) {
int filter_offset = filter_w_offset + z;
for (int j = 0; j < ic8; j++) {
int filter_c8_offset = filter_offset + j * 4 * 8;
int src_c8_offset = src_w_offset + j * 8 * 8;
for (int k = 0; k < 8; k++) {
const int16_t *w_ptr = weight + filter_c8_offset + k * 4;
const int16_t *input_ptr = src + src_c8_offset + k;
acc[z] += w_ptr[0] * input_ptr[0];
}
}
(dst + dst_w_offset + z)[0] = acc[z];
}
}
}
}
}
#endif
}
void ConvInt8(int8_t *input_data, int8_t *packed_input, int8_t *matmul_input, int8_t *packed_weight,
const int32_t *bias_data, int8_t *output_data, int32_t *filter_zp, int32_t *input_sum, int task_id,
@ -141,717 +95,3 @@ void ConvInt8(int8_t *input_data, int8_t *packed_input, int8_t *matmul_input, in
}
}
}
void Conv1x1PreOptPeroc(const int8_t *src_input, int8_t *packed_input, int32_t *input_sum, size_t input_channel,
size_t output_channel, size_t plane_size, int32_t *filter_zp, size_t inputsum_stride) {
int ic4 = UP_ROUND(input_channel, C4NUM);
int oc8 = UP_ROUND(output_channel, C8NUM);
int hw8 = UP_ROUND(plane_size, C8NUM);
size_t hw_8div = plane_size / C8NUM * C8NUM;
size_t oc_8div = output_channel / C8NUM * C8NUM;
size_t oc_8res = output_channel - oc_8div;
size_t ic_4div = input_channel / C4NUM * C4NUM;
const int8_t *src_r = src_input;
int8_t *pack_r = packed_input;
int32_t *input_sum_r = input_sum;
for (int hwi = 0; hwi < hw_8div; hwi += C8NUM) {
const int8_t *src_ic = src_r;
int8_t *pack_ic = pack_r;
int32_t *input_sum_oc = input_sum_r;
#ifdef ENABLE_ARM64
size_t src_stride = input_channel;
size_t ic_4res = input_channel - ic_4div;
size_t input_sum_stride = inputsum_stride * 4 - C8NUM * C8NUM * 4;
asm volatile(
"dup v16.4s, wzr \n"
"dup v17.4s, wzr \n"
"mov x10, %[src_ic] \n"
"mov x11, %[pack_ic] \n"
"mov x0, #0 \n"
"1: \n"
"cmp x0, %[ic_4div] \n"
"add x0, x0, #4\n"
"mov x12, x10 \n"
"add x10, x10, #4\n"
"blt 2f \n"
"cmp %[ic_4res], #0\n"
"beq 6f \n"
"cmp %[ic_4res], #1\n"
"beq 3f \n"
"cmp %[ic_4res], #2\n"
"beq 4f \n"
"cmp %[ic_4res], #3\n"
"beq 5f \n"
"2: \n"
"ld1 {v0.s}[0], [x12], %[src_stride]\n"
"ld1 {v0.s}[1], [x12], %[src_stride]\n"
"ld1 {v0.s}[2], [x12], %[src_stride]\n"
"ld1 {v0.s}[3], [x12], %[src_stride]\n"
"ld1 {v1.s}[0], [x12], %[src_stride]\n"
"ld1 {v1.s}[1], [x12], %[src_stride]\n"
"ld1 {v1.s}[2], [x12], %[src_stride]\n"
"ld1 {v1.s}[3], [x12], %[src_stride]\n"
"st1 {v0.16b}, [x11], #16\n"
"st1 {v1.16b}, [x11], #16\n"
"saddlp v4.8h, v0.16b \n"
"saddlp v5.8h, v1.16b \n"
"saddlp v0.4s, v4.8h \n"
"saddlp v1.4s, v5.8h \n"
"add v16.4s, v16.4s, v0.4s \n"
"add v17.4s, v17.4s, v1.4s \n"
"b 1b \n"
"3: \n" /* col res 1 */
"dup v0.4s, wzr \n"
"dup v1.4s, wzr \n"
"ld1 {v0.b}[0], [x12], %[src_stride]\n"
"ld1 {v0.b}[4], [x12], %[src_stride]\n"
"ld1 {v0.b}[8], [x12], %[src_stride]\n"
"ld1 {v0.b}[12], [x12], %[src_stride]\n"
"ld1 {v1.b}[0], [x12], %[src_stride]\n"
"ld1 {v1.b}[4], [x12], %[src_stride]\n"
"ld1 {v1.b}[8], [x12], %[src_stride]\n"
"ld1 {v1.b}[12], [x12], %[src_stride]\n"
"st1 {v0.16b}, [x11], #16\n"
"st1 {v1.16b}, [x11], #16\n"
"saddlp v4.8h, v0.16b \n"
"saddlp v5.8h, v1.16b \n"
"saddlp v0.4s, v4.8h \n"
"saddlp v1.4s, v5.8h \n"
"add v16.4s, v16.4s, v0.4s \n"
"add v17.4s, v17.4s, v1.4s \n"
"b 6f \n"
"4: \n" /* col res 2 */
"dup v0.4s, wzr \n"
"dup v1.4s, wzr \n"
"ld1 {v0.h}[0], [x12], %[src_stride]\n"
"ld1 {v0.h}[2], [x12], %[src_stride]\n"
"ld1 {v0.h}[4], [x12], %[src_stride]\n"
"ld1 {v0.h}[6], [x12], %[src_stride]\n"
"ld1 {v1.h}[0], [x12], %[src_stride]\n"
"ld1 {v1.h}[2], [x12], %[src_stride]\n"
"ld1 {v1.h}[4], [x12], %[src_stride]\n"
"ld1 {v1.h}[6], [x12], %[src_stride]\n"
"st1 {v0.16b}, [x11], #16\n"
"st1 {v1.16b}, [x11], #16\n"
"saddlp v4.8h, v0.16b \n"
"saddlp v5.8h, v1.16b \n"
"saddlp v0.4s, v4.8h \n"
"saddlp v1.4s, v5.8h \n"
"add v16.4s, v16.4s, v0.4s \n"
"add v17.4s, v17.4s, v1.4s \n"
"b 6f \n"
"5: \n" /* col res 3 */
"dup v0.4s, wzr \n"
"dup v1.4s, wzr \n"
"add x13, x12, #2 \n"
"ld1 {v0.h}[0], [x12], %[src_stride]\n"
"ld1 {v0.b}[2], [x13], %[src_stride]\n"
"ld1 {v0.h}[2], [x12], %[src_stride]\n"
"ld1 {v0.b}[6], [x13], %[src_stride]\n"
"ld1 {v0.h}[4], [x12], %[src_stride]\n"
"ld1 {v0.b}[10], [x13], %[src_stride]\n"
"ld1 {v0.h}[6], [x12], %[src_stride]\n"
"ld1 {v0.b}[14], [x13], %[src_stride]\n"
"ld1 {v1.h}[0], [x12], %[src_stride]\n"
"ld1 {v1.b}[2], [x13], %[src_stride]\n"
"ld1 {v1.h}[2], [x12], %[src_stride]\n"
"ld1 {v1.b}[6], [x13], %[src_stride]\n"
"ld1 {v1.h}[4], [x12], %[src_stride]\n"
"ld1 {v1.b}[10], [x13], %[src_stride]\n"
"ld1 {v1.h}[6], [x12], %[src_stride]\n"
"ld1 {v1.b}[14], [x13], %[src_stride]\n"
"st1 {v0.16b}, [x11], #16\n"
"st1 {v1.16b}, [x11], #16\n"
"saddlp v4.8h, v0.16b \n"
"saddlp v5.8h, v1.16b \n"
"saddlp v0.4s, v4.8h \n"
"saddlp v1.4s, v5.8h \n"
"add v16.4s, v16.4s, v0.4s \n"
"add v17.4s, v17.4s, v1.4s \n"
"b 6f \n"
"6: \n"
"dup v0.4s, v16.s[0] \n"
"dup v1.4s, v16.s[1] \n"
"dup v2.4s, v16.s[2] \n"
"dup v3.4s, v16.s[3] \n"
"dup v4.4s, v17.s[0] \n"
"dup v5.4s, v17.s[1] \n"
"dup v6.4s, v17.s[2] \n"
"dup v7.4s, v17.s[3] \n"
"mov x4, #0 \n"
"mov x10, %[filter_zp] \n"
"mov x11, %[input_sum_oc] \n"
"7: \n"
"cmp x4, %[oc_8div] \n"
"beq 8f \n"
"add x4, x4, #8\n"
"ld1 {v16.4s}, [x10], #16\n"
"ld1 {v17.4s}, [x10], #16\n"
"mul v18.4s, v16.4s, v0.4s \n"
"mul v19.4s, v17.4s, v0.4s \n"
"st1 {v18.4s}, [x11], #16 \n"
"st1 {v19.4s}, [x11], #16 \n"
"mul v20.4s, v16.4s, v1.4s \n"
"mul v21.4s, v17.4s, v1.4s \n"
"st1 {v20.4s}, [x11], #16 \n"
"st1 {v21.4s}, [x11], #16 \n"
"mul v22.4s, v16.4s, v2.4s \n"
"mul v23.4s, v17.4s, v2.4s \n"
"st1 {v22.4s}, [x11], #16 \n"
"st1 {v23.4s}, [x11], #16 \n"
"mul v24.4s, v16.4s, v3.4s \n"
"mul v25.4s, v17.4s, v3.4s \n"
"st1 {v24.4s}, [x11], #16 \n"
"st1 {v25.4s}, [x11], #16 \n"
"mul v18.4s, v16.4s, v4.4s \n"
"mul v19.4s, v17.4s, v4.4s \n"
"st1 {v18.4s}, [x11], #16 \n"
"st1 {v19.4s}, [x11], #16 \n"
"mul v20.4s, v16.4s, v5.4s \n"
"mul v21.4s, v17.4s, v5.4s \n"
"st1 {v20.4s}, [x11], #16 \n"
"st1 {v21.4s}, [x11], #16 \n"
"mul v22.4s, v16.4s, v6.4s \n"
"mul v23.4s, v17.4s, v6.4s \n"
"st1 {v22.4s}, [x11], #16 \n"
"st1 {v23.4s}, [x11], #16 \n"
"mul v24.4s, v16.4s, v7.4s \n"
"mul v25.4s, v17.4s, v7.4s \n"
"st1 {v24.4s}, [x11], #16 \n"
"st1 {v25.4s}, [x11], #16 \n"
"add x11, x11, %[input_sum_stride] \n"
"b 7b \n"
"8: \n"
"cmp %[oc_8res], #0\n"
"beq 17f \n"
"dup v16.4s, wzr \n"
"dup v17.4s, wzr \n"
"cmp %[oc_8res], #1\n"
"beq 9f \n"
"cmp %[oc_8res], #2\n"
"beq 10f \n"
"cmp %[oc_8res], #3\n"
"beq 11f \n"
"cmp %[oc_8res], #4\n"
"beq 12f \n"
"cmp %[oc_8res], #5\n"
"beq 13f \n"
"cmp %[oc_8res], #6\n"
"beq 14f \n"
"cmp %[oc_8res], #7\n"
"beq 15f \n"
"9: \n"
"ld1 {v16.s}[0], [x10] \n"
"b 16f \n"
"10: \n"
"ld1 {v16.d}[0], [x10] \n"
"b 16f \n"
"11: \n"
"ld1 {v16.d}[0], [x10] \n"
"add x10, x10, #8 \n"
"ld1 {v16.s}[2], [x10] \n"
"b 16f \n"
"12: \n"
"ld1 {v16.4s}, [x10] \n"
"b 16f \n"
"13: \n"
"ld1 {v16.4s}, [x10], #16\n"
"ld1 {v17.s}[0], [x10] \n"
"b 16f \n"
"14: \n"
"ld1 {v16.4s}, [x10], #16\n"
"ld1 {v17.d}[0], [x10] \n"
"b 16f \n"
"15: \n"
"ld1 {v16.4s}, [x10], #16\n"
"ld1 {v17.d}[0], [x10] \n"
"add x10, x10, #8 \n"
"ld1 {v17.s}[2], [x10] \n"
"b 16f \n"
"16: \n"
"mul v18.4s, v16.4s, v0.4s \n"
"mul v19.4s, v17.4s, v0.4s \n"
"mul v20.4s, v16.4s, v1.4s \n"
"mul v21.4s, v17.4s, v1.4s \n"
"mul v22.4s, v16.4s, v2.4s \n"
"mul v23.4s, v17.4s, v2.4s \n"
"mul v24.4s, v16.4s, v3.4s \n"
"mul v25.4s, v17.4s, v3.4s \n"
"st1 {v18.4s}, [x11], #16 \n"
"st1 {v19.4s}, [x11], #16 \n"
"st1 {v20.4s}, [x11], #16 \n"
"st1 {v21.4s}, [x11], #16 \n"
"st1 {v22.4s}, [x11], #16 \n"
"st1 {v23.4s}, [x11], #16 \n"
"st1 {v24.4s}, [x11], #16 \n"
"st1 {v25.4s}, [x11], #16 \n"
"mul v18.4s, v16.4s, v4.4s \n"
"mul v19.4s, v17.4s, v4.4s \n"
"mul v20.4s, v16.4s, v5.4s \n"
"mul v21.4s, v17.4s, v5.4s \n"
"mul v22.4s, v16.4s, v6.4s \n"
"mul v23.4s, v17.4s, v6.4s \n"
"mul v24.4s, v16.4s, v7.4s \n"
"mul v25.4s, v17.4s, v7.4s \n"
"st1 {v18.4s}, [x11], #16 \n"
"st1 {v19.4s}, [x11], #16 \n"
"st1 {v20.4s}, [x11], #16 \n"
"st1 {v21.4s}, [x11], #16 \n"
"st1 {v22.4s}, [x11], #16 \n"
"st1 {v23.4s}, [x11], #16 \n"
"st1 {v24.4s}, [x11], #16 \n"
"st1 {v25.4s}, [x11], #16 \n"
"17: \n"
:
: [ src_ic ] "r"(src_ic), [ pack_ic ] "r"(pack_ic), [ filter_zp ] "r"(filter_zp),
[ input_sum_oc ] "r"(input_sum_oc), [ input_sum_stride ] "r"(input_sum_stride), [ src_stride ] "r"(src_stride),
[ ic_4div ] "r"(ic_4div), [ ic_4res ] "r"(ic_4res), [ oc_8div ] "r"(oc_8div), [ oc_8res ] "r"(oc_8res)
: "x0", "x1", "x4", "x9", "x10", "x11", "x12", "x13", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v16",
"v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25");
#else
int32_t tmp_sum_value[8] = {0};
for (int ici = 0; ici < ic_4div; ici += C4NUM) {
for (int i = 0; i < C8NUM; i++) {
tmp_sum_value[i] += src_ic[0 + i * input_channel];
tmp_sum_value[i] += src_ic[1 + i * input_channel];
tmp_sum_value[i] += src_ic[2 + i * input_channel];
tmp_sum_value[i] += src_ic[3 + i * input_channel];
pack_ic[0 + i * C4NUM] = src_ic[0 + i * input_channel];
pack_ic[1 + i * C4NUM] = src_ic[1 + i * input_channel];
pack_ic[2 + i * C4NUM] = src_ic[2 + i * input_channel];
pack_ic[3 + i * C4NUM] = src_ic[3 + i * input_channel];
}
src_ic += C4NUM;
pack_ic += C4NUM * C8NUM;
}
for (int ici = ic_4div; ici < input_channel; ici += 1) {
for (int i = 0; i < C8NUM; i++) {
tmp_sum_value[i] += src_ic[i * input_channel];
pack_ic[i * C4NUM] = src_ic[i * input_channel];
}
src_ic += 1;
pack_ic += 1;
}
for (int ici = input_channel; ici < ic4; ici += 1) {
for (int i = 0; i < C8NUM; i++) {
pack_ic[i * C4NUM] = 0;
}
pack_ic += 1;
}
for (int oci = 0; oci < oc_8div; oci += C8NUM) {
for (int ri = 0; ri < C8NUM; ri++) {
input_sum_oc[ri * C8NUM + 0] = tmp_sum_value[ri] * filter_zp[oci + 0];
input_sum_oc[ri * C8NUM + 1] = tmp_sum_value[ri] * filter_zp[oci + 1];
input_sum_oc[ri * C8NUM + 2] = tmp_sum_value[ri] * filter_zp[oci + 2];
input_sum_oc[ri * C8NUM + 3] = tmp_sum_value[ri] * filter_zp[oci + 3];
input_sum_oc[ri * C8NUM + 4] = tmp_sum_value[ri] * filter_zp[oci + 4];
input_sum_oc[ri * C8NUM + 5] = tmp_sum_value[ri] * filter_zp[oci + 5];
input_sum_oc[ri * C8NUM + 6] = tmp_sum_value[ri] * filter_zp[oci + 6];
input_sum_oc[ri * C8NUM + 7] = tmp_sum_value[ri] * filter_zp[oci + 7];
}
input_sum_oc += inputsum_stride;
}
if (oc_8div != output_channel) {
for (int oci = 0; oci < oc_8res; oci += 1) {
for (int ri = 0; ri < C8NUM; ri++) {
input_sum_oc[ri * C8NUM + oci] = tmp_sum_value[ri] * filter_zp[oc_8div + oci];
}
}
for (int oci = oc_8res; oci < C8NUM; oci += 1) {
for (int ri = 0; ri < C8NUM; ri++) {
input_sum_oc[ri * C8NUM + oci] = 0;
}
}
} /* oc8 res done */
#endif
src_r += input_channel * C8NUM;
pack_r += ic4 * C8NUM;
input_sum_r += C8NUM * C8NUM;
}
if (hw_8div != plane_size) {
memset(pack_r, 0, C8NUM * ic4);
for (int hwi = hw_8div; hwi < plane_size; hwi += 1) {
int32_t *input_sum_oc = input_sum_r;
int32_t tmp_sum_value = 0;
const int8_t *src_ic = src_r;
int8_t *pack_ic = pack_r;
for (int ici = 0; ici < ic_4div; ici += C4NUM) {
tmp_sum_value += src_ic[0];
tmp_sum_value += src_ic[1];
tmp_sum_value += src_ic[2];
tmp_sum_value += src_ic[3];
pack_ic[0] = src_ic[0];
pack_ic[1] = src_ic[1];
pack_ic[2] = src_ic[2];
pack_ic[3] = src_ic[3];
src_ic += C4NUM;
pack_ic += C4NUM * C8NUM;
}
for (int ici = ic_4div; ici < input_channel; ici += 1) {
tmp_sum_value += src_ic[0];
pack_ic[0] = src_ic[0];
src_ic += 1;
pack_ic += 1;
}
for (int oci = 0; oci < oc_8div; oci += C8NUM) {
for (int curoi = 0; curoi < C8NUM; curoi++) {
input_sum_oc[curoi] = tmp_sum_value * filter_zp[oci + curoi];
}
input_sum_oc += inputsum_stride;
}
if (oc_8div != output_channel) {
for (int oci = 0; oci < oc_8res; oci += 1) {
input_sum_oc[oci] = tmp_sum_value * filter_zp[oc_8div + oci];
}
for (int oci = oc_8res; oci < C8NUM; oci += 1) {
input_sum_oc[oci] = 0;
}
} /* oc8 res done */
src_r += input_channel;
pack_r += C4NUM;
input_sum_r += C8NUM;
}
for (int hwi = plane_size; hwi < hw8; hwi++) {
for (int oc = 0; oc < oc8; oc++) {
int oc8div = oc / C8NUM, oc8res = oc % C8NUM;
input_sum[oc8div * inputsum_stride + hwi * C8NUM + oc8res] = 0;
}
}
}
return;
}
void Conv1x1PreOptPert(const int8_t *src_input, int8_t *packed_input, int32_t *input_sum, size_t input_channel,
size_t plane_size, ConvParameter *conv_param) {
int ic4 = UP_ROUND(input_channel, C4NUM);
size_t hw_8div = plane_size / C8NUM * C8NUM;
size_t ic_4div = input_channel / C4NUM * C4NUM;
int32_t filter_zp = conv_param->conv_quant_arg_.filter_quant_args_[0].zp_;
const int8_t *src_r = src_input;
int8_t *pack_r = packed_input;
/* per layer */
for (int hwi = 0; hwi < hw_8div; hwi += C8NUM) {
const int8_t *src_ic = src_r;
int8_t *pack_ic = pack_r;
int32_t *input_sum_r = input_sum + hwi;
#ifdef ENABLE_ARM64
size_t src_stride = input_channel;
size_t ic_4res = input_channel - ic_4div;
asm volatile(
"dup v16.4s, wzr \n"
"dup v17.4s, wzr \n"
"mov x14, %[input_sum_r] \n"
"dup v20.4s, %w[filter_zp] \n"
"mov x10, %[src_ic] \n"
"mov x11, %[pack_ic] \n"
"mov x0, #0 \n"
"1: \n"
"cmp x0, %[ic_4div] \n"
"add x0, x0, #4\n"
"mov x12, x10 \n"
"add x10, x10, #4\n"
"blt 2f \n"
"cmp %[ic_4res], #0\n"
"beq 6f \n"
"cmp %[ic_4res], #1\n"
"beq 3f \n"
"cmp %[ic_4res], #2\n"
"beq 4f \n"
"cmp %[ic_4res], #3\n"
"beq 5f \n"
"2: \n"
"ld1 {v0.s}[0], [x12], %[src_stride]\n"
"ld1 {v0.s}[1], [x12], %[src_stride]\n"
"ld1 {v0.s}[2], [x12], %[src_stride]\n"
"ld1 {v0.s}[3], [x12], %[src_stride]\n"
"ld1 {v1.s}[0], [x12], %[src_stride]\n"
"ld1 {v1.s}[1], [x12], %[src_stride]\n"
"ld1 {v1.s}[2], [x12], %[src_stride]\n"
"ld1 {v1.s}[3], [x12], %[src_stride]\n"
"st1 {v0.16b}, [x11], #16\n"
"st1 {v1.16b}, [x11], #16\n"
"saddlp v4.8h, v0.16b \n"
"saddlp v5.8h, v1.16b \n"
"saddlp v0.4s, v4.8h \n"
"saddlp v1.4s, v5.8h \n"
"add v16.4s, v16.4s, v0.4s \n"
"add v17.4s, v17.4s, v1.4s \n"
"b 1b \n"
"3: \n" /* col res 1 */
"dup v0.4s, wzr \n"
"dup v1.4s, wzr \n"
"ld1 {v0.b}[0], [x12], %[src_stride]\n"
"ld1 {v0.b}[4], [x12], %[src_stride]\n"
"ld1 {v0.b}[8], [x12], %[src_stride]\n"
"ld1 {v0.b}[12], [x12], %[src_stride]\n"
"ld1 {v1.b}[0], [x12], %[src_stride]\n"
"ld1 {v1.b}[4], [x12], %[src_stride]\n"
"ld1 {v1.b}[8], [x12], %[src_stride]\n"
"ld1 {v1.b}[12], [x12], %[src_stride]\n"
"st1 {v0.16b}, [x11], #16\n"
"st1 {v1.16b}, [x11], #16\n"
"saddlp v4.8h, v0.16b \n"
"saddlp v5.8h, v1.16b \n"
"saddlp v0.4s, v4.8h \n"
"saddlp v1.4s, v5.8h \n"
"add v16.4s, v16.4s, v0.4s \n"
"add v17.4s, v17.4s, v1.4s \n"
"b 6f \n"
"4: \n" /* col res 2 */
"dup v0.4s, wzr \n"
"dup v1.4s, wzr \n"
"ld1 {v0.h}[0], [x12], %[src_stride]\n"
"ld1 {v0.h}[2], [x12], %[src_stride]\n"
"ld1 {v0.h}[4], [x12], %[src_stride]\n"
"ld1 {v0.h}[6], [x12], %[src_stride]\n"
"ld1 {v1.h}[0], [x12], %[src_stride]\n"
"ld1 {v1.h}[2], [x12], %[src_stride]\n"
"ld1 {v1.h}[4], [x12], %[src_stride]\n"
"ld1 {v1.h}[6], [x12], %[src_stride]\n"
"st1 {v0.16b}, [x11], #16\n"
"st1 {v1.16b}, [x11], #16\n"
"saddlp v4.8h, v0.16b \n"
"saddlp v5.8h, v1.16b \n"
"saddlp v0.4s, v4.8h \n"
"saddlp v1.4s, v5.8h \n"
"add v16.4s, v16.4s, v0.4s \n"
"add v17.4s, v17.4s, v1.4s \n"
"b 6f \n"
"5: \n" /* col res 3 */
"dup v0.4s, wzr \n"
"dup v1.4s, wzr \n"
"add x13, x12, #2 \n"
"ld1 {v0.h}[0], [x12], %[src_stride]\n"
"ld1 {v0.b}[2], [x13], %[src_stride]\n"
"ld1 {v0.h}[2], [x12], %[src_stride]\n"
"ld1 {v0.b}[6], [x13], %[src_stride]\n"
"ld1 {v0.h}[4], [x12], %[src_stride]\n"
"ld1 {v0.b}[10], [x13], %[src_stride]\n"
"ld1 {v0.h}[6], [x12], %[src_stride]\n"
"ld1 {v0.b}[14], [x13], %[src_stride]\n"
"ld1 {v1.h}[0], [x12], %[src_stride]\n"
"ld1 {v1.b}[2], [x13], %[src_stride]\n"
"ld1 {v1.h}[2], [x12], %[src_stride]\n"
"ld1 {v1.b}[6], [x13], %[src_stride]\n"
"ld1 {v1.h}[4], [x12], %[src_stride]\n"
"ld1 {v1.b}[10], [x13], %[src_stride]\n"
"ld1 {v1.h}[6], [x12], %[src_stride]\n"
"ld1 {v1.b}[14], [x13], %[src_stride]\n"
"st1 {v0.16b}, [x11], #16\n"
"st1 {v1.16b}, [x11], #16\n"
"saddlp v4.8h, v0.16b \n"
"saddlp v5.8h, v1.16b \n"
"saddlp v0.4s, v4.8h \n"
"saddlp v1.4s, v5.8h \n"
"add v16.4s, v16.4s, v0.4s \n"
"add v17.4s, v17.4s, v1.4s \n"
"b 6f \n"
"6: \n"
"mul v16.4s, v16.4s, v20.4s \n"
"mul v17.4s, v17.4s, v20.4s \n"
"st1 {v16.4s}, [x14], #16 \n"
"st1 {v17.4s}, [x14], #16 \n"
:
: [ src_ic ] "r"(src_ic), [ pack_ic ] "r"(pack_ic), [ input_sum_r ] "r"(input_sum_r),
[ src_stride ] "r"(src_stride), [ ic_4div ] "r"(ic_4div), [ ic_4res ] "r"(ic_4res), [ filter_zp ] "r"(filter_zp)
: "x0", "x1", "x10", "x11", "x12", "x13", "x14", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v16", "v17",
"v20");
#else
int32_t tmp_sum_value[8] = {0};
for (int ici = 0; ici < ic_4div; ici += C4NUM) {
for (int i = 0; i < C8NUM; i++) {
tmp_sum_value[i] += src_ic[0 + i * input_channel];
tmp_sum_value[i] += src_ic[1 + i * input_channel];
tmp_sum_value[i] += src_ic[2 + i * input_channel];
tmp_sum_value[i] += src_ic[3 + i * input_channel];
pack_ic[0 + i * C4NUM] = src_ic[0 + i * input_channel];
pack_ic[1 + i * C4NUM] = src_ic[1 + i * input_channel];
pack_ic[2 + i * C4NUM] = src_ic[2 + i * input_channel];
pack_ic[3 + i * C4NUM] = src_ic[3 + i * input_channel];
}
src_ic += C4NUM;
pack_ic += C4NUM * C8NUM;
}
for (int ici = ic_4div; ici < input_channel; ici += 1) {
for (int i = 0; i < C8NUM; i++) {
tmp_sum_value[i] += src_ic[i * input_channel];
pack_ic[i * C4NUM] = src_ic[i * input_channel];
}
src_ic += 1;
pack_ic += 1;
}
for (int ici = input_channel; ici < ic4; ici += 1) {
for (int i = 0; i < C8NUM; i++) {
pack_ic[i * C4NUM] = 0;
}
pack_ic += 1;
}
for (int i = 0; i < C8NUM; i++) {
input_sum_r[i] = tmp_sum_value[i] * filter_zp;
}
#endif
src_r += input_channel * C8NUM;
pack_r += ic4 * C8NUM;
}
if (hw_8div != plane_size) {
memset(pack_r, 0, C8NUM * ic4);
for (int hwi = hw_8div; hwi < plane_size; hwi += 1) {
int32_t tmp_sum_value = 0;
const int8_t *src_ic = src_r;
int8_t *pack_ic = pack_r;
for (int ici = 0; ici < ic_4div; ici += C4NUM) {
tmp_sum_value += src_ic[0];
tmp_sum_value += src_ic[1];
tmp_sum_value += src_ic[2];
tmp_sum_value += src_ic[3];
pack_ic[0] = src_ic[0];
pack_ic[1] = src_ic[1];
pack_ic[2] = src_ic[2];
pack_ic[3] = src_ic[3];
src_ic += C4NUM;
pack_ic += C4NUM * C8NUM;
}
for (int ici = ic_4div; ici < input_channel; ici += 1) {
tmp_sum_value += src_ic[0];
pack_ic[0] = src_ic[0];
src_ic += 1;
pack_ic += 1;
}
input_sum[hwi] = tmp_sum_value * filter_zp;
src_r += input_channel;
pack_r += C4NUM;
}
for (int hwi = plane_size; hwi < UP_ROUND(plane_size, C8NUM); hwi++) {
input_sum[hwi] = 0;
}
}
return;
}
void Conv1x1Int8Opt(const int8_t *packed_input, const int8_t *packed_weight, int8_t *dst, const int32_t *input_sum,
const int32_t *bias, int row, int col, int deep4, int32_t *left_shift, int32_t *right_shift,
int32_t *multiplier, ConvParameter *conv_param, MATMUL_OPT_DP_FUNC matmul_func, int *filter_zp) {
int is_per_oc = (int)conv_param->conv_quant_arg_.filter_arg_num_ != 1;
matmul_func(packed_input, packed_weight, dst, row, col, deep4, conv_param->output_channel_, input_sum, bias,
left_shift, right_shift, multiplier, conv_param->conv_quant_arg_.output_quant_args_[0].zp_,
conv_param->conv_quant_arg_.out_act_min_[0], conv_param->conv_quant_arg_.out_act_max_[0], is_per_oc,
filter_zp);
return;
}
void Conv1x1Int8(const int8_t *packed_input, const int8_t *packed_weight, int8_t *dst, const int32_t *input_sum,
const int32_t *bias, int row, int col, int deep16, int32_t *left_shift, int32_t *right_shift,
int32_t *multiplier, ConvParameter *conv_param, int32_t *filter_zp) {
int is_per_oc = (int)conv_param->conv_quant_arg_.filter_arg_num_ != 1;
MatmulInt8Opt(packed_input, packed_weight, dst, row, col, deep16, input_sum, bias,
conv_param->conv_quant_arg_.out_act_min_[0], conv_param->conv_quant_arg_.out_act_max_[0],
conv_param->conv_quant_arg_.output_quant_args_[0].zp_, multiplier, left_shift, right_shift,
conv_param->output_channel_, is_per_oc, filter_zp);
return;
}
// int8 convolution 3x3
void Conv3x3Int8(int16_t *input_data, int16_t *transed_weight, const int32_t *bias_data, int8_t *output_data,
int16_t *tile_buffer, int16_t *block_unit_buffer, int32_t *tmp_dst_buffer, int8_t *tmp_out,
int task_id, ConvParameter *conv_param) {
int ic8 = UP_DIV(conv_param->input_channel_, C8NUM);
int out_w_block = UP_DIV(conv_param->output_w_, OUPUT_UNIT);
int out_h_block = UP_DIV(conv_param->output_h_, OUPUT_UNIT);
int output_count = out_w_block * out_h_block;
int output_tile_count = UP_DIV(output_count, TILE_NUM);
int oc4 = UP_DIV(conv_param->output_channel_, C4NUM);
int tile_buffer_offset = TILE_NUM * 16 * ic8 * C8NUM;
const int block_unit_buffer_offset = 16 * C8NUM;
int tmp_dst_buffer_offset = TILE_NUM * 16 * oc4 * C4NUM;
for (int batch = 0; batch < conv_param->input_batch_; batch++) {
int in_batch_offset = batch * ic8 * C8NUM * conv_param->input_h_ * conv_param->input_w_;
int tmp_out_batch_offset = batch * oc4 * C4NUM * conv_param->output_w_ * conv_param->output_h_;
for (int thread_id = task_id; thread_id < output_tile_count; thread_id += conv_param->thread_num_) {
int start_index = thread_id * TILE_NUM;
int real_cal_num = (output_count - start_index) < TILE_NUM ? (output_count - start_index) : TILE_NUM;
Conv3x3Int8InputTransform(input_data + in_batch_offset, tile_buffer + task_id * tile_buffer_offset,
block_unit_buffer + task_id * block_unit_buffer_offset, start_index, real_cal_num,
out_w_block, conv_param);
Conv3x3Int8Gemm(tmp_dst_buffer + task_id * tmp_dst_buffer_offset, tile_buffer + task_id * tile_buffer_offset,
transed_weight, conv_param->output_channel_, ic8, real_cal_num);
Conv3x3Int8OutputTransform(tmp_dst_buffer + task_id * tmp_dst_buffer_offset, tmp_out + tmp_out_batch_offset,
bias_data, start_index, real_cal_num, out_w_block, conv_param);
}
}
}

View File

@ -16,6 +16,7 @@
#ifndef MINDSPORE_LITE_NNACL_INT8_CONV_INT8_H_
#define MINDSPORE_LITE_NNACL_INT8_CONV_INT8_H_
#include <string.h>
#ifdef ENABLE_NEON
#include <arm_neon.h>
#endif
@ -24,9 +25,10 @@
#include "nnacl/common_func.h"
#include "nnacl/conv_parameter.h"
#include "nnacl/winograd_utils.h"
#include "nnacl/quantization/quantize.h"
#include "nnacl/int8/quantize.h"
#include "nnacl/matmul_parameter.h"
#include "nnacl/int8/matmul_int8.h"
#include "nnacl/int8/common_func_int8.h"
#ifdef __cplusplus
extern "C" {
@ -36,23 +38,6 @@ void ConvInt8(int8_t *input_data, int8_t *packed_input, int8_t *matmul_input, in
const int32_t *bias_data, int8_t *output_data, int32_t *filter_zp, int32_t *input_sum, int task_id,
ConvParameter *conv_param, MATMUL_OPT_R_FUNC matmul_func, bool is_optimize);
// int8 convolution 1x1
void Conv1x1PreOptPeroc(const int8_t *src_input, int8_t *packed_input, int32_t *input_sum, size_t input_channel,
size_t output_channel, size_t plane_size, int32_t *filter_zp, size_t inputsum_stride);
void Conv1x1PreOptPert(const int8_t *src_input, int8_t *packed_input, int32_t *input_sum, size_t input_channel,
size_t plane_size, ConvParameter *conv_param);
void Conv1x1Int8(const int8_t *packed_input, const int8_t *packed_weight, int8_t *dst, const int32_t *input_sum,
const int32_t *bias, int row, int col, int deep16, int32_t *left_shift, int32_t *right_shift,
int32_t *multiplier, ConvParameter *conv_param, int32_t *filter_zp);
void Conv1x1Int8Opt(const int8_t *packed_input, const int8_t *packed_weight, int8_t *dst, const int32_t *input_sum,
const int32_t *bias, int row, int col, int deep4, int32_t *left_shift, int32_t *right_shift,
int32_t *multiplier, ConvParameter *conv_param, MATMUL_OPT_DP_FUNC matmul_func, int32_t *filter_zp);
// int8 convolution 3x3
void Conv3x3Int8(int16_t *input_data, int16_t *transed_weight, const int32_t *bias_data, int8_t *output_data,
int16_t *tile_buffer, int16_t *block_unit_buffer, int32_t *tmp_dst_buffer, int8_t *tmp_out,
int task_id, ConvParameter *conv_param);
#ifdef __cplusplus
}
#endif

View File

@ -17,7 +17,7 @@
#define MINDSPORE_LITE_NNACL_INT8_DEPTH_TO_SPACE_INT8_H_
#include "nnacl/depth_to_space_parameter.h"
#include "nnacl/quantization/quantize.h"
#include "nnacl/int8/quantize.h"
#ifdef __cplusplus
extern "C" {

View File

@ -15,9 +15,6 @@
*/
#include "nnacl/int8/div_int8.h"
#include "nnacl/quantization/fixed_point.h"
#include "nnacl/errorcode.h"
#include "nnacl/quantization/quantize.h"
int DivInt8(int8_t *input0_data, int8_t *input1_data, int8_t *output_data, int64_t real_dst_count, DivQuantArg *para) {
int index = 0;

View File

@ -18,7 +18,9 @@
#define MINDSPORE_LITE_NNACL_INT8_DIV_INT8_H_
#include "nnacl/op_base.h"
#include "nnacl/quantization/quantize.h"
#include "nnacl/errorcode.h"
#include "nnacl/int8/quantize.h"
#include "nnacl/int8/fixed_point.h"
#ifdef __cplusplus
extern "C" {

View File

@ -14,7 +14,7 @@
* limitations under the License.
*/
#include "nnacl/quantization/fixed_point.h"
#include "nnacl/int8/fixed_point.h"
// returns the high-32 bits of a * b with rounding
// assume that a and b is divided by 2^31, who fall into [-1, 1]
@ -107,7 +107,7 @@ int CountLeadingZeroBits(uint32_t x) {
if (x == 0) {
return 8 * sizeof(uint32_t);
}
const int32_t leading_positive = (int32_t)(1) << (8 * sizeof(uint32_t) - 1);
const int32_t leading_positive = (uint32_t)(1) << (8 * sizeof(uint32_t) - 1);
int leading_zeros = 0;
while (x < leading_positive) {
x <<= 1;

View File

@ -18,7 +18,7 @@
#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_INT8_GATHERND_INT8_H_
#include "nnacl/op_base.h"
#include "nnacl/quantization/quantize.h"
#include "nnacl/int8/quantize.h"
#ifdef __cplusplus
extern "C" {

View File

@ -16,7 +16,7 @@
*/
#include "nnacl/int8/gather_int8.h"
#include "nnacl/op_base.h"
#include "nnacl/quantization/quantize.h"
#include "nnacl/int8/quantize.h"
#include "nnacl/errorcode.h"
int GatherInt8(int8_t *in_data, int8_t *out_data, int outer_size, int inner_size, int limit, const int *indices,

View File

@ -18,7 +18,7 @@
#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_INT8_GATHER_INT8_H_
#include "nnacl/op_base.h"
#include "nnacl/quantization/quantize.h"
#include "nnacl/int8/quantize.h"
#ifdef __cplusplus
extern "C" {

View File

@ -19,7 +19,7 @@
#include <math.h>
#include "nnacl/op_base.h"
#include "nnacl/errorcode.h"
#include "nnacl/quantization/fixed_point.h"
#include "nnacl/int8/fixed_point.h"
typedef struct HswishQuantArg {
double input_scale;

View File

@ -15,7 +15,7 @@
*/
#include "nnacl/int8/l2_norm_int8.h"
#include <limits.h>
#include "nnacl/quantization/fixed_point.h"
#include "nnacl/int8/fixed_point.h"
#include "nnacl/errorcode.h"
int L2NormalizationInt8(const int8_t *input_data, int8_t *output_data, const L2NormParameter *param,

View File

@ -18,7 +18,7 @@
#include "nnacl/errorcode.h"
#include "nnacl/layer_norm_parameter.h"
#include "nnacl/quantization/fixed_point.h"
#include "nnacl/int8/fixed_point.h"
#ifdef __cplusplus
extern "C" {

View File

@ -18,7 +18,7 @@
#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_OPCLIB_INT8_PRELU_INT8_H_
#include "nnacl/op_base.h"
#include "nnacl/quantization/quantize.h"
#include "nnacl/int8/quantize.h"
#ifdef __cplusplus
extern "C" {

View File

@ -15,7 +15,7 @@
*/
#include "nnacl/int8/matmul_int8.h"
#include "nnacl/quantization/fixed_point.h"
#include "nnacl/int8/fixed_point.h"
void RowMajor2Row2x16MajorInt8(int8_t *src_ptr, int8_t *dst_ptr, int row, int col) {
int col16 = UP_ROUND(col, C16NUM);

View File

@ -17,7 +17,6 @@
#ifndef MINDSPORE_LITE_NNACL_INT8_MATMUL_H_
#define MINDSPORE_LITE_NNACL_INT8_MATMUL_H_
#include <stdio.h>
#include <string.h>
#include "nnacl/op_base.h"
#include "nnacl/matmul_parameter.h"

View File

@ -15,15 +15,8 @@
*/
#include "nnacl/int8/mul_int8.h"
#include "nnacl/mul_parameter.h"
#ifdef ENABLE_NEON
#include <arm_neon.h>
#include "nnacl/int8/common_func_int8.h"
#endif
#include "nnacl/quantization/fixed_point.h"
#ifdef ENABLE_NEON
int16x4_t ClacSumHalfWordMul(int16x4_t scaled_input0, int16x4_t scaled_input1, int32x4_t left_shift_out_vec,
int32x4_t right_shift_out_vec, int32x4_t output_multiplier_vec) {
int32x4_t input_scale = vmull_s16(scaled_input0, scaled_input1);

View File

@ -19,6 +19,11 @@
#include "nnacl/op_base.h"
#include "nnacl/mul_parameter.h"
#include "nnacl/int8/common_func_int8.h"
#include "nnacl/int8/fixed_point.h"
#ifdef ENABLE_NEON
#include <arm_neon.h>
#endif
#ifdef __cplusplus
extern "C" {

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,62 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_NNACL_INT8_PACK_INT8_H_
#define MINDSPORE_LITE_NNACL_INT8_PACK_INT8_H_
#include <string.h>
#include "nnacl/op_base.h"
#include "nnacl/int8/matmul_int8.h"
#include "nnacl/conv_parameter.h"
#ifdef __cplusplus
extern "C" {
#endif
void PackNHWCToNHWC4Int8(const void *src, void *dst, int batch, int plane, int channel);
void PackNHWC4ToNHWCInt8(const void *src, void *dst, int batch, int plane, int channel);
void PackNHWCToNHWC8Int8(const void *src, void *dst, int batch, int plane, int channel);
void PackNHWC8ToNHWCInt8(const void *src, void *dst, int batch, int plane, int channel);
void PackNCHWToNC8HW8Int8(const void *src, void *dst, int batch, int plane, int channel);
void PackNC4HW4ToNHWCInt8(const void *src, void *dst, int batch, int plane, int channel);
void PackNHWCToC8HWN8Int8(const void *src, void *dst, int batch, int plane, int channel);
void PackNCHWToNHWCInt8(const void *src, void *dst, int batch, int plane, int channel);
void PackNHWCToNCHWInt8(const void *src, void *dst, int batch, int plane, int channel);
void PackInputSum16x4Int8(const int8_t *input, int32_t *input_sum, int32_t *filter_zp, ConvParameter *conv_param);
void PackInputSum16x4PerLayer(const int8_t *src, int32_t *dst, int32_t filter_zp, size_t row4, size_t col16);
void PackInputToC8Int8(const int8_t *input_data, int16_t *packed_input, ConvParameter *conv_param);
void PackWeightToC8Int8(const int8_t *origin_weight_data, int16_t *packed_weight_data, ConvParameter *conv_param);
void Im2ColPackUnitInt8Opt(const int8_t *input_data, int8_t *packed_input, int8_t *matmul_input, int real_cal_num,
int block_index, int32_t *filter_zp, int32_t *input_sum, ConvParameter *conv_param,
bool per_channel, bool is_optimize);
#ifdef ENABLE_ARM
void PreSum4x16Int8Pert(const int8_t *src, int32_t *sum, size_t row4, size_t col16, int32_t filter_zp);
void PreSum4x16Int8Peroc(const int8_t *src, int32_t *sum, int32_t *zp, size_t hw4, size_t ic16, int32_t oc_div,
size_t oc_res, size_t stride);
#endif
void PackDepthwiseInt8Input(const int8_t *src, int16_t *dst, const ConvParameter *conv_param);
void PackDepthwiseInt8Weight(const int8_t *origin_weight, int16_t *packed_weight_, int plane, int channel,
ConvQuantArg *quant_qrg);
void PackDeconvDepthwiseInt8Weight(const int8_t *origin_weight, int16_t *packed_weight_, int plane, int channel,
ConvQuantArg *quant_qrg);
#ifdef __cplusplus
}
#endif
#endif // MINDSPORE_LITE_NNACL_INT8_PAD_INT8_H_

View File

@ -19,7 +19,7 @@
#include "nnacl/op_base.h"
#include "nnacl/power_parameter.h"
#include "nnacl/quantization/quantize.h"
#include "nnacl/int8/quantize.h"
#ifdef __cplusplus
extern "C" {

View File

@ -14,8 +14,7 @@
* limitations under the License.
*/
#include "nnacl/quantization/quantize.h"
#include <stdio.h>
#include "nnacl/int8/quantize.h"
const uint64_t dSignMask = 1ull << 63;
const uint64_t dExponentMask = 0x7ffull << 52;
@ -57,8 +56,6 @@ void QuantizeRoundParameterWithSinglePrecision(double double_multiplier, int32_t
/* multipiler is in[0x40000000, 0x7FFFFF80] range */
*quantized_multiplier = (int32_t)(((scale_bits & UINT32_C(0x007FFFFF)) | UINT32_C(0x00800000)) << 7);
if (quantized_multiplier[0] < INT32_C(0x40000000) || quantized_multiplier[0] > INT32_C(0x7FFFFF80)) {
printf("quantized multiplier must be in [0x40000000, 0x7FFFFF80] range, now multiplier is %d\n",
quantized_multiplier[0]);
return;
}
/* shift is in [0, 31] range */

View File

@ -28,11 +28,6 @@
#define FILTER_PER_CHANNEL 0b010
#define OUTPUT_PER_CHANNEL 0b100
typedef struct QuantArg {
float scale_;
int32_t zp_;
} QuantArg;
typedef struct ConvQuantArg {
RoundingMode round_mode_;
CalFixedMultiplierMode quant_multiplier_mode_;
@ -58,24 +53,6 @@ typedef struct ConcatQuantArg {
int8_t output_activation_max_;
} ConcatQuantArg;
typedef struct SqueezeQuantArg {
QuantArg *in_quant_args_;
QuantArg *out_quant_args_;
} SqueezeQuantArg;
typedef struct UnSqueezeQuantArg {
int *input_sizes_;
int output_size_;
int **input_shapes_;
int *output_shape_;
float alpha;
int axis_;
size_t input_num_;
size_t output_dim_;
QuantArg in_quant_args_;
QuantArg out_quant_args_;
} UnSqueezeQuantArg;
typedef struct PreluQuantArg {
int *input_sizes_;
int output_size_;
@ -103,22 +80,6 @@ typedef struct MatmulQuantArg {
int32_t quant_multiplier;
} MatmulQuantArg;
typedef struct PadQuantArg {
QuantArg *in_quant_args_;
QuantArg *out_quanr_args_;
int8_t *constant_value_;
} PadQuantArg;
typedef struct MulQuantArg {
QuantArg in_quant_args_[2];
QuantArg out_quant_arg_;
int output_multiplier_;
int output_activation_min_;
int output_activation_max_;
int shift_left_;
int shift_right_;
} MulQuantArg;
typedef struct CropQuantArg {
QuantArg in_args_;
QuantArg out_args_;
@ -142,13 +103,6 @@ typedef struct GatherQuantArg {
int zp_out_;
} GatherQuantArg;
typedef struct SplitQuantArg {
QuantArg in_args_;
QuantArg out_args_[20];
int output_activation_min_;
int output_activation_max_;
} SplitQuantArg;
typedef struct SoftmaxQuantArg {
QuantArg in_quant_args_;
QuantArg out_quant_arg_;
@ -159,19 +113,6 @@ typedef struct SoftmaxQuantArg {
int shift_right_;
} SoftmaxQuantArg;
typedef struct ReshapeQuantArg {
QuantArg in_args_;
QuantArg out_args_;
int output_activation_min_;
int output_activation_max_;
} ReshapeQuantArg;
typedef struct QuantMulArg {
int32_t multiplier_;
int left_shift_;
int right_shift_;
} QuantMulArg;
typedef struct SubQuantArg {
QuantArg in0_args_;
QuantArg in1_args_;
@ -227,21 +168,6 @@ typedef struct ReduceQuantArg {
int sum_square_right_shift_;
} ReduceQuantArg;
typedef struct SliceQuantArg {
QuantArg in_args_;
QuantArg out_args_;
int output_activation_min_;
int output_activation_max_;
} SliceQuantArg;
typedef struct PowerQuantArg {
QuantArg in_args_;
QuantArg exp_args_;
QuantArg out_args_;
int output_activation_min_;
int output_activation_max_;
} PowerQuantArg;
typedef struct LeakyReluQuantArg {
OpParameter op_parameter_;
PreluQuantArg quant_arg;

View File

@ -17,7 +17,7 @@
#include <stdint.h>
#include "nnacl/int8/reduce_int8.h"
#include "nnacl/errorcode.h"
#include "nnacl/quantization/fixed_point.h"
#include "nnacl/int8/fixed_point.h"
#include "nnacl/common_func.h"
int ReduceMeanN(int n, int h, int w, int c, int8_t *in_data, int8_t *out_data, QuantMulArg quant_arg) {

View File

@ -16,7 +16,9 @@
#ifndef MINDSPORE_LITE_NNACL_INT8_REDUCE_INT8_H_
#define MINDSPORE_LITE_NNACL_INT8_REDUCE_INT8_H_
#include "nnacl/quantization/quantize.h"
#include "nnacl/int8/quantize.h"
#ifdef __cplusplus
extern "C" {
#endif

View File

@ -19,8 +19,8 @@
#include <math.h>
#include "nnacl/op_base.h"
#include "nnacl/errorcode.h"
#include "nnacl/quantization/fixed_point.h"
#include "nnacl/quantization/quantize.h"
#include "nnacl/int8/fixed_point.h"
#include "nnacl/int8/quantize.h"
typedef struct ReluXQuantArg {
QuantArg input_arg;

View File

@ -16,6 +16,8 @@
#ifndef MINDSPORE_LITE_NNACL_INT8_RESHAHPE_INT8_H_
#define MINDSPORE_LITE_NNACL_INT8_RESHAHPE_INT8_H_
#include <math.h>
#include "nnacl/op_base.h"
#include "nnacl/reshape_parameter.h"

View File

@ -16,7 +16,7 @@
#include <math.h>
#include "nnacl/int8/resize_int8.h"
#include "nnacl/common_func.h"
#include "nnacl/quantization/fixed_point.h"
#include "nnacl/int8/fixed_point.h"
#include "nnacl/errorcode.h"
int ResizeBilinearInt8(const int8_t *input_ptr, int8_t *output_ptr, int batch, int in_h, int in_w, int out_h, int out_w,

View File

@ -21,7 +21,7 @@
#endif
#include <memory.h>
#include "nnacl/op_base.h"
#include "nnacl/quantization/quantize.h"
#include "nnacl/int8/quantize.h"
#include "nnacl/resize_parameter.h"
#ifdef __cplusplus

View File

@ -15,7 +15,7 @@
*/
#include "nnacl/int8/scale_int8.h"
#include "nnacl/quantization/fixed_point.h"
#include "nnacl/int8/fixed_point.h"
#ifdef ENABLE_NEON
int16x4_t ClacSumHalfWordMul2(int32x4_t scaled_input0, int32x4_t scaled_input1, int32x4_t left_shift_out_vec,

View File

@ -19,6 +19,8 @@
#include "nnacl/op_base.h"
#include "nnacl/scale.h"
#include "nnacl/nnacl_common.h"
#ifdef __cplusplus
extern "C" {
#endif

View File

@ -19,7 +19,7 @@
#include <math.h>
#include "nnacl/op_base.h"
#include "nnacl/errorcode.h"
#include "nnacl/quantization/fixed_point.h"
#include "nnacl/int8/fixed_point.h"
#ifdef __cplusplus
extern "C" {

View File

@ -15,8 +15,6 @@
*/
#include "nnacl/int8/slice_int8.h"
#include <string.h>
#include "nnacl/quantization/fixed_point.h"
int SliceInt8NoParallel(const int8_t *input, int8_t *output, SliceParameter *param) {
double input_scale = param->quant_arg_.in_args_.scale_;

View File

@ -16,8 +16,11 @@
#ifndef MINDSPORE_LITE_NNACL_INT8_SLICE_INT8_H_
#define MINDSPORE_LITE_NNACL_INT8_SLICE_INT8_H_
#include <math.h>
#include <string.h>
#include "nnacl/op_base.h"
#include "nnacl/slice_parameter.h"
#include "nnacl/int8/fixed_point.h"
#ifdef __cplusplus
extern "C" {

View File

@ -15,9 +15,6 @@
*/
#include "nnacl/int8/softmax_int8.h"
#include <math.h>
#include "nnacl/quantization/fixed_point.h"
#include "nnacl/quantization/quantize.h"
int SoftmaxInt8(const int8_t *input_ptr, int8_t *output_ptr, int count, int *exp_data, int *sum_data,
SoftmaxQuantArg quant_param, SoftmaxParameter *parameter) {

View File

@ -17,9 +17,11 @@
#ifndef MINDSPORE_LITE_NNACL_INT8_SOFTMAX_INT8_H_
#define MINDSPORE_LITE_NNACL_INT8_SOFTMAX_INT8_H_
#include <math.h>
#include "nnacl/op_base.h"
#include "nnacl/softmax_parameter.h"
#include "nnacl/quantization/quantize.h"
#include "nnacl/int8/fixed_point.h"
#include "nnacl/int8/quantize.h"
#ifdef __cplusplus
extern "C" {

View File

@ -16,6 +16,8 @@
#ifndef MINDSPORE_LITE_NNACL_INT8_SPLIT_INT8_H_
#define MINDSPORE_LITE_NNACL_INT8_SPLIT_INT8_H_
#include <math.h>
#include "nnacl/op_base.h"
#include "nnacl/split_parameter.h"

View File

@ -17,7 +17,7 @@
#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_OPCLIB_INT8_SQUEEZE_INT8_H_
#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_OPCLIB_INT8_SQUEEZE_INT8_H_
#include "nnacl/quantization/quantize.h"
#include "nnacl/int8/quantize.h"
#include "nnacl/squeeze_parameter.h"
#ifdef __cplusplus

View File

@ -19,7 +19,7 @@
#include <arm_neon.h>
#include "nnacl/int8/common_func_int8.h"
#endif
#include "nnacl/quantization/fixed_point.h"
#include "nnacl/int8/fixed_point.h"
#ifdef ENABLE_NEON

View File

@ -18,7 +18,7 @@
#define MINDSPORE_LITE_NNACL_INT8_SUB_INT8_H_
#include "nnacl/op_base.h"
#include "nnacl/quantization/quantize.h"
#include "nnacl/int8/quantize.h"
#ifdef __cplusplus
extern "C" {

View File

@ -18,8 +18,8 @@
#define MINDSPORE_LITE_NNACL_INT8_TANH_INT8_H_
#include "nnacl/op_base.h"
#include "nnacl/quantization/quantize.h"
#include "nnacl/quantization/fixed_point.h"
#include "nnacl/int8/quantize.h"
#include "nnacl/int8/fixed_point.h"
#include "nnacl/int8/quant_dtype_cast_int8.h"
#include "nnacl/fp32/activation_fp32.h"

View File

@ -16,7 +16,6 @@
#include "nnacl/unsqueeze_parameter.h"
#include "nnacl/int8/unsqueeze_int8.h"
#include <string.h>
int Int8Unsqueeze(int8_t *input_ptr, int8_t *output_ptr, UnSqueezeParameter *para_, size_t data_size, int task_id) {
float output_scale = para_->quant_arg.out_quant_args_.scale_;

View File

@ -17,7 +17,7 @@
#define MINDSPORE_LITE_NNACL_L2NORM_PARAMETER_H_
#include "nnacl/op_base.h"
#include "nnacl/quantization/quantize.h"
#include "mindspore/lite/nnacl/int8/quantize.h"
typedef struct L2NormParameter {
// Primitive parameter

View File

@ -17,7 +17,7 @@
#define MINDSPORE_LITE_NNACL_LAYER_NORM_PARAMETER_H_
#include "nnacl/op_base.h"
#include "nnacl/quantization/quantize.h"
#include "mindspore/lite/nnacl/int8/quantize.h"
enum ElementwiseMode { ELEMENTWISE_NOT = 0, ELEMENTWISE_PER_CHANNEL = 1, ELEMENTWISE_PER_NUM = 2 };
typedef struct LayerNormParameter {

View File

@ -18,7 +18,6 @@
#define MINDSPORE_LITE_NNACL_MATMUL_H_
#include "nnacl/op_base.h"
#include "nnacl/quantization/quantize.h"
typedef void (*MATMUL_OPT_R4_FUNC)(const int8_t *a, const int8_t *b, int *dst, int row_4, int col_4, int deep_16,
const int *input_sum, const int *bias);

View File

@ -18,7 +18,16 @@
#define MINDSPORE_LITE_NNACL_MUL_PARAMETER_H_
#include "nnacl/op_base.h"
#include "nnacl/quantization/quantize.h"
typedef struct MulQuantArg {
QuantArg in_quant_args_[2];
QuantArg out_quant_arg_;
int output_multiplier_;
int output_activation_min_;
int output_activation_max_;
int shift_left_;
int shift_right_;
} MulQuantArg;
typedef struct MulParameter {
// Primitive parameter

View File

@ -14,11 +14,4 @@
* limitations under the License.
*/
#ifndef PYBIND_API_PYBIND_PATCH_H_
#define PYBIND_API_PYBIND_PATCH_H_
namespace pybind11 {
PYBIND11_RUNTIME_EXCEPTION(attribute_error, PyExc_AttributeError)
}
#endif // PYBIND_API_PYBIND_PATCH_H_
#include "nnacl/nnacl_common.h"

View File

@ -0,0 +1,35 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_NNACL_NNACL_COMMON_H_
#define MINDSPORE_LITE_NNACL_NNACL_COMMON_H_
#ifdef __cplusplus
extern "C" {
#endif
inline void ComputeStrides(const int *shape, int *strides, const int ndim) {
int stride = 1;
for (int i = ndim - 1; i >= 0; i--) {
strides[i] = stride;
stride *= shape[i];
}
}
#ifdef __cplusplus
}
#endif
#endif // MINDSPORE_LITE_NNACL_NNACL_COMMON_H_

View File

@ -28,6 +28,7 @@
#include <stdint.h>
#include <stdlib.h>
#include <stdbool.h>
#include <string.h>
#define C2NUM 2
#define C4NUM 4
@ -78,6 +79,17 @@ typedef struct OpParameter {
int thread_num_;
} OpParameter;
typedef struct QuantArg {
float scale_;
int32_t zp_;
} QuantArg;
typedef struct QuantMulArg {
int32_t multiplier_;
int left_shift_;
int right_shift_;
} QuantMulArg;
typedef enum ActType { ActType_No, ActType_Relu, ActType_Sigmod, ActType_Relu6, ActType_Prelu } ActType;
typedef enum PadMode { Pad_No, Pad_Same, Pad_Valid } PadMode;
typedef enum RoundingMode { Rounding_No, Rounding_Away_from_zero, Rounding_Up } RoundingMode;

File diff suppressed because it is too large Load Diff

View File

@ -17,102 +17,12 @@
#ifndef MINDSPORE_LITE_NNACL_PACK_H_
#define MINDSPORE_LITE_NNACL_PACK_H_
#include <stdio.h>
#ifdef ENABLE_NEON
#include <arm_neon.h>
#endif
#include "nnacl/conv_parameter.h"
#include "nnacl/op_base.h"
#include "nnacl/fp32/pack_fp32.h"
#include "nnacl/int8/pack_int8.h"
#ifdef __cplusplus
extern "C" {
#endif
void Im2ColPackUnitFp32(const float *input_data, const ConvParameter *conv_param, float *packed_input, int real_cal_num,
int block_index);
void PackHWCToWHC(const float *src, float *dst, int height, int width, int channel);
void Im2ColPackUnitInt8Opt(const int8_t *input_data, int8_t *packed_input, int8_t *matmul_input, int real_cal_num,
int block_index, int32_t *filter_zp, int32_t *input_sum, ConvParameter *conv_param,
bool per_channel, bool is_optimize);
void PackInputSum16x4PerLayer(const int8_t *src, int32_t *dst, int32_t filter_zp, size_t row4, size_t col16);
void PackInputSum16x4PerChannelArm32(const int8_t *input_value, int32_t *input_sum, int32_t *filter_zp_ptr,
size_t plane_size, size_t input_channel, size_t output_channel);
void PackInputSum16x4PerChannel(const int8_t *input_value, int32_t *input_sum, int32_t *filter_zp_ptr,
size_t plane_size, size_t input_channel, size_t output_channel);
void Conv1x1InputPack(const void *src_ptr, void *dst_ptr, ConvParameter *conv_param, int data_size);
void Pack1x1WeightFp32(const float *weight_data, float *packed_weight, ConvParameter *conv_param);
void PackInputSum16x4Int8(const int8_t *input, int32_t *input_sum, int32_t *filter_zp, ConvParameter *conv_param);
void PackInputToC8Int8(const int8_t *input_data, int16_t *packed_input, ConvParameter *conv_param);
void PackWeightKHWToHWKFp32(const void *src, void *dst, int plane, int channel);
void PackWeightInt8(int8_t *weight_data, ConvParameter *conv_param, int8_t *packed_weight, int32_t *weight_sum);
void PackWeightToC8Int8(const int8_t *origin_weight_data, int16_t *packed_weight_data, ConvParameter *conv_param);
void PackNHWCToNC4HW4Fp32(const void *src, void *dst, int batch, int plane, int channel);
void PackNCHWToNC4HW4Fp32(const void *src, void *dst, int batch, int plane, int channel);
void PackNHWCToNHWC4Fp32(const void *src, void *dst, int batch, int plane, int channel);
void PackNHWCToNHWC8Fp32(const void *src, void *dst, int batch, int plane, int channel);
void PackNHWCToNCHWFp32(const void *src, void *dst, int batch, int plane, int channel);
void PackNHWCToNCHWInt8(const void *src, void *dst, int batch, int plane, int channel);
void PackNCHWToNHWCFp32(const void *src, void *dst, int batch, int plane, int channel);
void PackNHWC4ToNHWCFp32(const void *src, void *dst, int batch, int plane, int channel);
void PackNC4HW4ToNHWC4Fp32(const void *src, void *dst, int batch, int plane, int channel);
void PackNC4HW4ToNHWCFp32(const void *src, void *dst, int batch, int plane, int channel);
void PackNHWCToC8HWN8Fp32(const void *src, void *dst, int batch, int plane, int channel);
void PackDepthwiseIndirectWeightC4Fp32(const void *src, void *dst, int height, int width, int channel);
void PackDepthwiseIndirectWeightC8Fp32(const void *src, void *dst, int height, int width, int channel);
void PackNHWCToNHWC4Int8(const void *src, void *dst, int batch, int plane, int channel);
void PackNHWC4ToNHWCInt8(const void *src, void *dst, int batch, int plane, int channel);
void PackNHWCToNHWC8Int8(const void *src, void *dst, int batch, int plane, int channel);
void PackNHWC8ToNHWCInt8(const void *src, void *dst, int batch, int plane, int channel);
void PackNCHWToNC8HW8Int8(const void *src, void *dst, int batch, int plane, int channel);
void PackNC4HW4ToNHWCInt8(const void *src, void *dst, int batch, int plane, int channel);
void PackNHWCToC8HWN8Int8(const void *src, void *dst, int batch, int plane, int channel);
void PackNCHWToNHWCInt8(const void *src, void *dst, int batch, int plane, int channel);
void PackDepthwiseInt8Input(const int8_t *src, int16_t *dst, const ConvParameter *conv_param);
void PackDepthwiseInt8Weight(const int8_t *origin_weight, int16_t *packed_weight_, int plane, int channel,
ConvQuantArg *quant_qrg);
void PackDeconvDepthwiseInt8Weight(const int8_t *origin_weight, int16_t *packed_weight_, int plane, int channel,
ConvQuantArg *quant_qrg);
#ifdef ENABLE_ARM
void PreSum4x16Int8Pert(const int8_t *src, int32_t *sum, size_t row4, size_t col16, int32_t filter_zp);
void PreSum4x16Int8Peroc(const int8_t *src, int32_t *sum, int32_t *zp, size_t hw4, size_t ic16, int32_t oc_div,
size_t oc_res, size_t stride);
#endif
#ifdef __cplusplus
}

View File

@ -17,11 +17,16 @@
#define MINDSPORE_LITE_NNACL_PAD_PARAMETER_H_
#include "nnacl/op_base.h"
#include "nnacl/quantization/quantize.h"
#define MAX_PAD_SIZE 8
#define DEFAULT_PAD_NDIMS 4
typedef struct PadQuantArg {
QuantArg *in_quant_args_;
QuantArg *out_quanr_args_;
int8_t *constant_value_;
} PadQuantArg;
typedef struct PadParameter {
// Primitive parameter
OpParameter op_parameter_;

View File

@ -17,7 +17,6 @@
#define MINDSPORE_LITE_NNACL_POOLING_PARAMETER_H_
#include "nnacl/op_base.h"
#include "nnacl/quantization/quantize.h"
typedef enum PoolMode { PoolMode_No, PoolMode_MaxPool, PoolMode_AvgPool } PoolMode;

View File

@ -18,7 +18,14 @@
#define MINDSPORE_LITE_NNACL_POWER_PARAMETER_H_
#include "nnacl/op_base.h"
#include "nnacl/quantization/quantize.h"
typedef struct PowerQuantArg {
QuantArg in_args_;
QuantArg exp_args_;
QuantArg out_args_;
int output_activation_min_;
int output_activation_max_;
} PowerQuantArg;
typedef struct PowerParameter {
// Primitive parameter

View File

@ -18,7 +18,13 @@
#define MINDSPORE_LITE_NNACL_RESHAHPE_PARAMETER_H_
#include "nnacl/op_base.h"
#include "nnacl/quantization/quantize.h"
typedef struct ReshapeQuantArg {
QuantArg in_args_;
QuantArg out_args_;
int output_activation_min_;
int output_activation_max_;
} ReshapeQuantArg;
typedef struct ReshapeParameter {
// primitive parameter

View File

@ -17,8 +17,8 @@
#ifndef MINDSPORE_LITE_NNACL_SCALE_H_
#define MINDSPORE_LITE_NNACL_SCALE_H_
#include <mindspore/lite/nnacl/quantization/quantize.h>
#include "nnacl/op_base.h"
typedef struct ScaleParameter {
// primitive parameter
OpParameter op_parameter_;

View File

@ -16,7 +16,6 @@
#include "nnacl/scatter_nd.h"
#include <string.h>
#include <stdio.h>
#include "nnacl/errorcode.h"
int DoScatterND(float *output_ptr, const float *update, int *output_unit_offsets, int unit_size, int num_units) {

View File

@ -18,10 +18,16 @@
#define MINDSPORE_LITE_NNACL_SLICE_PARAMETER_H_
#include "nnacl/op_base.h"
#include "nnacl/quantization/quantize.h"
#define SLICE_SHAPE_MAX_SIZE 4
typedef struct SliceQuantArg {
QuantArg in_args_;
QuantArg out_args_;
int output_activation_min_;
int output_activation_max_;
} SliceQuantArg;
typedef struct SliceParameter {
// primitive parameter
OpParameter op_parameter_;

View File

@ -18,8 +18,16 @@
#define MINDSPORE_LITE_NNACL_SPLIT_PARAMETER_H_
#include "nnacl/op_base.h"
#include "nnacl/quantization/quantize.h"
#define SPLIT_STRIDES_SIZE 32
typedef struct SplitQuantArg {
QuantArg in_args_;
QuantArg out_args_[20];
int output_activation_min_;
int output_activation_max_;
} SplitQuantArg;
typedef struct SplitParameter {
// primitive parameter
OpParameter op_parameter_;

View File

@ -16,11 +16,16 @@
#ifndef MINDSPORE_LITE_NNACL_SQUEEZE_PARAMETER_H_
#define MINDSPORE_LITE_NNACL_SQUEEZE_PARAMETER_H_
#include "nnacl/op_base.h"
#include "nnacl/quantization/quantize.h"
#define SQUEEZE_OFFSET_MAX_SIZE 4
typedef struct SqueezeQuantArg {
QuantArg *in_quant_args_;
QuantArg *out_quant_args_;
} SqueezeQuantArg;
typedef struct SqueezeParameter {
// primitive parameter
OpParameter op_parameter_;

View File

@ -16,11 +16,26 @@
#ifndef MINDSPORE_LITE_NNACL_UNSQUEEZE_PARAMETER_H_
#define MINDSPORE_LITE_NNACL_UNSQUEEZE_PARAMETER_H_
#include <string.h>
#include <math.h>
#include "nnacl/op_base.h"
#include "nnacl/quantization/quantize.h"
#define UNSQUEEZE_OFFSET_MAX_SIZE 4
typedef struct UnSqueezeQuantArg {
int *input_sizes_;
int output_size_;
int **input_shapes_;
int *output_shape_;
float alpha;
int axis_;
size_t input_num_;
size_t output_dim_;
QuantArg in_quant_args_;
QuantArg out_quant_args_;
} UnSqueezeQuantArg;
typedef struct UnSqueezeParameter {
// primitive parameter
OpParameter op_parameter_;

View File

@ -140,810 +140,3 @@ void WinogradOutputTransform(const float *gemm_out, float *out_data, const float
out_tile_index++;
}
}
// int8 conv3x3
void Conv3x3Int8InputUnit(int16_t *tmp_data, int16_t *trans_input_data, size_t step, int input_zp) {
#ifdef ENABLE_ARM
int16x8_t zp = vdupq_n_s16(input_zp);
int16x8_t d00 = vsubq_s16(vld1q_s16(tmp_data), zp);
int16x8_t d01 = vsubq_s16(vld1q_s16(tmp_data + 8), zp);
int16x8_t d02 = vsubq_s16(vld1q_s16(tmp_data + 2 * 8), zp);
int16x8_t d03 = vsubq_s16(vld1q_s16(tmp_data + 3 * 8), zp);
int16x8_t d10 = vsubq_s16(vld1q_s16(tmp_data + 4 * 8), zp);
int16x8_t d11 = vsubq_s16(vld1q_s16(tmp_data + 5 * 8), zp);
int16x8_t d12 = vsubq_s16(vld1q_s16(tmp_data + 6 * 8), zp);
int16x8_t d13 = vsubq_s16(vld1q_s16(tmp_data + 7 * 8), zp);
int16x8_t d20 = vsubq_s16(vld1q_s16(tmp_data + 8 * 8), zp);
int16x8_t d21 = vsubq_s16(vld1q_s16(tmp_data + 9 * 8), zp);
int16x8_t d22 = vsubq_s16(vld1q_s16(tmp_data + 10 * 8), zp);
int16x8_t d23 = vsubq_s16(vld1q_s16(tmp_data + 11 * 8), zp);
int16x8_t d30 = vsubq_s16(vld1q_s16(tmp_data + 12 * 8), zp);
int16x8_t d31 = vsubq_s16(vld1q_s16(tmp_data + 13 * 8), zp);
int16x8_t d32 = vsubq_s16(vld1q_s16(tmp_data + 14 * 8), zp);
int16x8_t d33 = vsubq_s16(vld1q_s16(tmp_data + 15 * 8), zp);
int16x8_t t00 = vsubq_s16(d00, d20);
int16x8_t t01 = vsubq_s16(d01, d21);
int16x8_t t02 = vsubq_s16(d02, d22);
int16x8_t t03 = vsubq_s16(d03, d23);
int16x8_t t10 = vaddq_s16(d10, d20);
int16x8_t t11 = vaddq_s16(d11, d21);
int16x8_t t12 = vaddq_s16(d12, d22);
int16x8_t t13 = vaddq_s16(d13, d23);
int16x8_t t20 = vsubq_s16(d20, d10);
int16x8_t t21 = vsubq_s16(d21, d11);
int16x8_t t22 = vsubq_s16(d22, d12);
int16x8_t t23 = vsubq_s16(d23, d13);
int16x8_t t30 = vsubq_s16(d10, d30);
int16x8_t t31 = vsubq_s16(d11, d31);
int16x8_t t32 = vsubq_s16(d12, d32);
int16x8_t t33 = vsubq_s16(d13, d33);
int16x8_t m00 = vsubq_s16(t00, t02);
int16x8_t m01 = vaddq_s16(t01, t02);
int16x8_t m02 = vsubq_s16(t02, t01);
int16x8_t m03 = vsubq_s16(t01, t03);
int16x8_t m10 = vsubq_s16(t10, t12);
int16x8_t m11 = vaddq_s16(t11, t12);
int16x8_t m12 = vsubq_s16(t12, t11);
int16x8_t m13 = vsubq_s16(t11, t13);
int16x8_t m20 = vsubq_s16(t20, t22);
int16x8_t m21 = vaddq_s16(t21, t22);
int16x8_t m22 = vsubq_s16(t22, t21);
int16x8_t m23 = vsubq_s16(t21, t23);
int16x8_t m30 = vsubq_s16(t30, t32);
int16x8_t m31 = vaddq_s16(t31, t32);
int16x8_t m32 = vsubq_s16(t32, t31);
int16x8_t m33 = vsubq_s16(t31, t33);
vst1q_s16(trans_input_data, m00);
vst1q_s16(trans_input_data + step, m01);
vst1q_s16(trans_input_data + 2 * step, m02);
vst1q_s16(trans_input_data + 3 * step, m03);
vst1q_s16(trans_input_data + 4 * step, m10);
vst1q_s16(trans_input_data + 5 * step, m11);
vst1q_s16(trans_input_data + 6 * step, m12);
vst1q_s16(trans_input_data + 7 * step, m13);
vst1q_s16(trans_input_data + 8 * step, m20);
vst1q_s16(trans_input_data + 9 * step, m21);
vst1q_s16(trans_input_data + 10 * step, m22);
vst1q_s16(trans_input_data + 11 * step, m23);
vst1q_s16(trans_input_data + 12 * step, m30);
vst1q_s16(trans_input_data + 13 * step, m31);
vst1q_s16(trans_input_data + 14 * step, m32);
vst1q_s16(trans_input_data + 15 * step, m33);
#else
for (int i = 0; i < C8NUM; i++) {
int16_t *local_ptr = tmp_data + i;
int16_t d00 = local_ptr[0] - input_zp;
int16_t d01 = (local_ptr + C8NUM)[0] - input_zp;
int16_t d02 = (local_ptr + 2 * C8NUM)[0] - input_zp;
int16_t d03 = (local_ptr + 3 * C8NUM)[0] - input_zp;
int16_t d10 = (local_ptr + 4 * C8NUM)[0] - input_zp;
int16_t d11 = (local_ptr + 5 * C8NUM)[0] - input_zp;
int16_t d12 = (local_ptr + 6 * C8NUM)[0] - input_zp;
int16_t d13 = (local_ptr + 7 * C8NUM)[0] - input_zp;
int16_t d20 = (local_ptr + 8 * C8NUM)[0] - input_zp;
int16_t d21 = (local_ptr + 9 * C8NUM)[0] - input_zp;
int16_t d22 = (local_ptr + 10 * C8NUM)[0] - input_zp;
int16_t d23 = (local_ptr + 11 * C8NUM)[0] - input_zp;
int16_t d30 = (local_ptr + 12 * C8NUM)[0] - input_zp;
int16_t d31 = (local_ptr + 13 * C8NUM)[0] - input_zp;
int16_t d32 = (local_ptr + 14 * C8NUM)[0] - input_zp;
int16_t d33 = (local_ptr + 15 * C8NUM)[0] - input_zp;
int16_t t00 = d00 - d20;
int16_t t01 = d01 - d21;
int16_t t02 = d02 - d22;
int16_t t03 = d03 - d23;
int16_t t10 = d10 + d20;
int16_t t11 = d11 + d21;
int16_t t12 = d12 + d22;
int16_t t13 = d13 + d23;
int16_t t20 = d20 - d10;
int16_t t21 = d21 - d11;
int16_t t22 = d22 - d12;
int16_t t23 = d23 - d13;
int16_t t30 = d10 - d30;
int16_t t31 = d11 - d31;
int16_t t32 = d12 - d32;
int16_t t33 = d13 - d33;
int16_t m00 = t00 - t02;
int16_t m01 = t01 + t02;
int16_t m02 = t02 - t01;
int16_t m03 = t01 - t03;
int16_t m10 = t10 - t12;
int16_t m11 = t11 + t12;
int16_t m12 = t12 - t11;
int16_t m13 = t11 - t13;
int16_t m20 = t20 - t22;
int16_t m21 = t21 + t22;
int16_t m22 = t22 - t21;
int16_t m23 = t21 - t23;
int16_t m30 = t30 - t32;
int16_t m31 = t31 + t32;
int16_t m32 = t32 - t31;
int16_t m33 = t31 - t33;
(trans_input_data + i)[0] = m00;
(trans_input_data + i + step)[0] = m01;
(trans_input_data + i + 2 * step)[0] = m02;
(trans_input_data + i + 3 * step)[0] = m03;
(trans_input_data + i + 4 * step)[0] = m10;
(trans_input_data + i + 5 * step)[0] = m11;
(trans_input_data + i + 6 * step)[0] = m12;
(trans_input_data + i + 7 * step)[0] = m13;
(trans_input_data + i + 8 * step)[0] = m20;
(trans_input_data + i + 9 * step)[0] = m21;
(trans_input_data + i + 10 * step)[0] = m22;
(trans_input_data + i + 11 * step)[0] = m23;
(trans_input_data + i + 12 * step)[0] = m30;
(trans_input_data + i + 13 * step)[0] = m31;
(trans_input_data + i + 14 * step)[0] = m32;
(trans_input_data + i + 15 * step)[0] = m33;
}
#endif
}
void Conv3x3Int8InputTransform(const int16_t *input_data, int16_t *trans_input, int16_t *tmp_data, int start_index,
int real_cal_num, int out_w_block, ConvParameter *conv_param) {
// input data format : nhwc
int input_channel = conv_param->input_channel_;
int input_width = conv_param->input_w_;
int input_height = conv_param->input_h_;
int pad_w = conv_param->pad_l_;
int pad_h = conv_param->pad_u_;
ConvQuantArg quant_arg = conv_param->conv_quant_arg_;
int input_zp = quant_arg.input_quant_args_[0].zp_;
const int ic8 = UP_DIV(input_channel, C8NUM);
const int input_unit = 4;
if (out_w_block == 0) {
return;
}
for (int cal_id = 0; cal_id < real_cal_num; cal_id++) {
int x_id = start_index + cal_id;
int origin_x = (x_id % out_w_block) * OUPUT_UNIT - pad_w;
int origin_y = (x_id / out_w_block) * OUPUT_UNIT - pad_h;
int real_x_start = origin_x > 0 ? 0 : -origin_x;
int real_x_end = (origin_x + input_unit) < input_width ? input_unit : (input_width - origin_x);
int real_y_start = origin_y > 0 ? 0 : -origin_y;
int real_y_end = (origin_y + input_unit) < input_height ? input_unit : (input_height - origin_y);
int src_plane_offset = C8NUM * (origin_y * input_width + origin_x);
int dst_plane_offset = cal_id * C8NUM;
for (int ic = 0; ic < ic8; ic++) {
// copy data from origin input to tmp buffer
for (int i = 0; i < input_unit * input_unit * TILE_NUM; i++) tmp_data[i] = input_zp;
int src_c8_offset = src_plane_offset + ic * C8NUM * input_height * input_width;
for (int j = real_y_start; j < real_y_end; j++) {
const int16_t *src = input_data + src_c8_offset + C8NUM * (j * input_width + real_x_start);
int16_t *dst = tmp_data + C8NUM * (C4NUM * j + real_x_start);
memcpy(dst, src, (real_x_end - real_x_start) * C8NUM * sizeof(int16_t));
}
// input transform
int dst_ic8_offset = dst_plane_offset + ic * TILE_NUM * C8NUM;
size_t dst_step = ic8 * C8NUM * TILE_NUM;
int16_t *trans_input_ptr = trans_input + dst_ic8_offset;
Conv3x3Int8InputUnit(tmp_data, trans_input_ptr, dst_step, input_zp);
}
}
}
void Conv3x3Int8FilterTransform(const int16_t *weight_data, int16_t *trans_weight, int iC8, int output_channel,
int kernel_plane) {
const int input_unit = 4;
int dst_step = iC8 * C8NUM * C4NUM;
for (int o = 0; o < output_channel; o++) {
int oc4_block_num = o / C4NUM;
int oc4_block_rem = o % C4NUM;
int src_oc_offset = o * iC8 * C8NUM * kernel_plane;
int dst_oc_offset = oc4_block_num * C4NUM * iC8 * C8NUM * input_unit * input_unit + oc4_block_rem;
for (int i = 0; i < iC8; i++) {
const int16_t *src_ic8_ptr = weight_data + src_oc_offset + i * kernel_plane * C8NUM;
int16_t *dst_ic8_ptr = trans_weight + dst_oc_offset + i * C4NUM * C8NUM;
#ifdef ENABLE_ARM
int16x8_t g00 = vld1q_s16(src_ic8_ptr);
int16x8_t g01 = vld1q_s16(src_ic8_ptr + 8);
int16x8_t g02 = vld1q_s16(src_ic8_ptr + 2 * 8);
int16x8_t g10 = vld1q_s16(src_ic8_ptr + 3 * 8);
int16x8_t g11 = vld1q_s16(src_ic8_ptr + 4 * 8);
int16x8_t g12 = vld1q_s16(src_ic8_ptr + 5 * 8);
int16x8_t g20 = vld1q_s16(src_ic8_ptr + 6 * 8);
int16x8_t g21 = vld1q_s16(src_ic8_ptr + 7 * 8);
int16x8_t g22 = vld1q_s16(src_ic8_ptr + 8 * 8);
int16x8_t dst00 = vmulq_n_s16(g00, 2);
int16x8_t dst01 = vmulq_n_s16(g01, 2);
int16x8_t dst02 = vmulq_n_s16(g02, 2);
int16x8_t dst10 = vaddq_s16(vaddq_s16(g00, g10), g20);
int16x8_t dst11 = vaddq_s16(vaddq_s16(g01, g11), g21);
int16x8_t dst12 = vaddq_s16(vaddq_s16(g02, g12), g22);
int16x8_t dst20 = vaddq_s16(vsubq_s16(g00, g10), g20);
int16x8_t dst21 = vaddq_s16(vsubq_s16(g01, g11), g21);
int16x8_t dst22 = vaddq_s16(vsubq_s16(g02, g12), g22);
int16x8_t dst30 = vmulq_n_s16(g20, 2);
int16x8_t dst31 = vmulq_n_s16(g21, 2);
int16x8_t dst32 = vmulq_n_s16(g22, 2);
int16x8_t m00 = vmulq_n_s16(dst00, 2);
int16x8_t m01 = vaddq_s16(vaddq_s16(dst00, dst01), dst02);
int16x8_t m02 = vaddq_s16(vsubq_s16(dst00, dst01), dst02);
int16x8_t m03 = vmulq_n_s16(dst02, 2);
int16x8_t m10 = vmulq_n_s16(dst10, 2);
int16x8_t m11 = vaddq_s16(vaddq_s16(dst10, dst11), dst12);
int16x8_t m12 = vaddq_s16(vsubq_s16(dst10, dst11), dst12);
int16x8_t m13 = vmulq_n_s16(dst12, 2);
int16x8_t m20 = vmulq_n_s16(dst20, 2);
int16x8_t m21 = vaddq_s16(vaddq_s16(dst20, dst21), dst22);
int16x8_t m22 = vaddq_s16(vsubq_s16(dst20, dst21), dst22);
int16x8_t m23 = vmulq_n_s16(dst22, 2);
int16x8_t m30 = vmulq_n_s16(dst30, 2);
int16x8_t m31 = vaddq_s16(vaddq_s16(dst30, dst31), dst32);
int16x8_t m32 = vaddq_s16(vsubq_s16(dst30, dst31), dst32);
int16x8_t m33 = vmulq_n_s16(dst32, 2);
dst_ic8_ptr[0] = m00[0];
dst_ic8_ptr[4] = m00[1];
dst_ic8_ptr[8] = m00[2];
dst_ic8_ptr[12] = m00[3];
dst_ic8_ptr[16] = m00[4];
dst_ic8_ptr[20] = m00[5];
dst_ic8_ptr[24] = m00[6];
dst_ic8_ptr[28] = m00[7];
dst_ic8_ptr[0 + dst_step] = m01[0];
dst_ic8_ptr[4 + dst_step] = m01[1];
dst_ic8_ptr[8 + dst_step] = m01[2];
dst_ic8_ptr[12 + dst_step] = m01[3];
dst_ic8_ptr[16 + dst_step] = m01[4];
dst_ic8_ptr[20 + dst_step] = m01[5];
dst_ic8_ptr[24 + dst_step] = m01[6];
dst_ic8_ptr[28 + dst_step] = m01[7];
dst_ic8_ptr[0 + 2 * dst_step] = m02[0];
dst_ic8_ptr[4 + 2 * dst_step] = m02[1];
dst_ic8_ptr[8 + 2 * dst_step] = m02[2];
dst_ic8_ptr[12 + 2 * dst_step] = m02[3];
dst_ic8_ptr[16 + 2 * dst_step] = m02[4];
dst_ic8_ptr[20 + 2 * dst_step] = m02[5];
dst_ic8_ptr[24 + 2 * dst_step] = m02[6];
dst_ic8_ptr[28 + 2 * dst_step] = m02[7];
dst_ic8_ptr[0 + 3 * dst_step] = m03[0];
dst_ic8_ptr[4 + 3 * dst_step] = m03[1];
dst_ic8_ptr[8 + 3 * dst_step] = m03[2];
dst_ic8_ptr[12 + 3 * dst_step] = m03[3];
dst_ic8_ptr[16 + 3 * dst_step] = m03[4];
dst_ic8_ptr[20 + 3 * dst_step] = m03[5];
dst_ic8_ptr[24 + 3 * dst_step] = m03[6];
dst_ic8_ptr[28 + 3 * dst_step] = m03[7];
dst_ic8_ptr[0 + 4 * dst_step] = m10[0];
dst_ic8_ptr[4 + 4 * dst_step] = m10[1];
dst_ic8_ptr[8 + 4 * dst_step] = m10[2];
dst_ic8_ptr[12 + 4 * dst_step] = m10[3];
dst_ic8_ptr[16 + 4 * dst_step] = m10[4];
dst_ic8_ptr[20 + 4 * dst_step] = m10[5];
dst_ic8_ptr[24 + 4 * dst_step] = m10[6];
dst_ic8_ptr[28 + 4 * dst_step] = m10[7];
dst_ic8_ptr[0 + 5 * dst_step] = m11[0];
dst_ic8_ptr[4 + 5 * dst_step] = m11[1];
dst_ic8_ptr[8 + 5 * dst_step] = m11[2];
dst_ic8_ptr[12 + 5 * dst_step] = m11[3];
dst_ic8_ptr[16 + 5 * dst_step] = m11[4];
dst_ic8_ptr[20 + 5 * dst_step] = m11[5];
dst_ic8_ptr[24 + 5 * dst_step] = m11[6];
dst_ic8_ptr[28 + 5 * dst_step] = m11[7];
dst_ic8_ptr[0 + 6 * dst_step] = m12[0];
dst_ic8_ptr[4 + 6 * dst_step] = m12[1];
dst_ic8_ptr[8 + 6 * dst_step] = m12[2];
dst_ic8_ptr[12 + 6 * dst_step] = m12[3];
dst_ic8_ptr[16 + 6 * dst_step] = m12[4];
dst_ic8_ptr[20 + 6 * dst_step] = m12[5];
dst_ic8_ptr[24 + 6 * dst_step] = m12[6];
dst_ic8_ptr[28 + 6 * dst_step] = m12[7];
dst_ic8_ptr[0 + 7 * dst_step] = m13[0];
dst_ic8_ptr[4 + 7 * dst_step] = m13[1];
dst_ic8_ptr[8 + 7 * dst_step] = m13[2];
dst_ic8_ptr[12 + 7 * dst_step] = m13[3];
dst_ic8_ptr[16 + 7 * dst_step] = m13[4];
dst_ic8_ptr[20 + 7 * dst_step] = m13[5];
dst_ic8_ptr[24 + 7 * dst_step] = m13[6];
dst_ic8_ptr[28 + 7 * dst_step] = m13[7];
dst_ic8_ptr[0 + 8 * dst_step] = m20[0];
dst_ic8_ptr[4 + 8 * dst_step] = m20[1];
dst_ic8_ptr[8 + 8 * dst_step] = m20[2];
dst_ic8_ptr[12 + 8 * dst_step] = m20[3];
dst_ic8_ptr[16 + 8 * dst_step] = m20[4];
dst_ic8_ptr[20 + 8 * dst_step] = m20[5];
dst_ic8_ptr[24 + 8 * dst_step] = m20[6];
dst_ic8_ptr[28 + 8 * dst_step] = m20[7];
dst_ic8_ptr[0 + 9 * dst_step] = m21[0];
dst_ic8_ptr[4 + 9 * dst_step] = m21[1];
dst_ic8_ptr[8 + 9 * dst_step] = m21[2];
dst_ic8_ptr[12 + 9 * dst_step] = m21[3];
dst_ic8_ptr[16 + 9 * dst_step] = m21[4];
dst_ic8_ptr[20 + 9 * dst_step] = m21[5];
dst_ic8_ptr[24 + 9 * dst_step] = m21[6];
dst_ic8_ptr[28 + 9 * dst_step] = m21[7];
dst_ic8_ptr[0 + 10 * dst_step] = m22[0];
dst_ic8_ptr[4 + 10 * dst_step] = m22[1];
dst_ic8_ptr[8 + 10 * dst_step] = m22[2];
dst_ic8_ptr[12 + 10 * dst_step] = m22[3];
dst_ic8_ptr[16 + 10 * dst_step] = m22[4];
dst_ic8_ptr[20 + 10 * dst_step] = m22[5];
dst_ic8_ptr[24 + 10 * dst_step] = m22[6];
dst_ic8_ptr[28 + 10 * dst_step] = m22[7];
dst_ic8_ptr[0 + 11 * dst_step] = m23[0];
dst_ic8_ptr[4 + 11 * dst_step] = m23[1];
dst_ic8_ptr[8 + 11 * dst_step] = m23[2];
dst_ic8_ptr[12 + 11 * dst_step] = m23[3];
dst_ic8_ptr[16 + 11 * dst_step] = m23[4];
dst_ic8_ptr[20 + 11 * dst_step] = m23[5];
dst_ic8_ptr[24 + 11 * dst_step] = m23[6];
dst_ic8_ptr[28 + 11 * dst_step] = m23[7];
dst_ic8_ptr[0 + 12 * dst_step] = m30[0];
dst_ic8_ptr[4 + 12 * dst_step] = m30[1];
dst_ic8_ptr[8 + 12 * dst_step] = m30[2];
dst_ic8_ptr[12 + 12 * dst_step] = m30[3];
dst_ic8_ptr[16 + 12 * dst_step] = m30[4];
dst_ic8_ptr[20 + 12 * dst_step] = m30[5];
dst_ic8_ptr[24 + 12 * dst_step] = m30[6];
dst_ic8_ptr[28 + 12 * dst_step] = m30[7];
dst_ic8_ptr[0 + 13 * dst_step] = m31[0];
dst_ic8_ptr[4 + 13 * dst_step] = m31[1];
dst_ic8_ptr[8 + 13 * dst_step] = m31[2];
dst_ic8_ptr[12 + 13 * dst_step] = m31[3];
dst_ic8_ptr[16 + 13 * dst_step] = m31[4];
dst_ic8_ptr[20 + 13 * dst_step] = m31[5];
dst_ic8_ptr[24 + 13 * dst_step] = m31[6];
dst_ic8_ptr[28 + 13 * dst_step] = m31[7];
dst_ic8_ptr[0 + 14 * dst_step] = m32[0];
dst_ic8_ptr[4 + 14 * dst_step] = m32[1];
dst_ic8_ptr[8 + 14 * dst_step] = m32[2];
dst_ic8_ptr[12 + 14 * dst_step] = m32[3];
dst_ic8_ptr[16 + 14 * dst_step] = m32[4];
dst_ic8_ptr[20 + 14 * dst_step] = m32[5];
dst_ic8_ptr[24 + 14 * dst_step] = m32[6];
dst_ic8_ptr[28 + 14 * dst_step] = m32[7];
dst_ic8_ptr[0 + 15 * dst_step] = m33[0];
dst_ic8_ptr[4 + 15 * dst_step] = m33[1];
dst_ic8_ptr[8 + 15 * dst_step] = m33[2];
dst_ic8_ptr[12 + 15 * dst_step] = m33[3];
dst_ic8_ptr[16 + 15 * dst_step] = m33[4];
dst_ic8_ptr[20 + 15 * dst_step] = m33[5];
dst_ic8_ptr[24 + 15 * dst_step] = m33[6];
dst_ic8_ptr[28 + 15 * dst_step] = m33[7];
#else
for (int j = 0; j < C8NUM; j++) {
const int16_t *local_ptr = src_ic8_ptr + j;
int16_t dst00 = local_ptr[0] * 2;
int16_t dst01 = (local_ptr + 8)[0] * 2;
int16_t dst02 = (local_ptr + 16)[0] * 2;
int16_t dst10 = local_ptr[0] + (local_ptr + 24)[0] + (local_ptr + 48)[0];
int16_t dst11 = (local_ptr + 8)[0] + (local_ptr + 32)[0] + (local_ptr + 56)[0];
int16_t dst12 = (local_ptr + 16)[0] + (local_ptr + 40)[0] + (local_ptr + 64)[0];
int16_t dst20 = local_ptr[0] - (local_ptr + 24)[0] + (local_ptr + 48)[0];
int16_t dst21 = (local_ptr + 8)[0] - (local_ptr + 32)[0] + (local_ptr + 56)[0];
int16_t dst22 = (local_ptr + 16)[0] - (local_ptr + 40)[0] + (local_ptr + 64)[0];
int16_t dst30 = (local_ptr + 48)[0] * 2;
int16_t dst31 = (local_ptr + 56)[0] * 2;
int16_t dst32 = (local_ptr + 64)[0] * 2;
int16_t m00 = dst00 * 2;
int16_t m01 = dst00 + dst01 + dst02;
int16_t m02 = dst00 - dst01 + dst02;
int16_t m03 = dst02 * 2;
int16_t m10 = dst10 * 2;
int16_t m11 = dst10 + dst11 + dst12;
int16_t m12 = dst10 - dst11 + dst12;
int16_t m13 = dst12 * 2;
int16_t m20 = dst20 * 2;
int16_t m21 = dst20 + dst21 + dst22;
int16_t m22 = dst20 - dst21 + dst22;
int16_t m23 = dst22 * 2;
int16_t m30 = dst30 * 2;
int16_t m31 = dst30 + dst31 + dst32;
int16_t m32 = dst30 - dst31 + dst32;
int16_t m33 = dst32 * 2;
*(dst_ic8_ptr + j * 4) = m00;
*(dst_ic8_ptr + j * 4 + dst_step) = m01;
*(dst_ic8_ptr + j * 4 + 2 * dst_step) = m02;
*(dst_ic8_ptr + j * 4 + 3 * dst_step) = m03;
*(dst_ic8_ptr + j * 4 + 4 * dst_step) = m10;
*(dst_ic8_ptr + j * 4 + 5 * dst_step) = m11;
*(dst_ic8_ptr + j * 4 + 6 * dst_step) = m12;
*(dst_ic8_ptr + j * 4 + 7 * dst_step) = m13;
*(dst_ic8_ptr + j * 4 + 8 * dst_step) = m20;
*(dst_ic8_ptr + j * 4 + 9 * dst_step) = m21;
*(dst_ic8_ptr + j * 4 + 10 * dst_step) = m22;
*(dst_ic8_ptr + j * 4 + 11 * dst_step) = m23;
*(dst_ic8_ptr + j * 4 + 12 * dst_step) = m30;
*(dst_ic8_ptr + j * 4 + 13 * dst_step) = m31;
*(dst_ic8_ptr + j * 4 + 14 * dst_step) = m32;
*(dst_ic8_ptr + j * 4 + 15 * dst_step) = m33;
}
#endif
}
}
}
void Conv3x3Int8OutputUnit(const int32_t *gemm_out, const int32_t *bias_data, int8_t *output_data, bool h_not_bound,
bool w_not_bound, int output_w, int real_num, int oc_start, ConvParameter *conv_param) {
int32_t *left_shift = conv_param->conv_quant_arg_.left_shift_;
int32_t *right_shift = conv_param->conv_quant_arg_.right_shift_;
int32_t *quant_multiplier = conv_param->conv_quant_arg_.quant_multiplier_;
int output_zp = conv_param->conv_quant_arg_.output_quant_args_[0].zp_;
int out_min = conv_param->conv_quant_arg_.out_act_min_[0];
int out_max = conv_param->conv_quant_arg_.out_act_max_[0];
#ifdef ENABLE_ARM
int32x4_t bias_ptr = vld1q_s32(bias_data);
int32x4_t s00 = vld1q_s32(gemm_out);
int32x4_t s01 = vld1q_s32(gemm_out + 4);
int32x4_t s02 = vld1q_s32(gemm_out + 8);
int32x4_t s03 = vld1q_s32(gemm_out + 12);
int32x4_t s10 = vld1q_s32(gemm_out + 16);
int32x4_t s11 = vld1q_s32(gemm_out + 20);
int32x4_t s12 = vld1q_s32(gemm_out + 24);
int32x4_t s13 = vld1q_s32(gemm_out + 28);
int32x4_t s20 = vld1q_s32(gemm_out + 32);
int32x4_t s21 = vld1q_s32(gemm_out + 36);
int32x4_t s22 = vld1q_s32(gemm_out + 40);
int32x4_t s23 = vld1q_s32(gemm_out + 44);
int32x4_t s30 = vld1q_s32(gemm_out + 48);
int32x4_t s31 = vld1q_s32(gemm_out + 52);
int32x4_t s32 = vld1q_s32(gemm_out + 56);
int32x4_t s33 = vld1q_s32(gemm_out + 60);
int32x4_t t00 = vshrq_n_s32(vaddq_s32(vaddq_s32(s00, s10), s20), 1);
int32x4_t t01 = vshrq_n_s32(vaddq_s32(vaddq_s32(s01, s11), s21), 1);
int32x4_t t02 = vshrq_n_s32(vaddq_s32(vaddq_s32(s02, s12), s22), 1);
int32x4_t t03 = vshrq_n_s32(vaddq_s32(vaddq_s32(s03, s13), s23), 1);
int32x4_t t10 = vshrq_n_s32(vsubq_s32(vsubq_s32(s10, s20), s30), 1);
int32x4_t t11 = vshrq_n_s32(vsubq_s32(vsubq_s32(s11, s21), s31), 1);
int32x4_t t12 = vshrq_n_s32(vsubq_s32(vsubq_s32(s12, s22), s32), 1);
int32x4_t t13 = vshrq_n_s32(vsubq_s32(vsubq_s32(s13, s23), s33), 1);
int32x4_t d00 = vaddq_s32(vshrq_n_s32(vaddq_s32(vaddq_s32(t00, t01), t02), 1), bias_ptr);
int32x4_t d01 = vaddq_s32(vshrq_n_s32(vsubq_s32(vsubq_s32(t01, t02), t03), 1), bias_ptr);
int32x4_t d10 = vaddq_s32(vshrq_n_s32(vaddq_s32(vaddq_s32(t10, t11), t12), 1), bias_ptr);
int32x4_t d11 = vaddq_s32(vshrq_n_s32(vsubq_s32(vsubq_s32(t11, t12), t13), 1), bias_ptr);
int32x4_t out_multiplier;
int32x4_t ls;
int32x4_t rs;
if ((conv_param->conv_quant_arg_.per_channel_ & FILTER_PER_CHANNEL)) {
out_multiplier = vld1q_s32(quant_multiplier + oc_start);
ls = vld1q_s32(left_shift + oc_start);
rs = vld1q_s32(right_shift + oc_start);
} else {
out_multiplier = vdupq_n_s32(quant_multiplier[0]);
ls = vdupq_n_s32(left_shift[0]);
rs = vdupq_n_s32(right_shift[0]);
}
int32x4_t out_zp = vdupq_n_s32(output_zp);
int32x4_t output_min = vdupq_n_s32(out_min);
int32x4_t output_max = vdupq_n_s32(out_max);
d00 = vqshlq_s32(d00, ls);
d00 = vqrdmulhq_s32(d00, out_multiplier);
int32x4_t carry = vandq_s32(d00, rs);
carry = vshrq_n_s32(carry, 31);
d00 = vqaddq_s32(d00, carry);
d00 = vqrshlq_s32(d00, rs);
d00 = vaddq_s32(d00, out_zp);
d00 = vmaxq_s32(d00, output_min);
d00 = vminq_s32(d00, output_max);
d01 = vqshlq_s32(d01, ls);
d01 = vqrdmulhq_s32(d01, out_multiplier);
carry = vandq_s32(d01, rs);
carry = vshrq_n_s32(carry, 31);
d01 = vqaddq_s32(d01, carry);
d01 = vqrshlq_s32(d01, rs);
d01 = vaddq_s32(d01, out_zp);
d01 = vmaxq_s32(d01, output_min);
d01 = vminq_s32(d01, output_max);
d10 = vqshlq_s32(d10, ls);
d10 = vqrdmulhq_s32(d10, out_multiplier);
carry = vandq_s32(d10, rs);
carry = vshrq_n_s32(carry, 31);
d10 = vqaddq_s32(d10, carry);
d10 = vqrshlq_s32(d10, rs);
d10 = vaddq_s32(d10, out_zp);
d10 = vmaxq_s32(d10, output_min);
d10 = vminq_s32(d10, output_max);
d11 = vqshlq_s32(d11, ls);
d11 = vqrdmulhq_s32(d11, out_multiplier);
carry = vandq_s32(d11, rs);
carry = vshrq_n_s32(carry, 31);
d11 = vqaddq_s32(d11, carry);
d11 = vqrshlq_s32(d11, rs);
d11 = vaddq_s32(d11, out_zp);
d11 = vmaxq_s32(d11, output_min);
d11 = vminq_s32(d11, output_max);
(output_data)[0] = (int8_t)d00[0];
(output_data + 1)[0] = (int8_t)d00[1];
(output_data + 2)[0] = (int8_t)d00[2];
(output_data + 3)[0] = (int8_t)d00[3];
if (w_not_bound) {
*(output_data + 4) = (int8_t)d01[0];
*(output_data + 5) = (int8_t)d01[1];
*(output_data + 6) = (int8_t)d01[2];
*(output_data + 7) = (int8_t)d01[3];
}
if (h_not_bound) {
*(output_data + output_w * 4) = (int8_t)d10[0];
*(output_data + output_w * 4 + 1) = (int8_t)d10[1];
*(output_data + output_w * 4 + 2) = (int8_t)d10[2];
*(output_data + output_w * 4 + 3) = (int8_t)d10[3];
if (w_not_bound) {
*(output_data + output_w * 4 + 4) = (int8_t)d11[0];
*(output_data + output_w * 4 + 5) = (int8_t)d11[1];
*(output_data + output_w * 4 + 6) = (int8_t)d11[2];
*(output_data + output_w * 4 + 7) = (int8_t)d11[3];
}
}
#else
if ((conv_param->conv_quant_arg_.per_channel_ & FILTER_PER_CHANNEL)) {
for (int i = 0; i < C4NUM; i++) {
const int32_t *local_ptr = gemm_out + i;
const int32_t *bias_ptr = bias_data + i;
int32_t s00 = local_ptr[0];
int32_t s01 = (local_ptr + 4)[0];
int32_t s02 = (local_ptr + 8)[0];
int32_t s03 = (local_ptr + 12)[0];
int32_t s10 = (local_ptr + 16)[0];
int32_t s11 = (local_ptr + 20)[0];
int32_t s12 = (local_ptr + 24)[0];
int32_t s13 = (local_ptr + 28)[0];
int32_t s20 = (local_ptr + 32)[0];
int32_t s21 = (local_ptr + 36)[0];
int32_t s22 = (local_ptr + 40)[0];
int32_t s23 = (local_ptr + 44)[0];
int32_t s30 = (local_ptr + 48)[0];
int32_t s31 = (local_ptr + 52)[0];
int32_t s32 = (local_ptr + 56)[0];
int32_t s33 = (local_ptr + 60)[0];
int32_t t00 = (s00 + s10 + s20) / 2;
int32_t t01 = (s01 + s11 + s21) / 2;
int32_t t02 = (s02 + s12 + s22) / 2;
int32_t t03 = (s03 + s13 + s23) / 2;
int32_t t10 = (s10 - s20 - s30) / 2;
int32_t t11 = (s11 - s21 - s31) / 2;
int32_t t12 = (s12 - s22 - s32) / 2;
int32_t t13 = (s13 - s23 - s33) / 2;
int32_t d00 = (t00 + t01 + t02) / 2 + bias_ptr[0];
int32_t d01 = (t01 - t02 - t03) / 2 + bias_ptr[0];
int32_t d10 = (t10 + t11 + t12) / 2 + bias_ptr[0];
int32_t d11 = (t11 - t12 - t13) / 2 + bias_ptr[0];
int oc_index = oc_start + i;
d00 = RoundingDivideByPOT(
SaturatingRoundingDoublingHighMul(d00 * (1 << (unsigned int)left_shift[oc_index]), quant_multiplier[oc_index]),
-right_shift[oc_index]);
d00 += output_zp;
d00 = d00 > out_min ? d00 : out_min;
d00 = d00 < out_max ? d00 : out_max;
d01 = RoundingDivideByPOT(
SaturatingRoundingDoublingHighMul(d01 * (1 << (unsigned int)left_shift[oc_index]), quant_multiplier[oc_index]),
-right_shift[oc_index]);
d01 += output_zp;
d01 = d01 > out_min ? d01 : out_min;
d01 = d01 < out_max ? d01 : out_max;
d10 = RoundingDivideByPOT(
SaturatingRoundingDoublingHighMul(d10 * (1 << (unsigned int)left_shift[oc_index]), quant_multiplier[oc_index]),
-right_shift[oc_index]);
d10 += output_zp;
d10 = d10 > out_min ? d10 : out_min;
d10 = d10 < out_max ? d10 : out_max;
d11 = RoundingDivideByPOT(
SaturatingRoundingDoublingHighMul(d11 * (1 << (unsigned int)left_shift[oc_index]), quant_multiplier[oc_index]),
-right_shift[oc_index]);
d11 += output_zp;
d11 = d11 > out_min ? d11 : out_min;
d11 = d11 < out_max ? d11 : out_max;
(output_data + i)[0] = (int8_t)d00;
if (w_not_bound) {
(output_data + i + C4NUM)[0] = (int8_t)d01;
}
if (h_not_bound) {
(output_data + i + output_w * C4NUM)[0] = (int8_t)d10;
if (w_not_bound) {
(output_data + i + output_w * C4NUM + C4NUM)[0] = (int8_t)d11;
}
}
}
} else {
for (int i = 0; i < C4NUM; i++) {
const int32_t *local_ptr = gemm_out + i;
const int32_t *bias_ptr = bias_data + i;
int32_t s00 = local_ptr[0];
int32_t s01 = (local_ptr + 4)[0];
int32_t s02 = (local_ptr + 8)[0];
int32_t s03 = (local_ptr + 12)[0];
int32_t s10 = (local_ptr + 16)[0];
int32_t s11 = (local_ptr + 20)[0];
int32_t s12 = (local_ptr + 24)[0];
int32_t s13 = (local_ptr + 28)[0];
int32_t s20 = (local_ptr + 32)[0];
int32_t s21 = (local_ptr + 36)[0];
int32_t s22 = (local_ptr + 40)[0];
int32_t s23 = (local_ptr + 44)[0];
int32_t s30 = (local_ptr + 48)[0];
int32_t s31 = (local_ptr + 52)[0];
int32_t s32 = (local_ptr + 56)[0];
int32_t s33 = (local_ptr + 60)[0];
int32_t t00 = (s00 + s10 + s20) / 2;
int32_t t01 = (s01 + s11 + s21) / 2;
int32_t t02 = (s02 + s12 + s22) / 2;
int32_t t03 = (s03 + s13 + s23) / 2;
int32_t t10 = (s10 - s20 - s30) / 2;
int32_t t11 = (s11 - s21 - s31) / 2;
int32_t t12 = (s12 - s22 - s32) / 2;
int32_t t13 = (s13 - s23 - s33) / 2;
int32_t d00 = (t00 + t01 + t02) / 2 + bias_ptr[0];
int32_t d01 = (t01 - t02 - t03) / 2 + bias_ptr[0];
int32_t d10 = (t10 + t11 + t12) / 2 + bias_ptr[0];
int32_t d11 = (t11 - t12 - t13) / 2 + bias_ptr[0];
d00 = RoundingDivideByPOT(
SaturatingRoundingDoublingHighMul(d00 * (1 << (unsigned int)left_shift[0]), quant_multiplier[0]),
-right_shift[0]);
d00 += output_zp;
d00 = d00 > out_min ? d00 : out_min;
d00 = d00 < out_max ? d00 : out_max;
d01 = RoundingDivideByPOT(
SaturatingRoundingDoublingHighMul(d01 * (1 << (unsigned int)left_shift[0]), quant_multiplier[0]),
-right_shift[0]);
d01 += output_zp;
d01 = d01 > out_min ? d01 : out_min;
d01 = d01 < out_max ? d01 : out_max;
d10 = RoundingDivideByPOT(
SaturatingRoundingDoublingHighMul(d10 * (1 << (unsigned int)left_shift[0]), quant_multiplier[0]),
-right_shift[0]);
d10 += output_zp;
d10 = d10 > out_min ? d10 : out_min;
d10 = d10 < out_max ? d10 : out_max;
d11 = RoundingDivideByPOT(
SaturatingRoundingDoublingHighMul(d11 * (1 << (unsigned int)left_shift[0]), quant_multiplier[0]),
-right_shift[0]);
d11 += output_zp;
d11 = d11 > out_min ? d11 : out_min;
d11 = d11 < out_max ? d11 : out_max;
(output_data + i)[0] = (int8_t)d00;
if (w_not_bound) {
(output_data + i + C4NUM)[0] = (int8_t)d01;
}
if (h_not_bound) {
(output_data + i + output_w * C4NUM)[0] = (int8_t)d10;
if (w_not_bound) {
(output_data + i + output_w * C4NUM + C4NUM)[0] = (int8_t)d11;
}
}
}
}
#endif
}
void Conv3x3Int8OutputTransform(const int32_t *gemm_out, int8_t *out_data, const int32_t *bias_data, int start_index,
int real_cal_num, int out_w_block, ConvParameter *conv_param) {
int output_channel = conv_param->output_channel_;
int output_w = conv_param->output_w_;
int output_h = conv_param->output_h_;
const int oc4 = UP_DIV(output_channel, C4NUM);
const int input_unit = 4;
if (out_w_block == 0) {
return;
}
for (int i = 0; i < real_cal_num; i++) {
int out_w_index = (start_index + i) % out_w_block;
int out_h_index = (start_index + i) / out_w_block;
int src_tile_offset = i * oc4 * C4NUM * input_unit * input_unit;
int dst_tile_offset = C4NUM * (out_w_index * OUPUT_UNIT + out_h_index * OUPUT_UNIT * output_w);
for (int j = 0; j < oc4; j++) {
int src_oc4_offset = src_tile_offset + j * input_unit * input_unit * C4NUM;
int dst_oc4_offset = dst_tile_offset + j * C4NUM * output_h * output_w;
const int32_t *src_ptr = gemm_out + src_oc4_offset;
const int32_t *bias_ptr = bias_data + j * C4NUM;
int8_t *dst_ptr = out_data + dst_oc4_offset;
// output transform
int real_num = (output_channel - j * C4NUM) < C4NUM ? (output_channel - j * C4NUM) : C4NUM;
bool w_not_bound = out_w_index * OUPUT_UNIT + 1 < output_w;
bool h_not_bound = out_h_index * OUPUT_UNIT + 1 < output_h;
Conv3x3Int8OutputUnit(src_ptr, bias_ptr, dst_ptr, h_not_bound, w_not_bound, output_w, real_num, j * C4NUM,
conv_param);
}
}
}

Some files were not shown because too many files have changed in this diff Show More