!16313 Evaluate the func graph in simulative stack, not to call recursively.

From: @zh_qh
Reviewed-by: @hwhewei,@ginfung
Signed-off-by: @hwhewei
This commit is contained in:
mindspore-ci-bot 2021-05-13 08:56:18 +08:00 committed by Gitee
commit eb483a276b
16 changed files with 545 additions and 150 deletions

View File

@ -114,7 +114,7 @@ void DumpGlobalInfoEntry(const FuncGraphPtr &graph, std::ostringstream &buffer)
return;
}
buffer << "#IR entry : @" << graph->ToString() << "." << graph->debug_info()->get_id() << std::endl;
buffer << "#IR entry : @" << graph->ToString() << std::endl;
buffer << "#attrs :" << std::endl;
for (const auto &attr : graph->attrs()) {
buffer << attr.first << " : ";
@ -216,7 +216,7 @@ void DumpOperator(const AnfNodePtr &op, const std::shared_ptr<SubGraphIRInfo> &g
if (IsValueNode<FuncGraph>(op)) {
FuncGraphPtr fg = GetValueNode<FuncGraphPtr>(op);
if (fg != nullptr) {
gsub->buffer << "call @" << fg->ToString() << "." << fg->debug_info()->get_id();
gsub->buffer << "call @" << fg->ToString();
}
} else if (op->isa<CNode>()) {
if (gsub->local_var_map.find(op) != gsub->local_var_map.end()) {
@ -224,7 +224,7 @@ void DumpOperator(const AnfNodePtr &op, const std::shared_ptr<SubGraphIRInfo> &g
} else {
auto node = op->cast<CNodePtr>();
auto fg = node->func_graph();
gsub->buffer << "$(" << fg->ToString() << "." << fg->debug_info()->get_id() << ":" << node->ToString() << ")";
gsub->buffer << "$(" << fg->ToString() << ":" << node->ToString() << ")";
}
} else if (op->isa<ValueNode>()) {
gsub->buffer << GetValueNode(op)->ToString();
@ -262,14 +262,14 @@ void DumpOperands(const AnfNodePtr &nd, OrderedMap<AnfNodePtr, int32_t> *para_ma
} else {
auto node = in->cast<CNodePtr>();
auto fg = node->func_graph();
gsub->buffer << "$(" << fg->ToString() << "." << fg->debug_info()->get_id() << ":" << node->ToString() << ")";
gsub->buffer << "$(" << fg->ToString() << ":" << node->ToString() << ")";
}
} else if (in->isa<ValueNode>() && !IsValueNode<FuncGraph>(in)) {
// non Primitive valuenode
gsub->buffer << GetValueNode(in)->ToString();
} else if (IsValueNode<FuncGraph>(in)) {
FuncGraphPtr fg = GetValueNode<FuncGraphPtr>(in);
gsub->buffer << "@" << fg->ToString() << "." << fg->debug_info()->get_id();
gsub->buffer << "@" << fg->ToString();
} else {
gsub->buffer << in->ToString();
}
@ -501,8 +501,7 @@ void DumpSubgraph(const OrderedMap<FuncGraphPtr, std::shared_ptr<SubGraphIRInfo>
}
fout << std::endl;
}
fout << "subgraph @" << sg.first->ToString() << ".";
fout << sg.first->debug_info()->get_id() << "(";
fout << "subgraph @" << sg.first->ToString() << "(";
if (sg.first != graph) {
std::vector<AnfNodePtr> parameters = sg.first->parameters();
if (parameters.size() == 1) {

View File

@ -73,7 +73,7 @@ int AnfExporter::GetParamIndex(const FuncGraphPtr &func_graph, const AnfNodePtr
if (!check_integrity_) {
break;
}
MS_LOG(EXCEPTION) << "Can not find func graph '" << fg->DumpText() << "." << fg->debug_info()->get_id() << "'";
MS_LOG(EXCEPTION) << "Can not find func graph '" << fg->DumpText() << "'";
}
auto param_map = exported[fg];
if (param_map.find(param) != param_map.end()) {
@ -83,7 +83,7 @@ int AnfExporter::GetParamIndex(const FuncGraphPtr &func_graph, const AnfNodePtr
}
if (throw_excp) {
MS_LOG(EXCEPTION) << "Can not find index for param '" << param->DumpText() << "' for func graph '"
<< func_graph->DumpText() << "." << func_graph->debug_info()->get_id() << "'";
<< func_graph->DumpText() << "'";
}
return -1;
}
@ -457,7 +457,7 @@ void AnfExporter::OutputStatementComment(std::ofstream &ofs, const CNodePtr &nod
}
FuncGraphPtr fg = GetValueNode<FuncGraphPtr>(arg);
std::string func_graph_id = fg->debug_info()->get_id();
comment << " fg_" << func_graph_id << "=" << fg->ToString() << "." << func_graph_id;
comment << " fg_" << func_graph_id << "=" << fg->ToString();
}
if (has_comment) {
ofs << comment.str();
@ -552,8 +552,7 @@ void AnfExporter::ExportOneFuncGraph(std::ofstream &ofs, const FuncGraphPtr &fun
if (*(func_graph->switch_layer_input())) {
ofs << "switch_layer_input: " << *(func_graph->switch_layer_input()) << "\n";
}
ofs << "# [No." << (exported.size() + 1) << "] " << func_graph->DumpText() << "."
<< func_graph->debug_info()->get_id() << "\n";
ofs << "# [No." << (exported.size() + 1) << "] " << func_graph->DumpText() << "\n";
if (label_manage::GetGlobalTraceLabelType() == label_manage::TraceLabelType::kWithUniqueId) {
ofs << trace::GetDebugInfo(func_graph->debug_info(), "# ", kSourceLineTipDiscard) << "#"
<< label_manage::Label(func_graph->debug_info()) << "\n";

View File

@ -93,7 +93,7 @@ void DumpInferStack(std::ostringstream &oss) {
infer_vec.clear();
break;
}
auto graph_context = graph_infer->graph_context();
auto graph_context = graph_infer->context();
if (graph_context == nullptr) {
MS_LOG(INFO) << "Null context continue";
continue;
@ -253,7 +253,7 @@ std::vector<AnalysisContextPtr> AnalyzedFuncGraphExporter::ProcessFuncGraphCall(
}
auto base_fg_evaluator = dyn_cast<abstract::BaseFuncGraphEvaluator>(evaluator);
auto ctx = base_fg_evaluator->graph_context();
auto ctx = base_fg_evaluator->context();
if (ctx != nullptr && context_map_.insert({ctx, false}).second) {
MS_LOG(DEBUG) << "Add new context, ctx.addr = " << ctx.get() << "ctx = " << ctx->ToString();
context_vec_.push_back(ctx);
@ -298,7 +298,7 @@ void AnalyzedFuncGraphExporter::OutputStatementComment(std::ofstream &ofs, const
}
FuncGraphPtr fg = GetValueNode<FuncGraphPtr>(arg);
std::string func_graph_id = fg->debug_info()->get_id();
comment << " fg_" << func_graph_id << "=" << fg->ToString() << "." << func_graph_id;
comment << " fg_" << func_graph_id << "=" << fg->ToString();
if (ctxs.size() > i && ctxs[i] != nullptr) {
comment << "(@ctx.addr=" << ctxs[i].get() << ")";
}
@ -392,8 +392,7 @@ void AnalyzedFuncGraphExporter::ExportOneFuncGraph(std::ofstream &ofs, const Fun
std::vector<AnfNodePtr> parameters = func_graph->parameters();
OrderedMap<AnfNodePtr, int, ParamPtrHasher, ParamPtrEqual> param_map;
ofs << "# [No." << (exported.size() + 1) << "] " << func_graph->DumpText() << "."
<< func_graph->debug_info()->get_id();
ofs << "# [No." << (exported.size() + 1) << "] " << func_graph->DumpText();
if (cur_ctx_ != nullptr) {
ofs << " @ctx.addr=" << cur_ctx_.get();
}

View File

@ -115,6 +115,8 @@ abstract::AnalysisResult AbstractAnalyze(const ResourcePtr &res, const FuncGraph
}
}
auto ret = engine->Run(func_graph, args_spec);
MS_LOG(INFO) << "function call max depth: " << engine->function_call_max_depth()
<< ", simulate call max depth: " << engine->stack_frame_max_depth();
MS_LOG(DEBUG) << "AbstractAnalyze end";
return ret;
}

View File

@ -877,7 +877,7 @@ class SideEffectFinder {
const SccPtr &GetScc(const FuncGraphPtr &func_graph) const {
auto found = scc_map_.find(func_graph);
if (found == scc_map_.end()) {
MS_LOG(EXCEPTION) << "SCC not found for " << func_graph->ToString() << "." << func_graph->debug_info()->get_id();
MS_LOG(EXCEPTION) << "SCC not found for " << func_graph->ToString();
}
return found->second;
}

View File

@ -24,6 +24,7 @@
#include "abstract/utils.h"
#include "debug/trace.h"
#include "utils/ms_context.h"
#include "pipeline/jit/static_analysis/stack_frame.h"
namespace mindspore {
namespace abstract {
@ -66,62 +67,145 @@ AnalysisContextPtr BaseFuncGraphEvaluator::MakeContext(const AnalysisEnginePtr &
return context;
}
EvalResultPtr BaseFuncGraphEvaluator::Eval(AnalysisEnginePtr engine, const AbstractBasePtrList &args_spec_list) {
FuncGraphPtr fg = GetFuncGraph(engine, args_spec_list);
MS_EXCEPTION_IF_NULL(fg);
std::size_t nargs = fg->parameters().size();
if (args_spec_list.size() != nargs) {
MS_EXCEPTION(TypeError) << "Function " << fg->ToString() << ", The number of parameters of this function is "
<< fg->parameters().size() << ", but the number of provided arguments is "
<< args_spec_list.size() << ". NodeInfo: " << trace::GetDebugInfo(fg->debug_info());
}
MS_EXCEPTION_IF_NULL(parent_context_);
MS_EXCEPTION_IF_NULL(engine);
graph_context_ = parent_context_->NewFuncGraphContext(fg, args_spec_list);
const auto &parameters = fg->parameters();
for (size_t i = 0; i < nargs; i++) {
const auto &arg = args_spec_list[i];
const auto &node = parameters[i];
AnfNodeConfigPtr conf = engine->MakeConfig(node, graph_context_);
engine->analysis_cache().set_value(conf, std::make_shared<EvalResult>(arg, nullptr));
}
const AnfNodePtr &func_node = fg->get_return();
void BaseFuncGraphEvaluator::EnterStackFrame(const AnalysisEnginePtr &engine, const StackFramePtr &current_stack_frame,
const StackFramePtr &new_stack_frame) {
// Enter new func graph.
auto &current_node = current_stack_frame->CurrentNode();
auto current_context = current_stack_frame->current_context();
AnfNodeConfigPtr call_conf = engine->MakeConfig(current_node, current_context);
auto evaluator = new_stack_frame->evaluator();
MS_EXCEPTION_IF_NULL(evaluator);
trace::TraceGraphEvalEnter(evaluator, call_conf);
MS_LOG(DEBUG) << "Analysis FuncGraph begin, func graph: " << fg.get() << "/" << fg->ToString()
<< ", context: " << graph_context_->ToString() << ", return node: " << func_node->DebugString()
<< ", current function call depth: " << engine->function_call_depth();
AbstractBasePtr ret_base = nullptr;
// Increase & Check the func graph call depth.
engine->IncreaseFunctionCallDepth();
engine->IncreaseStackFrameDepth();
if (engine->function_call_depth() > MsContext::GetInstance()->get_param<uint32_t>(MS_CTX_MAX_CALL_DEPTH)) {
MS_LOG(EXCEPTION) << "Exceed function call depth limit "
<< MsContext::GetInstance()->get_param<uint32_t>(MS_CTX_MAX_CALL_DEPTH)
<< ", (function call depth: " << engine->function_call_depth()
<< ", simulate call depth: " << engine->stack_frame_depth()
<< "), please call 'context.set_context(max_call_depth=value)' to adjust this value.";
}
MS_LOG(DEBUG) << evaluator << "(" << evaluator->type_name() << "/" << evaluator->ToString()
<< "), enter, function call depth: " << engine->function_call_depth() << " - "
<< engine->stack_frame_depth();
}
void BaseFuncGraphEvaluator::LeaveStackFrame(const AnalysisEnginePtr &engine,
const StackFramePtr &current_stack_frame) {
// Leave current func graph.
auto evaluator = current_stack_frame->evaluator();
MS_EXCEPTION_IF_NULL(evaluator);
trace::TraceGraphEvalLeave(evaluator);
// Decrease the func graph call depth.
engine->DecreaseFunctionCallDepth();
engine->DecreaseStackFrameDepth();
MS_LOG(DEBUG) << evaluator << "(" << evaluator->type_name() << "/" << evaluator->ToString()
<< "), leave, function call depth: " << engine->function_call_depth() << " - "
<< engine->stack_frame_depth();
}
// Start running stack frames in a Evaluator.
AbstractBasePtr BaseFuncGraphEvaluator::LaunchStackFrame(const AnalysisEnginePtr &engine, const FuncGraphPtr &fg) {
EvalResultPtr eval_result = nullptr;
AbstractBasePtr res_base = nullptr;
std::stack<StackFramePtr> stack_frames;
auto current_stack_frame = std::make_shared<StackFrame>(shared_from_base<Evaluator>(), fg, context_, parent_context_);
MS_LOG(DEBUG) << "[" << this << "/StackFrame] Start at func graph, " << current_stack_frame;
stack_frames.push(current_stack_frame);
while (1) {
current_stack_frame = stack_frames.top();
if (current_stack_frame->Done()) {
MS_EXCEPTION_IF_NULL(res_base);
MS_LOG(DEBUG) << "[" << this << "/StackFrame] Leave from func graph, " << current_stack_frame;
stack_frames.pop();
if (stack_frames.empty()) {
MS_LOG(DEBUG) << "[" << this << "/StackFrame] Finish at func graph, " << current_stack_frame
<< ", res_base: " << res_base->ToString();
break;
}
// Save func graph eval result for specialize.
auto evaluator = current_stack_frame->evaluator();
MS_EXCEPTION_IF_NULL(evaluator);
(*evaluator->evaluator_cache_map())[current_stack_frame->args_abs_list()] = eval_result;
// Leave current func graph.
LeaveStackFrame(engine, current_stack_frame);
// Switch the stack frame.
current_stack_frame = stack_frames.top();
MS_LOG(DEBUG) << "[" << this << "/StackFrame] Back to func graph, " << current_stack_frame;
current_stack_frame->Back(engine, eval_result);
continue;
}
auto new_stack_frame = current_stack_frame->Jump(engine);
if (new_stack_frame != nullptr) {
// Enter new func graph.
EnterStackFrame(engine, current_stack_frame, new_stack_frame);
// Update current stack frame.
stack_frames.push(new_stack_frame);
current_stack_frame = new_stack_frame;
MS_LOG(DEBUG) << "[" << this << "/StackFrame] Jump to new func graph, " << new_stack_frame;
}
eval_result = current_stack_frame->Step(engine);
MS_EXCEPTION_IF_NULL(eval_result);
res_base = eval_result->abstract();
}
return res_base;
}
EvalResultPtr BaseFuncGraphEvaluator::Eval(AnalysisEnginePtr engine, const AbstractBasePtrList &args_abs_list) {
MS_EXCEPTION_IF_NULL(engine);
engine->IncreaseFunctionCallDepth();
if (engine->function_call_depth() > MsContext::GetInstance()->get_param<uint32_t>(MS_CTX_MAX_CALL_DEPTH)) {
MS_LOG(EXCEPTION) << "Exceed function call depth limit "
<< MsContext::GetInstance()->get_param<uint32_t>(MS_CTX_MAX_CALL_DEPTH)
<< ", please call 'context.set_context(max_call_depth=value)' to adjust this value.";
<< ", (function call depth: " << engine->function_call_depth()
<< ", simulate call depth: " << engine->stack_frame_depth()
<< "), please call 'context.set_context(max_call_depth=value)' to adjust this value.";
}
const auto &all_nodes = TopoSort(func_node, SuccIncoming, [&fg](const AnfNodePtr &node) -> IncludeType {
if (node->func_graph() != fg || node->isa<ValueNode>()) {
return EXCLUDE;
}
return FOLLOW;
});
for (const auto &node : all_nodes) {
AnfNodeConfigPtr node_conf = engine->MakeConfig(node, graph_context_);
MS_LOG(DEBUG) << "Analysis node begin, func graph: " << fg.get() << "/" << fg->ToString()
<< ", node_conf: " << node_conf->ToString();
auto node_eval_result = engine->ObtainEvalResultWithCache(node_conf);
ret_base = node_eval_result->abstract();
MS_LOG(DEBUG) << "Analysis node end, func graph: " << fg.get() << "/" << fg->ToString()
<< ", node_conf: " << node_conf->ToString() << ", abstract: " << ret_base->ToString();
MS_LOG(DEBUG) << this << "(" << type_name() << "/" << ToString()
<< "), enter, function call depth: " << engine->function_call_depth() << " - "
<< engine->stack_frame_depth();
// Initialize evaluator starter with args_abs_list.
FuncGraphPtr fg = GetFuncGraph(engine, args_abs_list);
MS_EXCEPTION_IF_NULL(fg);
std::size_t nargs = fg->parameters().size();
if (args_abs_list.size() != nargs) {
MS_EXCEPTION(TypeError) << "Function " << fg->ToString() << ", The number of parameters of this function is "
<< fg->parameters().size() << ", but the number of provided arguments is "
<< args_abs_list.size() << ". NodeInfo: " << trace::GetDebugInfo(fg->debug_info());
}
engine->DecreaseFunctionCallDepth();
MS_EXCEPTION_IF_NULL(ret_base);
MS_LOG(DEBUG) << "BaseFuncGraph " << fg->ToString() << " eval end, evaluated abstract: " << ret_base->ToString()
<< ", is stub: " << fg->stub();
MS_EXCEPTION_IF_NULL(parent_context_);
context_ = parent_context_->NewFuncGraphContext(fg, args_abs_list);
const auto &parameters = fg->parameters();
for (size_t i = 0; i < nargs; i++) {
const auto &arg = args_abs_list[i];
const auto &node = parameters[i];
AnfNodeConfigPtr conf = engine->MakeConfig(node, context_);
engine->SaveEvalResultInCache(conf, std::make_shared<EvalResult>(arg, nullptr));
}
MS_LOG(DEBUG) << "Analysis FuncGraph begin, func graph: " << fg << "/" << fg->ToString()
<< ", context: " << context_->ToString() << ", return node: " << fg->get_return()->DebugString()
<< ", parent: " << (parent_context_->func_graph() ? parent_context_->func_graph()->ToString() : "NULL")
<< ", current function call depth: " << engine->function_call_depth();
auto res_base = LaunchStackFrame(engine, fg);
MS_EXCEPTION_IF_NULL(res_base);
MS_LOG(DEBUG) << "Analysis FuncGraph end, " << fg << "/" << fg->ToString()
<< ", evaluated abstract: " << res_base->ToString() << ", is stub: " << fg->stub();
if (fg->stub()) {
ret_base = std::make_shared<AbstractUndetermined>();
res_base = std::make_shared<AbstractUndetermined>();
}
return std::make_shared<EvalResult>(ret_base, nullptr);
engine->DecreaseFunctionCallDepth();
MS_LOG(DEBUG) << this << "(" << type_name() << "/" << ToString()
<< "), leave, function call depth: " << engine->function_call_depth() << " - "
<< engine->stack_frame_depth();
return std::make_shared<EvalResult>(res_base, nullptr);
}
AbstractBasePtrList FuncGraphEvaluator::NormalizeArgs(const AbstractBasePtrList &args_spec_list) const {
@ -158,7 +242,7 @@ AbstractBasePtrList FuncGraphEvaluator::BroadenUndeterminedArgs(const AbstractBa
if (parent_context_) {
MS_LOG(DEBUG) << "Undeterminate FuncGraphEvaluator " << ToString()
<< ", context: " << parent_context_->ToString();
auto last_context = parent_context_->Filter(func_graph_);
auto last_context = parent_context_->FindParentContext(func_graph_);
if (last_context && last_context->func_graph() == func_graph_) {
MS_LOG(DEBUG) << "Find last eval context: " << last_context->ToString();
MS_LOG(DEBUG) << "Current eval args: " << ::mindspore::ToString(args_spec_list);

View File

@ -23,6 +23,7 @@
#include <string>
#include <unordered_map>
#include <vector>
#include <stack>
#include "pipeline/jit/static_analysis/static_analysis.h"
#include "utils/ms_context.h"
@ -135,7 +136,7 @@ class SymbolicPrimEvaluator : public PrimEvaluator {
virtual EvalResultPtr EvalPrim(const ConfigPtrList &args_conf_list) = 0;
};
// Evaluator will be stored in AnalysisEngine.constructors_
// Evaluator will be stored in AnalysisEngine.evaluators_
using EvaluatorPtrList = std::vector<EvaluatorPtr>;
class DummyEvaluator : public Evaluator {
@ -179,6 +180,11 @@ class TrackedEvaluator : public Evaluator {
EvaluatorPtr sub_evaluator_;
};
using FuncGraphCacheMap =
std::unordered_map<AbstractBasePtrList, FuncGraphPtr, AbstractBasePtrListHasher, AbstractBasePtrListEqual>;
class StackFrame;
using StackFramePtr = std::shared_ptr<StackFrame>;
class BaseFuncGraphEvaluator : public Evaluator {
public:
explicit BaseFuncGraphEvaluator(const AnalysisContextPtr &context)
@ -192,19 +198,26 @@ class BaseFuncGraphEvaluator : public Evaluator {
virtual FuncGraphPtr GetFuncGraph(AnalysisEnginePtr engine, const AbstractBasePtrList &args_spec_list) = 0;
AnalysisContextPtr MakeContext(const AnalysisEnginePtr &engine, const AbstractBasePtrList &args_spec_list);
AnalysisContextPtr graph_context() const { return graph_context_; }
AnalysisContextPtr context() const { return context_; }
void set_context(const AnalysisContextPtr &context) { context_ = context; }
protected:
AnalysisContextPtr parent_context_;
private:
AnalysisContextPtr graph_context_;
// Add functions for stack frame routine.
AbstractBasePtr LaunchStackFrame(const AnalysisEnginePtr &engine, const FuncGraphPtr &fg);
void EnterStackFrame(const AnalysisEnginePtr &engine, const StackFramePtr &current_stack_frame,
const StackFramePtr &new_stack_frame);
void LeaveStackFrame(const AnalysisEnginePtr &engine, const StackFramePtr &current_stack_frame);
AnalysisContextPtr context_;
};
class FuncGraphEvaluator : public BaseFuncGraphEvaluator {
public:
FuncGraphEvaluator(const FuncGraphPtr &func_graph, const AnalysisContextPtr &context)
: BaseFuncGraphEvaluator(context->Filter(func_graph)), func_graph_(func_graph) {}
: BaseFuncGraphEvaluator(context->FindParentContext(func_graph)), func_graph_(func_graph) {}
~FuncGraphEvaluator() override = default;
MS_DECLARE_PARENT(FuncGraphEvaluator, BaseFuncGraphEvaluator);
@ -219,8 +232,7 @@ class FuncGraphEvaluator : public BaseFuncGraphEvaluator {
private:
FuncGraphPtr func_graph_;
std::unordered_map<AbstractBasePtrList, FuncGraphPtr, AbstractBasePtrListHasher, AbstractBasePtrListEqual>
func_graph_cache_;
FuncGraphCacheMap func_graph_cache_;
std::vector<AbstractBasePtrList> trace_;
};
using FuncGraphEvaluatorPtr = std::shared_ptr<FuncGraphEvaluator>;
@ -243,8 +255,7 @@ class MetaFuncGraphEvaluator : public BaseFuncGraphEvaluator {
private:
MetaFuncGraphPtr meta_func_graph_;
std::unordered_map<AbstractBasePtrList, FuncGraphPtr, AbstractBasePtrListHasher, AbstractBasePtrListEqual>
func_graph_cache_;
FuncGraphCacheMap func_graph_cache_;
ScopePtr scope_;
};

View File

@ -0,0 +1,114 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "pipeline/jit/static_analysis/stack_frame.h"
#include "debug/trace.h"
namespace mindspore {
namespace abstract {
AbstractBasePtrList StackFrame::GenerateArgsAbsList(const AnalysisEnginePtr &engine, const EvaluatorPtr &evaluator,
const CNodePtr current_cnode) {
AbstractBasePtrList args_abs_list;
auto &inputs = current_cnode->inputs();
for (std::size_t i = 1; i < inputs.size(); i++) {
auto config = engine->MakeConfig(inputs[i], current_context_);
auto abs = config->ObtainEvalResult()->abstract();
args_abs_list.push_back(abs);
}
args_abs_list = evaluator->NormalizeArgs(args_abs_list);
args_abs_list = evaluator->BroadenUndeterminedArgs(args_abs_list);
return args_abs_list;
}
StackFramePtr StackFrame::DoJump(const AnalysisEnginePtr &engine, const CNodePtr current_cnode,
const FuncGraphAbstractClosurePtr &graph_func) {
// Get the evaluator for func graph.
auto evaluator = engine->GetEvaluatorFor(graph_func);
auto fg_evaluator = dyn_cast<BaseFuncGraphEvaluator>(evaluator);
if (fg_evaluator == nullptr) {
MS_LOG(EXCEPTION) << "Evaluator should be a BaseGraphEvaluator, but got " << evaluator->ToString();
}
fg_evaluator->set_context(current_context_);
// Evaluate the inputs firstly. Build arguments for the func graph.
AbstractBasePtrList args_abs_list = GenerateArgsAbsList(engine, evaluator, current_cnode);
// Generate func graph with arguments.
auto fg = fg_evaluator->GetFuncGraph(engine, args_abs_list);
MS_EXCEPTION_IF_NULL(fg);
std::size_t nargs = fg->parameters().size();
if (args_abs_list.size() != nargs) {
MS_EXCEPTION(TypeError) << "Function " << fg->ToString() << ", The number of parameters of this function is "
<< fg->parameters().size() << ", but the number of provided arguments is "
<< args_abs_list.size() << ". NodeInfo: " << trace::GetDebugInfo(fg->debug_info());
}
MS_LOG(DEBUG) << "current_node: " << current_cnode->DebugString() << ", fg: " << fg->ToString()
<< ", current_context_: " << current_context_->ToString();
// Find parent context and create new context.
auto branch_fg = graph_func->func_graph();
auto parent_context = graph_func->context()->FindParentContext(branch_fg);
auto new_context = parent_context->NewFuncGraphContext(fg, args_abs_list);
// Evaluate the parameters with new context.
for (size_t i = 0; i < nargs; i++) {
const auto &arg_abs = args_abs_list[i];
const auto &node = fg->parameters()[i];
AnfNodeConfigPtr conf = engine->MakeConfig(node, new_context);
engine->SaveEvalResultInCache(conf, std::make_shared<EvalResult>(arg_abs, nullptr));
}
// Create a new stack frame and set arguments for it.
auto new_stack_frame = std::make_shared<StackFrame>(fg_evaluator, fg, new_context, current_context());
new_stack_frame->set_args_abs_list(std::move(args_abs_list));
return new_stack_frame;
}
// Check if we need branch to another func graph.
StackFramePtr StackFrame::Jump(const AnalysisEnginePtr &engine) {
auto &current_node = CurrentNode();
if (!current_node->isa<CNode>()) {
return nullptr;
}
auto cnode = current_node->cast<CNodePtr>();
auto func = engine->GetCNodeOperatorAbstract(cnode, current_context_);
auto graph_func = dyn_cast<FuncGraphAbstractClosure>(func); // Not handle MetaFuncGraphAbstractClosure by now.
if (graph_func == nullptr) {
return nullptr; // Not call FuncGraph.
}
// It's FuncGraph Call.
return DoJump(engine, cnode, graph_func);
}
EvalResultPtr StackFrame::Step(const AnalysisEnginePtr &engine) {
auto &current_node = NextNode();
MS_LOG(DEBUG) << "current_node: " << current_node->DebugString()
<< ", current_context_: " << current_context_->ToString();
AnfNodeConfigPtr node_conf = engine->MakeConfig(current_node, current_context_);
auto node_eval_result = engine->ObtainEvalResultWithCache(node_conf);
return node_eval_result;
}
void StackFrame::Back(const AnalysisEnginePtr &engine, const EvalResultPtr &result) {
auto &current_node = NextNode();
MS_LOG(DEBUG) << "current_node: " << current_node->DebugString()
<< ", current_context_: " << current_context_->ToString();
AnfNodeConfigPtr node_conf = engine->MakeConfig(current_node, current_context_);
engine->SaveEvalResultInCache(node_conf, result);
}
} // namespace abstract
} // namespace mindspore

View File

@ -0,0 +1,138 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_PIPELINE_JIT_STATIC_ANALYSIS_STACK_FRAME_H_
#define MINDSPORE_CCSRC_PIPELINE_JIT_STATIC_ANALYSIS_STACK_FRAME_H_
#include <utility>
#include <memory>
#include <string>
#include <vector>
#include "pipeline/jit/static_analysis/evaluator.h"
namespace mindspore {
namespace abstract {
class StackFrame;
using StackFramePtr = std::shared_ptr<StackFrame>;
using EvaluatorWeakPtr = std::weak_ptr<Evaluator>;
class StackFrame : public Base {
public:
StackFrame(const EvaluatorPtr &evaluator, const FuncGraphPtr &func_graph, const AnalysisContextPtr &current_context,
const AnalysisContextPtr &parent_context)
: evaluator_(EvaluatorWeakPtr(evaluator)),
func_graph_(func_graph),
current_context_(current_context),
parent_context_(parent_context),
slot_index_(0),
done_(false) {
Load();
}
virtual ~StackFrame() = default;
void Load() {
node_slots = TopoSort(func_graph_->get_return(), SuccIncoming, [this](const AnfNodePtr &node) -> IncludeType {
if (node->func_graph() != func_graph_ || node->isa<ValueNode>()) {
return EXCLUDE;
}
return FOLLOW;
});
slot_index_ = 0;
args_abs_list_.clear();
}
// Check if we need branch to another func graph.
StackFramePtr Jump(const AnalysisEnginePtr &engine);
// Run one step in current func graph.
EvalResultPtr Step(const AnalysisEnginePtr &engine);
// Return back from branch func graph.
void Back(const AnalysisEnginePtr &engine, const EvalResultPtr &result);
bool Done() { return done_; }
AnfNodePtr &CurrentNode() {
if (slot_index_ >= node_slots.size()) {
MS_LOG(EXCEPTION) << "The stack frame of " << func_graph_->ToAbstract()
<< " is invalid. Try to access frame sequence by index " << slot_index_
<< ", while the size is " << node_slots.size() << ".";
}
return node_slots[slot_index_];
}
AnfNodePtr &NextNode() {
auto &current_node = CurrentNode();
// Set `done_` true, if the stack frames is being exhausted.
if (current_node == func_graph_->get_return()) {
done_ = true;
}
// Move cursor to next node.
slot_index_++;
return current_node;
}
EvaluatorPtr evaluator() const { return evaluator_.lock(); }
FuncGraphPtr func_graph() const { return func_graph_; }
AnalysisContextPtr current_context() const { return current_context_; }
AnalysisContextPtr parent_context() const { return parent_context_; }
AbstractBasePtrList &args_abs_list() { return args_abs_list_; }
void set_args_abs_list(const AbstractBasePtrList &&args_abs_list) { args_abs_list_ = args_abs_list; }
std::string ToString() const override {
MS_EXCEPTION_IF_NULL(func_graph_);
std::ostringstream buffer;
buffer << "StackFrame: " << this << ", " << func_graph_->ToString();
if (slot_index_ < node_slots.size()) {
auto current_node = node_slots[slot_index_];
buffer << "(#" << slot_index_ << " / Running " << current_node->DebugString() << ")";
} else {
buffer << "(Exhausted..)";
}
buffer << ", parent: ";
auto parent_graph = parent_context_->func_graph();
if (parent_graph != nullptr) {
buffer << parent_graph << "/" << parent_graph->ToString();
} else {
buffer << "NULL";
}
return buffer.str();
}
friend std::ostream &operator<<(std::ostream &os, const StackFramePtr &frame) {
MS_EXCEPTION_IF_NULL(frame);
os << frame->ToString();
return os;
}
private:
AbstractBasePtrList GenerateArgsAbsList(const AnalysisEnginePtr &engine, const EvaluatorPtr &evaluator,
const CNodePtr current_cnode);
StackFramePtr DoJump(const AnalysisEnginePtr &engine, const CNodePtr current_cnode,
const FuncGraphAbstractClosurePtr &graph_func);
EvaluatorWeakPtr evaluator_;
FuncGraphPtr func_graph_;
AnalysisContextPtr current_context_;
AnalysisContextPtr parent_context_;
AbstractBasePtrList args_abs_list_;
std::vector<AnfNodePtr> node_slots;
size_t slot_index_;
bool done_;
};
} // namespace abstract
} // namespace mindspore
#endif // MINDSPORE_CCSRC_PIPELINE_JIT_STATIC_ANALYSIS_STACK_FRAME_H_

View File

@ -115,6 +115,7 @@ AnalysisResult AnalysisEngine::Run(const FuncGraphPtr &func_graph, const Abstrac
// Running the analyzer.
ResetFunctionCallDepth();
ResetStackFrameDepth();
AnalysisContextPtr root_context = Run(func_graph, empty_context, args_conf_list);
MS_EXCEPTION_IF_NULL(root_context);
MS_EXCEPTION_IF_NULL(root_context->func_graph());
@ -133,7 +134,7 @@ AnalysisContextPtr AnalysisEngine::Run(const FuncGraphPtr &func_graph, const Ana
const ConfigPtrList &args_conf_list) {
std::shared_ptr<FuncGraphEvaluator> eval = std::make_shared<FuncGraphEvaluator>(func_graph, context);
(void)eval->Run(shared_from_this(), args_conf_list, nullptr);
return eval->graph_context();
return eval->context();
}
EvalResultPtr AnalysisEngine::ObtainEvalResultWithCache(const AnfNodeConfigPtr &conf) {
@ -152,7 +153,7 @@ EvalResultPtr AnalysisEngine::ObtainEvalResultWithCache(const AnfNodeConfigPtr &
}
MS_LOG(DEBUG) << "Evaluate node on demond for NodeConfig: " << conf->ToString()
<< ", result: " << result->abstract().get() << ", " << result->abstract()->ToString();
analysis_cache_.set_value(conf, result);
SaveEvalResultInCache(conf, result);
return result;
}
@ -179,7 +180,7 @@ EvalResultPtr AnalysisEngine::Eval(const AnfNodeConfigPtr &conf) {
auto abstract = EvalValueNode(value_node, conf);
eval_result = std::make_shared<EvalResult>(abstract, std::make_shared<AttrValueMap>());
} else if (node->isa<CNode>()) {
CheckNoStackInSameFuncGraph(conf);
// CheckNoStackInSameFuncGraph(conf);
auto cnode = node->cast<CNodePtr>();
trace::TraceEvalCNodeEnter(conf);
eval_result = EvalCNode(cnode, conf);
@ -224,7 +225,7 @@ void AnalysisEngine::CheckNoStackInSameFuncGraph(const AnfNodeConfigPtr &conf) {
MS_LOG(EXCEPTION) << "Top evaluator is " << top_evaluator->ToString();
}
auto top_fg_evaluator = dyn_cast<BaseFuncGraphEvaluator>(top_evaluator);
auto top_context_fg = top_fg_evaluator->graph_context()->func_graph();
auto top_context_fg = top_fg_evaluator->context()->func_graph();
if (current_cnode_fg != top_context_fg) { // Ignore FV call.
return;
}
@ -249,18 +250,15 @@ AbstractBasePtr AnalysisEngine::EvalValueNode(const ValueNodePtr &value_node, co
return out;
}
EvalResultPtr AnalysisEngine::EvalCNode(const CNodePtr &cnode, const AnfNodeConfigPtr &conf) {
MS_EXCEPTION_IF_NULL(conf);
AbstractFunctionPtr AnalysisEngine::GetCNodeOperatorAbstract(const CNodePtr &cnode, const AnalysisContextPtr &context) {
MS_EXCEPTION_IF_NULL(cnode);
auto &inputs = cnode->inputs();
if (inputs.empty()) {
MS_LOG(EXCEPTION) << "CNode->inputs() is empty, CNode: " << cnode->DebugString();
}
AnfNodePtr func_node = inputs[0];
MS_EXCEPTION_IF_NULL(func_node);
MS_LOG(DEBUG) << "Current CNode function: " << func_node->DebugString();
AnalysisContextPtr context = conf->context();
AnfNodeConfigPtr func_conf = MakeConfig(func_node, context);
MS_EXCEPTION_IF_NULL(func_conf);
// Keep it in a local variable, otherwise smart pointer will free it.
@ -270,32 +268,40 @@ EvalResultPtr AnalysisEngine::EvalCNode(const CNodePtr &cnode, const AnfNodeConf
MS_LOG(EXCEPTION) << "No abstract, func_conf: " << func_conf->ToString()
<< " NodeInfo: " << trace::GetDebugInfo(cnode->debug_info());
}
if (maybe_func->BuildType()->type_id() == kObjectTypeUndeterminedType) {
MS_LOG(DEBUG) << "EvalCNode eval Undetermined";
return std::make_shared<EvalResult>(maybe_func->Clone(), std::make_shared<AttrValueMap>());
}
AbstractFunctionPtr func = dyn_cast<AbstractFunction>(maybe_func);
if (func == nullptr) {
MS_LOG(EXCEPTION) << "Not AbstractFunction: " << maybe_func->ToString() << ", func_conf: " << func_conf->ToString()
<< " NodeInfo: " << trace::GetDebugInfo(cnode->debug_info());
}
return func;
}
EvalResultPtr AnalysisEngine::EvalCNode(const CNodePtr &cnode, const AnfNodeConfigPtr &conf) {
MS_EXCEPTION_IF_NULL(conf);
MS_EXCEPTION_IF_NULL(cnode);
AbstractFunctionPtr func = GetCNodeOperatorAbstract(cnode, conf->context());
if (func->BuildType()->type_id() == kObjectTypeUndeterminedType) {
MS_LOG(DEBUG) << "EvalCNode eval Undetermined";
return std::make_shared<EvalResult>(func->Clone(), std::make_shared<AttrValueMap>());
}
ConfigPtrList args_conf_list;
// ignore the first node which is function name
// Ignore the first node which is function name
auto &inputs = cnode->inputs();
for (std::size_t i = 1; i < inputs.size(); i++) {
const AnfNodePtr &node = inputs[i];
args_conf_list.push_back(MakeConfig(node, context));
args_conf_list.push_back(MakeConfig(node, conf->context()));
}
std::vector<EvaluatorPtr> infs;
std::vector<EvaluatorPtr> evaluators;
auto build_evaluator = [this, &infs, &cnode](const AbstractFuncAtomPtr &poss) {
auto build_evaluator = [this, &evaluators, &cnode](const AbstractFuncAtomPtr &poss) {
auto evaluator = this->GetEvaluatorFor(poss);
evaluator->set_bound_node(cnode);
infs.push_back(evaluator);
evaluators.push_back(evaluator);
};
func->Visit(build_evaluator);
auto eval_result = ExecuteEvaluators(infs, conf, args_conf_list);
auto eval_result = ExecuteEvaluators(evaluators, conf, args_conf_list);
return eval_result;
}
@ -314,7 +320,7 @@ EvalResultPtr AnalysisEngine::Execute(const AbstractFunctionPtr &func, const Abs
}
void AnalysisEngine::ClearEvaluatorCache() {
for (std::pair<AbstractFunctionPtr, EvaluatorPtr> element : constructors_) {
for (std::pair<AbstractFunctionPtr, EvaluatorPtr> element : evaluators_) {
EvaluatorPtr evaluator = element.second;
MS_EXCEPTION_IF_NULL(evaluator);
MS_EXCEPTION_IF_NULL(evaluator->evaluator_cache_map());
@ -338,7 +344,7 @@ void AnalysisEngine::Clear() {
analysis_cache_.Clear();
anfnode_config_map_.clear();
eval_trace_.clear();
constructors_.clear();
evaluators_.clear();
constructors_app_.clear();
continued_evals_.clear();
}
@ -407,38 +413,38 @@ EvaluatorPtr GetPrimEvaluator(const PrimitivePtr &prim, const AnalysisEnginePtr
} // namespace
EvaluatorPtr AnalysisEngine::_GetEvaluatorFor(const std::shared_ptr<PrimitiveAbstractClosure> &func) {
auto inf_pair = constructors_.find(func);
if (inf_pair != constructors_.end()) {
auto inf_pair = evaluators_.find(func);
if (inf_pair != evaluators_.end()) {
return inf_pair->second;
}
MS_EXCEPTION_IF_NULL(func);
auto primitive = func->prim();
auto evaluator = GetPrimEvaluator(primitive, shared_from_this());
constructors_[func] = evaluator;
evaluators_[func] = evaluator;
return evaluator;
}
EvaluatorPtr AnalysisEngine::_GetEvaluatorFor(const std::shared_ptr<FuncGraphAbstractClosure> &func) {
auto inf_pair = constructors_.find(func);
if (inf_pair != constructors_.end()) {
auto inf_pair = evaluators_.find(func);
if (inf_pair != evaluators_.end()) {
return inf_pair->second;
}
MS_EXCEPTION_IF_NULL(func);
std::shared_ptr<FuncGraphEvaluator> func_graph_evaluator =
std::make_shared<FuncGraphEvaluator>(func->func_graph(), func->context());
constructors_[func] = func_graph_evaluator;
evaluators_[func] = func_graph_evaluator;
return func_graph_evaluator;
}
EvaluatorPtr AnalysisEngine::_GetEvaluatorFor(const std::shared_ptr<MetaFuncGraphAbstractClosure> &func) {
auto inf_pair = constructors_.find(func);
if (inf_pair != constructors_.end()) {
auto inf_pair = evaluators_.find(func);
if (inf_pair != evaluators_.end()) {
return inf_pair->second;
}
MS_EXCEPTION_IF_NULL(func);
std::shared_ptr<MetaFuncGraphEvaluator> evaluator =
std::make_shared<MetaFuncGraphEvaluator>(func->meta_func_graph(), func->context(), func->GetScope());
constructors_[func] = evaluator;
evaluators_[func] = evaluator;
return evaluator;
}
@ -505,18 +511,18 @@ EvaluatorPtr AnalysisEngine::_GetEvaluatorFor(const AbstractFunctionPtr &func) {
}
EvaluatorPtr AnalysisEngine::GetEvaluatorFor(const AbstractFunctionPtr &func) {
MS_EXCEPTION_IF_NULL(func);
MS_LOG(DEBUG) << "The func value: " << func->ToString();
if (func->tracking_id() != nullptr) {
MS_LOG(DEBUG) << "The tracking_id: " << func->tracking_id()->DebugString();
}
MS_EXCEPTION_IF_NULL(func);
if (func->tracking_id() == nullptr || func->isa<abstract::MetaFuncGraphAbstractClosure>() ||
func->isa<abstract::FuncGraphAbstractClosure>()) {
EvaluatorPtr evaluator = _GetEvaluatorFor(func);
return evaluator;
}
auto inf_pair = constructors_.find(func);
if (inf_pair != constructors_.end()) {
auto inf_pair = evaluators_.find(func);
if (inf_pair != evaluators_.end()) {
return inf_pair->second;
}
@ -524,7 +530,7 @@ EvaluatorPtr AnalysisEngine::GetEvaluatorFor(const AbstractFunctionPtr &func) {
func_generic->set_tracking_id(nullptr);
EvaluatorPtr eval = _GetEvaluatorFor(func_generic);
auto tracked_eval = std::make_shared<TrackedEvaluator>(eval);
constructors_[func] = tracked_eval;
evaluators_[func] = tracked_eval;
return tracked_eval;
}

View File

@ -89,7 +89,7 @@ class AnfNodeConfig : public Config {
}
context_ = nullptr;
if (context != nullptr) {
context_ = context->Filter(fg);
context_ = context->FindParentContext(fg);
}
}
@ -116,8 +116,8 @@ class AnfNodeConfig : public Config {
std::string ToString() const override {
std::ostringstream buffer;
buffer << "Node: " << node_->DebugString() << "-uid(" << node_->UniqueId()
<< "), Context: " << context_->ToString();
buffer << "Node: " << node_ << "/" << node_->DebugString() << "-uid(" << node_->UniqueId()
<< "), Context: " << context_ << "/" << context_->ToString();
return buffer.str();
}
@ -190,6 +190,9 @@ class AnalysisEngine : public std::enable_shared_from_this<AnalysisEngine> {
prim_constructors_(prim_evaluator_map),
func_graph_manager_(func_graph_manager) {
function_call_depth_ = 0;
function_call_max_depth_ = 0;
stack_frame_depth_ = 0;
stack_frame_max_depth_ = 0;
forward_count_ = 0;
}
~AnalysisEngine() = default;
@ -197,10 +200,16 @@ class AnalysisEngine : public std::enable_shared_from_this<AnalysisEngine> {
// func_graph: The func_graph to analyze.
// args_spec_list: The abstracted arguments for the func_graph. Must be a tuple of AbstractBase.
AnalysisResult Run(const FuncGraphPtr &func_graph, const AbstractBasePtrList &args_spec_list);
void SaveEvalResultInCache(const AnfNodeConfigPtr &conf, const EvalResultPtr &result) {
MS_EXCEPTION_IF_NULL(conf);
MS_EXCEPTION_IF_NULL(result);
analysis_cache_.set_value(conf, result);
}
EvalResultPtr ObtainEvalResultWithCache(const AnfNodeConfigPtr &conf);
// Return the Evaluator for the given function.
EvaluatorPtr GetEvaluatorFor(const AbstractFunctionPtr &fn);
AbstractFunctionPtr GetCNodeOperatorAbstract(const CNodePtr &cnode, const AnalysisContextPtr &context);
AbstractBasePtr EvalValueNode(const ValueNodePtr &value_node, const AnfNodeConfigPtr &conf);
EvalResultPtr EvalCNode(const CNodePtr &cnode, const AnfNodeConfigPtr &conf);
// Infer the result of fn(args).
@ -231,18 +240,43 @@ class AnalysisEngine : public std::enable_shared_from_this<AnalysisEngine> {
AnalysisCache analysis_cache_;
std::unordered_map<PrimitivePyPtr, EvaluatorPtr> prim_py_evaluators_;
void ResetFunctionCallDepth() { function_call_depth_ = 0; }
void IncreaseFunctionCallDepth() { function_call_depth_++; }
void ResetFunctionCallDepth() {
function_call_depth_ = 0;
function_call_max_depth_ = 0;
}
void IncreaseFunctionCallDepth() {
function_call_depth_++;
if (function_call_max_depth_ < function_call_depth_) {
function_call_max_depth_ = function_call_depth_;
}
}
void DecreaseFunctionCallDepth() {
if (function_call_depth_ == 0) {
MS_LOG(EXCEPTION) << "Current function call depth is already 0, can not decrease it.";
}
function_call_depth_--;
}
size_t function_call_depth() { return function_call_depth_; }
size_t function_call_max_depth() { return function_call_max_depth_; }
uint64_t function_call_depth() { return function_call_depth_; }
void ResetStackFrameDepth() {
stack_frame_depth_ = 0;
stack_frame_max_depth_ = 0;
}
void IncreaseStackFrameDepth() {
stack_frame_depth_++;
if (stack_frame_max_depth_ < stack_frame_depth_) {
stack_frame_max_depth_ = stack_frame_depth_;
}
}
void DecreaseStackFrameDepth() {
if (stack_frame_depth_ == 0) {
MS_LOG(EXCEPTION) << "Current stack frame depth is already 0, can not decrease it.";
}
stack_frame_depth_--;
}
size_t stack_frame_depth() { return stack_frame_depth_; }
size_t stack_frame_max_depth() { return stack_frame_max_depth_; }
void CheckNoStackInSameFuncGraph(const AnfNodeConfigPtr &conf);
@ -282,7 +316,7 @@ class AnalysisEngine : public std::enable_shared_from_this<AnalysisEngine> {
const PrimEvaluatorMap &prim_constructors_;
FuncGraphManagerPtr func_graph_manager_;
std::unordered_map<AbstractFunctionPtr, EvaluatorPtr, AbstractFunctionHasher, AbstractFunctionEqual> constructors_;
std::unordered_map<AbstractFunctionPtr, EvaluatorPtr, AbstractFunctionHasher, AbstractFunctionEqual> evaluators_;
std::unordered_map<std::pair<AbstractFunctionPtr, AbstractBasePtrList>, EvaluatorPtr, PartialAppHasher>
constructors_app_;
AnfNodeConfigMap anfnode_config_map_;
@ -299,10 +333,15 @@ class AnalysisEngine : public std::enable_shared_from_this<AnalysisEngine> {
const ConfigPtrList &args_conf_list);
EvalResultPtr ExecuteMultipleEvaluators(const std::vector<EvaluatorPtr> &evaluators, const AnfNodeConfigPtr &out_conf,
const ConfigPtrList &args_conf_list);
// record current depth of function call statck
uint64_t function_call_depth_;
// Record current depth of function call stack, including `stack_frame_depth_`.
size_t function_call_depth_;
size_t function_call_max_depth_;
uint64_t forward_count_;
// Record current depth of stack frames call.
size_t stack_frame_depth_;
size_t stack_frame_max_depth_;
size_t forward_count_;
#ifdef DEBUG
std::vector<AnfNodePtr> compute_conf_stack_;

View File

@ -133,7 +133,7 @@ class FuncGraphAbstractClosure : public AbstractFuncAtom {
// so different tracking_id will produce different FuncGraphAbstractClosure,
// different FuncGraphEvaluator.
// Espcecially useful for recursive func graph call, so it will not mess up
// the graph_context_ in FuncGraphEvaluator.
// the `context_` in FuncGraphEvaluator.
// Notes: Be careful to use nullptr for this variable.
// store it as weak_ptr to break reference cycle.
AnfNodeWeakPtr tracking_id_;

View File

@ -23,39 +23,39 @@
namespace mindspore {
namespace abstract {
AnalysisContextPtr AnalysisContext::NewContext(AnalysisContextPtr parent, FuncGraphPtr fg,
AnalysisContextPtr AnalysisContext::NewContext(AnalysisContextPtr parent_context, FuncGraphPtr fg,
const AbstractBasePtrList &args_spec_list) {
auto children_context_map_iter = parent->children_cache_.find(fg);
if (children_context_map_iter != parent->children_cache_.end()) {
auto children_context_map_iter = parent_context->children_cache_.find(fg);
if (children_context_map_iter != parent_context->children_cache_.end()) {
auto children_context_map = children_context_map_iter->second;
auto children_context_iter = children_context_map.find(args_spec_list);
if (children_context_iter != children_context_map.end()) {
return children_context_iter->second.lock();
}
}
AnalysisContextPtr context_new = std::make_shared<AnalysisContext>(parent, fg, args_spec_list);
AnalysisContextPtr new_context = std::make_shared<AnalysisContext>(parent_context, fg, args_spec_list);
// Reference to myself, so use weak_ptr to break reference cycle.
auto weak_context = std::weak_ptr<AnalysisContext>(context_new);
context_new->parent_cache_[fg] = weak_context;
parent->children_cache_[fg][args_spec_list] = weak_context;
return context_new;
auto weak_context = std::weak_ptr<AnalysisContext>(new_context);
new_context->parent_cache_[fg] = weak_context;
parent_context->children_cache_[fg][args_spec_list] = weak_context;
return new_context;
}
AnalysisContextPtr AnalysisContext::NewFuncGraphContext(const FuncGraphPtr &func_graph,
const AbstractBasePtrList &args_spec_list) {
FuncGraphPtr graph_parent = func_graph->parent();
auto iter = parent_cache_.find(graph_parent);
FuncGraphPtr parent_graph = func_graph->parent();
AnalysisContextPtr parent_context = nullptr;
auto iter = parent_cache_.find(parent_graph);
if (iter != parent_cache_.end()) {
parent_context = iter->second.lock();
}
// if this happen, it will be bug in code. but we raise exception to keep the scene.
// If this happen, it will be a bug in code. But we raise exception to keep the scene.
if (parent_context == nullptr) {
std::ostringstream oss;
oss << "BUG: cannot found parent_context in current context: " << this->ToString()
<< ", func_graph: " << func_graph->ToString() << ", graph_parent: ";
if (graph_parent != nullptr) {
oss << graph_parent->ToString();
oss << "BUG: Failed to find parent context in current context: " << this->ToString()
<< ", func_graph: " << func_graph->ToString() << ", parent_graph: ";
if (parent_graph != nullptr) {
oss << parent_graph->ToString();
} else {
oss << "nullptr";
}
@ -64,7 +64,7 @@ AnalysisContextPtr AnalysisContext::NewFuncGraphContext(const FuncGraphPtr &func
return NewContext(parent_context, func_graph, args_spec_list);
}
AnalysisContextPtr AnalysisContext::Filter(const FuncGraphPtr &func_graph) {
AnalysisContextPtr AnalysisContext::FindParentContext(const FuncGraphPtr &func_graph) {
auto p_iter = parent_cache_.find(func_graph);
AnalysisContextPtr parent_context = nullptr;
if (p_iter != parent_cache_.end()) {
@ -75,10 +75,10 @@ AnalysisContextPtr AnalysisContext::Filter(const FuncGraphPtr &func_graph) {
parent_context = iter_parent->second.lock();
}
}
// if this happen, it will be bug in code. but we raise exception to keep the scene.
// If this happen, it would be a bug in code. But we raise exception to keep the scene.
if (parent_context == nullptr) {
std::ostringstream oss;
oss << "BUG: Filter graph failed: " << func_graph->ToString() << ", graph_parent: ";
oss << "BUG: Failed to find parent context for: " << func_graph->ToString() << ", parent_graph: ";
if (func_graph->parent() != nullptr) {
oss << func_graph->parent()->ToString();
} else {

View File

@ -52,7 +52,7 @@ class AnalysisContext {
AnalysisContextPtr NewFuncGraphContext(const FuncGraphPtr &func_graph, const AbstractBasePtrList &args_spec_list);
// Return a context restricted to a graph's dependencies.
AnalysisContextPtr Filter(const FuncGraphPtr &graph);
AnalysisContextPtr FindParentContext(const FuncGraphPtr &graph);
bool operator==(const AnalysisContext &other) const;
std::size_t hash();
static AnalysisContextPtr DummyContext();

View File

@ -176,7 +176,11 @@ void FuncGraph::DumpCNodeList() {
}
std::string FuncGraph::ToString() const {
return mindspore::label_manage::Label(const_cast<FuncGraph *>(this)->shared_from_base<FuncGraph>()->debug_info());
std::ostringstream buffer;
auto debug_info = const_cast<FuncGraph *>(this)->shared_from_base<FuncGraph>()->debug_info();
buffer << mindspore::label_manage::Label(debug_info);
buffer << "." << debug_info->get_id();
return buffer.str();
}
GraphDebugInfoPtr FuncGraph::debug_info() {

View File

@ -295,22 +295,22 @@ TEST_F(TestInferGraph, test_context) {
AnalysisContextPtr dummy_context = AnalysisContext::DummyContext();
AnalysisContextPtr f_context = dummy_context->NewFuncGraphContext(graph_f_, AbstractBasePtrList());
ASSERT_TRUE(f_context->Filter(graph_f_) = f_context);
ASSERT_TRUE(f_context->Filter(nullptr) = dummy_context);
ASSERT_TRUE(f_context->FindParentContext(graph_f_) = f_context);
ASSERT_TRUE(f_context->FindParentContext(nullptr) = dummy_context);
AnalysisContextPtr g_context = f_context->NewFuncGraphContext(graph_g_, AbstractBasePtrList());
ASSERT_TRUE(g_context->Filter(graph_g_) = g_context);
ASSERT_TRUE(g_context->Filter(graph_f_) = dummy_context);
ASSERT_TRUE(g_context->Filter(nullptr) = dummy_context);
ASSERT_TRUE(g_context->FindParentContext(graph_g_) = g_context);
ASSERT_TRUE(g_context->FindParentContext(graph_f_) = dummy_context);
ASSERT_TRUE(g_context->FindParentContext(nullptr) = dummy_context);
AnalysisContextPtr alpha_context = dummy_context->NewFuncGraphContext(graph_alpha_, AbstractBasePtrList());
ASSERT_TRUE(alpha_context->Filter(graph_alpha_) = alpha_context);
ASSERT_TRUE(alpha_context->Filter(nullptr) = dummy_context);
ASSERT_TRUE(alpha_context->FindParentContext(graph_alpha_) = alpha_context);
ASSERT_TRUE(alpha_context->FindParentContext(nullptr) = dummy_context);
AnalysisContextPtr beta_context = alpha_context->NewFuncGraphContext(graph_beta_, AbstractBasePtrList());
ASSERT_TRUE(beta_context->Filter(graph_beta_) = beta_context);
ASSERT_TRUE(beta_context->Filter(graph_alpha_) = alpha_context);
ASSERT_TRUE(beta_context->Filter(nullptr) = dummy_context);
ASSERT_TRUE(beta_context->FindParentContext(graph_beta_) = beta_context);
ASSERT_TRUE(beta_context->FindParentContext(graph_alpha_) = alpha_context);
ASSERT_TRUE(beta_context->FindParentContext(nullptr) = dummy_context);
}
class TestInferMetaGraph : public UT::Common {