Merge pull request !4515 from fuzhiye/tmp
This commit is contained in:
mindspore-ci-bot 2020-08-15 19:13:03 +08:00 committed by Gitee
commit 9efd84fbb2
29 changed files with 5925 additions and 146 deletions

View File

@ -0,0 +1,97 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "src/runtime/kernel/arm/fp16/convolution_1x1_fp16.h"
#include "src/runtime/kernel/arm/nnacl/fp16/conv_fp16.h"
#include "src/runtime/kernel/arm/nnacl/fp16/cast_fp16.h"
#include "src/runtime/kernel/arm/nnacl/fp16/pack_fp16.h"
#include "src/runtime/kernel/arm/fp16/layout_transform_fp16.h"
#include "schema/model_generated.h"
#include "src/kernel_registry.h"
#include "include/errorcode.h"
#include "src/runtime/runtime_api.h"
using mindspore::kernel::KERNEL_ARCH::kCPU;
using mindspore::lite::KernelRegistrar;
using mindspore::lite::RET_ERROR;
using mindspore::lite::RET_OK;
using mindspore::schema::PrimitiveType_Conv2D;
namespace mindspore::kernel {
int Convolution1x1FP16CPUKernel::Init() {
auto ret = ConvolutionBaseCPUKernel::Init();
if (ret != RET_OK) {
MS_LOG(ERROR) << "ConvolutionBase init failed.";
return ret;
}
return RET_OK;
}
int Convolution1x1FP16CPUKernel::ReSize() {
if (fp16_out_ != nullptr) {
free(fp16_out_);
}
if (fp16_input_ != nullptr) {
free(fp16_input_);
}
if (nhwc4_input_ != nullptr) {
free(nhwc4_input_);
}
auto ret = ConvolutionBaseCPUKernel::Init();
if (ret != RET_OK) {
MS_LOG(ERROR) << "ConvolutionBase init failed.";
return ret;
}
return RET_OK;
}
int Convolution1x1FP16CPUKernel::RunImpl(int task_id) {
// Conv1x1Fp16(reinterpret_cast<float16_t *>(nhwc4_input_), transformed_filter_addr_,
// reinterpret_cast<float16_t *>(bias_data_), fp16_out_, tile_buffer_, block_unit_buffer_,
// tmp_dst_buffer_, tmp_out_, task_id, conv_param_);
return RET_OK;
}
int Convolution1x1Fp16Impl(int task_id, LiteParallelGroupEnv *penv, void *cdata) {
auto conv = reinterpret_cast<Convolution1x1FP16CPUKernel *>(cdata);
auto error_code = conv->RunImpl(task_id);
if (error_code != RET_OK) {
MS_LOG(ERROR) << "Convolution1x1 Fp16 Run error task_id[" << task_id << "] error_code[" << error_code << "]";
return RET_ERROR;
}
return RET_OK;
}
int Convolution1x1FP16CPUKernel::Run() {
auto ret = Prepare();
if (ret != RET_OK) {
MS_LOG(ERROR) << "Prepare failed.";
return RET_ERROR;
}
ConvolutionBaseFP16CPUKernel::GetExecuteTensor();
int error_code = LiteBackendParallelLaunch(Convolution1x1Fp16Impl, this, thread_count_);
if (error_code != RET_OK) {
MS_LOG(ERROR) << "conv1x1 fp16 error error_code[" << error_code << "]";
return RET_ERROR;
}
ConvolutionBaseFP16CPUKernel::IfCastOutput();
return RET_OK;
}
} // namespace mindspore::kernel

View File

@ -0,0 +1,54 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP16_CONVOLUTION_1x1_FP16_H_
#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP16_CONVOLUTION_1x1_FP16_H_
#include <arm_neon.h>
#include <vector>
#include "src/lite_kernel.h"
#include "src/runtime/kernel/arm/fp16/convolution_base_fp16.h"
#include "src/runtime/kernel/arm/nnacl/optimized_kernel.h"
namespace mindspore::kernel {
class Convolution1x1FP16CPUKernel : public ConvolutionBaseFP16CPUKernel {
public:
Convolution1x1FP16CPUKernel(OpParameter *parameter, const std::vector<lite::tensor::Tensor *> &inputs,
const std::vector<lite::tensor::Tensor *> &outputs, const Context *ctx,
const lite::Primitive *primitive)
: ConvolutionBaseFP16CPUKernel(parameter, inputs, outputs, ctx, primitive) {}
~Convolution1x1FP16CPUKernel() override {
if (fp16_input_ != nullptr) {
free(fp16_input_);
}
if (fp16_weight_ != nullptr) {
free(fp16_weight_);
}
if (fp16_out_ != nullptr) {
free(fp16_out_);
}
}
int Init() override;
int ReSize() override;
int Run() override;
int RunImpl(int task_id);
private:
};
} // namespace mindspore::kernel
#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP16_CONVOLUTION_1x1_FP16_H_

View File

