[MSLITE][DEVELOP] optimize fp32 lstm, add parallel run
This commit is contained in:
parent
55907b79e0
commit
f8a232d2f7
|
@ -19,6 +19,13 @@
|
|||
#include "nnacl/fp32/activation_fp32.h"
|
||||
#include "nnacl/fp32/arithmetic_fp32.h"
|
||||
#include "nnacl/fp32/matmul_fp32.h"
|
||||
void GruMatMul(float *c, const float *a, const float *b, const float *bias, int row, int deep, int col, bool is_vec) {
|
||||
if (is_vec) {
|
||||
MatVecMulFp32(a, b, c, bias, ActType_No, deep, col);
|
||||
} else {
|
||||
MatMulOpt(a, b, c, bias, ActType_No, deep, row, col, col, OutType_Nhwc);
|
||||
}
|
||||
}
|
||||
|
||||
void GruStepUnit(float *output, float *update_gate, float *reset_gate, float *hidden_buffer, const float *state_weight,
|
||||
const float *state_bias, float *hidden_state, float *buffer[4], const GruParameter *gru_param) {
|
||||
|
@ -38,16 +45,16 @@ void GruStepUnit(float *output, float *update_gate, float *reset_gate, float *hi
|
|||
|
||||
// state * weight
|
||||
if (is_vec) {
|
||||
LstmMatMul(state_reset_gate, hidden_state, state_reset_weight, state_reset_bias, gru_param->batch_,
|
||||
gru_param->hidden_size_, gru_param->hidden_size_, is_vec);
|
||||
LstmMatMul(state_update_gate, hidden_state, state_update_weight, state_update_bias, gru_param->batch_,
|
||||
gru_param->hidden_size_, gru_param->hidden_size_, is_vec);
|
||||
GruMatMul(state_reset_gate, hidden_state, state_reset_weight, state_reset_bias, gru_param->batch_,
|
||||
gru_param->hidden_size_, gru_param->hidden_size_, is_vec);
|
||||
GruMatMul(state_update_gate, hidden_state, state_update_weight, state_update_bias, gru_param->batch_,
|
||||
gru_param->hidden_size_, gru_param->hidden_size_, is_vec);
|
||||
} else {
|
||||
PackLstmInput(hidden_state, packed_state, gru_param->batch_, gru_param->hidden_size_);
|
||||
LstmMatMul(state_reset_gate, packed_state, state_reset_weight, state_reset_bias, gru_param->batch_,
|
||||
gru_param->hidden_size_, gru_param->hidden_size_, is_vec);
|
||||
LstmMatMul(state_update_gate, packed_state, state_update_weight, state_update_bias, gru_param->batch_,
|
||||
gru_param->hidden_size_, gru_param->hidden_size_, is_vec);
|
||||
GruMatMul(state_reset_gate, packed_state, state_reset_weight, state_reset_bias, gru_param->batch_,
|
||||
gru_param->hidden_size_, gru_param->hidden_size_, is_vec);
|
||||
GruMatMul(state_update_gate, packed_state, state_update_weight, state_update_bias, gru_param->batch_,
|
||||
gru_param->hidden_size_, gru_param->hidden_size_, is_vec);
|
||||
}
|
||||
ElementAdd(update_gate, state_update_gate, update_gate, gru_param->batch_ * gru_param->hidden_size_);
|
||||
ElementAdd(reset_gate, state_update_gate + gru_param->batch_ * gru_param->hidden_size_, reset_gate,
|
||||
|
@ -60,12 +67,12 @@ void GruStepUnit(float *output, float *update_gate, float *reset_gate, float *hi
|
|||
|
||||
ElementMul(hidden_state, reset_gate, reset_gate, gru_param->batch_ * gru_param->hidden_size_);
|
||||
if (is_vec) {
|
||||
LstmMatMul(state_hidden_buffer, reset_gate, state_hidden_weight, state_hidden_bias, gru_param->batch_,
|
||||
gru_param->hidden_size_, gru_param->hidden_size_, is_vec);
|
||||
GruMatMul(state_hidden_buffer, reset_gate, state_hidden_weight, state_hidden_bias, gru_param->batch_,
|
||||
gru_param->hidden_size_, gru_param->hidden_size_, is_vec);
|
||||
} else {
|
||||
PackLstmInput(reset_gate, packed_state, gru_param->batch_, gru_param->hidden_size_);
|
||||
LstmMatMul(state_hidden_buffer, packed_state, state_hidden_weight, state_hidden_bias, gru_param->batch_,
|
||||
gru_param->hidden_size_, gru_param->hidden_size_, is_vec);
|
||||
GruMatMul(state_hidden_buffer, packed_state, state_hidden_weight, state_hidden_bias, gru_param->batch_,
|
||||
gru_param->hidden_size_, gru_param->hidden_size_, is_vec);
|
||||
}
|
||||
ElementAdd(hidden_buffer, state_hidden_buffer, hidden_buffer, gru_param->batch_ * gru_param->hidden_size_);
|
||||
|
||||
|
|
|
@ -20,6 +20,7 @@
|
|||
#include "nnacl/fp32/activation_fp32.h"
|
||||
#include "nnacl/fp32/arithmetic_fp32.h"
|
||||
#include "nnacl/fp32/matmul_fp32.h"
|
||||
#include "nnacl/fp32/pack_fp32.h"
|
||||
|
||||
void PackLstmWeight(float *dst, const float *src, int batch, int deep, int col, int col_align) {
|
||||
for (int i = 0; i < batch; i++) {
|
||||
|
@ -63,9 +64,21 @@ void PackLstmInput(const float *src, float *dst, int row, int deep) {
|
|||
#endif
|
||||
}
|
||||
|
||||
void LstmMatMul(float *c, const float *a, const float *b, const float *bias, int row, int deep, int col, bool is_vec) {
|
||||
void LstmMatMul(float *c, const float *a, const float *b, const float *bias, int row, int deep, int col, int col_align,
|
||||
bool is_vec, float *packed_ptr) {
|
||||
if (is_vec) {
|
||||
#ifdef ENABLE_AVX
|
||||
bool need_packed = col % C8NUM;
|
||||
if (!need_packed) {
|
||||
packed_ptr = c;
|
||||
}
|
||||
MatVecMulAvxFp32(a, b, packed_ptr, bias, ActType_No, deep, col, col_align);
|
||||
if (need_packed) {
|
||||
PackNHWCXToNHWCFp32(packed_ptr, c, 1, row, col, C8NUM);
|
||||
}
|
||||
#else
|
||||
MatVecMulFp32(a, b, c, bias, ActType_No, deep, col);
|
||||
#endif
|
||||
} else {
|
||||
MatMulOpt(a, b, c, bias, ActType_No, deep, row, col, col, OutType_Nhwc);
|
||||
}
|
||||
|
@ -140,32 +153,42 @@ void UpdataOutput(const float *cell_state, const float *output_gate, float *hidd
|
|||
}
|
||||
|
||||
void UpdateLstmGate(float *gate_buffer, const float *input, const float *weight, const float *bias, int row, int deep,
|
||||
int col, int col_align, bool is_vec) {
|
||||
int col, int col_align, bool is_vec, float *packed_ptr) {
|
||||
for (int i = 0; i < 4; i++) {
|
||||
const float *weight_i = weight + deep * col * i;
|
||||
const float *weight_i;
|
||||
#ifdef ENABLE_AVX
|
||||
if (is_vec) {
|
||||
weight_i = weight + deep * col_align * i;
|
||||
} else {
|
||||
weight_i = weight + deep * col * i;
|
||||
}
|
||||
#else
|
||||
weight_i = weight + deep * col * i;
|
||||
#endif
|
||||
const float *bias_i = bias + col_align * i;
|
||||
float *gate = gate_buffer + row * col * i;
|
||||
LstmMatMul(gate, input, weight_i, bias_i, row, deep, col, is_vec);
|
||||
LstmMatMul(gate, input, weight_i, bias_i, row, deep, col, col_align, is_vec, packed_ptr);
|
||||
}
|
||||
}
|
||||
|
||||
void LstmStepUnit(float *output, float *input_gate, float *forget_gate, float *cell_gate, float *output_gate,
|
||||
const float *state_weight, const float *state_bias, float *hidden_state, float *cell_state,
|
||||
float *buffer[6], const LstmParameter *lstm_param) {
|
||||
float *buffer[7], const LstmParameter *lstm_param) {
|
||||
float *packed_state = buffer[2];
|
||||
float *state_gate = buffer[3];
|
||||
float *cell_buffer = buffer[4];
|
||||
float *hidden_buffer = buffer[5];
|
||||
float *packed_output = buffer[6];
|
||||
bool is_vec = lstm_param->batch_ == 1;
|
||||
// state * weight
|
||||
if (is_vec) {
|
||||
UpdateLstmGate(state_gate, hidden_state, state_weight, state_bias, lstm_param->batch_, lstm_param->hidden_size_,
|
||||
lstm_param->hidden_size_, lstm_param->state_col_align_, is_vec);
|
||||
lstm_param->hidden_size_, lstm_param->state_col_align_, is_vec, packed_output);
|
||||
} else {
|
||||
// pack state for matmul
|
||||
PackLstmInput(hidden_state, packed_state, lstm_param->batch_, lstm_param->hidden_size_);
|
||||
UpdateLstmGate(state_gate, packed_state, state_weight, state_bias, lstm_param->batch_, lstm_param->hidden_size_,
|
||||
lstm_param->hidden_size_, lstm_param->state_col_align_, is_vec);
|
||||
lstm_param->hidden_size_, lstm_param->state_col_align_, is_vec, packed_output);
|
||||
}
|
||||
ElementAdd(input_gate, state_gate, input_gate, lstm_param->batch_ * lstm_param->hidden_size_);
|
||||
ElementAdd(forget_gate, state_gate + lstm_param->batch_ * lstm_param->hidden_size_ * 2, forget_gate,
|
||||
|
|
|
@ -27,12 +27,17 @@ void PackLstmBias(float *dst, const float *src, int batch, int col, int col_alig
|
|||
|
||||
void PackLstmInput(const float *src, float *dst, int row, int deep);
|
||||
|
||||
void LstmMatMul(float *c, const float *a, const float *b, const float *bias, int row, int deep, int col, bool is_vec);
|
||||
void LstmMatMul(float *c, const float *a, const float *b, const float *bias, int row, int deep, int col, int col_align,
|
||||
bool is_vec, float *packed_ptr);
|
||||
|
||||
void ElementMulAcc(const float *input0, const float *input1, float *output, int element_size);
|
||||
|
||||
int ElementOptMulAcc(const float *input0, const float input1, float *output, const int element_size);
|
||||
|
||||
void LstmStepUnit(float *output, float *input_gate, float *forget_gate, float *cell_gate, float *output_gate,
|
||||
const float *state_weight, const float *state_bias, float *hidden_state, float *cell_state,
|
||||
float *buffer[6], const LstmParameter *lstm_param);
|
||||
|
||||
void Lstm(float *output, const float *input, const float *weight_i, const float *weight_h, const float *input_bias,
|
||||
const float *state_bias, float *hidden_state, float *cell_state, float *buffer[6],
|
||||
const LstmParameter *lstm_param);
|
||||
|
|
|
@ -39,7 +39,7 @@ void LstmCPUKernel::FreeTmpBuffer() {
|
|||
free(input_bias_);
|
||||
input_bias_ = nullptr;
|
||||
}
|
||||
if (!is_vec_) {
|
||||
if (!state_is_vec_) {
|
||||
if (weight_h_ptr_ != nullptr) {
|
||||
free(weight_h_ptr_);
|
||||
weight_h_ptr_ = nullptr;
|
||||
|
@ -54,7 +54,7 @@ void LstmCPUKernel::FreeTmpBuffer() {
|
|||
void LstmCPUKernel::FreeRunBuffer() {
|
||||
context_->allocator->Free(buffer_[0]);
|
||||
context_->allocator->Free(buffer_[1]);
|
||||
if (!is_vec_) {
|
||||
if (!state_is_vec_) {
|
||||
context_->allocator->Free(buffer_[2]);
|
||||
}
|
||||
context_->allocator->Free(buffer_[3]);
|
||||
|
@ -64,6 +64,9 @@ void LstmCPUKernel::FreeRunBuffer() {
|
|||
if (!(lstm_param_->zoneout_hidden_ >= -FLT_EPSILON && lstm_param_->zoneout_hidden_ <= FLT_EPSILON)) {
|
||||
context_->allocator->Free(buffer_[5]);
|
||||
}
|
||||
if (output_need_packed_) {
|
||||
context_->allocator->Free(buffer_[6]);
|
||||
}
|
||||
}
|
||||
|
||||
int LstmCPUKernel::InitInputWeightBias() {
|
||||
|
@ -103,7 +106,7 @@ int LstmCPUKernel::InitStateWeightBias() {
|
|||
auto weight_h = in_tensors_.at(2);
|
||||
MS_ASSERT(weight_h != nullptr);
|
||||
auto weight_h_data = reinterpret_cast<float *>(weight_h->data_c());
|
||||
if (!is_vec_) {
|
||||
if (!state_is_vec_) {
|
||||
weight_h_ptr_ = reinterpret_cast<float *>(
|
||||
malloc(weight_batch_ * lstm_param_->state_col_align_ * lstm_param_->hidden_size_ * sizeof(float)));
|
||||
if (weight_h_ptr_ == nullptr) {
|
||||
|
@ -113,7 +116,17 @@ int LstmCPUKernel::InitStateWeightBias() {
|
|||
PackLstmWeight(weight_h_ptr_, weight_h_data, weight_batch_, lstm_param_->hidden_size_, lstm_param_->hidden_size_,
|
||||
lstm_param_->state_col_align_);
|
||||
} else {
|
||||
#ifdef ENABLE_AVX
|
||||
weight_h_ptr_ = reinterpret_cast<float *>(
|
||||
malloc(weight_batch_ * lstm_param_->state_col_align_ * lstm_param_->hidden_size_ * sizeof(float)));
|
||||
for (int i = 0; i < weight_batch_; i++) {
|
||||
const float *src_batch = weight_h_data + i * lstm_param_->hidden_size_ * lstm_param_->hidden_size_;
|
||||
float *dst_batch = weight_h_ptr_ + i * lstm_param_->state_col_align_ * lstm_param_->hidden_size_;
|
||||
RowMajor2Col32Major(src_batch, dst_batch, lstm_param_->hidden_size_, lstm_param_->hidden_size_);
|
||||
}
|
||||
#else
|
||||
weight_h_ptr_ = weight_h_data;
|
||||
#endif
|
||||
}
|
||||
|
||||
// state bias
|
||||
|
@ -145,6 +158,7 @@ int LstmCPUKernel::InitParam() {
|
|||
lstm_param_->output_step_ = lstm_param_->bidirectional_ ? 2 * lstm_param_->batch_ * lstm_param_->hidden_size_
|
||||
: lstm_param_->batch_ * lstm_param_->hidden_size_;
|
||||
weight_batch_ = lstm_param_->bidirectional_ ? 8 : 4;
|
||||
state_is_vec_ = lstm_param_->batch_ == 1;
|
||||
|
||||
#ifdef ENABLE_AVX
|
||||
row_tile_ = C6NUM;
|
||||
|
@ -161,10 +175,25 @@ int LstmCPUKernel::InitParam() {
|
|||
#endif
|
||||
lstm_param_->input_row_align_ = UP_ROUND(lstm_param_->seq_len_ * lstm_param_->batch_, row_tile_);
|
||||
lstm_param_->input_col_align_ = UP_ROUND(lstm_param_->hidden_size_, col_tile_);
|
||||
input_thread_count_ = MSMIN(op_parameter_->thread_num_, UP_DIV(lstm_param_->input_col_align_, col_tile_));
|
||||
input_thread_stride_ = UP_DIV(UP_DIV(lstm_param_->input_col_align_, col_tile_), input_thread_count_);
|
||||
|
||||
is_vec_ = lstm_param_->batch_ == 1;
|
||||
lstm_param_->state_row_align_ = is_vec_ ? 1 : UP_ROUND(lstm_param_->batch_, row_tile_);
|
||||
lstm_param_->state_col_align_ = is_vec_ ? lstm_param_->hidden_size_ : UP_ROUND(lstm_param_->hidden_size_, col_tile_);
|
||||
state_row_tile_ = row_tile_;
|
||||
state_col_tile_ = col_tile_;
|
||||
#ifdef ENABLE_AVX
|
||||
if (state_is_vec_) {
|
||||
state_row_tile_ = 1;
|
||||
state_col_tile_ = C8NUM;
|
||||
}
|
||||
#endif
|
||||
|
||||
lstm_param_->state_row_align_ = state_is_vec_ ? 1 : UP_ROUND(lstm_param_->batch_, state_row_tile_);
|
||||
#ifdef ENABLE_AVX
|
||||
lstm_param_->state_col_align_ = UP_ROUND(lstm_param_->hidden_size_, state_col_tile_);
|
||||
#else
|
||||
lstm_param_->state_col_align_ =
|
||||
state_is_vec_ ? lstm_param_->hidden_size_ : UP_ROUND(lstm_param_->hidden_size_, state_col_tile_);
|
||||
#endif
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
|
@ -217,7 +246,7 @@ int LstmCPUKernel::MallocRunBuffer() {
|
|||
return RET_ERROR;
|
||||
}
|
||||
|
||||
if (!is_vec_) {
|
||||
if (!state_is_vec_) {
|
||||
buffer_[2] = reinterpret_cast<float *>(
|
||||
context_->allocator->Malloc(lstm_param_->state_row_align_ * lstm_param_->hidden_size_ * sizeof(float)));
|
||||
if (buffer_[2] == nullptr) {
|
||||
|
@ -249,6 +278,100 @@ int LstmCPUKernel::MallocRunBuffer() {
|
|||
return RET_ERROR;
|
||||
}
|
||||
}
|
||||
#ifdef ENABLE_AVX
|
||||
if (state_is_vec_) { // vec matmul need to malloc dst
|
||||
output_need_packed_ = lstm_param_->hidden_size_ % state_col_tile_;
|
||||
if (output_need_packed_) {
|
||||
int out_channel = lstm_param_->hidden_size_;
|
||||
int oc_block_num = UP_DIV(out_channel, state_col_tile_);
|
||||
MS_ASSERT(context_->allocator != nullptr);
|
||||
buffer_[6] = reinterpret_cast<float *>(
|
||||
context_->allocator->Malloc(lstm_param_->batch_ * oc_block_num * state_col_tile_ * sizeof(float)));
|
||||
if (buffer_[6] == nullptr) {
|
||||
MS_LOG(ERROR) << "LstmCPUKernel malloc tmp output data failed.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
}
|
||||
}
|
||||
#endif
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
int LstmCPUKernel::InputWeightMatMul(int task_id) {
|
||||
int current_start_oc = task_id * input_thread_stride_ * col_tile_;
|
||||
int current_rest_oc = 0;
|
||||
current_rest_oc = lstm_param_->hidden_size_ - current_start_oc;
|
||||
int cur_oc = MSMIN(input_thread_stride_ * col_tile_, current_rest_oc);
|
||||
if (cur_oc <= 0) {
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
auto input = buffer_[0];
|
||||
auto b = weight_loop_ + current_start_oc * lstm_param_->input_size_;
|
||||
auto c = gate_loop_ + current_start_oc;
|
||||
auto bias = (bias_loop_ == nullptr) ? nullptr : bias_loop_ + current_start_oc;
|
||||
MatMulOpt(input, b, c, bias, ActType_No, lstm_param_->input_size_, lstm_param_->seq_len_ * lstm_param_->batch_,
|
||||
cur_oc, lstm_param_->hidden_size_, OutType_Nhwc);
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
int LstmInputMulWeightRun(void *cdata, int task_id, float lhs_scale, float rhs_scale) {
|
||||
auto kernel = reinterpret_cast<LstmCPUKernel *>(cdata);
|
||||
auto ret = kernel->InputWeightMatMul(task_id);
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "InputWeightMatMul error task_id[" << task_id << "] error_code[" << ret << "]";
|
||||
return RET_ERROR;
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
int LstmCPUKernel::LstmUnidirectional(float *output, const float *weight_i, const float *weight_h,
|
||||
const float *input_bias, const float *state_bias, float *hidden_state,
|
||||
float *cell_state, bool is_backward) {
|
||||
float *gate = buffer_[1];
|
||||
for (int i = 0; i < 4; i++) {
|
||||
weight_loop_ = weight_i + lstm_param_->input_size_ * lstm_param_->input_col_align_ * i;
|
||||
bias_loop_ = input_bias + lstm_param_->input_col_align_ * i;
|
||||
gate_loop_ = gate + lstm_param_->seq_len_ * lstm_param_->batch_ * lstm_param_->hidden_size_ * i;
|
||||
ParallelLaunch(this->context_, LstmInputMulWeightRun, this, input_thread_count_);
|
||||
}
|
||||
|
||||
float *input_gate = gate;
|
||||
float *forget_gate = gate + lstm_param_->seq_len_ * lstm_param_->batch_ * lstm_param_->hidden_size_ * 2;
|
||||
float *cell_gate = gate + lstm_param_->seq_len_ * lstm_param_->batch_ * lstm_param_->hidden_size_ * 3;
|
||||
float *output_gate = gate + lstm_param_->seq_len_ * lstm_param_->batch_ * lstm_param_->hidden_size_;
|
||||
for (int t = 0; t < lstm_param_->seq_len_; t++) {
|
||||
int real_t = is_backward ? lstm_param_->seq_len_ - t - 1 : t;
|
||||
float *input_gate_t = input_gate + lstm_param_->batch_ * lstm_param_->hidden_size_ * real_t;
|
||||
float *forget_gate_t = forget_gate + lstm_param_->batch_ * lstm_param_->hidden_size_ * real_t;
|
||||
float *cell_gate_t = cell_gate + lstm_param_->batch_ * lstm_param_->hidden_size_ * real_t;
|
||||
float *output_gate_t = output_gate + lstm_param_->batch_ * lstm_param_->hidden_size_ * real_t;
|
||||
float *output_ptr = output + real_t * lstm_param_->output_step_;
|
||||
LstmStepUnit(output_ptr, input_gate_t, forget_gate_t, cell_gate_t, output_gate_t, weight_h, state_bias,
|
||||
hidden_state, cell_state, buffer_, lstm_param_);
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
int LstmCPUKernel::InnerExecute(float *output, const float *input, float *hidden_state, float *cell_state) {
|
||||
// forward
|
||||
// buffer_[0] : store packed input
|
||||
PackLstmInput(input, buffer_[0], lstm_param_->seq_len_ * lstm_param_->batch_, lstm_param_->input_size_);
|
||||
LstmUnidirectional(output, weight_i_ptr_, weight_h_ptr_, input_bias_, state_bias_, hidden_state, cell_state, false);
|
||||
|
||||
// backward
|
||||
if (lstm_param_->bidirectional_) {
|
||||
const float *backward_weight_i = weight_i_ptr_ + 4 * lstm_param_->input_col_align_ * lstm_param_->input_size_;
|
||||
const float *backward_weight_h = weight_h_ptr_ + 4 * lstm_param_->state_col_align_ * lstm_param_->hidden_size_;
|
||||
const float *backward_input_bias = input_bias_ + 4 * lstm_param_->input_col_align_;
|
||||
const float *backward_state_bias = state_bias_ + 4 * lstm_param_->state_col_align_;
|
||||
float *backward_output = output + lstm_param_->batch_ * lstm_param_->hidden_size_;
|
||||
float *backward_cell_state = cell_state + lstm_param_->batch_ * lstm_param_->hidden_size_;
|
||||
float *backward_hidden_state = hidden_state + lstm_param_->batch_ * lstm_param_->hidden_size_;
|
||||
|
||||
LstmUnidirectional(backward_output, backward_weight_i, backward_weight_h, backward_input_bias, backward_state_bias,
|
||||
backward_hidden_state, backward_cell_state, true);
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
|
@ -281,9 +404,8 @@ int LstmCPUKernel::Run() {
|
|||
MS_ASSERT(weight_i_ptr_);
|
||||
MS_ASSERT(input_bias_);
|
||||
MS_ASSERT(state_bias_);
|
||||
Lstm(output_ptr, input_ptr, weight_i_ptr_, weight_h_ptr_, input_bias_, state_bias_,
|
||||
reinterpret_cast<float *>(output_hidden_state->data_c()), reinterpret_cast<float *>(output_cell_state->data_c()),
|
||||
buffer_, lstm_param_);
|
||||
InnerExecute(output_ptr, input_ptr, reinterpret_cast<float *>(output_hidden_state->data_c()),
|
||||
reinterpret_cast<float *>(output_cell_state->data_c()));
|
||||
FreeRunBuffer();
|
||||
return RET_OK;
|
||||
}
|
||||
|
|
|
@ -36,6 +36,8 @@ class LstmCPUKernel : public InnerKernel {
|
|||
int ReSize() override;
|
||||
int Run() override;
|
||||
|
||||
int InputWeightMatMul(int task_id);
|
||||
|
||||
private:
|
||||
void FreeTmpBuffer();
|
||||
void FreeRunBuffer();
|
||||
|
@ -44,15 +46,27 @@ class LstmCPUKernel : public InnerKernel {
|
|||
int InitInputWeightBias();
|
||||
int InitStateWeightBias();
|
||||
|
||||
int LstmUnidirectional(float *output, const float *weight_i, const float *weight_h, const float *input_bias,
|
||||
const float *state_bias, float *hidden_state, float *cell_state, bool is_backward);
|
||||
int InnerExecute(float *output, const float *input, float *hidden_state, float *cell_state);
|
||||
const float *weight_loop_;
|
||||
const float *bias_loop_;
|
||||
float *gate_loop_;
|
||||
int input_thread_count_ = 0;
|
||||
int input_thread_stride_ = 0;
|
||||
|
||||
float *weight_i_ptr_ = nullptr;
|
||||
float *weight_h_ptr_ = nullptr;
|
||||
float *input_bias_ = nullptr;
|
||||
float *state_bias_ = nullptr;
|
||||
float *buffer_[6];
|
||||
float *buffer_[7];
|
||||
int row_tile_ = 0;
|
||||
int col_tile_ = 0;
|
||||
int state_row_tile_ = 0;
|
||||
int state_col_tile_ = 0;
|
||||
int weight_batch_ = 0;
|
||||
bool is_vec_ = false;
|
||||
bool state_is_vec_ = false;
|
||||
bool output_need_packed_ = false;
|
||||
LstmParameter *lstm_param_ = nullptr;
|
||||
};
|
||||
} // namespace mindspore::kernel
|
||||
|
|
Loading…
Reference in New Issue