!4449 Add fp16 conv sliding window

Merge pull request !4449 from fuzhiye/tmp
This commit is contained in:
mindspore-ci-bot 2020-08-14 18:29:30 +08:00 committed by Gitee
commit fe834fb15d
14 changed files with 644 additions and 86 deletions

View File

@ -23,10 +23,6 @@ Matrix *TransformMatrixGenerator(int m, int k) {
auto aa = malloc(m * k * sizeof(float));
matrix->SetData(aa);
matrix->SetNum(m, k);
// matrix->data_ = malloc(m * k * sizeof(float));
// matrix->m_ = m;
// matrix->k_ = k;
// matrix->row_major_ = true;
return matrix;
}

View File

@ -65,26 +65,6 @@ class Matrix {
int n_dim_;
bool row_major_;
};
// struct Matrix {
// void *data_;
// int *shape_;
// int *stride_;
// int m_;
// int k_;
// int n_dim_;
// bool row_major_;
// ~Matrix() {
// if (data_ != nullptr) {
// free(data_);
// }
// if (shape_ != nullptr) {
// free(shape_);
// }
// if (shape_ != nullptr) {
// free(stride_);
// }
// }
//};
Matrix *TransformMatrixGenerator(int m, int k);

View File

@ -16,6 +16,7 @@
#include "src/runtime/kernel/arm/fp16/convolution_3x3_fp16.h"
#include "src/runtime/kernel/arm/nnacl/fp16/conv_fp16.h"
#include "src/runtime/kernel/arm/nnacl/fp16/cast_fp16.h"
#include "src/runtime/kernel/arm/nnacl/fp16/winograd_transform_fp16.h"
#include "src/runtime/kernel/arm/nnacl/fp16/pack_fp16.h"
#include "src/runtime/kernel/arm/fp16/layout_transform_fp16.h"
@ -265,11 +266,9 @@ int Convolution3x3FP16CPUKernel::Run() {
return RET_ERROR;
}
auto input_tensor = in_tensors_.at(kInputIndex);
auto input_ele_num = input_tensor->ElementsNum();
auto ori_input_data = reinterpret_cast<float *>(input_tensor->Data());
auto input_element_num = input_tensor->ElementsNum();
for (int i = 0; i < input_element_num; ++i) {
fp16_input_[i] = (float16_t)ori_input_data[i];
}
Float32ToFloat16(ori_input_data, fp16_input_, input_ele_num);
int in_batch = conv_param_->input_batch_;
int in_h = conv_param_->input_h_;
@ -285,12 +284,9 @@ int Convolution3x3FP16CPUKernel::Run() {
// cast fp16 out to fp32 data
auto out_tensor = out_tensors_.at(kOutputIndex);
auto out_ele_num = out_tensor->ElementsNum();
auto output_addr = reinterpret_cast<float *>(out_tensor->Data());
auto output_element_num = out_tensor->ElementsNum();
for (int j = 0; j < output_element_num; ++j) {
output_addr[j] = static_cast<float>(fp16_out_[j]);
}
Float16ToFloat32(fp16_out_, output_addr, out_ele_num);
return RET_OK;
}
} // namespace mindspore::kernel

View File

@ -15,8 +15,11 @@
*/
#include "src/runtime/kernel/arm/fp16/convolution_fp16.h"
#include <vector>
#include "src/runtime/kernel/arm/fp16/convolution_sw_fp16.h"
#include "src/runtime/kernel/arm/fp16/convolution_3x3_fp16.h"
#include "src/runtime/kernel/arm/nnacl/fp16/conv_fp16.h"
#include "src/runtime/kernel/arm/nnacl/fp16/cast_fp16.h"
#include "src/runtime/kernel/arm/nnacl/fp16/pack_fp16.h"
#include "src/runtime/kernel/arm/fp16/layout_transform_fp16.h"
#include "schema/model_generated.h"
@ -231,10 +234,8 @@ int ConvolutionFP16CPUKernel::Run() {
}
auto input_tensor = in_tensors_.at(kInputIndex);
auto ori_input_data = reinterpret_cast<float *>(input_tensor->Data());
auto input_element_num = input_tensor->ElementsNum();
for (int i = 0; i < input_element_num; ++i) {
fp16_input_[i] = (float16_t)ori_input_data[i];
}
auto input_ele_num = input_tensor->ElementsNum();
Float32ToFloat16(ori_input_data, fp16_input_, input_ele_num);
int in_batch = conv_param_->input_batch_;
int in_h = conv_param_->input_h_;
@ -251,10 +252,8 @@ int ConvolutionFP16CPUKernel::Run() {
// cast fp16 out to fp32 data
auto out_tensor = out_tensors_.at(kOutputIndex);
auto output_addr = reinterpret_cast<float *>(out_tensor->Data());
auto output_element_num = out_tensor->ElementsNum();
for (int j = 0; j < output_element_num; ++j) {
output_addr[j] = static_cast<float>(fp16_out_[j]);
}
auto out_ele_num = out_tensor->ElementsNum();
Float16ToFloat32(fp16_out_, output_addr, out_ele_num);
return RET_OK;
}

View File

