diff --git a/mindspore/ccsrc/optimizer/irpass.cc b/mindspore/ccsrc/optimizer/irpass.cc index 0991c31b00a..96d88f6e616 100644 --- a/mindspore/ccsrc/optimizer/irpass.cc +++ b/mindspore/ccsrc/optimizer/irpass.cc @@ -40,6 +40,7 @@ #include "optimizer/irpass/incorporate_getitem.h" #include "optimizer/irpass/incorporate_call.h" #include "optimizer/irpass/grad_var_prepare.h" +#include "optimizer/irpass/param_replace.h" namespace mindspore { namespace opt { @@ -81,6 +82,7 @@ OptimizeIRPassLib::OptimizeIRPassLib() { get_make_ref_eliminate_ = MakeSubstitution(GetMakeRefEliminater(), "get_make_ref_eliminate", {prim::kPrimGetRefKey, prim::kPrimGetRefValue}); replace_refkey_by_param_ = MakeSubstitution(ReplaceRefkeyByParam(), "replace_refkey_by_param", IsValueNode); + replace_old_param_ = MakeSubstitution(ReplaceOldParam(), "replace_old_param", IsParam); // Gradient transforms expand_jprim_ = MakeSubstitution(ExpandJPrim(), "expand_jprim", prim::kPrimJ); diff --git a/mindspore/ccsrc/optimizer/irpass.h b/mindspore/ccsrc/optimizer/irpass.h index bdaf42b3ed1..00274bdcc80 100644 --- a/mindspore/ccsrc/optimizer/irpass.h +++ b/mindspore/ccsrc/optimizer/irpass.h @@ -58,6 +58,7 @@ class OptimizeIRPassLib { SubstitutionPtr make_ref_eliminate_; SubstitutionPtr get_make_ref_eliminate_; SubstitutionPtr replace_refkey_by_param_; + SubstitutionPtr replace_old_param_; // Branch culling SubstitutionPtr switch_simplify_; diff --git a/mindspore/ccsrc/optimizer/irpass/param_replace.h b/mindspore/ccsrc/optimizer/irpass/param_replace.h new file mode 100644 index 00000000000..c0c4c832d7a --- /dev/null +++ b/mindspore/ccsrc/optimizer/irpass/param_replace.h @@ -0,0 +1,60 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_OPTIMIZER_IRPASS_PARAM_REPLACE_H_ +#define MINDSPORE_CCSRC_OPTIMIZER_IRPASS_PARAM_REPLACE_H_ + +#include + +#include "optimizer/optimizer.h" +#include "optimizer/irpass.h" +#include "ir/visitor.h" +#include "operator/ops.h" +#include "pipeline/parse/parse.h" + +namespace mindspore { +namespace opt { +namespace irpass { +class ReplaceOldParam : public AnfVisitor { + public: + AnfNodePtr operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node) override { + if (!IsParam(node)) { + return nullptr; + } + auto resource = std::dynamic_pointer_cast(optimizer->resource()); + MS_EXCEPTION_IF_NULL(resource); + + auto top_graph = resource->func_graph(); // parse::Parser::GetTopFuncGraph(); + MS_EXCEPTION_IF_NULL(top_graph); + + auto param_node = node->cast(); + if (!param_node->has_default() || node->func_graph() == top_graph) { + return nullptr; + } + auto para_name = param_node->name(); + for (const auto &tnode : top_graph->parameters()) { + auto para = tnode->cast(); + if (para != nullptr && para->name() == para_name) { + return para; + } + } + return nullptr; + } +}; +} // namespace irpass +} // namespace opt +} // namespace mindspore +#endif // MINDSPORE_CCSRC_OPTIMIZER_IRPASS_PARAM_REPLACE_H_ diff --git a/mindspore/ccsrc/pipeline/action.cc b/mindspore/ccsrc/pipeline/action.cc index baf4bea7ece..d77fee84aa7 100644 --- a/mindspore/ccsrc/pipeline/action.cc +++ b/mindspore/ccsrc/pipeline/action.cc @@ -88,6 +88,7 @@ FuncGraphPtr Renormalize(const ResourcePtr& res, const FuncGraphPtr& func_graph, double t2 = GetTime(); #endif auto ret = ProgramSpecialize(res, func_graph, result.context); + res->set_func_graph(ret); #ifdef ENABLE_PROFILE double t3 = GetTime(); MsProfile::StatTime("renormalize.infer", t2 - t1); diff --git a/mindspore/ccsrc/pipeline/pass.cc b/mindspore/ccsrc/pipeline/pass.cc index a58ecf41b60..9248590f272 100644 --- a/mindspore/ccsrc/pipeline/pass.cc +++ b/mindspore/ccsrc/pipeline/pass.cc @@ -114,11 +114,9 @@ OptPassGroupMap GetOptPassesA(const opt::irpass::OptimizeIRPassLib& irpass) { opt::OptPassConfig grad = opt::OptPassConfig({irpass.expand_jprim_}, true); opt::irpass::ResolveIRPassLib resolve_irpass; - opt::OptPassConfig resolve_pass = opt::OptPassConfig({ - resolve_irpass.resolver_resolve_, - resolve_irpass.resolver_getattr_, - irpass.get_make_ref_eliminate_, - }); + opt::OptPassConfig resolve_pass = + opt::OptPassConfig({resolve_irpass.resolver_resolve_, resolve_irpass.resolver_getattr_, + irpass.get_make_ref_eliminate_, irpass.replace_old_param_}); OptPassGroupMap map_a({{"a_1", a_1}, {"a_2", a_2}, diff --git a/tests/ut/python/pynative_mode/test_insert_grad_of.py b/tests/ut/python/pynative_mode/test_insert_grad_of.py index 104ac4d1c7b..a11c5fa2b1b 100644 --- a/tests/ut/python/pynative_mode/test_insert_grad_of.py +++ b/tests/ut/python/pynative_mode/test_insert_grad_of.py @@ -129,7 +129,7 @@ def test_cell_assign(): self.matrix_g = mindspore.Parameter(Tensor(np.ones([2, 2], np.float32)), name="matrix_g") def save_gradient(self, dout): - self.matrix_g = dout + self.matrix_g = dout + self.matrix_g return dout def construct(self, x, y):