!19434 Modify context searching and creating routine, not searching all the time.
Merge pull request !19434 from 张清华/opt0
This commit is contained in:
commit
db0a064cfd
|
@ -58,16 +58,6 @@ void EvalFailLogging(const EvaluatorPtr &evaluator, const AbstractBasePtrList &,
|
|||
}
|
||||
} // namespace
|
||||
|
||||
AnalysisContextPtr BaseFuncGraphEvaluator::MakeContext(const AnalysisEnginePtr &engine,
|
||||
const AbstractBasePtrList &args_spec_list) {
|
||||
AbstractBasePtrList normalized_args_spec_list = NormalizeArgs(args_spec_list);
|
||||
normalized_args_spec_list = BroadenUndeterminedArgs(normalized_args_spec_list);
|
||||
FuncGraphPtr fg = GetFuncGraph(engine, normalized_args_spec_list);
|
||||
MS_EXCEPTION_IF_NULL(parent_context_);
|
||||
AnalysisContextPtr context = parent_context_->NewFuncGraphContext(fg, normalized_args_spec_list);
|
||||
return context;
|
||||
}
|
||||
|
||||
void BaseFuncGraphEvaluator::EnterStackFrame(const AnalysisEnginePtr &engine, const StackFramePtr ¤t_stack_frame,
|
||||
const StackFramePtr &new_stack_frame) {
|
||||
// Enter new func graph.
|
||||
|
@ -216,7 +206,7 @@ EvalResultPtr BaseFuncGraphEvaluator::Eval(AnalysisEnginePtr engine, const Abstr
|
|||
<< parent_context_->func_graph()->ToString() << "()->" << AnalysisResultCacheMgr::GetThreadid() << ":"
|
||||
<< fg->ToString() << "();";
|
||||
}
|
||||
auto context = parent_context_->NewFuncGraphContext(fg, args_abs_list);
|
||||
auto context = parent_context_->NewContext(fg, args_abs_list);
|
||||
auto func_graph_evaluator = dyn_cast<FuncGraphEvaluator>(shared_from_base<BaseFuncGraphEvaluator>());
|
||||
if (func_graph_evaluator != nullptr) {
|
||||
if (engine->root_func_graph() == func_graph_evaluator->func_graph()) {
|
||||
|
|
|
@ -234,7 +234,7 @@ class BaseFuncGraphEvaluator : public Evaluator {
|
|||
class FuncGraphEvaluator : public BaseFuncGraphEvaluator {
|
||||
public:
|
||||
FuncGraphEvaluator(const FuncGraphPtr &func_graph, const AnalysisContextPtr &context)
|
||||
: BaseFuncGraphEvaluator(context->FindParentContext(func_graph)), func_graph_(func_graph) {}
|
||||
: BaseFuncGraphEvaluator(context), func_graph_(func_graph) {}
|
||||
|
||||
~FuncGraphEvaluator() override = default;
|
||||
MS_DECLARE_PARENT(FuncGraphEvaluator, BaseFuncGraphEvaluator);
|
||||
|
|
|
@ -466,7 +466,7 @@ AnfNodePtr FuncGraphSpecializer::BuildSpecializedNodeInner(const AnfNodePtr &nod
|
|||
if (func->context() == nullptr) {
|
||||
MS_LOG(EXCEPTION) << "Func context is nullptr NodeInfo: " << trace::GetDebugInfo(func_graph_->debug_info());
|
||||
}
|
||||
AnalysisContextPtr context = real_eval->MakeContext(engine_, argvals);
|
||||
AnalysisContextPtr context = MakeContext(engine_, real_eval, argvals);
|
||||
MS_LOG(DEBUG) << "Specialize function graph: " << context->func_graph()->ToString() << ", args: " << argvals.size()
|
||||
<< ", graph: " << context->func_graph()->get_return()->DebugString();
|
||||
if (context->func_graph()->stub()) {
|
||||
|
@ -480,6 +480,17 @@ AnfNodePtr FuncGraphSpecializer::BuildSpecializedNodeInner(const AnfNodePtr &nod
|
|||
return BuildValueNode(v, abs);
|
||||
}
|
||||
|
||||
AnalysisContextPtr FuncGraphSpecializer::MakeContext(const AnalysisEnginePtr &engine,
|
||||
const BaseFuncGraphEvaluatorPtr &evaluator,
|
||||
const AbstractBasePtrList &args_spec_list) {
|
||||
AbstractBasePtrList normalized_args_spec_list = evaluator->NormalizeArgs(args_spec_list);
|
||||
normalized_args_spec_list = evaluator->BroadenUndeterminedArgs(normalized_args_spec_list);
|
||||
FuncGraphPtr fg = evaluator->GetFuncGraph(engine, normalized_args_spec_list);
|
||||
MS_EXCEPTION_IF_NULL(evaluator->parent_context());
|
||||
AnalysisContextPtr new_context = evaluator->parent_context()->NewContext(fg, normalized_args_spec_list);
|
||||
return new_context;
|
||||
}
|
||||
|
||||
AnfNodePtr FuncGraphSpecializer::BuildSpecializedParameterNode(const CNodePtr &new_node) {
|
||||
auto new_inputs = new_node->inputs();
|
||||
AnfNodePtr func = new_inputs[0];
|
||||
|
|
|
@ -41,6 +41,7 @@ enum SpecializeStatusCode {
|
|||
};
|
||||
|
||||
class FuncGraphSpecializer;
|
||||
using BaseFuncGraphEvaluatorPtr = std::shared_ptr<BaseFuncGraphEvaluator>;
|
||||
|
||||
// Specialize a func graph using analyzed abstract values.
|
||||
class ProgramSpecializer {
|
||||
|
@ -103,7 +104,10 @@ class FuncGraphSpecializer : public std::enable_shared_from_this<FuncGraphSpecia
|
|||
void ProcessNode(const AnfNodePtr &node);
|
||||
void ProcessCNode(const CNodePtr &new_node);
|
||||
|
||||
AnfNodeConfigPtr MakeConfig(const AnfNodePtr &node);
|
||||
inline AnfNodeConfigPtr MakeConfig(const AnfNodePtr &node);
|
||||
inline AnalysisContextPtr MakeContext(const AnalysisEnginePtr &engine, const BaseFuncGraphEvaluatorPtr &evaluator,
|
||||
const AbstractBasePtrList &args_spec_list);
|
||||
|
||||
inline void AddTodoItem(const AnfNodePtr &node) { todo_.push_back(node); }
|
||||
// Get node replicated by Cloner.
|
||||
AnfNodePtr GetReplicatedNode(const AnfNodePtr &node);
|
||||
|
|
|
@ -38,9 +38,8 @@ AnalysisContextPtr StackFrame::GetParentContext(const BaseFuncGraphEvaluatorPtr
|
|||
const AbstractFunctionPtr &graph_func) {
|
||||
AnalysisContextPtr parent_context = nullptr;
|
||||
auto func_graph_abs = dyn_cast<FuncGraphAbstractClosure>(graph_func);
|
||||
if (func_graph_abs != nullptr) { // Find parent context for FuncGraphAbstractClosure.
|
||||
auto branch_fg = func_graph_abs->func_graph();
|
||||
parent_context = func_graph_abs->context()->FindParentContext(branch_fg);
|
||||
if (func_graph_abs != nullptr) { // Set parent context for FuncGraphAbstractClosure.
|
||||
parent_context = func_graph_abs->context();
|
||||
} else if (graph_func->isa<MetaFuncGraphAbstractClosure>()) { // Or DummyContext for MetaFuncGraphAbstractClosure.
|
||||
parent_context = fg_evaluator->parent_context();
|
||||
if (parent_context == nullptr) {
|
||||
|
@ -85,7 +84,7 @@ StackFramePtr StackFrame::DoJump(const AnalysisEnginePtr &engine, const CNodePtr
|
|||
// Find parent context and create new context.
|
||||
AnalysisContextPtr parent_context = GetParentContext(fg_evaluator, graph_func);
|
||||
MS_EXCEPTION_IF_NULL(parent_context);
|
||||
auto new_context = parent_context->NewFuncGraphContext(fg, args_abs_list);
|
||||
auto new_context = parent_context->NewContext(fg, args_abs_list);
|
||||
|
||||
// Evaluate the parameters with new context.
|
||||
for (size_t i = 0; i < nargs; i++) {
|
||||
|
|
|
@ -92,7 +92,7 @@ using ConfigPtrList = std::vector<ConfigPtr>;
|
|||
class AnfNodeConfig : public Config {
|
||||
public:
|
||||
AnfNodeConfig(const AnalysisEnginePtr &engine, const AnfNodePtr &node, const AnalysisContextPtr &context)
|
||||
: Config(), engine_(std::weak_ptr<AnalysisEngine>(engine)), node_(node) {
|
||||
: Config(), engine_(std::weak_ptr<AnalysisEngine>(engine)), node_(node), context_(nullptr) {
|
||||
FuncGraphPtr fg;
|
||||
if (IsValueNode<FuncGraph>(node)) {
|
||||
auto v = node->cast<ValueNodePtr>();
|
||||
|
@ -100,9 +100,17 @@ class AnfNodeConfig : public Config {
|
|||
} else {
|
||||
fg = node->func_graph();
|
||||
}
|
||||
context_ = nullptr;
|
||||
if (context != nullptr) {
|
||||
context_ = context->FindParentContext(fg);
|
||||
|
||||
if (context == nullptr) {
|
||||
return;
|
||||
}
|
||||
if (context->func_graph() == fg) {
|
||||
// Usually `node` is CNode and not a FV, or top graph's ValueNodes.
|
||||
context_ = context;
|
||||
} else {
|
||||
// If `node` is FV, FuncGraph, or other graph ValueNodes.
|
||||
// Non-FuncGraph ValueNodes will always get a DummyContext since `fg` is null.
|
||||
context_ = context->FindOwnOrParentContext(fg);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -23,36 +23,17 @@
|
|||
|
||||
namespace mindspore {
|
||||
namespace abstract {
|
||||
AnalysisContextPtr AnalysisContext::NewContext(AnalysisContextPtr parent_context, FuncGraphPtr fg,
|
||||
const AbstractBasePtrList &args_spec_list) {
|
||||
MS_EXCEPTION_IF_NULL(parent_context);
|
||||
auto children_context_map_iter = parent_context->children_cache_.find(fg);
|
||||
if (children_context_map_iter != parent_context->children_cache_.end()) {
|
||||
auto children_context_map = children_context_map_iter->second;
|
||||
auto children_context_iter = children_context_map.find(args_spec_list);
|
||||
if (children_context_iter != children_context_map.end()) {
|
||||
return children_context_iter->second.lock();
|
||||
}
|
||||
}
|
||||
AnalysisContextPtr new_context = std::make_shared<AnalysisContext>(parent_context, fg, args_spec_list);
|
||||
// Reference to myself, so use weak_ptr to break reference cycle.
|
||||
auto weak_context = std::weak_ptr<AnalysisContext>(new_context);
|
||||
new_context->parent_cache_[fg] = weak_context;
|
||||
parent_context->children_cache_[fg][args_spec_list] = weak_context;
|
||||
return new_context;
|
||||
}
|
||||
|
||||
AnalysisContextPtr AnalysisContext::NewFuncGraphContext(const FuncGraphPtr &func_graph,
|
||||
AnalysisContextPtr AnalysisContext::NewContext(const FuncGraphPtr &func_graph,
|
||||
const AbstractBasePtrList &args_spec_list) {
|
||||
// Find func graph's parent and its parent context firstly.
|
||||
MS_EXCEPTION_IF_NULL(func_graph);
|
||||
FuncGraphPtr parent_graph = func_graph->parent();
|
||||
AnalysisContextPtr parent_context = nullptr;
|
||||
auto iter = parent_cache_.find(parent_graph);
|
||||
if (iter != parent_cache_.end()) {
|
||||
auto iter = extant_context_cache_.find(parent_graph);
|
||||
if (iter != extant_context_cache_.end()) {
|
||||
parent_context = iter->second.lock();
|
||||
}
|
||||
// If this happen, it will be a bug in code. But we raise exception to keep the scene.
|
||||
if (parent_context == nullptr) {
|
||||
if (parent_context == nullptr) { // If parent context is not found, we'll raise exception.
|
||||
std::ostringstream oss;
|
||||
oss << "BUG: Failed to find parent context in current context: " << this->ToString()
|
||||
<< ", func_graph: " << func_graph->ToString() << ", parent_graph: ";
|
||||
|
@ -63,31 +44,48 @@ AnalysisContextPtr AnalysisContext::NewFuncGraphContext(const FuncGraphPtr &func
|
|||
}
|
||||
MS_LOG(EXCEPTION) << oss.str() << " NodeInfo: " << trace::GetDebugInfo(func_graph->debug_info());
|
||||
}
|
||||
return NewContext(parent_context, func_graph, args_spec_list);
|
||||
|
||||
// Check if we created a context for func graph with the same arguments before.
|
||||
auto children_context_map_iter = parent_context->children_cache_.find(func_graph);
|
||||
if (children_context_map_iter != parent_context->children_cache_.end()) {
|
||||
auto children_context_map = children_context_map_iter->second;
|
||||
auto children_context_iter = children_context_map.find(args_spec_list);
|
||||
if (children_context_iter != children_context_map.end()) {
|
||||
return children_context_iter->second.lock();
|
||||
}
|
||||
}
|
||||
|
||||
// Create a new context for the func graph and its specific arguments.
|
||||
AnalysisContextPtr new_context = std::make_shared<AnalysisContext>(parent_context, func_graph, args_spec_list);
|
||||
// To avoid cycle-reference, use weak_ptr here.
|
||||
auto weak_new_context = std::weak_ptr<AnalysisContext>(new_context);
|
||||
new_context->extant_context_cache_[func_graph] = weak_new_context;
|
||||
parent_context->children_cache_[func_graph][args_spec_list] = weak_new_context;
|
||||
return new_context;
|
||||
}
|
||||
|
||||
AnalysisContextPtr AnalysisContext::FindParentContext(const FuncGraphPtr &func_graph) {
|
||||
auto p_iter = parent_cache_.find(func_graph);
|
||||
AnalysisContextPtr parent_context = nullptr;
|
||||
if (p_iter != parent_cache_.end()) {
|
||||
parent_context = p_iter->second.lock();
|
||||
AnalysisContextPtr AnalysisContext::FindOwnOrParentContext(const FuncGraphPtr &func_graph) {
|
||||
auto p_iter = extant_context_cache_.find(func_graph);
|
||||
AnalysisContextPtr extant_context = nullptr;
|
||||
if (p_iter != extant_context_cache_.end()) {
|
||||
extant_context = p_iter->second.lock();
|
||||
} else {
|
||||
auto iter_parent = parent_cache_.find(func_graph->parent());
|
||||
if (iter_parent != parent_cache_.end()) {
|
||||
parent_context = iter_parent->second.lock();
|
||||
auto iter_parent = extant_context_cache_.find(func_graph->parent());
|
||||
if (iter_parent != extant_context_cache_.end()) {
|
||||
extant_context = iter_parent->second.lock();
|
||||
}
|
||||
}
|
||||
// If this happen, it would be a bug in code. But we raise exception to keep the scene.
|
||||
if (parent_context == nullptr) {
|
||||
if (extant_context == nullptr) {
|
||||
std::ostringstream oss;
|
||||
oss << "BUG: Failed to find parent context for: " << func_graph->ToString() << ", parent_graph: ";
|
||||
oss << "BUG: Failed to find context for: " << func_graph->ToString() << ", parent_graph: ";
|
||||
if (func_graph->parent() != nullptr) {
|
||||
oss << func_graph->parent()->ToString();
|
||||
} else {
|
||||
oss << "nullptr";
|
||||
}
|
||||
oss << " parent_cache_: {";
|
||||
for (auto iter : parent_cache_) {
|
||||
oss << " extant context list: {";
|
||||
for (auto iter : extant_context_cache_) {
|
||||
if (iter.first == nullptr) {
|
||||
oss << " [graph: nullptr";
|
||||
} else {
|
||||
|
@ -100,12 +98,12 @@ AnalysisContextPtr AnalysisContext::FindParentContext(const FuncGraphPtr &func_g
|
|||
oss << "}";
|
||||
MS_LOG(EXCEPTION) << oss.str() << " NodeInfo: " << trace::GetDebugInfo(func_graph->debug_info());
|
||||
}
|
||||
return parent_context;
|
||||
return extant_context;
|
||||
}
|
||||
|
||||
AnalysisContextPtr AnalysisContext::DummyContext() {
|
||||
AnalysisContextPtr dummy_context = std::make_shared<AnalysisContext>(nullptr, nullptr, AbstractBasePtrList());
|
||||
dummy_context->parent_cache_[nullptr] = std::weak_ptr<AnalysisContext>(dummy_context);
|
||||
dummy_context->extant_context_cache_[nullptr] = std::weak_ptr<AnalysisContext>(dummy_context);
|
||||
return dummy_context;
|
||||
}
|
||||
|
||||
|
|
|
@ -39,20 +39,17 @@ class AnalysisContext {
|
|||
AnalysisContext(const AnalysisContextPtr &parent, const FuncGraphPtr &fg, const AbstractBasePtrList &args_spec_list)
|
||||
: parent_(parent), func_graph_(fg), args_spec_list_(args_spec_list) {
|
||||
if (parent_ != nullptr) {
|
||||
parent_cache_ = parent_->parent_cache_;
|
||||
extant_context_cache_ = parent_->extant_context_cache_;
|
||||
}
|
||||
}
|
||||
|
||||
~AnalysisContext() = default;
|
||||
|
||||
// Helper function to wrapper constructor to save shared_ptr in parent_cache.
|
||||
AnalysisContextPtr NewContext(AnalysisContextPtr parent, FuncGraphPtr fg, const AbstractBasePtrList &args_spec_list);
|
||||
|
||||
// Extend this context with values for another graph.
|
||||
AnalysisContextPtr NewFuncGraphContext(const FuncGraphPtr &func_graph, const AbstractBasePtrList &args_spec_list);
|
||||
AnalysisContextPtr NewContext(const FuncGraphPtr &func_graph, const AbstractBasePtrList &args_spec_list);
|
||||
|
||||
// Return a context restricted to a graph's dependencies.
|
||||
AnalysisContextPtr FindParentContext(const FuncGraphPtr &graph);
|
||||
// Return a context restricted to a graph and its parent.
|
||||
AnalysisContextPtr FindOwnOrParentContext(const FuncGraphPtr &graph);
|
||||
bool operator==(const AnalysisContext &other) const;
|
||||
std::size_t hash();
|
||||
static AnalysisContextPtr DummyContext();
|
||||
|
@ -67,7 +64,11 @@ class AnalysisContext {
|
|||
AnalysisContextPtr parent_;
|
||||
FuncGraphPtr func_graph_;
|
||||
AbstractBasePtrList args_spec_list_;
|
||||
std::unordered_map<FuncGraphPtr, AnalysisContextWeakPtr> parent_cache_;
|
||||
// Record all created context for each func graph.
|
||||
// `extant_context_cache_` is copied from its parent context.
|
||||
std::unordered_map<FuncGraphPtr, AnalysisContextWeakPtr> extant_context_cache_;
|
||||
// Record all created child contexts from this context.
|
||||
// Like: key: [func_graph & arguments], value: [child_context]
|
||||
std::unordered_map<FuncGraphPtr, ArgsSpecToAnalysisContextMap> children_cache_;
|
||||
};
|
||||
|
||||
|
|
|
@ -294,23 +294,23 @@ TEST_F(TestInferGraph, test_context) {
|
|||
|
||||
AnalysisContextPtr dummy_context = AnalysisContext::DummyContext();
|
||||
|
||||
AnalysisContextPtr f_context = dummy_context->NewFuncGraphContext(graph_f_, AbstractBasePtrList());
|
||||
ASSERT_TRUE(f_context->FindParentContext(graph_f_) = f_context);
|
||||
ASSERT_TRUE(f_context->FindParentContext(nullptr) = dummy_context);
|
||||
AnalysisContextPtr f_context = dummy_context->NewContext(graph_f_, AbstractBasePtrList());
|
||||
ASSERT_TRUE(f_context->FindOwnOrParentContext(graph_f_) = f_context);
|
||||
ASSERT_TRUE(f_context->FindOwnOrParentContext(nullptr) = dummy_context);
|
||||
|
||||
AnalysisContextPtr g_context = f_context->NewFuncGraphContext(graph_g_, AbstractBasePtrList());
|
||||
ASSERT_TRUE(g_context->FindParentContext(graph_g_) = g_context);
|
||||
ASSERT_TRUE(g_context->FindParentContext(graph_f_) = dummy_context);
|
||||
ASSERT_TRUE(g_context->FindParentContext(nullptr) = dummy_context);
|
||||
AnalysisContextPtr g_context = f_context->NewContext(graph_g_, AbstractBasePtrList());
|
||||
ASSERT_TRUE(g_context->FindOwnOrParentContext(graph_g_) = g_context);
|
||||
ASSERT_TRUE(g_context->FindOwnOrParentContext(graph_f_) = dummy_context);
|
||||
ASSERT_TRUE(g_context->FindOwnOrParentContext(nullptr) = dummy_context);
|
||||
|
||||
AnalysisContextPtr alpha_context = dummy_context->NewFuncGraphContext(graph_alpha_, AbstractBasePtrList());
|
||||
ASSERT_TRUE(alpha_context->FindParentContext(graph_alpha_) = alpha_context);
|
||||
ASSERT_TRUE(alpha_context->FindParentContext(nullptr) = dummy_context);
|
||||
AnalysisContextPtr alpha_context = dummy_context->NewContext(graph_alpha_, AbstractBasePtrList());
|
||||
ASSERT_TRUE(alpha_context->FindOwnOrParentContext(graph_alpha_) = alpha_context);
|
||||
ASSERT_TRUE(alpha_context->FindOwnOrParentContext(nullptr) = dummy_context);
|
||||
|
||||
AnalysisContextPtr beta_context = alpha_context->NewFuncGraphContext(graph_beta_, AbstractBasePtrList());
|
||||
ASSERT_TRUE(beta_context->FindParentContext(graph_beta_) = beta_context);
|
||||
ASSERT_TRUE(beta_context->FindParentContext(graph_alpha_) = alpha_context);
|
||||
ASSERT_TRUE(beta_context->FindParentContext(nullptr) = dummy_context);
|
||||
AnalysisContextPtr beta_context = alpha_context->NewContext(graph_beta_, AbstractBasePtrList());
|
||||
ASSERT_TRUE(beta_context->FindOwnOrParentContext(graph_beta_) = beta_context);
|
||||
ASSERT_TRUE(beta_context->FindOwnOrParentContext(graph_alpha_) = alpha_context);
|
||||
ASSERT_TRUE(beta_context->FindOwnOrParentContext(nullptr) = dummy_context);
|
||||
}
|
||||
|
||||
class TestInferMetaGraph : public UT::Common {
|
||||
|
|
Loading…
Reference in New Issue