forked from mindspore-Ecosystem/mindspore
!21204 [MS][Lite][ToD] Lay LSTMGrad infrastructure and fix static code analysis warnings
Merge pull request !21204 from ehaleva/export
This commit is contained in:
commit
e948611be1
|
@ -226,8 +226,9 @@ enum PrimType {
|
|||
PrimType_TensorArrayWrite = 199,
|
||||
PrimType_Affine = 200,
|
||||
PrimType_Attention = 201,
|
||||
PrimType_LSTMGrad = 202,
|
||||
PrimType_MIN = PrimType_NONE,
|
||||
PrimType_MAX = PrimType_Attention + 1
|
||||
PrimType_MAX = PrimType_LSTMGrad + 1
|
||||
};
|
||||
|
||||
void RegInfer(int prim_type, InferShape func);
|
||||
|
|
|
@ -0,0 +1,98 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "ops/grad/lstm_grad.h"
|
||||
#include <algorithm>
|
||||
#include "ops/op_utils.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ops {
|
||||
namespace {
|
||||
AbstractBasePtr LstmGradInfer(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
|
||||
// infer shape
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
return nullptr;
|
||||
}
|
||||
} // namespace
|
||||
|
||||
void LSTMGrad::set_input_size(const int64_t input_size) {
|
||||
CheckAndConvertUtils::CheckInteger(kInput_size, input_size, kGreaterThan, 0, this->name());
|
||||
AddAttr(kInput_size, MakeValue(input_size));
|
||||
}
|
||||
int64_t LSTMGrad::get_input_size() const { return GetValue<int64_t>(GetAttr(kInput_size)); }
|
||||
void LSTMGrad::set_hidden_size(const int64_t hidden_size) {
|
||||
CheckAndConvertUtils::CheckInteger(kHidden_size, hidden_size, kGreaterThan, 0, this->name());
|
||||
AddAttr(kHidden_size, MakeValue(hidden_size));
|
||||
}
|
||||
int64_t LSTMGrad::get_hidden_size() const { return GetValue<int64_t>(GetAttr(kHidden_size)); }
|
||||
void LSTMGrad::set_num_layers(const int64_t num_layers) {
|
||||
CheckAndConvertUtils::CheckInteger(kNumLayers, num_layers, kGreaterThan, 0, this->name());
|
||||
AddAttr(kNumLayers, MakeValue(num_layers));
|
||||
}
|
||||
int64_t LSTMGrad::get_num_layers() const { return GetValue<int64_t>(GetAttr(kNumLayers)); }
|
||||
void LSTMGrad::set_has_bias(const bool has_bias) { AddAttr(kHasBias, MakeValue(has_bias)); }
|
||||
bool LSTMGrad::get_has_bias() const {
|
||||
auto value_ptr = this->GetAttr(kHasBias);
|
||||
return GetValue<bool>(value_ptr);
|
||||
}
|
||||
void LSTMGrad::set_dropout(const float dropout) {
|
||||
CheckAndConvertUtils::CheckInRange<float>(kDropout, dropout, kIncludeBoth, {0.0, 1.0}, this->name());
|
||||
AddAttr(kDropout, MakeValue(dropout));
|
||||
}
|
||||
float LSTMGrad::get_dropout() const {
|
||||
auto value_ptr = this->GetAttr(kDropout);
|
||||
return GetValue<float>(value_ptr);
|
||||
}
|
||||
void LSTMGrad::set_bidirectional(const bool bidirectional) { AddAttr(kBidirectional, MakeValue(bidirectional)); }
|
||||
bool LSTMGrad::get_bidirectional() const {
|
||||
auto value_ptr = this->GetAttr(kBidirectional);
|
||||
return GetValue<bool>(value_ptr);
|
||||
}
|
||||
void LSTMGrad::set_num_directions(const int64_t num_directions) { AddAttr(kNumDirections, MakeValue(num_directions)); }
|
||||
int64_t LSTMGrad::get_num_directions() const { return GetValue<int64_t>(GetAttr(kNumDirections)); }
|
||||
void LSTMGrad::set_zoneout_cell(float zoneout_cell) { AddAttr(kZoneoutCell, MakeValue(zoneout_cell)); }
|
||||
|
||||
float LSTMGrad::get_zoneout_cell() const { return GetValue<float>(this->GetAttr(kZoneoutCell)); }
|
||||
|
||||
void LSTMGrad::set_zoneout_hidden(float zoneout_hidden) { AddAttr(kZoneoutHidden, MakeValue(zoneout_hidden)); }
|
||||
|
||||
float LSTMGrad::get_zoneout_hidden() const { return GetValue<float>(this->GetAttr(kZoneoutHidden)); }
|
||||
|
||||
void LSTMGrad::Init(const int64_t input_size, const int64_t hidden_size, const int64_t num_layers, const bool has_bias,
|
||||
const float dropout, const bool bidirectional, const float zoneout_cell,
|
||||
const float zoneout_hidden) {
|
||||
this->set_input_size(input_size);
|
||||
this->set_hidden_size(hidden_size);
|
||||
this->set_num_layers(num_layers);
|
||||
this->set_has_bias(has_bias);
|
||||
this->set_dropout(dropout);
|
||||
this->set_bidirectional(bidirectional);
|
||||
if (bidirectional) {
|
||||
this->set_num_directions(2);
|
||||
} else {
|
||||
this->set_num_directions(1);
|
||||
}
|
||||
this->set_zoneout_cell(zoneout_cell);
|
||||
this->set_zoneout_hidden(zoneout_hidden);
|
||||
}
|
||||
|
||||
AbstractBasePtr LstmGradInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
return std::make_shared<abstract::AbstractTensor>(LstmGradInfer(primitive, input_args));
|
||||
}
|
||||
REGISTER_PRIMITIVE_C(kNameLSTMGrad, LSTMGrad);
|
||||
} // namespace ops
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,64 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CORE_OPS_LSTM_GRAD_H_
|
||||
#define MINDSPORE_CORE_OPS_LSTM_GRAD_H_
|
||||
#include <map>
|
||||
#include <vector>
|
||||
#include <string>
|
||||
#include <memory>
|
||||
#include "ops/primitive_c.h"
|
||||
#include "abstract/abstract_value.h"
|
||||
#include "utils/check_convert_utils.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ops {
|
||||
constexpr auto kNameLSTMGrad = "LSTMGrad";
|
||||
class LSTMGrad : public PrimitiveC {
|
||||
public:
|
||||
LSTMGrad() : PrimitiveC(kNameLSTMGrad) {}
|
||||
~LSTMGrad() = default;
|
||||
MS_DECLARE_PARENT(LSTMGrad, PrimitiveC);
|
||||
void Init(const int64_t input_size, const int64_t hidden_size, const int64_t num_layers, const bool has_bias,
|
||||
const float dropout, const bool bidirectional = false, const float zoneout_cell = 0.0f,
|
||||
const float zoneout_hidden = 0.0f);
|
||||
void set_input_size(const int64_t input_size);
|
||||
int64_t get_input_size() const;
|
||||
void set_hidden_size(const int64_t hidden_size);
|
||||
int64_t get_hidden_size() const;
|
||||
void set_num_layers(const int64_t num_layers);
|
||||
int64_t get_num_layers() const;
|
||||
void set_has_bias(const bool has_bias);
|
||||
bool get_has_bias() const;
|
||||
void set_dropout(const float dropout);
|
||||
float get_dropout() const;
|
||||
void set_bidirectional(const bool bidirectional);
|
||||
bool get_bidirectional() const;
|
||||
void set_num_directions(const int64_t num_directions);
|
||||
int64_t get_num_directions() const;
|
||||
void set_zoneout_cell(float zoneout_cell);
|
||||
float get_zoneout_cell() const;
|
||||
void set_zoneout_hidden(float zoneout_hidden);
|
||||
float get_zoneout_hidden() const;
|
||||
int64_t get_good_ld(const int64_t dim, const int64_t type_size);
|
||||
};
|
||||
AbstractBasePtr LstmGradInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args);
|
||||
using PrimLstmGradPtr = std::shared_ptr<LSTMGrad>;
|
||||
} // namespace ops
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_CORE_OPS_LSTM_GRAD_H_
|
|
@ -219,6 +219,7 @@ union PrimitiveType {
|
|||
TensorArrayWrite,
|
||||
Affine,
|
||||
Attention,
|
||||
LSTMGrad,
|
||||
}
|
||||
|
||||
table Abs {
|
||||
|
@ -665,6 +666,18 @@ table LSTM {
|
|||
zoneout_hidden: float = 0;
|
||||
}
|
||||
|
||||
table LSTMGrad {
|
||||
bidirectional: bool;
|
||||
has_bias: bool;
|
||||
input_size: long;
|
||||
hidden_size: long;
|
||||
num_layers: long;
|
||||
num_directions: long;
|
||||
dropout: float;
|
||||
zoneout_cell: float = 0;
|
||||
zoneout_hidden: float = 0;
|
||||
}
|
||||
|
||||
table L2NormalizeFusion {
|
||||
axis: [long];
|
||||
epsilon: float;
|
||||
|
|
|
@ -219,6 +219,7 @@ OP_TYPE(TensorArrayWrite)
|
|||
// kaldi affine op
|
||||
OP_TYPE(Affine)
|
||||
OP_TYPE(Attention)
|
||||
OP_TYPE(LSTMGrad)
|
||||
OP_TYPE_DEF_END(PrimitiveType)
|
||||
|
||||
OP_SCHEMA_DEF(Abs)
|
||||
|
@ -665,6 +666,18 @@ OP_ATTR_WITH_VALUE(zoneout_cell, float, 0)
|
|||
OP_ATTR_WITH_VALUE(zoneout_hidden, float, 0)
|
||||
OP_SCHEMA_DEF_END(LSTM)
|
||||
|
||||
OP_SCHEMA_DEF(LSTMGrad)
|
||||
OP_ATTR(bidirectional, bool)
|
||||
OP_ATTR(has_bias, bool)
|
||||
OP_ATTR(input_size, long)
|
||||
OP_ATTR(hidden_size, long)
|
||||
OP_ATTR(num_layers, long)
|
||||
OP_ATTR(num_directions, long)
|
||||
OP_ATTR(dropout, float)
|
||||
OP_ATTR_WITH_VALUE(zoneout_cell, float, 0)
|
||||
OP_ATTR_WITH_VALUE(zoneout_hidden, float, 0)
|
||||
OP_SCHEMA_DEF_END(LSTMGrad)
|
||||
|
||||
OP_SCHEMA_DEF(L2NormalizeFusion)
|
||||
OP_ATTR(axis, [long])
|
||||
OP_ATTR(epsilon, float)
|
||||
|
|
|
@ -183,6 +183,7 @@
|
|||
#include "ops/grad/group_conv2d_grad_input.h"
|
||||
#include "ops/grad/layer_norm_grad.h"
|
||||
#include "ops/grad/log_grad.h"
|
||||
#include "ops/grad/lstm_grad.h"
|
||||
#include "ops/grad/max_pool_grad.h"
|
||||
#include "ops/grad/maximum_grad.h"
|
||||
#include "ops/grad/minimum_grad.h"
|
||||
|
@ -344,6 +345,7 @@ FUNC_MSOP2SCHEMAOP_DECLARE(LpNormalization)
|
|||
FUNC_MSOP2SCHEMAOP_DECLARE(LRN)
|
||||
FUNC_MSOP2SCHEMAOP_DECLARE(LshProjection)
|
||||
FUNC_MSOP2SCHEMAOP_DECLARE(LSTM)
|
||||
FUNC_MSOP2SCHEMAOP_DECLARE(LSTMGrad)
|
||||
FUNC_MSOP2SCHEMAOP_DECLARE(L2NormalizeFusion)
|
||||
FUNC_MSOP2SCHEMAOP_DECLARE(MatMul)
|
||||
FUNC_MSOP2SCHEMAOP_DECLARE(Maximum)
|
||||
|
|
|
@ -393,6 +393,10 @@ std::unique_ptr<schema::PrimitiveT> LSTMPrimitiveCreator(const AnfNodePtr &node)
|
|||
auto ms_primc = GetValueNode<std::shared_ptr<mindspore::ops::LSTM>>(node);
|
||||
return ms_primc != nullptr ? ops::MSOp2SchemaOp(ms_primc.get()) : nullptr;
|
||||
}
|
||||
std::unique_ptr<schema::PrimitiveT> LSTMGradPrimitiveCreator(const AnfNodePtr &node) {
|
||||
auto ms_primc = GetValueNode<std::shared_ptr<mindspore::ops::LSTMGrad>>(node);
|
||||
return ms_primc != nullptr ? ops::MSOp2SchemaOp(ms_primc.get()) : nullptr;
|
||||
}
|
||||
std::unique_ptr<schema::PrimitiveT> L2NormalizeFusionPrimitiveCreator(const AnfNodePtr &node) {
|
||||
auto ms_primc = GetValueNode<std::shared_ptr<mindspore::ops::L2NormalizeFusion>>(node);
|
||||
return ms_primc != nullptr ? ops::MSOp2SchemaOp(ms_primc.get()) : nullptr;
|
||||
|
@ -913,6 +917,7 @@ RegistryMSOps g_lpNormalizationPrimitiveCreatorRegistry("LpNormalization", LpNor
|
|||
RegistryMSOps g_lrnPrimitiveCreatorRegistry("LRN", LrnPrimitiveCreator);
|
||||
RegistryMSOps g_lshProjectionPrimitiveCreatorRegistry("LshProjection", LshProjectionPrimitiveCreator);
|
||||
RegistryMSOps g_lSTMPrimitiveCreatorRegistry("LSTM", LSTMPrimitiveCreator);
|
||||
RegistryMSOps g_lSTMGradPrimitiveCreatorRegistry("LSTMGrad", LSTMGradPrimitiveCreator);
|
||||
RegistryMSOps g_l2NormalizeFusionPrimitiveCreatorRegistry("L2NormalizeFusion", L2NormalizeFusionPrimitiveCreator);
|
||||
RegistryMSOps g_matMulPrimitiveCreatorRegistry("MatMul", MatMulPrimitiveCreator);
|
||||
RegistryMSOps g_maximumPrimitiveCreatorRegistry("Maximum", MaximumPrimitiveCreator);
|
||||
|
|
|
@ -30,7 +30,6 @@ using mindspore::lite::RET_OK;
|
|||
using mindspore::schema::PrimitiveType_DropoutGrad;
|
||||
|
||||
namespace mindspore::kernel {
|
||||
|
||||
int DropoutGradCPUKernelFp16::Init() {
|
||||
auto param = reinterpret_cast<DropoutParameter *>(op_parameter_);
|
||||
if (param == nullptr) {
|
||||
|
@ -63,7 +62,6 @@ int DropoutGradCPUKernelFp16::Execute(int task_id) {
|
|||
auto length = in_tensors_.at(kInputIndex)->ElementsNum();
|
||||
int stride = UP_DIV(length, thread_count_);
|
||||
int count = MSMIN(stride, length - stride * task_id);
|
||||
|
||||
if (count > 0) {
|
||||
int start = stride * task_id;
|
||||
DropoutFp16Grad(&(yt_ptr[start]), &(mask_ptr[start]), &(output_ptr[start]), count, (float16_t)scale_);
|
||||
|
|
|
@ -67,7 +67,6 @@ int PoolingGradCPUKernelFp16::Execute(int task_id) {
|
|||
auto output_ptr = reinterpret_cast<float16_t *>(out_tensors_.at(0)->data_c());
|
||||
int stride = UP_DIV(pool_param->output_batch_, thread_num_);
|
||||
int count = MSMIN(stride, pool_param->output_batch_ - stride * task_id);
|
||||
|
||||
if (count > 0) {
|
||||
int in_batch_size = pool_param->input_h_ * pool_param->input_w_ * pool_param->input_channel_;
|
||||
int out_batch_size = pool_param->output_h_ * pool_param->output_w_ * pool_param->input_channel_;
|
||||
|
|
|
@ -28,7 +28,6 @@ using mindspore::lite::RET_OK;
|
|||
using mindspore::schema::PrimitiveType_ResizeGrad;
|
||||
|
||||
namespace mindspore::kernel {
|
||||
|
||||
float16_t ScalingFp16(size_t in_size, size_t out_size, bool align_corners) {
|
||||
return (align_corners && out_size > 1) ? (in_size - 1) / static_cast<float16_t>(out_size - 1)
|
||||
: in_size / static_cast<float16_t>(out_size);
|
||||
|
|
|
@ -30,7 +30,6 @@ using mindspore::lite::RET_OK;
|
|||
using mindspore::schema::PrimitiveType_StridedSliceGrad;
|
||||
|
||||
namespace mindspore::kernel {
|
||||
|
||||
int StridedSliceGradCPUKernelFp16::Init() {
|
||||
if (!InferShapeDone()) {
|
||||
return RET_OK;
|
||||
|
|
|
@ -29,7 +29,6 @@ using mindspore::lite::RET_OK;
|
|||
using mindspore::schema::PrimitiveType_UnsortedSegmentSum;
|
||||
|
||||
namespace mindspore::kernel {
|
||||
|
||||
int UnsortedSegmentSumCPUKernelFp16::Init() {
|
||||
if (!InferShapeDone()) {
|
||||
return RET_OK;
|
||||
|
|
|
@ -30,6 +30,8 @@ using mindspore::lite::RET_OK;
|
|||
using mindspore::schema::PrimitiveType_LSTM;
|
||||
|
||||
namespace mindspore::kernel {
|
||||
constexpr static int kWorkspaceOutIdx = 4;
|
||||
|
||||
void LstmCPUKernel::FreeTmpBuffer() {
|
||||
if (weight_i_ptr_ != nullptr) {
|
||||
free(weight_i_ptr_);
|
||||
|
@ -361,10 +363,18 @@ int LstmCPUKernel::LstmUnidirectional(float *output, const float *weight_i, cons
|
|||
float *output_ptr = output + real_t * lstm_param_->output_step_;
|
||||
LstmStepUnit(output_ptr, input_gate_t, forget_gate_t, cell_gate_t, output_gate_t, weight_h, state_bias,
|
||||
hidden_state, cell_state, buffer_, lstm_param_);
|
||||
if (IsTrain() && IsTrainable()) {
|
||||
RecordStates(cell_state, real_t);
|
||||
}
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
void LstmCPUKernel::RecordStates(float *cell_state, int step) {
|
||||
float *workspace = reinterpret_cast<float *>(out_tensors_[kWorkspaceOutIdx]->MutableData());
|
||||
workspace[step] = *cell_state;
|
||||
}
|
||||
|
||||
int LstmCPUKernel::InnerExecute(float *output, const float *input, float *hidden_state, float *cell_state) {
|
||||
// forward
|
||||
// buffer_[packed_input_index] : store packed input
|
||||
|
|
|
@ -14,8 +14,8 @@
|
|||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_LSTM_H_
|
||||
#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_LSTM_H_
|
||||
#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_LSTM_FP32_H_
|
||||
#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_LSTM_FP32_H_
|
||||
|
||||
#include <vector>
|
||||
#include "src/inner_kernel.h"
|
||||
|
@ -49,6 +49,7 @@ class LstmCPUKernel : public InnerKernel {
|
|||
int LstmUnidirectional(float *output, const float *weight_i, const float *weight_h, const float *input_bias,
|
||||
const float *state_bias, float *hidden_state, float *cell_state, bool is_backward);
|
||||
int InnerExecute(float *output, const float *input, float *hidden_state, float *cell_state);
|
||||
void RecordStates(float *cell_state, int step);
|
||||
const float *weight_loop_;
|
||||
const float *bias_loop_;
|
||||
float *gate_loop_;
|
||||
|
@ -84,4 +85,4 @@ class LstmCPUKernel : public InnerKernel {
|
|||
};
|
||||
} // namespace mindspore::kernel
|
||||
|
||||
#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_LSTM_H_
|
||||
#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_LSTM_FP32_H_
|
||||
|
|
|
@ -27,7 +27,7 @@
|
|||
#include "nnacl/fp32_grad/batch_norm.h"
|
||||
#include "include/errorcode.h"
|
||||
|
||||
using mindspore::kernel::KERNEL_ARCH;
|
||||
using mindspore::kernel::KERNEL_ARCH::kCPU;
|
||||
using mindspore::lite::KernelRegistrar;
|
||||
using mindspore::lite::RET_ERROR;
|
||||
using mindspore::lite::RET_OK;
|
||||
|
|
|
@ -18,17 +18,21 @@
|
|||
#include "src/ops/populate/populate_register.h"
|
||||
#include "src/ops/populate/default_populate.h"
|
||||
#include "src/ops/populate/strided_slice_populate.h"
|
||||
#include "nnacl/pooling_parameter.h"
|
||||
#include "nnacl/fp32_grad/softmax_grad.h"
|
||||
#include "nnacl/fp32/activation_fp32.h"
|
||||
#include "nnacl/conv_parameter.h"
|
||||
#include "nnacl/power_parameter.h"
|
||||
#include "nnacl/arithmetic.h"
|
||||
#include "nnacl/conv_parameter.h"
|
||||
#include "nnacl/lstm_parameter.h"
|
||||
#include "nnacl/pooling_parameter.h"
|
||||
#include "nnacl/power_parameter.h"
|
||||
#include "nnacl/fp32/activation_fp32.h"
|
||||
#include "nnacl/fp32_grad/softmax_grad.h"
|
||||
#include "nnacl/fp32_grad/optimizer.h"
|
||||
#include "nnacl/fp32_grad/batch_norm.h"
|
||||
#include "nnacl/fp32_grad/dropout_parameter.h"
|
||||
#include "nnacl/fp32_grad/smooth_l1_loss.h"
|
||||
#include "nnacl/fp32_grad/resize_grad.h"
|
||||
|
||||
using mindspore::lite::Registry;
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
namespace {
|
||||
|
@ -471,6 +475,29 @@ OpParameter *PopulateStridedSliceGradParameter(const void *prim) {
|
|||
return reinterpret_cast<OpParameter *>(strided_slice_param);
|
||||
}
|
||||
|
||||
OpParameter *PopulateLstmGradParameter(const void *prim) {
|
||||
auto primitive = static_cast<const schema::Primitive *>(prim);
|
||||
MS_ASSERT(primitive != nullptr);
|
||||
auto value = primitive->value_as_LSTMGrad();
|
||||
if (value == nullptr) {
|
||||
MS_LOG(ERROR) << "value is nullptr.";
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
auto *param = reinterpret_cast<LstmParameter *>(malloc(sizeof(LstmParameter)));
|
||||
if (param == nullptr) {
|
||||
MS_LOG(ERROR) << "malloc LstmParameter failed.";
|
||||
return nullptr;
|
||||
}
|
||||
memset(param, 0, sizeof(LstmParameter));
|
||||
|
||||
param->op_parameter_.type_ = primitive->value_type();
|
||||
param->bidirectional_ = value->bidirectional();
|
||||
param->zoneout_cell_ = value->zoneout_cell();
|
||||
param->zoneout_hidden_ = value->zoneout_hidden();
|
||||
return reinterpret_cast<OpParameter *>(param);
|
||||
}
|
||||
|
||||
void PopulateTrainParameters() {
|
||||
lite::Registry ApplyMomentumParameterRegistry(schema::PrimitiveType_ApplyMomentum, PopulateApplyMomentumParameter,
|
||||
lite::SCHEMA_CUR);
|
||||
|
@ -534,10 +561,9 @@ void PopulateTrainParameters() {
|
|||
lite::SCHEMA_CUR);
|
||||
lite::Registry RsqrtGradParameterRegistry(schema::PrimitiveType_RsqrtGrad, lite::DefaultPopulateParameter,
|
||||
lite::SCHEMA_CUR);
|
||||
lite::Registry ResizeGradParameterRegistry(schema::PrimitiveType_ResizeGrad, PopulateResizeGradParameter,
|
||||
lite::SCHEMA_CUR);
|
||||
lite::Registry AbsGradParameterRegistry(schema::PrimitiveType_AbsGrad, lite::DefaultPopulateParameter,
|
||||
lite::SCHEMA_CUR);
|
||||
Registry ResizeGradParameterRegistry(schema::PrimitiveType_ResizeGrad, PopulateResizeGradParameter, lite::SCHEMA_CUR);
|
||||
Registry AbsGradParameterRegistry(schema::PrimitiveType_AbsGrad, lite::DefaultPopulateParameter, lite::SCHEMA_CUR);
|
||||
Registry LSTMGradParameterRegistry(schema::PrimitiveType_LSTMGrad, PopulateLstmGradParameter, lite::SCHEMA_CUR);
|
||||
}
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -47,6 +47,7 @@ constexpr int kFieldsToPrint = 5;
|
|||
constexpr int kPrintOffset = 4;
|
||||
constexpr int kCPUBindFlag2 = 2;
|
||||
constexpr int kCPUBindFlag1 = 1;
|
||||
static const int kTHOUSAND = 1000;
|
||||
|
||||
namespace {
|
||||
float *ReadFileBuf(const char *file, size_t *size) {
|
||||
|
@ -436,8 +437,8 @@ int NetTrain::CreateAndRunNetwork(const std::string &filename, const std::string
|
|||
}
|
||||
|
||||
auto end_prepare_time = GetTimeUs();
|
||||
MS_LOG(INFO) << "PrepareTime = " << ((end_prepare_time - start_prepare_time) / 1000) << " ms";
|
||||
std::cout << "PrepareTime = " << ((end_prepare_time - start_prepare_time) / 1000) << " ms" << std::endl;
|
||||
MS_LOG(INFO) << "PrepareTime = " << ((end_prepare_time - start_prepare_time) / kTHOUSAND) << " ms";
|
||||
std::cout << "PrepareTime = " << ((end_prepare_time - start_prepare_time) / kTHOUSAND) << " ms" << std::endl;
|
||||
// Load input
|
||||
MS_LOG(INFO) << "Load input data";
|
||||
auto ms_inputs = session->GetInputs();
|
||||
|
|
Loading…
Reference in New Issue