This commit is contained in:
fuzhiye 2020-08-18 18:55:52 +08:00
parent 030af09f60
commit 127b089a11
9 changed files with 142 additions and 131 deletions

View File

@ -296,16 +296,6 @@ int ConvolutionBaseCPUKernel::SetQuantParam() {
MS_LOG(ERROR) << "Set Output Tensor Quant Param Failed.";
return ret;
}
ret = SetQuantMultiplier();
if (ret != RET_OK) {
MS_LOG(ERROR) << "Set Quant Multiplier Failed.";
return ret;
}
// now only consider per tensor for output
CalculateActivationRangeQuantized(
conv_param_->is_relu_, conv_param_->is_relu6_, conv_param_->conv_quant_arg_.output_quant_args_[0].zp_,
conv_param_->conv_quant_arg_.output_quant_args_[0].scale_, &conv_param_->conv_quant_arg_.out_act_min_[0],
&conv_param_->conv_quant_arg_.out_act_max_[0]);
ret = SetIfPerChannel();
if (ret != RET_OK) {
@ -317,6 +307,18 @@ int ConvolutionBaseCPUKernel::SetQuantParam() {
MS_LOG(ERROR) << "Set if per asymmetric failed.";
return ret;
}
ret = SetQuantMultiplier();
if (ret != RET_OK) {
MS_LOG(ERROR) << "Set Quant Multiplier Failed.";
return ret;
}
// now only consider per tensor for output
CalculateActivationRangeQuantized(
conv_param_->is_relu_, conv_param_->is_relu6_, conv_param_->conv_quant_arg_.output_quant_args_[0].zp_,
conv_param_->conv_quant_arg_.output_quant_args_[0].scale_, &conv_param_->conv_quant_arg_.out_act_min_[0],
&conv_param_->conv_quant_arg_.out_act_max_[0]);
return RET_OK;
}
} // namespace mindspore::kernel

View File

