!2543 fix applyrmsprop to while list

Merge pull request !2543 from amongo/FixApplyRmsprop
This commit is contained in:
mindspore-ci-bot 2020-06-24 14:40:50 +08:00 committed by Gitee
commit 585acc984d
3 changed files with 26 additions and 12 deletions

View File

@ -231,6 +231,7 @@ const PrimitivePtr kPrimFakeBprop = std::make_shared<Primitive>("fake_bprop");
const PrimitivePtr kPrimBpropCut = std::make_shared<Primitive>("bprop_cut");
const PrimitivePtr kPrimFakeQuantPerLayer = std::make_shared<Primitive>("FakeQuantPerLayer");
const PrimitivePtr kPrimFakeQuantPerChannel = std::make_shared<Primitive>("FakeQuantPerChannel");
const PrimitivePtr kPrimApplyRMSProp = std::make_shared<Primitive>("ApplyRMSProp");
// Other miscellaneous
const PrimitivePtr kPrimIdentity = std::make_shared<Primitive>("identity");

View File

@ -244,6 +244,7 @@ extern const PrimitivePtr kPrimFakeBprop;
extern const PrimitivePtr kPrimBpropCut;
extern const PrimitivePtr kPrimFakeQuantPerLayer;
extern const PrimitivePtr kPrimFakeQuantPerChannel;
extern const PrimitivePtr kPrimApplyRMSProp;
// Other Miscellaneous
extern const PrimitivePtr kPrimIdentity;

View File

@ -51,18 +51,30 @@ bool InConvertWhiteList(const AnfNodePtr &node, size_t index) {
// node because it is attribute or ge specific reason.
// Example : when convert CNode(kPrimReduceSum, x, axis), node of index 2 in CNode->inputs is axis which should not be
// converted to switch guarded.
std::vector<std::pair<PrimitivePtr, std::vector<size_t>>> white_list(
{{prim::kPrimApplyMomentum, {1, 2}}, {prim::kPrimMomentum, {2, 3}},
{prim::kPrimStateSetItem, {1}}, {prim::kPrimTupleGetItem, {2}},
{prim::kPrimEnvGetItem, {1}}, {prim::kPrimEnvSetItem, {1}},
{prim::kPrimReduceSum, {2}}, {prim::kPrimReduceMean, {2}},
{prim::kPrimReduceAll, {2}}, {prim::kPrimCast, {2}},
{prim::kPrimTranspose, {2}}, {prim::kPrimOneHot, {2}},
{prim::kPrimGatherV2, {3}}, {prim::kPrimReshape, {2}},
{prim::kPrimAssign, {1}}, {prim::kPrimAssignAdd, {1}},
{prim::kPrimAssignSub, {1}}, {prim::kPrimTensorSummary, {1}},
{prim::kPrimImageSummary, {1}}, {prim::kPrimScalarSummary, {1}},
{prim::kPrimHistogramSummary, {1}}});
std::vector<std::pair<PrimitivePtr, std::vector<size_t>>> white_list({{prim::kPrimApplyMomentum, {1, 2}},
{prim::kPrimMomentum, {2, 3}},
{prim::kPrimStateSetItem, {1}},
{prim::kPrimTupleGetItem, {2}},
{prim::kPrimEnvGetItem, {1}},
{prim::kPrimEnvSetItem, {1}},
{prim::kPrimReduceSum, {2}},
{prim::kPrimReduceMean, {2}},
{prim::kPrimReduceAll, {2}},
{prim::kPrimCast, {2}},
{prim::kPrimTranspose, {2}},
{prim::kPrimOneHot, {2}},
{prim::kPrimGatherV2, {3}},
{prim::kPrimReshape, {2}},
{prim::kPrimAssign, {1}},
{prim::kPrimAssignAdd, {1}},
{prim::kPrimAssignSub, {1}},
{prim::kPrimTensorSummary, {1}},
{prim::kPrimImageSummary, {1}},
{prim::kPrimScalarSummary, {1}},
{prim::kPrimApplyRMSProp, {6, 7, 8}},
{prim::kPrimCumSum, {2}},
{prim::kPrimTile, {2}},
{prim::kPrimHistogramSummary, {1}}});
for (auto &item : white_list) {
auto matched = std::any_of(item.second.begin(), item.second.end(), [&item, &node, &index](size_t idx) {
return IsPrimitiveCNode(node, item.first) && idx == index;