!5006 [MS][LITE][Develop]Deconv Matmul 12x8

Merge pull request !5006 from ling/deconv
This commit is contained in:
mindspore-ci-bot 2020-08-24 16:44:25 +08:00 committed by Gitee
commit 33a562de3d
19 changed files with 134 additions and 888 deletions

View File

@ -251,12 +251,6 @@ void ConvFp32(float *input_data, float *packed_input, float *packed_weight, cons
}
}
// fp32 conv1x1 strassen matmul
int Conv1x1Fp32(const float *input_data, const float *weight_data, float *output_data, float *tmp_ptr,
StrassenMatMulParameter matmul_param) {
return StrassenMatmul(input_data, weight_data, output_data, &matmul_param, FP32_STRASSEN_MAX_RECURSION, 0, tmp_ptr);
}
// fp32 conv winograd
void ConvWinogardFp32(float *input_data, float *trans_weight, const float *bias_data, TmpBufferAddress *buffer_list,
int task_id, ConvParameter *conv_param, InputTransformUnitFunc input_trans_func,

View File

@ -24,7 +24,6 @@
#include "nnacl/op_base.h"
#include "nnacl/common_func.h"
#include "nnacl/conv_parameter.h"
#include "nnacl/fp32/strassen_matmul.h"
#include "nnacl/winograd_utils.h"
#include "nnacl/fp32/conv_depthwise.h"
@ -52,10 +51,6 @@ void ConvFp32(float *input_data, float *packed_input, float *packed_weight, cons
float *tmp_out_block, float *output_data, int task_id, ConvParameter *conv_param,
GEMM_FUNC_FP32 gemm_func);
// fp32 conv1x1 strassen matmul
int Conv1x1Fp32(const float *input_data, const float *weight_data, float *output_data, float *tmp_ptr,
StrassenMatMulParameter matmul_param);
// fp32 convolution winograd
void ConvWinogardFp32(float *input_data, float *trans_weight, const float *bias_data, TmpBufferAddress *buffer_list,
int task_id, ConvParameter *conv_param, InputTransformUnitFunc input_trans_func,

View File

@ -33,18 +33,18 @@ void PackDeConvWeightFp32(const float *weight, float *dst, int input_channel, in
return;
}
int DeConvPostFp32C8x8(const float *src, float *tmp, const float *bias, float *dst, int output_channel,
ConvParameter *conv_param) {
/* row8x8-major(ih*iw x oc*kh*kw) -> row8-major(oh*ow x oc) */
int DeConvPostFp32C12x8(const float *src, float *tmp, const float *bias, float *dst, int output_channel,
ConvParameter *conv_param) {
/* row12x8-major(ih*iw x oc*kh*kw) -> row8-major(oh*ow x oc) */
size_t input_plane = conv_param->input_w_ * conv_param->input_h_;
size_t kernel_plane = conv_param->kernel_w_ * conv_param->kernel_h_;
size_t output_plane = conv_param->output_w_ * conv_param->output_h_;
int oc8 = UP_ROUND(output_channel, C8NUM);
int in_plane8 = UP_ROUND(input_plane, C8NUM);
int in_plane12 = UP_ROUND(input_plane, C12NUM);
int src_iw_stride = C8NUM;
int src_ih_stride = conv_param->input_w_ * C8NUM;
int src_kw_stride = in_plane8 * C8NUM;
int src_kh_stride = in_plane8 * conv_param->kernel_w_ * C8NUM;
int src_kw_stride = in_plane12 * C8NUM;
int src_kh_stride = in_plane12 * conv_param->kernel_w_ * C8NUM;
int dst_oh_stride = conv_param->output_w_ * C8NUM;
int dst_ow_stride = C8NUM;
int dst_kh_stride = conv_param->dilation_h_ * conv_param->output_w_ * C8NUM;
@ -52,7 +52,7 @@ int DeConvPostFp32C8x8(const float *src, float *tmp, const float *bias, float *d
for (int c = 0; c < oc8; c += 8) {
float *dst_ptr = tmp + c * output_plane;
const float *src_ptr = src + c * in_plane8 * kernel_plane;
const float *src_ptr = src + c * in_plane12 * kernel_plane;
memset(dst_ptr, 0, output_plane * C8NUM * sizeof(float));
for (int ih = 0; ih < conv_param->input_h_; ih++) {
@ -101,41 +101,3 @@ int DeConvPostFp32C8x8(const float *src, float *tmp, const float *bias, float *d
conv_param->is_relu6_);
return NNACL_OK;
}
int DeConvPostFp32C4(const float *src, float *tmp_c4, float *dst, const float *bias, int output_channel,
int input_plane, int kernel_plane, int output_plane, ConvParameter *conv_param) {
int oc4 = UP_DIV(output_channel, C4NUM);
for (int c = 0; c < oc4; c++) {
float *dst_ptr = tmp_c4 + c * output_plane * C4NUM;
const float *src_ptr = src + c * input_plane * kernel_plane * C4NUM;
memset(dst_ptr, 0, output_plane * C4NUM * sizeof(float));
for (int ih = 0; ih < conv_param->input_h_; ih++) {
for (int iw = 0; iw < conv_param->input_w_; iw++) {
int oh = ih * conv_param->stride_h_ - conv_param->pad_h_;
int ow = iw * conv_param->stride_w_ - conv_param->pad_w_;
int kh_start = MSMAX(0, UP_DIV(-oh, conv_param->dilation_h_));
int kh_end = MSMIN(conv_param->kernel_h_, UP_DIV(conv_param->output_h_ - oh, conv_param->dilation_h_));
int kw_start = MSMAX(0, UP_DIV(-ow, conv_param->dilation_w_));
int kw_end = MSMIN(conv_param->kernel_w_, UP_DIV(conv_param->output_w_ - ow, conv_param->dilation_w_));
for (int kh = kh_start; kh < kh_end; kh++) {
for (int kw = kw_start; kw < kw_end; kw++) {
int src_index = ih * conv_param->input_w_ * C4NUM + iw * C4NUM +
kh * input_plane * conv_param->kernel_w_ * C4NUM + kw * input_plane * C4NUM;
int dst_index = oh * conv_param->output_w_ * C4NUM + ow * C4NUM +
kh * conv_param->dilation_h_ * conv_param->output_w_ * C4NUM +
kw * conv_param->dilation_w_ * C4NUM;
for (int i = 0; i < C4NUM; i++) {
dst_ptr[dst_index + i] += src_ptr[src_index + i];
}
} /*kw*/
} /*kh*/
} /*iw*/
} /*ih*/
} /*oc4*/
PostConvFuncFp32C4(tmp_c4, dst, bias, output_channel, output_plane, conv_param->output_channel_, conv_param->is_relu_,
conv_param->is_relu6_);
return NNACL_OK;
}

View File

@ -16,20 +16,19 @@
#ifndef MINDSPORE_LITE_NNACL_FP32_DECONV_H_
#define MINDSPORE_LITE_NNACL_FP32_DECONV_H_
#include <string.h>
#include "nnacl/pack.h"
#include "nnacl/op_base.h"
#include "nnacl/conv_parameter.h"
#include "nnacl/fp32/strassen_matmul.h"
#include "nnacl/errorcode.h"
#include "nnacl/fp32/common_func.h"
#ifdef __cplusplus
extern "C" {
#endif
void PackDeConvWeightFp32(const float *weight, float *dst, int input_channel, int output_channel, int plane);
int DeConvPostFp32C4(const float *src, float *tmp_c4, float *dst, const float *bias, int output_channel,
int input_plane, int kernel_plane, int output_plane, ConvParameter *conv_param);
int DeConvPostFp32C8x8(const float *src, float *tmp_out, const float *bias, float *dst, int output_channel,
ConvParameter *conv_param);
int DeConvPostFp32C12x8(const float *src, float *tmp_out, const float *bias, float *dst, int output_channel,
ConvParameter *conv_param);
#ifdef __cplusplus
}
#endif

View File

@ -28,6 +28,18 @@ void RowMajor2Row8Major(float *src_ptr, float *dst_ptr, int row, int col) {
return;
}
void RowMajor2Row12Major(float *src_ptr, float *dst_ptr, int row, int col) {
for (int r = 0; r < row; r++) {
float *src = src_ptr + r * col;
for (int c = 0; c < col; c++) {
int cd8 = c / C12NUM;
int cm8 = c % C12NUM;
dst_ptr[cd8 * C12NUM * row + r * C12NUM + cm8] = src[c];
}
}
return;
}
void RowMajor2Col12Major(float *src_ptr, float *dst_ptr, size_t row, size_t col) {
size_t row12 = row / C12NUM * C12NUM;
size_t col4 = col / C4NUM * C4NUM;
@ -323,55 +335,9 @@ void Row8x8Major2RowMajor(float *src_ptr, float *dst_ptr, size_t row, size_t col
return;
}
void MatMul8x8(const float *a, const float *b, float *dst, const float *bias, ActType act_type, int deep, int row,
int col, int stride, bool write_nhwc) {
if (write_nhwc) {
/* col8-major * row8-major => col-major */
for (int r = 0; r < row; r++) {
for (int c = 0; c < col; c++) {
int r8div = r / 8, r8mod = r % 8;
int c8div = c / 8, c8mod = c % 8;
size_t ci = r * stride + c;
float value = 0;
for (int d = 0; d < deep; d++) {
size_t ai = r8div * deep * 8 + d * 8 + r8mod;
size_t bi = c8div * deep * 8 + d * 8 + c8mod;
value = value + a[ai] * b[bi];
}
if (bias != NULL) value += bias[c];
if (act_type == ActType_Relu6) value = MSMIN(6.0f, value);
if (act_type != ActType_No) value = MSMAX(0.0f, value);
dst[ci] = value;
}
}
} else {
/* col8-major * row8-major => col8x8-major */
int col_8 = UP_ROUND(col, C8NUM);
int row_8 = UP_ROUND(row, C8NUM);
for (int r = 0; r < row_8; r++) {
for (int c = 0; c < col_8; c++) {
int r8div = r / 8, r8mod = r % 8;
int c8div = c / 8, c8mod = c % 8;
size_t ci = c8div * row_8 * 8 + r * 8 + c8mod;
float value = 0;
for (int d = 0; d < deep; d++) {
size_t ai = r8div * deep * 8 + d * 8 + r8mod;
size_t bi = c8div * deep * 8 + d * 8 + c8mod;
value = value + a[ai] * b[bi];
}
if (bias != NULL) value += bias[c];
if (act_type == ActType_Relu6) value = MSMIN(6.0f, value);
if (act_type != ActType_No) value = MSMAX(0.0f, value);
dst[ci] = value;
}
}
}
return;
}
void MatMul12x8(const float *a, const float *b, float *dst, const float *bias, ActType act_type, int deep, int row,
int col, size_t stride, size_t writeNhwc, size_t writeC4) {
if (writeNhwc != 0) {
int col, int stride, int out_type) {
if (out_type == OutType_Nhwc) {
/* col8-major * row8-major => col-major */
for (int r = 0; r < row; r++) {
for (int c = 0; c < col; c++) {
@ -390,24 +356,39 @@ void MatMul12x8(const float *a, const float *b, float *dst, const float *bias, A
dst[ci] = value;
}
}
} else {
/* col8-major * row8-major => col12x8-major */
int col_8 = UP_ROUND(col, C8NUM);
int row_12 = UP_ROUND(row, C12NUM);
for (int r = 0; r < row_12; r++) {
for (int c = 0; c < col_8; c++) {
int r12div = r / C12NUM, r12mod = r % C12NUM;
int c8div = c / C8NUM, c8mod = c % C8NUM;
int c4div = c / C4NUM, c4mod = c % C4NUM;
size_t ci = (out_type == OutType_C4) ? (c4div * C4NUM * row_12 + r * C4NUM + c4mod)
: (c8div * C8NUM * row_12 + r * C8NUM + c8mod);
float value = 0;
for (int d = 0; d < deep; d++) {
size_t ai = r12div * deep * C12NUM + d * C12NUM + r12mod;
size_t bi = c8div * deep * C8NUM + d * C8NUM + c8mod;
value = value + a[ai] * b[bi];
}
if (bias != NULL) value += bias[c];
if (act_type == ActType_Relu6) value = MSMIN(6.0f, value);
if (act_type != ActType_No) value = MSMAX(0.0f, value);
dst[ci] = value;
}
}
}
return;
}
void MatMul(const float *a, const float *b, float *c, const float *bias, ActType act_type, int deep, int row, int col,
int stride, bool write_nhwc) {
#ifdef ENABLE_ARM64
MatmulFloatNeon64(a, b, c, bias, (int)act_type, deep, row, col, stride, write_nhwc);
#else
MatMul8x8(a, b, c, bias, act_type, deep, row, col, stride, write_nhwc);
#endif
}
void MatMulOpt(const float *a, const float *b, float *c, const float *bias, ActType act_type, int deep, int row,
int col, size_t stride, size_t writeNhwc, size_t writeC4) {
int col, size_t stride, int out_type) {
#ifdef ENABLE_ARM64
MatmulFloatNeon64Opt(a, b, c, bias, (int)act_type, deep, row, col, stride, writeNhwc, writeC4);
MatmulFloatNeon64Opt(a, b, c, bias, (int)act_type, deep, row, col, stride, (int)(out_type == OutType_Nhwc),
(int)(out_type == OutType_C4));
#else
MatMul12x8(a, b, c, bias, act_type, deep, row, col, stride, writeNhwc, writeC4);
MatMul12x8(a, b, c, bias, act_type, deep, row, col, stride, out_type);
#endif
}

View File

@ -26,11 +26,11 @@
#ifdef __cplusplus
extern "C" {
#endif
void MatMul(const float *a, const float *b, float *c, const float *bias, ActType act_type, int depth, int row, int col,
int stride, bool write_nhwc);
void MatMulOpt(const float *a, const float *b, float *c, const float *bias, ActType act_type, int depth, int row,
int col, size_t stride, size_t writeNhwc, size_t writeC4);
void MatMulOpt(const float *a, const float *b, float *c, const float *bias, ActType act_type, int deep, int row,
int col, size_t stride, int out_type);
void RowMajor2Row8Major(float *src_ptr, float *dst_ptr, int row, int col);
void RowMajor2Row12Major(float *src_ptr, float *dst_ptr, int row, int col);
void RowMajor2Col8Major(float *src_ptr, float *dst_ptr, size_t row, size_t col);
void RowMajor2Col12Major(float *src_ptr, float *dst_ptr, size_t row, size_t col);
void Row8x8Major2RowMajor(float *src_ptr, float *dst_ptr, size_t row, size_t col, size_t stride);
@ -38,7 +38,7 @@ void Row8x8Major2RowMajor(float *src_ptr, float *dst_ptr, size_t row, size_t col
void MatmulFloatNeon64(const float *a, const float *b, float *c, const float *bias, int act_type, int depth, int row,
int col, size_t stride, bool write_nhwc);
void MatmulFloatNeon64Opt(const float *a, const float *b, float *c, const float *bias, int act_type, int depth, int row,
int col, size_t stride, size_t writeNhwc, size_t writeC4);
int col, size_t stride, size_t write_nhwc, size_t write_c4);
#endif
#ifdef __cplusplus
}

View File

@ -1,204 +0,0 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "nnacl/fp32/strassen_matmul.h"
bool CheckRecursion(int row, int col, int deep, int max_recursion, int cur_recursion) {
if (cur_recursion >= max_recursion) {
return false;
}
if (row % 2 != 0 || col % 2 != 0 || deep % 2 != 0) {
return false;
}
int row2 = row / 2;
int col2 = col / 2;
int deep2 = deep / 2;
float save_cost = row * col * 4 * deep * 4 * 2 + row * col * 4 -
7 * (row2 * col2 * 4 * deep2 * 4 * 2 - row2 * col2 * 4) - 4 * (row2 * deep2 * 4 * 3) -
4 * (deep2 * 4 * col2 * 4 * 3) - 7 * (row2 * col2 * 4 * 3);
return (save_cost > 0.f);
}
void GemmMatMulComm(const float *a_ptr, const float *b_ptr, float *dst_ptr, int row, int col, int deep, int b_stride,
int c_stride) {
int row4mod = row % 4;
int row4div = row / 4;
for (int r = 0; r < row; r++) {
int r4mod = r % 4;
int r4div = r / 4;
for (int c = 0; c < col * 4; c++) {
float value = 0;
int ic = c / 4 * c_stride + r * 4 + c % 4;
for (int d = 0; d < deep * 4; d++) {
int d4mod = d % 4;
int d4div = d / 4;
int a_stride = (r < (row4div * 4)) ? 4 : row4mod;
int ai = r4div * 4 * deep * 4 + d4div * a_stride * 4 + r4mod * 4 + d4mod;
int bi = c / 4 * b_stride + d * 4 + c % 4;
value = value + a_ptr[ai] * b_ptr[bi];
}
dst_ptr[ic] = value;
}
}
return;
}
void GemmMatMul(const float *a_ptr, const float *b_ptr, float *dst_ptr, int row, int col, int deep, int b_stride,
int c_stride) {
int row4mod = row % 4;
int row4div = row / 4;
if (row4div > 0) {
GemmMatMulComm(a_ptr, b_ptr, dst_ptr, row4div * 4, col, deep, b_stride, c_stride);
}
if (row4mod != 0) {
GemmMatMulComm(a_ptr + row4div * deep * 4 * 4, b_ptr, dst_ptr + row4div * 4 * 4, row4mod, col, deep, b_stride,
c_stride);
}
return;
}
int RecursionMatmul(const float *a_ptr, const float *b_ptr, float *c_ptr, StrassenMatMulParameter *matmul_param,
int max_recursion, int cur_recursion, float *tmp_a_ptr) {
size_t row2 = matmul_param->row_ / 2;
size_t deep2 = matmul_param->deep_ / 2;
size_t col2 = matmul_param->col_ / 2;
size_t a_stride = matmul_param->a_stride_;
size_t b_stride = matmul_param->b_stride_;
size_t c_stride = matmul_param->c_stride_;
StrassenMatMulParameter rec_matmul;
rec_matmul.row_ = row2;
rec_matmul.deep_ = deep2;
rec_matmul.col_ = col2;
float *x_ptr = (float *)(malloc(row2 * MSMAX(deep2, col2) * FP32_STRASSEN_UINT * sizeof(float)));
if (x_ptr == NULL) {
return NNACL_ERRCODE_STRASSEN_RECURSION_MALLOC;
}
float *y_ptr = (float *)(malloc(col2 * deep2 * FP32_STRASSEN_WEIGHT_UINT * sizeof(float)));
if (y_ptr == NULL) {
free(x_ptr);
return NNACL_ERRCODE_STRASSEN_RECURSION_MALLOC;
}
size_t x_stride = row2 * FP32_STRASSEN_UINT;
size_t y_stride = deep2 * FP32_STRASSEN_WEIGHT_UINT;
const float *a11 = a_ptr;
const float *a12 = a_ptr + deep2 * a_stride;
const float *a21 = a_ptr + row2 * FP32_STRASSEN_UINT;
const float *a22 = a_ptr + deep2 * a_stride + row2 * FP32_STRASSEN_UINT;
const float *b11 = b_ptr;
const float *b12 = b_ptr + col2 * b_stride;
const float *b21 = b_ptr + deep2 * FP32_STRASSEN_WEIGHT_UINT;
const float *b22 = b_ptr + col2 * b_stride + deep2 * FP32_STRASSEN_WEIGHT_UINT;
float *c11 = c_ptr;
float *c12 = c_ptr + col2 * c_stride;
float *c21 = c_ptr + row2 * FP32_STRASSEN_UINT;
float *c22 = c_ptr + col2 * c_stride + row2 * FP32_STRASSEN_UINT;
/* S3 = A11 - A21 */
MatrixSub(a11, a21, x_ptr, a_stride, a_stride, x_stride, row2, deep2);
/* T3 = B22 - B12 */
MatrixSub(b22, b12, y_ptr, b_stride, b_stride, y_stride, deep2 * 4, col2);
/* P7 = S3T3 */
rec_matmul.a_stride_ = x_stride;
rec_matmul.b_stride_ = y_stride;
rec_matmul.c_stride_ = c_stride;
StrassenMatmul(x_ptr, y_ptr, c21, &rec_matmul, max_recursion, cur_recursion + 1, tmp_a_ptr);
/* S1 = A21 + A22 */
MatrixAdd(a21, a22, x_ptr, a_stride, a_stride, x_stride, row2, deep2);
/* T1 = B12 - B11 */
MatrixSub(b12, b11, y_ptr, b_stride, b_stride, y_stride, deep2 * 4, col2);
/* P5 = S1T1 */
StrassenMatmul(x_ptr, y_ptr, c22, &rec_matmul, max_recursion, cur_recursion + 1, tmp_a_ptr);
/* S2 = S1 - A11 */
MatrixSub(x_ptr, a11, x_ptr, x_stride, a_stride, x_stride, row2, deep2);
/* T2 = B22 - T1 */
MatrixSub(b22, y_ptr, y_ptr, b_stride, y_stride, y_stride, deep2 * 4, col2);
/* P6 = S2T2 */
StrassenMatmul(x_ptr, y_ptr, c12, &rec_matmul, max_recursion, cur_recursion + 1, tmp_a_ptr);
/* S4 = A12 - S2 */
MatrixSub(a12, x_ptr, x_ptr, a_stride, x_stride, x_stride, row2, deep2);
/* P3 = S4B22 */
rec_matmul.b_stride_ = b_stride;
StrassenMatmul(x_ptr, b22, c11, &rec_matmul, max_recursion, cur_recursion + 1, tmp_a_ptr);
/* P1 = A11B11 */
rec_matmul.a_stride_ = a_stride;
rec_matmul.c_stride_ = row2 * FP32_STRASSEN_UINT;
StrassenMatmul(a11, b11, x_ptr, &rec_matmul, max_recursion, cur_recursion + 1, tmp_a_ptr);
/* U2 = P1 + P6
U3 = U2 + P7
U4 = U2 + P5
U7 = U3 + P5
U5 = U4 + P3 */
MatrixMultiAdd(c11, c12, c21, c22, x_ptr, row2, col2, c_stride, x_stride);
/* T4 = T2 - B21 */
MatrixSub(y_ptr, b21, y_ptr, y_stride, b_stride, y_stride, deep2 * 4, col2);
/* P4 = A22T4 */
rec_matmul.b_stride_ = y_stride;
rec_matmul.c_stride_ = c_stride;
StrassenMatmul(a22, y_ptr, c11, &rec_matmul, max_recursion, cur_recursion + 1, tmp_a_ptr);
/* U6 = U3 - P4 */
MatrixSub(c21, c11, c21, c_stride, c_stride, c_stride, row2, col2);
/* P2 = A12B21 */
rec_matmul.b_stride_ = b_stride;
StrassenMatmul(a12, b21, c11, &rec_matmul, max_recursion, cur_recursion + 1, tmp_a_ptr);
/* U1 = P1 + P2 */
MatrixAdd(x_ptr, c11, c11, x_stride, c_stride, c_stride, row2, col2);
free(x_ptr);
free(y_ptr);
return NNACL_OK;
}
int CommonMatMul(const float *a_ptr, const float *b_ptr, float *c_ptr, StrassenMatMulParameter *matmul_param,
float *tmp_a_ptr) {
MatrixPack(a_ptr, tmp_a_ptr, matmul_param->row_, matmul_param->deep_, matmul_param->a_stride_);
GemmMatMul(tmp_a_ptr, b_ptr, c_ptr, matmul_param->row_, matmul_param->col_, matmul_param->deep_,
matmul_param->b_stride_, matmul_param->c_stride_);
return NNACL_OK;
}
int StrassenMatmul(const float *a_ptr, const float *b_ptr, float *c_ptr, StrassenMatMulParameter *matmul_param,
int max_recursion, int cur_recursion, float *tmp_a_ptr) {
if (CheckRecursion(matmul_param->row_, matmul_param->col_, matmul_param->deep_, cur_recursion, max_recursion)) {
return RecursionMatmul(a_ptr, b_ptr, c_ptr, matmul_param, max_recursion, cur_recursion, tmp_a_ptr);
}
return CommonMatMul(a_ptr, b_ptr, c_ptr, matmul_param, tmp_a_ptr);
}

View File

@ -1,45 +0,0 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_NNACL_FP32_STRASSEN_MATMUL_H_
#define MINDSPORE_LITE_NNACL_FP32_STRASSEN_MATMUL_H_
#include <memory.h>
#include "nnacl/pack.h"
#include "nnacl/op_base.h"
#include "nnacl/errorcode.h"
#include "nnacl/strassen_matmul.h"
#include "nnacl/fp32/common_func.h"
#define FP32_STRASSEN_UINT C4NUM
#define FP32_STRASSEN_WEIGHT_UINT (C4NUM * C4NUM)
#define FP32_STRASSEN_MAX_RECURSION 5
#ifdef __cplusplus
extern "C" {
#endif
int RecursionMatmul(const float *a_ptr, const float *b_ptr, float *c_ptr, StrassenMatMulParameter *matmul_param,
int max_recursion, int, float *tmp_a_ptr);
int CommonMatMul(const float *a_ptr, const float *b_ptr, float *c_ptr, StrassenMatMulParameter *Matmul_param,
float *tmp_a_ptr);
int StrassenMatmul(const float *a_ptr, const float *b_ptr, float *c_ptr, StrassenMatMulParameter *matmul_param,
int max_recursion, int cur_recursion, float *tmp_a_ptr);
#ifdef __cplusplus
}
#endif
#endif // MINDSPORE_LITE_NNACL_FP32_STRASSEN_MATMUL_H_

View File

@ -31,6 +31,8 @@ typedef void (*MAT_TRANS_FUNC)(void *dst, void *a, int row, int col);
typedef enum ActType { ActType_No, ActType_Relu, ActType_Relu6 } ActType;
typedef enum OutType { OutType_C8 = 0, OutType_Nhwc = 1, OutType_C4 = 2 } OutType;
typedef struct MatMulParameter {
OpParameter op_parameter_;
int row_;

View File

@ -1,33 +0,0 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_NNACL_STRASSEN_MATMUL_H_
#define MINDSPORE_LITE_NNACL_STRASSEN_MATMUL_H_
#include "nnacl/op_base.h"
/* hw*inc4 X inc4*oc4 */
typedef struct StrassenMatMulParameter {
OpParameter op_parameter;
int row_; /* h * w */
int col_; /* oc4 / 4 */
int deep_; /* inc4 / 4 */
int a_stride_; /* h * w * 4 */
int b_stride_; /* inc4 * 4 */
int c_stride_; /* h * w * 4 */
} StrassenMatMulParameter;
#endif // MINDSPORE_LITE_NNACL_STRASSEN_MATMUL_H_

View File

@ -39,6 +39,10 @@ void Convolution1x1CPUKernel::FreeTmpBuffer() {
free(pack_input_);
pack_input_ = nullptr;
}
if (pre_trans_input_ && input_ptr_ != nullptr) {
free(input_ptr_);
input_ptr_ = nullptr;
}
return;
}
@ -106,6 +110,16 @@ int Convolution1x1CPUKernel::InitConv1x1Param() {
return RET_MEMORY_FAILED;
}
memset(pack_input_, 0, matmul_param_->row_12_ * matmul_param_->deep_ * sizeof(float));
if (pre_trans_input_) {
input_ptr_ = reinterpret_cast<float *>(malloc(matmul_param_->row_ * matmul_param_->deep_ * sizeof(float)));
if (input_ptr_ == nullptr) {
MS_LOG(ERROR) << "Conv1x1 Malloc input_ptr_ error!";
return RET_MEMORY_FAILED;
}
memset(input_ptr_, 0, matmul_param_->row_ * matmul_param_->deep_ * sizeof(float));
}
return RET_OK;
}
@ -140,13 +154,10 @@ int Convolution1x1CPUKernel::DoConv1x1(int task_id) {
if (cur_oc <= 0) {
return RET_OK;
}
auto bias = (bias_data_ == nullptr) ? nullptr : reinterpret_cast<float *>(bias_data_) + thread_stride_ * task_id;
MatMulOpt(pack_input_, weight_ptr_ + task_id * thread_stride_ * matmul_param_->deep_,
output_ptr_ + task_id * thread_stride_, bias, matmul_param_->act_type_, matmul_param_->deep_,
matmul_param_->row_, cur_oc, matmul_param_->col_, 1, 0);
output_ptr_ + task_id * thread_stride_, reinterpret_cast<float *>(bias_data_) + thread_stride_ * task_id,
matmul_param_->act_type_, matmul_param_->deep_, matmul_param_->row_, cur_oc, matmul_param_->col_,
OutType_Nhwc);
return RET_OK;
}
@ -169,15 +180,6 @@ int Convolution1x1CPUKernel::Run() {
auto src_in = reinterpret_cast<float *>(in_tensors_[0]->Data());
auto src_out = reinterpret_cast<float *>(out_tensors_[0]->Data());
if (pre_trans_input_) {
input_ptr_ =
reinterpret_cast<float *>(ctx_->allocator->Malloc(matmul_param_->row_ * matmul_param_->deep_ * sizeof(float)));
if (input_ptr_ == nullptr) {
MS_LOG(ERROR) << "Conv1x1 Malloc input_ptr_ error!";
return RET_MEMORY_FAILED;
}
}
for (int batch_index = 0; batch_index < conv_param_->input_batch_; batch_index++) {
Pre1x1Trans(src_in + batch_index * conv_param_->input_h_ * conv_param_->input_w_ * conv_param_->input_channel_,
src_out + batch_index * matmul_param_->row_ * matmul_param_->col_);
@ -189,10 +191,6 @@ int Convolution1x1CPUKernel::Run() {
}
}
if (pre_trans_input_) {
ctx_->allocator->Free(input_ptr_);
input_ptr_ = nullptr;
}
return RET_OK;
}
} // namespace mindspore::kernel

View File

@ -95,13 +95,13 @@ int DeConvolutionCPUKernel::InitParam() {
matmul_param_->row_ = input_plane_;
matmul_param_->deep_ = conv_param_->input_channel_;
matmul_param_->col_ = conv_param_->output_channel_ * kernel_plane_;
matmul_param_->row_8_ = UP_ROUND(matmul_param_->row_, C8NUM);
matmul_param_->row_12_ = UP_ROUND(matmul_param_->row_, C12NUM);
matmul_param_->col_8_ = UP_ROUND(conv_param_->output_channel_, C8NUM) * kernel_plane_;
thread_count_ = MSMIN(op_parameter_->thread_num_, UP_DIV(conv_param_->output_channel_, C8NUM));
thread_stride_ = UP_DIV(UP_DIV(conv_param_->output_channel_, C8NUM), thread_count_);
pack_input_ = reinterpret_cast<float *>(malloc(matmul_param_->row_8_ * matmul_param_->deep_ * sizeof(float)));
pack_input_ = reinterpret_cast<float *>(malloc(matmul_param_->row_12_ * matmul_param_->deep_ * sizeof(float)));
if (pack_input_ == nullptr) {
MS_LOG(ERROR) << "deconv Malloc pack_input_ error!";
return RET_ERROR;
@ -126,14 +126,14 @@ int DeConvolutionCPUKernel::DoDeconv(int task_id) {
return RET_OK;
}
auto tmp_buffer = tmp_buffer_ + task_id * thread_stride_ * C8NUM * kernel_plane_ * matmul_param_->row_8_;
MatMul(pack_input_, weight_ptr_ + task_id * thread_stride_ * C8NUM * kernel_plane_ * matmul_param_->deep_, tmp_buffer,
nullptr, ActType_No, matmul_param_->deep_, matmul_param_->row_8_, oc * C8NUM * kernel_plane_,
matmul_param_->col_, false);
auto tmp_buffer = tmp_buffer_ + task_id * thread_stride_ * C8NUM * kernel_plane_ * matmul_param_->row_12_;
MatMulOpt(pack_input_, weight_ptr_ + task_id * thread_stride_ * C8NUM * kernel_plane_ * matmul_param_->deep_,
tmp_buffer, nullptr, ActType_No, matmul_param_->deep_, matmul_param_->row_12_, oc * C8NUM * kernel_plane_,
matmul_param_->col_, OutType_C8);
DeConvPostFp32C8x8(tmp_buffer, pack_output_ + task_id * thread_stride_ * C8NUM * output_plane_,
reinterpret_cast<float *>(bias_data_) + thread_stride_ * task_id * C8NUM,
output_ptr_ + task_id * thread_stride_ * C8NUM, oc_res, conv_param_);
DeConvPostFp32C12x8(tmp_buffer, pack_output_ + task_id * thread_stride_ * C8NUM * output_plane_,
reinterpret_cast<float *>(bias_data_) + thread_stride_ * task_id * C8NUM,
output_ptr_ + task_id * thread_stride_ * C8NUM, oc_res, conv_param_);
return RET_OK;
}
@ -165,7 +165,7 @@ int DeConvolutionCPUKernel::InitRunBuf() {
}
tmp_buffer_ =
reinterpret_cast<float *>(ctx_->allocator->Malloc(matmul_param_->row_8_ * matmul_param_->col_8_ * sizeof(float)));
reinterpret_cast<float *>(ctx_->allocator->Malloc(matmul_param_->row_12_ * matmul_param_->col_8_ * sizeof(float)));
if (tmp_buffer_ == nullptr) {
MS_LOG(ERROR) << "Conv1x1 Malloc tmp_buffer_ error!";
return RET_NULL_PTR;
@ -192,7 +192,7 @@ int DeConvolutionCPUKernel::Run() {
input_ptr_ = src_in + batch_index * input_plane_ * conv_param_->input_channel_;
output_ptr_ = src_out + batch_index * output_plane_ * conv_param_->output_channel_;
RowMajor2Col8Major(input_ptr_, pack_input_, input_plane_, conv_param_->input_channel_);
RowMajor2Col12Major(input_ptr_, pack_input_, input_plane_, conv_param_->input_channel_);
error_code = LiteBackendParallelLaunch(DeConvFp32Run, this, thread_count_);
if (error_code != RET_OK) {

View File

@ -27,18 +27,14 @@ FullconnectionCPUKernel::~FullconnectionCPUKernel() {
}
void FullconnectionCPUKernel::FreeBuf() {
if (a_c8_ptr_ != nullptr) {
free(a_c8_ptr_);
a_c8_ptr_ = nullptr;
if (a_c12_ptr_ != nullptr) {
free(a_c12_ptr_);
a_c12_ptr_ = nullptr;
}
if (b_r8_ptr_ != nullptr) {
free(b_r8_ptr_);
b_r8_ptr_ = nullptr;
}
if (c_r8x8_ptr_ != nullptr) {
free(c_r8x8_ptr_);
c_r8x8_ptr_ = nullptr;
}
if (bias_ptr_ != nullptr) {
free(bias_ptr_);
bias_ptr_ = nullptr;
@ -51,8 +47,8 @@ int FullconnectionCPUKernel::ReSize() {
fc_param_->col_ = (in_tensors_[1]->shape())[0];
fc_param_->deep_ = (in_tensors_[1]->shape())[1];
fc_param_->row_8_ = UP_ROUND(fc_param_->row_, 8);
fc_param_->col_8_ = UP_ROUND(fc_param_->col_, 8);
fc_param_->row_12_ = UP_ROUND(fc_param_->row_, C12NUM);
fc_param_->col_8_ = UP_ROUND(fc_param_->col_, C8NUM);
thread_count_ = MSMIN(thread_count_, UP_DIV(fc_param_->col_8_, 8));
thread_stride_ = UP_DIV(UP_DIV(fc_param_->col_8_, 8), thread_count_);
@ -63,11 +59,11 @@ int FullconnectionCPUKernel::ReSize() {
memcpy(bias_ptr_, in_tensors_[2]->Data(), fc_param_->col_ * sizeof(float));
}
a_c8_ptr_ = reinterpret_cast<float *>(malloc(fc_param_->row_8_ * fc_param_->deep_ * sizeof(float)));
if (a_c8_ptr_ == nullptr) {
a_c12_ptr_ = reinterpret_cast<float *>(malloc(fc_param_->row_12_ * fc_param_->deep_ * sizeof(float)));
if (a_c12_ptr_ == nullptr) {
return RET_MEMORY_FAILED;
}
memset(a_c8_ptr_, 0, fc_param_->row_8_ * fc_param_->deep_ * sizeof(float));
memset(a_c12_ptr_, 0, fc_param_->row_12_ * fc_param_->deep_ * sizeof(float));
b_r8_ptr_ = reinterpret_cast<float *>(malloc(fc_param_->col_8_ * fc_param_->deep_ * sizeof(float)));
if (b_r8_ptr_ == nullptr) {
@ -76,16 +72,9 @@ int FullconnectionCPUKernel::ReSize() {
}
memset(b_r8_ptr_, 0, fc_param_->col_8_ * fc_param_->deep_ * sizeof(float));
c_r8x8_ptr_ = reinterpret_cast<float *>(malloc(fc_param_->row_8_ * fc_param_->col_8_ * sizeof(float)));
if (c_r8x8_ptr_ == nullptr) {
FreeBuf();
return RET_MEMORY_FAILED;
}
memset(c_r8x8_ptr_, 0, fc_param_->row_8_ * fc_param_->col_8_ * sizeof(float));
fc_param_->a_const_ = false;
fc_param_->b_const_ = false;
InitMatrixA(reinterpret_cast<float *>(in_tensors_[0]->Data()), a_c8_ptr_);
InitMatrixA(reinterpret_cast<float *>(in_tensors_[0]->Data()), a_c12_ptr_);
InitMatrixB(reinterpret_cast<float *>(in_tensors_[1]->Data()), b_r8_ptr_);
return RET_OK;
}
@ -105,7 +94,7 @@ void FullconnectionCPUKernel::InitMatrixA(float *src_ptr, float *dst_ptr) {
return;
}
fc_param_->a_const_ = true;
RowMajor2Col8Major(src_ptr, a_c8_ptr_, fc_param_->row_, fc_param_->deep_);
RowMajor2Col12Major(src_ptr, a_c12_ptr_, fc_param_->row_, fc_param_->deep_);
return;
}
@ -132,15 +121,14 @@ int FcFp32MatmulRun(int task_id, LiteParallelGroupEnv *penv, void *cdata) {
}
int FullconnectionCPUKernel::DoMatmul(int task_id) {
int cur_oc = MSMIN(thread_stride_, UP_DIV(fc_param_->col_8_, 8) - task_id * thread_stride_);
int cur_oc = MSMIN(thread_stride_ * C8NUM, fc_param_->col_ - task_id * thread_stride_ * C8NUM);
if (cur_oc <= 0) {
return RET_OK;
}
MatMul(a_c8_ptr_, b_r8_ptr_ + task_id * thread_stride_ * C8NUM * fc_param_->deep_,
c_r8x8_ptr_ + task_id * thread_stride_ * C8NUM * fc_param_->row_8_,
bias_ptr_ + task_id * thread_stride_ * C8NUM, fc_param_->act_type_, fc_param_->deep_, fc_param_->row_8_,
cur_oc * 8, 0, false);
MatMulOpt(a_c12_ptr_, b_r8_ptr_ + task_id * thread_stride_ * C8NUM * fc_param_->deep_,
c_r_ptr + task_id * thread_stride_ * C8NUM, bias_ptr_ + task_id * thread_stride_ * C8NUM,
fc_param_->act_type_, fc_param_->deep_, fc_param_->row_, cur_oc, fc_param_->col_, OutType_Nhwc);
return RET_OK;
}
@ -152,14 +140,13 @@ int FullconnectionCPUKernel::Run() {
}
auto a_ptr = reinterpret_cast<float *>(in_tensors_.at(0)->Data());
auto b_ptr = reinterpret_cast<float *>(in_tensors_.at(1)->Data());
auto output_ptr = reinterpret_cast<float *>(out_tensors_.at(0)->Data());
c_r_ptr = reinterpret_cast<float *>(out_tensors_.at(0)->Data());
InitMatrixA(a_ptr, a_c8_ptr_);
InitMatrixA(a_ptr, a_c12_ptr_);
InitMatrixB(b_ptr, b_r8_ptr_);
LiteBackendParallelLaunch(FcFp32MatmulRun, this, thread_count_);
Row8x8Major2RowMajor(c_r8x8_ptr_, output_ptr, fc_param_->row_, fc_param_->col_, fc_param_->col_);
return RET_OK;
}
} // namespace mindspore::kernel

View File

@ -47,9 +47,9 @@ class FullconnectionCPUKernel : public FullconnectionBaseCPUKernel {
void InitMatrixB(float *src_ptr, float *dst_ptr);
private:
float *a_c8_ptr_ = nullptr;
float *a_c12_ptr_ = nullptr;
float *b_r8_ptr_ = nullptr;
float *c_r8x8_ptr_ = nullptr;
float *c_r_ptr = nullptr;
float *bias_ptr_ = nullptr;
};
} // namespace mindspore::kernel

View File

@ -28,18 +28,14 @@ namespace mindspore::kernel {
MatmulCPUKernel::~MatmulCPUKernel() { FreeTmpBuffer(); }
void MatmulCPUKernel::FreeTmpBuffer() {
if (a_c8_ptr_ != nullptr) {
ctx_->allocator->Free(a_c8_ptr_);
a_c8_ptr_ = nullptr;
if (a_c12_ptr_ != nullptr) {
ctx_->allocator->Free(a_c12_ptr_);
a_c12_ptr_ = nullptr;
}
if (b_r8_ptr_ != nullptr) {
ctx_->allocator->Free(b_r8_ptr_);
b_r8_ptr_ = nullptr;
}
if (c_r8x8_ptr_ != nullptr) {
ctx_->allocator->Free(c_r8x8_ptr_);
c_r8x8_ptr_ = nullptr;
}
if (bias_ptr_ != nullptr) {
ctx_->allocator->Free(bias_ptr_);
bias_ptr_ = nullptr;
@ -66,45 +62,37 @@ int MatmulCPUKernel::ReSize() {
params_->row_ = c_shape[c_shape.size() - 2];
params_->col_ = c_shape[c_shape.size() - 1];
params_->deep_ = params_->a_transpose_ ? a_shape[a_shape.size() - 2] : a_shape[a_shape.size() - 1];
params_->row_8_ = UP_ROUND(params_->row_, 8);
params_->row_12_ = UP_ROUND(params_->row_, C12NUM);
params_->col_8_ = UP_ROUND(params_->col_, 8);
thread_count_ = MSMIN(thread_count_, UP_DIV(params_->col_8_, 8));
thread_stride_ = UP_DIV(UP_DIV(params_->col_8_, 8), thread_count_);
a_c8_ptr_ = reinterpret_cast<float *>(ctx_->allocator->Malloc(params_->row_8_ * params_->deep_ * sizeof(float)));
if (a_c8_ptr_ == nullptr) {
a_c12_ptr_ = reinterpret_cast<float *>(ctx_->allocator->Malloc(params_->row_12_ * params_->deep_ * sizeof(float)));
if (a_c12_ptr_ == nullptr) {
FreeTmpBuffer();
return RET_MEMORY_FAILED;
}
memset(a_c8_ptr_, 0, params_->row_8_ * params_->deep_ * sizeof(float));
memset(a_c12_ptr_, 0, params_->row_12_ * params_->deep_ * sizeof(float));
b_r8_ptr_ = reinterpret_cast<float *>(ctx_->allocator->Malloc(params_->col_8_ * params_->deep_ * sizeof(float)));
if (b_r8_ptr_ == nullptr) {
FreeTmpBuffer();
return RET_MEMORY_FAILED;
}
memset(b_r8_ptr_, 0, params_->col_8_ * params_->deep_ * sizeof(float));
c_r8x8_ptr_ = reinterpret_cast<float *>(ctx_->allocator->Malloc(params_->row_8_ * params_->col_8_ * sizeof(float)));
if (c_r8x8_ptr_ == nullptr) {
FreeTmpBuffer();
return RET_MEMORY_FAILED;
}
memset(c_r8x8_ptr_, 0, params_->row_8_ * params_->col_8_ * sizeof(float));
params_->a_const_ = false;
params_->b_const_ = false;
InitMatrixA(reinterpret_cast<float *>(in_tensors_[0]->Data()), a_c8_ptr_);
InitMatrixA(reinterpret_cast<float *>(in_tensors_[0]->Data()), a_c12_ptr_);
InitMatrixB(reinterpret_cast<float *>(in_tensors_[1]->Data()), b_r8_ptr_);
bias_ptr_ = reinterpret_cast<float *>(malloc(params_->col_8_ * sizeof(float)));
if (bias_ptr_ == nullptr) {
FreeTmpBuffer();
return RET_MEMORY_FAILED;
}
memset(bias_ptr_, 0, params_->col_8_ * sizeof(float));
if (in_tensors_.size() == 3) {
bias_ptr_ = reinterpret_cast<float *>(malloc(params_->col_8_ * sizeof(float)));
if (bias_ptr_ == nullptr) {
FreeTmpBuffer();
return RET_MEMORY_FAILED;
}
memset(bias_ptr_, 0, params_->col_8_ * sizeof(float));
memcpy(bias_ptr_, in_tensors_[2]->Data(), params_->col_ * sizeof(float));
} else {
bias_ptr_ = nullptr;
}
return RET_OK;
@ -120,9 +108,9 @@ void MatmulCPUKernel::InitMatrixA(float *src_ptr, float *dst_ptr) {
params_->a_const_ = true;
if (params_->a_transpose_) {
RowMajor2Row8Major(src_ptr, dst_ptr, params_->deep_, params_->row_);
RowMajor2Row12Major(src_ptr, dst_ptr, params_->deep_, params_->row_);
} else {
RowMajor2Col8Major(src_ptr, a_c8_ptr_, params_->row_, params_->deep_);
RowMajor2Col12Major(src_ptr, dst_ptr, params_->row_, params_->deep_);
}
return;
}
@ -152,18 +140,13 @@ int MatmulCPUKernel::Init() {
}
int MatmulCPUKernel::RunImpl(int task_id) {
int cur_oc = MSMIN(thread_stride_, UP_DIV(params_->col_8_, 8) - task_id * thread_stride_);
int cur_oc = MSMIN(thread_stride_ * C8NUM, params_->col_ - task_id * thread_stride_ * C8NUM);
if (cur_oc <= 0) {
return RET_OK;
}
auto cur_b = b_r8_ptr_ + task_id * thread_stride_ * C8NUM * params_->deep_;
auto cur_c = c_r8x8_ptr_ + task_id * thread_stride_ * C8NUM * params_->row_8_;
if (bias_ptr_) {
auto cur_bias = bias_ptr_ + task_id * thread_stride_ * C8NUM;
MatMul(a_c8_ptr_, cur_b, cur_c, cur_bias, ActType_No, params_->deep_, params_->row_8_, cur_oc * 8, 0, false);
} else {
MatMul(a_c8_ptr_, cur_b, cur_c, NULL, ActType_No, params_->deep_, params_->row_8_, cur_oc * 8, 0, false);
}
MatMulOpt(a_c12_ptr_, b_r8_ptr_ + task_id * thread_stride_ * C8NUM * params_->deep_,
c_r_ptr_ + task_id * thread_stride_ * C8NUM, bias_ptr_ + task_id * thread_stride_ * C8NUM, ActType_No,
params_->deep_, params_->row_, cur_oc, params_->col_, OutType_Nhwc);
return RET_OK;
}
@ -192,13 +175,12 @@ int MatmulCPUKernel::Run() {
for (int i = 0; i < params_->batch; ++i) {
auto cur_a_ptr = a_ptr + i * a_stride;
auto cur_b_ptr = b_ptr + i * b_stride;
auto cur_c_ptr = c_ptr + i * c_stride;
c_r_ptr_ = c_ptr + i * c_stride;
InitMatrixA(cur_a_ptr, a_c8_ptr_);
InitMatrixA(cur_a_ptr, a_c12_ptr_);
InitMatrixB(cur_b_ptr, b_r8_ptr_);
LiteBackendParallelLaunch(MatmulFloatRun, this, thread_count_);
Row8x8Major2RowMajor(c_r8x8_ptr_, cur_c_ptr, params_->row_, params_->col_, params_->col_);
}
return RET_OK;
}

View File

@ -41,9 +41,9 @@ class MatmulCPUKernel : public MatmulBaseCPUKernel {
void FreeTmpBuffer();
private:
float *a_c8_ptr_ = nullptr;
float *a_c12_ptr_ = nullptr;
float *b_r8_ptr_ = nullptr;
float *c_r8x8_ptr_ = nullptr;
float *c_r_ptr_ = nullptr;
float *bias_ptr_ = nullptr;
};
} // namespace mindspore::kernel

View File

@ -19,9 +19,8 @@
#include "utils/log_adapter.h"
#include "common/common_test.h"
#include "src/common/file_utils.h"
#include "src/runtime/kernel/arm/fp32/convolution_1x1.h"
#include "nnacl/matmul_parameter.h"
#include "nnacl/strassen_matmul.h"
#include "src/runtime/kernel/arm/fp32/convolution_1x1.h"
namespace mindspore {
using mindspore::lite::tensor::Tensor;

View File

@ -548,14 +548,14 @@ TEST_F(TestDeConvolutionFp32, DeConvTest2) {
float *correct;
int total_size = DeConvTestInit2(&inputs_, &outputs_, deconv_param, &correct);
lite::Context *ctx = new lite::Context;
ctx->thread_num_ = 4;
ctx->thread_num_ = 1;
kernel::DeConvolutionCPUKernel *deconv =
new kernel::DeConvolutionCPUKernel(reinterpret_cast<OpParameter *>(deconv_param), inputs_, outputs_, ctx, nullptr);
deconv->Init();
deconv->Run();
EXPECT_EQ(0, lite::CompareOutputData(reinterpret_cast<float *>(outputs_[0]->Data()), correct, total_size));
delete deconv_param;
CompareOutputData(reinterpret_cast<float *>(outputs_[0]->Data()), correct, total_size, 0.0001);
delete deconv;
for (auto t : inputs_) delete t;
for (auto t : outputs_) delete t;
@ -635,7 +635,6 @@ TEST_F(TestDeConvolutionFp32, DeConvTest3) {
deconv->Run();
CompareOutputData(reinterpret_cast<float *>(outputs_[0]->Data()), correct, total_size, 0.0001);
delete deconv_param;
delete deconv;
for (auto t : inputs_) delete t;
for (auto t : outputs_) delete t;
@ -723,7 +722,6 @@ TEST_F(TestDeConvolutionFp32, DeConvTest4) {
uint64_t time_avg = cost / loop_count;
printf("deconv fp32 average time : %f ms\n", time_avg / 1000.0f);
delete deconv_param;
delete deconv;
for (auto t : inputs_) delete t;
for (auto t : outputs_) delete t;

View File

@ -1,369 +0,0 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <iostream>
#include <memory>
#include "utils/log_adapter.h"
#include "common/common_test.h"
#include "src/common/file_utils.h"
#include "mindspore/lite/nnacl/pack.h"
#include "mindspore/lite/nnacl/fp32/strassen_matmul.h"
#include "mindspore/lite/nnacl/conv_parameter.h"
namespace mindspore {
class TestStrassenFp32 : public mindspore::CommonTest {
public:
TestStrassenFp32() {}
};
TEST_F(TestStrassenFp32, MatrixAdd1) {
float a[] = {0.06796285, 0.6176181, 0.33195993, 0.2752791, 0.36864007, 0.04605605, 0.33899087, 0.9820137,
0.49804246, 0.8242412, 0.8458231, 0.6530539, 0.6336898, 0.8367749, 0.57166654, 0.25895607,
0.90079665, 0.10585558, 0.8215811, 0.48977906, 0.7895138, 0.41816455, 0.18999523, 0.28736928,
0.5882977, 0.44262612, 0.65245426, 0.7834421, 0.60903394, 0.82289135, 0.03855767, 0.30543327,
0.37747085, 0, 0, 0, 0.590335, 0, 0, 0,
0.7578682, 0, 0, 0, 0.81001425, 0, 0, 0,
0.9487712, 0, 0, 0, 0.11742989, 0, 0, 0,
0.60004807, 0, 0, 0, 0.05973052, 0, 0, 0};
float b[] = {0.112120815, 0.6869974, 0.08290442, 0.43003577, 0.044390075, 0.23077105, 0.23964432, 0.4426781,
0.6612115, 0.14988606, 0.84881437, 0.032587975, 0.35028255, 0.41838303, 0.12859282, 0.060378596,
0.8272769, 0.6949804, 0.9120368, 0.12399232, 0.9292184, 0.7566025, 0.10235854, 0.015936268,
0.20426726, 0.9926392, 0.54714125, 0.7022856, 0.58746314, 0.95714045, 0.26433542, 0.9030878,
0.8596953, 0, 0, 0, 0.8341476, 0, 0, 0,
0.72301114, 0, 0, 0, 0.40733734, 0, 0, 0,
0.2873559, 0, 0, 0, 0.612321, 0, 0, 0,
0.5008707, 0, 0, 0, 0.2586266, 0, 0, 0};
float add[] = {0.18008366, 1.3046155, 0.41486436, 0.7053149, 0.41303015, 0.2768271, 0.5786352, 1.4246918,
1.159254, 0.9741273, 1.6946375, 0.6856419, 0.9839724, 1.255158, 0.7002593, 0.3193347,
1.7280736, 0.80083597, 1.7336179, 0.6137714, 1.7187322, 1.174767, 0.29235378, 0.30330554,
0.792565, 1.4352653, 1.1995955, 1.4857277, 1.1964971, 1.7800318, 0.3028931, 1.2085211,
1.2371662, 0, 0, 0, 1.4244826, 0, 0, 0,
1.4808793, 0, 0, 0, 1.2173516, 0, 0, 0,
1.2361271, 0, 0, 0, 0.72975093, 0, 0, 0,
1.1009188, 0, 0, 0, 0.31835714, 0, 0, 0};
float out[64] = {0};
MatrixAdd(a, b, out, 32, 32, 32, 8, 2);
EXPECT_EQ(0, lite::CompareOutputData(out, add, 64));
}
TEST_F(TestStrassenFp32, MatrixAdd2) {
float a[] = {0.06796285, 0.6176181, 0.33195993, 0.2752791, 0.36864007, 0.04605605, 0.33899087, 0.9820137,
0.49804246, 0.8242412, 0.8458231, 0.6530539, 0.6336898, 0.8367749, 0.57166654, 0.25895607,
0.90079665, 0.10585558, 0.8215811, 0.48977906, 0.7895138, 0.41816455, 0.18999523, 0.28736928,
0.5882977, 0.44262612, 0.65245426, 0.7834421, 0.60903394, 0.82289135, 0.03855767, 0.30543327,
0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0.37747085, 0, 0, 0,
0.590335, 0, 0, 0, 0.7578682, 0, 0, 0,
0.81001425, 0, 0, 0, 0.9487712, 0, 0, 0,
0.11742989, 0, 0, 0, 0.60004807, 0, 0, 0,
0.05973052, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0};
float b[] = {0.112120815, 0.6869974, 0.08290442, 0.43003577, 0.044390075, 0.23077105, 0.23964432, 0.4426781,
0.6612115, 0.14988606, 0.84881437, 0.032587975, 0.35028255, 0.41838303, 0.12859282, 0.060378596,
0.8272769, 0.6949804, 0.9120368, 0.12399232, 0.9292184, 0.7566025, 0.10235854, 0.015936268,
0.20426726, 0.9926392, 0.54714125, 0.7022856, 0.58746314, 0.95714045, 0.26433542, 0.9030878,
0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0,
0.8596953, 0, 0, 0, 0.8341476, 0, 0, 0,
0.72301114, 0, 0, 0, 0.40733734, 0, 0, 0,
0.2873559, 0, 0, 0, 0.612321, 0, 0, 0,
0.5008707, 0, 0, 0, 0.2586266, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0};
float add[] = {0.18008366, 1.3046155, 0.41486436, 0.7053149, 0.41303015, 0.2768271, 0.5786352, 1.4246918,
1.159254, 0.9741273, 1.6946375, 0.6856419, 0.9839724, 1.255158, 0.7002593, 0.3193347,
1.7280736, 0.80083597, 1.7336179, 0.6137714, 1.7187322, 1.174767, 0.29235378, 0.30330554,
0.792565, 1.4352653, 1.1995955, 1.4857277, 1.1964971, 1.7800318, 0.3028931, 1.2085211,
0, 0, 0, 0, 1.2371662, 0, 0, 0,
1.4244826, 0, 0, 0, 1.4808793, 0, 0, 0,
1.2173516, 0, 0, 0, 1.2361271, 0, 0, 0,
0.72975093, 0, 0, 0, 1.1009188, 0, 0, 0,
0.31835714, 0, 0, 0, 0, 0, 0, 0};
float out[72] = {0};
MatrixAdd(a, b, out, 44, 56, 36, 8, 2);
EXPECT_EQ(0, lite::CompareOutputData(out, add, 72));
}
TEST_F(TestStrassenFp32, MatrixSub1) {
float a[] = {0.4160896, 0.55011475, 0.60395557, 0.964036, 0.8010256, 0.908257, 0.60170764, 0.008877548,
0.4973592, 0.6104505, 0.2957374, 0.39589414, 0.0151615525, 0.45663023, 0.3815148, 0.6419536,
0.9118046, 0.5312479, 0.104496025, 0.5972911, 0.9671534, 0.7195669, 0.23360363, 0.22078007,
0.31118092, 0.7438336, 0.5592656, 0.7212792, 0.97856164, 0.26012093, 0.18205991, 0.90656054,
0.24593723, 0, 0, 0, 0.5024593, 0, 0, 0,
0.42271087, 0, 0, 0, 0.48668534, 0, 0, 0,
0.4374295, 0, 0, 0, 0.22822042, 0, 0, 0,
0.88180095, 0, 0, 0, 0.7505223, 0, 0, 0};
float b[] = {0.14911577, 0.63214976, 0.74834836, 0.36854064, 0.5801671, 0.24166176, 0.64528674, 0.04887214,
0.23637155, 0.34321627, 0.69035923, 0.6114065, 0.73006815, 0.575073, 0.88130534, 0.72951907,
0.17092401, 0.652334, 0.6288812, 0.62121505, 0.12793411, 0.16503152, 0.7564361, 0.51976234,
0.19353953, 0.5795124, 0.6671185, 0.10646773, 0.13608798, 0.37959677, 0.24294423, 0.1790138,
0.85054415, 0, 0, 0, 0.18541782, 0, 0, 0,
0.72714496, 0, 0, 0, 0.43221787, 0, 0, 0,
0.7200413, 0, 0, 0, 0.15780604, 0, 0, 0,
0.30473796, 0, 0, 0, 0.37719592, 0, 0, 0};
float s[] = {0.26697382, -0.082035, -0.14439279, 0.59549534, 0.22085851, 0.6665952, -0.0435791, -0.03999459,
0.26098764, 0.26723424, -0.39462185, -0.21551237, -0.7149066, -0.11844277, -0.49979055, -0.08756548,
0.7408806, -0.12108606, -0.5243852, -0.02392393, 0.8392193, 0.5545354, -0.5228325, -0.29898226,
0.11764139, 0.16432118, -0.10785288, 0.6148115, 0.8424736, -0.11947584, -0.06088431, 0.72754675,
-0.6046069, 0., 0., 0., 0.31704146, 0., 0., 0.,
-0.3044341, 0., 0., 0., 0.05446747, 0., 0., 0.,
-0.2826118, 0., 0., 0., 0.07041438, 0., 0., 0.,
0.57706296, 0., 0., 0., 0.3733264, 0., 0., 0.};
float out[64] = {0};
MatrixSub(a, b, out, 32, 32, 32, 8, 2);
EXPECT_EQ(0, lite::CompareOutputData(out, s, 64));
}
TEST_F(TestStrassenFp32, MatrixSub2) {
float a[] = {0.4160896, 0.55011475, 0.60395557, 0.964036, 0.8010256, 0.908257, 0.60170764, 0.008877548,
0.4973592, 0.6104505, 0.2957374, 0.39589414, 0.0151615525, 0.45663023, 0.3815148, 0.6419536,
0.9118046, 0.5312479, 0.104496025, 0.5972911, 0.9671534, 0.7195669, 0.23360363, 0.22078007,
0.31118092, 0.7438336, 0.5592656, 0.7212792, 0.97856164, 0.26012093, 0.18205991, 0.90656054,
0.24593723, 0, 0, 0, 0.5024593, 0, 0, 0,
0.42271087, 0, 0, 0, 0.48668534, 0, 0, 0,
0.4374295, 0, 0, 0, 0.22822042, 0, 0, 0,
0.88180095, 0, 0, 0, 0.7505223, 0, 0, 0};
float b[] = {0.14911577, 0.63214976, 0.74834836, 0.36854064, 0.5801671, 0.24166176, 0.64528674, 0.04887214,
0.23637155, 0.34321627, 0.69035923, 0.6114065, 0.73006815, 0.575073, 0.88130534, 0.72951907,
0.17092401, 0.652334, 0.6288812, 0.62121505, 0.12793411, 0.16503152, 0.7564361, 0.51976234,
0.19353953, 0.5795124, 0.6671185, 0.10646773, 0.13608798, 0.37959677, 0.24294423, 0.1790138,
0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0.85054415, 0, 0, 0,
0.18541782, 0, 0, 0, 0.72714496, 0, 0, 0,
0.43221787, 0, 0, 0, 0.7200413, 0, 0, 0,
0.15780604, 0, 0, 0, 0.30473796, 0, 0, 0,
0.37719592, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0};
float s[] = {0.26697382, -0.082035, -0.14439279, 0.59549534, 0.22085851, 0.6665952, -0.0435791, -0.03999459,
0.26098764, 0.26723424, -0.39462185, -0.21551237, -0.7149066, -0.11844277, -0.49979055, -0.08756548,
0.7408806, -0.12108606, -0.5243852, -0.02392393, 0.8392193, 0.5545354, -0.5228325, -0.29898226,
0.11764139, 0.16432118, -0.10785288, 0.6148115, 0.8424736, -0.11947584, -0.06088431, 0.72754675,
0, 0, 0, 0, -0.6046069, 0., 0., 0.,
0.31704146, 0., 0., 0., -0.3044341, 0., 0., 0.,
0.05446747, 0., 0., 0., -0.2826118, 0., 0., 0.,
0.07041438, 0., 0., 0., 0.57706296, 0., 0., 0.,
0.3733264, 0., 0., 0, 0, 0, 0, 0.};
float out[72] = {0};
MatrixSub(a, b, out, 32, 44, 36, 8, 2);
EXPECT_EQ(0, lite::CompareOutputData(out, s, 72));
}
TEST_F(TestStrassenFp32, MatrixPack1) {
float in[] = {4.1795, 13.142, -3.593, 16.505, 19.969, -6.235, -2.380, -9.027, 23.622, 8.3608, 47.325, -14.36,
-0.784, 37.925, -0.081, 6.1298, 37.998, 13.719, 11.029, 1.7127, 9.0560, 14.988, 3.1866, 0.0562,
14.530, -14.10, -8.115, -8.071, 19.250, 17.923, 13.584, 3.3293, -1.514, -0.293, 18.686, 0.0873,
19.899, 8.5562, 0.0, 0.0, 9.5542, 18.974, 0.0, 0.0, 15.370, 4.3049, 0.0, 0.0,
0.6721, -1.517, 0.0, 0.0, -1.770, 41.903, 0.0, 0.0, 8.1381, 9.1391, 0.0, 0.0,
-8.158, 7.7566, 0.0, 0.0, 9.7341, 18.834, 0.0, 0.0, 4.2010, -2.253, 0.0, 0.0};
float correct[] = {4.1795, 13.142, -3.593, 16.505, 19.969, -6.235, -2.380, -9.027, 23.622, 8.3608, 47.325, -14.36,
-0.784, 37.925, -0.081, 6.1298, 19.899, 8.5562, 0.0, 0.0, 9.5542, 18.974, 0.0, 0.0,
15.370, 4.3049, 0.0, 0.0, 0.6721, -1.517, 0.0, 0.0, 37.998, 13.719, 11.029, 1.7127,
9.0560, 14.988, 3.1866, 0.0562, 14.530, -14.10, -8.115, -8.071, -1.770, 41.903, 0.0, 0.0,
8.1381, 9.1391, 0.0, 0.0, -8.158, 7.7566, 0.0, 0.0};
float out[56] = {0};
MatrixPack(in, out, 7, 2, 36);
EXPECT_EQ(0, lite::CompareOutputData(out, correct, 56));
}
TEST_F(TestStrassenFp32, MatrixPack2) {
float in[] = {4.1795, 13.142, -3.593, 16.505, 19.969, -6.235, -2.380, -9.027, 23.622, 8.3608, 47.325, -14.36,
-0.784, 37.925, -0.081, 6.1298, 37.998, 13.719, 11.029, 1.7127, 9.0560, 14.988, 3.1866, 0.0562,
14.530, -14.10, -8.115, -8.071, 19.250, 17.923, 13.584, 3.3293, -1.514, -0.293, 18.686, 0.0873,
19.899, 8.5562, 0.0, 0.0, 9.5542, 18.974, 0.0, 0.0, 15.370, 4.3049, 0.0, 0.0,
0.6721, -1.517, 0.0, 0.0, -1.770, 41.903, 0.0, 0.0, 8.1381, 9.1391, 0.0, 0.0,
-8.158, 7.7566, 0.0, 0.0, 9.7341, 18.834, 0.0, 0.0, 4.2010, -2.253, 0.0, 0.0};
float correct[] = {4.1795, 13.142, -3.593, 16.505, 19.969, -6.235, -2.380, -9.027, 23.622, 8.3608, 47.325, -14.36,
-0.784, 37.925, -0.081, 6.1298, 19.899, 8.5562, 0.0, 0.0, 9.5542, 18.974, 0.0, 0.0,
15.370, 4.3049, 0.0, 0.0, 0.6721, -1.517, 0.0, 0.0, 37.998, 13.719, 11.029, 1.7127,
9.0560, 14.988, 3.1866, 0.0562, 14.530, -14.10, -8.115, -8.071, 19.250, 17.923, 13.584, 3.3293,
-1.770, 41.903, 0.0, 0.0, 8.1381, 9.1391, 0.0, 0.0, -8.158, 7.7566, 0.0, 0.0,
9.7341, 18.834, 0.0, 0.0, -1.514, -0.293, 18.686, 0.0873, 4.2010, -2.253, 0.0, 0.0};
float out[72] = {0};
MatrixPack(in, out, 9, 2, 36);
EXPECT_EQ(0, lite::CompareOutputData(out, correct, 72));
}
TEST_F(TestStrassenFp32, CommonMatmul1) {
float a_ptr[] = {7.756654, 19.250782, 17.923292, 0, 13.584222, 3.3293908, 9.734102, 0,
18.83455, -1.51425, -0.29382, 0, 18.686155, 0.0873076, 4.2010098, 0,
-2.2539594, 4.1795673, 13.14235, 0, -3.59393, 16.50578, 19.899279, 0,
8.556229, 19.969376, -6.2355065, 0, -2.380469, -9.027744, 9.5542, 0};
float b_ptr[] = {0.2674241, 0.089372, -0.081915, 2.0580146, -0.295045, 1.377944, 0.703658, 1.055378,
1.204049, -0.256505, -0.309640, 0.560465, 0, 0, 0, 0,
0.646906, 0, 0, 0, -0.168206, 0, 0, 0,
-0.95630, 0, 0, 0, 0, 0, 0, 0};
float correct[] = {17.97499, 22.622334, 7.360805, 46.325558, 14.37076, 3.304931, -1.784072, 36.925926,
5.129812, -0.3278886, -2.517368, 36.99899, 10.029593, 0.7127603, -2.77004, 40.90305,
13.988123, 2.186689, -0.943787, 7.138184, 18.128653, 17.31859, 5.7472067, 21.176342,
-11.11159, 29.880829, 15.281498, 35.1893, 13.530734, -15.10318, -9.11581, -9.071925,
-15.36046, 0, 0, 0, -1.081104, 0, 0, 0,
12.719885, 0, 0, 0, 8.056052, 0, 0, 0,
-14.72927, 0, 0, 0, -24.1311, 0, 0, 0,
8.139168, 0, 0, 0, -9.158176, 0, 0, 0};
StrassenMatMulParameter *matmul_param = new StrassenMatMulParameter();
matmul_param->row_ = 8;
matmul_param->deep_ = 1;
matmul_param->col_ = 2;
matmul_param->a_stride_ = 32;
matmul_param->b_stride_ = 16;
matmul_param->c_stride_ = 32;
float c_ptr[64] = {0};
float tmp_ptr[32];
CommonMatMul(a_ptr, b_ptr, c_ptr, matmul_param, tmp_ptr);
EXPECT_EQ(0, lite::CompareOutputData(c_ptr, correct, 64));
delete matmul_param;
}
TEST_F(TestStrassenFp32, CommonMatmul2) {
StrassenMatMulParameter *matmul_param = new StrassenMatMulParameter();
float a[] = {4.864725, 6.830073, 0.76780415, 8.922394, 5.096872, 2.4946148, 4.2148714, 1.7762588, 0.89195687,
9.703938, 2.0654619, 9.048538, 2.358036, 5.643526, 2.5152204, 3.512572, 3.7913973, 3.7136157,
8.820186, 1.5324963, 3.135459, 7.5792265, 7.1820426, 0.267987, 8.737802, 4.064117, 2.7232447,
0.27355433, 0, 0, 0, 0, 0, 0, 0, 0,
6.320409, 9.479354, 0, 0, 1.6220464, 0.57753897, 0, 0, 9.786372,
6.0404425, 0, 0, 2.1067812, 4.8034563, 0, 0, 2.1140356, 8.204062,
0, 0, 3.29985, 1.2034118, 0, 0, 7.6059656, 4.162436, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0};
float b[] = {
4.4558744, 0.6383263, 0.05037839, 9.730914, 8.1542015, 4.3625517, 8.654026, 3.805875, 9.845131, 4.08051,
9.667656, 7.73955, 9.283867, 8.465257, 2.292051, 9.853942, 0.13320169, 3.8789113, 9.460265, 4.2616735,
0.23831692, 4.420147, 0.5355651, 7.829217, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 1.9866786, 0, 0, 0, 6.0188327, 0,
0, 0, 6.6249146, 0, 0, 0, 3.5639563, 0, 0, 0,
0.14810833, 0, 0, 0, 7.4168983, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0};
float c[] = {170.86482, 177.98166, 152.0957, 268.3473, 101.39282, 55.216248, 82.31873, 120.65008, 190.18558,
192.58974, 220.54767, 239.75931, 115.32386, 95.52758, 103.82857, 145.08948, 150.4757, 112.04814,
145.50496, 207.63342, 149.6962, 84.76027, 167.65851, 141.06763, 103.42963, 84.63687, 136.74927,
189.26935, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 158.90288, 0, 0, 0, 63.917973,
0, 0, 0, 152.3613, 0, 0, 0, 103.77265, 0,
0, 0, 154.94044, 0, 0, 0, 109.79707, 0, 0,
0, 92.83551, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0};
matmul_param->row_ = 7;
matmul_param->deep_ = 2;
matmul_param->col_ = 2;
matmul_param->a_stride_ = 36;
matmul_param->b_stride_ = 64;
matmul_param->c_stride_ = 40;
float out[80] = {0};
float tmp_ptr[1000];
CommonMatMul(a, b, out, matmul_param, tmp_ptr);
EXPECT_EQ(0, lite::CompareOutputData(out, c, 80));
delete (matmul_param);
}
TEST_F(TestStrassenFp32, RecMatmul1) {
StrassenMatMulParameter *matmul_param = new StrassenMatMulParameter();
matmul_param->row_ = 4;
matmul_param->deep_ = 2;
matmul_param->col_ = 2;
matmul_param->a_stride_ = 16;
matmul_param->b_stride_ = 32;
matmul_param->c_stride_ = 16;
float a[] = {9.02165, 8.657163, 0.56371903, 0.7272156, 1.6258951, 9.919627, 7.47593, 3.5311592,
8.958062, 0.55338514, 9.611276, 7.429841, 8.23804, 3.7503464, 1.2829816, 6.4470887,
4.303486, 6.282502, 0, 0, 9.4194765, 7.8199654, 0, 0,
6.738705, 7.5398073, 0, 0, 0.47684374, 0.87746763, 0, 0};
float b[] = {1.8100919, 6.016964, 5.733568, 5.768448, 2.2823029, 2.173359, 0.56861514, 7.134393,
0.26377398, 3.9010656, 4.868408, 0.33401546, 1.7973539, 8.21896, 5.62239, 8.54786,
0.97356945, 1.0714527, 6.447588, 6.161091, 3.332229, 2.8775468, 6.558747, 2.6986659,
0, 0, 0, 0, 0, 0, 0, 0,
1.9830805, 0, 0, 0, 8.44718, 0, 0, 0,
9.360418, 0, 0, 0, 6.220693, 0, 0, 0,
1.8369701, 0, 0, 0, 4.3965054, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0};
float c[] = {62.668518, 103.9633, 132.43439, 163.67749, 69.12974, 122.12326, 183.23413, 191.96806,
65.052124, 182.57918, 233.14148, 184.20694, 38.785316, 118.74806, 100.689575, 135.12036,
136.34613, 0, 0, 0, 230.64507, 0, 0, 0,
204.15103, 0, 0, 0, 104.86488, 0, 0, 0};
float out[32] = {0};
float tmp_ptr[1000];
RecursionMatmul(a, b, out, matmul_param, 1, 0, tmp_ptr);
EXPECT_EQ(0, lite::CompareOutputData(out, c, 32));
delete (matmul_param);
}
TEST_F(TestStrassenFp32, RecMatmul2) {
StrassenMatMulParameter *matmul_param = new StrassenMatMulParameter();
matmul_param->row_ = 4;
matmul_param->deep_ = 2;
matmul_param->col_ = 2;
matmul_param->a_stride_ = 32;
matmul_param->b_stride_ = 64;
matmul_param->c_stride_ = 32;
float a[] = {9.02165, 8.657163, 0.56371903, 0.7272156, 1.6258951, 9.919627, 7.47593, 3.5311592,
8.958062, 0.55338514, 9.611276, 7.429841, 8.23804, 3.7503464, 1.2829816, 6.4470887,
1, 2, 3, 4, 1, 2, 3, 4,
3, 2, 3, 4, 4, 2, 3, 4,
4.303486, 6.282502, 0, 0, 9.4194765, 7.8199654, 0, 0,
6.738705, 7.5398073, 0, 0, 0.47684374, 0.87746763, 0, 0,
1, 2, 3, 4, 1, 2, 3, 4,
3, 2, 3, 4, 4, 2, 3, 4};
float b[] = {
1.8100919, 6.016964, 5.733568, 5.768448, 2.2823029, 2.173359, 0.56861514, 7.134393, 0.26377398, 3.9010656,
4.868408, 0.33401546, 1.7973539, 8.21896, 5.62239, 8.54786, 0.97356945, 1.0714527, 6.447588, 6.161091,
3.332229, 2.8775468, 6.558747, 2.6986659, 0, 0, 0, 0, 0, 0,
0, 0, 11, 2, 3, 4, 22, 2, 3, 4,
33, 3, 3, 4, 44, 2, 3, 4, 11, 2,
3, 4, 22, 2, 3, 4, 33, 3, 3, 4,
44, 2, 3, 4, 1.9830805, 0, 0, 0, 8.44718, 0,
0, 0, 9.360418, 0, 0, 0, 6.220693, 0, 0, 0,
1.8369701, 0, 0, 0, 4.3965054, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 11, 2, 3, 4,
22, 2, 3, 4, 33, 3, 3, 4, 44, 2,
3, 4, 11, 2, 3, 4, 22, 2, 3, 4,
33, 3, 3, 4, 44, 2, 3, 4};
float c[] = {62.668518, 103.9633, 132.43439, 163.67749, 69.12974, 122.12326, 183.23413, 191.96806,
65.052124, 182.57918, 233.14148, 184.20694, 38.785316, 118.74806, 100.689575, 135.12036,
0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0,
136.34613, 0, 0, 0, 230.64507, 0, 0, 0,
204.15103, 0, 0, 0, 104.86488, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0};
float out[64] = {0};
float tmp_ptr[1000];
RecursionMatmul(a, b, out, matmul_param, 1, 0, tmp_ptr);
EXPECT_EQ(0, lite::CompareOutputData(out, c, 64));
delete (matmul_param);
}
} // namespace mindspore