This commit is contained in:
lzk 2021-06-24 20:30:01 -07:00
parent 49f012ad74
commit a0ad5777bb
8 changed files with 261 additions and 111 deletions

View File

@ -960,17 +960,9 @@ void MatMulOpt(const float *a, const float *b, float *c, const float *bias, ActT
MatmulFloatNeon32Opt(a, b, c, bias, (int)act_type, deep, row, col, stride, (int)(out_type));
}
#elif ENABLE_AVX
if (out_type == OutType_C8) {
MatmulFloatSse64(a, b, c, bias, (int)act_type, deep, row, col, stride, 0, 0);
} else {
MatmulFloatAvxOpt(a, b, c, bias, (size_t)act_type, deep, row, col, stride, (size_t)(out_type));
}
MatmulFloatAvxOpt(a, b, c, bias, (size_t)act_type, deep, row, col, stride, (size_t)(out_type));
#elif ENABLE_SSE
if (out_type == OutType_C8) {
MatmulFloatSse64(a, b, c, bias, (int)act_type, deep, row, col, stride, 0, 0);
} else {
MatmulFloatSse64Opt(a, b, c, bias, (int)act_type, deep, row, col, stride, (int)(out_type));
}
MatmulFloatSse64Opt(a, b, c, bias, (int)act_type, deep, row, col, stride, (int)(out_type));
#else
MatMul12x8(a, b, c, bias, act_type, deep, row, col, stride, out_type);
#endif

View File

