!5887 [MS][LITE][CPU]rewrite winogard using general method

Merge pull request !5887 from fuzhiye/tmp
This commit is contained in:
mindspore-ci-bot 2020-09-11 11:44:24 +08:00 committed by Gitee
commit e4bcb55322
29 changed files with 1997 additions and 9905 deletions

View File

@ -540,7 +540,7 @@ void UnPack3x3Relu6OutputFp16(const float16_t *src, float16_t *dst, int batch, i
// fp16 convolution winograd
void ConvWinogardFp16(float16_t *input_data, float16_t *trans_weight, const float16_t *bias_data,
TmpBufferAddressFp16 *buffer_list, int task_id, ConvParameter *conv_param,
InputTransformUnitFp16Func input_trans_func, OutputTransformUnitFp16Func output_trans_func) {
MatricesFp16 *matrices) {
int thread_num = conv_param->thread_num_;
int input_unit = conv_param->input_unit_;
int in_batch = conv_param->input_batch_;
@ -575,14 +575,14 @@ void ConvWinogardFp16(float16_t *input_data, float16_t *trans_weight, const floa
cal_num = cal_num > tile_num ? tile_num : cal_num;
WinogradInputTransformFp16(input_data + in_batch_offset, trans_input + task_id * trans_input_offset,
tmp_data + task_id * tmp_data_offset, cal_num, out_tile_index, out_w_block, conv_param,
input_trans_func);
matrices[2], matrices[3]);
// step 3 : gemm
IndirectGemmFp16_16x8(gemm_out + task_id * gemm_out_offset, trans_input + task_id * trans_input_offset,
trans_weight, NULL, input_unit_square, ic8 * 2, oc8 * C8NUM, output_offset, 1, 1, 0, 0);
// step 4 : output transform
WinogradOutputTransformFp16(gemm_out + task_id * gemm_out_offset, tmp_out_data + tmp_out_batch_offset, bias_data,
cal_num, out_tile_index, out_w_block, conv_param, output_trans_func);
cal_num, out_tile_index, out_w_block, conv_param, matrices[0], matrices[1]);
}
}
}

View File

