!49859 fix optimizer unify

Merge pull request !49859 from 王禹程/fix_opt
This commit is contained in:
i-robot 2023-03-07 12:57:48 +00:00 committed by Gitee
commit f6fa8a0da5
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
2 changed files with 3 additions and 7 deletions

View File

@ -55,13 +55,9 @@ inline static PredictOutTypeMap out_type_prediction = {{"ActsULQ", kTupleTensor4
{"ApplyAdagradV2", kTupleTensor2},
{"ApplyAdamWithAmsgrad", kTupleTensor4},
{"ApplyAddSign", kTupleTensor2},
{"ApplyCenteredRMSProp", kTupleTensor4},
{"ApplyFtrl", kTupleTensor3},
{"ApplyKerasMomentum", kTupleTensor2},
{"ApplyMomentum", kTupleTensor2},
{"ApplyPowerSign", kTupleTensor2},
{"ApplyProximalAdagrad", kTupleTensor2},
{"ApplyRMSProp", kTupleTensor3},
{"ArgMaxWithValue", kTupleTensor2},
{"ArgMinWithValue", kTupleTensor2},
{"BNTrainingReduce", kTupleTensor2},

View File

@ -102,7 +102,7 @@ void FtrlUnifyOutput::DefineSrcPattern(SrcPattern *src_pattern) {
.AddVar(kOptL1)
.AddVar(kOptL2)
.AddVar(kOptLrPower)
.AddVar(kOptU)
.AddSeqVar(kOptU)
.AddCNode(kMOptimizer, {prim::kPrimApplyFtrl, kOptVar, kOptAccum, kOptLinear, kOptGrad, kOptLr, kOptL1, kOptL2,
kOptLrPower, kOptU});
}
@ -124,7 +124,7 @@ void MomentumUnifyOutput::DefineSrcPattern(SrcPattern *src_pattern) {
.AddVar(kOptLr)
.AddVar(kOptGrad)
.AddVar(kMomentum)
.AddVar(kOptU)
.AddSeqVar(kOptU)
.AddCNode(kMOptimizer, {prim::kPrimApplyMomentum, kOptVar, kOptAccum, kOptLr, kOptGrad, kMomentum, kOptU});
}
@ -164,7 +164,7 @@ void CenteredRMSPropUnifyOutput::DefineSrcPattern(SrcPattern *src_pattern) {
.AddVar(kRho)
.AddVar(kMomentum)
.AddVar(kEpsilon)
.AddVar(kOptU)
.AddSeqVar(kOptU)
.AddCNode(kMOptimizer, {prim::kPrimApplyCenteredRMSProp, kOptVar, kMg, kMs, kMom, kOptGrad, kOptLr, kRho, kMomentum,
kEpsilon, kOptU});
}