From 7c3eed1b1d1a2ebd2e8b6a31a00f38638918e003 Mon Sep 17 00:00:00 2001 From: yangruoqi713 Date: Fri, 5 Feb 2021 16:46:48 +0800 Subject: [PATCH] [MSLITE][Develop] optimize arm cpu fp16 op lstm and gru: use matmul fp16 --- mindspore/lite/nnacl/fp16/gru_fp16.c | 167 +++++++------ mindspore/lite/nnacl/fp16/gru_fp16.h | 4 +- mindspore/lite/nnacl/fp16/lstm_fp16.c | 186 +++++++------- mindspore/lite/nnacl/fp16/lstm_fp16.h | 9 +- mindspore/lite/nnacl/fp32/gru_fp32.c | 6 +- mindspore/lite/nnacl/fp32/lstm_fp32.c | 8 +- mindspore/lite/nnacl/fp32/lstm_fp32.h | 4 +- .../src/runtime/kernel/arm/fp16/gru_fp16.cc | 225 +++++++++++------ .../src/runtime/kernel/arm/fp16/gru_fp16.h | 10 +- .../src/runtime/kernel/arm/fp16/lstm_fp16.cc | 229 ++++++++++++------ .../src/runtime/kernel/arm/fp16/lstm_fp16.h | 6 +- mindspore/lite/test/run_benchmark_nets.sh | 2 +- 12 files changed, 517 insertions(+), 339 deletions(-) diff --git a/mindspore/lite/nnacl/fp16/gru_fp16.c b/mindspore/lite/nnacl/fp16/gru_fp16.c index 028a4c3263f..5fb6ff5730f 100644 --- a/mindspore/lite/nnacl/fp16/gru_fp16.c +++ b/mindspore/lite/nnacl/fp16/gru_fp16.c @@ -18,120 +18,129 @@ #include "nnacl/fp16/lstm_fp16.h" #include "nnacl/fp16/activation_fp16.h" #include "nnacl/fp16/arithmetic_fp16.h" +#include "nnacl/fp16/matmul_fp16.h" -void InitGruGateFp16(float16_t *gate_buffer, const float16_t *bias, const GruParameter *gru_parm) { - int gate_offest = 0; - for (int l = 0; l < 3; l++) { - int batch_offest = gate_offest; - int bias_offest = l * gru_parm->hidden_size_; - for (int b = 0; b < gru_parm->batch_; b++) { - memcpy(gate_buffer + batch_offest, bias + bias_offest, gru_parm->hidden_size_ * sizeof(float16_t)); - batch_offest += gru_parm->hidden_size_; - } - gate_offest += gru_parm->batch_ * gru_parm->hidden_size_; +void UpdateGruInputGateFp16(float16_t *gate_buffer, const float16_t *input, const float16_t *weight, + const float16_t *bias, int row, int deep, int col, int col_align, bool is_vec) { + for (int i = 0; i < 3; i++) { + const float16_t *weight_i = weight + deep * col * i; + const float16_t *bias_i = bias + col_align * i; + float16_t *gate = gate_buffer + row * col * i; + LstmMatMulFp16(gate, input, weight_i, bias_i, row, deep, col, is_vec); } } -void GruStepUnitFp16(float16_t *output, const float16_t *input, const float16_t *input_reset_weight, - const float16_t *input_update_weight, const float16_t *input_hidden_weight, - const float16_t *state_reset_weight, const float16_t *state_update_weight, - const float16_t *state_hidden_weight, const float16_t *bias, float16_t *hidden_state, - float16_t *gate_buffer, const GruParameter *gru_parm) { - InitGruGateFp16(gate_buffer, bias, gru_parm); - - float16_t *update_gate = gate_buffer; - float16_t *reset_gate = gate_buffer + gru_parm->batch_ * gru_parm->hidden_size_; - float16_t *hidden_buffer = gate_buffer + gru_parm->batch_ * gru_parm->hidden_size_ * 2; - +void GruStepUnitFp16(float16_t *output, const float16_t *input, const float16_t *input_weight, + const float16_t *state_weight, const float16_t *bias, float16_t *hidden_state, + float16_t *gate_buffer, float16_t *matmul_buffer[2], const GruParameter *gru_param) { + bool is_vec = gru_param->batch_ == 1; // input * weight - MatMulAccFp16(reset_gate, input, input_reset_weight, gru_parm->batch_, gru_parm->hidden_size_, gru_parm->input_size_); - MatMulAccFp16(update_gate, input, input_update_weight, gru_parm->batch_, gru_parm->hidden_size_, - gru_parm->input_size_); - MatMulAccFp16(hidden_buffer, input, input_hidden_weight, gru_parm->batch_, gru_parm->hidden_size_, - gru_parm->input_size_); + if (is_vec) { + UpdateGruInputGateFp16(gate_buffer, input, input_weight, bias, gru_param->batch_, gru_param->input_size_, + gru_param->hidden_size_, gru_param->col_align_, is_vec); + } else { + // pack input for matmul + RowMajor2Col16MajorFp16(input, matmul_buffer[0], gru_param->batch_, gru_param->input_size_, false); + UpdateGruInputGateFp16(gate_buffer, matmul_buffer[0], input_weight, bias, gru_param->batch_, gru_param->input_size_, + gru_param->hidden_size_, gru_param->col_align_, is_vec); + } + + const float16_t *state_update_weight = state_weight; + const float16_t *state_reset_weight = state_weight + gru_param->hidden_size_ * gru_param->hidden_size_; + const float16_t *state_hidden_weight = state_weight + gru_param->hidden_size_ * gru_param->hidden_size_ * 2; + float16_t *state_update_gate = gate_buffer + gru_param->batch_ * gru_param->hidden_size_ * 3; + float16_t *state_reset_gate = gate_buffer + gru_param->batch_ * gru_param->hidden_size_ * 4; + float16_t *state_hidden_buffer = gate_buffer + gru_param->batch_ * gru_param->hidden_size_ * 5; + const float16_t *state_update_bias = bias + gru_param->hidden_size_ * 3; + const float16_t *state_reset_bias = bias + gru_param->hidden_size_ * 4; + const float16_t *state_hidden_bias = bias + gru_param->hidden_size_ * 5; // state * weight - MatMulAccFp16(reset_gate, hidden_state, state_reset_weight, gru_parm->batch_, gru_parm->hidden_size_, - gru_parm->hidden_size_); - MatMulAccFp16(update_gate, hidden_state, state_update_weight, gru_parm->batch_, gru_parm->hidden_size_, - gru_parm->hidden_size_); + if (is_vec) { + LstmMatMulFp16(state_reset_gate, hidden_state, state_reset_weight, state_reset_bias, gru_param->batch_, + gru_param->hidden_size_, gru_param->hidden_size_, is_vec); + LstmMatMulFp16(state_update_gate, hidden_state, state_update_weight, state_update_bias, gru_param->batch_, + gru_param->hidden_size_, gru_param->hidden_size_, is_vec); + } else { + RowMajor2Col16MajorFp16(hidden_state, matmul_buffer[1], gru_param->batch_, gru_param->hidden_size_, false); + LstmMatMulFp16(state_reset_gate, matmul_buffer[1], state_reset_weight, state_reset_bias, gru_param->batch_, + gru_param->hidden_size_, gru_param->hidden_size_, is_vec); + LstmMatMulFp16(state_update_gate, matmul_buffer[1], state_update_weight, state_update_bias, gru_param->batch_, + gru_param->hidden_size_, gru_param->hidden_size_, is_vec); + } + + ElementAddFp16(gate_buffer, state_update_gate, gate_buffer, gru_param->batch_ * gru_param->hidden_size_ * 2); + float16_t *update_gate = gate_buffer; + float16_t *reset_gate = gate_buffer + gru_param->batch_ * gru_param->hidden_size_; + float16_t *hidden_buffer = gate_buffer + gru_param->batch_ * gru_param->hidden_size_ * 2; // update reset_gate - SigmoidFp16(reset_gate, reset_gate, gru_parm->batch_ * gru_parm->hidden_size_); + SigmoidFp16(reset_gate, reset_gate, gru_param->batch_ * gru_param->hidden_size_); // update update_gate - SigmoidFp16(update_gate, update_gate, gru_parm->batch_ * gru_parm->hidden_size_); + SigmoidFp16(update_gate, update_gate, gru_param->batch_ * gru_param->hidden_size_); - ElementMulFp16(hidden_state, reset_gate, reset_gate, gru_parm->batch_ * gru_parm->hidden_size_); - MatMulAccFp16(hidden_buffer, reset_gate, state_hidden_weight, gru_parm->batch_, gru_parm->hidden_size_, - gru_parm->hidden_size_); + ElementMulFp16(hidden_state, reset_gate, reset_gate, gru_param->batch_ * gru_param->hidden_size_); + if (is_vec) { + LstmMatMulFp16(state_hidden_buffer, reset_gate, state_hidden_weight, state_hidden_bias, gru_param->batch_, + gru_param->hidden_size_, gru_param->hidden_size_, is_vec); + } else { + RowMajor2Col16MajorFp16(reset_gate, matmul_buffer[1], gru_param->batch_, gru_param->hidden_size_, false); + LstmMatMulFp16(state_hidden_buffer, matmul_buffer[1], state_hidden_weight, state_hidden_bias, gru_param->batch_, + gru_param->hidden_size_, gru_param->hidden_size_, is_vec); + } + ElementAddFp16(hidden_buffer, state_hidden_buffer, hidden_buffer, gru_param->batch_ * gru_param->hidden_size_); - TanhFp16(hidden_buffer, hidden_buffer, gru_parm->batch_ * gru_parm->hidden_size_); + TanhFp16(hidden_buffer, hidden_buffer, gru_param->batch_ * gru_param->hidden_size_); - ElementMulFp16(update_gate, hidden_state, hidden_state, gru_parm->batch_ * gru_parm->hidden_size_); + ElementMulFp16(update_gate, hidden_state, hidden_state, gru_param->batch_ * gru_param->hidden_size_); ArithmeticParameter parameter; parameter.in_elements_num0_ = 1; - parameter.in_elements_num1_ = gru_parm->batch_ * gru_parm->hidden_size_; + parameter.in_elements_num1_ = gru_param->batch_ * gru_param->hidden_size_; float16_t one = 1.0f; - ElementOptSubFp16(&one, update_gate, update_gate, gru_parm->batch_ * gru_parm->hidden_size_, ¶meter); + ElementOptSubFp16(&one, update_gate, update_gate, gru_param->batch_ * gru_param->hidden_size_, ¶meter); - ElementMulAccFp16(update_gate, hidden_buffer, hidden_state, gru_parm->batch_ * gru_parm->hidden_size_); + ElementMulAccFp16(update_gate, hidden_buffer, hidden_state, gru_param->batch_ * gru_param->hidden_size_); - memcpy(output, hidden_state, gru_parm->batch_ * gru_parm->hidden_size_ * sizeof(float16_t)); + memcpy(output, hidden_state, gru_param->batch_ * gru_param->hidden_size_ * sizeof(float16_t)); } void GruFp16(float16_t *output, const float16_t *input, const float16_t *weight_g, const float16_t *weight_r, - const float16_t *bias, float16_t *hidden_state, float16_t *gate_buffer, int check_seq_len, - const GruParameter *gru_parm) { + const float16_t *bias, float16_t *hidden_state, float16_t *gate_buffer, float16_t *matmul_buffer[2], + int check_seq_len, const GruParameter *gru_param) { // forward - const float16_t *input_update_weight = weight_g; - const float16_t *input_reset_weight = weight_g + gru_parm->input_size_ * gru_parm->hidden_size_; - const float16_t *input_hidden_weight = weight_g + gru_parm->input_size_ * gru_parm->hidden_size_ * 2; - - const float16_t *state_update_weight = weight_r; - const float16_t *state_reset_weight = weight_r + gru_parm->hidden_size_ * gru_parm->hidden_size_; - const float16_t *state_hidden_weight = weight_r + gru_parm->hidden_size_ * gru_parm->hidden_size_ * 2; - for (int t = 0; t < check_seq_len; t++) { - const float16_t *input_ptr = input + t * gru_parm->input_step_; - float16_t *output_ptr = output + t * gru_parm->output_step_; - GruStepUnitFp16(output_ptr, input_ptr, input_reset_weight, input_update_weight, input_hidden_weight, - state_reset_weight, state_update_weight, state_hidden_weight, bias, hidden_state, gate_buffer, - gru_parm); + const float16_t *input_ptr = input + t * gru_param->input_step_; + float16_t *output_ptr = output + t * gru_param->output_step_; + GruStepUnitFp16(output_ptr, input_ptr, weight_g, weight_r, bias, hidden_state, gate_buffer, matmul_buffer, + gru_param); } // zero out extra fw outputs - for (int t = check_seq_len; t < gru_parm->seq_len_; t++) { - float16_t *output_ptr = output + t * gru_parm->output_step_; - for (int i = 0; i < gru_parm->batch_ * gru_parm->hidden_size_; i++) { + for (int t = check_seq_len; t < gru_param->seq_len_; t++) { + float16_t *output_ptr = output + t * gru_param->output_step_; + for (int i = 0; i < gru_param->batch_ * gru_param->hidden_size_; i++) { output_ptr[i] = 0.0f; } } // backward - if (gru_parm->bidirectional_) { - input_update_weight = weight_g + gru_parm->input_size_ * gru_parm->hidden_size_ * 3; - input_reset_weight = weight_g + gru_parm->input_size_ * gru_parm->hidden_size_ * 4; - input_hidden_weight = weight_g + gru_parm->input_size_ * gru_parm->hidden_size_ * 5; - - state_update_weight = weight_r + gru_parm->hidden_size_ * gru_parm->hidden_size_ * 3; - state_reset_weight = weight_r + gru_parm->hidden_size_ * gru_parm->hidden_size_ * 4; - state_hidden_weight = weight_r + gru_parm->hidden_size_ * gru_parm->hidden_size_ * 5; - - float16_t *backward_output = output + gru_parm->batch_ * gru_parm->hidden_size_; - const float16_t *backward_bias = bias + 3 * gru_parm->hidden_size_; - float16_t *backward_hidden_state = hidden_state + gru_parm->batch_ * gru_parm->hidden_size_; + if (gru_param->bidirectional_) { + const float16_t *backward_weight_g = weight_g + 3 * gru_param->col_align_ * gru_param->input_size_; + const float16_t *backward_weight_r = weight_r + 3 * gru_param->col_align_ * gru_param->hidden_size_; + const float16_t *backward_bias = bias + 6 * gru_param->hidden_size_; + float16_t *backward_output = output + gru_param->batch_ * gru_param->hidden_size_; + float16_t *backward_hidden_state = hidden_state + gru_param->batch_ * gru_param->hidden_size_; for (int t = check_seq_len - 1; t >= 0; t--) { - const float16_t *input_ptr = input + t * gru_parm->input_step_; - float16_t *output_ptr = backward_output + t * gru_parm->output_step_; - GruStepUnitFp16(output_ptr, input_ptr, input_reset_weight, input_update_weight, input_hidden_weight, - state_reset_weight, state_update_weight, state_hidden_weight, backward_bias, - backward_hidden_state, gate_buffer, gru_parm); + const float16_t *input_ptr = input + t * gru_param->input_step_; + float16_t *output_ptr = backward_output + t * gru_param->output_step_; + GruStepUnitFp16(output_ptr, input_ptr, backward_weight_g, backward_weight_r, backward_bias, backward_hidden_state, + gate_buffer, matmul_buffer, gru_param); } // zero out extra bw outputs - for (int t = gru_parm->seq_len_ - 1; t >= check_seq_len; t--) { - float16_t *output_ptr = backward_output + t * gru_parm->output_step_; - for (int i = 0; i < gru_parm->batch_ * gru_parm->hidden_size_; i++) { + for (int t = gru_param->seq_len_ - 1; t >= check_seq_len; t--) { + float16_t *output_ptr = backward_output + t * gru_param->output_step_; + for (int i = 0; i < gru_param->batch_ * gru_param->hidden_size_; i++) { output_ptr[i] = 0.0f; } } diff --git a/mindspore/lite/nnacl/fp16/gru_fp16.h b/mindspore/lite/nnacl/fp16/gru_fp16.h index 4d23cc0e966..2171227b72d 100644 --- a/mindspore/lite/nnacl/fp16/gru_fp16.h +++ b/mindspore/lite/nnacl/fp16/gru_fp16.h @@ -21,8 +21,8 @@ extern "C" { #endif void GruFp16(float16_t *output, const float16_t *input, const float16_t *weight_g, const float16_t *weight_r, - const float16_t *bias, float16_t *hidden_state, float16_t *gate_buffer, int check_seq_len, - const GruParameter *gru_parm); + const float16_t *bias, float16_t *hidden_state, float16_t *gate_buffer, float16_t *matmul_buffer[2], + int check_seq_len, const GruParameter *gru_param); #ifdef __cplusplus } #endif diff --git a/mindspore/lite/nnacl/fp16/lstm_fp16.c b/mindspore/lite/nnacl/fp16/lstm_fp16.c index dc7effd265a..8ffd6e174d3 100644 --- a/mindspore/lite/nnacl/fp16/lstm_fp16.c +++ b/mindspore/lite/nnacl/fp16/lstm_fp16.c @@ -18,17 +18,21 @@ #include #include "nnacl/fp16/activation_fp16.h" #include "nnacl/fp16/arithmetic_fp16.h" +#include "nnacl/fp16/matmul_fp16.h" -void InitGateFp16(float16_t *gate_buffer, const float16_t *bias, const LstmParameter *lstm_parm) { - int gate_offest = 0; - for (int l = 0; l < 4; l++) { - int batch_offest = gate_offest; - int bias_offest = l * lstm_parm->hidden_size_; - for (int b = 0; b < lstm_parm->batch_; b++) { - memcpy(gate_buffer + batch_offest, bias + bias_offest, lstm_parm->hidden_size_ * sizeof(float16_t)); - batch_offest += lstm_parm->hidden_size_; - } - gate_offest += lstm_parm->batch_ * lstm_parm->hidden_size_; +void PackLstmWeightFp32ToFp16(float16_t *dst, const float *src, int batch, int deep, int col, int col_align) { + for (int i = 0; i < batch; i++) { + const float *src_batch = src + i * col * deep; + float16_t *dst_batch = dst + i * col_align * deep; + RowMajor2Col8MajorFp16(src_batch, dst_batch, col, deep, true); + } +} + +void PackLstmWeightFp16(float16_t *dst, const float16_t *src, int batch, int deep, int col, int col_align) { + for (int i = 0; i < batch; i++) { + const float16_t *src_batch = src + i * col * deep; + float16_t *dst_batch = dst + i * col_align * deep; + RowMajor2Col8MajorFp16(src_batch, dst_batch, col, deep, false); } } @@ -125,111 +129,111 @@ void UpdataOutputFp16(const float16_t *cell_state, float16_t *output_gate, float } } -void LstmStepUnitFp16(float16_t *output, const float16_t *input, const float16_t *input_input_weight, - const float16_t *input_forget_weight, const float16_t *input_cell_weight, - const float16_t *input_output_weight, const float16_t *state_input_weight, - const float16_t *state_forget_weight, const float16_t *state_cell_weight, - const float16_t *state_output_weight, const float16_t *bias, float16_t *hidden_state, +void LstmMatMulFp16(float16_t *c, const float16_t *a, const float16_t *b, const float16_t *bias, int row, int deep, + int col, bool is_vec) { + if (is_vec) { + memcpy(c, bias, col * sizeof(float16_t)); + MatMulAccFp16(c, a, b, row, col, deep); + } else { + MatMulFp16(a, b, c, bias, ActType_No, deep, row, col, col, OutType_Nhwc); + } +} + +void UpdateLstmGateFp16(float16_t *gate_buffer, const float16_t *input, const float16_t *weight, const float16_t *bias, + int row, int deep, int col, int col_align, bool is_vec) { + for (int i = 0; i < 4; i++) { + const float16_t *weight_i = weight + deep * col * i; + const float16_t *bias_i = bias + col_align * i; + float16_t *gate = gate_buffer + row * col * i; + LstmMatMulFp16(gate, input, weight_i, bias_i, row, deep, col, is_vec); + } +} + +void LstmStepUnitFp16(float16_t *output, const float16_t *input, const float16_t *input_weight, + const float16_t *state_weight, const float16_t *bias, float16_t *hidden_state, float16_t *cell_state, float16_t *gate_buffer, float16_t *state_buffer, - const LstmParameter *lstm_parm) { - InitGateFp16(gate_buffer, bias, lstm_parm); - - float16_t *input_gate = gate_buffer; - float16_t *forget_gate = gate_buffer + lstm_parm->batch_ * lstm_parm->hidden_size_ * 2; - float16_t *cell_gate = gate_buffer + lstm_parm->batch_ * lstm_parm->hidden_size_ * 3; - float16_t *output_gate = gate_buffer + lstm_parm->batch_ * lstm_parm->hidden_size_ * 1; - + float16_t *matmul_buffer[2], const LstmParameter *lstm_param) { + bool is_vec = lstm_param->batch_ == 1; // input * weight - MatMulAccFp16(input_gate, input, input_input_weight, lstm_parm->batch_, lstm_parm->hidden_size_, - lstm_parm->input_size_); - MatMulAccFp16(forget_gate, input, input_forget_weight, lstm_parm->batch_, lstm_parm->hidden_size_, - lstm_parm->input_size_); - MatMulAccFp16(cell_gate, input, input_cell_weight, lstm_parm->batch_, lstm_parm->hidden_size_, - lstm_parm->input_size_); - MatMulAccFp16(output_gate, input, input_output_weight, lstm_parm->batch_, lstm_parm->hidden_size_, - lstm_parm->input_size_); + if (is_vec) { + UpdateLstmGateFp16(gate_buffer, input, input_weight, bias, lstm_param->batch_, lstm_param->input_size_, + lstm_param->hidden_size_, lstm_param->col_align_, is_vec); + } else { + // pack input for matmul + RowMajor2Col16MajorFp16(input, matmul_buffer[0], lstm_param->batch_, lstm_param->input_size_, false); + UpdateLstmGateFp16(gate_buffer, matmul_buffer[0], input_weight, bias, lstm_param->batch_, lstm_param->input_size_, + lstm_param->hidden_size_, lstm_param->col_align_, is_vec); + } // state * weight - MatMulAccFp16(input_gate, hidden_state, state_input_weight, lstm_parm->batch_, lstm_parm->hidden_size_, - lstm_parm->hidden_size_); - MatMulAccFp16(forget_gate, hidden_state, state_forget_weight, lstm_parm->batch_, lstm_parm->hidden_size_, - lstm_parm->hidden_size_); - MatMulAccFp16(cell_gate, hidden_state, state_cell_weight, lstm_parm->batch_, lstm_parm->hidden_size_, - lstm_parm->hidden_size_); - MatMulAccFp16(output_gate, hidden_state, state_output_weight, lstm_parm->batch_, lstm_parm->hidden_size_, - lstm_parm->hidden_size_); + float16_t *state_gate = gate_buffer + lstm_param->batch_ * lstm_param->hidden_size_ * 4; + const float16_t *state_bias = bias + lstm_param->col_align_ * 4; + if (is_vec) { + UpdateLstmGateFp16(state_gate, hidden_state, state_weight, state_bias, lstm_param->batch_, lstm_param->hidden_size_, + lstm_param->hidden_size_, lstm_param->col_align_, is_vec); + } else { + // pack state for matmul + RowMajor2Col16MajorFp16(hidden_state, matmul_buffer[1], lstm_param->batch_, lstm_param->hidden_size_, false); + UpdateLstmGateFp16(state_gate, matmul_buffer[1], state_weight, state_bias, lstm_param->batch_, + lstm_param->hidden_size_, lstm_param->hidden_size_, lstm_param->col_align_, is_vec); + } + ElementAddFp16(gate_buffer, state_gate, gate_buffer, 4 * lstm_param->batch_ * lstm_param->hidden_size_); + float16_t *input_gate = gate_buffer; + float16_t *forget_gate = gate_buffer + lstm_param->batch_ * lstm_param->hidden_size_ * 2; + float16_t *cell_gate = gate_buffer + lstm_param->batch_ * lstm_param->hidden_size_ * 3; + float16_t *output_gate = gate_buffer + lstm_param->batch_ * lstm_param->hidden_size_; // update input_gate - SigmoidFp16(input_gate, input_gate, lstm_parm->batch_ * lstm_parm->hidden_size_); + SigmoidFp16(input_gate, input_gate, lstm_param->batch_ * lstm_param->hidden_size_); // update forget_gate - SigmoidFp16(forget_gate, forget_gate, lstm_parm->batch_ * lstm_parm->hidden_size_); + SigmoidFp16(forget_gate, forget_gate, lstm_param->batch_ * lstm_param->hidden_size_); // update cell_gate - TanhFp16(cell_gate, cell_gate, lstm_parm->batch_ * lstm_parm->hidden_size_); + TanhFp16(cell_gate, cell_gate, lstm_param->batch_ * lstm_param->hidden_size_); // update cell state - UpdataStateFp16(cell_state, forget_gate, input_gate, cell_gate, state_buffer, lstm_parm->batch_, - lstm_parm->hidden_size_, lstm_parm->smooth_); + UpdataStateFp16(cell_state, forget_gate, input_gate, cell_gate, state_buffer, lstm_param->batch_, + lstm_param->hidden_size_, lstm_param->smooth_); // update output_gate - SigmoidFp16(output_gate, output_gate, lstm_parm->batch_ * lstm_parm->hidden_size_); + SigmoidFp16(output_gate, output_gate, lstm_param->batch_ * lstm_param->hidden_size_); // update output - UpdataOutputFp16(cell_state, output_gate, hidden_state, state_buffer, lstm_parm->batch_, lstm_parm->hidden_size_, - lstm_parm->smooth_); - memcpy(output, hidden_state, lstm_parm->batch_ * lstm_parm->hidden_size_ * sizeof(float16_t)); + UpdataOutputFp16(cell_state, output_gate, hidden_state, state_buffer, lstm_param->batch_, lstm_param->hidden_size_, + lstm_param->smooth_); + memcpy(output, hidden_state, lstm_param->batch_ * lstm_param->hidden_size_ * sizeof(float16_t)); - if (!(lstm_parm->smooth_ >= -FLT_EPSILON && lstm_parm->smooth_ <= FLT_EPSILON)) { - memcpy(cell_state, state_buffer, lstm_parm->batch_ * lstm_parm->hidden_size_ * sizeof(float16_t)); - memcpy(hidden_state, state_buffer + lstm_parm->batch_ * lstm_parm->hidden_size_, - lstm_parm->batch_ * lstm_parm->hidden_size_ * sizeof(float16_t)); + if (!(lstm_param->smooth_ >= -FLT_EPSILON && lstm_param->smooth_ <= FLT_EPSILON)) { + memcpy(cell_state, state_buffer, lstm_param->batch_ * lstm_param->hidden_size_ * sizeof(float16_t)); + memcpy(hidden_state, state_buffer + lstm_param->batch_ * lstm_param->hidden_size_, + lstm_param->batch_ * lstm_param->hidden_size_ * sizeof(float16_t)); } } void LstmFp16(float16_t *output, const float16_t *input, const float16_t *weight_i, const float16_t *weight_h, const float16_t *bias, float16_t *hidden_state, float16_t *cell_state, float16_t *gate_buffer, - float16_t *state_buffer, const LstmParameter *lstm_parm) { + float16_t *state_buffer, float16_t *matmul_buffer[2], const LstmParameter *lstm_param) { // forward - const float16_t *input_input_weight = weight_i; - const float16_t *input_forget_weight = weight_i + lstm_parm->input_size_ * lstm_parm->hidden_size_ * 2; - const float16_t *input_cell_weight = weight_i + lstm_parm->input_size_ * lstm_parm->hidden_size_ * 3; - const float16_t *input_output_weight = weight_i + lstm_parm->input_size_ * lstm_parm->hidden_size_ * 1; - - const float16_t *state_input_weight = weight_h; - const float16_t *state_forget_weight = weight_h + lstm_parm->hidden_size_ * lstm_parm->hidden_size_ * 2; - const float16_t *state_cell_weight = weight_h + lstm_parm->hidden_size_ * lstm_parm->hidden_size_ * 3; - const float16_t *state_output_weight = weight_h + lstm_parm->hidden_size_ * lstm_parm->hidden_size_ * 1; - - for (int t = 0; t < lstm_parm->seq_len_; t++) { - const float16_t *input_ptr = input + t * lstm_parm->input_step_; - float16_t *output_ptr = output + t * lstm_parm->output_step_; - LstmStepUnitFp16(output_ptr, input_ptr, input_input_weight, input_forget_weight, input_cell_weight, - input_output_weight, state_input_weight, state_forget_weight, state_cell_weight, - state_output_weight, bias, hidden_state, cell_state, gate_buffer, state_buffer, lstm_parm); + for (int t = 0; t < lstm_param->seq_len_; t++) { + const float16_t *input_ptr = input + t * lstm_param->input_step_; + float16_t *output_ptr = output + t * lstm_param->output_step_; + LstmStepUnitFp16(output_ptr, input_ptr, weight_i, weight_h, bias, hidden_state, cell_state, gate_buffer, + state_buffer, matmul_buffer, lstm_param); } // backward - if (lstm_parm->bidirectional_) { - input_input_weight = weight_i + lstm_parm->input_size_ * lstm_parm->hidden_size_ * 4; - input_forget_weight = weight_i + lstm_parm->input_size_ * lstm_parm->hidden_size_ * 6; - input_cell_weight = weight_i + lstm_parm->input_size_ * lstm_parm->hidden_size_ * 7; - input_output_weight = weight_i + lstm_parm->input_size_ * lstm_parm->hidden_size_ * 5; - - state_input_weight = weight_h + lstm_parm->hidden_size_ * lstm_parm->hidden_size_ * 4; - state_forget_weight = weight_h + lstm_parm->hidden_size_ * lstm_parm->hidden_size_ * 6; - state_cell_weight = weight_h + lstm_parm->hidden_size_ * lstm_parm->hidden_size_ * 7; - state_output_weight = weight_h + lstm_parm->hidden_size_ * lstm_parm->hidden_size_ * 5; - - float16_t *backward_output = output + lstm_parm->batch_ * lstm_parm->hidden_size_; - const float16_t *backward_bias = bias + 4 * lstm_parm->hidden_size_; - float16_t *backward_cell_state = cell_state + lstm_parm->batch_ * lstm_parm->hidden_size_; - float16_t *backward_hidden_state = hidden_state + lstm_parm->batch_ * lstm_parm->hidden_size_; - for (int t = lstm_parm->seq_len_ - 1; t >= 0; t--) { - const float16_t *input_ptr = input + t * lstm_parm->input_step_; - float16_t *output_ptr = backward_output + t * lstm_parm->output_step_; - LstmStepUnitFp16(output_ptr, input_ptr, input_input_weight, input_forget_weight, input_cell_weight, - input_output_weight, state_input_weight, state_forget_weight, state_cell_weight, - state_output_weight, backward_bias, backward_hidden_state, backward_cell_state, gate_buffer, - state_buffer, lstm_parm); + if (lstm_param->bidirectional_) { + const float16_t *backward_weight_i = weight_i + 4 * lstm_param->col_align_ * lstm_param->input_size_; + const float16_t *backward_weight_h = weight_h + 4 * lstm_param->col_align_ * lstm_param->hidden_size_; + const float16_t *backward_bias = bias + 8 * lstm_param->col_align_; + float16_t *backward_output = output + lstm_param->batch_ * lstm_param->hidden_size_; + float16_t *backward_cell_state = cell_state + lstm_param->batch_ * lstm_param->hidden_size_; + float16_t *backward_hidden_state = hidden_state + lstm_param->batch_ * lstm_param->hidden_size_; + for (int t = lstm_param->seq_len_ - 1; t >= 0; t--) { + const float16_t *input_ptr = input + t * lstm_param->input_step_; + float16_t *output_ptr = backward_output + t * lstm_param->output_step_; + LstmStepUnitFp16(output_ptr, input_ptr, backward_weight_i, backward_weight_h, backward_bias, + backward_hidden_state, backward_cell_state, gate_buffer, state_buffer, matmul_buffer, + lstm_param); } } } diff --git a/mindspore/lite/nnacl/fp16/lstm_fp16.h b/mindspore/lite/nnacl/fp16/lstm_fp16.h index d047402e425..b5bf7ad2a9c 100644 --- a/mindspore/lite/nnacl/fp16/lstm_fp16.h +++ b/mindspore/lite/nnacl/fp16/lstm_fp16.h @@ -21,6 +21,13 @@ #ifdef __cplusplus extern "C" { #endif +void PackLstmWeightFp32ToFp16(float16_t *dst, const float *src, int batch, int deep, int col, int col_align); + +void PackLstmWeightFp16(float16_t *dst, const float16_t *src, int batch, int deep, int col, int col_align); + +void LstmMatMulFp16(float16_t *c, const float16_t *a, const float16_t *b, const float16_t *bias, int row, int deep, + int col, bool is_vec); + void MatMulAccFp16(float16_t *output, const float16_t *input, const float16_t *weight, int rows, int cols, int inner_size); @@ -30,7 +37,7 @@ int ElementOptMulAccFp16(const float16_t *input0, const float16_t input1, float1 void LstmFp16(float16_t *output, const float16_t *input, const float16_t *weight_i, const float16_t *weight_h, const float16_t *bias, float16_t *hidden_state, float16_t *cell_state, float16_t *gate_buffer, - float16_t *state_buffer, const LstmParameter *lstm_parm); + float16_t *state_buffer, float16_t *matmul_buffer[2], const LstmParameter *lstm_param); #ifdef __cplusplus } #endif diff --git a/mindspore/lite/nnacl/fp32/gru_fp32.c b/mindspore/lite/nnacl/fp32/gru_fp32.c index 8e147e495fc..9ee85761f74 100644 --- a/mindspore/lite/nnacl/fp32/gru_fp32.c +++ b/mindspore/lite/nnacl/fp32/gru_fp32.c @@ -40,7 +40,7 @@ void GruStepUnit(float *output, const float *input, const float *input_weight, c gru_param->hidden_size_, gru_param->col_align_, is_vec); } else { // pack input for matmul - PackLstmInput(matmul_buffer[0], input, gru_param->batch_, gru_param->input_size_); + PackLstmInput(input, matmul_buffer[0], gru_param->batch_, gru_param->input_size_); UpdateGruInputGate(gate_buffer, matmul_buffer[0], input_weight, bias, gru_param->batch_, gru_param->input_size_, gru_param->hidden_size_, gru_param->col_align_, is_vec); } @@ -62,7 +62,7 @@ void GruStepUnit(float *output, const float *input, const float *input_weight, c LstmMatMul(state_update_gate, hidden_state, state_update_weight, state_update_bias, gru_param->batch_, gru_param->hidden_size_, gru_param->hidden_size_, is_vec); } else { - PackLstmInput(matmul_buffer[1], hidden_state, gru_param->batch_, gru_param->hidden_size_); + PackLstmInput(hidden_state, matmul_buffer[1], gru_param->batch_, gru_param->hidden_size_); LstmMatMul(state_reset_gate, matmul_buffer[1], state_reset_weight, state_reset_bias, gru_param->batch_, gru_param->hidden_size_, gru_param->hidden_size_, is_vec); LstmMatMul(state_update_gate, matmul_buffer[1], state_update_weight, state_update_bias, gru_param->batch_, @@ -83,7 +83,7 @@ void GruStepUnit(float *output, const float *input, const float *input_weight, c LstmMatMul(state_hidden_buffer, reset_gate, state_hidden_weight, state_hidden_bias, gru_param->batch_, gru_param->hidden_size_, gru_param->hidden_size_, is_vec); } else { - PackLstmInput(matmul_buffer[1], reset_gate, gru_param->batch_, gru_param->hidden_size_); + PackLstmInput(reset_gate, matmul_buffer[1], gru_param->batch_, gru_param->hidden_size_); LstmMatMul(state_hidden_buffer, matmul_buffer[1], state_hidden_weight, state_hidden_bias, gru_param->batch_, gru_param->hidden_size_, gru_param->hidden_size_, is_vec); } diff --git a/mindspore/lite/nnacl/fp32/lstm_fp32.c b/mindspore/lite/nnacl/fp32/lstm_fp32.c index 13c416f59ec..85bbe7a6488 100644 --- a/mindspore/lite/nnacl/fp32/lstm_fp32.c +++ b/mindspore/lite/nnacl/fp32/lstm_fp32.c @@ -35,7 +35,7 @@ void PackLstmWeight(float *dst, const float *src, int batch, int deep, int col, } } -void PackLstmInput(float *dst, const float *src, int row, int deep) { +void PackLstmInput(const float *src, float *dst, int row, int deep) { #ifdef ENABLE_AVX RowMajor2Col6Major(src, dst, row, deep); #elif defined(ENABLE_SSE) @@ -174,7 +174,7 @@ void LstmStepUnit(float *output, const float *input, const float *input_weight, lstm_param->hidden_size_, lstm_param->col_align_, is_vec); } else { // pack input for matmul - PackLstmInput(matmul_buffer[0], input, lstm_param->batch_, lstm_param->input_size_); + PackLstmInput(input, matmul_buffer[0], lstm_param->batch_, lstm_param->input_size_); UpdateLstmGate(gate_buffer, matmul_buffer[0], input_weight, bias, lstm_param->batch_, lstm_param->input_size_, lstm_param->hidden_size_, lstm_param->col_align_, is_vec); } @@ -187,7 +187,7 @@ void LstmStepUnit(float *output, const float *input, const float *input_weight, lstm_param->hidden_size_, lstm_param->col_align_, is_vec); } else { // pack state for matmul - PackLstmInput(matmul_buffer[1], hidden_state, lstm_param->batch_, lstm_param->hidden_size_); + PackLstmInput(hidden_state, matmul_buffer[1], lstm_param->batch_, lstm_param->hidden_size_); UpdateLstmGate(state_gate, matmul_buffer[1], state_weight, state_bias, lstm_param->batch_, lstm_param->hidden_size_, lstm_param->hidden_size_, lstm_param->col_align_, is_vec); } @@ -238,7 +238,7 @@ void Lstm(float *output, const float *input, const float *weight_i, const float if (lstm_param->bidirectional_) { const float *backward_weight_i = weight_i + 4 * lstm_param->col_align_ * lstm_param->input_size_; const float *backward_weight_h = weight_h + 4 * lstm_param->col_align_ * lstm_param->hidden_size_; - const float *backward_bias = bias + 8 * lstm_param->hidden_size_; + const float *backward_bias = bias + 8 * lstm_param->col_align_; float *backward_output = output + lstm_param->batch_ * lstm_param->hidden_size_; float *backward_cell_state = cell_state + lstm_param->batch_ * lstm_param->hidden_size_; float *backward_hidden_state = hidden_state + lstm_param->batch_ * lstm_param->hidden_size_; diff --git a/mindspore/lite/nnacl/fp32/lstm_fp32.h b/mindspore/lite/nnacl/fp32/lstm_fp32.h index 5e2d3d6176a..2bc060dd8ca 100644 --- a/mindspore/lite/nnacl/fp32/lstm_fp32.h +++ b/mindspore/lite/nnacl/fp32/lstm_fp32.h @@ -23,7 +23,7 @@ extern "C" { #endif void PackLstmWeight(float *dst, const float *src, int batch, int deep, int col, int col_align); -void PackLstmInput(float *dst, const float *src, int row, int deep); +void PackLstmInput(const float *src, float *dst, int row, int deep); void LstmMatMul(float *c, const float *a, const float *b, const float *bias, int row, int deep, int col, bool is_vec); @@ -33,7 +33,7 @@ int ElementOptMulAcc(const float *input0, const float input1, float *output, con void Lstm(float *output, const float *input, const float *weight_i, const float *weight_h, const float *bias, float *hidden_state, float *cell_state, float *gate_buffer, float *state_buffer, float *matmul_buffer[2], - const LstmParameter *lstm_parm); + const LstmParameter *lstm_param); #ifdef __cplusplus } #endif diff --git a/mindspore/lite/src/runtime/kernel/arm/fp16/gru_fp16.cc b/mindspore/lite/src/runtime/kernel/arm/fp16/gru_fp16.cc index 60040138977..c391993d686 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp16/gru_fp16.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp16/gru_fp16.cc @@ -19,6 +19,8 @@ #include "src/kernel_registry.h" #include "include/errorcode.h" #include "nnacl/fp16/gru_fp16.h" +#include "nnacl/fp16/cast_fp16.h" +#include "nnacl/fp16/lstm_fp16.h" using mindspore::kernel::KERNEL_ARCH::kCPU; using mindspore::lite::KernelRegistrar; @@ -28,21 +30,32 @@ using mindspore::schema::PrimitiveType_Gru; namespace mindspore::kernel { void GruFp16CPUKernel::FreeTmpBuffer() { - if (gate_buffer_ != nullptr) { - free(gate_buffer_); - gate_buffer_ = nullptr; + if (!is_vec_ || in_tensors_[1]->data_type() == kNumberTypeFloat32) { + if (weight_g_ptr_ != nullptr) { + free(weight_g_ptr_); + weight_g_ptr_ = nullptr; + } } - if (bias_ptr_ != nullptr) { - free(bias_ptr_); - bias_ptr_ = nullptr; + if (!is_vec_ || in_tensors_[2]->data_type() == kNumberTypeFloat32) { + if (weight_r_ptr_ != nullptr) { + free(weight_r_ptr_); + weight_r_ptr_ = nullptr; + } } - if (weight_g_ptr_ != nullptr) { - free(weight_g_ptr_); - weight_g_ptr_ = nullptr; + if (!is_vec_ || in_tensors_[3]->data_type() == kNumberTypeFloat32) { + if (bias_ptr_ != nullptr) { + free(bias_ptr_); + bias_ptr_ = nullptr; + } } - if (weight_r_ptr_ != nullptr) { - free(weight_r_ptr_); - weight_r_ptr_ = nullptr; +} + +void GruFp16CPUKernel::FreeRunBuffer() { + context_->allocator->Free(gate_buffer_); + if (!is_vec_) { + for (int i = 0; i < 2; i++) { + context_->allocator->Free(matmul_buffer_[i]); + } } } @@ -50,75 +63,115 @@ int GruFp16CPUKernel::InitParam() { auto input = in_tensors_.front(); MS_ASSERT(input != nullptr); std::vector in_shape = input->shape(); - gru_parm_->seq_len_ = in_shape.at(0); - gru_parm_->batch_ = in_shape.at(1); - gru_parm_->input_size_ = in_shape.at(2); + gru_param_->seq_len_ = in_shape.at(0); + gru_param_->batch_ = in_shape.at(1); + gru_param_->input_size_ = in_shape.at(2); auto weight_g = in_tensors_.at(1); MS_ASSERT(weight_g != nullptr); std::vector w_shape = weight_g->shape(); - gru_parm_->hidden_size_ = w_shape.at(1) / 3; + gru_param_->hidden_size_ = w_shape.at(1) / 3; - gru_parm_->input_step_ = gru_parm_->batch_ * gru_parm_->input_size_; - gru_parm_->output_step_ = gru_parm_->bidirectional_ ? 2 * gru_parm_->batch_ * gru_parm_->hidden_size_ - : gru_parm_->batch_ * gru_parm_->hidden_size_; + gru_param_->input_step_ = gru_param_->batch_ * gru_param_->input_size_; + gru_param_->output_step_ = gru_param_->bidirectional_ ? 2 * gru_param_->batch_ * gru_param_->hidden_size_ + : gru_param_->batch_ * gru_param_->hidden_size_; + + is_vec_ = gru_param_->batch_ == 1; + gru_param_->row_align_ = is_vec_ ? gru_param_->batch_ : UP_ROUND(gru_param_->batch_, C16NUM); + gru_param_->col_align_ = is_vec_ ? gru_param_->hidden_size_ : UP_ROUND(gru_param_->hidden_size_, C8NUM); return RET_OK; } -int GruFp16CPUKernel::InitBuffer() { - gate_buffer_ = - reinterpret_cast(malloc(3 * gru_parm_->batch_ * gru_parm_->hidden_size_ * sizeof(float16_t))); - if (gate_buffer_ == nullptr) { - MS_LOG(ERROR) << "GruFp16CPUKernel malloc gate_buffer error."; +int GruFp16CPUKernel::InitWeight(const lite::Tensor *tensor, float16_t *ptr, int deep) { + auto weight_batch = gru_param_->bidirectional_ ? 6 : 3; + if (tensor->data_type() == kNumberTypeFloat32) { + auto weight_data = reinterpret_cast(tensor->data_c()); + is_vec_ ? Float32ToFloat16(weight_data, ptr, tensor->ElementsNum()) + : PackLstmWeightFp32ToFp16(ptr, weight_data, weight_batch, deep, gru_param_->hidden_size_, + gru_param_->col_align_); + } else if (tensor->data_type() == kNumberTypeFloat16) { + auto weight_data = reinterpret_cast(tensor->data_c()); + if (is_vec_) { + ptr = weight_data; + } else { + PackLstmWeightFp16(ptr, weight_data, weight_batch, deep, gru_param_->hidden_size_, gru_param_->col_align_); + } + } else { + MS_LOG(ERROR) << "Unsupported data type of weight tensor for lstm."; return RET_ERROR; } return RET_OK; } int GruFp16CPUKernel::InitWeightBias() { - auto weight_gate = in_tensors_.at(1); - MS_ASSERT(weight_gate != nullptr); - weight_g_ptr_ = reinterpret_cast(malloc(weight_gate->ElementsNum() * sizeof(float16_t))); - if (weight_g_ptr_ == nullptr) { - MS_LOG(ERROR) << "GruFp16CPUKernel malloc weight_g_ptr_ error."; - return RET_ERROR; - } - auto weight_g_data = reinterpret_cast(weight_gate->data_c()); - for (size_t i = 0; i < weight_gate->ElementsNum(); i++) { - weight_g_ptr_[i] = (float16_t)weight_g_data[i]; - } - - auto weight_recu = in_tensors_.at(2); - MS_ASSERT(weight_recu != nullptr); - weight_r_ptr_ = reinterpret_cast(malloc(weight_recu->ElementsNum() * sizeof(float16_t))); - if (weight_r_ptr_ == nullptr) { - MS_LOG(ERROR) << "GruFp16CPUKernel malloc weight_r_ptr_ error."; - return RET_ERROR; - } - auto weight_r_data = reinterpret_cast(weight_recu->data_c()); - for (size_t i = 0; i < weight_recu->ElementsNum(); i++) { - weight_r_ptr_[i] = (float16_t)weight_r_data[i]; - } - - int bias_num = gru_parm_->bidirectional_ ? 2 * 3 * gru_parm_->hidden_size_ : 3 * gru_parm_->hidden_size_; - bias_ptr_ = reinterpret_cast(malloc(bias_num * sizeof(float16_t))); - if (bias_ptr_ == nullptr) { - MS_LOG(ERROR) << "GruFp16CPUKernel malloc bias_ptr_ error."; - return RET_ERROR; - } - - auto bias_data = reinterpret_cast(in_tensors_.at(3)->data_c()); - const int state_bias_offset = 3 * gru_parm_->hidden_size_; - for (int i = 0; i < state_bias_offset; i++) { - bias_ptr_[i] = (float16_t)(bias_data[i] + bias_data[i + state_bias_offset]); - } - if (gru_parm_->bidirectional_) { - bias_data += 3 * gru_parm_->hidden_size_ * 2; - auto backward_bias = bias_ptr_ + 3 * gru_parm_->hidden_size_; - for (int i = 0; i < state_bias_offset; i++) { - backward_bias[i] = (float16_t)(bias_data[i] + bias_data[i + state_bias_offset]); + auto weight_batch = gru_param_->bidirectional_ ? 6 : 3; + // malloc and init input * weight right matrix buffer + auto weight_g = in_tensors_.at(1); + MS_ASSERT(weight_g != nullptr); + if (!is_vec_ || weight_g->data_type() == kNumberTypeFloat32) { + weight_g_ptr_ = reinterpret_cast( + malloc(weight_batch * gru_param_->col_align_ * gru_param_->input_size_ * sizeof(float16_t))); + if (weight_g_ptr_ == nullptr) { + MS_LOG(ERROR) << "GruFp16CPUKernel malloc weight_g_ptr_ error."; + return RET_ERROR; } } + auto ret = InitWeight(weight_g, weight_g_ptr_, gru_param_->input_size_); + if (ret != RET_OK) { + MS_LOG(ERROR) << "GruFp16CPUKernel init weight_g failed."; + return RET_ERROR; + } + + // malloc and init state * weight right matrix buffer + auto weight_r = in_tensors_.at(2); + MS_ASSERT(weight_r != nullptr); + if (!is_vec_ || weight_r->data_type() == kNumberTypeFloat32) { + weight_r_ptr_ = reinterpret_cast( + malloc(weight_batch * gru_param_->col_align_ * gru_param_->hidden_size_ * sizeof(float16_t))); + if (weight_r_ptr_ == nullptr) { + MS_LOG(ERROR) << "GruFp16CPUKernel malloc weight_r_ptr_ error."; + return RET_ERROR; + } + } + ret = InitWeight(weight_r, weight_r_ptr_, gru_param_->hidden_size_); + if (ret != RET_OK) { + MS_LOG(ERROR) << "GruFp16CPUKernel init weight_r failed."; + return RET_ERROR; + } + + int bias_batch = gru_param_->bidirectional_ ? 12 : 6; + auto bias = in_tensors_.at(3); + MS_ASSERT(bias != nullptr); + if (!is_vec_ || bias->data_type() == kNumberTypeFloat32) { + bias_ptr_ = reinterpret_cast(malloc(bias_batch * gru_param_->col_align_ * sizeof(float16_t))); + if (bias_ptr_ == nullptr) { + MS_LOG(ERROR) << "GruFp16CPUKernel malloc bias_ptr_ error."; + return RET_ERROR; + } + memset(bias_ptr_, 0, bias_batch * gru_param_->col_align_ * sizeof(float16_t)); + } + if (bias->data_type() == kNumberTypeFloat32) { + auto bias_data = reinterpret_cast(bias->data_c()); + for (int i = 0; i < bias_batch; i++) { + auto src_batch = bias_data + i * gru_param_->hidden_size_; + auto dst_batch = bias_ptr_ + i * gru_param_->col_align_; + Float32ToFloat16(src_batch, dst_batch, gru_param_->hidden_size_); + } + } else if (bias->data_type() == kNumberTypeFloat16) { + auto bias_data = reinterpret_cast(bias->data_c()); + if (is_vec_) { + bias_ptr_ = bias_data; + } else { + for (int i = 0; i < bias_batch; i++) { + auto src_batch = bias_data + i * gru_param_->hidden_size_; + auto dst_batch = bias_ptr_ + i * gru_param_->col_align_; + memcpy(dst_batch, src_batch, gru_param_->hidden_size_ * sizeof(float16_t)); + } + } + } else { + MS_LOG(ERROR) << "Unsupported data type of bias tensor for lstm."; + return RET_ERROR; + } return RET_OK; } @@ -130,24 +183,43 @@ int GruFp16CPUKernel::Init() { } int GruFp16CPUKernel::ReSize() { - FreeTmpBuffer(); auto ret = InitParam(); if (ret != RET_OK) { MS_LOG(ERROR) << "GruFp16CPUKernel InitParam error."; return RET_ERROR; } + FreeTmpBuffer(); ret = InitWeightBias(); if (ret != RET_OK) { MS_LOG(ERROR) << "GruFp16CPUKernel InitWeightBias error."; FreeTmpBuffer(); return RET_ERROR; } + return RET_OK; +} - ret = InitBuffer(); - if (ret != RET_OK) { - MS_LOG(ERROR) << "GruFp16CPUKernel InitBuffer error."; - FreeTmpBuffer(); +int GruFp16CPUKernel::MallocRunBuffer() { + if (!is_vec_) { + matmul_buffer_[0] = reinterpret_cast( + context_->allocator->Malloc(3 * gru_param_->row_align_ * gru_param_->input_size_ * sizeof(float16_t))); + if (matmul_buffer_[0] == nullptr) { + MS_LOG(ERROR) << "GruFp16CPUKernel malloc input * weight left matirx error."; + return RET_ERROR; + } + + matmul_buffer_[1] = reinterpret_cast( + context_->allocator->Malloc(3 * gru_param_->row_align_ * gru_param_->hidden_size_ * sizeof(float16_t))); + if (matmul_buffer_[1] == nullptr) { + MS_LOG(ERROR) << "GruFp16CPUKernel malloc state * weight left matirx error."; + return RET_ERROR; + } + } + + gate_buffer_ = reinterpret_cast( + context_->allocator->Malloc(4 * gru_param_->batch_ * gru_param_->hidden_size_ * sizeof(float16_t))); + if (gate_buffer_ == nullptr) { + MS_LOG(ERROR) << "GruFp16CPUKernel malloc gate_buffer error."; return RET_ERROR; } return RET_OK; @@ -166,22 +238,29 @@ int GruFp16CPUKernel::Run() { MS_ASSERT(output_ptr); auto output_hidden_state = out_tensors_[1]; memcpy(output_hidden_state->data_c(), hidden_state->data_c(), hidden_state->ElementsNum() * sizeof(float16_t)); - int check_seq_len = gru_parm_->seq_len_; + int check_seq_len = gru_param_->seq_len_; if (in_tensors_.size() == 6) { auto seq_len = reinterpret_cast(in_tensors_.at(5)->data_c()); - if (!std::equal(seq_len + 1, seq_len + gru_parm_->batch_, seq_len)) { + if (!std::equal(seq_len + 1, seq_len + gru_param_->batch_, seq_len)) { MS_LOG(ERROR) << "different batch seq_len is currently not supported"; return RET_ERROR; } check_seq_len = MSMIN(check_seq_len, MSMAX(0, seq_len[0])); } + auto ret = MallocRunBuffer(); + if (ret != RET_OK) { + MS_LOG(ERROR) << "GruFp16CPUKernel MallocRunBuffer error."; + return RET_ERROR; + } MS_ASSERT(weight_g_ptr_ != nullptr); MS_ASSERT(weight_r_ptr_ != nullptr); MS_ASSERT(bias_ptr_ != nullptr); MS_ASSERT(gate_buffer_ != nullptr); GruFp16(output_ptr, input_ptr, weight_g_ptr_, weight_r_ptr_, bias_ptr_, - reinterpret_cast(output_hidden_state->data_c()), gate_buffer_, check_seq_len, gru_parm_); + reinterpret_cast(output_hidden_state->data_c()), gate_buffer_, matmul_buffer_, check_seq_len, + gru_param_); + FreeRunBuffer(); return RET_OK; } diff --git a/mindspore/lite/src/runtime/kernel/arm/fp16/gru_fp16.h b/mindspore/lite/src/runtime/kernel/arm/fp16/gru_fp16.h index ca470893796..8529c80a5b4 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp16/gru_fp16.h +++ b/mindspore/lite/src/runtime/kernel/arm/fp16/gru_fp16.h @@ -26,7 +26,7 @@ class GruFp16CPUKernel : public LiteKernel { const std::vector &outputs, const lite::InnerContext *ctx, const mindspore::lite::PrimitiveC *primitive) : LiteKernel(parameter, inputs, outputs, ctx, primitive) { - gru_parm_ = reinterpret_cast(op_parameter_); + gru_param_ = reinterpret_cast(op_parameter_); } ~GruFp16CPUKernel() override { FreeTmpBuffer(); } @@ -37,15 +37,19 @@ class GruFp16CPUKernel : public LiteKernel { private: void FreeTmpBuffer(); + void FreeRunBuffer(); int InitParam(); - int InitBuffer(); + int InitWeight(const lite::Tensor *tensor, float16_t *ptr, int deep); int InitWeightBias(); + int MallocRunBuffer(); float16_t *gate_buffer_ = nullptr; float16_t *weight_g_ptr_ = nullptr; float16_t *weight_r_ptr_ = nullptr; float16_t *bias_ptr_ = nullptr; - GruParameter *gru_parm_ = nullptr; + float16_t *matmul_buffer_[2]; + bool is_vec_ = false; + GruParameter *gru_param_ = nullptr; }; } // namespace mindspore::kernel diff --git a/mindspore/lite/src/runtime/kernel/arm/fp16/lstm_fp16.cc b/mindspore/lite/src/runtime/kernel/arm/fp16/lstm_fp16.cc index 81a08807817..669fb8db58e 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp16/lstm_fp16.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp16/lstm_fp16.cc @@ -20,6 +20,7 @@ #include "src/kernel_registry.h" #include "include/errorcode.h" #include "nnacl/fp16/lstm_fp16.h" +#include "nnacl/fp16/cast_fp16.h" using mindspore::kernel::KERNEL_ARCH::kCPU; using mindspore::lite::KernelRegistrar; @@ -29,25 +30,33 @@ using mindspore::schema::PrimitiveType_Lstm; namespace mindspore::kernel { void LstmFp16CPUKernel::FreeTmpBuffer() { - if (gate_buffer_ != nullptr) { - free(gate_buffer_); - gate_buffer_ = nullptr; + if (!is_vec_ || in_tensors_[1]->data_type() == kNumberTypeFloat32) { + if (weight_i_ptr_ != nullptr) { + free(weight_i_ptr_); + weight_i_ptr_ = nullptr; + } } - if (state_buffer_ != nullptr) { - free(state_buffer_); - state_buffer_ = nullptr; + if (!is_vec_ || in_tensors_[2]->data_type() == kNumberTypeFloat32) { + if (weight_h_ptr_ != nullptr) { + free(weight_h_ptr_); + weight_h_ptr_ = nullptr; + } } - if (weight_i_ptr_ != nullptr) { - free(weight_i_ptr_); - weight_i_ptr_ = nullptr; + if (!is_vec_ || in_tensors_[3]->data_type() == kNumberTypeFloat32) { + if (bias_ptr_ != nullptr) { + free(bias_ptr_); + bias_ptr_ = nullptr; + } } - if (weight_h_ptr_ != nullptr) { - free(weight_h_ptr_); - weight_h_ptr_ = nullptr; - } - if (bias_ptr_ != nullptr) { - free(bias_ptr_); - bias_ptr_ = nullptr; +} + +void LstmFp16CPUKernel::FreeRunBuffer() { + context_->allocator->Free(gate_buffer_); + context_->allocator->Free(state_buffer_); + if (!is_vec_) { + for (int i = 0; i < 2; i++) { + context_->allocator->Free(matmul_buffer_[i]); + } } } @@ -67,87 +76,107 @@ int LstmFp16CPUKernel::InitParam() { lstm_param_->input_step_ = lstm_param_->batch_ * lstm_param_->input_size_; lstm_param_->output_step_ = lstm_param_->bidirectional_ ? 2 * lstm_param_->batch_ * lstm_param_->hidden_size_ : lstm_param_->batch_ * lstm_param_->hidden_size_; + + is_vec_ = lstm_param_->batch_ == 1; + lstm_param_->row_align_ = is_vec_ ? lstm_param_->batch_ : UP_ROUND(lstm_param_->batch_, C16NUM); + lstm_param_->col_align_ = is_vec_ ? lstm_param_->hidden_size_ : UP_ROUND(lstm_param_->hidden_size_, C8NUM); return RET_OK; } -int LstmFp16CPUKernel::InitBuffer() { - gate_buffer_ = - reinterpret_cast(malloc(4 * lstm_param_->batch_ * lstm_param_->hidden_size_ * sizeof(float16_t))); - if (gate_buffer_ == nullptr) { - MS_LOG(ERROR) << "Lstm fp16 malloc gate_buffer error."; - return RET_ERROR; - } - if (!(lstm_param_->smooth_ >= -FLT_EPSILON && lstm_param_->smooth_ <= FLT_EPSILON)) { - int buffer_size = 2 * lstm_param_->batch_ * lstm_param_->hidden_size_ * sizeof(float16_t); - state_buffer_ = reinterpret_cast(malloc(buffer_size)); - if (state_buffer_ == nullptr) { - MS_LOG(ERROR) << "Lstm fp16 malloc state_buffer error."; - return RET_ERROR; +int LstmFp16CPUKernel::InitWeight(const lite::Tensor *tensor, float16_t *ptr, int deep) { + auto weight_batch = lstm_param_->bidirectional_ ? 8 : 4; + if (tensor->data_type() == kNumberTypeFloat32) { + auto weight_data = reinterpret_cast(tensor->data_c()); + is_vec_ ? Float32ToFloat16(weight_data, ptr, tensor->ElementsNum()) + : PackLstmWeightFp32ToFp16(ptr, weight_data, weight_batch, deep, lstm_param_->hidden_size_, + lstm_param_->col_align_); + } else if (tensor->data_type() == kNumberTypeFloat16) { + auto weight_data = reinterpret_cast(tensor->data_c()); + if (is_vec_) { + ptr = weight_data; + } else { + PackLstmWeightFp16(ptr, weight_data, weight_batch, deep, lstm_param_->hidden_size_, lstm_param_->col_align_); } + } else { + MS_LOG(ERROR) << "Unsupported data type of weight tensor for lstm."; + return RET_ERROR; } return RET_OK; } int LstmFp16CPUKernel::InitWeightBias() { - // copy weight_i and weight_h + auto weight_batch = lstm_param_->bidirectional_ ? 8 : 4; + // malloc and init input * weight right matrix buffer auto weight_i = in_tensors_.at(1); MS_ASSERT(weight_i != nullptr); - weight_i_ptr_ = reinterpret_cast(malloc(weight_i->ElementsNum() * sizeof(float16_t))); - if (weight_i_ptr_ == nullptr) { - MS_LOG(ERROR) << "Lstm fp16 malloc weight_i_ptr_ error."; + if (!is_vec_ || weight_i->data_type() == kNumberTypeFloat32) { + weight_i_ptr_ = reinterpret_cast( + malloc(weight_batch * lstm_param_->col_align_ * lstm_param_->input_size_ * sizeof(float16_t))); + if (weight_i_ptr_ == nullptr) { + MS_LOG(ERROR) << "LstmFp16CPUKernel malloc weight_i_ptr_ error."; + return RET_ERROR; + } + } + auto ret = InitWeight(weight_i, weight_i_ptr_, lstm_param_->input_size_); + if (ret != RET_OK) { + MS_LOG(ERROR) << "LstmFp16CPUKernel init weight_i failed."; return RET_ERROR; } - auto weight_i_data = reinterpret_cast(weight_i->data_c()); - for (size_t i = 0; i < weight_i->ElementsNum(); i++) { - weight_i_ptr_[i] = (float16_t)weight_i_data[i]; - } + // malloc and init state * weight right matrix buffer auto weight_h = in_tensors_.at(2); MS_ASSERT(weight_h != nullptr); - weight_h_ptr_ = reinterpret_cast(malloc(weight_h->ElementsNum() * sizeof(float16_t))); - if (weight_h_ptr_ == nullptr) { - MS_LOG(ERROR) << "Lstm fp16 malloc weight_h_ error."; - return RET_ERROR; - } - auto weight_h_data = reinterpret_cast(weight_h->data_c()); - for (size_t i = 0; i < weight_h->ElementsNum(); i++) { - weight_h_ptr_[i] = (float16_t)weight_h_data[i]; - } - - std::vector w_shape = weight_i->shape(); - auto hidden_size = w_shape.at(1) / 4; - // init bias - int bias_num = lstm_param_->bidirectional_ ? 2 * 4 * hidden_size : 4 * hidden_size; - bias_ptr_ = reinterpret_cast(malloc(bias_num * sizeof(float16_t))); - if (bias_ptr_ == nullptr) { - MS_LOG(ERROR) << "Lstm fp16 malloc bias_ptr_ error."; - return RET_ERROR; - } - - auto bias_data = reinterpret_cast(in_tensors_.at(3)->data_c()); - const int state_bias_offset = 4 * hidden_size; - for (int i = 0; i < state_bias_offset; i++) { - bias_ptr_[i] = (float16_t)(bias_data[i] + bias_data[i + state_bias_offset]); - } - if (lstm_param_->bidirectional_) { - bias_data += 4 * hidden_size * 2; - auto backward_bias = bias_ptr_ + 4 * hidden_size; - for (int i = 0; i < state_bias_offset; i++) { - backward_bias[i] = (float16_t)(bias_data[i] + bias_data[i + state_bias_offset]); + if (!is_vec_ || weight_h->data_type() == kNumberTypeFloat32) { + weight_h_ptr_ = reinterpret_cast( + malloc(weight_batch * lstm_param_->col_align_ * lstm_param_->hidden_size_ * sizeof(float16_t))); + if (weight_h_ptr_ == nullptr) { + MS_LOG(ERROR) << "LstmFp16CPUKernel malloc weight_h_ptr_ error."; + return RET_ERROR; } } + ret = InitWeight(weight_h, weight_h_ptr_, lstm_param_->hidden_size_); + if (ret != RET_OK) { + MS_LOG(ERROR) << "LstmFp16CPUKernel init weight_h failed."; + return RET_ERROR; + } + + int bias_batch = lstm_param_->bidirectional_ ? 16 : 8; + auto bias = in_tensors_.at(3); + MS_ASSERT(bias != nullptr); + if (!is_vec_ || bias->data_type() == kNumberTypeFloat32) { + bias_ptr_ = reinterpret_cast(malloc(bias_batch * lstm_param_->col_align_ * sizeof(float16_t))); + if (bias_ptr_ == nullptr) { + MS_LOG(ERROR) << "LstmFp16CPUKernel malloc bias_ptr_ error."; + return RET_ERROR; + } + memset(bias_ptr_, 0, bias_batch * lstm_param_->col_align_ * sizeof(float16_t)); + } + if (bias->data_type() == kNumberTypeFloat32) { + auto bias_data = reinterpret_cast(bias->data_c()); + for (int i = 0; i < bias_batch; i++) { + auto src_batch = bias_data + i * lstm_param_->hidden_size_; + auto dst_batch = bias_ptr_ + i * lstm_param_->col_align_; + Float32ToFloat16(src_batch, dst_batch, lstm_param_->hidden_size_); + } + } else if (bias->data_type() == kNumberTypeFloat16) { + auto bias_data = reinterpret_cast(bias->data_c()); + if (is_vec_) { + bias_ptr_ = bias_data; + } else { + for (int i = 0; i < bias_batch; i++) { + auto src_batch = bias_data + i * lstm_param_->hidden_size_; + auto dst_batch = bias_ptr_ + i * lstm_param_->col_align_; + memcpy(dst_batch, src_batch, lstm_param_->hidden_size_ * sizeof(float16_t)); + } + } + } else { + MS_LOG(ERROR) << "Unsupported data type of bias tensor for lstm."; + return RET_ERROR; + } return RET_OK; } int LstmFp16CPUKernel::Init() { - FreeTmpBuffer(); - auto ret = InitWeightBias(); - if (ret != RET_OK) { - MS_LOG(ERROR) << "Lstm fp16 InitWeightBias error."; - FreeTmpBuffer(); - return RET_ERROR; - } - if (!InferShapeDone()) { return RET_OK; } @@ -161,15 +190,50 @@ int LstmFp16CPUKernel::ReSize() { return RET_ERROR; } - ret = InitBuffer(); + FreeTmpBuffer(); + ret = InitWeightBias(); if (ret != RET_OK) { - MS_LOG(ERROR) << "Lstm fp16 InitBuffer error."; + MS_LOG(ERROR) << "Lstm fp16 InitWeightBias error."; FreeTmpBuffer(); return RET_ERROR; } return RET_OK; } +int LstmFp16CPUKernel::MallocRunBuffer() { + if (!is_vec_) { + matmul_buffer_[0] = reinterpret_cast( + context_->allocator->Malloc(4 * lstm_param_->row_align_ * lstm_param_->input_size_ * sizeof(float16_t))); + if (matmul_buffer_[0] == nullptr) { + MS_LOG(ERROR) << "LstmFp16CPUKernel malloc input * weight left matirx error."; + return RET_ERROR; + } + + matmul_buffer_[1] = reinterpret_cast( + context_->allocator->Malloc(4 * lstm_param_->row_align_ * lstm_param_->hidden_size_ * sizeof(float16_t))); + if (matmul_buffer_[1] == nullptr) { + MS_LOG(ERROR) << "LstmFp16CPUKernel malloc state * weight left matirx error."; + return RET_ERROR; + } + } + + gate_buffer_ = reinterpret_cast( + context_->allocator->Malloc(8 * lstm_param_->batch_ * lstm_param_->hidden_size_ * sizeof(float16_t))); + if (gate_buffer_ == nullptr) { + MS_LOG(ERROR) << "LstmFp16CPUKernel malloc gate_buffer error."; + return RET_ERROR; + } + if (!(lstm_param_->smooth_ >= -FLT_EPSILON && lstm_param_->smooth_ <= FLT_EPSILON)) { + int buffer_size = 2 * lstm_param_->batch_ * lstm_param_->hidden_size_ * sizeof(float16_t); + state_buffer_ = reinterpret_cast(context_->allocator->Malloc(buffer_size)); + if (state_buffer_ == nullptr) { + MS_LOG(ERROR) << "LstmFp16CPUKernel malloc state_buffer error."; + return RET_ERROR; + } + } + return RET_OK; +} + int LstmFp16CPUKernel::Run() { auto input = in_tensors_.at(kInputIndex); MS_ASSERT(input != nullptr); @@ -189,13 +253,20 @@ int LstmFp16CPUKernel::Run() { auto output_cell_state = out_tensors_[2]; memcpy(output_cell_state->data_c(), cell_state->data_c(), cell_state->ElementsNum() * sizeof(float16_t)); - MS_ASSERT(weight_h_ptr_); + auto ret = MallocRunBuffer(); + if (ret != RET_OK) { + MS_LOG(ERROR) << "LstmFp16CPUKernel MallocRunBuffer error."; + return RET_ERROR; + } MS_ASSERT(weight_i_ptr_); + MS_ASSERT(weight_h_ptr_); MS_ASSERT(bias_ptr_); MS_ASSERT(gate_buffer_); LstmFp16(output_ptr, input_ptr, weight_i_ptr_, weight_h_ptr_, bias_ptr_, reinterpret_cast(output_hidden_state->data_c()), - reinterpret_cast(output_cell_state->data_c()), gate_buffer_, state_buffer_, lstm_param_); + reinterpret_cast(output_cell_state->data_c()), gate_buffer_, state_buffer_, matmul_buffer_, + lstm_param_); + FreeRunBuffer(); return RET_OK; } diff --git a/mindspore/lite/src/runtime/kernel/arm/fp16/lstm_fp16.h b/mindspore/lite/src/runtime/kernel/arm/fp16/lstm_fp16.h index 4527213e376..0cc7e69ed07 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp16/lstm_fp16.h +++ b/mindspore/lite/src/runtime/kernel/arm/fp16/lstm_fp16.h @@ -39,15 +39,19 @@ class LstmFp16CPUKernel : public LiteKernel { private: void FreeTmpBuffer(); + void FreeRunBuffer(); int InitParam(); - int InitBuffer(); + int InitWeight(const lite::Tensor *tensor, float16_t *ptr, int deep); int InitWeightBias(); + int MallocRunBuffer(); float16_t *gate_buffer_ = nullptr; float16_t *state_buffer_ = nullptr; float16_t *weight_i_ptr_ = nullptr; float16_t *weight_h_ptr_ = nullptr; float16_t *bias_ptr_ = nullptr; + float16_t *matmul_buffer_[2]; + bool is_vec_ = false; LstmParameter *lstm_param_ = nullptr; }; } // namespace mindspore::kernel diff --git a/mindspore/lite/test/run_benchmark_nets.sh b/mindspore/lite/test/run_benchmark_nets.sh index 512d27abed1..209a5029d65 100755 --- a/mindspore/lite/test/run_benchmark_nets.sh +++ b/mindspore/lite/test/run_benchmark_nets.sh @@ -1548,7 +1548,7 @@ function Run_arm64() { echo 'cd /data/local/tmp/benchmark_test' > adb_run_cmd.txt echo 'export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/data/local/tmp/benchmark_test' >> adb_run_cmd.txt if [[ $accuracy_limit == "-1" ]]; then - echo './benchmark --modelFile='${model_name}'.fp16.ms --inDataFile=/data/local/tmp/input_output/input/'${model_name}'.ms.bin --benchmarkDataFile=/data/local/tmp/input_output/output/'${model_name}'.ms.out --enableFp16=true --inputShapes='${input_shapes} >> adb_run_cmd.txt + echo './benchmark --modelFile='${model_name}'.fp16.ms --inDataFile=/data/local/tmp/input_output/input/'${model_name}'.ms.bin --enableFp16=true --inputShapes='${input_shapes} >> adb_run_cmd.txt else echo './benchmark --modelFile='${model_name}'.fp16.ms --inDataFile=/data/local/tmp/input_output/input/'${model_name}'.ms.bin --benchmarkDataFile=/data/local/tmp/input_output/output/'${model_name}'.ms.out --enableFp16=true --accuracyThreshold='${accuracy_limit} ' --inputShapes='${input_shapes} >> adb_run_cmd.txt fi