@ -22,6 +22,7 @@
#include "nnacl/fp16/winograd_transform_fp16.h"
typedef float16_t *TmpBufferAddressFp16;
typedef float16_t *MatricesFp16;
#ifndef ENABLE_NEON
void IndirectGemmFp16_16x8(float16_t *output, float16_t *input, float16_t *weight, float16_t *bias, size_t step,
@ -69,7 +70,7 @@ void UnPack3x3Relu6OutputFp16(const float16_t *src, float16_t *dst, int batch, i
// fp16 convolution winograd
void ConvWinogardFp16(float16_t *input_data, float16_t *trans_weight, const float16_t *bias_data,
TmpBufferAddressFp16 *buffer_list, int task_id, ConvParameter *conv_param,
InputTransformUnitFp16Func input_trans_func, OutputTransformUnitFp16Func output_trans_func);
MatricesFp16 *matrices);
void UnPackWinogradOutputFp16(const float16_t *src, float16_t *dst, int batch, int height, int width, int channel,
int output_unit);

View File

@ -0,0 +1,65 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "nnacl/fp16/matrix_fp16.h"
void MatrixMultiplyFp16(const float16_t *matrix_a, const float16_t *matrix_b, float16_t *matrix_c, int m, int k,
int n) {
int count = 0;
for (int h = 0; h < m; h++) {
int h_offset = h * k;
for (int w = 0; w < n; w++) {
float16_t res = 0;
for (int i = 0; i < k; i++) {
res += *(matrix_a + h_offset + i) * *(matrix_b + w + i * n);
}
*(matrix_c + count) = res;
count++;
}
}
}
void MatrixMultiplyVecFp16(const float16x8_t *matrix_a, const float16x8_t *matrix_b, float16x8_t *matrix_c,
const float16_t *bias, int m, int k, int n) {
if (bias == NULL) {
int count = 0;
for (int h = 0; h < m; h++) {
int h_offset = h * k;
for (int w = 0; w < n; w++) {
float16x8_t res = vmovq_n_f16(0);
for (int i = 0; i < k; i++) {
res = vaddq_f16(res, vmulq_f16(matrix_a[h_offset + i], matrix_b[w + i * n]));
}
matrix_c[count] = res;
count++;
}
}
} else {
int count = 0;
float16x8_t bias_ptr = vld1q_f16(bias);
for (int h = 0; h < m; h++) {
int h_offset = h * k;
for (int w = 0; w < n; w++) {
float16x8_t res = vmovq_n_f16(0);
for (int i = 0; i < k; i++) {
res = vaddq_f16(res, vmulq_f16(matrix_a[h_offset + i], matrix_b[w + i * n]));
}
matrix_c[count] = vaddq_f16(res, bias_ptr);
count++;
}
}
}
}

View File

@ -14,14 +14,20 @@
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP16_MATRIX_FP16_H_
#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP16_MATRIX_FP16_H_
#ifndef MINDSPORE_LITE_NNACL_FP16_MATRIX_FP16_H_
#define MINDSPORE_LITE_NNACL_FP16_MATRIX_FP16_H_
#include "src/runtime/kernel/arm/base/matrix.h"
#include <arm_neon.h>
namespace mindspore::kernel {
void MatrixMultiplyFp16(const float16_t *matrix_a, const float16_t *matrix_b, float16_t *matrix_c, int m, int k, int n,
bool row);
#ifdef __cplusplus
extern "C" {
#endif
void MatrixMultiplyFp16(const float16_t *matrix_a, const float16_t *matrix_b, float16_t *matrix_c, int m, int k, int n);
void MatrixMultiplyVecFp16(const float16x8_t *matrix_a, const float16x8_t *matrix_b, float16x8_t *matrix_c,
const float16_t *bias, int m, int k, int n);
#ifdef __cplusplus
}
#endif
#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP16_MATRIX_FP16_H_
#endif // MINDSPORE_LITE_NNACL_FP16_MATRIX_FP16_H_

View File

@ -569,8 +569,8 @@ void Conv3x3Fp16OutputTransform(const float16_t *gemm_out, float16_t *out_data,
// fp16 common winograd
void WinogradInputTransformFp16(const float16_t *input_data, float16_t *trans_input, float16_t *tmp_data, int cal_num,
int out_tile_index, int out_w_block_num, ConvParameter *conv_param,
InputTransformUnitFp16Func input_trans_func) {
int out_tile_index, int out_w_block_num, ConvParameter *conv_param, float16_t *matrix_b,
float16_t *matrix_bt) {
const int tile_num = 16;
int input_unit = conv_param->input_unit_;
int output_unit = conv_param->output_unit_;
@ -622,7 +622,7 @@ void WinogradInputTransformFp16(const float16_t *input_data, float16_t *trans_in
int dst_ic8_offset = dst_plane_offset + ic * tile_num * C8NUM;
size_t dst_step = ic8 * C8NUM * tile_num;
float16_t *trans_input_ptr = trans_input + dst_ic8_offset;
input_trans_func(tmp_data, trans_input_ptr, C8NUM, dst_step);
GeneralInputTransformUnitFp16(tmp_data, trans_input_ptr, matrix_b, matrix_bt, C8NUM, dst_step, input_unit);
}
out_tile_index++;
} // cal_tile_num loop
@ -630,7 +630,7 @@ void WinogradInputTransformFp16(const float16_t *input_data, float16_t *trans_in
void WinogradOutputTransformFp16(const float16_t *gemm_out, float16_t *tmp_out_data, const float16_t *bias_data,
int cal_num, int out_tile_index, int output_unit_num, ConvParameter *conv_param,
OutputTransformUnitFp16Func output_trans_func) {
float16_t *matrix_a, float16_t *matrix_at) {
int output_unit = conv_param->output_unit_;
int output_w = conv_param->output_w_;
int output_h = conv_param->output_h_;
@ -655,7 +655,8 @@ void WinogradOutputTransformFp16(const float16_t *gemm_out, float16_t *tmp_out_d
const float16_t *src_ptr = gemm_out + src_oc8_offset;
const float16_t *bias_ptr = bias_data + j * C8NUM;
float16_t *dst_ptr = tmp_out_data + dst_oc8_offset;
output_trans_func(src_ptr, dst_ptr, bias_ptr, C8NUM, output_w_unit_block * output_unit);
GeneralOutputTransformUnitFp16(src_ptr, dst_ptr, bias_ptr, matrix_a, matrix_at, C8NUM,
output_w_unit_block * output_unit, input_unit, output_unit);
}
out_tile_index++;
}

View File

@ -43,12 +43,12 @@ void Conv3x3Fp16OutputTransform(const float16_t *gemm_out, float16_t *out_data,
// fp16 common winograd
void WinogradInputTransformFp16(const float16_t *input_data, float16_t *trans_input, float16_t *tmp_data, int cal_num,
int out_tile_index, int out_w_block_num, ConvParameter *conv_param,
InputTransformUnitFp16Func input_trans_func);
int out_tile_index, int out_w_block_num, ConvParameter *conv_param, float16_t *matrix_b,
float16_t *matrix_bt);
void WinogradOutputTransformFp16(const float16_t *gemm_out, float16_t *tmp_out_data, const float16_t *bias_data,
int cal_num, int out_tile_index, int output_unit_num, ConvParameter *conv_param,
OutputTransformUnitFp16Func output_trans_func);
float16_t *matrix_a, float16_t *matrix_at);
#ifdef __cplusplus
}
#endif

File diff suppressed because it is too large Load Diff

View File

@ -21,45 +21,17 @@
#include "nnacl/conv_parameter.h"
#include "nnacl/op_base.h"
typedef void (*InputTransformUnitFp16Func)(const float16_t *src_data, float16_t *dst_data, int src_step, int dst_step);
typedef void (*OutputTransformUnitFp16Func)(const float16_t *src_data, float16_t *dst_data, const float16_t *bias_data,
int src_step, int dst_step);
#define MAX_LEN 256
#ifdef __cplusplus
extern "C" {
#endif
void InputTransform4x4UnitFp16(const float16_t *src_data, float16_t *dst_data, int src_step, int dst_step);
void InputTransform8x8UnitFp16(const float16_t *src_data, float16_t *dst_data, int src_step, int dst_step);
void OutputTransform4x2UnitFp16(const float16_t *src_data, float16_t *dst_data, const float16_t *bias_data,
int src_step, int dst_step);
void OutputTransform4x3UnitFp16(const float16_t *src_data, float16_t *dst_data, const float16_t *bias_data,
int src_step, int dst_step);
void OutputTransform8x2UnitFp16(const float16_t *src_data, float16_t *dst_data, const float16_t *bias_data,
int src_step, int dst_step);
void OutputTransform8x3UnitFp16(const float16_t *src_data, float16_t *dst_data, const float16_t *bias_data,
int src_step, int dst_step);
void OutputTransform8x4UnitFp16(const float16_t *src_data, float16_t *dst_data, const float16_t *bias_data,
int src_step, int dst_step);
void OutputTransform8x5UnitFp16(const float16_t *src_data, float16_t *dst_data, const float16_t *bias_data,
int src_step, int dst_step);
void OutputTransform8x6UnitFp16(const float16_t *src_data, float16_t *dst_data, const float16_t *bias_data,
int src_step, int dst_step);
void OutputTransform8x7UnitFp16(const float16_t *src_data, float16_t *dst_data, const float16_t *bias_data,
int src_step, int dst_step);
InputTransformUnitFp16Func GetInputTransFuncFp16(int input_unit);
OutputTransformUnitFp16Func GetOutputTransFuncFp16(int input_unit, int output_unit);
void GeneralInputTransformUnitFp16(const float16_t *src_data, float16_t *dst_data, float16_t *matrix_b,
float16_t *matrix_bt, int src_step, int dst_step, int in_unit);
void GeneralOutputTransformUnitFp16(const float16_t *src_data, float16_t *dst_data, const float16_t *bias_data,
float16_t *matrix_a, float16_t *matrix_at, int src_step, int dst_step, int in_unit,
int out_unit);
#ifdef __cplusplus
}
#endif

View File

@ -259,8 +259,8 @@ void ConvFp32(float *input_data, float *packed_input, float *packed_weight, cons
// fp32 conv winograd
void ConvWinogardFp32(float *input_data, float *trans_weight, const float *bias_data, TmpBufferAddress *buffer_list,
int task_id, ConvParameter *conv_param, InputTransformUnitFunc input_trans_func,
OutputTransformUnitFunc output_trans_func, GEMM_FUNC_FP32 gemm_func) {
int task_id, ConvParameter *conv_param, InputTransFunc in_func, OutputTransFunc out_func,
GEMM_FUNC_FP32 gemm_func) {
int thread_num = conv_param->thread_num_;
int input_unit = conv_param->input_unit_;
int in_batch = conv_param->input_batch_;
@ -296,7 +296,7 @@ void ConvWinogardFp32(float *input_data, float *trans_weight, const float *bias_
cal_num = cal_num > C12NUM ? C12NUM : cal_num;
WinogradInputTransform(input_data + in_batch_offset, trans_input + task_id * trans_input_offset,
tmp_data + task_id * tmp_data_offset, cal_num, out_tile_index, out_w_block, conv_param,
input_trans_func);
in_func);
// step 3 : gemm
float *src_ptr = trans_input + task_id * trans_input_offset;
float *dst_ptr = gemm_out + task_id * gemm_out_offset;
@ -309,7 +309,7 @@ void ConvWinogardFp32(float *input_data, float *trans_weight, const float *bias_
// step 4 : output transform
WinogradOutputTransform(dst_ptr, tmp_out_data + tmp_out_batch_offset, bias_data, cal_num, out_tile_index,
out_w_block, conv_param, output_trans_func);
out_w_block, conv_param, out_func);
}
}
}

View File

@ -28,6 +28,7 @@
#include "nnacl/fp32/conv_depthwise.h"
typedef float *TmpBufferAddress;
typedef float *Matrices;
typedef void (*GEMM_FUNC_FP32)(float *output, const float *input, const float *weight, const float *bias, size_t step,
size_t ic4, size_t output_channel, size_t offset, size_t mode, size_t writeC4,
size_t relu, size_t relu6);
@ -53,8 +54,8 @@ void ConvFp32(float *input_data, float *packed_input, float *packed_weight, cons
// fp32 convolution winograd
void ConvWinogardFp32(float *input_data, float *trans_weight, const float *bias_data, TmpBufferAddress *buffer_list,
int task_id, ConvParameter *conv_param, InputTransformUnitFunc input_trans_func,
OutputTransformUnitFunc output_trans_func, GEMM_FUNC_FP32 gemm_func);
int task_id, ConvParameter *conv_param, InputTransFunc in_func, OutputTransFunc out_func,
GEMM_FUNC_FP32 gemm_func);
void UnPackWinogradOutput(const float *src, float *dst, int batch, int height, int width, int channel, int output_unit);

View File

@ -1,507 +0,0 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "nnacl/matrix_table.h"
void MatrixG4x2(float *matrix_data) {
matrix_data[0] = 1.0f;
matrix_data[1] = 0.0f;
matrix_data[2] = 1.0f;
matrix_data[3] = 0.5f;
matrix_data[4] = 1.0f;
matrix_data[5] = -0.5f;
matrix_data[6] = 0.0f;
matrix_data[7] = 1.0f;
}
void MatrixGT2x4(float *matrix_data) {
matrix_data[0] = 1.0f;
matrix_data[1] = 1.0f;
matrix_data[2] = 1.0f;
matrix_data[3] = 0.0f;
matrix_data[4] = 0.0f;
matrix_data[5] = 0.5f;
matrix_data[6] = -0.5f;
matrix_data[7] = 1.0f;
}
void MatrixG8x2(float *matrix_data) {
matrix_data[0] = 1.0f;
matrix_data[1] = 0.0f;
matrix_data[2] = 1.0f;
matrix_data[3] = 0.5f;
matrix_data[4] = 1.0f;
matrix_data[5] = -0.5f;
matrix_data[6] = 1.0f;
matrix_data[7] = 1.0f;
matrix_data[8] = 1.0f;
matrix_data[9] = -1.0f;
matrix_data[10] = 1.0f;
matrix_data[11] = 1.5f;
matrix_data[12] = 1.0f;
matrix_data[13] = -1.5f;
matrix_data[14] = 0.0f;
matrix_data[15] = 1.0f;
}
void MatrixGT2x8(float *matrix_data) {
matrix_data[0] = 1.0f;
matrix_data[1] = 1.0f;
matrix_data[2] = 1.0f;
matrix_data[3] = 1.5f;
matrix_data[4] = 1.0f;
matrix_data[5] = 1.0f;
matrix_data[6] = 1.0f;
matrix_data[7] = 0.0f;
matrix_data[8] = 0.0f;
matrix_data[9] = 0.5f;
matrix_data[10] = -0.5f;
matrix_data[11] = 1.0f;
matrix_data[12] = -1.0f;
matrix_data[13] = 1.5f;
matrix_data[14] = -1.5f;
matrix_data[15] = 1.0f;
}
void MatrixG8x3(float *matrix_data) {
matrix_data[0] = 1.0f;
matrix_data[1] = 0.0f;
matrix_data[2] = 0.0f;
matrix_data[3] = 1.0f;
matrix_data[4] = 0.5f;
matrix_data[5] = 0.25f;
matrix_data[6] = 1.0f;
matrix_data[7] = -0.5f;
matrix_data[8] = 0.25f;
matrix_data[9] = 1.0f;
matrix_data[10] = 1.0f;
matrix_data[11] = 1.0f;
matrix_data[12] = 1.0f;
matrix_data[13] = -1.0f;
matrix_data[14] = 1.0f;
matrix_data[15] = 1.0f;
matrix_data[16] = 1.5f;
matrix_data[17] = 2.25f;
matrix_data[18] = 1.0f;
matrix_data[19] = -1.5f;
matrix_data[20] = 2.25f;
matrix_data[21] = 0.0f;
matrix_data[22] = 0.0f;
matrix_data[23] = 1.0f;
}
void MatrixGT3x8(float *matrix_data) {
matrix_data[0] = 1.0f;
matrix_data[1] = 1.0f;
matrix_data[2] = 1.0f;
matrix_data[3] = 1.0f;
matrix_data[4] = 1.0f;
matrix_data[5] = 1.0f;
matrix_data[6] = 1.0f;
matrix_data[7] = 0.0f;
matrix_data[8] = 0.0f;
matrix_data[9] = 0.5f;
matrix_data[10] = -0.5f;
matrix_data[11] = 1.0f;
matrix_data[12] = -1.0f;
matrix_data[13] = 1.5f;
matrix_data[14] = -1.5f;
matrix_data[15] = 0.0f;
matrix_data[16] = 0.0f;
matrix_data[17] = 0.25f;
matrix_data[18] = 0.25f;
matrix_data[19] = 1.0f;
matrix_data[20] = 1.0f;
matrix_data[21] = 2.25f;
matrix_data[22] = 2.25f;
matrix_data[23] = 1.0f;
}
void MatrixG8x4(float *matrix_data) {
matrix_data[0] = 1.0f;
matrix_data[1] = 0.0f;
matrix_data[2] = 0.0f;
matrix_data[3] = 0.0f;
matrix_data[4] = 1.0f;
matrix_data[5] = 0.5f;
matrix_data[6] = 0.25f;
matrix_data[7] = 0.125f;
matrix_data[8] = 1.0f;
matrix_data[9] = -0.5f;
matrix_data[10] = 0.25f;
matrix_data[11] = -0.125f;
matrix_data[12] = 1.0f;
matrix_data[13] = 1.0f;
matrix_data[14] = 1.0f;
matrix_data[15] = 1.0f;
matrix_data[16] = 1.0f;
matrix_data[17] = -1.0f;
matrix_data[18] = 1.0f;
matrix_data[19] = -1.0f;
matrix_data[20] = 1.0f;
matrix_data[21] = 1.5f;
matrix_data[22] = 2.25f;
matrix_data[23] = 3.375f;
matrix_data[24] = 1.0f;
matrix_data[25] = -1.5f;
matrix_data[26] = 2.25f;
matrix_data[27] = -3.375f;
matrix_data[28] = 0.0f;
matrix_data[29] = 0.0f;
matrix_data[30] = 0.0f;
matrix_data[31] = 1.0f;
}
void MatrixGT4x8(float *matrix_data) {
matrix_data[0] = 1.0f;
matrix_data[1] = 1.0f;
matrix_data[2] = 1.0f;
matrix_data[3] = 1.0f;
matrix_data[4] = 1.0f;
matrix_data[5] = 1.0f;
matrix_data[6] = 1.0f;
matrix_data[7] = 0.0f;
matrix_data[8] = 0.0f;
matrix_data[9] = 0.5f;
matrix_data[10] = -0.5f;
matrix_data[11] = 1.0f;
matrix_data[12] = -1.0f;
matrix_data[13] = 1.5f;
matrix_data[14] = -1.5f;
matrix_data[15] = 0.0f;
matrix_data[16] = 0.0f;
matrix_data[17] = 0.25f;
matrix_data[18] = 0.25f;
matrix_data[19] = 1.0f;
matrix_data[20] = 1.0f;
matrix_data[21] = 2.25f;
matrix_data[22] = 2.25f;
matrix_data[23] = 0.0f;
matrix_data[24] = 0.0f;
matrix_data[25] = 0.125f;
matrix_data[26] = -0.125f;
matrix_data[27] = 1.0f;
matrix_data[28] = -1.0f;
matrix_data[29] = 3.375f;
matrix_data[30] = -3.375f;
matrix_data[31] = 1.0f;
}
void MatrixG8x5(float *matrix_data) {
matrix_data[0] = 1.0f;
matrix_data[1] = 0.0f;
matrix_data[2] = 0.0f;
matrix_data[3] = 0.0f;
matrix_data[4] = 0.0f;
matrix_data[5] = 1.0f;
matrix_data[6] = 0.5f;
matrix_data[7] = 0.25f;
matrix_data[8] = 0.125f;
matrix_data[9] = 0.0625f;
matrix_data[10] = 1.0f;
matrix_data[11] = -0.5f;
matrix_data[12] = 0.25f;
matrix_data[13] = -0.125f;
matrix_data[14] = 0.0625f;
matrix_data[15] = 1.0f;
matrix_data[16] = 1.0f;
matrix_data[17] = 1.0f;
matrix_data[18] = 1.0f;
matrix_data[19] = 1.0f;
matrix_data[20] = 1.0f;
matrix_data[21] = -1.0f;
matrix_data[22] = 1.0f;
matrix_data[23] = -1.0f;
matrix_data[24] = 1.0f;
matrix_data[25] = 1.0f;
matrix_data[26] = 1.5f;
matrix_data[27] = 2.25f;
matrix_data[28] = 3.375f;
matrix_data[29] = 5.0625f;
matrix_data[30] = 1.0f;
matrix_data[31] = -1.5f;
matrix_data[32] = 2.25f;
matrix_data[33] = -3.375f;
matrix_data[34] = 5.0625f;
matrix_data[35] = 0.0f;
matrix_data[36] = 0.0f;
matrix_data[37] = 0.0f;
matrix_data[38] = 0.0f;
matrix_data[39] = 1.0f;
}
void MatrixGT5x8(float *matrix_data) {
matrix_data[0] = 1.0f;
matrix_data[1] = 1.0f;
matrix_data[2] = 1.0f;
matrix_data[3] = 1.0f;
matrix_data[4] = 1.0f;
matrix_data[5] = 1.0f;
matrix_data[6] = 1.0f;
matrix_data[7] = 0.0f;
matrix_data[8] = 0.0f;
matrix_data[9] = 0.5f;
matrix_data[10] = -0.5f;
matrix_data[11] = 1.0f;
matrix_data[12] = -1.0f;
matrix_data[13] = 1.5f;
matrix_data[14] = -1.5f;
matrix_data[15] = 0.0f;
matrix_data[16] = 0.0f;
matrix_data[17] = 0.25f;
matrix_data[18] = 0.25f;
matrix_data[19] = 1.0f;
matrix_data[20] = 1.0f;
matrix_data[21] = 2.25f;
matrix_data[22] = 2.25f;
matrix_data[23] = 0.0f;
matrix_data[24] = 0.0f;
matrix_data[25] = 0.125f;
matrix_data[26] = -0.125f;
matrix_data[27] = 1.0f;
matrix_data[28] = -1.0f;
matrix_data[29] = 3.375f;
matrix_data[30] = -3.375f;
matrix_data[31] = 0.0f;
matrix_data[32] = 0.0f;
matrix_data[33] = 0.0625f;
matrix_data[34] = 0.0625f;
matrix_data[35] = 1.0f;
matrix_data[36] = 1.0f;
matrix_data[37] = 5.0625f;
matrix_data[38] = 5.0625f;
matrix_data[39] = 1.0f;
}
void MatrixG8x6(float *matrix_data) {
matrix_data[0] = 1.0f;
matrix_data[1] = 0.0f;
matrix_data[2] = 0.0f;
matrix_data[3] = 0.0f;
matrix_data[4] = 0.0f;
matrix_data[5] = 0.0f;
matrix_data[6] = 1.0f;
matrix_data[7] = 0.5f;
matrix_data[8] = 0.25f;
matrix_data[9] = 0.125f;
matrix_data[10] = 0.0625f;
matrix_data[11] = 0.03125f;
matrix_data[12] = 1.0f;
matrix_data[13] = -0.5f;
matrix_data[14] = 0.25f;
matrix_data[15] = -0.125f;
matrix_data[16] = 0.0625f;
matrix_data[17] = -0.03125f;
matrix_data[18] = 1.0f;
matrix_data[19] = 1.0f;
matrix_data[20] = 1.0f;
matrix_data[21] = 1.0f;
matrix_data[22] = 1.0f;
matrix_data[23] = 1.0f;
matrix_data[24] = 1.0f;
matrix_data[25] = -1.0f;
matrix_data[26] = 1.0f;
matrix_data[27] = -1.0f;
matrix_data[28] = 1.0f;
matrix_data[29] = -1.0f;
matrix_data[30] = 1.0f;
matrix_data[31] = 1.5f;
matrix_data[32] = 2.25f;
matrix_data[33] = 3.375f;
matrix_data[34] = 5.0625f;
matrix_data[35] = 7.59375f;
matrix_data[36] = 1.0f;
matrix_data[37] = -1.5f;
matrix_data[38] = 2.25f;
matrix_data[39] = -3.375f;
matrix_data[40] = 5.0625f;
matrix_data[41] = -7.59375f;
matrix_data[42] = 0.0f;
matrix_data[43] = 0.0f;
matrix_data[44] = 0.0f;
matrix_data[45] = 0.0f;
matrix_data[46] = 0.0f;
matrix_data[47] = 1.0f;
}
void MatrixGT6x8(float *matrix_data) {
matrix_data[0] = 1.0f;
matrix_data[1] = 1.0f;
matrix_data[2] = 1.0f;
matrix_data[3] = 1.0f;
matrix_data[4] = 1.0f;
matrix_data[5] = 1.0f;
matrix_data[6] = 1.0f;
matrix_data[7] = 0.0f;
matrix_data[8] = 0.0f;
matrix_data[9] = 0.5f;
matrix_data[10] = -0.5f;
matrix_data[11] = 1.0f;
matrix_data[12] = -1.0f;
matrix_data[13] = 1.5f;
matrix_data[14] = -1.5f;
matrix_data[15] = 0.0f;
matrix_data[16] = 0.0f;
matrix_data[17] = 0.25f;
matrix_data[18] = 0.25f;
matrix_data[19] = 1.0f;
matrix_data[20] = 1.0f;
matrix_data[21] = 2.25f;
matrix_data[22] = 2.25f;
matrix_data[23] = 0.0f;
matrix_data[24] = 0.0f;
matrix_data[25] = 0.125f;
matrix_data[26] = -0.125f;
matrix_data[27] = 1.0f;
matrix_data[28] = -1.0f;
matrix_data[29] = 3.375f;
matrix_data[30] = -3.375f;
matrix_data[31] = 0.0f;
matrix_data[32] = 0.0f;
matrix_data[33] = 0.0625f;
matrix_data[34] = 0.0625f;
matrix_data[35] = 1.0f;
matrix_data[36] = 1.0f;
matrix_data[37] = 5.0625f;
matrix_data[38] = 5.0625f;
matrix_data[39] = 0.0f;
matrix_data[40] = 0.0;
matrix_data[41] = 0.03125f;
matrix_data[42] = -0.03125f;
matrix_data[43] = 1.0f;
matrix_data[44] = -1.0f;
matrix_data[45] = 7.59375f;
matrix_data[46] = -7.59375f;
matrix_data[47] = 0.0f;
matrix_data[48] = 1.0f;
}
void MatrixG8x7(float *matrix_data) {
matrix_data[0] = 1.0f;
matrix_data[1] = 0.0f;
matrix_data[2] = 0.0f;
matrix_data[3] = 0.0f;
matrix_data[4] = 0.0f;
matrix_data[5] = 0.0f;
matrix_data[6] = 0.0f;
matrix_data[7] = 1.0f;
matrix_data[8] = 0.5f;
matrix_data[9] = 0.25f;
matrix_data[10] = 0.125f;
matrix_data[11] = 0.0625f;
matrix_data[12] = 0.03125f;
matrix_data[13] = 0.015625f;
matrix_data[14] = 1.0f;
matrix_data[15] = -0.5f;
matrix_data[16] = 0.25f;
matrix_data[17] = -0.125f;
matrix_data[18] = 0.0625f;
matrix_data[19] = -0.03125f;
matrix_data[20] = 0.015625f;
matrix_data[21] = 1.0f;
matrix_data[22] = 1.0f;
matrix_data[23] = 1.0f;
matrix_data[24] = 1.0f;
matrix_data[25] = 1.0f;
matrix_data[26] = 1.0f;
matrix_data[27] = 1.0f;
matrix_data[28] = 1.0f;
matrix_data[29] = -1.0f;
matrix_data[30] = 1.0f;
matrix_data[31] = -1.0f;
matrix_data[32] = 1.0f;
matrix_data[33] = -1.0f;
matrix_data[34] = 1.0f;
matrix_data[35] = 1.0f;
matrix_data[36] = 1.5f;
matrix_data[37] = 2.25f;
matrix_data[38] = 3.375f;
matrix_data[39] = 5.0625f;
matrix_data[40] = 7.59375f;
matrix_data[41] = 11.390625f;
matrix_data[42] = 1.0f;
matrix_data[43] = -1.5f;
matrix_data[44] = 2.25f;
matrix_data[45] = -3.375f;
matrix_data[46] = 5.0625f;
matrix_data[47] = -7.59375f;
matrix_data[48] = 11.390625f;
matrix_data[49] = 0.0f;
matrix_data[50] = 0.0f;
matrix_data[51] = 0.0f;
matrix_data[52] = 0.0f;
matrix_data[53] = 0.0f;
matrix_data[54] = 0.0f;
matrix_data[55] = 1.0f;
}
void MatrixGT7x8(float *matrix_data) {
matrix_data[0] = 1.0f;
matrix_data[1] = 1.0f;
matrix_data[2] = 1.0f;
matrix_data[3] = 1.0f;
matrix_data[4] = 1.0f;
matrix_data[5] = 1.0f;
matrix_data[6] = 1.0f;
matrix_data[7] = 0.0f;
matrix_data[8] = 0.0f;
matrix_data[9] = 0.5f;
matrix_data[10] = -0.5f;
matrix_data[11] = 1.0f;
matrix_data[12] = -1.0f;
matrix_data[13] = 1.5f;
matrix_data[14] = -1.5f;
matrix_data[15] = 0.0f;
matrix_data[16] = 0.0f;
matrix_data[17] = 0.25f;
matrix_data[18] = 0.25f;
matrix_data[19] = 1.0f;
matrix_data[20] = 1.0f;
matrix_data[21] = 2.25f;
matrix_data[22] = 2.25f;
matrix_data[23] = 0.0f;
matrix_data[24] = 0.0f;
matrix_data[25] = 0.125f;
matrix_data[26] = -0.125f;
matrix_data[27] = 1.0f;
matrix_data[28] = -1.0f;
matrix_data[29] = 3.375f;
matrix_data[30] = -3.375f;
matrix_data[31] = 0.0f;
matrix_data[32] = 0.0f;
matrix_data[33] = 0.0625f;
matrix_data[34] = 0.0625f;
matrix_data[35] = 1.0f;
matrix_data[36] = 1.0f;
matrix_data[37] = 5.0625f;
matrix_data[38] = 5.0625f;
matrix_data[39] = 0.0f;
matrix_data[40] = 0.0;
matrix_data[41] = 0.03125f;
matrix_data[42] = -0.03125f;
matrix_data[43] = 1.0f;
matrix_data[44] = -1.0f;
matrix_data[45] = 7.59375f;
matrix_data[46] = -7.59375f;
matrix_data[47] = 0.0f;
matrix_data[48] = 0.0f;
matrix_data[49] = 0.015625f;
matrix_data[50] = 0.015625f;
matrix_data[51] = 1.0f;
matrix_data[52] = 1.0f;
matrix_data[53] = 11.390625f;
matrix_data[54] = 11.390625f;
matrix_data[55] = 1.0f;
}

View File

@ -1,54 +0,0 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_NNACL_MATRIX_TABLE_H_
#define MINDSPORE_LITE_NNACL_MATRIX_TABLE_H_
#ifdef __cplusplus
extern "C" {
#endif
void MatrixG4x2(float *matrix_data);
void MatrixGT2x4(float *matrix_data);
void MatrixG8x2(float *matrix_data);
void MatrixGT2x8(float *matrix_data);
void MatrixG8x3(float *matrix_data);
void MatrixGT3x8(float *matrix_data);
void MatrixG8x4(float *matrix_data);
void MatrixGT4x8(float *matrix_data);
void MatrixG8x5(float *matrix_data);
void MatrixGT5x8(float *matrix_data);
void MatrixG8x6(float *matrix_data);
void MatrixGT6x8(float *matrix_data);
void MatrixG8x7(float *matrix_data);
void MatrixGT7x8(float *matrix_data);
#ifdef __cplusplus
}
#endif
#endif // MINDSPORE_LITE_NNACL_MATRIX_TABLE_H_

View File

@ -0,0 +1,233 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "nnacl/minimal_filtering_generator.h"
#include <string.h>
#include <math.h>
#include <stdlib.h>
void Polynomial(float *interval, float *m, int degree) {
for (int i = 0; i < degree; ++i) {
float mul = 1;
for (int j = 0; j < degree; ++j) {
if (i == j) continue;
mul *= (interval[i] - interval[j]);
}
m[i] = mul;
}
}
void DiagonalPlusMatrix(float *matrix, float *diagonal_matrix, int degree) {
int data_num = (degree + 1) * (degree + 1);
memset(diagonal_matrix, 0, data_num * sizeof(float));
for (int i = 0; i < (degree + 1); ++i) {
for (int j = 0; j < (degree + 1); ++j) {
if (j == i) diagonal_matrix[i * (degree + 1) + j] = matrix[i];
}
}
diagonal_matrix[data_num - 1] = 1;
}
void ResidueMatrix(float *interval, float *b, int row, int col) {
// row : input unit, col : output_unit
// result : matrix b
int len = row * col;
memset(b, 0, len * sizeof(float));
for (int i = 0; i < row - 1; ++i) {
for (int j = 0; j < col; ++j) {
b[i * col + j] = pow(interval[i], j);
}
}
b[len - 1] = 1;
}
void LT(float *poly_array, float *matrix_lt, int n) {
float *coefficient_array = (float *)malloc(n * sizeof(float));
float *poly = (float *)malloc(n * sizeof(float));
Polynomial(poly_array, poly, n);
for (int i = 0; i < n; ++i) {
// get coefficient
int index = 1;
memset(coefficient_array, 0, n * sizeof(float));
coefficient_array[0] = 1;
for (int j = 0; j < n; ++j) {
if (j == i) continue;
float poly_coe = poly_array[j] == 0 ? 0 : -poly_array[j];
coefficient_array[index] = 1;
for (int k = index - 1; k > 0; --k) {
coefficient_array[k] = coefficient_array[k] * poly_coe + coefficient_array[k - 1];
}
coefficient_array[0] *= poly_coe;
index++;
}
// lx[i, 0].nth(j) / f[i]
int setp = i * n;
for (int l = 0; l < n; ++l) {
matrix_lt[setp + l] = coefficient_array[l] / poly[i];
}
} // matrix L row loop
free(coefficient_array);
free(poly);
}
void T(float *poly_array, float *matrix_t, int n) {
memset(matrix_t, 0, n * (n + 1) * sizeof(float));
for (int i = 0; i < n; ++i) {
for (int j = 0; j < n + 1; ++j) {
if (j == i) matrix_t[i * (n + 1) + j] = 1;
if (j == n) {
if (poly_array[i] == 0) {
matrix_t[i * (n + 1) + j] = 0;
} else {
matrix_t[i * (n + 1) + j] = -pow(poly_array[i], n);
}
}
}
}
}
void B(float *poly_array, float *matrix_b, int in_unit) {
memset(matrix_b, 0, in_unit * in_unit * sizeof(float));
int n = in_unit - 1;
float *matrix_l = (float *)malloc(n * n * sizeof(float));
float *matrix_lt = (float *)malloc(n * n * sizeof(float));
float *matrix_t = (float *)malloc(n * in_unit * sizeof(float));
T(poly_array, matrix_t, n);
LT(poly_array, matrix_lt, n);
MatrixTranspose(matrix_lt, matrix_l, n, n);
MatrixMultiply(matrix_l, matrix_t, matrix_b, n, n, in_unit);
matrix_b[in_unit * in_unit - 1] = 1;
free(matrix_l);
free(matrix_lt);
free(matrix_t);
}
void GenerateIntervalArray(float *array, float interval, int degree) {
array[0] = 0;
for (int i = 1; i < degree; ++i) {
int coefficient = pow(-1, i - 1);
array[i] = array[i - 1] + interval * i * coefficient;
}
}
void MatrixTranspose(float *matrix, float *trans_matrix, int row, int col) {
for (int i = 0; i < col; ++i) {
for (int j = 0; j < row; ++j) {
trans_matrix[i * row + j] = matrix[j * col + i];
}
}
}
void MatrixMultiply(const float *matrix_a, const float *matrix_b, float *matrix_c, int m, int k, int n) {
int count = 0;
for (int h = 0; h < m; h++) {
int h_offset = h * k;
for (int w = 0; w < n; w++) {
float res = 0;
for (int i = 0; i < k; i++) {
res += *(matrix_a + h_offset + i) * *(matrix_b + w + i * n);
}
*(matrix_c + count) = res;
count++;
}
}
}
void CookToomFilter(float *matrix_a, float *matrix_at, float *matrix_b, float *matrix_bt, float *matrix_g,
float *matrix_gt, float coefficient, int out_unit, int filter_size) {
int in_unit = out_unit + filter_size - 1;
int degree = in_unit - 1;
float *polynomial_m = malloc(degree * sizeof(float));
float *diagonal_matrix = malloc(in_unit * in_unit * sizeof(float));
float *inverse_diagonal_matrix = malloc(in_unit * in_unit * sizeof(float));
// get diagonal matrix
float *interval = malloc(degree * sizeof(float));
GenerateIntervalArray(interval, coefficient, degree);
Polynomial(interval, polynomial_m, degree);
DiagonalPlusMatrix(polynomial_m, diagonal_matrix, degree);
if (diagonal_matrix[0] < 0) {
for (int i = 0; i < in_unit; ++i) {
if (diagonal_matrix[i] != 0) diagonal_matrix[i] *= -1;
}
}
// inverse diagonal matrix
for (int j = 0; j < in_unit * in_unit; ++j) {
if (diagonal_matrix[j] != 0) {
inverse_diagonal_matrix[j] = 1.0 / diagonal_matrix[j];
} else {
inverse_diagonal_matrix[j] = 0;
}
}
// get matrix A && AT
ResidueMatrix(interval, matrix_a, in_unit, out_unit);
MatrixTranspose(matrix_a, matrix_at, in_unit, out_unit);
// get matrix B
B(interval, matrix_bt, in_unit);
MatrixTranspose(matrix_bt, matrix_b, in_unit, in_unit);
MatrixMultiply(diagonal_matrix, matrix_b, matrix_bt, in_unit, in_unit, in_unit);
MatrixTranspose(matrix_bt, matrix_b, in_unit, in_unit);
// get matrix G && GT
float *tmp_g = malloc(in_unit * filter_size * sizeof(float));
ResidueMatrix(interval, matrix_g, in_unit, filter_size);
MatrixTranspose(matrix_g, tmp_g, in_unit, filter_size);
MatrixMultiply(tmp_g, inverse_diagonal_matrix, matrix_gt, filter_size, in_unit, in_unit);
MatrixTranspose(matrix_gt, matrix_g, filter_size, in_unit);
free(interval);
free(polynomial_m);
free(diagonal_matrix);
free(inverse_diagonal_matrix);
free(tmp_g);
}
#ifdef ENABLE_ARM
void MatrixMultiplyVec(const float32x4_t *matrix_a, const float32x4_t *matrix_b, float32x4_t *matrix_c,
const float *bias, int m, int k, int n) {
if (bias == NULL) {
int count = 0;
for (int h = 0; h < m; h++) {
int h_offset = h * k;
for (int w = 0; w < n; w++) {
float32x4_t res = vmovq_n_f32(0);
for (int i = 0; i < k; i++) {
res = vmlaq_f32(res, matrix_a[h_offset + i], matrix_b[w + i * n]);
}
matrix_c[count] = res;
count++;
}
}
} else {
int count = 0;
float32x4_t bias_ptr = vld1q_f32(bias);
for (int h = 0; h < m; h++) {
int h_offset = h * k;
for (int w = 0; w < n; w++) {
float32x4_t res = vmovq_n_f32(0);
for (int i = 0; i < k; i++) {
res = vmlaq_f32(res, matrix_a[h_offset + i], matrix_b[w + i * n]);
}
matrix_c[count] = vaddq_f32(res, bias_ptr);
count++;
}
}
}
}
#endif

View File

@ -0,0 +1,56 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_NNACL_MINIMAL_FILTERING_GENERATOR_H_
#define MINDSPORE_LITE_NNACL_MINIMAL_FILTERING_GENERATOR_H_
#ifdef ENABLE_ARM
#include <arm_neon.h>
#endif
#ifdef __cplusplus
extern "C" {
#endif
void Polynomial(float *interval, float *m, int degree);
void DiagonalPlusMatrix(float *matrix, float *diagonal_matrix, int degree);
void ResidueMatrix(float *interval, float *b, int row, int col);
void LT(float *poly_array, float *matrix_lt, int n);
void T(float *poly_array, float *matrix_t, int n);
void B(float *poly_array, float *matrix_b, int in_unit);
void GenerateIntervalArray(float *array, float interval, int degree);
void MatrixTranspose(float *matrix, float *trans_matrix, int row, int col);
void MatrixMultiply(const float *matrix_a, const float *matrix_b, float *matrix_c, int m, int k, int n);
void CookToomFilter(float *matrix_a, float *matrix_at, float *matrix_b, float *matrix_bt, float *matrix_g,
float *matrix_gt, float coefficient, int out_unit, int filter_size);
#ifdef ENABLE_ARM
void MatrixMultiplyVec(const float32x4_t *matrix_a, const float32x4_t *matrix_b, float32x4_t *matrix_c,
const float *bias, int m, int k, int n);
#endif
#ifdef __cplusplus
}
#endif
#endif // MINDSPORE_LITE_NNACL_MINIMAL_FILTERING_GENERATOR_H_

View File

@ -101,10 +101,6 @@ void PackWeightInt8(int8_t *weight_data, ConvParameter *conv_param, int8_t *pack
int8_t *origin_data_ptr = weight_data + kernel_block_stride + k * kernel_plane * in_channel;
int8_t *packed_data_ptr = packed_weight + packed_kernel_block_size + k * C4NUM * C4NUM;
*packed_data_ptr = origin_data_ptr[0];
// value of weight must between [-127, 127]
if (packed_data_ptr[0] == -128) {
packed_data_ptr[0] = -127;
}
weight_sum[j * C4NUM + k] += (int32_t)packed_data_ptr[0];
}
} // kernel block loop
@ -146,9 +142,6 @@ void PackWeightInt8Opt(int8_t *weight_data, ConvParameter *conv_param, int8_t *p
int8_t *origin_data_ptr = weight_data + kernel_block_stride + k * kernel_plane * in_channel;
int8_t *packed_data_ptr = packed_weight + packed_kernel_block_size + k * C4NUM;
*packed_data_ptr = origin_data_ptr[0];
if (packed_data_ptr[0] == -128) {
packed_data_ptr[0] = -127;
}
weight_sum[j * C4NUM + k] += (int32_t)(packed_data_ptr[0]);
}
} // kernel block loop