@ -17,6 +17,7 @@
#include "src/runtime/kernel/arm/fp16/convolution_fp16.h"
#include <vector>
#include "src/runtime/kernel/arm/fp16/convolution_sw_fp16.h"
#include "src/runtime/kernel/arm/fp16/convolution_winograd_fp16.h"
#include "src/runtime/kernel/arm/fp16/convolution_3x3_fp16.h"
#include "src/runtime/kernel/arm/fp16/convolution_1x1_fp16.h"
#include "src/runtime/kernel/arm/nnacl/fp16/conv_fp16.h"
@ -243,6 +244,10 @@ kernel::LiteKernel *CpuConvFp16KernelCreator(const std::vector<lite::tensor::Ten
InputTransformUnitFunc input_trans_func = nullptr;
OutputTransformUnitFunc output_trans_func = nullptr;
CheckIfUseWinograd(&use_winograd, &out_unit, conv_param, input_trans_func, output_trans_func);
if (use_winograd) {
kernel = new (std::nothrow)
kernel::ConvolutionWinogradFP16CPUKernel(opParameter, inputs, outputs, ctx, primitive, out_unit);
}
if (kernel_h != 1 && kernel_w != 1 && !use_winograd) {
kernel = new (std::nothrow) kernel::ConvolutionFP16CPUKernel(opParameter, inputs, outputs, ctx, primitive);
}

View File

@ -51,7 +51,7 @@ int ConvolutionCPUKernel::InitWeightBias() {
// #endif
int pack_weight_size = oc_block_num * oc_block * ic4 * C4NUM * kernel_plane;
// init weight
// =====================init weight==========================//
auto origin_weight = reinterpret_cast<float *>(in_tensors_.at(kWeightIndex)->Data());
packed_weight_ = reinterpret_cast<float *>(malloc(pack_weight_size * sizeof(float)));
if (packed_weight_ == nullptr) {
@ -61,7 +61,7 @@ int ConvolutionCPUKernel::InitWeightBias() {
memset(packed_weight_, 0, pack_weight_size * sizeof(float));
PackWeightFp32(origin_weight, conv_param_, packed_weight_, oc_block, oc_block_num);
// init bias
// =======================init bias==========================//
bias_data_ = reinterpret_cast<float *>(malloc(oc_block_num * oc_block * sizeof(float)));
if (bias_data_ == nullptr) {
MS_LOG(ERROR) << "malloc bias failed.";

View File

@ -633,7 +633,9 @@ void WinogradOutputTransformFp16(const float16_t *gemm_out, float16_t *tmp_out_d
OutputTransformUnitFp16Func output_trans_func) {
int output_unit = conv_param->output_unit_;
int output_w = conv_param->output_w_;
int output_unit_block = UP_DIV(output_w, output_unit);
int output_h = conv_param->output_h_;
int output_w_unit_block = UP_DIV(output_w, output_unit);
int output_h_unit_block = UP_DIV(output_h, output_unit);
int output_channel = conv_param->output_channel_;
int oc8 = UP_DIV(output_channel, C8NUM);
int input_unit = conv_param->input_unit_;
@ -644,16 +646,16 @@ void WinogradOutputTransformFp16(const float16_t *gemm_out, float16_t *tmp_out_d
int dst_x_s = out_tile_index % output_unit_num;
int dst_y_s = out_tile_index / output_unit_num;
int src_tile_offset = i * oc8 * C8NUM * input_unit * input_unit;
int dst_tile_offset = C8NUM * output_unit * (dst_x_s + dst_y_s * output_unit_block * output_unit);
int dst_tile_offset = C8NUM * output_unit * (dst_x_s + dst_y_s * output_w_unit_block * output_unit);
for (int j = 0; j < oc8; j++) {
int src_oc8_offset = src_tile_offset + j * input_unit * input_unit * C8NUM;
int dst_oc8_offset =
dst_tile_offset + j * C8NUM * output_unit_block * output_unit_block * output_unit * output_unit;
dst_tile_offset + j * C8NUM * output_h_unit_block * output_w_unit_block * output_unit * output_unit;
const float16_t *src_ptr = gemm_out + src_oc8_offset;
const float16_t *bias_ptr = bias_data + j * C8NUM;
float16_t *dst_ptr = tmp_out_data + dst_oc8_offset;
output_trans_func(src_ptr, dst_ptr, bias_ptr, C8NUM, output_unit_block * output_unit);
output_trans_func(src_ptr, dst_ptr, bias_ptr, C8NUM, output_w_unit_block * output_unit);
}
out_tile_index++;
}

View File

@ -1066,7 +1066,7 @@ void OutputTransform4x3UnitFp16(const float16_t *src_data, float16_t *dst_data,
const float16_t t10 = 0.5f * (src_data_10 - src_data_20);
const float16_t t11 = 0.5f * (src_data_11 - src_data_21);
const float16_t t12 = 0.5f * (src_data_12 - src_data_22);
const const float16_t t13 = 0.5f * (src_data_13 - src_data_23);
const float16_t t13 = 0.5f * (src_data_13 - src_data_23);
const float16_t t20 = 0.25f * (src_data_10 + src_data_20) + src_data_30;
const float16_t t21 = 0.25f * (src_data_11 + src_data_21) + src_data_31;
@ -2232,7 +2232,7 @@ void OutputTransform8x4UnitFp16(const float16_t *src_data, float16_t *dst_data,
const float16_t t24 = 0.25f * d35 + d45 + 2.25f * d55;
const float16_t t25 = 0.25f * d36 + d46 + 2.25f * d56;
const float16_t t26 = 0.25f * d37 + d47 + 2.25f * d57;
const const float16_t t27 = 0.25f * d38 + d48 + 2.25f * d58;
const float16_t t27 = 0.25f * d38 + d48 + 2.25f * d58;
const float16_t t30 = 0.125f * d01 + d11 + 3.375f * d21 + src_data_70;
const float16_t t31 = 0.125f * d02 + d12 + 3.375f * d22 + src_data_71;
@ -3392,7 +3392,7 @@ void OutputTransform8x6UnitFp16(const float16_t *src_data, float16_t *dst_data,
const float16_t t52 = 0.03125f * d03 + d13 + 7.59375f * d23 + src_data_72;
const float16_t t53 = 0.03125f * d04 + d14 + 7.59375f * d24 + src_data_73;
const float16_t t54 = 0.03125f * d05 + d15 + 7.59375f * d25 + src_data_74;
const const float16_t t55 = 0.03125f * d06 + d16 + 7.59375f * d26 + src_data_75;
const float16_t t55 = 0.03125f * d06 + d16 + 7.59375f * d26 + src_data_75;
const float16_t t56 = 0.03125f * d07 + d17 + 7.59375f * d27 + src_data_76;
const float16_t t57 = 0.03125f * d08 + d18 + 7.59375f * d28 + src_data_77;

View File

@ -325,7 +325,6 @@ void ConvInt8Opt(int8_t *input_data, int8_t *packed_input, int8_t *packed_weight
for (int thread_id = task_id; thread_id < output_tile_count; thread_id += thread_count) {
int start_index = thread_id * tile_n;
int real_cal_num = (output_count - start_index) < tile_n ? (output_count - start_index) : tile_n;
// todo
int32_t *tmp_input_sum = input_sum + task_id * tile_n;
int8_t *gemm_input = packed_input + thread_id * unit_size * tile_n + gemm_in_batch_offset;
// clear tmp buffer before compute

View File

@ -295,12 +295,12 @@ void Im2ColPackUnitInt8(const int8_t *input_data, int8_t *packed_input, int real
} // kernel_w loop
} // kernel_h loop
if (!(conv_param->conv_quant_arg_.asymmetric_ & FILTER_ASYMMETRIC)) {
return;
continue;
} else if ((conv_param->conv_quant_arg_.asymmetric_ & FILTER_ASYMMETRIC) &&
(conv_param->conv_quant_arg_.per_channel_ & FILTER_PER_CHANNEL)) {
int cal_num_offset = i * conv_param->output_channel_;
for (int l = 0; l < conv_param->output_channel_; ++l) {
input_sum[cal_num_offset + l] = input_accumulator * filter_arg[i].zp_;
input_sum[cal_num_offset + l] = input_accumulator * filter_arg[l].zp_;
}
} else if ((conv_param->conv_quant_arg_.asymmetric_ & FILTER_ASYMMETRIC) &&
!(conv_param->conv_quant_arg_.per_channel_ & FILTER_PER_CHANNEL)) {
@ -367,12 +367,12 @@ void Im2ColPackUnitInt8Opt(const int8_t *input_data, int8_t *packed_input, int r
}
}
if (!(conv_param->conv_quant_arg_.asymmetric_ & FILTER_ASYMMETRIC)) {
return;
continue;
} else if ((conv_param->conv_quant_arg_.asymmetric_ & FILTER_ASYMMETRIC) &&
(conv_param->conv_quant_arg_.per_channel_ & FILTER_PER_CHANNEL)) {
int cal_num_offset = i * conv_param->output_channel_;
for (int l = 0; l < conv_param->output_channel_; ++l) {
input_sum[cal_num_offset + l] = input_accumulator * filter_arg[i].zp_;
input_sum[cal_num_offset + l] = input_accumulator * filter_arg[l].zp_;
}
} else if ((conv_param->conv_quant_arg_.asymmetric_ & FILTER_ASYMMETRIC) &&
!(conv_param->conv_quant_arg_.per_channel_ & FILTER_PER_CHANNEL)) {
@ -870,8 +870,8 @@ void PackNHWCToNCHWFp32(const void *src, void *dst, int batches, int plane, int
int c8 = channel / C8NUM * C8NUM;
int batch = plane * channel;
for (int n = 0; n < batches; n++) {
const float *src_batch = (const float*) src + n * batch;
float *dst_batch = (float*) dst + n * batch;
const float *src_batch = (const float *)src + n * batch;
float *dst_batch = (float *)dst + n * batch;
int hw = 0;
for (; hw < hw8; hw += C8NUM) {
int c = 0;
@ -947,9 +947,10 @@ void PackNHWCToNCHWFp32(const void *src, void *dst, int batches, int plane, int
"st1 {v30.4s, v31.4s}, [x11], %[dstStride]\n"
:
: [ dst_ptr ] "r"(dst_ptr), [ src_ptr ] "r"(src_ptr), [ srcStride ] "r"(srcStride), [ dstStride ] "r"(dstStride)
:
[ dst_ptr ] "r"(dst_ptr), [ src_ptr ] "r"(src_ptr), [ srcStride ] "r"(srcStride), [ dstStride ] "r"(dstStride)
: "x10", "x11", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14",
"v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29",
"v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29",
"v30", "v31");
#else
for (int tr = 0; tr < C8NUM; tr++) {

View File

@ -81,7 +81,9 @@ void WinogradOutputTransform(const float *gemm_out, float *tmp_out_data, const f
OutputTransformUnitFunc output_trans_func) {
int output_unit = conv_param->output_unit_;
int output_w = conv_param->output_w_;
int output_unit_block = UP_DIV(output_w, output_unit);
int output_h = conv_param->output_h_;
int output_w_unit_block = UP_DIV(output_w, output_unit);
int output_h_unit_block = UP_DIV(output_h, output_unit);
int output_channel = conv_param->output_channel_;
int oc4 = UP_DIV(output_channel, C4NUM);
int input_unit = conv_param->input_unit_;
@ -92,16 +94,16 @@ void WinogradOutputTransform(const float *gemm_out, float *tmp_out_data, const f
int dst_x_s = out_tile_index % output_unit_num;
int dst_y_s = out_tile_index / output_unit_num;
int src_tile_offset = i * oc4 * C4NUM * input_unit * input_unit;
int dst_tile_offset = C4NUM * output_unit * (dst_x_s + dst_y_s * output_unit_block * output_unit);
int dst_tile_offset = C4NUM * output_unit * (dst_x_s + dst_y_s * output_w_unit_block * output_unit);
for (int j = 0; j < oc4; j++) {
int src_oc4_offset = src_tile_offset + j * input_unit * input_unit * C4NUM;
int dst_oc4_offset =
dst_tile_offset + j * C4NUM * output_unit_block * output_unit_block * output_unit * output_unit;
dst_tile_offset + j * C4NUM * output_h_unit_block * output_w_unit_block * output_unit * output_unit;
const float *src_ptr = gemm_out + src_oc4_offset;
const float *bias_ptr = bias_data + j * C4NUM;
float *dst_ptr = tmp_out_data + dst_oc4_offset;
output_trans_func(src_ptr, dst_ptr, bias_ptr, C4NUM, output_unit_block * output_unit);
output_trans_func(src_ptr, dst_ptr, bias_ptr, C4NUM, output_w_unit_block * output_unit);
}
out_tile_index++;
}

View File

@ -988,122 +988,122 @@ void InputTransform8x8Unit(const float *src_data, float *dst_data, int src_step,
1.77777777777777778f * src_data_67;
const float t10 = 1.5f * src_data_10 + 3.0f * src_data_20 - 2.1666666666666667f * src_data_30 -
4.333333333333333333f * src_data_40 + 0.66666666666666667f * src_data_50 +
1.333333333333333f * src_data_60;
4.333333333333333333f * src_data_40 + 0.66666666666666667f * src_data_50 +
1.333333333333333f * src_data_60;
const float t11 = 1.5f * src_data_11 + 3.0f * src_data_21 - 2.1666666666666667f * src_data_31 -
4.333333333333333333f * src_data_41 + 0.66666666666666667f * src_data_51 +
1.333333333333333f * src_data_61;
4.333333333333333333f * src_data_41 + 0.66666666666666667f * src_data_51 +
1.333333333333333f * src_data_61;
const float t12 = 1.5f * src_data_12 + 3.0f * src_data_22 - 2.1666666666666667f * src_data_32 -
4.333333333333333333f * src_data_42 + 0.66666666666666667f * src_data_52 +
1.333333333333333f * src_data_62;
4.333333333333333333f * src_data_42 + 0.66666666666666667f * src_data_52 +
1.333333333333333f * src_data_62;
const float t13 = 1.5f * src_data_13 + 3.0f * src_data_23 - 2.1666666666666667f * src_data_33 -
4.333333333333333333f * src_data_43 + 0.66666666666666667f * src_data_53 +
1.333333333333333f * src_data_63;
4.333333333333333333f * src_data_43 + 0.66666666666666667f * src_data_53 +
1.333333333333333f * src_data_63;
const float t14 = 1.5f * src_data_14 + 3.0f * src_data_24 - 2.1666666666666667f * src_data_34 -
4.333333333333333333f * src_data_44 + 0.66666666666666667f * src_data_54 +
1.333333333333333f * src_data_64;
4.333333333333333333f * src_data_44 + 0.66666666666666667f * src_data_54 +
1.333333333333333f * src_data_64;
const float t15 = 1.5f * src_data_15 + 3.0f * src_data_25 - 2.1666666666666667f * src_data_35 -
4.333333333333333333f * src_data_45 + 0.66666666666666667f * src_data_55 +
1.333333333333333f * src_data_65;
4.333333333333333333f * src_data_45 + 0.66666666666666667f * src_data_55 +
1.333333333333333f * src_data_65;
const float t16 = 1.5f * src_data_16 + 3.0f * src_data_26 - 2.1666666666666667f * src_data_36 -
4.333333333333333333f * src_data_46 + 0.66666666666666667f * src_data_56 +
1.333333333333333f * src_data_66;
4.333333333333333333f * src_data_46 + 0.66666666666666667f * src_data_56 +
1.333333333333333f * src_data_66;
const float t17 = 1.5f * src_data_17 + 3.0f * src_data_27 - 2.1666666666666667f * src_data_37 -
4.333333333333333333f * src_data_47 + 0.66666666666666667f * src_data_57 +
1.333333333333333f * src_data_67;
4.333333333333333333f * src_data_47 + 0.66666666666666667f * src_data_57 +
1.333333333333333f * src_data_67;
const float t20 = -1.5f * src_data_10 + 3.0f * src_data_20 + 2.1666666666666667f * src_data_30 -
4.333333333333333333f * src_data_40 - 0.66666666666666667f * src_data_50 +
1.333333333333333f * src_data_60;
4.333333333333333333f * src_data_40 - 0.66666666666666667f * src_data_50 +
1.333333333333333f * src_data_60;
const float t21 = -1.5f * src_data_11 + 3.0f * src_data_21 + 2.1666666666666667f * src_data_31 -
4.333333333333333333f * src_data_41 - 0.66666666666666667f * src_data_51 +
1.333333333333333f * src_data_61;
4.333333333333333333f * src_data_41 - 0.66666666666666667f * src_data_51 +
1.333333333333333f * src_data_61;
const float t22 = -1.5f * src_data_12 + 3.0f * src_data_22 + 2.1666666666666667f * src_data_32 -
4.333333333333333333f * src_data_42 - 0.66666666666666667f * src_data_52 +
1.333333333333333f * src_data_62;
4.333333333333333333f * src_data_42 - 0.66666666666666667f * src_data_52 +
1.333333333333333f * src_data_62;
const float t23 = -1.5f * src_data_13 + 3.0f * src_data_23 + 2.1666666666666667f * src_data_33 -
4.333333333333333333f * src_data_43 - 0.66666666666666667f * src_data_53 +
1.333333333333333f * src_data_63;
4.333333333333333333f * src_data_43 - 0.66666666666666667f * src_data_53 +
1.333333333333333f * src_data_63;
const float t24 = -1.5f * src_data_14 + 3.0f * src_data_24 + 2.1666666666666667f * src_data_34 -
4.333333333333333333f * src_data_44 - 0.66666666666666667f * src_data_54 +
1.333333333333333f * src_data_64;
4.333333333333333333f * src_data_44 - 0.66666666666666667f * src_data_54 +
1.333333333333333f * src_data_64;
const float t25 = -1.5f * src_data_15 + 3.0f * src_data_25 + 2.1666666666666667f * src_data_35 -
4.333333333333333333f * src_data_45 - 0.66666666666666667f * src_data_55 +
1.333333333333333f * src_data_65;
4.333333333333333333f * src_data_45 - 0.66666666666666667f * src_data_55 +
1.333333333333333f * src_data_65;
const float t26 = -1.5f * src_data_16 + 3.0f * src_data_26 + 2.1666666666666667f * src_data_36 -
4.333333333333333333f * src_data_46 - 0.66666666666666667f * src_data_56 +
1.333333333333333f * src_data_66;
4.333333333333333333f * src_data_46 - 0.66666666666666667f * src_data_56 +
1.333333333333333f * src_data_66;
const float t27 = -1.5f * src_data_17 + 3.0f * src_data_27 + 2.1666666666666667f * src_data_37 -
4.333333333333333333f * src_data_47 - 0.66666666666666667f * src_data_57 +
1.333333333333333f * src_data_67;
4.333333333333333333f * src_data_47 - 0.66666666666666667f * src_data_57 +
1.333333333333333f * src_data_67;
const float t30 = -0.3f * (src_data_10 + src_data_20) + 1.33333333333333f * (src_data_30 + src_data_40) -
0.53333333333f * (src_data_50 + src_data_60);
0.53333333333f * (src_data_50 + src_data_60);
const float t31 = -0.3f * (src_data_11 + src_data_21) + 1.33333333333333f * (src_data_31 + src_data_41) -
0.53333333333f * (src_data_51 + src_data_61);
0.53333333333f * (src_data_51 + src_data_61);
const float t32 = -0.3f * (src_data_12 + src_data_22) + 1.33333333333333f * (src_data_32 + src_data_42) -
0.53333333333f * (src_data_52 + src_data_62);
0.53333333333f * (src_data_52 + src_data_62);
const float t33 = -0.3f * (src_data_13 + src_data_23) + 1.33333333333333f * (src_data_33 + src_data_43) -
0.53333333333f * (src_data_53 + src_data_63);
0.53333333333f * (src_data_53 + src_data_63);
const float t34 = -0.3f * (src_data_14 + src_data_24) + 1.33333333333333f * (src_data_34 + src_data_44) -
0.53333333333f * (src_data_54 + src_data_64);
0.53333333333f * (src_data_54 + src_data_64);
const float t35 = -0.3f * (src_data_15 + src_data_25) + 1.33333333333333f * (src_data_35 + src_data_45) -
0.53333333333f * (src_data_55 + src_data_65);
const const float t36 = -0.3f * (src_data_16 + src_data_26) + 1.33333333333333f * (src_data_36 + src_data_46) -
0.53333333333f * (src_data_56 + src_data_66);
const const float t37 = -0.3f * (src_data_17 + src_data_27) + 1.33333333333333f * (src_data_37 + src_data_47) -
0.53333333333f * (src_data_57 + src_data_67);
0.53333333333f * (src_data_55 + src_data_65);
const float t36 = -0.3f * (src_data_16 + src_data_26) + 1.33333333333333f * (src_data_36 + src_data_46) -
0.53333333333f * (src_data_56 + src_data_66);
const float t37 = -0.3f * (src_data_17 + src_data_27) + 1.33333333333333f * (src_data_37 + src_data_47) -
0.53333333333f * (src_data_57 + src_data_67);
const float t40 = 0.3f * (src_data_10 - src_data_20) + 1.33333333333333f * (src_data_40 - src_data_30) +
0.53333333333f * (src_data_50 - src_data_60);
0.53333333333f * (src_data_50 - src_data_60);
const float t41 = 0.3f * (src_data_11 - src_data_21) + 1.33333333333333f * (src_data_41 - src_data_31) +
0.53333333333f * (src_data_51 - src_data_61);
0.53333333333f * (src_data_51 - src_data_61);
const float t42 = 0.3f * (src_data_12 - src_data_22) + 1.33333333333333f * (src_data_42 - src_data_32) +
0.53333333333f * (src_data_52 - src_data_62);
0.53333333333f * (src_data_52 - src_data_62);
const float t43 = 0.3f * (src_data_13 - src_data_23) + 1.33333333333333f * (src_data_43 - src_data_33) +
0.53333333333f * (src_data_53 - src_data_63);
0.53333333333f * (src_data_53 - src_data_63);
const float t44 = 0.3f * (src_data_14 - src_data_24) + 1.33333333333333f * (src_data_44 - src_data_34) +
0.53333333333f * (src_data_54 - src_data_64);
0.53333333333f * (src_data_54 - src_data_64);
const float t45 = 0.3f * (src_data_15 - src_data_25) + 1.33333333333333f * (src_data_45 - src_data_35) +
0.53333333333f * (src_data_55 - src_data_65);
0.53333333333f * (src_data_55 - src_data_65);
const float t46 = 0.3f * (src_data_16 - src_data_26) + 1.33333333333333f * (src_data_46 - src_data_36) +
0.53333333333f * (src_data_56 - src_data_66);
0.53333333333f * (src_data_56 - src_data_66);
const float t47 = 0.3f * (src_data_17 - src_data_27) + 1.33333333333333f * (src_data_47 - src_data_37) +
0.53333333333f * (src_data_57 - src_data_67);
0.53333333333f * (src_data_57 - src_data_67);
const float t50 = 0.0333333333f * src_data_10 + 0.02222222f * src_data_20 - 0.1666666666f * src_data_30 -
0.1111111111f * src_data_40 + 0.1333333f * src_data_50 + 0.0888888f * src_data_60;
0.1111111111f * src_data_40 + 0.1333333f * src_data_50 + 0.0888888f * src_data_60;
const float t51 = 0.0333333333f * src_data_11 + 0.02222222f * src_data_21 - 0.1666666666f * src_data_31 -
0.1111111111f * src_data_41 + 0.1333333f * src_data_51 + 0.0888888f * src_data_61;
0.1111111111f * src_data_41 + 0.1333333f * src_data_51 + 0.0888888f * src_data_61;
const float t52 = 0.0333333333f * src_data_12 + 0.02222222f * src_data_22 - 0.1666666666f * src_data_32 -
0.1111111111f * src_data_42 + 0.1333333f * src_data_52 + 0.0888888f * src_data_62;
0.1111111111f * src_data_42 + 0.1333333f * src_data_52 + 0.0888888f * src_data_62;
const float t53 = 0.0333333333f * src_data_13 + 0.02222222f * src_data_23 - 0.1666666666f * src_data_33 -
0.1111111111f * src_data_43 + 0.1333333f * src_data_53 + 0.0888888f * src_data_63;
0.1111111111f * src_data_43 + 0.1333333f * src_data_53 + 0.0888888f * src_data_63;
const float t54 = 0.0333333333f * src_data_14 + 0.02222222f * src_data_24 - 0.1666666666f * src_data_34 -
0.1111111111f * src_data_44 + 0.1333333f * src_data_54 + 0.0888888f * src_data_64;
0.1111111111f * src_data_44 + 0.1333333f * src_data_54 + 0.0888888f * src_data_64;
const float t55 = 0.0333333333f * src_data_15 + 0.02222222f * src_data_25 - 0.1666666666f * src_data_35 -
0.1111111111f * src_data_45 + 0.1333333f * src_data_55 + 0.0888888f * src_data_65;
0.1111111111f * src_data_45 + 0.1333333f * src_data_55 + 0.0888888f * src_data_65;
const float t56 = 0.0333333333f * src_data_16 + 0.02222222f * src_data_26 - 0.1666666666f * src_data_36 -
0.1111111111f * src_data_46 + 0.1333333f * src_data_56 + 0.0888888f * src_data_66;
0.1111111111f * src_data_46 + 0.1333333f * src_data_56 + 0.0888888f * src_data_66;
const float t57 = 0.0333333333f * src_data_17 + 0.02222222f * src_data_27 - 0.1666666666f * src_data_37 -
0.1111111111f * src_data_47 + 0.1333333f * src_data_57 + 0.0888888f * src_data_67;
0.1111111111f * src_data_47 + 0.1333333f * src_data_57 + 0.0888888f * src_data_67;
const float t60 = -0.0333333333f * src_data_10 + 0.02222222f * src_data_20 + 0.1666666666f * src_data_30 -
0.1111111111f * src_data_40 - 0.1333333f * src_data_50 + 0.0888888f * src_data_60;
0.1111111111f * src_data_40 - 0.1333333f * src_data_50 + 0.0888888f * src_data_60;
const float t61 = -0.0333333333f * src_data_11 + 0.02222222f * src_data_21 + 0.1666666666f * src_data_31 -
0.1111111111f * src_data_41 - 0.1333333f * src_data_51 + 0.0888888f * src_data_61;
0.1111111111f * src_data_41 - 0.1333333f * src_data_51 + 0.0888888f * src_data_61;
const float t62 = -0.0333333333f * src_data_12 + 0.02222222f * src_data_22 + 0.1666666666f * src_data_32 -
0.1111111111f * src_data_42 - 0.1333333f * src_data_52 + 0.0888888f * src_data_62;
0.1111111111f * src_data_42 - 0.1333333f * src_data_52 + 0.0888888f * src_data_62;
const float t63 = -0.0333333333f * src_data_13 + 0.02222222f * src_data_23 + 0.1666666666f * src_data_33 -
0.1111111111f * src_data_43 - 0.1333333f * src_data_53 + 0.0888888f * src_data_63;
0.1111111111f * src_data_43 - 0.1333333f * src_data_53 + 0.0888888f * src_data_63;
const float t64 = -0.0333333333f * src_data_14 + 0.02222222f * src_data_24 + 0.1666666666f * src_data_34 -
0.1111111111f * src_data_44 - 0.1333333f * src_data_54 + 0.0888888f * src_data_64;
0.1111111111f * src_data_44 - 0.1333333f * src_data_54 + 0.0888888f * src_data_64;
const float t65 = -0.0333333333f * src_data_15 + 0.02222222f * src_data_25 + 0.1666666666f * src_data_35 -
0.1111111111f * src_data_45 - 0.1333333f * src_data_55 + 0.0888888f * src_data_65;
0.1111111111f * src_data_45 - 0.1333333f * src_data_55 + 0.0888888f * src_data_65;
const float t66 = -0.0333333333f * src_data_16 + 0.02222222f * src_data_26 + 0.1666666666f * src_data_36 -
0.1111111111f * src_data_46 - 0.1333333f * src_data_56 + 0.0888888f * src_data_66;
0.1111111111f * src_data_46 - 0.1333333f * src_data_56 + 0.0888888f * src_data_66;
const float t67 = -0.0333333333f * src_data_17 + 0.02222222f * src_data_27 + 0.1666666666f * src_data_37 -
0.1111111111f * src_data_47 - 0.1333333f * src_data_57 + 0.0888888f * src_data_67;
0.1111111111f * src_data_47 - 0.1333333f * src_data_57 + 0.0888888f * src_data_67;
const float t70 = -0.5625f * src_data_10 + 3.0625f * src_data_30 - 3.5f * src_data_50 + src_data_70;
const float t71 = -0.5625f * src_data_11 + 3.0625f * src_data_31 - 3.5f * src_data_51 + src_data_71;
@ -1114,111 +1114,111 @@ void InputTransform8x8Unit(const float *src_data, float *dst_data, int src_step,
const float t76 = -0.5625f * src_data_16 + 3.0625f * src_data_36 - 3.5f * src_data_56 + src_data_76;
const float t77 = -0.5625f * src_data_17 + 3.0625f * src_data_37 - 3.5f * src_data_57 + src_data_77;
const float m00 = t00 - 5.444444444444444445125f * t02 + 6.222222222222222222223f * t04 -
1.77777777777777778f * t06;
const float m00 =
t00 - 5.444444444444444445125f * t02 + 6.222222222222222222223f * t04 - 1.77777777777777778f * t06;
const float m01 = 1.5f * t01 + 3.0f * t02 - 2.1666666666666667f * t03 - 4.333333333333333333f * t04 +
0.66666666666666667f * t05 + 1.333333333333333f * t06;
0.66666666666666667f * t05 + 1.333333333333333f * t06;
const float m02 = -1.5f * t01 + 3.0f * t02 + 2.1666666666666667f * t03 - 4.333333333333333333f * t04 -
0.66666666666666667f * t05 + 1.333333333333333f * t06;
0.66666666666666667f * t05 + 1.333333333333333f * t06;
const float m03 = -0.3f * (t01 + t02) + 1.33333333333333f * (t03 + t04) - 0.53333333333f * (t05 + t06);
const float m04 = 0.3f * (t01 - t02) + 1.33333333333333f * (t04 - t03) + 0.53333333333f * (t05 - t06);
const float m05 = 0.0333333333f * t01 + 0.02222222f * t02 - 0.1666666666f * t03 - 0.1111111111f * t04 +
0.1333333f * t05 + 0.0888888f * t06;
const float m06 = -0.0333333333f * t01 + 0.02222222f * t02 + 0.1666666666f * t03 - 0.1111111111f * t04 -
0.1333333f * t05 + 0.0888888f * t06;
0.1333333f * t05 + 0.0888888f * t06;
const float m07 = -0.5625f * t01 + 3.0625f * t03 - 3.5f * t05 + t07;
float m10 = t10 - 5.444444444444444445125f * t12 + 6.222222222222222222223f * t14 - 1.77777777777777778f * t16;
const float m11 = 1.5f * t11 + 3.0f * t12 - 2.1666666666666667f * t13 - 4.333333333333333333f * t14 +
0.66666666666666667f * t15 + 1.333333333333333f * t16;
0.66666666666666667f * t15 + 1.333333333333333f * t16;
const float m12 = -1.5f * t11 + 3.0f * t12 + 2.1666666666666667f * t13 - 4.333333333333333333f * t14 -
0.66666666666666667f * t15 + 1.333333333333333f * t16;
0.66666666666666667f * t15 + 1.333333333333333f * t16;
const float m13 = -0.3f * (t11 + t12) + 1.33333333333333f * (t13 + t14) - 0.53333333333f * (t15 + t16);
const float m14 = 0.3f * (t11 - t12) + 1.33333333333333f * (t14 - t13) + 0.53333333333f * (t15 - t16);
const float m15 = 0.0333333333f * t11 + 0.02222222f * t12 - 0.1666666666f * t13 - 0.1111111111f * t14 +
0.1333333f * t15 + 0.0888888f * t16;
const float m16 = -0.0333333333f * t11 + 0.02222222f * t12 + 0.1666666666f * t13 - 0.1111111111f * t14 -
0.1333333f * t15 + 0.0888888f * t16;
0.1333333f * t15 + 0.0888888f * t16;
const float m17 = -0.5625f * t11 + 3.0625f * t13 - 3.5f * t15 + t17;
const float m20 = t20 - 5.444444444444444445125f * t22 + 6.222222222222222222223f * t24 -
1.77777777777777778f * t26;
const float m20 =
t20 - 5.444444444444444445125f * t22 + 6.222222222222222222223f * t24 - 1.77777777777777778f * t26;
const float m21 = 1.5f * t21 + 3.0f * t22 - 2.1666666666666667f * t23 - 4.333333333333333333f * t24 +
0.66666666666666667f * t25 + 1.333333333333333f * t26;
0.66666666666666667f * t25 + 1.333333333333333f * t26;
const float m22 = -1.5f * t21 + 3.0f * t22 + 2.1666666666666667f * t23 - 4.333333333333333333f * t24 -
0.66666666666666667f * t25 + 1.333333333333333f * t26;
0.66666666666666667f * t25 + 1.333333333333333f * t26;
const float m23 = -0.3f * (t21 + t22) + 1.33333333333333f * (t23 + t24) - 0.53333333333f * (t25 + t26);
const float m24 = 0.3f * (t21 - t22) + 1.33333333333333f * (t24 - t23) + 0.53333333333f * (t25 - t26);
const float m25 = 0.0333333333f * t21 + 0.02222222f * t22 - 0.1666666666f * t23 - 0.1111111111f * t24 +
0.1333333f * t25 + 0.0888888f * t26;
const float m26 = -0.0333333333f * t21 + 0.02222222f * t22 + 0.1666666666f * t23 - 0.1111111111f * t24 -
0.1333333f * t25 + 0.0888888f * t26;
0.1333333f * t25 + 0.0888888f * t26;
const float m27 = -0.5625f * t21 + 3.0625f * t23 - 3.5f * t25 + t27;
float m30 = t30 - 5.444444444444444445125f * t32 + 6.222222222222222222223f * t34 - 1.77777777777777778f * t36;
const float m31 = 1.5f * t31 + 3.0f * t32 - 2.1666666666666667f * t33 - 4.333333333333333333f * t34 +
0.66666666666666667f * t35 + 1.333333333333333f * t36;
0.66666666666666667f * t35 + 1.333333333333333f * t36;
const float m32 = -1.5f * t31 + 3.0f * t32 + 2.1666666666666667f * t33 - 4.333333333333333333f * t34 -
0.66666666666666667f * t35 + 1.333333333333333f * t36;
0.66666666666666667f * t35 + 1.333333333333333f * t36;
const float m33 = -0.3f * (t31 + t32) + 1.33333333333333f * (t33 + t34) - 0.53333333333f * (t35 + t36);
const float m34 = 0.3f * (t31 - t32) + 1.33333333333333f * (t34 - t33) + 0.53333333333f * (t35 - t36);
const float m35 = 0.0333333333f * t31 + 0.02222222f * t32 - 0.1666666666f * t33 - 0.1111111111f * t34 +
0.1333333f * t35 + 0.0888888f * t36;
const float m36 = -0.0333333333f * t31 + 0.02222222f * t32 + 0.1666666666f * t33 - 0.1111111111f * t34 -
0.1333333f * t35 + 0.0888888f * t36;
0.1333333f * t35 + 0.0888888f * t36;
const float m37 = -0.5625f * t31 + 3.0625f * t33 - 3.5f * t35 + t37;
const float m40 = t40 - 5.444444444444444445125f * t42 + 6.222222222222222222223f * t44 -
1.77777777777777778f * t46;
const float m40 =
t40 - 5.444444444444444445125f * t42 + 6.222222222222222222223f * t44 - 1.77777777777777778f * t46;
const float m41 = 1.5f * t41 + 3.0f * t42 - 2.1666666666666667f * t43 - 4.333333333333333333f * t44 +
0.66666666666666667f * t45 + 1.333333333333333f * t46;
0.66666666666666667f * t45 + 1.333333333333333f * t46;
const float m42 = -1.5f * t41 + 3.0f * t42 + 2.1666666666666667f * t43 - 4.333333333333333333f * t44 -
0.66666666666666667f * t45 + 1.333333333333333f * t46;
0.66666666666666667f * t45 + 1.333333333333333f * t46;
const float m43 = -0.3f * (t41 + t42) + 1.33333333333333f * (t43 + t44) - 0.53333333333f * (t45 + t46);
const float m44 = 0.3f * (t41 - t42) + 1.33333333333333f * (t44 - t43) + 0.53333333333f * (t45 - t46);
const float m45 = 0.0333333333f * t41 + 0.02222222f * t42 - 0.1666666666f * t43 - 0.1111111111f * t44 +
0.1333333f * t45 + 0.0888888f * t46;
const float m46 = -0.0333333333f * t41 + 0.02222222f * t42 + 0.1666666666f * t43 - 0.1111111111f * t44 -
0.1333333f * t45 + 0.0888888f * t46;
0.1333333f * t45 + 0.0888888f * t46;
const float m47 = -0.5625f * t41 + 3.0625f * t43 - 3.5f * t45 + t47;
float m50 = t50 - 5.444444444444444445125f * t52 + 6.222222222222222222223f * t54 - 1.77777777777777778f * t56;
const float m51 = 1.5f * t51 + 3.0f * t52 - 2.1666666666666667f * t53 - 4.333333333333333333f * t54 +
0.66666666666666667f * t55 + 1.333333333333333f * t56;
0.66666666666666667f * t55 + 1.333333333333333f * t56;
const float m52 = -1.5f * t51 + 3.0f * t52 + 2.1666666666666667f * t53 - 4.333333333333333333f * t54 -
0.66666666666666667f * t55 + 1.333333333333333f * t56;
0.66666666666666667f * t55 + 1.333333333333333f * t56;
const float m53 = -0.3f * (t51 + t52) + 1.33333333333333f * (t53 + t54) - 0.53333333333f * (t55 + t56);
const float m54 = 0.3f * (t51 - t52) + 1.33333333333333f * (t54 - t53) + 0.53333333333f * (t55 - t56);
const float m55 = 0.0333333333f * t51 + 0.02222222f * t52 - 0.1666666666f * t53 - 0.1111111111f * t54 +
0.1333333f * t55 + 0.0888888f * t56;
const float m56 = -0.0333333333f * t51 + 0.02222222f * t52 + 0.1666666666f * t53 - 0.1111111111f * t54 -
0.1333333f * t55 + 0.0888888f * t56;
0.1333333f * t55 + 0.0888888f * t56;
const float m57 = -0.5625f * t51 + 3.0625f * t53 - 3.5f * t55 + t57;
float m60 = t60 - 5.444444444444444445125f * t62 + 6.222222222222222222223f * t64 - 1.77777777777777778f * t66;
const float m61 = 1.5f * t61 + 3.0f * t62 - 2.1666666666666667f * t63 - 4.333333333333333333f * t64 +
0.66666666666666667f * t65 + 1.333333333333333f * t66;
0.66666666666666667f * t65 + 1.333333333333333f * t66;
const float m62 = -1.5f * t61 + 3.0f * t62 + 2.1666666666666667f * t63 - 4.333333333333333333f * t64 -
0.66666666666666667f * t65 + 1.333333333333333f * t66;
0.66666666666666667f * t65 + 1.333333333333333f * t66;
const float m63 = -0.3f * (t61 + t62) + 1.33333333333333f * (t63 + t64) - 0.53333333333f * (t65 + t66);
const float m64 = 0.3f * (t61 - t62) + 1.33333333333333f * (t64 - t63) + 0.53333333333f * (t65 - t66);
const float m65 = 0.0333333333f * t61 + 0.02222222f * t62 - 0.1666666666f * t63 - 0.1111111111f * t64 +
0.1333333f * t65 + 0.0888888f * t66;
const float m66 = -0.0333333333f * t61 + 0.02222222f * t62 + 0.1666666666f * t63 - 0.1111111111f * t64 -
0.1333333f * t65 + 0.0888888f * t66;
0.1333333f * t65 + 0.0888888f * t66;
const float m67 = -0.5625f * t61 + 3.0625f * t63 - 3.5f * t65 + t67;
float m70 = t70 - 5.444444444444444445125f * t72 + 6.222222222222222222223f * t74 - 1.77777777777777778f * t76;
const float m71 = 1.5f * t71 + 3.0f * t72 - 2.1666666666666667f * t73 - 4.333333333333333333f * t74 +
0.66666666666666667f * t75 + 1.333333333333333f * t76;
0.66666666666666667f * t75 + 1.333333333333333f * t76;
const float m72 = -1.5f * t71 + 3.0f * t72 + 2.1666666666666667f * t73 - 4.333333333333333333f * t74 -
0.66666666666666667f * t75 + 1.333333333333333f * t76;
0.66666666666666667f * t75 + 1.333333333333333f * t76;
const float m73 = -0.3f * (t71 + t72) + 1.33333333333333f * (t73 + t74) - 0.53333333333f * (t75 + t76);
const float m74 = 0.3f * (t71 - t72) + 1.33333333333333f * (t74 - t73) + 0.53333333333f * (t75 - t76);
const float m75 = 0.0333333333f * t71 + 0.02222222f * t72 - 0.1666666666f * t73 - 0.1111111111f * t74 +
0.1333333f * t75 + 0.0888888f * t76;
const float m76 = -0.0333333333f * t71 + 0.02222222f * t72 + 0.1666666666f * t73 - 0.1111111111f * t74 -
0.1333333f * t75 + 0.0888888f * t76;
0.1333333f * t75 + 0.0888888f * t76;
const float m77 = -0.5625f * t71 + 3.0625f * t73 - 3.5f * t75 + t77;
(dst_data + i)[0] = m00;
@ -1460,7 +1460,7 @@ void OutputTransform4x3Unit(const float *src_data, float *dst_data, const float
const float t10 = 0.5f * (src_data_10 - src_data_20);
const float t11 = 0.5f * (src_data_11 - src_data_21);
const float t12 = 0.5f * (src_data_12 - src_data_22);
const const float t13 = 0.5f * (src_data_13 - src_data_23);
const float t13 = 0.5f * (src_data_13 - src_data_23);
const float t20 = 0.25f * (src_data_10 + src_data_20) + src_data_30;
const float t21 = 0.25f * (src_data_11 + src_data_21) + src_data_31;
@ -2626,7 +2626,7 @@ void OutputTransform8x4Unit(const float *src_data, float *dst_data, const float
const float t24 = 0.25f * d35 + d45 + 2.25f * d55;
const float t25 = 0.25f * d36 + d46 + 2.25f * d56;
const float t26 = 0.25f * d37 + d47 + 2.25f * d57;
const const float t27 = 0.25f * d38 + d48 + 2.25f * d58;
const float t27 = 0.25f * d38 + d48 + 2.25f * d58;
const float t30 = 0.125f * d01 + d11 + 3.375f * d21 + src_data_70;
const float t31 = 0.125f * d02 + d12 + 3.375f * d22 + src_data_71;
@ -3786,7 +3786,7 @@ void OutputTransform8x6Unit(const float *src_data, float *dst_data, const float
const float t52 = 0.03125f * d03 + d13 + 7.59375f * d23 + src_data_72;
const float t53 = 0.03125f * d04 + d14 + 7.59375f * d24 + src_data_73;
const float t54 = 0.03125f * d05 + d15 + 7.59375f * d25 + src_data_74;
const const float t55 = 0.03125f * d06 + d16 + 7.59375f * d26 + src_data_75;
const float t55 = 0.03125f * d06 + d16 + 7.59375f * d26 + src_data_75;
const float t56 = 0.03125f * d07 + d17 + 7.59375f * d27 + src_data_76;
const float t57 = 0.03125f * d08 + d18 + 7.59375f * d28 + src_data_77;