forked from mindspore-Ecosystem/mindspore
!14312 add lstm fp32 coder
From: @yangjie159 Reviewed-by: @wangchengyuan,@hangangqiang Signed-off-by: @wangchengyuan
This commit is contained in:
commit
4812f7c6eb
|
@ -84,6 +84,7 @@ set(CODER_OPCODERS_SRC
|
|||
${MICRO_DIR}/coder/opcoders/nnacl/fp32/convolution_depthwise_fp32_coder.cc
|
||||
${MICRO_DIR}/coder/opcoders/nnacl/fp32/full_connection_fp32_coder.cc
|
||||
${MICRO_DIR}/coder/opcoders/nnacl/fp32/gather_fp32_coder.cc
|
||||
${MICRO_DIR}/coder/opcoders/nnacl/fp32/lstm_fp32_coder.cc
|
||||
${MICRO_DIR}/coder/opcoders/nnacl/fp32/matmul_fp32_base_coder.cc
|
||||
${MICRO_DIR}/coder/opcoders/nnacl/fp32/matmul_fp32_coder.cc
|
||||
${MICRO_DIR}/coder/opcoders/nnacl/fp32/pad_fp32_coder.cc
|
||||
|
|
|
@ -0,0 +1,194 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "coder/opcoders/nnacl/fp32/lstm_fp32_coder.h"
|
||||
#include <float.h>
|
||||
#include <string>
|
||||
#include "coder/opcoders/serializers/nnacl_serializer/nnacl_fp32_serializer.h"
|
||||
#include "coder/log.h"
|
||||
#include "coder/opcoders/file_collector.h"
|
||||
|
||||
using mindspore::schema::PrimitiveType_LSTM;
|
||||
|
||||
namespace mindspore::lite::micro::nnacl {
|
||||
constexpr int kFifthIndex = 5;
|
||||
constexpr int kSixthIndex = 6;
|
||||
|
||||
int LstmFP32Coder::InitInputWeightBias(CoderContext *const context) {
|
||||
NNaclFp32Serializer init_code;
|
||||
Tensor *weight_i = input_tensors_.at(kWeightIndex);
|
||||
MS_CHECK_PTR(weight_i);
|
||||
size_t weight_i_size = weight_batch_ * lstm_param_->input_col_align_ * lstm_param_->input_size_ * sizeof(float);
|
||||
weight_i_ptr_ = reinterpret_cast<float *>(allocator_->Malloc(kNumberTypeFloat32, kOnlineSize, kOnlinePackWeight));
|
||||
MS_CHECK_PTR(weight_h_ptr_);
|
||||
|
||||
init_code.CodeMallocExpression(weight_i_ptr_, weight_i_size);
|
||||
init_code.CodeFunction("memset", weight_i_ptr_, 0, weight_i_size);
|
||||
init_code.CodeFunction("PackLstmWeight", weight_i_ptr_, weight_i, weight_batch_, lstm_param_->input_size_,
|
||||
lstm_param_->hidden_size_, lstm_param_->input_col_align_);
|
||||
|
||||
Tensor *bias_i = input_tensors_.at(kInputSize2);
|
||||
MS_CHECK_PTR(bias_i);
|
||||
input_bias_ = reinterpret_cast<float *>(allocator_->Malloc(kNumberTypeFloat32, kOnlineSize, kOnlinePackWeight));
|
||||
size_t bias_i_size = weight_batch_ * lstm_param_->input_col_align_ * sizeof(float);
|
||||
init_code.CodeMallocExpression(input_bias_, bias_i_size);
|
||||
init_code.CodeFunction("memset", input_bias_, 0, bias_i_size);
|
||||
init_code.CodeFunction("PackLstmBias", input_bias_, bias_i, weight_batch_, lstm_param_->hidden_size_,
|
||||
lstm_param_->input_col_align_, lstm_param_->bidirectional_);
|
||||
context->AppendInitCode(init_code.str());
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
int LstmFP32Coder::InitStateWeightBias(CoderContext *const context) {
|
||||
NNaclFp32Serializer init_code;
|
||||
Tensor *weight_h = input_tensors().at(kInputSize1);
|
||||
MS_CHECK_PTR(weight_h);
|
||||
if (!is_vec_) {
|
||||
size_t weight_h_size = weight_batch_ * lstm_param_->state_col_align_ * lstm_param_->hidden_size_ * sizeof(float);
|
||||
weight_h_ptr_ = reinterpret_cast<float *>(allocator_->Malloc(kNumberTypeFloat32, kOnlineSize, kOnlinePackWeight));
|
||||
MS_CHECK_PTR(weight_h_ptr_);
|
||||
init_code.CodeMallocExpression(weight_i_ptr_, weight_h_size);
|
||||
init_code.CodeFunction("memset", weight_i_ptr_, 0, weight_h_size);
|
||||
init_code.CodeFunction("PackLstmWeight", weight_h_ptr_, weight_h, weight_batch_, lstm_param_->hidden_size_,
|
||||
lstm_param_->hidden_size_, lstm_param_->state_col_align_);
|
||||
} else {
|
||||
size_t weight_h_size = weight_h->Size();
|
||||
weight_h_ptr_ =
|
||||
reinterpret_cast<float *>(allocator_->Malloc(kNumberTypeFloat32, weight_h->Size(), kOfflinePackWeight));
|
||||
MS_CHECK_PTR(weight_h_ptr_);
|
||||
MS_CHECK_RET_CODE(memcpy_s(weight_h_ptr_, weight_h_size, weight_h->data_c(), weight_h_size),
|
||||
"copy weight h data failed");
|
||||
}
|
||||
|
||||
state_bias_ = reinterpret_cast<float *>(allocator_->Malloc(kNumberTypeFloat32, kOnlineSize, kOnlinePackWeight));
|
||||
size_t state_bias_size = weight_batch_ * lstm_param_->state_col_align_ * sizeof(float);
|
||||
init_code.CodeMallocExpression(state_bias_, state_bias_size);
|
||||
init_code.CodeFunction("memset", state_bias_, 0, state_bias_size);
|
||||
|
||||
Tensor *bias_i = input_tensors_.at(kInputSize2);
|
||||
MS_CHECK_PTR(bias_i);
|
||||
std::string state_bias_addr =
|
||||
allocator_->GetRuntimeAddr(bias_i) + "+" + std::to_string(4 * lstm_param_->hidden_size_);
|
||||
init_code.CodeFunction("PackLstmBias", state_bias_, state_bias_addr, weight_batch_, lstm_param_->hidden_size_,
|
||||
lstm_param_->state_col_align_, lstm_param_->bidirectional_);
|
||||
context->AppendInitCode(init_code.str());
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
int LstmFP32Coder::InitParam() {
|
||||
std::vector<int> in_shape = input_tensor_->shape();
|
||||
lstm_param_->seq_len_ = in_shape.at(0);
|
||||
lstm_param_->batch_ = in_shape.at(1);
|
||||
lstm_param_->input_size_ = in_shape.at(2);
|
||||
|
||||
auto weight_i = input_tensors_.at(1);
|
||||
MS_ASSERT(weight_i != nullptr);
|
||||
std::vector<int> w_shape = weight_i->shape();
|
||||
lstm_param_->hidden_size_ = w_shape.at(1) / 4;
|
||||
lstm_param_->output_step_ = lstm_param_->bidirectional_ ? 2 * lstm_param_->batch_ * lstm_param_->hidden_size_
|
||||
: lstm_param_->batch_ * lstm_param_->hidden_size_;
|
||||
weight_batch_ = lstm_param_->bidirectional_ ? 8 : 4;
|
||||
|
||||
if (target_ == kARM32A || target_ == kARM32M) {
|
||||
row_tile_ = C12NUM;
|
||||
col_tile_ = C4NUM;
|
||||
} else {
|
||||
row_tile_ = C12NUM;
|
||||
col_tile_ = C8NUM;
|
||||
}
|
||||
lstm_param_->input_row_align_ = UP_ROUND(lstm_param_->seq_len_ * lstm_param_->batch_, row_tile_);
|
||||
lstm_param_->input_col_align_ = UP_ROUND(lstm_param_->hidden_size_, col_tile_);
|
||||
|
||||
is_vec_ = lstm_param_->batch_ == 1;
|
||||
lstm_param_->state_row_align_ = is_vec_ ? 1 : UP_ROUND(lstm_param_->batch_, row_tile_);
|
||||
lstm_param_->state_col_align_ = is_vec_ ? lstm_param_->hidden_size_ : UP_ROUND(lstm_param_->hidden_size_, col_tile_);
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
int LstmFP32Coder::MallocRunBuffer(CoderContext *const context) {
|
||||
buffer_[0] = reinterpret_cast<float *>(allocator_->Malloc(
|
||||
kNumberTypeFloat32, lstm_param_->input_row_align_ * lstm_param_->input_size_ * sizeof(float), kWorkspace));
|
||||
MS_CHECK_PTR(buffer_[0]);
|
||||
buffer_[1] = reinterpret_cast<float *>(allocator_->Malloc(
|
||||
kNumberTypeFloat32, 4 * lstm_param_->seq_len_ * lstm_param_->batch_ * lstm_param_->hidden_size_ * sizeof(float),
|
||||
kWorkspace));
|
||||
MS_CHECK_PTR(buffer_[1]);
|
||||
if (!is_vec_) {
|
||||
buffer_[2] = reinterpret_cast<float *>(allocator_->Malloc(
|
||||
kNumberTypeFloat32, lstm_param_->state_row_align_ * lstm_param_->hidden_size_ * sizeof(float), kWorkspace));
|
||||
MS_CHECK_PTR(buffer_[2]);
|
||||
}
|
||||
buffer_[3] = reinterpret_cast<float *>(allocator_->Malloc(
|
||||
kNumberTypeFloat32, 4 * lstm_param_->batch_ * lstm_param_->hidden_size_ * sizeof(float), kWorkspace));
|
||||
MS_CHECK_PTR(buffer_[3]);
|
||||
if (!(lstm_param_->zoneout_cell_ >= -FLT_EPSILON && lstm_param_->zoneout_cell_ <= FLT_EPSILON)) {
|
||||
buffer_[4] = reinterpret_cast<float *>(allocator_->Malloc(
|
||||
kNumberTypeFloat32, lstm_param_->batch_ * lstm_param_->hidden_size_ * sizeof(float), kWorkspace));
|
||||
MS_CHECK_PTR(buffer_[4]);
|
||||
}
|
||||
if (!(lstm_param_->zoneout_hidden_ >= -FLT_EPSILON && lstm_param_->zoneout_hidden_ <= FLT_EPSILON)) {
|
||||
buffer_[5] = reinterpret_cast<float *>(allocator_->Malloc(
|
||||
kNumberTypeFloat32, lstm_param_->batch_ * lstm_param_->hidden_size_ * sizeof(float), kWorkspace));
|
||||
MS_CHECK_PTR(buffer_[5]);
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
int LstmFP32Coder::ReSize(CoderContext *const context) {
|
||||
MS_CHECK_RET_CODE(InitParam(), "init params of lstm coder failed");
|
||||
MS_CHECK_RET_CODE(InitInputWeightBias(context), "init input weight and bias failed");
|
||||
MS_CHECK_RET_CODE(InitStateWeightBias(context), "init state weight and bias failed");
|
||||
MS_CHECK_RET_CODE(MallocRunBuffer(context), "malloc run buffer failed");
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
int LstmFP32Coder::Prepare(CoderContext *const context) {
|
||||
lstm_param_ = reinterpret_cast<LstmParameter *>(parameter_);
|
||||
return ReSize(context);
|
||||
}
|
||||
|
||||
int LstmFP32Coder::DoCode(CoderContext *context) {
|
||||
Collect(context,
|
||||
{
|
||||
"nnacl/lstm_parameter.h",
|
||||
"nnacl/fp32/lstm_fp32.h",
|
||||
},
|
||||
{
|
||||
"lstm_fp32.c",
|
||||
});
|
||||
|
||||
Tensor *hidden_state = input_tensors_.at(kFifthIndex);
|
||||
MS_CHECK_PTR(hidden_state);
|
||||
Tensor *cell_state = input_tensors_.at(kSixthIndex);
|
||||
MS_CHECK_PTR(cell_state);
|
||||
Tensor *output_hidden_state = output_tensors_[1];
|
||||
MS_CHECK_PTR(output_hidden_state);
|
||||
Tensor *output_cell_state = output_tensors_[2];
|
||||
MS_CHECK_PTR(output_hidden_state);
|
||||
|
||||
NNaclFp32Serializer code;
|
||||
code.CodeStruct("lstm_param", *lstm_param_);
|
||||
code.CodeFunction("memcpy", output_hidden_state, hidden_state, hidden_state->Size());
|
||||
code.CodeFunction("memcpy", output_cell_state, cell_state, cell_state->Size());
|
||||
code.CodeFunction("Lstm", output_tensor_, input_tensor_, weight_i_ptr_, weight_h_ptr_, input_bias_, state_bias_,
|
||||
output_hidden_state, output_cell_state, buffer_, lstm_param_);
|
||||
context->AppendCode(code.str());
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
REG_OPERATOR_CODER(kAllTargets, kNumberTypeFloat32, PrimitiveType_LSTM, CPUOpCoderCreator<LstmFP32Coder>)
|
||||
|
||||
} // namespace mindspore::lite::micro::nnacl
|
|
@ -0,0 +1,55 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_LITE_MICRO_CODER_OPCODERS_NNACL_LSTM_FP32_CODER_H_
|
||||
#define MINDSPORE_LITE_MICRO_CODER_OPCODERS_NNACL_LSTM_FP32_CODER_H_
|
||||
|
||||
#include <vector>
|
||||
#include "coder/opcoders/op_coder.h"
|
||||
#include "nnacl/lstm_parameter.h"
|
||||
|
||||
namespace mindspore::lite::micro::nnacl {
|
||||
class LstmFP32Coder final : public OperatorCoder {
|
||||
public:
|
||||
LstmFP32Coder(const std::vector<Tensor *> &in_tensors, const std::vector<Tensor *> &out_tensors,
|
||||
const Model::Node *node, size_t node_index, Target target)
|
||||
: OperatorCoder(in_tensors, out_tensors, node, node_index, target) {}
|
||||
|
||||
~LstmFP32Coder() override = default;
|
||||
|
||||
int Prepare(CoderContext *const context) override;
|
||||
int DoCode(CoderContext *const context) override;
|
||||
|
||||
private:
|
||||
int InitParam();
|
||||
int ReSize(CoderContext *const context);
|
||||
int MallocRunBuffer(CoderContext *const context);
|
||||
int InitInputWeightBias(CoderContext *const context);
|
||||
int InitStateWeightBias(CoderContext *const context);
|
||||
|
||||
float *weight_i_ptr_{nullptr};
|
||||
float *weight_h_ptr_{nullptr};
|
||||
float *input_bias_{nullptr};
|
||||
float *state_bias_{nullptr};
|
||||
float *buffer_[6];
|
||||
int row_tile_{0};
|
||||
int col_tile_{0};
|
||||
int weight_batch_{0};
|
||||
bool is_vec_{false};
|
||||
LstmParameter *lstm_param_{nullptr};
|
||||
};
|
||||
} // namespace mindspore::lite::micro::nnacl
|
||||
#endif // MINDSPORE_LITE_MICRO_CODER_OPCODERS_NNACL_LSTM_FP32_CODER_H_
|
|
@ -108,6 +108,14 @@ void NNaclFp32Serializer::CodeStruct(const std::string &name, const TransposePar
|
|||
transpose_parameter.data_size_);
|
||||
}
|
||||
|
||||
void NNaclFp32Serializer::CodeStruct(const std::string &name, const LstmParameter &lstm_parameter) {
|
||||
CodeBaseStruct("LstmParameter", name, lstm_parameter.op_parameter_, lstm_parameter.input_size_,
|
||||
lstm_parameter.hidden_size_, lstm_parameter.seq_len_, lstm_parameter.batch_,
|
||||
lstm_parameter.output_step_, lstm_parameter.bidirectional_, lstm_parameter.zoneout_cell_,
|
||||
lstm_parameter.zoneout_hidden_, lstm_parameter.input_row_align_, lstm_parameter.input_col_align_,
|
||||
lstm_parameter.state_row_align_, lstm_parameter.state_col_align_);
|
||||
}
|
||||
|
||||
void NNaclFp32Serializer::CodeStruct(const std::string &name, const DeQuantArg &de_quant_arg) {
|
||||
// this clusters is meaningless which will be supported in future
|
||||
CodeBaseStruct("DeQuantArg", name, de_quant_arg.scale, de_quant_arg.zeroPoint, de_quant_arg.var_corr,
|
||||
|
|
|
@ -30,6 +30,7 @@
|
|||
#include "nnacl/pooling_parameter.h"
|
||||
#include "nnacl/softmax_parameter.h"
|
||||
#include "nnacl/splice_parameter.h"
|
||||
#include "nnacl/lstm_parameter.h"
|
||||
#include "wrapper/fp32/dequant_int8_to_fp32_wrapper.h"
|
||||
namespace mindspore::lite::micro::nnacl {
|
||||
|
||||
|
@ -43,6 +44,7 @@ class NNaclFp32Serializer : public Serializer {
|
|||
void CodeStruct(const std::string &name, const ArithmeticParameter &arithmetic_parameter);
|
||||
void CodeStruct(const std::string &name, const ConvParameter &conv_parameter);
|
||||
void CodeStruct(const std::string &name, const MatMulParameter &mat_mul_parameter);
|
||||
void CodeStruct(const std::string &name, const LstmParameter &lstm_parameter);
|
||||
void CodeStruct(const std::string &name, const ScaleParameter &scale_parameter);
|
||||
void CodeStruct(const std::string &name, const SliceParameter &slice_parameter);
|
||||
void CodeStruct(const std::string &name, const TileParameter &tile_parameter);
|
||||
|
|
Loading…
Reference in New Issue