View File

@ -18,8 +18,7 @@
// fp32 conv winograd
void WinogradInputTransform(const float *input_data, float *trans_input, float *tmp_data, int cal_num,
int out_tile_index, int out_w_block_num, ConvParameter *conv_param,
InputTransformUnitFunc input_trans_func) {
int out_tile_index, int out_w_block_num, ConvParameter *conv_param, InputTransFunc func) {
int input_unit = conv_param->input_unit_;
int output_unit = conv_param->output_unit_;
int in_channel = conv_param->input_channel_;
@ -31,6 +30,7 @@ void WinogradInputTransform(const float *input_data, float *trans_input, float *
if (out_w_block_num == 0) {
return;
}
for (int c = 0; c < cal_num; c++) { // actual tiled number
int src_x_s = (out_tile_index % out_w_block_num) * output_unit - pad_w;
int src_y_s = (out_tile_index / out_w_block_num) * output_unit - pad_h;
@ -70,15 +70,15 @@ void WinogradInputTransform(const float *input_data, float *trans_input, float *
int dst_ic4_offset = dst_plane_offset + ic * C4NUM;
size_t dst_step = C12NUM * ic4 * C4NUM;
float *trans_input_ptr = trans_input + dst_ic4_offset;
input_trans_func(tmp_data, trans_input_ptr, C4NUM, dst_step);
func(tmp_data, trans_input_ptr, C4NUM, dst_step);
// GeneralInputTransformUnit(tmp_data, trans_input_ptr, matrix_b, matrix_bt, C4NUM, dst_step, input_unit);
}
out_tile_index++;
} // cal_tile_num loop
}
void WinogradOutputTransform(const float *gemm_out, float *tmp_out_data, const float *bias_data, int cal_num,
int out_tile_index, int output_unit_num, ConvParameter *conv_param,
OutputTransformUnitFunc output_trans_func) {
int out_tile_index, int output_unit_num, ConvParameter *conv_param, OutputTransFunc func) {
int output_unit = conv_param->output_unit_;
int output_w = conv_param->output_w_;
int output_h = conv_param->output_h_;
@ -106,7 +106,9 @@ void WinogradOutputTransform(const float *gemm_out, float *tmp_out_data, const f
const float *src_ptr = gemm_out + src_oc4_offset;
const float *bias_ptr = bias_data + j * C4NUM;
float *dst_ptr = tmp_out_data + dst_oc4_offset;
output_trans_func(src_ptr, dst_ptr, bias_ptr, C8NUM, output_w_unit_block * output_unit);
func(src_ptr, dst_ptr, bias_ptr, C8NUM, output_w_unit_block * output_unit);
// GeneralOutputTransformUnit(src_ptr, dst_ptr, bias_ptr, matrix_a, matrix_at, C8NUM,
// output_w_unit_block * output_unit, input_unit, output_unit);
}
out_tile_index++;
}
@ -865,7 +867,7 @@ void Conv3x3Int8InputUnit(int16_t *tmp_data, int16_t *trans_input_data, size_t s
}
void Conv3x3Int8InputTransform(const int16_t *input_data, int16_t *trans_input, int16_t *tmp_data, int start_index,
int real_cal_num, int out_w_block, ConvParameter *conv_param) {
int real_cal_num, int out_w_block, ConvParameter *conv_param) {
// input data format : nhwc
int input_channel = conv_param->input_channel_;
int input_width = conv_param->input_w_;
@ -1176,7 +1178,7 @@ void Conv3x3Int8FilterTransform(const int16_t *weight_data, int16_t *trans_weigh
}
void Conv3x3Int8OutputUnit(const int32_t *gemm_out, const int32_t *bias_data, int8_t *output_data, bool h_not_bound,
bool w_not_bound, int output_w, int real_num, int oc_start, ConvParameter *conv_param) {
bool w_not_bound, int output_w, int real_num, int oc_start, ConvParameter *conv_param) {
int32_t *left_shift = conv_param->conv_quant_arg_.left_shift_;
int32_t *right_shift = conv_param->conv_quant_arg_.right_shift_;
int32_t *quant_multiplier = conv_param->conv_quant_arg_.quant_multiplier_;
@ -1457,7 +1459,7 @@ void Conv3x3Int8OutputUnit(const int32_t *gemm_out, const int32_t *bias_data, in
}
void Conv3x3Int8OutputTransform(const int32_t *gemm_out, int8_t *out_data, const int32_t *bias_data, int start_index,
int real_cal_num, int out_w_block, ConvParameter *conv_param) {
int real_cal_num, int out_w_block, ConvParameter *conv_param) {
int output_channel = conv_param->output_channel_;
int output_w = conv_param->output_w_;
int output_h = conv_param->output_h_;
@ -1484,7 +1486,7 @@ void Conv3x3Int8OutputTransform(const int32_t *gemm_out, int8_t *out_data, const
bool w_not_bound = out_w_index * OUPUT_UNIT + 1 < output_w;
bool h_not_bound = out_h_index * OUPUT_UNIT + 1 < output_h;
Conv3x3Int8OutputUnit(src_ptr, bias_ptr, dst_ptr, h_not_bound, w_not_bound, output_w, real_num, j * C4NUM,
conv_param);
conv_param);
}
}
}

View File

@ -33,12 +33,10 @@ extern "C" {
#endif
// for fp32 winograd input/output transform
void WinogradInputTransform(const float *input_data, float *trans_input, float *tmp_data, int cal_num,
int out_tile_index, int out_w_block_num, ConvParameter *conv_param,
InputTransformUnitFunc input_trans_func);
int out_tile_index, int out_w_block_num, ConvParameter *conv_param, InputTransFunc func);
void WinogradOutputTransform(const float *gemm_out, float *tmp_out_data, const float *bias_data, int cal_num,
int out_tile_index, int output_unit_num, ConvParameter *conv_param,
OutputTransformUnitFunc output_trans_func);
int out_tile_index, int output_unit_num, ConvParameter *conv_param, OutputTransFunc func);
// for fp32 convolution 3x3 filter/input/output transform
void Conv3x3Fp32InputUnit(const float *tmp_data, float *trans_input_data, size_t step);

