!12345 Traverse all nodes once, then traverse all Substitutions on each node.

From: @zh_qh
Reviewed-by: 
Signed-off-by:
This commit is contained in:
mindspore-ci-bot 2021-02-22 19:23:01 +08:00 committed by Gitee
commit 86e3099c05
9 changed files with 223 additions and 113 deletions

View File

@ -31,7 +31,6 @@
#include "backend/kernel_compiler/kernel_build_info.h"
#include "common/trans.h"
#include "abstract/param_validator.h"
#include "abstract/primitive_infer_map.h"
#include "pipeline/jit/static_analysis/static_analysis.h"
#include "utils/trace_base.h"
@ -1806,14 +1805,7 @@ void AnfRuntimeAlgorithm::InferShape(const CNodePtr &node) {
args_spec_list.emplace_back(real_input->abstract());
}
}
auto &prim_eval_implement_map = abstract::GetPrimitiveToEvalImplMap();
auto ret = prim_eval_implement_map.find(primitive);
if (ret == prim_eval_implement_map.end()) {
MS_LOG(EXCEPTION) << "Get infer shape function failed, primitive name:" << primitive->name()
<< " primitive type:" << primitive->type_name();
}
auto eval_result = ret->second.impl_(nullptr, primitive, args_spec_list);
auto eval_result = abstract::CppInferShape(primitive, args_spec_list);
node->set_abstract(eval_result);
}
} // namespace session

View File

