From 701c0794d965f76e1c858e5b1897d31d79ac2731 Mon Sep 17 00:00:00 2001 From: huanghui Date: Fri, 25 Mar 2022 16:02:47 +0800 Subject: [PATCH] fix addn not dump for gradient --- mindspore/ccsrc/frontend/optimizer/irpass.cc | 1 + mindspore/ccsrc/frontend/optimizer/irpass.h | 1 + .../frontend/optimizer/irpass/merge_addn.h | 50 ++++++++++++++++++- mindspore/ccsrc/pipeline/jit/pass.cc | 1 + 4 files changed, 52 insertions(+), 1 deletion(-) diff --git a/mindspore/ccsrc/frontend/optimizer/irpass.cc b/mindspore/ccsrc/frontend/optimizer/irpass.cc index 023422864f9..afafdece78e 100644 --- a/mindspore/ccsrc/frontend/optimizer/irpass.cc +++ b/mindspore/ccsrc/frontend/optimizer/irpass.cc @@ -183,6 +183,7 @@ OptimizeIRPassLib::OptimizeIRPassLib() { // Addn merge_addn_ = MakeSubstitution(std::make_shared(), "merge_addn", prim::kPrimAddN); addn_zero_filter_ = MakeSubstitution(std::make_shared(), "addn_zero_filter", prim::kPrimAddN); + addn_check_dump_ = MakeSubstitution(std::make_shared(), "addn_check_dump", prim::kPrimAddN); // AccumulateNV2 accumulaten_eliminater_ = diff --git a/mindspore/ccsrc/frontend/optimizer/irpass.h b/mindspore/ccsrc/frontend/optimizer/irpass.h index 45c0ba3105d..d94a667f9a6 100644 --- a/mindspore/ccsrc/frontend/optimizer/irpass.h +++ b/mindspore/ccsrc/frontend/optimizer/irpass.h @@ -97,6 +97,7 @@ class OptimizeIRPassLib { // AddN SubstitutionPtr merge_addn_; SubstitutionPtr addn_zero_filter_; + SubstitutionPtr addn_check_dump_; // AccumulateNV2 SubstitutionPtr accumulaten_eliminater_; diff --git a/mindspore/ccsrc/frontend/optimizer/irpass/merge_addn.h b/mindspore/ccsrc/frontend/optimizer/irpass/merge_addn.h index a86cfca088e..f6645b6fe5d 100644 --- a/mindspore/ccsrc/frontend/optimizer/irpass/merge_addn.h +++ b/mindspore/ccsrc/frontend/optimizer/irpass/merge_addn.h @@ -1,5 +1,5 @@ /** - * Copyright 2020 Huawei Technologies Co., Ltd + * Copyright 2020-2022 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -225,6 +225,54 @@ class AddNZeroFilter : public AnfVisitor { std::vector filtered_Xs_{}, Xs_{}; bool has_zero_like_{false}; }; + +// {PrimAddN, {kPrimMakeTuple, Xs}} +class AddNCheckDump : public AnfVisitor { + public: + AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override { + Reset(); + AnfVisitor::Match(prim::kPrimAddN, {IsCNode})(node); + + // Only handle gradient addn. + if (node->scope()->name().find("Gradients/") != 0) { + return nullptr; + } + + if (set_dump_) { + AnfUtils::SetDumpFlag(node); + } + + return nullptr; + } + + void Visit(const CNodePtr &cnode) override { + if (!IsPrimitiveCNode(cnode, prim::kPrimMakeTuple)) { + return; + } + if (cnode->size() < kSizeThree) { + return; + } + + // When all of inputs has dump flag, we need set dump flag for AddN. + set_dump_ = true; + for (size_t i = 1; i < cnode->size(); ++i) { + auto input = cnode->input(i); + MS_EXCEPTION_IF_NULL(input); + if (IsPrimitiveCNode(input, prim::kPrimTupleGetItem) || IsPrimitiveCNode(input, prim::kPrimDepend)) { + input = input->cast()->input(kIndexOne); + } + if (!input->isa() || !AnfUtils::GetDumpFlag(input)) { + set_dump_ = false; + break; + } + } + } + + void Reset() { set_dump_ = false; } + + private: + bool set_dump_{false}; +}; } // namespace irpass } // namespace opt } // namespace mindspore diff --git a/mindspore/ccsrc/pipeline/jit/pass.cc b/mindspore/ccsrc/pipeline/jit/pass.cc index 51f1a71ed14..a226c292c43 100644 --- a/mindspore/ccsrc/pipeline/jit/pass.cc +++ b/mindspore/ccsrc/pipeline/jit/pass.cc @@ -313,6 +313,7 @@ OptPassGroupMap GetOptPassesA(const opt::irpass::OptimizeIRPassLib &irpass) { irpass.switch_simplify_, irpass.specialize_transform_, irpass.merge_addn_, + irpass.addn_check_dump_, irpass.float_tuple_getitem_switch_, irpass.float_environ_get_switch_, irpass.inline_,