forked from mindspore-Ecosystem/mindspore
!12461 1、fix square_sum_fusion pattern not match 2、add pass:load_eliminate
From: @Margaret_wangrui Reviewed-by: @hwhewei,@zh_qh Signed-off-by: @zh_qh
This commit is contained in:
commit
cab5ece4e9
|
@ -24,6 +24,7 @@
|
|||
#include "frontend/optimizer/irpass/gradient_eliminate.h"
|
||||
#include "frontend/optimizer/irpass/inline.h"
|
||||
#include "frontend/optimizer/irpass/updatestate_eliminate.h"
|
||||
#include "frontend/optimizer/irpass/load_eliminate.h"
|
||||
#include "frontend/optimizer/irpass/stopgrad_eliminate.h"
|
||||
#include "frontend/optimizer/irpass/incorporate_call.h"
|
||||
#include "frontend/optimizer/irpass/incorporate_getitem.h"
|
||||
|
@ -156,6 +157,9 @@ OptimizeIRPassLib::OptimizeIRPassLib() {
|
|||
switch_call_monad_eliminater_ = MakeSubstitution(std::make_shared<SwitchCallMonadParameterEliminater>(),
|
||||
"switch_call_monad_eliminater", IsCNodeDup);
|
||||
|
||||
// Load eliminate
|
||||
load_eliminater_ = MakeSubstitution(std::make_shared<LoadEliminater>(), "load_eliminater", prim::kPrimLoad);
|
||||
|
||||
// StopGradient eliminate
|
||||
stopgrad_eliminater_ =
|
||||
MakeSubstitution(std::make_shared<StopGradientEliminater>(), "stopgrad_eliminater", prim::kPrimStopGradient);
|
||||
|
|
|
@ -96,6 +96,7 @@ class OptimizeIRPassLib {
|
|||
SubstitutionPtr updatestate_eliminater_;
|
||||
SubstitutionPtr switch_call_monad_eliminater_;
|
||||
SubstitutionPtr stopgrad_eliminater_;
|
||||
SubstitutionPtr load_eliminater_;
|
||||
|
||||
// Incorporation
|
||||
SubstitutionPtr incorporate_getitem_set_;
|
||||
|
|
|
@ -0,0 +1,51 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "frontend/optimizer/irpass/load_eliminate.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <memory>
|
||||
#include <set>
|
||||
#include <vector>
|
||||
|
||||
#include "frontend/operator/ops.h"
|
||||
|
||||
namespace mindspore::opt::irpass {
|
||||
namespace {
|
||||
// Return true if the node has Ref abstract.
|
||||
bool HasAbstractRef(const AnfNodePtr &node) {
|
||||
if (node == nullptr) {
|
||||
return false;
|
||||
}
|
||||
auto &abs = node->abstract();
|
||||
return (abs != nullptr) && abs->isa<abstract::AbstractRef>();
|
||||
}
|
||||
} // namespace
|
||||
|
||||
AnfNodePtr LoadEliminater::operator()(const OptimizerPtr &, const AnfNodePtr &node) {
|
||||
auto load_node = dyn_cast<CNode>(node);
|
||||
if (load_node == nullptr || load_node->inputs().empty()) {
|
||||
MS_LOG(WARNING) << "LoadEliminater encounter invalid node: " << node->DebugString();
|
||||
return nullptr;
|
||||
}
|
||||
constexpr size_t kFirstInputIndex = 1;
|
||||
auto ¶m = load_node->inputs().at(kFirstInputIndex);
|
||||
if (!HasAbstractRef(param)) {
|
||||
return param;
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
} // namespace mindspore::opt::irpass
|
|
@ -0,0 +1,33 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_IRPASS_LOAD_ELIMINATE_H_
|
||||
#define MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_IRPASS_LOAD_ELIMINATE_H_
|
||||
|
||||
#include "frontend/optimizer/optimizer.h"
|
||||
#include "frontend/optimizer/anf_visitor.h"
|
||||
|
||||
namespace mindspore::opt::irpass {
|
||||
//
|
||||
// LoadEliminater eliminates redundant Load related nodes.
|
||||
//
|
||||
class LoadEliminater : public AnfVisitor {
|
||||
public:
|
||||
AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override;
|
||||
};
|
||||
|
||||
} // namespace mindspore::opt::irpass
|
||||
#endif // MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_IRPASS_LOAD_ELIMINATE_H_
|
|
@ -103,6 +103,7 @@ OptPassGroupMap GetOptPassesA(const opt::irpass::OptimizeIRPassLib &irpass) {
|
|||
// Safe inlining
|
||||
irpass.inline_,
|
||||
irpass.updatestate_eliminater_,
|
||||
irpass.load_eliminater_,
|
||||
irpass.stopgrad_eliminater_,
|
||||
irpass.partial_eliminate_,
|
||||
irpass.replace_applicator_,
|
||||
|
@ -130,6 +131,7 @@ OptPassGroupMap GetOptPassesA(const opt::irpass::OptimizeIRPassLib &irpass) {
|
|||
// Safe inlining
|
||||
irpass.inline_,
|
||||
irpass.updatestate_eliminater_,
|
||||
irpass.load_eliminater_,
|
||||
irpass.stopgrad_eliminater_,
|
||||
irpass.sparse_tensor_eliminate_,
|
||||
});
|
||||
|
@ -195,6 +197,7 @@ OptPassGroupMap GetOptPassesAfterCconv(const opt::irpass::OptimizeIRPassLib &irp
|
|||
// Safe inlining,
|
||||
irpass.inline_,
|
||||
irpass.updatestate_eliminater_,
|
||||
irpass.load_eliminater_,
|
||||
irpass.switch_call_monad_eliminater_,
|
||||
irpass.stopgrad_eliminater_,
|
||||
irpass.partial_eliminate_,
|
||||
|
@ -220,9 +223,9 @@ OptPassGroupMap GetOptPassesTransformGraph(const opt::irpass::OptimizeIRPassLib
|
|||
OptPassGroupMap GetOptPassesB(const opt::irpass::OptimizeIRPassLib &irpass) {
|
||||
opt::OptPassConfig b_1 = opt::OptPassConfig(
|
||||
{irpass.zero_like_fill_zero_, irpass.item_tuple_or_list_eliminate_, irpass.float_tuple_getitem_switch_,
|
||||
irpass.reset_defer_inline_, irpass.inline_, irpass.updatestate_eliminater_, irpass.stopgrad_eliminater_,
|
||||
irpass.special_op_eliminate_, irpass.get_make_ref_eliminate_, irpass.incorporate_env_getitem_,
|
||||
irpass.incorporate_env_getitem_switch_, irpass.env_get_item_eliminate_,
|
||||
irpass.reset_defer_inline_, irpass.inline_, irpass.updatestate_eliminater_, irpass.load_eliminater_,
|
||||
irpass.stopgrad_eliminater_, irpass.special_op_eliminate_, irpass.get_make_ref_eliminate_,
|
||||
irpass.incorporate_env_getitem_, irpass.incorporate_env_getitem_switch_, irpass.env_get_item_eliminate_,
|
||||
irpass.incorporate_env_getitem_switch_layer_, irpass.value_based_eliminate_, irpass.receive_eliminate_});
|
||||
opt::OptPassConfig b_2 = opt::OptPassConfig({
|
||||
irpass.replace_refkey_by_param_,
|
||||
|
|
|
@ -86,6 +86,10 @@ AbstractBasePtr InferImplSquare(const AnalysisEnginePtr &, const PrimitivePtr &p
|
|||
// Inputs: one tensor.
|
||||
const std::string op_name = primitive->name();
|
||||
CheckArgsSize(op_name, args_spec_list, 1);
|
||||
auto ref = dyn_cast<abstract::AbstractRef>(args_spec_list[0]);
|
||||
if (ref != nullptr) {
|
||||
return ref->CloneAsTensor();
|
||||
}
|
||||
return args_spec_list[0]->Broaden();
|
||||
}
|
||||
|
||||
|
|
Loading…
Reference in New Issue