forked from mindspore-Ecosystem/mindspore
winograd op thread cut opt
This commit is contained in:
parent
f0142dce53
commit
d56ce6ff05
|
@ -200,6 +200,7 @@ mindspore/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/fp16/instance_norm_fp16
|
||||||
mindspore/mindspore/lite/src/litert/kernel/cpu/fp32/matmul_fp32_base.cc:mindspore::kernel::MatmulFp32BaseCPUKernel::init_global_variable
|
mindspore/mindspore/lite/src/litert/kernel/cpu/fp32/matmul_fp32_base.cc:mindspore::kernel::MatmulFp32BaseCPUKernel::init_global_variable
|
||||||
mindspore/mindspore/lite/src/train/train_loop.cc:mindspore::lite::TrainLoop::Train
|
mindspore/mindspore/lite/src/train/train_loop.cc:mindspore::lite::TrainLoop::Train
|
||||||
mindspore/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/fp32/conv_winograd_fp32.c:ConvWinogardFp32
|
mindspore/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/fp32/conv_winograd_fp32.c:ConvWinogardFp32
|
||||||
|
mindspore/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/fp32/conv_winograd_fp32.c:ConvWinogardFp32CutByBatch
|
||||||
mindspore/mindspore/ccsrc/plugin/device/ascend/optimizer/ir_fusion/lamb_next_mv_with_decay_v1_rule.cc:mindspore::opt::MatchAdd5Pattern
|
mindspore/mindspore/ccsrc/plugin/device/ascend/optimizer/ir_fusion/lamb_next_mv_with_decay_v1_rule.cc:mindspore::opt::MatchAdd5Pattern
|
||||||
mindspore/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/experimental/conv_fp32_nchwx_avx512.c:conv2d_compute_fp32_nchwx_avx512
|
mindspore/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/experimental/conv_fp32_nchwx_avx512.c:conv2d_compute_fp32_nchwx_avx512
|
||||||
mindspore/mindspore/lite/src/litert/kernel/cpu/control/tensorlist_setitem.cc:mindspore::kernel::TensorListSetItemCPUKernel::Run
|
mindspore/mindspore/lite/src/litert/kernel/cpu/control/tensorlist_setitem.cc:mindspore::kernel::TensorListSetItemCPUKernel::Run
|
||||||
|
|
|
@ -51,11 +51,139 @@ void ConvWinogardFp32(const float *input_data, const float *trans_weight, const
|
||||||
float *col_buffer = buffer_list[3] + task_id * tile_num * in_channel;
|
float *col_buffer = buffer_list[3] + task_id * tile_num * in_channel;
|
||||||
// step 1 : filter transform (pre-processed offline)
|
// step 1 : filter transform (pre-processed offline)
|
||||||
// step 2 : input transform (online)
|
// step 2 : input transform (online)
|
||||||
|
|
||||||
|
int block_per_thread = UP_DIV(output_tile_count, conv_param->thread_num_);
|
||||||
|
int start_index = block_per_thread * task_id * tile_num;
|
||||||
|
if (start_index >= output_count) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
int end_index = MSMIN(start_index + block_per_thread * tile_num, output_count);
|
||||||
|
|
||||||
for (int b = 0; b < conv_param->input_batch_; b++) {
|
for (int b = 0; b < conv_param->input_batch_; b++) {
|
||||||
int in_batch_offset = b * in_channel * conv_param->input_h_ * conv_param->input_w_;
|
int in_batch_offset = b * in_channel * conv_param->input_h_ * conv_param->input_w_;
|
||||||
int out_batch_offset = b * conv_param->output_channel_ * conv_param->output_w_ * conv_param->output_h_;
|
int out_batch_offset = b * conv_param->output_channel_ * conv_param->output_w_ * conv_param->output_h_;
|
||||||
for (int thread_id = task_id; thread_id < output_tile_count; thread_id += conv_param->thread_num_) {
|
|
||||||
int out_tile_index = thread_id * tile_num;
|
for (int out_tile_index = start_index; out_tile_index < end_index; out_tile_index += tile_num) {
|
||||||
|
int cal_num = output_count - out_tile_index;
|
||||||
|
cal_num = cal_num > tile_num ? tile_num : cal_num;
|
||||||
|
if (cal_num <= 0) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
#ifdef ENABLE_ARM64
|
||||||
|
// Optimize input transform. Only valid for arm64, the tile num is 12, the channel_tile is 4.
|
||||||
|
// For arm32, the tile_num is 4.
|
||||||
|
// For x86_sse, the tile_num is 4, the channel_tile is 4.
|
||||||
|
// For avx, the tile_num is 6, the channel_tile is 8.
|
||||||
|
// N = input_unit, M = tile_num
|
||||||
|
// The function(InputTransformNxNStep, InputTransform4x4PackM) needs to be rewritten.
|
||||||
|
bool fused_pack =
|
||||||
|
(cal_num == tile_num) && (trans_func.in_step_func_ != NULL) && (trans_func.in_pack_func_ != NULL);
|
||||||
|
if (fused_pack) {
|
||||||
|
float *opt_trans_input =
|
||||||
|
buffer_list[4] + task_id * tile_num * input_unit_square * UP_ROUND(in_channel, channel_pack_tile);
|
||||||
|
WinogradInputTransformOptStep(input_data + in_batch_offset, opt_trans_input, tmp_data, cal_num, out_tile_index,
|
||||||
|
out_w_block, conv_param, trans_func.in_step_func_);
|
||||||
|
|
||||||
|
for (int w_index = 0; w_index < input_unit; w_index++) {
|
||||||
|
float *src_w = opt_trans_input + w_index * input_unit * tile_num * channel_pack_tile;
|
||||||
|
for (int c = 0; c < UP_DIV(in_channel, channel_pack_tile); c++) {
|
||||||
|
int real_c = in_channel - c * channel_pack_tile;
|
||||||
|
real_c = real_c > channel_pack_tile ? channel_pack_tile : real_c;
|
||||||
|
float *src_c = src_w + c * input_unit_square * tile_num * channel_pack_tile;
|
||||||
|
float *dst_c = trans_input + c * tile_num * channel_pack_tile;
|
||||||
|
trans_func.in_pack_func_(src_c, dst_c, channel_pack_tile, in_channel * tile_num, real_c);
|
||||||
|
}
|
||||||
|
|
||||||
|
for (int h_index = 0; h_index < input_unit; h_index++) {
|
||||||
|
const float *gemm_input = trans_input + h_index * tile_num * in_channel;
|
||||||
|
int point_index = h_index * input_unit + w_index;
|
||||||
|
const float *gemm_weight = trans_weight + point_index * in_channel * oc_tile * col_tile;
|
||||||
|
MatMulOpt(gemm_input, gemm_weight, gemm_out + point_index * C8NUM, NULL, 0, in_channel, cal_num,
|
||||||
|
oc8 * C8NUM, input_unit_square, OutType_TileC8);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
#endif
|
||||||
|
WinogradInputTransform(input_data + in_batch_offset, trans_input, tmp_data, cal_num, out_tile_index,
|
||||||
|
out_w_block, conv_param, trans_func.in_func_);
|
||||||
|
// step 3 : gemm
|
||||||
|
float *src_ptr = trans_input;
|
||||||
|
float *dst_ptr = gemm_out;
|
||||||
|
float *tmp_col_ptr = col_buffer;
|
||||||
|
for (int i = 0; i < input_unit_square; ++i) {
|
||||||
|
#ifdef ENABLE_AVX
|
||||||
|
RowMajor2Col6Major(src_ptr + i * C12NUM * in_channel, tmp_col_ptr, C12NUM, in_channel);
|
||||||
|
#elif defined(ENABLE_ARM32) || defined(ENABLE_SSE)
|
||||||
|
RowMajor2Col4Major(src_ptr + i * C12NUM * in_channel, tmp_col_ptr, C12NUM, in_channel);
|
||||||
|
#else
|
||||||
|
RowMajor2Col12Major(src_ptr + i * C12NUM * in_channel, tmp_col_ptr, C12NUM, in_channel);
|
||||||
|
#endif
|
||||||
|
MatMulOpt(tmp_col_ptr, trans_weight + i * in_channel * oc_tile * col_tile, dst_ptr + i * C8NUM, NULL, 0,
|
||||||
|
in_channel, cal_num, oc8 * C8NUM, input_unit_square, 2);
|
||||||
|
}
|
||||||
|
#ifdef ENABLE_ARM64
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
|
||||||
|
// step 4 : output transform
|
||||||
|
float *output_ptr = output_data + out_batch_offset;
|
||||||
|
if (conv_param->out_format_ != Format_NC4HW4) { // nc4hw4
|
||||||
|
WinogradOutputNHWCTransform(gemm_out, output_ptr, bias_data, cal_num, out_tile_index, out_w_block, conv_param,
|
||||||
|
trans_func.out_func_);
|
||||||
|
} else {
|
||||||
|
#if defined(ENABLE_AVX) || defined(ENABLE_ARM64)
|
||||||
|
WinogradOutputNC4HW4Transform(gemm_out, output_ptr, bias_data, cal_num, out_tile_index, out_w_block, conv_param,
|
||||||
|
trans_func.out_func_);
|
||||||
|
#else
|
||||||
|
WinogradOutputNHWCTransform(gemm_out, output_ptr, bias_data, cal_num, out_tile_index, out_w_block, conv_param,
|
||||||
|
trans_func.out_func_);
|
||||||
|
#endif
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// fp32 conv winograd
|
||||||
|
void ConvWinogardFp32CutByBatch(const float *input_data, const float *trans_weight, const float *bias_data,
|
||||||
|
float *output_data, TmpBufferAddress *buffer_list, int task_id,
|
||||||
|
const ConvParameter *conv_param, TransFuncList trans_func) {
|
||||||
|
if (conv_param->output_unit_ == 0) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
int in_channel = conv_param->input_channel_;
|
||||||
|
int input_unit = conv_param->input_unit_;
|
||||||
|
int out_w_block = UP_DIV(conv_param->output_w_, conv_param->output_unit_);
|
||||||
|
int out_h_block = UP_DIV(conv_param->output_h_, conv_param->output_unit_);
|
||||||
|
int output_count = out_w_block * out_h_block;
|
||||||
|
const int tile_num = C12NUM;
|
||||||
|
#ifdef ENABLE_AVX
|
||||||
|
const int col_tile = C16NUM;
|
||||||
|
const int channel_pack_tile = C8NUM;
|
||||||
|
#else
|
||||||
|
const int col_tile = C8NUM;
|
||||||
|
const int channel_pack_tile = C4NUM;
|
||||||
|
#endif
|
||||||
|
int oc_tile = UP_DIV(conv_param->output_channel_, col_tile);
|
||||||
|
int oc8 = UP_DIV(conv_param->output_channel_, C8NUM);
|
||||||
|
int input_unit_square = input_unit * input_unit;
|
||||||
|
|
||||||
|
float *trans_input = buffer_list[0] + task_id * tile_num * input_unit_square * in_channel;
|
||||||
|
float *gemm_out = buffer_list[1] + task_id * tile_num * input_unit_square * oc8 * C8NUM;
|
||||||
|
float *tmp_data = buffer_list[2] + task_id * input_unit_square * channel_pack_tile;
|
||||||
|
float *col_buffer = buffer_list[3] + task_id * tile_num * in_channel;
|
||||||
|
// step 1 : filter transform (pre-processed offline)
|
||||||
|
// step 2 : input transform (online)
|
||||||
|
|
||||||
|
int block_batch_per_thread = UP_DIV(conv_param->input_batch_, conv_param->thread_num_);
|
||||||
|
int start_batch = block_batch_per_thread * task_id;
|
||||||
|
int end_batch = MSMIN(conv_param->input_batch_, (start_batch + block_batch_per_thread));
|
||||||
|
|
||||||
|
for (int b = start_batch; b < end_batch; b++) {
|
||||||
|
int in_batch_offset = b * in_channel * conv_param->input_h_ * conv_param->input_w_;
|
||||||
|
int out_batch_offset = b * conv_param->output_channel_ * conv_param->output_w_ * conv_param->output_h_;
|
||||||
|
|
||||||
|
for (int out_tile_index = 0; out_tile_index < output_count; out_tile_index += tile_num) {
|
||||||
int cal_num = output_count - out_tile_index;
|
int cal_num = output_count - out_tile_index;
|
||||||
cal_num = cal_num > tile_num ? tile_num : cal_num;
|
cal_num = cal_num > tile_num ? tile_num : cal_num;
|
||||||
if (cal_num <= 0) {
|
if (cal_num <= 0) {
|
||||||
|
|
|
@ -37,6 +37,11 @@ extern "C" {
|
||||||
void ConvWinogardFp32(const float *input_data, const float *trans_weight, const float *bias_data, float *output_data,
|
void ConvWinogardFp32(const float *input_data, const float *trans_weight, const float *bias_data, float *output_data,
|
||||||
TmpBufferAddress *buffer_list, int task_id, const ConvParameter *conv_param,
|
TmpBufferAddress *buffer_list, int task_id, const ConvParameter *conv_param,
|
||||||
TransFuncList trans_func);
|
TransFuncList trans_func);
|
||||||
|
|
||||||
|
void ConvWinogardFp32CutByBatch(const float *input_data, const float *trans_weight, const float *bias_data,
|
||||||
|
float *output_data, TmpBufferAddress *buffer_list, int task_id,
|
||||||
|
const ConvParameter *conv_param, TransFuncList trans_func);
|
||||||
|
|
||||||
#ifdef __cplusplus
|
#ifdef __cplusplus
|
||||||
}
|
}
|
||||||
#endif
|
#endif
|
||||||
|
|
|
@ -91,6 +91,7 @@ class ConvolutionBaseCPUKernel : public LiteKernel {
|
||||||
bool is_repack_ = false;
|
bool is_repack_ = false;
|
||||||
void *origin_weight_; // do not free
|
void *origin_weight_; // do not free
|
||||||
void *origin_bias_; // do not free
|
void *origin_bias_; // do not free
|
||||||
|
bool use_batch_cut_flag_ = false;
|
||||||
};
|
};
|
||||||
} // namespace mindspore::kernel
|
} // namespace mindspore::kernel
|
||||||
|
|
||||||
|
|
|
@ -28,7 +28,7 @@ using mindspore::lite::RET_INFER_INVALID;
|
||||||
using mindspore::lite::RET_OK;
|
using mindspore::lite::RET_OK;
|
||||||
|
|
||||||
namespace mindspore::kernel {
|
namespace mindspore::kernel {
|
||||||
#define CONV_MIN_CALC_BLOCK C32NUM
|
#define CONV_MIN_CALC_BLOCK C1NUM
|
||||||
#ifdef ENABLE_AVX
|
#ifdef ENABLE_AVX
|
||||||
#define OC_BLOCK C16NUM
|
#define OC_BLOCK C16NUM
|
||||||
#elif defined(ENABLE_ARM32)
|
#elif defined(ENABLE_ARM32)
|
||||||
|
@ -126,7 +126,7 @@ int ConvolutionCPUKernel::UpdateThreadNumProcess(int32_t kernel_type, int64_t pe
|
||||||
const int cal_num = C12NUM;
|
const int cal_num = C12NUM;
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
conv_param_->thread_num_ = MSMIN(UP_DIV(UP_DIV(output_hw, cal_num), CONV_MIN_CALC_BLOCK), conv_param_->thread_num_);
|
conv_param_->thread_num_ = MSMIN(UP_DIV(UP_DIV(output_hw, cal_num), CONV_MIN_CALC_BLOCK), op_parameter_->thread_num_);
|
||||||
thread_count_ = conv_param_->thread_num_;
|
thread_count_ = conv_param_->thread_num_;
|
||||||
return RET_OK;
|
return RET_OK;
|
||||||
}
|
}
|
||||||
|
|
|
@ -65,7 +65,6 @@ class ConvolutionCPUKernel : public ConvolutionBaseCPUKernel {
|
||||||
float *packed_input_ = nullptr;
|
float *packed_input_ = nullptr;
|
||||||
float *col_major_input_ = nullptr;
|
float *col_major_input_ = nullptr;
|
||||||
bool output_need_align_ = false;
|
bool output_need_align_ = false;
|
||||||
bool use_batch_cut_flag_ = false;
|
|
||||||
};
|
};
|
||||||
} // namespace mindspore::kernel
|
} // namespace mindspore::kernel
|
||||||
|
|
||||||
|
|
|
@ -25,6 +25,7 @@ using mindspore::lite::RET_NULL_PTR;
|
||||||
using mindspore::lite::RET_OK;
|
using mindspore::lite::RET_OK;
|
||||||
|
|
||||||
namespace mindspore::kernel {
|
namespace mindspore::kernel {
|
||||||
|
#define CONV_MIN_CALC_BLOCK C1NUM
|
||||||
int ConvolutionWinogradCPUKernel::WinogradFilterTransform(const float *weight_data, float *matrix_g,
|
int ConvolutionWinogradCPUKernel::WinogradFilterTransform(const float *weight_data, float *matrix_g,
|
||||||
const float *matrix_gt, int oc_block) {
|
const float *matrix_gt, int oc_block) {
|
||||||
if (oc_block == 0) {
|
if (oc_block == 0) {
|
||||||
|
@ -141,6 +142,24 @@ int ConvolutionWinogradCPUKernel::Prepare() {
|
||||||
return RET_OK;
|
return RET_OK;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
int ConvolutionWinogradCPUKernel::UpdateThreadNumProcess(int32_t kernel_type, int64_t per_unit_load_num,
|
||||||
|
int64_t per_unit_store_num, int64_t unit_num) {
|
||||||
|
if (conv_param_->input_batch_ % conv_param_->thread_num_ == 0) {
|
||||||
|
use_batch_cut_flag_ = true;
|
||||||
|
return RET_OK;
|
||||||
|
} else {
|
||||||
|
use_batch_cut_flag_ = false;
|
||||||
|
}
|
||||||
|
|
||||||
|
auto output_hw = conv_param_->output_h_ * conv_param_->output_w_;
|
||||||
|
const int tile_num = C12NUM;
|
||||||
|
|
||||||
|
conv_param_->thread_num_ =
|
||||||
|
MSMIN(UP_DIV(UP_DIV(output_hw, tile_num), CONV_MIN_CALC_BLOCK), op_parameter_->thread_num_);
|
||||||
|
thread_count_ = conv_param_->thread_num_;
|
||||||
|
return RET_OK;
|
||||||
|
}
|
||||||
|
|
||||||
int ConvolutionWinogradCPUKernel::ReSize() {
|
int ConvolutionWinogradCPUKernel::ReSize() {
|
||||||
auto ret = ConvolutionBaseCPUKernel::CheckResizeValid();
|
auto ret = ConvolutionBaseCPUKernel::CheckResizeValid();
|
||||||
if (ret != RET_OK) {
|
if (ret != RET_OK) {
|
||||||
|
@ -152,6 +171,9 @@ int ConvolutionWinogradCPUKernel::ReSize() {
|
||||||
MS_LOG(ERROR) << "conv base init failed.";
|
MS_LOG(ERROR) << "conv base init failed.";
|
||||||
return ret;
|
return ret;
|
||||||
}
|
}
|
||||||
|
if (UpdateThreadNumPass(TC_PTYPE(type_), 0, 0, 0) != RET_OK) {
|
||||||
|
return RET_ERROR;
|
||||||
|
}
|
||||||
ret = ConfigInputOutput();
|
ret = ConfigInputOutput();
|
||||||
if (ret != RET_OK) {
|
if (ret != RET_OK) {
|
||||||
MS_LOG(ERROR) << "ConfigInputOutput failed.";
|
MS_LOG(ERROR) << "ConfigInputOutput failed.";
|
||||||
|
@ -169,9 +191,17 @@ int ConvolutionWinogradCPUKernel::RunImpl(int task_id) {
|
||||||
CHECK_NULL_RETURN(out_tensors_.front());
|
CHECK_NULL_RETURN(out_tensors_.front());
|
||||||
auto output_data = reinterpret_cast<float *>(out_tensors_.front()->data());
|
auto output_data = reinterpret_cast<float *>(out_tensors_.front()->data());
|
||||||
CHECK_NULL_RETURN(output_data);
|
CHECK_NULL_RETURN(output_data);
|
||||||
ConvWinogardFp32(ori_input_data, reinterpret_cast<float *>(packed_weight_),
|
|
||||||
reinterpret_cast<const float *>(bias_data_), output_data, tmp_buffer_address_list_, task_id,
|
if (use_batch_cut_flag_) {
|
||||||
conv_param_, trans_func_);
|
ConvWinogardFp32CutByBatch(ori_input_data, reinterpret_cast<float *>(packed_weight_),
|
||||||
|
reinterpret_cast<const float *>(bias_data_), output_data, tmp_buffer_address_list_,
|
||||||
|
task_id, conv_param_, trans_func_);
|
||||||
|
} else {
|
||||||
|
ConvWinogardFp32(ori_input_data, reinterpret_cast<float *>(packed_weight_),
|
||||||
|
reinterpret_cast<const float *>(bias_data_), output_data, tmp_buffer_address_list_, task_id,
|
||||||
|
conv_param_, trans_func_);
|
||||||
|
}
|
||||||
|
|
||||||
return RET_OK;
|
return RET_OK;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -45,6 +45,8 @@ class ConvolutionWinogradCPUKernel : public ConvolutionBaseCPUKernel {
|
||||||
private:
|
private:
|
||||||
int MallocWeightBiasData() override;
|
int MallocWeightBiasData() override;
|
||||||
void PackWeight() override;
|
void PackWeight() override;
|
||||||
|
int UpdateThreadNumProcess(int32_t kernel_type, int64_t per_unit_load_num, int64_t per_unit_store_num,
|
||||||
|
int64_t unit_num) override;
|
||||||
void FreeTmpBuffer() {
|
void FreeTmpBuffer() {
|
||||||
if (trans_input_ != nullptr) {
|
if (trans_input_ != nullptr) {
|
||||||
ctx_->allocator->Free(trans_input_);
|
ctx_->allocator->Free(trans_input_);
|
||||||
|
|
Loading…
Reference in New Issue