handle replace in updatestate_eliminate

This commit is contained in:
huangbingjian 2021-08-13 00:38:52 +08:00
parent a480d07dd2
commit 88059b76ee
8 changed files with 98 additions and 87 deletions

View File

@ -146,7 +146,7 @@ OptimizeIRPassLib::OptimizeIRPassLib() {
{prim::kPrimGetRefKey, prim::kPrimGetRefValue});
replace_refkey_by_param_ = MakeSubstitution(std::make_shared<ReplaceRefkeyByParam>(), "replace_refkey_by_param",
IsValueNode<RefKey>, opt::FORCE_RENORM);
IsValueNode<RefKey>, false, opt::FORCE_RENORM);
replace_old_param_ = MakeSubstitution(std::make_shared<ReplaceOldParam>(), "replace_old_param", IsParam);
minmaximum_grad_ = MakeSubstitution(std::make_shared<MinMaximumGrad>(), "minmaximum_grad", prim::kPrimTupleGetItem);
@ -192,13 +192,14 @@ OptimizeIRPassLib::OptimizeIRPassLib() {
updatestate_pure_node_eliminater_ = MakeSubstitution(std::make_shared<UpdatestatePureNodeEliminater>(),
"updatestate_pure_node_eliminater", prim::kPrimUpdateState);
updatestate_depend_eliminater_ = MakeSubstitution(std::make_shared<UpdatestateDependEliminater>(),
"updatestate_depend_eliminater", prim::kPrimUpdateState);
"updatestate_depend_eliminater", prim::kPrimUpdateState, true);
updatestate_assign_eliminater_ = MakeSubstitution(std::make_shared<UpdatestateAssignEliminater>(),
"updatestate_assign_eliminater", prim::kPrimUpdateState);
updatestate_maketuple_eliminater_ = MakeSubstitution(std::make_shared<UpdatestateMakeTupleEliminater>(),
"updatestate_maketuple_eliminater", prim::kPrimUpdateState);
"updatestate_assign_eliminater", prim::kPrimUpdateState, true);
updatestate_maketuple_eliminater_ =
MakeSubstitution(std::make_shared<UpdatestateMakeTupleEliminater>(), "updatestate_maketuple_eliminater",
prim::kPrimUpdateState, true);
updatestate_loads_eliminater_ = MakeSubstitution(std::make_shared<UpdatestateLoadsEliminater>(),
"updatestate_loads_eliminater", prim::kPrimUpdateState);
"updatestate_loads_eliminater", prim::kPrimUpdateState, true);
switch_call_monad_eliminater_ = MakeSubstitution(std::make_shared<SwitchCallMonadParameterEliminater>(),
"switch_call_monad_eliminater", IsCNodeDup);
@ -273,8 +274,9 @@ OptimizeIRPassLib::OptimizeIRPassLib() {
ResolveIRPassLib::ResolveIRPassLib() {
// In resolver_getattr_resolve_, some patterns have priority over others.
resolver_getattr_resolve_ = MakeSubstitution(std::make_shared<ResolverGetAttrResolve>(), "getattr_resolve",
{prim::kPrimGetAttr, prim::kPrimResolve}, opt::CHECK_RENORM, true);
resolver_getattr_resolve_ =
MakeSubstitution(std::make_shared<ResolverGetAttrResolve>(), "getattr_resolve",
{prim::kPrimGetAttr, prim::kPrimResolve}, false, opt::CHECK_RENORM, true);
}
InferenceOptPrepareLib::InferenceOptPrepareLib() {

View File

@ -103,7 +103,8 @@ AnfNodePtr EliminateUpdateStateForPureNode(const CNodePtr &update_state, const A
return update_state->input(kInputIndex);
}
AnfNodePtr EliminateUpdateStateWithDepend(const CNodePtr &update_state, const CNodePtr &depend) {
AnfNodePtr EliminateUpdateStateWithDepend(const OptimizerPtr &optimizer, const CNodePtr &update_state,
const CNodePtr &depend) {
auto input_monad = depend->inputs().back();
if (!HasAbstractMonad(input_monad)) {
// Skip if Depend attach input is not a monad.
@ -131,7 +132,7 @@ AnfNodePtr EliminateUpdateStateWithDepend(const CNodePtr &update_state, const CN
// Replace Depend with its input.
if (depend->size() == kMinDependSize) {
auto depend_input = depend->input(kInputIndex);
mgr->Replace(depend, depend_input);
optimizer->SubstitutionReplace(mgr, depend, depend_input);
} else {
auto inputs = depend->inputs();
inputs.pop_back();
@ -139,7 +140,7 @@ AnfNodePtr EliminateUpdateStateWithDepend(const CNodePtr &update_state, const CN
MS_EXCEPTION_IF_NULL(fg);
auto new_depend = fg->NewCNode(inputs);
new_depend->set_abstract(depend->abstract());
mgr->Replace(depend, new_depend);
optimizer->SubstitutionReplace(mgr, depend, new_depend);
}
// Replace UpdateState node with the input monad of Depend.
return input_monad;
@ -297,7 +298,7 @@ AnfNodePtr MakeTupleForSameNodes(const FuncGraphPtr &fg, const CNodePtr &old_upd
}
// Remove all nodes related to UpdateStates, if they're redundant.
void EliminateUselessNodesForUpdateStates(const std::vector<CNodePtr> &update_states) {
void EliminateUselessNodesForUpdateStates(const OptimizerPtr &optimizer, const std::vector<CNodePtr> &update_states) {
if (update_states.empty()) {
return;
}
@ -309,7 +310,7 @@ void EliminateUselessNodesForUpdateStates(const std::vector<CNodePtr> &update_st
// 1. Remove the use of UpdateState nodes, except the last one.
for (auto i = update_states.size() - 1; i > 0; i--) {
auto &us = update_states[i];
mgr->Replace(us, us->input(kInputIndex));
optimizer->SubstitutionReplace(mgr, us, us->input(kInputIndex));
}
// 2. Remove the Depend users of last UpdateState node.
@ -343,7 +344,7 @@ void EliminateUselessNodesForUpdateStates(const std::vector<CNodePtr> &update_st
for (ssize_t i = depend_nodes.size() - 1; i >= end; i--) {
const auto &depend_node = depend_nodes[i];
const auto &depend_cnode = depend_node->cast<CNodePtr>();
mgr->Replace(depend_cnode, depend_cnode->input(kInputIndex));
optimizer->SubstitutionReplace(mgr, depend_cnode, depend_cnode->input(kInputIndex));
}
}
@ -363,7 +364,8 @@ void EliminateUselessNodesForUpdateStates(const std::vector<CNodePtr> &update_st
// xN = Load(xN, u)
// t = make_tuple(x1, x2, ... , xN)
// u1 = UpdateState(u, t)
AnfNodePtr EliminateUpdateStateForLoads(const CNodePtr &old_update_state, const std::vector<CNodePtr> &update_states,
AnfNodePtr EliminateUpdateStateForLoads(const OptimizerPtr &optimizer, const CNodePtr &old_update_state,
const std::vector<CNodePtr> &update_states,
const std::vector<CNodePtr> &loads) {
auto fg = old_update_state->func_graph();
if (fg == nullptr) {
@ -393,7 +395,7 @@ AnfNodePtr EliminateUpdateStateForLoads(const CNodePtr &old_update_state, const
}
}
EliminateUselessNodesForUpdateStates(update_states);
EliminateUselessNodesForUpdateStates(optimizer, update_states);
if (make_tuple_inputs.size() == 1) {
// This should not happen.
@ -420,7 +422,8 @@ AnfNodePtr EliminateUpdateStateForLoads(const CNodePtr &old_update_state, const
// a2 = Assign(para2, value2, u1)
// t = MakeTuple(a1, a2)
// u3 = UpdateState(u1, t)
AnfNodePtr EliminateUpdateStateBetweenAssigns(const CNodePtr &update_state, const AnfNodePtr &assign) {
AnfNodePtr EliminateUpdateStateBetweenAssigns(const OptimizerPtr &optimizer, const CNodePtr &update_state,
const AnfNodePtr &assign) {
auto a2_cnode = assign->cast<CNodePtr>();
if (a2_cnode->size() != kAssignSize) {
return nullptr;
@ -444,7 +447,7 @@ AnfNodePtr EliminateUpdateStateBetweenAssigns(const CNodePtr &update_state, cons
MS_EXCEPTION_IF_NULL(fg);
auto mgr = fg->manager();
MS_EXCEPTION_IF_NULL(mgr);
mgr->Replace(u2, u1);
optimizer->SubstitutionReplace(mgr, u2, u1);
AnfNodePtrList make_tuple_inputs{NewValueNode(prim::kPrimMakeTuple), a1, assign};
auto make_tuple = MakeTupleForSameNodes(fg, update_state, make_tuple_inputs);
auto new_update_state = fg->NewCNode({NewValueNode(prim::kPrimUpdateState), u1, make_tuple});
@ -472,7 +475,8 @@ AnfNodePtr EliminateUpdateStateBetweenAssigns(const CNodePtr &update_state, cons
// a3 = Assign(para3, value3, u1)
// t = MakeTuple(a1, a2, a3)
// u4 = UpdateState(u1, t)
AnfNodePtr EliminateUpdateStateBetweenMakeTupleAssign(const CNodePtr &update_state, const AnfNodePtr &assign) {
AnfNodePtr EliminateUpdateStateBetweenMakeTupleAssign(const OptimizerPtr &optimizer, const CNodePtr &update_state,
const AnfNodePtr &assign) {
auto a3_cnode = assign->cast<CNodePtr>();
if (a3_cnode->size() != kAssignSize) {
return nullptr;
@ -509,11 +513,11 @@ AnfNodePtr EliminateUpdateStateBetweenMakeTupleAssign(const CNodePtr &update_sta
MS_EXCEPTION_IF_NULL(fg);
auto mgr = fg->manager();
MS_EXCEPTION_IF_NULL(mgr);
mgr->Replace(u3, u1);
optimizer->SubstitutionReplace(mgr, u3, u1);
AnfNodePtrList new_make_tuple_inputs{NewValueNode(prim::kPrimMakeTuple), make_tuple_cnode->input(kInputIndex),
make_tuple_cnode->input(kAttachIndex), assign};
auto new_make_tuple = MakeTupleForSameNodes(fg, update_state, new_make_tuple_inputs);
mgr->Replace(make_tuple, new_make_tuple);
optimizer->SubstitutionReplace(mgr, make_tuple, new_make_tuple);
auto new_update_state = fg->NewCNode({NewValueNode(prim::kPrimUpdateState), u1, new_make_tuple});
new_update_state->set_abstract(update_state->abstract());
new_update_state->set_scope(update_state->scope());
@ -540,7 +544,8 @@ AnfNodePtr EliminateUpdateStateBetweenMakeTupleAssign(const CNodePtr &update_sta
// a3 = Assign(para3, value3, u1)
// t = MakeTuple(a1, a2, a3)
// u4 = UpdateState(u1, t)
AnfNodePtr EliminateUpdateStateBetweenAssignMakeTuple(const CNodePtr &update_state, const AnfNodePtr &make_tuple) {
AnfNodePtr EliminateUpdateStateBetweenAssignMakeTuple(const OptimizerPtr &optimizer, const CNodePtr &update_state,
const AnfNodePtr &make_tuple) {
auto make_tuple_cnode = make_tuple->cast<CNodePtr>();
if (make_tuple_cnode->size() != kMakeTupleSize || !OnlyUsedByOneNode(make_tuple, update_state)) {
return nullptr;
@ -583,12 +588,12 @@ AnfNodePtr EliminateUpdateStateBetweenAssignMakeTuple(const CNodePtr &update_sta
MS_EXCEPTION_IF_NULL(fg);
auto mgr = fg->manager();
MS_EXCEPTION_IF_NULL(mgr);
mgr->Replace(u2, u1);
optimizer->SubstitutionReplace(mgr, u2, u1);
AnfNodePtrList new_make_tuple_inputs{NewValueNode(prim::kPrimMakeTuple), a1,
make_tuple_cnode->input(kInputIndex),
make_tuple_cnode->input(kAttachIndex)};
auto new_make_tuple = MakeTupleForSameNodes(fg, update_state, new_make_tuple_inputs);
mgr->Replace(make_tuple, new_make_tuple);
optimizer->SubstitutionReplace(mgr, make_tuple, new_make_tuple);
auto new_update_state = fg->NewCNode({NewValueNode(prim::kPrimUpdateState), u1, new_make_tuple});
new_update_state->set_abstract(update_state->abstract());
new_update_state->set_scope(update_state->scope());
@ -657,7 +662,7 @@ AnfNodePtr UpdatestatePureNodeEliminater::operator()(const OptimizerPtr &, const
// To:
// out = x_user(x)
// u2 = u_user(u)
AnfNodePtr UpdatestateDependEliminater::operator()(const OptimizerPtr &, const AnfNodePtr &node) {
AnfNodePtr UpdatestateDependEliminater::operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node) {
auto update_state_node = dyn_cast<CNode>(node);
if (update_state_node == nullptr || update_state_node->inputs().empty()) {
MS_LOG(WARNING) << "UpdatestateEliminater encounter invalid node: " << node->DebugString();
@ -665,14 +670,14 @@ AnfNodePtr UpdatestateDependEliminater::operator()(const OptimizerPtr &, const A
}
auto &attach = update_state_node->input(kAttachIndex);
if (IsPrimitiveCNode(attach, prim::kPrimDepend)) {
return EliminateUpdateStateWithDepend(update_state_node, attach->cast<CNodePtr>());
return EliminateUpdateStateWithDepend(optimizer, update_state_node, attach->cast<CNodePtr>());
}
return nullptr;
}
// Eliminate UpdateStates between Assign nodes.
// Eliminate UpdateStates between Assign and MakeTuple.
AnfNodePtr UpdatestateAssignEliminater::operator()(const OptimizerPtr &, const AnfNodePtr &node) {
AnfNodePtr UpdatestateAssignEliminater::operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node) {
auto update_state_node = dyn_cast<CNode>(node);
if (update_state_node == nullptr || update_state_node->inputs().empty()) {
MS_LOG(WARNING) << "UpdatestateEliminater encounter invalid node: " << node->DebugString();
@ -680,19 +685,19 @@ AnfNodePtr UpdatestateAssignEliminater::operator()(const OptimizerPtr &, const A
}
auto &attach = update_state_node->input(kAttachIndex);
if (IsPrimitiveCNode(attach, prim::kPrimAssign)) {
auto new_node = EliminateUpdateStateBetweenAssigns(update_state_node, attach);
auto new_node = EliminateUpdateStateBetweenAssigns(optimizer, update_state_node, attach);
if (new_node != nullptr) {
return new_node;
}
return EliminateUpdateStateBetweenMakeTupleAssign(update_state_node, attach);
return EliminateUpdateStateBetweenMakeTupleAssign(optimizer, update_state_node, attach);
}
return nullptr;
}
// Eliminate UpdateStates which the second input is MakeTuple.
AnfNodePtr UpdatestateMakeTupleEliminater::operator()(const OptimizerPtr &, const AnfNodePtr &node) {
AnfNodePtr UpdatestateMakeTupleEliminater::operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node) {
PatternNode<AnfNodePtr> u, attach;
auto MakeTupleLambda = [&node, &u, &attach]() -> AnfNodePtr {
auto MakeTupleLambda = [&optimizer, &node, &u, &attach]() -> AnfNodePtr {
auto update_state_node = node->cast<CNodePtr>();
auto make_tuple = attach.GetNode(node)->cast<CNodePtr>();
auto new_node = EliminateMakeTupleWithDeadNode(update_state_node, make_tuple);
@ -703,7 +708,7 @@ AnfNodePtr UpdatestateMakeTupleEliminater::operator()(const OptimizerPtr &, cons
if (new_node != nullptr) {
return new_node;
}
return EliminateUpdateStateBetweenAssignMakeTuple(update_state_node, make_tuple);
return EliminateUpdateStateBetweenAssignMakeTuple(optimizer, update_state_node, make_tuple);
};
MATCH_REPLACE_LAMBDA_IF(node, PPrimitive(prim::kPrimUpdateState, u, attach), MakeTupleLambda,
@ -712,7 +717,7 @@ AnfNodePtr UpdatestateMakeTupleEliminater::operator()(const OptimizerPtr &, cons
}
// Eliminate UpdateStates for consecutive Loads.
AnfNodePtr UpdatestateLoadsEliminater::operator()(const OptimizerPtr &, const AnfNodePtr &node) {
AnfNodePtr UpdatestateLoadsEliminater::operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node) {
auto update_state_node = dyn_cast<CNode>(node);
if (update_state_node == nullptr || update_state_node->inputs().empty()) {
MS_LOG(WARNING) << "UpdatestateEliminater encounter invalid node: " << node->DebugString();
@ -724,7 +729,7 @@ AnfNodePtr UpdatestateLoadsEliminater::operator()(const OptimizerPtr &, const An
std::vector<CNodePtr> loads;
GetLoadsFromUpdateState(update_state_node, &update_states, &loads);
if (update_states.size() > 1 && loads.size() > 1) {
return EliminateUpdateStateForLoads(update_state_node, update_states, loads);
return EliminateUpdateStateForLoads(optimizer, update_state_node, update_states, loads);
}
}
return nullptr;

View File

@ -36,26 +36,26 @@ class UpdatestatePureNodeEliminater : public AnfVisitor {
// Eliminate redundant UpdateState/Depend pair nodes caused by inline.
class UpdatestateDependEliminater : public AnfVisitor {
public:
AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override;
AnfNodePtr operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node) override;
};
// Eliminate UpdateStates between Assign nodes.
// Eliminate UpdateStates between Assign and MakeTuple.
class UpdatestateAssignEliminater : public AnfVisitor {
public:
AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override;
AnfNodePtr operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node) override;
};
// Eliminate UpdateStates which the second input is MakeTuple.
class UpdatestateMakeTupleEliminater : public AnfVisitor {
public:
AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override;
AnfNodePtr operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node) override;
};
// Eliminate UpdateStates for consecutive Loads.
class UpdatestateLoadsEliminater : public AnfVisitor {
public:
AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override;
AnfNodePtr operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node) override;
};
// SwitchCallMonadParameterEliminater eliminates Monad parameter in switch call.

View File

@ -30,14 +30,15 @@ namespace mindspore {
/* namespace to support opt */
namespace opt {
SubstitutionPtr MakeSubstitution(const OptimizerCallerPtr &transform, const std::string &name, const PrimitivePtr &prim,
const RenormAction &renorm_action, bool has_priority_pattern) {
bool has_node_replacement, const RenormAction &renorm_action,
bool has_priority_pattern) {
auto fn = [prim](const AnfNodePtr &node) -> bool { return IsPrimitiveCNode(node, prim); };
return std::make_shared<Substitution>(transform, name, fn, renorm_action, has_priority_pattern);
return std::make_shared<Substitution>(transform, name, fn, has_node_replacement, renorm_action, has_priority_pattern);
}
SubstitutionPtr MakeSubstitution(const OptimizerCallerPtr &transform, const std::string &name,
const std::vector<PrimitivePtr> &prims, const RenormAction &renorm_action,
bool has_priority_pattern) {
const std::vector<PrimitivePtr> &prims, bool has_node_replacement,
const RenormAction &renorm_action, bool has_priority_pattern) {
auto fn = [prims](const AnfNodePtr &node) -> bool {
if (!node->isa<CNode>()) {
return false;
@ -60,13 +61,14 @@ SubstitutionPtr MakeSubstitution(const OptimizerCallerPtr &transform, const std:
return false;
};
return std::make_shared<Substitution>(transform, name, fn, renorm_action, has_priority_pattern);
return std::make_shared<Substitution>(transform, name, fn, has_node_replacement, renorm_action, has_priority_pattern);
}
SubstitutionPtr MakeSubstitution(const OptimizerCallerPtr &transform, const std::string &name,
const PredicateFuncType &predicate, const RenormAction &renorm_action,
bool has_priority_pattern) {
return std::make_shared<Substitution>(transform, name, predicate, renorm_action, has_priority_pattern);
const PredicateFuncType &predicate, bool has_node_replacement,
const RenormAction &renorm_action, bool has_priority_pattern) {
return std::make_shared<Substitution>(transform, name, predicate, has_node_replacement, renorm_action,
has_priority_pattern);
}
AnfNodePtr Substitution::operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node) {
@ -212,6 +214,14 @@ bool SubstitutionList::ApplyIRToSubstitutions(const OptimizerPtr &optimizer, con
change = true;
changes = true;
node = res;
// If there is a node replacement in the substitution, add replaced nodes to todo list.
if (substitution->has_node_replacement_) {
for (auto &replaced_node : optimizer->substitution_replaced_nodes()) {
UpdateTransformingListForSubstitutions(replaced_node, &todo, change);
UpdateTransformingListWithUserNodes(optimizer, replaced_node, &todo, change, seen);
}
optimizer->clear_substitution_replaced_nodes();
}
break;
}
}
@ -251,6 +261,14 @@ bool SubstitutionList::ApplySubstitutionToIR(const OptimizerPtr &optimizer, cons
change = true;
changes = true;
node = res;
// If there is a node replacement in the substitution, add replaced nodes to todo list.
if (substitution->has_node_replacement_) {
for (auto &replaced_node : optimizer->substitution_replaced_nodes()) {
UpdateTransformingListForIR(replaced_node, &todo, change, substitution);
UpdateTransformingListWithUserNodes(optimizer, replaced_node, &todo, change, seen);
}
optimizer->clear_substitution_replaced_nodes();
}
}
UpdateTransformingListForIR(node, &todo, change, substitution);
UpdateTransformingListWithUserNodes(optimizer, node, &todo, change, seen);

View File

@ -42,16 +42,19 @@ class Substitution {
OptimizerCallerPtr transform_;
std::string name_;
PredicateFuncType predicate_{nullptr};
// An enum to mark this Substitution relation to renormalize pass
// Determine whether there is a node replacement in the substitution, such as manager->Replace(old_node, new_node).
bool has_node_replacement_{false};
// An enum to mark this Substitution relation to renormalize pass.
RenormAction renorm_action_;
// Determine whether it is a priority substitution, that is, some patterns need to be matched prior to others.
bool has_priority_pattern_{false};
Substitution(const OptimizerCallerPtr &transform, const std::string &name, const PredicateFuncType &predicate,
const RenormAction &renorm_action, bool has_priority_pattern)
bool has_node_replacement, const RenormAction &renorm_action, bool has_priority_pattern)
: transform_(transform),
name_(name),
predicate_(predicate),
has_node_replacement_(has_node_replacement),
renorm_action_(renorm_action),
has_priority_pattern_(has_priority_pattern) {}
~Substitution() = default;
@ -61,13 +64,14 @@ class Substitution {
using SubstitutionPtr = std::shared_ptr<Substitution>;
SubstitutionPtr MakeSubstitution(const OptimizerCallerPtr &transform, const std::string &name, const PrimitivePtr &prim,
const RenormAction &action_renorm = CHECK_RENORM, bool has_priority_pattern = false);
SubstitutionPtr MakeSubstitution(const OptimizerCallerPtr &transform, const std::string &name,
const std::vector<PrimitivePtr> &prims,
const RenormAction &action_renorm = CHECK_RENORM, bool has_priority_pattern = false);
SubstitutionPtr MakeSubstitution(const OptimizerCallerPtr &transform, const std::string &name,
const PredicateFuncType &predicate, const RenormAction &action_renorm = CHECK_RENORM,
bool has_node_replacement = false, const RenormAction &action_renorm = CHECK_RENORM,
bool has_priority_pattern = false);
SubstitutionPtr MakeSubstitution(const OptimizerCallerPtr &transform, const std::string &name,
const std::vector<PrimitivePtr> &prims, bool has_node_replacement = false,
const RenormAction &action_renorm = CHECK_RENORM, bool has_priority_pattern = false);
SubstitutionPtr MakeSubstitution(const OptimizerCallerPtr &transform, const std::string &name,
const PredicateFuncType &predicate, bool has_node_replacement = false,
const RenormAction &action_renorm = CHECK_RENORM, bool has_priority_pattern = false);
enum OptTraverseSubstitutionsMode { kOptTraverseFromIRToSubstitutions = 0, kOptTraverseFromSubstitutionsToIR };

View File

@ -226,6 +226,15 @@ class Optimizer : public std::enable_shared_from_this<Optimizer> {
MS_LOG(EXCEPTION) << "No ResourceBase exists.";
}
// Only for the case that manager->replace() has to be called in substitution. This interface can only be used in
// substitution. Note that it is not recommended to replace nodes other than the input node in substitution.
void SubstitutionReplace(const FuncGraphManagerPtr &manager, const AnfNodePtr &old_node, const AnfNodePtr &new_node) {
manager->Replace(old_node, new_node);
substitution_replaced_nodes_.emplace_back(new_node);
}
std::vector<AnfNodePtr> substitution_replaced_nodes() const { return substitution_replaced_nodes_; }
void clear_substitution_replaced_nodes() { substitution_replaced_nodes_.clear(); }
const std::string name() const { return name_; }
void set_is_untyped_generated() { is_untyped_generated_ = true; }
@ -250,6 +259,7 @@ class Optimizer : public std::enable_shared_from_this<Optimizer> {
pipeline::ResourceBasePtr resource_;
std::vector<OptPass> passes_;
std::vector<std::string> pass_names_;
std::vector<AnfNodePtr> substitution_replaced_nodes_;
bool run_only_once_;
bool is_watch_renormalize_;
bool is_enable_;

View File

@ -686,19 +686,19 @@ bool AutoMonadElimOptPass(const FuncGraphPtr &func_graph) {
// opt::irpass::OptimizeIRPassLib is not used here to avoid double free problems in external calls.
opt::SubstitutionPtr updatestate_depend_eliminater =
opt::MakeSubstitution(std::make_shared<opt::irpass::UpdatestateDependEliminater>(), "updatestate_depend_eliminater",
prim::kPrimUpdateState);
prim::kPrimUpdateState, true);
opt::SubstitutionPtr updatestate_assign_eliminater =
opt::MakeSubstitution(std::make_shared<opt::irpass::UpdatestateAssignEliminater>(), "updatestate_assign_eliminater",
prim::kPrimUpdateState);
prim::kPrimUpdateState, true);
opt::SubstitutionPtr updatestate_maketuple_eliminater =
opt::MakeSubstitution(std::make_shared<opt::irpass::UpdatestateMakeTupleEliminater>(),
"updatestate_maketuple_eliminater", prim::kPrimUpdateState);
"updatestate_maketuple_eliminater", prim::kPrimUpdateState, true);
opt::SubstitutionPtr updatestate_only_used_node_eliminater =
opt::MakeSubstitution(std::make_shared<opt::irpass::UpdatestateOnlyUsedNodeEliminater>(),
"updatestate_only_used_node_eliminater", prim::kPrimUpdateState);
opt::SubstitutionPtr updatestate_loads_eliminater =
opt::MakeSubstitution(std::make_shared<opt::irpass::UpdatestateLoadsEliminater>(), "updatestate_loads_eliminater",
prim::kPrimUpdateState);
prim::kPrimUpdateState, true);
opt::SubstitutionPtr updatestate_pure_node_eliminater =
opt::MakeSubstitution(std::make_shared<opt::irpass::UpdatestatePureNodeEliminater>(),
"updatestate_pure_node_eliminater", prim::kPrimUpdateState);

View File

@ -21,7 +21,7 @@ import mindspore as ms
import mindspore.ops.operations as P
import mindspore.nn as nn
from mindspore.nn import Cell
from mindspore.nn import ReLU, BatchNorm2d, Conv2d, Dense, PReLU, ParameterUpdate
from mindspore.nn import ReLU, BatchNorm2d, Conv2d, ParameterUpdate
from mindspore.nn import Momentum, SoftmaxCrossEntropyWithLogits
from mindspore import context, Tensor
from mindspore.common.parameter import Parameter
@ -1220,27 +1220,6 @@ def read_file():
return content
# Net contain Prelu,BN,Conv,Dense which have weight value
class NetRrelu(Cell):
def __init__(self, in_channel, out_channel):
super().__init__()
self.relu = PReLU(channel=in_channel, w=0.25)
self.bn = BatchNorm2d(num_features=in_channel)
self.conv = Conv2d(in_channels=in_channel, out_channels=out_channel, kernel_size=2, stride=1, has_bias=False,
weight_init='ones', pad_mode='same')
self.mean = P.ReduceMean(keep_dims=False)
self.fc = Dense(in_channels=out_channel, out_channels=out_channel,
weight_init='ones', bias_init='zeros', has_bias=True)
def construct(self, x):
x = self.relu(x)
x = self.bn(x)
x = self.conv(x)
x = self.mean(x, (2, 3))
x = self.fc(x)
return x
def check_keep_batchnorm_fp32_false(kwargs, level):
if ms.context.get_context("device_target") == "GPU":
if level == "O2":
@ -1274,13 +1253,6 @@ def use_build_train_network_check_cast_num(network, level, inputs, label, cast_n
return out_me
def test_auto_mixed_precision_train_prelunet(with_save_graphs):
net2 = NetRrelu(3, 12)
input32 = Tensor(np.ones([1, 3, 2, 2]).astype(np.float32))
label32 = Tensor(np.zeros([1, 12]).astype(np.float32))
use_build_train_network_check_cast_num(net2, "O2", input32, label32, 16)
class AssignNet(Cell):
def __init__(self):
super().__init__()