forked from mindspore-Ecosystem/mindspore
[MSLITE][Develop] fix arm32 deconv bug
This commit is contained in:
parent
788ff3366f
commit
d257fec69c
|
@ -300,6 +300,7 @@ LoopCol:
|
|||
vst1.32 {q10, q11}, [r2]!
|
||||
vst1.32 {q12, q13}, [r2]!
|
||||
vst1.32 {q14, q15}, [r2]!
|
||||
str r2, [sp, #-40]
|
||||
b WriteEnd
|
||||
WriteWino:
|
||||
vst1.32 {q8, q9}, [r2]
|
||||
|
|
|
@ -105,7 +105,6 @@ void RowMajor2Col16MajorFp16(float16_t *src_ptr, float16_t *dst_ptr, size_t row,
|
|||
float16_t *src_c = src_r + ci;
|
||||
float16_t *dst_c = dst_r + ci * C16NUM;
|
||||
|
||||
/* 16*8 row-major to col-major */
|
||||
#ifdef ENABLE_ARM64
|
||||
size_t stride = col * 2;
|
||||
asm volatile(
|
||||
|
@ -256,16 +255,6 @@ void Fp32RowMajor2Fp16Col16Major(float *src, float16_t *dst, size_t row, size_t
|
|||
}
|
||||
}
|
||||
|
||||
void Fp16RowMajor2Fp16Col16Major(float16_t *src, float16_t *dst, size_t row, size_t col) {
|
||||
for (int r = 0; r < row; r++) {
|
||||
for (int c = 0; c < col; c++) {
|
||||
int r_div16 = r / 16;
|
||||
int r_mod16 = r % 16;
|
||||
dst[r_div16 * 16 * col + c * 16 + r_mod16] = src[r * col + c];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void Fp32RowMajor2Fp16Row16Major(float *src, float16_t *dst, size_t row, size_t col) {
|
||||
for (int r = 0; r < row; r++) {
|
||||
for (int c = 0; c < col; c++) {
|
||||
|
|
|
@ -41,8 +41,6 @@ void MatmulFp16Neon64(const float16_t *a, const float16_t *b, float16_t *c, cons
|
|||
|
||||
void Fp32RowMajor2Fp16Col16Major(float *src, float16_t *dst, size_t row, size_t col);
|
||||
|
||||
void Fp16RowMajor2Fp16Col16Major(float16_t *src, float16_t *dst, size_t row, size_t col);
|
||||
|
||||
void Fp32RowMajor2Fp16Row16Major(float *src, float16_t *dst, size_t row, size_t col);
|
||||
|
||||
void Fp16RowMajor2Fp16Row16Major(float16_t *src, float16_t *dst, size_t row, size_t col);
|
||||
|
|
|
@ -378,112 +378,9 @@ void RowMajor2Col4Major(float *src_ptr, float *dst_ptr, size_t row, size_t col)
|
|||
return;
|
||||
}
|
||||
|
||||
void MatrixUnPackUnit(const void *src, void *dst, size_t row, size_t col, size_t src_stride, size_t dst_stride,
|
||||
size_t data_lenth) {
|
||||
size_t copy_size = col * data_lenth;
|
||||
size_t src_size = src_stride * data_lenth;
|
||||
size_t dst_size = dst_stride * data_lenth;
|
||||
char *src_ptr = (char *)src;
|
||||
char *dst_ptr = (char *)dst;
|
||||
for (int r = 0; r < row; r++) {
|
||||
memcpy(dst_ptr, src_ptr, copy_size);
|
||||
src_ptr += src_size;
|
||||
dst_ptr += dst_size;
|
||||
}
|
||||
}
|
||||
|
||||
void Row8x8Major2RowMajor(float *src_ptr, float *dst_ptr, size_t row, size_t col, size_t stride) {
|
||||
size_t row_up8 = UP_ROUND(row, C8NUM);
|
||||
size_t row_8div = row / C8NUM * C8NUM;
|
||||
size_t row_8res = row - row_8div;
|
||||
size_t col_8div = col / C8NUM * C8NUM;
|
||||
size_t col_8res = col - col_8div;
|
||||
float *src_c = src_ptr;
|
||||
float *dst_c = dst_ptr;
|
||||
|
||||
for (size_t ci = 0; ci < col_8div; ci += C8NUM) {
|
||||
#ifdef ENABLE_ARM64
|
||||
size_t offset = stride * 4 - 16;
|
||||
asm volatile(
|
||||
"mov x0, #0 \n"
|
||||
"mov x1, %[row_8div] \n"
|
||||
"mov x10, %[src_c] \n"
|
||||
"mov x11, %[dst_c] \n"
|
||||
|
||||
"1: \n"
|
||||
"cmp x0, x1 \n"
|
||||
"beq 2f \n"
|
||||
|
||||
"ld1 {v0.4s}, [x10], #16\n"
|
||||
"ld1 {v1.4s}, [x10], #16\n"
|
||||
"ld1 {v2.4s}, [x10], #16\n"
|
||||
"ld1 {v3.4s}, [x10], #16\n"
|
||||
"ld1 {v4.4s}, [x10], #16\n"
|
||||
"ld1 {v5.4s}, [x10], #16\n"
|
||||
"ld1 {v6.4s}, [x10], #16\n"
|
||||
"ld1 {v7.4s}, [x10], #16\n"
|
||||
"ld1 {v8.4s}, [x10], #16\n"
|
||||
"ld1 {v9.4s}, [x10], #16\n"
|
||||
"ld1 {v10.4s}, [x10], #16\n"
|
||||
"ld1 {v11.4s}, [x10], #16\n"
|
||||
"ld1 {v12.4s}, [x10], #16\n"
|
||||
"ld1 {v13.4s}, [x10], #16\n"
|
||||
"ld1 {v14.4s}, [x10], #16\n"
|
||||
"ld1 {v15.4s}, [x10], #16\n"
|
||||
|
||||
"add x0, x0, #8\n"
|
||||
|
||||
"st1 {v0.4s}, [x11], #16\n"
|
||||
"st1 {v1.4s}, [x11], %[offset]\n"
|
||||
"st1 {v2.4s}, [x11], #16\n"
|
||||
"st1 {v3.4s}, [x11], %[offset]\n"
|
||||
"st1 {v4.4s}, [x11], #16\n"
|
||||
"st1 {v5.4s}, [x11], %[offset]\n"
|
||||
"st1 {v6.4s}, [x11], #16\n"
|
||||
"st1 {v7.4s}, [x11], %[offset]\n"
|
||||
"st1 {v8.4s}, [x11], #16\n"
|
||||
"st1 {v9.4s}, [x11], %[offset]\n"
|
||||
"st1 {v10.4s}, [x11], #16\n"
|
||||
"st1 {v11.4s}, [x11], %[offset]\n"
|
||||
"st1 {v12.4s}, [x11], #16\n"
|
||||
"st1 {v13.4s}, [x11], %[offset]\n"
|
||||
"st1 {v14.4s}, [x11], #16\n"
|
||||
"st1 {v15.4s}, [x11], %[offset]\n"
|
||||
"b 1b\n"
|
||||
|
||||
"2:\n"
|
||||
|
||||
:
|
||||
: [ dst_c ] "r"(dst_c), [ src_c ] "r"(src_c), [ offset ] "r"(offset), [ row_8div ] "r"(row_8div)
|
||||
: "x0", "x1", "x10", "x11", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12",
|
||||
"v13", "v14", "v15");
|
||||
#else
|
||||
for (size_t ri = 0; ri < row_8div; ri += C8NUM) {
|
||||
float *src_r = src_c + ri * C8NUM;
|
||||
float *dst_r = dst_c + ri * stride;
|
||||
MatrixUnPackUnit(src_r, dst_r, C8NUM, C8NUM, C8NUM, stride, sizeof(float));
|
||||
}
|
||||
#endif
|
||||
|
||||
if (row != row_8div) {
|
||||
float *src_r = src_c + row_8div * C8NUM;
|
||||
float *dst_r = dst_c + row_8div * stride;
|
||||
MatrixUnPackUnit(src_r, dst_r, row_8res, C8NUM, C8NUM, stride, sizeof(float));
|
||||
}
|
||||
src_c += row_up8 * C8NUM;
|
||||
dst_c += C8NUM;
|
||||
}
|
||||
|
||||
if (col != col_8div) {
|
||||
MatrixUnPackUnit(src_c, dst_c, row, col_8res, C8NUM, stride, sizeof(float));
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
void MatMul12x8(const float *a, const float *b, float *dst, const float *bias, ActType act_type, int deep, int row,
|
||||
int col, int stride, int out_type) {
|
||||
if (out_type == OutType_Nhwc) {
|
||||
/* col8-major * row8-major => col-major */
|
||||
for (int r = 0; r < row; r++) {
|
||||
for (int c = 0; c < col; c++) {
|
||||
int r12div = r / 12, r12mod = r % 12;
|
||||
|
@ -502,7 +399,6 @@ void MatMul12x8(const float *a, const float *b, float *dst, const float *bias, A
|
|||
}
|
||||
}
|
||||
} else if (out_type == OutType_C8) {
|
||||
/* col8-major * row8-major => col12x8-major */
|
||||
int col_8 = UP_ROUND(col, C8NUM);
|
||||
int row_12 = UP_ROUND(row, C12NUM);
|
||||
for (int r = 0; r < row_12; r++) {
|
||||
|
@ -545,6 +441,32 @@ void MatMul12x8(const float *a, const float *b, float *dst, const float *bias, A
|
|||
return;
|
||||
}
|
||||
|
||||
void MatMul4x8(const float *a, const float *b, float *dst, const float *bias, ActType act_type, int deep, int row,
|
||||
int col, int stride, int out_type) {
|
||||
if (out_type == OutType_C8) {
|
||||
int col_8 = UP_ROUND(col, C8NUM);
|
||||
int row_4 = UP_ROUND(row, C4NUM);
|
||||
for (int r = 0; r < row_4; r++) {
|
||||
for (int c = 0; c < col_8; c++) {
|
||||
int r4div = r / C4NUM, r4mod = r % C4NUM;
|
||||
int c8div = c / C8NUM, c8mod = c % C8NUM;
|
||||
size_t ci = (c8div * C8NUM * row_4 + r * C8NUM + c8mod);
|
||||
float value = 0;
|
||||
for (int d = 0; d < deep; d++) {
|
||||
size_t ai = r4div * deep * C4NUM + d * C4NUM + r4mod;
|
||||
size_t bi = c8div * deep * C8NUM + d * C8NUM + c8mod;
|
||||
value = value + a[ai] * b[bi];
|
||||
}
|
||||
if (bias != NULL) value += bias[c];
|
||||
if (act_type == ActType_Relu6) value = MSMIN(6.0f, value);
|
||||
if (act_type != ActType_No) value = MSMAX(0.0f, value);
|
||||
dst[ci] = value;
|
||||
}
|
||||
}
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
void MatMulOpt(const float *a, const float *b, float *c, const float *bias, ActType act_type, int deep, int row,
|
||||
int col, size_t stride, int out_type) {
|
||||
#ifdef ENABLE_ARM64
|
||||
|
@ -556,7 +478,7 @@ void MatMulOpt(const float *a, const float *b, float *c, const float *bias, ActT
|
|||
}
|
||||
#elif ENABLE_ARM32
|
||||
MatmulFloatNeon32Opt(a, b, c, bias, (int)act_type, deep, row, col, stride, (int)(out_type == OutType_Nhwc),
|
||||
(int)(out_type == OutType_TileC8));
|
||||
(int)(out_type == OutType_TileC8));
|
||||
#else
|
||||
MatMul12x8(a, b, c, bias, act_type, deep, row, col, stride, out_type);
|
||||
#endif
|
||||
|
|
|
@ -28,14 +28,12 @@ extern "C" {
|
|||
#endif
|
||||
void MatMulOpt(const float *a, const float *b, float *c, const float *bias, ActType act_type, int deep, int row,
|
||||
int col, size_t stride, int out_type);
|
||||
|
||||
void RowMajor2Row4Major(float *src_ptr, float *dst_ptr, int row, int col);
|
||||
void RowMajor2Row8Major(float *src_ptr, float *dst_ptr, int row, int col);
|
||||
void RowMajor2Row12Major(float *src_ptr, float *dst_ptr, int row, int col);
|
||||
void RowMajor2Col4Major(float *src_ptr, float *dst_ptr, size_t row, size_t col);
|
||||
void RowMajor2Col8Major(float *src_ptr, float *dst_ptr, size_t row, size_t col);
|
||||
void RowMajor2Col12Major(float *src_ptr, float *dst_ptr, size_t row, size_t col);
|
||||
void Row8x8Major2RowMajor(float *src_ptr, float *dst_ptr, size_t row, size_t col, size_t stride);
|
||||
#ifdef ENABLE_ARM64
|
||||
void MatmulFloatNeon64(const float *a, const float *b, float *c, const float *bias, int act_type, int depth, int row,
|
||||
int col, size_t stride, bool write_nhwc);
|
||||
|
|
|
@ -33,34 +33,22 @@ class OptimizeModule {
|
|||
public:
|
||||
OptimizeModule() {
|
||||
bool support_optimize_ops = false;
|
||||
#ifdef __ANDROID__
|
||||
|
||||
#ifdef ENABLE_ARM64
|
||||
int hwcap_type = 16;
|
||||
uint32_t hwcap = getHwCap(hwcap_type);
|
||||
#ifdef ENABLE_ARM64
|
||||
if (hwcap & HWCAP_FPHP) {
|
||||
#elif defined(__arm__)
|
||||
if (hwcap & HWCAP_HALF) {
|
||||
#endif
|
||||
|
||||
#ifdef ENABLE_ARM64
|
||||
}
|
||||
#elif defined(__arm__)
|
||||
}
|
||||
#endif
|
||||
|
||||
#ifdef ENABLE_ARM64
|
||||
if (hwcap & HWCAP_ASIMDDP) {
|
||||
MS_LOG(INFO) << "Hw cap support SMID Dot Product, hwcap: 0x" << hwcap;
|
||||
support_optimize_ops = true;
|
||||
} else {
|
||||
MS_LOG(INFO) << "Hw cap NOT support SIMD Dot Product, hwcap: 0x" << hwcap;
|
||||
}
|
||||
#endif
|
||||
#endif
|
||||
if (support_optimize_ops == false) {
|
||||
return;
|
||||
}
|
||||
#ifndef _WIN32
|
||||
#ifdef ENABLE_ARM64
|
||||
optimized_op_handler_ = dlopen(OPTIMIZE_SHARED_LIBRARY_PATH, RTLD_LAZY);
|
||||
if (optimized_op_handler_ == nullptr) {
|
||||
MS_LOG(INFO) << "Open optimize shared library failed: " << dlerror();
|
||||
|
@ -81,26 +69,19 @@ class Float16Module {
|
|||
public:
|
||||
Float16Module() {
|
||||
bool support_fp16 = false;
|
||||
#ifdef __ANDROID__
|
||||
#ifdef ENABLE_ARM64
|
||||
int hwcap_type = 16;
|
||||
uint32_t hwcap = getHwCap(hwcap_type);
|
||||
#ifdef ENABLE_ARM64
|
||||
|
||||
if (hwcap & HWCAP_FPHP) {
|
||||
#elif defined(__arm__)
|
||||
if (hwcap & HWCAP_HALF) {
|
||||
#endif
|
||||
MS_LOG(INFO) << "Hw cap support FP16, hwcap: 0x" << hwcap;
|
||||
support_fp16 = true;
|
||||
#ifdef ENABLE_ARM64
|
||||
}
|
||||
#elif defined(__arm__)
|
||||
}
|
||||
#endif
|
||||
#endif
|
||||
if (support_fp16 == false) {
|
||||
return;
|
||||
}
|
||||
#ifndef _WIN32
|
||||
#ifdef ENABLE_ARM64
|
||||
float16_op_handler_ = dlopen(FLOAT16_SHARED_LIBRARY_PATH, RTLD_LAZY);
|
||||
if (float16_op_handler_ == nullptr) {
|
||||
MS_LOG(INFO) << "Open optimize shared library failed: " << dlerror();
|
||||
|
|
|
@ -99,7 +99,7 @@ void FullconnectionFP16CPUKernel::InitMatrixA(float *a_ptr, float16_t *a_pack_pt
|
|||
}
|
||||
|
||||
void FullconnectionFP16CPUKernel::InitMatrixA(float16_t *a_ptr, float16_t *a_pack_ptr) {
|
||||
Fp16RowMajor2Fp16Col16Major(a_ptr, a_pack_ptr, fc_param_->row_, fc_param_->deep_);
|
||||
RowMajor2Col16MajorFp16(a_ptr, a_pack_ptr, fc_param_->row_, fc_param_->deep_);
|
||||
}
|
||||
|
||||
void FullconnectionFP16CPUKernel::InitMatrixB(float *b_ptr, float16_t *b_pack_ptr) {
|
||||
|
|
|
@ -118,8 +118,8 @@ int MatmulFP16CPUKernel::ReSize() {
|
|||
output_ptr_ = reinterpret_cast<float16_t *>(
|
||||
ctx_->allocator->Malloc(params_->batch * params_->row_ * params_->col_ * sizeof(float16_t)));
|
||||
if (output_ptr_ == nullptr) {
|
||||
MS_LOG(ERROR) << "malloc output_ptr_ failed.";
|
||||
return RET_MEMORY_FAILED;
|
||||
MS_LOG(ERROR) << "malloc output_ptr_ failed.";
|
||||
return RET_MEMORY_FAILED;
|
||||
}
|
||||
}
|
||||
return RET_OK;
|
||||
|
@ -144,7 +144,7 @@ void MatmulFP16CPUKernel::InitMatrixA(float16_t *a_ptr, float16_t *a_pack_ptr) {
|
|||
if (params_->a_transpose_) {
|
||||
Fp16RowMajor2Fp16Row16Major(src, dst, params_->deep_, params_->row_);
|
||||
} else {
|
||||
Fp16RowMajor2Fp16Col16Major(src, dst, params_->row_, params_->deep_);
|
||||
RowMajor2Col16MajorFp16(src, dst, params_->row_, params_->deep_);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -212,7 +212,11 @@ int DeConvolutionCPUKernel::Run() {
|
|||
input_ptr_ = src_in + batch_index * input_plane_ * conv_param_->input_channel_;
|
||||
output_ptr_ = src_out + batch_index * output_plane_ * conv_param_->output_channel_;
|
||||
|
||||
RowMajor2Col12Major(input_ptr_, pack_input_, input_plane_, conv_param_->input_channel_);
|
||||
#ifdef ENABLE_ARM32
|
||||
RowMajor2Col4Major(input_ptr_, pack_input_, matmul_param_->row_, matmul_param_->deep_);
|
||||
#else
|
||||
RowMajor2Col12Major(input_ptr_, pack_input_, matmul_param_->row_, matmul_param_->deep_);
|
||||
#endif
|
||||
|
||||
error_code = ParallelLaunch(this->context_->thread_pool_, DeConvFp32Run, this, thread_count_);
|
||||
if (error_code != RET_OK) {
|
||||
|
|
|
@ -68,117 +68,6 @@ TEST_F(TestMatMulFp32, Row2Col8Test2) {
|
|||
CompareOutputData(out, co, 120, 0.0001);
|
||||
}
|
||||
|
||||
TEST_F(TestMatMulFp32, Row8x82RowTest1) {
|
||||
float in[] = {0.21, 0.38, 0.81, 0.98, 0.09, 0, 0, 0, 0.68, 0.02, 0.33, 0.85, 0.67, 0, 0, 0,
|
||||
0.81, 0.57, 0.70, 0.27, 0.90, 0, 0, 0, 0.07, 0.13, 0.03, 0.53, 0.97, 0, 0, 0,
|
||||
0.92, 0.35, 0.74, 0.78, 0.87, 0, 0, 0, 0.23, 0.34, 0.09, 0.50, 0.39, 0, 0, 0,
|
||||
0.09, 0.93, 0.91, 0.20, 0.97, 0, 0, 0, 0.61, 0.43, 0.14, 0.67, 0.10, 0, 0, 0,
|
||||
0.73, 0.37, 0.24, 0.93, 0.31, 0, 0, 0, 0.35, 0.52, 0.02, 0.33, 0.99, 0, 0, 0,
|
||||
0.49, 0.67, 0.75, 0.66, 0.04, 0, 0, 0, 0.10, 0.18, 0.92, 0.46, 0.08, 0, 0, 0,
|
||||
0.04, 0.24, 0.52, 0.43, 0.14, 0, 0, 0, 0.67, 0.10, 0.73, 0.37, 0.24, 0, 0, 0,
|
||||
0.93, 0.31, 0.35, 0.52, 0.02, 0, 0, 0, 0.33, 0.99, 0.49, 0.67, 0.75, 0, 0, 0,
|
||||
0.66, 0.04, 0.10, 0.18, 0.92, 0, 0, 0, 0.46, 0.08, 0.04, 0.24, 0.52, 0, 0, 0,
|
||||
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
|
||||
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
|
||||
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0};
|
||||
float co[] = {0.21, 0.38, 0.81, 0.98, 0.09, 0.68, 0.02, 0.33, 0.85, 0.67, 0.81, 0.57, 0.70, 0.27, 0.90,
|
||||
0.07, 0.13, 0.03, 0.53, 0.97, 0.92, 0.35, 0.74, 0.78, 0.87, 0.23, 0.34, 0.09, 0.50, 0.39,
|
||||
0.09, 0.93, 0.91, 0.20, 0.97, 0.61, 0.43, 0.14, 0.67, 0.10, 0.73, 0.37, 0.24, 0.93, 0.31,
|
||||
0.35, 0.52, 0.02, 0.33, 0.99, 0.49, 0.67, 0.75, 0.66, 0.04, 0.10, 0.18, 0.92, 0.46, 0.08,
|
||||
0.04, 0.24, 0.52, 0.43, 0.14, 0.67, 0.10, 0.73, 0.37, 0.24, 0.93, 0.31, 0.35, 0.52, 0.02,
|
||||
0.33, 0.99, 0.49, 0.67, 0.75, 0.66, 0.04, 0.10, 0.18, 0.92, 0.46, 0.08, 0.04, 0.24, 0.52};
|
||||
float out[90] = {0};
|
||||
Row8x8Major2RowMajor(in, out, 18, 5, 5);
|
||||
CompareOutputData(out, co, 90, 0.0001);
|
||||
}
|
||||
|
||||
TEST_F(TestMatMulFp32, Row8x82RowTest2) {
|
||||
float in[] = {0.21, 0.38, 0.81, 0.98, 0.09, 0, 0, 0, 0.68, 0.02, 0.33, 0.85, 0.67, 0, 0, 0,
|
||||
0.81, 0.57, 0.70, 0.27, 0.90, 0, 0, 0, 0.07, 0.13, 0.03, 0.53, 0.97, 0, 0, 0,
|
||||
0.92, 0.35, 0.74, 0.78, 0.87, 0, 0, 0, 0.23, 0.34, 0.09, 0.50, 0.39, 0, 0, 0,
|
||||
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0};
|
||||
float co[] = {0.21, 0.38, 0.81, 0.98, 0.09, 0.68, 0.02, 0.33, 0.85, 0.67, 0.81, 0.57, 0.70, 0.27, 0.90,
|
||||
0.07, 0.13, 0.03, 0.53, 0.97, 0.92, 0.35, 0.74, 0.78, 0.87, 0.23, 0.34, 0.09, 0.50, 0.39};
|
||||
float out[30] = {0};
|
||||
Row8x8Major2RowMajor(in, out, 6, 5, 5);
|
||||
CompareOutputData(out, co, 30, 0.0001);
|
||||
}
|
||||
|
||||
TEST_F(TestMatMulFp32, Row8x82RowTest3) {
|
||||
float in[] = {
|
||||
0.21, 0.38, 0.81, 0.98, 0.09, 0.68, 0.02, 0.33, 0.97, 0.92, 0.35, 0.74, 0.78, 0.87, 0.23, 0.34, 0.67, 0.10, 0.73,
|
||||
0.37, 0.24, 0.93, 0.31, 0.35, 0.92, 0.46, 0.08, 0.04, 0.24, 0.52, 0.43, 0.14, 0.99, 0.49, 0.67, 0.75, 0.66, 0.04,
|
||||
0.10, 0.18, 0.68, 0.02, 0.33, 0.85, 0.67, 0.81, 0.57, 0.70, 0.87, 0.23, 0.34, 0.09, 0.50, 0.39, 0.09, 0.93, 0.93,
|
||||
0.31, 0.35, 0.52, 0.02, 0.33, 0.99, 0.49, 0.52, 0.43, 0.14, 0.67, 0.10, 0.73, 0.37, 0.24, 0.75, 0.66, 0.04, 0.10,
|
||||
0.18, 0.92, 0.46, 0.08, 0.31, 0.35, 0.52, 0.02, 0.33, 0.99, 0.49, 0.67, 0.21, 0.38, 0.81, 0.98, 0.09, 0.68, 0.02,
|
||||
0.33, 0.97, 0.92, 0.35, 0.74, 0.78, 0.87, 0.23, 0.34, 0.67, 0.10, 0.73, 0.37, 0.24, 0.93, 0.31, 0.35, 0.92, 0.46,
|
||||
0.08, 0.04, 0.24, 0.52, 0.43, 0.14, 0.68, 0.02, 0.33, 0.85, 0.67, 0.81, 0.57, 0.70, 0.87, 0.23, 0.34, 0.09, 0.50,
|
||||
0.39, 0.09, 0.93, 0.93, 0.31, 0.35, 0.52, 0.02, 0.33, 0.99, 0.49, 0.52, 0.43, 0.14, 0.67, 0.10, 0.73, 0.37, 0.24,
|
||||
0.75, 0.66, 0.04, 0.10, 0.18, 0.92, 0.46, 0.08, 0.31, 0.35, 0.52, 0.02, 0.33, 0.99, 0.49, 0.67, 0.21, 0.38, 0.81,
|
||||
0.98, 0.09, 0.68, 0.02, 0.33, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
|
||||
0, 0, 0.85, 0.67, 0.81, 0.57, 0.70, 0.27, 0.90, 0.07, 0.09, 0.50, 0.39, 0.09, 0.93, 0.91, 0.20, 0.97, 0.52,
|
||||
0.02, 0.33, 0.99, 0.49, 0.67, 0.75, 0.66, 0.67, 0.10, 0.73, 0.37, 0.24, 0.93, 0.31, 0.35, 0.92, 0.46, 0.08, 0.04,
|
||||
0.24, 0.52, 0.21, 0.38, 0.27, 0.90, 0.07, 0.13, 0.03, 0.53, 0.97, 0.92, 0.91, 0.20, 0.97, 0.61, 0.43, 0.14, 0.67,
|
||||
0.10, 0.67, 0.75, 0.66, 0.04, 0.10, 0.18, 0.92, 0.46, 0.93, 0.31, 0.35, 0.52, 0.02, 0.33, 0.99, 0.49, 0.04, 0.24,
|
||||
0.52, 0.43, 0.14, 0.67, 0.10, 0.73, 0.75, 0.66, 0.04, 0.10, 0.18, 0.92, 0.46, 0.08, 0.85, 0.67, 0.81, 0.57, 0.70,
|
||||
0.27, 0.90, 0.07, 0.09, 0.50, 0.39, 0.09, 0.93, 0.91, 0.20, 0.97, 0.52, 0.02, 0.33, 0.99, 0.49, 0.67, 0.75, 0.66,
|
||||
0.67, 0.10, 0.73, 0.37, 0.24, 0.93, 0.31, 0.35, 0.27, 0.90, 0.07, 0.13, 0.03, 0.53, 0.97, 0.92, 0.91, 0.20, 0.97,
|
||||
0.61, 0.43, 0.14, 0.67, 0.10, 0.67, 0.75, 0.66, 0.04, 0.10, 0.18, 0.92, 0.46, 0.93, 0.31, 0.35, 0.52, 0.02, 0.33,
|
||||
0.99, 0.49, 0.04, 0.24, 0.52, 0.43, 0.14, 0.67, 0.10, 0.73, 0.75, 0.66, 0.04, 0.10, 0.18, 0.92, 0.46, 0.08, 0.85,
|
||||
0.67, 0.81, 0.57, 0.70, 0.27, 0.90, 0.07, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
|
||||
0, 0, 0, 0, 0.13, 0.03, 0.53, 0, 0, 0, 0, 0, 0.61, 0.43, 0.14, 0, 0, 0, 0,
|
||||
0, 0.04, 0.10, 0.18, 0, 0, 0, 0, 0, 0.52, 0.02, 0.33, 0, 0, 0, 0, 0, 0.81, 0.98,
|
||||
0.09, 0, 0, 0, 0, 0, 0.35, 0.74, 0.78, 0, 0, 0, 0, 0, 0.73, 0.37, 0.24, 0, 0,
|
||||
0, 0, 0, 0.08, 0.04, 0.24, 0, 0, 0, 0, 0, 0.67, 0.75, 0.67, 0, 0, 0, 0, 0,
|
||||
0.37, 0.24, 0.93, 0, 0, 0, 0, 0, 0.04, 0.24, 0.52, 0, 0, 0, 0, 0, 0.13, 0.03, 0.53,
|
||||
0, 0, 0, 0, 0, 0.61, 0.43, 0.14, 0, 0, 0, 0, 0, 0.04, 0.10, 0.18, 0, 0, 0,
|
||||
0, 0, 0.52, 0.02, 0.33, 0, 0, 0, 0, 0, 0.35, 0.74, 0.78, 0, 0, 0, 0, 0, 0.73,
|
||||
0.37, 0.24, 0, 0, 0, 0, 0, 0.08, 0.04, 0.24, 0, 0, 0, 0, 0, 0.67, 0.75, 0.67, 0,
|
||||
0, 0, 0, 0, 0.37, 0.24, 0.93, 0, 0, 0, 0, 0, 0.04, 0.24, 0.52, 0, 0, 0, 0,
|
||||
0, 0.13, 0.03, 0.53, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
|
||||
0, 0, 0, 0, 0, 0};
|
||||
float co[] = {
|
||||
0.21, 0.38, 0.81, 0.98, 0.09, 0.68, 0.02, 0.33, 0.85, 0.67, 0.81, 0.57, 0.70, 0.27, 0.90, 0.07, 0.13, 0.03, 0.53,
|
||||
0.97, 0.92, 0.35, 0.74, 0.78, 0.87, 0.23, 0.34, 0.09, 0.50, 0.39, 0.09, 0.93, 0.91, 0.20, 0.97, 0.61, 0.43, 0.14,
|
||||
0.67, 0.10, 0.73, 0.37, 0.24, 0.93, 0.31, 0.35, 0.52, 0.02, 0.33, 0.99, 0.49, 0.67, 0.75, 0.66, 0.04, 0.10, 0.18,
|
||||
0.92, 0.46, 0.08, 0.04, 0.24, 0.52, 0.43, 0.14, 0.67, 0.10, 0.73, 0.37, 0.24, 0.93, 0.31, 0.35, 0.52, 0.02, 0.33,
|
||||
0.99, 0.49, 0.67, 0.75, 0.66, 0.04, 0.10, 0.18, 0.92, 0.46, 0.08, 0.04, 0.24, 0.52, 0.21, 0.38, 0.81, 0.98, 0.09,
|
||||
0.68, 0.02, 0.33, 0.85, 0.67, 0.81, 0.57, 0.70, 0.27, 0.90, 0.07, 0.13, 0.03, 0.53, 0.97, 0.92, 0.35, 0.74, 0.78,
|
||||
0.87, 0.23, 0.34, 0.09, 0.50, 0.39, 0.09, 0.93, 0.91, 0.20, 0.97, 0.61, 0.43, 0.14, 0.67, 0.10, 0.73, 0.37, 0.24,
|
||||
0.93, 0.31, 0.35, 0.52, 0.02, 0.33, 0.99, 0.49, 0.67, 0.75, 0.66, 0.04, 0.10, 0.18, 0.92, 0.46, 0.08, 0.04, 0.24,
|
||||
0.52, 0.43, 0.14, 0.67, 0.10, 0.73, 0.37, 0.24, 0.93, 0.31, 0.35, 0.52, 0.02, 0.33, 0.99, 0.49, 0.67, 0.75, 0.67,
|
||||
0.75, 0.66, 0.04, 0.10, 0.18, 0.92, 0.46, 0.08, 0.04, 0.24, 0.52, 0.43, 0.14, 0.67, 0.10, 0.73, 0.37, 0.24, 0.93,
|
||||
0.31, 0.35, 0.52, 0.02, 0.33, 0.99, 0.49, 0.67, 0.75, 0.66, 0.04, 0.10, 0.18, 0.92, 0.46, 0.08, 0.04, 0.24, 0.52,
|
||||
0.21, 0.38, 0.81, 0.98, 0.09, 0.68, 0.02, 0.33, 0.85, 0.67, 0.81, 0.57, 0.70, 0.27, 0.90, 0.07, 0.13, 0.03, 0.53,
|
||||
0.97, 0.92, 0.35, 0.74, 0.78, 0.87, 0.23, 0.34, 0.09, 0.50, 0.39, 0.09, 0.93, 0.91, 0.20, 0.97, 0.61, 0.43, 0.14,
|
||||
0.67, 0.10, 0.73, 0.37, 0.24, 0.93, 0.31, 0.35, 0.52, 0.02, 0.33, 0.99, 0.49, 0.67, 0.75, 0.66, 0.04, 0.10, 0.18,
|
||||
0.92, 0.46, 0.08, 0.04, 0.24, 0.52, 0.43, 0.14, 0.67, 0.10, 0.73, 0.37, 0.24, 0.93, 0.31, 0.35, 0.52, 0.02, 0.33,
|
||||
0.68, 0.02, 0.33, 0.85, 0.67, 0.81, 0.57, 0.70, 0.27, 0.90, 0.07, 0.13, 0.03, 0.53, 0.97, 0.92, 0.35, 0.74, 0.78,
|
||||
0.87, 0.23, 0.34, 0.09, 0.50, 0.39, 0.09, 0.93, 0.91, 0.20, 0.97, 0.61, 0.43, 0.14, 0.67, 0.10, 0.73, 0.37, 0.24,
|
||||
0.93, 0.31, 0.35, 0.52, 0.02, 0.33, 0.99, 0.49, 0.67, 0.75, 0.66, 0.04, 0.10, 0.18, 0.92, 0.46, 0.08, 0.04, 0.24,
|
||||
0.52, 0.43, 0.14, 0.67, 0.10, 0.73, 0.37, 0.24, 0.93, 0.31, 0.35, 0.52, 0.02, 0.33, 0.99, 0.49, 0.67, 0.75, 0.67,
|
||||
0.75, 0.66, 0.04, 0.10, 0.18, 0.92, 0.46, 0.08, 0.04, 0.24, 0.52, 0.43, 0.14, 0.67, 0.10, 0.73, 0.37, 0.24, 0.93,
|
||||
0.31, 0.35, 0.52, 0.02, 0.33, 0.99, 0.49, 0.67, 0.75, 0.66, 0.04, 0.10, 0.18, 0.92, 0.46, 0.08, 0.04, 0.24, 0.52,
|
||||
0.21, 0.38, 0.81, 0.98, 0.09, 0.68, 0.02, 0.33, 0.85, 0.67, 0.81, 0.57, 0.70, 0.27, 0.90, 0.07, 0.13, 0.03, 0.53};
|
||||
float out[418] = {0};
|
||||
Row8x8Major2RowMajor(in, out, 22, 19, 19);
|
||||
CompareOutputData(out, co, 418, 0.0001);
|
||||
}
|
||||
|
||||
TEST_F(TestMatMulFp32, Row8x82RowTest4) {
|
||||
float in[] = {0.21, 0.38, 0.81, 0.98, 0.09, 0.68, 0.02, 0.33, 0.85, 0.67, 0.81, 0.57, 0.70, 0.27, 0.90, 0.27,
|
||||
0.14, 0.67, 0.10, 0.73, 0.37, 0.24, 0.93, 0.31, 0.70, 0.27, 0.90, 0.07, 0.13, 0.03, 0.53, 0.97,
|
||||
0.93, 0.91, 0.20, 0.97, 0.61, 0.43, 0.14, 0.67, 0.49, 0.67, 0.75, 0.66, 0.04, 0.10, 0.18, 0.92,
|
||||
0.07, 0.13, 0.03, 0.53, 0.97, 0.92, 0.35, 0.74, 0.78, 0.87, 0.23, 0.34, 0.09, 0.50, 0.27, 0.39};
|
||||
float co[] = {0.21, 0.38, 0.81, 0.98, 0.09, 0.68, 0.02, 0.33, 0.85, 0.67, 0.81, 0.57, 0.70, 0.27, 0.90, 0.27,
|
||||
0.14, 0.67, 0.10, 0.73, 0.37, 0.24, 0.93, 0.31, 0.70, 0.27, 0.90, 0.07, 0.13, 0.03, 0.53, 0.97,
|
||||
0.93, 0.91, 0.20, 0.97, 0.61, 0.43, 0.14, 0.67, 0.49, 0.67, 0.75, 0.66, 0.04, 0.10, 0.18, 0.92,
|
||||
0.07, 0.13, 0.03, 0.53, 0.97, 0.92, 0.35, 0.74, 0.78, 0.87, 0.23, 0.34, 0.09, 0.50, 0.27, 0.39};
|
||||
float out[64] = {0};
|
||||
Row8x8Major2RowMajor(in, out, 8, 8, 8);
|
||||
CompareOutputData(out, co, 64, 0.0001);
|
||||
}
|
||||
|
||||
int MMTestInit(std::vector<lite::Tensor *> *inputs_, std::vector<lite::Tensor *> *outputs_, float *a_ptr, float *b_ptr,
|
||||
std::vector<int> a_shape, std::vector<int> b_shape, std::vector<int> c_shape) {
|
||||
auto in_t = new lite::Tensor(kNumberTypeFloat, a_shape, schema::Format_NHWC, lite::Tensor::Category::CONST);
|
||||
|
|
Loading…
Reference in New Issue