!7606 [MS][LITE][CPU]add group conv

Merge pull request !7606 from fuzhiye/tmp
This commit is contained in:
mindspore-ci-bot 2020-10-22 14:23:24 +08:00 committed by Gitee
commit 038a41d973
6 changed files with 335 additions and 29 deletions

View File

@ -34,10 +34,6 @@ ConvolutionBaseCPUKernel::~ConvolutionBaseCPUKernel() {
free(bias_data_);
bias_data_ = nullptr;
}
if (nhwc4_input_ != nullptr) {
free(nhwc4_input_);
nhwc4_input_ = nullptr;
}
}
void ConvolutionBaseCPUKernel::FreeQuantParam() {
@ -112,18 +108,6 @@ int ConvolutionBaseCPUKernel::CheckResizeValid() {
return RET_OK;
}
int ConvolutionBaseCPUKernel::CheckLayout(lite::Tensor *input_tensor) {
auto data_type = input_tensor->data_type();
auto input_format = input_tensor->GetFormat();
schema::Format execute_format = schema::Format::Format_NHWC4;
convert_func_ = LayoutTransform(data_type, input_format, execute_format);
if (convert_func_ == nullptr) {
MS_LOG(ERROR) << "layout convert func is nullptr.";
return RET_ERROR;
}
return RET_OK;
}
int ConvolutionBaseCPUKernel::SetIfPerChannel() {
auto filter_tensor = in_tensors_.at(kWeightIndex);
auto input_channel = filter_tensor->Channel();

View File

@ -48,7 +48,6 @@ class ConvolutionBaseCPUKernel : public LiteKernel {
int Init() override;
int ReSize() override { return 0; }
int Run() override { return 0; }
virtual int CheckLayout(lite::Tensor *input_tensor);
int SetIfAsymmetric();
int SetIfPerChannel();
int MallocQuantParam();
@ -61,14 +60,12 @@ class ConvolutionBaseCPUKernel : public LiteKernel {
void FreeQuantParam();
protected:
int tile_num_;
void *bias_data_ = nullptr;
void *nhwc4_input_ = nullptr;
const InnerContext *ctx_;
int thread_count_;
ConvParameter *conv_param_;
ConvQuantArg *conv_quant_arg_;
LayoutConvertor convert_func_ = nullptr;
int tile_num_;
int thread_count_;
};
} // namespace mindspore::kernel

View File

@ -61,6 +61,10 @@ int ConvolutionFP16CPUKernel::InitWeightBias() {
}
memset(packed_weight_, 0, pack_weight_size * sizeof(float16_t));
RowMajor2Col8MajorFp16(execute_weight_, packed_weight_, out_channel, in_channel * kernel_plane, false);
if (fp16_weight_ != nullptr) {
free(fp16_weight_);
fp16_weight_ = nullptr;
}
// init bias
bias_data_ = malloc(oc8 * C8NUM * sizeof(float16_t));

View File

@ -17,6 +17,7 @@
#include "src/runtime/kernel/arm/fp32/convolution.h"
#include "src/runtime/kernel/arm/fp32/convolution_1x1.h"
#include "src/runtime/kernel/arm/fp32/convolution_winograd.h"
#include "src/runtime/kernel/arm/fp32/group_convolution.h"
#include "nnacl/fp32/conv.h"
#include "nnacl/common_func.h"
#include "schema/model_generated.h"
@ -31,6 +32,7 @@ using mindspore::lite::RET_ERROR;
using mindspore::lite::RET_INFER_INVALID;
using mindspore::lite::RET_OK;
using mindspore::schema::PrimitiveType_Conv2D;
using mindspore::schema::Format::Format_NHWC;
namespace mindspore::kernel {
int ConvolutionCPUKernel::InitWeightBias() {
@ -157,6 +159,108 @@ int ConvolutionCPUKernel::Run() {
return RET_OK;
}
kernel::LiteKernel *CpuConvFp32KernelSelect(const std::vector<lite::Tensor *> &inputs,
const std::vector<lite::Tensor *> &outputs, OpParameter *op_parameter,
const InnerContext *ctx, const mindspore::lite::PrimitiveC *primitive,
bool use_winograd, int out_unit) {
auto conv_param = reinterpret_cast<ConvParameter *>(op_parameter);
if (conv_param->kernel_h_ == 1 && conv_param->kernel_w_ == 1) {
return new (std::nothrow) kernel::Convolution1x1CPUKernel(op_parameter, inputs, outputs, ctx, primitive);
} else if (use_winograd) {
return new (std::nothrow)
kernel::ConvolutionWinogradCPUKernel(op_parameter, inputs, outputs, ctx, primitive, out_unit);
} else {
return new (std::nothrow) kernel::ConvolutionCPUKernel(op_parameter, inputs, outputs, ctx, primitive);
}
return nullptr;
}
kernel::LiteKernel *CpuGroupConvFp32KernelCreator(const std::vector<lite::Tensor *> &inputs,
const std::vector<lite::Tensor *> &outputs, OpParameter *op_parameter,
const InnerContext *ctx, const mindspore::lite::PrimitiveC *primitive,
int group) {
std::vector<kernel::LiteKernel *> group_convs;
std::vector<int> in_shape;
std::vector<int> filter_shape;
std::vector<int> bias_shape;
std::vector<int> out_shape;
auto conv_param = reinterpret_cast<ConvParameter *>(op_parameter);
int out_channel = inputs.at(kWeightIndex)->Batch();
int new_in_channel = inputs.at(kWeightIndex)->Channel();
int new_out_channel = out_channel / group;
int kernel_h = conv_param->kernel_h_;
int kernel_w = conv_param->kernel_w_;
int input_num = inputs.size();
int output_num = outputs.size();
bool has_bias = input_num == 3;
bool use_winograd = false;
int out_unit;
if (primitive != nullptr && primitive->GetInferFlag()) {
int batch = inputs.front()->Batch();
int in_h = inputs.front()->Height();
int in_w = inputs.front()->Width();
conv_param->input_channel_ = new_in_channel;
conv_param->output_channel_ = new_out_channel;
CheckIfUseWinograd(&use_winograd, &out_unit, conv_param);
in_shape = {batch, in_h, in_w, new_in_channel};
out_shape = {batch, conv_param->output_h_, conv_param->output_w_, new_out_channel};
}
filter_shape = {new_out_channel, kernel_h, kernel_w, new_in_channel};
bias_shape = {new_out_channel};
auto *origin_weight = reinterpret_cast<float *>(inputs.at(kWeightIndex)->data_c());
auto *origin_bias = reinterpret_cast<float *>(inputs.at(kBiasIndex)->data_c());
for (int i = 0; i < group; ++i) {
std::vector<lite::Tensor *> new_inputs;
std::vector<lite::Tensor *> new_outputs;
// get new input for each group
auto in_tensor =
new (std::nothrow) lite::Tensor(inputs.front()->data_type(), in_shape, Format_NHWC, lite::Tensor::Category::VAR);
if (primitive != nullptr && primitive->GetInferFlag()) {
in_tensor->MallocData();
}
new_inputs.emplace_back(in_tensor);
// nwe weight
auto filter_tensor = new (std::nothrow)
lite::Tensor(inputs.at(kWeightIndex)->data_type(), filter_shape, Format_NHWC, lite::Tensor::Category::CONST);
filter_tensor->MallocData();
int copy_length = kernel_h * kernel_w * new_in_channel * new_out_channel;
memcpy(filter_tensor->data_c(), origin_weight + i * copy_length, copy_length * sizeof(float));
new_inputs.emplace_back(filter_tensor);
// if has bias, set new bias
if (has_bias) {
auto bias_tensor = new (std::nothrow)
lite::Tensor(inputs.at(kBiasIndex)->data_type(), bias_shape, Format_NHWC, lite::Tensor::Category::CONST);
bias_tensor->MallocData();
memcpy(bias_tensor->data_c(), origin_bias + i * new_out_channel, new_out_channel * sizeof(float));
new_inputs.emplace_back(bias_tensor);
}
// set new output tensor
for (int j = 0; j < output_num; ++j) {
auto tmp_out_tensor = new (std::nothrow) lite::Tensor();
tmp_out_tensor->set_data_type(outputs.at(j)->data_type());
tmp_out_tensor->SetFormat(outputs.at(j)->GetFormat());
if (primitive != nullptr && primitive->GetInferFlag()) {
tmp_out_tensor->set_shape(out_shape);
tmp_out_tensor->MallocData();
}
new_outputs.emplace_back(tmp_out_tensor);
}
group_convs.emplace_back(
CpuConvFp32KernelSelect(new_inputs, new_outputs, op_parameter, ctx, primitive, use_winograd, out_unit));
}
// sub kernels and group conv kernel share the same op_parameter struct
return new (std::nothrow)
GroupConvolutionCPUKernel(op_parameter, inputs, outputs, ctx, primitive, group_convs, group);
}
kernel::LiteKernel *CpuConvFp32KernelCreator(const std::vector<lite::Tensor *> &inputs,
const std::vector<lite::Tensor *> &outputs, OpParameter *op_parameter,
const InnerContext *ctx, const kernel::KernelKey &desc,
@ -164,8 +268,7 @@ kernel::LiteKernel *CpuConvFp32KernelCreator(const std::vector<lite::Tensor *> &
MS_ASSERT(op_parameter != nullptr);
MS_ASSERT(desc.type == schema::PrimitiveType_Conv2D);
auto conv_param = reinterpret_cast<ConvParameter *>(op_parameter);
int kernel_h = conv_param->kernel_h_;
int kernel_w = conv_param->kernel_w_;
int group = conv_param->group_;
bool use_winograd = false;
int out_unit;
if (primitive != nullptr && primitive->GetInferFlag()) {
@ -192,14 +295,12 @@ kernel::LiteKernel *CpuConvFp32KernelCreator(const std::vector<lite::Tensor *> &
}
kernel::LiteKernel *kernel;
if (kernel_h == 1 && kernel_w == 1) {
kernel = new (std::nothrow) kernel::Convolution1x1CPUKernel(op_parameter, inputs, outputs, ctx, primitive);
} else if (use_winograd) {
kernel =
new (std::nothrow) kernel::ConvolutionWinogradCPUKernel(op_parameter, inputs, outputs, ctx, primitive, out_unit);
if (group == 1) {
kernel = CpuConvFp32KernelSelect(inputs, outputs, op_parameter, ctx, primitive, use_winograd, out_unit);
} else {
kernel = new (std::nothrow) kernel::ConvolutionCPUKernel(op_parameter, inputs, outputs, ctx, primitive);
kernel = CpuGroupConvFp32KernelCreator(inputs, outputs, op_parameter, ctx, primitive, group);
}
if (kernel == nullptr) {
MS_LOG(ERROR) << "kernel is nullptr.";
if (weight_tensor->data_type() == kNumberTypeInt8 || weight_tensor->data_type() == kNumberTypeInt16) {

View File

@ -0,0 +1,150 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "src/runtime/kernel/arm/fp32/group_convolution.h"
#include "schema/model_generated.h"
#include "src/kernel_registry.h"
#include "include/errorcode.h"
using mindspore::kernel::KERNEL_ARCH::kCPU;
using mindspore::lite::KernelRegistrar;
using mindspore::lite::RET_ERROR;
using mindspore::lite::RET_OK;
using mindspore::schema::PrimitiveType_Conv2D;
namespace mindspore::kernel {
int GroupConvolutionCPUKernel::Init() {
for (int i = 0; i < group_num_; ++i) {
auto ret = group_convs_[i]->Init();
if (ret != RET_OK) {
MS_LOG(ERROR) << "Sub kernel init failed.";
return ret;
}
}
// if infer shape is done, resize func will be invoked in sub kernels
return RET_OK;
}
int GroupConvolutionCPUKernel::ReSize() {
for (int i = 0; i < group_num_; ++i) {
auto ret = group_convs_[i]->ReSize();
if (ret != RET_OK) {
MS_LOG(ERROR) << "Sub kernel resize failed.";
return RET_ERROR;
}
}
conv_param_->input_channel_ /= group_num_;
conv_param_->output_channel_ /= group_num_;
return RET_OK;
}
int GroupConvolutionCPUKernel::PreProcess() {
if (!InferShapeDone()) {
auto ret = (const_cast<mindspore::lite::PrimitiveC *>(primitive_))->InferShape(in_tensors_, out_tensors_);
if (ret != 0) {
(const_cast<mindspore::lite::PrimitiveC *>(primitive_))->SetInferFlag(false);
MS_LOG(ERROR) << "InferShape fail!";
return ret;
}
(const_cast<mindspore::lite::PrimitiveC *>(primitive_))->SetInferFlag(true);
ret = ReSize();
if (ret != 0) {
MS_LOG(ERROR) << "ReSize fail!ret: " << ret;
return ret;
}
// if infershape func is called in runtime stage, we should malloc memory and set shape info for outputs of sub
// kernels here.
std::vector<int> in_shape;
std::vector<int> out_shape;
for (int i = 0; i < group_num_; ++i) {
// in
int in_batch = conv_param_->input_batch_;
int in_h = conv_param_->input_h_;
int in_w = conv_param_->input_w_;
int in_c = conv_param_->input_channel_;
in_shape = {in_batch, in_h, in_w, in_c};
auto sub_kernel_in_tensor = group_convs_[i]->in_tensors().front();
sub_kernel_in_tensor->set_shape(in_shape);
sub_kernel_in_tensor->MallocData();
// out
int out_batch = conv_param_->output_batch_;
int out_h = conv_param_->output_h_;
int out_w = conv_param_->output_w_;
int out_c = conv_param_->output_channel_;
out_shape = {out_batch, out_h, out_w, out_c};
auto sub_kernel_out_tensors = group_convs_[i]->out_tensors();
for (auto tensor : sub_kernel_out_tensors) {
tensor->set_shape(out_shape);
tensor->MallocData();
}
}
}
auto outputs = this->out_tensors();
for (auto *output : outputs) {
MS_ASSERT(output != nullptr);
output->MallocData();
}
return RET_OK;
}
void GroupConvolutionCPUKernel::SeparateInput(int group_id) {
int in_h = conv_param_->input_h_;
int in_w = conv_param_->input_w_;
int in_plane = in_h * in_w;
int sub_in_channel = conv_param_->input_channel_;
int ori_in_channel = sub_in_channel * group_num_;
auto sub_in_data = reinterpret_cast<float *>(group_convs_[group_id]->in_tensors().front()->data_c());
float *src_ptr = ori_in_data_ + group_id * sub_in_channel;
float *dst_ptr = sub_in_data;
for (int i = 0; i < in_plane; ++i) {
memcpy(dst_ptr, src_ptr, sub_in_channel * sizeof(float));
src_ptr += ori_in_channel;
dst_ptr += sub_in_channel;
}
}
void GroupConvolutionCPUKernel::PostConcat(int group_id) {
int out_h = conv_param_->output_h_;
int out_w = conv_param_->output_w_;
int out_plane = out_h * out_w;
int sub_out_channel = conv_param_->output_channel_;
int ori_out_channel = sub_out_channel * group_num_;
auto sub_out_data = reinterpret_cast<float *>(group_convs_[group_id]->out_tensors().front()->data_c());
float *src_ptr = sub_out_data;
float *dst_ptr = ori_out_data_ + group_id * sub_out_channel;
for (int i = 0; i < out_plane; ++i) {
memcpy(dst_ptr, src_ptr, sub_out_channel * sizeof(float));
src_ptr += sub_out_channel;
dst_ptr += ori_out_channel;
}
}
int GroupConvolutionCPUKernel::Run() {
ori_in_data_ = reinterpret_cast<float *>(in_tensors().front()->data_c());
ori_out_data_ = reinterpret_cast<float *>(out_tensors().front()->data_c());
for (int i = 0; i < group_num_; ++i) {
// first, separate group conv input into several parts. This step must be in runtime stage.
SeparateInput(i);
// sun kernels run
group_convs_[i]->Run();
// post process, concat all outputs of sub-kernels into one output
PostConcat(i);
}
return RET_OK;
}
} // namespace mindspore::kernel

View File

@ -0,0 +1,70 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_GROUP_CONVOLUTION_H_
#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_GROUP_CONVOLUTION_H_
#include <utility>
#include <vector>
#include "src/lite_kernel.h"
#include "nnacl/op_base.h"
#include "src/runtime/kernel/arm/base/convolution_base.h"
#include "nnacl/fp32/conv.h"
namespace mindspore::kernel {
class GroupConvolutionCPUKernel : public ConvolutionBaseCPUKernel {
public:
GroupConvolutionCPUKernel(OpParameter *parameter, const std::vector<lite::Tensor *> &inputs,
const std::vector<lite::Tensor *> &outputs, const lite::InnerContext *ctx,
const mindspore::lite::PrimitiveC *primitive, std::vector<kernel::LiteKernel *> group_convs,
const int group_num)
: ConvolutionBaseCPUKernel(parameter, inputs, outputs, ctx, primitive),
group_convs_(std::move(group_convs)),
group_num_(group_num) {} // opParameter(in channel, out channel) in this kernel has been split to groups, if
// you want to get real params, multiply in channel / out channel with group num
~GroupConvolutionCPUKernel() override {
for (auto sub_conv : group_convs_) {
// free sub conv input tensors / output tensors manually
auto sub_in_tensors = sub_conv->in_tensors();
auto sub_in_tensor_num = sub_in_tensors.size();
for (size_t i = 0; i < sub_in_tensor_num; ++i) {
delete sub_in_tensors[i];
}
auto sub_out_tensors = sub_conv->out_tensors();
auto sub_out_tensor_num = sub_out_tensors.size();
for (size_t i = 0; i < sub_out_tensor_num; ++i) {
delete sub_out_tensors[i];
}
delete sub_conv;
}
};
int Init() override;
int ReSize() override;
int Run() override;
int PreProcess() override;
void SeparateInput(int group_id);
void PostConcat(int group_id);
private:
std::vector<kernel::LiteKernel *> group_convs_;
float *ori_in_data_ = nullptr; // do not free
float *ori_out_data_ = nullptr; // do not free
const int group_num_;
};
} // namespace mindspore::kernel
#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_GROUP_CONVOLUTION_H_