File diff suppressed because it is too large Load Diff

View File

@ -20,45 +20,237 @@
#ifdef ENABLE_ARM
#include <arm_neon.h>
#endif
#include "nnacl/matrix_table.h"
#include "nnacl/conv_parameter.h"
#include "nnacl/op_base.h"
typedef void (*InputTransformUnitFunc)(const float *src_data, float *dst_data, int src_step, int dst_step);
typedef void (*OutputTransformUnitFunc)(const float *src_data, float *dst_data, const float *bias_data, int src_step,
int dst_step);
#define MAX_LEN 256
#ifdef __cplusplus
extern "C" {
#endif
typedef void (*InputTransFunc)(const float *src_data, float *dst_data, int src_step, int dst_step);
typedef void (*OutputTransFunc)(const float *src_data, float *dst_data, const float *bias_data, int src_step,
int dst_step);
void GeneralInputTransformUnit(const float *src_data, float *dst_data, float *matrix_b, float *matrix_bt, int src_step,
int dst_step, int in_unit);
void GeneralOutputTransformUnit(const float *src_data, float *dst_data, const float *bias_data, float *matrix_a,
float *matrix_at, int src_step, int dst_step, int in_unit, int out_unit);
#define Load16Data \
src[0] = vld1q_f32(src_data + 0 * src_step); \
src[1] = vld1q_f32(src_data + 1 * src_step); \
src[2] = vld1q_f32(src_data + 2 * src_step); \
src[3] = vld1q_f32(src_data + 3 * src_step); \
src[4] = vld1q_f32(src_data + 4 * src_step); \
src[5] = vld1q_f32(src_data + 5 * src_step); \
src[6] = vld1q_f32(src_data + 6 * src_step); \
src[7] = vld1q_f32(src_data + 7 * src_step); \
src[8] = vld1q_f32(src_data + 8 * src_step); \
src[9] = vld1q_f32(src_data + 9 * src_step); \
src[10] = vld1q_f32(src_data + 10 * src_step); \
src[11] = vld1q_f32(src_data + 11 * src_step); \
src[12] = vld1q_f32(src_data + 12 * src_step); \
src[13] = vld1q_f32(src_data + 13 * src_step); \
src[14] = vld1q_f32(src_data + 14 * src_step); \
src[15] = vld1q_f32(src_data + 15 * src_step);
#define Load36Data \
src[0] = vld1q_f32(src_data + 0 * src_step); \
src[1] = vld1q_f32(src_data + 1 * src_step); \
src[2] = vld1q_f32(src_data + 2 * src_step); \
src[3] = vld1q_f32(src_data + 3 * src_step); \
src[4] = vld1q_f32(src_data + 4 * src_step); \
src[5] = vld1q_f32(src_data + 5 * src_step); \
src[6] = vld1q_f32(src_data + 6 * src_step); \
src[7] = vld1q_f32(src_data + 7 * src_step); \
src[8] = vld1q_f32(src_data + 8 * src_step); \
src[9] = vld1q_f32(src_data + 9 * src_step); \
src[10] = vld1q_f32(src_data + 10 * src_step); \
src[11] = vld1q_f32(src_data + 11 * src_step); \
src[12] = vld1q_f32(src_data + 12 * src_step); \
src[13] = vld1q_f32(src_data + 13 * src_step); \
src[14] = vld1q_f32(src_data + 14 * src_step); \
src[15] = vld1q_f32(src_data + 15 * src_step); \
src[16] = vld1q_f32(src_data + 16 * src_step); \
src[17] = vld1q_f32(src_data + 17 * src_step); \
src[18] = vld1q_f32(src_data + 18 * src_step); \
src[19] = vld1q_f32(src_data + 19 * src_step); \
src[20] = vld1q_f32(src_data + 20 * src_step); \
src[21] = vld1q_f32(src_data + 21 * src_step); \
src[22] = vld1q_f32(src_data + 22 * src_step); \
src[23] = vld1q_f32(src_data + 23 * src_step); \
src[24] = vld1q_f32(src_data + 24 * src_step); \
src[25] = vld1q_f32(src_data + 25 * src_step); \
src[26] = vld1q_f32(src_data + 26 * src_step); \
src[27] = vld1q_f32(src_data + 27 * src_step); \
src[28] = vld1q_f32(src_data + 28 * src_step); \
src[29] = vld1q_f32(src_data + 29 * src_step); \
src[30] = vld1q_f32(src_data + 30 * src_step); \
src[31] = vld1q_f32(src_data + 31 * src_step); \
src[32] = vld1q_f32(src_data + 32 * src_step); \
src[33] = vld1q_f32(src_data + 33 * src_step); \
src[34] = vld1q_f32(src_data + 34 * src_step); \
src[35] = vld1q_f32(src_data + 35 * src_step);
#define Load64Data \
src[0] = vld1q_f32(src_data + 0 * src_step); \
src[1] = vld1q_f32(src_data + 1 * src_step); \
src[2] = vld1q_f32(src_data + 2 * src_step); \
src[3] = vld1q_f32(src_data + 3 * src_step); \
src[4] = vld1q_f32(src_data + 4 * src_step); \
src[5] = vld1q_f32(src_data + 5 * src_step); \
src[6] = vld1q_f32(src_data + 6 * src_step); \
src[7] = vld1q_f32(src_data + 7 * src_step); \
src[8] = vld1q_f32(src_data + 8 * src_step); \
src[9] = vld1q_f32(src_data + 9 * src_step); \
src[10] = vld1q_f32(src_data + 10 * src_step); \
src[11] = vld1q_f32(src_data + 11 * src_step); \
src[12] = vld1q_f32(src_data + 12 * src_step); \
src[13] = vld1q_f32(src_data + 13 * src_step); \
src[14] = vld1q_f32(src_data + 14 * src_step); \
src[15] = vld1q_f32(src_data + 15 * src_step); \
src[16] = vld1q_f32(src_data + 16 * src_step); \
src[17] = vld1q_f32(src_data + 17 * src_step); \
src[18] = vld1q_f32(src_data + 18 * src_step); \
src[19] = vld1q_f32(src_data + 19 * src_step); \
src[20] = vld1q_f32(src_data + 20 * src_step); \
src[21] = vld1q_f32(src_data + 21 * src_step); \
src[22] = vld1q_f32(src_data + 22 * src_step); \
src[23] = vld1q_f32(src_data + 23 * src_step); \
src[24] = vld1q_f32(src_data + 24 * src_step); \
src[25] = vld1q_f32(src_data + 25 * src_step); \
src[26] = vld1q_f32(src_data + 26 * src_step); \
src[27] = vld1q_f32(src_data + 27 * src_step); \
src[28] = vld1q_f32(src_data + 28 * src_step); \
src[29] = vld1q_f32(src_data + 29 * src_step); \
src[30] = vld1q_f32(src_data + 30 * src_step); \
src[31] = vld1q_f32(src_data + 31 * src_step); \
src[32] = vld1q_f32(src_data + 32 * src_step); \
src[33] = vld1q_f32(src_data + 33 * src_step); \
src[34] = vld1q_f32(src_data + 34 * src_step); \
src[35] = vld1q_f32(src_data + 35 * src_step); \
src[36] = vld1q_f32(src_data + 36 * src_step); \
src[37] = vld1q_f32(src_data + 37 * src_step); \
src[38] = vld1q_f32(src_data + 38 * src_step); \
src[39] = vld1q_f32(src_data + 39 * src_step); \
src[40] = vld1q_f32(src_data + 40 * src_step); \
src[41] = vld1q_f32(src_data + 41 * src_step); \
src[42] = vld1q_f32(src_data + 42 * src_step); \
src[43] = vld1q_f32(src_data + 43 * src_step); \
src[44] = vld1q_f32(src_data + 44 * src_step); \
src[45] = vld1q_f32(src_data + 45 * src_step); \
src[46] = vld1q_f32(src_data + 46 * src_step); \
src[47] = vld1q_f32(src_data + 47 * src_step); \
src[48] = vld1q_f32(src_data + 48 * src_step); \
src[49] = vld1q_f32(src_data + 49 * src_step); \
src[50] = vld1q_f32(src_data + 50 * src_step); \
src[51] = vld1q_f32(src_data + 51 * src_step); \
src[52] = vld1q_f32(src_data + 52 * src_step); \
src[53] = vld1q_f32(src_data + 53 * src_step); \
src[54] = vld1q_f32(src_data + 54 * src_step); \
src[55] = vld1q_f32(src_data + 55 * src_step); \
src[56] = vld1q_f32(src_data + 56 * src_step); \
src[57] = vld1q_f32(src_data + 57 * src_step); \
src[58] = vld1q_f32(src_data + 58 * src_step); \
src[59] = vld1q_f32(src_data + 59 * src_step); \
src[60] = vld1q_f32(src_data + 60 * src_step); \
src[61] = vld1q_f32(src_data + 61 * src_step); \
src[62] = vld1q_f32(src_data + 62 * src_step); \
src[63] = vld1q_f32(src_data + 63 * src_step);
InputTransFunc GetInputTransFunc(int input_unit);
void InputTransform4x4Unit(const float *src_data, float *dst_data, int src_step, int dst_step);
void InputTransform6x6Unit(const float *src_data, float *dst_data, int src_step, int dst_step);
void InputTransform8x8Unit(const float *src_data, float *dst_data, int src_step, int dst_step);
void OutputTransform4x2Unit(const float *src_data, float *dst_data, const float *bias_data, int src_step, int dst_step);
OutputTransFunc GetOutputTransFunc(int input_unit, int output_unit);
#define Store4Data \
vst1q_f32(dst_data, m[0]); \
vst1q_f32(dst_data + C4NUM, m[1]); \
vst1q_f32(dst_data + dst_step * C4NUM, m[2]); \
vst1q_f32(dst_data + dst_step * C4NUM + C4NUM, m[3]);
#define Store9Data \
vst1q_f32(dst_data, m[0]); \
vst1q_f32(dst_data + C4NUM, m[1]); \
vst1q_f32(dst_data + 2 * C4NUM, m[2]); \
vst1q_f32(dst_data + dst_step * C4NUM, m[3]); \
vst1q_f32(dst_data + dst_step * C4NUM + C4NUM, m[4]); \
vst1q_f32(dst_data + dst_step * C4NUM + 2 * C4NUM, m[5]); \
vst1q_f32(dst_data + 2 * dst_step * C4NUM, m[6]); \
vst1q_f32(dst_data + 2 * dst_step * C4NUM + C4NUM, m[7]); \
vst1q_f32(dst_data + 2 * dst_step * C4NUM + 2 * C4NUM, m[8]);
#define Store16Data \
vst1q_f32(dst_data, m[0]); \
vst1q_f32(dst_data + C4NUM, m[1]); \
vst1q_f32(dst_data + 2 * C4NUM, m[2]); \
vst1q_f32(dst_data + 3 * C4NUM, m[3]); \
vst1q_f32(dst_data + dst_step * C4NUM, m[4]); \
vst1q_f32(dst_data + dst_step * C4NUM + C4NUM, m[5]); \
vst1q_f32(dst_data + dst_step * C4NUM + 2 * C4NUM, m[6]); \
vst1q_f32(dst_data + dst_step * C4NUM + 3 * C4NUM, m[7]); \
vst1q_f32(dst_data + 2 * dst_step * C4NUM, m[8]); \
vst1q_f32(dst_data + 2 * dst_step * C4NUM + C4NUM, m[9]); \
vst1q_f32(dst_data + 2 * dst_step * C4NUM + 2 * C4NUM, m[10]); \
vst1q_f32(dst_data + 2 * dst_step * C4NUM + 3 * C4NUM, m[11]); \
vst1q_f32(dst_data + 3 * dst_step * C4NUM, m[12]); \
vst1q_f32(dst_data + 3 * dst_step * C4NUM + C4NUM, m[13]); \
vst1q_f32(dst_data + 3 * dst_step * C4NUM + 2 * C4NUM, m[14]); \
vst1q_f32(dst_data + 3 * dst_step * C4NUM + 3 * C4NUM, m[15]);
#define Store25Data \
vst1q_f32(dst_data, m[0]); \
vst1q_f32(dst_data + C4NUM, m[1]); \
vst1q_f32(dst_data + 2 * C4NUM, m[2]); \
vst1q_f32(dst_data + 3 * C4NUM, m[3]); \
vst1q_f32(dst_data + 4 * C4NUM, m[4]); \
vst1q_f32(dst_data + dst_step * C4NUM, m[5]); \
vst1q_f32(dst_data + dst_step * C4NUM + C4NUM, m[6]); \
vst1q_f32(dst_data + dst_step * C4NUM + 2 * C4NUM, m[7]); \
vst1q_f32(dst_data + dst_step * C4NUM + 3 * C4NUM, m[8]); \
vst1q_f32(dst_data + dst_step * C4NUM + 4 * C4NUM, m[9]); \
vst1q_f32(dst_data + 2 * dst_step * C4NUM, m[10]); \
vst1q_f32(dst_data + 2 * dst_step * C4NUM + C4NUM, m[11]); \
vst1q_f32(dst_data + 2 * dst_step * C4NUM + 2 * C4NUM, m[12]); \
vst1q_f32(dst_data + 2 * dst_step * C4NUM + 3 * C4NUM, m[13]); \
vst1q_f32(dst_data + 2 * dst_step * C4NUM + 4 * C4NUM, m[14]); \
vst1q_f32(dst_data + 3 * dst_step * C4NUM, m[15]); \
vst1q_f32(dst_data + 3 * dst_step * C4NUM + C4NUM, m[16]); \
vst1q_f32(dst_data + 3 * dst_step * C4NUM + 2 * C4NUM, m[17]); \
vst1q_f32(dst_data + 3 * dst_step * C4NUM + 3 * C4NUM, m[18]); \
vst1q_f32(dst_data + 3 * dst_step * C4NUM + 4 * C4NUM, m[19]); \
vst1q_f32(dst_data + 4 * dst_step * C4NUM, m[20]); \
vst1q_f32(dst_data + 4 * dst_step * C4NUM + C4NUM, m[21]); \
vst1q_f32(dst_data + 4 * dst_step * C4NUM + 2 * C4NUM, m[22]); \
vst1q_f32(dst_data + 4 * dst_step * C4NUM + 3 * C4NUM, m[23]); \
vst1q_f32(dst_data + 4 * dst_step * C4NUM + 4 * C4NUM, m[24]);
void OutputTransform4x2Unit(const float *src_data, float *dst_data, const float *bias_data, int src_step, int dst_step);
void OutputTransform4x3Unit(const float *src_data, float *dst_data, const float *bias_data, int src_step, int dst_step);
void OutputTransform6x2Unit(const float *src_data, float *dst_data, const float *bias_data, int src_step, int dst_step);
void OutputTransform6x3Unit(const float *src_data, float *dst_data, const float *bias_data, int src_step, int dst_step);
void OutputTransform6x4Unit(const float *src_data, float *dst_data, const float *bias_data, int src_step, int dst_step);
void OutputTransform6x5Unit(const float *src_data, float *dst_data, const float *bias_data, int src_step, int dst_step);
void OutputTransform8x2Unit(const float *src_data, float *dst_data, const float *bias_data, int src_step, int dst_step);
void OutputTransform8x3Unit(const float *src_data, float *dst_data, const float *bias_data, int src_step, int dst_step);
void OutputTransform8x4Unit(const float *src_data, float *dst_data, const float *bias_data, int src_step, int dst_step);
void OutputTransform8x5Unit(const float *src_data, float *dst_data, const float *bias_data, int src_step, int dst_step);
void OutputTransform8x6Unit(const float *src_data, float *dst_data, const float *bias_data, int src_step, int dst_step);
void OutputTransform8x7Unit(const float *src_data, float *dst_data, const float *bias_data, int src_step, int dst_step);
int SelectOutputUnit(ConvParameter *conv_param);
InputTransformUnitFunc GetInputTransFunc(int input_unit);
OutputTransformUnitFunc GetOutputTransFunc(int input_unit, int output_unit);
void CheckIfUseWinograd(bool *use_winograd, int *output_unit, ConvParameter *conv_param,
InputTransformUnitFunc input_trans_func, OutputTransformUnitFunc output_trans_func);
void CheckIfUseWinograd(bool *use_winograd, int *output_unit, ConvParameter *conv_param);
#ifdef __cplusplus
}
#endif

