forked from mindspore-Ecosystem/mindspore
!16313 Evaluate the func graph in simulative stack, not to call recursively.
From: @zh_qh Reviewed-by: @hwhewei,@ginfung Signed-off-by: @hwhewei
This commit is contained in:
commit
eb483a276b
|
@ -114,7 +114,7 @@ void DumpGlobalInfoEntry(const FuncGraphPtr &graph, std::ostringstream &buffer)
|
|||
return;
|
||||
}
|
||||
|
||||
buffer << "#IR entry : @" << graph->ToString() << "." << graph->debug_info()->get_id() << std::endl;
|
||||
buffer << "#IR entry : @" << graph->ToString() << std::endl;
|
||||
buffer << "#attrs :" << std::endl;
|
||||
for (const auto &attr : graph->attrs()) {
|
||||
buffer << attr.first << " : ";
|
||||
|
@ -216,7 +216,7 @@ void DumpOperator(const AnfNodePtr &op, const std::shared_ptr<SubGraphIRInfo> &g
|
|||
if (IsValueNode<FuncGraph>(op)) {
|
||||
FuncGraphPtr fg = GetValueNode<FuncGraphPtr>(op);
|
||||
if (fg != nullptr) {
|
||||
gsub->buffer << "call @" << fg->ToString() << "." << fg->debug_info()->get_id();
|
||||
gsub->buffer << "call @" << fg->ToString();
|
||||
}
|
||||
} else if (op->isa<CNode>()) {
|
||||
if (gsub->local_var_map.find(op) != gsub->local_var_map.end()) {
|
||||
|
@ -224,7 +224,7 @@ void DumpOperator(const AnfNodePtr &op, const std::shared_ptr<SubGraphIRInfo> &g
|
|||
} else {
|
||||
auto node = op->cast<CNodePtr>();
|
||||
auto fg = node->func_graph();
|
||||
gsub->buffer << "$(" << fg->ToString() << "." << fg->debug_info()->get_id() << ":" << node->ToString() << ")";
|
||||
gsub->buffer << "$(" << fg->ToString() << ":" << node->ToString() << ")";
|
||||
}
|
||||
} else if (op->isa<ValueNode>()) {
|
||||
gsub->buffer << GetValueNode(op)->ToString();
|
||||
|
@ -262,14 +262,14 @@ void DumpOperands(const AnfNodePtr &nd, OrderedMap<AnfNodePtr, int32_t> *para_ma
|
|||
} else {
|
||||
auto node = in->cast<CNodePtr>();
|
||||
auto fg = node->func_graph();
|
||||
gsub->buffer << "$(" << fg->ToString() << "." << fg->debug_info()->get_id() << ":" << node->ToString() << ")";
|
||||
gsub->buffer << "$(" << fg->ToString() << ":" << node->ToString() << ")";
|
||||
}
|
||||
} else if (in->isa<ValueNode>() && !IsValueNode<FuncGraph>(in)) {
|
||||
// non Primitive valuenode
|
||||
gsub->buffer << GetValueNode(in)->ToString();
|
||||
} else if (IsValueNode<FuncGraph>(in)) {
|
||||
FuncGraphPtr fg = GetValueNode<FuncGraphPtr>(in);
|
||||
gsub->buffer << "@" << fg->ToString() << "." << fg->debug_info()->get_id();
|
||||
gsub->buffer << "@" << fg->ToString();
|
||||
} else {
|
||||
gsub->buffer << in->ToString();
|
||||
}
|
||||
|
@ -501,8 +501,7 @@ void DumpSubgraph(const OrderedMap<FuncGraphPtr, std::shared_ptr<SubGraphIRInfo>
|
|||
}
|
||||
fout << std::endl;
|
||||
}
|
||||
fout << "subgraph @" << sg.first->ToString() << ".";
|
||||
fout << sg.first->debug_info()->get_id() << "(";
|
||||
fout << "subgraph @" << sg.first->ToString() << "(";
|
||||
if (sg.first != graph) {
|
||||
std::vector<AnfNodePtr> parameters = sg.first->parameters();
|
||||
if (parameters.size() == 1) {
|
||||
|
|
|
@ -73,7 +73,7 @@ int AnfExporter::GetParamIndex(const FuncGraphPtr &func_graph, const AnfNodePtr
|
|||
if (!check_integrity_) {
|
||||
break;
|
||||
}
|
||||
MS_LOG(EXCEPTION) << "Can not find func graph '" << fg->DumpText() << "." << fg->debug_info()->get_id() << "'";
|
||||
MS_LOG(EXCEPTION) << "Can not find func graph '" << fg->DumpText() << "'";
|
||||
}
|
||||
auto param_map = exported[fg];
|
||||
if (param_map.find(param) != param_map.end()) {
|
||||
|
@ -83,7 +83,7 @@ int AnfExporter::GetParamIndex(const FuncGraphPtr &func_graph, const AnfNodePtr
|
|||
}
|
||||
if (throw_excp) {
|
||||
MS_LOG(EXCEPTION) << "Can not find index for param '" << param->DumpText() << "' for func graph '"
|
||||
<< func_graph->DumpText() << "." << func_graph->debug_info()->get_id() << "'";
|
||||
<< func_graph->DumpText() << "'";
|
||||
}
|
||||
return -1;
|
||||
}
|
||||
|
@ -457,7 +457,7 @@ void AnfExporter::OutputStatementComment(std::ofstream &ofs, const CNodePtr &nod
|
|||
}
|
||||
FuncGraphPtr fg = GetValueNode<FuncGraphPtr>(arg);
|
||||
std::string func_graph_id = fg->debug_info()->get_id();
|
||||
comment << " fg_" << func_graph_id << "=" << fg->ToString() << "." << func_graph_id;
|
||||
comment << " fg_" << func_graph_id << "=" << fg->ToString();
|
||||
}
|
||||
if (has_comment) {
|
||||
ofs << comment.str();
|
||||
|
@ -552,8 +552,7 @@ void AnfExporter::ExportOneFuncGraph(std::ofstream &ofs, const FuncGraphPtr &fun
|
|||
if (*(func_graph->switch_layer_input())) {
|
||||
ofs << "switch_layer_input: " << *(func_graph->switch_layer_input()) << "\n";
|
||||
}
|
||||
ofs << "# [No." << (exported.size() + 1) << "] " << func_graph->DumpText() << "."
|
||||
<< func_graph->debug_info()->get_id() << "\n";
|
||||
ofs << "# [No." << (exported.size() + 1) << "] " << func_graph->DumpText() << "\n";
|
||||
if (label_manage::GetGlobalTraceLabelType() == label_manage::TraceLabelType::kWithUniqueId) {
|
||||
ofs << trace::GetDebugInfo(func_graph->debug_info(), "# ", kSourceLineTipDiscard) << "#"
|
||||
<< label_manage::Label(func_graph->debug_info()) << "\n";
|
||||
|
|
|
@ -93,7 +93,7 @@ void DumpInferStack(std::ostringstream &oss) {
|
|||
infer_vec.clear();
|
||||
break;
|
||||
}
|
||||
auto graph_context = graph_infer->graph_context();
|
||||
auto graph_context = graph_infer->context();
|
||||
if (graph_context == nullptr) {
|
||||
MS_LOG(INFO) << "Null context continue";
|
||||
continue;
|
||||
|
@ -253,7 +253,7 @@ std::vector<AnalysisContextPtr> AnalyzedFuncGraphExporter::ProcessFuncGraphCall(
|
|||
}
|
||||
|
||||
auto base_fg_evaluator = dyn_cast<abstract::BaseFuncGraphEvaluator>(evaluator);
|
||||
auto ctx = base_fg_evaluator->graph_context();
|
||||
auto ctx = base_fg_evaluator->context();
|
||||
if (ctx != nullptr && context_map_.insert({ctx, false}).second) {
|
||||
MS_LOG(DEBUG) << "Add new context, ctx.addr = " << ctx.get() << "ctx = " << ctx->ToString();
|
||||
context_vec_.push_back(ctx);
|
||||
|
@ -298,7 +298,7 @@ void AnalyzedFuncGraphExporter::OutputStatementComment(std::ofstream &ofs, const
|
|||
}
|
||||
FuncGraphPtr fg = GetValueNode<FuncGraphPtr>(arg);
|
||||
std::string func_graph_id = fg->debug_info()->get_id();
|
||||
comment << " fg_" << func_graph_id << "=" << fg->ToString() << "." << func_graph_id;
|
||||
comment << " fg_" << func_graph_id << "=" << fg->ToString();
|
||||
if (ctxs.size() > i && ctxs[i] != nullptr) {
|
||||
comment << "(@ctx.addr=" << ctxs[i].get() << ")";
|
||||
}
|
||||
|
@ -392,8 +392,7 @@ void AnalyzedFuncGraphExporter::ExportOneFuncGraph(std::ofstream &ofs, const Fun
|
|||
std::vector<AnfNodePtr> parameters = func_graph->parameters();
|
||||
OrderedMap<AnfNodePtr, int, ParamPtrHasher, ParamPtrEqual> param_map;
|
||||
|
||||
ofs << "# [No." << (exported.size() + 1) << "] " << func_graph->DumpText() << "."
|
||||
<< func_graph->debug_info()->get_id();
|
||||
ofs << "# [No." << (exported.size() + 1) << "] " << func_graph->DumpText();
|
||||
if (cur_ctx_ != nullptr) {
|
||||
ofs << " @ctx.addr=" << cur_ctx_.get();
|
||||
}
|
||||
|
|
|
@ -115,6 +115,8 @@ abstract::AnalysisResult AbstractAnalyze(const ResourcePtr &res, const FuncGraph
|
|||
}
|
||||
}
|
||||
auto ret = engine->Run(func_graph, args_spec);
|
||||
MS_LOG(INFO) << "function call max depth: " << engine->function_call_max_depth()
|
||||
<< ", simulate call max depth: " << engine->stack_frame_max_depth();
|
||||
MS_LOG(DEBUG) << "AbstractAnalyze end";
|
||||
return ret;
|
||||
}
|
||||
|
|
|
@ -877,7 +877,7 @@ class SideEffectFinder {
|
|||
const SccPtr &GetScc(const FuncGraphPtr &func_graph) const {
|
||||
auto found = scc_map_.find(func_graph);
|
||||
if (found == scc_map_.end()) {
|
||||
MS_LOG(EXCEPTION) << "SCC not found for " << func_graph->ToString() << "." << func_graph->debug_info()->get_id();
|
||||
MS_LOG(EXCEPTION) << "SCC not found for " << func_graph->ToString();
|
||||
}
|
||||
return found->second;
|
||||
}
|
||||
|
|
|
@ -24,6 +24,7 @@
|
|||
#include "abstract/utils.h"
|
||||
#include "debug/trace.h"
|
||||
#include "utils/ms_context.h"
|
||||
#include "pipeline/jit/static_analysis/stack_frame.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace abstract {
|
||||
|
@ -66,62 +67,145 @@ AnalysisContextPtr BaseFuncGraphEvaluator::MakeContext(const AnalysisEnginePtr &
|
|||
return context;
|
||||
}
|
||||
|
||||
EvalResultPtr BaseFuncGraphEvaluator::Eval(AnalysisEnginePtr engine, const AbstractBasePtrList &args_spec_list) {
|
||||
FuncGraphPtr fg = GetFuncGraph(engine, args_spec_list);
|
||||
MS_EXCEPTION_IF_NULL(fg);
|
||||
std::size_t nargs = fg->parameters().size();
|
||||
if (args_spec_list.size() != nargs) {
|
||||
MS_EXCEPTION(TypeError) << "Function " << fg->ToString() << ", The number of parameters of this function is "
|
||||
<< fg->parameters().size() << ", but the number of provided arguments is "
|
||||
<< args_spec_list.size() << ". NodeInfo: " << trace::GetDebugInfo(fg->debug_info());
|
||||
}
|
||||
MS_EXCEPTION_IF_NULL(parent_context_);
|
||||
MS_EXCEPTION_IF_NULL(engine);
|
||||
graph_context_ = parent_context_->NewFuncGraphContext(fg, args_spec_list);
|
||||
const auto ¶meters = fg->parameters();
|
||||
for (size_t i = 0; i < nargs; i++) {
|
||||
const auto &arg = args_spec_list[i];
|
||||
const auto &node = parameters[i];
|
||||
AnfNodeConfigPtr conf = engine->MakeConfig(node, graph_context_);
|
||||
engine->analysis_cache().set_value(conf, std::make_shared<EvalResult>(arg, nullptr));
|
||||
}
|
||||
const AnfNodePtr &func_node = fg->get_return();
|
||||
void BaseFuncGraphEvaluator::EnterStackFrame(const AnalysisEnginePtr &engine, const StackFramePtr ¤t_stack_frame,
|
||||
const StackFramePtr &new_stack_frame) {
|
||||
// Enter new func graph.
|
||||
auto ¤t_node = current_stack_frame->CurrentNode();
|
||||
auto current_context = current_stack_frame->current_context();
|
||||
AnfNodeConfigPtr call_conf = engine->MakeConfig(current_node, current_context);
|
||||
auto evaluator = new_stack_frame->evaluator();
|
||||
MS_EXCEPTION_IF_NULL(evaluator);
|
||||
trace::TraceGraphEvalEnter(evaluator, call_conf);
|
||||
|
||||
MS_LOG(DEBUG) << "Analysis FuncGraph begin, func graph: " << fg.get() << "/" << fg->ToString()
|
||||
<< ", context: " << graph_context_->ToString() << ", return node: " << func_node->DebugString()
|
||||
<< ", current function call depth: " << engine->function_call_depth();
|
||||
AbstractBasePtr ret_base = nullptr;
|
||||
// Increase & Check the func graph call depth.
|
||||
engine->IncreaseFunctionCallDepth();
|
||||
engine->IncreaseStackFrameDepth();
|
||||
if (engine->function_call_depth() > MsContext::GetInstance()->get_param<uint32_t>(MS_CTX_MAX_CALL_DEPTH)) {
|
||||
MS_LOG(EXCEPTION) << "Exceed function call depth limit "
|
||||
<< MsContext::GetInstance()->get_param<uint32_t>(MS_CTX_MAX_CALL_DEPTH)
|
||||
<< ", (function call depth: " << engine->function_call_depth()
|
||||
<< ", simulate call depth: " << engine->stack_frame_depth()
|
||||
<< "), please call 'context.set_context(max_call_depth=value)' to adjust this value.";
|
||||
}
|
||||
MS_LOG(DEBUG) << evaluator << "(" << evaluator->type_name() << "/" << evaluator->ToString()
|
||||
<< "), enter, function call depth: " << engine->function_call_depth() << " - "
|
||||
<< engine->stack_frame_depth();
|
||||
}
|
||||
|
||||
void BaseFuncGraphEvaluator::LeaveStackFrame(const AnalysisEnginePtr &engine,
|
||||
const StackFramePtr ¤t_stack_frame) {
|
||||
// Leave current func graph.
|
||||
auto evaluator = current_stack_frame->evaluator();
|
||||
MS_EXCEPTION_IF_NULL(evaluator);
|
||||
trace::TraceGraphEvalLeave(evaluator);
|
||||
|
||||
// Decrease the func graph call depth.
|
||||
engine->DecreaseFunctionCallDepth();
|
||||
engine->DecreaseStackFrameDepth();
|
||||
MS_LOG(DEBUG) << evaluator << "(" << evaluator->type_name() << "/" << evaluator->ToString()
|
||||
<< "), leave, function call depth: " << engine->function_call_depth() << " - "
|
||||
<< engine->stack_frame_depth();
|
||||
}
|
||||
|
||||
// Start running stack frames in a Evaluator.
|
||||
AbstractBasePtr BaseFuncGraphEvaluator::LaunchStackFrame(const AnalysisEnginePtr &engine, const FuncGraphPtr &fg) {
|
||||
EvalResultPtr eval_result = nullptr;
|
||||
AbstractBasePtr res_base = nullptr;
|
||||
std::stack<StackFramePtr> stack_frames;
|
||||
auto current_stack_frame = std::make_shared<StackFrame>(shared_from_base<Evaluator>(), fg, context_, parent_context_);
|
||||
MS_LOG(DEBUG) << "[" << this << "/StackFrame] Start at func graph, " << current_stack_frame;
|
||||
stack_frames.push(current_stack_frame);
|
||||
while (1) {
|
||||
current_stack_frame = stack_frames.top();
|
||||
if (current_stack_frame->Done()) {
|
||||
MS_EXCEPTION_IF_NULL(res_base);
|
||||
MS_LOG(DEBUG) << "[" << this << "/StackFrame] Leave from func graph, " << current_stack_frame;
|
||||
stack_frames.pop();
|
||||
if (stack_frames.empty()) {
|
||||
MS_LOG(DEBUG) << "[" << this << "/StackFrame] Finish at func graph, " << current_stack_frame
|
||||
<< ", res_base: " << res_base->ToString();
|
||||
break;
|
||||
}
|
||||
// Save func graph eval result for specialize.
|
||||
auto evaluator = current_stack_frame->evaluator();
|
||||
MS_EXCEPTION_IF_NULL(evaluator);
|
||||
(*evaluator->evaluator_cache_map())[current_stack_frame->args_abs_list()] = eval_result;
|
||||
|
||||
// Leave current func graph.
|
||||
LeaveStackFrame(engine, current_stack_frame);
|
||||
// Switch the stack frame.
|
||||
current_stack_frame = stack_frames.top();
|
||||
MS_LOG(DEBUG) << "[" << this << "/StackFrame] Back to func graph, " << current_stack_frame;
|
||||
current_stack_frame->Back(engine, eval_result);
|
||||
continue;
|
||||
}
|
||||
|
||||
auto new_stack_frame = current_stack_frame->Jump(engine);
|
||||
if (new_stack_frame != nullptr) {
|
||||
// Enter new func graph.
|
||||
EnterStackFrame(engine, current_stack_frame, new_stack_frame);
|
||||
// Update current stack frame.
|
||||
stack_frames.push(new_stack_frame);
|
||||
current_stack_frame = new_stack_frame;
|
||||
MS_LOG(DEBUG) << "[" << this << "/StackFrame] Jump to new func graph, " << new_stack_frame;
|
||||
}
|
||||
|
||||
eval_result = current_stack_frame->Step(engine);
|
||||
MS_EXCEPTION_IF_NULL(eval_result);
|
||||
res_base = eval_result->abstract();
|
||||
}
|
||||
return res_base;
|
||||
}
|
||||
|
||||
EvalResultPtr BaseFuncGraphEvaluator::Eval(AnalysisEnginePtr engine, const AbstractBasePtrList &args_abs_list) {
|
||||
MS_EXCEPTION_IF_NULL(engine);
|
||||
engine->IncreaseFunctionCallDepth();
|
||||
if (engine->function_call_depth() > MsContext::GetInstance()->get_param<uint32_t>(MS_CTX_MAX_CALL_DEPTH)) {
|
||||
MS_LOG(EXCEPTION) << "Exceed function call depth limit "
|
||||
<< MsContext::GetInstance()->get_param<uint32_t>(MS_CTX_MAX_CALL_DEPTH)
|
||||
<< ", please call 'context.set_context(max_call_depth=value)' to adjust this value.";
|
||||
<< ", (function call depth: " << engine->function_call_depth()
|
||||
<< ", simulate call depth: " << engine->stack_frame_depth()
|
||||
<< "), please call 'context.set_context(max_call_depth=value)' to adjust this value.";
|
||||
}
|
||||
const auto &all_nodes = TopoSort(func_node, SuccIncoming, [&fg](const AnfNodePtr &node) -> IncludeType {
|
||||
if (node->func_graph() != fg || node->isa<ValueNode>()) {
|
||||
return EXCLUDE;
|
||||
}
|
||||
return FOLLOW;
|
||||
});
|
||||
for (const auto &node : all_nodes) {
|
||||
AnfNodeConfigPtr node_conf = engine->MakeConfig(node, graph_context_);
|
||||
MS_LOG(DEBUG) << "Analysis node begin, func graph: " << fg.get() << "/" << fg->ToString()
|
||||
<< ", node_conf: " << node_conf->ToString();
|
||||
auto node_eval_result = engine->ObtainEvalResultWithCache(node_conf);
|
||||
ret_base = node_eval_result->abstract();
|
||||
MS_LOG(DEBUG) << "Analysis node end, func graph: " << fg.get() << "/" << fg->ToString()
|
||||
<< ", node_conf: " << node_conf->ToString() << ", abstract: " << ret_base->ToString();
|
||||
MS_LOG(DEBUG) << this << "(" << type_name() << "/" << ToString()
|
||||
<< "), enter, function call depth: " << engine->function_call_depth() << " - "
|
||||
<< engine->stack_frame_depth();
|
||||
|
||||
// Initialize evaluator starter with args_abs_list.
|
||||
FuncGraphPtr fg = GetFuncGraph(engine, args_abs_list);
|
||||
MS_EXCEPTION_IF_NULL(fg);
|
||||
std::size_t nargs = fg->parameters().size();
|
||||
if (args_abs_list.size() != nargs) {
|
||||
MS_EXCEPTION(TypeError) << "Function " << fg->ToString() << ", The number of parameters of this function is "
|
||||
<< fg->parameters().size() << ", but the number of provided arguments is "
|
||||
<< args_abs_list.size() << ". NodeInfo: " << trace::GetDebugInfo(fg->debug_info());
|
||||
}
|
||||
engine->DecreaseFunctionCallDepth();
|
||||
|
||||
MS_EXCEPTION_IF_NULL(ret_base);
|
||||
MS_LOG(DEBUG) << "BaseFuncGraph " << fg->ToString() << " eval end, evaluated abstract: " << ret_base->ToString()
|
||||
<< ", is stub: " << fg->stub();
|
||||
|
||||
MS_EXCEPTION_IF_NULL(parent_context_);
|
||||
context_ = parent_context_->NewFuncGraphContext(fg, args_abs_list);
|
||||
const auto ¶meters = fg->parameters();
|
||||
for (size_t i = 0; i < nargs; i++) {
|
||||
const auto &arg = args_abs_list[i];
|
||||
const auto &node = parameters[i];
|
||||
AnfNodeConfigPtr conf = engine->MakeConfig(node, context_);
|
||||
engine->SaveEvalResultInCache(conf, std::make_shared<EvalResult>(arg, nullptr));
|
||||
}
|
||||
MS_LOG(DEBUG) << "Analysis FuncGraph begin, func graph: " << fg << "/" << fg->ToString()
|
||||
<< ", context: " << context_->ToString() << ", return node: " << fg->get_return()->DebugString()
|
||||
<< ", parent: " << (parent_context_->func_graph() ? parent_context_->func_graph()->ToString() : "NULL")
|
||||
<< ", current function call depth: " << engine->function_call_depth();
|
||||
auto res_base = LaunchStackFrame(engine, fg);
|
||||
MS_EXCEPTION_IF_NULL(res_base);
|
||||
MS_LOG(DEBUG) << "Analysis FuncGraph end, " << fg << "/" << fg->ToString()
|
||||
<< ", evaluated abstract: " << res_base->ToString() << ", is stub: " << fg->stub();
|
||||
if (fg->stub()) {
|
||||
ret_base = std::make_shared<AbstractUndetermined>();
|
||||
res_base = std::make_shared<AbstractUndetermined>();
|
||||
}
|
||||
return std::make_shared<EvalResult>(ret_base, nullptr);
|
||||
|
||||
engine->DecreaseFunctionCallDepth();
|
||||
MS_LOG(DEBUG) << this << "(" << type_name() << "/" << ToString()
|
||||
<< "), leave, function call depth: " << engine->function_call_depth() << " - "
|
||||
<< engine->stack_frame_depth();
|
||||
return std::make_shared<EvalResult>(res_base, nullptr);
|
||||
}
|
||||
|
||||
AbstractBasePtrList FuncGraphEvaluator::NormalizeArgs(const AbstractBasePtrList &args_spec_list) const {
|
||||
|
@ -158,7 +242,7 @@ AbstractBasePtrList FuncGraphEvaluator::BroadenUndeterminedArgs(const AbstractBa
|
|||
if (parent_context_) {
|
||||
MS_LOG(DEBUG) << "Undeterminate FuncGraphEvaluator " << ToString()
|
||||
<< ", context: " << parent_context_->ToString();
|
||||
auto last_context = parent_context_->Filter(func_graph_);
|
||||
auto last_context = parent_context_->FindParentContext(func_graph_);
|
||||
if (last_context && last_context->func_graph() == func_graph_) {
|
||||
MS_LOG(DEBUG) << "Find last eval context: " << last_context->ToString();
|
||||
MS_LOG(DEBUG) << "Current eval args: " << ::mindspore::ToString(args_spec_list);
|
||||
|
|
|
@ -23,6 +23,7 @@
|
|||
#include <string>
|
||||
#include <unordered_map>
|
||||
#include <vector>
|
||||
#include <stack>
|
||||
|
||||
#include "pipeline/jit/static_analysis/static_analysis.h"
|
||||
#include "utils/ms_context.h"
|
||||
|
@ -135,7 +136,7 @@ class SymbolicPrimEvaluator : public PrimEvaluator {
|
|||
virtual EvalResultPtr EvalPrim(const ConfigPtrList &args_conf_list) = 0;
|
||||
};
|
||||
|
||||
// Evaluator will be stored in AnalysisEngine.constructors_
|
||||
// Evaluator will be stored in AnalysisEngine.evaluators_
|
||||
using EvaluatorPtrList = std::vector<EvaluatorPtr>;
|
||||
|
||||
class DummyEvaluator : public Evaluator {
|
||||
|
@ -179,6 +180,11 @@ class TrackedEvaluator : public Evaluator {
|
|||
EvaluatorPtr sub_evaluator_;
|
||||
};
|
||||
|
||||
using FuncGraphCacheMap =
|
||||
std::unordered_map<AbstractBasePtrList, FuncGraphPtr, AbstractBasePtrListHasher, AbstractBasePtrListEqual>;
|
||||
class StackFrame;
|
||||
using StackFramePtr = std::shared_ptr<StackFrame>;
|
||||
|
||||
class BaseFuncGraphEvaluator : public Evaluator {
|
||||
public:
|
||||
explicit BaseFuncGraphEvaluator(const AnalysisContextPtr &context)
|
||||
|
@ -192,19 +198,26 @@ class BaseFuncGraphEvaluator : public Evaluator {
|
|||
virtual FuncGraphPtr GetFuncGraph(AnalysisEnginePtr engine, const AbstractBasePtrList &args_spec_list) = 0;
|
||||
|
||||
AnalysisContextPtr MakeContext(const AnalysisEnginePtr &engine, const AbstractBasePtrList &args_spec_list);
|
||||
AnalysisContextPtr graph_context() const { return graph_context_; }
|
||||
AnalysisContextPtr context() const { return context_; }
|
||||
void set_context(const AnalysisContextPtr &context) { context_ = context; }
|
||||
|
||||
protected:
|
||||
AnalysisContextPtr parent_context_;
|
||||
|
||||
private:
|
||||
AnalysisContextPtr graph_context_;
|
||||
// Add functions for stack frame routine.
|
||||
AbstractBasePtr LaunchStackFrame(const AnalysisEnginePtr &engine, const FuncGraphPtr &fg);
|
||||
void EnterStackFrame(const AnalysisEnginePtr &engine, const StackFramePtr ¤t_stack_frame,
|
||||
const StackFramePtr &new_stack_frame);
|
||||
void LeaveStackFrame(const AnalysisEnginePtr &engine, const StackFramePtr ¤t_stack_frame);
|
||||
|
||||
AnalysisContextPtr context_;
|
||||
};
|
||||
|
||||
class FuncGraphEvaluator : public BaseFuncGraphEvaluator {
|
||||
public:
|
||||
FuncGraphEvaluator(const FuncGraphPtr &func_graph, const AnalysisContextPtr &context)
|
||||
: BaseFuncGraphEvaluator(context->Filter(func_graph)), func_graph_(func_graph) {}
|
||||
: BaseFuncGraphEvaluator(context->FindParentContext(func_graph)), func_graph_(func_graph) {}
|
||||
|
||||
~FuncGraphEvaluator() override = default;
|
||||
MS_DECLARE_PARENT(FuncGraphEvaluator, BaseFuncGraphEvaluator);
|
||||
|
@ -219,8 +232,7 @@ class FuncGraphEvaluator : public BaseFuncGraphEvaluator {
|
|||
|
||||
private:
|
||||
FuncGraphPtr func_graph_;
|
||||
std::unordered_map<AbstractBasePtrList, FuncGraphPtr, AbstractBasePtrListHasher, AbstractBasePtrListEqual>
|
||||
func_graph_cache_;
|
||||
FuncGraphCacheMap func_graph_cache_;
|
||||
std::vector<AbstractBasePtrList> trace_;
|
||||
};
|
||||
using FuncGraphEvaluatorPtr = std::shared_ptr<FuncGraphEvaluator>;
|
||||
|
@ -243,8 +255,7 @@ class MetaFuncGraphEvaluator : public BaseFuncGraphEvaluator {
|
|||
|
||||
private:
|
||||
MetaFuncGraphPtr meta_func_graph_;
|
||||
std::unordered_map<AbstractBasePtrList, FuncGraphPtr, AbstractBasePtrListHasher, AbstractBasePtrListEqual>
|
||||
func_graph_cache_;
|
||||
FuncGraphCacheMap func_graph_cache_;
|
||||
ScopePtr scope_;
|
||||
};
|
||||
|
||||
|
|
|
@ -0,0 +1,114 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "pipeline/jit/static_analysis/stack_frame.h"
|
||||
#include "debug/trace.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace abstract {
|
||||
AbstractBasePtrList StackFrame::GenerateArgsAbsList(const AnalysisEnginePtr &engine, const EvaluatorPtr &evaluator,
|
||||
const CNodePtr current_cnode) {
|
||||
AbstractBasePtrList args_abs_list;
|
||||
auto &inputs = current_cnode->inputs();
|
||||
for (std::size_t i = 1; i < inputs.size(); i++) {
|
||||
auto config = engine->MakeConfig(inputs[i], current_context_);
|
||||
auto abs = config->ObtainEvalResult()->abstract();
|
||||
args_abs_list.push_back(abs);
|
||||
}
|
||||
args_abs_list = evaluator->NormalizeArgs(args_abs_list);
|
||||
args_abs_list = evaluator->BroadenUndeterminedArgs(args_abs_list);
|
||||
return args_abs_list;
|
||||
}
|
||||
|
||||
StackFramePtr StackFrame::DoJump(const AnalysisEnginePtr &engine, const CNodePtr current_cnode,
|
||||
const FuncGraphAbstractClosurePtr &graph_func) {
|
||||
// Get the evaluator for func graph.
|
||||
auto evaluator = engine->GetEvaluatorFor(graph_func);
|
||||
auto fg_evaluator = dyn_cast<BaseFuncGraphEvaluator>(evaluator);
|
||||
if (fg_evaluator == nullptr) {
|
||||
MS_LOG(EXCEPTION) << "Evaluator should be a BaseGraphEvaluator, but got " << evaluator->ToString();
|
||||
}
|
||||
fg_evaluator->set_context(current_context_);
|
||||
|
||||
// Evaluate the inputs firstly. Build arguments for the func graph.
|
||||
AbstractBasePtrList args_abs_list = GenerateArgsAbsList(engine, evaluator, current_cnode);
|
||||
|
||||
// Generate func graph with arguments.
|
||||
auto fg = fg_evaluator->GetFuncGraph(engine, args_abs_list);
|
||||
MS_EXCEPTION_IF_NULL(fg);
|
||||
std::size_t nargs = fg->parameters().size();
|
||||
if (args_abs_list.size() != nargs) {
|
||||
MS_EXCEPTION(TypeError) << "Function " << fg->ToString() << ", The number of parameters of this function is "
|
||||
<< fg->parameters().size() << ", but the number of provided arguments is "
|
||||
<< args_abs_list.size() << ". NodeInfo: " << trace::GetDebugInfo(fg->debug_info());
|
||||
}
|
||||
|
||||
MS_LOG(DEBUG) << "current_node: " << current_cnode->DebugString() << ", fg: " << fg->ToString()
|
||||
<< ", current_context_: " << current_context_->ToString();
|
||||
|
||||
// Find parent context and create new context.
|
||||
auto branch_fg = graph_func->func_graph();
|
||||
auto parent_context = graph_func->context()->FindParentContext(branch_fg);
|
||||
auto new_context = parent_context->NewFuncGraphContext(fg, args_abs_list);
|
||||
// Evaluate the parameters with new context.
|
||||
for (size_t i = 0; i < nargs; i++) {
|
||||
const auto &arg_abs = args_abs_list[i];
|
||||
const auto &node = fg->parameters()[i];
|
||||
AnfNodeConfigPtr conf = engine->MakeConfig(node, new_context);
|
||||
engine->SaveEvalResultInCache(conf, std::make_shared<EvalResult>(arg_abs, nullptr));
|
||||
}
|
||||
|
||||
// Create a new stack frame and set arguments for it.
|
||||
auto new_stack_frame = std::make_shared<StackFrame>(fg_evaluator, fg, new_context, current_context());
|
||||
new_stack_frame->set_args_abs_list(std::move(args_abs_list));
|
||||
return new_stack_frame;
|
||||
}
|
||||
|
||||
// Check if we need branch to another func graph.
|
||||
StackFramePtr StackFrame::Jump(const AnalysisEnginePtr &engine) {
|
||||
auto ¤t_node = CurrentNode();
|
||||
if (!current_node->isa<CNode>()) {
|
||||
return nullptr;
|
||||
}
|
||||
auto cnode = current_node->cast<CNodePtr>();
|
||||
auto func = engine->GetCNodeOperatorAbstract(cnode, current_context_);
|
||||
auto graph_func = dyn_cast<FuncGraphAbstractClosure>(func); // Not handle MetaFuncGraphAbstractClosure by now.
|
||||
if (graph_func == nullptr) {
|
||||
return nullptr; // Not call FuncGraph.
|
||||
}
|
||||
|
||||
// It's FuncGraph Call.
|
||||
return DoJump(engine, cnode, graph_func);
|
||||
}
|
||||
|
||||
EvalResultPtr StackFrame::Step(const AnalysisEnginePtr &engine) {
|
||||
auto ¤t_node = NextNode();
|
||||
MS_LOG(DEBUG) << "current_node: " << current_node->DebugString()
|
||||
<< ", current_context_: " << current_context_->ToString();
|
||||
AnfNodeConfigPtr node_conf = engine->MakeConfig(current_node, current_context_);
|
||||
auto node_eval_result = engine->ObtainEvalResultWithCache(node_conf);
|
||||
return node_eval_result;
|
||||
}
|
||||
|
||||
void StackFrame::Back(const AnalysisEnginePtr &engine, const EvalResultPtr &result) {
|
||||
auto ¤t_node = NextNode();
|
||||
MS_LOG(DEBUG) << "current_node: " << current_node->DebugString()
|
||||
<< ", current_context_: " << current_context_->ToString();
|
||||
AnfNodeConfigPtr node_conf = engine->MakeConfig(current_node, current_context_);
|
||||
engine->SaveEvalResultInCache(node_conf, result);
|
||||
}
|
||||
} // namespace abstract
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,138 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_PIPELINE_JIT_STATIC_ANALYSIS_STACK_FRAME_H_
|
||||
#define MINDSPORE_CCSRC_PIPELINE_JIT_STATIC_ANALYSIS_STACK_FRAME_H_
|
||||
|
||||
#include <utility>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "pipeline/jit/static_analysis/evaluator.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace abstract {
|
||||
class StackFrame;
|
||||
using StackFramePtr = std::shared_ptr<StackFrame>;
|
||||
using EvaluatorWeakPtr = std::weak_ptr<Evaluator>;
|
||||
|
||||
class StackFrame : public Base {
|
||||
public:
|
||||
StackFrame(const EvaluatorPtr &evaluator, const FuncGraphPtr &func_graph, const AnalysisContextPtr ¤t_context,
|
||||
const AnalysisContextPtr &parent_context)
|
||||
: evaluator_(EvaluatorWeakPtr(evaluator)),
|
||||
func_graph_(func_graph),
|
||||
current_context_(current_context),
|
||||
parent_context_(parent_context),
|
||||
slot_index_(0),
|
||||
done_(false) {
|
||||
Load();
|
||||
}
|
||||
virtual ~StackFrame() = default;
|
||||
|
||||
void Load() {
|
||||
node_slots = TopoSort(func_graph_->get_return(), SuccIncoming, [this](const AnfNodePtr &node) -> IncludeType {
|
||||
if (node->func_graph() != func_graph_ || node->isa<ValueNode>()) {
|
||||
return EXCLUDE;
|
||||
}
|
||||
return FOLLOW;
|
||||
});
|
||||
slot_index_ = 0;
|
||||
args_abs_list_.clear();
|
||||
}
|
||||
|
||||
// Check if we need branch to another func graph.
|
||||
StackFramePtr Jump(const AnalysisEnginePtr &engine);
|
||||
// Run one step in current func graph.
|
||||
EvalResultPtr Step(const AnalysisEnginePtr &engine);
|
||||
// Return back from branch func graph.
|
||||
void Back(const AnalysisEnginePtr &engine, const EvalResultPtr &result);
|
||||
|
||||
bool Done() { return done_; }
|
||||
|
||||
AnfNodePtr &CurrentNode() {
|
||||
if (slot_index_ >= node_slots.size()) {
|
||||
MS_LOG(EXCEPTION) << "The stack frame of " << func_graph_->ToAbstract()
|
||||
<< " is invalid. Try to access frame sequence by index " << slot_index_
|
||||
<< ", while the size is " << node_slots.size() << ".";
|
||||
}
|
||||
return node_slots[slot_index_];
|
||||
}
|
||||
|
||||
AnfNodePtr &NextNode() {
|
||||
auto ¤t_node = CurrentNode();
|
||||
// Set `done_` true, if the stack frames is being exhausted.
|
||||
if (current_node == func_graph_->get_return()) {
|
||||
done_ = true;
|
||||
}
|
||||
// Move cursor to next node.
|
||||
slot_index_++;
|
||||
return current_node;
|
||||
}
|
||||
|
||||
EvaluatorPtr evaluator() const { return evaluator_.lock(); }
|
||||
FuncGraphPtr func_graph() const { return func_graph_; }
|
||||
AnalysisContextPtr current_context() const { return current_context_; }
|
||||
AnalysisContextPtr parent_context() const { return parent_context_; }
|
||||
|
||||
AbstractBasePtrList &args_abs_list() { return args_abs_list_; }
|
||||
void set_args_abs_list(const AbstractBasePtrList &&args_abs_list) { args_abs_list_ = args_abs_list; }
|
||||
|
||||
std::string ToString() const override {
|
||||
MS_EXCEPTION_IF_NULL(func_graph_);
|
||||
std::ostringstream buffer;
|
||||
buffer << "StackFrame: " << this << ", " << func_graph_->ToString();
|
||||
if (slot_index_ < node_slots.size()) {
|
||||
auto current_node = node_slots[slot_index_];
|
||||
buffer << "(#" << slot_index_ << " / Running " << current_node->DebugString() << ")";
|
||||
} else {
|
||||
buffer << "(Exhausted..)";
|
||||
}
|
||||
buffer << ", parent: ";
|
||||
auto parent_graph = parent_context_->func_graph();
|
||||
if (parent_graph != nullptr) {
|
||||
buffer << parent_graph << "/" << parent_graph->ToString();
|
||||
} else {
|
||||
buffer << "NULL";
|
||||
}
|
||||
return buffer.str();
|
||||
}
|
||||
|
||||
friend std::ostream &operator<<(std::ostream &os, const StackFramePtr &frame) {
|
||||
MS_EXCEPTION_IF_NULL(frame);
|
||||
os << frame->ToString();
|
||||
return os;
|
||||
}
|
||||
|
||||
private:
|
||||
AbstractBasePtrList GenerateArgsAbsList(const AnalysisEnginePtr &engine, const EvaluatorPtr &evaluator,
|
||||
const CNodePtr current_cnode);
|
||||
StackFramePtr DoJump(const AnalysisEnginePtr &engine, const CNodePtr current_cnode,
|
||||
const FuncGraphAbstractClosurePtr &graph_func);
|
||||
|
||||
EvaluatorWeakPtr evaluator_;
|
||||
FuncGraphPtr func_graph_;
|
||||
AnalysisContextPtr current_context_;
|
||||
AnalysisContextPtr parent_context_;
|
||||
AbstractBasePtrList args_abs_list_;
|
||||
std::vector<AnfNodePtr> node_slots;
|
||||
size_t slot_index_;
|
||||
bool done_;
|
||||
};
|
||||
} // namespace abstract
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_PIPELINE_JIT_STATIC_ANALYSIS_STACK_FRAME_H_
|
|
@ -115,6 +115,7 @@ AnalysisResult AnalysisEngine::Run(const FuncGraphPtr &func_graph, const Abstrac
|
|||
|
||||
// Running the analyzer.
|
||||
ResetFunctionCallDepth();
|
||||
ResetStackFrameDepth();
|
||||
AnalysisContextPtr root_context = Run(func_graph, empty_context, args_conf_list);
|
||||
MS_EXCEPTION_IF_NULL(root_context);
|
||||
MS_EXCEPTION_IF_NULL(root_context->func_graph());
|
||||
|
@ -133,7 +134,7 @@ AnalysisContextPtr AnalysisEngine::Run(const FuncGraphPtr &func_graph, const Ana
|
|||
const ConfigPtrList &args_conf_list) {
|
||||
std::shared_ptr<FuncGraphEvaluator> eval = std::make_shared<FuncGraphEvaluator>(func_graph, context);
|
||||
(void)eval->Run(shared_from_this(), args_conf_list, nullptr);
|
||||
return eval->graph_context();
|
||||
return eval->context();
|
||||
}
|
||||
|
||||
EvalResultPtr AnalysisEngine::ObtainEvalResultWithCache(const AnfNodeConfigPtr &conf) {
|
||||
|
@ -152,7 +153,7 @@ EvalResultPtr AnalysisEngine::ObtainEvalResultWithCache(const AnfNodeConfigPtr &
|
|||
}
|
||||
MS_LOG(DEBUG) << "Evaluate node on demond for NodeConfig: " << conf->ToString()
|
||||
<< ", result: " << result->abstract().get() << ", " << result->abstract()->ToString();
|
||||
analysis_cache_.set_value(conf, result);
|
||||
SaveEvalResultInCache(conf, result);
|
||||
return result;
|
||||
}
|
||||
|
||||
|
@ -179,7 +180,7 @@ EvalResultPtr AnalysisEngine::Eval(const AnfNodeConfigPtr &conf) {
|
|||
auto abstract = EvalValueNode(value_node, conf);
|
||||
eval_result = std::make_shared<EvalResult>(abstract, std::make_shared<AttrValueMap>());
|
||||
} else if (node->isa<CNode>()) {
|
||||
CheckNoStackInSameFuncGraph(conf);
|
||||
// CheckNoStackInSameFuncGraph(conf);
|
||||
auto cnode = node->cast<CNodePtr>();
|
||||
trace::TraceEvalCNodeEnter(conf);
|
||||
eval_result = EvalCNode(cnode, conf);
|
||||
|
@ -224,7 +225,7 @@ void AnalysisEngine::CheckNoStackInSameFuncGraph(const AnfNodeConfigPtr &conf) {
|
|||
MS_LOG(EXCEPTION) << "Top evaluator is " << top_evaluator->ToString();
|
||||
}
|
||||
auto top_fg_evaluator = dyn_cast<BaseFuncGraphEvaluator>(top_evaluator);
|
||||
auto top_context_fg = top_fg_evaluator->graph_context()->func_graph();
|
||||
auto top_context_fg = top_fg_evaluator->context()->func_graph();
|
||||
if (current_cnode_fg != top_context_fg) { // Ignore FV call.
|
||||
return;
|
||||
}
|
||||
|
@ -249,18 +250,15 @@ AbstractBasePtr AnalysisEngine::EvalValueNode(const ValueNodePtr &value_node, co
|
|||
return out;
|
||||
}
|
||||
|
||||
EvalResultPtr AnalysisEngine::EvalCNode(const CNodePtr &cnode, const AnfNodeConfigPtr &conf) {
|
||||
MS_EXCEPTION_IF_NULL(conf);
|
||||
AbstractFunctionPtr AnalysisEngine::GetCNodeOperatorAbstract(const CNodePtr &cnode, const AnalysisContextPtr &context) {
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
auto &inputs = cnode->inputs();
|
||||
if (inputs.empty()) {
|
||||
MS_LOG(EXCEPTION) << "CNode->inputs() is empty, CNode: " << cnode->DebugString();
|
||||
}
|
||||
|
||||
AnfNodePtr func_node = inputs[0];
|
||||
MS_EXCEPTION_IF_NULL(func_node);
|
||||
MS_LOG(DEBUG) << "Current CNode function: " << func_node->DebugString();
|
||||
AnalysisContextPtr context = conf->context();
|
||||
AnfNodeConfigPtr func_conf = MakeConfig(func_node, context);
|
||||
MS_EXCEPTION_IF_NULL(func_conf);
|
||||
// Keep it in a local variable, otherwise smart pointer will free it.
|
||||
|
@ -270,32 +268,40 @@ EvalResultPtr AnalysisEngine::EvalCNode(const CNodePtr &cnode, const AnfNodeConf
|
|||
MS_LOG(EXCEPTION) << "No abstract, func_conf: " << func_conf->ToString()
|
||||
<< " NodeInfo: " << trace::GetDebugInfo(cnode->debug_info());
|
||||
}
|
||||
if (maybe_func->BuildType()->type_id() == kObjectTypeUndeterminedType) {
|
||||
MS_LOG(DEBUG) << "EvalCNode eval Undetermined";
|
||||
return std::make_shared<EvalResult>(maybe_func->Clone(), std::make_shared<AttrValueMap>());
|
||||
}
|
||||
AbstractFunctionPtr func = dyn_cast<AbstractFunction>(maybe_func);
|
||||
if (func == nullptr) {
|
||||
MS_LOG(EXCEPTION) << "Not AbstractFunction: " << maybe_func->ToString() << ", func_conf: " << func_conf->ToString()
|
||||
<< " NodeInfo: " << trace::GetDebugInfo(cnode->debug_info());
|
||||
}
|
||||
return func;
|
||||
}
|
||||
|
||||
EvalResultPtr AnalysisEngine::EvalCNode(const CNodePtr &cnode, const AnfNodeConfigPtr &conf) {
|
||||
MS_EXCEPTION_IF_NULL(conf);
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
AbstractFunctionPtr func = GetCNodeOperatorAbstract(cnode, conf->context());
|
||||
if (func->BuildType()->type_id() == kObjectTypeUndeterminedType) {
|
||||
MS_LOG(DEBUG) << "EvalCNode eval Undetermined";
|
||||
return std::make_shared<EvalResult>(func->Clone(), std::make_shared<AttrValueMap>());
|
||||
}
|
||||
|
||||
ConfigPtrList args_conf_list;
|
||||
// ignore the first node which is function name
|
||||
// Ignore the first node which is function name
|
||||
auto &inputs = cnode->inputs();
|
||||
for (std::size_t i = 1; i < inputs.size(); i++) {
|
||||
const AnfNodePtr &node = inputs[i];
|
||||
args_conf_list.push_back(MakeConfig(node, context));
|
||||
args_conf_list.push_back(MakeConfig(node, conf->context()));
|
||||
}
|
||||
std::vector<EvaluatorPtr> infs;
|
||||
std::vector<EvaluatorPtr> evaluators;
|
||||
|
||||
auto build_evaluator = [this, &infs, &cnode](const AbstractFuncAtomPtr &poss) {
|
||||
auto build_evaluator = [this, &evaluators, &cnode](const AbstractFuncAtomPtr &poss) {
|
||||
auto evaluator = this->GetEvaluatorFor(poss);
|
||||
evaluator->set_bound_node(cnode);
|
||||
infs.push_back(evaluator);
|
||||
evaluators.push_back(evaluator);
|
||||
};
|
||||
func->Visit(build_evaluator);
|
||||
|
||||
auto eval_result = ExecuteEvaluators(infs, conf, args_conf_list);
|
||||
auto eval_result = ExecuteEvaluators(evaluators, conf, args_conf_list);
|
||||
return eval_result;
|
||||
}
|
||||
|
||||
|
@ -314,7 +320,7 @@ EvalResultPtr AnalysisEngine::Execute(const AbstractFunctionPtr &func, const Abs
|
|||
}
|
||||
|
||||
void AnalysisEngine::ClearEvaluatorCache() {
|
||||
for (std::pair<AbstractFunctionPtr, EvaluatorPtr> element : constructors_) {
|
||||
for (std::pair<AbstractFunctionPtr, EvaluatorPtr> element : evaluators_) {
|
||||
EvaluatorPtr evaluator = element.second;
|
||||
MS_EXCEPTION_IF_NULL(evaluator);
|
||||
MS_EXCEPTION_IF_NULL(evaluator->evaluator_cache_map());
|
||||
|
@ -338,7 +344,7 @@ void AnalysisEngine::Clear() {
|
|||
analysis_cache_.Clear();
|
||||
anfnode_config_map_.clear();
|
||||
eval_trace_.clear();
|
||||
constructors_.clear();
|
||||
evaluators_.clear();
|
||||
constructors_app_.clear();
|
||||
continued_evals_.clear();
|
||||
}
|
||||
|
@ -407,38 +413,38 @@ EvaluatorPtr GetPrimEvaluator(const PrimitivePtr &prim, const AnalysisEnginePtr
|
|||
} // namespace
|
||||
|
||||
EvaluatorPtr AnalysisEngine::_GetEvaluatorFor(const std::shared_ptr<PrimitiveAbstractClosure> &func) {
|
||||
auto inf_pair = constructors_.find(func);
|
||||
if (inf_pair != constructors_.end()) {
|
||||
auto inf_pair = evaluators_.find(func);
|
||||
if (inf_pair != evaluators_.end()) {
|
||||
return inf_pair->second;
|
||||
}
|
||||
MS_EXCEPTION_IF_NULL(func);
|
||||
auto primitive = func->prim();
|
||||
auto evaluator = GetPrimEvaluator(primitive, shared_from_this());
|
||||
constructors_[func] = evaluator;
|
||||
evaluators_[func] = evaluator;
|
||||
return evaluator;
|
||||
}
|
||||
|
||||
EvaluatorPtr AnalysisEngine::_GetEvaluatorFor(const std::shared_ptr<FuncGraphAbstractClosure> &func) {
|
||||
auto inf_pair = constructors_.find(func);
|
||||
if (inf_pair != constructors_.end()) {
|
||||
auto inf_pair = evaluators_.find(func);
|
||||
if (inf_pair != evaluators_.end()) {
|
||||
return inf_pair->second;
|
||||
}
|
||||
MS_EXCEPTION_IF_NULL(func);
|
||||
std::shared_ptr<FuncGraphEvaluator> func_graph_evaluator =
|
||||
std::make_shared<FuncGraphEvaluator>(func->func_graph(), func->context());
|
||||
constructors_[func] = func_graph_evaluator;
|
||||
evaluators_[func] = func_graph_evaluator;
|
||||
return func_graph_evaluator;
|
||||
}
|
||||
|
||||
EvaluatorPtr AnalysisEngine::_GetEvaluatorFor(const std::shared_ptr<MetaFuncGraphAbstractClosure> &func) {
|
||||
auto inf_pair = constructors_.find(func);
|
||||
if (inf_pair != constructors_.end()) {
|
||||
auto inf_pair = evaluators_.find(func);
|
||||
if (inf_pair != evaluators_.end()) {
|
||||
return inf_pair->second;
|
||||
}
|
||||
MS_EXCEPTION_IF_NULL(func);
|
||||
std::shared_ptr<MetaFuncGraphEvaluator> evaluator =
|
||||
std::make_shared<MetaFuncGraphEvaluator>(func->meta_func_graph(), func->context(), func->GetScope());
|
||||
constructors_[func] = evaluator;
|
||||
evaluators_[func] = evaluator;
|
||||
return evaluator;
|
||||
}
|
||||
|
||||
|
@ -505,18 +511,18 @@ EvaluatorPtr AnalysisEngine::_GetEvaluatorFor(const AbstractFunctionPtr &func) {
|
|||
}
|
||||
|
||||
EvaluatorPtr AnalysisEngine::GetEvaluatorFor(const AbstractFunctionPtr &func) {
|
||||
MS_EXCEPTION_IF_NULL(func);
|
||||
MS_LOG(DEBUG) << "The func value: " << func->ToString();
|
||||
if (func->tracking_id() != nullptr) {
|
||||
MS_LOG(DEBUG) << "The tracking_id: " << func->tracking_id()->DebugString();
|
||||
}
|
||||
MS_EXCEPTION_IF_NULL(func);
|
||||
if (func->tracking_id() == nullptr || func->isa<abstract::MetaFuncGraphAbstractClosure>() ||
|
||||
func->isa<abstract::FuncGraphAbstractClosure>()) {
|
||||
EvaluatorPtr evaluator = _GetEvaluatorFor(func);
|
||||
return evaluator;
|
||||
}
|
||||
auto inf_pair = constructors_.find(func);
|
||||
if (inf_pair != constructors_.end()) {
|
||||
auto inf_pair = evaluators_.find(func);
|
||||
if (inf_pair != evaluators_.end()) {
|
||||
return inf_pair->second;
|
||||
}
|
||||
|
||||
|
@ -524,7 +530,7 @@ EvaluatorPtr AnalysisEngine::GetEvaluatorFor(const AbstractFunctionPtr &func) {
|
|||
func_generic->set_tracking_id(nullptr);
|
||||
EvaluatorPtr eval = _GetEvaluatorFor(func_generic);
|
||||
auto tracked_eval = std::make_shared<TrackedEvaluator>(eval);
|
||||
constructors_[func] = tracked_eval;
|
||||
evaluators_[func] = tracked_eval;
|
||||
|
||||
return tracked_eval;
|
||||
}
|
||||
|
|
|
@ -89,7 +89,7 @@ class AnfNodeConfig : public Config {
|
|||
}
|
||||
context_ = nullptr;
|
||||
if (context != nullptr) {
|
||||
context_ = context->Filter(fg);
|
||||
context_ = context->FindParentContext(fg);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -116,8 +116,8 @@ class AnfNodeConfig : public Config {
|
|||
|
||||
std::string ToString() const override {
|
||||
std::ostringstream buffer;
|
||||
buffer << "Node: " << node_->DebugString() << "-uid(" << node_->UniqueId()
|
||||
<< "), Context: " << context_->ToString();
|
||||
buffer << "Node: " << node_ << "/" << node_->DebugString() << "-uid(" << node_->UniqueId()
|
||||
<< "), Context: " << context_ << "/" << context_->ToString();
|
||||
return buffer.str();
|
||||
}
|
||||
|
||||
|
@ -190,6 +190,9 @@ class AnalysisEngine : public std::enable_shared_from_this<AnalysisEngine> {
|
|||
prim_constructors_(prim_evaluator_map),
|
||||
func_graph_manager_(func_graph_manager) {
|
||||
function_call_depth_ = 0;
|
||||
function_call_max_depth_ = 0;
|
||||
stack_frame_depth_ = 0;
|
||||
stack_frame_max_depth_ = 0;
|
||||
forward_count_ = 0;
|
||||
}
|
||||
~AnalysisEngine() = default;
|
||||
|
@ -197,10 +200,16 @@ class AnalysisEngine : public std::enable_shared_from_this<AnalysisEngine> {
|
|||
// func_graph: The func_graph to analyze.
|
||||
// args_spec_list: The abstracted arguments for the func_graph. Must be a tuple of AbstractBase.
|
||||
AnalysisResult Run(const FuncGraphPtr &func_graph, const AbstractBasePtrList &args_spec_list);
|
||||
void SaveEvalResultInCache(const AnfNodeConfigPtr &conf, const EvalResultPtr &result) {
|
||||
MS_EXCEPTION_IF_NULL(conf);
|
||||
MS_EXCEPTION_IF_NULL(result);
|
||||
analysis_cache_.set_value(conf, result);
|
||||
}
|
||||
EvalResultPtr ObtainEvalResultWithCache(const AnfNodeConfigPtr &conf);
|
||||
// Return the Evaluator for the given function.
|
||||
EvaluatorPtr GetEvaluatorFor(const AbstractFunctionPtr &fn);
|
||||
|
||||
AbstractFunctionPtr GetCNodeOperatorAbstract(const CNodePtr &cnode, const AnalysisContextPtr &context);
|
||||
AbstractBasePtr EvalValueNode(const ValueNodePtr &value_node, const AnfNodeConfigPtr &conf);
|
||||
EvalResultPtr EvalCNode(const CNodePtr &cnode, const AnfNodeConfigPtr &conf);
|
||||
// Infer the result of fn(args).
|
||||
|
@ -231,18 +240,43 @@ class AnalysisEngine : public std::enable_shared_from_this<AnalysisEngine> {
|
|||
AnalysisCache analysis_cache_;
|
||||
std::unordered_map<PrimitivePyPtr, EvaluatorPtr> prim_py_evaluators_;
|
||||
|
||||
void ResetFunctionCallDepth() { function_call_depth_ = 0; }
|
||||
|
||||
void IncreaseFunctionCallDepth() { function_call_depth_++; }
|
||||
|
||||
void ResetFunctionCallDepth() {
|
||||
function_call_depth_ = 0;
|
||||
function_call_max_depth_ = 0;
|
||||
}
|
||||
void IncreaseFunctionCallDepth() {
|
||||
function_call_depth_++;
|
||||
if (function_call_max_depth_ < function_call_depth_) {
|
||||
function_call_max_depth_ = function_call_depth_;
|
||||
}
|
||||
}
|
||||
void DecreaseFunctionCallDepth() {
|
||||
if (function_call_depth_ == 0) {
|
||||
MS_LOG(EXCEPTION) << "Current function call depth is already 0, can not decrease it.";
|
||||
}
|
||||
function_call_depth_--;
|
||||
}
|
||||
size_t function_call_depth() { return function_call_depth_; }
|
||||
size_t function_call_max_depth() { return function_call_max_depth_; }
|
||||
|
||||
uint64_t function_call_depth() { return function_call_depth_; }
|
||||
void ResetStackFrameDepth() {
|
||||
stack_frame_depth_ = 0;
|
||||
stack_frame_max_depth_ = 0;
|
||||
}
|
||||
void IncreaseStackFrameDepth() {
|
||||
stack_frame_depth_++;
|
||||
if (stack_frame_max_depth_ < stack_frame_depth_) {
|
||||
stack_frame_max_depth_ = stack_frame_depth_;
|
||||
}
|
||||
}
|
||||
void DecreaseStackFrameDepth() {
|
||||
if (stack_frame_depth_ == 0) {
|
||||
MS_LOG(EXCEPTION) << "Current stack frame depth is already 0, can not decrease it.";
|
||||
}
|
||||
stack_frame_depth_--;
|
||||
}
|
||||
size_t stack_frame_depth() { return stack_frame_depth_; }
|
||||
size_t stack_frame_max_depth() { return stack_frame_max_depth_; }
|
||||
|
||||
void CheckNoStackInSameFuncGraph(const AnfNodeConfigPtr &conf);
|
||||
|
||||
|
@ -282,7 +316,7 @@ class AnalysisEngine : public std::enable_shared_from_this<AnalysisEngine> {
|
|||
|
||||
const PrimEvaluatorMap &prim_constructors_;
|
||||
FuncGraphManagerPtr func_graph_manager_;
|
||||
std::unordered_map<AbstractFunctionPtr, EvaluatorPtr, AbstractFunctionHasher, AbstractFunctionEqual> constructors_;
|
||||
std::unordered_map<AbstractFunctionPtr, EvaluatorPtr, AbstractFunctionHasher, AbstractFunctionEqual> evaluators_;
|
||||
std::unordered_map<std::pair<AbstractFunctionPtr, AbstractBasePtrList>, EvaluatorPtr, PartialAppHasher>
|
||||
constructors_app_;
|
||||
AnfNodeConfigMap anfnode_config_map_;
|
||||
|
@ -299,10 +333,15 @@ class AnalysisEngine : public std::enable_shared_from_this<AnalysisEngine> {
|
|||
const ConfigPtrList &args_conf_list);
|
||||
EvalResultPtr ExecuteMultipleEvaluators(const std::vector<EvaluatorPtr> &evaluators, const AnfNodeConfigPtr &out_conf,
|
||||
const ConfigPtrList &args_conf_list);
|
||||
// record current depth of function call statck
|
||||
uint64_t function_call_depth_;
|
||||
// Record current depth of function call stack, including `stack_frame_depth_`.
|
||||
size_t function_call_depth_;
|
||||
size_t function_call_max_depth_;
|
||||
|
||||
uint64_t forward_count_;
|
||||
// Record current depth of stack frames call.
|
||||
size_t stack_frame_depth_;
|
||||
size_t stack_frame_max_depth_;
|
||||
|
||||
size_t forward_count_;
|
||||
|
||||
#ifdef DEBUG
|
||||
std::vector<AnfNodePtr> compute_conf_stack_;
|
||||
|
|
|
@ -133,7 +133,7 @@ class FuncGraphAbstractClosure : public AbstractFuncAtom {
|
|||
// so different tracking_id will produce different FuncGraphAbstractClosure,
|
||||
// different FuncGraphEvaluator.
|
||||
// Espcecially useful for recursive func graph call, so it will not mess up
|
||||
// the graph_context_ in FuncGraphEvaluator.
|
||||
// the `context_` in FuncGraphEvaluator.
|
||||
// Notes: Be careful to use nullptr for this variable.
|
||||
// store it as weak_ptr to break reference cycle.
|
||||
AnfNodeWeakPtr tracking_id_;
|
||||
|
|
|
@ -23,39 +23,39 @@
|
|||
|
||||
namespace mindspore {
|
||||
namespace abstract {
|
||||
AnalysisContextPtr AnalysisContext::NewContext(AnalysisContextPtr parent, FuncGraphPtr fg,
|
||||
AnalysisContextPtr AnalysisContext::NewContext(AnalysisContextPtr parent_context, FuncGraphPtr fg,
|
||||
const AbstractBasePtrList &args_spec_list) {
|
||||
auto children_context_map_iter = parent->children_cache_.find(fg);
|
||||
if (children_context_map_iter != parent->children_cache_.end()) {
|
||||
auto children_context_map_iter = parent_context->children_cache_.find(fg);
|
||||
if (children_context_map_iter != parent_context->children_cache_.end()) {
|
||||
auto children_context_map = children_context_map_iter->second;
|
||||
auto children_context_iter = children_context_map.find(args_spec_list);
|
||||
if (children_context_iter != children_context_map.end()) {
|
||||
return children_context_iter->second.lock();
|
||||
}
|
||||
}
|
||||
AnalysisContextPtr context_new = std::make_shared<AnalysisContext>(parent, fg, args_spec_list);
|
||||
AnalysisContextPtr new_context = std::make_shared<AnalysisContext>(parent_context, fg, args_spec_list);
|
||||
// Reference to myself, so use weak_ptr to break reference cycle.
|
||||
auto weak_context = std::weak_ptr<AnalysisContext>(context_new);
|
||||
context_new->parent_cache_[fg] = weak_context;
|
||||
parent->children_cache_[fg][args_spec_list] = weak_context;
|
||||
return context_new;
|
||||
auto weak_context = std::weak_ptr<AnalysisContext>(new_context);
|
||||
new_context->parent_cache_[fg] = weak_context;
|
||||
parent_context->children_cache_[fg][args_spec_list] = weak_context;
|
||||
return new_context;
|
||||
}
|
||||
|
||||
AnalysisContextPtr AnalysisContext::NewFuncGraphContext(const FuncGraphPtr &func_graph,
|
||||
const AbstractBasePtrList &args_spec_list) {
|
||||
FuncGraphPtr graph_parent = func_graph->parent();
|
||||
auto iter = parent_cache_.find(graph_parent);
|
||||
FuncGraphPtr parent_graph = func_graph->parent();
|
||||
AnalysisContextPtr parent_context = nullptr;
|
||||
auto iter = parent_cache_.find(parent_graph);
|
||||
if (iter != parent_cache_.end()) {
|
||||
parent_context = iter->second.lock();
|
||||
}
|
||||
// if this happen, it will be bug in code. but we raise exception to keep the scene.
|
||||
// If this happen, it will be a bug in code. But we raise exception to keep the scene.
|
||||
if (parent_context == nullptr) {
|
||||
std::ostringstream oss;
|
||||
oss << "BUG: cannot found parent_context in current context: " << this->ToString()
|
||||
<< ", func_graph: " << func_graph->ToString() << ", graph_parent: ";
|
||||
if (graph_parent != nullptr) {
|
||||
oss << graph_parent->ToString();
|
||||
oss << "BUG: Failed to find parent context in current context: " << this->ToString()
|
||||
<< ", func_graph: " << func_graph->ToString() << ", parent_graph: ";
|
||||
if (parent_graph != nullptr) {
|
||||
oss << parent_graph->ToString();
|
||||
} else {
|
||||
oss << "nullptr";
|
||||
}
|
||||
|
@ -64,7 +64,7 @@ AnalysisContextPtr AnalysisContext::NewFuncGraphContext(const FuncGraphPtr &func
|
|||
return NewContext(parent_context, func_graph, args_spec_list);
|
||||
}
|
||||
|
||||
AnalysisContextPtr AnalysisContext::Filter(const FuncGraphPtr &func_graph) {
|
||||
AnalysisContextPtr AnalysisContext::FindParentContext(const FuncGraphPtr &func_graph) {
|
||||
auto p_iter = parent_cache_.find(func_graph);
|
||||
AnalysisContextPtr parent_context = nullptr;
|
||||
if (p_iter != parent_cache_.end()) {
|
||||
|
@ -75,10 +75,10 @@ AnalysisContextPtr AnalysisContext::Filter(const FuncGraphPtr &func_graph) {
|
|||
parent_context = iter_parent->second.lock();
|
||||
}
|
||||
}
|
||||
// if this happen, it will be bug in code. but we raise exception to keep the scene.
|
||||
// If this happen, it would be a bug in code. But we raise exception to keep the scene.
|
||||
if (parent_context == nullptr) {
|
||||
std::ostringstream oss;
|
||||
oss << "BUG: Filter graph failed: " << func_graph->ToString() << ", graph_parent: ";
|
||||
oss << "BUG: Failed to find parent context for: " << func_graph->ToString() << ", parent_graph: ";
|
||||
if (func_graph->parent() != nullptr) {
|
||||
oss << func_graph->parent()->ToString();
|
||||
} else {
|
||||
|
|
|
@ -52,7 +52,7 @@ class AnalysisContext {
|
|||
AnalysisContextPtr NewFuncGraphContext(const FuncGraphPtr &func_graph, const AbstractBasePtrList &args_spec_list);
|
||||
|
||||
// Return a context restricted to a graph's dependencies.
|
||||
AnalysisContextPtr Filter(const FuncGraphPtr &graph);
|
||||
AnalysisContextPtr FindParentContext(const FuncGraphPtr &graph);
|
||||
bool operator==(const AnalysisContext &other) const;
|
||||
std::size_t hash();
|
||||
static AnalysisContextPtr DummyContext();
|
||||
|
|
|
@ -176,7 +176,11 @@ void FuncGraph::DumpCNodeList() {
|
|||
}
|
||||
|
||||
std::string FuncGraph::ToString() const {
|
||||
return mindspore::label_manage::Label(const_cast<FuncGraph *>(this)->shared_from_base<FuncGraph>()->debug_info());
|
||||
std::ostringstream buffer;
|
||||
auto debug_info = const_cast<FuncGraph *>(this)->shared_from_base<FuncGraph>()->debug_info();
|
||||
buffer << mindspore::label_manage::Label(debug_info);
|
||||
buffer << "." << debug_info->get_id();
|
||||
return buffer.str();
|
||||
}
|
||||
|
||||
GraphDebugInfoPtr FuncGraph::debug_info() {
|
||||
|
|
|
@ -295,22 +295,22 @@ TEST_F(TestInferGraph, test_context) {
|
|||
AnalysisContextPtr dummy_context = AnalysisContext::DummyContext();
|
||||
|
||||
AnalysisContextPtr f_context = dummy_context->NewFuncGraphContext(graph_f_, AbstractBasePtrList());
|
||||
ASSERT_TRUE(f_context->Filter(graph_f_) = f_context);
|
||||
ASSERT_TRUE(f_context->Filter(nullptr) = dummy_context);
|
||||
ASSERT_TRUE(f_context->FindParentContext(graph_f_) = f_context);
|
||||
ASSERT_TRUE(f_context->FindParentContext(nullptr) = dummy_context);
|
||||
|
||||
AnalysisContextPtr g_context = f_context->NewFuncGraphContext(graph_g_, AbstractBasePtrList());
|
||||
ASSERT_TRUE(g_context->Filter(graph_g_) = g_context);
|
||||
ASSERT_TRUE(g_context->Filter(graph_f_) = dummy_context);
|
||||
ASSERT_TRUE(g_context->Filter(nullptr) = dummy_context);
|
||||
ASSERT_TRUE(g_context->FindParentContext(graph_g_) = g_context);
|
||||
ASSERT_TRUE(g_context->FindParentContext(graph_f_) = dummy_context);
|
||||
ASSERT_TRUE(g_context->FindParentContext(nullptr) = dummy_context);
|
||||
|
||||
AnalysisContextPtr alpha_context = dummy_context->NewFuncGraphContext(graph_alpha_, AbstractBasePtrList());
|
||||
ASSERT_TRUE(alpha_context->Filter(graph_alpha_) = alpha_context);
|
||||
ASSERT_TRUE(alpha_context->Filter(nullptr) = dummy_context);
|
||||
ASSERT_TRUE(alpha_context->FindParentContext(graph_alpha_) = alpha_context);
|
||||
ASSERT_TRUE(alpha_context->FindParentContext(nullptr) = dummy_context);
|
||||
|
||||
AnalysisContextPtr beta_context = alpha_context->NewFuncGraphContext(graph_beta_, AbstractBasePtrList());
|
||||
ASSERT_TRUE(beta_context->Filter(graph_beta_) = beta_context);
|
||||
ASSERT_TRUE(beta_context->Filter(graph_alpha_) = alpha_context);
|
||||
ASSERT_TRUE(beta_context->Filter(nullptr) = dummy_context);
|
||||
ASSERT_TRUE(beta_context->FindParentContext(graph_beta_) = beta_context);
|
||||
ASSERT_TRUE(beta_context->FindParentContext(graph_alpha_) = alpha_context);
|
||||
ASSERT_TRUE(beta_context->FindParentContext(nullptr) = dummy_context);
|
||||
}
|
||||
|
||||
class TestInferMetaGraph : public UT::Common {
|
||||
|
|
Loading…
Reference in New Issue