@ -0,0 +1,269 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "src/runtime/kernel/arm/fp16/convolution_sw_fp16.h"
#include <vector>
#include "src/runtime/kernel/arm/nnacl/fp16/conv_fp16.h"
#include "src/runtime/kernel/arm/nnacl/fp16/cast_fp16.h"
#include "src/runtime/kernel/arm/nnacl/fp16/pack_fp16.h"
#include "src/runtime/kernel/arm/nnacl/fp32/conv_depthwise.h"
#include "src/runtime/kernel/arm/fp16/layout_transform_fp16.h"
#include "schema/model_generated.h"
#include "src/kernel_registry.h"
#include "include/errorcode.h"
#include "src/runtime/runtime_api.h"
using mindspore::kernel::KERNEL_ARCH::kCPU;
using mindspore::lite::KernelRegistrar;
using mindspore::lite::RET_ERROR;
using mindspore::lite::RET_OK;
using mindspore::schema::PrimitiveType_Conv2D;
namespace mindspore::kernel {
int ConvolutionSWFP16CPUKernel::ProcessFilter() {
int kernel_h = conv_param_->kernel_h_;
int kernel_w = conv_param_->kernel_w_;
int in_channel = conv_param_->input_channel_;
int out_channel = conv_param_->output_channel_;
int ic4 = UP_DIV(in_channel, C4NUM);
auto *origin_weight = reinterpret_cast<float *>(in_tensors_.at(kWeightIndex)->Data());
size_t fp16_weight_size = in_channel * out_channel * kernel_h * kernel_w * sizeof(float16_t);
fp16_weight_ = reinterpret_cast<float16_t *>(malloc(fp16_weight_size));
if (fp16_weight_ == nullptr) {
MS_LOG(ERROR) << "malloc fp16_weight_ failed.";
return RET_ERROR;
}
// cast origin fp32 weight data to fp16 data
for (int i = 0; i < fp16_weight_size / sizeof(float16_t); ++i) {
fp16_weight_[i] = (float16_t)origin_weight[i];
}
for (int oc = 0; oc < out_channel; ++oc) {
int src_oc_offset = oc * kernel_h * kernel_w * in_channel;
int dst_oc_offset = oc * kernel_h * kernel_w * ic4 * C4NUM;
for (int i = 0; i < kernel_h * kernel_w; ++i) {
const float16_t *src = fp16_weight_ + src_oc_offset + i * in_channel;
float16_t *dst = packed_weight_ + dst_oc_offset + i * ic4 * C4NUM;
memcpy(dst, src, in_channel * sizeof(float16_t));
}
}
return RET_OK;
}
int ConvolutionSWFP16CPUKernel::InitWeightBias() {
int kernel_h = conv_param_->kernel_h_;
int kernel_w = conv_param_->kernel_w_;
int in_channel = conv_param_->input_channel_;
int out_channel = conv_param_->output_channel_;
int oc4 = UP_DIV(out_channel, C4NUM);
int ic4 = UP_DIV(in_channel, C4NUM);
int kernel_plane = kernel_h * kernel_w;
int pack_weight_size = oc4 * ic4 * C4NUM * C4NUM * kernel_plane;
// init weight
packed_weight_ = reinterpret_cast<float16_t *>(malloc(pack_weight_size * sizeof(float16_t)));
if (packed_weight_ == nullptr) {
MS_LOG(ERROR) << "malloc packed_weight_ failed.";
return RET_ERROR;
}
memset(packed_weight_, 0, pack_weight_size * sizeof(float16_t));
auto ret = ProcessFilter();
if (ret != RET_OK) {
MS_LOG(ERROR) << "Process filter failed.";
return ret;
}
// init bias
bias_data_ = malloc(oc4 * C4NUM * sizeof(float16_t));
if (bias_data_ == nullptr) {
MS_LOG(ERROR) << "malloc bias_data_ failed.";
return RET_ERROR;
}
memset(bias_data_, 0, oc4 * C4NUM * sizeof(float16_t));
auto fp16_bias_data = reinterpret_cast<float16_t *>(bias_data_);
if (in_tensors_.size() == kInputSize2) {
auto ori_bias = reinterpret_cast<float *>(in_tensors_.at(kBiasIndex)->Data());
for (int i = 0; i < out_channel; ++i) {
fp16_bias_data[i] = (float16_t)ori_bias[i];
}
} else {
MS_ASSERT(in_tensor_.size() == kInputSize1);
}
return RET_OK;
}
int ConvolutionSWFP16CPUKernel::InitTmpBuffer() {
int in_channel = conv_param_->input_channel_;
int out_channel = conv_param_->output_channel_;
int channel_block = UP_DIV(in_channel, C4NUM);
int oc4 = UP_DIV(out_channel, C4NUM);
/*=============================fp16_input_============================*/
size_t fp16_input_size =
in_channel * conv_param_->input_batch_ * conv_param_->input_h_ * conv_param_->input_w_ * sizeof(float16_t);
fp16_input_ = reinterpret_cast<float16_t *>(malloc(fp16_input_size));
if (fp16_input_ == nullptr) {
MS_LOG(ERROR) << "malloc fp16_input_ failed.";
return RET_ERROR;
}
/*=============================nhwc4_input_============================*/
size_t nhwc4_input_size = channel_block * C4NUM * conv_param_->input_batch_ * conv_param_->input_h_ *
conv_param_->input_w_ * sizeof(float16_t);
nhwc4_input_ = malloc(nhwc4_input_size);
if (nhwc4_input_ == nullptr) {
MS_LOG(ERROR) << "malloc nhwc4_input_ failed.";
return RET_ERROR;
}
memset(nhwc4_input_, 0, nhwc4_input_size);
/*=============================tmp_output_block_============================*/
tmp_output_block_ = reinterpret_cast<float16_t *>(malloc(conv_param_->output_batch_ * conv_param_->output_h_ *
conv_param_->output_w_ * oc4 * C4NUM * sizeof(float16_t)));
if (tmp_output_block_ == nullptr) {
MS_LOG(ERROR) << "malloc tmp_output_block_ failed.";
return RET_ERROR;
}
/*=============================fp16_out_============================*/
size_t fp16_output_size =
out_channel * conv_param_->output_batch_ * conv_param_->output_h_ * conv_param_->output_w_ * sizeof(float16_t);
fp16_out_ = reinterpret_cast<float16_t *>(malloc(fp16_output_size));
if (fp16_out_ == nullptr) {
MS_LOG(ERROR) << "malloc fp16_out_ failed.";
return RET_ERROR;
}
return RET_OK;
}
void ConvolutionSWFP16CPUKernel::ConfigInputOutput() {
auto input_tensor = in_tensors_.at(kInputIndex);
auto input_format = input_tensor->GetFormat();
schema::Format execute_format = schema::Format_NHWC4;
convert_func_ = LayoutTransformFp16(input_format, execute_format);
if (convert_func_ == nullptr) {
MS_LOG(ERROR) << "layout convert func is nullptr.";
return;
}
}
int ConvolutionSWFP16CPUKernel::Init() {
if (context_->infer_shape_interrupt_ && !context_->running_) {
set_need_reinit();
return RET_OK;
}
auto ret = ConvolutionBaseCPUKernel::Init();
if (ret != RET_OK) {
MS_LOG(ERROR) << "ConvolutionBase init fail!ret: " << ret;
return ret;
}
ret = InitWeightBias();
if (ret != RET_OK) {
MS_LOG(ERROR) << "Init weight bias failed.";
return RET_ERROR;
}
ret = InitTmpBuffer();
if (ret != RET_OK) {
MS_LOG(ERROR) << "Init tmp buffer failed.";
return RET_ERROR;
}
ConfigInputOutput();
// init sliding window param
slidingWindow_param_ = new SlidingWindowParam;
InitSlidingParamConv(slidingWindow_param_, conv_param_, C4NUM);
return RET_OK;
}
int ConvolutionSWFP16CPUKernel::ReSize() {
if (tmp_output_block_ != nullptr) {
free(tmp_output_block_);
}
if (nhwc4_input_ != nullptr) {
free(nhwc4_input_);
}
if (fp16_input_ != nullptr) {
free(fp16_input_);
}
if (fp16_out_ != nullptr) {
free(fp16_out_);
}
delete slidingWindow_param_;
auto ret = ConvolutionBaseCPUKernel::Init();
if (ret != RET_OK) {
MS_LOG(ERROR) << "ConvolutionBase init failed.";
return ret;
}
ret = InitTmpBuffer();
if (ret != RET_OK) {
MS_LOG(ERROR) << "Init tmp buffer failed.";
return RET_ERROR;
}
// init sliding window param
slidingWindow_param_ = new SlidingWindowParam;
InitSlidingParamConv(slidingWindow_param_, conv_param_, C4NUM);
return RET_OK;
}
int ConvolutionSWFP16CPUKernel::RunImpl(int task_id) {
ConvSWFp16(reinterpret_cast<float16_t *>(nhwc4_input_), packed_weight_, reinterpret_cast<float16_t *>(bias_data_),
tmp_output_block_, fp16_out_, task_id, conv_param_, slidingWindow_param_);
return RET_OK;
}
int ConvolutionSWFp16Impl(int task_id, LiteParallelGroupEnv *penv, void *cdata) {
auto conv = reinterpret_cast<ConvolutionSWFP16CPUKernel *>(cdata);
auto error_code = conv->RunImpl(task_id);
if (error_code != RET_OK) {
MS_LOG(ERROR) << "ConvolutionFp16 Run error task_id[" << task_id << "] error_code[" << error_code << "]";
return RET_ERROR;
}
return RET_OK;
}
int ConvolutionSWFP16CPUKernel::Run() {
auto ret = Prepare();
if (ret != RET_OK) {
MS_LOG(ERROR) << "Prepare failed.";
return RET_ERROR;
}
auto input_tensor = in_tensors_.at(kInputIndex);
auto input_ele_num = input_tensor->ElementsNum();
auto ori_input_data = reinterpret_cast<float *>(input_tensor->Data());
Float32ToFloat16(ori_input_data, fp16_input_, input_ele_num);
int in_batch = conv_param_->input_batch_;
int in_h = conv_param_->input_h_;
int in_w = conv_param_->input_w_;
int in_channel = conv_param_->input_channel_;
convert_func_(reinterpret_cast<void *>(fp16_input_), nhwc4_input_, in_batch, in_h * in_w, in_channel);
int error_code = LiteBackendParallelLaunch(ConvolutionSWFp16Impl, this, thread_count_);
if (error_code != RET_OK) {
MS_LOG(ERROR) << "conv fp16 error error_code[" << error_code << "]";
return RET_ERROR;
}
// cast fp16 out to fp32 data
auto out_tensor = out_tensors_.at(kOutputIndex);
auto out_ele_num = out_tensor->ElementsNum();
auto output_addr = reinterpret_cast<float *>(out_tensor->Data());
Float16ToFloat32(fp16_out_, output_addr, out_ele_num);
return RET_OK;
}
} // namespace mindspore::kernel

