Add DynamicRNN for old backend.

This commit is contained in:
liuxiao93 2020-10-19 17:19:49 +08:00
parent b6715eb790
commit 54c96fe13b
4 changed files with 54 additions and 2 deletions

View File

@ -183,6 +183,8 @@ constexpr const char kNameBasicLSTMCell[] = "BasicLSTMCell";
constexpr const char kNameBasicLSTMCellInputGrad[] = "BasicLSTMCellInputGrad";
constexpr const char kNameBasicLSTMCellWeightGrad[] = "BasicLSTMCellWeightGrad";
constexpr const char kNameBasicLSTMCellCStateGrad[] = "BasicLSTMCellCStateGrad";
constexpr const char kNameDynamicRNN[] = "DynamicRNN";
constexpr const char kNameDynamicRNNGrad[] = "DynamicRNNGrad";
constexpr const char kNameL2Loss[] = "L2Loss";
constexpr const char kNameCTCLoss[] = "CTCLoss";
constexpr const char kNameRange[] = "Range";

View File

@ -48,4 +48,48 @@ ATTR_MAP(BasicLSTMCellCStateGrad) = {{"forget_bias", ATTR_DESC(forget_bias, AnyT
{"activation", ATTR_DESC(activation, AnyTraits<std::string>())}};
OUTPUT_MAP(BasicLSTMCellCStateGrad) = {{0, OUTPUT_DESC(dgate)}, {1, OUTPUT_DESC(dct_1)}};
REG_ADPT_DESC(BasicLSTMCellCStateGrad, kNameBasicLSTMCellCStateGrad, ADPT_DESC(BasicLSTMCellCStateGrad))
// DynamicRNN
INPUT_MAP(DynamicRNN) = {{1, INPUT_DESC(x)}, {2, INPUT_DESC(w)}, {3, INPUT_DESC(b)},
{4, INPUT_DESC(seq_length)}, {5, INPUT_DESC(init_h)}, {6, INPUT_DESC(init_c)},
{7, INPUT_DESC(wci)}, {8, INPUT_DESC(wcf)}, {9, INPUT_DESC(wco)},
{10, INPUT_DESC(mask)}};
ATTR_MAP(DynamicRNN) = {{"cell_type", ATTR_DESC(cell_type, AnyTraits<std::string>())},
{"direction", ATTR_DESC(direction, AnyTraits<std::string>())},
{"cell_depth", ATTR_DESC(cell_depth, AnyTraits<int64_t>())},
{"use_peephole", ATTR_DESC(use_peephole, AnyTraits<bool>())},
{"keep_prob", ATTR_DESC(keep_prob, AnyTraits<float>())},
{"cell_clip", ATTR_DESC(cell_clip, AnyTraits<float>())},
{"num_proj", ATTR_DESC(num_proj, AnyTraits<int64_t>())},
{"time_major", ATTR_DESC(time_major, AnyTraits<bool>())},
{"ivation", ATTR_DESC(activation, AnyTraits<std::string>())},
{"forget_bias", ATTR_DESC(forget_bias, AnyTraits<float>())},
{"is_training", ATTR_DESC(is_training, AnyTraits<bool>())}};
OUTPUT_MAP(DynamicRNN) = {{0, OUTPUT_DESC(y)}, {1, OUTPUT_DESC(output_h)}, {2, OUTPUT_DESC(output_c)},
{3, OUTPUT_DESC(i)}, {4, OUTPUT_DESC(j)}, {5, OUTPUT_DESC(f)},
{6, OUTPUT_DESC(o)}, {7, OUTPUT_DESC(tanhc)}};
REG_ADPT_DESC(DynamicRNN, kNameDynamicRNN, ADPT_DESC(DynamicRNN))
// DynamicRNNGrad
INPUT_MAP(DynamicRNNGrad) = {
{1, INPUT_DESC(x)}, {2, INPUT_DESC(w)}, {3, INPUT_DESC(b)}, {4, INPUT_DESC(y)},
{5, INPUT_DESC(init_h)}, {6, INPUT_DESC(init_c)}, {7, INPUT_DESC(h)}, {8, INPUT_DESC(c)},
{9, INPUT_DESC(dy)}, {10, INPUT_DESC(dh)}, {11, INPUT_DESC(dc)}, {12, INPUT_DESC(i)},
{13, INPUT_DESC(j)}, {14, INPUT_DESC(f)}, {15, INPUT_DESC(o)}, {16, INPUT_DESC(tanhct)}};
ATTR_MAP(DynamicRNNGrad) = {{"cell_type", ATTR_DESC(cell_type, AnyTraits<std::string>())},
{"direction", ATTR_DESC(direction, AnyTraits<std::string>())},
{"cell_depth", ATTR_DESC(cell_depth, AnyTraits<int64_t>())},
{"use_peephole", ATTR_DESC(use_peephole, AnyTraits<bool>())},
{"keep_prob", ATTR_DESC(keep_prob, AnyTraits<float>())},
{"cell_clip", ATTR_DESC(cell_clip, AnyTraits<float>())},
{"num_proj", ATTR_DESC(num_proj, AnyTraits<int64_t>())},
{"time_major", ATTR_DESC(time_major, AnyTraits<bool>())},
{"forget_bias", ATTR_DESC(forget_bias, AnyTraits<float>())}};
OUTPUT_MAP(DynamicRNNGrad) = {{0, OUTPUT_DESC(dw)},
{1, OUTPUT_DESC(db)},
{2, OUTPUT_DESC(dx)},
{3, OUTPUT_DESC(dh_prev)},
{4, OUTPUT_DESC(dc_prev)}};
REG_ADPT_DESC(DynamicRNNGrad, kNameDynamicRNNGrad, ADPT_DESC(DynamicRNNGrad))
} // namespace mindspore::transform

View File

@ -19,8 +19,8 @@
#include <string>
#include <unordered_map>
#include "transform/graph_ir/op_declare/op_declare_macro.h"
#include "ops/rnn.h"
#include "transform/graph_ir/op_declare/op_declare_macro.h"
namespace mindspore::transform {
DECLARE_OP_ADAPTER(BasicLSTMCell)
@ -34,5 +34,11 @@ DECLARE_OP_USE_OUTPUT(BasicLSTMCellWeightGrad)
DECLARE_OP_ADAPTER(BasicLSTMCellCStateGrad)
DECLARE_OP_USE_OUTPUT(BasicLSTMCellCStateGrad)
DECLARE_OP_ADAPTER(DynamicRNN)
DECLARE_OP_USE_OUTPUT(DynamicRNN)
DECLARE_OP_ADAPTER(DynamicRNNGrad)
DECLARE_OP_USE_OUTPUT(DynamicRNNGrad)
} // namespace mindspore::transform
#endif // MINDSPORE_CCSRC_TRANSFORM_GRAPH_IR_OP_DECLARE_RNN_DECLARE_H_

View File

@ -1062,7 +1062,7 @@ class DynamicRNNGrad(PrimitiveWithInfer):
keep_prob=-1.0,
cell_clip=-1.0,
num_proj=0,
time_major=False,
time_major=True,
forget_bias=0.0):
self.forget_bias = validator.check_value_type("forget_bias", forget_bias, [float], self.name)
self.add_prim_attr("io_format", "ND")