From b37a0cef63940247b9659dd3e4814cf37028769b Mon Sep 17 00:00:00 2001 From: Zhang Qinghua Date: Mon, 5 Jul 2021 17:19:29 +0800 Subject: [PATCH] Modify context searching and creating routine. --- .../pipeline/jit/static_analysis/evaluator.cc | 12 +-- .../pipeline/jit/static_analysis/evaluator.h | 2 +- .../jit/static_analysis/program_specialize.cc | 13 +++- .../jit/static_analysis/program_specialize.h | 6 +- .../jit/static_analysis/stack_frame.cc | 7 +- .../jit/static_analysis/static_analysis.h | 16 +++- mindspore/core/abstract/analysis_context.cc | 76 +++++++++---------- mindspore/core/abstract/analysis_context.h | 17 +++-- .../static_analysis/static_analysis_test.cc | 28 +++---- 9 files changed, 94 insertions(+), 83 deletions(-) diff --git a/mindspore/ccsrc/pipeline/jit/static_analysis/evaluator.cc b/mindspore/ccsrc/pipeline/jit/static_analysis/evaluator.cc index 08ef601c04f..665e4088ad4 100644 --- a/mindspore/ccsrc/pipeline/jit/static_analysis/evaluator.cc +++ b/mindspore/ccsrc/pipeline/jit/static_analysis/evaluator.cc @@ -58,16 +58,6 @@ void EvalFailLogging(const EvaluatorPtr &evaluator, const AbstractBasePtrList &, } } // namespace -AnalysisContextPtr BaseFuncGraphEvaluator::MakeContext(const AnalysisEnginePtr &engine, - const AbstractBasePtrList &args_spec_list) { - AbstractBasePtrList normalized_args_spec_list = NormalizeArgs(args_spec_list); - normalized_args_spec_list = BroadenUndeterminedArgs(normalized_args_spec_list); - FuncGraphPtr fg = GetFuncGraph(engine, normalized_args_spec_list); - MS_EXCEPTION_IF_NULL(parent_context_); - AnalysisContextPtr context = parent_context_->NewFuncGraphContext(fg, normalized_args_spec_list); - return context; -} - void BaseFuncGraphEvaluator::EnterStackFrame(const AnalysisEnginePtr &engine, const StackFramePtr ¤t_stack_frame, const StackFramePtr &new_stack_frame) { // Enter new func graph. @@ -216,7 +206,7 @@ EvalResultPtr BaseFuncGraphEvaluator::Eval(AnalysisEnginePtr engine, const Abstr << parent_context_->func_graph()->ToString() << "()->" << AnalysisResultCacheMgr::GetThreadid() << ":" << fg->ToString() << "();"; } - auto context = parent_context_->NewFuncGraphContext(fg, args_abs_list); + auto context = parent_context_->NewContext(fg, args_abs_list); auto func_graph_evaluator = dyn_cast(shared_from_base()); if (func_graph_evaluator != nullptr) { if (engine->root_func_graph() == func_graph_evaluator->func_graph()) { diff --git a/mindspore/ccsrc/pipeline/jit/static_analysis/evaluator.h b/mindspore/ccsrc/pipeline/jit/static_analysis/evaluator.h index a588163bdb1..59836101e83 100644 --- a/mindspore/ccsrc/pipeline/jit/static_analysis/evaluator.h +++ b/mindspore/ccsrc/pipeline/jit/static_analysis/evaluator.h @@ -234,7 +234,7 @@ class BaseFuncGraphEvaluator : public Evaluator { class FuncGraphEvaluator : public BaseFuncGraphEvaluator { public: FuncGraphEvaluator(const FuncGraphPtr &func_graph, const AnalysisContextPtr &context) - : BaseFuncGraphEvaluator(context->FindParentContext(func_graph)), func_graph_(func_graph) {} + : BaseFuncGraphEvaluator(context), func_graph_(func_graph) {} ~FuncGraphEvaluator() override = default; MS_DECLARE_PARENT(FuncGraphEvaluator, BaseFuncGraphEvaluator); diff --git a/mindspore/ccsrc/pipeline/jit/static_analysis/program_specialize.cc b/mindspore/ccsrc/pipeline/jit/static_analysis/program_specialize.cc index 9c73dbe4110..bac1bc8d12c 100644 --- a/mindspore/ccsrc/pipeline/jit/static_analysis/program_specialize.cc +++ b/mindspore/ccsrc/pipeline/jit/static_analysis/program_specialize.cc @@ -466,7 +466,7 @@ AnfNodePtr FuncGraphSpecializer::BuildSpecializedNodeInner(const AnfNodePtr &nod if (func->context() == nullptr) { MS_LOG(EXCEPTION) << "Func context is nullptr NodeInfo: " << trace::GetDebugInfo(func_graph_->debug_info()); } - AnalysisContextPtr context = real_eval->MakeContext(engine_, argvals); + AnalysisContextPtr context = MakeContext(engine_, real_eval, argvals); MS_LOG(DEBUG) << "Specialize function graph: " << context->func_graph()->ToString() << ", args: " << argvals.size() << ", graph: " << context->func_graph()->get_return()->DebugString(); if (context->func_graph()->stub()) { @@ -480,6 +480,17 @@ AnfNodePtr FuncGraphSpecializer::BuildSpecializedNodeInner(const AnfNodePtr &nod return BuildValueNode(v, abs); } +AnalysisContextPtr FuncGraphSpecializer::MakeContext(const AnalysisEnginePtr &engine, + const BaseFuncGraphEvaluatorPtr &evaluator, + const AbstractBasePtrList &args_spec_list) { + AbstractBasePtrList normalized_args_spec_list = evaluator->NormalizeArgs(args_spec_list); + normalized_args_spec_list = evaluator->BroadenUndeterminedArgs(normalized_args_spec_list); + FuncGraphPtr fg = evaluator->GetFuncGraph(engine, normalized_args_spec_list); + MS_EXCEPTION_IF_NULL(evaluator->parent_context()); + AnalysisContextPtr new_context = evaluator->parent_context()->NewContext(fg, normalized_args_spec_list); + return new_context; +} + AnfNodePtr FuncGraphSpecializer::BuildSpecializedParameterNode(const CNodePtr &new_node) { auto new_inputs = new_node->inputs(); AnfNodePtr func = new_inputs[0]; diff --git a/mindspore/ccsrc/pipeline/jit/static_analysis/program_specialize.h b/mindspore/ccsrc/pipeline/jit/static_analysis/program_specialize.h index 50a9204c291..d22333c223a 100644 --- a/mindspore/ccsrc/pipeline/jit/static_analysis/program_specialize.h +++ b/mindspore/ccsrc/pipeline/jit/static_analysis/program_specialize.h @@ -41,6 +41,7 @@ enum SpecializeStatusCode { }; class FuncGraphSpecializer; +using BaseFuncGraphEvaluatorPtr = std::shared_ptr; // Specialize a func graph using analyzed abstract values. class ProgramSpecializer { @@ -103,7 +104,10 @@ class FuncGraphSpecializer : public std::enable_shared_from_this(graph_func); - if (func_graph_abs != nullptr) { // Find parent context for FuncGraphAbstractClosure. - auto branch_fg = func_graph_abs->func_graph(); - parent_context = func_graph_abs->context()->FindParentContext(branch_fg); + if (func_graph_abs != nullptr) { // Set parent context for FuncGraphAbstractClosure. + parent_context = func_graph_abs->context(); } else if (graph_func->isa()) { // Or DummyContext for MetaFuncGraphAbstractClosure. parent_context = fg_evaluator->parent_context(); if (parent_context == nullptr) { @@ -85,7 +84,7 @@ StackFramePtr StackFrame::DoJump(const AnalysisEnginePtr &engine, const CNodePtr // Find parent context and create new context. AnalysisContextPtr parent_context = GetParentContext(fg_evaluator, graph_func); MS_EXCEPTION_IF_NULL(parent_context); - auto new_context = parent_context->NewFuncGraphContext(fg, args_abs_list); + auto new_context = parent_context->NewContext(fg, args_abs_list); // Evaluate the parameters with new context. for (size_t i = 0; i < nargs; i++) { diff --git a/mindspore/ccsrc/pipeline/jit/static_analysis/static_analysis.h b/mindspore/ccsrc/pipeline/jit/static_analysis/static_analysis.h index 12ffc23b0ca..18057ca80fd 100644 --- a/mindspore/ccsrc/pipeline/jit/static_analysis/static_analysis.h +++ b/mindspore/ccsrc/pipeline/jit/static_analysis/static_analysis.h @@ -92,7 +92,7 @@ using ConfigPtrList = std::vector; class AnfNodeConfig : public Config { public: AnfNodeConfig(const AnalysisEnginePtr &engine, const AnfNodePtr &node, const AnalysisContextPtr &context) - : Config(), engine_(std::weak_ptr(engine)), node_(node) { + : Config(), engine_(std::weak_ptr(engine)), node_(node), context_(nullptr) { FuncGraphPtr fg; if (IsValueNode(node)) { auto v = node->cast(); @@ -100,9 +100,17 @@ class AnfNodeConfig : public Config { } else { fg = node->func_graph(); } - context_ = nullptr; - if (context != nullptr) { - context_ = context->FindParentContext(fg); + + if (context == nullptr) { + return; + } + if (context->func_graph() == fg) { + // Usually `node` is CNode and not a FV, or top graph's ValueNodes. + context_ = context; + } else { + // If `node` is FV, FuncGraph, or other graph ValueNodes. + // Non-FuncGraph ValueNodes will always get a DummyContext since `fg` is null. + context_ = context->FindOwnOrParentContext(fg); } } diff --git a/mindspore/core/abstract/analysis_context.cc b/mindspore/core/abstract/analysis_context.cc index 11b9012880f..5f52a9ff2ac 100644 --- a/mindspore/core/abstract/analysis_context.cc +++ b/mindspore/core/abstract/analysis_context.cc @@ -23,36 +23,17 @@ namespace mindspore { namespace abstract { -AnalysisContextPtr AnalysisContext::NewContext(AnalysisContextPtr parent_context, FuncGraphPtr fg, +AnalysisContextPtr AnalysisContext::NewContext(const FuncGraphPtr &func_graph, const AbstractBasePtrList &args_spec_list) { - MS_EXCEPTION_IF_NULL(parent_context); - auto children_context_map_iter = parent_context->children_cache_.find(fg); - if (children_context_map_iter != parent_context->children_cache_.end()) { - auto children_context_map = children_context_map_iter->second; - auto children_context_iter = children_context_map.find(args_spec_list); - if (children_context_iter != children_context_map.end()) { - return children_context_iter->second.lock(); - } - } - AnalysisContextPtr new_context = std::make_shared(parent_context, fg, args_spec_list); - // Reference to myself, so use weak_ptr to break reference cycle. - auto weak_context = std::weak_ptr(new_context); - new_context->parent_cache_[fg] = weak_context; - parent_context->children_cache_[fg][args_spec_list] = weak_context; - return new_context; -} - -AnalysisContextPtr AnalysisContext::NewFuncGraphContext(const FuncGraphPtr &func_graph, - const AbstractBasePtrList &args_spec_list) { + // Find func graph's parent and its parent context firstly. MS_EXCEPTION_IF_NULL(func_graph); FuncGraphPtr parent_graph = func_graph->parent(); AnalysisContextPtr parent_context = nullptr; - auto iter = parent_cache_.find(parent_graph); - if (iter != parent_cache_.end()) { + auto iter = extant_context_cache_.find(parent_graph); + if (iter != extant_context_cache_.end()) { parent_context = iter->second.lock(); } - // If this happen, it will be a bug in code. But we raise exception to keep the scene. - if (parent_context == nullptr) { + if (parent_context == nullptr) { // If parent context is not found, we'll raise exception. std::ostringstream oss; oss << "BUG: Failed to find parent context in current context: " << this->ToString() << ", func_graph: " << func_graph->ToString() << ", parent_graph: "; @@ -63,31 +44,48 @@ AnalysisContextPtr AnalysisContext::NewFuncGraphContext(const FuncGraphPtr &func } MS_LOG(EXCEPTION) << oss.str() << " NodeInfo: " << trace::GetDebugInfo(func_graph->debug_info()); } - return NewContext(parent_context, func_graph, args_spec_list); + + // Check if we created a context for func graph with the same arguments before. + auto children_context_map_iter = parent_context->children_cache_.find(func_graph); + if (children_context_map_iter != parent_context->children_cache_.end()) { + auto children_context_map = children_context_map_iter->second; + auto children_context_iter = children_context_map.find(args_spec_list); + if (children_context_iter != children_context_map.end()) { + return children_context_iter->second.lock(); + } + } + + // Create a new context for the func graph and its specific arguments. + AnalysisContextPtr new_context = std::make_shared(parent_context, func_graph, args_spec_list); + // To avoid cycle-reference, use weak_ptr here. + auto weak_new_context = std::weak_ptr(new_context); + new_context->extant_context_cache_[func_graph] = weak_new_context; + parent_context->children_cache_[func_graph][args_spec_list] = weak_new_context; + return new_context; } -AnalysisContextPtr AnalysisContext::FindParentContext(const FuncGraphPtr &func_graph) { - auto p_iter = parent_cache_.find(func_graph); - AnalysisContextPtr parent_context = nullptr; - if (p_iter != parent_cache_.end()) { - parent_context = p_iter->second.lock(); +AnalysisContextPtr AnalysisContext::FindOwnOrParentContext(const FuncGraphPtr &func_graph) { + auto p_iter = extant_context_cache_.find(func_graph); + AnalysisContextPtr extant_context = nullptr; + if (p_iter != extant_context_cache_.end()) { + extant_context = p_iter->second.lock(); } else { - auto iter_parent = parent_cache_.find(func_graph->parent()); - if (iter_parent != parent_cache_.end()) { - parent_context = iter_parent->second.lock(); + auto iter_parent = extant_context_cache_.find(func_graph->parent()); + if (iter_parent != extant_context_cache_.end()) { + extant_context = iter_parent->second.lock(); } } // If this happen, it would be a bug in code. But we raise exception to keep the scene. - if (parent_context == nullptr) { + if (extant_context == nullptr) { std::ostringstream oss; - oss << "BUG: Failed to find parent context for: " << func_graph->ToString() << ", parent_graph: "; + oss << "BUG: Failed to find context for: " << func_graph->ToString() << ", parent_graph: "; if (func_graph->parent() != nullptr) { oss << func_graph->parent()->ToString(); } else { oss << "nullptr"; } - oss << " parent_cache_: {"; - for (auto iter : parent_cache_) { + oss << " extant context list: {"; + for (auto iter : extant_context_cache_) { if (iter.first == nullptr) { oss << " [graph: nullptr"; } else { @@ -100,12 +98,12 @@ AnalysisContextPtr AnalysisContext::FindParentContext(const FuncGraphPtr &func_g oss << "}"; MS_LOG(EXCEPTION) << oss.str() << " NodeInfo: " << trace::GetDebugInfo(func_graph->debug_info()); } - return parent_context; + return extant_context; } AnalysisContextPtr AnalysisContext::DummyContext() { AnalysisContextPtr dummy_context = std::make_shared(nullptr, nullptr, AbstractBasePtrList()); - dummy_context->parent_cache_[nullptr] = std::weak_ptr(dummy_context); + dummy_context->extant_context_cache_[nullptr] = std::weak_ptr(dummy_context); return dummy_context; } diff --git a/mindspore/core/abstract/analysis_context.h b/mindspore/core/abstract/analysis_context.h index b2e94ed5567..e097888ebc7 100644 --- a/mindspore/core/abstract/analysis_context.h +++ b/mindspore/core/abstract/analysis_context.h @@ -39,20 +39,17 @@ class AnalysisContext { AnalysisContext(const AnalysisContextPtr &parent, const FuncGraphPtr &fg, const AbstractBasePtrList &args_spec_list) : parent_(parent), func_graph_(fg), args_spec_list_(args_spec_list) { if (parent_ != nullptr) { - parent_cache_ = parent_->parent_cache_; + extant_context_cache_ = parent_->extant_context_cache_; } } ~AnalysisContext() = default; - // Helper function to wrapper constructor to save shared_ptr in parent_cache. - AnalysisContextPtr NewContext(AnalysisContextPtr parent, FuncGraphPtr fg, const AbstractBasePtrList &args_spec_list); - // Extend this context with values for another graph. - AnalysisContextPtr NewFuncGraphContext(const FuncGraphPtr &func_graph, const AbstractBasePtrList &args_spec_list); + AnalysisContextPtr NewContext(const FuncGraphPtr &func_graph, const AbstractBasePtrList &args_spec_list); - // Return a context restricted to a graph's dependencies. - AnalysisContextPtr FindParentContext(const FuncGraphPtr &graph); + // Return a context restricted to a graph and its parent. + AnalysisContextPtr FindOwnOrParentContext(const FuncGraphPtr &graph); bool operator==(const AnalysisContext &other) const; std::size_t hash(); static AnalysisContextPtr DummyContext(); @@ -67,7 +64,11 @@ class AnalysisContext { AnalysisContextPtr parent_; FuncGraphPtr func_graph_; AbstractBasePtrList args_spec_list_; - std::unordered_map parent_cache_; + // Record all created context for each func graph. + // `extant_context_cache_` is copied from its parent context. + std::unordered_map extant_context_cache_; + // Record all created child contexts from this context. + // Like: key: [func_graph & arguments], value: [child_context] std::unordered_map children_cache_; }; diff --git a/tests/ut/cpp/pipeline/static_analysis/static_analysis_test.cc b/tests/ut/cpp/pipeline/static_analysis/static_analysis_test.cc index d4d16ceb920..dd01ebb18dc 100644 --- a/tests/ut/cpp/pipeline/static_analysis/static_analysis_test.cc +++ b/tests/ut/cpp/pipeline/static_analysis/static_analysis_test.cc @@ -294,23 +294,23 @@ TEST_F(TestInferGraph, test_context) { AnalysisContextPtr dummy_context = AnalysisContext::DummyContext(); - AnalysisContextPtr f_context = dummy_context->NewFuncGraphContext(graph_f_, AbstractBasePtrList()); - ASSERT_TRUE(f_context->FindParentContext(graph_f_) = f_context); - ASSERT_TRUE(f_context->FindParentContext(nullptr) = dummy_context); + AnalysisContextPtr f_context = dummy_context->NewContext(graph_f_, AbstractBasePtrList()); + ASSERT_TRUE(f_context->FindOwnOrParentContext(graph_f_) = f_context); + ASSERT_TRUE(f_context->FindOwnOrParentContext(nullptr) = dummy_context); - AnalysisContextPtr g_context = f_context->NewFuncGraphContext(graph_g_, AbstractBasePtrList()); - ASSERT_TRUE(g_context->FindParentContext(graph_g_) = g_context); - ASSERT_TRUE(g_context->FindParentContext(graph_f_) = dummy_context); - ASSERT_TRUE(g_context->FindParentContext(nullptr) = dummy_context); + AnalysisContextPtr g_context = f_context->NewContext(graph_g_, AbstractBasePtrList()); + ASSERT_TRUE(g_context->FindOwnOrParentContext(graph_g_) = g_context); + ASSERT_TRUE(g_context->FindOwnOrParentContext(graph_f_) = dummy_context); + ASSERT_TRUE(g_context->FindOwnOrParentContext(nullptr) = dummy_context); - AnalysisContextPtr alpha_context = dummy_context->NewFuncGraphContext(graph_alpha_, AbstractBasePtrList()); - ASSERT_TRUE(alpha_context->FindParentContext(graph_alpha_) = alpha_context); - ASSERT_TRUE(alpha_context->FindParentContext(nullptr) = dummy_context); + AnalysisContextPtr alpha_context = dummy_context->NewContext(graph_alpha_, AbstractBasePtrList()); + ASSERT_TRUE(alpha_context->FindOwnOrParentContext(graph_alpha_) = alpha_context); + ASSERT_TRUE(alpha_context->FindOwnOrParentContext(nullptr) = dummy_context); - AnalysisContextPtr beta_context = alpha_context->NewFuncGraphContext(graph_beta_, AbstractBasePtrList()); - ASSERT_TRUE(beta_context->FindParentContext(graph_beta_) = beta_context); - ASSERT_TRUE(beta_context->FindParentContext(graph_alpha_) = alpha_context); - ASSERT_TRUE(beta_context->FindParentContext(nullptr) = dummy_context); + AnalysisContextPtr beta_context = alpha_context->NewContext(graph_beta_, AbstractBasePtrList()); + ASSERT_TRUE(beta_context->FindOwnOrParentContext(graph_beta_) = beta_context); + ASSERT_TRUE(beta_context->FindOwnOrParentContext(graph_alpha_) = alpha_context); + ASSERT_TRUE(beta_context->FindOwnOrParentContext(nullptr) = dummy_context); } class TestInferMetaGraph : public UT::Common {