!4545 [MS][LITE] fix bug of arm cpu fp16 conv op

Merge pull request !4545 from yangruoqi713/lite
This commit is contained in:
mindspore-ci-bot 2020-08-16 23:03:57 +08:00 committed by Gitee
commit da9dc49d6c
8 changed files with 63 additions and 35 deletions

View File

@ -92,9 +92,8 @@ int Convolution1x1FP16CPUKernel::InitWeightBias() {
return RET_ERROR;
}
memset(weight_ptr_, 0, matmul_param_->deep_ * matmul_param_->col_8_ * sizeof(float16_t));
RowMajor2Col8MajorFp16(reinterpret_cast<float16_t *>(execute_weight_), weight_ptr_, matmul_param_->col_,
matmul_param_->deep_);
ColMajor2Row8MajorFp16(reinterpret_cast<float16_t *>(execute_weight_), weight_ptr_, matmul_param_->deep_,
matmul_param_->col_);
return RET_OK;
}
@ -159,7 +158,7 @@ void Convolution1x1FP16CPUKernel::Pre1x1Trans(float16_t *src_input, float16_t *s
input_ptr_ = src_input;
}
RowMajor2Col8MajorFp16(input_ptr_, pack_input_, matmul_param_->row_, matmul_param_->deep_);
RowMajor2Col16MajorFp16(input_ptr_, pack_input_, matmul_param_->row_, matmul_param_->deep_);
return;
}

View File

@ -35,15 +35,12 @@ class Convolution1x1FP16CPUKernel : public ConvolutionBaseFP16CPUKernel {
matmul_param_ = new MatMulParameter();
}
~Convolution1x1FP16CPUKernel() override {
if (fp16_weight_ != nullptr) {
free(fp16_weight_);
}
if (input_ptr_ != nullptr) {
free(input_ptr_);
}
if (weight_ptr_ != nullptr) {
free(weight_ptr_);
}
if (pack_input_ != nullptr) {
free(pack_input_);
}
delete matmul_param_;
}

View File