View File

@ -0,0 +1,69 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP16_CONVOLUTION_SW_FP16_H_
#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP16_CONVOLUTION_SW_FP16_H_
#include <arm_neon.h>
#include <vector>
#include "src/lite_kernel.h"
#include "src/runtime/kernel/arm/base/convolution_base.h"
namespace mindspore::kernel {
class ConvolutionSWFP16CPUKernel : public ConvolutionBaseCPUKernel {
public:
ConvolutionSWFP16CPUKernel(OpParameter *parameter, const std::vector<lite::tensor::Tensor *> &inputs,
const std::vector<lite::tensor::Tensor *> &outputs, const Context *ctx,
const lite::Primitive *primitive)
: ConvolutionBaseCPUKernel(parameter, inputs, outputs, ctx, primitive) {}
~ConvolutionSWFP16CPUKernel() override {
if (fp16_input_ != nullptr) {
free(fp16_input_);
}
if (fp16_weight_ != nullptr) {
free(fp16_weight_);
}
if (fp16_out_ != nullptr) {
free(fp16_out_);
}
if (packed_weight_ != nullptr) {
free(packed_weight_);
}
if (tmp_output_block_ != nullptr) {
free(tmp_output_block_);
}
delete slidingWindow_param_;
}
int Init() override;
int ReSize() override;
int Run() override;
int RunImpl(int task_id);
int InitWeightBias();
int InitTmpBuffer();
void ConfigInputOutput();
int ProcessFilter();
private:
float16_t *fp16_input_;
float16_t *fp16_weight_;
float16_t *fp16_out_;
float16_t *packed_weight_;
float16_t *tmp_output_block_;
SlidingWindowParam *slidingWindow_param_;
};
} // namespace mindspore::kernel
#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP16_CONVOLUTION_SW_FP16_H_

