forked from mindspore-Ecosystem/mindspore
[MS][LITE] arm cpu fp32 op: conv depthwise
This commit is contained in:
parent
dde257592b
commit
b4551670a9
|
@ -0,0 +1,117 @@
|
|||
#ifdef __aarch64__
|
||||
|
||||
.text
|
||||
.align 5
|
||||
.global ConvDwFp32Row
|
||||
#ifndef __APPLE__
|
||||
.type ConvDwFp32Row, %function
|
||||
#endif
|
||||
|
||||
// void ConvDwFp32Row(float* output_ptr, const float* input_ptr,const float* filter_ptr,
|
||||
// size_t num_pixels, size_t input_channel, size_t input_step)
|
||||
// x0: output_ptr, x1: input_ptr, x2: filter_ptr, x3: num_pixels,
|
||||
// x4: input_channel, x5: input_step
|
||||
//
|
||||
ConvDwFp32Row:
|
||||
// registers v8 ~ v15 must be preserved by a callee across subroutine calls, according to
|
||||
// https://github.com/ARM-software/abi-aa/blob/master/aapcs64/aapcs64.rst#simd-and-floating-point-registers
|
||||
// x19 ~ x29 should be also preserved
|
||||
// whereas our coding style do not permit such amount of parameters
|
||||
cmp x3, #0
|
||||
beq End
|
||||
|
||||
mov x9, x0
|
||||
mov x12, #4
|
||||
mul x5, x5, x12
|
||||
|
||||
LoopOutPixel:
|
||||
mov x6, x1
|
||||
mov x7, x2
|
||||
mov x8, x4
|
||||
|
||||
LoopInputDepth16In:
|
||||
cmp x8, #16
|
||||
blt L4
|
||||
sub x8, x8, #16
|
||||
|
||||
ld1 {v0.4s, v1.4s}, [x6], #32
|
||||
ld1 {v2.4s, v3.4s}, [x7], #32
|
||||
ld1 {v16.4s, v17.4s}, [x0], #32
|
||||
|
||||
cmp x8, #16
|
||||
blt LoopInputDepth16Out
|
||||
LoopInputDepth16:
|
||||
fmla v16.4s, v0.4s, v2.4s
|
||||
fmla v17.4s, v1.4s, v3.4s
|
||||
|
||||
st1 {v16.4s, v17.4s}, [x9], #32
|
||||
|
||||
ld1 {v4.4s, v5.4s}, [x6], #32
|
||||
ld1 {v6.4s, v7.4s}, [x7], #32
|
||||
ld1 {v18.4s, v19.4s}, [x0], #32
|
||||
|
||||
fmla v18.4s, v4.4s, v6.4s
|
||||
fmla v19.4s, v5.4s, v7.4s
|
||||
|
||||
st1 {v18.4s, v19.4s}, [x9], #32
|
||||
|
||||
ld1 {v0.4s, v1.4s}, [x6], #32
|
||||
ld1 {v2.4s, v3.4s}, [x7], #32
|
||||
ld1 {v16.4s, v17.4s}, [x0], #32
|
||||
|
||||
sub x8, x8, #16
|
||||
cmp x8, #16
|
||||
bge LoopInputDepth16
|
||||
|
||||
LoopInputDepth16Out:
|
||||
fmla v16.4s, v0.4s, v2.4s
|
||||
fmla v17.4s, v1.4s, v3.4s
|
||||
st1 {v16.4s, v17.4s}, [x9], #32
|
||||
|
||||
ld1 {v4.4s, v5.4s}, [x6], #32
|
||||
ld1 {v6.4s, v7.4s}, [x7], #32
|
||||
ld1 {v18.4s, v19.4s}, [x0], #32
|
||||
|
||||
fmla v18.4s, v4.4s, v6.4s
|
||||
fmla v19.4s, v5.4s, v7.4s
|
||||
|
||||
st1 {v18.4s, v19.4s}, [x9], #32
|
||||
|
||||
L4:
|
||||
cmp x8, #4
|
||||
blt L0
|
||||
|
||||
LoopInputDepth4:
|
||||
ld1 {v0.4s}, [x6], #16
|
||||
ld1 {v2.4s}, [x7], #16
|
||||
ld1 {v16.4s}, [x0], #16
|
||||
fmla v16.4s, v0.4s, v2.4s
|
||||
st1 {v16.4s}, [x9], #16
|
||||
sub x8, x8, #4
|
||||
cmp x8, #4
|
||||
bge LoopInputDepth4
|
||||
|
||||
L0:
|
||||
cmp x8, #0
|
||||
beq Loop16LineEnd
|
||||
|
||||
LoopInputDepth0:
|
||||
ldr s0, [x6], #4
|
||||
ldr s1, [x7], #4
|
||||
ldr s2, [x0], #4
|
||||
fmul s0, s0, s1
|
||||
fadd s2, s2, s0
|
||||
str s2, [x9], #4
|
||||
subs x8, x8, #1
|
||||
bne LoopInputDepth0
|
||||
|
||||
Loop16LineEnd:
|
||||
|
||||
subs x3, x3, #1
|
||||
add x1, x1, x5
|
||||
bne LoopOutPixel
|
||||
|
||||
End:
|
||||
ret
|
||||
|
||||
#endif
|
|
@ -41,7 +41,6 @@ float ShortToFloat32(uint16_t srcValue);
|
|||
|
||||
uint16_t Float32ToShort(float srcValue);
|
||||
|
||||
|
||||
#ifdef ENABLE_ARM
|
||||
void ConvDwFp32Center(float *dst, const float *src, const float *weight, const float *bias, size_t height, size_t width,
|
||||
size_t kernel_h, size_t kernel_w, size_t out_h_step, size_t block_channel, size_t in_sh_step,
|
||||
|
@ -63,6 +62,9 @@ void C4BiasAddRelu6(float *dst, const float *input, const float *bias, size_t oc
|
|||
void C4Relu(float *dst, const float *input, size_t oc, size_t plane_size, size_t stride);
|
||||
void C4Relu6(float *dst, const float *input, size_t oc, size_t plane_size, size_t stride);
|
||||
|
||||
void ConvDwFp32Row(float *output_ptr, const float *input_ptr, const float *weight_ptr, size_t num_pixels,
|
||||
size_t output_channel, size_t input_step);
|
||||
|
||||
void ConvDwFp32Border(float *dst, const float *src, const float *weight, const float *bias, size_t height, size_t width,
|
||||
size_t in_kh_step, size_t in_kw_step, size_t kernel_w, size_t relu, size_t relu6);
|
||||
|
||||
|
@ -72,10 +74,10 @@ void DeconvDwFp32Border(float *dst, const float *src, const float *weight, size_
|
|||
void PostFuncBiasReluC8(float *dst, const float *src, const float *bias, size_t oc8div, size_t oc8mod,
|
||||
size_t plane_size, size_t stride, size_t relu_type);
|
||||
|
||||
void ConvSwFp32Center(float *dst, const float *src, const float *weight, const float *bias, size_t height,
|
||||
size_t width, size_t kernel_h, size_t kernel_w, size_t out_h_step, size_t block_channel,
|
||||
size_t ic4, size_t in_sh_step, size_t in_sw_step, size_t in_kh_step, size_t in_kw_step,
|
||||
size_t relu, size_t relu6);
|
||||
void ConvSwFp32Center(float *dst, const float *src, const float *weight, const float *bias, size_t height, size_t width,
|
||||
size_t kernel_h, size_t kernel_w, size_t out_h_step, size_t block_channel, size_t ic4,
|
||||
size_t in_sh_step, size_t in_sw_step, size_t in_kh_step, size_t in_kw_step, size_t relu,
|
||||
size_t relu6);
|
||||
#endif
|
||||
|
||||
#ifdef __cplusplus
|
||||
|
|
|
@ -21,6 +21,70 @@
|
|||
#include <arm_neon.h>
|
||||
#endif
|
||||
|
||||
#ifndef ENABLE_ARM64
|
||||
void ConvDwFp32Row(float *output_ptr, const float *input_ptr, const float *weight_ptr, int num_pixels,
|
||||
int output_channel, int input_step) {
|
||||
for (int i = 0; i < num_pixels; i++) {
|
||||
for (int c = 0; c < output_channel; c++) {
|
||||
*output_ptr++ += weight_ptr[c] * input_ptr[c];
|
||||
}
|
||||
input_ptr += input_step;
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
void ConvDw(float *output_data, const float *input_data, const float *weight_data, const float *bias_data,
|
||||
const ConvParameter *conv_param, int task_id) {
|
||||
int h_step = UP_DIV(conv_param->output_h_, conv_param->thread_num_);
|
||||
int h_start = h_step * task_id;
|
||||
int h_end = MSMIN(h_start + h_step, conv_param->output_h_);
|
||||
for (int b = 0; b < conv_param->output_batch_; b++) {
|
||||
const float *src = input_data + b * conv_param->input_h_ * conv_param->input_w_ * conv_param->input_channel_;
|
||||
float *dst = output_data + b * conv_param->output_h_ * conv_param->output_w_ * conv_param->output_channel_;
|
||||
for (int oh = h_start; oh < h_end; oh++) {
|
||||
float *dst_data = dst + oh * conv_param->output_w_ * conv_param->output_channel_;
|
||||
|
||||
int ih_origin = oh * conv_param->stride_h_ - conv_param->pad_h_;
|
||||
int start_kh = MSMAX(0, UP_DIV(-ih_origin, conv_param->dilation_h_));
|
||||
int end_kh = MSMIN(conv_param->kernel_h_, UP_DIV(conv_param->input_h_ - ih_origin, conv_param->dilation_h_));
|
||||
|
||||
for (int ow = 0; ow < conv_param->output_w_; ow++) {
|
||||
memcpy(dst_data + ow * conv_param->output_channel_, bias_data, conv_param->output_channel_ * sizeof(float));
|
||||
}
|
||||
for (int kh = start_kh; kh < end_kh; kh++) {
|
||||
int ih = ih_origin + conv_param->dilation_w_ * kh;
|
||||
|
||||
const float *src_kh = src + ih * conv_param->input_w_ * conv_param->input_channel_;
|
||||
const float *weight_kh = weight_data + kh * conv_param->kernel_w_ * conv_param->output_channel_;
|
||||
|
||||
int in_sw_step = conv_param->stride_w_ * conv_param->input_channel_;
|
||||
for (int kw = 0; kw < conv_param->kernel_w_; kw++) {
|
||||
int out_w_start = MSMAX(
|
||||
0, (conv_param->pad_w_ - conv_param->dilation_w_ * kw + conv_param->stride_w_ - 1) / conv_param->stride_w_);
|
||||
int out_w_end = MSMIN(conv_param->output_w_, (conv_param->input_w_ + conv_param->pad_w_ -
|
||||
conv_param->dilation_w_ * kw + conv_param->stride_w_ - 1) /
|
||||
conv_param->stride_w_);
|
||||
|
||||
float *dst_w = dst_data + out_w_start * conv_param->output_channel_;
|
||||
int iw_origin = (out_w_start * conv_param->stride_w_) - conv_param->pad_w_ + conv_param->dilation_w_ * kw;
|
||||
|
||||
const float *src_kw = src_kh + iw_origin * conv_param->input_channel_;
|
||||
int num_pixels = out_w_end - out_w_start;
|
||||
|
||||
ConvDwFp32Row(dst_w, src_kw, weight_kh, num_pixels, conv_param->output_channel_, in_sw_step);
|
||||
weight_kh += conv_param->output_channel_;
|
||||
}
|
||||
}
|
||||
if (conv_param->is_relu_) {
|
||||
ReluFp32(dst_data, dst_data, conv_param->output_w_ * conv_param->output_channel_);
|
||||
}
|
||||
if (conv_param->is_relu6_) {
|
||||
Relu6Fp32(dst_data, dst_data, conv_param->output_w_ * conv_param->output_channel_);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void InitSlidingParam(SlidingWindowParam *sliding, const ConvParameter *conv_param, int block) {
|
||||
int left = 0;
|
||||
int right = conv_param->output_w_;
|
||||
|
|
|
@ -29,6 +29,9 @@ void DepthwiseCenter(float *dst, const float *src, const float *weight, const fl
|
|||
extern "C" {
|
||||
#endif
|
||||
|
||||
void ConvDw(float *output_data, const float *input_data, const float *weight_data, const float *bias_data,
|
||||
const ConvParameter *conv_param, int task_id);
|
||||
|
||||
void InitSlidingParam(SlidingWindowParam *sliding, const ConvParameter *conv_param, int block);
|
||||
|
||||
void InitSlidingParamConv(SlidingWindowParam *sliding, const ConvParameter *conv_param, int block);
|
||||
|
|
|
@ -15,7 +15,6 @@
|
|||
*/
|
||||
|
||||
#include "src/runtime/kernel/arm/fp32/convolution_depthwise.h"
|
||||
#include "src/runtime/kernel/arm/fp32/convolution_depthwise_3x3.h"
|
||||
#include "schema/model_generated.h"
|
||||
#include "src/kernel_registry.h"
|
||||
#include "include/errorcode.h"
|
||||
|
@ -30,89 +29,43 @@ using mindspore::schema::PrimitiveType_DepthwiseConv2D;
|
|||
|
||||
namespace mindspore::kernel {
|
||||
ConvolutionDepthwiseCPUKernel::~ConvolutionDepthwiseCPUKernel() {
|
||||
if (sliding_ != nullptr) {
|
||||
delete sliding_;
|
||||
sliding_ = nullptr;
|
||||
}
|
||||
if (packed_weight_ != nullptr) {
|
||||
delete packed_weight_;
|
||||
packed_weight_ = nullptr;
|
||||
}
|
||||
FreeTmpBuffer();
|
||||
}
|
||||
|
||||
void ConvolutionDepthwiseCPUKernel::FreeTmpBuffer() {
|
||||
if (need_align_) {
|
||||
if (packed_input_ != nullptr) {
|
||||
delete packed_input_;
|
||||
packed_input_ = nullptr;
|
||||
}
|
||||
if (packed_output_ != nullptr) {
|
||||
delete packed_output_;
|
||||
packed_output_ = nullptr;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
int ConvolutionDepthwiseCPUKernel::InitWeightBias() {
|
||||
// init weight: o, h, w, i; o == group, i == 1
|
||||
auto weight_tensor = in_tensors_[kWeightIndex];
|
||||
auto origin_weight = reinterpret_cast<float *>(weight_tensor->Data());
|
||||
int OC4 = UP_DIV(weight_tensor->Batch(), C4NUM);
|
||||
int pack_weight_size = C4NUM * OC4 * weight_tensor->Height() * weight_tensor->Width();
|
||||
int channel = weight_tensor->Batch();
|
||||
int pack_weight_size = weight_tensor->Batch() * weight_tensor->Height() * weight_tensor->Width();
|
||||
|
||||
packed_weight_ = reinterpret_cast<float *>(malloc(pack_weight_size * sizeof(float)));
|
||||
if (packed_weight_ == nullptr) {
|
||||
MS_LOG(ERROR) << "Malloc buffer failed.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
PackNCHWToNC4HW4Fp32(origin_weight, packed_weight_, 1, weight_tensor->Height() * weight_tensor->Width(),
|
||||
weight_tensor->Batch());
|
||||
PackNCHWToNHWCFp32(origin_weight, packed_weight_, 1, weight_tensor->Height() * weight_tensor->Width(), channel);
|
||||
|
||||
bias_data_ = reinterpret_cast<float *>(malloc(C4NUM * OC4 * sizeof(float)));
|
||||
auto bias_tensor = in_tensors_[kBiasIndex];
|
||||
bias_data_ = reinterpret_cast<float *>(malloc(channel * sizeof(float)));
|
||||
if (bias_data_ == nullptr) {
|
||||
MS_LOG(ERROR) << "Malloc buffer failed.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
memset(bias_data_, 0, C4NUM * OC4 * sizeof(float));
|
||||
|
||||
memset(bias_data_, 0, channel * sizeof(float));
|
||||
if (in_tensors_.size() == kInputSize2) {
|
||||
auto ori_bias = reinterpret_cast<float *>(in_tensors_.at(kBiasIndex)->Data());
|
||||
memcpy(bias_data_, ori_bias, in_tensors_.at(kBiasIndex)->ElementsNum() * sizeof(float));
|
||||
auto ori_bias = reinterpret_cast<float *>(bias_tensor->Data());
|
||||
memcpy(bias_data_, ori_bias, bias_tensor->ElementsNum() * sizeof(float));
|
||||
}
|
||||
|
||||
conv_param_->thread_num_ = MSMIN(thread_count_, OC4);
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
int ConvolutionDepthwiseCPUKernel::InitBuffer() {
|
||||
if (conv_param_->input_channel_ % C4NUM != 0) {
|
||||
need_align_ = true;
|
||||
int IC4 = UP_DIV(conv_param_->input_channel_, C4NUM);
|
||||
int pack_input_size = conv_param_->input_batch_ * conv_param_->input_h_ * conv_param_->input_w_ * C4NUM * IC4;
|
||||
packed_input_ = reinterpret_cast<float *>(malloc(pack_input_size * sizeof(float)));
|
||||
if (packed_input_ == nullptr) {
|
||||
MS_LOG(ERROR) << "Malloc buffer failed.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
|
||||
int OC4 = UP_DIV(conv_param_->output_channel_, C4NUM);
|
||||
int pack_output_size = conv_param_->output_batch_ * conv_param_->output_h_ * conv_param_->output_w_ * C4NUM * OC4;
|
||||
packed_output_ = reinterpret_cast<float *>(malloc(pack_output_size * sizeof(float)));
|
||||
if (packed_output_ == nullptr) {
|
||||
MS_LOG(ERROR) << "Malloc buffer failed.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
int ConvolutionDepthwiseCPUKernel::Init() {
|
||||
sliding_ = new (std::nothrow) SlidingWindowParam;
|
||||
if (sliding_ == nullptr) {
|
||||
MS_LOG(ERROR) << "new sliding window param failed.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
|
||||
auto ret = InitWeightBias();
|
||||
if (ret != 0) {
|
||||
MS_LOG(ERROR) << "Convolution depthwise fp32 InitWeightBias failed.";
|
||||
|
@ -125,21 +78,13 @@ int ConvolutionDepthwiseCPUKernel::Init() {
|
|||
}
|
||||
|
||||
int ConvolutionDepthwiseCPUKernel::ReSize() {
|
||||
FreeTmpBuffer();
|
||||
ConvolutionBaseCPUKernel::Init();
|
||||
InitSlidingParamConvDw(sliding_, conv_param_, C4NUM);
|
||||
|
||||
auto ret = InitBuffer();
|
||||
if (ret != 0) {
|
||||
MS_LOG(ERROR) << "Convolution depthwise fp32 InitBuffer failed.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
conv_param_->thread_num_ = MSMIN(thread_count_, conv_param_->output_h_);
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
int ConvolutionDepthwiseCPUKernel::Execute(int task_id) {
|
||||
ConvDwC4Fp32(packed_output_, packed_input_, packed_weight_, reinterpret_cast<float *>(bias_data_), conv_param_,
|
||||
sliding_, task_id);
|
||||
ConvDw(output_ptr_, input_ptr_, packed_weight_, reinterpret_cast<float *>(bias_data_), conv_param_, task_id);
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
|
@ -164,30 +109,16 @@ int ConvolutionDepthwiseCPUKernel::Run() {
|
|||
return RET_ERROR;
|
||||
}
|
||||
auto input_tensor = in_tensors_.at(kInputIndex);
|
||||
auto input_addr = reinterpret_cast<float *>(input_tensor->Data());
|
||||
input_ptr_ = reinterpret_cast<float *>(input_tensor->Data());
|
||||
|
||||
if (need_align_) {
|
||||
PackNHWCToNHWC4Fp32(input_addr, packed_input_, conv_param_->input_batch_,
|
||||
conv_param_->input_h_ * conv_param_->input_w_, conv_param_->input_channel_);
|
||||
} else {
|
||||
packed_input_ = input_addr;
|
||||
}
|
||||
|
||||
auto output_addr = reinterpret_cast<float *>(out_tensors_.at(kOutputIndex)->Data());
|
||||
if (!need_align_) {
|
||||
packed_output_ = output_addr;
|
||||
}
|
||||
auto output_tensor = out_tensors_.at(kOutputIndex);
|
||||
output_ptr_ = reinterpret_cast<float *>(output_tensor->Data());
|
||||
|
||||
ret = LiteBackendParallelLaunch(ConvDwRun, this, conv_param_->thread_num_);
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "ConvDwRun error: error_code[" << ret << "]";
|
||||
return RET_ERROR;
|
||||
}
|
||||
|
||||
if (need_align_) {
|
||||
PackNHWC4ToNHWCFp32(packed_output_, output_addr, conv_param_->output_batch_,
|
||||
conv_param_->output_h_ * conv_param_->output_w_, conv_param_->output_channel_);
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
|
|
|
@ -35,17 +35,13 @@ class ConvolutionDepthwiseCPUKernel : public ConvolutionBaseCPUKernel {
|
|||
int ReSize() override;
|
||||
int Run() override;
|
||||
|
||||
int InitBuffer();
|
||||
int InitWeightBias();
|
||||
int Execute(int task_id);
|
||||
|
||||
private:
|
||||
void FreeTmpBuffer();
|
||||
SlidingWindowParam *sliding_ = nullptr;
|
||||
float *packed_weight_ = nullptr;
|
||||
float *packed_input_ = nullptr;
|
||||
float *packed_output_ = nullptr;
|
||||
bool need_align_ = false;
|
||||
float *input_ptr_ = nullptr;
|
||||
float *output_ptr_ = nullptr;
|
||||
};
|
||||
} // namespace mindspore::kernel
|
||||
|
||||
|
|
|
@ -0,0 +1,196 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "src/runtime/kernel/arm/fp32/convolution_depthwise_slidewindow.h"
|
||||
#include "schema/model_generated.h"
|
||||
#include "src/kernel_registry.h"
|
||||
#include "include/errorcode.h"
|
||||
#include "src/runtime/runtime_api.h"
|
||||
|
||||
using mindspore::kernel::KERNEL_ARCH::kCPU;
|
||||
using mindspore::lite::KernelRegistrar;
|
||||
using mindspore::lite::RET_ERROR;
|
||||
using mindspore::lite::RET_INFER_INVALID;
|
||||
using mindspore::lite::RET_OK;
|
||||
|
||||
namespace mindspore::kernel {
|
||||
ConvolutionDepthwiseSWCPUKernel::~ConvolutionDepthwiseSWCPUKernel() {
|
||||
if (sliding_ != nullptr) {
|
||||
delete sliding_;
|
||||
sliding_ = nullptr;
|
||||
}
|
||||
if (packed_weight_ != nullptr) {
|
||||
delete packed_weight_;
|
||||
packed_weight_ = nullptr;
|
||||
}
|
||||
FreeTmpBuffer();
|
||||
}
|
||||
|
||||
void ConvolutionDepthwiseSWCPUKernel::FreeTmpBuffer() {
|
||||
if (need_align_) {
|
||||
if (packed_input_ != nullptr) {
|
||||
delete packed_input_;
|
||||
packed_input_ = nullptr;
|
||||
}
|
||||
if (packed_output_ != nullptr) {
|
||||
delete packed_output_;
|
||||
packed_output_ = nullptr;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
int ConvolutionDepthwiseSWCPUKernel::InitWeightBias() {
|
||||
// init weight: o, h, w, i; o == group, i == 1
|
||||
auto weight_tensor = in_tensors_[kWeightIndex];
|
||||
auto origin_weight = reinterpret_cast<float *>(weight_tensor->Data());
|
||||
int OC4 = UP_DIV(weight_tensor->Batch(), C4NUM);
|
||||
int pack_weight_size = C4NUM * OC4 * weight_tensor->Height() * weight_tensor->Width();
|
||||
|
||||
packed_weight_ = reinterpret_cast<float *>(malloc(pack_weight_size * sizeof(float)));
|
||||
if (packed_weight_ == nullptr) {
|
||||
MS_LOG(ERROR) << "Malloc buffer failed.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
PackNCHWToNC4HW4Fp32(origin_weight, packed_weight_, 1, weight_tensor->Height() * weight_tensor->Width(),
|
||||
weight_tensor->Batch());
|
||||
|
||||
auto bias_tensor = in_tensors_[kBiasIndex];
|
||||
bias_data_ = reinterpret_cast<float *>(malloc(C4NUM * OC4 * sizeof(float)));
|
||||
if (bias_data_ == nullptr) {
|
||||
MS_LOG(ERROR) << "Malloc buffer failed.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
|
||||
memset(bias_data_, 0, C4NUM * OC4 * sizeof(float));
|
||||
if (in_tensors_.size() == kInputSize2) {
|
||||
auto ori_bias = reinterpret_cast<float *>(bias_tensor->Data());
|
||||
memcpy(bias_data_, ori_bias, bias_tensor->ElementsNum() * sizeof(float));
|
||||
}
|
||||
|
||||
conv_param_->thread_num_ = MSMIN(thread_count_, OC4);
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
int ConvolutionDepthwiseSWCPUKernel::InitBuffer() {
|
||||
if (conv_param_->input_channel_ % C4NUM != 0) {
|
||||
need_align_ = true;
|
||||
int IC4 = UP_DIV(conv_param_->input_channel_, C4NUM);
|
||||
int pack_input_size = conv_param_->input_batch_ * conv_param_->input_h_ * conv_param_->input_w_ * C4NUM * IC4;
|
||||
packed_input_ = reinterpret_cast<float *>(malloc(pack_input_size * sizeof(float)));
|
||||
if (packed_input_ == nullptr) {
|
||||
MS_LOG(ERROR) << "Malloc buffer failed.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
|
||||
int OC4 = UP_DIV(conv_param_->output_channel_, C4NUM);
|
||||
int pack_output_size = conv_param_->output_batch_ * conv_param_->output_h_ * conv_param_->output_w_ * C4NUM * OC4;
|
||||
packed_output_ = reinterpret_cast<float *>(malloc(pack_output_size * sizeof(float)));
|
||||
if (packed_output_ == nullptr) {
|
||||
MS_LOG(ERROR) << "Malloc buffer failed.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
int ConvolutionDepthwiseSWCPUKernel::Init() {
|
||||
sliding_ = new (std::nothrow) SlidingWindowParam;
|
||||
if (sliding_ == nullptr) {
|
||||
MS_LOG(ERROR) << "new sliding window param failed.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
|
||||
auto ret = InitWeightBias();
|
||||
if (ret != 0) {
|
||||
MS_LOG(ERROR) << "Convolution depthwise fp32 InitWeightBias failed.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
if (!InferShapeDone()) {
|
||||
return RET_OK;
|
||||
}
|
||||
return ReSize();
|
||||
}
|
||||
|
||||
int ConvolutionDepthwiseSWCPUKernel::ReSize() {
|
||||
FreeTmpBuffer();
|
||||
ConvolutionBaseCPUKernel::Init();
|
||||
InitSlidingParamConvDw(sliding_, conv_param_, C4NUM);
|
||||
conv_param_->thread_num_ = MSMIN(thread_count_, conv_param_->output_h_);
|
||||
|
||||
auto ret = InitBuffer();
|
||||
if (ret != 0) {
|
||||
MS_LOG(ERROR) << "Convolution depthwise fp32 InitBuffer failed.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
int ConvolutionDepthwiseSWCPUKernel::Execute(int task_id) {
|
||||
ConvDwC4Fp32(packed_output_, packed_input_, packed_weight_, reinterpret_cast<float *>(bias_data_), conv_param_,
|
||||
sliding_, task_id);
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
int ConvDwSWRun(int task_id, LiteParallelGroupEnv *penv, void *cdata) {
|
||||
auto conv_dw = reinterpret_cast<ConvolutionDepthwiseSWCPUKernel *>(cdata);
|
||||
auto ret = conv_dw->Execute(task_id);
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "ConvolutionDepthwiseSWRun error task_id[" << task_id << "] error_code[" << ret << "]";
|
||||
return RET_ERROR;
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
int ConvolutionDepthwiseSWCPUKernel::Run() {
|
||||
auto ret = Prepare();
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "Prepare failed.";
|
||||
return ret;
|
||||
}
|
||||
if (conv_param_->input_channel_ != conv_param_->output_channel_) {
|
||||
MS_LOG(ERROR) << "Only support input channel equals output channel.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
auto input_tensor = in_tensors_.at(kInputIndex);
|
||||
auto input_ptr = reinterpret_cast<float *>(input_tensor->Data());
|
||||
|
||||
if (need_align_) {
|
||||
PackNHWCToNHWC4Fp32(input_ptr, packed_input_, conv_param_->input_batch_,
|
||||
conv_param_->input_h_ * conv_param_->input_w_, conv_param_->input_channel_);
|
||||
} else {
|
||||
packed_input_ = input_ptr;
|
||||
}
|
||||
|
||||
auto output_tensor = out_tensors_.at(kOutputIndex);
|
||||
auto output_ptr = reinterpret_cast<float *>(output_tensor->Data());
|
||||
|
||||
if (!need_align_) {
|
||||
packed_output_ = output_ptr;
|
||||
}
|
||||
|
||||
ret = LiteBackendParallelLaunch(ConvDwSWRun, this, conv_param_->thread_num_);
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "ConvDwSWRun error: error_code[" << ret << "]";
|
||||
return RET_ERROR;
|
||||
}
|
||||
|
||||
if (need_align_) {
|
||||
PackNHWC4ToNHWCFp32(packed_output_, output_ptr, conv_param_->output_batch_,
|
||||
conv_param_->output_h_ * conv_param_->output_w_, conv_param_->output_channel_);
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
} // namespace mindspore::kernel
|
|
@ -0,0 +1,52 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_CONVOLUTION_DEPTHWISE_SLIDEWINDOW_H_
|
||||
#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_CONVOLUTION_DEPTHWISE_SLIDEWINDOW_H_
|
||||
|
||||
#include <vector>
|
||||
#include "src/lite_kernel.h"
|
||||
#include "src/runtime/kernel/arm/base/convolution_base.h"
|
||||
#include "nnacl/fp32/conv_depthwise.h"
|
||||
|
||||
namespace mindspore::kernel {
|
||||
class ConvolutionDepthwiseSWCPUKernel : public ConvolutionBaseCPUKernel {
|
||||
public:
|
||||
ConvolutionDepthwiseSWCPUKernel(OpParameter *parameter, const std::vector<lite::tensor::Tensor *> &inputs,
|
||||
const std::vector<lite::tensor::Tensor *> &outputs, const lite::Context *ctx,
|
||||
const mindspore::lite::PrimitiveC *primitive)
|
||||
: ConvolutionBaseCPUKernel(parameter, inputs, outputs, ctx, primitive) {}
|
||||
~ConvolutionDepthwiseSWCPUKernel() override;
|
||||
|
||||
int Init() override;
|
||||
int ReSize() override;
|
||||
int Run() override;
|
||||
|
||||
int InitBuffer();
|
||||
int InitWeightBias();
|
||||
int Execute(int task_id);
|
||||
|
||||
private:
|
||||
void FreeTmpBuffer();
|
||||
SlidingWindowParam *sliding_ = nullptr;
|
||||
float *packed_weight_ = nullptr;
|
||||
float *packed_input_ = nullptr;
|
||||
float *packed_output_ = nullptr;
|
||||
bool need_align_ = false;
|
||||
};
|
||||
} // namespace mindspore::kernel
|
||||
|
||||
#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_CONVOLUTION_DEPTHWISE_SLIDEWINDOW_H_
|
Loading…
Reference in New Issue