forked from mindspore-Ecosystem/mindspore
!3943 add arm cpu fp32 op: lstm
Merge pull request !3943 from yangruoqi713/lite
This commit is contained in:
commit
4b2e26f16b
|
@ -0,0 +1,69 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "src/ops/ops.h"
|
||||
#include "include/errorcode.h"
|
||||
#include "utils/log_adapter.h"
|
||||
#include "src/ir/tensor.h"
|
||||
|
||||
namespace mindspore::lite {
|
||||
const int kLstmInputNum = 6;
|
||||
const int kLstmOutputNum = 3;
|
||||
int Lstm::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<tensor::Tensor *> outputs_) {
|
||||
MS_ASSERT(this->primitive != nullptr);
|
||||
if (inputs_.size() != kLstmInputNum || outputs_.size() != kLstmOutputNum) {
|
||||
MS_LOG(ERROR) << "OpLstm inputs or outputs size error.";
|
||||
return RET_INPUT_TENSOR_ERROR;
|
||||
}
|
||||
auto input = inputs_.front();
|
||||
MS_ASSERT(input != nullptr);
|
||||
auto weight_i = inputs_.front();
|
||||
MS_ASSERT(input0 != nullptr);
|
||||
auto output = outputs_.front();
|
||||
MS_ASSERT(output != nullptr);
|
||||
|
||||
std::vector<int> in_shape = input->shape();
|
||||
std::vector<int> w_shape = weight_i->shape(); // layer, hidden_size * 4, input_size
|
||||
if (in_shape.size() != 3 || w_shape.size() != 3) {
|
||||
MS_LOG(ERROR) << "OpLstm input dims should be 3.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
|
||||
auto lstm_prim = this->primitive->value_as_Lstm();
|
||||
int hidden_size = w_shape[1] / 4;
|
||||
|
||||
// set output
|
||||
std::vector<int> out_shape(in_shape);
|
||||
out_shape[2] = hidden_size;
|
||||
if (lstm_prim->bidirection()) {
|
||||
out_shape.insert(out_shape.begin() + 1, 2);
|
||||
}
|
||||
output->set_shape(out_shape);
|
||||
|
||||
// set hidden state, cell state
|
||||
std::vector<int> state_shape(in_shape);
|
||||
state_shape[0] = lstm_prim->bidirection() ? 2 : 1;
|
||||
state_shape[2] = hidden_size;
|
||||
outputs_[1]->set_shape(state_shape);
|
||||
outputs_[2]->set_shape(state_shape);
|
||||
|
||||
for (int i = 0; i < kLstmOutputNum; i++) {
|
||||
outputs_[i]->set_data_type(input->data_type());
|
||||
outputs_[i]->SetFormat(input->GetFormat());
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
} // namespace mindspore::lite
|
|
@ -770,6 +770,13 @@ class QuantDTypeCast : public Primitive {
|
|||
const schema::QuantDTypeCast *GetAttribute() const { return this->primitive->value_as_QuantDTypeCast(); }
|
||||
int InferShape(std::vector<tensor::Tensor *> inputs, std::vector<tensor::Tensor *> outputs) override;
|
||||
};
|
||||
|
||||
class Lstm : public Primitive {
|
||||
public:
|
||||
explicit Lstm(schema::Primitive *primitive) : Primitive(primitive) {}
|
||||
const schema::Lstm *GetAttribute() const { return this->primitive->value_as_Lstm(); }
|
||||
int InferShape(std::vector<tensor::Tensor *> inputs, std::vector<tensor::Tensor *> outputs) override;
|
||||
};
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_LITE_SRC_OPS_OPS_H_
|
||||
|
|
|
@ -67,6 +67,7 @@
|
|||
#include "src/runtime/kernel/arm/opclib/fp32/space_to_depth.h"
|
||||
#include "src/runtime/kernel/arm/opclib/fp32/space_to_batch.h"
|
||||
#include "src/runtime/kernel/arm/opclib/int8/quant_dtype_cast.h"
|
||||
#include "src/runtime/kernel/arm/opclib/fp32/lstm.h"
|
||||
|
||||
namespace mindspore::kernel {
|
||||
OpParameter *PopulateFillParameter(const lite::Primitive *primitive) {
|
||||
|
@ -1169,6 +1170,23 @@ OpParameter *PopulatePriorBoxParameter(const lite::Primitive *primitive) {
|
|||
return reinterpret_cast<OpParameter *>(prior_box_param);
|
||||
}
|
||||
|
||||
OpParameter *PopulateLstmParameter(const lite::Primitive *primitive) {
|
||||
LstmParameter *lstm_param = new (std::nothrow) LstmParameter();
|
||||
if (lstm_param == nullptr) {
|
||||
MS_LOG(ERROR) << "new LstmParameter fail!";
|
||||
return nullptr;
|
||||
}
|
||||
lstm_param->op_parameter_.type_ = primitive->Type();
|
||||
auto param = primitive->Value()->value_as_Lstm();
|
||||
if (param == nullptr) {
|
||||
delete (lstm_param);
|
||||
MS_LOG(ERROR) << "get Lstm param nullptr.";
|
||||
return nullptr;
|
||||
}
|
||||
lstm_param->bidirectional_ = param->bidirection();
|
||||
return reinterpret_cast<OpParameter *>(lstm_param);
|
||||
}
|
||||
|
||||
PopulateParameterRegistry::PopulateParameterRegistry() {
|
||||
populate_parameter_funcs_[schema::PrimitiveType_SoftMax] = PopulateSoftmaxParameter;
|
||||
populate_parameter_funcs_[schema::PrimitiveType_Activation] = PopulateActivationParameter;
|
||||
|
@ -1244,6 +1262,7 @@ PopulateParameterRegistry::PopulateParameterRegistry() {
|
|||
populate_parameter_funcs_[schema::PrimitiveType_Split] = PopulateSplitParameter;
|
||||
populate_parameter_funcs_[schema::PrimitiveType_PriorBox] = PopulatePriorBoxParameter;
|
||||
populate_parameter_funcs_[schema::PrimitiveType_QuantDTypeCast] = PopulateQuantDTypeCastParameter;
|
||||
populate_parameter_funcs_[schema::PrimitiveType_Lstm] = PopulateLstmParameter;
|
||||
}
|
||||
|
||||
PopulateParameterRegistry *PopulateParameterRegistry::GetInstance() {
|
||||
|
|
|
@ -0,0 +1,185 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "src/runtime/kernel/arm/fp32/lstm.h"
|
||||
#include <vector>
|
||||
#include "schema/model_generated.h"
|
||||
#include "src/kernel_registry.h"
|
||||
#include "include/errorcode.h"
|
||||
|
||||
using mindspore::kernel::KERNEL_ARCH::kCPU;
|
||||
using mindspore::lite::KernelRegistrar;
|
||||
using mindspore::lite::RET_ERROR;
|
||||
using mindspore::lite::RET_OK;
|
||||
using mindspore::schema::PrimitiveType_Lstm;
|
||||
|
||||
namespace mindspore::kernel {
|
||||
int LstmCPUKernel::InitParam() {
|
||||
auto input = inputs_.front();
|
||||
MS_ASSERT(input != nullptr);
|
||||
std::vector<int> in_shape = input->shape();
|
||||
lstm_parm_->seq_len_ = in_shape[0];
|
||||
lstm_parm_->batch_ = in_shape[1];
|
||||
lstm_parm_->input_size_ = in_shape[2];
|
||||
|
||||
auto weight_i = inputs_[1];
|
||||
MS_ASSERT(weight_i != nullptr);
|
||||
std::vector<int> w_shape = weight_i->shape();
|
||||
lstm_parm_->hidden_size_ = w_shape[1] / 4;
|
||||
|
||||
lstm_parm_->input_step_ = lstm_parm_->batch_ * lstm_parm_->input_size_;
|
||||
lstm_parm_->output_step_ = lstm_parm_->bidirectional_ ? 2 * lstm_parm_->batch_ * lstm_parm_->hidden_size_
|
||||
: lstm_parm_->batch_ * lstm_parm_->hidden_size_;
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
int LstmCPUKernel::InitBuffer() {
|
||||
gate_buffer_ = reinterpret_cast<float *>(malloc(4 * lstm_parm_->batch_ * lstm_parm_->hidden_size_ * sizeof(float)));
|
||||
if (gate_buffer_ == nullptr) {
|
||||
MS_LOG(ERROR) << "LstmCPUKernel malloc gate_buffer error.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
int LstmCPUKernel::InitWeightBias() {
|
||||
// copy weight_i and weight_h
|
||||
auto weight_i = inputs_.at(1);
|
||||
MS_ASSERT(weight_i != nullptr);
|
||||
weight_i_ptr_ = reinterpret_cast<float *>(malloc(weight_i->ElementsNum() * sizeof(float)));
|
||||
if (weight_i_ptr_ == nullptr) {
|
||||
MS_LOG(ERROR) << "LstmCPUKernel malloc weight_i_ptr_ error.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
memcpy(weight_i_ptr_, weight_i->Data(), weight_i->ElementsNum() * sizeof(float));
|
||||
|
||||
auto weight_h = inputs_.at(2);
|
||||
MS_ASSERT(weight_h != nullptr);
|
||||
weight_h_ptr_ = reinterpret_cast<float *>(malloc(weight_h->ElementsNum() * sizeof(float)));
|
||||
if (weight_h_ptr_ == nullptr) {
|
||||
MS_LOG(ERROR) << "LstmCPUKernel malloc weight_h_ error.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
memcpy(weight_h_ptr_, weight_h->Data(), weight_h->ElementsNum() * sizeof(float));
|
||||
|
||||
// init bias
|
||||
int bias_num = lstm_parm_->bidirectional_ ? 2 * 4 * lstm_parm_->hidden_size_ : 4 * lstm_parm_->hidden_size_;
|
||||
bias_ptr_ = reinterpret_cast<float *>(malloc(bias_num * sizeof(float)));
|
||||
if (bias_ptr_ == nullptr) {
|
||||
MS_LOG(ERROR) << "LstmCPUKernel malloc bias_ptr_ error.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
|
||||
auto bias_data = reinterpret_cast<float *>(inputs_.at(3)->Data());
|
||||
int state_bias_offset = 4 * lstm_parm_->hidden_size_;
|
||||
for (int i = 0; i < state_bias_offset; i++) {
|
||||
bias_ptr_[i] = bias_data[i] + bias_data[i + state_bias_offset];
|
||||
}
|
||||
if (lstm_parm_->bidirectional_) {
|
||||
bias_data += 4 * lstm_parm_->hidden_size_ * 2;
|
||||
auto backward_bias = bias_ptr_ + 4 * lstm_parm_->hidden_size_;
|
||||
for (int i = 0; i < state_bias_offset; i++) {
|
||||
backward_bias[i] = bias_data[i] + bias_data[i + state_bias_offset];
|
||||
}
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
int LstmCPUKernel::Init() {
|
||||
auto ret = InitParam();
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "LstmCPUKernel InitParam error.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
|
||||
ret = InitWeightBias();
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "LstmCPUKernel InitWeightBias error.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
|
||||
ret = InitBuffer();
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "LstmCPUKernel InitBuffer error.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
int LstmCPUKernel::ReSize() {
|
||||
free(gate_buffer_);
|
||||
|
||||
auto ret = InitParam();
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "LstmCPUKernel InitParam error.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
|
||||
ret = InitBuffer();
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "LstmCPUKernel InitBuffer error.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
int LstmCPUKernel::Run() {
|
||||
auto input = inputs_.at(kInputIndex);
|
||||
MS_ASSERT(input != nullptr);
|
||||
auto hidden_state = inputs_.at(4);
|
||||
MS_ASSERT(hidden_state != nullptr);
|
||||
auto cell_state = inputs_.at(5);
|
||||
MS_ASSERT(cell_state != nullptr);
|
||||
auto output = outputs_.at(0);
|
||||
MS_ASSERT(output != nullptr);
|
||||
|
||||
auto input_ptr = reinterpret_cast<float *>(input->Data());
|
||||
auto output_ptr = reinterpret_cast<float *>(output->Data());
|
||||
|
||||
auto output_hidden_state = outputs_[1];
|
||||
memcpy(output_hidden_state->Data(), hidden_state->Data(), hidden_state->ElementsNum() * sizeof(float));
|
||||
auto output_cell_state = outputs_[2];
|
||||
memcpy(output_cell_state->Data(), cell_state->Data(), cell_state->ElementsNum() * sizeof(float));
|
||||
|
||||
Lstm(output_ptr, input_ptr, weight_i_ptr_, weight_h_ptr_, bias_ptr_,
|
||||
reinterpret_cast<float *>(output_hidden_state->Data()), reinterpret_cast<float *>(output_cell_state->Data()),
|
||||
gate_buffer_, lstm_parm_);
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
kernel::LiteKernel *CpuLstmKernelCreator(const std::vector<lite::tensor::Tensor *> &inputs,
|
||||
const std::vector<lite::tensor::Tensor *> &outputs, OpParameter *opParameter,
|
||||
const lite::Context *ctx, const kernel::KernelKey &desc) {
|
||||
MS_ASSERT(opParameter != nullptr);
|
||||
MS_ASSERT(desc.type == schema::PrimitiveType_Lstm);
|
||||
|
||||
auto *kernel = new (std::nothrow) LstmCPUKernel(opParameter, inputs, outputs, ctx);
|
||||
if (kernel == nullptr) {
|
||||
MS_LOG(ERROR) << "kernel is nullptr.";
|
||||
return nullptr;
|
||||
}
|
||||
auto ret = kernel->Init();
|
||||
if (ret != RET_OK) {
|
||||
delete kernel;
|
||||
MS_LOG(ERROR) << "Init kernel failed, name: " << opParameter->name_ << ", type: "
|
||||
<< schema::EnumNamePrimitiveType(static_cast<schema::PrimitiveType>(opParameter->type_));
|
||||
return nullptr;
|
||||
}
|
||||
return kernel;
|
||||
}
|
||||
|
||||
REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_Lstm, CpuLstmKernelCreator)
|
||||
} // namespace mindspore::kernel
|
|
@ -0,0 +1,57 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_LSTM_H_
|
||||
#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_LSTM_H_
|
||||
|
||||
#include <vector>
|
||||
#include "src/lite_kernel.h"
|
||||
#include "src/runtime/kernel/arm/opclib/fp32/lstm.h"
|
||||
|
||||
namespace mindspore::kernel {
|
||||
class LstmCPUKernel : public LiteKernel {
|
||||
public:
|
||||
LstmCPUKernel(OpParameter *parameter, const std::vector<lite::tensor::Tensor *> &inputs,
|
||||
const std::vector<lite::tensor::Tensor *> &outputs, const lite::Context *ctx)
|
||||
: LiteKernel(parameter, inputs, outputs) {
|
||||
lstm_parm_ = reinterpret_cast<LstmParameter *>(opParameter);
|
||||
}
|
||||
|
||||
~LstmCPUKernel() override {
|
||||
free(gate_buffer_);
|
||||
free(weight_i_ptr_);
|
||||
free(weight_h_ptr_);
|
||||
free(bias_ptr_);
|
||||
}
|
||||
|
||||
int Init() override;
|
||||
int ReSize() override;
|
||||
int Run() override;
|
||||
|
||||
int InitParam();
|
||||
int InitBuffer();
|
||||
int InitWeightBias();
|
||||
|
||||
private:
|
||||
float *gate_buffer_;
|
||||
float *weight_i_ptr_;
|
||||
float *weight_h_ptr_;
|
||||
float *bias_ptr_;
|
||||
LstmParameter *lstm_parm_;
|
||||
};
|
||||
} // namespace mindspore::kernel
|
||||
|
||||
#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_LSTM_H_
|
|
@ -0,0 +1,158 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "src/runtime/kernel/arm/opclib/fp32/lstm.h"
|
||||
#include <string.h>
|
||||
#include "src/runtime/kernel/arm/opclib/fp32/activation.h"
|
||||
#include "src/runtime/kernel/arm/opclib/fp32/arithmetic.h"
|
||||
|
||||
void InitGate(float *gate_buffer, const float *bias, LstmParameter *lstm_parm) {
|
||||
int gate_offest = 0;
|
||||
for (int l = 0; l < 4; l++) {
|
||||
int batch_offest = gate_offest;
|
||||
int bias_offest = l * lstm_parm->hidden_size_;
|
||||
for (int b = 0; b < lstm_parm->batch_; b++) {
|
||||
memcpy(gate_buffer + batch_offest, bias + bias_offest, lstm_parm->hidden_size_ * sizeof(float));
|
||||
batch_offest += lstm_parm->hidden_size_;
|
||||
}
|
||||
gate_offest += lstm_parm->batch_ * lstm_parm->hidden_size_;
|
||||
}
|
||||
}
|
||||
|
||||
// input: [row, inner_size]; weight: [col, inner_size]; output: [row, col]
|
||||
void MatMulAcc(float *output, const float *input, const float *weight, int rows, int cols, int inner_size) {
|
||||
for (int r = 0; r < rows; r++) {
|
||||
for (int c = 0; c < cols; c++) {
|
||||
float res = 0;
|
||||
for (int i = 0; i < inner_size; i++) {
|
||||
res += input[r * inner_size + i] * weight[c * inner_size + i];
|
||||
}
|
||||
output[r * cols + c] += res;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void ElementMulAcc(float *input0, float *input1, float *output, int element_size) {
|
||||
for (int index = 0; index < element_size; index++) {
|
||||
output[index] += input0[index] * input1[index];
|
||||
}
|
||||
}
|
||||
|
||||
void UpdataState(float *cell_state, float *forget_gate, float *input_gate, float *cell_gate, int batch,
|
||||
int hidden_size) {
|
||||
ElementMul(forget_gate, cell_state, cell_state, batch * hidden_size);
|
||||
ElementMulAcc(input_gate, cell_gate, cell_state, batch * hidden_size);
|
||||
}
|
||||
|
||||
void UpdataOutput(float *cell_state, float *output_gate, float *hidden_state, int batch, int hidden_size) {
|
||||
Tanh(cell_state, batch * hidden_size, hidden_state);
|
||||
ElementMul(hidden_state, output_gate, hidden_state, batch * hidden_size);
|
||||
}
|
||||
|
||||
void LstmStepUnit(float *output, const float *input, const float *input_input_weight, const float *input_forget_weight,
|
||||
const float *input_cell_weight, const float *input_output_weight, const float *state_input_weight,
|
||||
const float *state_forget_weight, const float *state_cell_weight, const float *state_output_weight,
|
||||
const float *bias, float *hidden_state, float *cell_state, float *gate_buffer,
|
||||
LstmParameter *lstm_parm) {
|
||||
InitGate(gate_buffer, bias, lstm_parm);
|
||||
|
||||
float *input_gate = gate_buffer;
|
||||
float *forget_gate = gate_buffer + lstm_parm->batch_ * lstm_parm->hidden_size_ * 2;
|
||||
float *cell_gate = gate_buffer + lstm_parm->batch_ * lstm_parm->hidden_size_ * 3;
|
||||
float *output_gate = gate_buffer + lstm_parm->batch_ * lstm_parm->hidden_size_ * 1;
|
||||
|
||||
// input * weight
|
||||
MatMulAcc(input_gate, input, input_input_weight, lstm_parm->batch_, lstm_parm->hidden_size_, lstm_parm->input_size_);
|
||||
MatMulAcc(forget_gate, input, input_forget_weight, lstm_parm->batch_, lstm_parm->hidden_size_,
|
||||
lstm_parm->input_size_);
|
||||
MatMulAcc(cell_gate, input, input_cell_weight, lstm_parm->batch_, lstm_parm->hidden_size_, lstm_parm->input_size_);
|
||||
MatMulAcc(output_gate, input, input_output_weight, lstm_parm->batch_, lstm_parm->hidden_size_,
|
||||
lstm_parm->input_size_);
|
||||
|
||||
// state * weight
|
||||
MatMulAcc(input_gate, hidden_state, state_input_weight, lstm_parm->batch_, lstm_parm->hidden_size_,
|
||||
lstm_parm->hidden_size_);
|
||||
MatMulAcc(forget_gate, hidden_state, state_forget_weight, lstm_parm->batch_, lstm_parm->hidden_size_,
|
||||
lstm_parm->hidden_size_);
|
||||
MatMulAcc(cell_gate, hidden_state, state_cell_weight, lstm_parm->batch_, lstm_parm->hidden_size_,
|
||||
lstm_parm->hidden_size_);
|
||||
MatMulAcc(output_gate, hidden_state, state_output_weight, lstm_parm->batch_, lstm_parm->hidden_size_,
|
||||
lstm_parm->hidden_size_);
|
||||
|
||||
// update input_gate
|
||||
Sigmoid(input_gate, lstm_parm->batch_ * lstm_parm->hidden_size_, input_gate);
|
||||
|
||||
// update forget_gate
|
||||
Sigmoid(forget_gate, lstm_parm->batch_ * lstm_parm->hidden_size_, forget_gate);
|
||||
|
||||
// update cell_gate
|
||||
Tanh(cell_gate, lstm_parm->batch_ * lstm_parm->hidden_size_, cell_gate);
|
||||
// update cell state
|
||||
UpdataState(cell_state, forget_gate, input_gate, cell_gate, lstm_parm->batch_, lstm_parm->hidden_size_);
|
||||
|
||||
// update output_gate
|
||||
Sigmoid(output_gate, lstm_parm->batch_ * lstm_parm->hidden_size_, output_gate);
|
||||
// update output
|
||||
UpdataOutput(cell_state, output_gate, hidden_state, lstm_parm->batch_, lstm_parm->hidden_size_);
|
||||
memcpy(output, hidden_state, lstm_parm->batch_ * lstm_parm->hidden_size_ * sizeof(float));
|
||||
}
|
||||
|
||||
void Lstm(float *output, const float *input, const float *weight_i, const float *weight_h, const float *bias,
|
||||
float *hidden_state, float *cell_state, float *gate_buffer, LstmParameter *lstm_parm) {
|
||||
// forward
|
||||
const float *input_input_weight = weight_i;
|
||||
const float *input_forget_weight = weight_i + lstm_parm->input_size_ * lstm_parm->hidden_size_ * 2;
|
||||
const float *input_cell_weight = weight_i + lstm_parm->input_size_ * lstm_parm->hidden_size_ * 3;
|
||||
const float *input_output_weight = weight_i + lstm_parm->input_size_ * lstm_parm->hidden_size_ * 1;
|
||||
|
||||
const float *state_input_weight = weight_h;
|
||||
const float *state_forget_weight = weight_h + lstm_parm->hidden_size_ * lstm_parm->hidden_size_ * 2;
|
||||
const float *state_cell_weight = weight_h + lstm_parm->hidden_size_ * lstm_parm->hidden_size_ * 3;
|
||||
const float *state_output_weight = weight_h + lstm_parm->hidden_size_ * lstm_parm->hidden_size_ * 1;
|
||||
|
||||
for (int t = 0; t < lstm_parm->seq_len_; t++) {
|
||||
const float *input_ptr = input + t * lstm_parm->input_step_;
|
||||
float *output_ptr = output + t * lstm_parm->output_step_;
|
||||
LstmStepUnit(output_ptr, input_ptr, input_input_weight, input_forget_weight, input_cell_weight, input_output_weight,
|
||||
state_input_weight, state_forget_weight, state_cell_weight, state_output_weight, bias, hidden_state,
|
||||
cell_state, gate_buffer, lstm_parm);
|
||||
}
|
||||
|
||||
// backward
|
||||
if (lstm_parm->bidirectional_) {
|
||||
input_input_weight = weight_i + lstm_parm->input_size_ * lstm_parm->hidden_size_ * 4;
|
||||
input_forget_weight = weight_i + lstm_parm->input_size_ * lstm_parm->hidden_size_ * 6;
|
||||
input_cell_weight = weight_i + lstm_parm->input_size_ * lstm_parm->hidden_size_ * 7;
|
||||
input_output_weight = weight_i + lstm_parm->input_size_ * lstm_parm->hidden_size_ * 5;
|
||||
|
||||
state_input_weight = weight_h + lstm_parm->hidden_size_ * lstm_parm->hidden_size_ * 4;
|
||||
state_forget_weight = weight_h + lstm_parm->hidden_size_ * lstm_parm->hidden_size_ * 6;
|
||||
state_cell_weight = weight_h + lstm_parm->hidden_size_ * lstm_parm->hidden_size_ * 7;
|
||||
state_output_weight = weight_h + lstm_parm->hidden_size_ * lstm_parm->hidden_size_ * 5;
|
||||
|
||||
float *backward_output = output + lstm_parm->batch_ * lstm_parm->hidden_size_;
|
||||
const float *backward_bias = bias + 4 * lstm_parm->hidden_size_;
|
||||
float *backward_cell_state = cell_state + lstm_parm->batch_ * lstm_parm->hidden_size_;
|
||||
float *backward_hidden_state = hidden_state + lstm_parm->batch_ * lstm_parm->hidden_size_;
|
||||
for (int t = lstm_parm->seq_len_ - 1; t >= 0; t--) {
|
||||
const float *input_ptr = input + t * lstm_parm->input_step_;
|
||||
float *output_ptr = backward_output + t * lstm_parm->output_step_;
|
||||
LstmStepUnit(output_ptr, input_ptr, input_input_weight, input_forget_weight, input_cell_weight,
|
||||
input_output_weight, state_input_weight, state_forget_weight, state_cell_weight, state_output_weight,
|
||||
backward_bias, backward_hidden_state, backward_cell_state, gate_buffer, lstm_parm);
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,36 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_OPCLIB_FP32_LSTM_H_
|
||||
#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_OPCLIB_FP32_LSTM_H_
|
||||
|
||||
#include "src/runtime/kernel/arm/opclib/op_base.h"
|
||||
|
||||
struct LstmParameter {
|
||||
OpParameter op_parameter_;
|
||||
int input_size_;
|
||||
int hidden_size_; // output_size
|
||||
int seq_len_;
|
||||
int batch_;
|
||||
int input_step_;
|
||||
int output_step_;
|
||||
bool bidirectional_;
|
||||
};
|
||||
|
||||
void Lstm(float *output, const float *input, const float *weight_i, const float *weight_h, const float *bias,
|
||||
float *hidden_state, float *cell_state, float *gate_buffer, LstmParameter *lstm_parm);
|
||||
|
||||
#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_OPCLIB_FP32_LSTM_H_
|
|
@ -0,0 +1,330 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include <iostream>
|
||||
#include <memory>
|
||||
#include "utils/log_adapter.h"
|
||||
#include "common/common_test.h"
|
||||
#include "mindspore/lite/src/runtime/kernel/arm/opclib/fp32/lstm.h"
|
||||
#include "mindspore/lite/src/kernel_registry.h"
|
||||
#include "mindspore/lite/src/ops/ops.h"
|
||||
|
||||
namespace mindspore {
|
||||
class LstmFp32 : public mindspore::Common {
|
||||
public:
|
||||
LstmFp32() {}
|
||||
};
|
||||
|
||||
void InitLstmParam(LstmParameter *lstm_param) {
|
||||
lstm_param->seq_len_ = 4;
|
||||
lstm_param->batch_ = 1;
|
||||
lstm_param->input_size_ = 2;
|
||||
lstm_param->hidden_size_ = 3;
|
||||
lstm_param->bidirectional_ = false;
|
||||
}
|
||||
|
||||
void InitLstmForwardCreator(std::vector<lite::tensor::Tensor *> *inputs, std::vector<lite::tensor::Tensor *> *outputs,
|
||||
const LstmParameter *lstm_param) {
|
||||
// prepare input
|
||||
std::vector<float> input_data = {1.3889, -0.3006, -0.1787, 2.1504, -0.3181, 0.4945, -0.4758, -0.8187};
|
||||
auto *input = new lite::tensor::Tensor;
|
||||
input->set_data_type(kNumberTypeFloat32);
|
||||
input->set_shape({lstm_param->seq_len_, lstm_param->batch_, lstm_param->input_size_});
|
||||
input->MallocData();
|
||||
memcpy(input->Data(), input_data.data(), input_data.size() * sizeof(float));
|
||||
|
||||
// prepare weight_i
|
||||
std::vector<float> weight_i_data = {0.21368974, -0.3778776, 0.05025542, 0.09011161, 0.18355745, 0.5491228,
|
||||
-0.14186832, -0.4655916, 0.49541366, -0.44039622, 0.5625571, 0.23325664,
|
||||
0.3449825, -0.42750397, 0.01911497, -0.4125802, -0.56690466, 0.50593233,
|
||||
-0.29129684, -0.27841482, 0.01964372, -0.42543447, 0.41720617, -0.30054367};
|
||||
auto *weight_i = new lite::tensor::Tensor;
|
||||
weight_i->set_data_type(kNumberTypeFloat32);
|
||||
weight_i->SetFormat(schema::Format_NHWC);
|
||||
weight_i->set_shape({1, lstm_param->hidden_size_ * 4, lstm_param->input_size_});
|
||||
weight_i->MallocData();
|
||||
memcpy(weight_i->Data(), weight_i_data.data(), weight_i_data.size() * sizeof(float));
|
||||
|
||||
// prepare weight_r
|
||||
std::vector<float> weight_h_data = {
|
||||
-0.03424168, 0.00643545, 0.36867607, -0.08598137, 0.19804275, -0.11319417, -0.0244593, -0.16440144, -0.07268238,
|
||||
0.09828371, 0.33358777, 0.53381383, -0.39431244, -0.06005383, -0.3520246, 0.42687547, 0.5772828, 0.5380008,
|
||||
-0.16130409, -0.24737108, 0.42409766, -0.50648475, 0.48223662, -0.5221103, -0.49216837, -0.29084128, 0.3408438,
|
||||
0.34080023, 0.49467337, 0.23473483, 0.01759732, 0.04691631, 0.45574808, -0.29481018, 0.29442167, -0.36718};
|
||||
auto *weight_h = new lite::tensor::Tensor;
|
||||
weight_h->set_data_type(kNumberTypeFloat32);
|
||||
weight_h->SetFormat(schema::Format_NHWC);
|
||||
weight_h->set_shape({1, lstm_param->hidden_size_ * 4, lstm_param->hidden_size_});
|
||||
weight_h->MallocData();
|
||||
memcpy(weight_h->Data(), weight_h_data.data(), weight_h_data.size() * sizeof(float));
|
||||
|
||||
// prepare bias
|
||||
std::vector<float> bias_data = {-0.00207639, 0.16391152, -0.00069344, -0.32945693, -0.367423, 0.28301108,
|
||||
-0.17930457, 0.5278388, 0.12598747, -0.53130764, 0.1479364, 0.16695255,
|
||||
-0.00708795, -0.46417096, -0.23966661, -0.17496741, -0.19166365, -0.50466555,
|
||||
-0.23593256, -0.3911457, 0.51128435, 0.5128727, 0.253451, -0.51891875};
|
||||
auto *bias = new lite::tensor::Tensor;
|
||||
bias->set_data_type(kNumberTypeFloat32);
|
||||
bias->SetFormat(schema::Format_NHWC);
|
||||
bias->set_shape({1, lstm_param->hidden_size_ * 4 * 2});
|
||||
bias->MallocData();
|
||||
memcpy(bias->Data(), bias_data.data(), bias_data.size() * sizeof(float));
|
||||
|
||||
// prepare state
|
||||
std::vector<float> state_data = {0, 0, 0};
|
||||
auto *state = new lite::tensor::Tensor;
|
||||
state->set_data_type(kNumberTypeFloat32);
|
||||
state->SetFormat(schema::Format_NHWC);
|
||||
state->set_shape({1, lstm_param->batch_, lstm_param->hidden_size_});
|
||||
state->MallocData();
|
||||
memcpy(state->Data(), state_data.data(), state_data.size() * sizeof(float));
|
||||
|
||||
inputs->push_back(input);
|
||||
inputs->push_back(weight_i);
|
||||
inputs->push_back(weight_h);
|
||||
inputs->push_back(bias);
|
||||
inputs->push_back(state);
|
||||
inputs->push_back(state);
|
||||
|
||||
// malloc output buffer, for arm cpu, format: N C4 H W 4
|
||||
auto *output = new lite::tensor::Tensor;
|
||||
output->set_data_type(kNumberTypeFloat32);
|
||||
output->set_shape({lstm_param->seq_len_, lstm_param->batch_, lstm_param->hidden_size_});
|
||||
output->SetFormat(schema::Format_NHWC);
|
||||
output->MallocData();
|
||||
memset(output->Data(), 0, output->ElementsNum() * sizeof(float));
|
||||
|
||||
auto *cell_state = new lite::tensor::Tensor;
|
||||
cell_state->set_data_type(kNumberTypeFloat32);
|
||||
cell_state->set_shape({1, lstm_param->batch_, lstm_param->hidden_size_});
|
||||
cell_state->SetFormat(schema::Format_NHWC);
|
||||
cell_state->MallocData();
|
||||
memset(cell_state->Data(), 0, cell_state->ElementsNum() * sizeof(float));
|
||||
|
||||
auto *hidden_state = new lite::tensor::Tensor;
|
||||
hidden_state->set_data_type(kNumberTypeFloat32);
|
||||
hidden_state->set_shape({1, lstm_param->batch_, lstm_param->hidden_size_});
|
||||
hidden_state->SetFormat(schema::Format_NHWC);
|
||||
hidden_state->MallocData();
|
||||
memset(hidden_state->Data(), 0, hidden_state->ElementsNum() * sizeof(float));
|
||||
|
||||
outputs->push_back(output);
|
||||
outputs->push_back(cell_state);
|
||||
outputs->push_back(hidden_state);
|
||||
}
|
||||
|
||||
void CompareOutput(lite::tensor::Tensor *output, std::vector<float> data) {
|
||||
for (int i = 0; i < output->ElementsNum(); i++) {
|
||||
std::cout << reinterpret_cast<float *>(output->Data())[i] << ", ";
|
||||
}
|
||||
std::cout << std::endl;
|
||||
|
||||
Common::CompareOutputData(reinterpret_cast<float *>(output->Data()), data.data(), output->ElementsNum(), 0.0001);
|
||||
}
|
||||
|
||||
TEST_F(LstmFp32, LstmForwardFp32Accuracy) {
|
||||
// prepare stage
|
||||
auto lstm_param = new LstmParameter();
|
||||
InitLstmParam(lstm_param);
|
||||
|
||||
// init ctx
|
||||
auto ctx = new lite::Context();
|
||||
ctx->thread_num_ = 1;
|
||||
|
||||
// init tensor
|
||||
std::vector<lite::tensor::Tensor *> inputs;
|
||||
std::vector<lite::tensor::Tensor *> outputs;
|
||||
InitLstmForwardCreator(&inputs, &outputs, lstm_param);
|
||||
|
||||
// register op
|
||||
kernel::KernelKey desc = {kernel::KERNEL_ARCH::kCPU, kNumberTypeFloat32, mindspore::schema::PrimitiveType_Lstm};
|
||||
auto creator = lite::KernelRegistry::GetInstance()->GetCreator(desc);
|
||||
ASSERT_NE(creator, nullptr);
|
||||
kernel::LiteKernel *kernel = creator(inputs, outputs, reinterpret_cast<OpParameter *>(lstm_param), ctx, desc);
|
||||
ASSERT_NE(kernel, nullptr);
|
||||
// op run
|
||||
kernel->Run();
|
||||
|
||||
std::cout << "==================output data=================" << std::endl;
|
||||
std::vector<float> output0_data = {-0.0702, 0.1225, 0.0876, -0.0357, -0.0227, -0.2294,
|
||||
-0.0345, -0.0108, -0.2002, 0.0451, 0.0853, -0.1205};
|
||||
CompareOutput(outputs[0], output0_data);
|
||||
|
||||
std::vector<float> output1_data = {0.0451, 0.0853, -0.1205};
|
||||
CompareOutput(outputs[1], output1_data);
|
||||
|
||||
std::vector<float> output2_data = {0.0989, 0.2094, -0.4132};
|
||||
CompareOutput(outputs[2], output2_data);
|
||||
|
||||
delete lstm_param;
|
||||
for (int i = 0; i < inputs.size() - 1; i++) {
|
||||
delete inputs[i];
|
||||
}
|
||||
for (int i = 0; i < outputs.size(); i++) {
|
||||
delete outputs[i];
|
||||
}
|
||||
delete kernel;
|
||||
MS_LOG(INFO) << "LstmFp32 forward accuracy passed";
|
||||
}
|
||||
|
||||
void InitLstmBackwardCreator(std::vector<lite::tensor::Tensor *> *inputs, std::vector<lite::tensor::Tensor *> *outputs,
|
||||
const LstmParameter *lstm_param) {
|
||||
// prepare input
|
||||
std::vector<float> input_data = {1.4305, 0.5342, -0.9221, 0.0527, 2.3770, -0.3697, -0.2833, -2.1285};
|
||||
auto *input = new lite::tensor::Tensor;
|
||||
input->set_data_type(kNumberTypeFloat32);
|
||||
input->set_shape({lstm_param->seq_len_, lstm_param->batch_, lstm_param->input_size_});
|
||||
input->MallocData();
|
||||
memcpy(input->Data(), input_data.data(), input_data.size() * sizeof(float));
|
||||
|
||||
// prepare weight_i
|
||||
std::vector<float> weight_i_data = {
|
||||
-0.19253477, -0.007966279, -0.06039094, 0.27697134, -0.5071223, 0.18996351, 0.20472168, -0.1007814,
|
||||
0.04282999, 0.20836472, -0.4654655, 0.050321221, -0.3431457, 0.22256428, 0.29294532, 0.45042896,
|
||||
0.20468240, 0.13078391, -0.20987969, -0.3173505, -0.3813517, 0.10205835, 0.21858131, -0.0386473,
|
||||
0.5512280, -0.2763766, -0.3593936, -0.5181975, 0.3469863, -0.38533931, 0.010202527, -0.46598294,
|
||||
-0.5740513, 0.06127524, -0.03960543, 0.2478809, -0.17296993, 0.19159525, -0.4976995, 0.05985528,
|
||||
0.3653409, 0.386924, 0.3170289, -0.08830952, -0.31105759, 0.3110240, 0.15174299, 0.287579894};
|
||||
auto *weight_i = new lite::tensor::Tensor;
|
||||
weight_i->set_data_type(kNumberTypeFloat32);
|
||||
weight_i->SetFormat(schema::Format_NHWC);
|
||||
weight_i->set_shape({2, lstm_param->hidden_size_ * 4, lstm_param->input_size_});
|
||||
weight_i->MallocData();
|
||||
memcpy(weight_i->Data(), weight_i_data.data(), weight_i_data.size() * sizeof(float));
|
||||
|
||||
// prepare weight_r
|
||||
std::vector<float> weight_h_data = {
|
||||
0.106934666, -0.50430017, 0.33296257, -0.288117021, -0.38019785, -0.147071093, 0.422707557, 0.41497004,
|
||||
-0.5329730, -0.430150926, -0.032713949, 0.35401260, 0.179495036, -0.14158579, 0.380428612, -0.175597071,
|
||||
0.54088723, -0.403292059, -0.287720531, -0.51250511, -0.15405902, -0.440592586, 0.16726928, -0.0163397789,
|
||||
0.51673841, 0.5094323, -0.137105107, -0.181070089, -0.47221425, -0.38046866, -0.206725060, 0.248537719,
|
||||
-0.23961094, -0.117781728, 0.426800847, 0.0266208052, -0.197408229, 0.54831492, -0.280048757, -0.125062286,
|
||||
-0.29929456, 0.42354834, -0.401066303, 0.356340110, 0.54629492, -0.15852552, 0.131406366, -0.101815432,
|
||||
0.0121276974, -0.53553336, 0.121099889, 0.060554087, 0.46259057, -0.49666053, 0.090806663, 0.20542401,
|
||||
-0.38674920, -0.23874849, -0.5222138, 0.57537007, 0.113343358, -0.35233467, -0.25532332, 0.159506142,
|
||||
0.35996592, -0.201961308, -0.16323345, 0.119177639, -0.12677872, -0.175229549, -0.160024613, -0.21058899};
|
||||
auto *weight_h = new lite::tensor::Tensor;
|
||||
weight_h->set_data_type(kNumberTypeFloat32);
|
||||
weight_h->SetFormat(schema::Format_NHWC);
|
||||
weight_h->set_shape({2, lstm_param->hidden_size_ * 4, lstm_param->hidden_size_});
|
||||
weight_h->MallocData();
|
||||
memcpy(weight_h->Data(), weight_h_data.data(), weight_h_data.size() * sizeof(float));
|
||||
|
||||
// prepare bias
|
||||
std::vector<float> bias_data = {
|
||||
0.57061123, -0.25357073, -0.146834075, 0.412972748, -0.27809411, -0.0542128682, -0.45384609, -0.53261917,
|
||||
0.222133636, -0.18093895, -0.045559883, 0.09109061, 0.080319643, 0.455167174, 0.36235427, -0.00164419412,
|
||||
-0.135566502, 0.41905909, -0.450117409, 0.50565385, -0.077815443, -0.47051778, -0.141349375, -0.338519752,
|
||||
0.48683023, 0.282384872, 0.13399660, -0.382526844, -0.23370727, -0.184681564, 0.45679104, -0.339453905,
|
||||
0.452010273, 0.0552094578, 0.328843057, 0.127738714, -0.127084732, -0.334061294, -0.46742400, -0.401568055,
|
||||
0.23712641, -0.052937567, 0.272351622, 0.42767739, 0.303884744, -0.46025499, -0.43985402, 0.256422877};
|
||||
auto *bias = new lite::tensor::Tensor;
|
||||
bias->set_data_type(kNumberTypeFloat32);
|
||||
bias->SetFormat(schema::Format_NHWC);
|
||||
bias->set_shape({2, lstm_param->hidden_size_ * 4 * 2});
|
||||
bias->MallocData();
|
||||
memcpy(bias->Data(), bias_data.data(), bias_data.size() * sizeof(float));
|
||||
|
||||
// prepare state
|
||||
std::vector<float> state_data = {0, 0, 0, 0, 0, 0};
|
||||
auto *state = new lite::tensor::Tensor;
|
||||
state->set_data_type(kNumberTypeFloat32);
|
||||
state->SetFormat(schema::Format_NHWC);
|
||||
state->set_shape({2, lstm_param->batch_, lstm_param->hidden_size_});
|
||||
state->MallocData();
|
||||
memcpy(state->Data(), state_data.data(), state_data.size() * sizeof(float));
|
||||
|
||||
inputs->push_back(input);
|
||||
inputs->push_back(weight_i);
|
||||
inputs->push_back(weight_h);
|
||||
inputs->push_back(bias);
|
||||
inputs->push_back(state);
|
||||
inputs->push_back(state);
|
||||
|
||||
// malloc output buffer, for arm cpu, format: N C4 H W 4
|
||||
auto *output = new lite::tensor::Tensor;
|
||||
output->set_data_type(kNumberTypeFloat32);
|
||||
output->set_shape({lstm_param->seq_len_, 2, lstm_param->batch_, lstm_param->hidden_size_});
|
||||
output->SetFormat(schema::Format_NHWC);
|
||||
output->MallocData();
|
||||
memset(output->Data(), 0, output->ElementsNum() * sizeof(float));
|
||||
|
||||
auto *cell_state = new lite::tensor::Tensor;
|
||||
cell_state->set_data_type(kNumberTypeFloat32);
|
||||
cell_state->set_shape({2, lstm_param->batch_, lstm_param->hidden_size_});
|
||||
cell_state->SetFormat(schema::Format_NHWC);
|
||||
cell_state->MallocData();
|
||||
memset(cell_state->Data(), 0, cell_state->ElementsNum() * sizeof(float));
|
||||
|
||||
auto *hidden_state = new lite::tensor::Tensor;
|
||||
hidden_state->set_data_type(kNumberTypeFloat32);
|
||||
hidden_state->set_shape({2, lstm_param->batch_, lstm_param->hidden_size_});
|
||||
hidden_state->SetFormat(schema::Format_NHWC);
|
||||
hidden_state->MallocData();
|
||||
memset(hidden_state->Data(), 0, hidden_state->ElementsNum() * sizeof(float));
|
||||
|
||||
outputs->push_back(output);
|
||||
outputs->push_back(cell_state);
|
||||
outputs->push_back(hidden_state);
|
||||
}
|
||||
|
||||
TEST_F(LstmFp32, LstmBackwardFp32Accuracy) {
|
||||
// prepare stage
|
||||
auto lstm_param = new LstmParameter();
|
||||
InitLstmParam(lstm_param);
|
||||
lstm_param->bidirectional_ = true;
|
||||
|
||||
// init ctx
|
||||
auto ctx = new lite::Context();
|
||||
ctx->thread_num_ = 1;
|
||||
|
||||
// init tensor
|
||||
std::vector<lite::tensor::Tensor *> inputs;
|
||||
std::vector<lite::tensor::Tensor *> outputs;
|
||||
InitLstmBackwardCreator(&inputs, &outputs, lstm_param);
|
||||
|
||||
// register op
|
||||
kernel::KernelKey desc = {kernel::KERNEL_ARCH::kCPU, kNumberTypeFloat32, mindspore::schema::PrimitiveType_Lstm};
|
||||
auto creator = lite::KernelRegistry::GetInstance()->GetCreator(desc);
|
||||
ASSERT_NE(creator, nullptr);
|
||||
kernel::LiteKernel *kernel = creator(inputs, outputs, reinterpret_cast<OpParameter *>(lstm_param), ctx, desc);
|
||||
ASSERT_NE(kernel, nullptr);
|
||||
// op run
|
||||
kernel->Run();
|
||||
|
||||
std::cout << "==================output data=================" << std::endl;
|
||||
std::vector<float> output0_data = {-0.2922, -0.1416, 0.0077, -0.0422, -0.0585, 0.2061, -0.2385, -0.0146,
|
||||
-0.1796, -0.0554, -0.0973, 0.1013, -0.3062, -0.1516, -0.0310, 0.0459,
|
||||
-0.0784, 0.0949, 0.0249, -0.0653, -0.0869, -0.1113, -0.2155, -0.0500};
|
||||
CompareOutput(outputs[0], output0_data);
|
||||
|
||||
std::vector<float> output1_data = {0.0249, -0.0653, -0.0869, -0.0422, -0.0585, 0.2061};
|
||||
CompareOutput(outputs[1], output1_data);
|
||||
|
||||
std::vector<float> output2_data = {0.0373, -0.2322, -0.1477, -0.1621, -0.1808, 0.5146};
|
||||
CompareOutput(outputs[2], output2_data);
|
||||
|
||||
delete lstm_param;
|
||||
for (int i = 0; i < inputs.size() - 1; i++) {
|
||||
delete inputs[i];
|
||||
}
|
||||
for (int i = 0; i < outputs.size(); i++) {
|
||||
delete outputs[i];
|
||||
}
|
||||
delete kernel;
|
||||
MS_LOG(INFO) << "LstmFp32 backward accuracy passed";
|
||||
}
|
||||
|
||||
} // namespace mindspore
|
Loading…
Reference in New Issue