1. modify the execution method of substitution in irpass; 2. Refactor resolve pass; 3. Split UpdatestateEliminater into multiple passes.

This commit is contained in:
huangbingjian 2021-07-27 19:06:41 +08:00
parent e187cfc889
commit 5b59d0104f
11 changed files with 379 additions and 246 deletions

View File

@ -186,8 +186,19 @@ OptimizeIRPassLib::OptimizeIRPassLib() {
MakeSubstitution(std::make_shared<SpecializeOnGraphArguments>(), "specialize_transform", IsCNodeGraph);
// UpdateState eliminate
updatestate_eliminater_ =
MakeSubstitution(std::make_shared<UpdatestateEliminater>(), "updatestate_eliminater", prim::kPrimUpdateState);
updatestate_only_used_node_eliminater_ =
MakeSubstitution(std::make_shared<UpdatestateOnlyUsedNodeEliminater>(), "updatestate_only_used_node_eliminater",
prim::kPrimUpdateState);
updatestate_pure_node_eliminater_ = MakeSubstitution(std::make_shared<UpdatestatePureNodeEliminater>(),
"updatestate_pure_node_eliminater", prim::kPrimUpdateState);
updatestate_depend_eliminater_ = MakeSubstitution(std::make_shared<UpdatestateDependEliminater>(),
"updatestate_depend_eliminater", prim::kPrimUpdateState);
updatestate_assign_eliminater_ = MakeSubstitution(std::make_shared<UpdatestateAssignEliminater>(),
"updatestate_assign_eliminater", prim::kPrimUpdateState);
updatestate_maketuple_eliminater_ = MakeSubstitution(std::make_shared<UpdatestateMakeTupleEliminater>(),
"updatestate_maketuple_eliminater", prim::kPrimUpdateState);
updatestate_loads_eliminater_ = MakeSubstitution(std::make_shared<UpdatestateLoadsEliminater>(),
"updatestate_loads_eliminater", prim::kPrimUpdateState);
switch_call_monad_eliminater_ = MakeSubstitution(std::make_shared<SwitchCallMonadParameterEliminater>(),
"switch_call_monad_eliminater", IsCNodeDup);
@ -261,13 +272,9 @@ OptimizeIRPassLib::OptimizeIRPassLib() {
}
ResolveIRPassLib::ResolveIRPassLib() {
resolver_resolve_and_getattr_ =
MakeSubstitution(std::make_shared<ResolverResolveAndGetAttr>(), "resolver_resolve_and_getattr",
{prim::kPrimGetAttr, prim::kPrimResolve});
resolver_resolve_ = MakeSubstitution(std::make_shared<ResolverResolve>(), "resolver_resolve", prim::kPrimResolve);
resolver_getattr_ = MakeSubstitution(std::make_shared<ResolverGetAttr>(), "resolver_getattr", prim::kPrimGetAttr);
resolver_getattr_resolve_ =
MakeSubstitution(std::make_shared<ResolverGetAttrResolve>(), "resolver_getattr_resolve", prim::kPrimGetAttr);
// In resolver_getattr_resolve_, some patterns have priority over others.
resolver_getattr_resolve_ = MakeSubstitution(std::make_shared<ResolverGetAttrResolve>(), "getattr_resolve",
{prim::kPrimGetAttr, prim::kPrimResolve}, opt::CHECK_RENORM, true);
}
InferenceOptPrepareLib::InferenceOptPrepareLib() {

View File

@ -108,7 +108,12 @@ class OptimizeIRPassLib {
SubstitutionPtr specialize_transform_;
// Auto-monad related eliminaters.
SubstitutionPtr updatestate_eliminater_;
SubstitutionPtr updatestate_only_used_node_eliminater_;
SubstitutionPtr updatestate_pure_node_eliminater_;
SubstitutionPtr updatestate_depend_eliminater_;
SubstitutionPtr updatestate_assign_eliminater_;
SubstitutionPtr updatestate_maketuple_eliminater_;
SubstitutionPtr updatestate_loads_eliminater_;
SubstitutionPtr switch_call_monad_eliminater_;
SubstitutionPtr stopgrad_eliminater_;
SubstitutionPtr load_eliminater_;
@ -166,10 +171,6 @@ class ResolveIRPassLib {
public:
ResolveIRPassLib();
~ResolveIRPassLib() = default;
SubstitutionPtr resolver_resolve_and_getattr_;
SubstitutionPtr resolver_resolve_;
SubstitutionPtr resolver_getattr_;
SubstitutionPtr resolver_getattr_resolve_;
};

View File

@ -0,0 +1,91 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "frontend/optimizer/irpass/symbol_resolver.h"
#include <string>
#include <memory>
#include <vector>
namespace mindspore {
namespace opt {
namespace irpass {
// {prim::kPrimGetAttr, {prim::kPrimResolve, namespace, symbol}, attr}
// {prim::kPrimGetAttr, namespace, attr}
// {prim::kPrimGetAttr, bool, attr}
// {prim::kPrimResolve, namespace, symbol}
AnfNodePtr ResolverGetAttrResolve::operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node) {
constexpr char PARSE_SUPER_NAME[] = "namespace";
constexpr size_t namespace_index = 1;
constexpr size_t symbol_index = 2;
PatternNode<AnfNodePtr> resolve_node, ns_node, sym_node, attr_node, bool_node;
auto GetAttrResolveLambda = [&node, &resolve_node, &attr_node, &optimizer]() -> AnfNodePtr {
auto inner = resolve_node.GetNode(node);
auto attr = attr_node.GetNode(node);
if (IsPrimitiveCNode(inner, prim::kPrimResolve)) {
auto resolve_cnode = inner->cast<CNodePtr>();
auto namespace_node = resolve_cnode->input(namespace_index);
auto symbol_node = resolve_cnode->input(symbol_index);
if (!IsValueNode<parse::NameSpace>(namespace_node) || !IsValueNode<parse::Symbol>(symbol_node)) {
return nullptr;
}
// deal with the case of getting attr from a class member
// and avoid the case of getting attr from self (the result of ParseSuper)
auto ns = GetValueNode<parse::NameSpacePtr>(namespace_node);
auto sym = GetValueNode<parse::SymbolPtr>(symbol_node);
if (ns->module() == parse::RESOLVE_NAMESPACE_NAME_CLASS_MEMBER && sym->symbol() != PARSE_SUPER_NAME) {
return parse::ResolveCellwithAttr(optimizer->manager(), ns, sym, inner, attr);
}
}
return nullptr;
};
auto GetAttrLambda = [&node, &ns_node, &attr_node, &optimizer]() -> AnfNodePtr {
auto ns = GetValueNode<parse::NameSpacePtr>(ns_node.GetNode(node));
auto str = GetValue<std::string>(GetValueNode(attr_node.GetNode(node)));
parse::SymbolPtr sym = std::make_shared<parse::Symbol>(str);
return parse::ResolveSymbol(optimizer->manager(), ns, sym, node);
};
auto ResolveLambda = [&node, &ns_node, &sym_node, &optimizer]() -> AnfNodePtr {
auto ns = GetValueNode<parse::NameSpacePtr>(ns_node.GetNode(node));
auto sym = GetValueNode<parse::SymbolPtr>(sym_node.GetNode(node));
auto manager = optimizer->manager();
return parse::ResolveSymbol(manager, ns, sym, node);
};
// {prim::kPrimGetAttr, {prim::kPrimResolve, namespace, symbol}, attr}
MATCH_REPLACE_LAMBDA_IF(node, PPrimitive(prim::kPrimGetAttr, resolve_node, attr_node), GetAttrResolveLambda,
attr_node.CheckFunc(IsValueNode<StringImm>, node));
// {prim::kPrimGetAttr, namespace, attr}
MATCH_REPLACE_LAMBDA_IF(
node, PPrimitive(prim::kPrimGetAttr, ns_node, attr_node), GetAttrLambda,
ns_node.CheckFunc(IsValueNode<parse::NameSpace>, node) && attr_node.CheckFunc(IsValueNode<StringImm>, node));
// {prim::kPrimGetAttr, bool, attr}
MATCH_REPLACE_IF(
node, PPrimitive(prim::kPrimGetAttr, bool_node, attr_node), bool_node,
bool_node.CheckFunc(IsValueNode<BoolImm>, node) && attr_node.CheckFunc(IsValueNode<StringImm>, node));
// {prim::kPrimResolve, namespace, symbol}
MATCH_REPLACE_LAMBDA_IF(
node, PPrimitive(prim::kPrimResolve, ns_node, sym_node), ResolveLambda,
ns_node.CheckFunc(IsValueNode<parse::NameSpace>, node) && sym_node.CheckFunc(IsValueNode<parse::Symbol>, node));
return nullptr;
}
} // namespace irpass
} // namespace opt
} // namespace mindspore

View File

@ -34,117 +34,17 @@
namespace mindspore {
namespace opt {
namespace irpass {
const char PARSE_SUPER_NAME[] = "namespace";
// {prim::kPrimResolve, Ns, Sym}
class ResolverResolve : public AnfVisitor {
public:
AnfNodePtr operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node) override {
Reset();
AnfVisitor::Match(prim::kPrimResolve, {IsVNode, IsVNode})(node);
if (sym_ != nullptr) {
return parse::ResolveSymbol(optimizer->manager(), ns_, sym_, node);
}
return nullptr;
}
void Visit(const ValueNodePtr &vnode) override {
if (IsValueNode<parse::NameSpace>(vnode)) {
ns_ = GetValueNode<parse::NameSpacePtr>(vnode);
} else if (ns_ != nullptr && IsValueNode<parse::Symbol>(vnode)) {
sym_ = GetValueNode<parse::SymbolPtr>(vnode);
}
}
void Reset() {
ns_ = nullptr;
sym_ = nullptr;
}
private:
parse::NameSpacePtr ns_{nullptr};
parse::SymbolPtr sym_{nullptr};
};
// {prim::kPrimGetAttr, Ns, Str}
class ResolverGetAttr : public AnfVisitor {
public:
AnfNodePtr operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node) override {
Reset();
AnfVisitor::Match(prim::kPrimGetAttr, {IsVNode, IsVNode})(node);
if (sym_ != nullptr) {
return parse::ResolveSymbol(optimizer->manager(), ns_, sym_, node);
}
return nullptr;
}
void Visit(const AnfNodePtr &node) override {
if (IsValueNode<parse::NameSpace>(node)) {
ns_ = GetValueNode<parse::NameSpacePtr>(node);
} else if (ns_ != nullptr && IsValueNode<StringImm>(node)) {
auto str = GetValue<std::string>(GetValueNode(node));
sym_ = std::make_shared<parse::Symbol>(str);
}
}
void Reset() {
ns_ = nullptr;
sym_ = nullptr;
}
private:
parse::NameSpacePtr ns_{nullptr};
parse::SymbolPtr sym_{nullptr};
};
// {prim::kPrimGetAttr, {prim::kPrimResolve, ns_node, sym_node}, attr_node}
// Put GetAttr pattern and Resolve pattern together to ensure that GetAttr pattern always takes precedence over Resolve
// pattern. After matching GetAttr pattern, there may be new nodes that can match GetAttr pattern and Resolve pattern.
// The same is true for matching Resolve pattern.
//
// {prim::kPrimGetAttr, {prim::kPrimResolve, namespace, symbol}, attr}
// {prim::kPrimGetAttr, namespace, attr}
// {prim::kPrimGetAttr, bool, attr}
// {prim::kPrimResolve, namespace, symbol}
class ResolverGetAttrResolve : public OptimizerCaller {
public:
AnfNodePtr operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node) override {
PatternNode<AnfNodePtr> ns_node, sym_node, attr_node;
auto ResolveAttrLambda = [&node, &ns_node, &sym_node, &attr_node, &optimizer]() -> AnfNodePtr {
auto node_to_getattr = node->cast<CNodePtr>()->input(1);
std::string attr_as_string = GetValueNode<StringImmPtr>(attr_node.GetNode(node))->value();
auto ns_ = GetValueNode<parse::NameSpacePtr>(ns_node.GetNode(node));
auto sym_ = GetValueNode<parse::SymbolPtr>(sym_node.GetNode(node));
if (ns_->module() == parse::RESOLVE_NAMESPACE_NAME_CLASS_MEMBER && sym_->symbol() != PARSE_SUPER_NAME) {
// deal with the case of getting attr from a class member
// and avoid the case of getting attr from self (the result of ParseSuper)
auto result = parse::ResolveCellwithAttr(optimizer->manager(), ns_, sym_, node_to_getattr, attr_as_string);
return result;
}
return nullptr;
};
MATCH_REPLACE_LAMBDA_IF(
node, PPrimitive(prim::kPrimGetAttr, PPrimitive(prim::kPrimResolve, ns_node, sym_node), attr_node),
ResolveAttrLambda, attr_node.CheckFunc(IsValueNode<StringImm>, node));
return nullptr;
}
};
class ResolverResolveAndGetAttr : public OptimizerCaller {
public:
ResolverResolveAndGetAttr() {
resolver_optimizers_ = {std::make_shared<ResolverGetAttrResolve>(), std::make_shared<ResolverResolve>(),
std::make_shared<ResolverGetAttr>()};
}
virtual ~ResolverResolveAndGetAttr() = default;
AnfNodePtr operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node) override {
AnfNodePtr new_node;
for (const auto &resolver_opt : resolver_optimizers_) {
new_node = (*resolver_opt)(optimizer, node);
if (new_node != nullptr) {
return new_node;
}
}
return nullptr;
}
private:
std::vector<OptimizerCallerPtr> resolver_optimizers_{};
AnfNodePtr operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node) override;
};
} // namespace irpass
} // namespace opt

