diff --git a/mindspore/ccsrc/backend/session/anf_runtime_algorithm.cc b/mindspore/ccsrc/backend/session/anf_runtime_algorithm.cc index bdd99e0c7de..41e92b6de4e 100644 --- a/mindspore/ccsrc/backend/session/anf_runtime_algorithm.cc +++ b/mindspore/ccsrc/backend/session/anf_runtime_algorithm.cc @@ -31,7 +31,6 @@ #include "backend/kernel_compiler/kernel_build_info.h" #include "common/trans.h" #include "abstract/param_validator.h" -#include "abstract/primitive_infer_map.h" #include "pipeline/jit/static_analysis/static_analysis.h" #include "utils/trace_base.h" @@ -1806,14 +1805,7 @@ void AnfRuntimeAlgorithm::InferShape(const CNodePtr &node) { args_spec_list.emplace_back(real_input->abstract()); } } - - auto &prim_eval_implement_map = abstract::GetPrimitiveToEvalImplMap(); - auto ret = prim_eval_implement_map.find(primitive); - if (ret == prim_eval_implement_map.end()) { - MS_LOG(EXCEPTION) << "Get infer shape function failed, primitive name:" << primitive->name() - << " primitive type:" << primitive->type_name(); - } - auto eval_result = ret->second.impl_(nullptr, primitive, args_spec_list); + auto eval_result = abstract::CppInferShape(primitive, args_spec_list); node->set_abstract(eval_result); } } // namespace session diff --git a/mindspore/ccsrc/frontend/optimizer/irpass.cc b/mindspore/ccsrc/frontend/optimizer/irpass.cc index e7547e521a1..8b718f73a31 100644 --- a/mindspore/ccsrc/frontend/optimizer/irpass.cc +++ b/mindspore/ccsrc/frontend/optimizer/irpass.cc @@ -230,6 +230,8 @@ ResolveIRPassLib::ResolveIRPassLib() { {prim::kPrimGetAttr, prim::kPrimResolve}); resolver_resolve_ = MakeSubstitution(std::make_shared(), "resolver_resolve", prim::kPrimResolve); resolver_getattr_ = MakeSubstitution(std::make_shared(), "resolver_getattr", prim::kPrimGetAttr); + resolver_getattr_resolve_ = + MakeSubstitution(std::make_shared(), "resolver_getattr_resolve", prim::kPrimGetAttr); } InferenceOptPrepareLib::InferenceOptPrepareLib() { diff --git a/mindspore/ccsrc/frontend/optimizer/irpass.h b/mindspore/ccsrc/frontend/optimizer/irpass.h index ef43a751c27..26139fa9388 100644 --- a/mindspore/ccsrc/frontend/optimizer/irpass.h +++ b/mindspore/ccsrc/frontend/optimizer/irpass.h @@ -154,6 +154,7 @@ class ResolveIRPassLib { SubstitutionPtr resolver_resolve_and_getattr_; SubstitutionPtr resolver_resolve_; SubstitutionPtr resolver_getattr_; + SubstitutionPtr resolver_getattr_resolve_; }; class InferenceOptPrepareLib { diff --git a/mindspore/ccsrc/frontend/optimizer/irpass/updatestate_eliminate.cc b/mindspore/ccsrc/frontend/optimizer/irpass/updatestate_eliminate.cc index d2be02d8ae7..fc6c96cb8b2 100644 --- a/mindspore/ccsrc/frontend/optimizer/irpass/updatestate_eliminate.cc +++ b/mindspore/ccsrc/frontend/optimizer/irpass/updatestate_eliminate.cc @@ -71,7 +71,7 @@ AnfNodePtr EliminateUpdateStateOnlyUsedNode(const CNodePtr &update_state, const return nullptr; } // Replace UpdateState with the input monad. - return update_state->inputs().at(kInputIndex); + return update_state->input(kInputIndex); } // Eliminate UpdateState that attaches a pure (no-side-effect) node. @@ -100,7 +100,7 @@ AnfNodePtr EliminateUpdateStateForPureNode(const CNodePtr &update_state, const A } } // Remove UpdateState by replace it with its input monad. - return update_state->inputs().at(kInputIndex); + return update_state->input(kInputIndex); } // Eliminate redundant UpdateState/Depend pair nodes caused by inline. @@ -118,7 +118,7 @@ AnfNodePtr EliminateUpdateStateWithDepend(const CNodePtr &update_state, const CN // Skip if Depend attach input is not a monad. return nullptr; } - auto update_monad = update_state->inputs().at(kInputIndex); + auto update_monad = update_state->input(kInputIndex); if (!HasAbstractMonad(update_monad)) { // Skip if UpdateState input is not a monad. MS_LOG(WARNING) << "Not a monad input: " << update_state->DebugString(); @@ -139,7 +139,7 @@ AnfNodePtr EliminateUpdateStateWithDepend(const CNodePtr &update_state, const CN } // Replace Depend with its input. if (depend->size() == kMinDependSize) { - auto depend_input = depend->inputs().at(kInputIndex); + auto depend_input = depend->input(kInputIndex); mgr->Replace(depend, depend_input); } else { auto inputs = depend->inputs(); @@ -163,7 +163,7 @@ AnfNodePtr EliminateMakeTupleWithDeadNode(const CNodePtr &update_state, const CN if (make_tuple->size() != kMakeTupleSize) { return nullptr; } - auto &node = make_tuple->inputs().at(kAttachIndex); + auto &node = make_tuple->input(kAttachIndex); auto node_abs = node->abstract(); if (node_abs == nullptr || !node_abs->isa()) { return nullptr; @@ -173,7 +173,7 @@ AnfNodePtr EliminateMakeTupleWithDeadNode(const CNodePtr &update_state, const CN return nullptr; } // Create a new UpdateState to replace the old one. - const auto &attach = make_tuple->inputs().at(kInputIndex); + const auto &attach = make_tuple->input(kInputIndex); auto new_update_state = fg->NewCNode({update_state->input(0), update_state->input(1), attach}); new_update_state->set_abstract(update_state->abstract()); new_update_state->set_scope(update_state->scope()); @@ -206,42 +206,47 @@ AnfNodePtr EliminateUpdateStateWithMakeTupleFunc(const CNodePtr &update_state, c if (make_tuple->size() != kMakeTupleSize) { return nullptr; } - auto &first_input = make_tuple->inputs().at(kInputIndex); + auto &first_input = make_tuple->input(kInputIndex); if (IsValueNode(first_input) && OnlyMakeTupleUseFunc(make_tuple, first_input)) { return update_state->input(1); } - auto &second_input = make_tuple->inputs().at(kAttachIndex); + auto &second_input = make_tuple->input(kAttachIndex); if (IsValueNode(second_input) && OnlyMakeTupleUseFunc(make_tuple, second_input)) { return update_state->input(1); } return nullptr; } -size_t GetLoadsFollowLoad(const CNodePtr &load, std::vector *loads); -size_t GetLoadsFollowTuple(const CNodePtr &update_state, const CNodePtr &make_tuple, std::vector *loads); +size_t GetLoadsFollowLoad(const CNodePtr &load, std::vector *update_states, std::vector *loads); +size_t GetLoadsFollowTuple(const CNodePtr &update_state, const CNodePtr &make_tuple, + std::vector *update_states, std::vector *loads); // Search consecutive load nodes from UpdateState node. -size_t GetLoadsFromUpdateState(const CNodePtr &update_state, std::vector *loads) { - auto &attach = update_state->inputs().at(kAttachIndex); +size_t GetLoadsFromUpdateState(const CNodePtr &update_state, std::vector *update_states, + std::vector *loads) { + auto &attach = update_state->input(kAttachIndex); if (IsPrimitiveCNode(attach, prim::kPrimLoad)) { - return GetLoadsFollowLoad(attach->cast(), loads); + update_states->emplace_back(update_state); + return GetLoadsFollowLoad(attach->cast(), update_states, loads); } if (IsPrimitiveCNode(attach, prim::kPrimMakeTuple)) { - return GetLoadsFollowTuple(update_state, attach->cast(), loads); + update_states->emplace_back(update_state); + return GetLoadsFollowTuple(update_state, attach->cast(), update_states, loads); } return 0; } -size_t GetLoadsFollowLoad(const CNodePtr &load, std::vector *loads) { - loads->push_back(load); - auto &load_attach = load->inputs().at(kAttachIndex); +size_t GetLoadsFollowLoad(const CNodePtr &load, std::vector *update_states, std::vector *loads) { + loads->emplace_back(load); + auto &load_attach = load->input(kAttachIndex); if (IsPrimitiveCNode(load_attach, prim::kPrimUpdateState)) { - return GetLoadsFromUpdateState(load_attach->cast(), loads) + 1; + return GetLoadsFromUpdateState(load_attach->cast(), update_states, loads) + 1; } return 1; } -size_t GetLoadsFollowTuple(const CNodePtr &update_state, const CNodePtr &make_tuple, std::vector *loads) { +size_t GetLoadsFollowTuple(const CNodePtr &update_state, const CNodePtr &make_tuple, + std::vector *update_states, std::vector *loads) { if (!OnlyUpdateStateUse(update_state, make_tuple)) { // UpdateState should be the only user of return 0; @@ -256,12 +261,12 @@ size_t GetLoadsFollowTuple(const CNodePtr &update_state, const CNodePtr &make_tu // Add load nodes from tuple elements. for (size_t i = 1; i < inputs.size(); ++i) { auto &element = inputs.at(i); - loads->push_back(element->cast()); + loads->emplace_back(element->cast()); } // Follow prev update state if found. - auto prev_node = update_state->inputs().at(kInputIndex); + auto prev_node = update_state->input(kInputIndex); if (IsPrimitiveCNode(prev_node, prim::kPrimUpdateState)) { - return GetLoadsFromUpdateState(prev_node->cast(), loads) + 1; + return GetLoadsFromUpdateState(prev_node->cast(), update_states, loads) + 1; } return 1; } @@ -301,7 +306,8 @@ AnfNodePtr MakeTupleForSameNodes(const FuncGraphPtr &fg, const CNodePtr &old_upd // xN = Load(xN, u) // t = make_tuple(x1, x2, ... , xN) // u1 = UpdateState(u, t) -AnfNodePtr EliminateUpdateStateForLoads(const CNodePtr &old_update_state, const std::vector &loads) { +AnfNodePtr EliminateUpdateStateForLoads(const CNodePtr &old_update_state, const std::vector &update_states, + const std::vector &loads) { auto fg = old_update_state->func_graph(); if (fg == nullptr) { return nullptr; @@ -315,20 +321,24 @@ AnfNodePtr EliminateUpdateStateForLoads(const CNodePtr &old_update_state, const std::set loaded_para_set; make_tuple_inputs.reserve(loads.size() + 1); make_tuple_inputs.emplace_back(NewValueNode(prim::kPrimMakeTuple)); - auto input_monad = loads.back()->inputs().at(kAttachIndex); + auto input_monad = loads.back()->input(kAttachIndex); for (auto iter = loads.rbegin(); iter != loads.rend(); ++iter) { auto &load = *iter; - auto result = loaded_para_set.emplace(load->inputs().at(kInputIndex)); + auto result = loaded_para_set.emplace(load->input(kInputIndex)); const bool is_new_load = result.second; if (is_new_load) { // Put Load node as a tuple element, if the parameter is not loaded by other Load. make_tuple_inputs.emplace_back(load); } - if (load->inputs().at(kAttachIndex) != input_monad) { + if (load->input(kAttachIndex) != input_monad) { // Set all load use same input monad. mgr->SetEdge(load, kAttachIndex, input_monad); } } + for (auto i = update_states.size() - 1; i > 0; i--) { + auto &us = update_states[i]; + mgr->Replace(us, us->input(kInputIndex)); + } if (make_tuple_inputs.size() == 1) { // This should not happen. MS_LOG(WARNING) << "No loads for " << old_update_state->DebugString(2); @@ -538,7 +548,7 @@ AnfNodePtr UpdatestateEliminater::operator()(const OptimizerPtr &, const AnfNode MS_LOG(WARNING) << "UpdatestateEliminater encounter invalid node: " << node->DebugString(); return nullptr; } - auto &attach = update_state_node->inputs().at(kAttachIndex); + auto &attach = update_state_node->input(kAttachIndex); if (IsPrimitiveCNode(attach, prim::kPrimDepend)) { return EliminateUpdateStateWithDepend(update_state_node, attach->cast()); } @@ -586,9 +596,10 @@ AnfNodePtr UpdatestateEliminater::operator()(const OptimizerPtr &, const AnfNode return new_node; } } + std::vector update_states; std::vector loads; - if (GetLoadsFromUpdateState(update_state_node, &loads) > 1 && loads.size() > 1) { - return EliminateUpdateStateForLoads(update_state_node, loads); + if (GetLoadsFromUpdateState(update_state_node, &update_states, &loads) > 1 && loads.size() > 1) { + return EliminateUpdateStateForLoads(update_state_node, update_states, loads); } // Eliminate UpdateStates that attaches a no-side-effect node. if (!attach_is_load && !attach_is_tuple) { diff --git a/mindspore/ccsrc/frontend/optimizer/opt.cc b/mindspore/ccsrc/frontend/optimizer/opt.cc index 5a4c56cc20f..e9c557a4d29 100644 --- a/mindspore/ccsrc/frontend/optimizer/opt.cc +++ b/mindspore/ccsrc/frontend/optimizer/opt.cc @@ -103,8 +103,62 @@ static bool isTraversable(const AnfNodePtr &node) { return false; } -bool SubstitutionList::ApplyTransform(const OptimizerPtr &optimizer, const AnfNodePtr &root_node, - const SubstitutionPtr &transform) const { +static inline AnfNodePtr DoTransform(const OptimizerPtr &optimizer, const AnfNodePtr &node, + const SubstitutionPtr &substitution) { + auto manager = optimizer->manager(); + bool is_match = substitution->predicate_(node); + if (is_match) { + TraceGuard trace_guard(std::make_shared(node->debug_info())); + auto res = (*substitution)(optimizer, node); + if (res != nullptr && res != node) { +#ifdef ENABLE_PROFILE + double t = GetTime(); +#endif + MS_LOG(DEBUG) << "Replace " << node->DebugString() << " with " << res->DebugString() << ", by " + << substitution->name_; + (void)manager->Replace(node, res); +#ifdef ENABLE_PROFILE + MsProfile::StatTime("replace." + substitution->name_, GetTime() - t); +#endif + return res; + } + } + return nullptr; +} + +static inline void UpdateTransformingList(const OptimizerPtr &optimizer, const AnfNodePtr &node, + std::deque *todo, bool change, size_t seen) { + if (IsValueNode(node)) { + (*todo).emplace_back(GetValueNode(node)->output()); + } + if (node->isa()) { + auto &inputs = node->cast()->inputs(); + (void)std::copy(inputs.begin(), inputs.end(), std::back_inserter(*todo)); + } + + if (!change) { + return; + } + auto manager = optimizer->manager(); + auto &node_users = manager->node_users(); + auto users_iterator = node_users.find(node); + if (users_iterator == node_users.end()) { + return; + } + auto users = users_iterator->second; + for (auto &use : users) { + auto use_node = use.first; + if (use_node == nullptr) { + continue; + } + (*todo).emplace_back(use_node); + if (use_node->seen_ == seen) { + use_node->seen_--; + } + } +} + +bool SubstitutionList::ApplyIRToSubstitutions(const OptimizerPtr &optimizer, const FuncGraphPtr &func_graph) const { #ifdef ENABLE_PROFILE double start = GetTime(); #endif @@ -113,7 +167,7 @@ bool SubstitutionList::ApplyTransform(const OptimizerPtr &optimizer, const AnfNo // 1024 is for the initial capacity of deque std::deque todo(1024); todo.clear(); - todo.push_back(root_node); + todo.emplace_back(func_graph->output()); bool changes = false; auto &all_nodes = manager->all_nodes(); @@ -121,59 +175,61 @@ bool SubstitutionList::ApplyTransform(const OptimizerPtr &optimizer, const AnfNo AnfNodePtr node = todo.front(); todo.pop_front(); - // check whether this node has been matched. if (node == nullptr || node->seen_ == seen || !isTraversable(node) || !all_nodes.contains(node)) { continue; } node->seen_ = seen; - // select nodes that this transform can be applied. - bool is_match = transform->predicate_(node); - - // apply transform on this node bool change = false; - if (is_match) { - TraceGuard trace_guard(std::make_shared(node->debug_info())); - auto ret = (*transform)(optimizer, node); - if (ret != nullptr && ret != node) { + for (auto &substitution : list_) { + auto res = DoTransform(optimizer, node, substitution); + if (res != nullptr) { change = true; changes = true; -#ifdef ENABLE_PROFILE - double t = GetTime(); -#endif - MS_LOG(DEBUG) << "transform: " << transform->name_ << " will replace: " << node->DebugString() - << " with: " << ret->DebugString(); - (void)manager->Replace(node, ret); -#ifdef ENABLE_PROFILE - MsProfile::StatTime("replace." + transform->name_, GetTime() - t); -#endif - node = ret; + node = res; + todo.emplace_back(res); + break; } } + UpdateTransformingList(optimizer, node, &todo, change, seen); + } +#ifdef ENABLE_PROFILE + MsProfile::StatTime("opt.transforms." + optimizer->name(), GetTime() - start); +#endif + return changes; +} - // find success, and add them to todo list - if (IsValueNode(node)) { - todo.push_back(GetValueNode(node)->output()); - } +bool SubstitutionList::ApplySubstitutionToIR(const OptimizerPtr &optimizer, const AnfNodePtr &root_node, + const SubstitutionPtr &substitution) const { +#ifdef ENABLE_PROFILE + double start = GetTime(); +#endif + FuncGraphManagerPtr manager = optimizer->manager(); + auto seen = NewSeenGeneration(); + // 1024 is for the initial capacity of deque + std::deque todo(1024); + todo.clear(); + todo.emplace_back(root_node); + bool changes = false; - if (node->isa()) { - auto &inputs = node->cast()->inputs(); - (void)std::copy(inputs.begin(), inputs.end(), std::back_inserter(todo)); - } + auto &all_nodes = manager->all_nodes(); + while (!todo.empty()) { + AnfNodePtr node = todo.front(); + todo.pop_front(); - auto &node_users = manager->node_users(); - if (change && node_users.find(node) != node_users.end()) { - for (auto &use : node_users[node]) { - auto use_node = use.first; - if (use_node == nullptr) { - continue; - } - todo.push_back(use_node); - if (use_node->seen_ == seen) { - use_node->seen_--; - } - } + if (node == nullptr || node->seen_ == seen || !isTraversable(node) || !all_nodes.contains(node)) { + continue; } + node->seen_ = seen; + + bool change = false; + auto res = DoTransform(optimizer, node, substitution); + if (res != nullptr) { + change = true; + changes = true; + node = res; + } + UpdateTransformingList(optimizer, node, &todo, change, seen); } #ifdef ENABLE_PROFILE @@ -182,13 +238,29 @@ bool SubstitutionList::ApplyTransform(const OptimizerPtr &optimizer, const AnfNo return changes; } -bool SubstitutionList::operator()(const FuncGraphPtr &func_graph, const OptimizerPtr &optimizer) const { - MS_EXCEPTION_IF_NULL(optimizer); - MS_EXCEPTION_IF_NULL(func_graph); - FuncGraphManagerPtr manager = optimizer->manager(); - manager->AddFuncGraph(func_graph); +bool SubstitutionList::ApplySubstitutionsToIRForIsolate(const OptimizerPtr &optimizer) const { + const auto &manager = optimizer->manager(); + const auto &nodes = manager->isolate_nodes(); + bool changes = false; + bool loop = true; + while (loop) { + loop = false; + std::for_each(list_.cbegin(), list_.cend(), [&](const auto &substitution) { + std::for_each(nodes.cbegin(), nodes.cend(), [&](const auto &node) { + bool change = ApplySubstitutionToIR(optimizer, node, substitution); + changes = changes || change; + loop = loop || change; + }); + }); + if (is_once_) { + break; + } + } + return changes; +} - // for transform status counting +bool SubstitutionList::ApplySubstitutionsToIR(const OptimizerPtr &optimizer, const FuncGraphPtr &func_graph) const { + // Add for substitution status counting size_t space = 0; std::unordered_map> status; if (optimizer->is_on_debug_) { @@ -197,47 +269,39 @@ bool SubstitutionList::operator()(const FuncGraphPtr &func_graph, const Optimize } } - bool loop = false; bool changes = false; - - do { + bool loop = true; + while (loop) { loop = false; for (size_t i = 0; i < list_.size(); i++) { - auto change = ApplyTransform(optimizer, func_graph->output(), list_[i]); + const auto &substitution = list_[i]; + bool change = ApplySubstitutionToIR(optimizer, func_graph->output(), substitution); changes = changes || change; loop = loop || change; - // apply transform on isolate nodes. - auto &isolate_nodes = manager->isolate_nodes(); - for (auto &node : isolate_nodes) { - change = ApplyTransform(optimizer, node, list_[i]); - changes = changes || change; - loop = loop || change; - } - - // record the status of each transform static const auto enable_dump_pass_ir = (common::GetEnv("ENV_DUMP_PASS_IR") == "1"); if (enable_dump_pass_ir && MsContext::GetInstance()->get_param(MS_CTX_SAVE_GRAPHS_FLAG)) { - auto fg_name = optimizer->name() + "_" + std::to_string(optimizer->CurPass_.counter) + "_" + - optimizer->CurPass_.name + "_" + list_[i]->name_; + auto fg_name = optimizer->name() + "_r" + std::to_string(optimizer->CurPass_.counter) + "_" + + optimizer->CurPass_.name + "_" + substitution->name_; DumpIR(fg_name + ".ir", func_graph); if (MsContext::GetInstance()->get_param(MS_CTX_EXECUTION_MODE) != kPynativeMode) { func_graph->DumpFuncGraph(fg_name); ExportIR(fg_name + ".dat", "", func_graph); } } + + // Record the status of each substitution if (optimizer->is_on_debug_) { - status[list_[i]->name_ + std::to_string(i)].push_back(change); - space = std::max(list_[i]->name_.size(), space); + status[substitution->name_ + std::to_string(i)].push_back(change); + space = std::max(substitution->name_.size(), space); } } - if (is_once_) { break; } - } while (loop); + } - // display the status of each transform + // Display the status of each substitution if (optimizer->is_on_debug_) { std::stringstream ss; ss << std::endl @@ -253,7 +317,37 @@ bool SubstitutionList::operator()(const FuncGraphPtr &func_graph, const Optimize } MS_LOG(DEBUG) << ss.str(); } + return changes; +} +bool SubstitutionList::operator()(const FuncGraphPtr &func_graph, const OptimizerPtr &optimizer) const { + MS_EXCEPTION_IF_NULL(optimizer); + MS_EXCEPTION_IF_NULL(func_graph); + FuncGraphManagerPtr manager = optimizer->manager(); + manager->AddFuncGraph(func_graph); + bool changes = false; + static const auto traverse_mode = + (common::GetEnv("ENV_TRAVERSE_SUBSTITUTIONS_MODE") != "1" ? kOptTraverseFromIRToSubstitutions + : kOptTraverseFromSubstitutionsToIR); + if (traverse_mode == kOptTraverseFromIRToSubstitutions && + MsContext::GetInstance()->get_param(MS_CTX_EXECUTION_MODE) != kPynativeMode && + optimizer->traverse_nodes_first()) { + changes = ApplyIRToSubstitutions(optimizer, func_graph); + } else { + changes = ApplySubstitutionsToIR(optimizer, func_graph); + } + + bool has_isolate = !manager->isolate_nodes().empty(); + if (has_isolate) { +#ifdef ENABLE_PROFILE + double t = GetTime(); +#endif + bool change = ApplySubstitutionsToIRForIsolate(optimizer); + changes = changes || change; +#ifdef ENABLE_PROFILE + MsProfile::StatTime("opt.isolate.transform." + optimizer->name(), GetTime() - t); +#endif + } return changes; } } // namespace opt diff --git a/mindspore/ccsrc/frontend/optimizer/opt.h b/mindspore/ccsrc/frontend/optimizer/opt.h index eadc9db0717..f176b627049 100644 --- a/mindspore/ccsrc/frontend/optimizer/opt.h +++ b/mindspore/ccsrc/frontend/optimizer/opt.h @@ -59,6 +59,8 @@ SubstitutionPtr MakeSubstitution(const OptimizerCallerPtr &transform, const std: SubstitutionPtr MakeSubstitution(const OptimizerCallerPtr &transform, const std::string &name, const PredicateFuncType &predicate, const RenormAction &action_renorm = CHECK_RENORM); +enum OptTraverseSubstitutionsMode { kOptTraverseFromIRToSubstitutions = 0, kOptTraverseFromSubstitutionsToIR }; + class SubstitutionList { public: explicit SubstitutionList(const std::vector &patterns, bool is_once = false) @@ -68,7 +70,10 @@ class SubstitutionList { bool operator()(const FuncGraphPtr &func_graph, const OptimizerPtr &optimizer) const; private: - bool ApplyTransform(const OptimizerPtr &optimizer, const AnfNodePtr &node, const SubstitutionPtr &transform) const; + bool ApplyIRToSubstitutions(const OptimizerPtr &optimizer, const FuncGraphPtr &func_graph) const; + bool ApplySubstitutionToIR(const OptimizerPtr &optimizer, const AnfNodePtr &node, const SubstitutionPtr &sub) const; + bool ApplySubstitutionsToIR(const OptimizerPtr &optimizer, const FuncGraphPtr &func_graph) const; + bool ApplySubstitutionsToIRForIsolate(const OptimizerPtr &optimizer) const; std::vector list_; // a flag to mark this list of Substitution can only be executed only once bool is_once_; diff --git a/mindspore/ccsrc/frontend/optimizer/optimizer.h b/mindspore/ccsrc/frontend/optimizer/optimizer.h index 5436ce294d9..17c381ec56c 100644 --- a/mindspore/ccsrc/frontend/optimizer/optimizer.h +++ b/mindspore/ccsrc/frontend/optimizer/optimizer.h @@ -88,13 +88,14 @@ using OptPassGroupMap = std::vector>; class Optimizer : public std::enable_shared_from_this { public: - Optimizer(const std::string &name, const pipeline::ResourceBasePtr &resource_ptr) + Optimizer(const std::string &name, const pipeline::ResourceBasePtr &resource_ptr, bool traverse_nodes_first = true) : name_(name), resource_(resource_ptr), run_only_once_(false), is_watch_renormalize_(false), is_enable_(true), - is_untyped_generated_(false) {} + is_untyped_generated_(false), + traverse_nodes_first_(traverse_nodes_first) {} virtual ~Optimizer() = default; void Init(const OptPassGroupMap &passes, bool run_only_once) { @@ -129,8 +130,8 @@ class Optimizer : public std::enable_shared_from_this { static std::shared_ptr MakeOptimizer(const std::string &name, const pipeline::ResourceBasePtr resource_ptr, const OptPassGroupMap &passes, bool run_only_once = false, - bool watch_renormalize = false) { - OptimizerPtr optimizer = std::make_shared(name, resource_ptr); + bool watch_renormalize = false, bool traverse_nodes_first = true) { + OptimizerPtr optimizer = std::make_shared(name, resource_ptr, traverse_nodes_first); optimizer->Init(passes, run_only_once); if (watch_renormalize) { optimizer->enable_watch_renormalize(); @@ -223,6 +224,8 @@ class Optimizer : public std::enable_shared_from_this { bool is_watch_renormalize() { return is_watch_renormalize_; } void set_enable(bool enable) { is_enable_ = enable; } + bool traverse_nodes_first() { return traverse_nodes_first_; } + struct { int64_t counter; std::string name; @@ -239,6 +242,7 @@ class Optimizer : public std::enable_shared_from_this { bool is_watch_renormalize_; bool is_enable_; bool is_untyped_generated_; + bool traverse_nodes_first_; }; } // namespace opt } // namespace mindspore diff --git a/mindspore/ccsrc/pipeline/jit/parse/resolve.cc b/mindspore/ccsrc/pipeline/jit/parse/resolve.cc index 69c636dc0a8..4c002f81798 100644 --- a/mindspore/ccsrc/pipeline/jit/parse/resolve.cc +++ b/mindspore/ccsrc/pipeline/jit/parse/resolve.cc @@ -308,7 +308,8 @@ bool ResolveFuncGraph(const FuncGraphPtr &func_graph, const pipeline::ResourceBa return false; } opt::irpass::ResolveIRPassLib irpass; - opt::OptimizerPtr opt_resolve = opt::Optimizer::MakeOptimizer("opt_resolve", res, GetOptResolvePasses(irpass)); + opt::OptimizerPtr opt_resolve = + opt::Optimizer::MakeOptimizer("opt_resolve", res, GetOptResolvePasses(irpass), false, false, false); (void)parse::python_adapter::set_python_scoped(); diff --git a/mindspore/core/ir/anf.h b/mindspore/core/ir/anf.h index 08c4a60b1b3..4391d70cb3e 100644 --- a/mindspore/core/ir/anf.h +++ b/mindspore/core/ir/anf.h @@ -246,7 +246,7 @@ class CNode : public AnfNode, public EffectInfoHolder { bool IsApply(const PrimitivePtr &) const; const size_t size() const { return inputs_.size(); } - const AnfNodePtr input(size_t i) const { return inputs_[i]; } + const AnfNodePtr &input(size_t i) const { return inputs_.at(i); } const std::vector &inputs() const { return inputs_; } void add_input(const AnfNodePtr &input); void set_input(size_t i, const AnfNodePtr &input);