diff --git a/mindspore/lite/nnacl/CMakeLists.txt b/mindspore/lite/nnacl/CMakeLists.txt index c854813bd59..cfcf2fa94d8 100644 --- a/mindspore/lite/nnacl/CMakeLists.txt +++ b/mindspore/lite/nnacl/CMakeLists.txt @@ -15,7 +15,7 @@ file(GLOB KERNEL_SRC ${NNACL_DIR}/*.c ${NNACL_DIR}/fp32/*.c ${NNACL_DIR}/int8/*.c - ${NNACL_DIR}/quantization/*.c + ${NNACL_DIR}/base/*.c ) if (SUPPORT_TRAIN) diff --git a/mindspore/lite/nnacl/arithmetic.h b/mindspore/lite/nnacl/arithmetic.h index 5b6babee17a..d4f6a11c4c7 100644 --- a/mindspore/lite/nnacl/arithmetic.h +++ b/mindspore/lite/nnacl/arithmetic.h @@ -42,12 +42,4 @@ typedef struct ArithmeticParameter { int multiples1_[10]; } ArithmeticParameter; -#ifdef __cplusplus -extern "C" { -#endif -void CalcMultiplesAndStrides(ArithmeticParameter *param); -#ifdef __cplusplus -} -#endif - #endif // MINDSPORE_LITE_NNACL_ARTITHMETIC_H_ diff --git a/mindspore/lite/nnacl/arithmetic_self_parameter.h b/mindspore/lite/nnacl/arithmetic_self_parameter.h index fdbb1c8dbc8..d98eb72613c 100644 --- a/mindspore/lite/nnacl/arithmetic_self_parameter.h +++ b/mindspore/lite/nnacl/arithmetic_self_parameter.h @@ -19,7 +19,7 @@ #include "nnacl/op_base.h" #include "nnacl/errorcode.h" -#include "nnacl/quantization/quantize.h" +#include "nnacl/int8/quantize.h" // For Abs, Cos, Exp, Log, Square, Sqrt, Rsqrt ops. typedef struct ArithmeticSelfParameter { diff --git a/mindspore/lite/nnacl/arithmetic.c b/mindspore/lite/nnacl/base/arithmetic_base.c similarity index 96% rename from mindspore/lite/nnacl/arithmetic.c rename to mindspore/lite/nnacl/base/arithmetic_base.c index c595a376379..49cda36afee 100644 --- a/mindspore/lite/nnacl/arithmetic.c +++ b/mindspore/lite/nnacl/base/arithmetic_base.c @@ -14,7 +14,7 @@ * limitations under the License. */ -#include "nnacl/arithmetic.h" +#include "nnacl/base/arithmetic_base.h" void CalcMultiplesAndStrides(ArithmeticParameter *param) { NNACL_ASSERT(param->in_shape0_[i] != 0); diff --git a/mindspore/lite/nnacl/base/arithmetic_base.h b/mindspore/lite/nnacl/base/arithmetic_base.h new file mode 100644 index 00000000000..3e77c944c60 --- /dev/null +++ b/mindspore/lite/nnacl/base/arithmetic_base.h @@ -0,0 +1,34 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_NNACL_BASE_ARITHMETIC_BASE_H_ +#define MINDSPORE_LITE_NNACL_BASE_ARITHMETIC_BASE_H_ + +#include "nnacl/arithmetic.h" +#include "nnacl/nnacl_utils.h" +#include "nnacl/nnacl_common.h" + +#ifdef __cplusplus +extern "C" { +#endif + +void CalcMultiplesAndStrides(ArithmeticParameter *param); + +#ifdef __cplusplus +} +#endif + +#endif // MINDSPORE_LITE_NNACL_BASE_ARITHMETIC_BASE_H_ diff --git a/mindspore/lite/nnacl/base/conv1x1_base.c b/mindspore/lite/nnacl/base/conv1x1_base.c new file mode 100644 index 00000000000..7898e735097 --- /dev/null +++ b/mindspore/lite/nnacl/base/conv1x1_base.c @@ -0,0 +1,40 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl/base/conv1x1_base.h" + +void Conv1x1InputPack(const void *src_ptr, void *dst_ptr, ConvParameter *conv_param, int data_size) { + /* support nhwc */ + char *src = (char *)src_ptr; + char *dst = (char *)dst_ptr; + for (int dst_h = 0; dst_h < conv_param->output_h_; dst_h++) { + int src_h = dst_h * conv_param->stride_h_ - conv_param->pad_u_; + if (src_h < 0 || src_h >= conv_param->input_h_) { + continue; + } + const char *src_h_ptr = src + src_h * conv_param->input_w_ * conv_param->input_channel_ * data_size; + char *dst_h_ptr = dst + dst_h * conv_param->output_w_ * conv_param->input_channel_ * data_size; + for (int dst_w = 0; dst_w < conv_param->output_w_; dst_w++) { + int src_w = dst_w * conv_param->stride_w_ - conv_param->pad_l_; + if (src_w < 0 || src_w >= conv_param->input_w_) { + continue; + } + memcpy(dst_h_ptr + dst_w * conv_param->input_channel_ * data_size, + src_h_ptr + src_w * conv_param->input_channel_ * data_size, conv_param->input_channel_ * data_size); + } + } + return; +} diff --git a/mindspore/lite/nnacl/base/conv1x1_base.h b/mindspore/lite/nnacl/base/conv1x1_base.h new file mode 100644 index 00000000000..fc2b63d7b0f --- /dev/null +++ b/mindspore/lite/nnacl/base/conv1x1_base.h @@ -0,0 +1,32 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_NNACL_BASE_CONV1X1_BASE_H_ +#define MINDSPORE_LITE_NNACL_BASE_CONV1X1_BASE_H_ + +#include "nnacl/conv_parameter.h" + +#ifdef __cplusplus +extern "C" { +#endif + +void Conv1x1InputPack(const void *src_ptr, void *dst_ptr, ConvParameter *conv_param, int data_size); + +#ifdef __cplusplus +} +#endif + +#endif // MINDSPORE_LITE_NNACL_BASE_CONV1X1_BASE_H_ diff --git a/mindspore/lite/nnacl/depth_to_space.c b/mindspore/lite/nnacl/base/depth_to_space_base.c similarity index 97% rename from mindspore/lite/nnacl/depth_to_space.c rename to mindspore/lite/nnacl/base/depth_to_space_base.c index a41afb90159..e2b16837e44 100644 --- a/mindspore/lite/nnacl/depth_to_space.c +++ b/mindspore/lite/nnacl/base/depth_to_space_base.c @@ -13,8 +13,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#include "nnacl/depth_to_space.h" -#include + +#include "nnacl/base/depth_to_space_base.h" void DepthToSpaceForNHWC(const void *input, void *output, const int *in_shape, const DepthToSpaceParameter *param) { int32_t block_size = param->block_size_; diff --git a/mindspore/lite/nnacl/depth_to_space.h b/mindspore/lite/nnacl/base/depth_to_space_base.h similarity index 97% rename from mindspore/lite/nnacl/depth_to_space.h rename to mindspore/lite/nnacl/base/depth_to_space_base.h index 25f2bf622c3..23474a4f44a 100644 --- a/mindspore/lite/nnacl/depth_to_space.h +++ b/mindspore/lite/nnacl/base/depth_to_space_base.h @@ -15,6 +15,8 @@ */ #ifndef MINDSPORE_LITE_NNACL_DEPTH_TO_SPACE_H_ #define MINDSPORE_LITE_NNACL_DEPTH_TO_SPACE_H_ + +#include #include "nnacl/depth_to_space_parameter.h" #ifdef __cplusplus diff --git a/mindspore/lite/nnacl/clip.h b/mindspore/lite/nnacl/clip.h index 9e04a7b429f..23eec2d9700 100644 --- a/mindspore/lite/nnacl/clip.h +++ b/mindspore/lite/nnacl/clip.h @@ -18,7 +18,7 @@ #include #include "nnacl/op_base.h" -#include "nnacl/quantization/fixed_point.h" +#include "mindspore/lite/nnacl/int8/fixed_point.h" typedef struct ClipParameter { OpParameter op_parameter_; diff --git a/mindspore/lite/nnacl/common_func.h b/mindspore/lite/nnacl/common_func.h index 1e6dc30d279..0ab52e62b77 100644 --- a/mindspore/lite/nnacl/common_func.h +++ b/mindspore/lite/nnacl/common_func.h @@ -17,11 +17,10 @@ #ifndef MINDSPORE_LITE_NNACL_COMMON_FUNC_H_ #define MINDSPORE_LITE_NNACL_COMMON_FUNC_H_ -#include -#include #include #include "nnacl/op_base.h" #include "nnacl/conv_parameter.h" +#include "nnacl/nnacl_common.h" #ifdef __cplusplus extern "C" { @@ -63,14 +62,6 @@ static inline int GetStride(int *strides, const int *shape, int length) { return stride; } -inline void ComputeStrides(const int *shape, int *strides, const int ndim) { - int stride = 1; - for (int i = ndim - 1; i >= 0; i--) { - strides[i] = stride; - stride *= shape[i]; - } -} - #ifdef ENABLE_ARM64 void BiasAdd(const float *bias, float *data, size_t oc4, size_t plan_size); void BiasAddRelu6(const float *bias, float *data, size_t oc4, size_t plan_size); diff --git a/mindspore/lite/nnacl/concat_parameter.h b/mindspore/lite/nnacl/concat_parameter.h index 30d07da7555..8b22e934687 100644 --- a/mindspore/lite/nnacl/concat_parameter.h +++ b/mindspore/lite/nnacl/concat_parameter.h @@ -18,7 +18,7 @@ #define MINDSPORE_LITE_NNACL_CONCAT_PARAMETER_H_ #include "nnacl/op_base.h" -#include "nnacl/quantization/quantize.h" +#include "mindspore/lite/nnacl/int8/quantize.h" typedef struct ConcatParameter { OpParameter op_parameter_; diff --git a/mindspore/lite/nnacl/conv_parameter.h b/mindspore/lite/nnacl/conv_parameter.h index 3c314cfd1dd..4d1adc6a82c 100644 --- a/mindspore/lite/nnacl/conv_parameter.h +++ b/mindspore/lite/nnacl/conv_parameter.h @@ -21,7 +21,7 @@ #include #endif #include "nnacl/op_base.h" -#include "nnacl/quantization/quantize.h" +#include "mindspore/lite/nnacl/int8/quantize.h" typedef struct ConvParameter { OpParameter op_parameter_; diff --git a/mindspore/lite/nnacl/crop_parameter.h b/mindspore/lite/nnacl/crop_parameter.h index 9d0574b0ea7..c3c94224a35 100644 --- a/mindspore/lite/nnacl/crop_parameter.h +++ b/mindspore/lite/nnacl/crop_parameter.h @@ -18,7 +18,7 @@ #define MINDSPORE_LITE_NNACL_CROP_PARAMETER_H_ #include "nnacl/op_base.h" -#include "nnacl/quantization/quantize.h" +#include "mindspore/lite/nnacl/int8/quantize.h" #define CROP_OFFSET_MAX_SIZE 4 diff --git a/mindspore/lite/nnacl/fp16/activation_fp16.h b/mindspore/lite/nnacl/fp16/activation_fp16.h index d1ed088c4df..b4914aba681 100644 --- a/mindspore/lite/nnacl/fp16/activation_fp16.h +++ b/mindspore/lite/nnacl/fp16/activation_fp16.h @@ -21,7 +21,7 @@ #endif #include #include "nnacl/op_base.h" -#include "nnacl/quantization/fixed_point.h" +#include "mindspore/lite/nnacl/int8/fixed_point.h" typedef struct ActivationParameter { OpParameter op_parameter_; diff --git a/mindspore/lite/nnacl/fp16/arithmetic_fp16.h b/mindspore/lite/nnacl/fp16/arithmetic_fp16.h index 34a7ce96da2..840dfd2a856 100644 --- a/mindspore/lite/nnacl/fp16/arithmetic_fp16.h +++ b/mindspore/lite/nnacl/fp16/arithmetic_fp16.h @@ -20,7 +20,7 @@ #include #endif #include "nnacl/op_base.h" -#include "nnacl/arithmetic.h" +#include "nnacl/base/arithmetic_base.h" #include "nnacl/errorcode.h" #ifdef __cplusplus diff --git a/mindspore/lite/nnacl/fp16/pack_fp16.c b/mindspore/lite/nnacl/fp16/pack_fp16.c index 3baecff7cf5..67fe8d64557 100644 --- a/mindspore/lite/nnacl/fp16/pack_fp16.c +++ b/mindspore/lite/nnacl/fp16/pack_fp16.c @@ -16,7 +16,6 @@ #include "nnacl/fp16/pack_fp16.h" #include -#include void Im2ColPackUnitFp16(float16_t *input_data, ConvParameter *conv_param, float16_t *packed_input, int real_cal_num, int block_index) { diff --git a/mindspore/lite/nnacl/fp16/pooling_fp16.h b/mindspore/lite/nnacl/fp16/pooling_fp16.h index 5ae395f46e1..9dfd043ecc8 100644 --- a/mindspore/lite/nnacl/fp16/pooling_fp16.h +++ b/mindspore/lite/nnacl/fp16/pooling_fp16.h @@ -17,6 +17,7 @@ #ifndef MINDSPORE_LITE_NNACL_FP16_POOLING_FP16_H_ #define MINDSPORE_LITE_NNACL_FP16_POOLING_FP16_H_ +#include #ifdef ENABLE_NEON #include #endif diff --git a/mindspore/lite/nnacl/fp32/activation_fp32.h b/mindspore/lite/nnacl/fp32/activation_fp32.h index 999b04eb7a3..afae9869986 100644 --- a/mindspore/lite/nnacl/fp32/activation_fp32.h +++ b/mindspore/lite/nnacl/fp32/activation_fp32.h @@ -18,7 +18,7 @@ #include #include "nnacl/op_base.h" -#include "nnacl/quantization/fixed_point.h" +#include "mindspore/lite/nnacl/int8/fixed_point.h" typedef struct ActivationParameter { OpParameter op_parameter_; diff --git a/mindspore/lite/nnacl/fp32/arg_min_max_fp32.c b/mindspore/lite/nnacl/fp32/arg_min_max_fp32.c index 83dc83bbe5b..dcc03f1e1c5 100644 --- a/mindspore/lite/nnacl/fp32/arg_min_max_fp32.c +++ b/mindspore/lite/nnacl/fp32/arg_min_max_fp32.c @@ -13,8 +13,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + #include "nnacl/fp32/arg_min_max_fp32.h" -#include #include int ArgCompareAscFp32(const void *a, const void *b) { diff --git a/mindspore/lite/nnacl/fp32/arithmetic_fp32.h b/mindspore/lite/nnacl/fp32/arithmetic_fp32.h index 2d0496cfe50..d8bc67e8974 100644 --- a/mindspore/lite/nnacl/fp32/arithmetic_fp32.h +++ b/mindspore/lite/nnacl/fp32/arithmetic_fp32.h @@ -20,7 +20,7 @@ #include #endif #include "nnacl/op_base.h" -#include "nnacl/arithmetic.h" +#include "nnacl/base/arithmetic_base.h" #include "nnacl/errorcode.h" #ifdef __cplusplus diff --git a/mindspore/lite/nnacl/fp32/common_func_fp32.h b/mindspore/lite/nnacl/fp32/common_func_fp32.h index 898af91d645..81f73e1b8f6 100644 --- a/mindspore/lite/nnacl/fp32/common_func_fp32.h +++ b/mindspore/lite/nnacl/fp32/common_func_fp32.h @@ -17,8 +17,6 @@ #ifndef MINDSPORE_LITE_NNACL_FP32_COMMON_FUNC_H_ #define MINDSPORE_LITE_NNACL_FP32_COMMON_FUNC_H_ -#include -#include #include #include "nnacl/op_base.h" #include "nnacl/conv_parameter.h" diff --git a/mindspore/lite/nnacl/fp32/pack_fp32.c b/mindspore/lite/nnacl/fp32/pack_fp32.c new file mode 100644 index 00000000000..ff192c154a4 --- /dev/null +++ b/mindspore/lite/nnacl/fp32/pack_fp32.c @@ -0,0 +1,479 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl/fp32/pack_fp32.h" + +void PackWeightKHWToHWKFp32(const void *src, void *dst, int plane, int channel) { + return PackNCHWToNHWCFp32(src, dst, 1, plane, channel); +} + +void PackHWCToWHC(const float *src, float *dst, int height, int width, int channel) { + for (int i = 0; i < height; ++i) { + for (int j = 0; j < width; ++j) { + memcpy(dst + (j * height + i) * channel, src + (i * width + j) * channel, channel * sizeof(float)); + } + } +} + +void Im2ColPackUnitFp32(const float *input_data, const ConvParameter *conv_param, float *packed_input, int real_cal_num, + int block_index) { + // input format : nhwc + int kernel_h = conv_param->kernel_h_; + int kernel_w = conv_param->kernel_w_; + int kernel_plane = kernel_h * kernel_w; + int dilation_h = conv_param->dilation_h_; + int dilation_w = conv_param->dilation_w_; + int in_channel = conv_param->input_channel_; + int in_w = conv_param->input_w_; + int out_w = conv_param->output_w_; + + for (int i = 0; i < real_cal_num; i++) { + int block_start = block_index + i; + int input_h = block_start / out_w * conv_param->stride_h_ - conv_param->pad_u_; + int input_w = block_start % out_w * conv_param->stride_w_ - conv_param->pad_l_; + int input_stride = (input_h * in_w + input_w) * in_channel; + int kh_s = MSMAX(0, UP_DIV(-input_h, dilation_h)); + int kh_e = MSMIN(kernel_h, UP_DIV(conv_param->input_h_ - input_h, dilation_h)); + int kw_s = MSMAX(0, UP_DIV(-input_w, dilation_w)); + int kw_e = MSMIN(kernel_w, UP_DIV(in_w - input_w, dilation_w)); + if (dilation_w == 1 && dilation_h == 1) { + for (int j = kh_s; j < kh_e; j++) { + int input_y_stride = j * in_w * in_channel + input_stride; + int input_x_stride = input_y_stride + kw_s * in_channel; + int input_plane_offset = (j * kernel_w + kw_s) * in_channel + i * in_channel * kernel_plane; + memcpy(packed_input + input_plane_offset, input_data + input_x_stride, + (kw_e - kw_s) * in_channel * sizeof(float)); + } // kernel_h loop + } else { + for (int j = kh_s; j < kh_e; j++) { + int input_y_stride = j * dilation_h * in_w * in_channel + input_stride; + for (int k = kw_s; k < kw_e; ++k) { + int input_x_stride = input_y_stride + k * dilation_w * in_channel; + int input_plane_offset = (j * kernel_w + k) * in_channel + i * in_channel * kernel_plane; + memcpy(packed_input + input_plane_offset, input_data + input_x_stride, in_channel * sizeof(float)); + } + } // kernel_h loop + } + } // tile num loop +} + +void PackNHWCToNC4HW4Fp32(const void *src, void *dst, int batch, int plane, int channel) { + int c4 = UP_DIV(channel, C4NUM); + int c4_minus = c4 - 1; + for (int b = 0; b < batch; b++) { + int src_oc_offset = b * plane * channel; + int dst_oc_offset = b * plane * c4 * C4NUM; + for (int k = 0; k < plane; k++) { + int src_kernel_offset = src_oc_offset + k * channel; + int dst_kernel_offset = dst_oc_offset + k * C4NUM; + for (int j = 0; j < c4_minus; ++j) { + int src_ic_offset = src_kernel_offset + j * C4NUM; + int dst_ic_offset = dst_kernel_offset + j * plane * C4NUM; +#ifdef ENABLE_ARM + vst1q_f32((float *)dst + dst_ic_offset, vld1q_f32((float *)src + src_ic_offset)); +#else + for (int i = 0; i < C4NUM; ++i) { + ((float *)dst + dst_ic_offset)[i] = ((float *)src + src_ic_offset)[i]; + } +#endif + } + int tmp_c = c4_minus * C4NUM; + int tmp_c_offset = tmp_c * plane; + int res_c = channel - tmp_c; + for (int l = 0; l < res_c; ++l) { + int src_ic_offset = src_kernel_offset + tmp_c + l; + int dst_ic_offset = dst_kernel_offset + tmp_c_offset + l; + ((float *)dst + dst_ic_offset)[0] = ((float *)src + src_ic_offset)[0]; + } + } + } +} + +void PackNCHWToNC4HW4Fp32(const void *src, void *dst, int batch, int plane, int channel) { + int c4 = UP_DIV(channel, C4NUM); + for (int b = 0; b < batch; b++) { + int src_offset = b * plane * channel; + int dst_offset = b * plane * c4 * C4NUM; + for (int c = 0; c < channel; c++) { + int c4_block_num = c / C4NUM; + int c4_block_rem = c % C4NUM; + int src_c_offset = src_offset + c * plane; + int dst_c_offset = dst_offset + c4_block_num * plane * C4NUM; + for (int k = 0; k < plane; k++) { + int src_kernel_offset = src_c_offset + k; + int dst_kernel_offset = dst_c_offset + C4NUM * k + c4_block_rem; + ((float *)dst + dst_kernel_offset)[0] = ((float *)src + src_kernel_offset)[0]; + } + } + } +} + +void PackNHWCToNHWC4Fp32(const void *src, void *dst, int batch, int plane, int channel) { + int c4 = UP_DIV(channel, C4NUM); + int c4_channel = c4 * C4NUM; + int nhwc4_batch_unit_offset = c4 * C4NUM * plane; + int ic_remainder_ = channel % C4NUM; + if (ic_remainder_ != 0) { + int nhwc4_batch_offset = 0; + for (int b = 0; b < batch; b++) { + int batch_offset = b * channel * plane; + for (int i = 0; i < plane; i++) { + float *dst_per_plane = (float *)dst + nhwc4_batch_offset + i * c4_channel; + memcpy(dst_per_plane, (float *)src + batch_offset + i * channel, channel * sizeof(float)); + for (int j = channel; j < c4_channel; ++j) { + dst_per_plane[j] = 0; + } + } + nhwc4_batch_offset += nhwc4_batch_unit_offset; + } + } else { + size_t ori_input_size = batch * plane * channel * sizeof(float); + memcpy((float *)dst, (float *)src, ori_input_size); + } +} + +void PackNHWCToNHWC8Fp32(const void *src, void *dst, int batch, int plane, int channel) { + int c8 = UP_DIV(channel, C8NUM); + int c8_channel = c8 * C8NUM; + int nhwc8_batch_unit_offset = c8 * C8NUM * plane; + int ic_remainder_ = channel % C8NUM; + if (ic_remainder_ != 0) { + int nhwc8_batch_offset = 0; + for (int b = 0; b < batch; b++) { + int batch_offset = b * channel * plane; + for (int i = 0; i < plane; i++) { + float *dst_per_plane = (float *)dst + nhwc8_batch_offset + i * c8_channel; + memcpy(dst_per_plane, (float *)src + batch_offset + i * channel, channel * sizeof(float)); + for (int j = channel; j < c8_channel; ++j) { + dst_per_plane[j] = 0; + } + } + nhwc8_batch_offset += nhwc8_batch_unit_offset; + } + } else { + size_t ori_input_size = batch * plane * channel * sizeof(float); + memcpy((float *)dst, (float *)src, ori_input_size); + } +} + +void PackNHWC4ToNHWCFp32(const void *src, void *dst, int batch, int plane, int channel) { + int c4 = UP_DIV(channel, C4NUM); + int ic_remainder_ = channel % C4NUM; + if (ic_remainder_ != 0) { + int nhwc_batch_unit_offset = channel * plane; + for (int b = 0; b < batch; b++) { + int batch_offset = b * c4 * C4NUM * plane; + for (int i = 0; i < plane; i++) { + memcpy((float *)dst + b * nhwc_batch_unit_offset + i * channel, (float *)src + batch_offset + i * c4 * C4NUM, + channel * sizeof(float)); + } + } + } else { + size_t ori_input_size = batch * plane * channel * sizeof(float); + memcpy((float *)dst, (float *)src, ori_input_size); + } +} + +void PackNC4HW4ToNHWC4Fp32(const void *src, void *dst, int batch, int plane, int channel) { + int c4 = UP_DIV(channel, C4NUM); + for (int b = 0; b < batch; b++) { + int src_offset = b * plane * c4 * C4NUM; + int dst_offset = b * plane * channel; + for (int c = 0; c < channel; c++) { + int c4_block_num = c / C4NUM; + int c4_block_res = c % C4NUM; + int src_c_offset = src_offset + c4_block_num * plane * C4NUM + c4_block_res; + int dst_c_offset = dst_offset + c4_block_num * C4NUM + c4_block_res; + for (int k = 0; k < plane; k++) { + int src_kernel_offset = src_c_offset + k * C4NUM; + int dst_kernel_offset = dst_c_offset + k * c4 * C4NUM; + ((float *)dst + dst_kernel_offset)[0] = ((float *)src + src_kernel_offset)[0]; + } + } + } +} + +void PackNC4HW4ToNHWCFp32(const void *src, void *dst, int batch, int plane, int channel) { + int c4 = UP_DIV(channel, C4NUM); + for (int b = 0; b < batch; b++) { + int src_offset = b * plane * c4 * C4NUM; + int dst_offset = b * plane * channel; + for (int k = 0; k < plane; k++) { + int src_kernel_offset = src_offset + k * C4NUM; + int dst_kernel_offset = dst_offset + k * channel; + for (int c = 0; c < c4 - 1; c++) { + int src_c_offset = src_kernel_offset + c * plane * C4NUM; + int dst_c_offset = dst_kernel_offset + c * C4NUM; +#ifdef ENABLE_NEON + vst1q_f32((float *)dst + dst_c_offset, vld1q_f32((float *)src + src_c_offset)); +#else + ((float *)dst + dst_c_offset)[0] = ((float *)src + src_c_offset)[0]; + ((float *)dst + dst_c_offset)[1] = ((float *)src + src_c_offset)[1]; + ((float *)dst + dst_c_offset)[2] = ((float *)src + src_c_offset)[2]; + ((float *)dst + dst_c_offset)[3] = ((float *)src + src_c_offset)[3]; +#endif + } + // res part + int res_c = channel - (c4 - 1) * C4NUM; + for (int i = 0; i < res_c; i++) { + int src_res_c_offset = src_kernel_offset + (c4 - 1) * C4NUM * plane + i; + int dst_res_c_offset = dst_kernel_offset + (c4 - 1) * C4NUM + i; + ((float *)dst + dst_res_c_offset)[0] = ((float *)src + src_res_c_offset)[0]; + } + } + } +} + +void PackNHWCToC8HWN8Fp32(const void *src, void *dst, int batch, int plane, int channel) { + for (int n = 0; n < batch; n++) { + for (int hw = 0; hw < plane; hw++) { + for (int c = 0; c < channel; c++) { + int c8div = c / C8NUM; + int c8mod = c % C8NUM; + int src_index = n * plane * channel + hw * channel + c; + int dst_index = c8div * batch * plane * C8NUM + hw * batch * C8NUM + n * C8NUM + c8mod; + ((float *)dst)[dst_index] = ((float *)src)[src_index]; + } + } + } + return; +} + +void PackDepthwiseIndirectWeightC4Fp32(const void *src, void *dst, int height, int width, int channel) { + int c4 = UP_DIV(channel, C4NUM); + for (int c = 0; c < c4; c++) { + int dst_off_c = c * C4NUM * height * width; + for (int i = 0; i < C4NUM; i++) { + int src_off_c = (c * C4NUM + i) * height * width; + for (int kh = 0; kh < height; kh++) { + int src_off_kh = src_off_c + kh * width; + for (int kw = 0; kw < width; kw++) { + int dst_off = dst_off_c + kw * height * C4NUM + kh * C4NUM + i; + ((float *)dst)[dst_off] = ((float *)src)[src_off_kh + kw]; + } + } + } + } +} + +void PackDepthwiseIndirectWeightC8Fp32(const void *src, void *dst, int height, int width, int channel) { + int c8 = UP_DIV(channel, C8NUM); + for (int c = 0; c < c8; c++) { + int dst_off_c = c * C8NUM * height * width; + for (int i = 0; i < C8NUM; i++) { + int src_off_c = (c * C8NUM + i) * height * width; + for (int kh = 0; kh < height; kh++) { + int src_off_kh = src_off_c + kh * width; + for (int kw = 0; kw < width; kw++) { + int dst_off = dst_off_c + kw * height * C8NUM + kh * C8NUM + i; + ((float *)dst)[dst_off] = ((float *)src)[src_off_kh + kw]; + } + } + } + } +} + +#ifndef ENABLE_SSE +void PackNHWCToNCHWFp32(const void *src, void *dst, int batches, int plane, int channel) { + int hw8 = plane / C8NUM * C8NUM; + int c8 = channel / C8NUM * C8NUM; + int batch = plane * channel; + for (int n = 0; n < batches; n++) { + const float *src_batch = (const float *)src + n * batch; + float *dst_batch = (float *)dst + n * batch; + int hw = 0; + for (; hw < hw8; hw += C8NUM) { + int c = 0; + for (; c < c8; c += C8NUM) { + const float *src_ptr = src_batch + hw * channel + c; + float *dst_ptr = dst_batch + c * plane + hw; +#ifdef ENABLE_ARM64 + size_t srcStride = channel * sizeof(float); + size_t dstStride = plane * sizeof(float); + asm volatile( + "mov x10, %[src_ptr]\n" + "mov x11, %[dst_ptr]\n" + + "ld1 {v0.4s, v1.4s}, [x10], %[srcStride]\n" + "ld1 {v2.4s, v3.4s}, [x10], %[srcStride]\n" + + "zip1 v8.4s, v0.4s, v2.4s\n" + "zip2 v9.4s, v0.4s, v2.4s\n" + "zip1 v12.4s, v1.4s, v3.4s\n" + "zip2 v13.4s, v1.4s, v3.4s\n" + + "ld1 {v4.4s, v5.4s}, [x10], %[srcStride]\n" + "ld1 {v6.4s, v7.4s}, [x10], %[srcStride]\n" + + "zip1 v10.4s, v4.4s, v6.4s\n" + "zip2 v11.4s, v4.4s, v6.4s\n" + "zip1 v14.4s, v5.4s, v7.4s\n" + "zip2 v15.4s, v5.4s, v7.4s\n" + + "ld1 {v0.4s, v1.4s}, [x10], %[srcStride]\n" + "ld1 {v2.4s, v3.4s}, [x10], %[srcStride]\n" + + "trn1 v16.2d, v8.2d, v10.2d\n" + "trn2 v18.2d, v8.2d, v10.2d\n" + "trn1 v20.2d, v9.2d, v11.2d\n" + "trn2 v22.2d, v9.2d, v11.2d\n" + + "ld1 {v4.4s, v5.4s}, [x10], %[srcStride]\n" + "ld1 {v6.4s, v7.4s}, [x10], %[srcStride]\n" + + "trn1 v24.2d, v12.2d, v14.2d\n" + "trn2 v26.2d, v12.2d, v14.2d\n" + "trn1 v28.2d, v13.2d, v15.2d\n" + "trn2 v30.2d, v13.2d, v15.2d\n" + + "zip1 v8.4s, v0.4s, v2.4s\n" + "zip2 v9.4s, v0.4s, v2.4s\n" + "zip1 v12.4s, v1.4s, v3.4s\n" + "zip2 v13.4s, v1.4s, v3.4s\n" + + "zip1 v10.4s, v4.4s, v6.4s\n" + "zip2 v11.4s, v4.4s, v6.4s\n" + "zip1 v14.4s, v5.4s, v7.4s\n" + "zip2 v15.4s, v5.4s, v7.4s\n" + + "trn1 v17.2d, v8.2d, v10.2d\n" + "trn2 v19.2d, v8.2d, v10.2d\n" + "trn1 v21.2d, v9.2d, v11.2d\n" + "trn2 v23.2d, v9.2d, v11.2d\n" + + "trn1 v25.2d, v12.2d, v14.2d\n" + "trn2 v27.2d, v12.2d, v14.2d\n" + "trn1 v29.2d, v13.2d, v15.2d\n" + "trn2 v31.2d, v13.2d, v15.2d\n" + + "st1 {v16.4s, v17.4s}, [x11], %[dstStride]\n" + "st1 {v18.4s, v19.4s}, [x11], %[dstStride]\n" + "st1 {v20.4s, v21.4s}, [x11], %[dstStride]\n" + "st1 {v22.4s, v23.4s}, [x11], %[dstStride]\n" + "st1 {v24.4s, v25.4s}, [x11], %[dstStride]\n" + "st1 {v26.4s, v27.4s}, [x11], %[dstStride]\n" + "st1 {v28.4s, v29.4s}, [x11], %[dstStride]\n" + "st1 {v30.4s, v31.4s}, [x11], %[dstStride]\n" + + : + : + [ dst_ptr ] "r"(dst_ptr), [ src_ptr ] "r"(src_ptr), [ srcStride ] "r"(srcStride), [ dstStride ] "r"(dstStride) + : "x10", "x11", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", + "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", + "v30", "v31"); +#elif ENABLE_ARM32 + size_t srcStride = channel * sizeof(float); + size_t dstStride = plane * sizeof(float); + asm volatile( + "mov r10, %[src_ptr]\n" + "mov r12, %[dst_ptr]\n" + + "vld1.32 {q0, q1}, [r10], %[srcStride]\n" + "vld1.32 {q2, q3}, [r10], %[srcStride]\n" + + "vtrn.32 d0, d4\n" + "vtrn.32 d1, d5\n" + "vtrn.32 d2, d6\n" + "vtrn.32 d3, d7\n" + + "vld1.32 {q4, q5}, [r10], %[srcStride]\n" + "vld1.32 {q6, q7}, [r10], %[srcStride]\n" + + "vtrn.32 d8, d12\n" + "vtrn.32 d9, d13\n" + "vtrn.32 d10, d14\n" + "vtrn.32 d11, d15\n" + + "vld1.32 {q8, q9}, [r10], %[srcStride]\n" + "vld1.32 {q10, q11}, [r10], %[srcStride]\n" + + "vswp d1, d8\n" + "vswp d3, d10\n" + "vswp d5, d12\n" + "vswp d7, d14\n" + + "vtrn.32 d16, d20\n" + "vtrn.32 d17, d21\n" + "vtrn.32 d18, d22\n" + "vtrn.32 d19, d23\n" + + "vld1.32 {q12, q13}, [r10], %[srcStride]\n" + "vld1.32 {q14, q15}, [r10], %[srcStride]\n" + + "vtrn.32 d24, d28\n" + "vtrn.32 d25, d29\n" + "vtrn.32 d26, d30\n" + "vtrn.32 d27, d31\n" + + "vswp d17, d24\n" + "vswp d19, d26\n" + "vswp d21, d28\n" + "vswp d23, d30\n" + + "add r10, r12, #16\n" + "vst1.32 {q0}, [r12], %[dstStride]\n" + "vst1.32 {q8}, [r10], %[dstStride]\n" + "vst1.32 {q2}, [r12], %[dstStride]\n" + "vst1.32 {q10}, [r10], %[dstStride]\n" + "vst1.32 {q4}, [r12], %[dstStride]\n" + "vst1.32 {q12}, [r10], %[dstStride]\n" + "vst1.32 {q6}, [r12], %[dstStride]\n" + "vst1.32 {q14}, [r10], %[dstStride]\n" + "vst1.32 {q1}, [r12], %[dstStride]\n" + "vst1.32 {q9}, [r10], %[dstStride]\n" + "vst1.32 {q3}, [r12], %[dstStride]\n" + "vst1.32 {q11}, [r10], %[dstStride]\n" + "vst1.32 {q5}, [r12], %[dstStride]\n" + "vst1.32 {q13}, [r10], %[dstStride]\n" + "vst1.32 {q7}, [r12], %[dstStride]\n" + "vst1.32 {q15}, [r10], %[dstStride]\n" + + : + : + [ dst_ptr ] "r"(dst_ptr), [ src_ptr ] "r"(src_ptr), [ srcStride ] "r"(srcStride), [ dstStride ] "r"(dstStride) + : "r10", "r12", "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7", "q8", "q9", "q10", "q11", "q12", "q13", "q14", + "q15"); +#else + for (int tr = 0; tr < C8NUM; tr++) { + for (int tc = 0; tc < C8NUM; tc++) { + dst_ptr[tc * plane + tr] = src_ptr[tr * channel + tc]; + } + } +#endif + } + for (; c < channel; c++) { + const float *src_ptr = src_batch + hw * channel + c; + float *dst_ptr = dst_batch + c * plane + hw; + for (size_t i = 0; i < C8NUM; i++) { + dst_ptr[i] = src_ptr[i * channel]; + } + } + } + for (; hw < plane; hw++) { + const float *src_ptr = src_batch + hw * channel; + float *dst_ptr = dst_batch + hw; + for (size_t i = 0; i < channel; i++) { + dst_ptr[i * plane] = src_ptr[i]; + } + } + } + return; +} +#endif + +void PackNCHWToNHWCFp32(const void *src, void *dst, int batch, int plane, int channel) { + return PackNHWCToNCHWFp32(src, dst, batch, channel, plane); +} diff --git a/mindspore/lite/nnacl/fp32/pack_fp32.h b/mindspore/lite/nnacl/fp32/pack_fp32.h new file mode 100644 index 00000000000..7bca84490fc --- /dev/null +++ b/mindspore/lite/nnacl/fp32/pack_fp32.h @@ -0,0 +1,50 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_LITE_NNACL_FP32_PACK_H_ +#define MINDSPORE_LITE_NNACL_FP32_PACK_H_ + +#ifdef ENABLE_NEON +#include +#endif +#include "nnacl/conv_parameter.h" + +#ifdef __cplusplus +extern "C" { +#endif + +void PackHWCToWHC(const float *src, float *dst, int height, int width, int channel); +void PackNHWCToNC4HW4Fp32(const void *src, void *dst, int batch, int plane, int channel); +void PackNCHWToNC4HW4Fp32(const void *src, void *dst, int batch, int plane, int channel); +void PackNHWCToNHWC4Fp32(const void *src, void *dst, int batch, int plane, int channel); +void PackNHWCToNHWC8Fp32(const void *src, void *dst, int batch, int plane, int channel); +void PackNHWCToNCHWFp32(const void *src, void *dst, int batch, int plane, int channel); +void PackNCHWToNHWCFp32(const void *src, void *dst, int batch, int plane, int channel); +void PackNHWC4ToNHWCFp32(const void *src, void *dst, int batch, int plane, int channel); +void PackNC4HW4ToNHWC4Fp32(const void *src, void *dst, int batch, int plane, int channel); +void PackNC4HW4ToNHWCFp32(const void *src, void *dst, int batch, int plane, int channel); +void PackNHWCToC8HWN8Fp32(const void *src, void *dst, int batch, int plane, int channel); + +void PackWeightKHWToHWKFp32(const void *src, void *dst, int plane, int channel); +void PackDepthwiseIndirectWeightC4Fp32(const void *src, void *dst, int height, int width, int channel); +void PackDepthwiseIndirectWeightC8Fp32(const void *src, void *dst, int height, int width, int channel); +void Im2ColPackUnitFp32(const float *input_data, const ConvParameter *conv_param, float *packed_input, int real_cal_num, + int block_index); + +#ifdef __cplusplus +} +#endif + +#endif // MINDSPORE_LITE_NNACL_FP32_PAD_H_ diff --git a/mindspore/lite/nnacl/fp32/pooling_fp32.h b/mindspore/lite/nnacl/fp32/pooling_fp32.h index c897e2636bd..96c712bd776 100644 --- a/mindspore/lite/nnacl/fp32/pooling_fp32.h +++ b/mindspore/lite/nnacl/fp32/pooling_fp32.h @@ -22,7 +22,7 @@ #endif #include "nnacl/op_base.h" #include "nnacl/pooling_parameter.h" -#include "nnacl/quantization/quantize.h" +#include "mindspore/lite/nnacl/int8/quantize.h" #ifdef __cplusplus extern "C" { diff --git a/mindspore/lite/nnacl/int8/add_int8.c b/mindspore/lite/nnacl/int8/add_int8.c index 7b0c33067e7..aacbf42c25f 100644 --- a/mindspore/lite/nnacl/int8/add_int8.c +++ b/mindspore/lite/nnacl/int8/add_int8.c @@ -18,7 +18,7 @@ #ifdef ENABLE_NEON #include #endif -#include "nnacl/quantization/fixed_point.h" +#include "nnacl/int8/fixed_point.h" void AddInt8(const int8_t *input0, const int8_t *input1, int8_t *output, int size, AddQuantParameter *params) { int in0_left_shift = (1 << params->left_shift_) * (1 << params->in0_args_.left_shift_); diff --git a/mindspore/lite/nnacl/int8/arg_min_max_int8.h b/mindspore/lite/nnacl/int8/arg_min_max_int8.h index 7827fce9089..26854cdd346 100644 --- a/mindspore/lite/nnacl/int8/arg_min_max_int8.h +++ b/mindspore/lite/nnacl/int8/arg_min_max_int8.h @@ -17,7 +17,7 @@ #define MINDSPORE_LITE_NNACL_INT8_ARG_MIN_MAX_INT8_H_ #include "nnacl/arg_min_max_parameter.h" -#include "nnacl/quantization/quantize.h" +#include "nnacl/int8/quantize.h" #ifdef __cplusplus extern "C" { diff --git a/mindspore/lite/nnacl/int8/arithmetic_int8.h b/mindspore/lite/nnacl/int8/arithmetic_int8.h index 98f3f27ac4f..ec9d5c6fa18 100644 --- a/mindspore/lite/nnacl/int8/arithmetic_int8.h +++ b/mindspore/lite/nnacl/int8/arithmetic_int8.h @@ -17,8 +17,8 @@ #define MINDSPORE_LITE_NNACL_INT8_ARITHMETIC_INT8_H_ #include "nnacl/op_base.h" -#include "nnacl/arithmetic.h" -#include "nnacl/quantization/quantize.h" +#include "nnacl/int8/quantize.h" +#include "nnacl/base/arithmetic_base.h" #ifdef __cplusplus extern "C" { diff --git a/mindspore/lite/nnacl/int8/arithmetic_self_int8.c b/mindspore/lite/nnacl/int8/arithmetic_self_int8.c index d4148dc84c8..5a737b760d7 100644 --- a/mindspore/lite/nnacl/int8/arithmetic_self_int8.c +++ b/mindspore/lite/nnacl/int8/arithmetic_self_int8.c @@ -21,7 +21,7 @@ #include #include "nnacl/int8/common_func_int8.h" #endif -#include "nnacl/quantization/fixed_point.h" +#include "nnacl/int8/fixed_point.h" int Int8ElementFloor(int8_t *input, int8_t *output, int element_size, ArithSelfQuantArg para) { float in_scale = para.in_args_.scale_; diff --git a/mindspore/lite/nnacl/int8/arithmetic_self_int8.h b/mindspore/lite/nnacl/int8/arithmetic_self_int8.h index e792443d434..78ad1e00322 100644 --- a/mindspore/lite/nnacl/int8/arithmetic_self_int8.h +++ b/mindspore/lite/nnacl/int8/arithmetic_self_int8.h @@ -22,7 +22,7 @@ #endif #include "nnacl/op_base.h" #include "nnacl/errorcode.h" -#include "nnacl/quantization/quantize.h" +#include "nnacl/int8/quantize.h" #ifdef __cplusplus extern "C" { diff --git a/mindspore/lite/nnacl/int8/batch_to_space_int8.h b/mindspore/lite/nnacl/int8/batch_to_space_int8.h index ce2b3f3a85f..a2e0a18cbc3 100644 --- a/mindspore/lite/nnacl/int8/batch_to_space_int8.h +++ b/mindspore/lite/nnacl/int8/batch_to_space_int8.h @@ -16,7 +16,7 @@ #ifndef MINDSPORE_LITE_NNACL_INT8_BATCH_TO_SPACE_INT8_H_ #define MINDSPORE_LITE_NNACL_INT8_BATCH_TO_SPACE_INT8_H_ #include "nnacl/op_base.h" -#include "nnacl/quantization/quantize.h" +#include "nnacl/int8/quantize.h" #ifdef __cplusplus extern "C" { diff --git a/mindspore/lite/nnacl/int8/common_func_int8.c b/mindspore/lite/nnacl/int8/common_func_int8.c index 816ddfc896b..fa6eb486d84 100644 --- a/mindspore/lite/nnacl/int8/common_func_int8.c +++ b/mindspore/lite/nnacl/int8/common_func_int8.c @@ -15,7 +15,7 @@ */ #include "nnacl/int8/common_func_int8.h" -#include "nnacl/quantization/fixed_point.h" +#include "nnacl/int8/fixed_point.h" void PostConvFuncCommInt8(const int32_t *in, int8_t *out, const int32_t *bias, size_t oc, size_t plane, size_t out_oc_stride, size_t in_plane_stride, int32_t multiplier, int32_t mini, int32_t maxi, diff --git a/mindspore/lite/nnacl/int8/common_func_int8.h b/mindspore/lite/nnacl/int8/common_func_int8.h index cd3ed70b02c..ae0e0c5ceda 100644 --- a/mindspore/lite/nnacl/int8/common_func_int8.h +++ b/mindspore/lite/nnacl/int8/common_func_int8.h @@ -17,8 +17,6 @@ #ifndef MINDSPORE_LITE_NNACL_INT8_COMMON_FUNC_H_ #define MINDSPORE_LITE_NNACL_INT8_COMMON_FUNC_H_ -#include -#include #include #ifdef ENABLE_NEON #include @@ -29,9 +27,6 @@ #ifdef __cplusplus extern "C" { #endif - -void PostFuncInt8C8(const int32_t *in, const int32_t *bias, int8_t *out, size_t oc, size_t plane, int32_t multiplier, - int32_t left_shift, int32_t right_shift, int32_t zp, int32_t mini, int32_t maxi); void PostFuncInt8C4(const int32_t *in, const int32_t *bias, int8_t *out, size_t oc, size_t plane, size_t stride, int32_t multiplier, int32_t left_shift, int32_t right_shift, int32_t zp, int32_t mini, int32_t maxi); diff --git a/mindspore/lite/nnacl/int8/conv1x1_int8.c b/mindspore/lite/nnacl/int8/conv1x1_int8.c new file mode 100644 index 00000000000..e3d6840d2b3 --- /dev/null +++ b/mindspore/lite/nnacl/int8/conv1x1_int8.c @@ -0,0 +1,39 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl/int8/conv1x1_int8.h" + +void Conv1x1Int8Opt(const int8_t *packed_input, const int8_t *packed_weight, int8_t *dst, const int32_t *input_sum, + const int32_t *bias, int row, int col, int deep4, int32_t *left_shift, int32_t *right_shift, + int32_t *multiplier, ConvParameter *conv_param, MATMUL_OPT_DP_FUNC matmul_func, int *filter_zp) { + int is_per_oc = (int)conv_param->conv_quant_arg_.filter_arg_num_ != 1; + matmul_func(packed_input, packed_weight, dst, row, col, deep4, conv_param->output_channel_, input_sum, bias, + left_shift, right_shift, multiplier, conv_param->conv_quant_arg_.output_quant_args_[0].zp_, + conv_param->conv_quant_arg_.out_act_min_[0], conv_param->conv_quant_arg_.out_act_max_[0], is_per_oc, + filter_zp); + return; +} + +void Conv1x1Int8(const int8_t *packed_input, const int8_t *packed_weight, int8_t *dst, const int32_t *input_sum, + const int32_t *bias, int row, int col, int deep16, int32_t *left_shift, int32_t *right_shift, + int32_t *multiplier, ConvParameter *conv_param, int32_t *filter_zp) { + int is_per_oc = (int)conv_param->conv_quant_arg_.filter_arg_num_ != 1; + MatmulInt8Opt(packed_input, packed_weight, dst, row, col, deep16, input_sum, bias, + conv_param->conv_quant_arg_.out_act_min_[0], conv_param->conv_quant_arg_.out_act_max_[0], + conv_param->conv_quant_arg_.output_quant_args_[0].zp_, multiplier, left_shift, right_shift, + conv_param->output_channel_, is_per_oc, filter_zp); + return; +} diff --git a/mindspore/lite/nnacl/int8/conv1x1_int8.h b/mindspore/lite/nnacl/int8/conv1x1_int8.h new file mode 100644 index 00000000000..ec2ef268f51 --- /dev/null +++ b/mindspore/lite/nnacl/int8/conv1x1_int8.h @@ -0,0 +1,45 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_LITE_NNACL_INT8_CONV1X1_INT8_H_ +#define MINDSPORE_LITE_NNACL_INT8_CONV1X1_INT8_H_ + +#ifdef ENABLE_NEON +#include +#endif +#include "nnacl/pack.h" +#include "nnacl/op_base.h" +#include "nnacl/common_func.h" +#include "nnacl/conv_parameter.h" +#include "nnacl/int8/quantize.h" +#include "nnacl/matmul_parameter.h" +#include "nnacl/int8/matmul_int8.h" + +#ifdef __cplusplus +extern "C" { +#endif + +void Conv1x1Int8(const int8_t *packed_input, const int8_t *packed_weight, int8_t *dst, const int32_t *input_sum, + const int32_t *bias, int row, int col, int deep16, int32_t *left_shift, int32_t *right_shift, + int32_t *multiplier, ConvParameter *conv_param, int32_t *filter_zp); +void Conv1x1Int8Opt(const int8_t *packed_input, const int8_t *packed_weight, int8_t *dst, const int32_t *input_sum, + const int32_t *bias, int row, int col, int deep4, int32_t *left_shift, int32_t *right_shift, + int32_t *multiplier, ConvParameter *conv_param, MATMUL_OPT_DP_FUNC matmul_func, int32_t *filter_zp); + +#ifdef __cplusplus +} +#endif + +#endif // MINDSPORE_LITE_NNACL_INT8_CONV1X1_INT8_H_ diff --git a/mindspore/lite/nnacl/int8/conv3x3_int8.c b/mindspore/lite/nnacl/int8/conv3x3_int8.c new file mode 100644 index 00000000000..90c36cc92f4 --- /dev/null +++ b/mindspore/lite/nnacl/int8/conv3x3_int8.c @@ -0,0 +1,900 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl/int8/conv3x3_int8.h" + +void Conv3x3Int8InputUnit(int16_t *tmp_data, int16_t *trans_input_data, size_t step, int input_zp) { +#ifdef ENABLE_ARM + int16x8_t zp = vdupq_n_s16(input_zp); + + int16x8_t d00 = vsubq_s16(vld1q_s16(tmp_data), zp); + int16x8_t d01 = vsubq_s16(vld1q_s16(tmp_data + 8), zp); + int16x8_t d02 = vsubq_s16(vld1q_s16(tmp_data + 2 * 8), zp); + int16x8_t d03 = vsubq_s16(vld1q_s16(tmp_data + 3 * 8), zp); + + int16x8_t d10 = vsubq_s16(vld1q_s16(tmp_data + 4 * 8), zp); + int16x8_t d11 = vsubq_s16(vld1q_s16(tmp_data + 5 * 8), zp); + int16x8_t d12 = vsubq_s16(vld1q_s16(tmp_data + 6 * 8), zp); + int16x8_t d13 = vsubq_s16(vld1q_s16(tmp_data + 7 * 8), zp); + + int16x8_t d20 = vsubq_s16(vld1q_s16(tmp_data + 8 * 8), zp); + int16x8_t d21 = vsubq_s16(vld1q_s16(tmp_data + 9 * 8), zp); + int16x8_t d22 = vsubq_s16(vld1q_s16(tmp_data + 10 * 8), zp); + int16x8_t d23 = vsubq_s16(vld1q_s16(tmp_data + 11 * 8), zp); + + int16x8_t d30 = vsubq_s16(vld1q_s16(tmp_data + 12 * 8), zp); + int16x8_t d31 = vsubq_s16(vld1q_s16(tmp_data + 13 * 8), zp); + int16x8_t d32 = vsubq_s16(vld1q_s16(tmp_data + 14 * 8), zp); + int16x8_t d33 = vsubq_s16(vld1q_s16(tmp_data + 15 * 8), zp); + + int16x8_t t00 = vsubq_s16(d00, d20); + int16x8_t t01 = vsubq_s16(d01, d21); + int16x8_t t02 = vsubq_s16(d02, d22); + int16x8_t t03 = vsubq_s16(d03, d23); + + int16x8_t t10 = vaddq_s16(d10, d20); + int16x8_t t11 = vaddq_s16(d11, d21); + int16x8_t t12 = vaddq_s16(d12, d22); + int16x8_t t13 = vaddq_s16(d13, d23); + + int16x8_t t20 = vsubq_s16(d20, d10); + int16x8_t t21 = vsubq_s16(d21, d11); + int16x8_t t22 = vsubq_s16(d22, d12); + int16x8_t t23 = vsubq_s16(d23, d13); + + int16x8_t t30 = vsubq_s16(d10, d30); + int16x8_t t31 = vsubq_s16(d11, d31); + int16x8_t t32 = vsubq_s16(d12, d32); + int16x8_t t33 = vsubq_s16(d13, d33); + + int16x8_t m00 = vsubq_s16(t00, t02); + int16x8_t m01 = vaddq_s16(t01, t02); + int16x8_t m02 = vsubq_s16(t02, t01); + int16x8_t m03 = vsubq_s16(t01, t03); + + int16x8_t m10 = vsubq_s16(t10, t12); + int16x8_t m11 = vaddq_s16(t11, t12); + int16x8_t m12 = vsubq_s16(t12, t11); + int16x8_t m13 = vsubq_s16(t11, t13); + + int16x8_t m20 = vsubq_s16(t20, t22); + int16x8_t m21 = vaddq_s16(t21, t22); + int16x8_t m22 = vsubq_s16(t22, t21); + int16x8_t m23 = vsubq_s16(t21, t23); + + int16x8_t m30 = vsubq_s16(t30, t32); + int16x8_t m31 = vaddq_s16(t31, t32); + int16x8_t m32 = vsubq_s16(t32, t31); + int16x8_t m33 = vsubq_s16(t31, t33); + + vst1q_s16(trans_input_data, m00); + vst1q_s16(trans_input_data + step, m01); + vst1q_s16(trans_input_data + 2 * step, m02); + vst1q_s16(trans_input_data + 3 * step, m03); + + vst1q_s16(trans_input_data + 4 * step, m10); + vst1q_s16(trans_input_data + 5 * step, m11); + vst1q_s16(trans_input_data + 6 * step, m12); + vst1q_s16(trans_input_data + 7 * step, m13); + + vst1q_s16(trans_input_data + 8 * step, m20); + vst1q_s16(trans_input_data + 9 * step, m21); + vst1q_s16(trans_input_data + 10 * step, m22); + vst1q_s16(trans_input_data + 11 * step, m23); + + vst1q_s16(trans_input_data + 12 * step, m30); + vst1q_s16(trans_input_data + 13 * step, m31); + vst1q_s16(trans_input_data + 14 * step, m32); + vst1q_s16(trans_input_data + 15 * step, m33); +#else + for (int i = 0; i < C8NUM; i++) { + int16_t *local_ptr = tmp_data + i; + int16_t d00 = local_ptr[0] - input_zp; + int16_t d01 = (local_ptr + C8NUM)[0] - input_zp; + int16_t d02 = (local_ptr + 2 * C8NUM)[0] - input_zp; + int16_t d03 = (local_ptr + 3 * C8NUM)[0] - input_zp; + + int16_t d10 = (local_ptr + 4 * C8NUM)[0] - input_zp; + int16_t d11 = (local_ptr + 5 * C8NUM)[0] - input_zp; + int16_t d12 = (local_ptr + 6 * C8NUM)[0] - input_zp; + int16_t d13 = (local_ptr + 7 * C8NUM)[0] - input_zp; + + int16_t d20 = (local_ptr + 8 * C8NUM)[0] - input_zp; + int16_t d21 = (local_ptr + 9 * C8NUM)[0] - input_zp; + int16_t d22 = (local_ptr + 10 * C8NUM)[0] - input_zp; + int16_t d23 = (local_ptr + 11 * C8NUM)[0] - input_zp; + + int16_t d30 = (local_ptr + 12 * C8NUM)[0] - input_zp; + int16_t d31 = (local_ptr + 13 * C8NUM)[0] - input_zp; + int16_t d32 = (local_ptr + 14 * C8NUM)[0] - input_zp; + int16_t d33 = (local_ptr + 15 * C8NUM)[0] - input_zp; + + int16_t t00 = d00 - d20; + int16_t t01 = d01 - d21; + int16_t t02 = d02 - d22; + int16_t t03 = d03 - d23; + + int16_t t10 = d10 + d20; + int16_t t11 = d11 + d21; + int16_t t12 = d12 + d22; + int16_t t13 = d13 + d23; + + int16_t t20 = d20 - d10; + int16_t t21 = d21 - d11; + int16_t t22 = d22 - d12; + int16_t t23 = d23 - d13; + + int16_t t30 = d10 - d30; + int16_t t31 = d11 - d31; + int16_t t32 = d12 - d32; + int16_t t33 = d13 - d33; + + int16_t m00 = t00 - t02; + int16_t m01 = t01 + t02; + int16_t m02 = t02 - t01; + int16_t m03 = t01 - t03; + + int16_t m10 = t10 - t12; + int16_t m11 = t11 + t12; + int16_t m12 = t12 - t11; + int16_t m13 = t11 - t13; + + int16_t m20 = t20 - t22; + int16_t m21 = t21 + t22; + int16_t m22 = t22 - t21; + int16_t m23 = t21 - t23; + + int16_t m30 = t30 - t32; + int16_t m31 = t31 + t32; + int16_t m32 = t32 - t31; + int16_t m33 = t31 - t33; + + (trans_input_data + i)[0] = m00; + (trans_input_data + i + step)[0] = m01; + (trans_input_data + i + 2 * step)[0] = m02; + (trans_input_data + i + 3 * step)[0] = m03; + + (trans_input_data + i + 4 * step)[0] = m10; + (trans_input_data + i + 5 * step)[0] = m11; + (trans_input_data + i + 6 * step)[0] = m12; + (trans_input_data + i + 7 * step)[0] = m13; + + (trans_input_data + i + 8 * step)[0] = m20; + (trans_input_data + i + 9 * step)[0] = m21; + (trans_input_data + i + 10 * step)[0] = m22; + (trans_input_data + i + 11 * step)[0] = m23; + + (trans_input_data + i + 12 * step)[0] = m30; + (trans_input_data + i + 13 * step)[0] = m31; + (trans_input_data + i + 14 * step)[0] = m32; + (trans_input_data + i + 15 * step)[0] = m33; + } +#endif +} + +void Conv3x3Int8FilterTransform(const int16_t *weight_data, int16_t *trans_weight, int iC8, int output_channel, + int kernel_plane) { + const int input_unit = 4; + int dst_step = iC8 * C8NUM * C4NUM; + for (int o = 0; o < output_channel; o++) { + int oc4_block_num = o / C4NUM; + int oc4_block_rem = o % C4NUM; + int src_oc_offset = o * iC8 * C8NUM * kernel_plane; + int dst_oc_offset = oc4_block_num * C4NUM * iC8 * C8NUM * input_unit * input_unit + oc4_block_rem; + for (int i = 0; i < iC8; i++) { + const int16_t *src_ic8_ptr = weight_data + src_oc_offset + i * kernel_plane * C8NUM; + int16_t *dst_ic8_ptr = trans_weight + dst_oc_offset + i * C4NUM * C8NUM; +#ifdef ENABLE_ARM + int16x8_t g00 = vld1q_s16(src_ic8_ptr); + int16x8_t g01 = vld1q_s16(src_ic8_ptr + 8); + int16x8_t g02 = vld1q_s16(src_ic8_ptr + 2 * 8); + int16x8_t g10 = vld1q_s16(src_ic8_ptr + 3 * 8); + int16x8_t g11 = vld1q_s16(src_ic8_ptr + 4 * 8); + int16x8_t g12 = vld1q_s16(src_ic8_ptr + 5 * 8); + int16x8_t g20 = vld1q_s16(src_ic8_ptr + 6 * 8); + int16x8_t g21 = vld1q_s16(src_ic8_ptr + 7 * 8); + int16x8_t g22 = vld1q_s16(src_ic8_ptr + 8 * 8); + + int16x8_t dst00 = vmulq_n_s16(g00, 2); + int16x8_t dst01 = vmulq_n_s16(g01, 2); + int16x8_t dst02 = vmulq_n_s16(g02, 2); + + int16x8_t dst10 = vaddq_s16(vaddq_s16(g00, g10), g20); + int16x8_t dst11 = vaddq_s16(vaddq_s16(g01, g11), g21); + int16x8_t dst12 = vaddq_s16(vaddq_s16(g02, g12), g22); + + int16x8_t dst20 = vaddq_s16(vsubq_s16(g00, g10), g20); + int16x8_t dst21 = vaddq_s16(vsubq_s16(g01, g11), g21); + int16x8_t dst22 = vaddq_s16(vsubq_s16(g02, g12), g22); + + int16x8_t dst30 = vmulq_n_s16(g20, 2); + int16x8_t dst31 = vmulq_n_s16(g21, 2); + int16x8_t dst32 = vmulq_n_s16(g22, 2); + + int16x8_t m00 = vmulq_n_s16(dst00, 2); + int16x8_t m01 = vaddq_s16(vaddq_s16(dst00, dst01), dst02); + int16x8_t m02 = vaddq_s16(vsubq_s16(dst00, dst01), dst02); + int16x8_t m03 = vmulq_n_s16(dst02, 2); + + int16x8_t m10 = vmulq_n_s16(dst10, 2); + int16x8_t m11 = vaddq_s16(vaddq_s16(dst10, dst11), dst12); + int16x8_t m12 = vaddq_s16(vsubq_s16(dst10, dst11), dst12); + int16x8_t m13 = vmulq_n_s16(dst12, 2); + + int16x8_t m20 = vmulq_n_s16(dst20, 2); + int16x8_t m21 = vaddq_s16(vaddq_s16(dst20, dst21), dst22); + int16x8_t m22 = vaddq_s16(vsubq_s16(dst20, dst21), dst22); + int16x8_t m23 = vmulq_n_s16(dst22, 2); + + int16x8_t m30 = vmulq_n_s16(dst30, 2); + int16x8_t m31 = vaddq_s16(vaddq_s16(dst30, dst31), dst32); + int16x8_t m32 = vaddq_s16(vsubq_s16(dst30, dst31), dst32); + int16x8_t m33 = vmulq_n_s16(dst32, 2); + + dst_ic8_ptr[0] = m00[0]; + dst_ic8_ptr[4] = m00[1]; + dst_ic8_ptr[8] = m00[2]; + dst_ic8_ptr[12] = m00[3]; + dst_ic8_ptr[16] = m00[4]; + dst_ic8_ptr[20] = m00[5]; + dst_ic8_ptr[24] = m00[6]; + dst_ic8_ptr[28] = m00[7]; + + dst_ic8_ptr[0 + dst_step] = m01[0]; + dst_ic8_ptr[4 + dst_step] = m01[1]; + dst_ic8_ptr[8 + dst_step] = m01[2]; + dst_ic8_ptr[12 + dst_step] = m01[3]; + dst_ic8_ptr[16 + dst_step] = m01[4]; + dst_ic8_ptr[20 + dst_step] = m01[5]; + dst_ic8_ptr[24 + dst_step] = m01[6]; + dst_ic8_ptr[28 + dst_step] = m01[7]; + + dst_ic8_ptr[0 + 2 * dst_step] = m02[0]; + dst_ic8_ptr[4 + 2 * dst_step] = m02[1]; + dst_ic8_ptr[8 + 2 * dst_step] = m02[2]; + dst_ic8_ptr[12 + 2 * dst_step] = m02[3]; + dst_ic8_ptr[16 + 2 * dst_step] = m02[4]; + dst_ic8_ptr[20 + 2 * dst_step] = m02[5]; + dst_ic8_ptr[24 + 2 * dst_step] = m02[6]; + dst_ic8_ptr[28 + 2 * dst_step] = m02[7]; + + dst_ic8_ptr[0 + 3 * dst_step] = m03[0]; + dst_ic8_ptr[4 + 3 * dst_step] = m03[1]; + dst_ic8_ptr[8 + 3 * dst_step] = m03[2]; + dst_ic8_ptr[12 + 3 * dst_step] = m03[3]; + dst_ic8_ptr[16 + 3 * dst_step] = m03[4]; + dst_ic8_ptr[20 + 3 * dst_step] = m03[5]; + dst_ic8_ptr[24 + 3 * dst_step] = m03[6]; + dst_ic8_ptr[28 + 3 * dst_step] = m03[7]; + + dst_ic8_ptr[0 + 4 * dst_step] = m10[0]; + dst_ic8_ptr[4 + 4 * dst_step] = m10[1]; + dst_ic8_ptr[8 + 4 * dst_step] = m10[2]; + dst_ic8_ptr[12 + 4 * dst_step] = m10[3]; + dst_ic8_ptr[16 + 4 * dst_step] = m10[4]; + dst_ic8_ptr[20 + 4 * dst_step] = m10[5]; + dst_ic8_ptr[24 + 4 * dst_step] = m10[6]; + dst_ic8_ptr[28 + 4 * dst_step] = m10[7]; + + dst_ic8_ptr[0 + 5 * dst_step] = m11[0]; + dst_ic8_ptr[4 + 5 * dst_step] = m11[1]; + dst_ic8_ptr[8 + 5 * dst_step] = m11[2]; + dst_ic8_ptr[12 + 5 * dst_step] = m11[3]; + dst_ic8_ptr[16 + 5 * dst_step] = m11[4]; + dst_ic8_ptr[20 + 5 * dst_step] = m11[5]; + dst_ic8_ptr[24 + 5 * dst_step] = m11[6]; + dst_ic8_ptr[28 + 5 * dst_step] = m11[7]; + + dst_ic8_ptr[0 + 6 * dst_step] = m12[0]; + dst_ic8_ptr[4 + 6 * dst_step] = m12[1]; + dst_ic8_ptr[8 + 6 * dst_step] = m12[2]; + dst_ic8_ptr[12 + 6 * dst_step] = m12[3]; + dst_ic8_ptr[16 + 6 * dst_step] = m12[4]; + dst_ic8_ptr[20 + 6 * dst_step] = m12[5]; + dst_ic8_ptr[24 + 6 * dst_step] = m12[6]; + dst_ic8_ptr[28 + 6 * dst_step] = m12[7]; + + dst_ic8_ptr[0 + 7 * dst_step] = m13[0]; + dst_ic8_ptr[4 + 7 * dst_step] = m13[1]; + dst_ic8_ptr[8 + 7 * dst_step] = m13[2]; + dst_ic8_ptr[12 + 7 * dst_step] = m13[3]; + dst_ic8_ptr[16 + 7 * dst_step] = m13[4]; + dst_ic8_ptr[20 + 7 * dst_step] = m13[5]; + dst_ic8_ptr[24 + 7 * dst_step] = m13[6]; + dst_ic8_ptr[28 + 7 * dst_step] = m13[7]; + + dst_ic8_ptr[0 + 8 * dst_step] = m20[0]; + dst_ic8_ptr[4 + 8 * dst_step] = m20[1]; + dst_ic8_ptr[8 + 8 * dst_step] = m20[2]; + dst_ic8_ptr[12 + 8 * dst_step] = m20[3]; + dst_ic8_ptr[16 + 8 * dst_step] = m20[4]; + dst_ic8_ptr[20 + 8 * dst_step] = m20[5]; + dst_ic8_ptr[24 + 8 * dst_step] = m20[6]; + dst_ic8_ptr[28 + 8 * dst_step] = m20[7]; + + dst_ic8_ptr[0 + 9 * dst_step] = m21[0]; + dst_ic8_ptr[4 + 9 * dst_step] = m21[1]; + dst_ic8_ptr[8 + 9 * dst_step] = m21[2]; + dst_ic8_ptr[12 + 9 * dst_step] = m21[3]; + dst_ic8_ptr[16 + 9 * dst_step] = m21[4]; + dst_ic8_ptr[20 + 9 * dst_step] = m21[5]; + dst_ic8_ptr[24 + 9 * dst_step] = m21[6]; + dst_ic8_ptr[28 + 9 * dst_step] = m21[7]; + + dst_ic8_ptr[0 + 10 * dst_step] = m22[0]; + dst_ic8_ptr[4 + 10 * dst_step] = m22[1]; + dst_ic8_ptr[8 + 10 * dst_step] = m22[2]; + dst_ic8_ptr[12 + 10 * dst_step] = m22[3]; + dst_ic8_ptr[16 + 10 * dst_step] = m22[4]; + dst_ic8_ptr[20 + 10 * dst_step] = m22[5]; + dst_ic8_ptr[24 + 10 * dst_step] = m22[6]; + dst_ic8_ptr[28 + 10 * dst_step] = m22[7]; + + dst_ic8_ptr[0 + 11 * dst_step] = m23[0]; + dst_ic8_ptr[4 + 11 * dst_step] = m23[1]; + dst_ic8_ptr[8 + 11 * dst_step] = m23[2]; + dst_ic8_ptr[12 + 11 * dst_step] = m23[3]; + dst_ic8_ptr[16 + 11 * dst_step] = m23[4]; + dst_ic8_ptr[20 + 11 * dst_step] = m23[5]; + dst_ic8_ptr[24 + 11 * dst_step] = m23[6]; + dst_ic8_ptr[28 + 11 * dst_step] = m23[7]; + + dst_ic8_ptr[0 + 12 * dst_step] = m30[0]; + dst_ic8_ptr[4 + 12 * dst_step] = m30[1]; + dst_ic8_ptr[8 + 12 * dst_step] = m30[2]; + dst_ic8_ptr[12 + 12 * dst_step] = m30[3]; + dst_ic8_ptr[16 + 12 * dst_step] = m30[4]; + dst_ic8_ptr[20 + 12 * dst_step] = m30[5]; + dst_ic8_ptr[24 + 12 * dst_step] = m30[6]; + dst_ic8_ptr[28 + 12 * dst_step] = m30[7]; + + dst_ic8_ptr[0 + 13 * dst_step] = m31[0]; + dst_ic8_ptr[4 + 13 * dst_step] = m31[1]; + dst_ic8_ptr[8 + 13 * dst_step] = m31[2]; + dst_ic8_ptr[12 + 13 * dst_step] = m31[3]; + dst_ic8_ptr[16 + 13 * dst_step] = m31[4]; + dst_ic8_ptr[20 + 13 * dst_step] = m31[5]; + dst_ic8_ptr[24 + 13 * dst_step] = m31[6]; + dst_ic8_ptr[28 + 13 * dst_step] = m31[7]; + + dst_ic8_ptr[0 + 14 * dst_step] = m32[0]; + dst_ic8_ptr[4 + 14 * dst_step] = m32[1]; + dst_ic8_ptr[8 + 14 * dst_step] = m32[2]; + dst_ic8_ptr[12 + 14 * dst_step] = m32[3]; + dst_ic8_ptr[16 + 14 * dst_step] = m32[4]; + dst_ic8_ptr[20 + 14 * dst_step] = m32[5]; + dst_ic8_ptr[24 + 14 * dst_step] = m32[6]; + dst_ic8_ptr[28 + 14 * dst_step] = m32[7]; + + dst_ic8_ptr[0 + 15 * dst_step] = m33[0]; + dst_ic8_ptr[4 + 15 * dst_step] = m33[1]; + dst_ic8_ptr[8 + 15 * dst_step] = m33[2]; + dst_ic8_ptr[12 + 15 * dst_step] = m33[3]; + dst_ic8_ptr[16 + 15 * dst_step] = m33[4]; + dst_ic8_ptr[20 + 15 * dst_step] = m33[5]; + dst_ic8_ptr[24 + 15 * dst_step] = m33[6]; + dst_ic8_ptr[28 + 15 * dst_step] = m33[7]; +#else + for (int j = 0; j < C8NUM; j++) { + const int16_t *local_ptr = src_ic8_ptr + j; + int16_t dst00 = local_ptr[0] * 2; + int16_t dst01 = (local_ptr + 8)[0] * 2; + int16_t dst02 = (local_ptr + 16)[0] * 2; + + int16_t dst10 = local_ptr[0] + (local_ptr + 24)[0] + (local_ptr + 48)[0]; + int16_t dst11 = (local_ptr + 8)[0] + (local_ptr + 32)[0] + (local_ptr + 56)[0]; + int16_t dst12 = (local_ptr + 16)[0] + (local_ptr + 40)[0] + (local_ptr + 64)[0]; + + int16_t dst20 = local_ptr[0] - (local_ptr + 24)[0] + (local_ptr + 48)[0]; + int16_t dst21 = (local_ptr + 8)[0] - (local_ptr + 32)[0] + (local_ptr + 56)[0]; + int16_t dst22 = (local_ptr + 16)[0] - (local_ptr + 40)[0] + (local_ptr + 64)[0]; + + int16_t dst30 = (local_ptr + 48)[0] * 2; + int16_t dst31 = (local_ptr + 56)[0] * 2; + int16_t dst32 = (local_ptr + 64)[0] * 2; + + int16_t m00 = dst00 * 2; + int16_t m01 = dst00 + dst01 + dst02; + int16_t m02 = dst00 - dst01 + dst02; + int16_t m03 = dst02 * 2; + + int16_t m10 = dst10 * 2; + int16_t m11 = dst10 + dst11 + dst12; + int16_t m12 = dst10 - dst11 + dst12; + int16_t m13 = dst12 * 2; + + int16_t m20 = dst20 * 2; + int16_t m21 = dst20 + dst21 + dst22; + int16_t m22 = dst20 - dst21 + dst22; + int16_t m23 = dst22 * 2; + + int16_t m30 = dst30 * 2; + int16_t m31 = dst30 + dst31 + dst32; + int16_t m32 = dst30 - dst31 + dst32; + int16_t m33 = dst32 * 2; + + *(dst_ic8_ptr + j * 4) = m00; + *(dst_ic8_ptr + j * 4 + dst_step) = m01; + *(dst_ic8_ptr + j * 4 + 2 * dst_step) = m02; + *(dst_ic8_ptr + j * 4 + 3 * dst_step) = m03; + + *(dst_ic8_ptr + j * 4 + 4 * dst_step) = m10; + *(dst_ic8_ptr + j * 4 + 5 * dst_step) = m11; + *(dst_ic8_ptr + j * 4 + 6 * dst_step) = m12; + *(dst_ic8_ptr + j * 4 + 7 * dst_step) = m13; + + *(dst_ic8_ptr + j * 4 + 8 * dst_step) = m20; + *(dst_ic8_ptr + j * 4 + 9 * dst_step) = m21; + *(dst_ic8_ptr + j * 4 + 10 * dst_step) = m22; + *(dst_ic8_ptr + j * 4 + 11 * dst_step) = m23; + + *(dst_ic8_ptr + j * 4 + 12 * dst_step) = m30; + *(dst_ic8_ptr + j * 4 + 13 * dst_step) = m31; + *(dst_ic8_ptr + j * 4 + 14 * dst_step) = m32; + *(dst_ic8_ptr + j * 4 + 15 * dst_step) = m33; + } +#endif + } + } +} + +void Conv3x3Int8OutputUnit(const int32_t *gemm_out, const int32_t *bias_data, int8_t *output_data, bool h_not_bound, + bool w_not_bound, int output_w, int real_num, int oc_start, ConvParameter *conv_param) { + int32_t *left_shift = conv_param->conv_quant_arg_.left_shift_; + int32_t *right_shift = conv_param->conv_quant_arg_.right_shift_; + int32_t *quant_multiplier = conv_param->conv_quant_arg_.quant_multiplier_; + int output_zp = conv_param->conv_quant_arg_.output_quant_args_[0].zp_; + int out_min = conv_param->conv_quant_arg_.out_act_min_[0]; + int out_max = conv_param->conv_quant_arg_.out_act_max_[0]; + +#ifdef ENABLE_ARM + int32x4_t bias_ptr = vld1q_s32(bias_data); + + int32x4_t s00 = vld1q_s32(gemm_out); + int32x4_t s01 = vld1q_s32(gemm_out + 4); + int32x4_t s02 = vld1q_s32(gemm_out + 8); + int32x4_t s03 = vld1q_s32(gemm_out + 12); + + int32x4_t s10 = vld1q_s32(gemm_out + 16); + int32x4_t s11 = vld1q_s32(gemm_out + 20); + int32x4_t s12 = vld1q_s32(gemm_out + 24); + int32x4_t s13 = vld1q_s32(gemm_out + 28); + + int32x4_t s20 = vld1q_s32(gemm_out + 32); + int32x4_t s21 = vld1q_s32(gemm_out + 36); + int32x4_t s22 = vld1q_s32(gemm_out + 40); + int32x4_t s23 = vld1q_s32(gemm_out + 44); + + int32x4_t s30 = vld1q_s32(gemm_out + 48); + int32x4_t s31 = vld1q_s32(gemm_out + 52); + int32x4_t s32 = vld1q_s32(gemm_out + 56); + int32x4_t s33 = vld1q_s32(gemm_out + 60); + + int32x4_t t00 = vshrq_n_s32(vaddq_s32(vaddq_s32(s00, s10), s20), 1); + int32x4_t t01 = vshrq_n_s32(vaddq_s32(vaddq_s32(s01, s11), s21), 1); + int32x4_t t02 = vshrq_n_s32(vaddq_s32(vaddq_s32(s02, s12), s22), 1); + int32x4_t t03 = vshrq_n_s32(vaddq_s32(vaddq_s32(s03, s13), s23), 1); + + int32x4_t t10 = vshrq_n_s32(vsubq_s32(vsubq_s32(s10, s20), s30), 1); + int32x4_t t11 = vshrq_n_s32(vsubq_s32(vsubq_s32(s11, s21), s31), 1); + int32x4_t t12 = vshrq_n_s32(vsubq_s32(vsubq_s32(s12, s22), s32), 1); + int32x4_t t13 = vshrq_n_s32(vsubq_s32(vsubq_s32(s13, s23), s33), 1); + + int32x4_t d00 = vaddq_s32(vshrq_n_s32(vaddq_s32(vaddq_s32(t00, t01), t02), 1), bias_ptr); + int32x4_t d01 = vaddq_s32(vshrq_n_s32(vsubq_s32(vsubq_s32(t01, t02), t03), 1), bias_ptr); + + int32x4_t d10 = vaddq_s32(vshrq_n_s32(vaddq_s32(vaddq_s32(t10, t11), t12), 1), bias_ptr); + int32x4_t d11 = vaddq_s32(vshrq_n_s32(vsubq_s32(vsubq_s32(t11, t12), t13), 1), bias_ptr); + + int32x4_t out_multiplier; + int32x4_t ls; + int32x4_t rs; + if ((conv_param->conv_quant_arg_.per_channel_ & FILTER_PER_CHANNEL)) { + out_multiplier = vld1q_s32(quant_multiplier + oc_start); + ls = vld1q_s32(left_shift + oc_start); + rs = vld1q_s32(right_shift + oc_start); + } else { + out_multiplier = vdupq_n_s32(quant_multiplier[0]); + ls = vdupq_n_s32(left_shift[0]); + rs = vdupq_n_s32(right_shift[0]); + } + int32x4_t out_zp = vdupq_n_s32(output_zp); + int32x4_t output_min = vdupq_n_s32(out_min); + int32x4_t output_max = vdupq_n_s32(out_max); + + d00 = vqshlq_s32(d00, ls); + d00 = vqrdmulhq_s32(d00, out_multiplier); + int32x4_t carry = vandq_s32(d00, rs); + carry = vshrq_n_s32(carry, 31); + d00 = vqaddq_s32(d00, carry); + d00 = vqrshlq_s32(d00, rs); + d00 = vaddq_s32(d00, out_zp); + d00 = vmaxq_s32(d00, output_min); + d00 = vminq_s32(d00, output_max); + + d01 = vqshlq_s32(d01, ls); + d01 = vqrdmulhq_s32(d01, out_multiplier); + carry = vandq_s32(d01, rs); + carry = vshrq_n_s32(carry, 31); + d01 = vqaddq_s32(d01, carry); + d01 = vqrshlq_s32(d01, rs); + d01 = vaddq_s32(d01, out_zp); + d01 = vmaxq_s32(d01, output_min); + d01 = vminq_s32(d01, output_max); + + d10 = vqshlq_s32(d10, ls); + d10 = vqrdmulhq_s32(d10, out_multiplier); + carry = vandq_s32(d10, rs); + carry = vshrq_n_s32(carry, 31); + d10 = vqaddq_s32(d10, carry); + d10 = vqrshlq_s32(d10, rs); + d10 = vaddq_s32(d10, out_zp); + d10 = vmaxq_s32(d10, output_min); + d10 = vminq_s32(d10, output_max); + + d11 = vqshlq_s32(d11, ls); + d11 = vqrdmulhq_s32(d11, out_multiplier); + carry = vandq_s32(d11, rs); + carry = vshrq_n_s32(carry, 31); + d11 = vqaddq_s32(d11, carry); + d11 = vqrshlq_s32(d11, rs); + d11 = vaddq_s32(d11, out_zp); + d11 = vmaxq_s32(d11, output_min); + d11 = vminq_s32(d11, output_max); + + (output_data)[0] = (int8_t)d00[0]; + (output_data + 1)[0] = (int8_t)d00[1]; + (output_data + 2)[0] = (int8_t)d00[2]; + (output_data + 3)[0] = (int8_t)d00[3]; + + if (w_not_bound) { + *(output_data + 4) = (int8_t)d01[0]; + *(output_data + 5) = (int8_t)d01[1]; + *(output_data + 6) = (int8_t)d01[2]; + *(output_data + 7) = (int8_t)d01[3]; + } + if (h_not_bound) { + *(output_data + output_w * 4) = (int8_t)d10[0]; + *(output_data + output_w * 4 + 1) = (int8_t)d10[1]; + *(output_data + output_w * 4 + 2) = (int8_t)d10[2]; + *(output_data + output_w * 4 + 3) = (int8_t)d10[3]; + if (w_not_bound) { + *(output_data + output_w * 4 + 4) = (int8_t)d11[0]; + *(output_data + output_w * 4 + 5) = (int8_t)d11[1]; + *(output_data + output_w * 4 + 6) = (int8_t)d11[2]; + *(output_data + output_w * 4 + 7) = (int8_t)d11[3]; + } + } +#else + if ((conv_param->conv_quant_arg_.per_channel_ & FILTER_PER_CHANNEL)) { + for (int i = 0; i < C4NUM; i++) { + const int32_t *local_ptr = gemm_out + i; + const int32_t *bias_ptr = bias_data + i; + + int32_t s00 = local_ptr[0]; + int32_t s01 = (local_ptr + 4)[0]; + int32_t s02 = (local_ptr + 8)[0]; + int32_t s03 = (local_ptr + 12)[0]; + + int32_t s10 = (local_ptr + 16)[0]; + int32_t s11 = (local_ptr + 20)[0]; + int32_t s12 = (local_ptr + 24)[0]; + int32_t s13 = (local_ptr + 28)[0]; + + int32_t s20 = (local_ptr + 32)[0]; + int32_t s21 = (local_ptr + 36)[0]; + int32_t s22 = (local_ptr + 40)[0]; + int32_t s23 = (local_ptr + 44)[0]; + + int32_t s30 = (local_ptr + 48)[0]; + int32_t s31 = (local_ptr + 52)[0]; + int32_t s32 = (local_ptr + 56)[0]; + int32_t s33 = (local_ptr + 60)[0]; + + int32_t t00 = (s00 + s10 + s20) / 2; + int32_t t01 = (s01 + s11 + s21) / 2; + int32_t t02 = (s02 + s12 + s22) / 2; + int32_t t03 = (s03 + s13 + s23) / 2; + + int32_t t10 = (s10 - s20 - s30) / 2; + int32_t t11 = (s11 - s21 - s31) / 2; + int32_t t12 = (s12 - s22 - s32) / 2; + int32_t t13 = (s13 - s23 - s33) / 2; + + int32_t d00 = (t00 + t01 + t02) / 2 + bias_ptr[0]; + int32_t d01 = (t01 - t02 - t03) / 2 + bias_ptr[0]; + + int32_t d10 = (t10 + t11 + t12) / 2 + bias_ptr[0]; + int32_t d11 = (t11 - t12 - t13) / 2 + bias_ptr[0]; + + int oc_index = oc_start + i; + d00 = RoundingDivideByPOT( + SaturatingRoundingDoublingHighMul(d00 * (1 << (unsigned int)left_shift[oc_index]), quant_multiplier[oc_index]), + -right_shift[oc_index]); + d00 += output_zp; + d00 = d00 > out_min ? d00 : out_min; + d00 = d00 < out_max ? d00 : out_max; + + d01 = RoundingDivideByPOT( + SaturatingRoundingDoublingHighMul(d01 * (1 << (unsigned int)left_shift[oc_index]), quant_multiplier[oc_index]), + -right_shift[oc_index]); + d01 += output_zp; + d01 = d01 > out_min ? d01 : out_min; + d01 = d01 < out_max ? d01 : out_max; + + d10 = RoundingDivideByPOT( + SaturatingRoundingDoublingHighMul(d10 * (1 << (unsigned int)left_shift[oc_index]), quant_multiplier[oc_index]), + -right_shift[oc_index]); + d10 += output_zp; + d10 = d10 > out_min ? d10 : out_min; + d10 = d10 < out_max ? d10 : out_max; + + d11 = RoundingDivideByPOT( + SaturatingRoundingDoublingHighMul(d11 * (1 << (unsigned int)left_shift[oc_index]), quant_multiplier[oc_index]), + -right_shift[oc_index]); + d11 += output_zp; + d11 = d11 > out_min ? d11 : out_min; + d11 = d11 < out_max ? d11 : out_max; + + (output_data + i)[0] = (int8_t)d00; + if (w_not_bound) { + (output_data + i + C4NUM)[0] = (int8_t)d01; + } + if (h_not_bound) { + (output_data + i + output_w * C4NUM)[0] = (int8_t)d10; + if (w_not_bound) { + (output_data + i + output_w * C4NUM + C4NUM)[0] = (int8_t)d11; + } + } + } + } else { + for (int i = 0; i < C4NUM; i++) { + const int32_t *local_ptr = gemm_out + i; + const int32_t *bias_ptr = bias_data + i; + + int32_t s00 = local_ptr[0]; + int32_t s01 = (local_ptr + 4)[0]; + int32_t s02 = (local_ptr + 8)[0]; + int32_t s03 = (local_ptr + 12)[0]; + + int32_t s10 = (local_ptr + 16)[0]; + int32_t s11 = (local_ptr + 20)[0]; + int32_t s12 = (local_ptr + 24)[0]; + int32_t s13 = (local_ptr + 28)[0]; + + int32_t s20 = (local_ptr + 32)[0]; + int32_t s21 = (local_ptr + 36)[0]; + int32_t s22 = (local_ptr + 40)[0]; + int32_t s23 = (local_ptr + 44)[0]; + + int32_t s30 = (local_ptr + 48)[0]; + int32_t s31 = (local_ptr + 52)[0]; + int32_t s32 = (local_ptr + 56)[0]; + int32_t s33 = (local_ptr + 60)[0]; + + int32_t t00 = (s00 + s10 + s20) / 2; + int32_t t01 = (s01 + s11 + s21) / 2; + int32_t t02 = (s02 + s12 + s22) / 2; + int32_t t03 = (s03 + s13 + s23) / 2; + + int32_t t10 = (s10 - s20 - s30) / 2; + int32_t t11 = (s11 - s21 - s31) / 2; + int32_t t12 = (s12 - s22 - s32) / 2; + int32_t t13 = (s13 - s23 - s33) / 2; + + int32_t d00 = (t00 + t01 + t02) / 2 + bias_ptr[0]; + int32_t d01 = (t01 - t02 - t03) / 2 + bias_ptr[0]; + + int32_t d10 = (t10 + t11 + t12) / 2 + bias_ptr[0]; + int32_t d11 = (t11 - t12 - t13) / 2 + bias_ptr[0]; + + d00 = RoundingDivideByPOT( + SaturatingRoundingDoublingHighMul(d00 * (1 << (unsigned int)left_shift[0]), quant_multiplier[0]), + -right_shift[0]); + d00 += output_zp; + d00 = d00 > out_min ? d00 : out_min; + d00 = d00 < out_max ? d00 : out_max; + + d01 = RoundingDivideByPOT( + SaturatingRoundingDoublingHighMul(d01 * (1 << (unsigned int)left_shift[0]), quant_multiplier[0]), + -right_shift[0]); + d01 += output_zp; + d01 = d01 > out_min ? d01 : out_min; + d01 = d01 < out_max ? d01 : out_max; + + d10 = RoundingDivideByPOT( + SaturatingRoundingDoublingHighMul(d10 * (1 << (unsigned int)left_shift[0]), quant_multiplier[0]), + -right_shift[0]); + d10 += output_zp; + d10 = d10 > out_min ? d10 : out_min; + d10 = d10 < out_max ? d10 : out_max; + + d11 = RoundingDivideByPOT( + SaturatingRoundingDoublingHighMul(d11 * (1 << (unsigned int)left_shift[0]), quant_multiplier[0]), + -right_shift[0]); + d11 += output_zp; + d11 = d11 > out_min ? d11 : out_min; + d11 = d11 < out_max ? d11 : out_max; + + (output_data + i)[0] = (int8_t)d00; + if (w_not_bound) { + (output_data + i + C4NUM)[0] = (int8_t)d01; + } + if (h_not_bound) { + (output_data + i + output_w * C4NUM)[0] = (int8_t)d10; + if (w_not_bound) { + (output_data + i + output_w * C4NUM + C4NUM)[0] = (int8_t)d11; + } + } + } + } +#endif +} + +void Conv3x3Int8OutputTransform(const int32_t *gemm_out, int8_t *out_data, const int32_t *bias_data, int start_index, + int real_cal_num, int out_w_block, ConvParameter *conv_param) { + int output_channel = conv_param->output_channel_; + int output_w = conv_param->output_w_; + int output_h = conv_param->output_h_; + const int oc4 = UP_DIV(output_channel, C4NUM); + const int input_unit = 4; + if (out_w_block == 0) { + return; + } + for (int i = 0; i < real_cal_num; i++) { + int out_w_index = (start_index + i) % out_w_block; + int out_h_index = (start_index + i) / out_w_block; + int src_tile_offset = i * oc4 * C4NUM * input_unit * input_unit; + int dst_tile_offset = C4NUM * (out_w_index * OUPUT_UNIT + out_h_index * OUPUT_UNIT * output_w); + + for (int j = 0; j < oc4; j++) { + int src_oc4_offset = src_tile_offset + j * input_unit * input_unit * C4NUM; + int dst_oc4_offset = dst_tile_offset + j * C4NUM * output_h * output_w; + const int32_t *src_ptr = gemm_out + src_oc4_offset; + const int32_t *bias_ptr = bias_data + j * C4NUM; + int8_t *dst_ptr = out_data + dst_oc4_offset; + + // output transform + int real_num = (output_channel - j * C4NUM) < C4NUM ? (output_channel - j * C4NUM) : C4NUM; + bool w_not_bound = out_w_index * OUPUT_UNIT + 1 < output_w; + bool h_not_bound = out_h_index * OUPUT_UNIT + 1 < output_h; + Conv3x3Int8OutputUnit(src_ptr, bias_ptr, dst_ptr, h_not_bound, w_not_bound, output_w, real_num, j * C4NUM, + conv_param); + } + } +} + +void Conv3x3Int8InputTransform(const int16_t *input_data, int16_t *trans_input, int16_t *tmp_data, int start_index, + int real_cal_num, int out_w_block, ConvParameter *conv_param) { + // input data format : nhwc + int input_channel = conv_param->input_channel_; + int input_width = conv_param->input_w_; + int input_height = conv_param->input_h_; + int pad_w = conv_param->pad_l_; + int pad_h = conv_param->pad_u_; + ConvQuantArg quant_arg = conv_param->conv_quant_arg_; + int input_zp = quant_arg.input_quant_args_[0].zp_; + const int ic8 = UP_DIV(input_channel, C8NUM); + const int input_unit = 4; + if (out_w_block == 0) { + return; + } + for (int cal_id = 0; cal_id < real_cal_num; cal_id++) { + int x_id = start_index + cal_id; + int origin_x = (x_id % out_w_block) * OUPUT_UNIT - pad_w; + int origin_y = (x_id / out_w_block) * OUPUT_UNIT - pad_h; + int real_x_start = origin_x > 0 ? 0 : -origin_x; + int real_x_end = (origin_x + input_unit) < input_width ? input_unit : (input_width - origin_x); + int real_y_start = origin_y > 0 ? 0 : -origin_y; + int real_y_end = (origin_y + input_unit) < input_height ? input_unit : (input_height - origin_y); + + int src_plane_offset = C8NUM * (origin_y * input_width + origin_x); + int dst_plane_offset = cal_id * C8NUM; + for (int ic = 0; ic < ic8; ic++) { + // copy data from origin input to tmp buffer + for (int i = 0; i < input_unit * input_unit * TILE_NUM; i++) tmp_data[i] = input_zp; + + int src_c8_offset = src_plane_offset + ic * C8NUM * input_height * input_width; + for (int j = real_y_start; j < real_y_end; j++) { + const int16_t *src = input_data + src_c8_offset + C8NUM * (j * input_width + real_x_start); + int16_t *dst = tmp_data + C8NUM * (C4NUM * j + real_x_start); + memcpy(dst, src, (real_x_end - real_x_start) * C8NUM * sizeof(int16_t)); + } + // input transform + int dst_ic8_offset = dst_plane_offset + ic * TILE_NUM * C8NUM; + size_t dst_step = ic8 * C8NUM * TILE_NUM; + int16_t *trans_input_ptr = trans_input + dst_ic8_offset; + Conv3x3Int8InputUnit(tmp_data, trans_input_ptr, dst_step, input_zp); + } + } +} + +void Conv3x3Int8Gemm(int32_t *dst, const int16_t *src, const int16_t *weight, int oc, int ic8, size_t real_cal_num) { + int oc4 = UP_DIV(oc, C4NUM); +#ifdef ENABLE_ARM + IndirectGemmInt16to32_8x4(dst, src, weight, 16, ic8, oc4, oc4 * 4 * 16 * sizeof(int32_t)); +#else + const int input_unit_square = 16; + for (int c = 0; c < oc4; c++) { + int filter_oc_offset = c * input_unit_square * ic8 * C8NUM * C4NUM; + int dst_oc_offset = c * input_unit_square * C4NUM; + for (int n = 0; n < real_cal_num; n++) { + int src_tile_offset = n * C8NUM; + int dst_tile_offset = dst_oc_offset + n * oc4 * C4NUM * input_unit_square; + for (int i = 0; i < 4; i++) { + int filter_h_offset = filter_oc_offset + i * 4 * ic8 * C8NUM * C4NUM; + int src_h_offset = src_tile_offset + i * C8NUM * ic8 * C8NUM * C4NUM; + int dst_h_offset = dst_tile_offset + i * 4 * 4; + for (int m = 0; m < 4; m++) { + int filter_w_offset = filter_h_offset + m * 4 * C8NUM * ic8; + int src_w_offset = src_h_offset + m * 8 * ic8 * C8NUM; + int dst_w_offset = dst_h_offset + m * C4NUM; + + int32_t acc[4] = {0}; + for (int z = 0; z < 4; z++) { + int filter_offset = filter_w_offset + z; + for (int j = 0; j < ic8; j++) { + int filter_c8_offset = filter_offset + j * 4 * 8; + int src_c8_offset = src_w_offset + j * 8 * 8; + + for (int k = 0; k < 8; k++) { + const int16_t *w_ptr = weight + filter_c8_offset + k * 4; + const int16_t *input_ptr = src + src_c8_offset + k; + acc[z] += w_ptr[0] * input_ptr[0]; + } + } + (dst + dst_w_offset + z)[0] = acc[z]; + } + } + } + } + } +#endif +} + +// int8 convolution 3x3 +void Conv3x3Int8(int16_t *input_data, int16_t *transed_weight, const int32_t *bias_data, int8_t *output_data, + int16_t *tile_buffer, int16_t *block_unit_buffer, int32_t *tmp_dst_buffer, int8_t *tmp_out, + int task_id, ConvParameter *conv_param) { + int ic8 = UP_DIV(conv_param->input_channel_, C8NUM); + int out_w_block = UP_DIV(conv_param->output_w_, OUPUT_UNIT); + int out_h_block = UP_DIV(conv_param->output_h_, OUPUT_UNIT); + int output_count = out_w_block * out_h_block; + int output_tile_count = UP_DIV(output_count, TILE_NUM); + int oc4 = UP_DIV(conv_param->output_channel_, C4NUM); + int tile_buffer_offset = TILE_NUM * 16 * ic8 * C8NUM; + const int block_unit_buffer_offset = 16 * C8NUM; + int tmp_dst_buffer_offset = TILE_NUM * 16 * oc4 * C4NUM; + + for (int batch = 0; batch < conv_param->input_batch_; batch++) { + int in_batch_offset = batch * ic8 * C8NUM * conv_param->input_h_ * conv_param->input_w_; + int tmp_out_batch_offset = batch * oc4 * C4NUM * conv_param->output_w_ * conv_param->output_h_; + for (int thread_id = task_id; thread_id < output_tile_count; thread_id += conv_param->thread_num_) { + int start_index = thread_id * TILE_NUM; + int real_cal_num = (output_count - start_index) < TILE_NUM ? (output_count - start_index) : TILE_NUM; + + Conv3x3Int8InputTransform(input_data + in_batch_offset, tile_buffer + task_id * tile_buffer_offset, + block_unit_buffer + task_id * block_unit_buffer_offset, start_index, real_cal_num, + out_w_block, conv_param); + + Conv3x3Int8Gemm(tmp_dst_buffer + task_id * tmp_dst_buffer_offset, tile_buffer + task_id * tile_buffer_offset, + transed_weight, conv_param->output_channel_, ic8, real_cal_num); + + Conv3x3Int8OutputTransform(tmp_dst_buffer + task_id * tmp_dst_buffer_offset, tmp_out + tmp_out_batch_offset, + bias_data, start_index, real_cal_num, out_w_block, conv_param); + } + } +} diff --git a/mindspore/lite/nnacl/int8/conv3x3_int8.h b/mindspore/lite/nnacl/int8/conv3x3_int8.h new file mode 100644 index 00000000000..6111b46ef4a --- /dev/null +++ b/mindspore/lite/nnacl/int8/conv3x3_int8.h @@ -0,0 +1,49 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_LITE_NNACL_INT8_CONV_INT8_H_ +#define MINDSPORE_LITE_NNACL_INT8_CONV_INT8_H_ + +#include +#ifdef ENABLE_NEON +#include +#endif +#include "nnacl/pack.h" +#include "nnacl/op_base.h" +#include "nnacl/common_func.h" +#include "nnacl/conv_parameter.h" +#include "nnacl/winograd_utils.h" +#include "nnacl/int8/quantize.h" +#include "nnacl/matmul_parameter.h" +#include "nnacl/int8/matmul_int8.h" +#include "nnacl/winograd_transform.h" +#include "nnacl/int8/common_func_int8.h" + +#ifdef __cplusplus +extern "C" { +#endif + +void Conv3x3Int8FilterTransform(const int16_t *weight_data, int16_t *trans_weight, int iC8, int output_channel, + int kernel_plane); + +void Conv3x3Int8(int16_t *input_data, int16_t *transed_weight, const int32_t *bias_data, int8_t *output_data, + int16_t *tile_buffer, int16_t *block_unit_buffer, int32_t *tmp_dst_buffer, int8_t *tmp_out, + int task_id, ConvParameter *conv_param); + +#ifdef __cplusplus +} +#endif + +#endif // MINDSPORE_LITE_NNACL_INT8_CONV_INT8_H_ diff --git a/mindspore/lite/nnacl/int8/conv_depthwise_int8.c b/mindspore/lite/nnacl/int8/conv_depthwise_int8.c index 846ea9344f4..0bd2380fdd9 100644 --- a/mindspore/lite/nnacl/int8/conv_depthwise_int8.c +++ b/mindspore/lite/nnacl/int8/conv_depthwise_int8.c @@ -16,7 +16,7 @@ #include "nnacl/int8/conv_depthwise_int8.h" #include -#include "nnacl/quantization/fixed_point.h" +#include "nnacl/int8/fixed_point.h" #include "nnacl/int8/common_func_int8.h" /*conv depthwise int8 begin*/ diff --git a/mindspore/lite/nnacl/int8/conv_int8.c b/mindspore/lite/nnacl/int8/conv_int8.c index b54f7e45071..6e995abbade 100644 --- a/mindspore/lite/nnacl/int8/conv_int8.c +++ b/mindspore/lite/nnacl/int8/conv_int8.c @@ -15,52 +15,6 @@ */ #include "nnacl/int8/conv_int8.h" -#include -#include "nnacl/winograd_transform.h" -#include "nnacl/int8/common_func_int8.h" - -void Conv3x3Int8Gemm(int32_t *dst, const int16_t *src, const int16_t *weight, int oc, int ic8, size_t real_cal_num) { - int oc4 = UP_DIV(oc, C4NUM); -#ifdef ENABLE_ARM - IndirectGemmInt16to32_8x4(dst, src, weight, 16, ic8, oc4, oc4 * 4 * 16 * sizeof(int32_t)); -#else - const int input_unit_square = 16; - for (int c = 0; c < oc4; c++) { - int filter_oc_offset = c * input_unit_square * ic8 * C8NUM * C4NUM; - int dst_oc_offset = c * input_unit_square * C4NUM; - for (int n = 0; n < real_cal_num; n++) { - int src_tile_offset = n * C8NUM; - int dst_tile_offset = dst_oc_offset + n * oc4 * C4NUM * input_unit_square; - for (int i = 0; i < 4; i++) { - int filter_h_offset = filter_oc_offset + i * 4 * ic8 * C8NUM * C4NUM; - int src_h_offset = src_tile_offset + i * C8NUM * ic8 * C8NUM * C4NUM; - int dst_h_offset = dst_tile_offset + i * 4 * 4; - for (int m = 0; m < 4; m++) { - int filter_w_offset = filter_h_offset + m * 4 * C8NUM * ic8; - int src_w_offset = src_h_offset + m * 8 * ic8 * C8NUM; - int dst_w_offset = dst_h_offset + m * C4NUM; - - int32_t acc[4] = {0}; - for (int z = 0; z < 4; z++) { - int filter_offset = filter_w_offset + z; - for (int j = 0; j < ic8; j++) { - int filter_c8_offset = filter_offset + j * 4 * 8; - int src_c8_offset = src_w_offset + j * 8 * 8; - - for (int k = 0; k < 8; k++) { - const int16_t *w_ptr = weight + filter_c8_offset + k * 4; - const int16_t *input_ptr = src + src_c8_offset + k; - acc[z] += w_ptr[0] * input_ptr[0]; - } - } - (dst + dst_w_offset + z)[0] = acc[z]; - } - } - } - } - } -#endif -} void ConvInt8(int8_t *input_data, int8_t *packed_input, int8_t *matmul_input, int8_t *packed_weight, const int32_t *bias_data, int8_t *output_data, int32_t *filter_zp, int32_t *input_sum, int task_id, @@ -141,717 +95,3 @@ void ConvInt8(int8_t *input_data, int8_t *packed_input, int8_t *matmul_input, in } } } - -void Conv1x1PreOptPeroc(const int8_t *src_input, int8_t *packed_input, int32_t *input_sum, size_t input_channel, - size_t output_channel, size_t plane_size, int32_t *filter_zp, size_t inputsum_stride) { - int ic4 = UP_ROUND(input_channel, C4NUM); - int oc8 = UP_ROUND(output_channel, C8NUM); - int hw8 = UP_ROUND(plane_size, C8NUM); - size_t hw_8div = plane_size / C8NUM * C8NUM; - size_t oc_8div = output_channel / C8NUM * C8NUM; - size_t oc_8res = output_channel - oc_8div; - size_t ic_4div = input_channel / C4NUM * C4NUM; - - const int8_t *src_r = src_input; - int8_t *pack_r = packed_input; - int32_t *input_sum_r = input_sum; - - for (int hwi = 0; hwi < hw_8div; hwi += C8NUM) { - const int8_t *src_ic = src_r; - int8_t *pack_ic = pack_r; - int32_t *input_sum_oc = input_sum_r; -#ifdef ENABLE_ARM64 - size_t src_stride = input_channel; - size_t ic_4res = input_channel - ic_4div; - size_t input_sum_stride = inputsum_stride * 4 - C8NUM * C8NUM * 4; - asm volatile( - "dup v16.4s, wzr \n" - "dup v17.4s, wzr \n" - - "mov x10, %[src_ic] \n" - "mov x11, %[pack_ic] \n" - - "mov x0, #0 \n" - "1: \n" - "cmp x0, %[ic_4div] \n" - "add x0, x0, #4\n" - "mov x12, x10 \n" - "add x10, x10, #4\n" - "blt 2f \n" - "cmp %[ic_4res], #0\n" - "beq 6f \n" - "cmp %[ic_4res], #1\n" - "beq 3f \n" - "cmp %[ic_4res], #2\n" - "beq 4f \n" - "cmp %[ic_4res], #3\n" - "beq 5f \n" - - "2: \n" - "ld1 {v0.s}[0], [x12], %[src_stride]\n" - "ld1 {v0.s}[1], [x12], %[src_stride]\n" - "ld1 {v0.s}[2], [x12], %[src_stride]\n" - "ld1 {v0.s}[3], [x12], %[src_stride]\n" - "ld1 {v1.s}[0], [x12], %[src_stride]\n" - "ld1 {v1.s}[1], [x12], %[src_stride]\n" - "ld1 {v1.s}[2], [x12], %[src_stride]\n" - "ld1 {v1.s}[3], [x12], %[src_stride]\n" - - "st1 {v0.16b}, [x11], #16\n" - "st1 {v1.16b}, [x11], #16\n" - - "saddlp v4.8h, v0.16b \n" - "saddlp v5.8h, v1.16b \n" - "saddlp v0.4s, v4.8h \n" - "saddlp v1.4s, v5.8h \n" - "add v16.4s, v16.4s, v0.4s \n" - "add v17.4s, v17.4s, v1.4s \n" - "b 1b \n" - - "3: \n" /* col res 1 */ - "dup v0.4s, wzr \n" - "dup v1.4s, wzr \n" - - "ld1 {v0.b}[0], [x12], %[src_stride]\n" - "ld1 {v0.b}[4], [x12], %[src_stride]\n" - "ld1 {v0.b}[8], [x12], %[src_stride]\n" - "ld1 {v0.b}[12], [x12], %[src_stride]\n" - "ld1 {v1.b}[0], [x12], %[src_stride]\n" - "ld1 {v1.b}[4], [x12], %[src_stride]\n" - "ld1 {v1.b}[8], [x12], %[src_stride]\n" - "ld1 {v1.b}[12], [x12], %[src_stride]\n" - - "st1 {v0.16b}, [x11], #16\n" - "st1 {v1.16b}, [x11], #16\n" - - "saddlp v4.8h, v0.16b \n" - "saddlp v5.8h, v1.16b \n" - "saddlp v0.4s, v4.8h \n" - "saddlp v1.4s, v5.8h \n" - "add v16.4s, v16.4s, v0.4s \n" - "add v17.4s, v17.4s, v1.4s \n" - "b 6f \n" - - "4: \n" /* col res 2 */ - "dup v0.4s, wzr \n" - "dup v1.4s, wzr \n" - - "ld1 {v0.h}[0], [x12], %[src_stride]\n" - "ld1 {v0.h}[2], [x12], %[src_stride]\n" - "ld1 {v0.h}[4], [x12], %[src_stride]\n" - "ld1 {v0.h}[6], [x12], %[src_stride]\n" - "ld1 {v1.h}[0], [x12], %[src_stride]\n" - "ld1 {v1.h}[2], [x12], %[src_stride]\n" - "ld1 {v1.h}[4], [x12], %[src_stride]\n" - "ld1 {v1.h}[6], [x12], %[src_stride]\n" - - "st1 {v0.16b}, [x11], #16\n" - "st1 {v1.16b}, [x11], #16\n" - - "saddlp v4.8h, v0.16b \n" - "saddlp v5.8h, v1.16b \n" - "saddlp v0.4s, v4.8h \n" - "saddlp v1.4s, v5.8h \n" - "add v16.4s, v16.4s, v0.4s \n" - "add v17.4s, v17.4s, v1.4s \n" - "b 6f \n" - - "5: \n" /* col res 3 */ - "dup v0.4s, wzr \n" - "dup v1.4s, wzr \n" - "add x13, x12, #2 \n" - - "ld1 {v0.h}[0], [x12], %[src_stride]\n" - "ld1 {v0.b}[2], [x13], %[src_stride]\n" - "ld1 {v0.h}[2], [x12], %[src_stride]\n" - "ld1 {v0.b}[6], [x13], %[src_stride]\n" - "ld1 {v0.h}[4], [x12], %[src_stride]\n" - "ld1 {v0.b}[10], [x13], %[src_stride]\n" - "ld1 {v0.h}[6], [x12], %[src_stride]\n" - "ld1 {v0.b}[14], [x13], %[src_stride]\n" - "ld1 {v1.h}[0], [x12], %[src_stride]\n" - "ld1 {v1.b}[2], [x13], %[src_stride]\n" - "ld1 {v1.h}[2], [x12], %[src_stride]\n" - "ld1 {v1.b}[6], [x13], %[src_stride]\n" - "ld1 {v1.h}[4], [x12], %[src_stride]\n" - "ld1 {v1.b}[10], [x13], %[src_stride]\n" - "ld1 {v1.h}[6], [x12], %[src_stride]\n" - "ld1 {v1.b}[14], [x13], %[src_stride]\n" - - "st1 {v0.16b}, [x11], #16\n" - "st1 {v1.16b}, [x11], #16\n" - - "saddlp v4.8h, v0.16b \n" - "saddlp v5.8h, v1.16b \n" - "saddlp v0.4s, v4.8h \n" - "saddlp v1.4s, v5.8h \n" - "add v16.4s, v16.4s, v0.4s \n" - "add v17.4s, v17.4s, v1.4s \n" - "b 6f \n" - - "6: \n" - "dup v0.4s, v16.s[0] \n" - "dup v1.4s, v16.s[1] \n" - "dup v2.4s, v16.s[2] \n" - "dup v3.4s, v16.s[3] \n" - "dup v4.4s, v17.s[0] \n" - "dup v5.4s, v17.s[1] \n" - "dup v6.4s, v17.s[2] \n" - "dup v7.4s, v17.s[3] \n" - "mov x4, #0 \n" - "mov x10, %[filter_zp] \n" - "mov x11, %[input_sum_oc] \n" - - "7: \n" - "cmp x4, %[oc_8div] \n" - "beq 8f \n" - "add x4, x4, #8\n" - "ld1 {v16.4s}, [x10], #16\n" - "ld1 {v17.4s}, [x10], #16\n" - - "mul v18.4s, v16.4s, v0.4s \n" - "mul v19.4s, v17.4s, v0.4s \n" - "st1 {v18.4s}, [x11], #16 \n" - "st1 {v19.4s}, [x11], #16 \n" - - "mul v20.4s, v16.4s, v1.4s \n" - "mul v21.4s, v17.4s, v1.4s \n" - "st1 {v20.4s}, [x11], #16 \n" - "st1 {v21.4s}, [x11], #16 \n" - - "mul v22.4s, v16.4s, v2.4s \n" - "mul v23.4s, v17.4s, v2.4s \n" - "st1 {v22.4s}, [x11], #16 \n" - "st1 {v23.4s}, [x11], #16 \n" - - "mul v24.4s, v16.4s, v3.4s \n" - "mul v25.4s, v17.4s, v3.4s \n" - "st1 {v24.4s}, [x11], #16 \n" - "st1 {v25.4s}, [x11], #16 \n" - - "mul v18.4s, v16.4s, v4.4s \n" - "mul v19.4s, v17.4s, v4.4s \n" - "st1 {v18.4s}, [x11], #16 \n" - "st1 {v19.4s}, [x11], #16 \n" - - "mul v20.4s, v16.4s, v5.4s \n" - "mul v21.4s, v17.4s, v5.4s \n" - "st1 {v20.4s}, [x11], #16 \n" - "st1 {v21.4s}, [x11], #16 \n" - - "mul v22.4s, v16.4s, v6.4s \n" - "mul v23.4s, v17.4s, v6.4s \n" - "st1 {v22.4s}, [x11], #16 \n" - "st1 {v23.4s}, [x11], #16 \n" - - "mul v24.4s, v16.4s, v7.4s \n" - "mul v25.4s, v17.4s, v7.4s \n" - "st1 {v24.4s}, [x11], #16 \n" - "st1 {v25.4s}, [x11], #16 \n" - - "add x11, x11, %[input_sum_stride] \n" - "b 7b \n" - - "8: \n" - "cmp %[oc_8res], #0\n" - "beq 17f \n" - - "dup v16.4s, wzr \n" - "dup v17.4s, wzr \n" - "cmp %[oc_8res], #1\n" - "beq 9f \n" - "cmp %[oc_8res], #2\n" - "beq 10f \n" - "cmp %[oc_8res], #3\n" - "beq 11f \n" - "cmp %[oc_8res], #4\n" - "beq 12f \n" - "cmp %[oc_8res], #5\n" - "beq 13f \n" - "cmp %[oc_8res], #6\n" - "beq 14f \n" - "cmp %[oc_8res], #7\n" - "beq 15f \n" - - "9: \n" - "ld1 {v16.s}[0], [x10] \n" - "b 16f \n" - - "10: \n" - "ld1 {v16.d}[0], [x10] \n" - "b 16f \n" - - "11: \n" - "ld1 {v16.d}[0], [x10] \n" - "add x10, x10, #8 \n" - "ld1 {v16.s}[2], [x10] \n" - "b 16f \n" - - "12: \n" - "ld1 {v16.4s}, [x10] \n" - "b 16f \n" - - "13: \n" - "ld1 {v16.4s}, [x10], #16\n" - "ld1 {v17.s}[0], [x10] \n" - "b 16f \n" - - "14: \n" - "ld1 {v16.4s}, [x10], #16\n" - "ld1 {v17.d}[0], [x10] \n" - "b 16f \n" - - "15: \n" - "ld1 {v16.4s}, [x10], #16\n" - "ld1 {v17.d}[0], [x10] \n" - "add x10, x10, #8 \n" - "ld1 {v17.s}[2], [x10] \n" - "b 16f \n" - - "16: \n" - "mul v18.4s, v16.4s, v0.4s \n" - "mul v19.4s, v17.4s, v0.4s \n" - "mul v20.4s, v16.4s, v1.4s \n" - "mul v21.4s, v17.4s, v1.4s \n" - "mul v22.4s, v16.4s, v2.4s \n" - "mul v23.4s, v17.4s, v2.4s \n" - "mul v24.4s, v16.4s, v3.4s \n" - "mul v25.4s, v17.4s, v3.4s \n" - "st1 {v18.4s}, [x11], #16 \n" - "st1 {v19.4s}, [x11], #16 \n" - "st1 {v20.4s}, [x11], #16 \n" - "st1 {v21.4s}, [x11], #16 \n" - "st1 {v22.4s}, [x11], #16 \n" - "st1 {v23.4s}, [x11], #16 \n" - "st1 {v24.4s}, [x11], #16 \n" - "st1 {v25.4s}, [x11], #16 \n" - - "mul v18.4s, v16.4s, v4.4s \n" - "mul v19.4s, v17.4s, v4.4s \n" - "mul v20.4s, v16.4s, v5.4s \n" - "mul v21.4s, v17.4s, v5.4s \n" - "mul v22.4s, v16.4s, v6.4s \n" - "mul v23.4s, v17.4s, v6.4s \n" - "mul v24.4s, v16.4s, v7.4s \n" - "mul v25.4s, v17.4s, v7.4s \n" - "st1 {v18.4s}, [x11], #16 \n" - "st1 {v19.4s}, [x11], #16 \n" - "st1 {v20.4s}, [x11], #16 \n" - "st1 {v21.4s}, [x11], #16 \n" - "st1 {v22.4s}, [x11], #16 \n" - "st1 {v23.4s}, [x11], #16 \n" - "st1 {v24.4s}, [x11], #16 \n" - "st1 {v25.4s}, [x11], #16 \n" - - "17: \n" - - : - : [ src_ic ] "r"(src_ic), [ pack_ic ] "r"(pack_ic), [ filter_zp ] "r"(filter_zp), - [ input_sum_oc ] "r"(input_sum_oc), [ input_sum_stride ] "r"(input_sum_stride), [ src_stride ] "r"(src_stride), - [ ic_4div ] "r"(ic_4div), [ ic_4res ] "r"(ic_4res), [ oc_8div ] "r"(oc_8div), [ oc_8res ] "r"(oc_8res) - : "x0", "x1", "x4", "x9", "x10", "x11", "x12", "x13", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v16", - "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25"); -#else - int32_t tmp_sum_value[8] = {0}; - for (int ici = 0; ici < ic_4div; ici += C4NUM) { - for (int i = 0; i < C8NUM; i++) { - tmp_sum_value[i] += src_ic[0 + i * input_channel]; - tmp_sum_value[i] += src_ic[1 + i * input_channel]; - tmp_sum_value[i] += src_ic[2 + i * input_channel]; - tmp_sum_value[i] += src_ic[3 + i * input_channel]; - pack_ic[0 + i * C4NUM] = src_ic[0 + i * input_channel]; - pack_ic[1 + i * C4NUM] = src_ic[1 + i * input_channel]; - pack_ic[2 + i * C4NUM] = src_ic[2 + i * input_channel]; - pack_ic[3 + i * C4NUM] = src_ic[3 + i * input_channel]; - } - src_ic += C4NUM; - pack_ic += C4NUM * C8NUM; - } - for (int ici = ic_4div; ici < input_channel; ici += 1) { - for (int i = 0; i < C8NUM; i++) { - tmp_sum_value[i] += src_ic[i * input_channel]; - pack_ic[i * C4NUM] = src_ic[i * input_channel]; - } - src_ic += 1; - pack_ic += 1; - } - - for (int ici = input_channel; ici < ic4; ici += 1) { - for (int i = 0; i < C8NUM; i++) { - pack_ic[i * C4NUM] = 0; - } - pack_ic += 1; - } - - for (int oci = 0; oci < oc_8div; oci += C8NUM) { - for (int ri = 0; ri < C8NUM; ri++) { - input_sum_oc[ri * C8NUM + 0] = tmp_sum_value[ri] * filter_zp[oci + 0]; - input_sum_oc[ri * C8NUM + 1] = tmp_sum_value[ri] * filter_zp[oci + 1]; - input_sum_oc[ri * C8NUM + 2] = tmp_sum_value[ri] * filter_zp[oci + 2]; - input_sum_oc[ri * C8NUM + 3] = tmp_sum_value[ri] * filter_zp[oci + 3]; - input_sum_oc[ri * C8NUM + 4] = tmp_sum_value[ri] * filter_zp[oci + 4]; - input_sum_oc[ri * C8NUM + 5] = tmp_sum_value[ri] * filter_zp[oci + 5]; - input_sum_oc[ri * C8NUM + 6] = tmp_sum_value[ri] * filter_zp[oci + 6]; - input_sum_oc[ri * C8NUM + 7] = tmp_sum_value[ri] * filter_zp[oci + 7]; - } - input_sum_oc += inputsum_stride; - } - if (oc_8div != output_channel) { - for (int oci = 0; oci < oc_8res; oci += 1) { - for (int ri = 0; ri < C8NUM; ri++) { - input_sum_oc[ri * C8NUM + oci] = tmp_sum_value[ri] * filter_zp[oc_8div + oci]; - } - } - for (int oci = oc_8res; oci < C8NUM; oci += 1) { - for (int ri = 0; ri < C8NUM; ri++) { - input_sum_oc[ri * C8NUM + oci] = 0; - } - } - } /* oc8 res done */ -#endif - src_r += input_channel * C8NUM; - pack_r += ic4 * C8NUM; - input_sum_r += C8NUM * C8NUM; - } - - if (hw_8div != plane_size) { - memset(pack_r, 0, C8NUM * ic4); - for (int hwi = hw_8div; hwi < plane_size; hwi += 1) { - int32_t *input_sum_oc = input_sum_r; - int32_t tmp_sum_value = 0; - const int8_t *src_ic = src_r; - int8_t *pack_ic = pack_r; - for (int ici = 0; ici < ic_4div; ici += C4NUM) { - tmp_sum_value += src_ic[0]; - tmp_sum_value += src_ic[1]; - tmp_sum_value += src_ic[2]; - tmp_sum_value += src_ic[3]; - pack_ic[0] = src_ic[0]; - pack_ic[1] = src_ic[1]; - pack_ic[2] = src_ic[2]; - pack_ic[3] = src_ic[3]; - src_ic += C4NUM; - pack_ic += C4NUM * C8NUM; - } - for (int ici = ic_4div; ici < input_channel; ici += 1) { - tmp_sum_value += src_ic[0]; - pack_ic[0] = src_ic[0]; - src_ic += 1; - pack_ic += 1; - } - - for (int oci = 0; oci < oc_8div; oci += C8NUM) { - for (int curoi = 0; curoi < C8NUM; curoi++) { - input_sum_oc[curoi] = tmp_sum_value * filter_zp[oci + curoi]; - } - input_sum_oc += inputsum_stride; - } - if (oc_8div != output_channel) { - for (int oci = 0; oci < oc_8res; oci += 1) { - input_sum_oc[oci] = tmp_sum_value * filter_zp[oc_8div + oci]; - } - for (int oci = oc_8res; oci < C8NUM; oci += 1) { - input_sum_oc[oci] = 0; - } - } /* oc8 res done */ - - src_r += input_channel; - pack_r += C4NUM; - input_sum_r += C8NUM; - } - - for (int hwi = plane_size; hwi < hw8; hwi++) { - for (int oc = 0; oc < oc8; oc++) { - int oc8div = oc / C8NUM, oc8res = oc % C8NUM; - input_sum[oc8div * inputsum_stride + hwi * C8NUM + oc8res] = 0; - } - } - } - return; -} - -void Conv1x1PreOptPert(const int8_t *src_input, int8_t *packed_input, int32_t *input_sum, size_t input_channel, - size_t plane_size, ConvParameter *conv_param) { - int ic4 = UP_ROUND(input_channel, C4NUM); - size_t hw_8div = plane_size / C8NUM * C8NUM; - size_t ic_4div = input_channel / C4NUM * C4NUM; - int32_t filter_zp = conv_param->conv_quant_arg_.filter_quant_args_[0].zp_; - - const int8_t *src_r = src_input; - int8_t *pack_r = packed_input; - /* per layer */ - for (int hwi = 0; hwi < hw_8div; hwi += C8NUM) { - const int8_t *src_ic = src_r; - int8_t *pack_ic = pack_r; - int32_t *input_sum_r = input_sum + hwi; -#ifdef ENABLE_ARM64 - size_t src_stride = input_channel; - size_t ic_4res = input_channel - ic_4div; - asm volatile( - "dup v16.4s, wzr \n" - "dup v17.4s, wzr \n" - "mov x14, %[input_sum_r] \n" - "dup v20.4s, %w[filter_zp] \n" - - "mov x10, %[src_ic] \n" - "mov x11, %[pack_ic] \n" - - "mov x0, #0 \n" - "1: \n" - "cmp x0, %[ic_4div] \n" - "add x0, x0, #4\n" - "mov x12, x10 \n" - "add x10, x10, #4\n" - "blt 2f \n" - "cmp %[ic_4res], #0\n" - "beq 6f \n" - "cmp %[ic_4res], #1\n" - "beq 3f \n" - "cmp %[ic_4res], #2\n" - "beq 4f \n" - "cmp %[ic_4res], #3\n" - "beq 5f \n" - - "2: \n" - "ld1 {v0.s}[0], [x12], %[src_stride]\n" - "ld1 {v0.s}[1], [x12], %[src_stride]\n" - "ld1 {v0.s}[2], [x12], %[src_stride]\n" - "ld1 {v0.s}[3], [x12], %[src_stride]\n" - "ld1 {v1.s}[0], [x12], %[src_stride]\n" - "ld1 {v1.s}[1], [x12], %[src_stride]\n" - "ld1 {v1.s}[2], [x12], %[src_stride]\n" - "ld1 {v1.s}[3], [x12], %[src_stride]\n" - - "st1 {v0.16b}, [x11], #16\n" - "st1 {v1.16b}, [x11], #16\n" - - "saddlp v4.8h, v0.16b \n" - "saddlp v5.8h, v1.16b \n" - - "saddlp v0.4s, v4.8h \n" - "saddlp v1.4s, v5.8h \n" - - "add v16.4s, v16.4s, v0.4s \n" - "add v17.4s, v17.4s, v1.4s \n" - "b 1b \n" - - "3: \n" /* col res 1 */ - "dup v0.4s, wzr \n" - "dup v1.4s, wzr \n" - - "ld1 {v0.b}[0], [x12], %[src_stride]\n" - "ld1 {v0.b}[4], [x12], %[src_stride]\n" - "ld1 {v0.b}[8], [x12], %[src_stride]\n" - "ld1 {v0.b}[12], [x12], %[src_stride]\n" - "ld1 {v1.b}[0], [x12], %[src_stride]\n" - "ld1 {v1.b}[4], [x12], %[src_stride]\n" - "ld1 {v1.b}[8], [x12], %[src_stride]\n" - "ld1 {v1.b}[12], [x12], %[src_stride]\n" - - "st1 {v0.16b}, [x11], #16\n" - "st1 {v1.16b}, [x11], #16\n" - "saddlp v4.8h, v0.16b \n" - "saddlp v5.8h, v1.16b \n" - "saddlp v0.4s, v4.8h \n" - "saddlp v1.4s, v5.8h \n" - "add v16.4s, v16.4s, v0.4s \n" - "add v17.4s, v17.4s, v1.4s \n" - "b 6f \n" - - "4: \n" /* col res 2 */ - "dup v0.4s, wzr \n" - "dup v1.4s, wzr \n" - - "ld1 {v0.h}[0], [x12], %[src_stride]\n" - "ld1 {v0.h}[2], [x12], %[src_stride]\n" - "ld1 {v0.h}[4], [x12], %[src_stride]\n" - "ld1 {v0.h}[6], [x12], %[src_stride]\n" - "ld1 {v1.h}[0], [x12], %[src_stride]\n" - "ld1 {v1.h}[2], [x12], %[src_stride]\n" - "ld1 {v1.h}[4], [x12], %[src_stride]\n" - "ld1 {v1.h}[6], [x12], %[src_stride]\n" - - "st1 {v0.16b}, [x11], #16\n" - "st1 {v1.16b}, [x11], #16\n" - "saddlp v4.8h, v0.16b \n" - "saddlp v5.8h, v1.16b \n" - "saddlp v0.4s, v4.8h \n" - "saddlp v1.4s, v5.8h \n" - "add v16.4s, v16.4s, v0.4s \n" - "add v17.4s, v17.4s, v1.4s \n" - "b 6f \n" - - "5: \n" /* col res 3 */ - "dup v0.4s, wzr \n" - "dup v1.4s, wzr \n" - "add x13, x12, #2 \n" - - "ld1 {v0.h}[0], [x12], %[src_stride]\n" - "ld1 {v0.b}[2], [x13], %[src_stride]\n" - "ld1 {v0.h}[2], [x12], %[src_stride]\n" - "ld1 {v0.b}[6], [x13], %[src_stride]\n" - "ld1 {v0.h}[4], [x12], %[src_stride]\n" - "ld1 {v0.b}[10], [x13], %[src_stride]\n" - "ld1 {v0.h}[6], [x12], %[src_stride]\n" - "ld1 {v0.b}[14], [x13], %[src_stride]\n" - "ld1 {v1.h}[0], [x12], %[src_stride]\n" - "ld1 {v1.b}[2], [x13], %[src_stride]\n" - "ld1 {v1.h}[2], [x12], %[src_stride]\n" - "ld1 {v1.b}[6], [x13], %[src_stride]\n" - "ld1 {v1.h}[4], [x12], %[src_stride]\n" - "ld1 {v1.b}[10], [x13], %[src_stride]\n" - "ld1 {v1.h}[6], [x12], %[src_stride]\n" - "ld1 {v1.b}[14], [x13], %[src_stride]\n" - - "st1 {v0.16b}, [x11], #16\n" - "st1 {v1.16b}, [x11], #16\n" - "saddlp v4.8h, v0.16b \n" - "saddlp v5.8h, v1.16b \n" - "saddlp v0.4s, v4.8h \n" - "saddlp v1.4s, v5.8h \n" - "add v16.4s, v16.4s, v0.4s \n" - "add v17.4s, v17.4s, v1.4s \n" - "b 6f \n" - - "6: \n" - "mul v16.4s, v16.4s, v20.4s \n" - "mul v17.4s, v17.4s, v20.4s \n" - - "st1 {v16.4s}, [x14], #16 \n" - "st1 {v17.4s}, [x14], #16 \n" - - : - : [ src_ic ] "r"(src_ic), [ pack_ic ] "r"(pack_ic), [ input_sum_r ] "r"(input_sum_r), - [ src_stride ] "r"(src_stride), [ ic_4div ] "r"(ic_4div), [ ic_4res ] "r"(ic_4res), [ filter_zp ] "r"(filter_zp) - : "x0", "x1", "x10", "x11", "x12", "x13", "x14", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v16", "v17", - "v20"); -#else - int32_t tmp_sum_value[8] = {0}; - for (int ici = 0; ici < ic_4div; ici += C4NUM) { - for (int i = 0; i < C8NUM; i++) { - tmp_sum_value[i] += src_ic[0 + i * input_channel]; - tmp_sum_value[i] += src_ic[1 + i * input_channel]; - tmp_sum_value[i] += src_ic[2 + i * input_channel]; - tmp_sum_value[i] += src_ic[3 + i * input_channel]; - pack_ic[0 + i * C4NUM] = src_ic[0 + i * input_channel]; - pack_ic[1 + i * C4NUM] = src_ic[1 + i * input_channel]; - pack_ic[2 + i * C4NUM] = src_ic[2 + i * input_channel]; - pack_ic[3 + i * C4NUM] = src_ic[3 + i * input_channel]; - } - src_ic += C4NUM; - pack_ic += C4NUM * C8NUM; - } - for (int ici = ic_4div; ici < input_channel; ici += 1) { - for (int i = 0; i < C8NUM; i++) { - tmp_sum_value[i] += src_ic[i * input_channel]; - pack_ic[i * C4NUM] = src_ic[i * input_channel]; - } - src_ic += 1; - pack_ic += 1; - } - - for (int ici = input_channel; ici < ic4; ici += 1) { - for (int i = 0; i < C8NUM; i++) { - pack_ic[i * C4NUM] = 0; - } - pack_ic += 1; - } - - for (int i = 0; i < C8NUM; i++) { - input_sum_r[i] = tmp_sum_value[i] * filter_zp; - } -#endif - src_r += input_channel * C8NUM; - pack_r += ic4 * C8NUM; - } - - if (hw_8div != plane_size) { - memset(pack_r, 0, C8NUM * ic4); - for (int hwi = hw_8div; hwi < plane_size; hwi += 1) { - int32_t tmp_sum_value = 0; - const int8_t *src_ic = src_r; - int8_t *pack_ic = pack_r; - for (int ici = 0; ici < ic_4div; ici += C4NUM) { - tmp_sum_value += src_ic[0]; - tmp_sum_value += src_ic[1]; - tmp_sum_value += src_ic[2]; - tmp_sum_value += src_ic[3]; - pack_ic[0] = src_ic[0]; - pack_ic[1] = src_ic[1]; - pack_ic[2] = src_ic[2]; - pack_ic[3] = src_ic[3]; - src_ic += C4NUM; - pack_ic += C4NUM * C8NUM; - } - for (int ici = ic_4div; ici < input_channel; ici += 1) { - tmp_sum_value += src_ic[0]; - pack_ic[0] = src_ic[0]; - src_ic += 1; - pack_ic += 1; - } - input_sum[hwi] = tmp_sum_value * filter_zp; - src_r += input_channel; - pack_r += C4NUM; - } - for (int hwi = plane_size; hwi < UP_ROUND(plane_size, C8NUM); hwi++) { - input_sum[hwi] = 0; - } - } - return; -} - -void Conv1x1Int8Opt(const int8_t *packed_input, const int8_t *packed_weight, int8_t *dst, const int32_t *input_sum, - const int32_t *bias, int row, int col, int deep4, int32_t *left_shift, int32_t *right_shift, - int32_t *multiplier, ConvParameter *conv_param, MATMUL_OPT_DP_FUNC matmul_func, int *filter_zp) { - int is_per_oc = (int)conv_param->conv_quant_arg_.filter_arg_num_ != 1; - matmul_func(packed_input, packed_weight, dst, row, col, deep4, conv_param->output_channel_, input_sum, bias, - left_shift, right_shift, multiplier, conv_param->conv_quant_arg_.output_quant_args_[0].zp_, - conv_param->conv_quant_arg_.out_act_min_[0], conv_param->conv_quant_arg_.out_act_max_[0], is_per_oc, - filter_zp); - return; -} - -void Conv1x1Int8(const int8_t *packed_input, const int8_t *packed_weight, int8_t *dst, const int32_t *input_sum, - const int32_t *bias, int row, int col, int deep16, int32_t *left_shift, int32_t *right_shift, - int32_t *multiplier, ConvParameter *conv_param, int32_t *filter_zp) { - int is_per_oc = (int)conv_param->conv_quant_arg_.filter_arg_num_ != 1; - MatmulInt8Opt(packed_input, packed_weight, dst, row, col, deep16, input_sum, bias, - conv_param->conv_quant_arg_.out_act_min_[0], conv_param->conv_quant_arg_.out_act_max_[0], - conv_param->conv_quant_arg_.output_quant_args_[0].zp_, multiplier, left_shift, right_shift, - conv_param->output_channel_, is_per_oc, filter_zp); - return; -} - -// int8 convolution 3x3 -void Conv3x3Int8(int16_t *input_data, int16_t *transed_weight, const int32_t *bias_data, int8_t *output_data, - int16_t *tile_buffer, int16_t *block_unit_buffer, int32_t *tmp_dst_buffer, int8_t *tmp_out, - int task_id, ConvParameter *conv_param) { - int ic8 = UP_DIV(conv_param->input_channel_, C8NUM); - int out_w_block = UP_DIV(conv_param->output_w_, OUPUT_UNIT); - int out_h_block = UP_DIV(conv_param->output_h_, OUPUT_UNIT); - int output_count = out_w_block * out_h_block; - int output_tile_count = UP_DIV(output_count, TILE_NUM); - int oc4 = UP_DIV(conv_param->output_channel_, C4NUM); - int tile_buffer_offset = TILE_NUM * 16 * ic8 * C8NUM; - const int block_unit_buffer_offset = 16 * C8NUM; - int tmp_dst_buffer_offset = TILE_NUM * 16 * oc4 * C4NUM; - - for (int batch = 0; batch < conv_param->input_batch_; batch++) { - int in_batch_offset = batch * ic8 * C8NUM * conv_param->input_h_ * conv_param->input_w_; - int tmp_out_batch_offset = batch * oc4 * C4NUM * conv_param->output_w_ * conv_param->output_h_; - for (int thread_id = task_id; thread_id < output_tile_count; thread_id += conv_param->thread_num_) { - int start_index = thread_id * TILE_NUM; - int real_cal_num = (output_count - start_index) < TILE_NUM ? (output_count - start_index) : TILE_NUM; - - Conv3x3Int8InputTransform(input_data + in_batch_offset, tile_buffer + task_id * tile_buffer_offset, - block_unit_buffer + task_id * block_unit_buffer_offset, start_index, real_cal_num, - out_w_block, conv_param); - - Conv3x3Int8Gemm(tmp_dst_buffer + task_id * tmp_dst_buffer_offset, tile_buffer + task_id * tile_buffer_offset, - transed_weight, conv_param->output_channel_, ic8, real_cal_num); - - Conv3x3Int8OutputTransform(tmp_dst_buffer + task_id * tmp_dst_buffer_offset, tmp_out + tmp_out_batch_offset, - bias_data, start_index, real_cal_num, out_w_block, conv_param); - } - } -} diff --git a/mindspore/lite/nnacl/int8/conv_int8.h b/mindspore/lite/nnacl/int8/conv_int8.h index 1f84f86ae3a..208b8dbef84 100644 --- a/mindspore/lite/nnacl/int8/conv_int8.h +++ b/mindspore/lite/nnacl/int8/conv_int8.h @@ -16,6 +16,7 @@ #ifndef MINDSPORE_LITE_NNACL_INT8_CONV_INT8_H_ #define MINDSPORE_LITE_NNACL_INT8_CONV_INT8_H_ +#include #ifdef ENABLE_NEON #include #endif @@ -24,9 +25,10 @@ #include "nnacl/common_func.h" #include "nnacl/conv_parameter.h" #include "nnacl/winograd_utils.h" -#include "nnacl/quantization/quantize.h" +#include "nnacl/int8/quantize.h" #include "nnacl/matmul_parameter.h" #include "nnacl/int8/matmul_int8.h" +#include "nnacl/int8/common_func_int8.h" #ifdef __cplusplus extern "C" { @@ -36,23 +38,6 @@ void ConvInt8(int8_t *input_data, int8_t *packed_input, int8_t *matmul_input, in const int32_t *bias_data, int8_t *output_data, int32_t *filter_zp, int32_t *input_sum, int task_id, ConvParameter *conv_param, MATMUL_OPT_R_FUNC matmul_func, bool is_optimize); -// int8 convolution 1x1 -void Conv1x1PreOptPeroc(const int8_t *src_input, int8_t *packed_input, int32_t *input_sum, size_t input_channel, - size_t output_channel, size_t plane_size, int32_t *filter_zp, size_t inputsum_stride); -void Conv1x1PreOptPert(const int8_t *src_input, int8_t *packed_input, int32_t *input_sum, size_t input_channel, - size_t plane_size, ConvParameter *conv_param); -void Conv1x1Int8(const int8_t *packed_input, const int8_t *packed_weight, int8_t *dst, const int32_t *input_sum, - const int32_t *bias, int row, int col, int deep16, int32_t *left_shift, int32_t *right_shift, - int32_t *multiplier, ConvParameter *conv_param, int32_t *filter_zp); -void Conv1x1Int8Opt(const int8_t *packed_input, const int8_t *packed_weight, int8_t *dst, const int32_t *input_sum, - const int32_t *bias, int row, int col, int deep4, int32_t *left_shift, int32_t *right_shift, - int32_t *multiplier, ConvParameter *conv_param, MATMUL_OPT_DP_FUNC matmul_func, int32_t *filter_zp); - -// int8 convolution 3x3 -void Conv3x3Int8(int16_t *input_data, int16_t *transed_weight, const int32_t *bias_data, int8_t *output_data, - int16_t *tile_buffer, int16_t *block_unit_buffer, int32_t *tmp_dst_buffer, int8_t *tmp_out, - int task_id, ConvParameter *conv_param); - #ifdef __cplusplus } #endif diff --git a/mindspore/lite/nnacl/int8/depth_to_space_int8.h b/mindspore/lite/nnacl/int8/depth_to_space_int8.h index f2dc6bcec78..5af96940d58 100644 --- a/mindspore/lite/nnacl/int8/depth_to_space_int8.h +++ b/mindspore/lite/nnacl/int8/depth_to_space_int8.h @@ -17,7 +17,7 @@ #define MINDSPORE_LITE_NNACL_INT8_DEPTH_TO_SPACE_INT8_H_ #include "nnacl/depth_to_space_parameter.h" -#include "nnacl/quantization/quantize.h" +#include "nnacl/int8/quantize.h" #ifdef __cplusplus extern "C" { diff --git a/mindspore/lite/nnacl/int8/div_int8.c b/mindspore/lite/nnacl/int8/div_int8.c index 1f852cbb39b..cf9d9fa592e 100644 --- a/mindspore/lite/nnacl/int8/div_int8.c +++ b/mindspore/lite/nnacl/int8/div_int8.c @@ -15,9 +15,6 @@ */ #include "nnacl/int8/div_int8.h" -#include "nnacl/quantization/fixed_point.h" -#include "nnacl/errorcode.h" -#include "nnacl/quantization/quantize.h" int DivInt8(int8_t *input0_data, int8_t *input1_data, int8_t *output_data, int64_t real_dst_count, DivQuantArg *para) { int index = 0; diff --git a/mindspore/lite/nnacl/int8/div_int8.h b/mindspore/lite/nnacl/int8/div_int8.h index c39b8854b63..8522f7247cd 100644 --- a/mindspore/lite/nnacl/int8/div_int8.h +++ b/mindspore/lite/nnacl/int8/div_int8.h @@ -18,7 +18,9 @@ #define MINDSPORE_LITE_NNACL_INT8_DIV_INT8_H_ #include "nnacl/op_base.h" -#include "nnacl/quantization/quantize.h" +#include "nnacl/errorcode.h" +#include "nnacl/int8/quantize.h" +#include "nnacl/int8/fixed_point.h" #ifdef __cplusplus extern "C" { diff --git a/mindspore/lite/nnacl/quantization/fixed_point.c b/mindspore/lite/nnacl/int8/fixed_point.c similarity index 98% rename from mindspore/lite/nnacl/quantization/fixed_point.c rename to mindspore/lite/nnacl/int8/fixed_point.c index da4b940c4f1..0e4dd09e888 100644 --- a/mindspore/lite/nnacl/quantization/fixed_point.c +++ b/mindspore/lite/nnacl/int8/fixed_point.c @@ -14,7 +14,7 @@ * limitations under the License. */ -#include "nnacl/quantization/fixed_point.h" +#include "nnacl/int8/fixed_point.h" // returns the high-32 bits of a * b with rounding // assume that a and b is divided by 2^31, who fall into [-1, 1] @@ -107,7 +107,7 @@ int CountLeadingZeroBits(uint32_t x) { if (x == 0) { return 8 * sizeof(uint32_t); } - const int32_t leading_positive = (int32_t)(1) << (8 * sizeof(uint32_t) - 1); + const int32_t leading_positive = (uint32_t)(1) << (8 * sizeof(uint32_t) - 1); int leading_zeros = 0; while (x < leading_positive) { x <<= 1; diff --git a/mindspore/lite/nnacl/quantization/fixed_point.h b/mindspore/lite/nnacl/int8/fixed_point.h similarity index 100% rename from mindspore/lite/nnacl/quantization/fixed_point.h rename to mindspore/lite/nnacl/int8/fixed_point.h diff --git a/mindspore/lite/nnacl/int8/gatherNd_int8.h b/mindspore/lite/nnacl/int8/gatherNd_int8.h index a507ca56feb..16f9749a46f 100644 --- a/mindspore/lite/nnacl/int8/gatherNd_int8.h +++ b/mindspore/lite/nnacl/int8/gatherNd_int8.h @@ -18,7 +18,7 @@ #define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_INT8_GATHERND_INT8_H_ #include "nnacl/op_base.h" -#include "nnacl/quantization/quantize.h" +#include "nnacl/int8/quantize.h" #ifdef __cplusplus extern "C" { diff --git a/mindspore/lite/nnacl/int8/gather_int8.c b/mindspore/lite/nnacl/int8/gather_int8.c index 34b71b583b7..014ccb06ceb 100644 --- a/mindspore/lite/nnacl/int8/gather_int8.c +++ b/mindspore/lite/nnacl/int8/gather_int8.c @@ -16,7 +16,7 @@ */ #include "nnacl/int8/gather_int8.h" #include "nnacl/op_base.h" -#include "nnacl/quantization/quantize.h" +#include "nnacl/int8/quantize.h" #include "nnacl/errorcode.h" int GatherInt8(int8_t *in_data, int8_t *out_data, int outer_size, int inner_size, int limit, const int *indices, diff --git a/mindspore/lite/nnacl/int8/gather_int8.h b/mindspore/lite/nnacl/int8/gather_int8.h index 5563f9316f6..6666188cb78 100644 --- a/mindspore/lite/nnacl/int8/gather_int8.h +++ b/mindspore/lite/nnacl/int8/gather_int8.h @@ -18,7 +18,7 @@ #define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_INT8_GATHER_INT8_H_ #include "nnacl/op_base.h" -#include "nnacl/quantization/quantize.h" +#include "nnacl/int8/quantize.h" #ifdef __cplusplus extern "C" { diff --git a/mindspore/lite/nnacl/int8/hswish_int8.h b/mindspore/lite/nnacl/int8/hswish_int8.h index e54248fb4bd..fd18fe43716 100644 --- a/mindspore/lite/nnacl/int8/hswish_int8.h +++ b/mindspore/lite/nnacl/int8/hswish_int8.h @@ -19,7 +19,7 @@ #include #include "nnacl/op_base.h" #include "nnacl/errorcode.h" -#include "nnacl/quantization/fixed_point.h" +#include "nnacl/int8/fixed_point.h" typedef struct HswishQuantArg { double input_scale; diff --git a/mindspore/lite/nnacl/int8/l2_norm_int8.c b/mindspore/lite/nnacl/int8/l2_norm_int8.c index 4f86e657bbb..cea02a59d96 100644 --- a/mindspore/lite/nnacl/int8/l2_norm_int8.c +++ b/mindspore/lite/nnacl/int8/l2_norm_int8.c @@ -15,7 +15,7 @@ */ #include "nnacl/int8/l2_norm_int8.h" #include -#include "nnacl/quantization/fixed_point.h" +#include "nnacl/int8/fixed_point.h" #include "nnacl/errorcode.h" int L2NormalizationInt8(const int8_t *input_data, int8_t *output_data, const L2NormParameter *param, diff --git a/mindspore/lite/nnacl/int8/layer_norm_int8.h b/mindspore/lite/nnacl/int8/layer_norm_int8.h index 9708a37c2c2..91e4b0b63fd 100644 --- a/mindspore/lite/nnacl/int8/layer_norm_int8.h +++ b/mindspore/lite/nnacl/int8/layer_norm_int8.h @@ -18,7 +18,7 @@ #include "nnacl/errorcode.h" #include "nnacl/layer_norm_parameter.h" -#include "nnacl/quantization/fixed_point.h" +#include "nnacl/int8/fixed_point.h" #ifdef __cplusplus extern "C" { diff --git a/mindspore/lite/nnacl/int8/leaky_relu_int8.h b/mindspore/lite/nnacl/int8/leaky_relu_int8.h index afe8c1934fe..ee8808f3a25 100644 --- a/mindspore/lite/nnacl/int8/leaky_relu_int8.h +++ b/mindspore/lite/nnacl/int8/leaky_relu_int8.h @@ -18,7 +18,7 @@ #define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_OPCLIB_INT8_PRELU_INT8_H_ #include "nnacl/op_base.h" -#include "nnacl/quantization/quantize.h" +#include "nnacl/int8/quantize.h" #ifdef __cplusplus extern "C" { diff --git a/mindspore/lite/nnacl/int8/matmul_int8.c b/mindspore/lite/nnacl/int8/matmul_int8.c index bc45ea4a766..9db55b6ac67 100644 --- a/mindspore/lite/nnacl/int8/matmul_int8.c +++ b/mindspore/lite/nnacl/int8/matmul_int8.c @@ -15,7 +15,7 @@ */ #include "nnacl/int8/matmul_int8.h" -#include "nnacl/quantization/fixed_point.h" +#include "nnacl/int8/fixed_point.h" void RowMajor2Row2x16MajorInt8(int8_t *src_ptr, int8_t *dst_ptr, int row, int col) { int col16 = UP_ROUND(col, C16NUM); diff --git a/mindspore/lite/nnacl/int8/matmul_int8.h b/mindspore/lite/nnacl/int8/matmul_int8.h index 65316338a16..35ae749ddf9 100644 --- a/mindspore/lite/nnacl/int8/matmul_int8.h +++ b/mindspore/lite/nnacl/int8/matmul_int8.h @@ -17,7 +17,6 @@ #ifndef MINDSPORE_LITE_NNACL_INT8_MATMUL_H_ #define MINDSPORE_LITE_NNACL_INT8_MATMUL_H_ -#include #include #include "nnacl/op_base.h" #include "nnacl/matmul_parameter.h" diff --git a/mindspore/lite/nnacl/int8/mul_int8.c b/mindspore/lite/nnacl/int8/mul_int8.c index 2685f041f79..7dce85d201f 100644 --- a/mindspore/lite/nnacl/int8/mul_int8.c +++ b/mindspore/lite/nnacl/int8/mul_int8.c @@ -15,15 +15,8 @@ */ #include "nnacl/int8/mul_int8.h" -#include "nnacl/mul_parameter.h" -#ifdef ENABLE_NEON -#include -#include "nnacl/int8/common_func_int8.h" -#endif -#include "nnacl/quantization/fixed_point.h" #ifdef ENABLE_NEON - int16x4_t ClacSumHalfWordMul(int16x4_t scaled_input0, int16x4_t scaled_input1, int32x4_t left_shift_out_vec, int32x4_t right_shift_out_vec, int32x4_t output_multiplier_vec) { int32x4_t input_scale = vmull_s16(scaled_input0, scaled_input1); diff --git a/mindspore/lite/nnacl/int8/mul_int8.h b/mindspore/lite/nnacl/int8/mul_int8.h index cb198f7fb0f..61ccfd89181 100644 --- a/mindspore/lite/nnacl/int8/mul_int8.h +++ b/mindspore/lite/nnacl/int8/mul_int8.h @@ -19,6 +19,11 @@ #include "nnacl/op_base.h" #include "nnacl/mul_parameter.h" +#include "nnacl/int8/common_func_int8.h" +#include "nnacl/int8/fixed_point.h" +#ifdef ENABLE_NEON +#include +#endif #ifdef __cplusplus extern "C" { diff --git a/mindspore/lite/nnacl/int8/pack_int8.c b/mindspore/lite/nnacl/int8/pack_int8.c new file mode 100644 index 00000000000..6055cb7d015 --- /dev/null +++ b/mindspore/lite/nnacl/int8/pack_int8.c @@ -0,0 +1,1296 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl/int8/pack_int8.h" + +#ifdef ENABLE_ARM32 +void PackInputSum16x4PerChannelArm32(const int8_t *input_value, int32_t *input_sum, int32_t *filter_zp_ptr, + size_t plane_size, size_t input_channel, size_t output_channel) { + size_t hw4 = UP_ROUND(plane_size, C4NUM); + size_t ic16 = UP_ROUND(input_channel, C16NUM); + +#ifdef ENABLE_ARM32 + size_t oc_div2 = output_channel / C2NUM * C2NUM; + size_t oc_res2 = output_channel - oc_div2; + size_t inputsun_stride = hw4 * C2NUM * 4 - C4NUM * C2NUM * 4; + PreSum4x16Int8Peroc(input_value, input_sum, filter_zp_ptr, hw4, ic16, oc_div2, oc_res2, inputsun_stride); +#else + for (int ri = 0; ri < plane_size; ri++) { + int ri4div = ri / C4NUM, ri4mod = ri % C4NUM; + for (int ci = 0; ci < output_channel; ci++) { + int32_t tmp_sum_value = 0; + int ci2div = ci / C2NUM, ci2mod = ci % C2NUM; + int32_t filter_zp = filter_zp_ptr[ci]; + for (int di = 0; di < input_channel; di++) { + size_t di16div = di / C16NUM, di16mod = di % C16NUM; + int src_index = ri4div * C4NUM * ic16 + di16div * C16NUM * C4NUM + ri4mod * C16NUM + di16mod; + tmp_sum_value += input_value[src_index]; + } + int dst_index = ci2div * C2NUM * hw4 + ri * C2NUM + ci2mod; + input_sum[dst_index] = tmp_sum_value * filter_zp; + } + } +#endif + return; +} +#endif + +void PackInputSum16x4PerChannel(const int8_t *input_value, int32_t *input_sum, int32_t *filter_zp_ptr, + size_t plane_size, size_t input_channel, size_t output_channel) { + size_t hw4 = UP_ROUND(plane_size, C4NUM); + size_t ic16 = UP_ROUND(input_channel, C16NUM); +#ifdef ENABLE_ARM64 + size_t oc_div4 = output_channel / C4NUM * C4NUM; + size_t oc_res4 = output_channel - oc_div4; + size_t inputsun_stride = hw4 * C4NUM * 4 - C4NUM * C4NUM * 4; + PreSum4x16Int8Peroc(input_value, input_sum, filter_zp_ptr, hw4, ic16, oc_div4, oc_res4, inputsun_stride); +#else + + for (int ri = 0; ri < plane_size; ri++) { + int ri4div = ri / C4NUM, ri4mod = ri % C4NUM; + for (int ci = 0; ci < output_channel; ci++) { + int32_t tmp_sum_value = 0; + int ci4div = ci / C4NUM, ci4mod = ci % C4NUM; + int32_t filter_zp = filter_zp_ptr[ci]; + for (int di = 0; di < input_channel; di++) { + size_t di16div = di / C16NUM, di16mod = di % C16NUM; + int src_index = ri4div * C4NUM * ic16 + di16div * C16NUM * C4NUM + ri4mod * C16NUM + di16mod; + tmp_sum_value += input_value[src_index]; + } + int dst_index = ci4div * C4NUM * hw4 + ri * C4NUM + ci4mod; + input_sum[dst_index] = tmp_sum_value * filter_zp; + } + } +#endif + return; +} + +void Conv1x1PreOptPeroc(const int8_t *src_input, int8_t *packed_input, int32_t *input_sum, size_t input_channel, + size_t output_channel, size_t plane_size, int32_t *filter_zp, size_t inputsum_stride) { + int ic4 = UP_ROUND(input_channel, C4NUM); + int oc8 = UP_ROUND(output_channel, C8NUM); + int hw8 = UP_ROUND(plane_size, C8NUM); + size_t hw_8div = plane_size / C8NUM * C8NUM; + size_t oc_8div = output_channel / C8NUM * C8NUM; + size_t oc_8res = output_channel - oc_8div; + size_t ic_4div = input_channel / C4NUM * C4NUM; + + const int8_t *src_r = src_input; + int8_t *pack_r = packed_input; + int32_t *input_sum_r = input_sum; + + for (int hwi = 0; hwi < hw_8div; hwi += C8NUM) { + const int8_t *src_ic = src_r; + int8_t *pack_ic = pack_r; + int32_t *input_sum_oc = input_sum_r; +#ifdef ENABLE_ARM64 + size_t src_stride = input_channel; + size_t ic_4res = input_channel - ic_4div; + size_t input_sum_stride = inputsum_stride * 4 - C8NUM * C8NUM * 4; + asm volatile( + "dup v16.4s, wzr \n" + "dup v17.4s, wzr \n" + + "mov x10, %[src_ic] \n" + "mov x11, %[pack_ic] \n" + + "mov x0, #0 \n" + "1: \n" + "cmp x0, %[ic_4div] \n" + "add x0, x0, #4\n" + "mov x12, x10 \n" + "add x10, x10, #4\n" + "blt 2f \n" + "cmp %[ic_4res], #0\n" + "beq 6f \n" + "cmp %[ic_4res], #1\n" + "beq 3f \n" + "cmp %[ic_4res], #2\n" + "beq 4f \n" + "cmp %[ic_4res], #3\n" + "beq 5f \n" + + "2: \n" + "ld1 {v0.s}[0], [x12], %[src_stride]\n" + "ld1 {v0.s}[1], [x12], %[src_stride]\n" + "ld1 {v0.s}[2], [x12], %[src_stride]\n" + "ld1 {v0.s}[3], [x12], %[src_stride]\n" + "ld1 {v1.s}[0], [x12], %[src_stride]\n" + "ld1 {v1.s}[1], [x12], %[src_stride]\n" + "ld1 {v1.s}[2], [x12], %[src_stride]\n" + "ld1 {v1.s}[3], [x12], %[src_stride]\n" + + "st1 {v0.16b}, [x11], #16\n" + "st1 {v1.16b}, [x11], #16\n" + + "saddlp v4.8h, v0.16b \n" + "saddlp v5.8h, v1.16b \n" + "saddlp v0.4s, v4.8h \n" + "saddlp v1.4s, v5.8h \n" + "add v16.4s, v16.4s, v0.4s \n" + "add v17.4s, v17.4s, v1.4s \n" + "b 1b \n" + + "3: \n" /* col res 1 */ + "dup v0.4s, wzr \n" + "dup v1.4s, wzr \n" + + "ld1 {v0.b}[0], [x12], %[src_stride]\n" + "ld1 {v0.b}[4], [x12], %[src_stride]\n" + "ld1 {v0.b}[8], [x12], %[src_stride]\n" + "ld1 {v0.b}[12], [x12], %[src_stride]\n" + "ld1 {v1.b}[0], [x12], %[src_stride]\n" + "ld1 {v1.b}[4], [x12], %[src_stride]\n" + "ld1 {v1.b}[8], [x12], %[src_stride]\n" + "ld1 {v1.b}[12], [x12], %[src_stride]\n" + + "st1 {v0.16b}, [x11], #16\n" + "st1 {v1.16b}, [x11], #16\n" + + "saddlp v4.8h, v0.16b \n" + "saddlp v5.8h, v1.16b \n" + "saddlp v0.4s, v4.8h \n" + "saddlp v1.4s, v5.8h \n" + "add v16.4s, v16.4s, v0.4s \n" + "add v17.4s, v17.4s, v1.4s \n" + "b 6f \n" + + "4: \n" /* col res 2 */ + "dup v0.4s, wzr \n" + "dup v1.4s, wzr \n" + + "ld1 {v0.h}[0], [x12], %[src_stride]\n" + "ld1 {v0.h}[2], [x12], %[src_stride]\n" + "ld1 {v0.h}[4], [x12], %[src_stride]\n" + "ld1 {v0.h}[6], [x12], %[src_stride]\n" + "ld1 {v1.h}[0], [x12], %[src_stride]\n" + "ld1 {v1.h}[2], [x12], %[src_stride]\n" + "ld1 {v1.h}[4], [x12], %[src_stride]\n" + "ld1 {v1.h}[6], [x12], %[src_stride]\n" + + "st1 {v0.16b}, [x11], #16\n" + "st1 {v1.16b}, [x11], #16\n" + + "saddlp v4.8h, v0.16b \n" + "saddlp v5.8h, v1.16b \n" + "saddlp v0.4s, v4.8h \n" + "saddlp v1.4s, v5.8h \n" + "add v16.4s, v16.4s, v0.4s \n" + "add v17.4s, v17.4s, v1.4s \n" + "b 6f \n" + + "5: \n" /* col res 3 */ + "dup v0.4s, wzr \n" + "dup v1.4s, wzr \n" + "add x13, x12, #2 \n" + + "ld1 {v0.h}[0], [x12], %[src_stride]\n" + "ld1 {v0.b}[2], [x13], %[src_stride]\n" + "ld1 {v0.h}[2], [x12], %[src_stride]\n" + "ld1 {v0.b}[6], [x13], %[src_stride]\n" + "ld1 {v0.h}[4], [x12], %[src_stride]\n" + "ld1 {v0.b}[10], [x13], %[src_stride]\n" + "ld1 {v0.h}[6], [x12], %[src_stride]\n" + "ld1 {v0.b}[14], [x13], %[src_stride]\n" + "ld1 {v1.h}[0], [x12], %[src_stride]\n" + "ld1 {v1.b}[2], [x13], %[src_stride]\n" + "ld1 {v1.h}[2], [x12], %[src_stride]\n" + "ld1 {v1.b}[6], [x13], %[src_stride]\n" + "ld1 {v1.h}[4], [x12], %[src_stride]\n" + "ld1 {v1.b}[10], [x13], %[src_stride]\n" + "ld1 {v1.h}[6], [x12], %[src_stride]\n" + "ld1 {v1.b}[14], [x13], %[src_stride]\n" + + "st1 {v0.16b}, [x11], #16\n" + "st1 {v1.16b}, [x11], #16\n" + + "saddlp v4.8h, v0.16b \n" + "saddlp v5.8h, v1.16b \n" + "saddlp v0.4s, v4.8h \n" + "saddlp v1.4s, v5.8h \n" + "add v16.4s, v16.4s, v0.4s \n" + "add v17.4s, v17.4s, v1.4s \n" + "b 6f \n" + + "6: \n" + "dup v0.4s, v16.s[0] \n" + "dup v1.4s, v16.s[1] \n" + "dup v2.4s, v16.s[2] \n" + "dup v3.4s, v16.s[3] \n" + "dup v4.4s, v17.s[0] \n" + "dup v5.4s, v17.s[1] \n" + "dup v6.4s, v17.s[2] \n" + "dup v7.4s, v17.s[3] \n" + "mov x4, #0 \n" + "mov x10, %[filter_zp] \n" + "mov x11, %[input_sum_oc] \n" + + "7: \n" + "cmp x4, %[oc_8div] \n" + "beq 8f \n" + "add x4, x4, #8\n" + "ld1 {v16.4s}, [x10], #16\n" + "ld1 {v17.4s}, [x10], #16\n" + + "mul v18.4s, v16.4s, v0.4s \n" + "mul v19.4s, v17.4s, v0.4s \n" + "st1 {v18.4s}, [x11], #16 \n" + "st1 {v19.4s}, [x11], #16 \n" + + "mul v20.4s, v16.4s, v1.4s \n" + "mul v21.4s, v17.4s, v1.4s \n" + "st1 {v20.4s}, [x11], #16 \n" + "st1 {v21.4s}, [x11], #16 \n" + + "mul v22.4s, v16.4s, v2.4s \n" + "mul v23.4s, v17.4s, v2.4s \n" + "st1 {v22.4s}, [x11], #16 \n" + "st1 {v23.4s}, [x11], #16 \n" + + "mul v24.4s, v16.4s, v3.4s \n" + "mul v25.4s, v17.4s, v3.4s \n" + "st1 {v24.4s}, [x11], #16 \n" + "st1 {v25.4s}, [x11], #16 \n" + + "mul v18.4s, v16.4s, v4.4s \n" + "mul v19.4s, v17.4s, v4.4s \n" + "st1 {v18.4s}, [x11], #16 \n" + "st1 {v19.4s}, [x11], #16 \n" + + "mul v20.4s, v16.4s, v5.4s \n" + "mul v21.4s, v17.4s, v5.4s \n" + "st1 {v20.4s}, [x11], #16 \n" + "st1 {v21.4s}, [x11], #16 \n" + + "mul v22.4s, v16.4s, v6.4s \n" + "mul v23.4s, v17.4s, v6.4s \n" + "st1 {v22.4s}, [x11], #16 \n" + "st1 {v23.4s}, [x11], #16 \n" + + "mul v24.4s, v16.4s, v7.4s \n" + "mul v25.4s, v17.4s, v7.4s \n" + "st1 {v24.4s}, [x11], #16 \n" + "st1 {v25.4s}, [x11], #16 \n" + + "add x11, x11, %[input_sum_stride] \n" + "b 7b \n" + + "8: \n" + "cmp %[oc_8res], #0\n" + "beq 17f \n" + + "dup v16.4s, wzr \n" + "dup v17.4s, wzr \n" + "cmp %[oc_8res], #1\n" + "beq 9f \n" + "cmp %[oc_8res], #2\n" + "beq 10f \n" + "cmp %[oc_8res], #3\n" + "beq 11f \n" + "cmp %[oc_8res], #4\n" + "beq 12f \n" + "cmp %[oc_8res], #5\n" + "beq 13f \n" + "cmp %[oc_8res], #6\n" + "beq 14f \n" + "cmp %[oc_8res], #7\n" + "beq 15f \n" + + "9: \n" + "ld1 {v16.s}[0], [x10] \n" + "b 16f \n" + + "10: \n" + "ld1 {v16.d}[0], [x10] \n" + "b 16f \n" + + "11: \n" + "ld1 {v16.d}[0], [x10] \n" + "add x10, x10, #8 \n" + "ld1 {v16.s}[2], [x10] \n" + "b 16f \n" + + "12: \n" + "ld1 {v16.4s}, [x10] \n" + "b 16f \n" + + "13: \n" + "ld1 {v16.4s}, [x10], #16\n" + "ld1 {v17.s}[0], [x10] \n" + "b 16f \n" + + "14: \n" + "ld1 {v16.4s}, [x10], #16\n" + "ld1 {v17.d}[0], [x10] \n" + "b 16f \n" + + "15: \n" + "ld1 {v16.4s}, [x10], #16\n" + "ld1 {v17.d}[0], [x10] \n" + "add x10, x10, #8 \n" + "ld1 {v17.s}[2], [x10] \n" + "b 16f \n" + + "16: \n" + "mul v18.4s, v16.4s, v0.4s \n" + "mul v19.4s, v17.4s, v0.4s \n" + "mul v20.4s, v16.4s, v1.4s \n" + "mul v21.4s, v17.4s, v1.4s \n" + "mul v22.4s, v16.4s, v2.4s \n" + "mul v23.4s, v17.4s, v2.4s \n" + "mul v24.4s, v16.4s, v3.4s \n" + "mul v25.4s, v17.4s, v3.4s \n" + "st1 {v18.4s}, [x11], #16 \n" + "st1 {v19.4s}, [x11], #16 \n" + "st1 {v20.4s}, [x11], #16 \n" + "st1 {v21.4s}, [x11], #16 \n" + "st1 {v22.4s}, [x11], #16 \n" + "st1 {v23.4s}, [x11], #16 \n" + "st1 {v24.4s}, [x11], #16 \n" + "st1 {v25.4s}, [x11], #16 \n" + + "mul v18.4s, v16.4s, v4.4s \n" + "mul v19.4s, v17.4s, v4.4s \n" + "mul v20.4s, v16.4s, v5.4s \n" + "mul v21.4s, v17.4s, v5.4s \n" + "mul v22.4s, v16.4s, v6.4s \n" + "mul v23.4s, v17.4s, v6.4s \n" + "mul v24.4s, v16.4s, v7.4s \n" + "mul v25.4s, v17.4s, v7.4s \n" + "st1 {v18.4s}, [x11], #16 \n" + "st1 {v19.4s}, [x11], #16 \n" + "st1 {v20.4s}, [x11], #16 \n" + "st1 {v21.4s}, [x11], #16 \n" + "st1 {v22.4s}, [x11], #16 \n" + "st1 {v23.4s}, [x11], #16 \n" + "st1 {v24.4s}, [x11], #16 \n" + "st1 {v25.4s}, [x11], #16 \n" + + "17: \n" + + : + : [ src_ic ] "r"(src_ic), [ pack_ic ] "r"(pack_ic), [ filter_zp ] "r"(filter_zp), + [ input_sum_oc ] "r"(input_sum_oc), [ input_sum_stride ] "r"(input_sum_stride), [ src_stride ] "r"(src_stride), + [ ic_4div ] "r"(ic_4div), [ ic_4res ] "r"(ic_4res), [ oc_8div ] "r"(oc_8div), [ oc_8res ] "r"(oc_8res) + : "x0", "x1", "x4", "x9", "x10", "x11", "x12", "x13", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v16", + "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25"); +#else + int32_t tmp_sum_value[8] = {0}; + for (int ici = 0; ici < ic_4div; ici += C4NUM) { + for (int i = 0; i < C8NUM; i++) { + tmp_sum_value[i] += src_ic[0 + i * input_channel]; + tmp_sum_value[i] += src_ic[1 + i * input_channel]; + tmp_sum_value[i] += src_ic[2 + i * input_channel]; + tmp_sum_value[i] += src_ic[3 + i * input_channel]; + pack_ic[0 + i * C4NUM] = src_ic[0 + i * input_channel]; + pack_ic[1 + i * C4NUM] = src_ic[1 + i * input_channel]; + pack_ic[2 + i * C4NUM] = src_ic[2 + i * input_channel]; + pack_ic[3 + i * C4NUM] = src_ic[3 + i * input_channel]; + } + src_ic += C4NUM; + pack_ic += C4NUM * C8NUM; + } + for (int ici = ic_4div; ici < input_channel; ici += 1) { + for (int i = 0; i < C8NUM; i++) { + tmp_sum_value[i] += src_ic[i * input_channel]; + pack_ic[i * C4NUM] = src_ic[i * input_channel]; + } + src_ic += 1; + pack_ic += 1; + } + + for (int ici = input_channel; ici < ic4; ici += 1) { + for (int i = 0; i < C8NUM; i++) { + pack_ic[i * C4NUM] = 0; + } + pack_ic += 1; + } + + for (int oci = 0; oci < oc_8div; oci += C8NUM) { + for (int ri = 0; ri < C8NUM; ri++) { + input_sum_oc[ri * C8NUM + 0] = tmp_sum_value[ri] * filter_zp[oci + 0]; + input_sum_oc[ri * C8NUM + 1] = tmp_sum_value[ri] * filter_zp[oci + 1]; + input_sum_oc[ri * C8NUM + 2] = tmp_sum_value[ri] * filter_zp[oci + 2]; + input_sum_oc[ri * C8NUM + 3] = tmp_sum_value[ri] * filter_zp[oci + 3]; + input_sum_oc[ri * C8NUM + 4] = tmp_sum_value[ri] * filter_zp[oci + 4]; + input_sum_oc[ri * C8NUM + 5] = tmp_sum_value[ri] * filter_zp[oci + 5]; + input_sum_oc[ri * C8NUM + 6] = tmp_sum_value[ri] * filter_zp[oci + 6]; + input_sum_oc[ri * C8NUM + 7] = tmp_sum_value[ri] * filter_zp[oci + 7]; + } + input_sum_oc += inputsum_stride; + } + if (oc_8div != output_channel) { + for (int oci = 0; oci < oc_8res; oci += 1) { + for (int ri = 0; ri < C8NUM; ri++) { + input_sum_oc[ri * C8NUM + oci] = tmp_sum_value[ri] * filter_zp[oc_8div + oci]; + } + } + for (int oci = oc_8res; oci < C8NUM; oci += 1) { + for (int ri = 0; ri < C8NUM; ri++) { + input_sum_oc[ri * C8NUM + oci] = 0; + } + } + } /* oc8 res done */ +#endif + src_r += input_channel * C8NUM; + pack_r += ic4 * C8NUM; + input_sum_r += C8NUM * C8NUM; + } + + if (hw_8div != plane_size) { + memset(pack_r, 0, C8NUM * ic4); + for (int hwi = hw_8div; hwi < plane_size; hwi += 1) { + int32_t *input_sum_oc = input_sum_r; + int32_t tmp_sum_value = 0; + const int8_t *src_ic = src_r; + int8_t *pack_ic = pack_r; + for (int ici = 0; ici < ic_4div; ici += C4NUM) { + tmp_sum_value += src_ic[0]; + tmp_sum_value += src_ic[1]; + tmp_sum_value += src_ic[2]; + tmp_sum_value += src_ic[3]; + pack_ic[0] = src_ic[0]; + pack_ic[1] = src_ic[1]; + pack_ic[2] = src_ic[2]; + pack_ic[3] = src_ic[3]; + src_ic += C4NUM; + pack_ic += C4NUM * C8NUM; + } + for (int ici = ic_4div; ici < input_channel; ici += 1) { + tmp_sum_value += src_ic[0]; + pack_ic[0] = src_ic[0]; + src_ic += 1; + pack_ic += 1; + } + + for (int oci = 0; oci < oc_8div; oci += C8NUM) { + for (int curoi = 0; curoi < C8NUM; curoi++) { + input_sum_oc[curoi] = tmp_sum_value * filter_zp[oci + curoi]; + } + input_sum_oc += inputsum_stride; + } + if (oc_8div != output_channel) { + for (int oci = 0; oci < oc_8res; oci += 1) { + input_sum_oc[oci] = tmp_sum_value * filter_zp[oc_8div + oci]; + } + for (int oci = oc_8res; oci < C8NUM; oci += 1) { + input_sum_oc[oci] = 0; + } + } /* oc8 res done */ + + src_r += input_channel; + pack_r += C4NUM; + input_sum_r += C8NUM; + } + + for (int hwi = plane_size; hwi < hw8; hwi++) { + for (int oc = 0; oc < oc8; oc++) { + int oc8div = oc / C8NUM, oc8res = oc % C8NUM; + input_sum[oc8div * inputsum_stride + hwi * C8NUM + oc8res] = 0; + } + } + } + return; +} + +void Conv1x1PreOptPert(const int8_t *src_input, int8_t *packed_input, int32_t *input_sum, size_t input_channel, + size_t plane_size, ConvParameter *conv_param) { + int ic4 = UP_ROUND(input_channel, C4NUM); + size_t hw_8div = plane_size / C8NUM * C8NUM; + size_t ic_4div = input_channel / C4NUM * C4NUM; + int32_t filter_zp = conv_param->conv_quant_arg_.filter_quant_args_[0].zp_; + + const int8_t *src_r = src_input; + int8_t *pack_r = packed_input; + /* per layer */ + for (int hwi = 0; hwi < hw_8div; hwi += C8NUM) { + const int8_t *src_ic = src_r; + int8_t *pack_ic = pack_r; + int32_t *input_sum_r = input_sum + hwi; +#ifdef ENABLE_ARM64 + size_t src_stride = input_channel; + size_t ic_4res = input_channel - ic_4div; + asm volatile( + "dup v16.4s, wzr \n" + "dup v17.4s, wzr \n" + "mov x14, %[input_sum_r] \n" + "dup v20.4s, %w[filter_zp] \n" + + "mov x10, %[src_ic] \n" + "mov x11, %[pack_ic] \n" + + "mov x0, #0 \n" + "1: \n" + "cmp x0, %[ic_4div] \n" + "add x0, x0, #4\n" + "mov x12, x10 \n" + "add x10, x10, #4\n" + "blt 2f \n" + "cmp %[ic_4res], #0\n" + "beq 6f \n" + "cmp %[ic_4res], #1\n" + "beq 3f \n" + "cmp %[ic_4res], #2\n" + "beq 4f \n" + "cmp %[ic_4res], #3\n" + "beq 5f \n" + + "2: \n" + "ld1 {v0.s}[0], [x12], %[src_stride]\n" + "ld1 {v0.s}[1], [x12], %[src_stride]\n" + "ld1 {v0.s}[2], [x12], %[src_stride]\n" + "ld1 {v0.s}[3], [x12], %[src_stride]\n" + "ld1 {v1.s}[0], [x12], %[src_stride]\n" + "ld1 {v1.s}[1], [x12], %[src_stride]\n" + "ld1 {v1.s}[2], [x12], %[src_stride]\n" + "ld1 {v1.s}[3], [x12], %[src_stride]\n" + + "st1 {v0.16b}, [x11], #16\n" + "st1 {v1.16b}, [x11], #16\n" + + "saddlp v4.8h, v0.16b \n" + "saddlp v5.8h, v1.16b \n" + + "saddlp v0.4s, v4.8h \n" + "saddlp v1.4s, v5.8h \n" + + "add v16.4s, v16.4s, v0.4s \n" + "add v17.4s, v17.4s, v1.4s \n" + "b 1b \n" + + "3: \n" /* col res 1 */ + "dup v0.4s, wzr \n" + "dup v1.4s, wzr \n" + + "ld1 {v0.b}[0], [x12], %[src_stride]\n" + "ld1 {v0.b}[4], [x12], %[src_stride]\n" + "ld1 {v0.b}[8], [x12], %[src_stride]\n" + "ld1 {v0.b}[12], [x12], %[src_stride]\n" + "ld1 {v1.b}[0], [x12], %[src_stride]\n" + "ld1 {v1.b}[4], [x12], %[src_stride]\n" + "ld1 {v1.b}[8], [x12], %[src_stride]\n" + "ld1 {v1.b}[12], [x12], %[src_stride]\n" + + "st1 {v0.16b}, [x11], #16\n" + "st1 {v1.16b}, [x11], #16\n" + "saddlp v4.8h, v0.16b \n" + "saddlp v5.8h, v1.16b \n" + "saddlp v0.4s, v4.8h \n" + "saddlp v1.4s, v5.8h \n" + "add v16.4s, v16.4s, v0.4s \n" + "add v17.4s, v17.4s, v1.4s \n" + "b 6f \n" + + "4: \n" /* col res 2 */ + "dup v0.4s, wzr \n" + "dup v1.4s, wzr \n" + + "ld1 {v0.h}[0], [x12], %[src_stride]\n" + "ld1 {v0.h}[2], [x12], %[src_stride]\n" + "ld1 {v0.h}[4], [x12], %[src_stride]\n" + "ld1 {v0.h}[6], [x12], %[src_stride]\n" + "ld1 {v1.h}[0], [x12], %[src_stride]\n" + "ld1 {v1.h}[2], [x12], %[src_stride]\n" + "ld1 {v1.h}[4], [x12], %[src_stride]\n" + "ld1 {v1.h}[6], [x12], %[src_stride]\n" + + "st1 {v0.16b}, [x11], #16\n" + "st1 {v1.16b}, [x11], #16\n" + "saddlp v4.8h, v0.16b \n" + "saddlp v5.8h, v1.16b \n" + "saddlp v0.4s, v4.8h \n" + "saddlp v1.4s, v5.8h \n" + "add v16.4s, v16.4s, v0.4s \n" + "add v17.4s, v17.4s, v1.4s \n" + "b 6f \n" + + "5: \n" /* col res 3 */ + "dup v0.4s, wzr \n" + "dup v1.4s, wzr \n" + "add x13, x12, #2 \n" + + "ld1 {v0.h}[0], [x12], %[src_stride]\n" + "ld1 {v0.b}[2], [x13], %[src_stride]\n" + "ld1 {v0.h}[2], [x12], %[src_stride]\n" + "ld1 {v0.b}[6], [x13], %[src_stride]\n" + "ld1 {v0.h}[4], [x12], %[src_stride]\n" + "ld1 {v0.b}[10], [x13], %[src_stride]\n" + "ld1 {v0.h}[6], [x12], %[src_stride]\n" + "ld1 {v0.b}[14], [x13], %[src_stride]\n" + "ld1 {v1.h}[0], [x12], %[src_stride]\n" + "ld1 {v1.b}[2], [x13], %[src_stride]\n" + "ld1 {v1.h}[2], [x12], %[src_stride]\n" + "ld1 {v1.b}[6], [x13], %[src_stride]\n" + "ld1 {v1.h}[4], [x12], %[src_stride]\n" + "ld1 {v1.b}[10], [x13], %[src_stride]\n" + "ld1 {v1.h}[6], [x12], %[src_stride]\n" + "ld1 {v1.b}[14], [x13], %[src_stride]\n" + + "st1 {v0.16b}, [x11], #16\n" + "st1 {v1.16b}, [x11], #16\n" + "saddlp v4.8h, v0.16b \n" + "saddlp v5.8h, v1.16b \n" + "saddlp v0.4s, v4.8h \n" + "saddlp v1.4s, v5.8h \n" + "add v16.4s, v16.4s, v0.4s \n" + "add v17.4s, v17.4s, v1.4s \n" + "b 6f \n" + + "6: \n" + "mul v16.4s, v16.4s, v20.4s \n" + "mul v17.4s, v17.4s, v20.4s \n" + + "st1 {v16.4s}, [x14], #16 \n" + "st1 {v17.4s}, [x14], #16 \n" + + : + : [ src_ic ] "r"(src_ic), [ pack_ic ] "r"(pack_ic), [ input_sum_r ] "r"(input_sum_r), + [ src_stride ] "r"(src_stride), [ ic_4div ] "r"(ic_4div), [ ic_4res ] "r"(ic_4res), [ filter_zp ] "r"(filter_zp) + : "x0", "x1", "x10", "x11", "x12", "x13", "x14", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v16", "v17", + "v20"); +#else + int32_t tmp_sum_value[8] = {0}; + for (int ici = 0; ici < ic_4div; ici += C4NUM) { + for (int i = 0; i < C8NUM; i++) { + tmp_sum_value[i] += src_ic[0 + i * input_channel]; + tmp_sum_value[i] += src_ic[1 + i * input_channel]; + tmp_sum_value[i] += src_ic[2 + i * input_channel]; + tmp_sum_value[i] += src_ic[3 + i * input_channel]; + pack_ic[0 + i * C4NUM] = src_ic[0 + i * input_channel]; + pack_ic[1 + i * C4NUM] = src_ic[1 + i * input_channel]; + pack_ic[2 + i * C4NUM] = src_ic[2 + i * input_channel]; + pack_ic[3 + i * C4NUM] = src_ic[3 + i * input_channel]; + } + src_ic += C4NUM; + pack_ic += C4NUM * C8NUM; + } + for (int ici = ic_4div; ici < input_channel; ici += 1) { + for (int i = 0; i < C8NUM; i++) { + tmp_sum_value[i] += src_ic[i * input_channel]; + pack_ic[i * C4NUM] = src_ic[i * input_channel]; + } + src_ic += 1; + pack_ic += 1; + } + + for (int ici = input_channel; ici < ic4; ici += 1) { + for (int i = 0; i < C8NUM; i++) { + pack_ic[i * C4NUM] = 0; + } + pack_ic += 1; + } + + for (int i = 0; i < C8NUM; i++) { + input_sum_r[i] = tmp_sum_value[i] * filter_zp; + } +#endif + src_r += input_channel * C8NUM; + pack_r += ic4 * C8NUM; + } + + if (hw_8div != plane_size) { + memset(pack_r, 0, C8NUM * ic4); + for (int hwi = hw_8div; hwi < plane_size; hwi += 1) { + int32_t tmp_sum_value = 0; + const int8_t *src_ic = src_r; + int8_t *pack_ic = pack_r; + for (int ici = 0; ici < ic_4div; ici += C4NUM) { + tmp_sum_value += src_ic[0]; + tmp_sum_value += src_ic[1]; + tmp_sum_value += src_ic[2]; + tmp_sum_value += src_ic[3]; + pack_ic[0] = src_ic[0]; + pack_ic[1] = src_ic[1]; + pack_ic[2] = src_ic[2]; + pack_ic[3] = src_ic[3]; + src_ic += C4NUM; + pack_ic += C4NUM * C8NUM; + } + for (int ici = ic_4div; ici < input_channel; ici += 1) { + tmp_sum_value += src_ic[0]; + pack_ic[0] = src_ic[0]; + src_ic += 1; + pack_ic += 1; + } + input_sum[hwi] = tmp_sum_value * filter_zp; + src_r += input_channel; + pack_r += C4NUM; + } + for (int hwi = plane_size; hwi < UP_ROUND(plane_size, C8NUM); hwi++) { + input_sum[hwi] = 0; + } + } + return; +} + +void Im2ColPackUnitInt8Opt(const int8_t *input_data, int8_t *packed_input, int8_t *matmul_input, int real_cal_num, + int block_index, int32_t *filter_zp, int32_t *input_sum, ConvParameter *conv_param, + bool per_channel, bool is_optimize) { + // input format : nhwc + int kernel_h = conv_param->kernel_h_; + int kernel_w = conv_param->kernel_w_; + int stride_h = conv_param->stride_h_; + int stride_w = conv_param->stride_w_; + int pad_h = conv_param->pad_u_; + int pad_w = conv_param->pad_l_; + int dilation_h = conv_param->dilation_h_; + int dilation_w = conv_param->dilation_w_; + int in_channel = conv_param->input_channel_; + int in_h = conv_param->input_h_; + int in_w = conv_param->input_w_; + int out_w = conv_param->output_w_; + int kernel_plane = kernel_h * kernel_w; + + for (int i = 0; i < real_cal_num; i++) { + int block_start = block_index + i; + int input_h = block_start / out_w * stride_h - pad_h; + int input_w = block_start % out_w * stride_w - pad_w; + int input_stride = input_h * in_w * in_channel + input_w * in_channel; + int kh_s = MSMAX(0, UP_DIV(-input_h, dilation_h)); + int kh_e = MSMIN(kernel_h, UP_DIV(in_h - input_h, dilation_h)); + int kw_s = MSMAX(0, UP_DIV(-input_w, dilation_w)); + int kw_e = MSMIN(kernel_w, UP_DIV(in_w - input_w, dilation_w)); + if (dilation_w == 1 && dilation_h == 1) { + for (int j = kh_s; j < kh_e; j++) { + int input_y_stride = j * in_w * in_channel + input_stride; + int input_x_stride = input_y_stride + kw_s * in_channel; + int input_plane_offset = (j * kernel_w + kw_s) * in_channel + i * in_channel * kernel_plane; + memcpy(matmul_input + input_plane_offset, input_data + input_x_stride, (kw_e - kw_s) * in_channel); + } // kernel_h loop + } else { + for (int j = kh_s; j < kh_e; j++) { + int input_y_stride = j * dilation_h * in_w * in_channel + input_stride; + for (int k = kw_s; k < kw_e; ++k) { + int input_x_stride = input_y_stride + k * dilation_w * in_channel; + int input_plane_offset = (j * kernel_w + k) * in_channel + i * in_channel * kernel_plane; + memcpy(matmul_input + input_plane_offset, input_data + input_x_stride, in_channel); + } + } // kernel_h loop + } + } // tile num loop + int deep = kernel_plane * in_channel; + if (is_optimize) { + if (per_channel) { + Conv1x1PreOptPeroc(matmul_input, packed_input, input_sum, deep, conv_param->output_channel_, real_cal_num, + filter_zp, C8NUM * C8NUM); + } else { + Conv1x1PreOptPert(matmul_input, packed_input, input_sum, deep, real_cal_num, conv_param); + } + } else { + RowMajor2Row16x4MajorInt8(matmul_input, packed_input, real_cal_num, deep); + if (per_channel) { +#ifdef ENABLE_ARM32 + PackInputSum16x4PerChannelArm32(packed_input, input_sum, filter_zp, real_cal_num, deep, + conv_param->output_channel_); +#else + PackInputSum16x4PerChannel(packed_input, input_sum, filter_zp, real_cal_num, deep, conv_param->output_channel_); +#endif + } else { + size_t hw4 = UP_ROUND(real_cal_num, C4NUM); + size_t ic16 = UP_ROUND(deep, C16NUM); + PackInputSum16x4PerLayer(packed_input, input_sum, conv_param->conv_quant_arg_.filter_quant_args_[0].zp_, hw4, + ic16); + } + } +} + +void PackInputToC8Int8(const int8_t *input_data, int16_t *packed_input, ConvParameter *conv_param) { + int in_batch = conv_param->input_batch_; + int in_channel = conv_param->input_channel_; + int in_h = conv_param->input_h_; + int in_w = conv_param->input_w_; + int ic8_round = UP_ROUND(in_channel, C8NUM); + int ic8 = in_channel / C8NUM * C8NUM; + int in_plane = in_h * in_w; + + for (int b = 0; b < in_batch; b++) { + int src_batch_offset = b * in_channel * in_plane; + int dst_batch_offset = b * ic8_round * in_plane; + for (int k = 0; k < in_plane; k++) { + int src_plane_offset = src_batch_offset + k * in_channel; + int dst_plane_offset = dst_batch_offset + k * C8NUM; + for (int i = 0; i < ic8; i += 8) { + int src_c_offset = src_plane_offset + i; + int dst_c_offset = dst_plane_offset + i * in_plane; +#ifdef ENABLE_ARM + vst1q_s16(packed_input + dst_c_offset, vmovl_s8(vld1_s8(input_data + src_c_offset))); +#else + for (int j = 0; j < C8NUM; ++j) { + (packed_input + dst_c_offset)[j] = (int16_t)(input_data + src_c_offset)[j]; + } +#endif + } // ic8_minus loop + int res_c = in_channel - ic8; + int tmp_ic_offset = ic8 * in_plane; + for (int l = 0; l < res_c; ++l) { + int src_c_offset = src_plane_offset + ic8 + l; + int dst_c_offset = dst_plane_offset + tmp_ic_offset + l; + (packed_input + dst_c_offset)[0] = (int16_t)(input_data + src_c_offset)[0]; + } // res ic loop + int res2 = ic8_round - in_channel; + for (int l = 0; l < res2; ++l) { + int dst_c_offset = dst_plane_offset + tmp_ic_offset + res_c + l; + (packed_input + dst_c_offset)[0] = 0; + } // res ic loop + } // kh * kw loop + } +} + +void PackWeightToC8Int8(const int8_t *origin_weight_data, int16_t *packed_weight_data, ConvParameter *conv_param) { + // origin weight format : ohwi + int input_channel = conv_param->input_channel_; + int ic8 = input_channel / C8NUM * C8NUM; + int ic8_round = UP_ROUND(input_channel, C8NUM); + int output_channel = conv_param->output_channel_; + QuantArg *filter_zp = conv_param->conv_quant_arg_.filter_quant_args_; + int kernel_plane = conv_param->kernel_h_ * conv_param->kernel_w_; + + for (int k = 0; k < kernel_plane; k++) { + int src_kernel_offset = k * input_channel; + int dst_kernel_offset = k * C8NUM; + for (int o = 0; o < output_channel; o++) { + int32_t zp; + if (conv_param->conv_quant_arg_.filter_arg_num_ == 1) { + zp = filter_zp[0].zp_; + } else { + zp = filter_zp[o].zp_; + } + int src_oc_offset = src_kernel_offset + o * kernel_plane * input_channel; + int dst_oc_offset = dst_kernel_offset + o * ic8_round * kernel_plane; + int i = 0; + for (; i < ic8; i += C8NUM) { + int src_ic_offset = src_oc_offset + i; + int dst_ic_offset = dst_oc_offset + i * kernel_plane; +#ifdef ENABLE_ARM64 + int8x8_t src_s8 = vld1_s8(origin_weight_data + src_ic_offset); + int16x8_t src_s16 = vmovl_s8(src_s8); + int16x4_t src1_s16 = vget_low_s16(src_s16); + int16x4_t src2_s16 = vget_high_s16(src_s16); + int32x4_t src1_s32 = vmovl_s16(src1_s16); + int32x4_t src2_s32 = vmovl_s16(src2_s16); + int32x4_t zp_s32 = vdupq_n_s32(zp); + int32x4_t dst1_s32 = vsubq_s32(src1_s32, zp_s32); + int32x4_t dst2_s32 = vsubq_s32(src2_s32, zp_s32); + int16x4_t dst1_s16 = vqmovn_s32(dst1_s32); + int16x4_t dst2_s16 = vqmovn_s32(dst2_s32); + vst1_s16(packed_weight_data + dst_ic_offset, dst1_s16); + vst1_s16(packed_weight_data + dst_ic_offset + 4, dst2_s16); +#else + for (int ci = 0; ci < C8NUM; ++ci) { + (packed_weight_data + dst_ic_offset + ci)[0] = (int16_t)((origin_weight_data + src_ic_offset + ci)[0] - zp); + } +#endif + } + dst_oc_offset += ic8 * kernel_plane; + for (; i < input_channel; i++) { + int c8_block_rem = i % C8NUM; + int src_ic_offset = src_oc_offset + i; + int dst_ic_offset = dst_oc_offset + c8_block_rem; + (packed_weight_data + dst_ic_offset)[0] = (int16_t)((origin_weight_data + src_ic_offset)[0] - zp); + } + } + } +} + +void PackInputSum16x4Int8(const int8_t *input, int32_t *input_sum, int32_t *filter_zp, ConvParameter *conv_param) { + size_t hw = conv_param->output_h_ * conv_param->output_w_; + size_t hw4 = UP_ROUND(hw, C4NUM); + size_t ic16 = UP_ROUND(conv_param->input_channel_, C16NUM); + if (conv_param->conv_quant_arg_.filter_arg_num_ == 1) { + PackInputSum16x4PerLayer(input, input_sum, conv_param->conv_quant_arg_.filter_quant_args_[0].zp_, hw4, ic16); + } else { +#ifdef ENABLE_ARM32 + PackInputSum16x4PerChannelArm32(input, input_sum, filter_zp, hw, conv_param->input_channel_, + conv_param->output_channel_); +#else + PackInputSum16x4PerChannel(input, input_sum, filter_zp, hw, conv_param->input_channel_, + conv_param->output_channel_); +#endif + } + return; +} + +void PackInputSum16x4PerLayer(const int8_t *src, int32_t *dst, int32_t filter_zp, size_t row4, size_t col16) { + /* normal matmul : 4x16 * 16x4 -> 4x4 */ +#ifdef ENABLE_ARM + PreSum4x16Int8Pert(src, dst, row4, col16, filter_zp); +#else + for (int r = 0; r < row4; r++) { + int32_t tmp_value = 0; + for (int c = 0; c < col16; c++) { + int r4div = r / C4NUM, r4mod = r % C4NUM, c16div = c / C16NUM, c16mod = c % C16NUM; + int src_index = r4div * C4NUM * col16 + c16div * C16NUM * C4NUM + r4mod * C16NUM + c16mod; + tmp_value += src[src_index]; + } + dst[r] = tmp_value * filter_zp; + } +#endif + return; +} +void PackDepthwiseInt8Input(const int8_t *src, int16_t *dst, const ConvParameter *conv_param) { + int input_zp = conv_param->conv_quant_arg_.input_quant_args_[0].zp_; + int ic4 = UP_DIV(conv_param->input_channel_, C4NUM); + int unit = conv_param->input_h_ * conv_param->input_w_; + + for (int b = 0; b < conv_param->input_batch_; b++) { + const int8_t *src_b = src + b * unit * conv_param->input_channel_; + int16_t *dst_b = dst + b * unit * ic4 * C4NUM; + for (int k = 0; k < unit; k++) { + const int8_t *src_k = src_b + k * conv_param->input_channel_; + int16_t *dst_k = dst_b + k * ic4 * C4NUM; + for (int c = 0; c < conv_param->input_channel_; c++) { + dst_k[c] = (int16_t)(src_k[c] - input_zp); + } + } + } +} + +void PackDepthwiseInt8Weight(const int8_t *origin_weight, int16_t *packed_weight_, int plane, int channel, + ConvQuantArg *quant_qrg) { + int weight_zp = quant_qrg->filter_quant_args_[0].zp_; + for (int c = 0; c < channel; c++) { + if (quant_qrg->per_channel_ & FILTER_PER_CHANNEL) { + weight_zp = quant_qrg->filter_quant_args_[c].zp_; + } + int c8_block_num = c / C8NUM; + int c8_block_rem = c % C8NUM; + const int8_t *src_c = origin_weight + c * plane; + int16_t *dst_c = packed_weight_ + c8_block_num * plane * C8NUM; + for (int k = 0; k < plane; k++) { + const int8_t *src_kernel = src_c + k; + int16_t *dst_kernel = dst_c + C8NUM * k + c8_block_rem; + *dst_kernel = (int16_t)(src_kernel[0] - weight_zp); + } + } +} + +void PackDeconvDepthwiseInt8Weight(const int8_t *origin_weight, int16_t *packed_weight_, int plane, int channel, + ConvQuantArg *quant_qrg) { + int weight_zp = quant_qrg->filter_quant_args_[0].zp_; + for (int c = 0; c < channel; c++) { + if (quant_qrg->per_channel_ & FILTER_PER_CHANNEL) { + weight_zp = quant_qrg->filter_quant_args_[c].zp_; + } + int c4_block_num = c / C4NUM; + int c4_block_rem = c % C4NUM; + const int8_t *src_c = origin_weight + c * plane; + int16_t *dst_c = packed_weight_ + c4_block_num * plane * C4NUM; + for (int k = 0; k < plane; k++) { + const int8_t *src_kernel = src_c + k; + int16_t *dst_kernel = dst_c + C4NUM * k + c4_block_rem; + *dst_kernel = (int16_t)(src_kernel[0] - weight_zp); + } + } +} +void PackNHWCToNHWC4Int8(const void *src, void *dst, int batch, int plane, int channel) { + int c4 = UP_DIV(channel, C4NUM); + int c4_channel = c4 * C4NUM; + int nhwc4_batch_unit_offset = c4 * C4NUM * plane; + int ic_remainder_ = channel % C4NUM; + if (ic_remainder_ != 0) { + int nhwc4_batch_offset = 0; + for (int b = 0; b < batch; b++) { + int batch_offset = b * channel * plane; + for (int i = 0; i < plane; i++) { + int8_t *dst_per_plane = (int8_t *)dst + nhwc4_batch_offset + i * c4_channel; + memcpy(dst_per_plane, (int8_t *)src + batch_offset + i * channel, channel); + for (int j = channel; j < c4_channel; ++j) { + dst_per_plane[j] = 0; + } + } + nhwc4_batch_offset += nhwc4_batch_unit_offset; + } + } else { + size_t ori_input_size = batch * plane * channel; + memcpy((int8_t *)dst, (int8_t *)src, ori_input_size); + } +} + +void PackNHWC4ToNHWCInt8(const void *src, void *dst, int batch, int plane, int channel) { + int c4 = UP_DIV(channel, C4NUM); + int nhwc4_batch_unit_offset = c4 * C4NUM * plane; + int ic_remainder_ = channel % C4NUM; + if (ic_remainder_ != 0) { + for (int b = 0; b < batch; b++) { + int batch_offset = b * channel * plane; + int nhwc4_batch_offset = b * nhwc4_batch_unit_offset; + for (int i = 0; i < plane; i++) { + memcpy((int8_t *)dst + batch_offset + i * channel, (int8_t *)src + nhwc4_batch_offset + i * c4 * C4NUM, + channel); + } + } + } else { + size_t ori_input_size = batch * plane * channel; + memcpy((int8_t *)dst, (int8_t *)src, ori_input_size); + } +} + +void PackNHWCToNHWC8Int8(const void *src, void *dst, int batch, int plane, int channel) { + int c8 = UP_DIV(channel, C8NUM); + int nhwc8_batch_unit_offset = c8 * C8NUM * plane; + int ic_remainder_ = channel % C8NUM; + if (ic_remainder_ != 0) { + int nhwc8_batch_offset = 0; + for (int b = 0; b < batch; b++) { + int batch_offset = b * channel * plane; + for (int i = 0; i < plane; i++) { + memcpy((int8_t *)dst + nhwc8_batch_offset + i * c8 * C8NUM, (int8_t *)src + batch_offset + i * channel, + channel); + } + nhwc8_batch_offset += nhwc8_batch_unit_offset; + } + } else { + size_t ori_input_size = batch * plane * channel; + memcpy((int8_t *)dst, (int8_t *)src, ori_input_size); + } +} + +void PackNHWC8ToNHWCInt8(const void *src, void *dst, int batch, int plane, int channel) { + int c8 = UP_DIV(channel, C8NUM); + int nhwc8_batch_unit_offset = c8 * C8NUM * plane; + int ic_remainder_ = channel % C8NUM; + if (ic_remainder_ != 0) { + for (int b = 0; b < batch; b++) { + int batch_offset = b * channel * plane; + int nhwc8_batch_offset = b * nhwc8_batch_unit_offset; + for (int i = 0; i < plane; i++) { + memcpy((int8_t *)dst + batch_offset + i * channel, (int8_t *)src + nhwc8_batch_offset + i * c8 * C8NUM, + channel); + } + } + } else { + size_t ori_input_size = batch * plane * channel; + memcpy((int8_t *)dst, (int8_t *)src, ori_input_size); + } +} + +void PackNCHWToNC8HW8Int8(const void *src, void *dst, int batch, int plane, int channel) { + int c8 = UP_DIV(channel, C8NUM); + for (int b = 0; b < batch; b++) { + int src_offset = b * plane * channel; + int dst_offset = b * plane * c8 * C8NUM; + for (int c = 0; c < channel; c++) { + int c8_block_num = c / C8NUM; + int c8_block_rem = c % C8NUM; + int src_c_offset = src_offset + c * plane; + int dst_c_offset = dst_offset + c8_block_num * plane * C8NUM; + for (int k = 0; k < plane; k++) { + int src_kernel_offset = src_c_offset + k; + int dst_kernel_offset = dst_c_offset + C8NUM * k + c8_block_rem; + ((int8_t *)dst + dst_kernel_offset)[0] = ((int8_t *)src + src_kernel_offset)[0]; + } + } + } +} + +void PackNC4HW4ToNHWCInt8(const void *src, void *dst, int batch, int plane, int channel) { + int c4 = UP_DIV(channel, C4NUM); + for (int b = 0; b < batch; b++) { + int src_offset = b * plane * c4 * C4NUM; + int dst_offset = b * plane * channel; + for (int k = 0; k < plane; k++) { + int src_kernel_offset = src_offset + k * C4NUM; + int dst_kernel_offset = dst_offset + k * channel; + for (int c = 0; c < c4 - 1; c++) { + int src_c_offset = src_kernel_offset + c * plane * C4NUM; + int dst_c_offset = dst_kernel_offset + c * C4NUM; + ((int8_t *)dst + dst_c_offset)[0] = ((int8_t *)src + src_c_offset)[0]; + ((int8_t *)dst + dst_c_offset)[1] = ((int8_t *)src + src_c_offset)[1]; + ((int8_t *)dst + dst_c_offset)[2] = ((int8_t *)src + src_c_offset)[2]; + ((int8_t *)dst + dst_c_offset)[3] = ((int8_t *)src + src_c_offset)[3]; + } + // res part + int res_c = channel - (c4 - 1) * C4NUM; + for (int i = 0; i < res_c; i++) { + int src_res_c_offset = src_kernel_offset + (c4 - 1) * C4NUM * plane + i; + int dst_res_c_offset = dst_kernel_offset + (c4 - 1) * C4NUM + i; + ((int8_t *)dst + dst_res_c_offset)[0] = ((int8_t *)src + src_res_c_offset)[0]; + } + } + } +} + +void PackNHWCToC8HWN8Int8(const void *src, void *dst, int batch, int plane, int channel) { + for (int n = 0; n < batch; n++) { + for (int hw = 0; hw < plane; hw++) { + for (int c = 0; c < channel; c++) { + int c8div = c / C8NUM; + int c8mod = c % C8NUM; + int src_index = n * plane * channel + hw * channel + c; + int dst_index = c8div * batch * plane * C8NUM + hw * batch * C8NUM + n * C8NUM + c8mod; + ((int8_t *)dst)[dst_index] = ((int8_t *)src)[src_index]; + } + } + } + return; +} + +void PackNCHWToNHWCInt8(const void *src, void *dst, int batch, int plane, int channel) { + for (int n = 0; n < batch; n++) { + for (int c = 0; c < channel; c++) { + for (int hw = 0; hw < plane; hw++) { + int nhwc_index = n * channel * plane + hw * channel + c; + int nchw_index = n * channel * plane + c * plane + hw; + ((int8_t *)(dst))[nhwc_index] = ((const int8_t *)(src))[nchw_index]; + } + } + } + return; +} + +void PackNHWCToNCHWInt8(const void *src, void *dst, int batches, int plane, int channel) { + int hw8 = plane / C8NUM * C8NUM; + int c8 = channel / C8NUM * C8NUM; + int batch = plane * channel; + for (int n = 0; n < batches; n++) { + const int8_t *src_batch = (const int8_t *)src + n * batch; + int8_t *dst_batch = (int8_t *)dst + n * batch; + int hw = 0; + for (; hw < hw8; hw += C8NUM) { + int c = 0; + for (; c < c8; c += C8NUM) { + const int8_t *src_ptr = src_batch + hw * channel + c; + int8_t *dst_ptr = dst_batch + c * plane + hw; +#ifdef ENABLE_ARM64 + size_t srcStride = channel * sizeof(int8_t); + size_t dstStride = plane * sizeof(int8_t); + asm volatile( + "mov x10, %[src_ptr]\n" + "mov x11, %[dst_ptr]\n" + + "ld1 {v0.8b}, [x10], %[srcStride]\n" + "ld1 {v1.8b}, [x10], %[srcStride]\n" + "ld1 {v2.8b}, [x10], %[srcStride]\n" + "ld1 {v3.8b}, [x10], %[srcStride]\n" + + "trn1 v4.8b, v0.8b, v1.8b\n" + "trn2 v5.8b, v0.8b, v1.8b\n" + "trn1 v6.8b, v2.8b, v3.8b\n" + "trn2 v7.8b, v2.8b, v3.8b\n" + + "ld1 {v0.8b}, [x10], %[srcStride]\n" + "ld1 {v1.8b}, [x10], %[srcStride]\n" + "ld1 {v2.8b}, [x10], %[srcStride]\n" + "ld1 {v3.8b}, [x10], %[srcStride]\n" + + "trn1 v8.4h, v4.4h, v6.4h\n" + "trn2 v9.4h, v4.4h, v6.4h\n" + "trn1 v10.4h, v5.4h, v7.4h\n" + "trn2 v11.4h, v5.4h, v7.4h\n" + + "trn1 v4.8b, v0.8b, v1.8b\n" + "trn2 v5.8b, v0.8b, v1.8b\n" + "trn1 v6.8b, v2.8b, v3.8b\n" + "trn2 v7.8b, v2.8b, v3.8b\n" + + "trn1 v12.4h, v4.4h, v6.4h\n" + "trn2 v13.4h, v4.4h, v6.4h\n" + "trn1 v14.4h, v5.4h, v7.4h\n" + "trn2 v15.4h, v5.4h, v7.4h\n" + + "trn1 v0.2s, v8.2s, v12.2s\n" + "trn2 v4.2s, v8.2s, v12.2s\n" + "trn1 v1.2s, v10.2s, v14.2s\n" + "trn2 v5.2s, v10.2s, v14.2s\n" + "trn1 v2.2s, v9.2s, v13.2s\n" + "trn2 v6.2s, v9.2s, v13.2s\n" + "trn1 v3.2s, v11.2s, v15.2s\n" + "trn2 v7.2s, v11.2s, v15.2s\n" + + "st1 {v0.8b}, [x11], %[dstStride]\n" + "st1 {v1.8b}, [x11], %[dstStride]\n" + "st1 {v2.8b}, [x11], %[dstStride]\n" + "st1 {v3.8b}, [x11], %[dstStride]\n" + "st1 {v4.8b}, [x11], %[dstStride]\n" + "st1 {v5.8b}, [x11], %[dstStride]\n" + "st1 {v6.8b}, [x11], %[dstStride]\n" + "st1 {v7.8b}, [x11], %[dstStride]\n" + : + : + [ dst_ptr ] "r"(dst_ptr), [ src_ptr ] "r"(src_ptr), [ srcStride ] "r"(srcStride), [ dstStride ] "r"(dstStride) + : "x10", "x11", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", + "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", + "v30", "v31"); +#elif ENABLE_ARM32 + size_t srcStride = channel * sizeof(int8_t); + size_t dstStride = plane * sizeof(int8_t); + asm volatile( + "mov r10, %[src_ptr]\n" + "mov r12, %[dst_ptr]\n" + + "vld1.8 {d0}, [r10], %[srcStride]\n" + "vld1.8 {d1}, [r10], %[srcStride]\n" + "vld1.8 {d2}, [r10], %[srcStride]\n" + "vld1.8 {d3}, [r10], %[srcStride]\n" + "vld1.8 {d4}, [r10], %[srcStride]\n" + "vld1.8 {d5}, [r10], %[srcStride]\n" + "vld1.8 {d6}, [r10], %[srcStride]\n" + "vld1.8 {d7}, [r10], %[srcStride]\n" + + "vtrn.8 d0, d1\n" + "vtrn.8 d2, d3\n" + "vtrn.8 d4, d5\n" + "vtrn.8 d6, d7\n" + + "vtrn.16 d0, d2\n" + "vtrn.16 d1, d3\n" + "vtrn.16 d4, d6\n" + "vtrn.16 d5, d7\n" + + "vtrn.32 d0, d4\n" + "vtrn.32 d1, d5\n" + "vtrn.32 d2, d6\n" + "vtrn.32 d3, d7\n" + + "vst1.8 {d0}, [r12], %[dstStride]\n" + "vst1.8 {d1}, [r12], %[dstStride]\n" + "vst1.8 {d2}, [r12], %[dstStride]\n" + "vst1.8 {d3}, [r12], %[dstStride]\n" + "vst1.8 {d4}, [r12], %[dstStride]\n" + "vst1.8 {d5}, [r12], %[dstStride]\n" + "vst1.8 {d6}, [r12], %[dstStride]\n" + "vst1.8 {d7}, [r12], %[dstStride]\n" + : + : + [ dst_ptr ] "r"(dst_ptr), [ src_ptr ] "r"(src_ptr), [ srcStride ] "r"(srcStride), [ dstStride ] "r"(dstStride) + : "r10", "r12", "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7", "q8", "q9", "q10", "q11", "q12", "q13", "q14", + "q15"); +#else + for (int tr = 0; tr < C8NUM; tr++) { + for (int tc = 0; tc < C8NUM; tc++) { + dst_ptr[tc * plane + tr] = src_ptr[tr * channel + tc]; + } + } +#endif + } + for (; c < channel; c++) { + const int8_t *src_ptr = src_batch + hw * channel + c; + int8_t *dst_ptr = dst_batch + c * plane + hw; + for (size_t i = 0; i < C8NUM; i++) { + dst_ptr[i] = src_ptr[i * channel]; + } + } + } + for (; hw < plane; hw++) { + const int8_t *src_ptr = src_batch + hw * channel; + int8_t *dst_ptr = dst_batch + hw; + for (size_t i = 0; i < channel; i++) { + dst_ptr[i * plane] = src_ptr[i]; + } + } + } + return; +} diff --git a/mindspore/lite/nnacl/int8/pack_int8.h b/mindspore/lite/nnacl/int8/pack_int8.h new file mode 100644 index 00000000000..ea51ac89adc --- /dev/null +++ b/mindspore/lite/nnacl/int8/pack_int8.h @@ -0,0 +1,62 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_NNACL_INT8_PACK_INT8_H_ +#define MINDSPORE_LITE_NNACL_INT8_PACK_INT8_H_ + +#include +#include "nnacl/op_base.h" +#include "nnacl/int8/matmul_int8.h" +#include "nnacl/conv_parameter.h" + +#ifdef __cplusplus +extern "C" { +#endif + +void PackNHWCToNHWC4Int8(const void *src, void *dst, int batch, int plane, int channel); +void PackNHWC4ToNHWCInt8(const void *src, void *dst, int batch, int plane, int channel); +void PackNHWCToNHWC8Int8(const void *src, void *dst, int batch, int plane, int channel); +void PackNHWC8ToNHWCInt8(const void *src, void *dst, int batch, int plane, int channel); +void PackNCHWToNC8HW8Int8(const void *src, void *dst, int batch, int plane, int channel); +void PackNC4HW4ToNHWCInt8(const void *src, void *dst, int batch, int plane, int channel); +void PackNHWCToC8HWN8Int8(const void *src, void *dst, int batch, int plane, int channel); +void PackNCHWToNHWCInt8(const void *src, void *dst, int batch, int plane, int channel); +void PackNHWCToNCHWInt8(const void *src, void *dst, int batch, int plane, int channel); + +void PackInputSum16x4Int8(const int8_t *input, int32_t *input_sum, int32_t *filter_zp, ConvParameter *conv_param); +void PackInputSum16x4PerLayer(const int8_t *src, int32_t *dst, int32_t filter_zp, size_t row4, size_t col16); +void PackInputToC8Int8(const int8_t *input_data, int16_t *packed_input, ConvParameter *conv_param); +void PackWeightToC8Int8(const int8_t *origin_weight_data, int16_t *packed_weight_data, ConvParameter *conv_param); +void Im2ColPackUnitInt8Opt(const int8_t *input_data, int8_t *packed_input, int8_t *matmul_input, int real_cal_num, + int block_index, int32_t *filter_zp, int32_t *input_sum, ConvParameter *conv_param, + bool per_channel, bool is_optimize); +#ifdef ENABLE_ARM +void PreSum4x16Int8Pert(const int8_t *src, int32_t *sum, size_t row4, size_t col16, int32_t filter_zp); +void PreSum4x16Int8Peroc(const int8_t *src, int32_t *sum, int32_t *zp, size_t hw4, size_t ic16, int32_t oc_div, + size_t oc_res, size_t stride); +#endif + +void PackDepthwiseInt8Input(const int8_t *src, int16_t *dst, const ConvParameter *conv_param); +void PackDepthwiseInt8Weight(const int8_t *origin_weight, int16_t *packed_weight_, int plane, int channel, + ConvQuantArg *quant_qrg); +void PackDeconvDepthwiseInt8Weight(const int8_t *origin_weight, int16_t *packed_weight_, int plane, int channel, + ConvQuantArg *quant_qrg); + +#ifdef __cplusplus +} +#endif + +#endif // MINDSPORE_LITE_NNACL_INT8_PAD_INT8_H_ diff --git a/mindspore/lite/nnacl/int8/power_int8.h b/mindspore/lite/nnacl/int8/power_int8.h index 3cac590aff7..86fe6509801 100644 --- a/mindspore/lite/nnacl/int8/power_int8.h +++ b/mindspore/lite/nnacl/int8/power_int8.h @@ -19,7 +19,7 @@ #include "nnacl/op_base.h" #include "nnacl/power_parameter.h" -#include "nnacl/quantization/quantize.h" +#include "nnacl/int8/quantize.h" #ifdef __cplusplus extern "C" { diff --git a/mindspore/lite/nnacl/quantization/quantize.c b/mindspore/lite/nnacl/int8/quantize.c similarity index 96% rename from mindspore/lite/nnacl/quantization/quantize.c rename to mindspore/lite/nnacl/int8/quantize.c index 2da215a567a..a28ae44fffc 100644 --- a/mindspore/lite/nnacl/quantization/quantize.c +++ b/mindspore/lite/nnacl/int8/quantize.c @@ -14,8 +14,7 @@ * limitations under the License. */ -#include "nnacl/quantization/quantize.h" -#include +#include "nnacl/int8/quantize.h" const uint64_t dSignMask = 1ull << 63; const uint64_t dExponentMask = 0x7ffull << 52; @@ -57,8 +56,6 @@ void QuantizeRoundParameterWithSinglePrecision(double double_multiplier, int32_t /* multipiler is in[0x40000000, 0x7FFFFF80] range */ *quantized_multiplier = (int32_t)(((scale_bits & UINT32_C(0x007FFFFF)) | UINT32_C(0x00800000)) << 7); if (quantized_multiplier[0] < INT32_C(0x40000000) || quantized_multiplier[0] > INT32_C(0x7FFFFF80)) { - printf("quantized multiplier must be in [0x40000000, 0x7FFFFF80] range, now multiplier is %d\n", - quantized_multiplier[0]); return; } /* shift is in [0, 31] range */ diff --git a/mindspore/lite/nnacl/quantization/quantize.h b/mindspore/lite/nnacl/int8/quantize.h similarity index 79% rename from mindspore/lite/nnacl/quantization/quantize.h rename to mindspore/lite/nnacl/int8/quantize.h index 971a533642d..72595b4cfc0 100644 --- a/mindspore/lite/nnacl/quantization/quantize.h +++ b/mindspore/lite/nnacl/int8/quantize.h @@ -28,11 +28,6 @@ #define FILTER_PER_CHANNEL 0b010 #define OUTPUT_PER_CHANNEL 0b100 -typedef struct QuantArg { - float scale_; - int32_t zp_; -} QuantArg; - typedef struct ConvQuantArg { RoundingMode round_mode_; CalFixedMultiplierMode quant_multiplier_mode_; @@ -58,24 +53,6 @@ typedef struct ConcatQuantArg { int8_t output_activation_max_; } ConcatQuantArg; -typedef struct SqueezeQuantArg { - QuantArg *in_quant_args_; - QuantArg *out_quant_args_; -} SqueezeQuantArg; - -typedef struct UnSqueezeQuantArg { - int *input_sizes_; - int output_size_; - int **input_shapes_; - int *output_shape_; - float alpha; - int axis_; - size_t input_num_; - size_t output_dim_; - QuantArg in_quant_args_; - QuantArg out_quant_args_; -} UnSqueezeQuantArg; - typedef struct PreluQuantArg { int *input_sizes_; int output_size_; @@ -103,22 +80,6 @@ typedef struct MatmulQuantArg { int32_t quant_multiplier; } MatmulQuantArg; -typedef struct PadQuantArg { - QuantArg *in_quant_args_; - QuantArg *out_quanr_args_; - int8_t *constant_value_; -} PadQuantArg; - -typedef struct MulQuantArg { - QuantArg in_quant_args_[2]; - QuantArg out_quant_arg_; - int output_multiplier_; - int output_activation_min_; - int output_activation_max_; - int shift_left_; - int shift_right_; -} MulQuantArg; - typedef struct CropQuantArg { QuantArg in_args_; QuantArg out_args_; @@ -142,13 +103,6 @@ typedef struct GatherQuantArg { int zp_out_; } GatherQuantArg; -typedef struct SplitQuantArg { - QuantArg in_args_; - QuantArg out_args_[20]; - int output_activation_min_; - int output_activation_max_; -} SplitQuantArg; - typedef struct SoftmaxQuantArg { QuantArg in_quant_args_; QuantArg out_quant_arg_; @@ -159,19 +113,6 @@ typedef struct SoftmaxQuantArg { int shift_right_; } SoftmaxQuantArg; -typedef struct ReshapeQuantArg { - QuantArg in_args_; - QuantArg out_args_; - int output_activation_min_; - int output_activation_max_; -} ReshapeQuantArg; - -typedef struct QuantMulArg { - int32_t multiplier_; - int left_shift_; - int right_shift_; -} QuantMulArg; - typedef struct SubQuantArg { QuantArg in0_args_; QuantArg in1_args_; @@ -227,21 +168,6 @@ typedef struct ReduceQuantArg { int sum_square_right_shift_; } ReduceQuantArg; -typedef struct SliceQuantArg { - QuantArg in_args_; - QuantArg out_args_; - int output_activation_min_; - int output_activation_max_; -} SliceQuantArg; - -typedef struct PowerQuantArg { - QuantArg in_args_; - QuantArg exp_args_; - QuantArg out_args_; - int output_activation_min_; - int output_activation_max_; -} PowerQuantArg; - typedef struct LeakyReluQuantArg { OpParameter op_parameter_; PreluQuantArg quant_arg; diff --git a/mindspore/lite/nnacl/int8/reduce_int8.c b/mindspore/lite/nnacl/int8/reduce_int8.c index 3b206156250..a2fa7204483 100644 --- a/mindspore/lite/nnacl/int8/reduce_int8.c +++ b/mindspore/lite/nnacl/int8/reduce_int8.c @@ -17,7 +17,7 @@ #include #include "nnacl/int8/reduce_int8.h" #include "nnacl/errorcode.h" -#include "nnacl/quantization/fixed_point.h" +#include "nnacl/int8/fixed_point.h" #include "nnacl/common_func.h" int ReduceMeanN(int n, int h, int w, int c, int8_t *in_data, int8_t *out_data, QuantMulArg quant_arg) { diff --git a/mindspore/lite/nnacl/int8/reduce_int8.h b/mindspore/lite/nnacl/int8/reduce_int8.h index c573514fc99..44d845f5ff9 100644 --- a/mindspore/lite/nnacl/int8/reduce_int8.h +++ b/mindspore/lite/nnacl/int8/reduce_int8.h @@ -16,7 +16,9 @@ #ifndef MINDSPORE_LITE_NNACL_INT8_REDUCE_INT8_H_ #define MINDSPORE_LITE_NNACL_INT8_REDUCE_INT8_H_ -#include "nnacl/quantization/quantize.h" + +#include "nnacl/int8/quantize.h" + #ifdef __cplusplus extern "C" { #endif diff --git a/mindspore/lite/nnacl/int8/relux_int8.h b/mindspore/lite/nnacl/int8/relux_int8.h index 4be944ae4ea..78b78596e82 100644 --- a/mindspore/lite/nnacl/int8/relux_int8.h +++ b/mindspore/lite/nnacl/int8/relux_int8.h @@ -19,8 +19,8 @@ #include #include "nnacl/op_base.h" #include "nnacl/errorcode.h" -#include "nnacl/quantization/fixed_point.h" -#include "nnacl/quantization/quantize.h" +#include "nnacl/int8/fixed_point.h" +#include "nnacl/int8/quantize.h" typedef struct ReluXQuantArg { QuantArg input_arg; diff --git a/mindspore/lite/nnacl/int8/reshape_int8.h b/mindspore/lite/nnacl/int8/reshape_int8.h index fc0d2ac074d..5e88d859881 100644 --- a/mindspore/lite/nnacl/int8/reshape_int8.h +++ b/mindspore/lite/nnacl/int8/reshape_int8.h @@ -16,6 +16,8 @@ #ifndef MINDSPORE_LITE_NNACL_INT8_RESHAHPE_INT8_H_ #define MINDSPORE_LITE_NNACL_INT8_RESHAHPE_INT8_H_ + +#include #include "nnacl/op_base.h" #include "nnacl/reshape_parameter.h" diff --git a/mindspore/lite/nnacl/int8/resize_int8.c b/mindspore/lite/nnacl/int8/resize_int8.c index 228f70a40f9..31dd3e92b1d 100644 --- a/mindspore/lite/nnacl/int8/resize_int8.c +++ b/mindspore/lite/nnacl/int8/resize_int8.c @@ -16,7 +16,7 @@ #include #include "nnacl/int8/resize_int8.h" #include "nnacl/common_func.h" -#include "nnacl/quantization/fixed_point.h" +#include "nnacl/int8/fixed_point.h" #include "nnacl/errorcode.h" int ResizeBilinearInt8(const int8_t *input_ptr, int8_t *output_ptr, int batch, int in_h, int in_w, int out_h, int out_w, diff --git a/mindspore/lite/nnacl/int8/resize_int8.h b/mindspore/lite/nnacl/int8/resize_int8.h index 438bba4d7e3..49a328262a3 100644 --- a/mindspore/lite/nnacl/int8/resize_int8.h +++ b/mindspore/lite/nnacl/int8/resize_int8.h @@ -21,7 +21,7 @@ #endif #include #include "nnacl/op_base.h" -#include "nnacl/quantization/quantize.h" +#include "nnacl/int8/quantize.h" #include "nnacl/resize_parameter.h" #ifdef __cplusplus diff --git a/mindspore/lite/nnacl/int8/scale_int8.c b/mindspore/lite/nnacl/int8/scale_int8.c index dd9374c572f..bb33c643f17 100644 --- a/mindspore/lite/nnacl/int8/scale_int8.c +++ b/mindspore/lite/nnacl/int8/scale_int8.c @@ -15,7 +15,7 @@ */ #include "nnacl/int8/scale_int8.h" -#include "nnacl/quantization/fixed_point.h" +#include "nnacl/int8/fixed_point.h" #ifdef ENABLE_NEON int16x4_t ClacSumHalfWordMul2(int32x4_t scaled_input0, int32x4_t scaled_input1, int32x4_t left_shift_out_vec, diff --git a/mindspore/lite/nnacl/int8/scale_int8.h b/mindspore/lite/nnacl/int8/scale_int8.h index a773d6df1ad..993e5b808ca 100644 --- a/mindspore/lite/nnacl/int8/scale_int8.h +++ b/mindspore/lite/nnacl/int8/scale_int8.h @@ -19,6 +19,8 @@ #include "nnacl/op_base.h" #include "nnacl/scale.h" +#include "nnacl/nnacl_common.h" + #ifdef __cplusplus extern "C" { #endif diff --git a/mindspore/lite/nnacl/int8/sigmoid_int8.h b/mindspore/lite/nnacl/int8/sigmoid_int8.h index 1e546f6cd4c..c2ecfdae59d 100644 --- a/mindspore/lite/nnacl/int8/sigmoid_int8.h +++ b/mindspore/lite/nnacl/int8/sigmoid_int8.h @@ -19,7 +19,7 @@ #include #include "nnacl/op_base.h" #include "nnacl/errorcode.h" -#include "nnacl/quantization/fixed_point.h" +#include "nnacl/int8/fixed_point.h" #ifdef __cplusplus extern "C" { diff --git a/mindspore/lite/nnacl/int8/slice_int8.c b/mindspore/lite/nnacl/int8/slice_int8.c index 6fb522441cb..9604e4c8693 100644 --- a/mindspore/lite/nnacl/int8/slice_int8.c +++ b/mindspore/lite/nnacl/int8/slice_int8.c @@ -15,8 +15,6 @@ */ #include "nnacl/int8/slice_int8.h" -#include -#include "nnacl/quantization/fixed_point.h" int SliceInt8NoParallel(const int8_t *input, int8_t *output, SliceParameter *param) { double input_scale = param->quant_arg_.in_args_.scale_; diff --git a/mindspore/lite/nnacl/int8/slice_int8.h b/mindspore/lite/nnacl/int8/slice_int8.h index 5a67c083ccf..70ac1fbd8c9 100644 --- a/mindspore/lite/nnacl/int8/slice_int8.h +++ b/mindspore/lite/nnacl/int8/slice_int8.h @@ -16,8 +16,11 @@ #ifndef MINDSPORE_LITE_NNACL_INT8_SLICE_INT8_H_ #define MINDSPORE_LITE_NNACL_INT8_SLICE_INT8_H_ +#include +#include #include "nnacl/op_base.h" #include "nnacl/slice_parameter.h" +#include "nnacl/int8/fixed_point.h" #ifdef __cplusplus extern "C" { diff --git a/mindspore/lite/nnacl/int8/softmax_int8.c b/mindspore/lite/nnacl/int8/softmax_int8.c index 4cf455aef72..58e151f4539 100644 --- a/mindspore/lite/nnacl/int8/softmax_int8.c +++ b/mindspore/lite/nnacl/int8/softmax_int8.c @@ -15,9 +15,6 @@ */ #include "nnacl/int8/softmax_int8.h" -#include -#include "nnacl/quantization/fixed_point.h" -#include "nnacl/quantization/quantize.h" int SoftmaxInt8(const int8_t *input_ptr, int8_t *output_ptr, int count, int *exp_data, int *sum_data, SoftmaxQuantArg quant_param, SoftmaxParameter *parameter) { diff --git a/mindspore/lite/nnacl/int8/softmax_int8.h b/mindspore/lite/nnacl/int8/softmax_int8.h index a6adae25fb9..83c8aa7ca06 100644 --- a/mindspore/lite/nnacl/int8/softmax_int8.h +++ b/mindspore/lite/nnacl/int8/softmax_int8.h @@ -17,9 +17,11 @@ #ifndef MINDSPORE_LITE_NNACL_INT8_SOFTMAX_INT8_H_ #define MINDSPORE_LITE_NNACL_INT8_SOFTMAX_INT8_H_ +#include #include "nnacl/op_base.h" #include "nnacl/softmax_parameter.h" -#include "nnacl/quantization/quantize.h" +#include "nnacl/int8/fixed_point.h" +#include "nnacl/int8/quantize.h" #ifdef __cplusplus extern "C" { diff --git a/mindspore/lite/nnacl/int8/split_int8.h b/mindspore/lite/nnacl/int8/split_int8.h index f2d18954e83..065674c56a1 100644 --- a/mindspore/lite/nnacl/int8/split_int8.h +++ b/mindspore/lite/nnacl/int8/split_int8.h @@ -16,6 +16,8 @@ #ifndef MINDSPORE_LITE_NNACL_INT8_SPLIT_INT8_H_ #define MINDSPORE_LITE_NNACL_INT8_SPLIT_INT8_H_ + +#include #include "nnacl/op_base.h" #include "nnacl/split_parameter.h" diff --git a/mindspore/lite/nnacl/int8/squeeze_int8.h b/mindspore/lite/nnacl/int8/squeeze_int8.h index 6119c4c6531..e7312ae6a42 100644 --- a/mindspore/lite/nnacl/int8/squeeze_int8.h +++ b/mindspore/lite/nnacl/int8/squeeze_int8.h @@ -17,7 +17,7 @@ #ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_OPCLIB_INT8_SQUEEZE_INT8_H_ #define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_OPCLIB_INT8_SQUEEZE_INT8_H_ -#include "nnacl/quantization/quantize.h" +#include "nnacl/int8/quantize.h" #include "nnacl/squeeze_parameter.h" #ifdef __cplusplus diff --git a/mindspore/lite/nnacl/int8/sub_int8.c b/mindspore/lite/nnacl/int8/sub_int8.c index 1961dd0d805..ace1417b287 100644 --- a/mindspore/lite/nnacl/int8/sub_int8.c +++ b/mindspore/lite/nnacl/int8/sub_int8.c @@ -19,7 +19,7 @@ #include #include "nnacl/int8/common_func_int8.h" #endif -#include "nnacl/quantization/fixed_point.h" +#include "nnacl/int8/fixed_point.h" #ifdef ENABLE_NEON diff --git a/mindspore/lite/nnacl/int8/sub_int8.h b/mindspore/lite/nnacl/int8/sub_int8.h index 4ac8f500428..6764072e70b 100644 --- a/mindspore/lite/nnacl/int8/sub_int8.h +++ b/mindspore/lite/nnacl/int8/sub_int8.h @@ -18,7 +18,7 @@ #define MINDSPORE_LITE_NNACL_INT8_SUB_INT8_H_ #include "nnacl/op_base.h" -#include "nnacl/quantization/quantize.h" +#include "nnacl/int8/quantize.h" #ifdef __cplusplus extern "C" { diff --git a/mindspore/lite/nnacl/int8/tanh_int8.h b/mindspore/lite/nnacl/int8/tanh_int8.h index e814ce01c10..1ad2cbecf41 100644 --- a/mindspore/lite/nnacl/int8/tanh_int8.h +++ b/mindspore/lite/nnacl/int8/tanh_int8.h @@ -18,8 +18,8 @@ #define MINDSPORE_LITE_NNACL_INT8_TANH_INT8_H_ #include "nnacl/op_base.h" -#include "nnacl/quantization/quantize.h" -#include "nnacl/quantization/fixed_point.h" +#include "nnacl/int8/quantize.h" +#include "nnacl/int8/fixed_point.h" #include "nnacl/int8/quant_dtype_cast_int8.h" #include "nnacl/fp32/activation_fp32.h" diff --git a/mindspore/lite/nnacl/int8/unsqueeze_int8.c b/mindspore/lite/nnacl/int8/unsqueeze_int8.c index a7c2b04984d..acf2bb9f581 100644 --- a/mindspore/lite/nnacl/int8/unsqueeze_int8.c +++ b/mindspore/lite/nnacl/int8/unsqueeze_int8.c @@ -16,7 +16,6 @@ #include "nnacl/unsqueeze_parameter.h" #include "nnacl/int8/unsqueeze_int8.h" -#include int Int8Unsqueeze(int8_t *input_ptr, int8_t *output_ptr, UnSqueezeParameter *para_, size_t data_size, int task_id) { float output_scale = para_->quant_arg.out_quant_args_.scale_; diff --git a/mindspore/lite/nnacl/l2_norm_parameter.h b/mindspore/lite/nnacl/l2_norm_parameter.h index 90a0b491b46..f9bf06bb6dc 100644 --- a/mindspore/lite/nnacl/l2_norm_parameter.h +++ b/mindspore/lite/nnacl/l2_norm_parameter.h @@ -17,7 +17,7 @@ #define MINDSPORE_LITE_NNACL_L2NORM_PARAMETER_H_ #include "nnacl/op_base.h" -#include "nnacl/quantization/quantize.h" +#include "mindspore/lite/nnacl/int8/quantize.h" typedef struct L2NormParameter { // Primitive parameter diff --git a/mindspore/lite/nnacl/layer_norm_parameter.h b/mindspore/lite/nnacl/layer_norm_parameter.h index e849c15e7ea..dcbb1c54154 100644 --- a/mindspore/lite/nnacl/layer_norm_parameter.h +++ b/mindspore/lite/nnacl/layer_norm_parameter.h @@ -17,7 +17,7 @@ #define MINDSPORE_LITE_NNACL_LAYER_NORM_PARAMETER_H_ #include "nnacl/op_base.h" -#include "nnacl/quantization/quantize.h" +#include "mindspore/lite/nnacl/int8/quantize.h" enum ElementwiseMode { ELEMENTWISE_NOT = 0, ELEMENTWISE_PER_CHANNEL = 1, ELEMENTWISE_PER_NUM = 2 }; typedef struct LayerNormParameter { diff --git a/mindspore/lite/nnacl/matmul_parameter.h b/mindspore/lite/nnacl/matmul_parameter.h index 189e7e2ce04..8b208218bd1 100644 --- a/mindspore/lite/nnacl/matmul_parameter.h +++ b/mindspore/lite/nnacl/matmul_parameter.h @@ -18,7 +18,6 @@ #define MINDSPORE_LITE_NNACL_MATMUL_H_ #include "nnacl/op_base.h" -#include "nnacl/quantization/quantize.h" typedef void (*MATMUL_OPT_R4_FUNC)(const int8_t *a, const int8_t *b, int *dst, int row_4, int col_4, int deep_16, const int *input_sum, const int *bias); diff --git a/mindspore/lite/nnacl/mul_parameter.h b/mindspore/lite/nnacl/mul_parameter.h index 6b5c61cfde7..7c2d8fe181a 100644 --- a/mindspore/lite/nnacl/mul_parameter.h +++ b/mindspore/lite/nnacl/mul_parameter.h @@ -18,7 +18,16 @@ #define MINDSPORE_LITE_NNACL_MUL_PARAMETER_H_ #include "nnacl/op_base.h" -#include "nnacl/quantization/quantize.h" + +typedef struct MulQuantArg { + QuantArg in_quant_args_[2]; + QuantArg out_quant_arg_; + int output_multiplier_; + int output_activation_min_; + int output_activation_max_; + int shift_left_; + int shift_right_; +} MulQuantArg; typedef struct MulParameter { // Primitive parameter diff --git a/mindspore/ccsrc/pybind_api/pybind_patch.h b/mindspore/lite/nnacl/nnacl_common.c similarity index 75% rename from mindspore/ccsrc/pybind_api/pybind_patch.h rename to mindspore/lite/nnacl/nnacl_common.c index a71774b26ac..a07bdc8f903 100644 --- a/mindspore/ccsrc/pybind_api/pybind_patch.h +++ b/mindspore/lite/nnacl/nnacl_common.c @@ -14,11 +14,4 @@ * limitations under the License. */ -#ifndef PYBIND_API_PYBIND_PATCH_H_ -#define PYBIND_API_PYBIND_PATCH_H_ - -namespace pybind11 { -PYBIND11_RUNTIME_EXCEPTION(attribute_error, PyExc_AttributeError) -} - -#endif // PYBIND_API_PYBIND_PATCH_H_ +#include "nnacl/nnacl_common.h" diff --git a/mindspore/lite/nnacl/nnacl_common.h b/mindspore/lite/nnacl/nnacl_common.h new file mode 100644 index 00000000000..65ae6de172a --- /dev/null +++ b/mindspore/lite/nnacl/nnacl_common.h @@ -0,0 +1,35 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_NNACL_NNACL_COMMON_H_ +#define MINDSPORE_LITE_NNACL_NNACL_COMMON_H_ + +#ifdef __cplusplus +extern "C" { +#endif + +inline void ComputeStrides(const int *shape, int *strides, const int ndim) { + int stride = 1; + for (int i = ndim - 1; i >= 0; i--) { + strides[i] = stride; + stride *= shape[i]; + } +} + +#ifdef __cplusplus +} +#endif +#endif // MINDSPORE_LITE_NNACL_NNACL_COMMON_H_ diff --git a/mindspore/lite/nnacl/op_base.h b/mindspore/lite/nnacl/op_base.h index eb3b64fbe04..24d7cdecbb4 100644 --- a/mindspore/lite/nnacl/op_base.h +++ b/mindspore/lite/nnacl/op_base.h @@ -28,6 +28,7 @@ #include #include #include +#include #define C2NUM 2 #define C4NUM 4 @@ -78,6 +79,17 @@ typedef struct OpParameter { int thread_num_; } OpParameter; +typedef struct QuantArg { + float scale_; + int32_t zp_; +} QuantArg; + +typedef struct QuantMulArg { + int32_t multiplier_; + int left_shift_; + int right_shift_; +} QuantMulArg; + typedef enum ActType { ActType_No, ActType_Relu, ActType_Sigmod, ActType_Relu6, ActType_Prelu } ActType; typedef enum PadMode { Pad_No, Pad_Same, Pad_Valid } PadMode; typedef enum RoundingMode { Rounding_No, Rounding_Away_from_zero, Rounding_Up } RoundingMode; diff --git a/mindspore/lite/nnacl/pack.c b/mindspore/lite/nnacl/pack.c index 13b5c9a13dd..63056a3ae53 100644 --- a/mindspore/lite/nnacl/pack.c +++ b/mindspore/lite/nnacl/pack.c @@ -14,1177 +14,4 @@ * limitations under the License. */ -#include -#include -#include "nnacl/int8/conv_int8.h" #include "nnacl/pack.h" - -void PackWeightKHWToHWKFp32(const void *src, void *dst, int plane, int channel) { - return PackNCHWToNHWCFp32(src, dst, 1, plane, channel); -} - -void PackHWCToWHC(const float *src, float *dst, int height, int width, int channel) { - for (int i = 0; i < height; ++i) { - for (int j = 0; j < width; ++j) { - memcpy(dst + (j * height + i) * channel, src + (i * width + j) * channel, channel * sizeof(float)); - } - } -} - -void PackWeightInt8(int8_t *weight_data, ConvParameter *conv_param, int8_t *packed_weight, int32_t *weight_sum) { - // original weight format : ohwi - int kernel_h = conv_param->kernel_h_; - int kernel_w = conv_param->kernel_w_; - int in_channel = conv_param->input_channel_; - int out_channel = conv_param->output_channel_; - int oc4 = UP_DIV(out_channel, C4NUM); - int ic4 = UP_DIV(in_channel, C4NUM); - int kernel_plane = kernel_h * kernel_w; - int plane_c4 = UP_DIV(kernel_plane, C4NUM); - int pack_weight_size = oc4 * C4NUM * ic4 * C4NUM * plane_c4 * C4NUM; - int block_size = pack_weight_size / oc4; - QuantArg *filter_args = conv_param->conv_quant_arg_.filter_quant_args_; - - for (int m = 0; m < kernel_plane; m++) { - int kernel_plane_stride = m * in_channel; - int plane_block = m / C4NUM; - int plane_res = m % C4NUM; - int packed_kernel_plane_stride = plane_block * C4NUM * C4NUM * ic4 * C4NUM + plane_res * C4NUM; - for (int i = 0; i < ic4; i++) { - int channel_block_stride = kernel_plane_stride + i * C4NUM; - int packed_channel_block_size = packed_kernel_plane_stride + i * C4NUM * C4NUM * C4NUM; - int ic_remainder = in_channel - i * C4NUM; - int real_ic_num = ic_remainder < C4NUM ? ic_remainder : C4NUM; - for (int h = 0; h < real_ic_num; h++) { - int block_stride = channel_block_stride + h; - int packed_block_stride = packed_channel_block_size + h; - for (int j = 0; j < oc4; j++) { - int kernel_block_stride = block_stride + j * C4NUM * kernel_plane * in_channel; - int packed_kernel_block_size = packed_block_stride + j * block_size; - int oc_remainder = out_channel - j * C4NUM; - int real_oc_num = oc_remainder < C4NUM ? oc_remainder : C4NUM; - for (int k = 0; k < real_oc_num; k++) { - int8_t *origin_data_ptr = weight_data + kernel_block_stride + k * kernel_plane * in_channel; - int8_t *packed_data_ptr = packed_weight + packed_kernel_block_size + k * C4NUM * C4NUM; - *packed_data_ptr = origin_data_ptr[0]; - int32_t f_zp; - if (conv_param->conv_quant_arg_.per_channel_ & FILTER_PER_CHANNEL) { - f_zp = filter_args[j * C4NUM + k].zp_; - } else { - f_zp = filter_args[0].zp_; - } - weight_sum[j * C4NUM + k] += (int32_t)(packed_data_ptr[0] - f_zp); - } - } // kernel block loop - } // inchannel block loop - } // channel block loop - } // kernel plane loop -} - -void Conv1x1InputPack(const void *src_ptr, void *dst_ptr, ConvParameter *conv_param, int data_size) { - /* support nhwc */ - char *src = (char *)src_ptr; - char *dst = (char *)dst_ptr; - for (int dst_h = 0; dst_h < conv_param->output_h_; dst_h++) { - int src_h = dst_h * conv_param->stride_h_ - conv_param->pad_u_; - if (src_h < 0 || src_h >= conv_param->input_h_) { - continue; - } - const char *src_h_ptr = src + src_h * conv_param->input_w_ * conv_param->input_channel_ * data_size; - char *dst_h_ptr = dst + dst_h * conv_param->output_w_ * conv_param->input_channel_ * data_size; - for (int dst_w = 0; dst_w < conv_param->output_w_; dst_w++) { - int src_w = dst_w * conv_param->stride_w_ - conv_param->pad_l_; - if (src_w < 0 || src_w >= conv_param->input_w_) { - continue; - } - memcpy(dst_h_ptr + dst_w * conv_param->input_channel_ * data_size, - src_h_ptr + src_w * conv_param->input_channel_ * data_size, conv_param->input_channel_ * data_size); - } - } - return; -} - -void Pack1x1WeightFp32(const float *weight_data, float *packed_weight, ConvParameter *conv_param) { - int c4 = UP_ROUND(conv_param->input_channel_, C4NUM); - for (int ic = 0; ic < conv_param->input_channel_; ic++) { - for (int oc = 0; oc < conv_param->output_channel_; oc++) { - int oc4mod = oc % 4; - int oc4div = oc / 4; - int dst_index = oc4div * c4 * C4NUM + ic * C4NUM + oc4mod; - int src_index = oc * conv_param->input_channel_ + ic; - packed_weight[dst_index] = weight_data[src_index]; - } - } - return; -} - -void PackInputSum16x4PerLayer(const int8_t *src, int32_t *dst, int32_t filter_zp, size_t row4, size_t col16) { - /* normal matmul : 4x16 * 16x4 -> 4x4 */ -#ifdef ENABLE_ARM - PreSum4x16Int8Pert(src, dst, row4, col16, filter_zp); -#else - for (int r = 0; r < row4; r++) { - int32_t tmp_value = 0; - for (int c = 0; c < col16; c++) { - int r4div = r / C4NUM, r4mod = r % C4NUM, c16div = c / C16NUM, c16mod = c % C16NUM; - int src_index = r4div * C4NUM * col16 + c16div * C16NUM * C4NUM + r4mod * C16NUM + c16mod; - tmp_value += src[src_index]; - } - dst[r] = tmp_value * filter_zp; - } -#endif - return; -} - -void PackInputSum16x4PerChannel(const int8_t *input_value, int32_t *input_sum, int32_t *filter_zp_ptr, - size_t plane_size, size_t input_channel, size_t output_channel) { - size_t hw4 = UP_ROUND(plane_size, C4NUM); - size_t ic16 = UP_ROUND(input_channel, C16NUM); -#ifdef ENABLE_ARM64 - size_t oc_div4 = output_channel / C4NUM * C4NUM; - size_t oc_res4 = output_channel - oc_div4; - size_t inputsun_stride = hw4 * C4NUM * 4 - C4NUM * C4NUM * 4; - PreSum4x16Int8Peroc(input_value, input_sum, filter_zp_ptr, hw4, ic16, oc_div4, oc_res4, inputsun_stride); -#else - - for (int ri = 0; ri < plane_size; ri++) { - int ri4div = ri / C4NUM, ri4mod = ri % C4NUM; - for (int ci = 0; ci < output_channel; ci++) { - int32_t tmp_sum_value = 0; - int ci4div = ci / C4NUM, ci4mod = ci % C4NUM; - int32_t filter_zp = filter_zp_ptr[ci]; - for (int di = 0; di < input_channel; di++) { - size_t di16div = di / C16NUM, di16mod = di % C16NUM; - int src_index = ri4div * C4NUM * ic16 + di16div * C16NUM * C4NUM + ri4mod * C16NUM + di16mod; - tmp_sum_value += input_value[src_index]; - } - int dst_index = ci4div * C4NUM * hw4 + ri * C4NUM + ci4mod; - input_sum[dst_index] = tmp_sum_value * filter_zp; - } - } -#endif - return; -} - -void PackInputSum16x4PerChannelArm32(const int8_t *input_value, int32_t *input_sum, int32_t *filter_zp_ptr, - size_t plane_size, size_t input_channel, size_t output_channel) { - size_t hw4 = UP_ROUND(plane_size, C4NUM); - size_t ic16 = UP_ROUND(input_channel, C16NUM); - -#ifdef ENABLE_ARM32 - size_t oc_div2 = output_channel / C2NUM * C2NUM; - size_t oc_res2 = output_channel - oc_div2; - size_t inputsun_stride = hw4 * C2NUM * 4 - C4NUM * C2NUM * 4; - PreSum4x16Int8Peroc(input_value, input_sum, filter_zp_ptr, hw4, ic16, oc_div2, oc_res2, inputsun_stride); -#else - for (int ri = 0; ri < plane_size; ri++) { - int ri4div = ri / C4NUM, ri4mod = ri % C4NUM; - for (int ci = 0; ci < output_channel; ci++) { - int32_t tmp_sum_value = 0; - int ci2div = ci / C2NUM, ci2mod = ci % C2NUM; - int32_t filter_zp = filter_zp_ptr[ci]; - for (int di = 0; di < input_channel; di++) { - size_t di16div = di / C16NUM, di16mod = di % C16NUM; - int src_index = ri4div * C4NUM * ic16 + di16div * C16NUM * C4NUM + ri4mod * C16NUM + di16mod; - tmp_sum_value += input_value[src_index]; - } - int dst_index = ci2div * C2NUM * hw4 + ri * C2NUM + ci2mod; - input_sum[dst_index] = tmp_sum_value * filter_zp; - } - } -#endif - return; -} - -void PackInputSum16x4Int8(const int8_t *input, int32_t *input_sum, int32_t *filter_zp, ConvParameter *conv_param) { - size_t hw = conv_param->output_h_ * conv_param->output_w_; - size_t hw4 = UP_ROUND(hw, C4NUM); - size_t ic16 = UP_ROUND(conv_param->input_channel_, C16NUM); - if (conv_param->conv_quant_arg_.filter_arg_num_ == 1) { - PackInputSum16x4PerLayer(input, input_sum, conv_param->conv_quant_arg_.filter_quant_args_[0].zp_, hw4, ic16); - } else { -#ifdef ENABLE_ARM32 - PackInputSum16x4PerChannelArm32(input, input_sum, filter_zp, hw, conv_param->input_channel_, - conv_param->output_channel_); -#else - PackInputSum16x4PerChannel(input, input_sum, filter_zp, hw, conv_param->input_channel_, - conv_param->output_channel_); -#endif - } - return; -} - -void Im2ColPackUnitFp32(const float *input_data, const ConvParameter *conv_param, float *packed_input, int real_cal_num, - int block_index) { - // input format : nhwc - int kernel_h = conv_param->kernel_h_; - int kernel_w = conv_param->kernel_w_; - int kernel_plane = kernel_h * kernel_w; - int dilation_h = conv_param->dilation_h_; - int dilation_w = conv_param->dilation_w_; - int in_channel = conv_param->input_channel_; - int in_w = conv_param->input_w_; - int out_w = conv_param->output_w_; - - for (int i = 0; i < real_cal_num; i++) { - int block_start = block_index + i; - int input_h = block_start / out_w * conv_param->stride_h_ - conv_param->pad_u_; - int input_w = block_start % out_w * conv_param->stride_w_ - conv_param->pad_l_; - int input_stride = (input_h * in_w + input_w) * in_channel; - int kh_s = MSMAX(0, UP_DIV(-input_h, dilation_h)); - int kh_e = MSMIN(kernel_h, UP_DIV(conv_param->input_h_ - input_h, dilation_h)); - int kw_s = MSMAX(0, UP_DIV(-input_w, dilation_w)); - int kw_e = MSMIN(kernel_w, UP_DIV(in_w - input_w, dilation_w)); - if (dilation_w == 1 && dilation_h == 1) { - for (int j = kh_s; j < kh_e; j++) { - int input_y_stride = j * in_w * in_channel + input_stride; - int input_x_stride = input_y_stride + kw_s * in_channel; - int input_plane_offset = (j * kernel_w + kw_s) * in_channel + i * in_channel * kernel_plane; - memcpy(packed_input + input_plane_offset, input_data + input_x_stride, - (kw_e - kw_s) * in_channel * sizeof(float)); - } // kernel_h loop - } else { - for (int j = kh_s; j < kh_e; j++) { - int input_y_stride = j * dilation_h * in_w * in_channel + input_stride; - for (int k = kw_s; k < kw_e; ++k) { - int input_x_stride = input_y_stride + k * dilation_w * in_channel; - int input_plane_offset = (j * kernel_w + k) * in_channel + i * in_channel * kernel_plane; - memcpy(packed_input + input_plane_offset, input_data + input_x_stride, in_channel * sizeof(float)); - } - } // kernel_h loop - } - } // tile num loop -} - -void Im2ColPackUnitInt8Opt(const int8_t *input_data, int8_t *packed_input, int8_t *matmul_input, int real_cal_num, - int block_index, int32_t *filter_zp, int32_t *input_sum, ConvParameter *conv_param, - bool per_channel, bool is_optimize) { - // input format : nhwc - int kernel_h = conv_param->kernel_h_; - int kernel_w = conv_param->kernel_w_; - int stride_h = conv_param->stride_h_; - int stride_w = conv_param->stride_w_; - int pad_h = conv_param->pad_u_; - int pad_w = conv_param->pad_l_; - int dilation_h = conv_param->dilation_h_; - int dilation_w = conv_param->dilation_w_; - int in_channel = conv_param->input_channel_; - int in_h = conv_param->input_h_; - int in_w = conv_param->input_w_; - int out_w = conv_param->output_w_; - int kernel_plane = kernel_h * kernel_w; - - for (int i = 0; i < real_cal_num; i++) { - int block_start = block_index + i; - int input_h = block_start / out_w * stride_h - pad_h; - int input_w = block_start % out_w * stride_w - pad_w; - int input_stride = input_h * in_w * in_channel + input_w * in_channel; - int kh_s = MSMAX(0, UP_DIV(-input_h, dilation_h)); - int kh_e = MSMIN(kernel_h, UP_DIV(in_h - input_h, dilation_h)); - int kw_s = MSMAX(0, UP_DIV(-input_w, dilation_w)); - int kw_e = MSMIN(kernel_w, UP_DIV(in_w - input_w, dilation_w)); - if (dilation_w == 1 && dilation_h == 1) { - for (int j = kh_s; j < kh_e; j++) { - int input_y_stride = j * in_w * in_channel + input_stride; - int input_x_stride = input_y_stride + kw_s * in_channel; - int input_plane_offset = (j * kernel_w + kw_s) * in_channel + i * in_channel * kernel_plane; - memcpy(matmul_input + input_plane_offset, input_data + input_x_stride, (kw_e - kw_s) * in_channel); - } // kernel_h loop - } else { - for (int j = kh_s; j < kh_e; j++) { - int input_y_stride = j * dilation_h * in_w * in_channel + input_stride; - for (int k = kw_s; k < kw_e; ++k) { - int input_x_stride = input_y_stride + k * dilation_w * in_channel; - int input_plane_offset = (j * kernel_w + k) * in_channel + i * in_channel * kernel_plane; - memcpy(matmul_input + input_plane_offset, input_data + input_x_stride, in_channel); - } - } // kernel_h loop - } - } // tile num loop - int deep = kernel_plane * in_channel; - if (is_optimize) { - if (per_channel) { - Conv1x1PreOptPeroc(matmul_input, packed_input, input_sum, deep, conv_param->output_channel_, real_cal_num, - filter_zp, C8NUM * C8NUM); - } else { - Conv1x1PreOptPert(matmul_input, packed_input, input_sum, deep, real_cal_num, conv_param); - } - } else { - RowMajor2Row16x4MajorInt8(matmul_input, packed_input, real_cal_num, deep); - if (per_channel) { -#ifdef ENABLE_ARM32 - PackInputSum16x4PerChannelArm32(packed_input, input_sum, filter_zp, real_cal_num, deep, - conv_param->output_channel_); -#else - PackInputSum16x4PerChannel(packed_input, input_sum, filter_zp, real_cal_num, deep, conv_param->output_channel_); -#endif - } else { - size_t hw4 = UP_ROUND(real_cal_num, C4NUM); - size_t ic16 = UP_ROUND(deep, C16NUM); - PackInputSum16x4PerLayer(packed_input, input_sum, conv_param->conv_quant_arg_.filter_quant_args_[0].zp_, hw4, - ic16); - } - } -} - -void PackInputToC8Int8(const int8_t *input_data, int16_t *packed_input, ConvParameter *conv_param) { - int in_batch = conv_param->input_batch_; - int in_channel = conv_param->input_channel_; - int in_h = conv_param->input_h_; - int in_w = conv_param->input_w_; - int ic8_round = UP_ROUND(in_channel, C8NUM); - int ic8 = in_channel / C8NUM * C8NUM; - int in_plane = in_h * in_w; - - for (int b = 0; b < in_batch; b++) { - int src_batch_offset = b * in_channel * in_plane; - int dst_batch_offset = b * ic8_round * in_plane; - for (int k = 0; k < in_plane; k++) { - int src_plane_offset = src_batch_offset + k * in_channel; - int dst_plane_offset = dst_batch_offset + k * C8NUM; - for (int i = 0; i < ic8; i += 8) { - int src_c_offset = src_plane_offset + i; - int dst_c_offset = dst_plane_offset + i * in_plane; -#ifdef ENABLE_ARM - vst1q_s16(packed_input + dst_c_offset, vmovl_s8(vld1_s8(input_data + src_c_offset))); -#else - for (int j = 0; j < C8NUM; ++j) { - (packed_input + dst_c_offset)[j] = (int16_t)(input_data + src_c_offset)[j]; - } -#endif - } // ic8_minus loop - int res_c = in_channel - ic8; - int tmp_ic_offset = ic8 * in_plane; - for (int l = 0; l < res_c; ++l) { - int src_c_offset = src_plane_offset + ic8 + l; - int dst_c_offset = dst_plane_offset + tmp_ic_offset + l; - (packed_input + dst_c_offset)[0] = (int16_t)(input_data + src_c_offset)[0]; - } // res ic loop - int res2 = ic8_round - in_channel; - for (int l = 0; l < res2; ++l) { - int dst_c_offset = dst_plane_offset + tmp_ic_offset + res_c + l; - (packed_input + dst_c_offset)[0] = 0; - } // res ic loop - } // kh * kw loop - } -} - -void PackWeightToC8Int8(const int8_t *origin_weight_data, int16_t *packed_weight_data, ConvParameter *conv_param) { - // origin weight format : ohwi - int input_channel = conv_param->input_channel_; - int ic8 = input_channel / C8NUM * C8NUM; - int ic8_round = UP_ROUND(input_channel, C8NUM); - int output_channel = conv_param->output_channel_; - QuantArg *filter_zp = conv_param->conv_quant_arg_.filter_quant_args_; - int kernel_plane = conv_param->kernel_h_ * conv_param->kernel_w_; - - for (int k = 0; k < kernel_plane; k++) { - int src_kernel_offset = k * input_channel; - int dst_kernel_offset = k * C8NUM; - for (int o = 0; o < output_channel; o++) { - int32_t zp; - if (conv_param->conv_quant_arg_.filter_arg_num_ == 1) { - zp = filter_zp[0].zp_; - } else { - zp = filter_zp[o].zp_; - } - int src_oc_offset = src_kernel_offset + o * kernel_plane * input_channel; - int dst_oc_offset = dst_kernel_offset + o * ic8_round * kernel_plane; - int i = 0; - for (; i < ic8; i += C8NUM) { - int src_ic_offset = src_oc_offset + i; - int dst_ic_offset = dst_oc_offset + i * kernel_plane; -#ifdef ENABLE_ARM64 - int8x8_t src_s8 = vld1_s8(origin_weight_data + src_ic_offset); - int16x8_t src_s16 = vmovl_s8(src_s8); - int16x4_t src1_s16 = vget_low_s16(src_s16); - int16x4_t src2_s16 = vget_high_s16(src_s16); - int32x4_t src1_s32 = vmovl_s16(src1_s16); - int32x4_t src2_s32 = vmovl_s16(src2_s16); - int32x4_t zp_s32 = vdupq_n_s32(zp); - int32x4_t dst1_s32 = vsubq_s32(src1_s32, zp_s32); - int32x4_t dst2_s32 = vsubq_s32(src2_s32, zp_s32); - int16x4_t dst1_s16 = vqmovn_s32(dst1_s32); - int16x4_t dst2_s16 = vqmovn_s32(dst2_s32); - vst1_s16(packed_weight_data + dst_ic_offset, dst1_s16); - vst1_s16(packed_weight_data + dst_ic_offset + 4, dst2_s16); -#else - for (int ci = 0; ci < C8NUM; ++ci) { - (packed_weight_data + dst_ic_offset + ci)[0] = (int16_t)((origin_weight_data + src_ic_offset + ci)[0] - zp); - } -#endif - } - dst_oc_offset += ic8 * kernel_plane; - for (; i < input_channel; i++) { - int c8_block_rem = i % C8NUM; - int src_ic_offset = src_oc_offset + i; - int dst_ic_offset = dst_oc_offset + c8_block_rem; - (packed_weight_data + dst_ic_offset)[0] = (int16_t)((origin_weight_data + src_ic_offset)[0] - zp); - } - } - } -} - -void PackNHWCToNC4HW4Fp32(const void *src, void *dst, int batch, int plane, int channel) { - int c4 = UP_DIV(channel, C4NUM); - int c4_minus = c4 - 1; - for (int b = 0; b < batch; b++) { - int src_oc_offset = b * plane * channel; - int dst_oc_offset = b * plane * c4 * C4NUM; - for (int k = 0; k < plane; k++) { - int src_kernel_offset = src_oc_offset + k * channel; - int dst_kernel_offset = dst_oc_offset + k * C4NUM; - for (int j = 0; j < c4_minus; ++j) { - int src_ic_offset = src_kernel_offset + j * C4NUM; - int dst_ic_offset = dst_kernel_offset + j * plane * C4NUM; -#ifdef ENABLE_ARM - vst1q_f32((float *)dst + dst_ic_offset, vld1q_f32((float *)src + src_ic_offset)); -#else - for (int i = 0; i < C4NUM; ++i) { - ((float *)dst + dst_ic_offset)[i] = ((float *)src + src_ic_offset)[i]; - } -#endif - } - int tmp_c = c4_minus * C4NUM; - int tmp_c_offset = tmp_c * plane; - int res_c = channel - tmp_c; - for (int l = 0; l < res_c; ++l) { - int src_ic_offset = src_kernel_offset + tmp_c + l; - int dst_ic_offset = dst_kernel_offset + tmp_c_offset + l; - ((float *)dst + dst_ic_offset)[0] = ((float *)src + src_ic_offset)[0]; - } - } - } -} - -void PackNCHWToNC4HW4Fp32(const void *src, void *dst, int batch, int plane, int channel) { - int c4 = UP_DIV(channel, C4NUM); - for (int b = 0; b < batch; b++) { - int src_offset = b * plane * channel; - int dst_offset = b * plane * c4 * C4NUM; - for (int c = 0; c < channel; c++) { - int c4_block_num = c / C4NUM; - int c4_block_rem = c % C4NUM; - int src_c_offset = src_offset + c * plane; - int dst_c_offset = dst_offset + c4_block_num * plane * C4NUM; - for (int k = 0; k < plane; k++) { - int src_kernel_offset = src_c_offset + k; - int dst_kernel_offset = dst_c_offset + C4NUM * k + c4_block_rem; - ((float *)dst + dst_kernel_offset)[0] = ((float *)src + src_kernel_offset)[0]; - } - } - } -} - -void PackNHWCToNHWC4Fp32(const void *src, void *dst, int batch, int plane, int channel) { - int c4 = UP_DIV(channel, C4NUM); - int c4_channel = c4 * C4NUM; - int nhwc4_batch_unit_offset = c4 * C4NUM * plane; - int ic_remainder_ = channel % C4NUM; - if (ic_remainder_ != 0) { - int nhwc4_batch_offset = 0; - for (int b = 0; b < batch; b++) { - int batch_offset = b * channel * plane; - for (int i = 0; i < plane; i++) { - float *dst_per_plane = (float *)dst + nhwc4_batch_offset + i * c4_channel; - memcpy(dst_per_plane, (float *)src + batch_offset + i * channel, channel * sizeof(float)); - for (int j = channel; j < c4_channel; ++j) { - dst_per_plane[j] = 0; - } - } - nhwc4_batch_offset += nhwc4_batch_unit_offset; - } - } else { - size_t ori_input_size = batch * plane * channel * sizeof(float); - memcpy((float *)dst, (float *)src, ori_input_size); - } -} - -void PackNHWCToNHWC8Fp32(const void *src, void *dst, int batch, int plane, int channel) { - int c8 = UP_DIV(channel, C8NUM); - int c8_channel = c8 * C8NUM; - int nhwc8_batch_unit_offset = c8 * C8NUM * plane; - int ic_remainder_ = channel % C8NUM; - if (ic_remainder_ != 0) { - int nhwc8_batch_offset = 0; - for (int b = 0; b < batch; b++) { - int batch_offset = b * channel * plane; - for (int i = 0; i < plane; i++) { - float *dst_per_plane = (float *)dst + nhwc8_batch_offset + i * c8_channel; - memcpy(dst_per_plane, (float *)src + batch_offset + i * channel, channel * sizeof(float)); - for (int j = channel; j < c8_channel; ++j) { - dst_per_plane[j] = 0; - } - } - nhwc8_batch_offset += nhwc8_batch_unit_offset; - } - } else { - size_t ori_input_size = batch * plane * channel * sizeof(float); - memcpy((float *)dst, (float *)src, ori_input_size); - } -} - -void PackNHWC4ToNHWCFp32(const void *src, void *dst, int batch, int plane, int channel) { - int c4 = UP_DIV(channel, C4NUM); - int ic_remainder_ = channel % C4NUM; - if (ic_remainder_ != 0) { - int nhwc_batch_unit_offset = channel * plane; - for (int b = 0; b < batch; b++) { - int batch_offset = b * c4 * C4NUM * plane; - for (int i = 0; i < plane; i++) { - memcpy((float *)dst + b * nhwc_batch_unit_offset + i * channel, (float *)src + batch_offset + i * c4 * C4NUM, - channel * sizeof(float)); - } - } - } else { - size_t ori_input_size = batch * plane * channel * sizeof(float); - memcpy((float *)dst, (float *)src, ori_input_size); - } -} - -void PackNC4HW4ToNHWC4Fp32(const void *src, void *dst, int batch, int plane, int channel) { - int c4 = UP_DIV(channel, C4NUM); - for (int b = 0; b < batch; b++) { - int src_offset = b * plane * c4 * C4NUM; - int dst_offset = b * plane * channel; - for (int c = 0; c < channel; c++) { - int c4_block_num = c / C4NUM; - int c4_block_res = c % C4NUM; - int src_c_offset = src_offset + c4_block_num * plane * C4NUM + c4_block_res; - int dst_c_offset = dst_offset + c4_block_num * C4NUM + c4_block_res; - for (int k = 0; k < plane; k++) { - int src_kernel_offset = src_c_offset + k * C4NUM; - int dst_kernel_offset = dst_c_offset + k * c4 * C4NUM; - ((float *)dst + dst_kernel_offset)[0] = ((float *)src + src_kernel_offset)[0]; - } - } - } -} - -void PackNC4HW4ToNHWCFp32(const void *src, void *dst, int batch, int plane, int channel) { - int c4 = UP_DIV(channel, C4NUM); - for (int b = 0; b < batch; b++) { - int src_offset = b * plane * c4 * C4NUM; - int dst_offset = b * plane * channel; - for (int k = 0; k < plane; k++) { - int src_kernel_offset = src_offset + k * C4NUM; - int dst_kernel_offset = dst_offset + k * channel; - for (int c = 0; c < c4 - 1; c++) { - int src_c_offset = src_kernel_offset + c * plane * C4NUM; - int dst_c_offset = dst_kernel_offset + c * C4NUM; -#ifdef ENABLE_NEON - vst1q_f32((float *)dst + dst_c_offset, vld1q_f32((float *)src + src_c_offset)); -#else - ((float *)dst + dst_c_offset)[0] = ((float *)src + src_c_offset)[0]; - ((float *)dst + dst_c_offset)[1] = ((float *)src + src_c_offset)[1]; - ((float *)dst + dst_c_offset)[2] = ((float *)src + src_c_offset)[2]; - ((float *)dst + dst_c_offset)[3] = ((float *)src + src_c_offset)[3]; -#endif - } - // res part - int res_c = channel - (c4 - 1) * C4NUM; - for (int i = 0; i < res_c; i++) { - int src_res_c_offset = src_kernel_offset + (c4 - 1) * C4NUM * plane + i; - int dst_res_c_offset = dst_kernel_offset + (c4 - 1) * C4NUM + i; - ((float *)dst + dst_res_c_offset)[0] = ((float *)src + src_res_c_offset)[0]; - } - } - } -} - -void PackNHWCToC8HWN8Fp32(const void *src, void *dst, int batch, int plane, int channel) { - for (int n = 0; n < batch; n++) { - for (int hw = 0; hw < plane; hw++) { - for (int c = 0; c < channel; c++) { - int c8div = c / C8NUM; - int c8mod = c % C8NUM; - int src_index = n * plane * channel + hw * channel + c; - int dst_index = c8div * batch * plane * C8NUM + hw * batch * C8NUM + n * C8NUM + c8mod; - ((float *)dst)[dst_index] = ((float *)src)[src_index]; - } - } - } - return; -} - -void PackDepthwiseIndirectWeightC4Fp32(const void *src, void *dst, int height, int width, int channel) { - int c4 = UP_DIV(channel, C4NUM); - for (int c = 0; c < c4; c++) { - int dst_off_c = c * C4NUM * height * width; - for (int i = 0; i < C4NUM; i++) { - int src_off_c = (c * C4NUM + i) * height * width; - for (int kh = 0; kh < height; kh++) { - int src_off_kh = src_off_c + kh * width; - for (int kw = 0; kw < width; kw++) { - int dst_off = dst_off_c + kw * height * C4NUM + kh * C4NUM + i; - ((float *)dst)[dst_off] = ((float *)src)[src_off_kh + kw]; - } - } - } - } -} - -void PackDepthwiseIndirectWeightC8Fp32(const void *src, void *dst, int height, int width, int channel) { - int c8 = UP_DIV(channel, C8NUM); - for (int c = 0; c < c8; c++) { - int dst_off_c = c * C8NUM * height * width; - for (int i = 0; i < C8NUM; i++) { - int src_off_c = (c * C8NUM + i) * height * width; - for (int kh = 0; kh < height; kh++) { - int src_off_kh = src_off_c + kh * width; - for (int kw = 0; kw < width; kw++) { - int dst_off = dst_off_c + kw * height * C8NUM + kh * C8NUM + i; - ((float *)dst)[dst_off] = ((float *)src)[src_off_kh + kw]; - } - } - } - } -} - -void PackNHWCToNHWC4Int8(const void *src, void *dst, int batch, int plane, int channel) { - int c4 = UP_DIV(channel, C4NUM); - int c4_channel = c4 * C4NUM; - int nhwc4_batch_unit_offset = c4 * C4NUM * plane; - int ic_remainder_ = channel % C4NUM; - if (ic_remainder_ != 0) { - int nhwc4_batch_offset = 0; - for (int b = 0; b < batch; b++) { - int batch_offset = b * channel * plane; - for (int i = 0; i < plane; i++) { - int8_t *dst_per_plane = (int8_t *)dst + nhwc4_batch_offset + i * c4_channel; - memcpy(dst_per_plane, (int8_t *)src + batch_offset + i * channel, channel); - for (int j = channel; j < c4_channel; ++j) { - dst_per_plane[j] = 0; - } - } - nhwc4_batch_offset += nhwc4_batch_unit_offset; - } - } else { - size_t ori_input_size = batch * plane * channel; - memcpy((int8_t *)dst, (int8_t *)src, ori_input_size); - } -} - -void PackNHWC4ToNHWCInt8(const void *src, void *dst, int batch, int plane, int channel) { - int c4 = UP_DIV(channel, C4NUM); - int nhwc4_batch_unit_offset = c4 * C4NUM * plane; - int ic_remainder_ = channel % C4NUM; - if (ic_remainder_ != 0) { - for (int b = 0; b < batch; b++) { - int batch_offset = b * channel * plane; - int nhwc4_batch_offset = b * nhwc4_batch_unit_offset; - for (int i = 0; i < plane; i++) { - memcpy((int8_t *)dst + batch_offset + i * channel, (int8_t *)src + nhwc4_batch_offset + i * c4 * C4NUM, - channel); - } - } - } else { - size_t ori_input_size = batch * plane * channel; - memcpy((int8_t *)dst, (int8_t *)src, ori_input_size); - } -} - -void PackNHWCToNHWC8Int8(const void *src, void *dst, int batch, int plane, int channel) { - int c8 = UP_DIV(channel, C8NUM); - int nhwc8_batch_unit_offset = c8 * C8NUM * plane; - int ic_remainder_ = channel % C8NUM; - if (ic_remainder_ != 0) { - int nhwc8_batch_offset = 0; - for (int b = 0; b < batch; b++) { - int batch_offset = b * channel * plane; - for (int i = 0; i < plane; i++) { - memcpy((int8_t *)dst + nhwc8_batch_offset + i * c8 * C8NUM, (int8_t *)src + batch_offset + i * channel, - channel); - } - nhwc8_batch_offset += nhwc8_batch_unit_offset; - } - } else { - size_t ori_input_size = batch * plane * channel; - memcpy((int8_t *)dst, (int8_t *)src, ori_input_size); - } -} - -void PackNHWC8ToNHWCInt8(const void *src, void *dst, int batch, int plane, int channel) { - int c8 = UP_DIV(channel, C8NUM); - int nhwc8_batch_unit_offset = c8 * C8NUM * plane; - int ic_remainder_ = channel % C8NUM; - if (ic_remainder_ != 0) { - for (int b = 0; b < batch; b++) { - int batch_offset = b * channel * plane; - int nhwc8_batch_offset = b * nhwc8_batch_unit_offset; - for (int i = 0; i < plane; i++) { - memcpy((int8_t *)dst + batch_offset + i * channel, (int8_t *)src + nhwc8_batch_offset + i * c8 * C8NUM, - channel); - } - } - } else { - size_t ori_input_size = batch * plane * channel; - memcpy((int8_t *)dst, (int8_t *)src, ori_input_size); - } -} - -void PackNCHWToNC8HW8Int8(const void *src, void *dst, int batch, int plane, int channel) { - int c8 = UP_DIV(channel, C8NUM); - for (int b = 0; b < batch; b++) { - int src_offset = b * plane * channel; - int dst_offset = b * plane * c8 * C8NUM; - for (int c = 0; c < channel; c++) { - int c8_block_num = c / C8NUM; - int c8_block_rem = c % C8NUM; - int src_c_offset = src_offset + c * plane; - int dst_c_offset = dst_offset + c8_block_num * plane * C8NUM; - for (int k = 0; k < plane; k++) { - int src_kernel_offset = src_c_offset + k; - int dst_kernel_offset = dst_c_offset + C8NUM * k + c8_block_rem; - ((int8_t *)dst + dst_kernel_offset)[0] = ((int8_t *)src + src_kernel_offset)[0]; - } - } - } -} - -void PackNC4HW4ToNHWCInt8(const void *src, void *dst, int batch, int plane, int channel) { - int c4 = UP_DIV(channel, C4NUM); - for (int b = 0; b < batch; b++) { - int src_offset = b * plane * c4 * C4NUM; - int dst_offset = b * plane * channel; - for (int k = 0; k < plane; k++) { - int src_kernel_offset = src_offset + k * C4NUM; - int dst_kernel_offset = dst_offset + k * channel; - for (int c = 0; c < c4 - 1; c++) { - int src_c_offset = src_kernel_offset + c * plane * C4NUM; - int dst_c_offset = dst_kernel_offset + c * C4NUM; - ((int8_t *)dst + dst_c_offset)[0] = ((int8_t *)src + src_c_offset)[0]; - ((int8_t *)dst + dst_c_offset)[1] = ((int8_t *)src + src_c_offset)[1]; - ((int8_t *)dst + dst_c_offset)[2] = ((int8_t *)src + src_c_offset)[2]; - ((int8_t *)dst + dst_c_offset)[3] = ((int8_t *)src + src_c_offset)[3]; - } - // res part - int res_c = channel - (c4 - 1) * C4NUM; - for (int i = 0; i < res_c; i++) { - int src_res_c_offset = src_kernel_offset + (c4 - 1) * C4NUM * plane + i; - int dst_res_c_offset = dst_kernel_offset + (c4 - 1) * C4NUM + i; - ((int8_t *)dst + dst_res_c_offset)[0] = ((int8_t *)src + src_res_c_offset)[0]; - } - } - } -} - -void PackNHWCToC8HWN8Int8(const void *src, void *dst, int batch, int plane, int channel) { - for (int n = 0; n < batch; n++) { - for (int hw = 0; hw < plane; hw++) { - for (int c = 0; c < channel; c++) { - int c8div = c / C8NUM; - int c8mod = c % C8NUM; - int src_index = n * plane * channel + hw * channel + c; - int dst_index = c8div * batch * plane * C8NUM + hw * batch * C8NUM + n * C8NUM + c8mod; - ((int8_t *)dst)[dst_index] = ((int8_t *)src)[src_index]; - } - } - } - return; -} - -void PackNCHWToNHWCInt8(const void *src, void *dst, int batch, int plane, int channel) { - for (int n = 0; n < batch; n++) { - for (int c = 0; c < channel; c++) { - for (int hw = 0; hw < plane; hw++) { - int nhwc_index = n * channel * plane + hw * channel + c; - int nchw_index = n * channel * plane + c * plane + hw; - ((int8_t *)(dst))[nhwc_index] = ((const int8_t *)(src))[nchw_index]; - } - } - } - return; -} - -#ifndef ENABLE_SSE -void PackNHWCToNCHWFp32(const void *src, void *dst, int batches, int plane, int channel) { - int hw8 = plane / C8NUM * C8NUM; - int c8 = channel / C8NUM * C8NUM; - int batch = plane * channel; - for (int n = 0; n < batches; n++) { - const float *src_batch = (const float *)src + n * batch; - float *dst_batch = (float *)dst + n * batch; - int hw = 0; - for (; hw < hw8; hw += C8NUM) { - int c = 0; - for (; c < c8; c += C8NUM) { - const float *src_ptr = src_batch + hw * channel + c; - float *dst_ptr = dst_batch + c * plane + hw; -#ifdef ENABLE_ARM64 - size_t srcStride = channel * sizeof(float); - size_t dstStride = plane * sizeof(float); - asm volatile( - "mov x10, %[src_ptr]\n" - "mov x11, %[dst_ptr]\n" - - "ld1 {v0.4s, v1.4s}, [x10], %[srcStride]\n" - "ld1 {v2.4s, v3.4s}, [x10], %[srcStride]\n" - - "zip1 v8.4s, v0.4s, v2.4s\n" - "zip2 v9.4s, v0.4s, v2.4s\n" - "zip1 v12.4s, v1.4s, v3.4s\n" - "zip2 v13.4s, v1.4s, v3.4s\n" - - "ld1 {v4.4s, v5.4s}, [x10], %[srcStride]\n" - "ld1 {v6.4s, v7.4s}, [x10], %[srcStride]\n" - - "zip1 v10.4s, v4.4s, v6.4s\n" - "zip2 v11.4s, v4.4s, v6.4s\n" - "zip1 v14.4s, v5.4s, v7.4s\n" - "zip2 v15.4s, v5.4s, v7.4s\n" - - "ld1 {v0.4s, v1.4s}, [x10], %[srcStride]\n" - "ld1 {v2.4s, v3.4s}, [x10], %[srcStride]\n" - - "trn1 v16.2d, v8.2d, v10.2d\n" - "trn2 v18.2d, v8.2d, v10.2d\n" - "trn1 v20.2d, v9.2d, v11.2d\n" - "trn2 v22.2d, v9.2d, v11.2d\n" - - "ld1 {v4.4s, v5.4s}, [x10], %[srcStride]\n" - "ld1 {v6.4s, v7.4s}, [x10], %[srcStride]\n" - - "trn1 v24.2d, v12.2d, v14.2d\n" - "trn2 v26.2d, v12.2d, v14.2d\n" - "trn1 v28.2d, v13.2d, v15.2d\n" - "trn2 v30.2d, v13.2d, v15.2d\n" - - "zip1 v8.4s, v0.4s, v2.4s\n" - "zip2 v9.4s, v0.4s, v2.4s\n" - "zip1 v12.4s, v1.4s, v3.4s\n" - "zip2 v13.4s, v1.4s, v3.4s\n" - - "zip1 v10.4s, v4.4s, v6.4s\n" - "zip2 v11.4s, v4.4s, v6.4s\n" - "zip1 v14.4s, v5.4s, v7.4s\n" - "zip2 v15.4s, v5.4s, v7.4s\n" - - "trn1 v17.2d, v8.2d, v10.2d\n" - "trn2 v19.2d, v8.2d, v10.2d\n" - "trn1 v21.2d, v9.2d, v11.2d\n" - "trn2 v23.2d, v9.2d, v11.2d\n" - - "trn1 v25.2d, v12.2d, v14.2d\n" - "trn2 v27.2d, v12.2d, v14.2d\n" - "trn1 v29.2d, v13.2d, v15.2d\n" - "trn2 v31.2d, v13.2d, v15.2d\n" - - "st1 {v16.4s, v17.4s}, [x11], %[dstStride]\n" - "st1 {v18.4s, v19.4s}, [x11], %[dstStride]\n" - "st1 {v20.4s, v21.4s}, [x11], %[dstStride]\n" - "st1 {v22.4s, v23.4s}, [x11], %[dstStride]\n" - "st1 {v24.4s, v25.4s}, [x11], %[dstStride]\n" - "st1 {v26.4s, v27.4s}, [x11], %[dstStride]\n" - "st1 {v28.4s, v29.4s}, [x11], %[dstStride]\n" - "st1 {v30.4s, v31.4s}, [x11], %[dstStride]\n" - - : - : - [ dst_ptr ] "r"(dst_ptr), [ src_ptr ] "r"(src_ptr), [ srcStride ] "r"(srcStride), [ dstStride ] "r"(dstStride) - : "x10", "x11", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", - "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", - "v30", "v31"); -#elif ENABLE_ARM32 - size_t srcStride = channel * sizeof(float); - size_t dstStride = plane * sizeof(float); - asm volatile( - "mov r10, %[src_ptr]\n" - "mov r12, %[dst_ptr]\n" - - "vld1.32 {q0, q1}, [r10], %[srcStride]\n" - "vld1.32 {q2, q3}, [r10], %[srcStride]\n" - - "vtrn.32 d0, d4\n" - "vtrn.32 d1, d5\n" - "vtrn.32 d2, d6\n" - "vtrn.32 d3, d7\n" - - "vld1.32 {q4, q5}, [r10], %[srcStride]\n" - "vld1.32 {q6, q7}, [r10], %[srcStride]\n" - - "vtrn.32 d8, d12\n" - "vtrn.32 d9, d13\n" - "vtrn.32 d10, d14\n" - "vtrn.32 d11, d15\n" - - "vld1.32 {q8, q9}, [r10], %[srcStride]\n" - "vld1.32 {q10, q11}, [r10], %[srcStride]\n" - - "vswp d1, d8\n" - "vswp d3, d10\n" - "vswp d5, d12\n" - "vswp d7, d14\n" - - "vtrn.32 d16, d20\n" - "vtrn.32 d17, d21\n" - "vtrn.32 d18, d22\n" - "vtrn.32 d19, d23\n" - - "vld1.32 {q12, q13}, [r10], %[srcStride]\n" - "vld1.32 {q14, q15}, [r10], %[srcStride]\n" - - "vtrn.32 d24, d28\n" - "vtrn.32 d25, d29\n" - "vtrn.32 d26, d30\n" - "vtrn.32 d27, d31\n" - - "vswp d17, d24\n" - "vswp d19, d26\n" - "vswp d21, d28\n" - "vswp d23, d30\n" - - "add r10, r12, #16\n" - "vst1.32 {q0}, [r12], %[dstStride]\n" - "vst1.32 {q8}, [r10], %[dstStride]\n" - "vst1.32 {q2}, [r12], %[dstStride]\n" - "vst1.32 {q10}, [r10], %[dstStride]\n" - "vst1.32 {q4}, [r12], %[dstStride]\n" - "vst1.32 {q12}, [r10], %[dstStride]\n" - "vst1.32 {q6}, [r12], %[dstStride]\n" - "vst1.32 {q14}, [r10], %[dstStride]\n" - "vst1.32 {q1}, [r12], %[dstStride]\n" - "vst1.32 {q9}, [r10], %[dstStride]\n" - "vst1.32 {q3}, [r12], %[dstStride]\n" - "vst1.32 {q11}, [r10], %[dstStride]\n" - "vst1.32 {q5}, [r12], %[dstStride]\n" - "vst1.32 {q13}, [r10], %[dstStride]\n" - "vst1.32 {q7}, [r12], %[dstStride]\n" - "vst1.32 {q15}, [r10], %[dstStride]\n" - - : - : - [ dst_ptr ] "r"(dst_ptr), [ src_ptr ] "r"(src_ptr), [ srcStride ] "r"(srcStride), [ dstStride ] "r"(dstStride) - : "r10", "r12", "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7", "q8", "q9", "q10", "q11", "q12", "q13", "q14", - "q15"); -#else - for (int tr = 0; tr < C8NUM; tr++) { - for (int tc = 0; tc < C8NUM; tc++) { - dst_ptr[tc * plane + tr] = src_ptr[tr * channel + tc]; - } - } -#endif - } - for (; c < channel; c++) { - const float *src_ptr = src_batch + hw * channel + c; - float *dst_ptr = dst_batch + c * plane + hw; - for (size_t i = 0; i < C8NUM; i++) { - dst_ptr[i] = src_ptr[i * channel]; - } - } - } - for (; hw < plane; hw++) { - const float *src_ptr = src_batch + hw * channel; - float *dst_ptr = dst_batch + hw; - for (size_t i = 0; i < channel; i++) { - dst_ptr[i * plane] = src_ptr[i]; - } - } - } - return; -} -#endif - -void PackNHWCToNCHWInt8(const void *src, void *dst, int batches, int plane, int channel) { - int hw8 = plane / C8NUM * C8NUM; - int c8 = channel / C8NUM * C8NUM; - int batch = plane * channel; - for (int n = 0; n < batches; n++) { - const int8_t *src_batch = (const int8_t *)src + n * batch; - int8_t *dst_batch = (int8_t *)dst + n * batch; - int hw = 0; - for (; hw < hw8; hw += C8NUM) { - int c = 0; - for (; c < c8; c += C8NUM) { - const int8_t *src_ptr = src_batch + hw * channel + c; - int8_t *dst_ptr = dst_batch + c * plane + hw; -#ifdef ENABLE_ARM64 - size_t srcStride = channel * sizeof(int8_t); - size_t dstStride = plane * sizeof(int8_t); - asm volatile( - "mov x10, %[src_ptr]\n" - "mov x11, %[dst_ptr]\n" - - "ld1 {v0.8b}, [x10], %[srcStride]\n" - "ld1 {v1.8b}, [x10], %[srcStride]\n" - "ld1 {v2.8b}, [x10], %[srcStride]\n" - "ld1 {v3.8b}, [x10], %[srcStride]\n" - - "trn1 v4.8b, v0.8b, v1.8b\n" - "trn2 v5.8b, v0.8b, v1.8b\n" - "trn1 v6.8b, v2.8b, v3.8b\n" - "trn2 v7.8b, v2.8b, v3.8b\n" - - "ld1 {v0.8b}, [x10], %[srcStride]\n" - "ld1 {v1.8b}, [x10], %[srcStride]\n" - "ld1 {v2.8b}, [x10], %[srcStride]\n" - "ld1 {v3.8b}, [x10], %[srcStride]\n" - - "trn1 v8.4h, v4.4h, v6.4h\n" - "trn2 v9.4h, v4.4h, v6.4h\n" - "trn1 v10.4h, v5.4h, v7.4h\n" - "trn2 v11.4h, v5.4h, v7.4h\n" - - "trn1 v4.8b, v0.8b, v1.8b\n" - "trn2 v5.8b, v0.8b, v1.8b\n" - "trn1 v6.8b, v2.8b, v3.8b\n" - "trn2 v7.8b, v2.8b, v3.8b\n" - - "trn1 v12.4h, v4.4h, v6.4h\n" - "trn2 v13.4h, v4.4h, v6.4h\n" - "trn1 v14.4h, v5.4h, v7.4h\n" - "trn2 v15.4h, v5.4h, v7.4h\n" - - "trn1 v0.2s, v8.2s, v12.2s\n" - "trn2 v4.2s, v8.2s, v12.2s\n" - "trn1 v1.2s, v10.2s, v14.2s\n" - "trn2 v5.2s, v10.2s, v14.2s\n" - "trn1 v2.2s, v9.2s, v13.2s\n" - "trn2 v6.2s, v9.2s, v13.2s\n" - "trn1 v3.2s, v11.2s, v15.2s\n" - "trn2 v7.2s, v11.2s, v15.2s\n" - - "st1 {v0.8b}, [x11], %[dstStride]\n" - "st1 {v1.8b}, [x11], %[dstStride]\n" - "st1 {v2.8b}, [x11], %[dstStride]\n" - "st1 {v3.8b}, [x11], %[dstStride]\n" - "st1 {v4.8b}, [x11], %[dstStride]\n" - "st1 {v5.8b}, [x11], %[dstStride]\n" - "st1 {v6.8b}, [x11], %[dstStride]\n" - "st1 {v7.8b}, [x11], %[dstStride]\n" - : - : - [ dst_ptr ] "r"(dst_ptr), [ src_ptr ] "r"(src_ptr), [ srcStride ] "r"(srcStride), [ dstStride ] "r"(dstStride) - : "x10", "x11", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", - "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", - "v30", "v31"); -#elif ENABLE_ARM32 - size_t srcStride = channel * sizeof(int8_t); - size_t dstStride = plane * sizeof(int8_t); - asm volatile( - "mov r10, %[src_ptr]\n" - "mov r12, %[dst_ptr]\n" - - "vld1.8 {d0}, [r10], %[srcStride]\n" - "vld1.8 {d1}, [r10], %[srcStride]\n" - "vld1.8 {d2}, [r10], %[srcStride]\n" - "vld1.8 {d3}, [r10], %[srcStride]\n" - "vld1.8 {d4}, [r10], %[srcStride]\n" - "vld1.8 {d5}, [r10], %[srcStride]\n" - "vld1.8 {d6}, [r10], %[srcStride]\n" - "vld1.8 {d7}, [r10], %[srcStride]\n" - - "vtrn.8 d0, d1\n" - "vtrn.8 d2, d3\n" - "vtrn.8 d4, d5\n" - "vtrn.8 d6, d7\n" - - "vtrn.16 d0, d2\n" - "vtrn.16 d1, d3\n" - "vtrn.16 d4, d6\n" - "vtrn.16 d5, d7\n" - - "vtrn.32 d0, d4\n" - "vtrn.32 d1, d5\n" - "vtrn.32 d2, d6\n" - "vtrn.32 d3, d7\n" - - "vst1.8 {d0}, [r12], %[dstStride]\n" - "vst1.8 {d1}, [r12], %[dstStride]\n" - "vst1.8 {d2}, [r12], %[dstStride]\n" - "vst1.8 {d3}, [r12], %[dstStride]\n" - "vst1.8 {d4}, [r12], %[dstStride]\n" - "vst1.8 {d5}, [r12], %[dstStride]\n" - "vst1.8 {d6}, [r12], %[dstStride]\n" - "vst1.8 {d7}, [r12], %[dstStride]\n" - : - : - [ dst_ptr ] "r"(dst_ptr), [ src_ptr ] "r"(src_ptr), [ srcStride ] "r"(srcStride), [ dstStride ] "r"(dstStride) - : "r10", "r12", "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7", "q8", "q9", "q10", "q11", "q12", "q13", "q14", - "q15"); -#else - for (int tr = 0; tr < C8NUM; tr++) { - for (int tc = 0; tc < C8NUM; tc++) { - dst_ptr[tc * plane + tr] = src_ptr[tr * channel + tc]; - } - } -#endif - } - for (; c < channel; c++) { - const int8_t *src_ptr = src_batch + hw * channel + c; - int8_t *dst_ptr = dst_batch + c * plane + hw; - for (size_t i = 0; i < C8NUM; i++) { - dst_ptr[i] = src_ptr[i * channel]; - } - } - } - for (; hw < plane; hw++) { - const int8_t *src_ptr = src_batch + hw * channel; - int8_t *dst_ptr = dst_batch + hw; - for (size_t i = 0; i < channel; i++) { - dst_ptr[i * plane] = src_ptr[i]; - } - } - } - return; -} - -void PackNCHWToNHWCFp32(const void *src, void *dst, int batch, int plane, int channel) { - return PackNHWCToNCHWFp32(src, dst, batch, channel, plane); -} - -void PackDepthwiseInt8Input(const int8_t *src, int16_t *dst, const ConvParameter *conv_param) { - int input_zp = conv_param->conv_quant_arg_.input_quant_args_[0].zp_; - int ic4 = UP_DIV(conv_param->input_channel_, C4NUM); - int unit = conv_param->input_h_ * conv_param->input_w_; - - for (int b = 0; b < conv_param->input_batch_; b++) { - const int8_t *src_b = src + b * unit * conv_param->input_channel_; - int16_t *dst_b = dst + b * unit * ic4 * C4NUM; - for (int k = 0; k < unit; k++) { - const int8_t *src_k = src_b + k * conv_param->input_channel_; - int16_t *dst_k = dst_b + k * ic4 * C4NUM; - for (int c = 0; c < conv_param->input_channel_; c++) { - dst_k[c] = (int16_t)(src_k[c] - input_zp); - } - } - } -} - -void PackDepthwiseInt8Weight(const int8_t *origin_weight, int16_t *packed_weight_, int plane, int channel, - ConvQuantArg *quant_qrg) { - int weight_zp = quant_qrg->filter_quant_args_[0].zp_; - for (int c = 0; c < channel; c++) { - if (quant_qrg->per_channel_ & FILTER_PER_CHANNEL) { - weight_zp = quant_qrg->filter_quant_args_[c].zp_; - } - int c8_block_num = c / C8NUM; - int c8_block_rem = c % C8NUM; - const int8_t *src_c = origin_weight + c * plane; - int16_t *dst_c = packed_weight_ + c8_block_num * plane * C8NUM; - for (int k = 0; k < plane; k++) { - const int8_t *src_kernel = src_c + k; - int16_t *dst_kernel = dst_c + C8NUM * k + c8_block_rem; - *dst_kernel = (int16_t)(src_kernel[0] - weight_zp); - } - } -} - -void PackDeconvDepthwiseInt8Weight(const int8_t *origin_weight, int16_t *packed_weight_, int plane, int channel, - ConvQuantArg *quant_qrg) { - int weight_zp = quant_qrg->filter_quant_args_[0].zp_; - for (int c = 0; c < channel; c++) { - if (quant_qrg->per_channel_ & FILTER_PER_CHANNEL) { - weight_zp = quant_qrg->filter_quant_args_[c].zp_; - } - int c4_block_num = c / C4NUM; - int c4_block_rem = c % C4NUM; - const int8_t *src_c = origin_weight + c * plane; - int16_t *dst_c = packed_weight_ + c4_block_num * plane * C4NUM; - for (int k = 0; k < plane; k++) { - const int8_t *src_kernel = src_c + k; - int16_t *dst_kernel = dst_c + C4NUM * k + c4_block_rem; - *dst_kernel = (int16_t)(src_kernel[0] - weight_zp); - } - } -} diff --git a/mindspore/lite/nnacl/pack.h b/mindspore/lite/nnacl/pack.h index 13288c50e77..f51ec8410a9 100644 --- a/mindspore/lite/nnacl/pack.h +++ b/mindspore/lite/nnacl/pack.h @@ -17,102 +17,12 @@ #ifndef MINDSPORE_LITE_NNACL_PACK_H_ #define MINDSPORE_LITE_NNACL_PACK_H_ -#include -#ifdef ENABLE_NEON -#include -#endif -#include "nnacl/conv_parameter.h" -#include "nnacl/op_base.h" +#include "nnacl/fp32/pack_fp32.h" +#include "nnacl/int8/pack_int8.h" #ifdef __cplusplus extern "C" { #endif -void Im2ColPackUnitFp32(const float *input_data, const ConvParameter *conv_param, float *packed_input, int real_cal_num, - int block_index); - -void PackHWCToWHC(const float *src, float *dst, int height, int width, int channel); - -void Im2ColPackUnitInt8Opt(const int8_t *input_data, int8_t *packed_input, int8_t *matmul_input, int real_cal_num, - int block_index, int32_t *filter_zp, int32_t *input_sum, ConvParameter *conv_param, - bool per_channel, bool is_optimize); - -void PackInputSum16x4PerLayer(const int8_t *src, int32_t *dst, int32_t filter_zp, size_t row4, size_t col16); - -void PackInputSum16x4PerChannelArm32(const int8_t *input_value, int32_t *input_sum, int32_t *filter_zp_ptr, - size_t plane_size, size_t input_channel, size_t output_channel); - -void PackInputSum16x4PerChannel(const int8_t *input_value, int32_t *input_sum, int32_t *filter_zp_ptr, - size_t plane_size, size_t input_channel, size_t output_channel); - -void Conv1x1InputPack(const void *src_ptr, void *dst_ptr, ConvParameter *conv_param, int data_size); - -void Pack1x1WeightFp32(const float *weight_data, float *packed_weight, ConvParameter *conv_param); - -void PackInputSum16x4Int8(const int8_t *input, int32_t *input_sum, int32_t *filter_zp, ConvParameter *conv_param); - -void PackInputToC8Int8(const int8_t *input_data, int16_t *packed_input, ConvParameter *conv_param); - -void PackWeightKHWToHWKFp32(const void *src, void *dst, int plane, int channel); - -void PackWeightInt8(int8_t *weight_data, ConvParameter *conv_param, int8_t *packed_weight, int32_t *weight_sum); - -void PackWeightToC8Int8(const int8_t *origin_weight_data, int16_t *packed_weight_data, ConvParameter *conv_param); - -void PackNHWCToNC4HW4Fp32(const void *src, void *dst, int batch, int plane, int channel); - -void PackNCHWToNC4HW4Fp32(const void *src, void *dst, int batch, int plane, int channel); - -void PackNHWCToNHWC4Fp32(const void *src, void *dst, int batch, int plane, int channel); - -void PackNHWCToNHWC8Fp32(const void *src, void *dst, int batch, int plane, int channel); - -void PackNHWCToNCHWFp32(const void *src, void *dst, int batch, int plane, int channel); - -void PackNHWCToNCHWInt8(const void *src, void *dst, int batch, int plane, int channel); - -void PackNCHWToNHWCFp32(const void *src, void *dst, int batch, int plane, int channel); - -void PackNHWC4ToNHWCFp32(const void *src, void *dst, int batch, int plane, int channel); - -void PackNC4HW4ToNHWC4Fp32(const void *src, void *dst, int batch, int plane, int channel); - -void PackNC4HW4ToNHWCFp32(const void *src, void *dst, int batch, int plane, int channel); - -void PackNHWCToC8HWN8Fp32(const void *src, void *dst, int batch, int plane, int channel); - -void PackDepthwiseIndirectWeightC4Fp32(const void *src, void *dst, int height, int width, int channel); - -void PackDepthwiseIndirectWeightC8Fp32(const void *src, void *dst, int height, int width, int channel); - -void PackNHWCToNHWC4Int8(const void *src, void *dst, int batch, int plane, int channel); - -void PackNHWC4ToNHWCInt8(const void *src, void *dst, int batch, int plane, int channel); - -void PackNHWCToNHWC8Int8(const void *src, void *dst, int batch, int plane, int channel); - -void PackNHWC8ToNHWCInt8(const void *src, void *dst, int batch, int plane, int channel); - -void PackNCHWToNC8HW8Int8(const void *src, void *dst, int batch, int plane, int channel); - -void PackNC4HW4ToNHWCInt8(const void *src, void *dst, int batch, int plane, int channel); - -void PackNHWCToC8HWN8Int8(const void *src, void *dst, int batch, int plane, int channel); - -void PackNCHWToNHWCInt8(const void *src, void *dst, int batch, int plane, int channel); - -void PackDepthwiseInt8Input(const int8_t *src, int16_t *dst, const ConvParameter *conv_param); - -void PackDepthwiseInt8Weight(const int8_t *origin_weight, int16_t *packed_weight_, int plane, int channel, - ConvQuantArg *quant_qrg); - -void PackDeconvDepthwiseInt8Weight(const int8_t *origin_weight, int16_t *packed_weight_, int plane, int channel, - ConvQuantArg *quant_qrg); - -#ifdef ENABLE_ARM -void PreSum4x16Int8Pert(const int8_t *src, int32_t *sum, size_t row4, size_t col16, int32_t filter_zp); -void PreSum4x16Int8Peroc(const int8_t *src, int32_t *sum, int32_t *zp, size_t hw4, size_t ic16, int32_t oc_div, - size_t oc_res, size_t stride); -#endif #ifdef __cplusplus } diff --git a/mindspore/lite/nnacl/pad_parameter.h b/mindspore/lite/nnacl/pad_parameter.h index ed378f1f3a8..ff2629d5368 100644 --- a/mindspore/lite/nnacl/pad_parameter.h +++ b/mindspore/lite/nnacl/pad_parameter.h @@ -17,11 +17,16 @@ #define MINDSPORE_LITE_NNACL_PAD_PARAMETER_H_ #include "nnacl/op_base.h" -#include "nnacl/quantization/quantize.h" #define MAX_PAD_SIZE 8 #define DEFAULT_PAD_NDIMS 4 +typedef struct PadQuantArg { + QuantArg *in_quant_args_; + QuantArg *out_quanr_args_; + int8_t *constant_value_; +} PadQuantArg; + typedef struct PadParameter { // Primitive parameter OpParameter op_parameter_; diff --git a/mindspore/lite/nnacl/pooling_parameter.h b/mindspore/lite/nnacl/pooling_parameter.h index 6e7db32fef9..84f7e1b068f 100644 --- a/mindspore/lite/nnacl/pooling_parameter.h +++ b/mindspore/lite/nnacl/pooling_parameter.h @@ -17,7 +17,6 @@ #define MINDSPORE_LITE_NNACL_POOLING_PARAMETER_H_ #include "nnacl/op_base.h" -#include "nnacl/quantization/quantize.h" typedef enum PoolMode { PoolMode_No, PoolMode_MaxPool, PoolMode_AvgPool } PoolMode; diff --git a/mindspore/lite/nnacl/power_parameter.h b/mindspore/lite/nnacl/power_parameter.h index a81bd24f923..34f46a73e82 100644 --- a/mindspore/lite/nnacl/power_parameter.h +++ b/mindspore/lite/nnacl/power_parameter.h @@ -18,7 +18,14 @@ #define MINDSPORE_LITE_NNACL_POWER_PARAMETER_H_ #include "nnacl/op_base.h" -#include "nnacl/quantization/quantize.h" + +typedef struct PowerQuantArg { + QuantArg in_args_; + QuantArg exp_args_; + QuantArg out_args_; + int output_activation_min_; + int output_activation_max_; +} PowerQuantArg; typedef struct PowerParameter { // Primitive parameter diff --git a/mindspore/lite/nnacl/reshape_parameter.h b/mindspore/lite/nnacl/reshape_parameter.h index 2e07660fbbb..101ca27329a 100644 --- a/mindspore/lite/nnacl/reshape_parameter.h +++ b/mindspore/lite/nnacl/reshape_parameter.h @@ -18,7 +18,13 @@ #define MINDSPORE_LITE_NNACL_RESHAHPE_PARAMETER_H_ #include "nnacl/op_base.h" -#include "nnacl/quantization/quantize.h" + +typedef struct ReshapeQuantArg { + QuantArg in_args_; + QuantArg out_args_; + int output_activation_min_; + int output_activation_max_; +} ReshapeQuantArg; typedef struct ReshapeParameter { // primitive parameter diff --git a/mindspore/lite/nnacl/scale.h b/mindspore/lite/nnacl/scale.h index 2695edc0f6d..dbca958234d 100644 --- a/mindspore/lite/nnacl/scale.h +++ b/mindspore/lite/nnacl/scale.h @@ -17,8 +17,8 @@ #ifndef MINDSPORE_LITE_NNACL_SCALE_H_ #define MINDSPORE_LITE_NNACL_SCALE_H_ -#include #include "nnacl/op_base.h" + typedef struct ScaleParameter { // primitive parameter OpParameter op_parameter_; diff --git a/mindspore/lite/nnacl/scatter_nd.c b/mindspore/lite/nnacl/scatter_nd.c index 56fabec0657..2aefdc4c487 100644 --- a/mindspore/lite/nnacl/scatter_nd.c +++ b/mindspore/lite/nnacl/scatter_nd.c @@ -16,7 +16,6 @@ #include "nnacl/scatter_nd.h" #include -#include #include "nnacl/errorcode.h" int DoScatterND(float *output_ptr, const float *update, int *output_unit_offsets, int unit_size, int num_units) { diff --git a/mindspore/lite/nnacl/slice_parameter.h b/mindspore/lite/nnacl/slice_parameter.h index 506cfde62e1..5ebad629653 100644 --- a/mindspore/lite/nnacl/slice_parameter.h +++ b/mindspore/lite/nnacl/slice_parameter.h @@ -18,10 +18,16 @@ #define MINDSPORE_LITE_NNACL_SLICE_PARAMETER_H_ #include "nnacl/op_base.h" -#include "nnacl/quantization/quantize.h" #define SLICE_SHAPE_MAX_SIZE 4 +typedef struct SliceQuantArg { + QuantArg in_args_; + QuantArg out_args_; + int output_activation_min_; + int output_activation_max_; +} SliceQuantArg; + typedef struct SliceParameter { // primitive parameter OpParameter op_parameter_; diff --git a/mindspore/lite/nnacl/split_parameter.h b/mindspore/lite/nnacl/split_parameter.h index 5b9e543ab96..7eeb4a6212c 100644 --- a/mindspore/lite/nnacl/split_parameter.h +++ b/mindspore/lite/nnacl/split_parameter.h @@ -18,8 +18,16 @@ #define MINDSPORE_LITE_NNACL_SPLIT_PARAMETER_H_ #include "nnacl/op_base.h" -#include "nnacl/quantization/quantize.h" + #define SPLIT_STRIDES_SIZE 32 + +typedef struct SplitQuantArg { + QuantArg in_args_; + QuantArg out_args_[20]; + int output_activation_min_; + int output_activation_max_; +} SplitQuantArg; + typedef struct SplitParameter { // primitive parameter OpParameter op_parameter_; diff --git a/mindspore/lite/nnacl/squeeze_parameter.h b/mindspore/lite/nnacl/squeeze_parameter.h index 091eac00a57..a1634e43db1 100644 --- a/mindspore/lite/nnacl/squeeze_parameter.h +++ b/mindspore/lite/nnacl/squeeze_parameter.h @@ -16,11 +16,16 @@ #ifndef MINDSPORE_LITE_NNACL_SQUEEZE_PARAMETER_H_ #define MINDSPORE_LITE_NNACL_SQUEEZE_PARAMETER_H_ + #include "nnacl/op_base.h" -#include "nnacl/quantization/quantize.h" #define SQUEEZE_OFFSET_MAX_SIZE 4 +typedef struct SqueezeQuantArg { + QuantArg *in_quant_args_; + QuantArg *out_quant_args_; +} SqueezeQuantArg; + typedef struct SqueezeParameter { // primitive parameter OpParameter op_parameter_; diff --git a/mindspore/lite/nnacl/unsqueeze_parameter.h b/mindspore/lite/nnacl/unsqueeze_parameter.h index 9c739d68764..e3b1c949b7e 100644 --- a/mindspore/lite/nnacl/unsqueeze_parameter.h +++ b/mindspore/lite/nnacl/unsqueeze_parameter.h @@ -16,11 +16,26 @@ #ifndef MINDSPORE_LITE_NNACL_UNSQUEEZE_PARAMETER_H_ #define MINDSPORE_LITE_NNACL_UNSQUEEZE_PARAMETER_H_ + +#include +#include #include "nnacl/op_base.h" -#include "nnacl/quantization/quantize.h" #define UNSQUEEZE_OFFSET_MAX_SIZE 4 +typedef struct UnSqueezeQuantArg { + int *input_sizes_; + int output_size_; + int **input_shapes_; + int *output_shape_; + float alpha; + int axis_; + size_t input_num_; + size_t output_dim_; + QuantArg in_quant_args_; + QuantArg out_quant_args_; +} UnSqueezeQuantArg; + typedef struct UnSqueezeParameter { // primitive parameter OpParameter op_parameter_; diff --git a/mindspore/lite/nnacl/winograd_transform.c b/mindspore/lite/nnacl/winograd_transform.c index 483755c808a..a14a0266006 100644 --- a/mindspore/lite/nnacl/winograd_transform.c +++ b/mindspore/lite/nnacl/winograd_transform.c @@ -140,810 +140,3 @@ void WinogradOutputTransform(const float *gemm_out, float *out_data, const float out_tile_index++; } } - -// int8 conv3x3 -void Conv3x3Int8InputUnit(int16_t *tmp_data, int16_t *trans_input_data, size_t step, int input_zp) { -#ifdef ENABLE_ARM - int16x8_t zp = vdupq_n_s16(input_zp); - - int16x8_t d00 = vsubq_s16(vld1q_s16(tmp_data), zp); - int16x8_t d01 = vsubq_s16(vld1q_s16(tmp_data + 8), zp); - int16x8_t d02 = vsubq_s16(vld1q_s16(tmp_data + 2 * 8), zp); - int16x8_t d03 = vsubq_s16(vld1q_s16(tmp_data + 3 * 8), zp); - - int16x8_t d10 = vsubq_s16(vld1q_s16(tmp_data + 4 * 8), zp); - int16x8_t d11 = vsubq_s16(vld1q_s16(tmp_data + 5 * 8), zp); - int16x8_t d12 = vsubq_s16(vld1q_s16(tmp_data + 6 * 8), zp); - int16x8_t d13 = vsubq_s16(vld1q_s16(tmp_data + 7 * 8), zp); - - int16x8_t d20 = vsubq_s16(vld1q_s16(tmp_data + 8 * 8), zp); - int16x8_t d21 = vsubq_s16(vld1q_s16(tmp_data + 9 * 8), zp); - int16x8_t d22 = vsubq_s16(vld1q_s16(tmp_data + 10 * 8), zp); - int16x8_t d23 = vsubq_s16(vld1q_s16(tmp_data + 11 * 8), zp); - - int16x8_t d30 = vsubq_s16(vld1q_s16(tmp_data + 12 * 8), zp); - int16x8_t d31 = vsubq_s16(vld1q_s16(tmp_data + 13 * 8), zp); - int16x8_t d32 = vsubq_s16(vld1q_s16(tmp_data + 14 * 8), zp); - int16x8_t d33 = vsubq_s16(vld1q_s16(tmp_data + 15 * 8), zp); - - int16x8_t t00 = vsubq_s16(d00, d20); - int16x8_t t01 = vsubq_s16(d01, d21); - int16x8_t t02 = vsubq_s16(d02, d22); - int16x8_t t03 = vsubq_s16(d03, d23); - - int16x8_t t10 = vaddq_s16(d10, d20); - int16x8_t t11 = vaddq_s16(d11, d21); - int16x8_t t12 = vaddq_s16(d12, d22); - int16x8_t t13 = vaddq_s16(d13, d23); - - int16x8_t t20 = vsubq_s16(d20, d10); - int16x8_t t21 = vsubq_s16(d21, d11); - int16x8_t t22 = vsubq_s16(d22, d12); - int16x8_t t23 = vsubq_s16(d23, d13); - - int16x8_t t30 = vsubq_s16(d10, d30); - int16x8_t t31 = vsubq_s16(d11, d31); - int16x8_t t32 = vsubq_s16(d12, d32); - int16x8_t t33 = vsubq_s16(d13, d33); - - int16x8_t m00 = vsubq_s16(t00, t02); - int16x8_t m01 = vaddq_s16(t01, t02); - int16x8_t m02 = vsubq_s16(t02, t01); - int16x8_t m03 = vsubq_s16(t01, t03); - - int16x8_t m10 = vsubq_s16(t10, t12); - int16x8_t m11 = vaddq_s16(t11, t12); - int16x8_t m12 = vsubq_s16(t12, t11); - int16x8_t m13 = vsubq_s16(t11, t13); - - int16x8_t m20 = vsubq_s16(t20, t22); - int16x8_t m21 = vaddq_s16(t21, t22); - int16x8_t m22 = vsubq_s16(t22, t21); - int16x8_t m23 = vsubq_s16(t21, t23); - - int16x8_t m30 = vsubq_s16(t30, t32); - int16x8_t m31 = vaddq_s16(t31, t32); - int16x8_t m32 = vsubq_s16(t32, t31); - int16x8_t m33 = vsubq_s16(t31, t33); - - vst1q_s16(trans_input_data, m00); - vst1q_s16(trans_input_data + step, m01); - vst1q_s16(trans_input_data + 2 * step, m02); - vst1q_s16(trans_input_data + 3 * step, m03); - - vst1q_s16(trans_input_data + 4 * step, m10); - vst1q_s16(trans_input_data + 5 * step, m11); - vst1q_s16(trans_input_data + 6 * step, m12); - vst1q_s16(trans_input_data + 7 * step, m13); - - vst1q_s16(trans_input_data + 8 * step, m20); - vst1q_s16(trans_input_data + 9 * step, m21); - vst1q_s16(trans_input_data + 10 * step, m22); - vst1q_s16(trans_input_data + 11 * step, m23); - - vst1q_s16(trans_input_data + 12 * step, m30); - vst1q_s16(trans_input_data + 13 * step, m31); - vst1q_s16(trans_input_data + 14 * step, m32); - vst1q_s16(trans_input_data + 15 * step, m33); -#else - for (int i = 0; i < C8NUM; i++) { - int16_t *local_ptr = tmp_data + i; - int16_t d00 = local_ptr[0] - input_zp; - int16_t d01 = (local_ptr + C8NUM)[0] - input_zp; - int16_t d02 = (local_ptr + 2 * C8NUM)[0] - input_zp; - int16_t d03 = (local_ptr + 3 * C8NUM)[0] - input_zp; - - int16_t d10 = (local_ptr + 4 * C8NUM)[0] - input_zp; - int16_t d11 = (local_ptr + 5 * C8NUM)[0] - input_zp; - int16_t d12 = (local_ptr + 6 * C8NUM)[0] - input_zp; - int16_t d13 = (local_ptr + 7 * C8NUM)[0] - input_zp; - - int16_t d20 = (local_ptr + 8 * C8NUM)[0] - input_zp; - int16_t d21 = (local_ptr + 9 * C8NUM)[0] - input_zp; - int16_t d22 = (local_ptr + 10 * C8NUM)[0] - input_zp; - int16_t d23 = (local_ptr + 11 * C8NUM)[0] - input_zp; - - int16_t d30 = (local_ptr + 12 * C8NUM)[0] - input_zp; - int16_t d31 = (local_ptr + 13 * C8NUM)[0] - input_zp; - int16_t d32 = (local_ptr + 14 * C8NUM)[0] - input_zp; - int16_t d33 = (local_ptr + 15 * C8NUM)[0] - input_zp; - - int16_t t00 = d00 - d20; - int16_t t01 = d01 - d21; - int16_t t02 = d02 - d22; - int16_t t03 = d03 - d23; - - int16_t t10 = d10 + d20; - int16_t t11 = d11 + d21; - int16_t t12 = d12 + d22; - int16_t t13 = d13 + d23; - - int16_t t20 = d20 - d10; - int16_t t21 = d21 - d11; - int16_t t22 = d22 - d12; - int16_t t23 = d23 - d13; - - int16_t t30 = d10 - d30; - int16_t t31 = d11 - d31; - int16_t t32 = d12 - d32; - int16_t t33 = d13 - d33; - - int16_t m00 = t00 - t02; - int16_t m01 = t01 + t02; - int16_t m02 = t02 - t01; - int16_t m03 = t01 - t03; - - int16_t m10 = t10 - t12; - int16_t m11 = t11 + t12; - int16_t m12 = t12 - t11; - int16_t m13 = t11 - t13; - - int16_t m20 = t20 - t22; - int16_t m21 = t21 + t22; - int16_t m22 = t22 - t21; - int16_t m23 = t21 - t23; - - int16_t m30 = t30 - t32; - int16_t m31 = t31 + t32; - int16_t m32 = t32 - t31; - int16_t m33 = t31 - t33; - - (trans_input_data + i)[0] = m00; - (trans_input_data + i + step)[0] = m01; - (trans_input_data + i + 2 * step)[0] = m02; - (trans_input_data + i + 3 * step)[0] = m03; - - (trans_input_data + i + 4 * step)[0] = m10; - (trans_input_data + i + 5 * step)[0] = m11; - (trans_input_data + i + 6 * step)[0] = m12; - (trans_input_data + i + 7 * step)[0] = m13; - - (trans_input_data + i + 8 * step)[0] = m20; - (trans_input_data + i + 9 * step)[0] = m21; - (trans_input_data + i + 10 * step)[0] = m22; - (trans_input_data + i + 11 * step)[0] = m23; - - (trans_input_data + i + 12 * step)[0] = m30; - (trans_input_data + i + 13 * step)[0] = m31; - (trans_input_data + i + 14 * step)[0] = m32; - (trans_input_data + i + 15 * step)[0] = m33; - } -#endif -} - -void Conv3x3Int8InputTransform(const int16_t *input_data, int16_t *trans_input, int16_t *tmp_data, int start_index, - int real_cal_num, int out_w_block, ConvParameter *conv_param) { - // input data format : nhwc - int input_channel = conv_param->input_channel_; - int input_width = conv_param->input_w_; - int input_height = conv_param->input_h_; - int pad_w = conv_param->pad_l_; - int pad_h = conv_param->pad_u_; - ConvQuantArg quant_arg = conv_param->conv_quant_arg_; - int input_zp = quant_arg.input_quant_args_[0].zp_; - const int ic8 = UP_DIV(input_channel, C8NUM); - const int input_unit = 4; - if (out_w_block == 0) { - return; - } - for (int cal_id = 0; cal_id < real_cal_num; cal_id++) { - int x_id = start_index + cal_id; - int origin_x = (x_id % out_w_block) * OUPUT_UNIT - pad_w; - int origin_y = (x_id / out_w_block) * OUPUT_UNIT - pad_h; - int real_x_start = origin_x > 0 ? 0 : -origin_x; - int real_x_end = (origin_x + input_unit) < input_width ? input_unit : (input_width - origin_x); - int real_y_start = origin_y > 0 ? 0 : -origin_y; - int real_y_end = (origin_y + input_unit) < input_height ? input_unit : (input_height - origin_y); - - int src_plane_offset = C8NUM * (origin_y * input_width + origin_x); - int dst_plane_offset = cal_id * C8NUM; - for (int ic = 0; ic < ic8; ic++) { - // copy data from origin input to tmp buffer - for (int i = 0; i < input_unit * input_unit * TILE_NUM; i++) tmp_data[i] = input_zp; - - int src_c8_offset = src_plane_offset + ic * C8NUM * input_height * input_width; - for (int j = real_y_start; j < real_y_end; j++) { - const int16_t *src = input_data + src_c8_offset + C8NUM * (j * input_width + real_x_start); - int16_t *dst = tmp_data + C8NUM * (C4NUM * j + real_x_start); - memcpy(dst, src, (real_x_end - real_x_start) * C8NUM * sizeof(int16_t)); - } - // input transform - int dst_ic8_offset = dst_plane_offset + ic * TILE_NUM * C8NUM; - size_t dst_step = ic8 * C8NUM * TILE_NUM; - int16_t *trans_input_ptr = trans_input + dst_ic8_offset; - Conv3x3Int8InputUnit(tmp_data, trans_input_ptr, dst_step, input_zp); - } - } -} - -void Conv3x3Int8FilterTransform(const int16_t *weight_data, int16_t *trans_weight, int iC8, int output_channel, - int kernel_plane) { - const int input_unit = 4; - int dst_step = iC8 * C8NUM * C4NUM; - for (int o = 0; o < output_channel; o++) { - int oc4_block_num = o / C4NUM; - int oc4_block_rem = o % C4NUM; - int src_oc_offset = o * iC8 * C8NUM * kernel_plane; - int dst_oc_offset = oc4_block_num * C4NUM * iC8 * C8NUM * input_unit * input_unit + oc4_block_rem; - for (int i = 0; i < iC8; i++) { - const int16_t *src_ic8_ptr = weight_data + src_oc_offset + i * kernel_plane * C8NUM; - int16_t *dst_ic8_ptr = trans_weight + dst_oc_offset + i * C4NUM * C8NUM; -#ifdef ENABLE_ARM - int16x8_t g00 = vld1q_s16(src_ic8_ptr); - int16x8_t g01 = vld1q_s16(src_ic8_ptr + 8); - int16x8_t g02 = vld1q_s16(src_ic8_ptr + 2 * 8); - int16x8_t g10 = vld1q_s16(src_ic8_ptr + 3 * 8); - int16x8_t g11 = vld1q_s16(src_ic8_ptr + 4 * 8); - int16x8_t g12 = vld1q_s16(src_ic8_ptr + 5 * 8); - int16x8_t g20 = vld1q_s16(src_ic8_ptr + 6 * 8); - int16x8_t g21 = vld1q_s16(src_ic8_ptr + 7 * 8); - int16x8_t g22 = vld1q_s16(src_ic8_ptr + 8 * 8); - - int16x8_t dst00 = vmulq_n_s16(g00, 2); - int16x8_t dst01 = vmulq_n_s16(g01, 2); - int16x8_t dst02 = vmulq_n_s16(g02, 2); - - int16x8_t dst10 = vaddq_s16(vaddq_s16(g00, g10), g20); - int16x8_t dst11 = vaddq_s16(vaddq_s16(g01, g11), g21); - int16x8_t dst12 = vaddq_s16(vaddq_s16(g02, g12), g22); - - int16x8_t dst20 = vaddq_s16(vsubq_s16(g00, g10), g20); - int16x8_t dst21 = vaddq_s16(vsubq_s16(g01, g11), g21); - int16x8_t dst22 = vaddq_s16(vsubq_s16(g02, g12), g22); - - int16x8_t dst30 = vmulq_n_s16(g20, 2); - int16x8_t dst31 = vmulq_n_s16(g21, 2); - int16x8_t dst32 = vmulq_n_s16(g22, 2); - - int16x8_t m00 = vmulq_n_s16(dst00, 2); - int16x8_t m01 = vaddq_s16(vaddq_s16(dst00, dst01), dst02); - int16x8_t m02 = vaddq_s16(vsubq_s16(dst00, dst01), dst02); - int16x8_t m03 = vmulq_n_s16(dst02, 2); - - int16x8_t m10 = vmulq_n_s16(dst10, 2); - int16x8_t m11 = vaddq_s16(vaddq_s16(dst10, dst11), dst12); - int16x8_t m12 = vaddq_s16(vsubq_s16(dst10, dst11), dst12); - int16x8_t m13 = vmulq_n_s16(dst12, 2); - - int16x8_t m20 = vmulq_n_s16(dst20, 2); - int16x8_t m21 = vaddq_s16(vaddq_s16(dst20, dst21), dst22); - int16x8_t m22 = vaddq_s16(vsubq_s16(dst20, dst21), dst22); - int16x8_t m23 = vmulq_n_s16(dst22, 2); - - int16x8_t m30 = vmulq_n_s16(dst30, 2); - int16x8_t m31 = vaddq_s16(vaddq_s16(dst30, dst31), dst32); - int16x8_t m32 = vaddq_s16(vsubq_s16(dst30, dst31), dst32); - int16x8_t m33 = vmulq_n_s16(dst32, 2); - - dst_ic8_ptr[0] = m00[0]; - dst_ic8_ptr[4] = m00[1]; - dst_ic8_ptr[8] = m00[2]; - dst_ic8_ptr[12] = m00[3]; - dst_ic8_ptr[16] = m00[4]; - dst_ic8_ptr[20] = m00[5]; - dst_ic8_ptr[24] = m00[6]; - dst_ic8_ptr[28] = m00[7]; - - dst_ic8_ptr[0 + dst_step] = m01[0]; - dst_ic8_ptr[4 + dst_step] = m01[1]; - dst_ic8_ptr[8 + dst_step] = m01[2]; - dst_ic8_ptr[12 + dst_step] = m01[3]; - dst_ic8_ptr[16 + dst_step] = m01[4]; - dst_ic8_ptr[20 + dst_step] = m01[5]; - dst_ic8_ptr[24 + dst_step] = m01[6]; - dst_ic8_ptr[28 + dst_step] = m01[7]; - - dst_ic8_ptr[0 + 2 * dst_step] = m02[0]; - dst_ic8_ptr[4 + 2 * dst_step] = m02[1]; - dst_ic8_ptr[8 + 2 * dst_step] = m02[2]; - dst_ic8_ptr[12 + 2 * dst_step] = m02[3]; - dst_ic8_ptr[16 + 2 * dst_step] = m02[4]; - dst_ic8_ptr[20 + 2 * dst_step] = m02[5]; - dst_ic8_ptr[24 + 2 * dst_step] = m02[6]; - dst_ic8_ptr[28 + 2 * dst_step] = m02[7]; - - dst_ic8_ptr[0 + 3 * dst_step] = m03[0]; - dst_ic8_ptr[4 + 3 * dst_step] = m03[1]; - dst_ic8_ptr[8 + 3 * dst_step] = m03[2]; - dst_ic8_ptr[12 + 3 * dst_step] = m03[3]; - dst_ic8_ptr[16 + 3 * dst_step] = m03[4]; - dst_ic8_ptr[20 + 3 * dst_step] = m03[5]; - dst_ic8_ptr[24 + 3 * dst_step] = m03[6]; - dst_ic8_ptr[28 + 3 * dst_step] = m03[7]; - - dst_ic8_ptr[0 + 4 * dst_step] = m10[0]; - dst_ic8_ptr[4 + 4 * dst_step] = m10[1]; - dst_ic8_ptr[8 + 4 * dst_step] = m10[2]; - dst_ic8_ptr[12 + 4 * dst_step] = m10[3]; - dst_ic8_ptr[16 + 4 * dst_step] = m10[4]; - dst_ic8_ptr[20 + 4 * dst_step] = m10[5]; - dst_ic8_ptr[24 + 4 * dst_step] = m10[6]; - dst_ic8_ptr[28 + 4 * dst_step] = m10[7]; - - dst_ic8_ptr[0 + 5 * dst_step] = m11[0]; - dst_ic8_ptr[4 + 5 * dst_step] = m11[1]; - dst_ic8_ptr[8 + 5 * dst_step] = m11[2]; - dst_ic8_ptr[12 + 5 * dst_step] = m11[3]; - dst_ic8_ptr[16 + 5 * dst_step] = m11[4]; - dst_ic8_ptr[20 + 5 * dst_step] = m11[5]; - dst_ic8_ptr[24 + 5 * dst_step] = m11[6]; - dst_ic8_ptr[28 + 5 * dst_step] = m11[7]; - - dst_ic8_ptr[0 + 6 * dst_step] = m12[0]; - dst_ic8_ptr[4 + 6 * dst_step] = m12[1]; - dst_ic8_ptr[8 + 6 * dst_step] = m12[2]; - dst_ic8_ptr[12 + 6 * dst_step] = m12[3]; - dst_ic8_ptr[16 + 6 * dst_step] = m12[4]; - dst_ic8_ptr[20 + 6 * dst_step] = m12[5]; - dst_ic8_ptr[24 + 6 * dst_step] = m12[6]; - dst_ic8_ptr[28 + 6 * dst_step] = m12[7]; - - dst_ic8_ptr[0 + 7 * dst_step] = m13[0]; - dst_ic8_ptr[4 + 7 * dst_step] = m13[1]; - dst_ic8_ptr[8 + 7 * dst_step] = m13[2]; - dst_ic8_ptr[12 + 7 * dst_step] = m13[3]; - dst_ic8_ptr[16 + 7 * dst_step] = m13[4]; - dst_ic8_ptr[20 + 7 * dst_step] = m13[5]; - dst_ic8_ptr[24 + 7 * dst_step] = m13[6]; - dst_ic8_ptr[28 + 7 * dst_step] = m13[7]; - - dst_ic8_ptr[0 + 8 * dst_step] = m20[0]; - dst_ic8_ptr[4 + 8 * dst_step] = m20[1]; - dst_ic8_ptr[8 + 8 * dst_step] = m20[2]; - dst_ic8_ptr[12 + 8 * dst_step] = m20[3]; - dst_ic8_ptr[16 + 8 * dst_step] = m20[4]; - dst_ic8_ptr[20 + 8 * dst_step] = m20[5]; - dst_ic8_ptr[24 + 8 * dst_step] = m20[6]; - dst_ic8_ptr[28 + 8 * dst_step] = m20[7]; - - dst_ic8_ptr[0 + 9 * dst_step] = m21[0]; - dst_ic8_ptr[4 + 9 * dst_step] = m21[1]; - dst_ic8_ptr[8 + 9 * dst_step] = m21[2]; - dst_ic8_ptr[12 + 9 * dst_step] = m21[3]; - dst_ic8_ptr[16 + 9 * dst_step] = m21[4]; - dst_ic8_ptr[20 + 9 * dst_step] = m21[5]; - dst_ic8_ptr[24 + 9 * dst_step] = m21[6]; - dst_ic8_ptr[28 + 9 * dst_step] = m21[7]; - - dst_ic8_ptr[0 + 10 * dst_step] = m22[0]; - dst_ic8_ptr[4 + 10 * dst_step] = m22[1]; - dst_ic8_ptr[8 + 10 * dst_step] = m22[2]; - dst_ic8_ptr[12 + 10 * dst_step] = m22[3]; - dst_ic8_ptr[16 + 10 * dst_step] = m22[4]; - dst_ic8_ptr[20 + 10 * dst_step] = m22[5]; - dst_ic8_ptr[24 + 10 * dst_step] = m22[6]; - dst_ic8_ptr[28 + 10 * dst_step] = m22[7]; - - dst_ic8_ptr[0 + 11 * dst_step] = m23[0]; - dst_ic8_ptr[4 + 11 * dst_step] = m23[1]; - dst_ic8_ptr[8 + 11 * dst_step] = m23[2]; - dst_ic8_ptr[12 + 11 * dst_step] = m23[3]; - dst_ic8_ptr[16 + 11 * dst_step] = m23[4]; - dst_ic8_ptr[20 + 11 * dst_step] = m23[5]; - dst_ic8_ptr[24 + 11 * dst_step] = m23[6]; - dst_ic8_ptr[28 + 11 * dst_step] = m23[7]; - - dst_ic8_ptr[0 + 12 * dst_step] = m30[0]; - dst_ic8_ptr[4 + 12 * dst_step] = m30[1]; - dst_ic8_ptr[8 + 12 * dst_step] = m30[2]; - dst_ic8_ptr[12 + 12 * dst_step] = m30[3]; - dst_ic8_ptr[16 + 12 * dst_step] = m30[4]; - dst_ic8_ptr[20 + 12 * dst_step] = m30[5]; - dst_ic8_ptr[24 + 12 * dst_step] = m30[6]; - dst_ic8_ptr[28 + 12 * dst_step] = m30[7]; - - dst_ic8_ptr[0 + 13 * dst_step] = m31[0]; - dst_ic8_ptr[4 + 13 * dst_step] = m31[1]; - dst_ic8_ptr[8 + 13 * dst_step] = m31[2]; - dst_ic8_ptr[12 + 13 * dst_step] = m31[3]; - dst_ic8_ptr[16 + 13 * dst_step] = m31[4]; - dst_ic8_ptr[20 + 13 * dst_step] = m31[5]; - dst_ic8_ptr[24 + 13 * dst_step] = m31[6]; - dst_ic8_ptr[28 + 13 * dst_step] = m31[7]; - - dst_ic8_ptr[0 + 14 * dst_step] = m32[0]; - dst_ic8_ptr[4 + 14 * dst_step] = m32[1]; - dst_ic8_ptr[8 + 14 * dst_step] = m32[2]; - dst_ic8_ptr[12 + 14 * dst_step] = m32[3]; - dst_ic8_ptr[16 + 14 * dst_step] = m32[4]; - dst_ic8_ptr[20 + 14 * dst_step] = m32[5]; - dst_ic8_ptr[24 + 14 * dst_step] = m32[6]; - dst_ic8_ptr[28 + 14 * dst_step] = m32[7]; - - dst_ic8_ptr[0 + 15 * dst_step] = m33[0]; - dst_ic8_ptr[4 + 15 * dst_step] = m33[1]; - dst_ic8_ptr[8 + 15 * dst_step] = m33[2]; - dst_ic8_ptr[12 + 15 * dst_step] = m33[3]; - dst_ic8_ptr[16 + 15 * dst_step] = m33[4]; - dst_ic8_ptr[20 + 15 * dst_step] = m33[5]; - dst_ic8_ptr[24 + 15 * dst_step] = m33[6]; - dst_ic8_ptr[28 + 15 * dst_step] = m33[7]; -#else - for (int j = 0; j < C8NUM; j++) { - const int16_t *local_ptr = src_ic8_ptr + j; - int16_t dst00 = local_ptr[0] * 2; - int16_t dst01 = (local_ptr + 8)[0] * 2; - int16_t dst02 = (local_ptr + 16)[0] * 2; - - int16_t dst10 = local_ptr[0] + (local_ptr + 24)[0] + (local_ptr + 48)[0]; - int16_t dst11 = (local_ptr + 8)[0] + (local_ptr + 32)[0] + (local_ptr + 56)[0]; - int16_t dst12 = (local_ptr + 16)[0] + (local_ptr + 40)[0] + (local_ptr + 64)[0]; - - int16_t dst20 = local_ptr[0] - (local_ptr + 24)[0] + (local_ptr + 48)[0]; - int16_t dst21 = (local_ptr + 8)[0] - (local_ptr + 32)[0] + (local_ptr + 56)[0]; - int16_t dst22 = (local_ptr + 16)[0] - (local_ptr + 40)[0] + (local_ptr + 64)[0]; - - int16_t dst30 = (local_ptr + 48)[0] * 2; - int16_t dst31 = (local_ptr + 56)[0] * 2; - int16_t dst32 = (local_ptr + 64)[0] * 2; - - int16_t m00 = dst00 * 2; - int16_t m01 = dst00 + dst01 + dst02; - int16_t m02 = dst00 - dst01 + dst02; - int16_t m03 = dst02 * 2; - - int16_t m10 = dst10 * 2; - int16_t m11 = dst10 + dst11 + dst12; - int16_t m12 = dst10 - dst11 + dst12; - int16_t m13 = dst12 * 2; - - int16_t m20 = dst20 * 2; - int16_t m21 = dst20 + dst21 + dst22; - int16_t m22 = dst20 - dst21 + dst22; - int16_t m23 = dst22 * 2; - - int16_t m30 = dst30 * 2; - int16_t m31 = dst30 + dst31 + dst32; - int16_t m32 = dst30 - dst31 + dst32; - int16_t m33 = dst32 * 2; - - *(dst_ic8_ptr + j * 4) = m00; - *(dst_ic8_ptr + j * 4 + dst_step) = m01; - *(dst_ic8_ptr + j * 4 + 2 * dst_step) = m02; - *(dst_ic8_ptr + j * 4 + 3 * dst_step) = m03; - - *(dst_ic8_ptr + j * 4 + 4 * dst_step) = m10; - *(dst_ic8_ptr + j * 4 + 5 * dst_step) = m11; - *(dst_ic8_ptr + j * 4 + 6 * dst_step) = m12; - *(dst_ic8_ptr + j * 4 + 7 * dst_step) = m13; - - *(dst_ic8_ptr + j * 4 + 8 * dst_step) = m20; - *(dst_ic8_ptr + j * 4 + 9 * dst_step) = m21; - *(dst_ic8_ptr + j * 4 + 10 * dst_step) = m22; - *(dst_ic8_ptr + j * 4 + 11 * dst_step) = m23; - - *(dst_ic8_ptr + j * 4 + 12 * dst_step) = m30; - *(dst_ic8_ptr + j * 4 + 13 * dst_step) = m31; - *(dst_ic8_ptr + j * 4 + 14 * dst_step) = m32; - *(dst_ic8_ptr + j * 4 + 15 * dst_step) = m33; - } -#endif - } - } -} - -void Conv3x3Int8OutputUnit(const int32_t *gemm_out, const int32_t *bias_data, int8_t *output_data, bool h_not_bound, - bool w_not_bound, int output_w, int real_num, int oc_start, ConvParameter *conv_param) { - int32_t *left_shift = conv_param->conv_quant_arg_.left_shift_; - int32_t *right_shift = conv_param->conv_quant_arg_.right_shift_; - int32_t *quant_multiplier = conv_param->conv_quant_arg_.quant_multiplier_; - int output_zp = conv_param->conv_quant_arg_.output_quant_args_[0].zp_; - int out_min = conv_param->conv_quant_arg_.out_act_min_[0]; - int out_max = conv_param->conv_quant_arg_.out_act_max_[0]; - -#ifdef ENABLE_ARM - int32x4_t bias_ptr = vld1q_s32(bias_data); - - int32x4_t s00 = vld1q_s32(gemm_out); - int32x4_t s01 = vld1q_s32(gemm_out + 4); - int32x4_t s02 = vld1q_s32(gemm_out + 8); - int32x4_t s03 = vld1q_s32(gemm_out + 12); - - int32x4_t s10 = vld1q_s32(gemm_out + 16); - int32x4_t s11 = vld1q_s32(gemm_out + 20); - int32x4_t s12 = vld1q_s32(gemm_out + 24); - int32x4_t s13 = vld1q_s32(gemm_out + 28); - - int32x4_t s20 = vld1q_s32(gemm_out + 32); - int32x4_t s21 = vld1q_s32(gemm_out + 36); - int32x4_t s22 = vld1q_s32(gemm_out + 40); - int32x4_t s23 = vld1q_s32(gemm_out + 44); - - int32x4_t s30 = vld1q_s32(gemm_out + 48); - int32x4_t s31 = vld1q_s32(gemm_out + 52); - int32x4_t s32 = vld1q_s32(gemm_out + 56); - int32x4_t s33 = vld1q_s32(gemm_out + 60); - - int32x4_t t00 = vshrq_n_s32(vaddq_s32(vaddq_s32(s00, s10), s20), 1); - int32x4_t t01 = vshrq_n_s32(vaddq_s32(vaddq_s32(s01, s11), s21), 1); - int32x4_t t02 = vshrq_n_s32(vaddq_s32(vaddq_s32(s02, s12), s22), 1); - int32x4_t t03 = vshrq_n_s32(vaddq_s32(vaddq_s32(s03, s13), s23), 1); - - int32x4_t t10 = vshrq_n_s32(vsubq_s32(vsubq_s32(s10, s20), s30), 1); - int32x4_t t11 = vshrq_n_s32(vsubq_s32(vsubq_s32(s11, s21), s31), 1); - int32x4_t t12 = vshrq_n_s32(vsubq_s32(vsubq_s32(s12, s22), s32), 1); - int32x4_t t13 = vshrq_n_s32(vsubq_s32(vsubq_s32(s13, s23), s33), 1); - - int32x4_t d00 = vaddq_s32(vshrq_n_s32(vaddq_s32(vaddq_s32(t00, t01), t02), 1), bias_ptr); - int32x4_t d01 = vaddq_s32(vshrq_n_s32(vsubq_s32(vsubq_s32(t01, t02), t03), 1), bias_ptr); - - int32x4_t d10 = vaddq_s32(vshrq_n_s32(vaddq_s32(vaddq_s32(t10, t11), t12), 1), bias_ptr); - int32x4_t d11 = vaddq_s32(vshrq_n_s32(vsubq_s32(vsubq_s32(t11, t12), t13), 1), bias_ptr); - - int32x4_t out_multiplier; - int32x4_t ls; - int32x4_t rs; - if ((conv_param->conv_quant_arg_.per_channel_ & FILTER_PER_CHANNEL)) { - out_multiplier = vld1q_s32(quant_multiplier + oc_start); - ls = vld1q_s32(left_shift + oc_start); - rs = vld1q_s32(right_shift + oc_start); - } else { - out_multiplier = vdupq_n_s32(quant_multiplier[0]); - ls = vdupq_n_s32(left_shift[0]); - rs = vdupq_n_s32(right_shift[0]); - } - int32x4_t out_zp = vdupq_n_s32(output_zp); - int32x4_t output_min = vdupq_n_s32(out_min); - int32x4_t output_max = vdupq_n_s32(out_max); - - d00 = vqshlq_s32(d00, ls); - d00 = vqrdmulhq_s32(d00, out_multiplier); - int32x4_t carry = vandq_s32(d00, rs); - carry = vshrq_n_s32(carry, 31); - d00 = vqaddq_s32(d00, carry); - d00 = vqrshlq_s32(d00, rs); - d00 = vaddq_s32(d00, out_zp); - d00 = vmaxq_s32(d00, output_min); - d00 = vminq_s32(d00, output_max); - - d01 = vqshlq_s32(d01, ls); - d01 = vqrdmulhq_s32(d01, out_multiplier); - carry = vandq_s32(d01, rs); - carry = vshrq_n_s32(carry, 31); - d01 = vqaddq_s32(d01, carry); - d01 = vqrshlq_s32(d01, rs); - d01 = vaddq_s32(d01, out_zp); - d01 = vmaxq_s32(d01, output_min); - d01 = vminq_s32(d01, output_max); - - d10 = vqshlq_s32(d10, ls); - d10 = vqrdmulhq_s32(d10, out_multiplier); - carry = vandq_s32(d10, rs); - carry = vshrq_n_s32(carry, 31); - d10 = vqaddq_s32(d10, carry); - d10 = vqrshlq_s32(d10, rs); - d10 = vaddq_s32(d10, out_zp); - d10 = vmaxq_s32(d10, output_min); - d10 = vminq_s32(d10, output_max); - - d11 = vqshlq_s32(d11, ls); - d11 = vqrdmulhq_s32(d11, out_multiplier); - carry = vandq_s32(d11, rs); - carry = vshrq_n_s32(carry, 31); - d11 = vqaddq_s32(d11, carry); - d11 = vqrshlq_s32(d11, rs); - d11 = vaddq_s32(d11, out_zp); - d11 = vmaxq_s32(d11, output_min); - d11 = vminq_s32(d11, output_max); - - (output_data)[0] = (int8_t)d00[0]; - (output_data + 1)[0] = (int8_t)d00[1]; - (output_data + 2)[0] = (int8_t)d00[2]; - (output_data + 3)[0] = (int8_t)d00[3]; - - if (w_not_bound) { - *(output_data + 4) = (int8_t)d01[0]; - *(output_data + 5) = (int8_t)d01[1]; - *(output_data + 6) = (int8_t)d01[2]; - *(output_data + 7) = (int8_t)d01[3]; - } - if (h_not_bound) { - *(output_data + output_w * 4) = (int8_t)d10[0]; - *(output_data + output_w * 4 + 1) = (int8_t)d10[1]; - *(output_data + output_w * 4 + 2) = (int8_t)d10[2]; - *(output_data + output_w * 4 + 3) = (int8_t)d10[3]; - if (w_not_bound) { - *(output_data + output_w * 4 + 4) = (int8_t)d11[0]; - *(output_data + output_w * 4 + 5) = (int8_t)d11[1]; - *(output_data + output_w * 4 + 6) = (int8_t)d11[2]; - *(output_data + output_w * 4 + 7) = (int8_t)d11[3]; - } - } -#else - if ((conv_param->conv_quant_arg_.per_channel_ & FILTER_PER_CHANNEL)) { - for (int i = 0; i < C4NUM; i++) { - const int32_t *local_ptr = gemm_out + i; - const int32_t *bias_ptr = bias_data + i; - - int32_t s00 = local_ptr[0]; - int32_t s01 = (local_ptr + 4)[0]; - int32_t s02 = (local_ptr + 8)[0]; - int32_t s03 = (local_ptr + 12)[0]; - - int32_t s10 = (local_ptr + 16)[0]; - int32_t s11 = (local_ptr + 20)[0]; - int32_t s12 = (local_ptr + 24)[0]; - int32_t s13 = (local_ptr + 28)[0]; - - int32_t s20 = (local_ptr + 32)[0]; - int32_t s21 = (local_ptr + 36)[0]; - int32_t s22 = (local_ptr + 40)[0]; - int32_t s23 = (local_ptr + 44)[0]; - - int32_t s30 = (local_ptr + 48)[0]; - int32_t s31 = (local_ptr + 52)[0]; - int32_t s32 = (local_ptr + 56)[0]; - int32_t s33 = (local_ptr + 60)[0]; - - int32_t t00 = (s00 + s10 + s20) / 2; - int32_t t01 = (s01 + s11 + s21) / 2; - int32_t t02 = (s02 + s12 + s22) / 2; - int32_t t03 = (s03 + s13 + s23) / 2; - - int32_t t10 = (s10 - s20 - s30) / 2; - int32_t t11 = (s11 - s21 - s31) / 2; - int32_t t12 = (s12 - s22 - s32) / 2; - int32_t t13 = (s13 - s23 - s33) / 2; - - int32_t d00 = (t00 + t01 + t02) / 2 + bias_ptr[0]; - int32_t d01 = (t01 - t02 - t03) / 2 + bias_ptr[0]; - - int32_t d10 = (t10 + t11 + t12) / 2 + bias_ptr[0]; - int32_t d11 = (t11 - t12 - t13) / 2 + bias_ptr[0]; - - int oc_index = oc_start + i; - d00 = RoundingDivideByPOT( - SaturatingRoundingDoublingHighMul(d00 * (1 << (unsigned int)left_shift[oc_index]), quant_multiplier[oc_index]), - -right_shift[oc_index]); - d00 += output_zp; - d00 = d00 > out_min ? d00 : out_min; - d00 = d00 < out_max ? d00 : out_max; - - d01 = RoundingDivideByPOT( - SaturatingRoundingDoublingHighMul(d01 * (1 << (unsigned int)left_shift[oc_index]), quant_multiplier[oc_index]), - -right_shift[oc_index]); - d01 += output_zp; - d01 = d01 > out_min ? d01 : out_min; - d01 = d01 < out_max ? d01 : out_max; - - d10 = RoundingDivideByPOT( - SaturatingRoundingDoublingHighMul(d10 * (1 << (unsigned int)left_shift[oc_index]), quant_multiplier[oc_index]), - -right_shift[oc_index]); - d10 += output_zp; - d10 = d10 > out_min ? d10 : out_min; - d10 = d10 < out_max ? d10 : out_max; - - d11 = RoundingDivideByPOT( - SaturatingRoundingDoublingHighMul(d11 * (1 << (unsigned int)left_shift[oc_index]), quant_multiplier[oc_index]), - -right_shift[oc_index]); - d11 += output_zp; - d11 = d11 > out_min ? d11 : out_min; - d11 = d11 < out_max ? d11 : out_max; - - (output_data + i)[0] = (int8_t)d00; - if (w_not_bound) { - (output_data + i + C4NUM)[0] = (int8_t)d01; - } - if (h_not_bound) { - (output_data + i + output_w * C4NUM)[0] = (int8_t)d10; - if (w_not_bound) { - (output_data + i + output_w * C4NUM + C4NUM)[0] = (int8_t)d11; - } - } - } - } else { - for (int i = 0; i < C4NUM; i++) { - const int32_t *local_ptr = gemm_out + i; - const int32_t *bias_ptr = bias_data + i; - - int32_t s00 = local_ptr[0]; - int32_t s01 = (local_ptr + 4)[0]; - int32_t s02 = (local_ptr + 8)[0]; - int32_t s03 = (local_ptr + 12)[0]; - - int32_t s10 = (local_ptr + 16)[0]; - int32_t s11 = (local_ptr + 20)[0]; - int32_t s12 = (local_ptr + 24)[0]; - int32_t s13 = (local_ptr + 28)[0]; - - int32_t s20 = (local_ptr + 32)[0]; - int32_t s21 = (local_ptr + 36)[0]; - int32_t s22 = (local_ptr + 40)[0]; - int32_t s23 = (local_ptr + 44)[0]; - - int32_t s30 = (local_ptr + 48)[0]; - int32_t s31 = (local_ptr + 52)[0]; - int32_t s32 = (local_ptr + 56)[0]; - int32_t s33 = (local_ptr + 60)[0]; - - int32_t t00 = (s00 + s10 + s20) / 2; - int32_t t01 = (s01 + s11 + s21) / 2; - int32_t t02 = (s02 + s12 + s22) / 2; - int32_t t03 = (s03 + s13 + s23) / 2; - - int32_t t10 = (s10 - s20 - s30) / 2; - int32_t t11 = (s11 - s21 - s31) / 2; - int32_t t12 = (s12 - s22 - s32) / 2; - int32_t t13 = (s13 - s23 - s33) / 2; - - int32_t d00 = (t00 + t01 + t02) / 2 + bias_ptr[0]; - int32_t d01 = (t01 - t02 - t03) / 2 + bias_ptr[0]; - - int32_t d10 = (t10 + t11 + t12) / 2 + bias_ptr[0]; - int32_t d11 = (t11 - t12 - t13) / 2 + bias_ptr[0]; - - d00 = RoundingDivideByPOT( - SaturatingRoundingDoublingHighMul(d00 * (1 << (unsigned int)left_shift[0]), quant_multiplier[0]), - -right_shift[0]); - d00 += output_zp; - d00 = d00 > out_min ? d00 : out_min; - d00 = d00 < out_max ? d00 : out_max; - - d01 = RoundingDivideByPOT( - SaturatingRoundingDoublingHighMul(d01 * (1 << (unsigned int)left_shift[0]), quant_multiplier[0]), - -right_shift[0]); - d01 += output_zp; - d01 = d01 > out_min ? d01 : out_min; - d01 = d01 < out_max ? d01 : out_max; - - d10 = RoundingDivideByPOT( - SaturatingRoundingDoublingHighMul(d10 * (1 << (unsigned int)left_shift[0]), quant_multiplier[0]), - -right_shift[0]); - d10 += output_zp; - d10 = d10 > out_min ? d10 : out_min; - d10 = d10 < out_max ? d10 : out_max; - - d11 = RoundingDivideByPOT( - SaturatingRoundingDoublingHighMul(d11 * (1 << (unsigned int)left_shift[0]), quant_multiplier[0]), - -right_shift[0]); - d11 += output_zp; - d11 = d11 > out_min ? d11 : out_min; - d11 = d11 < out_max ? d11 : out_max; - - (output_data + i)[0] = (int8_t)d00; - if (w_not_bound) { - (output_data + i + C4NUM)[0] = (int8_t)d01; - } - if (h_not_bound) { - (output_data + i + output_w * C4NUM)[0] = (int8_t)d10; - if (w_not_bound) { - (output_data + i + output_w * C4NUM + C4NUM)[0] = (int8_t)d11; - } - } - } - } -#endif -} - -void Conv3x3Int8OutputTransform(const int32_t *gemm_out, int8_t *out_data, const int32_t *bias_data, int start_index, - int real_cal_num, int out_w_block, ConvParameter *conv_param) { - int output_channel = conv_param->output_channel_; - int output_w = conv_param->output_w_; - int output_h = conv_param->output_h_; - const int oc4 = UP_DIV(output_channel, C4NUM); - const int input_unit = 4; - if (out_w_block == 0) { - return; - } - for (int i = 0; i < real_cal_num; i++) { - int out_w_index = (start_index + i) % out_w_block; - int out_h_index = (start_index + i) / out_w_block; - int src_tile_offset = i * oc4 * C4NUM * input_unit * input_unit; - int dst_tile_offset = C4NUM * (out_w_index * OUPUT_UNIT + out_h_index * OUPUT_UNIT * output_w); - - for (int j = 0; j < oc4; j++) { - int src_oc4_offset = src_tile_offset + j * input_unit * input_unit * C4NUM; - int dst_oc4_offset = dst_tile_offset + j * C4NUM * output_h * output_w; - const int32_t *src_ptr = gemm_out + src_oc4_offset; - const int32_t *bias_ptr = bias_data + j * C4NUM; - int8_t *dst_ptr = out_data + dst_oc4_offset; - - // output transform - int real_num = (output_channel - j * C4NUM) < C4NUM ? (output_channel - j * C4NUM) : C4NUM; - bool w_not_bound = out_w_index * OUPUT_UNIT + 1 < output_w; - bool h_not_bound = out_h_index * OUPUT_UNIT + 1 < output_h; - Conv3x3Int8OutputUnit(src_ptr, bias_ptr, dst_ptr, h_not_bound, w_not_bound, output_w, real_num, j * C4NUM, - conv_param); - } - } -} diff --git a/mindspore/lite/nnacl/winograd_transform.h b/mindspore/lite/nnacl/winograd_transform.h index f92b6caa9f3..98a0a00f032 100644 --- a/mindspore/lite/nnacl/winograd_transform.h +++ b/mindspore/lite/nnacl/winograd_transform.h @@ -24,7 +24,7 @@ #include "nnacl/pack.h" #include "nnacl/fp32/conv_fp32.h" #include "nnacl/winograd_utils.h" -#include "nnacl/quantization/fixed_point.h" +#include "mindspore/lite/nnacl/int8/fixed_point.h" #define OUPUT_UNIT 2 @@ -40,20 +40,6 @@ void WinogradOutputTransform(const float *gemm_out, float *out_data, const float int out_tile_index, int output_unit_num, const ConvParameter *conv_param, OutputTransFunc func); -// for int8 convolution 3x3 filter/input/output transform -void Conv3x3Int8InputUnit(int16_t *tmp_data, int16_t *trans_input_data, size_t step, int input_zp); - -void Conv3x3Int8InputTransform(const int16_t *input_data, int16_t *trans_input, int16_t *tmp_data, int start_index, - int real_cal_num, int out_w_block, ConvParameter *conv_param); - -void Conv3x3Int8FilterTransform(const int16_t *weight_data, int16_t *trans_weight, int iC8, int output_channel, - int kernel_plane); - -void Conv3x3Int8OutputUnit(const int32_t *gemm_out, const int32_t *bias_data, int8_t *output_data, bool h_not_bound, - bool w_not_bound, int output_w, int real_num, int oc_start, ConvParameter *conv_param); - -void Conv3x3Int8OutputTransform(const int32_t *gemm_out, int8_t *out_data, const int32_t *bias_data, int start_index, - int real_cal_num, int out_w_block, ConvParameter *conv_param); #ifdef __cplusplus } #endif diff --git a/mindspore/lite/nnacl/winograd_utils.c b/mindspore/lite/nnacl/winograd_utils.c index ed95361a4c8..06231970bf6 100644 --- a/mindspore/lite/nnacl/winograd_utils.c +++ b/mindspore/lite/nnacl/winograd_utils.c @@ -15,7 +15,6 @@ */ #include "nnacl/winograd_utils.h" -#include #include "nnacl/minimal_filtering_generator.h" #define MIN_UNIT 2 diff --git a/mindspore/lite/nnacl/zeroslike.c b/mindspore/lite/nnacl/zeroslike.c index 92712b6ea3e..6e8e725673f 100644 --- a/mindspore/lite/nnacl/zeroslike.c +++ b/mindspore/lite/nnacl/zeroslike.c @@ -13,8 +13,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + #include "nnacl/zeroslike.h" -#include #include void ApproximateZerosLike(float *output, int number) { memset(output, 0.0, number * sizeof(float)); } diff --git a/mindspore/lite/src/runtime/kernel/arm/base/depth_to_space_base.cc b/mindspore/lite/src/runtime/kernel/arm/base/depth_to_space_base.cc index 1e01125c9e0..1906cf44b6c 100644 --- a/mindspore/lite/src/runtime/kernel/arm/base/depth_to_space_base.cc +++ b/mindspore/lite/src/runtime/kernel/arm/base/depth_to_space_base.cc @@ -14,20 +14,11 @@ * limitations under the License. */ #include "src/runtime/kernel/arm/base/depth_to_space_base.h" -#include "nnacl/depth_to_space.h" -#include "src/runtime/kernel/arm/fp32/depth_to_space_fp32.h" -#include "nnacl/common_func.h" -#include "schema/model_generated.h" -#include "src/kernel_registry.h" -#include "include/errorcode.h" -#include "include/context.h" -using mindspore::lite::KernelRegistrar; using mindspore::lite::RET_ERROR; using mindspore::lite::RET_FORMAT_ERR; using mindspore::lite::RET_OK; using mindspore::lite::RET_PARAM_INVALID; -using mindspore::schema::PrimitiveType_DepthToSpace; namespace mindspore::kernel { int DepthToSpaceBaseCPUKernel::ReSize() { diff --git a/mindspore/lite/src/runtime/kernel/arm/base/depth_to_space_base.h b/mindspore/lite/src/runtime/kernel/arm/base/depth_to_space_base.h index a5b49edf461..3ff99bfedc0 100644 --- a/mindspore/lite/src/runtime/kernel/arm/base/depth_to_space_base.h +++ b/mindspore/lite/src/runtime/kernel/arm/base/depth_to_space_base.h @@ -18,9 +18,10 @@ #define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_BASE_DEPTH_TO_SPACE_BASE_H_ #include -#include "include/errorcode.h" #include "src/lite_kernel.h" -#include "nnacl/depth_to_space.h" +#include "include/errorcode.h" +#include "include/context.h" +#include "nnacl/nnacl_common.h" #include "nnacl/depth_to_space_parameter.h" namespace mindspore::kernel { diff --git a/mindspore/lite/src/runtime/kernel/arm/fp16/convolution_1x1_fp16.cc b/mindspore/lite/src/runtime/kernel/arm/fp16/convolution_1x1_fp16.cc index d3c531817af..d3b93176bf4 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp16/convolution_1x1_fp16.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp16/convolution_1x1_fp16.cc @@ -15,7 +15,7 @@ */ #include "src/runtime/kernel/arm/fp16/convolution_1x1_fp16.h" -#include "nnacl/fp16/conv_fp16.h" +#include "nnacl/base/conv1x1_base.h" #include "nnacl/fp16/cast_fp16.h" #include "nnacl/fp16/pack_fp16.h" #include "src/runtime/kernel/arm/fp16/layout_transform_fp16.h" diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/convolution_1x1_fp32.h b/mindspore/lite/src/runtime/kernel/arm/fp32/convolution_1x1_fp32.h index 8594784fb85..1d8b82aa7f3 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32/convolution_1x1_fp32.h +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/convolution_1x1_fp32.h @@ -22,10 +22,9 @@ #include "src/lite_kernel.h" #include "include/errorcode.h" #include "nnacl/op_base.h" -#include "nnacl/winograd_transform.h" #include "src/runtime/kernel/arm/base/convolution_base.h" #include "src/runtime/kernel/arm/base/layout_transform.h" -#include "nnacl/fp32/conv_fp32.h" +#include "nnacl/base/conv1x1_base.h" #include "nnacl/fp32/common_func_fp32.h" #include "nnacl/matmul_parameter.h" #include "nnacl/fp32/matmul_fp32.h" diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/depth_to_space_fp32.h b/mindspore/lite/src/runtime/kernel/arm/fp32/depth_to_space_fp32.h index c2d8c2b0d44..3681e6c7b03 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32/depth_to_space_fp32.h +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/depth_to_space_fp32.h @@ -18,7 +18,7 @@ #include #include "include/errorcode.h" -#include "nnacl/depth_to_space.h" +#include "nnacl/base/depth_to_space_base.h" #include "src/runtime/kernel/arm/base/depth_to_space_base.h" namespace mindspore::kernel { diff --git a/mindspore/lite/src/runtime/kernel/arm/int8/add_int8.cc b/mindspore/lite/src/runtime/kernel/arm/int8/add_int8.cc index 9fd0251e125..c44a4f0b70f 100644 --- a/mindspore/lite/src/runtime/kernel/arm/int8/add_int8.cc +++ b/mindspore/lite/src/runtime/kernel/arm/int8/add_int8.cc @@ -15,7 +15,7 @@ */ #include "src/runtime/kernel/arm/int8/add_int8.h" -#include "nnacl/quantization/quantize.h" +#include "nnacl/int8/quantize.h" #include "src/runtime/runtime_api.h" #include "src/kernel_registry.h" #include "include/errorcode.h" diff --git a/mindspore/lite/src/runtime/kernel/arm/int8/argminmax_int8.h b/mindspore/lite/src/runtime/kernel/arm/int8/argminmax_int8.h index 9dbc601ff24..642c457b143 100644 --- a/mindspore/lite/src/runtime/kernel/arm/int8/argminmax_int8.h +++ b/mindspore/lite/src/runtime/kernel/arm/int8/argminmax_int8.h @@ -17,7 +17,7 @@ #define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_INT8_ARGMINMAX_INT8_H_ #include -#include "nnacl/quantization/quantize.h" +#include "mindspore/lite/nnacl/int8/quantize.h" #include "nnacl/int8/arg_min_max_int8.h" #include "nnacl/common_func.h" #include "include/errorcode.h" diff --git a/mindspore/lite/src/runtime/kernel/arm/int8/convolution_1x1_int8.h b/mindspore/lite/src/runtime/kernel/arm/int8/convolution_1x1_int8.h index 7ebe3d5b88c..21e81fe446a 100644 --- a/mindspore/lite/src/runtime/kernel/arm/int8/convolution_1x1_int8.h +++ b/mindspore/lite/src/runtime/kernel/arm/int8/convolution_1x1_int8.h @@ -22,7 +22,8 @@ #include "include/errorcode.h" #include "schema/model_generated.h" #include "src/runtime/kernel/arm/base/convolution_base.h" -#include "nnacl/int8/conv_int8.h" +#include "nnacl/int8/conv1x1_int8.h" +#include "nnacl/base/conv1x1_base.h" #include "nnacl/int8/matmul_int8.h" #include "nnacl/matmul_parameter.h" #include "src/common/utils.h" diff --git a/mindspore/lite/src/runtime/kernel/arm/int8/convolution_3x3_int8.cc b/mindspore/lite/src/runtime/kernel/arm/int8/convolution_3x3_int8.cc index 9134ed4f20d..7df36b64313 100644 --- a/mindspore/lite/src/runtime/kernel/arm/int8/convolution_3x3_int8.cc +++ b/mindspore/lite/src/runtime/kernel/arm/int8/convolution_3x3_int8.cc @@ -15,7 +15,7 @@ */ #include "src/runtime/kernel/arm/int8/convolution_3x3_int8.h" -#include "nnacl/int8/conv_int8.h" +#include "nnacl/int8/conv3x3_int8.h" #include "schema/model_generated.h" #include "src/kernel_registry.h" #include "include/errorcode.h" diff --git a/mindspore/lite/src/runtime/kernel/arm/int8/depth_to_space_int8.h b/mindspore/lite/src/runtime/kernel/arm/int8/depth_to_space_int8.h index 7a5ea651a30..0dbcdfc03c1 100644 --- a/mindspore/lite/src/runtime/kernel/arm/int8/depth_to_space_int8.h +++ b/mindspore/lite/src/runtime/kernel/arm/int8/depth_to_space_int8.h @@ -18,9 +18,9 @@ #include #include "include/errorcode.h" -#include "nnacl/depth_to_space.h" +#include "nnacl/base/depth_to_space_base.h" #include "nnacl/int8/depth_to_space_int8.h" -#include "nnacl/quantization/quantize.h" +#include "mindspore/lite/nnacl/int8/quantize.h" #include "src/runtime/kernel/arm/base/depth_to_space_base.h" namespace mindspore::kernel { diff --git a/mindspore/lite/src/runtime/kernel/arm/int8/fullconnection_int8.h b/mindspore/lite/src/runtime/kernel/arm/int8/fullconnection_int8.h index 3c2b6e7dce0..5ca9ac46cc8 100644 --- a/mindspore/lite/src/runtime/kernel/arm/int8/fullconnection_int8.h +++ b/mindspore/lite/src/runtime/kernel/arm/int8/fullconnection_int8.h @@ -20,7 +20,7 @@ #include #include "src/lite_kernel.h" #include "include/errorcode.h" -#include "nnacl/quantization/quantize.h" +#include "mindspore/lite/nnacl/int8/quantize.h" #include "nnacl/common_func.h" #include "nnacl/int8/common_func_int8.h" #include "nnacl/int8/matmul_int8.h" diff --git a/mindspore/lite/src/runtime/kernel/arm/int8/gatherNd_int8.h b/mindspore/lite/src/runtime/kernel/arm/int8/gatherNd_int8.h index 3007530d133..b6529f00e44 100644 --- a/mindspore/lite/src/runtime/kernel/arm/int8/gatherNd_int8.h +++ b/mindspore/lite/src/runtime/kernel/arm/int8/gatherNd_int8.h @@ -18,7 +18,7 @@ #define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_INT8_GATHERND_INT8_H_ #include -#include "nnacl/quantization/quantize.h" +#include "mindspore/lite/nnacl/int8/quantize.h" #include "src/lite_kernel.h" namespace mindspore::kernel { diff --git a/mindspore/lite/src/runtime/kernel/arm/int8/gather_int8.cc b/mindspore/lite/src/runtime/kernel/arm/int8/gather_int8.cc index c2f1ee43b49..11ac11a5aa1 100644 --- a/mindspore/lite/src/runtime/kernel/arm/int8/gather_int8.cc +++ b/mindspore/lite/src/runtime/kernel/arm/int8/gather_int8.cc @@ -17,7 +17,7 @@ #include #include "nnacl/gather_parameter.h" #include "nnacl/int8/gather_int8.h" -#include "nnacl/quantization/quantize.h" +#include "mindspore/lite/nnacl/int8/quantize.h" #include "schema/model_generated.h" #include "src/kernel_registry.h" #include "src/runtime/runtime_api.h" diff --git a/mindspore/lite/src/runtime/kernel/arm/int8/gather_int8.h b/mindspore/lite/src/runtime/kernel/arm/int8/gather_int8.h index 7972630eed8..04a546a4876 100644 --- a/mindspore/lite/src/runtime/kernel/arm/int8/gather_int8.h +++ b/mindspore/lite/src/runtime/kernel/arm/int8/gather_int8.h @@ -19,7 +19,7 @@ #include #include "nnacl/gather_parameter.h" -#include "nnacl/quantization/quantize.h" +#include "mindspore/lite/nnacl/int8/quantize.h" #include "src/lite_kernel.h" namespace mindspore::kernel { diff --git a/mindspore/lite/src/runtime/kernel/arm/int8/hswish_int8.h b/mindspore/lite/src/runtime/kernel/arm/int8/hswish_int8.h index 00008ae2637..3ff202cf7a5 100644 --- a/mindspore/lite/src/runtime/kernel/arm/int8/hswish_int8.h +++ b/mindspore/lite/src/runtime/kernel/arm/int8/hswish_int8.h @@ -20,7 +20,7 @@ #include #include "src/lite_kernel.h" #include "nnacl/int8/hswish_int8.h" -#include "nnacl/quantization/quantize.h" +#include "mindspore/lite/nnacl/int8/quantize.h" namespace mindspore::kernel { class HswishInt8CPUKernel : public LiteKernel { diff --git a/mindspore/lite/src/runtime/kernel/arm/int8/matmul_int8.h b/mindspore/lite/src/runtime/kernel/arm/int8/matmul_int8.h index acd7fe5280e..601a2ac2098 100644 --- a/mindspore/lite/src/runtime/kernel/arm/int8/matmul_int8.h +++ b/mindspore/lite/src/runtime/kernel/arm/int8/matmul_int8.h @@ -20,7 +20,7 @@ #include #include "include/context.h" #include "nnacl/matmul_parameter.h" -#include "nnacl/quantization/quantize.h" +#include "mindspore/lite/nnacl/int8/quantize.h" #include "src/lite_kernel.h" using mindspore::lite::InnerContext; diff --git a/mindspore/lite/src/runtime/kernel/arm/int8/pad_int8.cc b/mindspore/lite/src/runtime/kernel/arm/int8/pad_int8.cc index 10e80bef6cf..828de0c4ac5 100644 --- a/mindspore/lite/src/runtime/kernel/arm/int8/pad_int8.cc +++ b/mindspore/lite/src/runtime/kernel/arm/int8/pad_int8.cc @@ -15,10 +15,6 @@ */ #include "src/runtime/kernel/arm/int8/pad_int8.h" -#include -#include "include/errorcode.h" -#include "nnacl/errorcode.h" -#include "nnacl/int8/pad_int8.h" #include "src/runtime/runtime_api.h" #include "src/kernel_registry.h" diff --git a/mindspore/lite/src/runtime/kernel/arm/int8/pad_int8.h b/mindspore/lite/src/runtime/kernel/arm/int8/pad_int8.h index 1062b47ceda..c0bf6735024 100644 --- a/mindspore/lite/src/runtime/kernel/arm/int8/pad_int8.h +++ b/mindspore/lite/src/runtime/kernel/arm/int8/pad_int8.h @@ -16,12 +16,15 @@ #ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_INT8_PAD_INT8_H_ #define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_INT8_PAD_INT8_H_ +#include #include #include "include/errorcode.h" #include "src/lite_kernel.h" #include "src/runtime/runtime_api.h" +#include "nnacl/errorcode.h" #include "nnacl/pad_parameter.h" #include "nnacl/int8/pad_int8.h" +#include "nnacl/int8/quantize.h" namespace mindspore::kernel { class PadInt8CPUKernel : public LiteKernel { diff --git a/mindspore/lite/src/runtime/kernel/arm/int8/power_int8.h b/mindspore/lite/src/runtime/kernel/arm/int8/power_int8.h index 393ac7755d5..1928542cd37 100644 --- a/mindspore/lite/src/runtime/kernel/arm/int8/power_int8.h +++ b/mindspore/lite/src/runtime/kernel/arm/int8/power_int8.h @@ -19,7 +19,7 @@ #include #include "src/lite_kernel.h" -#include "nnacl/quantization/quantize.h" +#include "mindspore/lite/nnacl/int8/quantize.h" #include "nnacl/power_parameter.h" namespace mindspore::kernel { diff --git a/mindspore/lite/src/runtime/kernel/arm/int8/reduce_int8.cc b/mindspore/lite/src/runtime/kernel/arm/int8/reduce_int8.cc index 9a304e0279c..479f2e62859 100644 --- a/mindspore/lite/src/runtime/kernel/arm/int8/reduce_int8.cc +++ b/mindspore/lite/src/runtime/kernel/arm/int8/reduce_int8.cc @@ -18,7 +18,7 @@ #include "schema/model_generated.h" #include "src/kernel_registry.h" #include "src/runtime/runtime_api.h" -#include "nnacl/quantization/quantize.h" +#include "mindspore/lite/nnacl/int8/quantize.h" #include "nnacl/pack.h" #include "include/errorcode.h" diff --git a/mindspore/lite/src/runtime/kernel/arm/int8/reduce_int8.h b/mindspore/lite/src/runtime/kernel/arm/int8/reduce_int8.h index 79b6405affd..1745e64bb2b 100644 --- a/mindspore/lite/src/runtime/kernel/arm/int8/reduce_int8.h +++ b/mindspore/lite/src/runtime/kernel/arm/int8/reduce_int8.h @@ -21,7 +21,7 @@ #include "src/lite_kernel.h" #include "nnacl/reduce_parameter.h" #include "nnacl/int8/reduce_int8.h" -#include "nnacl/quantization/quantize.h" +#include "mindspore/lite/nnacl/int8/quantize.h" #include "src/runtime/kernel/arm/base/reduce_base.h" diff --git a/mindspore/lite/src/runtime/kernel/arm/int8/resize_int8.h b/mindspore/lite/src/runtime/kernel/arm/int8/resize_int8.h index 8ee8dd7bfc3..568b07c72e4 100644 --- a/mindspore/lite/src/runtime/kernel/arm/int8/resize_int8.h +++ b/mindspore/lite/src/runtime/kernel/arm/int8/resize_int8.h @@ -19,7 +19,7 @@ #include #include "src/lite_kernel.h" #include "src/runtime/kernel/arm/base/resize_base.h" -#include "nnacl/quantization/quantize.h" +#include "mindspore/lite/nnacl/int8/quantize.h" using mindspore::schema::PrimitiveType_Resize; using mindspore::schema::ResizeMethod; diff --git a/mindspore/lite/src/runtime/kernel/arm/int8/scale_int8.h b/mindspore/lite/src/runtime/kernel/arm/int8/scale_int8.h index c7e07207c57..b0835f61262 100644 --- a/mindspore/lite/src/runtime/kernel/arm/int8/scale_int8.h +++ b/mindspore/lite/src/runtime/kernel/arm/int8/scale_int8.h @@ -21,7 +21,7 @@ #include #include "src/lite_kernel.h" #include "nnacl/scale.h" -#include "nnacl/quantization/quantize.h" +#include "nnacl/int8/quantize.h" #include "nnacl/int8/arithmetic_int8.h" #include "nnacl/int8/scale_int8.h" diff --git a/mindspore/lite/src/runtime/kernel/arm/int8/sigmoid_int8.cc b/mindspore/lite/src/runtime/kernel/arm/int8/sigmoid_int8.cc index 091c263b275..7289577948e 100644 --- a/mindspore/lite/src/runtime/kernel/arm/int8/sigmoid_int8.cc +++ b/mindspore/lite/src/runtime/kernel/arm/int8/sigmoid_int8.cc @@ -18,7 +18,7 @@ #include #include #include "nnacl/int8/sigmoid_int8.h" -#include "nnacl/quantization/quantize.h" +#include "mindspore/lite/nnacl/int8/quantize.h" #include "schema/model_generated.h" #include "src/kernel_registry.h" #include "src/runtime/runtime_api.h" diff --git a/mindspore/lite/src/runtime/kernel/arm/int8/slice_int8.h b/mindspore/lite/src/runtime/kernel/arm/int8/slice_int8.h index 79ec68fdace..95db3cc5c83 100644 --- a/mindspore/lite/src/runtime/kernel/arm/int8/slice_int8.h +++ b/mindspore/lite/src/runtime/kernel/arm/int8/slice_int8.h @@ -19,7 +19,7 @@ #include #include "src/runtime/kernel/arm/fp32/slice_fp32.h" -#include "nnacl/quantization/quantize.h" +#include "mindspore/lite/nnacl/int8/quantize.h" namespace mindspore::kernel { class SliceInt8CPUKernel : public SliceCPUKernel { diff --git a/mindspore/lite/src/runtime/kernel/arm/int8/softmax_int8.h b/mindspore/lite/src/runtime/kernel/arm/int8/softmax_int8.h index 761a8257b3c..ecbb1ba62a4 100644 --- a/mindspore/lite/src/runtime/kernel/arm/int8/softmax_int8.h +++ b/mindspore/lite/src/runtime/kernel/arm/int8/softmax_int8.h @@ -19,7 +19,7 @@ #include #include "src/runtime/kernel/arm/base/softmax_base.h" -#include "nnacl/quantization/quantize.h" +#include "mindspore/lite/nnacl/int8/quantize.h" namespace mindspore::kernel { class SoftmaxInt8CPUKernel : public SoftmaxBaseCPUKernel { diff --git a/mindspore/lite/src/runtime/kernel/arm/int8/sub_int8.h b/mindspore/lite/src/runtime/kernel/arm/int8/sub_int8.h index 728ced359a0..1b5f83d9ecd 100644 --- a/mindspore/lite/src/runtime/kernel/arm/int8/sub_int8.h +++ b/mindspore/lite/src/runtime/kernel/arm/int8/sub_int8.h @@ -20,7 +20,7 @@ #include #include #include "nnacl/int8/arithmetic_int8.h" -#include "nnacl/quantization/quantize.h" +#include "nnacl/int8/quantize.h" #include "src/lite_kernel.h" #include "nnacl/int8/sub_int8.h" #include "src/runtime/runtime_api.h" diff --git a/mindspore/lite/src/runtime/kernel/arm/int8/tanh_int8.h b/mindspore/lite/src/runtime/kernel/arm/int8/tanh_int8.h index 54495b88a44..d23bb987569 100644 --- a/mindspore/lite/src/runtime/kernel/arm/int8/tanh_int8.h +++ b/mindspore/lite/src/runtime/kernel/arm/int8/tanh_int8.h @@ -22,7 +22,7 @@ #include #include "src/lite_kernel.h" #include "nnacl/int8/tanh_int8.h" -#include "nnacl/quantization/quantize.h" +#include "mindspore/lite/nnacl/int8/quantize.h" #include "include/errorcode.h" namespace mindspore::kernel { diff --git a/mindspore/lite/test/CMakeLists.txt b/mindspore/lite/test/CMakeLists.txt index d3af5d8f4a7..037a37de7b3 100644 --- a/mindspore/lite/test/CMakeLists.txt +++ b/mindspore/lite/test/CMakeLists.txt @@ -30,7 +30,7 @@ file(GLOB KERNEL_OP_SRC ${LITE_DIR}/nnacl/*.c ${LITE_DIR}/nnacl/fp32/*.c ${LITE_DIR}/nnacl/int8/*.c - ${LITE_DIR}/nnacl/quantization/*.c + ${LITE_DIR}/nnacl/base/*.c ) file(GLOB KERNEL_OP_TRAIN_SRC diff --git a/mindspore/lite/test/ut/src/runtime/kernel/arm/common/pack_tests.cc b/mindspore/lite/test/ut/src/runtime/kernel/arm/common/pack_tests.cc index 613e0e7c86f..fb2b29b9995 100644 --- a/mindspore/lite/test/ut/src/runtime/kernel/arm/common/pack_tests.cc +++ b/mindspore/lite/test/ut/src/runtime/kernel/arm/common/pack_tests.cc @@ -169,42 +169,4 @@ TEST_F(TestPack, PackInputFp16) { } #endif -TEST_F(TestPack, PackWeightUint8) { - auto conv_param = new ConvParameter; - InitConvParamPack(conv_param); - - int k_h = conv_param->kernel_h_; - int k_w = conv_param->kernel_w_; - int in_channel = conv_param->input_channel_; - int out_channel = conv_param->output_channel_; - int ic4 = UP_DIV(in_channel, C4NUM); - int oc4 = UP_DIV(out_channel, C4NUM); - - size_t weight_size; - std::string weight_path = "./test_data/conv/convuint8_weight_32_3_3_3.bin"; - auto weight_data = reinterpret_cast(mindspore::lite::ReadFile(weight_path.c_str(), &weight_size)); - auto int8_weight = reinterpret_cast(malloc(weight_size)); - for (unsigned int i = 0; i < weight_size; i++) { - int8_weight[i] = (int8_t)(weight_data[i] - 128); - } - int32_t filter_zp = 20; - - int32_t *weight_sum = reinterpret_cast(malloc(sizeof(int32_t) * out_channel)); - for (int i = 0; i < out_channel; i++) weight_sum[i] = filter_zp * ic4 * C4NUM * k_h * k_w; - auto packed_weight = reinterpret_cast(malloc(k_h * k_w * ic4 * C4NUM * oc4 * C4NUM)); - PackWeightInt8(int8_weight, conv_param, packed_weight, weight_sum); - - printf("==================output data=================\n"); - for (int i = 0; i < 20; i++) { - std::cout << static_cast(packed_weight[i]) << " ,"; - } - std::cout << std::endl; - - free(weight_sum); - free(int8_weight); - free(packed_weight); - delete conv_param; - - MS_LOG(INFO) << "TestPackWeightUint8 passed"; -} } // namespace mindspore diff --git a/mindspore/lite/test/ut/src/runtime/kernel/arm/fp32/conv1x1_fp32_tests.cc b/mindspore/lite/test/ut/src/runtime/kernel/arm/fp32/conv1x1_fp32_tests.cc index f1e7b668f36..71c037bd6e1 100644 --- a/mindspore/lite/test/ut/src/runtime/kernel/arm/fp32/conv1x1_fp32_tests.cc +++ b/mindspore/lite/test/ut/src/runtime/kernel/arm/fp32/conv1x1_fp32_tests.cc @@ -140,36 +140,6 @@ TEST_F(TestConv1x1Fp32, Input1x1PrePack4) { delete conv_param; } -TEST_F(TestConv1x1Fp32, Conv1x1WeightTest1) { - auto *conv_param = new ConvParameter(); - float in[] = {0.214637, 0.3815, 0.811557, 0.982146, 0.09123, 0.687198, 0.02742, 0.3360, 0.853275, - 0.674123, 0.81337, 0.57188, 0.706416, 0.2740942, 0.9045, 0.07155, 0.130864, 0.037712, - 0.5369175, 0.97283, 0.92133, 0.3588165, 0.7432479, 0.7886823, 0.870324, 0.230946, 0.343969, - 0.095415, 0.50036, 0.396918, 0.09029, 0.934583, 0.91616, 0.206713, 0.9756054, 0.614025, - 0.432057, 0.1493, 0.6787, 0.10642, 0.736823, 0.377668, 0.2464896, 0.93152, 0.315917, - 0.35745, 0.52233, 0.0263, 0.339392, 0.99447, 0.49129, 0.675686, 0.75703, 0.6665356, - 0.0491, 0.1070, 0.18899, 0.929156, 0.4633427, 0.08585, 0.040709, 0.2478724, 0.5238441, - 0.0579918, 0.531636, 0.085524, 0.640923, 0.336395, 0.218651, 0.630491}; - float co[] = {0.214637, 0.81337, 0.92133, 0.09029, 0.3815, 0.57188, 0.3588165, 0.934583, 0.811557, - 0.706416, 0.7432479, 0.91616, 0.982146, 0.2740942, 0.7886823, 0.206713, 0.09123, 0.9045, - 0.870324, 0.9756054, 0.687198, 0.07155, 0.230946, 0.614025, 0.02742, 0.130864, 0.343969, - 0.432057, 0.3360, 0.037712, 0.095415, 0.1493, 0.853275, 0.5369175, 0.50036, 0.6787, - 0.674123, 0.97283, 0.396918, 0.10642, 0, 0, 0, 0, 0, - 0, 0, 0, 0.736823, 0.49129, 0.040709, 0, 0.377668, 0.675686, - 0.2478724, 0, 0.2464896, 0.75703, 0.5238441, 0, 0.93152, 0.6665356, 0.0579918, - 0, 0.315917, 0.0491, 0.531636, 0, 0.35745, 0.1070, 0.085524, 0, - 0.52233, 0.18899, 0.640923, 0, 0.0263, 0.929156, 0.336395, 0, 0.339392, - 0.4633427, 0.218651, 0, 0.99447, 0.08585, 0.630491, 0, 0, 0, - 0, 0, 0, 0, 0, 0}; - - conv_param->input_channel_ = 10; - conv_param->output_channel_ = 7; - float out[96] = {0}; - Pack1x1WeightFp32(in, out, conv_param); - EXPECT_EQ(0, CompareOutputData(out, co, 96)); - delete conv_param; -} - int Conv1x1TestInit1(std::vector *inputs_, std::vector *outputs_, ConvParameter *conv_param, float **correct) { auto *in_t = new lite::Tensor(kNumberTypeFloat, {1, 2, 3, 4}, schema::Format_NHWC, lite::Tensor::VAR); diff --git a/mindspore/lite/test/ut/src/runtime/kernel/arm/fp32/depth_to_space_fp32_test.cc b/mindspore/lite/test/ut/src/runtime/kernel/arm/fp32/depth_to_space_fp32_test.cc index 0578b2bd6d1..c1e48155e23 100644 --- a/mindspore/lite/test/ut/src/runtime/kernel/arm/fp32/depth_to_space_fp32_test.cc +++ b/mindspore/lite/test/ut/src/runtime/kernel/arm/fp32/depth_to_space_fp32_test.cc @@ -15,7 +15,7 @@ */ #include "src/common/log_adapter.h" #include "common/common_test.h" -#include "mindspore/lite/nnacl/depth_to_space.h" +#include "mindspore/lite/nnacl/base/depth_to_space_base.h" #include "mindspore/lite/nnacl/common_func.h" namespace mindspore { diff --git a/mindspore/lite/test/ut/src/runtime/kernel/arm/int8/conv_1x1_int8_tests.cc b/mindspore/lite/test/ut/src/runtime/kernel/arm/int8/conv_1x1_int8_tests.cc index bea6ae291a0..7910a435a74 100644 --- a/mindspore/lite/test/ut/src/runtime/kernel/arm/int8/conv_1x1_int8_tests.cc +++ b/mindspore/lite/test/ut/src/runtime/kernel/arm/int8/conv_1x1_int8_tests.cc @@ -17,7 +17,7 @@ #include "common/common_test.h" #include "mindspore/lite/src/lite_kernel.h" #include "src/common/file_utils.h" -#include "nnacl/quantization/quantize.h" +#include "mindspore/lite/nnacl/int8/quantize.h" #include "nnacl/common_func.h" #include "mindspore/lite/src/runtime/kernel/arm/int8/convolution_1x1_int8.h" diff --git a/mindspore/lite/test/ut/src/runtime/kernel/arm/int8/div_int8_test.cc b/mindspore/lite/test/ut/src/runtime/kernel/arm/int8/div_int8_test.cc deleted file mode 100644 index a019ab63900..00000000000 --- a/mindspore/lite/test/ut/src/runtime/kernel/arm/int8/div_int8_test.cc +++ /dev/null @@ -1,76 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include -#include -#include "schema/inner/model_generated.h" -#include "common/common_test.h" -#include "mindspore/lite/src/runtime/kernel/arm/int8/div_int8.h" -#include "mindspore/lite/src/kernel_registry.h" -#include "mindspore/lite/include/context.h" - -namespace mindspore { -class TestDivInt8 : public mindspore::CommonTest { - public: - TestDivInt8() {} -}; - -TEST_F(TestDivInt8, DivInt8) { - lite::Tensor in_tensor0(kNumberTypeInt8, {1, 1, 2, 5}); - lite::Tensor in_tensor1(kNumberTypeInt8, {1, 1, 2, 5}); - lite::Tensor out_tensor(kNumberTypeInt8, {1, 1, 2, 5}); - - int8_t input_data0[] = {105, 35, -27, 0, -63, 99, 16, 45, 67, -49}; - int8_t input_data1[] = {126, -38, -115, 106, -98, 119, 103, 81, -114, 68}; - int8_t output_data[10] = {0}; - in_tensor0.set_data(input_data0); - in_tensor1.set_data(input_data1); - out_tensor.set_data(output_data); - - const lite::QuantArg quant_in0 = {0.00784314f, 0}; // -1.0--1.0 -> 0--255 - const lite::QuantArg quant_in1 = {0.00784314f, 0}; - const lite::QuantArg quant_out = {0.00784314f, 0}; - in_tensor0.AddQuantParam(quant_in0); - in_tensor1.AddQuantParam(quant_in1); - out_tensor.AddQuantParam(quant_out); - - std::vector inputs = {&in_tensor0, &in_tensor1}; - std::vector outputs = {&out_tensor}; - - OpParameter parameter = {}; - kernel::KernelKey desc = {kernel::KERNEL_ARCH::kCPU, kNumberTypeInt8, schema::PrimitiveType_Div}; - - auto creator = lite::KernelRegistry::GetInstance()->GetCreator(desc); - ASSERT_NE(creator, nullptr); - - auto ctx = std::make_shared(); - ASSERT_EQ(lite::RET_OK, ctx->Init()); - auto kernel = creator(inputs, outputs, reinterpret_cast(¶meter), ctx.get(), desc, nullptr); - ASSERT_NE(kernel, nullptr); - - auto ret = kernel->Run(); - EXPECT_EQ(0, ret); - - int8_t expect0[10] = {106, -117, 30, 0, 82, 106, 20, 71, -75, -92}; - for (int i = 0; i < 10; ++i) { - EXPECT_EQ(output_data[i], expect0[i]); - } - - in_tensor0.set_data(nullptr); - in_tensor1.set_data(nullptr); - out_tensor.set_data(nullptr); -} -} // namespace mindspore diff --git a/mindspore/lite/test/ut/src/runtime/kernel/arm/int8/fullconnection_int8_tests.cc b/mindspore/lite/test/ut/src/runtime/kernel/arm/int8/fullconnection_int8_tests.cc index 72c4ee2c121..1330b43f7de 100644 --- a/mindspore/lite/test/ut/src/runtime/kernel/arm/int8/fullconnection_int8_tests.cc +++ b/mindspore/lite/test/ut/src/runtime/kernel/arm/int8/fullconnection_int8_tests.cc @@ -18,7 +18,7 @@ #include "common/common_test.h" #include "mindspore/lite/src/runtime/kernel/arm/int8/fullconnection_int8.h" #include "mindspore/lite/nnacl/common_func.h" -#include "mindspore/lite/nnacl/quantization/quantize.h" +#include "mindspore/lite/nnacl/int8/quantize.h" #include "mindspore/lite/src/kernel_registry.h" #include "mindspore/lite/src/lite_kernel.h" diff --git a/mindspore/lite/test/ut/src/runtime/kernel/arm/int8/matmul_int8_tests.cc b/mindspore/lite/test/ut/src/runtime/kernel/arm/int8/matmul_int8_tests.cc index f1cd13aa25f..e9b57c3975e 100644 --- a/mindspore/lite/test/ut/src/runtime/kernel/arm/int8/matmul_int8_tests.cc +++ b/mindspore/lite/test/ut/src/runtime/kernel/arm/int8/matmul_int8_tests.cc @@ -18,7 +18,7 @@ #include "src/common/log_adapter.h" #include "common/common_test.h" #include "mindspore/lite/src/runtime/kernel/arm/int8/matmul_int8.h" -#include "nnacl/quantization/quantize.h" +#include "mindspore/lite/nnacl/int8/quantize.h" #include "nnacl/common_func.h" #include "nnacl/int8/matmul_int8.h" #include "mindspore/lite/src/kernel_registry.h" diff --git a/mindspore/lite/test/ut/src/runtime/kernel/arm/int8/prelu_int8_tests.cc b/mindspore/lite/test/ut/src/runtime/kernel/arm/int8/prelu_int8_tests.cc index 52169d0c4df..e1b2c75465f 100644 --- a/mindspore/lite/test/ut/src/runtime/kernel/arm/int8/prelu_int8_tests.cc +++ b/mindspore/lite/test/ut/src/runtime/kernel/arm/int8/prelu_int8_tests.cc @@ -18,7 +18,7 @@ #include "schema/inner/model_generated.h" #include "src/common/log_adapter.h" #include "common/common_test.h" -#include "mindspore/lite/nnacl/quantization/quantize.h" +#include "mindspore/lite/nnacl/int8/quantize.h" #include "mindspore/lite/src/kernel_registry.h" #include "mindspore/lite/src/lite_kernel.h" #include "mindspore/lite/src/tensor.h" diff --git a/mindspore/lite/tools/converter/CMakeLists.txt b/mindspore/lite/tools/converter/CMakeLists.txt index 3bdf681010d..fa3fac0a1ef 100644 --- a/mindspore/lite/tools/converter/CMakeLists.txt +++ b/mindspore/lite/tools/converter/CMakeLists.txt @@ -108,7 +108,7 @@ file(GLOB KERNEL_SRC ${CMAKE_CURRENT_SOURCE_DIR}/../../nnacl/*.c ${CMAKE_CURRENT_SOURCE_DIR}/../../nnacl/fp32/*.c ${CMAKE_CURRENT_SOURCE_DIR}/../../nnacl/int8/*.c - ${CMAKE_CURRENT_SOURCE_DIR}/../../nnacl/quantization/*.c + ${CMAKE_CURRENT_SOURCE_DIR}/../../nnacl/base/*.c ${ARM_DIR}/fp32/*.cc ${ARM_DIR}/int8/*.cc )