View File

@ -259,6 +259,21 @@ int Convolution3x3CPUKernel::Run() {
MS_LOG(ERROR) << "conv3x3 error error_code[" << error_code << "]";
return RET_ERROR;
}
auto is_relu = conv_param_->is_relu_;
auto is_relu6 = conv_param_->is_relu6_;
auto output_addr = reinterpret_cast<float *>(out_tensors_.at(kOutputIndex)->Data());
PackNC4HW4ToNHWCFp32(nc4hw4_out_, output_addr, conv_param_->output_batch_,
conv_param_->output_h_ * conv_param_->output_w_, conv_param_->output_channel_);
int output_num =
conv_param_->output_channel_ * conv_param_->output_h_ * conv_param_->output_w_ * conv_param_->output_batch_;
if (is_relu) {
ReluFp32(output_addr, output_addr, output_num);
} else if (is_relu6) {
Relu6Fp32(output_addr, output_addr, output_num);
} else {
// do nothing
}
return RET_OK;
}
} // namespace mindspore::kernel

View File

@ -189,8 +189,8 @@ int ConvolutionWinogradCPUKernel::InitTmpBuffer() {
/*=============================tmp_out_data_============================*/
int out_w_block = UP_DIV(output_w, output_unit_);
int out_h_block = UP_DIV(output_h, output_unit_);
tmp_out_data_ = reinterpret_cast<float *>(
malloc(out_w_block * out_h_block * output_unit_ * output_unit_ * oc4 * C4NUM * sizeof(float)));
tmp_out_data_ = reinterpret_cast<float *>(malloc(conv_param_->output_batch_ * out_w_block * out_h_block *
output_unit_ * output_unit_ * oc4 * C4NUM * sizeof(float)));
if (tmp_out_data_ == nullptr) {
MS_LOG(ERROR) << "malloc tmp_out_data_ failed.";
return RET_ERROR;
@ -365,6 +365,22 @@ int ConvolutionWinogradCPUKernel::Run() {
MS_LOG(ERROR) << "conv winograd error error_code[" << error_code << "]";
return RET_ERROR;
}
// get real output
auto out_tensor = out_tensors_.front();
auto out_data = reinterpret_cast<float *>(out_tensor->Data());
UnPackWinogradOutput(tmp_out_data_, out_data, conv_param_->output_batch_, conv_param_->output_h_,
conv_param_->output_w_, conv_param_->output_channel_, output_unit_);
int output_num =
conv_param_->output_channel_ * conv_param_->output_h_ * conv_param_->output_w_ * conv_param_->output_batch_;
if (conv_param_->is_relu_) {
ReluFp32(out_data, out_data, output_num);
} else if (conv_param_->is_relu6_) {
Relu6Fp32(out_data, out_data, output_num);
} else {
// do nothing
}
return RET_OK;
}
} // namespace mindspore::kernel

View File

