!3190 Add op BasicLSTMCell for GE.

Merge pull request !3190 from liuxiao93/BasicLSTMCell
This commit is contained in:
mindspore-ci-bot 2020-07-18 19:28:11 +08:00 committed by Gitee
commit b34d40a1df
3 changed files with 44 additions and 0 deletions

View File

@ -201,6 +201,10 @@ const char kNameBatchToSpace[] = "BatchToSpace";
const char kNameAtan2[] = "Atan2";
const char kNameApplyRMSProp[] = "ApplyRMSProp";
const char kNameApplyCenteredRMSProp[] = "ApplyCenteredRMSProp";
const char kNameBasicLSTMCell[] = "BasicLSTMCell";
const char kNameBasicLSTMCellInputGrad[] = "BasicLSTMCellInputGrad";
const char kNameBasicLSTMCellWeightGrad[] = "BasicLSTMCellWeightGrad";
const char kNameBasicLSTMCellCStateGrad[] = "BasicLSTMCellCStateGrad";
const char kNameL2Loss[] = "L2Loss";
const char kNameCTCLoss[] = "CTCLoss";
const char kNameRange[] = "Range";
@ -410,6 +414,10 @@ std::unordered_map<std::string, OpAdapterDescPtr> &DfGraphConvertor::get_adpt_ma
{string(kNameAtan2), ADPT_DESC(Atan2)},
{string(kNameApplyRMSProp), ADPT_DESC(ApplyRMSPropD)},
{string(kNameApplyCenteredRMSProp), ADPT_DESC(ApplyCenteredRMSPropD)},
{string(kNameBasicLSTMCell), ADPT_DESC(BasicLSTMCell)},
{string(kNameBasicLSTMCellInputGrad), ADPT_DESC(BasicLSTMCellInputGrad)},
{string(kNameBasicLSTMCellWeightGrad), ADPT_DESC(BasicLSTMCellWeightGrad)},
{string(kNameBasicLSTMCellCStateGrad), ADPT_DESC(BasicLSTMCellCStateGrad)},
{string(kNameL2Loss), ADPT_DESC(L2Loss)},
{string(kNameCTCLoss), ADPT_DESC(CTCLoss)},
{string(kNameRange), ADPT_DESC(RangeD)},

View File

@ -1292,6 +1292,34 @@ ATTR_MAP(ApplyCenteredRMSPropD) = {{"use_locking", ATTR_DESC(use_locking, AnyTra
OUTPUT_MAP(ApplyCenteredRMSPropD) = {
{0, OUTPUT_DESC(var)}, {1, OUTPUT_DESC(mg)}, {2, OUTPUT_DESC(ms)}, {3, OUTPUT_DESC(mom)}};
// BasicLSTMCell
INPUT_MAP(BasicLSTMCell) = {
{1, INPUT_DESC(x)}, {2, INPUT_DESC(h)}, {3, INPUT_DESC(c)}, {4, INPUT_DESC(w)}, {5, INPUT_DESC(b)}};
ATTR_MAP(BasicLSTMCell) = {{"keep_prob", ATTR_DESC(keep_prob, AnyTraits<float>())},
{"forget_bias", ATTR_DESC(forget_bias, AnyTraits<float>())},
{"state_is_tuple", ATTR_DESC(state_is_tuple, AnyTraits<bool>())},
{"activation", ATTR_DESC(activation, AnyTraits<std::string>())}};
OUTPUT_MAP(BasicLSTMCell) = {{0, OUTPUT_DESC(ct)}, {1, OUTPUT_DESC(ht)}, {2, OUTPUT_DESC(it)}, {3, OUTPUT_DESC(jt)},
{4, OUTPUT_DESC(ft)}, {5, OUTPUT_DESC(ot)}, {7, OUTPUT_DESC(tanhct)}};
// BasicLSTMCellInputGrad
INPUT_MAP(BasicLSTMCellInputGrad) = {{1, INPUT_DESC(dgate)}, {2, INPUT_DESC(w)}};
ATTR_MAP(BasicLSTMCellInputGrad) = {{"keep_prob", ATTR_DESC(keep_prob, AnyTraits<float>())}};
OUTPUT_MAP(BasicLSTMCellInputGrad) = {{0, OUTPUT_DESC(dxt)}, {1, OUTPUT_DESC(dht)}};
// BasicLSTMCellWeightGrad
INPUT_MAP(BasicLSTMCellWeightGrad) = {{1, INPUT_DESC(h)}, {2, INPUT_DESC(x)}, {3, INPUT_DESC(dgate)}};
ATTR_MAP(BasicLSTMCellWeightGrad) = EMPTY_ATTR_MAP;
OUTPUT_MAP(BasicLSTMCellWeightGrad) = {{0, OUTPUT_DESC(dw)}, {1, OUTPUT_DESC(db)}};
// BasicLSTMCellCStateGrad
INPUT_MAP(BasicLSTMCellCStateGrad) = {{1, INPUT_DESC(c)}, {2, INPUT_DESC(dht)}, {3, INPUT_DESC(dct)},
{4, INPUT_DESC(it)}, {5, INPUT_DESC(jt)}, {6, INPUT_DESC(ft)},
{7, INPUT_DESC(ot)}, {8, INPUT_DESC(tanhct)}};
ATTR_MAP(BasicLSTMCellCStateGrad) = {{"forget_bias", ATTR_DESC(forget_bias, AnyTraits<float>())},
{"activation", ATTR_DESC(activation, AnyTraits<std::string>())}};
OUTPUT_MAP(BasicLSTMCellCStateGrad) = {{0, OUTPUT_DESC(dgate)}, {1, OUTPUT_DESC(dct_1)}};
// L2Loss
INPUT_MAP(L2Loss) = {{1, INPUT_DESC(x)}};
ATTR_MAP(L2Loss) = EMPTY_ATTR_MAP;

View File

@ -488,6 +488,14 @@ DECLARE_OP_USE_INPUT_ATTR(ApplyRMSPropD)
DECLARE_OP_USE_OUTPUT(ApplyRMSPropD)
DECLARE_OP_ADAPTER(ApplyCenteredRMSPropD)
DECLARE_OP_USE_OUTPUT(ApplyCenteredRMSPropD)
DECLARE_OP_ADAPTER(BasicLSTMCell)
DECLARE_OP_USE_OUTPUT(BasicLSTMCell)
DECLARE_OP_ADAPTER(BasicLSTMCellInputGrad)
DECLARE_OP_USE_OUTPUT(BasicLSTMCellInputGrad)
DECLARE_OP_ADAPTER(BasicLSTMCellWeightGrad)
DECLARE_OP_USE_OUTPUT(BasicLSTMCellWeightGrad)
DECLARE_OP_ADAPTER(BasicLSTMCellCStateGrad)
DECLARE_OP_USE_OUTPUT(BasicLSTMCellCStateGrad)
DECLARE_OP_ADAPTER(L2Loss)
DECLARE_OP_USE_OUTPUT(L2Loss)
DECLARE_OP_ADAPTER(CTCLoss)