forked from mindspore-Ecosystem/mindspore
commit
9efd84fbb2
|
@ -0,0 +1,97 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "src/runtime/kernel/arm/fp16/convolution_1x1_fp16.h"
|
||||
#include "src/runtime/kernel/arm/nnacl/fp16/conv_fp16.h"
|
||||
#include "src/runtime/kernel/arm/nnacl/fp16/cast_fp16.h"
|
||||
#include "src/runtime/kernel/arm/nnacl/fp16/pack_fp16.h"
|
||||
#include "src/runtime/kernel/arm/fp16/layout_transform_fp16.h"
|
||||
#include "schema/model_generated.h"
|
||||
#include "src/kernel_registry.h"
|
||||
#include "include/errorcode.h"
|
||||
#include "src/runtime/runtime_api.h"
|
||||
|
||||
using mindspore::kernel::KERNEL_ARCH::kCPU;
|
||||
using mindspore::lite::KernelRegistrar;
|
||||
using mindspore::lite::RET_ERROR;
|
||||
using mindspore::lite::RET_OK;
|
||||
using mindspore::schema::PrimitiveType_Conv2D;
|
||||
|
||||
namespace mindspore::kernel {
|
||||
int Convolution1x1FP16CPUKernel::Init() {
|
||||
auto ret = ConvolutionBaseCPUKernel::Init();
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "ConvolutionBase init failed.";
|
||||
return ret;
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
int Convolution1x1FP16CPUKernel::ReSize() {
|
||||
if (fp16_out_ != nullptr) {
|
||||
free(fp16_out_);
|
||||
}
|
||||
if (fp16_input_ != nullptr) {
|
||||
free(fp16_input_);
|
||||
}
|
||||
if (nhwc4_input_ != nullptr) {
|
||||
free(nhwc4_input_);
|
||||
}
|
||||
|
||||
auto ret = ConvolutionBaseCPUKernel::Init();
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "ConvolutionBase init failed.";
|
||||
return ret;
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
int Convolution1x1FP16CPUKernel::RunImpl(int task_id) {
|
||||
// Conv1x1Fp16(reinterpret_cast<float16_t *>(nhwc4_input_), transformed_filter_addr_,
|
||||
// reinterpret_cast<float16_t *>(bias_data_), fp16_out_, tile_buffer_, block_unit_buffer_,
|
||||
// tmp_dst_buffer_, tmp_out_, task_id, conv_param_);
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
int Convolution1x1Fp16Impl(int task_id, LiteParallelGroupEnv *penv, void *cdata) {
|
||||
auto conv = reinterpret_cast<Convolution1x1FP16CPUKernel *>(cdata);
|
||||
auto error_code = conv->RunImpl(task_id);
|
||||
if (error_code != RET_OK) {
|
||||
MS_LOG(ERROR) << "Convolution1x1 Fp16 Run error task_id[" << task_id << "] error_code[" << error_code << "]";
|
||||
return RET_ERROR;
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
int Convolution1x1FP16CPUKernel::Run() {
|
||||
auto ret = Prepare();
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "Prepare failed.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
|
||||
ConvolutionBaseFP16CPUKernel::GetExecuteTensor();
|
||||
|
||||
int error_code = LiteBackendParallelLaunch(Convolution1x1Fp16Impl, this, thread_count_);
|
||||
if (error_code != RET_OK) {
|
||||
MS_LOG(ERROR) << "conv1x1 fp16 error error_code[" << error_code << "]";
|
||||
return RET_ERROR;
|
||||
}
|
||||
|
||||
ConvolutionBaseFP16CPUKernel::IfCastOutput();
|
||||
return RET_OK;
|
||||
}
|
||||
} // namespace mindspore::kernel
|
|
@ -0,0 +1,54 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP16_CONVOLUTION_1x1_FP16_H_
|
||||
#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP16_CONVOLUTION_1x1_FP16_H_
|
||||
|
||||
#include <arm_neon.h>
|
||||
#include <vector>
|
||||
#include "src/lite_kernel.h"
|
||||
#include "src/runtime/kernel/arm/fp16/convolution_base_fp16.h"
|
||||
#include "src/runtime/kernel/arm/nnacl/optimized_kernel.h"
|
||||
|
||||
namespace mindspore::kernel {
|
||||
class Convolution1x1FP16CPUKernel : public ConvolutionBaseFP16CPUKernel {
|
||||
public:
|
||||
Convolution1x1FP16CPUKernel(OpParameter *parameter, const std::vector<lite::tensor::Tensor *> &inputs,
|
||||
const std::vector<lite::tensor::Tensor *> &outputs, const Context *ctx,
|
||||
const lite::Primitive *primitive)
|
||||
: ConvolutionBaseFP16CPUKernel(parameter, inputs, outputs, ctx, primitive) {}
|
||||
~Convolution1x1FP16CPUKernel() override {
|
||||
if (fp16_input_ != nullptr) {
|
||||
free(fp16_input_);
|
||||
}
|
||||
if (fp16_weight_ != nullptr) {
|
||||
free(fp16_weight_);
|
||||
}
|
||||
if (fp16_out_ != nullptr) {
|
||||
free(fp16_out_);
|
||||
}
|
||||
}
|
||||
|
||||
int Init() override;
|
||||
int ReSize() override;
|
||||
int Run() override;
|
||||
int RunImpl(int task_id);
|
||||
|
||||
private:
|
||||
};
|
||||
} // namespace mindspore::kernel
|
||||
|
||||
#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP16_CONVOLUTION_1x1_FP16_H_
|
|
@ -52,8 +52,6 @@ void ProcessFilterFp16(float16_t *origin_weight, float16_t *dst_weight, ConvPara
|
|||
int Convolution3x3FP16CPUKernel::InitWeightBias() {
|
||||
auto input_channel = conv_param_->input_channel_;
|
||||
int output_channel = conv_param_->output_channel_;
|
||||
int kernel_h = conv_param_->kernel_h_;
|
||||
int kernel_w = conv_param_->kernel_w_;
|
||||
int iC4 = UP_DIV(input_channel, C4NUM);
|
||||
int oC8 = UP_DIV(output_channel, C8NUM);
|
||||
// init weight
|
||||
|
@ -64,18 +62,8 @@ int Convolution3x3FP16CPUKernel::InitWeightBias() {
|
|||
return RET_ERROR;
|
||||
}
|
||||
memset(transformed_filter_addr_, 0, transformed_size);
|
||||
float *origin_weight = reinterpret_cast<float *>(in_tensors_.at(kWeightIndex)->Data());
|
||||
size_t fp16_weight_size = input_channel * output_channel * kernel_h * kernel_w * sizeof(float16_t);
|
||||
fp16_weight_ = reinterpret_cast<float16_t *>(malloc(fp16_weight_size));
|
||||
if (fp16_weight_ == nullptr) {
|
||||
MS_LOG(ERROR) << "malloc fp16_weight_ failed.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
memset(fp16_weight_, 0, fp16_weight_size);
|
||||
for (int i = 0; i < fp16_weight_size / sizeof(float16_t); ++i) {
|
||||
fp16_weight_[i] = (float16_t)origin_weight[i];
|
||||
}
|
||||
ProcessFilterFp16(fp16_weight_, transformed_filter_addr_, conv_param_);
|
||||
ConvolutionBaseFP16CPUKernel::GetExecuteFilter();
|
||||
ProcessFilterFp16(execute_weight_, transformed_filter_addr_, conv_param_);
|
||||
|
||||
// init bias
|
||||
size_t new_bias_size = oC8 * C8NUM * sizeof(float16_t);
|
||||
|
@ -183,10 +171,6 @@ void Convolution3x3FP16CPUKernel::ConfigInputOutput() {
|
|||
}
|
||||
|
||||
int Convolution3x3FP16CPUKernel::Init() {
|
||||
if (context_->infer_shape_interrupt_ && !context_->running_) {
|
||||
set_need_reinit();
|
||||
return RET_OK;
|
||||
}
|
||||
auto ret = ConvolutionBaseCPUKernel::Init();
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "ConvolutionBase init failed.";
|
||||
|
@ -244,8 +228,8 @@ int Convolution3x3FP16CPUKernel::ReSize() {
|
|||
|
||||
int Convolution3x3FP16CPUKernel::RunImpl(int task_id) {
|
||||
Conv3x3Fp16(reinterpret_cast<float16_t *>(nhwc4_input_), transformed_filter_addr_,
|
||||
reinterpret_cast<float16_t *>(bias_data_), fp16_out_, tile_buffer_, block_unit_buffer_, tmp_dst_buffer_,
|
||||
tmp_out_, task_id, conv_param_);
|
||||
reinterpret_cast<float16_t *>(bias_data_), execute_output_, tile_buffer_, block_unit_buffer_,
|
||||
tmp_dst_buffer_, tmp_out_, task_id, conv_param_);
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
|
@ -265,16 +249,13 @@ int Convolution3x3FP16CPUKernel::Run() {
|
|||
MS_LOG(ERROR) << "Prepare failed.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
auto input_tensor = in_tensors_.at(kInputIndex);
|
||||
auto input_ele_num = input_tensor->ElementsNum();
|
||||
auto ori_input_data = reinterpret_cast<float *>(input_tensor->Data());
|
||||
Float32ToFloat16(ori_input_data, fp16_input_, input_ele_num);
|
||||
ConvolutionBaseFP16CPUKernel::GetExecuteTensor();
|
||||
|
||||
int in_batch = conv_param_->input_batch_;
|
||||
int in_h = conv_param_->input_h_;
|
||||
int in_w = conv_param_->input_w_;
|
||||
int in_channel = conv_param_->input_channel_;
|
||||
convert_func_(reinterpret_cast<void *>(fp16_input_), nhwc4_input_, in_batch, in_h * in_w, in_channel);
|
||||
convert_func_(reinterpret_cast<void *>(execute_input_), nhwc4_input_, in_batch, in_h * in_w, in_channel);
|
||||
|
||||
int error_code = LiteBackendParallelLaunch(Convolution3x3Fp16Impl, this, thread_count_);
|
||||
if (error_code != RET_OK) {
|
||||
|
@ -294,7 +275,7 @@ int Convolution3x3FP16CPUKernel::Run() {
|
|||
batch * oc8 * C8NUM * out_w_block * out_h_block * conv_param_->output_unit_ * conv_param_->output_unit_;
|
||||
int ro_batch_size = batch * conv_param_->output_channel_ * conv_param_->output_h_ * conv_param_->output_w_;
|
||||
const float16_t *batch_tmp_out = tmp_out_ + tmp_out_batch_offset;
|
||||
float16_t *batch_out = fp16_out_ + ro_batch_size;
|
||||
float16_t *batch_out = execute_output_ + ro_batch_size;
|
||||
for (int h = 0; h < conv_param_->output_h_; h++) {
|
||||
for (int w = 0; w < conv_param_->output_w_; w++) {
|
||||
for (int c = 0; c < conv_param_->output_channel_; c++) {
|
||||
|
@ -315,11 +296,7 @@ int Convolution3x3FP16CPUKernel::Run() {
|
|||
}
|
||||
}
|
||||
|
||||
// cast fp16 out to fp32 data
|
||||
auto out_tensor = out_tensors_.at(kOutputIndex);
|
||||
auto out_ele_num = out_tensor->ElementsNum();
|
||||
auto output_addr = reinterpret_cast<float *>(out_tensor->Data());
|
||||
Float16ToFloat32(fp16_out_, output_addr, out_ele_num);
|
||||
ConvolutionBaseFP16CPUKernel::IfCastOutput();
|
||||
return RET_OK;
|
||||
}
|
||||
} // namespace mindspore::kernel
|
||||
|
|
|
@ -20,16 +20,16 @@
|
|||
#include <arm_neon.h>
|
||||
#include <vector>
|
||||
#include "src/lite_kernel.h"
|
||||
#include "src/runtime/kernel/arm/base/convolution_base.h"
|
||||
#include "src/runtime/kernel/arm/fp16/convolution_base_fp16.h"
|
||||
#include "src/runtime/kernel/arm/nnacl/optimized_kernel.h"
|
||||
|
||||
namespace mindspore::kernel {
|
||||
class Convolution3x3FP16CPUKernel : public ConvolutionBaseCPUKernel {
|
||||
class Convolution3x3FP16CPUKernel : public ConvolutionBaseFP16CPUKernel {
|
||||
public:
|
||||
Convolution3x3FP16CPUKernel(OpParameter *parameter, const std::vector<lite::tensor::Tensor *> &inputs,
|
||||
const std::vector<lite::tensor::Tensor *> &outputs, const Context *ctx,
|
||||
const lite::Primitive *primitive)
|
||||
: ConvolutionBaseCPUKernel(parameter, inputs, outputs, ctx, primitive) {}
|
||||
: ConvolutionBaseFP16CPUKernel(parameter, inputs, outputs, ctx, primitive) {}
|
||||
~Convolution3x3FP16CPUKernel() override {
|
||||
if (fp16_input_ != nullptr) {
|
||||
free(fp16_input_);
|
||||
|
@ -66,9 +66,6 @@ class Convolution3x3FP16CPUKernel : public ConvolutionBaseCPUKernel {
|
|||
void ConfigInputOutput();
|
||||
|
||||
private:
|
||||
float16_t *fp16_input_;
|
||||
float16_t *fp16_weight_;
|
||||
float16_t *fp16_out_;
|
||||
float16_t *transformed_filter_addr_;
|
||||
float16_t *tile_buffer_;
|
||||
float16_t *block_unit_buffer_;
|
||||
|
|
|
@ -0,0 +1,86 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "src/runtime/kernel/arm/fp16/convolution_base_fp16.h"
|
||||
#include "src/runtime/kernel/arm/nnacl/fp16/cast_fp16.h"
|
||||
#include "schema/model_generated.h"
|
||||
#include "src/kernel_factory.h"
|
||||
#include "include/errorcode.h"
|
||||
#include "src/runtime/runtime_api.h"
|
||||
|
||||
namespace mindspore::kernel {
|
||||
int ConvolutionBaseFP16CPUKernel::GetExecuteTensor() {
|
||||
// ===================input====================//
|
||||
auto input_tensor = in_tensors_.at(kInputIndex);
|
||||
auto input_data_type = input_tensor->data_type();
|
||||
MS_ASSERT(input_data_type == kNumberTypeFloat32 || input_data_type == kNumberTypeFloat16);
|
||||
if (input_data_type == kNumberTypeFloat32) {
|
||||
auto input_ele_num = input_tensor->ElementsNum();
|
||||
auto ori_input_data = reinterpret_cast<float *>(input_tensor->Data());
|
||||
Float32ToFloat16(ori_input_data, fp16_input_, input_ele_num);
|
||||
execute_input_ = fp16_input_;
|
||||
} else {
|
||||
auto ori_input_data = reinterpret_cast<float16_t *>(input_tensor->Data());
|
||||
execute_input_ = ori_input_data;
|
||||
}
|
||||
// ==================output====================//
|
||||
auto out_tensor = out_tensors_.at(kOutputIndex);
|
||||
auto out_data_type = out_tensor->data_type();
|
||||
MS_ASSERT(out_data_type == kNumberTypeFloat32 || out_data_type == kNumberTypeFloat16);
|
||||
out_data_type_ = out_data_type;
|
||||
if (out_data_type == kNumberTypeFloat32) {
|
||||
execute_output_ = fp16_out_;
|
||||
} else {
|
||||
auto out_ptr = reinterpret_cast<float16_t *>(out_tensor->Data());
|
||||
execute_output_ = out_ptr;
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
int ConvolutionBaseFP16CPUKernel::GetExecuteFilter() {
|
||||
auto weight_tensor = in_tensors_.at(kWeightIndex);
|
||||
auto weight_data_type = weight_tensor->data_type();
|
||||
MS_ASSERT(weight_data_type == kNumberTypeFloat32 || weight_data_type == kNumberTypeFloat16);
|
||||
if (weight_data_type == kNumberTypeFloat32) {
|
||||
float *origin_weight = reinterpret_cast<float *>(in_tensors_.at(kWeightIndex)->Data());
|
||||
size_t fp16_weight_size = conv_param_->input_channel_ * conv_param_->output_channel_ * conv_param_->kernel_h_ *
|
||||
conv_param_->input_w_ * sizeof(float16_t);
|
||||
fp16_weight_ = reinterpret_cast<float16_t *>(malloc(fp16_weight_size));
|
||||
if (fp16_weight_ == nullptr) {
|
||||
MS_LOG(ERROR) << "malloc fp16_weight_ failed.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
for (int i = 0; i < fp16_weight_size / sizeof(float16_t); ++i) {
|
||||
fp16_weight_[i] = (float16_t)origin_weight[i];
|
||||
}
|
||||
execute_weight_ = fp16_weight_;
|
||||
} else {
|
||||
auto *origin_weight = reinterpret_cast<float16_t *>(in_tensors_.at(kWeightIndex)->Data());
|
||||
execute_weight_ = origin_weight;
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
void ConvolutionBaseFP16CPUKernel::IfCastOutput() {
|
||||
if (out_data_type_ == kNumberTypeFloat32) {
|
||||
auto out_tensor = out_tensors_.at(kOutputIndex);
|
||||
auto out_ele_num = out_tensor->ElementsNum();
|
||||
auto output_addr = reinterpret_cast<float *>(out_tensor->Data());
|
||||
Float16ToFloat32(fp16_out_, output_addr, out_ele_num);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace mindspore::kernel
|
|
@ -0,0 +1,54 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP16_CONVOLUTION_BASE_FP16_H_
|
||||
#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP16_CONVOLUTION_BASE_FP16_H_
|
||||
|
||||
#include <arm_neon.h>
|
||||
#include <vector>
|
||||
#include "src/lite_kernel.h"
|
||||
#include "src/runtime/kernel/arm/base/convolution_base.h"
|
||||
#include "src/runtime/kernel/arm/nnacl/optimized_kernel.h"
|
||||
|
||||
namespace mindspore::kernel {
|
||||
class ConvolutionBaseFP16CPUKernel : public ConvolutionBaseCPUKernel {
|
||||
public:
|
||||
ConvolutionBaseFP16CPUKernel(OpParameter *parameter, const std::vector<lite::tensor::Tensor *> &inputs,
|
||||
const std::vector<lite::tensor::Tensor *> &outputs, const Context *ctx,
|
||||
const lite::Primitive *primitive)
|
||||
: ConvolutionBaseCPUKernel(parameter, inputs, outputs, ctx, primitive) {}
|
||||
~ConvolutionBaseFP16CPUKernel() override = default;
|
||||
|
||||
int Init() override { return RET_OK; }
|
||||
int ReSize() override { return RET_OK; }
|
||||
int Run() override { return RET_OK; }
|
||||
int RunImpl(int task_id) { return RET_OK; }
|
||||
virtual int GetExecuteTensor();
|
||||
virtual int GetExecuteFilter();
|
||||
virtual void IfCastOutput();
|
||||
|
||||
protected:
|
||||
float16_t *fp16_input_ = nullptr;
|
||||
float16_t *fp16_weight_ = nullptr;
|
||||
float16_t *fp16_out_ = nullptr;
|
||||
float16_t *execute_input_;
|
||||
float16_t *execute_weight_;
|
||||
float16_t *execute_output_;
|
||||
TypeId out_data_type_;
|
||||
};
|
||||
} // namespace mindspore::kernel
|
||||
|
||||
#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP16_CONVOLUTION_BASE_FP16_H_
|
|
@ -102,10 +102,6 @@ int ConvolutionDepthwiseFp16CPUKernel::InitWeightBias() {
|
|||
}
|
||||
|
||||
int ConvolutionDepthwiseFp16CPUKernel::Init() {
|
||||
if (context_->infer_shape_interrupt_ && !context_->running_) {
|
||||
set_need_reinit();
|
||||
return RET_OK;
|
||||
}
|
||||
// conv base init
|
||||
auto ret = ConvolutionBaseCPUKernel::Init();
|
||||
if (ret != RET_OK) {
|
||||
|
|
|
@ -46,24 +46,14 @@ int ConvolutionFP16CPUKernel::InitWeightBias() {
|
|||
int pack_weight_size = oc8 * ic4 * C8NUM * C4NUM * kernel_plane;
|
||||
|
||||
// init weight
|
||||
float *origin_weight = reinterpret_cast<float *>(in_tensors_.at(kWeightIndex)->Data());
|
||||
size_t fp16_weight_size = in_channel * out_channel * kernel_h * kernel_w * sizeof(float16_t);
|
||||
fp16_weight_ = reinterpret_cast<float16_t *>(malloc(fp16_weight_size));
|
||||
if (fp16_weight_ == nullptr) {
|
||||
MS_LOG(ERROR) << "malloc fp16_weight_ failed.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
for (int i = 0; i < fp16_weight_size / sizeof(float16_t); ++i) {
|
||||
fp16_weight_[i] = (float16_t)origin_weight[i];
|
||||
}
|
||||
|
||||
ConvolutionBaseFP16CPUKernel::GetExecuteFilter();
|
||||
packed_weight_ = reinterpret_cast<float16_t *>(malloc(pack_weight_size * sizeof(float16_t)));
|
||||
if (packed_weight_ == nullptr) {
|
||||
MS_LOG(ERROR) << "malloc packed_weight_ failed.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
memset(packed_weight_, 0, pack_weight_size * sizeof(float16_t));
|
||||
PackWeightFp16(fp16_weight_, conv_param_, packed_weight_);
|
||||
PackWeightFp16(execute_weight_, conv_param_, packed_weight_);
|
||||
|
||||
// init bias
|
||||
bias_data_ = malloc(oc8 * C8NUM * sizeof(float16_t));
|
||||
|
@ -157,10 +147,6 @@ void ConvolutionFP16CPUKernel::ConfigInputOutput() {
|
|||
}
|
||||
|
||||
int ConvolutionFP16CPUKernel::Init() {
|
||||
if (context_->infer_shape_interrupt_ && !context_->running_) {
|
||||
set_need_reinit();
|
||||
return RET_OK;
|
||||
}
|
||||
auto ret = ConvolutionBaseCPUKernel::Init();
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "ConvolutionBase init fail!ret: " << ret;
|
||||
|
@ -212,7 +198,7 @@ int ConvolutionFP16CPUKernel::ReSize() {
|
|||
|
||||
int ConvolutionFP16CPUKernel::RunImpl(int task_id) {
|
||||
ConvFp16(reinterpret_cast<float16_t *>(nhwc4_input_), packed_input_, packed_weight_,
|
||||
reinterpret_cast<float16_t *>(bias_data_), tmp_output_block_, fp16_out_, task_id, conv_param_);
|
||||
reinterpret_cast<float16_t *>(bias_data_), tmp_output_block_, execute_output_, task_id, conv_param_);
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
|
@ -232,16 +218,13 @@ int ConvolutionFP16CPUKernel::Run() {
|
|||
MS_LOG(ERROR) << "Prepare failed.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
auto input_tensor = in_tensors_.at(kInputIndex);
|
||||
auto ori_input_data = reinterpret_cast<float *>(input_tensor->Data());
|
||||
auto input_ele_num = input_tensor->ElementsNum();
|
||||
Float32ToFloat16(ori_input_data, fp16_input_, input_ele_num);
|
||||
ConvolutionBaseFP16CPUKernel::GetExecuteTensor();
|
||||
|
||||
int in_batch = conv_param_->input_batch_;
|
||||
int in_h = conv_param_->input_h_;
|
||||
int in_w = conv_param_->input_w_;
|
||||
int in_channel = conv_param_->input_channel_;
|
||||
convert_func_(reinterpret_cast<void *>(fp16_input_), nhwc4_input_, in_batch, in_h * in_w, in_channel);
|
||||
convert_func_(reinterpret_cast<void *>(execute_input_), nhwc4_input_, in_batch, in_h * in_w, in_channel);
|
||||
|
||||
int error_code = LiteBackendParallelLaunch(ConvolutionFp16Impl, this, thread_count_);
|
||||
if (error_code != RET_OK) {
|
||||
|
@ -249,11 +232,7 @@ int ConvolutionFP16CPUKernel::Run() {
|
|||
return RET_ERROR;
|
||||
}
|
||||
|
||||
// cast fp16 out to fp32 data
|
||||
auto out_tensor = out_tensors_.at(kOutputIndex);
|
||||
auto output_addr = reinterpret_cast<float *>(out_tensor->Data());
|
||||
auto out_ele_num = out_tensor->ElementsNum();
|
||||
Float16ToFloat32(fp16_out_, output_addr, out_ele_num);
|
||||
ConvolutionBaseFP16CPUKernel::IfCastOutput();
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
|
|
|
@ -20,15 +20,15 @@
|
|||
#include <arm_neon.h>
|
||||
#include <vector>
|
||||
#include "src/lite_kernel.h"
|
||||
#include "src/runtime/kernel/arm/base/convolution_base.h"
|
||||
#include "src/runtime/kernel/arm/fp16/convolution_base_fp16.h"
|
||||
|
||||
namespace mindspore::kernel {
|
||||
class ConvolutionFP16CPUKernel : public ConvolutionBaseCPUKernel {
|
||||
class ConvolutionFP16CPUKernel : public ConvolutionBaseFP16CPUKernel {
|
||||
public:
|
||||
ConvolutionFP16CPUKernel(OpParameter *parameter, const std::vector<lite::tensor::Tensor *> &inputs,
|
||||
const std::vector<lite::tensor::Tensor *> &outputs, const Context *ctx,
|
||||
const lite::Primitive *primitive)
|
||||
: ConvolutionBaseCPUKernel(parameter, inputs, outputs, ctx, primitive) {}
|
||||
: ConvolutionBaseFP16CPUKernel(parameter, inputs, outputs, ctx, primitive) {}
|
||||
~ConvolutionFP16CPUKernel() override {
|
||||
if (fp16_input_ != nullptr) {
|
||||
free(fp16_input_);
|
||||
|
@ -59,9 +59,6 @@ class ConvolutionFP16CPUKernel : public ConvolutionBaseCPUKernel {
|
|||
void ConfigInputOutput();
|
||||
|
||||
private:
|
||||
float16_t *fp16_input_;
|
||||
float16_t *fp16_weight_;
|
||||
float16_t *fp16_out_;
|
||||
float16_t *packed_input_;
|
||||
float16_t *packed_weight_;
|
||||
float16_t *tmp_output_block_;
|
||||
|
|
|
@ -39,23 +39,13 @@ int ConvolutionSWFP16CPUKernel::ProcessFilter() {
|
|||
int out_channel = conv_param_->output_channel_;
|
||||
int ic4 = UP_DIV(in_channel, C4NUM);
|
||||
|
||||
auto *origin_weight = reinterpret_cast<float *>(in_tensors_.at(kWeightIndex)->Data());
|
||||
size_t fp16_weight_size = in_channel * out_channel * kernel_h * kernel_w * sizeof(float16_t);
|
||||
fp16_weight_ = reinterpret_cast<float16_t *>(malloc(fp16_weight_size));
|
||||
if (fp16_weight_ == nullptr) {
|
||||
MS_LOG(ERROR) << "malloc fp16_weight_ failed.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
// cast origin fp32 weight data to fp16 data
|
||||
for (int i = 0; i < fp16_weight_size / sizeof(float16_t); ++i) {
|
||||
fp16_weight_[i] = (float16_t)origin_weight[i];
|
||||
}
|
||||
ConvolutionBaseFP16CPUKernel::GetExecuteFilter();
|
||||
|
||||
for (int oc = 0; oc < out_channel; ++oc) {
|
||||
int src_oc_offset = oc * kernel_h * kernel_w * in_channel;
|
||||
int dst_oc_offset = oc * kernel_h * kernel_w * ic4 * C4NUM;
|
||||
for (int i = 0; i < kernel_h * kernel_w; ++i) {
|
||||
const float16_t *src = fp16_weight_ + src_oc_offset + i * in_channel;
|
||||
const float16_t *src = execute_weight_ + src_oc_offset + i * in_channel;
|
||||
float16_t *dst = packed_weight_ + dst_oc_offset + i * ic4 * C4NUM;
|
||||
memcpy(dst, src, in_channel * sizeof(float16_t));
|
||||
}
|
||||
|
@ -162,10 +152,6 @@ void ConvolutionSWFP16CPUKernel::ConfigInputOutput() {
|
|||
}
|
||||
|
||||
int ConvolutionSWFP16CPUKernel::Init() {
|
||||
if (context_->infer_shape_interrupt_ && !context_->running_) {
|
||||
set_need_reinit();
|
||||
return RET_OK;
|
||||
}
|
||||
auto ret = ConvolutionBaseCPUKernel::Init();
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "ConvolutionBase init fail!ret: " << ret;
|
||||
|
@ -222,7 +208,7 @@ int ConvolutionSWFP16CPUKernel::ReSize() {
|
|||
|
||||
int ConvolutionSWFP16CPUKernel::RunImpl(int task_id) {
|
||||
ConvSWFp16(reinterpret_cast<float16_t *>(nhwc4_input_), packed_weight_, reinterpret_cast<float16_t *>(bias_data_),
|
||||
tmp_output_block_, fp16_out_, task_id, conv_param_, slidingWindow_param_);
|
||||
tmp_output_block_, execute_output_, task_id, conv_param_, slidingWindow_param_);
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
|
@ -242,16 +228,13 @@ int ConvolutionSWFP16CPUKernel::Run() {
|
|||
MS_LOG(ERROR) << "Prepare failed.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
auto input_tensor = in_tensors_.at(kInputIndex);
|
||||
auto input_ele_num = input_tensor->ElementsNum();
|
||||
auto ori_input_data = reinterpret_cast<float *>(input_tensor->Data());
|
||||
Float32ToFloat16(ori_input_data, fp16_input_, input_ele_num);
|
||||
ConvolutionBaseFP16CPUKernel::GetExecuteTensor();
|
||||
|
||||
int in_batch = conv_param_->input_batch_;
|
||||
int in_h = conv_param_->input_h_;
|
||||
int in_w = conv_param_->input_w_;
|
||||
int in_channel = conv_param_->input_channel_;
|
||||
convert_func_(reinterpret_cast<void *>(fp16_input_), nhwc4_input_, in_batch, in_h * in_w, in_channel);
|
||||
convert_func_(reinterpret_cast<void *>(execute_input_), nhwc4_input_, in_batch, in_h * in_w, in_channel);
|
||||
|
||||
int error_code = LiteBackendParallelLaunch(ConvolutionSWFp16Impl, this, thread_count_);
|
||||
if (error_code != RET_OK) {
|
||||
|
@ -259,18 +242,14 @@ int ConvolutionSWFP16CPUKernel::Run() {
|
|||
return RET_ERROR;
|
||||
}
|
||||
|
||||
// cast fp16 out to fp32 data
|
||||
auto out_tensor = out_tensors_.at(kOutputIndex);
|
||||
auto out_ele_num = out_tensor->ElementsNum();
|
||||
auto output_addr = reinterpret_cast<float *>(out_tensor->Data());
|
||||
// output nhwc4
|
||||
int oc4_res = conv_param_->output_channel_ % C4NUM;
|
||||
if (oc4_res != 0) {
|
||||
PackNHWC4ToNHWCFp16(reinterpret_cast<const void *>(tmp_output_block_), reinterpret_cast<void *>(fp16_out_),
|
||||
PackNHWC4ToNHWCFp16(reinterpret_cast<const void *>(tmp_output_block_), reinterpret_cast<void *>(execute_output_),
|
||||
conv_param_->output_batch_, conv_param_->output_h_ * conv_param_->output_w_,
|
||||
conv_param_->output_channel_);
|
||||
}
|
||||
Float16ToFloat32(fp16_out_, output_addr, out_ele_num);
|
||||
ConvolutionBaseFP16CPUKernel::IfCastOutput();
|
||||
return RET_OK;
|
||||
}
|
||||
} // namespace mindspore::kernel
|
||||
|
|
|
@ -19,15 +19,15 @@
|
|||
#include <arm_neon.h>
|
||||
#include <vector>
|
||||
#include "src/lite_kernel.h"
|
||||
#include "src/runtime/kernel/arm/base/convolution_base.h"
|
||||
#include "src/runtime/kernel/arm/fp16/convolution_base_fp16.h"
|
||||
|
||||
namespace mindspore::kernel {
|
||||
class ConvolutionSWFP16CPUKernel : public ConvolutionBaseCPUKernel {
|
||||
class ConvolutionSWFP16CPUKernel : public ConvolutionBaseFP16CPUKernel {
|
||||
public:
|
||||
ConvolutionSWFP16CPUKernel(OpParameter *parameter, const std::vector<lite::tensor::Tensor *> &inputs,
|
||||
const std::vector<lite::tensor::Tensor *> &outputs, const Context *ctx,
|
||||
const lite::Primitive *primitive)
|
||||
: ConvolutionBaseCPUKernel(parameter, inputs, outputs, ctx, primitive) {}
|
||||
: ConvolutionBaseFP16CPUKernel(parameter, inputs, outputs, ctx, primitive) {}
|
||||
~ConvolutionSWFP16CPUKernel() override {
|
||||
if (fp16_input_ != nullptr) {
|
||||
free(fp16_input_);
|
||||
|
@ -57,9 +57,6 @@ class ConvolutionSWFP16CPUKernel : public ConvolutionBaseCPUKernel {
|
|||
int ProcessFilter();
|
||||
|
||||
private:
|
||||
float16_t *fp16_input_;
|
||||
float16_t *fp16_weight_;
|
||||
float16_t *fp16_out_;
|
||||
float16_t *packed_weight_;
|
||||
float16_t *tmp_output_block_;
|
||||
SlidingWindowParam *slidingWindow_param_;
|
||||
|
|
|
@ -0,0 +1,409 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "src/runtime/kernel/arm/fp16/convolution_winograd_fp16.h"
|
||||
#include "src/runtime/kernel/arm/fp16/matrix_fp16.h"
|
||||
#include "src/runtime/kernel/arm/nnacl/fp16/conv_fp16.h"
|
||||
#include "src/runtime/kernel/arm/nnacl/fp16/common_func.h"
|
||||
#include "src/runtime/kernel/arm/nnacl/fp16/cast_fp16.h"
|
||||
#include "src/runtime/kernel/arm/nnacl/fp16/pack_fp16.h"
|
||||
#include "src/runtime/kernel/arm/nnacl/fp16/winograd_transform_fp16.h"
|
||||
#include "src/runtime/kernel/arm/nnacl/fp16/winograd_utils_fp16.h"
|
||||
#include "src/runtime/kernel/arm/fp16/layout_transform_fp16.h"
|
||||
#include "schema/model_generated.h"
|
||||
#include "src/kernel_registry.h"
|
||||
#include "include/errorcode.h"
|
||||
#include "src/runtime/runtime_api.h"
|
||||
|
||||
using mindspore::kernel::KERNEL_ARCH::kCPU;
|
||||
using mindspore::lite::KernelRegistrar;
|
||||
using mindspore::lite::RET_ERROR;
|
||||
using mindspore::lite::RET_OK;
|
||||
using mindspore::schema::PrimitiveType_Conv2D;
|
||||
|
||||
namespace mindspore::kernel {
|
||||
void WinogradFilterTransformFp16(const float16_t *weight_data, Matrix *trans_weight, int kernel_unit, int input_unit,
|
||||
ConvParameter *conv_param, int oc_block) {
|
||||
// original weight format : ohwi
|
||||
auto channel_in = conv_param->input_channel_;
|
||||
auto channel_out = conv_param->output_channel_;
|
||||
int input_unit_square = input_unit * input_unit;
|
||||
|
||||
// generate matrix_G && matrix_GT
|
||||
auto matrix_g = TransformMatrixGenerator(input_unit, kernel_unit);
|
||||
auto matrix_gt = TransformMatrixGenerator(kernel_unit, input_unit);
|
||||
ChooseMatrixG(matrix_g, matrix_gt);
|
||||
auto matrix_g_data = reinterpret_cast<float *>(matrix_g->GetData());
|
||||
auto matrix_gt_data = reinterpret_cast<float *>(matrix_gt->GetData());
|
||||
auto matrix_g_data_fp16 = reinterpret_cast<float16_t *>(malloc(input_unit * kernel_unit * sizeof(float16_t)));
|
||||
auto matrix_gt_data_fp16 = reinterpret_cast<float16_t *>(malloc(input_unit * kernel_unit * sizeof(float16_t)));
|
||||
Float32ToFloat16(matrix_g_data, matrix_g_data_fp16, input_unit * kernel_unit);
|
||||
Float32ToFloat16(matrix_gt_data, matrix_gt_data_fp16, input_unit * kernel_unit);
|
||||
|
||||
// trans_filter = G*g*GT (g represents weight_data)
|
||||
// separate into two steps ===> tmp = G*g ===> out = tmp * GT
|
||||
auto tmp_weight_data = reinterpret_cast<float16_t *>(malloc(kernel_unit * kernel_unit * sizeof(float16_t)));
|
||||
auto tmp_data = reinterpret_cast<float16_t *>(malloc(input_unit * kernel_unit * sizeof(float16_t)));
|
||||
auto trans_out_data = reinterpret_cast<float16_t *>(malloc(input_unit * input_unit * sizeof(float16_t)));
|
||||
bool row = true;
|
||||
auto trans_weight_data = reinterpret_cast<float16_t *>(trans_weight->GetData());
|
||||
std::vector<int> strides = trans_weight->GetStride();
|
||||
|
||||
int kernel_plane_stride = channel_in;
|
||||
if (oc_block == 0) {
|
||||
MS_LOG(ERROR) << "Divide by zero";
|
||||
return;
|
||||
}
|
||||
for (int i = 0; i < channel_out; i++) {
|
||||
int out_c_block = i / oc_block;
|
||||
int out_c_res = i % oc_block;
|
||||
int input_oz_offset = i * kernel_unit * kernel_unit * channel_in;
|
||||
int output_oz_offset = out_c_block * strides[1] * input_unit * input_unit + out_c_res;
|
||||
for (int j = 0; j < channel_in; j++) {
|
||||
int ic4_block = j / C4NUM;
|
||||
int ic4_res = j % C4NUM;
|
||||
int input_iz_offset = input_oz_offset + j;
|
||||
int output_iz_offset = output_oz_offset + ic4_block * strides[2] + ic4_res * strides[3];
|
||||
for (int k = 0; k < kernel_unit * kernel_unit; k++) {
|
||||
int input_xy_offset = input_iz_offset + k * kernel_plane_stride;
|
||||
tmp_weight_data[k] = *(weight_data + input_xy_offset);
|
||||
}
|
||||
// now we only support row-major matrix-multiply
|
||||
// tmp = G * g
|
||||
MatrixMultiplyFp16(matrix_g_data_fp16, tmp_weight_data, tmp_data, input_unit, kernel_unit, kernel_unit, row);
|
||||
// out = tmp * GT
|
||||
MatrixMultiplyFp16(tmp_data, matrix_gt_data_fp16, trans_out_data, input_unit, kernel_unit, input_unit, row);
|
||||
|
||||
for (int z = 0; z < input_unit_square; z++) {
|
||||
int output_xy_offset = output_iz_offset + z * strides[1];
|
||||
*(trans_weight_data + output_xy_offset) = trans_out_data[z];
|
||||
}
|
||||
}
|
||||
}
|
||||
free(tmp_weight_data);
|
||||
free(tmp_data);
|
||||
free(trans_out_data);
|
||||
free(matrix_g_data_fp16);
|
||||
free(matrix_gt_data_fp16);
|
||||
delete matrix_g;
|
||||
delete matrix_gt;
|
||||
}
|
||||
|
||||
int ConvolutionWinogradFP16CPUKernel::InitWeightBias() {
|
||||
int output_channel = conv_param_->output_channel_;
|
||||
int oc_block, oc_block_num;
|
||||
oc_block = C8NUM;
|
||||
oc_block_num = UP_DIV(output_channel, C8NUM);
|
||||
|
||||
// init weight
|
||||
auto ret = MallocFilterMatrix(oc_block, oc_block_num);
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "Malloc filter matrix failed.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
|
||||
ConvolutionBaseFP16CPUKernel::GetExecuteFilter();
|
||||
WinogradFilterTransformFp16(execute_weight_, trans_weight_, kernel_unit_, input_unit_, conv_param_, oc_block);
|
||||
|
||||
// init bias
|
||||
bias_data_ = malloc(oc_block_num * oc_block * sizeof(float16_t));
|
||||
if (bias_data_ == nullptr) {
|
||||
MS_LOG(ERROR) << "malloc bias_data_ failed.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
memset(bias_data_, 0, oc_block_num * oc_block * sizeof(float16_t));
|
||||
auto fp16_bias_data = reinterpret_cast<float16_t *>(bias_data_);
|
||||
if (in_tensors_.size() == kInputSize2) {
|
||||
auto ori_bias = reinterpret_cast<float *>(in_tensors_.at(kBiasIndex)->Data());
|
||||
for (int i = 0; i < output_channel; ++i) {
|
||||
fp16_bias_data[i] = (float16_t)ori_bias[i];
|
||||
}
|
||||
} else {
|
||||
MS_ASSERT(inputs_.size() == kInputSize1);
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
int ConvolutionWinogradFP16CPUKernel::MallocFilterMatrix(int oc_block, int oc_block_num) {
|
||||
int channel_in = conv_param_->input_channel_;
|
||||
int ic4 = UP_DIV(channel_in, BLOCK);
|
||||
|
||||
// set data
|
||||
auto trans_matrix_data_size = input_unit_ * input_unit_ * ic4 * C4NUM * oc_block_num * oc_block * sizeof(float);
|
||||
auto matrix_buffer = malloc(trans_matrix_data_size);
|
||||
if (matrix_buffer == nullptr) {
|
||||
MS_LOG(ERROR) << "malloc matrix_buffer failed.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
memset(matrix_buffer, 0, trans_matrix_data_size);
|
||||
trans_weight_ = new Matrix();
|
||||
trans_weight_->SetData(matrix_buffer);
|
||||
trans_weight_->SetNDim(5);
|
||||
|
||||
std::vector<int> shapes;
|
||||
std::vector<int> strides;
|
||||
// set shape
|
||||
shapes.push_back(input_unit_ * input_unit_);
|
||||
shapes.push_back(oc_block_num);
|
||||
shapes.push_back(ic4);
|
||||
shapes.push_back(C4NUM);
|
||||
shapes.push_back(oc_block);
|
||||
// set stride
|
||||
for (int i = 0; i < 4; i++) {
|
||||
int stride = 1;
|
||||
for (int j = i + 1; j < 5; j++) {
|
||||
stride *= shapes[j];
|
||||
}
|
||||
strides.push_back(stride);
|
||||
}
|
||||
trans_weight_->SetShape(shapes);
|
||||
trans_weight_->SetStride(strides);
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
int ConvolutionWinogradFP16CPUKernel::InitTmpBuffer() {
|
||||
int cal_num = 16;
|
||||
int channel_in = conv_param_->input_channel_;
|
||||
int channel_out = conv_param_->output_channel_;
|
||||
int output_h = conv_param_->output_h_;
|
||||
int output_w = conv_param_->output_w_;
|
||||
int ic4 = UP_DIV(channel_in, C4NUM);
|
||||
int oc8 = UP_DIV(channel_out, C8NUM);
|
||||
|
||||
/*=============================fp16_input_============================*/
|
||||
size_t fp16_input_size = conv_param_->input_channel_ * conv_param_->input_batch_ * conv_param_->input_h_ *
|
||||
conv_param_->input_w_ * sizeof(float16_t);
|
||||
fp16_input_ = reinterpret_cast<float16_t *>(malloc(fp16_input_size));
|
||||
if (fp16_input_ == nullptr) {
|
||||
MS_LOG(ERROR) << "malloc fp16_input_ failed.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
|
||||
/*=============================trans_input_============================*/
|
||||
size_t tile_buffer_size = thread_count_ * cal_num * input_unit_ * input_unit_ * ic4 * C4NUM * sizeof(float16_t);
|
||||
trans_input_ = reinterpret_cast<float16_t *>(malloc(tile_buffer_size));
|
||||
if (trans_input_ == nullptr) {
|
||||
MS_LOG(ERROR) << "malloc trans_input_ failed.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
memset(trans_input_, 0, tile_buffer_size);
|
||||
|
||||
/*=============================gemm_out_============================*/
|
||||
gemm_out_ = reinterpret_cast<float16_t *>(
|
||||
malloc(thread_count_ * cal_num * input_unit_ * input_unit_ * oc8 * C8NUM * sizeof(float16_t)));
|
||||
if (gemm_out_ == nullptr) {
|
||||
MS_LOG(ERROR) << "malloc gemm_out_ failed.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
|
||||
/*=============================tmp_out_data_============================*/
|
||||
int out_w_block = UP_DIV(output_w, output_unit_);
|
||||
int out_h_block = UP_DIV(output_h, output_unit_);
|
||||
tmp_out_data_ = reinterpret_cast<float16_t *>(malloc(conv_param_->output_batch_ * out_w_block * out_h_block *
|
||||
output_unit_ * output_unit_ * oc8 * C8NUM * sizeof(float16_t)));
|
||||
if (tmp_out_data_ == nullptr) {
|
||||
MS_LOG(ERROR) << "malloc tmp_out_data_ failed.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
/*=============================fp16_out_============================*/
|
||||
size_t fp16_output_size = conv_param_->output_channel_ * conv_param_->output_batch_ * conv_param_->output_h_ *
|
||||
conv_param_->output_w_ * sizeof(float16_t);
|
||||
fp16_out_ = reinterpret_cast<float16_t *>(malloc(fp16_output_size));
|
||||
if (fp16_out_ == nullptr) {
|
||||
MS_LOG(ERROR) << "malloc fp16_out_ failed.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
|
||||
/*=============================tmp_data_============================*/
|
||||
tmp_data_ =
|
||||
reinterpret_cast<float16_t *>(malloc(thread_count_ * C4NUM * input_unit_ * input_unit_ * sizeof(float16_t)));
|
||||
if (tmp_data_ == nullptr) {
|
||||
MS_LOG(ERROR) << "malloc tmp_data_ failed.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
memset(tmp_data_, 0, C4NUM * input_unit_ * input_unit_ * sizeof(float16_t));
|
||||
|
||||
tmp_buffer_address_list_[0] = trans_input_;
|
||||
tmp_buffer_address_list_[1] = gemm_out_;
|
||||
tmp_buffer_address_list_[2] = tmp_out_data_;
|
||||
tmp_buffer_address_list_[3] = tmp_data_;
|
||||
|
||||
/*=============================nhwc4_input_============================*/
|
||||
size_t nhwc4_input_size =
|
||||
ic4 * C4NUM * conv_param_->input_batch_ * conv_param_->input_h_ * conv_param_->input_w_ * sizeof(float16_t);
|
||||
nhwc4_input_ = malloc(nhwc4_input_size);
|
||||
if (nhwc4_input_ == nullptr) {
|
||||
MS_LOG(ERROR) << "malloc nhwc4_input_ failed.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
memset(nhwc4_input_, 0, nhwc4_input_size);
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
int ConvolutionWinogradFP16CPUKernel::ConfigInputOutput() {
|
||||
auto input_tensor = in_tensors_.at(kInputIndex);
|
||||
auto ret = CheckLayout(input_tensor);
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "Check layout failed.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
auto output_tensor = out_tensors_.at(kOutputIndex);
|
||||
output_tensor->SetFormat(schema::Format_NHWC);
|
||||
|
||||
// choose input transformer function (4x4 unit or 8x8 unit)
|
||||
input_trans_func_ = GetInputTransFuncFp16(input_unit_);
|
||||
if (input_trans_func_ == nullptr) {
|
||||
MS_LOG(ERROR) << "Get input_trans_func failed.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
output_trans_func_ = GetOutputTransFuncFp16(input_unit_, output_unit_);
|
||||
if (output_trans_func_ == nullptr) {
|
||||
MS_LOG(ERROR) << "Get output_trans_func_ failed.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
int ConvolutionWinogradFP16CPUKernel::Init() {
|
||||
auto ret = ConvolutionBaseCPUKernel::Init();
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "ConvolutionBase init failed.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
kernel_unit_ = conv_param_->kernel_h_;
|
||||
input_unit_ = output_unit_ + kernel_unit_ - 1;
|
||||
conv_param_->input_unit_ = input_unit_;
|
||||
conv_param_->output_unit_ = output_unit_;
|
||||
|
||||
ret = InitWeightBias();
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "Init weight bias failed.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
// malloc tmp buffer
|
||||
ret = InitTmpBuffer();
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "Init tmp buffer failed.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
ret = ConfigInputOutput();
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "ConfigInputOutput failed.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
int ConvolutionWinogradFP16CPUKernel::ReSize() {
|
||||
if (tmp_data_ != nullptr) {
|
||||
free(tmp_data_);
|
||||
}
|
||||
if (trans_input_ != nullptr) {
|
||||
free(trans_input_);
|
||||
}
|
||||
if (gemm_out_ != nullptr) {
|
||||
free(gemm_out_);
|
||||
}
|
||||
if (tmp_out_data_ != nullptr) {
|
||||
free(tmp_out_data_);
|
||||
}
|
||||
if (nhwc4_input_ != nullptr) {
|
||||
free(nhwc4_input_);
|
||||
}
|
||||
if (fp16_input_ != nullptr) {
|
||||
free(fp16_input_);
|
||||
}
|
||||
if (fp16_out_ != nullptr) {
|
||||
free(fp16_out_);
|
||||
}
|
||||
|
||||
auto ret = ConvolutionBaseCPUKernel::Init();
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "ConvolutionBase init failed.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
kernel_unit_ = conv_param_->kernel_h_;
|
||||
input_unit_ = output_unit_ + kernel_unit_ - 1;
|
||||
conv_param_->input_unit_ = input_unit_;
|
||||
conv_param_->output_unit_ = output_unit_;
|
||||
|
||||
ret = InitTmpBuffer();
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "Init tmp buffer failed.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
ret = ConfigInputOutput();
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "ConfigInputOutput failed.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
int ConvolutionWinogradFP16CPUKernel::RunImpl(int task_id) {
|
||||
ConvWinogardFp16(reinterpret_cast<float16_t *>(nhwc4_input_), reinterpret_cast<float16_t *>(trans_weight_->GetData()),
|
||||
reinterpret_cast<const float16_t *>(bias_data_), tmp_buffer_address_list_, task_id, conv_param_,
|
||||
input_trans_func_, output_trans_func_);
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
int ConvolutionWinogradFp16Impl(int task_id, LiteParallelGroupEnv *penv, void *cdata) {
|
||||
auto conv = reinterpret_cast<ConvolutionWinogradFP16CPUKernel *>(cdata);
|
||||
auto error_code = conv->RunImpl(task_id);
|
||||
if (error_code != RET_OK) {
|
||||
MS_LOG(ERROR) << "ConvolutionWinograd Fp16 Run error task_id[" << task_id << "] error_code[" << error_code << "]";
|
||||
return RET_ERROR;
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
int ConvolutionWinogradFP16CPUKernel::Run() {
|
||||
auto prepare_ret = Prepare();
|
||||
if (prepare_ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "Prepare fail!ret: " << prepare_ret;
|
||||
return prepare_ret;
|
||||
}
|
||||
|
||||
ConvolutionBaseFP16CPUKernel::GetExecuteTensor();
|
||||
|
||||
int in_batch = conv_param_->input_batch_;
|
||||
int in_h = conv_param_->input_h_;
|
||||
int in_w = conv_param_->input_w_;
|
||||
int in_channel = conv_param_->input_channel_;
|
||||
convert_func_(execute_input_, nhwc4_input_, in_batch, in_h * in_w, in_channel);
|
||||
|
||||
int error_code = LiteBackendParallelLaunch(ConvolutionWinogradFp16Impl, this, thread_count_);
|
||||
if (error_code != RET_OK) {
|
||||
MS_LOG(ERROR) << "conv winograd error error_code[" << error_code << "]";
|
||||
return RET_ERROR;
|
||||
}
|
||||
|
||||
// get real output
|
||||
UnPackWinogradOutputFp16(tmp_out_data_, execute_output_, conv_param_->output_batch_, conv_param_->output_h_,
|
||||
conv_param_->output_w_, conv_param_->output_channel_, output_unit_);
|
||||
int output_num =
|
||||
conv_param_->output_channel_ * conv_param_->output_h_ * conv_param_->output_w_ * conv_param_->output_batch_;
|
||||
if (conv_param_->is_relu_) {
|
||||
ReluFp16(execute_output_, execute_output_, output_num);
|
||||
} else if (conv_param_->is_relu6_) {
|
||||
Relu6Fp16(execute_output_, execute_output_, output_num);
|
||||
} else {
|
||||
// do nothing
|
||||
}
|
||||
ConvolutionBaseFP16CPUKernel::IfCastOutput();
|
||||
return RET_OK;
|
||||
}
|
||||
} // namespace mindspore::kernel
|
|
@ -0,0 +1,87 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP16_CONVOLUTION_WINOGRAD_FP16_H_
|
||||
#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP16_CONVOLUTION_WINOGRAD_FP16_H_
|
||||
|
||||
#include <arm_neon.h>
|
||||
#include <vector>
|
||||
#include "src/lite_kernel.h"
|
||||
#include "src/runtime/kernel/arm/fp16/convolution_base_fp16.h"
|
||||
#include "src/runtime/kernel/arm/nnacl/fp16/conv_fp16.h"
|
||||
#include "src/runtime/kernel/arm/fp16/matrix_fp16.h"
|
||||
#include "src/runtime/kernel/arm/nnacl/fp16/winograd_utils_fp16.h"
|
||||
#include "src/runtime/kernel/arm/nnacl/optimized_kernel.h"
|
||||
|
||||
namespace mindspore::kernel {
|
||||
class ConvolutionWinogradFP16CPUKernel : public ConvolutionBaseFP16CPUKernel {
|
||||
public:
|
||||
ConvolutionWinogradFP16CPUKernel(OpParameter *parameter, const std::vector<lite::tensor::Tensor *> &inputs,
|
||||
const std::vector<lite::tensor::Tensor *> &outputs, const Context *ctx,
|
||||
const lite::Primitive *primitive)
|
||||
: ConvolutionBaseFP16CPUKernel(parameter, inputs, outputs, ctx, primitive) {}
|
||||
~ConvolutionWinogradFP16CPUKernel() override {
|
||||
if (fp16_input_ != nullptr) {
|
||||
free(fp16_input_);
|
||||
}
|
||||
if (fp16_weight_ != nullptr) {
|
||||
free(fp16_weight_);
|
||||
}
|
||||
if (fp16_out_ != nullptr) {
|
||||
free(fp16_out_);
|
||||
}
|
||||
if (tmp_data_ != nullptr) {
|
||||
free(tmp_data_);
|
||||
}
|
||||
if (trans_input_ != nullptr) {
|
||||
free(trans_input_);
|
||||
}
|
||||
if (gemm_out_ != nullptr) {
|
||||
free(gemm_out_);
|
||||
}
|
||||
if (tmp_out_data_ != nullptr) {
|
||||
free(tmp_out_data_);
|
||||
}
|
||||
delete trans_weight_;
|
||||
}
|
||||
|
||||
int Init() override;
|
||||
int ReSize() override;
|
||||
int Run() override;
|
||||
int RunImpl(int task_id);
|
||||
int InitWeightBias();
|
||||
int MallocFilterMatrix(int oc_block, int oc_block_num);
|
||||
int InitTmpBuffer();
|
||||
int ConfigInputOutput();
|
||||
|
||||
private:
|
||||
int kernel_unit_;
|
||||
int input_unit_;
|
||||
int output_unit_;
|
||||
float16_t *tmp_data_;
|
||||
float16_t *trans_input_;
|
||||
float16_t *gemm_out_;
|
||||
float16_t *tmp_out_data_;
|
||||
Matrix *trans_weight_;
|
||||
InputTransformUnitFp16Func input_trans_func_;
|
||||
OutputTransformUnitFp16Func output_trans_func_;
|
||||
TmpBufferAddressFp16 tmp_buffer_address_list_[4];
|
||||
};
|
||||
void WinogradFilterTransformFp16(const float16_t *weight_data, Matrix *trans_weight, int kernel_unit, int input_unit,
|
||||
ConvParameter *conv_param, int oc_block);
|
||||
} // namespace mindspore::kernel
|
||||
|
||||
#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP16_CONVOLUTION_WINOGRAD_FP16_H_
|
|
@ -115,10 +115,6 @@ int DeconvolutionDepthwiseFp16CPUKernel::InitWeightBias() {
|
|||
}
|
||||
|
||||
int DeconvolutionDepthwiseFp16CPUKernel::Init() {
|
||||
if (context_->infer_shape_interrupt_ && !context_->running_) {
|
||||
set_need_reinit();
|
||||
return RET_OK;
|
||||
}
|
||||
sliding_ = new SlidingWindowParam;
|
||||
InitSlideParam();
|
||||
// conv base init
|
||||
|
|
|
@ -0,0 +1,39 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "src/runtime/kernel/arm/fp16/matrix_fp16.h"
|
||||
|
||||
namespace mindspore::kernel {
|
||||
|
||||
void MatrixMultiplyFp16(const float16_t *matrix_a, const float16_t *matrix_b, float16_t *matrix_c, int m, int k, int n,
|
||||
bool row) {
|
||||
// row-major implementation
|
||||
int count = 0;
|
||||
for (int h = 0; h < m; h++) {
|
||||
int h_offset = h * k;
|
||||
for (int w = 0; w < n; w++) {
|
||||
float16_t res = 0;
|
||||
for (int i = 0; i < k; i++) {
|
||||
res += *(matrix_a + h_offset + i) * *(matrix_b + w + i * n);
|
||||
}
|
||||
*(matrix_c + count) = res;
|
||||
count++;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
} // namespace mindspore::kernel
|
|
@ -0,0 +1,27 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP16_MATRIX_FP16_H_
|
||||
#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP16_MATRIX_FP16_H_
|
||||
|
||||
#include "src/runtime/kernel/arm/base/matrix.h"
|
||||
|
||||
namespace mindspore::kernel {
|
||||
void MatrixMultiplyFp16(const float16_t *matrix_a, const float16_t *matrix_b, float16_t *matrix_c, int m, int k, int n,
|
||||
bool row);
|
||||
}
|
||||
|
||||
#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP16_MATRIX_FP16_H_
|
|
@ -53,10 +53,6 @@ int PoolingFp16CPUKernel::InitBuffer() {
|
|||
}
|
||||
|
||||
int PoolingFp16CPUKernel::Init() {
|
||||
if (context_->infer_shape_interrupt_ && !context_->running_) {
|
||||
set_need_reinit();
|
||||
return RET_OK;
|
||||
}
|
||||
auto ret = PoolingBaseCPUKernel::Init();
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "PoolingBase Init failed.";
|
||||
|
|
|
@ -329,10 +329,9 @@ int ConvolutionWinogradCPUKernel::RunImpl(int task_id) {
|
|||
MS_LOG(ERROR) << "gemm_func is nullptr.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
auto output_addr = reinterpret_cast<float *>(out_tensors_.at(kOutputIndex)->Data());
|
||||
ConvWinogardFp32(reinterpret_cast<float *>(nhwc4_input_), reinterpret_cast<float *>(trans_weight_->GetData()),
|
||||
reinterpret_cast<const float *>(bias_data_), output_addr, tmp_buffer_address_list_, task_id,
|
||||
conv_param_, input_trans_func_, output_trans_func_, gemm_func_);
|
||||
reinterpret_cast<const float *>(bias_data_), tmp_buffer_address_list_, task_id, conv_param_,
|
||||
input_trans_func_, output_trans_func_, gemm_func_);
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
|
|
|
@ -16,9 +16,7 @@
|
|||
#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_CAST_FP16_H_
|
||||
#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_CAST_FP16_H_
|
||||
|
||||
#ifdef ENABLE_NEON
|
||||
#include <arm_neon.h>
|
||||
#endif
|
||||
#include "nnacl/op_base.h"
|
||||
#include "nnacl/fp32/cast.h"
|
||||
#ifdef __cplusplus
|
||||
|
|
|
@ -0,0 +1,61 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "nnacl/fp16/common_func.h"
|
||||
|
||||
void ReluFp16(float16_t *data, float16_t *dst, int ele_num) {
|
||||
int eight_block = UP_DIV(ele_num, C8NUM);
|
||||
for (int i = 0; i < eight_block - 1; i++) {
|
||||
int index = i * C8NUM;
|
||||
#ifdef ENABLE_NEON
|
||||
float16x8_t relu_data = vld1q_f16(data + index);
|
||||
float16x8_t zero_data = vdupq_n_f16(0);
|
||||
relu_data = vmaxq_f16(relu_data, zero_data);
|
||||
vst1q_f16(dst + index, relu_data);
|
||||
#else
|
||||
data[index] = data[index] < 0 ? 0 : data[index];
|
||||
data[index + 1] = data[index + 1] < 0 ? 0 : data[index + 1];
|
||||
data[index + 2] = data[index + 2] < 0 ? 0 : data[index + 2];
|
||||
data[index + 3] = data[index + 3] < 0 ? 0 : data[index + 3];
|
||||
#endif
|
||||
}
|
||||
for (int j = (eight_block - 1) * C8NUM; j < ele_num; ++j) {
|
||||
data[j] = data[j] < 0 ? 0 : data[j];
|
||||
}
|
||||
}
|
||||
|
||||
void Relu6Fp16(float16_t *data, float16_t *dst, int ele_num) {
|
||||
int eight_block = UP_DIV(ele_num, C8NUM);
|
||||
for (int i = 0; i < eight_block - 1; i++) {
|
||||
int index = i * C8NUM;
|
||||
#ifdef ENABLE_NEON
|
||||
float16x8_t relu6_data = vld1q_f16(data + index);
|
||||
float16x8_t zero_data = vdupq_n_f16(0);
|
||||
float16x8_t six_data = vdupq_n_f16(6);
|
||||
relu6_data = vmaxq_f16(relu6_data, zero_data);
|
||||
relu6_data = vminq_f16(relu6_data, six_data);
|
||||
vst1q_f16(dst + index, relu6_data);
|
||||
#else
|
||||
for (int j = 0; j < C8NUM; ++j) {
|
||||
data[index + j] = data[index + j] < 0 ? 0 : data[index + j];
|
||||
data[index + j] = data[index + j] > 6 ? 6 : data[index + j];
|
||||
}
|
||||
#endif
|
||||
}
|
||||
for (int j = (eight_block - 1) * C8NUM; j < ele_num; ++j) {
|
||||
data[j] = data[j] < 0 ? 0 : data[j];
|
||||
data[j] = data[j] > 6 ? 6 : data[j];
|
||||
}
|
||||
}
|
|
@ -39,6 +39,8 @@ void DeconvDwFp16Center(float16_t *dst, const float16_t *src, const float16_t *w
|
|||
size_t kernel_h, size_t kernel_w, size_t out_h_step, size_t block_channel, size_t in_sh_step,
|
||||
size_t in_sw_step, size_t in_kh_step, size_t in_kw_step);
|
||||
#endif
|
||||
void ReluFp16(float16_t *data, float16_t *dst, int ele_num);
|
||||
void Relu6Fp16(float16_t *data, float16_t *dst, int ele_num);
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
|
|
|
@ -32,12 +32,21 @@ void IndirectGemmFp16_16x8(float16_t *output, float16_t *input, float16_t *weigh
|
|||
#endif
|
||||
#ifndef ENABLE_NEON
|
||||
void IndirectGemmFp16_16x8(float16_t *output, float16_t *input, float16_t *weight, float16_t *bias, size_t step,
|
||||
size_t ic4, size_t out_channel, size_t offset, size_t mode, size_t writeC4, size_t relu,
|
||||
size_t ic4, size_t out_channel, size_t offset, size_t mode, size_t writeC8, size_t relu,
|
||||
size_t relu6) {
|
||||
if (!(mode && writeC8)) {
|
||||
IndirectGemmFp16_16x8_common(output, input, weight, bias, step, ic4, output, offset, relu, relu6);
|
||||
} else {
|
||||
IndirectGemmFp16_16x8_c8(output, input, weight, bias, step, ic4, output, offset, mode, writeC8, relu, relu6);
|
||||
}
|
||||
}
|
||||
|
||||
void IndirectGemmFp16_16x8_common(float16_t *output, float16_t *input, float16_t *weight, float16_t *bias, size_t step,
|
||||
size_t ic4, size_t oc8, size_t offset, size_t relu, size_t relu6) {
|
||||
const int tile_n = 16;
|
||||
for (int i = 0; i < out_channel; i++) {
|
||||
int oc8_block = i / 8;
|
||||
int oc8_res = i % 8;
|
||||
int oc8_block = i / C8NUM;
|
||||
int oc8_res = i % C8NUM;
|
||||
int weight_oc_offset = oc8_block * step * ic4 * C4NUM * C8NUM + oc8_res;
|
||||
for (int k = 0; k < tile_n; k++) {
|
||||
int input_tile_offset = k * C4NUM;
|
||||
|
@ -72,32 +81,32 @@ void IndirectGemmFp16_16x8(float16_t *output, float16_t *input, float16_t *weigh
|
|||
}
|
||||
}
|
||||
|
||||
void IndirectGemmFp16_16x8_tmp(float16_t *output, float16_t *input, float16_t *weight, const float16_t *bias,
|
||||
size_t step, size_t ic4, size_t output_channel, size_t offset, size_t mode,
|
||||
size_t writeC4, size_t relu, size_t relu6) {
|
||||
void IndirectGemmFp16_16x8_c8(float16_t *output, float16_t *input, float16_t *weight, float16_t *bias, size_t step,
|
||||
size_t ic4, size_t output_channel, size_t offset, size_t mode, size_t writeC8,
|
||||
size_t relu, size_t relu6) {
|
||||
const int tile_num = 16;
|
||||
if (mode) {
|
||||
if (mode && writeC8) {
|
||||
for (int i = 0; i < tile_num; i++) {
|
||||
int input_tile_offset = i * C4NUM;
|
||||
int output_tile_offset = i * output_channel * 36;
|
||||
int output_tile_offset = i * output_channel * step;
|
||||
for (int j = 0; j < output_channel; j++) {
|
||||
int oc8_block = j / 8;
|
||||
int oc8_res = j % 8;
|
||||
int weight_oc_offset = oc8_block * 36 * ic4 * C4NUM * 8 + oc8_res;
|
||||
int out_oc_offset = output_tile_offset + oc8_block * 36 * C8NUM + oc8_res;
|
||||
int oc8_block = j / C8NUM;
|
||||
int oc8_res = j % C8NUM;
|
||||
int weight_oc_offset = oc8_block * step * ic4 * C4NUM * C8NUM + oc8_res;
|
||||
int out_oc_offset = output_tile_offset + oc8_block * step * C8NUM + oc8_res;
|
||||
|
||||
for (int n = 0; n < step; n++) {
|
||||
int input_kw_offset = input_tile_offset + n * ic4 * C4NUM * tile_num;
|
||||
int weight_kw_offset = weight_oc_offset + n * ic4 * C4NUM * 8;
|
||||
int weight_kw_offset = weight_oc_offset + n * ic4 * C4NUM * C8NUM;
|
||||
int output_kw_offset = out_oc_offset + n * C8NUM;
|
||||
float16_t acc = 0;
|
||||
|
||||
for (int k = 0; k < ic4; k++) {
|
||||
int input_ic4_offset = input_kw_offset + k * tile_num * C4NUM;
|
||||
int weight_ic4_offset = weight_kw_offset + k * C4NUM * 8;
|
||||
for (int m = 0; m < 4; m++) {
|
||||
int weight_ic4_offset = weight_kw_offset + k * C4NUM * C8NUM;
|
||||
for (int m = 0; m < C4NUM; m++) {
|
||||
int input_ic_offset = input_ic4_offset + m;
|
||||
int weight_ic_offset = weight_ic4_offset + m * 8;
|
||||
int weight_ic_offset = weight_ic4_offset + m * C8NUM;
|
||||
acc += (weight + weight_ic_offset)[0] * (input + input_ic_offset)[0];
|
||||
}
|
||||
}
|
||||
|
@ -405,3 +414,91 @@ void Conv3x3Fp16(float16_t *input_data, float16_t *transed_weight, const float16
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
// fp16 convolution winograd
|
||||
void ConvWinogardFp16(float16_t *input_data, float16_t *trans_weight, const float16_t *bias_data,
|
||||
TmpBufferAddressFp16 *buffer_list, int task_id, ConvParameter *conv_param,
|
||||
InputTransformUnitFp16Func input_trans_func, OutputTransformUnitFp16Func output_trans_func) {
|
||||
int thread_num = conv_param->thread_num_;
|
||||
int input_unit = conv_param->input_unit_;
|
||||
int in_batch = conv_param->input_batch_;
|
||||
int in_channel = conv_param->input_channel_;
|
||||
int ic4 = UP_DIV(in_channel, C4NUM);
|
||||
int out_unit = conv_param->output_unit_;
|
||||
int out_w_block = UP_DIV(conv_param->output_w_, out_unit);
|
||||
int out_h_block = UP_DIV(conv_param->output_h_, out_unit);
|
||||
int tile_num = 16;
|
||||
int output_count = out_w_block * out_h_block;
|
||||
int output_tile_count = UP_DIV(output_count, tile_num);
|
||||
int out_channel = conv_param->output_channel_;
|
||||
int oc8 = UP_DIV(out_channel, C8NUM);
|
||||
int input_unit_square = input_unit * input_unit;
|
||||
size_t output_offset = oc8 * C8NUM * input_unit_square * sizeof(float16_t);
|
||||
|
||||
float16_t *trans_input = buffer_list[0];
|
||||
float16_t *gemm_out = buffer_list[1];
|
||||
float16_t *tmp_out_data = buffer_list[2];
|
||||
float16_t *tmp_data = buffer_list[3];
|
||||
int trans_input_offset = tile_num * input_unit_square * ic4 * C4NUM;
|
||||
int gemm_out_offset = tile_num * input_unit_square * oc8 * C8NUM;
|
||||
int tmp_data_offset = input_unit_square * C4NUM;
|
||||
// step 1 : filter transform (pre-processed offline)
|
||||
// step 2 : input transform (online)
|
||||
for (int b = 0; b < in_batch; b++) {
|
||||
int in_batch_offset = b * ic4 * C4NUM * conv_param->input_h_ * conv_param->input_w_;
|
||||
int tmp_out_batch_offset = b * out_w_block * out_h_block * out_unit * out_unit * oc8 * C8NUM;
|
||||
for (int thread_id = task_id; thread_id < output_tile_count; thread_id += thread_num) {
|
||||
int out_tile_index = thread_id * TILE_NUM;
|
||||
int cal_num = output_count - thread_id * tile_num;
|
||||
cal_num = cal_num > tile_num ? tile_num : cal_num;
|
||||
WinogradInputTransformFp16(input_data + in_batch_offset, trans_input + task_id * trans_input_offset,
|
||||
tmp_data + task_id * tmp_data_offset, cal_num, out_tile_index, out_w_block, conv_param,
|
||||
input_trans_func);
|
||||
// step 3 : gemm
|
||||
IndirectGemmFp16_16x8(gemm_out + task_id * gemm_out_offset, trans_input + task_id * trans_input_offset,
|
||||
trans_weight, NULL, input_unit_square, ic4, oc8 * C8NUM, output_offset, 1, 1, 0, 0);
|
||||
|
||||
// step 4 : output transform
|
||||
WinogradOutputTransformFp16(gemm_out + task_id * gemm_out_offset, tmp_out_data + tmp_out_batch_offset, bias_data,
|
||||
cal_num, out_tile_index, out_w_block, conv_param, output_trans_func);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void UnPackWinogradOutputFp16(const float16_t *src, float16_t *dst, int batch, int height, int width, int channel,
|
||||
int output_unit) {
|
||||
int out_h_block_num = UP_DIV(height, output_unit);
|
||||
int out_w_block_num = UP_DIV(width, output_unit);
|
||||
int c8 = UP_DIV(channel, C8NUM);
|
||||
for (int b = 0; b < batch; b++) {
|
||||
int src_batch_offset = b * c8 * C8NUM * out_h_block_num * output_unit * out_w_block_num * output_unit;
|
||||
int dst_batch_offset = b * height * width * channel;
|
||||
for (int h = 0; h < height; h++) {
|
||||
int src_h_offset = src_batch_offset + C8NUM * (h * out_w_block_num * output_unit);
|
||||
int dst_h_offset = dst_batch_offset + h * width * channel;
|
||||
for (int w = 0; w < width; w++) {
|
||||
int src_w_offset = src_h_offset + w * C8NUM;
|
||||
int dst_w_offset = dst_h_offset + w * channel;
|
||||
for (int c = 0; c < c8 - 1; c++) {
|
||||
int src_c8_offset = src_w_offset + c * C8NUM * out_w_block_num * out_h_block_num * output_unit * output_unit;
|
||||
int dst_c8_offset = dst_w_offset + c * C8NUM;
|
||||
#ifdef ENABLE_NEON
|
||||
vst1q_f16(dst + dst_c8_offset, vld1q_f16(src + src_c8_offset));
|
||||
#else
|
||||
for (int i = 0; i < C8NUM; ++i) {
|
||||
dst[dst_c8_offset + i] = src[src_c8_offset + i];
|
||||
}
|
||||
#endif
|
||||
}
|
||||
int c_res = channel - (c8 - 1) * C8NUM;
|
||||
int src_c_res_offset = (c8 - 1) * C8NUM * out_w_block_num * out_h_block_num * output_unit * output_unit;
|
||||
int dst_c_res_offset = (c8 - 1) * C8NUM;
|
||||
for (int c = 0; c < c_res; c++) {
|
||||
int src_c8_res_offset = src_w_offset + src_c_res_offset + c;
|
||||
int dst_c8_res_offset = dst_w_offset + dst_c_res_offset + c;
|
||||
dst[dst_c8_res_offset] = src[src_c8_res_offset];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -18,11 +18,22 @@
|
|||
|
||||
#include <arm_neon.h>
|
||||
#include "nnacl/conv_parameter.h"
|
||||
#include "nnacl/fp16/winograd_utils_fp16.h"
|
||||
#include "nnacl/fp16/winograd_transform_fp16.h"
|
||||
|
||||
typedef float16_t *TmpBufferAddressFp16;
|
||||
|
||||
#ifndef ENABLE_NEON
|
||||
void IndirectGemmFp16_16x8(float16_t *output, float16_t *input, float16_t *weight, float16_t *bias, size_t step,
|
||||
size_t ic4, size_t oc8, size_t offset, size_t mode, size_t writeC4, size_t relu,
|
||||
size_t ic4, size_t oc8, size_t offset, size_t mode, size_t writeC8, size_t relu,
|
||||
size_t relu6);
|
||||
|
||||
void IndirectGemmFp16_16x8_common(float16_t *output, float16_t *input, float16_t *weight, float16_t *bias, size_t step,
|
||||
size_t ic4, size_t oc8, size_t offset, size_t relu, size_t relu6);
|
||||
|
||||
void IndirectGemmFp16_16x8_c8(float16_t *output, float16_t *input, float16_t *weight, float16_t *bias, size_t step,
|
||||
size_t ic4, size_t oc8, size_t offset, size_t mode, size_t writeC8, size_t relu,
|
||||
size_t relu6);
|
||||
#endif
|
||||
|
||||
#ifdef __cplusplus
|
||||
|
@ -48,6 +59,14 @@ void ConvFp16(float16_t *input_data, float16_t *packed_input, float16_t *packed_
|
|||
void Conv3x3Fp16(float16_t *input_data, float16_t *transed_weight, const float16_t *bias_data, float16_t *output_data,
|
||||
float16_t *tile_buffer, float16_t *block_unit_buffer, float16_t *tmp_dst_buffer, float16_t *tmp_out,
|
||||
int task_id, ConvParameter *conv_param);
|
||||
|
||||
// fp16 convolution winograd
|
||||
void ConvWinogardFp16(float16_t *input_data, float16_t *trans_weight, const float16_t *bias_data,
|
||||
TmpBufferAddressFp16 *buffer_list, int task_id, ConvParameter *conv_param,
|
||||
InputTransformUnitFp16Func input_trans_func, OutputTransformUnitFp16Func output_trans_func);
|
||||
|
||||
void UnPackWinogradOutputFp16(const float16_t *src, float16_t *dst, int batch, int height, int width, int channel,
|
||||
int output_unit);
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
|
|
|
@ -534,3 +534,95 @@ void Conv3x3Fp16OutputTransform(const float16_t *gemm_out, float16_t *out_data,
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
// fp16 common winograd
|
||||
void WinogradInputTransformFp16(const float16_t *input_data, float16_t *trans_input, float16_t *tmp_data, int cal_num,
|
||||
int out_tile_index, int out_w_block_num, ConvParameter *conv_param,
|
||||
InputTransformUnitFp16Func input_trans_func) {
|
||||
int tile_num = 16;
|
||||
int input_unit = conv_param->input_unit_;
|
||||
int output_unit = conv_param->output_unit_;
|
||||
int in_channel = conv_param->input_channel_;
|
||||
int ic4 = UP_DIV(in_channel, C4NUM);
|
||||
int pad_h = conv_param->pad_h_;
|
||||
int pad_w = conv_param->pad_w_;
|
||||
int input_h = conv_param->input_h_;
|
||||
int input_w = conv_param->input_w_;
|
||||
if (out_w_block_num == 0) {
|
||||
return;
|
||||
}
|
||||
for (int c = 0; c < cal_num; c++) { // actual tiled number
|
||||
int src_x_s = (out_tile_index % out_w_block_num) * output_unit - pad_w;
|
||||
int src_y_s = (out_tile_index / out_w_block_num) * output_unit - pad_h;
|
||||
int interval_x_s = src_x_s > 0 ? 0 : -src_x_s;
|
||||
int interval_y_s = src_y_s > 0 ? 0 : -src_y_s;
|
||||
int src_x_e = src_x_s + input_unit;
|
||||
int src_y_e = src_y_s + input_unit;
|
||||
int interval_x_e = src_x_e < input_w ? input_unit : (input_w - src_x_s);
|
||||
int interval_y_e = src_y_e < input_h ? input_unit : (input_h - src_y_s);
|
||||
|
||||
int src_plane_offset = ic4 * C4NUM * (src_y_s * input_w + src_x_s);
|
||||
int dst_plane_offset = c * C4NUM;
|
||||
for (int ic = 0; ic < ic4; ic++) {
|
||||
// clear tmp buffer
|
||||
memset(tmp_data, 0, input_unit * input_unit * C4NUM * sizeof(float16_t));
|
||||
|
||||
// get real input block with padding
|
||||
int src_ic4_offset = src_plane_offset + ic * C4NUM;
|
||||
for (int interval = interval_y_s; interval < interval_y_e; interval++) {
|
||||
int src_y_offset = src_ic4_offset + (interval * input_w + interval_x_s) * ic4 * C4NUM;
|
||||
int dst_y_offset = interval * input_unit * C4NUM + interval_x_s * C4NUM;
|
||||
for (int j = 0; j < (interval_x_e - interval_x_s); j++) {
|
||||
int src_x_offset = src_y_offset + j * ic4 * C4NUM;
|
||||
int dst_x_offset = dst_y_offset + j * C4NUM;
|
||||
float16_t *src_addr = input_data + src_x_offset;
|
||||
float16_t *dst_addr = tmp_data + dst_x_offset;
|
||||
#ifdef ENABLE_NEON
|
||||
vst1_f16(dst_addr, vld1_f16(src_addr));
|
||||
#else
|
||||
for (int k = 0; k < C4NUM; k++) {
|
||||
dst_addr[k] = src_addr[k];
|
||||
}
|
||||
#endif
|
||||
}
|
||||
}
|
||||
// input transform
|
||||
int dst_ic4_offset = dst_plane_offset + ic * tile_num * C4NUM;
|
||||
size_t dst_step = ic4 * C4NUM * tile_num;
|
||||
float16_t *trans_input_ptr = trans_input + dst_ic4_offset;
|
||||
input_trans_func(tmp_data, trans_input_ptr, C4NUM, dst_step);
|
||||
}
|
||||
out_tile_index++;
|
||||
} // cal_tile_num loop
|
||||
}
|
||||
|
||||
void WinogradOutputTransformFp16(const float16_t *gemm_out, float16_t *tmp_out_data, const float16_t *bias_data,
|
||||
int cal_num, int out_tile_index, int output_unit_num, ConvParameter *conv_param,
|
||||
OutputTransformUnitFp16Func output_trans_func) {
|
||||
int output_unit = conv_param->output_unit_;
|
||||
int output_w = conv_param->output_w_;
|
||||
int output_unit_block = UP_DIV(output_w, output_unit);
|
||||
int output_channel = conv_param->output_channel_;
|
||||
int oc8 = UP_DIV(output_channel, C8NUM);
|
||||
int input_unit = conv_param->input_unit_;
|
||||
if (output_unit_num == 0) {
|
||||
return;
|
||||
}
|
||||
for (int i = 0; i < cal_num; i++) {
|
||||
int dst_x_s = out_tile_index % output_unit_num;
|
||||
int dst_y_s = out_tile_index / output_unit_num;
|
||||
int src_tile_offset = i * oc8 * C8NUM * input_unit * input_unit;
|
||||
int dst_tile_offset = C4NUM * output_unit * (dst_x_s + dst_y_s * output_unit_block * output_unit);
|
||||
|
||||
for (int j = 0; j < oc8; j++) {
|
||||
int src_oc8_offset = src_tile_offset + j * input_unit * input_unit * C8NUM;
|
||||
int dst_oc8_offset =
|
||||
dst_tile_offset + j * C8NUM * output_unit_block * output_unit_block * output_unit * output_unit;
|
||||
const float16_t *src_ptr = gemm_out + src_oc8_offset;
|
||||
const float16_t *bias_ptr = bias_data + j * C8NUM;
|
||||
float16_t *dst_ptr = tmp_out_data + dst_oc8_offset;
|
||||
output_trans_func(src_ptr, dst_ptr, bias_ptr, C8NUM, output_unit_block * output_unit);
|
||||
}
|
||||
out_tile_index++;
|
||||
}
|
||||
}
|
||||
|
|
|
@ -21,6 +21,7 @@
|
|||
#include <string.h>
|
||||
#include "nnacl/fp16/pack_fp16.h"
|
||||
#include "nnacl/fp16/conv_fp16.h"
|
||||
#include "nnacl/fp16/winograd_utils_fp16.h"
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
|
@ -39,6 +40,15 @@ void Conv3x3Fp16OutputUnit(const float16_t *gemm_out, const float16_t *bias_data
|
|||
|
||||
void Conv3x3Fp16OutputTransform(const float16_t *gemm_out, float16_t *out_data, const float16_t *bias_data,
|
||||
int start_index, int real_cal_num, int out_w_block, ConvParameter *conv_param);
|
||||
|
||||
// fp16 common winograd
|
||||
void WinogradInputTransformFp16(const float16_t *input_data, float16_t *trans_input, float16_t *tmp_data, int cal_num,
|
||||
int out_tile_index, int out_w_block_num, ConvParameter *conv_param,
|
||||
InputTransformUnitFp16Func input_trans_func);
|
||||
|
||||
void WinogradOutputTransformFp16(const float16_t *gemm_out, float16_t *tmp_out_data, const float16_t *bias_data,
|
||||
int cal_num, int out_tile_index, int output_unit_num, ConvParameter *conv_param,
|
||||
OutputTransformUnitFp16Func output_trans_func);
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
|
|
File diff suppressed because it is too large
Load Diff
|
@ -0,0 +1,67 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_FP16_WINOGRAD_UTILS_H_
|
||||
#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_FP16_WINOGRAD_UTILS_H_
|
||||
|
||||
#include <arm_neon.h>
|
||||
#include "nnacl/conv_parameter.h"
|
||||
#include "nnacl/op_base.h"
|
||||
|
||||
typedef void (*InputTransformUnitFp16Func)(const float16_t *src_data, float16_t *dst_data, int src_step, int dst_step);
|
||||
typedef void (*OutputTransformUnitFp16Func)(const float16_t *src_data, float16_t *dst_data, const float16_t *bias_data,
|
||||
int src_step, int dst_step);
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif
|
||||
void InputTransform4x4UnitFp16(const float16_t *src_data, float16_t *dst_data, int src_step, int dst_step);
|
||||
|
||||
void InputTransform8x8UnitFp16(const float16_t *src_data, float16_t *dst_data, int src_step, int dst_step);
|
||||
|
||||
void OutputTransform4x2UnitFp16(const float16_t *src_data, float16_t *dst_data, const float16_t *bias_data,
|
||||
int src_step, int dst_step);
|
||||
|
||||
void OutputTransform4x3UnitFp16(const float16_t *src_data, float16_t *dst_data, const float16_t *bias_data,
|
||||
int src_step, int dst_step);
|
||||
|
||||
void OutputTransform8x2UnitFp16(const float16_t *src_data, float16_t *dst_data, const float16_t *bias_data,
|
||||
int src_step, int dst_step);
|
||||
|
||||
void OutputTransform8x3UnitFp16(const float16_t *src_data, float16_t *dst_data, const float16_t *bias_data,
|
||||
int src_step, int dst_step);
|
||||
|
||||
void OutputTransform8x4UnitFp16(const float16_t *src_data, float16_t *dst_data, const float16_t *bias_data,
|
||||
int src_step, int dst_step);
|
||||
|
||||
void OutputTransform8x5UnitFp16(const float16_t *src_data, float16_t *dst_data, const float16_t *bias_data,
|
||||
int src_step, int dst_step);
|
||||
|
||||
void OutputTransform8x6UnitFp16(const float16_t *src_data, float16_t *dst_data, const float16_t *bias_data,
|
||||
int src_step, int dst_step);
|
||||
|
||||
void OutputTransform8x7UnitFp16(const float16_t *src_data, float16_t *dst_data, const float16_t *bias_data,
|
||||
int src_step, int dst_step);
|
||||
|
||||
InputTransformUnitFp16Func GetInputTransFuncFp16(int input_unit);
|
||||
|
||||
OutputTransformUnitFp16Func GetOutputTransFuncFp16(int input_unit, int output_unit);
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
|
||||
#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_FP16_WINOGRAD_UTILS_H_
|
|
@ -243,10 +243,9 @@ int Conv1x1Fp32(const float *input_data, const float *weight_data, float *output
|
|||
}
|
||||
|
||||
// fp32 conv winograd
|
||||
void ConvWinogardFp32(float *input_data, float *trans_weight, const float *bias_data, float *output_data,
|
||||
TmpBufferAddress *buffer_list, int task_id, ConvParameter *conv_param,
|
||||
InputTransformUnitFunc input_trans_func, OutputTransformUnitFunc output_trans_func,
|
||||
GEMM_FUNC_FP32 gemm_func) {
|
||||
void ConvWinogardFp32(float *input_data, float *trans_weight, const float *bias_data, TmpBufferAddress *buffer_list,
|
||||
int task_id, ConvParameter *conv_param, InputTransformUnitFunc input_trans_func,
|
||||
OutputTransformUnitFunc output_trans_func, GEMM_FUNC_FP32 gemm_func) {
|
||||
int thread_num = conv_param->thread_num_;
|
||||
int input_unit = conv_param->input_unit_;
|
||||
int in_batch = conv_param->input_batch_;
|
||||
|
|
|
@ -57,10 +57,9 @@ int Conv1x1Fp32(const float *input_data, const float *weight_data, float *output
|
|||
StrassenMatMulParameter matmul_param);
|
||||
|
||||
// fp32 convolution winograd
|
||||
void ConvWinogardFp32(float *input_data, float *trans_weight, const float *bias_data, float *output_data,
|
||||
TmpBufferAddress *buffer_list, int task_id, ConvParameter *conv_param,
|
||||
InputTransformUnitFunc input_trans_func, OutputTransformUnitFunc output_trans_func,
|
||||
GEMM_FUNC_FP32 gemm_func);
|
||||
void ConvWinogardFp32(float *input_data, float *trans_weight, const float *bias_data, TmpBufferAddress *buffer_list,
|
||||
int task_id, ConvParameter *conv_param, InputTransformUnitFunc input_trans_func,
|
||||
OutputTransformUnitFunc output_trans_func, GEMM_FUNC_FP32 gemm_func);
|
||||
|
||||
void UnPackWinogradOutput(const float *src, float *dst, int batch, int height, int width, int channel, int output_unit);
|
||||
|
||||
|
|
Loading…
Reference in New Issue