forked from mindspore-Ecosystem/mindspore
conv winograd op refactor
This commit is contained in:
parent
4a616f1b27
commit
c749dd57a1
|
@ -64,3 +64,5 @@
|
|||
"mindspore/mindspore/lite/python/src/pybind_module.cc" "syntaxError"
|
||||
"mindspore/mindspore/lite/src/litert/kernel/cpu/fp32/convolution_im2col_fp32.cc" "knownConditionTrueFalse"
|
||||
"mindspore/mindspore/lite/src/litert/kernel/cpu/fp32/convolution_im2col_fp32.cc" "shadowVariable"
|
||||
"mindspore/mindspore/lite/src/litert/kernel/cpu/fp32/convolution_winograd_fp32.cc" "knownConditionTrueFalse"
|
||||
"mindspore/mindspore/lite/src/litert/kernel/cpu/fp32/convolution_winograd_fp32.cc" "shadowVariable"
|
||||
|
|
|
@ -32,6 +32,7 @@ endif()
|
|||
if(NOT PLATFORM_ARM64)
|
||||
set(KERNEL_SRC_ARM64_FILE ${CMAKE_CURRENT_SOURCE_DIR}/fp32/convolution_im2col_arm64_fp32.cc
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/fp32/matmul_fp32_arm64.cc
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/fp32/convolution_winograd_arm64_fp32.cc
|
||||
)
|
||||
list(REMOVE_ITEM KERNEL_SRC ${KERNEL_SRC_ARM64_FILE})
|
||||
endif()
|
||||
|
@ -39,6 +40,7 @@ endif()
|
|||
if(NOT PLATFORM_ARM32)
|
||||
set(KERNEL_SRC_ARM32_FILE ${CMAKE_CURRENT_SOURCE_DIR}/fp32/convolution_im2col_arm32_fp32.cc
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/fp32/matmul_fp32_arm32.cc
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/fp32/convolution_winograd_arm32_fp32.cc
|
||||
)
|
||||
list(REMOVE_ITEM KERNEL_SRC ${KERNEL_SRC_ARM32_FILE})
|
||||
endif()
|
||||
|
@ -46,6 +48,7 @@ endif()
|
|||
if(NOT("${X86_64_SIMD}" STREQUAL "sse" OR "${X86_64_SIMD}" STREQUAL "avx" OR "${X86_64_SIMD}" STREQUAL "avx512"))
|
||||
set(KERNEL_SRC_SSE_FILE ${CMAKE_CURRENT_SOURCE_DIR}/fp32/convolution_im2col_sse_fp32.cc
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/fp32/matmul_fp32_sse.cc
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/fp32/convolution_winograd_sse_fp32.cc
|
||||
)
|
||||
list(REMOVE_ITEM KERNEL_SRC ${KERNEL_SRC_SSE_FILE})
|
||||
endif()
|
||||
|
@ -54,6 +57,7 @@ if(NOT("${X86_64_SIMD}" STREQUAL "avx" OR "${X86_64_SIMD}" STREQUAL "avx512"))
|
|||
set(KERNEL_SRC_AVX_FILE ${CMAKE_CURRENT_SOURCE_DIR}/fp32/convolution_im2col_avx_fp32.cc
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/fp32/matmul_fp32_avx.cc
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/fp32/convolution_slidewindows_avx_fp32.cc
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/fp32/convolution_winograd_avx_fp32.cc
|
||||
)
|
||||
list(REMOVE_ITEM KERNEL_SRC ${KERNEL_SRC_AVX_FILE})
|
||||
endif()
|
||||
|
|
|
@ -258,9 +258,9 @@ kernel::LiteKernel *ConvolutionDelegateCPUKernel::CpuConvFp32NHWCKernelSelect()
|
|||
|
||||
int out_unit;
|
||||
if (CheckIfUseWinograd(&out_unit, conv_param)) {
|
||||
kernel = new (std::nothrow) kernel::ConvolutionWinogradCPUKernel(
|
||||
op_parameter_, in_tensors_, out_tensors_, static_cast<const lite::InnerContext *>(this->ms_context_), out_unit,
|
||||
origin_weight_, origin_bias_);
|
||||
kernel = CreateConvolutionWinogradCPUKernel(op_parameter_, in_tensors_, out_tensors_,
|
||||
static_cast<const lite::InnerContext *>(this->ms_context_), out_unit,
|
||||
origin_weight_, origin_bias_);
|
||||
}
|
||||
|
||||
#ifdef ENABLE_AVX
|
||||
|
|
|
@ -0,0 +1,25 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "src/litert/kernel/cpu/fp32/convolution_winograd_arm32_fp32.h"
|
||||
|
||||
namespace mindspore::kernel {
|
||||
void ConvolutionWinogradARM32CPUKernel::InitGlobalVariable() {
|
||||
oc_block_ = C8NUM;
|
||||
tmp_data_tile_ = C4NUM;
|
||||
tile_num_ = C12NUM;
|
||||
}
|
||||
} // namespace mindspore::kernel
|
|
@ -0,0 +1,34 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_CPU_ARM32_FP32_CONVOLUTION_WINOGRAD_FP32_H_
|
||||
#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_CPU_ARM32_FP32_CONVOLUTION_WINOGRAD_FP32_H_
|
||||
|
||||
#include <vector>
|
||||
#include "src/litert/kernel/cpu/fp32/convolution_winograd_base_fp32.h"
|
||||
|
||||
namespace mindspore::kernel {
|
||||
class ConvolutionWinogradARM32CPUKernel : public ConvolutionWinogradBaseCPUKernel {
|
||||
public:
|
||||
ConvolutionWinogradARM32CPUKernel(OpParameter *parameter, const std::vector<lite::Tensor *> &inputs,
|
||||
const std::vector<lite::Tensor *> &outputs, const lite::InnerContext *ctx,
|
||||
int output_unit, float *origin_weight, float *origin_bias)
|
||||
: ConvolutionWinogradBaseCPUKernel(parameter, inputs, outputs, ctx, output_unit, origin_weight, origin_bias) {}
|
||||
~ConvolutionWinogradARM32CPUKernel() override {}
|
||||
void InitGlobalVariable() override;
|
||||
};
|
||||
} // namespace mindspore::kernel
|
||||
#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_CPU_ARM32_FP32_CONVOLUTION_WINOGRAD_FP32_H_
|
|
@ -0,0 +1,55 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "src/litert/kernel/cpu/fp32/convolution_winograd_arm64_fp32.h"
|
||||
#include "nnacl/fp32/winograd_utils.h"
|
||||
|
||||
using mindspore::lite::RET_ERROR;
|
||||
using mindspore::lite::RET_MEMORY_FAILED;
|
||||
using mindspore::lite::RET_NULL_PTR;
|
||||
using mindspore::lite::RET_OK;
|
||||
|
||||
namespace mindspore::kernel {
|
||||
void ConvolutionWinogradARM64CPUKernel::InitGlobalVariable() {
|
||||
oc_block_ = C8NUM;
|
||||
tmp_data_tile_ = C4NUM;
|
||||
tile_num_ = C12NUM;
|
||||
}
|
||||
|
||||
int ConvolutionWinogradARM64CPUKernel::ConfigInputOutput() {
|
||||
trans_func_.in_func_ = GetInputTransFunc(input_unit_);
|
||||
if (trans_func_.in_func_ == nullptr) {
|
||||
MS_LOG(ERROR) << "in_func_ is null.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
|
||||
trans_func_.in_step_func_ = GetInputTransStepFunc(input_unit_);
|
||||
if (trans_func_.in_step_func_ == nullptr) {
|
||||
MS_LOG(DEBUG) << "in_step_func_ is null.";
|
||||
}
|
||||
trans_func_.in_pack_func_ = GetInputTransPackFunc(input_unit_);
|
||||
if (trans_func_.in_pack_func_ == nullptr) {
|
||||
MS_LOG(DEBUG) << "in_pack_func_ is null.";
|
||||
}
|
||||
|
||||
trans_func_.out_func_ = GetOutputTransFunc(input_unit_, output_unit_, conv_param_->act_type_);
|
||||
if (trans_func_.out_func_ == nullptr) {
|
||||
MS_LOG(ERROR) << "out_func_ is null.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
} // namespace mindspore::kernel
|
|
@ -0,0 +1,35 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_CPU_ARM64_FP32_CONVOLUTION_WINOGRAD_FP32_H_
|
||||
#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_CPU_ARM64_FP32_CONVOLUTION_WINOGRAD_FP32_H_
|
||||
|
||||
#include <vector>
|
||||
#include "src/litert/kernel/cpu/fp32/convolution_winograd_base_fp32.h"
|
||||
|
||||
namespace mindspore::kernel {
|
||||
class ConvolutionWinogradARM64CPUKernel : public ConvolutionWinogradBaseCPUKernel {
|
||||
public:
|
||||
ConvolutionWinogradARM64CPUKernel(OpParameter *parameter, const std::vector<lite::Tensor *> &inputs,
|
||||
const std::vector<lite::Tensor *> &outputs, const lite::InnerContext *ctx,
|
||||
int output_unit, float *origin_weight, float *origin_bias)
|
||||
: ConvolutionWinogradBaseCPUKernel(parameter, inputs, outputs, ctx, output_unit, origin_weight, origin_bias) {}
|
||||
~ConvolutionWinogradARM64CPUKernel() override {}
|
||||
void InitGlobalVariable() override;
|
||||
int ConfigInputOutput() override;
|
||||
};
|
||||
} // namespace mindspore::kernel
|
||||
#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_CPU_ARM64_FP32_CONVOLUTION_WINOGRAD_FP32_H_
|
|
@ -0,0 +1,33 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "src/litert/kernel/cpu/fp32/convolution_winograd_avx_fp32.h"
|
||||
#include "nnacl/fp32/conv_winograd_fp32.h"
|
||||
#include "nnacl/pack.h"
|
||||
#include "include/errorcode.h"
|
||||
|
||||
using mindspore::lite::RET_ERROR;
|
||||
using mindspore::lite::RET_MEMORY_FAILED;
|
||||
using mindspore::lite::RET_NULL_PTR;
|
||||
using mindspore::lite::RET_OK;
|
||||
|
||||
namespace mindspore::kernel {
|
||||
void ConvolutionWinogradAVXCPUKernel::InitGlobalVariable() {
|
||||
oc_block_ = C16NUM;
|
||||
tmp_data_tile_ = C8NUM;
|
||||
tile_num_ = C12NUM;
|
||||
}
|
||||
} // namespace mindspore::kernel
|
|
@ -0,0 +1,34 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_CPU_AVX_FP32_CONVOLUTION_WINOGRAD_FP32_H_
|
||||
#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_CPU_AVX_FP32_CONVOLUTION_WINOGRAD_FP32_H_
|
||||
|
||||
#include <vector>
|
||||
#include "src/litert/kernel/cpu/fp32/convolution_winograd_base_fp32.h"
|
||||
|
||||
namespace mindspore::kernel {
|
||||
class ConvolutionWinogradAVXCPUKernel : public ConvolutionWinogradBaseCPUKernel {
|
||||
public:
|
||||
ConvolutionWinogradAVXCPUKernel(OpParameter *parameter, const std::vector<lite::Tensor *> &inputs,
|
||||
const std::vector<lite::Tensor *> &outputs, const lite::InnerContext *ctx,
|
||||
int output_unit, float *origin_weight, float *origin_bias)
|
||||
: ConvolutionWinogradBaseCPUKernel(parameter, inputs, outputs, ctx, output_unit, origin_weight, origin_bias) {}
|
||||
~ConvolutionWinogradAVXCPUKernel() override {}
|
||||
void InitGlobalVariable() override;
|
||||
};
|
||||
} // namespace mindspore::kernel
|
||||
#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_CPU_AVX_FP32_CONVOLUTION_WINOGRAD_FP32_H_
|
|
@ -0,0 +1,295 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "src/litert/kernel/cpu/fp32/convolution_winograd_base_fp32.h"
|
||||
#include "nnacl/fp32/conv_winograd_fp32.h"
|
||||
#include "nnacl/pack.h"
|
||||
#include "include/errorcode.h"
|
||||
|
||||
using mindspore::lite::RET_ERROR;
|
||||
using mindspore::lite::RET_MEMORY_FAILED;
|
||||
using mindspore::lite::RET_NULL_PTR;
|
||||
using mindspore::lite::RET_OK;
|
||||
|
||||
namespace mindspore::kernel {
|
||||
#define CONV_MIN_CALC_BLOCK C1NUM
|
||||
void ConvolutionWinogradBaseCPUKernel::InitGlobalVariable() {
|
||||
oc_block_ = C8NUM;
|
||||
tmp_data_tile_ = C4NUM;
|
||||
tile_num_ = C12NUM;
|
||||
}
|
||||
|
||||
int ConvolutionWinogradBaseCPUKernel::WinogradFilterTransform(const float *weight_data, float *matrix_g,
|
||||
const float *matrix_gt, int oc_block) {
|
||||
if (oc_block == 0) {
|
||||
MS_LOG(ERROR) << "Divide by zero";
|
||||
return RET_ERROR;
|
||||
}
|
||||
|
||||
return WinogradWeightTransform(weight_data, reinterpret_cast<float *>(packed_weight_), matrix_g, matrix_gt, oc_block,
|
||||
input_unit_, kernel_unit_, conv_param_->input_channel_, conv_param_->output_channel_,
|
||||
true);
|
||||
}
|
||||
|
||||
int ConvolutionWinogradBaseCPUKernel::InitTmpBuffer() {
|
||||
MS_ASSERT(ctx_->allocator != nullptr);
|
||||
size_t tile_buffer_size =
|
||||
thread_count_ * tile_num_ * input_unit_ * input_unit_ * conv_param_->input_channel_ * sizeof(float);
|
||||
trans_input_ = reinterpret_cast<float *>(ctx_->allocator->Malloc(tile_buffer_size));
|
||||
if (trans_input_ == nullptr) {
|
||||
MS_LOG(ERROR) << "malloc trans_input_ failed.";
|
||||
return RET_MEMORY_FAILED;
|
||||
}
|
||||
|
||||
int oc8 = UP_ROUND(conv_param_->output_channel_, C8NUM);
|
||||
gemm_out_ = reinterpret_cast<float *>(
|
||||
ctx_->allocator->Malloc(thread_count_ * tile_num_ * input_unit_ * input_unit_ * oc8 * sizeof(float)));
|
||||
if (gemm_out_ == nullptr) {
|
||||
MS_LOG(ERROR) << "malloc gemm_out_ failed.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
|
||||
tmp_data_ = reinterpret_cast<float *>(
|
||||
ctx_->allocator->Malloc(thread_count_ * tmp_data_tile_ * input_unit_ * input_unit_ * sizeof(float)));
|
||||
if (tmp_data_ == nullptr) {
|
||||
MS_LOG(ERROR) << "malloc tmp_data_ failed.";
|
||||
return RET_MEMORY_FAILED;
|
||||
}
|
||||
|
||||
col_buffer_ = reinterpret_cast<float *>(
|
||||
ctx_->allocator->Malloc(thread_count_ * tile_num_ * conv_param_->input_channel_ * sizeof(float)));
|
||||
if (col_buffer_ == nullptr) {
|
||||
MS_LOG(ERROR) << "malloc col_buffer_ failed.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
|
||||
opt_input_trans_ = reinterpret_cast<float *>(
|
||||
ctx_->allocator->Malloc(thread_count_ * tile_num_ * input_unit_ * input_unit_ *
|
||||
UP_ROUND(conv_param_->input_channel_, tmp_data_tile_) * sizeof(float)));
|
||||
if (opt_input_trans_ == nullptr) {
|
||||
MS_LOG(ERROR) << "malloc opt_input_trans_ failed.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
|
||||
tmp_buffer_address_list_[C0NUM] = trans_input_;
|
||||
tmp_buffer_address_list_[C1NUM] = gemm_out_;
|
||||
tmp_buffer_address_list_[C2NUM] = tmp_data_;
|
||||
tmp_buffer_address_list_[C3NUM] = col_buffer_;
|
||||
tmp_buffer_address_list_[C4NUM] = opt_input_trans_;
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
int ConvolutionWinogradBaseCPUKernel::ConfigInputOutput() {
|
||||
trans_func_.in_func_ = GetInputTransFunc(input_unit_);
|
||||
if (trans_func_.in_func_ == nullptr) {
|
||||
MS_LOG(ERROR) << "in_func_ is null.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
|
||||
trans_func_.out_func_ = GetOutputTransFunc(input_unit_, output_unit_, conv_param_->act_type_);
|
||||
if (trans_func_.out_func_ == nullptr) {
|
||||
MS_LOG(ERROR) << "out_func_ is null.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
int ConvolutionWinogradBaseCPUKernel::Prepare() {
|
||||
CHECK_LESS_RETURN(in_tensors_.size(), C2NUM);
|
||||
CHECK_LESS_RETURN(out_tensors_.size(), 1);
|
||||
|
||||
InitGlobalVariable();
|
||||
kernel_unit_ = conv_param_->kernel_h_;
|
||||
input_unit_ = output_unit_ + kernel_unit_ - 1;
|
||||
conv_param_->input_unit_ = input_unit_;
|
||||
conv_param_->output_unit_ = output_unit_;
|
||||
if (op_parameter_->is_train_session_) {
|
||||
auto filter_tensor = in_tensors_.at(kWeightIndex);
|
||||
CHECK_NULL_RETURN(filter_tensor);
|
||||
int in_channel = filter_tensor->Channel();
|
||||
int out_channel = filter_tensor->Batch();
|
||||
auto trans_matrix_data_size =
|
||||
input_unit_ * input_unit_ * in_channel * UP_ROUND(out_channel, oc_block_) * sizeof(float);
|
||||
set_workspace_size(trans_matrix_data_size);
|
||||
}
|
||||
auto ret = InitConvWeightBias();
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "Init weight bias failed.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
int ConvolutionWinogradBaseCPUKernel::UpdateThreadNumProcess(int32_t kernel_type, int64_t per_unit_load_num,
|
||||
int64_t per_unit_store_num, int64_t unit_num) {
|
||||
if (conv_param_->input_batch_ % conv_param_->thread_num_ == 0) {
|
||||
use_batch_cut_flag_ = true;
|
||||
return RET_OK;
|
||||
} else {
|
||||
use_batch_cut_flag_ = false;
|
||||
}
|
||||
|
||||
auto output_hw = conv_param_->output_h_ * conv_param_->output_w_;
|
||||
const int tile_num = C12NUM;
|
||||
|
||||
conv_param_->thread_num_ =
|
||||
MSMIN(UP_DIV(UP_DIV(output_hw, tile_num), CONV_MIN_CALC_BLOCK), op_parameter_->thread_num_);
|
||||
thread_count_ = conv_param_->thread_num_;
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
int ConvolutionWinogradBaseCPUKernel::ReSize() {
|
||||
auto ret = ConvolutionBaseCPUKernel::CheckResizeValid();
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "Resize is invalid.";
|
||||
return ret;
|
||||
}
|
||||
ret = ConvolutionBaseCPUKernel::Prepare();
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "conv base init failed.";
|
||||
return ret;
|
||||
}
|
||||
if (UpdateThreadNumPass(TC_PTYPE(type_), 0, 0, 0) != RET_OK) {
|
||||
return RET_ERROR;
|
||||
}
|
||||
ret = ConfigInputOutput();
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "ConfigInputOutput failed.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
conv_param_->out_format_ = out_tensors_[0]->format();
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
int ConvolutionWinogradBaseCPUKernel::RunImpl(int task_id) {
|
||||
auto input_tensor = in_tensors_.at(kInputIndex);
|
||||
CHECK_NULL_RETURN(input_tensor);
|
||||
auto ori_input_data = reinterpret_cast<float *>(input_tensor->data());
|
||||
CHECK_NULL_RETURN(ori_input_data);
|
||||
CHECK_NULL_RETURN(out_tensors_.front());
|
||||
auto output_data = reinterpret_cast<float *>(out_tensors_.front()->data());
|
||||
CHECK_NULL_RETURN(output_data);
|
||||
|
||||
if (use_batch_cut_flag_) {
|
||||
ConvWinogardFp32CutByBatch(ori_input_data, reinterpret_cast<float *>(packed_weight_),
|
||||
reinterpret_cast<const float *>(bias_data_), output_data, tmp_buffer_address_list_,
|
||||
task_id, conv_param_, trans_func_);
|
||||
} else {
|
||||
ConvWinogardFp32(ori_input_data, reinterpret_cast<float *>(packed_weight_),
|
||||
reinterpret_cast<const float *>(bias_data_), output_data, tmp_buffer_address_list_, task_id,
|
||||
conv_param_, trans_func_);
|
||||
}
|
||||
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
int ConvolutionWinogradImpl(void *cdata, int task_id, float lhs_scale, float rhs_scale) {
|
||||
auto conv = reinterpret_cast<ConvolutionWinogradBaseCPUKernel *>(cdata);
|
||||
auto error_code = conv->RunImpl(task_id);
|
||||
if (error_code != RET_OK) {
|
||||
MS_LOG(ERROR) << "ConvolutionWinograd Run error task_id[" << task_id << "] error_code[" << error_code << "]";
|
||||
return RET_ERROR;
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
int ConvolutionWinogradBaseCPUKernel::Run() {
|
||||
auto ret = InitTmpBuffer();
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "Init tmp buffer failed.";
|
||||
FreeTmpBuffer();
|
||||
return RET_ERROR;
|
||||
}
|
||||
if (RepackWeight() != RET_OK) {
|
||||
MS_LOG(ERROR) << "Repack weight failed.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
|
||||
ret = ParallelLaunch(this->ms_context_, ConvolutionWinogradImpl, this, thread_count_);
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "conv winograd error error_code[" << ret << "]";
|
||||
}
|
||||
|
||||
FreeTmpBuffer();
|
||||
return ret;
|
||||
}
|
||||
|
||||
int ConvolutionWinogradBaseCPUKernel::MallocWeightBiasData() {
|
||||
auto filter_tensor = in_tensors_.at(kWeightIndex);
|
||||
int in_channel = filter_tensor->Channel();
|
||||
if (in_channel < 0) {
|
||||
MS_LOG(ERROR) << "get channel from filter tensor failed.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
int out_channel = filter_tensor->Batch();
|
||||
if (out_channel < 0) {
|
||||
MS_LOG(ERROR) << "get batch from filter tensor failed.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
conv_param_->input_channel_ = in_channel;
|
||||
conv_param_->output_channel_ = out_channel;
|
||||
|
||||
// set data
|
||||
auto trans_matrix_data_size =
|
||||
input_unit_ * input_unit_ * in_channel * UP_ROUND(out_channel, oc_block_) * sizeof(float);
|
||||
if (!op_parameter_->is_train_session_) {
|
||||
if (packed_weight_ == nullptr) {
|
||||
CHECK_LESS_RETURN(MAX_MALLOC_SIZE, trans_matrix_data_size);
|
||||
packed_weight_ = lite::PackWeightManager::GetInstance()->GetPackData(in_tensors_[1]->data(),
|
||||
trans_matrix_data_size, &weight_is_packed_);
|
||||
if (packed_weight_ == nullptr) {
|
||||
MS_LOG(ERROR) << "malloc matrix_buffer failed.";
|
||||
return RET_MEMORY_FAILED;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
float matrix_a[64];
|
||||
float matrix_at[64];
|
||||
float matrix_b[64];
|
||||
float matrix_bt[64];
|
||||
float coef = 1.0f;
|
||||
if (input_unit_ == CONV_INPUT_UNIT_SIZE) {
|
||||
coef = 0.5f;
|
||||
}
|
||||
auto ret =
|
||||
CookToomFilter(matrix_a, matrix_at, matrix_b, matrix_bt, matrix_g_, matrix_gt_, coef, output_unit_, kernel_unit_);
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "get matrix g from CookToomFilter failed.";
|
||||
return ret;
|
||||
}
|
||||
|
||||
// init bias
|
||||
size_t new_bias_size = UP_ROUND(out_channel, C4NUM) * sizeof(float);
|
||||
if (bias_data_ == nullptr) {
|
||||
CHECK_LESS_RETURN(MAX_MALLOC_SIZE, new_bias_size);
|
||||
bias_data_ = malloc(new_bias_size);
|
||||
if (bias_data_ == nullptr) {
|
||||
MS_LOG(ERROR) << "malloc bias_data_ failed.";
|
||||
return RET_MEMORY_FAILED;
|
||||
}
|
||||
}
|
||||
memset(bias_data_, 0, new_bias_size);
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
void ConvolutionWinogradBaseCPUKernel::PackWeight() {
|
||||
auto weight_tensor = in_tensors_.at(kWeightIndex);
|
||||
void *origin_weight = (op_parameter_->is_train_session_) ? weight_tensor->data() : origin_weight_;
|
||||
MS_ASSERT(origin_weight != nullptr);
|
||||
WinogradFilterTransform(reinterpret_cast<float *>(origin_weight), matrix_g_, matrix_gt_, oc_block_);
|
||||
}
|
||||
} // namespace mindspore::kernel
|
|
@ -0,0 +1,92 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_CPU_BASE_FP32_CONVOLUTION_WINOGRAD_FP32_H_
|
||||
#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_CPU_BASE_FP32_CONVOLUTION_WINOGRAD_FP32_H_
|
||||
|
||||
#include <vector>
|
||||
#include "src/litert/lite_kernel.h"
|
||||
#include "nnacl/fp32/winograd_transform.h"
|
||||
#include "nnacl/base/minimal_filtering_generator.h"
|
||||
#include "nnacl/fp32/conv_winograd_fp32.h"
|
||||
#include "src/litert/kernel/cpu/base/convolution_base.h"
|
||||
|
||||
#define CONV_INPUT_UNIT_SIZE 8
|
||||
namespace mindspore::kernel {
|
||||
class ConvolutionWinogradBaseCPUKernel : public ConvolutionBaseCPUKernel {
|
||||
public:
|
||||
ConvolutionWinogradBaseCPUKernel(OpParameter *parameter, const std::vector<lite::Tensor *> &inputs,
|
||||
const std::vector<lite::Tensor *> &outputs, const lite::InnerContext *ctx,
|
||||
int output_unit, float *origin_weight, float *origin_bias)
|
||||
: ConvolutionBaseCPUKernel(parameter, inputs, outputs, ctx, origin_weight, origin_bias),
|
||||
output_unit_(output_unit) {}
|
||||
~ConvolutionWinogradBaseCPUKernel() override {}
|
||||
virtual void InitGlobalVariable();
|
||||
int Prepare() override;
|
||||
int ReSize() override;
|
||||
int Run() override;
|
||||
int RunImpl(int task_id);
|
||||
int InitTmpBuffer();
|
||||
virtual int ConfigInputOutput();
|
||||
int WinogradFilterTransform(const float *weight_data, float *matrix_g, const float *matrix_gt, int oc_block);
|
||||
|
||||
private:
|
||||
int MallocWeightBiasData() override;
|
||||
void PackWeight() override;
|
||||
int UpdateThreadNumProcess(int32_t kernel_type, int64_t per_unit_load_num, int64_t per_unit_store_num,
|
||||
int64_t unit_num) override;
|
||||
void FreeTmpBuffer() {
|
||||
if (trans_input_ != nullptr) {
|
||||
ctx_->allocator->Free(trans_input_);
|
||||
trans_input_ = nullptr;
|
||||
}
|
||||
if (tmp_data_ != nullptr) {
|
||||
ctx_->allocator->Free(tmp_data_);
|
||||
tmp_data_ = nullptr;
|
||||
}
|
||||
if (gemm_out_ != nullptr) {
|
||||
ctx_->allocator->Free(gemm_out_);
|
||||
gemm_out_ = nullptr;
|
||||
}
|
||||
if (col_buffer_ != nullptr) {
|
||||
ctx_->allocator->Free(col_buffer_);
|
||||
col_buffer_ = nullptr;
|
||||
}
|
||||
if (opt_input_trans_ != nullptr) {
|
||||
ctx_->allocator->Free(opt_input_trans_);
|
||||
opt_input_trans_ = nullptr;
|
||||
}
|
||||
}
|
||||
|
||||
protected:
|
||||
int kernel_unit_{0};
|
||||
int input_unit_{0};
|
||||
int output_unit_{0};
|
||||
int oc_block_{0};
|
||||
int tile_num_{0};
|
||||
int tmp_data_tile_{0};
|
||||
float *tmp_data_ = nullptr;
|
||||
float *trans_input_ = nullptr;
|
||||
float *gemm_out_ = nullptr;
|
||||
float *col_buffer_ = nullptr;
|
||||
float *opt_input_trans_ = nullptr;
|
||||
float matrix_g_[64];
|
||||
float matrix_gt_[64];
|
||||
TmpBufferAddress tmp_buffer_address_list_[5] = {nullptr};
|
||||
TransFuncList trans_func_;
|
||||
};
|
||||
} // namespace mindspore::kernel
|
||||
#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_CPU_BASE_FP32_CONVOLUTION_WINOGRAD_FP32_H_
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -15,290 +15,61 @@
|
|||
*/
|
||||
|
||||
#include "src/litert/kernel/cpu/fp32/convolution_winograd_fp32.h"
|
||||
#include "nnacl/fp32/conv_winograd_fp32.h"
|
||||
#include "nnacl/pack.h"
|
||||
#include "include/errorcode.h"
|
||||
#include "src/litert/kernel/cpu/fp32/convolution_winograd_base_fp32.h"
|
||||
#if defined(ENABLE_AVX)
|
||||
#include "src/litert/kernel/cpu/fp32/convolution_winograd_avx_fp32.h"
|
||||
#endif
|
||||
|
||||
using mindspore::lite::RET_ERROR;
|
||||
using mindspore::lite::RET_MEMORY_FAILED;
|
||||
using mindspore::lite::RET_NULL_PTR;
|
||||
using mindspore::lite::RET_OK;
|
||||
#if defined(ENABLE_SSE)
|
||||
#include "src/litert/kernel/cpu/fp32/convolution_winograd_sse_fp32.h"
|
||||
#endif
|
||||
|
||||
#if defined(ENABLE_ARM32)
|
||||
#include "src/litert/kernel/cpu/fp32/convolution_winograd_arm32_fp32.h"
|
||||
#endif
|
||||
|
||||
#if defined(ENABLE_ARM64)
|
||||
#include "src/litert/kernel/cpu/fp32/convolution_winograd_arm64_fp32.h"
|
||||
#endif
|
||||
#include "nnacl/intrinsics/ms_simd_cpu_info.h"
|
||||
|
||||
namespace mindspore::kernel {
|
||||
#define CONV_MIN_CALC_BLOCK C1NUM
|
||||
int ConvolutionWinogradCPUKernel::WinogradFilterTransform(const float *weight_data, float *matrix_g,
|
||||
const float *matrix_gt, int oc_block) {
|
||||
if (oc_block == 0) {
|
||||
MS_LOG(ERROR) << "Divide by zero";
|
||||
return RET_ERROR;
|
||||
}
|
||||
LiteKernel *CreateConvolutionWinogradCPUKernel(OpParameter *parameter, const std::vector<lite::Tensor *> &inputs,
|
||||
const std::vector<lite::Tensor *> &outputs,
|
||||
const lite::InnerContext *ctx, int out_unit, float *origin_weight,
|
||||
float *origin_bias) {
|
||||
LiteKernel *kernel = nullptr;
|
||||
|
||||
return WinogradWeightTransform(weight_data, reinterpret_cast<float *>(packed_weight_), matrix_g, matrix_gt, oc_block,
|
||||
input_unit_, kernel_unit_, conv_param_->input_channel_, conv_param_->output_channel_,
|
||||
true);
|
||||
}
|
||||
|
||||
int ConvolutionWinogradCPUKernel::InitTmpBuffer() {
|
||||
MS_ASSERT(ctx_->allocator != nullptr);
|
||||
size_t tile_buffer_size =
|
||||
thread_count_ * tile_num_ * input_unit_ * input_unit_ * conv_param_->input_channel_ * sizeof(float);
|
||||
trans_input_ = reinterpret_cast<float *>(ctx_->allocator->Malloc(tile_buffer_size));
|
||||
if (trans_input_ == nullptr) {
|
||||
MS_LOG(ERROR) << "malloc trans_input_ failed.";
|
||||
return RET_MEMORY_FAILED;
|
||||
}
|
||||
|
||||
int oc8 = UP_ROUND(conv_param_->output_channel_, C8NUM);
|
||||
gemm_out_ = reinterpret_cast<float *>(
|
||||
ctx_->allocator->Malloc(thread_count_ * tile_num_ * input_unit_ * input_unit_ * oc8 * sizeof(float)));
|
||||
if (gemm_out_ == nullptr) {
|
||||
MS_LOG(ERROR) << "malloc gemm_out_ failed.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
|
||||
tmp_data_ = reinterpret_cast<float *>(
|
||||
ctx_->allocator->Malloc(thread_count_ * tmp_data_tile_ * input_unit_ * input_unit_ * sizeof(float)));
|
||||
if (tmp_data_ == nullptr) {
|
||||
MS_LOG(ERROR) << "malloc tmp_data_ failed.";
|
||||
return RET_MEMORY_FAILED;
|
||||
}
|
||||
|
||||
col_buffer_ = reinterpret_cast<float *>(
|
||||
ctx_->allocator->Malloc(thread_count_ * tile_num_ * conv_param_->input_channel_ * sizeof(float)));
|
||||
if (col_buffer_ == nullptr) {
|
||||
MS_LOG(ERROR) << "malloc col_buffer_ failed.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
|
||||
opt_input_trans_ = reinterpret_cast<float *>(
|
||||
ctx_->allocator->Malloc(thread_count_ * tile_num_ * input_unit_ * input_unit_ *
|
||||
UP_ROUND(conv_param_->input_channel_, tmp_data_tile_) * sizeof(float)));
|
||||
if (opt_input_trans_ == nullptr) {
|
||||
MS_LOG(ERROR) << "malloc opt_input_trans_ failed.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
|
||||
tmp_buffer_address_list_[0] = trans_input_;
|
||||
tmp_buffer_address_list_[1] = gemm_out_;
|
||||
tmp_buffer_address_list_[2] = tmp_data_;
|
||||
tmp_buffer_address_list_[3] = col_buffer_;
|
||||
tmp_buffer_address_list_[4] = opt_input_trans_;
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
int ConvolutionWinogradCPUKernel::ConfigInputOutput() {
|
||||
trans_func_.in_func_ = GetInputTransFunc(input_unit_);
|
||||
if (trans_func_.in_func_ == nullptr) {
|
||||
MS_LOG(ERROR) << "in_func_ is null.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
#ifdef ENABLE_ARM64
|
||||
trans_func_.in_step_func_ = GetInputTransStepFunc(input_unit_);
|
||||
if (trans_func_.in_step_func_ == nullptr) {
|
||||
MS_LOG(DEBUG) << "in_step_func_ is null.";
|
||||
}
|
||||
trans_func_.in_pack_func_ = GetInputTransPackFunc(input_unit_);
|
||||
if (trans_func_.in_pack_func_ == nullptr) {
|
||||
MS_LOG(DEBUG) << "in_pack_func_ is null.";
|
||||
#if defined(ENABLE_AVX)
|
||||
if (kernel == nullptr) {
|
||||
kernel = new (std::nothrow)
|
||||
kernel::ConvolutionWinogradAVXCPUKernel(parameter, inputs, outputs, ctx, out_unit, origin_weight, origin_bias);
|
||||
}
|
||||
#endif
|
||||
trans_func_.out_func_ = GetOutputTransFunc(input_unit_, output_unit_, conv_param_->act_type_);
|
||||
if (trans_func_.out_func_ == nullptr) {
|
||||
MS_LOG(ERROR) << "out_func_ is null.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
int ConvolutionWinogradCPUKernel::Prepare() {
|
||||
CHECK_LESS_RETURN(in_tensors_.size(), C2NUM);
|
||||
CHECK_LESS_RETURN(out_tensors_.size(), 1);
|
||||
tile_num_ = C12NUM;
|
||||
#ifdef ENABLE_AVX
|
||||
oc_block_ = C16NUM;
|
||||
tmp_data_tile_ = C8NUM;
|
||||
#else
|
||||
oc_block_ = C8NUM;
|
||||
tmp_data_tile_ = C4NUM;
|
||||
#if defined(ENABLE_SSE)
|
||||
if (kernel == nullptr) {
|
||||
kernel = new (std::nothrow)
|
||||
kernel::ConvolutionWinogradSSECPUKernel(parameter, inputs, outputs, ctx, out_unit, origin_weight, origin_bias);
|
||||
}
|
||||
#endif
|
||||
kernel_unit_ = conv_param_->kernel_h_;
|
||||
input_unit_ = output_unit_ + kernel_unit_ - 1;
|
||||
conv_param_->input_unit_ = input_unit_;
|
||||
conv_param_->output_unit_ = output_unit_;
|
||||
if (op_parameter_->is_train_session_) {
|
||||
auto filter_tensor = in_tensors_.at(kWeightIndex);
|
||||
CHECK_NULL_RETURN(filter_tensor);
|
||||
int in_channel = filter_tensor->Channel();
|
||||
int out_channel = filter_tensor->Batch();
|
||||
auto trans_matrix_data_size =
|
||||
input_unit_ * input_unit_ * in_channel * UP_ROUND(out_channel, oc_block_) * sizeof(float);
|
||||
set_workspace_size(trans_matrix_data_size);
|
||||
}
|
||||
auto ret = InitConvWeightBias();
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "Init weight bias failed.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
int ConvolutionWinogradCPUKernel::UpdateThreadNumProcess(int32_t kernel_type, int64_t per_unit_load_num,
|
||||
int64_t per_unit_store_num, int64_t unit_num) {
|
||||
if (conv_param_->input_batch_ % conv_param_->thread_num_ == 0) {
|
||||
use_batch_cut_flag_ = true;
|
||||
return RET_OK;
|
||||
} else {
|
||||
use_batch_cut_flag_ = false;
|
||||
#if defined(ENABLE_ARM64)
|
||||
if (kernel == nullptr) {
|
||||
kernel = new (std::nothrow)
|
||||
kernel::ConvolutionWinogradARM64CPUKernel(parameter, inputs, outputs, ctx, out_unit, origin_weight, origin_bias);
|
||||
}
|
||||
#elif defined(ENABLE_ARM32)
|
||||
if (kernel == nullptr) {
|
||||
kernel = new (std::nothrow)
|
||||
kernel::ConvolutionWinogradARM32CPUKernel(parameter, inputs, outputs, ctx, out_unit, origin_weight, origin_bias);
|
||||
}
|
||||
#endif
|
||||
|
||||
auto output_hw = conv_param_->output_h_ * conv_param_->output_w_;
|
||||
const int tile_num = C12NUM;
|
||||
|
||||
conv_param_->thread_num_ =
|
||||
MSMIN(UP_DIV(UP_DIV(output_hw, tile_num), CONV_MIN_CALC_BLOCK), op_parameter_->thread_num_);
|
||||
thread_count_ = conv_param_->thread_num_;
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
int ConvolutionWinogradCPUKernel::ReSize() {
|
||||
auto ret = ConvolutionBaseCPUKernel::CheckResizeValid();
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "Resize is invalid.";
|
||||
return ret;
|
||||
if (kernel == nullptr) {
|
||||
kernel = new (std::nothrow)
|
||||
kernel::ConvolutionWinogradBaseCPUKernel(parameter, inputs, outputs, ctx, out_unit, origin_weight, origin_bias);
|
||||
}
|
||||
ret = ConvolutionBaseCPUKernel::Prepare();
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "conv base init failed.";
|
||||
return ret;
|
||||
}
|
||||
if (UpdateThreadNumPass(TC_PTYPE(type_), 0, 0, 0) != RET_OK) {
|
||||
return RET_ERROR;
|
||||
}
|
||||
ret = ConfigInputOutput();
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "ConfigInputOutput failed.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
conv_param_->out_format_ = out_tensors_[0]->format();
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
int ConvolutionWinogradCPUKernel::RunImpl(int task_id) {
|
||||
auto input_tensor = in_tensors_.at(kInputIndex);
|
||||
CHECK_NULL_RETURN(input_tensor);
|
||||
auto ori_input_data = reinterpret_cast<float *>(input_tensor->data());
|
||||
CHECK_NULL_RETURN(ori_input_data);
|
||||
CHECK_NULL_RETURN(out_tensors_.front());
|
||||
auto output_data = reinterpret_cast<float *>(out_tensors_.front()->data());
|
||||
CHECK_NULL_RETURN(output_data);
|
||||
|
||||
if (use_batch_cut_flag_) {
|
||||
ConvWinogardFp32CutByBatch(ori_input_data, reinterpret_cast<float *>(packed_weight_),
|
||||
reinterpret_cast<const float *>(bias_data_), output_data, tmp_buffer_address_list_,
|
||||
task_id, conv_param_, trans_func_);
|
||||
} else {
|
||||
ConvWinogardFp32(ori_input_data, reinterpret_cast<float *>(packed_weight_),
|
||||
reinterpret_cast<const float *>(bias_data_), output_data, tmp_buffer_address_list_, task_id,
|
||||
conv_param_, trans_func_);
|
||||
}
|
||||
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
int ConvolutionWinogradImpl(void *cdata, int task_id, float lhs_scale, float rhs_scale) {
|
||||
auto conv = reinterpret_cast<ConvolutionWinogradCPUKernel *>(cdata);
|
||||
auto error_code = conv->RunImpl(task_id);
|
||||
if (error_code != RET_OK) {
|
||||
MS_LOG(ERROR) << "ConvolutionWinograd Run error task_id[" << task_id << "] error_code[" << error_code << "]";
|
||||
return RET_ERROR;
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
int ConvolutionWinogradCPUKernel::Run() {
|
||||
auto ret = InitTmpBuffer();
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "Init tmp buffer failed.";
|
||||
FreeTmpBuffer();
|
||||
return RET_ERROR;
|
||||
}
|
||||
if (RepackWeight() != RET_OK) {
|
||||
MS_LOG(ERROR) << "Repack weight failed.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
|
||||
ret = ParallelLaunch(this->ms_context_, ConvolutionWinogradImpl, this, thread_count_);
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "conv winograd error error_code[" << ret << "]";
|
||||
}
|
||||
|
||||
FreeTmpBuffer();
|
||||
return ret;
|
||||
}
|
||||
|
||||
int ConvolutionWinogradCPUKernel::MallocWeightBiasData() {
|
||||
auto filter_tensor = in_tensors_.at(kWeightIndex);
|
||||
int in_channel = filter_tensor->Channel();
|
||||
if (in_channel < 0) {
|
||||
MS_LOG(ERROR) << "get channel from filter tensor failed.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
int out_channel = filter_tensor->Batch();
|
||||
if (out_channel < 0) {
|
||||
MS_LOG(ERROR) << "get batch from filter tensor failed.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
conv_param_->input_channel_ = in_channel;
|
||||
conv_param_->output_channel_ = out_channel;
|
||||
|
||||
// set data
|
||||
auto trans_matrix_data_size =
|
||||
input_unit_ * input_unit_ * in_channel * UP_ROUND(out_channel, oc_block_) * sizeof(float);
|
||||
if (!op_parameter_->is_train_session_) {
|
||||
if (packed_weight_ == nullptr) {
|
||||
CHECK_LESS_RETURN(MAX_MALLOC_SIZE, trans_matrix_data_size);
|
||||
packed_weight_ = lite::PackWeightManager::GetInstance()->GetPackData(in_tensors_[1]->data(),
|
||||
trans_matrix_data_size, &weight_is_packed_);
|
||||
if (packed_weight_ == nullptr) {
|
||||
MS_LOG(ERROR) << "malloc matrix_buffer failed.";
|
||||
return RET_MEMORY_FAILED;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
float matrix_a[64];
|
||||
float matrix_at[64];
|
||||
float matrix_b[64];
|
||||
float matrix_bt[64];
|
||||
float coef = 1.0f;
|
||||
if (input_unit_ == CONV_INPUT_UNIT_SIZE) {
|
||||
coef = 0.5f;
|
||||
}
|
||||
auto ret =
|
||||
CookToomFilter(matrix_a, matrix_at, matrix_b, matrix_bt, matrix_g_, matrix_gt_, coef, output_unit_, kernel_unit_);
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "get matrix g from CookToomFilter failed.";
|
||||
return ret;
|
||||
}
|
||||
|
||||
// init bias
|
||||
size_t new_bias_size = UP_ROUND(out_channel, C4NUM) * sizeof(float);
|
||||
if (bias_data_ == nullptr) {
|
||||
CHECK_LESS_RETURN(MAX_MALLOC_SIZE, new_bias_size);
|
||||
bias_data_ = malloc(new_bias_size);
|
||||
if (bias_data_ == nullptr) {
|
||||
MS_LOG(ERROR) << "malloc bias_data_ failed.";
|
||||
return RET_MEMORY_FAILED;
|
||||
}
|
||||
}
|
||||
memset(bias_data_, 0, new_bias_size);
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
void ConvolutionWinogradCPUKernel::PackWeight() {
|
||||
auto weight_tensor = in_tensors_.at(kWeightIndex);
|
||||
void *origin_weight = (op_parameter_->is_train_session_) ? weight_tensor->data() : origin_weight_;
|
||||
MS_ASSERT(origin_weight != nullptr);
|
||||
WinogradFilterTransform(reinterpret_cast<float *>(origin_weight), matrix_g_, matrix_gt_, oc_block_);
|
||||
return kernel;
|
||||
}
|
||||
} // namespace mindspore::kernel
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -19,72 +19,13 @@
|
|||
|
||||
#include <vector>
|
||||
#include "src/litert/lite_kernel.h"
|
||||
#include "nnacl/fp32/winograd_transform.h"
|
||||
#include "nnacl/base/minimal_filtering_generator.h"
|
||||
#include "nnacl/fp32/conv_winograd_fp32.h"
|
||||
#include "nnacl/op_base.h"
|
||||
#include "src/litert/kernel/cpu/base/convolution_base.h"
|
||||
|
||||
#define CONV_INPUT_UNIT_SIZE 8
|
||||
namespace mindspore::kernel {
|
||||
class ConvolutionWinogradCPUKernel : public ConvolutionBaseCPUKernel {
|
||||
public:
|
||||
ConvolutionWinogradCPUKernel(OpParameter *parameter, const std::vector<lite::Tensor *> &inputs,
|
||||
const std::vector<lite::Tensor *> &outputs, const lite::InnerContext *ctx,
|
||||
int output_unit, float *origin_weight, float *origin_bias)
|
||||
: ConvolutionBaseCPUKernel(parameter, inputs, outputs, ctx, origin_weight, origin_bias),
|
||||
output_unit_(output_unit) {}
|
||||
~ConvolutionWinogradCPUKernel() override {}
|
||||
int Prepare() override;
|
||||
int ReSize() override;
|
||||
int Run() override;
|
||||
int RunImpl(int task_id);
|
||||
int InitTmpBuffer();
|
||||
int ConfigInputOutput();
|
||||
int WinogradFilterTransform(const float *weight_data, float *matrix_g, const float *matrix_gt, int oc_block);
|
||||
|
||||
private:
|
||||
int MallocWeightBiasData() override;
|
||||
void PackWeight() override;
|
||||
int UpdateThreadNumProcess(int32_t kernel_type, int64_t per_unit_load_num, int64_t per_unit_store_num,
|
||||
int64_t unit_num) override;
|
||||
void FreeTmpBuffer() {
|
||||
if (trans_input_ != nullptr) {
|
||||
ctx_->allocator->Free(trans_input_);
|
||||
trans_input_ = nullptr;
|
||||
}
|
||||
if (tmp_data_ != nullptr) {
|
||||
ctx_->allocator->Free(tmp_data_);
|
||||
tmp_data_ = nullptr;
|
||||
}
|
||||
if (gemm_out_ != nullptr) {
|
||||
ctx_->allocator->Free(gemm_out_);
|
||||
gemm_out_ = nullptr;
|
||||
}
|
||||
if (col_buffer_ != nullptr) {
|
||||
ctx_->allocator->Free(col_buffer_);
|
||||
col_buffer_ = nullptr;
|
||||
}
|
||||
if (opt_input_trans_ != nullptr) {
|
||||
ctx_->allocator->Free(opt_input_trans_);
|
||||
opt_input_trans_ = nullptr;
|
||||
}
|
||||
}
|
||||
int kernel_unit_{0};
|
||||
int input_unit_{0};
|
||||
int output_unit_{0};
|
||||
int oc_block_{0};
|
||||
int tile_num_{0};
|
||||
int tmp_data_tile_{0};
|
||||
float *tmp_data_ = nullptr;
|
||||
float *trans_input_ = nullptr;
|
||||
float *gemm_out_ = nullptr;
|
||||
float *col_buffer_ = nullptr;
|
||||
float *opt_input_trans_ = nullptr;
|
||||
float matrix_g_[64];
|
||||
float matrix_gt_[64];
|
||||
TmpBufferAddress tmp_buffer_address_list_[5] = {nullptr};
|
||||
TransFuncList trans_func_;
|
||||
};
|
||||
|
||||
LiteKernel *CreateConvolutionWinogradCPUKernel(OpParameter *parameter, const std::vector<lite::Tensor *> &inputs,
|
||||
const std::vector<lite::Tensor *> &outputs,
|
||||
const lite::InnerContext *ctx, int out_unit, float *origin_weight,
|
||||
float *origin_bias);
|
||||
} // namespace mindspore::kernel
|
||||
#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_CPU_FP32_CONVOLUTION_WINOGRAD_FP32_H_
|
||||
|
|
|
@ -0,0 +1,25 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "src/litert/kernel/cpu/fp32/convolution_winograd_sse_fp32.h"
|
||||
|
||||
namespace mindspore::kernel {
|
||||
void ConvolutionWinogradSSECPUKernel::InitGlobalVariable() {
|
||||
oc_block_ = C8NUM;
|
||||
tmp_data_tile_ = C4NUM;
|
||||
tile_num_ = C12NUM;
|
||||
}
|
||||
} // namespace mindspore::kernel
|
|
@ -0,0 +1,34 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_CPU_SSE_FP32_CONVOLUTION_WINOGRAD_FP32_H_
|
||||
#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_CPU_SSE_FP32_CONVOLUTION_WINOGRAD_FP32_H_
|
||||
|
||||
#include <vector>
|
||||
#include "src/litert/kernel/cpu/fp32/convolution_winograd_base_fp32.h"
|
||||
|
||||
namespace mindspore::kernel {
|
||||
class ConvolutionWinogradSSECPUKernel : public ConvolutionWinogradBaseCPUKernel {
|
||||
public:
|
||||
ConvolutionWinogradSSECPUKernel(OpParameter *parameter, const std::vector<lite::Tensor *> &inputs,
|
||||
const std::vector<lite::Tensor *> &outputs, const lite::InnerContext *ctx,
|
||||
int output_unit, float *origin_weight, float *origin_bias)
|
||||
: ConvolutionWinogradBaseCPUKernel(parameter, inputs, outputs, ctx, output_unit, origin_weight, origin_bias) {}
|
||||
~ConvolutionWinogradSSECPUKernel() override {}
|
||||
void InitGlobalVariable() override;
|
||||
};
|
||||
} // namespace mindspore::kernel
|
||||
#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_CPU_SSE_FP32_CONVOLUTION_WINOGRAD_FP32_H_
|
Loading…
Reference in New Issue