forked from mindspore-Ecosystem/mindspore
conv im2col support avx512
This commit is contained in:
parent
ca134f7940
commit
247e24d8d4
|
@ -61,3 +61,4 @@
|
|||
"mindspore/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/experimental" "unreadVariable"
|
||||
"mindspore/mindspore/lite/python/src/pybind_module.cc" "syntaxError"
|
||||
"mindspore/mindspore/lite/src/litert/kernel/cpu/fp32/convolution_im2col_fp32.cc" "knownConditionTrueFalse"
|
||||
"mindspore/mindspore/lite/src/litert/kernel/cpu/fp32/convolution_im2col_fp32.cc" "shadowVariable"
|
||||
|
|
|
@ -92,7 +92,9 @@ file(GLOB KERNEL_SRC
|
|||
${NNACL_DIR}/experimental/*.c
|
||||
)
|
||||
|
||||
set(KERNEL_AVX512_FILE ${NNACL_DIR}/fp32/matmul_avx512_fp32.c)
|
||||
set(KERNEL_AVX512_FILE ${NNACL_DIR}/fp32/matmul_avx512_fp32.c
|
||||
${NNACL_DIR}/fp32/conv_im2col_avx512_fp32.c
|
||||
)
|
||||
list(REMOVE_ITEM KERNEL_SRC ${KERNEL_AVX512_FILE})
|
||||
|
||||
set(KERNEL_AVX_FILE ${NNACL_DIR}/fp32/conv_sw_avx_fp32.c
|
||||
|
|
|
@ -0,0 +1,92 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "nnacl/fp32/conv_im2col_avx512_fp32.h"
|
||||
#include "nnacl/fp32/conv_im2col_fp32.h"
|
||||
#include "nnacl/fp32/matmul_avx512_fp32.h"
|
||||
#include "nnacl/intrinsics/ms_simd_avx512_instructions.h"
|
||||
|
||||
// fp32 conv common
|
||||
void ConvIm2ColAVX512Fp32(const float *input_data, float *packed_input, const float *packed_weight,
|
||||
const float *bias_data, float *output_data, int task_id, const ConvParameter *conv_param,
|
||||
int cal_num) {
|
||||
if (conv_param->thread_num_ == 0) {
|
||||
return;
|
||||
}
|
||||
int output_hw = conv_param->output_h_ * conv_param->output_w_;
|
||||
int out_channel_align = UP_ROUND(conv_param->output_channel_, C16NUM);
|
||||
|
||||
int block_per_thread = UP_DIV(UP_DIV(output_hw, cal_num), conv_param->thread_num_);
|
||||
int start_block = block_per_thread * task_id;
|
||||
int start_hw = start_block * cal_num;
|
||||
int end_hw = MSMIN(output_hw, (start_block + block_per_thread) * cal_num);
|
||||
if (start_hw >= end_hw) {
|
||||
return;
|
||||
}
|
||||
int out_stride = out_channel_align * cal_num;
|
||||
int deep = conv_param->kernel_h_ * conv_param->kernel_w_ * conv_param->input_channel_;
|
||||
packed_input += task_id * deep * cal_num;
|
||||
size_t input_size = deep * cal_num * sizeof(float);
|
||||
|
||||
for (int b = 0; b < conv_param->input_batch_; b++) {
|
||||
int in_offset = b * conv_param->input_channel_ * conv_param->input_h_ * conv_param->input_w_;
|
||||
int out_offset = b * out_channel_align * output_hw + start_hw * out_channel_align;
|
||||
for (int i = start_hw; i < end_hw; i += cal_num, out_offset += out_stride) {
|
||||
int real_cal_row = MSMIN(output_hw - i, cal_num);
|
||||
memset(packed_input, 0, input_size);
|
||||
Im2ColDataPackUnitFp32(input_data + in_offset, conv_param, packed_input, real_cal_row, i);
|
||||
|
||||
float *gemm_output = output_data + out_offset;
|
||||
MatMulAvx512Fp32(packed_input, packed_weight, gemm_output, bias_data, (size_t)conv_param->act_type_, deep,
|
||||
out_channel_align, out_channel_align, real_cal_row);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// fp32 conv common
|
||||
void ConvIm2ColAVX512Fp32CutByBatch(const float *input_data, float *packed_input, const float *packed_weight,
|
||||
const float *bias_data, float *output_data, int task_id,
|
||||
const ConvParameter *conv_param, int cal_num) {
|
||||
if (conv_param->thread_num_ == 0) {
|
||||
return;
|
||||
}
|
||||
int output_hw = conv_param->output_h_ * conv_param->output_w_;
|
||||
int out_channel_align = UP_ROUND(conv_param->output_channel_, C16NUM);
|
||||
|
||||
int block_batch_per_thread = UP_DIV(conv_param->input_batch_, conv_param->thread_num_);
|
||||
int start_batch = block_batch_per_thread * task_id;
|
||||
int end_batch = MSMIN(conv_param->input_batch_, (start_batch + block_batch_per_thread));
|
||||
|
||||
int out_stride = out_channel_align * cal_num;
|
||||
int deep = conv_param->kernel_h_ * conv_param->kernel_w_ * conv_param->input_channel_;
|
||||
packed_input += task_id * deep * cal_num;
|
||||
|
||||
size_t input_size = deep * cal_num * sizeof(float);
|
||||
|
||||
for (int b = start_batch; b < end_batch; b++) {
|
||||
int in_offset = b * conv_param->input_channel_ * conv_param->input_h_ * conv_param->input_w_;
|
||||
int out_offset = b * out_channel_align * output_hw;
|
||||
for (int i = 0; i < output_hw; i += cal_num, out_offset += out_stride) {
|
||||
int real_cal_row = MSMIN(output_hw - i, cal_num);
|
||||
memset(packed_input, 0, input_size);
|
||||
Im2ColDataPackUnitFp32(input_data + in_offset, conv_param, packed_input, real_cal_row, i);
|
||||
|
||||
float *gemm_output = output_data + out_offset;
|
||||
MatMulAvx512Fp32(packed_input, packed_weight, gemm_output, bias_data, (size_t)conv_param->act_type_, deep,
|
||||
out_channel_align, out_channel_align, real_cal_row);
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,38 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_NNACL_FP32_CONV_IM2COL_AVX512_H_
|
||||
#define MINDSPORE_NNACL_FP32_CONV_IM2COL_AVX512_H_
|
||||
|
||||
#include "nnacl/conv_parameter.h"
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif
|
||||
|
||||
void ConvIm2ColAVX512Fp32(const float *input_data, float *packed_input, const float *packed_weight,
|
||||
const float *bias_data, float *output_data, int task_id, const ConvParameter *conv_param,
|
||||
int cal_num);
|
||||
|
||||
void ConvIm2ColAVX512Fp32CutByBatch(const float *input_data, float *packed_input, const float *packed_weight,
|
||||
const float *bias_data, float *output_data, int task_id,
|
||||
const ConvParameter *conv_param, int cal_num);
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
|
||||
#endif // MINDSPORE_NNACL_FP32_CONV_IM2COL_AVX512_H_
|
|
@ -0,0 +1,64 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "nnacl/fp32/conv_im2col_fp32.h"
|
||||
|
||||
void Im2ColDataPackUnitFp32(const float *input_data, const ConvParameter *conv_param, float *packed_input,
|
||||
int real_cal_num, int block_index) {
|
||||
// input format : nhwc
|
||||
int kernel_h = conv_param->kernel_h_;
|
||||
int kernel_w = conv_param->kernel_w_;
|
||||
int kernel_plane = kernel_h * kernel_w;
|
||||
int dilation_h = conv_param->dilation_h_;
|
||||
int dilation_w = conv_param->dilation_w_;
|
||||
int out_w = conv_param->output_w_;
|
||||
if (dilation_h == 0 || dilation_w == 0 || out_w == 0) {
|
||||
return;
|
||||
}
|
||||
int in_channel = conv_param->input_channel_;
|
||||
int in_w = conv_param->input_w_;
|
||||
for (int i = 0; i < real_cal_num; i++) {
|
||||
int block_start = block_index + i;
|
||||
int input_h = block_start / out_w * conv_param->stride_h_ - conv_param->pad_u_;
|
||||
int input_w = block_start % out_w * conv_param->stride_w_ - conv_param->pad_l_;
|
||||
if (conv_param->input_h_ - input_h < 0 || in_w - input_w < 0) {
|
||||
continue;
|
||||
}
|
||||
int input_stride = (input_h * in_w + input_w) * in_channel;
|
||||
int kh_s = MSMAX(0, UP_DIV(-input_h, dilation_h));
|
||||
int kh_e = MSMIN(kernel_h, UP_DIV(conv_param->input_h_ - input_h, dilation_h));
|
||||
int kw_s = MSMAX(0, UP_DIV(-input_w, dilation_w));
|
||||
int kw_e = MSMIN(kernel_w, UP_DIV(in_w - input_w, dilation_w));
|
||||
if (dilation_w == 1 && dilation_h == 1) {
|
||||
for (int j = kh_s; j < kh_e; j++) {
|
||||
int input_y_stride = j * in_w * in_channel + input_stride;
|
||||
int input_x_stride = input_y_stride + kw_s * in_channel;
|
||||
int input_plane_offset = (j * kernel_w + kw_s) * in_channel + i * in_channel * kernel_plane;
|
||||
memcpy(packed_input + input_plane_offset, input_data + input_x_stride,
|
||||
(kw_e - kw_s) * in_channel * sizeof(float));
|
||||
} // kernel_h loop
|
||||
} else {
|
||||
for (int j = kh_s; j < kh_e; j++) {
|
||||
int input_y_stride = j * dilation_h * in_w * in_channel + input_stride;
|
||||
for (int k = kw_s; k < kw_e; ++k) {
|
||||
int input_x_stride = input_y_stride + k * dilation_w * in_channel;
|
||||
int input_plane_offset = (j * kernel_w + k) * in_channel + i * in_channel * kernel_plane;
|
||||
memcpy(packed_input + input_plane_offset, input_data + input_x_stride, in_channel * sizeof(float));
|
||||
}
|
||||
} // kernel_h loop
|
||||
}
|
||||
} // tile num loop
|
||||
}
|
|
@ -0,0 +1,33 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_NNACL_FP32_CONV_IM2COL_H_
|
||||
#define MINDSPORE_NNACL_FP32_CONV_IM2COL_H_
|
||||
|
||||
#include "nnacl/conv_parameter.h"
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif
|
||||
|
||||
void Im2ColDataPackUnitFp32(const float *input_data, const ConvParameter *conv_param, float *packed_input,
|
||||
int real_cal_num, int block_index);
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
|
||||
#endif // MINDSPORE_NNACL_FP32_CONV_IM2COL_H_
|
|
@ -29,10 +29,6 @@ int ElementOptDiv(const float *in0, const float *in1, float *out, int size, cons
|
|||
out[index] = in0[0] / in1[index];
|
||||
}
|
||||
} else {
|
||||
if (in1[0] == 0) {
|
||||
return NNACL_ERRCODE_DIVISOR_ZERO;
|
||||
}
|
||||
|
||||
SIMD_RUN_NO_SCALAR(ElementOptDivNum1, index, in0, in1, out, size);
|
||||
for (; index < size; index++) {
|
||||
out[index] = in0[index] / in1[0];
|
||||
|
|
|
@ -52,6 +52,7 @@
|
|||
#define C56NUM 56
|
||||
#define C64NUM 64
|
||||
#define C128NUM 128
|
||||
#define C150NUM 150
|
||||
#define C256NUM 256
|
||||
#define C1500NUM 1500
|
||||
#define TILE_NUM 8
|
||||
|
|
|
@ -59,7 +59,9 @@ if(NOT("${X86_64_SIMD}" STREQUAL "avx" OR "${X86_64_SIMD}" STREQUAL "avx512"))
|
|||
endif()
|
||||
|
||||
if(NOT("${X86_64_SIMD}" STREQUAL "avx512"))
|
||||
set(KERNEL_SRC_AVX512_FILE ${CMAKE_CURRENT_SOURCE_DIR}/fp32/matmul_fp32_avx512.cc)
|
||||
set(KERNEL_SRC_AVX512_FILE ${CMAKE_CURRENT_SOURCE_DIR}/fp32/convolution_im2col_avx512_fp32.cc
|
||||
{CMAKE_CURRENT_SOURCE_DIR}/fp32/matmul_fp32_avx512.cc
|
||||
)
|
||||
list(REMOVE_ITEM KERNEL_SRC ${KERNEL_SRC_AVX512_FILE})
|
||||
endif()
|
||||
|
||||
|
|
|
@ -0,0 +1,119 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "src/litert/kernel/cpu/fp32/convolution_im2col_avx512_fp32.h"
|
||||
#include "nnacl/fp32/conv_im2col_avx512_fp32.h"
|
||||
|
||||
using mindspore::lite::RET_ERROR;
|
||||
using mindspore::lite::RET_NULL_PTR;
|
||||
using mindspore::lite::RET_OK;
|
||||
|
||||
namespace mindspore::kernel {
|
||||
void ConvolutionIm2ColAVX512CPUKernel::InitGlobalVariable() {
|
||||
oc_tile_ = C16NUM;
|
||||
row_tile_ = C150NUM;
|
||||
|
||||
rowMajor2ColNMajorFunc = RowMajor2Col64Major;
|
||||
}
|
||||
|
||||
int ConvolutionIm2ColAVX512CPUKernel::InitTmpBuffer() {
|
||||
MS_ASSERT(ctx_->allocator != nullptr);
|
||||
CHECK_NULL_RETURN(out_tensors_[0]);
|
||||
CHECK_NULL_RETURN(out_tensors_[0]->MutableData());
|
||||
|
||||
int unit_size =
|
||||
conv_param_->kernel_h_ * conv_param_->kernel_w_ * conv_param_->input_channel_ * row_tile_ * thread_count_;
|
||||
|
||||
if (packed_input_ != nullptr) {
|
||||
ctx_->allocator->Free(packed_input_);
|
||||
packed_input_ = nullptr;
|
||||
}
|
||||
packed_input_ = reinterpret_cast<float *>(ctx_->allocator->Malloc(unit_size * sizeof(float)));
|
||||
if (packed_input_ == nullptr) {
|
||||
MS_LOG(ERROR) << "malloc packed input failed.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
|
||||
if (conv_param_->output_channel_ % oc_tile_ != 0) {
|
||||
output_need_align_ = true;
|
||||
if (tmp_output_ != nullptr) {
|
||||
ctx_->allocator->Free(tmp_output_);
|
||||
}
|
||||
|
||||
// avx512 need to malloc dst aligned to C16NUM
|
||||
int oc_algin = UP_ROUND(conv_param_->output_channel_, oc_tile_);
|
||||
int pack_output_size = conv_param_->output_batch_ * conv_param_->output_h_ * conv_param_->output_w_ * oc_algin;
|
||||
tmp_output_ =
|
||||
reinterpret_cast<float *>(ctx_->allocator->Malloc(pack_output_size * static_cast<int>(sizeof(float))));
|
||||
if (tmp_output_ == nullptr) {
|
||||
MS_LOG(ERROR) << "malloc tmp output data failed.";
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
}
|
||||
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
int ConvolutionIm2ColAVX512CPUKernel::RunImpl(int task_id) {
|
||||
auto ori_input_data = reinterpret_cast<float *>(in_tensors_.at(kInputIndex)->data());
|
||||
if (out_tensors_[0]->format() != NC4HW4) {
|
||||
if (use_batch_cut_flag_) {
|
||||
ConvIm2ColAVX512Fp32CutByBatch(ori_input_data, packed_input_, reinterpret_cast<float *>(packed_weight_),
|
||||
reinterpret_cast<float *>(bias_data_), tmp_output_, task_id, conv_param_,
|
||||
row_tile_);
|
||||
} else {
|
||||
ConvIm2ColAVX512Fp32(ori_input_data, packed_input_, reinterpret_cast<float *>(packed_weight_),
|
||||
reinterpret_cast<float *>(bias_data_), tmp_output_, task_id, conv_param_, row_tile_);
|
||||
}
|
||||
} else {
|
||||
MS_LOG(ERROR) << "ConvolutionIm2ColAVX512CPUKernel do not support NC4HW4 output-format's avx512 version";
|
||||
return RET_ERROR;
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
int ConvolutionIm2ColAVX512CPUKernel::Run() {
|
||||
auto ret = InitTmpBuffer();
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "Init tmp buffer failed.";
|
||||
FreeTmpBuffer();
|
||||
return RET_ERROR;
|
||||
}
|
||||
auto output_addr = reinterpret_cast<float *>(out_tensors_.at(kOutputIndex)->MutableData());
|
||||
if (!output_need_align_) {
|
||||
tmp_output_ = output_addr;
|
||||
}
|
||||
if (RepackWeight() != RET_OK) {
|
||||
FreeTmpBuffer();
|
||||
MS_LOG(ERROR) << "Repack weight failed.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
ret = ParallelLaunch(this->ms_context_, ConvolutionIm2ColImpl, this, thread_count_);
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "conv error error_code[" << ret << "]";
|
||||
}
|
||||
|
||||
if (output_need_align_) {
|
||||
PackNHWCXToNHWCFp32(tmp_output_, output_addr, conv_param_->output_batch_,
|
||||
conv_param_->output_w_ * conv_param_->output_h_, conv_param_->output_channel_, oc_tile_);
|
||||
} else {
|
||||
tmp_output_ = nullptr;
|
||||
}
|
||||
|
||||
FreeTmpBuffer();
|
||||
return ret;
|
||||
}
|
||||
} // namespace mindspore::kernel
|
|
@ -0,0 +1,39 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_CPU_FP32_CONVOLUTION_IM2COL_AVX512_FP32_H_
|
||||
#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_CPU_FP32_CONVOLUTION_IM2COL_AVX512_FP32_H_
|
||||
|
||||
#include <vector>
|
||||
#include "src/litert/kernel/cpu/fp32/convolution_im2col_base_fp32.h"
|
||||
|
||||
namespace mindspore::kernel {
|
||||
class ConvolutionIm2ColAVX512CPUKernel : public ConvolutionIm2ColBaseCPUKernel {
|
||||
public:
|
||||
ConvolutionIm2ColAVX512CPUKernel(OpParameter *parameter, const std::vector<lite::Tensor *> &inputs,
|
||||
const std::vector<lite::Tensor *> &outputs, const lite::InnerContext *ctx,
|
||||
float *origin_weight, float *origin_bias)
|
||||
: ConvolutionIm2ColBaseCPUKernel(parameter, inputs, outputs, ctx, origin_weight, origin_bias) {}
|
||||
~ConvolutionIm2ColAVX512CPUKernel() override {}
|
||||
|
||||
void InitGlobalVariable() override;
|
||||
int InitTmpBuffer() override;
|
||||
int RunImpl(int task_id) override;
|
||||
int Run() override;
|
||||
};
|
||||
} // namespace mindspore::kernel
|
||||
|
||||
#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_CPU_FP32_CONVOLUTION_IM2COL_FP32_H_
|
|
@ -37,12 +37,20 @@ int ConvolutionIm2ColAVXCPUKernel::InitTmpBuffer() {
|
|||
int unit_size =
|
||||
conv_param_->kernel_h_ * conv_param_->kernel_w_ * conv_param_->input_channel_ * row_tile_ * thread_count_;
|
||||
|
||||
if (packed_input_ != nullptr) {
|
||||
ctx_->allocator->Free(packed_input_);
|
||||
packed_input_ = nullptr;
|
||||
}
|
||||
packed_input_ = reinterpret_cast<float *>(ctx_->allocator->Malloc(unit_size * sizeof(float)));
|
||||
if (packed_input_ == nullptr) {
|
||||
MS_LOG(ERROR) << "malloc packed input failed.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
|
||||
if (col_major_input_ != nullptr) {
|
||||
ctx_->allocator->Free(col_major_input_);
|
||||
col_major_input_ = nullptr;
|
||||
}
|
||||
col_major_input_ = reinterpret_cast<float *>(ctx_->allocator->Malloc(unit_size * sizeof(float)));
|
||||
if (col_major_input_ == nullptr) {
|
||||
MS_LOG(ERROR) << "malloc col_major_input_ failed.";
|
||||
|
@ -54,6 +62,10 @@ int ConvolutionIm2ColAVXCPUKernel::InitTmpBuffer() {
|
|||
int oc_algin = UP_DIV(conv_param_->output_channel_, oc_tile_);
|
||||
int pack_output_size =
|
||||
conv_param_->output_batch_ * conv_param_->output_h_ * conv_param_->output_w_ * oc_tile_ * oc_algin;
|
||||
if (tmp_output_ != nullptr) {
|
||||
ctx_->allocator->Free(tmp_output_);
|
||||
tmp_output_ = nullptr;
|
||||
}
|
||||
tmp_output_ = reinterpret_cast<float *>(ms_context_->allocator->Malloc(pack_output_size * sizeof(float)));
|
||||
if (tmp_output_ == nullptr) {
|
||||
MS_LOG(ERROR) << "Malloc tmp_output_ buffer is failed.";
|
||||
|
|
|
@ -25,6 +25,7 @@
|
|||
|
||||
using mindspore::lite::RET_ERROR;
|
||||
using mindspore::lite::RET_INFER_INVALID;
|
||||
using mindspore::lite::RET_NULL_PTR;
|
||||
using mindspore::lite::RET_OK;
|
||||
|
||||
namespace mindspore::kernel {
|
||||
|
@ -44,13 +45,20 @@ int ConvolutionIm2ColBaseCPUKernel::InitTmpBuffer() {
|
|||
|
||||
int unit_size =
|
||||
conv_param_->kernel_h_ * conv_param_->kernel_w_ * conv_param_->input_channel_ * row_tile_ * thread_count_;
|
||||
|
||||
if (packed_input_ != nullptr) {
|
||||
ctx_->allocator->Free(packed_input_);
|
||||
packed_input_ = nullptr;
|
||||
}
|
||||
packed_input_ = reinterpret_cast<float *>(ctx_->allocator->Malloc(unit_size * sizeof(float)));
|
||||
if (packed_input_ == nullptr) {
|
||||
MS_LOG(ERROR) << "malloc packed input failed.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
|
||||
if (col_major_input_ != nullptr) {
|
||||
ctx_->allocator->Free(col_major_input_);
|
||||
col_major_input_ = nullptr;
|
||||
}
|
||||
col_major_input_ = reinterpret_cast<float *>(ctx_->allocator->Malloc(unit_size * sizeof(float)));
|
||||
if (col_major_input_ == nullptr) {
|
||||
MS_LOG(ERROR) << "malloc col_major_input_ failed.";
|
||||
|
|
|
@ -16,6 +16,9 @@
|
|||
|
||||
#include "src/litert/kernel/cpu/fp32/convolution_im2col_fp32.h"
|
||||
#include "src/litert/kernel/cpu/fp32/convolution_im2col_base_fp32.h"
|
||||
#if defined(ENABLE_AVX512)
|
||||
#include "src/litert/kernel/cpu/fp32/convolution_im2col_avx512_fp32.h"
|
||||
#endif
|
||||
#if defined(ENABLE_AVX)
|
||||
#include "src/litert/kernel/cpu/fp32/convolution_im2col_avx_fp32.h"
|
||||
#endif
|
||||
|
@ -31,12 +34,22 @@
|
|||
#if defined(ENABLE_ARM64)
|
||||
#include "src/litert/kernel/cpu/fp32/convolution_im2col_arm64_fp32.h"
|
||||
#endif
|
||||
#include "nnacl/intrinsics/ms_simd_cpu_info.h"
|
||||
|
||||
namespace mindspore::kernel {
|
||||
LiteKernel *CreateConvolutionIm2ColCPUKernel(OpParameter *parameter, const std::vector<lite::Tensor *> &inputs,
|
||||
const std::vector<lite::Tensor *> &outputs, const lite::InnerContext *ctx,
|
||||
float *origin_weight, float *origin_bias) {
|
||||
LiteKernel *kernel = nullptr;
|
||||
#if defined(ENABLE_AVX512)
|
||||
if (kernel == nullptr && outputs.front()->format() != NC4HW4) {
|
||||
AVX512_HARDWARE_SELF_AWARENESS_BEGIN;
|
||||
kernel = new (std::nothrow)
|
||||
kernel::ConvolutionIm2ColAVX512CPUKernel(parameter, inputs, outputs, ctx, origin_weight, origin_bias);
|
||||
AVX512_HARDWARE_SELF_AWARENESS_END;
|
||||
}
|
||||
#endif
|
||||
|
||||
#if defined(ENABLE_AVX)
|
||||
if (kernel == nullptr) {
|
||||
kernel = new (std::nothrow)
|
||||
|
@ -45,21 +58,25 @@ LiteKernel *CreateConvolutionIm2ColCPUKernel(OpParameter *parameter, const std::
|
|||
#endif
|
||||
|
||||
#if defined(ENABLE_SSE)
|
||||
if (kernel == nullptr) {
|
||||
if (kernel == nullptr && outputs.front()->format() != NC4HW4) {
|
||||
kernel = new (std::nothrow)
|
||||
kernel::ConvolutionIm2ColSSECPUKernel(parameter, inputs, outputs, ctx, origin_weight, origin_bias);
|
||||
}
|
||||
#endif
|
||||
|
||||
#if defined(ENABLE_ARM64)
|
||||
kernel = new (std::nothrow)
|
||||
kernel::ConvolutionIm2ColARM64CPUKernel(parameter, inputs, outputs, ctx, origin_weight, origin_bias);
|
||||
if (kernel == nullptr) {
|
||||
kernel = new (std::nothrow)
|
||||
kernel::ConvolutionIm2ColARM64CPUKernel(parameter, inputs, outputs, ctx, origin_weight, origin_bias);
|
||||
}
|
||||
#elif defined(ENABLE_ARM32)
|
||||
kernel = new (std::nothrow)
|
||||
kernel::ConvolutionIm2ColARM32CPUKernel(parameter, inputs, outputs, ctx, origin_weight, origin_bias);
|
||||
if (kernel == nullptr && outputs.front()->format() != NC4HW4) {
|
||||
kernel = new (std::nothrow)
|
||||
kernel::ConvolutionIm2ColARM32CPUKernel(parameter, inputs, outputs, ctx, origin_weight, origin_bias);
|
||||
}
|
||||
#endif
|
||||
|
||||
if (kernel == nullptr) {
|
||||
if (kernel == nullptr && outputs.front()->format() != NC4HW4) {
|
||||
kernel = new (std::nothrow)
|
||||
kernel::ConvolutionIm2ColBaseCPUKernel(parameter, inputs, outputs, ctx, origin_weight, origin_bias);
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue