forked from mindspore-Ecosystem/mindspore
1. modify the execution method of substitution in irpass; 2. Refactor resolve pass; 3. Split UpdatestateEliminater into multiple passes.
This commit is contained in:
parent
e187cfc889
commit
5b59d0104f
|
@ -186,8 +186,19 @@ OptimizeIRPassLib::OptimizeIRPassLib() {
|
|||
MakeSubstitution(std::make_shared<SpecializeOnGraphArguments>(), "specialize_transform", IsCNodeGraph);
|
||||
|
||||
// UpdateState eliminate
|
||||
updatestate_eliminater_ =
|
||||
MakeSubstitution(std::make_shared<UpdatestateEliminater>(), "updatestate_eliminater", prim::kPrimUpdateState);
|
||||
updatestate_only_used_node_eliminater_ =
|
||||
MakeSubstitution(std::make_shared<UpdatestateOnlyUsedNodeEliminater>(), "updatestate_only_used_node_eliminater",
|
||||
prim::kPrimUpdateState);
|
||||
updatestate_pure_node_eliminater_ = MakeSubstitution(std::make_shared<UpdatestatePureNodeEliminater>(),
|
||||
"updatestate_pure_node_eliminater", prim::kPrimUpdateState);
|
||||
updatestate_depend_eliminater_ = MakeSubstitution(std::make_shared<UpdatestateDependEliminater>(),
|
||||
"updatestate_depend_eliminater", prim::kPrimUpdateState);
|
||||
updatestate_assign_eliminater_ = MakeSubstitution(std::make_shared<UpdatestateAssignEliminater>(),
|
||||
"updatestate_assign_eliminater", prim::kPrimUpdateState);
|
||||
updatestate_maketuple_eliminater_ = MakeSubstitution(std::make_shared<UpdatestateMakeTupleEliminater>(),
|
||||
"updatestate_maketuple_eliminater", prim::kPrimUpdateState);
|
||||
updatestate_loads_eliminater_ = MakeSubstitution(std::make_shared<UpdatestateLoadsEliminater>(),
|
||||
"updatestate_loads_eliminater", prim::kPrimUpdateState);
|
||||
switch_call_monad_eliminater_ = MakeSubstitution(std::make_shared<SwitchCallMonadParameterEliminater>(),
|
||||
"switch_call_monad_eliminater", IsCNodeDup);
|
||||
|
||||
|
@ -261,13 +272,9 @@ OptimizeIRPassLib::OptimizeIRPassLib() {
|
|||
}
|
||||
|
||||
ResolveIRPassLib::ResolveIRPassLib() {
|
||||
resolver_resolve_and_getattr_ =
|
||||
MakeSubstitution(std::make_shared<ResolverResolveAndGetAttr>(), "resolver_resolve_and_getattr",
|
||||
{prim::kPrimGetAttr, prim::kPrimResolve});
|
||||
resolver_resolve_ = MakeSubstitution(std::make_shared<ResolverResolve>(), "resolver_resolve", prim::kPrimResolve);
|
||||
resolver_getattr_ = MakeSubstitution(std::make_shared<ResolverGetAttr>(), "resolver_getattr", prim::kPrimGetAttr);
|
||||
resolver_getattr_resolve_ =
|
||||
MakeSubstitution(std::make_shared<ResolverGetAttrResolve>(), "resolver_getattr_resolve", prim::kPrimGetAttr);
|
||||
// In resolver_getattr_resolve_, some patterns have priority over others.
|
||||
resolver_getattr_resolve_ = MakeSubstitution(std::make_shared<ResolverGetAttrResolve>(), "getattr_resolve",
|
||||
{prim::kPrimGetAttr, prim::kPrimResolve}, opt::CHECK_RENORM, true);
|
||||
}
|
||||
|
||||
InferenceOptPrepareLib::InferenceOptPrepareLib() {
|
||||
|
|
|
@ -108,7 +108,12 @@ class OptimizeIRPassLib {
|
|||
SubstitutionPtr specialize_transform_;
|
||||
|
||||
// Auto-monad related eliminaters.
|
||||
SubstitutionPtr updatestate_eliminater_;
|
||||
SubstitutionPtr updatestate_only_used_node_eliminater_;
|
||||
SubstitutionPtr updatestate_pure_node_eliminater_;
|
||||
SubstitutionPtr updatestate_depend_eliminater_;
|
||||
SubstitutionPtr updatestate_assign_eliminater_;
|
||||
SubstitutionPtr updatestate_maketuple_eliminater_;
|
||||
SubstitutionPtr updatestate_loads_eliminater_;
|
||||
SubstitutionPtr switch_call_monad_eliminater_;
|
||||
SubstitutionPtr stopgrad_eliminater_;
|
||||
SubstitutionPtr load_eliminater_;
|
||||
|
@ -166,10 +171,6 @@ class ResolveIRPassLib {
|
|||
public:
|
||||
ResolveIRPassLib();
|
||||
~ResolveIRPassLib() = default;
|
||||
|
||||
SubstitutionPtr resolver_resolve_and_getattr_;
|
||||
SubstitutionPtr resolver_resolve_;
|
||||
SubstitutionPtr resolver_getattr_;
|
||||
SubstitutionPtr resolver_getattr_resolve_;
|
||||
};
|
||||
|
||||
|
|
|
@ -0,0 +1,91 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "frontend/optimizer/irpass/symbol_resolver.h"
|
||||
|
||||
#include <string>
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
namespace irpass {
|
||||
|
||||
// {prim::kPrimGetAttr, {prim::kPrimResolve, namespace, symbol}, attr}
|
||||
// {prim::kPrimGetAttr, namespace, attr}
|
||||
// {prim::kPrimGetAttr, bool, attr}
|
||||
// {prim::kPrimResolve, namespace, symbol}
|
||||
AnfNodePtr ResolverGetAttrResolve::operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node) {
|
||||
constexpr char PARSE_SUPER_NAME[] = "namespace";
|
||||
constexpr size_t namespace_index = 1;
|
||||
constexpr size_t symbol_index = 2;
|
||||
|
||||
PatternNode<AnfNodePtr> resolve_node, ns_node, sym_node, attr_node, bool_node;
|
||||
auto GetAttrResolveLambda = [&node, &resolve_node, &attr_node, &optimizer]() -> AnfNodePtr {
|
||||
auto inner = resolve_node.GetNode(node);
|
||||
auto attr = attr_node.GetNode(node);
|
||||
if (IsPrimitiveCNode(inner, prim::kPrimResolve)) {
|
||||
auto resolve_cnode = inner->cast<CNodePtr>();
|
||||
auto namespace_node = resolve_cnode->input(namespace_index);
|
||||
auto symbol_node = resolve_cnode->input(symbol_index);
|
||||
if (!IsValueNode<parse::NameSpace>(namespace_node) || !IsValueNode<parse::Symbol>(symbol_node)) {
|
||||
return nullptr;
|
||||
}
|
||||
// deal with the case of getting attr from a class member
|
||||
// and avoid the case of getting attr from self (the result of ParseSuper)
|
||||
auto ns = GetValueNode<parse::NameSpacePtr>(namespace_node);
|
||||
auto sym = GetValueNode<parse::SymbolPtr>(symbol_node);
|
||||
if (ns->module() == parse::RESOLVE_NAMESPACE_NAME_CLASS_MEMBER && sym->symbol() != PARSE_SUPER_NAME) {
|
||||
return parse::ResolveCellwithAttr(optimizer->manager(), ns, sym, inner, attr);
|
||||
}
|
||||
}
|
||||
return nullptr;
|
||||
};
|
||||
|
||||
auto GetAttrLambda = [&node, &ns_node, &attr_node, &optimizer]() -> AnfNodePtr {
|
||||
auto ns = GetValueNode<parse::NameSpacePtr>(ns_node.GetNode(node));
|
||||
auto str = GetValue<std::string>(GetValueNode(attr_node.GetNode(node)));
|
||||
parse::SymbolPtr sym = std::make_shared<parse::Symbol>(str);
|
||||
return parse::ResolveSymbol(optimizer->manager(), ns, sym, node);
|
||||
};
|
||||
|
||||
auto ResolveLambda = [&node, &ns_node, &sym_node, &optimizer]() -> AnfNodePtr {
|
||||
auto ns = GetValueNode<parse::NameSpacePtr>(ns_node.GetNode(node));
|
||||
auto sym = GetValueNode<parse::SymbolPtr>(sym_node.GetNode(node));
|
||||
auto manager = optimizer->manager();
|
||||
return parse::ResolveSymbol(manager, ns, sym, node);
|
||||
};
|
||||
|
||||
// {prim::kPrimGetAttr, {prim::kPrimResolve, namespace, symbol}, attr}
|
||||
MATCH_REPLACE_LAMBDA_IF(node, PPrimitive(prim::kPrimGetAttr, resolve_node, attr_node), GetAttrResolveLambda,
|
||||
attr_node.CheckFunc(IsValueNode<StringImm>, node));
|
||||
// {prim::kPrimGetAttr, namespace, attr}
|
||||
MATCH_REPLACE_LAMBDA_IF(
|
||||
node, PPrimitive(prim::kPrimGetAttr, ns_node, attr_node), GetAttrLambda,
|
||||
ns_node.CheckFunc(IsValueNode<parse::NameSpace>, node) && attr_node.CheckFunc(IsValueNode<StringImm>, node));
|
||||
// {prim::kPrimGetAttr, bool, attr}
|
||||
MATCH_REPLACE_IF(
|
||||
node, PPrimitive(prim::kPrimGetAttr, bool_node, attr_node), bool_node,
|
||||
bool_node.CheckFunc(IsValueNode<BoolImm>, node) && attr_node.CheckFunc(IsValueNode<StringImm>, node));
|
||||
// {prim::kPrimResolve, namespace, symbol}
|
||||
MATCH_REPLACE_LAMBDA_IF(
|
||||
node, PPrimitive(prim::kPrimResolve, ns_node, sym_node), ResolveLambda,
|
||||
ns_node.CheckFunc(IsValueNode<parse::NameSpace>, node) && sym_node.CheckFunc(IsValueNode<parse::Symbol>, node));
|
||||
return nullptr;
|
||||
}
|
||||
} // namespace irpass
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
|
@ -34,117 +34,17 @@
|
|||
namespace mindspore {
|
||||
namespace opt {
|
||||
namespace irpass {
|
||||
const char PARSE_SUPER_NAME[] = "namespace";
|
||||
|
||||
// {prim::kPrimResolve, Ns, Sym}
|
||||
class ResolverResolve : public AnfVisitor {
|
||||
public:
|
||||
AnfNodePtr operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node) override {
|
||||
Reset();
|
||||
AnfVisitor::Match(prim::kPrimResolve, {IsVNode, IsVNode})(node);
|
||||
if (sym_ != nullptr) {
|
||||
return parse::ResolveSymbol(optimizer->manager(), ns_, sym_, node);
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
void Visit(const ValueNodePtr &vnode) override {
|
||||
if (IsValueNode<parse::NameSpace>(vnode)) {
|
||||
ns_ = GetValueNode<parse::NameSpacePtr>(vnode);
|
||||
} else if (ns_ != nullptr && IsValueNode<parse::Symbol>(vnode)) {
|
||||
sym_ = GetValueNode<parse::SymbolPtr>(vnode);
|
||||
}
|
||||
}
|
||||
|
||||
void Reset() {
|
||||
ns_ = nullptr;
|
||||
sym_ = nullptr;
|
||||
}
|
||||
|
||||
private:
|
||||
parse::NameSpacePtr ns_{nullptr};
|
||||
parse::SymbolPtr sym_{nullptr};
|
||||
};
|
||||
|
||||
// {prim::kPrimGetAttr, Ns, Str}
|
||||
class ResolverGetAttr : public AnfVisitor {
|
||||
public:
|
||||
AnfNodePtr operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node) override {
|
||||
Reset();
|
||||
AnfVisitor::Match(prim::kPrimGetAttr, {IsVNode, IsVNode})(node);
|
||||
if (sym_ != nullptr) {
|
||||
return parse::ResolveSymbol(optimizer->manager(), ns_, sym_, node);
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
void Visit(const AnfNodePtr &node) override {
|
||||
if (IsValueNode<parse::NameSpace>(node)) {
|
||||
ns_ = GetValueNode<parse::NameSpacePtr>(node);
|
||||
} else if (ns_ != nullptr && IsValueNode<StringImm>(node)) {
|
||||
auto str = GetValue<std::string>(GetValueNode(node));
|
||||
sym_ = std::make_shared<parse::Symbol>(str);
|
||||
}
|
||||
}
|
||||
|
||||
void Reset() {
|
||||
ns_ = nullptr;
|
||||
sym_ = nullptr;
|
||||
}
|
||||
|
||||
private:
|
||||
parse::NameSpacePtr ns_{nullptr};
|
||||
parse::SymbolPtr sym_{nullptr};
|
||||
};
|
||||
|
||||
// {prim::kPrimGetAttr, {prim::kPrimResolve, ns_node, sym_node}, attr_node}
|
||||
// Put GetAttr pattern and Resolve pattern together to ensure that GetAttr pattern always takes precedence over Resolve
|
||||
// pattern. After matching GetAttr pattern, there may be new nodes that can match GetAttr pattern and Resolve pattern.
|
||||
// The same is true for matching Resolve pattern.
|
||||
//
|
||||
// {prim::kPrimGetAttr, {prim::kPrimResolve, namespace, symbol}, attr}
|
||||
// {prim::kPrimGetAttr, namespace, attr}
|
||||
// {prim::kPrimGetAttr, bool, attr}
|
||||
// {prim::kPrimResolve, namespace, symbol}
|
||||
class ResolverGetAttrResolve : public OptimizerCaller {
|
||||
public:
|
||||
AnfNodePtr operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node) override {
|
||||
PatternNode<AnfNodePtr> ns_node, sym_node, attr_node;
|
||||
auto ResolveAttrLambda = [&node, &ns_node, &sym_node, &attr_node, &optimizer]() -> AnfNodePtr {
|
||||
auto node_to_getattr = node->cast<CNodePtr>()->input(1);
|
||||
std::string attr_as_string = GetValueNode<StringImmPtr>(attr_node.GetNode(node))->value();
|
||||
|
||||
auto ns_ = GetValueNode<parse::NameSpacePtr>(ns_node.GetNode(node));
|
||||
auto sym_ = GetValueNode<parse::SymbolPtr>(sym_node.GetNode(node));
|
||||
if (ns_->module() == parse::RESOLVE_NAMESPACE_NAME_CLASS_MEMBER && sym_->symbol() != PARSE_SUPER_NAME) {
|
||||
// deal with the case of getting attr from a class member
|
||||
// and avoid the case of getting attr from self (the result of ParseSuper)
|
||||
auto result = parse::ResolveCellwithAttr(optimizer->manager(), ns_, sym_, node_to_getattr, attr_as_string);
|
||||
return result;
|
||||
}
|
||||
return nullptr;
|
||||
};
|
||||
MATCH_REPLACE_LAMBDA_IF(
|
||||
node, PPrimitive(prim::kPrimGetAttr, PPrimitive(prim::kPrimResolve, ns_node, sym_node), attr_node),
|
||||
ResolveAttrLambda, attr_node.CheckFunc(IsValueNode<StringImm>, node));
|
||||
|
||||
return nullptr;
|
||||
}
|
||||
};
|
||||
|
||||
class ResolverResolveAndGetAttr : public OptimizerCaller {
|
||||
public:
|
||||
ResolverResolveAndGetAttr() {
|
||||
resolver_optimizers_ = {std::make_shared<ResolverGetAttrResolve>(), std::make_shared<ResolverResolve>(),
|
||||
std::make_shared<ResolverGetAttr>()};
|
||||
}
|
||||
virtual ~ResolverResolveAndGetAttr() = default;
|
||||
|
||||
AnfNodePtr operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node) override {
|
||||
AnfNodePtr new_node;
|
||||
for (const auto &resolver_opt : resolver_optimizers_) {
|
||||
new_node = (*resolver_opt)(optimizer, node);
|
||||
if (new_node != nullptr) {
|
||||
return new_node;
|
||||
}
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
private:
|
||||
std::vector<OptimizerCallerPtr> resolver_optimizers_{};
|
||||
AnfNodePtr operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node) override;
|
||||
};
|
||||
} // namespace irpass
|
||||
} // namespace opt
|
||||
|
|
|
@ -22,6 +22,10 @@
|
|||
#include <vector>
|
||||
|
||||
#include "frontend/operator/ops.h"
|
||||
#include "frontend/optimizer/irpass.h"
|
||||
#include "frontend/optimizer/optimizer_caller.h"
|
||||
#include "frontend/optimizer/anf_visitor.h"
|
||||
#include "ir/pattern_matcher.h"
|
||||
|
||||
namespace mindspore::opt::irpass {
|
||||
namespace {
|
||||
|
@ -81,35 +85,7 @@ bool OnlyUsedByTwoNode(const AnfNodePtr &be_used_node, const AnfNodePtr &first_n
|
|||
(first_user == second_node && second_user == first_node);
|
||||
}
|
||||
|
||||
// Eliminate useless node that only used by associated update_state.
|
||||
// Convert:
|
||||
// x1 = node(x, u)
|
||||
// u1 = update_state(u, x1) # update_state is the only user of node
|
||||
// user(u1)
|
||||
// To:
|
||||
// user(u)
|
||||
AnfNodePtr EliminateUpdateStateOnlyUsedNode(const CNodePtr &update_state, const AnfNodePtr &node) {
|
||||
if (!OnlyUsedByOneNode(node, update_state)) {
|
||||
// Skip if UpdateState is not the only user of cnode.
|
||||
return nullptr;
|
||||
}
|
||||
// Replace UpdateState with the input monad.
|
||||
return update_state->input(kInputIndex);
|
||||
}
|
||||
|
||||
// Eliminate UpdateState that attaches a pure (no-side-effect) node.
|
||||
// Convert:
|
||||
// x = pure_node(args) # no side effect
|
||||
// u1 = update_state(u, x)
|
||||
// user(u1)
|
||||
// To:
|
||||
// x = pure_node(args)
|
||||
// user(u)
|
||||
AnfNodePtr EliminateUpdateStateForPureNode(const CNodePtr &update_state, const AnfNodePtr &attach) {
|
||||
if (IsPrimitiveCNode(attach, prim::kPrimTupleGetItem)) {
|
||||
// Skip tuple_getitem.
|
||||
return nullptr;
|
||||
}
|
||||
auto cnode = dyn_cast<CNode>(attach);
|
||||
if (cnode == nullptr) {
|
||||
// Skip value node or parameter.
|
||||
|
@ -122,26 +98,11 @@ AnfNodePtr EliminateUpdateStateForPureNode(const CNodePtr &update_state, const A
|
|||
return nullptr;
|
||||
}
|
||||
}
|
||||
// Skip Call/Switch/SwitchLayer.
|
||||
auto first_input_node = cnode->input(kFirstInputIndex);
|
||||
if (IsPrimitiveCNode(first_input_node, prim::kPrimCall) || IsPrimitiveCNode(first_input_node, prim::kPrimSwitch) ||
|
||||
IsPrimitiveCNode(first_input_node, prim::kPrimSwitchLayer)) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
// Remove UpdateState by replace it with its input monad.
|
||||
return update_state->input(kInputIndex);
|
||||
}
|
||||
|
||||
// Eliminate redundant UpdateState/Depend pair nodes caused by inline.
|
||||
// Convert:
|
||||
// x1 = Depend(x, u)
|
||||
// u1 = UpdateState(u, x1)
|
||||
// out = x_user(x1)
|
||||
// u2 = u_user(u1)
|
||||
// To:
|
||||
// out = x_user(x)
|
||||
// u2 = u_user(u)
|
||||
AnfNodePtr EliminateUpdateStateWithDepend(const CNodePtr &update_state, const CNodePtr &depend) {
|
||||
auto input_monad = depend->inputs().back();
|
||||
if (!HasAbstractMonad(input_monad)) {
|
||||
|
@ -638,28 +599,86 @@ AnfNodePtr EliminateUpdateStateBetweenAssignMakeTuple(const CNodePtr &update_sta
|
|||
}
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
AnfNodePtr UpdatestateEliminater::operator()(const OptimizerPtr &, const AnfNodePtr &node) {
|
||||
// Eliminate useless node that only used by associated update_state.
|
||||
// {prim::kPrimUpdateState, u, {prim::kPrimLoad, m, u}} -> u
|
||||
// {prim::kPrimUpdateState, u, {prim::kPrimPartial, m, u}} -> u
|
||||
// Convert:
|
||||
// x1 = node(x, u)
|
||||
// u1 = update_state(u, x1) # update_state is the only user of x1.
|
||||
// user(u1)
|
||||
// To:
|
||||
// user(u)
|
||||
AnfNodePtr UpdatestateOnlyUsedNodeEliminater::operator()(const OptimizerPtr &, const AnfNodePtr &node) {
|
||||
auto update_state_node = dyn_cast<CNode>(node);
|
||||
if (update_state_node == nullptr || update_state_node->inputs().empty()) {
|
||||
MS_LOG(WARNING) << "UpdatestateEliminater encounter invalid node: " << node->DebugString();
|
||||
return nullptr;
|
||||
}
|
||||
auto &attach = update_state_node->input(kAttachIndex);
|
||||
if (IsPrimitiveCNode(attach, prim::kPrimPartial) || IsPrimitiveCNode(attach, prim::kPrimLoad)) {
|
||||
// Replace UpdateState with the input monad.
|
||||
if (OnlyUsedByOneNode(attach, update_state_node)) {
|
||||
return update_state_node->input(kInputIndex);
|
||||
}
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
// Handle UpdateState(u, Depend(...)).
|
||||
// Eliminate UpdateState that attaches a pure (no-side-effect) node.
|
||||
// Convert:
|
||||
// x = pure_node(args) # no side effect
|
||||
// u1 = update_state(u, x)
|
||||
// user(u1)
|
||||
// To:
|
||||
// x = pure_node(args)
|
||||
// user(u)
|
||||
AnfNodePtr UpdatestatePureNodeEliminater::operator()(const OptimizerPtr &, const AnfNodePtr &node) {
|
||||
auto update_state_node = dyn_cast<CNode>(node);
|
||||
if (update_state_node == nullptr || update_state_node->inputs().empty()) {
|
||||
MS_LOG(WARNING) << "UpdatestateEliminater encounter invalid node: " << node->DebugString();
|
||||
return nullptr;
|
||||
}
|
||||
auto &attach = update_state_node->input(kAttachIndex);
|
||||
if (IsPrimitiveCNode(attach, prim::kPrimTupleGetItem) || IsPrimitiveCNode(attach, prim::kPrimDepend) ||
|
||||
IsPrimitiveCNode(attach, prim::kPrimPartial) || IsPrimitiveCNode(attach, prim::kPrimMakeTuple)) {
|
||||
return nullptr;
|
||||
}
|
||||
return EliminateUpdateStateForPureNode(update_state_node, attach);
|
||||
}
|
||||
|
||||
// Eliminate redundant UpdateState/Depend pair nodes caused by inline.
|
||||
// Convert:
|
||||
// x1 = Depend(x, u)
|
||||
// u1 = UpdateState(u, x1)
|
||||
// out = x_user(x1)
|
||||
// u2 = u_user(u1)
|
||||
// To:
|
||||
// out = x_user(x)
|
||||
// u2 = u_user(u)
|
||||
AnfNodePtr UpdatestateDependEliminater::operator()(const OptimizerPtr &, const AnfNodePtr &node) {
|
||||
auto update_state_node = dyn_cast<CNode>(node);
|
||||
if (update_state_node == nullptr || update_state_node->inputs().empty()) {
|
||||
MS_LOG(WARNING) << "UpdatestateEliminater encounter invalid node: " << node->DebugString();
|
||||
return nullptr;
|
||||
}
|
||||
auto &attach = update_state_node->input(kAttachIndex);
|
||||
if (IsPrimitiveCNode(attach, prim::kPrimDepend)) {
|
||||
return EliminateUpdateStateWithDepend(update_state_node, attach->cast<CNodePtr>());
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
// Handle UpdateState(u, Partial(...)).
|
||||
if (IsPrimitiveCNode(attach, prim::kPrimPartial)) {
|
||||
return EliminateUpdateStateOnlyUsedNode(update_state_node, attach);
|
||||
// Eliminate UpdateStates between Assign nodes.
|
||||
// Eliminate UpdateStates between Assign and MakeTuple.
|
||||
AnfNodePtr UpdatestateAssignEliminater::operator()(const OptimizerPtr &, const AnfNodePtr &node) {
|
||||
auto update_state_node = dyn_cast<CNode>(node);
|
||||
if (update_state_node == nullptr || update_state_node->inputs().empty()) {
|
||||
MS_LOG(WARNING) << "UpdatestateEliminater encounter invalid node: " << node->DebugString();
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
// Handle UpdateState(u, Assign(...)).
|
||||
auto &attach = update_state_node->input(kAttachIndex);
|
||||
if (IsPrimitiveCNode(attach, prim::kPrimAssign)) {
|
||||
auto new_node = EliminateUpdateStateBetweenAssigns(update_state_node, attach);
|
||||
if (new_node != nullptr) {
|
||||
|
@ -667,20 +686,15 @@ AnfNodePtr UpdatestateEliminater::operator()(const OptimizerPtr &, const AnfNode
|
|||
}
|
||||
return EliminateUpdateStateBetweenMakeTupleAssign(update_state_node, attach);
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
// Handle UpdateState(u, Load(...)).
|
||||
const bool attach_is_load = IsPrimitiveCNode(attach, prim::kPrimLoad);
|
||||
if (attach_is_load) {
|
||||
auto new_node = EliminateUpdateStateOnlyUsedNode(update_state_node, attach);
|
||||
if (new_node != nullptr) {
|
||||
return new_node;
|
||||
}
|
||||
}
|
||||
|
||||
// Handle UpdateState(u, MakeTuple(...)).
|
||||
const bool attach_is_tuple = IsPrimitiveCNode(attach, prim::kPrimMakeTuple);
|
||||
if (attach_is_tuple) {
|
||||
auto make_tuple = attach->cast<CNodePtr>();
|
||||
// Eliminate UpdateStates which the second input is MakeTuple.
|
||||
AnfNodePtr UpdatestateMakeTupleEliminater::operator()(const OptimizerPtr &, const AnfNodePtr &node) {
|
||||
PatternNode<AnfNodePtr> u, attach;
|
||||
auto MakeTupleLambda = [&node, &u, &attach]() -> AnfNodePtr {
|
||||
auto update_state_node = node->cast<CNodePtr>();
|
||||
auto make_tuple = attach.GetNode(node)->cast<CNodePtr>();
|
||||
auto new_node = EliminateMakeTupleWithDeadNode(update_state_node, make_tuple);
|
||||
if (new_node != nullptr) {
|
||||
return new_node;
|
||||
|
@ -689,23 +703,31 @@ AnfNodePtr UpdatestateEliminater::operator()(const OptimizerPtr &, const AnfNode
|
|||
if (new_node != nullptr) {
|
||||
return new_node;
|
||||
}
|
||||
new_node = EliminateUpdateStateBetweenAssignMakeTuple(update_state_node, make_tuple);
|
||||
if (new_node != nullptr) {
|
||||
return new_node;
|
||||
}
|
||||
return EliminateUpdateStateBetweenAssignMakeTuple(update_state_node, make_tuple);
|
||||
};
|
||||
|
||||
MATCH_REPLACE_LAMBDA_IF(node, PPrimitive(prim::kPrimUpdateState, u, attach), MakeTupleLambda,
|
||||
IsPrimitiveCNode(attach.GetNode(node), prim::kPrimMakeTuple));
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
// Eliminate UpdateStates for consecutive Loads.
|
||||
AnfNodePtr UpdatestateLoadsEliminater::operator()(const OptimizerPtr &, const AnfNodePtr &node) {
|
||||
auto update_state_node = dyn_cast<CNode>(node);
|
||||
if (update_state_node == nullptr || update_state_node->inputs().empty()) {
|
||||
MS_LOG(WARNING) << "UpdatestateEliminater encounter invalid node: " << node->DebugString();
|
||||
return nullptr;
|
||||
}
|
||||
// Merge UpdateStates for Loads.
|
||||
if (attach_is_load || attach_is_tuple) {
|
||||
auto &attach = update_state_node->input(kAttachIndex);
|
||||
if (IsPrimitiveCNode(attach, prim::kPrimLoad) || IsPrimitiveCNode(attach, prim::kPrimMakeTuple)) {
|
||||
std::vector<CNodePtr> update_states;
|
||||
std::vector<CNodePtr> loads;
|
||||
GetLoadsFromUpdateState(update_state_node, &update_states, &loads);
|
||||
if (update_states.size() > 1 && loads.size() > 1) {
|
||||
return EliminateUpdateStateForLoads(update_state_node, update_states, loads);
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
// Eliminate UpdateStates that attaches a no-side-effect node.
|
||||
return EliminateUpdateStateForPureNode(update_state_node, attach);
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
// Eliminate Monad parameter for switch call.
|
||||
|
@ -725,7 +747,7 @@ AnfNodePtr UpdatestateEliminater::operator()(const OptimizerPtr &, const AnfNode
|
|||
// g2 = Partial(..., u)
|
||||
// s = switch(cond, g1, g2)
|
||||
// res = s()
|
||||
AnfNodePtr EliminateMonadParameterForSwitchCall(const AnfNodePtr &node) {
|
||||
AnfNodePtr SwitchCallMonadParameterEliminater::operator()(const OptimizerPtr &, const AnfNodePtr &node) {
|
||||
const CNodePtr &switch_call = dyn_cast<CNode>(node);
|
||||
if (switch_call == nullptr) {
|
||||
return nullptr;
|
||||
|
@ -777,8 +799,4 @@ AnfNodePtr EliminateMonadParameterForSwitchCall(const AnfNodePtr &node) {
|
|||
auto new_switch_call = fg->NewCNode({new_switch_cnode});
|
||||
return new_switch_call;
|
||||
}
|
||||
|
||||
AnfNodePtr SwitchCallMonadParameterEliminater::operator()(const OptimizerPtr &, const AnfNodePtr &node) {
|
||||
return EliminateMonadParameterForSwitchCall(node);
|
||||
}
|
||||
} // namespace mindspore::opt::irpass
|
||||
|
|
|
@ -21,17 +21,44 @@
|
|||
#include "frontend/optimizer/anf_visitor.h"
|
||||
|
||||
namespace mindspore::opt::irpass {
|
||||
//
|
||||
// UpdatestateEliminater eliminates redundant UpdateState related nodes.
|
||||
//
|
||||
class UpdatestateEliminater : public AnfVisitor {
|
||||
// Eliminate useless node that only used by associated update_state.
|
||||
class UpdatestateOnlyUsedNodeEliminater : public AnfVisitor {
|
||||
public:
|
||||
AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override;
|
||||
};
|
||||
|
||||
// Eliminate UpdateStates that attaches a no-side-effect node.
|
||||
class UpdatestatePureNodeEliminater : public AnfVisitor {
|
||||
public:
|
||||
AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override;
|
||||
};
|
||||
|
||||
// Eliminate redundant UpdateState/Depend pair nodes caused by inline.
|
||||
class UpdatestateDependEliminater : public AnfVisitor {
|
||||
public:
|
||||
AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override;
|
||||
};
|
||||
|
||||
// Eliminate UpdateStates between Assign nodes.
|
||||
// Eliminate UpdateStates between Assign and MakeTuple.
|
||||
class UpdatestateAssignEliminater : public AnfVisitor {
|
||||
public:
|
||||
AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override;
|
||||
};
|
||||
|
||||
// Eliminate UpdateStates which the second input is MakeTuple.
|
||||
class UpdatestateMakeTupleEliminater : public AnfVisitor {
|
||||
public:
|
||||
AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override;
|
||||
};
|
||||
|
||||
// Eliminate UpdateStates for consecutive Loads.
|
||||
class UpdatestateLoadsEliminater : public AnfVisitor {
|
||||
public:
|
||||
AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override;
|
||||
};
|
||||
|
||||
//
|
||||
// SwitchCallMonadParameterEliminater eliminates Monad parameter in switch call.
|
||||
//
|
||||
class SwitchCallMonadParameterEliminater : public AnfVisitor {
|
||||
public:
|
||||
AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override;
|
||||
|
|
|
@ -30,13 +30,14 @@ namespace mindspore {
|
|||
/* namespace to support opt */
|
||||
namespace opt {
|
||||
SubstitutionPtr MakeSubstitution(const OptimizerCallerPtr &transform, const std::string &name, const PrimitivePtr &prim,
|
||||
const RenormAction &renorm_action) {
|
||||
const RenormAction &renorm_action, bool has_priority_pattern) {
|
||||
auto fn = [prim](const AnfNodePtr &node) -> bool { return IsPrimitiveCNode(node, prim); };
|
||||
return std::make_shared<Substitution>(transform, name, fn, renorm_action);
|
||||
return std::make_shared<Substitution>(transform, name, fn, renorm_action, has_priority_pattern);
|
||||
}
|
||||
|
||||
SubstitutionPtr MakeSubstitution(const OptimizerCallerPtr &transform, const std::string &name,
|
||||
const std::vector<PrimitivePtr> &prims, const RenormAction &renorm_action) {
|
||||
const std::vector<PrimitivePtr> &prims, const RenormAction &renorm_action,
|
||||
bool has_priority_pattern) {
|
||||
auto fn = [prims](const AnfNodePtr &node) -> bool {
|
||||
if (!node->isa<CNode>()) {
|
||||
return false;
|
||||
|
@ -59,12 +60,13 @@ SubstitutionPtr MakeSubstitution(const OptimizerCallerPtr &transform, const std:
|
|||
return false;
|
||||
};
|
||||
|
||||
return std::make_shared<Substitution>(transform, name, fn, renorm_action);
|
||||
return std::make_shared<Substitution>(transform, name, fn, renorm_action, has_priority_pattern);
|
||||
}
|
||||
|
||||
SubstitutionPtr MakeSubstitution(const OptimizerCallerPtr &transform, const std::string &name,
|
||||
const PredicateFuncType &predicate, const RenormAction &renorm_action) {
|
||||
return std::make_shared<Substitution>(transform, name, predicate, renorm_action);
|
||||
const PredicateFuncType &predicate, const RenormAction &renorm_action,
|
||||
bool has_priority_pattern) {
|
||||
return std::make_shared<Substitution>(transform, name, predicate, renorm_action, has_priority_pattern);
|
||||
}
|
||||
|
||||
AnfNodePtr Substitution::operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node) {
|
||||
|
@ -126,16 +128,41 @@ static AnfNodePtr DoTransform(const OptimizerPtr &optimizer, const AnfNodePtr &n
|
|||
return nullptr;
|
||||
}
|
||||
|
||||
static void UpdateTransformingList(const OptimizerPtr &optimizer, const AnfNodePtr &node, std::deque<AnfNodePtr> *todo,
|
||||
bool change, size_t seen) {
|
||||
static void UpdateTransformingListForSubstitutions(const AnfNodePtr &node, std::deque<AnfNodePtr> *todo, bool change) {
|
||||
if (IsValueNode<FuncGraph>(node)) {
|
||||
(*todo).emplace_back(GetValueNode<FuncGraphPtr>(node)->output());
|
||||
}
|
||||
if (node->isa<CNode>()) {
|
||||
auto &inputs = node->cast<CNodePtr>()->inputs();
|
||||
(void)std::copy(inputs.begin(), inputs.end(), std::back_inserter(*todo));
|
||||
|
||||
if (change) {
|
||||
(*todo).emplace_back(node);
|
||||
} else {
|
||||
if (node->isa<CNode>()) {
|
||||
auto &inputs = node->cast<CNodePtr>()->inputs();
|
||||
(void)std::copy(inputs.begin(), inputs.end(), std::back_inserter(*todo));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
static void UpdateTransformingListForIR(const AnfNodePtr &node, std::deque<AnfNodePtr> *todo, bool change,
|
||||
const SubstitutionPtr &substitution) {
|
||||
if (IsValueNode<FuncGraph>(node)) {
|
||||
(*todo).emplace_back(GetValueNode<FuncGraphPtr>(node)->output());
|
||||
}
|
||||
|
||||
// If there is a priority pattern in substitution, don't transform the new node,
|
||||
// otherwise some nodes may match the wrong patterns.
|
||||
if (change && substitution != nullptr && !substitution->has_priority_pattern_) {
|
||||
(*todo).emplace_back(node);
|
||||
} else {
|
||||
if (node->isa<CNode>()) {
|
||||
auto &inputs = node->cast<CNodePtr>()->inputs();
|
||||
(void)std::copy(inputs.begin(), inputs.end(), std::back_inserter(*todo));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
static void UpdateTransformingListWithUserNodes(const OptimizerPtr &optimizer, const AnfNodePtr &node,
|
||||
std::deque<AnfNodePtr> *todo, bool change, size_t seen) {
|
||||
if (!change) {
|
||||
return;
|
||||
}
|
||||
|
@ -185,11 +212,11 @@ bool SubstitutionList::ApplyIRToSubstitutions(const OptimizerPtr &optimizer, con
|
|||
change = true;
|
||||
changes = true;
|
||||
node = res;
|
||||
todo.emplace_back(res);
|
||||
break;
|
||||
}
|
||||
}
|
||||
UpdateTransformingList(optimizer, node, &todo, change, seen);
|
||||
UpdateTransformingListForSubstitutions(node, &todo, change);
|
||||
UpdateTransformingListWithUserNodes(optimizer, node, &todo, change, seen);
|
||||
}
|
||||
#ifdef ENABLE_PROFILE
|
||||
MsProfile::StatTime("opt.transforms." + optimizer->name(), GetTime() - start);
|
||||
|
@ -197,7 +224,7 @@ bool SubstitutionList::ApplyIRToSubstitutions(const OptimizerPtr &optimizer, con
|
|||
return changes;
|
||||
}
|
||||
|
||||
bool SubstitutionList::ApplySubstitutionToIR(const OptimizerPtr &optimizer, const AnfNodePtr &root_node,
|
||||
bool SubstitutionList::ApplySubstitutionToIR(const OptimizerPtr &optimizer, const FuncGraphPtr &func_graph,
|
||||
const SubstitutionPtr &substitution) const {
|
||||
#ifdef ENABLE_PROFILE
|
||||
double start = GetTime();
|
||||
|
@ -205,7 +232,7 @@ bool SubstitutionList::ApplySubstitutionToIR(const OptimizerPtr &optimizer, cons
|
|||
FuncGraphManagerPtr manager = optimizer->manager();
|
||||
auto seen = NewSeenGeneration();
|
||||
std::deque<AnfNodePtr> todo;
|
||||
todo.emplace_back(root_node);
|
||||
todo.emplace_back(func_graph->output());
|
||||
bool changes = false;
|
||||
|
||||
auto &all_nodes = manager->all_nodes();
|
||||
|
@ -225,7 +252,8 @@ bool SubstitutionList::ApplySubstitutionToIR(const OptimizerPtr &optimizer, cons
|
|||
changes = true;
|
||||
node = res;
|
||||
}
|
||||
UpdateTransformingList(optimizer, node, &todo, change, seen);
|
||||
UpdateTransformingListForIR(node, &todo, change, substitution);
|
||||
UpdateTransformingListWithUserNodes(optimizer, node, &todo, change, seen);
|
||||
}
|
||||
|
||||
#ifdef ENABLE_PROFILE
|
||||
|
@ -268,7 +296,7 @@ bool SubstitutionList::ApplySubstitutionsToIR(const OptimizerPtr &optimizer, con
|
|||
loop = false;
|
||||
for (size_t i = 0; i < list_.size(); i++) {
|
||||
const auto &substitution = list_[i];
|
||||
bool change = ApplySubstitutionToIR(optimizer, func_graph->output(), substitution);
|
||||
bool change = ApplySubstitutionToIR(optimizer, func_graph, substitution);
|
||||
changes = changes || change;
|
||||
loop = loop || change;
|
||||
|
||||
|
|
|
@ -17,6 +17,7 @@
|
|||
#ifndef MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_OPT_H_
|
||||
#define MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_OPT_H_
|
||||
|
||||
#include <deque>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
@ -41,11 +42,18 @@ class Substitution {
|
|||
OptimizerCallerPtr transform_;
|
||||
std::string name_;
|
||||
PredicateFuncType predicate_{nullptr};
|
||||
// an enum to mark this Substitution relation to renormalize pass
|
||||
// An enum to mark this Substitution relation to renormalize pass
|
||||
RenormAction renorm_action_;
|
||||
// Determine whether it is a priority substitution, that is, some patterns need to be matched prior to others.
|
||||
bool has_priority_pattern_{false};
|
||||
|
||||
Substitution(const OptimizerCallerPtr &transform, const std::string &name, const PredicateFuncType &predicate,
|
||||
const RenormAction &renorm_action)
|
||||
: transform_(transform), name_(name), predicate_(predicate), renorm_action_(renorm_action) {}
|
||||
const RenormAction &renorm_action, bool has_priority_pattern)
|
||||
: transform_(transform),
|
||||
name_(name),
|
||||
predicate_(predicate),
|
||||
renorm_action_(renorm_action),
|
||||
has_priority_pattern_(has_priority_pattern) {}
|
||||
~Substitution() = default;
|
||||
AnfNodePtr operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node);
|
||||
};
|
||||
|
@ -53,12 +61,13 @@ class Substitution {
|
|||
using SubstitutionPtr = std::shared_ptr<Substitution>;
|
||||
|
||||
SubstitutionPtr MakeSubstitution(const OptimizerCallerPtr &transform, const std::string &name, const PrimitivePtr &prim,
|
||||
const RenormAction &action_renorm = CHECK_RENORM);
|
||||
const RenormAction &action_renorm = CHECK_RENORM, bool has_priority_pattern = false);
|
||||
SubstitutionPtr MakeSubstitution(const OptimizerCallerPtr &transform, const std::string &name,
|
||||
const std::vector<PrimitivePtr> &prims,
|
||||
const RenormAction &action_renorm = CHECK_RENORM);
|
||||
const RenormAction &action_renorm = CHECK_RENORM, bool has_priority_pattern = false);
|
||||
SubstitutionPtr MakeSubstitution(const OptimizerCallerPtr &transform, const std::string &name,
|
||||
const PredicateFuncType &predicate, const RenormAction &action_renorm = CHECK_RENORM);
|
||||
const PredicateFuncType &predicate, const RenormAction &action_renorm = CHECK_RENORM,
|
||||
bool has_priority_pattern = false);
|
||||
|
||||
enum OptTraverseSubstitutionsMode { kOptTraverseFromIRToSubstitutions = 0, kOptTraverseFromSubstitutionsToIR };
|
||||
|
||||
|
@ -73,15 +82,16 @@ class SubstitutionList {
|
|||
|
||||
private:
|
||||
bool ApplyIRToSubstitutions(const OptimizerPtr &optimizer, const FuncGraphPtr &func_graph) const;
|
||||
bool ApplySubstitutionToIR(const OptimizerPtr &optimizer, const AnfNodePtr &node, const SubstitutionPtr &sub) const;
|
||||
bool ApplySubstitutionToIR(const OptimizerPtr &optimizer, const FuncGraphPtr &func_graph,
|
||||
const SubstitutionPtr &sub) const;
|
||||
bool ApplySubstitutionsToIR(const OptimizerPtr &optimizer, const FuncGraphPtr &func_graph) const;
|
||||
void DisplayStatusOfSubstitution(const std::unordered_map<std::string, std::vector<bool>> &status,
|
||||
const OptimizerPtr &optimizer, size_t space) const;
|
||||
|
||||
std::vector<SubstitutionPtr> list_;
|
||||
// a flag to mark this list of Substitution can only be executed only once
|
||||
bool is_once_;
|
||||
bool global_sensitive_;
|
||||
bool is_once_{false};
|
||||
bool global_sensitive_{false};
|
||||
};
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -28,6 +28,7 @@
|
|||
#include "frontend/operator/ops.h"
|
||||
#include "frontend/optimizer/opt.h"
|
||||
#include "frontend/optimizer/irpass.h"
|
||||
#include "frontend/optimizer/irpass/symbol_resolver.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace parse {
|
||||
|
@ -306,7 +307,7 @@ AnfNodePtr ResolveSymbol(const FuncGraphManagerPtr &manager, const NameSpacePtr
|
|||
}
|
||||
|
||||
AnfNodePtr ResolveCellwithAttr(const FuncGraphManagerPtr &manager, const NameSpacePtr &name_space,
|
||||
const SymbolPtr &symbol, const AnfNodePtr &node, const std::string &attr) {
|
||||
const SymbolPtr &symbol, const AnfNodePtr &node, const AnfNodePtr &attr) {
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
TraceGuard trace_guard(std::make_shared<TraceResolve>(node->debug_info()));
|
||||
if (node->func_graph() == nullptr || manager == nullptr) {
|
||||
|
@ -319,14 +320,19 @@ AnfNodePtr ResolveCellwithAttr(const FuncGraphManagerPtr &manager, const NameSpa
|
|||
|
||||
py::object obj = symbol_resolver.result();
|
||||
if (!data_converter::IsCellInstance(obj)) {
|
||||
return nullptr;
|
||||
AnfNodePtr resolved_node = ResolveObjectAndAddToManager(manager, obj, node);
|
||||
AnfNodePtrList inputs = {NewValueNode(prim::kPrimGetAttr), resolved_node, attr};
|
||||
AnfNodePtr res_node = node->func_graph()->NewCNode(inputs);
|
||||
TraceManager::ClearParseOrResolveDebugInfo();
|
||||
return res_node;
|
||||
}
|
||||
|
||||
const std::string fn = PYTHON_MOD_GET_MEMBER_NAMESPACE_SYMBOL;
|
||||
const std::string module = "mindspore._extends.parse.parser";
|
||||
py::object namespace_obj = parse::python_adapter::GetPyFn(module, fn)(obj);
|
||||
auto new_namespace = std::make_shared<NameSpace>(RESOLVE_NAMESPACE_NAME_CLASS_MEMBER, namespace_obj);
|
||||
auto new_symbol = std::make_shared<Symbol>(attr);
|
||||
std::string attr_as_string = GetValueNode<StringImmPtr>(attr)->value();
|
||||
auto new_symbol = std::make_shared<Symbol>(attr_as_string);
|
||||
|
||||
AnfNodePtrList inputs = {NewValueNode(prim::kPrimResolve), NewValueNode(new_namespace), NewValueNode(new_symbol)};
|
||||
AnfNodePtr resolved_node = node->func_graph()->NewCNode(inputs);
|
||||
|
@ -336,11 +342,11 @@ AnfNodePtr ResolveCellwithAttr(const FuncGraphManagerPtr &manager, const NameSpa
|
|||
|
||||
namespace {
|
||||
opt::OptPassGroupMap GetOptResolvePasses(const opt::irpass::ResolveIRPassLib &irpass) {
|
||||
// For resolve and getattr primitive.
|
||||
opt::OptPassGroupMap map({
|
||||
{"resolve",
|
||||
{
|
||||
// For resolve and getattr primitive;
|
||||
irpass.resolver_resolve_and_getattr_,
|
||||
irpass.resolver_getattr_resolve_,
|
||||
}},
|
||||
});
|
||||
return map;
|
||||
|
|
|
@ -148,7 +148,7 @@ AnfNodePtr ResolveSymbol(const FuncGraphManagerPtr &manager, const NameSpacePtr
|
|||
|
||||
// Resolve Cell with attr name.
|
||||
AnfNodePtr ResolveCellwithAttr(const FuncGraphManagerPtr &manager, const NameSpacePtr &name_space,
|
||||
const SymbolPtr &symbol, const AnfNodePtr &node, const std::string &attr);
|
||||
const SymbolPtr &symbol, const AnfNodePtr &node, const AnfNodePtr &attr);
|
||||
|
||||
// Resolve one graph which normally is the root graph. FuncGraph shall be managed by res->manager().
|
||||
bool ResolveFuncGraph(const FuncGraphPtr &func_graph, const pipeline::ResourceBasePtr &res, bool use_profile = true);
|
||||
|
|
|
@ -234,7 +234,12 @@ opt::OptPassConfig GetOptPassA1(const opt::irpass::OptimizeIRPassLib &irpass) {
|
|||
|
||||
// Safe inlining
|
||||
irpass.inline_,
|
||||
irpass.updatestate_eliminater_,
|
||||
irpass.updatestate_depend_eliminater_,
|
||||
irpass.updatestate_assign_eliminater_,
|
||||
irpass.updatestate_maketuple_eliminater_,
|
||||
irpass.updatestate_only_used_node_eliminater_,
|
||||
irpass.updatestate_loads_eliminater_,
|
||||
irpass.updatestate_pure_node_eliminater_,
|
||||
irpass.load_eliminater_,
|
||||
irpass.stopgrad_eliminater_,
|
||||
irpass.partial_eliminate_,
|
||||
|
@ -268,7 +273,12 @@ opt::OptPassConfig GetOptPassA1(const opt::irpass::OptimizeIRPassLib &irpass) {
|
|||
|
||||
// Safe inlining
|
||||
irpass.inline_,
|
||||
irpass.updatestate_eliminater_,
|
||||
irpass.updatestate_depend_eliminater_,
|
||||
irpass.updatestate_assign_eliminater_,
|
||||
irpass.updatestate_maketuple_eliminater_,
|
||||
irpass.updatestate_only_used_node_eliminater_,
|
||||
irpass.updatestate_loads_eliminater_,
|
||||
irpass.updatestate_pure_node_eliminater_,
|
||||
irpass.load_eliminater_,
|
||||
irpass.stopgrad_eliminater_,
|
||||
irpass.sparse_tensor_eliminate_,
|
||||
|
@ -352,7 +362,12 @@ OptPassGroupMap GetOptPassesAfterCconv(const opt::irpass::OptimizeIRPassLib &irp
|
|||
opt::OptPassConfig c_1 = opt::OptPassConfig({
|
||||
// Safe inlining,
|
||||
irpass.inline_,
|
||||
irpass.updatestate_eliminater_,
|
||||
irpass.updatestate_depend_eliminater_,
|
||||
irpass.updatestate_assign_eliminater_,
|
||||
irpass.updatestate_maketuple_eliminater_,
|
||||
irpass.updatestate_only_used_node_eliminater_,
|
||||
irpass.updatestate_loads_eliminater_,
|
||||
irpass.updatestate_pure_node_eliminater_,
|
||||
irpass.load_eliminater_,
|
||||
irpass.switch_call_monad_eliminater_,
|
||||
irpass.stopgrad_eliminater_,
|
||||
|
@ -389,7 +404,12 @@ OptPassGroupMap GetOptPassesB(const opt::irpass::OptimizeIRPassLib &irpass) {
|
|||
irpass.float_tuple_getitem_switch_,
|
||||
irpass.reset_defer_inline_,
|
||||
irpass.inline_,
|
||||
irpass.updatestate_eliminater_,
|
||||
irpass.updatestate_depend_eliminater_,
|
||||
irpass.updatestate_assign_eliminater_,
|
||||
irpass.updatestate_maketuple_eliminater_,
|
||||
irpass.updatestate_only_used_node_eliminater_,
|
||||
irpass.updatestate_loads_eliminater_,
|
||||
irpass.updatestate_pure_node_eliminater_,
|
||||
irpass.load_eliminater_,
|
||||
irpass.stopgrad_eliminater_,
|
||||
irpass.special_op_eliminate_,
|
||||
|
@ -652,10 +672,35 @@ bool AutoMonadElimOptPass(const FuncGraphPtr &func_graph) {
|
|||
res->set_manager(func_graph->manager());
|
||||
|
||||
// opt::irpass::OptimizeIRPassLib is not used here to avoid double free problems in external calls.
|
||||
opt::SubstitutionPtr updatestate_eliminater = opt::MakeSubstitution(
|
||||
std::make_shared<opt::irpass::UpdatestateEliminater>(), "updatestate_eliminater", prim::kPrimUpdateState);
|
||||
opt::SubstitutionPtr updatestate_depend_eliminater =
|
||||
opt::MakeSubstitution(std::make_shared<opt::irpass::UpdatestateDependEliminater>(), "updatestate_depend_eliminater",
|
||||
prim::kPrimUpdateState);
|
||||
opt::SubstitutionPtr updatestate_assign_eliminater =
|
||||
opt::MakeSubstitution(std::make_shared<opt::irpass::UpdatestateAssignEliminater>(), "updatestate_assign_eliminater",
|
||||
prim::kPrimUpdateState);
|
||||
opt::SubstitutionPtr updatestate_maketuple_eliminater =
|
||||
opt::MakeSubstitution(std::make_shared<opt::irpass::UpdatestateMakeTupleEliminater>(),
|
||||
"updatestate_maketuple_eliminater", prim::kPrimUpdateState);
|
||||
opt::SubstitutionPtr updatestate_only_used_node_eliminater =
|
||||
opt::MakeSubstitution(std::make_shared<opt::irpass::UpdatestateOnlyUsedNodeEliminater>(),
|
||||
"updatestate_only_used_node_eliminater", prim::kPrimUpdateState);
|
||||
opt::SubstitutionPtr updatestate_loads_eliminater =
|
||||
opt::MakeSubstitution(std::make_shared<opt::irpass::UpdatestateLoadsEliminater>(), "updatestate_loads_eliminater",
|
||||
prim::kPrimUpdateState);
|
||||
opt::SubstitutionPtr updatestate_pure_node_eliminater =
|
||||
opt::MakeSubstitution(std::make_shared<opt::irpass::UpdatestatePureNodeEliminater>(),
|
||||
"updatestate_pure_node_eliminater", prim::kPrimUpdateState);
|
||||
|
||||
opt::OptPassConfig updatestate_eliminater = opt::OptPassConfig({
|
||||
updatestate_depend_eliminater,
|
||||
updatestate_assign_eliminater,
|
||||
updatestate_maketuple_eliminater,
|
||||
updatestate_only_used_node_eliminater,
|
||||
updatestate_loads_eliminater,
|
||||
updatestate_pure_node_eliminater,
|
||||
});
|
||||
opt::OptPassGroupMap elim_map({
|
||||
{"updatestate_eliminate", opt::OptPassConfig({updatestate_eliminater})},
|
||||
{"updatestate_eliminater", updatestate_eliminater},
|
||||
{"auto_monad_eliminator", opt::OptPassConfig(opt::AutoMonadEliminator())},
|
||||
});
|
||||
|
||||
|
|
Loading…
Reference in New Issue