View File

@ -1,86 +0,0 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "src/runtime/kernel/arm/base/matrix.h"
#include "utils/log_adapter.h"
namespace mindspore::kernel {
Matrix *TransformMatrixGenerator(int m, int k) {
auto matrix = new (std::nothrow) Matrix;
if (matrix == nullptr) {
MS_LOG(ERROR) << "matrix is nullptr.";
return nullptr;
}
auto data = malloc(m * k * sizeof(float));
if (data == nullptr) {
MS_LOG(ERROR) << "Malloc matrix data failed.";
return nullptr;
}
matrix->SetData(data);
matrix->SetNum(m, k);
return matrix;
}
void ChooseMatrixG(Matrix *matrix_g, Matrix *matrix_gt) {
int m = matrix_g->GetM();
int k = matrix_g->GetK();
auto matrix_g_data = reinterpret_cast<float *>(matrix_g->GetData());
auto matrix_gt_data = reinterpret_cast<float *>(matrix_gt->GetData());
// m represents input unit, only 4 or 8 can be accepted for input unit.
// k represents kernel unit, varies from 2 to 7.
if (m == 4 && k == 2) {
MatrixG4x2(matrix_g_data);
MatrixGT2x4(matrix_gt_data);
} else if (m == 8 && k == 2) {
MatrixG8x2(matrix_g_data);
MatrixGT2x8(matrix_gt_data);
} else if (m == 8 && k == 3) {
MatrixG8x3(matrix_g_data);
MatrixGT3x8(matrix_gt_data);
} else if (m == 8 && k == 4) {
MatrixG8x4(matrix_g_data);
MatrixGT4x8(matrix_gt_data);
} else if (m == 8 && k == 5) {
MatrixG8x5(matrix_g_data);
MatrixGT5x8(matrix_gt_data);
} else if (m == 8 && k == 6) {
MatrixG8x6(matrix_g_data);
MatrixGT6x8(matrix_gt_data);
} else if (m == 8 && k == 7) {
MatrixG8x7(matrix_g_data);
MatrixGT7x8(matrix_gt_data);
} else {
MS_LOG(ERROR) << "Unsupported input unit or kernel unit.";
return;
}
}
void MatrixMultiply(const float *matrix_a, const float *matrix_b, float *matrix_c, int m, int k, int n, bool row) {
// row-major implementation
int count = 0;
for (int h = 0; h < m; h++) {
int h_offset = h * k;
for (int w = 0; w < n; w++) {
float res = 0;
for (int i = 0; i < k; i++) {
res += *(matrix_a + h_offset + i) * *(matrix_b + w + i * n);
}
*(matrix_c + count) = res;
count++;
}
}
}
} // namespace mindspore::kernel

