!12461 1、fix square_sum_fusion pattern not match 2、add pass:load_eliminate

From: @Margaret_wangrui
Reviewed-by: @hwhewei,@zh_qh
Signed-off-by: @zh_qh
This commit is contained in:
mindspore-ci-bot 2021-02-22 15:55:39 +08:00 committed by Gitee
commit cab5ece4e9
6 changed files with 99 additions and 3 deletions

View File

@ -24,6 +24,7 @@
#include "frontend/optimizer/irpass/gradient_eliminate.h"
#include "frontend/optimizer/irpass/inline.h"
#include "frontend/optimizer/irpass/updatestate_eliminate.h"
#include "frontend/optimizer/irpass/load_eliminate.h"
#include "frontend/optimizer/irpass/stopgrad_eliminate.h"
#include "frontend/optimizer/irpass/incorporate_call.h"
#include "frontend/optimizer/irpass/incorporate_getitem.h"
@ -156,6 +157,9 @@ OptimizeIRPassLib::OptimizeIRPassLib() {
switch_call_monad_eliminater_ = MakeSubstitution(std::make_shared<SwitchCallMonadParameterEliminater>(),
"switch_call_monad_eliminater", IsCNodeDup);
// Load eliminate
load_eliminater_ = MakeSubstitution(std::make_shared<LoadEliminater>(), "load_eliminater", prim::kPrimLoad);
// StopGradient eliminate
stopgrad_eliminater_ =
MakeSubstitution(std::make_shared<StopGradientEliminater>(), "stopgrad_eliminater", prim::kPrimStopGradient);

View File

@ -96,6 +96,7 @@ class OptimizeIRPassLib {
SubstitutionPtr updatestate_eliminater_;
SubstitutionPtr switch_call_monad_eliminater_;
SubstitutionPtr stopgrad_eliminater_;
SubstitutionPtr load_eliminater_;
// Incorporation
SubstitutionPtr incorporate_getitem_set_;

View File

@ -0,0 +1,51 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "frontend/optimizer/irpass/load_eliminate.h"
#include <algorithm>
#include <memory>
#include <set>
#include <vector>
#include "frontend/operator/ops.h"
namespace mindspore::opt::irpass {
namespace {
// Return true if the node has Ref abstract.
bool HasAbstractRef(const AnfNodePtr &node) {
if (node == nullptr) {
return false;
}
auto &abs = node->abstract();
return (abs != nullptr) && abs->isa<abstract::AbstractRef>();
}
} // namespace
AnfNodePtr LoadEliminater::operator()(const OptimizerPtr &, const AnfNodePtr &node) {
auto load_node = dyn_cast<CNode>(node);
if (load_node == nullptr || load_node->inputs().empty()) {
MS_LOG(WARNING) << "LoadEliminater encounter invalid node: " << node->DebugString();
return nullptr;
}
constexpr size_t kFirstInputIndex = 1;
auto &param = load_node->inputs().at(kFirstInputIndex);
if (!HasAbstractRef(param)) {
return param;
}
return nullptr;
}
} // namespace mindspore::opt::irpass

View File

@ -0,0 +1,33 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_IRPASS_LOAD_ELIMINATE_H_
#define MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_IRPASS_LOAD_ELIMINATE_H_
#include "frontend/optimizer/optimizer.h"
#include "frontend/optimizer/anf_visitor.h"
namespace mindspore::opt::irpass {
//
// LoadEliminater eliminates redundant Load related nodes.
//
class LoadEliminater : public AnfVisitor {
public:
AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override;
};
} // namespace mindspore::opt::irpass
#endif // MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_IRPASS_LOAD_ELIMINATE_H_

View File

@ -103,6 +103,7 @@ OptPassGroupMap GetOptPassesA(const opt::irpass::OptimizeIRPassLib &irpass) {
// Safe inlining
irpass.inline_,
irpass.updatestate_eliminater_,
irpass.load_eliminater_,
irpass.stopgrad_eliminater_,
irpass.partial_eliminate_,
irpass.replace_applicator_,
@ -130,6 +131,7 @@ OptPassGroupMap GetOptPassesA(const opt::irpass::OptimizeIRPassLib &irpass) {
// Safe inlining
irpass.inline_,
irpass.updatestate_eliminater_,
irpass.load_eliminater_,
irpass.stopgrad_eliminater_,
irpass.sparse_tensor_eliminate_,
});
@ -195,6 +197,7 @@ OptPassGroupMap GetOptPassesAfterCconv(const opt::irpass::OptimizeIRPassLib &irp
// Safe inlining,
irpass.inline_,
irpass.updatestate_eliminater_,
irpass.load_eliminater_,
irpass.switch_call_monad_eliminater_,
irpass.stopgrad_eliminater_,
irpass.partial_eliminate_,
@ -220,9 +223,9 @@ OptPassGroupMap GetOptPassesTransformGraph(const opt::irpass::OptimizeIRPassLib
OptPassGroupMap GetOptPassesB(const opt::irpass::OptimizeIRPassLib &irpass) {
opt::OptPassConfig b_1 = opt::OptPassConfig(
{irpass.zero_like_fill_zero_, irpass.item_tuple_or_list_eliminate_, irpass.float_tuple_getitem_switch_,
irpass.reset_defer_inline_, irpass.inline_, irpass.updatestate_eliminater_, irpass.stopgrad_eliminater_,
irpass.special_op_eliminate_, irpass.get_make_ref_eliminate_, irpass.incorporate_env_getitem_,
irpass.incorporate_env_getitem_switch_, irpass.env_get_item_eliminate_,
irpass.reset_defer_inline_, irpass.inline_, irpass.updatestate_eliminater_, irpass.load_eliminater_,
irpass.stopgrad_eliminater_, irpass.special_op_eliminate_, irpass.get_make_ref_eliminate_,
irpass.incorporate_env_getitem_, irpass.incorporate_env_getitem_switch_, irpass.env_get_item_eliminate_,
irpass.incorporate_env_getitem_switch_layer_, irpass.value_based_eliminate_, irpass.receive_eliminate_});
opt::OptPassConfig b_2 = opt::OptPassConfig({
irpass.replace_refkey_by_param_,

View File

@ -86,6 +86,10 @@ AbstractBasePtr InferImplSquare(const AnalysisEnginePtr &, const PrimitivePtr &p
// Inputs: one tensor.
const std::string op_name = primitive->name();
CheckArgsSize(op_name, args_spec_list, 1);
auto ref = dyn_cast<abstract::AbstractRef>(args_spec_list[0]);
if (ref != nullptr) {
return ref->CloneAsTensor();
}
return args_spec_list[0]->Broaden();
}