diff --git a/mindspore/ccsrc/transform/graph_ir/convert.cc b/mindspore/ccsrc/transform/graph_ir/convert.cc index 3cb9697465c..e05f3e01a03 100644 --- a/mindspore/ccsrc/transform/graph_ir/convert.cc +++ b/mindspore/ccsrc/transform/graph_ir/convert.cc @@ -201,6 +201,10 @@ const char kNameBatchToSpace[] = "BatchToSpace"; const char kNameAtan2[] = "Atan2"; const char kNameApplyRMSProp[] = "ApplyRMSProp"; const char kNameApplyCenteredRMSProp[] = "ApplyCenteredRMSProp"; +const char kNameBasicLSTMCell[] = "BasicLSTMCell"; +const char kNameBasicLSTMCellInputGrad[] = "BasicLSTMCellInputGrad"; +const char kNameBasicLSTMCellWeightGrad[] = "BasicLSTMCellWeightGrad"; +const char kNameBasicLSTMCellCStateGrad[] = "BasicLSTMCellCStateGrad"; const char kNameL2Loss[] = "L2Loss"; const char kNameCTCLoss[] = "CTCLoss"; const char kNameRange[] = "Range"; @@ -410,6 +414,10 @@ std::unordered_map &DfGraphConvertor::get_adpt_ma {string(kNameAtan2), ADPT_DESC(Atan2)}, {string(kNameApplyRMSProp), ADPT_DESC(ApplyRMSPropD)}, {string(kNameApplyCenteredRMSProp), ADPT_DESC(ApplyCenteredRMSPropD)}, + {string(kNameBasicLSTMCell), ADPT_DESC(BasicLSTMCell)}, + {string(kNameBasicLSTMCellInputGrad), ADPT_DESC(BasicLSTMCellInputGrad)}, + {string(kNameBasicLSTMCellWeightGrad), ADPT_DESC(BasicLSTMCellWeightGrad)}, + {string(kNameBasicLSTMCellCStateGrad), ADPT_DESC(BasicLSTMCellCStateGrad)}, {string(kNameL2Loss), ADPT_DESC(L2Loss)}, {string(kNameCTCLoss), ADPT_DESC(CTCLoss)}, {string(kNameRange), ADPT_DESC(RangeD)}, diff --git a/mindspore/ccsrc/transform/graph_ir/op_declare.cc b/mindspore/ccsrc/transform/graph_ir/op_declare.cc index 372051926c4..7efc6061586 100644 --- a/mindspore/ccsrc/transform/graph_ir/op_declare.cc +++ b/mindspore/ccsrc/transform/graph_ir/op_declare.cc @@ -1292,6 +1292,34 @@ ATTR_MAP(ApplyCenteredRMSPropD) = {{"use_locking", ATTR_DESC(use_locking, AnyTra OUTPUT_MAP(ApplyCenteredRMSPropD) = { {0, OUTPUT_DESC(var)}, {1, OUTPUT_DESC(mg)}, {2, OUTPUT_DESC(ms)}, {3, OUTPUT_DESC(mom)}}; +// BasicLSTMCell +INPUT_MAP(BasicLSTMCell) = { + {1, INPUT_DESC(x)}, {2, INPUT_DESC(h)}, {3, INPUT_DESC(c)}, {4, INPUT_DESC(w)}, {5, INPUT_DESC(b)}}; +ATTR_MAP(BasicLSTMCell) = {{"keep_prob", ATTR_DESC(keep_prob, AnyTraits())}, + {"forget_bias", ATTR_DESC(forget_bias, AnyTraits())}, + {"state_is_tuple", ATTR_DESC(state_is_tuple, AnyTraits())}, + {"activation", ATTR_DESC(activation, AnyTraits())}}; +OUTPUT_MAP(BasicLSTMCell) = {{0, OUTPUT_DESC(ct)}, {1, OUTPUT_DESC(ht)}, {2, OUTPUT_DESC(it)}, {3, OUTPUT_DESC(jt)}, + {4, OUTPUT_DESC(ft)}, {5, OUTPUT_DESC(ot)}, {7, OUTPUT_DESC(tanhct)}}; + +// BasicLSTMCellInputGrad +INPUT_MAP(BasicLSTMCellInputGrad) = {{1, INPUT_DESC(dgate)}, {2, INPUT_DESC(w)}}; +ATTR_MAP(BasicLSTMCellInputGrad) = {{"keep_prob", ATTR_DESC(keep_prob, AnyTraits())}}; +OUTPUT_MAP(BasicLSTMCellInputGrad) = {{0, OUTPUT_DESC(dxt)}, {1, OUTPUT_DESC(dht)}}; + +// BasicLSTMCellWeightGrad +INPUT_MAP(BasicLSTMCellWeightGrad) = {{1, INPUT_DESC(h)}, {2, INPUT_DESC(x)}, {3, INPUT_DESC(dgate)}}; +ATTR_MAP(BasicLSTMCellWeightGrad) = EMPTY_ATTR_MAP; +OUTPUT_MAP(BasicLSTMCellWeightGrad) = {{0, OUTPUT_DESC(dw)}, {1, OUTPUT_DESC(db)}}; + +// BasicLSTMCellCStateGrad +INPUT_MAP(BasicLSTMCellCStateGrad) = {{1, INPUT_DESC(c)}, {2, INPUT_DESC(dht)}, {3, INPUT_DESC(dct)}, + {4, INPUT_DESC(it)}, {5, INPUT_DESC(jt)}, {6, INPUT_DESC(ft)}, + {7, INPUT_DESC(ot)}, {8, INPUT_DESC(tanhct)}}; +ATTR_MAP(BasicLSTMCellCStateGrad) = {{"forget_bias", ATTR_DESC(forget_bias, AnyTraits())}, + {"activation", ATTR_DESC(activation, AnyTraits())}}; +OUTPUT_MAP(BasicLSTMCellCStateGrad) = {{0, OUTPUT_DESC(dgate)}, {1, OUTPUT_DESC(dct_1)}}; + // L2Loss INPUT_MAP(L2Loss) = {{1, INPUT_DESC(x)}}; ATTR_MAP(L2Loss) = EMPTY_ATTR_MAP; diff --git a/mindspore/ccsrc/transform/graph_ir/op_declare.h b/mindspore/ccsrc/transform/graph_ir/op_declare.h index 93462c4071a..c34a40df9a6 100755 --- a/mindspore/ccsrc/transform/graph_ir/op_declare.h +++ b/mindspore/ccsrc/transform/graph_ir/op_declare.h @@ -488,6 +488,14 @@ DECLARE_OP_USE_INPUT_ATTR(ApplyRMSPropD) DECLARE_OP_USE_OUTPUT(ApplyRMSPropD) DECLARE_OP_ADAPTER(ApplyCenteredRMSPropD) DECLARE_OP_USE_OUTPUT(ApplyCenteredRMSPropD) +DECLARE_OP_ADAPTER(BasicLSTMCell) +DECLARE_OP_USE_OUTPUT(BasicLSTMCell) +DECLARE_OP_ADAPTER(BasicLSTMCellInputGrad) +DECLARE_OP_USE_OUTPUT(BasicLSTMCellInputGrad) +DECLARE_OP_ADAPTER(BasicLSTMCellWeightGrad) +DECLARE_OP_USE_OUTPUT(BasicLSTMCellWeightGrad) +DECLARE_OP_ADAPTER(BasicLSTMCellCStateGrad) +DECLARE_OP_USE_OUTPUT(BasicLSTMCellCStateGrad) DECLARE_OP_ADAPTER(L2Loss) DECLARE_OP_USE_OUTPUT(L2Loss) DECLARE_OP_ADAPTER(CTCLoss)