forked from mindspore-Ecosystem/mindspore
!12345 Traverse all nodes once, then traverse all Substitutions on each node.
From: @zh_qh Reviewed-by: Signed-off-by:
This commit is contained in:
commit
86e3099c05
|
@ -31,7 +31,6 @@
|
|||
#include "backend/kernel_compiler/kernel_build_info.h"
|
||||
#include "common/trans.h"
|
||||
#include "abstract/param_validator.h"
|
||||
#include "abstract/primitive_infer_map.h"
|
||||
#include "pipeline/jit/static_analysis/static_analysis.h"
|
||||
#include "utils/trace_base.h"
|
||||
|
||||
|
@ -1806,14 +1805,7 @@ void AnfRuntimeAlgorithm::InferShape(const CNodePtr &node) {
|
|||
args_spec_list.emplace_back(real_input->abstract());
|
||||
}
|
||||
}
|
||||
|
||||
auto &prim_eval_implement_map = abstract::GetPrimitiveToEvalImplMap();
|
||||
auto ret = prim_eval_implement_map.find(primitive);
|
||||
if (ret == prim_eval_implement_map.end()) {
|
||||
MS_LOG(EXCEPTION) << "Get infer shape function failed, primitive name:" << primitive->name()
|
||||
<< " primitive type:" << primitive->type_name();
|
||||
}
|
||||
auto eval_result = ret->second.impl_(nullptr, primitive, args_spec_list);
|
||||
auto eval_result = abstract::CppInferShape(primitive, args_spec_list);
|
||||
node->set_abstract(eval_result);
|
||||
}
|
||||
} // namespace session
|
||||
|
|
|
@ -230,6 +230,8 @@ ResolveIRPassLib::ResolveIRPassLib() {
|
|||
{prim::kPrimGetAttr, prim::kPrimResolve});
|
||||
resolver_resolve_ = MakeSubstitution(std::make_shared<ResolverResolve>(), "resolver_resolve", prim::kPrimResolve);
|
||||
resolver_getattr_ = MakeSubstitution(std::make_shared<ResolverGetAttr>(), "resolver_getattr", prim::kPrimGetAttr);
|
||||
resolver_getattr_resolve_ =
|
||||
MakeSubstitution(std::make_shared<ResolverGetAttrResolve>(), "resolver_getattr_resolve", prim::kPrimGetAttr);
|
||||
}
|
||||
|
||||
InferenceOptPrepareLib::InferenceOptPrepareLib() {
|
||||
|
|
|
@ -154,6 +154,7 @@ class ResolveIRPassLib {
|
|||
SubstitutionPtr resolver_resolve_and_getattr_;
|
||||
SubstitutionPtr resolver_resolve_;
|
||||
SubstitutionPtr resolver_getattr_;
|
||||
SubstitutionPtr resolver_getattr_resolve_;
|
||||
};
|
||||
|
||||
class InferenceOptPrepareLib {
|
||||
|
|
|
@ -71,7 +71,7 @@ AnfNodePtr EliminateUpdateStateOnlyUsedNode(const CNodePtr &update_state, const
|
|||
return nullptr;
|
||||
}
|
||||
// Replace UpdateState with the input monad.
|
||||
return update_state->inputs().at(kInputIndex);
|
||||
return update_state->input(kInputIndex);
|
||||
}
|
||||
|
||||
// Eliminate UpdateState that attaches a pure (no-side-effect) node.
|
||||
|
@ -100,7 +100,7 @@ AnfNodePtr EliminateUpdateStateForPureNode(const CNodePtr &update_state, const A
|
|||
}
|
||||
}
|
||||
// Remove UpdateState by replace it with its input monad.
|
||||
return update_state->inputs().at(kInputIndex);
|
||||
return update_state->input(kInputIndex);
|
||||
}
|
||||
|
||||
// Eliminate redundant UpdateState/Depend pair nodes caused by inline.
|
||||
|
@ -118,7 +118,7 @@ AnfNodePtr EliminateUpdateStateWithDepend(const CNodePtr &update_state, const CN
|
|||
// Skip if Depend attach input is not a monad.
|
||||
return nullptr;
|
||||
}
|
||||
auto update_monad = update_state->inputs().at(kInputIndex);
|
||||
auto update_monad = update_state->input(kInputIndex);
|
||||
if (!HasAbstractMonad(update_monad)) {
|
||||
// Skip if UpdateState input is not a monad.
|
||||
MS_LOG(WARNING) << "Not a monad input: " << update_state->DebugString();
|
||||
|
@ -139,7 +139,7 @@ AnfNodePtr EliminateUpdateStateWithDepend(const CNodePtr &update_state, const CN
|
|||
}
|
||||
// Replace Depend with its input.
|
||||
if (depend->size() == kMinDependSize) {
|
||||
auto depend_input = depend->inputs().at(kInputIndex);
|
||||
auto depend_input = depend->input(kInputIndex);
|
||||
mgr->Replace(depend, depend_input);
|
||||
} else {
|
||||
auto inputs = depend->inputs();
|
||||
|
@ -163,7 +163,7 @@ AnfNodePtr EliminateMakeTupleWithDeadNode(const CNodePtr &update_state, const CN
|
|||
if (make_tuple->size() != kMakeTupleSize) {
|
||||
return nullptr;
|
||||
}
|
||||
auto &node = make_tuple->inputs().at(kAttachIndex);
|
||||
auto &node = make_tuple->input(kAttachIndex);
|
||||
auto node_abs = node->abstract();
|
||||
if (node_abs == nullptr || !node_abs->isa<abstract::AbstractError>()) {
|
||||
return nullptr;
|
||||
|
@ -173,7 +173,7 @@ AnfNodePtr EliminateMakeTupleWithDeadNode(const CNodePtr &update_state, const CN
|
|||
return nullptr;
|
||||
}
|
||||
// Create a new UpdateState to replace the old one.
|
||||
const auto &attach = make_tuple->inputs().at(kInputIndex);
|
||||
const auto &attach = make_tuple->input(kInputIndex);
|
||||
auto new_update_state = fg->NewCNode({update_state->input(0), update_state->input(1), attach});
|
||||
new_update_state->set_abstract(update_state->abstract());
|
||||
new_update_state->set_scope(update_state->scope());
|
||||
|
@ -206,42 +206,47 @@ AnfNodePtr EliminateUpdateStateWithMakeTupleFunc(const CNodePtr &update_state, c
|
|||
if (make_tuple->size() != kMakeTupleSize) {
|
||||
return nullptr;
|
||||
}
|
||||
auto &first_input = make_tuple->inputs().at(kInputIndex);
|
||||
auto &first_input = make_tuple->input(kInputIndex);
|
||||
if (IsValueNode<FuncGraph>(first_input) && OnlyMakeTupleUseFunc(make_tuple, first_input)) {
|
||||
return update_state->input(1);
|
||||
}
|
||||
auto &second_input = make_tuple->inputs().at(kAttachIndex);
|
||||
auto &second_input = make_tuple->input(kAttachIndex);
|
||||
if (IsValueNode<FuncGraph>(second_input) && OnlyMakeTupleUseFunc(make_tuple, second_input)) {
|
||||
return update_state->input(1);
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
size_t GetLoadsFollowLoad(const CNodePtr &load, std::vector<CNodePtr> *loads);
|
||||
size_t GetLoadsFollowTuple(const CNodePtr &update_state, const CNodePtr &make_tuple, std::vector<CNodePtr> *loads);
|
||||
size_t GetLoadsFollowLoad(const CNodePtr &load, std::vector<CNodePtr> *update_states, std::vector<CNodePtr> *loads);
|
||||
size_t GetLoadsFollowTuple(const CNodePtr &update_state, const CNodePtr &make_tuple,
|
||||
std::vector<CNodePtr> *update_states, std::vector<CNodePtr> *loads);
|
||||
|
||||
// Search consecutive load nodes from UpdateState node.
|
||||
size_t GetLoadsFromUpdateState(const CNodePtr &update_state, std::vector<CNodePtr> *loads) {
|
||||
auto &attach = update_state->inputs().at(kAttachIndex);
|
||||
size_t GetLoadsFromUpdateState(const CNodePtr &update_state, std::vector<CNodePtr> *update_states,
|
||||
std::vector<CNodePtr> *loads) {
|
||||
auto &attach = update_state->input(kAttachIndex);
|
||||
if (IsPrimitiveCNode(attach, prim::kPrimLoad)) {
|
||||
return GetLoadsFollowLoad(attach->cast<CNodePtr>(), loads);
|
||||
update_states->emplace_back(update_state);
|
||||
return GetLoadsFollowLoad(attach->cast<CNodePtr>(), update_states, loads);
|
||||
}
|
||||
if (IsPrimitiveCNode(attach, prim::kPrimMakeTuple)) {
|
||||
return GetLoadsFollowTuple(update_state, attach->cast<CNodePtr>(), loads);
|
||||
update_states->emplace_back(update_state);
|
||||
return GetLoadsFollowTuple(update_state, attach->cast<CNodePtr>(), update_states, loads);
|
||||
}
|
||||
return 0;
|
||||
}
|
||||
|
||||
size_t GetLoadsFollowLoad(const CNodePtr &load, std::vector<CNodePtr> *loads) {
|
||||
loads->push_back(load);
|
||||
auto &load_attach = load->inputs().at(kAttachIndex);
|
||||
size_t GetLoadsFollowLoad(const CNodePtr &load, std::vector<CNodePtr> *update_states, std::vector<CNodePtr> *loads) {
|
||||
loads->emplace_back(load);
|
||||
auto &load_attach = load->input(kAttachIndex);
|
||||
if (IsPrimitiveCNode(load_attach, prim::kPrimUpdateState)) {
|
||||
return GetLoadsFromUpdateState(load_attach->cast<CNodePtr>(), loads) + 1;
|
||||
return GetLoadsFromUpdateState(load_attach->cast<CNodePtr>(), update_states, loads) + 1;
|
||||
}
|
||||
return 1;
|
||||
}
|
||||
|
||||
size_t GetLoadsFollowTuple(const CNodePtr &update_state, const CNodePtr &make_tuple, std::vector<CNodePtr> *loads) {
|
||||
size_t GetLoadsFollowTuple(const CNodePtr &update_state, const CNodePtr &make_tuple,
|
||||
std::vector<CNodePtr> *update_states, std::vector<CNodePtr> *loads) {
|
||||
if (!OnlyUpdateStateUse(update_state, make_tuple)) {
|
||||
// UpdateState should be the only user of
|
||||
return 0;
|
||||
|
@ -256,12 +261,12 @@ size_t GetLoadsFollowTuple(const CNodePtr &update_state, const CNodePtr &make_tu
|
|||
// Add load nodes from tuple elements.
|
||||
for (size_t i = 1; i < inputs.size(); ++i) {
|
||||
auto &element = inputs.at(i);
|
||||
loads->push_back(element->cast<CNodePtr>());
|
||||
loads->emplace_back(element->cast<CNodePtr>());
|
||||
}
|
||||
// Follow prev update state if found.
|
||||
auto prev_node = update_state->inputs().at(kInputIndex);
|
||||
auto prev_node = update_state->input(kInputIndex);
|
||||
if (IsPrimitiveCNode(prev_node, prim::kPrimUpdateState)) {
|
||||
return GetLoadsFromUpdateState(prev_node->cast<CNodePtr>(), loads) + 1;
|
||||
return GetLoadsFromUpdateState(prev_node->cast<CNodePtr>(), update_states, loads) + 1;
|
||||
}
|
||||
return 1;
|
||||
}
|
||||
|
@ -301,7 +306,8 @@ AnfNodePtr MakeTupleForSameNodes(const FuncGraphPtr &fg, const CNodePtr &old_upd
|
|||
// xN = Load(xN, u)
|
||||
// t = make_tuple(x1, x2, ... , xN)
|
||||
// u1 = UpdateState(u, t)
|
||||
AnfNodePtr EliminateUpdateStateForLoads(const CNodePtr &old_update_state, const std::vector<CNodePtr> &loads) {
|
||||
AnfNodePtr EliminateUpdateStateForLoads(const CNodePtr &old_update_state, const std::vector<CNodePtr> &update_states,
|
||||
const std::vector<CNodePtr> &loads) {
|
||||
auto fg = old_update_state->func_graph();
|
||||
if (fg == nullptr) {
|
||||
return nullptr;
|
||||
|
@ -315,20 +321,24 @@ AnfNodePtr EliminateUpdateStateForLoads(const CNodePtr &old_update_state, const
|
|||
std::set<AnfNodePtr> loaded_para_set;
|
||||
make_tuple_inputs.reserve(loads.size() + 1);
|
||||
make_tuple_inputs.emplace_back(NewValueNode(prim::kPrimMakeTuple));
|
||||
auto input_monad = loads.back()->inputs().at(kAttachIndex);
|
||||
auto input_monad = loads.back()->input(kAttachIndex);
|
||||
for (auto iter = loads.rbegin(); iter != loads.rend(); ++iter) {
|
||||
auto &load = *iter;
|
||||
auto result = loaded_para_set.emplace(load->inputs().at(kInputIndex));
|
||||
auto result = loaded_para_set.emplace(load->input(kInputIndex));
|
||||
const bool is_new_load = result.second;
|
||||
if (is_new_load) {
|
||||
// Put Load node as a tuple element, if the parameter is not loaded by other Load.
|
||||
make_tuple_inputs.emplace_back(load);
|
||||
}
|
||||
if (load->inputs().at(kAttachIndex) != input_monad) {
|
||||
if (load->input(kAttachIndex) != input_monad) {
|
||||
// Set all load use same input monad.
|
||||
mgr->SetEdge(load, kAttachIndex, input_monad);
|
||||
}
|
||||
}
|
||||
for (auto i = update_states.size() - 1; i > 0; i--) {
|
||||
auto &us = update_states[i];
|
||||
mgr->Replace(us, us->input(kInputIndex));
|
||||
}
|
||||
if (make_tuple_inputs.size() == 1) {
|
||||
// This should not happen.
|
||||
MS_LOG(WARNING) << "No loads for " << old_update_state->DebugString(2);
|
||||
|
@ -538,7 +548,7 @@ AnfNodePtr UpdatestateEliminater::operator()(const OptimizerPtr &, const AnfNode
|
|||
MS_LOG(WARNING) << "UpdatestateEliminater encounter invalid node: " << node->DebugString();
|
||||
return nullptr;
|
||||
}
|
||||
auto &attach = update_state_node->inputs().at(kAttachIndex);
|
||||
auto &attach = update_state_node->input(kAttachIndex);
|
||||
if (IsPrimitiveCNode(attach, prim::kPrimDepend)) {
|
||||
return EliminateUpdateStateWithDepend(update_state_node, attach->cast<CNodePtr>());
|
||||
}
|
||||
|
@ -586,9 +596,10 @@ AnfNodePtr UpdatestateEliminater::operator()(const OptimizerPtr &, const AnfNode
|
|||
return new_node;
|
||||
}
|
||||
}
|
||||
std::vector<CNodePtr> update_states;
|
||||
std::vector<CNodePtr> loads;
|
||||
if (GetLoadsFromUpdateState(update_state_node, &loads) > 1 && loads.size() > 1) {
|
||||
return EliminateUpdateStateForLoads(update_state_node, loads);
|
||||
if (GetLoadsFromUpdateState(update_state_node, &update_states, &loads) > 1 && loads.size() > 1) {
|
||||
return EliminateUpdateStateForLoads(update_state_node, update_states, loads);
|
||||
}
|
||||
// Eliminate UpdateStates that attaches a no-side-effect node.
|
||||
if (!attach_is_load && !attach_is_tuple) {
|
||||
|
|
|
@ -103,8 +103,62 @@ static bool isTraversable(const AnfNodePtr &node) {
|
|||
return false;
|
||||
}
|
||||
|
||||
bool SubstitutionList::ApplyTransform(const OptimizerPtr &optimizer, const AnfNodePtr &root_node,
|
||||
const SubstitutionPtr &transform) const {
|
||||
static inline AnfNodePtr DoTransform(const OptimizerPtr &optimizer, const AnfNodePtr &node,
|
||||
const SubstitutionPtr &substitution) {
|
||||
auto manager = optimizer->manager();
|
||||
bool is_match = substitution->predicate_(node);
|
||||
if (is_match) {
|
||||
TraceGuard trace_guard(std::make_shared<TraceOpt>(node->debug_info()));
|
||||
auto res = (*substitution)(optimizer, node);
|
||||
if (res != nullptr && res != node) {
|
||||
#ifdef ENABLE_PROFILE
|
||||
double t = GetTime();
|
||||
#endif
|
||||
MS_LOG(DEBUG) << "Replace " << node->DebugString() << " with " << res->DebugString() << ", by "
|
||||
<< substitution->name_;
|
||||
(void)manager->Replace(node, res);
|
||||
#ifdef ENABLE_PROFILE
|
||||
MsProfile::StatTime("replace." + substitution->name_, GetTime() - t);
|
||||
#endif
|
||||
return res;
|
||||
}
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
static inline void UpdateTransformingList(const OptimizerPtr &optimizer, const AnfNodePtr &node,
|
||||
std::deque<AnfNodePtr> *todo, bool change, size_t seen) {
|
||||
if (IsValueNode<FuncGraph>(node)) {
|
||||
(*todo).emplace_back(GetValueNode<FuncGraphPtr>(node)->output());
|
||||
}
|
||||
if (node->isa<CNode>()) {
|
||||
auto &inputs = node->cast<CNodePtr>()->inputs();
|
||||
(void)std::copy(inputs.begin(), inputs.end(), std::back_inserter(*todo));
|
||||
}
|
||||
|
||||
if (!change) {
|
||||
return;
|
||||
}
|
||||
auto manager = optimizer->manager();
|
||||
auto &node_users = manager->node_users();
|
||||
auto users_iterator = node_users.find(node);
|
||||
if (users_iterator == node_users.end()) {
|
||||
return;
|
||||
}
|
||||
auto users = users_iterator->second;
|
||||
for (auto &use : users) {
|
||||
auto use_node = use.first;
|
||||
if (use_node == nullptr) {
|
||||
continue;
|
||||
}
|
||||
(*todo).emplace_back(use_node);
|
||||
if (use_node->seen_ == seen) {
|
||||
use_node->seen_--;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
bool SubstitutionList::ApplyIRToSubstitutions(const OptimizerPtr &optimizer, const FuncGraphPtr &func_graph) const {
|
||||
#ifdef ENABLE_PROFILE
|
||||
double start = GetTime();
|
||||
#endif
|
||||
|
@ -113,7 +167,7 @@ bool SubstitutionList::ApplyTransform(const OptimizerPtr &optimizer, const AnfNo
|
|||
// 1024 is for the initial capacity of deque
|
||||
std::deque<AnfNodePtr> todo(1024);
|
||||
todo.clear();
|
||||
todo.push_back(root_node);
|
||||
todo.emplace_back(func_graph->output());
|
||||
bool changes = false;
|
||||
|
||||
auto &all_nodes = manager->all_nodes();
|
||||
|
@ -121,59 +175,61 @@ bool SubstitutionList::ApplyTransform(const OptimizerPtr &optimizer, const AnfNo
|
|||
AnfNodePtr node = todo.front();
|
||||
todo.pop_front();
|
||||
|
||||
// check whether this node has been matched.
|
||||
if (node == nullptr || node->seen_ == seen || !isTraversable(node) || !all_nodes.contains(node)) {
|
||||
continue;
|
||||
}
|
||||
node->seen_ = seen;
|
||||
|
||||
// select nodes that this transform can be applied.
|
||||
bool is_match = transform->predicate_(node);
|
||||
|
||||
// apply transform on this node
|
||||
bool change = false;
|
||||
if (is_match) {
|
||||
TraceGuard trace_guard(std::make_shared<TraceOpt>(node->debug_info()));
|
||||
auto ret = (*transform)(optimizer, node);
|
||||
if (ret != nullptr && ret != node) {
|
||||
for (auto &substitution : list_) {
|
||||
auto res = DoTransform(optimizer, node, substitution);
|
||||
if (res != nullptr) {
|
||||
change = true;
|
||||
changes = true;
|
||||
#ifdef ENABLE_PROFILE
|
||||
double t = GetTime();
|
||||
#endif
|
||||
MS_LOG(DEBUG) << "transform: " << transform->name_ << " will replace: " << node->DebugString()
|
||||
<< " with: " << ret->DebugString();
|
||||
(void)manager->Replace(node, ret);
|
||||
#ifdef ENABLE_PROFILE
|
||||
MsProfile::StatTime("replace." + transform->name_, GetTime() - t);
|
||||
#endif
|
||||
node = ret;
|
||||
node = res;
|
||||
todo.emplace_back(res);
|
||||
break;
|
||||
}
|
||||
}
|
||||
UpdateTransformingList(optimizer, node, &todo, change, seen);
|
||||
}
|
||||
#ifdef ENABLE_PROFILE
|
||||
MsProfile::StatTime("opt.transforms." + optimizer->name(), GetTime() - start);
|
||||
#endif
|
||||
return changes;
|
||||
}
|
||||
|
||||
// find success, and add them to todo list
|
||||
if (IsValueNode<FuncGraph>(node)) {
|
||||
todo.push_back(GetValueNode<FuncGraphPtr>(node)->output());
|
||||
}
|
||||
bool SubstitutionList::ApplySubstitutionToIR(const OptimizerPtr &optimizer, const AnfNodePtr &root_node,
|
||||
const SubstitutionPtr &substitution) const {
|
||||
#ifdef ENABLE_PROFILE
|
||||
double start = GetTime();
|
||||
#endif
|
||||
FuncGraphManagerPtr manager = optimizer->manager();
|
||||
auto seen = NewSeenGeneration();
|
||||
// 1024 is for the initial capacity of deque
|
||||
std::deque<AnfNodePtr> todo(1024);
|
||||
todo.clear();
|
||||
todo.emplace_back(root_node);
|
||||
bool changes = false;
|
||||
|
||||
if (node->isa<CNode>()) {
|
||||
auto &inputs = node->cast<CNodePtr>()->inputs();
|
||||
(void)std::copy(inputs.begin(), inputs.end(), std::back_inserter(todo));
|
||||
}
|
||||
auto &all_nodes = manager->all_nodes();
|
||||
while (!todo.empty()) {
|
||||
AnfNodePtr node = todo.front();
|
||||
todo.pop_front();
|
||||
|
||||
auto &node_users = manager->node_users();
|
||||
if (change && node_users.find(node) != node_users.end()) {
|
||||
for (auto &use : node_users[node]) {
|
||||
auto use_node = use.first;
|
||||
if (use_node == nullptr) {
|
||||
continue;
|
||||
}
|
||||
todo.push_back(use_node);
|
||||
if (use_node->seen_ == seen) {
|
||||
use_node->seen_--;
|
||||
}
|
||||
}
|
||||
if (node == nullptr || node->seen_ == seen || !isTraversable(node) || !all_nodes.contains(node)) {
|
||||
continue;
|
||||
}
|
||||
node->seen_ = seen;
|
||||
|
||||
bool change = false;
|
||||
auto res = DoTransform(optimizer, node, substitution);
|
||||
if (res != nullptr) {
|
||||
change = true;
|
||||
changes = true;
|
||||
node = res;
|
||||
}
|
||||
UpdateTransformingList(optimizer, node, &todo, change, seen);
|
||||
}
|
||||
|
||||
#ifdef ENABLE_PROFILE
|
||||
|
@ -182,13 +238,29 @@ bool SubstitutionList::ApplyTransform(const OptimizerPtr &optimizer, const AnfNo
|
|||
return changes;
|
||||
}
|
||||
|
||||
bool SubstitutionList::operator()(const FuncGraphPtr &func_graph, const OptimizerPtr &optimizer) const {
|
||||
MS_EXCEPTION_IF_NULL(optimizer);
|
||||
MS_EXCEPTION_IF_NULL(func_graph);
|
||||
FuncGraphManagerPtr manager = optimizer->manager();
|
||||
manager->AddFuncGraph(func_graph);
|
||||
bool SubstitutionList::ApplySubstitutionsToIRForIsolate(const OptimizerPtr &optimizer) const {
|
||||
const auto &manager = optimizer->manager();
|
||||
const auto &nodes = manager->isolate_nodes();
|
||||
bool changes = false;
|
||||
bool loop = true;
|
||||
while (loop) {
|
||||
loop = false;
|
||||
std::for_each(list_.cbegin(), list_.cend(), [&](const auto &substitution) {
|
||||
std::for_each(nodes.cbegin(), nodes.cend(), [&](const auto &node) {
|
||||
bool change = ApplySubstitutionToIR(optimizer, node, substitution);
|
||||
changes = changes || change;
|
||||
loop = loop || change;
|
||||
});
|
||||
});
|
||||
if (is_once_) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
return changes;
|
||||
}
|
||||
|
||||
// for transform status counting
|
||||
bool SubstitutionList::ApplySubstitutionsToIR(const OptimizerPtr &optimizer, const FuncGraphPtr &func_graph) const {
|
||||
// Add for substitution status counting
|
||||
size_t space = 0;
|
||||
std::unordered_map<std::string, std::vector<bool>> status;
|
||||
if (optimizer->is_on_debug_) {
|
||||
|
@ -197,47 +269,39 @@ bool SubstitutionList::operator()(const FuncGraphPtr &func_graph, const Optimize
|
|||
}
|
||||
}
|
||||
|
||||
bool loop = false;
|
||||
bool changes = false;
|
||||
|
||||
do {
|
||||
bool loop = true;
|
||||
while (loop) {
|
||||
loop = false;
|
||||
for (size_t i = 0; i < list_.size(); i++) {
|
||||
auto change = ApplyTransform(optimizer, func_graph->output(), list_[i]);
|
||||
const auto &substitution = list_[i];
|
||||
bool change = ApplySubstitutionToIR(optimizer, func_graph->output(), substitution);
|
||||
changes = changes || change;
|
||||
loop = loop || change;
|
||||
|
||||
// apply transform on isolate nodes.
|
||||
auto &isolate_nodes = manager->isolate_nodes();
|
||||
for (auto &node : isolate_nodes) {
|
||||
change = ApplyTransform(optimizer, node, list_[i]);
|
||||
changes = changes || change;
|
||||
loop = loop || change;
|
||||
}
|
||||
|
||||
// record the status of each transform
|
||||
static const auto enable_dump_pass_ir = (common::GetEnv("ENV_DUMP_PASS_IR") == "1");
|
||||
if (enable_dump_pass_ir && MsContext::GetInstance()->get_param<bool>(MS_CTX_SAVE_GRAPHS_FLAG)) {
|
||||
auto fg_name = optimizer->name() + "_" + std::to_string(optimizer->CurPass_.counter) + "_" +
|
||||
optimizer->CurPass_.name + "_" + list_[i]->name_;
|
||||
auto fg_name = optimizer->name() + "_r" + std::to_string(optimizer->CurPass_.counter) + "_" +
|
||||
optimizer->CurPass_.name + "_" + substitution->name_;
|
||||
DumpIR(fg_name + ".ir", func_graph);
|
||||
if (MsContext::GetInstance()->get_param<int>(MS_CTX_EXECUTION_MODE) != kPynativeMode) {
|
||||
func_graph->DumpFuncGraph(fg_name);
|
||||
ExportIR(fg_name + ".dat", "", func_graph);
|
||||
}
|
||||
}
|
||||
|
||||
// Record the status of each substitution
|
||||
if (optimizer->is_on_debug_) {
|
||||
status[list_[i]->name_ + std::to_string(i)].push_back(change);
|
||||
space = std::max(list_[i]->name_.size(), space);
|
||||
status[substitution->name_ + std::to_string(i)].push_back(change);
|
||||
space = std::max(substitution->name_.size(), space);
|
||||
}
|
||||
}
|
||||
|
||||
if (is_once_) {
|
||||
break;
|
||||
}
|
||||
} while (loop);
|
||||
}
|
||||
|
||||
// display the status of each transform
|
||||
// Display the status of each substitution
|
||||
if (optimizer->is_on_debug_) {
|
||||
std::stringstream ss;
|
||||
ss << std::endl
|
||||
|
@ -253,7 +317,37 @@ bool SubstitutionList::operator()(const FuncGraphPtr &func_graph, const Optimize
|
|||
}
|
||||
MS_LOG(DEBUG) << ss.str();
|
||||
}
|
||||
return changes;
|
||||
}
|
||||
|
||||
bool SubstitutionList::operator()(const FuncGraphPtr &func_graph, const OptimizerPtr &optimizer) const {
|
||||
MS_EXCEPTION_IF_NULL(optimizer);
|
||||
MS_EXCEPTION_IF_NULL(func_graph);
|
||||
FuncGraphManagerPtr manager = optimizer->manager();
|
||||
manager->AddFuncGraph(func_graph);
|
||||
bool changes = false;
|
||||
static const auto traverse_mode =
|
||||
(common::GetEnv("ENV_TRAVERSE_SUBSTITUTIONS_MODE") != "1" ? kOptTraverseFromIRToSubstitutions
|
||||
: kOptTraverseFromSubstitutionsToIR);
|
||||
if (traverse_mode == kOptTraverseFromIRToSubstitutions &&
|
||||
MsContext::GetInstance()->get_param<int>(MS_CTX_EXECUTION_MODE) != kPynativeMode &&
|
||||
optimizer->traverse_nodes_first()) {
|
||||
changes = ApplyIRToSubstitutions(optimizer, func_graph);
|
||||
} else {
|
||||
changes = ApplySubstitutionsToIR(optimizer, func_graph);
|
||||
}
|
||||
|
||||
bool has_isolate = !manager->isolate_nodes().empty();
|
||||
if (has_isolate) {
|
||||
#ifdef ENABLE_PROFILE
|
||||
double t = GetTime();
|
||||
#endif
|
||||
bool change = ApplySubstitutionsToIRForIsolate(optimizer);
|
||||
changes = changes || change;
|
||||
#ifdef ENABLE_PROFILE
|
||||
MsProfile::StatTime("opt.isolate.transform." + optimizer->name(), GetTime() - t);
|
||||
#endif
|
||||
}
|
||||
return changes;
|
||||
}
|
||||
} // namespace opt
|
||||
|
|
|
@ -59,6 +59,8 @@ SubstitutionPtr MakeSubstitution(const OptimizerCallerPtr &transform, const std:
|
|||
SubstitutionPtr MakeSubstitution(const OptimizerCallerPtr &transform, const std::string &name,
|
||||
const PredicateFuncType &predicate, const RenormAction &action_renorm = CHECK_RENORM);
|
||||
|
||||
enum OptTraverseSubstitutionsMode { kOptTraverseFromIRToSubstitutions = 0, kOptTraverseFromSubstitutionsToIR };
|
||||
|
||||
class SubstitutionList {
|
||||
public:
|
||||
explicit SubstitutionList(const std::vector<SubstitutionPtr> &patterns, bool is_once = false)
|
||||
|
@ -68,7 +70,10 @@ class SubstitutionList {
|
|||
bool operator()(const FuncGraphPtr &func_graph, const OptimizerPtr &optimizer) const;
|
||||
|
||||
private:
|
||||
bool ApplyTransform(const OptimizerPtr &optimizer, const AnfNodePtr &node, const SubstitutionPtr &transform) const;
|
||||
bool ApplyIRToSubstitutions(const OptimizerPtr &optimizer, const FuncGraphPtr &func_graph) const;
|
||||
bool ApplySubstitutionToIR(const OptimizerPtr &optimizer, const AnfNodePtr &node, const SubstitutionPtr &sub) const;
|
||||
bool ApplySubstitutionsToIR(const OptimizerPtr &optimizer, const FuncGraphPtr &func_graph) const;
|
||||
bool ApplySubstitutionsToIRForIsolate(const OptimizerPtr &optimizer) const;
|
||||
std::vector<SubstitutionPtr> list_;
|
||||
// a flag to mark this list of Substitution can only be executed only once
|
||||
bool is_once_;
|
||||
|
|
|
@ -88,13 +88,14 @@ using OptPassGroupMap = std::vector<std::pair<std::string, OptPassConfig>>;
|
|||
|
||||
class Optimizer : public std::enable_shared_from_this<Optimizer> {
|
||||
public:
|
||||
Optimizer(const std::string &name, const pipeline::ResourceBasePtr &resource_ptr)
|
||||
Optimizer(const std::string &name, const pipeline::ResourceBasePtr &resource_ptr, bool traverse_nodes_first = true)
|
||||
: name_(name),
|
||||
resource_(resource_ptr),
|
||||
run_only_once_(false),
|
||||
is_watch_renormalize_(false),
|
||||
is_enable_(true),
|
||||
is_untyped_generated_(false) {}
|
||||
is_untyped_generated_(false),
|
||||
traverse_nodes_first_(traverse_nodes_first) {}
|
||||
virtual ~Optimizer() = default;
|
||||
|
||||
void Init(const OptPassGroupMap &passes, bool run_only_once) {
|
||||
|
@ -129,8 +130,8 @@ class Optimizer : public std::enable_shared_from_this<Optimizer> {
|
|||
|
||||
static std::shared_ptr<Optimizer> MakeOptimizer(const std::string &name, const pipeline::ResourceBasePtr resource_ptr,
|
||||
const OptPassGroupMap &passes, bool run_only_once = false,
|
||||
bool watch_renormalize = false) {
|
||||
OptimizerPtr optimizer = std::make_shared<Optimizer>(name, resource_ptr);
|
||||
bool watch_renormalize = false, bool traverse_nodes_first = true) {
|
||||
OptimizerPtr optimizer = std::make_shared<Optimizer>(name, resource_ptr, traverse_nodes_first);
|
||||
optimizer->Init(passes, run_only_once);
|
||||
if (watch_renormalize) {
|
||||
optimizer->enable_watch_renormalize();
|
||||
|
@ -223,6 +224,8 @@ class Optimizer : public std::enable_shared_from_this<Optimizer> {
|
|||
bool is_watch_renormalize() { return is_watch_renormalize_; }
|
||||
void set_enable(bool enable) { is_enable_ = enable; }
|
||||
|
||||
bool traverse_nodes_first() { return traverse_nodes_first_; }
|
||||
|
||||
struct {
|
||||
int64_t counter;
|
||||
std::string name;
|
||||
|
@ -239,6 +242,7 @@ class Optimizer : public std::enable_shared_from_this<Optimizer> {
|
|||
bool is_watch_renormalize_;
|
||||
bool is_enable_;
|
||||
bool is_untyped_generated_;
|
||||
bool traverse_nodes_first_;
|
||||
};
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -308,7 +308,8 @@ bool ResolveFuncGraph(const FuncGraphPtr &func_graph, const pipeline::ResourceBa
|
|||
return false;
|
||||
}
|
||||
opt::irpass::ResolveIRPassLib irpass;
|
||||
opt::OptimizerPtr opt_resolve = opt::Optimizer::MakeOptimizer("opt_resolve", res, GetOptResolvePasses(irpass));
|
||||
opt::OptimizerPtr opt_resolve =
|
||||
opt::Optimizer::MakeOptimizer("opt_resolve", res, GetOptResolvePasses(irpass), false, false, false);
|
||||
|
||||
(void)parse::python_adapter::set_python_scoped();
|
||||
|
||||
|
|
|
@ -246,7 +246,7 @@ class CNode : public AnfNode, public EffectInfoHolder {
|
|||
bool IsApply(const PrimitivePtr &) const;
|
||||
|
||||
const size_t size() const { return inputs_.size(); }
|
||||
const AnfNodePtr input(size_t i) const { return inputs_[i]; }
|
||||
const AnfNodePtr &input(size_t i) const { return inputs_.at(i); }
|
||||
const std::vector<AnfNodePtr> &inputs() const { return inputs_; }
|
||||
void add_input(const AnfNodePtr &input);
|
||||
void set_input(size_t i, const AnfNodePtr &input);
|
||||
|
|
Loading…
Reference in New Issue