!21204 [MS][Lite][ToD] Lay LSTMGrad infrastructure and fix static code analysis warnings

Merge pull request !21204 from ehaleva/export
This commit is contained in:
i-robot 2021-08-02 01:53:01 +00:00 committed by Gitee
commit e948611be1
17 changed files with 250 additions and 22 deletions

View File

@ -226,8 +226,9 @@ enum PrimType {
PrimType_TensorArrayWrite = 199, PrimType_TensorArrayWrite = 199,
PrimType_Affine = 200, PrimType_Affine = 200,
PrimType_Attention = 201, PrimType_Attention = 201,
PrimType_LSTMGrad = 202,
PrimType_MIN = PrimType_NONE, PrimType_MIN = PrimType_NONE,
PrimType_MAX = PrimType_Attention + 1 PrimType_MAX = PrimType_LSTMGrad + 1
}; };
void RegInfer(int prim_type, InferShape func); void RegInfer(int prim_type, InferShape func);

View File

@ -0,0 +1,98 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "ops/grad/lstm_grad.h"
#include <algorithm>
#include "ops/op_utils.h"
namespace mindspore {
namespace ops {
namespace {
AbstractBasePtr LstmGradInfer(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
// infer shape
MS_EXCEPTION_IF_NULL(primitive);
return nullptr;
}
} // namespace
void LSTMGrad::set_input_size(const int64_t input_size) {
CheckAndConvertUtils::CheckInteger(kInput_size, input_size, kGreaterThan, 0, this->name());
AddAttr(kInput_size, MakeValue(input_size));
}
int64_t LSTMGrad::get_input_size() const { return GetValue<int64_t>(GetAttr(kInput_size)); }
void LSTMGrad::set_hidden_size(const int64_t hidden_size) {
CheckAndConvertUtils::CheckInteger(kHidden_size, hidden_size, kGreaterThan, 0, this->name());
AddAttr(kHidden_size, MakeValue(hidden_size));
}
int64_t LSTMGrad::get_hidden_size() const { return GetValue<int64_t>(GetAttr(kHidden_size)); }
void LSTMGrad::set_num_layers(const int64_t num_layers) {
CheckAndConvertUtils::CheckInteger(kNumLayers, num_layers, kGreaterThan, 0, this->name());
AddAttr(kNumLayers, MakeValue(num_layers));
}
int64_t LSTMGrad::get_num_layers() const { return GetValue<int64_t>(GetAttr(kNumLayers)); }
void LSTMGrad::set_has_bias(const bool has_bias) { AddAttr(kHasBias, MakeValue(has_bias)); }
bool LSTMGrad::get_has_bias() const {
auto value_ptr = this->GetAttr(kHasBias);
return GetValue<bool>(value_ptr);
}
void LSTMGrad::set_dropout(const float dropout) {
CheckAndConvertUtils::CheckInRange<float>(kDropout, dropout, kIncludeBoth, {0.0, 1.0}, this->name());
AddAttr(kDropout, MakeValue(dropout));
}
float LSTMGrad::get_dropout() const {
auto value_ptr = this->GetAttr(kDropout);
return GetValue<float>(value_ptr);
}
void LSTMGrad::set_bidirectional(const bool bidirectional) { AddAttr(kBidirectional, MakeValue(bidirectional)); }
bool LSTMGrad::get_bidirectional() const {
auto value_ptr = this->GetAttr(kBidirectional);
return GetValue<bool>(value_ptr);
}
void LSTMGrad::set_num_directions(const int64_t num_directions) { AddAttr(kNumDirections, MakeValue(num_directions)); }
int64_t LSTMGrad::get_num_directions() const { return GetValue<int64_t>(GetAttr(kNumDirections)); }
void LSTMGrad::set_zoneout_cell(float zoneout_cell) { AddAttr(kZoneoutCell, MakeValue(zoneout_cell)); }
float LSTMGrad::get_zoneout_cell() const { return GetValue<float>(this->GetAttr(kZoneoutCell)); }
void LSTMGrad::set_zoneout_hidden(float zoneout_hidden) { AddAttr(kZoneoutHidden, MakeValue(zoneout_hidden)); }
float LSTMGrad::get_zoneout_hidden() const { return GetValue<float>(this->GetAttr(kZoneoutHidden)); }
void LSTMGrad::Init(const int64_t input_size, const int64_t hidden_size, const int64_t num_layers, const bool has_bias,
const float dropout, const bool bidirectional, const float zoneout_cell,
const float zoneout_hidden) {
this->set_input_size(input_size);
this->set_hidden_size(hidden_size);
this->set_num_layers(num_layers);
this->set_has_bias(has_bias);
this->set_dropout(dropout);
this->set_bidirectional(bidirectional);
if (bidirectional) {
this->set_num_directions(2);
} else {
this->set_num_directions(1);
}
this->set_zoneout_cell(zoneout_cell);
this->set_zoneout_hidden(zoneout_hidden);
}
AbstractBasePtr LstmGradInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) {
return std::make_shared<abstract::AbstractTensor>(LstmGradInfer(primitive, input_args));
}
REGISTER_PRIMITIVE_C(kNameLSTMGrad, LSTMGrad);
} // namespace ops
} // namespace mindspore

