Dump func graph where to call the node.

This commit is contained in:
Zhang Qinghua 2021-07-09 20:14:55 +08:00
parent 607ffdf63a
commit cbb2d17efb
8 changed files with 92 additions and 77 deletions

View File

@ -105,7 +105,7 @@ void DumpInferStack(std::ostringstream &oss) {
continue;
}
auto args_spec_list = context->args_spec_list();
oss << " #" << index++ << " " << GetGraphParamString(graph, args_spec_list);
oss << " #" << index++ << " " << GetGraphParamString(graph, args_spec_list) << "\n";
}
}
@ -128,7 +128,7 @@ class AnalyzeFailExporter : public AnfExporter {
AnalyzeFailExporter() : AnfExporter(true, false) {}
~AnalyzeFailExporter() override = default;
bool ExportFuncGraph(const std::string &filename, const std::vector<abstract::AnfNodeConfigPtr> &node_cfg_stack);
bool ExportFuncGraph(const std::string &filename, const std::vector<abstract::AnfNodeConfigPtr> &node_config_stack);
private:
void ExportOneFuncGraph(std::ofstream &ofs, const FuncGraphPtr &func_graph, const TaggedNodeMap &tagged_cnodes_map);
@ -139,23 +139,23 @@ class AnalyzeFailExporter : public AnfExporter {
std::string GetNodeType(const AnfNodePtr &nd) override;
AbstractBasePtr GetNodeAbstract(const AnfNodePtr &nd);
AnfNodeConfigPtr GetFordwardConfigPtr(const AnfNodeConfigPtr &cfg);
AnfNodeConfigPtr GetFordwardConfig(const AnfNodeConfigPtr &cfg);
void ProcessFuncGraphCall(const CNodePtr &node, std::string *const op_comment);
void OutputStatementComment(std::ofstream &ofs, const CNodePtr &node);
AnalysisContextPtr cur_ctx_ = nullptr;
AnalysisContextPtr current_context_ = nullptr;
AnalysisEnginePtr engine_ = nullptr;
};
std::unordered_map<FuncGraphPtr, TaggedNodeMap> CalcTaggedFuncGraphs(
const std::vector<abstract::AnfNodeConfigPtr> &node_cfg_stack) {
const std::vector<abstract::AnfNodeConfigPtr> &node_config_stack) {
std::unordered_map<FuncGraphPtr, TaggedNodeMap> tagged_func_graphs;
for (size_t i = 0; i < node_cfg_stack.size(); ++i) {
auto node_cfg = node_cfg_stack[i];
MS_EXCEPTION_IF_NULL(node_cfg);
auto fg = node_cfg->context()->func_graph();
for (size_t i = 0; i < node_config_stack.size(); ++i) {
auto node_config = node_config_stack[i];
MS_EXCEPTION_IF_NULL(node_config);
auto fg = node_config->func_graph();
MS_EXCEPTION_IF_NULL(fg);
auto node = node_cfg->node();
auto node = node_config->node();
tagged_func_graphs[fg][node] = i;
}
return tagged_func_graphs;
@ -167,12 +167,13 @@ bool OutputAnalyzedGraphWithType(const string &file_path) {
}
std::string AnalyzeFailExporter::GetNodeType(const AnfNodePtr &node) {
if (cur_ctx_ == nullptr) {
if (current_context_ == nullptr) {
return AnfExporter::GetNodeType(node);
}
MS_EXCEPTION_IF_NULL(engine_);
auto cfg = engine_->MakeConfig(node, cur_ctx_);
FuncGraphPtr dummy_call_func_graph = nullptr;
auto cfg = engine_->MakeConfig(node, current_context_, dummy_call_func_graph);
auto ret = abstract::AnalysisResultCacheMgr::GetInstance().GetValue(cfg);
if (ret == nullptr) {
return "Undefined";
@ -181,23 +182,24 @@ std::string AnalyzeFailExporter::GetNodeType(const AnfNodePtr &node) {
}
AbstractBasePtr AnalyzeFailExporter::GetNodeAbstract(const AnfNodePtr &node) {
if (cur_ctx_ == nullptr) {
if (current_context_ == nullptr) {
return nullptr;
}
MS_EXCEPTION_IF_NULL(engine_);
auto cfg = engine_->MakeConfig(node, cur_ctx_);
FuncGraphPtr dummy_call_func_graph = nullptr;
auto cfg = engine_->MakeConfig(node, current_context_, dummy_call_func_graph);
auto ret = abstract::AnalysisResultCacheMgr::GetInstance().GetValue(cfg);
return ret == nullptr ? nullptr : ret->abstract();
}
AnfNodeConfigPtr AnalyzeFailExporter::GetFordwardConfigPtr(const AnfNodeConfigPtr &cfg) {
AnfNodeConfigPtr AnalyzeFailExporter::GetFordwardConfig(const AnfNodeConfigPtr &cfg) {
AnfNodeConfigPtr cur_cfg = cfg;
auto iter = engine_->anfnode_config_map().find(cur_cfg);
while (iter != engine_->anfnode_config_map().end()) {
auto node = cur_cfg->node();
cur_cfg = iter->second;
MS_LOG(DEBUG) << "Get forword node: " << node.get() << "[" << node->ToString() << "] --> " << cur_cfg->node().get()
<< "[" << cur_cfg->node()->ToString() << "]";
MS_LOG(DEBUG) << "Get forword node: " << node << "[" << node->DebugString() << "] --> " << cur_cfg->node() << "["
<< cur_cfg->node()->DebugString() << "]";
iter = engine_->anfnode_config_map().find(cur_cfg);
}
return cur_cfg;
@ -205,13 +207,15 @@ AnfNodeConfigPtr AnalyzeFailExporter::GetFordwardConfigPtr(const AnfNodeConfigPt
void AnalyzeFailExporter::ProcessFuncGraphCall(const CNodePtr &node, std::string *const op_comment) {
if (node == nullptr) {
MS_LOG(ERROR) << "Node is nullptr";
return;
}
auto cfg = engine_->MakeConfig(node, cur_ctx_);
cfg = GetFordwardConfigPtr(cfg);
FuncGraphPtr dummy_call_func_graph = nullptr;
auto cfg = engine_->MakeConfig(node, current_context_, dummy_call_func_graph);
cfg = GetFordwardConfig(cfg);
auto cnode = dyn_cast<CNode>(cfg->node());
if (cnode == nullptr) {
MS_LOG(DEBUG) << "CNode is nullptr";
MS_LOG(ERROR) << "CNode is nullptr";
return;
}
@ -369,8 +373,8 @@ void AnalyzeFailExporter::ExportOneFuncGraph(std::ofstream &ofs, const FuncGraph
OrderedMap<AnfNodePtr, int, ParamPtrHasher, ParamPtrEqual> param_map;
ofs << "# [No." << (exported.size() + 1) << "] " << func_graph->DumpText();
if (cur_ctx_ != nullptr) {
ofs << " @ctx.addr=" << cur_ctx_.get();
if (current_context_ != nullptr) {
ofs << " @ctx.addr=" << current_context_.get();
}
ofs << "\n";
if (label_manage::GetGlobalTraceLabelType() == label_manage::TraceLabelType::kWithUniqueId) {
@ -397,26 +401,24 @@ void AnalyzeFailExporter::ExportOneFuncGraph(std::ofstream &ofs, const FuncGraph
}
bool AnalyzeFailExporter::ExportFuncGraph(const std::string &filename,
const std::vector<abstract::AnfNodeConfigPtr> &node_cfg_stack) {
if (node_cfg_stack.empty()) {
const std::vector<abstract::AnfNodeConfigPtr> &node_config_stack) {
if (node_config_stack.empty()) {
MS_LOG(DEBUG) << "Node configs is empty";
return false;
}
std::ofstream ofs(filename);
if (!ofs.is_open()) {
MS_LOG(ERROR) << "Open file '" << filename << "' failed!";
return false;
}
auto tagged_func_graphs = CalcTaggedFuncGraphs(node_cfg_stack);
auto tagged_func_graphs = CalcTaggedFuncGraphs(node_config_stack);
std::unordered_set<FuncGraphPtr> printed_func_graphs; // Check if func graph has been printed.
// Output graph on the analysis stack
for (const auto &node_cfg : node_cfg_stack) {
auto ctx = node_cfg->context();
auto fg = ctx->func_graph();
for (const auto &node_config : node_config_stack) {
auto fg = node_config->func_graph();
if (fg == nullptr) {
MS_LOG(ERROR) << "FuncGraph is null, context: " << node_cfg->context()->ToString();
MS_LOG(ERROR) << "FuncGraph is null, context: " << node_config->ToString();
continue;
}
if (printed_func_graphs.find(fg) != printed_func_graphs.end()) {
@ -425,16 +427,16 @@ bool AnalyzeFailExporter::ExportFuncGraph(const std::string &filename,
printed_func_graphs.emplace(fg);
if (engine_ == nullptr) {
engine_ = node_cfg->engine();
engine_ = node_config->engine();
}
cur_ctx_ = ctx; // Set current context.
current_context_ = node_config->context(); // Set current context.
ExportOneFuncGraph(ofs, fg, tagged_func_graphs[fg]);
ofs << "\n\n";
}
ofs << "#===============================================================================\n";
ofs << "# num of function graphs printed: " << printed_func_graphs.size() << "/" << node_cfg_stack.size() << "\n";
ofs << "# num of function graphs in stack: " << node_config_stack.size() << "\n";
ofs.close();
return true;
}
@ -468,9 +470,9 @@ void GetEvalStackInfo(std::ostringstream &oss) {
int index = 0;
std::string last_location_info = "";
for (size_t i = 0; i < stack.size(); ++i) {
auto node_cfg = stack[i];
auto node_config = stack[i];
auto cnode = dyn_cast<CNode>(node_cfg->node());
auto cnode = dyn_cast<CNode>(node_config->node());
if (cnode == nullptr) {
MS_LOG(DEBUG) << "CNode of elements[" << i << "] is nullptr.";
continue;
@ -513,7 +515,7 @@ void TraceGraphEvalLeave(const abstract::AnalysisContextPtr &context) {
graph_infer_stack.pop();
}
void TraceEvalCNodeEnter(const abstract::AnfNodeConfigPtr &node_cfg) { cnode_debug_stack.push_back(node_cfg); }
void TraceEvalCNodeEnter(const abstract::AnfNodeConfigPtr &node_config) { cnode_debug_stack.push_back(node_config); }
void TraceEvalCNodeLeave() { cnode_debug_stack.pop_back(); }

View File

@ -63,7 +63,7 @@ void BaseFuncGraphEvaluator::EnterStackFrame(const AnalysisEnginePtr &engine, co
// Enter new func graph.
auto &current_node = current_stack_frame->CurrentNode();
auto current_context = current_stack_frame->current_context();
AnfNodeConfigPtr call_conf = engine->MakeConfig(current_node, current_context);
AnfNodeConfigPtr call_conf = engine->MakeConfig(current_node, current_context, current_context->func_graph());
auto evaluator = new_stack_frame->evaluator();
MS_EXCEPTION_IF_NULL(evaluator);
auto new_context = new_stack_frame->current_context();
@ -158,7 +158,7 @@ AbstractBasePtr BaseFuncGraphEvaluator::LaunchRecursiveEval(const AnalysisEngine
});
AbstractBasePtr res_base = nullptr;
for (const auto &node : all_nodes) {
AnfNodeConfigPtr node_conf = engine->MakeConfig(node, context);
AnfNodeConfigPtr node_conf = engine->MakeConfig(node, context, fg);
MS_LOG(DEBUG) << "Analysis node begin, func graph: " << fg << "/" << fg->ToString()
<< ", node_conf: " << node_conf->ToString();
auto node_eval_result = engine->ObtainEvalResultWithCache(node_conf);
@ -221,7 +221,7 @@ EvalResultPtr BaseFuncGraphEvaluator::Eval(AnalysisEnginePtr engine, const Abstr
for (size_t i = 0; i < nargs; i++) {
const auto &arg = args_abs_list[i];
const auto &node = parameters[i];
AnfNodeConfigPtr conf = engine->MakeConfig(node, context);
AnfNodeConfigPtr conf = engine->MakeConfig(node, context, fg);
engine->SaveEvalResultInCache(conf, std::make_shared<EvalResult>(arg, nullptr));
MS_LOG(DEBUG) << GetInferThread() << "Set Param: " << conf->ToString() << " = " << arg->ToString();
}

View File

@ -96,7 +96,7 @@ EvalResultPtr DoSignatureEvaluator::Run(AnalysisEnginePtr engine, const ConfigPt
new_node = prim::GenerateCNode(out_node->func_graph(), prim_->ToString(), do_signature->function(), args_spec_list,
args_inputs);
}
AnfNodeConfigPtr fn_conf = engine->MakeConfig(new_node, out_conf->context());
AnfNodeConfigPtr fn_conf = engine->MakeConfig(new_node, out_conf->context(), out_conf->func_graph());
if (out_node->isa<CNode>()) {
auto out_cnode = out_node->cast<CNodePtr>();
@ -181,7 +181,7 @@ EvalResultPtr UnpackGraphEvaluator::Run(AnalysisEnginePtr engine, const ConfigPt
}
ScopeGuard scope_guard(scope);
AnfNodePtr new_vnode = NewValueNode(new_graph);
AnfNodeConfigPtr fn_conf = engine->MakeConfig(new_vnode, out_conf->context());
AnfNodeConfigPtr fn_conf = engine->MakeConfig(new_vnode, out_conf->context(), out_conf->func_graph());
return engine->ForwardConfig(out_conf, fn_conf);
}
@ -263,7 +263,7 @@ EvalResultPtr MixedPrecisionCastEvaluator::Run(AnalysisEnginePtr engine, const C
constexpr size_t source_node_index = 2;
AnfNodePtr new_node =
MixedPrecisionCastHelper(out_node_inputs[source_node_index], args_spec_list[1], out_node_inputs[1], func_graph);
AnfNodeConfigPtr fn_conf = engine->MakeConfig(new_node, out_conf->context());
AnfNodeConfigPtr fn_conf = engine->MakeConfig(new_node, out_conf->context(), out_conf->func_graph());
if (new_node->isa<CNode>()) {
auto new_cnode = new_node->cast<CNodePtr>();
@ -813,7 +813,7 @@ EvalResultPtr StaticGetterInferred(const ValuePtr &value, const ConfigPtr &data_
new_cnode = func_graph->NewCNode({new_cnode});
}
AnalysisEnginePtr eng = old_conf->engine();
AnfNodeConfigPtr fn_conf = eng->MakeConfig(new_cnode, old_conf->context());
AnfNodeConfigPtr fn_conf = eng->MakeConfig(new_cnode, old_conf->context(), old_conf->func_graph());
return eng->ForwardConfig(old_conf, fn_conf);
}
@ -859,7 +859,7 @@ EvalResultPtr GetEvaluatedValueForNameSpaceString(const AnalysisEnginePtr &engin
func_graph->ReplaceInOrder(out_node, new_node);
AnalysisEnginePtr eng = out_conf->engine();
AnfNodeConfigPtr fn_conf = eng->MakeConfig(new_node, out_conf->context());
AnfNodeConfigPtr fn_conf = eng->MakeConfig(new_node, out_conf->context(), out_conf->func_graph());
return eng->ForwardConfig(out_conf, fn_conf);
}
@ -1277,7 +1277,7 @@ class PartialEvaluator : public Evaluator {
ScopeGuard scope_guard(scope);
CNodePtr new_cnode = func_graph->NewCNode(new_nodes_inputs);
AnfNodeConfigPtr fn_conf = engine->MakeConfig(new_cnode, out_conf->context());
AnfNodeConfigPtr fn_conf = engine->MakeConfig(new_cnode, out_conf->context(), out_conf->func_graph());
return engine->ForwardConfig(out_conf, fn_conf);
}
};

View File

@ -829,7 +829,7 @@ AnfNodePtr FuncGraphSpecializer::BuildPossibleValueNode(const AnfNodePtr &origin
}
AnfNodeConfigPtr FuncGraphSpecializer::MakeConfig(const AnfNodePtr &node) {
return engine_->MakeConfig(node, context_);
return engine_->MakeConfig(node, context_, func_graph_); // `func_graph_` is dummy here.
}
} // namespace abstract
} // namespace mindspore

View File

@ -25,7 +25,7 @@ AbstractBasePtrList StackFrame::GenerateArgsAbsList(const AnalysisEnginePtr &eng
AbstractBasePtrList args_abs_list;
auto &inputs = current_cnode->inputs();
for (std::size_t i = 1; i < inputs.size(); i++) {
auto config = engine->MakeConfig(inputs[i], current_context_);
auto config = engine->MakeConfig(inputs[i], current_context_, current_context_->func_graph());
auto abs = config->ObtainEvalResult()->abstract();
args_abs_list.push_back(abs);
}
@ -90,7 +90,7 @@ StackFramePtr StackFrame::DoJump(const AnalysisEnginePtr &engine, const CNodePtr
for (size_t i = 0; i < nargs; i++) {
const auto &arg_abs = args_abs_list[i];
const auto &node = fg->parameters()[i];
AnfNodeConfigPtr conf = engine->MakeConfig(node, new_context);
AnfNodeConfigPtr conf = engine->MakeConfig(node, new_context, new_context->func_graph());
engine->SaveEvalResultInCache(conf, std::make_shared<EvalResult>(arg_abs, nullptr));
}
@ -107,14 +107,14 @@ StackFramePtr StackFrame::Jump(const AnalysisEnginePtr &engine) {
return nullptr;
}
auto cnode = current_node->cast<CNodePtr>();
auto maybe_func = engine->GetCNodeOperatorAbstract(cnode, current_context_);
auto maybe_func = engine->GetCNodeOperatorAbstract(cnode, current_context_, current_context_->func_graph());
if (!maybe_func->isa<abstract::MetaFuncGraphAbstractClosure>() &&
!maybe_func->isa<abstract::FuncGraphAbstractClosure>()) {
return nullptr; // Not call FuncGraph or MetaFuncGraph.
}
// It's FuncGraph Call or MetaFuncGraph Call. `maybe_func` is definitely a AbstractFunction.
AnfNodeConfigPtr call_node_conf = engine->MakeConfig(cnode, current_context_);
AnfNodeConfigPtr call_node_conf = engine->MakeConfig(cnode, current_context_, current_context_->func_graph());
// Enter the call CNode.
trace::TraceEvalCNodeEnter(call_node_conf);
auto res = DoJump(engine, cnode, dyn_cast<AbstractFunction>(maybe_func));
@ -129,7 +129,7 @@ EvalResultPtr StackFrame::Step(const AnalysisEnginePtr &engine) {
auto &current_node = NextNode();
MS_LOG(DEBUG) << "current_node: " << current_node->DebugString()
<< ", current_context_: " << current_context_->ToString();
AnfNodeConfigPtr node_conf = engine->MakeConfig(current_node, current_context_);
AnfNodeConfigPtr node_conf = engine->MakeConfig(current_node, current_context_, current_context_->func_graph());
auto node_eval_result = engine->ObtainEvalResultWithCache(node_conf);
MS_LOG(DEBUG) << GetInferThread() << "Eval(" << node_conf->ToString()
<< ") = " << node_eval_result->abstract()->ToString();
@ -153,7 +153,7 @@ void StackFrame::Back(const AnalysisEnginePtr &engine, const StackFramePtr &last
auto &current_node = NextNode();
MS_LOG(DEBUG) << "current_node: " << current_node->DebugString()
<< ", current_context_: " << current_context_->ToString();
AnfNodeConfigPtr node_conf = engine->MakeConfig(current_node, current_context_);
AnfNodeConfigPtr node_conf = engine->MakeConfig(current_node, current_context_, current_context_->func_graph());
engine->SaveEvalResultInCache(node_conf, result);
// Leave the call CNode.

View File

@ -127,18 +127,17 @@ AnalysisResult AnalysisEngine::Run(const FuncGraphPtr &func_graph, const Abstrac
[](const AbstractBasePtr &arg) -> ConfigPtr { return std::make_shared<VirtualConfig>(arg); });
MS_EXCEPTION_IF_NULL(func_graph_manager_);
func_graph_manager_->AddFuncGraph(func_graph);
root_func_graph_ = func_graph;
AnalysisContextPtr empty_context = AnalysisContext::DummyContext();
// Running the analyzer.
ResetFunctionCallDepth();
ResetStackFrameDepth();
AnalysisContextPtr root_context = Run(func_graph, empty_context, args_conf_list);
AnalysisContextPtr dummy_context = AnalysisContext::DummyContext();
AnalysisContextPtr root_context = Run(func_graph, dummy_context, args_conf_list);
MS_EXCEPTION_IF_NULL(root_context);
MS_EXCEPTION_IF_NULL(root_context->func_graph());
AnfNodeConfigPtr output_conf = MakeConfig(root_context->func_graph()->get_return(), root_context);
auto root_context_fg = root_context->func_graph();
MS_EXCEPTION_IF_NULL(root_context_fg);
AnfNodeConfigPtr output_conf = MakeConfig(root_context_fg->get_return(), root_context, root_context_fg);
MS_EXCEPTION_IF_NULL(func_graph);
MS_LOG(INFO) << func_graph->ToString() << ": Run finished.";
@ -291,7 +290,8 @@ AbstractBasePtr AnalysisEngine::EvalValueNode(const ValueNodePtr &value_node, co
return out;
}
AbstractBasePtr AnalysisEngine::GetCNodeOperatorAbstract(const CNodePtr &cnode, const AnalysisContextPtr &context) {
AbstractBasePtr AnalysisEngine::GetCNodeOperatorAbstract(const CNodePtr &cnode, const AnalysisContextPtr &context,
const FuncGraphPtr &func_graph) {
MS_EXCEPTION_IF_NULL(cnode);
auto &inputs = cnode->inputs();
if (inputs.empty()) {
@ -300,29 +300,29 @@ AbstractBasePtr AnalysisEngine::GetCNodeOperatorAbstract(const CNodePtr &cnode,
AnfNodePtr func_node = inputs[0];
MS_EXCEPTION_IF_NULL(func_node);
MS_LOG(DEBUG) << "Current CNode function: " << func_node->DebugString();
AnfNodeConfigPtr func_conf = MakeConfig(func_node, context);
AnfNodeConfigPtr func_conf = MakeConfig(func_node, context, func_graph);
MS_EXCEPTION_IF_NULL(func_conf);
// Keep it in a local variable, otherwise smart pointer will free it.
auto maybe_func_eval_result = func_conf->ObtainEvalResult();
AbstractBasePtr maybe_func = maybe_func_eval_result->abstract();
if (maybe_func == nullptr) {
auto possible_func_eval_result = func_conf->ObtainEvalResult();
AbstractBasePtr possible_func = possible_func_eval_result->abstract();
if (possible_func == nullptr) {
MS_LOG(EXCEPTION) << "No abstract, func_conf: " << func_conf->ToString();
}
return maybe_func;
return possible_func;
}
EvalResultPtr AnalysisEngine::EvalCNode(const CNodePtr &cnode, const AnfNodeConfigPtr &conf) {
MS_EXCEPTION_IF_NULL(conf);
MS_EXCEPTION_IF_NULL(cnode);
AbstractBasePtr maybe_func = GetCNodeOperatorAbstract(cnode, conf->context());
if (maybe_func->BuildType()->type_id() == kObjectTypeUndeterminedType) {
AbstractBasePtr possible_func = GetCNodeOperatorAbstract(cnode, conf->context(), conf->func_graph());
if (possible_func->BuildType()->type_id() == kObjectTypeUndeterminedType) {
MS_LOG(DEBUG) << "EvalCNode eval Undetermined";
return std::make_shared<EvalResult>(maybe_func->Clone(), std::make_shared<AttrValueMap>());
return std::make_shared<EvalResult>(possible_func->Clone(), std::make_shared<AttrValueMap>());
}
AbstractFunctionPtr func = dyn_cast<AbstractFunction>(maybe_func);
AbstractFunctionPtr func = dyn_cast<AbstractFunction>(possible_func);
if (func == nullptr) {
MS_LOG(ERROR) << "Can not cast to a AbstractFunction: " << maybe_func->ToString() << ".";
MS_LOG(ERROR) << "Can not cast to a AbstractFunction: " << possible_func->ToString() << ".";
MS_EXCEPTION(ValueError) << "This may be not defined, and it can't be a operator. Please check code.";
}
@ -331,7 +331,7 @@ EvalResultPtr AnalysisEngine::EvalCNode(const CNodePtr &cnode, const AnfNodeConf
auto &inputs = cnode->inputs();
for (std::size_t i = 1; i < inputs.size(); i++) {
const AnfNodePtr &node = inputs[i];
args_conf_list.push_back(MakeConfig(node, conf->context()));
args_conf_list.push_back(MakeConfig(node, conf->context(), conf->func_graph()));
}
std::vector<EvaluatorPtr> evaluators;

View File

@ -91,8 +91,13 @@ using ConfigPtrList = std::vector<ConfigPtr>;
// Config to a certain node in a certain context.
class AnfNodeConfig : public Config {
public:
AnfNodeConfig(const AnalysisEnginePtr &engine, const AnfNodePtr &node, const AnalysisContextPtr &context)
: Config(), engine_(std::weak_ptr<AnalysisEngine>(engine)), node_(node), context_(nullptr) {
AnfNodeConfig(const AnalysisEnginePtr &engine, const AnfNodePtr &node, const AnalysisContextPtr &context,
const FuncGraphPtr &func_graph)
: Config(),
engine_(std::weak_ptr<AnalysisEngine>(engine)),
node_(node),
context_(nullptr),
func_graph_(func_graph) {
FuncGraphPtr fg;
if (IsValueNode<FuncGraph>(node)) {
auto v = node->cast<ValueNodePtr>();
@ -123,6 +128,8 @@ class AnfNodeConfig : public Config {
AnfNodePtr node() const { return node_; }
FuncGraphPtr func_graph() const { return func_graph_; }
AnalysisEnginePtr engine() const { return engine_.lock(); }
// used by unordered_map;
@ -132,13 +139,14 @@ class AnfNodeConfig : public Config {
if (context_->IsDummyContext() && other.context_->IsDummyContext()) {
return true;
}
// Don't check `func_graph_` equality.
return (node_ == other.node_) && (context_ == other.context_);
}
std::string ToString() const override {
std::ostringstream buffer;
buffer << "Node: " << node_ << "/" << node_->DebugString() << "-uid(" << node_->UniqueId()
<< "), Context: " << context_ << "/" << context_->ToString();
<< "), Context: " << context_ << "/" << context_->ToString() << ", FuncGraph: " << func_graph_->ToString();
return buffer.str();
}
@ -148,7 +156,10 @@ class AnfNodeConfig : public Config {
// weak_ptr to break Config cycle.
std::weak_ptr<AnalysisEngine> engine_;
AnfNodePtr node_;
// Which context the node would be called, usually in owner func graph context.
AnalysisContextPtr context_;
// Where to call the node.
FuncGraphPtr func_graph_;
};
using AnfNodeConfigPtr = std::shared_ptr<AnfNodeConfig>;
@ -236,15 +247,17 @@ class AnalysisEngine : public std::enable_shared_from_this<AnalysisEngine> {
// Return the Evaluator for the given function.
EvaluatorPtr GetEvaluatorFor(const AbstractFunctionPtr &fn);
AbstractBasePtr GetCNodeOperatorAbstract(const CNodePtr &cnode, const AnalysisContextPtr &context);
AbstractBasePtr GetCNodeOperatorAbstract(const CNodePtr &cnode, const AnalysisContextPtr &context,
const FuncGraphPtr &func_graph);
AbstractBasePtr EvalValueNode(const ValueNodePtr &value_node, const AnfNodeConfigPtr &conf);
EvalResultPtr EvalCNode(const CNodePtr &cnode, const AnfNodeConfigPtr &conf);
// Infer the result of fn(args).
EvalResultPtr Execute(const AbstractFunctionPtr &fn, const AbstractBasePtrList &args_spec_list);
void Clear();
void ClearEvaluatorCache();
AnfNodeConfigPtr MakeConfig(const AnfNodePtr &node, const AnalysisContextPtr &context) {
return std::make_shared<AnfNodeConfig>(shared_from_this(), node, context);
AnfNodeConfigPtr MakeConfig(const AnfNodePtr &node, const AnalysisContextPtr &context,
const FuncGraphPtr &func_graph) {
return std::make_shared<AnfNodeConfig>(shared_from_this(), node, context, func_graph);
}
// Overloaded function.
EvaluatorPtr _GetEvaluatorFor(const std::shared_ptr<PrimitiveAbstractClosure> &fn);

View File

@ -79,7 +79,7 @@ TEST_F(TestAbstract, TestParseDataClass) {
ValuePtr obj = std::make_shared<parse::ClassObject>(fn, "TestFoo");
ValueNodePtr fn_node = NewValueNode(obj);
AnfNodeConfigPtr fn_conf = std::make_shared<AnfNodeConfig>(nullptr, fn_node, nullptr);
AnfNodeConfigPtr fn_conf = std::make_shared<AnfNodeConfig>(nullptr, fn_node, nullptr, nullptr);
AbstractBasePtr foo = ToAbstract(obj, nullptr, fn_conf);
ASSERT_TRUE(foo != nullptr);