[MS][LITE] arm cpu fp32 op: conv depthwise

This commit is contained in:
yangruoqi713 2020-08-23 12:03:05 +08:00
parent dde257592b
commit b4551670a9
8 changed files with 455 additions and 94 deletions

View File

@ -0,0 +1,117 @@
#ifdef __aarch64__
.text
.align 5
.global ConvDwFp32Row
#ifndef __APPLE__
.type ConvDwFp32Row, %function
#endif
// void ConvDwFp32Row(float* output_ptr, const float* input_ptr,const float* filter_ptr,
// size_t num_pixels, size_t input_channel, size_t input_step)
// x0: output_ptr, x1: input_ptr, x2: filter_ptr, x3: num_pixels,
// x4: input_channel, x5: input_step
//
ConvDwFp32Row:
// registers v8 ~ v15 must be preserved by a callee across subroutine calls, according to
// https://github.com/ARM-software/abi-aa/blob/master/aapcs64/aapcs64.rst#simd-and-floating-point-registers
// x19 ~ x29 should be also preserved
// whereas our coding style do not permit such amount of parameters
cmp x3, #0
beq End
mov x9, x0
mov x12, #4
mul x5, x5, x12
LoopOutPixel:
mov x6, x1
mov x7, x2
mov x8, x4
LoopInputDepth16In:
cmp x8, #16
blt L4
sub x8, x8, #16
ld1 {v0.4s, v1.4s}, [x6], #32
ld1 {v2.4s, v3.4s}, [x7], #32
ld1 {v16.4s, v17.4s}, [x0], #32
cmp x8, #16
blt LoopInputDepth16Out
LoopInputDepth16:
fmla v16.4s, v0.4s, v2.4s
fmla v17.4s, v1.4s, v3.4s
st1 {v16.4s, v17.4s}, [x9], #32
ld1 {v4.4s, v5.4s}, [x6], #32
ld1 {v6.4s, v7.4s}, [x7], #32
ld1 {v18.4s, v19.4s}, [x0], #32
fmla v18.4s, v4.4s, v6.4s
fmla v19.4s, v5.4s, v7.4s
st1 {v18.4s, v19.4s}, [x9], #32
ld1 {v0.4s, v1.4s}, [x6], #32
ld1 {v2.4s, v3.4s}, [x7], #32
ld1 {v16.4s, v17.4s}, [x0], #32
sub x8, x8, #16
cmp x8, #16
bge LoopInputDepth16
LoopInputDepth16Out:
fmla v16.4s, v0.4s, v2.4s
fmla v17.4s, v1.4s, v3.4s
st1 {v16.4s, v17.4s}, [x9], #32
ld1 {v4.4s, v5.4s}, [x6], #32
ld1 {v6.4s, v7.4s}, [x7], #32
ld1 {v18.4s, v19.4s}, [x0], #32
fmla v18.4s, v4.4s, v6.4s
fmla v19.4s, v5.4s, v7.4s
st1 {v18.4s, v19.4s}, [x9], #32
L4:
cmp x8, #4
blt L0
LoopInputDepth4:
ld1 {v0.4s}, [x6], #16
ld1 {v2.4s}, [x7], #16
ld1 {v16.4s}, [x0], #16
fmla v16.4s, v0.4s, v2.4s
st1 {v16.4s}, [x9], #16
sub x8, x8, #4
cmp x8, #4
bge LoopInputDepth4
L0:
cmp x8, #0
beq Loop16LineEnd
LoopInputDepth0:
ldr s0, [x6], #4
ldr s1, [x7], #4
ldr s2, [x0], #4
fmul s0, s0, s1
fadd s2, s2, s0
str s2, [x9], #4
subs x8, x8, #1
bne LoopInputDepth0
Loop16LineEnd:
subs x3, x3, #1
add x1, x1, x5
bne LoopOutPixel
End:
ret
#endif

View File

