forked from mindspore-Ecosystem/mindspore
handle replace in updatestate_eliminate
This commit is contained in:
parent
a480d07dd2
commit
88059b76ee
|
@ -146,7 +146,7 @@ OptimizeIRPassLib::OptimizeIRPassLib() {
|
|||
{prim::kPrimGetRefKey, prim::kPrimGetRefValue});
|
||||
|
||||
replace_refkey_by_param_ = MakeSubstitution(std::make_shared<ReplaceRefkeyByParam>(), "replace_refkey_by_param",
|
||||
IsValueNode<RefKey>, opt::FORCE_RENORM);
|
||||
IsValueNode<RefKey>, false, opt::FORCE_RENORM);
|
||||
replace_old_param_ = MakeSubstitution(std::make_shared<ReplaceOldParam>(), "replace_old_param", IsParam);
|
||||
minmaximum_grad_ = MakeSubstitution(std::make_shared<MinMaximumGrad>(), "minmaximum_grad", prim::kPrimTupleGetItem);
|
||||
|
||||
|
@ -192,13 +192,14 @@ OptimizeIRPassLib::OptimizeIRPassLib() {
|
|||
updatestate_pure_node_eliminater_ = MakeSubstitution(std::make_shared<UpdatestatePureNodeEliminater>(),
|
||||
"updatestate_pure_node_eliminater", prim::kPrimUpdateState);
|
||||
updatestate_depend_eliminater_ = MakeSubstitution(std::make_shared<UpdatestateDependEliminater>(),
|
||||
"updatestate_depend_eliminater", prim::kPrimUpdateState);
|
||||
"updatestate_depend_eliminater", prim::kPrimUpdateState, true);
|
||||
updatestate_assign_eliminater_ = MakeSubstitution(std::make_shared<UpdatestateAssignEliminater>(),
|
||||
"updatestate_assign_eliminater", prim::kPrimUpdateState);
|
||||
updatestate_maketuple_eliminater_ = MakeSubstitution(std::make_shared<UpdatestateMakeTupleEliminater>(),
|
||||
"updatestate_maketuple_eliminater", prim::kPrimUpdateState);
|
||||
"updatestate_assign_eliminater", prim::kPrimUpdateState, true);
|
||||
updatestate_maketuple_eliminater_ =
|
||||
MakeSubstitution(std::make_shared<UpdatestateMakeTupleEliminater>(), "updatestate_maketuple_eliminater",
|
||||
prim::kPrimUpdateState, true);
|
||||
updatestate_loads_eliminater_ = MakeSubstitution(std::make_shared<UpdatestateLoadsEliminater>(),
|
||||
"updatestate_loads_eliminater", prim::kPrimUpdateState);
|
||||
"updatestate_loads_eliminater", prim::kPrimUpdateState, true);
|
||||
switch_call_monad_eliminater_ = MakeSubstitution(std::make_shared<SwitchCallMonadParameterEliminater>(),
|
||||
"switch_call_monad_eliminater", IsCNodeDup);
|
||||
|
||||
|
@ -273,8 +274,9 @@ OptimizeIRPassLib::OptimizeIRPassLib() {
|
|||
|
||||
ResolveIRPassLib::ResolveIRPassLib() {
|
||||
// In resolver_getattr_resolve_, some patterns have priority over others.
|
||||
resolver_getattr_resolve_ = MakeSubstitution(std::make_shared<ResolverGetAttrResolve>(), "getattr_resolve",
|
||||
{prim::kPrimGetAttr, prim::kPrimResolve}, opt::CHECK_RENORM, true);
|
||||
resolver_getattr_resolve_ =
|
||||
MakeSubstitution(std::make_shared<ResolverGetAttrResolve>(), "getattr_resolve",
|
||||
{prim::kPrimGetAttr, prim::kPrimResolve}, false, opt::CHECK_RENORM, true);
|
||||
}
|
||||
|
||||
InferenceOptPrepareLib::InferenceOptPrepareLib() {
|
||||
|
|
|
@ -103,7 +103,8 @@ AnfNodePtr EliminateUpdateStateForPureNode(const CNodePtr &update_state, const A
|
|||
return update_state->input(kInputIndex);
|
||||
}
|
||||
|
||||
AnfNodePtr EliminateUpdateStateWithDepend(const CNodePtr &update_state, const CNodePtr &depend) {
|
||||
AnfNodePtr EliminateUpdateStateWithDepend(const OptimizerPtr &optimizer, const CNodePtr &update_state,
|
||||
const CNodePtr &depend) {
|
||||
auto input_monad = depend->inputs().back();
|
||||
if (!HasAbstractMonad(input_monad)) {
|
||||
// Skip if Depend attach input is not a monad.
|
||||
|
@ -131,7 +132,7 @@ AnfNodePtr EliminateUpdateStateWithDepend(const CNodePtr &update_state, const CN
|
|||
// Replace Depend with its input.
|
||||
if (depend->size() == kMinDependSize) {
|
||||
auto depend_input = depend->input(kInputIndex);
|
||||
mgr->Replace(depend, depend_input);
|
||||
optimizer->SubstitutionReplace(mgr, depend, depend_input);
|
||||
} else {
|
||||
auto inputs = depend->inputs();
|
||||
inputs.pop_back();
|
||||
|
@ -139,7 +140,7 @@ AnfNodePtr EliminateUpdateStateWithDepend(const CNodePtr &update_state, const CN
|
|||
MS_EXCEPTION_IF_NULL(fg);
|
||||
auto new_depend = fg->NewCNode(inputs);
|
||||
new_depend->set_abstract(depend->abstract());
|
||||
mgr->Replace(depend, new_depend);
|
||||
optimizer->SubstitutionReplace(mgr, depend, new_depend);
|
||||
}
|
||||
// Replace UpdateState node with the input monad of Depend.
|
||||
return input_monad;
|
||||
|
@ -297,7 +298,7 @@ AnfNodePtr MakeTupleForSameNodes(const FuncGraphPtr &fg, const CNodePtr &old_upd
|
|||
}
|
||||
|
||||
// Remove all nodes related to UpdateStates, if they're redundant.
|
||||
void EliminateUselessNodesForUpdateStates(const std::vector<CNodePtr> &update_states) {
|
||||
void EliminateUselessNodesForUpdateStates(const OptimizerPtr &optimizer, const std::vector<CNodePtr> &update_states) {
|
||||
if (update_states.empty()) {
|
||||
return;
|
||||
}
|
||||
|
@ -309,7 +310,7 @@ void EliminateUselessNodesForUpdateStates(const std::vector<CNodePtr> &update_st
|
|||
// 1. Remove the use of UpdateState nodes, except the last one.
|
||||
for (auto i = update_states.size() - 1; i > 0; i--) {
|
||||
auto &us = update_states[i];
|
||||
mgr->Replace(us, us->input(kInputIndex));
|
||||
optimizer->SubstitutionReplace(mgr, us, us->input(kInputIndex));
|
||||
}
|
||||
|
||||
// 2. Remove the Depend users of last UpdateState node.
|
||||
|
@ -343,7 +344,7 @@ void EliminateUselessNodesForUpdateStates(const std::vector<CNodePtr> &update_st
|
|||
for (ssize_t i = depend_nodes.size() - 1; i >= end; i--) {
|
||||
const auto &depend_node = depend_nodes[i];
|
||||
const auto &depend_cnode = depend_node->cast<CNodePtr>();
|
||||
mgr->Replace(depend_cnode, depend_cnode->input(kInputIndex));
|
||||
optimizer->SubstitutionReplace(mgr, depend_cnode, depend_cnode->input(kInputIndex));
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -363,7 +364,8 @@ void EliminateUselessNodesForUpdateStates(const std::vector<CNodePtr> &update_st
|
|||
// xN = Load(xN, u)
|
||||
// t = make_tuple(x1, x2, ... , xN)
|
||||
// u1 = UpdateState(u, t)
|
||||
AnfNodePtr EliminateUpdateStateForLoads(const CNodePtr &old_update_state, const std::vector<CNodePtr> &update_states,
|
||||
AnfNodePtr EliminateUpdateStateForLoads(const OptimizerPtr &optimizer, const CNodePtr &old_update_state,
|
||||
const std::vector<CNodePtr> &update_states,
|
||||
const std::vector<CNodePtr> &loads) {
|
||||
auto fg = old_update_state->func_graph();
|
||||
if (fg == nullptr) {
|
||||
|
@ -393,7 +395,7 @@ AnfNodePtr EliminateUpdateStateForLoads(const CNodePtr &old_update_state, const
|
|||
}
|
||||
}
|
||||
|
||||
EliminateUselessNodesForUpdateStates(update_states);
|
||||
EliminateUselessNodesForUpdateStates(optimizer, update_states);
|
||||
|
||||
if (make_tuple_inputs.size() == 1) {
|
||||
// This should not happen.
|
||||
|
@ -420,7 +422,8 @@ AnfNodePtr EliminateUpdateStateForLoads(const CNodePtr &old_update_state, const
|
|||
// a2 = Assign(para2, value2, u1)
|
||||
// t = MakeTuple(a1, a2)
|
||||
// u3 = UpdateState(u1, t)
|
||||
AnfNodePtr EliminateUpdateStateBetweenAssigns(const CNodePtr &update_state, const AnfNodePtr &assign) {
|
||||
AnfNodePtr EliminateUpdateStateBetweenAssigns(const OptimizerPtr &optimizer, const CNodePtr &update_state,
|
||||
const AnfNodePtr &assign) {
|
||||
auto a2_cnode = assign->cast<CNodePtr>();
|
||||
if (a2_cnode->size() != kAssignSize) {
|
||||
return nullptr;
|
||||
|
@ -444,7 +447,7 @@ AnfNodePtr EliminateUpdateStateBetweenAssigns(const CNodePtr &update_state, cons
|
|||
MS_EXCEPTION_IF_NULL(fg);
|
||||
auto mgr = fg->manager();
|
||||
MS_EXCEPTION_IF_NULL(mgr);
|
||||
mgr->Replace(u2, u1);
|
||||
optimizer->SubstitutionReplace(mgr, u2, u1);
|
||||
AnfNodePtrList make_tuple_inputs{NewValueNode(prim::kPrimMakeTuple), a1, assign};
|
||||
auto make_tuple = MakeTupleForSameNodes(fg, update_state, make_tuple_inputs);
|
||||
auto new_update_state = fg->NewCNode({NewValueNode(prim::kPrimUpdateState), u1, make_tuple});
|
||||
|
@ -472,7 +475,8 @@ AnfNodePtr EliminateUpdateStateBetweenAssigns(const CNodePtr &update_state, cons
|
|||
// a3 = Assign(para3, value3, u1)
|
||||
// t = MakeTuple(a1, a2, a3)
|
||||
// u4 = UpdateState(u1, t)
|
||||
AnfNodePtr EliminateUpdateStateBetweenMakeTupleAssign(const CNodePtr &update_state, const AnfNodePtr &assign) {
|
||||
AnfNodePtr EliminateUpdateStateBetweenMakeTupleAssign(const OptimizerPtr &optimizer, const CNodePtr &update_state,
|
||||
const AnfNodePtr &assign) {
|
||||
auto a3_cnode = assign->cast<CNodePtr>();
|
||||
if (a3_cnode->size() != kAssignSize) {
|
||||
return nullptr;
|
||||
|
@ -509,11 +513,11 @@ AnfNodePtr EliminateUpdateStateBetweenMakeTupleAssign(const CNodePtr &update_sta
|
|||
MS_EXCEPTION_IF_NULL(fg);
|
||||
auto mgr = fg->manager();
|
||||
MS_EXCEPTION_IF_NULL(mgr);
|
||||
mgr->Replace(u3, u1);
|
||||
optimizer->SubstitutionReplace(mgr, u3, u1);
|
||||
AnfNodePtrList new_make_tuple_inputs{NewValueNode(prim::kPrimMakeTuple), make_tuple_cnode->input(kInputIndex),
|
||||
make_tuple_cnode->input(kAttachIndex), assign};
|
||||
auto new_make_tuple = MakeTupleForSameNodes(fg, update_state, new_make_tuple_inputs);
|
||||
mgr->Replace(make_tuple, new_make_tuple);
|
||||
optimizer->SubstitutionReplace(mgr, make_tuple, new_make_tuple);
|
||||
auto new_update_state = fg->NewCNode({NewValueNode(prim::kPrimUpdateState), u1, new_make_tuple});
|
||||
new_update_state->set_abstract(update_state->abstract());
|
||||
new_update_state->set_scope(update_state->scope());
|
||||
|
@ -540,7 +544,8 @@ AnfNodePtr EliminateUpdateStateBetweenMakeTupleAssign(const CNodePtr &update_sta
|
|||
// a3 = Assign(para3, value3, u1)
|
||||
// t = MakeTuple(a1, a2, a3)
|
||||
// u4 = UpdateState(u1, t)
|
||||
AnfNodePtr EliminateUpdateStateBetweenAssignMakeTuple(const CNodePtr &update_state, const AnfNodePtr &make_tuple) {
|
||||
AnfNodePtr EliminateUpdateStateBetweenAssignMakeTuple(const OptimizerPtr &optimizer, const CNodePtr &update_state,
|
||||
const AnfNodePtr &make_tuple) {
|
||||
auto make_tuple_cnode = make_tuple->cast<CNodePtr>();
|
||||
if (make_tuple_cnode->size() != kMakeTupleSize || !OnlyUsedByOneNode(make_tuple, update_state)) {
|
||||
return nullptr;
|
||||
|
@ -583,12 +588,12 @@ AnfNodePtr EliminateUpdateStateBetweenAssignMakeTuple(const CNodePtr &update_sta
|
|||
MS_EXCEPTION_IF_NULL(fg);
|
||||
auto mgr = fg->manager();
|
||||
MS_EXCEPTION_IF_NULL(mgr);
|
||||
mgr->Replace(u2, u1);
|
||||
optimizer->SubstitutionReplace(mgr, u2, u1);
|
||||
AnfNodePtrList new_make_tuple_inputs{NewValueNode(prim::kPrimMakeTuple), a1,
|
||||
make_tuple_cnode->input(kInputIndex),
|
||||
make_tuple_cnode->input(kAttachIndex)};
|
||||
auto new_make_tuple = MakeTupleForSameNodes(fg, update_state, new_make_tuple_inputs);
|
||||
mgr->Replace(make_tuple, new_make_tuple);
|
||||
optimizer->SubstitutionReplace(mgr, make_tuple, new_make_tuple);
|
||||
auto new_update_state = fg->NewCNode({NewValueNode(prim::kPrimUpdateState), u1, new_make_tuple});
|
||||
new_update_state->set_abstract(update_state->abstract());
|
||||
new_update_state->set_scope(update_state->scope());
|
||||
|
@ -657,7 +662,7 @@ AnfNodePtr UpdatestatePureNodeEliminater::operator()(const OptimizerPtr &, const
|
|||
// To:
|
||||
// out = x_user(x)
|
||||
// u2 = u_user(u)
|
||||
AnfNodePtr UpdatestateDependEliminater::operator()(const OptimizerPtr &, const AnfNodePtr &node) {
|
||||
AnfNodePtr UpdatestateDependEliminater::operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node) {
|
||||
auto update_state_node = dyn_cast<CNode>(node);
|
||||
if (update_state_node == nullptr || update_state_node->inputs().empty()) {
|
||||
MS_LOG(WARNING) << "UpdatestateEliminater encounter invalid node: " << node->DebugString();
|
||||
|
@ -665,14 +670,14 @@ AnfNodePtr UpdatestateDependEliminater::operator()(const OptimizerPtr &, const A
|
|||
}
|
||||
auto &attach = update_state_node->input(kAttachIndex);
|
||||
if (IsPrimitiveCNode(attach, prim::kPrimDepend)) {
|
||||
return EliminateUpdateStateWithDepend(update_state_node, attach->cast<CNodePtr>());
|
||||
return EliminateUpdateStateWithDepend(optimizer, update_state_node, attach->cast<CNodePtr>());
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
// Eliminate UpdateStates between Assign nodes.
|
||||
// Eliminate UpdateStates between Assign and MakeTuple.
|
||||
AnfNodePtr UpdatestateAssignEliminater::operator()(const OptimizerPtr &, const AnfNodePtr &node) {
|
||||
AnfNodePtr UpdatestateAssignEliminater::operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node) {
|
||||
auto update_state_node = dyn_cast<CNode>(node);
|
||||
if (update_state_node == nullptr || update_state_node->inputs().empty()) {
|
||||
MS_LOG(WARNING) << "UpdatestateEliminater encounter invalid node: " << node->DebugString();
|
||||
|
@ -680,19 +685,19 @@ AnfNodePtr UpdatestateAssignEliminater::operator()(const OptimizerPtr &, const A
|
|||
}
|
||||
auto &attach = update_state_node->input(kAttachIndex);
|
||||
if (IsPrimitiveCNode(attach, prim::kPrimAssign)) {
|
||||
auto new_node = EliminateUpdateStateBetweenAssigns(update_state_node, attach);
|
||||
auto new_node = EliminateUpdateStateBetweenAssigns(optimizer, update_state_node, attach);
|
||||
if (new_node != nullptr) {
|
||||
return new_node;
|
||||
}
|
||||
return EliminateUpdateStateBetweenMakeTupleAssign(update_state_node, attach);
|
||||
return EliminateUpdateStateBetweenMakeTupleAssign(optimizer, update_state_node, attach);
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
// Eliminate UpdateStates which the second input is MakeTuple.
|
||||
AnfNodePtr UpdatestateMakeTupleEliminater::operator()(const OptimizerPtr &, const AnfNodePtr &node) {
|
||||
AnfNodePtr UpdatestateMakeTupleEliminater::operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node) {
|
||||
PatternNode<AnfNodePtr> u, attach;
|
||||
auto MakeTupleLambda = [&node, &u, &attach]() -> AnfNodePtr {
|
||||
auto MakeTupleLambda = [&optimizer, &node, &u, &attach]() -> AnfNodePtr {
|
||||
auto update_state_node = node->cast<CNodePtr>();
|
||||
auto make_tuple = attach.GetNode(node)->cast<CNodePtr>();
|
||||
auto new_node = EliminateMakeTupleWithDeadNode(update_state_node, make_tuple);
|
||||
|
@ -703,7 +708,7 @@ AnfNodePtr UpdatestateMakeTupleEliminater::operator()(const OptimizerPtr &, cons
|
|||
if (new_node != nullptr) {
|
||||
return new_node;
|
||||
}
|
||||
return EliminateUpdateStateBetweenAssignMakeTuple(update_state_node, make_tuple);
|
||||
return EliminateUpdateStateBetweenAssignMakeTuple(optimizer, update_state_node, make_tuple);
|
||||
};
|
||||
|
||||
MATCH_REPLACE_LAMBDA_IF(node, PPrimitive(prim::kPrimUpdateState, u, attach), MakeTupleLambda,
|
||||
|
@ -712,7 +717,7 @@ AnfNodePtr UpdatestateMakeTupleEliminater::operator()(const OptimizerPtr &, cons
|
|||
}
|
||||
|
||||
// Eliminate UpdateStates for consecutive Loads.
|
||||
AnfNodePtr UpdatestateLoadsEliminater::operator()(const OptimizerPtr &, const AnfNodePtr &node) {
|
||||
AnfNodePtr UpdatestateLoadsEliminater::operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node) {
|
||||
auto update_state_node = dyn_cast<CNode>(node);
|
||||
if (update_state_node == nullptr || update_state_node->inputs().empty()) {
|
||||
MS_LOG(WARNING) << "UpdatestateEliminater encounter invalid node: " << node->DebugString();
|
||||
|
@ -724,7 +729,7 @@ AnfNodePtr UpdatestateLoadsEliminater::operator()(const OptimizerPtr &, const An
|
|||
std::vector<CNodePtr> loads;
|
||||
GetLoadsFromUpdateState(update_state_node, &update_states, &loads);
|
||||
if (update_states.size() > 1 && loads.size() > 1) {
|
||||
return EliminateUpdateStateForLoads(update_state_node, update_states, loads);
|
||||
return EliminateUpdateStateForLoads(optimizer, update_state_node, update_states, loads);
|
||||
}
|
||||
}
|
||||
return nullptr;
|
||||
|
|
|
@ -36,26 +36,26 @@ class UpdatestatePureNodeEliminater : public AnfVisitor {
|
|||
// Eliminate redundant UpdateState/Depend pair nodes caused by inline.
|
||||
class UpdatestateDependEliminater : public AnfVisitor {
|
||||
public:
|
||||
AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override;
|
||||
AnfNodePtr operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node) override;
|
||||
};
|
||||
|
||||
// Eliminate UpdateStates between Assign nodes.
|
||||
// Eliminate UpdateStates between Assign and MakeTuple.
|
||||
class UpdatestateAssignEliminater : public AnfVisitor {
|
||||
public:
|
||||
AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override;
|
||||
AnfNodePtr operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node) override;
|
||||
};
|
||||
|
||||
// Eliminate UpdateStates which the second input is MakeTuple.
|
||||
class UpdatestateMakeTupleEliminater : public AnfVisitor {
|
||||
public:
|
||||
AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override;
|
||||
AnfNodePtr operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node) override;
|
||||
};
|
||||
|
||||
// Eliminate UpdateStates for consecutive Loads.
|
||||
class UpdatestateLoadsEliminater : public AnfVisitor {
|
||||
public:
|
||||
AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override;
|
||||
AnfNodePtr operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node) override;
|
||||
};
|
||||
|
||||
// SwitchCallMonadParameterEliminater eliminates Monad parameter in switch call.
|
||||
|
|
|
@ -30,14 +30,15 @@ namespace mindspore {
|
|||
/* namespace to support opt */
|
||||
namespace opt {
|
||||
SubstitutionPtr MakeSubstitution(const OptimizerCallerPtr &transform, const std::string &name, const PrimitivePtr &prim,
|
||||
const RenormAction &renorm_action, bool has_priority_pattern) {
|
||||
bool has_node_replacement, const RenormAction &renorm_action,
|
||||
bool has_priority_pattern) {
|
||||
auto fn = [prim](const AnfNodePtr &node) -> bool { return IsPrimitiveCNode(node, prim); };
|
||||
return std::make_shared<Substitution>(transform, name, fn, renorm_action, has_priority_pattern);
|
||||
return std::make_shared<Substitution>(transform, name, fn, has_node_replacement, renorm_action, has_priority_pattern);
|
||||
}
|
||||
|
||||
SubstitutionPtr MakeSubstitution(const OptimizerCallerPtr &transform, const std::string &name,
|
||||
const std::vector<PrimitivePtr> &prims, const RenormAction &renorm_action,
|
||||
bool has_priority_pattern) {
|
||||
const std::vector<PrimitivePtr> &prims, bool has_node_replacement,
|
||||
const RenormAction &renorm_action, bool has_priority_pattern) {
|
||||
auto fn = [prims](const AnfNodePtr &node) -> bool {
|
||||
if (!node->isa<CNode>()) {
|
||||
return false;
|
||||
|
@ -60,13 +61,14 @@ SubstitutionPtr MakeSubstitution(const OptimizerCallerPtr &transform, const std:
|
|||
return false;
|
||||
};
|
||||
|
||||
return std::make_shared<Substitution>(transform, name, fn, renorm_action, has_priority_pattern);
|
||||
return std::make_shared<Substitution>(transform, name, fn, has_node_replacement, renorm_action, has_priority_pattern);
|
||||
}
|
||||
|
||||
SubstitutionPtr MakeSubstitution(const OptimizerCallerPtr &transform, const std::string &name,
|
||||
const PredicateFuncType &predicate, const RenormAction &renorm_action,
|
||||
bool has_priority_pattern) {
|
||||
return std::make_shared<Substitution>(transform, name, predicate, renorm_action, has_priority_pattern);
|
||||
const PredicateFuncType &predicate, bool has_node_replacement,
|
||||
const RenormAction &renorm_action, bool has_priority_pattern) {
|
||||
return std::make_shared<Substitution>(transform, name, predicate, has_node_replacement, renorm_action,
|
||||
has_priority_pattern);
|
||||
}
|
||||
|
||||
AnfNodePtr Substitution::operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node) {
|
||||
|
@ -212,6 +214,14 @@ bool SubstitutionList::ApplyIRToSubstitutions(const OptimizerPtr &optimizer, con
|
|||
change = true;
|
||||
changes = true;
|
||||
node = res;
|
||||
// If there is a node replacement in the substitution, add replaced nodes to todo list.
|
||||
if (substitution->has_node_replacement_) {
|
||||
for (auto &replaced_node : optimizer->substitution_replaced_nodes()) {
|
||||
UpdateTransformingListForSubstitutions(replaced_node, &todo, change);
|
||||
UpdateTransformingListWithUserNodes(optimizer, replaced_node, &todo, change, seen);
|
||||
}
|
||||
optimizer->clear_substitution_replaced_nodes();
|
||||
}
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
@ -251,6 +261,14 @@ bool SubstitutionList::ApplySubstitutionToIR(const OptimizerPtr &optimizer, cons
|
|||
change = true;
|
||||
changes = true;
|
||||
node = res;
|
||||
// If there is a node replacement in the substitution, add replaced nodes to todo list.
|
||||
if (substitution->has_node_replacement_) {
|
||||
for (auto &replaced_node : optimizer->substitution_replaced_nodes()) {
|
||||
UpdateTransformingListForIR(replaced_node, &todo, change, substitution);
|
||||
UpdateTransformingListWithUserNodes(optimizer, replaced_node, &todo, change, seen);
|
||||
}
|
||||
optimizer->clear_substitution_replaced_nodes();
|
||||
}
|
||||
}
|
||||
UpdateTransformingListForIR(node, &todo, change, substitution);
|
||||
UpdateTransformingListWithUserNodes(optimizer, node, &todo, change, seen);
|
||||
|
|
|
@ -42,16 +42,19 @@ class Substitution {
|
|||
OptimizerCallerPtr transform_;
|
||||
std::string name_;
|
||||
PredicateFuncType predicate_{nullptr};
|
||||
// An enum to mark this Substitution relation to renormalize pass
|
||||
// Determine whether there is a node replacement in the substitution, such as manager->Replace(old_node, new_node).
|
||||
bool has_node_replacement_{false};
|
||||
// An enum to mark this Substitution relation to renormalize pass.
|
||||
RenormAction renorm_action_;
|
||||
// Determine whether it is a priority substitution, that is, some patterns need to be matched prior to others.
|
||||
bool has_priority_pattern_{false};
|
||||
|
||||
Substitution(const OptimizerCallerPtr &transform, const std::string &name, const PredicateFuncType &predicate,
|
||||
const RenormAction &renorm_action, bool has_priority_pattern)
|
||||
bool has_node_replacement, const RenormAction &renorm_action, bool has_priority_pattern)
|
||||
: transform_(transform),
|
||||
name_(name),
|
||||
predicate_(predicate),
|
||||
has_node_replacement_(has_node_replacement),
|
||||
renorm_action_(renorm_action),
|
||||
has_priority_pattern_(has_priority_pattern) {}
|
||||
~Substitution() = default;
|
||||
|
@ -61,13 +64,14 @@ class Substitution {
|
|||
using SubstitutionPtr = std::shared_ptr<Substitution>;
|
||||
|
||||
SubstitutionPtr MakeSubstitution(const OptimizerCallerPtr &transform, const std::string &name, const PrimitivePtr &prim,
|
||||
const RenormAction &action_renorm = CHECK_RENORM, bool has_priority_pattern = false);
|
||||
SubstitutionPtr MakeSubstitution(const OptimizerCallerPtr &transform, const std::string &name,
|
||||
const std::vector<PrimitivePtr> &prims,
|
||||
const RenormAction &action_renorm = CHECK_RENORM, bool has_priority_pattern = false);
|
||||
SubstitutionPtr MakeSubstitution(const OptimizerCallerPtr &transform, const std::string &name,
|
||||
const PredicateFuncType &predicate, const RenormAction &action_renorm = CHECK_RENORM,
|
||||
bool has_node_replacement = false, const RenormAction &action_renorm = CHECK_RENORM,
|
||||
bool has_priority_pattern = false);
|
||||
SubstitutionPtr MakeSubstitution(const OptimizerCallerPtr &transform, const std::string &name,
|
||||
const std::vector<PrimitivePtr> &prims, bool has_node_replacement = false,
|
||||
const RenormAction &action_renorm = CHECK_RENORM, bool has_priority_pattern = false);
|
||||
SubstitutionPtr MakeSubstitution(const OptimizerCallerPtr &transform, const std::string &name,
|
||||
const PredicateFuncType &predicate, bool has_node_replacement = false,
|
||||
const RenormAction &action_renorm = CHECK_RENORM, bool has_priority_pattern = false);
|
||||
|
||||
enum OptTraverseSubstitutionsMode { kOptTraverseFromIRToSubstitutions = 0, kOptTraverseFromSubstitutionsToIR };
|
||||
|
||||
|
|
|
@ -226,6 +226,15 @@ class Optimizer : public std::enable_shared_from_this<Optimizer> {
|
|||
MS_LOG(EXCEPTION) << "No ResourceBase exists.";
|
||||
}
|
||||
|
||||
// Only for the case that manager->replace() has to be called in substitution. This interface can only be used in
|
||||
// substitution. Note that it is not recommended to replace nodes other than the input node in substitution.
|
||||
void SubstitutionReplace(const FuncGraphManagerPtr &manager, const AnfNodePtr &old_node, const AnfNodePtr &new_node) {
|
||||
manager->Replace(old_node, new_node);
|
||||
substitution_replaced_nodes_.emplace_back(new_node);
|
||||
}
|
||||
std::vector<AnfNodePtr> substitution_replaced_nodes() const { return substitution_replaced_nodes_; }
|
||||
void clear_substitution_replaced_nodes() { substitution_replaced_nodes_.clear(); }
|
||||
|
||||
const std::string name() const { return name_; }
|
||||
|
||||
void set_is_untyped_generated() { is_untyped_generated_ = true; }
|
||||
|
@ -250,6 +259,7 @@ class Optimizer : public std::enable_shared_from_this<Optimizer> {
|
|||
pipeline::ResourceBasePtr resource_;
|
||||
std::vector<OptPass> passes_;
|
||||
std::vector<std::string> pass_names_;
|
||||
std::vector<AnfNodePtr> substitution_replaced_nodes_;
|
||||
bool run_only_once_;
|
||||
bool is_watch_renormalize_;
|
||||
bool is_enable_;
|
||||
|
|
|
@ -686,19 +686,19 @@ bool AutoMonadElimOptPass(const FuncGraphPtr &func_graph) {
|
|||
// opt::irpass::OptimizeIRPassLib is not used here to avoid double free problems in external calls.
|
||||
opt::SubstitutionPtr updatestate_depend_eliminater =
|
||||
opt::MakeSubstitution(std::make_shared<opt::irpass::UpdatestateDependEliminater>(), "updatestate_depend_eliminater",
|
||||
prim::kPrimUpdateState);
|
||||
prim::kPrimUpdateState, true);
|
||||
opt::SubstitutionPtr updatestate_assign_eliminater =
|
||||
opt::MakeSubstitution(std::make_shared<opt::irpass::UpdatestateAssignEliminater>(), "updatestate_assign_eliminater",
|
||||
prim::kPrimUpdateState);
|
||||
prim::kPrimUpdateState, true);
|
||||
opt::SubstitutionPtr updatestate_maketuple_eliminater =
|
||||
opt::MakeSubstitution(std::make_shared<opt::irpass::UpdatestateMakeTupleEliminater>(),
|
||||
"updatestate_maketuple_eliminater", prim::kPrimUpdateState);
|
||||
"updatestate_maketuple_eliminater", prim::kPrimUpdateState, true);
|
||||
opt::SubstitutionPtr updatestate_only_used_node_eliminater =
|
||||
opt::MakeSubstitution(std::make_shared<opt::irpass::UpdatestateOnlyUsedNodeEliminater>(),
|
||||
"updatestate_only_used_node_eliminater", prim::kPrimUpdateState);
|
||||
opt::SubstitutionPtr updatestate_loads_eliminater =
|
||||
opt::MakeSubstitution(std::make_shared<opt::irpass::UpdatestateLoadsEliminater>(), "updatestate_loads_eliminater",
|
||||
prim::kPrimUpdateState);
|
||||
prim::kPrimUpdateState, true);
|
||||
opt::SubstitutionPtr updatestate_pure_node_eliminater =
|
||||
opt::MakeSubstitution(std::make_shared<opt::irpass::UpdatestatePureNodeEliminater>(),
|
||||
"updatestate_pure_node_eliminater", prim::kPrimUpdateState);
|
||||
|
|
|
@ -21,7 +21,7 @@ import mindspore as ms
|
|||
import mindspore.ops.operations as P
|
||||
import mindspore.nn as nn
|
||||
from mindspore.nn import Cell
|
||||
from mindspore.nn import ReLU, BatchNorm2d, Conv2d, Dense, PReLU, ParameterUpdate
|
||||
from mindspore.nn import ReLU, BatchNorm2d, Conv2d, ParameterUpdate
|
||||
from mindspore.nn import Momentum, SoftmaxCrossEntropyWithLogits
|
||||
from mindspore import context, Tensor
|
||||
from mindspore.common.parameter import Parameter
|
||||
|
@ -1220,27 +1220,6 @@ def read_file():
|
|||
return content
|
||||
|
||||
|
||||
# Net contain Prelu,BN,Conv,Dense which have weight value
|
||||
class NetRrelu(Cell):
|
||||
def __init__(self, in_channel, out_channel):
|
||||
super().__init__()
|
||||
self.relu = PReLU(channel=in_channel, w=0.25)
|
||||
self.bn = BatchNorm2d(num_features=in_channel)
|
||||
self.conv = Conv2d(in_channels=in_channel, out_channels=out_channel, kernel_size=2, stride=1, has_bias=False,
|
||||
weight_init='ones', pad_mode='same')
|
||||
self.mean = P.ReduceMean(keep_dims=False)
|
||||
self.fc = Dense(in_channels=out_channel, out_channels=out_channel,
|
||||
weight_init='ones', bias_init='zeros', has_bias=True)
|
||||
|
||||
def construct(self, x):
|
||||
x = self.relu(x)
|
||||
x = self.bn(x)
|
||||
x = self.conv(x)
|
||||
x = self.mean(x, (2, 3))
|
||||
x = self.fc(x)
|
||||
return x
|
||||
|
||||
|
||||
def check_keep_batchnorm_fp32_false(kwargs, level):
|
||||
if ms.context.get_context("device_target") == "GPU":
|
||||
if level == "O2":
|
||||
|
@ -1274,13 +1253,6 @@ def use_build_train_network_check_cast_num(network, level, inputs, label, cast_n
|
|||
return out_me
|
||||
|
||||
|
||||
def test_auto_mixed_precision_train_prelunet(with_save_graphs):
|
||||
net2 = NetRrelu(3, 12)
|
||||
input32 = Tensor(np.ones([1, 3, 2, 2]).astype(np.float32))
|
||||
label32 = Tensor(np.zeros([1, 12]).astype(np.float32))
|
||||
use_build_train_network_check_cast_num(net2, "O2", input32, label32, 16)
|
||||
|
||||
|
||||
class AssignNet(Cell):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
|
Loading…
Reference in New Issue