View File

@ -0,0 +1,64 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CORE_OPS_LSTM_GRAD_H_
#define MINDSPORE_CORE_OPS_LSTM_GRAD_H_
#include <map>
#include <vector>
#include <string>
#include <memory>
#include "ops/primitive_c.h"
#include "abstract/abstract_value.h"
#include "utils/check_convert_utils.h"
namespace mindspore {
namespace ops {
constexpr auto kNameLSTMGrad = "LSTMGrad";
class LSTMGrad : public PrimitiveC {
public:
LSTMGrad() : PrimitiveC(kNameLSTMGrad) {}
~LSTMGrad() = default;
MS_DECLARE_PARENT(LSTMGrad, PrimitiveC);
void Init(const int64_t input_size, const int64_t hidden_size, const int64_t num_layers, const bool has_bias,
const float dropout, const bool bidirectional = false, const float zoneout_cell = 0.0f,
const float zoneout_hidden = 0.0f);
void set_input_size(const int64_t input_size);
int64_t get_input_size() const;
void set_hidden_size(const int64_t hidden_size);
int64_t get_hidden_size() const;
void set_num_layers(const int64_t num_layers);
int64_t get_num_layers() const;
void set_has_bias(const bool has_bias);
bool get_has_bias() const;
void set_dropout(const float dropout);
float get_dropout() const;
void set_bidirectional(const bool bidirectional);
bool get_bidirectional() const;
void set_num_directions(const int64_t num_directions);
int64_t get_num_directions() const;
void set_zoneout_cell(float zoneout_cell);
float get_zoneout_cell() const;
void set_zoneout_hidden(float zoneout_hidden);
float get_zoneout_hidden() const;
int64_t get_good_ld(const int64_t dim, const int64_t type_size);
};
AbstractBasePtr LstmGradInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args);
using PrimLstmGradPtr = std::shared_ptr<LSTMGrad>;
} // namespace ops
} // namespace mindspore
#endif // MINDSPORE_CORE_OPS_LSTM_GRAD_H_

View File

