!9415 [MS][LITE][CPU]add int8 group conv

From: @fuzhiye
Reviewed-by: @zhang_xue_tong,@hangangqiang
Signed-off-by: @zhang_xue_tong
This commit is contained in:
mindspore-ci-bot 2020-12-03 16:57:12 +08:00 committed by Gitee
commit d2fde24794
8 changed files with 337 additions and 48 deletions

View File

@ -221,7 +221,7 @@ void FreeMemoryFp16(const std::vector<kernel::LiteKernel *> &group_convs, const
}
}
lite::Tensor *CreateInputTensor(TypeId data_type, std::vector<int> in_shape, bool infered_flag) {
lite::Tensor *CreateInputTensorFp16(TypeId data_type, std::vector<int> in_shape, bool infered_flag) {
auto in_tensor = new (std::nothrow) lite::Tensor(data_type, in_shape, Format_NHWC, lite::Tensor::Category::VAR);
if (in_tensor == nullptr) {
MS_LOG(ERROR) << "new in_tensor failed.";
@ -238,8 +238,8 @@ lite::Tensor *CreateInputTensor(TypeId data_type, std::vector<int> in_shape, boo
return in_tensor;
}
lite::Tensor *CreateFilterTensor(TypeId data_type, std::vector<int> filter_shape,
const std::vector<lite::Tensor *> &inputs, int copy_length, int index) {
lite::Tensor *CreateFilterTensorFp16(TypeId data_type, std::vector<int> filter_shape,
const std::vector<lite::Tensor *> &inputs, int copy_length, int index) {
auto filter_tensor =
new (std::nothrow) lite::Tensor(data_type, filter_shape, Format_NHWC, lite::Tensor::Category::CONST_TENSOR);
if (filter_tensor == nullptr) {
@ -263,8 +263,8 @@ lite::Tensor *CreateFilterTensor(TypeId data_type, std::vector<int> filter_shape
return filter_tensor;
}
lite::Tensor *CreateBiasTensor(TypeId data_type, std::vector<int> bias_shape, const std::vector<lite::Tensor *> &inputs,
int new_out_channel, int index) {
lite::Tensor *CreateBiasTensorFp16(TypeId data_type, std::vector<int> bias_shape,
const std::vector<lite::Tensor *> &inputs, int new_out_channel, int index) {
auto *origin_bias = inputs.at(kBiasIndex)->data_c();
auto bias_tensor =
new (std::nothrow) lite::Tensor(data_type, bias_shape, Format_NHWC, lite::Tensor::Category::CONST_TENSOR);
@ -289,8 +289,8 @@ lite::Tensor *CreateBiasTensor(TypeId data_type, std::vector<int> bias_shape, co
return bias_tensor;
}
lite::Tensor *CreateOutputTensor(std::vector<int> out_shape, const std::vector<lite::Tensor *> &outputs,
bool infered_flag, int index) {
lite::Tensor *CreateOutputTensorFp16(std::vector<int> out_shape, const std::vector<lite::Tensor *> &outputs,
bool infered_flag, int index) {
auto out_tensor = new (std::nothrow) lite::Tensor();
if (out_tensor == nullptr) {
MS_LOG(ERROR) << "new tmp_out_tensor failed.";
@ -356,7 +356,7 @@ kernel::LiteKernel *CpuGroupConvFp16KernelCreator(const std::vector<lite::Tensor
return nullptr;
}
// create new input for each group
auto in_tensor = CreateInputTensor(mindspore::kNumberTypeFloat16, in_shape, infered_flag);
auto in_tensor = CreateInputTensorFp16(mindspore::kNumberTypeFloat16, in_shape, infered_flag);
if (in_tensor == nullptr) {
delete new_conv_parameter;
FreeMemoryFp16(group_convs, new_inputs, new_outputs);
@ -367,7 +367,8 @@ kernel::LiteKernel *CpuGroupConvFp16KernelCreator(const std::vector<lite::Tensor
// create new weight
int copy_length = conv_param->kernel_h_ * conv_param->kernel_w_ * new_in_channel * new_out_channel;
auto filter_tensor = CreateFilterTensor(inputs.at(kWeightIndex)->data_type(), filter_shape, inputs, copy_length, i);
auto filter_tensor =
CreateFilterTensorFp16(inputs.at(kWeightIndex)->data_type(), filter_shape, inputs, copy_length, i);
if (filter_tensor == nullptr) {
delete new_conv_parameter;
FreeMemoryFp16(group_convs, new_inputs, new_outputs);
@ -378,7 +379,8 @@ kernel::LiteKernel *CpuGroupConvFp16KernelCreator(const std::vector<lite::Tensor
// if has bias, create new bias
if (has_bias) {
auto bias_tensor = CreateBiasTensor(inputs.at(kBiasIndex)->data_type(), bias_shape, inputs, new_out_channel, i);
auto bias_tensor =
CreateBiasTensorFp16(inputs.at(kBiasIndex)->data_type(), bias_shape, inputs, new_out_channel, i);
if (bias_tensor == nullptr) {
delete new_conv_parameter;
FreeMemoryFp16(group_convs, new_inputs, new_outputs);
@ -390,7 +392,7 @@ kernel::LiteKernel *CpuGroupConvFp16KernelCreator(const std::vector<lite::Tensor
// create new output tensors
for (size_t j = 0; j < outputs.size(); ++j) {
auto out_tensor = CreateOutputTensor(out_shape, outputs, infered_flag, j);
auto out_tensor = CreateOutputTensorFp16(out_shape, outputs, infered_flag, j);
if (out_tensor == nullptr) {
delete new_conv_parameter;
FreeMemoryFp16(group_convs, new_inputs, new_outputs);

View File

@ -168,8 +168,8 @@ ConvParameter *CreateNewConvParameter(ConvParameter *parameter) {
return conv_parameter;
}
void FreeMemoryFp32(const std::vector<kernel::LiteKernel *> &group_convs, const std::vector<lite::Tensor *> &new_inputs,
const std::vector<lite::Tensor *> &new_outputs) {
void FreeMemory(const std::vector<kernel::LiteKernel *> &group_convs, const std::vector<lite::Tensor *> &new_inputs,
const std::vector<lite::Tensor *> &new_outputs) {
for (auto sub_conv : group_convs) {
if (sub_conv != nullptr) {
delete sub_conv;
@ -187,7 +187,7 @@ void FreeMemoryFp32(const std::vector<kernel::LiteKernel *> &group_convs, const
}
}
lite::Tensor *CreateInputTensorFp32(TypeId data_type, std::vector<int> in_shape, bool infered_flag) {
lite::Tensor *CreateInputTensor(TypeId data_type, std::vector<int> in_shape, bool infered_flag) {
auto in_tensor = new (std::nothrow) lite::Tensor(data_type, in_shape, Format_NHWC, lite::Tensor::Category::VAR);
if (in_tensor == nullptr) {
MS_LOG(ERROR) << "new in_tensor failed.";
@ -247,8 +247,8 @@ lite::Tensor *CreateBiasTensorFp32(TypeId data_type, std::vector<int> bias_shape
return bias_tensor;
}
lite::Tensor *CreateOutputTensorFp32(std::vector<int> out_shape, const std::vector<lite::Tensor *> &outputs,
bool infered_flag, int index) {
lite::Tensor *CreateOutputTensor(std::vector<int> out_shape, const std::vector<lite::Tensor *> &outputs,
bool infered_flag, int index) {
auto out_tensor = new (std::nothrow) lite::Tensor();
if (out_tensor == nullptr) {
MS_LOG(ERROR) << "new tmp_out_tensor failed.";
@ -324,16 +324,16 @@ kernel::LiteKernel *CpuGroupConvFp32KernelCreator(const std::vector<lite::Tensor
std::vector<lite::Tensor *> new_outputs;
auto new_conv_parameter = CreateNewConvParameter(conv_param);
if (new_conv_parameter == nullptr) {
FreeMemoryFp32(group_convs, new_inputs, new_outputs);
FreeMemory(group_convs, new_inputs, new_outputs);
MS_LOG(ERROR) << "Get new conv parameter failed.";
return nullptr;
}
// create new input for each group
auto in_tensor = CreateInputTensorFp32(inputs.front()->data_type(), in_shape, infered_flag);
auto in_tensor = CreateInputTensor(inputs.front()->data_type(), in_shape, infered_flag);
if (in_tensor == nullptr) {
delete new_conv_parameter;
FreeMemoryFp32(group_convs, new_inputs, new_outputs);
FreeMemory(group_convs, new_inputs, new_outputs);
MS_LOG(ERROR) << "create input tensor failed.";
return nullptr;
}
@ -345,7 +345,7 @@ kernel::LiteKernel *CpuGroupConvFp32KernelCreator(const std::vector<lite::Tensor
CreateFilterTensorFp32(inputs.at(kWeightIndex)->data_type(), filter_shape, inputs, copy_length, i);
if (filter_tensor == nullptr) {
delete new_conv_parameter;
FreeMemoryFp32(group_convs, new_inputs, new_outputs);
FreeMemory(group_convs, new_inputs, new_outputs);
MS_LOG(ERROR) << "create filter tensor failed.";
return nullptr;
}
@ -357,7 +357,7 @@ kernel::LiteKernel *CpuGroupConvFp32KernelCreator(const std::vector<lite::Tensor
CreateBiasTensorFp32(inputs.at(kBiasIndex)->data_type(), bias_shape, inputs, new_out_channel, i);
if (bias_tensor == nullptr) {
delete new_conv_parameter;
FreeMemoryFp32(group_convs, new_inputs, new_outputs);
FreeMemory(group_convs, new_inputs, new_outputs);
MS_LOG(ERROR) << "create bias_tensor failed.";
return nullptr;
}
@ -366,10 +366,10 @@ kernel::LiteKernel *CpuGroupConvFp32KernelCreator(const std::vector<lite::Tensor
// create new output tensor
for (size_t j = 0; j < outputs.size(); ++j) {
auto out_tensor = CreateOutputTensorFp32(out_shape, outputs, infered_flag, j);
auto out_tensor = CreateOutputTensor(out_shape, outputs, infered_flag, j);
if (out_tensor == nullptr) {
delete new_conv_parameter;
FreeMemoryFp32(group_convs, new_inputs, new_outputs);
FreeMemory(group_convs, new_inputs, new_outputs);
MS_LOG(ERROR) << "new out_tensor failed.";
return nullptr;
}

View File

@ -61,6 +61,16 @@ class ConvolutionCPUKernel : public ConvolutionBaseCPUKernel {
float *packed_input_ = nullptr;
float *col_major_input_ = nullptr;
};
void FreeMemory(const std::vector<kernel::LiteKernel *> &group_convs, const std::vector<lite::Tensor *> &new_inputs,
const std::vector<lite::Tensor *> &new_outputs);
ConvParameter *CreateNewConvParameter(ConvParameter *parameter);
lite::Tensor *CreateInputTensor(TypeId data_type, std::vector<int> in_shape, bool infered_flag);
lite::Tensor *CreateOutputTensor(std::vector<int> out_shape, const std::vector<lite::Tensor *> &outputs,
bool infered_flag, int index);
} // namespace mindspore::kernel
#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_CONVOLUTION_H_

View File

@ -28,6 +28,11 @@ using mindspore::schema::PrimitiveType_Conv2D;
namespace mindspore::kernel {
int GroupConvolutionCPUKernel::Init() {
for (int i = 0; i < group_num_; ++i) {
auto sub_conv = group_convs_.at(i);
if (sub_conv == nullptr) {
MS_LOG(ERROR) << "sub con " << i << " is null.";
return RET_ERROR;
}
auto ret = group_convs_.at(i)->Init();
if (ret != RET_OK) {
MS_LOG(ERROR) << "Sub kernel init failed.";
@ -127,7 +132,7 @@ int GroupConvolutionCPUKernel::PreProcess() {
auto ret = output->MallocData();
if (ret != RET_OK) {
FreeSubKernel();
MS_LOG(ERROR) << "fp32 group conv out tensor malloc data failed.";
MS_LOG(ERROR) << "group conv out tensor malloc data failed.";
return ret;
}
}

View File

@ -41,15 +41,17 @@ class GroupConvolutionCPUKernel : public ConvolutionBaseCPUKernel {
int ReSize() override;
int Run() override;
int PreProcess() override;
void SeparateInput(int group_id);
void PostConcat(int group_id);
virtual void SeparateInput(int group_id);
virtual void PostConcat(int group_id);
void FreeSubKernel();
private:
protected:
std::vector<kernel::LiteKernel *> group_convs_;
const int group_num_;
private:
float *ori_in_data_ = nullptr; // do not free
float *ori_out_data_ = nullptr; // do not free
const int group_num_;
};
} // namespace mindspore::kernel

View File

@ -20,8 +20,10 @@
#include "schema/model_generated.h"
#include "src/kernel_registry.h"
#include "src/runtime/kernel/arm/base/layout_transform.h"
#include "src/runtime/kernel/arm/fp32/convolution_fp32.h"
#include "src/runtime/kernel/arm/int8/convolution_1x1_int8.h"
#include "src/runtime/kernel/arm/int8/convolution_3x3_int8.h"
#include "src/runtime/kernel/arm/int8/group_convolution_int8.h"
#include "src/runtime/runtime_api.h"
#ifdef ENABLE_ARM64
#include "src/runtime/kernel/arm/int8/opt_op_handler.h"
@ -32,6 +34,7 @@ using mindspore::lite::KernelRegistrar;
using mindspore::lite::RET_ERROR;
using mindspore::lite::RET_OK;
using mindspore::schema::PrimitiveType_Conv2D;
using mindspore::schema::Format::Format_NHWC;
namespace mindspore::kernel {
void ConvolutionInt8CPUKernel::CheckSupportOptimize() {
@ -242,6 +245,166 @@ int ConvolutionInt8CPUKernel::Run() {
return RET_OK;
}
lite::Tensor *CreateFilterTensorInt8(TypeId data_type, std::vector<int> filter_shape,
const std::vector<lite::Tensor *> &inputs, int copy_length, int index) {
MS_ASSERT(data_type == kNumberTypeInt8);
auto filter_tensor =
new (std::nothrow) lite::Tensor(data_type, filter_shape, Format_NHWC, lite::Tensor::Category::CONST_TENSOR);
if (filter_tensor == nullptr) {
MS_LOG(ERROR) << "new filter_tensor failed.";
return nullptr;
}
auto ret = filter_tensor->MallocData();
if (ret != RET_OK) {
delete filter_tensor;
MS_LOG(ERROR) << "filter_tensor malloc failed.";
return nullptr;
}
auto *origin_weight = reinterpret_cast<int8_t *>(inputs.at(kWeightIndex)->data_c());
memcpy(filter_tensor->data_c(), origin_weight + index * copy_length, copy_length * sizeof(int8_t));
return filter_tensor;
}
lite::Tensor *CreateBiasTensorInt8(TypeId data_type, std::vector<int> bias_shape,
const std::vector<lite::Tensor *> &inputs, int new_out_channel, int index) {
MS_ASSERT(data_type == kNumberTypeInt32);
auto *origin_bias = inputs.at(kBiasIndex)->data_c();
auto bias_tensor =
new (std::nothrow) lite::Tensor(data_type, bias_shape, Format_NHWC, lite::Tensor::Category::CONST_TENSOR);
if (bias_tensor == nullptr) {
MS_LOG(ERROR) << "new bias_tensor failed.";
return nullptr;
}
auto ret = bias_tensor->MallocData();
if (ret != RET_OK) {
delete bias_tensor;
MS_LOG(ERROR) << "bias_tensor malloc failed.";
return nullptr;
}
auto bias_data = reinterpret_cast<int32_t *>(origin_bias);
memcpy(bias_tensor->data_c(), bias_data + index * new_out_channel, new_out_channel * sizeof(int32_t));
return bias_tensor;
}
kernel::LiteKernel *CpuConvInt8KernelSelect(const std::vector<lite::Tensor *> &inputs,
const std::vector<lite::Tensor *> &outputs, OpParameter *op_parameter,
const InnerContext *ctx, const mindspore::lite::PrimitiveC *primitive) {
auto conv_param = reinterpret_cast<ConvParameter *>(op_parameter);
kernel::LiteKernel *kernel = nullptr;
if (conv_param->kernel_h_ == 3 && conv_param->kernel_w_ == 3 && conv_param->stride_h_ == 1 &&
conv_param->stride_w_ == 1 && conv_param->dilation_h_ == 1 && conv_param->dilation_w_ == 1) {
#ifdef ENABLE_ARM64
if (mindspore::lite::IsSupportSDot()) {
kernel = new (std::nothrow) kernel::ConvolutionInt8CPUKernel(op_parameter, inputs, outputs, ctx, primitive);
} else {
kernel = new (std::nothrow) kernel::Convolution3x3Int8CPUKernel(op_parameter, inputs, outputs, ctx, primitive);
}
#else
kernel = new (std::nothrow) kernel::Convolution3x3Int8CPUKernel(op_parameter, inputs, outputs, ctx, primitive);
#endif
} else if (conv_param->kernel_h_ == 1 && conv_param->kernel_w_ == 1) {
kernel = new (std::nothrow) kernel::Convolution1x1Int8CPUKernel(op_parameter, inputs, outputs, ctx, primitive);
} else {
kernel = new (std::nothrow) kernel::ConvolutionInt8CPUKernel(op_parameter, inputs, outputs, ctx, primitive);
}
return kernel;
}
kernel::LiteKernel *CpuGroupConvInt8KernelCreator(const std::vector<lite::Tensor *> &inputs,
const std::vector<lite::Tensor *> &outputs, OpParameter *op_parameter,
const InnerContext *ctx, const mindspore::lite::PrimitiveC *primitive,
int group) {
auto conv_param = reinterpret_cast<ConvParameter *>(op_parameter);
std::vector<int> in_shape;
std::vector<int> out_shape;
int new_in_channel = inputs.at(kWeightIndex)->Channel();
int new_out_channel = 0;
if (group == 0) {
MS_LOG(ERROR) << "Divisor 'group' cannot be 0.";
return nullptr;
} else {
new_out_channel = inputs.at(kWeightIndex)->Batch() / group;
}
bool infered_flag = primitive != nullptr && primitive->infer_flag();
if (infered_flag) {
int batch = inputs.front()->Batch();
int in_h = inputs.front()->Height();
int in_w = inputs.front()->Width();
conv_param->input_channel_ = new_in_channel;
conv_param->output_channel_ = new_out_channel;
in_shape = {batch, in_h, in_w, new_in_channel};
out_shape = {batch, conv_param->output_h_, conv_param->output_w_, new_out_channel};
}
std::vector<int> filter_shape = {new_out_channel, conv_param->kernel_h_, conv_param->kernel_w_, new_in_channel};
std::vector<int> bias_shape = {new_out_channel};
// create sub kernels
std::vector<kernel::LiteKernel *> group_convs;
for (int i = 0; i < group; ++i) {
std::vector<lite::Tensor *> new_inputs;
std::vector<lite::Tensor *> new_outputs;
auto new_conv_parameter = CreateNewConvParameter(conv_param);
if (new_conv_parameter == nullptr) {
FreeMemory(group_convs, new_inputs, new_outputs);
MS_LOG(ERROR) << "Get new conv parameter failed.";
return nullptr;
}
// create new input for each group
auto input_data_type = inputs.front()->data_type();
MS_ASSERT(input_data_type == kNumberTypeInt8);
auto in_tensor = CreateInputTensor(input_data_type, in_shape, infered_flag);
if (in_tensor == nullptr) {
delete new_conv_parameter;
FreeMemory(group_convs, new_inputs, new_outputs);
MS_LOG(ERROR) << "create input tensor failed.";
return nullptr;
}
new_inputs.emplace_back(in_tensor);
// create new weight
int copy_length = conv_param->kernel_h_ * conv_param->kernel_w_ * new_in_channel * new_out_channel;
auto filter_tensor =
CreateFilterTensorInt8(inputs.at(kWeightIndex)->data_type(), filter_shape, inputs, copy_length, i);
if (filter_tensor == nullptr) {
delete new_conv_parameter;
FreeMemory(group_convs, new_inputs, new_outputs);
MS_LOG(ERROR) << "create filter tensor failed.";
return nullptr;
}
new_inputs.emplace_back(filter_tensor);
// if has bias, create new bias
if (inputs.size() == 3) {
auto bias_tensor =
CreateBiasTensorInt8(inputs.at(kBiasIndex)->data_type(), bias_shape, inputs, new_out_channel, i);
if (bias_tensor == nullptr) {
delete new_conv_parameter;
FreeMemory(group_convs, new_inputs, new_outputs);
MS_LOG(ERROR) << "create bias_tensor failed.";
return nullptr;
}
new_inputs.emplace_back(bias_tensor);
}
// create new output tensor
for (size_t j = 0; j < outputs.size(); ++j) {
auto out_tensor = CreateOutputTensor(out_shape, outputs, infered_flag, j);
if (out_tensor == nullptr) {
delete new_conv_parameter;
FreeMemory(group_convs, new_inputs, new_outputs);
MS_LOG(ERROR) << "new out_tensor failed.";
return nullptr;
}
new_outputs.emplace_back(out_tensor);
}
group_convs.emplace_back(CpuConvInt8KernelSelect(
new_inputs, new_outputs, reinterpret_cast<OpParameter *>(new_conv_parameter), ctx, primitive));
}
return new (std::nothrow)
GroupConvolutionInt8CPUKernel(op_parameter, inputs, outputs, ctx, primitive, group_convs, group);
}
kernel::LiteKernel *CpuConvInt8KernelCreator(const std::vector<lite::Tensor *> &inputs,
const std::vector<lite::Tensor *> &outputs, OpParameter *opParameter,
const InnerContext *ctx, const kernel::KernelKey &desc,
@ -249,27 +412,12 @@ kernel::LiteKernel *CpuConvInt8KernelCreator(const std::vector<lite::Tensor *> &
MS_ASSERT(opParameter != nullptr);
MS_ASSERT(desc.type == schema::PrimitiveType_Conv2D);
auto conv_param = reinterpret_cast<ConvParameter *>(opParameter);
int kernel_h = conv_param->kernel_h_;
int kernel_w = conv_param->kernel_w_;
int stride_h = conv_param->stride_h_;
int stride_w = conv_param->stride_w_;
int dilation_h = conv_param->dilation_h_;
int dilation_w = conv_param->dilation_w_;
kernel::LiteKernel *kernel;
if (kernel_h == 3 && kernel_w == 3 && stride_h == 1 && stride_w == 1 && dilation_h == 1 && dilation_w == 1) {
#ifdef ENABLE_ARM64
if (mindspore::lite::IsSupportSDot()) {
kernel = new (std::nothrow) kernel::ConvolutionInt8CPUKernel(opParameter, inputs, outputs, ctx, primitive);
} else {
kernel = new (std::nothrow) kernel::Convolution3x3Int8CPUKernel(opParameter, inputs, outputs, ctx, primitive);
}
#else
kernel = new (std::nothrow) kernel::Convolution3x3Int8CPUKernel(opParameter, inputs, outputs, ctx, primitive);
#endif
} else if (kernel_h == 1 && kernel_w == 1) {
kernel = new (std::nothrow) kernel::Convolution1x1Int8CPUKernel(opParameter, inputs, outputs, ctx, primitive);
kernel::LiteKernel *kernel = nullptr;
if (conv_param->group_ == 1) {
kernel = CpuConvInt8KernelSelect(inputs, outputs, opParameter, ctx, primitive);
} else {
kernel = new (std::nothrow) kernel::ConvolutionInt8CPUKernel(opParameter, inputs, outputs, ctx, primitive);
MS_ASSERT(conv_param->group_ > 1);
kernel = CpuGroupConvInt8KernelCreator(inputs, outputs, opParameter, ctx, primitive, conv_param->group_);
}
if (kernel == nullptr) {
MS_LOG(ERROR) << "kernel is nullptr.";

View File

@ -0,0 +1,74 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "src/runtime/kernel/arm/int8/group_convolution_int8.h"
#include "schema/model_generated.h"
#include "src/kernel_registry.h"
#include "include/errorcode.h"
using mindspore::kernel::KERNEL_ARCH::kCPU;
using mindspore::lite::KernelRegistrar;
using mindspore::lite::RET_ERROR;
using mindspore::lite::RET_OK;
using mindspore::schema::PrimitiveType_Conv2D;
namespace mindspore::kernel {
void GroupConvolutionInt8CPUKernel::SeparateInput(int group_id) {
int in_plane = conv_param_->input_h_ * conv_param_->input_w_;
int sub_in_channel = conv_param_->input_channel_;
int ori_in_channel = sub_in_channel * group_num_;
auto sub_in_data = reinterpret_cast<int8_t *>(group_convs_.at(group_id)->in_tensors().front()->data_c());
int8_t *src_ptr = ori_in_data_ + group_id * sub_in_channel;
int8_t *dst_ptr = sub_in_data;
for (int i = 0; i < in_plane; ++i) {
memcpy(dst_ptr, src_ptr, sub_in_channel * sizeof(int8_t));
src_ptr += ori_in_channel;
dst_ptr += sub_in_channel;
}
}
void GroupConvolutionInt8CPUKernel::PostConcat(int group_id) {
int out_plane = conv_param_->output_h_ * conv_param_->output_w_;
int sub_out_channel = conv_param_->output_channel_;
int ori_out_channel = sub_out_channel * group_num_;
auto sub_out_data = reinterpret_cast<int8_t *>(group_convs_.at(group_id)->out_tensors().front()->data_c());
int8_t *src_ptr = sub_out_data;
int8_t *dst_ptr = ori_out_data_ + group_id * sub_out_channel;
for (int i = 0; i < out_plane; ++i) {
memcpy(dst_ptr, src_ptr, sub_out_channel * sizeof(int8_t));
src_ptr += sub_out_channel;
dst_ptr += ori_out_channel;
}
}
int GroupConvolutionInt8CPUKernel::Run() {
ori_in_data_ = reinterpret_cast<int8_t *>(in_tensors().front()->data_c());
ori_out_data_ = reinterpret_cast<int8_t *>(out_tensors().front()->data_c());
for (int i = 0; i < group_num_; ++i) {
// first, separate group conv input into several parts. This step must be in runtime stage.
SeparateInput(i);
// sun kernels run
auto ret = group_convs_.at(i)->Run();
if (ret != RET_OK) {
MS_LOG(ERROR) << "sub kernel " << i << " execute failed.";
return ret;
}
// post process, concat all outputs of sub-kernels into one output
PostConcat(i);
}
return RET_OK;
}
} // namespace mindspore::kernel

View File

@ -0,0 +1,48 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_INT8_GROUP_CONVOLUTION_INT8_H_
#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_INT8_GROUP_CONVOLUTION_INT8_H_
#include <utility>
#include <vector>
#include "src/lite_kernel.h"
#include "nnacl/op_base.h"
#include "src/runtime/kernel/arm/fp32/group_convolution_fp32.h"
namespace mindspore::kernel {
class GroupConvolutionInt8CPUKernel : public GroupConvolutionCPUKernel {
public:
GroupConvolutionInt8CPUKernel(OpParameter *parameter, const std::vector<lite::Tensor *> &inputs,
const std::vector<lite::Tensor *> &outputs, const lite::InnerContext *ctx,
const mindspore::lite::PrimitiveC *primitive,
std::vector<kernel::LiteKernel *> group_convs, const int group_num)
: GroupConvolutionCPUKernel(parameter, inputs, outputs, ctx, primitive, group_convs, group_num) {
} // opParameter(in channel, out channel) in this kernel has been split to groups, if
// you want to get real params, multiply in channel / out channel with group num
~GroupConvolutionInt8CPUKernel() override { GroupConvolutionCPUKernel::FreeSubKernel(); }
int Run() override;
void SeparateInput(int group_id) override;
void PostConcat(int group_id) override;
private:
int8_t *ori_in_data_ = nullptr; // do not free
int8_t *ori_out_data_ = nullptr; // do not free
};
} // namespace mindspore::kernel
#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_INT8_GROUP_CONVOLUTION_INT8_H_