!19434 Modify context searching and creating routine, not searching all the time.
Merge pull request !19434 from 张清华/opt0
This commit is contained in:
commit
db0a064cfd
|
@ -58,16 +58,6 @@ void EvalFailLogging(const EvaluatorPtr &evaluator, const AbstractBasePtrList &,
|
||||||
}
|
}
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
AnalysisContextPtr BaseFuncGraphEvaluator::MakeContext(const AnalysisEnginePtr &engine,
|
|
||||||
const AbstractBasePtrList &args_spec_list) {
|
|
||||||
AbstractBasePtrList normalized_args_spec_list = NormalizeArgs(args_spec_list);
|
|
||||||
normalized_args_spec_list = BroadenUndeterminedArgs(normalized_args_spec_list);
|
|
||||||
FuncGraphPtr fg = GetFuncGraph(engine, normalized_args_spec_list);
|
|
||||||
MS_EXCEPTION_IF_NULL(parent_context_);
|
|
||||||
AnalysisContextPtr context = parent_context_->NewFuncGraphContext(fg, normalized_args_spec_list);
|
|
||||||
return context;
|
|
||||||
}
|
|
||||||
|
|
||||||
void BaseFuncGraphEvaluator::EnterStackFrame(const AnalysisEnginePtr &engine, const StackFramePtr ¤t_stack_frame,
|
void BaseFuncGraphEvaluator::EnterStackFrame(const AnalysisEnginePtr &engine, const StackFramePtr ¤t_stack_frame,
|
||||||
const StackFramePtr &new_stack_frame) {
|
const StackFramePtr &new_stack_frame) {
|
||||||
// Enter new func graph.
|
// Enter new func graph.
|
||||||
|
@ -216,7 +206,7 @@ EvalResultPtr BaseFuncGraphEvaluator::Eval(AnalysisEnginePtr engine, const Abstr
|
||||||
<< parent_context_->func_graph()->ToString() << "()->" << AnalysisResultCacheMgr::GetThreadid() << ":"
|
<< parent_context_->func_graph()->ToString() << "()->" << AnalysisResultCacheMgr::GetThreadid() << ":"
|
||||||
<< fg->ToString() << "();";
|
<< fg->ToString() << "();";
|
||||||
}
|
}
|
||||||
auto context = parent_context_->NewFuncGraphContext(fg, args_abs_list);
|
auto context = parent_context_->NewContext(fg, args_abs_list);
|
||||||
auto func_graph_evaluator = dyn_cast<FuncGraphEvaluator>(shared_from_base<BaseFuncGraphEvaluator>());
|
auto func_graph_evaluator = dyn_cast<FuncGraphEvaluator>(shared_from_base<BaseFuncGraphEvaluator>());
|
||||||
if (func_graph_evaluator != nullptr) {
|
if (func_graph_evaluator != nullptr) {
|
||||||
if (engine->root_func_graph() == func_graph_evaluator->func_graph()) {
|
if (engine->root_func_graph() == func_graph_evaluator->func_graph()) {
|
||||||
|
|
|
@ -234,7 +234,7 @@ class BaseFuncGraphEvaluator : public Evaluator {
|
||||||
class FuncGraphEvaluator : public BaseFuncGraphEvaluator {
|
class FuncGraphEvaluator : public BaseFuncGraphEvaluator {
|
||||||
public:
|
public:
|
||||||
FuncGraphEvaluator(const FuncGraphPtr &func_graph, const AnalysisContextPtr &context)
|
FuncGraphEvaluator(const FuncGraphPtr &func_graph, const AnalysisContextPtr &context)
|
||||||
: BaseFuncGraphEvaluator(context->FindParentContext(func_graph)), func_graph_(func_graph) {}
|
: BaseFuncGraphEvaluator(context), func_graph_(func_graph) {}
|
||||||
|
|
||||||
~FuncGraphEvaluator() override = default;
|
~FuncGraphEvaluator() override = default;
|
||||||
MS_DECLARE_PARENT(FuncGraphEvaluator, BaseFuncGraphEvaluator);
|
MS_DECLARE_PARENT(FuncGraphEvaluator, BaseFuncGraphEvaluator);
|
||||||
|
|
|
@ -466,7 +466,7 @@ AnfNodePtr FuncGraphSpecializer::BuildSpecializedNodeInner(const AnfNodePtr &nod
|
||||||
if (func->context() == nullptr) {
|
if (func->context() == nullptr) {
|
||||||
MS_LOG(EXCEPTION) << "Func context is nullptr NodeInfo: " << trace::GetDebugInfo(func_graph_->debug_info());
|
MS_LOG(EXCEPTION) << "Func context is nullptr NodeInfo: " << trace::GetDebugInfo(func_graph_->debug_info());
|
||||||
}
|
}
|
||||||
AnalysisContextPtr context = real_eval->MakeContext(engine_, argvals);
|
AnalysisContextPtr context = MakeContext(engine_, real_eval, argvals);
|
||||||
MS_LOG(DEBUG) << "Specialize function graph: " << context->func_graph()->ToString() << ", args: " << argvals.size()
|
MS_LOG(DEBUG) << "Specialize function graph: " << context->func_graph()->ToString() << ", args: " << argvals.size()
|
||||||
<< ", graph: " << context->func_graph()->get_return()->DebugString();
|
<< ", graph: " << context->func_graph()->get_return()->DebugString();
|
||||||
if (context->func_graph()->stub()) {
|
if (context->func_graph()->stub()) {
|
||||||
|
@ -480,6 +480,17 @@ AnfNodePtr FuncGraphSpecializer::BuildSpecializedNodeInner(const AnfNodePtr &nod
|
||||||
return BuildValueNode(v, abs);
|
return BuildValueNode(v, abs);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
AnalysisContextPtr FuncGraphSpecializer::MakeContext(const AnalysisEnginePtr &engine,
|
||||||
|
const BaseFuncGraphEvaluatorPtr &evaluator,
|
||||||
|
const AbstractBasePtrList &args_spec_list) {
|
||||||
|
AbstractBasePtrList normalized_args_spec_list = evaluator->NormalizeArgs(args_spec_list);
|
||||||
|
normalized_args_spec_list = evaluator->BroadenUndeterminedArgs(normalized_args_spec_list);
|
||||||
|
FuncGraphPtr fg = evaluator->GetFuncGraph(engine, normalized_args_spec_list);
|
||||||
|
MS_EXCEPTION_IF_NULL(evaluator->parent_context());
|
||||||
|
AnalysisContextPtr new_context = evaluator->parent_context()->NewContext(fg, normalized_args_spec_list);
|
||||||
|
return new_context;
|
||||||
|
}
|
||||||
|
|
||||||
AnfNodePtr FuncGraphSpecializer::BuildSpecializedParameterNode(const CNodePtr &new_node) {
|
AnfNodePtr FuncGraphSpecializer::BuildSpecializedParameterNode(const CNodePtr &new_node) {
|
||||||
auto new_inputs = new_node->inputs();
|
auto new_inputs = new_node->inputs();
|
||||||
AnfNodePtr func = new_inputs[0];
|
AnfNodePtr func = new_inputs[0];
|
||||||
|
|
|
@ -41,6 +41,7 @@ enum SpecializeStatusCode {
|
||||||
};
|
};
|
||||||
|
|
||||||
class FuncGraphSpecializer;
|
class FuncGraphSpecializer;
|
||||||
|
using BaseFuncGraphEvaluatorPtr = std::shared_ptr<BaseFuncGraphEvaluator>;
|
||||||
|
|
||||||
// Specialize a func graph using analyzed abstract values.
|
// Specialize a func graph using analyzed abstract values.
|
||||||
class ProgramSpecializer {
|
class ProgramSpecializer {
|
||||||
|
@ -103,7 +104,10 @@ class FuncGraphSpecializer : public std::enable_shared_from_this<FuncGraphSpecia
|
||||||
void ProcessNode(const AnfNodePtr &node);
|
void ProcessNode(const AnfNodePtr &node);
|
||||||
void ProcessCNode(const CNodePtr &new_node);
|
void ProcessCNode(const CNodePtr &new_node);
|
||||||
|
|
||||||
AnfNodeConfigPtr MakeConfig(const AnfNodePtr &node);
|
inline AnfNodeConfigPtr MakeConfig(const AnfNodePtr &node);
|
||||||
|
inline AnalysisContextPtr MakeContext(const AnalysisEnginePtr &engine, const BaseFuncGraphEvaluatorPtr &evaluator,
|
||||||
|
const AbstractBasePtrList &args_spec_list);
|
||||||
|
|
||||||
inline void AddTodoItem(const AnfNodePtr &node) { todo_.push_back(node); }
|
inline void AddTodoItem(const AnfNodePtr &node) { todo_.push_back(node); }
|
||||||
// Get node replicated by Cloner.
|
// Get node replicated by Cloner.
|
||||||
AnfNodePtr GetReplicatedNode(const AnfNodePtr &node);
|
AnfNodePtr GetReplicatedNode(const AnfNodePtr &node);
|
||||||
|
|
|
@ -38,9 +38,8 @@ AnalysisContextPtr StackFrame::GetParentContext(const BaseFuncGraphEvaluatorPtr
|
||||||
const AbstractFunctionPtr &graph_func) {
|
const AbstractFunctionPtr &graph_func) {
|
||||||
AnalysisContextPtr parent_context = nullptr;
|
AnalysisContextPtr parent_context = nullptr;
|
||||||
auto func_graph_abs = dyn_cast<FuncGraphAbstractClosure>(graph_func);
|
auto func_graph_abs = dyn_cast<FuncGraphAbstractClosure>(graph_func);
|
||||||
if (func_graph_abs != nullptr) { // Find parent context for FuncGraphAbstractClosure.
|
if (func_graph_abs != nullptr) { // Set parent context for FuncGraphAbstractClosure.
|
||||||
auto branch_fg = func_graph_abs->func_graph();
|
parent_context = func_graph_abs->context();
|
||||||
parent_context = func_graph_abs->context()->FindParentContext(branch_fg);
|
|
||||||
} else if (graph_func->isa<MetaFuncGraphAbstractClosure>()) { // Or DummyContext for MetaFuncGraphAbstractClosure.
|
} else if (graph_func->isa<MetaFuncGraphAbstractClosure>()) { // Or DummyContext for MetaFuncGraphAbstractClosure.
|
||||||
parent_context = fg_evaluator->parent_context();
|
parent_context = fg_evaluator->parent_context();
|
||||||
if (parent_context == nullptr) {
|
if (parent_context == nullptr) {
|
||||||
|
@ -85,7 +84,7 @@ StackFramePtr StackFrame::DoJump(const AnalysisEnginePtr &engine, const CNodePtr
|
||||||
// Find parent context and create new context.
|
// Find parent context and create new context.
|
||||||
AnalysisContextPtr parent_context = GetParentContext(fg_evaluator, graph_func);
|
AnalysisContextPtr parent_context = GetParentContext(fg_evaluator, graph_func);
|
||||||
MS_EXCEPTION_IF_NULL(parent_context);
|
MS_EXCEPTION_IF_NULL(parent_context);
|
||||||
auto new_context = parent_context->NewFuncGraphContext(fg, args_abs_list);
|
auto new_context = parent_context->NewContext(fg, args_abs_list);
|
||||||
|
|
||||||
// Evaluate the parameters with new context.
|
// Evaluate the parameters with new context.
|
||||||
for (size_t i = 0; i < nargs; i++) {
|
for (size_t i = 0; i < nargs; i++) {
|
||||||
|
|
|
@ -92,7 +92,7 @@ using ConfigPtrList = std::vector<ConfigPtr>;
|
||||||
class AnfNodeConfig : public Config {
|
class AnfNodeConfig : public Config {
|
||||||
public:
|
public:
|
||||||
AnfNodeConfig(const AnalysisEnginePtr &engine, const AnfNodePtr &node, const AnalysisContextPtr &context)
|
AnfNodeConfig(const AnalysisEnginePtr &engine, const AnfNodePtr &node, const AnalysisContextPtr &context)
|
||||||
: Config(), engine_(std::weak_ptr<AnalysisEngine>(engine)), node_(node) {
|
: Config(), engine_(std::weak_ptr<AnalysisEngine>(engine)), node_(node), context_(nullptr) {
|
||||||
FuncGraphPtr fg;
|
FuncGraphPtr fg;
|
||||||
if (IsValueNode<FuncGraph>(node)) {
|
if (IsValueNode<FuncGraph>(node)) {
|
||||||
auto v = node->cast<ValueNodePtr>();
|
auto v = node->cast<ValueNodePtr>();
|
||||||
|
@ -100,9 +100,17 @@ class AnfNodeConfig : public Config {
|
||||||
} else {
|
} else {
|
||||||
fg = node->func_graph();
|
fg = node->func_graph();
|
||||||
}
|
}
|
||||||
context_ = nullptr;
|
|
||||||
if (context != nullptr) {
|
if (context == nullptr) {
|
||||||
context_ = context->FindParentContext(fg);
|
return;
|
||||||
|
}
|
||||||
|
if (context->func_graph() == fg) {
|
||||||
|
// Usually `node` is CNode and not a FV, or top graph's ValueNodes.
|
||||||
|
context_ = context;
|
||||||
|
} else {
|
||||||
|
// If `node` is FV, FuncGraph, or other graph ValueNodes.
|
||||||
|
// Non-FuncGraph ValueNodes will always get a DummyContext since `fg` is null.
|
||||||
|
context_ = context->FindOwnOrParentContext(fg);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -23,36 +23,17 @@
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace abstract {
|
namespace abstract {
|
||||||
AnalysisContextPtr AnalysisContext::NewContext(AnalysisContextPtr parent_context, FuncGraphPtr fg,
|
AnalysisContextPtr AnalysisContext::NewContext(const FuncGraphPtr &func_graph,
|
||||||
const AbstractBasePtrList &args_spec_list) {
|
|
||||||
MS_EXCEPTION_IF_NULL(parent_context);
|
|
||||||
auto children_context_map_iter = parent_context->children_cache_.find(fg);
|
|
||||||
if (children_context_map_iter != parent_context->children_cache_.end()) {
|
|
||||||
auto children_context_map = children_context_map_iter->second;
|
|
||||||
auto children_context_iter = children_context_map.find(args_spec_list);
|
|
||||||
if (children_context_iter != children_context_map.end()) {
|
|
||||||
return children_context_iter->second.lock();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
AnalysisContextPtr new_context = std::make_shared<AnalysisContext>(parent_context, fg, args_spec_list);
|
|
||||||
// Reference to myself, so use weak_ptr to break reference cycle.
|
|
||||||
auto weak_context = std::weak_ptr<AnalysisContext>(new_context);
|
|
||||||
new_context->parent_cache_[fg] = weak_context;
|
|
||||||
parent_context->children_cache_[fg][args_spec_list] = weak_context;
|
|
||||||
return new_context;
|
|
||||||
}
|
|
||||||
|
|
||||||
AnalysisContextPtr AnalysisContext::NewFuncGraphContext(const FuncGraphPtr &func_graph,
|
|
||||||
const AbstractBasePtrList &args_spec_list) {
|
const AbstractBasePtrList &args_spec_list) {
|
||||||
|
// Find func graph's parent and its parent context firstly.
|
||||||
MS_EXCEPTION_IF_NULL(func_graph);
|
MS_EXCEPTION_IF_NULL(func_graph);
|
||||||
FuncGraphPtr parent_graph = func_graph->parent();
|
FuncGraphPtr parent_graph = func_graph->parent();
|
||||||
AnalysisContextPtr parent_context = nullptr;
|
AnalysisContextPtr parent_context = nullptr;
|
||||||
auto iter = parent_cache_.find(parent_graph);
|
auto iter = extant_context_cache_.find(parent_graph);
|
||||||
if (iter != parent_cache_.end()) {
|
if (iter != extant_context_cache_.end()) {
|
||||||
parent_context = iter->second.lock();
|
parent_context = iter->second.lock();
|
||||||
}
|
}
|
||||||
// If this happen, it will be a bug in code. But we raise exception to keep the scene.
|
if (parent_context == nullptr) { // If parent context is not found, we'll raise exception.
|
||||||
if (parent_context == nullptr) {
|
|
||||||
std::ostringstream oss;
|
std::ostringstream oss;
|
||||||
oss << "BUG: Failed to find parent context in current context: " << this->ToString()
|
oss << "BUG: Failed to find parent context in current context: " << this->ToString()
|
||||||
<< ", func_graph: " << func_graph->ToString() << ", parent_graph: ";
|
<< ", func_graph: " << func_graph->ToString() << ", parent_graph: ";
|
||||||
|
@ -63,31 +44,48 @@ AnalysisContextPtr AnalysisContext::NewFuncGraphContext(const FuncGraphPtr &func
|
||||||
}
|
}
|
||||||
MS_LOG(EXCEPTION) << oss.str() << " NodeInfo: " << trace::GetDebugInfo(func_graph->debug_info());
|
MS_LOG(EXCEPTION) << oss.str() << " NodeInfo: " << trace::GetDebugInfo(func_graph->debug_info());
|
||||||
}
|
}
|
||||||
return NewContext(parent_context, func_graph, args_spec_list);
|
|
||||||
|
// Check if we created a context for func graph with the same arguments before.
|
||||||
|
auto children_context_map_iter = parent_context->children_cache_.find(func_graph);
|
||||||
|
if (children_context_map_iter != parent_context->children_cache_.end()) {
|
||||||
|
auto children_context_map = children_context_map_iter->second;
|
||||||
|
auto children_context_iter = children_context_map.find(args_spec_list);
|
||||||
|
if (children_context_iter != children_context_map.end()) {
|
||||||
|
return children_context_iter->second.lock();
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
AnalysisContextPtr AnalysisContext::FindParentContext(const FuncGraphPtr &func_graph) {
|
// Create a new context for the func graph and its specific arguments.
|
||||||
auto p_iter = parent_cache_.find(func_graph);
|
AnalysisContextPtr new_context = std::make_shared<AnalysisContext>(parent_context, func_graph, args_spec_list);
|
||||||
AnalysisContextPtr parent_context = nullptr;
|
// To avoid cycle-reference, use weak_ptr here.
|
||||||
if (p_iter != parent_cache_.end()) {
|
auto weak_new_context = std::weak_ptr<AnalysisContext>(new_context);
|
||||||
parent_context = p_iter->second.lock();
|
new_context->extant_context_cache_[func_graph] = weak_new_context;
|
||||||
|
parent_context->children_cache_[func_graph][args_spec_list] = weak_new_context;
|
||||||
|
return new_context;
|
||||||
|
}
|
||||||
|
|
||||||
|
AnalysisContextPtr AnalysisContext::FindOwnOrParentContext(const FuncGraphPtr &func_graph) {
|
||||||
|
auto p_iter = extant_context_cache_.find(func_graph);
|
||||||
|
AnalysisContextPtr extant_context = nullptr;
|
||||||
|
if (p_iter != extant_context_cache_.end()) {
|
||||||
|
extant_context = p_iter->second.lock();
|
||||||
} else {
|
} else {
|
||||||
auto iter_parent = parent_cache_.find(func_graph->parent());
|
auto iter_parent = extant_context_cache_.find(func_graph->parent());
|
||||||
if (iter_parent != parent_cache_.end()) {
|
if (iter_parent != extant_context_cache_.end()) {
|
||||||
parent_context = iter_parent->second.lock();
|
extant_context = iter_parent->second.lock();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
// If this happen, it would be a bug in code. But we raise exception to keep the scene.
|
// If this happen, it would be a bug in code. But we raise exception to keep the scene.
|
||||||
if (parent_context == nullptr) {
|
if (extant_context == nullptr) {
|
||||||
std::ostringstream oss;
|
std::ostringstream oss;
|
||||||
oss << "BUG: Failed to find parent context for: " << func_graph->ToString() << ", parent_graph: ";
|
oss << "BUG: Failed to find context for: " << func_graph->ToString() << ", parent_graph: ";
|
||||||
if (func_graph->parent() != nullptr) {
|
if (func_graph->parent() != nullptr) {
|
||||||
oss << func_graph->parent()->ToString();
|
oss << func_graph->parent()->ToString();
|
||||||
} else {
|
} else {
|
||||||
oss << "nullptr";
|
oss << "nullptr";
|
||||||
}
|
}
|
||||||
oss << " parent_cache_: {";
|
oss << " extant context list: {";
|
||||||
for (auto iter : parent_cache_) {
|
for (auto iter : extant_context_cache_) {
|
||||||
if (iter.first == nullptr) {
|
if (iter.first == nullptr) {
|
||||||
oss << " [graph: nullptr";
|
oss << " [graph: nullptr";
|
||||||
} else {
|
} else {
|
||||||
|
@ -100,12 +98,12 @@ AnalysisContextPtr AnalysisContext::FindParentContext(const FuncGraphPtr &func_g
|
||||||
oss << "}";
|
oss << "}";
|
||||||
MS_LOG(EXCEPTION) << oss.str() << " NodeInfo: " << trace::GetDebugInfo(func_graph->debug_info());
|
MS_LOG(EXCEPTION) << oss.str() << " NodeInfo: " << trace::GetDebugInfo(func_graph->debug_info());
|
||||||
}
|
}
|
||||||
return parent_context;
|
return extant_context;
|
||||||
}
|
}
|
||||||
|
|
||||||
AnalysisContextPtr AnalysisContext::DummyContext() {
|
AnalysisContextPtr AnalysisContext::DummyContext() {
|
||||||
AnalysisContextPtr dummy_context = std::make_shared<AnalysisContext>(nullptr, nullptr, AbstractBasePtrList());
|
AnalysisContextPtr dummy_context = std::make_shared<AnalysisContext>(nullptr, nullptr, AbstractBasePtrList());
|
||||||
dummy_context->parent_cache_[nullptr] = std::weak_ptr<AnalysisContext>(dummy_context);
|
dummy_context->extant_context_cache_[nullptr] = std::weak_ptr<AnalysisContext>(dummy_context);
|
||||||
return dummy_context;
|
return dummy_context;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -39,20 +39,17 @@ class AnalysisContext {
|
||||||
AnalysisContext(const AnalysisContextPtr &parent, const FuncGraphPtr &fg, const AbstractBasePtrList &args_spec_list)
|
AnalysisContext(const AnalysisContextPtr &parent, const FuncGraphPtr &fg, const AbstractBasePtrList &args_spec_list)
|
||||||
: parent_(parent), func_graph_(fg), args_spec_list_(args_spec_list) {
|
: parent_(parent), func_graph_(fg), args_spec_list_(args_spec_list) {
|
||||||
if (parent_ != nullptr) {
|
if (parent_ != nullptr) {
|
||||||
parent_cache_ = parent_->parent_cache_;
|
extant_context_cache_ = parent_->extant_context_cache_;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
~AnalysisContext() = default;
|
~AnalysisContext() = default;
|
||||||
|
|
||||||
// Helper function to wrapper constructor to save shared_ptr in parent_cache.
|
|
||||||
AnalysisContextPtr NewContext(AnalysisContextPtr parent, FuncGraphPtr fg, const AbstractBasePtrList &args_spec_list);
|
|
||||||
|
|
||||||
// Extend this context with values for another graph.
|
// Extend this context with values for another graph.
|
||||||
AnalysisContextPtr NewFuncGraphContext(const FuncGraphPtr &func_graph, const AbstractBasePtrList &args_spec_list);
|
AnalysisContextPtr NewContext(const FuncGraphPtr &func_graph, const AbstractBasePtrList &args_spec_list);
|
||||||
|
|
||||||
// Return a context restricted to a graph's dependencies.
|
// Return a context restricted to a graph and its parent.
|
||||||
AnalysisContextPtr FindParentContext(const FuncGraphPtr &graph);
|
AnalysisContextPtr FindOwnOrParentContext(const FuncGraphPtr &graph);
|
||||||
bool operator==(const AnalysisContext &other) const;
|
bool operator==(const AnalysisContext &other) const;
|
||||||
std::size_t hash();
|
std::size_t hash();
|
||||||
static AnalysisContextPtr DummyContext();
|
static AnalysisContextPtr DummyContext();
|
||||||
|
@ -67,7 +64,11 @@ class AnalysisContext {
|
||||||
AnalysisContextPtr parent_;
|
AnalysisContextPtr parent_;
|
||||||
FuncGraphPtr func_graph_;
|
FuncGraphPtr func_graph_;
|
||||||
AbstractBasePtrList args_spec_list_;
|
AbstractBasePtrList args_spec_list_;
|
||||||
std::unordered_map<FuncGraphPtr, AnalysisContextWeakPtr> parent_cache_;
|
// Record all created context for each func graph.
|
||||||
|
// `extant_context_cache_` is copied from its parent context.
|
||||||
|
std::unordered_map<FuncGraphPtr, AnalysisContextWeakPtr> extant_context_cache_;
|
||||||
|
// Record all created child contexts from this context.
|
||||||
|
// Like: key: [func_graph & arguments], value: [child_context]
|
||||||
std::unordered_map<FuncGraphPtr, ArgsSpecToAnalysisContextMap> children_cache_;
|
std::unordered_map<FuncGraphPtr, ArgsSpecToAnalysisContextMap> children_cache_;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
|
@ -294,23 +294,23 @@ TEST_F(TestInferGraph, test_context) {
|
||||||
|
|
||||||
AnalysisContextPtr dummy_context = AnalysisContext::DummyContext();
|
AnalysisContextPtr dummy_context = AnalysisContext::DummyContext();
|
||||||
|
|
||||||
AnalysisContextPtr f_context = dummy_context->NewFuncGraphContext(graph_f_, AbstractBasePtrList());
|
AnalysisContextPtr f_context = dummy_context->NewContext(graph_f_, AbstractBasePtrList());
|
||||||
ASSERT_TRUE(f_context->FindParentContext(graph_f_) = f_context);
|
ASSERT_TRUE(f_context->FindOwnOrParentContext(graph_f_) = f_context);
|
||||||
ASSERT_TRUE(f_context->FindParentContext(nullptr) = dummy_context);
|
ASSERT_TRUE(f_context->FindOwnOrParentContext(nullptr) = dummy_context);
|
||||||
|
|
||||||
AnalysisContextPtr g_context = f_context->NewFuncGraphContext(graph_g_, AbstractBasePtrList());
|
AnalysisContextPtr g_context = f_context->NewContext(graph_g_, AbstractBasePtrList());
|
||||||
ASSERT_TRUE(g_context->FindParentContext(graph_g_) = g_context);
|
ASSERT_TRUE(g_context->FindOwnOrParentContext(graph_g_) = g_context);
|
||||||
ASSERT_TRUE(g_context->FindParentContext(graph_f_) = dummy_context);
|
ASSERT_TRUE(g_context->FindOwnOrParentContext(graph_f_) = dummy_context);
|
||||||
ASSERT_TRUE(g_context->FindParentContext(nullptr) = dummy_context);
|
ASSERT_TRUE(g_context->FindOwnOrParentContext(nullptr) = dummy_context);
|
||||||
|
|
||||||
AnalysisContextPtr alpha_context = dummy_context->NewFuncGraphContext(graph_alpha_, AbstractBasePtrList());
|
AnalysisContextPtr alpha_context = dummy_context->NewContext(graph_alpha_, AbstractBasePtrList());
|
||||||
ASSERT_TRUE(alpha_context->FindParentContext(graph_alpha_) = alpha_context);
|
ASSERT_TRUE(alpha_context->FindOwnOrParentContext(graph_alpha_) = alpha_context);
|
||||||
ASSERT_TRUE(alpha_context->FindParentContext(nullptr) = dummy_context);
|
ASSERT_TRUE(alpha_context->FindOwnOrParentContext(nullptr) = dummy_context);
|
||||||
|
|
||||||
AnalysisContextPtr beta_context = alpha_context->NewFuncGraphContext(graph_beta_, AbstractBasePtrList());
|
AnalysisContextPtr beta_context = alpha_context->NewContext(graph_beta_, AbstractBasePtrList());
|
||||||
ASSERT_TRUE(beta_context->FindParentContext(graph_beta_) = beta_context);
|
ASSERT_TRUE(beta_context->FindOwnOrParentContext(graph_beta_) = beta_context);
|
||||||
ASSERT_TRUE(beta_context->FindParentContext(graph_alpha_) = alpha_context);
|
ASSERT_TRUE(beta_context->FindOwnOrParentContext(graph_alpha_) = alpha_context);
|
||||||
ASSERT_TRUE(beta_context->FindParentContext(nullptr) = dummy_context);
|
ASSERT_TRUE(beta_context->FindOwnOrParentContext(nullptr) = dummy_context);
|
||||||
}
|
}
|
||||||
|
|
||||||
class TestInferMetaGraph : public UT::Common {
|
class TestInferMetaGraph : public UT::Common {
|
||||||
|
|
Loading…
Reference in New Issue