fix optimizer unify

This commit is contained in:
reku1997 2023-03-06 20:02:11 +08:00
parent 232aede207
commit 8c38f35bfc
2 changed files with 3 additions and 7 deletions

View File

@ -55,13 +55,9 @@ inline static PredictOutTypeMap out_type_prediction = {{"ActsULQ", kTupleTensor4
{"ApplyAdagradV2", kTupleTensor2},
{"ApplyAdamWithAmsgrad", kTupleTensor4},
{"ApplyAddSign", kTupleTensor2},
{"ApplyCenteredRMSProp", kTupleTensor4},
{"ApplyFtrl", kTupleTensor3},
{"ApplyKerasMomentum", kTupleTensor2},
{"ApplyMomentum", kTupleTensor2},
{"ApplyPowerSign", kTupleTensor2},
{"ApplyProximalAdagrad", kTupleTensor2},
{"ApplyRMSProp", kTupleTensor3},
{"ArgMaxWithValue", kTupleTensor2},
{"ArgMinWithValue", kTupleTensor2},
{"BNTrainingReduce", kTupleTensor2},

View File

@ -102,7 +102,7 @@ void FtrlUnifyOutput::DefineSrcPattern(SrcPattern *src_pattern) {
.AddVar(kOptL1)
.AddVar(kOptL2)
.AddVar(kOptLrPower)
.AddVar(kOptU)
.AddSeqVar(kOptU)
.AddCNode(kMOptimizer, {prim::kPrimApplyFtrl, kOptVar, kOptAccum, kOptLinear, kOptGrad, kOptLr, kOptL1, kOptL2,
kOptLrPower, kOptU});
}
@ -124,7 +124,7 @@ void MomentumUnifyOutput::DefineSrcPattern(SrcPattern *src_pattern) {
.AddVar(kOptLr)
.AddVar(kOptGrad)
.AddVar(kMomentum)
.AddVar(kOptU)
.AddSeqVar(kOptU)
.AddCNode(kMOptimizer, {prim::kPrimApplyMomentum, kOptVar, kOptAccum, kOptLr, kOptGrad, kMomentum, kOptU});
}
@ -164,7 +164,7 @@ void CenteredRMSPropUnifyOutput::DefineSrcPattern(SrcPattern *src_pattern) {
.AddVar(kRho)
.AddVar(kMomentum)
.AddVar(kEpsilon)
.AddVar(kOptU)
.AddSeqVar(kOptU)
.AddCNode(kMOptimizer, {prim::kPrimApplyCenteredRMSProp, kOptVar, kMg, kMs, kMom, kOptGrad, kOptLr, kRho, kMomentum,
kEpsilon, kOptU});
}