View File

@ -22,6 +22,10 @@
#include <vector>
#include "frontend/operator/ops.h"
#include "frontend/optimizer/irpass.h"
#include "frontend/optimizer/optimizer_caller.h"
#include "frontend/optimizer/anf_visitor.h"
#include "ir/pattern_matcher.h"
namespace mindspore::opt::irpass {
namespace {
@ -81,35 +85,7 @@ bool OnlyUsedByTwoNode(const AnfNodePtr &be_used_node, const AnfNodePtr &first_n
(first_user == second_node && second_user == first_node);
}
// Eliminate useless node that only used by associated update_state.
// Convert:
// x1 = node(x, u)
// u1 = update_state(u, x1) # update_state is the only user of node
// user(u1)
// To:
// user(u)
AnfNodePtr EliminateUpdateStateOnlyUsedNode(const CNodePtr &update_state, const AnfNodePtr &node) {
if (!OnlyUsedByOneNode(node, update_state)) {
// Skip if UpdateState is not the only user of cnode.
return nullptr;
}
// Replace UpdateState with the input monad.
return update_state->input(kInputIndex);
}
// Eliminate UpdateState that attaches a pure (no-side-effect) node.
// Convert:
// x = pure_node(args) # no side effect
// u1 = update_state(u, x)
// user(u1)
// To:
// x = pure_node(args)
// user(u)
AnfNodePtr EliminateUpdateStateForPureNode(const CNodePtr &update_state, const AnfNodePtr &attach) {
if (IsPrimitiveCNode(attach, prim::kPrimTupleGetItem)) {
// Skip tuple_getitem.
return nullptr;
}
auto cnode = dyn_cast<CNode>(attach);
if (cnode == nullptr) {
// Skip value node or parameter.
@ -122,26 +98,11 @@ AnfNodePtr EliminateUpdateStateForPureNode(const CNodePtr &update_state, const A
return nullptr;
}
}
// Skip Call/Switch/SwitchLayer.
auto first_input_node = cnode->input(kFirstInputIndex);
if (IsPrimitiveCNode(first_input_node, prim::kPrimCall) || IsPrimitiveCNode(first_input_node, prim::kPrimSwitch) ||
IsPrimitiveCNode(first_input_node, prim::kPrimSwitchLayer)) {
return nullptr;
}
// Remove UpdateState by replace it with its input monad.
return update_state->input(kInputIndex);
}
// Eliminate redundant UpdateState/Depend pair nodes caused by inline.
// Convert:
// x1 = Depend(x, u)
// u1 = UpdateState(u, x1)
// out = x_user(x1)
// u2 = u_user(u1)
// To:
// out = x_user(x)
// u2 = u_user(u)
AnfNodePtr EliminateUpdateStateWithDepend(const CNodePtr &update_state, const CNodePtr &depend) {
auto input_monad = depend->inputs().back();
if (!HasAbstractMonad(input_monad)) {
@ -638,28 +599,86 @@ AnfNodePtr EliminateUpdateStateBetweenAssignMakeTuple(const CNodePtr &update_sta
}
return nullptr;
}
} // namespace
AnfNodePtr UpdatestateEliminater::operator()(const OptimizerPtr &, const AnfNodePtr &node) {
// Eliminate useless node that only used by associated update_state.
// {prim::kPrimUpdateState, u, {prim::kPrimLoad, m, u}} -> u
// {prim::kPrimUpdateState, u, {prim::kPrimPartial, m, u}} -> u
// Convert:
// x1 = node(x, u)
// u1 = update_state(u, x1) # update_state is the only user of x1.
// user(u1)
// To:
// user(u)
AnfNodePtr UpdatestateOnlyUsedNodeEliminater::operator()(const OptimizerPtr &, const AnfNodePtr &node) {
auto update_state_node = dyn_cast<CNode>(node);
if (update_state_node == nullptr || update_state_node->inputs().empty()) {
MS_LOG(WARNING) << "UpdatestateEliminater encounter invalid node: " << node->DebugString();
return nullptr;
}
auto &attach = update_state_node->input(kAttachIndex);
if (IsPrimitiveCNode(attach, prim::kPrimPartial) || IsPrimitiveCNode(attach, prim::kPrimLoad)) {
// Replace UpdateState with the input monad.
if (OnlyUsedByOneNode(attach, update_state_node)) {
return update_state_node->input(kInputIndex);
}
}
return nullptr;
}
// Handle UpdateState(u, Depend(...)).
// Eliminate UpdateState that attaches a pure (no-side-effect) node.
// Convert:
// x = pure_node(args) # no side effect
// u1 = update_state(u, x)
// user(u1)
// To:
// x = pure_node(args)
// user(u)
AnfNodePtr UpdatestatePureNodeEliminater::operator()(const OptimizerPtr &, const AnfNodePtr &node) {
auto update_state_node = dyn_cast<CNode>(node);
if (update_state_node == nullptr || update_state_node->inputs().empty()) {
MS_LOG(WARNING) << "UpdatestateEliminater encounter invalid node: " << node->DebugString();
return nullptr;
}
auto &attach = update_state_node->input(kAttachIndex);
if (IsPrimitiveCNode(attach, prim::kPrimTupleGetItem) || IsPrimitiveCNode(attach, prim::kPrimDepend) ||
IsPrimitiveCNode(attach, prim::kPrimPartial) || IsPrimitiveCNode(attach, prim::kPrimMakeTuple)) {
return nullptr;
}
return EliminateUpdateStateForPureNode(update_state_node, attach);
}
// Eliminate redundant UpdateState/Depend pair nodes caused by inline.
// Convert:
// x1 = Depend(x, u)
// u1 = UpdateState(u, x1)
// out = x_user(x1)
// u2 = u_user(u1)
// To:
// out = x_user(x)
// u2 = u_user(u)
AnfNodePtr UpdatestateDependEliminater::operator()(const OptimizerPtr &, const AnfNodePtr &node) {
auto update_state_node = dyn_cast<CNode>(node);
if (update_state_node == nullptr || update_state_node->inputs().empty()) {
MS_LOG(WARNING) << "UpdatestateEliminater encounter invalid node: " << node->DebugString();
return nullptr;
}
auto &attach = update_state_node->input(kAttachIndex);
if (IsPrimitiveCNode(attach, prim::kPrimDepend)) {
return EliminateUpdateStateWithDepend(update_state_node, attach->cast<CNodePtr>());
}
return nullptr;
}
// Handle UpdateState(u, Partial(...)).
if (IsPrimitiveCNode(attach, prim::kPrimPartial)) {
return EliminateUpdateStateOnlyUsedNode(update_state_node, attach);
// Eliminate UpdateStates between Assign nodes.
// Eliminate UpdateStates between Assign and MakeTuple.
AnfNodePtr UpdatestateAssignEliminater::operator()(const OptimizerPtr &, const AnfNodePtr &node) {
auto update_state_node = dyn_cast<CNode>(node);
if (update_state_node == nullptr || update_state_node->inputs().empty()) {
MS_LOG(WARNING) << "UpdatestateEliminater encounter invalid node: " << node->DebugString();
return nullptr;
}
// Handle UpdateState(u, Assign(...)).
auto &attach = update_state_node->input(kAttachIndex);
if (IsPrimitiveCNode(attach, prim::kPrimAssign)) {
auto new_node = EliminateUpdateStateBetweenAssigns(update_state_node, attach);
if (new_node != nullptr) {
@ -667,20 +686,15 @@ AnfNodePtr UpdatestateEliminater::operator()(const OptimizerPtr &, const AnfNode
}
return EliminateUpdateStateBetweenMakeTupleAssign(update_state_node, attach);
}
return nullptr;
}
// Handle UpdateState(u, Load(...)).
const bool attach_is_load = IsPrimitiveCNode(attach, prim::kPrimLoad);
if (attach_is_load) {
auto new_node = EliminateUpdateStateOnlyUsedNode(update_state_node, attach);
if (new_node != nullptr) {
return new_node;
}
}
// Handle UpdateState(u, MakeTuple(...)).
const bool attach_is_tuple = IsPrimitiveCNode(attach, prim::kPrimMakeTuple);
if (attach_is_tuple) {
auto make_tuple = attach->cast<CNodePtr>();
// Eliminate UpdateStates which the second input is MakeTuple.
AnfNodePtr UpdatestateMakeTupleEliminater::operator()(const OptimizerPtr &, const AnfNodePtr &node) {
PatternNode<AnfNodePtr> u, attach;
auto MakeTupleLambda = [&node, &u, &attach]() -> AnfNodePtr {
auto update_state_node = node->cast<CNodePtr>();
auto make_tuple = attach.GetNode(node)->cast<CNodePtr>();
auto new_node = EliminateMakeTupleWithDeadNode(update_state_node, make_tuple);
if (new_node != nullptr) {
return new_node;
@ -689,23 +703,31 @@ AnfNodePtr UpdatestateEliminater::operator()(const OptimizerPtr &, const AnfNode
if (new_node != nullptr) {
return new_node;
}
new_node = EliminateUpdateStateBetweenAssignMakeTuple(update_state_node, make_tuple);
if (new_node != nullptr) {
return new_node;
}
return EliminateUpdateStateBetweenAssignMakeTuple(update_state_node, make_tuple);
};
MATCH_REPLACE_LAMBDA_IF(node, PPrimitive(prim::kPrimUpdateState, u, attach), MakeTupleLambda,
IsPrimitiveCNode(attach.GetNode(node), prim::kPrimMakeTuple));
return nullptr;
}
// Eliminate UpdateStates for consecutive Loads.
AnfNodePtr UpdatestateLoadsEliminater::operator()(const OptimizerPtr &, const AnfNodePtr &node) {
auto update_state_node = dyn_cast<CNode>(node);
if (update_state_node == nullptr || update_state_node->inputs().empty()) {
MS_LOG(WARNING) << "UpdatestateEliminater encounter invalid node: " << node->DebugString();
return nullptr;
}
// Merge UpdateStates for Loads.
if (attach_is_load || attach_is_tuple) {
auto &attach = update_state_node->input(kAttachIndex);
if (IsPrimitiveCNode(attach, prim::kPrimLoad) || IsPrimitiveCNode(attach, prim::kPrimMakeTuple)) {
std::vector<CNodePtr> update_states;
std::vector<CNodePtr> loads;
GetLoadsFromUpdateState(update_state_node, &update_states, &loads);
if (update_states.size() > 1 && loads.size() > 1) {
return EliminateUpdateStateForLoads(update_state_node, update_states, loads);
}
return nullptr;
}
// Eliminate UpdateStates that attaches a no-side-effect node.
return EliminateUpdateStateForPureNode(update_state_node, attach);
return nullptr;
}
// Eliminate Monad parameter for switch call.
@ -725,7 +747,7 @@ AnfNodePtr UpdatestateEliminater::operator()(const OptimizerPtr &, const AnfNode
// g2 = Partial(..., u)
// s = switch(cond, g1, g2)
// res = s()
AnfNodePtr EliminateMonadParameterForSwitchCall(const AnfNodePtr &node) {
AnfNodePtr SwitchCallMonadParameterEliminater::operator()(const OptimizerPtr &, const AnfNodePtr &node) {
const CNodePtr &switch_call = dyn_cast<CNode>(node);
if (switch_call == nullptr) {
return nullptr;
@ -777,8 +799,4 @@ AnfNodePtr EliminateMonadParameterForSwitchCall(const AnfNodePtr &node) {
auto new_switch_call = fg->NewCNode({new_switch_cnode});
return new_switch_call;
}
AnfNodePtr SwitchCallMonadParameterEliminater::operator()(const OptimizerPtr &, const AnfNodePtr &node) {
return EliminateMonadParameterForSwitchCall(node);
}
} // namespace mindspore::opt::irpass

View File

@ -21,17 +21,44 @@
#include "frontend/optimizer/anf_visitor.h"
namespace mindspore::opt::irpass {
//
// UpdatestateEliminater eliminates redundant UpdateState related nodes.
//
class UpdatestateEliminater : public AnfVisitor {
// Eliminate useless node that only used by associated update_state.
class UpdatestateOnlyUsedNodeEliminater : public AnfVisitor {
public:
AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override;
};
// Eliminate UpdateStates that attaches a no-side-effect node.
class UpdatestatePureNodeEliminater : public AnfVisitor {
public:
AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override;
};
// Eliminate redundant UpdateState/Depend pair nodes caused by inline.
class UpdatestateDependEliminater : public AnfVisitor {
public:
AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override;
};
// Eliminate UpdateStates between Assign nodes.
// Eliminate UpdateStates between Assign and MakeTuple.
class UpdatestateAssignEliminater : public AnfVisitor {
public:
AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override;
};
// Eliminate UpdateStates which the second input is MakeTuple.
class UpdatestateMakeTupleEliminater : public AnfVisitor {
public:
AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override;
};
// Eliminate UpdateStates for consecutive Loads.
class UpdatestateLoadsEliminater : public AnfVisitor {
public:
AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override;
};
//
// SwitchCallMonadParameterEliminater eliminates Monad parameter in switch call.
//
class SwitchCallMonadParameterEliminater : public AnfVisitor {
public:
AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override;

View File

@ -30,13 +30,14 @@ namespace mindspore {
/* namespace to support opt */
namespace opt {
SubstitutionPtr MakeSubstitution(const OptimizerCallerPtr &transform, const std::string &name, const PrimitivePtr &prim,
const RenormAction &renorm_action) {
const RenormAction &renorm_action, bool has_priority_pattern) {
auto fn = [prim](const AnfNodePtr &node) -> bool { return IsPrimitiveCNode(node, prim); };
return std::make_shared<Substitution>(transform, name, fn, renorm_action);
return std::make_shared<Substitution>(transform, name, fn, renorm_action, has_priority_pattern);
}
SubstitutionPtr MakeSubstitution(const OptimizerCallerPtr &transform, const std::string &name,
const std::vector<PrimitivePtr> &prims, const RenormAction &renorm_action) {
const std::vector<PrimitivePtr> &prims, const RenormAction &renorm_action,
bool has_priority_pattern) {
auto fn = [prims](const AnfNodePtr &node) -> bool {
if (!node->isa<CNode>()) {
return false;
@ -59,12 +60,13 @@ SubstitutionPtr MakeSubstitution(const OptimizerCallerPtr &transform, const std:
return false;
};
return std::make_shared<Substitution>(transform, name, fn, renorm_action);
return std::make_shared<Substitution>(transform, name, fn, renorm_action, has_priority_pattern);
}
SubstitutionPtr MakeSubstitution(const OptimizerCallerPtr &transform, const std::string &name,
const PredicateFuncType &predicate, const RenormAction &renorm_action) {
return std::make_shared<Substitution>(transform, name, predicate, renorm_action);
const PredicateFuncType &predicate, const RenormAction &renorm_action,
bool has_priority_pattern) {
return std::make_shared<Substitution>(transform, name, predicate, renorm_action, has_priority_pattern);
}
AnfNodePtr Substitution::operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node) {
@ -126,16 +128,41 @@ static AnfNodePtr DoTransform(const OptimizerPtr &optimizer, const AnfNodePtr &n
return nullptr;
}
static void UpdateTransformingList(const OptimizerPtr &optimizer, const AnfNodePtr &node, std::deque<AnfNodePtr> *todo,
bool change, size_t seen) {
static void UpdateTransformingListForSubstitutions(const AnfNodePtr &node, std::deque<AnfNodePtr> *todo, bool change) {
if (IsValueNode<FuncGraph>(node)) {
(*todo).emplace_back(GetValueNode<FuncGraphPtr>(node)->output());
}
if (node->isa<CNode>()) {
auto &inputs = node->cast<CNodePtr>()->inputs();
(void)std::copy(inputs.begin(), inputs.end(), std::back_inserter(*todo));
if (change) {
(*todo).emplace_back(node);
} else {
if (node->isa<CNode>()) {
auto &inputs = node->cast<CNodePtr>()->inputs();
(void)std::copy(inputs.begin(), inputs.end(), std::back_inserter(*todo));
}
}
}
static void UpdateTransformingListForIR(const AnfNodePtr &node, std::deque<AnfNodePtr> *todo, bool change,
const SubstitutionPtr &substitution) {
if (IsValueNode<FuncGraph>(node)) {
(*todo).emplace_back(GetValueNode<FuncGraphPtr>(node)->output());
}
// If there is a priority pattern in substitution, don't transform the new node,
// otherwise some nodes may match the wrong patterns.
if (change && substitution != nullptr && !substitution->has_priority_pattern_) {
(*todo).emplace_back(node);
} else {
if (node->isa<CNode>()) {
auto &inputs = node->cast<CNodePtr>()->inputs();
(void)std::copy(inputs.begin(), inputs.end(), std::back_inserter(*todo));
}
}
}
static void UpdateTransformingListWithUserNodes(const OptimizerPtr &optimizer, const AnfNodePtr &node,
std::deque<AnfNodePtr> *todo, bool change, size_t seen) {
if (!change) {
return;
}
@ -185,11 +212,11 @@ bool SubstitutionList::ApplyIRToSubstitutions(const OptimizerPtr &optimizer, con
change = true;
changes = true;
node = res;
todo.emplace_back(res);
break;
}
}
UpdateTransformingList(optimizer, node, &todo, change, seen);
UpdateTransformingListForSubstitutions(node, &todo, change);
UpdateTransformingListWithUserNodes(optimizer, node, &todo, change, seen);
}
#ifdef ENABLE_PROFILE
MsProfile::StatTime("opt.transforms." + optimizer->name(), GetTime() - start);
@ -197,7 +224,7 @@ bool SubstitutionList::ApplyIRToSubstitutions(const OptimizerPtr &optimizer, con
return changes;
}
bool SubstitutionList::ApplySubstitutionToIR(const OptimizerPtr &optimizer, const AnfNodePtr &root_node,
bool SubstitutionList::ApplySubstitutionToIR(const OptimizerPtr &optimizer, const FuncGraphPtr &func_graph,
const SubstitutionPtr &substitution) const {
#ifdef ENABLE_PROFILE
double start = GetTime();
@ -205,7 +232,7 @@ bool SubstitutionList::ApplySubstitutionToIR(const OptimizerPtr &optimizer, cons
FuncGraphManagerPtr manager = optimizer->manager();
auto seen = NewSeenGeneration();
std::deque<AnfNodePtr> todo;
todo.emplace_back(root_node);
todo.emplace_back(func_graph->output());
bool changes = false;
auto &all_nodes = manager->all_nodes();
@ -225,7 +252,8 @@ bool SubstitutionList::ApplySubstitutionToIR(const OptimizerPtr &optimizer, cons
changes = true;
node = res;
}
UpdateTransformingList(optimizer, node, &todo, change, seen);
UpdateTransformingListForIR(node, &todo, change, substitution);
UpdateTransformingListWithUserNodes(optimizer, node, &todo, change, seen);
}
#ifdef ENABLE_PROFILE
@ -268,7 +296,7 @@ bool SubstitutionList::ApplySubstitutionsToIR(const OptimizerPtr &optimizer, con
loop = false;
for (size_t i = 0; i < list_.size(); i++) {
const auto &substitution = list_[i];
bool change = ApplySubstitutionToIR(optimizer, func_graph->output(), substitution);
bool change = ApplySubstitutionToIR(optimizer, func_graph, substitution);
changes = changes || change;
loop = loop || change;

View File

@ -17,6 +17,7 @@
#ifndef MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_OPT_H_
#define MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_OPT_H_
#include <deque>
#include <memory>
#include <string>
#include <vector>
@ -41,11 +42,18 @@ class Substitution {
OptimizerCallerPtr transform_;
std::string name_;
PredicateFuncType predicate_{nullptr};
// an enum to mark this Substitution relation to renormalize pass
// An enum to mark this Substitution relation to renormalize pass
RenormAction renorm_action_;
// Determine whether it is a priority substitution, that is, some patterns need to be matched prior to others.
bool has_priority_pattern_{false};
Substitution(const OptimizerCallerPtr &transform, const std::string &name, const PredicateFuncType &predicate,
const RenormAction &renorm_action)
: transform_(transform), name_(name), predicate_(predicate), renorm_action_(renorm_action) {}
const RenormAction &renorm_action, bool has_priority_pattern)
: transform_(transform),
name_(name),
predicate_(predicate),
renorm_action_(renorm_action),
has_priority_pattern_(has_priority_pattern) {}
~Substitution() = default;
AnfNodePtr operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node);
};
@ -53,12 +61,13 @@ class Substitution {
using SubstitutionPtr = std::shared_ptr<Substitution>;
SubstitutionPtr MakeSubstitution(const OptimizerCallerPtr &transform, const std::string &name, const PrimitivePtr &prim,
const RenormAction &action_renorm = CHECK_RENORM);
const RenormAction &action_renorm = CHECK_RENORM, bool has_priority_pattern = false);
SubstitutionPtr MakeSubstitution(const OptimizerCallerPtr &transform, const std::string &name,
const std::vector<PrimitivePtr> &prims,
const RenormAction &action_renorm = CHECK_RENORM);
const RenormAction &action_renorm = CHECK_RENORM, bool has_priority_pattern = false);
SubstitutionPtr MakeSubstitution(const OptimizerCallerPtr &transform, const std::string &name,
const PredicateFuncType &predicate, const RenormAction &action_renorm = CHECK_RENORM);
const PredicateFuncType &predicate, const RenormAction &action_renorm = CHECK_RENORM,
bool has_priority_pattern = false);
enum OptTraverseSubstitutionsMode { kOptTraverseFromIRToSubstitutions = 0, kOptTraverseFromSubstitutionsToIR };
@ -73,15 +82,16 @@ class SubstitutionList {
private:
bool ApplyIRToSubstitutions(const OptimizerPtr &optimizer, const FuncGraphPtr &func_graph) const;
bool ApplySubstitutionToIR(const OptimizerPtr &optimizer, const AnfNodePtr &node, const SubstitutionPtr &sub) const;
bool ApplySubstitutionToIR(const OptimizerPtr &optimizer, const FuncGraphPtr &func_graph,
const SubstitutionPtr &sub) const;
bool ApplySubstitutionsToIR(const OptimizerPtr &optimizer, const FuncGraphPtr &func_graph) const;
void DisplayStatusOfSubstitution(const std::unordered_map<std::string, std::vector<bool>> &status,
const OptimizerPtr &optimizer, size_t space) const;
std::vector<SubstitutionPtr> list_;
// a flag to mark this list of Substitution can only be executed only once
bool is_once_;
bool global_sensitive_;
bool is_once_{false};
bool global_sensitive_{false};
};
} // namespace opt
} // namespace mindspore

View File

@ -28,6 +28,7 @@
#include "frontend/operator/ops.h"
#include "frontend/optimizer/opt.h"
#include "frontend/optimizer/irpass.h"
#include "frontend/optimizer/irpass/symbol_resolver.h"
namespace mindspore {
namespace parse {
@ -306,7 +307,7 @@ AnfNodePtr ResolveSymbol(const FuncGraphManagerPtr &manager, const NameSpacePtr
}
AnfNodePtr ResolveCellwithAttr(const FuncGraphManagerPtr &manager, const NameSpacePtr &name_space,
const SymbolPtr &symbol, const AnfNodePtr &node, const std::string &attr) {
const SymbolPtr &symbol, const AnfNodePtr &node, const AnfNodePtr &attr) {
MS_EXCEPTION_IF_NULL(node);
TraceGuard trace_guard(std::make_shared<TraceResolve>(node->debug_info()));
if (node->func_graph() == nullptr || manager == nullptr) {
@ -319,14 +320,19 @@ AnfNodePtr ResolveCellwithAttr(const FuncGraphManagerPtr &manager, const NameSpa
py::object obj = symbol_resolver.result();
if (!data_converter::IsCellInstance(obj)) {
return nullptr;
AnfNodePtr resolved_node = ResolveObjectAndAddToManager(manager, obj, node);
AnfNodePtrList inputs = {NewValueNode(prim::kPrimGetAttr), resolved_node, attr};
AnfNodePtr res_node = node->func_graph()->NewCNode(inputs);
TraceManager::ClearParseOrResolveDebugInfo();
return res_node;
}
const std::string fn = PYTHON_MOD_GET_MEMBER_NAMESPACE_SYMBOL;
const std::string module = "mindspore._extends.parse.parser";
py::object namespace_obj = parse::python_adapter::GetPyFn(module, fn)(obj);
auto new_namespace = std::make_shared<NameSpace>(RESOLVE_NAMESPACE_NAME_CLASS_MEMBER, namespace_obj);
auto new_symbol = std::make_shared<Symbol>(attr);
std::string attr_as_string = GetValueNode<StringImmPtr>(attr)->value();
auto new_symbol = std::make_shared<Symbol>(attr_as_string);
AnfNodePtrList inputs = {NewValueNode(prim::kPrimResolve), NewValueNode(new_namespace), NewValueNode(new_symbol)};
AnfNodePtr resolved_node = node->func_graph()->NewCNode(inputs);
@ -336,11 +342,11 @@ AnfNodePtr ResolveCellwithAttr(const FuncGraphManagerPtr &manager, const NameSpa
namespace {
opt::OptPassGroupMap GetOptResolvePasses(const opt::irpass::ResolveIRPassLib &irpass) {
// For resolve and getattr primitive.
opt::OptPassGroupMap map({
{"resolve",
{
// For resolve and getattr primitive;
irpass.resolver_resolve_and_getattr_,
irpass.resolver_getattr_resolve_,
}},
});
return map;

View File

@ -148,7 +148,7 @@ AnfNodePtr ResolveSymbol(const FuncGraphManagerPtr &manager, const NameSpacePtr
// Resolve Cell with attr name.
AnfNodePtr ResolveCellwithAttr(const FuncGraphManagerPtr &manager, const NameSpacePtr &name_space,
const SymbolPtr &symbol, const AnfNodePtr &node, const std::string &attr);
const SymbolPtr &symbol, const AnfNodePtr &node, const AnfNodePtr &attr);
// Resolve one graph which normally is the root graph. FuncGraph shall be managed by res->manager().
bool ResolveFuncGraph(const FuncGraphPtr &func_graph, const pipeline::ResourceBasePtr &res, bool use_profile = true);

View File

@ -234,7 +234,12 @@ opt::OptPassConfig GetOptPassA1(const opt::irpass::OptimizeIRPassLib &irpass) {
// Safe inlining
irpass.inline_,
irpass.updatestate_eliminater_,
irpass.updatestate_depend_eliminater_,
irpass.updatestate_assign_eliminater_,
irpass.updatestate_maketuple_eliminater_,
irpass.updatestate_only_used_node_eliminater_,
irpass.updatestate_loads_eliminater_,
irpass.updatestate_pure_node_eliminater_,
irpass.load_eliminater_,
irpass.stopgrad_eliminater_,
irpass.partial_eliminate_,
@ -268,7 +273,12 @@ opt::OptPassConfig GetOptPassA1(const opt::irpass::OptimizeIRPassLib &irpass) {
// Safe inlining
irpass.inline_,
irpass.updatestate_eliminater_,
irpass.updatestate_depend_eliminater_,
irpass.updatestate_assign_eliminater_,
irpass.updatestate_maketuple_eliminater_,
irpass.updatestate_only_used_node_eliminater_,
irpass.updatestate_loads_eliminater_,
irpass.updatestate_pure_node_eliminater_,
irpass.load_eliminater_,
irpass.stopgrad_eliminater_,
irpass.sparse_tensor_eliminate_,
@ -352,7 +362,12 @@ OptPassGroupMap GetOptPassesAfterCconv(const opt::irpass::OptimizeIRPassLib &irp
opt::OptPassConfig c_1 = opt::OptPassConfig({
// Safe inlining,
irpass.inline_,
irpass.updatestate_eliminater_,
irpass.updatestate_depend_eliminater_,
irpass.updatestate_assign_eliminater_,
irpass.updatestate_maketuple_eliminater_,
irpass.updatestate_only_used_node_eliminater_,
irpass.updatestate_loads_eliminater_,
irpass.updatestate_pure_node_eliminater_,
irpass.load_eliminater_,
irpass.switch_call_monad_eliminater_,
irpass.stopgrad_eliminater_,
@ -389,7 +404,12 @@ OptPassGroupMap GetOptPassesB(const opt::irpass::OptimizeIRPassLib &irpass) {
irpass.float_tuple_getitem_switch_,
irpass.reset_defer_inline_,
irpass.inline_,
irpass.updatestate_eliminater_,
irpass.updatestate_depend_eliminater_,
irpass.updatestate_assign_eliminater_,
irpass.updatestate_maketuple_eliminater_,
irpass.updatestate_only_used_node_eliminater_,
irpass.updatestate_loads_eliminater_,
irpass.updatestate_pure_node_eliminater_,
irpass.load_eliminater_,
irpass.stopgrad_eliminater_,
irpass.special_op_eliminate_,
@ -652,10 +672,35 @@ bool AutoMonadElimOptPass(const FuncGraphPtr &func_graph) {
res->set_manager(func_graph->manager());
// opt::irpass::OptimizeIRPassLib is not used here to avoid double free problems in external calls.
opt::SubstitutionPtr updatestate_eliminater = opt::MakeSubstitution(
std::make_shared<opt::irpass::UpdatestateEliminater>(), "updatestate_eliminater", prim::kPrimUpdateState);
opt::SubstitutionPtr updatestate_depend_eliminater =
opt::MakeSubstitution(std::make_shared<opt::irpass::UpdatestateDependEliminater>(), "updatestate_depend_eliminater",
prim::kPrimUpdateState);
opt::SubstitutionPtr updatestate_assign_eliminater =
opt::MakeSubstitution(std::make_shared<opt::irpass::UpdatestateAssignEliminater>(), "updatestate_assign_eliminater",
prim::kPrimUpdateState);
opt::SubstitutionPtr updatestate_maketuple_eliminater =
opt::MakeSubstitution(std::make_shared<opt::irpass::UpdatestateMakeTupleEliminater>(),
"updatestate_maketuple_eliminater", prim::kPrimUpdateState);
opt::SubstitutionPtr updatestate_only_used_node_eliminater =
opt::MakeSubstitution(std::make_shared<opt::irpass::UpdatestateOnlyUsedNodeEliminater>(),
"updatestate_only_used_node_eliminater", prim::kPrimUpdateState);
opt::SubstitutionPtr updatestate_loads_eliminater =
opt::MakeSubstitution(std::make_shared<opt::irpass::UpdatestateLoadsEliminater>(), "updatestate_loads_eliminater",
prim::kPrimUpdateState);
opt::SubstitutionPtr updatestate_pure_node_eliminater =
opt::MakeSubstitution(std::make_shared<opt::irpass::UpdatestatePureNodeEliminater>(),
"updatestate_pure_node_eliminater", prim::kPrimUpdateState);
opt::OptPassConfig updatestate_eliminater = opt::OptPassConfig({
updatestate_depend_eliminater,
updatestate_assign_eliminater,
updatestate_maketuple_eliminater,
updatestate_only_used_node_eliminater,
updatestate_loads_eliminater,
updatestate_pure_node_eliminater,
});
opt::OptPassGroupMap elim_map({
{"updatestate_eliminate", opt::OptPassConfig({updatestate_eliminater})},
{"updatestate_eliminater", updatestate_eliminater},
{"auto_monad_eliminator", opt::OptPassConfig(opt::AutoMonadEliminator())},
});