diff --git a/mindspore/lite/src/runtime/kernel/arm/fp16/deconvolution_fp16.cc b/mindspore/lite/src/runtime/kernel/arm/fp16/deconvolution_fp16.cc new file mode 100644 index 00000000000..6732293016d --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/fp16/deconvolution_fp16.cc @@ -0,0 +1,210 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/runtime/kernel/arm/fp16/deconvolution_fp16.h" +#include "src/runtime/runtime_api.h" + +using mindspore::kernel::KERNEL_ARCH::kCPU; +using mindspore::lite::KernelRegistrar; +using mindspore::lite::RET_ERROR; +using mindspore::lite::RET_NULL_PTR; +using mindspore::lite::RET_OK; +using mindspore::schema::PrimitiveType_DeConv2D; + +namespace mindspore::kernel { +DeConvolutionFp16CPUKernel::~DeConvolutionFp16CPUKernel() { + FreeParam(); + return; +} + +int DeConvolutionFp16CPUKernel::ReSize() { + FreeParam(); + InitParam(); + return RET_OK; +} + +void DeConvolutionFp16CPUKernel::FreeParam() { + if (tmp_buffer_ != nullptr) { + free(tmp_buffer_); + tmp_buffer_ = nullptr; + } + if (pack_input_ != nullptr) { + free(pack_input_); + pack_input_ = nullptr; + } + if (pack_output_ != nullptr) { + free(pack_output_); + pack_output_ = nullptr; + } + return; +} + +int DeConvolutionFp16CPUKernel::InitWeightBias() { + bias_data_ = malloc(UP_ROUND(conv_param_->output_channel_, C4NUM) * sizeof(float16_t)); + if (bias_data_ == nullptr) { + MS_LOG(ERROR) << "deconv malloc bias_data_ error!"; + return RET_ERROR; + } + memset(bias_data_, 0, UP_ROUND(conv_param_->output_channel_, C4NUM) * sizeof(float16_t)); + if (in_tensors_.size() == 3) { + Float32ToFloat16(reinterpret_cast(in_tensors_[2]->Data()), reinterpret_cast(bias_data_), + conv_param_->output_channel_); + } + + size_t weight_pack_size = conv_param_->input_channel_ * conv_param_->kernel_w_ * conv_param_->kernel_h_ * + UP_ROUND(conv_param_->output_channel_, C8NUM) * sizeof(float16_t); + execute_weight_ = reinterpret_cast(malloc(weight_pack_size)); + if (execute_weight_ == nullptr) { + MS_LOG(ERROR) << "deconv malloc execute_weight_ error!"; + return RET_ERROR; + } + memset(execute_weight_, 0, weight_pack_size); + PackNHWCFp32ToC8HWN8Fp16(reinterpret_cast(in_tensors_[1]->Data()), execute_weight_, + conv_param_->input_channel_, kernel_plane_, conv_param_->output_channel_); + return RET_OK; +} + +int DeConvolutionFp16CPUKernel::InitParam() { + input_plane_ = conv_param_->input_h_ * conv_param_->input_w_; + kernel_plane_ = conv_param_->kernel_w_ * conv_param_->kernel_h_; + output_plane_ = conv_param_->output_h_ * conv_param_->output_w_; + + matmul_param_->row_ = input_plane_; + matmul_param_->deep_ = conv_param_->input_channel_; + matmul_param_->col_ = conv_param_->output_channel_ * kernel_plane_; + row16_ = UP_ROUND(matmul_param_->row_, 16); + col8_ = UP_ROUND(matmul_param_->col_, 8); + + thread_count_ = MSMIN(op_parameter_->thread_num_, UP_DIV(conv_param_->output_channel_, C8NUM)); + thread_stride_ = UP_DIV(UP_DIV(conv_param_->output_channel_, C8NUM), thread_count_); + + pack_input_ = reinterpret_cast(malloc(row16_ * matmul_param_->deep_ * sizeof(float16_t))); + if (pack_input_ == nullptr) { + MS_LOG(ERROR) << "deconv Malloc pack_input_ error!"; + return RET_ERROR; + } + + pack_output_ = reinterpret_cast( + malloc(UP_ROUND(conv_param_->output_channel_, C8NUM) * output_plane_ * sizeof(float16_t))); + if (pack_output_ == nullptr) { + MS_LOG(ERROR) << "deconv Malloc pack_output_ error!"; + return RET_NULL_PTR; + } + + tmp_buffer_ = reinterpret_cast(malloc(row16_ * col8_ * sizeof(float16_t))); + if (tmp_buffer_ == nullptr) { + MS_LOG(ERROR) << "deconv Malloc tmp_buffer_ error!"; + return RET_ERROR; + } + + return RET_OK; +} + +int DeConvFp16Run(int task_id, LiteParallelGroupEnv *penv, void *cdata) { + auto deconv = reinterpret_cast(cdata); + auto error_code = deconv->DoDeconv(task_id); + if (error_code != RET_OK) { + MS_LOG(ERROR) << "DeConvFp16Run error task_id[" << task_id << "] error_code[" << error_code << "]"; + return RET_ERROR; + } + return RET_OK; +} + +int DeConvolutionFp16CPUKernel::DoDeconv(int task_id) { + int oc = MSMIN(thread_stride_ * C8NUM, conv_param_->output_channel_ - task_id * thread_stride_ * C8NUM); + int oc_res = MSMIN(thread_stride_ * C8NUM, conv_param_->output_channel_ - task_id * thread_stride_ * C8NUM); + if (oc <= 0) { + return RET_OK; + } + + auto tmp_buf = tmp_buffer_ + task_id * thread_stride_ * C8NUM * kernel_plane_ * row16_; + MatMulFp16(pack_input_, execute_weight_ + task_id * thread_stride_ * C8NUM * kernel_plane_ * matmul_param_->deep_, + tmp_buf, nullptr, ActType_No, matmul_param_->deep_, matmul_param_->row_, oc * C8NUM * kernel_plane_, 0, + false); + DeConvPostFp16(tmp_buf, pack_output_ + task_id * thread_stride_ * C8NUM * output_plane_, + reinterpret_cast(bias_data_) + task_id * thread_stride_ * C8NUM, + execute_output_ + task_id * thread_stride_ * C8NUM, oc_res, conv_param_); + return RET_OK; +} + +int DeConvolutionFp16CPUKernel::Init() { + if (context_->infer_shape_interrupt_ && !context_->running_) { + set_need_reinit(); + return RET_OK; + } + ConvolutionBaseCPUKernel::Init(); + + int error_code = InitParam(); + if (error_code != RET_OK) { + MS_LOG(ERROR) << "deconv InitParam error!"; + return error_code; + } + + error_code = InitWeightBias(); + if (error_code != RET_OK) { + MS_LOG(ERROR) << "deconv InitWeightBias error!"; + return error_code; + } + return RET_OK; +} + +int DeConvolutionFp16CPUKernel::Run() { + auto prepare_ret = Prepare(); + if (prepare_ret != RET_OK) { + MS_LOG(ERROR) << "Prepare fail!ret: " << prepare_ret; + return prepare_ret; + } + ConvolutionBaseFP16CPUKernel::GetExecuteTensor(); + + for (int batch_index = 0; batch_index < conv_param_->input_batch_; batch_index++) { + RowMajor2Col8MajorFp16(execute_input_, pack_input_, input_plane_, conv_param_->input_channel_); + + int error_code = LiteBackendParallelLaunch(DeConvFp16Run, this, thread_count_); + if (error_code != RET_OK) { + MS_LOG(ERROR) << "deconv fp32 run error! error_code[" << error_code << "]"; + return RET_ERROR; + } + } + + ConvolutionBaseFP16CPUKernel::IfCastOutput(); + ConvolutionBaseFP16CPUKernel::FreeTmpBuffer(); + + return RET_OK; +} + +kernel::LiteKernel *CpuDeConvFp16KernelCreator(const std::vector &inputs, + const std::vector &outputs, + OpParameter *opParameter, const lite::Context *ctx, + const kernel::KernelKey &desc, const lite::Primitive *primitive) { + MS_ASSERT(opParameter != nullptr); + MS_ASSERT(desc.type == schema::PrimitiveType_DeConv2D); + auto kernel = new (std::nothrow) DeConvolutionFp16CPUKernel(opParameter, inputs, outputs, ctx, primitive); + if (kernel == nullptr) { + MS_LOG(ERROR) << "kernel is nullptr."; + return nullptr; + } + auto ret = kernel->Init(); + if (ret != RET_OK) { + delete kernel; + MS_LOG(ERROR) << "Init kernel failed, name: " << opParameter->name_ << ", type: " + << schema::EnumNamePrimitiveType(static_cast(opParameter->type_)); + return nullptr; + } + return kernel; +} + +REG_KERNEL(kCPU, kNumberTypeFloat16, PrimitiveType_DeConv2D, CpuDeConvFp16KernelCreator) +} // namespace mindspore::kernel diff --git a/mindspore/lite/src/runtime/kernel/arm/fp16/deconvolution_fp16.h b/mindspore/lite/src/runtime/kernel/arm/fp16/deconvolution_fp16.h new file mode 100644 index 00000000000..b2778318ad7 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/fp16/deconvolution_fp16.h @@ -0,0 +1,68 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP16_DECONVOLUTION_H_ +#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP16_DECONVOLUTION_H_ + +#include +#include +#include "src/lite_kernel.h" +#include "src/kernel_registry.h" +#include "include/errorcode.h" +#include "schema/model_generated.h" +#include "src/runtime/kernel/arm/fp16/convolution_base_fp16.h" +#include "src/runtime/kernel/arm/nnacl/fp16/deconv_fp16.h" +#include "src/runtime/kernel/arm/nnacl/fp16/matmul_fp16.h" +#include "src/runtime/kernel/arm/nnacl/fp16/pack_fp16.h" +#include "src/runtime/kernel/arm/nnacl/fp16/cast_fp16.h" + +namespace mindspore::kernel { +class DeConvolutionFp16CPUKernel : public ConvolutionBaseFP16CPUKernel { + public: + DeConvolutionFp16CPUKernel(OpParameter *parameter, const std::vector &inputs, + const std::vector &outputs, const lite::Context *ctx, + const lite::Primitive *primitive) + : ConvolutionBaseFP16CPUKernel(parameter, inputs, outputs, ctx, primitive) { + matmul_param_ = new MatMulParameter(); + } + ~DeConvolutionFp16CPUKernel() override; + int Init() override; + int Run() override; + int ReSize() override; + + public: + int DoDeconv(int task_id); + + private: + void FreeParam(); + int InitParam(); + int InitWeightBias(); + + private: + MatMulParameter *matmul_param_; + int row16_; + int col8_; + int input_plane_; + int kernel_plane_; + int output_plane_; + int thread_count_; + int thread_stride_; + float16_t *pack_input_; + float16_t *pack_output_; + float16_t *tmp_buffer_; +}; +} // namespace mindspore::kernel +#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP16_DECONVOLUTION_H_ diff --git a/mindspore/lite/src/runtime/kernel/arm/nnacl/assembly/opt/MatmulFp16.S b/mindspore/lite/src/runtime/kernel/arm/nnacl/assembly/opt/MatmulFp16.S index b667bec931c..09f408004d9 100644 --- a/mindspore/lite/src/runtime/kernel/arm/nnacl/assembly/opt/MatmulFp16.S +++ b/mindspore/lite/src/runtime/kernel/arm/nnacl/assembly/opt/MatmulFp16.S @@ -24,10 +24,10 @@ MatmulFp16Neon64: st1 {v8.8h, v9.8h, v10.8h, v11.8h}, [sp], #64 st1 {v12.8h, v13.8h, v14.8h, v15.8h}, [sp], #64 - mov w18, #32 // sizeof(float) * 8 + mov w18, #16 // sizeof(float) * 8 mul w15, w5, w18 // block stride of lhs/rhs: sizeof(float) * 8 * depth mov x11, x3 // bias flag - mov x18, #4 + mov x18, #2 ldr x17, [sp] mul x17, x17, x18 diff --git a/mindspore/lite/src/runtime/kernel/arm/nnacl/assembly/opt/PostFuncBiasReluC8Fp16.S b/mindspore/lite/src/runtime/kernel/arm/nnacl/assembly/opt/PostFuncBiasReluC8Fp16.S new file mode 100644 index 00000000000..491392386d0 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/nnacl/assembly/opt/PostFuncBiasReluC8Fp16.S @@ -0,0 +1,459 @@ +#ifdef __aarch64__ + + .text + .align 5 + //.p2align 5,,15 + .global PostFuncBiasReluC8Fp16 +#ifndef __APPLE__ + .type PostFuncBiasReluC8Fp16, %function +#endif + +//void PostFuncBiasReluC8Fp16(float *dst, const float *src, const float *bias, size_t oc8div,size_t oc8mod +// size_t plane_size, size_t stride, int relu_type); +// x0 dst x1 srx x2 bias +// x3 oc8div x4 oc8mod x5 plane_size +// x6 stride x7 relu_type + +// v0 ~ v7 value +// v16 bias data +// x24 x25 weite loop tmp buf +// x26 relu6 #6; x27 relu #0 +// w10 oc8 loop control +// w13 hw loop control + +PostFuncBiasReluC8Fp16: + movi v26.8h, #6 + scvtf v26.8h, v26.8h + dup v27.8h, wzr + mov w10, #0 + +Loop_C8: + cmp w10, w3 + beq Loop_C1 + mov x25, #4 + mul x24, x10, x25 + add x25, x0, x24 + add w10, w10, #8 + mov w13, w5 + ld1 {v16.8h}, [x2], #16 + +Loop8x8: + cmp w13, #8 + blt Loop_4x8 + sub w13, w13, #8 + ld1 {v0.8h, v1.8h, v2.8h, v3.8h}, [x1], #64 + ld1 {v4.8h, v5.8h, v6.8h, v7.8h}, [x1], #64 + + fadd v0.8h, v0.8h, v16.8h + fadd v1.8h, v1.8h, v16.8h + fadd v2.8h, v2.8h, v16.8h + fadd v3.8h, v3.8h, v16.8h + fadd v4.8h, v4.8h, v16.8h + fadd v5.8h, v5.8h, v16.8h + fadd v6.8h, v6.8h, v16.8h + fadd v7.8h, v7.8h, v16.8h + + cmp w7, #2 + beq Relu6_8x8 + cmp w7, #1 + beq Relu_8x8 + b Write_8x8 +Relu6_8x8: + fmin v0.8h, v0.8h, v26.8h + fmin v1.8h, v1.8h, v26.8h + fmin v2.8h, v2.8h, v26.8h + fmin v3.8h, v3.8h, v26.8h + fmin v4.8h, v4.8h, v26.8h + fmin v5.8h, v5.8h, v26.8h + fmin v6.8h, v6.8h, v26.8h + fmin v7.8h, v7.8h, v26.8h +Relu_8x8: + fmax v0.8h, v0.8h, v27.8h + fmax v1.8h, v1.8h, v27.8h + fmax v2.8h, v2.8h, v27.8h + fmax v3.8h, v3.8h, v27.8h + fmax v4.8h, v4.8h, v27.8h + fmax v5.8h, v5.8h, v27.8h + fmax v6.8h, v6.8h, v27.8h + fmax v7.8h, v7.8h, v27.8h +Write_8x8: + st1 {v0.8h}, [x25], x6 + st1 {v1.8h}, [x25], x6 + st1 {v2.8h}, [x25], x6 + st1 {v3.8h}, [x25], x6 + st1 {v4.8h}, [x25], x6 + st1 {v5.8h}, [x25], x6 + st1 {v6.8h}, [x25], x6 + st1 {v7.8h}, [x25], x6 + b Loop8x8 + +Loop_4x8: + cmp w13, #4 + blt Loop_1x8 + sub w13, w13, #4 + ld1 {v0.8h, v1.8h, v2.8h, v3.8h}, [x1], #64 + + fadd v0.8h, v0.8h, v16.8h + fadd v1.8h, v1.8h, v16.8h + fadd v2.8h, v2.8h, v16.8h + fadd v3.8h, v3.8h, v16.8h + + cmp w7, #2 + beq Relu6_4x8 + cmp w7, #1 + beq Relu_4x8 + b Write_4x8 +Relu6_4x8: + fmin v0.8h, v0.8h, v26.8h + fmin v1.8h, v1.8h, v26.8h + fmin v2.8h, v2.8h, v26.8h + fmin v3.8h, v3.8h, v26.8h +Relu_4x8: + fmax v0.8h, v0.8h, v27.8h + fmax v1.8h, v1.8h, v27.8h + fmax v2.8h, v2.8h, v27.8h + fmax v3.8h, v3.8h, v27.8h +Write_4x8: + st1 {v0.8h}, [x25], x6 + st1 {v1.8h}, [x25], x6 + st1 {v2.8h}, [x25], x6 + st1 {v3.8h}, [x25], x6 + +Loop_1x8: + cmp w7, #2 + beq Relu6_1x8 + cmp w7, #1 + beq Relu_1x8 + b Write_1x8 +Relu6_1x8: + cmp w13, #0 + beq Loop_C8 + sub w13, w13, #1 + ld1 {v0.8h}, [x1], #16 + fadd v0.8h, v0.8h, v16.8h + fmin v0.8h, v0.8h, v26.8h + fmax v0.8h, v0.8h, v27.8h + st1 {v0.8h}, [x25], x6 + b Relu6_1x8 +Relu_1x8: + cmp w13, #0 + beq Loop_C8 + sub w13, w13, #1 + ld1 {v0.8h}, [x1], #16 + fadd v0.8h, v0.8h, v16.8h + fmax v0.8h, v0.8h, v27.8h + st1 {v0.8h}, [x25], x6 + b Relu_1x8 +Write_1x8: + cmp w13, #0 + beq Loop_C8 + sub w13, w13, #1 + ld1 {v0.8h}, [x1], #16 + fadd v0.8h, v0.8h, v16.8h + st1 {v0.8h}, [x25], x6 + b Write_1x8 + + +Loop_C1: + cmp x4, #0 + beq End + mov w13, w5 + ld1 {v16.8h}, [x2], #16 + + cmp x4, #1 + beq Loop_C1_1 + cmp x4, #2 + beq Loop_C1_2 + cmp x4, #3 + beq Loop_C1_3 + cmp x4, #4 + beq Loop_C1_4 + cmp x4, #5 + beq Loop_C1_5 + cmp x4, #6 + beq Loop_C1_6 + cmp x4, #7 + beq Loop_C1_7 + +Loop_C1_1: + cmp w7, #2 + beq Loop_C1_1_Relu6 + cmp w7, #1 + beq Loop_C1_1_Relu + b Loop_C1_1_Write +Loop_C1_1_Relu6: + cmp w13, #0 + beq End + sub w13, w13, #1 + ld1 {v0.8h}, [x1], #16 + fadd v0.8h, v0.8h, v16.8h + fmin v0.8h, v0.8h, v26.8h + fmax v0.8h, v0.8h, v27.8h + st1 {v1.h}[0], [x0], x6 + b Loop_C1_1_Relu6 +Loop_C1_1_Relu: + cmp w13, #0 + beq End + sub w13, w13, #1 + ld1 {v0.8h}, [x1], #16 + fadd v0.8h, v0.8h, v16.8h + fmax v0.8h, v0.8h, v27.8h + st1 {v1.h}[0], [x0], x6 + b Loop_C1_1_Relu +Loop_C1_1_Write: + cmp w13, #0 + beq End + sub w13, w13, #1 + ld1 {v0.8h}, [x1], #16 + fadd v0.8h, v0.8h, v16.8h + st1 {v1.h}[0], [x0], x6 + b Loop_C1_1_Write + +Loop_C1_2: + add x24, x0, #2 + cmp w7, #2 + beq Loop_C1_2_Relu6 + cmp w7, #1 + beq Loop_C1_2_Relu + b Loop_C1_2_Write +Loop_C1_2_Relu6: + cmp w13, #0 + beq End + sub w13, w13, #1 + ld1 {v0.8h}, [x1], #16 + fadd v0.8h, v0.8h, v16.8h + fmin v0.8h, v0.8h, v26.8h + fmax v0.8h, v0.8h, v27.8h + st1 {v1.h}[0], [x0], x6 + st1 {v1.h}[1], [x24], x6 + b Loop_C1_2_Relu6 +Loop_C1_2_Relu: + cmp w13, #0 + beq End + sub w13, w13, #1 + ld1 {v0.8h}, [x1], #16 + fadd v0.8h, v0.8h, v16.8h + fmax v0.8h, v0.8h, v27.8h + st1 {v1.h}[0], [x0], x6 + st1 {v1.h}[1], [x24], x6 + b Loop_C1_2_Relu +Loop_C1_2_Write: + cmp w13, #0 + beq End + sub w13, w13, #1 + ld1 {v0.8h}, [x1], #16 + fadd v0.8h, v0.8h, v16.8h + st1 {v1.h}[0], [x0], x6 + st1 {v1.h}[1], [x24], x6 + b Loop_C1_2_Write + + +Loop_C1_3: + add x24, x0, #2 + add x25, x0, #4 + cmp w7, #2 + beq Loop_C1_3_Relu6 + cmp w7, #1 + beq Loop_C1_3_Relu + b Loop_C1_3_Write +Loop_C1_3_Relu6: + cmp w13, #0 + beq End + sub w13, w13, #1 + ld1 {v0.8h}, [x1], #16 + fadd v0.8h, v0.8h, v16.8h + fmin v0.8h, v0.8h, v26.8h + fmax v0.8h, v0.8h, v27.8h + st1 {v1.h}[0], [x0], x6 + st1 {v1.h}[1], [x24], x6 + st1 {v1.h}[2], [x25], x6 + b Loop_C1_3_Relu6 +Loop_C1_3_Relu: + cmp w13, #0 + beq End + sub w13, w13, #1 + ld1 {v0.8h}, [x1], #16 + fadd v0.8h, v0.8h, v16.8h + fmax v0.8h, v0.8h, v27.8h + st1 {v1.h}[0], [x0], x6 + st1 {v1.h}[1], [x24], x6 + st1 {v1.h}[2], [x25], x6 + b Loop_C1_3_Relu +Loop_C1_3_Write: + cmp w13, #0 + beq End + sub w13, w13, #1 + ld1 {v0.8h}, [x1], #16 + fadd v0.8h, v0.8h, v16.8h + st1 {v1.h}[0], [x0], x6 + st1 {v1.h}[1], [x24], x6 + st1 {v1.h}[2], [x25], x6 + b Loop_C1_3_Write + +Loop_C1_4: + cmp w7, #2 + beq Loop_C1_4_Relu6 + cmp w7, #1 + beq Loop_C1_4_Relu + b Loop_C1_4_Write +Loop_C1_4_Relu6: + cmp w13, #0 + beq End + sub w13, w13, #1 + ld1 {v0.8h}, [x1], #16 + fadd v0.8h, v0.8h, v16.8h + fmin v0.8h, v0.8h, v26.8h + fmax v0.8h, v0.8h, v27.8h + st1 {v0.4h}, [x0], x6 + b Loop_C1_4_Relu6 +Loop_C1_4_Relu: + cmp w13, #0 + beq End + sub w13, w13, #1 + ld1 {v0.8h}, [x1], #16 + fadd v0.8h, v0.8h, v16.8h + fmax v0.8h, v0.8h, v27.8h + st1 {v0.4h}, [x0], x6 + b Loop_C1_4_Relu6 +Loop_C1_4_Write: + cmp w13, #0 + beq End + sub w13, w13, #1 + ld1 {v0.8h}, [x1], #16 + fadd v0.8h, v0.8h, v16.8h + st1 {v0.4h}, [x0], x6 + b Loop_C1_4_Write + +Loop_C1_5: + add x25, x0, #16 + cmp w7, #2 + beq Loop_C1_5_Relu6 + cmp w7, #1 + beq Loop_C1_5_Relu + b Loop_C1_5_Write +Loop_C1_5_Relu6: + cmp w13, #0 + beq End + sub w13, w13, #1 + ld1 {v0.8h}, [x1], #16 + fadd v0.8h, v0.8h, v16.8h + fmin v0.8h, v0.8h, v26.8h + fmax v0.8h, v0.8h, v27.8h + st1 {v0.4h}, [x0], x6 + str h1, [x25] + add x25, x25, x6 + b Loop_C1_5_Relu6 +Loop_C1_5_Relu: + cmp w13, #0 + beq End + sub w13, w13, #1 + ld1 {v0.8h}, [x1], #16 + fadd v0.8h, v0.8h, v16.8h + fmax v0.8h, v0.8h, v27.8h + st1 {v0.4h}, [x0], x6 + str h1, [x25] + add x25, x25, x6 + b Loop_C1_5_Relu +Loop_C1_5_Write: + cmp w13, #0 + beq End + sub w13, w13, #1 + ld1 {v0.8h}, [x1], #16 + fadd v0.8h, v0.8h, v16.8h + st1 {v0.4h}, [x0], x6 + str h1, [x25] + add x25, x25, x6 + b Loop_C1_5_Write + +Loop_C1_6: + add x23, x0, #8 + add x24, x0, #10 + cmp w7, #2 + beq Loop_C1_6_Relu6 + cmp w7, #1 + beq Loop_C1_6_Relu + b Loop_C1_6_Write +Loop_C1_6_Relu6: + cmp w13, #0 + beq End + sub w13, w13, #1 + ld1 {v0.8h}, [x1], #16 + fadd v0.8h, v0.8h, v16.8h + fmin v0.8h, v0.8h, v26.8h + fmax v0.8h, v0.8h, v27.8h + st1 {v0.4h}, [x0], x6 + st1 {v1.h}[4], [x23], x6 + st1 {v1.h}[5], [x24], x6 + b Loop_C1_6_Relu6 +Loop_C1_6_Relu: + cmp w13, #0 + beq End + sub w13, w13, #1 + ld1 {v0.8h}, [x1], #16 + fadd v0.8h, v0.8h, v16.8h + fmax v0.8h, v0.8h, v27.8h + st1 {v0.4h}, [x0], x6 + st1 {v1.h}[4], [x23], x6 + st1 {v1.h}[5], [x24], x6 + b Loop_C1_6_Relu +Loop_C1_6_Write: + cmp w13, #0 + beq End + sub w13, w13, #1 + ld1 {v0.8h}, [x1], #16 + fadd v0.8h, v0.8h, v16.8h + st1 {v0.4h}, [x0], x6 + st1 {v1.h}[4], [x23], x6 + st1 {v1.h}[5], [x24], x6 + b Loop_C1_6_Write + +Loop_C1_7: + add x23, x0, #8 + add x24, x0, #10 + add x25, x0, #12 + cmp w7, #2 + beq Loop_C1_7_Relu6 + cmp w7, #1 + beq Loop_C1_7_Relu + b Loop_C1_7_Write +Loop_C1_7_Relu6: + cmp w13, #0 + beq End + sub w13, w13, #1 + ld1 {v0.8h}, [x1], #16 + fadd v0.8h, v0.8h, v16.8h + fmin v0.8h, v0.8h, v26.8h + fmax v0.8h, v0.8h, v27.8h + st1 {v0.4h}, [x0], x6 + st1 {v1.h}[4], [x23], x6 + st1 {v1.h}[5], [x24], x6 + st1 {v1.h}[6], [x25], x6 + b Loop_C1_7_Relu6 +Loop_C1_7_Relu: + cmp w13, #0 + beq End + sub w13, w13, #1 + ld1 {v0.8h}, [x1], #16 + fadd v0.8h, v0.8h, v16.8h + fmax v0.8h, v0.8h, v27.8h + st1 {v0.4h}, [x0], x6 + st1 {v1.h}[4], [x23], x6 + st1 {v1.h}[5], [x24], x6 + st1 {v1.h}[6], [x25], x6 + b Loop_C1_7_Relu +Loop_C1_7_Write: + cmp w13, #0 + beq End + sub w13, w13, #1 + ld1 {v0.8h}, [x1], #16 + fadd v0.8h, v0.8h, v16.8h + fmax v0.8h, v0.8h, v27.8h + st1 {v0.4h}, [x0], x6 + st1 {v1.h}[4], [x23], x6 + st1 {v1.h}[5], [x24], x6 + st1 {v1.h}[6], [x25], x6 + b Loop_C1_7_Write + +End: + ret +#endif diff --git a/mindspore/lite/src/runtime/kernel/arm/nnacl/fp16/deconv_fp16.c b/mindspore/lite/src/runtime/kernel/arm/nnacl/fp16/deconv_fp16.c new file mode 100644 index 00000000000..050304a0ef7 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/nnacl/fp16/deconv_fp16.c @@ -0,0 +1,119 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "nnacl/fp16/deconv_fp16.h" + +void PostConvFuncCommFp16(float16_t *out_ptr, const float16_t *src_ptr_, const float16_t *bias_ptr, + size_t output_channel, size_t plane_size, size_t stride, bool is_relu, bool is_relu6, + int size) { + for (int oc = 0; oc < output_channel; oc++) { + int oc_div = oc / size, oc_mod = oc % size; + for (int hw = 0; hw < plane_size; hw++) { + int src_index = oc_div * size * plane_size + hw * size + oc_mod; + int dst_index = hw * stride + oc; + float16_t value = src_ptr_[src_index]; + if (bias_ptr != NULL) { + value = value + bias_ptr[oc]; + } + value = (is_relu || is_relu6) ? (MSMAX(0.f, value)) : (value); + value = (is_relu6) ? (MSMIN(6.f, value)) : (value); + out_ptr[dst_index] = value; + } + } + return; +} + +void PostConvFuncFp16C8(const float16_t *c8_out_ptr, float16_t *out_ptr, const float16_t *bias_ptr, + size_t output_channel, size_t plane_size, size_t stride, bool is_relu, bool is_relu6) { +#ifdef DEBUG_CODE + PostConvFuncCommFp16(out_ptr, c8_out_ptr, bias_ptr, output_channel, plane_size, stride, is_relu, is_relu6, C8NUM); +#else + size_t oc8mod = output_channel % C8NUM; + size_t oc8div = output_channel - oc8mod; + size_t stride_size = stride * sizeof(float16_t); + size_t relu_type = is_relu ? 1 : 0; + relu_type = is_relu6 ? 2 : relu_type; + PostFuncBiasReluC8Fp16(out_ptr, c8_out_ptr, bias_ptr, oc8div, oc8mod, plane_size, stride_size, relu_type); +#endif + return; +} + +int DeConvPostFp16(const float16_t *src, float16_t *tmp, const float16_t *bias, float16_t *dst, int output_channel, + ConvParameter *conv_param) { + /* row8x8-major(ih*iw x oc*kh*kw) -> row8-major(oh*ow x oc) */ + size_t input_plane = conv_param->input_w_ * conv_param->input_h_; + size_t kernel_plane = conv_param->kernel_w_ * conv_param->kernel_h_; + size_t output_plane = conv_param->output_w_ * conv_param->output_h_; + int oc8 = UP_ROUND(output_channel, C8NUM); + int in_plane16 = UP_ROUND(input_plane, 16); + int src_iw_stride = C8NUM; + int src_ih_stride = conv_param->input_w_ * C8NUM; + int src_kw_stride = in_plane16 * C8NUM; + int src_kh_stride = in_plane16 * conv_param->kernel_w_ * C8NUM; + int dst_oh_stride = conv_param->output_w_ * C8NUM; + int dst_ow_stride = C8NUM; + int dst_kh_stride = conv_param->dilation_h_ * conv_param->output_w_ * C8NUM; + int dst_kw_stride = conv_param->dilation_w_ * C8NUM; + + for (int c = 0; c < oc8; c += 8) { + float16_t *dst_ptr = tmp + c * output_plane; + const float16_t *src_ptr = src + c * in_plane16 * kernel_plane; + memset(dst_ptr, 0, output_plane * C8NUM * sizeof(float16_t)); + + for (int ih = 0; ih < conv_param->input_h_; ih++) { + for (int iw = 0; iw < conv_param->input_w_; iw++) { + int oh = ih * conv_param->stride_h_ - conv_param->pad_h_; + int ow = iw * conv_param->stride_w_ - conv_param->pad_w_; + + int kh_start = MSMAX(0, UP_DIV(-oh, conv_param->dilation_h_)); + int kh_end = MSMIN(conv_param->kernel_h_, UP_DIV(conv_param->output_h_ - oh, conv_param->dilation_h_)); + int kw_start = MSMAX(0, UP_DIV(-ow, conv_param->dilation_w_)); + int kw_end = MSMIN(conv_param->kernel_w_, UP_DIV(conv_param->output_w_ - ow, conv_param->dilation_w_)); + for (int kh = kh_start; kh < kh_end; kh++) { + for (int kw = kw_start; kw < kw_end; kw++) { + int src_index = ih * src_ih_stride + iw * src_iw_stride + kh * src_kh_stride + kw * src_kw_stride; + int dst_index = oh * dst_oh_stride + ow * dst_ow_stride + kh * dst_kh_stride + kw * dst_kw_stride; + float16_t *tmp_dst = dst_ptr + dst_index; + const float16_t *tmp_src = src_ptr + src_index; +#ifdef DEBUG_CODE + for (int i = 0; i < C8NUM; i++) { + tmp_dst[i] += tmp_src[i]; + } +#else + asm volatile( + "mov x0, %[tmp_src] \n" + "mov x1, %[tmp_dst] \n" + + "ld1 {v0.8h}, [x0] \n" + "ld1 {v1.8h}, [x1] \n" + + "fadd v0.8h, v0.8h, v1.8h \n" + + "st1 {v0.8h}, [x1] \n" + + : + : [ tmp_src ] "r"(tmp_src), [ tmp_dst ] "r"(tmp_dst) + : "x0", "x1", "v0", "v1"); +#endif + } /*kw*/ + } /*kh*/ + } /*iw*/ + } /*ih*/ + } /*oc8*/ + + PostConvFuncFp16C8(tmp, dst, bias, output_channel, output_plane, conv_param->output_channel_, conv_param->is_relu_, + conv_param->is_relu6_); + return NNACL_OK; +} diff --git a/mindspore/lite/src/runtime/kernel/arm/nnacl/fp16/deconv_fp16.h b/mindspore/lite/src/runtime/kernel/arm/nnacl/fp16/deconv_fp16.h new file mode 100644 index 00000000000..04909401811 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/nnacl/fp16/deconv_fp16.h @@ -0,0 +1,40 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_FP16_DECONV_FP16_H_ +#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_FP16_DECONV_FP16_H_ + +#include +#include +#include "nnacl/conv_parameter.h" +#include "nnacl/matmul_parameter.h" +#include "nnacl/fp16/matmul_fp16.h" + +#ifdef __cplusplus +extern "C" { +#endif +int DeConvPostFp16(const float16_t *src, float16_t *tmp, const float16_t *bias, float16_t *dst, int output_channel, + ConvParameter *conv_param); + +void PostConvFuncFp16C8(const float16_t *c8_out_ptr, float16_t *out_ptr, const float16_t *bias_ptr, + size_t output_channel, size_t plane_size, size_t stride, bool is_relu, bool is_relu6); + +void PostFuncBiasReluC8Fp16(float16_t *dst, const float16_t *src, const float16_t *bias, size_t oc8div, size_t oc8mod, + size_t plane_size, size_t stride, size_t relu_type); + +#ifdef __cplusplus +} +#endif +#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_FP16_DECONV_FP16_H_ diff --git a/mindspore/lite/src/runtime/kernel/arm/nnacl/fp16/matmul_fp16.c b/mindspore/lite/src/runtime/kernel/arm/nnacl/fp16/matmul_fp16.c index 460bdaf7b1d..bbb4360beb8 100644 --- a/mindspore/lite/src/runtime/kernel/arm/nnacl/fp16/matmul_fp16.c +++ b/mindspore/lite/src/runtime/kernel/arm/nnacl/fp16/matmul_fp16.c @@ -15,10 +15,37 @@ */ #include "nnacl/fp16/matmul_fp16.h" +void MatMul16x8(const float16_t *a, const float16_t *b, float16_t *dst, const float16_t *bias, ActType act_type, + int deep, int row, int col, int stride, bool write_nhwc) { + int row_16 = UP_ROUND(row, C16NUM); + int col_8 = UP_ROUND(col, C8NUM); + /* col16-major * row8-major => row16x8-major */ + if (write_nhwc) return; + for (int r = 0; r < row_16; r++) { + for (int c = 0; c < col_8; c++) { + int r16div = r / C16NUM, r16mod = r % C16NUM; + int c8div = c / C8NUM, c8mod = c % C8NUM; + size_t ci = c8div * row_16 * C8NUM + r * C8NUM + c8mod; + float16_t value = 0; + for (int d = 0; d < deep; d++) { + size_t ai = r16div * deep * C16NUM + d * C16NUM + r16mod; + size_t bi = c8div * deep * C8NUM + d * C8NUM + c8mod; + value = value + a[ai] * b[bi]; + } + if (bias != NULL) value += bias[col]; + if (act_type == ActType_Relu6) value = MSMIN(6.0f, value); + if (act_type != ActType_No) value = MSMAX(0.0f, value); + dst[ci] = value; + } + } + return; +} void MatMulFp16(const float16_t *a, const float16_t *b, float16_t *c, const float16_t *bias, ActType act_type, int depth, int row, int col, int stride, bool write_nhwc) { - MatmulFp16Neon64(a, b, c, bias, (int)act_type, depth, row, col, stride, write_nhwc); + // MatmulFp16Neon64(a, b, c, bias, (int)act_type, depth, row, col, stride, write_nhwc); + MatMul16x8(a, b, c, bias, (int)act_type, depth, row, col, stride, write_nhwc); + return; } void RowMajor2Col8MajorFp16(float16_t *src_ptr, float16_t *dst_ptr, size_t row, size_t col) { diff --git a/mindspore/lite/src/runtime/kernel/arm/nnacl/fp16/matmul_fp16.h b/mindspore/lite/src/runtime/kernel/arm/nnacl/fp16/matmul_fp16.h index 8156a11a1a4..237d6cb1416 100644 --- a/mindspore/lite/src/runtime/kernel/arm/nnacl/fp16/matmul_fp16.h +++ b/mindspore/lite/src/runtime/kernel/arm/nnacl/fp16/matmul_fp16.h @@ -33,10 +33,10 @@ void MatMulFp16(const float16_t *a, const float16_t *b, float16_t *c, const floa int depth, int row, int col, int stride, bool write_nhwc); void RowMajor2Col8MajorFp16(float16_t *src_ptr, float16_t *dst_ptr, size_t row, size_t col); -#ifdef __aarch64__ -void MatmulFp16Neon64(const float *a, const float *b, float *c, const float *bias, int act_type, int depth, int row, - int col, int stride, bool write_nhwc); -#endif + +void MatmulFp16Neon64(const float16_t *a, const float16_t *b, float16_t *c, const float16_t *bias, int act_type, + size_t depth, size_t row, size_t col, size_t stride, bool write_nhwc); + #ifdef __cplusplus } #endif diff --git a/mindspore/lite/src/runtime/kernel/arm/nnacl/fp16/pack_fp16.c b/mindspore/lite/src/runtime/kernel/arm/nnacl/fp16/pack_fp16.c index b3872f6cc01..af5f9b92d76 100644 --- a/mindspore/lite/src/runtime/kernel/arm/nnacl/fp16/pack_fp16.c +++ b/mindspore/lite/src/runtime/kernel/arm/nnacl/fp16/pack_fp16.c @@ -369,6 +369,21 @@ void PackNHWCFp32ToNHWC8Fp16(float *src, float16_t *dst, int batch, int plane, i } } +void PackNHWCFp32ToC8HWN8Fp16(float *src, float16_t *dst, int batch, int plane, int channel) { + for (int n = 0; n < batch; n++) { + for (int hw = 0; hw < plane; hw++) { + for (int c = 0; c < channel; c++) { + int c8div = c / C8NUM; + int c8mod = c % C8NUM; + int src_index = n * plane * channel + hw * channel + c; + int dst_index = c8div * batch * plane * C8NUM + hw * batch * C8NUM + n * C8NUM + c8mod; + dst[dst_index] = (float16_t)(src[src_index]); + } + } + } + return; +} + void PackNHWC8Fp16ToNHWCFp32(float16_t *src, float *dst, int batch, int plane, int channel) { int c8_channel = UP_DIV(channel, C8NUM) * C8NUM; for (int b = 0; b < batch; b++) { diff --git a/mindspore/lite/src/runtime/kernel/arm/nnacl/fp16/pack_fp16.h b/mindspore/lite/src/runtime/kernel/arm/nnacl/fp16/pack_fp16.h index 40d95be422c..f84cacafb00 100644 --- a/mindspore/lite/src/runtime/kernel/arm/nnacl/fp16/pack_fp16.h +++ b/mindspore/lite/src/runtime/kernel/arm/nnacl/fp16/pack_fp16.h @@ -59,6 +59,8 @@ void PackNCHWFp32ToNC8HW8Fp16(float *src, float16_t *dst, int batch, int plane, void PackNHWCFp32ToNHWC8Fp16(float *src, float16_t *dst, int batch, int plane, int channel); +void PackNHWCFp32ToC8HWN8Fp16(float *src, float16_t *dst, int batch, int plane, int channel); + void PackNHWC8Fp16ToNHWCFp32(float16_t *src, float *dst, int batch, int plane, int channel); void PackNHWCToNHWC8Fp16(float16_t *src, float16_t *dst, int batch, int plane, int channel); diff --git a/mindspore/lite/src/runtime/kernel/arm/nnacl/fp32/matmul.c b/mindspore/lite/src/runtime/kernel/arm/nnacl/fp32/matmul.c index b48c18a6de1..e5f879e0c3f 100644 --- a/mindspore/lite/src/runtime/kernel/arm/nnacl/fp32/matmul.c +++ b/mindspore/lite/src/runtime/kernel/arm/nnacl/fp32/matmul.c @@ -120,7 +120,7 @@ void RowMajor2Col8Major(float *src_ptr, float *dst_ptr, size_t row, size_t col) } void MatrixUnPackUnit(const void *src, void *dst, size_t row, size_t col, size_t src_stride, size_t dst_stride, - size_t data_lenth) { + size_t data_lenth) { size_t copy_size = col * data_lenth; size_t src_size = src_stride * data_lenth; size_t dst_size = dst_stride * data_lenth;