!31933 fix addn not dump for gradient
Merge pull request !31933 from huanghui/fix-addn
This commit is contained in:
commit
6eae8fa1f2
|
@ -183,6 +183,7 @@ OptimizeIRPassLib::OptimizeIRPassLib() {
|
|||
// Addn
|
||||
merge_addn_ = MakeSubstitution(std::make_shared<MergeAddN>(), "merge_addn", prim::kPrimAddN);
|
||||
addn_zero_filter_ = MakeSubstitution(std::make_shared<AddNZeroFilter>(), "addn_zero_filter", prim::kPrimAddN);
|
||||
addn_check_dump_ = MakeSubstitution(std::make_shared<AddNCheckDump>(), "addn_check_dump", prim::kPrimAddN);
|
||||
|
||||
// AccumulateNV2
|
||||
accumulaten_eliminater_ =
|
||||
|
|
|
@ -97,6 +97,7 @@ class OptimizeIRPassLib {
|
|||
// AddN
|
||||
SubstitutionPtr merge_addn_;
|
||||
SubstitutionPtr addn_zero_filter_;
|
||||
SubstitutionPtr addn_check_dump_;
|
||||
|
||||
// AccumulateNV2
|
||||
SubstitutionPtr accumulaten_eliminater_;
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
* Copyright 2020-2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -225,6 +225,54 @@ class AddNZeroFilter : public AnfVisitor {
|
|||
std::vector<AnfNodePtr> filtered_Xs_{}, Xs_{};
|
||||
bool has_zero_like_{false};
|
||||
};
|
||||
|
||||
// {PrimAddN, {kPrimMakeTuple, Xs}}
|
||||
class AddNCheckDump : public AnfVisitor {
|
||||
public:
|
||||
AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override {
|
||||
Reset();
|
||||
AnfVisitor::Match(prim::kPrimAddN, {IsCNode})(node);
|
||||
|
||||
// Only handle gradient addn.
|
||||
if (node->scope()->name().find("Gradients/") != 0) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
if (set_dump_) {
|
||||
AnfUtils::SetDumpFlag(node);
|
||||
}
|
||||
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
void Visit(const CNodePtr &cnode) override {
|
||||
if (!IsPrimitiveCNode(cnode, prim::kPrimMakeTuple)) {
|
||||
return;
|
||||
}
|
||||
if (cnode->size() < kSizeThree) {
|
||||
return;
|
||||
}
|
||||
|
||||
// When all of inputs has dump flag, we need set dump flag for AddN.
|
||||
set_dump_ = true;
|
||||
for (size_t i = 1; i < cnode->size(); ++i) {
|
||||
auto input = cnode->input(i);
|
||||
MS_EXCEPTION_IF_NULL(input);
|
||||
if (IsPrimitiveCNode(input, prim::kPrimTupleGetItem) || IsPrimitiveCNode(input, prim::kPrimDepend)) {
|
||||
input = input->cast<CNodePtr>()->input(kIndexOne);
|
||||
}
|
||||
if (!input->isa<CNode>() || !AnfUtils::GetDumpFlag(input)) {
|
||||
set_dump_ = false;
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void Reset() { set_dump_ = false; }
|
||||
|
||||
private:
|
||||
bool set_dump_{false};
|
||||
};
|
||||
} // namespace irpass
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -313,6 +313,7 @@ OptPassGroupMap GetOptPassesA(const opt::irpass::OptimizeIRPassLib &irpass) {
|
|||
irpass.switch_simplify_,
|
||||
irpass.specialize_transform_,
|
||||
irpass.merge_addn_,
|
||||
irpass.addn_check_dump_,
|
||||
irpass.float_tuple_getitem_switch_,
|
||||
irpass.float_environ_get_switch_,
|
||||
irpass.inline_,
|
||||
|
|
Loading…
Reference in New Issue