forked from mindspore-Ecosystem/mindspore
!12280 [MSLITE][Develop] optimize arm cpu fp16 op lstm: use matmul
From: @yangruoqi713 Reviewed-by: @zhanghaibo5 Signed-off-by: @zhanghaibo5
This commit is contained in:
commit
347961ee6f
|
@ -18,120 +18,129 @@
|
|||
#include "nnacl/fp16/lstm_fp16.h"
|
||||
#include "nnacl/fp16/activation_fp16.h"
|
||||
#include "nnacl/fp16/arithmetic_fp16.h"
|
||||
#include "nnacl/fp16/matmul_fp16.h"
|
||||
|
||||
void InitGruGateFp16(float16_t *gate_buffer, const float16_t *bias, const GruParameter *gru_parm) {
|
||||
int gate_offest = 0;
|
||||
for (int l = 0; l < 3; l++) {
|
||||
int batch_offest = gate_offest;
|
||||
int bias_offest = l * gru_parm->hidden_size_;
|
||||
for (int b = 0; b < gru_parm->batch_; b++) {
|
||||
memcpy(gate_buffer + batch_offest, bias + bias_offest, gru_parm->hidden_size_ * sizeof(float16_t));
|
||||
batch_offest += gru_parm->hidden_size_;
|
||||
}
|
||||
gate_offest += gru_parm->batch_ * gru_parm->hidden_size_;
|
||||
void UpdateGruInputGateFp16(float16_t *gate_buffer, const float16_t *input, const float16_t *weight,
|
||||
const float16_t *bias, int row, int deep, int col, int col_align, bool is_vec) {
|
||||
for (int i = 0; i < 3; i++) {
|
||||
const float16_t *weight_i = weight + deep * col * i;
|
||||
const float16_t *bias_i = bias + col_align * i;
|
||||
float16_t *gate = gate_buffer + row * col * i;
|
||||
LstmMatMulFp16(gate, input, weight_i, bias_i, row, deep, col, is_vec);
|
||||
}
|
||||
}
|
||||
|
||||
void GruStepUnitFp16(float16_t *output, const float16_t *input, const float16_t *input_reset_weight,
|
||||
const float16_t *input_update_weight, const float16_t *input_hidden_weight,
|
||||
const float16_t *state_reset_weight, const float16_t *state_update_weight,
|
||||
const float16_t *state_hidden_weight, const float16_t *bias, float16_t *hidden_state,
|
||||
float16_t *gate_buffer, const GruParameter *gru_parm) {
|
||||
InitGruGateFp16(gate_buffer, bias, gru_parm);
|
||||
|
||||
float16_t *update_gate = gate_buffer;
|
||||
float16_t *reset_gate = gate_buffer + gru_parm->batch_ * gru_parm->hidden_size_;
|
||||
float16_t *hidden_buffer = gate_buffer + gru_parm->batch_ * gru_parm->hidden_size_ * 2;
|
||||
|
||||
void GruStepUnitFp16(float16_t *output, const float16_t *input, const float16_t *input_weight,
|
||||
const float16_t *state_weight, const float16_t *bias, float16_t *hidden_state,
|
||||
float16_t *gate_buffer, float16_t *matmul_buffer[2], const GruParameter *gru_param) {
|
||||
bool is_vec = gru_param->batch_ == 1;
|
||||
// input * weight
|
||||
MatMulAccFp16(reset_gate, input, input_reset_weight, gru_parm->batch_, gru_parm->hidden_size_, gru_parm->input_size_);
|
||||
MatMulAccFp16(update_gate, input, input_update_weight, gru_parm->batch_, gru_parm->hidden_size_,
|
||||
gru_parm->input_size_);
|
||||
MatMulAccFp16(hidden_buffer, input, input_hidden_weight, gru_parm->batch_, gru_parm->hidden_size_,
|
||||
gru_parm->input_size_);
|
||||
if (is_vec) {
|
||||
UpdateGruInputGateFp16(gate_buffer, input, input_weight, bias, gru_param->batch_, gru_param->input_size_,
|
||||
gru_param->hidden_size_, gru_param->col_align_, is_vec);
|
||||
} else {
|
||||
// pack input for matmul
|
||||
RowMajor2Col16MajorFp16(input, matmul_buffer[0], gru_param->batch_, gru_param->input_size_, false);
|
||||
UpdateGruInputGateFp16(gate_buffer, matmul_buffer[0], input_weight, bias, gru_param->batch_, gru_param->input_size_,
|
||||
gru_param->hidden_size_, gru_param->col_align_, is_vec);
|
||||
}
|
||||
|
||||
const float16_t *state_update_weight = state_weight;
|
||||
const float16_t *state_reset_weight = state_weight + gru_param->hidden_size_ * gru_param->hidden_size_;
|
||||
const float16_t *state_hidden_weight = state_weight + gru_param->hidden_size_ * gru_param->hidden_size_ * 2;
|
||||
float16_t *state_update_gate = gate_buffer + gru_param->batch_ * gru_param->hidden_size_ * 3;
|
||||
float16_t *state_reset_gate = gate_buffer + gru_param->batch_ * gru_param->hidden_size_ * 4;
|
||||
float16_t *state_hidden_buffer = gate_buffer + gru_param->batch_ * gru_param->hidden_size_ * 5;
|
||||
const float16_t *state_update_bias = bias + gru_param->hidden_size_ * 3;
|
||||
const float16_t *state_reset_bias = bias + gru_param->hidden_size_ * 4;
|
||||
const float16_t *state_hidden_bias = bias + gru_param->hidden_size_ * 5;
|
||||
|
||||
// state * weight
|
||||
MatMulAccFp16(reset_gate, hidden_state, state_reset_weight, gru_parm->batch_, gru_parm->hidden_size_,
|
||||
gru_parm->hidden_size_);
|
||||
MatMulAccFp16(update_gate, hidden_state, state_update_weight, gru_parm->batch_, gru_parm->hidden_size_,
|
||||
gru_parm->hidden_size_);
|
||||
if (is_vec) {
|
||||
LstmMatMulFp16(state_reset_gate, hidden_state, state_reset_weight, state_reset_bias, gru_param->batch_,
|
||||
gru_param->hidden_size_, gru_param->hidden_size_, is_vec);
|
||||
LstmMatMulFp16(state_update_gate, hidden_state, state_update_weight, state_update_bias, gru_param->batch_,
|
||||
gru_param->hidden_size_, gru_param->hidden_size_, is_vec);
|
||||
} else {
|
||||
RowMajor2Col16MajorFp16(hidden_state, matmul_buffer[1], gru_param->batch_, gru_param->hidden_size_, false);
|
||||
LstmMatMulFp16(state_reset_gate, matmul_buffer[1], state_reset_weight, state_reset_bias, gru_param->batch_,
|
||||
gru_param->hidden_size_, gru_param->hidden_size_, is_vec);
|
||||
LstmMatMulFp16(state_update_gate, matmul_buffer[1], state_update_weight, state_update_bias, gru_param->batch_,
|
||||
gru_param->hidden_size_, gru_param->hidden_size_, is_vec);
|
||||
}
|
||||
|
||||
ElementAddFp16(gate_buffer, state_update_gate, gate_buffer, gru_param->batch_ * gru_param->hidden_size_ * 2);
|
||||
float16_t *update_gate = gate_buffer;
|
||||
float16_t *reset_gate = gate_buffer + gru_param->batch_ * gru_param->hidden_size_;
|
||||
float16_t *hidden_buffer = gate_buffer + gru_param->batch_ * gru_param->hidden_size_ * 2;
|
||||
|
||||
// update reset_gate
|
||||
SigmoidFp16(reset_gate, reset_gate, gru_parm->batch_ * gru_parm->hidden_size_);
|
||||
SigmoidFp16(reset_gate, reset_gate, gru_param->batch_ * gru_param->hidden_size_);
|
||||
|
||||
// update update_gate
|
||||
SigmoidFp16(update_gate, update_gate, gru_parm->batch_ * gru_parm->hidden_size_);
|
||||
SigmoidFp16(update_gate, update_gate, gru_param->batch_ * gru_param->hidden_size_);
|
||||
|
||||
ElementMulFp16(hidden_state, reset_gate, reset_gate, gru_parm->batch_ * gru_parm->hidden_size_);
|
||||
MatMulAccFp16(hidden_buffer, reset_gate, state_hidden_weight, gru_parm->batch_, gru_parm->hidden_size_,
|
||||
gru_parm->hidden_size_);
|
||||
ElementMulFp16(hidden_state, reset_gate, reset_gate, gru_param->batch_ * gru_param->hidden_size_);
|
||||
if (is_vec) {
|
||||
LstmMatMulFp16(state_hidden_buffer, reset_gate, state_hidden_weight, state_hidden_bias, gru_param->batch_,
|
||||
gru_param->hidden_size_, gru_param->hidden_size_, is_vec);
|
||||
} else {
|
||||
RowMajor2Col16MajorFp16(reset_gate, matmul_buffer[1], gru_param->batch_, gru_param->hidden_size_, false);
|
||||
LstmMatMulFp16(state_hidden_buffer, matmul_buffer[1], state_hidden_weight, state_hidden_bias, gru_param->batch_,
|
||||
gru_param->hidden_size_, gru_param->hidden_size_, is_vec);
|
||||
}
|
||||
ElementAddFp16(hidden_buffer, state_hidden_buffer, hidden_buffer, gru_param->batch_ * gru_param->hidden_size_);
|
||||
|
||||
TanhFp16(hidden_buffer, hidden_buffer, gru_parm->batch_ * gru_parm->hidden_size_);
|
||||
TanhFp16(hidden_buffer, hidden_buffer, gru_param->batch_ * gru_param->hidden_size_);
|
||||
|
||||
ElementMulFp16(update_gate, hidden_state, hidden_state, gru_parm->batch_ * gru_parm->hidden_size_);
|
||||
ElementMulFp16(update_gate, hidden_state, hidden_state, gru_param->batch_ * gru_param->hidden_size_);
|
||||
|
||||
ArithmeticParameter parameter;
|
||||
parameter.in_elements_num0_ = 1;
|
||||
parameter.in_elements_num1_ = gru_parm->batch_ * gru_parm->hidden_size_;
|
||||
parameter.in_elements_num1_ = gru_param->batch_ * gru_param->hidden_size_;
|
||||
float16_t one = 1.0f;
|
||||
ElementOptSubFp16(&one, update_gate, update_gate, gru_parm->batch_ * gru_parm->hidden_size_, ¶meter);
|
||||
ElementOptSubFp16(&one, update_gate, update_gate, gru_param->batch_ * gru_param->hidden_size_, ¶meter);
|
||||
|
||||
ElementMulAccFp16(update_gate, hidden_buffer, hidden_state, gru_parm->batch_ * gru_parm->hidden_size_);
|
||||
ElementMulAccFp16(update_gate, hidden_buffer, hidden_state, gru_param->batch_ * gru_param->hidden_size_);
|
||||
|
||||
memcpy(output, hidden_state, gru_parm->batch_ * gru_parm->hidden_size_ * sizeof(float16_t));
|
||||
memcpy(output, hidden_state, gru_param->batch_ * gru_param->hidden_size_ * sizeof(float16_t));
|
||||
}
|
||||
|
||||
void GruFp16(float16_t *output, const float16_t *input, const float16_t *weight_g, const float16_t *weight_r,
|
||||
const float16_t *bias, float16_t *hidden_state, float16_t *gate_buffer, int check_seq_len,
|
||||
const GruParameter *gru_parm) {
|
||||
const float16_t *bias, float16_t *hidden_state, float16_t *gate_buffer, float16_t *matmul_buffer[2],
|
||||
int check_seq_len, const GruParameter *gru_param) {
|
||||
// forward
|
||||
const float16_t *input_update_weight = weight_g;
|
||||
const float16_t *input_reset_weight = weight_g + gru_parm->input_size_ * gru_parm->hidden_size_;
|
||||
const float16_t *input_hidden_weight = weight_g + gru_parm->input_size_ * gru_parm->hidden_size_ * 2;
|
||||
|
||||
const float16_t *state_update_weight = weight_r;
|
||||
const float16_t *state_reset_weight = weight_r + gru_parm->hidden_size_ * gru_parm->hidden_size_;
|
||||
const float16_t *state_hidden_weight = weight_r + gru_parm->hidden_size_ * gru_parm->hidden_size_ * 2;
|
||||
|
||||
for (int t = 0; t < check_seq_len; t++) {
|
||||
const float16_t *input_ptr = input + t * gru_parm->input_step_;
|
||||
float16_t *output_ptr = output + t * gru_parm->output_step_;
|
||||
GruStepUnitFp16(output_ptr, input_ptr, input_reset_weight, input_update_weight, input_hidden_weight,
|
||||
state_reset_weight, state_update_weight, state_hidden_weight, bias, hidden_state, gate_buffer,
|
||||
gru_parm);
|
||||
const float16_t *input_ptr = input + t * gru_param->input_step_;
|
||||
float16_t *output_ptr = output + t * gru_param->output_step_;
|
||||
GruStepUnitFp16(output_ptr, input_ptr, weight_g, weight_r, bias, hidden_state, gate_buffer, matmul_buffer,
|
||||
gru_param);
|
||||
}
|
||||
// zero out extra fw outputs
|
||||
for (int t = check_seq_len; t < gru_parm->seq_len_; t++) {
|
||||
float16_t *output_ptr = output + t * gru_parm->output_step_;
|
||||
for (int i = 0; i < gru_parm->batch_ * gru_parm->hidden_size_; i++) {
|
||||
for (int t = check_seq_len; t < gru_param->seq_len_; t++) {
|
||||
float16_t *output_ptr = output + t * gru_param->output_step_;
|
||||
for (int i = 0; i < gru_param->batch_ * gru_param->hidden_size_; i++) {
|
||||
output_ptr[i] = 0.0f;
|
||||
}
|
||||
}
|
||||
|
||||
// backward
|
||||
if (gru_parm->bidirectional_) {
|
||||
input_update_weight = weight_g + gru_parm->input_size_ * gru_parm->hidden_size_ * 3;
|
||||
input_reset_weight = weight_g + gru_parm->input_size_ * gru_parm->hidden_size_ * 4;
|
||||
input_hidden_weight = weight_g + gru_parm->input_size_ * gru_parm->hidden_size_ * 5;
|
||||
|
||||
state_update_weight = weight_r + gru_parm->hidden_size_ * gru_parm->hidden_size_ * 3;
|
||||
state_reset_weight = weight_r + gru_parm->hidden_size_ * gru_parm->hidden_size_ * 4;
|
||||
state_hidden_weight = weight_r + gru_parm->hidden_size_ * gru_parm->hidden_size_ * 5;
|
||||
|
||||
float16_t *backward_output = output + gru_parm->batch_ * gru_parm->hidden_size_;
|
||||
const float16_t *backward_bias = bias + 3 * gru_parm->hidden_size_;
|
||||
float16_t *backward_hidden_state = hidden_state + gru_parm->batch_ * gru_parm->hidden_size_;
|
||||
if (gru_param->bidirectional_) {
|
||||
const float16_t *backward_weight_g = weight_g + 3 * gru_param->col_align_ * gru_param->input_size_;
|
||||
const float16_t *backward_weight_r = weight_r + 3 * gru_param->col_align_ * gru_param->hidden_size_;
|
||||
const float16_t *backward_bias = bias + 6 * gru_param->hidden_size_;
|
||||
float16_t *backward_output = output + gru_param->batch_ * gru_param->hidden_size_;
|
||||
float16_t *backward_hidden_state = hidden_state + gru_param->batch_ * gru_param->hidden_size_;
|
||||
for (int t = check_seq_len - 1; t >= 0; t--) {
|
||||
const float16_t *input_ptr = input + t * gru_parm->input_step_;
|
||||
float16_t *output_ptr = backward_output + t * gru_parm->output_step_;
|
||||
GruStepUnitFp16(output_ptr, input_ptr, input_reset_weight, input_update_weight, input_hidden_weight,
|
||||
state_reset_weight, state_update_weight, state_hidden_weight, backward_bias,
|
||||
backward_hidden_state, gate_buffer, gru_parm);
|
||||
const float16_t *input_ptr = input + t * gru_param->input_step_;
|
||||
float16_t *output_ptr = backward_output + t * gru_param->output_step_;
|
||||
GruStepUnitFp16(output_ptr, input_ptr, backward_weight_g, backward_weight_r, backward_bias, backward_hidden_state,
|
||||
gate_buffer, matmul_buffer, gru_param);
|
||||
}
|
||||
// zero out extra bw outputs
|
||||
for (int t = gru_parm->seq_len_ - 1; t >= check_seq_len; t--) {
|
||||
float16_t *output_ptr = backward_output + t * gru_parm->output_step_;
|
||||
for (int i = 0; i < gru_parm->batch_ * gru_parm->hidden_size_; i++) {
|
||||
for (int t = gru_param->seq_len_ - 1; t >= check_seq_len; t--) {
|
||||
float16_t *output_ptr = backward_output + t * gru_param->output_step_;
|
||||
for (int i = 0; i < gru_param->batch_ * gru_param->hidden_size_; i++) {
|
||||
output_ptr[i] = 0.0f;
|
||||
}
|
||||
}
|
||||
|
|
|
@ -21,8 +21,8 @@
|
|||
extern "C" {
|
||||
#endif
|
||||
void GruFp16(float16_t *output, const float16_t *input, const float16_t *weight_g, const float16_t *weight_r,
|
||||
const float16_t *bias, float16_t *hidden_state, float16_t *gate_buffer, int check_seq_len,
|
||||
const GruParameter *gru_parm);
|
||||
const float16_t *bias, float16_t *hidden_state, float16_t *gate_buffer, float16_t *matmul_buffer[2],
|
||||
int check_seq_len, const GruParameter *gru_param);
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
|
|
|
@ -18,17 +18,21 @@
|
|||
#include <string.h>
|
||||
#include "nnacl/fp16/activation_fp16.h"
|
||||
#include "nnacl/fp16/arithmetic_fp16.h"
|
||||
#include "nnacl/fp16/matmul_fp16.h"
|
||||
|
||||
void InitGateFp16(float16_t *gate_buffer, const float16_t *bias, const LstmParameter *lstm_parm) {
|
||||
int gate_offest = 0;
|
||||
for (int l = 0; l < 4; l++) {
|
||||
int batch_offest = gate_offest;
|
||||
int bias_offest = l * lstm_parm->hidden_size_;
|
||||
for (int b = 0; b < lstm_parm->batch_; b++) {
|
||||
memcpy(gate_buffer + batch_offest, bias + bias_offest, lstm_parm->hidden_size_ * sizeof(float16_t));
|
||||
batch_offest += lstm_parm->hidden_size_;
|
||||
}
|
||||
gate_offest += lstm_parm->batch_ * lstm_parm->hidden_size_;
|
||||
void PackLstmWeightFp32ToFp16(float16_t *dst, const float *src, int batch, int deep, int col, int col_align) {
|
||||
for (int i = 0; i < batch; i++) {
|
||||
const float *src_batch = src + i * col * deep;
|
||||
float16_t *dst_batch = dst + i * col_align * deep;
|
||||
RowMajor2Col8MajorFp16(src_batch, dst_batch, col, deep, true);
|
||||
}
|
||||
}
|
||||
|
||||
void PackLstmWeightFp16(float16_t *dst, const float16_t *src, int batch, int deep, int col, int col_align) {
|
||||
for (int i = 0; i < batch; i++) {
|
||||
const float16_t *src_batch = src + i * col * deep;
|
||||
float16_t *dst_batch = dst + i * col_align * deep;
|
||||
RowMajor2Col8MajorFp16(src_batch, dst_batch, col, deep, false);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -125,111 +129,111 @@ void UpdataOutputFp16(const float16_t *cell_state, float16_t *output_gate, float
|
|||
}
|
||||
}
|
||||
|
||||
void LstmStepUnitFp16(float16_t *output, const float16_t *input, const float16_t *input_input_weight,
|
||||
const float16_t *input_forget_weight, const float16_t *input_cell_weight,
|
||||
const float16_t *input_output_weight, const float16_t *state_input_weight,
|
||||
const float16_t *state_forget_weight, const float16_t *state_cell_weight,
|
||||
const float16_t *state_output_weight, const float16_t *bias, float16_t *hidden_state,
|
||||
void LstmMatMulFp16(float16_t *c, const float16_t *a, const float16_t *b, const float16_t *bias, int row, int deep,
|
||||
int col, bool is_vec) {
|
||||
if (is_vec) {
|
||||
memcpy(c, bias, col * sizeof(float16_t));
|
||||
MatMulAccFp16(c, a, b, row, col, deep);
|
||||
} else {
|
||||
MatMulFp16(a, b, c, bias, ActType_No, deep, row, col, col, OutType_Nhwc);
|
||||
}
|
||||
}
|
||||
|
||||
void UpdateLstmGateFp16(float16_t *gate_buffer, const float16_t *input, const float16_t *weight, const float16_t *bias,
|
||||
int row, int deep, int col, int col_align, bool is_vec) {
|
||||
for (int i = 0; i < 4; i++) {
|
||||
const float16_t *weight_i = weight + deep * col * i;
|
||||
const float16_t *bias_i = bias + col_align * i;
|
||||
float16_t *gate = gate_buffer + row * col * i;
|
||||
LstmMatMulFp16(gate, input, weight_i, bias_i, row, deep, col, is_vec);
|
||||
}
|
||||
}
|
||||
|
||||
void LstmStepUnitFp16(float16_t *output, const float16_t *input, const float16_t *input_weight,
|
||||
const float16_t *state_weight, const float16_t *bias, float16_t *hidden_state,
|
||||
float16_t *cell_state, float16_t *gate_buffer, float16_t *state_buffer,
|
||||
const LstmParameter *lstm_parm) {
|
||||
InitGateFp16(gate_buffer, bias, lstm_parm);
|
||||
|
||||
float16_t *input_gate = gate_buffer;
|
||||
float16_t *forget_gate = gate_buffer + lstm_parm->batch_ * lstm_parm->hidden_size_ * 2;
|
||||
float16_t *cell_gate = gate_buffer + lstm_parm->batch_ * lstm_parm->hidden_size_ * 3;
|
||||
float16_t *output_gate = gate_buffer + lstm_parm->batch_ * lstm_parm->hidden_size_ * 1;
|
||||
|
||||
float16_t *matmul_buffer[2], const LstmParameter *lstm_param) {
|
||||
bool is_vec = lstm_param->batch_ == 1;
|
||||
// input * weight
|
||||
MatMulAccFp16(input_gate, input, input_input_weight, lstm_parm->batch_, lstm_parm->hidden_size_,
|
||||
lstm_parm->input_size_);
|
||||
MatMulAccFp16(forget_gate, input, input_forget_weight, lstm_parm->batch_, lstm_parm->hidden_size_,
|
||||
lstm_parm->input_size_);
|
||||
MatMulAccFp16(cell_gate, input, input_cell_weight, lstm_parm->batch_, lstm_parm->hidden_size_,
|
||||
lstm_parm->input_size_);
|
||||
MatMulAccFp16(output_gate, input, input_output_weight, lstm_parm->batch_, lstm_parm->hidden_size_,
|
||||
lstm_parm->input_size_);
|
||||
if (is_vec) {
|
||||
UpdateLstmGateFp16(gate_buffer, input, input_weight, bias, lstm_param->batch_, lstm_param->input_size_,
|
||||
lstm_param->hidden_size_, lstm_param->col_align_, is_vec);
|
||||
} else {
|
||||
// pack input for matmul
|
||||
RowMajor2Col16MajorFp16(input, matmul_buffer[0], lstm_param->batch_, lstm_param->input_size_, false);
|
||||
UpdateLstmGateFp16(gate_buffer, matmul_buffer[0], input_weight, bias, lstm_param->batch_, lstm_param->input_size_,
|
||||
lstm_param->hidden_size_, lstm_param->col_align_, is_vec);
|
||||
}
|
||||
|
||||
// state * weight
|
||||
MatMulAccFp16(input_gate, hidden_state, state_input_weight, lstm_parm->batch_, lstm_parm->hidden_size_,
|
||||
lstm_parm->hidden_size_);
|
||||
MatMulAccFp16(forget_gate, hidden_state, state_forget_weight, lstm_parm->batch_, lstm_parm->hidden_size_,
|
||||
lstm_parm->hidden_size_);
|
||||
MatMulAccFp16(cell_gate, hidden_state, state_cell_weight, lstm_parm->batch_, lstm_parm->hidden_size_,
|
||||
lstm_parm->hidden_size_);
|
||||
MatMulAccFp16(output_gate, hidden_state, state_output_weight, lstm_parm->batch_, lstm_parm->hidden_size_,
|
||||
lstm_parm->hidden_size_);
|
||||
float16_t *state_gate = gate_buffer + lstm_param->batch_ * lstm_param->hidden_size_ * 4;
|
||||
const float16_t *state_bias = bias + lstm_param->col_align_ * 4;
|
||||
if (is_vec) {
|
||||
UpdateLstmGateFp16(state_gate, hidden_state, state_weight, state_bias, lstm_param->batch_, lstm_param->hidden_size_,
|
||||
lstm_param->hidden_size_, lstm_param->col_align_, is_vec);
|
||||
} else {
|
||||
// pack state for matmul
|
||||
RowMajor2Col16MajorFp16(hidden_state, matmul_buffer[1], lstm_param->batch_, lstm_param->hidden_size_, false);
|
||||
UpdateLstmGateFp16(state_gate, matmul_buffer[1], state_weight, state_bias, lstm_param->batch_,
|
||||
lstm_param->hidden_size_, lstm_param->hidden_size_, lstm_param->col_align_, is_vec);
|
||||
}
|
||||
ElementAddFp16(gate_buffer, state_gate, gate_buffer, 4 * lstm_param->batch_ * lstm_param->hidden_size_);
|
||||
|
||||
float16_t *input_gate = gate_buffer;
|
||||
float16_t *forget_gate = gate_buffer + lstm_param->batch_ * lstm_param->hidden_size_ * 2;
|
||||
float16_t *cell_gate = gate_buffer + lstm_param->batch_ * lstm_param->hidden_size_ * 3;
|
||||
float16_t *output_gate = gate_buffer + lstm_param->batch_ * lstm_param->hidden_size_;
|
||||
// update input_gate
|
||||
SigmoidFp16(input_gate, input_gate, lstm_parm->batch_ * lstm_parm->hidden_size_);
|
||||
SigmoidFp16(input_gate, input_gate, lstm_param->batch_ * lstm_param->hidden_size_);
|
||||
|
||||
// update forget_gate
|
||||
SigmoidFp16(forget_gate, forget_gate, lstm_parm->batch_ * lstm_parm->hidden_size_);
|
||||
SigmoidFp16(forget_gate, forget_gate, lstm_param->batch_ * lstm_param->hidden_size_);
|
||||
|
||||
// update cell_gate
|
||||
TanhFp16(cell_gate, cell_gate, lstm_parm->batch_ * lstm_parm->hidden_size_);
|
||||
TanhFp16(cell_gate, cell_gate, lstm_param->batch_ * lstm_param->hidden_size_);
|
||||
// update cell state
|
||||
UpdataStateFp16(cell_state, forget_gate, input_gate, cell_gate, state_buffer, lstm_parm->batch_,
|
||||
lstm_parm->hidden_size_, lstm_parm->smooth_);
|
||||
UpdataStateFp16(cell_state, forget_gate, input_gate, cell_gate, state_buffer, lstm_param->batch_,
|
||||
lstm_param->hidden_size_, lstm_param->smooth_);
|
||||
|
||||
// update output_gate
|
||||
SigmoidFp16(output_gate, output_gate, lstm_parm->batch_ * lstm_parm->hidden_size_);
|
||||
SigmoidFp16(output_gate, output_gate, lstm_param->batch_ * lstm_param->hidden_size_);
|
||||
// update output
|
||||
UpdataOutputFp16(cell_state, output_gate, hidden_state, state_buffer, lstm_parm->batch_, lstm_parm->hidden_size_,
|
||||
lstm_parm->smooth_);
|
||||
memcpy(output, hidden_state, lstm_parm->batch_ * lstm_parm->hidden_size_ * sizeof(float16_t));
|
||||
UpdataOutputFp16(cell_state, output_gate, hidden_state, state_buffer, lstm_param->batch_, lstm_param->hidden_size_,
|
||||
lstm_param->smooth_);
|
||||
memcpy(output, hidden_state, lstm_param->batch_ * lstm_param->hidden_size_ * sizeof(float16_t));
|
||||
|
||||
if (!(lstm_parm->smooth_ >= -FLT_EPSILON && lstm_parm->smooth_ <= FLT_EPSILON)) {
|
||||
memcpy(cell_state, state_buffer, lstm_parm->batch_ * lstm_parm->hidden_size_ * sizeof(float16_t));
|
||||
memcpy(hidden_state, state_buffer + lstm_parm->batch_ * lstm_parm->hidden_size_,
|
||||
lstm_parm->batch_ * lstm_parm->hidden_size_ * sizeof(float16_t));
|
||||
if (!(lstm_param->smooth_ >= -FLT_EPSILON && lstm_param->smooth_ <= FLT_EPSILON)) {
|
||||
memcpy(cell_state, state_buffer, lstm_param->batch_ * lstm_param->hidden_size_ * sizeof(float16_t));
|
||||
memcpy(hidden_state, state_buffer + lstm_param->batch_ * lstm_param->hidden_size_,
|
||||
lstm_param->batch_ * lstm_param->hidden_size_ * sizeof(float16_t));
|
||||
}
|
||||
}
|
||||
|
||||
void LstmFp16(float16_t *output, const float16_t *input, const float16_t *weight_i, const float16_t *weight_h,
|
||||
const float16_t *bias, float16_t *hidden_state, float16_t *cell_state, float16_t *gate_buffer,
|
||||
float16_t *state_buffer, const LstmParameter *lstm_parm) {
|
||||
float16_t *state_buffer, float16_t *matmul_buffer[2], const LstmParameter *lstm_param) {
|
||||
// forward
|
||||
const float16_t *input_input_weight = weight_i;
|
||||
const float16_t *input_forget_weight = weight_i + lstm_parm->input_size_ * lstm_parm->hidden_size_ * 2;
|
||||
const float16_t *input_cell_weight = weight_i + lstm_parm->input_size_ * lstm_parm->hidden_size_ * 3;
|
||||
const float16_t *input_output_weight = weight_i + lstm_parm->input_size_ * lstm_parm->hidden_size_ * 1;
|
||||
|
||||
const float16_t *state_input_weight = weight_h;
|
||||
const float16_t *state_forget_weight = weight_h + lstm_parm->hidden_size_ * lstm_parm->hidden_size_ * 2;
|
||||
const float16_t *state_cell_weight = weight_h + lstm_parm->hidden_size_ * lstm_parm->hidden_size_ * 3;
|
||||
const float16_t *state_output_weight = weight_h + lstm_parm->hidden_size_ * lstm_parm->hidden_size_ * 1;
|
||||
|
||||
for (int t = 0; t < lstm_parm->seq_len_; t++) {
|
||||
const float16_t *input_ptr = input + t * lstm_parm->input_step_;
|
||||
float16_t *output_ptr = output + t * lstm_parm->output_step_;
|
||||
LstmStepUnitFp16(output_ptr, input_ptr, input_input_weight, input_forget_weight, input_cell_weight,
|
||||
input_output_weight, state_input_weight, state_forget_weight, state_cell_weight,
|
||||
state_output_weight, bias, hidden_state, cell_state, gate_buffer, state_buffer, lstm_parm);
|
||||
for (int t = 0; t < lstm_param->seq_len_; t++) {
|
||||
const float16_t *input_ptr = input + t * lstm_param->input_step_;
|
||||
float16_t *output_ptr = output + t * lstm_param->output_step_;
|
||||
LstmStepUnitFp16(output_ptr, input_ptr, weight_i, weight_h, bias, hidden_state, cell_state, gate_buffer,
|
||||
state_buffer, matmul_buffer, lstm_param);
|
||||
}
|
||||
|
||||
// backward
|
||||
if (lstm_parm->bidirectional_) {
|
||||
input_input_weight = weight_i + lstm_parm->input_size_ * lstm_parm->hidden_size_ * 4;
|
||||
input_forget_weight = weight_i + lstm_parm->input_size_ * lstm_parm->hidden_size_ * 6;
|
||||
input_cell_weight = weight_i + lstm_parm->input_size_ * lstm_parm->hidden_size_ * 7;
|
||||
input_output_weight = weight_i + lstm_parm->input_size_ * lstm_parm->hidden_size_ * 5;
|
||||
|
||||
state_input_weight = weight_h + lstm_parm->hidden_size_ * lstm_parm->hidden_size_ * 4;
|
||||
state_forget_weight = weight_h + lstm_parm->hidden_size_ * lstm_parm->hidden_size_ * 6;
|
||||
state_cell_weight = weight_h + lstm_parm->hidden_size_ * lstm_parm->hidden_size_ * 7;
|
||||
state_output_weight = weight_h + lstm_parm->hidden_size_ * lstm_parm->hidden_size_ * 5;
|
||||
|
||||
float16_t *backward_output = output + lstm_parm->batch_ * lstm_parm->hidden_size_;
|
||||
const float16_t *backward_bias = bias + 4 * lstm_parm->hidden_size_;
|
||||
float16_t *backward_cell_state = cell_state + lstm_parm->batch_ * lstm_parm->hidden_size_;
|
||||
float16_t *backward_hidden_state = hidden_state + lstm_parm->batch_ * lstm_parm->hidden_size_;
|
||||
for (int t = lstm_parm->seq_len_ - 1; t >= 0; t--) {
|
||||
const float16_t *input_ptr = input + t * lstm_parm->input_step_;
|
||||
float16_t *output_ptr = backward_output + t * lstm_parm->output_step_;
|
||||
LstmStepUnitFp16(output_ptr, input_ptr, input_input_weight, input_forget_weight, input_cell_weight,
|
||||
input_output_weight, state_input_weight, state_forget_weight, state_cell_weight,
|
||||
state_output_weight, backward_bias, backward_hidden_state, backward_cell_state, gate_buffer,
|
||||
state_buffer, lstm_parm);
|
||||
if (lstm_param->bidirectional_) {
|
||||
const float16_t *backward_weight_i = weight_i + 4 * lstm_param->col_align_ * lstm_param->input_size_;
|
||||
const float16_t *backward_weight_h = weight_h + 4 * lstm_param->col_align_ * lstm_param->hidden_size_;
|
||||
const float16_t *backward_bias = bias + 8 * lstm_param->col_align_;
|
||||
float16_t *backward_output = output + lstm_param->batch_ * lstm_param->hidden_size_;
|
||||
float16_t *backward_cell_state = cell_state + lstm_param->batch_ * lstm_param->hidden_size_;
|
||||
float16_t *backward_hidden_state = hidden_state + lstm_param->batch_ * lstm_param->hidden_size_;
|
||||
for (int t = lstm_param->seq_len_ - 1; t >= 0; t--) {
|
||||
const float16_t *input_ptr = input + t * lstm_param->input_step_;
|
||||
float16_t *output_ptr = backward_output + t * lstm_param->output_step_;
|
||||
LstmStepUnitFp16(output_ptr, input_ptr, backward_weight_i, backward_weight_h, backward_bias,
|
||||
backward_hidden_state, backward_cell_state, gate_buffer, state_buffer, matmul_buffer,
|
||||
lstm_param);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -21,6 +21,13 @@
|
|||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif
|
||||
void PackLstmWeightFp32ToFp16(float16_t *dst, const float *src, int batch, int deep, int col, int col_align);
|
||||
|
||||
void PackLstmWeightFp16(float16_t *dst, const float16_t *src, int batch, int deep, int col, int col_align);
|
||||
|
||||
void LstmMatMulFp16(float16_t *c, const float16_t *a, const float16_t *b, const float16_t *bias, int row, int deep,
|
||||
int col, bool is_vec);
|
||||
|
||||
void MatMulAccFp16(float16_t *output, const float16_t *input, const float16_t *weight, int rows, int cols,
|
||||
int inner_size);
|
||||
|
||||
|
@ -30,7 +37,7 @@ int ElementOptMulAccFp16(const float16_t *input0, const float16_t input1, float1
|
|||
|
||||
void LstmFp16(float16_t *output, const float16_t *input, const float16_t *weight_i, const float16_t *weight_h,
|
||||
const float16_t *bias, float16_t *hidden_state, float16_t *cell_state, float16_t *gate_buffer,
|
||||
float16_t *state_buffer, const LstmParameter *lstm_parm);
|
||||
float16_t *state_buffer, float16_t *matmul_buffer[2], const LstmParameter *lstm_param);
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
|
|
|
@ -40,7 +40,7 @@ void GruStepUnit(float *output, const float *input, const float *input_weight, c
|
|||
gru_param->hidden_size_, gru_param->col_align_, is_vec);
|
||||
} else {
|
||||
// pack input for matmul
|
||||
PackLstmInput(matmul_buffer[0], input, gru_param->batch_, gru_param->input_size_);
|
||||
PackLstmInput(input, matmul_buffer[0], gru_param->batch_, gru_param->input_size_);
|
||||
UpdateGruInputGate(gate_buffer, matmul_buffer[0], input_weight, bias, gru_param->batch_, gru_param->input_size_,
|
||||
gru_param->hidden_size_, gru_param->col_align_, is_vec);
|
||||
}
|
||||
|
@ -62,7 +62,7 @@ void GruStepUnit(float *output, const float *input, const float *input_weight, c
|
|||
LstmMatMul(state_update_gate, hidden_state, state_update_weight, state_update_bias, gru_param->batch_,
|
||||
gru_param->hidden_size_, gru_param->hidden_size_, is_vec);
|
||||
} else {
|
||||
PackLstmInput(matmul_buffer[1], hidden_state, gru_param->batch_, gru_param->hidden_size_);
|
||||
PackLstmInput(hidden_state, matmul_buffer[1], gru_param->batch_, gru_param->hidden_size_);
|
||||
LstmMatMul(state_reset_gate, matmul_buffer[1], state_reset_weight, state_reset_bias, gru_param->batch_,
|
||||
gru_param->hidden_size_, gru_param->hidden_size_, is_vec);
|
||||
LstmMatMul(state_update_gate, matmul_buffer[1], state_update_weight, state_update_bias, gru_param->batch_,
|
||||
|
@ -83,7 +83,7 @@ void GruStepUnit(float *output, const float *input, const float *input_weight, c
|
|||
LstmMatMul(state_hidden_buffer, reset_gate, state_hidden_weight, state_hidden_bias, gru_param->batch_,
|
||||
gru_param->hidden_size_, gru_param->hidden_size_, is_vec);
|
||||
} else {
|
||||
PackLstmInput(matmul_buffer[1], reset_gate, gru_param->batch_, gru_param->hidden_size_);
|
||||
PackLstmInput(reset_gate, matmul_buffer[1], gru_param->batch_, gru_param->hidden_size_);
|
||||
LstmMatMul(state_hidden_buffer, matmul_buffer[1], state_hidden_weight, state_hidden_bias, gru_param->batch_,
|
||||
gru_param->hidden_size_, gru_param->hidden_size_, is_vec);
|
||||
}
|
||||
|
|
|
@ -35,7 +35,7 @@ void PackLstmWeight(float *dst, const float *src, int batch, int deep, int col,
|
|||
}
|
||||
}
|
||||
|
||||
void PackLstmInput(float *dst, const float *src, int row, int deep) {
|
||||
void PackLstmInput(const float *src, float *dst, int row, int deep) {
|
||||
#ifdef ENABLE_AVX
|
||||
RowMajor2Col6Major(src, dst, row, deep);
|
||||
#elif defined(ENABLE_SSE)
|
||||
|
@ -174,7 +174,7 @@ void LstmStepUnit(float *output, const float *input, const float *input_weight,
|
|||
lstm_param->hidden_size_, lstm_param->col_align_, is_vec);
|
||||
} else {
|
||||
// pack input for matmul
|
||||
PackLstmInput(matmul_buffer[0], input, lstm_param->batch_, lstm_param->input_size_);
|
||||
PackLstmInput(input, matmul_buffer[0], lstm_param->batch_, lstm_param->input_size_);
|
||||
UpdateLstmGate(gate_buffer, matmul_buffer[0], input_weight, bias, lstm_param->batch_, lstm_param->input_size_,
|
||||
lstm_param->hidden_size_, lstm_param->col_align_, is_vec);
|
||||
}
|
||||
|
@ -187,7 +187,7 @@ void LstmStepUnit(float *output, const float *input, const float *input_weight,
|
|||
lstm_param->hidden_size_, lstm_param->col_align_, is_vec);
|
||||
} else {
|
||||
// pack state for matmul
|
||||
PackLstmInput(matmul_buffer[1], hidden_state, lstm_param->batch_, lstm_param->hidden_size_);
|
||||
PackLstmInput(hidden_state, matmul_buffer[1], lstm_param->batch_, lstm_param->hidden_size_);
|
||||
UpdateLstmGate(state_gate, matmul_buffer[1], state_weight, state_bias, lstm_param->batch_, lstm_param->hidden_size_,
|
||||
lstm_param->hidden_size_, lstm_param->col_align_, is_vec);
|
||||
}
|
||||
|
@ -238,7 +238,7 @@ void Lstm(float *output, const float *input, const float *weight_i, const float
|
|||
if (lstm_param->bidirectional_) {
|
||||
const float *backward_weight_i = weight_i + 4 * lstm_param->col_align_ * lstm_param->input_size_;
|
||||
const float *backward_weight_h = weight_h + 4 * lstm_param->col_align_ * lstm_param->hidden_size_;
|
||||
const float *backward_bias = bias + 8 * lstm_param->hidden_size_;
|
||||
const float *backward_bias = bias + 8 * lstm_param->col_align_;
|
||||
float *backward_output = output + lstm_param->batch_ * lstm_param->hidden_size_;
|
||||
float *backward_cell_state = cell_state + lstm_param->batch_ * lstm_param->hidden_size_;
|
||||
float *backward_hidden_state = hidden_state + lstm_param->batch_ * lstm_param->hidden_size_;
|
||||
|
|
|
@ -23,7 +23,7 @@ extern "C" {
|
|||
#endif
|
||||
void PackLstmWeight(float *dst, const float *src, int batch, int deep, int col, int col_align);
|
||||
|
||||
void PackLstmInput(float *dst, const float *src, int row, int deep);
|
||||
void PackLstmInput(const float *src, float *dst, int row, int deep);
|
||||
|
||||
void LstmMatMul(float *c, const float *a, const float *b, const float *bias, int row, int deep, int col, bool is_vec);
|
||||
|
||||
|
@ -33,7 +33,7 @@ int ElementOptMulAcc(const float *input0, const float input1, float *output, con
|
|||
|
||||
void Lstm(float *output, const float *input, const float *weight_i, const float *weight_h, const float *bias,
|
||||
float *hidden_state, float *cell_state, float *gate_buffer, float *state_buffer, float *matmul_buffer[2],
|
||||
const LstmParameter *lstm_parm);
|
||||
const LstmParameter *lstm_param);
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
|
|
|
@ -19,6 +19,8 @@
|
|||
#include "src/kernel_registry.h"
|
||||
#include "include/errorcode.h"
|
||||
#include "nnacl/fp16/gru_fp16.h"
|
||||
#include "nnacl/fp16/cast_fp16.h"
|
||||
#include "nnacl/fp16/lstm_fp16.h"
|
||||
|
||||
using mindspore::kernel::KERNEL_ARCH::kCPU;
|
||||
using mindspore::lite::KernelRegistrar;
|
||||
|
@ -28,21 +30,32 @@ using mindspore::schema::PrimitiveType_Gru;
|
|||
|
||||
namespace mindspore::kernel {
|
||||
void GruFp16CPUKernel::FreeTmpBuffer() {
|
||||
if (gate_buffer_ != nullptr) {
|
||||
free(gate_buffer_);
|
||||
gate_buffer_ = nullptr;
|
||||
if (!is_vec_ || in_tensors_[1]->data_type() == kNumberTypeFloat32) {
|
||||
if (weight_g_ptr_ != nullptr) {
|
||||
free(weight_g_ptr_);
|
||||
weight_g_ptr_ = nullptr;
|
||||
}
|
||||
}
|
||||
if (bias_ptr_ != nullptr) {
|
||||
free(bias_ptr_);
|
||||
bias_ptr_ = nullptr;
|
||||
if (!is_vec_ || in_tensors_[2]->data_type() == kNumberTypeFloat32) {
|
||||
if (weight_r_ptr_ != nullptr) {
|
||||
free(weight_r_ptr_);
|
||||
weight_r_ptr_ = nullptr;
|
||||
}
|
||||
}
|
||||
if (weight_g_ptr_ != nullptr) {
|
||||
free(weight_g_ptr_);
|
||||
weight_g_ptr_ = nullptr;
|
||||
if (!is_vec_ || in_tensors_[3]->data_type() == kNumberTypeFloat32) {
|
||||
if (bias_ptr_ != nullptr) {
|
||||
free(bias_ptr_);
|
||||
bias_ptr_ = nullptr;
|
||||
}
|
||||
}
|
||||
if (weight_r_ptr_ != nullptr) {
|
||||
free(weight_r_ptr_);
|
||||
weight_r_ptr_ = nullptr;
|
||||
}
|
||||
|
||||
void GruFp16CPUKernel::FreeRunBuffer() {
|
||||
context_->allocator->Free(gate_buffer_);
|
||||
if (!is_vec_) {
|
||||
for (int i = 0; i < 2; i++) {
|
||||
context_->allocator->Free(matmul_buffer_[i]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -50,75 +63,115 @@ int GruFp16CPUKernel::InitParam() {
|
|||
auto input = in_tensors_.front();
|
||||
MS_ASSERT(input != nullptr);
|
||||
std::vector<int> in_shape = input->shape();
|
||||
gru_parm_->seq_len_ = in_shape.at(0);
|
||||
gru_parm_->batch_ = in_shape.at(1);
|
||||
gru_parm_->input_size_ = in_shape.at(2);
|
||||
gru_param_->seq_len_ = in_shape.at(0);
|
||||
gru_param_->batch_ = in_shape.at(1);
|
||||
gru_param_->input_size_ = in_shape.at(2);
|
||||
|
||||
auto weight_g = in_tensors_.at(1);
|
||||
MS_ASSERT(weight_g != nullptr);
|
||||
std::vector<int> w_shape = weight_g->shape();
|
||||
gru_parm_->hidden_size_ = w_shape.at(1) / 3;
|
||||
gru_param_->hidden_size_ = w_shape.at(1) / 3;
|
||||
|
||||
gru_parm_->input_step_ = gru_parm_->batch_ * gru_parm_->input_size_;
|
||||
gru_parm_->output_step_ = gru_parm_->bidirectional_ ? 2 * gru_parm_->batch_ * gru_parm_->hidden_size_
|
||||
: gru_parm_->batch_ * gru_parm_->hidden_size_;
|
||||
gru_param_->input_step_ = gru_param_->batch_ * gru_param_->input_size_;
|
||||
gru_param_->output_step_ = gru_param_->bidirectional_ ? 2 * gru_param_->batch_ * gru_param_->hidden_size_
|
||||
: gru_param_->batch_ * gru_param_->hidden_size_;
|
||||
|
||||
is_vec_ = gru_param_->batch_ == 1;
|
||||
gru_param_->row_align_ = is_vec_ ? gru_param_->batch_ : UP_ROUND(gru_param_->batch_, C16NUM);
|
||||
gru_param_->col_align_ = is_vec_ ? gru_param_->hidden_size_ : UP_ROUND(gru_param_->hidden_size_, C8NUM);
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
int GruFp16CPUKernel::InitBuffer() {
|
||||
gate_buffer_ =
|
||||
reinterpret_cast<float16_t *>(malloc(3 * gru_parm_->batch_ * gru_parm_->hidden_size_ * sizeof(float16_t)));
|
||||
if (gate_buffer_ == nullptr) {
|
||||
MS_LOG(ERROR) << "GruFp16CPUKernel malloc gate_buffer error.";
|
||||
int GruFp16CPUKernel::InitWeight(const lite::Tensor *tensor, float16_t *ptr, int deep) {
|
||||
auto weight_batch = gru_param_->bidirectional_ ? 6 : 3;
|
||||
if (tensor->data_type() == kNumberTypeFloat32) {
|
||||
auto weight_data = reinterpret_cast<float *>(tensor->data_c());
|
||||
is_vec_ ? Float32ToFloat16(weight_data, ptr, tensor->ElementsNum())
|
||||
: PackLstmWeightFp32ToFp16(ptr, weight_data, weight_batch, deep, gru_param_->hidden_size_,
|
||||
gru_param_->col_align_);
|
||||
} else if (tensor->data_type() == kNumberTypeFloat16) {
|
||||
auto weight_data = reinterpret_cast<float16_t *>(tensor->data_c());
|
||||
if (is_vec_) {
|
||||
ptr = weight_data;
|
||||
} else {
|
||||
PackLstmWeightFp16(ptr, weight_data, weight_batch, deep, gru_param_->hidden_size_, gru_param_->col_align_);
|
||||
}
|
||||
} else {
|
||||
MS_LOG(ERROR) << "Unsupported data type of weight tensor for lstm.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
int GruFp16CPUKernel::InitWeightBias() {
|
||||
auto weight_gate = in_tensors_.at(1);
|
||||
MS_ASSERT(weight_gate != nullptr);
|
||||
weight_g_ptr_ = reinterpret_cast<float16_t *>(malloc(weight_gate->ElementsNum() * sizeof(float16_t)));
|
||||
if (weight_g_ptr_ == nullptr) {
|
||||
MS_LOG(ERROR) << "GruFp16CPUKernel malloc weight_g_ptr_ error.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
auto weight_g_data = reinterpret_cast<float *>(weight_gate->data_c());
|
||||
for (size_t i = 0; i < weight_gate->ElementsNum(); i++) {
|
||||
weight_g_ptr_[i] = (float16_t)weight_g_data[i];
|
||||
}
|
||||
|
||||
auto weight_recu = in_tensors_.at(2);
|
||||
MS_ASSERT(weight_recu != nullptr);
|
||||
weight_r_ptr_ = reinterpret_cast<float16_t *>(malloc(weight_recu->ElementsNum() * sizeof(float16_t)));
|
||||
if (weight_r_ptr_ == nullptr) {
|
||||
MS_LOG(ERROR) << "GruFp16CPUKernel malloc weight_r_ptr_ error.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
auto weight_r_data = reinterpret_cast<float *>(weight_recu->data_c());
|
||||
for (size_t i = 0; i < weight_recu->ElementsNum(); i++) {
|
||||
weight_r_ptr_[i] = (float16_t)weight_r_data[i];
|
||||
}
|
||||
|
||||
int bias_num = gru_parm_->bidirectional_ ? 2 * 3 * gru_parm_->hidden_size_ : 3 * gru_parm_->hidden_size_;
|
||||
bias_ptr_ = reinterpret_cast<float16_t *>(malloc(bias_num * sizeof(float16_t)));
|
||||
if (bias_ptr_ == nullptr) {
|
||||
MS_LOG(ERROR) << "GruFp16CPUKernel malloc bias_ptr_ error.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
|
||||
auto bias_data = reinterpret_cast<float *>(in_tensors_.at(3)->data_c());
|
||||
const int state_bias_offset = 3 * gru_parm_->hidden_size_;
|
||||
for (int i = 0; i < state_bias_offset; i++) {
|
||||
bias_ptr_[i] = (float16_t)(bias_data[i] + bias_data[i + state_bias_offset]);
|
||||
}
|
||||
if (gru_parm_->bidirectional_) {
|
||||
bias_data += 3 * gru_parm_->hidden_size_ * 2;
|
||||
auto backward_bias = bias_ptr_ + 3 * gru_parm_->hidden_size_;
|
||||
for (int i = 0; i < state_bias_offset; i++) {
|
||||
backward_bias[i] = (float16_t)(bias_data[i] + bias_data[i + state_bias_offset]);
|
||||
auto weight_batch = gru_param_->bidirectional_ ? 6 : 3;
|
||||
// malloc and init input * weight right matrix buffer
|
||||
auto weight_g = in_tensors_.at(1);
|
||||
MS_ASSERT(weight_g != nullptr);
|
||||
if (!is_vec_ || weight_g->data_type() == kNumberTypeFloat32) {
|
||||
weight_g_ptr_ = reinterpret_cast<float16_t *>(
|
||||
malloc(weight_batch * gru_param_->col_align_ * gru_param_->input_size_ * sizeof(float16_t)));
|
||||
if (weight_g_ptr_ == nullptr) {
|
||||
MS_LOG(ERROR) << "GruFp16CPUKernel malloc weight_g_ptr_ error.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
}
|
||||
auto ret = InitWeight(weight_g, weight_g_ptr_, gru_param_->input_size_);
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "GruFp16CPUKernel init weight_g failed.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
|
||||
// malloc and init state * weight right matrix buffer
|
||||
auto weight_r = in_tensors_.at(2);
|
||||
MS_ASSERT(weight_r != nullptr);
|
||||
if (!is_vec_ || weight_r->data_type() == kNumberTypeFloat32) {
|
||||
weight_r_ptr_ = reinterpret_cast<float16_t *>(
|
||||
malloc(weight_batch * gru_param_->col_align_ * gru_param_->hidden_size_ * sizeof(float16_t)));
|
||||
if (weight_r_ptr_ == nullptr) {
|
||||
MS_LOG(ERROR) << "GruFp16CPUKernel malloc weight_r_ptr_ error.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
}
|
||||
ret = InitWeight(weight_r, weight_r_ptr_, gru_param_->hidden_size_);
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "GruFp16CPUKernel init weight_r failed.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
|
||||
int bias_batch = gru_param_->bidirectional_ ? 12 : 6;
|
||||
auto bias = in_tensors_.at(3);
|
||||
MS_ASSERT(bias != nullptr);
|
||||
if (!is_vec_ || bias->data_type() == kNumberTypeFloat32) {
|
||||
bias_ptr_ = reinterpret_cast<float16_t *>(malloc(bias_batch * gru_param_->col_align_ * sizeof(float16_t)));
|
||||
if (bias_ptr_ == nullptr) {
|
||||
MS_LOG(ERROR) << "GruFp16CPUKernel malloc bias_ptr_ error.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
memset(bias_ptr_, 0, bias_batch * gru_param_->col_align_ * sizeof(float16_t));
|
||||
}
|
||||
if (bias->data_type() == kNumberTypeFloat32) {
|
||||
auto bias_data = reinterpret_cast<float *>(bias->data_c());
|
||||
for (int i = 0; i < bias_batch; i++) {
|
||||
auto src_batch = bias_data + i * gru_param_->hidden_size_;
|
||||
auto dst_batch = bias_ptr_ + i * gru_param_->col_align_;
|
||||
Float32ToFloat16(src_batch, dst_batch, gru_param_->hidden_size_);
|
||||
}
|
||||
} else if (bias->data_type() == kNumberTypeFloat16) {
|
||||
auto bias_data = reinterpret_cast<float16_t *>(bias->data_c());
|
||||
if (is_vec_) {
|
||||
bias_ptr_ = bias_data;
|
||||
} else {
|
||||
for (int i = 0; i < bias_batch; i++) {
|
||||
auto src_batch = bias_data + i * gru_param_->hidden_size_;
|
||||
auto dst_batch = bias_ptr_ + i * gru_param_->col_align_;
|
||||
memcpy(dst_batch, src_batch, gru_param_->hidden_size_ * sizeof(float16_t));
|
||||
}
|
||||
}
|
||||
} else {
|
||||
MS_LOG(ERROR) << "Unsupported data type of bias tensor for lstm.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
|
@ -130,24 +183,43 @@ int GruFp16CPUKernel::Init() {
|
|||
}
|
||||
|
||||
int GruFp16CPUKernel::ReSize() {
|
||||
FreeTmpBuffer();
|
||||
auto ret = InitParam();
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "GruFp16CPUKernel InitParam error.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
|
||||
FreeTmpBuffer();
|
||||
ret = InitWeightBias();
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "GruFp16CPUKernel InitWeightBias error.";
|
||||
FreeTmpBuffer();
|
||||
return RET_ERROR;
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
ret = InitBuffer();
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "GruFp16CPUKernel InitBuffer error.";
|
||||
FreeTmpBuffer();
|
||||
int GruFp16CPUKernel::MallocRunBuffer() {
|
||||
if (!is_vec_) {
|
||||
matmul_buffer_[0] = reinterpret_cast<float16_t *>(
|
||||
context_->allocator->Malloc(3 * gru_param_->row_align_ * gru_param_->input_size_ * sizeof(float16_t)));
|
||||
if (matmul_buffer_[0] == nullptr) {
|
||||
MS_LOG(ERROR) << "GruFp16CPUKernel malloc input * weight left matirx error.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
|
||||
matmul_buffer_[1] = reinterpret_cast<float16_t *>(
|
||||
context_->allocator->Malloc(3 * gru_param_->row_align_ * gru_param_->hidden_size_ * sizeof(float16_t)));
|
||||
if (matmul_buffer_[1] == nullptr) {
|
||||
MS_LOG(ERROR) << "GruFp16CPUKernel malloc state * weight left matirx error.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
}
|
||||
|
||||
gate_buffer_ = reinterpret_cast<float16_t *>(
|
||||
context_->allocator->Malloc(4 * gru_param_->batch_ * gru_param_->hidden_size_ * sizeof(float16_t)));
|
||||
if (gate_buffer_ == nullptr) {
|
||||
MS_LOG(ERROR) << "GruFp16CPUKernel malloc gate_buffer error.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
return RET_OK;
|
||||
|
@ -166,22 +238,29 @@ int GruFp16CPUKernel::Run() {
|
|||
MS_ASSERT(output_ptr);
|
||||
auto output_hidden_state = out_tensors_[1];
|
||||
memcpy(output_hidden_state->data_c(), hidden_state->data_c(), hidden_state->ElementsNum() * sizeof(float16_t));
|
||||
int check_seq_len = gru_parm_->seq_len_;
|
||||
int check_seq_len = gru_param_->seq_len_;
|
||||
if (in_tensors_.size() == 6) {
|
||||
auto seq_len = reinterpret_cast<int *>(in_tensors_.at(5)->data_c());
|
||||
if (!std::equal(seq_len + 1, seq_len + gru_parm_->batch_, seq_len)) {
|
||||
if (!std::equal(seq_len + 1, seq_len + gru_param_->batch_, seq_len)) {
|
||||
MS_LOG(ERROR) << "different batch seq_len is currently not supported";
|
||||
return RET_ERROR;
|
||||
}
|
||||
check_seq_len = MSMIN(check_seq_len, MSMAX(0, seq_len[0]));
|
||||
}
|
||||
|
||||
auto ret = MallocRunBuffer();
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "GruFp16CPUKernel MallocRunBuffer error.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
MS_ASSERT(weight_g_ptr_ != nullptr);
|
||||
MS_ASSERT(weight_r_ptr_ != nullptr);
|
||||
MS_ASSERT(bias_ptr_ != nullptr);
|
||||
MS_ASSERT(gate_buffer_ != nullptr);
|
||||
GruFp16(output_ptr, input_ptr, weight_g_ptr_, weight_r_ptr_, bias_ptr_,
|
||||
reinterpret_cast<float16_t *>(output_hidden_state->data_c()), gate_buffer_, check_seq_len, gru_parm_);
|
||||
reinterpret_cast<float16_t *>(output_hidden_state->data_c()), gate_buffer_, matmul_buffer_, check_seq_len,
|
||||
gru_param_);
|
||||
FreeRunBuffer();
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
|
|
|
@ -26,7 +26,7 @@ class GruFp16CPUKernel : public LiteKernel {
|
|||
const std::vector<lite::Tensor *> &outputs, const lite::InnerContext *ctx,
|
||||
const mindspore::lite::PrimitiveC *primitive)
|
||||
: LiteKernel(parameter, inputs, outputs, ctx, primitive) {
|
||||
gru_parm_ = reinterpret_cast<GruParameter *>(op_parameter_);
|
||||
gru_param_ = reinterpret_cast<GruParameter *>(op_parameter_);
|
||||
}
|
||||
|
||||
~GruFp16CPUKernel() override { FreeTmpBuffer(); }
|
||||
|
@ -37,15 +37,19 @@ class GruFp16CPUKernel : public LiteKernel {
|
|||
|
||||
private:
|
||||
void FreeTmpBuffer();
|
||||
void FreeRunBuffer();
|
||||
int InitParam();
|
||||
int InitBuffer();
|
||||
int InitWeight(const lite::Tensor *tensor, float16_t *ptr, int deep);
|
||||
int InitWeightBias();
|
||||
int MallocRunBuffer();
|
||||
|
||||
float16_t *gate_buffer_ = nullptr;
|
||||
float16_t *weight_g_ptr_ = nullptr;
|
||||
float16_t *weight_r_ptr_ = nullptr;
|
||||
float16_t *bias_ptr_ = nullptr;
|
||||
GruParameter *gru_parm_ = nullptr;
|
||||
float16_t *matmul_buffer_[2];
|
||||
bool is_vec_ = false;
|
||||
GruParameter *gru_param_ = nullptr;
|
||||
};
|
||||
} // namespace mindspore::kernel
|
||||
|
||||
|
|
|
@ -20,6 +20,7 @@
|
|||
#include "src/kernel_registry.h"
|
||||
#include "include/errorcode.h"
|
||||
#include "nnacl/fp16/lstm_fp16.h"
|
||||
#include "nnacl/fp16/cast_fp16.h"
|
||||
|
||||
using mindspore::kernel::KERNEL_ARCH::kCPU;
|
||||
using mindspore::lite::KernelRegistrar;
|
||||
|
@ -29,25 +30,33 @@ using mindspore::schema::PrimitiveType_Lstm;
|
|||
|
||||
namespace mindspore::kernel {
|
||||
void LstmFp16CPUKernel::FreeTmpBuffer() {
|
||||
if (gate_buffer_ != nullptr) {
|
||||
free(gate_buffer_);
|
||||
gate_buffer_ = nullptr;
|
||||
if (!is_vec_ || in_tensors_[1]->data_type() == kNumberTypeFloat32) {
|
||||
if (weight_i_ptr_ != nullptr) {
|
||||
free(weight_i_ptr_);
|
||||
weight_i_ptr_ = nullptr;
|
||||
}
|
||||
}
|
||||
if (state_buffer_ != nullptr) {
|
||||
free(state_buffer_);
|
||||
state_buffer_ = nullptr;
|
||||
if (!is_vec_ || in_tensors_[2]->data_type() == kNumberTypeFloat32) {
|
||||
if (weight_h_ptr_ != nullptr) {
|
||||
free(weight_h_ptr_);
|
||||
weight_h_ptr_ = nullptr;
|
||||
}
|
||||
}
|
||||
if (weight_i_ptr_ != nullptr) {
|
||||
free(weight_i_ptr_);
|
||||
weight_i_ptr_ = nullptr;
|
||||
if (!is_vec_ || in_tensors_[3]->data_type() == kNumberTypeFloat32) {
|
||||
if (bias_ptr_ != nullptr) {
|
||||
free(bias_ptr_);
|
||||
bias_ptr_ = nullptr;
|
||||
}
|
||||
}
|
||||
if (weight_h_ptr_ != nullptr) {
|
||||
free(weight_h_ptr_);
|
||||
weight_h_ptr_ = nullptr;
|
||||
}
|
||||
if (bias_ptr_ != nullptr) {
|
||||
free(bias_ptr_);
|
||||
bias_ptr_ = nullptr;
|
||||
}
|
||||
|
||||
void LstmFp16CPUKernel::FreeRunBuffer() {
|
||||
context_->allocator->Free(gate_buffer_);
|
||||
context_->allocator->Free(state_buffer_);
|
||||
if (!is_vec_) {
|
||||
for (int i = 0; i < 2; i++) {
|
||||
context_->allocator->Free(matmul_buffer_[i]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -67,87 +76,107 @@ int LstmFp16CPUKernel::InitParam() {
|
|||
lstm_param_->input_step_ = lstm_param_->batch_ * lstm_param_->input_size_;
|
||||
lstm_param_->output_step_ = lstm_param_->bidirectional_ ? 2 * lstm_param_->batch_ * lstm_param_->hidden_size_
|
||||
: lstm_param_->batch_ * lstm_param_->hidden_size_;
|
||||
|
||||
is_vec_ = lstm_param_->batch_ == 1;
|
||||
lstm_param_->row_align_ = is_vec_ ? lstm_param_->batch_ : UP_ROUND(lstm_param_->batch_, C16NUM);
|
||||
lstm_param_->col_align_ = is_vec_ ? lstm_param_->hidden_size_ : UP_ROUND(lstm_param_->hidden_size_, C8NUM);
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
int LstmFp16CPUKernel::InitBuffer() {
|
||||
gate_buffer_ =
|
||||
reinterpret_cast<float16_t *>(malloc(4 * lstm_param_->batch_ * lstm_param_->hidden_size_ * sizeof(float16_t)));
|
||||
if (gate_buffer_ == nullptr) {
|
||||
MS_LOG(ERROR) << "Lstm fp16 malloc gate_buffer error.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
if (!(lstm_param_->smooth_ >= -FLT_EPSILON && lstm_param_->smooth_ <= FLT_EPSILON)) {
|
||||
int buffer_size = 2 * lstm_param_->batch_ * lstm_param_->hidden_size_ * sizeof(float16_t);
|
||||
state_buffer_ = reinterpret_cast<float16_t *>(malloc(buffer_size));
|
||||
if (state_buffer_ == nullptr) {
|
||||
MS_LOG(ERROR) << "Lstm fp16 malloc state_buffer error.";
|
||||
return RET_ERROR;
|
||||
int LstmFp16CPUKernel::InitWeight(const lite::Tensor *tensor, float16_t *ptr, int deep) {
|
||||
auto weight_batch = lstm_param_->bidirectional_ ? 8 : 4;
|
||||
if (tensor->data_type() == kNumberTypeFloat32) {
|
||||
auto weight_data = reinterpret_cast<float *>(tensor->data_c());
|
||||
is_vec_ ? Float32ToFloat16(weight_data, ptr, tensor->ElementsNum())
|
||||
: PackLstmWeightFp32ToFp16(ptr, weight_data, weight_batch, deep, lstm_param_->hidden_size_,
|
||||
lstm_param_->col_align_);
|
||||
} else if (tensor->data_type() == kNumberTypeFloat16) {
|
||||
auto weight_data = reinterpret_cast<float16_t *>(tensor->data_c());
|
||||
if (is_vec_) {
|
||||
ptr = weight_data;
|
||||
} else {
|
||||
PackLstmWeightFp16(ptr, weight_data, weight_batch, deep, lstm_param_->hidden_size_, lstm_param_->col_align_);
|
||||
}
|
||||
} else {
|
||||
MS_LOG(ERROR) << "Unsupported data type of weight tensor for lstm.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
int LstmFp16CPUKernel::InitWeightBias() {
|
||||
// copy weight_i and weight_h
|
||||
auto weight_batch = lstm_param_->bidirectional_ ? 8 : 4;
|
||||
// malloc and init input * weight right matrix buffer
|
||||
auto weight_i = in_tensors_.at(1);
|
||||
MS_ASSERT(weight_i != nullptr);
|
||||
weight_i_ptr_ = reinterpret_cast<float16_t *>(malloc(weight_i->ElementsNum() * sizeof(float16_t)));
|
||||
if (weight_i_ptr_ == nullptr) {
|
||||
MS_LOG(ERROR) << "Lstm fp16 malloc weight_i_ptr_ error.";
|
||||
if (!is_vec_ || weight_i->data_type() == kNumberTypeFloat32) {
|
||||
weight_i_ptr_ = reinterpret_cast<float16_t *>(
|
||||
malloc(weight_batch * lstm_param_->col_align_ * lstm_param_->input_size_ * sizeof(float16_t)));
|
||||
if (weight_i_ptr_ == nullptr) {
|
||||
MS_LOG(ERROR) << "LstmFp16CPUKernel malloc weight_i_ptr_ error.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
}
|
||||
auto ret = InitWeight(weight_i, weight_i_ptr_, lstm_param_->input_size_);
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "LstmFp16CPUKernel init weight_i failed.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
auto weight_i_data = reinterpret_cast<float *>(weight_i->data_c());
|
||||
for (size_t i = 0; i < weight_i->ElementsNum(); i++) {
|
||||
weight_i_ptr_[i] = (float16_t)weight_i_data[i];
|
||||
}
|
||||
|
||||
// malloc and init state * weight right matrix buffer
|
||||
auto weight_h = in_tensors_.at(2);
|
||||
MS_ASSERT(weight_h != nullptr);
|
||||
weight_h_ptr_ = reinterpret_cast<float16_t *>(malloc(weight_h->ElementsNum() * sizeof(float16_t)));
|
||||
if (weight_h_ptr_ == nullptr) {
|
||||
MS_LOG(ERROR) << "Lstm fp16 malloc weight_h_ error.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
auto weight_h_data = reinterpret_cast<float *>(weight_h->data_c());
|
||||
for (size_t i = 0; i < weight_h->ElementsNum(); i++) {
|
||||
weight_h_ptr_[i] = (float16_t)weight_h_data[i];
|
||||
}
|
||||
|
||||
std::vector<int> w_shape = weight_i->shape();
|
||||
auto hidden_size = w_shape.at(1) / 4;
|
||||
// init bias
|
||||
int bias_num = lstm_param_->bidirectional_ ? 2 * 4 * hidden_size : 4 * hidden_size;
|
||||
bias_ptr_ = reinterpret_cast<float16_t *>(malloc(bias_num * sizeof(float16_t)));
|
||||
if (bias_ptr_ == nullptr) {
|
||||
MS_LOG(ERROR) << "Lstm fp16 malloc bias_ptr_ error.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
|
||||
auto bias_data = reinterpret_cast<float *>(in_tensors_.at(3)->data_c());
|
||||
const int state_bias_offset = 4 * hidden_size;
|
||||
for (int i = 0; i < state_bias_offset; i++) {
|
||||
bias_ptr_[i] = (float16_t)(bias_data[i] + bias_data[i + state_bias_offset]);
|
||||
}
|
||||
if (lstm_param_->bidirectional_) {
|
||||
bias_data += 4 * hidden_size * 2;
|
||||
auto backward_bias = bias_ptr_ + 4 * hidden_size;
|
||||
for (int i = 0; i < state_bias_offset; i++) {
|
||||
backward_bias[i] = (float16_t)(bias_data[i] + bias_data[i + state_bias_offset]);
|
||||
if (!is_vec_ || weight_h->data_type() == kNumberTypeFloat32) {
|
||||
weight_h_ptr_ = reinterpret_cast<float16_t *>(
|
||||
malloc(weight_batch * lstm_param_->col_align_ * lstm_param_->hidden_size_ * sizeof(float16_t)));
|
||||
if (weight_h_ptr_ == nullptr) {
|
||||
MS_LOG(ERROR) << "LstmFp16CPUKernel malloc weight_h_ptr_ error.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
}
|
||||
ret = InitWeight(weight_h, weight_h_ptr_, lstm_param_->hidden_size_);
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "LstmFp16CPUKernel init weight_h failed.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
|
||||
int bias_batch = lstm_param_->bidirectional_ ? 16 : 8;
|
||||
auto bias = in_tensors_.at(3);
|
||||
MS_ASSERT(bias != nullptr);
|
||||
if (!is_vec_ || bias->data_type() == kNumberTypeFloat32) {
|
||||
bias_ptr_ = reinterpret_cast<float16_t *>(malloc(bias_batch * lstm_param_->col_align_ * sizeof(float16_t)));
|
||||
if (bias_ptr_ == nullptr) {
|
||||
MS_LOG(ERROR) << "LstmFp16CPUKernel malloc bias_ptr_ error.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
memset(bias_ptr_, 0, bias_batch * lstm_param_->col_align_ * sizeof(float16_t));
|
||||
}
|
||||
if (bias->data_type() == kNumberTypeFloat32) {
|
||||
auto bias_data = reinterpret_cast<float *>(bias->data_c());
|
||||
for (int i = 0; i < bias_batch; i++) {
|
||||
auto src_batch = bias_data + i * lstm_param_->hidden_size_;
|
||||
auto dst_batch = bias_ptr_ + i * lstm_param_->col_align_;
|
||||
Float32ToFloat16(src_batch, dst_batch, lstm_param_->hidden_size_);
|
||||
}
|
||||
} else if (bias->data_type() == kNumberTypeFloat16) {
|
||||
auto bias_data = reinterpret_cast<float16_t *>(bias->data_c());
|
||||
if (is_vec_) {
|
||||
bias_ptr_ = bias_data;
|
||||
} else {
|
||||
for (int i = 0; i < bias_batch; i++) {
|
||||
auto src_batch = bias_data + i * lstm_param_->hidden_size_;
|
||||
auto dst_batch = bias_ptr_ + i * lstm_param_->col_align_;
|
||||
memcpy(dst_batch, src_batch, lstm_param_->hidden_size_ * sizeof(float16_t));
|
||||
}
|
||||
}
|
||||
} else {
|
||||
MS_LOG(ERROR) << "Unsupported data type of bias tensor for lstm.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
int LstmFp16CPUKernel::Init() {
|
||||
FreeTmpBuffer();
|
||||
auto ret = InitWeightBias();
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "Lstm fp16 InitWeightBias error.";
|
||||
FreeTmpBuffer();
|
||||
return RET_ERROR;
|
||||
}
|
||||
|
||||
if (!InferShapeDone()) {
|
||||
return RET_OK;
|
||||
}
|
||||
|
@ -161,15 +190,50 @@ int LstmFp16CPUKernel::ReSize() {
|
|||
return RET_ERROR;
|
||||
}
|
||||
|
||||
ret = InitBuffer();
|
||||
FreeTmpBuffer();
|
||||
ret = InitWeightBias();
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "Lstm fp16 InitBuffer error.";
|
||||
MS_LOG(ERROR) << "Lstm fp16 InitWeightBias error.";
|
||||
FreeTmpBuffer();
|
||||
return RET_ERROR;
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
int LstmFp16CPUKernel::MallocRunBuffer() {
|
||||
if (!is_vec_) {
|
||||
matmul_buffer_[0] = reinterpret_cast<float16_t *>(
|
||||
context_->allocator->Malloc(4 * lstm_param_->row_align_ * lstm_param_->input_size_ * sizeof(float16_t)));
|
||||
if (matmul_buffer_[0] == nullptr) {
|
||||
MS_LOG(ERROR) << "LstmFp16CPUKernel malloc input * weight left matirx error.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
|
||||
matmul_buffer_[1] = reinterpret_cast<float16_t *>(
|
||||
context_->allocator->Malloc(4 * lstm_param_->row_align_ * lstm_param_->hidden_size_ * sizeof(float16_t)));
|
||||
if (matmul_buffer_[1] == nullptr) {
|
||||
MS_LOG(ERROR) << "LstmFp16CPUKernel malloc state * weight left matirx error.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
}
|
||||
|
||||
gate_buffer_ = reinterpret_cast<float16_t *>(
|
||||
context_->allocator->Malloc(8 * lstm_param_->batch_ * lstm_param_->hidden_size_ * sizeof(float16_t)));
|
||||
if (gate_buffer_ == nullptr) {
|
||||
MS_LOG(ERROR) << "LstmFp16CPUKernel malloc gate_buffer error.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
if (!(lstm_param_->smooth_ >= -FLT_EPSILON && lstm_param_->smooth_ <= FLT_EPSILON)) {
|
||||
int buffer_size = 2 * lstm_param_->batch_ * lstm_param_->hidden_size_ * sizeof(float16_t);
|
||||
state_buffer_ = reinterpret_cast<float16_t *>(context_->allocator->Malloc(buffer_size));
|
||||
if (state_buffer_ == nullptr) {
|
||||
MS_LOG(ERROR) << "LstmFp16CPUKernel malloc state_buffer error.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
int LstmFp16CPUKernel::Run() {
|
||||
auto input = in_tensors_.at(kInputIndex);
|
||||
MS_ASSERT(input != nullptr);
|
||||
|
@ -189,13 +253,20 @@ int LstmFp16CPUKernel::Run() {
|
|||
auto output_cell_state = out_tensors_[2];
|
||||
memcpy(output_cell_state->data_c(), cell_state->data_c(), cell_state->ElementsNum() * sizeof(float16_t));
|
||||
|
||||
MS_ASSERT(weight_h_ptr_);
|
||||
auto ret = MallocRunBuffer();
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "LstmFp16CPUKernel MallocRunBuffer error.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
MS_ASSERT(weight_i_ptr_);
|
||||
MS_ASSERT(weight_h_ptr_);
|
||||
MS_ASSERT(bias_ptr_);
|
||||
MS_ASSERT(gate_buffer_);
|
||||
LstmFp16(output_ptr, input_ptr, weight_i_ptr_, weight_h_ptr_, bias_ptr_,
|
||||
reinterpret_cast<float16_t *>(output_hidden_state->data_c()),
|
||||
reinterpret_cast<float16_t *>(output_cell_state->data_c()), gate_buffer_, state_buffer_, lstm_param_);
|
||||
reinterpret_cast<float16_t *>(output_cell_state->data_c()), gate_buffer_, state_buffer_, matmul_buffer_,
|
||||
lstm_param_);
|
||||
FreeRunBuffer();
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
|
|
|
@ -39,15 +39,19 @@ class LstmFp16CPUKernel : public LiteKernel {
|
|||
|
||||
private:
|
||||
void FreeTmpBuffer();
|
||||
void FreeRunBuffer();
|
||||
int InitParam();
|
||||
int InitBuffer();
|
||||
int InitWeight(const lite::Tensor *tensor, float16_t *ptr, int deep);
|
||||
int InitWeightBias();
|
||||
int MallocRunBuffer();
|
||||
|
||||
float16_t *gate_buffer_ = nullptr;
|
||||
float16_t *state_buffer_ = nullptr;
|
||||
float16_t *weight_i_ptr_ = nullptr;
|
||||
float16_t *weight_h_ptr_ = nullptr;
|
||||
float16_t *bias_ptr_ = nullptr;
|
||||
float16_t *matmul_buffer_[2];
|
||||
bool is_vec_ = false;
|
||||
LstmParameter *lstm_param_ = nullptr;
|
||||
};
|
||||
} // namespace mindspore::kernel
|
||||
|
|
|
@ -1548,7 +1548,7 @@ function Run_arm64() {
|
|||
echo 'cd /data/local/tmp/benchmark_test' > adb_run_cmd.txt
|
||||
echo 'export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/data/local/tmp/benchmark_test' >> adb_run_cmd.txt
|
||||
if [[ $accuracy_limit == "-1" ]]; then
|
||||
echo './benchmark --modelFile='${model_name}'.fp16.ms --inDataFile=/data/local/tmp/input_output/input/'${model_name}'.ms.bin --benchmarkDataFile=/data/local/tmp/input_output/output/'${model_name}'.ms.out --enableFp16=true --inputShapes='${input_shapes} >> adb_run_cmd.txt
|
||||
echo './benchmark --modelFile='${model_name}'.fp16.ms --inDataFile=/data/local/tmp/input_output/input/'${model_name}'.ms.bin --enableFp16=true --inputShapes='${input_shapes} >> adb_run_cmd.txt
|
||||
else
|
||||
echo './benchmark --modelFile='${model_name}'.fp16.ms --inDataFile=/data/local/tmp/input_output/input/'${model_name}'.ms.bin --benchmarkDataFile=/data/local/tmp/input_output/output/'${model_name}'.ms.out --enableFp16=true --accuracyThreshold='${accuracy_limit} ' --inputShapes='${input_shapes} >> adb_run_cmd.txt
|
||||
fi
|
||||
|
|
Loading…
Reference in New Issue