@ -255,7 +255,7 @@ int Convolution3x3FP16CPUKernel::Run() {
bool relu6 = conv_param_->is_relu6_;
for (int batch = 0; batch < conv_param_->output_batch_; batch++) {
int tmp_out_batch_offset =
batch * oc8 * C8NUM * out_w_block * out_h_block * conv_param_->output_unit_ * conv_param_->output_unit_;
batch * oc8 * C8NUM * out_w_block * out_h_block * C4NUM * C4NUM;
int ro_batch_size = batch * conv_param_->output_channel_ * conv_param_->output_h_ * conv_param_->output_w_;
const float16_t *batch_tmp_out = tmp_out_ + tmp_out_batch_offset;
float16_t *batch_out = execute_output_ + ro_batch_size;
@ -265,7 +265,7 @@ int Convolution3x3FP16CPUKernel::Run() {
int oc8_block = c / C8NUM;
int oc8_res = c % C8NUM;
int src_offset = oc8_block * C8NUM * out_w_block * out_h_block * C4NUM * C4NUM +
C8NUM * (h * out_w_block * conv_param_->output_unit_ + w) + oc8_res;
C8NUM * (h * out_w_block * C4NUM + w) + oc8_res;
int dst_offset = (h * conv_param_->output_w_ + w) * conv_param_->output_channel_ + c;
(batch_out + dst_offset)[0] = (batch_tmp_out + src_offset)[0];
if (relu) {

View File

@ -47,7 +47,7 @@ int ConvolutionBaseFP16CPUKernel::GetExecuteFilter() {
if (weight_data_type == kNumberTypeFloat32) {
float *origin_weight = reinterpret_cast<float *>(in_tensors_.at(kWeightIndex)->Data());
size_t fp16_weight_size = conv_param_->input_channel_ * conv_param_->output_channel_ * conv_param_->kernel_h_ *
conv_param_->input_w_ * sizeof(float16_t);
conv_param_->kernel_w_ * sizeof(float16_t);
fp16_weight_ = reinterpret_cast<float16_t *>(malloc(fp16_weight_size));
if (fp16_weight_ == nullptr) {
MS_LOG(ERROR) << "malloc fp16_weight_ failed.";

View File

@ -219,8 +219,8 @@ int ConvolutionFP16CPUKernel::Run() {
return RET_ERROR;
}
ConvolutionBaseFP16CPUKernel::FreeTmpBuffer();
ConvolutionBaseFP16CPUKernel::IfCastOutput();
ConvolutionBaseFP16CPUKernel::FreeTmpBuffer();
return RET_OK;
}

View File

@ -170,7 +170,7 @@ int DeConvolutionFp16CPUKernel::Run() {
ConvolutionBaseFP16CPUKernel::GetExecuteTensor();
for (int batch_index = 0; batch_index < conv_param_->input_batch_; batch_index++) {
RowMajor2Col8MajorFp16(execute_input_, pack_input_, input_plane_, conv_param_->input_channel_);
RowMajor2Col16MajorFp16(execute_input_, pack_input_, input_plane_, conv_param_->input_channel_);
int error_code = LiteBackendParallelLaunch(DeConvFp16Run, this, thread_count_);
if (error_code != RET_OK) {

View File

@ -15,27 +15,57 @@
*/
#include "nnacl/fp16/matmul_fp16.h"
void ColMajor2Row8MajorFp16(float16_t *src_ptr, float16_t *dst_ptr, size_t row, size_t col) {
for (int r = 0; r < row; r++) {
for (int c = 0; c < col; c++) {
int cd8 = c / 8;
int cm8 = c % 8;
dst_ptr[cd8 * 8 * row + r * 8 + cm8] = src_ptr[c * row + r];
}
}
}
void MatMul16x8(const float16_t *a, const float16_t *b, float16_t *dst, const float16_t *bias, ActType act_type,
int deep, int row, int col, int stride, bool write_nhwc) {
int row_16 = UP_ROUND(row, C16NUM);
int col_8 = UP_ROUND(col, C8NUM);
/* col16-major * row8-major => row16x8-major */
if (write_nhwc) return;
for (int r = 0; r < row_16; r++) {
for (int c = 0; c < col_8; c++) {
int r16div = r / C16NUM, r16mod = r % C16NUM;
int c8div = c / C8NUM, c8mod = c % C8NUM;
size_t ci = c8div * row_16 * C8NUM + r * C8NUM + c8mod;
float16_t value = 0;
for (int d = 0; d < deep; d++) {
size_t ai = r16div * deep * C16NUM + d * C16NUM + r16mod;
size_t bi = c8div * deep * C8NUM + d * C8NUM + c8mod;
value = value + a[ai] * b[bi];
if (write_nhwc) {
/* col16-major * row8-major => col-major */
for (int r = 0; r < row; r++) {
for (int c = 0; c < col; c++) {
int r16div = r / C16NUM, r16mod = r % C16NUM;
int c8div = c / C8NUM, c8mod = c % C8NUM;
size_t ci = r * stride + c;
float value = 0;
for (int d = 0; d < deep; d++) {
size_t ai = r16div * deep * C16NUM + d * C16NUM + r16mod;
size_t bi = c8div * deep * C8NUM + d * C8NUM + c8mod;
value = value + a[ai] * b[bi];
}
if (bias != NULL) value += bias[c];
if (act_type == ActType_Relu6) value = MSMIN(6.0f, value);
if (act_type != ActType_No) value = MSMAX(0.0f, value);
dst[ci] = value;
}
}
} else {
/* col16-major * row8-major => row16x8-major */
for (int r = 0; r < row_16; r++) {
for (int c = 0; c < col_8; c++) {
int r16div = r / C16NUM, r16mod = r % C16NUM;
int c8div = c / C8NUM, c8mod = c % C8NUM;
size_t ci = c8div * row_16 * C8NUM + r * C8NUM + c8mod;
float16_t value = 0;
for (int d = 0; d < deep; d++) {
size_t ai = r16div * deep * C16NUM + d * C16NUM + r16mod;
size_t bi = c8div * deep * C8NUM + d * C8NUM + c8mod;
value = value + a[ai] * b[bi];
}
if (bias != NULL) value += bias[col];
if (act_type == ActType_Relu6) value = MSMIN(6.0f, value);
if (act_type != ActType_No) value = MSMAX(0.0f, value);
dst[ci] = value;
}
if (bias != NULL) value += bias[col];
if (act_type == ActType_Relu6) value = MSMIN(6.0f, value);
if (act_type != ActType_No) value = MSMAX(0.0f, value);
dst[ci] = value;
}
}
return;
@ -43,12 +73,12 @@ void MatMul16x8(const float16_t *a, const float16_t *b, float16_t *dst, const fl
void MatMulFp16(const float16_t *a, const float16_t *b, float16_t *c, const float16_t *bias, ActType act_type,
int depth, int row, int col, int stride, bool write_nhwc) {
// MatmulFp16Neon64(a, b, c, bias, (int)act_type, depth, row, col, stride, write_nhwc);
MatMul16x8(a, b, c, bias, (int)act_type, depth, row, col, stride, write_nhwc);
MatmulFp16Neon64(a, b, c, bias, (int)act_type, depth, row, col, stride, write_nhwc);
// MatMul16x8(a, b, c, bias, (int)act_type, depth, row, col, stride, write_nhwc);
return;
}
void RowMajor2Col8MajorFp16(float16_t *src_ptr, float16_t *dst_ptr, size_t row, size_t col) {
void RowMajor2Col16MajorFp16(float16_t *src_ptr, float16_t *dst_ptr, size_t row, size_t col) {
size_t row16 = row / C16NUM * C16NUM;
size_t col8 = col / C8NUM * C8NUM;
float16_t *src_r = src_ptr;

View File

@ -32,7 +32,9 @@ extern "C" {
void MatMulFp16(const float16_t *a, const float16_t *b, float16_t *c, const float16_t *bias, ActType act_type,
int depth, int row, int col, int stride, bool write_nhwc);
void RowMajor2Col8MajorFp16(float16_t *src_ptr, float16_t *dst_ptr, size_t row, size_t col);
void ColMajor2Row8MajorFp16(float16_t *src_ptr, float16_t *dst_ptr, size_t row, size_t col);
void RowMajor2Col16MajorFp16(float16_t *src_ptr, float16_t *dst_ptr, size_t row, size_t col);
void MatmulFp16Neon64(const float16_t *a, const float16_t *b, float16_t *c, const float16_t *bias, int act_type,
size_t depth, size_t row, size_t col, size_t stride, bool write_nhwc);