Dump func graph where to call the node.
This commit is contained in:
parent
607ffdf63a
commit
cbb2d17efb
|
@ -105,7 +105,7 @@ void DumpInferStack(std::ostringstream &oss) {
|
|||
continue;
|
||||
}
|
||||
auto args_spec_list = context->args_spec_list();
|
||||
oss << " #" << index++ << " " << GetGraphParamString(graph, args_spec_list);
|
||||
oss << " #" << index++ << " " << GetGraphParamString(graph, args_spec_list) << "\n";
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -128,7 +128,7 @@ class AnalyzeFailExporter : public AnfExporter {
|
|||
AnalyzeFailExporter() : AnfExporter(true, false) {}
|
||||
~AnalyzeFailExporter() override = default;
|
||||
|
||||
bool ExportFuncGraph(const std::string &filename, const std::vector<abstract::AnfNodeConfigPtr> &node_cfg_stack);
|
||||
bool ExportFuncGraph(const std::string &filename, const std::vector<abstract::AnfNodeConfigPtr> &node_config_stack);
|
||||
|
||||
private:
|
||||
void ExportOneFuncGraph(std::ofstream &ofs, const FuncGraphPtr &func_graph, const TaggedNodeMap &tagged_cnodes_map);
|
||||
|
@ -139,23 +139,23 @@ class AnalyzeFailExporter : public AnfExporter {
|
|||
|
||||
std::string GetNodeType(const AnfNodePtr &nd) override;
|
||||
AbstractBasePtr GetNodeAbstract(const AnfNodePtr &nd);
|
||||
AnfNodeConfigPtr GetFordwardConfigPtr(const AnfNodeConfigPtr &cfg);
|
||||
AnfNodeConfigPtr GetFordwardConfig(const AnfNodeConfigPtr &cfg);
|
||||
void ProcessFuncGraphCall(const CNodePtr &node, std::string *const op_comment);
|
||||
void OutputStatementComment(std::ofstream &ofs, const CNodePtr &node);
|
||||
|
||||
AnalysisContextPtr cur_ctx_ = nullptr;
|
||||
AnalysisContextPtr current_context_ = nullptr;
|
||||
AnalysisEnginePtr engine_ = nullptr;
|
||||
};
|
||||
|
||||
std::unordered_map<FuncGraphPtr, TaggedNodeMap> CalcTaggedFuncGraphs(
|
||||
const std::vector<abstract::AnfNodeConfigPtr> &node_cfg_stack) {
|
||||
const std::vector<abstract::AnfNodeConfigPtr> &node_config_stack) {
|
||||
std::unordered_map<FuncGraphPtr, TaggedNodeMap> tagged_func_graphs;
|
||||
for (size_t i = 0; i < node_cfg_stack.size(); ++i) {
|
||||
auto node_cfg = node_cfg_stack[i];
|
||||
MS_EXCEPTION_IF_NULL(node_cfg);
|
||||
auto fg = node_cfg->context()->func_graph();
|
||||
for (size_t i = 0; i < node_config_stack.size(); ++i) {
|
||||
auto node_config = node_config_stack[i];
|
||||
MS_EXCEPTION_IF_NULL(node_config);
|
||||
auto fg = node_config->func_graph();
|
||||
MS_EXCEPTION_IF_NULL(fg);
|
||||
auto node = node_cfg->node();
|
||||
auto node = node_config->node();
|
||||
tagged_func_graphs[fg][node] = i;
|
||||
}
|
||||
return tagged_func_graphs;
|
||||
|
@ -167,12 +167,13 @@ bool OutputAnalyzedGraphWithType(const string &file_path) {
|
|||
}
|
||||
|
||||
std::string AnalyzeFailExporter::GetNodeType(const AnfNodePtr &node) {
|
||||
if (cur_ctx_ == nullptr) {
|
||||
if (current_context_ == nullptr) {
|
||||
return AnfExporter::GetNodeType(node);
|
||||
}
|
||||
|
||||
MS_EXCEPTION_IF_NULL(engine_);
|
||||
auto cfg = engine_->MakeConfig(node, cur_ctx_);
|
||||
FuncGraphPtr dummy_call_func_graph = nullptr;
|
||||
auto cfg = engine_->MakeConfig(node, current_context_, dummy_call_func_graph);
|
||||
auto ret = abstract::AnalysisResultCacheMgr::GetInstance().GetValue(cfg);
|
||||
if (ret == nullptr) {
|
||||
return "Undefined";
|
||||
|
@ -181,23 +182,24 @@ std::string AnalyzeFailExporter::GetNodeType(const AnfNodePtr &node) {
|
|||
}
|
||||
|
||||
AbstractBasePtr AnalyzeFailExporter::GetNodeAbstract(const AnfNodePtr &node) {
|
||||
if (cur_ctx_ == nullptr) {
|
||||
if (current_context_ == nullptr) {
|
||||
return nullptr;
|
||||
}
|
||||
MS_EXCEPTION_IF_NULL(engine_);
|
||||
auto cfg = engine_->MakeConfig(node, cur_ctx_);
|
||||
FuncGraphPtr dummy_call_func_graph = nullptr;
|
||||
auto cfg = engine_->MakeConfig(node, current_context_, dummy_call_func_graph);
|
||||
auto ret = abstract::AnalysisResultCacheMgr::GetInstance().GetValue(cfg);
|
||||
return ret == nullptr ? nullptr : ret->abstract();
|
||||
}
|
||||
|
||||
AnfNodeConfigPtr AnalyzeFailExporter::GetFordwardConfigPtr(const AnfNodeConfigPtr &cfg) {
|
||||
AnfNodeConfigPtr AnalyzeFailExporter::GetFordwardConfig(const AnfNodeConfigPtr &cfg) {
|
||||
AnfNodeConfigPtr cur_cfg = cfg;
|
||||
auto iter = engine_->anfnode_config_map().find(cur_cfg);
|
||||
while (iter != engine_->anfnode_config_map().end()) {
|
||||
auto node = cur_cfg->node();
|
||||
cur_cfg = iter->second;
|
||||
MS_LOG(DEBUG) << "Get forword node: " << node.get() << "[" << node->ToString() << "] --> " << cur_cfg->node().get()
|
||||
<< "[" << cur_cfg->node()->ToString() << "]";
|
||||
MS_LOG(DEBUG) << "Get forword node: " << node << "[" << node->DebugString() << "] --> " << cur_cfg->node() << "["
|
||||
<< cur_cfg->node()->DebugString() << "]";
|
||||
iter = engine_->anfnode_config_map().find(cur_cfg);
|
||||
}
|
||||
return cur_cfg;
|
||||
|
@ -205,13 +207,15 @@ AnfNodeConfigPtr AnalyzeFailExporter::GetFordwardConfigPtr(const AnfNodeConfigPt
|
|||
|
||||
void AnalyzeFailExporter::ProcessFuncGraphCall(const CNodePtr &node, std::string *const op_comment) {
|
||||
if (node == nullptr) {
|
||||
MS_LOG(ERROR) << "Node is nullptr";
|
||||
return;
|
||||
}
|
||||
auto cfg = engine_->MakeConfig(node, cur_ctx_);
|
||||
cfg = GetFordwardConfigPtr(cfg);
|
||||
FuncGraphPtr dummy_call_func_graph = nullptr;
|
||||
auto cfg = engine_->MakeConfig(node, current_context_, dummy_call_func_graph);
|
||||
cfg = GetFordwardConfig(cfg);
|
||||
auto cnode = dyn_cast<CNode>(cfg->node());
|
||||
if (cnode == nullptr) {
|
||||
MS_LOG(DEBUG) << "CNode is nullptr";
|
||||
MS_LOG(ERROR) << "CNode is nullptr";
|
||||
return;
|
||||
}
|
||||
|
||||
|
@ -369,8 +373,8 @@ void AnalyzeFailExporter::ExportOneFuncGraph(std::ofstream &ofs, const FuncGraph
|
|||
OrderedMap<AnfNodePtr, int, ParamPtrHasher, ParamPtrEqual> param_map;
|
||||
|
||||
ofs << "# [No." << (exported.size() + 1) << "] " << func_graph->DumpText();
|
||||
if (cur_ctx_ != nullptr) {
|
||||
ofs << " @ctx.addr=" << cur_ctx_.get();
|
||||
if (current_context_ != nullptr) {
|
||||
ofs << " @ctx.addr=" << current_context_.get();
|
||||
}
|
||||
ofs << "\n";
|
||||
if (label_manage::GetGlobalTraceLabelType() == label_manage::TraceLabelType::kWithUniqueId) {
|
||||
|
@ -397,26 +401,24 @@ void AnalyzeFailExporter::ExportOneFuncGraph(std::ofstream &ofs, const FuncGraph
|
|||
}
|
||||
|
||||
bool AnalyzeFailExporter::ExportFuncGraph(const std::string &filename,
|
||||
const std::vector<abstract::AnfNodeConfigPtr> &node_cfg_stack) {
|
||||
if (node_cfg_stack.empty()) {
|
||||
const std::vector<abstract::AnfNodeConfigPtr> &node_config_stack) {
|
||||
if (node_config_stack.empty()) {
|
||||
MS_LOG(DEBUG) << "Node configs is empty";
|
||||
return false;
|
||||
}
|
||||
|
||||
std::ofstream ofs(filename);
|
||||
if (!ofs.is_open()) {
|
||||
MS_LOG(ERROR) << "Open file '" << filename << "' failed!";
|
||||
return false;
|
||||
}
|
||||
|
||||
auto tagged_func_graphs = CalcTaggedFuncGraphs(node_cfg_stack);
|
||||
auto tagged_func_graphs = CalcTaggedFuncGraphs(node_config_stack);
|
||||
std::unordered_set<FuncGraphPtr> printed_func_graphs; // Check if func graph has been printed.
|
||||
// Output graph on the analysis stack
|
||||
for (const auto &node_cfg : node_cfg_stack) {
|
||||
auto ctx = node_cfg->context();
|
||||
auto fg = ctx->func_graph();
|
||||
for (const auto &node_config : node_config_stack) {
|
||||
auto fg = node_config->func_graph();
|
||||
if (fg == nullptr) {
|
||||
MS_LOG(ERROR) << "FuncGraph is null, context: " << node_cfg->context()->ToString();
|
||||
MS_LOG(ERROR) << "FuncGraph is null, context: " << node_config->ToString();
|
||||
continue;
|
||||
}
|
||||
if (printed_func_graphs.find(fg) != printed_func_graphs.end()) {
|
||||
|
@ -425,16 +427,16 @@ bool AnalyzeFailExporter::ExportFuncGraph(const std::string &filename,
|
|||
printed_func_graphs.emplace(fg);
|
||||
|
||||
if (engine_ == nullptr) {
|
||||
engine_ = node_cfg->engine();
|
||||
engine_ = node_config->engine();
|
||||
}
|
||||
|
||||
cur_ctx_ = ctx; // Set current context.
|
||||
current_context_ = node_config->context(); // Set current context.
|
||||
ExportOneFuncGraph(ofs, fg, tagged_func_graphs[fg]);
|
||||
ofs << "\n\n";
|
||||
}
|
||||
|
||||
ofs << "#===============================================================================\n";
|
||||
ofs << "# num of function graphs printed: " << printed_func_graphs.size() << "/" << node_cfg_stack.size() << "\n";
|
||||
ofs << "# num of function graphs in stack: " << node_config_stack.size() << "\n";
|
||||
ofs.close();
|
||||
return true;
|
||||
}
|
||||
|
@ -468,9 +470,9 @@ void GetEvalStackInfo(std::ostringstream &oss) {
|
|||
int index = 0;
|
||||
std::string last_location_info = "";
|
||||
for (size_t i = 0; i < stack.size(); ++i) {
|
||||
auto node_cfg = stack[i];
|
||||
auto node_config = stack[i];
|
||||
|
||||
auto cnode = dyn_cast<CNode>(node_cfg->node());
|
||||
auto cnode = dyn_cast<CNode>(node_config->node());
|
||||
if (cnode == nullptr) {
|
||||
MS_LOG(DEBUG) << "CNode of elements[" << i << "] is nullptr.";
|
||||
continue;
|
||||
|
@ -513,7 +515,7 @@ void TraceGraphEvalLeave(const abstract::AnalysisContextPtr &context) {
|
|||
graph_infer_stack.pop();
|
||||
}
|
||||
|
||||
void TraceEvalCNodeEnter(const abstract::AnfNodeConfigPtr &node_cfg) { cnode_debug_stack.push_back(node_cfg); }
|
||||
void TraceEvalCNodeEnter(const abstract::AnfNodeConfigPtr &node_config) { cnode_debug_stack.push_back(node_config); }
|
||||
|
||||
void TraceEvalCNodeLeave() { cnode_debug_stack.pop_back(); }
|
||||
|
||||
|
|
|
@ -63,7 +63,7 @@ void BaseFuncGraphEvaluator::EnterStackFrame(const AnalysisEnginePtr &engine, co
|
|||
// Enter new func graph.
|
||||
auto ¤t_node = current_stack_frame->CurrentNode();
|
||||
auto current_context = current_stack_frame->current_context();
|
||||
AnfNodeConfigPtr call_conf = engine->MakeConfig(current_node, current_context);
|
||||
AnfNodeConfigPtr call_conf = engine->MakeConfig(current_node, current_context, current_context->func_graph());
|
||||
auto evaluator = new_stack_frame->evaluator();
|
||||
MS_EXCEPTION_IF_NULL(evaluator);
|
||||
auto new_context = new_stack_frame->current_context();
|
||||
|
@ -158,7 +158,7 @@ AbstractBasePtr BaseFuncGraphEvaluator::LaunchRecursiveEval(const AnalysisEngine
|
|||
});
|
||||
AbstractBasePtr res_base = nullptr;
|
||||
for (const auto &node : all_nodes) {
|
||||
AnfNodeConfigPtr node_conf = engine->MakeConfig(node, context);
|
||||
AnfNodeConfigPtr node_conf = engine->MakeConfig(node, context, fg);
|
||||
MS_LOG(DEBUG) << "Analysis node begin, func graph: " << fg << "/" << fg->ToString()
|
||||
<< ", node_conf: " << node_conf->ToString();
|
||||
auto node_eval_result = engine->ObtainEvalResultWithCache(node_conf);
|
||||
|
@ -221,7 +221,7 @@ EvalResultPtr BaseFuncGraphEvaluator::Eval(AnalysisEnginePtr engine, const Abstr
|
|||
for (size_t i = 0; i < nargs; i++) {
|
||||
const auto &arg = args_abs_list[i];
|
||||
const auto &node = parameters[i];
|
||||
AnfNodeConfigPtr conf = engine->MakeConfig(node, context);
|
||||
AnfNodeConfigPtr conf = engine->MakeConfig(node, context, fg);
|
||||
engine->SaveEvalResultInCache(conf, std::make_shared<EvalResult>(arg, nullptr));
|
||||
MS_LOG(DEBUG) << GetInferThread() << "Set Param: " << conf->ToString() << " = " << arg->ToString();
|
||||
}
|
||||
|
|
|
@ -96,7 +96,7 @@ EvalResultPtr DoSignatureEvaluator::Run(AnalysisEnginePtr engine, const ConfigPt
|
|||
new_node = prim::GenerateCNode(out_node->func_graph(), prim_->ToString(), do_signature->function(), args_spec_list,
|
||||
args_inputs);
|
||||
}
|
||||
AnfNodeConfigPtr fn_conf = engine->MakeConfig(new_node, out_conf->context());
|
||||
AnfNodeConfigPtr fn_conf = engine->MakeConfig(new_node, out_conf->context(), out_conf->func_graph());
|
||||
|
||||
if (out_node->isa<CNode>()) {
|
||||
auto out_cnode = out_node->cast<CNodePtr>();
|
||||
|
@ -181,7 +181,7 @@ EvalResultPtr UnpackGraphEvaluator::Run(AnalysisEnginePtr engine, const ConfigPt
|
|||
}
|
||||
ScopeGuard scope_guard(scope);
|
||||
AnfNodePtr new_vnode = NewValueNode(new_graph);
|
||||
AnfNodeConfigPtr fn_conf = engine->MakeConfig(new_vnode, out_conf->context());
|
||||
AnfNodeConfigPtr fn_conf = engine->MakeConfig(new_vnode, out_conf->context(), out_conf->func_graph());
|
||||
|
||||
return engine->ForwardConfig(out_conf, fn_conf);
|
||||
}
|
||||
|
@ -263,7 +263,7 @@ EvalResultPtr MixedPrecisionCastEvaluator::Run(AnalysisEnginePtr engine, const C
|
|||
constexpr size_t source_node_index = 2;
|
||||
AnfNodePtr new_node =
|
||||
MixedPrecisionCastHelper(out_node_inputs[source_node_index], args_spec_list[1], out_node_inputs[1], func_graph);
|
||||
AnfNodeConfigPtr fn_conf = engine->MakeConfig(new_node, out_conf->context());
|
||||
AnfNodeConfigPtr fn_conf = engine->MakeConfig(new_node, out_conf->context(), out_conf->func_graph());
|
||||
|
||||
if (new_node->isa<CNode>()) {
|
||||
auto new_cnode = new_node->cast<CNodePtr>();
|
||||
|
@ -813,7 +813,7 @@ EvalResultPtr StaticGetterInferred(const ValuePtr &value, const ConfigPtr &data_
|
|||
new_cnode = func_graph->NewCNode({new_cnode});
|
||||
}
|
||||
AnalysisEnginePtr eng = old_conf->engine();
|
||||
AnfNodeConfigPtr fn_conf = eng->MakeConfig(new_cnode, old_conf->context());
|
||||
AnfNodeConfigPtr fn_conf = eng->MakeConfig(new_cnode, old_conf->context(), old_conf->func_graph());
|
||||
return eng->ForwardConfig(old_conf, fn_conf);
|
||||
}
|
||||
|
||||
|
@ -859,7 +859,7 @@ EvalResultPtr GetEvaluatedValueForNameSpaceString(const AnalysisEnginePtr &engin
|
|||
func_graph->ReplaceInOrder(out_node, new_node);
|
||||
|
||||
AnalysisEnginePtr eng = out_conf->engine();
|
||||
AnfNodeConfigPtr fn_conf = eng->MakeConfig(new_node, out_conf->context());
|
||||
AnfNodeConfigPtr fn_conf = eng->MakeConfig(new_node, out_conf->context(), out_conf->func_graph());
|
||||
return eng->ForwardConfig(out_conf, fn_conf);
|
||||
}
|
||||
|
||||
|
@ -1277,7 +1277,7 @@ class PartialEvaluator : public Evaluator {
|
|||
ScopeGuard scope_guard(scope);
|
||||
|
||||
CNodePtr new_cnode = func_graph->NewCNode(new_nodes_inputs);
|
||||
AnfNodeConfigPtr fn_conf = engine->MakeConfig(new_cnode, out_conf->context());
|
||||
AnfNodeConfigPtr fn_conf = engine->MakeConfig(new_cnode, out_conf->context(), out_conf->func_graph());
|
||||
return engine->ForwardConfig(out_conf, fn_conf);
|
||||
}
|
||||
};
|
||||
|
|
|
@ -829,7 +829,7 @@ AnfNodePtr FuncGraphSpecializer::BuildPossibleValueNode(const AnfNodePtr &origin
|
|||
}
|
||||
|
||||
AnfNodeConfigPtr FuncGraphSpecializer::MakeConfig(const AnfNodePtr &node) {
|
||||
return engine_->MakeConfig(node, context_);
|
||||
return engine_->MakeConfig(node, context_, func_graph_); // `func_graph_` is dummy here.
|
||||
}
|
||||
} // namespace abstract
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -25,7 +25,7 @@ AbstractBasePtrList StackFrame::GenerateArgsAbsList(const AnalysisEnginePtr &eng
|
|||
AbstractBasePtrList args_abs_list;
|
||||
auto &inputs = current_cnode->inputs();
|
||||
for (std::size_t i = 1; i < inputs.size(); i++) {
|
||||
auto config = engine->MakeConfig(inputs[i], current_context_);
|
||||
auto config = engine->MakeConfig(inputs[i], current_context_, current_context_->func_graph());
|
||||
auto abs = config->ObtainEvalResult()->abstract();
|
||||
args_abs_list.push_back(abs);
|
||||
}
|
||||
|
@ -90,7 +90,7 @@ StackFramePtr StackFrame::DoJump(const AnalysisEnginePtr &engine, const CNodePtr
|
|||
for (size_t i = 0; i < nargs; i++) {
|
||||
const auto &arg_abs = args_abs_list[i];
|
||||
const auto &node = fg->parameters()[i];
|
||||
AnfNodeConfigPtr conf = engine->MakeConfig(node, new_context);
|
||||
AnfNodeConfigPtr conf = engine->MakeConfig(node, new_context, new_context->func_graph());
|
||||
engine->SaveEvalResultInCache(conf, std::make_shared<EvalResult>(arg_abs, nullptr));
|
||||
}
|
||||
|
||||
|
@ -107,14 +107,14 @@ StackFramePtr StackFrame::Jump(const AnalysisEnginePtr &engine) {
|
|||
return nullptr;
|
||||
}
|
||||
auto cnode = current_node->cast<CNodePtr>();
|
||||
auto maybe_func = engine->GetCNodeOperatorAbstract(cnode, current_context_);
|
||||
auto maybe_func = engine->GetCNodeOperatorAbstract(cnode, current_context_, current_context_->func_graph());
|
||||
if (!maybe_func->isa<abstract::MetaFuncGraphAbstractClosure>() &&
|
||||
!maybe_func->isa<abstract::FuncGraphAbstractClosure>()) {
|
||||
return nullptr; // Not call FuncGraph or MetaFuncGraph.
|
||||
}
|
||||
|
||||
// It's FuncGraph Call or MetaFuncGraph Call. `maybe_func` is definitely a AbstractFunction.
|
||||
AnfNodeConfigPtr call_node_conf = engine->MakeConfig(cnode, current_context_);
|
||||
AnfNodeConfigPtr call_node_conf = engine->MakeConfig(cnode, current_context_, current_context_->func_graph());
|
||||
// Enter the call CNode.
|
||||
trace::TraceEvalCNodeEnter(call_node_conf);
|
||||
auto res = DoJump(engine, cnode, dyn_cast<AbstractFunction>(maybe_func));
|
||||
|
@ -129,7 +129,7 @@ EvalResultPtr StackFrame::Step(const AnalysisEnginePtr &engine) {
|
|||
auto ¤t_node = NextNode();
|
||||
MS_LOG(DEBUG) << "current_node: " << current_node->DebugString()
|
||||
<< ", current_context_: " << current_context_->ToString();
|
||||
AnfNodeConfigPtr node_conf = engine->MakeConfig(current_node, current_context_);
|
||||
AnfNodeConfigPtr node_conf = engine->MakeConfig(current_node, current_context_, current_context_->func_graph());
|
||||
auto node_eval_result = engine->ObtainEvalResultWithCache(node_conf);
|
||||
MS_LOG(DEBUG) << GetInferThread() << "Eval(" << node_conf->ToString()
|
||||
<< ") = " << node_eval_result->abstract()->ToString();
|
||||
|
@ -153,7 +153,7 @@ void StackFrame::Back(const AnalysisEnginePtr &engine, const StackFramePtr &last
|
|||
auto ¤t_node = NextNode();
|
||||
MS_LOG(DEBUG) << "current_node: " << current_node->DebugString()
|
||||
<< ", current_context_: " << current_context_->ToString();
|
||||
AnfNodeConfigPtr node_conf = engine->MakeConfig(current_node, current_context_);
|
||||
AnfNodeConfigPtr node_conf = engine->MakeConfig(current_node, current_context_, current_context_->func_graph());
|
||||
engine->SaveEvalResultInCache(node_conf, result);
|
||||
|
||||
// Leave the call CNode.
|
||||
|
|
|
@ -127,18 +127,17 @@ AnalysisResult AnalysisEngine::Run(const FuncGraphPtr &func_graph, const Abstrac
|
|||
[](const AbstractBasePtr &arg) -> ConfigPtr { return std::make_shared<VirtualConfig>(arg); });
|
||||
MS_EXCEPTION_IF_NULL(func_graph_manager_);
|
||||
func_graph_manager_->AddFuncGraph(func_graph);
|
||||
|
||||
root_func_graph_ = func_graph;
|
||||
|
||||
AnalysisContextPtr empty_context = AnalysisContext::DummyContext();
|
||||
|
||||
// Running the analyzer.
|
||||
ResetFunctionCallDepth();
|
||||
ResetStackFrameDepth();
|
||||
AnalysisContextPtr root_context = Run(func_graph, empty_context, args_conf_list);
|
||||
AnalysisContextPtr dummy_context = AnalysisContext::DummyContext();
|
||||
AnalysisContextPtr root_context = Run(func_graph, dummy_context, args_conf_list);
|
||||
MS_EXCEPTION_IF_NULL(root_context);
|
||||
MS_EXCEPTION_IF_NULL(root_context->func_graph());
|
||||
AnfNodeConfigPtr output_conf = MakeConfig(root_context->func_graph()->get_return(), root_context);
|
||||
auto root_context_fg = root_context->func_graph();
|
||||
MS_EXCEPTION_IF_NULL(root_context_fg);
|
||||
AnfNodeConfigPtr output_conf = MakeConfig(root_context_fg->get_return(), root_context, root_context_fg);
|
||||
MS_EXCEPTION_IF_NULL(func_graph);
|
||||
MS_LOG(INFO) << func_graph->ToString() << ": Run finished.";
|
||||
|
||||
|
@ -291,7 +290,8 @@ AbstractBasePtr AnalysisEngine::EvalValueNode(const ValueNodePtr &value_node, co
|
|||
return out;
|
||||
}
|
||||
|
||||
AbstractBasePtr AnalysisEngine::GetCNodeOperatorAbstract(const CNodePtr &cnode, const AnalysisContextPtr &context) {
|
||||
AbstractBasePtr AnalysisEngine::GetCNodeOperatorAbstract(const CNodePtr &cnode, const AnalysisContextPtr &context,
|
||||
const FuncGraphPtr &func_graph) {
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
auto &inputs = cnode->inputs();
|
||||
if (inputs.empty()) {
|
||||
|
@ -300,29 +300,29 @@ AbstractBasePtr AnalysisEngine::GetCNodeOperatorAbstract(const CNodePtr &cnode,
|
|||
AnfNodePtr func_node = inputs[0];
|
||||
MS_EXCEPTION_IF_NULL(func_node);
|
||||
MS_LOG(DEBUG) << "Current CNode function: " << func_node->DebugString();
|
||||
AnfNodeConfigPtr func_conf = MakeConfig(func_node, context);
|
||||
AnfNodeConfigPtr func_conf = MakeConfig(func_node, context, func_graph);
|
||||
MS_EXCEPTION_IF_NULL(func_conf);
|
||||
// Keep it in a local variable, otherwise smart pointer will free it.
|
||||
auto maybe_func_eval_result = func_conf->ObtainEvalResult();
|
||||
AbstractBasePtr maybe_func = maybe_func_eval_result->abstract();
|
||||
if (maybe_func == nullptr) {
|
||||
auto possible_func_eval_result = func_conf->ObtainEvalResult();
|
||||
AbstractBasePtr possible_func = possible_func_eval_result->abstract();
|
||||
if (possible_func == nullptr) {
|
||||
MS_LOG(EXCEPTION) << "No abstract, func_conf: " << func_conf->ToString();
|
||||
}
|
||||
return maybe_func;
|
||||
return possible_func;
|
||||
}
|
||||
|
||||
EvalResultPtr AnalysisEngine::EvalCNode(const CNodePtr &cnode, const AnfNodeConfigPtr &conf) {
|
||||
MS_EXCEPTION_IF_NULL(conf);
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
AbstractBasePtr maybe_func = GetCNodeOperatorAbstract(cnode, conf->context());
|
||||
if (maybe_func->BuildType()->type_id() == kObjectTypeUndeterminedType) {
|
||||
AbstractBasePtr possible_func = GetCNodeOperatorAbstract(cnode, conf->context(), conf->func_graph());
|
||||
if (possible_func->BuildType()->type_id() == kObjectTypeUndeterminedType) {
|
||||
MS_LOG(DEBUG) << "EvalCNode eval Undetermined";
|
||||
return std::make_shared<EvalResult>(maybe_func->Clone(), std::make_shared<AttrValueMap>());
|
||||
return std::make_shared<EvalResult>(possible_func->Clone(), std::make_shared<AttrValueMap>());
|
||||
}
|
||||
|
||||
AbstractFunctionPtr func = dyn_cast<AbstractFunction>(maybe_func);
|
||||
AbstractFunctionPtr func = dyn_cast<AbstractFunction>(possible_func);
|
||||
if (func == nullptr) {
|
||||
MS_LOG(ERROR) << "Can not cast to a AbstractFunction: " << maybe_func->ToString() << ".";
|
||||
MS_LOG(ERROR) << "Can not cast to a AbstractFunction: " << possible_func->ToString() << ".";
|
||||
MS_EXCEPTION(ValueError) << "This may be not defined, and it can't be a operator. Please check code.";
|
||||
}
|
||||
|
||||
|
@ -331,7 +331,7 @@ EvalResultPtr AnalysisEngine::EvalCNode(const CNodePtr &cnode, const AnfNodeConf
|
|||
auto &inputs = cnode->inputs();
|
||||
for (std::size_t i = 1; i < inputs.size(); i++) {
|
||||
const AnfNodePtr &node = inputs[i];
|
||||
args_conf_list.push_back(MakeConfig(node, conf->context()));
|
||||
args_conf_list.push_back(MakeConfig(node, conf->context(), conf->func_graph()));
|
||||
}
|
||||
std::vector<EvaluatorPtr> evaluators;
|
||||
|
||||
|
|
|
@ -91,8 +91,13 @@ using ConfigPtrList = std::vector<ConfigPtr>;
|
|||
// Config to a certain node in a certain context.
|
||||
class AnfNodeConfig : public Config {
|
||||
public:
|
||||
AnfNodeConfig(const AnalysisEnginePtr &engine, const AnfNodePtr &node, const AnalysisContextPtr &context)
|
||||
: Config(), engine_(std::weak_ptr<AnalysisEngine>(engine)), node_(node), context_(nullptr) {
|
||||
AnfNodeConfig(const AnalysisEnginePtr &engine, const AnfNodePtr &node, const AnalysisContextPtr &context,
|
||||
const FuncGraphPtr &func_graph)
|
||||
: Config(),
|
||||
engine_(std::weak_ptr<AnalysisEngine>(engine)),
|
||||
node_(node),
|
||||
context_(nullptr),
|
||||
func_graph_(func_graph) {
|
||||
FuncGraphPtr fg;
|
||||
if (IsValueNode<FuncGraph>(node)) {
|
||||
auto v = node->cast<ValueNodePtr>();
|
||||
|
@ -123,6 +128,8 @@ class AnfNodeConfig : public Config {
|
|||
|
||||
AnfNodePtr node() const { return node_; }
|
||||
|
||||
FuncGraphPtr func_graph() const { return func_graph_; }
|
||||
|
||||
AnalysisEnginePtr engine() const { return engine_.lock(); }
|
||||
|
||||
// used by unordered_map;
|
||||
|
@ -132,13 +139,14 @@ class AnfNodeConfig : public Config {
|
|||
if (context_->IsDummyContext() && other.context_->IsDummyContext()) {
|
||||
return true;
|
||||
}
|
||||
// Don't check `func_graph_` equality.
|
||||
return (node_ == other.node_) && (context_ == other.context_);
|
||||
}
|
||||
|
||||
std::string ToString() const override {
|
||||
std::ostringstream buffer;
|
||||
buffer << "Node: " << node_ << "/" << node_->DebugString() << "-uid(" << node_->UniqueId()
|
||||
<< "), Context: " << context_ << "/" << context_->ToString();
|
||||
<< "), Context: " << context_ << "/" << context_->ToString() << ", FuncGraph: " << func_graph_->ToString();
|
||||
return buffer.str();
|
||||
}
|
||||
|
||||
|
@ -148,7 +156,10 @@ class AnfNodeConfig : public Config {
|
|||
// weak_ptr to break Config cycle.
|
||||
std::weak_ptr<AnalysisEngine> engine_;
|
||||
AnfNodePtr node_;
|
||||
// Which context the node would be called, usually in owner func graph context.
|
||||
AnalysisContextPtr context_;
|
||||
// Where to call the node.
|
||||
FuncGraphPtr func_graph_;
|
||||
};
|
||||
|
||||
using AnfNodeConfigPtr = std::shared_ptr<AnfNodeConfig>;
|
||||
|
@ -236,15 +247,17 @@ class AnalysisEngine : public std::enable_shared_from_this<AnalysisEngine> {
|
|||
// Return the Evaluator for the given function.
|
||||
EvaluatorPtr GetEvaluatorFor(const AbstractFunctionPtr &fn);
|
||||
|
||||
AbstractBasePtr GetCNodeOperatorAbstract(const CNodePtr &cnode, const AnalysisContextPtr &context);
|
||||
AbstractBasePtr GetCNodeOperatorAbstract(const CNodePtr &cnode, const AnalysisContextPtr &context,
|
||||
const FuncGraphPtr &func_graph);
|
||||
AbstractBasePtr EvalValueNode(const ValueNodePtr &value_node, const AnfNodeConfigPtr &conf);
|
||||
EvalResultPtr EvalCNode(const CNodePtr &cnode, const AnfNodeConfigPtr &conf);
|
||||
// Infer the result of fn(args).
|
||||
EvalResultPtr Execute(const AbstractFunctionPtr &fn, const AbstractBasePtrList &args_spec_list);
|
||||
void Clear();
|
||||
void ClearEvaluatorCache();
|
||||
AnfNodeConfigPtr MakeConfig(const AnfNodePtr &node, const AnalysisContextPtr &context) {
|
||||
return std::make_shared<AnfNodeConfig>(shared_from_this(), node, context);
|
||||
AnfNodeConfigPtr MakeConfig(const AnfNodePtr &node, const AnalysisContextPtr &context,
|
||||
const FuncGraphPtr &func_graph) {
|
||||
return std::make_shared<AnfNodeConfig>(shared_from_this(), node, context, func_graph);
|
||||
}
|
||||
// Overloaded function.
|
||||
EvaluatorPtr _GetEvaluatorFor(const std::shared_ptr<PrimitiveAbstractClosure> &fn);
|
||||
|
|
|
@ -79,7 +79,7 @@ TEST_F(TestAbstract, TestParseDataClass) {
|
|||
ValuePtr obj = std::make_shared<parse::ClassObject>(fn, "TestFoo");
|
||||
|
||||
ValueNodePtr fn_node = NewValueNode(obj);
|
||||
AnfNodeConfigPtr fn_conf = std::make_shared<AnfNodeConfig>(nullptr, fn_node, nullptr);
|
||||
AnfNodeConfigPtr fn_conf = std::make_shared<AnfNodeConfig>(nullptr, fn_node, nullptr, nullptr);
|
||||
AbstractBasePtr foo = ToAbstract(obj, nullptr, fn_conf);
|
||||
ASSERT_TRUE(foo != nullptr);
|
||||
|
||||
|
|
Loading…
Reference in New Issue