@ -52,8 +52,6 @@ void ProcessFilterFp16(float16_t *origin_weight, float16_t *dst_weight, ConvPara
int Convolution3x3FP16CPUKernel::InitWeightBias() {
auto input_channel = conv_param_->input_channel_;
int output_channel = conv_param_->output_channel_;
int kernel_h = conv_param_->kernel_h_;
int kernel_w = conv_param_->kernel_w_;
int iC4 = UP_DIV(input_channel, C4NUM);
int oC8 = UP_DIV(output_channel, C8NUM);
// init weight
@ -64,18 +62,8 @@ int Convolution3x3FP16CPUKernel::InitWeightBias() {
return RET_ERROR;
}
memset(transformed_filter_addr_, 0, transformed_size);
float *origin_weight = reinterpret_cast<float *>(in_tensors_.at(kWeightIndex)->Data());
size_t fp16_weight_size = input_channel * output_channel * kernel_h * kernel_w * sizeof(float16_t);
fp16_weight_ = reinterpret_cast<float16_t *>(malloc(fp16_weight_size));
if (fp16_weight_ == nullptr) {
MS_LOG(ERROR) << "malloc fp16_weight_ failed.";
return RET_ERROR;
}
memset(fp16_weight_, 0, fp16_weight_size);
for (int i = 0; i < fp16_weight_size / sizeof(float16_t); ++i) {
fp16_weight_[i] = (float16_t)origin_weight[i];
}
ProcessFilterFp16(fp16_weight_, transformed_filter_addr_, conv_param_);
ConvolutionBaseFP16CPUKernel::GetExecuteFilter();
ProcessFilterFp16(execute_weight_, transformed_filter_addr_, conv_param_);
// init bias
size_t new_bias_size = oC8 * C8NUM * sizeof(float16_t);
@ -183,10 +171,6 @@ void Convolution3x3FP16CPUKernel::ConfigInputOutput() {
}
int Convolution3x3FP16CPUKernel::Init() {
if (context_->infer_shape_interrupt_ && !context_->running_) {
set_need_reinit();
return RET_OK;
}
auto ret = ConvolutionBaseCPUKernel::Init();
if (ret != RET_OK) {
MS_LOG(ERROR) << "ConvolutionBase init failed.";
@ -244,8 +228,8 @@ int Convolution3x3FP16CPUKernel::ReSize() {
int Convolution3x3FP16CPUKernel::RunImpl(int task_id) {
Conv3x3Fp16(reinterpret_cast<float16_t *>(nhwc4_input_), transformed_filter_addr_,
reinterpret_cast<float16_t *>(bias_data_), fp16_out_, tile_buffer_, block_unit_buffer_, tmp_dst_buffer_,
tmp_out_, task_id, conv_param_);
reinterpret_cast<float16_t *>(bias_data_), execute_output_, tile_buffer_, block_unit_buffer_,
tmp_dst_buffer_, tmp_out_, task_id, conv_param_);
return RET_OK;
}
@ -265,16 +249,13 @@ int Convolution3x3FP16CPUKernel::Run() {
MS_LOG(ERROR) << "Prepare failed.";
return RET_ERROR;
}
auto input_tensor = in_tensors_.at(kInputIndex);
auto input_ele_num = input_tensor->ElementsNum();
auto ori_input_data = reinterpret_cast<float *>(input_tensor->Data());
Float32ToFloat16(ori_input_data, fp16_input_, input_ele_num);
ConvolutionBaseFP16CPUKernel::GetExecuteTensor();
int in_batch = conv_param_->input_batch_;
int in_h = conv_param_->input_h_;
int in_w = conv_param_->input_w_;
int in_channel = conv_param_->input_channel_;
convert_func_(reinterpret_cast<void *>(fp16_input_), nhwc4_input_, in_batch, in_h * in_w, in_channel);
convert_func_(reinterpret_cast<void *>(execute_input_), nhwc4_input_, in_batch, in_h * in_w, in_channel);
int error_code = LiteBackendParallelLaunch(Convolution3x3Fp16Impl, this, thread_count_);
if (error_code != RET_OK) {
@ -294,7 +275,7 @@ int Convolution3x3FP16CPUKernel::Run() {
batch * oc8 * C8NUM * out_w_block * out_h_block * conv_param_->output_unit_ * conv_param_->output_unit_;
int ro_batch_size = batch * conv_param_->output_channel_ * conv_param_->output_h_ * conv_param_->output_w_;
const float16_t *batch_tmp_out = tmp_out_ + tmp_out_batch_offset;
float16_t *batch_out = fp16_out_ + ro_batch_size;
float16_t *batch_out = execute_output_ + ro_batch_size;
for (int h = 0; h < conv_param_->output_h_; h++) {
for (int w = 0; w < conv_param_->output_w_; w++) {
for (int c = 0; c < conv_param_->output_channel_; c++) {
@ -315,11 +296,7 @@ int Convolution3x3FP16CPUKernel::Run() {
}
}
// cast fp16 out to fp32 data
auto out_tensor = out_tensors_.at(kOutputIndex);
auto out_ele_num = out_tensor->ElementsNum();
auto output_addr = reinterpret_cast<float *>(out_tensor->Data());
Float16ToFloat32(fp16_out_, output_addr, out_ele_num);
ConvolutionBaseFP16CPUKernel::IfCastOutput();
return RET_OK;
}
} // namespace mindspore::kernel

View File

@ -20,16 +20,16 @@
#include <arm_neon.h>
#include <vector>
#include "src/lite_kernel.h"
#include "src/runtime/kernel/arm/base/convolution_base.h"
#include "src/runtime/kernel/arm/fp16/convolution_base_fp16.h"
#include "src/runtime/kernel/arm/nnacl/optimized_kernel.h"
namespace mindspore::kernel {
class Convolution3x3FP16CPUKernel : public ConvolutionBaseCPUKernel {
class Convolution3x3FP16CPUKernel : public ConvolutionBaseFP16CPUKernel {
public:
Convolution3x3FP16CPUKernel(OpParameter *parameter, const std::vector<lite::tensor::Tensor *> &inputs,
const std::vector<lite::tensor::Tensor *> &outputs, const Context *ctx,
const lite::Primitive *primitive)
: ConvolutionBaseCPUKernel(parameter, inputs, outputs, ctx, primitive) {}
: ConvolutionBaseFP16CPUKernel(parameter, inputs, outputs, ctx, primitive) {}
~Convolution3x3FP16CPUKernel() override {
if (fp16_input_ != nullptr) {
free(fp16_input_);
@ -66,9 +66,6 @@ class Convolution3x3FP16CPUKernel : public ConvolutionBaseCPUKernel {
void ConfigInputOutput();
private:
float16_t *fp16_input_;
float16_t *fp16_weight_;
float16_t *fp16_out_;
float16_t *transformed_filter_addr_;
float16_t *tile_buffer_;
float16_t *block_unit_buffer_;

View File

@ -0,0 +1,86 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "src/runtime/kernel/arm/fp16/convolution_base_fp16.h"
#include "src/runtime/kernel/arm/nnacl/fp16/cast_fp16.h"
#include "schema/model_generated.h"
#include "src/kernel_factory.h"
#include "include/errorcode.h"
#include "src/runtime/runtime_api.h"
namespace mindspore::kernel {
int ConvolutionBaseFP16CPUKernel::GetExecuteTensor() {
// ===================input====================//
auto input_tensor = in_tensors_.at(kInputIndex);
auto input_data_type = input_tensor->data_type();
MS_ASSERT(input_data_type == kNumberTypeFloat32 || input_data_type == kNumberTypeFloat16);
if (input_data_type == kNumberTypeFloat32) {
auto input_ele_num = input_tensor->ElementsNum();
auto ori_input_data = reinterpret_cast<float *>(input_tensor->Data());
Float32ToFloat16(ori_input_data, fp16_input_, input_ele_num);
execute_input_ = fp16_input_;
} else {
auto ori_input_data = reinterpret_cast<float16_t *>(input_tensor->Data());
execute_input_ = ori_input_data;
}
// ==================output====================//
auto out_tensor = out_tensors_.at(kOutputIndex);
auto out_data_type = out_tensor->data_type();
MS_ASSERT(out_data_type == kNumberTypeFloat32 || out_data_type == kNumberTypeFloat16);
out_data_type_ = out_data_type;
if (out_data_type == kNumberTypeFloat32) {
execute_output_ = fp16_out_;
} else {
auto out_ptr = reinterpret_cast<float16_t *>(out_tensor->Data());
execute_output_ = out_ptr;
}
return RET_OK;
}
int ConvolutionBaseFP16CPUKernel::GetExecuteFilter() {
auto weight_tensor = in_tensors_.at(kWeightIndex);
auto weight_data_type = weight_tensor->data_type();
MS_ASSERT(weight_data_type == kNumberTypeFloat32 || weight_data_type == kNumberTypeFloat16);
if (weight_data_type == kNumberTypeFloat32) {
float *origin_weight = reinterpret_cast<float *>(in_tensors_.at(kWeightIndex)->Data());
size_t fp16_weight_size = conv_param_->input_channel_ * conv_param_->output_channel_ * conv_param_->kernel_h_ *
conv_param_->input_w_ * sizeof(float16_t);
fp16_weight_ = reinterpret_cast<float16_t *>(malloc(fp16_weight_size));
if (fp16_weight_ == nullptr) {
MS_LOG(ERROR) << "malloc fp16_weight_ failed.";
return RET_ERROR;
}
for (int i = 0; i < fp16_weight_size / sizeof(float16_t); ++i) {
fp16_weight_[i] = (float16_t)origin_weight[i];
}
execute_weight_ = fp16_weight_;
} else {
auto *origin_weight = reinterpret_cast<float16_t *>(in_tensors_.at(kWeightIndex)->Data());
execute_weight_ = origin_weight;
}
return RET_OK;
}
void ConvolutionBaseFP16CPUKernel::IfCastOutput() {
if (out_data_type_ == kNumberTypeFloat32) {
auto out_tensor = out_tensors_.at(kOutputIndex);
auto out_ele_num = out_tensor->ElementsNum();
auto output_addr = reinterpret_cast<float *>(out_tensor->Data());
Float16ToFloat32(fp16_out_, output_addr, out_ele_num);
}
}
} // namespace mindspore::kernel

View File

@ -0,0 +1,54 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP16_CONVOLUTION_BASE_FP16_H_
#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP16_CONVOLUTION_BASE_FP16_H_
#include <arm_neon.h>
#include <vector>
#include "src/lite_kernel.h"
#include "src/runtime/kernel/arm/base/convolution_base.h"
#include "src/runtime/kernel/arm/nnacl/optimized_kernel.h"
namespace mindspore::kernel {
class ConvolutionBaseFP16CPUKernel : public ConvolutionBaseCPUKernel {
public:
ConvolutionBaseFP16CPUKernel(OpParameter *parameter, const std::vector<lite::tensor::Tensor *> &inputs,
const std::vector<lite::tensor::Tensor *> &outputs, const Context *ctx,
const lite::Primitive *primitive)
: ConvolutionBaseCPUKernel(parameter, inputs, outputs, ctx, primitive) {}
~ConvolutionBaseFP16CPUKernel() override = default;
int Init() override { return RET_OK; }
int ReSize() override { return RET_OK; }
int Run() override { return RET_OK; }
int RunImpl(int task_id) { return RET_OK; }
virtual int GetExecuteTensor();
virtual int GetExecuteFilter();
virtual void IfCastOutput();
protected:
float16_t *fp16_input_ = nullptr;
float16_t *fp16_weight_ = nullptr;
float16_t *fp16_out_ = nullptr;
float16_t *execute_input_;
float16_t *execute_weight_;
float16_t *execute_output_;
TypeId out_data_type_;
};
} // namespace mindspore::kernel
#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP16_CONVOLUTION_BASE_FP16_H_

View File

@ -102,10 +102,6 @@ int ConvolutionDepthwiseFp16CPUKernel::InitWeightBias() {
}
int ConvolutionDepthwiseFp16CPUKernel::Init() {
if (context_->infer_shape_interrupt_ && !context_->running_) {
set_need_reinit();
return RET_OK;
}
// conv base init
auto ret = ConvolutionBaseCPUKernel::Init();
if (ret != RET_OK) {

View File

@ -46,24 +46,14 @@ int ConvolutionFP16CPUKernel::InitWeightBias() {
int pack_weight_size = oc8 * ic4 * C8NUM * C4NUM * kernel_plane;
// init weight
float *origin_weight = reinterpret_cast<float *>(in_tensors_.at(kWeightIndex)->Data());
size_t fp16_weight_size = in_channel * out_channel * kernel_h * kernel_w * sizeof(float16_t);
fp16_weight_ = reinterpret_cast<float16_t *>(malloc(fp16_weight_size));
if (fp16_weight_ == nullptr) {
MS_LOG(ERROR) << "malloc fp16_weight_ failed.";
return RET_ERROR;
}
for (int i = 0; i < fp16_weight_size / sizeof(float16_t); ++i) {
fp16_weight_[i] = (float16_t)origin_weight[i];
}
ConvolutionBaseFP16CPUKernel::GetExecuteFilter();
packed_weight_ = reinterpret_cast<float16_t *>(malloc(pack_weight_size * sizeof(float16_t)));
if (packed_weight_ == nullptr) {
MS_LOG(ERROR) << "malloc packed_weight_ failed.";
return RET_ERROR;
}
memset(packed_weight_, 0, pack_weight_size * sizeof(float16_t));
PackWeightFp16(fp16_weight_, conv_param_, packed_weight_);
PackWeightFp16(execute_weight_, conv_param_, packed_weight_);
// init bias
bias_data_ = malloc(oc8 * C8NUM * sizeof(float16_t));
@ -157,10 +147,6 @@ void ConvolutionFP16CPUKernel::ConfigInputOutput() {
}
int ConvolutionFP16CPUKernel::Init() {
if (context_->infer_shape_interrupt_ && !context_->running_) {
set_need_reinit();
return RET_OK;
}
auto ret = ConvolutionBaseCPUKernel::Init();
if (ret != RET_OK) {
MS_LOG(ERROR) << "ConvolutionBase init fail!ret: " << ret;
@ -212,7 +198,7 @@ int ConvolutionFP16CPUKernel::ReSize() {
int ConvolutionFP16CPUKernel::RunImpl(int task_id) {
ConvFp16(reinterpret_cast<float16_t *>(nhwc4_input_), packed_input_, packed_weight_,
reinterpret_cast<float16_t *>(bias_data_), tmp_output_block_, fp16_out_, task_id, conv_param_);
reinterpret_cast<float16_t *>(bias_data_), tmp_output_block_, execute_output_, task_id, conv_param_);
return RET_OK;
}
@ -232,16 +218,13 @@ int ConvolutionFP16CPUKernel::Run() {
MS_LOG(ERROR) << "Prepare failed.";
return RET_ERROR;
}
auto input_tensor = in_tensors_.at(kInputIndex);
auto ori_input_data = reinterpret_cast<float *>(input_tensor->Data());
auto input_ele_num = input_tensor->ElementsNum();
Float32ToFloat16(ori_input_data, fp16_input_, input_ele_num);
ConvolutionBaseFP16CPUKernel::GetExecuteTensor();
int in_batch = conv_param_->input_batch_;
int in_h = conv_param_->input_h_;
int in_w = conv_param_->input_w_;
int in_channel = conv_param_->input_channel_;
convert_func_(reinterpret_cast<void *>(fp16_input_), nhwc4_input_, in_batch, in_h * in_w, in_channel);
convert_func_(reinterpret_cast<void *>(execute_input_), nhwc4_input_, in_batch, in_h * in_w, in_channel);
int error_code = LiteBackendParallelLaunch(ConvolutionFp16Impl, this, thread_count_);
if (error_code != RET_OK) {
@ -249,11 +232,7 @@ int ConvolutionFP16CPUKernel::Run() {
return RET_ERROR;
}
// cast fp16 out to fp32 data
auto out_tensor = out_tensors_.at(kOutputIndex);
auto output_addr = reinterpret_cast<float *>(out_tensor->Data());
auto out_ele_num = out_tensor->ElementsNum();
Float16ToFloat32(fp16_out_, output_addr, out_ele_num);
ConvolutionBaseFP16CPUKernel::IfCastOutput();
return RET_OK;
}

View File

@ -20,15 +20,15 @@
#include <arm_neon.h>
#include <vector>
#include "src/lite_kernel.h"
#include "src/runtime/kernel/arm/base/convolution_base.h"
#include "src/runtime/kernel/arm/fp16/convolution_base_fp16.h"
namespace mindspore::kernel {
class ConvolutionFP16CPUKernel : public ConvolutionBaseCPUKernel {
class ConvolutionFP16CPUKernel : public ConvolutionBaseFP16CPUKernel {
public:
ConvolutionFP16CPUKernel(OpParameter *parameter, const std::vector<lite::tensor::Tensor *> &inputs,
const std::vector<lite::tensor::Tensor *> &outputs, const Context *ctx,
const lite::Primitive *primitive)
: ConvolutionBaseCPUKernel(parameter, inputs, outputs, ctx, primitive) {}
: ConvolutionBaseFP16CPUKernel(parameter, inputs, outputs, ctx, primitive) {}
~ConvolutionFP16CPUKernel() override {
if (fp16_input_ != nullptr) {
free(fp16_input_);
@ -59,9 +59,6 @@ class ConvolutionFP16CPUKernel : public ConvolutionBaseCPUKernel {
void ConfigInputOutput();
private:
float16_t *fp16_input_;
float16_t *fp16_weight_;
float16_t *fp16_out_;
float16_t *packed_input_;
float16_t *packed_weight_;
float16_t *tmp_output_block_;

View File

@ -39,23 +39,13 @@ int ConvolutionSWFP16CPUKernel::ProcessFilter() {
int out_channel = conv_param_->output_channel_;
int ic4 = UP_DIV(in_channel, C4NUM);
auto *origin_weight = reinterpret_cast<float *>(in_tensors_.at(kWeightIndex)->Data());
size_t fp16_weight_size = in_channel * out_channel * kernel_h * kernel_w * sizeof(float16_t);
fp16_weight_ = reinterpret_cast<float16_t *>(malloc(fp16_weight_size));
if (fp16_weight_ == nullptr) {
MS_LOG(ERROR) << "malloc fp16_weight_ failed.";
return RET_ERROR;
}
// cast origin fp32 weight data to fp16 data
for (int i = 0; i < fp16_weight_size / sizeof(float16_t); ++i) {
fp16_weight_[i] = (float16_t)origin_weight[i];
}
ConvolutionBaseFP16CPUKernel::GetExecuteFilter();
for (int oc = 0; oc < out_channel; ++oc) {
int src_oc_offset = oc * kernel_h * kernel_w * in_channel;
int dst_oc_offset = oc * kernel_h * kernel_w * ic4 * C4NUM;
for (int i = 0; i < kernel_h * kernel_w; ++i) {
const float16_t *src = fp16_weight_ + src_oc_offset + i * in_channel;
const float16_t *src = execute_weight_ + src_oc_offset + i * in_channel;
float16_t *dst = packed_weight_ + dst_oc_offset + i * ic4 * C4NUM;
memcpy(dst, src, in_channel * sizeof(float16_t));
}
@ -162,10 +152,6 @@ void ConvolutionSWFP16CPUKernel::ConfigInputOutput() {
}
int ConvolutionSWFP16CPUKernel::Init() {
if (context_->infer_shape_interrupt_ && !context_->running_) {
set_need_reinit();
return RET_OK;
}
auto ret = ConvolutionBaseCPUKernel::Init();
if (ret != RET_OK) {
MS_LOG(ERROR) << "ConvolutionBase init fail!ret: " << ret;
@ -222,7 +208,7 @@ int ConvolutionSWFP16CPUKernel::ReSize() {
int ConvolutionSWFP16CPUKernel::RunImpl(int task_id) {
ConvSWFp16(reinterpret_cast<float16_t *>(nhwc4_input_), packed_weight_, reinterpret_cast<float16_t *>(bias_data_),
tmp_output_block_, fp16_out_, task_id, conv_param_, slidingWindow_param_);
tmp_output_block_, execute_output_, task_id, conv_param_, slidingWindow_param_);
return RET_OK;
}
@ -242,16 +228,13 @@ int ConvolutionSWFP16CPUKernel::Run() {
MS_LOG(ERROR) << "Prepare failed.";
return RET_ERROR;
}
auto input_tensor = in_tensors_.at(kInputIndex);
auto input_ele_num = input_tensor->ElementsNum();
auto ori_input_data = reinterpret_cast<float *>(input_tensor->Data());
Float32ToFloat16(ori_input_data, fp16_input_, input_ele_num);
ConvolutionBaseFP16CPUKernel::GetExecuteTensor();
int in_batch = conv_param_->input_batch_;
int in_h = conv_param_->input_h_;
int in_w = conv_param_->input_w_;
int in_channel = conv_param_->input_channel_;
convert_func_(reinterpret_cast<void *>(fp16_input_), nhwc4_input_, in_batch, in_h * in_w, in_channel);
convert_func_(reinterpret_cast<void *>(execute_input_), nhwc4_input_, in_batch, in_h * in_w, in_channel);
int error_code = LiteBackendParallelLaunch(ConvolutionSWFp16Impl, this, thread_count_);
if (error_code != RET_OK) {
@ -259,18 +242,14 @@ int ConvolutionSWFP16CPUKernel::Run() {
return RET_ERROR;
}
// cast fp16 out to fp32 data
auto out_tensor = out_tensors_.at(kOutputIndex);
auto out_ele_num = out_tensor->ElementsNum();
auto output_addr = reinterpret_cast<float *>(out_tensor->Data());
// output nhwc4
int oc4_res = conv_param_->output_channel_ % C4NUM;
if (oc4_res != 0) {
PackNHWC4ToNHWCFp16(reinterpret_cast<const void *>(tmp_output_block_), reinterpret_cast<void *>(fp16_out_),
PackNHWC4ToNHWCFp16(reinterpret_cast<const void *>(tmp_output_block_), reinterpret_cast<void *>(execute_output_),
conv_param_->output_batch_, conv_param_->output_h_ * conv_param_->output_w_,
conv_param_->output_channel_);
}
Float16ToFloat32(fp16_out_, output_addr, out_ele_num);
ConvolutionBaseFP16CPUKernel::IfCastOutput();
return RET_OK;
}
} // namespace mindspore::kernel

View File

@ -19,15 +19,15 @@
#include <arm_neon.h>
#include <vector>
#include "src/lite_kernel.h"
#include "src/runtime/kernel/arm/base/convolution_base.h"
#include "src/runtime/kernel/arm/fp16/convolution_base_fp16.h"
namespace mindspore::kernel {
class ConvolutionSWFP16CPUKernel : public ConvolutionBaseCPUKernel {
class ConvolutionSWFP16CPUKernel : public ConvolutionBaseFP16CPUKernel {
public:
ConvolutionSWFP16CPUKernel(OpParameter *parameter, const std::vector<lite::tensor::Tensor *> &inputs,
const std::vector<lite::tensor::Tensor *> &outputs, const Context *ctx,
const lite::Primitive *primitive)
: ConvolutionBaseCPUKernel(parameter, inputs, outputs, ctx, primitive) {}
: ConvolutionBaseFP16CPUKernel(parameter, inputs, outputs, ctx, primitive) {}
~ConvolutionSWFP16CPUKernel() override {
if (fp16_input_ != nullptr) {
free(fp16_input_);
@ -57,9 +57,6 @@ class ConvolutionSWFP16CPUKernel : public ConvolutionBaseCPUKernel {
int ProcessFilter();
private:
float16_t *fp16_input_;
float16_t *fp16_weight_;
float16_t *fp16_out_;
float16_t *packed_weight_;
float16_t *tmp_output_block_;
SlidingWindowParam *slidingWindow_param_;

View File

@ -0,0 +1,409 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "src/runtime/kernel/arm/fp16/convolution_winograd_fp16.h"
#include "src/runtime/kernel/arm/fp16/matrix_fp16.h"
#include "src/runtime/kernel/arm/nnacl/fp16/conv_fp16.h"
#include "src/runtime/kernel/arm/nnacl/fp16/common_func.h"
#include "src/runtime/kernel/arm/nnacl/fp16/cast_fp16.h"
#include "src/runtime/kernel/arm/nnacl/fp16/pack_fp16.h"
#include "src/runtime/kernel/arm/nnacl/fp16/winograd_transform_fp16.h"
#include "src/runtime/kernel/arm/nnacl/fp16/winograd_utils_fp16.h"
#include "src/runtime/kernel/arm/fp16/layout_transform_fp16.h"
#include "schema/model_generated.h"
#include "src/kernel_registry.h"
#include "include/errorcode.h"
#include "src/runtime/runtime_api.h"
using mindspore::kernel::KERNEL_ARCH::kCPU;
using mindspore::lite::KernelRegistrar;
using mindspore::lite::RET_ERROR;
using mindspore::lite::RET_OK;
using mindspore::schema::PrimitiveType_Conv2D;
namespace mindspore::kernel {
void WinogradFilterTransformFp16(const float16_t *weight_data, Matrix *trans_weight, int kernel_unit, int input_unit,
ConvParameter *conv_param, int oc_block) {
// original weight format : ohwi
auto channel_in = conv_param->input_channel_;
auto channel_out = conv_param->output_channel_;
int input_unit_square = input_unit * input_unit;
// generate matrix_G && matrix_GT
auto matrix_g = TransformMatrixGenerator(input_unit, kernel_unit);
auto matrix_gt = TransformMatrixGenerator(kernel_unit, input_unit);
ChooseMatrixG(matrix_g, matrix_gt);
auto matrix_g_data = reinterpret_cast<float *>(matrix_g->GetData());
auto matrix_gt_data = reinterpret_cast<float *>(matrix_gt->GetData());
auto matrix_g_data_fp16 = reinterpret_cast<float16_t *>(malloc(input_unit * kernel_unit * sizeof(float16_t)));
auto matrix_gt_data_fp16 = reinterpret_cast<float16_t *>(malloc(input_unit * kernel_unit * sizeof(float16_t)));
Float32ToFloat16(matrix_g_data, matrix_g_data_fp16, input_unit * kernel_unit);
Float32ToFloat16(matrix_gt_data, matrix_gt_data_fp16, input_unit * kernel_unit);
// trans_filter = G*g*GT (g represents weight_data)
// separate into two steps ===> tmp = G*g ===> out = tmp * GT
auto tmp_weight_data = reinterpret_cast<float16_t *>(malloc(kernel_unit * kernel_unit * sizeof(float16_t)));
auto tmp_data = reinterpret_cast<float16_t *>(malloc(input_unit * kernel_unit * sizeof(float16_t)));
auto trans_out_data = reinterpret_cast<float16_t *>(malloc(input_unit * input_unit * sizeof(float16_t)));
bool row = true;
auto trans_weight_data = reinterpret_cast<float16_t *>(trans_weight->GetData());
std::vector<int> strides = trans_weight->GetStride();
int kernel_plane_stride = channel_in;
if (oc_block == 0) {
MS_LOG(ERROR) << "Divide by zero";
return;
}
for (int i = 0; i < channel_out; i++) {
int out_c_block = i / oc_block;
int out_c_res = i % oc_block;
int input_oz_offset = i * kernel_unit * kernel_unit * channel_in;
int output_oz_offset = out_c_block * strides[1] * input_unit * input_unit + out_c_res;
for (int j = 0; j < channel_in; j++) {
int ic4_block = j / C4NUM;
int ic4_res = j % C4NUM;
int input_iz_offset = input_oz_offset + j;
int output_iz_offset = output_oz_offset + ic4_block * strides[2] + ic4_res * strides[3];
for (int k = 0; k < kernel_unit * kernel_unit; k++) {
int input_xy_offset = input_iz_offset + k * kernel_plane_stride;
tmp_weight_data[k] = *(weight_data + input_xy_offset);
}
// now we only support row-major matrix-multiply
// tmp = G * g
MatrixMultiplyFp16(matrix_g_data_fp16, tmp_weight_data, tmp_data, input_unit, kernel_unit, kernel_unit, row);
// out = tmp * GT
MatrixMultiplyFp16(tmp_data, matrix_gt_data_fp16, trans_out_data, input_unit, kernel_unit, input_unit, row);
for (int z = 0; z < input_unit_square; z++) {
int output_xy_offset = output_iz_offset + z * strides[1];
*(trans_weight_data + output_xy_offset) = trans_out_data[z];
}
}
}
free(tmp_weight_data);
free(tmp_data);
free(trans_out_data);
free(matrix_g_data_fp16);
free(matrix_gt_data_fp16);
delete matrix_g;
delete matrix_gt;
}
int ConvolutionWinogradFP16CPUKernel::InitWeightBias() {
int output_channel = conv_param_->output_channel_;
int oc_block, oc_block_num;
oc_block = C8NUM;
oc_block_num = UP_DIV(output_channel, C8NUM);
// init weight
auto ret = MallocFilterMatrix(oc_block, oc_block_num);
if (ret != RET_OK) {
MS_LOG(ERROR) << "Malloc filter matrix failed.";
return RET_ERROR;
}
ConvolutionBaseFP16CPUKernel::GetExecuteFilter();
WinogradFilterTransformFp16(execute_weight_, trans_weight_, kernel_unit_, input_unit_, conv_param_, oc_block);
// init bias
bias_data_ = malloc(oc_block_num * oc_block * sizeof(float16_t));
if (bias_data_ == nullptr) {
MS_LOG(ERROR) << "malloc bias_data_ failed.";
return RET_ERROR;
}
memset(bias_data_, 0, oc_block_num * oc_block * sizeof(float16_t));
auto fp16_bias_data = reinterpret_cast<float16_t *>(bias_data_);
if (in_tensors_.size() == kInputSize2) {
auto ori_bias = reinterpret_cast<float *>(in_tensors_.at(kBiasIndex)->Data());
for (int i = 0; i < output_channel; ++i) {
fp16_bias_data[i] = (float16_t)ori_bias[i];
}
} else {
MS_ASSERT(inputs_.size() == kInputSize1);
}
return RET_OK;
}
int ConvolutionWinogradFP16CPUKernel::MallocFilterMatrix(int oc_block, int oc_block_num) {
int channel_in = conv_param_->input_channel_;
int ic4 = UP_DIV(channel_in, BLOCK);
// set data
auto trans_matrix_data_size = input_unit_ * input_unit_ * ic4 * C4NUM * oc_block_num * oc_block * sizeof(float);
auto matrix_buffer = malloc(trans_matrix_data_size);
if (matrix_buffer == nullptr) {
MS_LOG(ERROR) << "malloc matrix_buffer failed.";
return RET_ERROR;
}
memset(matrix_buffer, 0, trans_matrix_data_size);
trans_weight_ = new Matrix();
trans_weight_->SetData(matrix_buffer);
trans_weight_->SetNDim(5);
std::vector<int> shapes;
std::vector<int> strides;
// set shape
shapes.push_back(input_unit_ * input_unit_);
shapes.push_back(oc_block_num);
shapes.push_back(ic4);
shapes.push_back(C4NUM);
shapes.push_back(oc_block);
// set stride
for (int i = 0; i < 4; i++) {
int stride = 1;
for (int j = i + 1; j < 5; j++) {
stride *= shapes[j];
}
strides.push_back(stride);
}
trans_weight_->SetShape(shapes);
trans_weight_->SetStride(strides);
return RET_OK;
}
int ConvolutionWinogradFP16CPUKernel::InitTmpBuffer() {
int cal_num = 16;
int channel_in = conv_param_->input_channel_;
int channel_out = conv_param_->output_channel_;
int output_h = conv_param_->output_h_;
int output_w = conv_param_->output_w_;
int ic4 = UP_DIV(channel_in, C4NUM);
int oc8 = UP_DIV(channel_out, C8NUM);
/*=============================fp16_input_============================*/
size_t fp16_input_size = conv_param_->input_channel_ * conv_param_->input_batch_ * conv_param_->input_h_ *
conv_param_->input_w_ * sizeof(float16_t);
fp16_input_ = reinterpret_cast<float16_t *>(malloc(fp16_input_size));
if (fp16_input_ == nullptr) {
MS_LOG(ERROR) << "malloc fp16_input_ failed.";
return RET_ERROR;
}
/*=============================trans_input_============================*/
size_t tile_buffer_size = thread_count_ * cal_num * input_unit_ * input_unit_ * ic4 * C4NUM * sizeof(float16_t);
trans_input_ = reinterpret_cast<float16_t *>(malloc(tile_buffer_size));
if (trans_input_ == nullptr) {
MS_LOG(ERROR) << "malloc trans_input_ failed.";
return RET_ERROR;
}
memset(trans_input_, 0, tile_buffer_size);
/*=============================gemm_out_============================*/
gemm_out_ = reinterpret_cast<float16_t *>(
malloc(thread_count_ * cal_num * input_unit_ * input_unit_ * oc8 * C8NUM * sizeof(float16_t)));
if (gemm_out_ == nullptr) {
MS_LOG(ERROR) << "malloc gemm_out_ failed.";
return RET_ERROR;
}
/*=============================tmp_out_data_============================*/
int out_w_block = UP_DIV(output_w, output_unit_);
int out_h_block = UP_DIV(output_h, output_unit_);
tmp_out_data_ = reinterpret_cast<float16_t *>(malloc(conv_param_->output_batch_ * out_w_block * out_h_block *
output_unit_ * output_unit_ * oc8 * C8NUM * sizeof(float16_t)));
if (tmp_out_data_ == nullptr) {
MS_LOG(ERROR) << "malloc tmp_out_data_ failed.";
return RET_ERROR;
}
/*=============================fp16_out_============================*/
size_t fp16_output_size = conv_param_->output_channel_ * conv_param_->output_batch_ * conv_param_->output_h_ *
conv_param_->output_w_ * sizeof(float16_t);
fp16_out_ = reinterpret_cast<float16_t *>(malloc(fp16_output_size));
if (fp16_out_ == nullptr) {
MS_LOG(ERROR) << "malloc fp16_out_ failed.";
return RET_ERROR;
}
/*=============================tmp_data_============================*/
tmp_data_ =
reinterpret_cast<float16_t *>(malloc(thread_count_ * C4NUM * input_unit_ * input_unit_ * sizeof(float16_t)));
if (tmp_data_ == nullptr) {
MS_LOG(ERROR) << "malloc tmp_data_ failed.";
return RET_ERROR;
}
memset(tmp_data_, 0, C4NUM * input_unit_ * input_unit_ * sizeof(float16_t));
tmp_buffer_address_list_[0] = trans_input_;
tmp_buffer_address_list_[1] = gemm_out_;
tmp_buffer_address_list_[2] = tmp_out_data_;
tmp_buffer_address_list_[3] = tmp_data_;
/*=============================nhwc4_input_============================*/
size_t nhwc4_input_size =
ic4 * C4NUM * conv_param_->input_batch_ * conv_param_->input_h_ * conv_param_->input_w_ * sizeof(float16_t);
nhwc4_input_ = malloc(nhwc4_input_size);
if (nhwc4_input_ == nullptr) {
MS_LOG(ERROR) << "malloc nhwc4_input_ failed.";
return RET_ERROR;
}
memset(nhwc4_input_, 0, nhwc4_input_size);
return RET_OK;
}
int ConvolutionWinogradFP16CPUKernel::ConfigInputOutput() {
auto input_tensor = in_tensors_.at(kInputIndex);
auto ret = CheckLayout(input_tensor);
if (ret != RET_OK) {
MS_LOG(ERROR) << "Check layout failed.";
return RET_ERROR;
}
auto output_tensor = out_tensors_.at(kOutputIndex);
output_tensor->SetFormat(schema::Format_NHWC);
// choose input transformer function (4x4 unit or 8x8 unit)
input_trans_func_ = GetInputTransFuncFp16(input_unit_);
if (input_trans_func_ == nullptr) {
MS_LOG(ERROR) << "Get input_trans_func failed.";
return RET_ERROR;
}
output_trans_func_ = GetOutputTransFuncFp16(input_unit_, output_unit_);
if (output_trans_func_ == nullptr) {
MS_LOG(ERROR) << "Get output_trans_func_ failed.";
return RET_ERROR;
}
return RET_OK;
}
int ConvolutionWinogradFP16CPUKernel::Init() {
auto ret = ConvolutionBaseCPUKernel::Init();
if (ret != RET_OK) {
MS_LOG(ERROR) << "ConvolutionBase init failed.";
return RET_ERROR;
}
kernel_unit_ = conv_param_->kernel_h_;
input_unit_ = output_unit_ + kernel_unit_ - 1;
conv_param_->input_unit_ = input_unit_;
conv_param_->output_unit_ = output_unit_;
ret = InitWeightBias();
if (ret != RET_OK) {
MS_LOG(ERROR) << "Init weight bias failed.";
return RET_ERROR;
}
// malloc tmp buffer
ret = InitTmpBuffer();
if (ret != RET_OK) {
MS_LOG(ERROR) << "Init tmp buffer failed.";
return RET_ERROR;
}
ret = ConfigInputOutput();
if (ret != RET_OK) {
MS_LOG(ERROR) << "ConfigInputOutput failed.";
return RET_ERROR;
}
return RET_OK;
}
int ConvolutionWinogradFP16CPUKernel::ReSize() {
if (tmp_data_ != nullptr) {
free(tmp_data_);
}
if (trans_input_ != nullptr) {
free(trans_input_);
}
if (gemm_out_ != nullptr) {
free(gemm_out_);
}
if (tmp_out_data_ != nullptr) {
free(tmp_out_data_);
}
if (nhwc4_input_ != nullptr) {
free(nhwc4_input_);
}
if (fp16_input_ != nullptr) {
free(fp16_input_);
}
if (fp16_out_ != nullptr) {
free(fp16_out_);
}
auto ret = ConvolutionBaseCPUKernel::Init();
if (ret != RET_OK) {
MS_LOG(ERROR) << "ConvolutionBase init failed.";
return RET_ERROR;
}
kernel_unit_ = conv_param_->kernel_h_;
input_unit_ = output_unit_ + kernel_unit_ - 1;
conv_param_->input_unit_ = input_unit_;
conv_param_->output_unit_ = output_unit_;
ret = InitTmpBuffer();
if (ret != RET_OK) {
MS_LOG(ERROR) << "Init tmp buffer failed.";
return RET_ERROR;
}
ret = ConfigInputOutput();
if (ret != RET_OK) {
MS_LOG(ERROR) << "ConfigInputOutput failed.";
return RET_ERROR;
}
return RET_OK;
}
int ConvolutionWinogradFP16CPUKernel::RunImpl(int task_id) {
ConvWinogardFp16(reinterpret_cast<float16_t *>(nhwc4_input_), reinterpret_cast<float16_t *>(trans_weight_->GetData()),
reinterpret_cast<const float16_t *>(bias_data_), tmp_buffer_address_list_, task_id, conv_param_,
input_trans_func_, output_trans_func_);
return RET_OK;
}
int ConvolutionWinogradFp16Impl(int task_id, LiteParallelGroupEnv *penv, void *cdata) {
auto conv = reinterpret_cast<ConvolutionWinogradFP16CPUKernel *>(cdata);
auto error_code = conv->RunImpl(task_id);
if (error_code != RET_OK) {
MS_LOG(ERROR) << "ConvolutionWinograd Fp16 Run error task_id[" << task_id << "] error_code[" << error_code << "]";
return RET_ERROR;
}
return RET_OK;
}
int ConvolutionWinogradFP16CPUKernel::Run() {
auto prepare_ret = Prepare();
if (prepare_ret != RET_OK) {
MS_LOG(ERROR) << "Prepare fail!ret: " << prepare_ret;
return prepare_ret;
}
ConvolutionBaseFP16CPUKernel::GetExecuteTensor();
int in_batch = conv_param_->input_batch_;
int in_h = conv_param_->input_h_;
int in_w = conv_param_->input_w_;
int in_channel = conv_param_->input_channel_;
convert_func_(execute_input_, nhwc4_input_, in_batch, in_h * in_w, in_channel);
int error_code = LiteBackendParallelLaunch(ConvolutionWinogradFp16Impl, this, thread_count_);
if (error_code != RET_OK) {
MS_LOG(ERROR) << "conv winograd error error_code[" << error_code << "]";
return RET_ERROR;
}
// get real output
UnPackWinogradOutputFp16(tmp_out_data_, execute_output_, conv_param_->output_batch_, conv_param_->output_h_,
conv_param_->output_w_, conv_param_->output_channel_, output_unit_);
int output_num =
conv_param_->output_channel_ * conv_param_->output_h_ * conv_param_->output_w_ * conv_param_->output_batch_;
if (conv_param_->is_relu_) {
ReluFp16(execute_output_, execute_output_, output_num);
} else if (conv_param_->is_relu6_) {
Relu6Fp16(execute_output_, execute_output_, output_num);
} else {
// do nothing
}
ConvolutionBaseFP16CPUKernel::IfCastOutput();
return RET_OK;
}
} // namespace mindspore::kernel

View File

@ -0,0 +1,87 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP16_CONVOLUTION_WINOGRAD_FP16_H_
#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP16_CONVOLUTION_WINOGRAD_FP16_H_
#include <arm_neon.h>
#include <vector>
#include "src/lite_kernel.h"
#include "src/runtime/kernel/arm/fp16/convolution_base_fp16.h"
#include "src/runtime/kernel/arm/nnacl/fp16/conv_fp16.h"
#include "src/runtime/kernel/arm/fp16/matrix_fp16.h"
#include "src/runtime/kernel/arm/nnacl/fp16/winograd_utils_fp16.h"
#include "src/runtime/kernel/arm/nnacl/optimized_kernel.h"
namespace mindspore::kernel {
class ConvolutionWinogradFP16CPUKernel : public ConvolutionBaseFP16CPUKernel {
public:
ConvolutionWinogradFP16CPUKernel(OpParameter *parameter, const std::vector<lite::tensor::Tensor *> &inputs,
const std::vector<lite::tensor::Tensor *> &outputs, const Context *ctx,
const lite::Primitive *primitive)
: ConvolutionBaseFP16CPUKernel(parameter, inputs, outputs, ctx, primitive) {}
~ConvolutionWinogradFP16CPUKernel() override {
if (fp16_input_ != nullptr) {
free(fp16_input_);
}
if (fp16_weight_ != nullptr) {
free(fp16_weight_);
}
if (fp16_out_ != nullptr) {
free(fp16_out_);
}
if (tmp_data_ != nullptr) {
free(tmp_data_);
}
if (trans_input_ != nullptr) {
free(trans_input_);
}
if (gemm_out_ != nullptr) {
free(gemm_out_);
}
if (tmp_out_data_ != nullptr) {
free(tmp_out_data_);
}
delete trans_weight_;
}
int Init() override;
int ReSize() override;
int Run() override;
int RunImpl(int task_id);
int InitWeightBias();
int MallocFilterMatrix(int oc_block, int oc_block_num);
int InitTmpBuffer();
int ConfigInputOutput();
private:
int kernel_unit_;
int input_unit_;
int output_unit_;
float16_t *tmp_data_;
float16_t *trans_input_;
float16_t *gemm_out_;
float16_t *tmp_out_data_;
Matrix *trans_weight_;
InputTransformUnitFp16Func input_trans_func_;
OutputTransformUnitFp16Func output_trans_func_;
TmpBufferAddressFp16 tmp_buffer_address_list_[4];
};
void WinogradFilterTransformFp16(const float16_t *weight_data, Matrix *trans_weight, int kernel_unit, int input_unit,
ConvParameter *conv_param, int oc_block);
} // namespace mindspore::kernel
#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP16_CONVOLUTION_WINOGRAD_FP16_H_

View File

@ -115,10 +115,6 @@ int DeconvolutionDepthwiseFp16CPUKernel::InitWeightBias() {
}
int DeconvolutionDepthwiseFp16CPUKernel::Init() {
if (context_->infer_shape_interrupt_ && !context_->running_) {
set_need_reinit();
return RET_OK;
}
sliding_ = new SlidingWindowParam;
InitSlideParam();
// conv base init

View File

@ -0,0 +1,39 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "src/runtime/kernel/arm/fp16/matrix_fp16.h"
namespace mindspore::kernel {
void MatrixMultiplyFp16(const float16_t *matrix_a, const float16_t *matrix_b, float16_t *matrix_c, int m, int k, int n,
bool row) {
// row-major implementation
int count = 0;
for (int h = 0; h < m; h++) {
int h_offset = h * k;
for (int w = 0; w < n; w++) {
float16_t res = 0;
for (int i = 0; i < k; i++) {
res += *(matrix_a + h_offset + i) * *(matrix_b + w + i * n);
}
*(matrix_c + count) = res;
count++;
}
}
}
} // namespace mindspore::kernel

View File

@ -0,0 +1,27 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP16_MATRIX_FP16_H_
#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP16_MATRIX_FP16_H_
#include "src/runtime/kernel/arm/base/matrix.h"
namespace mindspore::kernel {
void MatrixMultiplyFp16(const float16_t *matrix_a, const float16_t *matrix_b, float16_t *matrix_c, int m, int k, int n,
bool row);
}
#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP16_MATRIX_FP16_H_

View File

@ -53,10 +53,6 @@ int PoolingFp16CPUKernel::InitBuffer() {
}
int PoolingFp16CPUKernel::Init() {
if (context_->infer_shape_interrupt_ && !context_->running_) {
set_need_reinit();
return RET_OK;
}
auto ret = PoolingBaseCPUKernel::Init();
if (ret != RET_OK) {
MS_LOG(ERROR) << "PoolingBase Init failed.";

View File

@ -329,10 +329,9 @@ int ConvolutionWinogradCPUKernel::RunImpl(int task_id) {
MS_LOG(ERROR) << "gemm_func is nullptr.";
return RET_ERROR;
}
auto output_addr = reinterpret_cast<float *>(out_tensors_.at(kOutputIndex)->Data());
ConvWinogardFp32(reinterpret_cast<float *>(nhwc4_input_), reinterpret_cast<float *>(trans_weight_->GetData()),
reinterpret_cast<const float *>(bias_data_), output_addr, tmp_buffer_address_list_, task_id,
conv_param_, input_trans_func_, output_trans_func_, gemm_func_);
reinterpret_cast<const float *>(bias_data_), tmp_buffer_address_list_, task_id, conv_param_,
input_trans_func_, output_trans_func_, gemm_func_);
return RET_OK;
}

View File

@ -16,9 +16,7 @@
#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_CAST_FP16_H_
#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_CAST_FP16_H_
#ifdef ENABLE_NEON
#include <arm_neon.h>
#endif
#include "nnacl/op_base.h"
#include "nnacl/fp32/cast.h"
#ifdef __cplusplus

View File

@ -0,0 +1,61 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "nnacl/fp16/common_func.h"
void ReluFp16(float16_t *data, float16_t *dst, int ele_num) {
int eight_block = UP_DIV(ele_num, C8NUM);
for (int i = 0; i < eight_block - 1; i++) {
int index = i * C8NUM;
#ifdef ENABLE_NEON
float16x8_t relu_data = vld1q_f16(data + index);
float16x8_t zero_data = vdupq_n_f16(0);
relu_data = vmaxq_f16(relu_data, zero_data);
vst1q_f16(dst + index, relu_data);
#else
data[index] = data[index] < 0 ? 0 : data[index];
data[index + 1] = data[index + 1] < 0 ? 0 : data[index + 1];
data[index + 2] = data[index + 2] < 0 ? 0 : data[index + 2];
data[index + 3] = data[index + 3] < 0 ? 0 : data[index + 3];
#endif
}
for (int j = (eight_block - 1) * C8NUM; j < ele_num; ++j) {
data[j] = data[j] < 0 ? 0 : data[j];
}
}
void Relu6Fp16(float16_t *data, float16_t *dst, int ele_num) {
int eight_block = UP_DIV(ele_num, C8NUM);
for (int i = 0; i < eight_block - 1; i++) {
int index = i * C8NUM;
#ifdef ENABLE_NEON
float16x8_t relu6_data = vld1q_f16(data + index);
float16x8_t zero_data = vdupq_n_f16(0);
float16x8_t six_data = vdupq_n_f16(6);
relu6_data = vmaxq_f16(relu6_data, zero_data);
relu6_data = vminq_f16(relu6_data, six_data);
vst1q_f16(dst + index, relu6_data);
#else
for (int j = 0; j < C8NUM; ++j) {
data[index + j] = data[index + j] < 0 ? 0 : data[index + j];
data[index + j] = data[index + j] > 6 ? 6 : data[index + j];
}
#endif
}
for (int j = (eight_block - 1) * C8NUM; j < ele_num; ++j) {
data[j] = data[j] < 0 ? 0 : data[j];
data[j] = data[j] > 6 ? 6 : data[j];
}
}

View File

@ -39,6 +39,8 @@ void DeconvDwFp16Center(float16_t *dst, const float16_t *src, const float16_t *w
size_t kernel_h, size_t kernel_w, size_t out_h_step, size_t block_channel, size_t in_sh_step,
size_t in_sw_step, size_t in_kh_step, size_t in_kw_step);
#endif
void ReluFp16(float16_t *data, float16_t *dst, int ele_num);
void Relu6Fp16(float16_t *data, float16_t *dst, int ele_num);
#ifdef __cplusplus
}

View File

@ -32,12 +32,21 @@ void IndirectGemmFp16_16x8(float16_t *output, float16_t *input, float16_t *weigh
#endif
#ifndef ENABLE_NEON
void IndirectGemmFp16_16x8(float16_t *output, float16_t *input, float16_t *weight, float16_t *bias, size_t step,
size_t ic4, size_t out_channel, size_t offset, size_t mode, size_t writeC4, size_t relu,
size_t ic4, size_t out_channel, size_t offset, size_t mode, size_t writeC8, size_t relu,
size_t relu6) {
if (!(mode && writeC8)) {
IndirectGemmFp16_16x8_common(output, input, weight, bias, step, ic4, output, offset, relu, relu6);
} else {
IndirectGemmFp16_16x8_c8(output, input, weight, bias, step, ic4, output, offset, mode, writeC8, relu, relu6);
}
}
void IndirectGemmFp16_16x8_common(float16_t *output, float16_t *input, float16_t *weight, float16_t *bias, size_t step,
size_t ic4, size_t oc8, size_t offset, size_t relu, size_t relu6) {
const int tile_n = 16;
for (int i = 0; i < out_channel; i++) {
int oc8_block = i / 8;
int oc8_res = i % 8;
int oc8_block = i / C8NUM;
int oc8_res = i % C8NUM;
int weight_oc_offset = oc8_block * step * ic4 * C4NUM * C8NUM + oc8_res;
for (int k = 0; k < tile_n; k++) {
int input_tile_offset = k * C4NUM;
@ -72,32 +81,32 @@ void IndirectGemmFp16_16x8(float16_t *output, float16_t *input, float16_t *weigh
}
}
void IndirectGemmFp16_16x8_tmp(float16_t *output, float16_t *input, float16_t *weight, const float16_t *bias,
size_t step, size_t ic4, size_t output_channel, size_t offset, size_t mode,
size_t writeC4, size_t relu, size_t relu6) {
void IndirectGemmFp16_16x8_c8(float16_t *output, float16_t *input, float16_t *weight, float16_t *bias, size_t step,
size_t ic4, size_t output_channel, size_t offset, size_t mode, size_t writeC8,
size_t relu, size_t relu6) {
const int tile_num = 16;
if (mode) {
if (mode && writeC8) {
for (int i = 0; i < tile_num; i++) {
int input_tile_offset = i * C4NUM;
int output_tile_offset = i * output_channel * 36;
int output_tile_offset = i * output_channel * step;
for (int j = 0; j < output_channel; j++) {
int oc8_block = j / 8;
int oc8_res = j % 8;
int weight_oc_offset = oc8_block * 36 * ic4 * C4NUM * 8 + oc8_res;
int out_oc_offset = output_tile_offset + oc8_block * 36 * C8NUM + oc8_res;
int oc8_block = j / C8NUM;
int oc8_res = j % C8NUM;
int weight_oc_offset = oc8_block * step * ic4 * C4NUM * C8NUM + oc8_res;
int out_oc_offset = output_tile_offset + oc8_block * step * C8NUM + oc8_res;
for (int n = 0; n < step; n++) {
int input_kw_offset = input_tile_offset + n * ic4 * C4NUM * tile_num;
int weight_kw_offset = weight_oc_offset + n * ic4 * C4NUM * 8;
int weight_kw_offset = weight_oc_offset + n * ic4 * C4NUM * C8NUM;
int output_kw_offset = out_oc_offset + n * C8NUM;
float16_t acc = 0;
for (int k = 0; k < ic4; k++) {
int input_ic4_offset = input_kw_offset + k * tile_num * C4NUM;
int weight_ic4_offset = weight_kw_offset + k * C4NUM * 8;
for (int m = 0; m < 4; m++) {
int weight_ic4_offset = weight_kw_offset + k * C4NUM * C8NUM;
for (int m = 0; m < C4NUM; m++) {
int input_ic_offset = input_ic4_offset + m;
int weight_ic_offset = weight_ic4_offset + m * 8;
int weight_ic_offset = weight_ic4_offset + m * C8NUM;
acc += (weight + weight_ic_offset)[0] * (input + input_ic_offset)[0];
}
}
@ -405,3 +414,91 @@ void Conv3x3Fp16(float16_t *input_data, float16_t *transed_weight, const float16
}
}
}
// fp16 convolution winograd
void ConvWinogardFp16(float16_t *input_data, float16_t *trans_weight, const float16_t *bias_data,
TmpBufferAddressFp16 *buffer_list, int task_id, ConvParameter *conv_param,
InputTransformUnitFp16Func input_trans_func, OutputTransformUnitFp16Func output_trans_func) {
int thread_num = conv_param->thread_num_;
int input_unit = conv_param->input_unit_;
int in_batch = conv_param->input_batch_;
int in_channel = conv_param->input_channel_;
int ic4 = UP_DIV(in_channel, C4NUM);
int out_unit = conv_param->output_unit_;
int out_w_block = UP_DIV(conv_param->output_w_, out_unit);
int out_h_block = UP_DIV(conv_param->output_h_, out_unit);
int tile_num = 16;
int output_count = out_w_block * out_h_block;
int output_tile_count = UP_DIV(output_count, tile_num);
int out_channel = conv_param->output_channel_;
int oc8 = UP_DIV(out_channel, C8NUM);
int input_unit_square = input_unit * input_unit;
size_t output_offset = oc8 * C8NUM * input_unit_square * sizeof(float16_t);
float16_t *trans_input = buffer_list[0];
float16_t *gemm_out = buffer_list[1];
float16_t *tmp_out_data = buffer_list[2];
float16_t *tmp_data = buffer_list[3];
int trans_input_offset = tile_num * input_unit_square * ic4 * C4NUM;
int gemm_out_offset = tile_num * input_unit_square * oc8 * C8NUM;
int tmp_data_offset = input_unit_square * C4NUM;
// step 1 : filter transform (pre-processed offline)
// step 2 : input transform (online)
for (int b = 0; b < in_batch; b++) {
int in_batch_offset = b * ic4 * C4NUM * conv_param->input_h_ * conv_param->input_w_;
int tmp_out_batch_offset = b * out_w_block * out_h_block * out_unit * out_unit * oc8 * C8NUM;
for (int thread_id = task_id; thread_id < output_tile_count; thread_id += thread_num) {
int out_tile_index = thread_id * TILE_NUM;
int cal_num = output_count - thread_id * tile_num;
cal_num = cal_num > tile_num ? tile_num : cal_num;
WinogradInputTransformFp16(input_data + in_batch_offset, trans_input + task_id * trans_input_offset,
tmp_data + task_id * tmp_data_offset, cal_num, out_tile_index, out_w_block, conv_param,
input_trans_func);
// step 3 : gemm
IndirectGemmFp16_16x8(gemm_out + task_id * gemm_out_offset, trans_input + task_id * trans_input_offset,
trans_weight, NULL, input_unit_square, ic4, oc8 * C8NUM, output_offset, 1, 1, 0, 0);
// step 4 : output transform
WinogradOutputTransformFp16(gemm_out + task_id * gemm_out_offset, tmp_out_data + tmp_out_batch_offset, bias_data,
cal_num, out_tile_index, out_w_block, conv_param, output_trans_func);
}
}
}
void UnPackWinogradOutputFp16(const float16_t *src, float16_t *dst, int batch, int height, int width, int channel,
int output_unit) {
int out_h_block_num = UP_DIV(height, output_unit);
int out_w_block_num = UP_DIV(width, output_unit);
int c8 = UP_DIV(channel, C8NUM);
for (int b = 0; b < batch; b++) {
int src_batch_offset = b * c8 * C8NUM * out_h_block_num * output_unit * out_w_block_num * output_unit;
int dst_batch_offset = b * height * width * channel;
for (int h = 0; h < height; h++) {
int src_h_offset = src_batch_offset + C8NUM * (h * out_w_block_num * output_unit);
int dst_h_offset = dst_batch_offset + h * width * channel;
for (int w = 0; w < width; w++) {
int src_w_offset = src_h_offset + w * C8NUM;
int dst_w_offset = dst_h_offset + w * channel;
for (int c = 0; c < c8 - 1; c++) {
int src_c8_offset = src_w_offset + c * C8NUM * out_w_block_num * out_h_block_num * output_unit * output_unit;
int dst_c8_offset = dst_w_offset + c * C8NUM;
#ifdef ENABLE_NEON
vst1q_f16(dst + dst_c8_offset, vld1q_f16(src + src_c8_offset));
#else
for (int i = 0; i < C8NUM; ++i) {
dst[dst_c8_offset + i] = src[src_c8_offset + i];
}
#endif
}
int c_res = channel - (c8 - 1) * C8NUM;
int src_c_res_offset = (c8 - 1) * C8NUM * out_w_block_num * out_h_block_num * output_unit * output_unit;
int dst_c_res_offset = (c8 - 1) * C8NUM;
for (int c = 0; c < c_res; c++) {
int src_c8_res_offset = src_w_offset + src_c_res_offset + c;
int dst_c8_res_offset = dst_w_offset + dst_c_res_offset + c;
dst[dst_c8_res_offset] = src[src_c8_res_offset];
}
}
}
}
}

View File

@ -18,11 +18,22 @@
#include <arm_neon.h>
#include "nnacl/conv_parameter.h"
#include "nnacl/fp16/winograd_utils_fp16.h"
#include "nnacl/fp16/winograd_transform_fp16.h"
typedef float16_t *TmpBufferAddressFp16;
#ifndef ENABLE_NEON
void IndirectGemmFp16_16x8(float16_t *output, float16_t *input, float16_t *weight, float16_t *bias, size_t step,
size_t ic4, size_t oc8, size_t offset, size_t mode, size_t writeC4, size_t relu,
size_t ic4, size_t oc8, size_t offset, size_t mode, size_t writeC8, size_t relu,
size_t relu6);
void IndirectGemmFp16_16x8_common(float16_t *output, float16_t *input, float16_t *weight, float16_t *bias, size_t step,
size_t ic4, size_t oc8, size_t offset, size_t relu, size_t relu6);
void IndirectGemmFp16_16x8_c8(float16_t *output, float16_t *input, float16_t *weight, float16_t *bias, size_t step,
size_t ic4, size_t oc8, size_t offset, size_t mode, size_t writeC8, size_t relu,
size_t relu6);
#endif
#ifdef __cplusplus
@ -48,6 +59,14 @@ void ConvFp16(float16_t *input_data, float16_t *packed_input, float16_t *packed_
void Conv3x3Fp16(float16_t *input_data, float16_t *transed_weight, const float16_t *bias_data, float16_t *output_data,
float16_t *tile_buffer, float16_t *block_unit_buffer, float16_t *tmp_dst_buffer, float16_t *tmp_out,
int task_id, ConvParameter *conv_param);
// fp16 convolution winograd
void ConvWinogardFp16(float16_t *input_data, float16_t *trans_weight, const float16_t *bias_data,
TmpBufferAddressFp16 *buffer_list, int task_id, ConvParameter *conv_param,
InputTransformUnitFp16Func input_trans_func, OutputTransformUnitFp16Func output_trans_func);
void UnPackWinogradOutputFp16(const float16_t *src, float16_t *dst, int batch, int height, int width, int channel,
int output_unit);
#ifdef __cplusplus
}
#endif

View File

@ -534,3 +534,95 @@ void Conv3x3Fp16OutputTransform(const float16_t *gemm_out, float16_t *out_data,
}
}
}
// fp16 common winograd
void WinogradInputTransformFp16(const float16_t *input_data, float16_t *trans_input, float16_t *tmp_data, int cal_num,
int out_tile_index, int out_w_block_num, ConvParameter *conv_param,
InputTransformUnitFp16Func input_trans_func) {
int tile_num = 16;
int input_unit = conv_param->input_unit_;
int output_unit = conv_param->output_unit_;
int in_channel = conv_param->input_channel_;
int ic4 = UP_DIV(in_channel, C4NUM);
int pad_h = conv_param->pad_h_;
int pad_w = conv_param->pad_w_;
int input_h = conv_param->input_h_;
int input_w = conv_param->input_w_;
if (out_w_block_num == 0) {
return;
}
for (int c = 0; c < cal_num; c++) { // actual tiled number
int src_x_s = (out_tile_index % out_w_block_num) * output_unit - pad_w;
int src_y_s = (out_tile_index / out_w_block_num) * output_unit - pad_h;
int interval_x_s = src_x_s > 0 ? 0 : -src_x_s;
int interval_y_s = src_y_s > 0 ? 0 : -src_y_s;
int src_x_e = src_x_s + input_unit;
int src_y_e = src_y_s + input_unit;
int interval_x_e = src_x_e < input_w ? input_unit : (input_w - src_x_s);
int interval_y_e = src_y_e < input_h ? input_unit : (input_h - src_y_s);
int src_plane_offset = ic4 * C4NUM * (src_y_s * input_w + src_x_s);
int dst_plane_offset = c * C4NUM;
for (int ic = 0; ic < ic4; ic++) {
// clear tmp buffer
memset(tmp_data, 0, input_unit * input_unit * C4NUM * sizeof(float16_t));
// get real input block with padding
int src_ic4_offset = src_plane_offset + ic * C4NUM;
for (int interval = interval_y_s; interval < interval_y_e; interval++) {
int src_y_offset = src_ic4_offset + (interval * input_w + interval_x_s) * ic4 * C4NUM;
int dst_y_offset = interval * input_unit * C4NUM + interval_x_s * C4NUM;
for (int j = 0; j < (interval_x_e - interval_x_s); j++) {
int src_x_offset = src_y_offset + j * ic4 * C4NUM;
int dst_x_offset = dst_y_offset + j * C4NUM;
float16_t *src_addr = input_data + src_x_offset;
float16_t *dst_addr = tmp_data + dst_x_offset;
#ifdef ENABLE_NEON
vst1_f16(dst_addr, vld1_f16(src_addr));
#else
for (int k = 0; k < C4NUM; k++) {
dst_addr[k] = src_addr[k];
}
#endif
}
}
// input transform
int dst_ic4_offset = dst_plane_offset + ic * tile_num * C4NUM;
size_t dst_step = ic4 * C4NUM * tile_num;
float16_t *trans_input_ptr = trans_input + dst_ic4_offset;
input_trans_func(tmp_data, trans_input_ptr, C4NUM, dst_step);
}
out_tile_index++;
} // cal_tile_num loop
}
void WinogradOutputTransformFp16(const float16_t *gemm_out, float16_t *tmp_out_data, const float16_t *bias_data,
int cal_num, int out_tile_index, int output_unit_num, ConvParameter *conv_param,
OutputTransformUnitFp16Func output_trans_func) {
int output_unit = conv_param->output_unit_;
int output_w = conv_param->output_w_;
int output_unit_block = UP_DIV(output_w, output_unit);
int output_channel = conv_param->output_channel_;
int oc8 = UP_DIV(output_channel, C8NUM);
int input_unit = conv_param->input_unit_;
if (output_unit_num == 0) {
return;
}
for (int i = 0; i < cal_num; i++) {
int dst_x_s = out_tile_index % output_unit_num;
int dst_y_s = out_tile_index / output_unit_num;
int src_tile_offset = i * oc8 * C8NUM * input_unit * input_unit;
int dst_tile_offset = C4NUM * output_unit * (dst_x_s + dst_y_s * output_unit_block * output_unit);
for (int j = 0; j < oc8; j++) {
int src_oc8_offset = src_tile_offset + j * input_unit * input_unit * C8NUM;
int dst_oc8_offset =
dst_tile_offset + j * C8NUM * output_unit_block * output_unit_block * output_unit * output_unit;
const float16_t *src_ptr = gemm_out + src_oc8_offset;
const float16_t *bias_ptr = bias_data + j * C8NUM;
float16_t *dst_ptr = tmp_out_data + dst_oc8_offset;
output_trans_func(src_ptr, dst_ptr, bias_ptr, C8NUM, output_unit_block * output_unit);
}
out_tile_index++;
}
}

View File

@ -21,6 +21,7 @@
#include <string.h>
#include "nnacl/fp16/pack_fp16.h"
#include "nnacl/fp16/conv_fp16.h"
#include "nnacl/fp16/winograd_utils_fp16.h"
#ifdef __cplusplus
extern "C" {
@ -39,6 +40,15 @@ void Conv3x3Fp16OutputUnit(const float16_t *gemm_out, const float16_t *bias_data
void Conv3x3Fp16OutputTransform(const float16_t *gemm_out, float16_t *out_data, const float16_t *bias_data,
int start_index, int real_cal_num, int out_w_block, ConvParameter *conv_param);
// fp16 common winograd
void WinogradInputTransformFp16(const float16_t *input_data, float16_t *trans_input, float16_t *tmp_data, int cal_num,
int out_tile_index, int out_w_block_num, ConvParameter *conv_param,
InputTransformUnitFp16Func input_trans_func);
void WinogradOutputTransformFp16(const float16_t *gemm_out, float16_t *tmp_out_data, const float16_t *bias_data,
int cal_num, int out_tile_index, int output_unit_num, ConvParameter *conv_param,
OutputTransformUnitFp16Func output_trans_func);
#ifdef __cplusplus
}
#endif

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,67 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_FP16_WINOGRAD_UTILS_H_
#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_FP16_WINOGRAD_UTILS_H_
#include <arm_neon.h>
#include "nnacl/conv_parameter.h"
#include "nnacl/op_base.h"
typedef void (*InputTransformUnitFp16Func)(const float16_t *src_data, float16_t *dst_data, int src_step, int dst_step);
typedef void (*OutputTransformUnitFp16Func)(const float16_t *src_data, float16_t *dst_data, const float16_t *bias_data,
int src_step, int dst_step);
#ifdef __cplusplus
extern "C" {
#endif
void InputTransform4x4UnitFp16(const float16_t *src_data, float16_t *dst_data, int src_step, int dst_step);
void InputTransform8x8UnitFp16(const float16_t *src_data, float16_t *dst_data, int src_step, int dst_step);
void OutputTransform4x2UnitFp16(const float16_t *src_data, float16_t *dst_data, const float16_t *bias_data,
int src_step, int dst_step);
void OutputTransform4x3UnitFp16(const float16_t *src_data, float16_t *dst_data, const float16_t *bias_data,
int src_step, int dst_step);
void OutputTransform8x2UnitFp16(const float16_t *src_data, float16_t *dst_data, const float16_t *bias_data,
int src_step, int dst_step);
void OutputTransform8x3UnitFp16(const float16_t *src_data, float16_t *dst_data, const float16_t *bias_data,
int src_step, int dst_step);
void OutputTransform8x4UnitFp16(const float16_t *src_data, float16_t *dst_data, const float16_t *bias_data,
int src_step, int dst_step);
void OutputTransform8x5UnitFp16(const float16_t *src_data, float16_t *dst_data, const float16_t *bias_data,
int src_step, int dst_step);
void OutputTransform8x6UnitFp16(const float16_t *src_data, float16_t *dst_data, const float16_t *bias_data,
int src_step, int dst_step);
void OutputTransform8x7UnitFp16(const float16_t *src_data, float16_t *dst_data, const float16_t *bias_data,
int src_step, int dst_step);
InputTransformUnitFp16Func GetInputTransFuncFp16(int input_unit);
OutputTransformUnitFp16Func GetOutputTransFuncFp16(int input_unit, int output_unit);
#ifdef __cplusplus
}
#endif
#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_FP16_WINOGRAD_UTILS_H_

View File

@ -243,10 +243,9 @@ int Conv1x1Fp32(const float *input_data, const float *weight_data, float *output
}
// fp32 conv winograd
void ConvWinogardFp32(float *input_data, float *trans_weight, const float *bias_data, float *output_data,
TmpBufferAddress *buffer_list, int task_id, ConvParameter *conv_param,
InputTransformUnitFunc input_trans_func, OutputTransformUnitFunc output_trans_func,
GEMM_FUNC_FP32 gemm_func) {
void ConvWinogardFp32(float *input_data, float *trans_weight, const float *bias_data, TmpBufferAddress *buffer_list,
int task_id, ConvParameter *conv_param, InputTransformUnitFunc input_trans_func,
OutputTransformUnitFunc output_trans_func, GEMM_FUNC_FP32 gemm_func) {
int thread_num = conv_param->thread_num_;
int input_unit = conv_param->input_unit_;
int in_batch = conv_param->input_batch_;

View File

@ -57,10 +57,9 @@ int Conv1x1Fp32(const float *input_data, const float *weight_data, float *output
StrassenMatMulParameter matmul_param);
// fp32 convolution winograd
void ConvWinogardFp32(float *input_data, float *trans_weight, const float *bias_data, float *output_data,
TmpBufferAddress *buffer_list, int task_id, ConvParameter *conv_param,
InputTransformUnitFunc input_trans_func, OutputTransformUnitFunc output_trans_func,
GEMM_FUNC_FP32 gemm_func);
void ConvWinogardFp32(float *input_data, float *trans_weight, const float *bias_data, TmpBufferAddress *buffer_list,
int task_id, ConvParameter *conv_param, InputTransformUnitFunc input_trans_func,
OutputTransformUnitFunc output_trans_func, GEMM_FUNC_FP32 gemm_func);
void UnPackWinogradOutput(const float *src, float *dst, int batch, int height, int width, int channel, int output_unit);