[MSLITE][DEVELOP] optimize fp32 lstm, add parallel run

This commit is contained in:
yangruoqi713 2021-06-30 10:11:04 +08:00
parent 55907b79e0
commit f8a232d2f7
5 changed files with 203 additions and 32 deletions

View File

@ -19,6 +19,13 @@
#include "nnacl/fp32/activation_fp32.h"
#include "nnacl/fp32/arithmetic_fp32.h"
#include "nnacl/fp32/matmul_fp32.h"
void GruMatMul(float *c, const float *a, const float *b, const float *bias, int row, int deep, int col, bool is_vec) {
if (is_vec) {
MatVecMulFp32(a, b, c, bias, ActType_No, deep, col);
} else {
MatMulOpt(a, b, c, bias, ActType_No, deep, row, col, col, OutType_Nhwc);
}
}
void GruStepUnit(float *output, float *update_gate, float *reset_gate, float *hidden_buffer, const float *state_weight,
const float *state_bias, float *hidden_state, float *buffer[4], const GruParameter *gru_param) {
@ -38,16 +45,16 @@ void GruStepUnit(float *output, float *update_gate, float *reset_gate, float *hi
// state * weight
if (is_vec) {
LstmMatMul(state_reset_gate, hidden_state, state_reset_weight, state_reset_bias, gru_param->batch_,
gru_param->hidden_size_, gru_param->hidden_size_, is_vec);
LstmMatMul(state_update_gate, hidden_state, state_update_weight, state_update_bias, gru_param->batch_,
gru_param->hidden_size_, gru_param->hidden_size_, is_vec);
GruMatMul(state_reset_gate, hidden_state, state_reset_weight, state_reset_bias, gru_param->batch_,
gru_param->hidden_size_, gru_param->hidden_size_, is_vec);
GruMatMul(state_update_gate, hidden_state, state_update_weight, state_update_bias, gru_param->batch_,
gru_param->hidden_size_, gru_param->hidden_size_, is_vec);
} else {
PackLstmInput(hidden_state, packed_state, gru_param->batch_, gru_param->hidden_size_);
LstmMatMul(state_reset_gate, packed_state, state_reset_weight, state_reset_bias, gru_param->batch_,
gru_param->hidden_size_, gru_param->hidden_size_, is_vec);
LstmMatMul(state_update_gate, packed_state, state_update_weight, state_update_bias, gru_param->batch_,
gru_param->hidden_size_, gru_param->hidden_size_, is_vec);
GruMatMul(state_reset_gate, packed_state, state_reset_weight, state_reset_bias, gru_param->batch_,
gru_param->hidden_size_, gru_param->hidden_size_, is_vec);
GruMatMul(state_update_gate, packed_state, state_update_weight, state_update_bias, gru_param->batch_,
gru_param->hidden_size_, gru_param->hidden_size_, is_vec);
}
ElementAdd(update_gate, state_update_gate, update_gate, gru_param->batch_ * gru_param->hidden_size_);
ElementAdd(reset_gate, state_update_gate + gru_param->batch_ * gru_param->hidden_size_, reset_gate,
@ -60,12 +67,12 @@ void GruStepUnit(float *output, float *update_gate, float *reset_gate, float *hi
ElementMul(hidden_state, reset_gate, reset_gate, gru_param->batch_ * gru_param->hidden_size_);
if (is_vec) {
LstmMatMul(state_hidden_buffer, reset_gate, state_hidden_weight, state_hidden_bias, gru_param->batch_,
gru_param->hidden_size_, gru_param->hidden_size_, is_vec);
GruMatMul(state_hidden_buffer, reset_gate, state_hidden_weight, state_hidden_bias, gru_param->batch_,
gru_param->hidden_size_, gru_param->hidden_size_, is_vec);
} else {
PackLstmInput(reset_gate, packed_state, gru_param->batch_, gru_param->hidden_size_);
LstmMatMul(state_hidden_buffer, packed_state, state_hidden_weight, state_hidden_bias, gru_param->batch_,
gru_param->hidden_size_, gru_param->hidden_size_, is_vec);
GruMatMul(state_hidden_buffer, packed_state, state_hidden_weight, state_hidden_bias, gru_param->batch_,
gru_param->hidden_size_, gru_param->hidden_size_, is_vec);
}
ElementAdd(hidden_buffer, state_hidden_buffer, hidden_buffer, gru_param->batch_ * gru_param->hidden_size_);

View File

@ -20,6 +20,7 @@
#include "nnacl/fp32/activation_fp32.h"
#include "nnacl/fp32/arithmetic_fp32.h"
#include "nnacl/fp32/matmul_fp32.h"
#include "nnacl/fp32/pack_fp32.h"
void PackLstmWeight(float *dst, const float *src, int batch, int deep, int col, int col_align) {
for (int i = 0; i < batch; i++) {
@ -63,9 +64,21 @@ void PackLstmInput(const float *src, float *dst, int row, int deep) {
#endif
}
void LstmMatMul(float *c, const float *a, const float *b, const float *bias, int row, int deep, int col, bool is_vec) {
void LstmMatMul(float *c, const float *a, const float *b, const float *bias, int row, int deep, int col, int col_align,
bool is_vec, float *packed_ptr) {
if (is_vec) {
#ifdef ENABLE_AVX
bool need_packed = col % C8NUM;
if (!need_packed) {
packed_ptr = c;
}
MatVecMulAvxFp32(a, b, packed_ptr, bias, ActType_No, deep, col, col_align);
if (need_packed) {
PackNHWCXToNHWCFp32(packed_ptr, c, 1, row, col, C8NUM);
}
#else
MatVecMulFp32(a, b, c, bias, ActType_No, deep, col);
#endif
} else {
MatMulOpt(a, b, c, bias, ActType_No, deep, row, col, col, OutType_Nhwc);
}
@ -140,32 +153,42 @@ void UpdataOutput(const float *cell_state, const float *output_gate, float *hidd
}
void UpdateLstmGate(float *gate_buffer, const float *input, const float *weight, const float *bias, int row, int deep,
int col, int col_align, bool is_vec) {
int col, int col_align, bool is_vec, float *packed_ptr) {
for (int i = 0; i < 4; i++) {
const float *weight_i = weight + deep * col * i;
const float *weight_i;
#ifdef ENABLE_AVX
if (is_vec) {
weight_i = weight + deep * col_align * i;
} else {
weight_i = weight + deep * col * i;
}
#else
weight_i = weight + deep * col * i;
#endif
const float *bias_i = bias + col_align * i;
float *gate = gate_buffer + row * col * i;
LstmMatMul(gate, input, weight_i, bias_i, row, deep, col, is_vec);
LstmMatMul(gate, input, weight_i, bias_i, row, deep, col, col_align, is_vec, packed_ptr);
}
}
void LstmStepUnit(float *output, float *input_gate, float *forget_gate, float *cell_gate, float *output_gate,
const float *state_weight, const float *state_bias, float *hidden_state, float *cell_state,
float *buffer[6], const LstmParameter *lstm_param) {
float *buffer[7], const LstmParameter *lstm_param) {
float *packed_state = buffer[2];
float *state_gate = buffer[3];
float *cell_buffer = buffer[4];
float *hidden_buffer = buffer[5];
float *packed_output = buffer[6];
bool is_vec = lstm_param->batch_ == 1;
// state * weight
if (is_vec) {
UpdateLstmGate(state_gate, hidden_state, state_weight, state_bias, lstm_param->batch_, lstm_param->hidden_size_,
lstm_param->hidden_size_, lstm_param->state_col_align_, is_vec);
lstm_param->hidden_size_, lstm_param->state_col_align_, is_vec, packed_output);
} else {
// pack state for matmul
PackLstmInput(hidden_state, packed_state, lstm_param->batch_, lstm_param->hidden_size_);
UpdateLstmGate(state_gate, packed_state, state_weight, state_bias, lstm_param->batch_, lstm_param->hidden_size_,
lstm_param->hidden_size_, lstm_param->state_col_align_, is_vec);
lstm_param->hidden_size_, lstm_param->state_col_align_, is_vec, packed_output);
}
ElementAdd(input_gate, state_gate, input_gate, lstm_param->batch_ * lstm_param->hidden_size_);
ElementAdd(forget_gate, state_gate + lstm_param->batch_ * lstm_param->hidden_size_ * 2, forget_gate,

View File

@ -27,12 +27,17 @@ void PackLstmBias(float *dst, const float *src, int batch, int col, int col_alig
void PackLstmInput(const float *src, float *dst, int row, int deep);
void LstmMatMul(float *c, const float *a, const float *b, const float *bias, int row, int deep, int col, bool is_vec);
void LstmMatMul(float *c, const float *a, const float *b, const float *bias, int row, int deep, int col, int col_align,
bool is_vec, float *packed_ptr);
void ElementMulAcc(const float *input0, const float *input1, float *output, int element_size);
int ElementOptMulAcc(const float *input0, const float input1, float *output, const int element_size);
void LstmStepUnit(float *output, float *input_gate, float *forget_gate, float *cell_gate, float *output_gate,
const float *state_weight, const float *state_bias, float *hidden_state, float *cell_state,
float *buffer[6], const LstmParameter *lstm_param);
void Lstm(float *output, const float *input, const float *weight_i, const float *weight_h, const float *input_bias,
const float *state_bias, float *hidden_state, float *cell_state, float *buffer[6],
const LstmParameter *lstm_param);

View File

@ -39,7 +39,7 @@ void LstmCPUKernel::FreeTmpBuffer() {
free(input_bias_);
input_bias_ = nullptr;
}
if (!is_vec_) {
if (!state_is_vec_) {
if (weight_h_ptr_ != nullptr) {
free(weight_h_ptr_);
weight_h_ptr_ = nullptr;
@ -54,7 +54,7 @@ void LstmCPUKernel::FreeTmpBuffer() {
void LstmCPUKernel::FreeRunBuffer() {
context_->allocator->Free(buffer_[0]);
context_->allocator->Free(buffer_[1]);
if (!is_vec_) {
if (!state_is_vec_) {
context_->allocator->Free(buffer_[2]);
}
context_->allocator->Free(buffer_[3]);
@ -64,6 +64,9 @@ void LstmCPUKernel::FreeRunBuffer() {
if (!(lstm_param_->zoneout_hidden_ >= -FLT_EPSILON && lstm_param_->zoneout_hidden_ <= FLT_EPSILON)) {
context_->allocator->Free(buffer_[5]);
}
if (output_need_packed_) {
context_->allocator->Free(buffer_[6]);
}
}
int LstmCPUKernel::InitInputWeightBias() {
@ -103,7 +106,7 @@ int LstmCPUKernel::InitStateWeightBias() {
auto weight_h = in_tensors_.at(2);
MS_ASSERT(weight_h != nullptr);
auto weight_h_data = reinterpret_cast<float *>(weight_h->data_c());
if (!is_vec_) {
if (!state_is_vec_) {
weight_h_ptr_ = reinterpret_cast<float *>(
malloc(weight_batch_ * lstm_param_->state_col_align_ * lstm_param_->hidden_size_ * sizeof(float)));
if (weight_h_ptr_ == nullptr) {
@ -113,7 +116,17 @@ int LstmCPUKernel::InitStateWeightBias() {
PackLstmWeight(weight_h_ptr_, weight_h_data, weight_batch_, lstm_param_->hidden_size_, lstm_param_->hidden_size_,
lstm_param_->state_col_align_);
} else {
#ifdef ENABLE_AVX
weight_h_ptr_ = reinterpret_cast<float *>(
malloc(weight_batch_ * lstm_param_->state_col_align_ * lstm_param_->hidden_size_ * sizeof(float)));
for (int i = 0; i < weight_batch_; i++) {
const float *src_batch = weight_h_data + i * lstm_param_->hidden_size_ * lstm_param_->hidden_size_;
float *dst_batch = weight_h_ptr_ + i * lstm_param_->state_col_align_ * lstm_param_->hidden_size_;
RowMajor2Col32Major(src_batch, dst_batch, lstm_param_->hidden_size_, lstm_param_->hidden_size_);
}
#else
weight_h_ptr_ = weight_h_data;
#endif
}
// state bias
@ -145,6 +158,7 @@ int LstmCPUKernel::InitParam() {
lstm_param_->output_step_ = lstm_param_->bidirectional_ ? 2 * lstm_param_->batch_ * lstm_param_->hidden_size_
: lstm_param_->batch_ * lstm_param_->hidden_size_;
weight_batch_ = lstm_param_->bidirectional_ ? 8 : 4;
state_is_vec_ = lstm_param_->batch_ == 1;
#ifdef ENABLE_AVX
row_tile_ = C6NUM;
@ -161,10 +175,25 @@ int LstmCPUKernel::InitParam() {
#endif
lstm_param_->input_row_align_ = UP_ROUND(lstm_param_->seq_len_ * lstm_param_->batch_, row_tile_);
lstm_param_->input_col_align_ = UP_ROUND(lstm_param_->hidden_size_, col_tile_);
input_thread_count_ = MSMIN(op_parameter_->thread_num_, UP_DIV(lstm_param_->input_col_align_, col_tile_));
input_thread_stride_ = UP_DIV(UP_DIV(lstm_param_->input_col_align_, col_tile_), input_thread_count_);
is_vec_ = lstm_param_->batch_ == 1;
lstm_param_->state_row_align_ = is_vec_ ? 1 : UP_ROUND(lstm_param_->batch_, row_tile_);
lstm_param_->state_col_align_ = is_vec_ ? lstm_param_->hidden_size_ : UP_ROUND(lstm_param_->hidden_size_, col_tile_);
state_row_tile_ = row_tile_;
state_col_tile_ = col_tile_;
#ifdef ENABLE_AVX
if (state_is_vec_) {
state_row_tile_ = 1;
state_col_tile_ = C8NUM;
}
#endif
lstm_param_->state_row_align_ = state_is_vec_ ? 1 : UP_ROUND(lstm_param_->batch_, state_row_tile_);
#ifdef ENABLE_AVX
lstm_param_->state_col_align_ = UP_ROUND(lstm_param_->hidden_size_, state_col_tile_);
#else
lstm_param_->state_col_align_ =
state_is_vec_ ? lstm_param_->hidden_size_ : UP_ROUND(lstm_param_->hidden_size_, state_col_tile_);
#endif
return RET_OK;
}
@ -217,7 +246,7 @@ int LstmCPUKernel::MallocRunBuffer() {
return RET_ERROR;
}
if (!is_vec_) {
if (!state_is_vec_) {
buffer_[2] = reinterpret_cast<float *>(
context_->allocator->Malloc(lstm_param_->state_row_align_ * lstm_param_->hidden_size_ * sizeof(float)));
if (buffer_[2] == nullptr) {
@ -249,6 +278,100 @@ int LstmCPUKernel::MallocRunBuffer() {
return RET_ERROR;
}
}
#ifdef ENABLE_AVX
if (state_is_vec_) { // vec matmul need to malloc dst
output_need_packed_ = lstm_param_->hidden_size_ % state_col_tile_;
if (output_need_packed_) {
int out_channel = lstm_param_->hidden_size_;
int oc_block_num = UP_DIV(out_channel, state_col_tile_);
MS_ASSERT(context_->allocator != nullptr);
buffer_[6] = reinterpret_cast<float *>(
context_->allocator->Malloc(lstm_param_->batch_ * oc_block_num * state_col_tile_ * sizeof(float)));
if (buffer_[6] == nullptr) {
MS_LOG(ERROR) << "LstmCPUKernel malloc tmp output data failed.";
return RET_ERROR;
}
}
}
#endif
return RET_OK;
}
int LstmCPUKernel::InputWeightMatMul(int task_id) {
int current_start_oc = task_id * input_thread_stride_ * col_tile_;
int current_rest_oc = 0;
current_rest_oc = lstm_param_->hidden_size_ - current_start_oc;
int cur_oc = MSMIN(input_thread_stride_ * col_tile_, current_rest_oc);
if (cur_oc <= 0) {
return RET_OK;
}
auto input = buffer_[0];
auto b = weight_loop_ + current_start_oc * lstm_param_->input_size_;
auto c = gate_loop_ + current_start_oc;
auto bias = (bias_loop_ == nullptr) ? nullptr : bias_loop_ + current_start_oc;
MatMulOpt(input, b, c, bias, ActType_No, lstm_param_->input_size_, lstm_param_->seq_len_ * lstm_param_->batch_,
cur_oc, lstm_param_->hidden_size_, OutType_Nhwc);
return RET_OK;
}
int LstmInputMulWeightRun(void *cdata, int task_id, float lhs_scale, float rhs_scale) {
auto kernel = reinterpret_cast<LstmCPUKernel *>(cdata);
auto ret = kernel->InputWeightMatMul(task_id);
if (ret != RET_OK) {
MS_LOG(ERROR) << "InputWeightMatMul error task_id[" << task_id << "] error_code[" << ret << "]";
return RET_ERROR;
}
return RET_OK;
}
int LstmCPUKernel::LstmUnidirectional(float *output, const float *weight_i, const float *weight_h,
const float *input_bias, const float *state_bias, float *hidden_state,
float *cell_state, bool is_backward) {
float *gate = buffer_[1];
for (int i = 0; i < 4; i++) {
weight_loop_ = weight_i + lstm_param_->input_size_ * lstm_param_->input_col_align_ * i;
bias_loop_ = input_bias + lstm_param_->input_col_align_ * i;
gate_loop_ = gate + lstm_param_->seq_len_ * lstm_param_->batch_ * lstm_param_->hidden_size_ * i;
ParallelLaunch(this->context_, LstmInputMulWeightRun, this, input_thread_count_);
}
float *input_gate = gate;
float *forget_gate = gate + lstm_param_->seq_len_ * lstm_param_->batch_ * lstm_param_->hidden_size_ * 2;
float *cell_gate = gate + lstm_param_->seq_len_ * lstm_param_->batch_ * lstm_param_->hidden_size_ * 3;
float *output_gate = gate + lstm_param_->seq_len_ * lstm_param_->batch_ * lstm_param_->hidden_size_;
for (int t = 0; t < lstm_param_->seq_len_; t++) {
int real_t = is_backward ? lstm_param_->seq_len_ - t - 1 : t;
float *input_gate_t = input_gate + lstm_param_->batch_ * lstm_param_->hidden_size_ * real_t;
float *forget_gate_t = forget_gate + lstm_param_->batch_ * lstm_param_->hidden_size_ * real_t;
float *cell_gate_t = cell_gate + lstm_param_->batch_ * lstm_param_->hidden_size_ * real_t;
float *output_gate_t = output_gate + lstm_param_->batch_ * lstm_param_->hidden_size_ * real_t;
float *output_ptr = output + real_t * lstm_param_->output_step_;
LstmStepUnit(output_ptr, input_gate_t, forget_gate_t, cell_gate_t, output_gate_t, weight_h, state_bias,
hidden_state, cell_state, buffer_, lstm_param_);
}
return RET_OK;
}
int LstmCPUKernel::InnerExecute(float *output, const float *input, float *hidden_state, float *cell_state) {
// forward
// buffer_[0] : store packed input
PackLstmInput(input, buffer_[0], lstm_param_->seq_len_ * lstm_param_->batch_, lstm_param_->input_size_);
LstmUnidirectional(output, weight_i_ptr_, weight_h_ptr_, input_bias_, state_bias_, hidden_state, cell_state, false);
// backward
if (lstm_param_->bidirectional_) {
const float *backward_weight_i = weight_i_ptr_ + 4 * lstm_param_->input_col_align_ * lstm_param_->input_size_;
const float *backward_weight_h = weight_h_ptr_ + 4 * lstm_param_->state_col_align_ * lstm_param_->hidden_size_;
const float *backward_input_bias = input_bias_ + 4 * lstm_param_->input_col_align_;
const float *backward_state_bias = state_bias_ + 4 * lstm_param_->state_col_align_;
float *backward_output = output + lstm_param_->batch_ * lstm_param_->hidden_size_;
float *backward_cell_state = cell_state + lstm_param_->batch_ * lstm_param_->hidden_size_;
float *backward_hidden_state = hidden_state + lstm_param_->batch_ * lstm_param_->hidden_size_;
LstmUnidirectional(backward_output, backward_weight_i, backward_weight_h, backward_input_bias, backward_state_bias,
backward_hidden_state, backward_cell_state, true);
}
return RET_OK;
}
@ -281,9 +404,8 @@ int LstmCPUKernel::Run() {
MS_ASSERT(weight_i_ptr_);
MS_ASSERT(input_bias_);
MS_ASSERT(state_bias_);
Lstm(output_ptr, input_ptr, weight_i_ptr_, weight_h_ptr_, input_bias_, state_bias_,
reinterpret_cast<float *>(output_hidden_state->data_c()), reinterpret_cast<float *>(output_cell_state->data_c()),
buffer_, lstm_param_);
InnerExecute(output_ptr, input_ptr, reinterpret_cast<float *>(output_hidden_state->data_c()),
reinterpret_cast<float *>(output_cell_state->data_c()));
FreeRunBuffer();
return RET_OK;
}

View File

@ -36,6 +36,8 @@ class LstmCPUKernel : public InnerKernel {
int ReSize() override;
int Run() override;
int InputWeightMatMul(int task_id);
private:
void FreeTmpBuffer();
void FreeRunBuffer();
@ -44,15 +46,27 @@ class LstmCPUKernel : public InnerKernel {
int InitInputWeightBias();
int InitStateWeightBias();
int LstmUnidirectional(float *output, const float *weight_i, const float *weight_h, const float *input_bias,
const float *state_bias, float *hidden_state, float *cell_state, bool is_backward);
int InnerExecute(float *output, const float *input, float *hidden_state, float *cell_state);
const float *weight_loop_;
const float *bias_loop_;
float *gate_loop_;
int input_thread_count_ = 0;
int input_thread_stride_ = 0;
float *weight_i_ptr_ = nullptr;
float *weight_h_ptr_ = nullptr;
float *input_bias_ = nullptr;
float *state_bias_ = nullptr;
float *buffer_[6];
float *buffer_[7];
int row_tile_ = 0;
int col_tile_ = 0;
int state_row_tile_ = 0;
int state_col_tile_ = 0;
int weight_batch_ = 0;
bool is_vec_ = false;
bool state_is_vec_ = false;
bool output_need_packed_ = false;
LstmParameter *lstm_param_ = nullptr;
};
} // namespace mindspore::kernel