!19434 Modify context searching and creating routine, not searching all the time.

Merge pull request !19434 from 张清华/opt0
This commit is contained in:
i-robot 2021-07-06 13:00:32 +00:00 committed by Gitee
commit db0a064cfd
9 changed files with 94 additions and 83 deletions

View File

@ -58,16 +58,6 @@ void EvalFailLogging(const EvaluatorPtr &evaluator, const AbstractBasePtrList &,
} }
} // namespace } // namespace
AnalysisContextPtr BaseFuncGraphEvaluator::MakeContext(const AnalysisEnginePtr &engine,
const AbstractBasePtrList &args_spec_list) {
AbstractBasePtrList normalized_args_spec_list = NormalizeArgs(args_spec_list);
normalized_args_spec_list = BroadenUndeterminedArgs(normalized_args_spec_list);
FuncGraphPtr fg = GetFuncGraph(engine, normalized_args_spec_list);
MS_EXCEPTION_IF_NULL(parent_context_);
AnalysisContextPtr context = parent_context_->NewFuncGraphContext(fg, normalized_args_spec_list);
return context;
}
void BaseFuncGraphEvaluator::EnterStackFrame(const AnalysisEnginePtr &engine, const StackFramePtr &current_stack_frame, void BaseFuncGraphEvaluator::EnterStackFrame(const AnalysisEnginePtr &engine, const StackFramePtr &current_stack_frame,
const StackFramePtr &new_stack_frame) { const StackFramePtr &new_stack_frame) {
// Enter new func graph. // Enter new func graph.
@ -216,7 +206,7 @@ EvalResultPtr BaseFuncGraphEvaluator::Eval(AnalysisEnginePtr engine, const Abstr
<< parent_context_->func_graph()->ToString() << "()->" << AnalysisResultCacheMgr::GetThreadid() << ":" << parent_context_->func_graph()->ToString() << "()->" << AnalysisResultCacheMgr::GetThreadid() << ":"
<< fg->ToString() << "();"; << fg->ToString() << "();";
} }
auto context = parent_context_->NewFuncGraphContext(fg, args_abs_list); auto context = parent_context_->NewContext(fg, args_abs_list);
auto func_graph_evaluator = dyn_cast<FuncGraphEvaluator>(shared_from_base<BaseFuncGraphEvaluator>()); auto func_graph_evaluator = dyn_cast<FuncGraphEvaluator>(shared_from_base<BaseFuncGraphEvaluator>());
if (func_graph_evaluator != nullptr) { if (func_graph_evaluator != nullptr) {
if (engine->root_func_graph() == func_graph_evaluator->func_graph()) { if (engine->root_func_graph() == func_graph_evaluator->func_graph()) {

View File

@ -234,7 +234,7 @@ class BaseFuncGraphEvaluator : public Evaluator {
class FuncGraphEvaluator : public BaseFuncGraphEvaluator { class FuncGraphEvaluator : public BaseFuncGraphEvaluator {
public: public:
FuncGraphEvaluator(const FuncGraphPtr &func_graph, const AnalysisContextPtr &context) FuncGraphEvaluator(const FuncGraphPtr &func_graph, const AnalysisContextPtr &context)
: BaseFuncGraphEvaluator(context->FindParentContext(func_graph)), func_graph_(func_graph) {} : BaseFuncGraphEvaluator(context), func_graph_(func_graph) {}
~FuncGraphEvaluator() override = default; ~FuncGraphEvaluator() override = default;
MS_DECLARE_PARENT(FuncGraphEvaluator, BaseFuncGraphEvaluator); MS_DECLARE_PARENT(FuncGraphEvaluator, BaseFuncGraphEvaluator);

View File

@ -466,7 +466,7 @@ AnfNodePtr FuncGraphSpecializer::BuildSpecializedNodeInner(const AnfNodePtr &nod
if (func->context() == nullptr) { if (func->context() == nullptr) {
MS_LOG(EXCEPTION) << "Func context is nullptr NodeInfo: " << trace::GetDebugInfo(func_graph_->debug_info()); MS_LOG(EXCEPTION) << "Func context is nullptr NodeInfo: " << trace::GetDebugInfo(func_graph_->debug_info());
} }
AnalysisContextPtr context = real_eval->MakeContext(engine_, argvals); AnalysisContextPtr context = MakeContext(engine_, real_eval, argvals);
MS_LOG(DEBUG) << "Specialize function graph: " << context->func_graph()->ToString() << ", args: " << argvals.size() MS_LOG(DEBUG) << "Specialize function graph: " << context->func_graph()->ToString() << ", args: " << argvals.size()
<< ", graph: " << context->func_graph()->get_return()->DebugString(); << ", graph: " << context->func_graph()->get_return()->DebugString();
if (context->func_graph()->stub()) { if (context->func_graph()->stub()) {
@ -480,6 +480,17 @@ AnfNodePtr FuncGraphSpecializer::BuildSpecializedNodeInner(const AnfNodePtr &nod
return BuildValueNode(v, abs); return BuildValueNode(v, abs);
} }
AnalysisContextPtr FuncGraphSpecializer::MakeContext(const AnalysisEnginePtr &engine,
const BaseFuncGraphEvaluatorPtr &evaluator,
const AbstractBasePtrList &args_spec_list) {
AbstractBasePtrList normalized_args_spec_list = evaluator->NormalizeArgs(args_spec_list);
normalized_args_spec_list = evaluator->BroadenUndeterminedArgs(normalized_args_spec_list);
FuncGraphPtr fg = evaluator->GetFuncGraph(engine, normalized_args_spec_list);
MS_EXCEPTION_IF_NULL(evaluator->parent_context());
AnalysisContextPtr new_context = evaluator->parent_context()->NewContext(fg, normalized_args_spec_list);
return new_context;
}
AnfNodePtr FuncGraphSpecializer::BuildSpecializedParameterNode(const CNodePtr &new_node) { AnfNodePtr FuncGraphSpecializer::BuildSpecializedParameterNode(const CNodePtr &new_node) {
auto new_inputs = new_node->inputs(); auto new_inputs = new_node->inputs();
AnfNodePtr func = new_inputs[0]; AnfNodePtr func = new_inputs[0];

View File

@ -41,6 +41,7 @@ enum SpecializeStatusCode {
}; };
class FuncGraphSpecializer; class FuncGraphSpecializer;
using BaseFuncGraphEvaluatorPtr = std::shared_ptr<BaseFuncGraphEvaluator>;
// Specialize a func graph using analyzed abstract values. // Specialize a func graph using analyzed abstract values.
class ProgramSpecializer { class ProgramSpecializer {
@ -103,7 +104,10 @@ class FuncGraphSpecializer : public std::enable_shared_from_this<FuncGraphSpecia
void ProcessNode(const AnfNodePtr &node); void ProcessNode(const AnfNodePtr &node);
void ProcessCNode(const CNodePtr &new_node); void ProcessCNode(const CNodePtr &new_node);
AnfNodeConfigPtr MakeConfig(const AnfNodePtr &node); inline AnfNodeConfigPtr MakeConfig(const AnfNodePtr &node);
inline AnalysisContextPtr MakeContext(const AnalysisEnginePtr &engine, const BaseFuncGraphEvaluatorPtr &evaluator,
const AbstractBasePtrList &args_spec_list);
inline void AddTodoItem(const AnfNodePtr &node) { todo_.push_back(node); } inline void AddTodoItem(const AnfNodePtr &node) { todo_.push_back(node); }
// Get node replicated by Cloner. // Get node replicated by Cloner.
AnfNodePtr GetReplicatedNode(const AnfNodePtr &node); AnfNodePtr GetReplicatedNode(const AnfNodePtr &node);

View File

@ -38,9 +38,8 @@ AnalysisContextPtr StackFrame::GetParentContext(const BaseFuncGraphEvaluatorPtr
const AbstractFunctionPtr &graph_func) { const AbstractFunctionPtr &graph_func) {
AnalysisContextPtr parent_context = nullptr; AnalysisContextPtr parent_context = nullptr;
auto func_graph_abs = dyn_cast<FuncGraphAbstractClosure>(graph_func); auto func_graph_abs = dyn_cast<FuncGraphAbstractClosure>(graph_func);
if (func_graph_abs != nullptr) { // Find parent context for FuncGraphAbstractClosure. if (func_graph_abs != nullptr) { // Set parent context for FuncGraphAbstractClosure.
auto branch_fg = func_graph_abs->func_graph(); parent_context = func_graph_abs->context();
parent_context = func_graph_abs->context()->FindParentContext(branch_fg);
} else if (graph_func->isa<MetaFuncGraphAbstractClosure>()) { // Or DummyContext for MetaFuncGraphAbstractClosure. } else if (graph_func->isa<MetaFuncGraphAbstractClosure>()) { // Or DummyContext for MetaFuncGraphAbstractClosure.
parent_context = fg_evaluator->parent_context(); parent_context = fg_evaluator->parent_context();
if (parent_context == nullptr) { if (parent_context == nullptr) {
@ -85,7 +84,7 @@ StackFramePtr StackFrame::DoJump(const AnalysisEnginePtr &engine, const CNodePtr
// Find parent context and create new context. // Find parent context and create new context.
AnalysisContextPtr parent_context = GetParentContext(fg_evaluator, graph_func); AnalysisContextPtr parent_context = GetParentContext(fg_evaluator, graph_func);
MS_EXCEPTION_IF_NULL(parent_context); MS_EXCEPTION_IF_NULL(parent_context);
auto new_context = parent_context->NewFuncGraphContext(fg, args_abs_list); auto new_context = parent_context->NewContext(fg, args_abs_list);
// Evaluate the parameters with new context. // Evaluate the parameters with new context.
for (size_t i = 0; i < nargs; i++) { for (size_t i = 0; i < nargs; i++) {

View File

@ -92,7 +92,7 @@ using ConfigPtrList = std::vector<ConfigPtr>;
class AnfNodeConfig : public Config { class AnfNodeConfig : public Config {
public: public:
AnfNodeConfig(const AnalysisEnginePtr &engine, const AnfNodePtr &node, const AnalysisContextPtr &context) AnfNodeConfig(const AnalysisEnginePtr &engine, const AnfNodePtr &node, const AnalysisContextPtr &context)
: Config(), engine_(std::weak_ptr<AnalysisEngine>(engine)), node_(node) { : Config(), engine_(std::weak_ptr<AnalysisEngine>(engine)), node_(node), context_(nullptr) {
FuncGraphPtr fg; FuncGraphPtr fg;
if (IsValueNode<FuncGraph>(node)) { if (IsValueNode<FuncGraph>(node)) {
auto v = node->cast<ValueNodePtr>(); auto v = node->cast<ValueNodePtr>();
@ -100,9 +100,17 @@ class AnfNodeConfig : public Config {
} else { } else {
fg = node->func_graph(); fg = node->func_graph();
} }
context_ = nullptr;
if (context != nullptr) { if (context == nullptr) {
context_ = context->FindParentContext(fg); return;
}
if (context->func_graph() == fg) {
// Usually `node` is CNode and not a FV, or top graph's ValueNodes.
context_ = context;
} else {
// If `node` is FV, FuncGraph, or other graph ValueNodes.
// Non-FuncGraph ValueNodes will always get a DummyContext since `fg` is null.
context_ = context->FindOwnOrParentContext(fg);
} }
} }

View File

@ -23,36 +23,17 @@
namespace mindspore { namespace mindspore {
namespace abstract { namespace abstract {
AnalysisContextPtr AnalysisContext::NewContext(AnalysisContextPtr parent_context, FuncGraphPtr fg, AnalysisContextPtr AnalysisContext::NewContext(const FuncGraphPtr &func_graph,
const AbstractBasePtrList &args_spec_list) { const AbstractBasePtrList &args_spec_list) {
MS_EXCEPTION_IF_NULL(parent_context); // Find func graph's parent and its parent context firstly.
auto children_context_map_iter = parent_context->children_cache_.find(fg);
if (children_context_map_iter != parent_context->children_cache_.end()) {
auto children_context_map = children_context_map_iter->second;
auto children_context_iter = children_context_map.find(args_spec_list);
if (children_context_iter != children_context_map.end()) {
return children_context_iter->second.lock();
}
}
AnalysisContextPtr new_context = std::make_shared<AnalysisContext>(parent_context, fg, args_spec_list);
// Reference to myself, so use weak_ptr to break reference cycle.
auto weak_context = std::weak_ptr<AnalysisContext>(new_context);
new_context->parent_cache_[fg] = weak_context;
parent_context->children_cache_[fg][args_spec_list] = weak_context;
return new_context;
}
AnalysisContextPtr AnalysisContext::NewFuncGraphContext(const FuncGraphPtr &func_graph,
const AbstractBasePtrList &args_spec_list) {
MS_EXCEPTION_IF_NULL(func_graph); MS_EXCEPTION_IF_NULL(func_graph);
FuncGraphPtr parent_graph = func_graph->parent(); FuncGraphPtr parent_graph = func_graph->parent();
AnalysisContextPtr parent_context = nullptr; AnalysisContextPtr parent_context = nullptr;
auto iter = parent_cache_.find(parent_graph); auto iter = extant_context_cache_.find(parent_graph);
if (iter != parent_cache_.end()) { if (iter != extant_context_cache_.end()) {
parent_context = iter->second.lock(); parent_context = iter->second.lock();
} }
// If this happen, it will be a bug in code. But we raise exception to keep the scene. if (parent_context == nullptr) { // If parent context is not found, we'll raise exception.
if (parent_context == nullptr) {
std::ostringstream oss; std::ostringstream oss;
oss << "BUG: Failed to find parent context in current context: " << this->ToString() oss << "BUG: Failed to find parent context in current context: " << this->ToString()
<< ", func_graph: " << func_graph->ToString() << ", parent_graph: "; << ", func_graph: " << func_graph->ToString() << ", parent_graph: ";
@ -63,31 +44,48 @@ AnalysisContextPtr AnalysisContext::NewFuncGraphContext(const FuncGraphPtr &func
} }
MS_LOG(EXCEPTION) << oss.str() << " NodeInfo: " << trace::GetDebugInfo(func_graph->debug_info()); MS_LOG(EXCEPTION) << oss.str() << " NodeInfo: " << trace::GetDebugInfo(func_graph->debug_info());
} }
return NewContext(parent_context, func_graph, args_spec_list);
// Check if we created a context for func graph with the same arguments before.
auto children_context_map_iter = parent_context->children_cache_.find(func_graph);
if (children_context_map_iter != parent_context->children_cache_.end()) {
auto children_context_map = children_context_map_iter->second;
auto children_context_iter = children_context_map.find(args_spec_list);
if (children_context_iter != children_context_map.end()) {
return children_context_iter->second.lock();
}
}
// Create a new context for the func graph and its specific arguments.
AnalysisContextPtr new_context = std::make_shared<AnalysisContext>(parent_context, func_graph, args_spec_list);
// To avoid cycle-reference, use weak_ptr here.
auto weak_new_context = std::weak_ptr<AnalysisContext>(new_context);
new_context->extant_context_cache_[func_graph] = weak_new_context;
parent_context->children_cache_[func_graph][args_spec_list] = weak_new_context;
return new_context;
} }
AnalysisContextPtr AnalysisContext::FindParentContext(const FuncGraphPtr &func_graph) { AnalysisContextPtr AnalysisContext::FindOwnOrParentContext(const FuncGraphPtr &func_graph) {
auto p_iter = parent_cache_.find(func_graph); auto p_iter = extant_context_cache_.find(func_graph);
AnalysisContextPtr parent_context = nullptr; AnalysisContextPtr extant_context = nullptr;
if (p_iter != parent_cache_.end()) { if (p_iter != extant_context_cache_.end()) {
parent_context = p_iter->second.lock(); extant_context = p_iter->second.lock();
} else { } else {
auto iter_parent = parent_cache_.find(func_graph->parent()); auto iter_parent = extant_context_cache_.find(func_graph->parent());
if (iter_parent != parent_cache_.end()) { if (iter_parent != extant_context_cache_.end()) {
parent_context = iter_parent->second.lock(); extant_context = iter_parent->second.lock();
} }
} }
// If this happen, it would be a bug in code. But we raise exception to keep the scene. // If this happen, it would be a bug in code. But we raise exception to keep the scene.
if (parent_context == nullptr) { if (extant_context == nullptr) {
std::ostringstream oss; std::ostringstream oss;
oss << "BUG: Failed to find parent context for: " << func_graph->ToString() << ", parent_graph: "; oss << "BUG: Failed to find context for: " << func_graph->ToString() << ", parent_graph: ";
if (func_graph->parent() != nullptr) { if (func_graph->parent() != nullptr) {
oss << func_graph->parent()->ToString(); oss << func_graph->parent()->ToString();
} else { } else {
oss << "nullptr"; oss << "nullptr";
} }
oss << " parent_cache_: {"; oss << " extant context list: {";
for (auto iter : parent_cache_) { for (auto iter : extant_context_cache_) {
if (iter.first == nullptr) { if (iter.first == nullptr) {
oss << " [graph: nullptr"; oss << " [graph: nullptr";
} else { } else {
@ -100,12 +98,12 @@ AnalysisContextPtr AnalysisContext::FindParentContext(const FuncGraphPtr &func_g
oss << "}"; oss << "}";
MS_LOG(EXCEPTION) << oss.str() << " NodeInfo: " << trace::GetDebugInfo(func_graph->debug_info()); MS_LOG(EXCEPTION) << oss.str() << " NodeInfo: " << trace::GetDebugInfo(func_graph->debug_info());
} }
return parent_context; return extant_context;
} }
AnalysisContextPtr AnalysisContext::DummyContext() { AnalysisContextPtr AnalysisContext::DummyContext() {
AnalysisContextPtr dummy_context = std::make_shared<AnalysisContext>(nullptr, nullptr, AbstractBasePtrList()); AnalysisContextPtr dummy_context = std::make_shared<AnalysisContext>(nullptr, nullptr, AbstractBasePtrList());
dummy_context->parent_cache_[nullptr] = std::weak_ptr<AnalysisContext>(dummy_context); dummy_context->extant_context_cache_[nullptr] = std::weak_ptr<AnalysisContext>(dummy_context);
return dummy_context; return dummy_context;
} }

View File

@ -39,20 +39,17 @@ class AnalysisContext {
AnalysisContext(const AnalysisContextPtr &parent, const FuncGraphPtr &fg, const AbstractBasePtrList &args_spec_list) AnalysisContext(const AnalysisContextPtr &parent, const FuncGraphPtr &fg, const AbstractBasePtrList &args_spec_list)
: parent_(parent), func_graph_(fg), args_spec_list_(args_spec_list) { : parent_(parent), func_graph_(fg), args_spec_list_(args_spec_list) {
if (parent_ != nullptr) { if (parent_ != nullptr) {
parent_cache_ = parent_->parent_cache_; extant_context_cache_ = parent_->extant_context_cache_;
} }
} }
~AnalysisContext() = default; ~AnalysisContext() = default;
// Helper function to wrapper constructor to save shared_ptr in parent_cache.
AnalysisContextPtr NewContext(AnalysisContextPtr parent, FuncGraphPtr fg, const AbstractBasePtrList &args_spec_list);
// Extend this context with values for another graph. // Extend this context with values for another graph.
AnalysisContextPtr NewFuncGraphContext(const FuncGraphPtr &func_graph, const AbstractBasePtrList &args_spec_list); AnalysisContextPtr NewContext(const FuncGraphPtr &func_graph, const AbstractBasePtrList &args_spec_list);
// Return a context restricted to a graph's dependencies. // Return a context restricted to a graph and its parent.
AnalysisContextPtr FindParentContext(const FuncGraphPtr &graph); AnalysisContextPtr FindOwnOrParentContext(const FuncGraphPtr &graph);
bool operator==(const AnalysisContext &other) const; bool operator==(const AnalysisContext &other) const;
std::size_t hash(); std::size_t hash();
static AnalysisContextPtr DummyContext(); static AnalysisContextPtr DummyContext();
@ -67,7 +64,11 @@ class AnalysisContext {
AnalysisContextPtr parent_; AnalysisContextPtr parent_;
FuncGraphPtr func_graph_; FuncGraphPtr func_graph_;
AbstractBasePtrList args_spec_list_; AbstractBasePtrList args_spec_list_;
std::unordered_map<FuncGraphPtr, AnalysisContextWeakPtr> parent_cache_; // Record all created context for each func graph.
// `extant_context_cache_` is copied from its parent context.
std::unordered_map<FuncGraphPtr, AnalysisContextWeakPtr> extant_context_cache_;
// Record all created child contexts from this context.
// Like: key: [func_graph & arguments], value: [child_context]
std::unordered_map<FuncGraphPtr, ArgsSpecToAnalysisContextMap> children_cache_; std::unordered_map<FuncGraphPtr, ArgsSpecToAnalysisContextMap> children_cache_;
}; };

View File

@ -294,23 +294,23 @@ TEST_F(TestInferGraph, test_context) {
AnalysisContextPtr dummy_context = AnalysisContext::DummyContext(); AnalysisContextPtr dummy_context = AnalysisContext::DummyContext();
AnalysisContextPtr f_context = dummy_context->NewFuncGraphContext(graph_f_, AbstractBasePtrList()); AnalysisContextPtr f_context = dummy_context->NewContext(graph_f_, AbstractBasePtrList());
ASSERT_TRUE(f_context->FindParentContext(graph_f_) = f_context); ASSERT_TRUE(f_context->FindOwnOrParentContext(graph_f_) = f_context);
ASSERT_TRUE(f_context->FindParentContext(nullptr) = dummy_context); ASSERT_TRUE(f_context->FindOwnOrParentContext(nullptr) = dummy_context);
AnalysisContextPtr g_context = f_context->NewFuncGraphContext(graph_g_, AbstractBasePtrList()); AnalysisContextPtr g_context = f_context->NewContext(graph_g_, AbstractBasePtrList());
ASSERT_TRUE(g_context->FindParentContext(graph_g_) = g_context); ASSERT_TRUE(g_context->FindOwnOrParentContext(graph_g_) = g_context);
ASSERT_TRUE(g_context->FindParentContext(graph_f_) = dummy_context); ASSERT_TRUE(g_context->FindOwnOrParentContext(graph_f_) = dummy_context);
ASSERT_TRUE(g_context->FindParentContext(nullptr) = dummy_context); ASSERT_TRUE(g_context->FindOwnOrParentContext(nullptr) = dummy_context);
AnalysisContextPtr alpha_context = dummy_context->NewFuncGraphContext(graph_alpha_, AbstractBasePtrList()); AnalysisContextPtr alpha_context = dummy_context->NewContext(graph_alpha_, AbstractBasePtrList());
ASSERT_TRUE(alpha_context->FindParentContext(graph_alpha_) = alpha_context); ASSERT_TRUE(alpha_context->FindOwnOrParentContext(graph_alpha_) = alpha_context);
ASSERT_TRUE(alpha_context->FindParentContext(nullptr) = dummy_context); ASSERT_TRUE(alpha_context->FindOwnOrParentContext(nullptr) = dummy_context);
AnalysisContextPtr beta_context = alpha_context->NewFuncGraphContext(graph_beta_, AbstractBasePtrList()); AnalysisContextPtr beta_context = alpha_context->NewContext(graph_beta_, AbstractBasePtrList());
ASSERT_TRUE(beta_context->FindParentContext(graph_beta_) = beta_context); ASSERT_TRUE(beta_context->FindOwnOrParentContext(graph_beta_) = beta_context);
ASSERT_TRUE(beta_context->FindParentContext(graph_alpha_) = alpha_context); ASSERT_TRUE(beta_context->FindOwnOrParentContext(graph_alpha_) = alpha_context);
ASSERT_TRUE(beta_context->FindParentContext(nullptr) = dummy_context); ASSERT_TRUE(beta_context->FindOwnOrParentContext(nullptr) = dummy_context);
} }
class TestInferMetaGraph : public UT::Common { class TestInferMetaGraph : public UT::Common {