!39988 [MSLITE][CPU] conv 1x1 slidewindows, support avx512

Merge pull request !39988 from Greatpan/conv_sw1x1_avx512_support
This commit is contained in:
i-robot 2022-08-09 12:44:32 +00:00 committed by Gitee
commit 32946f08b7
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
6 changed files with 309 additions and 4 deletions

View File

@ -24,6 +24,7 @@
#include "src/litert/kernel/cpu/fp32/convolution_depthwise_slidewindow_x86_fp32.h"
#include "src/litert/kernel/cpu/base/group_convolution_creator.h"
#include "src/litert/kernel/cpu/fp32/group_convolution_fp32.h"
#include "src/litert/kernel/cpu/fp32/convolution_sw_1x1_fp32.h"
#include "nnacl/base/conv_common_base.h"
#include "schema/model_generated.h"
#include "include/errorcode.h"
@ -127,6 +128,8 @@ int ConvolutionDelegateCPUKernel::Prepare() {
MS_LOG(ERROR) << "Get weight and bias failed.";
return ret;
}
input_const_ = in_tensors_[kInputIndex]->IsConst() && !op_parameter_->is_train_session_;
weight_const_ = in_tensors_[kWeightIndex]->IsConst() && !op_parameter_->is_train_session_;
if (!InferShapeDone()) {
return RET_OK;
}
@ -196,6 +199,16 @@ kernel::LiteKernel *ConvolutionDelegateCPUKernel::CpuConvFp32NC4KernelSelect() {
return nullptr;
}
bool ConvolutionDelegateCPUKernel::CheckAvxUseSW1x1Conv(const ConvParameter *conv_param) {
if (conv_param->kernel_h_ == 1 && conv_param->kernel_w_ == 1) {
if (conv_param->pad_d_ == 0 && conv_param->pad_l_ == 0 && conv_param->pad_r_ == 0 && conv_param->pad_u_ == 0 &&
conv_param->stride_h_ == 1 && conv_param->stride_w_ == 1) {
return true;
}
}
return false;
}
bool ConvolutionDelegateCPUKernel::CheckAvxUseSWConv(const ConvParameter *conv_param) {
if (conv_param->kernel_h_ == 1 && conv_param->kernel_w_ == 1) {
if (conv_param->pad_d_ == 0 && conv_param->pad_l_ == 0 && conv_param->pad_r_ == 0 && conv_param->pad_u_ == 0 &&
@ -214,6 +227,31 @@ bool ConvolutionDelegateCPUKernel::CheckAvxUseSWConv(const ConvParameter *conv_p
return false;
}
kernel::LiteKernel *ConvolutionDelegateCPUKernel::CreateConv1x1MatmulKernel() {
auto conv_param = reinterpret_cast<ConvParameter *>(op_parameter_);
matmul_param_ = new (std::nothrow) MatMulParameter;
if (matmul_param_ == nullptr) {
MS_LOG(WARNING) << "Memory allocation failed, Create Conv1x1 Matmul Kernel failed.";
return nullptr;
}
matmul_param_->row_ = conv_param->output_h_ * conv_param->output_w_;
matmul_param_->col_ = conv_param->output_channel_;
matmul_param_->deep_ = conv_param->input_channel_;
matmul_param_->batch = conv_param->input_batch_;
matmul_param_->op_parameter_ = conv_param->op_parameter_;
matmul_param_->act_type_ = conv_param->act_type_;
matmul_param_->a_transpose_ = false;
matmul_param_->b_transpose_ = true;
matmul_param_->a_const_ = input_const_;
matmul_param_->b_const_ = weight_const_;
auto kernel = new (std::nothrow) kernel::ConvolutionSW1x1CPUKernel(
reinterpret_cast<OpParameter *>(matmul_param_), in_tensors_, out_tensors_,
static_cast<const lite::InnerContext *>(this->ms_context_), origin_weight_, origin_bias_);
return kernel;
}
kernel::LiteKernel *ConvolutionDelegateCPUKernel::CpuConvFp32NHWCKernelSelect() {
kernel::LiteKernel *kernel = nullptr;
auto conv_param = reinterpret_cast<ConvParameter *>(op_parameter_);
@ -226,6 +264,10 @@ kernel::LiteKernel *ConvolutionDelegateCPUKernel::CpuConvFp32NHWCKernelSelect()
}
#ifdef ENABLE_AVX
if (kernel == nullptr && CheckAvxUseSW1x1Conv(conv_param)) {
kernel = CreateConv1x1MatmulKernel();
}
if (kernel == nullptr && CheckAvxUseSWConv(conv_param)) {
kernel = new (std::nothrow) kernel::ConvolutionSWAVXCPUKernel(
op_parameter_, in_tensors_, out_tensors_, static_cast<const lite::InnerContext *>(this->ms_context_),

View File

@ -19,6 +19,7 @@
#include <vector>
#include "src/litert/lite_kernel.h"
#include "nnacl/conv_parameter.h"
#include "nnacl/matmul_parameter.h"
#include "nnacl/op_base.h"
using mindspore::lite::InnerContext;
@ -30,6 +31,9 @@ class ConvolutionDelegateCPUKernel : public LiteKernel {
: LiteKernel(parameter, inputs, outputs, ctx) {}
~ConvolutionDelegateCPUKernel() override {
FreeCopiedData();
if (matmul_param_ != nullptr) {
matmul_param_ = nullptr;
}
if (conv_kernel_ != nullptr) {
op_parameter_ = nullptr; // op_parameter will be freed in conv_kernel
delete conv_kernel_;
@ -84,6 +88,8 @@ class ConvolutionDelegateCPUKernel : public LiteKernel {
kernel::LiteKernel *CpuConvFp32KernelSelect();
kernel::LiteKernel *CpuConvFp32NC4KernelSelect();
kernel::LiteKernel *CpuConvFp32NHWCKernelSelect();
kernel::LiteKernel *CreateConv1x1MatmulKernel();
bool CheckAvxUseSW1x1Conv(const ConvParameter *conv_param);
bool CheckAvxUseSWConv(const ConvParameter *conv_param);
// If inferShape process can't complete in Init part, initialization of weight and bis will be implemented in runtime
// via Resize() API. However,data of const tensor(weight and bias) doesn't exist anymore in runtime stage.Thus,
@ -117,10 +123,13 @@ class ConvolutionDelegateCPUKernel : public LiteKernel {
protected:
kernel::LiteKernel *conv_kernel_{nullptr};
MatMulParameter *matmul_param_{nullptr};
float *origin_weight_{nullptr};
float *origin_bias_{nullptr};
bool need_free_weight_{false};
bool need_free_bias_{false};
bool input_const_{false};
bool weight_const_{false};
};
} // namespace mindspore::kernel

View File

@ -0,0 +1,46 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "src/litert/kernel/cpu/fp32/convolution_sw_1x1_fp32.h"
#include "src/litert/kernel_registry.h"
using mindspore::kernel::KERNEL_ARCH;
using mindspore::lite::KernelRegistrar;
using mindspore::lite::RET_ERROR;
using mindspore::lite::RET_OK;
namespace mindspore::kernel {
int ConvolutionSW1x1CPUKernel::Prepare() {
CHECK_NULL_RETURN(matmul_base_);
matmul_base_->set_name(name_);
matmul_base_->set_workspace(workspace());
matmul_base_->SetConv1x1OriginWeight(origin_weight_);
matmul_base_->SetConv1x1OriginBias(origin_bias_);
return matmul_base_->Conv1x1Prepare();
}
int ConvolutionSW1x1CPUKernel::ReSize() {
CHECK_NULL_RETURN(matmul_base_);
matmul_base_->set_workspace(workspace());
return matmul_base_->Conv1x1ReSize();
}
int ConvolutionSW1x1CPUKernel::Run() {
CHECK_NULL_RETURN(matmul_base_);
matmul_base_->set_workspace(workspace());
return matmul_base_->Run();
}
} // namespace mindspore::kernel

View File

@ -0,0 +1,150 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_CPU_FP32_CONVOLUTION_SW_1X1_FP32_H_
#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_CPU_FP32_CONVOLUTION_SW_1X1_FP32_H_
#include <vector>
#include "include/context.h"
#include "include/errorcode.h"
#include "nnacl/intrinsics/ms_simd_cpu_info.h"
#include "nnacl/conv_parameter.h"
#include "nnacl/fp32/matmul_fp32.h"
#include "src/litert/kernel/cpu/fp32/matmul_fp32_base.h"
#if defined(ENABLE_AVX512)
#include "src/litert/kernel/cpu/fp32/matmul_fp32_avx512.h"
#endif
#if defined(ENABLE_AVX)
#include "src/litert/kernel/cpu/fp32/matmul_fp32_avx.h"
#endif
#if defined(ENABLE_SSE)
#include "src/litert/kernel/cpu/fp32/matmul_fp32_sse.h"
#endif
#if defined(ENABLE_ARM32)
#include "src/litert/kernel/cpu/fp32/matmul_fp32_arm32.h"
#endif
#if defined(ENABLE_ARM64)
#include "src/litert/kernel/cpu/fp32/matmul_fp32_arm64.h"
#endif
namespace mindspore::kernel {
class ConvolutionSW1x1CPUKernel : public LiteKernel {
public:
ConvolutionSW1x1CPUKernel(OpParameter *parameter, const std::vector<lite::Tensor *> &inputs,
const std::vector<lite::Tensor *> &outputs, const mindspore::lite::InnerContext *ctx,
float *origin_weight, float *origin_bias)
: LiteKernel(parameter, inputs, outputs, ctx), origin_weight_(origin_weight), origin_bias_(origin_bias) {
#if defined(ENABLE_AVX512)
if (matmul_base_ == nullptr) {
AVX512_HARDWARE_SELF_AWARENESS_BEGIN
matmul_base_ = new (std::nothrow) MatmulFp32AVX512CPUKernel(parameter, inputs, outputs, ctx);
AVX512_HARDWARE_SELF_AWARENESS_END
}
#endif
#if defined(ENABLE_AVX)
if (matmul_base_ == nullptr) {
matmul_base_ = new (std::nothrow) MatmulFp32AVXCPUKernel(parameter, inputs, outputs, ctx);
}
#endif
#if defined(ENABLE_SSE)
if (matmul_base_ == nullptr) {
matmul_base_ = new (std::nothrow) MatmulFp32SSECPUKernel(parameter, inputs, outputs, ctx);
}
#endif
#if defined(ENABLE_ARM64)
if (matmul_base_ == nullptr) {
matmul_base_ = new (std::nothrow) MatmulFp32ARM64CPUKernel(parameter, inputs, outputs, ctx);
}
#elif defined(ENABLE_ARM32)
if (matmul_base_ == nullptr) {
matmul_base_ = new (std::nothrow) MatmulFp32ARM32CPUKernel(parameter, inputs, outputs, ctx);
}
#endif
if (matmul_base_ == nullptr) {
matmul_base_ = new (std::nothrow) MatmulFp32BaseCPUKernel(parameter, inputs, outputs, ctx);
}
}
~ConvolutionSW1x1CPUKernel() {
if (matmul_base_ != nullptr) {
op_parameter_ = nullptr; // op_parameter will be freed in LiteKernel
matmul_base_->ws_allocated_ = this->ws_allocated_;
delete matmul_base_;
matmul_base_ = nullptr;
}
}
int Prepare() override;
int ReSize() override;
int Run() override;
void set_in_tensors(const std::vector<lite::Tensor *> &in_tensors) override {
this->in_tensors_ = in_tensors;
if (matmul_base_ != nullptr) {
matmul_base_->set_in_tensors(in_tensors);
}
}
void set_in_tensor(lite::Tensor *in_tensor, size_t index) override {
MS_ASSERT(index < in_tensors_.size());
this->in_tensors_[index] = in_tensor;
if (matmul_base_ != nullptr) {
matmul_base_->set_in_tensor(in_tensor, index);
}
}
void set_out_tensors(const std::vector<lite::Tensor *> &out_tensors) override {
this->out_tensors_ = out_tensors;
if (matmul_base_ != nullptr) {
matmul_base_->set_out_tensors(out_tensors);
}
}
void set_out_tensor(lite::Tensor *out_tensor, size_t index) override {
MS_ASSERT(index < out_tensors_.size());
this->out_tensors_[index] = out_tensor;
if (matmul_base_ != nullptr) {
matmul_base_->set_out_tensor(out_tensor, index);
}
}
// Train API
int Train() override {
(void)LiteKernel::Train();
return matmul_base_->Train();
}
void SetTrainable(bool trainable) override {
LiteKernel::SetTrainable(trainable);
return matmul_base_->SetTrainable(trainable);
}
size_t workspace_size() override {
(void)LiteKernel::workspace_size();
return matmul_base_->workspace_size();
}
private:
MatmulFp32BaseCPUKernel *matmul_base_ = nullptr;
float *origin_weight_ = nullptr;
float *origin_bias_ = nullptr;
};
} // namespace mindspore::kernel
#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_CPU_FP32_CONVOLUTION_SW_1X1_FP32_H_

View File

@ -233,8 +233,10 @@ int MatmulFp32BaseCPUKernel::PackMatrixB() {
}
int MatmulFp32BaseCPUKernel::PackMatrixBImpl() {
auto src_ptr =
matrix_b_.has_origin ? matrix_b_.origin_ptr : reinterpret_cast<float *>(in_tensors_[SECOND_INPUT]->data());
auto src_ptr = matrix_b_.has_origin
? matrix_b_.origin_ptr
: (conv1x1_origin_weight_ != nullptr ? conv1x1_origin_weight_
: reinterpret_cast<float *>(in_tensors_[SECOND_INPUT]->data()));
MS_CHECK_TRUE_MSG(src_ptr != nullptr, RET_ERROR, "matrix-b source ptr is a nullptr.");
MS_CHECK_TRUE_MSG(matrix_b_.pack_ptr != nullptr, RET_ERROR, "matrix-b pack ptr is a nullptr.");
MS_CHECK_TRUE_MSG(matrix_b_pack_fun_ != nullptr, RET_ERROR, "matrix-b func is a nullptr.");
@ -273,7 +275,10 @@ int MatmulFp32BaseCPUKernel::PackBiasMatrix() {
MS_LOG(ERROR) << "bias_tensor invalid";
return RET_ERROR;
}
auto bias_src = matrix_c_.has_origin ? matrix_c_.origin_ptr : reinterpret_cast<float *>(bias_tensor->data());
auto bias_src =
matrix_c_.has_origin
? matrix_c_.origin_ptr
: (conv1x1_origin_bias_ != nullptr ? conv1x1_origin_bias_ : reinterpret_cast<float *>(bias_tensor->data()));
MS_CHECK_TRUE_MSG(bias_src != nullptr, RET_ERROR, "matrix-c is a nullptr.");
auto bias_num = bias_tensor->ElementsNum();
MS_CHECK_TRUE_MSG(bias_num > 0 && params_->col_align_ >= bias_num, RET_ERROR, "matrix-c is invalid.");
@ -320,7 +325,8 @@ int MatmulFp32BaseCPUKernel::Prepare() {
MS_CHECK_TRUE_MSG(in_tensors_[SECOND_INPUT]->data_type() == kNumberTypeFloat32, RET_ERROR,
"matrix-b's data type is invalid.");
if (in_tensors_.size() == FOURTH_INPUT) {
MS_CHECK_TRUE_MSG(in_tensors_[THIRD_INPUT]->IsConst(), RET_ERROR, "matrix-c must be const when existing.");
MS_CHECK_TRUE_MSG(in_tensors_[THIRD_INPUT]->IsConst() || (conv1x1_origin_bias_ != nullptr), RET_ERROR,
"matrix-c must be const when existing.");
MS_CHECK_TRUE_MSG(in_tensors_[THIRD_INPUT]->data_type() == kNumberTypeFloat32, RET_ERROR,
"matrix-c's data type is invalid.");
}
@ -439,6 +445,41 @@ int MatmulFp32BaseCPUKernel::MatmulPrepare() {
return MatmulReSize();
}
int MatmulFp32BaseCPUKernel::Conv1x1Prepare() {
CHECK_LESS_RETURN(in_tensors_.size(), C2NUM);
CHECK_LESS_RETURN(out_tensors_.size(), 1);
if (params_->a_const_ || InferShapeDone()) {
auto input = in_tensors_.at(0);
params_->row_ = in_tensors_.at(0)->Batch() * input->Height() * input->Width();
params_->deep_ = input->Channel();
}
if (params_->b_const_ || InferShapeDone()) {
auto weight = in_tensors_.at(1);
params_->col_ = weight->Batch();
params_->deep_ = weight->Channel();
}
params_->batch = 1;
a_offset_.resize(params_->batch, 0);
b_offset_.resize(params_->batch, 0);
a_batch_ = 1;
b_batch_ = 1;
params_->a_transpose_ = false;
params_->b_transpose_ = true;
auto ret = MatmulFp32BaseCPUKernel::Prepare();
if (ret != RET_OK) {
return ret;
}
if (!InferShapeDone()) {
return RET_OK;
}
return Conv1x1ReSize();
}
int MatmulFp32BaseCPUKernel::ReSize() {
auto ret = InitParameter();
MS_CHECK_TRUE_MSG(ret == RET_OK, RET_ERROR, "Init parameters failed.");
@ -551,6 +592,17 @@ int MatmulFp32BaseCPUKernel::FullConnectionReSize() {
return MatmulFp32BaseCPUKernel::ReSize();
}
int MatmulFp32BaseCPUKernel::Conv1x1ReSize() {
auto input = in_tensors_.at(0);
params_->row_ = in_tensors_.at(0)->Batch() * input->Height() * input->Width();
params_->deep_ = input->Channel();
auto weight = in_tensors_.at(1);
params_->col_ = weight->Batch();
return MatmulFp32BaseCPUKernel::ReSize();
}
bool MatmulFp32BaseCPUKernel::CheckRow1OptimalConditions() {
return params_->row_ == 1 && !(SupportMulBatchCuttingByRow() && (a_batch_ > 1 && b_batch_ == 1));
}

View File

@ -44,9 +44,13 @@ class MatmulFp32BaseCPUKernel : public LiteKernel {
int Prepare() override;
int FullConnectionPrepare();
int MatmulPrepare();
void SetConv1x1OriginWeight(float *conv1x1_origin_weight) { conv1x1_origin_weight_ = conv1x1_origin_weight; }
void SetConv1x1OriginBias(float *conv1x1_origin_bias) { conv1x1_origin_bias_ = conv1x1_origin_bias; }
int Conv1x1Prepare();
int ReSize() override;
int FullConnectionReSize();
int MatmulReSize();
int Conv1x1ReSize();
int Run() override;
using ParallelRun = int (MatmulFp32BaseCPUKernel::*)(int task_id) const;
@ -112,6 +116,8 @@ class MatmulFp32BaseCPUKernel : public LiteKernel {
bool pack_opt_{false}; // indicate whether packing can be multi-threads, currently, only support in ARM64 && packA.
MatrixPackFun matrix_a_pack_fun_ = nullptr;
MatrixPackFun matrix_b_pack_fun_ = nullptr;
float *conv1x1_origin_weight_ = nullptr;
float *conv1x1_origin_bias_ = nullptr;
};
} // namespace mindspore::kernel
#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_CPU_FP32_MATMUL_FP32_BASE_H_