forked from mindspore-Ecosystem/mindspore
fix optimizer unify
This commit is contained in:
parent
232aede207
commit
8c38f35bfc
|
@ -55,13 +55,9 @@ inline static PredictOutTypeMap out_type_prediction = {{"ActsULQ", kTupleTensor4
|
|||
{"ApplyAdagradV2", kTupleTensor2},
|
||||
{"ApplyAdamWithAmsgrad", kTupleTensor4},
|
||||
{"ApplyAddSign", kTupleTensor2},
|
||||
{"ApplyCenteredRMSProp", kTupleTensor4},
|
||||
{"ApplyFtrl", kTupleTensor3},
|
||||
{"ApplyKerasMomentum", kTupleTensor2},
|
||||
{"ApplyMomentum", kTupleTensor2},
|
||||
{"ApplyPowerSign", kTupleTensor2},
|
||||
{"ApplyProximalAdagrad", kTupleTensor2},
|
||||
{"ApplyRMSProp", kTupleTensor3},
|
||||
{"ArgMaxWithValue", kTupleTensor2},
|
||||
{"ArgMinWithValue", kTupleTensor2},
|
||||
{"BNTrainingReduce", kTupleTensor2},
|
||||
|
|
|
@ -102,7 +102,7 @@ void FtrlUnifyOutput::DefineSrcPattern(SrcPattern *src_pattern) {
|
|||
.AddVar(kOptL1)
|
||||
.AddVar(kOptL2)
|
||||
.AddVar(kOptLrPower)
|
||||
.AddVar(kOptU)
|
||||
.AddSeqVar(kOptU)
|
||||
.AddCNode(kMOptimizer, {prim::kPrimApplyFtrl, kOptVar, kOptAccum, kOptLinear, kOptGrad, kOptLr, kOptL1, kOptL2,
|
||||
kOptLrPower, kOptU});
|
||||
}
|
||||
|
@ -124,7 +124,7 @@ void MomentumUnifyOutput::DefineSrcPattern(SrcPattern *src_pattern) {
|
|||
.AddVar(kOptLr)
|
||||
.AddVar(kOptGrad)
|
||||
.AddVar(kMomentum)
|
||||
.AddVar(kOptU)
|
||||
.AddSeqVar(kOptU)
|
||||
.AddCNode(kMOptimizer, {prim::kPrimApplyMomentum, kOptVar, kOptAccum, kOptLr, kOptGrad, kMomentum, kOptU});
|
||||
}
|
||||
|
||||
|
@ -164,7 +164,7 @@ void CenteredRMSPropUnifyOutput::DefineSrcPattern(SrcPattern *src_pattern) {
|
|||
.AddVar(kRho)
|
||||
.AddVar(kMomentum)
|
||||
.AddVar(kEpsilon)
|
||||
.AddVar(kOptU)
|
||||
.AddSeqVar(kOptU)
|
||||
.AddCNode(kMOptimizer, {prim::kPrimApplyCenteredRMSProp, kOptVar, kMg, kMs, kMom, kOptGrad, kOptLr, kRho, kMomentum,
|
||||
kEpsilon, kOptU});
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue