forked from mindspore-Ecosystem/mindspore
Add DynamicRNN for old backend.
This commit is contained in:
parent
b6715eb790
commit
54c96fe13b
|
@ -183,6 +183,8 @@ constexpr const char kNameBasicLSTMCell[] = "BasicLSTMCell";
|
|||
constexpr const char kNameBasicLSTMCellInputGrad[] = "BasicLSTMCellInputGrad";
|
||||
constexpr const char kNameBasicLSTMCellWeightGrad[] = "BasicLSTMCellWeightGrad";
|
||||
constexpr const char kNameBasicLSTMCellCStateGrad[] = "BasicLSTMCellCStateGrad";
|
||||
constexpr const char kNameDynamicRNN[] = "DynamicRNN";
|
||||
constexpr const char kNameDynamicRNNGrad[] = "DynamicRNNGrad";
|
||||
constexpr const char kNameL2Loss[] = "L2Loss";
|
||||
constexpr const char kNameCTCLoss[] = "CTCLoss";
|
||||
constexpr const char kNameRange[] = "Range";
|
||||
|
|
|
@ -48,4 +48,48 @@ ATTR_MAP(BasicLSTMCellCStateGrad) = {{"forget_bias", ATTR_DESC(forget_bias, AnyT
|
|||
{"activation", ATTR_DESC(activation, AnyTraits<std::string>())}};
|
||||
OUTPUT_MAP(BasicLSTMCellCStateGrad) = {{0, OUTPUT_DESC(dgate)}, {1, OUTPUT_DESC(dct_1)}};
|
||||
REG_ADPT_DESC(BasicLSTMCellCStateGrad, kNameBasicLSTMCellCStateGrad, ADPT_DESC(BasicLSTMCellCStateGrad))
|
||||
|
||||
// DynamicRNN
|
||||
INPUT_MAP(DynamicRNN) = {{1, INPUT_DESC(x)}, {2, INPUT_DESC(w)}, {3, INPUT_DESC(b)},
|
||||
{4, INPUT_DESC(seq_length)}, {5, INPUT_DESC(init_h)}, {6, INPUT_DESC(init_c)},
|
||||
{7, INPUT_DESC(wci)}, {8, INPUT_DESC(wcf)}, {9, INPUT_DESC(wco)},
|
||||
{10, INPUT_DESC(mask)}};
|
||||
ATTR_MAP(DynamicRNN) = {{"cell_type", ATTR_DESC(cell_type, AnyTraits<std::string>())},
|
||||
{"direction", ATTR_DESC(direction, AnyTraits<std::string>())},
|
||||
{"cell_depth", ATTR_DESC(cell_depth, AnyTraits<int64_t>())},
|
||||
{"use_peephole", ATTR_DESC(use_peephole, AnyTraits<bool>())},
|
||||
{"keep_prob", ATTR_DESC(keep_prob, AnyTraits<float>())},
|
||||
{"cell_clip", ATTR_DESC(cell_clip, AnyTraits<float>())},
|
||||
{"num_proj", ATTR_DESC(num_proj, AnyTraits<int64_t>())},
|
||||
{"time_major", ATTR_DESC(time_major, AnyTraits<bool>())},
|
||||
{"ivation", ATTR_DESC(activation, AnyTraits<std::string>())},
|
||||
{"forget_bias", ATTR_DESC(forget_bias, AnyTraits<float>())},
|
||||
{"is_training", ATTR_DESC(is_training, AnyTraits<bool>())}};
|
||||
OUTPUT_MAP(DynamicRNN) = {{0, OUTPUT_DESC(y)}, {1, OUTPUT_DESC(output_h)}, {2, OUTPUT_DESC(output_c)},
|
||||
{3, OUTPUT_DESC(i)}, {4, OUTPUT_DESC(j)}, {5, OUTPUT_DESC(f)},
|
||||
{6, OUTPUT_DESC(o)}, {7, OUTPUT_DESC(tanhc)}};
|
||||
REG_ADPT_DESC(DynamicRNN, kNameDynamicRNN, ADPT_DESC(DynamicRNN))
|
||||
|
||||
// DynamicRNNGrad
|
||||
INPUT_MAP(DynamicRNNGrad) = {
|
||||
{1, INPUT_DESC(x)}, {2, INPUT_DESC(w)}, {3, INPUT_DESC(b)}, {4, INPUT_DESC(y)},
|
||||
{5, INPUT_DESC(init_h)}, {6, INPUT_DESC(init_c)}, {7, INPUT_DESC(h)}, {8, INPUT_DESC(c)},
|
||||
{9, INPUT_DESC(dy)}, {10, INPUT_DESC(dh)}, {11, INPUT_DESC(dc)}, {12, INPUT_DESC(i)},
|
||||
{13, INPUT_DESC(j)}, {14, INPUT_DESC(f)}, {15, INPUT_DESC(o)}, {16, INPUT_DESC(tanhct)}};
|
||||
|
||||
ATTR_MAP(DynamicRNNGrad) = {{"cell_type", ATTR_DESC(cell_type, AnyTraits<std::string>())},
|
||||
{"direction", ATTR_DESC(direction, AnyTraits<std::string>())},
|
||||
{"cell_depth", ATTR_DESC(cell_depth, AnyTraits<int64_t>())},
|
||||
{"use_peephole", ATTR_DESC(use_peephole, AnyTraits<bool>())},
|
||||
{"keep_prob", ATTR_DESC(keep_prob, AnyTraits<float>())},
|
||||
{"cell_clip", ATTR_DESC(cell_clip, AnyTraits<float>())},
|
||||
{"num_proj", ATTR_DESC(num_proj, AnyTraits<int64_t>())},
|
||||
{"time_major", ATTR_DESC(time_major, AnyTraits<bool>())},
|
||||
{"forget_bias", ATTR_DESC(forget_bias, AnyTraits<float>())}};
|
||||
OUTPUT_MAP(DynamicRNNGrad) = {{0, OUTPUT_DESC(dw)},
|
||||
{1, OUTPUT_DESC(db)},
|
||||
{2, OUTPUT_DESC(dx)},
|
||||
{3, OUTPUT_DESC(dh_prev)},
|
||||
{4, OUTPUT_DESC(dc_prev)}};
|
||||
REG_ADPT_DESC(DynamicRNNGrad, kNameDynamicRNNGrad, ADPT_DESC(DynamicRNNGrad))
|
||||
} // namespace mindspore::transform
|
||||
|
|
|
@ -19,8 +19,8 @@
|
|||
|
||||
#include <string>
|
||||
#include <unordered_map>
|
||||
#include "transform/graph_ir/op_declare/op_declare_macro.h"
|
||||
#include "ops/rnn.h"
|
||||
#include "transform/graph_ir/op_declare/op_declare_macro.h"
|
||||
|
||||
namespace mindspore::transform {
|
||||
DECLARE_OP_ADAPTER(BasicLSTMCell)
|
||||
|
@ -34,5 +34,11 @@ DECLARE_OP_USE_OUTPUT(BasicLSTMCellWeightGrad)
|
|||
|
||||
DECLARE_OP_ADAPTER(BasicLSTMCellCStateGrad)
|
||||
DECLARE_OP_USE_OUTPUT(BasicLSTMCellCStateGrad)
|
||||
|
||||
DECLARE_OP_ADAPTER(DynamicRNN)
|
||||
DECLARE_OP_USE_OUTPUT(DynamicRNN)
|
||||
|
||||
DECLARE_OP_ADAPTER(DynamicRNNGrad)
|
||||
DECLARE_OP_USE_OUTPUT(DynamicRNNGrad)
|
||||
} // namespace mindspore::transform
|
||||
#endif // MINDSPORE_CCSRC_TRANSFORM_GRAPH_IR_OP_DECLARE_RNN_DECLARE_H_
|
||||
|
|
|
@ -1062,7 +1062,7 @@ class DynamicRNNGrad(PrimitiveWithInfer):
|
|||
keep_prob=-1.0,
|
||||
cell_clip=-1.0,
|
||||
num_proj=0,
|
||||
time_major=False,
|
||||
time_major=True,
|
||||
forget_bias=0.0):
|
||||
self.forget_bias = validator.check_value_type("forget_bias", forget_bias, [float], self.name)
|
||||
self.add_prim_attr("io_format", "ND")
|
||||
|
|
Loading…
Reference in New Issue