@ -230,6 +230,8 @@ ResolveIRPassLib::ResolveIRPassLib() {
{prim::kPrimGetAttr, prim::kPrimResolve});
resolver_resolve_ = MakeSubstitution(std::make_shared<ResolverResolve>(), "resolver_resolve", prim::kPrimResolve);
resolver_getattr_ = MakeSubstitution(std::make_shared<ResolverGetAttr>(), "resolver_getattr", prim::kPrimGetAttr);
resolver_getattr_resolve_ =
MakeSubstitution(std::make_shared<ResolverGetAttrResolve>(), "resolver_getattr_resolve", prim::kPrimGetAttr);
}
InferenceOptPrepareLib::InferenceOptPrepareLib() {

View File

@ -154,6 +154,7 @@ class ResolveIRPassLib {
SubstitutionPtr resolver_resolve_and_getattr_;
SubstitutionPtr resolver_resolve_;
SubstitutionPtr resolver_getattr_;
SubstitutionPtr resolver_getattr_resolve_;
};
class InferenceOptPrepareLib {

View File

@ -71,7 +71,7 @@ AnfNodePtr EliminateUpdateStateOnlyUsedNode(const CNodePtr &update_state, const
return nullptr;
}
// Replace UpdateState with the input monad.
return update_state->inputs().at(kInputIndex);
return update_state->input(kInputIndex);
}
// Eliminate UpdateState that attaches a pure (no-side-effect) node.
@ -100,7 +100,7 @@ AnfNodePtr EliminateUpdateStateForPureNode(const CNodePtr &update_state, const A
}
}
// Remove UpdateState by replace it with its input monad.
return update_state->inputs().at(kInputIndex);
return update_state->input(kInputIndex);
}
// Eliminate redundant UpdateState/Depend pair nodes caused by inline.
@ -118,7 +118,7 @@ AnfNodePtr EliminateUpdateStateWithDepend(const CNodePtr &update_state, const CN
// Skip if Depend attach input is not a monad.
return nullptr;
}
auto update_monad = update_state->inputs().at(kInputIndex);
auto update_monad = update_state->input(kInputIndex);
if (!HasAbstractMonad(update_monad)) {
// Skip if UpdateState input is not a monad.
MS_LOG(WARNING) << "Not a monad input: " << update_state->DebugString();
@ -139,7 +139,7 @@ AnfNodePtr EliminateUpdateStateWithDepend(const CNodePtr &update_state, const CN
}
// Replace Depend with its input.
if (depend->size() == kMinDependSize) {
auto depend_input = depend->inputs().at(kInputIndex);
auto depend_input = depend->input(kInputIndex);
mgr->Replace(depend, depend_input);
} else {
auto inputs = depend->inputs();
@ -163,7 +163,7 @@ AnfNodePtr EliminateMakeTupleWithDeadNode(const CNodePtr &update_state, const CN
if (make_tuple->size() != kMakeTupleSize) {
return nullptr;
}
auto &node = make_tuple->inputs().at(kAttachIndex);
auto &node = make_tuple->input(kAttachIndex);
auto node_abs = node->abstract();
if (node_abs == nullptr || !node_abs->isa<abstract::AbstractError>()) {
return nullptr;
@ -173,7 +173,7 @@ AnfNodePtr EliminateMakeTupleWithDeadNode(const CNodePtr &update_state, const CN
return nullptr;
}
// Create a new UpdateState to replace the old one.
const auto &attach = make_tuple->inputs().at(kInputIndex);
const auto &attach = make_tuple->input(kInputIndex);
auto new_update_state = fg->NewCNode({update_state->input(0), update_state->input(1), attach});
new_update_state->set_abstract(update_state->abstract());
new_update_state->set_scope(update_state->scope());
@ -206,42 +206,47 @@ AnfNodePtr EliminateUpdateStateWithMakeTupleFunc(const CNodePtr &update_state, c
if (make_tuple->size() != kMakeTupleSize) {
return nullptr;
}
auto &first_input = make_tuple->inputs().at(kInputIndex);
auto &first_input = make_tuple->input(kInputIndex);
if (IsValueNode<FuncGraph>(first_input) && OnlyMakeTupleUseFunc(make_tuple, first_input)) {
return update_state->input(1);
}
auto &second_input = make_tuple->inputs().at(kAttachIndex);
auto &second_input = make_tuple->input(kAttachIndex);
if (IsValueNode<FuncGraph>(second_input) && OnlyMakeTupleUseFunc(make_tuple, second_input)) {
return update_state->input(1);
}
return nullptr;
}
size_t GetLoadsFollowLoad(const CNodePtr &load, std::vector<CNodePtr> *loads);
size_t GetLoadsFollowTuple(const CNodePtr &update_state, const CNodePtr &make_tuple, std::vector<CNodePtr> *loads);
size_t GetLoadsFollowLoad(const CNodePtr &load, std::vector<CNodePtr> *update_states, std::vector<CNodePtr> *loads);
size_t GetLoadsFollowTuple(const CNodePtr &update_state, const CNodePtr &make_tuple,
std::vector<CNodePtr> *update_states, std::vector<CNodePtr> *loads);
// Search consecutive load nodes from UpdateState node.
size_t GetLoadsFromUpdateState(const CNodePtr &update_state, std::vector<CNodePtr> *loads) {
auto &attach = update_state->inputs().at(kAttachIndex);
size_t GetLoadsFromUpdateState(const CNodePtr &update_state, std::vector<CNodePtr> *update_states,
std::vector<CNodePtr> *loads) {
auto &attach = update_state->input(kAttachIndex);
if (IsPrimitiveCNode(attach, prim::kPrimLoad)) {
return GetLoadsFollowLoad(attach->cast<CNodePtr>(), loads);
update_states->emplace_back(update_state);
return GetLoadsFollowLoad(attach->cast<CNodePtr>(), update_states, loads);
}
if (IsPrimitiveCNode(attach, prim::kPrimMakeTuple)) {
return GetLoadsFollowTuple(update_state, attach->cast<CNodePtr>(), loads);
update_states->emplace_back(update_state);
return GetLoadsFollowTuple(update_state, attach->cast<CNodePtr>(), update_states, loads);
}
return 0;
}
size_t GetLoadsFollowLoad(const CNodePtr &load, std::vector<CNodePtr> *loads) {
loads->push_back(load);
auto &load_attach = load->inputs().at(kAttachIndex);
size_t GetLoadsFollowLoad(const CNodePtr &load, std::vector<CNodePtr> *update_states, std::vector<CNodePtr> *loads) {
loads->emplace_back(load);
auto &load_attach = load->input(kAttachIndex);
if (IsPrimitiveCNode(load_attach, prim::kPrimUpdateState)) {
return GetLoadsFromUpdateState(load_attach->cast<CNodePtr>(), loads) + 1;
return GetLoadsFromUpdateState(load_attach->cast<CNodePtr>(), update_states, loads) + 1;
}
return 1;
}
size_t GetLoadsFollowTuple(const CNodePtr &update_state, const CNodePtr &make_tuple, std::vector<CNodePtr> *loads) {
size_t GetLoadsFollowTuple(const CNodePtr &update_state, const CNodePtr &make_tuple,
std::vector<CNodePtr> *update_states, std::vector<CNodePtr> *loads) {
if (!OnlyUpdateStateUse(update_state, make_tuple)) {
// UpdateState should be the only user of
return 0;
@ -256,12 +261,12 @@ size_t GetLoadsFollowTuple(const CNodePtr &update_state, const CNodePtr &make_tu
// Add load nodes from tuple elements.
for (size_t i = 1; i < inputs.size(); ++i) {
auto &element = inputs.at(i);
loads->push_back(element->cast<CNodePtr>());
loads->emplace_back(element->cast<CNodePtr>());
}
// Follow prev update state if found.
auto prev_node = update_state->inputs().at(kInputIndex);
auto prev_node = update_state->input(kInputIndex);
if (IsPrimitiveCNode(prev_node, prim::kPrimUpdateState)) {
return GetLoadsFromUpdateState(prev_node->cast<CNodePtr>(), loads) + 1;
return GetLoadsFromUpdateState(prev_node->cast<CNodePtr>(), update_states, loads) + 1;
}
return 1;
}
@ -301,7 +306,8 @@ AnfNodePtr MakeTupleForSameNodes(const FuncGraphPtr &fg, const CNodePtr &old_upd
// xN = Load(xN, u)
// t = make_tuple(x1, x2, ... , xN)
// u1 = UpdateState(u, t)
AnfNodePtr EliminateUpdateStateForLoads(const CNodePtr &old_update_state, const std::vector<CNodePtr> &loads) {
AnfNodePtr EliminateUpdateStateForLoads(const CNodePtr &old_update_state, const std::vector<CNodePtr> &update_states,
const std::vector<CNodePtr> &loads) {
auto fg = old_update_state->func_graph();
if (fg == nullptr) {
return nullptr;
@ -315,20 +321,24 @@ AnfNodePtr EliminateUpdateStateForLoads(const CNodePtr &old_update_state, const
std::set<AnfNodePtr> loaded_para_set;
make_tuple_inputs.reserve(loads.size() + 1);
make_tuple_inputs.emplace_back(NewValueNode(prim::kPrimMakeTuple));
auto input_monad = loads.back()->inputs().at(kAttachIndex);
auto input_monad = loads.back()->input(kAttachIndex);
for (auto iter = loads.rbegin(); iter != loads.rend(); ++iter) {
auto &load = *iter;
auto result = loaded_para_set.emplace(load->inputs().at(kInputIndex));
auto result = loaded_para_set.emplace(load->input(kInputIndex));
const bool is_new_load = result.second;
if (is_new_load) {
// Put Load node as a tuple element, if the parameter is not loaded by other Load.
make_tuple_inputs.emplace_back(load);
}
if (load->inputs().at(kAttachIndex) != input_monad) {
if (load->input(kAttachIndex) != input_monad) {
// Set all load use same input monad.
mgr->SetEdge(load, kAttachIndex, input_monad);
}
}
for (auto i = update_states.size() - 1; i > 0; i--) {
auto &us = update_states[i];
mgr->Replace(us, us->input(kInputIndex));
}
if (make_tuple_inputs.size() == 1) {
// This should not happen.
MS_LOG(WARNING) << "No loads for " << old_update_state->DebugString(2);
@ -538,7 +548,7 @@ AnfNodePtr UpdatestateEliminater::operator()(const OptimizerPtr &, const AnfNode
MS_LOG(WARNING) << "UpdatestateEliminater encounter invalid node: " << node->DebugString();
return nullptr;
}
auto &attach = update_state_node->inputs().at(kAttachIndex);
auto &attach = update_state_node->input(kAttachIndex);
if (IsPrimitiveCNode(attach, prim::kPrimDepend)) {
return EliminateUpdateStateWithDepend(update_state_node, attach->cast<CNodePtr>());
}
@ -586,9 +596,10 @@ AnfNodePtr UpdatestateEliminater::operator()(const OptimizerPtr &, const AnfNode
return new_node;
}
}
std::vector<CNodePtr> update_states;
std::vector<CNodePtr> loads;
if (GetLoadsFromUpdateState(update_state_node, &loads) > 1 && loads.size() > 1) {
return EliminateUpdateStateForLoads(update_state_node, loads);
if (GetLoadsFromUpdateState(update_state_node, &update_states, &loads) > 1 && loads.size() > 1) {
return EliminateUpdateStateForLoads(update_state_node, update_states, loads);
}
// Eliminate UpdateStates that attaches a no-side-effect node.
if (!attach_is_load && !attach_is_tuple) {

View File

@ -103,8 +103,62 @@ static bool isTraversable(const AnfNodePtr &node) {
return false;
}
bool SubstitutionList::ApplyTransform(const OptimizerPtr &optimizer, const AnfNodePtr &root_node,
const SubstitutionPtr &transform) const {
static inline AnfNodePtr DoTransform(const OptimizerPtr &optimizer, const AnfNodePtr &node,
const SubstitutionPtr &substitution) {
auto manager = optimizer->manager();
bool is_match = substitution->predicate_(node);
if (is_match) {
TraceGuard trace_guard(std::make_shared<TraceOpt>(node->debug_info()));
auto res = (*substitution)(optimizer, node);
if (res != nullptr && res != node) {
#ifdef ENABLE_PROFILE
double t = GetTime();
#endif
MS_LOG(DEBUG) << "Replace " << node->DebugString() << " with " << res->DebugString() << ", by "
<< substitution->name_;
(void)manager->Replace(node, res);
#ifdef ENABLE_PROFILE
MsProfile::StatTime("replace." + substitution->name_, GetTime() - t);
#endif
return res;
}
}
return nullptr;
}
static inline void UpdateTransformingList(const OptimizerPtr &optimizer, const AnfNodePtr &node,
std::deque<AnfNodePtr> *todo, bool change, size_t seen) {
if (IsValueNode<FuncGraph>(node)) {
(*todo).emplace_back(GetValueNode<FuncGraphPtr>(node)->output());
}
if (node->isa<CNode>()) {
auto &inputs = node->cast<CNodePtr>()->inputs();
(void)std::copy(inputs.begin(), inputs.end(), std::back_inserter(*todo));
}
if (!change) {
return;
}
auto manager = optimizer->manager();
auto &node_users = manager->node_users();
auto users_iterator = node_users.find(node);
if (users_iterator == node_users.end()) {
return;
}
auto users = users_iterator->second;
for (auto &use : users) {
auto use_node = use.first;
if (use_node == nullptr) {
continue;
}
(*todo).emplace_back(use_node);
if (use_node->seen_ == seen) {
use_node->seen_--;
}
}
}
bool SubstitutionList::ApplyIRToSubstitutions(const OptimizerPtr &optimizer, const FuncGraphPtr &func_graph) const {
#ifdef ENABLE_PROFILE
double start = GetTime();
#endif
@ -113,7 +167,7 @@ bool SubstitutionList::ApplyTransform(const OptimizerPtr &optimizer, const AnfNo
// 1024 is for the initial capacity of deque
std::deque<AnfNodePtr> todo(1024);
todo.clear();
todo.push_back(root_node);
todo.emplace_back(func_graph->output());
bool changes = false;
auto &all_nodes = manager->all_nodes();
@ -121,59 +175,61 @@ bool SubstitutionList::ApplyTransform(const OptimizerPtr &optimizer, const AnfNo
AnfNodePtr node = todo.front();
todo.pop_front();
// check whether this node has been matched.
if (node == nullptr || node->seen_ == seen || !isTraversable(node) || !all_nodes.contains(node)) {
continue;
}
node->seen_ = seen;
// select nodes that this transform can be applied.
bool is_match = transform->predicate_(node);
// apply transform on this node
bool change = false;
if (is_match) {
TraceGuard trace_guard(std::make_shared<TraceOpt>(node->debug_info()));
auto ret = (*transform)(optimizer, node);
if (ret != nullptr && ret != node) {
for (auto &substitution : list_) {
auto res = DoTransform(optimizer, node, substitution);
if (res != nullptr) {
change = true;
changes = true;
node = res;
todo.emplace_back(res);
break;
}
}
UpdateTransformingList(optimizer, node, &todo, change, seen);
}
#ifdef ENABLE_PROFILE
double t = GetTime();
MsProfile::StatTime("opt.transforms." + optimizer->name(), GetTime() - start);
#endif
MS_LOG(DEBUG) << "transform: " << transform->name_ << " will replace: " << node->DebugString()
<< " with: " << ret->DebugString();
(void)manager->Replace(node, ret);
return changes;
}
bool SubstitutionList::ApplySubstitutionToIR(const OptimizerPtr &optimizer, const AnfNodePtr &root_node,
const SubstitutionPtr &substitution) const {
#ifdef ENABLE_PROFILE
MsProfile::StatTime("replace." + transform->name_, GetTime() - t);
double start = GetTime();
#endif
node = ret;
}
}
FuncGraphManagerPtr manager = optimizer->manager();
auto seen = NewSeenGeneration();
// 1024 is for the initial capacity of deque
std::deque<AnfNodePtr> todo(1024);
todo.clear();
todo.emplace_back(root_node);
bool changes = false;
// find success, and add them to todo list
if (IsValueNode<FuncGraph>(node)) {
todo.push_back(GetValueNode<FuncGraphPtr>(node)->output());
}
auto &all_nodes = manager->all_nodes();
while (!todo.empty()) {
AnfNodePtr node = todo.front();
todo.pop_front();
if (node->isa<CNode>()) {
auto &inputs = node->cast<CNodePtr>()->inputs();
(void)std::copy(inputs.begin(), inputs.end(), std::back_inserter(todo));
}
auto &node_users = manager->node_users();
if (change && node_users.find(node) != node_users.end()) {
for (auto &use : node_users[node]) {
auto use_node = use.first;
if (use_node == nullptr) {
if (node == nullptr || node->seen_ == seen || !isTraversable(node) || !all_nodes.contains(node)) {
continue;
}
todo.push_back(use_node);
if (use_node->seen_ == seen) {
use_node->seen_--;
}
}
node->seen_ = seen;
bool change = false;
auto res = DoTransform(optimizer, node, substitution);
if (res != nullptr) {
change = true;
changes = true;
node = res;
}
UpdateTransformingList(optimizer, node, &todo, change, seen);
}
#ifdef ENABLE_PROFILE
@ -182,13 +238,29 @@ bool SubstitutionList::ApplyTransform(const OptimizerPtr &optimizer, const AnfNo
return changes;
}
bool SubstitutionList::operator()(const FuncGraphPtr &func_graph, const OptimizerPtr &optimizer) const {
MS_EXCEPTION_IF_NULL(optimizer);
MS_EXCEPTION_IF_NULL(func_graph);
FuncGraphManagerPtr manager = optimizer->manager();
manager->AddFuncGraph(func_graph);
bool SubstitutionList::ApplySubstitutionsToIRForIsolate(const OptimizerPtr &optimizer) const {
const auto &manager = optimizer->manager();
const auto &nodes = manager->isolate_nodes();
bool changes = false;
bool loop = true;
while (loop) {
loop = false;
std::for_each(list_.cbegin(), list_.cend(), [&](const auto &substitution) {
std::for_each(nodes.cbegin(), nodes.cend(), [&](const auto &node) {
bool change = ApplySubstitutionToIR(optimizer, node, substitution);
changes = changes || change;
loop = loop || change;
});
});
if (is_once_) {
break;
}
}
return changes;
}
// for transform status counting
bool SubstitutionList::ApplySubstitutionsToIR(const OptimizerPtr &optimizer, const FuncGraphPtr &func_graph) const {
// Add for substitution status counting
size_t space = 0;
std::unordered_map<std::string, std::vector<bool>> status;
if (optimizer->is_on_debug_) {
@ -197,47 +269,39 @@ bool SubstitutionList::operator()(const FuncGraphPtr &func_graph, const Optimize
}
}
bool loop = false;
bool changes = false;
do {
bool loop = true;
while (loop) {
loop = false;
for (size_t i = 0; i < list_.size(); i++) {
auto change = ApplyTransform(optimizer, func_graph->output(), list_[i]);
const auto &substitution = list_[i];
bool change = ApplySubstitutionToIR(optimizer, func_graph->output(), substitution);
changes = changes || change;
loop = loop || change;
// apply transform on isolate nodes.
auto &isolate_nodes = manager->isolate_nodes();
for (auto &node : isolate_nodes) {
change = ApplyTransform(optimizer, node, list_[i]);
changes = changes || change;
loop = loop || change;
}
// record the status of each transform
static const auto enable_dump_pass_ir = (common::GetEnv("ENV_DUMP_PASS_IR") == "1");
if (enable_dump_pass_ir && MsContext::GetInstance()->get_param<bool>(MS_CTX_SAVE_GRAPHS_FLAG)) {
auto fg_name = optimizer->name() + "_" + std::to_string(optimizer->CurPass_.counter) + "_" +
optimizer->CurPass_.name + "_" + list_[i]->name_;
auto fg_name = optimizer->name() + "_r" + std::to_string(optimizer->CurPass_.counter) + "_" +
optimizer->CurPass_.name + "_" + substitution->name_;
DumpIR(fg_name + ".ir", func_graph);
if (MsContext::GetInstance()->get_param<int>(MS_CTX_EXECUTION_MODE) != kPynativeMode) {
func_graph->DumpFuncGraph(fg_name);
ExportIR(fg_name + ".dat", "", func_graph);
}
}
if (optimizer->is_on_debug_) {
status[list_[i]->name_ + std::to_string(i)].push_back(change);
space = std::max(list_[i]->name_.size(), space);
}
}
// Record the status of each substitution
if (optimizer->is_on_debug_) {
status[substitution->name_ + std::to_string(i)].push_back(change);
space = std::max(substitution->name_.size(), space);
}
}
if (is_once_) {
break;
}
} while (loop);
}
// display the status of each transform
// Display the status of each substitution
if (optimizer->is_on_debug_) {
std::stringstream ss;
ss << std::endl
@ -253,7 +317,37 @@ bool SubstitutionList::operator()(const FuncGraphPtr &func_graph, const Optimize
}
MS_LOG(DEBUG) << ss.str();
}
return changes;
}
bool SubstitutionList::operator()(const FuncGraphPtr &func_graph, const OptimizerPtr &optimizer) const {
MS_EXCEPTION_IF_NULL(optimizer);
MS_EXCEPTION_IF_NULL(func_graph);
FuncGraphManagerPtr manager = optimizer->manager();
manager->AddFuncGraph(func_graph);
bool changes = false;
static const auto traverse_mode =
(common::GetEnv("ENV_TRAVERSE_SUBSTITUTIONS_MODE") != "1" ? kOptTraverseFromIRToSubstitutions
: kOptTraverseFromSubstitutionsToIR);
if (traverse_mode == kOptTraverseFromIRToSubstitutions &&
MsContext::GetInstance()->get_param<int>(MS_CTX_EXECUTION_MODE) != kPynativeMode &&
optimizer->traverse_nodes_first()) {
changes = ApplyIRToSubstitutions(optimizer, func_graph);
} else {
changes = ApplySubstitutionsToIR(optimizer, func_graph);
}
bool has_isolate = !manager->isolate_nodes().empty();
if (has_isolate) {
#ifdef ENABLE_PROFILE
double t = GetTime();
#endif
bool change = ApplySubstitutionsToIRForIsolate(optimizer);
changes = changes || change;
#ifdef ENABLE_PROFILE
MsProfile::StatTime("opt.isolate.transform." + optimizer->name(), GetTime() - t);
#endif
}
return changes;
}
} // namespace opt

View File

@ -59,6 +59,8 @@ SubstitutionPtr MakeSubstitution(const OptimizerCallerPtr &transform, const std:
SubstitutionPtr MakeSubstitution(const OptimizerCallerPtr &transform, const std::string &name,
const PredicateFuncType &predicate, const RenormAction &action_renorm = CHECK_RENORM);
enum OptTraverseSubstitutionsMode { kOptTraverseFromIRToSubstitutions = 0, kOptTraverseFromSubstitutionsToIR };
class SubstitutionList {
public:
explicit SubstitutionList(const std::vector<SubstitutionPtr> &patterns, bool is_once = false)
@ -68,7 +70,10 @@ class SubstitutionList {
bool operator()(const FuncGraphPtr &func_graph, const OptimizerPtr &optimizer) const;
private:
bool ApplyTransform(const OptimizerPtr &optimizer, const AnfNodePtr &node, const SubstitutionPtr &transform) const;
bool ApplyIRToSubstitutions(const OptimizerPtr &optimizer, const FuncGraphPtr &func_graph) const;
bool ApplySubstitutionToIR(const OptimizerPtr &optimizer, const AnfNodePtr &node, const SubstitutionPtr &sub) const;
bool ApplySubstitutionsToIR(const OptimizerPtr &optimizer, const FuncGraphPtr &func_graph) const;
bool ApplySubstitutionsToIRForIsolate(const OptimizerPtr &optimizer) const;
std::vector<SubstitutionPtr> list_;
// a flag to mark this list of Substitution can only be executed only once
bool is_once_;

View File

@ -88,13 +88,14 @@ using OptPassGroupMap = std::vector<std::pair<std::string, OptPassConfig>>;
class Optimizer : public std::enable_shared_from_this<Optimizer> {
public:
Optimizer(const std::string &name, const pipeline::ResourceBasePtr &resource_ptr)
Optimizer(const std::string &name, const pipeline::ResourceBasePtr &resource_ptr, bool traverse_nodes_first = true)
: name_(name),
resource_(resource_ptr),
run_only_once_(false),
is_watch_renormalize_(false),
is_enable_(true),
is_untyped_generated_(false) {}
is_untyped_generated_(false),
traverse_nodes_first_(traverse_nodes_first) {}
virtual ~Optimizer() = default;
void Init(const OptPassGroupMap &passes, bool run_only_once) {
@ -129,8 +130,8 @@ class Optimizer : public std::enable_shared_from_this<Optimizer> {
static std::shared_ptr<Optimizer> MakeOptimizer(const std::string &name, const pipeline::ResourceBasePtr resource_ptr,
const OptPassGroupMap &passes, bool run_only_once = false,
bool watch_renormalize = false) {
OptimizerPtr optimizer = std::make_shared<Optimizer>(name, resource_ptr);
bool watch_renormalize = false, bool traverse_nodes_first = true) {
OptimizerPtr optimizer = std::make_shared<Optimizer>(name, resource_ptr, traverse_nodes_first);
optimizer->Init(passes, run_only_once);
if (watch_renormalize) {
optimizer->enable_watch_renormalize();
@ -223,6 +224,8 @@ class Optimizer : public std::enable_shared_from_this<Optimizer> {
bool is_watch_renormalize() { return is_watch_renormalize_; }
void set_enable(bool enable) { is_enable_ = enable; }
bool traverse_nodes_first() { return traverse_nodes_first_; }
struct {
int64_t counter;
std::string name;
@ -239,6 +242,7 @@ class Optimizer : public std::enable_shared_from_this<Optimizer> {
bool is_watch_renormalize_;
bool is_enable_;
bool is_untyped_generated_;
bool traverse_nodes_first_;
};
} // namespace opt
} // namespace mindspore

View File

@ -308,7 +308,8 @@ bool ResolveFuncGraph(const FuncGraphPtr &func_graph, const pipeline::ResourceBa
return false;
}
opt::irpass::ResolveIRPassLib irpass;
opt::OptimizerPtr opt_resolve = opt::Optimizer::MakeOptimizer("opt_resolve", res, GetOptResolvePasses(irpass));
opt::OptimizerPtr opt_resolve =
opt::Optimizer::MakeOptimizer("opt_resolve", res, GetOptResolvePasses(irpass), false, false, false);
(void)parse::python_adapter::set_python_scoped();

View File

@ -246,7 +246,7 @@ class CNode : public AnfNode, public EffectInfoHolder {
bool IsApply(const PrimitivePtr &) const;
const size_t size() const { return inputs_.size(); }
const AnfNodePtr input(size_t i) const { return inputs_[i]; }
const AnfNodePtr &input(size_t i) const { return inputs_.at(i); }
const std::vector<AnfNodePtr> &inputs() const { return inputs_; }
void add_input(const AnfNodePtr &input);
void set_input(size_t i, const AnfNodePtr &input);