View File

@ -1,77 +0,0 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_BASE_MATRIX_H_
#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_BASE_MATRIX_H_
#include <stdlib.h>
#include <vector>
#include "nnacl/winograd_utils.h"
namespace mindspore::kernel {
class Matrix {
public:
Matrix() = default;
~Matrix() {
if (data_ != nullptr) {
free(data_);
}
}
void SetData(void *data) { this->data_ = data; }
void *GetData() { return this->data_; }
void SetNDim(int dim) { this->n_dim_ = dim; }
int GetNDim() { return this->n_dim_; }
void SetShape(std::vector<int> shape) { this->shape_ = shape; }
std::vector<int> GetShape() { return this->shape_; }
void SetStride(std::vector<int> stride) { this->stride_ = stride; }
std::vector<int> GetStride() { return this->stride_; }
void SetNum(int m, int k) {
this->m_ = m;
this->k_ = k;
}
int GetM() { return this->m_; }
int GetK() { return this->k_; }
protected:
void *data_ = nullptr;
std::vector<int> shape_;
std::vector<int> stride_;
int m_;
int k_;
int n_dim_;
bool row_major_;
};
Matrix *TransformMatrixGenerator(int m, int k);
// Chinese Remainder Theorem interp: 0.5
void ChooseMatrixG(Matrix *matrix_g, Matrix *matrix_gt);
void MatrixMultiply(const float *matrix_a, const float *matrix_b, float *matrix_c, int m, int k, int n, bool row);
} // namespace mindspore::kernel
#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_BASE_MATRIX_H_

View File

