forked from mindspore-Ecosystem/mindspore
[MSLITE] arithmetic optimize
This commit is contained in:
parent
77adcecedb
commit
8a78b1f362
|
@ -15,7 +15,7 @@ file(GLOB KERNEL_SRC
|
|||
${NNACL_DIR}/*.c
|
||||
${NNACL_DIR}/fp32/*.c
|
||||
${NNACL_DIR}/int8/*.c
|
||||
${NNACL_DIR}/quantization/*.c
|
||||
${NNACL_DIR}/base/*.c
|
||||
)
|
||||
|
||||
if (SUPPORT_TRAIN)
|
||||
|
|
|
@ -42,12 +42,4 @@ typedef struct ArithmeticParameter {
|
|||
int multiples1_[10];
|
||||
} ArithmeticParameter;
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif
|
||||
void CalcMultiplesAndStrides(ArithmeticParameter *param);
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
|
||||
#endif // MINDSPORE_LITE_NNACL_ARTITHMETIC_H_
|
||||
|
|
|
@ -19,7 +19,7 @@
|
|||
|
||||
#include "nnacl/op_base.h"
|
||||
#include "nnacl/errorcode.h"
|
||||
#include "nnacl/quantization/quantize.h"
|
||||
#include "nnacl/int8/quantize.h"
|
||||
|
||||
// For Abs, Cos, Exp, Log, Square, Sqrt, Rsqrt ops.
|
||||
typedef struct ArithmeticSelfParameter {
|
||||
|
|
|
@ -14,7 +14,7 @@
|
|||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "nnacl/arithmetic.h"
|
||||
#include "nnacl/base/arithmetic_base.h"
|
||||
|
||||
void CalcMultiplesAndStrides(ArithmeticParameter *param) {
|
||||
NNACL_ASSERT(param->in_shape0_[i] != 0);
|
|
@ -0,0 +1,34 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_LITE_NNACL_BASE_ARITHMETIC_BASE_H_
|
||||
#define MINDSPORE_LITE_NNACL_BASE_ARITHMETIC_BASE_H_
|
||||
|
||||
#include "nnacl/arithmetic.h"
|
||||
#include "nnacl/nnacl_utils.h"
|
||||
#include "nnacl/nnacl_common.h"
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif
|
||||
|
||||
void CalcMultiplesAndStrides(ArithmeticParameter *param);
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
|
||||
#endif // MINDSPORE_LITE_NNACL_BASE_ARITHMETIC_BASE_H_
|
|
@ -0,0 +1,40 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "nnacl/base/conv1x1_base.h"
|
||||
|
||||
void Conv1x1InputPack(const void *src_ptr, void *dst_ptr, ConvParameter *conv_param, int data_size) {
|
||||
/* support nhwc */
|
||||
char *src = (char *)src_ptr;
|
||||
char *dst = (char *)dst_ptr;
|
||||
for (int dst_h = 0; dst_h < conv_param->output_h_; dst_h++) {
|
||||
int src_h = dst_h * conv_param->stride_h_ - conv_param->pad_u_;
|
||||
if (src_h < 0 || src_h >= conv_param->input_h_) {
|
||||
continue;
|
||||
}
|
||||
const char *src_h_ptr = src + src_h * conv_param->input_w_ * conv_param->input_channel_ * data_size;
|
||||
char *dst_h_ptr = dst + dst_h * conv_param->output_w_ * conv_param->input_channel_ * data_size;
|
||||
for (int dst_w = 0; dst_w < conv_param->output_w_; dst_w++) {
|
||||
int src_w = dst_w * conv_param->stride_w_ - conv_param->pad_l_;
|
||||
if (src_w < 0 || src_w >= conv_param->input_w_) {
|
||||
continue;
|
||||
}
|
||||
memcpy(dst_h_ptr + dst_w * conv_param->input_channel_ * data_size,
|
||||
src_h_ptr + src_w * conv_param->input_channel_ * data_size, conv_param->input_channel_ * data_size);
|
||||
}
|
||||
}
|
||||
return;
|
||||
}
|
|
@ -0,0 +1,32 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_LITE_NNACL_BASE_CONV1X1_BASE_H_
|
||||
#define MINDSPORE_LITE_NNACL_BASE_CONV1X1_BASE_H_
|
||||
|
||||
#include "nnacl/conv_parameter.h"
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif
|
||||
|
||||
void Conv1x1InputPack(const void *src_ptr, void *dst_ptr, ConvParameter *conv_param, int data_size);
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
|
||||
#endif // MINDSPORE_LITE_NNACL_BASE_CONV1X1_BASE_H_
|
|
@ -13,8 +13,8 @@
|
|||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "nnacl/depth_to_space.h"
|
||||
#include <string.h>
|
||||
|
||||
#include "nnacl/base/depth_to_space_base.h"
|
||||
|
||||
void DepthToSpaceForNHWC(const void *input, void *output, const int *in_shape, const DepthToSpaceParameter *param) {
|
||||
int32_t block_size = param->block_size_;
|
|
@ -15,6 +15,8 @@
|
|||
*/
|
||||
#ifndef MINDSPORE_LITE_NNACL_DEPTH_TO_SPACE_H_
|
||||
#define MINDSPORE_LITE_NNACL_DEPTH_TO_SPACE_H_
|
||||
|
||||
#include <string.h>
|
||||
#include "nnacl/depth_to_space_parameter.h"
|
||||
|
||||
#ifdef __cplusplus
|
|
@ -18,7 +18,7 @@
|
|||
|
||||
#include <math.h>
|
||||
#include "nnacl/op_base.h"
|
||||
#include "nnacl/quantization/fixed_point.h"
|
||||
#include "mindspore/lite/nnacl/int8/fixed_point.h"
|
||||
|
||||
typedef struct ClipParameter {
|
||||
OpParameter op_parameter_;
|
||||
|
|
|
@ -17,11 +17,10 @@
|
|||
#ifndef MINDSPORE_LITE_NNACL_COMMON_FUNC_H_
|
||||
#define MINDSPORE_LITE_NNACL_COMMON_FUNC_H_
|
||||
|
||||
#include <stdint.h>
|
||||
#include <stdio.h>
|
||||
#include <string.h>
|
||||
#include "nnacl/op_base.h"
|
||||
#include "nnacl/conv_parameter.h"
|
||||
#include "nnacl/nnacl_common.h"
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
|
@ -63,14 +62,6 @@ static inline int GetStride(int *strides, const int *shape, int length) {
|
|||
return stride;
|
||||
}
|
||||
|
||||
inline void ComputeStrides(const int *shape, int *strides, const int ndim) {
|
||||
int stride = 1;
|
||||
for (int i = ndim - 1; i >= 0; i--) {
|
||||
strides[i] = stride;
|
||||
stride *= shape[i];
|
||||
}
|
||||
}
|
||||
|
||||
#ifdef ENABLE_ARM64
|
||||
void BiasAdd(const float *bias, float *data, size_t oc4, size_t plan_size);
|
||||
void BiasAddRelu6(const float *bias, float *data, size_t oc4, size_t plan_size);
|
||||
|
|
|
@ -18,7 +18,7 @@
|
|||
#define MINDSPORE_LITE_NNACL_CONCAT_PARAMETER_H_
|
||||
|
||||
#include "nnacl/op_base.h"
|
||||
#include "nnacl/quantization/quantize.h"
|
||||
#include "mindspore/lite/nnacl/int8/quantize.h"
|
||||
|
||||
typedef struct ConcatParameter {
|
||||
OpParameter op_parameter_;
|
||||
|
|
|
@ -21,7 +21,7 @@
|
|||
#include <arm_neon.h>
|
||||
#endif
|
||||
#include "nnacl/op_base.h"
|
||||
#include "nnacl/quantization/quantize.h"
|
||||
#include "mindspore/lite/nnacl/int8/quantize.h"
|
||||
|
||||
typedef struct ConvParameter {
|
||||
OpParameter op_parameter_;
|
||||
|
|
|
@ -18,7 +18,7 @@
|
|||
#define MINDSPORE_LITE_NNACL_CROP_PARAMETER_H_
|
||||
|
||||
#include "nnacl/op_base.h"
|
||||
#include "nnacl/quantization/quantize.h"
|
||||
#include "mindspore/lite/nnacl/int8/quantize.h"
|
||||
|
||||
#define CROP_OFFSET_MAX_SIZE 4
|
||||
|
||||
|
|
|
@ -21,7 +21,7 @@
|
|||
#endif
|
||||
#include <math.h>
|
||||
#include "nnacl/op_base.h"
|
||||
#include "nnacl/quantization/fixed_point.h"
|
||||
#include "mindspore/lite/nnacl/int8/fixed_point.h"
|
||||
|
||||
typedef struct ActivationParameter {
|
||||
OpParameter op_parameter_;
|
||||
|
|
|
@ -20,7 +20,7 @@
|
|||
#include <arm_neon.h>
|
||||
#endif
|
||||
#include "nnacl/op_base.h"
|
||||
#include "nnacl/arithmetic.h"
|
||||
#include "nnacl/base/arithmetic_base.h"
|
||||
#include "nnacl/errorcode.h"
|
||||
|
||||
#ifdef __cplusplus
|
||||
|
|
|
@ -16,7 +16,6 @@
|
|||
|
||||
#include "nnacl/fp16/pack_fp16.h"
|
||||
#include <string.h>
|
||||
#include <stdlib.h>
|
||||
|
||||
void Im2ColPackUnitFp16(float16_t *input_data, ConvParameter *conv_param, float16_t *packed_input, int real_cal_num,
|
||||
int block_index) {
|
||||
|
|
|
@ -17,6 +17,7 @@
|
|||
#ifndef MINDSPORE_LITE_NNACL_FP16_POOLING_FP16_H_
|
||||
#define MINDSPORE_LITE_NNACL_FP16_POOLING_FP16_H_
|
||||
|
||||
#include <math.h>
|
||||
#ifdef ENABLE_NEON
|
||||
#include <arm_neon.h>
|
||||
#endif
|
||||
|
|
|
@ -18,7 +18,7 @@
|
|||
|
||||
#include <math.h>
|
||||
#include "nnacl/op_base.h"
|
||||
#include "nnacl/quantization/fixed_point.h"
|
||||
#include "mindspore/lite/nnacl/int8/fixed_point.h"
|
||||
|
||||
typedef struct ActivationParameter {
|
||||
OpParameter op_parameter_;
|
||||
|
|
|
@ -13,8 +13,8 @@
|
|||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "nnacl/fp32/arg_min_max_fp32.h"
|
||||
#include <stdlib.h>
|
||||
#include <float.h>
|
||||
|
||||
int ArgCompareAscFp32(const void *a, const void *b) {
|
||||
|
|
|
@ -20,7 +20,7 @@
|
|||
#include <arm_neon.h>
|
||||
#endif
|
||||
#include "nnacl/op_base.h"
|
||||
#include "nnacl/arithmetic.h"
|
||||
#include "nnacl/base/arithmetic_base.h"
|
||||
#include "nnacl/errorcode.h"
|
||||
|
||||
#ifdef __cplusplus
|
||||
|
|
|
@ -17,8 +17,6 @@
|
|||
#ifndef MINDSPORE_LITE_NNACL_FP32_COMMON_FUNC_H_
|
||||
#define MINDSPORE_LITE_NNACL_FP32_COMMON_FUNC_H_
|
||||
|
||||
#include <stdint.h>
|
||||
#include <stdio.h>
|
||||
#include <string.h>
|
||||
#include "nnacl/op_base.h"
|
||||
#include "nnacl/conv_parameter.h"
|
||||
|
|
|
@ -0,0 +1,479 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "nnacl/fp32/pack_fp32.h"
|
||||
|
||||
void PackWeightKHWToHWKFp32(const void *src, void *dst, int plane, int channel) {
|
||||
return PackNCHWToNHWCFp32(src, dst, 1, plane, channel);
|
||||
}
|
||||
|
||||
void PackHWCToWHC(const float *src, float *dst, int height, int width, int channel) {
|
||||
for (int i = 0; i < height; ++i) {
|
||||
for (int j = 0; j < width; ++j) {
|
||||
memcpy(dst + (j * height + i) * channel, src + (i * width + j) * channel, channel * sizeof(float));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void Im2ColPackUnitFp32(const float *input_data, const ConvParameter *conv_param, float *packed_input, int real_cal_num,
|
||||
int block_index) {
|
||||
// input format : nhwc
|
||||
int kernel_h = conv_param->kernel_h_;
|
||||
int kernel_w = conv_param->kernel_w_;
|
||||
int kernel_plane = kernel_h * kernel_w;
|
||||
int dilation_h = conv_param->dilation_h_;
|
||||
int dilation_w = conv_param->dilation_w_;
|
||||
int in_channel = conv_param->input_channel_;
|
||||
int in_w = conv_param->input_w_;
|
||||
int out_w = conv_param->output_w_;
|
||||
|
||||
for (int i = 0; i < real_cal_num; i++) {
|
||||
int block_start = block_index + i;
|
||||
int input_h = block_start / out_w * conv_param->stride_h_ - conv_param->pad_u_;
|
||||
int input_w = block_start % out_w * conv_param->stride_w_ - conv_param->pad_l_;
|
||||
int input_stride = (input_h * in_w + input_w) * in_channel;
|
||||
int kh_s = MSMAX(0, UP_DIV(-input_h, dilation_h));
|
||||
int kh_e = MSMIN(kernel_h, UP_DIV(conv_param->input_h_ - input_h, dilation_h));
|
||||
int kw_s = MSMAX(0, UP_DIV(-input_w, dilation_w));
|
||||
int kw_e = MSMIN(kernel_w, UP_DIV(in_w - input_w, dilation_w));
|
||||
if (dilation_w == 1 && dilation_h == 1) {
|
||||
for (int j = kh_s; j < kh_e; j++) {
|
||||
int input_y_stride = j * in_w * in_channel + input_stride;
|
||||
int input_x_stride = input_y_stride + kw_s * in_channel;
|
||||
int input_plane_offset = (j * kernel_w + kw_s) * in_channel + i * in_channel * kernel_plane;
|
||||
memcpy(packed_input + input_plane_offset, input_data + input_x_stride,
|
||||
(kw_e - kw_s) * in_channel * sizeof(float));
|
||||
} // kernel_h loop
|
||||
} else {
|
||||
for (int j = kh_s; j < kh_e; j++) {
|
||||
int input_y_stride = j * dilation_h * in_w * in_channel + input_stride;
|
||||
for (int k = kw_s; k < kw_e; ++k) {
|
||||
int input_x_stride = input_y_stride + k * dilation_w * in_channel;
|
||||
int input_plane_offset = (j * kernel_w + k) * in_channel + i * in_channel * kernel_plane;
|
||||
memcpy(packed_input + input_plane_offset, input_data + input_x_stride, in_channel * sizeof(float));
|
||||
}
|
||||
} // kernel_h loop
|
||||
}
|
||||
} // tile num loop
|
||||
}
|
||||
|
||||
void PackNHWCToNC4HW4Fp32(const void *src, void *dst, int batch, int plane, int channel) {
|
||||
int c4 = UP_DIV(channel, C4NUM);
|
||||
int c4_minus = c4 - 1;
|
||||
for (int b = 0; b < batch; b++) {
|
||||
int src_oc_offset = b * plane * channel;
|
||||
int dst_oc_offset = b * plane * c4 * C4NUM;
|
||||
for (int k = 0; k < plane; k++) {
|
||||
int src_kernel_offset = src_oc_offset + k * channel;
|
||||
int dst_kernel_offset = dst_oc_offset + k * C4NUM;
|
||||
for (int j = 0; j < c4_minus; ++j) {
|
||||
int src_ic_offset = src_kernel_offset + j * C4NUM;
|
||||
int dst_ic_offset = dst_kernel_offset + j * plane * C4NUM;
|
||||
#ifdef ENABLE_ARM
|
||||
vst1q_f32((float *)dst + dst_ic_offset, vld1q_f32((float *)src + src_ic_offset));
|
||||
#else
|
||||
for (int i = 0; i < C4NUM; ++i) {
|
||||
((float *)dst + dst_ic_offset)[i] = ((float *)src + src_ic_offset)[i];
|
||||
}
|
||||
#endif
|
||||
}
|
||||
int tmp_c = c4_minus * C4NUM;
|
||||
int tmp_c_offset = tmp_c * plane;
|
||||
int res_c = channel - tmp_c;
|
||||
for (int l = 0; l < res_c; ++l) {
|
||||
int src_ic_offset = src_kernel_offset + tmp_c + l;
|
||||
int dst_ic_offset = dst_kernel_offset + tmp_c_offset + l;
|
||||
((float *)dst + dst_ic_offset)[0] = ((float *)src + src_ic_offset)[0];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void PackNCHWToNC4HW4Fp32(const void *src, void *dst, int batch, int plane, int channel) {
|
||||
int c4 = UP_DIV(channel, C4NUM);
|
||||
for (int b = 0; b < batch; b++) {
|
||||
int src_offset = b * plane * channel;
|
||||
int dst_offset = b * plane * c4 * C4NUM;
|
||||
for (int c = 0; c < channel; c++) {
|
||||
int c4_block_num = c / C4NUM;
|
||||
int c4_block_rem = c % C4NUM;
|
||||
int src_c_offset = src_offset + c * plane;
|
||||
int dst_c_offset = dst_offset + c4_block_num * plane * C4NUM;
|
||||
for (int k = 0; k < plane; k++) {
|
||||
int src_kernel_offset = src_c_offset + k;
|
||||
int dst_kernel_offset = dst_c_offset + C4NUM * k + c4_block_rem;
|
||||
((float *)dst + dst_kernel_offset)[0] = ((float *)src + src_kernel_offset)[0];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void PackNHWCToNHWC4Fp32(const void *src, void *dst, int batch, int plane, int channel) {
|
||||
int c4 = UP_DIV(channel, C4NUM);
|
||||
int c4_channel = c4 * C4NUM;
|
||||
int nhwc4_batch_unit_offset = c4 * C4NUM * plane;
|
||||
int ic_remainder_ = channel % C4NUM;
|
||||
if (ic_remainder_ != 0) {
|
||||
int nhwc4_batch_offset = 0;
|
||||
for (int b = 0; b < batch; b++) {
|
||||
int batch_offset = b * channel * plane;
|
||||
for (int i = 0; i < plane; i++) {
|
||||
float *dst_per_plane = (float *)dst + nhwc4_batch_offset + i * c4_channel;
|
||||
memcpy(dst_per_plane, (float *)src + batch_offset + i * channel, channel * sizeof(float));
|
||||
for (int j = channel; j < c4_channel; ++j) {
|
||||
dst_per_plane[j] = 0;
|
||||
}
|
||||
}
|
||||
nhwc4_batch_offset += nhwc4_batch_unit_offset;
|
||||
}
|
||||
} else {
|
||||
size_t ori_input_size = batch * plane * channel * sizeof(float);
|
||||
memcpy((float *)dst, (float *)src, ori_input_size);
|
||||
}
|
||||
}
|
||||
|
||||
void PackNHWCToNHWC8Fp32(const void *src, void *dst, int batch, int plane, int channel) {
|
||||
int c8 = UP_DIV(channel, C8NUM);
|
||||
int c8_channel = c8 * C8NUM;
|
||||
int nhwc8_batch_unit_offset = c8 * C8NUM * plane;
|
||||
int ic_remainder_ = channel % C8NUM;
|
||||
if (ic_remainder_ != 0) {
|
||||
int nhwc8_batch_offset = 0;
|
||||
for (int b = 0; b < batch; b++) {
|
||||
int batch_offset = b * channel * plane;
|
||||
for (int i = 0; i < plane; i++) {
|
||||
float *dst_per_plane = (float *)dst + nhwc8_batch_offset + i * c8_channel;
|
||||
memcpy(dst_per_plane, (float *)src + batch_offset + i * channel, channel * sizeof(float));
|
||||
for (int j = channel; j < c8_channel; ++j) {
|
||||
dst_per_plane[j] = 0;
|
||||
}
|
||||
}
|
||||
nhwc8_batch_offset += nhwc8_batch_unit_offset;
|
||||
}
|
||||
} else {
|
||||
size_t ori_input_size = batch * plane * channel * sizeof(float);
|
||||
memcpy((float *)dst, (float *)src, ori_input_size);
|
||||
}
|
||||
}
|
||||
|
||||
void PackNHWC4ToNHWCFp32(const void *src, void *dst, int batch, int plane, int channel) {
|
||||
int c4 = UP_DIV(channel, C4NUM);
|
||||
int ic_remainder_ = channel % C4NUM;
|
||||
if (ic_remainder_ != 0) {
|
||||
int nhwc_batch_unit_offset = channel * plane;
|
||||
for (int b = 0; b < batch; b++) {
|
||||
int batch_offset = b * c4 * C4NUM * plane;
|
||||
for (int i = 0; i < plane; i++) {
|
||||
memcpy((float *)dst + b * nhwc_batch_unit_offset + i * channel, (float *)src + batch_offset + i * c4 * C4NUM,
|
||||
channel * sizeof(float));
|
||||
}
|
||||
}
|
||||
} else {
|
||||
size_t ori_input_size = batch * plane * channel * sizeof(float);
|
||||
memcpy((float *)dst, (float *)src, ori_input_size);
|
||||
}
|
||||
}
|
||||
|
||||
void PackNC4HW4ToNHWC4Fp32(const void *src, void *dst, int batch, int plane, int channel) {
|
||||
int c4 = UP_DIV(channel, C4NUM);
|
||||
for (int b = 0; b < batch; b++) {
|
||||
int src_offset = b * plane * c4 * C4NUM;
|
||||
int dst_offset = b * plane * channel;
|
||||
for (int c = 0; c < channel; c++) {
|
||||
int c4_block_num = c / C4NUM;
|
||||
int c4_block_res = c % C4NUM;
|
||||
int src_c_offset = src_offset + c4_block_num * plane * C4NUM + c4_block_res;
|
||||
int dst_c_offset = dst_offset + c4_block_num * C4NUM + c4_block_res;
|
||||
for (int k = 0; k < plane; k++) {
|
||||
int src_kernel_offset = src_c_offset + k * C4NUM;
|
||||
int dst_kernel_offset = dst_c_offset + k * c4 * C4NUM;
|
||||
((float *)dst + dst_kernel_offset)[0] = ((float *)src + src_kernel_offset)[0];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void PackNC4HW4ToNHWCFp32(const void *src, void *dst, int batch, int plane, int channel) {
|
||||
int c4 = UP_DIV(channel, C4NUM);
|
||||
for (int b = 0; b < batch; b++) {
|
||||
int src_offset = b * plane * c4 * C4NUM;
|
||||
int dst_offset = b * plane * channel;
|
||||
for (int k = 0; k < plane; k++) {
|
||||
int src_kernel_offset = src_offset + k * C4NUM;
|
||||
int dst_kernel_offset = dst_offset + k * channel;
|
||||
for (int c = 0; c < c4 - 1; c++) {
|
||||
int src_c_offset = src_kernel_offset + c * plane * C4NUM;
|
||||
int dst_c_offset = dst_kernel_offset + c * C4NUM;
|
||||
#ifdef ENABLE_NEON
|
||||
vst1q_f32((float *)dst + dst_c_offset, vld1q_f32((float *)src + src_c_offset));
|
||||
#else
|
||||
((float *)dst + dst_c_offset)[0] = ((float *)src + src_c_offset)[0];
|
||||
((float *)dst + dst_c_offset)[1] = ((float *)src + src_c_offset)[1];
|
||||
((float *)dst + dst_c_offset)[2] = ((float *)src + src_c_offset)[2];
|
||||
((float *)dst + dst_c_offset)[3] = ((float *)src + src_c_offset)[3];
|
||||
#endif
|
||||
}
|
||||
// res part
|
||||
int res_c = channel - (c4 - 1) * C4NUM;
|
||||
for (int i = 0; i < res_c; i++) {
|
||||
int src_res_c_offset = src_kernel_offset + (c4 - 1) * C4NUM * plane + i;
|
||||
int dst_res_c_offset = dst_kernel_offset + (c4 - 1) * C4NUM + i;
|
||||
((float *)dst + dst_res_c_offset)[0] = ((float *)src + src_res_c_offset)[0];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void PackNHWCToC8HWN8Fp32(const void *src, void *dst, int batch, int plane, int channel) {
|
||||
for (int n = 0; n < batch; n++) {
|
||||
for (int hw = 0; hw < plane; hw++) {
|
||||
for (int c = 0; c < channel; c++) {
|
||||
int c8div = c / C8NUM;
|
||||
int c8mod = c % C8NUM;
|
||||
int src_index = n * plane * channel + hw * channel + c;
|
||||
int dst_index = c8div * batch * plane * C8NUM + hw * batch * C8NUM + n * C8NUM + c8mod;
|
||||
((float *)dst)[dst_index] = ((float *)src)[src_index];
|
||||
}
|
||||
}
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
void PackDepthwiseIndirectWeightC4Fp32(const void *src, void *dst, int height, int width, int channel) {
|
||||
int c4 = UP_DIV(channel, C4NUM);
|
||||
for (int c = 0; c < c4; c++) {
|
||||
int dst_off_c = c * C4NUM * height * width;
|
||||
for (int i = 0; i < C4NUM; i++) {
|
||||
int src_off_c = (c * C4NUM + i) * height * width;
|
||||
for (int kh = 0; kh < height; kh++) {
|
||||
int src_off_kh = src_off_c + kh * width;
|
||||
for (int kw = 0; kw < width; kw++) {
|
||||
int dst_off = dst_off_c + kw * height * C4NUM + kh * C4NUM + i;
|
||||
((float *)dst)[dst_off] = ((float *)src)[src_off_kh + kw];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void PackDepthwiseIndirectWeightC8Fp32(const void *src, void *dst, int height, int width, int channel) {
|
||||
int c8 = UP_DIV(channel, C8NUM);
|
||||
for (int c = 0; c < c8; c++) {
|
||||
int dst_off_c = c * C8NUM * height * width;
|
||||
for (int i = 0; i < C8NUM; i++) {
|
||||
int src_off_c = (c * C8NUM + i) * height * width;
|
||||
for (int kh = 0; kh < height; kh++) {
|
||||
int src_off_kh = src_off_c + kh * width;
|
||||
for (int kw = 0; kw < width; kw++) {
|
||||
int dst_off = dst_off_c + kw * height * C8NUM + kh * C8NUM + i;
|
||||
((float *)dst)[dst_off] = ((float *)src)[src_off_kh + kw];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#ifndef ENABLE_SSE
|
||||
void PackNHWCToNCHWFp32(const void *src, void *dst, int batches, int plane, int channel) {
|
||||
int hw8 = plane / C8NUM * C8NUM;
|
||||
int c8 = channel / C8NUM * C8NUM;
|
||||
int batch = plane * channel;
|
||||
for (int n = 0; n < batches; n++) {
|
||||
const float *src_batch = (const float *)src + n * batch;
|
||||
float *dst_batch = (float *)dst + n * batch;
|
||||
int hw = 0;
|
||||
for (; hw < hw8; hw += C8NUM) {
|
||||
int c = 0;
|
||||
for (; c < c8; c += C8NUM) {
|
||||
const float *src_ptr = src_batch + hw * channel + c;
|
||||
float *dst_ptr = dst_batch + c * plane + hw;
|
||||
#ifdef ENABLE_ARM64
|
||||
size_t srcStride = channel * sizeof(float);
|
||||
size_t dstStride = plane * sizeof(float);
|
||||
asm volatile(
|
||||
"mov x10, %[src_ptr]\n"
|
||||
"mov x11, %[dst_ptr]\n"
|
||||
|
||||
"ld1 {v0.4s, v1.4s}, [x10], %[srcStride]\n"
|
||||
"ld1 {v2.4s, v3.4s}, [x10], %[srcStride]\n"
|
||||
|
||||
"zip1 v8.4s, v0.4s, v2.4s\n"
|
||||
"zip2 v9.4s, v0.4s, v2.4s\n"
|
||||
"zip1 v12.4s, v1.4s, v3.4s\n"
|
||||
"zip2 v13.4s, v1.4s, v3.4s\n"
|
||||
|
||||
"ld1 {v4.4s, v5.4s}, [x10], %[srcStride]\n"
|
||||
"ld1 {v6.4s, v7.4s}, [x10], %[srcStride]\n"
|
||||
|
||||
"zip1 v10.4s, v4.4s, v6.4s\n"
|
||||
"zip2 v11.4s, v4.4s, v6.4s\n"
|
||||
"zip1 v14.4s, v5.4s, v7.4s\n"
|
||||
"zip2 v15.4s, v5.4s, v7.4s\n"
|
||||
|
||||
"ld1 {v0.4s, v1.4s}, [x10], %[srcStride]\n"
|
||||
"ld1 {v2.4s, v3.4s}, [x10], %[srcStride]\n"
|
||||
|
||||
"trn1 v16.2d, v8.2d, v10.2d\n"
|
||||
"trn2 v18.2d, v8.2d, v10.2d\n"
|
||||
"trn1 v20.2d, v9.2d, v11.2d\n"
|
||||
"trn2 v22.2d, v9.2d, v11.2d\n"
|
||||
|
||||
"ld1 {v4.4s, v5.4s}, [x10], %[srcStride]\n"
|
||||
"ld1 {v6.4s, v7.4s}, [x10], %[srcStride]\n"
|
||||
|
||||
"trn1 v24.2d, v12.2d, v14.2d\n"
|
||||
"trn2 v26.2d, v12.2d, v14.2d\n"
|
||||
"trn1 v28.2d, v13.2d, v15.2d\n"
|
||||
"trn2 v30.2d, v13.2d, v15.2d\n"
|
||||
|
||||
"zip1 v8.4s, v0.4s, v2.4s\n"
|
||||
"zip2 v9.4s, v0.4s, v2.4s\n"
|
||||
"zip1 v12.4s, v1.4s, v3.4s\n"
|
||||
"zip2 v13.4s, v1.4s, v3.4s\n"
|
||||
|
||||
"zip1 v10.4s, v4.4s, v6.4s\n"
|
||||
"zip2 v11.4s, v4.4s, v6.4s\n"
|
||||
"zip1 v14.4s, v5.4s, v7.4s\n"
|
||||
"zip2 v15.4s, v5.4s, v7.4s\n"
|
||||
|
||||
"trn1 v17.2d, v8.2d, v10.2d\n"
|
||||
"trn2 v19.2d, v8.2d, v10.2d\n"
|
||||
"trn1 v21.2d, v9.2d, v11.2d\n"
|
||||
"trn2 v23.2d, v9.2d, v11.2d\n"
|
||||
|
||||
"trn1 v25.2d, v12.2d, v14.2d\n"
|
||||
"trn2 v27.2d, v12.2d, v14.2d\n"
|
||||
"trn1 v29.2d, v13.2d, v15.2d\n"
|
||||
"trn2 v31.2d, v13.2d, v15.2d\n"
|
||||
|
||||
"st1 {v16.4s, v17.4s}, [x11], %[dstStride]\n"
|
||||
"st1 {v18.4s, v19.4s}, [x11], %[dstStride]\n"
|
||||
"st1 {v20.4s, v21.4s}, [x11], %[dstStride]\n"
|
||||
"st1 {v22.4s, v23.4s}, [x11], %[dstStride]\n"
|
||||
"st1 {v24.4s, v25.4s}, [x11], %[dstStride]\n"
|
||||
"st1 {v26.4s, v27.4s}, [x11], %[dstStride]\n"
|
||||
"st1 {v28.4s, v29.4s}, [x11], %[dstStride]\n"
|
||||
"st1 {v30.4s, v31.4s}, [x11], %[dstStride]\n"
|
||||
|
||||
:
|
||||
:
|
||||
[ dst_ptr ] "r"(dst_ptr), [ src_ptr ] "r"(src_ptr), [ srcStride ] "r"(srcStride), [ dstStride ] "r"(dstStride)
|
||||
: "x10", "x11", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14",
|
||||
"v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29",
|
||||
"v30", "v31");
|
||||
#elif ENABLE_ARM32
|
||||
size_t srcStride = channel * sizeof(float);
|
||||
size_t dstStride = plane * sizeof(float);
|
||||
asm volatile(
|
||||
"mov r10, %[src_ptr]\n"
|
||||
"mov r12, %[dst_ptr]\n"
|
||||
|
||||
"vld1.32 {q0, q1}, [r10], %[srcStride]\n"
|
||||
"vld1.32 {q2, q3}, [r10], %[srcStride]\n"
|
||||
|
||||
"vtrn.32 d0, d4\n"
|
||||
"vtrn.32 d1, d5\n"
|
||||
"vtrn.32 d2, d6\n"
|
||||
"vtrn.32 d3, d7\n"
|
||||
|
||||
"vld1.32 {q4, q5}, [r10], %[srcStride]\n"
|
||||
"vld1.32 {q6, q7}, [r10], %[srcStride]\n"
|
||||
|
||||
"vtrn.32 d8, d12\n"
|
||||
"vtrn.32 d9, d13\n"
|
||||
"vtrn.32 d10, d14\n"
|
||||
"vtrn.32 d11, d15\n"
|
||||
|
||||
"vld1.32 {q8, q9}, [r10], %[srcStride]\n"
|
||||
"vld1.32 {q10, q11}, [r10], %[srcStride]\n"
|
||||
|
||||
"vswp d1, d8\n"
|
||||
"vswp d3, d10\n"
|
||||
"vswp d5, d12\n"
|
||||
"vswp d7, d14\n"
|
||||
|
||||
"vtrn.32 d16, d20\n"
|
||||
"vtrn.32 d17, d21\n"
|
||||
"vtrn.32 d18, d22\n"
|
||||
"vtrn.32 d19, d23\n"
|
||||
|
||||
"vld1.32 {q12, q13}, [r10], %[srcStride]\n"
|
||||
"vld1.32 {q14, q15}, [r10], %[srcStride]\n"
|
||||
|
||||
"vtrn.32 d24, d28\n"
|
||||
"vtrn.32 d25, d29\n"
|
||||
"vtrn.32 d26, d30\n"
|
||||
"vtrn.32 d27, d31\n"
|
||||
|
||||
"vswp d17, d24\n"
|
||||
"vswp d19, d26\n"
|
||||
"vswp d21, d28\n"
|
||||
"vswp d23, d30\n"
|
||||
|
||||
"add r10, r12, #16\n"
|
||||
"vst1.32 {q0}, [r12], %[dstStride]\n"
|
||||
"vst1.32 {q8}, [r10], %[dstStride]\n"
|
||||
"vst1.32 {q2}, [r12], %[dstStride]\n"
|
||||
"vst1.32 {q10}, [r10], %[dstStride]\n"
|
||||
"vst1.32 {q4}, [r12], %[dstStride]\n"
|
||||
"vst1.32 {q12}, [r10], %[dstStride]\n"
|
||||
"vst1.32 {q6}, [r12], %[dstStride]\n"
|
||||
"vst1.32 {q14}, [r10], %[dstStride]\n"
|
||||
"vst1.32 {q1}, [r12], %[dstStride]\n"
|
||||
"vst1.32 {q9}, [r10], %[dstStride]\n"
|
||||
"vst1.32 {q3}, [r12], %[dstStride]\n"
|
||||
"vst1.32 {q11}, [r10], %[dstStride]\n"
|
||||
"vst1.32 {q5}, [r12], %[dstStride]\n"
|
||||
"vst1.32 {q13}, [r10], %[dstStride]\n"
|
||||
"vst1.32 {q7}, [r12], %[dstStride]\n"
|
||||
"vst1.32 {q15}, [r10], %[dstStride]\n"
|
||||
|
||||
:
|
||||
:
|
||||
[ dst_ptr ] "r"(dst_ptr), [ src_ptr ] "r"(src_ptr), [ srcStride ] "r"(srcStride), [ dstStride ] "r"(dstStride)
|
||||
: "r10", "r12", "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7", "q8", "q9", "q10", "q11", "q12", "q13", "q14",
|
||||
"q15");
|
||||
#else
|
||||
for (int tr = 0; tr < C8NUM; tr++) {
|
||||
for (int tc = 0; tc < C8NUM; tc++) {
|
||||
dst_ptr[tc * plane + tr] = src_ptr[tr * channel + tc];
|
||||
}
|
||||
}
|
||||
#endif
|
||||
}
|
||||
for (; c < channel; c++) {
|
||||
const float *src_ptr = src_batch + hw * channel + c;
|
||||
float *dst_ptr = dst_batch + c * plane + hw;
|
||||
for (size_t i = 0; i < C8NUM; i++) {
|
||||
dst_ptr[i] = src_ptr[i * channel];
|
||||
}
|
||||
}
|
||||
}
|
||||
for (; hw < plane; hw++) {
|
||||
const float *src_ptr = src_batch + hw * channel;
|
||||
float *dst_ptr = dst_batch + hw;
|
||||
for (size_t i = 0; i < channel; i++) {
|
||||
dst_ptr[i * plane] = src_ptr[i];
|
||||
}
|
||||
}
|
||||
}
|
||||
return;
|
||||
}
|
||||
#endif
|
||||
|
||||
void PackNCHWToNHWCFp32(const void *src, void *dst, int batch, int plane, int channel) {
|
||||
return PackNHWCToNCHWFp32(src, dst, batch, channel, plane);
|
||||
}
|
|
@ -0,0 +1,50 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#ifndef MINDSPORE_LITE_NNACL_FP32_PACK_H_
|
||||
#define MINDSPORE_LITE_NNACL_FP32_PACK_H_
|
||||
|
||||
#ifdef ENABLE_NEON
|
||||
#include <arm_neon.h>
|
||||
#endif
|
||||
#include "nnacl/conv_parameter.h"
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif
|
||||
|
||||
void PackHWCToWHC(const float *src, float *dst, int height, int width, int channel);
|
||||
void PackNHWCToNC4HW4Fp32(const void *src, void *dst, int batch, int plane, int channel);
|
||||
void PackNCHWToNC4HW4Fp32(const void *src, void *dst, int batch, int plane, int channel);
|
||||
void PackNHWCToNHWC4Fp32(const void *src, void *dst, int batch, int plane, int channel);
|
||||
void PackNHWCToNHWC8Fp32(const void *src, void *dst, int batch, int plane, int channel);
|
||||
void PackNHWCToNCHWFp32(const void *src, void *dst, int batch, int plane, int channel);
|
||||
void PackNCHWToNHWCFp32(const void *src, void *dst, int batch, int plane, int channel);
|
||||
void PackNHWC4ToNHWCFp32(const void *src, void *dst, int batch, int plane, int channel);
|
||||
void PackNC4HW4ToNHWC4Fp32(const void *src, void *dst, int batch, int plane, int channel);
|
||||
void PackNC4HW4ToNHWCFp32(const void *src, void *dst, int batch, int plane, int channel);
|
||||
void PackNHWCToC8HWN8Fp32(const void *src, void *dst, int batch, int plane, int channel);
|
||||
|
||||
void PackWeightKHWToHWKFp32(const void *src, void *dst, int plane, int channel);
|
||||
void PackDepthwiseIndirectWeightC4Fp32(const void *src, void *dst, int height, int width, int channel);
|
||||
void PackDepthwiseIndirectWeightC8Fp32(const void *src, void *dst, int height, int width, int channel);
|
||||
void Im2ColPackUnitFp32(const float *input_data, const ConvParameter *conv_param, float *packed_input, int real_cal_num,
|
||||
int block_index);
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
|
||||
#endif // MINDSPORE_LITE_NNACL_FP32_PAD_H_
|
|
@ -22,7 +22,7 @@
|
|||
#endif
|
||||
#include "nnacl/op_base.h"
|
||||
#include "nnacl/pooling_parameter.h"
|
||||
#include "nnacl/quantization/quantize.h"
|
||||
#include "mindspore/lite/nnacl/int8/quantize.h"
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
|
|
|
@ -18,7 +18,7 @@
|
|||
#ifdef ENABLE_NEON
|
||||
#include <arm_neon.h>
|
||||
#endif
|
||||
#include "nnacl/quantization/fixed_point.h"
|
||||
#include "nnacl/int8/fixed_point.h"
|
||||
|
||||
void AddInt8(const int8_t *input0, const int8_t *input1, int8_t *output, int size, AddQuantParameter *params) {
|
||||
int in0_left_shift = (1 << params->left_shift_) * (1 << params->in0_args_.left_shift_);
|
||||
|
|
|
@ -17,7 +17,7 @@
|
|||
#define MINDSPORE_LITE_NNACL_INT8_ARG_MIN_MAX_INT8_H_
|
||||
|
||||
#include "nnacl/arg_min_max_parameter.h"
|
||||
#include "nnacl/quantization/quantize.h"
|
||||
#include "nnacl/int8/quantize.h"
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
|
|
|
@ -17,8 +17,8 @@
|
|||
#define MINDSPORE_LITE_NNACL_INT8_ARITHMETIC_INT8_H_
|
||||
|
||||
#include "nnacl/op_base.h"
|
||||
#include "nnacl/arithmetic.h"
|
||||
#include "nnacl/quantization/quantize.h"
|
||||
#include "nnacl/int8/quantize.h"
|
||||
#include "nnacl/base/arithmetic_base.h"
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
|
|
|
@ -21,7 +21,7 @@
|
|||
#include <arm_neon.h>
|
||||
#include "nnacl/int8/common_func_int8.h"
|
||||
#endif
|
||||
#include "nnacl/quantization/fixed_point.h"
|
||||
#include "nnacl/int8/fixed_point.h"
|
||||
|
||||
int Int8ElementFloor(int8_t *input, int8_t *output, int element_size, ArithSelfQuantArg para) {
|
||||
float in_scale = para.in_args_.scale_;
|
||||
|
|
|
@ -22,7 +22,7 @@
|
|||
#endif
|
||||
#include "nnacl/op_base.h"
|
||||
#include "nnacl/errorcode.h"
|
||||
#include "nnacl/quantization/quantize.h"
|
||||
#include "nnacl/int8/quantize.h"
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
|
|
|
@ -16,7 +16,7 @@
|
|||
#ifndef MINDSPORE_LITE_NNACL_INT8_BATCH_TO_SPACE_INT8_H_
|
||||
#define MINDSPORE_LITE_NNACL_INT8_BATCH_TO_SPACE_INT8_H_
|
||||
#include "nnacl/op_base.h"
|
||||
#include "nnacl/quantization/quantize.h"
|
||||
#include "nnacl/int8/quantize.h"
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
|
|
|
@ -15,7 +15,7 @@
|
|||
*/
|
||||
|
||||
#include "nnacl/int8/common_func_int8.h"
|
||||
#include "nnacl/quantization/fixed_point.h"
|
||||
#include "nnacl/int8/fixed_point.h"
|
||||
|
||||
void PostConvFuncCommInt8(const int32_t *in, int8_t *out, const int32_t *bias, size_t oc, size_t plane,
|
||||
size_t out_oc_stride, size_t in_plane_stride, int32_t multiplier, int32_t mini, int32_t maxi,
|
||||
|
|
|
@ -17,8 +17,6 @@
|
|||
#ifndef MINDSPORE_LITE_NNACL_INT8_COMMON_FUNC_H_
|
||||
#define MINDSPORE_LITE_NNACL_INT8_COMMON_FUNC_H_
|
||||
|
||||
#include <stdint.h>
|
||||
#include <stdio.h>
|
||||
#include <string.h>
|
||||
#ifdef ENABLE_NEON
|
||||
#include <arm_neon.h>
|
||||
|
@ -29,9 +27,6 @@
|
|||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif
|
||||
|
||||
void PostFuncInt8C8(const int32_t *in, const int32_t *bias, int8_t *out, size_t oc, size_t plane, int32_t multiplier,
|
||||
int32_t left_shift, int32_t right_shift, int32_t zp, int32_t mini, int32_t maxi);
|
||||
void PostFuncInt8C4(const int32_t *in, const int32_t *bias, int8_t *out, size_t oc, size_t plane, size_t stride,
|
||||
int32_t multiplier, int32_t left_shift, int32_t right_shift, int32_t zp, int32_t mini,
|
||||
int32_t maxi);
|
||||
|
|
|
@ -0,0 +1,39 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "nnacl/int8/conv1x1_int8.h"
|
||||
|
||||
void Conv1x1Int8Opt(const int8_t *packed_input, const int8_t *packed_weight, int8_t *dst, const int32_t *input_sum,
|
||||
const int32_t *bias, int row, int col, int deep4, int32_t *left_shift, int32_t *right_shift,
|
||||
int32_t *multiplier, ConvParameter *conv_param, MATMUL_OPT_DP_FUNC matmul_func, int *filter_zp) {
|
||||
int is_per_oc = (int)conv_param->conv_quant_arg_.filter_arg_num_ != 1;
|
||||
matmul_func(packed_input, packed_weight, dst, row, col, deep4, conv_param->output_channel_, input_sum, bias,
|
||||
left_shift, right_shift, multiplier, conv_param->conv_quant_arg_.output_quant_args_[0].zp_,
|
||||
conv_param->conv_quant_arg_.out_act_min_[0], conv_param->conv_quant_arg_.out_act_max_[0], is_per_oc,
|
||||
filter_zp);
|
||||
return;
|
||||
}
|
||||
|
||||
void Conv1x1Int8(const int8_t *packed_input, const int8_t *packed_weight, int8_t *dst, const int32_t *input_sum,
|
||||
const int32_t *bias, int row, int col, int deep16, int32_t *left_shift, int32_t *right_shift,
|
||||
int32_t *multiplier, ConvParameter *conv_param, int32_t *filter_zp) {
|
||||
int is_per_oc = (int)conv_param->conv_quant_arg_.filter_arg_num_ != 1;
|
||||
MatmulInt8Opt(packed_input, packed_weight, dst, row, col, deep16, input_sum, bias,
|
||||
conv_param->conv_quant_arg_.out_act_min_[0], conv_param->conv_quant_arg_.out_act_max_[0],
|
||||
conv_param->conv_quant_arg_.output_quant_args_[0].zp_, multiplier, left_shift, right_shift,
|
||||
conv_param->output_channel_, is_per_oc, filter_zp);
|
||||
return;
|
||||
}
|
|
@ -0,0 +1,45 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#ifndef MINDSPORE_LITE_NNACL_INT8_CONV1X1_INT8_H_
|
||||
#define MINDSPORE_LITE_NNACL_INT8_CONV1X1_INT8_H_
|
||||
|
||||
#ifdef ENABLE_NEON
|
||||
#include <arm_neon.h>
|
||||
#endif
|
||||
#include "nnacl/pack.h"
|
||||
#include "nnacl/op_base.h"
|
||||
#include "nnacl/common_func.h"
|
||||
#include "nnacl/conv_parameter.h"
|
||||
#include "nnacl/int8/quantize.h"
|
||||
#include "nnacl/matmul_parameter.h"
|
||||
#include "nnacl/int8/matmul_int8.h"
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif
|
||||
|
||||
void Conv1x1Int8(const int8_t *packed_input, const int8_t *packed_weight, int8_t *dst, const int32_t *input_sum,
|
||||
const int32_t *bias, int row, int col, int deep16, int32_t *left_shift, int32_t *right_shift,
|
||||
int32_t *multiplier, ConvParameter *conv_param, int32_t *filter_zp);
|
||||
void Conv1x1Int8Opt(const int8_t *packed_input, const int8_t *packed_weight, int8_t *dst, const int32_t *input_sum,
|
||||
const int32_t *bias, int row, int col, int deep4, int32_t *left_shift, int32_t *right_shift,
|
||||
int32_t *multiplier, ConvParameter *conv_param, MATMUL_OPT_DP_FUNC matmul_func, int32_t *filter_zp);
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
|
||||
#endif // MINDSPORE_LITE_NNACL_INT8_CONV1X1_INT8_H_
|
|
@ -0,0 +1,900 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "nnacl/int8/conv3x3_int8.h"
|
||||
|
||||
void Conv3x3Int8InputUnit(int16_t *tmp_data, int16_t *trans_input_data, size_t step, int input_zp) {
|
||||
#ifdef ENABLE_ARM
|
||||
int16x8_t zp = vdupq_n_s16(input_zp);
|
||||
|
||||
int16x8_t d00 = vsubq_s16(vld1q_s16(tmp_data), zp);
|
||||
int16x8_t d01 = vsubq_s16(vld1q_s16(tmp_data + 8), zp);
|
||||
int16x8_t d02 = vsubq_s16(vld1q_s16(tmp_data + 2 * 8), zp);
|
||||
int16x8_t d03 = vsubq_s16(vld1q_s16(tmp_data + 3 * 8), zp);
|
||||
|
||||
int16x8_t d10 = vsubq_s16(vld1q_s16(tmp_data + 4 * 8), zp);
|
||||
int16x8_t d11 = vsubq_s16(vld1q_s16(tmp_data + 5 * 8), zp);
|
||||
int16x8_t d12 = vsubq_s16(vld1q_s16(tmp_data + 6 * 8), zp);
|
||||
int16x8_t d13 = vsubq_s16(vld1q_s16(tmp_data + 7 * 8), zp);
|
||||
|
||||
int16x8_t d20 = vsubq_s16(vld1q_s16(tmp_data + 8 * 8), zp);
|
||||
int16x8_t d21 = vsubq_s16(vld1q_s16(tmp_data + 9 * 8), zp);
|
||||
int16x8_t d22 = vsubq_s16(vld1q_s16(tmp_data + 10 * 8), zp);
|
||||
int16x8_t d23 = vsubq_s16(vld1q_s16(tmp_data + 11 * 8), zp);
|
||||
|
||||
int16x8_t d30 = vsubq_s16(vld1q_s16(tmp_data + 12 * 8), zp);
|
||||
int16x8_t d31 = vsubq_s16(vld1q_s16(tmp_data + 13 * 8), zp);
|
||||
int16x8_t d32 = vsubq_s16(vld1q_s16(tmp_data + 14 * 8), zp);
|
||||
int16x8_t d33 = vsubq_s16(vld1q_s16(tmp_data + 15 * 8), zp);
|
||||
|
||||
int16x8_t t00 = vsubq_s16(d00, d20);
|
||||
int16x8_t t01 = vsubq_s16(d01, d21);
|
||||
int16x8_t t02 = vsubq_s16(d02, d22);
|
||||
int16x8_t t03 = vsubq_s16(d03, d23);
|
||||
|
||||
int16x8_t t10 = vaddq_s16(d10, d20);
|
||||
int16x8_t t11 = vaddq_s16(d11, d21);
|
||||
int16x8_t t12 = vaddq_s16(d12, d22);
|
||||
int16x8_t t13 = vaddq_s16(d13, d23);
|
||||
|
||||
int16x8_t t20 = vsubq_s16(d20, d10);
|
||||
int16x8_t t21 = vsubq_s16(d21, d11);
|
||||
int16x8_t t22 = vsubq_s16(d22, d12);
|
||||
int16x8_t t23 = vsubq_s16(d23, d13);
|
||||
|
||||
int16x8_t t30 = vsubq_s16(d10, d30);
|
||||
int16x8_t t31 = vsubq_s16(d11, d31);
|
||||
int16x8_t t32 = vsubq_s16(d12, d32);
|
||||
int16x8_t t33 = vsubq_s16(d13, d33);
|
||||
|
||||
int16x8_t m00 = vsubq_s16(t00, t02);
|
||||
int16x8_t m01 = vaddq_s16(t01, t02);
|
||||
int16x8_t m02 = vsubq_s16(t02, t01);
|
||||
int16x8_t m03 = vsubq_s16(t01, t03);
|
||||
|
||||
int16x8_t m10 = vsubq_s16(t10, t12);
|
||||
int16x8_t m11 = vaddq_s16(t11, t12);
|
||||
int16x8_t m12 = vsubq_s16(t12, t11);
|
||||
int16x8_t m13 = vsubq_s16(t11, t13);
|
||||
|
||||
int16x8_t m20 = vsubq_s16(t20, t22);
|
||||
int16x8_t m21 = vaddq_s16(t21, t22);
|
||||
int16x8_t m22 = vsubq_s16(t22, t21);
|
||||
int16x8_t m23 = vsubq_s16(t21, t23);
|
||||
|
||||
int16x8_t m30 = vsubq_s16(t30, t32);
|
||||
int16x8_t m31 = vaddq_s16(t31, t32);
|
||||
int16x8_t m32 = vsubq_s16(t32, t31);
|
||||
int16x8_t m33 = vsubq_s16(t31, t33);
|
||||
|
||||
vst1q_s16(trans_input_data, m00);
|
||||
vst1q_s16(trans_input_data + step, m01);
|
||||
vst1q_s16(trans_input_data + 2 * step, m02);
|
||||
vst1q_s16(trans_input_data + 3 * step, m03);
|
||||
|
||||
vst1q_s16(trans_input_data + 4 * step, m10);
|
||||
vst1q_s16(trans_input_data + 5 * step, m11);
|
||||
vst1q_s16(trans_input_data + 6 * step, m12);
|
||||
vst1q_s16(trans_input_data + 7 * step, m13);
|
||||
|
||||
vst1q_s16(trans_input_data + 8 * step, m20);
|
||||
vst1q_s16(trans_input_data + 9 * step, m21);
|
||||
vst1q_s16(trans_input_data + 10 * step, m22);
|
||||
vst1q_s16(trans_input_data + 11 * step, m23);
|
||||
|
||||
vst1q_s16(trans_input_data + 12 * step, m30);
|
||||
vst1q_s16(trans_input_data + 13 * step, m31);
|
||||
vst1q_s16(trans_input_data + 14 * step, m32);
|
||||
vst1q_s16(trans_input_data + 15 * step, m33);
|
||||
#else
|
||||
for (int i = 0; i < C8NUM; i++) {
|
||||
int16_t *local_ptr = tmp_data + i;
|
||||
int16_t d00 = local_ptr[0] - input_zp;
|
||||
int16_t d01 = (local_ptr + C8NUM)[0] - input_zp;
|
||||
int16_t d02 = (local_ptr + 2 * C8NUM)[0] - input_zp;
|
||||
int16_t d03 = (local_ptr + 3 * C8NUM)[0] - input_zp;
|
||||
|
||||
int16_t d10 = (local_ptr + 4 * C8NUM)[0] - input_zp;
|
||||
int16_t d11 = (local_ptr + 5 * C8NUM)[0] - input_zp;
|
||||
int16_t d12 = (local_ptr + 6 * C8NUM)[0] - input_zp;
|
||||
int16_t d13 = (local_ptr + 7 * C8NUM)[0] - input_zp;
|
||||
|
||||
int16_t d20 = (local_ptr + 8 * C8NUM)[0] - input_zp;
|
||||
int16_t d21 = (local_ptr + 9 * C8NUM)[0] - input_zp;
|
||||
int16_t d22 = (local_ptr + 10 * C8NUM)[0] - input_zp;
|
||||
int16_t d23 = (local_ptr + 11 * C8NUM)[0] - input_zp;
|
||||
|
||||
int16_t d30 = (local_ptr + 12 * C8NUM)[0] - input_zp;
|
||||
int16_t d31 = (local_ptr + 13 * C8NUM)[0] - input_zp;
|
||||
int16_t d32 = (local_ptr + 14 * C8NUM)[0] - input_zp;
|
||||
int16_t d33 = (local_ptr + 15 * C8NUM)[0] - input_zp;
|
||||
|
||||
int16_t t00 = d00 - d20;
|
||||
int16_t t01 = d01 - d21;
|
||||
int16_t t02 = d02 - d22;
|
||||
int16_t t03 = d03 - d23;
|
||||
|
||||
int16_t t10 = d10 + d20;
|
||||
int16_t t11 = d11 + d21;
|
||||
int16_t t12 = d12 + d22;
|
||||
int16_t t13 = d13 + d23;
|
||||
|
||||
int16_t t20 = d20 - d10;
|
||||
int16_t t21 = d21 - d11;
|
||||
int16_t t22 = d22 - d12;
|
||||
int16_t t23 = d23 - d13;
|
||||
|
||||
int16_t t30 = d10 - d30;
|
||||
int16_t t31 = d11 - d31;
|
||||
int16_t t32 = d12 - d32;
|
||||
int16_t t33 = d13 - d33;
|
||||
|
||||
int16_t m00 = t00 - t02;
|
||||
int16_t m01 = t01 + t02;
|
||||
int16_t m02 = t02 - t01;
|
||||
int16_t m03 = t01 - t03;
|
||||
|
||||
int16_t m10 = t10 - t12;
|
||||
int16_t m11 = t11 + t12;
|
||||
int16_t m12 = t12 - t11;
|
||||
int16_t m13 = t11 - t13;
|
||||
|
||||
int16_t m20 = t20 - t22;
|
||||
int16_t m21 = t21 + t22;
|
||||
int16_t m22 = t22 - t21;
|
||||
int16_t m23 = t21 - t23;
|
||||
|
||||
int16_t m30 = t30 - t32;
|
||||
int16_t m31 = t31 + t32;
|
||||
int16_t m32 = t32 - t31;
|
||||
int16_t m33 = t31 - t33;
|
||||
|
||||
(trans_input_data + i)[0] = m00;
|
||||
(trans_input_data + i + step)[0] = m01;
|
||||
(trans_input_data + i + 2 * step)[0] = m02;
|
||||
(trans_input_data + i + 3 * step)[0] = m03;
|
||||
|
||||
(trans_input_data + i + 4 * step)[0] = m10;
|
||||
(trans_input_data + i + 5 * step)[0] = m11;
|
||||
(trans_input_data + i + 6 * step)[0] = m12;
|
||||
(trans_input_data + i + 7 * step)[0] = m13;
|
||||
|
||||
(trans_input_data + i + 8 * step)[0] = m20;
|
||||
(trans_input_data + i + 9 * step)[0] = m21;
|
||||
(trans_input_data + i + 10 * step)[0] = m22;
|
||||
(trans_input_data + i + 11 * step)[0] = m23;
|
||||
|
||||
(trans_input_data + i + 12 * step)[0] = m30;
|
||||
(trans_input_data + i + 13 * step)[0] = m31;
|
||||
(trans_input_data + i + 14 * step)[0] = m32;
|
||||
(trans_input_data + i + 15 * step)[0] = m33;
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
void Conv3x3Int8FilterTransform(const int16_t *weight_data, int16_t *trans_weight, int iC8, int output_channel,
|
||||
int kernel_plane) {
|
||||
const int input_unit = 4;
|
||||
int dst_step = iC8 * C8NUM * C4NUM;
|
||||
for (int o = 0; o < output_channel; o++) {
|
||||
int oc4_block_num = o / C4NUM;
|
||||
int oc4_block_rem = o % C4NUM;
|
||||
int src_oc_offset = o * iC8 * C8NUM * kernel_plane;
|
||||
int dst_oc_offset = oc4_block_num * C4NUM * iC8 * C8NUM * input_unit * input_unit + oc4_block_rem;
|
||||
for (int i = 0; i < iC8; i++) {
|
||||
const int16_t *src_ic8_ptr = weight_data + src_oc_offset + i * kernel_plane * C8NUM;
|
||||
int16_t *dst_ic8_ptr = trans_weight + dst_oc_offset + i * C4NUM * C8NUM;
|
||||
#ifdef ENABLE_ARM
|
||||
int16x8_t g00 = vld1q_s16(src_ic8_ptr);
|
||||
int16x8_t g01 = vld1q_s16(src_ic8_ptr + 8);
|
||||
int16x8_t g02 = vld1q_s16(src_ic8_ptr + 2 * 8);
|
||||
int16x8_t g10 = vld1q_s16(src_ic8_ptr + 3 * 8);
|
||||
int16x8_t g11 = vld1q_s16(src_ic8_ptr + 4 * 8);
|
||||
int16x8_t g12 = vld1q_s16(src_ic8_ptr + 5 * 8);
|
||||
int16x8_t g20 = vld1q_s16(src_ic8_ptr + 6 * 8);
|
||||
int16x8_t g21 = vld1q_s16(src_ic8_ptr + 7 * 8);
|
||||
int16x8_t g22 = vld1q_s16(src_ic8_ptr + 8 * 8);
|
||||
|
||||
int16x8_t dst00 = vmulq_n_s16(g00, 2);
|
||||
int16x8_t dst01 = vmulq_n_s16(g01, 2);
|
||||
int16x8_t dst02 = vmulq_n_s16(g02, 2);
|
||||
|
||||
int16x8_t dst10 = vaddq_s16(vaddq_s16(g00, g10), g20);
|
||||
int16x8_t dst11 = vaddq_s16(vaddq_s16(g01, g11), g21);
|
||||
int16x8_t dst12 = vaddq_s16(vaddq_s16(g02, g12), g22);
|
||||
|
||||
int16x8_t dst20 = vaddq_s16(vsubq_s16(g00, g10), g20);
|
||||
int16x8_t dst21 = vaddq_s16(vsubq_s16(g01, g11), g21);
|
||||
int16x8_t dst22 = vaddq_s16(vsubq_s16(g02, g12), g22);
|
||||
|
||||
int16x8_t dst30 = vmulq_n_s16(g20, 2);
|
||||
int16x8_t dst31 = vmulq_n_s16(g21, 2);
|
||||
int16x8_t dst32 = vmulq_n_s16(g22, 2);
|
||||
|
||||
int16x8_t m00 = vmulq_n_s16(dst00, 2);
|
||||
int16x8_t m01 = vaddq_s16(vaddq_s16(dst00, dst01), dst02);
|
||||
int16x8_t m02 = vaddq_s16(vsubq_s16(dst00, dst01), dst02);
|
||||
int16x8_t m03 = vmulq_n_s16(dst02, 2);
|
||||
|
||||
int16x8_t m10 = vmulq_n_s16(dst10, 2);
|
||||
int16x8_t m11 = vaddq_s16(vaddq_s16(dst10, dst11), dst12);
|
||||
int16x8_t m12 = vaddq_s16(vsubq_s16(dst10, dst11), dst12);
|
||||
int16x8_t m13 = vmulq_n_s16(dst12, 2);
|
||||
|
||||
int16x8_t m20 = vmulq_n_s16(dst20, 2);
|
||||
int16x8_t m21 = vaddq_s16(vaddq_s16(dst20, dst21), dst22);
|
||||
int16x8_t m22 = vaddq_s16(vsubq_s16(dst20, dst21), dst22);
|
||||
int16x8_t m23 = vmulq_n_s16(dst22, 2);
|
||||
|
||||
int16x8_t m30 = vmulq_n_s16(dst30, 2);
|
||||
int16x8_t m31 = vaddq_s16(vaddq_s16(dst30, dst31), dst32);
|
||||
int16x8_t m32 = vaddq_s16(vsubq_s16(dst30, dst31), dst32);
|
||||
int16x8_t m33 = vmulq_n_s16(dst32, 2);
|
||||
|
||||
dst_ic8_ptr[0] = m00[0];
|
||||
dst_ic8_ptr[4] = m00[1];
|
||||
dst_ic8_ptr[8] = m00[2];
|
||||
dst_ic8_ptr[12] = m00[3];
|
||||
dst_ic8_ptr[16] = m00[4];
|
||||
dst_ic8_ptr[20] = m00[5];
|
||||
dst_ic8_ptr[24] = m00[6];
|
||||
dst_ic8_ptr[28] = m00[7];
|
||||
|
||||
dst_ic8_ptr[0 + dst_step] = m01[0];
|
||||
dst_ic8_ptr[4 + dst_step] = m01[1];
|
||||
dst_ic8_ptr[8 + dst_step] = m01[2];
|
||||
dst_ic8_ptr[12 + dst_step] = m01[3];
|
||||
dst_ic8_ptr[16 + dst_step] = m01[4];
|
||||
dst_ic8_ptr[20 + dst_step] = m01[5];
|
||||
dst_ic8_ptr[24 + dst_step] = m01[6];
|
||||
dst_ic8_ptr[28 + dst_step] = m01[7];
|
||||
|
||||
dst_ic8_ptr[0 + 2 * dst_step] = m02[0];
|
||||
dst_ic8_ptr[4 + 2 * dst_step] = m02[1];
|
||||
dst_ic8_ptr[8 + 2 * dst_step] = m02[2];
|
||||
dst_ic8_ptr[12 + 2 * dst_step] = m02[3];
|
||||
dst_ic8_ptr[16 + 2 * dst_step] = m02[4];
|
||||
dst_ic8_ptr[20 + 2 * dst_step] = m02[5];
|
||||
dst_ic8_ptr[24 + 2 * dst_step] = m02[6];
|
||||
dst_ic8_ptr[28 + 2 * dst_step] = m02[7];
|
||||
|
||||
dst_ic8_ptr[0 + 3 * dst_step] = m03[0];
|
||||
dst_ic8_ptr[4 + 3 * dst_step] = m03[1];
|
||||
dst_ic8_ptr[8 + 3 * dst_step] = m03[2];
|
||||
dst_ic8_ptr[12 + 3 * dst_step] = m03[3];
|
||||
dst_ic8_ptr[16 + 3 * dst_step] = m03[4];
|
||||
dst_ic8_ptr[20 + 3 * dst_step] = m03[5];
|
||||
dst_ic8_ptr[24 + 3 * dst_step] = m03[6];
|
||||
dst_ic8_ptr[28 + 3 * dst_step] = m03[7];
|
||||
|
||||
dst_ic8_ptr[0 + 4 * dst_step] = m10[0];
|
||||
dst_ic8_ptr[4 + 4 * dst_step] = m10[1];
|
||||
dst_ic8_ptr[8 + 4 * dst_step] = m10[2];
|
||||
dst_ic8_ptr[12 + 4 * dst_step] = m10[3];
|
||||
dst_ic8_ptr[16 + 4 * dst_step] = m10[4];
|
||||
dst_ic8_ptr[20 + 4 * dst_step] = m10[5];
|
||||
dst_ic8_ptr[24 + 4 * dst_step] = m10[6];
|
||||
dst_ic8_ptr[28 + 4 * dst_step] = m10[7];
|
||||
|
||||
dst_ic8_ptr[0 + 5 * dst_step] = m11[0];
|
||||
dst_ic8_ptr[4 + 5 * dst_step] = m11[1];
|
||||
dst_ic8_ptr[8 + 5 * dst_step] = m11[2];
|
||||
dst_ic8_ptr[12 + 5 * dst_step] = m11[3];
|
||||
dst_ic8_ptr[16 + 5 * dst_step] = m11[4];
|
||||
dst_ic8_ptr[20 + 5 * dst_step] = m11[5];
|
||||
dst_ic8_ptr[24 + 5 * dst_step] = m11[6];
|
||||
dst_ic8_ptr[28 + 5 * dst_step] = m11[7];
|
||||
|
||||
dst_ic8_ptr[0 + 6 * dst_step] = m12[0];
|
||||
dst_ic8_ptr[4 + 6 * dst_step] = m12[1];
|
||||
dst_ic8_ptr[8 + 6 * dst_step] = m12[2];
|
||||
dst_ic8_ptr[12 + 6 * dst_step] = m12[3];
|
||||
dst_ic8_ptr[16 + 6 * dst_step] = m12[4];
|
||||
dst_ic8_ptr[20 + 6 * dst_step] = m12[5];
|
||||
dst_ic8_ptr[24 + 6 * dst_step] = m12[6];
|
||||
dst_ic8_ptr[28 + 6 * dst_step] = m12[7];
|
||||
|
||||
dst_ic8_ptr[0 + 7 * dst_step] = m13[0];
|
||||
dst_ic8_ptr[4 + 7 * dst_step] = m13[1];
|
||||
dst_ic8_ptr[8 + 7 * dst_step] = m13[2];
|
||||
dst_ic8_ptr[12 + 7 * dst_step] = m13[3];
|
||||
dst_ic8_ptr[16 + 7 * dst_step] = m13[4];
|
||||
dst_ic8_ptr[20 + 7 * dst_step] = m13[5];
|
||||
dst_ic8_ptr[24 + 7 * dst_step] = m13[6];
|
||||
dst_ic8_ptr[28 + 7 * dst_step] = m13[7];
|
||||
|
||||
dst_ic8_ptr[0 + 8 * dst_step] = m20[0];
|
||||
dst_ic8_ptr[4 + 8 * dst_step] = m20[1];
|
||||
dst_ic8_ptr[8 + 8 * dst_step] = m20[2];
|
||||
dst_ic8_ptr[12 + 8 * dst_step] = m20[3];
|
||||
dst_ic8_ptr[16 + 8 * dst_step] = m20[4];
|
||||
dst_ic8_ptr[20 + 8 * dst_step] = m20[5];
|
||||
dst_ic8_ptr[24 + 8 * dst_step] = m20[6];
|
||||
dst_ic8_ptr[28 + 8 * dst_step] = m20[7];
|
||||
|
||||
dst_ic8_ptr[0 + 9 * dst_step] = m21[0];
|
||||
dst_ic8_ptr[4 + 9 * dst_step] = m21[1];
|
||||
dst_ic8_ptr[8 + 9 * dst_step] = m21[2];
|
||||
dst_ic8_ptr[12 + 9 * dst_step] = m21[3];
|
||||
dst_ic8_ptr[16 + 9 * dst_step] = m21[4];
|
||||
dst_ic8_ptr[20 + 9 * dst_step] = m21[5];
|
||||
dst_ic8_ptr[24 + 9 * dst_step] = m21[6];
|
||||
dst_ic8_ptr[28 + 9 * dst_step] = m21[7];
|
||||
|
||||
dst_ic8_ptr[0 + 10 * dst_step] = m22[0];
|
||||
dst_ic8_ptr[4 + 10 * dst_step] = m22[1];
|
||||
dst_ic8_ptr[8 + 10 * dst_step] = m22[2];
|
||||
dst_ic8_ptr[12 + 10 * dst_step] = m22[3];
|
||||
dst_ic8_ptr[16 + 10 * dst_step] = m22[4];
|
||||
dst_ic8_ptr[20 + 10 * dst_step] = m22[5];
|
||||
dst_ic8_ptr[24 + 10 * dst_step] = m22[6];
|
||||
dst_ic8_ptr[28 + 10 * dst_step] = m22[7];
|
||||
|
||||
dst_ic8_ptr[0 + 11 * dst_step] = m23[0];
|
||||
dst_ic8_ptr[4 + 11 * dst_step] = m23[1];
|
||||
dst_ic8_ptr[8 + 11 * dst_step] = m23[2];
|
||||
dst_ic8_ptr[12 + 11 * dst_step] = m23[3];
|
||||
dst_ic8_ptr[16 + 11 * dst_step] = m23[4];
|
||||
dst_ic8_ptr[20 + 11 * dst_step] = m23[5];
|
||||
dst_ic8_ptr[24 + 11 * dst_step] = m23[6];
|
||||
dst_ic8_ptr[28 + 11 * dst_step] = m23[7];
|
||||
|
||||
dst_ic8_ptr[0 + 12 * dst_step] = m30[0];
|
||||
dst_ic8_ptr[4 + 12 * dst_step] = m30[1];
|
||||
dst_ic8_ptr[8 + 12 * dst_step] = m30[2];
|
||||
dst_ic8_ptr[12 + 12 * dst_step] = m30[3];
|
||||
dst_ic8_ptr[16 + 12 * dst_step] = m30[4];
|
||||
dst_ic8_ptr[20 + 12 * dst_step] = m30[5];
|
||||
dst_ic8_ptr[24 + 12 * dst_step] = m30[6];
|
||||
dst_ic8_ptr[28 + 12 * dst_step] = m30[7];
|
||||
|
||||
dst_ic8_ptr[0 + 13 * dst_step] = m31[0];
|
||||
dst_ic8_ptr[4 + 13 * dst_step] = m31[1];
|
||||
dst_ic8_ptr[8 + 13 * dst_step] = m31[2];
|
||||
dst_ic8_ptr[12 + 13 * dst_step] = m31[3];
|
||||
dst_ic8_ptr[16 + 13 * dst_step] = m31[4];
|
||||
dst_ic8_ptr[20 + 13 * dst_step] = m31[5];
|
||||
dst_ic8_ptr[24 + 13 * dst_step] = m31[6];
|
||||
dst_ic8_ptr[28 + 13 * dst_step] = m31[7];
|
||||
|
||||
dst_ic8_ptr[0 + 14 * dst_step] = m32[0];
|
||||
dst_ic8_ptr[4 + 14 * dst_step] = m32[1];
|
||||
dst_ic8_ptr[8 + 14 * dst_step] = m32[2];
|
||||
dst_ic8_ptr[12 + 14 * dst_step] = m32[3];
|
||||
dst_ic8_ptr[16 + 14 * dst_step] = m32[4];
|
||||
dst_ic8_ptr[20 + 14 * dst_step] = m32[5];
|
||||
dst_ic8_ptr[24 + 14 * dst_step] = m32[6];
|
||||
dst_ic8_ptr[28 + 14 * dst_step] = m32[7];
|
||||
|
||||
dst_ic8_ptr[0 + 15 * dst_step] = m33[0];
|
||||
dst_ic8_ptr[4 + 15 * dst_step] = m33[1];
|
||||
dst_ic8_ptr[8 + 15 * dst_step] = m33[2];
|
||||
dst_ic8_ptr[12 + 15 * dst_step] = m33[3];
|
||||
dst_ic8_ptr[16 + 15 * dst_step] = m33[4];
|
||||
dst_ic8_ptr[20 + 15 * dst_step] = m33[5];
|
||||
dst_ic8_ptr[24 + 15 * dst_step] = m33[6];
|
||||
dst_ic8_ptr[28 + 15 * dst_step] = m33[7];
|
||||
#else
|
||||
for (int j = 0; j < C8NUM; j++) {
|
||||
const int16_t *local_ptr = src_ic8_ptr + j;
|
||||
int16_t dst00 = local_ptr[0] * 2;
|
||||
int16_t dst01 = (local_ptr + 8)[0] * 2;
|
||||
int16_t dst02 = (local_ptr + 16)[0] * 2;
|
||||
|
||||
int16_t dst10 = local_ptr[0] + (local_ptr + 24)[0] + (local_ptr + 48)[0];
|
||||
int16_t dst11 = (local_ptr + 8)[0] + (local_ptr + 32)[0] + (local_ptr + 56)[0];
|
||||
int16_t dst12 = (local_ptr + 16)[0] + (local_ptr + 40)[0] + (local_ptr + 64)[0];
|
||||
|
||||
int16_t dst20 = local_ptr[0] - (local_ptr + 24)[0] + (local_ptr + 48)[0];
|
||||
int16_t dst21 = (local_ptr + 8)[0] - (local_ptr + 32)[0] + (local_ptr + 56)[0];
|
||||
int16_t dst22 = (local_ptr + 16)[0] - (local_ptr + 40)[0] + (local_ptr + 64)[0];
|
||||
|
||||
int16_t dst30 = (local_ptr + 48)[0] * 2;
|
||||
int16_t dst31 = (local_ptr + 56)[0] * 2;
|
||||
int16_t dst32 = (local_ptr + 64)[0] * 2;
|
||||
|
||||
int16_t m00 = dst00 * 2;
|
||||
int16_t m01 = dst00 + dst01 + dst02;
|
||||
int16_t m02 = dst00 - dst01 + dst02;
|
||||
int16_t m03 = dst02 * 2;
|
||||
|
||||
int16_t m10 = dst10 * 2;
|
||||
int16_t m11 = dst10 + dst11 + dst12;
|
||||
int16_t m12 = dst10 - dst11 + dst12;
|
||||
int16_t m13 = dst12 * 2;
|
||||
|
||||
int16_t m20 = dst20 * 2;
|
||||
int16_t m21 = dst20 + dst21 + dst22;
|
||||
int16_t m22 = dst20 - dst21 + dst22;
|
||||
int16_t m23 = dst22 * 2;
|
||||
|
||||
int16_t m30 = dst30 * 2;
|
||||
int16_t m31 = dst30 + dst31 + dst32;
|
||||
int16_t m32 = dst30 - dst31 + dst32;
|
||||
int16_t m33 = dst32 * 2;
|
||||
|
||||
*(dst_ic8_ptr + j * 4) = m00;
|
||||
*(dst_ic8_ptr + j * 4 + dst_step) = m01;
|
||||
*(dst_ic8_ptr + j * 4 + 2 * dst_step) = m02;
|
||||
*(dst_ic8_ptr + j * 4 + 3 * dst_step) = m03;
|
||||
|
||||
*(dst_ic8_ptr + j * 4 + 4 * dst_step) = m10;
|
||||
*(dst_ic8_ptr + j * 4 + 5 * dst_step) = m11;
|
||||
*(dst_ic8_ptr + j * 4 + 6 * dst_step) = m12;
|
||||
*(dst_ic8_ptr + j * 4 + 7 * dst_step) = m13;
|
||||
|
||||
*(dst_ic8_ptr + j * 4 + 8 * dst_step) = m20;
|
||||
*(dst_ic8_ptr + j * 4 + 9 * dst_step) = m21;
|
||||
*(dst_ic8_ptr + j * 4 + 10 * dst_step) = m22;
|
||||
*(dst_ic8_ptr + j * 4 + 11 * dst_step) = m23;
|
||||
|
||||
*(dst_ic8_ptr + j * 4 + 12 * dst_step) = m30;
|
||||
*(dst_ic8_ptr + j * 4 + 13 * dst_step) = m31;
|
||||
*(dst_ic8_ptr + j * 4 + 14 * dst_step) = m32;
|
||||
*(dst_ic8_ptr + j * 4 + 15 * dst_step) = m33;
|
||||
}
|
||||
#endif
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void Conv3x3Int8OutputUnit(const int32_t *gemm_out, const int32_t *bias_data, int8_t *output_data, bool h_not_bound,
|
||||
bool w_not_bound, int output_w, int real_num, int oc_start, ConvParameter *conv_param) {
|
||||
int32_t *left_shift = conv_param->conv_quant_arg_.left_shift_;
|
||||
int32_t *right_shift = conv_param->conv_quant_arg_.right_shift_;
|
||||
int32_t *quant_multiplier = conv_param->conv_quant_arg_.quant_multiplier_;
|
||||
int output_zp = conv_param->conv_quant_arg_.output_quant_args_[0].zp_;
|
||||
int out_min = conv_param->conv_quant_arg_.out_act_min_[0];
|
||||
int out_max = conv_param->conv_quant_arg_.out_act_max_[0];
|
||||
|
||||
#ifdef ENABLE_ARM
|
||||
int32x4_t bias_ptr = vld1q_s32(bias_data);
|
||||
|
||||
int32x4_t s00 = vld1q_s32(gemm_out);
|
||||
int32x4_t s01 = vld1q_s32(gemm_out + 4);
|
||||
int32x4_t s02 = vld1q_s32(gemm_out + 8);
|
||||
int32x4_t s03 = vld1q_s32(gemm_out + 12);
|
||||
|
||||
int32x4_t s10 = vld1q_s32(gemm_out + 16);
|
||||
int32x4_t s11 = vld1q_s32(gemm_out + 20);
|
||||
int32x4_t s12 = vld1q_s32(gemm_out + 24);
|
||||
int32x4_t s13 = vld1q_s32(gemm_out + 28);
|
||||
|
||||
int32x4_t s20 = vld1q_s32(gemm_out + 32);
|
||||
int32x4_t s21 = vld1q_s32(gemm_out + 36);
|
||||
int32x4_t s22 = vld1q_s32(gemm_out + 40);
|
||||
int32x4_t s23 = vld1q_s32(gemm_out + 44);
|
||||
|
||||
int32x4_t s30 = vld1q_s32(gemm_out + 48);
|
||||
int32x4_t s31 = vld1q_s32(gemm_out + 52);
|
||||
int32x4_t s32 = vld1q_s32(gemm_out + 56);
|
||||
int32x4_t s33 = vld1q_s32(gemm_out + 60);
|
||||
|
||||
int32x4_t t00 = vshrq_n_s32(vaddq_s32(vaddq_s32(s00, s10), s20), 1);
|
||||
int32x4_t t01 = vshrq_n_s32(vaddq_s32(vaddq_s32(s01, s11), s21), 1);
|
||||
int32x4_t t02 = vshrq_n_s32(vaddq_s32(vaddq_s32(s02, s12), s22), 1);
|
||||
int32x4_t t03 = vshrq_n_s32(vaddq_s32(vaddq_s32(s03, s13), s23), 1);
|
||||
|
||||
int32x4_t t10 = vshrq_n_s32(vsubq_s32(vsubq_s32(s10, s20), s30), 1);
|
||||
int32x4_t t11 = vshrq_n_s32(vsubq_s32(vsubq_s32(s11, s21), s31), 1);
|
||||
int32x4_t t12 = vshrq_n_s32(vsubq_s32(vsubq_s32(s12, s22), s32), 1);
|
||||
int32x4_t t13 = vshrq_n_s32(vsubq_s32(vsubq_s32(s13, s23), s33), 1);
|
||||
|
||||
int32x4_t d00 = vaddq_s32(vshrq_n_s32(vaddq_s32(vaddq_s32(t00, t01), t02), 1), bias_ptr);
|
||||
int32x4_t d01 = vaddq_s32(vshrq_n_s32(vsubq_s32(vsubq_s32(t01, t02), t03), 1), bias_ptr);
|
||||
|
||||
int32x4_t d10 = vaddq_s32(vshrq_n_s32(vaddq_s32(vaddq_s32(t10, t11), t12), 1), bias_ptr);
|
||||
int32x4_t d11 = vaddq_s32(vshrq_n_s32(vsubq_s32(vsubq_s32(t11, t12), t13), 1), bias_ptr);
|
||||
|
||||
int32x4_t out_multiplier;
|
||||
int32x4_t ls;
|
||||
int32x4_t rs;
|
||||
if ((conv_param->conv_quant_arg_.per_channel_ & FILTER_PER_CHANNEL)) {
|
||||
out_multiplier = vld1q_s32(quant_multiplier + oc_start);
|
||||
ls = vld1q_s32(left_shift + oc_start);
|
||||
rs = vld1q_s32(right_shift + oc_start);
|
||||
} else {
|
||||
out_multiplier = vdupq_n_s32(quant_multiplier[0]);
|
||||
ls = vdupq_n_s32(left_shift[0]);
|
||||
rs = vdupq_n_s32(right_shift[0]);
|
||||
}
|
||||
int32x4_t out_zp = vdupq_n_s32(output_zp);
|
||||
int32x4_t output_min = vdupq_n_s32(out_min);
|
||||
int32x4_t output_max = vdupq_n_s32(out_max);
|
||||
|
||||
d00 = vqshlq_s32(d00, ls);
|
||||
d00 = vqrdmulhq_s32(d00, out_multiplier);
|
||||
int32x4_t carry = vandq_s32(d00, rs);
|
||||
carry = vshrq_n_s32(carry, 31);
|
||||
d00 = vqaddq_s32(d00, carry);
|
||||
d00 = vqrshlq_s32(d00, rs);
|
||||
d00 = vaddq_s32(d00, out_zp);
|
||||
d00 = vmaxq_s32(d00, output_min);
|
||||
d00 = vminq_s32(d00, output_max);
|
||||
|
||||
d01 = vqshlq_s32(d01, ls);
|
||||
d01 = vqrdmulhq_s32(d01, out_multiplier);
|
||||
carry = vandq_s32(d01, rs);
|
||||
carry = vshrq_n_s32(carry, 31);
|
||||
d01 = vqaddq_s32(d01, carry);
|
||||
d01 = vqrshlq_s32(d01, rs);
|
||||
d01 = vaddq_s32(d01, out_zp);
|
||||
d01 = vmaxq_s32(d01, output_min);
|
||||
d01 = vminq_s32(d01, output_max);
|
||||
|
||||
d10 = vqshlq_s32(d10, ls);
|
||||
d10 = vqrdmulhq_s32(d10, out_multiplier);
|
||||
carry = vandq_s32(d10, rs);
|
||||
carry = vshrq_n_s32(carry, 31);
|
||||
d10 = vqaddq_s32(d10, carry);
|
||||
d10 = vqrshlq_s32(d10, rs);
|
||||
d10 = vaddq_s32(d10, out_zp);
|
||||
d10 = vmaxq_s32(d10, output_min);
|
||||
d10 = vminq_s32(d10, output_max);
|
||||
|
||||
d11 = vqshlq_s32(d11, ls);
|
||||
d11 = vqrdmulhq_s32(d11, out_multiplier);
|
||||
carry = vandq_s32(d11, rs);
|
||||
carry = vshrq_n_s32(carry, 31);
|
||||
d11 = vqaddq_s32(d11, carry);
|
||||
d11 = vqrshlq_s32(d11, rs);
|
||||
d11 = vaddq_s32(d11, out_zp);
|
||||
d11 = vmaxq_s32(d11, output_min);
|
||||
d11 = vminq_s32(d11, output_max);
|
||||
|
||||
(output_data)[0] = (int8_t)d00[0];
|
||||
(output_data + 1)[0] = (int8_t)d00[1];
|
||||
(output_data + 2)[0] = (int8_t)d00[2];
|
||||
(output_data + 3)[0] = (int8_t)d00[3];
|
||||
|
||||
if (w_not_bound) {
|
||||
*(output_data + 4) = (int8_t)d01[0];
|
||||
*(output_data + 5) = (int8_t)d01[1];
|
||||
*(output_data + 6) = (int8_t)d01[2];
|
||||
*(output_data + 7) = (int8_t)d01[3];
|
||||
}
|
||||
if (h_not_bound) {
|
||||
*(output_data + output_w * 4) = (int8_t)d10[0];
|
||||
*(output_data + output_w * 4 + 1) = (int8_t)d10[1];
|
||||
*(output_data + output_w * 4 + 2) = (int8_t)d10[2];
|
||||
*(output_data + output_w * 4 + 3) = (int8_t)d10[3];
|
||||
if (w_not_bound) {
|
||||
*(output_data + output_w * 4 + 4) = (int8_t)d11[0];
|
||||
*(output_data + output_w * 4 + 5) = (int8_t)d11[1];
|
||||
*(output_data + output_w * 4 + 6) = (int8_t)d11[2];
|
||||
*(output_data + output_w * 4 + 7) = (int8_t)d11[3];
|
||||
}
|
||||
}
|
||||
#else
|
||||
if ((conv_param->conv_quant_arg_.per_channel_ & FILTER_PER_CHANNEL)) {
|
||||
for (int i = 0; i < C4NUM; i++) {
|
||||
const int32_t *local_ptr = gemm_out + i;
|
||||
const int32_t *bias_ptr = bias_data + i;
|
||||
|
||||
int32_t s00 = local_ptr[0];
|
||||
int32_t s01 = (local_ptr + 4)[0];
|
||||
int32_t s02 = (local_ptr + 8)[0];
|
||||
int32_t s03 = (local_ptr + 12)[0];
|
||||
|
||||
int32_t s10 = (local_ptr + 16)[0];
|
||||
int32_t s11 = (local_ptr + 20)[0];
|
||||
int32_t s12 = (local_ptr + 24)[0];
|
||||
int32_t s13 = (local_ptr + 28)[0];
|
||||
|
||||
int32_t s20 = (local_ptr + 32)[0];
|
||||
int32_t s21 = (local_ptr + 36)[0];
|
||||
int32_t s22 = (local_ptr + 40)[0];
|
||||
int32_t s23 = (local_ptr + 44)[0];
|
||||
|
||||
int32_t s30 = (local_ptr + 48)[0];
|
||||
int32_t s31 = (local_ptr + 52)[0];
|
||||
int32_t s32 = (local_ptr + 56)[0];
|
||||
int32_t s33 = (local_ptr + 60)[0];
|
||||
|
||||
int32_t t00 = (s00 + s10 + s20) / 2;
|
||||
int32_t t01 = (s01 + s11 + s21) / 2;
|
||||
int32_t t02 = (s02 + s12 + s22) / 2;
|
||||
int32_t t03 = (s03 + s13 + s23) / 2;
|
||||
|
||||
int32_t t10 = (s10 - s20 - s30) / 2;
|
||||
int32_t t11 = (s11 - s21 - s31) / 2;
|
||||
int32_t t12 = (s12 - s22 - s32) / 2;
|
||||
int32_t t13 = (s13 - s23 - s33) / 2;
|
||||
|
||||
int32_t d00 = (t00 + t01 + t02) / 2 + bias_ptr[0];
|
||||
int32_t d01 = (t01 - t02 - t03) / 2 + bias_ptr[0];
|
||||
|
||||
int32_t d10 = (t10 + t11 + t12) / 2 + bias_ptr[0];
|
||||
int32_t d11 = (t11 - t12 - t13) / 2 + bias_ptr[0];
|
||||
|
||||
int oc_index = oc_start + i;
|
||||
d00 = RoundingDivideByPOT(
|
||||
SaturatingRoundingDoublingHighMul(d00 * (1 << (unsigned int)left_shift[oc_index]), quant_multiplier[oc_index]),
|
||||
-right_shift[oc_index]);
|
||||
d00 += output_zp;
|
||||
d00 = d00 > out_min ? d00 : out_min;
|
||||
d00 = d00 < out_max ? d00 : out_max;
|
||||
|
||||
d01 = RoundingDivideByPOT(
|
||||
SaturatingRoundingDoublingHighMul(d01 * (1 << (unsigned int)left_shift[oc_index]), quant_multiplier[oc_index]),
|
||||
-right_shift[oc_index]);
|
||||
d01 += output_zp;
|
||||
d01 = d01 > out_min ? d01 : out_min;
|
||||
d01 = d01 < out_max ? d01 : out_max;
|
||||
|
||||
d10 = RoundingDivideByPOT(
|
||||
SaturatingRoundingDoublingHighMul(d10 * (1 << (unsigned int)left_shift[oc_index]), quant_multiplier[oc_index]),
|
||||
-right_shift[oc_index]);
|
||||
d10 += output_zp;
|
||||
d10 = d10 > out_min ? d10 : out_min;
|
||||
d10 = d10 < out_max ? d10 : out_max;
|
||||
|
||||
d11 = RoundingDivideByPOT(
|
||||
SaturatingRoundingDoublingHighMul(d11 * (1 << (unsigned int)left_shift[oc_index]), quant_multiplier[oc_index]),
|
||||
-right_shift[oc_index]);
|
||||
d11 += output_zp;
|
||||
d11 = d11 > out_min ? d11 : out_min;
|
||||
d11 = d11 < out_max ? d11 : out_max;
|
||||
|
||||
(output_data + i)[0] = (int8_t)d00;
|
||||
if (w_not_bound) {
|
||||
(output_data + i + C4NUM)[0] = (int8_t)d01;
|
||||
}
|
||||
if (h_not_bound) {
|
||||
(output_data + i + output_w * C4NUM)[0] = (int8_t)d10;
|
||||
if (w_not_bound) {
|
||||
(output_data + i + output_w * C4NUM + C4NUM)[0] = (int8_t)d11;
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
for (int i = 0; i < C4NUM; i++) {
|
||||
const int32_t *local_ptr = gemm_out + i;
|
||||
const int32_t *bias_ptr = bias_data + i;
|
||||
|
||||
int32_t s00 = local_ptr[0];
|
||||
int32_t s01 = (local_ptr + 4)[0];
|
||||
int32_t s02 = (local_ptr + 8)[0];
|
||||
int32_t s03 = (local_ptr + 12)[0];
|
||||
|
||||
int32_t s10 = (local_ptr + 16)[0];
|
||||
int32_t s11 = (local_ptr + 20)[0];
|
||||
int32_t s12 = (local_ptr + 24)[0];
|
||||
int32_t s13 = (local_ptr + 28)[0];
|
||||
|
||||
int32_t s20 = (local_ptr + 32)[0];
|
||||
int32_t s21 = (local_ptr + 36)[0];
|
||||
int32_t s22 = (local_ptr + 40)[0];
|
||||
int32_t s23 = (local_ptr + 44)[0];
|
||||
|
||||
int32_t s30 = (local_ptr + 48)[0];
|
||||
int32_t s31 = (local_ptr + 52)[0];
|
||||
int32_t s32 = (local_ptr + 56)[0];
|
||||
int32_t s33 = (local_ptr + 60)[0];
|
||||
|
||||
int32_t t00 = (s00 + s10 + s20) / 2;
|
||||
int32_t t01 = (s01 + s11 + s21) / 2;
|
||||
int32_t t02 = (s02 + s12 + s22) / 2;
|
||||
int32_t t03 = (s03 + s13 + s23) / 2;
|
||||
|
||||
int32_t t10 = (s10 - s20 - s30) / 2;
|
||||
int32_t t11 = (s11 - s21 - s31) / 2;
|
||||
int32_t t12 = (s12 - s22 - s32) / 2;
|
||||
int32_t t13 = (s13 - s23 - s33) / 2;
|
||||
|
||||
int32_t d00 = (t00 + t01 + t02) / 2 + bias_ptr[0];
|
||||
int32_t d01 = (t01 - t02 - t03) / 2 + bias_ptr[0];
|
||||
|
||||
int32_t d10 = (t10 + t11 + t12) / 2 + bias_ptr[0];
|
||||
int32_t d11 = (t11 - t12 - t13) / 2 + bias_ptr[0];
|
||||
|
||||
d00 = RoundingDivideByPOT(
|
||||
SaturatingRoundingDoublingHighMul(d00 * (1 << (unsigned int)left_shift[0]), quant_multiplier[0]),
|
||||
-right_shift[0]);
|
||||
d00 += output_zp;
|
||||
d00 = d00 > out_min ? d00 : out_min;
|
||||
d00 = d00 < out_max ? d00 : out_max;
|
||||
|
||||
d01 = RoundingDivideByPOT(
|
||||
SaturatingRoundingDoublingHighMul(d01 * (1 << (unsigned int)left_shift[0]), quant_multiplier[0]),
|
||||
-right_shift[0]);
|
||||
d01 += output_zp;
|
||||
d01 = d01 > out_min ? d01 : out_min;
|
||||
d01 = d01 < out_max ? d01 : out_max;
|
||||
|
||||
d10 = RoundingDivideByPOT(
|
||||
SaturatingRoundingDoublingHighMul(d10 * (1 << (unsigned int)left_shift[0]), quant_multiplier[0]),
|
||||
-right_shift[0]);
|
||||
d10 += output_zp;
|
||||
d10 = d10 > out_min ? d10 : out_min;
|
||||
d10 = d10 < out_max ? d10 : out_max;
|
||||
|
||||
d11 = RoundingDivideByPOT(
|
||||
SaturatingRoundingDoublingHighMul(d11 * (1 << (unsigned int)left_shift[0]), quant_multiplier[0]),
|
||||
-right_shift[0]);
|
||||
d11 += output_zp;
|
||||
d11 = d11 > out_min ? d11 : out_min;
|
||||
d11 = d11 < out_max ? d11 : out_max;
|
||||
|
||||
(output_data + i)[0] = (int8_t)d00;
|
||||
if (w_not_bound) {
|
||||
(output_data + i + C4NUM)[0] = (int8_t)d01;
|
||||
}
|
||||
if (h_not_bound) {
|
||||
(output_data + i + output_w * C4NUM)[0] = (int8_t)d10;
|
||||
if (w_not_bound) {
|
||||
(output_data + i + output_w * C4NUM + C4NUM)[0] = (int8_t)d11;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
void Conv3x3Int8OutputTransform(const int32_t *gemm_out, int8_t *out_data, const int32_t *bias_data, int start_index,
|
||||
int real_cal_num, int out_w_block, ConvParameter *conv_param) {
|
||||
int output_channel = conv_param->output_channel_;
|
||||
int output_w = conv_param->output_w_;
|
||||
int output_h = conv_param->output_h_;
|
||||
const int oc4 = UP_DIV(output_channel, C4NUM);
|
||||
const int input_unit = 4;
|
||||
if (out_w_block == 0) {
|
||||
return;
|
||||
}
|
||||
for (int i = 0; i < real_cal_num; i++) {
|
||||
int out_w_index = (start_index + i) % out_w_block;
|
||||
int out_h_index = (start_index + i) / out_w_block;
|
||||
int src_tile_offset = i * oc4 * C4NUM * input_unit * input_unit;
|
||||
int dst_tile_offset = C4NUM * (out_w_index * OUPUT_UNIT + out_h_index * OUPUT_UNIT * output_w);
|
||||
|
||||
for (int j = 0; j < oc4; j++) {
|
||||
int src_oc4_offset = src_tile_offset + j * input_unit * input_unit * C4NUM;
|
||||
int dst_oc4_offset = dst_tile_offset + j * C4NUM * output_h * output_w;
|
||||
const int32_t *src_ptr = gemm_out + src_oc4_offset;
|
||||
const int32_t *bias_ptr = bias_data + j * C4NUM;
|
||||
int8_t *dst_ptr = out_data + dst_oc4_offset;
|
||||
|
||||
// output transform
|
||||
int real_num = (output_channel - j * C4NUM) < C4NUM ? (output_channel - j * C4NUM) : C4NUM;
|
||||
bool w_not_bound = out_w_index * OUPUT_UNIT + 1 < output_w;
|
||||
bool h_not_bound = out_h_index * OUPUT_UNIT + 1 < output_h;
|
||||
Conv3x3Int8OutputUnit(src_ptr, bias_ptr, dst_ptr, h_not_bound, w_not_bound, output_w, real_num, j * C4NUM,
|
||||
conv_param);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void Conv3x3Int8InputTransform(const int16_t *input_data, int16_t *trans_input, int16_t *tmp_data, int start_index,
|
||||
int real_cal_num, int out_w_block, ConvParameter *conv_param) {
|
||||
// input data format : nhwc
|
||||
int input_channel = conv_param->input_channel_;
|
||||
int input_width = conv_param->input_w_;
|
||||
int input_height = conv_param->input_h_;
|
||||
int pad_w = conv_param->pad_l_;
|
||||
int pad_h = conv_param->pad_u_;
|
||||
ConvQuantArg quant_arg = conv_param->conv_quant_arg_;
|
||||
int input_zp = quant_arg.input_quant_args_[0].zp_;
|
||||
const int ic8 = UP_DIV(input_channel, C8NUM);
|
||||
const int input_unit = 4;
|
||||
if (out_w_block == 0) {
|
||||
return;
|
||||
}
|
||||
for (int cal_id = 0; cal_id < real_cal_num; cal_id++) {
|
||||
int x_id = start_index + cal_id;
|
||||
int origin_x = (x_id % out_w_block) * OUPUT_UNIT - pad_w;
|
||||
int origin_y = (x_id / out_w_block) * OUPUT_UNIT - pad_h;
|
||||
int real_x_start = origin_x > 0 ? 0 : -origin_x;
|
||||
int real_x_end = (origin_x + input_unit) < input_width ? input_unit : (input_width - origin_x);
|
||||
int real_y_start = origin_y > 0 ? 0 : -origin_y;
|
||||
int real_y_end = (origin_y + input_unit) < input_height ? input_unit : (input_height - origin_y);
|
||||
|
||||
int src_plane_offset = C8NUM * (origin_y * input_width + origin_x);
|
||||
int dst_plane_offset = cal_id * C8NUM;
|
||||
for (int ic = 0; ic < ic8; ic++) {
|
||||
// copy data from origin input to tmp buffer
|
||||
for (int i = 0; i < input_unit * input_unit * TILE_NUM; i++) tmp_data[i] = input_zp;
|
||||
|
||||
int src_c8_offset = src_plane_offset + ic * C8NUM * input_height * input_width;
|
||||
for (int j = real_y_start; j < real_y_end; j++) {
|
||||
const int16_t *src = input_data + src_c8_offset + C8NUM * (j * input_width + real_x_start);
|
||||
int16_t *dst = tmp_data + C8NUM * (C4NUM * j + real_x_start);
|
||||
memcpy(dst, src, (real_x_end - real_x_start) * C8NUM * sizeof(int16_t));
|
||||
}
|
||||
// input transform
|
||||
int dst_ic8_offset = dst_plane_offset + ic * TILE_NUM * C8NUM;
|
||||
size_t dst_step = ic8 * C8NUM * TILE_NUM;
|
||||
int16_t *trans_input_ptr = trans_input + dst_ic8_offset;
|
||||
Conv3x3Int8InputUnit(tmp_data, trans_input_ptr, dst_step, input_zp);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void Conv3x3Int8Gemm(int32_t *dst, const int16_t *src, const int16_t *weight, int oc, int ic8, size_t real_cal_num) {
|
||||
int oc4 = UP_DIV(oc, C4NUM);
|
||||
#ifdef ENABLE_ARM
|
||||
IndirectGemmInt16to32_8x4(dst, src, weight, 16, ic8, oc4, oc4 * 4 * 16 * sizeof(int32_t));
|
||||
#else
|
||||
const int input_unit_square = 16;
|
||||
for (int c = 0; c < oc4; c++) {
|
||||
int filter_oc_offset = c * input_unit_square * ic8 * C8NUM * C4NUM;
|
||||
int dst_oc_offset = c * input_unit_square * C4NUM;
|
||||
for (int n = 0; n < real_cal_num; n++) {
|
||||
int src_tile_offset = n * C8NUM;
|
||||
int dst_tile_offset = dst_oc_offset + n * oc4 * C4NUM * input_unit_square;
|
||||
for (int i = 0; i < 4; i++) {
|
||||
int filter_h_offset = filter_oc_offset + i * 4 * ic8 * C8NUM * C4NUM;
|
||||
int src_h_offset = src_tile_offset + i * C8NUM * ic8 * C8NUM * C4NUM;
|
||||
int dst_h_offset = dst_tile_offset + i * 4 * 4;
|
||||
for (int m = 0; m < 4; m++) {
|
||||
int filter_w_offset = filter_h_offset + m * 4 * C8NUM * ic8;
|
||||
int src_w_offset = src_h_offset + m * 8 * ic8 * C8NUM;
|
||||
int dst_w_offset = dst_h_offset + m * C4NUM;
|
||||
|
||||
int32_t acc[4] = {0};
|
||||
for (int z = 0; z < 4; z++) {
|
||||
int filter_offset = filter_w_offset + z;
|
||||
for (int j = 0; j < ic8; j++) {
|
||||
int filter_c8_offset = filter_offset + j * 4 * 8;
|
||||
int src_c8_offset = src_w_offset + j * 8 * 8;
|
||||
|
||||
for (int k = 0; k < 8; k++) {
|
||||
const int16_t *w_ptr = weight + filter_c8_offset + k * 4;
|
||||
const int16_t *input_ptr = src + src_c8_offset + k;
|
||||
acc[z] += w_ptr[0] * input_ptr[0];
|
||||
}
|
||||
}
|
||||
(dst + dst_w_offset + z)[0] = acc[z];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
// int8 convolution 3x3
|
||||
void Conv3x3Int8(int16_t *input_data, int16_t *transed_weight, const int32_t *bias_data, int8_t *output_data,
|
||||
int16_t *tile_buffer, int16_t *block_unit_buffer, int32_t *tmp_dst_buffer, int8_t *tmp_out,
|
||||
int task_id, ConvParameter *conv_param) {
|
||||
int ic8 = UP_DIV(conv_param->input_channel_, C8NUM);
|
||||
int out_w_block = UP_DIV(conv_param->output_w_, OUPUT_UNIT);
|
||||
int out_h_block = UP_DIV(conv_param->output_h_, OUPUT_UNIT);
|
||||
int output_count = out_w_block * out_h_block;
|
||||
int output_tile_count = UP_DIV(output_count, TILE_NUM);
|
||||
int oc4 = UP_DIV(conv_param->output_channel_, C4NUM);
|
||||
int tile_buffer_offset = TILE_NUM * 16 * ic8 * C8NUM;
|
||||
const int block_unit_buffer_offset = 16 * C8NUM;
|
||||
int tmp_dst_buffer_offset = TILE_NUM * 16 * oc4 * C4NUM;
|
||||
|
||||
for (int batch = 0; batch < conv_param->input_batch_; batch++) {
|
||||
int in_batch_offset = batch * ic8 * C8NUM * conv_param->input_h_ * conv_param->input_w_;
|
||||
int tmp_out_batch_offset = batch * oc4 * C4NUM * conv_param->output_w_ * conv_param->output_h_;
|
||||
for (int thread_id = task_id; thread_id < output_tile_count; thread_id += conv_param->thread_num_) {
|
||||
int start_index = thread_id * TILE_NUM;
|
||||
int real_cal_num = (output_count - start_index) < TILE_NUM ? (output_count - start_index) : TILE_NUM;
|
||||
|
||||
Conv3x3Int8InputTransform(input_data + in_batch_offset, tile_buffer + task_id * tile_buffer_offset,
|
||||
block_unit_buffer + task_id * block_unit_buffer_offset, start_index, real_cal_num,
|
||||
out_w_block, conv_param);
|
||||
|
||||
Conv3x3Int8Gemm(tmp_dst_buffer + task_id * tmp_dst_buffer_offset, tile_buffer + task_id * tile_buffer_offset,
|
||||
transed_weight, conv_param->output_channel_, ic8, real_cal_num);
|
||||
|
||||
Conv3x3Int8OutputTransform(tmp_dst_buffer + task_id * tmp_dst_buffer_offset, tmp_out + tmp_out_batch_offset,
|
||||
bias_data, start_index, real_cal_num, out_w_block, conv_param);
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,49 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#ifndef MINDSPORE_LITE_NNACL_INT8_CONV_INT8_H_
|
||||
#define MINDSPORE_LITE_NNACL_INT8_CONV_INT8_H_
|
||||
|
||||
#include <string.h>
|
||||
#ifdef ENABLE_NEON
|
||||
#include <arm_neon.h>
|
||||
#endif
|
||||
#include "nnacl/pack.h"
|
||||
#include "nnacl/op_base.h"
|
||||
#include "nnacl/common_func.h"
|
||||
#include "nnacl/conv_parameter.h"
|
||||
#include "nnacl/winograd_utils.h"
|
||||
#include "nnacl/int8/quantize.h"
|
||||
#include "nnacl/matmul_parameter.h"
|
||||
#include "nnacl/int8/matmul_int8.h"
|
||||
#include "nnacl/winograd_transform.h"
|
||||
#include "nnacl/int8/common_func_int8.h"
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif
|
||||
|
||||
void Conv3x3Int8FilterTransform(const int16_t *weight_data, int16_t *trans_weight, int iC8, int output_channel,
|
||||
int kernel_plane);
|
||||
|
||||
void Conv3x3Int8(int16_t *input_data, int16_t *transed_weight, const int32_t *bias_data, int8_t *output_data,
|
||||
int16_t *tile_buffer, int16_t *block_unit_buffer, int32_t *tmp_dst_buffer, int8_t *tmp_out,
|
||||
int task_id, ConvParameter *conv_param);
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
|
||||
#endif // MINDSPORE_LITE_NNACL_INT8_CONV_INT8_H_
|
|
@ -16,7 +16,7 @@
|
|||
|
||||
#include "nnacl/int8/conv_depthwise_int8.h"
|
||||
#include <string.h>
|
||||
#include "nnacl/quantization/fixed_point.h"
|
||||
#include "nnacl/int8/fixed_point.h"
|
||||
#include "nnacl/int8/common_func_int8.h"
|
||||
|
||||
/*conv depthwise int8 begin*/
|
||||
|
|
|
@ -15,52 +15,6 @@
|
|||
*/
|
||||
|
||||
#include "nnacl/int8/conv_int8.h"
|
||||
#include <string.h>
|
||||
#include "nnacl/winograd_transform.h"
|
||||
#include "nnacl/int8/common_func_int8.h"
|
||||
|
||||
void Conv3x3Int8Gemm(int32_t *dst, const int16_t *src, const int16_t *weight, int oc, int ic8, size_t real_cal_num) {
|
||||
int oc4 = UP_DIV(oc, C4NUM);
|
||||
#ifdef ENABLE_ARM
|
||||
IndirectGemmInt16to32_8x4(dst, src, weight, 16, ic8, oc4, oc4 * 4 * 16 * sizeof(int32_t));
|
||||
#else
|
||||
const int input_unit_square = 16;
|
||||
for (int c = 0; c < oc4; c++) {
|
||||
int filter_oc_offset = c * input_unit_square * ic8 * C8NUM * C4NUM;
|
||||
int dst_oc_offset = c * input_unit_square * C4NUM;
|
||||
for (int n = 0; n < real_cal_num; n++) {
|
||||
int src_tile_offset = n * C8NUM;
|
||||
int dst_tile_offset = dst_oc_offset + n * oc4 * C4NUM * input_unit_square;
|
||||
for (int i = 0; i < 4; i++) {
|
||||
int filter_h_offset = filter_oc_offset + i * 4 * ic8 * C8NUM * C4NUM;
|
||||
int src_h_offset = src_tile_offset + i * C8NUM * ic8 * C8NUM * C4NUM;
|
||||
int dst_h_offset = dst_tile_offset + i * 4 * 4;
|
||||
for (int m = 0; m < 4; m++) {
|
||||
int filter_w_offset = filter_h_offset + m * 4 * C8NUM * ic8;
|
||||
int src_w_offset = src_h_offset + m * 8 * ic8 * C8NUM;
|
||||
int dst_w_offset = dst_h_offset + m * C4NUM;
|
||||
|
||||
int32_t acc[4] = {0};
|
||||
for (int z = 0; z < 4; z++) {
|
||||
int filter_offset = filter_w_offset + z;
|
||||
for (int j = 0; j < ic8; j++) {
|
||||
int filter_c8_offset = filter_offset + j * 4 * 8;
|
||||
int src_c8_offset = src_w_offset + j * 8 * 8;
|
||||
|
||||
for (int k = 0; k < 8; k++) {
|
||||
const int16_t *w_ptr = weight + filter_c8_offset + k * 4;
|
||||
const int16_t *input_ptr = src + src_c8_offset + k;
|
||||
acc[z] += w_ptr[0] * input_ptr[0];
|
||||
}
|
||||
}
|
||||
(dst + dst_w_offset + z)[0] = acc[z];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
void ConvInt8(int8_t *input_data, int8_t *packed_input, int8_t *matmul_input, int8_t *packed_weight,
|
||||
const int32_t *bias_data, int8_t *output_data, int32_t *filter_zp, int32_t *input_sum, int task_id,
|
||||
|
@ -141,717 +95,3 @@ void ConvInt8(int8_t *input_data, int8_t *packed_input, int8_t *matmul_input, in
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
void Conv1x1PreOptPeroc(const int8_t *src_input, int8_t *packed_input, int32_t *input_sum, size_t input_channel,
|
||||
size_t output_channel, size_t plane_size, int32_t *filter_zp, size_t inputsum_stride) {
|
||||
int ic4 = UP_ROUND(input_channel, C4NUM);
|
||||
int oc8 = UP_ROUND(output_channel, C8NUM);
|
||||
int hw8 = UP_ROUND(plane_size, C8NUM);
|
||||
size_t hw_8div = plane_size / C8NUM * C8NUM;
|
||||
size_t oc_8div = output_channel / C8NUM * C8NUM;
|
||||
size_t oc_8res = output_channel - oc_8div;
|
||||
size_t ic_4div = input_channel / C4NUM * C4NUM;
|
||||
|
||||
const int8_t *src_r = src_input;
|
||||
int8_t *pack_r = packed_input;
|
||||
int32_t *input_sum_r = input_sum;
|
||||
|
||||
for (int hwi = 0; hwi < hw_8div; hwi += C8NUM) {
|
||||
const int8_t *src_ic = src_r;
|
||||
int8_t *pack_ic = pack_r;
|
||||
int32_t *input_sum_oc = input_sum_r;
|
||||
#ifdef ENABLE_ARM64
|
||||
size_t src_stride = input_channel;
|
||||
size_t ic_4res = input_channel - ic_4div;
|
||||
size_t input_sum_stride = inputsum_stride * 4 - C8NUM * C8NUM * 4;
|
||||
asm volatile(
|
||||
"dup v16.4s, wzr \n"
|
||||
"dup v17.4s, wzr \n"
|
||||
|
||||
"mov x10, %[src_ic] \n"
|
||||
"mov x11, %[pack_ic] \n"
|
||||
|
||||
"mov x0, #0 \n"
|
||||
"1: \n"
|
||||
"cmp x0, %[ic_4div] \n"
|
||||
"add x0, x0, #4\n"
|
||||
"mov x12, x10 \n"
|
||||
"add x10, x10, #4\n"
|
||||
"blt 2f \n"
|
||||
"cmp %[ic_4res], #0\n"
|
||||
"beq 6f \n"
|
||||
"cmp %[ic_4res], #1\n"
|
||||
"beq 3f \n"
|
||||
"cmp %[ic_4res], #2\n"
|
||||
"beq 4f \n"
|
||||
"cmp %[ic_4res], #3\n"
|
||||
"beq 5f \n"
|
||||
|
||||
"2: \n"
|
||||
"ld1 {v0.s}[0], [x12], %[src_stride]\n"
|
||||
"ld1 {v0.s}[1], [x12], %[src_stride]\n"
|
||||
"ld1 {v0.s}[2], [x12], %[src_stride]\n"
|
||||
"ld1 {v0.s}[3], [x12], %[src_stride]\n"
|
||||
"ld1 {v1.s}[0], [x12], %[src_stride]\n"
|
||||
"ld1 {v1.s}[1], [x12], %[src_stride]\n"
|
||||
"ld1 {v1.s}[2], [x12], %[src_stride]\n"
|
||||
"ld1 {v1.s}[3], [x12], %[src_stride]\n"
|
||||
|
||||
"st1 {v0.16b}, [x11], #16\n"
|
||||
"st1 {v1.16b}, [x11], #16\n"
|
||||
|
||||
"saddlp v4.8h, v0.16b \n"
|
||||
"saddlp v5.8h, v1.16b \n"
|
||||
"saddlp v0.4s, v4.8h \n"
|
||||
"saddlp v1.4s, v5.8h \n"
|
||||
"add v16.4s, v16.4s, v0.4s \n"
|
||||
"add v17.4s, v17.4s, v1.4s \n"
|
||||
"b 1b \n"
|
||||
|
||||
"3: \n" /* col res 1 */
|
||||
"dup v0.4s, wzr \n"
|
||||
"dup v1.4s, wzr \n"
|
||||
|
||||
"ld1 {v0.b}[0], [x12], %[src_stride]\n"
|
||||
"ld1 {v0.b}[4], [x12], %[src_stride]\n"
|
||||
"ld1 {v0.b}[8], [x12], %[src_stride]\n"
|
||||
"ld1 {v0.b}[12], [x12], %[src_stride]\n"
|
||||
"ld1 {v1.b}[0], [x12], %[src_stride]\n"
|
||||
"ld1 {v1.b}[4], [x12], %[src_stride]\n"
|
||||
"ld1 {v1.b}[8], [x12], %[src_stride]\n"
|
||||
"ld1 {v1.b}[12], [x12], %[src_stride]\n"
|
||||
|
||||
"st1 {v0.16b}, [x11], #16\n"
|
||||
"st1 {v1.16b}, [x11], #16\n"
|
||||
|
||||
"saddlp v4.8h, v0.16b \n"
|
||||
"saddlp v5.8h, v1.16b \n"
|
||||
"saddlp v0.4s, v4.8h \n"
|
||||
"saddlp v1.4s, v5.8h \n"
|
||||
"add v16.4s, v16.4s, v0.4s \n"
|
||||
"add v17.4s, v17.4s, v1.4s \n"
|
||||
"b 6f \n"
|
||||
|
||||
"4: \n" /* col res 2 */
|
||||
"dup v0.4s, wzr \n"
|
||||
"dup v1.4s, wzr \n"
|
||||
|
||||
"ld1 {v0.h}[0], [x12], %[src_stride]\n"
|
||||
"ld1 {v0.h}[2], [x12], %[src_stride]\n"
|
||||
"ld1 {v0.h}[4], [x12], %[src_stride]\n"
|
||||
"ld1 {v0.h}[6], [x12], %[src_stride]\n"
|
||||
"ld1 {v1.h}[0], [x12], %[src_stride]\n"
|
||||
"ld1 {v1.h}[2], [x12], %[src_stride]\n"
|
||||
"ld1 {v1.h}[4], [x12], %[src_stride]\n"
|
||||
"ld1 {v1.h}[6], [x12], %[src_stride]\n"
|
||||
|
||||
"st1 {v0.16b}, [x11], #16\n"
|
||||
"st1 {v1.16b}, [x11], #16\n"
|
||||
|
||||
"saddlp v4.8h, v0.16b \n"
|
||||
"saddlp v5.8h, v1.16b \n"
|
||||
"saddlp v0.4s, v4.8h \n"
|
||||
"saddlp v1.4s, v5.8h \n"
|
||||
"add v16.4s, v16.4s, v0.4s \n"
|
||||
"add v17.4s, v17.4s, v1.4s \n"
|
||||
"b 6f \n"
|
||||
|
||||
"5: \n" /* col res 3 */
|
||||
"dup v0.4s, wzr \n"
|
||||
"dup v1.4s, wzr \n"
|
||||
"add x13, x12, #2 \n"
|
||||
|
||||
"ld1 {v0.h}[0], [x12], %[src_stride]\n"
|
||||
"ld1 {v0.b}[2], [x13], %[src_stride]\n"
|
||||
"ld1 {v0.h}[2], [x12], %[src_stride]\n"
|
||||
"ld1 {v0.b}[6], [x13], %[src_stride]\n"
|
||||
"ld1 {v0.h}[4], [x12], %[src_stride]\n"
|
||||
"ld1 {v0.b}[10], [x13], %[src_stride]\n"
|
||||
"ld1 {v0.h}[6], [x12], %[src_stride]\n"
|
||||
"ld1 {v0.b}[14], [x13], %[src_stride]\n"
|
||||
"ld1 {v1.h}[0], [x12], %[src_stride]\n"
|
||||
"ld1 {v1.b}[2], [x13], %[src_stride]\n"
|
||||
"ld1 {v1.h}[2], [x12], %[src_stride]\n"
|
||||
"ld1 {v1.b}[6], [x13], %[src_stride]\n"
|
||||
"ld1 {v1.h}[4], [x12], %[src_stride]\n"
|
||||
"ld1 {v1.b}[10], [x13], %[src_stride]\n"
|
||||
"ld1 {v1.h}[6], [x12], %[src_stride]\n"
|
||||
"ld1 {v1.b}[14], [x13], %[src_stride]\n"
|
||||
|
||||
"st1 {v0.16b}, [x11], #16\n"
|
||||
"st1 {v1.16b}, [x11], #16\n"
|
||||
|
||||
"saddlp v4.8h, v0.16b \n"
|
||||
"saddlp v5.8h, v1.16b \n"
|
||||
"saddlp v0.4s, v4.8h \n"
|
||||
"saddlp v1.4s, v5.8h \n"
|
||||
"add v16.4s, v16.4s, v0.4s \n"
|
||||
"add v17.4s, v17.4s, v1.4s \n"
|
||||
"b 6f \n"
|
||||
|
||||
"6: \n"
|
||||
"dup v0.4s, v16.s[0] \n"
|
||||
"dup v1.4s, v16.s[1] \n"
|
||||
"dup v2.4s, v16.s[2] \n"
|
||||
"dup v3.4s, v16.s[3] \n"
|
||||
"dup v4.4s, v17.s[0] \n"
|
||||
"dup v5.4s, v17.s[1] \n"
|
||||
"dup v6.4s, v17.s[2] \n"
|
||||
"dup v7.4s, v17.s[3] \n"
|
||||
"mov x4, #0 \n"
|
||||
"mov x10, %[filter_zp] \n"
|
||||
"mov x11, %[input_sum_oc] \n"
|
||||
|
||||
"7: \n"
|
||||
"cmp x4, %[oc_8div] \n"
|
||||
"beq 8f \n"
|
||||
"add x4, x4, #8\n"
|
||||
"ld1 {v16.4s}, [x10], #16\n"
|
||||
"ld1 {v17.4s}, [x10], #16\n"
|
||||
|
||||
"mul v18.4s, v16.4s, v0.4s \n"
|
||||
"mul v19.4s, v17.4s, v0.4s \n"
|
||||
"st1 {v18.4s}, [x11], #16 \n"
|
||||
"st1 {v19.4s}, [x11], #16 \n"
|
||||
|
||||
"mul v20.4s, v16.4s, v1.4s \n"
|
||||
"mul v21.4s, v17.4s, v1.4s \n"
|
||||
"st1 {v20.4s}, [x11], #16 \n"
|
||||
"st1 {v21.4s}, [x11], #16 \n"
|
||||
|
||||
"mul v22.4s, v16.4s, v2.4s \n"
|
||||
"mul v23.4s, v17.4s, v2.4s \n"
|
||||
"st1 {v22.4s}, [x11], #16 \n"
|
||||
"st1 {v23.4s}, [x11], #16 \n"
|
||||
|
||||
"mul v24.4s, v16.4s, v3.4s \n"
|
||||
"mul v25.4s, v17.4s, v3.4s \n"
|
||||
"st1 {v24.4s}, [x11], #16 \n"
|
||||
"st1 {v25.4s}, [x11], #16 \n"
|
||||
|
||||
"mul v18.4s, v16.4s, v4.4s \n"
|
||||
"mul v19.4s, v17.4s, v4.4s \n"
|
||||
"st1 {v18.4s}, [x11], #16 \n"
|
||||
"st1 {v19.4s}, [x11], #16 \n"
|
||||
|
||||
"mul v20.4s, v16.4s, v5.4s \n"
|
||||
"mul v21.4s, v17.4s, v5.4s \n"
|
||||
"st1 {v20.4s}, [x11], #16 \n"
|
||||
"st1 {v21.4s}, [x11], #16 \n"
|
||||
|
||||
"mul v22.4s, v16.4s, v6.4s \n"
|
||||
"mul v23.4s, v17.4s, v6.4s \n"
|
||||
"st1 {v22.4s}, [x11], #16 \n"
|
||||
"st1 {v23.4s}, [x11], #16 \n"
|
||||
|
||||
"mul v24.4s, v16.4s, v7.4s \n"
|
||||
"mul v25.4s, v17.4s, v7.4s \n"
|
||||
"st1 {v24.4s}, [x11], #16 \n"
|
||||
"st1 {v25.4s}, [x11], #16 \n"
|
||||
|
||||
"add x11, x11, %[input_sum_stride] \n"
|
||||
"b 7b \n"
|
||||
|
||||
"8: \n"
|
||||
"cmp %[oc_8res], #0\n"
|
||||
"beq 17f \n"
|
||||
|
||||
"dup v16.4s, wzr \n"
|
||||
"dup v17.4s, wzr \n"
|
||||
"cmp %[oc_8res], #1\n"
|
||||
"beq 9f \n"
|
||||
"cmp %[oc_8res], #2\n"
|
||||
"beq 10f \n"
|
||||
"cmp %[oc_8res], #3\n"
|
||||
"beq 11f \n"
|
||||
"cmp %[oc_8res], #4\n"
|
||||
"beq 12f \n"
|
||||
"cmp %[oc_8res], #5\n"
|
||||
"beq 13f \n"
|
||||
"cmp %[oc_8res], #6\n"
|
||||
"beq 14f \n"
|
||||
"cmp %[oc_8res], #7\n"
|
||||
"beq 15f \n"
|
||||
|
||||
"9: \n"
|
||||
"ld1 {v16.s}[0], [x10] \n"
|
||||
"b 16f \n"
|
||||
|
||||
"10: \n"
|
||||
"ld1 {v16.d}[0], [x10] \n"
|
||||
"b 16f \n"
|
||||
|
||||
"11: \n"
|
||||
"ld1 {v16.d}[0], [x10] \n"
|
||||
"add x10, x10, #8 \n"
|
||||
"ld1 {v16.s}[2], [x10] \n"
|
||||
"b 16f \n"
|
||||
|
||||
"12: \n"
|
||||
"ld1 {v16.4s}, [x10] \n"
|
||||
"b 16f \n"
|
||||
|
||||
"13: \n"
|
||||
"ld1 {v16.4s}, [x10], #16\n"
|
||||
"ld1 {v17.s}[0], [x10] \n"
|
||||
"b 16f \n"
|
||||
|
||||
"14: \n"
|
||||
"ld1 {v16.4s}, [x10], #16\n"
|
||||
"ld1 {v17.d}[0], [x10] \n"
|
||||
"b 16f \n"
|
||||
|
||||
"15: \n"
|
||||
"ld1 {v16.4s}, [x10], #16\n"
|
||||
"ld1 {v17.d}[0], [x10] \n"
|
||||
"add x10, x10, #8 \n"
|
||||
"ld1 {v17.s}[2], [x10] \n"
|
||||
"b 16f \n"
|
||||
|
||||
"16: \n"
|
||||
"mul v18.4s, v16.4s, v0.4s \n"
|
||||
"mul v19.4s, v17.4s, v0.4s \n"
|
||||
"mul v20.4s, v16.4s, v1.4s \n"
|
||||
"mul v21.4s, v17.4s, v1.4s \n"
|
||||
"mul v22.4s, v16.4s, v2.4s \n"
|
||||
"mul v23.4s, v17.4s, v2.4s \n"
|
||||
"mul v24.4s, v16.4s, v3.4s \n"
|
||||
"mul v25.4s, v17.4s, v3.4s \n"
|
||||
"st1 {v18.4s}, [x11], #16 \n"
|
||||
"st1 {v19.4s}, [x11], #16 \n"
|
||||
"st1 {v20.4s}, [x11], #16 \n"
|
||||
"st1 {v21.4s}, [x11], #16 \n"
|
||||
"st1 {v22.4s}, [x11], #16 \n"
|
||||
"st1 {v23.4s}, [x11], #16 \n"
|
||||
"st1 {v24.4s}, [x11], #16 \n"
|
||||
"st1 {v25.4s}, [x11], #16 \n"
|
||||
|
||||
"mul v18.4s, v16.4s, v4.4s \n"
|
||||
"mul v19.4s, v17.4s, v4.4s \n"
|
||||
"mul v20.4s, v16.4s, v5.4s \n"
|
||||
"mul v21.4s, v17.4s, v5.4s \n"
|
||||
"mul v22.4s, v16.4s, v6.4s \n"
|
||||
"mul v23.4s, v17.4s, v6.4s \n"
|
||||
"mul v24.4s, v16.4s, v7.4s \n"
|
||||
"mul v25.4s, v17.4s, v7.4s \n"
|
||||
"st1 {v18.4s}, [x11], #16 \n"
|
||||
"st1 {v19.4s}, [x11], #16 \n"
|
||||
"st1 {v20.4s}, [x11], #16 \n"
|
||||
"st1 {v21.4s}, [x11], #16 \n"
|
||||
"st1 {v22.4s}, [x11], #16 \n"
|
||||
"st1 {v23.4s}, [x11], #16 \n"
|
||||
"st1 {v24.4s}, [x11], #16 \n"
|
||||
"st1 {v25.4s}, [x11], #16 \n"
|
||||
|
||||
"17: \n"
|
||||
|
||||
:
|
||||
: [ src_ic ] "r"(src_ic), [ pack_ic ] "r"(pack_ic), [ filter_zp ] "r"(filter_zp),
|
||||
[ input_sum_oc ] "r"(input_sum_oc), [ input_sum_stride ] "r"(input_sum_stride), [ src_stride ] "r"(src_stride),
|
||||
[ ic_4div ] "r"(ic_4div), [ ic_4res ] "r"(ic_4res), [ oc_8div ] "r"(oc_8div), [ oc_8res ] "r"(oc_8res)
|
||||
: "x0", "x1", "x4", "x9", "x10", "x11", "x12", "x13", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v16",
|
||||
"v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25");
|
||||
#else
|
||||
int32_t tmp_sum_value[8] = {0};
|
||||
for (int ici = 0; ici < ic_4div; ici += C4NUM) {
|
||||
for (int i = 0; i < C8NUM; i++) {
|
||||
tmp_sum_value[i] += src_ic[0 + i * input_channel];
|
||||
tmp_sum_value[i] += src_ic[1 + i * input_channel];
|
||||
tmp_sum_value[i] += src_ic[2 + i * input_channel];
|
||||
tmp_sum_value[i] += src_ic[3 + i * input_channel];
|
||||
pack_ic[0 + i * C4NUM] = src_ic[0 + i * input_channel];
|
||||
pack_ic[1 + i * C4NUM] = src_ic[1 + i * input_channel];
|
||||
pack_ic[2 + i * C4NUM] = src_ic[2 + i * input_channel];
|
||||
pack_ic[3 + i * C4NUM] = src_ic[3 + i * input_channel];
|
||||
}
|
||||
src_ic += C4NUM;
|
||||
pack_ic += C4NUM * C8NUM;
|
||||
}
|
||||
for (int ici = ic_4div; ici < input_channel; ici += 1) {
|
||||
for (int i = 0; i < C8NUM; i++) {
|
||||
tmp_sum_value[i] += src_ic[i * input_channel];
|
||||
pack_ic[i * C4NUM] = src_ic[i * input_channel];
|
||||
}
|
||||
src_ic += 1;
|
||||
pack_ic += 1;
|
||||
}
|
||||
|
||||
for (int ici = input_channel; ici < ic4; ici += 1) {
|
||||
for (int i = 0; i < C8NUM; i++) {
|
||||
pack_ic[i * C4NUM] = 0;
|
||||
}
|
||||
pack_ic += 1;
|
||||
}
|
||||
|
||||
for (int oci = 0; oci < oc_8div; oci += C8NUM) {
|
||||
for (int ri = 0; ri < C8NUM; ri++) {
|
||||
input_sum_oc[ri * C8NUM + 0] = tmp_sum_value[ri] * filter_zp[oci + 0];
|
||||
input_sum_oc[ri * C8NUM + 1] = tmp_sum_value[ri] * filter_zp[oci + 1];
|
||||
input_sum_oc[ri * C8NUM + 2] = tmp_sum_value[ri] * filter_zp[oci + 2];
|
||||
input_sum_oc[ri * C8NUM + 3] = tmp_sum_value[ri] * filter_zp[oci + 3];
|
||||
input_sum_oc[ri * C8NUM + 4] = tmp_sum_value[ri] * filter_zp[oci + 4];
|
||||
input_sum_oc[ri * C8NUM + 5] = tmp_sum_value[ri] * filter_zp[oci + 5];
|
||||
input_sum_oc[ri * C8NUM + 6] = tmp_sum_value[ri] * filter_zp[oci + 6];
|
||||
input_sum_oc[ri * C8NUM + 7] = tmp_sum_value[ri] * filter_zp[oci + 7];
|
||||
}
|
||||
input_sum_oc += inputsum_stride;
|
||||
}
|
||||
if (oc_8div != output_channel) {
|
||||
for (int oci = 0; oci < oc_8res; oci += 1) {
|
||||
for (int ri = 0; ri < C8NUM; ri++) {
|
||||
input_sum_oc[ri * C8NUM + oci] = tmp_sum_value[ri] * filter_zp[oc_8div + oci];
|
||||
}
|
||||
}
|
||||
for (int oci = oc_8res; oci < C8NUM; oci += 1) {
|
||||
for (int ri = 0; ri < C8NUM; ri++) {
|
||||
input_sum_oc[ri * C8NUM + oci] = 0;
|
||||
}
|
||||
}
|
||||
} /* oc8 res done */
|
||||
#endif
|
||||
src_r += input_channel * C8NUM;
|
||||
pack_r += ic4 * C8NUM;
|
||||
input_sum_r += C8NUM * C8NUM;
|
||||
}
|
||||
|
||||
if (hw_8div != plane_size) {
|
||||
memset(pack_r, 0, C8NUM * ic4);
|
||||
for (int hwi = hw_8div; hwi < plane_size; hwi += 1) {
|
||||
int32_t *input_sum_oc = input_sum_r;
|
||||
int32_t tmp_sum_value = 0;
|
||||
const int8_t *src_ic = src_r;
|
||||
int8_t *pack_ic = pack_r;
|
||||
for (int ici = 0; ici < ic_4div; ici += C4NUM) {
|
||||
tmp_sum_value += src_ic[0];
|
||||
tmp_sum_value += src_ic[1];
|
||||
tmp_sum_value += src_ic[2];
|
||||
tmp_sum_value += src_ic[3];
|
||||
pack_ic[0] = src_ic[0];
|
||||
pack_ic[1] = src_ic[1];
|
||||
pack_ic[2] = src_ic[2];
|
||||
pack_ic[3] = src_ic[3];
|
||||
src_ic += C4NUM;
|
||||
pack_ic += C4NUM * C8NUM;
|
||||
}
|
||||
for (int ici = ic_4div; ici < input_channel; ici += 1) {
|
||||
tmp_sum_value += src_ic[0];
|
||||
pack_ic[0] = src_ic[0];
|
||||
src_ic += 1;
|
||||
pack_ic += 1;
|
||||
}
|
||||
|
||||
for (int oci = 0; oci < oc_8div; oci += C8NUM) {
|
||||
for (int curoi = 0; curoi < C8NUM; curoi++) {
|
||||
input_sum_oc[curoi] = tmp_sum_value * filter_zp[oci + curoi];
|
||||
}
|
||||
input_sum_oc += inputsum_stride;
|
||||
}
|
||||
if (oc_8div != output_channel) {
|
||||
for (int oci = 0; oci < oc_8res; oci += 1) {
|
||||
input_sum_oc[oci] = tmp_sum_value * filter_zp[oc_8div + oci];
|
||||
}
|
||||
for (int oci = oc_8res; oci < C8NUM; oci += 1) {
|
||||
input_sum_oc[oci] = 0;
|
||||
}
|
||||
} /* oc8 res done */
|
||||
|
||||
src_r += input_channel;
|
||||
pack_r += C4NUM;
|
||||
input_sum_r += C8NUM;
|
||||
}
|
||||
|
||||
for (int hwi = plane_size; hwi < hw8; hwi++) {
|
||||
for (int oc = 0; oc < oc8; oc++) {
|
||||
int oc8div = oc / C8NUM, oc8res = oc % C8NUM;
|
||||
input_sum[oc8div * inputsum_stride + hwi * C8NUM + oc8res] = 0;
|
||||
}
|
||||
}
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
void Conv1x1PreOptPert(const int8_t *src_input, int8_t *packed_input, int32_t *input_sum, size_t input_channel,
|
||||
size_t plane_size, ConvParameter *conv_param) {
|
||||
int ic4 = UP_ROUND(input_channel, C4NUM);
|
||||
size_t hw_8div = plane_size / C8NUM * C8NUM;
|
||||
size_t ic_4div = input_channel / C4NUM * C4NUM;
|
||||
int32_t filter_zp = conv_param->conv_quant_arg_.filter_quant_args_[0].zp_;
|
||||
|
||||
const int8_t *src_r = src_input;
|
||||
int8_t *pack_r = packed_input;
|
||||
/* per layer */
|
||||
for (int hwi = 0; hwi < hw_8div; hwi += C8NUM) {
|
||||
const int8_t *src_ic = src_r;
|
||||
int8_t *pack_ic = pack_r;
|
||||
int32_t *input_sum_r = input_sum + hwi;
|
||||
#ifdef ENABLE_ARM64
|
||||
size_t src_stride = input_channel;
|
||||
size_t ic_4res = input_channel - ic_4div;
|
||||
asm volatile(
|
||||
"dup v16.4s, wzr \n"
|
||||
"dup v17.4s, wzr \n"
|
||||
"mov x14, %[input_sum_r] \n"
|
||||
"dup v20.4s, %w[filter_zp] \n"
|
||||
|
||||
"mov x10, %[src_ic] \n"
|
||||
"mov x11, %[pack_ic] \n"
|
||||
|
||||
"mov x0, #0 \n"
|
||||
"1: \n"
|
||||
"cmp x0, %[ic_4div] \n"
|
||||
"add x0, x0, #4\n"
|
||||
"mov x12, x10 \n"
|
||||
"add x10, x10, #4\n"
|
||||
"blt 2f \n"
|
||||
"cmp %[ic_4res], #0\n"
|
||||
"beq 6f \n"
|
||||
"cmp %[ic_4res], #1\n"
|
||||
"beq 3f \n"
|
||||
"cmp %[ic_4res], #2\n"
|
||||
"beq 4f \n"
|
||||
"cmp %[ic_4res], #3\n"
|
||||
"beq 5f \n"
|
||||
|
||||
"2: \n"
|
||||
"ld1 {v0.s}[0], [x12], %[src_stride]\n"
|
||||
"ld1 {v0.s}[1], [x12], %[src_stride]\n"
|
||||
"ld1 {v0.s}[2], [x12], %[src_stride]\n"
|
||||
"ld1 {v0.s}[3], [x12], %[src_stride]\n"
|
||||
"ld1 {v1.s}[0], [x12], %[src_stride]\n"
|
||||
"ld1 {v1.s}[1], [x12], %[src_stride]\n"
|
||||
"ld1 {v1.s}[2], [x12], %[src_stride]\n"
|
||||
"ld1 {v1.s}[3], [x12], %[src_stride]\n"
|
||||
|
||||
"st1 {v0.16b}, [x11], #16\n"
|
||||
"st1 {v1.16b}, [x11], #16\n"
|
||||
|
||||
"saddlp v4.8h, v0.16b \n"
|
||||
"saddlp v5.8h, v1.16b \n"
|
||||
|
||||
"saddlp v0.4s, v4.8h \n"
|
||||
"saddlp v1.4s, v5.8h \n"
|
||||
|
||||
"add v16.4s, v16.4s, v0.4s \n"
|
||||
"add v17.4s, v17.4s, v1.4s \n"
|
||||
"b 1b \n"
|
||||
|
||||
"3: \n" /* col res 1 */
|
||||
"dup v0.4s, wzr \n"
|
||||
"dup v1.4s, wzr \n"
|
||||
|
||||
"ld1 {v0.b}[0], [x12], %[src_stride]\n"
|
||||
"ld1 {v0.b}[4], [x12], %[src_stride]\n"
|
||||
"ld1 {v0.b}[8], [x12], %[src_stride]\n"
|
||||
"ld1 {v0.b}[12], [x12], %[src_stride]\n"
|
||||
"ld1 {v1.b}[0], [x12], %[src_stride]\n"
|
||||
"ld1 {v1.b}[4], [x12], %[src_stride]\n"
|
||||
"ld1 {v1.b}[8], [x12], %[src_stride]\n"
|
||||
"ld1 {v1.b}[12], [x12], %[src_stride]\n"
|
||||
|
||||
"st1 {v0.16b}, [x11], #16\n"
|
||||
"st1 {v1.16b}, [x11], #16\n"
|
||||
"saddlp v4.8h, v0.16b \n"
|
||||
"saddlp v5.8h, v1.16b \n"
|
||||
"saddlp v0.4s, v4.8h \n"
|
||||
"saddlp v1.4s, v5.8h \n"
|
||||
"add v16.4s, v16.4s, v0.4s \n"
|
||||
"add v17.4s, v17.4s, v1.4s \n"
|
||||
"b 6f \n"
|
||||
|
||||
"4: \n" /* col res 2 */
|
||||
"dup v0.4s, wzr \n"
|
||||
"dup v1.4s, wzr \n"
|
||||
|
||||
"ld1 {v0.h}[0], [x12], %[src_stride]\n"
|
||||
"ld1 {v0.h}[2], [x12], %[src_stride]\n"
|
||||
"ld1 {v0.h}[4], [x12], %[src_stride]\n"
|
||||
"ld1 {v0.h}[6], [x12], %[src_stride]\n"
|
||||
"ld1 {v1.h}[0], [x12], %[src_stride]\n"
|
||||
"ld1 {v1.h}[2], [x12], %[src_stride]\n"
|
||||
"ld1 {v1.h}[4], [x12], %[src_stride]\n"
|
||||
"ld1 {v1.h}[6], [x12], %[src_stride]\n"
|
||||
|
||||
"st1 {v0.16b}, [x11], #16\n"
|
||||
"st1 {v1.16b}, [x11], #16\n"
|
||||
"saddlp v4.8h, v0.16b \n"
|
||||
"saddlp v5.8h, v1.16b \n"
|
||||
"saddlp v0.4s, v4.8h \n"
|
||||
"saddlp v1.4s, v5.8h \n"
|
||||
"add v16.4s, v16.4s, v0.4s \n"
|
||||
"add v17.4s, v17.4s, v1.4s \n"
|
||||
"b 6f \n"
|
||||
|
||||
"5: \n" /* col res 3 */
|
||||
"dup v0.4s, wzr \n"
|
||||
"dup v1.4s, wzr \n"
|
||||
"add x13, x12, #2 \n"
|
||||
|
||||
"ld1 {v0.h}[0], [x12], %[src_stride]\n"
|
||||
"ld1 {v0.b}[2], [x13], %[src_stride]\n"
|
||||
"ld1 {v0.h}[2], [x12], %[src_stride]\n"
|
||||
"ld1 {v0.b}[6], [x13], %[src_stride]\n"
|
||||
"ld1 {v0.h}[4], [x12], %[src_stride]\n"
|
||||
"ld1 {v0.b}[10], [x13], %[src_stride]\n"
|
||||
"ld1 {v0.h}[6], [x12], %[src_stride]\n"
|
||||
"ld1 {v0.b}[14], [x13], %[src_stride]\n"
|
||||
"ld1 {v1.h}[0], [x12], %[src_stride]\n"
|
||||
"ld1 {v1.b}[2], [x13], %[src_stride]\n"
|
||||
"ld1 {v1.h}[2], [x12], %[src_stride]\n"
|
||||
"ld1 {v1.b}[6], [x13], %[src_stride]\n"
|
||||
"ld1 {v1.h}[4], [x12], %[src_stride]\n"
|
||||
"ld1 {v1.b}[10], [x13], %[src_stride]\n"
|
||||
"ld1 {v1.h}[6], [x12], %[src_stride]\n"
|
||||
"ld1 {v1.b}[14], [x13], %[src_stride]\n"
|
||||
|
||||
"st1 {v0.16b}, [x11], #16\n"
|
||||
"st1 {v1.16b}, [x11], #16\n"
|
||||
"saddlp v4.8h, v0.16b \n"
|
||||
"saddlp v5.8h, v1.16b \n"
|
||||
"saddlp v0.4s, v4.8h \n"
|
||||
"saddlp v1.4s, v5.8h \n"
|
||||
"add v16.4s, v16.4s, v0.4s \n"
|
||||
"add v17.4s, v17.4s, v1.4s \n"
|
||||
"b 6f \n"
|
||||
|
||||
"6: \n"
|
||||
"mul v16.4s, v16.4s, v20.4s \n"
|
||||
"mul v17.4s, v17.4s, v20.4s \n"
|
||||
|
||||
"st1 {v16.4s}, [x14], #16 \n"
|
||||
"st1 {v17.4s}, [x14], #16 \n"
|
||||
|
||||
:
|
||||
: [ src_ic ] "r"(src_ic), [ pack_ic ] "r"(pack_ic), [ input_sum_r ] "r"(input_sum_r),
|
||||
[ src_stride ] "r"(src_stride), [ ic_4div ] "r"(ic_4div), [ ic_4res ] "r"(ic_4res), [ filter_zp ] "r"(filter_zp)
|
||||
: "x0", "x1", "x10", "x11", "x12", "x13", "x14", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v16", "v17",
|
||||
"v20");
|
||||
#else
|
||||
int32_t tmp_sum_value[8] = {0};
|
||||
for (int ici = 0; ici < ic_4div; ici += C4NUM) {
|
||||
for (int i = 0; i < C8NUM; i++) {
|
||||
tmp_sum_value[i] += src_ic[0 + i * input_channel];
|
||||
tmp_sum_value[i] += src_ic[1 + i * input_channel];
|
||||
tmp_sum_value[i] += src_ic[2 + i * input_channel];
|
||||
tmp_sum_value[i] += src_ic[3 + i * input_channel];
|
||||
pack_ic[0 + i * C4NUM] = src_ic[0 + i * input_channel];
|
||||
pack_ic[1 + i * C4NUM] = src_ic[1 + i * input_channel];
|
||||
pack_ic[2 + i * C4NUM] = src_ic[2 + i * input_channel];
|
||||
pack_ic[3 + i * C4NUM] = src_ic[3 + i * input_channel];
|
||||
}
|
||||
src_ic += C4NUM;
|
||||
pack_ic += C4NUM * C8NUM;
|
||||
}
|
||||
for (int ici = ic_4div; ici < input_channel; ici += 1) {
|
||||
for (int i = 0; i < C8NUM; i++) {
|
||||
tmp_sum_value[i] += src_ic[i * input_channel];
|
||||
pack_ic[i * C4NUM] = src_ic[i * input_channel];
|
||||
}
|
||||
src_ic += 1;
|
||||
pack_ic += 1;
|
||||
}
|
||||
|
||||
for (int ici = input_channel; ici < ic4; ici += 1) {
|
||||
for (int i = 0; i < C8NUM; i++) {
|
||||
pack_ic[i * C4NUM] = 0;
|
||||
}
|
||||
pack_ic += 1;
|
||||
}
|
||||
|
||||
for (int i = 0; i < C8NUM; i++) {
|
||||
input_sum_r[i] = tmp_sum_value[i] * filter_zp;
|
||||
}
|
||||
#endif
|
||||
src_r += input_channel * C8NUM;
|
||||
pack_r += ic4 * C8NUM;
|
||||
}
|
||||
|
||||
if (hw_8div != plane_size) {
|
||||
memset(pack_r, 0, C8NUM * ic4);
|
||||
for (int hwi = hw_8div; hwi < plane_size; hwi += 1) {
|
||||
int32_t tmp_sum_value = 0;
|
||||
const int8_t *src_ic = src_r;
|
||||
int8_t *pack_ic = pack_r;
|
||||
for (int ici = 0; ici < ic_4div; ici += C4NUM) {
|
||||
tmp_sum_value += src_ic[0];
|
||||
tmp_sum_value += src_ic[1];
|
||||
tmp_sum_value += src_ic[2];
|
||||
tmp_sum_value += src_ic[3];
|
||||
pack_ic[0] = src_ic[0];
|
||||
pack_ic[1] = src_ic[1];
|
||||
pack_ic[2] = src_ic[2];
|
||||
pack_ic[3] = src_ic[3];
|
||||
src_ic += C4NUM;
|
||||
pack_ic += C4NUM * C8NUM;
|
||||
}
|
||||
for (int ici = ic_4div; ici < input_channel; ici += 1) {
|
||||
tmp_sum_value += src_ic[0];
|
||||
pack_ic[0] = src_ic[0];
|
||||
src_ic += 1;
|
||||
pack_ic += 1;
|
||||
}
|
||||
input_sum[hwi] = tmp_sum_value * filter_zp;
|
||||
src_r += input_channel;
|
||||
pack_r += C4NUM;
|
||||
}
|
||||
for (int hwi = plane_size; hwi < UP_ROUND(plane_size, C8NUM); hwi++) {
|
||||
input_sum[hwi] = 0;
|
||||
}
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
void Conv1x1Int8Opt(const int8_t *packed_input, const int8_t *packed_weight, int8_t *dst, const int32_t *input_sum,
|
||||
const int32_t *bias, int row, int col, int deep4, int32_t *left_shift, int32_t *right_shift,
|
||||
int32_t *multiplier, ConvParameter *conv_param, MATMUL_OPT_DP_FUNC matmul_func, int *filter_zp) {
|
||||
int is_per_oc = (int)conv_param->conv_quant_arg_.filter_arg_num_ != 1;
|
||||
matmul_func(packed_input, packed_weight, dst, row, col, deep4, conv_param->output_channel_, input_sum, bias,
|
||||
left_shift, right_shift, multiplier, conv_param->conv_quant_arg_.output_quant_args_[0].zp_,
|
||||
conv_param->conv_quant_arg_.out_act_min_[0], conv_param->conv_quant_arg_.out_act_max_[0], is_per_oc,
|
||||
filter_zp);
|
||||
return;
|
||||
}
|
||||
|
||||
void Conv1x1Int8(const int8_t *packed_input, const int8_t *packed_weight, int8_t *dst, const int32_t *input_sum,
|
||||
const int32_t *bias, int row, int col, int deep16, int32_t *left_shift, int32_t *right_shift,
|
||||
int32_t *multiplier, ConvParameter *conv_param, int32_t *filter_zp) {
|
||||
int is_per_oc = (int)conv_param->conv_quant_arg_.filter_arg_num_ != 1;
|
||||
MatmulInt8Opt(packed_input, packed_weight, dst, row, col, deep16, input_sum, bias,
|
||||
conv_param->conv_quant_arg_.out_act_min_[0], conv_param->conv_quant_arg_.out_act_max_[0],
|
||||
conv_param->conv_quant_arg_.output_quant_args_[0].zp_, multiplier, left_shift, right_shift,
|
||||
conv_param->output_channel_, is_per_oc, filter_zp);
|
||||
return;
|
||||
}
|
||||
|
||||
// int8 convolution 3x3
|
||||
void Conv3x3Int8(int16_t *input_data, int16_t *transed_weight, const int32_t *bias_data, int8_t *output_data,
|
||||
int16_t *tile_buffer, int16_t *block_unit_buffer, int32_t *tmp_dst_buffer, int8_t *tmp_out,
|
||||
int task_id, ConvParameter *conv_param) {
|
||||
int ic8 = UP_DIV(conv_param->input_channel_, C8NUM);
|
||||
int out_w_block = UP_DIV(conv_param->output_w_, OUPUT_UNIT);
|
||||
int out_h_block = UP_DIV(conv_param->output_h_, OUPUT_UNIT);
|
||||
int output_count = out_w_block * out_h_block;
|
||||
int output_tile_count = UP_DIV(output_count, TILE_NUM);
|
||||
int oc4 = UP_DIV(conv_param->output_channel_, C4NUM);
|
||||
int tile_buffer_offset = TILE_NUM * 16 * ic8 * C8NUM;
|
||||
const int block_unit_buffer_offset = 16 * C8NUM;
|
||||
int tmp_dst_buffer_offset = TILE_NUM * 16 * oc4 * C4NUM;
|
||||
|
||||
for (int batch = 0; batch < conv_param->input_batch_; batch++) {
|
||||
int in_batch_offset = batch * ic8 * C8NUM * conv_param->input_h_ * conv_param->input_w_;
|
||||
int tmp_out_batch_offset = batch * oc4 * C4NUM * conv_param->output_w_ * conv_param->output_h_;
|
||||
for (int thread_id = task_id; thread_id < output_tile_count; thread_id += conv_param->thread_num_) {
|
||||
int start_index = thread_id * TILE_NUM;
|
||||
int real_cal_num = (output_count - start_index) < TILE_NUM ? (output_count - start_index) : TILE_NUM;
|
||||
|
||||
Conv3x3Int8InputTransform(input_data + in_batch_offset, tile_buffer + task_id * tile_buffer_offset,
|
||||
block_unit_buffer + task_id * block_unit_buffer_offset, start_index, real_cal_num,
|
||||
out_w_block, conv_param);
|
||||
|
||||
Conv3x3Int8Gemm(tmp_dst_buffer + task_id * tmp_dst_buffer_offset, tile_buffer + task_id * tile_buffer_offset,
|
||||
transed_weight, conv_param->output_channel_, ic8, real_cal_num);
|
||||
|
||||
Conv3x3Int8OutputTransform(tmp_dst_buffer + task_id * tmp_dst_buffer_offset, tmp_out + tmp_out_batch_offset,
|
||||
bias_data, start_index, real_cal_num, out_w_block, conv_param);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -16,6 +16,7 @@
|
|||
#ifndef MINDSPORE_LITE_NNACL_INT8_CONV_INT8_H_
|
||||
#define MINDSPORE_LITE_NNACL_INT8_CONV_INT8_H_
|
||||
|
||||
#include <string.h>
|
||||
#ifdef ENABLE_NEON
|
||||
#include <arm_neon.h>
|
||||
#endif
|
||||
|
@ -24,9 +25,10 @@
|
|||
#include "nnacl/common_func.h"
|
||||
#include "nnacl/conv_parameter.h"
|
||||
#include "nnacl/winograd_utils.h"
|
||||
#include "nnacl/quantization/quantize.h"
|
||||
#include "nnacl/int8/quantize.h"
|
||||
#include "nnacl/matmul_parameter.h"
|
||||
#include "nnacl/int8/matmul_int8.h"
|
||||
#include "nnacl/int8/common_func_int8.h"
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
|
@ -36,23 +38,6 @@ void ConvInt8(int8_t *input_data, int8_t *packed_input, int8_t *matmul_input, in
|
|||
const int32_t *bias_data, int8_t *output_data, int32_t *filter_zp, int32_t *input_sum, int task_id,
|
||||
ConvParameter *conv_param, MATMUL_OPT_R_FUNC matmul_func, bool is_optimize);
|
||||
|
||||
// int8 convolution 1x1
|
||||
void Conv1x1PreOptPeroc(const int8_t *src_input, int8_t *packed_input, int32_t *input_sum, size_t input_channel,
|
||||
size_t output_channel, size_t plane_size, int32_t *filter_zp, size_t inputsum_stride);
|
||||
void Conv1x1PreOptPert(const int8_t *src_input, int8_t *packed_input, int32_t *input_sum, size_t input_channel,
|
||||
size_t plane_size, ConvParameter *conv_param);
|
||||
void Conv1x1Int8(const int8_t *packed_input, const int8_t *packed_weight, int8_t *dst, const int32_t *input_sum,
|
||||
const int32_t *bias, int row, int col, int deep16, int32_t *left_shift, int32_t *right_shift,
|
||||
int32_t *multiplier, ConvParameter *conv_param, int32_t *filter_zp);
|
||||
void Conv1x1Int8Opt(const int8_t *packed_input, const int8_t *packed_weight, int8_t *dst, const int32_t *input_sum,
|
||||
const int32_t *bias, int row, int col, int deep4, int32_t *left_shift, int32_t *right_shift,
|
||||
int32_t *multiplier, ConvParameter *conv_param, MATMUL_OPT_DP_FUNC matmul_func, int32_t *filter_zp);
|
||||
|
||||
// int8 convolution 3x3
|
||||
void Conv3x3Int8(int16_t *input_data, int16_t *transed_weight, const int32_t *bias_data, int8_t *output_data,
|
||||
int16_t *tile_buffer, int16_t *block_unit_buffer, int32_t *tmp_dst_buffer, int8_t *tmp_out,
|
||||
int task_id, ConvParameter *conv_param);
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
|
|
|
@ -17,7 +17,7 @@
|
|||
#define MINDSPORE_LITE_NNACL_INT8_DEPTH_TO_SPACE_INT8_H_
|
||||
|
||||
#include "nnacl/depth_to_space_parameter.h"
|
||||
#include "nnacl/quantization/quantize.h"
|
||||
#include "nnacl/int8/quantize.h"
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
|
|
|
@ -15,9 +15,6 @@
|
|||
*/
|
||||
|
||||
#include "nnacl/int8/div_int8.h"
|
||||
#include "nnacl/quantization/fixed_point.h"
|
||||
#include "nnacl/errorcode.h"
|
||||
#include "nnacl/quantization/quantize.h"
|
||||
|
||||
int DivInt8(int8_t *input0_data, int8_t *input1_data, int8_t *output_data, int64_t real_dst_count, DivQuantArg *para) {
|
||||
int index = 0;
|
||||
|
|
|
@ -18,7 +18,9 @@
|
|||
#define MINDSPORE_LITE_NNACL_INT8_DIV_INT8_H_
|
||||
|
||||
#include "nnacl/op_base.h"
|
||||
#include "nnacl/quantization/quantize.h"
|
||||
#include "nnacl/errorcode.h"
|
||||
#include "nnacl/int8/quantize.h"
|
||||
#include "nnacl/int8/fixed_point.h"
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
|
|
|
@ -14,7 +14,7 @@
|
|||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "nnacl/quantization/fixed_point.h"
|
||||
#include "nnacl/int8/fixed_point.h"
|
||||
|
||||
// returns the high-32 bits of a * b with rounding
|
||||
// assume that a and b is divided by 2^31, who fall into [-1, 1]
|
||||
|
@ -107,7 +107,7 @@ int CountLeadingZeroBits(uint32_t x) {
|
|||
if (x == 0) {
|
||||
return 8 * sizeof(uint32_t);
|
||||
}
|
||||
const int32_t leading_positive = (int32_t)(1) << (8 * sizeof(uint32_t) - 1);
|
||||
const int32_t leading_positive = (uint32_t)(1) << (8 * sizeof(uint32_t) - 1);
|
||||
int leading_zeros = 0;
|
||||
while (x < leading_positive) {
|
||||
x <<= 1;
|
|
@ -18,7 +18,7 @@
|
|||
#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_INT8_GATHERND_INT8_H_
|
||||
|
||||
#include "nnacl/op_base.h"
|
||||
#include "nnacl/quantization/quantize.h"
|
||||
#include "nnacl/int8/quantize.h"
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
|
|
|
@ -16,7 +16,7 @@
|
|||
*/
|
||||
#include "nnacl/int8/gather_int8.h"
|
||||
#include "nnacl/op_base.h"
|
||||
#include "nnacl/quantization/quantize.h"
|
||||
#include "nnacl/int8/quantize.h"
|
||||
#include "nnacl/errorcode.h"
|
||||
|
||||
int GatherInt8(int8_t *in_data, int8_t *out_data, int outer_size, int inner_size, int limit, const int *indices,
|
||||
|
|
|
@ -18,7 +18,7 @@
|
|||
#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_INT8_GATHER_INT8_H_
|
||||
|
||||
#include "nnacl/op_base.h"
|
||||
#include "nnacl/quantization/quantize.h"
|
||||
#include "nnacl/int8/quantize.h"
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
|
|
|
@ -19,7 +19,7 @@
|
|||
#include <math.h>
|
||||
#include "nnacl/op_base.h"
|
||||
#include "nnacl/errorcode.h"
|
||||
#include "nnacl/quantization/fixed_point.h"
|
||||
#include "nnacl/int8/fixed_point.h"
|
||||
|
||||
typedef struct HswishQuantArg {
|
||||
double input_scale;
|
||||
|
|
|
@ -15,7 +15,7 @@
|
|||
*/
|
||||
#include "nnacl/int8/l2_norm_int8.h"
|
||||
#include <limits.h>
|
||||
#include "nnacl/quantization/fixed_point.h"
|
||||
#include "nnacl/int8/fixed_point.h"
|
||||
#include "nnacl/errorcode.h"
|
||||
|
||||
int L2NormalizationInt8(const int8_t *input_data, int8_t *output_data, const L2NormParameter *param,
|
||||
|
|
|
@ -18,7 +18,7 @@
|
|||
|
||||
#include "nnacl/errorcode.h"
|
||||
#include "nnacl/layer_norm_parameter.h"
|
||||
#include "nnacl/quantization/fixed_point.h"
|
||||
#include "nnacl/int8/fixed_point.h"
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
|
|
|
@ -18,7 +18,7 @@
|
|||
#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_OPCLIB_INT8_PRELU_INT8_H_
|
||||
|
||||
#include "nnacl/op_base.h"
|
||||
#include "nnacl/quantization/quantize.h"
|
||||
#include "nnacl/int8/quantize.h"
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
|
|
|
@ -15,7 +15,7 @@
|
|||
*/
|
||||
|
||||
#include "nnacl/int8/matmul_int8.h"
|
||||
#include "nnacl/quantization/fixed_point.h"
|
||||
#include "nnacl/int8/fixed_point.h"
|
||||
|
||||
void RowMajor2Row2x16MajorInt8(int8_t *src_ptr, int8_t *dst_ptr, int row, int col) {
|
||||
int col16 = UP_ROUND(col, C16NUM);
|
||||
|
|
|
@ -17,7 +17,6 @@
|
|||
#ifndef MINDSPORE_LITE_NNACL_INT8_MATMUL_H_
|
||||
#define MINDSPORE_LITE_NNACL_INT8_MATMUL_H_
|
||||
|
||||
#include <stdio.h>
|
||||
#include <string.h>
|
||||
#include "nnacl/op_base.h"
|
||||
#include "nnacl/matmul_parameter.h"
|
||||
|
|
|
@ -15,15 +15,8 @@
|
|||
*/
|
||||
|
||||
#include "nnacl/int8/mul_int8.h"
|
||||
#include "nnacl/mul_parameter.h"
|
||||
#ifdef ENABLE_NEON
|
||||
#include <arm_neon.h>
|
||||
#include "nnacl/int8/common_func_int8.h"
|
||||
#endif
|
||||
#include "nnacl/quantization/fixed_point.h"
|
||||
|
||||
#ifdef ENABLE_NEON
|
||||
|
||||
int16x4_t ClacSumHalfWordMul(int16x4_t scaled_input0, int16x4_t scaled_input1, int32x4_t left_shift_out_vec,
|
||||
int32x4_t right_shift_out_vec, int32x4_t output_multiplier_vec) {
|
||||
int32x4_t input_scale = vmull_s16(scaled_input0, scaled_input1);
|
||||
|
|
|
@ -19,6 +19,11 @@
|
|||
|
||||
#include "nnacl/op_base.h"
|
||||
#include "nnacl/mul_parameter.h"
|
||||
#include "nnacl/int8/common_func_int8.h"
|
||||
#include "nnacl/int8/fixed_point.h"
|
||||
#ifdef ENABLE_NEON
|
||||
#include <arm_neon.h>
|
||||
#endif
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
|
|
File diff suppressed because it is too large
Load Diff
|
@ -0,0 +1,62 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_LITE_NNACL_INT8_PACK_INT8_H_
|
||||
#define MINDSPORE_LITE_NNACL_INT8_PACK_INT8_H_
|
||||
|
||||
#include <string.h>
|
||||
#include "nnacl/op_base.h"
|
||||
#include "nnacl/int8/matmul_int8.h"
|
||||
#include "nnacl/conv_parameter.h"
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif
|
||||
|
||||
void PackNHWCToNHWC4Int8(const void *src, void *dst, int batch, int plane, int channel);
|
||||
void PackNHWC4ToNHWCInt8(const void *src, void *dst, int batch, int plane, int channel);
|
||||
void PackNHWCToNHWC8Int8(const void *src, void *dst, int batch, int plane, int channel);
|
||||
void PackNHWC8ToNHWCInt8(const void *src, void *dst, int batch, int plane, int channel);
|
||||
void PackNCHWToNC8HW8Int8(const void *src, void *dst, int batch, int plane, int channel);
|
||||
void PackNC4HW4ToNHWCInt8(const void *src, void *dst, int batch, int plane, int channel);
|
||||
void PackNHWCToC8HWN8Int8(const void *src, void *dst, int batch, int plane, int channel);
|
||||
void PackNCHWToNHWCInt8(const void *src, void *dst, int batch, int plane, int channel);
|
||||
void PackNHWCToNCHWInt8(const void *src, void *dst, int batch, int plane, int channel);
|
||||
|
||||
void PackInputSum16x4Int8(const int8_t *input, int32_t *input_sum, int32_t *filter_zp, ConvParameter *conv_param);
|
||||
void PackInputSum16x4PerLayer(const int8_t *src, int32_t *dst, int32_t filter_zp, size_t row4, size_t col16);
|
||||
void PackInputToC8Int8(const int8_t *input_data, int16_t *packed_input, ConvParameter *conv_param);
|
||||
void PackWeightToC8Int8(const int8_t *origin_weight_data, int16_t *packed_weight_data, ConvParameter *conv_param);
|
||||
void Im2ColPackUnitInt8Opt(const int8_t *input_data, int8_t *packed_input, int8_t *matmul_input, int real_cal_num,
|
||||
int block_index, int32_t *filter_zp, int32_t *input_sum, ConvParameter *conv_param,
|
||||
bool per_channel, bool is_optimize);
|
||||
#ifdef ENABLE_ARM
|
||||
void PreSum4x16Int8Pert(const int8_t *src, int32_t *sum, size_t row4, size_t col16, int32_t filter_zp);
|
||||
void PreSum4x16Int8Peroc(const int8_t *src, int32_t *sum, int32_t *zp, size_t hw4, size_t ic16, int32_t oc_div,
|
||||
size_t oc_res, size_t stride);
|
||||
#endif
|
||||
|
||||
void PackDepthwiseInt8Input(const int8_t *src, int16_t *dst, const ConvParameter *conv_param);
|
||||
void PackDepthwiseInt8Weight(const int8_t *origin_weight, int16_t *packed_weight_, int plane, int channel,
|
||||
ConvQuantArg *quant_qrg);
|
||||
void PackDeconvDepthwiseInt8Weight(const int8_t *origin_weight, int16_t *packed_weight_, int plane, int channel,
|
||||
ConvQuantArg *quant_qrg);
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
|
||||
#endif // MINDSPORE_LITE_NNACL_INT8_PAD_INT8_H_
|
|
@ -19,7 +19,7 @@
|
|||
|
||||
#include "nnacl/op_base.h"
|
||||
#include "nnacl/power_parameter.h"
|
||||
#include "nnacl/quantization/quantize.h"
|
||||
#include "nnacl/int8/quantize.h"
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
|
|
|
@ -14,8 +14,7 @@
|
|||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "nnacl/quantization/quantize.h"
|
||||
#include <stdio.h>
|
||||
#include "nnacl/int8/quantize.h"
|
||||
|
||||
const uint64_t dSignMask = 1ull << 63;
|
||||
const uint64_t dExponentMask = 0x7ffull << 52;
|
||||
|
@ -57,8 +56,6 @@ void QuantizeRoundParameterWithSinglePrecision(double double_multiplier, int32_t
|
|||
/* multipiler is in[0x40000000, 0x7FFFFF80] range */
|
||||
*quantized_multiplier = (int32_t)(((scale_bits & UINT32_C(0x007FFFFF)) | UINT32_C(0x00800000)) << 7);
|
||||
if (quantized_multiplier[0] < INT32_C(0x40000000) || quantized_multiplier[0] > INT32_C(0x7FFFFF80)) {
|
||||
printf("quantized multiplier must be in [0x40000000, 0x7FFFFF80] range, now multiplier is %d\n",
|
||||
quantized_multiplier[0]);
|
||||
return;
|
||||
}
|
||||
/* shift is in [0, 31] range */
|
|
@ -28,11 +28,6 @@
|
|||
#define FILTER_PER_CHANNEL 0b010
|
||||
#define OUTPUT_PER_CHANNEL 0b100
|
||||
|
||||
typedef struct QuantArg {
|
||||
float scale_;
|
||||
int32_t zp_;
|
||||
} QuantArg;
|
||||
|
||||
typedef struct ConvQuantArg {
|
||||
RoundingMode round_mode_;
|
||||
CalFixedMultiplierMode quant_multiplier_mode_;
|
||||
|
@ -58,24 +53,6 @@ typedef struct ConcatQuantArg {
|
|||
int8_t output_activation_max_;
|
||||
} ConcatQuantArg;
|
||||
|
||||
typedef struct SqueezeQuantArg {
|
||||
QuantArg *in_quant_args_;
|
||||
QuantArg *out_quant_args_;
|
||||
} SqueezeQuantArg;
|
||||
|
||||
typedef struct UnSqueezeQuantArg {
|
||||
int *input_sizes_;
|
||||
int output_size_;
|
||||
int **input_shapes_;
|
||||
int *output_shape_;
|
||||
float alpha;
|
||||
int axis_;
|
||||
size_t input_num_;
|
||||
size_t output_dim_;
|
||||
QuantArg in_quant_args_;
|
||||
QuantArg out_quant_args_;
|
||||
} UnSqueezeQuantArg;
|
||||
|
||||
typedef struct PreluQuantArg {
|
||||
int *input_sizes_;
|
||||
int output_size_;
|
||||
|
@ -103,22 +80,6 @@ typedef struct MatmulQuantArg {
|
|||
int32_t quant_multiplier;
|
||||
} MatmulQuantArg;
|
||||
|
||||
typedef struct PadQuantArg {
|
||||
QuantArg *in_quant_args_;
|
||||
QuantArg *out_quanr_args_;
|
||||
int8_t *constant_value_;
|
||||
} PadQuantArg;
|
||||
|
||||
typedef struct MulQuantArg {
|
||||
QuantArg in_quant_args_[2];
|
||||
QuantArg out_quant_arg_;
|
||||
int output_multiplier_;
|
||||
int output_activation_min_;
|
||||
int output_activation_max_;
|
||||
int shift_left_;
|
||||
int shift_right_;
|
||||
} MulQuantArg;
|
||||
|
||||
typedef struct CropQuantArg {
|
||||
QuantArg in_args_;
|
||||
QuantArg out_args_;
|
||||
|
@ -142,13 +103,6 @@ typedef struct GatherQuantArg {
|
|||
int zp_out_;
|
||||
} GatherQuantArg;
|
||||
|
||||
typedef struct SplitQuantArg {
|
||||
QuantArg in_args_;
|
||||
QuantArg out_args_[20];
|
||||
int output_activation_min_;
|
||||
int output_activation_max_;
|
||||
} SplitQuantArg;
|
||||
|
||||
typedef struct SoftmaxQuantArg {
|
||||
QuantArg in_quant_args_;
|
||||
QuantArg out_quant_arg_;
|
||||
|
@ -159,19 +113,6 @@ typedef struct SoftmaxQuantArg {
|
|||
int shift_right_;
|
||||
} SoftmaxQuantArg;
|
||||
|
||||
typedef struct ReshapeQuantArg {
|
||||
QuantArg in_args_;
|
||||
QuantArg out_args_;
|
||||
int output_activation_min_;
|
||||
int output_activation_max_;
|
||||
} ReshapeQuantArg;
|
||||
|
||||
typedef struct QuantMulArg {
|
||||
int32_t multiplier_;
|
||||
int left_shift_;
|
||||
int right_shift_;
|
||||
} QuantMulArg;
|
||||
|
||||
typedef struct SubQuantArg {
|
||||
QuantArg in0_args_;
|
||||
QuantArg in1_args_;
|
||||
|
@ -227,21 +168,6 @@ typedef struct ReduceQuantArg {
|
|||
int sum_square_right_shift_;
|
||||
} ReduceQuantArg;
|
||||
|
||||
typedef struct SliceQuantArg {
|
||||
QuantArg in_args_;
|
||||
QuantArg out_args_;
|
||||
int output_activation_min_;
|
||||
int output_activation_max_;
|
||||
} SliceQuantArg;
|
||||
|
||||
typedef struct PowerQuantArg {
|
||||
QuantArg in_args_;
|
||||
QuantArg exp_args_;
|
||||
QuantArg out_args_;
|
||||
int output_activation_min_;
|
||||
int output_activation_max_;
|
||||
} PowerQuantArg;
|
||||
|
||||
typedef struct LeakyReluQuantArg {
|
||||
OpParameter op_parameter_;
|
||||
PreluQuantArg quant_arg;
|
|
@ -17,7 +17,7 @@
|
|||
#include <stdint.h>
|
||||
#include "nnacl/int8/reduce_int8.h"
|
||||
#include "nnacl/errorcode.h"
|
||||
#include "nnacl/quantization/fixed_point.h"
|
||||
#include "nnacl/int8/fixed_point.h"
|
||||
#include "nnacl/common_func.h"
|
||||
|
||||
int ReduceMeanN(int n, int h, int w, int c, int8_t *in_data, int8_t *out_data, QuantMulArg quant_arg) {
|
||||
|
|
|
@ -16,7 +16,9 @@
|
|||
|
||||
#ifndef MINDSPORE_LITE_NNACL_INT8_REDUCE_INT8_H_
|
||||
#define MINDSPORE_LITE_NNACL_INT8_REDUCE_INT8_H_
|
||||
#include "nnacl/quantization/quantize.h"
|
||||
|
||||
#include "nnacl/int8/quantize.h"
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif
|
||||
|
|
|
@ -19,8 +19,8 @@
|
|||
#include <math.h>
|
||||
#include "nnacl/op_base.h"
|
||||
#include "nnacl/errorcode.h"
|
||||
#include "nnacl/quantization/fixed_point.h"
|
||||
#include "nnacl/quantization/quantize.h"
|
||||
#include "nnacl/int8/fixed_point.h"
|
||||
#include "nnacl/int8/quantize.h"
|
||||
|
||||
typedef struct ReluXQuantArg {
|
||||
QuantArg input_arg;
|
||||
|
|
|
@ -16,6 +16,8 @@
|
|||
|
||||
#ifndef MINDSPORE_LITE_NNACL_INT8_RESHAHPE_INT8_H_
|
||||
#define MINDSPORE_LITE_NNACL_INT8_RESHAHPE_INT8_H_
|
||||
|
||||
#include <math.h>
|
||||
#include "nnacl/op_base.h"
|
||||
#include "nnacl/reshape_parameter.h"
|
||||
|
||||
|
|
|
@ -16,7 +16,7 @@
|
|||
#include <math.h>
|
||||
#include "nnacl/int8/resize_int8.h"
|
||||
#include "nnacl/common_func.h"
|
||||
#include "nnacl/quantization/fixed_point.h"
|
||||
#include "nnacl/int8/fixed_point.h"
|
||||
#include "nnacl/errorcode.h"
|
||||
|
||||
int ResizeBilinearInt8(const int8_t *input_ptr, int8_t *output_ptr, int batch, int in_h, int in_w, int out_h, int out_w,
|
||||
|
|
|
@ -21,7 +21,7 @@
|
|||
#endif
|
||||
#include <memory.h>
|
||||
#include "nnacl/op_base.h"
|
||||
#include "nnacl/quantization/quantize.h"
|
||||
#include "nnacl/int8/quantize.h"
|
||||
#include "nnacl/resize_parameter.h"
|
||||
|
||||
#ifdef __cplusplus
|
||||
|
|
|
@ -15,7 +15,7 @@
|
|||
*/
|
||||
|
||||
#include "nnacl/int8/scale_int8.h"
|
||||
#include "nnacl/quantization/fixed_point.h"
|
||||
#include "nnacl/int8/fixed_point.h"
|
||||
|
||||
#ifdef ENABLE_NEON
|
||||
int16x4_t ClacSumHalfWordMul2(int32x4_t scaled_input0, int32x4_t scaled_input1, int32x4_t left_shift_out_vec,
|
||||
|
|
|
@ -19,6 +19,8 @@
|
|||
|
||||
#include "nnacl/op_base.h"
|
||||
#include "nnacl/scale.h"
|
||||
#include "nnacl/nnacl_common.h"
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif
|
||||
|
|
|
@ -19,7 +19,7 @@
|
|||
#include <math.h>
|
||||
#include "nnacl/op_base.h"
|
||||
#include "nnacl/errorcode.h"
|
||||
#include "nnacl/quantization/fixed_point.h"
|
||||
#include "nnacl/int8/fixed_point.h"
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
|
|
|
@ -15,8 +15,6 @@
|
|||
*/
|
||||
|
||||
#include "nnacl/int8/slice_int8.h"
|
||||
#include <string.h>
|
||||
#include "nnacl/quantization/fixed_point.h"
|
||||
|
||||
int SliceInt8NoParallel(const int8_t *input, int8_t *output, SliceParameter *param) {
|
||||
double input_scale = param->quant_arg_.in_args_.scale_;
|
||||
|
|
|
@ -16,8 +16,11 @@
|
|||
#ifndef MINDSPORE_LITE_NNACL_INT8_SLICE_INT8_H_
|
||||
#define MINDSPORE_LITE_NNACL_INT8_SLICE_INT8_H_
|
||||
|
||||
#include <math.h>
|
||||
#include <string.h>
|
||||
#include "nnacl/op_base.h"
|
||||
#include "nnacl/slice_parameter.h"
|
||||
#include "nnacl/int8/fixed_point.h"
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
|
|
|
@ -15,9 +15,6 @@
|
|||
*/
|
||||
|
||||
#include "nnacl/int8/softmax_int8.h"
|
||||
#include <math.h>
|
||||
#include "nnacl/quantization/fixed_point.h"
|
||||
#include "nnacl/quantization/quantize.h"
|
||||
|
||||
int SoftmaxInt8(const int8_t *input_ptr, int8_t *output_ptr, int count, int *exp_data, int *sum_data,
|
||||
SoftmaxQuantArg quant_param, SoftmaxParameter *parameter) {
|
||||
|
|
|
@ -17,9 +17,11 @@
|
|||
#ifndef MINDSPORE_LITE_NNACL_INT8_SOFTMAX_INT8_H_
|
||||
#define MINDSPORE_LITE_NNACL_INT8_SOFTMAX_INT8_H_
|
||||
|
||||
#include <math.h>
|
||||
#include "nnacl/op_base.h"
|
||||
#include "nnacl/softmax_parameter.h"
|
||||
#include "nnacl/quantization/quantize.h"
|
||||
#include "nnacl/int8/fixed_point.h"
|
||||
#include "nnacl/int8/quantize.h"
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
|
|
|
@ -16,6 +16,8 @@
|
|||
|
||||
#ifndef MINDSPORE_LITE_NNACL_INT8_SPLIT_INT8_H_
|
||||
#define MINDSPORE_LITE_NNACL_INT8_SPLIT_INT8_H_
|
||||
|
||||
#include <math.h>
|
||||
#include "nnacl/op_base.h"
|
||||
#include "nnacl/split_parameter.h"
|
||||
|
||||
|
|
|
@ -17,7 +17,7 @@
|
|||
#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_OPCLIB_INT8_SQUEEZE_INT8_H_
|
||||
#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_OPCLIB_INT8_SQUEEZE_INT8_H_
|
||||
|
||||
#include "nnacl/quantization/quantize.h"
|
||||
#include "nnacl/int8/quantize.h"
|
||||
#include "nnacl/squeeze_parameter.h"
|
||||
|
||||
#ifdef __cplusplus
|
||||
|
|
|
@ -19,7 +19,7 @@
|
|||
#include <arm_neon.h>
|
||||
#include "nnacl/int8/common_func_int8.h"
|
||||
#endif
|
||||
#include "nnacl/quantization/fixed_point.h"
|
||||
#include "nnacl/int8/fixed_point.h"
|
||||
|
||||
#ifdef ENABLE_NEON
|
||||
|
||||
|
|
|
@ -18,7 +18,7 @@
|
|||
#define MINDSPORE_LITE_NNACL_INT8_SUB_INT8_H_
|
||||
|
||||
#include "nnacl/op_base.h"
|
||||
#include "nnacl/quantization/quantize.h"
|
||||
#include "nnacl/int8/quantize.h"
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
|
|
|
@ -18,8 +18,8 @@
|
|||
#define MINDSPORE_LITE_NNACL_INT8_TANH_INT8_H_
|
||||
|
||||
#include "nnacl/op_base.h"
|
||||
#include "nnacl/quantization/quantize.h"
|
||||
#include "nnacl/quantization/fixed_point.h"
|
||||
#include "nnacl/int8/quantize.h"
|
||||
#include "nnacl/int8/fixed_point.h"
|
||||
#include "nnacl/int8/quant_dtype_cast_int8.h"
|
||||
#include "nnacl/fp32/activation_fp32.h"
|
||||
|
||||
|
|
|
@ -16,7 +16,6 @@
|
|||
|
||||
#include "nnacl/unsqueeze_parameter.h"
|
||||
#include "nnacl/int8/unsqueeze_int8.h"
|
||||
#include <string.h>
|
||||
|
||||
int Int8Unsqueeze(int8_t *input_ptr, int8_t *output_ptr, UnSqueezeParameter *para_, size_t data_size, int task_id) {
|
||||
float output_scale = para_->quant_arg.out_quant_args_.scale_;
|
||||
|
|
|
@ -17,7 +17,7 @@
|
|||
#define MINDSPORE_LITE_NNACL_L2NORM_PARAMETER_H_
|
||||
|
||||
#include "nnacl/op_base.h"
|
||||
#include "nnacl/quantization/quantize.h"
|
||||
#include "mindspore/lite/nnacl/int8/quantize.h"
|
||||
|
||||
typedef struct L2NormParameter {
|
||||
// Primitive parameter
|
||||
|
|
|
@ -17,7 +17,7 @@
|
|||
#define MINDSPORE_LITE_NNACL_LAYER_NORM_PARAMETER_H_
|
||||
|
||||
#include "nnacl/op_base.h"
|
||||
#include "nnacl/quantization/quantize.h"
|
||||
#include "mindspore/lite/nnacl/int8/quantize.h"
|
||||
|
||||
enum ElementwiseMode { ELEMENTWISE_NOT = 0, ELEMENTWISE_PER_CHANNEL = 1, ELEMENTWISE_PER_NUM = 2 };
|
||||
typedef struct LayerNormParameter {
|
||||
|
|
|
@ -18,7 +18,6 @@
|
|||
#define MINDSPORE_LITE_NNACL_MATMUL_H_
|
||||
|
||||
#include "nnacl/op_base.h"
|
||||
#include "nnacl/quantization/quantize.h"
|
||||
|
||||
typedef void (*MATMUL_OPT_R4_FUNC)(const int8_t *a, const int8_t *b, int *dst, int row_4, int col_4, int deep_16,
|
||||
const int *input_sum, const int *bias);
|
||||
|
|
|
@ -18,7 +18,16 @@
|
|||
#define MINDSPORE_LITE_NNACL_MUL_PARAMETER_H_
|
||||
|
||||
#include "nnacl/op_base.h"
|
||||
#include "nnacl/quantization/quantize.h"
|
||||
|
||||
typedef struct MulQuantArg {
|
||||
QuantArg in_quant_args_[2];
|
||||
QuantArg out_quant_arg_;
|
||||
int output_multiplier_;
|
||||
int output_activation_min_;
|
||||
int output_activation_max_;
|
||||
int shift_left_;
|
||||
int shift_right_;
|
||||
} MulQuantArg;
|
||||
|
||||
typedef struct MulParameter {
|
||||
// Primitive parameter
|
||||
|
|
|
@ -14,11 +14,4 @@
|
|||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef PYBIND_API_PYBIND_PATCH_H_
|
||||
#define PYBIND_API_PYBIND_PATCH_H_
|
||||
|
||||
namespace pybind11 {
|
||||
PYBIND11_RUNTIME_EXCEPTION(attribute_error, PyExc_AttributeError)
|
||||
}
|
||||
|
||||
#endif // PYBIND_API_PYBIND_PATCH_H_
|
||||
#include "nnacl/nnacl_common.h"
|
|
@ -0,0 +1,35 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_LITE_NNACL_NNACL_COMMON_H_
|
||||
#define MINDSPORE_LITE_NNACL_NNACL_COMMON_H_
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif
|
||||
|
||||
inline void ComputeStrides(const int *shape, int *strides, const int ndim) {
|
||||
int stride = 1;
|
||||
for (int i = ndim - 1; i >= 0; i--) {
|
||||
strides[i] = stride;
|
||||
stride *= shape[i];
|
||||
}
|
||||
}
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
#endif // MINDSPORE_LITE_NNACL_NNACL_COMMON_H_
|
|
@ -28,6 +28,7 @@
|
|||
#include <stdint.h>
|
||||
#include <stdlib.h>
|
||||
#include <stdbool.h>
|
||||
#include <string.h>
|
||||
|
||||
#define C2NUM 2
|
||||
#define C4NUM 4
|
||||
|
@ -78,6 +79,17 @@ typedef struct OpParameter {
|
|||
int thread_num_;
|
||||
} OpParameter;
|
||||
|
||||
typedef struct QuantArg {
|
||||
float scale_;
|
||||
int32_t zp_;
|
||||
} QuantArg;
|
||||
|
||||
typedef struct QuantMulArg {
|
||||
int32_t multiplier_;
|
||||
int left_shift_;
|
||||
int right_shift_;
|
||||
} QuantMulArg;
|
||||
|
||||
typedef enum ActType { ActType_No, ActType_Relu, ActType_Sigmod, ActType_Relu6, ActType_Prelu } ActType;
|
||||
typedef enum PadMode { Pad_No, Pad_Same, Pad_Valid } PadMode;
|
||||
typedef enum RoundingMode { Rounding_No, Rounding_Away_from_zero, Rounding_Up } RoundingMode;
|
||||
|
|
File diff suppressed because it is too large
Load Diff
|
@ -17,102 +17,12 @@
|
|||
#ifndef MINDSPORE_LITE_NNACL_PACK_H_
|
||||
#define MINDSPORE_LITE_NNACL_PACK_H_
|
||||
|
||||
#include <stdio.h>
|
||||
#ifdef ENABLE_NEON
|
||||
#include <arm_neon.h>
|
||||
#endif
|
||||
#include "nnacl/conv_parameter.h"
|
||||
#include "nnacl/op_base.h"
|
||||
#include "nnacl/fp32/pack_fp32.h"
|
||||
#include "nnacl/int8/pack_int8.h"
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif
|
||||
void Im2ColPackUnitFp32(const float *input_data, const ConvParameter *conv_param, float *packed_input, int real_cal_num,
|
||||
int block_index);
|
||||
|
||||
void PackHWCToWHC(const float *src, float *dst, int height, int width, int channel);
|
||||
|
||||
void Im2ColPackUnitInt8Opt(const int8_t *input_data, int8_t *packed_input, int8_t *matmul_input, int real_cal_num,
|
||||
int block_index, int32_t *filter_zp, int32_t *input_sum, ConvParameter *conv_param,
|
||||
bool per_channel, bool is_optimize);
|
||||
|
||||
void PackInputSum16x4PerLayer(const int8_t *src, int32_t *dst, int32_t filter_zp, size_t row4, size_t col16);
|
||||
|
||||
void PackInputSum16x4PerChannelArm32(const int8_t *input_value, int32_t *input_sum, int32_t *filter_zp_ptr,
|
||||
size_t plane_size, size_t input_channel, size_t output_channel);
|
||||
|
||||
void PackInputSum16x4PerChannel(const int8_t *input_value, int32_t *input_sum, int32_t *filter_zp_ptr,
|
||||
size_t plane_size, size_t input_channel, size_t output_channel);
|
||||
|
||||
void Conv1x1InputPack(const void *src_ptr, void *dst_ptr, ConvParameter *conv_param, int data_size);
|
||||
|
||||
void Pack1x1WeightFp32(const float *weight_data, float *packed_weight, ConvParameter *conv_param);
|
||||
|
||||
void PackInputSum16x4Int8(const int8_t *input, int32_t *input_sum, int32_t *filter_zp, ConvParameter *conv_param);
|
||||
|
||||
void PackInputToC8Int8(const int8_t *input_data, int16_t *packed_input, ConvParameter *conv_param);
|
||||
|
||||
void PackWeightKHWToHWKFp32(const void *src, void *dst, int plane, int channel);
|
||||
|
||||
void PackWeightInt8(int8_t *weight_data, ConvParameter *conv_param, int8_t *packed_weight, int32_t *weight_sum);
|
||||
|
||||
void PackWeightToC8Int8(const int8_t *origin_weight_data, int16_t *packed_weight_data, ConvParameter *conv_param);
|
||||
|
||||
void PackNHWCToNC4HW4Fp32(const void *src, void *dst, int batch, int plane, int channel);
|
||||
|
||||
void PackNCHWToNC4HW4Fp32(const void *src, void *dst, int batch, int plane, int channel);
|
||||
|
||||
void PackNHWCToNHWC4Fp32(const void *src, void *dst, int batch, int plane, int channel);
|
||||
|
||||
void PackNHWCToNHWC8Fp32(const void *src, void *dst, int batch, int plane, int channel);
|
||||
|
||||
void PackNHWCToNCHWFp32(const void *src, void *dst, int batch, int plane, int channel);
|
||||
|
||||
void PackNHWCToNCHWInt8(const void *src, void *dst, int batch, int plane, int channel);
|
||||
|
||||
void PackNCHWToNHWCFp32(const void *src, void *dst, int batch, int plane, int channel);
|
||||
|
||||
void PackNHWC4ToNHWCFp32(const void *src, void *dst, int batch, int plane, int channel);
|
||||
|
||||
void PackNC4HW4ToNHWC4Fp32(const void *src, void *dst, int batch, int plane, int channel);
|
||||
|
||||
void PackNC4HW4ToNHWCFp32(const void *src, void *dst, int batch, int plane, int channel);
|
||||
|
||||
void PackNHWCToC8HWN8Fp32(const void *src, void *dst, int batch, int plane, int channel);
|
||||
|
||||
void PackDepthwiseIndirectWeightC4Fp32(const void *src, void *dst, int height, int width, int channel);
|
||||
|
||||
void PackDepthwiseIndirectWeightC8Fp32(const void *src, void *dst, int height, int width, int channel);
|
||||
|
||||
void PackNHWCToNHWC4Int8(const void *src, void *dst, int batch, int plane, int channel);
|
||||
|
||||
void PackNHWC4ToNHWCInt8(const void *src, void *dst, int batch, int plane, int channel);
|
||||
|
||||
void PackNHWCToNHWC8Int8(const void *src, void *dst, int batch, int plane, int channel);
|
||||
|
||||
void PackNHWC8ToNHWCInt8(const void *src, void *dst, int batch, int plane, int channel);
|
||||
|
||||
void PackNCHWToNC8HW8Int8(const void *src, void *dst, int batch, int plane, int channel);
|
||||
|
||||
void PackNC4HW4ToNHWCInt8(const void *src, void *dst, int batch, int plane, int channel);
|
||||
|
||||
void PackNHWCToC8HWN8Int8(const void *src, void *dst, int batch, int plane, int channel);
|
||||
|
||||
void PackNCHWToNHWCInt8(const void *src, void *dst, int batch, int plane, int channel);
|
||||
|
||||
void PackDepthwiseInt8Input(const int8_t *src, int16_t *dst, const ConvParameter *conv_param);
|
||||
|
||||
void PackDepthwiseInt8Weight(const int8_t *origin_weight, int16_t *packed_weight_, int plane, int channel,
|
||||
ConvQuantArg *quant_qrg);
|
||||
|
||||
void PackDeconvDepthwiseInt8Weight(const int8_t *origin_weight, int16_t *packed_weight_, int plane, int channel,
|
||||
ConvQuantArg *quant_qrg);
|
||||
|
||||
#ifdef ENABLE_ARM
|
||||
void PreSum4x16Int8Pert(const int8_t *src, int32_t *sum, size_t row4, size_t col16, int32_t filter_zp);
|
||||
void PreSum4x16Int8Peroc(const int8_t *src, int32_t *sum, int32_t *zp, size_t hw4, size_t ic16, int32_t oc_div,
|
||||
size_t oc_res, size_t stride);
|
||||
#endif
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
|
|
|
@ -17,11 +17,16 @@
|
|||
#define MINDSPORE_LITE_NNACL_PAD_PARAMETER_H_
|
||||
|
||||
#include "nnacl/op_base.h"
|
||||
#include "nnacl/quantization/quantize.h"
|
||||
|
||||
#define MAX_PAD_SIZE 8
|
||||
#define DEFAULT_PAD_NDIMS 4
|
||||
|
||||
typedef struct PadQuantArg {
|
||||
QuantArg *in_quant_args_;
|
||||
QuantArg *out_quanr_args_;
|
||||
int8_t *constant_value_;
|
||||
} PadQuantArg;
|
||||
|
||||
typedef struct PadParameter {
|
||||
// Primitive parameter
|
||||
OpParameter op_parameter_;
|
||||
|
|
|
@ -17,7 +17,6 @@
|
|||
#define MINDSPORE_LITE_NNACL_POOLING_PARAMETER_H_
|
||||
|
||||
#include "nnacl/op_base.h"
|
||||
#include "nnacl/quantization/quantize.h"
|
||||
|
||||
typedef enum PoolMode { PoolMode_No, PoolMode_MaxPool, PoolMode_AvgPool } PoolMode;
|
||||
|
||||
|
|
|
@ -18,7 +18,14 @@
|
|||
#define MINDSPORE_LITE_NNACL_POWER_PARAMETER_H_
|
||||
|
||||
#include "nnacl/op_base.h"
|
||||
#include "nnacl/quantization/quantize.h"
|
||||
|
||||
typedef struct PowerQuantArg {
|
||||
QuantArg in_args_;
|
||||
QuantArg exp_args_;
|
||||
QuantArg out_args_;
|
||||
int output_activation_min_;
|
||||
int output_activation_max_;
|
||||
} PowerQuantArg;
|
||||
|
||||
typedef struct PowerParameter {
|
||||
// Primitive parameter
|
||||
|
|
|
@ -18,7 +18,13 @@
|
|||
#define MINDSPORE_LITE_NNACL_RESHAHPE_PARAMETER_H_
|
||||
|
||||
#include "nnacl/op_base.h"
|
||||
#include "nnacl/quantization/quantize.h"
|
||||
|
||||
typedef struct ReshapeQuantArg {
|
||||
QuantArg in_args_;
|
||||
QuantArg out_args_;
|
||||
int output_activation_min_;
|
||||
int output_activation_max_;
|
||||
} ReshapeQuantArg;
|
||||
|
||||
typedef struct ReshapeParameter {
|
||||
// primitive parameter
|
||||
|
|
|
@ -17,8 +17,8 @@
|
|||
#ifndef MINDSPORE_LITE_NNACL_SCALE_H_
|
||||
#define MINDSPORE_LITE_NNACL_SCALE_H_
|
||||
|
||||
#include <mindspore/lite/nnacl/quantization/quantize.h>
|
||||
#include "nnacl/op_base.h"
|
||||
|
||||
typedef struct ScaleParameter {
|
||||
// primitive parameter
|
||||
OpParameter op_parameter_;
|
||||
|
|
|
@ -16,7 +16,6 @@
|
|||
|
||||
#include "nnacl/scatter_nd.h"
|
||||
#include <string.h>
|
||||
#include <stdio.h>
|
||||
#include "nnacl/errorcode.h"
|
||||
|
||||
int DoScatterND(float *output_ptr, const float *update, int *output_unit_offsets, int unit_size, int num_units) {
|
||||
|
|
|
@ -18,10 +18,16 @@
|
|||
#define MINDSPORE_LITE_NNACL_SLICE_PARAMETER_H_
|
||||
|
||||
#include "nnacl/op_base.h"
|
||||
#include "nnacl/quantization/quantize.h"
|
||||
|
||||
#define SLICE_SHAPE_MAX_SIZE 4
|
||||
|
||||
typedef struct SliceQuantArg {
|
||||
QuantArg in_args_;
|
||||
QuantArg out_args_;
|
||||
int output_activation_min_;
|
||||
int output_activation_max_;
|
||||
} SliceQuantArg;
|
||||
|
||||
typedef struct SliceParameter {
|
||||
// primitive parameter
|
||||
OpParameter op_parameter_;
|
||||
|
|
|
@ -18,8 +18,16 @@
|
|||
#define MINDSPORE_LITE_NNACL_SPLIT_PARAMETER_H_
|
||||
|
||||
#include "nnacl/op_base.h"
|
||||
#include "nnacl/quantization/quantize.h"
|
||||
|
||||
#define SPLIT_STRIDES_SIZE 32
|
||||
|
||||
typedef struct SplitQuantArg {
|
||||
QuantArg in_args_;
|
||||
QuantArg out_args_[20];
|
||||
int output_activation_min_;
|
||||
int output_activation_max_;
|
||||
} SplitQuantArg;
|
||||
|
||||
typedef struct SplitParameter {
|
||||
// primitive parameter
|
||||
OpParameter op_parameter_;
|
||||
|
|
|
@ -16,11 +16,16 @@
|
|||
|
||||
#ifndef MINDSPORE_LITE_NNACL_SQUEEZE_PARAMETER_H_
|
||||
#define MINDSPORE_LITE_NNACL_SQUEEZE_PARAMETER_H_
|
||||
|
||||
#include "nnacl/op_base.h"
|
||||
#include "nnacl/quantization/quantize.h"
|
||||
|
||||
#define SQUEEZE_OFFSET_MAX_SIZE 4
|
||||
|
||||
typedef struct SqueezeQuantArg {
|
||||
QuantArg *in_quant_args_;
|
||||
QuantArg *out_quant_args_;
|
||||
} SqueezeQuantArg;
|
||||
|
||||
typedef struct SqueezeParameter {
|
||||
// primitive parameter
|
||||
OpParameter op_parameter_;
|
||||
|
|
|
@ -16,11 +16,26 @@
|
|||
|
||||
#ifndef MINDSPORE_LITE_NNACL_UNSQUEEZE_PARAMETER_H_
|
||||
#define MINDSPORE_LITE_NNACL_UNSQUEEZE_PARAMETER_H_
|
||||
|
||||
#include <string.h>
|
||||
#include <math.h>
|
||||
#include "nnacl/op_base.h"
|
||||
#include "nnacl/quantization/quantize.h"
|
||||
|
||||
#define UNSQUEEZE_OFFSET_MAX_SIZE 4
|
||||
|
||||
typedef struct UnSqueezeQuantArg {
|
||||
int *input_sizes_;
|
||||
int output_size_;
|
||||
int **input_shapes_;
|
||||
int *output_shape_;
|
||||
float alpha;
|
||||
int axis_;
|
||||
size_t input_num_;
|
||||
size_t output_dim_;
|
||||
QuantArg in_quant_args_;
|
||||
QuantArg out_quant_args_;
|
||||
} UnSqueezeQuantArg;
|
||||
|
||||
typedef struct UnSqueezeParameter {
|
||||
// primitive parameter
|
||||
OpParameter op_parameter_;
|
||||
|
|
|
@ -140,810 +140,3 @@ void WinogradOutputTransform(const float *gemm_out, float *out_data, const float
|
|||
out_tile_index++;
|
||||
}
|
||||
}
|
||||
|
||||
// int8 conv3x3
|
||||
void Conv3x3Int8InputUnit(int16_t *tmp_data, int16_t *trans_input_data, size_t step, int input_zp) {
|
||||
#ifdef ENABLE_ARM
|
||||
int16x8_t zp = vdupq_n_s16(input_zp);
|
||||
|
||||
int16x8_t d00 = vsubq_s16(vld1q_s16(tmp_data), zp);
|
||||
int16x8_t d01 = vsubq_s16(vld1q_s16(tmp_data + 8), zp);
|
||||
int16x8_t d02 = vsubq_s16(vld1q_s16(tmp_data + 2 * 8), zp);
|
||||
int16x8_t d03 = vsubq_s16(vld1q_s16(tmp_data + 3 * 8), zp);
|
||||
|
||||
int16x8_t d10 = vsubq_s16(vld1q_s16(tmp_data + 4 * 8), zp);
|
||||
int16x8_t d11 = vsubq_s16(vld1q_s16(tmp_data + 5 * 8), zp);
|
||||
int16x8_t d12 = vsubq_s16(vld1q_s16(tmp_data + 6 * 8), zp);
|
||||
int16x8_t d13 = vsubq_s16(vld1q_s16(tmp_data + 7 * 8), zp);
|
||||
|
||||
int16x8_t d20 = vsubq_s16(vld1q_s16(tmp_data + 8 * 8), zp);
|
||||
int16x8_t d21 = vsubq_s16(vld1q_s16(tmp_data + 9 * 8), zp);
|
||||
int16x8_t d22 = vsubq_s16(vld1q_s16(tmp_data + 10 * 8), zp);
|
||||
int16x8_t d23 = vsubq_s16(vld1q_s16(tmp_data + 11 * 8), zp);
|
||||
|
||||
int16x8_t d30 = vsubq_s16(vld1q_s16(tmp_data + 12 * 8), zp);
|
||||
int16x8_t d31 = vsubq_s16(vld1q_s16(tmp_data + 13 * 8), zp);
|
||||
int16x8_t d32 = vsubq_s16(vld1q_s16(tmp_data + 14 * 8), zp);
|
||||
int16x8_t d33 = vsubq_s16(vld1q_s16(tmp_data + 15 * 8), zp);
|
||||
|
||||
int16x8_t t00 = vsubq_s16(d00, d20);
|
||||
int16x8_t t01 = vsubq_s16(d01, d21);
|
||||
int16x8_t t02 = vsubq_s16(d02, d22);
|
||||
int16x8_t t03 = vsubq_s16(d03, d23);
|
||||
|
||||
int16x8_t t10 = vaddq_s16(d10, d20);
|
||||
int16x8_t t11 = vaddq_s16(d11, d21);
|
||||
int16x8_t t12 = vaddq_s16(d12, d22);
|
||||
int16x8_t t13 = vaddq_s16(d13, d23);
|
||||
|
||||
int16x8_t t20 = vsubq_s16(d20, d10);
|
||||
int16x8_t t21 = vsubq_s16(d21, d11);
|
||||
int16x8_t t22 = vsubq_s16(d22, d12);
|
||||
int16x8_t t23 = vsubq_s16(d23, d13);
|
||||
|
||||
int16x8_t t30 = vsubq_s16(d10, d30);
|
||||
int16x8_t t31 = vsubq_s16(d11, d31);
|
||||
int16x8_t t32 = vsubq_s16(d12, d32);
|
||||
int16x8_t t33 = vsubq_s16(d13, d33);
|
||||
|
||||
int16x8_t m00 = vsubq_s16(t00, t02);
|
||||
int16x8_t m01 = vaddq_s16(t01, t02);
|
||||
int16x8_t m02 = vsubq_s16(t02, t01);
|
||||
int16x8_t m03 = vsubq_s16(t01, t03);
|
||||
|
||||
int16x8_t m10 = vsubq_s16(t10, t12);
|
||||
int16x8_t m11 = vaddq_s16(t11, t12);
|
||||
int16x8_t m12 = vsubq_s16(t12, t11);
|
||||
int16x8_t m13 = vsubq_s16(t11, t13);
|
||||
|
||||
int16x8_t m20 = vsubq_s16(t20, t22);
|
||||
int16x8_t m21 = vaddq_s16(t21, t22);
|
||||
int16x8_t m22 = vsubq_s16(t22, t21);
|
||||
int16x8_t m23 = vsubq_s16(t21, t23);
|
||||
|
||||
int16x8_t m30 = vsubq_s16(t30, t32);
|
||||
int16x8_t m31 = vaddq_s16(t31, t32);
|
||||
int16x8_t m32 = vsubq_s16(t32, t31);
|
||||
int16x8_t m33 = vsubq_s16(t31, t33);
|
||||
|
||||
vst1q_s16(trans_input_data, m00);
|
||||
vst1q_s16(trans_input_data + step, m01);
|
||||
vst1q_s16(trans_input_data + 2 * step, m02);
|
||||
vst1q_s16(trans_input_data + 3 * step, m03);
|
||||
|
||||
vst1q_s16(trans_input_data + 4 * step, m10);
|
||||
vst1q_s16(trans_input_data + 5 * step, m11);
|
||||
vst1q_s16(trans_input_data + 6 * step, m12);
|
||||
vst1q_s16(trans_input_data + 7 * step, m13);
|
||||
|
||||
vst1q_s16(trans_input_data + 8 * step, m20);
|
||||
vst1q_s16(trans_input_data + 9 * step, m21);
|
||||
vst1q_s16(trans_input_data + 10 * step, m22);
|
||||
vst1q_s16(trans_input_data + 11 * step, m23);
|
||||
|
||||
vst1q_s16(trans_input_data + 12 * step, m30);
|
||||
vst1q_s16(trans_input_data + 13 * step, m31);
|
||||
vst1q_s16(trans_input_data + 14 * step, m32);
|
||||
vst1q_s16(trans_input_data + 15 * step, m33);
|
||||
#else
|
||||
for (int i = 0; i < C8NUM; i++) {
|
||||
int16_t *local_ptr = tmp_data + i;
|
||||
int16_t d00 = local_ptr[0] - input_zp;
|
||||
int16_t d01 = (local_ptr + C8NUM)[0] - input_zp;
|
||||
int16_t d02 = (local_ptr + 2 * C8NUM)[0] - input_zp;
|
||||
int16_t d03 = (local_ptr + 3 * C8NUM)[0] - input_zp;
|
||||
|
||||
int16_t d10 = (local_ptr + 4 * C8NUM)[0] - input_zp;
|
||||
int16_t d11 = (local_ptr + 5 * C8NUM)[0] - input_zp;
|
||||
int16_t d12 = (local_ptr + 6 * C8NUM)[0] - input_zp;
|
||||
int16_t d13 = (local_ptr + 7 * C8NUM)[0] - input_zp;
|
||||
|
||||
int16_t d20 = (local_ptr + 8 * C8NUM)[0] - input_zp;
|
||||
int16_t d21 = (local_ptr + 9 * C8NUM)[0] - input_zp;
|
||||
int16_t d22 = (local_ptr + 10 * C8NUM)[0] - input_zp;
|
||||
int16_t d23 = (local_ptr + 11 * C8NUM)[0] - input_zp;
|
||||
|
||||
int16_t d30 = (local_ptr + 12 * C8NUM)[0] - input_zp;
|
||||
int16_t d31 = (local_ptr + 13 * C8NUM)[0] - input_zp;
|
||||
int16_t d32 = (local_ptr + 14 * C8NUM)[0] - input_zp;
|
||||
int16_t d33 = (local_ptr + 15 * C8NUM)[0] - input_zp;
|
||||
|
||||
int16_t t00 = d00 - d20;
|
||||
int16_t t01 = d01 - d21;
|
||||
int16_t t02 = d02 - d22;
|
||||
int16_t t03 = d03 - d23;
|
||||
|
||||
int16_t t10 = d10 + d20;
|
||||
int16_t t11 = d11 + d21;
|
||||
int16_t t12 = d12 + d22;
|
||||
int16_t t13 = d13 + d23;
|
||||
|
||||
int16_t t20 = d20 - d10;
|
||||
int16_t t21 = d21 - d11;
|
||||
int16_t t22 = d22 - d12;
|
||||
int16_t t23 = d23 - d13;
|
||||
|
||||
int16_t t30 = d10 - d30;
|
||||
int16_t t31 = d11 - d31;
|
||||
int16_t t32 = d12 - d32;
|
||||
int16_t t33 = d13 - d33;
|
||||
|
||||
int16_t m00 = t00 - t02;
|
||||
int16_t m01 = t01 + t02;
|
||||
int16_t m02 = t02 - t01;
|
||||
int16_t m03 = t01 - t03;
|
||||
|
||||
int16_t m10 = t10 - t12;
|
||||
int16_t m11 = t11 + t12;
|
||||
int16_t m12 = t12 - t11;
|
||||
int16_t m13 = t11 - t13;
|
||||
|
||||
int16_t m20 = t20 - t22;
|
||||
int16_t m21 = t21 + t22;
|
||||
int16_t m22 = t22 - t21;
|
||||
int16_t m23 = t21 - t23;
|
||||
|
||||
int16_t m30 = t30 - t32;
|
||||
int16_t m31 = t31 + t32;
|
||||
int16_t m32 = t32 - t31;
|
||||
int16_t m33 = t31 - t33;
|
||||
|
||||
(trans_input_data + i)[0] = m00;
|
||||
(trans_input_data + i + step)[0] = m01;
|
||||
(trans_input_data + i + 2 * step)[0] = m02;
|
||||
(trans_input_data + i + 3 * step)[0] = m03;
|
||||
|
||||
(trans_input_data + i + 4 * step)[0] = m10;
|
||||
(trans_input_data + i + 5 * step)[0] = m11;
|
||||
(trans_input_data + i + 6 * step)[0] = m12;
|
||||
(trans_input_data + i + 7 * step)[0] = m13;
|
||||
|
||||
(trans_input_data + i + 8 * step)[0] = m20;
|
||||
(trans_input_data + i + 9 * step)[0] = m21;
|
||||
(trans_input_data + i + 10 * step)[0] = m22;
|
||||
(trans_input_data + i + 11 * step)[0] = m23;
|
||||
|
||||
(trans_input_data + i + 12 * step)[0] = m30;
|
||||
(trans_input_data + i + 13 * step)[0] = m31;
|
||||
(trans_input_data + i + 14 * step)[0] = m32;
|
||||
(trans_input_data + i + 15 * step)[0] = m33;
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
void Conv3x3Int8InputTransform(const int16_t *input_data, int16_t *trans_input, int16_t *tmp_data, int start_index,
|
||||
int real_cal_num, int out_w_block, ConvParameter *conv_param) {
|
||||
// input data format : nhwc
|
||||
int input_channel = conv_param->input_channel_;
|
||||
int input_width = conv_param->input_w_;
|
||||
int input_height = conv_param->input_h_;
|
||||
int pad_w = conv_param->pad_l_;
|
||||
int pad_h = conv_param->pad_u_;
|
||||
ConvQuantArg quant_arg = conv_param->conv_quant_arg_;
|
||||
int input_zp = quant_arg.input_quant_args_[0].zp_;
|
||||
const int ic8 = UP_DIV(input_channel, C8NUM);
|
||||
const int input_unit = 4;
|
||||
if (out_w_block == 0) {
|
||||
return;
|
||||
}
|
||||
for (int cal_id = 0; cal_id < real_cal_num; cal_id++) {
|
||||
int x_id = start_index + cal_id;
|
||||
int origin_x = (x_id % out_w_block) * OUPUT_UNIT - pad_w;
|
||||
int origin_y = (x_id / out_w_block) * OUPUT_UNIT - pad_h;
|
||||
int real_x_start = origin_x > 0 ? 0 : -origin_x;
|
||||
int real_x_end = (origin_x + input_unit) < input_width ? input_unit : (input_width - origin_x);
|
||||
int real_y_start = origin_y > 0 ? 0 : -origin_y;
|
||||
int real_y_end = (origin_y + input_unit) < input_height ? input_unit : (input_height - origin_y);
|
||||
|
||||
int src_plane_offset = C8NUM * (origin_y * input_width + origin_x);
|
||||
int dst_plane_offset = cal_id * C8NUM;
|
||||
for (int ic = 0; ic < ic8; ic++) {
|
||||
// copy data from origin input to tmp buffer
|
||||
for (int i = 0; i < input_unit * input_unit * TILE_NUM; i++) tmp_data[i] = input_zp;
|
||||
|
||||
int src_c8_offset = src_plane_offset + ic * C8NUM * input_height * input_width;
|
||||
for (int j = real_y_start; j < real_y_end; j++) {
|
||||
const int16_t *src = input_data + src_c8_offset + C8NUM * (j * input_width + real_x_start);
|
||||
int16_t *dst = tmp_data + C8NUM * (C4NUM * j + real_x_start);
|
||||
memcpy(dst, src, (real_x_end - real_x_start) * C8NUM * sizeof(int16_t));
|
||||
}
|
||||
// input transform
|
||||
int dst_ic8_offset = dst_plane_offset + ic * TILE_NUM * C8NUM;
|
||||
size_t dst_step = ic8 * C8NUM * TILE_NUM;
|
||||
int16_t *trans_input_ptr = trans_input + dst_ic8_offset;
|
||||
Conv3x3Int8InputUnit(tmp_data, trans_input_ptr, dst_step, input_zp);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void Conv3x3Int8FilterTransform(const int16_t *weight_data, int16_t *trans_weight, int iC8, int output_channel,
|
||||
int kernel_plane) {
|
||||
const int input_unit = 4;
|
||||
int dst_step = iC8 * C8NUM * C4NUM;
|
||||
for (int o = 0; o < output_channel; o++) {
|
||||
int oc4_block_num = o / C4NUM;
|
||||
int oc4_block_rem = o % C4NUM;
|
||||
int src_oc_offset = o * iC8 * C8NUM * kernel_plane;
|
||||
int dst_oc_offset = oc4_block_num * C4NUM * iC8 * C8NUM * input_unit * input_unit + oc4_block_rem;
|
||||
for (int i = 0; i < iC8; i++) {
|
||||
const int16_t *src_ic8_ptr = weight_data + src_oc_offset + i * kernel_plane * C8NUM;
|
||||
int16_t *dst_ic8_ptr = trans_weight + dst_oc_offset + i * C4NUM * C8NUM;
|
||||
#ifdef ENABLE_ARM
|
||||
int16x8_t g00 = vld1q_s16(src_ic8_ptr);
|
||||
int16x8_t g01 = vld1q_s16(src_ic8_ptr + 8);
|
||||
int16x8_t g02 = vld1q_s16(src_ic8_ptr + 2 * 8);
|
||||
int16x8_t g10 = vld1q_s16(src_ic8_ptr + 3 * 8);
|
||||
int16x8_t g11 = vld1q_s16(src_ic8_ptr + 4 * 8);
|
||||
int16x8_t g12 = vld1q_s16(src_ic8_ptr + 5 * 8);
|
||||
int16x8_t g20 = vld1q_s16(src_ic8_ptr + 6 * 8);
|
||||
int16x8_t g21 = vld1q_s16(src_ic8_ptr + 7 * 8);
|
||||
int16x8_t g22 = vld1q_s16(src_ic8_ptr + 8 * 8);
|
||||
|
||||
int16x8_t dst00 = vmulq_n_s16(g00, 2);
|
||||
int16x8_t dst01 = vmulq_n_s16(g01, 2);
|
||||
int16x8_t dst02 = vmulq_n_s16(g02, 2);
|
||||
|
||||
int16x8_t dst10 = vaddq_s16(vaddq_s16(g00, g10), g20);
|
||||
int16x8_t dst11 = vaddq_s16(vaddq_s16(g01, g11), g21);
|
||||
int16x8_t dst12 = vaddq_s16(vaddq_s16(g02, g12), g22);
|
||||
|
||||
int16x8_t dst20 = vaddq_s16(vsubq_s16(g00, g10), g20);
|
||||
int16x8_t dst21 = vaddq_s16(vsubq_s16(g01, g11), g21);
|
||||
int16x8_t dst22 = vaddq_s16(vsubq_s16(g02, g12), g22);
|
||||
|
||||
int16x8_t dst30 = vmulq_n_s16(g20, 2);
|
||||
int16x8_t dst31 = vmulq_n_s16(g21, 2);
|
||||
int16x8_t dst32 = vmulq_n_s16(g22, 2);
|
||||
|
||||
int16x8_t m00 = vmulq_n_s16(dst00, 2);
|
||||
int16x8_t m01 = vaddq_s16(vaddq_s16(dst00, dst01), dst02);
|
||||
int16x8_t m02 = vaddq_s16(vsubq_s16(dst00, dst01), dst02);
|
||||
int16x8_t m03 = vmulq_n_s16(dst02, 2);
|
||||
|
||||
int16x8_t m10 = vmulq_n_s16(dst10, 2);
|
||||
int16x8_t m11 = vaddq_s16(vaddq_s16(dst10, dst11), dst12);
|
||||
int16x8_t m12 = vaddq_s16(vsubq_s16(dst10, dst11), dst12);
|
||||
int16x8_t m13 = vmulq_n_s16(dst12, 2);
|
||||
|
||||
int16x8_t m20 = vmulq_n_s16(dst20, 2);
|
||||
int16x8_t m21 = vaddq_s16(vaddq_s16(dst20, dst21), dst22);
|
||||
int16x8_t m22 = vaddq_s16(vsubq_s16(dst20, dst21), dst22);
|
||||
int16x8_t m23 = vmulq_n_s16(dst22, 2);
|
||||
|
||||
int16x8_t m30 = vmulq_n_s16(dst30, 2);
|
||||
int16x8_t m31 = vaddq_s16(vaddq_s16(dst30, dst31), dst32);
|
||||
int16x8_t m32 = vaddq_s16(vsubq_s16(dst30, dst31), dst32);
|
||||
int16x8_t m33 = vmulq_n_s16(dst32, 2);
|
||||
|
||||
dst_ic8_ptr[0] = m00[0];
|
||||
dst_ic8_ptr[4] = m00[1];
|
||||
dst_ic8_ptr[8] = m00[2];
|
||||
dst_ic8_ptr[12] = m00[3];
|
||||
dst_ic8_ptr[16] = m00[4];
|
||||
dst_ic8_ptr[20] = m00[5];
|
||||
dst_ic8_ptr[24] = m00[6];
|
||||
dst_ic8_ptr[28] = m00[7];
|
||||
|
||||
dst_ic8_ptr[0 + dst_step] = m01[0];
|
||||
dst_ic8_ptr[4 + dst_step] = m01[1];
|
||||
dst_ic8_ptr[8 + dst_step] = m01[2];
|
||||
dst_ic8_ptr[12 + dst_step] = m01[3];
|
||||
dst_ic8_ptr[16 + dst_step] = m01[4];
|
||||
dst_ic8_ptr[20 + dst_step] = m01[5];
|
||||
dst_ic8_ptr[24 + dst_step] = m01[6];
|
||||
dst_ic8_ptr[28 + dst_step] = m01[7];
|
||||
|
||||
dst_ic8_ptr[0 + 2 * dst_step] = m02[0];
|
||||
dst_ic8_ptr[4 + 2 * dst_step] = m02[1];
|
||||
dst_ic8_ptr[8 + 2 * dst_step] = m02[2];
|
||||
dst_ic8_ptr[12 + 2 * dst_step] = m02[3];
|
||||
dst_ic8_ptr[16 + 2 * dst_step] = m02[4];
|
||||
dst_ic8_ptr[20 + 2 * dst_step] = m02[5];
|
||||
dst_ic8_ptr[24 + 2 * dst_step] = m02[6];
|
||||
dst_ic8_ptr[28 + 2 * dst_step] = m02[7];
|
||||
|
||||
dst_ic8_ptr[0 + 3 * dst_step] = m03[0];
|
||||
dst_ic8_ptr[4 + 3 * dst_step] = m03[1];
|
||||
dst_ic8_ptr[8 + 3 * dst_step] = m03[2];
|
||||
dst_ic8_ptr[12 + 3 * dst_step] = m03[3];
|
||||
dst_ic8_ptr[16 + 3 * dst_step] = m03[4];
|
||||
dst_ic8_ptr[20 + 3 * dst_step] = m03[5];
|
||||
dst_ic8_ptr[24 + 3 * dst_step] = m03[6];
|
||||
dst_ic8_ptr[28 + 3 * dst_step] = m03[7];
|
||||
|
||||
dst_ic8_ptr[0 + 4 * dst_step] = m10[0];
|
||||
dst_ic8_ptr[4 + 4 * dst_step] = m10[1];
|
||||
dst_ic8_ptr[8 + 4 * dst_step] = m10[2];
|
||||
dst_ic8_ptr[12 + 4 * dst_step] = m10[3];
|
||||
dst_ic8_ptr[16 + 4 * dst_step] = m10[4];
|
||||
dst_ic8_ptr[20 + 4 * dst_step] = m10[5];
|
||||
dst_ic8_ptr[24 + 4 * dst_step] = m10[6];
|
||||
dst_ic8_ptr[28 + 4 * dst_step] = m10[7];
|
||||
|
||||
dst_ic8_ptr[0 + 5 * dst_step] = m11[0];
|
||||
dst_ic8_ptr[4 + 5 * dst_step] = m11[1];
|
||||
dst_ic8_ptr[8 + 5 * dst_step] = m11[2];
|
||||
dst_ic8_ptr[12 + 5 * dst_step] = m11[3];
|
||||
dst_ic8_ptr[16 + 5 * dst_step] = m11[4];
|
||||
dst_ic8_ptr[20 + 5 * dst_step] = m11[5];
|
||||
dst_ic8_ptr[24 + 5 * dst_step] = m11[6];
|
||||
dst_ic8_ptr[28 + 5 * dst_step] = m11[7];
|
||||
|
||||
dst_ic8_ptr[0 + 6 * dst_step] = m12[0];
|
||||
dst_ic8_ptr[4 + 6 * dst_step] = m12[1];
|
||||
dst_ic8_ptr[8 + 6 * dst_step] = m12[2];
|
||||
dst_ic8_ptr[12 + 6 * dst_step] = m12[3];
|
||||
dst_ic8_ptr[16 + 6 * dst_step] = m12[4];
|
||||
dst_ic8_ptr[20 + 6 * dst_step] = m12[5];
|
||||
dst_ic8_ptr[24 + 6 * dst_step] = m12[6];
|
||||
dst_ic8_ptr[28 + 6 * dst_step] = m12[7];
|
||||
|
||||
dst_ic8_ptr[0 + 7 * dst_step] = m13[0];
|
||||
dst_ic8_ptr[4 + 7 * dst_step] = m13[1];
|
||||
dst_ic8_ptr[8 + 7 * dst_step] = m13[2];
|
||||
dst_ic8_ptr[12 + 7 * dst_step] = m13[3];
|
||||
dst_ic8_ptr[16 + 7 * dst_step] = m13[4];
|
||||
dst_ic8_ptr[20 + 7 * dst_step] = m13[5];
|
||||
dst_ic8_ptr[24 + 7 * dst_step] = m13[6];
|
||||
dst_ic8_ptr[28 + 7 * dst_step] = m13[7];
|
||||
|
||||
dst_ic8_ptr[0 + 8 * dst_step] = m20[0];
|
||||
dst_ic8_ptr[4 + 8 * dst_step] = m20[1];
|
||||
dst_ic8_ptr[8 + 8 * dst_step] = m20[2];
|
||||
dst_ic8_ptr[12 + 8 * dst_step] = m20[3];
|
||||
dst_ic8_ptr[16 + 8 * dst_step] = m20[4];
|
||||
dst_ic8_ptr[20 + 8 * dst_step] = m20[5];
|
||||
dst_ic8_ptr[24 + 8 * dst_step] = m20[6];
|
||||
dst_ic8_ptr[28 + 8 * dst_step] = m20[7];
|
||||
|
||||
dst_ic8_ptr[0 + 9 * dst_step] = m21[0];
|
||||
dst_ic8_ptr[4 + 9 * dst_step] = m21[1];
|
||||
dst_ic8_ptr[8 + 9 * dst_step] = m21[2];
|
||||
dst_ic8_ptr[12 + 9 * dst_step] = m21[3];
|
||||
dst_ic8_ptr[16 + 9 * dst_step] = m21[4];
|
||||
dst_ic8_ptr[20 + 9 * dst_step] = m21[5];
|
||||
dst_ic8_ptr[24 + 9 * dst_step] = m21[6];
|
||||
dst_ic8_ptr[28 + 9 * dst_step] = m21[7];
|
||||
|
||||
dst_ic8_ptr[0 + 10 * dst_step] = m22[0];
|
||||
dst_ic8_ptr[4 + 10 * dst_step] = m22[1];
|
||||
dst_ic8_ptr[8 + 10 * dst_step] = m22[2];
|
||||
dst_ic8_ptr[12 + 10 * dst_step] = m22[3];
|
||||
dst_ic8_ptr[16 + 10 * dst_step] = m22[4];
|
||||
dst_ic8_ptr[20 + 10 * dst_step] = m22[5];
|
||||
dst_ic8_ptr[24 + 10 * dst_step] = m22[6];
|
||||
dst_ic8_ptr[28 + 10 * dst_step] = m22[7];
|
||||
|
||||
dst_ic8_ptr[0 + 11 * dst_step] = m23[0];
|
||||
dst_ic8_ptr[4 + 11 * dst_step] = m23[1];
|
||||
dst_ic8_ptr[8 + 11 * dst_step] = m23[2];
|
||||
dst_ic8_ptr[12 + 11 * dst_step] = m23[3];
|
||||
dst_ic8_ptr[16 + 11 * dst_step] = m23[4];
|
||||
dst_ic8_ptr[20 + 11 * dst_step] = m23[5];
|
||||
dst_ic8_ptr[24 + 11 * dst_step] = m23[6];
|
||||
dst_ic8_ptr[28 + 11 * dst_step] = m23[7];
|
||||
|
||||
dst_ic8_ptr[0 + 12 * dst_step] = m30[0];
|
||||
dst_ic8_ptr[4 + 12 * dst_step] = m30[1];
|
||||
dst_ic8_ptr[8 + 12 * dst_step] = m30[2];
|
||||
dst_ic8_ptr[12 + 12 * dst_step] = m30[3];
|
||||
dst_ic8_ptr[16 + 12 * dst_step] = m30[4];
|
||||
dst_ic8_ptr[20 + 12 * dst_step] = m30[5];
|
||||
dst_ic8_ptr[24 + 12 * dst_step] = m30[6];
|
||||
dst_ic8_ptr[28 + 12 * dst_step] = m30[7];
|
||||
|
||||
dst_ic8_ptr[0 + 13 * dst_step] = m31[0];
|
||||
dst_ic8_ptr[4 + 13 * dst_step] = m31[1];
|
||||
dst_ic8_ptr[8 + 13 * dst_step] = m31[2];
|
||||
dst_ic8_ptr[12 + 13 * dst_step] = m31[3];
|
||||
dst_ic8_ptr[16 + 13 * dst_step] = m31[4];
|
||||
dst_ic8_ptr[20 + 13 * dst_step] = m31[5];
|
||||
dst_ic8_ptr[24 + 13 * dst_step] = m31[6];
|
||||
dst_ic8_ptr[28 + 13 * dst_step] = m31[7];
|
||||
|
||||
dst_ic8_ptr[0 + 14 * dst_step] = m32[0];
|
||||
dst_ic8_ptr[4 + 14 * dst_step] = m32[1];
|
||||
dst_ic8_ptr[8 + 14 * dst_step] = m32[2];
|
||||
dst_ic8_ptr[12 + 14 * dst_step] = m32[3];
|
||||
dst_ic8_ptr[16 + 14 * dst_step] = m32[4];
|
||||
dst_ic8_ptr[20 + 14 * dst_step] = m32[5];
|
||||
dst_ic8_ptr[24 + 14 * dst_step] = m32[6];
|
||||
dst_ic8_ptr[28 + 14 * dst_step] = m32[7];
|
||||
|
||||
dst_ic8_ptr[0 + 15 * dst_step] = m33[0];
|
||||
dst_ic8_ptr[4 + 15 * dst_step] = m33[1];
|
||||
dst_ic8_ptr[8 + 15 * dst_step] = m33[2];
|
||||
dst_ic8_ptr[12 + 15 * dst_step] = m33[3];
|
||||
dst_ic8_ptr[16 + 15 * dst_step] = m33[4];
|
||||
dst_ic8_ptr[20 + 15 * dst_step] = m33[5];
|
||||
dst_ic8_ptr[24 + 15 * dst_step] = m33[6];
|
||||
dst_ic8_ptr[28 + 15 * dst_step] = m33[7];
|
||||
#else
|
||||
for (int j = 0; j < C8NUM; j++) {
|
||||
const int16_t *local_ptr = src_ic8_ptr + j;
|
||||
int16_t dst00 = local_ptr[0] * 2;
|
||||
int16_t dst01 = (local_ptr + 8)[0] * 2;
|
||||
int16_t dst02 = (local_ptr + 16)[0] * 2;
|
||||
|
||||
int16_t dst10 = local_ptr[0] + (local_ptr + 24)[0] + (local_ptr + 48)[0];
|
||||
int16_t dst11 = (local_ptr + 8)[0] + (local_ptr + 32)[0] + (local_ptr + 56)[0];
|
||||
int16_t dst12 = (local_ptr + 16)[0] + (local_ptr + 40)[0] + (local_ptr + 64)[0];
|
||||
|
||||
int16_t dst20 = local_ptr[0] - (local_ptr + 24)[0] + (local_ptr + 48)[0];
|
||||
int16_t dst21 = (local_ptr + 8)[0] - (local_ptr + 32)[0] + (local_ptr + 56)[0];
|
||||
int16_t dst22 = (local_ptr + 16)[0] - (local_ptr + 40)[0] + (local_ptr + 64)[0];
|
||||
|
||||
int16_t dst30 = (local_ptr + 48)[0] * 2;
|
||||
int16_t dst31 = (local_ptr + 56)[0] * 2;
|
||||
int16_t dst32 = (local_ptr + 64)[0] * 2;
|
||||
|
||||
int16_t m00 = dst00 * 2;
|
||||
int16_t m01 = dst00 + dst01 + dst02;
|
||||
int16_t m02 = dst00 - dst01 + dst02;
|
||||
int16_t m03 = dst02 * 2;
|
||||
|
||||
int16_t m10 = dst10 * 2;
|
||||
int16_t m11 = dst10 + dst11 + dst12;
|
||||
int16_t m12 = dst10 - dst11 + dst12;
|
||||
int16_t m13 = dst12 * 2;
|
||||
|
||||
int16_t m20 = dst20 * 2;
|
||||
int16_t m21 = dst20 + dst21 + dst22;
|
||||
int16_t m22 = dst20 - dst21 + dst22;
|
||||
int16_t m23 = dst22 * 2;
|
||||
|
||||
int16_t m30 = dst30 * 2;
|
||||
int16_t m31 = dst30 + dst31 + dst32;
|
||||
int16_t m32 = dst30 - dst31 + dst32;
|
||||
int16_t m33 = dst32 * 2;
|
||||
|
||||
*(dst_ic8_ptr + j * 4) = m00;
|
||||
*(dst_ic8_ptr + j * 4 + dst_step) = m01;
|
||||
*(dst_ic8_ptr + j * 4 + 2 * dst_step) = m02;
|
||||
*(dst_ic8_ptr + j * 4 + 3 * dst_step) = m03;
|
||||
|
||||
*(dst_ic8_ptr + j * 4 + 4 * dst_step) = m10;
|
||||
*(dst_ic8_ptr + j * 4 + 5 * dst_step) = m11;
|
||||
*(dst_ic8_ptr + j * 4 + 6 * dst_step) = m12;
|
||||
*(dst_ic8_ptr + j * 4 + 7 * dst_step) = m13;
|
||||
|
||||
*(dst_ic8_ptr + j * 4 + 8 * dst_step) = m20;
|
||||
*(dst_ic8_ptr + j * 4 + 9 * dst_step) = m21;
|
||||
*(dst_ic8_ptr + j * 4 + 10 * dst_step) = m22;
|
||||
*(dst_ic8_ptr + j * 4 + 11 * dst_step) = m23;
|
||||
|
||||
*(dst_ic8_ptr + j * 4 + 12 * dst_step) = m30;
|
||||
*(dst_ic8_ptr + j * 4 + 13 * dst_step) = m31;
|
||||
*(dst_ic8_ptr + j * 4 + 14 * dst_step) = m32;
|
||||
*(dst_ic8_ptr + j * 4 + 15 * dst_step) = m33;
|
||||
}
|
||||
#endif
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void Conv3x3Int8OutputUnit(const int32_t *gemm_out, const int32_t *bias_data, int8_t *output_data, bool h_not_bound,
|
||||
bool w_not_bound, int output_w, int real_num, int oc_start, ConvParameter *conv_param) {
|
||||
int32_t *left_shift = conv_param->conv_quant_arg_.left_shift_;
|
||||
int32_t *right_shift = conv_param->conv_quant_arg_.right_shift_;
|
||||
int32_t *quant_multiplier = conv_param->conv_quant_arg_.quant_multiplier_;
|
||||
int output_zp = conv_param->conv_quant_arg_.output_quant_args_[0].zp_;
|
||||
int out_min = conv_param->conv_quant_arg_.out_act_min_[0];
|
||||
int out_max = conv_param->conv_quant_arg_.out_act_max_[0];
|
||||
|
||||
#ifdef ENABLE_ARM
|
||||
int32x4_t bias_ptr = vld1q_s32(bias_data);
|
||||
|
||||
int32x4_t s00 = vld1q_s32(gemm_out);
|
||||
int32x4_t s01 = vld1q_s32(gemm_out + 4);
|
||||
int32x4_t s02 = vld1q_s32(gemm_out + 8);
|
||||
int32x4_t s03 = vld1q_s32(gemm_out + 12);
|
||||
|
||||
int32x4_t s10 = vld1q_s32(gemm_out + 16);
|
||||
int32x4_t s11 = vld1q_s32(gemm_out + 20);
|
||||
int32x4_t s12 = vld1q_s32(gemm_out + 24);
|
||||
int32x4_t s13 = vld1q_s32(gemm_out + 28);
|
||||
|
||||
int32x4_t s20 = vld1q_s32(gemm_out + 32);
|
||||
int32x4_t s21 = vld1q_s32(gemm_out + 36);
|
||||
int32x4_t s22 = vld1q_s32(gemm_out + 40);
|
||||
int32x4_t s23 = vld1q_s32(gemm_out + 44);
|
||||
|
||||
int32x4_t s30 = vld1q_s32(gemm_out + 48);
|
||||
int32x4_t s31 = vld1q_s32(gemm_out + 52);
|
||||
int32x4_t s32 = vld1q_s32(gemm_out + 56);
|
||||
int32x4_t s33 = vld1q_s32(gemm_out + 60);
|
||||
|
||||
int32x4_t t00 = vshrq_n_s32(vaddq_s32(vaddq_s32(s00, s10), s20), 1);
|
||||
int32x4_t t01 = vshrq_n_s32(vaddq_s32(vaddq_s32(s01, s11), s21), 1);
|
||||
int32x4_t t02 = vshrq_n_s32(vaddq_s32(vaddq_s32(s02, s12), s22), 1);
|
||||
int32x4_t t03 = vshrq_n_s32(vaddq_s32(vaddq_s32(s03, s13), s23), 1);
|
||||
|
||||
int32x4_t t10 = vshrq_n_s32(vsubq_s32(vsubq_s32(s10, s20), s30), 1);
|
||||
int32x4_t t11 = vshrq_n_s32(vsubq_s32(vsubq_s32(s11, s21), s31), 1);
|
||||
int32x4_t t12 = vshrq_n_s32(vsubq_s32(vsubq_s32(s12, s22), s32), 1);
|
||||
int32x4_t t13 = vshrq_n_s32(vsubq_s32(vsubq_s32(s13, s23), s33), 1);
|
||||
|
||||
int32x4_t d00 = vaddq_s32(vshrq_n_s32(vaddq_s32(vaddq_s32(t00, t01), t02), 1), bias_ptr);
|
||||
int32x4_t d01 = vaddq_s32(vshrq_n_s32(vsubq_s32(vsubq_s32(t01, t02), t03), 1), bias_ptr);
|
||||
|
||||
int32x4_t d10 = vaddq_s32(vshrq_n_s32(vaddq_s32(vaddq_s32(t10, t11), t12), 1), bias_ptr);
|
||||
int32x4_t d11 = vaddq_s32(vshrq_n_s32(vsubq_s32(vsubq_s32(t11, t12), t13), 1), bias_ptr);
|
||||
|
||||
int32x4_t out_multiplier;
|
||||
int32x4_t ls;
|
||||
int32x4_t rs;
|
||||
if ((conv_param->conv_quant_arg_.per_channel_ & FILTER_PER_CHANNEL)) {
|
||||
out_multiplier = vld1q_s32(quant_multiplier + oc_start);
|
||||
ls = vld1q_s32(left_shift + oc_start);
|
||||
rs = vld1q_s32(right_shift + oc_start);
|
||||
} else {
|
||||
out_multiplier = vdupq_n_s32(quant_multiplier[0]);
|
||||
ls = vdupq_n_s32(left_shift[0]);
|
||||
rs = vdupq_n_s32(right_shift[0]);
|
||||
}
|
||||
int32x4_t out_zp = vdupq_n_s32(output_zp);
|
||||
int32x4_t output_min = vdupq_n_s32(out_min);
|
||||
int32x4_t output_max = vdupq_n_s32(out_max);
|
||||
|
||||
d00 = vqshlq_s32(d00, ls);
|
||||
d00 = vqrdmulhq_s32(d00, out_multiplier);
|
||||
int32x4_t carry = vandq_s32(d00, rs);
|
||||
carry = vshrq_n_s32(carry, 31);
|
||||
d00 = vqaddq_s32(d00, carry);
|
||||
d00 = vqrshlq_s32(d00, rs);
|
||||
d00 = vaddq_s32(d00, out_zp);
|
||||
d00 = vmaxq_s32(d00, output_min);
|
||||
d00 = vminq_s32(d00, output_max);
|
||||
|
||||
d01 = vqshlq_s32(d01, ls);
|
||||
d01 = vqrdmulhq_s32(d01, out_multiplier);
|
||||
carry = vandq_s32(d01, rs);
|
||||
carry = vshrq_n_s32(carry, 31);
|
||||
d01 = vqaddq_s32(d01, carry);
|
||||
d01 = vqrshlq_s32(d01, rs);
|
||||
d01 = vaddq_s32(d01, out_zp);
|
||||
d01 = vmaxq_s32(d01, output_min);
|
||||
d01 = vminq_s32(d01, output_max);
|
||||
|
||||
d10 = vqshlq_s32(d10, ls);
|
||||
d10 = vqrdmulhq_s32(d10, out_multiplier);
|
||||
carry = vandq_s32(d10, rs);
|
||||
carry = vshrq_n_s32(carry, 31);
|
||||
d10 = vqaddq_s32(d10, carry);
|
||||
d10 = vqrshlq_s32(d10, rs);
|
||||
d10 = vaddq_s32(d10, out_zp);
|
||||
d10 = vmaxq_s32(d10, output_min);
|
||||
d10 = vminq_s32(d10, output_max);
|
||||
|
||||
d11 = vqshlq_s32(d11, ls);
|
||||
d11 = vqrdmulhq_s32(d11, out_multiplier);
|
||||
carry = vandq_s32(d11, rs);
|
||||
carry = vshrq_n_s32(carry, 31);
|
||||
d11 = vqaddq_s32(d11, carry);
|
||||
d11 = vqrshlq_s32(d11, rs);
|
||||
d11 = vaddq_s32(d11, out_zp);
|
||||
d11 = vmaxq_s32(d11, output_min);
|
||||
d11 = vminq_s32(d11, output_max);
|
||||
|
||||
(output_data)[0] = (int8_t)d00[0];
|
||||
(output_data + 1)[0] = (int8_t)d00[1];
|
||||
(output_data + 2)[0] = (int8_t)d00[2];
|
||||
(output_data + 3)[0] = (int8_t)d00[3];
|
||||
|
||||
if (w_not_bound) {
|
||||
*(output_data + 4) = (int8_t)d01[0];
|
||||
*(output_data + 5) = (int8_t)d01[1];
|
||||
*(output_data + 6) = (int8_t)d01[2];
|
||||
*(output_data + 7) = (int8_t)d01[3];
|
||||
}
|
||||
if (h_not_bound) {
|
||||
*(output_data + output_w * 4) = (int8_t)d10[0];
|
||||
*(output_data + output_w * 4 + 1) = (int8_t)d10[1];
|
||||
*(output_data + output_w * 4 + 2) = (int8_t)d10[2];
|
||||
*(output_data + output_w * 4 + 3) = (int8_t)d10[3];
|
||||
if (w_not_bound) {
|
||||
*(output_data + output_w * 4 + 4) = (int8_t)d11[0];
|
||||
*(output_data + output_w * 4 + 5) = (int8_t)d11[1];
|
||||
*(output_data + output_w * 4 + 6) = (int8_t)d11[2];
|
||||
*(output_data + output_w * 4 + 7) = (int8_t)d11[3];
|
||||
}
|
||||
}
|
||||
#else
|
||||
if ((conv_param->conv_quant_arg_.per_channel_ & FILTER_PER_CHANNEL)) {
|
||||
for (int i = 0; i < C4NUM; i++) {
|
||||
const int32_t *local_ptr = gemm_out + i;
|
||||
const int32_t *bias_ptr = bias_data + i;
|
||||
|
||||
int32_t s00 = local_ptr[0];
|
||||
int32_t s01 = (local_ptr + 4)[0];
|
||||
int32_t s02 = (local_ptr + 8)[0];
|
||||
int32_t s03 = (local_ptr + 12)[0];
|
||||
|
||||
int32_t s10 = (local_ptr + 16)[0];
|
||||
int32_t s11 = (local_ptr + 20)[0];
|
||||
int32_t s12 = (local_ptr + 24)[0];
|
||||
int32_t s13 = (local_ptr + 28)[0];
|
||||
|
||||
int32_t s20 = (local_ptr + 32)[0];
|
||||
int32_t s21 = (local_ptr + 36)[0];
|
||||
int32_t s22 = (local_ptr + 40)[0];
|
||||
int32_t s23 = (local_ptr + 44)[0];
|
||||
|
||||
int32_t s30 = (local_ptr + 48)[0];
|
||||
int32_t s31 = (local_ptr + 52)[0];
|
||||
int32_t s32 = (local_ptr + 56)[0];
|
||||
int32_t s33 = (local_ptr + 60)[0];
|
||||
|
||||
int32_t t00 = (s00 + s10 + s20) / 2;
|
||||
int32_t t01 = (s01 + s11 + s21) / 2;
|
||||
int32_t t02 = (s02 + s12 + s22) / 2;
|
||||
int32_t t03 = (s03 + s13 + s23) / 2;
|
||||
|
||||
int32_t t10 = (s10 - s20 - s30) / 2;
|
||||
int32_t t11 = (s11 - s21 - s31) / 2;
|
||||
int32_t t12 = (s12 - s22 - s32) / 2;
|
||||
int32_t t13 = (s13 - s23 - s33) / 2;
|
||||
|
||||
int32_t d00 = (t00 + t01 + t02) / 2 + bias_ptr[0];
|
||||
int32_t d01 = (t01 - t02 - t03) / 2 + bias_ptr[0];
|
||||
|
||||
int32_t d10 = (t10 + t11 + t12) / 2 + bias_ptr[0];
|
||||
int32_t d11 = (t11 - t12 - t13) / 2 + bias_ptr[0];
|
||||
|
||||
int oc_index = oc_start + i;
|
||||
d00 = RoundingDivideByPOT(
|
||||
SaturatingRoundingDoublingHighMul(d00 * (1 << (unsigned int)left_shift[oc_index]), quant_multiplier[oc_index]),
|
||||
-right_shift[oc_index]);
|
||||
d00 += output_zp;
|
||||
d00 = d00 > out_min ? d00 : out_min;
|
||||
d00 = d00 < out_max ? d00 : out_max;
|
||||
|
||||
d01 = RoundingDivideByPOT(
|
||||
SaturatingRoundingDoublingHighMul(d01 * (1 << (unsigned int)left_shift[oc_index]), quant_multiplier[oc_index]),
|
||||
-right_shift[oc_index]);
|
||||
d01 += output_zp;
|
||||
d01 = d01 > out_min ? d01 : out_min;
|
||||
d01 = d01 < out_max ? d01 : out_max;
|
||||
|
||||
d10 = RoundingDivideByPOT(
|
||||
SaturatingRoundingDoublingHighMul(d10 * (1 << (unsigned int)left_shift[oc_index]), quant_multiplier[oc_index]),
|
||||
-right_shift[oc_index]);
|
||||
d10 += output_zp;
|
||||
d10 = d10 > out_min ? d10 : out_min;
|
||||
d10 = d10 < out_max ? d10 : out_max;
|
||||
|
||||
d11 = RoundingDivideByPOT(
|
||||
SaturatingRoundingDoublingHighMul(d11 * (1 << (unsigned int)left_shift[oc_index]), quant_multiplier[oc_index]),
|
||||
-right_shift[oc_index]);
|
||||
d11 += output_zp;
|
||||
d11 = d11 > out_min ? d11 : out_min;
|
||||
d11 = d11 < out_max ? d11 : out_max;
|
||||
|
||||
(output_data + i)[0] = (int8_t)d00;
|
||||
if (w_not_bound) {
|
||||
(output_data + i + C4NUM)[0] = (int8_t)d01;
|
||||
}
|
||||
if (h_not_bound) {
|
||||
(output_data + i + output_w * C4NUM)[0] = (int8_t)d10;
|
||||
if (w_not_bound) {
|
||||
(output_data + i + output_w * C4NUM + C4NUM)[0] = (int8_t)d11;
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
for (int i = 0; i < C4NUM; i++) {
|
||||
const int32_t *local_ptr = gemm_out + i;
|
||||
const int32_t *bias_ptr = bias_data + i;
|
||||
|
||||
int32_t s00 = local_ptr[0];
|
||||
int32_t s01 = (local_ptr + 4)[0];
|
||||
int32_t s02 = (local_ptr + 8)[0];
|
||||
int32_t s03 = (local_ptr + 12)[0];
|
||||
|
||||
int32_t s10 = (local_ptr + 16)[0];
|
||||
int32_t s11 = (local_ptr + 20)[0];
|
||||
int32_t s12 = (local_ptr + 24)[0];
|
||||
int32_t s13 = (local_ptr + 28)[0];
|
||||
|
||||
int32_t s20 = (local_ptr + 32)[0];
|
||||
int32_t s21 = (local_ptr + 36)[0];
|
||||
int32_t s22 = (local_ptr + 40)[0];
|
||||
int32_t s23 = (local_ptr + 44)[0];
|
||||
|
||||
int32_t s30 = (local_ptr + 48)[0];
|
||||
int32_t s31 = (local_ptr + 52)[0];
|
||||
int32_t s32 = (local_ptr + 56)[0];
|
||||
int32_t s33 = (local_ptr + 60)[0];
|
||||
|
||||
int32_t t00 = (s00 + s10 + s20) / 2;
|
||||
int32_t t01 = (s01 + s11 + s21) / 2;
|
||||
int32_t t02 = (s02 + s12 + s22) / 2;
|
||||
int32_t t03 = (s03 + s13 + s23) / 2;
|
||||
|
||||
int32_t t10 = (s10 - s20 - s30) / 2;
|
||||
int32_t t11 = (s11 - s21 - s31) / 2;
|
||||
int32_t t12 = (s12 - s22 - s32) / 2;
|
||||
int32_t t13 = (s13 - s23 - s33) / 2;
|
||||
|
||||
int32_t d00 = (t00 + t01 + t02) / 2 + bias_ptr[0];
|
||||
int32_t d01 = (t01 - t02 - t03) / 2 + bias_ptr[0];
|
||||
|
||||
int32_t d10 = (t10 + t11 + t12) / 2 + bias_ptr[0];
|
||||
int32_t d11 = (t11 - t12 - t13) / 2 + bias_ptr[0];
|
||||
|
||||
d00 = RoundingDivideByPOT(
|
||||
SaturatingRoundingDoublingHighMul(d00 * (1 << (unsigned int)left_shift[0]), quant_multiplier[0]),
|
||||
-right_shift[0]);
|
||||
d00 += output_zp;
|
||||
d00 = d00 > out_min ? d00 : out_min;
|
||||
d00 = d00 < out_max ? d00 : out_max;
|
||||
|
||||
d01 = RoundingDivideByPOT(
|
||||
SaturatingRoundingDoublingHighMul(d01 * (1 << (unsigned int)left_shift[0]), quant_multiplier[0]),
|
||||
-right_shift[0]);
|
||||
d01 += output_zp;
|
||||
d01 = d01 > out_min ? d01 : out_min;
|
||||
d01 = d01 < out_max ? d01 : out_max;
|
||||
|
||||
d10 = RoundingDivideByPOT(
|
||||
SaturatingRoundingDoublingHighMul(d10 * (1 << (unsigned int)left_shift[0]), quant_multiplier[0]),
|
||||
-right_shift[0]);
|
||||
d10 += output_zp;
|
||||
d10 = d10 > out_min ? d10 : out_min;
|
||||
d10 = d10 < out_max ? d10 : out_max;
|
||||
|
||||
d11 = RoundingDivideByPOT(
|
||||
SaturatingRoundingDoublingHighMul(d11 * (1 << (unsigned int)left_shift[0]), quant_multiplier[0]),
|
||||
-right_shift[0]);
|
||||
d11 += output_zp;
|
||||
d11 = d11 > out_min ? d11 : out_min;
|
||||
d11 = d11 < out_max ? d11 : out_max;
|
||||
|
||||
(output_data + i)[0] = (int8_t)d00;
|
||||
if (w_not_bound) {
|
||||
(output_data + i + C4NUM)[0] = (int8_t)d01;
|
||||
}
|
||||
if (h_not_bound) {
|
||||
(output_data + i + output_w * C4NUM)[0] = (int8_t)d10;
|
||||
if (w_not_bound) {
|
||||
(output_data + i + output_w * C4NUM + C4NUM)[0] = (int8_t)d11;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
void Conv3x3Int8OutputTransform(const int32_t *gemm_out, int8_t *out_data, const int32_t *bias_data, int start_index,
|
||||
int real_cal_num, int out_w_block, ConvParameter *conv_param) {
|
||||
int output_channel = conv_param->output_channel_;
|
||||
int output_w = conv_param->output_w_;
|
||||
int output_h = conv_param->output_h_;
|
||||
const int oc4 = UP_DIV(output_channel, C4NUM);
|
||||
const int input_unit = 4;
|
||||
if (out_w_block == 0) {
|
||||
return;
|
||||
}
|
||||
for (int i = 0; i < real_cal_num; i++) {
|
||||
int out_w_index = (start_index + i) % out_w_block;
|
||||
int out_h_index = (start_index + i) / out_w_block;
|
||||
int src_tile_offset = i * oc4 * C4NUM * input_unit * input_unit;
|
||||
int dst_tile_offset = C4NUM * (out_w_index * OUPUT_UNIT + out_h_index * OUPUT_UNIT * output_w);
|
||||
|
||||
for (int j = 0; j < oc4; j++) {
|
||||
int src_oc4_offset = src_tile_offset + j * input_unit * input_unit * C4NUM;
|
||||
int dst_oc4_offset = dst_tile_offset + j * C4NUM * output_h * output_w;
|
||||
const int32_t *src_ptr = gemm_out + src_oc4_offset;
|
||||
const int32_t *bias_ptr = bias_data + j * C4NUM;
|
||||
int8_t *dst_ptr = out_data + dst_oc4_offset;
|
||||
|
||||
// output transform
|
||||
int real_num = (output_channel - j * C4NUM) < C4NUM ? (output_channel - j * C4NUM) : C4NUM;
|
||||
bool w_not_bound = out_w_index * OUPUT_UNIT + 1 < output_w;
|
||||
bool h_not_bound = out_h_index * OUPUT_UNIT + 1 < output_h;
|
||||
Conv3x3Int8OutputUnit(src_ptr, bias_ptr, dst_ptr, h_not_bound, w_not_bound, output_w, real_num, j * C4NUM,
|
||||
conv_param);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue