forked from mindspore-Ecosystem/mindspore
parent
2d22cbc335
commit
f902997ed5
|
@ -23,10 +23,6 @@ Matrix *TransformMatrixGenerator(int m, int k) {
|
|||
auto aa = malloc(m * k * sizeof(float));
|
||||
matrix->SetData(aa);
|
||||
matrix->SetNum(m, k);
|
||||
// matrix->data_ = malloc(m * k * sizeof(float));
|
||||
// matrix->m_ = m;
|
||||
// matrix->k_ = k;
|
||||
// matrix->row_major_ = true;
|
||||
return matrix;
|
||||
}
|
||||
|
||||
|
|
|
@ -65,26 +65,6 @@ class Matrix {
|
|||
int n_dim_;
|
||||
bool row_major_;
|
||||
};
|
||||
// struct Matrix {
|
||||
// void *data_;
|
||||
// int *shape_;
|
||||
// int *stride_;
|
||||
// int m_;
|
||||
// int k_;
|
||||
// int n_dim_;
|
||||
// bool row_major_;
|
||||
// ~Matrix() {
|
||||
// if (data_ != nullptr) {
|
||||
// free(data_);
|
||||
// }
|
||||
// if (shape_ != nullptr) {
|
||||
// free(shape_);
|
||||
// }
|
||||
// if (shape_ != nullptr) {
|
||||
// free(stride_);
|
||||
// }
|
||||
// }
|
||||
//};
|
||||
|
||||
Matrix *TransformMatrixGenerator(int m, int k);
|
||||
|
||||
|
|
|
@ -16,6 +16,7 @@
|
|||
|
||||
#include "src/runtime/kernel/arm/fp16/convolution_3x3_fp16.h"
|
||||
#include "src/runtime/kernel/arm/nnacl/fp16/conv_fp16.h"
|
||||
#include "src/runtime/kernel/arm/nnacl/fp16/cast_fp16.h"
|
||||
#include "src/runtime/kernel/arm/nnacl/fp16/winograd_transform_fp16.h"
|
||||
#include "src/runtime/kernel/arm/nnacl/fp16/pack_fp16.h"
|
||||
#include "src/runtime/kernel/arm/fp16/layout_transform_fp16.h"
|
||||
|
@ -265,11 +266,9 @@ int Convolution3x3FP16CPUKernel::Run() {
|
|||
return RET_ERROR;
|
||||
}
|
||||
auto input_tensor = in_tensors_.at(kInputIndex);
|
||||
auto input_ele_num = input_tensor->ElementsNum();
|
||||
auto ori_input_data = reinterpret_cast<float *>(input_tensor->Data());
|
||||
auto input_element_num = input_tensor->ElementsNum();
|
||||
for (int i = 0; i < input_element_num; ++i) {
|
||||
fp16_input_[i] = (float16_t)ori_input_data[i];
|
||||
}
|
||||
Float32ToFloat16(ori_input_data, fp16_input_, input_ele_num);
|
||||
|
||||
int in_batch = conv_param_->input_batch_;
|
||||
int in_h = conv_param_->input_h_;
|
||||
|
@ -285,12 +284,9 @@ int Convolution3x3FP16CPUKernel::Run() {
|
|||
|
||||
// cast fp16 out to fp32 data
|
||||
auto out_tensor = out_tensors_.at(kOutputIndex);
|
||||
auto out_ele_num = out_tensor->ElementsNum();
|
||||
auto output_addr = reinterpret_cast<float *>(out_tensor->Data());
|
||||
auto output_element_num = out_tensor->ElementsNum();
|
||||
|
||||
for (int j = 0; j < output_element_num; ++j) {
|
||||
output_addr[j] = static_cast<float>(fp16_out_[j]);
|
||||
}
|
||||
Float16ToFloat32(fp16_out_, output_addr, out_ele_num);
|
||||
return RET_OK;
|
||||
}
|
||||
} // namespace mindspore::kernel
|
||||
|
|
|
@ -15,8 +15,11 @@
|
|||
*/
|
||||
|
||||
#include "src/runtime/kernel/arm/fp16/convolution_fp16.h"
|
||||
#include <vector>
|
||||
#include "src/runtime/kernel/arm/fp16/convolution_sw_fp16.h"
|
||||
#include "src/runtime/kernel/arm/fp16/convolution_3x3_fp16.h"
|
||||
#include "src/runtime/kernel/arm/nnacl/fp16/conv_fp16.h"
|
||||
#include "src/runtime/kernel/arm/nnacl/fp16/cast_fp16.h"
|
||||
#include "src/runtime/kernel/arm/nnacl/fp16/pack_fp16.h"
|
||||
#include "src/runtime/kernel/arm/fp16/layout_transform_fp16.h"
|
||||
#include "schema/model_generated.h"
|
||||
|
@ -231,10 +234,8 @@ int ConvolutionFP16CPUKernel::Run() {
|
|||
}
|
||||
auto input_tensor = in_tensors_.at(kInputIndex);
|
||||
auto ori_input_data = reinterpret_cast<float *>(input_tensor->Data());
|
||||
auto input_element_num = input_tensor->ElementsNum();
|
||||
for (int i = 0; i < input_element_num; ++i) {
|
||||
fp16_input_[i] = (float16_t)ori_input_data[i];
|
||||
}
|
||||
auto input_ele_num = input_tensor->ElementsNum();
|
||||
Float32ToFloat16(ori_input_data, fp16_input_, input_ele_num);
|
||||
|
||||
int in_batch = conv_param_->input_batch_;
|
||||
int in_h = conv_param_->input_h_;
|
||||
|
@ -251,10 +252,8 @@ int ConvolutionFP16CPUKernel::Run() {
|
|||
// cast fp16 out to fp32 data
|
||||
auto out_tensor = out_tensors_.at(kOutputIndex);
|
||||
auto output_addr = reinterpret_cast<float *>(out_tensor->Data());
|
||||
auto output_element_num = out_tensor->ElementsNum();
|
||||
for (int j = 0; j < output_element_num; ++j) {
|
||||
output_addr[j] = static_cast<float>(fp16_out_[j]);
|
||||
}
|
||||
auto out_ele_num = out_tensor->ElementsNum();
|
||||
Float16ToFloat32(fp16_out_, output_addr, out_ele_num);
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
|
|
|
@ -0,0 +1,269 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "src/runtime/kernel/arm/fp16/convolution_sw_fp16.h"
|
||||
#include <vector>
|
||||
#include "src/runtime/kernel/arm/nnacl/fp16/conv_fp16.h"
|
||||
#include "src/runtime/kernel/arm/nnacl/fp16/cast_fp16.h"
|
||||
#include "src/runtime/kernel/arm/nnacl/fp16/pack_fp16.h"
|
||||
#include "src/runtime/kernel/arm/nnacl/fp32/conv_depthwise.h"
|
||||
#include "src/runtime/kernel/arm/fp16/layout_transform_fp16.h"
|
||||
#include "schema/model_generated.h"
|
||||
#include "src/kernel_registry.h"
|
||||
#include "include/errorcode.h"
|
||||
#include "src/runtime/runtime_api.h"
|
||||
|
||||
using mindspore::kernel::KERNEL_ARCH::kCPU;
|
||||
using mindspore::lite::KernelRegistrar;
|
||||
using mindspore::lite::RET_ERROR;
|
||||
using mindspore::lite::RET_OK;
|
||||
using mindspore::schema::PrimitiveType_Conv2D;
|
||||
|
||||
namespace mindspore::kernel {
|
||||
int ConvolutionSWFP16CPUKernel::ProcessFilter() {
|
||||
int kernel_h = conv_param_->kernel_h_;
|
||||
int kernel_w = conv_param_->kernel_w_;
|
||||
int in_channel = conv_param_->input_channel_;
|
||||
int out_channel = conv_param_->output_channel_;
|
||||
int ic4 = UP_DIV(in_channel, C4NUM);
|
||||
|
||||
auto *origin_weight = reinterpret_cast<float *>(in_tensors_.at(kWeightIndex)->Data());
|
||||
size_t fp16_weight_size = in_channel * out_channel * kernel_h * kernel_w * sizeof(float16_t);
|
||||
fp16_weight_ = reinterpret_cast<float16_t *>(malloc(fp16_weight_size));
|
||||
if (fp16_weight_ == nullptr) {
|
||||
MS_LOG(ERROR) << "malloc fp16_weight_ failed.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
// cast origin fp32 weight data to fp16 data
|
||||
for (int i = 0; i < fp16_weight_size / sizeof(float16_t); ++i) {
|
||||
fp16_weight_[i] = (float16_t)origin_weight[i];
|
||||
}
|
||||
|
||||
for (int oc = 0; oc < out_channel; ++oc) {
|
||||
int src_oc_offset = oc * kernel_h * kernel_w * in_channel;
|
||||
int dst_oc_offset = oc * kernel_h * kernel_w * ic4 * C4NUM;
|
||||
for (int i = 0; i < kernel_h * kernel_w; ++i) {
|
||||
const float16_t *src = fp16_weight_ + src_oc_offset + i * in_channel;
|
||||
float16_t *dst = packed_weight_ + dst_oc_offset + i * ic4 * C4NUM;
|
||||
memcpy(dst, src, in_channel * sizeof(float16_t));
|
||||
}
|
||||
}
|
||||
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
int ConvolutionSWFP16CPUKernel::InitWeightBias() {
|
||||
int kernel_h = conv_param_->kernel_h_;
|
||||
int kernel_w = conv_param_->kernel_w_;
|
||||
int in_channel = conv_param_->input_channel_;
|
||||
int out_channel = conv_param_->output_channel_;
|
||||
int oc4 = UP_DIV(out_channel, C4NUM);
|
||||
int ic4 = UP_DIV(in_channel, C4NUM);
|
||||
int kernel_plane = kernel_h * kernel_w;
|
||||
int pack_weight_size = oc4 * ic4 * C4NUM * C4NUM * kernel_plane;
|
||||
|
||||
// init weight
|
||||
packed_weight_ = reinterpret_cast<float16_t *>(malloc(pack_weight_size * sizeof(float16_t)));
|
||||
if (packed_weight_ == nullptr) {
|
||||
MS_LOG(ERROR) << "malloc packed_weight_ failed.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
memset(packed_weight_, 0, pack_weight_size * sizeof(float16_t));
|
||||
auto ret = ProcessFilter();
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "Process filter failed.";
|
||||
return ret;
|
||||
}
|
||||
|
||||
// init bias
|
||||
bias_data_ = malloc(oc4 * C4NUM * sizeof(float16_t));
|
||||
if (bias_data_ == nullptr) {
|
||||
MS_LOG(ERROR) << "malloc bias_data_ failed.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
memset(bias_data_, 0, oc4 * C4NUM * sizeof(float16_t));
|
||||
auto fp16_bias_data = reinterpret_cast<float16_t *>(bias_data_);
|
||||
if (in_tensors_.size() == kInputSize2) {
|
||||
auto ori_bias = reinterpret_cast<float *>(in_tensors_.at(kBiasIndex)->Data());
|
||||
for (int i = 0; i < out_channel; ++i) {
|
||||
fp16_bias_data[i] = (float16_t)ori_bias[i];
|
||||
}
|
||||
} else {
|
||||
MS_ASSERT(in_tensor_.size() == kInputSize1);
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
int ConvolutionSWFP16CPUKernel::InitTmpBuffer() {
|
||||
int in_channel = conv_param_->input_channel_;
|
||||
int out_channel = conv_param_->output_channel_;
|
||||
int channel_block = UP_DIV(in_channel, C4NUM);
|
||||
int oc4 = UP_DIV(out_channel, C4NUM);
|
||||
|
||||
/*=============================fp16_input_============================*/
|
||||
size_t fp16_input_size =
|
||||
in_channel * conv_param_->input_batch_ * conv_param_->input_h_ * conv_param_->input_w_ * sizeof(float16_t);
|
||||
fp16_input_ = reinterpret_cast<float16_t *>(malloc(fp16_input_size));
|
||||
if (fp16_input_ == nullptr) {
|
||||
MS_LOG(ERROR) << "malloc fp16_input_ failed.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
|
||||
/*=============================nhwc4_input_============================*/
|
||||
size_t nhwc4_input_size = channel_block * C4NUM * conv_param_->input_batch_ * conv_param_->input_h_ *
|
||||
conv_param_->input_w_ * sizeof(float16_t);
|
||||
nhwc4_input_ = malloc(nhwc4_input_size);
|
||||
if (nhwc4_input_ == nullptr) {
|
||||
MS_LOG(ERROR) << "malloc nhwc4_input_ failed.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
memset(nhwc4_input_, 0, nhwc4_input_size);
|
||||
|
||||
/*=============================tmp_output_block_============================*/
|
||||
tmp_output_block_ = reinterpret_cast<float16_t *>(malloc(conv_param_->output_batch_ * conv_param_->output_h_ *
|
||||
conv_param_->output_w_ * oc4 * C4NUM * sizeof(float16_t)));
|
||||
if (tmp_output_block_ == nullptr) {
|
||||
MS_LOG(ERROR) << "malloc tmp_output_block_ failed.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
|
||||
/*=============================fp16_out_============================*/
|
||||
size_t fp16_output_size =
|
||||
out_channel * conv_param_->output_batch_ * conv_param_->output_h_ * conv_param_->output_w_ * sizeof(float16_t);
|
||||
fp16_out_ = reinterpret_cast<float16_t *>(malloc(fp16_output_size));
|
||||
if (fp16_out_ == nullptr) {
|
||||
MS_LOG(ERROR) << "malloc fp16_out_ failed.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
void ConvolutionSWFP16CPUKernel::ConfigInputOutput() {
|
||||
auto input_tensor = in_tensors_.at(kInputIndex);
|
||||
auto input_format = input_tensor->GetFormat();
|
||||
schema::Format execute_format = schema::Format_NHWC4;
|
||||
convert_func_ = LayoutTransformFp16(input_format, execute_format);
|
||||
if (convert_func_ == nullptr) {
|
||||
MS_LOG(ERROR) << "layout convert func is nullptr.";
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
int ConvolutionSWFP16CPUKernel::Init() {
|
||||
if (context_->infer_shape_interrupt_ && !context_->running_) {
|
||||
set_need_reinit();
|
||||
return RET_OK;
|
||||
}
|
||||
auto ret = ConvolutionBaseCPUKernel::Init();
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "ConvolutionBase init fail!ret: " << ret;
|
||||
return ret;
|
||||
}
|
||||
ret = InitWeightBias();
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "Init weight bias failed.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
ret = InitTmpBuffer();
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "Init tmp buffer failed.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
ConfigInputOutput();
|
||||
|
||||
// init sliding window param
|
||||
slidingWindow_param_ = new SlidingWindowParam;
|
||||
InitSlidingParamConv(slidingWindow_param_, conv_param_, C4NUM);
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
int ConvolutionSWFP16CPUKernel::ReSize() {
|
||||
if (tmp_output_block_ != nullptr) {
|
||||
free(tmp_output_block_);
|
||||
}
|
||||
if (nhwc4_input_ != nullptr) {
|
||||
free(nhwc4_input_);
|
||||
}
|
||||
if (fp16_input_ != nullptr) {
|
||||
free(fp16_input_);
|
||||
}
|
||||
if (fp16_out_ != nullptr) {
|
||||
free(fp16_out_);
|
||||
}
|
||||
delete slidingWindow_param_;
|
||||
|
||||
auto ret = ConvolutionBaseCPUKernel::Init();
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "ConvolutionBase init failed.";
|
||||
return ret;
|
||||
}
|
||||
ret = InitTmpBuffer();
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "Init tmp buffer failed.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
// init sliding window param
|
||||
slidingWindow_param_ = new SlidingWindowParam;
|
||||
InitSlidingParamConv(slidingWindow_param_, conv_param_, C4NUM);
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
int ConvolutionSWFP16CPUKernel::RunImpl(int task_id) {
|
||||
ConvSWFp16(reinterpret_cast<float16_t *>(nhwc4_input_), packed_weight_, reinterpret_cast<float16_t *>(bias_data_),
|
||||
tmp_output_block_, fp16_out_, task_id, conv_param_, slidingWindow_param_);
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
int ConvolutionSWFp16Impl(int task_id, LiteParallelGroupEnv *penv, void *cdata) {
|
||||
auto conv = reinterpret_cast<ConvolutionSWFP16CPUKernel *>(cdata);
|
||||
auto error_code = conv->RunImpl(task_id);
|
||||
if (error_code != RET_OK) {
|
||||
MS_LOG(ERROR) << "ConvolutionFp16 Run error task_id[" << task_id << "] error_code[" << error_code << "]";
|
||||
return RET_ERROR;
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
int ConvolutionSWFP16CPUKernel::Run() {
|
||||
auto ret = Prepare();
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "Prepare failed.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
auto input_tensor = in_tensors_.at(kInputIndex);
|
||||
auto input_ele_num = input_tensor->ElementsNum();
|
||||
auto ori_input_data = reinterpret_cast<float *>(input_tensor->Data());
|
||||
Float32ToFloat16(ori_input_data, fp16_input_, input_ele_num);
|
||||
|
||||
int in_batch = conv_param_->input_batch_;
|
||||
int in_h = conv_param_->input_h_;
|
||||
int in_w = conv_param_->input_w_;
|
||||
int in_channel = conv_param_->input_channel_;
|
||||
convert_func_(reinterpret_cast<void *>(fp16_input_), nhwc4_input_, in_batch, in_h * in_w, in_channel);
|
||||
|
||||
int error_code = LiteBackendParallelLaunch(ConvolutionSWFp16Impl, this, thread_count_);
|
||||
if (error_code != RET_OK) {
|
||||
MS_LOG(ERROR) << "conv fp16 error error_code[" << error_code << "]";
|
||||
return RET_ERROR;
|
||||
}
|
||||
|
||||
// cast fp16 out to fp32 data
|
||||
auto out_tensor = out_tensors_.at(kOutputIndex);
|
||||
auto out_ele_num = out_tensor->ElementsNum();
|
||||
auto output_addr = reinterpret_cast<float *>(out_tensor->Data());
|
||||
Float16ToFloat32(fp16_out_, output_addr, out_ele_num);
|
||||
return RET_OK;
|
||||
}
|
||||
} // namespace mindspore::kernel
|
|
@ -0,0 +1,69 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP16_CONVOLUTION_SW_FP16_H_
|
||||
#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP16_CONVOLUTION_SW_FP16_H_
|
||||
|
||||
#include <arm_neon.h>
|
||||
#include <vector>
|
||||
#include "src/lite_kernel.h"
|
||||
#include "src/runtime/kernel/arm/base/convolution_base.h"
|
||||
|
||||
namespace mindspore::kernel {
|
||||
class ConvolutionSWFP16CPUKernel : public ConvolutionBaseCPUKernel {
|
||||
public:
|
||||
ConvolutionSWFP16CPUKernel(OpParameter *parameter, const std::vector<lite::tensor::Tensor *> &inputs,
|
||||
const std::vector<lite::tensor::Tensor *> &outputs, const Context *ctx,
|
||||
const lite::Primitive *primitive)
|
||||
: ConvolutionBaseCPUKernel(parameter, inputs, outputs, ctx, primitive) {}
|
||||
~ConvolutionSWFP16CPUKernel() override {
|
||||
if (fp16_input_ != nullptr) {
|
||||
free(fp16_input_);
|
||||
}
|
||||
if (fp16_weight_ != nullptr) {
|
||||
free(fp16_weight_);
|
||||
}
|
||||
if (fp16_out_ != nullptr) {
|
||||
free(fp16_out_);
|
||||
}
|
||||
if (packed_weight_ != nullptr) {
|
||||
free(packed_weight_);
|
||||
}
|
||||
if (tmp_output_block_ != nullptr) {
|
||||
free(tmp_output_block_);
|
||||
}
|
||||
delete slidingWindow_param_;
|
||||
}
|
||||
|
||||
int Init() override;
|
||||
int ReSize() override;
|
||||
int Run() override;
|
||||
int RunImpl(int task_id);
|
||||
int InitWeightBias();
|
||||
int InitTmpBuffer();
|
||||
void ConfigInputOutput();
|
||||
int ProcessFilter();
|
||||
|
||||
private:
|
||||
float16_t *fp16_input_;
|
||||
float16_t *fp16_weight_;
|
||||
float16_t *fp16_out_;
|
||||
float16_t *packed_weight_;
|
||||
float16_t *tmp_output_block_;
|
||||
SlidingWindowParam *slidingWindow_param_;
|
||||
};
|
||||
} // namespace mindspore::kernel
|
||||
|
||||
#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP16_CONVOLUTION_SW_FP16_H_
|
|
@ -259,6 +259,21 @@ int Convolution3x3CPUKernel::Run() {
|
|||
MS_LOG(ERROR) << "conv3x3 error error_code[" << error_code << "]";
|
||||
return RET_ERROR;
|
||||
}
|
||||
|
||||
auto is_relu = conv_param_->is_relu_;
|
||||
auto is_relu6 = conv_param_->is_relu6_;
|
||||
auto output_addr = reinterpret_cast<float *>(out_tensors_.at(kOutputIndex)->Data());
|
||||
PackNC4HW4ToNHWCFp32(nc4hw4_out_, output_addr, conv_param_->output_batch_,
|
||||
conv_param_->output_h_ * conv_param_->output_w_, conv_param_->output_channel_);
|
||||
int output_num =
|
||||
conv_param_->output_channel_ * conv_param_->output_h_ * conv_param_->output_w_ * conv_param_->output_batch_;
|
||||
if (is_relu) {
|
||||
ReluFp32(output_addr, output_addr, output_num);
|
||||
} else if (is_relu6) {
|
||||
Relu6Fp32(output_addr, output_addr, output_num);
|
||||
} else {
|
||||
// do nothing
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
} // namespace mindspore::kernel
|
||||
|
|
|
@ -189,8 +189,8 @@ int ConvolutionWinogradCPUKernel::InitTmpBuffer() {
|
|||
/*=============================tmp_out_data_============================*/
|
||||
int out_w_block = UP_DIV(output_w, output_unit_);
|
||||
int out_h_block = UP_DIV(output_h, output_unit_);
|
||||
tmp_out_data_ = reinterpret_cast<float *>(
|
||||
malloc(out_w_block * out_h_block * output_unit_ * output_unit_ * oc4 * C4NUM * sizeof(float)));
|
||||
tmp_out_data_ = reinterpret_cast<float *>(malloc(conv_param_->output_batch_ * out_w_block * out_h_block *
|
||||
output_unit_ * output_unit_ * oc4 * C4NUM * sizeof(float)));
|
||||
if (tmp_out_data_ == nullptr) {
|
||||
MS_LOG(ERROR) << "malloc tmp_out_data_ failed.";
|
||||
return RET_ERROR;
|
||||
|
@ -365,6 +365,22 @@ int ConvolutionWinogradCPUKernel::Run() {
|
|||
MS_LOG(ERROR) << "conv winograd error error_code[" << error_code << "]";
|
||||
return RET_ERROR;
|
||||
}
|
||||
|
||||
// get real output
|
||||
auto out_tensor = out_tensors_.front();
|
||||
auto out_data = reinterpret_cast<float *>(out_tensor->Data());
|
||||
UnPackWinogradOutput(tmp_out_data_, out_data, conv_param_->output_batch_, conv_param_->output_h_,
|
||||
conv_param_->output_w_, conv_param_->output_channel_, output_unit_);
|
||||
int output_num =
|
||||
conv_param_->output_channel_ * conv_param_->output_h_ * conv_param_->output_w_ * conv_param_->output_batch_;
|
||||
if (conv_param_->is_relu_) {
|
||||
ReluFp32(out_data, out_data, output_num);
|
||||
} else if (conv_param_->is_relu6_) {
|
||||
Relu6Fp32(out_data, out_data, output_num);
|
||||
} else {
|
||||
// do nothing
|
||||
}
|
||||
|
||||
return RET_OK;
|
||||
}
|
||||
} // namespace mindspore::kernel
|
||||
|
|
|
@ -18,7 +18,6 @@
|
|||
#include "nnacl/fp16/pack_fp16.h"
|
||||
#include "nnacl/fp16/winograd_transform_fp16.h"
|
||||
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif
|
||||
|
@ -112,6 +111,209 @@ void IndirectGemmFp16_16x8_tmp(float16_t *output, float16_t *input, float16_t *w
|
|||
}
|
||||
#endif
|
||||
|
||||
void SWBorderPixel(float16_t *dst, const float16_t *src, const float16_t *weight, const float16_t *bias, int height,
|
||||
int width, int in_kh_step, int in_kw_step, int kernel_h, int kernel_w, int ic, bool is_relu,
|
||||
bool is_relu6) {
|
||||
int ic8 = ic / C8NUM;
|
||||
int ic8_res = ic8 % C8NUM;
|
||||
int ic4 = ic8_res / C4NUM;
|
||||
for (int c = 0; c < C4NUM; c++) {
|
||||
dst[c] = 0;
|
||||
}
|
||||
const float16_t *weight_oc = weight;
|
||||
for (int oc = 0; oc < C4NUM; ++oc) {
|
||||
const float16_t *weight_kh = weight_oc;
|
||||
const float16_t *src_kh = src;
|
||||
for (int kh = 0; kh < height; kh++) {
|
||||
const float16_t *src_kw = src_kh;
|
||||
const float16_t *weight_kw = weight_kh;
|
||||
for (int kw = 0; kw < width; kw++) {
|
||||
const float16_t *src_ic8 = src_kw;
|
||||
const float16_t *weight_ic8 = weight_kw;
|
||||
|
||||
for (int rc = 0; rc < ic8; ++rc) {
|
||||
for (int c = 0; c < C8NUM; c++) {
|
||||
dst[oc] += src_ic8[c] * weight_ic8[c];
|
||||
}
|
||||
src_ic8 += C8NUM;
|
||||
weight_ic8 += C8NUM;
|
||||
} // ic8 loop
|
||||
|
||||
const float16_t *src_ic4 = src_ic8;
|
||||
const float16_t *weight_ic4 = weight_ic8;
|
||||
for (int rc = 0; rc < ic4; ++rc) {
|
||||
for (int c = 0; c < C4NUM; c++) {
|
||||
dst[oc] += src_ic4[c] * weight_ic4[c];
|
||||
}
|
||||
src_ic4 += C4NUM;
|
||||
weight_ic4 += C4NUM;
|
||||
} // ic4 loop
|
||||
|
||||
src_kw += in_kw_step;
|
||||
weight_kw += ic4 * C4NUM;
|
||||
} // kernel_w loop
|
||||
src_kh += in_kh_step;
|
||||
weight_kh += kernel_w * ic4 * C4NUM;
|
||||
} // kernel_h loop
|
||||
dst[oc] += bias[oc];
|
||||
dst[oc] = (is_relu) ? (MSMAX(0, dst[oc])) : (dst[oc]);
|
||||
dst[oc] = (is_relu6) ? (MSMIN(6, MSMAX(0, dst[oc]))) : (dst[oc]);
|
||||
weight_oc += kernel_h * kernel_w * ic4 * C4NUM;
|
||||
} // oc loop
|
||||
}
|
||||
|
||||
void SWBorderFp16(float16_t *dst, const float16_t *src, const float16_t *weight, const float16_t *bias, int top,
|
||||
int bottom, int left, int right, const ConvParameter *conv_param, const SlidingWindowParam *sliding) {
|
||||
float16_t *dst_h = dst + top * sliding->out_h_step_;
|
||||
for (int oh = top; oh < bottom; oh++) {
|
||||
int ih = oh * conv_param->stride_h_ - conv_param->pad_h_;
|
||||
int start_kh = MSMAX(0, UP_DIV(-ih, conv_param->dilation_h_));
|
||||
int end_kh = MSMIN(conv_param->kernel_h_, UP_DIV(conv_param->input_h_ - ih, conv_param->dilation_h_));
|
||||
const float16_t *src_h = src + ih * sliding->in_h_step_;
|
||||
|
||||
float16_t *dst_kernel = dst_h + left * sliding->block_channel_;
|
||||
for (int ow = left; ow < right; ow++) {
|
||||
int iw = ow * conv_param->stride_w_ - conv_param->pad_w_;
|
||||
int start_kw = MSMAX(0, UP_DIV(-iw, conv_param->dilation_w_));
|
||||
int end_kw = MSMIN(conv_param->kernel_w_, UP_DIV(conv_param->input_w_ - iw, conv_param->dilation_w_));
|
||||
const float16_t *src_w = src_h + iw * sliding->ic4_channel_;
|
||||
|
||||
const float16_t *src_kernel = src_w + start_kh * sliding->in_kh_step_ + start_kw * sliding->in_kw_step_;
|
||||
const float16_t *weight_kernel = weight + (start_kh * conv_param->kernel_w_ + start_kw) * sliding->ic4_channel_;
|
||||
|
||||
SWBorderPixel(dst_kernel, src_kernel, weight_kernel, bias, end_kh - start_kh, end_kw - start_kw,
|
||||
sliding->in_kh_step_, sliding->in_kw_step_, conv_param->kernel_h_, conv_param->kernel_w_,
|
||||
sliding->ic4_channel_, conv_param->is_relu_, conv_param->is_relu6_);
|
||||
|
||||
dst_kernel += sliding->block_channel_;
|
||||
} // width loop
|
||||
dst_h += sliding->out_h_step_;
|
||||
} // height loop
|
||||
}
|
||||
|
||||
void SWCenterFp16(float16_t *dst, const float16_t *src, const float16_t *weight, const float16_t *bias, int height,
|
||||
int width, int kernel_h, int kernel_w, int out_h_step, int block_channel, int ic, int in_sh_step,
|
||||
int in_sw_step, int in_kh_step, int in_kw_step, bool is_relu, bool is_relu6) {
|
||||
int ic8 = ic / C8NUM;
|
||||
int ic8_res = ic % C8NUM;
|
||||
int ic4 = ic8_res / C4NUM;
|
||||
float16_t *dst_h = dst;
|
||||
const float16_t *src_h = src;
|
||||
for (int oh = 0; oh < height; oh++) {
|
||||
float16_t *dst_w = dst_h;
|
||||
const float16_t *src_w = src_h;
|
||||
for (int ow = 0; ow < width; ow++) {
|
||||
const float16_t *weight_oc = weight;
|
||||
for (int c = 0; c < C4NUM; c++) {
|
||||
dst_w[c] = 0;
|
||||
}
|
||||
|
||||
for (int oc = 0; oc < C4NUM; oc++) {
|
||||
const float16_t *weight_kh = weight_oc;
|
||||
const float16_t *src_kh = src_w;
|
||||
for (int kh = 0; kh < kernel_h; kh++) {
|
||||
const float16_t *src_kw = src_kh;
|
||||
const float16_t *weight_kw = weight_kh;
|
||||
for (int kw = 0; kw < kernel_w; kw++) {
|
||||
const float16_t *src_ic8 = src_kw;
|
||||
const float16_t *weight_ic8 = weight_kw;
|
||||
|
||||
for (int rc = 0; rc < ic8; ++rc) {
|
||||
for (int c = 0; c < C8NUM; c++) {
|
||||
dst_w[oc] += src_ic8[c] * weight_ic8[c];
|
||||
}
|
||||
|
||||
src_ic8 += C8NUM;
|
||||
weight_ic8 += C8NUM;
|
||||
} // ic8 loop
|
||||
|
||||
const float16_t *src_ic4 = src_ic8;
|
||||
const float16_t *weight_ic4 = weight_ic8;
|
||||
for (int rc = 0; rc < ic4; ++rc) {
|
||||
for (int c = 0; c < C4NUM; c++) {
|
||||
dst_w[oc] += src_ic4[c] * weight_ic4[c];
|
||||
}
|
||||
|
||||
src_ic4 += C4NUM;
|
||||
weight_ic4 += C4NUM;
|
||||
} // ic4 loop
|
||||
|
||||
src_kw += in_kw_step;
|
||||
weight_kw += ic4 * C4NUM;
|
||||
} // kernel_w loop
|
||||
src_kh += in_kh_step;
|
||||
weight_kh += kernel_w * ic4 * C4NUM;
|
||||
} // kernel_h loop
|
||||
// add biad relu
|
||||
|
||||
dst_w[oc] += bias[oc];
|
||||
dst_w[oc] = (is_relu) ? (MSMAX(0, dst_w[oc])) : (dst_w[oc]);
|
||||
dst_w[oc] = (is_relu6) ? (MSMIN(6, MSMAX(0, dst_w[oc]))) : (dst_w[oc]);
|
||||
weight_oc += kernel_h * kernel_w * ic4 * C4NUM;
|
||||
} // oc block
|
||||
|
||||
dst_w += block_channel;
|
||||
src_w += in_sw_step;
|
||||
} // dst_width loop
|
||||
dst_h += out_h_step;
|
||||
src_h += in_sh_step;
|
||||
} // dst_height loop
|
||||
}
|
||||
|
||||
// fp16 conv sliding window
|
||||
void ConvSWFp16(const float16_t *input_data, const float16_t *packed_weight, const float16_t *bias_data,
|
||||
float16_t *tmp_out_block, float16_t *output_data, int task_id, ConvParameter *conv_param,
|
||||
SlidingWindowParam *slidingWindow_param) {
|
||||
int oc4_res = conv_param->output_channel_ % C4NUM;
|
||||
const float16_t *src = input_data;
|
||||
float16_t *dst;
|
||||
if (oc4_res == 0) {
|
||||
dst = output_data;
|
||||
} else {
|
||||
dst = tmp_out_block;
|
||||
}
|
||||
|
||||
for (int b = 0; b < conv_param->output_batch_; b++) {
|
||||
for (int oc = task_id; oc < slidingWindow_param->c_block_; oc += conv_param->thread_num_) {
|
||||
const float16_t *src_data = src;
|
||||
float16_t *dst_data = dst + oc * C4NUM;
|
||||
const float16_t *weight = packed_weight + oc * slidingWindow_param->kernel_step_;
|
||||
const float16_t *bias = bias_data + oc * C4NUM;
|
||||
SWBorderFp16(dst_data, src_data, weight, bias, 0, slidingWindow_param->top_, 0, conv_param->output_w_, conv_param,
|
||||
slidingWindow_param);
|
||||
SWBorderFp16(dst_data, src_data, weight, bias, slidingWindow_param->bottom_, conv_param->output_h_, 0,
|
||||
conv_param->output_w_, conv_param, slidingWindow_param);
|
||||
SWBorderFp16(dst_data, src_data, weight, bias, slidingWindow_param->top_, slidingWindow_param->bottom_, 0,
|
||||
slidingWindow_param->left_, conv_param, slidingWindow_param);
|
||||
SWBorderFp16(dst_data, src_data, weight, bias, slidingWindow_param->top_, slidingWindow_param->bottom_,
|
||||
slidingWindow_param->right_, conv_param->output_w_, conv_param, slidingWindow_param);
|
||||
|
||||
if (slidingWindow_param->right_ > slidingWindow_param->left_ &&
|
||||
slidingWindow_param->bottom_ > slidingWindow_param->top_) {
|
||||
int in_h_start = slidingWindow_param->top_ * conv_param->stride_h_ - conv_param->pad_h_;
|
||||
int in_w_start = slidingWindow_param->left_ * conv_param->stride_w_ - conv_param->pad_w_;
|
||||
const float16_t *in_t =
|
||||
src_data + in_h_start * slidingWindow_param->in_h_step_ + in_w_start * slidingWindow_param->ic4_channel_;
|
||||
float16_t *out_t = dst_data + slidingWindow_param->top_ * slidingWindow_param->out_h_step_ +
|
||||
slidingWindow_param->left_ * slidingWindow_param->block_channel_;
|
||||
SWCenterFp16(out_t, in_t, weight, bias, slidingWindow_param->bottom_ - slidingWindow_param->top_,
|
||||
slidingWindow_param->right_ - slidingWindow_param->left_, conv_param->kernel_h_,
|
||||
conv_param->kernel_w_, slidingWindow_param->out_h_step_, slidingWindow_param->block_channel_,
|
||||
slidingWindow_param->ic4_channel_, slidingWindow_param->in_sh_step_,
|
||||
slidingWindow_param->in_sw_step_, slidingWindow_param->in_kh_step_,
|
||||
slidingWindow_param->in_kw_step_, conv_param->is_relu_, conv_param->is_relu6_);
|
||||
}
|
||||
} // output C4 loop
|
||||
src += slidingWindow_param->in_step_;
|
||||
dst += slidingWindow_param->out_step_;
|
||||
} // batch loop
|
||||
// output nhwc4
|
||||
if (oc4_res != 0) {
|
||||
PackNHWC4ToNHWCFp16((const void *)tmp_out_block, (void *)output_data, conv_param->output_batch_,
|
||||
conv_param->output_h_ * conv_param->output_w_, conv_param->output_channel_);
|
||||
}
|
||||
}
|
||||
|
||||
// fp16 convolution common (im2col+gemm)
|
||||
void ConvFp16(float16_t *input_data, float16_t *packed_input, float16_t *packed_weight, float16_t *bias_data,
|
||||
float16_t *tmp_out_block, float16_t *output_data, int task_id, ConvParameter *conv_param) {
|
||||
|
@ -144,7 +346,7 @@ void ConvFp16(float16_t *input_data, float16_t *packed_input, float16_t *packed_
|
|||
// we write 32 bytes per st1 instruction, after which the pointer in register will step 32B forward
|
||||
|
||||
for (int b = 0; b < in_batch; b++) {
|
||||
int in_batch_offset = b * in_channel * in_h * in_w;
|
||||
int in_batch_offset = b * ic4 * C4NUM * in_h * in_w;
|
||||
int out_batch_offset = b * out_channel * out_h * out_w;
|
||||
int gemm_in_batch_offset = b * packed_input_size;
|
||||
for (int thread_id = task_id; thread_id < output_tile_count; thread_id += thread_count) {
|
||||
|
@ -172,7 +374,6 @@ void ConvFp16(float16_t *input_data, float16_t *packed_input, float16_t *packed_
|
|||
void Conv3x3Fp16(float16_t *input_data, float16_t *transed_weight, const float16_t *bias_data, float16_t *output_data,
|
||||
float16_t *tile_buffer, float16_t *block_unit_buffer, float16_t *tmp_dst_buffer, float16_t *tmp_out,
|
||||
int task_id, ConvParameter *conv_param) {
|
||||
// todo
|
||||
int thread_count = conv_param->thread_num_;
|
||||
int tile_num = 16;
|
||||
const int output_unit = 4;
|
||||
|
@ -195,6 +396,8 @@ void Conv3x3Fp16(float16_t *input_data, float16_t *transed_weight, const float16
|
|||
|
||||
int input_batch = conv_param->input_batch_;
|
||||
for (int batch = 0; batch < input_batch; batch++) {
|
||||
int in_batch_offset = batch * ic4 * C4NUM * conv_param->input_h_ * conv_param->input_w_;
|
||||
int tmp_out_batch_offset = batch * oc8 * C8NUM * out_w_block * out_h_block * output_unit * output_unit;
|
||||
for (int thread_id = task_id; thread_id < output_tile_count; thread_id += thread_count) {
|
||||
int start_index = thread_id * tile_num;
|
||||
int real_cal_num = (output_count - start_index) < tile_num ? (output_count - start_index) : tile_num;
|
||||
|
@ -207,8 +410,8 @@ void Conv3x3Fp16(float16_t *input_data, float16_t *transed_weight, const float16
|
|||
tile_buffer + task_id * tile_buffer_offset, transed_weight, NULL, 36, ic4, oc8 * C8NUM,
|
||||
oc8 * C8NUM * 36 * sizeof(float16_t), 1, 1, 0, 0);
|
||||
|
||||
Conv3x3Fp16OutputTransform(tmp_dst_buffer + task_id * tmp_dst_buffer_offset, tmp_out, bias_data, start_index,
|
||||
real_cal_num, out_w_block, conv_param);
|
||||
Conv3x3Fp16OutputTransform(tmp_dst_buffer + task_id * tmp_dst_buffer_offset, tmp_out + tmp_out_batch_offset,
|
||||
bias_data, start_index, real_cal_num, out_w_block, conv_param);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -217,7 +420,10 @@ void Conv3x3Fp16(float16_t *input_data, float16_t *transed_weight, const float16
|
|||
bool relu = conv_param->is_relu_;
|
||||
bool relu6 = conv_param->is_relu6_;
|
||||
for (int batch = 0; batch < output_batch; batch++) {
|
||||
int batch_size = batch * output_channel * output_h * output_w;
|
||||
int tmp_out_batch_offset = batch * oc8 * C8NUM * out_w_block * out_h_block * output_unit * output_unit;
|
||||
int ro_batch_size = batch * output_channel * output_h * output_w;
|
||||
const float16_t *batch_tmp_out = tmp_out + tmp_out_batch_offset;
|
||||
float16_t *batch_out = output_data + ro_batch_size;
|
||||
for (int h = 0; h < output_h; h++) {
|
||||
for (int w = 0; w < output_w; w++) {
|
||||
for (int c = 0; c < output_channel; c++) {
|
||||
|
@ -226,12 +432,12 @@ void Conv3x3Fp16(float16_t *input_data, float16_t *transed_weight, const float16
|
|||
int src_offset = oc8_block * C8NUM * out_w_block * out_h_block * C4NUM * C4NUM +
|
||||
C8NUM * (h * out_w_block * output_unit + w) + oc8_res;
|
||||
int dst_offset = (h * output_w + w) * output_channel + c;
|
||||
(output_data + dst_offset)[0] = (tmp_out + src_offset)[0];
|
||||
(batch_out + dst_offset)[0] = (batch_tmp_out + src_offset)[0];
|
||||
if (relu) {
|
||||
(output_data + dst_offset)[0] = (output_data + dst_offset)[0] < 0 ? 0 : (output_data + dst_offset)[0];
|
||||
(batch_out + dst_offset)[0] = (batch_out + dst_offset)[0] < 0 ? 0 : (batch_out + dst_offset)[0];
|
||||
} else if (relu6) {
|
||||
(output_data + dst_offset)[0] = (output_data + dst_offset)[0] < 0 ? 0 : (output_data + dst_offset)[0];
|
||||
(output_data + dst_offset)[0] = (output_data + dst_offset)[0] > 6 ? 6 : (output_data + dst_offset)[0];
|
||||
(batch_out + dst_offset)[0] = (batch_out + dst_offset)[0] < 0 ? 0 : (batch_out + dst_offset)[0];
|
||||
(batch_out + dst_offset)[0] = (batch_out + dst_offset)[0] > 6 ? 6 : (batch_out + dst_offset)[0];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -28,6 +28,18 @@ void IndirectGemmFp16_16x8(float16_t *output, float16_t *input, float16_t *weigh
|
|||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif
|
||||
void SWBorderFp16(float16_t *dst, const float16_t *src, const float16_t *weight, const float16_t *bias, int top,
|
||||
int bottom, int left, int right, const ConvParameter *conv_param, const SlidingWindowParam *sliding);
|
||||
|
||||
void SWCenterFp16(float16_t *dst, const float16_t *src, const float16_t *weight, const float16_t *bias, int height,
|
||||
int width, int kernel_h, int kernel_w, int out_h_step, int block_channel, int ic, int in_sh_step,
|
||||
int in_sw_step, int in_kh_step, int in_kw_step, bool is_relu, bool is_relu6);
|
||||
|
||||
// fp16 sliding window
|
||||
void ConvSWFp16(const float16_t *input_data, const float16_t *packed_weight, const float16_t *bias_data,
|
||||
float16_t *tmp_out_block, float16_t *output_data, int task_id, ConvParameter *conv_param,
|
||||
SlidingWindowParam *slidingWindow_param);
|
||||
|
||||
// fp16 convolution common (im2col+gemm)
|
||||
void ConvFp16(float16_t *input_data, float16_t *packed_input, float16_t *packed_weight, float16_t *bias_data,
|
||||
float16_t *tmp_out_block, float16_t *output_data, int task_id, ConvParameter *conv_param);
|
||||
|
|
|
@ -219,6 +219,24 @@ void PackNHWCToNHWC4Fp16(const void *src, void *dst, int batch, int plane, int c
|
|||
}
|
||||
}
|
||||
|
||||
void PackNHWC4ToNHWCFp16(const void *src, void *dst, int batch, int plane, int channel) {
|
||||
int c4 = UP_DIV(channel, C4NUM);
|
||||
int ic_remainder_ = channel % C4NUM;
|
||||
if (ic_remainder_ != 0) {
|
||||
int nhwc_batch_unit_offset = channel * plane;
|
||||
for (int b = 0; b < batch; b++) {
|
||||
int batch_offset = b * c4 * C4NUM * plane;
|
||||
for (int i = 0; i < plane; i++) {
|
||||
memcpy((float16_t *)dst + b * nhwc_batch_unit_offset + i * channel,
|
||||
(float16_t *)src + batch_offset + i * c4 * C4NUM, channel * sizeof(float16_t));
|
||||
}
|
||||
}
|
||||
} else {
|
||||
size_t ori_input_size = batch * plane * channel * sizeof(float16_t);
|
||||
memcpy((float16_t *)dst, (float16_t *)src, ori_input_size);
|
||||
}
|
||||
}
|
||||
|
||||
void PackNCHWToNHWC4Fp16(const void *src, void *dst, int batch, int plane, int channel) {
|
||||
int nhwc4_batch_offset = 0;
|
||||
int ic4 = UP_DIV(channel, C4NUM);
|
||||
|
|
|
@ -41,6 +41,8 @@ void PackNCHWToNC4HW4Fp16(const void *src, void *dst, int batch, int plane, int
|
|||
|
||||
void PackNHWCToNHWC4Fp16(const void *src, void *dst, int batch, int plane, int channel);
|
||||
|
||||
void PackNHWC4ToNHWCFp16(const void *src, void *dst, int batch, int plane, int channel);
|
||||
|
||||
void PackNCHWToNHWC4Fp16(const void *src, void *dst, int batch, int plane, int channel);
|
||||
|
||||
void PackNC4HW4ToNHWC4Fp16(const void *src, void *dst, int batch, int plane, int channel);
|
||||
|
|
|
@ -217,7 +217,7 @@ void ConvFp32(float *input_data, float *packed_input, float *packed_weight, cons
|
|||
size_t output_offset = out_channel * sizeof(float);
|
||||
|
||||
for (int b = 0; b < in_batch; b++) {
|
||||
int in_batch_offset = b * in_channel * in_h * in_w;
|
||||
int in_batch_offset = b * ic4 * C4NUM * in_h * in_w;
|
||||
int out_batch_offset = b * out_channel * out_h * out_w;
|
||||
int gemm_in_batch_offset = b * packed_input_size;
|
||||
for (int thread_id = task_id; thread_id < output_tile_count; thread_id += thread_count) {
|
||||
|
@ -263,12 +263,9 @@ void ConvWinogardFp32(float *input_data, float *trans_weight, const float *bias_
|
|||
int output_count = out_w_block * out_h_block;
|
||||
int output_tile_count = UP_DIV(output_count, TILE_NUM);
|
||||
int out_channel = conv_param->output_channel_;
|
||||
int out_batch = conv_param->output_batch_;
|
||||
int oc4 = UP_DIV(out_channel, C4NUM);
|
||||
int input_unit_square = input_unit * input_unit;
|
||||
size_t output_offset = oc4 * C4NUM * input_unit_square * sizeof(float);
|
||||
bool is_relu = conv_param->is_relu_;
|
||||
bool is_relu6 = conv_param->is_relu6_;
|
||||
|
||||
float *trans_input = buffer_list[0];
|
||||
float *gemm_out = buffer_list[1];
|
||||
|
@ -280,11 +277,13 @@ void ConvWinogardFp32(float *input_data, float *trans_weight, const float *bias_
|
|||
// step 1 : filter transform (pre-processed offline)
|
||||
// step 2 : input transform (online)
|
||||
for (int b = 0; b < in_batch; b++) {
|
||||
int in_batch_offset = b * ic4 * C4NUM * conv_param->input_h_ * conv_param->input_w_;
|
||||
int tmp_out_batch_offset = b * out_w_block * out_h_block * out_unit * out_unit * oc4 * C4NUM;
|
||||
for (int thread_id = task_id; thread_id < output_tile_count; thread_id += thread_num) {
|
||||
int out_tile_index = thread_id * TILE_NUM;
|
||||
int cal_num = output_count - thread_id * TILE_NUM;
|
||||
cal_num = cal_num > TILE_NUM ? TILE_NUM : cal_num;
|
||||
WinogradInputTransform(input_data, trans_input + task_id * trans_input_offset,
|
||||
WinogradInputTransform(input_data + in_batch_offset, trans_input + task_id * trans_input_offset,
|
||||
tmp_data + task_id * tmp_data_offset, cal_num, out_tile_index, out_w_block, conv_param,
|
||||
input_trans_func);
|
||||
// step 3 : gemm
|
||||
|
@ -292,21 +291,10 @@ void ConvWinogardFp32(float *input_data, float *trans_weight, const float *bias_
|
|||
input_unit_square, ic4, oc4 * C4NUM, output_offset, 1, 1, 0, 0);
|
||||
|
||||
// step 4 : output transform
|
||||
WinogradOutputTransform(gemm_out + task_id * gemm_out_offset, tmp_out_data, bias_data, cal_num, out_tile_index,
|
||||
out_w_block, conv_param, output_trans_func);
|
||||
WinogradOutputTransform(gemm_out + task_id * gemm_out_offset, tmp_out_data + tmp_out_batch_offset, bias_data,
|
||||
cal_num, out_tile_index, out_w_block, conv_param, output_trans_func);
|
||||
}
|
||||
}
|
||||
// get real output
|
||||
UnPackWinogradOutput(tmp_out_data, output_data, out_batch, conv_param->output_h_, conv_param->output_w_, out_channel,
|
||||
out_unit);
|
||||
int output_num = out_channel * conv_param->output_h_ * conv_param->output_w_ * conv_param->output_batch_;
|
||||
if (is_relu) {
|
||||
ReluFp32(output_data, output_data, output_num);
|
||||
} else if (is_relu6) {
|
||||
Relu6Fp32(output_data, output_data, output_num);
|
||||
} else {
|
||||
// do nothing
|
||||
}
|
||||
}
|
||||
|
||||
void UnPackWinogradOutput(const float *src, float *dst, int batch, int height, int width, int channel,
|
||||
|
@ -360,8 +348,6 @@ void Conv3x3Fp32(float *input_data, float *transed_weight, const float *bias_dat
|
|||
int output_count = out_w_block * out_h_block;
|
||||
int output_tile_count = UP_DIV(output_count, TILE_NUM);
|
||||
int input_unit_square = 4 * 4;
|
||||
bool is_relu = conv_param->is_relu_;
|
||||
bool is_relu6 = conv_param->is_relu6_;
|
||||
float *tile_buffer = buffer_list[0];
|
||||
float *block_unit_buffer = buffer_list[1];
|
||||
float *tmp_dst_buffer = buffer_list[2];
|
||||
|
@ -372,10 +358,13 @@ void Conv3x3Fp32(float *input_data, float *transed_weight, const float *bias_dat
|
|||
|
||||
int input_batch = conv_param->input_batch_;
|
||||
for (int batch = 0; batch < input_batch; batch++) {
|
||||
int in_batch_offset = batch * ic4 * C4NUM * conv_param->input_h_ * conv_param->input_w_;
|
||||
int nc4hw4_buffer_offset = batch * oc4 * C4NUM * conv_param->output_h_ * conv_param->output_w_;
|
||||
|
||||
for (int thread_id = task_id; thread_id < output_tile_count; thread_id += thread_count) {
|
||||
int start_index = thread_id * TILE_NUM;
|
||||
int real_cal_num = (output_count - start_index) < TILE_NUM ? (output_count - start_index) : TILE_NUM;
|
||||
Conv3x3Fp32InputTransform(input_data, tile_buffer + task_id * tile_buffer_offset,
|
||||
Conv3x3Fp32InputTransform(input_data + in_batch_offset, tile_buffer + task_id * tile_buffer_offset,
|
||||
block_unit_buffer + task_id * block_unit_buffer_offset, start_index, real_cal_num,
|
||||
out_w_block, conv_param);
|
||||
|
||||
|
@ -383,17 +372,8 @@ void Conv3x3Fp32(float *input_data, float *transed_weight, const float *bias_dat
|
|||
transed_weight, NULL, input_unit_square, ic4, oc4 * C4NUM,
|
||||
oc4 * C4NUM * input_unit_square * sizeof(float), 1, 1, 0, 0);
|
||||
|
||||
Conv3x3Fp32OutputTransform(tmp_dst_buffer + task_id * tmp_dst_buffer_offset, nc4hw4_out, bias_data, start_index,
|
||||
real_cal_num, out_w_block, conv_param);
|
||||
Conv3x3Fp32OutputTransform(tmp_dst_buffer + task_id * tmp_dst_buffer_offset, nc4hw4_out + nc4hw4_buffer_offset,
|
||||
bias_data, start_index, real_cal_num, out_w_block, conv_param);
|
||||
}
|
||||
PackNC4HW4ToNHWCFp32(nc4hw4_out, output_data, 1, conv_param->output_h_ * conv_param->output_w_, output_channel);
|
||||
}
|
||||
int output_num = output_channel * conv_param->output_h_ * conv_param->output_w_ * conv_param->output_batch_;
|
||||
if (is_relu) {
|
||||
ReluFp32(output_data, output_data, output_num);
|
||||
} else if (is_relu6) {
|
||||
Relu6Fp32(output_data, output_data, output_num);
|
||||
} else {
|
||||
// do nothing
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
hiai_model_0909_kd_rot_ps_softmax.tflite
|
||||
hiai_chinese_english_recognize_model_float32.tflite
|
||||
#hiai_bigmodel_ghost_2_1_no_normalized_no_trans_tflite.tflite
|
||||
hiai_bigmodel_ghost_2_1_no_normalized_no_trans_tflite.tflite
|
||||
hiai_bigmodel_ghost_5_1_no_normalized_no_trans_tflite.tflite
|
||||
hiai_cn_recognize_modify_padv2.tflite
|
||||
hiai_model_normalize_object_scene_ps_20200519.tflite
|
||||
|
|
Loading…
Reference in New Issue