@ -41,7 +41,6 @@ float ShortToFloat32(uint16_t srcValue);
uint16_t Float32ToShort(float srcValue);
#ifdef ENABLE_ARM
void ConvDwFp32Center(float *dst, const float *src, const float *weight, const float *bias, size_t height, size_t width,
size_t kernel_h, size_t kernel_w, size_t out_h_step, size_t block_channel, size_t in_sh_step,
@ -63,6 +62,9 @@ void C4BiasAddRelu6(float *dst, const float *input, const float *bias, size_t oc
void C4Relu(float *dst, const float *input, size_t oc, size_t plane_size, size_t stride);
void C4Relu6(float *dst, const float *input, size_t oc, size_t plane_size, size_t stride);
void ConvDwFp32Row(float *output_ptr, const float *input_ptr, const float *weight_ptr, size_t num_pixels,
size_t output_channel, size_t input_step);
void ConvDwFp32Border(float *dst, const float *src, const float *weight, const float *bias, size_t height, size_t width,
size_t in_kh_step, size_t in_kw_step, size_t kernel_w, size_t relu, size_t relu6);
@ -72,10 +74,10 @@ void DeconvDwFp32Border(float *dst, const float *src, const float *weight, size_
void PostFuncBiasReluC8(float *dst, const float *src, const float *bias, size_t oc8div, size_t oc8mod,
size_t plane_size, size_t stride, size_t relu_type);
void ConvSwFp32Center(float *dst, const float *src, const float *weight, const float *bias, size_t height,
size_t width, size_t kernel_h, size_t kernel_w, size_t out_h_step, size_t block_channel,
size_t ic4, size_t in_sh_step, size_t in_sw_step, size_t in_kh_step, size_t in_kw_step,
size_t relu, size_t relu6);
void ConvSwFp32Center(float *dst, const float *src, const float *weight, const float *bias, size_t height, size_t width,
size_t kernel_h, size_t kernel_w, size_t out_h_step, size_t block_channel, size_t ic4,
size_t in_sh_step, size_t in_sw_step, size_t in_kh_step, size_t in_kw_step, size_t relu,
size_t relu6);
#endif
#ifdef __cplusplus

View File

@ -21,6 +21,70 @@
#include <arm_neon.h>
#endif
#ifndef ENABLE_ARM64
void ConvDwFp32Row(float *output_ptr, const float *input_ptr, const float *weight_ptr, int num_pixels,
int output_channel, int input_step) {
for (int i = 0; i < num_pixels; i++) {
for (int c = 0; c < output_channel; c++) {
*output_ptr++ += weight_ptr[c] * input_ptr[c];
}
input_ptr += input_step;
}
}
#endif
void ConvDw(float *output_data, const float *input_data, const float *weight_data, const float *bias_data,
const ConvParameter *conv_param, int task_id) {
int h_step = UP_DIV(conv_param->output_h_, conv_param->thread_num_);
int h_start = h_step * task_id;
int h_end = MSMIN(h_start + h_step, conv_param->output_h_);
for (int b = 0; b < conv_param->output_batch_; b++) {
const float *src = input_data + b * conv_param->input_h_ * conv_param->input_w_ * conv_param->input_channel_;
float *dst = output_data + b * conv_param->output_h_ * conv_param->output_w_ * conv_param->output_channel_;
for (int oh = h_start; oh < h_end; oh++) {
float *dst_data = dst + oh * conv_param->output_w_ * conv_param->output_channel_;
int ih_origin = oh * conv_param->stride_h_ - conv_param->pad_h_;
int start_kh = MSMAX(0, UP_DIV(-ih_origin, conv_param->dilation_h_));
int end_kh = MSMIN(conv_param->kernel_h_, UP_DIV(conv_param->input_h_ - ih_origin, conv_param->dilation_h_));
for (int ow = 0; ow < conv_param->output_w_; ow++) {
memcpy(dst_data + ow * conv_param->output_channel_, bias_data, conv_param->output_channel_ * sizeof(float));
}
for (int kh = start_kh; kh < end_kh; kh++) {
int ih = ih_origin + conv_param->dilation_w_ * kh;
const float *src_kh = src + ih * conv_param->input_w_ * conv_param->input_channel_;
const float *weight_kh = weight_data + kh * conv_param->kernel_w_ * conv_param->output_channel_;
int in_sw_step = conv_param->stride_w_ * conv_param->input_channel_;
for (int kw = 0; kw < conv_param->kernel_w_; kw++) {
int out_w_start = MSMAX(
0, (conv_param->pad_w_ - conv_param->dilation_w_ * kw + conv_param->stride_w_ - 1) / conv_param->stride_w_);
int out_w_end = MSMIN(conv_param->output_w_, (conv_param->input_w_ + conv_param->pad_w_ -
conv_param->dilation_w_ * kw + conv_param->stride_w_ - 1) /
conv_param->stride_w_);
float *dst_w = dst_data + out_w_start * conv_param->output_channel_;
int iw_origin = (out_w_start * conv_param->stride_w_) - conv_param->pad_w_ + conv_param->dilation_w_ * kw;
const float *src_kw = src_kh + iw_origin * conv_param->input_channel_;
int num_pixels = out_w_end - out_w_start;
ConvDwFp32Row(dst_w, src_kw, weight_kh, num_pixels, conv_param->output_channel_, in_sw_step);
weight_kh += conv_param->output_channel_;
}
}
if (conv_param->is_relu_) {
ReluFp32(dst_data, dst_data, conv_param->output_w_ * conv_param->output_channel_);
}
if (conv_param->is_relu6_) {
Relu6Fp32(dst_data, dst_data, conv_param->output_w_ * conv_param->output_channel_);
}
}
}
}
void InitSlidingParam(SlidingWindowParam *sliding, const ConvParameter *conv_param, int block) {
int left = 0;
int right = conv_param->output_w_;

View File

@ -29,6 +29,9 @@ void DepthwiseCenter(float *dst, const float *src, const float *weight, const fl
extern "C" {
#endif
void ConvDw(float *output_data, const float *input_data, const float *weight_data, const float *bias_data,
const ConvParameter *conv_param, int task_id);
void InitSlidingParam(SlidingWindowParam *sliding, const ConvParameter *conv_param, int block);
void InitSlidingParamConv(SlidingWindowParam *sliding, const ConvParameter *conv_param, int block);

View File

@ -15,7 +15,6 @@
*/
#include "src/runtime/kernel/arm/fp32/convolution_depthwise.h"
#include "src/runtime/kernel/arm/fp32/convolution_depthwise_3x3.h"
#include "schema/model_generated.h"
#include "src/kernel_registry.h"
#include "include/errorcode.h"
@ -30,89 +29,43 @@ using mindspore::schema::PrimitiveType_DepthwiseConv2D;
namespace mindspore::kernel {
ConvolutionDepthwiseCPUKernel::~ConvolutionDepthwiseCPUKernel() {
if (sliding_ != nullptr) {
delete sliding_;
sliding_ = nullptr;
}
if (packed_weight_ != nullptr) {
delete packed_weight_;
packed_weight_ = nullptr;
}
FreeTmpBuffer();
}
void ConvolutionDepthwiseCPUKernel::FreeTmpBuffer() {
if (need_align_) {
if (packed_input_ != nullptr) {
delete packed_input_;
packed_input_ = nullptr;
}
if (packed_output_ != nullptr) {
delete packed_output_;
packed_output_ = nullptr;
}
}
}
int ConvolutionDepthwiseCPUKernel::InitWeightBias() {
// init weight: o, h, w, i; o == group, i == 1
auto weight_tensor = in_tensors_[kWeightIndex];
auto origin_weight = reinterpret_cast<float *>(weight_tensor->Data());
int OC4 = UP_DIV(weight_tensor->Batch(), C4NUM);
int pack_weight_size = C4NUM * OC4 * weight_tensor->Height() * weight_tensor->Width();
int channel = weight_tensor->Batch();
int pack_weight_size = weight_tensor->Batch() * weight_tensor->Height() * weight_tensor->Width();
packed_weight_ = reinterpret_cast<float *>(malloc(pack_weight_size * sizeof(float)));
if (packed_weight_ == nullptr) {
MS_LOG(ERROR) << "Malloc buffer failed.";
return RET_ERROR;
}
PackNCHWToNC4HW4Fp32(origin_weight, packed_weight_, 1, weight_tensor->Height() * weight_tensor->Width(),
weight_tensor->Batch());
PackNCHWToNHWCFp32(origin_weight, packed_weight_, 1, weight_tensor->Height() * weight_tensor->Width(), channel);
bias_data_ = reinterpret_cast<float *>(malloc(C4NUM * OC4 * sizeof(float)));
auto bias_tensor = in_tensors_[kBiasIndex];
bias_data_ = reinterpret_cast<float *>(malloc(channel * sizeof(float)));
if (bias_data_ == nullptr) {
MS_LOG(ERROR) << "Malloc buffer failed.";
return RET_ERROR;
}
memset(bias_data_, 0, C4NUM * OC4 * sizeof(float));
memset(bias_data_, 0, channel * sizeof(float));
if (in_tensors_.size() == kInputSize2) {
auto ori_bias = reinterpret_cast<float *>(in_tensors_.at(kBiasIndex)->Data());
memcpy(bias_data_, ori_bias, in_tensors_.at(kBiasIndex)->ElementsNum() * sizeof(float));
auto ori_bias = reinterpret_cast<float *>(bias_tensor->Data());
memcpy(bias_data_, ori_bias, bias_tensor->ElementsNum() * sizeof(float));
}
conv_param_->thread_num_ = MSMIN(thread_count_, OC4);
return RET_OK;
}
int ConvolutionDepthwiseCPUKernel::InitBuffer() {
if (conv_param_->input_channel_ % C4NUM != 0) {
need_align_ = true;
int IC4 = UP_DIV(conv_param_->input_channel_, C4NUM);
int pack_input_size = conv_param_->input_batch_ * conv_param_->input_h_ * conv_param_->input_w_ * C4NUM * IC4;
packed_input_ = reinterpret_cast<float *>(malloc(pack_input_size * sizeof(float)));
if (packed_input_ == nullptr) {
MS_LOG(ERROR) << "Malloc buffer failed.";
return RET_ERROR;
}
int OC4 = UP_DIV(conv_param_->output_channel_, C4NUM);
int pack_output_size = conv_param_->output_batch_ * conv_param_->output_h_ * conv_param_->output_w_ * C4NUM * OC4;
packed_output_ = reinterpret_cast<float *>(malloc(pack_output_size * sizeof(float)));
if (packed_output_ == nullptr) {
MS_LOG(ERROR) << "Malloc buffer failed.";
return RET_ERROR;
}
}
return RET_OK;
}
int ConvolutionDepthwiseCPUKernel::Init() {
sliding_ = new (std::nothrow) SlidingWindowParam;
if (sliding_ == nullptr) {
MS_LOG(ERROR) << "new sliding window param failed.";
return RET_ERROR;
}
auto ret = InitWeightBias();
if (ret != 0) {
MS_LOG(ERROR) << "Convolution depthwise fp32 InitWeightBias failed.";
@ -125,21 +78,13 @@ int ConvolutionDepthwiseCPUKernel::Init() {
}
int ConvolutionDepthwiseCPUKernel::ReSize() {
FreeTmpBuffer();
ConvolutionBaseCPUKernel::Init();
InitSlidingParamConvDw(sliding_, conv_param_, C4NUM);
auto ret = InitBuffer();
if (ret != 0) {
MS_LOG(ERROR) << "Convolution depthwise fp32 InitBuffer failed.";
return RET_ERROR;
}
conv_param_->thread_num_ = MSMIN(thread_count_, conv_param_->output_h_);
return RET_OK;
}
int ConvolutionDepthwiseCPUKernel::Execute(int task_id) {
ConvDwC4Fp32(packed_output_, packed_input_, packed_weight_, reinterpret_cast<float *>(bias_data_), conv_param_,
sliding_, task_id);
ConvDw(output_ptr_, input_ptr_, packed_weight_, reinterpret_cast<float *>(bias_data_), conv_param_, task_id);
return RET_OK;
}
@ -164,30 +109,16 @@ int ConvolutionDepthwiseCPUKernel::Run() {
return RET_ERROR;
}
auto input_tensor = in_tensors_.at(kInputIndex);
auto input_addr = reinterpret_cast<float *>(input_tensor->Data());
input_ptr_ = reinterpret_cast<float *>(input_tensor->Data());
if (need_align_) {
PackNHWCToNHWC4Fp32(input_addr, packed_input_, conv_param_->input_batch_,
conv_param_->input_h_ * conv_param_->input_w_, conv_param_->input_channel_);
} else {
packed_input_ = input_addr;
}
auto output_addr = reinterpret_cast<float *>(out_tensors_.at(kOutputIndex)->Data());
if (!need_align_) {
packed_output_ = output_addr;
}
auto output_tensor = out_tensors_.at(kOutputIndex);
output_ptr_ = reinterpret_cast<float *>(output_tensor->Data());
ret = LiteBackendParallelLaunch(ConvDwRun, this, conv_param_->thread_num_);
if (ret != RET_OK) {
MS_LOG(ERROR) << "ConvDwRun error: error_code[" << ret << "]";
return RET_ERROR;
}
if (need_align_) {
PackNHWC4ToNHWCFp32(packed_output_, output_addr, conv_param_->output_batch_,
conv_param_->output_h_ * conv_param_->output_w_, conv_param_->output_channel_);
}
return RET_OK;
}

View File

@ -35,17 +35,13 @@ class ConvolutionDepthwiseCPUKernel : public ConvolutionBaseCPUKernel {
int ReSize() override;
int Run() override;
int InitBuffer();
int InitWeightBias();
int Execute(int task_id);
private:
void FreeTmpBuffer();
SlidingWindowParam *sliding_ = nullptr;
float *packed_weight_ = nullptr;
float *packed_input_ = nullptr;
float *packed_output_ = nullptr;
bool need_align_ = false;
float *input_ptr_ = nullptr;
float *output_ptr_ = nullptr;
};
} // namespace mindspore::kernel

View File

@ -0,0 +1,196 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "src/runtime/kernel/arm/fp32/convolution_depthwise_slidewindow.h"
#include "schema/model_generated.h"
#include "src/kernel_registry.h"
#include "include/errorcode.h"
#include "src/runtime/runtime_api.h"
using mindspore::kernel::KERNEL_ARCH::kCPU;
using mindspore::lite::KernelRegistrar;
using mindspore::lite::RET_ERROR;
using mindspore::lite::RET_INFER_INVALID;
using mindspore::lite::RET_OK;
namespace mindspore::kernel {
ConvolutionDepthwiseSWCPUKernel::~ConvolutionDepthwiseSWCPUKernel() {
if (sliding_ != nullptr) {
delete sliding_;
sliding_ = nullptr;
}
if (packed_weight_ != nullptr) {
delete packed_weight_;
packed_weight_ = nullptr;
}
FreeTmpBuffer();
}
void ConvolutionDepthwiseSWCPUKernel::FreeTmpBuffer() {
if (need_align_) {
if (packed_input_ != nullptr) {
delete packed_input_;
packed_input_ = nullptr;
}
if (packed_output_ != nullptr) {
delete packed_output_;
packed_output_ = nullptr;
}
}
}
int ConvolutionDepthwiseSWCPUKernel::InitWeightBias() {
// init weight: o, h, w, i; o == group, i == 1
auto weight_tensor = in_tensors_[kWeightIndex];
auto origin_weight = reinterpret_cast<float *>(weight_tensor->Data());
int OC4 = UP_DIV(weight_tensor->Batch(), C4NUM);
int pack_weight_size = C4NUM * OC4 * weight_tensor->Height() * weight_tensor->Width();
packed_weight_ = reinterpret_cast<float *>(malloc(pack_weight_size * sizeof(float)));
if (packed_weight_ == nullptr) {
MS_LOG(ERROR) << "Malloc buffer failed.";
return RET_ERROR;
}
PackNCHWToNC4HW4Fp32(origin_weight, packed_weight_, 1, weight_tensor->Height() * weight_tensor->Width(),
weight_tensor->Batch());
auto bias_tensor = in_tensors_[kBiasIndex];
bias_data_ = reinterpret_cast<float *>(malloc(C4NUM * OC4 * sizeof(float)));
if (bias_data_ == nullptr) {
MS_LOG(ERROR) << "Malloc buffer failed.";
return RET_ERROR;
}
memset(bias_data_, 0, C4NUM * OC4 * sizeof(float));
if (in_tensors_.size() == kInputSize2) {
auto ori_bias = reinterpret_cast<float *>(bias_tensor->Data());
memcpy(bias_data_, ori_bias, bias_tensor->ElementsNum() * sizeof(float));
}
conv_param_->thread_num_ = MSMIN(thread_count_, OC4);
return RET_OK;
}
int ConvolutionDepthwiseSWCPUKernel::InitBuffer() {
if (conv_param_->input_channel_ % C4NUM != 0) {
need_align_ = true;
int IC4 = UP_DIV(conv_param_->input_channel_, C4NUM);
int pack_input_size = conv_param_->input_batch_ * conv_param_->input_h_ * conv_param_->input_w_ * C4NUM * IC4;
packed_input_ = reinterpret_cast<float *>(malloc(pack_input_size * sizeof(float)));
if (packed_input_ == nullptr) {
MS_LOG(ERROR) << "Malloc buffer failed.";
return RET_ERROR;
}
int OC4 = UP_DIV(conv_param_->output_channel_, C4NUM);
int pack_output_size = conv_param_->output_batch_ * conv_param_->output_h_ * conv_param_->output_w_ * C4NUM * OC4;
packed_output_ = reinterpret_cast<float *>(malloc(pack_output_size * sizeof(float)));
if (packed_output_ == nullptr) {
MS_LOG(ERROR) << "Malloc buffer failed.";
return RET_ERROR;
}
}
return RET_OK;
}
int ConvolutionDepthwiseSWCPUKernel::Init() {
sliding_ = new (std::nothrow) SlidingWindowParam;
if (sliding_ == nullptr) {
MS_LOG(ERROR) << "new sliding window param failed.";
return RET_ERROR;
}
auto ret = InitWeightBias();
if (ret != 0) {
MS_LOG(ERROR) << "Convolution depthwise fp32 InitWeightBias failed.";
return RET_ERROR;
}
if (!InferShapeDone()) {
return RET_OK;
}
return ReSize();
}
int ConvolutionDepthwiseSWCPUKernel::ReSize() {
FreeTmpBuffer();
ConvolutionBaseCPUKernel::Init();
InitSlidingParamConvDw(sliding_, conv_param_, C4NUM);
conv_param_->thread_num_ = MSMIN(thread_count_, conv_param_->output_h_);
auto ret = InitBuffer();
if (ret != 0) {
MS_LOG(ERROR) << "Convolution depthwise fp32 InitBuffer failed.";
return RET_ERROR;
}
return RET_OK;
}
int ConvolutionDepthwiseSWCPUKernel::Execute(int task_id) {
ConvDwC4Fp32(packed_output_, packed_input_, packed_weight_, reinterpret_cast<float *>(bias_data_), conv_param_,
sliding_, task_id);
return RET_OK;
}
int ConvDwSWRun(int task_id, LiteParallelGroupEnv *penv, void *cdata) {
auto conv_dw = reinterpret_cast<ConvolutionDepthwiseSWCPUKernel *>(cdata);
auto ret = conv_dw->Execute(task_id);
if (ret != RET_OK) {
MS_LOG(ERROR) << "ConvolutionDepthwiseSWRun error task_id[" << task_id << "] error_code[" << ret << "]";
return RET_ERROR;
}
return RET_OK;
}
int ConvolutionDepthwiseSWCPUKernel::Run() {
auto ret = Prepare();
if (ret != RET_OK) {
MS_LOG(ERROR) << "Prepare failed.";
return ret;
}
if (conv_param_->input_channel_ != conv_param_->output_channel_) {
MS_LOG(ERROR) << "Only support input channel equals output channel.";
return RET_ERROR;
}
auto input_tensor = in_tensors_.at(kInputIndex);
auto input_ptr = reinterpret_cast<float *>(input_tensor->Data());
if (need_align_) {
PackNHWCToNHWC4Fp32(input_ptr, packed_input_, conv_param_->input_batch_,
conv_param_->input_h_ * conv_param_->input_w_, conv_param_->input_channel_);
} else {
packed_input_ = input_ptr;
}
auto output_tensor = out_tensors_.at(kOutputIndex);
auto output_ptr = reinterpret_cast<float *>(output_tensor->Data());
if (!need_align_) {
packed_output_ = output_ptr;
}
ret = LiteBackendParallelLaunch(ConvDwSWRun, this, conv_param_->thread_num_);
if (ret != RET_OK) {
MS_LOG(ERROR) << "ConvDwSWRun error: error_code[" << ret << "]";
return RET_ERROR;
}
if (need_align_) {
PackNHWC4ToNHWCFp32(packed_output_, output_ptr, conv_param_->output_batch_,
conv_param_->output_h_ * conv_param_->output_w_, conv_param_->output_channel_);
}
return RET_OK;
}
} // namespace mindspore::kernel

View File

@ -0,0 +1,52 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_CONVOLUTION_DEPTHWISE_SLIDEWINDOW_H_
#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_CONVOLUTION_DEPTHWISE_SLIDEWINDOW_H_
#include <vector>
#include "src/lite_kernel.h"
#include "src/runtime/kernel/arm/base/convolution_base.h"
#include "nnacl/fp32/conv_depthwise.h"
namespace mindspore::kernel {
class ConvolutionDepthwiseSWCPUKernel : public ConvolutionBaseCPUKernel {
public:
ConvolutionDepthwiseSWCPUKernel(OpParameter *parameter, const std::vector<lite::tensor::Tensor *> &inputs,
const std::vector<lite::tensor::Tensor *> &outputs, const lite::Context *ctx,
const mindspore::lite::PrimitiveC *primitive)
: ConvolutionBaseCPUKernel(parameter, inputs, outputs, ctx, primitive) {}
~ConvolutionDepthwiseSWCPUKernel() override;
int Init() override;
int ReSize() override;
int Run() override;
int InitBuffer();
int InitWeightBias();
int Execute(int task_id);
private:
void FreeTmpBuffer();
SlidingWindowParam *sliding_ = nullptr;
float *packed_weight_ = nullptr;
float *packed_input_ = nullptr;
float *packed_output_ = nullptr;
bool need_align_ = false;
};
} // namespace mindspore::kernel
#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_CONVOLUTION_DEPTHWISE_SLIDEWINDOW_H_