From 4fc284454f185b83975701613d2037855d491352 Mon Sep 17 00:00:00 2001 From: yangruoqi713 Date: Fri, 19 Mar 2021 17:30:26 +0800 Subject: [PATCH] [MSLITE][DEVELOP] optimize cpu fp16 op: lstm, gru --- mindspore/lite/nnacl/fp16/gru_fp16.c | 113 +++---- mindspore/lite/nnacl/fp16/gru_fp16.h | 2 +- mindspore/lite/nnacl/fp16/lstm_fp16.c | 154 +++++++--- mindspore/lite/nnacl/fp16/lstm_fp16.h | 8 +- mindspore/lite/nnacl/fp32/gru_fp32.c | 113 +++---- mindspore/lite/nnacl/fp32/gru_fp32.h | 4 +- mindspore/lite/nnacl/fp32/lstm_fp32.c | 62 +--- mindspore/lite/nnacl/fp32/lstm_fp32.h | 2 +- mindspore/lite/nnacl/gru_parameter.h | 7 +- mindspore/lite/nnacl/lstm_parameter.h | 3 - .../src/runtime/kernel/arm/fp16/gru_fp16.cc | 271 +++++++++-------- .../src/runtime/kernel/arm/fp16/gru_fp16.h | 11 +- .../src/runtime/kernel/arm/fp16/lstm_fp16.cc | 283 ++++++++++-------- .../src/runtime/kernel/arm/fp16/lstm_fp16.h | 12 +- .../src/runtime/kernel/arm/fp32/gru_fp32.cc | 191 +++++++----- .../src/runtime/kernel/arm/fp32/gru_fp32.h | 10 +- .../src/runtime/kernel/arm/fp32/lstm_fp32.cc | 26 +- .../src/runtime/kernel/arm/fp32/lstm_fp32.h | 3 +- 18 files changed, 706 insertions(+), 569 deletions(-) diff --git a/mindspore/lite/nnacl/fp16/gru_fp16.c b/mindspore/lite/nnacl/fp16/gru_fp16.c index 5fb6ff5730f..f781c3cb0cc 100644 --- a/mindspore/lite/nnacl/fp16/gru_fp16.c +++ b/mindspore/lite/nnacl/fp16/gru_fp16.c @@ -20,40 +20,22 @@ #include "nnacl/fp16/arithmetic_fp16.h" #include "nnacl/fp16/matmul_fp16.h" -void UpdateGruInputGateFp16(float16_t *gate_buffer, const float16_t *input, const float16_t *weight, - const float16_t *bias, int row, int deep, int col, int col_align, bool is_vec) { - for (int i = 0; i < 3; i++) { - const float16_t *weight_i = weight + deep * col * i; - const float16_t *bias_i = bias + col_align * i; - float16_t *gate = gate_buffer + row * col * i; - LstmMatMulFp16(gate, input, weight_i, bias_i, row, deep, col, is_vec); - } -} - -void GruStepUnitFp16(float16_t *output, const float16_t *input, const float16_t *input_weight, - const float16_t *state_weight, const float16_t *bias, float16_t *hidden_state, - float16_t *gate_buffer, float16_t *matmul_buffer[2], const GruParameter *gru_param) { +void GruStepUnitFp16(float16_t *output, float16_t *update_gate, float16_t *reset_gate, float16_t *hidden_buffer, + const float16_t *state_weight, const float16_t *state_bias, float16_t *hidden_state, + float16_t *buffer[4], const GruParameter *gru_param) { + float16_t *packed_state = buffer[2]; + float16_t *state_gate = buffer[3]; bool is_vec = gru_param->batch_ == 1; - // input * weight - if (is_vec) { - UpdateGruInputGateFp16(gate_buffer, input, input_weight, bias, gru_param->batch_, gru_param->input_size_, - gru_param->hidden_size_, gru_param->col_align_, is_vec); - } else { - // pack input for matmul - RowMajor2Col16MajorFp16(input, matmul_buffer[0], gru_param->batch_, gru_param->input_size_, false); - UpdateGruInputGateFp16(gate_buffer, matmul_buffer[0], input_weight, bias, gru_param->batch_, gru_param->input_size_, - gru_param->hidden_size_, gru_param->col_align_, is_vec); - } const float16_t *state_update_weight = state_weight; const float16_t *state_reset_weight = state_weight + gru_param->hidden_size_ * gru_param->hidden_size_; const float16_t *state_hidden_weight = state_weight + gru_param->hidden_size_ * gru_param->hidden_size_ * 2; - float16_t *state_update_gate = gate_buffer + gru_param->batch_ * gru_param->hidden_size_ * 3; - float16_t *state_reset_gate = gate_buffer + gru_param->batch_ * gru_param->hidden_size_ * 4; - float16_t *state_hidden_buffer = gate_buffer + gru_param->batch_ * gru_param->hidden_size_ * 5; - const float16_t *state_update_bias = bias + gru_param->hidden_size_ * 3; - const float16_t *state_reset_bias = bias + gru_param->hidden_size_ * 4; - const float16_t *state_hidden_bias = bias + gru_param->hidden_size_ * 5; + float16_t *state_update_gate = state_gate; + float16_t *state_reset_gate = state_gate + gru_param->batch_ * gru_param->hidden_size_; + float16_t *state_hidden_buffer = state_gate + gru_param->batch_ * gru_param->hidden_size_ * 2; + const float16_t *state_update_bias = state_bias; + const float16_t *state_reset_bias = state_bias + gru_param->hidden_size_; + const float16_t *state_hidden_bias = state_bias + gru_param->hidden_size_ * 2; // state * weight if (is_vec) { @@ -62,17 +44,15 @@ void GruStepUnitFp16(float16_t *output, const float16_t *input, const float16_t LstmMatMulFp16(state_update_gate, hidden_state, state_update_weight, state_update_bias, gru_param->batch_, gru_param->hidden_size_, gru_param->hidden_size_, is_vec); } else { - RowMajor2Col16MajorFp16(hidden_state, matmul_buffer[1], gru_param->batch_, gru_param->hidden_size_, false); - LstmMatMulFp16(state_reset_gate, matmul_buffer[1], state_reset_weight, state_reset_bias, gru_param->batch_, + RowMajor2Col16MajorFp16(hidden_state, packed_state, gru_param->batch_, gru_param->hidden_size_, false); + LstmMatMulFp16(state_reset_gate, packed_state, state_reset_weight, state_reset_bias, gru_param->batch_, gru_param->hidden_size_, gru_param->hidden_size_, is_vec); - LstmMatMulFp16(state_update_gate, matmul_buffer[1], state_update_weight, state_update_bias, gru_param->batch_, + LstmMatMulFp16(state_update_gate, packed_state, state_update_weight, state_update_bias, gru_param->batch_, gru_param->hidden_size_, gru_param->hidden_size_, is_vec); } - - ElementAddFp16(gate_buffer, state_update_gate, gate_buffer, gru_param->batch_ * gru_param->hidden_size_ * 2); - float16_t *update_gate = gate_buffer; - float16_t *reset_gate = gate_buffer + gru_param->batch_ * gru_param->hidden_size_; - float16_t *hidden_buffer = gate_buffer + gru_param->batch_ * gru_param->hidden_size_ * 2; + ElementAddFp16(update_gate, state_update_gate, update_gate, gru_param->batch_ * gru_param->hidden_size_); + ElementAddFp16(reset_gate, state_update_gate + gru_param->batch_ * gru_param->hidden_size_, reset_gate, + gru_param->batch_ * gru_param->hidden_size_); // update reset_gate SigmoidFp16(reset_gate, reset_gate, gru_param->batch_ * gru_param->hidden_size_); @@ -85,8 +65,8 @@ void GruStepUnitFp16(float16_t *output, const float16_t *input, const float16_t LstmMatMulFp16(state_hidden_buffer, reset_gate, state_hidden_weight, state_hidden_bias, gru_param->batch_, gru_param->hidden_size_, gru_param->hidden_size_, is_vec); } else { - RowMajor2Col16MajorFp16(reset_gate, matmul_buffer[1], gru_param->batch_, gru_param->hidden_size_, false); - LstmMatMulFp16(state_hidden_buffer, matmul_buffer[1], state_hidden_weight, state_hidden_bias, gru_param->batch_, + RowMajor2Col16MajorFp16(reset_gate, packed_state, gru_param->batch_, gru_param->hidden_size_, false); + LstmMatMulFp16(state_hidden_buffer, packed_state, state_hidden_weight, state_hidden_bias, gru_param->batch_, gru_param->hidden_size_, gru_param->hidden_size_, is_vec); } ElementAddFp16(hidden_buffer, state_hidden_buffer, hidden_buffer, gru_param->batch_ * gru_param->hidden_size_); @@ -106,16 +86,41 @@ void GruStepUnitFp16(float16_t *output, const float16_t *input, const float16_t memcpy(output, hidden_state, gru_param->batch_ * gru_param->hidden_size_ * sizeof(float16_t)); } +void GruUnidirectionalFp16(float16_t *output, const float16_t *packed_input, const float16_t *weight_g, + const float16_t *weight_r, const float16_t *input_bias, const float16_t *state_bias, + float16_t *hidden_state, float16_t *buffer[4], const GruParameter *gru_param, + bool is_backward) { + float16_t *gate = buffer[1]; + for (int i = 0; i < 3; i++) { + const float16_t *weight_loop = weight_g + gru_param->input_size_ * gru_param->input_col_align_ * i; + const float16_t *bias_loop = input_bias + gru_param->input_col_align_ * i; + float16_t *gate_loop = gate + gru_param->seq_len_ * gru_param->batch_ * gru_param->hidden_size_ * i; + MatMulFp16(packed_input, weight_loop, gate_loop, bias_loop, ActType_No, gru_param->input_size_, + gru_param->seq_len_ * gru_param->batch_, gru_param->hidden_size_, gru_param->hidden_size_, OutType_Nhwc); + } + + float16_t *update_gate = gate; + float16_t *reset_gate = gate + gru_param->seq_len_ * gru_param->batch_ * gru_param->hidden_size_; + float16_t *hidden_buffer = gate + gru_param->seq_len_ * gru_param->batch_ * gru_param->hidden_size_ * 2; + for (int t = 0; t < gru_param->seq_len_; t++) { + int real_t = is_backward ? gru_param->seq_len_ - t - 1 : t; + float16_t *update_gate_t = update_gate + gru_param->batch_ * gru_param->hidden_size_ * real_t; + float16_t *reset_gate_t = reset_gate + gru_param->batch_ * gru_param->hidden_size_ * real_t; + float16_t *hidden_buffer_t = hidden_buffer + gru_param->batch_ * gru_param->hidden_size_ * real_t; + float16_t *output_ptr = output + real_t * gru_param->output_step_; + GruStepUnitFp16(output_ptr, update_gate_t, reset_gate_t, hidden_buffer_t, weight_r, state_bias, hidden_state, + buffer, gru_param); + } +} + void GruFp16(float16_t *output, const float16_t *input, const float16_t *weight_g, const float16_t *weight_r, - const float16_t *bias, float16_t *hidden_state, float16_t *gate_buffer, float16_t *matmul_buffer[2], + const float16_t *input_bias, const float16_t *state_bias, float16_t *hidden_state, float16_t *buffer[4], int check_seq_len, const GruParameter *gru_param) { // forward - for (int t = 0; t < check_seq_len; t++) { - const float16_t *input_ptr = input + t * gru_param->input_step_; - float16_t *output_ptr = output + t * gru_param->output_step_; - GruStepUnitFp16(output_ptr, input_ptr, weight_g, weight_r, bias, hidden_state, gate_buffer, matmul_buffer, - gru_param); - } + float16_t *packed_input = buffer[0]; + RowMajor2Col16MajorFp16(input, packed_input, gru_param->seq_len_ * gru_param->batch_, gru_param->input_size_, false); + GruUnidirectionalFp16(output, packed_input, weight_g, weight_r, input_bias, state_bias, hidden_state, buffer, + gru_param, false); // zero out extra fw outputs for (int t = check_seq_len; t < gru_param->seq_len_; t++) { float16_t *output_ptr = output + t * gru_param->output_step_; @@ -126,17 +131,15 @@ void GruFp16(float16_t *output, const float16_t *input, const float16_t *weight_ // backward if (gru_param->bidirectional_) { - const float16_t *backward_weight_g = weight_g + 3 * gru_param->col_align_ * gru_param->input_size_; - const float16_t *backward_weight_r = weight_r + 3 * gru_param->col_align_ * gru_param->hidden_size_; - const float16_t *backward_bias = bias + 6 * gru_param->hidden_size_; + const float16_t *backward_weight_g = weight_g + 3 * gru_param->input_col_align_ * gru_param->input_size_; + const float16_t *backward_weight_r = weight_r + 3 * gru_param->state_col_align_ * gru_param->hidden_size_; + const float16_t *backward_input_bias = input_bias + 3 * gru_param->input_col_align_; + const float16_t *backward_state_bias = state_bias + 3 * gru_param->state_col_align_; float16_t *backward_output = output + gru_param->batch_ * gru_param->hidden_size_; float16_t *backward_hidden_state = hidden_state + gru_param->batch_ * gru_param->hidden_size_; - for (int t = check_seq_len - 1; t >= 0; t--) { - const float16_t *input_ptr = input + t * gru_param->input_step_; - float16_t *output_ptr = backward_output + t * gru_param->output_step_; - GruStepUnitFp16(output_ptr, input_ptr, backward_weight_g, backward_weight_r, backward_bias, backward_hidden_state, - gate_buffer, matmul_buffer, gru_param); - } + GruUnidirectionalFp16(backward_output, packed_input, backward_weight_g, backward_weight_r, backward_input_bias, + backward_state_bias, backward_hidden_state, buffer, gru_param, true); + // zero out extra bw outputs for (int t = gru_param->seq_len_ - 1; t >= check_seq_len; t--) { float16_t *output_ptr = backward_output + t * gru_param->output_step_; diff --git a/mindspore/lite/nnacl/fp16/gru_fp16.h b/mindspore/lite/nnacl/fp16/gru_fp16.h index 2171227b72d..4f4485748d3 100644 --- a/mindspore/lite/nnacl/fp16/gru_fp16.h +++ b/mindspore/lite/nnacl/fp16/gru_fp16.h @@ -21,7 +21,7 @@ extern "C" { #endif void GruFp16(float16_t *output, const float16_t *input, const float16_t *weight_g, const float16_t *weight_r, - const float16_t *bias, float16_t *hidden_state, float16_t *gate_buffer, float16_t *matmul_buffer[2], + const float16_t *input_bias, const float16_t *state_bias, float16_t *hidden_state, float16_t *buffer[4], int check_seq_len, const GruParameter *gru_param); #ifdef __cplusplus } diff --git a/mindspore/lite/nnacl/fp16/lstm_fp16.c b/mindspore/lite/nnacl/fp16/lstm_fp16.c index d23baeb01d4..a1a2e7bf5cd 100644 --- a/mindspore/lite/nnacl/fp16/lstm_fp16.c +++ b/mindspore/lite/nnacl/fp16/lstm_fp16.c @@ -20,6 +20,7 @@ #include "nnacl/fp16/activation_fp16.h" #include "nnacl/fp16/arithmetic_fp16.h" #include "nnacl/fp16/matmul_fp16.h" +#include "nnacl/fp16/cast_fp16.h" void PackLstmWeightFp32ToFp16(float16_t *dst, const float *src, int batch, int deep, int col, int col_align) { for (int i = 0; i < batch; i++) { @@ -37,6 +38,43 @@ void PackLstmWeightFp16(float16_t *dst, const float16_t *src, int batch, int dee } } +void PackLstmBiasFp32ToFp16(float16_t *dst, const float *src, int batch, int col, int col_align, + bool is_bidirectional) { + int unidirectional_batch = is_bidirectional ? batch / 2 : batch; + for (int i = 0; i < unidirectional_batch; i++) { + const float *src_batch = src + i * col; + float16_t *dst_batch = dst + i * col_align; + Float32ToFloat16(src_batch, dst_batch, col); + } + if (is_bidirectional) { + const float *backward_src = src + batch * col; + float16_t *backward_dst = dst + unidirectional_batch * col_align; + for (int i = 0; i < unidirectional_batch; i++) { + const float *backward_src_batch = backward_src + i * col; + float16_t *backward_dst_batch = backward_dst + i * col_align; + Float32ToFloat16(backward_src_batch, backward_dst_batch, col); + } + } +} + +void PackLstmBiasFp16(float16_t *dst, const float16_t *src, int batch, int col, int col_align, bool is_bidirectional) { + int unidirectional_batch = is_bidirectional ? batch / 2 : batch; + for (int i = 0; i < unidirectional_batch; i++) { + const float16_t *src_batch = src + i * col; + float16_t *dst_batch = dst + i * col_align; + memcpy(dst_batch, src_batch, col * sizeof(float16_t)); + } + if (is_bidirectional) { + const float16_t *backward_src = src + batch * col; + float16_t *backward_dst = dst + unidirectional_batch * col_align; + for (int i = 0; i < unidirectional_batch; i++) { + const float16_t *backward_src_batch = backward_src + i * col; + float16_t *backward_dst_batch = backward_dst + i * col_align; + memcpy(backward_dst_batch, backward_src_batch, col * sizeof(float16_t)); + } + } +} + // input: [row, inner_size]; weight: [col, inner_size]; output: [row, col] void MatMulAccFp16(float16_t *output, const float16_t *input, const float16_t *weight, int rows, int cols, int inner_size) { @@ -149,40 +187,32 @@ void UpdateLstmGateFp16(float16_t *gate_buffer, const float16_t *input, const fl } } -void LstmStepUnitFp16(float16_t *output, const float16_t *input, const float16_t *input_weight, - const float16_t *state_weight, const float16_t *bias, float16_t *hidden_state, - float16_t *cell_state, float16_t *gate_buffer, float16_t *state_buffer[2], - float16_t *matmul_buffer[2], const LstmParameter *lstm_param) { +void LstmStepUnitFp16(float16_t *output, float16_t *input_gate, float16_t *forget_gate, float16_t *cell_gate, + float16_t *output_gate, const float16_t *state_weight, const float16_t *state_bias, + float16_t *hidden_state, float16_t *cell_state, float16_t *buffer[6], + const LstmParameter *lstm_param) { + float16_t *packed_state = buffer[2]; + float16_t *state_gate = buffer[3]; + float16_t *cell_buffer = buffer[4]; + float16_t *hidden_buffer = buffer[5]; bool is_vec = lstm_param->batch_ == 1; - // input * weight - if (is_vec) { - UpdateLstmGateFp16(gate_buffer, input, input_weight, bias, lstm_param->batch_, lstm_param->input_size_, - lstm_param->hidden_size_, lstm_param->col_align_, is_vec); - } else { - // pack input for matmul - RowMajor2Col16MajorFp16(input, matmul_buffer[0], lstm_param->batch_, lstm_param->input_size_, false); - UpdateLstmGateFp16(gate_buffer, matmul_buffer[0], input_weight, bias, lstm_param->batch_, lstm_param->input_size_, - lstm_param->hidden_size_, lstm_param->col_align_, is_vec); - } - - // state * weight - float16_t *state_gate = gate_buffer + lstm_param->batch_ * lstm_param->hidden_size_ * 4; - const float16_t *state_bias = bias + lstm_param->col_align_ * 4; if (is_vec) { UpdateLstmGateFp16(state_gate, hidden_state, state_weight, state_bias, lstm_param->batch_, lstm_param->hidden_size_, - lstm_param->hidden_size_, lstm_param->col_align_, is_vec); + lstm_param->hidden_size_, lstm_param->state_col_align_, is_vec); } else { // pack state for matmul - RowMajor2Col16MajorFp16(hidden_state, matmul_buffer[1], lstm_param->batch_, lstm_param->hidden_size_, false); - UpdateLstmGateFp16(state_gate, matmul_buffer[1], state_weight, state_bias, lstm_param->batch_, - lstm_param->hidden_size_, lstm_param->hidden_size_, lstm_param->col_align_, is_vec); + RowMajor2Col16MajorFp16(hidden_state, packed_state, lstm_param->batch_, lstm_param->hidden_size_, false); + UpdateLstmGateFp16(state_gate, packed_state, state_weight, state_bias, lstm_param->batch_, lstm_param->hidden_size_, + lstm_param->hidden_size_, lstm_param->state_col_align_, is_vec); } - ElementAddFp16(gate_buffer, state_gate, gate_buffer, 4 * lstm_param->batch_ * lstm_param->hidden_size_); + ElementAddFp16(input_gate, state_gate, input_gate, lstm_param->batch_ * lstm_param->hidden_size_); + ElementAddFp16(forget_gate, state_gate + lstm_param->batch_ * lstm_param->hidden_size_ * 2, forget_gate, + lstm_param->batch_ * lstm_param->hidden_size_); + ElementAddFp16(cell_gate, state_gate + lstm_param->batch_ * lstm_param->hidden_size_ * 3, cell_gate, + lstm_param->batch_ * lstm_param->hidden_size_); + ElementAddFp16(output_gate, state_gate + lstm_param->batch_ * lstm_param->hidden_size_, output_gate, + lstm_param->batch_ * lstm_param->hidden_size_); - float16_t *input_gate = gate_buffer; - float16_t *forget_gate = gate_buffer + lstm_param->batch_ * lstm_param->hidden_size_ * 2; - float16_t *cell_gate = gate_buffer + lstm_param->batch_ * lstm_param->hidden_size_ * 3; - float16_t *output_gate = gate_buffer + lstm_param->batch_ * lstm_param->hidden_size_; // update input_gate SigmoidFp16(input_gate, input_gate, lstm_param->batch_ * lstm_param->hidden_size_); @@ -192,50 +222,76 @@ void LstmStepUnitFp16(float16_t *output, const float16_t *input, const float16_t // update cell_gate TanhFp16(cell_gate, cell_gate, lstm_param->batch_ * lstm_param->hidden_size_); // update cell state - UpdataStateFp16(cell_state, forget_gate, input_gate, cell_gate, state_buffer[0], lstm_param->batch_, + UpdataStateFp16(cell_state, forget_gate, input_gate, cell_gate, cell_buffer, lstm_param->batch_, lstm_param->hidden_size_, lstm_param->zoneout_cell_); // update output_gate SigmoidFp16(output_gate, output_gate, lstm_param->batch_ * lstm_param->hidden_size_); // update output - UpdataOutputFp16(cell_state, output_gate, hidden_state, state_buffer[1], lstm_param->batch_, lstm_param->hidden_size_, + UpdataOutputFp16(cell_state, output_gate, hidden_state, hidden_buffer, lstm_param->batch_, lstm_param->hidden_size_, lstm_param->zoneout_hidden_); memcpy(output, hidden_state, lstm_param->batch_ * lstm_param->hidden_size_ * sizeof(float16_t)); if (!(lstm_param->zoneout_cell_ >= -FLT_EPSILON && lstm_param->zoneout_cell_ <= FLT_EPSILON)) { - memcpy(cell_state, state_buffer[0], lstm_param->batch_ * lstm_param->hidden_size_ * sizeof(float16_t)); + memcpy(cell_state, cell_buffer, lstm_param->batch_ * lstm_param->hidden_size_ * sizeof(float16_t)); } if (!(lstm_param->zoneout_hidden_ >= -FLT_EPSILON && lstm_param->zoneout_hidden_ <= FLT_EPSILON)) { - memcpy(hidden_state, state_buffer[1], lstm_param->batch_ * lstm_param->hidden_size_ * sizeof(float16_t)); + memcpy(hidden_state, hidden_buffer, lstm_param->batch_ * lstm_param->hidden_size_ * sizeof(float16_t)); + } +} + +void LstmUnidirectionalFp16(float16_t *output, const float16_t *packed_input, const float16_t *weight_i, + const float16_t *weight_h, const float16_t *input_bias, const float16_t *state_bias, + float16_t *hidden_state, float16_t *cell_state, float16_t *buffer[6], + const LstmParameter *lstm_param, bool is_backward) { + float16_t *gate = buffer[1]; + for (int i = 0; i < 4; i++) { + const float16_t *weight_loop = weight_i + lstm_param->input_size_ * lstm_param->input_col_align_ * i; + const float16_t *bias_loop = input_bias + lstm_param->input_col_align_ * i; + float16_t *gate_loop = gate + lstm_param->seq_len_ * lstm_param->batch_ * lstm_param->hidden_size_ * i; + MatMulFp16(packed_input, weight_loop, gate_loop, bias_loop, ActType_No, lstm_param->input_size_, + lstm_param->seq_len_ * lstm_param->batch_, lstm_param->hidden_size_, lstm_param->hidden_size_, + OutType_Nhwc); + } + + float16_t *input_gate = gate; + float16_t *forget_gate = gate + lstm_param->seq_len_ * lstm_param->batch_ * lstm_param->hidden_size_ * 2; + float16_t *cell_gate = gate + lstm_param->seq_len_ * lstm_param->batch_ * lstm_param->hidden_size_ * 3; + float16_t *output_gate = gate + lstm_param->seq_len_ * lstm_param->batch_ * lstm_param->hidden_size_; + for (int t = 0; t < lstm_param->seq_len_; t++) { + int real_t = is_backward ? lstm_param->seq_len_ - t - 1 : t; + float16_t *input_gate_t = input_gate + lstm_param->batch_ * lstm_param->hidden_size_ * real_t; + float16_t *forget_gate_t = forget_gate + lstm_param->batch_ * lstm_param->hidden_size_ * real_t; + float16_t *cell_gate_t = cell_gate + lstm_param->batch_ * lstm_param->hidden_size_ * real_t; + float16_t *output_gate_t = output_gate + lstm_param->batch_ * lstm_param->hidden_size_ * real_t; + float16_t *output_ptr = output + real_t * lstm_param->output_step_; + LstmStepUnitFp16(output_ptr, input_gate_t, forget_gate_t, cell_gate_t, output_gate_t, weight_h, state_bias, + hidden_state, cell_state, buffer, lstm_param); } } void LstmFp16(float16_t *output, const float16_t *input, const float16_t *weight_i, const float16_t *weight_h, - const float16_t *bias, float16_t *hidden_state, float16_t *cell_state, float16_t *gate_buffer, - float16_t *state_buffer[2], float16_t *matmul_buffer[2], const LstmParameter *lstm_param) { + const float16_t *input_bias, const float16_t *state_bias, float16_t *hidden_state, float16_t *cell_state, + float16_t *buffer[6], const LstmParameter *lstm_param) { // forward - for (int t = 0; t < lstm_param->seq_len_; t++) { - const float16_t *input_ptr = input + t * lstm_param->input_step_; - float16_t *output_ptr = output + t * lstm_param->output_step_; - LstmStepUnitFp16(output_ptr, input_ptr, weight_i, weight_h, bias, hidden_state, cell_state, gate_buffer, - state_buffer, matmul_buffer, lstm_param); - } + float16_t *packed_input = buffer[0]; + RowMajor2Col16MajorFp16(input, packed_input, lstm_param->seq_len_ * lstm_param->batch_, lstm_param->input_size_, + false); + LstmUnidirectionalFp16(output, packed_input, weight_i, weight_h, input_bias, state_bias, hidden_state, cell_state, + buffer, lstm_param, false); // backward if (lstm_param->bidirectional_) { - const float16_t *backward_weight_i = weight_i + 4 * lstm_param->col_align_ * lstm_param->input_size_; - const float16_t *backward_weight_h = weight_h + 4 * lstm_param->col_align_ * lstm_param->hidden_size_; - const float16_t *backward_bias = bias + 8 * lstm_param->col_align_; + const float16_t *backward_weight_i = weight_i + 4 * lstm_param->input_col_align_ * lstm_param->input_size_; + const float16_t *backward_weight_h = weight_h + 4 * lstm_param->state_col_align_ * lstm_param->hidden_size_; + const float16_t *backward_input_bias = input_bias + 4 * lstm_param->input_col_align_; + const float16_t *backward_state_bias = state_bias + 4 * lstm_param->state_col_align_; float16_t *backward_output = output + lstm_param->batch_ * lstm_param->hidden_size_; float16_t *backward_cell_state = cell_state + lstm_param->batch_ * lstm_param->hidden_size_; float16_t *backward_hidden_state = hidden_state + lstm_param->batch_ * lstm_param->hidden_size_; - for (int t = lstm_param->seq_len_ - 1; t >= 0; t--) { - const float16_t *input_ptr = input + t * lstm_param->input_step_; - float16_t *output_ptr = backward_output + t * lstm_param->output_step_; - LstmStepUnitFp16(output_ptr, input_ptr, backward_weight_i, backward_weight_h, backward_bias, - backward_hidden_state, backward_cell_state, gate_buffer, state_buffer, matmul_buffer, - lstm_param); - } + + LstmUnidirectionalFp16(backward_output, packed_input, backward_weight_i, backward_weight_h, backward_input_bias, + backward_state_bias, backward_hidden_state, backward_cell_state, buffer, lstm_param, true); } } diff --git a/mindspore/lite/nnacl/fp16/lstm_fp16.h b/mindspore/lite/nnacl/fp16/lstm_fp16.h index e3aae4326be..fff951d8adf 100644 --- a/mindspore/lite/nnacl/fp16/lstm_fp16.h +++ b/mindspore/lite/nnacl/fp16/lstm_fp16.h @@ -25,6 +25,10 @@ void PackLstmWeightFp32ToFp16(float16_t *dst, const float *src, int batch, int d void PackLstmWeightFp16(float16_t *dst, const float16_t *src, int batch, int deep, int col, int col_align); +void PackLstmBiasFp32ToFp16(float16_t *dst, const float *src, int batch, int col, int col_align, bool is_bidirectional); + +void PackLstmBiasFp16(float16_t *dst, const float16_t *src, int batch, int col, int col_align, bool is_bidirectional); + void LstmMatMulFp16(float16_t *c, const float16_t *a, const float16_t *b, const float16_t *bias, int row, int deep, int col, bool is_vec); @@ -36,8 +40,8 @@ void ElementMulAccFp16(const float16_t *input0, const float16_t *input1, float16 int ElementOptMulAccFp16(const float16_t *input0, const float16_t input1, float16_t *output, const int element_size); void LstmFp16(float16_t *output, const float16_t *input, const float16_t *weight_i, const float16_t *weight_h, - const float16_t *bias, float16_t *hidden_state, float16_t *cell_state, float16_t *gate_buffer, - float16_t *state_buffer[2], float16_t *matmul_buffer[2], const LstmParameter *lstm_param); + const float16_t *input_bias, const float16_t *state_bias, float16_t *hidden_state, float16_t *cell_state, + float16_t *buffer[6], const LstmParameter *lstm_param); #ifdef __cplusplus } #endif diff --git a/mindspore/lite/nnacl/fp32/gru_fp32.c b/mindspore/lite/nnacl/fp32/gru_fp32.c index 9ee85761f74..9d1e5991ff0 100644 --- a/mindspore/lite/nnacl/fp32/gru_fp32.c +++ b/mindspore/lite/nnacl/fp32/gru_fp32.c @@ -20,40 +20,21 @@ #include "nnacl/fp32/arithmetic_fp32.h" #include "nnacl/fp32/matmul_fp32.h" -void UpdateGruInputGate(float *gate_buffer, const float *input, const float *weight, const float *bias, int row, - int deep, int col, int col_align, bool is_vec) { - for (int i = 0; i < 3; i++) { - const float *weight_i = weight + deep * col * i; - const float *bias_i = bias + col_align * i; - float *gate = gate_buffer + row * col * i; - LstmMatMul(gate, input, weight_i, bias_i, row, deep, col, is_vec); - } -} - -void GruStepUnit(float *output, const float *input, const float *input_weight, const float *state_weight, - const float *bias, float *hidden_state, float *gate_buffer, float *matmul_buffer[2], - const GruParameter *gru_param) { +void GruStepUnit(float *output, float *update_gate, float *reset_gate, float *hidden_buffer, const float *state_weight, + const float *state_bias, float *hidden_state, float *buffer[4], const GruParameter *gru_param) { + float *packed_state = buffer[2]; + float *state_gate = buffer[3]; bool is_vec = gru_param->batch_ == 1; - // input * weight - if (is_vec) { - UpdateGruInputGate(gate_buffer, input, input_weight, bias, gru_param->batch_, gru_param->input_size_, - gru_param->hidden_size_, gru_param->col_align_, is_vec); - } else { - // pack input for matmul - PackLstmInput(input, matmul_buffer[0], gru_param->batch_, gru_param->input_size_); - UpdateGruInputGate(gate_buffer, matmul_buffer[0], input_weight, bias, gru_param->batch_, gru_param->input_size_, - gru_param->hidden_size_, gru_param->col_align_, is_vec); - } const float *state_update_weight = state_weight; const float *state_reset_weight = state_weight + gru_param->hidden_size_ * gru_param->hidden_size_; const float *state_hidden_weight = state_weight + gru_param->hidden_size_ * gru_param->hidden_size_ * 2; - float *state_update_gate = gate_buffer + gru_param->batch_ * gru_param->hidden_size_ * 3; - float *state_reset_gate = gate_buffer + gru_param->batch_ * gru_param->hidden_size_ * 4; - float *state_hidden_buffer = gate_buffer + gru_param->batch_ * gru_param->hidden_size_ * 5; - const float *state_update_bias = bias + gru_param->hidden_size_ * 3; - const float *state_reset_bias = bias + gru_param->hidden_size_ * 4; - const float *state_hidden_bias = bias + gru_param->hidden_size_ * 5; + float *state_update_gate = state_gate; + float *state_reset_gate = state_gate + gru_param->batch_ * gru_param->hidden_size_; + float *state_hidden_buffer = state_gate + gru_param->batch_ * gru_param->hidden_size_ * 2; + const float *state_update_bias = state_bias; + const float *state_reset_bias = state_bias + gru_param->hidden_size_; + const float *state_hidden_bias = state_bias + gru_param->hidden_size_ * 2; // state * weight if (is_vec) { @@ -62,16 +43,15 @@ void GruStepUnit(float *output, const float *input, const float *input_weight, c LstmMatMul(state_update_gate, hidden_state, state_update_weight, state_update_bias, gru_param->batch_, gru_param->hidden_size_, gru_param->hidden_size_, is_vec); } else { - PackLstmInput(hidden_state, matmul_buffer[1], gru_param->batch_, gru_param->hidden_size_); - LstmMatMul(state_reset_gate, matmul_buffer[1], state_reset_weight, state_reset_bias, gru_param->batch_, + PackLstmInput(hidden_state, packed_state, gru_param->batch_, gru_param->hidden_size_); + LstmMatMul(state_reset_gate, packed_state, state_reset_weight, state_reset_bias, gru_param->batch_, gru_param->hidden_size_, gru_param->hidden_size_, is_vec); - LstmMatMul(state_update_gate, matmul_buffer[1], state_update_weight, state_update_bias, gru_param->batch_, + LstmMatMul(state_update_gate, packed_state, state_update_weight, state_update_bias, gru_param->batch_, gru_param->hidden_size_, gru_param->hidden_size_, is_vec); } - ElementAdd(gate_buffer, state_update_gate, gate_buffer, gru_param->batch_ * gru_param->hidden_size_ * 2); - float *update_gate = gate_buffer; - float *reset_gate = gate_buffer + gru_param->batch_ * gru_param->hidden_size_; - float *hidden_buffer = gate_buffer + gru_param->batch_ * gru_param->hidden_size_ * 2; + ElementAdd(update_gate, state_update_gate, update_gate, gru_param->batch_ * gru_param->hidden_size_); + ElementAdd(reset_gate, state_update_gate + gru_param->batch_ * gru_param->hidden_size_, reset_gate, + gru_param->batch_ * gru_param->hidden_size_); // update reset_gate Sigmoid(reset_gate, gru_param->batch_ * gru_param->hidden_size_, reset_gate); @@ -83,8 +63,8 @@ void GruStepUnit(float *output, const float *input, const float *input_weight, c LstmMatMul(state_hidden_buffer, reset_gate, state_hidden_weight, state_hidden_bias, gru_param->batch_, gru_param->hidden_size_, gru_param->hidden_size_, is_vec); } else { - PackLstmInput(reset_gate, matmul_buffer[1], gru_param->batch_, gru_param->hidden_size_); - LstmMatMul(state_hidden_buffer, matmul_buffer[1], state_hidden_weight, state_hidden_bias, gru_param->batch_, + PackLstmInput(reset_gate, packed_state, gru_param->batch_, gru_param->hidden_size_); + LstmMatMul(state_hidden_buffer, packed_state, state_hidden_weight, state_hidden_bias, gru_param->batch_, gru_param->hidden_size_, gru_param->hidden_size_, is_vec); } ElementAdd(hidden_buffer, state_hidden_buffer, hidden_buffer, gru_param->batch_ * gru_param->hidden_size_); @@ -104,15 +84,41 @@ void GruStepUnit(float *output, const float *input, const float *input_weight, c memcpy(output, hidden_state, gru_param->batch_ * gru_param->hidden_size_ * sizeof(float)); } -void Gru(float *output, const float *input, const float *weight_g, const float *weight_r, const float *bias, - float *hidden_state, float *gate_buffer, float *matmul_buffer[2], int check_seq_len, +void GruUnidirectional(float *output, const float *packed_input, const float *weight_g, const float *weight_r, + const float *input_bias, const float *state_bias, float *hidden_state, float *buffer[4], + const GruParameter *gru_param, bool is_backward) { + float *gate = buffer[1]; + for (int i = 0; i < 3; i++) { + const float *weight_loop = weight_g + gru_param->input_size_ * gru_param->input_col_align_ * i; + const float *bias_loop = input_bias + gru_param->input_col_align_ * i; + float *gate_loop = gate + gru_param->seq_len_ * gru_param->batch_ * gru_param->hidden_size_ * i; + MatMulOpt(packed_input, weight_loop, gate_loop, bias_loop, ActType_No, gru_param->input_size_, + gru_param->seq_len_ * gru_param->batch_, gru_param->hidden_size_, gru_param->hidden_size_, OutType_Nhwc); + } + + float *update_gate = gate; + float *reset_gate = gate + gru_param->seq_len_ * gru_param->batch_ * gru_param->hidden_size_; + float *hidden_buffer = gate + gru_param->seq_len_ * gru_param->batch_ * gru_param->hidden_size_ * 2; + for (int t = 0; t < gru_param->seq_len_; t++) { + int real_t = is_backward ? gru_param->seq_len_ - t - 1 : t; + float *update_gate_t = update_gate + gru_param->batch_ * gru_param->hidden_size_ * real_t; + float *reset_gate_t = reset_gate + gru_param->batch_ * gru_param->hidden_size_ * real_t; + float *hidden_buffer_t = hidden_buffer + gru_param->batch_ * gru_param->hidden_size_ * real_t; + float *output_ptr = output + real_t * gru_param->output_step_; + GruStepUnit(output_ptr, update_gate_t, reset_gate_t, hidden_buffer_t, weight_r, state_bias, hidden_state, buffer, + gru_param); + } +} + +void Gru(float *output, const float *input, const float *weight_g, const float *weight_r, const float *input_bias, + const float *state_bias, float *hidden_state, float *buffer[4], int check_seq_len, const GruParameter *gru_param) { // forward - for (int t = 0; t < check_seq_len; t++) { - const float *input_ptr = input + t * gru_param->input_step_; - float *output_ptr = output + t * gru_param->output_step_; - GruStepUnit(output_ptr, input_ptr, weight_g, weight_r, bias, hidden_state, gate_buffer, matmul_buffer, gru_param); - } + float *packed_input = buffer[0]; + PackLstmInput(input, packed_input, gru_param->seq_len_ * gru_param->batch_, gru_param->input_size_); + GruUnidirectional(output, packed_input, weight_g, weight_r, input_bias, state_bias, hidden_state, buffer, gru_param, + false); + // zero out extra fw outputs for (int t = check_seq_len; t < gru_param->seq_len_; t++) { float *output_ptr = output + t * gru_param->output_step_; @@ -123,17 +129,16 @@ void Gru(float *output, const float *input, const float *weight_g, const float * // backward if (gru_param->bidirectional_) { - const float *backward_weight_g = weight_g + 3 * gru_param->col_align_ * gru_param->input_size_; - const float *backward_weight_r = weight_r + 3 * gru_param->col_align_ * gru_param->hidden_size_; - const float *backward_bias = bias + 6 * gru_param->hidden_size_; + const float *backward_weight_g = weight_g + 3 * gru_param->input_col_align_ * gru_param->input_size_; + const float *backward_weight_r = weight_r + 3 * gru_param->state_col_align_ * gru_param->hidden_size_; + const float *backward_input_bias = input_bias + 3 * gru_param->input_col_align_; + const float *backward_state_bias = state_bias + 3 * gru_param->state_col_align_; float *backward_output = output + gru_param->batch_ * gru_param->hidden_size_; float *backward_hidden_state = hidden_state + gru_param->batch_ * gru_param->hidden_size_; - for (int t = check_seq_len - 1; t >= 0; t--) { - const float *input_ptr = input + t * gru_param->input_step_; - float *output_ptr = backward_output + t * gru_param->output_step_; - GruStepUnit(output_ptr, input_ptr, backward_weight_g, backward_weight_r, backward_bias, backward_hidden_state, - gate_buffer, matmul_buffer, gru_param); - } + + GruUnidirectional(backward_output, packed_input, backward_weight_g, backward_weight_r, backward_input_bias, + backward_state_bias, backward_hidden_state, buffer, gru_param, true); + // zero out extra bw outputs for (int t = gru_param->seq_len_ - 1; t >= check_seq_len; t--) { float *output_ptr = backward_output + t * gru_param->output_step_; diff --git a/mindspore/lite/nnacl/fp32/gru_fp32.h b/mindspore/lite/nnacl/fp32/gru_fp32.h index 69ddd23bf07..3333eafd9cf 100644 --- a/mindspore/lite/nnacl/fp32/gru_fp32.h +++ b/mindspore/lite/nnacl/fp32/gru_fp32.h @@ -20,8 +20,8 @@ #ifdef __cplusplus extern "C" { #endif -void Gru(float *output, const float *input, const float *weight_g, const float *weight_r, const float *bias, - float *hidden_state, float *gate_buffer, float *matmul_buffer[2], int check_seq_len, +void Gru(float *output, const float *input, const float *weight_g, const float *weight_r, const float *input_bias, + const float *state_bias, float *hidden_state, float *buffer[4], int check_seq_len, const GruParameter *gru_parm); #ifdef __cplusplus } diff --git a/mindspore/lite/nnacl/fp32/lstm_fp32.c b/mindspore/lite/nnacl/fp32/lstm_fp32.c index 969f88fb0de..445321646c1 100644 --- a/mindspore/lite/nnacl/fp32/lstm_fp32.c +++ b/mindspore/lite/nnacl/fp32/lstm_fp32.c @@ -63,37 +63,6 @@ void PackLstmInput(const float *src, float *dst, int row, int deep) { #endif } -// input: [row, inner_size]; weight: [col, inner_size]; output: [row, col] -void MatMulAcc(float *output, const float *input, const float *weight, int rows, int cols, int inner_size) { - for (int r = 0; r < rows; r++) { - for (int c = 0; c < cols; c++) { - float res = 0; - const float *input_col = input + r * inner_size; - const float *weight_col = weight + c * inner_size; - int index = 0; -#ifdef ENABLE_ARM - float32x4_t out = vdupq_n_f32(0.0f); - for (; index <= inner_size - 4; index += 4) { - float32x4_t in_0 = vld1q_f32(input_col + index); - float32x4_t in_1 = vld1q_f32(weight_col + index); - out = vmlaq_f32(out, in_1, in_0); - } -#ifdef ENABLE_ARM64 - res += vaddvq_f32(out); -#else - float32x2_t add2 = vadd_f32(vget_low_f32(out), vget_high_f32(out)); - float32x2_t add4 = vpadd_f32(add2, add2); - res += vget_lane_f32(add4, 0); -#endif -#endif - for (; index < inner_size; index++) { - res += input_col[index] * weight_col[index]; - } - output[r * cols + c] += res; - } - } -} - void LstmMatMul(float *c, const float *a, const float *b, const float *bias, int row, int deep, int col, bool is_vec) { if (is_vec) { MatVecMulFp32(a, b, c, bias, ActType_No, deep, col); @@ -182,7 +151,11 @@ void UpdateLstmGate(float *gate_buffer, const float *input, const float *weight, void LstmStepUnit(float *output, float *input_gate, float *forget_gate, float *cell_gate, float *output_gate, const float *state_weight, const float *state_bias, float *hidden_state, float *cell_state, - float *state_gate, float *state_buffer[2], float *packed_state, const LstmParameter *lstm_param) { + float *buffer[6], const LstmParameter *lstm_param) { + float *packed_state = buffer[2]; + float *state_gate = buffer[3]; + float *cell_buffer = buffer[4]; + float *hidden_buffer = buffer[5]; bool is_vec = lstm_param->batch_ == 1; // state * weight if (is_vec) { @@ -211,31 +184,29 @@ void LstmStepUnit(float *output, float *input_gate, float *forget_gate, float *c // update cell_gate Tanh(cell_gate, lstm_param->batch_ * lstm_param->hidden_size_, cell_gate); // update cell state - UpdataState(cell_state, forget_gate, input_gate, cell_gate, state_buffer[0], lstm_param->batch_, - lstm_param->hidden_size_, lstm_param->zoneout_cell_); + UpdataState(cell_state, forget_gate, input_gate, cell_gate, cell_buffer, lstm_param->batch_, lstm_param->hidden_size_, + lstm_param->zoneout_cell_); // update output_gate Sigmoid(output_gate, lstm_param->batch_ * lstm_param->hidden_size_, output_gate); // update output - UpdataOutput(cell_state, output_gate, hidden_state, state_buffer[1], lstm_param->batch_, lstm_param->hidden_size_, + UpdataOutput(cell_state, output_gate, hidden_state, hidden_buffer, lstm_param->batch_, lstm_param->hidden_size_, lstm_param->zoneout_hidden_); memcpy(output, hidden_state, lstm_param->batch_ * lstm_param->hidden_size_ * sizeof(float)); if (!(lstm_param->zoneout_cell_ >= -FLT_EPSILON && lstm_param->zoneout_cell_ <= FLT_EPSILON)) { - memcpy(cell_state, state_buffer[0], lstm_param->batch_ * lstm_param->hidden_size_ * sizeof(float)); + memcpy(cell_state, cell_buffer, lstm_param->batch_ * lstm_param->hidden_size_ * sizeof(float)); } if (!(lstm_param->zoneout_hidden_ >= -FLT_EPSILON && lstm_param->zoneout_hidden_ <= FLT_EPSILON)) { - memcpy(hidden_state, state_buffer[1], lstm_param->batch_ * lstm_param->hidden_size_ * sizeof(float)); + memcpy(hidden_state, hidden_buffer, lstm_param->batch_ * lstm_param->hidden_size_ * sizeof(float)); } } void LstmUnidirectional(float *output, const float *packed_input, const float *weight_i, const float *weight_h, const float *input_bias, const float *state_bias, float *hidden_state, float *cell_state, - float *state_buffer[2], float *buffer[4], const LstmParameter *lstm_param, bool is_backward) { + float *buffer[6], const LstmParameter *lstm_param, bool is_backward) { float *gate = buffer[1]; - float *packed_state = buffer[2]; - float *state_gate = buffer[3]; for (int i = 0; i < 4; i++) { const float *weight_loop = weight_i + lstm_param->input_size_ * lstm_param->input_col_align_ * i; const float *bias_loop = input_bias + lstm_param->input_col_align_ * i; @@ -257,18 +228,18 @@ void LstmUnidirectional(float *output, const float *packed_input, const float *w float *output_gate_t = output_gate + lstm_param->batch_ * lstm_param->hidden_size_ * real_t; float *output_ptr = output + real_t * lstm_param->output_step_; LstmStepUnit(output_ptr, input_gate_t, forget_gate_t, cell_gate_t, output_gate_t, weight_h, state_bias, - hidden_state, cell_state, state_gate, state_buffer, packed_state, lstm_param); + hidden_state, cell_state, buffer, lstm_param); } } void Lstm(float *output, const float *input, const float *weight_i, const float *weight_h, const float *input_bias, - const float *state_bias, float *hidden_state, float *cell_state, float *state_buffer[2], float *buffer[4], + const float *state_bias, float *hidden_state, float *cell_state, float *buffer[6], const LstmParameter *lstm_param) { // forward float *packed_input = buffer[0]; PackLstmInput(input, packed_input, lstm_param->seq_len_ * lstm_param->batch_, lstm_param->input_size_); - LstmUnidirectional(output, packed_input, weight_i, weight_h, input_bias, state_bias, hidden_state, cell_state, - state_buffer, buffer, lstm_param, false); + LstmUnidirectional(output, packed_input, weight_i, weight_h, input_bias, state_bias, hidden_state, cell_state, buffer, + lstm_param, false); // backward if (lstm_param->bidirectional_) { @@ -281,7 +252,6 @@ void Lstm(float *output, const float *input, const float *weight_i, const float float *backward_hidden_state = hidden_state + lstm_param->batch_ * lstm_param->hidden_size_; LstmUnidirectional(backward_output, packed_input, backward_weight_i, backward_weight_h, backward_input_bias, - backward_state_bias, backward_hidden_state, backward_cell_state, state_buffer, buffer, - lstm_param, true); + backward_state_bias, backward_hidden_state, backward_cell_state, buffer, lstm_param, true); } } diff --git a/mindspore/lite/nnacl/fp32/lstm_fp32.h b/mindspore/lite/nnacl/fp32/lstm_fp32.h index 3a7142e6c6c..8a9d8276dc2 100644 --- a/mindspore/lite/nnacl/fp32/lstm_fp32.h +++ b/mindspore/lite/nnacl/fp32/lstm_fp32.h @@ -34,7 +34,7 @@ void ElementMulAcc(const float *input0, const float *input1, float *output, int int ElementOptMulAcc(const float *input0, const float input1, float *output, const int element_size); void Lstm(float *output, const float *input, const float *weight_i, const float *weight_h, const float *input_bias, - const float *state_bias, float *hidden_state, float *cell_state, float *state_buffer[2], float *buffer[4], + const float *state_bias, float *hidden_state, float *cell_state, float *buffer[6], const LstmParameter *lstm_param); #ifdef __cplusplus } diff --git a/mindspore/lite/nnacl/gru_parameter.h b/mindspore/lite/nnacl/gru_parameter.h index 29ebdbdc128..fdea2c29868 100644 --- a/mindspore/lite/nnacl/gru_parameter.h +++ b/mindspore/lite/nnacl/gru_parameter.h @@ -27,11 +27,12 @@ typedef struct GruParameter { int seq_len_; int batch_; // other parameter - int input_step_; int output_step_; bool bidirectional_; - int col_align_; - int row_align_; + int input_row_align_; + int input_col_align_; + int state_row_align_; + int state_col_align_; } GruParameter; #endif // MINDSPORE_LITE_NNACL_GRU_PARAMETER_H_ diff --git a/mindspore/lite/nnacl/lstm_parameter.h b/mindspore/lite/nnacl/lstm_parameter.h index d29c880f87f..b400a2437c2 100644 --- a/mindspore/lite/nnacl/lstm_parameter.h +++ b/mindspore/lite/nnacl/lstm_parameter.h @@ -27,7 +27,6 @@ typedef struct LstmParameter { int seq_len_; int batch_; // other parameter - int input_step_; int output_step_; bool bidirectional_; float zoneout_cell_; @@ -36,8 +35,6 @@ typedef struct LstmParameter { int input_col_align_; int state_row_align_; int state_col_align_; - int col_align_; - int row_align_; } LstmParameter; #endif // MINDSPORE_LITE_NNACL_LSTM_PARAMETER_H_ diff --git a/mindspore/lite/src/runtime/kernel/arm/fp16/gru_fp16.cc b/mindspore/lite/src/runtime/kernel/arm/fp16/gru_fp16.cc index 0045f98c3d1..779bcac1b8e 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp16/gru_fp16.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp16/gru_fp16.cc @@ -30,33 +30,31 @@ using mindspore::schema::PrimitiveType_GRU; namespace mindspore::kernel { void GruFp16CPUKernel::FreeTmpBuffer() { - if (!is_vec_ || in_tensors_[1]->data_type() == kNumberTypeFloat32) { - if (weight_g_ptr_ != nullptr) { - free(weight_g_ptr_); - weight_g_ptr_ = nullptr; - } + if (weight_g_ptr_ != nullptr) { + free(weight_g_ptr_); + weight_g_ptr_ = nullptr; } - if (!is_vec_ || in_tensors_[2]->data_type() == kNumberTypeFloat32) { - if (weight_r_ptr_ != nullptr) { - free(weight_r_ptr_); - weight_r_ptr_ = nullptr; - } + if (input_bias_ != nullptr) { + free(input_bias_); + input_bias_ = nullptr; } - if (!is_vec_ || in_tensors_[3]->data_type() == kNumberTypeFloat32) { - if (bias_ptr_ != nullptr) { - free(bias_ptr_); - bias_ptr_ = nullptr; - } + if (weight_r_ptr_ != nullptr) { + free(weight_r_ptr_); + weight_r_ptr_ = nullptr; + } + if (state_bias_ != nullptr) { + free(state_bias_); + state_bias_ = nullptr; } } void GruFp16CPUKernel::FreeRunBuffer() { - context_->allocator->Free(gate_buffer_); + context_->allocator->Free(buffer_[0]); + context_->allocator->Free(buffer_[1]); if (!is_vec_) { - for (int i = 0; i < 2; i++) { - context_->allocator->Free(matmul_buffer_[i]); - } + context_->allocator->Free(buffer_[2]); } + context_->allocator->Free(buffer_[3]); } int GruFp16CPUKernel::InitParam() { @@ -71,105 +69,120 @@ int GruFp16CPUKernel::InitParam() { MS_ASSERT(weight_g != nullptr); std::vector w_shape = weight_g->shape(); gru_param_->hidden_size_ = w_shape.at(1) / 3; - - gru_param_->input_step_ = gru_param_->batch_ * gru_param_->input_size_; + weight_batch_ = gru_param_->bidirectional_ ? 6 : 3; gru_param_->output_step_ = gru_param_->bidirectional_ ? 2 * gru_param_->batch_ * gru_param_->hidden_size_ : gru_param_->batch_ * gru_param_->hidden_size_; + gru_param_->input_row_align_ = UP_ROUND(gru_param_->seq_len_ * gru_param_->batch_, C16NUM); + gru_param_->input_col_align_ = UP_ROUND(gru_param_->hidden_size_, C8NUM); + is_vec_ = gru_param_->batch_ == 1; - gru_param_->row_align_ = is_vec_ ? gru_param_->batch_ : UP_ROUND(gru_param_->batch_, C16NUM); - gru_param_->col_align_ = is_vec_ ? gru_param_->hidden_size_ : UP_ROUND(gru_param_->hidden_size_, C8NUM); + gru_param_->state_row_align_ = is_vec_ ? gru_param_->batch_ : UP_ROUND(gru_param_->batch_, C16NUM); + gru_param_->state_col_align_ = is_vec_ ? gru_param_->hidden_size_ : UP_ROUND(gru_param_->hidden_size_, C8NUM); return RET_OK; } -int GruFp16CPUKernel::InitWeight(const lite::Tensor *tensor, float16_t *ptr, int deep) { - auto weight_batch = gru_param_->bidirectional_ ? 6 : 3; - if (tensor->data_type() == kNumberTypeFloat32) { - auto weight_data = reinterpret_cast(tensor->data_c()); - is_vec_ ? Float32ToFloat16(weight_data, ptr, tensor->ElementsNum()) - : PackLstmWeightFp32ToFp16(ptr, weight_data, weight_batch, deep, gru_param_->hidden_size_, - gru_param_->col_align_); - } else if (tensor->data_type() == kNumberTypeFloat16) { - auto weight_data = reinterpret_cast(tensor->data_c()); - if (is_vec_) { - ptr = weight_data; - } else { - PackLstmWeightFp16(ptr, weight_data, weight_batch, deep, gru_param_->hidden_size_, gru_param_->col_align_); - } - } else { - MS_LOG(ERROR) << "Unsupported data type of weight tensor for lstm."; - return RET_ERROR; - } - return RET_OK; -} - -int GruFp16CPUKernel::InitWeightBias() { - auto weight_batch = gru_param_->bidirectional_ ? 6 : 3; +int GruFp16CPUKernel::InitInputWeightBias() { // malloc and init input * weight right matrix buffer + // input -- row: seq_len * batch; col: input_size + // weight -- row: hidden_size; col: input_size, need transpose + // result -- row: seq_len * batch; col: hidden_size auto weight_g = in_tensors_.at(1); MS_ASSERT(weight_g != nullptr); - if (!is_vec_ || weight_g->data_type() == kNumberTypeFloat32) { - weight_g_ptr_ = reinterpret_cast( - malloc(weight_batch * gru_param_->col_align_ * gru_param_->input_size_ * sizeof(float16_t))); - if (weight_g_ptr_ == nullptr) { - MS_LOG(ERROR) << "GruFp16CPUKernel malloc weight_g_ptr_ error."; - return RET_ERROR; - } + weight_g_ptr_ = reinterpret_cast( + malloc(weight_batch_ * gru_param_->input_col_align_ * gru_param_->input_size_ * sizeof(float16_t))); + if (weight_g_ptr_ == nullptr) { + MS_LOG(ERROR) << "GruFp16CPUKernel malloc weight_g_ptr_ error."; + return RET_ERROR; } - auto ret = InitWeight(weight_g, weight_g_ptr_, gru_param_->input_size_); - if (ret != RET_OK) { - MS_LOG(ERROR) << "GruFp16CPUKernel init weight_g failed."; + if (weight_g->data_type() == kNumberTypeFloat32) { + PackLstmWeightFp32ToFp16(weight_g_ptr_, reinterpret_cast(weight_g->data_c()), weight_batch_, + gru_param_->input_size_, gru_param_->hidden_size_, gru_param_->input_col_align_); + } else if (weight_g->data_type() == kNumberTypeFloat16) { + PackLstmWeightFp16(weight_g_ptr_, reinterpret_cast(weight_g->data_c()), weight_batch_, + gru_param_->input_size_, gru_param_->hidden_size_, gru_param_->input_col_align_); + } else { + MS_LOG(ERROR) << "Unsupported data type of weight_g tensor for gru."; return RET_ERROR; } - // malloc and init state * weight right matrix buffer - auto weight_r = in_tensors_.at(2); - MS_ASSERT(weight_r != nullptr); - if (!is_vec_ || weight_r->data_type() == kNumberTypeFloat32) { - weight_r_ptr_ = reinterpret_cast( - malloc(weight_batch * gru_param_->col_align_ * gru_param_->hidden_size_ * sizeof(float16_t))); - if (weight_r_ptr_ == nullptr) { - MS_LOG(ERROR) << "GruFp16CPUKernel malloc weight_r_ptr_ error."; - return RET_ERROR; - } - } - ret = InitWeight(weight_r, weight_r_ptr_, gru_param_->hidden_size_); - if (ret != RET_OK) { - MS_LOG(ERROR) << "GruFp16CPUKernel init weight_r failed."; - return RET_ERROR; - } - - int bias_batch = gru_param_->bidirectional_ ? 12 : 6; + // input bias auto bias = in_tensors_.at(3); MS_ASSERT(bias != nullptr); - if (!is_vec_ || bias->data_type() == kNumberTypeFloat32) { - bias_ptr_ = reinterpret_cast(malloc(bias_batch * gru_param_->col_align_ * sizeof(float16_t))); - if (bias_ptr_ == nullptr) { - MS_LOG(ERROR) << "GruFp16CPUKernel malloc bias_ptr_ error."; + input_bias_ = reinterpret_cast(malloc(weight_batch_ * gru_param_->input_col_align_ * sizeof(float16_t))); + if (input_bias_ == nullptr) { + MS_LOG(ERROR) << "GruFp16CPUKernel malloc input_bias_ error."; + return RET_ERROR; + } + memset(input_bias_, 0, weight_batch_ * gru_param_->input_col_align_ * sizeof(float16_t)); + if (bias->data_type() == kNumberTypeFloat32) { + PackLstmBiasFp32ToFp16(input_bias_, reinterpret_cast(bias->data_c()), weight_batch_, + gru_param_->hidden_size_, gru_param_->input_col_align_, gru_param_->bidirectional_); + } else if (bias->data_type() == kNumberTypeFloat16) { + PackLstmBiasFp16(input_bias_, reinterpret_cast(bias->data_c()), weight_batch_, + gru_param_->hidden_size_, gru_param_->input_col_align_, gru_param_->bidirectional_); + } else { + MS_LOG(ERROR) << "Unsupported data type of bias tensor for gru."; + return RET_ERROR; + } + return RET_OK; +} + +int GruFp16CPUKernel::InitStateWeightBias() { + // malloc and init state * weight right matrix buffer, state * weight will be executed seq_len_ times. + // state -- row: batch; col: hidden_size + // weight -- row: hidden_size; col: hidden_size, need transpose + // result -- row: batch; col: hidden_size + auto weight_r = in_tensors_.at(2); + MS_ASSERT(weight_r != nullptr); + weight_r_ptr_ = reinterpret_cast( + malloc(weight_batch_ * gru_param_->state_col_align_ * gru_param_->hidden_size_ * sizeof(float16_t))); + if (weight_r_ptr_ == nullptr) { + MS_LOG(ERROR) << "GruFp16CPUKernel malloc weight_r_ptr_ error."; + return RET_ERROR; + } + + if (!is_vec_) { + if (weight_r->data_type() == kNumberTypeFloat32) { + PackLstmWeightFp32ToFp16(weight_r_ptr_, reinterpret_cast(weight_r->data_c()), weight_batch_, + gru_param_->hidden_size_, gru_param_->hidden_size_, gru_param_->state_col_align_); + } else if (weight_r->data_type() == kNumberTypeFloat16) { + PackLstmWeightFp16(weight_r_ptr_, reinterpret_cast(weight_r->data_c()), weight_batch_, + gru_param_->hidden_size_, gru_param_->hidden_size_, gru_param_->state_col_align_); + } else { + MS_LOG(ERROR) << "Unsupported data type of weight_r tensor for gru."; return RET_ERROR; } - memset(bias_ptr_, 0, bias_batch * gru_param_->col_align_ * sizeof(float16_t)); - } - if (bias->data_type() == kNumberTypeFloat32) { - auto bias_data = reinterpret_cast(bias->data_c()); - for (int i = 0; i < bias_batch; i++) { - auto src_batch = bias_data + i * gru_param_->hidden_size_; - auto dst_batch = bias_ptr_ + i * gru_param_->col_align_; - Float32ToFloat16(src_batch, dst_batch, gru_param_->hidden_size_); - } - } else if (bias->data_type() == kNumberTypeFloat16) { - auto bias_data = reinterpret_cast(bias->data_c()); - if (is_vec_) { - bias_ptr_ = bias_data; - } else { - for (int i = 0; i < bias_batch; i++) { - auto src_batch = bias_data + i * gru_param_->hidden_size_; - auto dst_batch = bias_ptr_ + i * gru_param_->col_align_; - memcpy(dst_batch, src_batch, gru_param_->hidden_size_ * sizeof(float16_t)); - } - } } else { - MS_LOG(ERROR) << "Unsupported data type of bias tensor for lstm."; + if (weight_r->data_type() == kNumberTypeFloat32) { + Float32ToFloat16(reinterpret_cast(weight_r->data_c()), weight_r_ptr_, weight_r->ElementsNum()); + } else if (weight_r->data_type() == kNumberTypeFloat16) { + memcpy(weight_r_ptr_, reinterpret_cast(weight_r->data_c()), weight_r->ElementsNum()); + } else { + MS_LOG(ERROR) << "Unsupported data type of weight_r tensor for gru."; + return RET_ERROR; + } + } + + // state bias + auto bias = in_tensors_.at(3); + MS_ASSERT(bias != nullptr); + state_bias_ = reinterpret_cast(malloc(weight_batch_ * gru_param_->state_col_align_ * sizeof(float16_t))); + if (state_bias_ == nullptr) { + MS_LOG(ERROR) << "GruFp16CPUKernel malloc state_bias_ error."; + return RET_ERROR; + } + memset(state_bias_, 0, weight_batch_ * gru_param_->state_col_align_ * sizeof(float16_t)); + if (bias->data_type() == kNumberTypeFloat32) { + auto state_bias_data = reinterpret_cast(bias->data_c()) + 3 * gru_param_->hidden_size_; + PackLstmBiasFp32ToFp16(state_bias_, state_bias_data, weight_batch_, gru_param_->hidden_size_, + gru_param_->state_col_align_, gru_param_->bidirectional_); + } else if (bias->data_type() == kNumberTypeFloat16) { + auto state_bias_data = reinterpret_cast(bias->data_c()) + 3 * gru_param_->hidden_size_; + PackLstmBiasFp16(state_bias_, state_bias_data, weight_batch_, gru_param_->hidden_size_, + gru_param_->state_col_align_, gru_param_->bidirectional_); + } else { + MS_LOG(ERROR) << "Unsupported data type of bias tensor for gru."; return RET_ERROR; } return RET_OK; @@ -190,9 +203,16 @@ int GruFp16CPUKernel::ReSize() { } FreeTmpBuffer(); - ret = InitWeightBias(); + ret = InitInputWeightBias(); if (ret != RET_OK) { - MS_LOG(ERROR) << "GruFp16CPUKernel InitWeightBias error."; + MS_LOG(ERROR) << "GruFp16CPUKernel InitInputWeightBias error."; + FreeTmpBuffer(); + return RET_ERROR; + } + + ret = InitStateWeightBias(); + if (ret != RET_OK) { + MS_LOG(ERROR) << "GruFp16CPUKernel InitStateWeightBias error."; FreeTmpBuffer(); return RET_ERROR; } @@ -200,26 +220,36 @@ int GruFp16CPUKernel::ReSize() { } int GruFp16CPUKernel::MallocRunBuffer() { - if (!is_vec_) { - matmul_buffer_[0] = reinterpret_cast( - context_->allocator->Malloc(3 * gru_param_->row_align_ * gru_param_->input_size_ * sizeof(float16_t))); - if (matmul_buffer_[0] == nullptr) { - MS_LOG(ERROR) << "GruFp16CPUKernel malloc input * weight left matirx error."; - return RET_ERROR; - } + for (int i = 0; i < 4; i++) { + buffer_[i] = nullptr; + } + buffer_[0] = reinterpret_cast( + context_->allocator->Malloc(gru_param_->input_row_align_ * gru_param_->input_size_ * sizeof(float16_t))); + if (buffer_[0] == nullptr) { + MS_LOG(ERROR) << "GruCPUKernel malloc input * weight left matirx error."; + return RET_ERROR; + } - matmul_buffer_[1] = reinterpret_cast( - context_->allocator->Malloc(3 * gru_param_->row_align_ * gru_param_->hidden_size_ * sizeof(float16_t))); - if (matmul_buffer_[1] == nullptr) { - MS_LOG(ERROR) << "GruFp16CPUKernel malloc state * weight left matirx error."; + buffer_[1] = reinterpret_cast(context_->allocator->Malloc(3 * gru_param_->seq_len_ * gru_param_->batch_ * + gru_param_->hidden_size_ * sizeof(float16_t))); + if (buffer_[1] == nullptr) { + MS_LOG(ERROR) << "GruCPUKernel malloc input * weight result matirx error."; + return RET_ERROR; + } + + if (!is_vec_) { + buffer_[2] = reinterpret_cast( + context_->allocator->Malloc(gru_param_->state_row_align_ * gru_param_->hidden_size_ * sizeof(float16_t))); + if (buffer_[2] == nullptr) { + MS_LOG(ERROR) << "GruCPUKernel malloc state * weight left matirx error."; return RET_ERROR; } } - gate_buffer_ = reinterpret_cast( - context_->allocator->Malloc(4 * gru_param_->batch_ * gru_param_->hidden_size_ * sizeof(float16_t))); - if (gate_buffer_ == nullptr) { - MS_LOG(ERROR) << "GruFp16CPUKernel malloc gate_buffer error."; + buffer_[3] = reinterpret_cast( + context_->allocator->Malloc(3 * gru_param_->batch_ * gru_param_->hidden_size_ * sizeof(float16_t))); + if (buffer_[3] == nullptr) { + MS_LOG(ERROR) << "GruCPUKernel malloc state gate buffer error."; return RET_ERROR; } return RET_OK; @@ -255,11 +285,10 @@ int GruFp16CPUKernel::Run() { } MS_ASSERT(weight_g_ptr_ != nullptr); MS_ASSERT(weight_r_ptr_ != nullptr); - MS_ASSERT(bias_ptr_ != nullptr); - MS_ASSERT(gate_buffer_ != nullptr); - GruFp16(output_ptr, input_ptr, weight_g_ptr_, weight_r_ptr_, bias_ptr_, - reinterpret_cast(output_hidden_state->data_c()), gate_buffer_, matmul_buffer_, check_seq_len, - gru_param_); + MS_ASSERT(input_bias_ != nullptr); + MS_ASSERT(state_bias_ != nullptr); + GruFp16(output_ptr, input_ptr, weight_g_ptr_, weight_r_ptr_, input_bias_, state_bias_, + reinterpret_cast(output_hidden_state->data_c()), buffer_, check_seq_len, gru_param_); FreeRunBuffer(); return RET_OK; } diff --git a/mindspore/lite/src/runtime/kernel/arm/fp16/gru_fp16.h b/mindspore/lite/src/runtime/kernel/arm/fp16/gru_fp16.h index a11bada3496..eb2bcd164fa 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp16/gru_fp16.h +++ b/mindspore/lite/src/runtime/kernel/arm/fp16/gru_fp16.h @@ -38,15 +38,16 @@ class GruFp16CPUKernel : public LiteKernel { void FreeTmpBuffer(); void FreeRunBuffer(); int InitParam(); - int InitWeight(const lite::Tensor *tensor, float16_t *ptr, int deep); - int InitWeightBias(); + int InitInputWeightBias(); + int InitStateWeightBias(); int MallocRunBuffer(); - float16_t *gate_buffer_ = nullptr; float16_t *weight_g_ptr_ = nullptr; float16_t *weight_r_ptr_ = nullptr; - float16_t *bias_ptr_ = nullptr; - float16_t *matmul_buffer_[2]; + float16_t *input_bias_ = nullptr; + float16_t *state_bias_ = nullptr; + float16_t *buffer_[4]; + int weight_batch_ = 0; bool is_vec_ = false; GruParameter *gru_param_ = nullptr; }; diff --git a/mindspore/lite/src/runtime/kernel/arm/fp16/lstm_fp16.cc b/mindspore/lite/src/runtime/kernel/arm/fp16/lstm_fp16.cc index 9cac5b77fbc..4f03669ec9b 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp16/lstm_fp16.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp16/lstm_fp16.cc @@ -31,35 +31,36 @@ using mindspore::schema::PrimitiveType_LSTM; namespace mindspore::kernel { void LstmFp16CPUKernel::FreeTmpBuffer() { - if (!is_vec_ || in_tensors_[1]->data_type() == kNumberTypeFloat32) { - if (weight_i_ptr_ != nullptr) { - free(weight_i_ptr_); - weight_i_ptr_ = nullptr; - } + if (weight_i_ptr_ != nullptr) { + free(weight_i_ptr_); + weight_i_ptr_ = nullptr; } - if (!is_vec_ || in_tensors_[2]->data_type() == kNumberTypeFloat32) { - if (weight_h_ptr_ != nullptr) { - free(weight_h_ptr_); - weight_h_ptr_ = nullptr; - } + if (input_bias_ != nullptr) { + free(input_bias_); + input_bias_ = nullptr; } - if (!is_vec_ || in_tensors_[3]->data_type() == kNumberTypeFloat32) { - if (bias_ptr_ != nullptr) { - free(bias_ptr_); - bias_ptr_ = nullptr; - } + if (weight_h_ptr_ != nullptr) { + free(weight_h_ptr_); + weight_h_ptr_ = nullptr; + } + if (state_bias_ != nullptr) { + free(state_bias_); + state_bias_ = nullptr; } } void LstmFp16CPUKernel::FreeRunBuffer() { - context_->allocator->Free(gate_buffer_); - for (int i = 0; i < 2; i++) { - context_->allocator->Free(state_buffer_[i]); - } + context_->allocator->Free(buffer_[0]); + context_->allocator->Free(buffer_[1]); if (!is_vec_) { - for (int i = 0; i < 2; i++) { - context_->allocator->Free(matmul_buffer_[i]); - } + context_->allocator->Free(buffer_[2]); + } + context_->allocator->Free(buffer_[3]); + if (!(lstm_param_->zoneout_cell_ >= -FLT_EPSILON && lstm_param_->zoneout_cell_ <= FLT_EPSILON)) { + context_->allocator->Free(buffer_[4]); + } + if (!(lstm_param_->zoneout_hidden_ >= -FLT_EPSILON && lstm_param_->zoneout_hidden_ <= FLT_EPSILON)) { + context_->allocator->Free(buffer_[5]); } } @@ -76,102 +77,119 @@ int LstmFp16CPUKernel::InitParam() { std::vector w_shape = weight_i->shape(); lstm_param_->hidden_size_ = w_shape.at(1) / 4; - lstm_param_->input_step_ = lstm_param_->batch_ * lstm_param_->input_size_; lstm_param_->output_step_ = lstm_param_->bidirectional_ ? 2 * lstm_param_->batch_ * lstm_param_->hidden_size_ : lstm_param_->batch_ * lstm_param_->hidden_size_; + weight_batch_ = lstm_param_->bidirectional_ ? 8 : 4; + lstm_param_->input_row_align_ = UP_ROUND(lstm_param_->seq_len_ * lstm_param_->batch_, C16NUM); + lstm_param_->input_col_align_ = UP_ROUND(lstm_param_->hidden_size_, C8NUM); is_vec_ = lstm_param_->batch_ == 1; - lstm_param_->row_align_ = is_vec_ ? lstm_param_->batch_ : UP_ROUND(lstm_param_->batch_, C16NUM); - lstm_param_->col_align_ = is_vec_ ? lstm_param_->hidden_size_ : UP_ROUND(lstm_param_->hidden_size_, C8NUM); + lstm_param_->state_row_align_ = is_vec_ ? lstm_param_->batch_ : UP_ROUND(lstm_param_->batch_, C16NUM); + lstm_param_->state_col_align_ = is_vec_ ? lstm_param_->hidden_size_ : UP_ROUND(lstm_param_->hidden_size_, C8NUM); return RET_OK; } -int LstmFp16CPUKernel::InitWeight(const lite::Tensor *tensor, float16_t *ptr, int deep) { - auto weight_batch = lstm_param_->bidirectional_ ? 8 : 4; - if (tensor->data_type() == kNumberTypeFloat32) { - auto weight_data = reinterpret_cast(tensor->data_c()); - is_vec_ ? Float32ToFloat16(weight_data, ptr, tensor->ElementsNum()) - : PackLstmWeightFp32ToFp16(ptr, weight_data, weight_batch, deep, lstm_param_->hidden_size_, - lstm_param_->col_align_); - } else if (tensor->data_type() == kNumberTypeFloat16) { - auto weight_data = reinterpret_cast(tensor->data_c()); - if (is_vec_) { - ptr = weight_data; - } else { - PackLstmWeightFp16(ptr, weight_data, weight_batch, deep, lstm_param_->hidden_size_, lstm_param_->col_align_); - } - } else { - MS_LOG(ERROR) << "Unsupported data type of weight tensor for lstm."; - return RET_ERROR; - } - return RET_OK; -} - -int LstmFp16CPUKernel::InitWeightBias() { - auto weight_batch = lstm_param_->bidirectional_ ? 8 : 4; +int LstmFp16CPUKernel::InitInputWeightBias() { // malloc and init input * weight right matrix buffer + // input -- row: seq_len * batch; col: input_size + // weight -- row: hidden_size; col: input_size, need transpose + // result -- row: seq_len * batch; col: hidden_size auto weight_i = in_tensors_.at(1); MS_ASSERT(weight_i != nullptr); - if (!is_vec_ || weight_i->data_type() == kNumberTypeFloat32) { - weight_i_ptr_ = reinterpret_cast( - malloc(weight_batch * lstm_param_->col_align_ * lstm_param_->input_size_ * sizeof(float16_t))); - if (weight_i_ptr_ == nullptr) { - MS_LOG(ERROR) << "LstmFp16CPUKernel malloc weight_i_ptr_ error."; - return RET_ERROR; - } + weight_i_ptr_ = reinterpret_cast( + malloc(weight_batch_ * lstm_param_->input_col_align_ * lstm_param_->input_size_ * sizeof(float16_t))); + if (weight_i_ptr_ == nullptr) { + MS_LOG(ERROR) << "LstmFp16CPUKernel malloc weight_i_ptr_ error."; + return RET_ERROR; } - auto ret = InitWeight(weight_i, weight_i_ptr_, lstm_param_->input_size_); - if (ret != RET_OK) { - MS_LOG(ERROR) << "LstmFp16CPUKernel init weight_i failed."; + if (weight_i->data_type() == kNumberTypeFloat32) { + PackLstmWeightFp32ToFp16(weight_i_ptr_, reinterpret_cast(weight_i->data_c()), weight_batch_, + lstm_param_->input_size_, lstm_param_->hidden_size_, lstm_param_->input_col_align_); + } else if (weight_i->data_type() == kNumberTypeFloat16) { + PackLstmWeightFp16(weight_i_ptr_, reinterpret_cast(weight_i->data_c()), weight_batch_, + lstm_param_->input_size_, lstm_param_->hidden_size_, lstm_param_->input_col_align_); + } else { + MS_LOG(ERROR) << "Unsupported data type of weight_i tensor for lstm."; return RET_ERROR; } - // malloc and init state * weight right matrix buffer - auto weight_h = in_tensors_.at(2); - MS_ASSERT(weight_h != nullptr); - if (!is_vec_ || weight_h->data_type() == kNumberTypeFloat32) { - weight_h_ptr_ = reinterpret_cast( - malloc(weight_batch * lstm_param_->col_align_ * lstm_param_->hidden_size_ * sizeof(float16_t))); - if (weight_h_ptr_ == nullptr) { - MS_LOG(ERROR) << "LstmFp16CPUKernel malloc weight_h_ptr_ error."; - return RET_ERROR; - } - } - ret = InitWeight(weight_h, weight_h_ptr_, lstm_param_->hidden_size_); - if (ret != RET_OK) { - MS_LOG(ERROR) << "LstmFp16CPUKernel init weight_h failed."; - return RET_ERROR; - } - - int bias_batch = lstm_param_->bidirectional_ ? 16 : 8; + // input bias auto bias = in_tensors_.at(3); MS_ASSERT(bias != nullptr); - if (!is_vec_ || bias->data_type() == kNumberTypeFloat32) { - bias_ptr_ = reinterpret_cast(malloc(bias_batch * lstm_param_->col_align_ * sizeof(float16_t))); - if (bias_ptr_ == nullptr) { - MS_LOG(ERROR) << "LstmFp16CPUKernel malloc bias_ptr_ error."; + input_bias_ = + reinterpret_cast(malloc(weight_batch_ * lstm_param_->input_col_align_ * sizeof(float16_t))); + if (input_bias_ == nullptr) { + MS_LOG(ERROR) << "LstmFp16CPUKernel malloc input_bias_ error."; + return RET_ERROR; + } + memset(input_bias_, 0, weight_batch_ * lstm_param_->input_col_align_ * sizeof(float16_t)); + if (bias->data_type() == kNumberTypeFloat32) { + PackLstmBiasFp32ToFp16(input_bias_, reinterpret_cast(bias->data_c()), weight_batch_, + lstm_param_->hidden_size_, lstm_param_->input_col_align_, lstm_param_->bidirectional_); + } else if (bias->data_type() == kNumberTypeFloat16) { + PackLstmBiasFp16(input_bias_, reinterpret_cast(bias->data_c()), weight_batch_, + lstm_param_->hidden_size_, lstm_param_->input_col_align_, lstm_param_->bidirectional_); + } else { + MS_LOG(ERROR) << "Unsupported data type of bias tensor for lstm."; + return RET_ERROR; + } + return RET_OK; +} + +int LstmFp16CPUKernel::InitStateWeightBias() { + // malloc and init state * weight right matrix buffer, state * weight will be executed seq_len_ times. + // state -- row: batch; col: hidden_size + // weight -- row: hidden_size; col: hidden_size, need transpose + // result -- row: batch; col: hidden_size + auto weight_h = in_tensors_.at(2); + MS_ASSERT(weight_h != nullptr); + weight_h_ptr_ = reinterpret_cast( + malloc(weight_batch_ * lstm_param_->state_col_align_ * lstm_param_->hidden_size_ * sizeof(float16_t))); + if (weight_h_ptr_ == nullptr) { + MS_LOG(ERROR) << "LstmFp16CPUKernel malloc weight_h_ptr_ error."; + return RET_ERROR; + } + + if (!is_vec_) { + if (weight_h->data_type() == kNumberTypeFloat32) { + PackLstmWeightFp32ToFp16(weight_h_ptr_, reinterpret_cast(weight_h->data_c()), weight_batch_, + lstm_param_->hidden_size_, lstm_param_->hidden_size_, lstm_param_->state_col_align_); + } else if (weight_h->data_type() == kNumberTypeFloat16) { + PackLstmWeightFp16(weight_h_ptr_, reinterpret_cast(weight_h->data_c()), weight_batch_, + lstm_param_->hidden_size_, lstm_param_->hidden_size_, lstm_param_->state_col_align_); + } else { + MS_LOG(ERROR) << "Unsupported data type of weight_h tensor for lstm."; return RET_ERROR; } - memset(bias_ptr_, 0, bias_batch * lstm_param_->col_align_ * sizeof(float16_t)); - } - if (bias->data_type() == kNumberTypeFloat32) { - auto bias_data = reinterpret_cast(bias->data_c()); - for (int i = 0; i < bias_batch; i++) { - auto src_batch = bias_data + i * lstm_param_->hidden_size_; - auto dst_batch = bias_ptr_ + i * lstm_param_->col_align_; - Float32ToFloat16(src_batch, dst_batch, lstm_param_->hidden_size_); - } - } else if (bias->data_type() == kNumberTypeFloat16) { - auto bias_data = reinterpret_cast(bias->data_c()); - if (is_vec_) { - bias_ptr_ = bias_data; + } else { + if (weight_h->data_type() == kNumberTypeFloat32) { + Float32ToFloat16(reinterpret_cast(weight_h->data_c()), weight_h_ptr_, weight_h->ElementsNum()); + } else if (weight_h->data_type() == kNumberTypeFloat16) { + memcpy(weight_h_ptr_, reinterpret_cast(weight_h->data_c()), weight_h->ElementsNum()); } else { - for (int i = 0; i < bias_batch; i++) { - auto src_batch = bias_data + i * lstm_param_->hidden_size_; - auto dst_batch = bias_ptr_ + i * lstm_param_->col_align_; - memcpy(dst_batch, src_batch, lstm_param_->hidden_size_ * sizeof(float16_t)); - } + MS_LOG(ERROR) << "Unsupported data type of weight_h tensor for lstm."; + return RET_ERROR; } + } + + // state bias + auto bias = in_tensors_.at(3); + MS_ASSERT(bias != nullptr); + state_bias_ = + reinterpret_cast(malloc(weight_batch_ * lstm_param_->state_col_align_ * sizeof(float16_t))); + if (state_bias_ == nullptr) { + MS_LOG(ERROR) << "LstmFp16CPUKernel malloc state_bias_ error."; + return RET_ERROR; + } + memset(state_bias_, 0, weight_batch_ * lstm_param_->state_col_align_ * sizeof(float16_t)); + if (bias->data_type() == kNumberTypeFloat32) { + auto state_bias_data = reinterpret_cast(bias->data_c()) + 4 * lstm_param_->hidden_size_; + PackLstmBiasFp32ToFp16(state_bias_, state_bias_data, weight_batch_, lstm_param_->hidden_size_, + lstm_param_->state_col_align_, lstm_param_->bidirectional_); + } else if (bias->data_type() == kNumberTypeFloat16) { + auto state_bias_data = reinterpret_cast(bias->data_c()) + 4 * lstm_param_->hidden_size_; + PackLstmBiasFp16(state_bias_, state_bias_data, weight_batch_, lstm_param_->hidden_size_, + lstm_param_->state_col_align_, lstm_param_->bidirectional_); } else { MS_LOG(ERROR) << "Unsupported data type of bias tensor for lstm."; return RET_ERROR; @@ -194,9 +212,16 @@ int LstmFp16CPUKernel::ReSize() { } FreeTmpBuffer(); - ret = InitWeightBias(); + ret = InitInputWeightBias(); if (ret != RET_OK) { - MS_LOG(ERROR) << "Lstm fp16 InitWeightBias error."; + MS_LOG(ERROR) << "Lstm fp16 InitInputWeightBias error."; + FreeTmpBuffer(); + return RET_ERROR; + } + + ret = InitStateWeightBias(); + if (ret != RET_OK) { + MS_LOG(ERROR) << "Lstm fp16 InitStateWeightBias error."; FreeTmpBuffer(); return RET_ERROR; } @@ -204,42 +229,51 @@ int LstmFp16CPUKernel::ReSize() { } int LstmFp16CPUKernel::MallocRunBuffer() { - if (!is_vec_) { - matmul_buffer_[0] = reinterpret_cast( - context_->allocator->Malloc(4 * lstm_param_->row_align_ * lstm_param_->input_size_ * sizeof(float16_t))); - if (matmul_buffer_[0] == nullptr) { - MS_LOG(ERROR) << "LstmFp16CPUKernel malloc input * weight left matirx error."; - return RET_ERROR; - } + for (int i = 0; i < 6; i++) { + buffer_[i] = nullptr; + } + buffer_[0] = reinterpret_cast( + context_->allocator->Malloc(lstm_param_->input_row_align_ * lstm_param_->input_size_ * sizeof(float16_t))); + if (buffer_[0] == nullptr) { + MS_LOG(ERROR) << "LstmFp16CPUKernel malloc input * weight left matirx error."; + return RET_ERROR; + } - matmul_buffer_[1] = reinterpret_cast( - context_->allocator->Malloc(4 * lstm_param_->row_align_ * lstm_param_->hidden_size_ * sizeof(float16_t))); - if (matmul_buffer_[1] == nullptr) { + buffer_[1] = reinterpret_cast(context_->allocator->Malloc( + 4 * lstm_param_->seq_len_ * lstm_param_->batch_ * lstm_param_->hidden_size_ * sizeof(float16_t))); + if (buffer_[1] == nullptr) { + MS_LOG(ERROR) << "LstmFp16CPUKernel malloc state * weight left matirx error."; + return RET_ERROR; + } + + if (!is_vec_) { + buffer_[2] = reinterpret_cast( + context_->allocator->Malloc(lstm_param_->state_row_align_ * lstm_param_->hidden_size_ * sizeof(float16_t))); + if (buffer_[2] == nullptr) { MS_LOG(ERROR) << "LstmFp16CPUKernel malloc state * weight left matirx error."; return RET_ERROR; } } - gate_buffer_ = reinterpret_cast( - context_->allocator->Malloc(8 * lstm_param_->batch_ * lstm_param_->hidden_size_ * sizeof(float16_t))); - if (gate_buffer_ == nullptr) { - MS_LOG(ERROR) << "LstmFp16CPUKernel malloc gate_buffer error."; + buffer_[3] = reinterpret_cast( + context_->allocator->Malloc(4 * lstm_param_->batch_ * lstm_param_->hidden_size_ * sizeof(float16_t))); + if (buffer_[3] == nullptr) { + MS_LOG(ERROR) << "LstmFp16CPUKernel malloc state gate buffer error."; return RET_ERROR; } - state_buffer_[0] = nullptr; - state_buffer_[1] = nullptr; + if (!(lstm_param_->zoneout_cell_ >= -FLT_EPSILON && lstm_param_->zoneout_cell_ <= FLT_EPSILON)) { int buffer_size = lstm_param_->batch_ * lstm_param_->hidden_size_ * sizeof(float16_t); - state_buffer_[0] = reinterpret_cast(context_->allocator->Malloc(buffer_size)); - if (state_buffer_[0] == nullptr) { + buffer_[4] = reinterpret_cast(context_->allocator->Malloc(buffer_size)); + if (buffer_[4] == nullptr) { MS_LOG(ERROR) << "LstmFp16CPUKernel malloc state_buffer for cell error."; return RET_ERROR; } } if (!(lstm_param_->zoneout_hidden_ >= -FLT_EPSILON && lstm_param_->zoneout_hidden_ <= FLT_EPSILON)) { int buffer_size = lstm_param_->batch_ * lstm_param_->hidden_size_ * sizeof(float16_t); - state_buffer_[1] = reinterpret_cast(context_->allocator->Malloc(buffer_size)); - if (state_buffer_[1] == nullptr) { + buffer_[5] = reinterpret_cast(context_->allocator->Malloc(buffer_size)); + if (buffer_[5] == nullptr) { MS_LOG(ERROR) << "LstmFp16CPUKernel malloc state_buffer for hidden error."; return RET_ERROR; } @@ -273,12 +307,11 @@ int LstmFp16CPUKernel::Run() { } MS_ASSERT(weight_i_ptr_); MS_ASSERT(weight_h_ptr_); - MS_ASSERT(bias_ptr_); - MS_ASSERT(gate_buffer_); - LstmFp16(output_ptr, input_ptr, weight_i_ptr_, weight_h_ptr_, bias_ptr_, + MS_ASSERT(input_bias_); + MS_ASSERT(state_bias_); + LstmFp16(output_ptr, input_ptr, weight_i_ptr_, weight_h_ptr_, input_bias_, state_bias_, reinterpret_cast(output_hidden_state->data_c()), - reinterpret_cast(output_cell_state->data_c()), gate_buffer_, state_buffer_, matmul_buffer_, - lstm_param_); + reinterpret_cast(output_cell_state->data_c()), buffer_, lstm_param_); FreeRunBuffer(); return RET_OK; } diff --git a/mindspore/lite/src/runtime/kernel/arm/fp16/lstm_fp16.h b/mindspore/lite/src/runtime/kernel/arm/fp16/lstm_fp16.h index afa6ceae50f..5eb0b5d0ee1 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp16/lstm_fp16.h +++ b/mindspore/lite/src/runtime/kernel/arm/fp16/lstm_fp16.h @@ -40,16 +40,16 @@ class LstmFp16CPUKernel : public LiteKernel { void FreeTmpBuffer(); void FreeRunBuffer(); int InitParam(); - int InitWeight(const lite::Tensor *tensor, float16_t *ptr, int deep); - int InitWeightBias(); + int InitInputWeightBias(); + int InitStateWeightBias(); int MallocRunBuffer(); - float16_t *gate_buffer_ = nullptr; - float16_t *state_buffer_[2]; float16_t *weight_i_ptr_ = nullptr; float16_t *weight_h_ptr_ = nullptr; - float16_t *bias_ptr_ = nullptr; - float16_t *matmul_buffer_[2]; + float16_t *input_bias_ = nullptr; + float16_t *state_bias_ = nullptr; + float16_t *buffer_[6]; + int weight_batch_ = 0; bool is_vec_ = false; LstmParameter *lstm_param_ = nullptr; }; diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/gru_fp32.cc b/mindspore/lite/src/runtime/kernel/arm/fp32/gru_fp32.cc index fc4bfece90d..7aeffe8c9f5 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32/gru_fp32.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/gru_fp32.cc @@ -29,29 +29,33 @@ using mindspore::schema::PrimitiveType_GRU; namespace mindspore::kernel { void GruCPUKernel::FreeTmpBuffer() { + if (weight_g_ptr_ != nullptr) { + free(weight_g_ptr_); + weight_g_ptr_ = nullptr; + } + if (input_bias_ != nullptr) { + free(input_bias_); + input_bias_ = nullptr; + } if (!is_vec_) { - if (weight_g_ptr_ != nullptr) { - free(weight_g_ptr_); - weight_g_ptr_ = nullptr; - } if (weight_r_ptr_ != nullptr) { free(weight_r_ptr_); weight_r_ptr_ = nullptr; } - if (bias_ptr_ != nullptr) { - free(bias_ptr_); - bias_ptr_ = nullptr; - } + } + if (state_bias_ != nullptr) { + free(state_bias_); + state_bias_ = nullptr; } } void GruCPUKernel::FreeRunBuffer() { - context_->allocator->Free(gate_buffer_); + context_->allocator->Free(buffer_[0]); + context_->allocator->Free(buffer_[1]); if (!is_vec_) { - for (int i = 0; i < 2; i++) { - context_->allocator->Free(matmul_buffer_[i]); - } + context_->allocator->Free(buffer_[2]); } + context_->allocator->Free(buffer_[3]); } int GruCPUKernel::InitParam() { @@ -67,9 +71,9 @@ int GruCPUKernel::InitParam() { std::vector w_shape = weight_g->shape(); gru_param_->hidden_size_ = w_shape.at(1) / 3; - gru_param_->input_step_ = gru_param_->batch_ * gru_param_->input_size_; gru_param_->output_step_ = gru_param_->bidirectional_ ? 2 * gru_param_->batch_ * gru_param_->hidden_size_ : gru_param_->batch_ * gru_param_->hidden_size_; + weight_batch_ = gru_param_->bidirectional_ ? 6 : 3; #ifdef ENABLE_AVX row_tile_ = C6NUM; @@ -84,56 +88,75 @@ int GruCPUKernel::InitParam() { row_tile_ = C12NUM; col_tile_ = C8NUM; #endif + gru_param_->input_row_align_ = UP_ROUND(gru_param_->seq_len_ * gru_param_->batch_, row_tile_); + gru_param_->input_col_align_ = UP_ROUND(gru_param_->hidden_size_, col_tile_); + is_vec_ = gru_param_->batch_ == 1; - gru_param_->row_align_ = is_vec_ ? 1 : UP_ROUND(gru_param_->batch_, row_tile_); - gru_param_->col_align_ = is_vec_ ? gru_param_->hidden_size_ : UP_ROUND(gru_param_->hidden_size_, col_tile_); + gru_param_->state_row_align_ = is_vec_ ? 1 : UP_ROUND(gru_param_->batch_, row_tile_); + gru_param_->state_col_align_ = is_vec_ ? gru_param_->hidden_size_ : UP_ROUND(gru_param_->hidden_size_, col_tile_); return RET_OK; } -int GruCPUKernel::InitWeightBias() { - auto weight_batch = gru_param_->bidirectional_ ? 6 : 3; - if (!is_vec_) { - // malloc and init input * weight right matrix buffer - auto weight_g = in_tensors_.at(1); - MS_ASSERT(weight_g != nullptr); - weight_g_ptr_ = reinterpret_cast( - malloc(weight_batch * gru_param_->col_align_ * gru_param_->input_size_ * sizeof(float))); - if (weight_g_ptr_ == nullptr) { - MS_LOG(ERROR) << "GruCPUKernel malloc weight_g_ptr_ error."; - return RET_ERROR; - } - auto weight_i_data = reinterpret_cast(weight_g->data_c()); - PackLstmWeight(weight_g_ptr_, weight_i_data, weight_batch, gru_param_->input_size_, gru_param_->hidden_size_, - gru_param_->col_align_); +int GruCPUKernel::InitInputWeightBias() { + // malloc and init input * weight right matrix buffer + // input -- row: seq_len * batch; col: input_size + // weight -- row: hidden_size; col: input_size, need transpose + // result -- row: seq_len * batch; col: hidden_size + auto weight_g = in_tensors_.at(1); + MS_ASSERT(weight_g != nullptr); + weight_g_ptr_ = reinterpret_cast( + malloc(weight_batch_ * gru_param_->input_col_align_ * gru_param_->input_size_ * sizeof(float))); + if (weight_g_ptr_ == nullptr) { + MS_LOG(ERROR) << "GruCPUKernel malloc weight_g_ptr_ error."; + return RET_ERROR; + } + auto weight_g_data = reinterpret_cast(weight_g->data_c()); + PackLstmWeight(weight_g_ptr_, weight_g_data, weight_batch_, gru_param_->input_size_, gru_param_->hidden_size_, + gru_param_->input_col_align_); - // malloc and init state * weight right matrix buffer - auto weight_r = in_tensors_.at(2); - MS_ASSERT(weight_r != nullptr); + // input bias + input_bias_ = reinterpret_cast(malloc(weight_batch_ * gru_param_->input_col_align_ * sizeof(float))); + if (input_bias_ == nullptr) { + MS_LOG(ERROR) << "GruCPUKernel malloc input_bias_ error."; + return RET_ERROR; + } + memset(input_bias_, 0, weight_batch_ * gru_param_->input_col_align_ * sizeof(float)); + PackLstmBias(input_bias_, reinterpret_cast(in_tensors_.at(3)->data_c()), weight_batch_, + gru_param_->hidden_size_, gru_param_->input_col_align_, gru_param_->bidirectional_); + return RET_OK; +} + +int GruCPUKernel::InitStateWeightBias() { + // malloc and init state * weight right matrix buffer, state * weight will be executed seq_len_ times. + // state -- row: batch; col: hidden_size + // weight -- row: hidden_size; col: hidden_size, need transpose + // result -- row: batch; col: hidden_size + auto weight_r = in_tensors_.at(2); + MS_ASSERT(weight_r != nullptr); + auto weight_r_data = reinterpret_cast(weight_r->data_c()); + if (!is_vec_) { weight_r_ptr_ = reinterpret_cast( - malloc(weight_batch * gru_param_->col_align_ * gru_param_->hidden_size_ * sizeof(float))); + malloc(weight_batch_ * gru_param_->state_col_align_ * gru_param_->hidden_size_ * sizeof(float))); if (weight_r_ptr_ == nullptr) { MS_LOG(ERROR) << "GruCPUKernel malloc weight_r_ptr_ error."; return RET_ERROR; } - auto weight_r_data = reinterpret_cast(weight_r->data_c()); - PackLstmWeight(weight_r_ptr_, weight_r_data, weight_batch, gru_param_->hidden_size_, gru_param_->hidden_size_, - gru_param_->col_align_); - - // init bias - int bias_batch = gru_param_->bidirectional_ ? 16 : 8; - bias_ptr_ = reinterpret_cast(malloc(bias_batch * gru_param_->col_align_ * sizeof(float))); - if (bias_ptr_ == nullptr) { - MS_LOG(ERROR) << "GruCPUKernel malloc bias_ptr_ error."; - return RET_ERROR; - } - memset(bias_ptr_, 0, bias_batch * gru_param_->col_align_ * sizeof(float)); - auto bias_data = reinterpret_cast(in_tensors_.at(3)->data_c()); - for (int i = 0; i < bias_batch; i++) { - auto src_batch = bias_data + i * gru_param_->hidden_size_; - auto dst_batch = bias_ptr_ + i * gru_param_->col_align_; - memcpy(dst_batch, src_batch, gru_param_->hidden_size_ * sizeof(float)); - } + PackLstmWeight(weight_r_ptr_, weight_r_data, weight_batch_, gru_param_->hidden_size_, gru_param_->hidden_size_, + gru_param_->state_col_align_); + } else { + weight_r_ptr_ = weight_r_data; } + + // state bias + state_bias_ = reinterpret_cast(malloc(weight_batch_ * gru_param_->state_col_align_ * sizeof(float))); + if (state_bias_ == nullptr) { + MS_LOG(ERROR) << "GruCPUKernel malloc state_bias_ error."; + return RET_ERROR; + } + memset(state_bias_, 0, weight_batch_ * gru_param_->state_col_align_ * sizeof(float)); + auto state_bias = reinterpret_cast(in_tensors_.at(3)->data_c()) + 3 * gru_param_->hidden_size_; + PackLstmBias(state_bias_, state_bias, weight_batch_, gru_param_->hidden_size_, gru_param_->state_col_align_, + gru_param_->bidirectional_); return RET_OK; } @@ -152,9 +175,16 @@ int GruCPUKernel::ReSize() { } FreeTmpBuffer(); - ret = InitWeightBias(); + ret = InitInputWeightBias(); if (ret != RET_OK) { - MS_LOG(ERROR) << "GruCPUKernel InitWeightBias error."; + MS_LOG(ERROR) << "GruCPUKernel InitInputWeightBias error."; + FreeTmpBuffer(); + return RET_ERROR; + } + + ret = InitStateWeightBias(); + if (ret != RET_OK) { + MS_LOG(ERROR) << "GruCPUKernel InitStateWeightBias error."; FreeTmpBuffer(); return RET_ERROR; } @@ -162,25 +192,36 @@ int GruCPUKernel::ReSize() { } int GruCPUKernel::MallocRunBuffer() { - if (!is_vec_) { - matmul_buffer_[0] = reinterpret_cast( - context_->allocator->Malloc(3 * gru_param_->row_align_ * gru_param_->input_size_ * sizeof(float))); - if (matmul_buffer_[0] == nullptr) { - MS_LOG(ERROR) << "GruCPUKernel malloc input * weight left matirx error."; - return RET_ERROR; - } + for (int i = 0; i < 4; i++) { + buffer_[i] = nullptr; + } + buffer_[0] = reinterpret_cast( + context_->allocator->Malloc(gru_param_->input_row_align_ * gru_param_->input_size_ * sizeof(float))); + if (buffer_[0] == nullptr) { + MS_LOG(ERROR) << "GruCPUKernel malloc input * weight left matirx error."; + return RET_ERROR; + } - matmul_buffer_[1] = reinterpret_cast( - context_->allocator->Malloc(3 * gru_param_->row_align_ * gru_param_->hidden_size_ * sizeof(float))); - if (matmul_buffer_[1] == nullptr) { + buffer_[1] = reinterpret_cast(context_->allocator->Malloc(3 * gru_param_->seq_len_ * gru_param_->batch_ * + gru_param_->hidden_size_ * sizeof(float))); + if (buffer_[1] == nullptr) { + MS_LOG(ERROR) << "GruCPUKernel malloc input * weight result matirx error."; + return RET_ERROR; + } + + if (!is_vec_) { + buffer_[2] = reinterpret_cast( + context_->allocator->Malloc(gru_param_->state_row_align_ * gru_param_->hidden_size_ * sizeof(float))); + if (buffer_[2] == nullptr) { MS_LOG(ERROR) << "GruCPUKernel malloc state * weight left matirx error."; return RET_ERROR; } } - gate_buffer_ = reinterpret_cast( - context_->allocator->Malloc(6 * gru_param_->batch_ * gru_param_->hidden_size_ * sizeof(float))); - if (gate_buffer_ == nullptr) { - MS_LOG(ERROR) << "GruCPUKernel malloc gate_buffer error."; + + buffer_[3] = reinterpret_cast( + context_->allocator->Malloc(3 * gru_param_->batch_ * gru_param_->hidden_size_ * sizeof(float))); + if (buffer_[3] == nullptr) { + MS_LOG(ERROR) << "GruCPUKernel malloc state gate buffer error."; return RET_ERROR; } return RET_OK; @@ -215,18 +256,12 @@ int GruCPUKernel::Run() { return RET_ERROR; } - if (is_vec_) { - weight_g_ptr_ = reinterpret_cast(in_tensors_[1]->data_c()); - weight_r_ptr_ = reinterpret_cast(in_tensors_[2]->data_c()); - bias_ptr_ = reinterpret_cast(in_tensors_[3]->data_c()); - } MS_ASSERT(weight_g_ptr_ != nullptr); MS_ASSERT(weight_r_ptr_ != nullptr); - MS_ASSERT(bias_ptr_ != nullptr); - MS_ASSERT(gate_buffer_ != nullptr); - Gru(output_ptr, input_ptr, weight_g_ptr_, weight_r_ptr_, bias_ptr_, - reinterpret_cast(output_hidden_state->data_c()), gate_buffer_, matmul_buffer_, check_seq_len, - gru_param_); + MS_ASSERT(input_bias_ != nullptr); + MS_ASSERT(state_bias_ != nullptr); + Gru(output_ptr, input_ptr, weight_g_ptr_, weight_r_ptr_, input_bias_, state_bias_, + reinterpret_cast(output_hidden_state->data_c()), buffer_, check_seq_len, gru_param_); FreeRunBuffer(); return RET_OK; } diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/gru_fp32.h b/mindspore/lite/src/runtime/kernel/arm/fp32/gru_fp32.h index 741294a7d3f..53a1c6a7a8f 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32/gru_fp32.h +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/gru_fp32.h @@ -39,15 +39,17 @@ class GruCPUKernel : public LiteKernel { void FreeRunBuffer(); int InitParam(); int MallocRunBuffer(); - int InitWeightBias(); + int InitInputWeightBias(); + int InitStateWeightBias(); - float *gate_buffer_ = nullptr; float *weight_g_ptr_ = nullptr; float *weight_r_ptr_ = nullptr; - float *bias_ptr_ = nullptr; - float *matmul_buffer_[2]; + float *input_bias_ = nullptr; + float *state_bias_ = nullptr; + float *buffer_[4]; int row_tile_ = 0; int col_tile_ = 0; + int weight_batch_ = 0; bool is_vec_ = false; GruParameter *gru_param_ = nullptr; }; diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/lstm_fp32.cc b/mindspore/lite/src/runtime/kernel/arm/fp32/lstm_fp32.cc index f35f6f4f3aa..32292c85bc3 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32/lstm_fp32.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/lstm_fp32.cc @@ -52,15 +52,18 @@ void LstmCPUKernel::FreeTmpBuffer() { } void LstmCPUKernel::FreeRunBuffer() { - for (int i = 0; i < 2; i++) { - context_->allocator->Free(state_buffer_[i]); - } context_->allocator->Free(buffer_[0]); context_->allocator->Free(buffer_[1]); if (!is_vec_) { context_->allocator->Free(buffer_[2]); } context_->allocator->Free(buffer_[3]); + if (!(lstm_param_->zoneout_cell_ >= -FLT_EPSILON && lstm_param_->zoneout_cell_ <= FLT_EPSILON)) { + context_->allocator->Free(buffer_[4]); + } + if (!(lstm_param_->zoneout_hidden_ >= -FLT_EPSILON && lstm_param_->zoneout_hidden_ <= FLT_EPSILON)) { + context_->allocator->Free(buffer_[5]); + } } int LstmCPUKernel::InitInputWeightBias() { @@ -197,7 +200,7 @@ int LstmCPUKernel::ReSize() { } int LstmCPUKernel::MallocRunBuffer() { - for (int i = 0; i < 4; i++) { + for (int i = 0; i < 6; i++) { buffer_[i] = nullptr; } buffer_[0] = reinterpret_cast( @@ -216,7 +219,7 @@ int LstmCPUKernel::MallocRunBuffer() { if (!is_vec_) { buffer_[2] = reinterpret_cast( - context_->allocator->Malloc(4 * lstm_param_->state_row_align_ * lstm_param_->hidden_size_ * sizeof(float))); + context_->allocator->Malloc(lstm_param_->state_row_align_ * lstm_param_->hidden_size_ * sizeof(float))); if (buffer_[2] == nullptr) { MS_LOG(ERROR) << "LstmCPUKernel malloc state * weight left matirx error."; return RET_ERROR; @@ -229,20 +232,19 @@ int LstmCPUKernel::MallocRunBuffer() { MS_LOG(ERROR) << "LstmCPUKernel malloc state gate buffer error."; return RET_ERROR; } - state_buffer_[0] = nullptr; - state_buffer_[1] = nullptr; + if (!(lstm_param_->zoneout_cell_ >= -FLT_EPSILON && lstm_param_->zoneout_cell_ <= FLT_EPSILON)) { auto buffer_size = lstm_param_->batch_ * lstm_param_->hidden_size_ * sizeof(float); - state_buffer_[0] = reinterpret_cast(context_->allocator->Malloc(buffer_size)); - if (state_buffer_[0] == nullptr) { + buffer_[4] = reinterpret_cast(context_->allocator->Malloc(buffer_size)); + if (buffer_[4] == nullptr) { MS_LOG(ERROR) << "LstmCPUKernel malloc state_buffer for cell error."; return RET_ERROR; } } if (!(lstm_param_->zoneout_hidden_ >= -FLT_EPSILON && lstm_param_->zoneout_hidden_ <= FLT_EPSILON)) { auto buffer_size = lstm_param_->batch_ * lstm_param_->hidden_size_ * sizeof(float); - state_buffer_[1] = reinterpret_cast(context_->allocator->Malloc(buffer_size)); - if (state_buffer_[1] == nullptr) { + buffer_[5] = reinterpret_cast(context_->allocator->Malloc(buffer_size)); + if (buffer_[5] == nullptr) { MS_LOG(ERROR) << "LstmCPUKernel malloc state_buffer for hidden error."; return RET_ERROR; } @@ -281,7 +283,7 @@ int LstmCPUKernel::Run() { MS_ASSERT(state_bias_); Lstm(output_ptr, input_ptr, weight_i_ptr_, weight_h_ptr_, input_bias_, state_bias_, reinterpret_cast(output_hidden_state->data_c()), reinterpret_cast(output_cell_state->data_c()), - state_buffer_, buffer_, lstm_param_); + buffer_, lstm_param_); FreeRunBuffer(); return RET_OK; } diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/lstm_fp32.h b/mindspore/lite/src/runtime/kernel/arm/fp32/lstm_fp32.h index 53179c598bc..4261db6c3fe 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32/lstm_fp32.h +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/lstm_fp32.h @@ -44,12 +44,11 @@ class LstmCPUKernel : public LiteKernel { int InitInputWeightBias(); int InitStateWeightBias(); - float *state_buffer_[2]; float *weight_i_ptr_ = nullptr; float *weight_h_ptr_ = nullptr; float *input_bias_ = nullptr; float *state_bias_ = nullptr; - float *buffer_[4]; + float *buffer_[6]; int row_tile_ = 0; int col_tile_ = 0; int weight_batch_ = 0;