From a0ad5777bbacac84aa47ad774915f9f47549613a Mon Sep 17 00:00:00 2001 From: lzk Date: Thu, 24 Jun 2021 20:30:01 -0700 Subject: [PATCH] deconv --- .../cpu/nnacl/fp32/matmul_fp32.c | 12 +- .../cpu/nnacl/fp32/matmul_fp32.h | 7 +- .../cpu/nnacl/fp32/pack_fp32.c | 21 +- .../cpu/nnacl/fp32/pack_fp32.h | 1 + .../cpu/nnacl/intrinsics/sse/MatMul_Sse.c | 231 +++++++++++++----- .../kernel_compiler/cpu/nnacl/op_base.h | 2 + .../kernel/arm/fp32/deconvolution_fp32.cc | 96 ++++---- .../kernel/arm/fp32/deconvolution_fp32.h | 2 + 8 files changed, 261 insertions(+), 111 deletions(-) diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/fp32/matmul_fp32.c b/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/fp32/matmul_fp32.c index 2af8fb553ab..4a2f2955453 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/fp32/matmul_fp32.c +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/fp32/matmul_fp32.c @@ -960,17 +960,9 @@ void MatMulOpt(const float *a, const float *b, float *c, const float *bias, ActT MatmulFloatNeon32Opt(a, b, c, bias, (int)act_type, deep, row, col, stride, (int)(out_type)); } #elif ENABLE_AVX - if (out_type == OutType_C8) { - MatmulFloatSse64(a, b, c, bias, (int)act_type, deep, row, col, stride, 0, 0); - } else { - MatmulFloatAvxOpt(a, b, c, bias, (size_t)act_type, deep, row, col, stride, (size_t)(out_type)); - } + MatmulFloatAvxOpt(a, b, c, bias, (size_t)act_type, deep, row, col, stride, (size_t)(out_type)); #elif ENABLE_SSE - if (out_type == OutType_C8) { - MatmulFloatSse64(a, b, c, bias, (int)act_type, deep, row, col, stride, 0, 0); - } else { - MatmulFloatSse64Opt(a, b, c, bias, (int)act_type, deep, row, col, stride, (int)(out_type)); - } + MatmulFloatSse64Opt(a, b, c, bias, (int)act_type, deep, row, col, stride, (int)(out_type)); #else MatMul12x8(a, b, c, bias, act_type, deep, row, col, stride, out_type); #endif diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/fp32/matmul_fp32.h b/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/fp32/matmul_fp32.h index 513e525a935..a82587fd3bc 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/fp32/matmul_fp32.h +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/fp32/matmul_fp32.h @@ -73,11 +73,14 @@ void MatmulFloatNeon32Opt(const float *a, const float *b, float *c, const float void MatmulFloatNeon32Opt12x4(const float *a, const float *b, float *c, const float *bias, int act_type, int depth, int row, int col, int stride, int write_mode); #elif ENABLE_SSE -void MatmulFloatSse64(const float *a, const float *b, float *c, const float *bias, int act_type, int depth, int row, - int col, int stride, size_t writeNhwc, size_t WriteWino); +void DeconvMatmulFloatSse(const float *a, const float *b, float *c, int depth, int row, int col); void MatmulFloatSse64Opt(const float *a, const float *b, float *c, const float *bias, int act_type, int depth, int row, int col, int stride, int write_mode); #ifdef ENABLE_AVX +typedef void (*DeconvAvxKernel)(const float *src, const float *weight, float *dst, int col, int row, int depth, + int stride); +void DeconvMatmulFloatAvx(const float *a, const float *b, float *c, int depth, int row, int col, int kernel_plane); +void DeconvAvxColXRowKernel(const float *src, const float *weight, float *dst, int col, int row, int depth, int stride); void MatmulFloatAvxOpt(const float *a, const float *b, float *c, const float *bias, size_t act_type, size_t depth, size_t row, size_t col, size_t stride, size_t write_mode); typedef void (*MatVecMulKernel)(float *dst, const float *src, const float *weight, const float *bias, size_t act_flag, diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/fp32/pack_fp32.c b/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/fp32/pack_fp32.c index 67f8e7a835d..e520385c51a 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/fp32/pack_fp32.c +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/fp32/pack_fp32.c @@ -293,7 +293,26 @@ void PackNHWCToC8HWN8Fp32(const void *src, void *dst, int batch, int plane, int } } } - return; +} + +void PackNHWCToCXHWNXFp32(const float *src, float *dst, int batch, int plane, int channel) { + // pack weight NHWC to C24HWN24 (Priority 24)=>C16HWN16 (Not satisfied 24)=>C8HWN8 (Not satisfied 16) + int oc_block = 0; + int oc_block_num = UP_DIV(channel, C8NUM); + for (int i = 0; i < oc_block_num; i += oc_block) { + oc_block = MSMIN(C3NUM, oc_block_num - i); // max_tile = 4 + int oc_remainder = MSMIN(C8NUM * oc_block, channel - i * C8NUM); + for (int p = 0; p < plane; ++p) { + int index_plane = i * C8NUM + p * channel; + for (int b = 0; b < batch; ++b) { + int index_batch = index_plane + b * plane * channel; + for (int oc = 0; oc < oc_remainder; ++oc) { + dst[oc] = src[index_batch + oc]; + } + dst += oc_block * C8NUM; + } + } + } } void PackDepthwiseIndirectWeightC4Fp32(const void *src, void *dst, int height, int width, int channel) { diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/fp32/pack_fp32.h b/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/fp32/pack_fp32.h index 9ac7f191a40..a318dff8427 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/fp32/pack_fp32.h +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/fp32/pack_fp32.h @@ -40,6 +40,7 @@ void PackNHWCXToNHWCFp32(const void *src, void *dst, int batch, int plane, int c void PackNC4HW4ToNHWC4Fp32(const void *src, void *dst, int batch, int plane, int channel); void PackNC4HW4ToNHWCFp32(const void *src, void *dst, int batch, int plane, int channel); void PackNHWCToC8HWN8Fp32(const void *src, void *dst, int batch, int plane, int channel); +void PackNHWCToCXHWNXFp32(const float *src, float *dst, int batch, int plane, int channel); void PackWeightKHWToHWKFp32(const void *src, void *dst, int plane, int channel); void PackDepthwiseIndirectWeightC4Fp32(const void *src, void *dst, int height, int width, int channel); diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/intrinsics/sse/MatMul_Sse.c b/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/intrinsics/sse/MatMul_Sse.c index 17981ccf5eb..d58ff403dd6 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/intrinsics/sse/MatMul_Sse.c +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/intrinsics/sse/MatMul_Sse.c @@ -16,6 +16,7 @@ #ifdef ENABLE_SSE #include +#include "nnacl/fp32/matmul_fp32.h" #include "nnacl/op_base.h" #include "nnacl/matmul_parameter.h" #include "nnacl/intrinsics/sse/sse_common.h" @@ -204,9 +205,7 @@ void MatmulFloatSse64Opt(const float *a, const float *b, float *c, const float * } } -void MatmulFloatSse64(const float *a, const float *b, float *c, const float *bias, int act_type, int depth, int row, - int col, int stride, size_t writeNhwc, size_t WriteWino) { - size_t DstWinoSteps = stride * C8NUM, WriteWinoSteps = stride * col; +void DeconvMatmulFloatSse(const float *a, const float *b, float *c, int depth, int row, int col) { for (int col_tmp = col; col_tmp > 0; col_tmp -= C8NUM) { const float *srca_d = a; float *dst = c; @@ -231,63 +230,181 @@ void MatmulFloatSse64(const float *a, const float *b, float *c, const float *bia dst7 = _mm_add_ps(dst7, tmp3), dst8 = _mm_add_ps(dst8, tmp4); srcb_d += C8NUM, srca_d += C4NUM; } - - if (bias != NULL) { - DoBiasBlock8(bias, &dst1, &dst2, &dst3, &dst4, &dst5, &dst6, &dst7, &dst8); - } - - ActBlock8(&dst1, &dst2, &dst3, &dst4, &dst5, &dst6, &dst7, &dst8, act_type); - - if (WriteWino != 0) { // WriteWino - _mm_storeu_ps(dst, dst1), _mm_storeu_ps(dst + 4, dst2); - dst += WriteWinoSteps; - _mm_storeu_ps(dst, dst3), _mm_storeu_ps(dst + 4, dst4); - dst += WriteWinoSteps; - _mm_storeu_ps(dst, dst5), _mm_storeu_ps(dst + 4, dst6); - dst += WriteWinoSteps; - _mm_storeu_ps(dst, dst7), _mm_storeu_ps(dst + 4, dst8); - dst += WriteWinoSteps; - } else if (writeNhwc == 0) { // WriteC8 - _mm_storeu_ps(dst, dst1), _mm_storeu_ps(dst + 4, dst2); - _mm_storeu_ps(dst + 8, dst3), _mm_storeu_ps(dst + 12, dst4); - _mm_storeu_ps(dst + 16, dst5), _mm_storeu_ps(dst + 20, dst6); - _mm_storeu_ps(dst + 24, dst7), _mm_storeu_ps(dst + 28, dst8); - dst += 32; - c = dst; - } else { - switch (col) { - case 1: // write1 - WriteCol1(&dst, &dst1, &dst2, &dst3, &dst4, &dst5, &dst6, &dst7, &dst8, stride, 0, r); - case 2: // write2 - WriteCol2(&dst, &dst1, &dst2, &dst3, &dst4, &dst5, &dst6, &dst7, &dst8, stride, r); - case 3: // write3 - WriteCol3(&dst, &dst1, &dst2, &dst3, &dst4, &dst5, &dst6, &dst7, &dst8, stride, 0, r); - case 4: // write4 - WriteCol4(&dst, &dst1, &dst2, &dst3, &dst4, &dst5, &dst6, &dst7, &dst8, stride, 0, r); - case 5: // // write - WriteCol5(&dst, &dst1, &dst2, &dst3, &dst4, &dst5, &dst6, &dst7, &dst8, stride, 0, r); - case 6: // write6 - WriteCol6(&dst, &dst1, &dst2, &dst3, &dst4, &dst5, &dst6, &dst7, &dst8, stride, 0, r); - case 7: // write7 - WriteCol7(&dst, &dst1, &dst2, &dst3, &dst4, &dst5, &dst6, &dst7, &dst8, stride, 0, r); - default: // write8 - WriteCol8(&dst, &dst1, &dst2, &dst3, &dst4, &dst5, &dst6, &dst7, &dst8, stride, 0, r); - } - } - if (r <= C4NUM) { // WriteEnd - break; - } + _mm_storeu_ps(dst, dst1), _mm_storeu_ps(dst + 4, dst2); + _mm_storeu_ps(dst + 8, dst3), _mm_storeu_ps(dst + 12, dst4); + _mm_storeu_ps(dst + 16, dst5), _mm_storeu_ps(dst + 20, dst6); + _mm_storeu_ps(dst + 24, dst7), _mm_storeu_ps(dst + 28, dst8); + dst += 32; + c = dst; } b += depth * C8NUM; - bias += (bias != NULL) ? C8NUM : 0; - if (WriteWino != 0) { - c += DstWinoSteps; - } else if (writeNhwc != 0) { - c += C8NUM; - } - if (col_tmp <= C8NUM) { - break; + } +} + +#ifdef ENABLE_AVX +void DeconvAvx4X8Kernel(const float *src, const float *weight, float *dst, int col, int row, int depth, int stride) { + __m256 res1 = _mm256_setzero_ps(); + __m256 res4 = _mm256_setzero_ps(); + __m256 res7 = _mm256_setzero_ps(); + __m256 res10 = _mm256_setzero_ps(); + + for (int d = 0; d < depth; ++d) { + __m256 w0 = _mm256_loadu_ps(weight); + __m256 tmp = _mm256_set1_ps(*src); + res1 = _mm256_fmadd_ps(tmp, w0, res1); + tmp = _mm256_set1_ps(*(src + 1)); + res4 = _mm256_fmadd_ps(tmp, w0, res4); + tmp = _mm256_set1_ps(*(src + 2)); + res7 = _mm256_fmadd_ps(tmp, w0, res7); + tmp = _mm256_set1_ps(*(src + 3)); + res10 = _mm256_fmadd_ps(tmp, w0, res10); + weight += C8NUM; + src += C4NUM; + } + // write + _mm256_storeu_ps(dst, res1); + _mm256_storeu_ps(dst + C8NUM, res4); + _mm256_storeu_ps(dst + C16NUM, res7); + _mm256_storeu_ps(dst + C24NUM, res10); +} + +void DeconvAvx4X16Kernel(const float *src, const float *weight, float *dst, int col, int row, int depth, int stride) { + __m256 res1 = _mm256_setzero_ps(); + __m256 res2 = _mm256_setzero_ps(); + __m256 res4 = _mm256_setzero_ps(); + __m256 res5 = _mm256_setzero_ps(); + __m256 res7 = _mm256_setzero_ps(); + __m256 res8 = _mm256_setzero_ps(); + __m256 res10 = _mm256_setzero_ps(); + __m256 res11 = _mm256_setzero_ps(); + + for (int d = 0; d < depth; ++d) { + __m256 w0 = _mm256_loadu_ps(weight); + __m256 w1 = _mm256_loadu_ps(weight + C8NUM); + weight += C16NUM; + __m256 tmp = _mm256_set1_ps(*src); + res1 = _mm256_fmadd_ps(tmp, w0, res1); + res2 = _mm256_fmadd_ps(tmp, w1, res2); + tmp = _mm256_set1_ps(*(src + 1)); + res4 = _mm256_fmadd_ps(tmp, w0, res4); + res5 = _mm256_fmadd_ps(tmp, w1, res5); + tmp = _mm256_set1_ps(*(src + 2)); + res7 = _mm256_fmadd_ps(tmp, w0, res7); + res8 = _mm256_fmadd_ps(tmp, w1, res8); + tmp = _mm256_set1_ps(*(src + 3)); + res10 = _mm256_fmadd_ps(tmp, w0, res10); + res11 = _mm256_fmadd_ps(tmp, w1, res11); + src += C4NUM; + } + // write + _mm256_storeu_ps(dst, res1); + _mm256_storeu_ps(dst + C8NUM, res4); + _mm256_storeu_ps(dst + C16NUM, res7); + _mm256_storeu_ps(dst + C24NUM, res10); + + _mm256_storeu_ps(dst + stride, res2); + _mm256_storeu_ps(dst + stride + C8NUM, res5); + _mm256_storeu_ps(dst + stride + C16NUM, res8); + _mm256_storeu_ps(dst + stride + C24NUM, res11); +} + +void DeconvAvx4X24Kernel(const float *src, const float *weight, float *dst, int col, int row, int depth, int stride) { + __m256 res1 = _mm256_setzero_ps(); + __m256 res2 = _mm256_setzero_ps(); + __m256 res3 = _mm256_setzero_ps(); + __m256 res4 = _mm256_setzero_ps(); + __m256 res5 = _mm256_setzero_ps(); + __m256 res6 = _mm256_setzero_ps(); + __m256 res7 = _mm256_setzero_ps(); + __m256 res8 = _mm256_setzero_ps(); + __m256 res9 = _mm256_setzero_ps(); + __m256 res10 = _mm256_setzero_ps(); + __m256 res11 = _mm256_setzero_ps(); + __m256 res12 = _mm256_setzero_ps(); + + for (int d = 0; d < depth; ++d) { + __m256 w0 = _mm256_loadu_ps(weight); + __m256 w1 = _mm256_loadu_ps(weight + C8NUM); + __m256 w2 = _mm256_loadu_ps(weight + C16NUM); + __m256 tmp = _mm256_set1_ps(*src); + res1 = _mm256_fmadd_ps(tmp, w0, res1); + res2 = _mm256_fmadd_ps(tmp, w1, res2); + res3 = _mm256_fmadd_ps(tmp, w2, res3); + tmp = _mm256_set1_ps(*(src + 1)); + res4 = _mm256_fmadd_ps(tmp, w0, res4); + res5 = _mm256_fmadd_ps(tmp, w1, res5); + res6 = _mm256_fmadd_ps(tmp, w2, res6); + tmp = _mm256_set1_ps(*(src + 2)); + res7 = _mm256_fmadd_ps(tmp, w0, res7); + res8 = _mm256_fmadd_ps(tmp, w1, res8); + res9 = _mm256_fmadd_ps(tmp, w2, res9); + tmp = _mm256_set1_ps(*(src + 3)); + res10 = _mm256_fmadd_ps(tmp, w0, res10); + res11 = _mm256_fmadd_ps(tmp, w1, res11); + res12 = _mm256_fmadd_ps(tmp, w2, res12); + weight += C24NUM; + src += C4NUM; + } + // write + _mm256_storeu_ps(dst, res1); + _mm256_storeu_ps(dst + C8NUM, res4); + _mm256_storeu_ps(dst + C16NUM, res7); + _mm256_storeu_ps(dst + C24NUM, res10); + + _mm256_storeu_ps(dst + stride, res2); + _mm256_storeu_ps(dst + stride + C8NUM, res5); + _mm256_storeu_ps(dst + stride + C16NUM, res8); + _mm256_storeu_ps(dst + stride + C24NUM, res11); + + _mm256_storeu_ps(dst + 2 * stride, res3); + _mm256_storeu_ps(dst + 2 * stride + C8NUM, res6); + _mm256_storeu_ps(dst + 2 * stride + C16NUM, res9); + _mm256_storeu_ps(dst + 2 * stride + C24NUM, res12); +} + +void DeconvMatmulFloatAvx(const float *a, const float *b, float *c, int depth, int row, int col, int plane) { + int col_num = 0; + int col_block = UP_DIV(col / plane, C8NUM); + DeconvAvxKernel kernel[3] = {DeconvAvx4X8Kernel, DeconvAvx4X16Kernel, DeconvAvx4X24Kernel}; + for (int col_tmp = 0; col_tmp < col_block; col_tmp += col_num) { + col_num = MSMIN(C3NUM, col_block - col_tmp); + for (int p = 0; p < plane; ++p) { + for (int r = 0; r < row; r += C4NUM) { + kernel[col_num - 1](a + r * depth, b + (col_tmp * plane + p * col_num) * C8NUM * depth, + c + (col_tmp * plane + p) * C8NUM * row + r * C8NUM, col_num, C4NUM, depth, + row * C8NUM * plane); + } } } } + +void DeconvAvxRowXColKernel(const float *src, const float *weight, float *dst, int col, int row, int depth, + int stride) { + __m256 res[C12NUM]; + __m256 w[C3NUM]; + for (int i = 0; i < C12NUM; ++i) { + res[i] = _mm256_setzero_ps(); + } + for (int d = 0; d < depth; ++d) { + for (int c = 0; c < col; ++c) { + w[c] = _mm256_loadu_ps(weight + c * C8NUM); + } + weight += col * C8NUM; + for (int r = 0; r < row; ++r) { // C4NUm + __m256 tmp = _mm256_set1_ps(*src); + for (int c = 0; c < col; ++c) { // 3 * C8NUM + res[r * col + c] = _mm256_fmadd_ps(tmp, w[c], res[r * col + c]); + } + src += 1; + } + } + // write + for (int i = 0; i < col; ++i) { + for (int j = 0; j < row; ++j) { + _mm256_storeu_ps(dst + j * C8NUM, res[j * col + i]); + } + dst += stride; + } +} +#endif #endif diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/op_base.h b/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/op_base.h index d699639c32b..4f27a3bb735 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/op_base.h +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/op_base.h @@ -26,11 +26,13 @@ #endif #define C2NUM 2 +#define C3NUM 3 #define C4NUM 4 #define C6NUM 6 #define C8NUM 8 #define C12NUM 12 #define C16NUM 16 +#define C24NUM 24 #define C32NUM 32 #define TILE_NUM 8 diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/deconvolution_fp32.cc b/mindspore/lite/src/runtime/kernel/arm/fp32/deconvolution_fp32.cc index 1e657d9d6c3..29997d9a7ca 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32/deconvolution_fp32.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/deconvolution_fp32.cc @@ -31,10 +31,8 @@ DeConvolutionCPUKernel::~DeConvolutionCPUKernel() { delete matmul_param_; matmul_param_ = nullptr; } - if (weight_ptr_ != nullptr) { - free(weight_ptr_); - weight_ptr_ = nullptr; - } + FreeAlignedData(reinterpret_cast(&weight_ptr_)); + FreeAlignedData(reinterpret_cast(&bias_ptr)); } int DeConvolutionCPUKernel::ReSize() { @@ -58,34 +56,39 @@ int DeConvolutionCPUKernel::InitWeightBias() { auto output_channel = weight_tensor->Channel(); auto kernel_h_ = weight_tensor->Height(); auto kernel_w_ = weight_tensor->Width(); - - bias_data_ = malloc(UP_ROUND(output_channel, C8NUM) * sizeof(float)); - if (bias_data_ == nullptr) { - MS_LOG(ERROR) << "deconv malloc bias_data_ error!"; + int output_aligned_size = UP_ROUND(output_channel, C8NUM); + bias_ptr = reinterpret_cast(MallocAlignedData(C32NUM, output_aligned_size * sizeof(float))); + if (bias_ptr == nullptr) { + MS_LOG(ERROR) << "deconv malloc bias_ptr error!"; return RET_ERROR; } - memset(bias_data_, 0, UP_ROUND(output_channel, C8NUM) * sizeof(float)); - if (in_tensors_.size() == 3) { - if (in_tensors_.at(kBiasIndex)->shape().size() == 1 && + memset(bias_ptr, 0, output_aligned_size * sizeof(float)); + if (in_tensors_.size() == DIMENSION_3D) { + if (in_tensors_.at(kBiasIndex)->shape().size() == DIMENSION_1D && in_tensors_.at(kBiasIndex)->DimensionSize(0) == output_channel) { MS_ASSERT(in_tensors_.at(kBiasIndex)->data_c() != nullptr); - memcpy(bias_data_, in_tensors_.at(kBiasIndex)->data_c(), output_channel * sizeof(float)); + memcpy(bias_ptr, in_tensors_.at(kBiasIndex)->data_c(), output_channel * sizeof(float)); } else { MS_LOG(ERROR) << "unsupported bias shape for deconv!"; return RET_ERROR; } } - size_t weight_pack_size = input_channel * kernel_w_ * kernel_h_ * UP_ROUND(output_channel, C8NUM) * sizeof(float); - weight_ptr_ = reinterpret_cast(malloc(weight_pack_size)); + size_t weight_pack_size = input_channel * kernel_w_ * kernel_h_ * output_aligned_size * sizeof(float); + weight_ptr_ = reinterpret_cast(MallocAlignedData(C32NUM, weight_pack_size)); if (weight_ptr_ == nullptr) { MS_LOG(ERROR) << "deconv malloc weight_ptr_ error!"; return RET_ERROR; } memset(weight_ptr_, 0, weight_pack_size); MS_ASSERT(in_tensors_.at(kWeightIndex)->data_c() != nullptr); +#ifdef ENABLE_AVX + PackNHWCToCXHWNXFp32(reinterpret_cast(in_tensors_.at(kWeightIndex)->data_c()), weight_ptr_, input_channel, + kernel_w_ * kernel_h_, output_channel); +#else PackNHWCToC8HWN8Fp32(reinterpret_cast(in_tensors_.at(kWeightIndex)->data_c()), weight_ptr_, input_channel, kernel_w_ * kernel_h_, output_channel); +#endif return RET_OK; } @@ -97,12 +100,15 @@ int DeConvolutionCPUKernel::InitParam() { matmul_param_->row_ = input_plane_; matmul_param_->deep_ = conv_param_->input_channel_; matmul_param_->col_ = conv_param_->output_channel_ * kernel_plane_; - matmul_param_->row_12_ = UP_ROUND(matmul_param_->row_, C12NUM); - matmul_param_->row_4_ = UP_ROUND(matmul_param_->row_, C4NUM); + matmul_param_->row_align_ = UP_ROUND(matmul_param_->row_, row_tile_); matmul_param_->col_8_ = UP_ROUND(conv_param_->output_channel_, C8NUM) * kernel_plane_; thread_count_ = MSMIN(op_parameter_->thread_num_, UP_DIV(conv_param_->output_channel_, C8NUM)); +#ifdef ENABLE_AVX + thread_stride_ = UP_DIV(UP_DIV(conv_param_->output_channel_, C8NUM * C3NUM), thread_count_) * C3NUM; +#else thread_stride_ = UP_DIV(UP_DIV(conv_param_->output_channel_, C8NUM), thread_count_); +#endif return RET_OK; } @@ -125,26 +131,33 @@ int DeConvolutionCPUKernel::DoDeconv(int task_id) { if (oc <= 0 || oc_res <= 0) { return RET_OK; } - -#if defined(ENABLE_ARM32) || defined(ENABLE_SSE) - auto tmp_buffer = tmp_buffer_ + task_id * thread_stride_ * C8NUM * kernel_plane_ * matmul_param_->row_4_; - MatMulOpt(pack_input_, weight_ptr_ + task_id * thread_stride_ * C8NUM * kernel_plane_ * matmul_param_->deep_, - tmp_buffer, nullptr, ActType_No, matmul_param_->deep_, matmul_param_->row_4_, oc * C8NUM * kernel_plane_, - matmul_param_->col_, OutType_C8); + auto tmp_buffer = tmp_buffer_ + task_id * thread_stride_ * C8NUM * kernel_plane_ * matmul_param_->row_align_; +#ifdef ENABLE_AVX + DeconvMatmulFloatAvx( + pack_input_, weight_ptr_ + task_id * thread_stride_ * C8NUM * kernel_plane_ * matmul_param_->deep_, tmp_buffer, + matmul_param_->deep_, matmul_param_->row_align_, oc * C8NUM * kernel_plane_, kernel_plane_); +#elif ENABLE_SSE + DeconvMatmulFloatSse(pack_input_, + weight_ptr_ + task_id * thread_stride_ * C8NUM * kernel_plane_ * matmul_param_->deep_, + tmp_buffer, matmul_param_->deep_, matmul_param_->row_align_, oc * C8NUM * kernel_plane_); #else - auto tmp_buffer = tmp_buffer_ + task_id * thread_stride_ * C8NUM * kernel_plane_ * matmul_param_->row_12_; MatMulOpt(pack_input_, weight_ptr_ + task_id * thread_stride_ * C8NUM * kernel_plane_ * matmul_param_->deep_, - tmp_buffer, nullptr, ActType_No, matmul_param_->deep_, matmul_param_->row_12_, oc * C8NUM * kernel_plane_, - matmul_param_->col_, OutType_C8); + tmp_buffer, nullptr, ActType_No, matmul_param_->deep_, matmul_param_->row_align_, + oc * C8NUM * kernel_plane_, matmul_param_->col_, OutType_C8); #endif DeConvPostFp32C8(tmp_buffer, pack_output_ + task_id * thread_stride_ * C8NUM * output_plane_, - reinterpret_cast(bias_data_) + thread_stride_ * task_id * C8NUM, + reinterpret_cast(bias_ptr) + thread_stride_ * task_id * C8NUM, output_ptr_ + task_id * thread_stride_ * C8NUM, oc_res, conv_param_); return RET_OK; } int DeConvolutionCPUKernel::Init() { +#if defined(ENABLE_ARM32) || defined(ENABLE_AVX) || defined(ENABLE_SSE) + row_tile_ = C4NUM; +#else + row_tile_ = C12NUM; +#endif matmul_param_ = new (std::nothrow) MatMulParameter(); if (matmul_param_ == nullptr) { MS_LOG(ERROR) << "Memory allocation failed"; @@ -174,7 +187,6 @@ void DeConvolutionCPUKernel::FreeRunBuf() { ctx_->allocator->Free(pack_input_); pack_input_ = nullptr; } - return; } int DeConvolutionCPUKernel::InitRunBuf() { @@ -185,25 +197,15 @@ int DeConvolutionCPUKernel::InitRunBuf() { return RET_NULL_PTR; } -#if defined(ENABLE_ARM32) || defined(ENABLE_SSE) - tmp_buffer_ = - reinterpret_cast(ctx_->allocator->Malloc(matmul_param_->row_4_ * matmul_param_->col_8_ * sizeof(float))); -#else - tmp_buffer_ = - reinterpret_cast(ctx_->allocator->Malloc(matmul_param_->row_12_ * matmul_param_->col_8_ * sizeof(float))); -#endif + tmp_buffer_ = reinterpret_cast( + ctx_->allocator->Malloc(matmul_param_->row_align_ * matmul_param_->col_8_ * sizeof(float))); if (tmp_buffer_ == nullptr) { MS_LOG(ERROR) << "Conv1x1 Malloc tmp_buffer_ error!"; return RET_NULL_PTR; } -#if defined(ENABLE_ARM32) || defined(ENABLE_SSE) - pack_input_ = - reinterpret_cast(ctx_->allocator->Malloc(matmul_param_->row_4_ * matmul_param_->deep_ * sizeof(float))); -#else - pack_input_ = - reinterpret_cast(ctx_->allocator->Malloc(matmul_param_->row_12_ * matmul_param_->deep_ * sizeof(float))); -#endif + pack_input_ = reinterpret_cast( + ctx_->allocator->Malloc(matmul_param_->row_align_ * matmul_param_->deep_ * sizeof(float))); if (pack_input_ == nullptr) { MS_LOG(ERROR) << "deconv Malloc pack_input_ error!"; return RET_ERROR; @@ -254,6 +256,17 @@ kernel::InnerKernel *CpuDeConvFp32KernelCreator(const std::vector(op_parameter); kernel::InnerKernel *kernel = nullptr; if (conv_param->group_ == 1) { +#ifdef ENABLE_AVX + if ((conv_param->stride_h_ > 1 || conv_param->stride_w_ > 1) && + (conv_param->dilation_w_ == 1 && conv_param->dilation_h_ == 1) && + (conv_param->kernel_w_ / conv_param->stride_w_ > 2 || conv_param->kernel_h_ / conv_param->stride_h_ > 2)) { + kernel = new (std::nothrow) kernel::DeConvolutionWinogradCPUKernel(op_parameter, inputs, outputs, + static_cast(ctx)); + } else { + kernel = new (std::nothrow) + kernel::DeConvolutionCPUKernel(op_parameter, inputs, outputs, static_cast(ctx)); + } +#else if ((conv_param->stride_h_ != 1 || conv_param->stride_w_ != 1) && (conv_param->dilation_w_ == 1 && conv_param->dilation_h_ == 1)) { kernel = new (std::nothrow) kernel::DeConvolutionWinogradCPUKernel(op_parameter, inputs, outputs, @@ -262,6 +275,7 @@ kernel::InnerKernel *CpuDeConvFp32KernelCreator(const std::vector(ctx)); } +#endif } else if (conv_param->group_ == conv_param->input_channel_ && conv_param->group_ == conv_param->output_channel_) { kernel = new (std::nothrow) kernel::DeconvolutionDepthwiseCPUKernel(op_parameter, inputs, outputs, static_cast(ctx)); diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/deconvolution_fp32.h b/mindspore/lite/src/runtime/kernel/arm/fp32/deconvolution_fp32.h index ec90d37c846..83f10cd2b81 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32/deconvolution_fp32.h +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/deconvolution_fp32.h @@ -54,12 +54,14 @@ class DeConvolutionCPUKernel : public ConvolutionBaseCPUKernel { int output_plane_ = 0; int thread_count_ = 1; int thread_stride_ = 0; + int row_tile_ = 0; float *weight_ptr_ = nullptr; float *pack_input_ = nullptr; float *pack_output_ = nullptr; float *tmp_buffer_ = nullptr; float *input_ptr_ = nullptr; float *output_ptr_ = nullptr; + float *bias_ptr = nullptr; }; } // namespace mindspore::kernel #endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_DECONVOLUTION_H_