forked from mindspore-Ecosystem/mindspore
fix recursive specialize
This commit is contained in:
parent
dba4f661d7
commit
fabeee307b
|
@ -25,7 +25,6 @@
|
|||
#include "frontend/operator/composite/do_signature.h"
|
||||
#include "abstract/abstract_function.h"
|
||||
#include "abstract/utils.h"
|
||||
#include "include/common/utils/utils.h"
|
||||
#include "ir/graph_utils.h"
|
||||
#include "utils/log_adapter.h"
|
||||
#include "pipeline/jit/debug/trace.h"
|
||||
|
@ -188,26 +187,23 @@ FuncGraphPtr ProgramSpecializer::Run(const FuncGraphPtr &fg, const AnalysisConte
|
|||
top_context_ = context;
|
||||
MS_LOG(INFO) << "Specialize set top func graph context: " << context->ToString();
|
||||
}
|
||||
auto res = SpecializeFuncGraph(fg, context);
|
||||
auto top_func_graph_spec = NewFuncGraphSpecializer(context, fg);
|
||||
PushFuncGraphTodoItem(top_func_graph_spec);
|
||||
while (!func_graph_todo_items_.empty()) {
|
||||
auto current_fg_spec = func_graph_todo_items_.top();
|
||||
if (current_fg_spec->done()) {
|
||||
func_graph_todo_items_.pop();
|
||||
continue;
|
||||
}
|
||||
// run current func graph specializer
|
||||
current_fg_spec->Run();
|
||||
}
|
||||
auto res = top_func_graph_spec->specialized_func_graph();
|
||||
MS_LOG(DEBUG) << "Specialized top graph:" << res->ToString();
|
||||
EliminateCollectedSequenceNodes(this);
|
||||
return res;
|
||||
}
|
||||
|
||||
FuncGraphPtr ProgramSpecializer::SpecializeFuncGraph(const FuncGraphPtr &fg, const AnalysisContextPtr &context) {
|
||||
MS_EXCEPTION_IF_NULL(fg);
|
||||
MS_EXCEPTION_IF_NULL(context);
|
||||
auto iter = specializations_.find(context);
|
||||
if (iter != specializations_.end()) {
|
||||
MS_EXCEPTION_IF_NULL(iter->second);
|
||||
return iter->second->specialized_func_graph();
|
||||
}
|
||||
auto fg_spec = std::make_shared<FuncGraphSpecializer>(this, fg, context);
|
||||
auto specialized_func_graph = fg_spec->specialized_func_graph();
|
||||
specializations_[context] = fg_spec;
|
||||
fg_spec->Run();
|
||||
return specialized_func_graph;
|
||||
}
|
||||
|
||||
std::shared_ptr<FuncGraphSpecializer> ProgramSpecializer::GetFuncGraphSpecializer(const AnalysisContextPtr &context) {
|
||||
MS_EXCEPTION_IF_NULL(context);
|
||||
auto iter = specializations_.find(context);
|
||||
|
@ -217,6 +213,19 @@ std::shared_ptr<FuncGraphSpecializer> ProgramSpecializer::GetFuncGraphSpecialize
|
|||
return nullptr;
|
||||
}
|
||||
|
||||
FuncGraphSpecializerPtr ProgramSpecializer::NewFuncGraphSpecializer(const AnalysisContextPtr &context,
|
||||
const FuncGraphPtr &fg) {
|
||||
MS_EXCEPTION_IF_NULL(context);
|
||||
auto result = specializations_.emplace(context, nullptr);
|
||||
if (result.second) {
|
||||
MS_LOG(DEBUG) << "Make new specializer of context: " << context->ToString() << ", fg: " << fg->ToString();
|
||||
auto fg_spec = std::make_shared<FuncGraphSpecializer>(this, fg, context);
|
||||
result.first->second = fg_spec;
|
||||
return fg_spec;
|
||||
}
|
||||
MS_LOG(EXCEPTION) << "Specializer exist in cache, can't not create again, context: " << context->ToString();
|
||||
}
|
||||
|
||||
void ProgramSpecializer::PutSpecializedAbstract(const CNodePtr &cnode, const AnfNodePtr &func,
|
||||
const AbstractFunctionPtr &old_abs_func,
|
||||
const AbstractFunctionPtr &new_abs_func) {
|
||||
|
@ -363,7 +372,8 @@ FuncGraphSpecializer::FuncGraphSpecializer(ProgramSpecializer *const s, const Fu
|
|||
const AnalysisContextPtr &context)
|
||||
: specializer_(s), func_graph_(fg), context_(context) {
|
||||
parent_ = s->GetFuncGraphSpecializer(context->parent());
|
||||
if (parent_ == nullptr && context->parent()->func_graph() != nullptr) { // If context's not dummy context.
|
||||
MS_EXCEPTION_IF_NULL(context->parent());
|
||||
if (ParentNotSpecialized(context)) {
|
||||
MS_LOG(EXCEPTION) << "Parent func graph should be handled in advance, fg: " << fg->ToString()
|
||||
<< ", context: " << context->ToString() << ", parent context: " << context->parent()->ToString();
|
||||
}
|
||||
|
@ -502,7 +512,6 @@ void FuncGraphSpecializer::Run() {
|
|||
? specialized_func_graph_->get_return()->DebugString()
|
||||
: "return null"
|
||||
: "FG(null)");
|
||||
ProcessDeferredCNode();
|
||||
}
|
||||
|
||||
void FuncGraphSpecializer::FirstPass() {
|
||||
|
@ -547,9 +556,22 @@ void FuncGraphSpecializer::FirstPass() {
|
|||
|
||||
// Specialize CNode in func graphs
|
||||
void FuncGraphSpecializer::SecondPass() {
|
||||
for (const auto &cnode : BroadFirstSearchGraphCNodes(specialized_func_graph_->return_node())) {
|
||||
ProcessCNode(cnode);
|
||||
if (second_pass_todo_.empty()) {
|
||||
second_pass_todo_ = BroadFirstSearchGraphCNodes(specialized_func_graph_->return_node());
|
||||
}
|
||||
MS_LOG(DEBUG) << "Start in index: " << second_pass_todo_inedx_ << ", fg: " << func_graph_->ToString()
|
||||
<< ", todo list size:" << second_pass_todo_.size();
|
||||
while (second_pass_todo_inedx_ < second_pass_todo_.size()) {
|
||||
auto success = ProcessCNode(second_pass_todo_[second_pass_todo_inedx_]);
|
||||
if (!success) {
|
||||
MS_LOG(DEBUG) << "Suspend in index:" << second_pass_todo_inedx_
|
||||
<< ", node: " << second_pass_todo_[second_pass_todo_inedx_]->DebugString();
|
||||
return;
|
||||
}
|
||||
++second_pass_todo_inedx_;
|
||||
}
|
||||
MS_LOG(DEBUG) << "Set done of fg: " << func_graph_->ToString();
|
||||
done_ = true;
|
||||
}
|
||||
|
||||
namespace {
|
||||
|
@ -921,6 +943,10 @@ AnfNodePtr FuncGraphSpecializer::BuildSpecializedNode(const CNodePtr &cnode, con
|
|||
ScopeGuard scope_guard(func->scope());
|
||||
AnfNodePtr repl = BuildSpecializedNodeInner(cnode, func, abs, func_abs, argvals, &errcode);
|
||||
if (repl == nullptr) {
|
||||
// If errcode is success, it means child graph specialize.
|
||||
if (errcode == kSpecializeSuccess) {
|
||||
return nullptr;
|
||||
}
|
||||
if (errcode == kSpecializeDead) {
|
||||
const auto err_dead_value = std::make_shared<ErrorValue>(ErrorValueType::kDead);
|
||||
const auto err_dead_abstract = std::make_shared<AbstractError>(err_dead_value, func);
|
||||
|
@ -1010,10 +1036,21 @@ AnfNodePtr FuncGraphSpecializer::BuildSpecializedNodeInner(const CNodePtr &cnode
|
|||
<< ", " << func->ToString();
|
||||
return func;
|
||||
}
|
||||
if (RecordDeferredCNode(cnode, context)) {
|
||||
return func;
|
||||
// Get the upper most func graph of which parent has been specialized.
|
||||
while (ParentNotSpecialized(context)) {
|
||||
context = context->parent();
|
||||
}
|
||||
FuncGraphPtr func_graph = specializer_->SpecializeFuncGraph(context->func_graph(), context);
|
||||
auto fg_spec = specializer_->GetFuncGraphSpecializer(context);
|
||||
// If func graph specializer dose not exist before, make a new specializer and push to stack, and return nullptr.
|
||||
if (fg_spec == nullptr) {
|
||||
fg_spec = specializer_->NewFuncGraphSpecializer(context, context->func_graph());
|
||||
specializer_->PushFuncGraphTodoItem(fg_spec);
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
FuncGraphPtr func_graph = fg_spec->specialized_func_graph();
|
||||
MS_LOG(DEBUG) << "Get spec fg of func graph:" << context->func_graph()->ToString()
|
||||
<< ", specialized fg:" << func_graph->ToString();
|
||||
MS_EXCEPTION_IF_NULL(func_graph);
|
||||
func_graph->set_flag(kFuncGraphFlagUndetermined, false);
|
||||
static auto dummy_context = AnalysisContext::DummyContext();
|
||||
|
@ -1051,6 +1088,9 @@ AnfNodePtr FuncGraphSpecializer::BuildSpecializedParameterNode(const CNodePtr &c
|
|||
|
||||
ScopeGuard scope_guard(cnode->scope());
|
||||
auto specialized_node = BuildSpecializedNode(cnode, func, backed_fnval, args);
|
||||
if (specialized_node == nullptr) {
|
||||
return nullptr;
|
||||
}
|
||||
auto wrapped_node = specialized_node;
|
||||
if (fnval->isa<PartialAbstractClosure>()) {
|
||||
auto partial_closure = dyn_cast<PartialAbstractClosure>(fnval);
|
||||
|
@ -1161,13 +1201,16 @@ std::pair<AbstractBasePtrList, AbstractBasePtr> FuncGraphSpecializer::BuildFromB
|
|||
return std::make_pair(AbstractBasePtrList(), nullptr);
|
||||
}
|
||||
|
||||
void FuncGraphSpecializer::ProcessCNode(const CNodePtr &cnode) {
|
||||
bool IsHighOrderCall(const AnfNodePtr &func) {
|
||||
return !func->isa<ValueNode>() && func->abstract()->isa<AbstractFunction>() &&
|
||||
!func->abstract()->isa<AbstractFuncUnion>();
|
||||
}
|
||||
|
||||
bool FuncGraphSpecializer::ProcessCNode(const CNodePtr &cnode) {
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
if (specializer_->seen().count(cnode) > 0) {
|
||||
return;
|
||||
return true;
|
||||
}
|
||||
specializer_->AddSeen(cnode);
|
||||
|
||||
auto new_inputs = cnode->inputs();
|
||||
if (new_inputs.empty()) {
|
||||
MS_LOG(EXCEPTION) << "Inputs of CNode is empty";
|
||||
|
@ -1191,8 +1234,7 @@ void FuncGraphSpecializer::ProcessCNode(const CNodePtr &cnode) {
|
|||
(void)new_inputs.insert(new_inputs.cbegin(), func);
|
||||
|
||||
// Deal with the CNode|Parameter function call including Partial closure ahead.
|
||||
if (!func->isa<ValueNode>() && func->abstract()->isa<AbstractFunction>() &&
|
||||
!func->abstract()->isa<AbstractFuncUnion>()) {
|
||||
if (IsHighOrderCall(func)) {
|
||||
auto func_abs = func->abstract()->cast<AbstractFunctionPtr>();
|
||||
EvaluatorPtr eval = engine_->GetEvaluatorFor(func_abs);
|
||||
std::pair<AbstractBasePtrList, AbstractBasePtr> result;
|
||||
|
@ -1206,6 +1248,9 @@ void FuncGraphSpecializer::ProcessCNode(const CNodePtr &cnode) {
|
|||
if (status == kSpecializePoly ||
|
||||
(func->isa<Parameter>() && func->func_graph()->has_flag(FUNC_GRAPH_FLAG_SPECIALIZE_PARAMETER))) {
|
||||
auto wrapped_node = BuildSpecializedParameterNode(cnode);
|
||||
if (wrapped_node == nullptr) {
|
||||
return false;
|
||||
}
|
||||
MS_LOG(DEBUG) << "Partial closure is handled, wrapped_node: " << wrapped_node->DebugString(recursive_level);
|
||||
new_inputs[0] = wrapped_node;
|
||||
}
|
||||
|
@ -1225,7 +1270,11 @@ void FuncGraphSpecializer::ProcessCNode(const CNodePtr &cnode) {
|
|||
for (size_t i = 1; i < new_inputs.size(); ++i) {
|
||||
argvals.push_back(new_inputs[i]->abstract());
|
||||
}
|
||||
new_inputs[0] = BuildSpecializedNode(cnode, func, fnval, argvals);
|
||||
auto specialized_func_node = BuildSpecializedNode(cnode, func, fnval, argvals);
|
||||
if (specialized_func_node == nullptr) {
|
||||
return false;
|
||||
}
|
||||
new_inputs[0] = specialized_func_node;
|
||||
MS_LOG(DEBUG) << "Specalize func: " << func->type_name() << "/" << func->DebugString(recursive_level)
|
||||
<< ", new_func: " << new_inputs[0]->DebugString(recursive_level) << ", argvals: " << argvals;
|
||||
}
|
||||
|
@ -1236,7 +1285,9 @@ void FuncGraphSpecializer::ProcessCNode(const CNodePtr &cnode) {
|
|||
auto &old_node = new_inputs[i];
|
||||
if (CanSpecializeValueNode(old_node)) {
|
||||
auto new_node = BuildSpecializedNode(cnode, old_node, old_node->abstract(), std::vector<AbstractBasePtr>{});
|
||||
MS_EXCEPTION_IF_NULL(new_node);
|
||||
if (new_node == nullptr) {
|
||||
return false;
|
||||
}
|
||||
MS_LOG(DEBUG) << "Specalize arg[" << i << "]: " << old_node->DebugString(recursive_level)
|
||||
<< ", new_node: " << new_node->DebugString(recursive_level);
|
||||
new_inputs[i] = new_node;
|
||||
|
@ -1252,35 +1303,17 @@ void FuncGraphSpecializer::ProcessCNode(const CNodePtr &cnode) {
|
|||
if (enable_eliminate_unused_element && !enable_only_mark_unused_element) {
|
||||
EliminateUnusedSequenceItem(cnode);
|
||||
}
|
||||
// Only success processed node can be added to seen.
|
||||
specializer_->AddSeen(cnode);
|
||||
return true;
|
||||
}
|
||||
|
||||
bool FuncGraphSpecializer::RecordDeferredCNode(const CNodePtr &cnode, const AnalysisContextPtr &context) {
|
||||
bool FuncGraphSpecializer::ParentNotSpecialized(const AnalysisContextPtr &context) {
|
||||
auto parent_context = context->parent();
|
||||
auto parent_not_specialized = parent_context != nullptr && parent_context->func_graph() != nullptr &&
|
||||
specializer_->GetFuncGraphSpecializer(parent_context) == nullptr;
|
||||
if (parent_not_specialized) {
|
||||
MS_LOG(DEBUG) << "Add deferred specialize cnode: " << cnode->DebugString()
|
||||
<< ", parent context:" << parent_context->ToString();
|
||||
specializer_->EraseSeen(cnode);
|
||||
specializer_->AddDeferSpecializeNode(parent_context, cnode, this);
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
void FuncGraphSpecializer::ProcessDeferredCNode() {
|
||||
// Check whether there were call nodes depend on current func_graph_spcializer finished.
|
||||
auto parent_context_it = specializer_->defer_specialize_nodes().find(context_);
|
||||
if (parent_context_it == specializer_->defer_specialize_nodes().cend()) {
|
||||
return;
|
||||
}
|
||||
const auto &defer_specialize_call_nodes = parent_context_it->second;
|
||||
for (const auto &[func_sepcializer, deferred_node] : defer_specialize_call_nodes) {
|
||||
MS_LOG(DEBUG) << "Start second pass of deferred node: " << deferred_node->DebugString();
|
||||
func_sepcializer->ProcessCNode(deferred_node);
|
||||
}
|
||||
// Erase the current context which was parent context of deferred nodes.
|
||||
specializer_->RemoveDeferSpecializeNode(parent_context_it);
|
||||
auto parent_specializer = specializer_->GetFuncGraphSpecializer(parent_context);
|
||||
// If can't get specializer of parent and parent is not DummyContext, it means parent not specialized.
|
||||
auto parent_not_specialized = parent_specializer == nullptr && parent_context->func_graph() != nullptr;
|
||||
return parent_not_specialized;
|
||||
}
|
||||
|
||||
namespace {
|
||||
|
|
|
@ -25,6 +25,7 @@
|
|||
#include <utility>
|
||||
#include <vector>
|
||||
#include <unordered_map>
|
||||
#include <stack>
|
||||
|
||||
#include "utils/hash_map.h"
|
||||
#include "utils/hash_set.h"
|
||||
|
@ -60,8 +61,9 @@ class ProgramSpecializer {
|
|||
void EraseSeen(const AnfNodePtr &node) { (void)seen_.erase(node); }
|
||||
|
||||
std::shared_ptr<FuncGraphSpecializer> GetFuncGraphSpecializer(const AnalysisContextPtr &context);
|
||||
// Specialze one FuncGraph in a given context.
|
||||
FuncGraphPtr SpecializeFuncGraph(const FuncGraphPtr &fg, const AnalysisContextPtr &context);
|
||||
|
||||
std::shared_ptr<FuncGraphSpecializer> NewFuncGraphSpecializer(const AnalysisContextPtr &context,
|
||||
const FuncGraphPtr &fg);
|
||||
|
||||
std::shared_ptr<AnalysisEngine> engine() { return engine_; }
|
||||
|
||||
|
@ -105,6 +107,10 @@ class ProgramSpecializer {
|
|||
return nullptr;
|
||||
}
|
||||
|
||||
void PushFuncGraphTodoItem(const std::shared_ptr<FuncGraphSpecializer> &fg_spec) {
|
||||
func_graph_todo_items_.push(fg_spec);
|
||||
}
|
||||
|
||||
private:
|
||||
std::shared_ptr<AnalysisEngine> engine_;
|
||||
mindspore::HashSet<AnfNodePtr> seen_;
|
||||
|
@ -127,6 +133,8 @@ class ProgramSpecializer {
|
|||
mindspore::HashMap<FuncGraphPtr, std::pair<bool, AbstractFunctionPtr>> func_graph_to_abstract_map_;
|
||||
|
||||
AbstractFunctionPtr SpecializeAbstractFuncRecursively(const AbstractFunctionPtr &old_abs_func);
|
||||
|
||||
std::stack<std::shared_ptr<FuncGraphSpecializer>> func_graph_todo_items_;
|
||||
};
|
||||
|
||||
class FuncGraphSpecializer : public std::enable_shared_from_this<FuncGraphSpecializer> {
|
||||
|
@ -138,6 +146,8 @@ class FuncGraphSpecializer : public std::enable_shared_from_this<FuncGraphSpecia
|
|||
|
||||
std::shared_ptr<FuncGraphSpecializer> GetTopSpecializer(const AnfNodePtr &node);
|
||||
|
||||
bool done() const { return done_; }
|
||||
|
||||
private:
|
||||
ProgramSpecializer *specializer_;
|
||||
FuncGraphPtr func_graph_;
|
||||
|
@ -149,14 +159,15 @@ class FuncGraphSpecializer : public std::enable_shared_from_this<FuncGraphSpecia
|
|||
std::vector<AnfNodePtr> todo_;
|
||||
mindspore::HashSet<AnfNodePtr> marked_;
|
||||
mindspore::HashMap<EvaluatorPtr, EvaluatorCacheMgrPtr> eval_cache_;
|
||||
std::vector<CNodePtr> second_pass_todo_;
|
||||
size_t second_pass_todo_inedx_{0};
|
||||
bool done_{false};
|
||||
|
||||
void FirstPass();
|
||||
void SecondPass();
|
||||
void ProcessNode(const AnfNodePtr &node);
|
||||
void ProcessCNode(const CNodePtr &cnode);
|
||||
// If cnode need deferred specialized, return true, otherwise return false.
|
||||
bool RecordDeferredCNode(const CNodePtr &cnode, const AnalysisContextPtr &context);
|
||||
void ProcessDeferredCNode();
|
||||
bool ProcessCNode(const CNodePtr &cnode);
|
||||
bool ParentNotSpecialized(const AnalysisContextPtr &context);
|
||||
|
||||
void EliminateUnusedSequenceItem(const CNodePtr &cnode) const;
|
||||
|
||||
|
@ -209,6 +220,7 @@ class FuncGraphSpecializer : public std::enable_shared_from_this<FuncGraphSpecia
|
|||
std::pair<AbstractBasePtrList, AbstractBasePtr> BuildFromBroadedArgsVal(const EvaluatorPtr &eval);
|
||||
void UpdateNewCNodeInputs(const AnfNodePtr &node, const AnfNodePtr &new_node);
|
||||
};
|
||||
using FuncGraphSpecializerPtr = std::shared_ptr<FuncGraphSpecializer>;
|
||||
} // namespace abstract
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_PIPELINE_JIT_STATIC_ANALYSIS_SPECIALIZE_H_
|
||||
|
|
Loading…
Reference in New Issue