@ -238,9 +238,7 @@ kernel::LiteKernel *CpuConvFp16KernelCreator(const std::vector<lite::Tensor *> &
} else {
bool use_winograd = false;
int out_unit;
InputTransformUnitFunc input_trans_func = nullptr;
OutputTransformUnitFunc output_trans_func = nullptr;
CheckIfUseWinograd(&use_winograd, &out_unit, conv_param, input_trans_func, output_trans_func);
CheckIfUseWinograd(&use_winograd, &out_unit, conv_param);
if (use_winograd) {
kernel = new (std::nothrow)
kernel::ConvolutionWinogradFP16CPUKernel(opParameter, inputs, outputs, ctx, primitive, out_unit);

View File

@ -15,7 +15,7 @@
*/
#include "src/runtime/kernel/arm/fp16/convolution_winograd_fp16.h"
#include "src/runtime/kernel/arm/fp16/matrix_fp16.h"
#include "nnacl/fp16/matrix_fp16.h"
#include "nnacl/fp16/conv_fp16.h"
#include "nnacl/fp16/cast_fp16.h"
#include "nnacl/fp16/pack_fp16.h"
@ -34,43 +34,35 @@ using mindspore::lite::RET_OK;
using mindspore::schema::PrimitiveType_Conv2D;
namespace mindspore::kernel {
int WinogradFilterTransformFp16(const float16_t *weight_data, Matrix *trans_weight, int kernel_unit, int input_unit,
ConvParameter *conv_param, int oc_block) {
int ConvolutionWinogradFP16CPUKernel::WinogradFilterTransformFp16(const float16_t *weight_data, float *matrix_g,
float *matrix_gt, int oc_block) {
// original weight format : ohwi
auto channel_in = conv_param->input_channel_;
auto channel_out = conv_param->output_channel_;
int input_unit_square = input_unit * input_unit;
auto channel_in = conv_param_->input_channel_;
auto channel_out = conv_param_->output_channel_;
int ic8 = UP_DIV(channel_in, C8NUM);
int ic4 = ic8 * 2;
int input_unit_square = input_unit_ * input_unit_;
int oc_block_num = UP_DIV(channel_out, oc_block);
// generate matrix_G && matrix_GT
auto matrix_g = TransformMatrixGenerator(input_unit, kernel_unit);
if (matrix_g == nullptr) {
MS_LOG(ERROR) << "matrix_g is null.";
delete matrix_g;
return RET_ERROR;
}
auto matrix_gt = TransformMatrixGenerator(kernel_unit, input_unit);
if (matrix_gt == nullptr) {
MS_LOG(ERROR) << "matrix_gt is null.";
delete matrix_g;
delete matrix_gt;
return RET_ERROR;
}
ChooseMatrixG(matrix_g, matrix_gt);
auto matrix_g_data = reinterpret_cast<float *>(matrix_g->GetData());
auto matrix_gt_data = reinterpret_cast<float *>(matrix_gt->GetData());
auto matrix_g_data_fp16 = reinterpret_cast<float16_t *>(malloc(input_unit * kernel_unit * sizeof(float16_t)));
auto matrix_gt_data_fp16 = reinterpret_cast<float16_t *>(malloc(input_unit * kernel_unit * sizeof(float16_t)));
Float32ToFloat16(matrix_g_data, matrix_g_data_fp16, input_unit * kernel_unit);
Float32ToFloat16(matrix_gt_data, matrix_gt_data_fp16, input_unit * kernel_unit);
auto matrix_g_data_fp16 = reinterpret_cast<float16_t *>(malloc(input_unit_ * kernel_unit_ * sizeof(float16_t)));
auto matrix_gt_data_fp16 = reinterpret_cast<float16_t *>(malloc(input_unit_ * kernel_unit_ * sizeof(float16_t)));
Float32ToFloat16(matrix_g, matrix_g_data_fp16, input_unit_ * kernel_unit_);
Float32ToFloat16(matrix_gt, matrix_gt_data_fp16, input_unit_ * kernel_unit_);
// trans_filter = G*g*GT (g represents weight_data)
// separate into two steps ===> tmp = G*g ===> out = tmp * GT
auto tmp_weight_data = reinterpret_cast<float16_t *>(malloc(kernel_unit * kernel_unit * sizeof(float16_t)));
auto tmp_data = reinterpret_cast<float16_t *>(malloc(input_unit * kernel_unit * sizeof(float16_t)));
auto trans_out_data = reinterpret_cast<float16_t *>(malloc(input_unit * input_unit * sizeof(float16_t)));
bool row = true;
auto trans_weight_data = reinterpret_cast<float16_t *>(trans_weight->GetData());
std::vector<int> strides = trans_weight->GetStride();
auto tmp_weight_data = reinterpret_cast<float16_t *>(malloc(kernel_unit_ * kernel_unit_ * sizeof(float16_t)));
auto tmp_data = reinterpret_cast<float16_t *>(malloc(input_unit_ * kernel_unit_ * sizeof(float16_t)));
auto trans_out_data = reinterpret_cast<float16_t *>(malloc(input_unit_ * input_unit_ * sizeof(float16_t)));
std::vector<int> shape{input_unit_ * input_unit_, oc_block_num, ic4, C4NUM, oc_block};
std::vector<int> strides;
for (int i = 0; i < 4; i++) {
int stride = 1;
for (int j = i + 1; j < 5; j++) {
stride *= shape[j];
}
strides.push_back(stride);
}
int kernel_plane_stride = channel_in;
if (oc_block == 0) {
@ -80,33 +72,31 @@ int WinogradFilterTransformFp16(const float16_t *weight_data, Matrix *trans_weig
free(trans_out_data);
free(matrix_g_data_fp16);
free(matrix_gt_data_fp16);
delete matrix_g;
delete matrix_gt;
return RET_ERROR;
}
for (int i = 0; i < channel_out; i++) {
int out_c_block = i / oc_block;
int out_c_res = i % oc_block;
int input_oz_offset = i * kernel_unit * kernel_unit * channel_in;
int output_oz_offset = out_c_block * strides[1] * input_unit * input_unit + out_c_res;
int input_oz_offset = i * kernel_unit_ * kernel_unit_ * channel_in;
int output_oz_offset = out_c_block * strides[1] * input_unit_ * input_unit_ + out_c_res;
for (int j = 0; j < channel_in; j++) {
int ic4_block = j / C4NUM;
int ic4_res = j % C4NUM;
int input_iz_offset = input_oz_offset + j;
int output_iz_offset = output_oz_offset + ic4_block * strides[2] + ic4_res * strides[3];
for (int k = 0; k < kernel_unit * kernel_unit; k++) {
for (int k = 0; k < kernel_unit_ * kernel_unit_; k++) {
int input_xy_offset = input_iz_offset + k * kernel_plane_stride;
tmp_weight_data[k] = *(weight_data + input_xy_offset);
}
// now we only support row-major matrix-multiply
// tmp = G * g
MatrixMultiplyFp16(matrix_g_data_fp16, tmp_weight_data, tmp_data, input_unit, kernel_unit, kernel_unit, row);
MatrixMultiplyFp16(matrix_g_data_fp16, tmp_weight_data, tmp_data, input_unit_, kernel_unit_, kernel_unit_);
// out = tmp * GT
MatrixMultiplyFp16(tmp_data, matrix_gt_data_fp16, trans_out_data, input_unit, kernel_unit, input_unit, row);
MatrixMultiplyFp16(tmp_data, matrix_gt_data_fp16, trans_out_data, input_unit_, kernel_unit_, input_unit_);
for (int z = 0; z < input_unit_square; z++) {
int output_xy_offset = output_iz_offset + z * strides[1];
*(trans_weight_data + output_xy_offset) = trans_out_data[z];
trans_weight_[output_xy_offset] = trans_out_data[z];
}
}
}
@ -115,15 +105,58 @@ int WinogradFilterTransformFp16(const float16_t *weight_data, Matrix *trans_weig
free(trans_out_data);
free(matrix_g_data_fp16);
free(matrix_gt_data_fp16);
delete matrix_g;
delete matrix_gt;
return RET_OK;
}
int ConvolutionWinogradFP16CPUKernel::MallocTransformMatrices() {
matrix_a_ = reinterpret_cast<float16_t *>(malloc(input_unit_ * output_unit_ * sizeof(float16_t)));
if (matrix_a_ == nullptr) {
MS_LOG(ERROR) << "malloc matrix_a_ failed.";
return RET_ERROR;
}
matrix_at_ = reinterpret_cast<float16_t *>(malloc(input_unit_ * output_unit_ * sizeof(float16_t)));
if (matrix_at_ == nullptr) {
MS_LOG(ERROR) << "malloc matrix_at_ failed.";
return RET_ERROR;
}
matrix_b_ = reinterpret_cast<float16_t *>(malloc(input_unit_ * input_unit_ * sizeof(float16_t)));
if (matrix_b_ == nullptr) {
MS_LOG(ERROR) << "malloc matrix_b_ failed.";
return RET_ERROR;
}
matrix_bt_ = reinterpret_cast<float16_t *>(malloc(input_unit_ * input_unit_ * sizeof(float16_t)));
if (matrix_bt_ == nullptr) {
MS_LOG(ERROR) << "malloc matrix_bt_ failed.";
return RET_ERROR;
}
return RET_OK;
}
void ConvolutionWinogradFP16CPUKernel::FreeTransformMatrices() {
if (matrix_a_ != nullptr) {
free(matrix_a_);
matrix_a_ = nullptr;
}
if (matrix_at_ != nullptr) {
free(matrix_at_);
matrix_at_ = nullptr;
}
if (matrix_b_ != nullptr) {
free(matrix_b_);
matrix_b_ = nullptr;
}
if (matrix_bt_ != nullptr) {
free(matrix_bt_);
matrix_bt_ = nullptr;
}
return;
}
int ConvolutionWinogradFP16CPUKernel::InitWeightBias() {
auto filter_tensor = in_tensors_.at(kWeightIndex);
int in_channel = filter_tensor->Channel();
int out_channel = filter_tensor->Batch();
int ic8 = UP_DIV(in_channel, C8NUM);
conv_param_->input_channel_ = in_channel;
conv_param_->output_channel_ = out_channel;
@ -132,19 +165,43 @@ int ConvolutionWinogradFP16CPUKernel::InitWeightBias() {
oc_block_num = UP_DIV(out_channel, C8NUM);
// init weight
auto ret = MallocFilterMatrix(oc_block, oc_block_num);
if (ret != RET_OK) {
MS_LOG(ERROR) << "Malloc filter matrix failed.";
return RET_ERROR;
}
ret = ConvolutionBaseFP16CPUKernel::GetExecuteFilter();
auto ret = ConvolutionBaseFP16CPUKernel::GetExecuteFilter();
if (ret != RET_OK) {
MS_LOG(ERROR) << "Get Execute filter failed.";
return ret;
}
ret = WinogradFilterTransformFp16(execute_weight_, trans_weight_, kernel_unit_, input_unit_, conv_param_, oc_block);
// set data
auto trans_matrix_data_size = input_unit_ * input_unit_ * ic8 * C8NUM * oc_block_num * oc_block * sizeof(float16_t);
trans_weight_ = reinterpret_cast<float16_t *>(malloc(trans_matrix_data_size));
if (trans_weight_ == nullptr) {
MS_LOG(ERROR) << "malloc trans_weight_ failed.";
return RET_ERROR;
}
memset(trans_weight_, 0, trans_matrix_data_size);
auto *matrix_g = reinterpret_cast<float *>(malloc(input_unit_ * kernel_unit_ * sizeof(float)));
auto matrix_gt = reinterpret_cast<float *>(malloc(input_unit_ * kernel_unit_ * sizeof(float)));
ret = MallocTransformMatrices();
if (ret != RET_OK) {
MS_LOG(ERROR) << "Malloc transform matrices failed.";
return ret;
}
float matrix_a[MAX_LEN];
float matrix_at[MAX_LEN];
float matrix_b[MAX_LEN];
float matrix_bt[MAX_LEN];
CookToomFilter(matrix_a, matrix_at, matrix_b, matrix_bt, matrix_g, matrix_gt, 0.5f, output_unit_, kernel_unit_);
Float32ToFloat16(matrix_a, matrix_a_, input_unit_ * output_unit_);
Float32ToFloat16(matrix_at, matrix_at_, input_unit_ * output_unit_);
Float32ToFloat16(matrix_b, matrix_b_, input_unit_ * input_unit_);
Float32ToFloat16(matrix_bt, matrix_bt_, input_unit_ * input_unit_);
matrices_[0] = matrix_a_;
matrices_[1] = matrix_at_;
matrices_[2] = matrix_b_;
matrices_[3] = matrix_bt_;
ret = WinogradFilterTransformFp16(execute_weight_, matrix_g, matrix_gt, oc_block);
if (ret != RET_OK) {
MS_LOG(ERROR) << "winograd filter transfrom failed.";
return ret;
@ -166,49 +223,8 @@ int ConvolutionWinogradFP16CPUKernel::InitWeightBias() {
} else {
MS_ASSERT(inputs_.size() == kInputSize1);
}
return RET_OK;
}
int ConvolutionWinogradFP16CPUKernel::MallocFilterMatrix(int oc_block, int oc_block_num) {
int channel_in = conv_param_->input_channel_;
int ic8 = UP_DIV(channel_in, C8NUM);
int ic4 = ic8 * 2;
// set data
auto trans_matrix_data_size = input_unit_ * input_unit_ * ic8 * C8NUM * oc_block_num * oc_block * sizeof(float);
auto matrix_buffer = malloc(trans_matrix_data_size);
if (matrix_buffer == nullptr) {
MS_LOG(ERROR) << "malloc matrix_buffer failed.";
return RET_ERROR;
}
memset(matrix_buffer, 0, trans_matrix_data_size);
trans_weight_ = new (std::nothrow) Matrix();
if (trans_weight_ == nullptr) {
MS_LOG(ERROR) << "new Matrix fail!";
free(matrix_buffer);
return RET_ERROR;
}
trans_weight_->SetData(matrix_buffer);
trans_weight_->SetNDim(5);
std::vector<int> shapes;
std::vector<int> strides;
// set shape
shapes.push_back(input_unit_ * input_unit_);
shapes.push_back(oc_block_num);
shapes.push_back(ic4);
shapes.push_back(C4NUM);
shapes.push_back(oc_block);
// set stride
for (int i = 0; i < 4; i++) {
int stride = 1;
for (int j = i + 1; j < 5; j++) {
stride *= shapes[j];
}
strides.push_back(stride);
}
trans_weight_->SetShape(shapes);
trans_weight_->SetStride(strides);
free(matrix_g);
free(matrix_gt);
return RET_OK;
}
@ -260,19 +276,7 @@ int ConvolutionWinogradFP16CPUKernel::InitTmpBuffer() {
int ConvolutionWinogradFP16CPUKernel::ConfigInputOutput() {
auto output_tensor = out_tensors_.at(kOutputIndex);
output_tensor->SetFormat(schema::Format::Format_NHWC);
// choose input transformer function (4x4 unit or 8x8 unit)
input_trans_func_ = GetInputTransFuncFp16(input_unit_);
if (input_trans_func_ == nullptr) {
MS_LOG(ERROR) << "Get input_trans_func failed.";
return RET_ERROR;
}
output_trans_func_ = GetOutputTransFuncFp16(input_unit_, output_unit_);
if (output_trans_func_ == nullptr) {
MS_LOG(ERROR) << "Get output_trans_func_ failed.";
return RET_ERROR;
}
output_tensor->SetFormat(schema::Format_NHWC);
return RET_OK;
}
@ -334,9 +338,9 @@ int ConvolutionWinogradFP16CPUKernel::ReSize() {
}
int ConvolutionWinogradFP16CPUKernel::RunImpl(int task_id) {
ConvWinogardFp16(reinterpret_cast<float16_t *>(nhwc4_input_), reinterpret_cast<float16_t *>(trans_weight_->GetData()),
ConvWinogardFp16(reinterpret_cast<float16_t *>(nhwc4_input_), trans_weight_,
reinterpret_cast<const float16_t *>(bias_data_), tmp_buffer_address_list_, task_id, conv_param_,
input_trans_func_, output_trans_func_);
matrices_);
return RET_OK;
}

View File

@ -22,9 +22,9 @@
#include "src/lite_kernel.h"
#include "src/runtime/kernel/arm/fp16/convolution_base_fp16.h"
#include "nnacl/fp16/conv_fp16.h"
#include "src/runtime/kernel/arm/fp16/matrix_fp16.h"
#include "nnacl/fp16/winograd_utils_fp16.h"
#include "nnacl/optimized_kernel.h"
#include "nnacl/minimal_filtering_generator.h"
namespace mindspore::kernel {
class ConvolutionWinogradFP16CPUKernel : public ConvolutionBaseFP16CPUKernel {
@ -39,9 +39,10 @@ class ConvolutionWinogradFP16CPUKernel : public ConvolutionBaseFP16CPUKernel {
fp16_weight_ = nullptr;
}
if (trans_weight_ != nullptr) {
delete trans_weight_;
free(trans_weight_);
trans_weight_ = nullptr;
}
FreeTransformMatrices();
}
int Init() override;
@ -49,10 +50,12 @@ class ConvolutionWinogradFP16CPUKernel : public ConvolutionBaseFP16CPUKernel {
int Run() override;
int RunImpl(int task_id);
int InitWeightBias();
int MallocFilterMatrix(int oc_block, int oc_block_num);
int MallocTransformMatrices();
void FreeTransformMatrices();
int InitTmpBuffer();
int ConfigInputOutput();
int PostProcess();
int WinogradFilterTransformFp16(const float16_t *weight_data, float *matrix_g, float *matrix_gt, int oc_block);
private:
void FreeTmpBuffer() {
@ -80,13 +83,14 @@ class ConvolutionWinogradFP16CPUKernel : public ConvolutionBaseFP16CPUKernel {
float16_t *trans_input_ = nullptr;
float16_t *gemm_out_ = nullptr;
float16_t *tmp_out_data_ = nullptr;
Matrix *trans_weight_ = nullptr;
InputTransformUnitFp16Func input_trans_func_;
OutputTransformUnitFp16Func output_trans_func_;
float16_t *matrix_a_ = nullptr;
float16_t *matrix_at_ = nullptr;
float16_t *matrix_b_ = nullptr;
float16_t *matrix_bt_ = nullptr;
float16_t *trans_weight_ = nullptr;
TmpBufferAddressFp16 tmp_buffer_address_list_[4];
MatricesFp16 matrices_[4];
};
int WinogradFilterTransformFp16(const float16_t *weight_data, Matrix *trans_weight, int kernel_unit, int input_unit,
ConvParameter *conv_param, int oc_block);
} // namespace mindspore::kernel
#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP16_CONVOLUTION_WINOGRAD_FP16_H_

View File

@ -1,36 +0,0 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "src/runtime/kernel/arm/fp16/matrix_fp16.h"
namespace mindspore::kernel {
void MatrixMultiplyFp16(const float16_t *matrix_a, const float16_t *matrix_b, float16_t *matrix_c, int m, int k, int n,
bool row) {
// row-major implementation
int count = 0;
for (int h = 0; h < m; h++) {
int h_offset = h * k;
for (int w = 0; w < n; w++) {
float16_t res = 0;
for (int i = 0; i < k; i++) {
res += *(matrix_a + h_offset + i) * *(matrix_b + w + i * n);
}
*(matrix_c + count) = res;
count++;
}
}
}
} // namespace mindspore::kernel

View File

@ -228,10 +228,8 @@ kernel::LiteKernel *CpuConvFp32KernelCreator(const std::vector<lite::Tensor *> &
conv_param->op_parameter_.thread_num_ = ctx->thread_num_;
bool use_winograd = false;
int out_unit;
InputTransformUnitFunc input_trans_func = nullptr;
OutputTransformUnitFunc output_trans_func = nullptr;
if (primitive != nullptr && primitive->GetInferFlag()) {
CheckIfUseWinograd(&use_winograd, &out_unit, conv_param, input_trans_func, output_trans_func);
CheckIfUseWinograd(&use_winograd, &out_unit, conv_param);
}
auto *weight_tensor = inputs.at(kWeightIndex);

View File

@ -28,39 +28,29 @@ using mindspore::lite::RET_OK;
using mindspore::schema::PrimitiveType_Conv2D;
namespace mindspore::kernel {
int WinogradFilterTransform(const float *weight_data, Matrix *trans_weight, int kernel_unit, int input_unit,
ConvParameter *conv_param, int oc_block) {
int ConvolutionWinogradCPUKernel::WinogradFilterTransform(const float *weight_data, float *matrix_g, float *matrix_gt,
int oc_block) {
// original weight format : ohwi
auto channel_in = conv_param->input_channel_;
auto channel_out = conv_param->output_channel_;
int input_unit_square = input_unit * input_unit;
// generate matrix_G && matrix_GT
auto matrix_g = TransformMatrixGenerator(input_unit, kernel_unit);
if (matrix_g == nullptr) {
MS_LOG(ERROR) << "matrix_g is null.";
delete matrix_g;
return RET_ERROR;
}
auto matrix_gt = TransformMatrixGenerator(kernel_unit, input_unit);
if (matrix_gt == nullptr) {
MS_LOG(ERROR) << "matrix_gt is null.";
delete matrix_g;
delete matrix_gt;
return RET_ERROR;
}
ChooseMatrixG(matrix_g, matrix_gt);
auto matrix_g_data = reinterpret_cast<float *>(matrix_g->GetData());
auto matrix_gt_data = reinterpret_cast<float *>(matrix_gt->GetData());
auto channel_in = conv_param_->input_channel_;
auto channel_out = conv_param_->output_channel_;
int input_unit_square = input_unit_ * input_unit_;
int ic4 = UP_DIV(channel_in, C4NUM);
int oc_block_num = UP_DIV(channel_out, oc_block);
// trans_filter = G*g*GT (g represents weight_data)
// separate into two steps ===> tmp = G*g ===> out = tmp * GT
auto tmp_weight_data = reinterpret_cast<float *>(malloc(kernel_unit * kernel_unit * sizeof(float)));
auto tmp_data = reinterpret_cast<float *>(malloc(input_unit * kernel_unit * sizeof(float)));
auto trans_out_data = reinterpret_cast<float *>(malloc(input_unit * input_unit * sizeof(float)));
bool row = true;
auto trans_weight_data = reinterpret_cast<float *>(trans_weight->GetData());
std::vector<int> strides = trans_weight->GetStride();
auto tmp_weight_data = reinterpret_cast<float *>(malloc(kernel_unit_ * kernel_unit_ * sizeof(float)));
auto tmp_data = reinterpret_cast<float *>(malloc(input_unit_ * kernel_unit_ * sizeof(float)));
auto trans_out_data = reinterpret_cast<float *>(malloc(input_unit_ * input_unit_ * sizeof(float)));
std::vector<int> shape{input_unit_ * input_unit_, oc_block_num, ic4, C4NUM, oc_block};
std::vector<int> strides;
for (int i = 0; i < 4; i++) {
int stride = 1;
for (int j = i + 1; j < 5; j++) {
stride *= shape[j];
}
strides.push_back(stride);
}
int kernel_plane_stride = channel_in;
if (oc_block == 0) {
@ -68,41 +58,37 @@ int WinogradFilterTransform(const float *weight_data, Matrix *trans_weight, int
free(tmp_weight_data);
free(tmp_data);
free(trans_out_data);
delete matrix_g;
delete matrix_gt;
return RET_ERROR;
}
for (int i = 0; i < channel_out; i++) {
int out_c_block = i / oc_block;
int out_c_res = i % oc_block;
int input_oz_offset = i * kernel_unit * kernel_unit * channel_in;
int input_oz_offset = i * kernel_unit_ * kernel_unit_ * channel_in;
int output_oz_offset = out_c_block * strides[1] + out_c_res;
for (int j = 0; j < channel_in; j++) {
int ic4_block = j / C4NUM;
int ic4_res = j % C4NUM;
int input_iz_offset = input_oz_offset + j;
int output_iz_offset = output_oz_offset + ic4_block * strides[2] + ic4_res * strides[3];
for (int k = 0; k < kernel_unit * kernel_unit; k++) {
for (int k = 0; k < kernel_unit_ * kernel_unit_; k++) {
int input_xy_offset = input_iz_offset + k * kernel_plane_stride;
tmp_weight_data[k] = *(weight_data + input_xy_offset);
}
// now we only support row-major matrix-multiply
// tmp = G * g
MatrixMultiply(matrix_g_data, tmp_weight_data, tmp_data, input_unit, kernel_unit, kernel_unit, row);
MatrixMultiply(matrix_g, tmp_weight_data, tmp_data, input_unit_, kernel_unit_, kernel_unit_);
// out = tmp * GT
MatrixMultiply(tmp_data, matrix_gt_data, trans_out_data, input_unit, kernel_unit, input_unit, row);
MatrixMultiply(tmp_data, matrix_gt, trans_out_data, input_unit_, kernel_unit_, input_unit_);
for (int z = 0; z < input_unit_square; z++) {
int output_xy_offset = output_iz_offset + z * strides[0];
*(trans_weight_data + output_xy_offset) = trans_out_data[z];
*(trans_weight_ + output_xy_offset) = trans_out_data[z];
}
}
}
free(tmp_weight_data);
free(tmp_data);
free(trans_out_data);
delete matrix_g;
delete matrix_gt;
return RET_OK;
}
@ -110,6 +96,7 @@ int ConvolutionWinogradCPUKernel::InitWeightBias() {
auto filter_tensor = in_tensors_.at(kWeightIndex);
int in_channel = filter_tensor->Channel();
int out_channel = filter_tensor->Batch();
int ic4 = UP_DIV(in_channel, C4NUM);
conv_param_->input_channel_ = in_channel;
conv_param_->output_channel_ = out_channel;
@ -118,14 +105,26 @@ int ConvolutionWinogradCPUKernel::InitWeightBias() {
oc_block = C8NUM;
oc_block_num = UP_DIV(out_channel, C8NUM);
// init weight
auto ret = MallocFilterMatrix(oc_block, oc_block_num);
if (ret != RET_OK) {
MS_LOG(ERROR) << "Malloc filter matrix failed.";
// set data
auto trans_matrix_data_size = input_unit_ * input_unit_ * ic4 * C4NUM * oc_block_num * oc_block * sizeof(float);
trans_weight_ = reinterpret_cast<float *>(malloc(trans_matrix_data_size));
if (trans_weight_ == nullptr) {
MS_LOG(ERROR) << "malloc matrix_buffer failed.";
return RET_ERROR;
}
memset(trans_weight_, 0, trans_matrix_data_size);
float matrix_g[64];
float matrix_gt[64];
float matrix_a[64];
float matrix_at[64];
float matrix_b[64];
float matrix_bt[64];
CookToomFilter(matrix_a, matrix_at, matrix_b, matrix_bt, matrix_g, matrix_gt, 1.0f, output_unit_, kernel_unit_);
auto weight_data = reinterpret_cast<float *>(filter_tensor->MutableData());
ret = WinogradFilterTransform(weight_data, trans_weight_, kernel_unit_, input_unit_, conv_param_, oc_block);
auto ret = WinogradFilterTransform(weight_data, matrix_g, matrix_gt, oc_block);
if (ret != RET_OK) {
MS_LOG(ERROR) << "winograd filter transfrom failed.";
return ret;
@ -144,48 +143,6 @@ int ConvolutionWinogradCPUKernel::InitWeightBias() {
return RET_OK;
}
int ConvolutionWinogradCPUKernel::MallocFilterMatrix(int oc_block, int oc_block_num) {
int channel_in = conv_param_->input_channel_;
int ic4 = UP_DIV(channel_in, C4NUM);
// set data
auto trans_matrix_data_size = input_unit_ * input_unit_ * ic4 * C4NUM * oc_block_num * oc_block * sizeof(float);
auto matrix_buffer = malloc(trans_matrix_data_size);
if (matrix_buffer == nullptr) {
MS_LOG(ERROR) << "malloc matrix_buffer failed.";
return RET_ERROR;
}
memset(matrix_buffer, 0, trans_matrix_data_size);
trans_weight_ = new (std::nothrow) Matrix();
if (trans_weight_ == nullptr) {
MS_LOG(ERROR) << "new Matrix fail!";
free(matrix_buffer);
return RET_ERROR;
}
trans_weight_->SetData(matrix_buffer);
trans_weight_->SetNDim(5);
std::vector<int> shapes;
std::vector<int> strides;
// set shape
shapes.push_back(input_unit_ * input_unit_);
shapes.push_back(oc_block_num);
shapes.push_back(ic4);
shapes.push_back(C4NUM);
shapes.push_back(oc_block);
// set stride
for (int i = 0; i < 4; i++) {
int stride = 1;
for (int j = i + 1; j < 5; j++) {
stride *= shapes[j];
}
strides.push_back(stride);
}
trans_weight_->SetShape(shapes);
trans_weight_->SetStride(strides);
return RET_OK;
}
int ConvolutionWinogradCPUKernel::InitTmpBuffer() {
int channel_out = conv_param_->output_channel_;
int output_h = conv_param_->output_h_;
@ -245,17 +202,17 @@ int ConvolutionWinogradCPUKernel::ConfigInputOutput() {
auto output_tensor = out_tensors_.at(kOutputIndex);
output_tensor->SetFormat(schema::Format::Format_NHWC);
// choose input transformer function (4x4 unit or 8x8 unit)
input_trans_func_ = GetInputTransFunc(input_unit_);
if (input_trans_func_ == nullptr) {
MS_LOG(ERROR) << "Get input_trans_func failed.";
in_func_ = GetInputTransFunc(input_unit_);
if (in_func_ == nullptr) {
MS_LOG(ERROR) << "in_func_ is null.";
return RET_ERROR;
}
output_trans_func_ = GetOutputTransFunc(input_unit_, output_unit_);
if (output_trans_func_ == nullptr) {
MS_LOG(ERROR) << "Get output_trans_func_ failed.";
out_func_ = GetOutputTransFunc(input_unit_, output_unit_);
if (out_func_ == nullptr) {
MS_LOG(ERROR) << "out_func_ is null.";
return RET_ERROR;
}
// #ifdef ENABLE_ARM32
// gemm_func_ = IndirectGemmFp32_8x4;
// #else
@ -326,9 +283,8 @@ int ConvolutionWinogradCPUKernel::RunImpl(int task_id) {
MS_LOG(ERROR) << "gemm_func is nullptr.";
return RET_ERROR;
}
ConvWinogardFp32(reinterpret_cast<float *>(nhwc4_input_), reinterpret_cast<float *>(trans_weight_->GetData()),
reinterpret_cast<const float *>(bias_data_), tmp_buffer_address_list_, task_id, conv_param_,
input_trans_func_, output_trans_func_, gemm_func_);
ConvWinogardFp32(reinterpret_cast<float *>(nhwc4_input_), trans_weight_, reinterpret_cast<const float *>(bias_data_),
tmp_buffer_address_list_, task_id, conv_param_, in_func_, out_func_, gemm_func_);
return RET_OK;
}

View File

@ -19,10 +19,9 @@
#include <vector>
#include "src/lite_kernel.h"
#include "nnacl/winograd_transform.h"
#include "nnacl/minimal_filtering_generator.h"
#include "src/runtime/kernel/arm/base/convolution_base.h"
#include "src/runtime/kernel/arm/base/matrix.h"
namespace mindspore::kernel {
class ConvolutionWinogradCPUKernel : public ConvolutionBaseCPUKernel {
@ -35,7 +34,7 @@ class ConvolutionWinogradCPUKernel : public ConvolutionBaseCPUKernel {
trans_weight_(nullptr) {}
~ConvolutionWinogradCPUKernel() override {
if (trans_weight_ != nullptr) {
delete trans_weight_;
free(trans_weight_);
trans_weight_ = nullptr;
}
};
@ -44,10 +43,10 @@ class ConvolutionWinogradCPUKernel : public ConvolutionBaseCPUKernel {
int Run() override;
int RunImpl(int task_id);
int InitWeightBias();
int MallocFilterMatrix(int oc_block, int oc_block_num);
int InitTmpBuffer();
int ConfigInputOutput();
int PostProcess();
int WinogradFilterTransform(const float *weight_data, float *matrix_g, float *matrix_gt, int oc_block);
private:
void FreeTmpBuffer() {
@ -80,13 +79,12 @@ class ConvolutionWinogradCPUKernel : public ConvolutionBaseCPUKernel {
float *gemm_out_ = nullptr;
float *tmp_out_data_ = nullptr;
float *col_buffer_ = nullptr;
Matrix *trans_weight_ = nullptr;
InputTransformUnitFunc input_trans_func_;
OutputTransformUnitFunc output_trans_func_;
float *trans_weight_ = nullptr;
TmpBufferAddress tmp_buffer_address_list_[5];
InputTransFunc in_func_;
OutputTransFunc out_func_;
GEMM_FUNC_FP32 gemm_func_ = nullptr;
};
int WinogradFilterTransform(const float *weight_data, Matrix *trans_weight, int kernel_unit, int input_unit,
ConvParameter *conv_param, int oc_block);
} // namespace mindspore::kernel
#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_CONVOLUTION_WINOGRAD_H_

View File

@ -181,9 +181,18 @@ class MS_API Benchmark {
auto tolerance = absoluteTolerance + relativeTolerance * fabs(calibTensor->data.at(j));
auto absoluteError = std::fabs(msTensorData[j] - calibTensor->data.at(j));
if (absoluteError > tolerance) {
// just assume that atol = rtol
meanError += absoluteError / (fabs(calibTensor->data.at(j)) + FLT_MIN);
errorCount++;
if (fabs(calibTensor->data.at(j)) == 0) {
if (absoluteError > 1e-5) {
meanError += absoluteError;
errorCount++;
} else {
continue;
}
} else {
// just assume that atol = rtol
meanError += absoluteError / (fabs(calibTensor->data.at(j)) + FLT_MIN);
errorCount++;
}
}
}
std::cout << std::endl;