@ -219,6 +219,7 @@ union PrimitiveType {
TensorArrayWrite, TensorArrayWrite,
Affine, Affine,
Attention, Attention,
LSTMGrad,
} }
table Abs { table Abs {
@ -665,6 +666,18 @@ table LSTM {
zoneout_hidden: float = 0; zoneout_hidden: float = 0;
} }
table LSTMGrad {
bidirectional: bool;
has_bias: bool;
input_size: long;
hidden_size: long;
num_layers: long;
num_directions: long;
dropout: float;
zoneout_cell: float = 0;
zoneout_hidden: float = 0;
}
table L2NormalizeFusion { table L2NormalizeFusion {
axis: [long]; axis: [long];
epsilon: float; epsilon: float;

View File

@ -219,6 +219,7 @@ OP_TYPE(TensorArrayWrite)
// kaldi affine op // kaldi affine op
OP_TYPE(Affine) OP_TYPE(Affine)
OP_TYPE(Attention) OP_TYPE(Attention)
OP_TYPE(LSTMGrad)
OP_TYPE_DEF_END(PrimitiveType) OP_TYPE_DEF_END(PrimitiveType)
OP_SCHEMA_DEF(Abs) OP_SCHEMA_DEF(Abs)
@ -665,6 +666,18 @@ OP_ATTR_WITH_VALUE(zoneout_cell, float, 0)
OP_ATTR_WITH_VALUE(zoneout_hidden, float, 0) OP_ATTR_WITH_VALUE(zoneout_hidden, float, 0)
OP_SCHEMA_DEF_END(LSTM) OP_SCHEMA_DEF_END(LSTM)
OP_SCHEMA_DEF(LSTMGrad)
OP_ATTR(bidirectional, bool)
OP_ATTR(has_bias, bool)
OP_ATTR(input_size, long)
OP_ATTR(hidden_size, long)
OP_ATTR(num_layers, long)
OP_ATTR(num_directions, long)
OP_ATTR(dropout, float)
OP_ATTR_WITH_VALUE(zoneout_cell, float, 0)
OP_ATTR_WITH_VALUE(zoneout_hidden, float, 0)
OP_SCHEMA_DEF_END(LSTMGrad)
OP_SCHEMA_DEF(L2NormalizeFusion) OP_SCHEMA_DEF(L2NormalizeFusion)
OP_ATTR(axis, [long]) OP_ATTR(axis, [long])
OP_ATTR(epsilon, float) OP_ATTR(epsilon, float)

View File

@ -183,6 +183,7 @@
#include "ops/grad/group_conv2d_grad_input.h" #include "ops/grad/group_conv2d_grad_input.h"
#include "ops/grad/layer_norm_grad.h" #include "ops/grad/layer_norm_grad.h"
#include "ops/grad/log_grad.h" #include "ops/grad/log_grad.h"
#include "ops/grad/lstm_grad.h"
#include "ops/grad/max_pool_grad.h" #include "ops/grad/max_pool_grad.h"
#include "ops/grad/maximum_grad.h" #include "ops/grad/maximum_grad.h"
#include "ops/grad/minimum_grad.h" #include "ops/grad/minimum_grad.h"
@ -344,6 +345,7 @@ FUNC_MSOP2SCHEMAOP_DECLARE(LpNormalization)
FUNC_MSOP2SCHEMAOP_DECLARE(LRN) FUNC_MSOP2SCHEMAOP_DECLARE(LRN)
FUNC_MSOP2SCHEMAOP_DECLARE(LshProjection) FUNC_MSOP2SCHEMAOP_DECLARE(LshProjection)
FUNC_MSOP2SCHEMAOP_DECLARE(LSTM) FUNC_MSOP2SCHEMAOP_DECLARE(LSTM)
FUNC_MSOP2SCHEMAOP_DECLARE(LSTMGrad)
FUNC_MSOP2SCHEMAOP_DECLARE(L2NormalizeFusion) FUNC_MSOP2SCHEMAOP_DECLARE(L2NormalizeFusion)
FUNC_MSOP2SCHEMAOP_DECLARE(MatMul) FUNC_MSOP2SCHEMAOP_DECLARE(MatMul)
FUNC_MSOP2SCHEMAOP_DECLARE(Maximum) FUNC_MSOP2SCHEMAOP_DECLARE(Maximum)

View File

@ -393,6 +393,10 @@ std::unique_ptr<schema::PrimitiveT> LSTMPrimitiveCreator(const AnfNodePtr &node)
auto ms_primc = GetValueNode<std::shared_ptr<mindspore::ops::LSTM>>(node); auto ms_primc = GetValueNode<std::shared_ptr<mindspore::ops::LSTM>>(node);
return ms_primc != nullptr ? ops::MSOp2SchemaOp(ms_primc.get()) : nullptr; return ms_primc != nullptr ? ops::MSOp2SchemaOp(ms_primc.get()) : nullptr;
} }
std::unique_ptr<schema::PrimitiveT> LSTMGradPrimitiveCreator(const AnfNodePtr &node) {
auto ms_primc = GetValueNode<std::shared_ptr<mindspore::ops::LSTMGrad>>(node);
return ms_primc != nullptr ? ops::MSOp2SchemaOp(ms_primc.get()) : nullptr;
}
std::unique_ptr<schema::PrimitiveT> L2NormalizeFusionPrimitiveCreator(const AnfNodePtr &node) { std::unique_ptr<schema::PrimitiveT> L2NormalizeFusionPrimitiveCreator(const AnfNodePtr &node) {
auto ms_primc = GetValueNode<std::shared_ptr<mindspore::ops::L2NormalizeFusion>>(node); auto ms_primc = GetValueNode<std::shared_ptr<mindspore::ops::L2NormalizeFusion>>(node);
return ms_primc != nullptr ? ops::MSOp2SchemaOp(ms_primc.get()) : nullptr; return ms_primc != nullptr ? ops::MSOp2SchemaOp(ms_primc.get()) : nullptr;
@ -913,6 +917,7 @@ RegistryMSOps g_lpNormalizationPrimitiveCreatorRegistry("LpNormalization", LpNor
RegistryMSOps g_lrnPrimitiveCreatorRegistry("LRN", LrnPrimitiveCreator); RegistryMSOps g_lrnPrimitiveCreatorRegistry("LRN", LrnPrimitiveCreator);
RegistryMSOps g_lshProjectionPrimitiveCreatorRegistry("LshProjection", LshProjectionPrimitiveCreator); RegistryMSOps g_lshProjectionPrimitiveCreatorRegistry("LshProjection", LshProjectionPrimitiveCreator);
RegistryMSOps g_lSTMPrimitiveCreatorRegistry("LSTM", LSTMPrimitiveCreator); RegistryMSOps g_lSTMPrimitiveCreatorRegistry("LSTM", LSTMPrimitiveCreator);
RegistryMSOps g_lSTMGradPrimitiveCreatorRegistry("LSTMGrad", LSTMGradPrimitiveCreator);
RegistryMSOps g_l2NormalizeFusionPrimitiveCreatorRegistry("L2NormalizeFusion", L2NormalizeFusionPrimitiveCreator); RegistryMSOps g_l2NormalizeFusionPrimitiveCreatorRegistry("L2NormalizeFusion", L2NormalizeFusionPrimitiveCreator);
RegistryMSOps g_matMulPrimitiveCreatorRegistry("MatMul", MatMulPrimitiveCreator); RegistryMSOps g_matMulPrimitiveCreatorRegistry("MatMul", MatMulPrimitiveCreator);
RegistryMSOps g_maximumPrimitiveCreatorRegistry("Maximum", MaximumPrimitiveCreator); RegistryMSOps g_maximumPrimitiveCreatorRegistry("Maximum", MaximumPrimitiveCreator);

View File

@ -30,7 +30,6 @@ using mindspore::lite::RET_OK;
using mindspore::schema::PrimitiveType_DropoutGrad; using mindspore::schema::PrimitiveType_DropoutGrad;
namespace mindspore::kernel { namespace mindspore::kernel {
int DropoutGradCPUKernelFp16::Init() { int DropoutGradCPUKernelFp16::Init() {
auto param = reinterpret_cast<DropoutParameter *>(op_parameter_); auto param = reinterpret_cast<DropoutParameter *>(op_parameter_);
if (param == nullptr) { if (param == nullptr) {
@ -63,7 +62,6 @@ int DropoutGradCPUKernelFp16::Execute(int task_id) {
auto length = in_tensors_.at(kInputIndex)->ElementsNum(); auto length = in_tensors_.at(kInputIndex)->ElementsNum();
int stride = UP_DIV(length, thread_count_); int stride = UP_DIV(length, thread_count_);
int count = MSMIN(stride, length - stride * task_id); int count = MSMIN(stride, length - stride * task_id);
if (count > 0) { if (count > 0) {
int start = stride * task_id; int start = stride * task_id;
DropoutFp16Grad(&(yt_ptr[start]), &(mask_ptr[start]), &(output_ptr[start]), count, (float16_t)scale_); DropoutFp16Grad(&(yt_ptr[start]), &(mask_ptr[start]), &(output_ptr[start]), count, (float16_t)scale_);

View File

@ -67,7 +67,6 @@ int PoolingGradCPUKernelFp16::Execute(int task_id) {
auto output_ptr = reinterpret_cast<float16_t *>(out_tensors_.at(0)->data_c()); auto output_ptr = reinterpret_cast<float16_t *>(out_tensors_.at(0)->data_c());
int stride = UP_DIV(pool_param->output_batch_, thread_num_); int stride = UP_DIV(pool_param->output_batch_, thread_num_);
int count = MSMIN(stride, pool_param->output_batch_ - stride * task_id); int count = MSMIN(stride, pool_param->output_batch_ - stride * task_id);
if (count > 0) { if (count > 0) {
int in_batch_size = pool_param->input_h_ * pool_param->input_w_ * pool_param->input_channel_; int in_batch_size = pool_param->input_h_ * pool_param->input_w_ * pool_param->input_channel_;
int out_batch_size = pool_param->output_h_ * pool_param->output_w_ * pool_param->input_channel_; int out_batch_size = pool_param->output_h_ * pool_param->output_w_ * pool_param->input_channel_;

View File

@ -28,7 +28,6 @@ using mindspore::lite::RET_OK;
using mindspore::schema::PrimitiveType_ResizeGrad; using mindspore::schema::PrimitiveType_ResizeGrad;
namespace mindspore::kernel { namespace mindspore::kernel {
float16_t ScalingFp16(size_t in_size, size_t out_size, bool align_corners) { float16_t ScalingFp16(size_t in_size, size_t out_size, bool align_corners) {
return (align_corners && out_size > 1) ? (in_size - 1) / static_cast<float16_t>(out_size - 1) return (align_corners && out_size > 1) ? (in_size - 1) / static_cast<float16_t>(out_size - 1)
: in_size / static_cast<float16_t>(out_size); : in_size / static_cast<float16_t>(out_size);

View File

@ -30,7 +30,6 @@ using mindspore::lite::RET_OK;
using mindspore::schema::PrimitiveType_StridedSliceGrad; using mindspore::schema::PrimitiveType_StridedSliceGrad;
namespace mindspore::kernel { namespace mindspore::kernel {
int StridedSliceGradCPUKernelFp16::Init() { int StridedSliceGradCPUKernelFp16::Init() {
if (!InferShapeDone()) { if (!InferShapeDone()) {
return RET_OK; return RET_OK;

View File

@ -29,7 +29,6 @@ using mindspore::lite::RET_OK;
using mindspore::schema::PrimitiveType_UnsortedSegmentSum; using mindspore::schema::PrimitiveType_UnsortedSegmentSum;
namespace mindspore::kernel { namespace mindspore::kernel {
int UnsortedSegmentSumCPUKernelFp16::Init() { int UnsortedSegmentSumCPUKernelFp16::Init() {
if (!InferShapeDone()) { if (!InferShapeDone()) {
return RET_OK; return RET_OK;

View File

@ -30,6 +30,8 @@ using mindspore::lite::RET_OK;
using mindspore::schema::PrimitiveType_LSTM; using mindspore::schema::PrimitiveType_LSTM;
namespace mindspore::kernel { namespace mindspore::kernel {
constexpr static int kWorkspaceOutIdx = 4;
void LstmCPUKernel::FreeTmpBuffer() { void LstmCPUKernel::FreeTmpBuffer() {
if (weight_i_ptr_ != nullptr) { if (weight_i_ptr_ != nullptr) {
free(weight_i_ptr_); free(weight_i_ptr_);
@ -361,10 +363,18 @@ int LstmCPUKernel::LstmUnidirectional(float *output, const float *weight_i, cons
float *output_ptr = output + real_t * lstm_param_->output_step_; float *output_ptr = output + real_t * lstm_param_->output_step_;
LstmStepUnit(output_ptr, input_gate_t, forget_gate_t, cell_gate_t, output_gate_t, weight_h, state_bias, LstmStepUnit(output_ptr, input_gate_t, forget_gate_t, cell_gate_t, output_gate_t, weight_h, state_bias,
hidden_state, cell_state, buffer_, lstm_param_); hidden_state, cell_state, buffer_, lstm_param_);
if (IsTrain() && IsTrainable()) {
RecordStates(cell_state, real_t);
}
} }
return RET_OK; return RET_OK;
} }
void LstmCPUKernel::RecordStates(float *cell_state, int step) {
float *workspace = reinterpret_cast<float *>(out_tensors_[kWorkspaceOutIdx]->MutableData());
workspace[step] = *cell_state;
}
int LstmCPUKernel::InnerExecute(float *output, const float *input, float *hidden_state, float *cell_state) { int LstmCPUKernel::InnerExecute(float *output, const float *input, float *hidden_state, float *cell_state) {
// forward // forward
// buffer_[packed_input_index] : store packed input // buffer_[packed_input_index] : store packed input

View File

@ -14,8 +14,8 @@
* limitations under the License. * limitations under the License.
*/ */
#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_LSTM_H_ #ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_LSTM_FP32_H_
#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_LSTM_H_ #define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_LSTM_FP32_H_
#include <vector> #include <vector>
#include "src/inner_kernel.h" #include "src/inner_kernel.h"
@ -49,6 +49,7 @@ class LstmCPUKernel : public InnerKernel {
int LstmUnidirectional(float *output, const float *weight_i, const float *weight_h, const float *input_bias, int LstmUnidirectional(float *output, const float *weight_i, const float *weight_h, const float *input_bias,
const float *state_bias, float *hidden_state, float *cell_state, bool is_backward); const float *state_bias, float *hidden_state, float *cell_state, bool is_backward);
int InnerExecute(float *output, const float *input, float *hidden_state, float *cell_state); int InnerExecute(float *output, const float *input, float *hidden_state, float *cell_state);
void RecordStates(float *cell_state, int step);
const float *weight_loop_; const float *weight_loop_;
const float *bias_loop_; const float *bias_loop_;
float *gate_loop_; float *gate_loop_;
@ -84,4 +85,4 @@ class LstmCPUKernel : public InnerKernel {
}; };
} // namespace mindspore::kernel } // namespace mindspore::kernel
#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_LSTM_H_ #endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_LSTM_FP32_H_

View File

@ -27,7 +27,7 @@
#include "nnacl/fp32_grad/batch_norm.h" #include "nnacl/fp32_grad/batch_norm.h"
#include "include/errorcode.h" #include "include/errorcode.h"
using mindspore::kernel::KERNEL_ARCH; using mindspore::kernel::KERNEL_ARCH::kCPU;
using mindspore::lite::KernelRegistrar; using mindspore::lite::KernelRegistrar;
using mindspore::lite::RET_ERROR; using mindspore::lite::RET_ERROR;
using mindspore::lite::RET_OK; using mindspore::lite::RET_OK;

View File

@ -18,17 +18,21 @@
#include "src/ops/populate/populate_register.h" #include "src/ops/populate/populate_register.h"
#include "src/ops/populate/default_populate.h" #include "src/ops/populate/default_populate.h"
#include "src/ops/populate/strided_slice_populate.h" #include "src/ops/populate/strided_slice_populate.h"
#include "nnacl/pooling_parameter.h"
#include "nnacl/fp32_grad/softmax_grad.h"
#include "nnacl/fp32/activation_fp32.h"
#include "nnacl/conv_parameter.h"
#include "nnacl/power_parameter.h"
#include "nnacl/arithmetic.h" #include "nnacl/arithmetic.h"
#include "nnacl/conv_parameter.h"
#include "nnacl/lstm_parameter.h"
#include "nnacl/pooling_parameter.h"
#include "nnacl/power_parameter.h"
#include "nnacl/fp32/activation_fp32.h"
#include "nnacl/fp32_grad/softmax_grad.h"
#include "nnacl/fp32_grad/optimizer.h" #include "nnacl/fp32_grad/optimizer.h"
#include "nnacl/fp32_grad/batch_norm.h" #include "nnacl/fp32_grad/batch_norm.h"
#include "nnacl/fp32_grad/dropout_parameter.h" #include "nnacl/fp32_grad/dropout_parameter.h"
#include "nnacl/fp32_grad/smooth_l1_loss.h" #include "nnacl/fp32_grad/smooth_l1_loss.h"
#include "nnacl/fp32_grad/resize_grad.h" #include "nnacl/fp32_grad/resize_grad.h"
using mindspore::lite::Registry;
namespace mindspore { namespace mindspore {
namespace kernel { namespace kernel {
namespace { namespace {
@ -471,6 +475,29 @@ OpParameter *PopulateStridedSliceGradParameter(const void *prim) {
return reinterpret_cast<OpParameter *>(strided_slice_param); return reinterpret_cast<OpParameter *>(strided_slice_param);
} }
OpParameter *PopulateLstmGradParameter(const void *prim) {
auto primitive = static_cast<const schema::Primitive *>(prim);
MS_ASSERT(primitive != nullptr);
auto value = primitive->value_as_LSTMGrad();
if (value == nullptr) {
MS_LOG(ERROR) << "value is nullptr.";
return nullptr;
}
auto *param = reinterpret_cast<LstmParameter *>(malloc(sizeof(LstmParameter)));
if (param == nullptr) {
MS_LOG(ERROR) << "malloc LstmParameter failed.";
return nullptr;
}
memset(param, 0, sizeof(LstmParameter));
param->op_parameter_.type_ = primitive->value_type();
param->bidirectional_ = value->bidirectional();
param->zoneout_cell_ = value->zoneout_cell();
param->zoneout_hidden_ = value->zoneout_hidden();
return reinterpret_cast<OpParameter *>(param);
}
void PopulateTrainParameters() { void PopulateTrainParameters() {
lite::Registry ApplyMomentumParameterRegistry(schema::PrimitiveType_ApplyMomentum, PopulateApplyMomentumParameter, lite::Registry ApplyMomentumParameterRegistry(schema::PrimitiveType_ApplyMomentum, PopulateApplyMomentumParameter,
lite::SCHEMA_CUR); lite::SCHEMA_CUR);
@ -534,10 +561,9 @@ void PopulateTrainParameters() {
lite::SCHEMA_CUR); lite::SCHEMA_CUR);
lite::Registry RsqrtGradParameterRegistry(schema::PrimitiveType_RsqrtGrad, lite::DefaultPopulateParameter, lite::Registry RsqrtGradParameterRegistry(schema::PrimitiveType_RsqrtGrad, lite::DefaultPopulateParameter,
lite::SCHEMA_CUR); lite::SCHEMA_CUR);
lite::Registry ResizeGradParameterRegistry(schema::PrimitiveType_ResizeGrad, PopulateResizeGradParameter, Registry ResizeGradParameterRegistry(schema::PrimitiveType_ResizeGrad, PopulateResizeGradParameter, lite::SCHEMA_CUR);
lite::SCHEMA_CUR); Registry AbsGradParameterRegistry(schema::PrimitiveType_AbsGrad, lite::DefaultPopulateParameter, lite::SCHEMA_CUR);
lite::Registry AbsGradParameterRegistry(schema::PrimitiveType_AbsGrad, lite::DefaultPopulateParameter, Registry LSTMGradParameterRegistry(schema::PrimitiveType_LSTMGrad, PopulateLstmGradParameter, lite::SCHEMA_CUR);
lite::SCHEMA_CUR);
} }
} // namespace kernel } // namespace kernel
} // namespace mindspore } // namespace mindspore

View File

@ -47,6 +47,7 @@ constexpr int kFieldsToPrint = 5;
constexpr int kPrintOffset = 4; constexpr int kPrintOffset = 4;
constexpr int kCPUBindFlag2 = 2; constexpr int kCPUBindFlag2 = 2;
constexpr int kCPUBindFlag1 = 1; constexpr int kCPUBindFlag1 = 1;
static const int kTHOUSAND = 1000;
namespace { namespace {
float *ReadFileBuf(const char *file, size_t *size) { float *ReadFileBuf(const char *file, size_t *size) {
@ -436,8 +437,8 @@ int NetTrain::CreateAndRunNetwork(const std::string &filename, const std::string
} }
auto end_prepare_time = GetTimeUs(); auto end_prepare_time = GetTimeUs();
MS_LOG(INFO) << "PrepareTime = " << ((end_prepare_time - start_prepare_time) / 1000) << " ms"; MS_LOG(INFO) << "PrepareTime = " << ((end_prepare_time - start_prepare_time) / kTHOUSAND) << " ms";
std::cout << "PrepareTime = " << ((end_prepare_time - start_prepare_time) / 1000) << " ms" << std::endl; std::cout << "PrepareTime = " << ((end_prepare_time - start_prepare_time) / kTHOUSAND) << " ms" << std::endl;
// Load input // Load input
MS_LOG(INFO) << "Load input data"; MS_LOG(INFO) << "Load input data";
auto ms_inputs = session->GetInputs(); auto ms_inputs = session->GetInputs();