diff --git a/mindspore/lite/nnacl/fp32/lstm_fp32.c b/mindspore/lite/nnacl/fp32/lstm_fp32.c index 7c88d8707e6..cd1f12007cb 100644 --- a/mindspore/lite/nnacl/fp32/lstm_fp32.c +++ b/mindspore/lite/nnacl/fp32/lstm_fp32.c @@ -19,20 +19,7 @@ #include #include "nnacl/fp32/activation_fp32.h" #include "nnacl/fp32/arithmetic_fp32.h" -#include "nnacl/fp32/mul_fp32.h" - -void InitGate(float *gate_buffer, const float *bias, const LstmParameter *lstm_parm) { - int gate_offest = 0; - for (int l = 0; l < 4; l++) { - int batch_offest = gate_offest; - int bias_offest = l * lstm_parm->hidden_size_; - for (int b = 0; b < lstm_parm->batch_; b++) { - memcpy(gate_buffer + batch_offest, bias + bias_offest, lstm_parm->hidden_size_ * sizeof(float)); - batch_offest += lstm_parm->hidden_size_; - } - gate_offest += lstm_parm->batch_ * lstm_parm->hidden_size_; - } -} +#include "nnacl/fp32/matmul_fp32.h" // input: [row, inner_size]; weight: [col, inner_size]; output: [row, col] void MatMulAcc(float *output, const float *input, const float *weight, int rows, int cols, int inner_size) { @@ -134,106 +121,131 @@ void UpdataOutput(const float *cell_state, const float *output_gate, float *hidd } } -void LstmStepUnit(float *output, const float *input, const float *input_input_weight, const float *input_forget_weight, - const float *input_cell_weight, const float *input_output_weight, const float *state_input_weight, - const float *state_forget_weight, const float *state_cell_weight, const float *state_output_weight, - const float *bias, float *hidden_state, float *cell_state, float *gate_buffer, float *state_buffer, - const LstmParameter *lstm_parm) { - InitGate(gate_buffer, bias, lstm_parm); +void LstmMatmul(float *c, const float *a, const float *b, const float *bias, int row, int deep, int col, bool is_vec) { + if (is_vec) { + memcpy(c, bias, col * sizeof(float)); + MatMulAcc(c, a, b, row, col, deep); + } else { + MatMulOpt(a, b, c, bias, ActType_No, deep, row, col, col, OutType_Nhwc); + } +} + +void PackLstmInput(float *dst, const float *src, int row, int deep) { +#ifdef ENABLE_AVX + RowMajor2Col6Major(src, dst, row, deep); +#elif defined(ENABLE_SSE) + RowMajor2Col4Major(src, dst, row, deep); +#else + RowMajor2Col12Major(src, dst, row, deep); +#endif +} + +void UpdateGate(float *gate_buffer, const float *input, const float *weight, const float *bias, int row, int deep, + int col, int col_align, bool is_vec) { + const float *input_weight = weight; + const float *forget_weight = weight + deep * col * 2; + const float *cell_weight = weight + deep * col * 3; + const float *output_weight = weight + deep * col; + + const float *input_bias = bias; + const float *forget_bias = bias + col_align * 2; + const float *cell_bias = bias + col_align * 3; + const float *output_bias = bias + col_align; float *input_gate = gate_buffer; - float *forget_gate = gate_buffer + lstm_parm->batch_ * lstm_parm->hidden_size_ * 2; - float *cell_gate = gate_buffer + lstm_parm->batch_ * lstm_parm->hidden_size_ * 3; - float *output_gate = gate_buffer + lstm_parm->batch_ * lstm_parm->hidden_size_ * 1; + float *forget_gate = gate_buffer + row * col * 2; + float *cell_gate = gate_buffer + row * col * 3; + float *output_gate = gate_buffer + row * col; + LstmMatmul(input_gate, input, input_weight, input_bias, row, deep, col, is_vec); + LstmMatmul(forget_gate, input, forget_weight, forget_bias, row, deep, col, is_vec); + LstmMatmul(cell_gate, input, cell_weight, cell_bias, row, deep, col, is_vec); + LstmMatmul(output_gate, input, output_weight, output_bias, row, deep, col, is_vec); +} + +void LstmStepUnit(float *output, const float *input, const float *input_weight, const float *state_weight, + const float *bias, float *hidden_state, float *cell_state, float *gate_buffer, float *state_buffer, + float *matmul_buffer[2], const LstmParameter *lstm_param) { + bool is_vec = lstm_param->batch_ == 1; // input * weight - MatMulAcc(input_gate, input, input_input_weight, lstm_parm->batch_, lstm_parm->hidden_size_, lstm_parm->input_size_); - MatMulAcc(forget_gate, input, input_forget_weight, lstm_parm->batch_, lstm_parm->hidden_size_, - lstm_parm->input_size_); - MatMulAcc(cell_gate, input, input_cell_weight, lstm_parm->batch_, lstm_parm->hidden_size_, lstm_parm->input_size_); - MatMulAcc(output_gate, input, input_output_weight, lstm_parm->batch_, lstm_parm->hidden_size_, - lstm_parm->input_size_); + if (is_vec) { + UpdateGate(gate_buffer, input, input_weight, bias, lstm_param->batch_, lstm_param->input_size_, + lstm_param->hidden_size_, lstm_param->col_align_, is_vec); + } else { + // pack input for matmul + PackLstmInput(matmul_buffer[0], input, lstm_param->batch_, lstm_param->input_size_); + UpdateGate(gate_buffer, matmul_buffer[0], input_weight, bias, lstm_param->batch_, lstm_param->input_size_, + lstm_param->hidden_size_, lstm_param->col_align_, is_vec); + } // state * weight - MatMulAcc(input_gate, hidden_state, state_input_weight, lstm_parm->batch_, lstm_parm->hidden_size_, - lstm_parm->hidden_size_); - MatMulAcc(forget_gate, hidden_state, state_forget_weight, lstm_parm->batch_, lstm_parm->hidden_size_, - lstm_parm->hidden_size_); - MatMulAcc(cell_gate, hidden_state, state_cell_weight, lstm_parm->batch_, lstm_parm->hidden_size_, - lstm_parm->hidden_size_); - MatMulAcc(output_gate, hidden_state, state_output_weight, lstm_parm->batch_, lstm_parm->hidden_size_, - lstm_parm->hidden_size_); + float *state_gate = gate_buffer + lstm_param->batch_ * lstm_param->hidden_size_ * 4; + const float *state_bias = bias + lstm_param->col_align_ * 4; + if (is_vec) { + UpdateGate(state_gate, hidden_state, state_weight, state_bias, lstm_param->batch_, lstm_param->hidden_size_, + lstm_param->hidden_size_, lstm_param->col_align_, is_vec); + } else { + // pack state for matmul + PackLstmInput(matmul_buffer[1], hidden_state, lstm_param->batch_, lstm_param->hidden_size_); + UpdateGate(state_gate, matmul_buffer[1], state_weight, state_bias, lstm_param->batch_, lstm_param->hidden_size_, + lstm_param->hidden_size_, lstm_param->col_align_, is_vec); + } + ElementAdd(gate_buffer, state_gate, gate_buffer, 4 * lstm_param->batch_ * lstm_param->hidden_size_); + float *input_gate = gate_buffer; + float *forget_gate = gate_buffer + lstm_param->batch_ * lstm_param->hidden_size_ * 2; + float *cell_gate = gate_buffer + lstm_param->batch_ * lstm_param->hidden_size_ * 3; + float *output_gate = gate_buffer + lstm_param->batch_ * lstm_param->hidden_size_; // update input_gate - Sigmoid(input_gate, lstm_parm->batch_ * lstm_parm->hidden_size_, input_gate); + Sigmoid(input_gate, lstm_param->batch_ * lstm_param->hidden_size_, input_gate); // update forget_gate - Sigmoid(forget_gate, lstm_parm->batch_ * lstm_parm->hidden_size_, forget_gate); + Sigmoid(forget_gate, lstm_param->batch_ * lstm_param->hidden_size_, forget_gate); // update cell_gate - Tanh(cell_gate, lstm_parm->batch_ * lstm_parm->hidden_size_, cell_gate); + Tanh(cell_gate, lstm_param->batch_ * lstm_param->hidden_size_, cell_gate); // update cell state - UpdataState(cell_state, forget_gate, input_gate, cell_gate, state_buffer, lstm_parm->batch_, lstm_parm->hidden_size_, - lstm_parm->smooth_); + UpdataState(cell_state, forget_gate, input_gate, cell_gate, state_buffer, lstm_param->batch_, + lstm_param->hidden_size_, lstm_param->smooth_); // update output_gate - Sigmoid(output_gate, lstm_parm->batch_ * lstm_parm->hidden_size_, output_gate); + Sigmoid(output_gate, lstm_param->batch_ * lstm_param->hidden_size_, output_gate); // update output - UpdataOutput(cell_state, output_gate, hidden_state, state_buffer, lstm_parm->batch_, lstm_parm->hidden_size_, - lstm_parm->smooth_); - memcpy(output, hidden_state, lstm_parm->batch_ * lstm_parm->hidden_size_ * sizeof(float)); + UpdataOutput(cell_state, output_gate, hidden_state, state_buffer, lstm_param->batch_, lstm_param->hidden_size_, + lstm_param->smooth_); + memcpy(output, hidden_state, lstm_param->batch_ * lstm_param->hidden_size_ * sizeof(float)); - if (!(lstm_parm->smooth_ >= -FLT_EPSILON && lstm_parm->smooth_ <= FLT_EPSILON)) { - memcpy(cell_state, state_buffer, lstm_parm->batch_ * lstm_parm->hidden_size_ * sizeof(float)); - memcpy(hidden_state, state_buffer + lstm_parm->batch_ * lstm_parm->hidden_size_, - lstm_parm->batch_ * lstm_parm->hidden_size_ * sizeof(float)); + if (!(lstm_param->smooth_ >= -FLT_EPSILON && lstm_param->smooth_ <= FLT_EPSILON)) { + memcpy(cell_state, state_buffer, lstm_param->batch_ * lstm_param->hidden_size_ * sizeof(float)); + memcpy(hidden_state, state_buffer + lstm_param->batch_ * lstm_param->hidden_size_, + lstm_param->batch_ * lstm_param->hidden_size_ * sizeof(float)); } } void Lstm(float *output, const float *input, const float *weight_i, const float *weight_h, const float *bias, - float *hidden_state, float *cell_state, float *gate_buffer, float *state_buffer, - const LstmParameter *lstm_parm) { + float *hidden_state, float *cell_state, float *gate_buffer, float *state_buffer, float *matmul_buffer[2], + const LstmParameter *lstm_param) { // forward - const float *input_input_weight = weight_i; - const float *input_forget_weight = weight_i + lstm_parm->input_size_ * lstm_parm->hidden_size_ * 2; - const float *input_cell_weight = weight_i + lstm_parm->input_size_ * lstm_parm->hidden_size_ * 3; - const float *input_output_weight = weight_i + lstm_parm->input_size_ * lstm_parm->hidden_size_ * 1; - - const float *state_input_weight = weight_h; - const float *state_forget_weight = weight_h + lstm_parm->hidden_size_ * lstm_parm->hidden_size_ * 2; - const float *state_cell_weight = weight_h + lstm_parm->hidden_size_ * lstm_parm->hidden_size_ * 3; - const float *state_output_weight = weight_h + lstm_parm->hidden_size_ * lstm_parm->hidden_size_ * 1; - - for (int t = 0; t < lstm_parm->seq_len_; t++) { - const float *input_ptr = input + t * lstm_parm->input_step_; - float *output_ptr = output + t * lstm_parm->output_step_; - LstmStepUnit(output_ptr, input_ptr, input_input_weight, input_forget_weight, input_cell_weight, input_output_weight, - state_input_weight, state_forget_weight, state_cell_weight, state_output_weight, bias, hidden_state, - cell_state, gate_buffer, state_buffer, lstm_parm); + for (int t = 0; t < lstm_param->seq_len_; t++) { + const float *input_ptr = input + t * lstm_param->input_step_; + float *output_ptr = output + t * lstm_param->output_step_; + LstmStepUnit(output_ptr, input_ptr, weight_i, weight_h, bias, hidden_state, cell_state, gate_buffer, state_buffer, + matmul_buffer, lstm_param); } // backward - if (lstm_parm->bidirectional_) { - input_input_weight = weight_i + lstm_parm->input_size_ * lstm_parm->hidden_size_ * 4; - input_forget_weight = weight_i + lstm_parm->input_size_ * lstm_parm->hidden_size_ * 6; - input_cell_weight = weight_i + lstm_parm->input_size_ * lstm_parm->hidden_size_ * 7; - input_output_weight = weight_i + lstm_parm->input_size_ * lstm_parm->hidden_size_ * 5; - - state_input_weight = weight_h + lstm_parm->hidden_size_ * lstm_parm->hidden_size_ * 4; - state_forget_weight = weight_h + lstm_parm->hidden_size_ * lstm_parm->hidden_size_ * 6; - state_cell_weight = weight_h + lstm_parm->hidden_size_ * lstm_parm->hidden_size_ * 7; - state_output_weight = weight_h + lstm_parm->hidden_size_ * lstm_parm->hidden_size_ * 5; - - float *backward_output = output + lstm_parm->batch_ * lstm_parm->hidden_size_; - const float *backward_bias = bias + 4 * lstm_parm->hidden_size_; - float *backward_cell_state = cell_state + lstm_parm->batch_ * lstm_parm->hidden_size_; - float *backward_hidden_state = hidden_state + lstm_parm->batch_ * lstm_parm->hidden_size_; - for (int t = lstm_parm->seq_len_ - 1; t >= 0; t--) { - const float *input_ptr = input + t * lstm_parm->input_step_; - float *output_ptr = backward_output + t * lstm_parm->output_step_; - LstmStepUnit(output_ptr, input_ptr, input_input_weight, input_forget_weight, input_cell_weight, - input_output_weight, state_input_weight, state_forget_weight, state_cell_weight, state_output_weight, - backward_bias, backward_hidden_state, backward_cell_state, gate_buffer, state_buffer, lstm_parm); + if (lstm_param->bidirectional_) { + const float *backward_weight_i = weight_i + 4 * lstm_param->col_align_ * lstm_param->input_size_; + const float *backward_weight_h = weight_h + 4 * lstm_param->col_align_ * lstm_param->hidden_size_; + const float *backward_bias = bias + 8 * lstm_param->hidden_size_; + float *backward_output = output + lstm_param->batch_ * lstm_param->hidden_size_; + float *backward_cell_state = cell_state + lstm_param->batch_ * lstm_param->hidden_size_; + float *backward_hidden_state = hidden_state + lstm_param->batch_ * lstm_param->hidden_size_; + for (int t = lstm_param->seq_len_ - 1; t >= 0; t--) { + const float *input_ptr = input + t * lstm_param->input_step_; + float *output_ptr = backward_output + t * lstm_param->output_step_; + LstmStepUnit(output_ptr, input_ptr, backward_weight_i, backward_weight_h, backward_bias, backward_hidden_state, + backward_cell_state, gate_buffer, state_buffer, matmul_buffer, lstm_param); } } } diff --git a/mindspore/lite/nnacl/fp32/lstm_fp32.h b/mindspore/lite/nnacl/fp32/lstm_fp32.h index 207fd2bd24a..709a62a6fa8 100644 --- a/mindspore/lite/nnacl/fp32/lstm_fp32.h +++ b/mindspore/lite/nnacl/fp32/lstm_fp32.h @@ -28,7 +28,7 @@ void ElementMulAcc(const float *input0, const float *input1, float *output, int int ElementOptMulAcc(const float *input0, const float input1, float *output, const int element_size); void Lstm(float *output, const float *input, const float *weight_i, const float *weight_h, const float *bias, - float *hidden_state, float *cell_state, float *gate_buffer, float *state_buffer, + float *hidden_state, float *cell_state, float *gate_buffer, float *state_buffer, float *matmul_buffer[2], const LstmParameter *lstm_parm); #ifdef __cplusplus } diff --git a/mindspore/lite/nnacl/lstm_parameter.h b/mindspore/lite/nnacl/lstm_parameter.h index e9bc1621d80..8a04e3a81a3 100644 --- a/mindspore/lite/nnacl/lstm_parameter.h +++ b/mindspore/lite/nnacl/lstm_parameter.h @@ -34,6 +34,8 @@ typedef struct LstmParameter { // output_hidden = old_hidden * smooth + new_hidden * (1 - smooth) // output_cell = old_cell * smooth + new_cell * (1 - smooth) float smooth_; + int col_align_; + int row_align_; } LstmParameter; #endif // MINDSPORE_LITE_NNACL_LSTM_PARAMETER_H_ diff --git a/mindspore/lite/src/common/graph_util.cc b/mindspore/lite/src/common/graph_util.cc index d43fb7405a1..59c8015e2bb 100644 --- a/mindspore/lite/src/common/graph_util.cc +++ b/mindspore/lite/src/common/graph_util.cc @@ -84,9 +84,8 @@ std::vector GetLinkedPostNodeIdx(const lite::Model *model, const size_t bool IsPackedOp(schema::PrimitiveType op_type) { static std::vector packed_ops = { - schema::PrimitiveType_Conv2D, schema::PrimitiveType_DeConv2D, - schema::PrimitiveType_DepthwiseConv2D, schema::PrimitiveType_DeDepthwiseConv2D, - schema::PrimitiveType_MatMul, schema::PrimitiveType_Lstm}; + schema::PrimitiveType_Conv2D, schema::PrimitiveType_DeConv2D, schema::PrimitiveType_DepthwiseConv2D, + schema::PrimitiveType_DeDepthwiseConv2D, schema::PrimitiveType_MatMul}; return IsContain(packed_ops, op_type); } } // namespace lite diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/lstm_fp32.cc b/mindspore/lite/src/runtime/kernel/arm/fp32/lstm_fp32.cc index 3c2c2f4fc1f..828dd14f040 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32/lstm_fp32.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/lstm_fp32.cc @@ -20,35 +20,104 @@ #include "schema/model_generated.h" #include "src/kernel_registry.h" #include "include/errorcode.h" +#include "nnacl/fp32/matmul_fp32.h" using mindspore::kernel::KERNEL_ARCH::kCPU; using mindspore::lite::KernelRegistrar; using mindspore::lite::RET_ERROR; +using mindspore::lite::RET_MEMORY_FAILED; using mindspore::lite::RET_OK; using mindspore::schema::PrimitiveType_Lstm; namespace mindspore::kernel { void LstmCPUKernel::FreeTmpBuffer() { - if (gate_buffer_ != nullptr) { - free(gate_buffer_); - gate_buffer_ = nullptr; + if (!is_vec_) { + if (weight_i_ptr_ != nullptr) { + free(weight_i_ptr_); + weight_i_ptr_ = nullptr; + } + if (weight_h_ptr_ != nullptr) { + free(weight_h_ptr_); + weight_h_ptr_ = nullptr; + } + if (bias_ptr_ != nullptr) { + free(bias_ptr_); + bias_ptr_ = nullptr; + } } - if (state_buffer_ != nullptr) { - free(state_buffer_); - state_buffer_ = nullptr; +} + +void LstmCPUKernel::FreeRunBuffer() { + context_->allocator->Free(gate_buffer_); + context_->allocator->Free(state_buffer_); + if (!is_vec_) { + for (int i = 0; i < 2; i++) { + context_->allocator->Free(matmul_buffer_[i]); + } } - if (weight_i_ptr_ != nullptr) { - free(weight_i_ptr_); - weight_i_ptr_ = nullptr; +} + +int InitRightMatrix(float *dst, const float *src, int batch, int deep, int col, int col_align, bool is_vec) { + for (int i = 0; i < batch; i++) { + auto src_batch = src + i * col * deep; + auto dst_batch = dst + i * col_align * deep; +#ifdef ENABLE_AVX + RowMajor2Col16Major(src_batch, dst_batch, col, deep); +#elif defined(ENABLE_ARM32) + RowMajor2Col4Major(src_batch, dst_batch, col, deep); +#else + RowMajor2Col8Major(src_batch, dst_batch, col, deep); +#endif } - if (weight_h_ptr_ != nullptr) { - free(weight_h_ptr_); - weight_h_ptr_ = nullptr; - } - if (bias_ptr_ != nullptr) { - free(bias_ptr_); - bias_ptr_ = nullptr; + return RET_OK; +} + +int LstmCPUKernel::InitWeightBias() { + auto weight_batch = lstm_param_->bidirectional_ ? 8 : 4; + + if (!is_vec_) { + // malloc and init input * weight right matrix buffer + auto weight_i = in_tensors_.at(1); + MS_ASSERT(weight_i != nullptr); + weight_i_ptr_ = reinterpret_cast( + malloc(weight_batch * lstm_param_->col_align_ * lstm_param_->input_size_ * sizeof(float))); + if (weight_i_ptr_ == nullptr) { + MS_LOG(ERROR) << "LstmCPUKernel malloc weight_i_ptr_ error."; + return RET_ERROR; + } + auto weight_i_data = reinterpret_cast(weight_i->data_c()); + InitRightMatrix(weight_i_ptr_, weight_i_data, weight_batch, lstm_param_->input_size_, lstm_param_->hidden_size_, + lstm_param_->col_align_, is_vec_); + + // malloc and init state * weight right matrix buffer + auto weight_h = in_tensors_.at(2); + MS_ASSERT(weight_h != nullptr); + weight_h_ptr_ = reinterpret_cast( + malloc(weight_batch * lstm_param_->col_align_ * lstm_param_->hidden_size_ * sizeof(float))); + if (weight_h_ptr_ == nullptr) { + MS_LOG(ERROR) << "LstmCPUKernel malloc weight_h_ptr_ error."; + return RET_ERROR; + } + auto weight_h_data = reinterpret_cast(weight_h->data_c()); + InitRightMatrix(weight_h_ptr_, weight_h_data, weight_batch, lstm_param_->hidden_size_, lstm_param_->hidden_size_, + lstm_param_->col_align_, is_vec_); + + // init bias + int bias_batch = lstm_param_->bidirectional_ ? 16 : 8; + bias_ptr_ = reinterpret_cast(malloc(bias_batch * lstm_param_->col_align_ * sizeof(float))); + if (bias_ptr_ == nullptr) { + MS_LOG(ERROR) << "LstmCPUKernel malloc bias_ptr_ error."; + return RET_ERROR; + } + memset(bias_ptr_, 0, bias_batch * lstm_param_->col_align_ * sizeof(float)); + auto bias_data = reinterpret_cast(in_tensors_.at(3)->data_c()); + for (int i = 0; i < bias_batch; i++) { + auto src_batch = bias_data + i * lstm_param_->hidden_size_; + auto dst_batch = bias_ptr_ + i * lstm_param_->col_align_; + memcpy(dst_batch, src_batch, lstm_param_->hidden_size_ * sizeof(float)); + } } + return RET_OK; } int LstmCPUKernel::InitParam() { @@ -67,80 +136,27 @@ int LstmCPUKernel::InitParam() { lstm_param_->input_step_ = lstm_param_->batch_ * lstm_param_->input_size_; lstm_param_->output_step_ = lstm_param_->bidirectional_ ? 2 * lstm_param_->batch_ * lstm_param_->hidden_size_ : lstm_param_->batch_ * lstm_param_->hidden_size_; - return RET_OK; -} -int LstmCPUKernel::InitBuffer() { - gate_buffer_ = reinterpret_cast(malloc(4 * lstm_param_->batch_ * lstm_param_->hidden_size_ * sizeof(float))); - if (gate_buffer_ == nullptr) { - MS_LOG(ERROR) << "LstmCPUKernel malloc gate_buffer error."; - return RET_ERROR; - } - if (!(lstm_param_->smooth_ >= -FLT_EPSILON && lstm_param_->smooth_ <= FLT_EPSILON)) { - int buffer_size = 2 * lstm_param_->batch_ * lstm_param_->hidden_size_ * sizeof(float); - state_buffer_ = reinterpret_cast(malloc(buffer_size)); - if (state_buffer_ == nullptr) { - MS_LOG(ERROR) << "LstmCPUKernel malloc state_buffer error."; - return RET_ERROR; - } - } - return RET_OK; -} - -int LstmCPUKernel::InitWeightBias() { - // copy weight_i and weight_h - auto weight_i = in_tensors_.at(1); - MS_ASSERT(weight_i != nullptr); - weight_i_ptr_ = reinterpret_cast(malloc(weight_i->ElementsNum() * sizeof(float))); - if (weight_i_ptr_ == nullptr) { - MS_LOG(ERROR) << "LstmCPUKernel malloc weight_i_ptr_ error."; - return RET_ERROR; - } - memcpy(weight_i_ptr_, weight_i->data_c(), weight_i->ElementsNum() * sizeof(float)); - - auto weight_h = in_tensors_.at(2); - MS_ASSERT(weight_h != nullptr); - weight_h_ptr_ = reinterpret_cast(malloc(weight_h->ElementsNum() * sizeof(float))); - if (weight_h_ptr_ == nullptr) { - MS_LOG(ERROR) << "LstmCPUKernel malloc weight_h_ error."; - return RET_ERROR; - } - memcpy(weight_h_ptr_, weight_h->data_c(), weight_h->ElementsNum() * sizeof(float)); - - std::vector w_shape = weight_i->shape(); - auto hidden_size = w_shape.at(1) / 4; - // init bias - int bias_num = lstm_param_->bidirectional_ ? 2 * 4 * hidden_size : 4 * hidden_size; - bias_ptr_ = reinterpret_cast(malloc(bias_num * sizeof(float))); - if (bias_ptr_ == nullptr) { - MS_LOG(ERROR) << "LstmCPUKernel malloc bias_ptr_ error."; - return RET_ERROR; - } - - auto bias_data = reinterpret_cast(in_tensors_.at(3)->data_c()); - const int state_bias_offset = 4 * hidden_size; - for (int i = 0; i < state_bias_offset; i++) { - bias_ptr_[i] = bias_data[i] + bias_data[i + state_bias_offset]; - } - if (lstm_param_->bidirectional_) { - bias_data += 4 * hidden_size * 2; - auto backward_bias = bias_ptr_ + 4 * hidden_size; - for (int i = 0; i < state_bias_offset; i++) { - backward_bias[i] = bias_data[i] + bias_data[i + state_bias_offset]; - } - } +#ifdef ENABLE_AVX + row_tile_ = C6NUM; + col_tile_ = C16NUM; +#elif defined(ENABLE_ARM32) + row_tile_ = C12NUM; + col_tile_ = C4NUM; +#elif defined(ENABLE_SSE) + row_tile_ = C4NUM; + col_tile_ = C8NUM; +#else + row_tile_ = C12NUM; + col_tile_ = C8NUM; +#endif + is_vec_ = lstm_param_->batch_ == 1; + lstm_param_->row_align_ = is_vec_ ? 1 : UP_ROUND(lstm_param_->batch_, row_tile_); + lstm_param_->col_align_ = is_vec_ ? lstm_param_->hidden_size_ : UP_ROUND(lstm_param_->hidden_size_, col_tile_); return RET_OK; } int LstmCPUKernel::Init() { - FreeTmpBuffer(); - auto ret = InitWeightBias(); - if (ret != RET_OK) { - MS_LOG(ERROR) << "LstmCPUKernel InitWeightBias error."; - FreeTmpBuffer(); - return RET_ERROR; - } - if (!InferShapeDone()) { return RET_OK; } @@ -154,15 +170,50 @@ int LstmCPUKernel::ReSize() { return RET_ERROR; } - ret = InitBuffer(); + FreeTmpBuffer(); + ret = InitWeightBias(); if (ret != RET_OK) { - MS_LOG(ERROR) << "LstmCPUKernel InitBuffer error."; + MS_LOG(ERROR) << "LstmCPUKernel InitWeightBias error."; FreeTmpBuffer(); return RET_ERROR; } return RET_OK; } +int LstmCPUKernel::MallocRunBuffer() { + if (!is_vec_) { + matmul_buffer_[0] = reinterpret_cast( + context_->allocator->Malloc(4 * lstm_param_->row_align_ * lstm_param_->input_size_ * sizeof(float))); + if (matmul_buffer_[0] == nullptr) { + MS_LOG(ERROR) << "LstmCPUKernel malloc input * weight left matirx error."; + return RET_ERROR; + } + + matmul_buffer_[1] = reinterpret_cast( + context_->allocator->Malloc(4 * lstm_param_->row_align_ * lstm_param_->hidden_size_ * sizeof(float))); + if (matmul_buffer_[1] == nullptr) { + MS_LOG(ERROR) << "LstmCPUKernel malloc state * weight left matirx error."; + return RET_ERROR; + } + } + + gate_buffer_ = reinterpret_cast( + context_->allocator->Malloc(8 * lstm_param_->batch_ * lstm_param_->hidden_size_ * sizeof(float))); + if (gate_buffer_ == nullptr) { + MS_LOG(ERROR) << "LstmCPUKernel malloc gate_buffer error."; + return RET_ERROR; + } + if (!(lstm_param_->smooth_ >= -FLT_EPSILON && lstm_param_->smooth_ <= FLT_EPSILON)) { + int buffer_size = 2 * lstm_param_->batch_ * lstm_param_->hidden_size_ * sizeof(float); + state_buffer_ = reinterpret_cast(context_->allocator->Malloc(buffer_size)); + if (state_buffer_ == nullptr) { + MS_LOG(ERROR) << "LstmCPUKernel malloc state_buffer error."; + return RET_ERROR; + } + } + return RET_OK; +} + int LstmCPUKernel::Run() { auto input = in_tensors_.at(kInputIndex); MS_ASSERT(input != nullptr); @@ -182,13 +233,26 @@ int LstmCPUKernel::Run() { auto output_cell_state = out_tensors_[2]; memcpy(output_cell_state->data_c(), cell_state->data_c(), cell_state->ElementsNum() * sizeof(float)); + auto ret = MallocRunBuffer(); + if (ret != RET_OK) { + MS_LOG(ERROR) << "LstmCPUKernel InitRunBuffer error."; + return RET_ERROR; + } + + if (is_vec_) { + weight_i_ptr_ = reinterpret_cast(in_tensors_[1]->data_c()); + weight_h_ptr_ = reinterpret_cast(in_tensors_[2]->data_c()); + bias_ptr_ = reinterpret_cast(in_tensors_[3]->data_c()); + } + MS_ASSERT(weight_h_ptr_); MS_ASSERT(weight_i_ptr_); MS_ASSERT(bias_ptr_); MS_ASSERT(gate_buffer_); Lstm(output_ptr, input_ptr, weight_i_ptr_, weight_h_ptr_, bias_ptr_, reinterpret_cast(output_hidden_state->data_c()), reinterpret_cast(output_cell_state->data_c()), - gate_buffer_, state_buffer_, lstm_param_); + gate_buffer_, state_buffer_, matmul_buffer_, lstm_param_); + FreeRunBuffer(); return RET_OK; } diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/lstm_fp32.h b/mindspore/lite/src/runtime/kernel/arm/fp32/lstm_fp32.h index 0980762ca4b..d0f91671b6e 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32/lstm_fp32.h +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/lstm_fp32.h @@ -39,8 +39,9 @@ class LstmCPUKernel : public LiteKernel { private: void FreeTmpBuffer(); + void FreeRunBuffer(); int InitParam(); - int InitBuffer(); + int MallocRunBuffer(); int InitWeightBias(); float *gate_buffer_ = nullptr; @@ -48,6 +49,10 @@ class LstmCPUKernel : public LiteKernel { float *weight_i_ptr_ = nullptr; float *weight_h_ptr_ = nullptr; float *bias_ptr_ = nullptr; + float *matmul_buffer_[2]; + int row_tile_ = 0; + int col_tile_ = 0; + bool is_vec_ = false; LstmParameter *lstm_param_ = nullptr; }; } // namespace mindspore::kernel