!31933 fix addn not dump for gradient

Merge pull request !31933 from huanghui/fix-addn
This commit is contained in:
i-robot 2022-03-29 01:59:52 +00:00 committed by Gitee
commit 6eae8fa1f2
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
4 changed files with 52 additions and 1 deletions

View File

@ -183,6 +183,7 @@ OptimizeIRPassLib::OptimizeIRPassLib() {
// Addn
merge_addn_ = MakeSubstitution(std::make_shared<MergeAddN>(), "merge_addn", prim::kPrimAddN);
addn_zero_filter_ = MakeSubstitution(std::make_shared<AddNZeroFilter>(), "addn_zero_filter", prim::kPrimAddN);
addn_check_dump_ = MakeSubstitution(std::make_shared<AddNCheckDump>(), "addn_check_dump", prim::kPrimAddN);
// AccumulateNV2
accumulaten_eliminater_ =

View File

@ -97,6 +97,7 @@ class OptimizeIRPassLib {
// AddN
SubstitutionPtr merge_addn_;
SubstitutionPtr addn_zero_filter_;
SubstitutionPtr addn_check_dump_;
// AccumulateNV2
SubstitutionPtr accumulaten_eliminater_;

View File

@ -1,5 +1,5 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
* Copyright 2020-2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -225,6 +225,54 @@ class AddNZeroFilter : public AnfVisitor {
std::vector<AnfNodePtr> filtered_Xs_{}, Xs_{};
bool has_zero_like_{false};
};
// {PrimAddN, {kPrimMakeTuple, Xs}}
class AddNCheckDump : public AnfVisitor {
public:
AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override {
Reset();
AnfVisitor::Match(prim::kPrimAddN, {IsCNode})(node);
// Only handle gradient addn.
if (node->scope()->name().find("Gradients/") != 0) {
return nullptr;
}
if (set_dump_) {
AnfUtils::SetDumpFlag(node);
}
return nullptr;
}
void Visit(const CNodePtr &cnode) override {
if (!IsPrimitiveCNode(cnode, prim::kPrimMakeTuple)) {
return;
}
if (cnode->size() < kSizeThree) {
return;
}
// When all of inputs has dump flag, we need set dump flag for AddN.
set_dump_ = true;
for (size_t i = 1; i < cnode->size(); ++i) {
auto input = cnode->input(i);
MS_EXCEPTION_IF_NULL(input);
if (IsPrimitiveCNode(input, prim::kPrimTupleGetItem) || IsPrimitiveCNode(input, prim::kPrimDepend)) {
input = input->cast<CNodePtr>()->input(kIndexOne);
}
if (!input->isa<CNode>() || !AnfUtils::GetDumpFlag(input)) {
set_dump_ = false;
break;
}
}
}
void Reset() { set_dump_ = false; }
private:
bool set_dump_{false};
};
} // namespace irpass
} // namespace opt
} // namespace mindspore

View File

@ -313,6 +313,7 @@ OptPassGroupMap GetOptPassesA(const opt::irpass::OptimizeIRPassLib &irpass) {
irpass.switch_simplify_,
irpass.specialize_transform_,
irpass.merge_addn_,
irpass.addn_check_dump_,
irpass.float_tuple_getitem_switch_,
irpass.float_environ_get_switch_,
irpass.inline_,