forked from mindspore-Ecosystem/mindspore
!3190 Add op BasicLSTMCell for GE.
Merge pull request !3190 from liuxiao93/BasicLSTMCell
This commit is contained in:
commit
b34d40a1df
|
@ -201,6 +201,10 @@ const char kNameBatchToSpace[] = "BatchToSpace";
|
|||
const char kNameAtan2[] = "Atan2";
|
||||
const char kNameApplyRMSProp[] = "ApplyRMSProp";
|
||||
const char kNameApplyCenteredRMSProp[] = "ApplyCenteredRMSProp";
|
||||
const char kNameBasicLSTMCell[] = "BasicLSTMCell";
|
||||
const char kNameBasicLSTMCellInputGrad[] = "BasicLSTMCellInputGrad";
|
||||
const char kNameBasicLSTMCellWeightGrad[] = "BasicLSTMCellWeightGrad";
|
||||
const char kNameBasicLSTMCellCStateGrad[] = "BasicLSTMCellCStateGrad";
|
||||
const char kNameL2Loss[] = "L2Loss";
|
||||
const char kNameCTCLoss[] = "CTCLoss";
|
||||
const char kNameRange[] = "Range";
|
||||
|
@ -410,6 +414,10 @@ std::unordered_map<std::string, OpAdapterDescPtr> &DfGraphConvertor::get_adpt_ma
|
|||
{string(kNameAtan2), ADPT_DESC(Atan2)},
|
||||
{string(kNameApplyRMSProp), ADPT_DESC(ApplyRMSPropD)},
|
||||
{string(kNameApplyCenteredRMSProp), ADPT_DESC(ApplyCenteredRMSPropD)},
|
||||
{string(kNameBasicLSTMCell), ADPT_DESC(BasicLSTMCell)},
|
||||
{string(kNameBasicLSTMCellInputGrad), ADPT_DESC(BasicLSTMCellInputGrad)},
|
||||
{string(kNameBasicLSTMCellWeightGrad), ADPT_DESC(BasicLSTMCellWeightGrad)},
|
||||
{string(kNameBasicLSTMCellCStateGrad), ADPT_DESC(BasicLSTMCellCStateGrad)},
|
||||
{string(kNameL2Loss), ADPT_DESC(L2Loss)},
|
||||
{string(kNameCTCLoss), ADPT_DESC(CTCLoss)},
|
||||
{string(kNameRange), ADPT_DESC(RangeD)},
|
||||
|
|
|
@ -1292,6 +1292,34 @@ ATTR_MAP(ApplyCenteredRMSPropD) = {{"use_locking", ATTR_DESC(use_locking, AnyTra
|
|||
OUTPUT_MAP(ApplyCenteredRMSPropD) = {
|
||||
{0, OUTPUT_DESC(var)}, {1, OUTPUT_DESC(mg)}, {2, OUTPUT_DESC(ms)}, {3, OUTPUT_DESC(mom)}};
|
||||
|
||||
// BasicLSTMCell
|
||||
INPUT_MAP(BasicLSTMCell) = {
|
||||
{1, INPUT_DESC(x)}, {2, INPUT_DESC(h)}, {3, INPUT_DESC(c)}, {4, INPUT_DESC(w)}, {5, INPUT_DESC(b)}};
|
||||
ATTR_MAP(BasicLSTMCell) = {{"keep_prob", ATTR_DESC(keep_prob, AnyTraits<float>())},
|
||||
{"forget_bias", ATTR_DESC(forget_bias, AnyTraits<float>())},
|
||||
{"state_is_tuple", ATTR_DESC(state_is_tuple, AnyTraits<bool>())},
|
||||
{"activation", ATTR_DESC(activation, AnyTraits<std::string>())}};
|
||||
OUTPUT_MAP(BasicLSTMCell) = {{0, OUTPUT_DESC(ct)}, {1, OUTPUT_DESC(ht)}, {2, OUTPUT_DESC(it)}, {3, OUTPUT_DESC(jt)},
|
||||
{4, OUTPUT_DESC(ft)}, {5, OUTPUT_DESC(ot)}, {7, OUTPUT_DESC(tanhct)}};
|
||||
|
||||
// BasicLSTMCellInputGrad
|
||||
INPUT_MAP(BasicLSTMCellInputGrad) = {{1, INPUT_DESC(dgate)}, {2, INPUT_DESC(w)}};
|
||||
ATTR_MAP(BasicLSTMCellInputGrad) = {{"keep_prob", ATTR_DESC(keep_prob, AnyTraits<float>())}};
|
||||
OUTPUT_MAP(BasicLSTMCellInputGrad) = {{0, OUTPUT_DESC(dxt)}, {1, OUTPUT_DESC(dht)}};
|
||||
|
||||
// BasicLSTMCellWeightGrad
|
||||
INPUT_MAP(BasicLSTMCellWeightGrad) = {{1, INPUT_DESC(h)}, {2, INPUT_DESC(x)}, {3, INPUT_DESC(dgate)}};
|
||||
ATTR_MAP(BasicLSTMCellWeightGrad) = EMPTY_ATTR_MAP;
|
||||
OUTPUT_MAP(BasicLSTMCellWeightGrad) = {{0, OUTPUT_DESC(dw)}, {1, OUTPUT_DESC(db)}};
|
||||
|
||||
// BasicLSTMCellCStateGrad
|
||||
INPUT_MAP(BasicLSTMCellCStateGrad) = {{1, INPUT_DESC(c)}, {2, INPUT_DESC(dht)}, {3, INPUT_DESC(dct)},
|
||||
{4, INPUT_DESC(it)}, {5, INPUT_DESC(jt)}, {6, INPUT_DESC(ft)},
|
||||
{7, INPUT_DESC(ot)}, {8, INPUT_DESC(tanhct)}};
|
||||
ATTR_MAP(BasicLSTMCellCStateGrad) = {{"forget_bias", ATTR_DESC(forget_bias, AnyTraits<float>())},
|
||||
{"activation", ATTR_DESC(activation, AnyTraits<std::string>())}};
|
||||
OUTPUT_MAP(BasicLSTMCellCStateGrad) = {{0, OUTPUT_DESC(dgate)}, {1, OUTPUT_DESC(dct_1)}};
|
||||
|
||||
// L2Loss
|
||||
INPUT_MAP(L2Loss) = {{1, INPUT_DESC(x)}};
|
||||
ATTR_MAP(L2Loss) = EMPTY_ATTR_MAP;
|
||||
|
|
|
@ -488,6 +488,14 @@ DECLARE_OP_USE_INPUT_ATTR(ApplyRMSPropD)
|
|||
DECLARE_OP_USE_OUTPUT(ApplyRMSPropD)
|
||||
DECLARE_OP_ADAPTER(ApplyCenteredRMSPropD)
|
||||
DECLARE_OP_USE_OUTPUT(ApplyCenteredRMSPropD)
|
||||
DECLARE_OP_ADAPTER(BasicLSTMCell)
|
||||
DECLARE_OP_USE_OUTPUT(BasicLSTMCell)
|
||||
DECLARE_OP_ADAPTER(BasicLSTMCellInputGrad)
|
||||
DECLARE_OP_USE_OUTPUT(BasicLSTMCellInputGrad)
|
||||
DECLARE_OP_ADAPTER(BasicLSTMCellWeightGrad)
|
||||
DECLARE_OP_USE_OUTPUT(BasicLSTMCellWeightGrad)
|
||||
DECLARE_OP_ADAPTER(BasicLSTMCellCStateGrad)
|
||||
DECLARE_OP_USE_OUTPUT(BasicLSTMCellCStateGrad)
|
||||
DECLARE_OP_ADAPTER(L2Loss)
|
||||
DECLARE_OP_USE_OUTPUT(L2Loss)
|
||||
DECLARE_OP_ADAPTER(CTCLoss)
|
||||
|
|
Loading…
Reference in New Issue