@ -18,7 +18,6 @@
#include "nnacl/fp16/pack_fp16.h"
#include "nnacl/fp16/winograd_transform_fp16.h"
#ifdef __cplusplus
extern "C" {
#endif
@ -112,6 +111,209 @@ void IndirectGemmFp16_16x8_tmp(float16_t *output, float16_t *input, float16_t *w
}
#endif
void SWBorderPixel(float16_t *dst, const float16_t *src, const float16_t *weight, const float16_t *bias, int height,
int width, int in_kh_step, int in_kw_step, int kernel_h, int kernel_w, int ic, bool is_relu,
bool is_relu6) {
int ic8 = ic / C8NUM;
int ic8_res = ic8 % C8NUM;
int ic4 = ic8_res / C4NUM;
for (int c = 0; c < C4NUM; c++) {
dst[c] = 0;
}
const float16_t *weight_oc = weight;
for (int oc = 0; oc < C4NUM; ++oc) {
const float16_t *weight_kh = weight_oc;
const float16_t *src_kh = src;
for (int kh = 0; kh < height; kh++) {
const float16_t *src_kw = src_kh;
const float16_t *weight_kw = weight_kh;
for (int kw = 0; kw < width; kw++) {
const float16_t *src_ic8 = src_kw;
const float16_t *weight_ic8 = weight_kw;
for (int rc = 0; rc < ic8; ++rc) {
for (int c = 0; c < C8NUM; c++) {
dst[oc] += src_ic8[c] * weight_ic8[c];
}
src_ic8 += C8NUM;
weight_ic8 += C8NUM;
} // ic8 loop
const float16_t *src_ic4 = src_ic8;
const float16_t *weight_ic4 = weight_ic8;
for (int rc = 0; rc < ic4; ++rc) {
for (int c = 0; c < C4NUM; c++) {
dst[oc] += src_ic4[c] * weight_ic4[c];
}
src_ic4 += C4NUM;
weight_ic4 += C4NUM;
} // ic4 loop
src_kw += in_kw_step;
weight_kw += ic4 * C4NUM;
} // kernel_w loop
src_kh += in_kh_step;
weight_kh += kernel_w * ic4 * C4NUM;
} // kernel_h loop
dst[oc] += bias[oc];
dst[oc] = (is_relu) ? (MSMAX(0, dst[oc])) : (dst[oc]);
dst[oc] = (is_relu6) ? (MSMIN(6, MSMAX(0, dst[oc]))) : (dst[oc]);
weight_oc += kernel_h * kernel_w * ic4 * C4NUM;
} // oc loop
}
void SWBorderFp16(float16_t *dst, const float16_t *src, const float16_t *weight, const float16_t *bias, int top,
int bottom, int left, int right, const ConvParameter *conv_param, const SlidingWindowParam *sliding) {
float16_t *dst_h = dst + top * sliding->out_h_step_;
for (int oh = top; oh < bottom; oh++) {
int ih = oh * conv_param->stride_h_ - conv_param->pad_h_;
int start_kh = MSMAX(0, UP_DIV(-ih, conv_param->dilation_h_));
int end_kh = MSMIN(conv_param->kernel_h_, UP_DIV(conv_param->input_h_ - ih, conv_param->dilation_h_));
const float16_t *src_h = src + ih * sliding->in_h_step_;
float16_t *dst_kernel = dst_h + left * sliding->block_channel_;
for (int ow = left; ow < right; ow++) {
int iw = ow * conv_param->stride_w_ - conv_param->pad_w_;
int start_kw = MSMAX(0, UP_DIV(-iw, conv_param->dilation_w_));
int end_kw = MSMIN(conv_param->kernel_w_, UP_DIV(conv_param->input_w_ - iw, conv_param->dilation_w_));
const float16_t *src_w = src_h + iw * sliding->ic4_channel_;
const float16_t *src_kernel = src_w + start_kh * sliding->in_kh_step_ + start_kw * sliding->in_kw_step_;
const float16_t *weight_kernel = weight + (start_kh * conv_param->kernel_w_ + start_kw) * sliding->ic4_channel_;
SWBorderPixel(dst_kernel, src_kernel, weight_kernel, bias, end_kh - start_kh, end_kw - start_kw,
sliding->in_kh_step_, sliding->in_kw_step_, conv_param->kernel_h_, conv_param->kernel_w_,
sliding->ic4_channel_, conv_param->is_relu_, conv_param->is_relu6_);
dst_kernel += sliding->block_channel_;
} // width loop
dst_h += sliding->out_h_step_;
} // height loop
}
void SWCenterFp16(float16_t *dst, const float16_t *src, const float16_t *weight, const float16_t *bias, int height,
int width, int kernel_h, int kernel_w, int out_h_step, int block_channel, int ic, int in_sh_step,
int in_sw_step, int in_kh_step, int in_kw_step, bool is_relu, bool is_relu6) {
int ic8 = ic / C8NUM;
int ic8_res = ic % C8NUM;
int ic4 = ic8_res / C4NUM;
float16_t *dst_h = dst;
const float16_t *src_h = src;
for (int oh = 0; oh < height; oh++) {
float16_t *dst_w = dst_h;
const float16_t *src_w = src_h;
for (int ow = 0; ow < width; ow++) {
const float16_t *weight_oc = weight;
for (int c = 0; c < C4NUM; c++) {
dst_w[c] = 0;
}
for (int oc = 0; oc < C4NUM; oc++) {
const float16_t *weight_kh = weight_oc;
const float16_t *src_kh = src_w;
for (int kh = 0; kh < kernel_h; kh++) {
const float16_t *src_kw = src_kh;
const float16_t *weight_kw = weight_kh;
for (int kw = 0; kw < kernel_w; kw++) {
const float16_t *src_ic8 = src_kw;
const float16_t *weight_ic8 = weight_kw;
for (int rc = 0; rc < ic8; ++rc) {
for (int c = 0; c < C8NUM; c++) {
dst_w[oc] += src_ic8[c] * weight_ic8[c];
}
src_ic8 += C8NUM;
weight_ic8 += C8NUM;
} // ic8 loop
const float16_t *src_ic4 = src_ic8;
const float16_t *weight_ic4 = weight_ic8;
for (int rc = 0; rc < ic4; ++rc) {
for (int c = 0; c < C4NUM; c++) {
dst_w[oc] += src_ic4[c] * weight_ic4[c];
}
src_ic4 += C4NUM;
weight_ic4 += C4NUM;
} // ic4 loop
src_kw += in_kw_step;
weight_kw += ic4 * C4NUM;
} // kernel_w loop
src_kh += in_kh_step;
weight_kh += kernel_w * ic4 * C4NUM;
} // kernel_h loop
// add biad relu
dst_w[oc] += bias[oc];
dst_w[oc] = (is_relu) ? (MSMAX(0, dst_w[oc])) : (dst_w[oc]);
dst_w[oc] = (is_relu6) ? (MSMIN(6, MSMAX(0, dst_w[oc]))) : (dst_w[oc]);
weight_oc += kernel_h * kernel_w * ic4 * C4NUM;
} // oc block
dst_w += block_channel;
src_w += in_sw_step;
} // dst_width loop
dst_h += out_h_step;
src_h += in_sh_step;
} // dst_height loop
}
// fp16 conv sliding window
void ConvSWFp16(const float16_t *input_data, const float16_t *packed_weight, const float16_t *bias_data,
float16_t *tmp_out_block, float16_t *output_data, int task_id, ConvParameter *conv_param,
SlidingWindowParam *slidingWindow_param) {
int oc4_res = conv_param->output_channel_ % C4NUM;
const float16_t *src = input_data;
float16_t *dst;
if (oc4_res == 0) {
dst = output_data;
} else {
dst = tmp_out_block;
}
for (int b = 0; b < conv_param->output_batch_; b++) {
for (int oc = task_id; oc < slidingWindow_param->c_block_; oc += conv_param->thread_num_) {
const float16_t *src_data = src;
float16_t *dst_data = dst + oc * C4NUM;
const float16_t *weight = packed_weight + oc * slidingWindow_param->kernel_step_;
const float16_t *bias = bias_data + oc * C4NUM;
SWBorderFp16(dst_data, src_data, weight, bias, 0, slidingWindow_param->top_, 0, conv_param->output_w_, conv_param,
slidingWindow_param);
SWBorderFp16(dst_data, src_data, weight, bias, slidingWindow_param->bottom_, conv_param->output_h_, 0,
conv_param->output_w_, conv_param, slidingWindow_param);
SWBorderFp16(dst_data, src_data, weight, bias, slidingWindow_param->top_, slidingWindow_param->bottom_, 0,
slidingWindow_param->left_, conv_param, slidingWindow_param);
SWBorderFp16(dst_data, src_data, weight, bias, slidingWindow_param->top_, slidingWindow_param->bottom_,
slidingWindow_param->right_, conv_param->output_w_, conv_param, slidingWindow_param);
if (slidingWindow_param->right_ > slidingWindow_param->left_ &&
slidingWindow_param->bottom_ > slidingWindow_param->top_) {
int in_h_start = slidingWindow_param->top_ * conv_param->stride_h_ - conv_param->pad_h_;
int in_w_start = slidingWindow_param->left_ * conv_param->stride_w_ - conv_param->pad_w_;
const float16_t *in_t =
src_data + in_h_start * slidingWindow_param->in_h_step_ + in_w_start * slidingWindow_param->ic4_channel_;
float16_t *out_t = dst_data + slidingWindow_param->top_ * slidingWindow_param->out_h_step_ +
slidingWindow_param->left_ * slidingWindow_param->block_channel_;
SWCenterFp16(out_t, in_t, weight, bias, slidingWindow_param->bottom_ - slidingWindow_param->top_,
slidingWindow_param->right_ - slidingWindow_param->left_, conv_param->kernel_h_,
conv_param->kernel_w_, slidingWindow_param->out_h_step_, slidingWindow_param->block_channel_,
slidingWindow_param->ic4_channel_, slidingWindow_param->in_sh_step_,
slidingWindow_param->in_sw_step_, slidingWindow_param->in_kh_step_,
slidingWindow_param->in_kw_step_, conv_param->is_relu_, conv_param->is_relu6_);
}
} // output C4 loop
src += slidingWindow_param->in_step_;
dst += slidingWindow_param->out_step_;
} // batch loop
// output nhwc4
if (oc4_res != 0) {
PackNHWC4ToNHWCFp16((const void *)tmp_out_block, (void *)output_data, conv_param->output_batch_,
conv_param->output_h_ * conv_param->output_w_, conv_param->output_channel_);
}
}
// fp16 convolution common (im2col+gemm)
void ConvFp16(float16_t *input_data, float16_t *packed_input, float16_t *packed_weight, float16_t *bias_data,
float16_t *tmp_out_block, float16_t *output_data, int task_id, ConvParameter *conv_param) {
@ -144,7 +346,7 @@ void ConvFp16(float16_t *input_data, float16_t *packed_input, float16_t *packed_
// we write 32 bytes per st1 instruction, after which the pointer in register will step 32B forward
for (int b = 0; b < in_batch; b++) {
int in_batch_offset = b * in_channel * in_h * in_w;
int in_batch_offset = b * ic4 * C4NUM * in_h * in_w;
int out_batch_offset = b * out_channel * out_h * out_w;
int gemm_in_batch_offset = b * packed_input_size;
for (int thread_id = task_id; thread_id < output_tile_count; thread_id += thread_count) {
@ -172,7 +374,6 @@ void ConvFp16(float16_t *input_data, float16_t *packed_input, float16_t *packed_
void Conv3x3Fp16(float16_t *input_data, float16_t *transed_weight, const float16_t *bias_data, float16_t *output_data,
float16_t *tile_buffer, float16_t *block_unit_buffer, float16_t *tmp_dst_buffer, float16_t *tmp_out,
int task_id, ConvParameter *conv_param) {
// todo
int thread_count = conv_param->thread_num_;
int tile_num = 16;
const int output_unit = 4;
@ -195,6 +396,8 @@ void Conv3x3Fp16(float16_t *input_data, float16_t *transed_weight, const float16
int input_batch = conv_param->input_batch_;
for (int batch = 0; batch < input_batch; batch++) {
int in_batch_offset = batch * ic4 * C4NUM * conv_param->input_h_ * conv_param->input_w_;
int tmp_out_batch_offset = batch * oc8 * C8NUM * out_w_block * out_h_block * output_unit * output_unit;
for (int thread_id = task_id; thread_id < output_tile_count; thread_id += thread_count) {
int start_index = thread_id * tile_num;
int real_cal_num = (output_count - start_index) < tile_num ? (output_count - start_index) : tile_num;
@ -207,8 +410,8 @@ void Conv3x3Fp16(float16_t *input_data, float16_t *transed_weight, const float16
tile_buffer + task_id * tile_buffer_offset, transed_weight, NULL, 36, ic4, oc8 * C8NUM,
oc8 * C8NUM * 36 * sizeof(float16_t), 1, 1, 0, 0);
Conv3x3Fp16OutputTransform(tmp_dst_buffer + task_id * tmp_dst_buffer_offset, tmp_out, bias_data, start_index,
real_cal_num, out_w_block, conv_param);
Conv3x3Fp16OutputTransform(tmp_dst_buffer + task_id * tmp_dst_buffer_offset, tmp_out + tmp_out_batch_offset,
bias_data, start_index, real_cal_num, out_w_block, conv_param);
}
}
@ -217,7 +420,10 @@ void Conv3x3Fp16(float16_t *input_data, float16_t *transed_weight, const float16
bool relu = conv_param->is_relu_;
bool relu6 = conv_param->is_relu6_;
for (int batch = 0; batch < output_batch; batch++) {
int batch_size = batch * output_channel * output_h * output_w;
int tmp_out_batch_offset = batch * oc8 * C8NUM * out_w_block * out_h_block * output_unit * output_unit;
int ro_batch_size = batch * output_channel * output_h * output_w;
const float16_t *batch_tmp_out = tmp_out + tmp_out_batch_offset;
float16_t *batch_out = output_data + ro_batch_size;
for (int h = 0; h < output_h; h++) {
for (int w = 0; w < output_w; w++) {
for (int c = 0; c < output_channel; c++) {
@ -226,12 +432,12 @@ void Conv3x3Fp16(float16_t *input_data, float16_t *transed_weight, const float16
int src_offset = oc8_block * C8NUM * out_w_block * out_h_block * C4NUM * C4NUM +
C8NUM * (h * out_w_block * output_unit + w) + oc8_res;
int dst_offset = (h * output_w + w) * output_channel + c;
(output_data + dst_offset)[0] = (tmp_out + src_offset)[0];
(batch_out + dst_offset)[0] = (batch_tmp_out + src_offset)[0];
if (relu) {
(output_data + dst_offset)[0] = (output_data + dst_offset)[0] < 0 ? 0 : (output_data + dst_offset)[0];
(batch_out + dst_offset)[0] = (batch_out + dst_offset)[0] < 0 ? 0 : (batch_out + dst_offset)[0];
} else if (relu6) {
(output_data + dst_offset)[0] = (output_data + dst_offset)[0] < 0 ? 0 : (output_data + dst_offset)[0];
(output_data + dst_offset)[0] = (output_data + dst_offset)[0] > 6 ? 6 : (output_data + dst_offset)[0];
(batch_out + dst_offset)[0] = (batch_out + dst_offset)[0] < 0 ? 0 : (batch_out + dst_offset)[0];
(batch_out + dst_offset)[0] = (batch_out + dst_offset)[0] > 6 ? 6 : (batch_out + dst_offset)[0];
}
}
}

View File

@ -28,6 +28,18 @@ void IndirectGemmFp16_16x8(float16_t *output, float16_t *input, float16_t *weigh
#ifdef __cplusplus
extern "C" {
#endif
void SWBorderFp16(float16_t *dst, const float16_t *src, const float16_t *weight, const float16_t *bias, int top,
int bottom, int left, int right, const ConvParameter *conv_param, const SlidingWindowParam *sliding);
void SWCenterFp16(float16_t *dst, const float16_t *src, const float16_t *weight, const float16_t *bias, int height,
int width, int kernel_h, int kernel_w, int out_h_step, int block_channel, int ic, int in_sh_step,
int in_sw_step, int in_kh_step, int in_kw_step, bool is_relu, bool is_relu6);
// fp16 sliding window
void ConvSWFp16(const float16_t *input_data, const float16_t *packed_weight, const float16_t *bias_data,
float16_t *tmp_out_block, float16_t *output_data, int task_id, ConvParameter *conv_param,
SlidingWindowParam *slidingWindow_param);
// fp16 convolution common (im2col+gemm)
void ConvFp16(float16_t *input_data, float16_t *packed_input, float16_t *packed_weight, float16_t *bias_data,
float16_t *tmp_out_block, float16_t *output_data, int task_id, ConvParameter *conv_param);

View File

@ -219,6 +219,24 @@ void PackNHWCToNHWC4Fp16(const void *src, void *dst, int batch, int plane, int c
}
}
void PackNHWC4ToNHWCFp16(const void *src, void *dst, int batch, int plane, int channel) {
int c4 = UP_DIV(channel, C4NUM);
int ic_remainder_ = channel % C4NUM;
if (ic_remainder_ != 0) {
int nhwc_batch_unit_offset = channel * plane;
for (int b = 0; b < batch; b++) {
int batch_offset = b * c4 * C4NUM * plane;
for (int i = 0; i < plane; i++) {
memcpy((float16_t *)dst + b * nhwc_batch_unit_offset + i * channel,
(float16_t *)src + batch_offset + i * c4 * C4NUM, channel * sizeof(float16_t));
}
}
} else {
size_t ori_input_size = batch * plane * channel * sizeof(float16_t);
memcpy((float16_t *)dst, (float16_t *)src, ori_input_size);
}
}
void PackNCHWToNHWC4Fp16(const void *src, void *dst, int batch, int plane, int channel) {
int nhwc4_batch_offset = 0;
int ic4 = UP_DIV(channel, C4NUM);

View File

@ -41,6 +41,8 @@ void PackNCHWToNC4HW4Fp16(const void *src, void *dst, int batch, int plane, int
void PackNHWCToNHWC4Fp16(const void *src, void *dst, int batch, int plane, int channel);
void PackNHWC4ToNHWCFp16(const void *src, void *dst, int batch, int plane, int channel);
void PackNCHWToNHWC4Fp16(const void *src, void *dst, int batch, int plane, int channel);
void PackNC4HW4ToNHWC4Fp16(const void *src, void *dst, int batch, int plane, int channel);

View File

@ -217,7 +217,7 @@ void ConvFp32(float *input_data, float *packed_input, float *packed_weight, cons
size_t output_offset = out_channel * sizeof(float);
for (int b = 0; b < in_batch; b++) {
int in_batch_offset = b * in_channel * in_h * in_w;
int in_batch_offset = b * ic4 * C4NUM * in_h * in_w;
int out_batch_offset = b * out_channel * out_h * out_w;
int gemm_in_batch_offset = b * packed_input_size;
for (int thread_id = task_id; thread_id < output_tile_count; thread_id += thread_count) {
@ -263,12 +263,9 @@ void ConvWinogardFp32(float *input_data, float *trans_weight, const float *bias_
int output_count = out_w_block * out_h_block;
int output_tile_count = UP_DIV(output_count, TILE_NUM);
int out_channel = conv_param->output_channel_;
int out_batch = conv_param->output_batch_;
int oc4 = UP_DIV(out_channel, C4NUM);
int input_unit_square = input_unit * input_unit;
size_t output_offset = oc4 * C4NUM * input_unit_square * sizeof(float);
bool is_relu = conv_param->is_relu_;
bool is_relu6 = conv_param->is_relu6_;
float *trans_input = buffer_list[0];
float *gemm_out = buffer_list[1];
@ -280,11 +277,13 @@ void ConvWinogardFp32(float *input_data, float *trans_weight, const float *bias_
// step 1 : filter transform (pre-processed offline)
// step 2 : input transform (online)
for (int b = 0; b < in_batch; b++) {
int in_batch_offset = b * ic4 * C4NUM * conv_param->input_h_ * conv_param->input_w_;
int tmp_out_batch_offset = b * out_w_block * out_h_block * out_unit * out_unit * oc4 * C4NUM;
for (int thread_id = task_id; thread_id < output_tile_count; thread_id += thread_num) {
int out_tile_index = thread_id * TILE_NUM;
int cal_num = output_count - thread_id * TILE_NUM;
cal_num = cal_num > TILE_NUM ? TILE_NUM : cal_num;
WinogradInputTransform(input_data, trans_input + task_id * trans_input_offset,
WinogradInputTransform(input_data + in_batch_offset, trans_input + task_id * trans_input_offset,
tmp_data + task_id * tmp_data_offset, cal_num, out_tile_index, out_w_block, conv_param,
input_trans_func);
// step 3 : gemm
@ -292,21 +291,10 @@ void ConvWinogardFp32(float *input_data, float *trans_weight, const float *bias_
input_unit_square, ic4, oc4 * C4NUM, output_offset, 1, 1, 0, 0);
// step 4 : output transform
WinogradOutputTransform(gemm_out + task_id * gemm_out_offset, tmp_out_data, bias_data, cal_num, out_tile_index,
out_w_block, conv_param, output_trans_func);
WinogradOutputTransform(gemm_out + task_id * gemm_out_offset, tmp_out_data + tmp_out_batch_offset, bias_data,
cal_num, out_tile_index, out_w_block, conv_param, output_trans_func);
}
}
// get real output
UnPackWinogradOutput(tmp_out_data, output_data, out_batch, conv_param->output_h_, conv_param->output_w_, out_channel,
out_unit);
int output_num = out_channel * conv_param->output_h_ * conv_param->output_w_ * conv_param->output_batch_;
if (is_relu) {
ReluFp32(output_data, output_data, output_num);
} else if (is_relu6) {
Relu6Fp32(output_data, output_data, output_num);
} else {
// do nothing
}
}
void UnPackWinogradOutput(const float *src, float *dst, int batch, int height, int width, int channel,
@ -360,8 +348,6 @@ void Conv3x3Fp32(float *input_data, float *transed_weight, const float *bias_dat
int output_count = out_w_block * out_h_block;
int output_tile_count = UP_DIV(output_count, TILE_NUM);
int input_unit_square = 4 * 4;
bool is_relu = conv_param->is_relu_;
bool is_relu6 = conv_param->is_relu6_;
float *tile_buffer = buffer_list[0];
float *block_unit_buffer = buffer_list[1];
float *tmp_dst_buffer = buffer_list[2];
@ -372,10 +358,13 @@ void Conv3x3Fp32(float *input_data, float *transed_weight, const float *bias_dat
int input_batch = conv_param->input_batch_;
for (int batch = 0; batch < input_batch; batch++) {
int in_batch_offset = batch * ic4 * C4NUM * conv_param->input_h_ * conv_param->input_w_;
int nc4hw4_buffer_offset = batch * oc4 * C4NUM * conv_param->output_h_ * conv_param->output_w_;
for (int thread_id = task_id; thread_id < output_tile_count; thread_id += thread_count) {
int start_index = thread_id * TILE_NUM;
int real_cal_num = (output_count - start_index) < TILE_NUM ? (output_count - start_index) : TILE_NUM;
Conv3x3Fp32InputTransform(input_data, tile_buffer + task_id * tile_buffer_offset,
Conv3x3Fp32InputTransform(input_data + in_batch_offset, tile_buffer + task_id * tile_buffer_offset,
block_unit_buffer + task_id * block_unit_buffer_offset, start_index, real_cal_num,
out_w_block, conv_param);
@ -383,17 +372,8 @@ void Conv3x3Fp32(float *input_data, float *transed_weight, const float *bias_dat
transed_weight, NULL, input_unit_square, ic4, oc4 * C4NUM,
oc4 * C4NUM * input_unit_square * sizeof(float), 1, 1, 0, 0);
Conv3x3Fp32OutputTransform(tmp_dst_buffer + task_id * tmp_dst_buffer_offset, nc4hw4_out, bias_data, start_index,
real_cal_num, out_w_block, conv_param);
Conv3x3Fp32OutputTransform(tmp_dst_buffer + task_id * tmp_dst_buffer_offset, nc4hw4_out + nc4hw4_buffer_offset,
bias_data, start_index, real_cal_num, out_w_block, conv_param);
}
PackNC4HW4ToNHWCFp32(nc4hw4_out, output_data, 1, conv_param->output_h_ * conv_param->output_w_, output_channel);
}
int output_num = output_channel * conv_param->output_h_ * conv_param->output_w_ * conv_param->output_batch_;
if (is_relu) {
ReluFp32(output_data, output_data, output_num);
} else if (is_relu6) {
Relu6Fp32(output_data, output_data, output_num);
} else {
// do nothing
}
}

View File

@ -1,6 +1,6 @@
hiai_model_0909_kd_rot_ps_softmax.tflite
hiai_chinese_english_recognize_model_float32.tflite
#hiai_bigmodel_ghost_2_1_no_normalized_no_trans_tflite.tflite
hiai_bigmodel_ghost_2_1_no_normalized_no_trans_tflite.tflite
hiai_bigmodel_ghost_5_1_no_normalized_no_trans_tflite.tflite
hiai_cn_recognize_modify_padv2.tflite
hiai_model_normalize_object_scene_ps_20200519.tflite