[MSLITE][Develop] optimize arm cpu fp32 op gru: use matmul

This commit is contained in:
yangruoqi713 2021-02-03 15:56:25 +08:00
parent 4d866ea64a
commit 53568e0e8b
8 changed files with 279 additions and 213 deletions

View File

@ -18,115 +18,126 @@
#include "nnacl/fp32/lstm_fp32.h"
#include "nnacl/fp32/activation_fp32.h"
#include "nnacl/fp32/arithmetic_fp32.h"
#include "nnacl/fp32/matmul_fp32.h"
void InitGruGate(float *gate_buffer, const float *bias, const GruParameter *gru_parm) {
int gate_offest = 0;
for (int l = 0; l < 3; l++) {
int batch_offest = gate_offest;
int bias_offest = l * gru_parm->hidden_size_;
for (int b = 0; b < gru_parm->batch_; b++) {
memcpy(gate_buffer + batch_offest, bias + bias_offest, gru_parm->hidden_size_ * sizeof(float));
batch_offest += gru_parm->hidden_size_;
}
gate_offest += gru_parm->batch_ * gru_parm->hidden_size_;
void UpdateGruInputGate(float *gate_buffer, const float *input, const float *weight, const float *bias, int row,
int deep, int col, int col_align, bool is_vec) {
for (int i = 0; i < 3; i++) {
const float *weight_i = weight + deep * col * i;
const float *bias_i = bias + col_align * i;
float *gate = gate_buffer + row * col * i;
LstmMatMul(gate, input, weight_i, bias_i, row, deep, col, is_vec);
}
}
void GruStepUnit(float *output, const float *input, const float *input_reset_weight, const float *input_update_weight,
const float *input_hidden_weight, const float *state_reset_weight, const float *state_update_weight,
const float *state_hidden_weight, const float *bias, float *hidden_state, float *gate_buffer,
const GruParameter *gru_parm) {
InitGruGate(gate_buffer, bias, gru_parm);
float *update_gate = gate_buffer;
float *reset_gate = gate_buffer + gru_parm->batch_ * gru_parm->hidden_size_;
float *hidden_buffer = gate_buffer + gru_parm->batch_ * gru_parm->hidden_size_ * 2;
void GruStepUnit(float *output, const float *input, const float *input_weight, const float *state_weight,
const float *bias, float *hidden_state, float *gate_buffer, float *matmul_buffer[2],
const GruParameter *gru_param) {
bool is_vec = gru_param->batch_ == 1;
// input * weight
MatMulAcc(reset_gate, input, input_reset_weight, gru_parm->batch_, gru_parm->hidden_size_, gru_parm->input_size_);
MatMulAcc(update_gate, input, input_update_weight, gru_parm->batch_, gru_parm->hidden_size_, gru_parm->input_size_);
MatMulAcc(hidden_buffer, input, input_hidden_weight, gru_parm->batch_, gru_parm->hidden_size_, gru_parm->input_size_);
if (is_vec) {
UpdateGruInputGate(gate_buffer, input, input_weight, bias, gru_param->batch_, gru_param->input_size_,
gru_param->hidden_size_, gru_param->col_align_, is_vec);
} else {
// pack input for matmul
PackLstmInput(matmul_buffer[0], input, gru_param->batch_, gru_param->input_size_);
UpdateGruInputGate(gate_buffer, matmul_buffer[0], input_weight, bias, gru_param->batch_, gru_param->input_size_,
gru_param->hidden_size_, gru_param->col_align_, is_vec);
}
const float *state_update_weight = state_weight;
const float *state_reset_weight = state_weight + gru_param->hidden_size_ * gru_param->hidden_size_;
const float *state_hidden_weight = state_weight + gru_param->hidden_size_ * gru_param->hidden_size_ * 2;
float *state_update_gate = gate_buffer + gru_param->batch_ * gru_param->hidden_size_ * 3;
float *state_reset_gate = gate_buffer + gru_param->batch_ * gru_param->hidden_size_ * 4;
float *state_hidden_buffer = gate_buffer + gru_param->batch_ * gru_param->hidden_size_ * 5;
const float *state_update_bias = bias + gru_param->hidden_size_ * 3;
const float *state_reset_bias = bias + gru_param->hidden_size_ * 4;
const float *state_hidden_bias = bias + gru_param->hidden_size_ * 5;
// state * weight
MatMulAcc(reset_gate, hidden_state, state_reset_weight, gru_parm->batch_, gru_parm->hidden_size_,
gru_parm->hidden_size_);
MatMulAcc(update_gate, hidden_state, state_update_weight, gru_parm->batch_, gru_parm->hidden_size_,
gru_parm->hidden_size_);
if (is_vec) {
LstmMatMul(state_reset_gate, hidden_state, state_reset_weight, state_reset_bias, gru_param->batch_,
gru_param->hidden_size_, gru_param->hidden_size_, is_vec);
LstmMatMul(state_update_gate, hidden_state, state_update_weight, state_update_bias, gru_param->batch_,
gru_param->hidden_size_, gru_param->hidden_size_, is_vec);
} else {
PackLstmInput(matmul_buffer[1], hidden_state, gru_param->batch_, gru_param->hidden_size_);
LstmMatMul(state_reset_gate, matmul_buffer[1], state_reset_weight, state_reset_bias, gru_param->batch_,
gru_param->hidden_size_, gru_param->hidden_size_, is_vec);
LstmMatMul(state_update_gate, matmul_buffer[1], state_update_weight, state_update_bias, gru_param->batch_,
gru_param->hidden_size_, gru_param->hidden_size_, is_vec);
}
ElementAdd(gate_buffer, state_update_gate, gate_buffer, gru_param->batch_ * gru_param->hidden_size_ * 2);
float *update_gate = gate_buffer;
float *reset_gate = gate_buffer + gru_param->batch_ * gru_param->hidden_size_;
float *hidden_buffer = gate_buffer + gru_param->batch_ * gru_param->hidden_size_ * 2;
// update reset_gate
Sigmoid(reset_gate, gru_parm->batch_ * gru_parm->hidden_size_, reset_gate);
Sigmoid(reset_gate, gru_param->batch_ * gru_param->hidden_size_, reset_gate);
// update update_gate
Sigmoid(update_gate, gru_parm->batch_ * gru_parm->hidden_size_, update_gate);
Sigmoid(update_gate, gru_param->batch_ * gru_param->hidden_size_, update_gate);
ElementMul(hidden_state, reset_gate, reset_gate, gru_parm->batch_ * gru_parm->hidden_size_);
MatMulAcc(hidden_buffer, reset_gate, state_hidden_weight, gru_parm->batch_, gru_parm->hidden_size_,
gru_parm->hidden_size_);
ElementMul(hidden_state, reset_gate, reset_gate, gru_param->batch_ * gru_param->hidden_size_);
if (is_vec) {
LstmMatMul(state_hidden_buffer, reset_gate, state_hidden_weight, state_hidden_bias, gru_param->batch_,
gru_param->hidden_size_, gru_param->hidden_size_, is_vec);
} else {
PackLstmInput(matmul_buffer[1], reset_gate, gru_param->batch_, gru_param->hidden_size_);
LstmMatMul(state_hidden_buffer, matmul_buffer[1], state_hidden_weight, state_hidden_bias, gru_param->batch_,
gru_param->hidden_size_, gru_param->hidden_size_, is_vec);
}
ElementAdd(hidden_buffer, state_hidden_buffer, hidden_buffer, gru_param->batch_ * gru_param->hidden_size_);
Tanh(hidden_buffer, gru_parm->batch_ * gru_parm->hidden_size_, hidden_buffer);
Tanh(hidden_buffer, gru_param->batch_ * gru_param->hidden_size_, hidden_buffer);
ElementMul(update_gate, hidden_state, hidden_state, gru_parm->batch_ * gru_parm->hidden_size_);
ElementMul(update_gate, hidden_state, hidden_state, gru_param->batch_ * gru_param->hidden_size_);
ArithmeticParameter parameter;
parameter.in_elements_num0_ = 1;
parameter.in_elements_num1_ = gru_parm->batch_ * gru_parm->hidden_size_;
parameter.in_elements_num1_ = gru_param->batch_ * gru_param->hidden_size_;
const float one = 1.0f;
ElementOptSub(&one, update_gate, update_gate, gru_parm->batch_ * gru_parm->hidden_size_, &parameter);
ElementOptSub(&one, update_gate, update_gate, gru_param->batch_ * gru_param->hidden_size_, &parameter);
ElementMulAcc(update_gate, hidden_buffer, hidden_state, gru_parm->batch_ * gru_parm->hidden_size_);
ElementMulAcc(update_gate, hidden_buffer, hidden_state, gru_param->batch_ * gru_param->hidden_size_);
memcpy(output, hidden_state, gru_parm->batch_ * gru_parm->hidden_size_ * sizeof(float));
memcpy(output, hidden_state, gru_param->batch_ * gru_param->hidden_size_ * sizeof(float));
}
void Gru(float *output, const float *input, const float *weight_g, const float *weight_r, const float *bias,
float *hidden_state, float *gate_buffer, int check_seq_len, const GruParameter *gru_parm) {
float *hidden_state, float *gate_buffer, float *matmul_buffer[2], int check_seq_len,
const GruParameter *gru_param) {
// forward
const float *input_update_weight = weight_g;
const float *input_reset_weight = weight_g + gru_parm->input_size_ * gru_parm->hidden_size_;
const float *input_hidden_weight = weight_g + gru_parm->input_size_ * gru_parm->hidden_size_ * 2;
const float *state_update_weight = weight_r;
const float *state_reset_weight = weight_r + gru_parm->hidden_size_ * gru_parm->hidden_size_;
const float *state_hidden_weight = weight_r + gru_parm->hidden_size_ * gru_parm->hidden_size_ * 2;
for (int t = 0; t < check_seq_len; t++) {
const float *input_ptr = input + t * gru_parm->input_step_;
float *output_ptr = output + t * gru_parm->output_step_;
GruStepUnit(output_ptr, input_ptr, input_reset_weight, input_update_weight, input_hidden_weight, state_reset_weight,
state_update_weight, state_hidden_weight, bias, hidden_state, gate_buffer, gru_parm);
const float *input_ptr = input + t * gru_param->input_step_;
float *output_ptr = output + t * gru_param->output_step_;
GruStepUnit(output_ptr, input_ptr, weight_g, weight_r, bias, hidden_state, gate_buffer, matmul_buffer, gru_param);
}
// zero out extra fw outputs
for (int t = check_seq_len; t < gru_parm->seq_len_; t++) {
float *output_ptr = output + t * gru_parm->output_step_;
for (int i = 0; i < gru_parm->batch_ * gru_parm->hidden_size_; i++) {
for (int t = check_seq_len; t < gru_param->seq_len_; t++) {
float *output_ptr = output + t * gru_param->output_step_;
for (int i = 0; i < gru_param->batch_ * gru_param->hidden_size_; i++) {
output_ptr[i] = 0.0f;
}
}
// backward
if (gru_parm->bidirectional_) {
input_update_weight = weight_g + gru_parm->input_size_ * gru_parm->hidden_size_ * 3;
input_reset_weight = weight_g + gru_parm->input_size_ * gru_parm->hidden_size_ * 4;
input_hidden_weight = weight_g + gru_parm->input_size_ * gru_parm->hidden_size_ * 5;
state_update_weight = weight_r + gru_parm->hidden_size_ * gru_parm->hidden_size_ * 3;
state_reset_weight = weight_r + gru_parm->hidden_size_ * gru_parm->hidden_size_ * 4;
state_hidden_weight = weight_r + gru_parm->hidden_size_ * gru_parm->hidden_size_ * 5;
float *backward_output = output + gru_parm->batch_ * gru_parm->hidden_size_;
const float *backward_bias = bias + 3 * gru_parm->hidden_size_;
float *backward_hidden_state = hidden_state + gru_parm->batch_ * gru_parm->hidden_size_;
if (gru_param->bidirectional_) {
const float *backward_weight_g = weight_g + 3 * gru_param->col_align_ * gru_param->input_size_;
const float *backward_weight_r = weight_r + 3 * gru_param->col_align_ * gru_param->hidden_size_;
const float *backward_bias = bias + 6 * gru_param->hidden_size_;
float *backward_output = output + gru_param->batch_ * gru_param->hidden_size_;
float *backward_hidden_state = hidden_state + gru_param->batch_ * gru_param->hidden_size_;
for (int t = check_seq_len - 1; t >= 0; t--) {
const float *input_ptr = input + t * gru_parm->input_step_;
float *output_ptr = backward_output + t * gru_parm->output_step_;
GruStepUnit(output_ptr, input_ptr, input_reset_weight, input_update_weight, input_hidden_weight,
state_reset_weight, state_update_weight, state_hidden_weight, backward_bias, backward_hidden_state,
gate_buffer, gru_parm);
const float *input_ptr = input + t * gru_param->input_step_;
float *output_ptr = backward_output + t * gru_param->output_step_;
GruStepUnit(output_ptr, input_ptr, backward_weight_g, backward_weight_r, backward_bias, backward_hidden_state,
gate_buffer, matmul_buffer, gru_param);
}
// zero out extra bw outputs
for (int t = gru_parm->seq_len_ - 1; t >= check_seq_len; t--) {
float *output_ptr = backward_output + t * gru_parm->output_step_;
for (int i = 0; i < gru_parm->batch_ * gru_parm->hidden_size_; i++) {
for (int t = gru_param->seq_len_ - 1; t >= check_seq_len; t--) {
float *output_ptr = backward_output + t * gru_param->output_step_;
for (int i = 0; i < gru_param->batch_ * gru_param->hidden_size_; i++) {
output_ptr[i] = 0.0f;
}
}

View File

@ -21,7 +21,8 @@
extern "C" {
#endif
void Gru(float *output, const float *input, const float *weight_g, const float *weight_r, const float *bias,
float *hidden_state, float *gate_buffer, int check_seq_len, const GruParameter *gru_parm);
float *hidden_state, float *gate_buffer, float *matmul_buffer[2], int check_seq_len,
const GruParameter *gru_parm);
#ifdef __cplusplus
}
#endif

View File

@ -21,6 +21,30 @@
#include "nnacl/fp32/arithmetic_fp32.h"
#include "nnacl/fp32/matmul_fp32.h"
void PackLstmWeight(float *dst, const float *src, int batch, int deep, int col, int col_align) {
for (int i = 0; i < batch; i++) {
const float *src_batch = src + i * col * deep;
float *dst_batch = dst + i * col_align * deep;
#ifdef ENABLE_AVX
RowMajor2Col16Major(src_batch, dst_batch, col, deep);
#elif defined(ENABLE_ARM32)
RowMajor2Col4Major(src_batch, dst_batch, col, deep);
#else
RowMajor2Col8Major(src_batch, dst_batch, col, deep);
#endif
}
}
void PackLstmInput(float *dst, const float *src, int row, int deep) {
#ifdef ENABLE_AVX
RowMajor2Col6Major(src, dst, row, deep);
#elif defined(ENABLE_SSE)
RowMajor2Col4Major(src, dst, row, deep);
#else
RowMajor2Col12Major(src, dst, row, deep);
#endif
}
// input: [row, inner_size]; weight: [col, inner_size]; output: [row, col]
void MatMulAcc(float *output, const float *input, const float *weight, int rows, int cols, int inner_size) {
for (int r = 0; r < rows; r++) {
@ -52,6 +76,15 @@ void MatMulAcc(float *output, const float *input, const float *weight, int rows,
}
}
void LstmMatMul(float *c, const float *a, const float *b, const float *bias, int row, int deep, int col, bool is_vec) {
if (is_vec) {
memcpy(c, bias, col * sizeof(float));
MatMulAcc(c, a, b, row, col, deep);
} else {
MatMulOpt(a, b, c, bias, ActType_No, deep, row, col, col, OutType_Nhwc);
}
}
void ElementMulAcc(const float *input0, const float *input1, float *output, int element_size) {
int index = 0;
#ifdef ENABLE_ARM
@ -121,74 +154,42 @@ void UpdataOutput(const float *cell_state, const float *output_gate, float *hidd
}
}
void LstmMatmul(float *c, const float *a, const float *b, const float *bias, int row, int deep, int col, bool is_vec) {
if (is_vec) {
memcpy(c, bias, col * sizeof(float));
MatMulAcc(c, a, b, row, col, deep);
} else {
MatMulOpt(a, b, c, bias, ActType_No, deep, row, col, col, OutType_Nhwc);
void UpdateLstmGate(float *gate_buffer, const float *input, const float *weight, const float *bias, int row, int deep,
int col, int col_align, bool is_vec) {
for (int i = 0; i < 4; i++) {
const float *weight_i = weight + deep * col * i;
const float *bias_i = bias + col_align * i;
float *gate = gate_buffer + row * col * i;
LstmMatMul(gate, input, weight_i, bias_i, row, deep, col, is_vec);
}
}
void PackLstmInput(float *dst, const float *src, int row, int deep) {
#ifdef ENABLE_AVX
RowMajor2Col6Major(src, dst, row, deep);
#elif defined(ENABLE_SSE)
RowMajor2Col4Major(src, dst, row, deep);
#else
RowMajor2Col12Major(src, dst, row, deep);
#endif
}
void UpdateGate(float *gate_buffer, const float *input, const float *weight, const float *bias, int row, int deep,
int col, int col_align, bool is_vec) {
const float *input_weight = weight;
const float *forget_weight = weight + deep * col * 2;
const float *cell_weight = weight + deep * col * 3;
const float *output_weight = weight + deep * col;
const float *input_bias = bias;
const float *forget_bias = bias + col_align * 2;
const float *cell_bias = bias + col_align * 3;
const float *output_bias = bias + col_align;
float *input_gate = gate_buffer;
float *forget_gate = gate_buffer + row * col * 2;
float *cell_gate = gate_buffer + row * col * 3;
float *output_gate = gate_buffer + row * col;
LstmMatmul(input_gate, input, input_weight, input_bias, row, deep, col, is_vec);
LstmMatmul(forget_gate, input, forget_weight, forget_bias, row, deep, col, is_vec);
LstmMatmul(cell_gate, input, cell_weight, cell_bias, row, deep, col, is_vec);
LstmMatmul(output_gate, input, output_weight, output_bias, row, deep, col, is_vec);
}
void LstmStepUnit(float *output, const float *input, const float *input_weight, const float *state_weight,
const float *bias, float *hidden_state, float *cell_state, float *gate_buffer, float *state_buffer,
float *matmul_buffer[2], const LstmParameter *lstm_param) {
bool is_vec = lstm_param->batch_ == 1;
// input * weight
if (is_vec) {
UpdateGate(gate_buffer, input, input_weight, bias, lstm_param->batch_, lstm_param->input_size_,
lstm_param->hidden_size_, lstm_param->col_align_, is_vec);
UpdateLstmGate(gate_buffer, input, input_weight, bias, lstm_param->batch_, lstm_param->input_size_,
lstm_param->hidden_size_, lstm_param->col_align_, is_vec);
} else {
// pack input for matmul
PackLstmInput(matmul_buffer[0], input, lstm_param->batch_, lstm_param->input_size_);
UpdateGate(gate_buffer, matmul_buffer[0], input_weight, bias, lstm_param->batch_, lstm_param->input_size_,
lstm_param->hidden_size_, lstm_param->col_align_, is_vec);
UpdateLstmGate(gate_buffer, matmul_buffer[0], input_weight, bias, lstm_param->batch_, lstm_param->input_size_,
lstm_param->hidden_size_, lstm_param->col_align_, is_vec);
}
// state * weight
float *state_gate = gate_buffer + lstm_param->batch_ * lstm_param->hidden_size_ * 4;
const float *state_bias = bias + lstm_param->col_align_ * 4;
if (is_vec) {
UpdateGate(state_gate, hidden_state, state_weight, state_bias, lstm_param->batch_, lstm_param->hidden_size_,
lstm_param->hidden_size_, lstm_param->col_align_, is_vec);
UpdateLstmGate(state_gate, hidden_state, state_weight, state_bias, lstm_param->batch_, lstm_param->hidden_size_,
lstm_param->hidden_size_, lstm_param->col_align_, is_vec);
} else {
// pack state for matmul
PackLstmInput(matmul_buffer[1], hidden_state, lstm_param->batch_, lstm_param->hidden_size_);
UpdateGate(state_gate, matmul_buffer[1], state_weight, state_bias, lstm_param->batch_, lstm_param->hidden_size_,
lstm_param->hidden_size_, lstm_param->col_align_, is_vec);
UpdateLstmGate(state_gate, matmul_buffer[1], state_weight, state_bias, lstm_param->batch_, lstm_param->hidden_size_,
lstm_param->hidden_size_, lstm_param->col_align_, is_vec);
}
ElementAdd(gate_buffer, state_gate, gate_buffer, 4 * lstm_param->batch_ * lstm_param->hidden_size_);

View File

@ -21,7 +21,11 @@
#ifdef __cplusplus
extern "C" {
#endif
void MatMulAcc(float *output, const float *input, const float *weight, int rows, int cols, int inner_size);
void PackLstmWeight(float *dst, const float *src, int batch, int deep, int col, int col_align);
void PackLstmInput(float *dst, const float *src, int row, int deep);
void LstmMatMul(float *c, const float *a, const float *b, const float *bias, int row, int deep, int col, bool is_vec);
void ElementMulAcc(const float *input0, const float *input1, float *output, int element_size);

View File

@ -30,6 +30,8 @@ typedef struct GruParameter {
int input_step_;
int output_step_;
bool bidirectional_;
int col_align_;
int row_align_;
} GruParameter;
#endif // MINDSPORE_LITE_NNACL_GRU_PARAMETER_H_

View File

@ -19,6 +19,7 @@
#include "src/kernel_registry.h"
#include "include/errorcode.h"
#include "nnacl/fp32/gru_fp32.h"
#include "nnacl/fp32/lstm_fp32.h"
using mindspore::kernel::KERNEL_ARCH::kCPU;
using mindspore::lite::KernelRegistrar;
@ -28,82 +29,109 @@ using mindspore::schema::PrimitiveType_Gru;
namespace mindspore::kernel {
void GruCPUKernel::FreeTmpBuffer() {
if (gate_buffer_ != nullptr) {
free(gate_buffer_);
gate_buffer_ = nullptr;
if (!is_vec_) {
if (weight_g_ptr_ != nullptr) {
free(weight_g_ptr_);
weight_g_ptr_ = nullptr;
}
if (weight_r_ptr_ != nullptr) {
free(weight_r_ptr_);
weight_r_ptr_ = nullptr;
}
if (bias_ptr_ != nullptr) {
free(bias_ptr_);
bias_ptr_ = nullptr;
}
}
if (bias_ptr_ != nullptr) {
free(bias_ptr_);
bias_ptr_ = nullptr;
}
void GruCPUKernel::FreeRunBuffer() {
context_->allocator->Free(gate_buffer_);
if (!is_vec_) {
for (int i = 0; i < 2; i++) {
context_->allocator->Free(matmul_buffer_[i]);
}
}
weight_g_ptr_ = nullptr;
weight_r_ptr_ = nullptr;
}
int GruCPUKernel::InitParam() {
auto input = in_tensors_.front();
MS_ASSERT(input != nullptr);
std::vector<int> in_shape = input->shape();
gru_parm_->seq_len_ = in_shape.at(0);
gru_parm_->batch_ = in_shape.at(1);
gru_parm_->input_size_ = in_shape.at(2);
gru_param_->seq_len_ = in_shape.at(0);
gru_param_->batch_ = in_shape.at(1);
gru_param_->input_size_ = in_shape.at(2);
auto weight_g = in_tensors_.at(1);
MS_ASSERT(weight_g != nullptr);
std::vector<int> w_shape = weight_g->shape();
gru_parm_->hidden_size_ = w_shape.at(1) / 3;
gru_param_->hidden_size_ = w_shape.at(1) / 3;
gru_parm_->input_step_ = gru_parm_->batch_ * gru_parm_->input_size_;
gru_parm_->output_step_ = gru_parm_->bidirectional_ ? 2 * gru_parm_->batch_ * gru_parm_->hidden_size_
: gru_parm_->batch_ * gru_parm_->hidden_size_;
return RET_OK;
}
gru_param_->input_step_ = gru_param_->batch_ * gru_param_->input_size_;
gru_param_->output_step_ = gru_param_->bidirectional_ ? 2 * gru_param_->batch_ * gru_param_->hidden_size_
: gru_param_->batch_ * gru_param_->hidden_size_;
int GruCPUKernel::InitBuffer() {
gate_buffer_ = reinterpret_cast<float *>(malloc(3 * gru_parm_->batch_ * gru_parm_->hidden_size_ * sizeof(float)));
if (gate_buffer_ == nullptr) {
MS_LOG(ERROR) << "GruCPUKernel malloc gate_buffer error.";
return RET_ERROR;
}
#ifdef ENABLE_AVX
row_tile_ = C6NUM;
col_tile_ = C16NUM;
#elif defined(ENABLE_ARM32)
row_tile_ = C12NUM;
col_tile_ = C4NUM;
#elif defined(ENABLE_SSE)
row_tile_ = C4NUM;
col_tile_ = C8NUM;
#else
row_tile_ = C12NUM;
col_tile_ = C8NUM;
#endif
is_vec_ = gru_param_->batch_ == 1;
gru_param_->row_align_ = is_vec_ ? 1 : UP_ROUND(gru_param_->batch_, row_tile_);
gru_param_->col_align_ = is_vec_ ? gru_param_->hidden_size_ : UP_ROUND(gru_param_->hidden_size_, col_tile_);
return RET_OK;
}
int GruCPUKernel::InitWeightBias() {
auto weight_gate = in_tensors_.at(1);
MS_ASSERT(weight_gate != nullptr);
weight_g_ptr_ = reinterpret_cast<float *>(malloc(weight_gate->ElementsNum() * sizeof(float)));
if (weight_g_ptr_ == nullptr) {
MS_LOG(ERROR) << "GruCPUKernel malloc weight_g_ptr_ error.";
return RET_ERROR;
}
memcpy(weight_g_ptr_, weight_gate->data_c(), weight_gate->ElementsNum() * sizeof(float));
auto weight_batch = gru_param_->bidirectional_ ? 6 : 3;
if (!is_vec_) {
// malloc and init input * weight right matrix buffer
auto weight_g = in_tensors_.at(1);
MS_ASSERT(weight_g != nullptr);
weight_g_ptr_ = reinterpret_cast<float *>(
malloc(weight_batch * gru_param_->col_align_ * gru_param_->input_size_ * sizeof(float)));
if (weight_g_ptr_ == nullptr) {
MS_LOG(ERROR) << "GruCPUKernel malloc weight_g_ptr_ error.";
return RET_ERROR;
}
auto weight_i_data = reinterpret_cast<float *>(weight_g->data_c());
PackLstmWeight(weight_g_ptr_, weight_i_data, weight_batch, gru_param_->input_size_, gru_param_->hidden_size_,
gru_param_->col_align_);
auto weight_recu = in_tensors_.at(2);
MS_ASSERT(weight_recu != nullptr);
weight_r_ptr_ = reinterpret_cast<float *>(malloc(weight_recu->ElementsNum() * sizeof(float)));
if (weight_r_ptr_ == nullptr) {
MS_LOG(ERROR) << "GruCPUKernel malloc weight_r_ptr_ error.";
return RET_ERROR;
}
memcpy(weight_r_ptr_, weight_recu->data_c(), weight_recu->ElementsNum() * sizeof(float));
// malloc and init state * weight right matrix buffer
auto weight_r = in_tensors_.at(2);
MS_ASSERT(weight_r != nullptr);
weight_r_ptr_ = reinterpret_cast<float *>(
malloc(weight_batch * gru_param_->col_align_ * gru_param_->hidden_size_ * sizeof(float)));
if (weight_r_ptr_ == nullptr) {
MS_LOG(ERROR) << "GruCPUKernel malloc weight_r_ptr_ error.";
return RET_ERROR;
}
auto weight_r_data = reinterpret_cast<float *>(weight_r->data_c());
PackLstmWeight(weight_r_ptr_, weight_r_data, weight_batch, gru_param_->hidden_size_, gru_param_->hidden_size_,
gru_param_->col_align_);
int bias_num = gru_parm_->bidirectional_ ? 2 * 3 * gru_parm_->hidden_size_ : 3 * gru_parm_->hidden_size_;
bias_ptr_ = reinterpret_cast<float *>(malloc(bias_num * sizeof(float)));
if (bias_ptr_ == nullptr) {
MS_LOG(ERROR) << "GruCPUKernel malloc bias_ptr_ error.";
return RET_ERROR;
}
auto bias_data = reinterpret_cast<float *>(in_tensors_.at(3)->data_c());
const int state_bias_offset = 3 * gru_parm_->hidden_size_;
for (int i = 0; i < state_bias_offset; i++) {
bias_ptr_[i] = bias_data[i] + bias_data[i + state_bias_offset];
}
if (gru_parm_->bidirectional_) {
bias_data += 3 * gru_parm_->hidden_size_ * 2;
auto backward_bias = bias_ptr_ + 3 * gru_parm_->hidden_size_;
for (int i = 0; i < state_bias_offset; i++) {
backward_bias[i] = bias_data[i] + bias_data[i + state_bias_offset];
// init bias
int bias_batch = gru_param_->bidirectional_ ? 16 : 8;
bias_ptr_ = reinterpret_cast<float *>(malloc(bias_batch * gru_param_->col_align_ * sizeof(float)));
if (bias_ptr_ == nullptr) {
MS_LOG(ERROR) << "GruCPUKernel malloc bias_ptr_ error.";
return RET_ERROR;
}
memset(bias_ptr_, 0, bias_batch * gru_param_->col_align_ * sizeof(float));
auto bias_data = reinterpret_cast<float *>(in_tensors_.at(3)->data_c());
for (int i = 0; i < bias_batch; i++) {
auto src_batch = bias_data + i * gru_param_->hidden_size_;
auto dst_batch = bias_ptr_ + i * gru_param_->col_align_;
memcpy(dst_batch, src_batch, gru_param_->hidden_size_ * sizeof(float));
}
}
return RET_OK;
@ -117,24 +145,42 @@ int GruCPUKernel::Init() {
}
int GruCPUKernel::ReSize() {
FreeTmpBuffer();
auto ret = InitParam();
if (ret != RET_OK) {
MS_LOG(ERROR) << "GruCPUKernel InitParam error.";
return RET_ERROR;
}
FreeTmpBuffer();
ret = InitWeightBias();
if (ret != RET_OK) {
MS_LOG(ERROR) << "GruCPUKernel InitWeightBias error.";
FreeTmpBuffer();
return RET_ERROR;
}
return RET_OK;
}
ret = InitBuffer();
if (ret != RET_OK) {
MS_LOG(ERROR) << "GruCPUKernel InitBuffer error.";
FreeTmpBuffer();
int GruCPUKernel::MallocRunBuffer() {
if (!is_vec_) {
matmul_buffer_[0] = reinterpret_cast<float *>(
context_->allocator->Malloc(3 * gru_param_->row_align_ * gru_param_->input_size_ * sizeof(float)));
if (matmul_buffer_[0] == nullptr) {
MS_LOG(ERROR) << "GruCPUKernel malloc input * weight left matirx error.";
return RET_ERROR;
}
matmul_buffer_[1] = reinterpret_cast<float *>(
context_->allocator->Malloc(3 * gru_param_->row_align_ * gru_param_->hidden_size_ * sizeof(float)));
if (matmul_buffer_[1] == nullptr) {
MS_LOG(ERROR) << "GruCPUKernel malloc state * weight left matirx error.";
return RET_ERROR;
}
}
gate_buffer_ = reinterpret_cast<float *>(
context_->allocator->Malloc(6 * gru_param_->batch_ * gru_param_->hidden_size_ * sizeof(float)));
if (gate_buffer_ == nullptr) {
MS_LOG(ERROR) << "GruCPUKernel malloc gate_buffer error.";
return RET_ERROR;
}
return RET_OK;
@ -153,22 +199,35 @@ int GruCPUKernel::Run() {
MS_ASSERT(output_ptr);
auto output_hidden_state = out_tensors_[1];
memcpy(output_hidden_state->data_c(), hidden_state->data_c(), hidden_state->ElementsNum() * sizeof(float));
int check_seq_len = gru_parm_->seq_len_;
int check_seq_len = gru_param_->seq_len_;
if (in_tensors_.size() == 6) {
auto seq_len = reinterpret_cast<int *>(in_tensors_.at(5)->data_c());
if (!std::equal(seq_len + 1, seq_len + gru_parm_->batch_, seq_len)) {
if (!std::equal(seq_len + 1, seq_len + gru_param_->batch_, seq_len)) {
MS_LOG(ERROR) << "different batch seq_len is currently not supported";
return RET_ERROR;
}
check_seq_len = MSMIN(check_seq_len, MSMAX(0, seq_len[0]));
}
auto ret = MallocRunBuffer();
if (ret != RET_OK) {
MS_LOG(ERROR) << "GruCPUKernel MallocRunBuffer error.";
return RET_ERROR;
}
if (is_vec_) {
weight_g_ptr_ = reinterpret_cast<float *>(in_tensors_[1]->data_c());
weight_r_ptr_ = reinterpret_cast<float *>(in_tensors_[2]->data_c());
bias_ptr_ = reinterpret_cast<float *>(in_tensors_[3]->data_c());
}
MS_ASSERT(weight_g_ptr_ != nullptr);
MS_ASSERT(weight_r_ptr_ != nullptr);
MS_ASSERT(bias_ptr_ != nullptr);
MS_ASSERT(gate_buffer_ != nullptr);
Gru(output_ptr, input_ptr, weight_g_ptr_, weight_r_ptr_, bias_ptr_,
reinterpret_cast<float *>(output_hidden_state->data_c()), gate_buffer_, check_seq_len, gru_parm_);
reinterpret_cast<float *>(output_hidden_state->data_c()), gate_buffer_, matmul_buffer_, check_seq_len,
gru_param_);
FreeRunBuffer();
return RET_OK;
}

View File

@ -26,7 +26,7 @@ class GruCPUKernel : public LiteKernel {
const std::vector<lite::Tensor *> &outputs, const lite::InnerContext *ctx,
const mindspore::lite::PrimitiveC *primitive)
: LiteKernel(parameter, inputs, outputs, ctx, primitive) {
gru_parm_ = reinterpret_cast<GruParameter *>(op_parameter_);
gru_param_ = reinterpret_cast<GruParameter *>(op_parameter_);
}
~GruCPUKernel() override { FreeTmpBuffer(); }
@ -37,15 +37,20 @@ class GruCPUKernel : public LiteKernel {
private:
void FreeTmpBuffer();
void FreeRunBuffer();
int InitParam();
int InitBuffer();
int MallocRunBuffer();
int InitWeightBias();
float *gate_buffer_ = nullptr;
float *weight_g_ptr_ = nullptr;
float *weight_r_ptr_ = nullptr;
float *bias_ptr_ = nullptr;
GruParameter *gru_parm_ = nullptr;
float *matmul_buffer_[2];
int row_tile_ = 0;
int col_tile_ = 0;
bool is_vec_ = false;
GruParameter *gru_param_ = nullptr;
};
} // namespace mindspore::kernel

View File

@ -57,24 +57,8 @@ void LstmCPUKernel::FreeRunBuffer() {
}
}
int InitRightMatrix(float *dst, const float *src, int batch, int deep, int col, int col_align, bool is_vec) {
for (int i = 0; i < batch; i++) {
auto src_batch = src + i * col * deep;
auto dst_batch = dst + i * col_align * deep;
#ifdef ENABLE_AVX
RowMajor2Col16Major(src_batch, dst_batch, col, deep);
#elif defined(ENABLE_ARM32)
RowMajor2Col4Major(src_batch, dst_batch, col, deep);
#else
RowMajor2Col8Major(src_batch, dst_batch, col, deep);
#endif
}
return RET_OK;
}
int LstmCPUKernel::InitWeightBias() {
auto weight_batch = lstm_param_->bidirectional_ ? 8 : 4;
if (!is_vec_) {
// malloc and init input * weight right matrix buffer
auto weight_i = in_tensors_.at(1);
@ -86,8 +70,8 @@ int LstmCPUKernel::InitWeightBias() {
return RET_ERROR;
}
auto weight_i_data = reinterpret_cast<float *>(weight_i->data_c());
InitRightMatrix(weight_i_ptr_, weight_i_data, weight_batch, lstm_param_->input_size_, lstm_param_->hidden_size_,
lstm_param_->col_align_, is_vec_);
PackLstmWeight(weight_i_ptr_, weight_i_data, weight_batch, lstm_param_->input_size_, lstm_param_->hidden_size_,
lstm_param_->col_align_);
// malloc and init state * weight right matrix buffer
auto weight_h = in_tensors_.at(2);
@ -99,8 +83,8 @@ int LstmCPUKernel::InitWeightBias() {
return RET_ERROR;
}
auto weight_h_data = reinterpret_cast<float *>(weight_h->data_c());
InitRightMatrix(weight_h_ptr_, weight_h_data, weight_batch, lstm_param_->hidden_size_, lstm_param_->hidden_size_,
lstm_param_->col_align_, is_vec_);
PackLstmWeight(weight_h_ptr_, weight_h_data, weight_batch, lstm_param_->hidden_size_, lstm_param_->hidden_size_,
lstm_param_->col_align_);
// init bias
int bias_batch = lstm_param_->bidirectional_ ? 16 : 8;
@ -235,7 +219,7 @@ int LstmCPUKernel::Run() {
auto ret = MallocRunBuffer();
if (ret != RET_OK) {
MS_LOG(ERROR) << "LstmCPUKernel InitRunBuffer error.";
MS_LOG(ERROR) << "LstmCPUKernel MallocRunBuffer error.";
return RET_ERROR;
}
@ -244,7 +228,6 @@ int LstmCPUKernel::Run() {
weight_h_ptr_ = reinterpret_cast<float *>(in_tensors_[2]->data_c());
bias_ptr_ = reinterpret_cast<float *>(in_tensors_[3]->data_c());
}
MS_ASSERT(weight_h_ptr_);
MS_ASSERT(weight_i_ptr_);
MS_ASSERT(bias_ptr_);