forked from mindspore-Ecosystem/mindspore
!2543 fix applyrmsprop to while list
Merge pull request !2543 from amongo/FixApplyRmsprop
This commit is contained in:
commit
585acc984d
|
@ -231,6 +231,7 @@ const PrimitivePtr kPrimFakeBprop = std::make_shared<Primitive>("fake_bprop");
|
|||
const PrimitivePtr kPrimBpropCut = std::make_shared<Primitive>("bprop_cut");
|
||||
const PrimitivePtr kPrimFakeQuantPerLayer = std::make_shared<Primitive>("FakeQuantPerLayer");
|
||||
const PrimitivePtr kPrimFakeQuantPerChannel = std::make_shared<Primitive>("FakeQuantPerChannel");
|
||||
const PrimitivePtr kPrimApplyRMSProp = std::make_shared<Primitive>("ApplyRMSProp");
|
||||
|
||||
// Other miscellaneous
|
||||
const PrimitivePtr kPrimIdentity = std::make_shared<Primitive>("identity");
|
||||
|
|
|
@ -244,6 +244,7 @@ extern const PrimitivePtr kPrimFakeBprop;
|
|||
extern const PrimitivePtr kPrimBpropCut;
|
||||
extern const PrimitivePtr kPrimFakeQuantPerLayer;
|
||||
extern const PrimitivePtr kPrimFakeQuantPerChannel;
|
||||
extern const PrimitivePtr kPrimApplyRMSProp;
|
||||
|
||||
// Other Miscellaneous
|
||||
extern const PrimitivePtr kPrimIdentity;
|
||||
|
|
|
@ -51,18 +51,30 @@ bool InConvertWhiteList(const AnfNodePtr &node, size_t index) {
|
|||
// node because it is attribute or ge specific reason.
|
||||
// Example : when convert CNode(kPrimReduceSum, x, axis), node of index 2 in CNode->inputs is axis which should not be
|
||||
// converted to switch guarded.
|
||||
std::vector<std::pair<PrimitivePtr, std::vector<size_t>>> white_list(
|
||||
{{prim::kPrimApplyMomentum, {1, 2}}, {prim::kPrimMomentum, {2, 3}},
|
||||
{prim::kPrimStateSetItem, {1}}, {prim::kPrimTupleGetItem, {2}},
|
||||
{prim::kPrimEnvGetItem, {1}}, {prim::kPrimEnvSetItem, {1}},
|
||||
{prim::kPrimReduceSum, {2}}, {prim::kPrimReduceMean, {2}},
|
||||
{prim::kPrimReduceAll, {2}}, {prim::kPrimCast, {2}},
|
||||
{prim::kPrimTranspose, {2}}, {prim::kPrimOneHot, {2}},
|
||||
{prim::kPrimGatherV2, {3}}, {prim::kPrimReshape, {2}},
|
||||
{prim::kPrimAssign, {1}}, {prim::kPrimAssignAdd, {1}},
|
||||
{prim::kPrimAssignSub, {1}}, {prim::kPrimTensorSummary, {1}},
|
||||
{prim::kPrimImageSummary, {1}}, {prim::kPrimScalarSummary, {1}},
|
||||
{prim::kPrimHistogramSummary, {1}}});
|
||||
std::vector<std::pair<PrimitivePtr, std::vector<size_t>>> white_list({{prim::kPrimApplyMomentum, {1, 2}},
|
||||
{prim::kPrimMomentum, {2, 3}},
|
||||
{prim::kPrimStateSetItem, {1}},
|
||||
{prim::kPrimTupleGetItem, {2}},
|
||||
{prim::kPrimEnvGetItem, {1}},
|
||||
{prim::kPrimEnvSetItem, {1}},
|
||||
{prim::kPrimReduceSum, {2}},
|
||||
{prim::kPrimReduceMean, {2}},
|
||||
{prim::kPrimReduceAll, {2}},
|
||||
{prim::kPrimCast, {2}},
|
||||
{prim::kPrimTranspose, {2}},
|
||||
{prim::kPrimOneHot, {2}},
|
||||
{prim::kPrimGatherV2, {3}},
|
||||
{prim::kPrimReshape, {2}},
|
||||
{prim::kPrimAssign, {1}},
|
||||
{prim::kPrimAssignAdd, {1}},
|
||||
{prim::kPrimAssignSub, {1}},
|
||||
{prim::kPrimTensorSummary, {1}},
|
||||
{prim::kPrimImageSummary, {1}},
|
||||
{prim::kPrimScalarSummary, {1}},
|
||||
{prim::kPrimApplyRMSProp, {6, 7, 8}},
|
||||
{prim::kPrimCumSum, {2}},
|
||||
{prim::kPrimTile, {2}},
|
||||
{prim::kPrimHistogramSummary, {1}}});
|
||||
for (auto &item : white_list) {
|
||||
auto matched = std::any_of(item.second.begin(), item.second.end(), [&item, &node, &index](size_t idx) {
|
||||
return IsPrimitiveCNode(node, item.first) && idx == index;
|
||||
|
|
Loading…
Reference in New Issue