conv im2col support avx512

This commit is contained in:
greatpan 2022-08-25 16:54:34 +08:00
parent ca134f7940
commit 247e24d8d4
14 changed files with 437 additions and 13 deletions

View File

@ -61,3 +61,4 @@
"mindspore/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/experimental" "unreadVariable"
"mindspore/mindspore/lite/python/src/pybind_module.cc" "syntaxError"
"mindspore/mindspore/lite/src/litert/kernel/cpu/fp32/convolution_im2col_fp32.cc" "knownConditionTrueFalse"
"mindspore/mindspore/lite/src/litert/kernel/cpu/fp32/convolution_im2col_fp32.cc" "shadowVariable"

View File

@ -92,7 +92,9 @@ file(GLOB KERNEL_SRC
${NNACL_DIR}/experimental/*.c
)
set(KERNEL_AVX512_FILE ${NNACL_DIR}/fp32/matmul_avx512_fp32.c)
set(KERNEL_AVX512_FILE ${NNACL_DIR}/fp32/matmul_avx512_fp32.c
${NNACL_DIR}/fp32/conv_im2col_avx512_fp32.c
)
list(REMOVE_ITEM KERNEL_SRC ${KERNEL_AVX512_FILE})
set(KERNEL_AVX_FILE ${NNACL_DIR}/fp32/conv_sw_avx_fp32.c

View File

@ -0,0 +1,92 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "nnacl/fp32/conv_im2col_avx512_fp32.h"
#include "nnacl/fp32/conv_im2col_fp32.h"
#include "nnacl/fp32/matmul_avx512_fp32.h"
#include "nnacl/intrinsics/ms_simd_avx512_instructions.h"
// fp32 conv common
void ConvIm2ColAVX512Fp32(const float *input_data, float *packed_input, const float *packed_weight,
const float *bias_data, float *output_data, int task_id, const ConvParameter *conv_param,
int cal_num) {
if (conv_param->thread_num_ == 0) {
return;
}
int output_hw = conv_param->output_h_ * conv_param->output_w_;
int out_channel_align = UP_ROUND(conv_param->output_channel_, C16NUM);
int block_per_thread = UP_DIV(UP_DIV(output_hw, cal_num), conv_param->thread_num_);
int start_block = block_per_thread * task_id;
int start_hw = start_block * cal_num;
int end_hw = MSMIN(output_hw, (start_block + block_per_thread) * cal_num);
if (start_hw >= end_hw) {
return;
}
int out_stride = out_channel_align * cal_num;
int deep = conv_param->kernel_h_ * conv_param->kernel_w_ * conv_param->input_channel_;
packed_input += task_id * deep * cal_num;
size_t input_size = deep * cal_num * sizeof(float);
for (int b = 0; b < conv_param->input_batch_; b++) {
int in_offset = b * conv_param->input_channel_ * conv_param->input_h_ * conv_param->input_w_;
int out_offset = b * out_channel_align * output_hw + start_hw * out_channel_align;
for (int i = start_hw; i < end_hw; i += cal_num, out_offset += out_stride) {
int real_cal_row = MSMIN(output_hw - i, cal_num);
memset(packed_input, 0, input_size);
Im2ColDataPackUnitFp32(input_data + in_offset, conv_param, packed_input, real_cal_row, i);
float *gemm_output = output_data + out_offset;
MatMulAvx512Fp32(packed_input, packed_weight, gemm_output, bias_data, (size_t)conv_param->act_type_, deep,
out_channel_align, out_channel_align, real_cal_row);
}
}
}
// fp32 conv common
void ConvIm2ColAVX512Fp32CutByBatch(const float *input_data, float *packed_input, const float *packed_weight,
const float *bias_data, float *output_data, int task_id,
const ConvParameter *conv_param, int cal_num) {
if (conv_param->thread_num_ == 0) {
return;
}
int output_hw = conv_param->output_h_ * conv_param->output_w_;
int out_channel_align = UP_ROUND(conv_param->output_channel_, C16NUM);
int block_batch_per_thread = UP_DIV(conv_param->input_batch_, conv_param->thread_num_);
int start_batch = block_batch_per_thread * task_id;
int end_batch = MSMIN(conv_param->input_batch_, (start_batch + block_batch_per_thread));
int out_stride = out_channel_align * cal_num;
int deep = conv_param->kernel_h_ * conv_param->kernel_w_ * conv_param->input_channel_;
packed_input += task_id * deep * cal_num;
size_t input_size = deep * cal_num * sizeof(float);
for (int b = start_batch; b < end_batch; b++) {
int in_offset = b * conv_param->input_channel_ * conv_param->input_h_ * conv_param->input_w_;
int out_offset = b * out_channel_align * output_hw;
for (int i = 0; i < output_hw; i += cal_num, out_offset += out_stride) {
int real_cal_row = MSMIN(output_hw - i, cal_num);
memset(packed_input, 0, input_size);
Im2ColDataPackUnitFp32(input_data + in_offset, conv_param, packed_input, real_cal_row, i);
float *gemm_output = output_data + out_offset;
MatMulAvx512Fp32(packed_input, packed_weight, gemm_output, bias_data, (size_t)conv_param->act_type_, deep,
out_channel_align, out_channel_align, real_cal_row);
}
}
}

View File

@ -0,0 +1,38 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_NNACL_FP32_CONV_IM2COL_AVX512_H_
#define MINDSPORE_NNACL_FP32_CONV_IM2COL_AVX512_H_
#include "nnacl/conv_parameter.h"
#ifdef __cplusplus
extern "C" {
#endif
void ConvIm2ColAVX512Fp32(const float *input_data, float *packed_input, const float *packed_weight,
const float *bias_data, float *output_data, int task_id, const ConvParameter *conv_param,
int cal_num);
void ConvIm2ColAVX512Fp32CutByBatch(const float *input_data, float *packed_input, const float *packed_weight,
const float *bias_data, float *output_data, int task_id,
const ConvParameter *conv_param, int cal_num);
#ifdef __cplusplus
}
#endif
#endif // MINDSPORE_NNACL_FP32_CONV_IM2COL_AVX512_H_

View File

@ -0,0 +1,64 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "nnacl/fp32/conv_im2col_fp32.h"
void Im2ColDataPackUnitFp32(const float *input_data, const ConvParameter *conv_param, float *packed_input,
int real_cal_num, int block_index) {
// input format : nhwc
int kernel_h = conv_param->kernel_h_;
int kernel_w = conv_param->kernel_w_;
int kernel_plane = kernel_h * kernel_w;
int dilation_h = conv_param->dilation_h_;
int dilation_w = conv_param->dilation_w_;
int out_w = conv_param->output_w_;
if (dilation_h == 0 || dilation_w == 0 || out_w == 0) {
return;
}
int in_channel = conv_param->input_channel_;
int in_w = conv_param->input_w_;
for (int i = 0; i < real_cal_num; i++) {
int block_start = block_index + i;
int input_h = block_start / out_w * conv_param->stride_h_ - conv_param->pad_u_;
int input_w = block_start % out_w * conv_param->stride_w_ - conv_param->pad_l_;
if (conv_param->input_h_ - input_h < 0 || in_w - input_w < 0) {
continue;
}
int input_stride = (input_h * in_w + input_w) * in_channel;
int kh_s = MSMAX(0, UP_DIV(-input_h, dilation_h));
int kh_e = MSMIN(kernel_h, UP_DIV(conv_param->input_h_ - input_h, dilation_h));
int kw_s = MSMAX(0, UP_DIV(-input_w, dilation_w));
int kw_e = MSMIN(kernel_w, UP_DIV(in_w - input_w, dilation_w));
if (dilation_w == 1 && dilation_h == 1) {
for (int j = kh_s; j < kh_e; j++) {
int input_y_stride = j * in_w * in_channel + input_stride;
int input_x_stride = input_y_stride + kw_s * in_channel;
int input_plane_offset = (j * kernel_w + kw_s) * in_channel + i * in_channel * kernel_plane;
memcpy(packed_input + input_plane_offset, input_data + input_x_stride,
(kw_e - kw_s) * in_channel * sizeof(float));
} // kernel_h loop
} else {
for (int j = kh_s; j < kh_e; j++) {
int input_y_stride = j * dilation_h * in_w * in_channel + input_stride;
for (int k = kw_s; k < kw_e; ++k) {
int input_x_stride = input_y_stride + k * dilation_w * in_channel;
int input_plane_offset = (j * kernel_w + k) * in_channel + i * in_channel * kernel_plane;
memcpy(packed_input + input_plane_offset, input_data + input_x_stride, in_channel * sizeof(float));
}
} // kernel_h loop
}
} // tile num loop
}

View File

@ -0,0 +1,33 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_NNACL_FP32_CONV_IM2COL_H_
#define MINDSPORE_NNACL_FP32_CONV_IM2COL_H_
#include "nnacl/conv_parameter.h"
#ifdef __cplusplus
extern "C" {
#endif
void Im2ColDataPackUnitFp32(const float *input_data, const ConvParameter *conv_param, float *packed_input,
int real_cal_num, int block_index);
#ifdef __cplusplus
}
#endif
#endif // MINDSPORE_NNACL_FP32_CONV_IM2COL_H_

View File

@ -29,10 +29,6 @@ int ElementOptDiv(const float *in0, const float *in1, float *out, int size, cons
out[index] = in0[0] / in1[index];
}
} else {
if (in1[0] == 0) {
return NNACL_ERRCODE_DIVISOR_ZERO;
}
SIMD_RUN_NO_SCALAR(ElementOptDivNum1, index, in0, in1, out, size);
for (; index < size; index++) {
out[index] = in0[index] / in1[0];

View File

@ -52,6 +52,7 @@
#define C56NUM 56
#define C64NUM 64
#define C128NUM 128
#define C150NUM 150
#define C256NUM 256
#define C1500NUM 1500
#define TILE_NUM 8

View File

@ -59,7 +59,9 @@ if(NOT("${X86_64_SIMD}" STREQUAL "avx" OR "${X86_64_SIMD}" STREQUAL "avx512"))
endif()
if(NOT("${X86_64_SIMD}" STREQUAL "avx512"))
set(KERNEL_SRC_AVX512_FILE ${CMAKE_CURRENT_SOURCE_DIR}/fp32/matmul_fp32_avx512.cc)
set(KERNEL_SRC_AVX512_FILE ${CMAKE_CURRENT_SOURCE_DIR}/fp32/convolution_im2col_avx512_fp32.cc
{CMAKE_CURRENT_SOURCE_DIR}/fp32/matmul_fp32_avx512.cc
)
list(REMOVE_ITEM KERNEL_SRC ${KERNEL_SRC_AVX512_FILE})
endif()

View File

@ -0,0 +1,119 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "src/litert/kernel/cpu/fp32/convolution_im2col_avx512_fp32.h"
#include "nnacl/fp32/conv_im2col_avx512_fp32.h"
using mindspore::lite::RET_ERROR;
using mindspore::lite::RET_NULL_PTR;
using mindspore::lite::RET_OK;
namespace mindspore::kernel {
void ConvolutionIm2ColAVX512CPUKernel::InitGlobalVariable() {
oc_tile_ = C16NUM;
row_tile_ = C150NUM;
rowMajor2ColNMajorFunc = RowMajor2Col64Major;
}
int ConvolutionIm2ColAVX512CPUKernel::InitTmpBuffer() {
MS_ASSERT(ctx_->allocator != nullptr);
CHECK_NULL_RETURN(out_tensors_[0]);
CHECK_NULL_RETURN(out_tensors_[0]->MutableData());
int unit_size =
conv_param_->kernel_h_ * conv_param_->kernel_w_ * conv_param_->input_channel_ * row_tile_ * thread_count_;
if (packed_input_ != nullptr) {
ctx_->allocator->Free(packed_input_);
packed_input_ = nullptr;
}
packed_input_ = reinterpret_cast<float *>(ctx_->allocator->Malloc(unit_size * sizeof(float)));
if (packed_input_ == nullptr) {
MS_LOG(ERROR) << "malloc packed input failed.";
return RET_ERROR;
}
if (conv_param_->output_channel_ % oc_tile_ != 0) {
output_need_align_ = true;
if (tmp_output_ != nullptr) {
ctx_->allocator->Free(tmp_output_);
}
// avx512 need to malloc dst aligned to C16NUM
int oc_algin = UP_ROUND(conv_param_->output_channel_, oc_tile_);
int pack_output_size = conv_param_->output_batch_ * conv_param_->output_h_ * conv_param_->output_w_ * oc_algin;
tmp_output_ =
reinterpret_cast<float *>(ctx_->allocator->Malloc(pack_output_size * static_cast<int>(sizeof(float))));
if (tmp_output_ == nullptr) {
MS_LOG(ERROR) << "malloc tmp output data failed.";
return RET_NULL_PTR;
}
}
return RET_OK;
}
int ConvolutionIm2ColAVX512CPUKernel::RunImpl(int task_id) {
auto ori_input_data = reinterpret_cast<float *>(in_tensors_.at(kInputIndex)->data());
if (out_tensors_[0]->format() != NC4HW4) {
if (use_batch_cut_flag_) {
ConvIm2ColAVX512Fp32CutByBatch(ori_input_data, packed_input_, reinterpret_cast<float *>(packed_weight_),
reinterpret_cast<float *>(bias_data_), tmp_output_, task_id, conv_param_,
row_tile_);
} else {
ConvIm2ColAVX512Fp32(ori_input_data, packed_input_, reinterpret_cast<float *>(packed_weight_),
reinterpret_cast<float *>(bias_data_), tmp_output_, task_id, conv_param_, row_tile_);
}
} else {
MS_LOG(ERROR) << "ConvolutionIm2ColAVX512CPUKernel do not support NC4HW4 output-format's avx512 version";
return RET_ERROR;
}
return RET_OK;
}
int ConvolutionIm2ColAVX512CPUKernel::Run() {
auto ret = InitTmpBuffer();
if (ret != RET_OK) {
MS_LOG(ERROR) << "Init tmp buffer failed.";
FreeTmpBuffer();
return RET_ERROR;
}
auto output_addr = reinterpret_cast<float *>(out_tensors_.at(kOutputIndex)->MutableData());
if (!output_need_align_) {
tmp_output_ = output_addr;
}
if (RepackWeight() != RET_OK) {
FreeTmpBuffer();
MS_LOG(ERROR) << "Repack weight failed.";
return RET_ERROR;
}
ret = ParallelLaunch(this->ms_context_, ConvolutionIm2ColImpl, this, thread_count_);
if (ret != RET_OK) {
MS_LOG(ERROR) << "conv error error_code[" << ret << "]";
}
if (output_need_align_) {
PackNHWCXToNHWCFp32(tmp_output_, output_addr, conv_param_->output_batch_,
conv_param_->output_w_ * conv_param_->output_h_, conv_param_->output_channel_, oc_tile_);
} else {
tmp_output_ = nullptr;
}
FreeTmpBuffer();
return ret;
}
} // namespace mindspore::kernel

View File

@ -0,0 +1,39 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_CPU_FP32_CONVOLUTION_IM2COL_AVX512_FP32_H_
#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_CPU_FP32_CONVOLUTION_IM2COL_AVX512_FP32_H_
#include <vector>
#include "src/litert/kernel/cpu/fp32/convolution_im2col_base_fp32.h"
namespace mindspore::kernel {
class ConvolutionIm2ColAVX512CPUKernel : public ConvolutionIm2ColBaseCPUKernel {
public:
ConvolutionIm2ColAVX512CPUKernel(OpParameter *parameter, const std::vector<lite::Tensor *> &inputs,
const std::vector<lite::Tensor *> &outputs, const lite::InnerContext *ctx,
float *origin_weight, float *origin_bias)
: ConvolutionIm2ColBaseCPUKernel(parameter, inputs, outputs, ctx, origin_weight, origin_bias) {}
~ConvolutionIm2ColAVX512CPUKernel() override {}
void InitGlobalVariable() override;
int InitTmpBuffer() override;
int RunImpl(int task_id) override;
int Run() override;
};
} // namespace mindspore::kernel
#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_CPU_FP32_CONVOLUTION_IM2COL_FP32_H_

View File

@ -37,12 +37,20 @@ int ConvolutionIm2ColAVXCPUKernel::InitTmpBuffer() {
int unit_size =
conv_param_->kernel_h_ * conv_param_->kernel_w_ * conv_param_->input_channel_ * row_tile_ * thread_count_;
if (packed_input_ != nullptr) {
ctx_->allocator->Free(packed_input_);
packed_input_ = nullptr;
}
packed_input_ = reinterpret_cast<float *>(ctx_->allocator->Malloc(unit_size * sizeof(float)));
if (packed_input_ == nullptr) {
MS_LOG(ERROR) << "malloc packed input failed.";
return RET_ERROR;
}
if (col_major_input_ != nullptr) {
ctx_->allocator->Free(col_major_input_);
col_major_input_ = nullptr;
}
col_major_input_ = reinterpret_cast<float *>(ctx_->allocator->Malloc(unit_size * sizeof(float)));
if (col_major_input_ == nullptr) {
MS_LOG(ERROR) << "malloc col_major_input_ failed.";
@ -54,6 +62,10 @@ int ConvolutionIm2ColAVXCPUKernel::InitTmpBuffer() {
int oc_algin = UP_DIV(conv_param_->output_channel_, oc_tile_);
int pack_output_size =
conv_param_->output_batch_ * conv_param_->output_h_ * conv_param_->output_w_ * oc_tile_ * oc_algin;
if (tmp_output_ != nullptr) {
ctx_->allocator->Free(tmp_output_);
tmp_output_ = nullptr;
}
tmp_output_ = reinterpret_cast<float *>(ms_context_->allocator->Malloc(pack_output_size * sizeof(float)));
if (tmp_output_ == nullptr) {
MS_LOG(ERROR) << "Malloc tmp_output_ buffer is failed.";

View File

@ -25,6 +25,7 @@
using mindspore::lite::RET_ERROR;
using mindspore::lite::RET_INFER_INVALID;
using mindspore::lite::RET_NULL_PTR;
using mindspore::lite::RET_OK;
namespace mindspore::kernel {
@ -44,13 +45,20 @@ int ConvolutionIm2ColBaseCPUKernel::InitTmpBuffer() {
int unit_size =
conv_param_->kernel_h_ * conv_param_->kernel_w_ * conv_param_->input_channel_ * row_tile_ * thread_count_;
if (packed_input_ != nullptr) {
ctx_->allocator->Free(packed_input_);
packed_input_ = nullptr;
}
packed_input_ = reinterpret_cast<float *>(ctx_->allocator->Malloc(unit_size * sizeof(float)));
if (packed_input_ == nullptr) {
MS_LOG(ERROR) << "malloc packed input failed.";
return RET_ERROR;
}
if (col_major_input_ != nullptr) {
ctx_->allocator->Free(col_major_input_);
col_major_input_ = nullptr;
}
col_major_input_ = reinterpret_cast<float *>(ctx_->allocator->Malloc(unit_size * sizeof(float)));
if (col_major_input_ == nullptr) {
MS_LOG(ERROR) << "malloc col_major_input_ failed.";

View File

@ -16,6 +16,9 @@
#include "src/litert/kernel/cpu/fp32/convolution_im2col_fp32.h"
#include "src/litert/kernel/cpu/fp32/convolution_im2col_base_fp32.h"
#if defined(ENABLE_AVX512)
#include "src/litert/kernel/cpu/fp32/convolution_im2col_avx512_fp32.h"
#endif
#if defined(ENABLE_AVX)
#include "src/litert/kernel/cpu/fp32/convolution_im2col_avx_fp32.h"
#endif
@ -31,12 +34,22 @@
#if defined(ENABLE_ARM64)
#include "src/litert/kernel/cpu/fp32/convolution_im2col_arm64_fp32.h"
#endif
#include "nnacl/intrinsics/ms_simd_cpu_info.h"
namespace mindspore::kernel {
LiteKernel *CreateConvolutionIm2ColCPUKernel(OpParameter *parameter, const std::vector<lite::Tensor *> &inputs,
const std::vector<lite::Tensor *> &outputs, const lite::InnerContext *ctx,
float *origin_weight, float *origin_bias) {
LiteKernel *kernel = nullptr;
#if defined(ENABLE_AVX512)
if (kernel == nullptr && outputs.front()->format() != NC4HW4) {
AVX512_HARDWARE_SELF_AWARENESS_BEGIN;
kernel = new (std::nothrow)
kernel::ConvolutionIm2ColAVX512CPUKernel(parameter, inputs, outputs, ctx, origin_weight, origin_bias);
AVX512_HARDWARE_SELF_AWARENESS_END;
}
#endif
#if defined(ENABLE_AVX)
if (kernel == nullptr) {
kernel = new (std::nothrow)
@ -45,21 +58,25 @@ LiteKernel *CreateConvolutionIm2ColCPUKernel(OpParameter *parameter, const std::
#endif
#if defined(ENABLE_SSE)
if (kernel == nullptr) {
if (kernel == nullptr && outputs.front()->format() != NC4HW4) {
kernel = new (std::nothrow)
kernel::ConvolutionIm2ColSSECPUKernel(parameter, inputs, outputs, ctx, origin_weight, origin_bias);
}
#endif
#if defined(ENABLE_ARM64)
if (kernel == nullptr) {
kernel = new (std::nothrow)
kernel::ConvolutionIm2ColARM64CPUKernel(parameter, inputs, outputs, ctx, origin_weight, origin_bias);
}
#elif defined(ENABLE_ARM32)
if (kernel == nullptr && outputs.front()->format() != NC4HW4) {
kernel = new (std::nothrow)
kernel::ConvolutionIm2ColARM32CPUKernel(parameter, inputs, outputs, ctx, origin_weight, origin_bias);
}
#endif
if (kernel == nullptr) {
if (kernel == nullptr && outputs.front()->format() != NC4HW4) {
kernel = new (std::nothrow)
kernel::ConvolutionIm2ColBaseCPUKernel(parameter, inputs, outputs, ctx, origin_weight, origin_bias);
}