!18846 [ms][lite][cpu] deconv optimize
Merge pull request !18846 from liuzhongkai/deconv1
This commit is contained in:
commit
593bf34d35
|
@ -960,17 +960,9 @@ void MatMulOpt(const float *a, const float *b, float *c, const float *bias, ActT
|
|||
MatmulFloatNeon32Opt(a, b, c, bias, (int)act_type, deep, row, col, stride, (int)(out_type));
|
||||
}
|
||||
#elif ENABLE_AVX
|
||||
if (out_type == OutType_C8) {
|
||||
MatmulFloatSse64(a, b, c, bias, (int)act_type, deep, row, col, stride, 0, 0);
|
||||
} else {
|
||||
MatmulFloatAvxOpt(a, b, c, bias, (size_t)act_type, deep, row, col, stride, (size_t)(out_type));
|
||||
}
|
||||
MatmulFloatAvxOpt(a, b, c, bias, (size_t)act_type, deep, row, col, stride, (size_t)(out_type));
|
||||
#elif ENABLE_SSE
|
||||
if (out_type == OutType_C8) {
|
||||
MatmulFloatSse64(a, b, c, bias, (int)act_type, deep, row, col, stride, 0, 0);
|
||||
} else {
|
||||
MatmulFloatSse64Opt(a, b, c, bias, (int)act_type, deep, row, col, stride, (int)(out_type));
|
||||
}
|
||||
MatmulFloatSse64Opt(a, b, c, bias, (int)act_type, deep, row, col, stride, (int)(out_type));
|
||||
#else
|
||||
MatMul12x8(a, b, c, bias, act_type, deep, row, col, stride, out_type);
|
||||
#endif
|
||||
|
|
|
@ -73,11 +73,14 @@ void MatmulFloatNeon32Opt(const float *a, const float *b, float *c, const float
|
|||
void MatmulFloatNeon32Opt12x4(const float *a, const float *b, float *c, const float *bias, int act_type, int depth,
|
||||
int row, int col, int stride, int write_mode);
|
||||
#elif ENABLE_SSE
|
||||
void MatmulFloatSse64(const float *a, const float *b, float *c, const float *bias, int act_type, int depth, int row,
|
||||
int col, int stride, size_t writeNhwc, size_t WriteWino);
|
||||
void DeconvMatmulFloatSse(const float *a, const float *b, float *c, int depth, int row, int col);
|
||||
void MatmulFloatSse64Opt(const float *a, const float *b, float *c, const float *bias, int act_type, int depth, int row,
|
||||
int col, int stride, int write_mode);
|
||||
#ifdef ENABLE_AVX
|
||||
typedef void (*DeconvAvxKernel)(const float *src, const float *weight, float *dst, int col, int row, int depth,
|
||||
int stride);
|
||||
void DeconvMatmulFloatAvx(const float *a, const float *b, float *c, int depth, int row, int col, int kernel_plane);
|
||||
void DeconvAvxColXRowKernel(const float *src, const float *weight, float *dst, int col, int row, int depth, int stride);
|
||||
void MatmulFloatAvxOpt(const float *a, const float *b, float *c, const float *bias, size_t act_type, size_t depth,
|
||||
size_t row, size_t col, size_t stride, size_t write_mode);
|
||||
typedef void (*MatVecMulKernel)(float *dst, const float *src, const float *weight, const float *bias, size_t act_flag,
|
||||
|
|
|
@ -293,7 +293,26 @@ void PackNHWCToC8HWN8Fp32(const void *src, void *dst, int batch, int plane, int
|
|||
}
|
||||
}
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
void PackNHWCToCXHWNXFp32(const float *src, float *dst, int batch, int plane, int channel) {
|
||||
// pack weight NHWC to C24HWN24 (Priority 24)=>C16HWN16 (Not satisfied 24)=>C8HWN8 (Not satisfied 16)
|
||||
int oc_block = 0;
|
||||
int oc_block_num = UP_DIV(channel, C8NUM);
|
||||
for (int i = 0; i < oc_block_num; i += oc_block) {
|
||||
oc_block = MSMIN(C3NUM, oc_block_num - i); // max_tile = 4
|
||||
int oc_remainder = MSMIN(C8NUM * oc_block, channel - i * C8NUM);
|
||||
for (int p = 0; p < plane; ++p) {
|
||||
int index_plane = i * C8NUM + p * channel;
|
||||
for (int b = 0; b < batch; ++b) {
|
||||
int index_batch = index_plane + b * plane * channel;
|
||||
for (int oc = 0; oc < oc_remainder; ++oc) {
|
||||
dst[oc] = src[index_batch + oc];
|
||||
}
|
||||
dst += oc_block * C8NUM;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void PackDepthwiseIndirectWeightC4Fp32(const void *src, void *dst, int height, int width, int channel) {
|
||||
|
|
|
@ -40,6 +40,7 @@ void PackNHWCXToNHWCFp32(const void *src, void *dst, int batch, int plane, int c
|
|||
void PackNC4HW4ToNHWC4Fp32(const void *src, void *dst, int batch, int plane, int channel);
|
||||
void PackNC4HW4ToNHWCFp32(const void *src, void *dst, int batch, int plane, int channel);
|
||||
void PackNHWCToC8HWN8Fp32(const void *src, void *dst, int batch, int plane, int channel);
|
||||
void PackNHWCToCXHWNXFp32(const float *src, float *dst, int batch, int plane, int channel);
|
||||
|
||||
void PackWeightKHWToHWKFp32(const void *src, void *dst, int plane, int channel);
|
||||
void PackDepthwiseIndirectWeightC4Fp32(const void *src, void *dst, int height, int width, int channel);
|
||||
|
|
|
@ -16,6 +16,7 @@
|
|||
|
||||
#ifdef ENABLE_SSE
|
||||
#include <x86intrin.h>
|
||||
#include "nnacl/fp32/matmul_fp32.h"
|
||||
#include "nnacl/op_base.h"
|
||||
#include "nnacl/matmul_parameter.h"
|
||||
#include "nnacl/intrinsics/sse/sse_common.h"
|
||||
|
@ -204,9 +205,7 @@ void MatmulFloatSse64Opt(const float *a, const float *b, float *c, const float *
|
|||
}
|
||||
}
|
||||
|
||||
void MatmulFloatSse64(const float *a, const float *b, float *c, const float *bias, int act_type, int depth, int row,
|
||||
int col, int stride, size_t writeNhwc, size_t WriteWino) {
|
||||
size_t DstWinoSteps = stride * C8NUM, WriteWinoSteps = stride * col;
|
||||
void DeconvMatmulFloatSse(const float *a, const float *b, float *c, int depth, int row, int col) {
|
||||
for (int col_tmp = col; col_tmp > 0; col_tmp -= C8NUM) {
|
||||
const float *srca_d = a;
|
||||
float *dst = c;
|
||||
|
@ -231,63 +230,181 @@ void MatmulFloatSse64(const float *a, const float *b, float *c, const float *bia
|
|||
dst7 = _mm_add_ps(dst7, tmp3), dst8 = _mm_add_ps(dst8, tmp4);
|
||||
srcb_d += C8NUM, srca_d += C4NUM;
|
||||
}
|
||||
|
||||
if (bias != NULL) {
|
||||
DoBiasBlock8(bias, &dst1, &dst2, &dst3, &dst4, &dst5, &dst6, &dst7, &dst8);
|
||||
}
|
||||
|
||||
ActBlock8(&dst1, &dst2, &dst3, &dst4, &dst5, &dst6, &dst7, &dst8, act_type);
|
||||
|
||||
if (WriteWino != 0) { // WriteWino
|
||||
_mm_storeu_ps(dst, dst1), _mm_storeu_ps(dst + 4, dst2);
|
||||
dst += WriteWinoSteps;
|
||||
_mm_storeu_ps(dst, dst3), _mm_storeu_ps(dst + 4, dst4);
|
||||
dst += WriteWinoSteps;
|
||||
_mm_storeu_ps(dst, dst5), _mm_storeu_ps(dst + 4, dst6);
|
||||
dst += WriteWinoSteps;
|
||||
_mm_storeu_ps(dst, dst7), _mm_storeu_ps(dst + 4, dst8);
|
||||
dst += WriteWinoSteps;
|
||||
} else if (writeNhwc == 0) { // WriteC8
|
||||
_mm_storeu_ps(dst, dst1), _mm_storeu_ps(dst + 4, dst2);
|
||||
_mm_storeu_ps(dst + 8, dst3), _mm_storeu_ps(dst + 12, dst4);
|
||||
_mm_storeu_ps(dst + 16, dst5), _mm_storeu_ps(dst + 20, dst6);
|
||||
_mm_storeu_ps(dst + 24, dst7), _mm_storeu_ps(dst + 28, dst8);
|
||||
dst += 32;
|
||||
c = dst;
|
||||
} else {
|
||||
switch (col) {
|
||||
case 1: // write1
|
||||
WriteCol1(&dst, &dst1, &dst2, &dst3, &dst4, &dst5, &dst6, &dst7, &dst8, stride, 0, r);
|
||||
case 2: // write2
|
||||
WriteCol2(&dst, &dst1, &dst2, &dst3, &dst4, &dst5, &dst6, &dst7, &dst8, stride, r);
|
||||
case 3: // write3
|
||||
WriteCol3(&dst, &dst1, &dst2, &dst3, &dst4, &dst5, &dst6, &dst7, &dst8, stride, 0, r);
|
||||
case 4: // write4
|
||||
WriteCol4(&dst, &dst1, &dst2, &dst3, &dst4, &dst5, &dst6, &dst7, &dst8, stride, 0, r);
|
||||
case 5: // // write
|
||||
WriteCol5(&dst, &dst1, &dst2, &dst3, &dst4, &dst5, &dst6, &dst7, &dst8, stride, 0, r);
|
||||
case 6: // write6
|
||||
WriteCol6(&dst, &dst1, &dst2, &dst3, &dst4, &dst5, &dst6, &dst7, &dst8, stride, 0, r);
|
||||
case 7: // write7
|
||||
WriteCol7(&dst, &dst1, &dst2, &dst3, &dst4, &dst5, &dst6, &dst7, &dst8, stride, 0, r);
|
||||
default: // write8
|
||||
WriteCol8(&dst, &dst1, &dst2, &dst3, &dst4, &dst5, &dst6, &dst7, &dst8, stride, 0, r);
|
||||
}
|
||||
}
|
||||
if (r <= C4NUM) { // WriteEnd
|
||||
break;
|
||||
}
|
||||
_mm_storeu_ps(dst, dst1), _mm_storeu_ps(dst + 4, dst2);
|
||||
_mm_storeu_ps(dst + 8, dst3), _mm_storeu_ps(dst + 12, dst4);
|
||||
_mm_storeu_ps(dst + 16, dst5), _mm_storeu_ps(dst + 20, dst6);
|
||||
_mm_storeu_ps(dst + 24, dst7), _mm_storeu_ps(dst + 28, dst8);
|
||||
dst += 32;
|
||||
c = dst;
|
||||
}
|
||||
b += depth * C8NUM;
|
||||
bias += (bias != NULL) ? C8NUM : 0;
|
||||
if (WriteWino != 0) {
|
||||
c += DstWinoSteps;
|
||||
} else if (writeNhwc != 0) {
|
||||
c += C8NUM;
|
||||
}
|
||||
if (col_tmp <= C8NUM) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
#ifdef ENABLE_AVX
|
||||
void DeconvAvx4X8Kernel(const float *src, const float *weight, float *dst, int col, int row, int depth, int stride) {
|
||||
__m256 res1 = _mm256_setzero_ps();
|
||||
__m256 res4 = _mm256_setzero_ps();
|
||||
__m256 res7 = _mm256_setzero_ps();
|
||||
__m256 res10 = _mm256_setzero_ps();
|
||||
|
||||
for (int d = 0; d < depth; ++d) {
|
||||
__m256 w0 = _mm256_loadu_ps(weight);
|
||||
__m256 tmp = _mm256_set1_ps(*src);
|
||||
res1 = _mm256_fmadd_ps(tmp, w0, res1);
|
||||
tmp = _mm256_set1_ps(*(src + 1));
|
||||
res4 = _mm256_fmadd_ps(tmp, w0, res4);
|
||||
tmp = _mm256_set1_ps(*(src + 2));
|
||||
res7 = _mm256_fmadd_ps(tmp, w0, res7);
|
||||
tmp = _mm256_set1_ps(*(src + 3));
|
||||
res10 = _mm256_fmadd_ps(tmp, w0, res10);
|
||||
weight += C8NUM;
|
||||
src += C4NUM;
|
||||
}
|
||||
// write
|
||||
_mm256_storeu_ps(dst, res1);
|
||||
_mm256_storeu_ps(dst + C8NUM, res4);
|
||||
_mm256_storeu_ps(dst + C16NUM, res7);
|
||||
_mm256_storeu_ps(dst + C24NUM, res10);
|
||||
}
|
||||
|
||||
void DeconvAvx4X16Kernel(const float *src, const float *weight, float *dst, int col, int row, int depth, int stride) {
|
||||
__m256 res1 = _mm256_setzero_ps();
|
||||
__m256 res2 = _mm256_setzero_ps();
|
||||
__m256 res4 = _mm256_setzero_ps();
|
||||
__m256 res5 = _mm256_setzero_ps();
|
||||
__m256 res7 = _mm256_setzero_ps();
|
||||
__m256 res8 = _mm256_setzero_ps();
|
||||
__m256 res10 = _mm256_setzero_ps();
|
||||
__m256 res11 = _mm256_setzero_ps();
|
||||
|
||||
for (int d = 0; d < depth; ++d) {
|
||||
__m256 w0 = _mm256_loadu_ps(weight);
|
||||
__m256 w1 = _mm256_loadu_ps(weight + C8NUM);
|
||||
weight += C16NUM;
|
||||
__m256 tmp = _mm256_set1_ps(*src);
|
||||
res1 = _mm256_fmadd_ps(tmp, w0, res1);
|
||||
res2 = _mm256_fmadd_ps(tmp, w1, res2);
|
||||
tmp = _mm256_set1_ps(*(src + 1));
|
||||
res4 = _mm256_fmadd_ps(tmp, w0, res4);
|
||||
res5 = _mm256_fmadd_ps(tmp, w1, res5);
|
||||
tmp = _mm256_set1_ps(*(src + 2));
|
||||
res7 = _mm256_fmadd_ps(tmp, w0, res7);
|
||||
res8 = _mm256_fmadd_ps(tmp, w1, res8);
|
||||
tmp = _mm256_set1_ps(*(src + 3));
|
||||
res10 = _mm256_fmadd_ps(tmp, w0, res10);
|
||||
res11 = _mm256_fmadd_ps(tmp, w1, res11);
|
||||
src += C4NUM;
|
||||
}
|
||||
// write
|
||||
_mm256_storeu_ps(dst, res1);
|
||||
_mm256_storeu_ps(dst + C8NUM, res4);
|
||||
_mm256_storeu_ps(dst + C16NUM, res7);
|
||||
_mm256_storeu_ps(dst + C24NUM, res10);
|
||||
|
||||
_mm256_storeu_ps(dst + stride, res2);
|
||||
_mm256_storeu_ps(dst + stride + C8NUM, res5);
|
||||
_mm256_storeu_ps(dst + stride + C16NUM, res8);
|
||||
_mm256_storeu_ps(dst + stride + C24NUM, res11);
|
||||
}
|
||||
|
||||
void DeconvAvx4X24Kernel(const float *src, const float *weight, float *dst, int col, int row, int depth, int stride) {
|
||||
__m256 res1 = _mm256_setzero_ps();
|
||||
__m256 res2 = _mm256_setzero_ps();
|
||||
__m256 res3 = _mm256_setzero_ps();
|
||||
__m256 res4 = _mm256_setzero_ps();
|
||||
__m256 res5 = _mm256_setzero_ps();
|
||||
__m256 res6 = _mm256_setzero_ps();
|
||||
__m256 res7 = _mm256_setzero_ps();
|
||||
__m256 res8 = _mm256_setzero_ps();
|
||||
__m256 res9 = _mm256_setzero_ps();
|
||||
__m256 res10 = _mm256_setzero_ps();
|
||||
__m256 res11 = _mm256_setzero_ps();
|
||||
__m256 res12 = _mm256_setzero_ps();
|
||||
|
||||
for (int d = 0; d < depth; ++d) {
|
||||
__m256 w0 = _mm256_loadu_ps(weight);
|
||||
__m256 w1 = _mm256_loadu_ps(weight + C8NUM);
|
||||
__m256 w2 = _mm256_loadu_ps(weight + C16NUM);
|
||||
__m256 tmp = _mm256_set1_ps(*src);
|
||||
res1 = _mm256_fmadd_ps(tmp, w0, res1);
|
||||
res2 = _mm256_fmadd_ps(tmp, w1, res2);
|
||||
res3 = _mm256_fmadd_ps(tmp, w2, res3);
|
||||
tmp = _mm256_set1_ps(*(src + 1));
|
||||
res4 = _mm256_fmadd_ps(tmp, w0, res4);
|
||||
res5 = _mm256_fmadd_ps(tmp, w1, res5);
|
||||
res6 = _mm256_fmadd_ps(tmp, w2, res6);
|
||||
tmp = _mm256_set1_ps(*(src + 2));
|
||||
res7 = _mm256_fmadd_ps(tmp, w0, res7);
|
||||
res8 = _mm256_fmadd_ps(tmp, w1, res8);
|
||||
res9 = _mm256_fmadd_ps(tmp, w2, res9);
|
||||
tmp = _mm256_set1_ps(*(src + 3));
|
||||
res10 = _mm256_fmadd_ps(tmp, w0, res10);
|
||||
res11 = _mm256_fmadd_ps(tmp, w1, res11);
|
||||
res12 = _mm256_fmadd_ps(tmp, w2, res12);
|
||||
weight += C24NUM;
|
||||
src += C4NUM;
|
||||
}
|
||||
// write
|
||||
_mm256_storeu_ps(dst, res1);
|
||||
_mm256_storeu_ps(dst + C8NUM, res4);
|
||||
_mm256_storeu_ps(dst + C16NUM, res7);
|
||||
_mm256_storeu_ps(dst + C24NUM, res10);
|
||||
|
||||
_mm256_storeu_ps(dst + stride, res2);
|
||||
_mm256_storeu_ps(dst + stride + C8NUM, res5);
|
||||
_mm256_storeu_ps(dst + stride + C16NUM, res8);
|
||||
_mm256_storeu_ps(dst + stride + C24NUM, res11);
|
||||
|
||||
_mm256_storeu_ps(dst + 2 * stride, res3);
|
||||
_mm256_storeu_ps(dst + 2 * stride + C8NUM, res6);
|
||||
_mm256_storeu_ps(dst + 2 * stride + C16NUM, res9);
|
||||
_mm256_storeu_ps(dst + 2 * stride + C24NUM, res12);
|
||||
}
|
||||
|
||||
void DeconvMatmulFloatAvx(const float *a, const float *b, float *c, int depth, int row, int col, int plane) {
|
||||
int col_num = 0;
|
||||
int col_block = UP_DIV(col / plane, C8NUM);
|
||||
DeconvAvxKernel kernel[3] = {DeconvAvx4X8Kernel, DeconvAvx4X16Kernel, DeconvAvx4X24Kernel};
|
||||
for (int col_tmp = 0; col_tmp < col_block; col_tmp += col_num) {
|
||||
col_num = MSMIN(C3NUM, col_block - col_tmp);
|
||||
for (int p = 0; p < plane; ++p) {
|
||||
for (int r = 0; r < row; r += C4NUM) {
|
||||
kernel[col_num - 1](a + r * depth, b + (col_tmp * plane + p * col_num) * C8NUM * depth,
|
||||
c + (col_tmp * plane + p) * C8NUM * row + r * C8NUM, col_num, C4NUM, depth,
|
||||
row * C8NUM * plane);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void DeconvAvxRowXColKernel(const float *src, const float *weight, float *dst, int col, int row, int depth,
|
||||
int stride) {
|
||||
__m256 res[C12NUM];
|
||||
__m256 w[C3NUM];
|
||||
for (int i = 0; i < C12NUM; ++i) {
|
||||
res[i] = _mm256_setzero_ps();
|
||||
}
|
||||
for (int d = 0; d < depth; ++d) {
|
||||
for (int c = 0; c < col; ++c) {
|
||||
w[c] = _mm256_loadu_ps(weight + c * C8NUM);
|
||||
}
|
||||
weight += col * C8NUM;
|
||||
for (int r = 0; r < row; ++r) { // C4NUm
|
||||
__m256 tmp = _mm256_set1_ps(*src);
|
||||
for (int c = 0; c < col; ++c) { // 3 * C8NUM
|
||||
res[r * col + c] = _mm256_fmadd_ps(tmp, w[c], res[r * col + c]);
|
||||
}
|
||||
src += 1;
|
||||
}
|
||||
}
|
||||
// write
|
||||
for (int i = 0; i < col; ++i) {
|
||||
for (int j = 0; j < row; ++j) {
|
||||
_mm256_storeu_ps(dst + j * C8NUM, res[j * col + i]);
|
||||
}
|
||||
dst += stride;
|
||||
}
|
||||
}
|
||||
#endif
|
||||
#endif
|
||||
|
|
|
@ -26,11 +26,13 @@
|
|||
#endif
|
||||
|
||||
#define C2NUM 2
|
||||
#define C3NUM 3
|
||||
#define C4NUM 4
|
||||
#define C6NUM 6
|
||||
#define C8NUM 8
|
||||
#define C12NUM 12
|
||||
#define C16NUM 16
|
||||
#define C24NUM 24
|
||||
#define C32NUM 32
|
||||
#define TILE_NUM 8
|
||||
|
||||
|
|
|
@ -31,10 +31,8 @@ DeConvolutionCPUKernel::~DeConvolutionCPUKernel() {
|
|||
delete matmul_param_;
|
||||
matmul_param_ = nullptr;
|
||||
}
|
||||
if (weight_ptr_ != nullptr) {
|
||||
free(weight_ptr_);
|
||||
weight_ptr_ = nullptr;
|
||||
}
|
||||
FreeAlignedData(reinterpret_cast<void **>(&weight_ptr_));
|
||||
FreeAlignedData(reinterpret_cast<void **>(&bias_ptr));
|
||||
}
|
||||
|
||||
int DeConvolutionCPUKernel::ReSize() {
|
||||
|
@ -58,34 +56,39 @@ int DeConvolutionCPUKernel::InitWeightBias() {
|
|||
auto output_channel = weight_tensor->Channel();
|
||||
auto kernel_h_ = weight_tensor->Height();
|
||||
auto kernel_w_ = weight_tensor->Width();
|
||||
|
||||
bias_data_ = malloc(UP_ROUND(output_channel, C8NUM) * sizeof(float));
|
||||
if (bias_data_ == nullptr) {
|
||||
MS_LOG(ERROR) << "deconv malloc bias_data_ error!";
|
||||
int output_aligned_size = UP_ROUND(output_channel, C8NUM);
|
||||
bias_ptr = reinterpret_cast<float *>(MallocAlignedData(C32NUM, output_aligned_size * sizeof(float)));
|
||||
if (bias_ptr == nullptr) {
|
||||
MS_LOG(ERROR) << "deconv malloc bias_ptr error!";
|
||||
return RET_ERROR;
|
||||
}
|
||||
memset(bias_data_, 0, UP_ROUND(output_channel, C8NUM) * sizeof(float));
|
||||
if (in_tensors_.size() == 3) {
|
||||
if (in_tensors_.at(kBiasIndex)->shape().size() == 1 &&
|
||||
memset(bias_ptr, 0, output_aligned_size * sizeof(float));
|
||||
if (in_tensors_.size() == DIMENSION_3D) {
|
||||
if (in_tensors_.at(kBiasIndex)->shape().size() == DIMENSION_1D &&
|
||||
in_tensors_.at(kBiasIndex)->DimensionSize(0) == output_channel) {
|
||||
MS_ASSERT(in_tensors_.at(kBiasIndex)->data_c() != nullptr);
|
||||
memcpy(bias_data_, in_tensors_.at(kBiasIndex)->data_c(), output_channel * sizeof(float));
|
||||
memcpy(bias_ptr, in_tensors_.at(kBiasIndex)->data_c(), output_channel * sizeof(float));
|
||||
} else {
|
||||
MS_LOG(ERROR) << "unsupported bias shape for deconv!";
|
||||
return RET_ERROR;
|
||||
}
|
||||
}
|
||||
|
||||
size_t weight_pack_size = input_channel * kernel_w_ * kernel_h_ * UP_ROUND(output_channel, C8NUM) * sizeof(float);
|
||||
weight_ptr_ = reinterpret_cast<float *>(malloc(weight_pack_size));
|
||||
size_t weight_pack_size = input_channel * kernel_w_ * kernel_h_ * output_aligned_size * sizeof(float);
|
||||
weight_ptr_ = reinterpret_cast<float *>(MallocAlignedData(C32NUM, weight_pack_size));
|
||||
if (weight_ptr_ == nullptr) {
|
||||
MS_LOG(ERROR) << "deconv malloc weight_ptr_ error!";
|
||||
return RET_ERROR;
|
||||
}
|
||||
memset(weight_ptr_, 0, weight_pack_size);
|
||||
MS_ASSERT(in_tensors_.at(kWeightIndex)->data_c() != nullptr);
|
||||
#ifdef ENABLE_AVX
|
||||
PackNHWCToCXHWNXFp32(reinterpret_cast<float *>(in_tensors_.at(kWeightIndex)->data_c()), weight_ptr_, input_channel,
|
||||
kernel_w_ * kernel_h_, output_channel);
|
||||
#else
|
||||
PackNHWCToC8HWN8Fp32(reinterpret_cast<float *>(in_tensors_.at(kWeightIndex)->data_c()), weight_ptr_, input_channel,
|
||||
kernel_w_ * kernel_h_, output_channel);
|
||||
#endif
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
|
@ -97,12 +100,15 @@ int DeConvolutionCPUKernel::InitParam() {
|
|||
matmul_param_->row_ = input_plane_;
|
||||
matmul_param_->deep_ = conv_param_->input_channel_;
|
||||
matmul_param_->col_ = conv_param_->output_channel_ * kernel_plane_;
|
||||
matmul_param_->row_12_ = UP_ROUND(matmul_param_->row_, C12NUM);
|
||||
matmul_param_->row_4_ = UP_ROUND(matmul_param_->row_, C4NUM);
|
||||
matmul_param_->row_align_ = UP_ROUND(matmul_param_->row_, row_tile_);
|
||||
matmul_param_->col_8_ = UP_ROUND(conv_param_->output_channel_, C8NUM) * kernel_plane_;
|
||||
|
||||
thread_count_ = MSMIN(op_parameter_->thread_num_, UP_DIV(conv_param_->output_channel_, C8NUM));
|
||||
#ifdef ENABLE_AVX
|
||||
thread_stride_ = UP_DIV(UP_DIV(conv_param_->output_channel_, C8NUM * C3NUM), thread_count_) * C3NUM;
|
||||
#else
|
||||
thread_stride_ = UP_DIV(UP_DIV(conv_param_->output_channel_, C8NUM), thread_count_);
|
||||
#endif
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
|
@ -125,26 +131,33 @@ int DeConvolutionCPUKernel::DoDeconv(int task_id) {
|
|||
if (oc <= 0 || oc_res <= 0) {
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
#if defined(ENABLE_ARM32) || defined(ENABLE_SSE)
|
||||
auto tmp_buffer = tmp_buffer_ + task_id * thread_stride_ * C8NUM * kernel_plane_ * matmul_param_->row_4_;
|
||||
MatMulOpt(pack_input_, weight_ptr_ + task_id * thread_stride_ * C8NUM * kernel_plane_ * matmul_param_->deep_,
|
||||
tmp_buffer, nullptr, ActType_No, matmul_param_->deep_, matmul_param_->row_4_, oc * C8NUM * kernel_plane_,
|
||||
matmul_param_->col_, OutType_C8);
|
||||
auto tmp_buffer = tmp_buffer_ + task_id * thread_stride_ * C8NUM * kernel_plane_ * matmul_param_->row_align_;
|
||||
#ifdef ENABLE_AVX
|
||||
DeconvMatmulFloatAvx(
|
||||
pack_input_, weight_ptr_ + task_id * thread_stride_ * C8NUM * kernel_plane_ * matmul_param_->deep_, tmp_buffer,
|
||||
matmul_param_->deep_, matmul_param_->row_align_, oc * C8NUM * kernel_plane_, kernel_plane_);
|
||||
#elif ENABLE_SSE
|
||||
DeconvMatmulFloatSse(pack_input_,
|
||||
weight_ptr_ + task_id * thread_stride_ * C8NUM * kernel_plane_ * matmul_param_->deep_,
|
||||
tmp_buffer, matmul_param_->deep_, matmul_param_->row_align_, oc * C8NUM * kernel_plane_);
|
||||
#else
|
||||
auto tmp_buffer = tmp_buffer_ + task_id * thread_stride_ * C8NUM * kernel_plane_ * matmul_param_->row_12_;
|
||||
MatMulOpt(pack_input_, weight_ptr_ + task_id * thread_stride_ * C8NUM * kernel_plane_ * matmul_param_->deep_,
|
||||
tmp_buffer, nullptr, ActType_No, matmul_param_->deep_, matmul_param_->row_12_, oc * C8NUM * kernel_plane_,
|
||||
matmul_param_->col_, OutType_C8);
|
||||
tmp_buffer, nullptr, ActType_No, matmul_param_->deep_, matmul_param_->row_align_,
|
||||
oc * C8NUM * kernel_plane_, matmul_param_->col_, OutType_C8);
|
||||
#endif
|
||||
|
||||
DeConvPostFp32C8(tmp_buffer, pack_output_ + task_id * thread_stride_ * C8NUM * output_plane_,
|
||||
reinterpret_cast<float *>(bias_data_) + thread_stride_ * task_id * C8NUM,
|
||||
reinterpret_cast<float *>(bias_ptr) + thread_stride_ * task_id * C8NUM,
|
||||
output_ptr_ + task_id * thread_stride_ * C8NUM, oc_res, conv_param_);
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
int DeConvolutionCPUKernel::Init() {
|
||||
#if defined(ENABLE_ARM32) || defined(ENABLE_AVX) || defined(ENABLE_SSE)
|
||||
row_tile_ = C4NUM;
|
||||
#else
|
||||
row_tile_ = C12NUM;
|
||||
#endif
|
||||
matmul_param_ = new (std::nothrow) MatMulParameter();
|
||||
if (matmul_param_ == nullptr) {
|
||||
MS_LOG(ERROR) << "Memory allocation failed";
|
||||
|
@ -174,7 +187,6 @@ void DeConvolutionCPUKernel::FreeRunBuf() {
|
|||
ctx_->allocator->Free(pack_input_);
|
||||
pack_input_ = nullptr;
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
int DeConvolutionCPUKernel::InitRunBuf() {
|
||||
|
@ -185,25 +197,15 @@ int DeConvolutionCPUKernel::InitRunBuf() {
|
|||
return RET_NULL_PTR;
|
||||
}
|
||||
|
||||
#if defined(ENABLE_ARM32) || defined(ENABLE_SSE)
|
||||
tmp_buffer_ =
|
||||
reinterpret_cast<float *>(ctx_->allocator->Malloc(matmul_param_->row_4_ * matmul_param_->col_8_ * sizeof(float)));
|
||||
#else
|
||||
tmp_buffer_ =
|
||||
reinterpret_cast<float *>(ctx_->allocator->Malloc(matmul_param_->row_12_ * matmul_param_->col_8_ * sizeof(float)));
|
||||
#endif
|
||||
tmp_buffer_ = reinterpret_cast<float *>(
|
||||
ctx_->allocator->Malloc(matmul_param_->row_align_ * matmul_param_->col_8_ * sizeof(float)));
|
||||
if (tmp_buffer_ == nullptr) {
|
||||
MS_LOG(ERROR) << "Conv1x1 Malloc tmp_buffer_ error!";
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
|
||||
#if defined(ENABLE_ARM32) || defined(ENABLE_SSE)
|
||||
pack_input_ =
|
||||
reinterpret_cast<float *>(ctx_->allocator->Malloc(matmul_param_->row_4_ * matmul_param_->deep_ * sizeof(float)));
|
||||
#else
|
||||
pack_input_ =
|
||||
reinterpret_cast<float *>(ctx_->allocator->Malloc(matmul_param_->row_12_ * matmul_param_->deep_ * sizeof(float)));
|
||||
#endif
|
||||
pack_input_ = reinterpret_cast<float *>(
|
||||
ctx_->allocator->Malloc(matmul_param_->row_align_ * matmul_param_->deep_ * sizeof(float)));
|
||||
if (pack_input_ == nullptr) {
|
||||
MS_LOG(ERROR) << "deconv Malloc pack_input_ error!";
|
||||
return RET_ERROR;
|
||||
|
@ -254,6 +256,17 @@ kernel::InnerKernel *CpuDeConvFp32KernelCreator(const std::vector<lite::Tensor *
|
|||
auto conv_param = reinterpret_cast<ConvParameter *>(op_parameter);
|
||||
kernel::InnerKernel *kernel = nullptr;
|
||||
if (conv_param->group_ == 1) {
|
||||
#ifdef ENABLE_AVX
|
||||
if ((conv_param->stride_h_ > 1 || conv_param->stride_w_ > 1) &&
|
||||
(conv_param->dilation_w_ == 1 && conv_param->dilation_h_ == 1) &&
|
||||
(conv_param->kernel_w_ / conv_param->stride_w_ > 2 || conv_param->kernel_h_ / conv_param->stride_h_ > 2)) {
|
||||
kernel = new (std::nothrow) kernel::DeConvolutionWinogradCPUKernel(op_parameter, inputs, outputs,
|
||||
static_cast<const lite::InnerContext *>(ctx));
|
||||
} else {
|
||||
kernel = new (std::nothrow)
|
||||
kernel::DeConvolutionCPUKernel(op_parameter, inputs, outputs, static_cast<const lite::InnerContext *>(ctx));
|
||||
}
|
||||
#else
|
||||
if ((conv_param->stride_h_ != 1 || conv_param->stride_w_ != 1) &&
|
||||
(conv_param->dilation_w_ == 1 && conv_param->dilation_h_ == 1)) {
|
||||
kernel = new (std::nothrow) kernel::DeConvolutionWinogradCPUKernel(op_parameter, inputs, outputs,
|
||||
|
@ -262,6 +275,7 @@ kernel::InnerKernel *CpuDeConvFp32KernelCreator(const std::vector<lite::Tensor *
|
|||
kernel = new (std::nothrow)
|
||||
kernel::DeConvolutionCPUKernel(op_parameter, inputs, outputs, static_cast<const lite::InnerContext *>(ctx));
|
||||
}
|
||||
#endif
|
||||
} else if (conv_param->group_ == conv_param->input_channel_ && conv_param->group_ == conv_param->output_channel_) {
|
||||
kernel = new (std::nothrow) kernel::DeconvolutionDepthwiseCPUKernel(op_parameter, inputs, outputs,
|
||||
static_cast<const lite::InnerContext *>(ctx));
|
||||
|
|
|
@ -54,12 +54,14 @@ class DeConvolutionCPUKernel : public ConvolutionBaseCPUKernel {
|
|||
int output_plane_ = 0;
|
||||
int thread_count_ = 1;
|
||||
int thread_stride_ = 0;
|
||||
int row_tile_ = 0;
|
||||
float *weight_ptr_ = nullptr;
|
||||
float *pack_input_ = nullptr;
|
||||
float *pack_output_ = nullptr;
|
||||
float *tmp_buffer_ = nullptr;
|
||||
float *input_ptr_ = nullptr;
|
||||
float *output_ptr_ = nullptr;
|
||||
float *bias_ptr = nullptr;
|
||||
};
|
||||
} // namespace mindspore::kernel
|
||||
#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_DECONVOLUTION_H_
|
||||
|
|
Loading…
Reference in New Issue