From 1a0695aba6297f80c572fdef29fa654d2336d372 Mon Sep 17 00:00:00 2001 From: Emir Haleva Date: Sun, 1 Aug 2021 11:30:57 +0300 Subject: [PATCH] Add LSTMGrad infra and fixed static code analysis warnings --- .../cpu/nnacl/infer/infer_register.h | 3 +- mindspore/core/ops/grad/lstm_grad.cc | 98 +++++++++++++++++++ mindspore/core/ops/grad/lstm_grad.h | 64 ++++++++++++ mindspore/lite/schema/ops.fbs | 13 +++ mindspore/lite/src/ops/ops_def.cc | 13 +++ mindspore/lite/src/ops/ops_func_declare.h | 2 + mindspore/lite/src/ops/ops_utils.cc | 5 + .../kernel/arm/fp16_grad/dropout_fp16_grad.cc | 2 - .../kernel/arm/fp16_grad/pooling_fp16_grad.cc | 1 - .../kernel/arm/fp16_grad/resize_fp16_grad.cc | 1 - .../arm/fp16_grad/strided_slice_fp16_grad.cc | 1 - .../fp16_grad/unsorted_segment_sum_fp16.cc | 1 - .../src/runtime/kernel/arm/fp32/lstm_fp32.cc | 10 ++ .../src/runtime/kernel/arm/fp32/lstm_fp32.h | 7 +- .../runtime/kernel/arm/fp32_grad/bn_grad.cc | 2 +- .../src/train/train_populate_parameter.cc | 44 +++++++-- .../lite/tools/benchmark_train/net_train.cc | 5 +- 17 files changed, 250 insertions(+), 22 deletions(-) create mode 100644 mindspore/core/ops/grad/lstm_grad.cc create mode 100644 mindspore/core/ops/grad/lstm_grad.h diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/infer/infer_register.h b/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/infer/infer_register.h index b05c9017913..351e4f70086 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/infer/infer_register.h +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/infer/infer_register.h @@ -226,8 +226,9 @@ enum PrimType { PrimType_TensorArrayWrite = 199, PrimType_Affine = 200, PrimType_Attention = 201, + PrimType_LSTMGrad = 202, PrimType_MIN = PrimType_NONE, - PrimType_MAX = PrimType_Attention + 1 + PrimType_MAX = PrimType_LSTMGrad + 1 }; void RegInfer(int prim_type, InferShape func); diff --git a/mindspore/core/ops/grad/lstm_grad.cc b/mindspore/core/ops/grad/lstm_grad.cc new file mode 100644 index 00000000000..f85f684612c --- /dev/null +++ b/mindspore/core/ops/grad/lstm_grad.cc @@ -0,0 +1,98 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ops/grad/lstm_grad.h" +#include +#include "ops/op_utils.h" + +namespace mindspore { +namespace ops { +namespace { +AbstractBasePtr LstmGradInfer(const PrimitivePtr &primitive, const std::vector &input_args) { + // infer shape + MS_EXCEPTION_IF_NULL(primitive); + return nullptr; +} +} // namespace + +void LSTMGrad::set_input_size(const int64_t input_size) { + CheckAndConvertUtils::CheckInteger(kInput_size, input_size, kGreaterThan, 0, this->name()); + AddAttr(kInput_size, MakeValue(input_size)); +} +int64_t LSTMGrad::get_input_size() const { return GetValue(GetAttr(kInput_size)); } +void LSTMGrad::set_hidden_size(const int64_t hidden_size) { + CheckAndConvertUtils::CheckInteger(kHidden_size, hidden_size, kGreaterThan, 0, this->name()); + AddAttr(kHidden_size, MakeValue(hidden_size)); +} +int64_t LSTMGrad::get_hidden_size() const { return GetValue(GetAttr(kHidden_size)); } +void LSTMGrad::set_num_layers(const int64_t num_layers) { + CheckAndConvertUtils::CheckInteger(kNumLayers, num_layers, kGreaterThan, 0, this->name()); + AddAttr(kNumLayers, MakeValue(num_layers)); +} +int64_t LSTMGrad::get_num_layers() const { return GetValue(GetAttr(kNumLayers)); } +void LSTMGrad::set_has_bias(const bool has_bias) { AddAttr(kHasBias, MakeValue(has_bias)); } +bool LSTMGrad::get_has_bias() const { + auto value_ptr = this->GetAttr(kHasBias); + return GetValue(value_ptr); +} +void LSTMGrad::set_dropout(const float dropout) { + CheckAndConvertUtils::CheckInRange(kDropout, dropout, kIncludeBoth, {0.0, 1.0}, this->name()); + AddAttr(kDropout, MakeValue(dropout)); +} +float LSTMGrad::get_dropout() const { + auto value_ptr = this->GetAttr(kDropout); + return GetValue(value_ptr); +} +void LSTMGrad::set_bidirectional(const bool bidirectional) { AddAttr(kBidirectional, MakeValue(bidirectional)); } +bool LSTMGrad::get_bidirectional() const { + auto value_ptr = this->GetAttr(kBidirectional); + return GetValue(value_ptr); +} +void LSTMGrad::set_num_directions(const int64_t num_directions) { AddAttr(kNumDirections, MakeValue(num_directions)); } +int64_t LSTMGrad::get_num_directions() const { return GetValue(GetAttr(kNumDirections)); } +void LSTMGrad::set_zoneout_cell(float zoneout_cell) { AddAttr(kZoneoutCell, MakeValue(zoneout_cell)); } + +float LSTMGrad::get_zoneout_cell() const { return GetValue(this->GetAttr(kZoneoutCell)); } + +void LSTMGrad::set_zoneout_hidden(float zoneout_hidden) { AddAttr(kZoneoutHidden, MakeValue(zoneout_hidden)); } + +float LSTMGrad::get_zoneout_hidden() const { return GetValue(this->GetAttr(kZoneoutHidden)); } + +void LSTMGrad::Init(const int64_t input_size, const int64_t hidden_size, const int64_t num_layers, const bool has_bias, + const float dropout, const bool bidirectional, const float zoneout_cell, + const float zoneout_hidden) { + this->set_input_size(input_size); + this->set_hidden_size(hidden_size); + this->set_num_layers(num_layers); + this->set_has_bias(has_bias); + this->set_dropout(dropout); + this->set_bidirectional(bidirectional); + if (bidirectional) { + this->set_num_directions(2); + } else { + this->set_num_directions(1); + } + this->set_zoneout_cell(zoneout_cell); + this->set_zoneout_hidden(zoneout_hidden); +} + +AbstractBasePtr LstmGradInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, + const std::vector &input_args) { + return std::make_shared(LstmGradInfer(primitive, input_args)); +} +REGISTER_PRIMITIVE_C(kNameLSTMGrad, LSTMGrad); +} // namespace ops +} // namespace mindspore diff --git a/mindspore/core/ops/grad/lstm_grad.h b/mindspore/core/ops/grad/lstm_grad.h new file mode 100644 index 00000000000..f91323d2c2b --- /dev/null +++ b/mindspore/core/ops/grad/lstm_grad.h @@ -0,0 +1,64 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CORE_OPS_LSTM_GRAD_H_ +#define MINDSPORE_CORE_OPS_LSTM_GRAD_H_ +#include +#include +#include +#include +#include "ops/primitive_c.h" +#include "abstract/abstract_value.h" +#include "utils/check_convert_utils.h" + +namespace mindspore { +namespace ops { +constexpr auto kNameLSTMGrad = "LSTMGrad"; +class LSTMGrad : public PrimitiveC { + public: + LSTMGrad() : PrimitiveC(kNameLSTMGrad) {} + ~LSTMGrad() = default; + MS_DECLARE_PARENT(LSTMGrad, PrimitiveC); + void Init(const int64_t input_size, const int64_t hidden_size, const int64_t num_layers, const bool has_bias, + const float dropout, const bool bidirectional = false, const float zoneout_cell = 0.0f, + const float zoneout_hidden = 0.0f); + void set_input_size(const int64_t input_size); + int64_t get_input_size() const; + void set_hidden_size(const int64_t hidden_size); + int64_t get_hidden_size() const; + void set_num_layers(const int64_t num_layers); + int64_t get_num_layers() const; + void set_has_bias(const bool has_bias); + bool get_has_bias() const; + void set_dropout(const float dropout); + float get_dropout() const; + void set_bidirectional(const bool bidirectional); + bool get_bidirectional() const; + void set_num_directions(const int64_t num_directions); + int64_t get_num_directions() const; + void set_zoneout_cell(float zoneout_cell); + float get_zoneout_cell() const; + void set_zoneout_hidden(float zoneout_hidden); + float get_zoneout_hidden() const; + int64_t get_good_ld(const int64_t dim, const int64_t type_size); +}; +AbstractBasePtr LstmGradInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, + const std::vector &input_args); +using PrimLstmGradPtr = std::shared_ptr; +} // namespace ops +} // namespace mindspore + +#endif // MINDSPORE_CORE_OPS_LSTM_GRAD_H_ diff --git a/mindspore/lite/schema/ops.fbs b/mindspore/lite/schema/ops.fbs index 7efe58bc43d..ded169b7171 100644 --- a/mindspore/lite/schema/ops.fbs +++ b/mindspore/lite/schema/ops.fbs @@ -219,6 +219,7 @@ union PrimitiveType { TensorArrayWrite, Affine, Attention, + LSTMGrad, } table Abs { @@ -665,6 +666,18 @@ table LSTM { zoneout_hidden: float = 0; } +table LSTMGrad { + bidirectional: bool; + has_bias: bool; + input_size: long; + hidden_size: long; + num_layers: long; + num_directions: long; + dropout: float; + zoneout_cell: float = 0; + zoneout_hidden: float = 0; +} + table L2NormalizeFusion { axis: [long]; epsilon: float; diff --git a/mindspore/lite/src/ops/ops_def.cc b/mindspore/lite/src/ops/ops_def.cc index 70972160131..b64ca1619fb 100644 --- a/mindspore/lite/src/ops/ops_def.cc +++ b/mindspore/lite/src/ops/ops_def.cc @@ -219,6 +219,7 @@ OP_TYPE(TensorArrayWrite) // kaldi affine op OP_TYPE(Affine) OP_TYPE(Attention) +OP_TYPE(LSTMGrad) OP_TYPE_DEF_END(PrimitiveType) OP_SCHEMA_DEF(Abs) @@ -665,6 +666,18 @@ OP_ATTR_WITH_VALUE(zoneout_cell, float, 0) OP_ATTR_WITH_VALUE(zoneout_hidden, float, 0) OP_SCHEMA_DEF_END(LSTM) +OP_SCHEMA_DEF(LSTMGrad) +OP_ATTR(bidirectional, bool) +OP_ATTR(has_bias, bool) +OP_ATTR(input_size, long) +OP_ATTR(hidden_size, long) +OP_ATTR(num_layers, long) +OP_ATTR(num_directions, long) +OP_ATTR(dropout, float) +OP_ATTR_WITH_VALUE(zoneout_cell, float, 0) +OP_ATTR_WITH_VALUE(zoneout_hidden, float, 0) +OP_SCHEMA_DEF_END(LSTMGrad) + OP_SCHEMA_DEF(L2NormalizeFusion) OP_ATTR(axis, [long]) OP_ATTR(epsilon, float) diff --git a/mindspore/lite/src/ops/ops_func_declare.h b/mindspore/lite/src/ops/ops_func_declare.h index c6b7b5fdc9f..a2dee794b4e 100644 --- a/mindspore/lite/src/ops/ops_func_declare.h +++ b/mindspore/lite/src/ops/ops_func_declare.h @@ -183,6 +183,7 @@ #include "ops/grad/group_conv2d_grad_input.h" #include "ops/grad/layer_norm_grad.h" #include "ops/grad/log_grad.h" +#include "ops/grad/lstm_grad.h" #include "ops/grad/max_pool_grad.h" #include "ops/grad/maximum_grad.h" #include "ops/grad/minimum_grad.h" @@ -344,6 +345,7 @@ FUNC_MSOP2SCHEMAOP_DECLARE(LpNormalization) FUNC_MSOP2SCHEMAOP_DECLARE(LRN) FUNC_MSOP2SCHEMAOP_DECLARE(LshProjection) FUNC_MSOP2SCHEMAOP_DECLARE(LSTM) +FUNC_MSOP2SCHEMAOP_DECLARE(LSTMGrad) FUNC_MSOP2SCHEMAOP_DECLARE(L2NormalizeFusion) FUNC_MSOP2SCHEMAOP_DECLARE(MatMul) FUNC_MSOP2SCHEMAOP_DECLARE(Maximum) diff --git a/mindspore/lite/src/ops/ops_utils.cc b/mindspore/lite/src/ops/ops_utils.cc index 0eb61c703de..10a23304de7 100644 --- a/mindspore/lite/src/ops/ops_utils.cc +++ b/mindspore/lite/src/ops/ops_utils.cc @@ -393,6 +393,10 @@ std::unique_ptr LSTMPrimitiveCreator(const AnfNodePtr &node) auto ms_primc = GetValueNode>(node); return ms_primc != nullptr ? ops::MSOp2SchemaOp(ms_primc.get()) : nullptr; } +std::unique_ptr LSTMGradPrimitiveCreator(const AnfNodePtr &node) { + auto ms_primc = GetValueNode>(node); + return ms_primc != nullptr ? ops::MSOp2SchemaOp(ms_primc.get()) : nullptr; +} std::unique_ptr L2NormalizeFusionPrimitiveCreator(const AnfNodePtr &node) { auto ms_primc = GetValueNode>(node); return ms_primc != nullptr ? ops::MSOp2SchemaOp(ms_primc.get()) : nullptr; @@ -913,6 +917,7 @@ RegistryMSOps g_lpNormalizationPrimitiveCreatorRegistry("LpNormalization", LpNor RegistryMSOps g_lrnPrimitiveCreatorRegistry("LRN", LrnPrimitiveCreator); RegistryMSOps g_lshProjectionPrimitiveCreatorRegistry("LshProjection", LshProjectionPrimitiveCreator); RegistryMSOps g_lSTMPrimitiveCreatorRegistry("LSTM", LSTMPrimitiveCreator); +RegistryMSOps g_lSTMGradPrimitiveCreatorRegistry("LSTMGrad", LSTMGradPrimitiveCreator); RegistryMSOps g_l2NormalizeFusionPrimitiveCreatorRegistry("L2NormalizeFusion", L2NormalizeFusionPrimitiveCreator); RegistryMSOps g_matMulPrimitiveCreatorRegistry("MatMul", MatMulPrimitiveCreator); RegistryMSOps g_maximumPrimitiveCreatorRegistry("Maximum", MaximumPrimitiveCreator); diff --git a/mindspore/lite/src/runtime/kernel/arm/fp16_grad/dropout_fp16_grad.cc b/mindspore/lite/src/runtime/kernel/arm/fp16_grad/dropout_fp16_grad.cc index 706cfbd8698..d9dca4254d9 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp16_grad/dropout_fp16_grad.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp16_grad/dropout_fp16_grad.cc @@ -30,7 +30,6 @@ using mindspore::lite::RET_OK; using mindspore::schema::PrimitiveType_DropoutGrad; namespace mindspore::kernel { - int DropoutGradCPUKernelFp16::Init() { auto param = reinterpret_cast(op_parameter_); if (param == nullptr) { @@ -63,7 +62,6 @@ int DropoutGradCPUKernelFp16::Execute(int task_id) { auto length = in_tensors_.at(kInputIndex)->ElementsNum(); int stride = UP_DIV(length, thread_count_); int count = MSMIN(stride, length - stride * task_id); - if (count > 0) { int start = stride * task_id; DropoutFp16Grad(&(yt_ptr[start]), &(mask_ptr[start]), &(output_ptr[start]), count, (float16_t)scale_); diff --git a/mindspore/lite/src/runtime/kernel/arm/fp16_grad/pooling_fp16_grad.cc b/mindspore/lite/src/runtime/kernel/arm/fp16_grad/pooling_fp16_grad.cc index b1f10fc6e93..a4d557d84ad 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp16_grad/pooling_fp16_grad.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp16_grad/pooling_fp16_grad.cc @@ -67,7 +67,6 @@ int PoolingGradCPUKernelFp16::Execute(int task_id) { auto output_ptr = reinterpret_cast(out_tensors_.at(0)->data_c()); int stride = UP_DIV(pool_param->output_batch_, thread_num_); int count = MSMIN(stride, pool_param->output_batch_ - stride * task_id); - if (count > 0) { int in_batch_size = pool_param->input_h_ * pool_param->input_w_ * pool_param->input_channel_; int out_batch_size = pool_param->output_h_ * pool_param->output_w_ * pool_param->input_channel_; diff --git a/mindspore/lite/src/runtime/kernel/arm/fp16_grad/resize_fp16_grad.cc b/mindspore/lite/src/runtime/kernel/arm/fp16_grad/resize_fp16_grad.cc index 018ee92d704..74175c9e9b2 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp16_grad/resize_fp16_grad.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp16_grad/resize_fp16_grad.cc @@ -28,7 +28,6 @@ using mindspore::lite::RET_OK; using mindspore::schema::PrimitiveType_ResizeGrad; namespace mindspore::kernel { - float16_t ScalingFp16(size_t in_size, size_t out_size, bool align_corners) { return (align_corners && out_size > 1) ? (in_size - 1) / static_cast(out_size - 1) : in_size / static_cast(out_size); diff --git a/mindspore/lite/src/runtime/kernel/arm/fp16_grad/strided_slice_fp16_grad.cc b/mindspore/lite/src/runtime/kernel/arm/fp16_grad/strided_slice_fp16_grad.cc index e5414ed79d6..5c952d080be 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp16_grad/strided_slice_fp16_grad.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp16_grad/strided_slice_fp16_grad.cc @@ -30,7 +30,6 @@ using mindspore::lite::RET_OK; using mindspore::schema::PrimitiveType_StridedSliceGrad; namespace mindspore::kernel { - int StridedSliceGradCPUKernelFp16::Init() { if (!InferShapeDone()) { return RET_OK; diff --git a/mindspore/lite/src/runtime/kernel/arm/fp16_grad/unsorted_segment_sum_fp16.cc b/mindspore/lite/src/runtime/kernel/arm/fp16_grad/unsorted_segment_sum_fp16.cc index 4c1cfcccd0b..ac0afe73d0e 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp16_grad/unsorted_segment_sum_fp16.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp16_grad/unsorted_segment_sum_fp16.cc @@ -29,7 +29,6 @@ using mindspore::lite::RET_OK; using mindspore::schema::PrimitiveType_UnsortedSegmentSum; namespace mindspore::kernel { - int UnsortedSegmentSumCPUKernelFp16::Init() { if (!InferShapeDone()) { return RET_OK; diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/lstm_fp32.cc b/mindspore/lite/src/runtime/kernel/arm/fp32/lstm_fp32.cc index 4610e12af62..cb0e4ec6b1c 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32/lstm_fp32.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/lstm_fp32.cc @@ -30,6 +30,8 @@ using mindspore::lite::RET_OK; using mindspore::schema::PrimitiveType_LSTM; namespace mindspore::kernel { +constexpr static int kWorkspaceOutIdx = 4; + void LstmCPUKernel::FreeTmpBuffer() { if (weight_i_ptr_ != nullptr) { free(weight_i_ptr_); @@ -361,10 +363,18 @@ int LstmCPUKernel::LstmUnidirectional(float *output, const float *weight_i, cons float *output_ptr = output + real_t * lstm_param_->output_step_; LstmStepUnit(output_ptr, input_gate_t, forget_gate_t, cell_gate_t, output_gate_t, weight_h, state_bias, hidden_state, cell_state, buffer_, lstm_param_); + if (IsTrain() && IsTrainable()) { + RecordStates(cell_state, real_t); + } } return RET_OK; } +void LstmCPUKernel::RecordStates(float *cell_state, int step) { + float *workspace = reinterpret_cast(out_tensors_[kWorkspaceOutIdx]->MutableData()); + workspace[step] = *cell_state; +} + int LstmCPUKernel::InnerExecute(float *output, const float *input, float *hidden_state, float *cell_state) { // forward // buffer_[packed_input_index] : store packed input diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/lstm_fp32.h b/mindspore/lite/src/runtime/kernel/arm/fp32/lstm_fp32.h index d9d02d444d0..0d18f316520 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32/lstm_fp32.h +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/lstm_fp32.h @@ -14,8 +14,8 @@ * limitations under the License. */ -#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_LSTM_H_ -#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_LSTM_H_ +#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_LSTM_FP32_H_ +#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_LSTM_FP32_H_ #include #include "src/inner_kernel.h" @@ -49,6 +49,7 @@ class LstmCPUKernel : public InnerKernel { int LstmUnidirectional(float *output, const float *weight_i, const float *weight_h, const float *input_bias, const float *state_bias, float *hidden_state, float *cell_state, bool is_backward); int InnerExecute(float *output, const float *input, float *hidden_state, float *cell_state); + void RecordStates(float *cell_state, int step); const float *weight_loop_; const float *bias_loop_; float *gate_loop_; @@ -84,4 +85,4 @@ class LstmCPUKernel : public InnerKernel { }; } // namespace mindspore::kernel -#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_LSTM_H_ +#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_LSTM_FP32_H_ diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32_grad/bn_grad.cc b/mindspore/lite/src/runtime/kernel/arm/fp32_grad/bn_grad.cc index dc6eeca17ee..cc85d51e31c 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32_grad/bn_grad.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp32_grad/bn_grad.cc @@ -27,7 +27,7 @@ #include "nnacl/fp32_grad/batch_norm.h" #include "include/errorcode.h" -using mindspore::kernel::KERNEL_ARCH; +using mindspore::kernel::KERNEL_ARCH::kCPU; using mindspore::lite::KernelRegistrar; using mindspore::lite::RET_ERROR; using mindspore::lite::RET_OK; diff --git a/mindspore/lite/src/train/train_populate_parameter.cc b/mindspore/lite/src/train/train_populate_parameter.cc index f53764b7d78..cf7257404a5 100644 --- a/mindspore/lite/src/train/train_populate_parameter.cc +++ b/mindspore/lite/src/train/train_populate_parameter.cc @@ -18,17 +18,21 @@ #include "src/ops/populate/populate_register.h" #include "src/ops/populate/default_populate.h" #include "src/ops/populate/strided_slice_populate.h" -#include "nnacl/pooling_parameter.h" -#include "nnacl/fp32_grad/softmax_grad.h" -#include "nnacl/fp32/activation_fp32.h" -#include "nnacl/conv_parameter.h" -#include "nnacl/power_parameter.h" #include "nnacl/arithmetic.h" +#include "nnacl/conv_parameter.h" +#include "nnacl/lstm_parameter.h" +#include "nnacl/pooling_parameter.h" +#include "nnacl/power_parameter.h" +#include "nnacl/fp32/activation_fp32.h" +#include "nnacl/fp32_grad/softmax_grad.h" #include "nnacl/fp32_grad/optimizer.h" #include "nnacl/fp32_grad/batch_norm.h" #include "nnacl/fp32_grad/dropout_parameter.h" #include "nnacl/fp32_grad/smooth_l1_loss.h" #include "nnacl/fp32_grad/resize_grad.h" + +using mindspore::lite::Registry; + namespace mindspore { namespace kernel { namespace { @@ -471,6 +475,29 @@ OpParameter *PopulateStridedSliceGradParameter(const void *prim) { return reinterpret_cast(strided_slice_param); } +OpParameter *PopulateLstmGradParameter(const void *prim) { + auto primitive = static_cast(prim); + MS_ASSERT(primitive != nullptr); + auto value = primitive->value_as_LSTMGrad(); + if (value == nullptr) { + MS_LOG(ERROR) << "value is nullptr."; + return nullptr; + } + + auto *param = reinterpret_cast(malloc(sizeof(LstmParameter))); + if (param == nullptr) { + MS_LOG(ERROR) << "malloc LstmParameter failed."; + return nullptr; + } + memset(param, 0, sizeof(LstmParameter)); + + param->op_parameter_.type_ = primitive->value_type(); + param->bidirectional_ = value->bidirectional(); + param->zoneout_cell_ = value->zoneout_cell(); + param->zoneout_hidden_ = value->zoneout_hidden(); + return reinterpret_cast(param); +} + void PopulateTrainParameters() { lite::Registry ApplyMomentumParameterRegistry(schema::PrimitiveType_ApplyMomentum, PopulateApplyMomentumParameter, lite::SCHEMA_CUR); @@ -534,10 +561,9 @@ void PopulateTrainParameters() { lite::SCHEMA_CUR); lite::Registry RsqrtGradParameterRegistry(schema::PrimitiveType_RsqrtGrad, lite::DefaultPopulateParameter, lite::SCHEMA_CUR); - lite::Registry ResizeGradParameterRegistry(schema::PrimitiveType_ResizeGrad, PopulateResizeGradParameter, - lite::SCHEMA_CUR); - lite::Registry AbsGradParameterRegistry(schema::PrimitiveType_AbsGrad, lite::DefaultPopulateParameter, - lite::SCHEMA_CUR); + Registry ResizeGradParameterRegistry(schema::PrimitiveType_ResizeGrad, PopulateResizeGradParameter, lite::SCHEMA_CUR); + Registry AbsGradParameterRegistry(schema::PrimitiveType_AbsGrad, lite::DefaultPopulateParameter, lite::SCHEMA_CUR); + Registry LSTMGradParameterRegistry(schema::PrimitiveType_LSTMGrad, PopulateLstmGradParameter, lite::SCHEMA_CUR); } } // namespace kernel } // namespace mindspore diff --git a/mindspore/lite/tools/benchmark_train/net_train.cc b/mindspore/lite/tools/benchmark_train/net_train.cc index 43c2a6c8a0e..8a2333b037f 100644 --- a/mindspore/lite/tools/benchmark_train/net_train.cc +++ b/mindspore/lite/tools/benchmark_train/net_train.cc @@ -47,6 +47,7 @@ constexpr int kFieldsToPrint = 5; constexpr int kPrintOffset = 4; constexpr int kCPUBindFlag2 = 2; constexpr int kCPUBindFlag1 = 1; +static const int kTHOUSAND = 1000; namespace { float *ReadFileBuf(const char *file, size_t *size) { @@ -436,8 +437,8 @@ int NetTrain::CreateAndRunNetwork(const std::string &filename, const std::string } auto end_prepare_time = GetTimeUs(); - MS_LOG(INFO) << "PrepareTime = " << ((end_prepare_time - start_prepare_time) / 1000) << " ms"; - std::cout << "PrepareTime = " << ((end_prepare_time - start_prepare_time) / 1000) << " ms" << std::endl; + MS_LOG(INFO) << "PrepareTime = " << ((end_prepare_time - start_prepare_time) / kTHOUSAND) << " ms"; + std::cout << "PrepareTime = " << ((end_prepare_time - start_prepare_time) / kTHOUSAND) << " ms" << std::endl; // Load input MS_LOG(INFO) << "Load input data"; auto ms_inputs = session->GetInputs();