forked from mindspore-Ecosystem/mindspore
!13594 [MSLITE][Develop] optimize arm cpu fp32 op: gru
From: @yangruoqi713 Reviewed-by: @zhang_xue_tong,@hangangqiang Signed-off-by: @zhang_xue_tong
This commit is contained in:
commit
40a6da4b57
|
@ -20,40 +20,22 @@
|
|||
#include "nnacl/fp16/arithmetic_fp16.h"
|
||||
#include "nnacl/fp16/matmul_fp16.h"
|
||||
|
||||
void UpdateGruInputGateFp16(float16_t *gate_buffer, const float16_t *input, const float16_t *weight,
|
||||
const float16_t *bias, int row, int deep, int col, int col_align, bool is_vec) {
|
||||
for (int i = 0; i < 3; i++) {
|
||||
const float16_t *weight_i = weight + deep * col * i;
|
||||
const float16_t *bias_i = bias + col_align * i;
|
||||
float16_t *gate = gate_buffer + row * col * i;
|
||||
LstmMatMulFp16(gate, input, weight_i, bias_i, row, deep, col, is_vec);
|
||||
}
|
||||
}
|
||||
|
||||
void GruStepUnitFp16(float16_t *output, const float16_t *input, const float16_t *input_weight,
|
||||
const float16_t *state_weight, const float16_t *bias, float16_t *hidden_state,
|
||||
float16_t *gate_buffer, float16_t *matmul_buffer[2], const GruParameter *gru_param) {
|
||||
void GruStepUnitFp16(float16_t *output, float16_t *update_gate, float16_t *reset_gate, float16_t *hidden_buffer,
|
||||
const float16_t *state_weight, const float16_t *state_bias, float16_t *hidden_state,
|
||||
float16_t *buffer[4], const GruParameter *gru_param) {
|
||||
float16_t *packed_state = buffer[2];
|
||||
float16_t *state_gate = buffer[3];
|
||||
bool is_vec = gru_param->batch_ == 1;
|
||||
// input * weight
|
||||
if (is_vec) {
|
||||
UpdateGruInputGateFp16(gate_buffer, input, input_weight, bias, gru_param->batch_, gru_param->input_size_,
|
||||
gru_param->hidden_size_, gru_param->col_align_, is_vec);
|
||||
} else {
|
||||
// pack input for matmul
|
||||
RowMajor2Col16MajorFp16(input, matmul_buffer[0], gru_param->batch_, gru_param->input_size_, false);
|
||||
UpdateGruInputGateFp16(gate_buffer, matmul_buffer[0], input_weight, bias, gru_param->batch_, gru_param->input_size_,
|
||||
gru_param->hidden_size_, gru_param->col_align_, is_vec);
|
||||
}
|
||||
|
||||
const float16_t *state_update_weight = state_weight;
|
||||
const float16_t *state_reset_weight = state_weight + gru_param->hidden_size_ * gru_param->hidden_size_;
|
||||
const float16_t *state_hidden_weight = state_weight + gru_param->hidden_size_ * gru_param->hidden_size_ * 2;
|
||||
float16_t *state_update_gate = gate_buffer + gru_param->batch_ * gru_param->hidden_size_ * 3;
|
||||
float16_t *state_reset_gate = gate_buffer + gru_param->batch_ * gru_param->hidden_size_ * 4;
|
||||
float16_t *state_hidden_buffer = gate_buffer + gru_param->batch_ * gru_param->hidden_size_ * 5;
|
||||
const float16_t *state_update_bias = bias + gru_param->hidden_size_ * 3;
|
||||
const float16_t *state_reset_bias = bias + gru_param->hidden_size_ * 4;
|
||||
const float16_t *state_hidden_bias = bias + gru_param->hidden_size_ * 5;
|
||||
float16_t *state_update_gate = state_gate;
|
||||
float16_t *state_reset_gate = state_gate + gru_param->batch_ * gru_param->hidden_size_;
|
||||
float16_t *state_hidden_buffer = state_gate + gru_param->batch_ * gru_param->hidden_size_ * 2;
|
||||
const float16_t *state_update_bias = state_bias;
|
||||
const float16_t *state_reset_bias = state_bias + gru_param->hidden_size_;
|
||||
const float16_t *state_hidden_bias = state_bias + gru_param->hidden_size_ * 2;
|
||||
|
||||
// state * weight
|
||||
if (is_vec) {
|
||||
|
@ -62,17 +44,15 @@ void GruStepUnitFp16(float16_t *output, const float16_t *input, const float16_t
|
|||
LstmMatMulFp16(state_update_gate, hidden_state, state_update_weight, state_update_bias, gru_param->batch_,
|
||||
gru_param->hidden_size_, gru_param->hidden_size_, is_vec);
|
||||
} else {
|
||||
RowMajor2Col16MajorFp16(hidden_state, matmul_buffer[1], gru_param->batch_, gru_param->hidden_size_, false);
|
||||
LstmMatMulFp16(state_reset_gate, matmul_buffer[1], state_reset_weight, state_reset_bias, gru_param->batch_,
|
||||
RowMajor2Col16MajorFp16(hidden_state, packed_state, gru_param->batch_, gru_param->hidden_size_, false);
|
||||
LstmMatMulFp16(state_reset_gate, packed_state, state_reset_weight, state_reset_bias, gru_param->batch_,
|
||||
gru_param->hidden_size_, gru_param->hidden_size_, is_vec);
|
||||
LstmMatMulFp16(state_update_gate, matmul_buffer[1], state_update_weight, state_update_bias, gru_param->batch_,
|
||||
LstmMatMulFp16(state_update_gate, packed_state, state_update_weight, state_update_bias, gru_param->batch_,
|
||||
gru_param->hidden_size_, gru_param->hidden_size_, is_vec);
|
||||
}
|
||||
|
||||
ElementAddFp16(gate_buffer, state_update_gate, gate_buffer, gru_param->batch_ * gru_param->hidden_size_ * 2);
|
||||
float16_t *update_gate = gate_buffer;
|
||||
float16_t *reset_gate = gate_buffer + gru_param->batch_ * gru_param->hidden_size_;
|
||||
float16_t *hidden_buffer = gate_buffer + gru_param->batch_ * gru_param->hidden_size_ * 2;
|
||||
ElementAddFp16(update_gate, state_update_gate, update_gate, gru_param->batch_ * gru_param->hidden_size_);
|
||||
ElementAddFp16(reset_gate, state_update_gate + gru_param->batch_ * gru_param->hidden_size_, reset_gate,
|
||||
gru_param->batch_ * gru_param->hidden_size_);
|
||||
|
||||
// update reset_gate
|
||||
SigmoidFp16(reset_gate, reset_gate, gru_param->batch_ * gru_param->hidden_size_);
|
||||
|
@ -85,8 +65,8 @@ void GruStepUnitFp16(float16_t *output, const float16_t *input, const float16_t
|
|||
LstmMatMulFp16(state_hidden_buffer, reset_gate, state_hidden_weight, state_hidden_bias, gru_param->batch_,
|
||||
gru_param->hidden_size_, gru_param->hidden_size_, is_vec);
|
||||
} else {
|
||||
RowMajor2Col16MajorFp16(reset_gate, matmul_buffer[1], gru_param->batch_, gru_param->hidden_size_, false);
|
||||
LstmMatMulFp16(state_hidden_buffer, matmul_buffer[1], state_hidden_weight, state_hidden_bias, gru_param->batch_,
|
||||
RowMajor2Col16MajorFp16(reset_gate, packed_state, gru_param->batch_, gru_param->hidden_size_, false);
|
||||
LstmMatMulFp16(state_hidden_buffer, packed_state, state_hidden_weight, state_hidden_bias, gru_param->batch_,
|
||||
gru_param->hidden_size_, gru_param->hidden_size_, is_vec);
|
||||
}
|
||||
ElementAddFp16(hidden_buffer, state_hidden_buffer, hidden_buffer, gru_param->batch_ * gru_param->hidden_size_);
|
||||
|
@ -106,16 +86,41 @@ void GruStepUnitFp16(float16_t *output, const float16_t *input, const float16_t
|
|||
memcpy(output, hidden_state, gru_param->batch_ * gru_param->hidden_size_ * sizeof(float16_t));
|
||||
}
|
||||
|
||||
void GruUnidirectionalFp16(float16_t *output, const float16_t *packed_input, const float16_t *weight_g,
|
||||
const float16_t *weight_r, const float16_t *input_bias, const float16_t *state_bias,
|
||||
float16_t *hidden_state, float16_t *buffer[4], const GruParameter *gru_param,
|
||||
bool is_backward) {
|
||||
float16_t *gate = buffer[1];
|
||||
for (int i = 0; i < 3; i++) {
|
||||
const float16_t *weight_loop = weight_g + gru_param->input_size_ * gru_param->input_col_align_ * i;
|
||||
const float16_t *bias_loop = input_bias + gru_param->input_col_align_ * i;
|
||||
float16_t *gate_loop = gate + gru_param->seq_len_ * gru_param->batch_ * gru_param->hidden_size_ * i;
|
||||
MatMulFp16(packed_input, weight_loop, gate_loop, bias_loop, ActType_No, gru_param->input_size_,
|
||||
gru_param->seq_len_ * gru_param->batch_, gru_param->hidden_size_, gru_param->hidden_size_, OutType_Nhwc);
|
||||
}
|
||||
|
||||
float16_t *update_gate = gate;
|
||||
float16_t *reset_gate = gate + gru_param->seq_len_ * gru_param->batch_ * gru_param->hidden_size_;
|
||||
float16_t *hidden_buffer = gate + gru_param->seq_len_ * gru_param->batch_ * gru_param->hidden_size_ * 2;
|
||||
for (int t = 0; t < gru_param->seq_len_; t++) {
|
||||
int real_t = is_backward ? gru_param->seq_len_ - t - 1 : t;
|
||||
float16_t *update_gate_t = update_gate + gru_param->batch_ * gru_param->hidden_size_ * real_t;
|
||||
float16_t *reset_gate_t = reset_gate + gru_param->batch_ * gru_param->hidden_size_ * real_t;
|
||||
float16_t *hidden_buffer_t = hidden_buffer + gru_param->batch_ * gru_param->hidden_size_ * real_t;
|
||||
float16_t *output_ptr = output + real_t * gru_param->output_step_;
|
||||
GruStepUnitFp16(output_ptr, update_gate_t, reset_gate_t, hidden_buffer_t, weight_r, state_bias, hidden_state,
|
||||
buffer, gru_param);
|
||||
}
|
||||
}
|
||||
|
||||
void GruFp16(float16_t *output, const float16_t *input, const float16_t *weight_g, const float16_t *weight_r,
|
||||
const float16_t *bias, float16_t *hidden_state, float16_t *gate_buffer, float16_t *matmul_buffer[2],
|
||||
const float16_t *input_bias, const float16_t *state_bias, float16_t *hidden_state, float16_t *buffer[4],
|
||||
int check_seq_len, const GruParameter *gru_param) {
|
||||
// forward
|
||||
for (int t = 0; t < check_seq_len; t++) {
|
||||
const float16_t *input_ptr = input + t * gru_param->input_step_;
|
||||
float16_t *output_ptr = output + t * gru_param->output_step_;
|
||||
GruStepUnitFp16(output_ptr, input_ptr, weight_g, weight_r, bias, hidden_state, gate_buffer, matmul_buffer,
|
||||
gru_param);
|
||||
}
|
||||
float16_t *packed_input = buffer[0];
|
||||
RowMajor2Col16MajorFp16(input, packed_input, gru_param->seq_len_ * gru_param->batch_, gru_param->input_size_, false);
|
||||
GruUnidirectionalFp16(output, packed_input, weight_g, weight_r, input_bias, state_bias, hidden_state, buffer,
|
||||
gru_param, false);
|
||||
// zero out extra fw outputs
|
||||
for (int t = check_seq_len; t < gru_param->seq_len_; t++) {
|
||||
float16_t *output_ptr = output + t * gru_param->output_step_;
|
||||
|
@ -126,17 +131,15 @@ void GruFp16(float16_t *output, const float16_t *input, const float16_t *weight_
|
|||
|
||||
// backward
|
||||
if (gru_param->bidirectional_) {
|
||||
const float16_t *backward_weight_g = weight_g + 3 * gru_param->col_align_ * gru_param->input_size_;
|
||||
const float16_t *backward_weight_r = weight_r + 3 * gru_param->col_align_ * gru_param->hidden_size_;
|
||||
const float16_t *backward_bias = bias + 6 * gru_param->hidden_size_;
|
||||
const float16_t *backward_weight_g = weight_g + 3 * gru_param->input_col_align_ * gru_param->input_size_;
|
||||
const float16_t *backward_weight_r = weight_r + 3 * gru_param->state_col_align_ * gru_param->hidden_size_;
|
||||
const float16_t *backward_input_bias = input_bias + 3 * gru_param->input_col_align_;
|
||||
const float16_t *backward_state_bias = state_bias + 3 * gru_param->state_col_align_;
|
||||
float16_t *backward_output = output + gru_param->batch_ * gru_param->hidden_size_;
|
||||
float16_t *backward_hidden_state = hidden_state + gru_param->batch_ * gru_param->hidden_size_;
|
||||
for (int t = check_seq_len - 1; t >= 0; t--) {
|
||||
const float16_t *input_ptr = input + t * gru_param->input_step_;
|
||||
float16_t *output_ptr = backward_output + t * gru_param->output_step_;
|
||||
GruStepUnitFp16(output_ptr, input_ptr, backward_weight_g, backward_weight_r, backward_bias, backward_hidden_state,
|
||||
gate_buffer, matmul_buffer, gru_param);
|
||||
}
|
||||
GruUnidirectionalFp16(backward_output, packed_input, backward_weight_g, backward_weight_r, backward_input_bias,
|
||||
backward_state_bias, backward_hidden_state, buffer, gru_param, true);
|
||||
|
||||
// zero out extra bw outputs
|
||||
for (int t = gru_param->seq_len_ - 1; t >= check_seq_len; t--) {
|
||||
float16_t *output_ptr = backward_output + t * gru_param->output_step_;
|
||||
|
|
|
@ -21,7 +21,7 @@
|
|||
extern "C" {
|
||||
#endif
|
||||
void GruFp16(float16_t *output, const float16_t *input, const float16_t *weight_g, const float16_t *weight_r,
|
||||
const float16_t *bias, float16_t *hidden_state, float16_t *gate_buffer, float16_t *matmul_buffer[2],
|
||||
const float16_t *input_bias, const float16_t *state_bias, float16_t *hidden_state, float16_t *buffer[4],
|
||||
int check_seq_len, const GruParameter *gru_param);
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
|
|
|
@ -20,6 +20,7 @@
|
|||
#include "nnacl/fp16/activation_fp16.h"
|
||||
#include "nnacl/fp16/arithmetic_fp16.h"
|
||||
#include "nnacl/fp16/matmul_fp16.h"
|
||||
#include "nnacl/fp16/cast_fp16.h"
|
||||
|
||||
void PackLstmWeightFp32ToFp16(float16_t *dst, const float *src, int batch, int deep, int col, int col_align) {
|
||||
for (int i = 0; i < batch; i++) {
|
||||
|
@ -37,6 +38,43 @@ void PackLstmWeightFp16(float16_t *dst, const float16_t *src, int batch, int dee
|
|||
}
|
||||
}
|
||||
|
||||
void PackLstmBiasFp32ToFp16(float16_t *dst, const float *src, int batch, int col, int col_align,
|
||||
bool is_bidirectional) {
|
||||
int unidirectional_batch = is_bidirectional ? batch / 2 : batch;
|
||||
for (int i = 0; i < unidirectional_batch; i++) {
|
||||
const float *src_batch = src + i * col;
|
||||
float16_t *dst_batch = dst + i * col_align;
|
||||
Float32ToFloat16(src_batch, dst_batch, col);
|
||||
}
|
||||
if (is_bidirectional) {
|
||||
const float *backward_src = src + batch * col;
|
||||
float16_t *backward_dst = dst + unidirectional_batch * col_align;
|
||||
for (int i = 0; i < unidirectional_batch; i++) {
|
||||
const float *backward_src_batch = backward_src + i * col;
|
||||
float16_t *backward_dst_batch = backward_dst + i * col_align;
|
||||
Float32ToFloat16(backward_src_batch, backward_dst_batch, col);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void PackLstmBiasFp16(float16_t *dst, const float16_t *src, int batch, int col, int col_align, bool is_bidirectional) {
|
||||
int unidirectional_batch = is_bidirectional ? batch / 2 : batch;
|
||||
for (int i = 0; i < unidirectional_batch; i++) {
|
||||
const float16_t *src_batch = src + i * col;
|
||||
float16_t *dst_batch = dst + i * col_align;
|
||||
memcpy(dst_batch, src_batch, col * sizeof(float16_t));
|
||||
}
|
||||
if (is_bidirectional) {
|
||||
const float16_t *backward_src = src + batch * col;
|
||||
float16_t *backward_dst = dst + unidirectional_batch * col_align;
|
||||
for (int i = 0; i < unidirectional_batch; i++) {
|
||||
const float16_t *backward_src_batch = backward_src + i * col;
|
||||
float16_t *backward_dst_batch = backward_dst + i * col_align;
|
||||
memcpy(backward_dst_batch, backward_src_batch, col * sizeof(float16_t));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// input: [row, inner_size]; weight: [col, inner_size]; output: [row, col]
|
||||
void MatMulAccFp16(float16_t *output, const float16_t *input, const float16_t *weight, int rows, int cols,
|
||||
int inner_size) {
|
||||
|
@ -149,40 +187,32 @@ void UpdateLstmGateFp16(float16_t *gate_buffer, const float16_t *input, const fl
|
|||
}
|
||||
}
|
||||
|
||||
void LstmStepUnitFp16(float16_t *output, const float16_t *input, const float16_t *input_weight,
|
||||
const float16_t *state_weight, const float16_t *bias, float16_t *hidden_state,
|
||||
float16_t *cell_state, float16_t *gate_buffer, float16_t *state_buffer[2],
|
||||
float16_t *matmul_buffer[2], const LstmParameter *lstm_param) {
|
||||
void LstmStepUnitFp16(float16_t *output, float16_t *input_gate, float16_t *forget_gate, float16_t *cell_gate,
|
||||
float16_t *output_gate, const float16_t *state_weight, const float16_t *state_bias,
|
||||
float16_t *hidden_state, float16_t *cell_state, float16_t *buffer[6],
|
||||
const LstmParameter *lstm_param) {
|
||||
float16_t *packed_state = buffer[2];
|
||||
float16_t *state_gate = buffer[3];
|
||||
float16_t *cell_buffer = buffer[4];
|
||||
float16_t *hidden_buffer = buffer[5];
|
||||
bool is_vec = lstm_param->batch_ == 1;
|
||||
// input * weight
|
||||
if (is_vec) {
|
||||
UpdateLstmGateFp16(gate_buffer, input, input_weight, bias, lstm_param->batch_, lstm_param->input_size_,
|
||||
lstm_param->hidden_size_, lstm_param->col_align_, is_vec);
|
||||
} else {
|
||||
// pack input for matmul
|
||||
RowMajor2Col16MajorFp16(input, matmul_buffer[0], lstm_param->batch_, lstm_param->input_size_, false);
|
||||
UpdateLstmGateFp16(gate_buffer, matmul_buffer[0], input_weight, bias, lstm_param->batch_, lstm_param->input_size_,
|
||||
lstm_param->hidden_size_, lstm_param->col_align_, is_vec);
|
||||
}
|
||||
|
||||
// state * weight
|
||||
float16_t *state_gate = gate_buffer + lstm_param->batch_ * lstm_param->hidden_size_ * 4;
|
||||
const float16_t *state_bias = bias + lstm_param->col_align_ * 4;
|
||||
if (is_vec) {
|
||||
UpdateLstmGateFp16(state_gate, hidden_state, state_weight, state_bias, lstm_param->batch_, lstm_param->hidden_size_,
|
||||
lstm_param->hidden_size_, lstm_param->col_align_, is_vec);
|
||||
lstm_param->hidden_size_, lstm_param->state_col_align_, is_vec);
|
||||
} else {
|
||||
// pack state for matmul
|
||||
RowMajor2Col16MajorFp16(hidden_state, matmul_buffer[1], lstm_param->batch_, lstm_param->hidden_size_, false);
|
||||
UpdateLstmGateFp16(state_gate, matmul_buffer[1], state_weight, state_bias, lstm_param->batch_,
|
||||
lstm_param->hidden_size_, lstm_param->hidden_size_, lstm_param->col_align_, is_vec);
|
||||
RowMajor2Col16MajorFp16(hidden_state, packed_state, lstm_param->batch_, lstm_param->hidden_size_, false);
|
||||
UpdateLstmGateFp16(state_gate, packed_state, state_weight, state_bias, lstm_param->batch_, lstm_param->hidden_size_,
|
||||
lstm_param->hidden_size_, lstm_param->state_col_align_, is_vec);
|
||||
}
|
||||
ElementAddFp16(gate_buffer, state_gate, gate_buffer, 4 * lstm_param->batch_ * lstm_param->hidden_size_);
|
||||
ElementAddFp16(input_gate, state_gate, input_gate, lstm_param->batch_ * lstm_param->hidden_size_);
|
||||
ElementAddFp16(forget_gate, state_gate + lstm_param->batch_ * lstm_param->hidden_size_ * 2, forget_gate,
|
||||
lstm_param->batch_ * lstm_param->hidden_size_);
|
||||
ElementAddFp16(cell_gate, state_gate + lstm_param->batch_ * lstm_param->hidden_size_ * 3, cell_gate,
|
||||
lstm_param->batch_ * lstm_param->hidden_size_);
|
||||
ElementAddFp16(output_gate, state_gate + lstm_param->batch_ * lstm_param->hidden_size_, output_gate,
|
||||
lstm_param->batch_ * lstm_param->hidden_size_);
|
||||
|
||||
float16_t *input_gate = gate_buffer;
|
||||
float16_t *forget_gate = gate_buffer + lstm_param->batch_ * lstm_param->hidden_size_ * 2;
|
||||
float16_t *cell_gate = gate_buffer + lstm_param->batch_ * lstm_param->hidden_size_ * 3;
|
||||
float16_t *output_gate = gate_buffer + lstm_param->batch_ * lstm_param->hidden_size_;
|
||||
// update input_gate
|
||||
SigmoidFp16(input_gate, input_gate, lstm_param->batch_ * lstm_param->hidden_size_);
|
||||
|
||||
|
@ -192,50 +222,76 @@ void LstmStepUnitFp16(float16_t *output, const float16_t *input, const float16_t
|
|||
// update cell_gate
|
||||
TanhFp16(cell_gate, cell_gate, lstm_param->batch_ * lstm_param->hidden_size_);
|
||||
// update cell state
|
||||
UpdataStateFp16(cell_state, forget_gate, input_gate, cell_gate, state_buffer[0], lstm_param->batch_,
|
||||
UpdataStateFp16(cell_state, forget_gate, input_gate, cell_gate, cell_buffer, lstm_param->batch_,
|
||||
lstm_param->hidden_size_, lstm_param->zoneout_cell_);
|
||||
|
||||
// update output_gate
|
||||
SigmoidFp16(output_gate, output_gate, lstm_param->batch_ * lstm_param->hidden_size_);
|
||||
// update output
|
||||
UpdataOutputFp16(cell_state, output_gate, hidden_state, state_buffer[1], lstm_param->batch_, lstm_param->hidden_size_,
|
||||
UpdataOutputFp16(cell_state, output_gate, hidden_state, hidden_buffer, lstm_param->batch_, lstm_param->hidden_size_,
|
||||
lstm_param->zoneout_hidden_);
|
||||
memcpy(output, hidden_state, lstm_param->batch_ * lstm_param->hidden_size_ * sizeof(float16_t));
|
||||
|
||||
if (!(lstm_param->zoneout_cell_ >= -FLT_EPSILON && lstm_param->zoneout_cell_ <= FLT_EPSILON)) {
|
||||
memcpy(cell_state, state_buffer[0], lstm_param->batch_ * lstm_param->hidden_size_ * sizeof(float16_t));
|
||||
memcpy(cell_state, cell_buffer, lstm_param->batch_ * lstm_param->hidden_size_ * sizeof(float16_t));
|
||||
}
|
||||
|
||||
if (!(lstm_param->zoneout_hidden_ >= -FLT_EPSILON && lstm_param->zoneout_hidden_ <= FLT_EPSILON)) {
|
||||
memcpy(hidden_state, state_buffer[1], lstm_param->batch_ * lstm_param->hidden_size_ * sizeof(float16_t));
|
||||
memcpy(hidden_state, hidden_buffer, lstm_param->batch_ * lstm_param->hidden_size_ * sizeof(float16_t));
|
||||
}
|
||||
}
|
||||
|
||||
void LstmUnidirectionalFp16(float16_t *output, const float16_t *packed_input, const float16_t *weight_i,
|
||||
const float16_t *weight_h, const float16_t *input_bias, const float16_t *state_bias,
|
||||
float16_t *hidden_state, float16_t *cell_state, float16_t *buffer[6],
|
||||
const LstmParameter *lstm_param, bool is_backward) {
|
||||
float16_t *gate = buffer[1];
|
||||
for (int i = 0; i < 4; i++) {
|
||||
const float16_t *weight_loop = weight_i + lstm_param->input_size_ * lstm_param->input_col_align_ * i;
|
||||
const float16_t *bias_loop = input_bias + lstm_param->input_col_align_ * i;
|
||||
float16_t *gate_loop = gate + lstm_param->seq_len_ * lstm_param->batch_ * lstm_param->hidden_size_ * i;
|
||||
MatMulFp16(packed_input, weight_loop, gate_loop, bias_loop, ActType_No, lstm_param->input_size_,
|
||||
lstm_param->seq_len_ * lstm_param->batch_, lstm_param->hidden_size_, lstm_param->hidden_size_,
|
||||
OutType_Nhwc);
|
||||
}
|
||||
|
||||
float16_t *input_gate = gate;
|
||||
float16_t *forget_gate = gate + lstm_param->seq_len_ * lstm_param->batch_ * lstm_param->hidden_size_ * 2;
|
||||
float16_t *cell_gate = gate + lstm_param->seq_len_ * lstm_param->batch_ * lstm_param->hidden_size_ * 3;
|
||||
float16_t *output_gate = gate + lstm_param->seq_len_ * lstm_param->batch_ * lstm_param->hidden_size_;
|
||||
for (int t = 0; t < lstm_param->seq_len_; t++) {
|
||||
int real_t = is_backward ? lstm_param->seq_len_ - t - 1 : t;
|
||||
float16_t *input_gate_t = input_gate + lstm_param->batch_ * lstm_param->hidden_size_ * real_t;
|
||||
float16_t *forget_gate_t = forget_gate + lstm_param->batch_ * lstm_param->hidden_size_ * real_t;
|
||||
float16_t *cell_gate_t = cell_gate + lstm_param->batch_ * lstm_param->hidden_size_ * real_t;
|
||||
float16_t *output_gate_t = output_gate + lstm_param->batch_ * lstm_param->hidden_size_ * real_t;
|
||||
float16_t *output_ptr = output + real_t * lstm_param->output_step_;
|
||||
LstmStepUnitFp16(output_ptr, input_gate_t, forget_gate_t, cell_gate_t, output_gate_t, weight_h, state_bias,
|
||||
hidden_state, cell_state, buffer, lstm_param);
|
||||
}
|
||||
}
|
||||
|
||||
void LstmFp16(float16_t *output, const float16_t *input, const float16_t *weight_i, const float16_t *weight_h,
|
||||
const float16_t *bias, float16_t *hidden_state, float16_t *cell_state, float16_t *gate_buffer,
|
||||
float16_t *state_buffer[2], float16_t *matmul_buffer[2], const LstmParameter *lstm_param) {
|
||||
const float16_t *input_bias, const float16_t *state_bias, float16_t *hidden_state, float16_t *cell_state,
|
||||
float16_t *buffer[6], const LstmParameter *lstm_param) {
|
||||
// forward
|
||||
for (int t = 0; t < lstm_param->seq_len_; t++) {
|
||||
const float16_t *input_ptr = input + t * lstm_param->input_step_;
|
||||
float16_t *output_ptr = output + t * lstm_param->output_step_;
|
||||
LstmStepUnitFp16(output_ptr, input_ptr, weight_i, weight_h, bias, hidden_state, cell_state, gate_buffer,
|
||||
state_buffer, matmul_buffer, lstm_param);
|
||||
}
|
||||
float16_t *packed_input = buffer[0];
|
||||
RowMajor2Col16MajorFp16(input, packed_input, lstm_param->seq_len_ * lstm_param->batch_, lstm_param->input_size_,
|
||||
false);
|
||||
LstmUnidirectionalFp16(output, packed_input, weight_i, weight_h, input_bias, state_bias, hidden_state, cell_state,
|
||||
buffer, lstm_param, false);
|
||||
|
||||
// backward
|
||||
if (lstm_param->bidirectional_) {
|
||||
const float16_t *backward_weight_i = weight_i + 4 * lstm_param->col_align_ * lstm_param->input_size_;
|
||||
const float16_t *backward_weight_h = weight_h + 4 * lstm_param->col_align_ * lstm_param->hidden_size_;
|
||||
const float16_t *backward_bias = bias + 8 * lstm_param->col_align_;
|
||||
const float16_t *backward_weight_i = weight_i + 4 * lstm_param->input_col_align_ * lstm_param->input_size_;
|
||||
const float16_t *backward_weight_h = weight_h + 4 * lstm_param->state_col_align_ * lstm_param->hidden_size_;
|
||||
const float16_t *backward_input_bias = input_bias + 4 * lstm_param->input_col_align_;
|
||||
const float16_t *backward_state_bias = state_bias + 4 * lstm_param->state_col_align_;
|
||||
float16_t *backward_output = output + lstm_param->batch_ * lstm_param->hidden_size_;
|
||||
float16_t *backward_cell_state = cell_state + lstm_param->batch_ * lstm_param->hidden_size_;
|
||||
float16_t *backward_hidden_state = hidden_state + lstm_param->batch_ * lstm_param->hidden_size_;
|
||||
for (int t = lstm_param->seq_len_ - 1; t >= 0; t--) {
|
||||
const float16_t *input_ptr = input + t * lstm_param->input_step_;
|
||||
float16_t *output_ptr = backward_output + t * lstm_param->output_step_;
|
||||
LstmStepUnitFp16(output_ptr, input_ptr, backward_weight_i, backward_weight_h, backward_bias,
|
||||
backward_hidden_state, backward_cell_state, gate_buffer, state_buffer, matmul_buffer,
|
||||
lstm_param);
|
||||
}
|
||||
|
||||
LstmUnidirectionalFp16(backward_output, packed_input, backward_weight_i, backward_weight_h, backward_input_bias,
|
||||
backward_state_bias, backward_hidden_state, backward_cell_state, buffer, lstm_param, true);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -25,6 +25,10 @@ void PackLstmWeightFp32ToFp16(float16_t *dst, const float *src, int batch, int d
|
|||
|
||||
void PackLstmWeightFp16(float16_t *dst, const float16_t *src, int batch, int deep, int col, int col_align);
|
||||
|
||||
void PackLstmBiasFp32ToFp16(float16_t *dst, const float *src, int batch, int col, int col_align, bool is_bidirectional);
|
||||
|
||||
void PackLstmBiasFp16(float16_t *dst, const float16_t *src, int batch, int col, int col_align, bool is_bidirectional);
|
||||
|
||||
void LstmMatMulFp16(float16_t *c, const float16_t *a, const float16_t *b, const float16_t *bias, int row, int deep,
|
||||
int col, bool is_vec);
|
||||
|
||||
|
@ -36,8 +40,8 @@ void ElementMulAccFp16(const float16_t *input0, const float16_t *input1, float16
|
|||
int ElementOptMulAccFp16(const float16_t *input0, const float16_t input1, float16_t *output, const int element_size);
|
||||
|
||||
void LstmFp16(float16_t *output, const float16_t *input, const float16_t *weight_i, const float16_t *weight_h,
|
||||
const float16_t *bias, float16_t *hidden_state, float16_t *cell_state, float16_t *gate_buffer,
|
||||
float16_t *state_buffer[2], float16_t *matmul_buffer[2], const LstmParameter *lstm_param);
|
||||
const float16_t *input_bias, const float16_t *state_bias, float16_t *hidden_state, float16_t *cell_state,
|
||||
float16_t *buffer[6], const LstmParameter *lstm_param);
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
|
|
|
@ -20,40 +20,21 @@
|
|||
#include "nnacl/fp32/arithmetic_fp32.h"
|
||||
#include "nnacl/fp32/matmul_fp32.h"
|
||||
|
||||
void UpdateGruInputGate(float *gate_buffer, const float *input, const float *weight, const float *bias, int row,
|
||||
int deep, int col, int col_align, bool is_vec) {
|
||||
for (int i = 0; i < 3; i++) {
|
||||
const float *weight_i = weight + deep * col * i;
|
||||
const float *bias_i = bias + col_align * i;
|
||||
float *gate = gate_buffer + row * col * i;
|
||||
LstmMatMul(gate, input, weight_i, bias_i, row, deep, col, is_vec);
|
||||
}
|
||||
}
|
||||
|
||||
void GruStepUnit(float *output, const float *input, const float *input_weight, const float *state_weight,
|
||||
const float *bias, float *hidden_state, float *gate_buffer, float *matmul_buffer[2],
|
||||
const GruParameter *gru_param) {
|
||||
void GruStepUnit(float *output, float *update_gate, float *reset_gate, float *hidden_buffer, const float *state_weight,
|
||||
const float *state_bias, float *hidden_state, float *buffer[4], const GruParameter *gru_param) {
|
||||
float *packed_state = buffer[2];
|
||||
float *state_gate = buffer[3];
|
||||
bool is_vec = gru_param->batch_ == 1;
|
||||
// input * weight
|
||||
if (is_vec) {
|
||||
UpdateGruInputGate(gate_buffer, input, input_weight, bias, gru_param->batch_, gru_param->input_size_,
|
||||
gru_param->hidden_size_, gru_param->col_align_, is_vec);
|
||||
} else {
|
||||
// pack input for matmul
|
||||
PackLstmInput(input, matmul_buffer[0], gru_param->batch_, gru_param->input_size_);
|
||||
UpdateGruInputGate(gate_buffer, matmul_buffer[0], input_weight, bias, gru_param->batch_, gru_param->input_size_,
|
||||
gru_param->hidden_size_, gru_param->col_align_, is_vec);
|
||||
}
|
||||
|
||||
const float *state_update_weight = state_weight;
|
||||
const float *state_reset_weight = state_weight + gru_param->hidden_size_ * gru_param->hidden_size_;
|
||||
const float *state_hidden_weight = state_weight + gru_param->hidden_size_ * gru_param->hidden_size_ * 2;
|
||||
float *state_update_gate = gate_buffer + gru_param->batch_ * gru_param->hidden_size_ * 3;
|
||||
float *state_reset_gate = gate_buffer + gru_param->batch_ * gru_param->hidden_size_ * 4;
|
||||
float *state_hidden_buffer = gate_buffer + gru_param->batch_ * gru_param->hidden_size_ * 5;
|
||||
const float *state_update_bias = bias + gru_param->hidden_size_ * 3;
|
||||
const float *state_reset_bias = bias + gru_param->hidden_size_ * 4;
|
||||
const float *state_hidden_bias = bias + gru_param->hidden_size_ * 5;
|
||||
float *state_update_gate = state_gate;
|
||||
float *state_reset_gate = state_gate + gru_param->batch_ * gru_param->hidden_size_;
|
||||
float *state_hidden_buffer = state_gate + gru_param->batch_ * gru_param->hidden_size_ * 2;
|
||||
const float *state_update_bias = state_bias;
|
||||
const float *state_reset_bias = state_bias + gru_param->hidden_size_;
|
||||
const float *state_hidden_bias = state_bias + gru_param->hidden_size_ * 2;
|
||||
|
||||
// state * weight
|
||||
if (is_vec) {
|
||||
|
@ -62,16 +43,15 @@ void GruStepUnit(float *output, const float *input, const float *input_weight, c
|
|||
LstmMatMul(state_update_gate, hidden_state, state_update_weight, state_update_bias, gru_param->batch_,
|
||||
gru_param->hidden_size_, gru_param->hidden_size_, is_vec);
|
||||
} else {
|
||||
PackLstmInput(hidden_state, matmul_buffer[1], gru_param->batch_, gru_param->hidden_size_);
|
||||
LstmMatMul(state_reset_gate, matmul_buffer[1], state_reset_weight, state_reset_bias, gru_param->batch_,
|
||||
PackLstmInput(hidden_state, packed_state, gru_param->batch_, gru_param->hidden_size_);
|
||||
LstmMatMul(state_reset_gate, packed_state, state_reset_weight, state_reset_bias, gru_param->batch_,
|
||||
gru_param->hidden_size_, gru_param->hidden_size_, is_vec);
|
||||
LstmMatMul(state_update_gate, matmul_buffer[1], state_update_weight, state_update_bias, gru_param->batch_,
|
||||
LstmMatMul(state_update_gate, packed_state, state_update_weight, state_update_bias, gru_param->batch_,
|
||||
gru_param->hidden_size_, gru_param->hidden_size_, is_vec);
|
||||
}
|
||||
ElementAdd(gate_buffer, state_update_gate, gate_buffer, gru_param->batch_ * gru_param->hidden_size_ * 2);
|
||||
float *update_gate = gate_buffer;
|
||||
float *reset_gate = gate_buffer + gru_param->batch_ * gru_param->hidden_size_;
|
||||
float *hidden_buffer = gate_buffer + gru_param->batch_ * gru_param->hidden_size_ * 2;
|
||||
ElementAdd(update_gate, state_update_gate, update_gate, gru_param->batch_ * gru_param->hidden_size_);
|
||||
ElementAdd(reset_gate, state_update_gate + gru_param->batch_ * gru_param->hidden_size_, reset_gate,
|
||||
gru_param->batch_ * gru_param->hidden_size_);
|
||||
|
||||
// update reset_gate
|
||||
Sigmoid(reset_gate, gru_param->batch_ * gru_param->hidden_size_, reset_gate);
|
||||
|
@ -83,8 +63,8 @@ void GruStepUnit(float *output, const float *input, const float *input_weight, c
|
|||
LstmMatMul(state_hidden_buffer, reset_gate, state_hidden_weight, state_hidden_bias, gru_param->batch_,
|
||||
gru_param->hidden_size_, gru_param->hidden_size_, is_vec);
|
||||
} else {
|
||||
PackLstmInput(reset_gate, matmul_buffer[1], gru_param->batch_, gru_param->hidden_size_);
|
||||
LstmMatMul(state_hidden_buffer, matmul_buffer[1], state_hidden_weight, state_hidden_bias, gru_param->batch_,
|
||||
PackLstmInput(reset_gate, packed_state, gru_param->batch_, gru_param->hidden_size_);
|
||||
LstmMatMul(state_hidden_buffer, packed_state, state_hidden_weight, state_hidden_bias, gru_param->batch_,
|
||||
gru_param->hidden_size_, gru_param->hidden_size_, is_vec);
|
||||
}
|
||||
ElementAdd(hidden_buffer, state_hidden_buffer, hidden_buffer, gru_param->batch_ * gru_param->hidden_size_);
|
||||
|
@ -104,15 +84,41 @@ void GruStepUnit(float *output, const float *input, const float *input_weight, c
|
|||
memcpy(output, hidden_state, gru_param->batch_ * gru_param->hidden_size_ * sizeof(float));
|
||||
}
|
||||
|
||||
void Gru(float *output, const float *input, const float *weight_g, const float *weight_r, const float *bias,
|
||||
float *hidden_state, float *gate_buffer, float *matmul_buffer[2], int check_seq_len,
|
||||
void GruUnidirectional(float *output, const float *packed_input, const float *weight_g, const float *weight_r,
|
||||
const float *input_bias, const float *state_bias, float *hidden_state, float *buffer[4],
|
||||
const GruParameter *gru_param, bool is_backward) {
|
||||
float *gate = buffer[1];
|
||||
for (int i = 0; i < 3; i++) {
|
||||
const float *weight_loop = weight_g + gru_param->input_size_ * gru_param->input_col_align_ * i;
|
||||
const float *bias_loop = input_bias + gru_param->input_col_align_ * i;
|
||||
float *gate_loop = gate + gru_param->seq_len_ * gru_param->batch_ * gru_param->hidden_size_ * i;
|
||||
MatMulOpt(packed_input, weight_loop, gate_loop, bias_loop, ActType_No, gru_param->input_size_,
|
||||
gru_param->seq_len_ * gru_param->batch_, gru_param->hidden_size_, gru_param->hidden_size_, OutType_Nhwc);
|
||||
}
|
||||
|
||||
float *update_gate = gate;
|
||||
float *reset_gate = gate + gru_param->seq_len_ * gru_param->batch_ * gru_param->hidden_size_;
|
||||
float *hidden_buffer = gate + gru_param->seq_len_ * gru_param->batch_ * gru_param->hidden_size_ * 2;
|
||||
for (int t = 0; t < gru_param->seq_len_; t++) {
|
||||
int real_t = is_backward ? gru_param->seq_len_ - t - 1 : t;
|
||||
float *update_gate_t = update_gate + gru_param->batch_ * gru_param->hidden_size_ * real_t;
|
||||
float *reset_gate_t = reset_gate + gru_param->batch_ * gru_param->hidden_size_ * real_t;
|
||||
float *hidden_buffer_t = hidden_buffer + gru_param->batch_ * gru_param->hidden_size_ * real_t;
|
||||
float *output_ptr = output + real_t * gru_param->output_step_;
|
||||
GruStepUnit(output_ptr, update_gate_t, reset_gate_t, hidden_buffer_t, weight_r, state_bias, hidden_state, buffer,
|
||||
gru_param);
|
||||
}
|
||||
}
|
||||
|
||||
void Gru(float *output, const float *input, const float *weight_g, const float *weight_r, const float *input_bias,
|
||||
const float *state_bias, float *hidden_state, float *buffer[4], int check_seq_len,
|
||||
const GruParameter *gru_param) {
|
||||
// forward
|
||||
for (int t = 0; t < check_seq_len; t++) {
|
||||
const float *input_ptr = input + t * gru_param->input_step_;
|
||||
float *output_ptr = output + t * gru_param->output_step_;
|
||||
GruStepUnit(output_ptr, input_ptr, weight_g, weight_r, bias, hidden_state, gate_buffer, matmul_buffer, gru_param);
|
||||
}
|
||||
float *packed_input = buffer[0];
|
||||
PackLstmInput(input, packed_input, gru_param->seq_len_ * gru_param->batch_, gru_param->input_size_);
|
||||
GruUnidirectional(output, packed_input, weight_g, weight_r, input_bias, state_bias, hidden_state, buffer, gru_param,
|
||||
false);
|
||||
|
||||
// zero out extra fw outputs
|
||||
for (int t = check_seq_len; t < gru_param->seq_len_; t++) {
|
||||
float *output_ptr = output + t * gru_param->output_step_;
|
||||
|
@ -123,17 +129,16 @@ void Gru(float *output, const float *input, const float *weight_g, const float *
|
|||
|
||||
// backward
|
||||
if (gru_param->bidirectional_) {
|
||||
const float *backward_weight_g = weight_g + 3 * gru_param->col_align_ * gru_param->input_size_;
|
||||
const float *backward_weight_r = weight_r + 3 * gru_param->col_align_ * gru_param->hidden_size_;
|
||||
const float *backward_bias = bias + 6 * gru_param->hidden_size_;
|
||||
const float *backward_weight_g = weight_g + 3 * gru_param->input_col_align_ * gru_param->input_size_;
|
||||
const float *backward_weight_r = weight_r + 3 * gru_param->state_col_align_ * gru_param->hidden_size_;
|
||||
const float *backward_input_bias = input_bias + 3 * gru_param->input_col_align_;
|
||||
const float *backward_state_bias = state_bias + 3 * gru_param->state_col_align_;
|
||||
float *backward_output = output + gru_param->batch_ * gru_param->hidden_size_;
|
||||
float *backward_hidden_state = hidden_state + gru_param->batch_ * gru_param->hidden_size_;
|
||||
for (int t = check_seq_len - 1; t >= 0; t--) {
|
||||
const float *input_ptr = input + t * gru_param->input_step_;
|
||||
float *output_ptr = backward_output + t * gru_param->output_step_;
|
||||
GruStepUnit(output_ptr, input_ptr, backward_weight_g, backward_weight_r, backward_bias, backward_hidden_state,
|
||||
gate_buffer, matmul_buffer, gru_param);
|
||||
}
|
||||
|
||||
GruUnidirectional(backward_output, packed_input, backward_weight_g, backward_weight_r, backward_input_bias,
|
||||
backward_state_bias, backward_hidden_state, buffer, gru_param, true);
|
||||
|
||||
// zero out extra bw outputs
|
||||
for (int t = gru_param->seq_len_ - 1; t >= check_seq_len; t--) {
|
||||
float *output_ptr = backward_output + t * gru_param->output_step_;
|
||||
|
|
|
@ -20,8 +20,8 @@
|
|||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif
|
||||
void Gru(float *output, const float *input, const float *weight_g, const float *weight_r, const float *bias,
|
||||
float *hidden_state, float *gate_buffer, float *matmul_buffer[2], int check_seq_len,
|
||||
void Gru(float *output, const float *input, const float *weight_g, const float *weight_r, const float *input_bias,
|
||||
const float *state_bias, float *hidden_state, float *buffer[4], int check_seq_len,
|
||||
const GruParameter *gru_parm);
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
|
|
|
@ -63,37 +63,6 @@ void PackLstmInput(const float *src, float *dst, int row, int deep) {
|
|||
#endif
|
||||
}
|
||||
|
||||
// input: [row, inner_size]; weight: [col, inner_size]; output: [row, col]
|
||||
void MatMulAcc(float *output, const float *input, const float *weight, int rows, int cols, int inner_size) {
|
||||
for (int r = 0; r < rows; r++) {
|
||||
for (int c = 0; c < cols; c++) {
|
||||
float res = 0;
|
||||
const float *input_col = input + r * inner_size;
|
||||
const float *weight_col = weight + c * inner_size;
|
||||
int index = 0;
|
||||
#ifdef ENABLE_ARM
|
||||
float32x4_t out = vdupq_n_f32(0.0f);
|
||||
for (; index <= inner_size - 4; index += 4) {
|
||||
float32x4_t in_0 = vld1q_f32(input_col + index);
|
||||
float32x4_t in_1 = vld1q_f32(weight_col + index);
|
||||
out = vmlaq_f32(out, in_1, in_0);
|
||||
}
|
||||
#ifdef ENABLE_ARM64
|
||||
res += vaddvq_f32(out);
|
||||
#else
|
||||
float32x2_t add2 = vadd_f32(vget_low_f32(out), vget_high_f32(out));
|
||||
float32x2_t add4 = vpadd_f32(add2, add2);
|
||||
res += vget_lane_f32(add4, 0);
|
||||
#endif
|
||||
#endif
|
||||
for (; index < inner_size; index++) {
|
||||
res += input_col[index] * weight_col[index];
|
||||
}
|
||||
output[r * cols + c] += res;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void LstmMatMul(float *c, const float *a, const float *b, const float *bias, int row, int deep, int col, bool is_vec) {
|
||||
if (is_vec) {
|
||||
MatVecMulFp32(a, b, c, bias, ActType_No, deep, col);
|
||||
|
@ -182,7 +151,11 @@ void UpdateLstmGate(float *gate_buffer, const float *input, const float *weight,
|
|||
|
||||
void LstmStepUnit(float *output, float *input_gate, float *forget_gate, float *cell_gate, float *output_gate,
|
||||
const float *state_weight, const float *state_bias, float *hidden_state, float *cell_state,
|
||||
float *state_gate, float *state_buffer[2], float *packed_state, const LstmParameter *lstm_param) {
|
||||
float *buffer[6], const LstmParameter *lstm_param) {
|
||||
float *packed_state = buffer[2];
|
||||
float *state_gate = buffer[3];
|
||||
float *cell_buffer = buffer[4];
|
||||
float *hidden_buffer = buffer[5];
|
||||
bool is_vec = lstm_param->batch_ == 1;
|
||||
// state * weight
|
||||
if (is_vec) {
|
||||
|
@ -211,31 +184,29 @@ void LstmStepUnit(float *output, float *input_gate, float *forget_gate, float *c
|
|||
// update cell_gate
|
||||
Tanh(cell_gate, lstm_param->batch_ * lstm_param->hidden_size_, cell_gate);
|
||||
// update cell state
|
||||
UpdataState(cell_state, forget_gate, input_gate, cell_gate, state_buffer[0], lstm_param->batch_,
|
||||
lstm_param->hidden_size_, lstm_param->zoneout_cell_);
|
||||
UpdataState(cell_state, forget_gate, input_gate, cell_gate, cell_buffer, lstm_param->batch_, lstm_param->hidden_size_,
|
||||
lstm_param->zoneout_cell_);
|
||||
|
||||
// update output_gate
|
||||
Sigmoid(output_gate, lstm_param->batch_ * lstm_param->hidden_size_, output_gate);
|
||||
// update output
|
||||
UpdataOutput(cell_state, output_gate, hidden_state, state_buffer[1], lstm_param->batch_, lstm_param->hidden_size_,
|
||||
UpdataOutput(cell_state, output_gate, hidden_state, hidden_buffer, lstm_param->batch_, lstm_param->hidden_size_,
|
||||
lstm_param->zoneout_hidden_);
|
||||
memcpy(output, hidden_state, lstm_param->batch_ * lstm_param->hidden_size_ * sizeof(float));
|
||||
|
||||
if (!(lstm_param->zoneout_cell_ >= -FLT_EPSILON && lstm_param->zoneout_cell_ <= FLT_EPSILON)) {
|
||||
memcpy(cell_state, state_buffer[0], lstm_param->batch_ * lstm_param->hidden_size_ * sizeof(float));
|
||||
memcpy(cell_state, cell_buffer, lstm_param->batch_ * lstm_param->hidden_size_ * sizeof(float));
|
||||
}
|
||||
|
||||
if (!(lstm_param->zoneout_hidden_ >= -FLT_EPSILON && lstm_param->zoneout_hidden_ <= FLT_EPSILON)) {
|
||||
memcpy(hidden_state, state_buffer[1], lstm_param->batch_ * lstm_param->hidden_size_ * sizeof(float));
|
||||
memcpy(hidden_state, hidden_buffer, lstm_param->batch_ * lstm_param->hidden_size_ * sizeof(float));
|
||||
}
|
||||
}
|
||||
|
||||
void LstmUnidirectional(float *output, const float *packed_input, const float *weight_i, const float *weight_h,
|
||||
const float *input_bias, const float *state_bias, float *hidden_state, float *cell_state,
|
||||
float *state_buffer[2], float *buffer[4], const LstmParameter *lstm_param, bool is_backward) {
|
||||
float *buffer[6], const LstmParameter *lstm_param, bool is_backward) {
|
||||
float *gate = buffer[1];
|
||||
float *packed_state = buffer[2];
|
||||
float *state_gate = buffer[3];
|
||||
for (int i = 0; i < 4; i++) {
|
||||
const float *weight_loop = weight_i + lstm_param->input_size_ * lstm_param->input_col_align_ * i;
|
||||
const float *bias_loop = input_bias + lstm_param->input_col_align_ * i;
|
||||
|
@ -257,18 +228,18 @@ void LstmUnidirectional(float *output, const float *packed_input, const float *w
|
|||
float *output_gate_t = output_gate + lstm_param->batch_ * lstm_param->hidden_size_ * real_t;
|
||||
float *output_ptr = output + real_t * lstm_param->output_step_;
|
||||
LstmStepUnit(output_ptr, input_gate_t, forget_gate_t, cell_gate_t, output_gate_t, weight_h, state_bias,
|
||||
hidden_state, cell_state, state_gate, state_buffer, packed_state, lstm_param);
|
||||
hidden_state, cell_state, buffer, lstm_param);
|
||||
}
|
||||
}
|
||||
|
||||
void Lstm(float *output, const float *input, const float *weight_i, const float *weight_h, const float *input_bias,
|
||||
const float *state_bias, float *hidden_state, float *cell_state, float *state_buffer[2], float *buffer[4],
|
||||
const float *state_bias, float *hidden_state, float *cell_state, float *buffer[6],
|
||||
const LstmParameter *lstm_param) {
|
||||
// forward
|
||||
float *packed_input = buffer[0];
|
||||
PackLstmInput(input, packed_input, lstm_param->seq_len_ * lstm_param->batch_, lstm_param->input_size_);
|
||||
LstmUnidirectional(output, packed_input, weight_i, weight_h, input_bias, state_bias, hidden_state, cell_state,
|
||||
state_buffer, buffer, lstm_param, false);
|
||||
LstmUnidirectional(output, packed_input, weight_i, weight_h, input_bias, state_bias, hidden_state, cell_state, buffer,
|
||||
lstm_param, false);
|
||||
|
||||
// backward
|
||||
if (lstm_param->bidirectional_) {
|
||||
|
@ -281,7 +252,6 @@ void Lstm(float *output, const float *input, const float *weight_i, const float
|
|||
float *backward_hidden_state = hidden_state + lstm_param->batch_ * lstm_param->hidden_size_;
|
||||
|
||||
LstmUnidirectional(backward_output, packed_input, backward_weight_i, backward_weight_h, backward_input_bias,
|
||||
backward_state_bias, backward_hidden_state, backward_cell_state, state_buffer, buffer,
|
||||
lstm_param, true);
|
||||
backward_state_bias, backward_hidden_state, backward_cell_state, buffer, lstm_param, true);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -34,7 +34,7 @@ void ElementMulAcc(const float *input0, const float *input1, float *output, int
|
|||
int ElementOptMulAcc(const float *input0, const float input1, float *output, const int element_size);
|
||||
|
||||
void Lstm(float *output, const float *input, const float *weight_i, const float *weight_h, const float *input_bias,
|
||||
const float *state_bias, float *hidden_state, float *cell_state, float *state_buffer[2], float *buffer[4],
|
||||
const float *state_bias, float *hidden_state, float *cell_state, float *buffer[6],
|
||||
const LstmParameter *lstm_param);
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
|
|
|
@ -27,11 +27,12 @@ typedef struct GruParameter {
|
|||
int seq_len_;
|
||||
int batch_;
|
||||
// other parameter
|
||||
int input_step_;
|
||||
int output_step_;
|
||||
bool bidirectional_;
|
||||
int col_align_;
|
||||
int row_align_;
|
||||
int input_row_align_;
|
||||
int input_col_align_;
|
||||
int state_row_align_;
|
||||
int state_col_align_;
|
||||
} GruParameter;
|
||||
|
||||
#endif // MINDSPORE_LITE_NNACL_GRU_PARAMETER_H_
|
||||
|
|
|
@ -27,7 +27,6 @@ typedef struct LstmParameter {
|
|||
int seq_len_;
|
||||
int batch_;
|
||||
// other parameter
|
||||
int input_step_;
|
||||
int output_step_;
|
||||
bool bidirectional_;
|
||||
float zoneout_cell_;
|
||||
|
@ -36,8 +35,6 @@ typedef struct LstmParameter {
|
|||
int input_col_align_;
|
||||
int state_row_align_;
|
||||
int state_col_align_;
|
||||
int col_align_;
|
||||
int row_align_;
|
||||
} LstmParameter;
|
||||
|
||||
#endif // MINDSPORE_LITE_NNACL_LSTM_PARAMETER_H_
|
||||
|
|
|
@ -30,33 +30,31 @@ using mindspore::schema::PrimitiveType_GRU;
|
|||
|
||||
namespace mindspore::kernel {
|
||||
void GruFp16CPUKernel::FreeTmpBuffer() {
|
||||
if (!is_vec_ || in_tensors_[1]->data_type() == kNumberTypeFloat32) {
|
||||
if (weight_g_ptr_ != nullptr) {
|
||||
free(weight_g_ptr_);
|
||||
weight_g_ptr_ = nullptr;
|
||||
}
|
||||
if (input_bias_ != nullptr) {
|
||||
free(input_bias_);
|
||||
input_bias_ = nullptr;
|
||||
}
|
||||
if (!is_vec_ || in_tensors_[2]->data_type() == kNumberTypeFloat32) {
|
||||
if (weight_r_ptr_ != nullptr) {
|
||||
free(weight_r_ptr_);
|
||||
weight_r_ptr_ = nullptr;
|
||||
}
|
||||
}
|
||||
if (!is_vec_ || in_tensors_[3]->data_type() == kNumberTypeFloat32) {
|
||||
if (bias_ptr_ != nullptr) {
|
||||
free(bias_ptr_);
|
||||
bias_ptr_ = nullptr;
|
||||
}
|
||||
if (state_bias_ != nullptr) {
|
||||
free(state_bias_);
|
||||
state_bias_ = nullptr;
|
||||
}
|
||||
}
|
||||
|
||||
void GruFp16CPUKernel::FreeRunBuffer() {
|
||||
context_->allocator->Free(gate_buffer_);
|
||||
context_->allocator->Free(buffer_[0]);
|
||||
context_->allocator->Free(buffer_[1]);
|
||||
if (!is_vec_) {
|
||||
for (int i = 0; i < 2; i++) {
|
||||
context_->allocator->Free(matmul_buffer_[i]);
|
||||
}
|
||||
context_->allocator->Free(buffer_[2]);
|
||||
}
|
||||
context_->allocator->Free(buffer_[3]);
|
||||
}
|
||||
|
||||
int GruFp16CPUKernel::InitParam() {
|
||||
|
@ -71,105 +69,120 @@ int GruFp16CPUKernel::InitParam() {
|
|||
MS_ASSERT(weight_g != nullptr);
|
||||
std::vector<int> w_shape = weight_g->shape();
|
||||
gru_param_->hidden_size_ = w_shape.at(1) / 3;
|
||||
|
||||
gru_param_->input_step_ = gru_param_->batch_ * gru_param_->input_size_;
|
||||
weight_batch_ = gru_param_->bidirectional_ ? 6 : 3;
|
||||
gru_param_->output_step_ = gru_param_->bidirectional_ ? 2 * gru_param_->batch_ * gru_param_->hidden_size_
|
||||
: gru_param_->batch_ * gru_param_->hidden_size_;
|
||||
|
||||
gru_param_->input_row_align_ = UP_ROUND(gru_param_->seq_len_ * gru_param_->batch_, C16NUM);
|
||||
gru_param_->input_col_align_ = UP_ROUND(gru_param_->hidden_size_, C8NUM);
|
||||
|
||||
is_vec_ = gru_param_->batch_ == 1;
|
||||
gru_param_->row_align_ = is_vec_ ? gru_param_->batch_ : UP_ROUND(gru_param_->batch_, C16NUM);
|
||||
gru_param_->col_align_ = is_vec_ ? gru_param_->hidden_size_ : UP_ROUND(gru_param_->hidden_size_, C8NUM);
|
||||
gru_param_->state_row_align_ = is_vec_ ? gru_param_->batch_ : UP_ROUND(gru_param_->batch_, C16NUM);
|
||||
gru_param_->state_col_align_ = is_vec_ ? gru_param_->hidden_size_ : UP_ROUND(gru_param_->hidden_size_, C8NUM);
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
int GruFp16CPUKernel::InitWeight(const lite::Tensor *tensor, float16_t *ptr, int deep) {
|
||||
auto weight_batch = gru_param_->bidirectional_ ? 6 : 3;
|
||||
if (tensor->data_type() == kNumberTypeFloat32) {
|
||||
auto weight_data = reinterpret_cast<float *>(tensor->data_c());
|
||||
is_vec_ ? Float32ToFloat16(weight_data, ptr, tensor->ElementsNum())
|
||||
: PackLstmWeightFp32ToFp16(ptr, weight_data, weight_batch, deep, gru_param_->hidden_size_,
|
||||
gru_param_->col_align_);
|
||||
} else if (tensor->data_type() == kNumberTypeFloat16) {
|
||||
auto weight_data = reinterpret_cast<float16_t *>(tensor->data_c());
|
||||
if (is_vec_) {
|
||||
ptr = weight_data;
|
||||
} else {
|
||||
PackLstmWeightFp16(ptr, weight_data, weight_batch, deep, gru_param_->hidden_size_, gru_param_->col_align_);
|
||||
}
|
||||
} else {
|
||||
MS_LOG(ERROR) << "Unsupported data type of weight tensor for lstm.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
int GruFp16CPUKernel::InitWeightBias() {
|
||||
auto weight_batch = gru_param_->bidirectional_ ? 6 : 3;
|
||||
int GruFp16CPUKernel::InitInputWeightBias() {
|
||||
// malloc and init input * weight right matrix buffer
|
||||
// input -- row: seq_len * batch; col: input_size
|
||||
// weight -- row: hidden_size; col: input_size, need transpose
|
||||
// result -- row: seq_len * batch; col: hidden_size
|
||||
auto weight_g = in_tensors_.at(1);
|
||||
MS_ASSERT(weight_g != nullptr);
|
||||
if (!is_vec_ || weight_g->data_type() == kNumberTypeFloat32) {
|
||||
weight_g_ptr_ = reinterpret_cast<float16_t *>(
|
||||
malloc(weight_batch * gru_param_->col_align_ * gru_param_->input_size_ * sizeof(float16_t)));
|
||||
malloc(weight_batch_ * gru_param_->input_col_align_ * gru_param_->input_size_ * sizeof(float16_t)));
|
||||
if (weight_g_ptr_ == nullptr) {
|
||||
MS_LOG(ERROR) << "GruFp16CPUKernel malloc weight_g_ptr_ error.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
}
|
||||
auto ret = InitWeight(weight_g, weight_g_ptr_, gru_param_->input_size_);
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "GruFp16CPUKernel init weight_g failed.";
|
||||
if (weight_g->data_type() == kNumberTypeFloat32) {
|
||||
PackLstmWeightFp32ToFp16(weight_g_ptr_, reinterpret_cast<float *>(weight_g->data_c()), weight_batch_,
|
||||
gru_param_->input_size_, gru_param_->hidden_size_, gru_param_->input_col_align_);
|
||||
} else if (weight_g->data_type() == kNumberTypeFloat16) {
|
||||
PackLstmWeightFp16(weight_g_ptr_, reinterpret_cast<float16_t *>(weight_g->data_c()), weight_batch_,
|
||||
gru_param_->input_size_, gru_param_->hidden_size_, gru_param_->input_col_align_);
|
||||
} else {
|
||||
MS_LOG(ERROR) << "Unsupported data type of weight_g tensor for gru.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
|
||||
// malloc and init state * weight right matrix buffer
|
||||
// input bias
|
||||
auto bias = in_tensors_.at(3);
|
||||
MS_ASSERT(bias != nullptr);
|
||||
input_bias_ = reinterpret_cast<float16_t *>(malloc(weight_batch_ * gru_param_->input_col_align_ * sizeof(float16_t)));
|
||||
if (input_bias_ == nullptr) {
|
||||
MS_LOG(ERROR) << "GruFp16CPUKernel malloc input_bias_ error.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
memset(input_bias_, 0, weight_batch_ * gru_param_->input_col_align_ * sizeof(float16_t));
|
||||
if (bias->data_type() == kNumberTypeFloat32) {
|
||||
PackLstmBiasFp32ToFp16(input_bias_, reinterpret_cast<float *>(bias->data_c()), weight_batch_,
|
||||
gru_param_->hidden_size_, gru_param_->input_col_align_, gru_param_->bidirectional_);
|
||||
} else if (bias->data_type() == kNumberTypeFloat16) {
|
||||
PackLstmBiasFp16(input_bias_, reinterpret_cast<float16_t *>(bias->data_c()), weight_batch_,
|
||||
gru_param_->hidden_size_, gru_param_->input_col_align_, gru_param_->bidirectional_);
|
||||
} else {
|
||||
MS_LOG(ERROR) << "Unsupported data type of bias tensor for gru.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
int GruFp16CPUKernel::InitStateWeightBias() {
|
||||
// malloc and init state * weight right matrix buffer, state * weight will be executed seq_len_ times.
|
||||
// state -- row: batch; col: hidden_size
|
||||
// weight -- row: hidden_size; col: hidden_size, need transpose
|
||||
// result -- row: batch; col: hidden_size
|
||||
auto weight_r = in_tensors_.at(2);
|
||||
MS_ASSERT(weight_r != nullptr);
|
||||
if (!is_vec_ || weight_r->data_type() == kNumberTypeFloat32) {
|
||||
weight_r_ptr_ = reinterpret_cast<float16_t *>(
|
||||
malloc(weight_batch * gru_param_->col_align_ * gru_param_->hidden_size_ * sizeof(float16_t)));
|
||||
malloc(weight_batch_ * gru_param_->state_col_align_ * gru_param_->hidden_size_ * sizeof(float16_t)));
|
||||
if (weight_r_ptr_ == nullptr) {
|
||||
MS_LOG(ERROR) << "GruFp16CPUKernel malloc weight_r_ptr_ error.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
}
|
||||
ret = InitWeight(weight_r, weight_r_ptr_, gru_param_->hidden_size_);
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "GruFp16CPUKernel init weight_r failed.";
|
||||
|
||||
if (!is_vec_) {
|
||||
if (weight_r->data_type() == kNumberTypeFloat32) {
|
||||
PackLstmWeightFp32ToFp16(weight_r_ptr_, reinterpret_cast<float *>(weight_r->data_c()), weight_batch_,
|
||||
gru_param_->hidden_size_, gru_param_->hidden_size_, gru_param_->state_col_align_);
|
||||
} else if (weight_r->data_type() == kNumberTypeFloat16) {
|
||||
PackLstmWeightFp16(weight_r_ptr_, reinterpret_cast<float16_t *>(weight_r->data_c()), weight_batch_,
|
||||
gru_param_->hidden_size_, gru_param_->hidden_size_, gru_param_->state_col_align_);
|
||||
} else {
|
||||
MS_LOG(ERROR) << "Unsupported data type of weight_r tensor for gru.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
} else {
|
||||
if (weight_r->data_type() == kNumberTypeFloat32) {
|
||||
Float32ToFloat16(reinterpret_cast<float *>(weight_r->data_c()), weight_r_ptr_, weight_r->ElementsNum());
|
||||
} else if (weight_r->data_type() == kNumberTypeFloat16) {
|
||||
memcpy(weight_r_ptr_, reinterpret_cast<float16_t *>(weight_r->data_c()), weight_r->ElementsNum());
|
||||
} else {
|
||||
MS_LOG(ERROR) << "Unsupported data type of weight_r tensor for gru.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
}
|
||||
|
||||
int bias_batch = gru_param_->bidirectional_ ? 12 : 6;
|
||||
// state bias
|
||||
auto bias = in_tensors_.at(3);
|
||||
MS_ASSERT(bias != nullptr);
|
||||
if (!is_vec_ || bias->data_type() == kNumberTypeFloat32) {
|
||||
bias_ptr_ = reinterpret_cast<float16_t *>(malloc(bias_batch * gru_param_->col_align_ * sizeof(float16_t)));
|
||||
if (bias_ptr_ == nullptr) {
|
||||
MS_LOG(ERROR) << "GruFp16CPUKernel malloc bias_ptr_ error.";
|
||||
state_bias_ = reinterpret_cast<float16_t *>(malloc(weight_batch_ * gru_param_->state_col_align_ * sizeof(float16_t)));
|
||||
if (state_bias_ == nullptr) {
|
||||
MS_LOG(ERROR) << "GruFp16CPUKernel malloc state_bias_ error.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
memset(bias_ptr_, 0, bias_batch * gru_param_->col_align_ * sizeof(float16_t));
|
||||
}
|
||||
memset(state_bias_, 0, weight_batch_ * gru_param_->state_col_align_ * sizeof(float16_t));
|
||||
if (bias->data_type() == kNumberTypeFloat32) {
|
||||
auto bias_data = reinterpret_cast<float *>(bias->data_c());
|
||||
for (int i = 0; i < bias_batch; i++) {
|
||||
auto src_batch = bias_data + i * gru_param_->hidden_size_;
|
||||
auto dst_batch = bias_ptr_ + i * gru_param_->col_align_;
|
||||
Float32ToFloat16(src_batch, dst_batch, gru_param_->hidden_size_);
|
||||
}
|
||||
auto state_bias_data = reinterpret_cast<float *>(bias->data_c()) + 3 * gru_param_->hidden_size_;
|
||||
PackLstmBiasFp32ToFp16(state_bias_, state_bias_data, weight_batch_, gru_param_->hidden_size_,
|
||||
gru_param_->state_col_align_, gru_param_->bidirectional_);
|
||||
} else if (bias->data_type() == kNumberTypeFloat16) {
|
||||
auto bias_data = reinterpret_cast<float16_t *>(bias->data_c());
|
||||
if (is_vec_) {
|
||||
bias_ptr_ = bias_data;
|
||||
auto state_bias_data = reinterpret_cast<float16_t *>(bias->data_c()) + 3 * gru_param_->hidden_size_;
|
||||
PackLstmBiasFp16(state_bias_, state_bias_data, weight_batch_, gru_param_->hidden_size_,
|
||||
gru_param_->state_col_align_, gru_param_->bidirectional_);
|
||||
} else {
|
||||
for (int i = 0; i < bias_batch; i++) {
|
||||
auto src_batch = bias_data + i * gru_param_->hidden_size_;
|
||||
auto dst_batch = bias_ptr_ + i * gru_param_->col_align_;
|
||||
memcpy(dst_batch, src_batch, gru_param_->hidden_size_ * sizeof(float16_t));
|
||||
}
|
||||
}
|
||||
} else {
|
||||
MS_LOG(ERROR) << "Unsupported data type of bias tensor for lstm.";
|
||||
MS_LOG(ERROR) << "Unsupported data type of bias tensor for gru.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
return RET_OK;
|
||||
|
@ -190,9 +203,16 @@ int GruFp16CPUKernel::ReSize() {
|
|||
}
|
||||
|
||||
FreeTmpBuffer();
|
||||
ret = InitWeightBias();
|
||||
ret = InitInputWeightBias();
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "GruFp16CPUKernel InitWeightBias error.";
|
||||
MS_LOG(ERROR) << "GruFp16CPUKernel InitInputWeightBias error.";
|
||||
FreeTmpBuffer();
|
||||
return RET_ERROR;
|
||||
}
|
||||
|
||||
ret = InitStateWeightBias();
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "GruFp16CPUKernel InitStateWeightBias error.";
|
||||
FreeTmpBuffer();
|
||||
return RET_ERROR;
|
||||
}
|
||||
|
@ -200,26 +220,36 @@ int GruFp16CPUKernel::ReSize() {
|
|||
}
|
||||
|
||||
int GruFp16CPUKernel::MallocRunBuffer() {
|
||||
for (int i = 0; i < 4; i++) {
|
||||
buffer_[i] = nullptr;
|
||||
}
|
||||
buffer_[0] = reinterpret_cast<float16_t *>(
|
||||
context_->allocator->Malloc(gru_param_->input_row_align_ * gru_param_->input_size_ * sizeof(float16_t)));
|
||||
if (buffer_[0] == nullptr) {
|
||||
MS_LOG(ERROR) << "GruCPUKernel malloc input * weight left matirx error.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
|
||||
buffer_[1] = reinterpret_cast<float16_t *>(context_->allocator->Malloc(3 * gru_param_->seq_len_ * gru_param_->batch_ *
|
||||
gru_param_->hidden_size_ * sizeof(float16_t)));
|
||||
if (buffer_[1] == nullptr) {
|
||||
MS_LOG(ERROR) << "GruCPUKernel malloc input * weight result matirx error.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
|
||||
if (!is_vec_) {
|
||||
matmul_buffer_[0] = reinterpret_cast<float16_t *>(
|
||||
context_->allocator->Malloc(3 * gru_param_->row_align_ * gru_param_->input_size_ * sizeof(float16_t)));
|
||||
if (matmul_buffer_[0] == nullptr) {
|
||||
MS_LOG(ERROR) << "GruFp16CPUKernel malloc input * weight left matirx error.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
|
||||
matmul_buffer_[1] = reinterpret_cast<float16_t *>(
|
||||
context_->allocator->Malloc(3 * gru_param_->row_align_ * gru_param_->hidden_size_ * sizeof(float16_t)));
|
||||
if (matmul_buffer_[1] == nullptr) {
|
||||
MS_LOG(ERROR) << "GruFp16CPUKernel malloc state * weight left matirx error.";
|
||||
buffer_[2] = reinterpret_cast<float16_t *>(
|
||||
context_->allocator->Malloc(gru_param_->state_row_align_ * gru_param_->hidden_size_ * sizeof(float16_t)));
|
||||
if (buffer_[2] == nullptr) {
|
||||
MS_LOG(ERROR) << "GruCPUKernel malloc state * weight left matirx error.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
}
|
||||
|
||||
gate_buffer_ = reinterpret_cast<float16_t *>(
|
||||
context_->allocator->Malloc(4 * gru_param_->batch_ * gru_param_->hidden_size_ * sizeof(float16_t)));
|
||||
if (gate_buffer_ == nullptr) {
|
||||
MS_LOG(ERROR) << "GruFp16CPUKernel malloc gate_buffer error.";
|
||||
buffer_[3] = reinterpret_cast<float16_t *>(
|
||||
context_->allocator->Malloc(3 * gru_param_->batch_ * gru_param_->hidden_size_ * sizeof(float16_t)));
|
||||
if (buffer_[3] == nullptr) {
|
||||
MS_LOG(ERROR) << "GruCPUKernel malloc state gate buffer error.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
return RET_OK;
|
||||
|
@ -255,11 +285,10 @@ int GruFp16CPUKernel::Run() {
|
|||
}
|
||||
MS_ASSERT(weight_g_ptr_ != nullptr);
|
||||
MS_ASSERT(weight_r_ptr_ != nullptr);
|
||||
MS_ASSERT(bias_ptr_ != nullptr);
|
||||
MS_ASSERT(gate_buffer_ != nullptr);
|
||||
GruFp16(output_ptr, input_ptr, weight_g_ptr_, weight_r_ptr_, bias_ptr_,
|
||||
reinterpret_cast<float16_t *>(output_hidden_state->data_c()), gate_buffer_, matmul_buffer_, check_seq_len,
|
||||
gru_param_);
|
||||
MS_ASSERT(input_bias_ != nullptr);
|
||||
MS_ASSERT(state_bias_ != nullptr);
|
||||
GruFp16(output_ptr, input_ptr, weight_g_ptr_, weight_r_ptr_, input_bias_, state_bias_,
|
||||
reinterpret_cast<float16_t *>(output_hidden_state->data_c()), buffer_, check_seq_len, gru_param_);
|
||||
FreeRunBuffer();
|
||||
return RET_OK;
|
||||
}
|
||||
|
|
|
@ -38,15 +38,16 @@ class GruFp16CPUKernel : public LiteKernel {
|
|||
void FreeTmpBuffer();
|
||||
void FreeRunBuffer();
|
||||
int InitParam();
|
||||
int InitWeight(const lite::Tensor *tensor, float16_t *ptr, int deep);
|
||||
int InitWeightBias();
|
||||
int InitInputWeightBias();
|
||||
int InitStateWeightBias();
|
||||
int MallocRunBuffer();
|
||||
|
||||
float16_t *gate_buffer_ = nullptr;
|
||||
float16_t *weight_g_ptr_ = nullptr;
|
||||
float16_t *weight_r_ptr_ = nullptr;
|
||||
float16_t *bias_ptr_ = nullptr;
|
||||
float16_t *matmul_buffer_[2];
|
||||
float16_t *input_bias_ = nullptr;
|
||||
float16_t *state_bias_ = nullptr;
|
||||
float16_t *buffer_[4];
|
||||
int weight_batch_ = 0;
|
||||
bool is_vec_ = false;
|
||||
GruParameter *gru_param_ = nullptr;
|
||||
};
|
||||
|
|
|
@ -31,35 +31,36 @@ using mindspore::schema::PrimitiveType_LSTM;
|
|||
|
||||
namespace mindspore::kernel {
|
||||
void LstmFp16CPUKernel::FreeTmpBuffer() {
|
||||
if (!is_vec_ || in_tensors_[1]->data_type() == kNumberTypeFloat32) {
|
||||
if (weight_i_ptr_ != nullptr) {
|
||||
free(weight_i_ptr_);
|
||||
weight_i_ptr_ = nullptr;
|
||||
}
|
||||
if (input_bias_ != nullptr) {
|
||||
free(input_bias_);
|
||||
input_bias_ = nullptr;
|
||||
}
|
||||
if (!is_vec_ || in_tensors_[2]->data_type() == kNumberTypeFloat32) {
|
||||
if (weight_h_ptr_ != nullptr) {
|
||||
free(weight_h_ptr_);
|
||||
weight_h_ptr_ = nullptr;
|
||||
}
|
||||
}
|
||||
if (!is_vec_ || in_tensors_[3]->data_type() == kNumberTypeFloat32) {
|
||||
if (bias_ptr_ != nullptr) {
|
||||
free(bias_ptr_);
|
||||
bias_ptr_ = nullptr;
|
||||
}
|
||||
if (state_bias_ != nullptr) {
|
||||
free(state_bias_);
|
||||
state_bias_ = nullptr;
|
||||
}
|
||||
}
|
||||
|
||||
void LstmFp16CPUKernel::FreeRunBuffer() {
|
||||
context_->allocator->Free(gate_buffer_);
|
||||
for (int i = 0; i < 2; i++) {
|
||||
context_->allocator->Free(state_buffer_[i]);
|
||||
}
|
||||
context_->allocator->Free(buffer_[0]);
|
||||
context_->allocator->Free(buffer_[1]);
|
||||
if (!is_vec_) {
|
||||
for (int i = 0; i < 2; i++) {
|
||||
context_->allocator->Free(matmul_buffer_[i]);
|
||||
context_->allocator->Free(buffer_[2]);
|
||||
}
|
||||
context_->allocator->Free(buffer_[3]);
|
||||
if (!(lstm_param_->zoneout_cell_ >= -FLT_EPSILON && lstm_param_->zoneout_cell_ <= FLT_EPSILON)) {
|
||||
context_->allocator->Free(buffer_[4]);
|
||||
}
|
||||
if (!(lstm_param_->zoneout_hidden_ >= -FLT_EPSILON && lstm_param_->zoneout_hidden_ <= FLT_EPSILON)) {
|
||||
context_->allocator->Free(buffer_[5]);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -76,102 +77,119 @@ int LstmFp16CPUKernel::InitParam() {
|
|||
std::vector<int> w_shape = weight_i->shape();
|
||||
lstm_param_->hidden_size_ = w_shape.at(1) / 4;
|
||||
|
||||
lstm_param_->input_step_ = lstm_param_->batch_ * lstm_param_->input_size_;
|
||||
lstm_param_->output_step_ = lstm_param_->bidirectional_ ? 2 * lstm_param_->batch_ * lstm_param_->hidden_size_
|
||||
: lstm_param_->batch_ * lstm_param_->hidden_size_;
|
||||
weight_batch_ = lstm_param_->bidirectional_ ? 8 : 4;
|
||||
lstm_param_->input_row_align_ = UP_ROUND(lstm_param_->seq_len_ * lstm_param_->batch_, C16NUM);
|
||||
lstm_param_->input_col_align_ = UP_ROUND(lstm_param_->hidden_size_, C8NUM);
|
||||
|
||||
is_vec_ = lstm_param_->batch_ == 1;
|
||||
lstm_param_->row_align_ = is_vec_ ? lstm_param_->batch_ : UP_ROUND(lstm_param_->batch_, C16NUM);
|
||||
lstm_param_->col_align_ = is_vec_ ? lstm_param_->hidden_size_ : UP_ROUND(lstm_param_->hidden_size_, C8NUM);
|
||||
lstm_param_->state_row_align_ = is_vec_ ? lstm_param_->batch_ : UP_ROUND(lstm_param_->batch_, C16NUM);
|
||||
lstm_param_->state_col_align_ = is_vec_ ? lstm_param_->hidden_size_ : UP_ROUND(lstm_param_->hidden_size_, C8NUM);
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
int LstmFp16CPUKernel::InitWeight(const lite::Tensor *tensor, float16_t *ptr, int deep) {
|
||||
auto weight_batch = lstm_param_->bidirectional_ ? 8 : 4;
|
||||
if (tensor->data_type() == kNumberTypeFloat32) {
|
||||
auto weight_data = reinterpret_cast<float *>(tensor->data_c());
|
||||
is_vec_ ? Float32ToFloat16(weight_data, ptr, tensor->ElementsNum())
|
||||
: PackLstmWeightFp32ToFp16(ptr, weight_data, weight_batch, deep, lstm_param_->hidden_size_,
|
||||
lstm_param_->col_align_);
|
||||
} else if (tensor->data_type() == kNumberTypeFloat16) {
|
||||
auto weight_data = reinterpret_cast<float16_t *>(tensor->data_c());
|
||||
if (is_vec_) {
|
||||
ptr = weight_data;
|
||||
} else {
|
||||
PackLstmWeightFp16(ptr, weight_data, weight_batch, deep, lstm_param_->hidden_size_, lstm_param_->col_align_);
|
||||
}
|
||||
} else {
|
||||
MS_LOG(ERROR) << "Unsupported data type of weight tensor for lstm.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
int LstmFp16CPUKernel::InitWeightBias() {
|
||||
auto weight_batch = lstm_param_->bidirectional_ ? 8 : 4;
|
||||
int LstmFp16CPUKernel::InitInputWeightBias() {
|
||||
// malloc and init input * weight right matrix buffer
|
||||
// input -- row: seq_len * batch; col: input_size
|
||||
// weight -- row: hidden_size; col: input_size, need transpose
|
||||
// result -- row: seq_len * batch; col: hidden_size
|
||||
auto weight_i = in_tensors_.at(1);
|
||||
MS_ASSERT(weight_i != nullptr);
|
||||
if (!is_vec_ || weight_i->data_type() == kNumberTypeFloat32) {
|
||||
weight_i_ptr_ = reinterpret_cast<float16_t *>(
|
||||
malloc(weight_batch * lstm_param_->col_align_ * lstm_param_->input_size_ * sizeof(float16_t)));
|
||||
malloc(weight_batch_ * lstm_param_->input_col_align_ * lstm_param_->input_size_ * sizeof(float16_t)));
|
||||
if (weight_i_ptr_ == nullptr) {
|
||||
MS_LOG(ERROR) << "LstmFp16CPUKernel malloc weight_i_ptr_ error.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
}
|
||||
auto ret = InitWeight(weight_i, weight_i_ptr_, lstm_param_->input_size_);
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "LstmFp16CPUKernel init weight_i failed.";
|
||||
if (weight_i->data_type() == kNumberTypeFloat32) {
|
||||
PackLstmWeightFp32ToFp16(weight_i_ptr_, reinterpret_cast<float *>(weight_i->data_c()), weight_batch_,
|
||||
lstm_param_->input_size_, lstm_param_->hidden_size_, lstm_param_->input_col_align_);
|
||||
} else if (weight_i->data_type() == kNumberTypeFloat16) {
|
||||
PackLstmWeightFp16(weight_i_ptr_, reinterpret_cast<float16_t *>(weight_i->data_c()), weight_batch_,
|
||||
lstm_param_->input_size_, lstm_param_->hidden_size_, lstm_param_->input_col_align_);
|
||||
} else {
|
||||
MS_LOG(ERROR) << "Unsupported data type of weight_i tensor for lstm.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
|
||||
// malloc and init state * weight right matrix buffer
|
||||
// input bias
|
||||
auto bias = in_tensors_.at(3);
|
||||
MS_ASSERT(bias != nullptr);
|
||||
input_bias_ =
|
||||
reinterpret_cast<float16_t *>(malloc(weight_batch_ * lstm_param_->input_col_align_ * sizeof(float16_t)));
|
||||
if (input_bias_ == nullptr) {
|
||||
MS_LOG(ERROR) << "LstmFp16CPUKernel malloc input_bias_ error.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
memset(input_bias_, 0, weight_batch_ * lstm_param_->input_col_align_ * sizeof(float16_t));
|
||||
if (bias->data_type() == kNumberTypeFloat32) {
|
||||
PackLstmBiasFp32ToFp16(input_bias_, reinterpret_cast<float *>(bias->data_c()), weight_batch_,
|
||||
lstm_param_->hidden_size_, lstm_param_->input_col_align_, lstm_param_->bidirectional_);
|
||||
} else if (bias->data_type() == kNumberTypeFloat16) {
|
||||
PackLstmBiasFp16(input_bias_, reinterpret_cast<float16_t *>(bias->data_c()), weight_batch_,
|
||||
lstm_param_->hidden_size_, lstm_param_->input_col_align_, lstm_param_->bidirectional_);
|
||||
} else {
|
||||
MS_LOG(ERROR) << "Unsupported data type of bias tensor for lstm.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
int LstmFp16CPUKernel::InitStateWeightBias() {
|
||||
// malloc and init state * weight right matrix buffer, state * weight will be executed seq_len_ times.
|
||||
// state -- row: batch; col: hidden_size
|
||||
// weight -- row: hidden_size; col: hidden_size, need transpose
|
||||
// result -- row: batch; col: hidden_size
|
||||
auto weight_h = in_tensors_.at(2);
|
||||
MS_ASSERT(weight_h != nullptr);
|
||||
if (!is_vec_ || weight_h->data_type() == kNumberTypeFloat32) {
|
||||
weight_h_ptr_ = reinterpret_cast<float16_t *>(
|
||||
malloc(weight_batch * lstm_param_->col_align_ * lstm_param_->hidden_size_ * sizeof(float16_t)));
|
||||
malloc(weight_batch_ * lstm_param_->state_col_align_ * lstm_param_->hidden_size_ * sizeof(float16_t)));
|
||||
if (weight_h_ptr_ == nullptr) {
|
||||
MS_LOG(ERROR) << "LstmFp16CPUKernel malloc weight_h_ptr_ error.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
}
|
||||
ret = InitWeight(weight_h, weight_h_ptr_, lstm_param_->hidden_size_);
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "LstmFp16CPUKernel init weight_h failed.";
|
||||
|
||||
if (!is_vec_) {
|
||||
if (weight_h->data_type() == kNumberTypeFloat32) {
|
||||
PackLstmWeightFp32ToFp16(weight_h_ptr_, reinterpret_cast<float *>(weight_h->data_c()), weight_batch_,
|
||||
lstm_param_->hidden_size_, lstm_param_->hidden_size_, lstm_param_->state_col_align_);
|
||||
} else if (weight_h->data_type() == kNumberTypeFloat16) {
|
||||
PackLstmWeightFp16(weight_h_ptr_, reinterpret_cast<float16_t *>(weight_h->data_c()), weight_batch_,
|
||||
lstm_param_->hidden_size_, lstm_param_->hidden_size_, lstm_param_->state_col_align_);
|
||||
} else {
|
||||
MS_LOG(ERROR) << "Unsupported data type of weight_h tensor for lstm.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
} else {
|
||||
if (weight_h->data_type() == kNumberTypeFloat32) {
|
||||
Float32ToFloat16(reinterpret_cast<float *>(weight_h->data_c()), weight_h_ptr_, weight_h->ElementsNum());
|
||||
} else if (weight_h->data_type() == kNumberTypeFloat16) {
|
||||
memcpy(weight_h_ptr_, reinterpret_cast<float16_t *>(weight_h->data_c()), weight_h->ElementsNum());
|
||||
} else {
|
||||
MS_LOG(ERROR) << "Unsupported data type of weight_h tensor for lstm.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
}
|
||||
|
||||
int bias_batch = lstm_param_->bidirectional_ ? 16 : 8;
|
||||
// state bias
|
||||
auto bias = in_tensors_.at(3);
|
||||
MS_ASSERT(bias != nullptr);
|
||||
if (!is_vec_ || bias->data_type() == kNumberTypeFloat32) {
|
||||
bias_ptr_ = reinterpret_cast<float16_t *>(malloc(bias_batch * lstm_param_->col_align_ * sizeof(float16_t)));
|
||||
if (bias_ptr_ == nullptr) {
|
||||
MS_LOG(ERROR) << "LstmFp16CPUKernel malloc bias_ptr_ error.";
|
||||
state_bias_ =
|
||||
reinterpret_cast<float16_t *>(malloc(weight_batch_ * lstm_param_->state_col_align_ * sizeof(float16_t)));
|
||||
if (state_bias_ == nullptr) {
|
||||
MS_LOG(ERROR) << "LstmFp16CPUKernel malloc state_bias_ error.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
memset(bias_ptr_, 0, bias_batch * lstm_param_->col_align_ * sizeof(float16_t));
|
||||
}
|
||||
memset(state_bias_, 0, weight_batch_ * lstm_param_->state_col_align_ * sizeof(float16_t));
|
||||
if (bias->data_type() == kNumberTypeFloat32) {
|
||||
auto bias_data = reinterpret_cast<float *>(bias->data_c());
|
||||
for (int i = 0; i < bias_batch; i++) {
|
||||
auto src_batch = bias_data + i * lstm_param_->hidden_size_;
|
||||
auto dst_batch = bias_ptr_ + i * lstm_param_->col_align_;
|
||||
Float32ToFloat16(src_batch, dst_batch, lstm_param_->hidden_size_);
|
||||
}
|
||||
auto state_bias_data = reinterpret_cast<float *>(bias->data_c()) + 4 * lstm_param_->hidden_size_;
|
||||
PackLstmBiasFp32ToFp16(state_bias_, state_bias_data, weight_batch_, lstm_param_->hidden_size_,
|
||||
lstm_param_->state_col_align_, lstm_param_->bidirectional_);
|
||||
} else if (bias->data_type() == kNumberTypeFloat16) {
|
||||
auto bias_data = reinterpret_cast<float16_t *>(bias->data_c());
|
||||
if (is_vec_) {
|
||||
bias_ptr_ = bias_data;
|
||||
} else {
|
||||
for (int i = 0; i < bias_batch; i++) {
|
||||
auto src_batch = bias_data + i * lstm_param_->hidden_size_;
|
||||
auto dst_batch = bias_ptr_ + i * lstm_param_->col_align_;
|
||||
memcpy(dst_batch, src_batch, lstm_param_->hidden_size_ * sizeof(float16_t));
|
||||
}
|
||||
}
|
||||
auto state_bias_data = reinterpret_cast<float16_t *>(bias->data_c()) + 4 * lstm_param_->hidden_size_;
|
||||
PackLstmBiasFp16(state_bias_, state_bias_data, weight_batch_, lstm_param_->hidden_size_,
|
||||
lstm_param_->state_col_align_, lstm_param_->bidirectional_);
|
||||
} else {
|
||||
MS_LOG(ERROR) << "Unsupported data type of bias tensor for lstm.";
|
||||
return RET_ERROR;
|
||||
|
@ -194,9 +212,16 @@ int LstmFp16CPUKernel::ReSize() {
|
|||
}
|
||||
|
||||
FreeTmpBuffer();
|
||||
ret = InitWeightBias();
|
||||
ret = InitInputWeightBias();
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "Lstm fp16 InitWeightBias error.";
|
||||
MS_LOG(ERROR) << "Lstm fp16 InitInputWeightBias error.";
|
||||
FreeTmpBuffer();
|
||||
return RET_ERROR;
|
||||
}
|
||||
|
||||
ret = InitStateWeightBias();
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "Lstm fp16 InitStateWeightBias error.";
|
||||
FreeTmpBuffer();
|
||||
return RET_ERROR;
|
||||
}
|
||||
|
@ -204,42 +229,51 @@ int LstmFp16CPUKernel::ReSize() {
|
|||
}
|
||||
|
||||
int LstmFp16CPUKernel::MallocRunBuffer() {
|
||||
if (!is_vec_) {
|
||||
matmul_buffer_[0] = reinterpret_cast<float16_t *>(
|
||||
context_->allocator->Malloc(4 * lstm_param_->row_align_ * lstm_param_->input_size_ * sizeof(float16_t)));
|
||||
if (matmul_buffer_[0] == nullptr) {
|
||||
for (int i = 0; i < 6; i++) {
|
||||
buffer_[i] = nullptr;
|
||||
}
|
||||
buffer_[0] = reinterpret_cast<float16_t *>(
|
||||
context_->allocator->Malloc(lstm_param_->input_row_align_ * lstm_param_->input_size_ * sizeof(float16_t)));
|
||||
if (buffer_[0] == nullptr) {
|
||||
MS_LOG(ERROR) << "LstmFp16CPUKernel malloc input * weight left matirx error.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
|
||||
matmul_buffer_[1] = reinterpret_cast<float16_t *>(
|
||||
context_->allocator->Malloc(4 * lstm_param_->row_align_ * lstm_param_->hidden_size_ * sizeof(float16_t)));
|
||||
if (matmul_buffer_[1] == nullptr) {
|
||||
buffer_[1] = reinterpret_cast<float16_t *>(context_->allocator->Malloc(
|
||||
4 * lstm_param_->seq_len_ * lstm_param_->batch_ * lstm_param_->hidden_size_ * sizeof(float16_t)));
|
||||
if (buffer_[1] == nullptr) {
|
||||
MS_LOG(ERROR) << "LstmFp16CPUKernel malloc state * weight left matirx error.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
|
||||
if (!is_vec_) {
|
||||
buffer_[2] = reinterpret_cast<float16_t *>(
|
||||
context_->allocator->Malloc(lstm_param_->state_row_align_ * lstm_param_->hidden_size_ * sizeof(float16_t)));
|
||||
if (buffer_[2] == nullptr) {
|
||||
MS_LOG(ERROR) << "LstmFp16CPUKernel malloc state * weight left matirx error.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
}
|
||||
|
||||
gate_buffer_ = reinterpret_cast<float16_t *>(
|
||||
context_->allocator->Malloc(8 * lstm_param_->batch_ * lstm_param_->hidden_size_ * sizeof(float16_t)));
|
||||
if (gate_buffer_ == nullptr) {
|
||||
MS_LOG(ERROR) << "LstmFp16CPUKernel malloc gate_buffer error.";
|
||||
buffer_[3] = reinterpret_cast<float16_t *>(
|
||||
context_->allocator->Malloc(4 * lstm_param_->batch_ * lstm_param_->hidden_size_ * sizeof(float16_t)));
|
||||
if (buffer_[3] == nullptr) {
|
||||
MS_LOG(ERROR) << "LstmFp16CPUKernel malloc state gate buffer error.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
state_buffer_[0] = nullptr;
|
||||
state_buffer_[1] = nullptr;
|
||||
|
||||
if (!(lstm_param_->zoneout_cell_ >= -FLT_EPSILON && lstm_param_->zoneout_cell_ <= FLT_EPSILON)) {
|
||||
int buffer_size = lstm_param_->batch_ * lstm_param_->hidden_size_ * sizeof(float16_t);
|
||||
state_buffer_[0] = reinterpret_cast<float16_t *>(context_->allocator->Malloc(buffer_size));
|
||||
if (state_buffer_[0] == nullptr) {
|
||||
buffer_[4] = reinterpret_cast<float16_t *>(context_->allocator->Malloc(buffer_size));
|
||||
if (buffer_[4] == nullptr) {
|
||||
MS_LOG(ERROR) << "LstmFp16CPUKernel malloc state_buffer for cell error.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
}
|
||||
if (!(lstm_param_->zoneout_hidden_ >= -FLT_EPSILON && lstm_param_->zoneout_hidden_ <= FLT_EPSILON)) {
|
||||
int buffer_size = lstm_param_->batch_ * lstm_param_->hidden_size_ * sizeof(float16_t);
|
||||
state_buffer_[1] = reinterpret_cast<float16_t *>(context_->allocator->Malloc(buffer_size));
|
||||
if (state_buffer_[1] == nullptr) {
|
||||
buffer_[5] = reinterpret_cast<float16_t *>(context_->allocator->Malloc(buffer_size));
|
||||
if (buffer_[5] == nullptr) {
|
||||
MS_LOG(ERROR) << "LstmFp16CPUKernel malloc state_buffer for hidden error.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
|
@ -273,12 +307,11 @@ int LstmFp16CPUKernel::Run() {
|
|||
}
|
||||
MS_ASSERT(weight_i_ptr_);
|
||||
MS_ASSERT(weight_h_ptr_);
|
||||
MS_ASSERT(bias_ptr_);
|
||||
MS_ASSERT(gate_buffer_);
|
||||
LstmFp16(output_ptr, input_ptr, weight_i_ptr_, weight_h_ptr_, bias_ptr_,
|
||||
MS_ASSERT(input_bias_);
|
||||
MS_ASSERT(state_bias_);
|
||||
LstmFp16(output_ptr, input_ptr, weight_i_ptr_, weight_h_ptr_, input_bias_, state_bias_,
|
||||
reinterpret_cast<float16_t *>(output_hidden_state->data_c()),
|
||||
reinterpret_cast<float16_t *>(output_cell_state->data_c()), gate_buffer_, state_buffer_, matmul_buffer_,
|
||||
lstm_param_);
|
||||
reinterpret_cast<float16_t *>(output_cell_state->data_c()), buffer_, lstm_param_);
|
||||
FreeRunBuffer();
|
||||
return RET_OK;
|
||||
}
|
||||
|
|
|
@ -40,16 +40,16 @@ class LstmFp16CPUKernel : public LiteKernel {
|
|||
void FreeTmpBuffer();
|
||||
void FreeRunBuffer();
|
||||
int InitParam();
|
||||
int InitWeight(const lite::Tensor *tensor, float16_t *ptr, int deep);
|
||||
int InitWeightBias();
|
||||
int InitInputWeightBias();
|
||||
int InitStateWeightBias();
|
||||
int MallocRunBuffer();
|
||||
|
||||
float16_t *gate_buffer_ = nullptr;
|
||||
float16_t *state_buffer_[2];
|
||||
float16_t *weight_i_ptr_ = nullptr;
|
||||
float16_t *weight_h_ptr_ = nullptr;
|
||||
float16_t *bias_ptr_ = nullptr;
|
||||
float16_t *matmul_buffer_[2];
|
||||
float16_t *input_bias_ = nullptr;
|
||||
float16_t *state_bias_ = nullptr;
|
||||
float16_t *buffer_[6];
|
||||
int weight_batch_ = 0;
|
||||
bool is_vec_ = false;
|
||||
LstmParameter *lstm_param_ = nullptr;
|
||||
};
|
||||
|
|
|
@ -29,29 +29,33 @@ using mindspore::schema::PrimitiveType_GRU;
|
|||
|
||||
namespace mindspore::kernel {
|
||||
void GruCPUKernel::FreeTmpBuffer() {
|
||||
if (!is_vec_) {
|
||||
if (weight_g_ptr_ != nullptr) {
|
||||
free(weight_g_ptr_);
|
||||
weight_g_ptr_ = nullptr;
|
||||
}
|
||||
if (input_bias_ != nullptr) {
|
||||
free(input_bias_);
|
||||
input_bias_ = nullptr;
|
||||
}
|
||||
if (!is_vec_) {
|
||||
if (weight_r_ptr_ != nullptr) {
|
||||
free(weight_r_ptr_);
|
||||
weight_r_ptr_ = nullptr;
|
||||
}
|
||||
if (bias_ptr_ != nullptr) {
|
||||
free(bias_ptr_);
|
||||
bias_ptr_ = nullptr;
|
||||
}
|
||||
if (state_bias_ != nullptr) {
|
||||
free(state_bias_);
|
||||
state_bias_ = nullptr;
|
||||
}
|
||||
}
|
||||
|
||||
void GruCPUKernel::FreeRunBuffer() {
|
||||
context_->allocator->Free(gate_buffer_);
|
||||
context_->allocator->Free(buffer_[0]);
|
||||
context_->allocator->Free(buffer_[1]);
|
||||
if (!is_vec_) {
|
||||
for (int i = 0; i < 2; i++) {
|
||||
context_->allocator->Free(matmul_buffer_[i]);
|
||||
}
|
||||
context_->allocator->Free(buffer_[2]);
|
||||
}
|
||||
context_->allocator->Free(buffer_[3]);
|
||||
}
|
||||
|
||||
int GruCPUKernel::InitParam() {
|
||||
|
@ -67,9 +71,9 @@ int GruCPUKernel::InitParam() {
|
|||
std::vector<int> w_shape = weight_g->shape();
|
||||
gru_param_->hidden_size_ = w_shape.at(1) / 3;
|
||||
|
||||
gru_param_->input_step_ = gru_param_->batch_ * gru_param_->input_size_;
|
||||
gru_param_->output_step_ = gru_param_->bidirectional_ ? 2 * gru_param_->batch_ * gru_param_->hidden_size_
|
||||
: gru_param_->batch_ * gru_param_->hidden_size_;
|
||||
weight_batch_ = gru_param_->bidirectional_ ? 6 : 3;
|
||||
|
||||
#ifdef ENABLE_AVX
|
||||
row_tile_ = C6NUM;
|
||||
|
@ -84,56 +88,75 @@ int GruCPUKernel::InitParam() {
|
|||
row_tile_ = C12NUM;
|
||||
col_tile_ = C8NUM;
|
||||
#endif
|
||||
gru_param_->input_row_align_ = UP_ROUND(gru_param_->seq_len_ * gru_param_->batch_, row_tile_);
|
||||
gru_param_->input_col_align_ = UP_ROUND(gru_param_->hidden_size_, col_tile_);
|
||||
|
||||
is_vec_ = gru_param_->batch_ == 1;
|
||||
gru_param_->row_align_ = is_vec_ ? 1 : UP_ROUND(gru_param_->batch_, row_tile_);
|
||||
gru_param_->col_align_ = is_vec_ ? gru_param_->hidden_size_ : UP_ROUND(gru_param_->hidden_size_, col_tile_);
|
||||
gru_param_->state_row_align_ = is_vec_ ? 1 : UP_ROUND(gru_param_->batch_, row_tile_);
|
||||
gru_param_->state_col_align_ = is_vec_ ? gru_param_->hidden_size_ : UP_ROUND(gru_param_->hidden_size_, col_tile_);
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
int GruCPUKernel::InitWeightBias() {
|
||||
auto weight_batch = gru_param_->bidirectional_ ? 6 : 3;
|
||||
if (!is_vec_) {
|
||||
int GruCPUKernel::InitInputWeightBias() {
|
||||
// malloc and init input * weight right matrix buffer
|
||||
// input -- row: seq_len * batch; col: input_size
|
||||
// weight -- row: hidden_size; col: input_size, need transpose
|
||||
// result -- row: seq_len * batch; col: hidden_size
|
||||
auto weight_g = in_tensors_.at(1);
|
||||
MS_ASSERT(weight_g != nullptr);
|
||||
weight_g_ptr_ = reinterpret_cast<float *>(
|
||||
malloc(weight_batch * gru_param_->col_align_ * gru_param_->input_size_ * sizeof(float)));
|
||||
malloc(weight_batch_ * gru_param_->input_col_align_ * gru_param_->input_size_ * sizeof(float)));
|
||||
if (weight_g_ptr_ == nullptr) {
|
||||
MS_LOG(ERROR) << "GruCPUKernel malloc weight_g_ptr_ error.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
auto weight_i_data = reinterpret_cast<float *>(weight_g->data_c());
|
||||
PackLstmWeight(weight_g_ptr_, weight_i_data, weight_batch, gru_param_->input_size_, gru_param_->hidden_size_,
|
||||
gru_param_->col_align_);
|
||||
auto weight_g_data = reinterpret_cast<float *>(weight_g->data_c());
|
||||
PackLstmWeight(weight_g_ptr_, weight_g_data, weight_batch_, gru_param_->input_size_, gru_param_->hidden_size_,
|
||||
gru_param_->input_col_align_);
|
||||
|
||||
// malloc and init state * weight right matrix buffer
|
||||
// input bias
|
||||
input_bias_ = reinterpret_cast<float *>(malloc(weight_batch_ * gru_param_->input_col_align_ * sizeof(float)));
|
||||
if (input_bias_ == nullptr) {
|
||||
MS_LOG(ERROR) << "GruCPUKernel malloc input_bias_ error.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
memset(input_bias_, 0, weight_batch_ * gru_param_->input_col_align_ * sizeof(float));
|
||||
PackLstmBias(input_bias_, reinterpret_cast<float *>(in_tensors_.at(3)->data_c()), weight_batch_,
|
||||
gru_param_->hidden_size_, gru_param_->input_col_align_, gru_param_->bidirectional_);
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
int GruCPUKernel::InitStateWeightBias() {
|
||||
// malloc and init state * weight right matrix buffer, state * weight will be executed seq_len_ times.
|
||||
// state -- row: batch; col: hidden_size
|
||||
// weight -- row: hidden_size; col: hidden_size, need transpose
|
||||
// result -- row: batch; col: hidden_size
|
||||
auto weight_r = in_tensors_.at(2);
|
||||
MS_ASSERT(weight_r != nullptr);
|
||||
auto weight_r_data = reinterpret_cast<float *>(weight_r->data_c());
|
||||
if (!is_vec_) {
|
||||
weight_r_ptr_ = reinterpret_cast<float *>(
|
||||
malloc(weight_batch * gru_param_->col_align_ * gru_param_->hidden_size_ * sizeof(float)));
|
||||
malloc(weight_batch_ * gru_param_->state_col_align_ * gru_param_->hidden_size_ * sizeof(float)));
|
||||
if (weight_r_ptr_ == nullptr) {
|
||||
MS_LOG(ERROR) << "GruCPUKernel malloc weight_r_ptr_ error.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
auto weight_r_data = reinterpret_cast<float *>(weight_r->data_c());
|
||||
PackLstmWeight(weight_r_ptr_, weight_r_data, weight_batch, gru_param_->hidden_size_, gru_param_->hidden_size_,
|
||||
gru_param_->col_align_);
|
||||
PackLstmWeight(weight_r_ptr_, weight_r_data, weight_batch_, gru_param_->hidden_size_, gru_param_->hidden_size_,
|
||||
gru_param_->state_col_align_);
|
||||
} else {
|
||||
weight_r_ptr_ = weight_r_data;
|
||||
}
|
||||
|
||||
// init bias
|
||||
int bias_batch = gru_param_->bidirectional_ ? 16 : 8;
|
||||
bias_ptr_ = reinterpret_cast<float *>(malloc(bias_batch * gru_param_->col_align_ * sizeof(float)));
|
||||
if (bias_ptr_ == nullptr) {
|
||||
MS_LOG(ERROR) << "GruCPUKernel malloc bias_ptr_ error.";
|
||||
// state bias
|
||||
state_bias_ = reinterpret_cast<float *>(malloc(weight_batch_ * gru_param_->state_col_align_ * sizeof(float)));
|
||||
if (state_bias_ == nullptr) {
|
||||
MS_LOG(ERROR) << "GruCPUKernel malloc state_bias_ error.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
memset(bias_ptr_, 0, bias_batch * gru_param_->col_align_ * sizeof(float));
|
||||
auto bias_data = reinterpret_cast<float *>(in_tensors_.at(3)->data_c());
|
||||
for (int i = 0; i < bias_batch; i++) {
|
||||
auto src_batch = bias_data + i * gru_param_->hidden_size_;
|
||||
auto dst_batch = bias_ptr_ + i * gru_param_->col_align_;
|
||||
memcpy(dst_batch, src_batch, gru_param_->hidden_size_ * sizeof(float));
|
||||
}
|
||||
}
|
||||
memset(state_bias_, 0, weight_batch_ * gru_param_->state_col_align_ * sizeof(float));
|
||||
auto state_bias = reinterpret_cast<float *>(in_tensors_.at(3)->data_c()) + 3 * gru_param_->hidden_size_;
|
||||
PackLstmBias(state_bias_, state_bias, weight_batch_, gru_param_->hidden_size_, gru_param_->state_col_align_,
|
||||
gru_param_->bidirectional_);
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
|
@ -152,9 +175,16 @@ int GruCPUKernel::ReSize() {
|
|||
}
|
||||
|
||||
FreeTmpBuffer();
|
||||
ret = InitWeightBias();
|
||||
ret = InitInputWeightBias();
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "GruCPUKernel InitWeightBias error.";
|
||||
MS_LOG(ERROR) << "GruCPUKernel InitInputWeightBias error.";
|
||||
FreeTmpBuffer();
|
||||
return RET_ERROR;
|
||||
}
|
||||
|
||||
ret = InitStateWeightBias();
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "GruCPUKernel InitStateWeightBias error.";
|
||||
FreeTmpBuffer();
|
||||
return RET_ERROR;
|
||||
}
|
||||
|
@ -162,25 +192,36 @@ int GruCPUKernel::ReSize() {
|
|||
}
|
||||
|
||||
int GruCPUKernel::MallocRunBuffer() {
|
||||
if (!is_vec_) {
|
||||
matmul_buffer_[0] = reinterpret_cast<float *>(
|
||||
context_->allocator->Malloc(3 * gru_param_->row_align_ * gru_param_->input_size_ * sizeof(float)));
|
||||
if (matmul_buffer_[0] == nullptr) {
|
||||
for (int i = 0; i < 4; i++) {
|
||||
buffer_[i] = nullptr;
|
||||
}
|
||||
buffer_[0] = reinterpret_cast<float *>(
|
||||
context_->allocator->Malloc(gru_param_->input_row_align_ * gru_param_->input_size_ * sizeof(float)));
|
||||
if (buffer_[0] == nullptr) {
|
||||
MS_LOG(ERROR) << "GruCPUKernel malloc input * weight left matirx error.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
|
||||
matmul_buffer_[1] = reinterpret_cast<float *>(
|
||||
context_->allocator->Malloc(3 * gru_param_->row_align_ * gru_param_->hidden_size_ * sizeof(float)));
|
||||
if (matmul_buffer_[1] == nullptr) {
|
||||
buffer_[1] = reinterpret_cast<float *>(context_->allocator->Malloc(3 * gru_param_->seq_len_ * gru_param_->batch_ *
|
||||
gru_param_->hidden_size_ * sizeof(float)));
|
||||
if (buffer_[1] == nullptr) {
|
||||
MS_LOG(ERROR) << "GruCPUKernel malloc input * weight result matirx error.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
|
||||
if (!is_vec_) {
|
||||
buffer_[2] = reinterpret_cast<float *>(
|
||||
context_->allocator->Malloc(gru_param_->state_row_align_ * gru_param_->hidden_size_ * sizeof(float)));
|
||||
if (buffer_[2] == nullptr) {
|
||||
MS_LOG(ERROR) << "GruCPUKernel malloc state * weight left matirx error.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
}
|
||||
gate_buffer_ = reinterpret_cast<float *>(
|
||||
context_->allocator->Malloc(6 * gru_param_->batch_ * gru_param_->hidden_size_ * sizeof(float)));
|
||||
if (gate_buffer_ == nullptr) {
|
||||
MS_LOG(ERROR) << "GruCPUKernel malloc gate_buffer error.";
|
||||
|
||||
buffer_[3] = reinterpret_cast<float *>(
|
||||
context_->allocator->Malloc(3 * gru_param_->batch_ * gru_param_->hidden_size_ * sizeof(float)));
|
||||
if (buffer_[3] == nullptr) {
|
||||
MS_LOG(ERROR) << "GruCPUKernel malloc state gate buffer error.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
return RET_OK;
|
||||
|
@ -215,18 +256,12 @@ int GruCPUKernel::Run() {
|
|||
return RET_ERROR;
|
||||
}
|
||||
|
||||
if (is_vec_) {
|
||||
weight_g_ptr_ = reinterpret_cast<float *>(in_tensors_[1]->data_c());
|
||||
weight_r_ptr_ = reinterpret_cast<float *>(in_tensors_[2]->data_c());
|
||||
bias_ptr_ = reinterpret_cast<float *>(in_tensors_[3]->data_c());
|
||||
}
|
||||
MS_ASSERT(weight_g_ptr_ != nullptr);
|
||||
MS_ASSERT(weight_r_ptr_ != nullptr);
|
||||
MS_ASSERT(bias_ptr_ != nullptr);
|
||||
MS_ASSERT(gate_buffer_ != nullptr);
|
||||
Gru(output_ptr, input_ptr, weight_g_ptr_, weight_r_ptr_, bias_ptr_,
|
||||
reinterpret_cast<float *>(output_hidden_state->data_c()), gate_buffer_, matmul_buffer_, check_seq_len,
|
||||
gru_param_);
|
||||
MS_ASSERT(input_bias_ != nullptr);
|
||||
MS_ASSERT(state_bias_ != nullptr);
|
||||
Gru(output_ptr, input_ptr, weight_g_ptr_, weight_r_ptr_, input_bias_, state_bias_,
|
||||
reinterpret_cast<float *>(output_hidden_state->data_c()), buffer_, check_seq_len, gru_param_);
|
||||
FreeRunBuffer();
|
||||
return RET_OK;
|
||||
}
|
||||
|
|
|
@ -39,15 +39,17 @@ class GruCPUKernel : public LiteKernel {
|
|||
void FreeRunBuffer();
|
||||
int InitParam();
|
||||
int MallocRunBuffer();
|
||||
int InitWeightBias();
|
||||
int InitInputWeightBias();
|
||||
int InitStateWeightBias();
|
||||
|
||||
float *gate_buffer_ = nullptr;
|
||||
float *weight_g_ptr_ = nullptr;
|
||||
float *weight_r_ptr_ = nullptr;
|
||||
float *bias_ptr_ = nullptr;
|
||||
float *matmul_buffer_[2];
|
||||
float *input_bias_ = nullptr;
|
||||
float *state_bias_ = nullptr;
|
||||
float *buffer_[4];
|
||||
int row_tile_ = 0;
|
||||
int col_tile_ = 0;
|
||||
int weight_batch_ = 0;
|
||||
bool is_vec_ = false;
|
||||
GruParameter *gru_param_ = nullptr;
|
||||
};
|
||||
|
|
|
@ -52,15 +52,18 @@ void LstmCPUKernel::FreeTmpBuffer() {
|
|||
}
|
||||
|
||||
void LstmCPUKernel::FreeRunBuffer() {
|
||||
for (int i = 0; i < 2; i++) {
|
||||
context_->allocator->Free(state_buffer_[i]);
|
||||
}
|
||||
context_->allocator->Free(buffer_[0]);
|
||||
context_->allocator->Free(buffer_[1]);
|
||||
if (!is_vec_) {
|
||||
context_->allocator->Free(buffer_[2]);
|
||||
}
|
||||
context_->allocator->Free(buffer_[3]);
|
||||
if (!(lstm_param_->zoneout_cell_ >= -FLT_EPSILON && lstm_param_->zoneout_cell_ <= FLT_EPSILON)) {
|
||||
context_->allocator->Free(buffer_[4]);
|
||||
}
|
||||
if (!(lstm_param_->zoneout_hidden_ >= -FLT_EPSILON && lstm_param_->zoneout_hidden_ <= FLT_EPSILON)) {
|
||||
context_->allocator->Free(buffer_[5]);
|
||||
}
|
||||
}
|
||||
|
||||
int LstmCPUKernel::InitInputWeightBias() {
|
||||
|
@ -197,7 +200,7 @@ int LstmCPUKernel::ReSize() {
|
|||
}
|
||||
|
||||
int LstmCPUKernel::MallocRunBuffer() {
|
||||
for (int i = 0; i < 4; i++) {
|
||||
for (int i = 0; i < 6; i++) {
|
||||
buffer_[i] = nullptr;
|
||||
}
|
||||
buffer_[0] = reinterpret_cast<float *>(
|
||||
|
@ -216,7 +219,7 @@ int LstmCPUKernel::MallocRunBuffer() {
|
|||
|
||||
if (!is_vec_) {
|
||||
buffer_[2] = reinterpret_cast<float *>(
|
||||
context_->allocator->Malloc(4 * lstm_param_->state_row_align_ * lstm_param_->hidden_size_ * sizeof(float)));
|
||||
context_->allocator->Malloc(lstm_param_->state_row_align_ * lstm_param_->hidden_size_ * sizeof(float)));
|
||||
if (buffer_[2] == nullptr) {
|
||||
MS_LOG(ERROR) << "LstmCPUKernel malloc state * weight left matirx error.";
|
||||
return RET_ERROR;
|
||||
|
@ -229,20 +232,19 @@ int LstmCPUKernel::MallocRunBuffer() {
|
|||
MS_LOG(ERROR) << "LstmCPUKernel malloc state gate buffer error.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
state_buffer_[0] = nullptr;
|
||||
state_buffer_[1] = nullptr;
|
||||
|
||||
if (!(lstm_param_->zoneout_cell_ >= -FLT_EPSILON && lstm_param_->zoneout_cell_ <= FLT_EPSILON)) {
|
||||
auto buffer_size = lstm_param_->batch_ * lstm_param_->hidden_size_ * sizeof(float);
|
||||
state_buffer_[0] = reinterpret_cast<float *>(context_->allocator->Malloc(buffer_size));
|
||||
if (state_buffer_[0] == nullptr) {
|
||||
buffer_[4] = reinterpret_cast<float *>(context_->allocator->Malloc(buffer_size));
|
||||
if (buffer_[4] == nullptr) {
|
||||
MS_LOG(ERROR) << "LstmCPUKernel malloc state_buffer for cell error.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
}
|
||||
if (!(lstm_param_->zoneout_hidden_ >= -FLT_EPSILON && lstm_param_->zoneout_hidden_ <= FLT_EPSILON)) {
|
||||
auto buffer_size = lstm_param_->batch_ * lstm_param_->hidden_size_ * sizeof(float);
|
||||
state_buffer_[1] = reinterpret_cast<float *>(context_->allocator->Malloc(buffer_size));
|
||||
if (state_buffer_[1] == nullptr) {
|
||||
buffer_[5] = reinterpret_cast<float *>(context_->allocator->Malloc(buffer_size));
|
||||
if (buffer_[5] == nullptr) {
|
||||
MS_LOG(ERROR) << "LstmCPUKernel malloc state_buffer for hidden error.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
|
@ -281,7 +283,7 @@ int LstmCPUKernel::Run() {
|
|||
MS_ASSERT(state_bias_);
|
||||
Lstm(output_ptr, input_ptr, weight_i_ptr_, weight_h_ptr_, input_bias_, state_bias_,
|
||||
reinterpret_cast<float *>(output_hidden_state->data_c()), reinterpret_cast<float *>(output_cell_state->data_c()),
|
||||
state_buffer_, buffer_, lstm_param_);
|
||||
buffer_, lstm_param_);
|
||||
FreeRunBuffer();
|
||||
return RET_OK;
|
||||
}
|
||||
|
|
|
@ -44,12 +44,11 @@ class LstmCPUKernel : public LiteKernel {
|
|||
int InitInputWeightBias();
|
||||
int InitStateWeightBias();
|
||||
|
||||
float *state_buffer_[2];
|
||||
float *weight_i_ptr_ = nullptr;
|
||||
float *weight_h_ptr_ = nullptr;
|
||||
float *input_bias_ = nullptr;
|
||||
float *state_bias_ = nullptr;
|
||||
float *buffer_[4];
|
||||
float *buffer_[6];
|
||||
int row_tile_ = 0;
|
||||
int col_tile_ = 0;
|
||||
int weight_batch_ = 0;
|
||||
|
|
Loading…
Reference in New Issue