forked from mindspore-Ecosystem/mindspore
1. add fp32 conv slide window
2. enable int8 conv multi thread execute
This commit is contained in:
parent
23c4c072ce
commit
2b606e87b4
|
@ -15,6 +15,7 @@
|
|||
*/
|
||||
|
||||
#include "src/runtime/kernel/arm/fp32/convolution.h"
|
||||
#include "src/runtime/kernel/arm/fp32/convolution_slidewindow.h"
|
||||
#include "src/runtime/kernel/arm/fp32/convolution_1x1.h"
|
||||
#include "src/runtime/kernel/arm/fp32/convolution_3x3.h"
|
||||
#include "src/runtime/kernel/arm/fp32/convolution_winograd.h"
|
||||
|
@ -230,6 +231,19 @@ int ConvolutionCPUKernel::Run() {
|
|||
return RET_OK;
|
||||
}
|
||||
|
||||
bool CheckIfUseSlideWindow(ConvParameter *conv_param) {
|
||||
int in_channel = conv_param->input_channel_;
|
||||
int out_h = conv_param->output_h_;
|
||||
int out_w = conv_param->output_w_;
|
||||
int out_channel = conv_param->output_channel_;
|
||||
int ic4 = UP_DIV(in_channel, C4NUM);
|
||||
int oc4 = UP_DIV(out_channel, C4NUM);
|
||||
if (out_h * out_w <= 32 || ic4 < 4 || oc4 < 4) {
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
kernel::LiteKernel *CpuConvFp32KernelCreator(const std::vector<lite::tensor::Tensor *> &inputs,
|
||||
const std::vector<lite::tensor::Tensor *> &outputs,
|
||||
OpParameter *opParameter, const Context *ctx,
|
||||
|
@ -252,6 +266,8 @@ kernel::LiteKernel *CpuConvFp32KernelCreator(const std::vector<lite::tensor::Ten
|
|||
InputTransformUnitFunc input_trans_func = nullptr;
|
||||
OutputTransformUnitFunc output_trans_func = nullptr;
|
||||
CheckIfUseWinograd(&use_winograd, &out_unit, conv_param, input_trans_func, output_trans_func);
|
||||
bool use_sw = CheckIfUseSlideWindow(conv_param);
|
||||
|
||||
kernel::LiteKernel *kernel;
|
||||
if (kernel_h == 1 && kernel_w == 1) {
|
||||
kernel = new (std::nothrow) kernel::Convolution1x1CPUKernel(opParameter, inputs, outputs, ctx, primitive);
|
||||
|
@ -260,6 +276,8 @@ kernel::LiteKernel *CpuConvFp32KernelCreator(const std::vector<lite::tensor::Ten
|
|||
} else if (use_winograd) {
|
||||
kernel =
|
||||
new (std::nothrow) kernel::ConvolutionWinogradCPUKernel(opParameter, inputs, outputs, ctx, primitive, out_unit);
|
||||
} else if (use_sw) {
|
||||
kernel = new (std::nothrow) kernel::ConvolutionSWCPUKernel(opParameter, inputs, outputs, ctx, primitive);
|
||||
} else {
|
||||
kernel = new (std::nothrow) kernel::ConvolutionCPUKernel(opParameter, inputs, outputs, ctx, primitive);
|
||||
}
|
||||
|
|
|
@ -24,8 +24,8 @@
|
|||
using mindspore::kernel::KERNEL_ARCH::kCPU;
|
||||
using mindspore::lite::KernelRegistrar;
|
||||
using mindspore::lite::RET_ERROR;
|
||||
using mindspore::lite::RET_OK;
|
||||
using mindspore::lite::RET_INFER_INVALID;
|
||||
using mindspore::lite::RET_OK;
|
||||
using mindspore::schema::PrimitiveType_DepthwiseConv2D;
|
||||
|
||||
namespace mindspore::kernel {
|
||||
|
@ -96,7 +96,7 @@ int ConvolutionDepthwiseCPUKernel::Init() {
|
|||
|
||||
// init sliding window param
|
||||
sliding_ = new SlidingWindowParam;
|
||||
InitSlidingParam(sliding_, conv_param_, C4NUM);
|
||||
InitSlidingParamConvDw(sliding_, conv_param_, C4NUM);
|
||||
|
||||
auto ret = InitWeightBias();
|
||||
if (ret != 0) {
|
||||
|
@ -122,7 +122,7 @@ int ConvolutionDepthwiseCPUKernel::ReSize() {
|
|||
|
||||
// init sliding window param
|
||||
sliding_ = new SlidingWindowParam;
|
||||
InitSlidingParam(sliding_, conv_param_, C4NUM);
|
||||
InitSlidingParamConvDw(sliding_, conv_param_, C4NUM);
|
||||
|
||||
auto ret = InitBuffer();
|
||||
if (ret != 0) {
|
||||
|
|
|
@ -0,0 +1,211 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "src/runtime/kernel/arm/fp32/convolution_slidewindow.h"
|
||||
#include "src/runtime/kernel/arm/nnacl/common_func.h"
|
||||
#include "schema/model_generated.h"
|
||||
#include "src/kernel_factory.h"
|
||||
#include "include/errorcode.h"
|
||||
#include "src/runtime/runtime_api.h"
|
||||
|
||||
namespace mindspore::kernel {
|
||||
using mindspore::kernel::KERNEL_ARCH::kCPU;
|
||||
using mindspore::lite::KernelRegistrar;
|
||||
using mindspore::lite::RET_ERROR;
|
||||
using mindspore::lite::RET_INFER_INVALID;
|
||||
using mindspore::lite::RET_OK;
|
||||
using mindspore::schema::PrimitiveType_Conv2D;
|
||||
|
||||
int ConvolutionSWCPUKernel::InitWeightBias() {
|
||||
int kernel_h = conv_param_->kernel_h_;
|
||||
int kernel_w = conv_param_->kernel_w_;
|
||||
int in_channel = conv_param_->input_channel_;
|
||||
int out_channel = conv_param_->output_channel_;
|
||||
int ic4 = UP_DIV(in_channel, C4NUM);
|
||||
int kernel_plane = kernel_h * kernel_w;
|
||||
int oc_block = C4NUM;
|
||||
int oc_block_num = UP_DIV(out_channel, C4NUM);
|
||||
int pack_weight_size = oc_block_num * oc_block * ic4 * C4NUM * kernel_plane;
|
||||
|
||||
// ==================================init weight======================================//
|
||||
auto origin_weight = reinterpret_cast<float *>(inputs_.at(kWeightIndex)->Data());
|
||||
packed_weight_ = reinterpret_cast<float *>(malloc(pack_weight_size * sizeof(float)));
|
||||
if (packed_weight_ == nullptr) {
|
||||
MS_LOG(ERROR) << "malloc packed weight failed.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
memset(packed_weight_, 0, pack_weight_size * sizeof(float));
|
||||
for (int oc = 0; oc < out_channel; ++oc) {
|
||||
int src_oc_offset = oc * kernel_h * kernel_w * in_channel;
|
||||
int dst_oc_offset = oc * kernel_h * kernel_w * ic4 * C4NUM;
|
||||
for (int i = 0; i < kernel_h * kernel_w; ++i) {
|
||||
const float *src = origin_weight + src_oc_offset + i * in_channel;
|
||||
float *dst = packed_weight_ + dst_oc_offset + i * ic4 * C4NUM;
|
||||
memcpy(dst, src, in_channel * sizeof(float));
|
||||
}
|
||||
}
|
||||
|
||||
// ====================================init bias====================================== //
|
||||
bias_data_ = reinterpret_cast<float *>(malloc(oc_block_num * oc_block * sizeof(float)));
|
||||
if (bias_data_ == nullptr) {
|
||||
MS_LOG(ERROR) << "malloc bias failed.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
memset(bias_data_, 0, oc_block_num * oc_block * sizeof(float));
|
||||
if (inputs_.size() == kInputSize2) {
|
||||
auto ori_bias = reinterpret_cast<float *>(inputs_.at(kBiasIndex)->Data());
|
||||
memcpy(bias_data_, ori_bias, out_channel * sizeof(float));
|
||||
} else {
|
||||
MS_ASSERT(inputs_.size() == kInputSize1);
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
int ConvolutionSWCPUKernel::InitTmpBuffer() {
|
||||
int in_channel = conv_param_->input_channel_;
|
||||
int ic4 = UP_DIV(in_channel, C4NUM);
|
||||
int out_channel = conv_param_->output_channel_;
|
||||
int oc4 = UP_DIV(out_channel, C4NUM);
|
||||
|
||||
/*=============================nhwc4_input_============================*/
|
||||
size_t nhwc4_input_size =
|
||||
ic4 * C4NUM * conv_param_->input_batch_ * conv_param_->input_h_ * conv_param_->input_w_ * sizeof(float);
|
||||
nhwc4_input_ = malloc(nhwc4_input_size);
|
||||
if (nhwc4_input_ == nullptr) {
|
||||
MS_LOG(ERROR) << "malloc nhwc4 input failed.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
memset(nhwc4_input_, 0, nhwc4_input_size);
|
||||
|
||||
/*=============================tmp_output_block_============================*/
|
||||
tmp_output_block_ = reinterpret_cast<float *>(
|
||||
malloc(conv_param_->output_batch_ * conv_param_->output_h_ * conv_param_->output_w_ * oc4 * C4NUM * sizeof(float)));
|
||||
if (tmp_output_block_ == nullptr) {
|
||||
MS_LOG(ERROR) << "malloc tmp output block failed.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
void ConvolutionSWCPUKernel::ConfigInputOutput() {
|
||||
// set output format
|
||||
auto output_tensor = outputs_.at(kOutputIndex);
|
||||
output_tensor->SetFormat(schema::Format_NHWC);
|
||||
|
||||
// select trans func for input
|
||||
auto input_tensor = inputs_.at(kInputIndex);
|
||||
auto ret = CheckLayout(input_tensor);
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "Check layout failed.";
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
int ConvolutionSWCPUKernel::Init() {
|
||||
if (context_->infer_shape_interrupt_ && !context_->running_) {
|
||||
SetNeedReInit();
|
||||
return RET_OK;
|
||||
}
|
||||
auto ret = ConvolutionBaseCPUKernel::Init();
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "ConvolutionBase init failed.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
ret = InitWeightBias();
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "Init weight bias failed.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
// init tmp input, output
|
||||
ret = InitTmpBuffer();
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "Init tmp buffer failed.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
// init sliding window param
|
||||
slidingWindow_param_ = new SlidingWindowParam;
|
||||
InitSlidingParamConv(slidingWindow_param_, conv_param_, C4NUM);
|
||||
|
||||
// config input output
|
||||
ConfigInputOutput();
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
int ConvolutionSWCPUKernel::ReSize() {
|
||||
if (tmp_output_block_ != nullptr) {
|
||||
free(tmp_output_block_);
|
||||
}
|
||||
if (nhwc4_input_ != nullptr) {
|
||||
free(nhwc4_input_);
|
||||
}
|
||||
delete slidingWindow_param_;
|
||||
|
||||
auto ret = ConvolutionBaseCPUKernel::Init();
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "ConvolutionBase init failed.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
// init tmp input, output
|
||||
ret = InitTmpBuffer();
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "Init tmp buffer failed.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
// init sliding window param
|
||||
slidingWindow_param_ = new SlidingWindowParam;
|
||||
InitSlidingParamConv(slidingWindow_param_, conv_param_, C4NUM);
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
int ConvolutionSWCPUKernel::RunImpl(int task_id) {
|
||||
auto output_addr = reinterpret_cast<float *>(outputs_.at(kOutputIndex)->Data());
|
||||
ConvSWFp32(reinterpret_cast<float *>(nhwc4_input_), packed_weight_, reinterpret_cast<float *>(bias_data_),
|
||||
tmp_output_block_, output_addr, task_id, conv_param_, slidingWindow_param_);
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
int ConvolutionSWImpl(int task_id, LiteParallelGroupEnv *penv, void *cdata) {
|
||||
auto conv = reinterpret_cast<ConvolutionSWCPUKernel *>(cdata);
|
||||
auto error_code = conv->RunImpl(task_id);
|
||||
if (error_code != RET_OK) {
|
||||
MS_LOG(ERROR) << "Convolution Sliding Window Run error task_id[" << task_id << "] error_code[" << error_code << "]";
|
||||
return RET_ERROR;
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
int ConvolutionSWCPUKernel::Run() {
|
||||
auto prepare_ret = Prepare();
|
||||
if (prepare_ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "Prepare fail!ret: " << prepare_ret;
|
||||
return prepare_ret;
|
||||
}
|
||||
auto input_tensor = inputs_.at(kInputIndex);
|
||||
auto ori_input_data = input_tensor->Data();
|
||||
int in_batch = conv_param_->input_batch_;
|
||||
int in_h = conv_param_->input_h_;
|
||||
int in_w = conv_param_->input_w_;
|
||||
int in_channel = conv_param_->input_channel_;
|
||||
convert_func_(ori_input_data, nhwc4_input_, in_batch, in_h * in_w, in_channel);
|
||||
|
||||
int error_code = LiteBackendParallelLaunch(ConvolutionSWImpl, this, thread_count_);
|
||||
if (error_code != RET_OK) {
|
||||
MS_LOG(ERROR) << "conv error error_code[" << error_code << "]";
|
||||
return RET_ERROR;
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
} // namespace mindspore::kernel
|
|
@ -0,0 +1,59 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_CONVOLUTION_SLIDEWINDOW_H_
|
||||
#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_CONVOLUTION_SLIDEWINDOW_H_
|
||||
|
||||
#include <vector>
|
||||
#include "src/lite_kernel.h"
|
||||
#include "src/runtime/kernel/arm/nnacl/op_base.h"
|
||||
#include "src/runtime/kernel/arm/base/convolution_base.h"
|
||||
#include "src/runtime/kernel/arm/nnacl/fp32/conv.h"
|
||||
#include "src/runtime/kernel/arm/nnacl/fp32/conv_depthwise.h"
|
||||
|
||||
namespace mindspore::kernel {
|
||||
class ConvolutionSWCPUKernel : public ConvolutionBaseCPUKernel {
|
||||
public:
|
||||
ConvolutionSWCPUKernel(OpParameter *parameter, const std::vector<lite::tensor::Tensor *> &inputs,
|
||||
const std::vector<lite::tensor::Tensor *> &outputs, const lite::Context *ctx,
|
||||
const lite::Primitive *primitive)
|
||||
: ConvolutionBaseCPUKernel(parameter, inputs, outputs, ctx, primitive) {}
|
||||
|
||||
~ConvolutionSWCPUKernel() override {
|
||||
if (packed_weight_ != nullptr) {
|
||||
free(packed_weight_);
|
||||
}
|
||||
if (tmp_output_block_ != nullptr) {
|
||||
free(tmp_output_block_);
|
||||
}
|
||||
delete slidingWindow_param_;
|
||||
};
|
||||
|
||||
int Init() override;
|
||||
int ReSize() override;
|
||||
int Run() override;
|
||||
int RunImpl(int task_id);
|
||||
int InitWeightBias();
|
||||
int InitTmpBuffer();
|
||||
void ConfigInputOutput();
|
||||
|
||||
private:
|
||||
float *packed_weight_;
|
||||
float *tmp_output_block_;
|
||||
SlidingWindowParam *slidingWindow_param_;
|
||||
};
|
||||
} // namespace mindspore::kernel
|
||||
#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_CONVOLUTION_SLIDEWINDOW_H_
|
|
@ -39,7 +39,7 @@ int DeconvolutionDepthwiseCPUKernel::InitSlideParam() {
|
|||
|
||||
// init sliding window param
|
||||
sliding_ = new SlidingWindowParam;
|
||||
InitSlidingParam(sliding_, conv_param_, C4NUM);
|
||||
InitSlidingParamConvDw(sliding_, conv_param_, C4NUM);
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
|
|
|
@ -90,7 +90,7 @@ int ConvolutionDepthwiseInt8CPUKernel::Init() {
|
|||
|
||||
// init sliding window param
|
||||
sliding = new SlidingWindowParam;
|
||||
InitSlidingParam(sliding, conv_param_, C4NUM);
|
||||
InitSlidingParamConvDw(sliding, conv_param_, C4NUM);
|
||||
|
||||
// init quant param
|
||||
ConvolutionBaseCPUKernel::SetQuantParam();
|
||||
|
@ -119,7 +119,7 @@ int ConvolutionDepthwiseInt8CPUKernel::ReSize() {
|
|||
ConvolutionBaseCPUKernel::Init();
|
||||
|
||||
// init sliding window param
|
||||
InitSlidingParam(sliding, conv_param_, C4NUM);
|
||||
InitSlidingParamConvDw(sliding, conv_param_, C4NUM);
|
||||
|
||||
// init quant param
|
||||
ConvolutionBaseCPUKernel::SetQuantParam();
|
||||
|
|
|
@ -67,7 +67,7 @@ int DeconvolutionDepthwiseInt8CPUKernel::InitSlideParam() {
|
|||
conv_param_->output_channel_ = inputs_.front()->shape().at(kNHWC_C);
|
||||
|
||||
// init sliding window param
|
||||
InitSlidingParam(sliding, conv_param_, C4NUM);
|
||||
InitSlidingParamConvDw(sliding, conv_param_, C4NUM);
|
||||
|
||||
sliding->in_h_step_ = conv_param_->input_w_ * C4NUM;
|
||||
sliding->in_sh_step_ = conv_param_->input_w_ * C4NUM * conv_param_->stride_h_; // stride H
|
||||
|
|
|
@ -55,4 +55,24 @@ typedef struct ConvParameter {
|
|||
bool is_relu6_;
|
||||
} ConvParameter;
|
||||
|
||||
|
||||
typedef struct SlidingWindowParam {
|
||||
int left_;
|
||||
int right_;
|
||||
int top_;
|
||||
int bottom_;
|
||||
int c_block_;
|
||||
int block_channel_;
|
||||
int ic4_channel_;
|
||||
int out_step_;
|
||||
int out_h_step_;
|
||||
int in_step_;
|
||||
int in_h_step_;
|
||||
int in_sh_step_; // stride H
|
||||
int in_sw_step_; // stride W
|
||||
int in_kh_step_; // kernel H
|
||||
int in_kw_step_; // kernel W
|
||||
int kernel_step_;
|
||||
} SlidingWindowParam;
|
||||
|
||||
#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_CONV_PARAMETER_H_
|
||||
|
|
|
@ -18,6 +18,177 @@
|
|||
#include <string.h>
|
||||
#include "nnacl/winograd_transform.h"
|
||||
|
||||
void SWBorderPixel(float *dst, const float *src, const float *weight, const float *bias, int height, int width,
|
||||
int in_kh_step, int in_kw_step, int kernel_h, int kernel_w, int ic4, bool is_relu, bool is_relu6) {
|
||||
for (int c = 0; c < C4NUM; c++) {
|
||||
dst[c] = 0;
|
||||
}
|
||||
const float *weight_oc = weight;
|
||||
for (int oc = 0; oc < C4NUM; ++oc) {
|
||||
const float *weight_kh = weight_oc;
|
||||
const float *src_kh = src;
|
||||
for (int kh = 0; kh < height; kh++) {
|
||||
const float *src_kw = src_kh;
|
||||
const float *weight_kw = weight_kh;
|
||||
for (int kw = 0; kw < width; kw++) {
|
||||
const float *src_ic4 = src_kw;
|
||||
const float *weight_ic4 = weight_kw;
|
||||
for (int ic = 0; ic < ic4; ++ic) {
|
||||
for (int c = 0; c < C4NUM; c++) {
|
||||
dst[oc] += src_ic4[c] * weight_ic4[c];
|
||||
}
|
||||
src_ic4 += C4NUM;
|
||||
weight_ic4 += C4NUM;
|
||||
} // ic4 loop
|
||||
src_kw += in_kw_step;
|
||||
weight_kw += ic4 * C4NUM;
|
||||
} // kernel_w loop
|
||||
src_kh += in_kh_step;
|
||||
weight_kh += kernel_w * ic4 * C4NUM;
|
||||
} // kernel_h loop
|
||||
dst[oc] += bias[oc];
|
||||
dst[oc] = (is_relu) ? (MSMAX(0, dst[oc])) : (dst[oc]);
|
||||
dst[oc] = (is_relu6) ? (MSMIN(6, MSMAX(0, dst[oc]))) : (dst[oc]);
|
||||
weight_oc += kernel_h * kernel_w * ic4 * C4NUM;
|
||||
} // oc loop
|
||||
}
|
||||
|
||||
void SWBorder(float *dst, const float *src, const float *weight, const float *bias, int top, int bottom, int left,
|
||||
int right, const ConvParameter *conv_param, const SlidingWindowParam *sliding) {
|
||||
int ic4 = sliding->ic4_channel_ / C4NUM;
|
||||
float *dst_h = dst + top * sliding->out_h_step_;
|
||||
for (int oh = top; oh < bottom; oh++) {
|
||||
int ih = oh * conv_param->stride_h_ - conv_param->pad_h_;
|
||||
int start_kh = MSMAX(0, UP_DIV(-ih, conv_param->dilation_h_));
|
||||
int end_kh = MSMIN(conv_param->kernel_h_, UP_DIV(conv_param->input_h_ - ih, conv_param->dilation_h_));
|
||||
const float *src_h = src + ih * sliding->in_h_step_;
|
||||
|
||||
float *dst_kernel = dst_h + left * sliding->block_channel_;
|
||||
for (int ow = left; ow < right; ow++) {
|
||||
int iw = ow * conv_param->stride_w_ - conv_param->pad_w_;
|
||||
int start_kw = MSMAX(0, UP_DIV(-iw, conv_param->dilation_w_));
|
||||
int end_kw = MSMIN(conv_param->kernel_w_, UP_DIV(conv_param->input_w_ - iw, conv_param->dilation_w_));
|
||||
const float *src_w = src_h + iw * sliding->ic4_channel_;
|
||||
|
||||
const float *src_kernel = src_w + start_kh * sliding->in_kh_step_ + start_kw * sliding->in_kw_step_;
|
||||
const float *weight_kernel = weight + (start_kh * conv_param->kernel_w_ + start_kw) * sliding->ic4_channel_;
|
||||
|
||||
SWBorderPixel(dst_kernel, src_kernel, weight_kernel, bias, end_kh - start_kh, end_kw - start_kw,
|
||||
sliding->in_kh_step_, sliding->in_kw_step_, conv_param->kernel_h_, conv_param->kernel_w_, ic4,
|
||||
conv_param->is_relu_, conv_param->is_relu6_);
|
||||
|
||||
dst_kernel += sliding->block_channel_;
|
||||
} // width loop
|
||||
dst_h += sliding->out_h_step_;
|
||||
} // height loop
|
||||
}
|
||||
|
||||
void SWCenter(float *dst, const float *src, const float *weight, const float *bias, int height, int width, int kernel_h,
|
||||
int kernel_w, int out_h_step, int block_channel, int ic4, int in_sh_step, int in_sw_step, int in_kh_step,
|
||||
int in_kw_step, bool is_relu, bool is_relu6) {
|
||||
float *dst_h = dst;
|
||||
const float *src_h = src;
|
||||
for (int oh = 0; oh < height; oh++) {
|
||||
float *dst_w = dst_h;
|
||||
const float *src_w = src_h;
|
||||
for (int ow = 0; ow < width; ow++) {
|
||||
const float *weight_oc = weight;
|
||||
for (int c = 0; c < C4NUM; c++) {
|
||||
dst_w[c] = 0;
|
||||
}
|
||||
|
||||
for (int oc = 0; oc < C4NUM; oc++) {
|
||||
const float *weight_kh = weight_oc;
|
||||
const float *src_kh = src_w;
|
||||
for (int kh = 0; kh < kernel_h; kh++) {
|
||||
const float *src_kw = src_kh;
|
||||
const float *weight_kw = weight_kh;
|
||||
for (int kw = 0; kw < kernel_w; kw++) {
|
||||
const float *src_ic4 = src_kw;
|
||||
const float *weight_ic4 = weight_kw;
|
||||
for (int ic = 0; ic < ic4; ++ic) {
|
||||
for (int c = 0; c < C4NUM; c++) {
|
||||
dst_w[oc] += src_ic4[c] * weight_ic4[c];
|
||||
}
|
||||
|
||||
src_ic4 += C4NUM;
|
||||
weight_ic4 += C4NUM;
|
||||
} // ic4 loop
|
||||
src_kw += in_kw_step;
|
||||
weight_kw += ic4 * C4NUM;
|
||||
} // kernel_w loop
|
||||
src_kh += in_kh_step;
|
||||
weight_kh += kernel_w * ic4 * C4NUM;
|
||||
} // kernel_h loop
|
||||
// add biad relu
|
||||
|
||||
dst_w[oc] += bias[oc];
|
||||
dst_w[oc] = (is_relu) ? (MSMAX(0, dst_w[oc])) : (dst_w[oc]);
|
||||
dst_w[oc] = (is_relu6) ? (MSMIN(6, MSMAX(0, dst_w[oc]))) : (dst_w[oc]);
|
||||
weight_oc += kernel_h * kernel_w * ic4 * C4NUM;
|
||||
} // oc block
|
||||
|
||||
dst_w += block_channel;
|
||||
src_w += in_sw_step;
|
||||
} // dst_width loop
|
||||
dst_h += out_h_step;
|
||||
src_h += in_sh_step;
|
||||
} // dst_height loop
|
||||
}
|
||||
|
||||
// fp32 sliding window
|
||||
void ConvSWFp32(const float *input_data, const float *packed_weight, const float *bias_data, float *tmp_out_block,
|
||||
float *output_data, int task_id, ConvParameter *conv_param, SlidingWindowParam *slidingWindow_param) {
|
||||
int ic4 = slidingWindow_param->ic4_channel_ / C4NUM;
|
||||
int ic4_res = conv_param->input_channel_ % C4NUM;
|
||||
const float *src = input_data;
|
||||
float *dst;
|
||||
if (ic4_res == 0) {
|
||||
dst = output_data;
|
||||
} else {
|
||||
dst = tmp_out_block;
|
||||
}
|
||||
|
||||
for (int b = 0; b < conv_param->output_batch_; b++) {
|
||||
for (int oc = task_id; oc < slidingWindow_param->c_block_; oc += conv_param->thread_num_) {
|
||||
const float *src_data = src;
|
||||
float *dst_data = dst + oc * C4NUM;
|
||||
const float *weight = packed_weight + oc * slidingWindow_param->kernel_step_;
|
||||
const float *bias = bias_data + oc * C4NUM;
|
||||
SWBorder(dst_data, src_data, weight, bias, 0, slidingWindow_param->top_, 0, conv_param->output_w_, conv_param,
|
||||
slidingWindow_param);
|
||||
SWBorder(dst_data, src_data, weight, bias, slidingWindow_param->bottom_, conv_param->output_h_, 0,
|
||||
conv_param->output_w_, conv_param, slidingWindow_param);
|
||||
SWBorder(dst_data, src_data, weight, bias, slidingWindow_param->top_, slidingWindow_param->bottom_, 0,
|
||||
slidingWindow_param->left_, conv_param, slidingWindow_param);
|
||||
SWBorder(dst_data, src_data, weight, bias, slidingWindow_param->top_, slidingWindow_param->bottom_,
|
||||
slidingWindow_param->right_, conv_param->output_w_, conv_param, slidingWindow_param);
|
||||
|
||||
if (slidingWindow_param->right_ > slidingWindow_param->left_ &&
|
||||
slidingWindow_param->bottom_ > slidingWindow_param->top_) {
|
||||
int in_h_start = slidingWindow_param->top_ * conv_param->stride_h_ - conv_param->pad_h_;
|
||||
int in_w_start = slidingWindow_param->left_ * conv_param->stride_w_ - conv_param->pad_w_;
|
||||
const float *in_t =
|
||||
src_data + in_h_start * slidingWindow_param->in_h_step_ + in_w_start * slidingWindow_param->ic4_channel_;
|
||||
float *out_t = dst_data + slidingWindow_param->top_ * slidingWindow_param->out_h_step_ +
|
||||
slidingWindow_param->left_ * slidingWindow_param->block_channel_;
|
||||
SWCenter(out_t, in_t, weight, bias, slidingWindow_param->bottom_ - slidingWindow_param->top_,
|
||||
slidingWindow_param->right_ - slidingWindow_param->left_, conv_param->kernel_h_, conv_param->kernel_w_,
|
||||
slidingWindow_param->out_h_step_, slidingWindow_param->block_channel_, ic4,
|
||||
slidingWindow_param->in_sh_step_, slidingWindow_param->in_sw_step_, slidingWindow_param->in_kh_step_,
|
||||
slidingWindow_param->in_kw_step_, conv_param->is_relu_, conv_param->is_relu6_);
|
||||
}
|
||||
} // output C4 loop
|
||||
src += slidingWindow_param->in_step_;
|
||||
dst += slidingWindow_param->out_step_;
|
||||
} // batch loop
|
||||
// output nhwc4
|
||||
if (ic4_res != 0) {
|
||||
PackNHWC4ToNHWCFp32(tmp_out_block, output_data, conv_param->output_batch_,
|
||||
conv_param->output_h_ * conv_param->output_w_, conv_param->output_channel_);
|
||||
}
|
||||
}
|
||||
|
||||
// fp32 conv common
|
||||
void ConvFp32(float *input_data, float *packed_input, float *packed_weight, const float *bias_data,
|
||||
float *tmp_out_block, float *output_data, int task_id, ConvParameter *conv_param,
|
||||
|
|
|
@ -26,12 +26,24 @@
|
|||
#include "nnacl/conv_parameter.h"
|
||||
#include "nnacl/fp32/strassen_matmul.h"
|
||||
#include "nnacl/winograd_utils.h"
|
||||
#include "nnacl/fp32/conv_depthwise.h"
|
||||
|
||||
using TmpBufferAddress = float *;
|
||||
typedef void (*GEMM_FUNC_FP32)(float *output, const float *input, const float *weight, const float *bias, size_t step,
|
||||
size_t ic4, size_t output_channel, size_t offset, size_t mode, size_t writeC4,
|
||||
size_t relu, size_t relu6);
|
||||
|
||||
void SWBorder(float *dst, const float *src, const float *weight, const float *bias, int top, int bottom, int left,
|
||||
int right, const ConvParameter *conv_param, const SlidingWindowParam *sliding);
|
||||
|
||||
void SWCenter(float *dst, const float *src, const float *weight, const float *bias, int height, int width, int kernel_h,
|
||||
int kernel_w, int out_h_step, int block_channel, int ic4, int in_sh_step, int in_sw_step, int in_kh_step,
|
||||
int in_kw_step, bool is_relu, bool is_relu6);
|
||||
|
||||
// fp32 sliding window
|
||||
void ConvSWFp32(const float *input_data, const float *packed_weight, const float *bias_data, float *tmp_out_block,
|
||||
float *output_data, int task_id, ConvParameter *conv_param, SlidingWindowParam *slidingWindow_param);
|
||||
|
||||
// fp32 convolution common (im2col+gemm)
|
||||
void ConvFp32(float *input_data, float *packed_input, float *packed_weight, const float *bias_data,
|
||||
float *tmp_out_block, float *output_data, int task_id, ConvParameter *conv_param,
|
||||
|
|
|
@ -47,9 +47,35 @@ void InitSlidingParam(SlidingWindowParam *sliding, const ConvParameter *conv_par
|
|||
sliding->bottom_ = bottom;
|
||||
sliding->c_block_ = UP_DIV(conv_param->output_channel_, block);
|
||||
sliding->block_channel_ = UP_DIV(conv_param->output_channel_, block) * block;
|
||||
|
||||
sliding->out_step_ = conv_param->output_h_ * conv_param->output_w_ * sliding->block_channel_;
|
||||
sliding->out_h_step_ = conv_param->output_w_ * sliding->block_channel_;
|
||||
}
|
||||
|
||||
void InitSlidingParamConv(SlidingWindowParam *sliding, const ConvParameter *conv_param, int block) {
|
||||
InitSlidingParam(sliding, conv_param, block);
|
||||
AppendSlidingParamConv(sliding, conv_param, block);
|
||||
}
|
||||
|
||||
void AppendSlidingParamConv(SlidingWindowParam *sliding, const ConvParameter *conv_param, int block) {
|
||||
int in_channel = conv_param->input_channel_;
|
||||
int ic4 = UP_DIV(in_channel, C4NUM);
|
||||
int ic4_channel = ic4 * C4NUM;
|
||||
sliding->ic4_channel_ = ic4_channel;
|
||||
sliding->in_step_ = conv_param->input_h_ * conv_param->input_w_ * ic4_channel; // for batch loop
|
||||
sliding->in_h_step_ = conv_param->input_w_ * ic4_channel;
|
||||
sliding->in_sh_step_ = conv_param->input_w_ * ic4_channel * conv_param->stride_h_; // stride H
|
||||
sliding->in_sw_step_ = ic4_channel * conv_param->stride_w_; // stride W
|
||||
sliding->in_kh_step_ = conv_param->input_w_ * ic4_channel * conv_param->dilation_h_; // kernel H
|
||||
sliding->in_kw_step_ = ic4_channel * conv_param->dilation_w_; // kernel W
|
||||
sliding->kernel_step_ = conv_param->kernel_w_ * conv_param->kernel_h_ * ic4_channel * block;
|
||||
}
|
||||
|
||||
void InitSlidingParamConvDw(SlidingWindowParam *sliding, const ConvParameter *conv_param, int block) {
|
||||
InitSlidingParam(sliding, conv_param, block);
|
||||
AppendSlidingParamConvDw(sliding, conv_param, block);
|
||||
}
|
||||
|
||||
void AppendSlidingParamConvDw(SlidingWindowParam *sliding, const ConvParameter *conv_param, int block) {
|
||||
sliding->in_step_ = conv_param->input_h_ * conv_param->input_w_ * sliding->block_channel_; // for batch loop
|
||||
sliding->in_h_step_ = conv_param->input_w_ * sliding->block_channel_;
|
||||
sliding->in_sh_step_ = conv_param->input_w_ * sliding->block_channel_ * conv_param->stride_h_; // stride H
|
||||
|
|
|
@ -19,26 +19,25 @@
|
|||
|
||||
#include "nnacl/conv_parameter.h"
|
||||
|
||||
typedef struct SlidingWindowParam {
|
||||
int left_;
|
||||
int right_;
|
||||
int top_;
|
||||
int bottom_;
|
||||
int c_block_;
|
||||
int block_channel_;
|
||||
int out_step_;
|
||||
int out_h_step_;
|
||||
int in_step_;
|
||||
int in_h_step_;
|
||||
int in_sh_step_; // stride H
|
||||
int in_sw_step_; // stride W
|
||||
int in_kh_step_; // kernel H
|
||||
int in_kw_step_; // kernel W
|
||||
int kernel_step_;
|
||||
} SlidingWindowParam;
|
||||
#ifndef ENABLE_ARM64
|
||||
void DepthwiseCenter(float *dst, const float *src, const float *weight, const float *bias, int height, int width,
|
||||
int kernel_h, int kernel_w, int out_h_step, int block_channel, int in_sh_step, int in_sw_step,
|
||||
int in_kh_step, int in_kw_step, bool is_relu, bool is_relu6);
|
||||
#endif
|
||||
|
||||
void InitSlidingParam(SlidingWindowParam *sliding, const ConvParameter *conv_param, int block);
|
||||
|
||||
void InitSlidingParamConv(SlidingWindowParam *sliding, const ConvParameter *conv_param, int block);
|
||||
|
||||
void AppendSlidingParamConv(SlidingWindowParam *sliding, const ConvParameter *conv_param, int block);
|
||||
|
||||
void InitSlidingParamConvDw(SlidingWindowParam *sliding, const ConvParameter *conv_param, int block);
|
||||
|
||||
void AppendSlidingParamConvDw(SlidingWindowParam *sliding, const ConvParameter *conv_param, int block);
|
||||
|
||||
void DepthwiseBorder(float *dst, const float *src, const float *weight, const float *bias, int top, int bottom,
|
||||
int left, int right, const ConvParameter *conv_param, const SlidingWindowParam *sliding);
|
||||
|
||||
void ConvDwC4Fp32(float *output_data, const float *input_data, const float *weight_data, const float *bias_data,
|
||||
const ConvParameter *conv_param, const SlidingWindowParam *sliding, int task_id);
|
||||
|
||||
|
|
|
@ -198,7 +198,7 @@ void ConvInt8(int8_t *input_data, int8_t *packed_input, int8_t *packed_weight, c
|
|||
for (int thread_id = task_id; thread_id < output_tile_count; thread_id += thread_count) {
|
||||
int start_index = thread_id * tile_n;
|
||||
int real_cal_num = (output_count - start_index) < tile_n ? (output_count - start_index) : tile_n;
|
||||
int32_t *tmp_input_sum = input_sum + thread_id * tile_n;
|
||||
int32_t *tmp_input_sum = input_sum + task_id * tile_n;
|
||||
int8_t *gemm_input = packed_input + thread_id * unit_size * tile_n + gemm_in_batch_offset;
|
||||
// clear tmp buffer before compute
|
||||
memset(gemm_input, (int8_t)input_zp, unit_size * tile_n);
|
||||
|
@ -208,15 +208,16 @@ void ConvInt8(int8_t *input_data, int8_t *packed_input, int8_t *packed_weight, c
|
|||
int tmp_dst_offset = task_id * tile_n * conv_param->output_channel_;
|
||||
memset(tmp_dst + tmp_dst_offset, 0, tmp_dst_size);
|
||||
|
||||
Im2ColPackUnitInt8(input_data + in_batch_offset, gemm_input, real_cal_num, start_index, input_sum, conv_param);
|
||||
Im2ColPackUnitInt8(input_data + in_batch_offset, gemm_input, real_cal_num, start_index, tmp_input_sum,
|
||||
conv_param);
|
||||
if (real_cal_num == tile_n) {
|
||||
int8_t *gemm_output = output_data + out_offset;
|
||||
IndirectGemmInt8(gemm_output, tmp_dst + tmp_dst_offset, gemm_input, packed_weight, bias_data, ic4, kernel_plane,
|
||||
out_channel, input_sum, conv_param);
|
||||
out_channel, tmp_input_sum, conv_param);
|
||||
} else {
|
||||
// res part
|
||||
IndirectGemmInt8(tmp_out, tmp_dst + tmp_dst_offset, gemm_input, packed_weight, bias_data, ic4, kernel_plane,
|
||||
out_channel, input_sum, conv_param);
|
||||
out_channel, tmp_input_sum, conv_param);
|
||||
memcpy(output_data + out_offset, tmp_out, real_cal_num * out_channel);
|
||||
}
|
||||
}
|
||||
|
@ -253,7 +254,7 @@ void ConvInt8Opt(int8_t *input_data, int8_t *packed_input, int8_t *packed_weight
|
|||
int start_index = thread_id * tile_n;
|
||||
int real_cal_num = (output_count - start_index) < tile_n ? (output_count - start_index) : tile_n;
|
||||
// todo
|
||||
int32_t *tmp_input_sum = input_sum + thread_id * tile_n;
|
||||
int32_t *tmp_input_sum = input_sum + task_id * tile_n;
|
||||
int8_t *gemm_input = packed_input + thread_id * unit_size * tile_n + gemm_in_batch_offset;
|
||||
// clear tmp buffer before compute
|
||||
memset(gemm_input, (int8_t)input_zp, unit_size * tile_n);
|
||||
|
@ -263,15 +264,16 @@ void ConvInt8Opt(int8_t *input_data, int8_t *packed_input, int8_t *packed_weight
|
|||
int tmp_dst_offset = task_id * tile_n * conv_param->output_channel_;
|
||||
memset(tmp_dst + tmp_dst_offset, 0, tmp_dst_size);
|
||||
|
||||
Im2ColPackUnitInt8Opt(input_data + in_batch_offset, gemm_input, real_cal_num, start_index, input_sum, conv_param);
|
||||
Im2ColPackUnitInt8Opt(input_data + in_batch_offset, gemm_input, real_cal_num, start_index, tmp_input_sum,
|
||||
conv_param);
|
||||
if (real_cal_num == tile_n) {
|
||||
int8_t *gemm_output = output_data + out_offset;
|
||||
IndirectGemmInt8Opt(gemm_output, tmp_dst + tmp_dst_offset, gemm_input, packed_weight, bias_data, ic4,
|
||||
kernel_plane, out_channel, input_sum, conv_param, gemm_func);
|
||||
kernel_plane, out_channel, tmp_input_sum, conv_param, gemm_func);
|
||||
} else {
|
||||
// res part
|
||||
IndirectGemmInt8Opt(tmp_out, tmp_dst + tmp_dst_offset, gemm_input, packed_weight, bias_data, ic4, kernel_plane,
|
||||
out_channel, input_sum, conv_param, gemm_func);
|
||||
out_channel, tmp_input_sum, conv_param, gemm_func);
|
||||
memcpy(output_data + out_offset, tmp_out, real_cal_num * out_channel);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -27,4 +27,3 @@ typedef struct ReshapeParameter {
|
|||
} ReshapeParameter;
|
||||
|
||||
#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_RESHAHPE_PARAMETER_H_
|
||||
|
||||
|
|
Loading…
Reference in New Issue