@ -73,11 +73,14 @@ void MatmulFloatNeon32Opt(const float *a, const float *b, float *c, const float
void MatmulFloatNeon32Opt12x4(const float *a, const float *b, float *c, const float *bias, int act_type, int depth,
int row, int col, int stride, int write_mode);
#elif ENABLE_SSE
void MatmulFloatSse64(const float *a, const float *b, float *c, const float *bias, int act_type, int depth, int row,
int col, int stride, size_t writeNhwc, size_t WriteWino);
void DeconvMatmulFloatSse(const float *a, const float *b, float *c, int depth, int row, int col);
void MatmulFloatSse64Opt(const float *a, const float *b, float *c, const float *bias, int act_type, int depth, int row,
int col, int stride, int write_mode);
#ifdef ENABLE_AVX
typedef void (*DeconvAvxKernel)(const float *src, const float *weight, float *dst, int col, int row, int depth,
int stride);
void DeconvMatmulFloatAvx(const float *a, const float *b, float *c, int depth, int row, int col, int kernel_plane);
void DeconvAvxColXRowKernel(const float *src, const float *weight, float *dst, int col, int row, int depth, int stride);
void MatmulFloatAvxOpt(const float *a, const float *b, float *c, const float *bias, size_t act_type, size_t depth,
size_t row, size_t col, size_t stride, size_t write_mode);
typedef void (*MatVecMulKernel)(float *dst, const float *src, const float *weight, const float *bias, size_t act_flag,

View File

@ -293,7 +293,26 @@ void PackNHWCToC8HWN8Fp32(const void *src, void *dst, int batch, int plane, int
}
}
}
return;
}
void PackNHWCToCXHWNXFp32(const float *src, float *dst, int batch, int plane, int channel) {
// pack weight NHWC to C24HWN24 (Priority 24)=>C16HWN16 (Not satisfied 24)=>C8HWN8 (Not satisfied 16)
int oc_block = 0;
int oc_block_num = UP_DIV(channel, C8NUM);
for (int i = 0; i < oc_block_num; i += oc_block) {
oc_block = MSMIN(C3NUM, oc_block_num - i); // max_tile = 4
int oc_remainder = MSMIN(C8NUM * oc_block, channel - i * C8NUM);
for (int p = 0; p < plane; ++p) {
int index_plane = i * C8NUM + p * channel;
for (int b = 0; b < batch; ++b) {
int index_batch = index_plane + b * plane * channel;
for (int oc = 0; oc < oc_remainder; ++oc) {
dst[oc] = src[index_batch + oc];
}
dst += oc_block * C8NUM;
}
}
}
}
void PackDepthwiseIndirectWeightC4Fp32(const void *src, void *dst, int height, int width, int channel) {

View File

@ -40,6 +40,7 @@ void PackNHWCXToNHWCFp32(const void *src, void *dst, int batch, int plane, int c
void PackNC4HW4ToNHWC4Fp32(const void *src, void *dst, int batch, int plane, int channel);
void PackNC4HW4ToNHWCFp32(const void *src, void *dst, int batch, int plane, int channel);
void PackNHWCToC8HWN8Fp32(const void *src, void *dst, int batch, int plane, int channel);
void PackNHWCToCXHWNXFp32(const float *src, float *dst, int batch, int plane, int channel);
void PackWeightKHWToHWKFp32(const void *src, void *dst, int plane, int channel);
void PackDepthwiseIndirectWeightC4Fp32(const void *src, void *dst, int height, int width, int channel);

View File

@ -16,6 +16,7 @@
#ifdef ENABLE_SSE
#include <x86intrin.h>
#include "nnacl/fp32/matmul_fp32.h"
#include "nnacl/op_base.h"
#include "nnacl/matmul_parameter.h"
#include "nnacl/intrinsics/sse/sse_common.h"
@ -204,9 +205,7 @@ void MatmulFloatSse64Opt(const float *a, const float *b, float *c, const float *
}
}
void MatmulFloatSse64(const float *a, const float *b, float *c, const float *bias, int act_type, int depth, int row,
int col, int stride, size_t writeNhwc, size_t WriteWino) {
size_t DstWinoSteps = stride * C8NUM, WriteWinoSteps = stride * col;
void DeconvMatmulFloatSse(const float *a, const float *b, float *c, int depth, int row, int col) {
for (int col_tmp = col; col_tmp > 0; col_tmp -= C8NUM) {
const float *srca_d = a;
float *dst = c;
@ -231,63 +230,181 @@ void MatmulFloatSse64(const float *a, const float *b, float *c, const float *bia
dst7 = _mm_add_ps(dst7, tmp3), dst8 = _mm_add_ps(dst8, tmp4);
srcb_d += C8NUM, srca_d += C4NUM;
}
if (bias != NULL) {
DoBiasBlock8(bias, &dst1, &dst2, &dst3, &dst4, &dst5, &dst6, &dst7, &dst8);
}
ActBlock8(&dst1, &dst2, &dst3, &dst4, &dst5, &dst6, &dst7, &dst8, act_type);
if (WriteWino != 0) { // WriteWino
_mm_storeu_ps(dst, dst1), _mm_storeu_ps(dst + 4, dst2);
dst += WriteWinoSteps;
_mm_storeu_ps(dst, dst3), _mm_storeu_ps(dst + 4, dst4);
dst += WriteWinoSteps;
_mm_storeu_ps(dst, dst5), _mm_storeu_ps(dst + 4, dst6);
dst += WriteWinoSteps;
_mm_storeu_ps(dst, dst7), _mm_storeu_ps(dst + 4, dst8);
dst += WriteWinoSteps;
} else if (writeNhwc == 0) { // WriteC8
_mm_storeu_ps(dst, dst1), _mm_storeu_ps(dst + 4, dst2);
_mm_storeu_ps(dst + 8, dst3), _mm_storeu_ps(dst + 12, dst4);
_mm_storeu_ps(dst + 16, dst5), _mm_storeu_ps(dst + 20, dst6);
_mm_storeu_ps(dst + 24, dst7), _mm_storeu_ps(dst + 28, dst8);
dst += 32;
c = dst;
} else {
switch (col) {
case 1: // write1
WriteCol1(&dst, &dst1, &dst2, &dst3, &dst4, &dst5, &dst6, &dst7, &dst8, stride, 0, r);
case 2: // write2
WriteCol2(&dst, &dst1, &dst2, &dst3, &dst4, &dst5, &dst6, &dst7, &dst8, stride, r);
case 3: // write3
WriteCol3(&dst, &dst1, &dst2, &dst3, &dst4, &dst5, &dst6, &dst7, &dst8, stride, 0, r);
case 4: // write4
WriteCol4(&dst, &dst1, &dst2, &dst3, &dst4, &dst5, &dst6, &dst7, &dst8, stride, 0, r);
case 5: // // write
WriteCol5(&dst, &dst1, &dst2, &dst3, &dst4, &dst5, &dst6, &dst7, &dst8, stride, 0, r);
case 6: // write6
WriteCol6(&dst, &dst1, &dst2, &dst3, &dst4, &dst5, &dst6, &dst7, &dst8, stride, 0, r);
case 7: // write7
WriteCol7(&dst, &dst1, &dst2, &dst3, &dst4, &dst5, &dst6, &dst7, &dst8, stride, 0, r);
default: // write8
WriteCol8(&dst, &dst1, &dst2, &dst3, &dst4, &dst5, &dst6, &dst7, &dst8, stride, 0, r);
}
}
if (r <= C4NUM) { // WriteEnd
break;
}
_mm_storeu_ps(dst, dst1), _mm_storeu_ps(dst + 4, dst2);
_mm_storeu_ps(dst + 8, dst3), _mm_storeu_ps(dst + 12, dst4);
_mm_storeu_ps(dst + 16, dst5), _mm_storeu_ps(dst + 20, dst6);
_mm_storeu_ps(dst + 24, dst7), _mm_storeu_ps(dst + 28, dst8);
dst += 32;
c = dst;
}
b += depth * C8NUM;
bias += (bias != NULL) ? C8NUM : 0;
if (WriteWino != 0) {
c += DstWinoSteps;
} else if (writeNhwc != 0) {
c += C8NUM;
}
if (col_tmp <= C8NUM) {
break;
}
}
#ifdef ENABLE_AVX
void DeconvAvx4X8Kernel(const float *src, const float *weight, float *dst, int col, int row, int depth, int stride) {
__m256 res1 = _mm256_setzero_ps();
__m256 res4 = _mm256_setzero_ps();
__m256 res7 = _mm256_setzero_ps();
__m256 res10 = _mm256_setzero_ps();
for (int d = 0; d < depth; ++d) {
__m256 w0 = _mm256_loadu_ps(weight);
__m256 tmp = _mm256_set1_ps(*src);
res1 = _mm256_fmadd_ps(tmp, w0, res1);
tmp = _mm256_set1_ps(*(src + 1));
res4 = _mm256_fmadd_ps(tmp, w0, res4);
tmp = _mm256_set1_ps(*(src + 2));
res7 = _mm256_fmadd_ps(tmp, w0, res7);
tmp = _mm256_set1_ps(*(src + 3));
res10 = _mm256_fmadd_ps(tmp, w0, res10);
weight += C8NUM;
src += C4NUM;
}
// write
_mm256_storeu_ps(dst, res1);
_mm256_storeu_ps(dst + C8NUM, res4);
_mm256_storeu_ps(dst + C16NUM, res7);
_mm256_storeu_ps(dst + C24NUM, res10);
}
void DeconvAvx4X16Kernel(const float *src, const float *weight, float *dst, int col, int row, int depth, int stride) {
__m256 res1 = _mm256_setzero_ps();
__m256 res2 = _mm256_setzero_ps();
__m256 res4 = _mm256_setzero_ps();
__m256 res5 = _mm256_setzero_ps();
__m256 res7 = _mm256_setzero_ps();
__m256 res8 = _mm256_setzero_ps();
__m256 res10 = _mm256_setzero_ps();
__m256 res11 = _mm256_setzero_ps();
for (int d = 0; d < depth; ++d) {
__m256 w0 = _mm256_loadu_ps(weight);
__m256 w1 = _mm256_loadu_ps(weight + C8NUM);
weight += C16NUM;
__m256 tmp = _mm256_set1_ps(*src);
res1 = _mm256_fmadd_ps(tmp, w0, res1);
res2 = _mm256_fmadd_ps(tmp, w1, res2);
tmp = _mm256_set1_ps(*(src + 1));
res4 = _mm256_fmadd_ps(tmp, w0, res4);
res5 = _mm256_fmadd_ps(tmp, w1, res5);
tmp = _mm256_set1_ps(*(src + 2));
res7 = _mm256_fmadd_ps(tmp, w0, res7);
res8 = _mm256_fmadd_ps(tmp, w1, res8);
tmp = _mm256_set1_ps(*(src + 3));
res10 = _mm256_fmadd_ps(tmp, w0, res10);
res11 = _mm256_fmadd_ps(tmp, w1, res11);
src += C4NUM;
}
// write
_mm256_storeu_ps(dst, res1);
_mm256_storeu_ps(dst + C8NUM, res4);
_mm256_storeu_ps(dst + C16NUM, res7);
_mm256_storeu_ps(dst + C24NUM, res10);
_mm256_storeu_ps(dst + stride, res2);
_mm256_storeu_ps(dst + stride + C8NUM, res5);
_mm256_storeu_ps(dst + stride + C16NUM, res8);
_mm256_storeu_ps(dst + stride + C24NUM, res11);
}
void DeconvAvx4X24Kernel(const float *src, const float *weight, float *dst, int col, int row, int depth, int stride) {
__m256 res1 = _mm256_setzero_ps();
__m256 res2 = _mm256_setzero_ps();
__m256 res3 = _mm256_setzero_ps();
__m256 res4 = _mm256_setzero_ps();
__m256 res5 = _mm256_setzero_ps();
__m256 res6 = _mm256_setzero_ps();
__m256 res7 = _mm256_setzero_ps();
__m256 res8 = _mm256_setzero_ps();
__m256 res9 = _mm256_setzero_ps();
__m256 res10 = _mm256_setzero_ps();
__m256 res11 = _mm256_setzero_ps();
__m256 res12 = _mm256_setzero_ps();
for (int d = 0; d < depth; ++d) {
__m256 w0 = _mm256_loadu_ps(weight);
__m256 w1 = _mm256_loadu_ps(weight + C8NUM);
__m256 w2 = _mm256_loadu_ps(weight + C16NUM);
__m256 tmp = _mm256_set1_ps(*src);
res1 = _mm256_fmadd_ps(tmp, w0, res1);
res2 = _mm256_fmadd_ps(tmp, w1, res2);
res3 = _mm256_fmadd_ps(tmp, w2, res3);
tmp = _mm256_set1_ps(*(src + 1));
res4 = _mm256_fmadd_ps(tmp, w0, res4);
res5 = _mm256_fmadd_ps(tmp, w1, res5);
res6 = _mm256_fmadd_ps(tmp, w2, res6);
tmp = _mm256_set1_ps(*(src + 2));
res7 = _mm256_fmadd_ps(tmp, w0, res7);
res8 = _mm256_fmadd_ps(tmp, w1, res8);
res9 = _mm256_fmadd_ps(tmp, w2, res9);
tmp = _mm256_set1_ps(*(src + 3));
res10 = _mm256_fmadd_ps(tmp, w0, res10);
res11 = _mm256_fmadd_ps(tmp, w1, res11);
res12 = _mm256_fmadd_ps(tmp, w2, res12);
weight += C24NUM;
src += C4NUM;
}
// write
_mm256_storeu_ps(dst, res1);
_mm256_storeu_ps(dst + C8NUM, res4);
_mm256_storeu_ps(dst + C16NUM, res7);
_mm256_storeu_ps(dst + C24NUM, res10);
_mm256_storeu_ps(dst + stride, res2);
_mm256_storeu_ps(dst + stride + C8NUM, res5);
_mm256_storeu_ps(dst + stride + C16NUM, res8);
_mm256_storeu_ps(dst + stride + C24NUM, res11);
_mm256_storeu_ps(dst + 2 * stride, res3);
_mm256_storeu_ps(dst + 2 * stride + C8NUM, res6);
_mm256_storeu_ps(dst + 2 * stride + C16NUM, res9);
_mm256_storeu_ps(dst + 2 * stride + C24NUM, res12);
}
void DeconvMatmulFloatAvx(const float *a, const float *b, float *c, int depth, int row, int col, int plane) {
int col_num = 0;
int col_block = UP_DIV(col / plane, C8NUM);
DeconvAvxKernel kernel[3] = {DeconvAvx4X8Kernel, DeconvAvx4X16Kernel, DeconvAvx4X24Kernel};
for (int col_tmp = 0; col_tmp < col_block; col_tmp += col_num) {
col_num = MSMIN(C3NUM, col_block - col_tmp);
for (int p = 0; p < plane; ++p) {
for (int r = 0; r < row; r += C4NUM) {
kernel[col_num - 1](a + r * depth, b + (col_tmp * plane + p * col_num) * C8NUM * depth,
c + (col_tmp * plane + p) * C8NUM * row + r * C8NUM, col_num, C4NUM, depth,
row * C8NUM * plane);
}
}
}
}
void DeconvAvxRowXColKernel(const float *src, const float *weight, float *dst, int col, int row, int depth,
int stride) {
__m256 res[C12NUM];
__m256 w[C3NUM];
for (int i = 0; i < C12NUM; ++i) {
res[i] = _mm256_setzero_ps();
}
for (int d = 0; d < depth; ++d) {
for (int c = 0; c < col; ++c) {
w[c] = _mm256_loadu_ps(weight + c * C8NUM);
}
weight += col * C8NUM;
for (int r = 0; r < row; ++r) { // C4NUm
__m256 tmp = _mm256_set1_ps(*src);
for (int c = 0; c < col; ++c) { // 3 * C8NUM
res[r * col + c] = _mm256_fmadd_ps(tmp, w[c], res[r * col + c]);
}
src += 1;
}
}
// write
for (int i = 0; i < col; ++i) {
for (int j = 0; j < row; ++j) {
_mm256_storeu_ps(dst + j * C8NUM, res[j * col + i]);
}
dst += stride;
}
}
#endif
#endif

View File

@ -26,11 +26,13 @@
#endif
#define C2NUM 2
#define C3NUM 3
#define C4NUM 4
#define C6NUM 6
#define C8NUM 8
#define C12NUM 12
#define C16NUM 16
#define C24NUM 24
#define C32NUM 32
#define TILE_NUM 8

View File

@ -31,10 +31,8 @@ DeConvolutionCPUKernel::~DeConvolutionCPUKernel() {
delete matmul_param_;
matmul_param_ = nullptr;
}
if (weight_ptr_ != nullptr) {
free(weight_ptr_);
weight_ptr_ = nullptr;
}
FreeAlignedData(reinterpret_cast<void **>(&weight_ptr_));
FreeAlignedData(reinterpret_cast<void **>(&bias_ptr));
}
int DeConvolutionCPUKernel::ReSize() {
@ -58,34 +56,39 @@ int DeConvolutionCPUKernel::InitWeightBias() {
auto output_channel = weight_tensor->Channel();
auto kernel_h_ = weight_tensor->Height();
auto kernel_w_ = weight_tensor->Width();
bias_data_ = malloc(UP_ROUND(output_channel, C8NUM) * sizeof(float));
if (bias_data_ == nullptr) {
MS_LOG(ERROR) << "deconv malloc bias_data_ error!";
int output_aligned_size = UP_ROUND(output_channel, C8NUM);
bias_ptr = reinterpret_cast<float *>(MallocAlignedData(C32NUM, output_aligned_size * sizeof(float)));
if (bias_ptr == nullptr) {
MS_LOG(ERROR) << "deconv malloc bias_ptr error!";
return RET_ERROR;
}
memset(bias_data_, 0, UP_ROUND(output_channel, C8NUM) * sizeof(float));
if (in_tensors_.size() == 3) {
if (in_tensors_.at(kBiasIndex)->shape().size() == 1 &&
memset(bias_ptr, 0, output_aligned_size * sizeof(float));
if (in_tensors_.size() == DIMENSION_3D) {
if (in_tensors_.at(kBiasIndex)->shape().size() == DIMENSION_1D &&
in_tensors_.at(kBiasIndex)->DimensionSize(0) == output_channel) {
MS_ASSERT(in_tensors_.at(kBiasIndex)->data_c() != nullptr);
memcpy(bias_data_, in_tensors_.at(kBiasIndex)->data_c(), output_channel * sizeof(float));
memcpy(bias_ptr, in_tensors_.at(kBiasIndex)->data_c(), output_channel * sizeof(float));
} else {
MS_LOG(ERROR) << "unsupported bias shape for deconv!";
return RET_ERROR;
}
}
size_t weight_pack_size = input_channel * kernel_w_ * kernel_h_ * UP_ROUND(output_channel, C8NUM) * sizeof(float);
weight_ptr_ = reinterpret_cast<float *>(malloc(weight_pack_size));
size_t weight_pack_size = input_channel * kernel_w_ * kernel_h_ * output_aligned_size * sizeof(float);
weight_ptr_ = reinterpret_cast<float *>(MallocAlignedData(C32NUM, weight_pack_size));
if (weight_ptr_ == nullptr) {
MS_LOG(ERROR) << "deconv malloc weight_ptr_ error!";
return RET_ERROR;
}
memset(weight_ptr_, 0, weight_pack_size);
MS_ASSERT(in_tensors_.at(kWeightIndex)->data_c() != nullptr);
#ifdef ENABLE_AVX
PackNHWCToCXHWNXFp32(reinterpret_cast<float *>(in_tensors_.at(kWeightIndex)->data_c()), weight_ptr_, input_channel,
kernel_w_ * kernel_h_, output_channel);
#else
PackNHWCToC8HWN8Fp32(reinterpret_cast<float *>(in_tensors_.at(kWeightIndex)->data_c()), weight_ptr_, input_channel,
kernel_w_ * kernel_h_, output_channel);
#endif
return RET_OK;
}
@ -97,12 +100,15 @@ int DeConvolutionCPUKernel::InitParam() {
matmul_param_->row_ = input_plane_;
matmul_param_->deep_ = conv_param_->input_channel_;
matmul_param_->col_ = conv_param_->output_channel_ * kernel_plane_;
matmul_param_->row_12_ = UP_ROUND(matmul_param_->row_, C12NUM);
matmul_param_->row_4_ = UP_ROUND(matmul_param_->row_, C4NUM);
matmul_param_->row_align_ = UP_ROUND(matmul_param_->row_, row_tile_);
matmul_param_->col_8_ = UP_ROUND(conv_param_->output_channel_, C8NUM) * kernel_plane_;
thread_count_ = MSMIN(op_parameter_->thread_num_, UP_DIV(conv_param_->output_channel_, C8NUM));
#ifdef ENABLE_AVX
thread_stride_ = UP_DIV(UP_DIV(conv_param_->output_channel_, C8NUM * C3NUM), thread_count_) * C3NUM;
#else
thread_stride_ = UP_DIV(UP_DIV(conv_param_->output_channel_, C8NUM), thread_count_);
#endif
return RET_OK;
}
@ -125,26 +131,33 @@ int DeConvolutionCPUKernel::DoDeconv(int task_id) {
if (oc <= 0 || oc_res <= 0) {
return RET_OK;
}
#if defined(ENABLE_ARM32) || defined(ENABLE_SSE)
auto tmp_buffer = tmp_buffer_ + task_id * thread_stride_ * C8NUM * kernel_plane_ * matmul_param_->row_4_;
MatMulOpt(pack_input_, weight_ptr_ + task_id * thread_stride_ * C8NUM * kernel_plane_ * matmul_param_->deep_,
tmp_buffer, nullptr, ActType_No, matmul_param_->deep_, matmul_param_->row_4_, oc * C8NUM * kernel_plane_,
matmul_param_->col_, OutType_C8);
auto tmp_buffer = tmp_buffer_ + task_id * thread_stride_ * C8NUM * kernel_plane_ * matmul_param_->row_align_;
#ifdef ENABLE_AVX
DeconvMatmulFloatAvx(
pack_input_, weight_ptr_ + task_id * thread_stride_ * C8NUM * kernel_plane_ * matmul_param_->deep_, tmp_buffer,
matmul_param_->deep_, matmul_param_->row_align_, oc * C8NUM * kernel_plane_, kernel_plane_);
#elif ENABLE_SSE
DeconvMatmulFloatSse(pack_input_,
weight_ptr_ + task_id * thread_stride_ * C8NUM * kernel_plane_ * matmul_param_->deep_,
tmp_buffer, matmul_param_->deep_, matmul_param_->row_align_, oc * C8NUM * kernel_plane_);
#else
auto tmp_buffer = tmp_buffer_ + task_id * thread_stride_ * C8NUM * kernel_plane_ * matmul_param_->row_12_;
MatMulOpt(pack_input_, weight_ptr_ + task_id * thread_stride_ * C8NUM * kernel_plane_ * matmul_param_->deep_,
tmp_buffer, nullptr, ActType_No, matmul_param_->deep_, matmul_param_->row_12_, oc * C8NUM * kernel_plane_,
matmul_param_->col_, OutType_C8);
tmp_buffer, nullptr, ActType_No, matmul_param_->deep_, matmul_param_->row_align_,
oc * C8NUM * kernel_plane_, matmul_param_->col_, OutType_C8);
#endif
DeConvPostFp32C8(tmp_buffer, pack_output_ + task_id * thread_stride_ * C8NUM * output_plane_,
reinterpret_cast<float *>(bias_data_) + thread_stride_ * task_id * C8NUM,
reinterpret_cast<float *>(bias_ptr) + thread_stride_ * task_id * C8NUM,
output_ptr_ + task_id * thread_stride_ * C8NUM, oc_res, conv_param_);
return RET_OK;
}
int DeConvolutionCPUKernel::Init() {
#if defined(ENABLE_ARM32) || defined(ENABLE_AVX) || defined(ENABLE_SSE)
row_tile_ = C4NUM;
#else
row_tile_ = C12NUM;
#endif
matmul_param_ = new (std::nothrow) MatMulParameter();
if (matmul_param_ == nullptr) {
MS_LOG(ERROR) << "Memory allocation failed";
@ -174,7 +187,6 @@ void DeConvolutionCPUKernel::FreeRunBuf() {
ctx_->allocator->Free(pack_input_);
pack_input_ = nullptr;
}
return;
}
int DeConvolutionCPUKernel::InitRunBuf() {
@ -185,25 +197,15 @@ int DeConvolutionCPUKernel::InitRunBuf() {
return RET_NULL_PTR;
}
#if defined(ENABLE_ARM32) || defined(ENABLE_SSE)
tmp_buffer_ =
reinterpret_cast<float *>(ctx_->allocator->Malloc(matmul_param_->row_4_ * matmul_param_->col_8_ * sizeof(float)));
#else
tmp_buffer_ =
reinterpret_cast<float *>(ctx_->allocator->Malloc(matmul_param_->row_12_ * matmul_param_->col_8_ * sizeof(float)));
#endif
tmp_buffer_ = reinterpret_cast<float *>(
ctx_->allocator->Malloc(matmul_param_->row_align_ * matmul_param_->col_8_ * sizeof(float)));
if (tmp_buffer_ == nullptr) {
MS_LOG(ERROR) << "Conv1x1 Malloc tmp_buffer_ error!";
return RET_NULL_PTR;
}
#if defined(ENABLE_ARM32) || defined(ENABLE_SSE)
pack_input_ =
reinterpret_cast<float *>(ctx_->allocator->Malloc(matmul_param_->row_4_ * matmul_param_->deep_ * sizeof(float)));
#else
pack_input_ =
reinterpret_cast<float *>(ctx_->allocator->Malloc(matmul_param_->row_12_ * matmul_param_->deep_ * sizeof(float)));
#endif
pack_input_ = reinterpret_cast<float *>(
ctx_->allocator->Malloc(matmul_param_->row_align_ * matmul_param_->deep_ * sizeof(float)));
if (pack_input_ == nullptr) {
MS_LOG(ERROR) << "deconv Malloc pack_input_ error!";
return RET_ERROR;
@ -254,6 +256,17 @@ kernel::InnerKernel *CpuDeConvFp32KernelCreator(const std::vector<lite::Tensor *
auto conv_param = reinterpret_cast<ConvParameter *>(op_parameter);
kernel::InnerKernel *kernel = nullptr;
if (conv_param->group_ == 1) {
#ifdef ENABLE_AVX
if ((conv_param->stride_h_ > 1 || conv_param->stride_w_ > 1) &&
(conv_param->dilation_w_ == 1 && conv_param->dilation_h_ == 1) &&
(conv_param->kernel_w_ / conv_param->stride_w_ > 2 || conv_param->kernel_h_ / conv_param->stride_h_ > 2)) {
kernel = new (std::nothrow) kernel::DeConvolutionWinogradCPUKernel(op_parameter, inputs, outputs,
static_cast<const lite::InnerContext *>(ctx));
} else {
kernel = new (std::nothrow)
kernel::DeConvolutionCPUKernel(op_parameter, inputs, outputs, static_cast<const lite::InnerContext *>(ctx));
}
#else
if ((conv_param->stride_h_ != 1 || conv_param->stride_w_ != 1) &&
(conv_param->dilation_w_ == 1 && conv_param->dilation_h_ == 1)) {
kernel = new (std::nothrow) kernel::DeConvolutionWinogradCPUKernel(op_parameter, inputs, outputs,
@ -262,6 +275,7 @@ kernel::InnerKernel *CpuDeConvFp32KernelCreator(const std::vector<lite::Tensor *
kernel = new (std::nothrow)
kernel::DeConvolutionCPUKernel(op_parameter, inputs, outputs, static_cast<const lite::InnerContext *>(ctx));
}
#endif
} else if (conv_param->group_ == conv_param->input_channel_ && conv_param->group_ == conv_param->output_channel_) {
kernel = new (std::nothrow) kernel::DeconvolutionDepthwiseCPUKernel(op_parameter, inputs, outputs,
static_cast<const lite::InnerContext *>(ctx));

View File

@ -54,12 +54,14 @@ class DeConvolutionCPUKernel : public ConvolutionBaseCPUKernel {
int output_plane_ = 0;
int thread_count_ = 1;
int thread_stride_ = 0;
int row_tile_ = 0;
float *weight_ptr_ = nullptr;
float *pack_input_ = nullptr;
float *pack_output_ = nullptr;
float *tmp_buffer_ = nullptr;
float *input_ptr_ = nullptr;
float *output_ptr_ = nullptr;
float *bias_ptr = nullptr;
};
} // namespace mindspore::kernel
#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_DECONVOLUTION_H_