forked from mindspore-Ecosystem/mindspore
!11989 [MSLITE][Develop] optimize arm cpu fp32 op lstm: use matmul calculate function
From: @yangruoqi713 Reviewed-by: @zhang_xue_tong Signed-off-by: @zhang_xue_tong
This commit is contained in:
commit
4d866ea64a
|
@ -19,20 +19,7 @@
|
|||
#include <float.h>
|
||||
#include "nnacl/fp32/activation_fp32.h"
|
||||
#include "nnacl/fp32/arithmetic_fp32.h"
|
||||
#include "nnacl/fp32/mul_fp32.h"
|
||||
|
||||
void InitGate(float *gate_buffer, const float *bias, const LstmParameter *lstm_parm) {
|
||||
int gate_offest = 0;
|
||||
for (int l = 0; l < 4; l++) {
|
||||
int batch_offest = gate_offest;
|
||||
int bias_offest = l * lstm_parm->hidden_size_;
|
||||
for (int b = 0; b < lstm_parm->batch_; b++) {
|
||||
memcpy(gate_buffer + batch_offest, bias + bias_offest, lstm_parm->hidden_size_ * sizeof(float));
|
||||
batch_offest += lstm_parm->hidden_size_;
|
||||
}
|
||||
gate_offest += lstm_parm->batch_ * lstm_parm->hidden_size_;
|
||||
}
|
||||
}
|
||||
#include "nnacl/fp32/matmul_fp32.h"
|
||||
|
||||
// input: [row, inner_size]; weight: [col, inner_size]; output: [row, col]
|
||||
void MatMulAcc(float *output, const float *input, const float *weight, int rows, int cols, int inner_size) {
|
||||
|
@ -134,106 +121,131 @@ void UpdataOutput(const float *cell_state, const float *output_gate, float *hidd
|
|||
}
|
||||
}
|
||||
|
||||
void LstmStepUnit(float *output, const float *input, const float *input_input_weight, const float *input_forget_weight,
|
||||
const float *input_cell_weight, const float *input_output_weight, const float *state_input_weight,
|
||||
const float *state_forget_weight, const float *state_cell_weight, const float *state_output_weight,
|
||||
const float *bias, float *hidden_state, float *cell_state, float *gate_buffer, float *state_buffer,
|
||||
const LstmParameter *lstm_parm) {
|
||||
InitGate(gate_buffer, bias, lstm_parm);
|
||||
void LstmMatmul(float *c, const float *a, const float *b, const float *bias, int row, int deep, int col, bool is_vec) {
|
||||
if (is_vec) {
|
||||
memcpy(c, bias, col * sizeof(float));
|
||||
MatMulAcc(c, a, b, row, col, deep);
|
||||
} else {
|
||||
MatMulOpt(a, b, c, bias, ActType_No, deep, row, col, col, OutType_Nhwc);
|
||||
}
|
||||
}
|
||||
|
||||
void PackLstmInput(float *dst, const float *src, int row, int deep) {
|
||||
#ifdef ENABLE_AVX
|
||||
RowMajor2Col6Major(src, dst, row, deep);
|
||||
#elif defined(ENABLE_SSE)
|
||||
RowMajor2Col4Major(src, dst, row, deep);
|
||||
#else
|
||||
RowMajor2Col12Major(src, dst, row, deep);
|
||||
#endif
|
||||
}
|
||||
|
||||
void UpdateGate(float *gate_buffer, const float *input, const float *weight, const float *bias, int row, int deep,
|
||||
int col, int col_align, bool is_vec) {
|
||||
const float *input_weight = weight;
|
||||
const float *forget_weight = weight + deep * col * 2;
|
||||
const float *cell_weight = weight + deep * col * 3;
|
||||
const float *output_weight = weight + deep * col;
|
||||
|
||||
const float *input_bias = bias;
|
||||
const float *forget_bias = bias + col_align * 2;
|
||||
const float *cell_bias = bias + col_align * 3;
|
||||
const float *output_bias = bias + col_align;
|
||||
|
||||
float *input_gate = gate_buffer;
|
||||
float *forget_gate = gate_buffer + lstm_parm->batch_ * lstm_parm->hidden_size_ * 2;
|
||||
float *cell_gate = gate_buffer + lstm_parm->batch_ * lstm_parm->hidden_size_ * 3;
|
||||
float *output_gate = gate_buffer + lstm_parm->batch_ * lstm_parm->hidden_size_ * 1;
|
||||
float *forget_gate = gate_buffer + row * col * 2;
|
||||
float *cell_gate = gate_buffer + row * col * 3;
|
||||
float *output_gate = gate_buffer + row * col;
|
||||
|
||||
LstmMatmul(input_gate, input, input_weight, input_bias, row, deep, col, is_vec);
|
||||
LstmMatmul(forget_gate, input, forget_weight, forget_bias, row, deep, col, is_vec);
|
||||
LstmMatmul(cell_gate, input, cell_weight, cell_bias, row, deep, col, is_vec);
|
||||
LstmMatmul(output_gate, input, output_weight, output_bias, row, deep, col, is_vec);
|
||||
}
|
||||
|
||||
void LstmStepUnit(float *output, const float *input, const float *input_weight, const float *state_weight,
|
||||
const float *bias, float *hidden_state, float *cell_state, float *gate_buffer, float *state_buffer,
|
||||
float *matmul_buffer[2], const LstmParameter *lstm_param) {
|
||||
bool is_vec = lstm_param->batch_ == 1;
|
||||
// input * weight
|
||||
MatMulAcc(input_gate, input, input_input_weight, lstm_parm->batch_, lstm_parm->hidden_size_, lstm_parm->input_size_);
|
||||
MatMulAcc(forget_gate, input, input_forget_weight, lstm_parm->batch_, lstm_parm->hidden_size_,
|
||||
lstm_parm->input_size_);
|
||||
MatMulAcc(cell_gate, input, input_cell_weight, lstm_parm->batch_, lstm_parm->hidden_size_, lstm_parm->input_size_);
|
||||
MatMulAcc(output_gate, input, input_output_weight, lstm_parm->batch_, lstm_parm->hidden_size_,
|
||||
lstm_parm->input_size_);
|
||||
if (is_vec) {
|
||||
UpdateGate(gate_buffer, input, input_weight, bias, lstm_param->batch_, lstm_param->input_size_,
|
||||
lstm_param->hidden_size_, lstm_param->col_align_, is_vec);
|
||||
} else {
|
||||
// pack input for matmul
|
||||
PackLstmInput(matmul_buffer[0], input, lstm_param->batch_, lstm_param->input_size_);
|
||||
UpdateGate(gate_buffer, matmul_buffer[0], input_weight, bias, lstm_param->batch_, lstm_param->input_size_,
|
||||
lstm_param->hidden_size_, lstm_param->col_align_, is_vec);
|
||||
}
|
||||
|
||||
// state * weight
|
||||
MatMulAcc(input_gate, hidden_state, state_input_weight, lstm_parm->batch_, lstm_parm->hidden_size_,
|
||||
lstm_parm->hidden_size_);
|
||||
MatMulAcc(forget_gate, hidden_state, state_forget_weight, lstm_parm->batch_, lstm_parm->hidden_size_,
|
||||
lstm_parm->hidden_size_);
|
||||
MatMulAcc(cell_gate, hidden_state, state_cell_weight, lstm_parm->batch_, lstm_parm->hidden_size_,
|
||||
lstm_parm->hidden_size_);
|
||||
MatMulAcc(output_gate, hidden_state, state_output_weight, lstm_parm->batch_, lstm_parm->hidden_size_,
|
||||
lstm_parm->hidden_size_);
|
||||
float *state_gate = gate_buffer + lstm_param->batch_ * lstm_param->hidden_size_ * 4;
|
||||
const float *state_bias = bias + lstm_param->col_align_ * 4;
|
||||
if (is_vec) {
|
||||
UpdateGate(state_gate, hidden_state, state_weight, state_bias, lstm_param->batch_, lstm_param->hidden_size_,
|
||||
lstm_param->hidden_size_, lstm_param->col_align_, is_vec);
|
||||
} else {
|
||||
// pack state for matmul
|
||||
PackLstmInput(matmul_buffer[1], hidden_state, lstm_param->batch_, lstm_param->hidden_size_);
|
||||
UpdateGate(state_gate, matmul_buffer[1], state_weight, state_bias, lstm_param->batch_, lstm_param->hidden_size_,
|
||||
lstm_param->hidden_size_, lstm_param->col_align_, is_vec);
|
||||
}
|
||||
ElementAdd(gate_buffer, state_gate, gate_buffer, 4 * lstm_param->batch_ * lstm_param->hidden_size_);
|
||||
|
||||
float *input_gate = gate_buffer;
|
||||
float *forget_gate = gate_buffer + lstm_param->batch_ * lstm_param->hidden_size_ * 2;
|
||||
float *cell_gate = gate_buffer + lstm_param->batch_ * lstm_param->hidden_size_ * 3;
|
||||
float *output_gate = gate_buffer + lstm_param->batch_ * lstm_param->hidden_size_;
|
||||
// update input_gate
|
||||
Sigmoid(input_gate, lstm_parm->batch_ * lstm_parm->hidden_size_, input_gate);
|
||||
Sigmoid(input_gate, lstm_param->batch_ * lstm_param->hidden_size_, input_gate);
|
||||
|
||||
// update forget_gate
|
||||
Sigmoid(forget_gate, lstm_parm->batch_ * lstm_parm->hidden_size_, forget_gate);
|
||||
Sigmoid(forget_gate, lstm_param->batch_ * lstm_param->hidden_size_, forget_gate);
|
||||
|
||||
// update cell_gate
|
||||
Tanh(cell_gate, lstm_parm->batch_ * lstm_parm->hidden_size_, cell_gate);
|
||||
Tanh(cell_gate, lstm_param->batch_ * lstm_param->hidden_size_, cell_gate);
|
||||
// update cell state
|
||||
UpdataState(cell_state, forget_gate, input_gate, cell_gate, state_buffer, lstm_parm->batch_, lstm_parm->hidden_size_,
|
||||
lstm_parm->smooth_);
|
||||
UpdataState(cell_state, forget_gate, input_gate, cell_gate, state_buffer, lstm_param->batch_,
|
||||
lstm_param->hidden_size_, lstm_param->smooth_);
|
||||
|
||||
// update output_gate
|
||||
Sigmoid(output_gate, lstm_parm->batch_ * lstm_parm->hidden_size_, output_gate);
|
||||
Sigmoid(output_gate, lstm_param->batch_ * lstm_param->hidden_size_, output_gate);
|
||||
// update output
|
||||
UpdataOutput(cell_state, output_gate, hidden_state, state_buffer, lstm_parm->batch_, lstm_parm->hidden_size_,
|
||||
lstm_parm->smooth_);
|
||||
memcpy(output, hidden_state, lstm_parm->batch_ * lstm_parm->hidden_size_ * sizeof(float));
|
||||
UpdataOutput(cell_state, output_gate, hidden_state, state_buffer, lstm_param->batch_, lstm_param->hidden_size_,
|
||||
lstm_param->smooth_);
|
||||
memcpy(output, hidden_state, lstm_param->batch_ * lstm_param->hidden_size_ * sizeof(float));
|
||||
|
||||
if (!(lstm_parm->smooth_ >= -FLT_EPSILON && lstm_parm->smooth_ <= FLT_EPSILON)) {
|
||||
memcpy(cell_state, state_buffer, lstm_parm->batch_ * lstm_parm->hidden_size_ * sizeof(float));
|
||||
memcpy(hidden_state, state_buffer + lstm_parm->batch_ * lstm_parm->hidden_size_,
|
||||
lstm_parm->batch_ * lstm_parm->hidden_size_ * sizeof(float));
|
||||
if (!(lstm_param->smooth_ >= -FLT_EPSILON && lstm_param->smooth_ <= FLT_EPSILON)) {
|
||||
memcpy(cell_state, state_buffer, lstm_param->batch_ * lstm_param->hidden_size_ * sizeof(float));
|
||||
memcpy(hidden_state, state_buffer + lstm_param->batch_ * lstm_param->hidden_size_,
|
||||
lstm_param->batch_ * lstm_param->hidden_size_ * sizeof(float));
|
||||
}
|
||||
}
|
||||
|
||||
void Lstm(float *output, const float *input, const float *weight_i, const float *weight_h, const float *bias,
|
||||
float *hidden_state, float *cell_state, float *gate_buffer, float *state_buffer,
|
||||
const LstmParameter *lstm_parm) {
|
||||
float *hidden_state, float *cell_state, float *gate_buffer, float *state_buffer, float *matmul_buffer[2],
|
||||
const LstmParameter *lstm_param) {
|
||||
// forward
|
||||
const float *input_input_weight = weight_i;
|
||||
const float *input_forget_weight = weight_i + lstm_parm->input_size_ * lstm_parm->hidden_size_ * 2;
|
||||
const float *input_cell_weight = weight_i + lstm_parm->input_size_ * lstm_parm->hidden_size_ * 3;
|
||||
const float *input_output_weight = weight_i + lstm_parm->input_size_ * lstm_parm->hidden_size_ * 1;
|
||||
|
||||
const float *state_input_weight = weight_h;
|
||||
const float *state_forget_weight = weight_h + lstm_parm->hidden_size_ * lstm_parm->hidden_size_ * 2;
|
||||
const float *state_cell_weight = weight_h + lstm_parm->hidden_size_ * lstm_parm->hidden_size_ * 3;
|
||||
const float *state_output_weight = weight_h + lstm_parm->hidden_size_ * lstm_parm->hidden_size_ * 1;
|
||||
|
||||
for (int t = 0; t < lstm_parm->seq_len_; t++) {
|
||||
const float *input_ptr = input + t * lstm_parm->input_step_;
|
||||
float *output_ptr = output + t * lstm_parm->output_step_;
|
||||
LstmStepUnit(output_ptr, input_ptr, input_input_weight, input_forget_weight, input_cell_weight, input_output_weight,
|
||||
state_input_weight, state_forget_weight, state_cell_weight, state_output_weight, bias, hidden_state,
|
||||
cell_state, gate_buffer, state_buffer, lstm_parm);
|
||||
for (int t = 0; t < lstm_param->seq_len_; t++) {
|
||||
const float *input_ptr = input + t * lstm_param->input_step_;
|
||||
float *output_ptr = output + t * lstm_param->output_step_;
|
||||
LstmStepUnit(output_ptr, input_ptr, weight_i, weight_h, bias, hidden_state, cell_state, gate_buffer, state_buffer,
|
||||
matmul_buffer, lstm_param);
|
||||
}
|
||||
|
||||
// backward
|
||||
if (lstm_parm->bidirectional_) {
|
||||
input_input_weight = weight_i + lstm_parm->input_size_ * lstm_parm->hidden_size_ * 4;
|
||||
input_forget_weight = weight_i + lstm_parm->input_size_ * lstm_parm->hidden_size_ * 6;
|
||||
input_cell_weight = weight_i + lstm_parm->input_size_ * lstm_parm->hidden_size_ * 7;
|
||||
input_output_weight = weight_i + lstm_parm->input_size_ * lstm_parm->hidden_size_ * 5;
|
||||
|
||||
state_input_weight = weight_h + lstm_parm->hidden_size_ * lstm_parm->hidden_size_ * 4;
|
||||
state_forget_weight = weight_h + lstm_parm->hidden_size_ * lstm_parm->hidden_size_ * 6;
|
||||
state_cell_weight = weight_h + lstm_parm->hidden_size_ * lstm_parm->hidden_size_ * 7;
|
||||
state_output_weight = weight_h + lstm_parm->hidden_size_ * lstm_parm->hidden_size_ * 5;
|
||||
|
||||
float *backward_output = output + lstm_parm->batch_ * lstm_parm->hidden_size_;
|
||||
const float *backward_bias = bias + 4 * lstm_parm->hidden_size_;
|
||||
float *backward_cell_state = cell_state + lstm_parm->batch_ * lstm_parm->hidden_size_;
|
||||
float *backward_hidden_state = hidden_state + lstm_parm->batch_ * lstm_parm->hidden_size_;
|
||||
for (int t = lstm_parm->seq_len_ - 1; t >= 0; t--) {
|
||||
const float *input_ptr = input + t * lstm_parm->input_step_;
|
||||
float *output_ptr = backward_output + t * lstm_parm->output_step_;
|
||||
LstmStepUnit(output_ptr, input_ptr, input_input_weight, input_forget_weight, input_cell_weight,
|
||||
input_output_weight, state_input_weight, state_forget_weight, state_cell_weight, state_output_weight,
|
||||
backward_bias, backward_hidden_state, backward_cell_state, gate_buffer, state_buffer, lstm_parm);
|
||||
if (lstm_param->bidirectional_) {
|
||||
const float *backward_weight_i = weight_i + 4 * lstm_param->col_align_ * lstm_param->input_size_;
|
||||
const float *backward_weight_h = weight_h + 4 * lstm_param->col_align_ * lstm_param->hidden_size_;
|
||||
const float *backward_bias = bias + 8 * lstm_param->hidden_size_;
|
||||
float *backward_output = output + lstm_param->batch_ * lstm_param->hidden_size_;
|
||||
float *backward_cell_state = cell_state + lstm_param->batch_ * lstm_param->hidden_size_;
|
||||
float *backward_hidden_state = hidden_state + lstm_param->batch_ * lstm_param->hidden_size_;
|
||||
for (int t = lstm_param->seq_len_ - 1; t >= 0; t--) {
|
||||
const float *input_ptr = input + t * lstm_param->input_step_;
|
||||
float *output_ptr = backward_output + t * lstm_param->output_step_;
|
||||
LstmStepUnit(output_ptr, input_ptr, backward_weight_i, backward_weight_h, backward_bias, backward_hidden_state,
|
||||
backward_cell_state, gate_buffer, state_buffer, matmul_buffer, lstm_param);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -28,7 +28,7 @@ void ElementMulAcc(const float *input0, const float *input1, float *output, int
|
|||
int ElementOptMulAcc(const float *input0, const float input1, float *output, const int element_size);
|
||||
|
||||
void Lstm(float *output, const float *input, const float *weight_i, const float *weight_h, const float *bias,
|
||||
float *hidden_state, float *cell_state, float *gate_buffer, float *state_buffer,
|
||||
float *hidden_state, float *cell_state, float *gate_buffer, float *state_buffer, float *matmul_buffer[2],
|
||||
const LstmParameter *lstm_parm);
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
|
|
|
@ -34,6 +34,8 @@ typedef struct LstmParameter {
|
|||
// output_hidden = old_hidden * smooth + new_hidden * (1 - smooth)
|
||||
// output_cell = old_cell * smooth + new_cell * (1 - smooth)
|
||||
float smooth_;
|
||||
int col_align_;
|
||||
int row_align_;
|
||||
} LstmParameter;
|
||||
|
||||
#endif // MINDSPORE_LITE_NNACL_LSTM_PARAMETER_H_
|
||||
|
|
|
@ -84,9 +84,8 @@ std::vector<size_t> GetLinkedPostNodeIdx(const lite::Model *model, const size_t
|
|||
|
||||
bool IsPackedOp(schema::PrimitiveType op_type) {
|
||||
static std::vector<schema::PrimitiveType> packed_ops = {
|
||||
schema::PrimitiveType_Conv2D, schema::PrimitiveType_DeConv2D,
|
||||
schema::PrimitiveType_DepthwiseConv2D, schema::PrimitiveType_DeDepthwiseConv2D,
|
||||
schema::PrimitiveType_MatMul, schema::PrimitiveType_Lstm};
|
||||
schema::PrimitiveType_Conv2D, schema::PrimitiveType_DeConv2D, schema::PrimitiveType_DepthwiseConv2D,
|
||||
schema::PrimitiveType_DeDepthwiseConv2D, schema::PrimitiveType_MatMul};
|
||||
return IsContain(packed_ops, op_type);
|
||||
}
|
||||
} // namespace lite
|
||||
|
|
|
@ -20,35 +20,104 @@
|
|||
#include "schema/model_generated.h"
|
||||
#include "src/kernel_registry.h"
|
||||
#include "include/errorcode.h"
|
||||
#include "nnacl/fp32/matmul_fp32.h"
|
||||
|
||||
using mindspore::kernel::KERNEL_ARCH::kCPU;
|
||||
using mindspore::lite::KernelRegistrar;
|
||||
using mindspore::lite::RET_ERROR;
|
||||
using mindspore::lite::RET_MEMORY_FAILED;
|
||||
using mindspore::lite::RET_OK;
|
||||
using mindspore::schema::PrimitiveType_Lstm;
|
||||
|
||||
namespace mindspore::kernel {
|
||||
void LstmCPUKernel::FreeTmpBuffer() {
|
||||
if (gate_buffer_ != nullptr) {
|
||||
free(gate_buffer_);
|
||||
gate_buffer_ = nullptr;
|
||||
if (!is_vec_) {
|
||||
if (weight_i_ptr_ != nullptr) {
|
||||
free(weight_i_ptr_);
|
||||
weight_i_ptr_ = nullptr;
|
||||
}
|
||||
if (weight_h_ptr_ != nullptr) {
|
||||
free(weight_h_ptr_);
|
||||
weight_h_ptr_ = nullptr;
|
||||
}
|
||||
if (bias_ptr_ != nullptr) {
|
||||
free(bias_ptr_);
|
||||
bias_ptr_ = nullptr;
|
||||
}
|
||||
}
|
||||
if (state_buffer_ != nullptr) {
|
||||
free(state_buffer_);
|
||||
state_buffer_ = nullptr;
|
||||
}
|
||||
|
||||
void LstmCPUKernel::FreeRunBuffer() {
|
||||
context_->allocator->Free(gate_buffer_);
|
||||
context_->allocator->Free(state_buffer_);
|
||||
if (!is_vec_) {
|
||||
for (int i = 0; i < 2; i++) {
|
||||
context_->allocator->Free(matmul_buffer_[i]);
|
||||
}
|
||||
}
|
||||
if (weight_i_ptr_ != nullptr) {
|
||||
free(weight_i_ptr_);
|
||||
weight_i_ptr_ = nullptr;
|
||||
}
|
||||
|
||||
int InitRightMatrix(float *dst, const float *src, int batch, int deep, int col, int col_align, bool is_vec) {
|
||||
for (int i = 0; i < batch; i++) {
|
||||
auto src_batch = src + i * col * deep;
|
||||
auto dst_batch = dst + i * col_align * deep;
|
||||
#ifdef ENABLE_AVX
|
||||
RowMajor2Col16Major(src_batch, dst_batch, col, deep);
|
||||
#elif defined(ENABLE_ARM32)
|
||||
RowMajor2Col4Major(src_batch, dst_batch, col, deep);
|
||||
#else
|
||||
RowMajor2Col8Major(src_batch, dst_batch, col, deep);
|
||||
#endif
|
||||
}
|
||||
if (weight_h_ptr_ != nullptr) {
|
||||
free(weight_h_ptr_);
|
||||
weight_h_ptr_ = nullptr;
|
||||
}
|
||||
if (bias_ptr_ != nullptr) {
|
||||
free(bias_ptr_);
|
||||
bias_ptr_ = nullptr;
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
int LstmCPUKernel::InitWeightBias() {
|
||||
auto weight_batch = lstm_param_->bidirectional_ ? 8 : 4;
|
||||
|
||||
if (!is_vec_) {
|
||||
// malloc and init input * weight right matrix buffer
|
||||
auto weight_i = in_tensors_.at(1);
|
||||
MS_ASSERT(weight_i != nullptr);
|
||||
weight_i_ptr_ = reinterpret_cast<float *>(
|
||||
malloc(weight_batch * lstm_param_->col_align_ * lstm_param_->input_size_ * sizeof(float)));
|
||||
if (weight_i_ptr_ == nullptr) {
|
||||
MS_LOG(ERROR) << "LstmCPUKernel malloc weight_i_ptr_ error.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
auto weight_i_data = reinterpret_cast<float *>(weight_i->data_c());
|
||||
InitRightMatrix(weight_i_ptr_, weight_i_data, weight_batch, lstm_param_->input_size_, lstm_param_->hidden_size_,
|
||||
lstm_param_->col_align_, is_vec_);
|
||||
|
||||
// malloc and init state * weight right matrix buffer
|
||||
auto weight_h = in_tensors_.at(2);
|
||||
MS_ASSERT(weight_h != nullptr);
|
||||
weight_h_ptr_ = reinterpret_cast<float *>(
|
||||
malloc(weight_batch * lstm_param_->col_align_ * lstm_param_->hidden_size_ * sizeof(float)));
|
||||
if (weight_h_ptr_ == nullptr) {
|
||||
MS_LOG(ERROR) << "LstmCPUKernel malloc weight_h_ptr_ error.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
auto weight_h_data = reinterpret_cast<float *>(weight_h->data_c());
|
||||
InitRightMatrix(weight_h_ptr_, weight_h_data, weight_batch, lstm_param_->hidden_size_, lstm_param_->hidden_size_,
|
||||
lstm_param_->col_align_, is_vec_);
|
||||
|
||||
// init bias
|
||||
int bias_batch = lstm_param_->bidirectional_ ? 16 : 8;
|
||||
bias_ptr_ = reinterpret_cast<float *>(malloc(bias_batch * lstm_param_->col_align_ * sizeof(float)));
|
||||
if (bias_ptr_ == nullptr) {
|
||||
MS_LOG(ERROR) << "LstmCPUKernel malloc bias_ptr_ error.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
memset(bias_ptr_, 0, bias_batch * lstm_param_->col_align_ * sizeof(float));
|
||||
auto bias_data = reinterpret_cast<float *>(in_tensors_.at(3)->data_c());
|
||||
for (int i = 0; i < bias_batch; i++) {
|
||||
auto src_batch = bias_data + i * lstm_param_->hidden_size_;
|
||||
auto dst_batch = bias_ptr_ + i * lstm_param_->col_align_;
|
||||
memcpy(dst_batch, src_batch, lstm_param_->hidden_size_ * sizeof(float));
|
||||
}
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
int LstmCPUKernel::InitParam() {
|
||||
|
@ -67,80 +136,27 @@ int LstmCPUKernel::InitParam() {
|
|||
lstm_param_->input_step_ = lstm_param_->batch_ * lstm_param_->input_size_;
|
||||
lstm_param_->output_step_ = lstm_param_->bidirectional_ ? 2 * lstm_param_->batch_ * lstm_param_->hidden_size_
|
||||
: lstm_param_->batch_ * lstm_param_->hidden_size_;
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
int LstmCPUKernel::InitBuffer() {
|
||||
gate_buffer_ = reinterpret_cast<float *>(malloc(4 * lstm_param_->batch_ * lstm_param_->hidden_size_ * sizeof(float)));
|
||||
if (gate_buffer_ == nullptr) {
|
||||
MS_LOG(ERROR) << "LstmCPUKernel malloc gate_buffer error.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
if (!(lstm_param_->smooth_ >= -FLT_EPSILON && lstm_param_->smooth_ <= FLT_EPSILON)) {
|
||||
int buffer_size = 2 * lstm_param_->batch_ * lstm_param_->hidden_size_ * sizeof(float);
|
||||
state_buffer_ = reinterpret_cast<float *>(malloc(buffer_size));
|
||||
if (state_buffer_ == nullptr) {
|
||||
MS_LOG(ERROR) << "LstmCPUKernel malloc state_buffer error.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
int LstmCPUKernel::InitWeightBias() {
|
||||
// copy weight_i and weight_h
|
||||
auto weight_i = in_tensors_.at(1);
|
||||
MS_ASSERT(weight_i != nullptr);
|
||||
weight_i_ptr_ = reinterpret_cast<float *>(malloc(weight_i->ElementsNum() * sizeof(float)));
|
||||
if (weight_i_ptr_ == nullptr) {
|
||||
MS_LOG(ERROR) << "LstmCPUKernel malloc weight_i_ptr_ error.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
memcpy(weight_i_ptr_, weight_i->data_c(), weight_i->ElementsNum() * sizeof(float));
|
||||
|
||||
auto weight_h = in_tensors_.at(2);
|
||||
MS_ASSERT(weight_h != nullptr);
|
||||
weight_h_ptr_ = reinterpret_cast<float *>(malloc(weight_h->ElementsNum() * sizeof(float)));
|
||||
if (weight_h_ptr_ == nullptr) {
|
||||
MS_LOG(ERROR) << "LstmCPUKernel malloc weight_h_ error.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
memcpy(weight_h_ptr_, weight_h->data_c(), weight_h->ElementsNum() * sizeof(float));
|
||||
|
||||
std::vector<int> w_shape = weight_i->shape();
|
||||
auto hidden_size = w_shape.at(1) / 4;
|
||||
// init bias
|
||||
int bias_num = lstm_param_->bidirectional_ ? 2 * 4 * hidden_size : 4 * hidden_size;
|
||||
bias_ptr_ = reinterpret_cast<float *>(malloc(bias_num * sizeof(float)));
|
||||
if (bias_ptr_ == nullptr) {
|
||||
MS_LOG(ERROR) << "LstmCPUKernel malloc bias_ptr_ error.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
|
||||
auto bias_data = reinterpret_cast<float *>(in_tensors_.at(3)->data_c());
|
||||
const int state_bias_offset = 4 * hidden_size;
|
||||
for (int i = 0; i < state_bias_offset; i++) {
|
||||
bias_ptr_[i] = bias_data[i] + bias_data[i + state_bias_offset];
|
||||
}
|
||||
if (lstm_param_->bidirectional_) {
|
||||
bias_data += 4 * hidden_size * 2;
|
||||
auto backward_bias = bias_ptr_ + 4 * hidden_size;
|
||||
for (int i = 0; i < state_bias_offset; i++) {
|
||||
backward_bias[i] = bias_data[i] + bias_data[i + state_bias_offset];
|
||||
}
|
||||
}
|
||||
#ifdef ENABLE_AVX
|
||||
row_tile_ = C6NUM;
|
||||
col_tile_ = C16NUM;
|
||||
#elif defined(ENABLE_ARM32)
|
||||
row_tile_ = C12NUM;
|
||||
col_tile_ = C4NUM;
|
||||
#elif defined(ENABLE_SSE)
|
||||
row_tile_ = C4NUM;
|
||||
col_tile_ = C8NUM;
|
||||
#else
|
||||
row_tile_ = C12NUM;
|
||||
col_tile_ = C8NUM;
|
||||
#endif
|
||||
is_vec_ = lstm_param_->batch_ == 1;
|
||||
lstm_param_->row_align_ = is_vec_ ? 1 : UP_ROUND(lstm_param_->batch_, row_tile_);
|
||||
lstm_param_->col_align_ = is_vec_ ? lstm_param_->hidden_size_ : UP_ROUND(lstm_param_->hidden_size_, col_tile_);
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
int LstmCPUKernel::Init() {
|
||||
FreeTmpBuffer();
|
||||
auto ret = InitWeightBias();
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "LstmCPUKernel InitWeightBias error.";
|
||||
FreeTmpBuffer();
|
||||
return RET_ERROR;
|
||||
}
|
||||
|
||||
if (!InferShapeDone()) {
|
||||
return RET_OK;
|
||||
}
|
||||
|
@ -154,15 +170,50 @@ int LstmCPUKernel::ReSize() {
|
|||
return RET_ERROR;
|
||||
}
|
||||
|
||||
ret = InitBuffer();
|
||||
FreeTmpBuffer();
|
||||
ret = InitWeightBias();
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "LstmCPUKernel InitBuffer error.";
|
||||
MS_LOG(ERROR) << "LstmCPUKernel InitWeightBias error.";
|
||||
FreeTmpBuffer();
|
||||
return RET_ERROR;
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
int LstmCPUKernel::MallocRunBuffer() {
|
||||
if (!is_vec_) {
|
||||
matmul_buffer_[0] = reinterpret_cast<float *>(
|
||||
context_->allocator->Malloc(4 * lstm_param_->row_align_ * lstm_param_->input_size_ * sizeof(float)));
|
||||
if (matmul_buffer_[0] == nullptr) {
|
||||
MS_LOG(ERROR) << "LstmCPUKernel malloc input * weight left matirx error.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
|
||||
matmul_buffer_[1] = reinterpret_cast<float *>(
|
||||
context_->allocator->Malloc(4 * lstm_param_->row_align_ * lstm_param_->hidden_size_ * sizeof(float)));
|
||||
if (matmul_buffer_[1] == nullptr) {
|
||||
MS_LOG(ERROR) << "LstmCPUKernel malloc state * weight left matirx error.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
}
|
||||
|
||||
gate_buffer_ = reinterpret_cast<float *>(
|
||||
context_->allocator->Malloc(8 * lstm_param_->batch_ * lstm_param_->hidden_size_ * sizeof(float)));
|
||||
if (gate_buffer_ == nullptr) {
|
||||
MS_LOG(ERROR) << "LstmCPUKernel malloc gate_buffer error.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
if (!(lstm_param_->smooth_ >= -FLT_EPSILON && lstm_param_->smooth_ <= FLT_EPSILON)) {
|
||||
int buffer_size = 2 * lstm_param_->batch_ * lstm_param_->hidden_size_ * sizeof(float);
|
||||
state_buffer_ = reinterpret_cast<float *>(context_->allocator->Malloc(buffer_size));
|
||||
if (state_buffer_ == nullptr) {
|
||||
MS_LOG(ERROR) << "LstmCPUKernel malloc state_buffer error.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
int LstmCPUKernel::Run() {
|
||||
auto input = in_tensors_.at(kInputIndex);
|
||||
MS_ASSERT(input != nullptr);
|
||||
|
@ -182,13 +233,26 @@ int LstmCPUKernel::Run() {
|
|||
auto output_cell_state = out_tensors_[2];
|
||||
memcpy(output_cell_state->data_c(), cell_state->data_c(), cell_state->ElementsNum() * sizeof(float));
|
||||
|
||||
auto ret = MallocRunBuffer();
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "LstmCPUKernel InitRunBuffer error.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
|
||||
if (is_vec_) {
|
||||
weight_i_ptr_ = reinterpret_cast<float *>(in_tensors_[1]->data_c());
|
||||
weight_h_ptr_ = reinterpret_cast<float *>(in_tensors_[2]->data_c());
|
||||
bias_ptr_ = reinterpret_cast<float *>(in_tensors_[3]->data_c());
|
||||
}
|
||||
|
||||
MS_ASSERT(weight_h_ptr_);
|
||||
MS_ASSERT(weight_i_ptr_);
|
||||
MS_ASSERT(bias_ptr_);
|
||||
MS_ASSERT(gate_buffer_);
|
||||
Lstm(output_ptr, input_ptr, weight_i_ptr_, weight_h_ptr_, bias_ptr_,
|
||||
reinterpret_cast<float *>(output_hidden_state->data_c()), reinterpret_cast<float *>(output_cell_state->data_c()),
|
||||
gate_buffer_, state_buffer_, lstm_param_);
|
||||
gate_buffer_, state_buffer_, matmul_buffer_, lstm_param_);
|
||||
FreeRunBuffer();
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
|
|
|
@ -39,8 +39,9 @@ class LstmCPUKernel : public LiteKernel {
|
|||
|
||||
private:
|
||||
void FreeTmpBuffer();
|
||||
void FreeRunBuffer();
|
||||
int InitParam();
|
||||
int InitBuffer();
|
||||
int MallocRunBuffer();
|
||||
int InitWeightBias();
|
||||
|
||||
float *gate_buffer_ = nullptr;
|
||||
|
@ -48,6 +49,10 @@ class LstmCPUKernel : public LiteKernel {
|
|||
float *weight_i_ptr_ = nullptr;
|
||||
float *weight_h_ptr_ = nullptr;
|
||||
float *bias_ptr_ = nullptr;
|
||||
float *matmul_buffer_[2];
|
||||
int row_tile_ = 0;
|
||||
int col_tile_ = 0;
|
||||
bool is_vec_ = false;
|
||||
LstmParameter *lstm_param_ = nullptr;
|
||||
};
|
||||
} // namespace mindspore::kernel
|
||||
|
|
Loading…
Reference in New Issue