fix recursive specialize

This commit is contained in:
chenfei 2023-01-05 19:14:02 +08:00
parent dba4f661d7
commit fabeee307b
2 changed files with 109 additions and 64 deletions

View File

@ -25,7 +25,6 @@
#include "frontend/operator/composite/do_signature.h"
#include "abstract/abstract_function.h"
#include "abstract/utils.h"
#include "include/common/utils/utils.h"
#include "ir/graph_utils.h"
#include "utils/log_adapter.h"
#include "pipeline/jit/debug/trace.h"
@ -188,26 +187,23 @@ FuncGraphPtr ProgramSpecializer::Run(const FuncGraphPtr &fg, const AnalysisConte
top_context_ = context;
MS_LOG(INFO) << "Specialize set top func graph context: " << context->ToString();
}
auto res = SpecializeFuncGraph(fg, context);
auto top_func_graph_spec = NewFuncGraphSpecializer(context, fg);
PushFuncGraphTodoItem(top_func_graph_spec);
while (!func_graph_todo_items_.empty()) {
auto current_fg_spec = func_graph_todo_items_.top();
if (current_fg_spec->done()) {
func_graph_todo_items_.pop();
continue;
}
// run current func graph specializer
current_fg_spec->Run();
}
auto res = top_func_graph_spec->specialized_func_graph();
MS_LOG(DEBUG) << "Specialized top graph:" << res->ToString();
EliminateCollectedSequenceNodes(this);
return res;
}
FuncGraphPtr ProgramSpecializer::SpecializeFuncGraph(const FuncGraphPtr &fg, const AnalysisContextPtr &context) {
MS_EXCEPTION_IF_NULL(fg);
MS_EXCEPTION_IF_NULL(context);
auto iter = specializations_.find(context);
if (iter != specializations_.end()) {
MS_EXCEPTION_IF_NULL(iter->second);
return iter->second->specialized_func_graph();
}
auto fg_spec = std::make_shared<FuncGraphSpecializer>(this, fg, context);
auto specialized_func_graph = fg_spec->specialized_func_graph();
specializations_[context] = fg_spec;
fg_spec->Run();
return specialized_func_graph;
}
std::shared_ptr<FuncGraphSpecializer> ProgramSpecializer::GetFuncGraphSpecializer(const AnalysisContextPtr &context) {
MS_EXCEPTION_IF_NULL(context);
auto iter = specializations_.find(context);
@ -217,6 +213,19 @@ std::shared_ptr<FuncGraphSpecializer> ProgramSpecializer::GetFuncGraphSpecialize
return nullptr;
}
FuncGraphSpecializerPtr ProgramSpecializer::NewFuncGraphSpecializer(const AnalysisContextPtr &context,
const FuncGraphPtr &fg) {
MS_EXCEPTION_IF_NULL(context);
auto result = specializations_.emplace(context, nullptr);
if (result.second) {
MS_LOG(DEBUG) << "Make new specializer of context: " << context->ToString() << ", fg: " << fg->ToString();
auto fg_spec = std::make_shared<FuncGraphSpecializer>(this, fg, context);
result.first->second = fg_spec;
return fg_spec;
}
MS_LOG(EXCEPTION) << "Specializer exist in cache, can't not create again, context: " << context->ToString();
}
void ProgramSpecializer::PutSpecializedAbstract(const CNodePtr &cnode, const AnfNodePtr &func,
const AbstractFunctionPtr &old_abs_func,
const AbstractFunctionPtr &new_abs_func) {
@ -363,7 +372,8 @@ FuncGraphSpecializer::FuncGraphSpecializer(ProgramSpecializer *const s, const Fu
const AnalysisContextPtr &context)
: specializer_(s), func_graph_(fg), context_(context) {
parent_ = s->GetFuncGraphSpecializer(context->parent());
if (parent_ == nullptr && context->parent()->func_graph() != nullptr) { // If context's not dummy context.
MS_EXCEPTION_IF_NULL(context->parent());
if (ParentNotSpecialized(context)) {
MS_LOG(EXCEPTION) << "Parent func graph should be handled in advance, fg: " << fg->ToString()
<< ", context: " << context->ToString() << ", parent context: " << context->parent()->ToString();
}
@ -502,7 +512,6 @@ void FuncGraphSpecializer::Run() {
? specialized_func_graph_->get_return()->DebugString()
: "return null"
: "FG(null)");
ProcessDeferredCNode();
}
void FuncGraphSpecializer::FirstPass() {
@ -547,9 +556,22 @@ void FuncGraphSpecializer::FirstPass() {
// Specialize CNode in func graphs
void FuncGraphSpecializer::SecondPass() {
for (const auto &cnode : BroadFirstSearchGraphCNodes(specialized_func_graph_->return_node())) {
ProcessCNode(cnode);
if (second_pass_todo_.empty()) {
second_pass_todo_ = BroadFirstSearchGraphCNodes(specialized_func_graph_->return_node());
}
MS_LOG(DEBUG) << "Start in index: " << second_pass_todo_inedx_ << ", fg: " << func_graph_->ToString()
<< ", todo list size:" << second_pass_todo_.size();
while (second_pass_todo_inedx_ < second_pass_todo_.size()) {
auto success = ProcessCNode(second_pass_todo_[second_pass_todo_inedx_]);
if (!success) {
MS_LOG(DEBUG) << "Suspend in index:" << second_pass_todo_inedx_
<< ", node: " << second_pass_todo_[second_pass_todo_inedx_]->DebugString();
return;
}
++second_pass_todo_inedx_;
}
MS_LOG(DEBUG) << "Set done of fg: " << func_graph_->ToString();
done_ = true;
}
namespace {
@ -921,6 +943,10 @@ AnfNodePtr FuncGraphSpecializer::BuildSpecializedNode(const CNodePtr &cnode, con
ScopeGuard scope_guard(func->scope());
AnfNodePtr repl = BuildSpecializedNodeInner(cnode, func, abs, func_abs, argvals, &errcode);
if (repl == nullptr) {
// If errcode is success, it means child graph specialize.
if (errcode == kSpecializeSuccess) {
return nullptr;
}
if (errcode == kSpecializeDead) {
const auto err_dead_value = std::make_shared<ErrorValue>(ErrorValueType::kDead);
const auto err_dead_abstract = std::make_shared<AbstractError>(err_dead_value, func);
@ -1010,10 +1036,21 @@ AnfNodePtr FuncGraphSpecializer::BuildSpecializedNodeInner(const CNodePtr &cnode
<< ", " << func->ToString();
return func;
}
if (RecordDeferredCNode(cnode, context)) {
return func;
// Get the upper most func graph of which parent has been specialized.
while (ParentNotSpecialized(context)) {
context = context->parent();
}
FuncGraphPtr func_graph = specializer_->SpecializeFuncGraph(context->func_graph(), context);
auto fg_spec = specializer_->GetFuncGraphSpecializer(context);
// If func graph specializer dose not exist before, make a new specializer and push to stack, and return nullptr.
if (fg_spec == nullptr) {
fg_spec = specializer_->NewFuncGraphSpecializer(context, context->func_graph());
specializer_->PushFuncGraphTodoItem(fg_spec);
return nullptr;
}
FuncGraphPtr func_graph = fg_spec->specialized_func_graph();
MS_LOG(DEBUG) << "Get spec fg of func graph:" << context->func_graph()->ToString()
<< ", specialized fg:" << func_graph->ToString();
MS_EXCEPTION_IF_NULL(func_graph);
func_graph->set_flag(kFuncGraphFlagUndetermined, false);
static auto dummy_context = AnalysisContext::DummyContext();
@ -1051,6 +1088,9 @@ AnfNodePtr FuncGraphSpecializer::BuildSpecializedParameterNode(const CNodePtr &c
ScopeGuard scope_guard(cnode->scope());
auto specialized_node = BuildSpecializedNode(cnode, func, backed_fnval, args);
if (specialized_node == nullptr) {
return nullptr;
}
auto wrapped_node = specialized_node;
if (fnval->isa<PartialAbstractClosure>()) {
auto partial_closure = dyn_cast<PartialAbstractClosure>(fnval);
@ -1161,13 +1201,16 @@ std::pair<AbstractBasePtrList, AbstractBasePtr> FuncGraphSpecializer::BuildFromB
return std::make_pair(AbstractBasePtrList(), nullptr);
}
void FuncGraphSpecializer::ProcessCNode(const CNodePtr &cnode) {
bool IsHighOrderCall(const AnfNodePtr &func) {
return !func->isa<ValueNode>() && func->abstract()->isa<AbstractFunction>() &&
!func->abstract()->isa<AbstractFuncUnion>();
}
bool FuncGraphSpecializer::ProcessCNode(const CNodePtr &cnode) {
MS_EXCEPTION_IF_NULL(cnode);
if (specializer_->seen().count(cnode) > 0) {
return;
return true;
}
specializer_->AddSeen(cnode);
auto new_inputs = cnode->inputs();
if (new_inputs.empty()) {
MS_LOG(EXCEPTION) << "Inputs of CNode is empty";
@ -1191,8 +1234,7 @@ void FuncGraphSpecializer::ProcessCNode(const CNodePtr &cnode) {
(void)new_inputs.insert(new_inputs.cbegin(), func);
// Deal with the CNode|Parameter function call including Partial closure ahead.
if (!func->isa<ValueNode>() && func->abstract()->isa<AbstractFunction>() &&
!func->abstract()->isa<AbstractFuncUnion>()) {
if (IsHighOrderCall(func)) {
auto func_abs = func->abstract()->cast<AbstractFunctionPtr>();
EvaluatorPtr eval = engine_->GetEvaluatorFor(func_abs);
std::pair<AbstractBasePtrList, AbstractBasePtr> result;
@ -1206,6 +1248,9 @@ void FuncGraphSpecializer::ProcessCNode(const CNodePtr &cnode) {
if (status == kSpecializePoly ||
(func->isa<Parameter>() && func->func_graph()->has_flag(FUNC_GRAPH_FLAG_SPECIALIZE_PARAMETER))) {
auto wrapped_node = BuildSpecializedParameterNode(cnode);
if (wrapped_node == nullptr) {
return false;
}
MS_LOG(DEBUG) << "Partial closure is handled, wrapped_node: " << wrapped_node->DebugString(recursive_level);
new_inputs[0] = wrapped_node;
}
@ -1225,7 +1270,11 @@ void FuncGraphSpecializer::ProcessCNode(const CNodePtr &cnode) {
for (size_t i = 1; i < new_inputs.size(); ++i) {
argvals.push_back(new_inputs[i]->abstract());
}
new_inputs[0] = BuildSpecializedNode(cnode, func, fnval, argvals);
auto specialized_func_node = BuildSpecializedNode(cnode, func, fnval, argvals);
if (specialized_func_node == nullptr) {
return false;
}
new_inputs[0] = specialized_func_node;
MS_LOG(DEBUG) << "Specalize func: " << func->type_name() << "/" << func->DebugString(recursive_level)
<< ", new_func: " << new_inputs[0]->DebugString(recursive_level) << ", argvals: " << argvals;
}
@ -1236,7 +1285,9 @@ void FuncGraphSpecializer::ProcessCNode(const CNodePtr &cnode) {
auto &old_node = new_inputs[i];
if (CanSpecializeValueNode(old_node)) {
auto new_node = BuildSpecializedNode(cnode, old_node, old_node->abstract(), std::vector<AbstractBasePtr>{});
MS_EXCEPTION_IF_NULL(new_node);
if (new_node == nullptr) {
return false;
}
MS_LOG(DEBUG) << "Specalize arg[" << i << "]: " << old_node->DebugString(recursive_level)
<< ", new_node: " << new_node->DebugString(recursive_level);
new_inputs[i] = new_node;
@ -1252,35 +1303,17 @@ void FuncGraphSpecializer::ProcessCNode(const CNodePtr &cnode) {
if (enable_eliminate_unused_element && !enable_only_mark_unused_element) {
EliminateUnusedSequenceItem(cnode);
}
// Only success processed node can be added to seen.
specializer_->AddSeen(cnode);
return true;
}
bool FuncGraphSpecializer::RecordDeferredCNode(const CNodePtr &cnode, const AnalysisContextPtr &context) {
bool FuncGraphSpecializer::ParentNotSpecialized(const AnalysisContextPtr &context) {
auto parent_context = context->parent();
auto parent_not_specialized = parent_context != nullptr && parent_context->func_graph() != nullptr &&
specializer_->GetFuncGraphSpecializer(parent_context) == nullptr;
if (parent_not_specialized) {
MS_LOG(DEBUG) << "Add deferred specialize cnode: " << cnode->DebugString()
<< ", parent context:" << parent_context->ToString();
specializer_->EraseSeen(cnode);
specializer_->AddDeferSpecializeNode(parent_context, cnode, this);
return true;
}
return false;
}
void FuncGraphSpecializer::ProcessDeferredCNode() {
// Check whether there were call nodes depend on current func_graph_spcializer finished.
auto parent_context_it = specializer_->defer_specialize_nodes().find(context_);
if (parent_context_it == specializer_->defer_specialize_nodes().cend()) {
return;
}
const auto &defer_specialize_call_nodes = parent_context_it->second;
for (const auto &[func_sepcializer, deferred_node] : defer_specialize_call_nodes) {
MS_LOG(DEBUG) << "Start second pass of deferred node: " << deferred_node->DebugString();
func_sepcializer->ProcessCNode(deferred_node);
}
// Erase the current context which was parent context of deferred nodes.
specializer_->RemoveDeferSpecializeNode(parent_context_it);
auto parent_specializer = specializer_->GetFuncGraphSpecializer(parent_context);
// If can't get specializer of parent and parent is not DummyContext, it means parent not specialized.
auto parent_not_specialized = parent_specializer == nullptr && parent_context->func_graph() != nullptr;
return parent_not_specialized;
}
namespace {

View File

@ -25,6 +25,7 @@
#include <utility>
#include <vector>
#include <unordered_map>
#include <stack>
#include "utils/hash_map.h"
#include "utils/hash_set.h"
@ -60,8 +61,9 @@ class ProgramSpecializer {
void EraseSeen(const AnfNodePtr &node) { (void)seen_.erase(node); }
std::shared_ptr<FuncGraphSpecializer> GetFuncGraphSpecializer(const AnalysisContextPtr &context);
// Specialze one FuncGraph in a given context.
FuncGraphPtr SpecializeFuncGraph(const FuncGraphPtr &fg, const AnalysisContextPtr &context);
std::shared_ptr<FuncGraphSpecializer> NewFuncGraphSpecializer(const AnalysisContextPtr &context,
const FuncGraphPtr &fg);
std::shared_ptr<AnalysisEngine> engine() { return engine_; }
@ -105,6 +107,10 @@ class ProgramSpecializer {
return nullptr;
}
void PushFuncGraphTodoItem(const std::shared_ptr<FuncGraphSpecializer> &fg_spec) {
func_graph_todo_items_.push(fg_spec);
}
private:
std::shared_ptr<AnalysisEngine> engine_;
mindspore::HashSet<AnfNodePtr> seen_;
@ -127,6 +133,8 @@ class ProgramSpecializer {
mindspore::HashMap<FuncGraphPtr, std::pair<bool, AbstractFunctionPtr>> func_graph_to_abstract_map_;
AbstractFunctionPtr SpecializeAbstractFuncRecursively(const AbstractFunctionPtr &old_abs_func);
std::stack<std::shared_ptr<FuncGraphSpecializer>> func_graph_todo_items_;
};
class FuncGraphSpecializer : public std::enable_shared_from_this<FuncGraphSpecializer> {
@ -138,6 +146,8 @@ class FuncGraphSpecializer : public std::enable_shared_from_this<FuncGraphSpecia
std::shared_ptr<FuncGraphSpecializer> GetTopSpecializer(const AnfNodePtr &node);
bool done() const { return done_; }
private:
ProgramSpecializer *specializer_;
FuncGraphPtr func_graph_;
@ -149,14 +159,15 @@ class FuncGraphSpecializer : public std::enable_shared_from_this<FuncGraphSpecia
std::vector<AnfNodePtr> todo_;
mindspore::HashSet<AnfNodePtr> marked_;
mindspore::HashMap<EvaluatorPtr, EvaluatorCacheMgrPtr> eval_cache_;
std::vector<CNodePtr> second_pass_todo_;
size_t second_pass_todo_inedx_{0};
bool done_{false};
void FirstPass();
void SecondPass();
void ProcessNode(const AnfNodePtr &node);
void ProcessCNode(const CNodePtr &cnode);
// If cnode need deferred specialized, return true, otherwise return false.
bool RecordDeferredCNode(const CNodePtr &cnode, const AnalysisContextPtr &context);
void ProcessDeferredCNode();
bool ProcessCNode(const CNodePtr &cnode);
bool ParentNotSpecialized(const AnalysisContextPtr &context);
void EliminateUnusedSequenceItem(const CNodePtr &cnode) const;
@ -209,6 +220,7 @@ class FuncGraphSpecializer : public std::enable_shared_from_this<FuncGraphSpecia
std::pair<AbstractBasePtrList, AbstractBasePtr> BuildFromBroadedArgsVal(const EvaluatorPtr &eval);
void UpdateNewCNodeInputs(const AnfNodePtr &node, const AnfNodePtr &new_node);
};
using FuncGraphSpecializerPtr = std::shared_ptr<FuncGraphSpecializer>;
} // namespace abstract
} // namespace mindspore
#endif // MINDSPORE_CCSRC_PIPELINE_JIT_STATIC_ANALYSIS_SPECIALIZE_H_