diff --git a/mindspore/ccsrc/frontend/optimizer/irpass.cc b/mindspore/ccsrc/frontend/optimizer/irpass.cc index 80170019e7b..e7547e521a1 100644 --- a/mindspore/ccsrc/frontend/optimizer/irpass.cc +++ b/mindspore/ccsrc/frontend/optimizer/irpass.cc @@ -24,6 +24,7 @@ #include "frontend/optimizer/irpass/gradient_eliminate.h" #include "frontend/optimizer/irpass/inline.h" #include "frontend/optimizer/irpass/updatestate_eliminate.h" +#include "frontend/optimizer/irpass/load_eliminate.h" #include "frontend/optimizer/irpass/stopgrad_eliminate.h" #include "frontend/optimizer/irpass/incorporate_call.h" #include "frontend/optimizer/irpass/incorporate_getitem.h" @@ -156,6 +157,9 @@ OptimizeIRPassLib::OptimizeIRPassLib() { switch_call_monad_eliminater_ = MakeSubstitution(std::make_shared(), "switch_call_monad_eliminater", IsCNodeDup); + // Load eliminate + load_eliminater_ = MakeSubstitution(std::make_shared(), "load_eliminater", prim::kPrimLoad); + // StopGradient eliminate stopgrad_eliminater_ = MakeSubstitution(std::make_shared(), "stopgrad_eliminater", prim::kPrimStopGradient); diff --git a/mindspore/ccsrc/frontend/optimizer/irpass.h b/mindspore/ccsrc/frontend/optimizer/irpass.h index 003e54d939e..ef43a751c27 100644 --- a/mindspore/ccsrc/frontend/optimizer/irpass.h +++ b/mindspore/ccsrc/frontend/optimizer/irpass.h @@ -96,6 +96,7 @@ class OptimizeIRPassLib { SubstitutionPtr updatestate_eliminater_; SubstitutionPtr switch_call_monad_eliminater_; SubstitutionPtr stopgrad_eliminater_; + SubstitutionPtr load_eliminater_; // Incorporation SubstitutionPtr incorporate_getitem_set_; diff --git a/mindspore/ccsrc/frontend/optimizer/irpass/load_eliminate.cc b/mindspore/ccsrc/frontend/optimizer/irpass/load_eliminate.cc new file mode 100644 index 00000000000..9a2785c51ed --- /dev/null +++ b/mindspore/ccsrc/frontend/optimizer/irpass/load_eliminate.cc @@ -0,0 +1,51 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "frontend/optimizer/irpass/load_eliminate.h" + +#include +#include +#include +#include + +#include "frontend/operator/ops.h" + +namespace mindspore::opt::irpass { +namespace { +// Return true if the node has Ref abstract. +bool HasAbstractRef(const AnfNodePtr &node) { + if (node == nullptr) { + return false; + } + auto &abs = node->abstract(); + return (abs != nullptr) && abs->isa(); +} +} // namespace + +AnfNodePtr LoadEliminater::operator()(const OptimizerPtr &, const AnfNodePtr &node) { + auto load_node = dyn_cast(node); + if (load_node == nullptr || load_node->inputs().empty()) { + MS_LOG(WARNING) << "LoadEliminater encounter invalid node: " << node->DebugString(); + return nullptr; + } + constexpr size_t kFirstInputIndex = 1; + auto ¶m = load_node->inputs().at(kFirstInputIndex); + if (!HasAbstractRef(param)) { + return param; + } + return nullptr; +} +} // namespace mindspore::opt::irpass diff --git a/mindspore/ccsrc/frontend/optimizer/irpass/load_eliminate.h b/mindspore/ccsrc/frontend/optimizer/irpass/load_eliminate.h new file mode 100644 index 00000000000..a1a0eee2484 --- /dev/null +++ b/mindspore/ccsrc/frontend/optimizer/irpass/load_eliminate.h @@ -0,0 +1,33 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_IRPASS_LOAD_ELIMINATE_H_ +#define MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_IRPASS_LOAD_ELIMINATE_H_ + +#include "frontend/optimizer/optimizer.h" +#include "frontend/optimizer/anf_visitor.h" + +namespace mindspore::opt::irpass { +// +// LoadEliminater eliminates redundant Load related nodes. +// +class LoadEliminater : public AnfVisitor { + public: + AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override; +}; + +} // namespace mindspore::opt::irpass +#endif // MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_IRPASS_LOAD_ELIMINATE_H_ diff --git a/mindspore/ccsrc/pipeline/jit/pass.cc b/mindspore/ccsrc/pipeline/jit/pass.cc index 9d1ab1a7c71..8ef2e39a942 100644 --- a/mindspore/ccsrc/pipeline/jit/pass.cc +++ b/mindspore/ccsrc/pipeline/jit/pass.cc @@ -103,6 +103,7 @@ OptPassGroupMap GetOptPassesA(const opt::irpass::OptimizeIRPassLib &irpass) { // Safe inlining irpass.inline_, irpass.updatestate_eliminater_, + irpass.load_eliminater_, irpass.stopgrad_eliminater_, irpass.partial_eliminate_, irpass.replace_applicator_, @@ -130,6 +131,7 @@ OptPassGroupMap GetOptPassesA(const opt::irpass::OptimizeIRPassLib &irpass) { // Safe inlining irpass.inline_, irpass.updatestate_eliminater_, + irpass.load_eliminater_, irpass.stopgrad_eliminater_, irpass.sparse_tensor_eliminate_, }); @@ -195,6 +197,7 @@ OptPassGroupMap GetOptPassesAfterCconv(const opt::irpass::OptimizeIRPassLib &irp // Safe inlining, irpass.inline_, irpass.updatestate_eliminater_, + irpass.load_eliminater_, irpass.switch_call_monad_eliminater_, irpass.stopgrad_eliminater_, irpass.partial_eliminate_, @@ -220,9 +223,9 @@ OptPassGroupMap GetOptPassesTransformGraph(const opt::irpass::OptimizeIRPassLib OptPassGroupMap GetOptPassesB(const opt::irpass::OptimizeIRPassLib &irpass) { opt::OptPassConfig b_1 = opt::OptPassConfig( {irpass.zero_like_fill_zero_, irpass.item_tuple_or_list_eliminate_, irpass.float_tuple_getitem_switch_, - irpass.reset_defer_inline_, irpass.inline_, irpass.updatestate_eliminater_, irpass.stopgrad_eliminater_, - irpass.special_op_eliminate_, irpass.get_make_ref_eliminate_, irpass.incorporate_env_getitem_, - irpass.incorporate_env_getitem_switch_, irpass.env_get_item_eliminate_, + irpass.reset_defer_inline_, irpass.inline_, irpass.updatestate_eliminater_, irpass.load_eliminater_, + irpass.stopgrad_eliminater_, irpass.special_op_eliminate_, irpass.get_make_ref_eliminate_, + irpass.incorporate_env_getitem_, irpass.incorporate_env_getitem_switch_, irpass.env_get_item_eliminate_, irpass.incorporate_env_getitem_switch_layer_, irpass.value_based_eliminate_, irpass.receive_eliminate_}); opt::OptPassConfig b_2 = opt::OptPassConfig({ irpass.replace_refkey_by_param_, diff --git a/mindspore/core/abstract/prim_maths.cc b/mindspore/core/abstract/prim_maths.cc index 7bcd7759d2a..81e2841fd54 100644 --- a/mindspore/core/abstract/prim_maths.cc +++ b/mindspore/core/abstract/prim_maths.cc @@ -86,6 +86,10 @@ AbstractBasePtr InferImplSquare(const AnalysisEnginePtr &, const PrimitivePtr &p // Inputs: one tensor. const std::string op_name = primitive->name(); CheckArgsSize(op_name, args_spec_list, 1); + auto ref = dyn_cast(args_spec_list[0]); + if (ref != nullptr) { + return ref->CloneAsTensor(); + } return args_spec_list[0]->Broaden(); }